diff --git a/src/ebml/element_reader.rs b/src/ebml/element_reader.rs index 0fb389cd..f07ec742 100644 --- a/src/ebml/element_reader.rs +++ b/src/ebml/element_reader.rs @@ -148,6 +148,8 @@ ebml_master_elements! { } struct ElementReaderContext { + /// Previous master element + previous_master: Option, /// Current master element current_master: Option, /// Remaining length of the master element @@ -156,17 +158,24 @@ struct ElementReaderContext { max_id_length: u8, /// Maximum size in octets of all element data sizes max_size_length: u8, + /// Whether the reader is locked to the current master element + /// + /// This is set with [`ElementReader::lock`], and is used to prevent + /// the reader from reading past the end of the current master element. + locked: bool, } impl Default for ElementReaderContext { fn default() -> Self { Self { + previous_master: None, current_master: None, master_length: 0, // https://www.rfc-editor.org/rfc/rfc8794.html#name-ebmlmaxidlength-element max_id_length: 4, // https://www.rfc-editor.org/rfc/rfc8794.html#name-ebmlmaxsizelength-element max_size_length: 8, + locked: false, } } } @@ -213,6 +222,7 @@ where return Ok(ElementReaderYield::Unknown(header)); }; + self.ctx.previous_master = self.ctx.current_master; self.ctx.current_master = Some(*master); self.ctx.master_length = header.size.value(); Ok(ElementReaderYield::Master(( @@ -221,12 +231,34 @@ where ))) } + /// Lock the reader to the current master element + pub(crate) fn lock(&mut self) { + self.ctx.locked = true; + } + + pub(crate) fn unlock(&mut self) { + self.ctx.locked = false; + } + + pub(crate) fn goto_previous_master(&mut self) -> Result<()> { + if let Some(previous_master) = self.ctx.previous_master { + self.ctx.current_master = Some(previous_master); + Ok(()) + } else { + decode_err!(@BAIL Ebml, "Expected a parent element to be available") + } + } + pub(crate) fn next(&mut self) -> Result { let Some(current_master) = self.ctx.current_master else { return self.next_master(); }; if self.ctx.master_length == 0 { + if self.ctx.locked { + return Ok(ElementReaderYield::Eof); + } + return self.next_master(); } @@ -244,6 +276,21 @@ where return Ok(ElementReaderYield::Unknown(header)); }; + if child.data_type == ElementDataType::Master { + self.ctx.current_master = Some( + *MASTER_ELEMENTS + .get(&header.id) + .expect("Nested master elements should be defined at this level."), + ); + self.ctx.master_length = header.size.value(); + + // We encountered a nested master element + return Ok(ElementReaderYield::Master(( + child.ident, + header.size.value(), + ))); + } + Ok(ElementReaderYield::Child((*child, header.size.value()))) }