EBML: Improve ElementReader locking

This commit is contained in:
Serial 2024-08-10 12:00:52 -04:00
parent 2ff8dc6b6e
commit 09f170e08e
No known key found for this signature in database
GPG key ID: DA95198DC17C4568
5 changed files with 296 additions and 91 deletions

View file

@ -214,41 +214,49 @@ ebml_master_elements! {
},
}
#[derive(Debug)]
struct MasterElementContext {
element: MasterElement,
remaining_length: VInt,
}
const MAX_DEPTH: u8 = 16;
const ROOT_DEPTH: u8 = 1;
struct ElementReaderContext {
/// Previous master element
previous_master: Option<MasterElement>,
previous_master_length: VInt,
/// Current master element
current_master: Option<MasterElement>,
/// Remaining length of the master element
master_length: VInt,
depth: u8,
masters: Vec<MasterElementContext>,
/// Maximum size in octets of all element IDs
max_id_length: u8,
/// Maximum size in octets of all element data sizes
max_size_length: u8,
/// Whether the reader is locked to the current master element
/// Whether the reader is locked to the master element at `lock_depth`
///
/// This is set with [`ElementReader::lock`], and is used to prevent
/// the reader from reading past the end of the current master element.
locked: bool,
/// The depth at which we are locked to
lock_depth: u8,
lock_len: VInt,
}
impl Default for ElementReaderContext {
fn default() -> Self {
Self {
previous_master: None,
previous_master_length: VInt::ZERO,
current_master: None,
master_length: VInt::ZERO,
depth: 0,
masters: Vec::with_capacity(MAX_DEPTH as usize),
// https://www.rfc-editor.org/rfc/rfc8794.html#name-ebmlmaxidlength-element
max_id_length: 4,
// https://www.rfc-editor.org/rfc/rfc8794.html#name-ebmlmaxsizelength-element
max_size_length: 8,
locked: false,
lock_depth: 0,
lock_len: VInt::ZERO,
}
}
}
#[derive(Debug)]
pub(crate) enum ElementReaderYield {
Master((ElementIdent, VInt)),
Child((ChildElementDescriptor, VInt)),
@ -268,8 +276,9 @@ impl ElementReaderYield {
pub fn size(&self) -> Option<u64> {
match self {
ElementReaderYield::Master((_, size)) => Some(size.value()),
ElementReaderYield::Child((_, size)) => Some(size.value()),
ElementReaderYield::Master((_, size)) | ElementReaderYield::Child((_, size)) => {
Some(size.value())
},
ElementReaderYield::Unknown(header) => Some(header.size.value()),
_ => None,
}
@ -286,8 +295,19 @@ where
R: Read,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.ctx.locked {
let lock_len = self.ctx.lock_len.value();
if buf.len() > lock_len as usize {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Cannot read past the end of the current master element",
));
}
}
let ret = self.reader.read(buf)?;
self.ctx.master_length = self.ctx.master_length.saturating_sub(ret as u64);
let len = self.current_master_length();
self.set_current_master_length(len.saturating_sub(ret as u64));
Ok(ret)
}
}
@ -311,15 +331,63 @@ where
self.ctx.max_size_length = len
}
fn store_previous_master(&mut self) {
self.ctx.previous_master = self.ctx.current_master;
self.ctx.previous_master_length = self.ctx.master_length;
fn current_master(&self) -> Option<MasterElement> {
if self.ctx.depth == 0 {
assert!(self.ctx.masters.is_empty());
return None;
}
Some(self.ctx.masters[(self.ctx.depth - 1) as usize].element)
}
fn current_master_length(&self) -> VInt {
if self.ctx.depth == 0 {
assert!(self.ctx.masters.is_empty());
return VInt::ZERO;
}
self.ctx.masters[(self.ctx.depth - 1) as usize].remaining_length
}
fn set_current_master_length(&mut self, length: VInt) {
if self.ctx.depth == 0 {
assert!(self.ctx.masters.is_empty());
return;
}
if self.ctx.locked {
self.ctx.lock_len = length;
}
self.ctx.masters[(self.ctx.depth - 1) as usize].remaining_length = length;
}
fn push_new_master(&mut self, master: MasterElement, size: VInt) -> Result<()> {
if self.ctx.depth == MAX_DEPTH {
decode_err!(@BAIL Ebml, "Maximum depth reached");
}
// If we are at the root level, we do not increment the depth
// since we are not actually inside a master element.
// For example, we are moving from \EBML to \Segment.
let at_root_level = self.ctx.depth == ROOT_DEPTH && self.current_master_length() == 0;
if at_root_level {
assert_eq!(self.ctx.masters.len(), 1);
self.ctx.masters.clear();
} else {
self.ctx.depth += 1;
}
self.ctx.masters.push(MasterElementContext {
element: master,
remaining_length: size,
});
Ok(())
}
fn goto_next_master(&mut self) -> Result<ElementReaderYield> {
if self.ctx.master_length != 0 {
self.skip(self.ctx.master_length.value())?;
}
self.exhaust_current_master()?;
let header = ElementHeader::read(
&mut self.reader,
@ -331,35 +399,34 @@ where
return Ok(ElementReaderYield::Unknown(header));
};
self.store_previous_master();
self.ctx.current_master = Some(*master);
self.ctx.master_length = header.size;
Ok(ElementReaderYield::Master((
master.id,
self.ctx.master_length,
)))
self.push_new_master(*master, header.size)?;
Ok(ElementReaderYield::Master((master.id, header.size)))
}
pub(crate) fn goto_previous_master(&mut self) -> Result<()> {
if let Some(previous_master) = self.ctx.previous_master {
self.ctx.current_master = Some(previous_master);
self.ctx.master_length = self.ctx.previous_master_length;
Ok(())
} else {
decode_err!(@BAIL Ebml, "Expected a parent element to be available")
fn goto_previous_master(&mut self) -> Result<()> {
if self.ctx.depth == 0 || self.ctx.depth == self.ctx.lock_depth {
decode_err!(@BAIL Ebml, "Cannot go to previous master element, already at root")
}
self.exhaust_current_master()?;
self.ctx.depth -= 1;
let _ = self.ctx.masters.pop();
Ok(())
}
pub(crate) fn next(&mut self) -> Result<ElementReaderYield> {
let Some(current_master) = self.ctx.current_master else {
let Some(current_master) = self.current_master() else {
return self.goto_next_master();
};
if self.ctx.master_length == 0 {
if self.ctx.locked {
return Ok(ElementReaderYield::Eof);
}
if self.ctx.locked && self.ctx.lock_len == 0 {
return Ok(ElementReaderYield::Eof);
}
if self.current_master_length() == 0 {
return self.goto_next_master();
}
@ -374,13 +441,12 @@ where
};
if child.data_type == ElementDataType::Master {
self.store_previous_master();
self.ctx.current_master = Some(
self.push_new_master(
*master_elements()
.get(&header.id)
.expect("Nested master elements should be defined at this level."),
);
self.ctx.master_length = header.size;
header.size,
)?;
// We encountered a nested master element
return Ok(ElementReaderYield::Master((child.ident, header.size)));
@ -389,11 +455,22 @@ where
Ok(ElementReaderYield::Child((*child, header.size)))
}
fn lock(&mut self) {
self.ctx.locked = true;
pub(crate) fn exhaust_current_master(&mut self) -> Result<()> {
let master_length = self.current_master_length().value();
if master_length == 0 {
return Ok(());
}
self.skip(master_length)?;
Ok(())
}
fn unlock(&mut self) {
pub(crate) fn lock(&mut self) {
self.ctx.locked = true;
self.ctx.lock_len = self.current_master_length();
}
pub(crate) fn unlock(&mut self) {
self.ctx.locked = false;
}
@ -403,13 +480,20 @@ where
}
pub(crate) fn skip(&mut self, length: u64) -> Result<()> {
log::trace!("Skipping {} bytes", length);
let current_master_length = self.current_master_length();
if length > current_master_length.value() {
decode_err!(@BAIL Ebml, "Cannot skip past the end of the current master element")
}
std::io::copy(&mut self.by_ref().take(length), &mut std::io::sink())?;
Ok(())
}
pub(crate) fn skip_element(&mut self, element_header: ElementHeader) -> Result<()> {
log::debug!(
"Encountered unknown EBML element in header: {:X}",
"Encountered unknown EBML element: {:X}, skipping",
element_header.id.0
);
self.skip(element_header.size.value())?;
@ -445,6 +529,16 @@ where
Ok(u64::from_be_bytes(buf))
}
/// Same as `read_unsigned_int`, but will warn if the value is out of range.
pub(crate) fn read_flag(&mut self, element_length: u64) -> Result<bool> {
let val = self.read_unsigned_int(element_length)?;
if val > 1 {
log::warn!("Flag value `{}` is out of range, assuming true", val);
}
Ok(val != 0)
}
pub(crate) fn read_float(&mut self, element_length: u64) -> Result<f64> {
// https://www.rfc-editor.org/rfc/rfc8794.html#section-7.3
// A Float Element MUST declare a length of either zero octets (0 bit),
@ -513,18 +607,16 @@ where
self.reader.skip_element(header)?;
self.next()
},
Ok(ElementReaderYield::Eof) => Ok(None),
Err(e) => Err(e),
element => element.map(Some),
}
}
pub(crate) fn master_exhausted(&self) -> bool {
self.reader.ctx.master_length == 0
}
let lock_depth = self.reader.ctx.lock_depth;
assert!(lock_depth < self.reader.ctx.depth);
pub(crate) fn inner(&mut self) -> &mut ElementReader<R> {
self.reader
self.reader.ctx.masters[lock_depth as usize].remaining_length == 0
}
}

View file

@ -77,18 +77,29 @@ where
let mut child_reader = element_reader.children();
while let Some(child) = child_reader.next()? {
let ident;
let data_ty;
let size;
match child {
// The only expected master element in the header is `DocTypeExtension`
ElementReaderYield::Master((ElementIdent::DocTypeExtension, _)) => continue,
ElementReaderYield::Master((ElementIdent::DocTypeExtension, size)) => {
child_reader.skip(size.value())?;
continue;
},
ElementReaderYield::Master(_) => {
decode_err!(
@BAIL Ebml,
"Unexpected master element in the EBML header"
);
},
ElementReaderYield::Child((child, size_)) => {
ident = child.ident;
data_ty = child.data_type;
size = size_;
},
_ => break,
ElementReaderYield::Unknown(header) => {
child_reader.skip_element(header)?;
continue;
},
ElementReaderYield::Eof => break,
}
if ident == ElementIdent::EBMLMaxIDLength {

View file

@ -3,9 +3,10 @@ use super::{
segment_tracks,
};
use crate::config::ParseOptions;
use crate::ebml::element_reader::{ElementIdent, ElementReader, ElementReaderYield};
use crate::ebml::element_reader::{ElementHeader, ElementIdent, ElementReader, ElementReaderYield};
use crate::ebml::properties::EbmlProperties;
use crate::ebml::tag::EbmlTag;
use crate::ebml::VInt;
use crate::error::Result;
use crate::macros::decode_err;
@ -21,30 +22,29 @@ where
{
let mut tags = None;
let mut children_reader = element_reader.children();
while let Some(child) = children_reader.next()? {
element_reader.lock();
loop {
let child = element_reader.next()?;
match child {
ElementReaderYield::Master((id, size)) => match id {
ElementIdent::Info if parse_options.read_properties => {
segment_info::read_from(children_reader.inner(), parse_options, properties)?
segment_info::read_from(element_reader, parse_options, properties)?
},
ElementIdent::Cluster if parse_options.read_properties => {
segment_cluster::read_from(children_reader.inner(), parse_options, properties)?
segment_cluster::read_from(element_reader, parse_options, properties)?
},
ElementIdent::Tracks if parse_options.read_properties => {
segment_tracks::read_from(children_reader.inner(), parse_options, properties)?
segment_tracks::read_from(element_reader, parse_options, properties)?
},
ElementIdent::Tags | ElementIdent::Chapters if parse_options.read_tags => {
let mut tag = tags.unwrap_or_default();
if id == ElementIdent::Tags {
segment_tags::read_from(children_reader.inner(), parse_options, &mut tag)?
segment_tags::read_from(element_reader, parse_options, &mut tag)?
} else {
segment_chapters::read_from(
children_reader.inner(),
parse_options,
&mut tag,
)?
segment_chapters::read_from(element_reader, parse_options, &mut tag)?
}
tags = Some(tag);
@ -52,11 +52,7 @@ where
ElementIdent::Attachments if parse_options.read_cover_art => {
let mut tag = tags.unwrap_or_default();
segment_attachments::read_from(
children_reader.inner(),
parse_options,
&mut tag,
)?;
segment_attachments::read_from(element_reader, parse_options, &mut tag)?;
tags = Some(tag);
},
@ -64,16 +60,24 @@ where
// We do not end up using information from all of the segment
// elements, so we can just skip any useless ones.
log::debug!("Skipping EBML master element: {:?}", id);
children_reader.skip(size.value())?;
children_reader.goto_previous_master()?;
element_reader.skip_element(ElementHeader {
id: VInt(id as u64),
size,
})?;
continue;
},
},
ElementReaderYield::Unknown(header) => {
element_reader.skip_element(header)?;
continue;
},
ElementReaderYield::Child(_) => {
decode_err!(@BAIL Ebml, "Segment element should only contain master elements")
},
_ => break,
ElementReaderYield::Eof => {
element_reader.unlock();
break;
},
}
}

View file

@ -24,7 +24,6 @@ where
log::debug!("Skipping EBML master element: {:?}", id);
children_reader.skip(size.value())?;
children_reader.goto_previous_master()?;
continue;
},
ElementReaderYield::Child((child, size)) => {
@ -58,11 +57,13 @@ where
},
}
},
ElementReaderYield::Unknown(header) => {
children_reader.skip_element(header)?;
continue;
},
_ => break,
}
}
drop(children_reader);
element_reader.goto_previous_master()?;
Ok(())
}

View file

@ -3,6 +3,7 @@ use crate::ebml::element_reader::{
ChildElementDescriptor, ElementHeader, ElementIdent, ElementReader, ElementReaderYield,
};
use crate::ebml::properties::EbmlProperties;
use crate::ebml::VInt;
use crate::error::Result;
use crate::macros::decode_err;
@ -16,15 +17,21 @@ pub(super) fn read_from<R>(
where
R: Read + Seek,
{
let mut children_reader = element_reader.children();
element_reader.lock();
let mut audio_tracks = Vec::new();
loop {
let child = element_reader.next()?;
if let ElementReaderYield::Eof = child {
break;
}
while let Some(child) = children_reader.next()? {
match child {
ElementReaderYield::Master((ElementIdent::TrackEntry, size)) => {
read_track_entry(children_reader.inner(), parse_options, properties)?;
},
ElementReaderYield::Eof => {
break;
element_reader.unlock();
read_track_entry(element_reader, parse_options, &mut audio_tracks)?;
element_reader.lock();
},
_ => {
let id = child
@ -36,22 +43,112 @@ where
"Unexpected child element in \\EBML\\Segment\\Tracks: {:?}, skipping",
id
);
children_reader.skip(size)?;
element_reader.skip(size)?;
continue;
},
}
}
element_reader.unlock();
Ok(())
}
#[derive(Default)]
struct AudioTrack {
default: bool,
enabled: bool,
codec_id: String,
codec_name: String,
}
const AUDIO_TRACK_TYPE: u64 = 2;
fn read_track_entry<R>(
_element_reader: &mut ElementReader<R>,
_parse_options: ParseOptions,
_properties: &mut EbmlProperties,
element_reader: &mut ElementReader<R>,
parse_options: ParseOptions,
audio_tracks: &mut Vec<AudioTrack>,
) -> Result<()>
where
R: Read + Seek,
{
element_reader.lock();
let mut track = AudioTrack::default();
loop {
let child = element_reader.next()?;
if let ElementReaderYield::Eof = child {
break;
}
match child {
ElementReaderYield::Child((ChildElementDescriptor { ident, .. }, size)) => {
match ident {
ElementIdent::TrackType => {
let track_type = element_reader.read_unsigned_int(size.value())?;
log::trace!("Encountered new track of type: {}", track_type);
if track_type != AUDIO_TRACK_TYPE {
element_reader.exhaust_current_master()?;
break;
}
},
ElementIdent::FlagEnabled => {
let enabled = element_reader.read_flag(size.value())?;
track.enabled = enabled;
},
ElementIdent::FlagDefault => {
let default = element_reader.read_flag(size.value())?;
track.default = default;
},
ElementIdent::DefaultDuration => {
let _default_duration = element_reader.read_unsigned_int(size.value())?;
},
ElementIdent::TrackTimecodeScale => {
let _timecode_scale = element_reader.read_float(size.value())?;
},
ElementIdent::Language => {
let _language = element_reader.read_string(size.value())?;
},
ElementIdent::CodecID => {
let codec_id = element_reader.read_string(size.value())?;
track.codec_id = codec_id;
},
ElementIdent::CodecDelay => {
let _codec_delay = element_reader.read_unsigned_int(size.value())?;
},
ElementIdent::CodecName => {
let codec_name = element_reader.read_utf8(size.value())?;
track.codec_name = codec_name;
},
ElementIdent::SeekPreRoll => {
let _seek_pre_roll = element_reader.read_unsigned_int(size.value())?;
},
_ => unreachable!("Unhandled child element in TrackEntry: {:?}", ident),
}
},
ElementReaderYield::Master((id, size)) => match id {
ElementIdent::Audio => {
element_reader.skip(size.value())?;
},
_ => {
unreachable!("Unhandled master element in TrackEntry: {:?}", id);
},
},
ElementReaderYield::Unknown(ElementHeader { size, id }) => {
element_reader.skip(size.value())?;
},
_ => {},
}
}
if !track.enabled {
log::debug!("Skipping disabled track");
return Ok(());
}
audio_tracks.push(track);
element_reader.unlock();
Ok(())
}