diff --git a/src/logic/iff/wav/read.rs b/src/logic/iff/wav/read.rs index 152f88c0..dd279513 100644 --- a/src/logic/iff/wav/read.rs +++ b/src/logic/iff/wav/read.rs @@ -14,15 +14,19 @@ const PCM: u16 = 0x0001; const IEEE_FLOAT: u16 = 0x0003; const EXTENSIBLE: u16 = 0xfffe; -pub(in crate::logic::iff) fn verify_riff(data: &mut T) -> Result<()> +pub(in crate::logic::iff) fn verify_wav(data: &mut T) -> Result<()> where T: Read + Seek, { - let mut id = [0; 4]; + let mut id = [0; 12]; data.read_exact(&mut id)?; - if &id != b"RIFF" { - return Err(LoftyError::Wav("RIFF file doesn't contain a RIFF chunk")); + if &id[..4] != b"RIFF" { + return Err(LoftyError::Wav("WAV file doesn't contain a RIFF chunk")); + } + + if &id[8..] != b"WAVE" { + return Err(LoftyError::Wav("Found RIFF file, format is not WAVE")); } Ok(()) @@ -100,13 +104,11 @@ fn read_properties(fmt: &mut &[u8], total_samples: u32, stream_len: u32) -> Resu )) } -pub(in crate::logic) fn read_from(data: &mut T) -> Result +pub(in crate::logic) fn read_from(data: &mut R) -> Result where - T: Read + Seek, + R: Read + Seek, { - verify_riff(data)?; - - data.seek(SeekFrom::Current(8))?; + verify_wav(data)?; let mut stream_len = 0_u32; let mut total_samples = 0_u32; @@ -154,31 +156,7 @@ where if &list_type == b"INFO" { let end = data.seek(SeekFrom::Current(0))? + u64::from(size - 4); - - while data.seek(SeekFrom::Current(0))? != end { - let mut key = [0; 4]; - data.read_exact(&mut key)?; - - let key_str = std::str::from_utf8(&key) - .map_err(|_| LoftyError::Wav("Non UTF-8 key found in RIFF INFO"))?; - - let item_key = ItemKey::from_key(&TagType::RiffInfo, key_str) - .unwrap_or_else(|| ItemKey::Unknown(key_str.to_string())); - - let size = data.read_u32::()?; - - let mut buf = vec![0; size as usize]; - data.read_exact(&mut buf)?; - - let val = String::from_utf8(buf)?; - - let item = TagItem::new( - item_key, - ItemValue::Text(val.trim_matches('\0').to_string()), - ); - - riff_info.insert_item(item); - } + parse_riff_info(data, end, &mut riff_info)?; } else { data.seek(SeekFrom::Current(i64::from(size)))?; } @@ -220,3 +198,38 @@ where id3v2: id3, }) } + +fn parse_riff_info(data: &mut R, end: u64, tag: &mut Tag) -> Result<()> +where + R: Read + Seek, +{ + while data.seek(SeekFrom::Current(0))? != end { + let mut key = [0; 4]; + data.read_exact(&mut key)?; + + let key_str = std::str::from_utf8(&key) + .map_err(|_| LoftyError::Wav("Non UTF-8 key found in RIFF INFO"))?; + + if !key_str.is_ascii() { + return Err(LoftyError::Wav("Non ascii key found in RIFF INFO")); + } + + let item_key = ItemKey::from_key(&TagType::RiffInfo, key_str) + .unwrap_or_else(|| ItemKey::Unknown(key_str.to_string())); + + let size = data.read_u32::()?; + + let mut value = vec![0; size as usize]; + data.read_exact(&mut value)?; + + let value_str = std::str::from_utf8(&value) + .map_err(|_| LoftyError::Wav("Non UTF-8 value found in RIFF INFO"))?; + + tag.insert_item_unchecked(TagItem::new( + item_key, + ItemValue::Text(value_str.trim_matches('\0').to_string()), + )); + } + + Ok(()) +} diff --git a/src/logic/iff/wav/write.rs b/src/logic/iff/wav/write.rs index 22be8446..055c771b 100644 --- a/src/logic/iff/wav/write.rs +++ b/src/logic/iff/wav/write.rs @@ -1,37 +1,41 @@ -use super::read::verify_riff; +use super::read::verify_wav; use crate::error::{LoftyError, Result}; use crate::types::tag::{ItemValue, Tag, TagType}; use std::fs::File; use std::io::{Read, Seek, SeekFrom, Write}; -use byteorder::{LittleEndian, ReadBytesExt}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; -fn find_info_list(data: &mut T) -> Result<()> +fn find_info_list(data: &mut T) -> Result where T: Read + Seek, { - loop { - let mut chunk_name = [0; 4]; - data.read_exact(&mut chunk_name)?; + let mut fourcc = [0; 4]; - if &chunk_name == b"LIST" { - data.seek(SeekFrom::Current(4))?; + let mut found_info = false; + while let (Ok(()), Ok(size)) = ( + data.read_exact(&mut fourcc), + data.read_u32::(), + ) { + if &fourcc == b"LIST" { let mut list_type = [0; 4]; data.read_exact(&mut list_type)?; if &list_type == b"INFO" { data.seek(SeekFrom::Current(-8))?; - return Ok(()); + found_info = true; + break; } data.seek(SeekFrom::Current(-8))?; } - let size = data.read_u32::()?; data.seek(SeekFrom::Current(i64::from(size)))?; } + + Ok(found_info) } // TODO: ID3v2 @@ -40,12 +44,52 @@ pub(crate) fn write_to(data: &mut File, tag: &Tag) -> Result<()> { return Err(LoftyError::UnsupportedTag); } - verify_riff(data)?; + verify_wav(data)?; - let mut packet = Vec::new(); + let mut riff_info_bytes = Vec::new(); + create_riff_info(tag, &mut riff_info_bytes)?; - packet.extend(b"LIST".iter()); - packet.extend(b"INFO".iter()); + if find_info_list(data)? { + let info_list_size = data.read_u32::()? as usize; + data.seek(SeekFrom::Current(-8))?; + + let info_list_start = data.seek(SeekFrom::Current(0))? as usize; + let info_list_end = info_list_start + 8 + info_list_size; + + data.seek(SeekFrom::Start(0))?; + + let mut file_bytes = Vec::new(); + data.read_to_end(&mut file_bytes)?; + + let _ = file_bytes.splice(info_list_start..info_list_end, riff_info_bytes); + + let total_size = (file_bytes.len() - 8) as u32; + let _ = file_bytes.splice(4..8, total_size.to_le_bytes()); + + data.seek(SeekFrom::Start(0))?; + data.set_len(0)?; + data.write_all(&*file_bytes)?; + } else { + data.seek(SeekFrom::End(0))?; + + data.write_all(&riff_info_bytes)?; + + let len = (data.seek(SeekFrom::Current(0))? - 8) as u32; + + data.seek(SeekFrom::Start(4))?; + data.write_u32::(len)?; + } + + Ok(()) +} + +fn create_riff_info(tag: &Tag, bytes: &mut Vec) -> Result<()> { + if tag.item_count() == 0 { + return Ok(()); + } + + bytes.extend(b"LIST".iter()); + bytes.extend(b"INFO".iter()); for item in tag.items() { if let Some(key) = item.key().map_key(&TagType::RiffInfo) { @@ -66,16 +110,16 @@ pub(crate) fn write_to(data: &mut File, tag: &Tag) -> Result<()> { ((len + 1) as u32, &[0, 0]) }; - packet.extend(key.as_bytes().iter()); - packet.extend(size.to_le_bytes().iter()); - packet.extend(val_b.iter()); - packet.extend(terminator.iter()); + bytes.extend(key.as_bytes().iter()); + bytes.extend(size.to_le_bytes().iter()); + bytes.extend(val_b.iter()); + bytes.extend(terminator.iter()); } } } } - let packet_size = packet.len() - 4; + let packet_size = bytes.len() - 4; if packet_size > u32::MAX as usize { return Err(LoftyError::TooMuchData); @@ -85,31 +129,8 @@ pub(crate) fn write_to(data: &mut File, tag: &Tag) -> Result<()> { #[allow(clippy::needless_range_loop)] for i in 0..4 { - packet.insert(i + 4, size[i]); + bytes.insert(i + 4, size[i]); } - data.seek(SeekFrom::Current(8))?; - - find_info_list(data)?; - - let info_list_size = data.read_u32::()? as usize; - data.seek(SeekFrom::Current(-8))?; - - let info_list_start = data.seek(SeekFrom::Current(0))? as usize; - let info_list_end = info_list_start + 8 + info_list_size; - - data.seek(SeekFrom::Start(0))?; - let mut file_bytes = Vec::new(); - data.read_to_end(&mut file_bytes)?; - - let _ = file_bytes.splice(info_list_start..info_list_end, packet); - - let total_size = (file_bytes.len() - 8) as u32; - let _ = file_bytes.splice(4..8, total_size.to_le_bytes().to_vec()); - - data.seek(SeekFrom::Start(0))?; - data.set_len(0)?; - data.write_all(&*file_bytes)?; - Ok(()) }