Fix query transmute from table to archetype iteration unsoundness (#14615)

# Objective

- Fixes #14348 
- Fixes #14528
- Less complex (but also likely less performant) alternative to #14611

## Solution

- Add a `is_dense` field flag to `QueryIter` indicating whether it is
dense or not, that is whether it can perform dense iteration or not;
- Check this flag any time iteration over a query is performed.

---

It would be nice if someone could try benching this change to see if it
actually matters.

~Note that this not 100% ready for mergin, since there are a bunch of
safety comments on the use of the various `IS_DENSE` for checks that
still need to be updated.~ This is ready modulo benchmarks

---------

Co-authored-by: Alice Cecile <alice.i.cecile@gmail.com>
This commit is contained in:
Giacomo Stevanato 2024-08-27 02:58:40 +02:00 committed by GitHub
parent f06cd448db
commit e320fa0738
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 154 additions and 35 deletions

View file

@ -1,5 +1,6 @@
use std::marker::PhantomData;
use crate::component::StorageType;
use crate::{component::ComponentId, prelude::*};
use super::{FilteredAccess, QueryData, QueryFilter};
@ -68,6 +69,26 @@ impl<'w, D: QueryData, F: QueryFilter> QueryBuilder<'w, D, F> {
}
}
pub(super) fn is_dense(&self) -> bool {
// Note: `component_id` comes from the user in safe code, so we cannot trust it to
// exist. If it doesn't exist we pessimistically assume it's sparse.
let is_dense = |component_id| {
self.world()
.components()
.get_info(component_id)
.map_or(false, |info| info.storage_type() == StorageType::Table)
};
self.access
.access()
.component_reads_and_writes()
.all(is_dense)
&& self.access.access().archetypal().all(is_dense)
&& !self.access.access().has_read_all_components()
&& self.access.with_filters().all(is_dense)
&& self.access.without_filters().all(is_dense)
}
/// Returns a reference to the world passed to [`Self::new`].
pub fn world(&self) -> &World {
self.world
@ -396,4 +417,27 @@ mod tests {
assert_eq!(1, b.deref::<B>().0);
}
}
/// Regression test for issue #14348
#[test]
fn builder_static_dense_dynamic_sparse() {
#[derive(Component)]
struct Dense;
#[derive(Component)]
#[component(storage = "SparseSet")]
struct Sparse;
let mut world = World::new();
world.spawn(Dense);
world.spawn((Dense, Sparse));
let mut query = QueryBuilder::<&Dense>::new(&mut world)
.with::<Sparse>()
.build();
let matched = query.iter(&world).count();
assert_eq!(matched, 1);
}
}

View file

