diff --git a/src/ebml/element_reader.rs b/src/ebml/element_reader.rs index a4a717a3..f5f75f41 100644 --- a/src/ebml/element_reader.rs +++ b/src/ebml/element_reader.rs @@ -3,6 +3,7 @@ use crate::error::Result; use crate::macros::{decode_err, try_vec}; use std::io::Read; +use std::ops::{Deref, DerefMut}; use byteorder::{BigEndian, ReadBytesExt}; use lofty_attr::ebml_master_elements; @@ -229,7 +230,11 @@ where self.ctx.previous_master_length = self.ctx.master_length; } - fn next_master(&mut self) -> Result { + fn goto_next_master(&mut self) -> Result { + if self.ctx.master_length != 0 { + self.skip(self.ctx.master_length)?; + } + let header = ElementHeader::read( &mut self.reader, self.ctx.max_id_length, @@ -249,15 +254,6 @@ where ))) } - /// Lock the reader to the current master element - pub(crate) fn lock(&mut self) { - self.ctx.locked = true; - } - - pub(crate) fn unlock(&mut self) { - self.ctx.locked = false; - } - pub(crate) fn goto_previous_master(&mut self) -> Result<()> { if let Some(previous_master) = self.ctx.previous_master { self.ctx.current_master = Some(previous_master); @@ -270,7 +266,7 @@ where pub(crate) fn next(&mut self) -> Result { let Some(current_master) = self.ctx.current_master else { - return self.next_master(); + return self.goto_next_master(); }; if self.ctx.master_length == 0 { @@ -278,7 +274,7 @@ where return Ok(ElementReaderYield::Eof); } - return self.next_master(); + return self.goto_next_master(); } let header = ElementHeader::read(self, self.ctx.max_id_length, self.ctx.max_size_length)?; @@ -310,11 +306,33 @@ where Ok(ElementReaderYield::Child((*child, header.size.value()))) } + fn lock(&mut self) { + self.ctx.locked = true; + } + + fn unlock(&mut self) { + self.ctx.locked = false; + } + + pub(crate) fn children(&mut self) -> ElementChildIterator<'_, R> { + self.lock(); + ElementChildIterator::new(self) + } + pub(crate) fn skip(&mut self, length: u64) -> Result<()> { std::io::copy(&mut self.by_ref().take(length), &mut std::io::sink())?; Ok(()) } + pub(crate) fn skip_element(&mut self, element_header: ElementHeader) -> Result<()> { + log::debug!( + "Encountered unknown EBML element in header: {:X}", + element_header.id.0 + ); + self.skip(element_header.size.value())?; + Ok(()) + } + pub(crate) fn read_signed_int(&mut self, element_length: u64) -> Result { // https://www.rfc-editor.org/rfc/rfc8794.html#section-7.1 // A Signed Integer Element MUST declare a length from zero to eight octets @@ -390,3 +408,68 @@ where todo!() } } + +pub(crate) struct ElementChildIterator<'a, R> +where + R: Read, +{ + reader: &'a mut ElementReader, +} + +impl<'a, R> ElementChildIterator<'a, R> +where + R: Read, +{ + pub(crate) fn new(reader: &'a mut ElementReader) -> Self { + Self { reader } + } + + pub(crate) fn next(&mut self) -> Result> { + match self.reader.next() { + Ok(ElementReaderYield::Unknown(header)) => { + self.reader.skip_element(header)?; + self.next() + }, + Ok(ElementReaderYield::Eof) => Ok(None), + Err(e) => Err(e), + element => element.map(Some), + } + } + + pub(crate) fn master_exhausted(&self) -> bool { + self.reader.ctx.master_length == 0 + } + + pub(crate) fn inner(&mut self) -> &mut ElementReader { + self.reader + } +} + +impl<'a, R> Deref for ElementChildIterator<'a, R> +where + R: Read, +{ + type Target = ElementReader; + + fn deref(&self) -> &Self::Target { + self.reader + } +} + +impl<'a, R> DerefMut for ElementChildIterator<'a, R> +where + R: Read, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + self.reader + } +} + +impl<'a, R> Drop for ElementChildIterator<'a, R> +where + R: Read, +{ + fn drop(&mut self) { + self.reader.unlock(); + } +} diff --git a/src/ebml/read.rs b/src/ebml/read.rs index 9510ffed..73abb7ab 100644 --- a/src/ebml/read.rs +++ b/src/ebml/read.rs @@ -69,15 +69,13 @@ where Err(e) => return Err(e), } - element_reader.lock(); - - loop { + let mut child_reader = element_reader.children(); + while let Some(child) = child_reader.next()? { let ident; let data_ty; let size; - let res = element_reader.next()?; - match res { + match child { // The only expected master element in the header is `DocTypeExtension` ElementReaderYield::Master((ElementIdent::DocTypeExtension, _)) => continue, ElementReaderYield::Child((child, size_)) => { @@ -85,53 +83,46 @@ where data_ty = child.data_type; size = size_; }, - ElementReaderYield::Unknown(element) => { - log::debug!( - "Encountered unknown EBML element in header: {:X}", - element.id.0 - ); - element_reader.skip(element.size.value())?; - continue; - }, _ => break, } if ident == ElementIdent::EBMLMaxIDLength { - properties.header.max_id_length = element_reader.read_unsigned_int(size)? as u8; - element_reader.set_max_id_length(properties.header.max_id_length); + properties.header.max_id_length = child_reader.read_unsigned_int(size)? as u8; + child_reader.set_max_id_length(properties.header.max_id_length); continue; } if ident == ElementIdent::EBMLMaxSizeLength { - properties.header.max_size_length = element_reader.read_unsigned_int(size)? as u8; - element_reader.set_max_size_length(properties.header.max_size_length); + properties.header.max_size_length = child_reader.read_unsigned_int(size)? as u8; + child_reader.set_max_size_length(properties.header.max_size_length); continue; } // Anything else in the header is unnecessary, and only read for the properties // struct if !parse_options.read_properties { - element_reader.skip(size)?; + child_reader.skip(size)?; continue; } match ident { ElementIdent::EBMLVersion => { - properties.header.version = element_reader.read_unsigned_int(size)? + properties.header.version = child_reader.read_unsigned_int(size)? }, ElementIdent::EBMLReadVersion => { - properties.header.read_version = element_reader.read_unsigned_int(size)? - }, - ElementIdent::DocType => { - properties.header.doc_type = element_reader.read_string(size)? + properties.header.read_version = child_reader.read_unsigned_int(size)? }, + ElementIdent::DocType => properties.header.doc_type = child_reader.read_string(size)?, ElementIdent::DocTypeVersion => { - properties.header.doc_type_version = element_reader.read_unsigned_int(size)? + properties.header.doc_type_version = child_reader.read_unsigned_int(size)? }, - _ => element_reader.skip(size)?, + _ => child_reader.skip(size)?, } } - element_reader.unlock(); + debug_assert!( + child_reader.master_exhausted(), + "There should be no remaining elements in the header" + ); Ok(()) } diff --git a/src/ebml/read/segment.rs b/src/ebml/read/segment.rs index 98a99a27..af73f0c9 100644 --- a/src/ebml/read/segment.rs +++ b/src/ebml/read/segment.rs @@ -16,16 +16,14 @@ pub(super) fn read_from( where R: Read + Seek, { - element_reader.lock(); - let mut tags = None; - loop { - let res = element_reader.next()?; - match res { + let mut children_reader = element_reader.children(); + while let Some(child) = children_reader.next()? { + match child { ElementReaderYield::Master((id, size)) => match id { ElementIdent::Info => { - segment_info::read_from(element_reader, parse_options, properties)? + segment_info::read_from(children_reader.inner(), parse_options, properties)? }, ElementIdent::Cluster => todo!("Support segment.Cluster"), ElementIdent::Tracks => todo!("Support segment.Tracks"), @@ -37,23 +35,15 @@ where // elements, so we can just skip any useless ones. log::debug!("Skipping EBML master element: {:?}", id); - element_reader.skip(size)?; - element_reader.goto_previous_master()?; + children_reader.skip(size)?; + children_reader.goto_previous_master()?; continue; }, }, - ElementReaderYield::Unknown(element) => { - log::debug!("Skipping unknown EBML element: {:X}", element.id.0); - element_reader.skip(element.size.value())?; - continue; - }, - ElementReaderYield::Eof => { - element_reader.unlock(); - break; - }, - _ => { + ElementReaderYield::Child(_) => { decode_err!(@BAIL Ebml, "Segment element should only contain master elements") }, + _ => break, } } diff --git a/src/ebml/read/segment_info.rs b/src/ebml/read/segment_info.rs index e3eac2e7..891f01be 100644 --- a/src/ebml/read/segment_info.rs +++ b/src/ebml/read/segment_info.rs @@ -14,25 +14,24 @@ pub(super) fn read_from( where R: Read + Seek, { - element_reader.lock(); + let mut children_reader = element_reader.children(); - loop { - let res = element_reader.next()?; - match res { + while let Some(child) = children_reader.next()? { + match child { ElementReaderYield::Master((id, size)) => { // We do not end up using information from any of the nested master // elements, so we can just skip them. log::debug!("Skipping EBML master element: {:?}", id); - element_reader.skip(size)?; - element_reader.goto_previous_master()?; + children_reader.skip(size)?; + children_reader.goto_previous_master()?; continue; }, ElementReaderYield::Child((child, size)) => { match child.ident { ElementIdent::TimecodeScale => { properties.segment_info.timestamp_scale = - element_reader.read_unsigned_int(size)?; + children_reader.read_unsigned_int(size)?; if properties.segment_info.timestamp_scale == 0 { log::warn!("Segment.Info.TimecodeScale is 0, which is invalid"); @@ -42,33 +41,26 @@ where } }, ElementIdent::MuxingApp => { - properties.segment_info.muxing_app = element_reader.read_utf8(size)? + properties.segment_info.muxing_app = children_reader.read_utf8(size)? }, ElementIdent::WritingApp => { - properties.segment_info.writing_app = element_reader.read_utf8(size)? + properties.segment_info.writing_app = children_reader.read_utf8(size)? }, _ => { // We do not end up using information from all of the segment // elements, so we can just skip any useless ones. log::debug!("Skipping EBML child element: {:?}", child.ident); - element_reader.skip(size)?; + children_reader.skip(size)?; continue; }, } }, - ElementReaderYield::Unknown(element) => { - log::debug!("Skipping unknown EBML element: {:X}", element.id.0); - element_reader.skip(element.size.value())?; - continue; - }, - ElementReaderYield::Eof => { - element_reader.unlock(); - break; - }, + _ => break, } } + drop(children_reader); element_reader.goto_previous_master()?; Ok(()) }