diff --git a/bevy_legion/legion_core/src/filter.rs b/bevy_legion/legion_core/src/filter.rs index 96782c4fa9..f4ea211e1b 100644 --- a/bevy_legion/legion_core/src/filter.rs +++ b/bevy_legion/legion_core/src/filter.rs @@ -95,6 +95,9 @@ impl FilterResult for Option { pub trait Filter: Send + Sync + Sized { type Iter: Iterator + Send + Sync; + // Called when a query is about to begin execution. + fn init(&self) {} + /// Pulls iterator data out of the source. fn collect(&self, source: T) -> Self::Iter; @@ -168,6 +171,9 @@ pub trait EntityFilter: Send + Clone { type ChunksetFilter: for<'a> Filter> + Clone; type ChunkFilter: for<'a> Filter> + Clone; + /// Initializes the entity filter for iteration. + fn init(&self); + /// Gets mutable references to both inner filters. fn filters( &self, @@ -239,6 +245,12 @@ where type ChunksetFilter = S; type ChunkFilter = C; + fn init(&self) { + self.arch_filter.init(); + self.chunkset_filter.init(); + self.chunk_filter.init(); + } + fn filters( &self, ) -> ( @@ -380,7 +392,7 @@ impl<'a, 'b, F: Filter>> Iterator for FilterArchIter<'a, } } -/// An iterator which yields the index of chunks that match a filter. +/// An iterator which yields the index of chuinks that match a filter. pub struct FilterChunkIter<'a, 'b, F: Filter>> { filter: &'b F, chunks: Enumerate, @@ -458,6 +470,9 @@ pub struct Passthrough; impl<'a> Filter> for Passthrough { type Iter = Take>; + #[inline] + fn init(&self) {} + #[inline] fn collect(&self, arch: ArchetypeFilterData<'a>) -> Self::Iter { std::iter::repeat(()).take(arch.component_types.len()) @@ -470,6 +485,9 @@ impl<'a> Filter> for Passthrough { impl<'a> Filter> for Passthrough { type Iter = Take>; + #[inline] + fn init(&self) {} + #[inline] fn collect(&self, sets: ChunksetFilterData<'a>) -> Self::Iter { std::iter::repeat(()).take(sets.archetype_data.len()) @@ -482,6 +500,9 @@ impl<'a> Filter> for Passthrough { impl<'a> Filter> for Passthrough { type Iter = Take>; + #[inline] + fn init(&self) {} + #[inline] fn collect(&self, chunk: ChunkFilterData<'a>) -> Self::Iter { std::iter::repeat(()).take(chunk.chunks.len()) @@ -520,6 +541,9 @@ impl ActiveFilter for Any {} impl<'a> Filter> for Any { type Iter = Take>; + #[inline] + fn init(&self) {} + #[inline] fn collect(&self, arch: ArchetypeFilterData<'a>) -> Self::Iter { std::iter::repeat(()).take(arch.component_types.len()) @@ -532,6 +556,9 @@ impl<'a> Filter> for Any { impl<'a> Filter> for Any { type Iter = Take>; + #[inline] + fn init(&self) {} + #[inline] fn collect(&self, sets: ChunksetFilterData<'a>) -> Self::Iter { std::iter::repeat(()).take(sets.archetype_data.len()) @@ -544,6 +571,9 @@ impl<'a> Filter> for Any { impl<'a> Filter> for Any { type Iter = Take>; + #[inline] + fn init(&self) {} + #[inline] fn collect(&self, chunk: ChunkFilterData<'a>) -> Self::Iter { std::iter::repeat(()).take(chunk.chunks.len()) @@ -592,6 +622,9 @@ impl ActiveFilter for Not {} impl<'a, T: Copy, F: Filter> Filter for Not { type Iter = F::Iter; + #[inline] + fn init(&self) { self.filter.init(); } + #[inline] fn collect(&self, source: T) -> Self::Iter { self.filter.collect(source) } @@ -648,6 +681,9 @@ impl ActiveFilter for And<(T,)> {} impl<'a, T: Copy, F: Filter> Filter for And<(F,)> { type Iter = F::Iter; + #[inline] + fn init(&self) { self.filters.0.init(); } + #[inline] fn collect(&self, source: T) -> Self::Iter { self.filters.0.collect(source) } @@ -717,6 +753,13 @@ macro_rules! impl_and_filter { // type Iter = crate::zip::Zip<( $( $ty::Iter ),* )>; type Iter = recursive_zip!(@type $($ty::Iter),*); + #[inline] + fn init(&self) { + #![allow(non_snake_case)] + let ($( $ty, )*) = &self.filters; + $( $ty.init(); )* + } + fn collect(&self, source: T) -> Self::Iter { #![allow(non_snake_case)] let ($( $ty, )*) = &self.filters; @@ -818,6 +861,13 @@ macro_rules! impl_or_filter { // type Iter = crate::zip::Zip<( $( $ty::Iter ),* )>; type Iter = recursive_zip!(@type $($ty::Iter),*); + #[inline] + fn init(&self) { + #![allow(non_snake_case)] + let ($( $ty, )*) = &self.filters; + $( $ty.init(); )* + } + fn collect(&self, source: T) -> Self::Iter { #![allow(non_snake_case)] let ($( $ty, )*) = &self.filters; @@ -905,7 +955,7 @@ impl_or_filter!(A => a, B => b, C => c, D => d, E => e, F => f, G => g, H => h, impl_or_filter!(A => a, B => b, C => c, D => d, E => e, F => f, G => g, H => h, I => i, J => j, K => k); impl_or_filter!(A => a, B => b, C => c, D => d, E => e, F => f, G => g, H => h, I => i, J => j, K => k, L => l); -/// A filter which requires that all chunks contain entity data components of type `T`. +/// A filter qhich requires that all chunks contain entity data components of type `T`. #[derive(Debug)] pub struct ComponentFilter(PhantomData); @@ -923,6 +973,9 @@ impl Clone for ComponentFilter { impl<'a, T: Component> Filter> for ComponentFilter { type Iter = SliceVecIter<'a, ComponentTypeId>; + #[inline] + fn init(&self) {} + #[inline] fn collect(&self, source: ArchetypeFilterData<'a>) -> Self::Iter { source.component_types.iter() @@ -995,6 +1048,9 @@ impl Clone for TagFilter { impl<'a, T: Tag> Filter> for TagFilter { type Iter = SliceVecIter<'a, TagTypeId>; + #[inline] + fn init(&self) {} + #[inline] fn collect(&self, source: ArchetypeFilterData<'a>) -> Self::Iter { source.tag_types.iter() } @@ -1067,6 +1123,9 @@ impl<'a, T> Clone for TagValueFilter<'a, T> { impl<'a, 'b, T: Tag> Filter> for TagValueFilter<'b, T> { type Iter = Iter<'a, T>; + #[inline] + fn init(&self) {} + fn collect(&self, source: ChunksetFilterData<'a>) -> Self::Iter { unsafe { source @@ -1132,14 +1191,16 @@ impl<'a, T> std::ops::BitOr for TagValueFilter<'a, T> { /// chunk since the last time the filter was executed. #[derive(Debug)] pub struct ComponentChangedFilter { - last_read_version: AtomicU64, + high_water_mark: AtomicU64, + version_threshold: AtomicU64, phantom: PhantomData, } impl ComponentChangedFilter { fn new() -> ComponentChangedFilter { ComponentChangedFilter { - last_read_version: AtomicU64::new(0), + high_water_mark: AtomicU64::new(0), + version_threshold: AtomicU64::new(0), phantom: PhantomData, } } @@ -1150,29 +1211,63 @@ impl ActiveFilter for ComponentChangedFilter {} impl Clone for ComponentChangedFilter { fn clone(&self) -> Self { Self { - last_read_version: AtomicU64::new(self.last_read_version.load(Ordering::Relaxed)), + high_water_mark: AtomicU64::new(self.high_water_mark.load(Ordering::Relaxed)), + version_threshold: AtomicU64::new(self.version_threshold.load(Ordering::Relaxed)), phantom: PhantomData, } } } impl<'a, T: Component> Filter> for ComponentChangedFilter { - type Iter = Iter<'a, ComponentStorage>; + type Iter = ComponentChangedState<'a, ComponentStorage>; - fn collect(&self, source: ChunkFilterData<'a>) -> Self::Iter { source.chunks.iter() } + #[inline] + fn init(&self) { + let version = self.high_water_mark.load(Ordering::Relaxed); + let mut threshold = self.version_threshold.load(Ordering::Relaxed); + if threshold < version { + loop { + match self.version_threshold.compare_exchange_weak( + threshold, + version, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(stored_last_read) => { + threshold = stored_last_read; + if threshold >= version { + // matched version is already considered visited, update no longer needed + break; + } + } + } + } + } + } + + fn collect(&self, source: ChunkFilterData<'a>) -> Self::Iter { + let compare_version = self.version_threshold.load(Ordering::Relaxed); + ComponentChangedState { + iter: source.chunks.iter(), + version_threshold: compare_version, + } + } #[inline] fn is_match(&self, item: &::Item) -> Option { - let components = item.components(ComponentTypeId::of::()); + let (version_threshold, storage) = item; + + let components = storage.components(ComponentTypeId::of::()); if components.is_none() { return Some(false); } let version = components.unwrap().version(); - let mut last_read = self.last_read_version.load(Ordering::Relaxed); + let mut last_read = self.high_water_mark.load(Ordering::Relaxed); if last_read < version { loop { - match self.last_read_version.compare_exchange_weak( + match self.high_water_mark.compare_exchange_weak( last_read, version, Ordering::Relaxed, @@ -1181,13 +1276,16 @@ impl<'a, T: Component> Filter> for ComponentChangedFilter Ok(_) => break, Err(stored_last_read) => { last_read = stored_last_read; - if last_read < version { + if last_read >= version { // matched version is already considered visited, update no longer needed break; } } } } + } + + if version > *version_threshold { Some(true) } else { Some(false) @@ -1195,6 +1293,19 @@ impl<'a, T: Component> Filter> for ComponentChangedFilter } } +pub struct ComponentChangedState<'a, T: Component> { + iter: Iter<'a, T>, + version_threshold: u64, +} + +impl<'a, T: Component> Iterator for ComponentChangedState<'a, T> { + type Item = (u64, &'a T); + + fn next(&mut self) -> Option { + self.iter.next().map(|c| (self.version_threshold, c)) + } +} + impl<'a, T: Component> std::ops::Not for ComponentChangedFilter { type Output = Not; @@ -1241,6 +1352,7 @@ impl<'a, T: Component> std::ops::BitOr for ComponentChangedFilter() | tag_value(&5isize); tracing::trace!(?filter); } + + #[test] + fn component_changed_filter() { + let _ = tracing_subscriber::fmt::try_init(); + + let universe = Universe::new(); + let mut world = universe.create_world(); + + let entity1 = world.insert((), vec![(1usize,)])[0]; + let entity2 = world.insert((), vec![(2usize, false)])[0]; + + let query = >::query().filter(changed::()); + + assert_eq!(2, query.iter_chunks(&world).collect::>().len()); + + *world.get_component_mut::(entity1).unwrap() = 3usize; + + assert_eq!(1, query.iter_chunks(&world).collect::>().len()); + + *world.get_component_mut::(entity1).unwrap() = 4usize; + *world.get_component_mut::(entity2).unwrap() = 5usize; + + assert_eq!(2, query.iter_chunks(&world).collect::>().len()); + + *world.get_component_mut::(entity1).unwrap() = 6usize; + *world.get_component_mut::(entity1).unwrap() = 7usize; + *world.get_component_mut::(entity2).unwrap() = 8usize; + + assert_eq!(2, query.iter_chunks(&world).collect::>().len()); + + *world.get_component_mut::(entity2).unwrap() = 6usize; + *world.get_component_mut::(entity2).unwrap() = 7usize; + *world.get_component_mut::(entity1).unwrap() = 8usize; + + assert_eq!(2, query.iter_chunks(&world).collect::>().len()); + } } diff --git a/bevy_legion/legion_core/src/query.rs b/bevy_legion/legion_core/src/query.rs index 292f4a11c5..04a4cbdacf 100644 --- a/bevy_legion/legion_core/src/query.rs +++ b/bevy_legion/legion_core/src/query.rs @@ -949,6 +949,7 @@ where &'a self, world: &'data World, ) -> ChunkViewIter<'data, 'a, V, F::ArchetypeFilter, F::ChunksetFilter, F::ChunkFilter> { + self.filter.init(); let (arch_filter, chunkset_filter, chunk_filter) = self.filter.filters(); let storage = world.storage(); let archetypes = arch_filter @@ -1190,6 +1191,7 @@ where >>::Iter: FissileIterator, >>::Iter: FissileIterator, { + self.filter.init(); let (arch_filter, chunkset_filter, chunk_filter) = self.filter.filters(); let storage = world.storage(); let archetypes = FissileEnumerate::new(arch_filter.collect(ArchetypeFilterData { diff --git a/bevy_legion/legion_core/src/storage.rs b/bevy_legion/legion_core/src/storage.rs index d54302dfaf..276d6f2a49 100644 --- a/bevy_legion/legion_core/src/storage.rs +++ b/bevy_legion/legion_core/src/storage.rs @@ -1425,7 +1425,10 @@ impl ComponentStorage { // free component memory unsafe { let ptr = self.component_data.take().unwrap(); - std::alloc::dealloc(ptr.as_ptr(), self.component_layout); + + if self.component_layout.size() > 0 { + std::alloc::dealloc(ptr.as_ptr(), self.component_layout); + } } self.update_mem_gauge(); @@ -1446,13 +1449,18 @@ impl ComponentStorage { unsafe { // allocating backing store - let ptr = std::alloc::alloc(self.component_layout); - self.component_data = Some(NonNull::new_unchecked(ptr)); + if self.component_layout.size() > 0 { + let ptr = std::alloc::alloc(self.component_layout); + self.component_data = Some(NonNull::new_unchecked(ptr)); - // update accessor pointers - for (type_id, component) in (&mut *self.component_info.get()).iter_mut() { - let &offset = self.component_offsets.get(type_id).unwrap(); - *component.ptr.get_mut() = ptr.add(offset); + // update accessor pointers + for (type_id, component) in (&mut *self.component_info.get()).iter_mut() { + let &offset = self.component_offsets.get(type_id).unwrap(); + *component.ptr.get_mut() = ptr.add(offset); + } + } else { + self.component_data = + Some(NonNull::new(self.component_layout.align() as *mut u8).unwrap()); } } @@ -1516,8 +1524,10 @@ impl Drop for ComponentStorage { self.update_count_gauge(); // free the chunk's memory - unsafe { - std::alloc::dealloc(ptr.as_ptr(), self.component_layout); + if self.component_layout.size() > 0 { + unsafe { + std::alloc::dealloc(ptr.as_ptr(), self.component_layout); + } } } } diff --git a/bevy_legion/legion_core/src/world.rs b/bevy_legion/legion_core/src/world.rs index 23acc68de9..ac56bb6d5b 100644 --- a/bevy_legion/legion_core/src/world.rs +++ b/bevy_legion/legion_core/src/world.rs @@ -462,8 +462,6 @@ impl World { /// This function has the overhead of moving the entity to either an existing or new archetype, /// causing a memory copy of the entity to a new location. This function should not be used /// multiple times in successive order. - /// - /// `World::add_components` should be used for adding multiple omponents to an entity at once. pub fn add_component( &mut self, entity: Entity, diff --git a/bevy_legion/legion_systems/src/schedule.rs b/bevy_legion/legion_systems/src/schedule.rs index 45bb7cf3dc..823261d7ac 100644 --- a/bevy_legion/legion_systems/src/schedule.rs +++ b/bevy_legion/legion_systems/src/schedule.rs @@ -168,8 +168,13 @@ impl Executor { 64, Default::default(), ); - let mut component_mutated = - FxHashMap::>::with_capacity_and_hasher( + let mut component_last_mutated = + FxHashMap::::with_capacity_and_hasher( + 64, + Default::default(), + ); + let mut component_last_read = + FxHashMap::::with_capacity_and_hasher( 64, Default::default(), ); @@ -220,23 +225,28 @@ impl Executor { // find component access dependencies let mut comp_dependencies = FxHashSet::default(); - for comp in read_comp { - if let Some(ns) = component_mutated.get(comp) { - for n in ns { - comp_dependencies.insert(*n); - } - } - } for comp in write_comp { - if let Some(ns) = component_mutated.get(comp) { - for n in ns { - comp_dependencies.insert(*n); - } + // Writes have to be exclusive, so we are dependent on reads too + trace!(component = ?comp, "Write component"); + if let Some(n) = component_last_read.get(comp) { + trace!(system_index = n, "Added read dependency"); + comp_dependencies.insert(*n); } - component_mutated - .entry(*comp) - .or_insert_with(Vec::new) - .push(i); + if let Some(n) = component_last_mutated.get(comp) { + trace!(system_index = n, "Added write dependency"); + comp_dependencies.insert(*n); + } + component_last_mutated.insert(*comp, i); + } + + // Do reads after writes to ensure we don't overwrite last_read + for comp in read_comp { + trace!(component = ?comp, "Read component"); + if let Some(n) = component_last_mutated.get(comp) { + trace!(system_index = n, "Added write dependency"); + comp_dependencies.insert(*n); + } + component_last_read.insert(*comp, i); } trace!(depentants = ?comp_dependencies, "Computed dynamic dependants"); diff --git a/bevy_legion/legion_systems/src/system.rs b/bevy_legion/legion_systems/src/system.rs index 80e721ab44..ee14f14984 100644 --- a/bevy_legion/legion_systems/src/system.rs +++ b/bevy_legion/legion_systems/src/system.rs @@ -1627,11 +1627,11 @@ mod tests { struct Comp2(f32, f32, f32); let components = vec![ - (Pos(69., 69., 69.), Vel(69., 69., 69.)), - (Pos(69., 69., 69.), Vel(69., 69., 69.)), + (Comp1(69., 69., 69.), Comp2(69., 69., 69.)), + (Comp1(69., 69., 69.), Comp2(69., 69., 69.)), ]; - let mut expected = HashMap::::new(); + let mut expected = HashMap::::new(); for (i, e) in world.insert((), components.clone()).iter().enumerate() { if let Some((pos, rot)) = components.get(i) { @@ -1670,7 +1670,7 @@ mod tests { let system3 = SystemBuilder::<()>::new("TestSystem3") .with_query(<(Write, Write)>::query()) .build(move |_, world, _, query| { - query.iter_mut(world).for_each(|(mut one, two)| { + query.iter_mut(world).for_each(|(mut one, mut two)| { assert_eq!(one.0, 456.); assert_eq!(one.1, 456.); assert_eq!(one.2, 456.); @@ -1683,9 +1683,9 @@ mod tests { one.1 = 789.; one.2 = 789.; - one.0 = 789.; - one.1 = 789.; - one.2 = 789.; + two.0 = 789.; + two.1 = 789.; + two.2 = 789.; }); });