@ -128,7 +128,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> {
/// # Safety
/// - all `rows` must be in `[0, table.entity_count)`.
/// - `table` must match D and F
/// - Both `D::IS_DENSE` and `F::IS_DENSE` must be true.
/// - The query iteration must be dense (i.e. `self.query_state.is_dense` must be true).
#[inline]
pub(super) unsafe fn fold_over_table_range<B, Func>(
&mut self,
@ -183,7 +183,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> {
/// # Safety
/// - all `indices` must be in `[0, archetype.len())`.
/// - `archetype` must match D and F
/// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false.
/// - The query iteration must not be dense (i.e. `self.query_state.is_dense` must be false).
#[inline]
pub(super) unsafe fn fold_over_archetype_range<B, Func>(
&mut self,
@ -252,7 +252,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> {
/// - all `indices` must be in `[0, archetype.len())`.
/// - `archetype` must match D and F
/// - `archetype` must have the same length with it's table.
/// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false.
/// - The query iteration must not be dense (i.e. `self.query_state.is_dense` must be false).
#[inline]
pub(super) unsafe fn fold_over_dense_archetype_range<B, Func>(
&mut self,
@ -1031,20 +1031,27 @@ impl<'w, 's, D: QueryData, F: QueryFilter> Iterator for QueryIter<'w, 's, D, F>
let Some(item) = self.next() else { break };
accum = func(accum, item);
}
for id in self.cursor.storage_id_iter.clone() {
if D::IS_DENSE && F::IS_DENSE {
if self.cursor.is_dense {
for id in self.cursor.storage_id_iter.clone() {
// SAFETY: `self.cursor.is_dense` is true, so storage ids are guaranteed to be table ids.
let table_id = unsafe { id.table_id };
// SAFETY: Matched table IDs are guaranteed to still exist.
let table = unsafe { self.tables.get(id.table_id).debug_checked_unwrap() };
let table = unsafe { self.tables.get(table_id).debug_checked_unwrap() };
accum =
// SAFETY:
// - The fetched table matches both D and F
// - The provided range is equivalent to [0, table.entity_count)
// - The if block ensures that D::IS_DENSE and F::IS_DENSE are both true
// - The if block ensures that the query iteration is dense
unsafe { self.fold_over_table_range(accum, &mut func, table, 0..table.entity_count()) };
} else {
let archetype =
// SAFETY: Matched archetype IDs are guaranteed to still exist.
unsafe { self.archetypes.get(id.archetype_id).debug_checked_unwrap() };
}
} else {
for id in self.cursor.storage_id_iter.clone() {
// SAFETY: `self.cursor.is_dense` is false, so storage ids are guaranteed to be archetype ids.
let archetype_id = unsafe { id.archetype_id };
// SAFETY: Matched archetype IDs are guaranteed to still exist.
let archetype = unsafe { self.archetypes.get(archetype_id).debug_checked_unwrap() };
// SAFETY: Matched table IDs are guaranteed to still exist.
let table = unsafe { self.tables.get(archetype.table_id()).debug_checked_unwrap() };
@ -1052,19 +1059,19 @@ impl<'w, 's, D: QueryData, F: QueryFilter> Iterator for QueryIter<'w, 's, D, F>
// this leverages cache locality to optimize performance.
if table.entity_count() == archetype.len() {
accum =
// SAFETY:
// - The fetched archetype matches both D and F
// - The provided archetype and its' table have the same length.
// - The provided range is equivalent to [0, archetype.len)
// - The if block ensures that ether D::IS_DENSE or F::IS_DENSE are false
unsafe { self.fold_over_dense_archetype_range(accum, &mut func, archetype,0..archetype.len()) };
// SAFETY:
// - The fetched archetype matches both D and F
// - The provided archetype and its' table have the same length.
// - The provided range is equivalent to [0, archetype.len)
// - The if block ensures that the query iteration is not dense.
unsafe { self.fold_over_dense_archetype_range(accum, &mut func, archetype, 0..archetype.len()) };
} else {
accum =
// SAFETY:
// - The fetched archetype matches both D and F
// - The provided range is equivalent to [0, archetype.len)
// - The if block ensures that ether D::IS_DENSE or F::IS_DENSE are false
unsafe { self.fold_over_archetype_range(accum, &mut func, archetype,0..archetype.len()) };
// SAFETY:
// - The fetched archetype matches both D and F
// - The provided range is equivalent to [0, archetype.len)
// - The if block ensures that the query iteration is not dense.
unsafe { self.fold_over_archetype_range(accum, &mut func, archetype, 0..archetype.len()) };
}
}
}
@ -1675,6 +1682,8 @@ impl<'w, 's, D: QueryData, F: QueryFilter, const K: usize> Debug
}
struct QueryIterationCursor<'w, 's, D: QueryData, F: QueryFilter> {
// whether the query iteration is dense or not. Mirrors QueryState's `is_dense` field.
is_dense: bool,
storage_id_iter: std::slice::Iter<'s, StorageId>,
table_entities: &'w [Entity],
archetype_entities: &'w [ArchetypeEntity],
@ -1689,6 +1698,7 @@ struct QueryIterationCursor<'w, 's, D: QueryData, F: QueryFilter> {
impl<D: QueryData, F: QueryFilter> Clone for QueryIterationCursor<'_, '_, D, F> {
fn clone(&self) -> Self {
Self {
is_dense: self.is_dense,
storage_id_iter: self.storage_id_iter.clone(),
table_entities: self.table_entities,
archetype_entities: self.archetype_entities,
@ -1701,8 +1711,6 @@ impl<D: QueryData, F: QueryFilter> Clone for QueryIterationCursor<'_, '_, D, F>
}
impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
const IS_DENSE: bool = D::IS_DENSE && F::IS_DENSE;
unsafe fn init_empty(
world: UnsafeWorldCell<'w>,
query_state: &'s QueryState<D, F>,
@ -1732,6 +1740,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
table_entities: &[],
archetype_entities: &[],
storage_id_iter: query_state.matched_storage_ids.iter(),
is_dense: query_state.is_dense,
current_len: 0,
current_row: 0,
}
@ -1739,6 +1748,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
fn reborrow(&mut self) -> QueryIterationCursor<'_, 's, D, F> {
QueryIterationCursor {
is_dense: self.is_dense,
fetch: D::shrink_fetch(self.fetch.clone()),
filter: F::shrink_fetch(self.filter.clone()),
table_entities: self.table_entities,
@ -1754,7 +1764,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
unsafe fn peek_last(&mut self) -> Option<D::Item<'w>> {
if self.current_row > 0 {
let index = self.current_row - 1;
if Self::IS_DENSE {
if self.is_dense {
let entity = self.table_entities.get_unchecked(index);
Some(D::fetch(
&mut self.fetch,
@ -1780,7 +1790,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
/// will be **the exact count of remaining values**.
fn max_remaining(&self, tables: &'w Tables, archetypes: &'w Archetypes) -> usize {
let ids = self.storage_id_iter.clone();
let remaining_matched: usize = if Self::IS_DENSE {
let remaining_matched: usize = if self.is_dense {
// SAFETY: The if check ensures that storage_id_iter stores TableIds
unsafe { ids.map(|id| tables[id.table_id].entity_count()).sum() }
} else {
@ -1803,7 +1813,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
archetypes: &'w Archetypes,
query_state: &'s QueryState<D, F>,
) -> Option<D::Item<'w>> {
if Self::IS_DENSE {
if self.is_dense {
loop {
// we are on the beginning of the query, or finished processing a table, so skip to the next
if self.current_row == self.current_len {

View file

@ -126,7 +126,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
fn get_batch_size(&self, thread_count: usize) -> usize {
let max_items = || {
let id_iter = self.state.matched_storage_ids.iter();
if D::IS_DENSE && F::IS_DENSE {
if self.state.is_dense {
// SAFETY: We only access table metadata.
let tables = unsafe { &self.world.world_metadata().storages().tables };
id_iter

View file

@ -24,7 +24,10 @@ use super::{
/// An ID for either a table or an archetype. Used for Query iteration.
///
/// Query iteration is exclusively dense (over tables) or archetypal (over archetypes) based on whether
/// both `D::IS_DENSE` and `F::IS_DENSE` are true or not.
/// the query filters are dense or not. This is represented by the [`QueryState::is_dense`] field.
///
/// Note that `D::IS_DENSE` and `F::IS_DENSE` have no relationship with `QueryState::is_dense` and
/// any combination of their values can happen.
///
/// This is a union instead of an enum as the usage is determined at compile time, as all [`StorageId`]s for
/// a [`QueryState`] will be all [`TableId`]s or all [`ArchetypeId`]s, and not a mixture of both. This
@ -68,6 +71,9 @@ pub struct QueryState<D: QueryData, F: QueryFilter = ()> {
pub(crate) component_access: FilteredAccess<ComponentId>,
// NOTE: we maintain both a bitset and a vec because iterating the vec is faster
pub(super) matched_storage_ids: Vec<StorageId>,
// Represents whether this query iteration is dense or not. When this is true
// `matched_storage_ids` stores `TableId`s, otherwise it stores `ArchetypeId`s.
pub(super) is_dense: bool,
pub(crate) fetch_state: D::State,
pub(crate) filter_state: F::State,
#[cfg(feature = "trace")]
@ -194,10 +200,15 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
// properly considered in a global "cross-query" context (both within systems and across systems).
component_access.extend(&filter_component_access);
// For queries without dynamic filters the dense-ness of the query is equal to the dense-ness
// of its static type parameters.
let is_dense = D::IS_DENSE && F::IS_DENSE;
Self {
world_id: world.id(),
archetype_generation: ArchetypeGeneration::initial(),
matched_storage_ids: Vec::new(),
is_dense,
fetch_state,
filter_state,
component_access,
@ -222,6 +233,8 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
world_id: builder.world().id(),
archetype_generation: ArchetypeGeneration::initial(),
matched_storage_ids: Vec::new(),
// For dynamic queries the dense-ness is given by the query builder.
is_dense: builder.is_dense(),
fetch_state,
filter_state,
component_access: builder.access().clone(),
@ -450,7 +463,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
let archetype_index = archetype.id().index();
if !self.matched_archetypes.contains(archetype_index) {
self.matched_archetypes.grow_and_insert(archetype_index);
if !D::IS_DENSE || !F::IS_DENSE {
if !self.is_dense {
self.matched_storage_ids.push(StorageId {
archetype_id: archetype.id(),
});
@ -459,7 +472,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
let table_index = archetype.table_id().as_usize();
if !self.matched_tables.contains(table_index) {
self.matched_tables.grow_and_insert(table_index);
if D::IS_DENSE && F::IS_DENSE {
if self.is_dense {
self.matched_storage_ids.push(StorageId {
table_id: archetype.table_id(),
});
@ -560,6 +573,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
world_id: self.world_id,
archetype_generation: self.archetype_generation,
matched_storage_ids: self.matched_storage_ids.clone(),
is_dense: self.is_dense,
fetch_state,
filter_state,
component_access: self.component_access.clone(),
@ -653,12 +667,15 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
warn!("You have tried to join queries with different archetype_generations. This could lead to unpredictable results.");
}
// the join is dense of both the queries were dense.
let is_dense = self.is_dense && other.is_dense;
// take the intersection of the matched ids
let mut matched_tables = self.matched_tables.clone();
let mut matched_archetypes = self.matched_archetypes.clone();
matched_tables.intersect_with(&other.matched_tables);
matched_archetypes.intersect_with(&other.matched_archetypes);
let matched_storage_ids = if NewD::IS_DENSE && NewF::IS_DENSE {
let matched_storage_ids = if is_dense {
matched_tables
.ones()
.map(|id| StorageId {
@ -678,6 +695,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
world_id: self.world_id,
archetype_generation: self.archetype_generation,
matched_storage_ids,
is_dense,
fetch_state: new_fetch_state,
filter_state: new_filter_state,
component_access: joined_component_access,
@ -1487,7 +1505,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
let mut iter = self.iter_unchecked_manual(world, last_run, this_run);
let mut accum = init_accum();
for storage_id in queue {
if D::IS_DENSE && F::IS_DENSE {
if self.is_dense {
let id = storage_id.table_id;
let table = &world.storages().tables.get(id).debug_checked_unwrap();
accum = iter.fold_over_table_range(
@ -1521,7 +1539,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
#[cfg(feature = "trace")]
let _span = self.par_iter_span.enter();
let accum = init_accum();
if D::IS_DENSE && F::IS_DENSE {
if self.is_dense {
let id = storage_id.table_id;
let table = world.storages().tables.get(id).debug_checked_unwrap();
self.iter_unchecked_manual(world, last_run, this_run)
@ -1537,7 +1555,7 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
};
let storage_entity_count = |storage_id: StorageId| -> usize {
if D::IS_DENSE && F::IS_DENSE {
if self.is_dense {
tables[storage_id.table_id].entity_count()
} else {
archetypes[storage_id.archetype_id].len()
@ -2042,6 +2060,53 @@ mod tests {
world.query::<(&A, &B)>().transmute::<&B>(&world2);
}
/// Regression test for issue #14528
#[test]
fn transmute_from_sparse_to_dense() {
#[derive(Component)]
struct Dense;
#[derive(Component)]
#[component(storage = "SparseSet")]
struct Sparse;
let mut world = World::new();
world.spawn(Dense);
world.spawn((Dense, Sparse));
let mut query = world
.query_filtered::<&Dense, With<Sparse>>()
.transmute::<&Dense>(&world);
let matched = query.iter(&world).count();
assert_eq!(matched, 1);
}
#[test]
fn transmute_from_dense_to_sparse() {
#[derive(Component)]
struct Dense;
#[derive(Component)]
#[component(storage = "SparseSet")]
struct Sparse;
let mut world = World::new();
world.spawn(Dense);
world.spawn((Dense, Sparse));
let mut query = world
.query::<&Dense>()
.transmute_filtered::<&Dense, With<Sparse>>(&world);
// Note: `transmute_filtered` is supposed to keep the same matched tables/archetypes,
// so it doesn't actually filter out those entities without `Sparse` and the iteration
// remains dense.
let matched = query.iter(&world).count();
assert_eq!(matched, 2);
}
#[test]
fn join() {
let mut world = World::new();