Fix unsoundness in FilteredEntity{Ref,Mut} various get methods (#13554)

# Objective

- `FilteredEntity{Ref,Mut}` various `get` methods never checked that the
given component was present on the entity, only the access allowed
reading/writing them, which is always the case when it is constructed
from a `EntityRef`/`EntityMut`/`EntityWorldMut` (and I guess can also
happen with queries containing `Option<T>` that get transmuted).
- In those cases the various `get` methods were calling
`debug_checked_unwrap` on `None`s, which is UB when debug assertions are
not enabled;
- The goal is thus to fix this soundness issue.

## Solution

- Don't call `debug_checked_unwrap` on those `None` and instead
`flatten` them.

## Testing

- This PR includes regression tests for each combination of
`FilteredEntityRef`/`FilteredEntityMut` and component
present/not-present. The two tests for the not-present cases fail on
`main` but success with this PR changes.
This commit is contained in:
Giacomo Stevanato 2024-05-28 16:01:23 +02:00 committed by GitHub
parent bc102d41de
commit d98d6d8d00
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,7 +4,7 @@ use crate::{
change_detection::MutUntyped,
component::{Component, ComponentId, ComponentTicks, Components, StorageType},
entity::{Entities, Entity, EntityLocation},
query::{Access, DebugCheckedUnwrap},
query::Access,
removal_detection::RemovedComponentEvents,
storage::Storages,
world::{Mut, World},
@ -1824,8 +1824,9 @@ impl<'w> FilteredEntityRef<'w> {
let id = self.entity.world().components().get_id(TypeId::of::<T>())?;
self.access
.has_read(id)
// SAFETY: We have read access so we must have the component
.then(|| unsafe { self.entity.get().debug_checked_unwrap() })
// SAFETY: We have read access
.then(|| unsafe { self.entity.get() })
.flatten()
}
/// Gets access to the component of type `T` for the current entity,
@ -1837,8 +1838,9 @@ impl<'w> FilteredEntityRef<'w> {
let id = self.entity.world().components().get_id(TypeId::of::<T>())?;
self.access
.has_read(id)
// SAFETY: We have read access so we must have the component
.then(|| unsafe { self.entity.get_ref().debug_checked_unwrap() })
// SAFETY: We have read access
.then(|| unsafe { self.entity.get_ref() })
.flatten()
}
/// Retrieves the change ticks for the given component. This can be useful for implementing change
@ -1848,8 +1850,9 @@ impl<'w> FilteredEntityRef<'w> {
let id = self.entity.world().components().get_id(TypeId::of::<T>())?;
self.access
.has_read(id)
// SAFETY: We have read access so we must have the component
.then(|| unsafe { self.entity.get_change_ticks::<T>().debug_checked_unwrap() })
// SAFETY: We have read access
.then(|| unsafe { self.entity.get_change_ticks::<T>() })
.flatten()
}
/// Retrieves the change ticks for the given [`ComponentId`]. This can be useful for implementing change
@ -1860,12 +1863,11 @@ impl<'w> FilteredEntityRef<'w> {
/// compile time.**
#[inline]
pub fn get_change_ticks_by_id(&self, component_id: ComponentId) -> Option<ComponentTicks> {
// SAFETY: We have read access so we must have the component
self.access.has_read(component_id).then(|| unsafe {
self.entity
.get_change_ticks_by_id(component_id)
.debug_checked_unwrap()
})
self.access
.has_read(component_id)
// SAFETY: We have read access
.then(|| unsafe { self.entity.get_change_ticks_by_id(component_id) })
.flatten()
}
/// Gets the component of the given [`ComponentId`] from the entity.
@ -1880,8 +1882,9 @@ impl<'w> FilteredEntityRef<'w> {
pub fn get_by_id(&self, component_id: ComponentId) -> Option<Ptr<'w>> {
self.access
.has_read(component_id)
// SAFETY: We have read access so we must have the component
.then(|| unsafe { self.entity.get_by_id(component_id).debug_checked_unwrap() })
// SAFETY: We have read access
.then(|| unsafe { self.entity.get_by_id(component_id) })
.flatten()
}
}
@ -2094,8 +2097,9 @@ impl<'w> FilteredEntityMut<'w> {
let id = self.entity.world().components().get_id(TypeId::of::<T>())?;
self.access
.has_write(id)
// SAFETY: We have write access so we must have the component
.then(|| unsafe { self.entity.get_mut().debug_checked_unwrap() })
// SAFETY: We have write access
.then(|| unsafe { self.entity.get_mut() })
.flatten()
}
/// Retrieves the change ticks for the given component. This can be useful for implementing change
@ -2139,12 +2143,11 @@ impl<'w> FilteredEntityMut<'w> {
/// which is only valid while the [`FilteredEntityMut`] is alive.
#[inline]
pub fn get_mut_by_id(&mut self, component_id: ComponentId) -> Option<MutUntyped<'_>> {
// SAFETY: We have write access so we must have the component
self.access.has_write(component_id).then(|| unsafe {
self.entity
.get_mut_by_id(component_id)
.debug_checked_unwrap()
})
self.access
.has_write(component_id)
// SAFETY: We have write access
.then(|| unsafe { self.entity.get_mut_by_id(component_id) })
.flatten()
}
}
@ -2416,6 +2419,7 @@ mod tests {
use bevy_ptr::OwningPtr;
use std::panic::AssertUnwindSafe;
use crate::world::{FilteredEntityMut, FilteredEntityRef};
use crate::{self as bevy_ecs, component::ComponentId, prelude::*, system::assert_is_system};
#[test]
@ -2918,4 +2922,64 @@ mod tests {
assert_is_system(incompatible_system);
}
#[test]
fn filtered_entity_ref_normal() {
let mut world = World::new();
let a_id = world.init_component::<A>();
let e: FilteredEntityRef = world.spawn(A).into();
assert!(e.get::<A>().is_some());
assert!(e.get_ref::<A>().is_some());
assert!(e.get_change_ticks::<A>().is_some());
assert!(e.get_by_id(a_id).is_some());
assert!(e.get_change_ticks_by_id(a_id).is_some());
}
#[test]
fn filtered_entity_ref_missing() {
let mut world = World::new();
let a_id = world.init_component::<A>();
let e: FilteredEntityRef = world.spawn(()).into();
assert!(e.get::<A>().is_none());
assert!(e.get_ref::<A>().is_none());
assert!(e.get_change_ticks::<A>().is_none());
assert!(e.get_by_id(a_id).is_none());
assert!(e.get_change_ticks_by_id(a_id).is_none());
}
#[test]
fn filtered_entity_mut_normal() {
let mut world = World::new();
let a_id = world.init_component::<A>();
let mut e: FilteredEntityMut = world.spawn(A).into();
assert!(e.get::<A>().is_some());
assert!(e.get_ref::<A>().is_some());
assert!(e.get_mut::<A>().is_some());
assert!(e.get_change_ticks::<A>().is_some());
assert!(e.get_by_id(a_id).is_some());
assert!(e.get_mut_by_id(a_id).is_some());
assert!(e.get_change_ticks_by_id(a_id).is_some());
}
#[test]
fn filtered_entity_mut_missing() {
let mut world = World::new();
let a_id = world.init_component::<A>();
let mut e: FilteredEntityMut = world.spawn(()).into();
assert!(e.get::<A>().is_none());
assert!(e.get_ref::<A>().is_none());
assert!(e.get_mut::<A>().is_none());
assert!(e.get_change_ticks::<A>().is_none());
assert!(e.get_by_id(a_id).is_none());
assert!(e.get_mut_by_id(a_id).is_none());
assert!(e.get_change_ticks_by_id(a_id).is_none());
}
}