Improve type inference in DynSystemParam::downcast() by making the type parameter match the return value. (#15103)

# Objective

Right now, `DynSystemParam::downcast()` always requires the type
parameter to be specified with a turbofish. Make it so that it can be
inferred from the use of the return value, like:

```rust
fn expects_res_a(mut param: DynSystemParam) {
    let res: Res<A> = param.downcast().unwrap();
}
```

## Solution

The reason this doesn't currently work is that the type parameter is a
`'static` version of the `SystemParam` so that it can be used with
`Any::downcast_mut()`. Change the method signature so that the type
parameter matches the return type, and use `T::Item<'static, 'static>`
to get the `'static` version. That means we wind up returning a
`T::Item<'static, 'static>::Item<'w, 's>`, so constrain that to be equal
to `T`. That works with every `SystemParam` implementation, since they
have `T::Item == T` up to lifetimes.
This commit is contained in:
Chris Russell 2024-09-16 18:56:57 -04:00 committed by GitHub
parent db525e660e
commit 382917fbb3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1894,7 +1894,9 @@ unsafe impl<T: ?Sized> ReadOnlySystemParam for PhantomData<T> {}
/// assert!(param.is::<Res<A>>()); /// assert!(param.is::<Res<A>>());
/// assert!(!param.is::<Res<B>>()); /// assert!(!param.is::<Res<B>>());
/// assert!(param.downcast_mut::<Res<B>>().is_none()); /// assert!(param.downcast_mut::<Res<B>>().is_none());
/// let foo: Res<A> = param.downcast::<Res<A>>().unwrap(); /// let res = param.downcast_mut::<Res<A>>().unwrap();
/// // The type parameter can be left out if it can be determined from use.
/// let res: Res<A> = param.downcast().unwrap();
/// } /// }
/// ///
/// let system = ( /// let system = (
@ -1942,13 +1944,21 @@ impl<'w, 's> DynSystemParam<'w, 's> {
} }
/// Returns `true` if the inner system param is the same as `T`. /// Returns `true` if the inner system param is the same as `T`.
pub fn is<T: SystemParam + 'static>(&self) -> bool { pub fn is<T: SystemParam>(&self) -> bool
self.state.is::<ParamState<T>>() // See downcast() function for an explanation of the where clause
where
T::Item<'static, 'static>: SystemParam<Item<'w, 's> = T> + 'static,
{
self.state.is::<ParamState<T::Item<'static, 'static>>>()
} }
/// Returns the inner system param if it is the correct type. /// Returns the inner system param if it is the correct type.
/// This consumes the dyn param, so the returned param can have its original world and state lifetimes. /// This consumes the dyn param, so the returned param can have its original world and state lifetimes.
pub fn downcast<T: SystemParam + 'static>(self) -> Option<T::Item<'w, 's>> { pub fn downcast<T: SystemParam>(self) -> Option<T>
// See downcast() function for an explanation of the where clause
where
T::Item<'static, 'static>: SystemParam<Item<'w, 's> = T> + 'static,
{
// SAFETY: // SAFETY:
// - `DynSystemParam::new()` ensures `state` is a `ParamState<T>`, that the world matches, // - `DynSystemParam::new()` ensures `state` is a `ParamState<T>`, that the world matches,
// and that it has access required by the inner system param. // and that it has access required by the inner system param.
@ -1958,7 +1968,11 @@ impl<'w, 's> DynSystemParam<'w, 's> {
/// Returns the inner system parameter if it is the correct type. /// Returns the inner system parameter if it is the correct type.
/// This borrows the dyn param, so the returned param is only valid for the duration of that borrow. /// This borrows the dyn param, so the returned param is only valid for the duration of that borrow.
pub fn downcast_mut<T: SystemParam + 'static>(&mut self) -> Option<T::Item<'_, '_>> { pub fn downcast_mut<'a, T: SystemParam>(&'a mut self) -> Option<T>
// See downcast() function for an explanation of the where clause
where
T::Item<'static, 'static>: SystemParam<Item<'a, 'a> = T> + 'static,
{
// SAFETY: // SAFETY:
// - `DynSystemParam::new()` ensures `state` is a `ParamState<T>`, that the world matches, // - `DynSystemParam::new()` ensures `state` is a `ParamState<T>`, that the world matches,
// and that it has access required by the inner system param. // and that it has access required by the inner system param.
@ -1971,9 +1985,11 @@ impl<'w, 's> DynSystemParam<'w, 's> {
/// but since it only performs read access it can keep the original world lifetime. /// but since it only performs read access it can keep the original world lifetime.
/// This can be useful with methods like [`Query::iter_inner()`] or [`Res::into_inner()`] /// This can be useful with methods like [`Query::iter_inner()`] or [`Res::into_inner()`]
/// to obtain references with the original world lifetime. /// to obtain references with the original world lifetime.
pub fn downcast_mut_inner<T: ReadOnlySystemParam + 'static>( pub fn downcast_mut_inner<'a, T: ReadOnlySystemParam>(&'a mut self) -> Option<T>
&mut self, // See downcast() function for an explanation of the where clause
) -> Option<T::Item<'w, '_>> { where
T::Item<'static, 'static>: SystemParam<Item<'w, 'a> = T> + 'static,
{
// SAFETY: // SAFETY:
// - `DynSystemParam::new()` ensures `state` is a `ParamState<T>`, that the world matches, // - `DynSystemParam::new()` ensures `state` is a `ParamState<T>`, that the world matches,
// and that it has access required by the inner system param. // and that it has access required by the inner system param.
@ -1988,19 +2004,32 @@ impl<'w, 's> DynSystemParam<'w, 's> {
/// in [`init_state`](SystemParam::init_state) for the inner system param. /// in [`init_state`](SystemParam::init_state) for the inner system param.
/// - `world` must be the same `World` that was used to initialize /// - `world` must be the same `World` that was used to initialize
/// [`state`](SystemParam::init_state) for the inner system param. /// [`state`](SystemParam::init_state) for the inner system param.
unsafe fn downcast<'w, 's, T: SystemParam + 'static>( unsafe fn downcast<'w, 's, T: SystemParam>(
state: &'s mut dyn Any, state: &'s mut dyn Any,
system_meta: &SystemMeta, system_meta: &SystemMeta,
world: UnsafeWorldCell<'w>, world: UnsafeWorldCell<'w>,
change_tick: Tick, change_tick: Tick,
) -> Option<T::Item<'w, 's>> { ) -> Option<T>
state.downcast_mut::<ParamState<T>>().map(|state| { // We need a 'static version of the SystemParam to use with `Any::downcast_mut()`,
// SAFETY: // and we need a <'w, 's> version to actually return.
// - The caller ensures the world has access for the underlying system param, // The type parameter T must be the one we return in order to get type inference from the return value.
// and since the downcast succeeded, the underlying system param is T. // So we use `T::Item<'static, 'static>` as the 'static version, and require that it be 'static.
// - The caller ensures the `world` matches. // That means the return value will be T::Item<'static, 'static>::Item<'w, 's>,
unsafe { T::get_param(&mut state.0, system_meta, world, change_tick) } // so we constrain that to be equal to T.
}) // Every actual `SystemParam` implementation has `T::Item == T` up to lifetimes,
// so they should all work with this constraint.
where
T::Item<'static, 'static>: SystemParam<Item<'w, 's> = T> + 'static,
{
state
.downcast_mut::<ParamState<T::Item<'static, 'static>>>()
.map(|state| {
// SAFETY:
// - The caller ensures the world has access for the underlying system param,
// and since the downcast succeeded, the underlying system param is T.
// - The caller ensures the `world` matches.
unsafe { T::Item::get_param(&mut state.0, system_meta, world, change_tick) }
})
} }
/// The [`SystemParam::State`] for a [`DynSystemParam`]. /// The [`SystemParam::State`] for a [`DynSystemParam`].
@ -2323,4 +2352,12 @@ mod tests {
schedule.add_systems((non_send_param_set, non_send_param_set, non_send_param_set)); schedule.add_systems((non_send_param_set, non_send_param_set, non_send_param_set));
schedule.run(&mut world); schedule.run(&mut world);
} }
fn _dyn_system_param_type_inference(mut p: DynSystemParam) {
// Make sure the downcast() methods are able to infer their type parameters from the use of the return type.
// This is just a compilation test, so there is nothing to run.
let _query: Query<()> = p.downcast_mut().unwrap();
let _query: Query<()> = p.downcast_mut_inner().unwrap();
let _query: Query<()> = p.downcast().unwrap();
}
} }