Simplify parallel iteration methods (#8854)

# Objective

The `QueryParIter::for_each_mut` function is required when doing
parallel iteration with mutable queries.
This results in an unfortunate stutter:
`query.par_iter_mut().par_for_each_mut()` ('mut' is repeated).

## Solution

- Make `for_each` compatible with mutable queries, and deprecate
`for_each_mut`. In order to prevent `for_each` from being called
multiple times in parallel, we take ownership of the QueryParIter.

---

## Changelog

- `QueryParIter::for_each` is now compatible with mutable queries.
`for_each_mut` has been deprecated as it is now redundant.

## Migration Guide

The method `QueryParIter::for_each_mut` has been deprecated and is no
longer functional. Use `for_each` instead, which now supports mutable
queries.

```rust
// Before:
query.par_iter_mut().for_each_mut(|x| ...);

// After:
query.par_iter_mut().for_each(|x| ...);
```

The method `QueryParIter::for_each` now takes ownership of the
`QueryParIter`, rather than taking a shared reference.

```rust
// Before:
let par_iter = my_query.par_iter().batching_strategy(my_batching_strategy);
par_iter.for_each(|x| {
    // ...Do stuff with x...
    par_iter.for_each(|y| {
        // ...Do nested stuff with y...
    });
});

// After:
my_query.par_iter().batching_strategy(my_batching_strategy).for_each(|x| {
    // ...Do stuff with x...
    my_query.par_iter().batching_strategy(my_batching_strategy).for_each(|y| {
        // ...Do nested stuff with y...
    });
});
```
This commit is contained in:
Joseph 2023-07-23 07:09:24 -04:00 committed by GitHub
parent 5b0e6a5321
commit ddbfa48711
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 57 additions and 66 deletions

View file

@ -34,7 +34,7 @@ pub fn heavy_compute(c: &mut Criterion) {
})); }));
fn sys(mut query: Query<(&mut Position, &mut Transform)>) { fn sys(mut query: Query<(&mut Position, &mut Transform)>) {
query.par_iter_mut().for_each_mut(|(mut pos, mut mat)| { query.par_iter_mut().for_each(|(mut pos, mut mat)| {
for _ in 0..100 { for _ in 0..100 {
mat.0 = mat.0.inverse(); mat.0 = mat.0.inverse();
} }

View file

@ -364,7 +364,7 @@ pub fn animation_player(
) { ) {
animation_players animation_players
.par_iter_mut() .par_iter_mut()
.for_each_mut(|(root, maybe_parent, mut player)| { .for_each(|(root, maybe_parent, mut player)| {
update_transitions(&mut player, &time); update_transitions(&mut player, &time);
run_animation_player( run_animation_player(
root, root,

View file

@ -778,7 +778,7 @@ mod tests {
world.spawn((A(1), B(1))); world.spawn((A(1), B(1)));
fn propagate_system(mut query: Query<(&A, &mut B), Changed<A>>) { fn propagate_system(mut query: Query<(&A, &mut B), Changed<A>>) {
query.par_iter_mut().for_each_mut(|(a, mut b)| { query.par_iter_mut().for_each(|(a, mut b)| {
b.0 = a.0; b.0 = a.0;
}); });
} }

View file

@ -1,7 +1,7 @@
use crate::{component::Tick, world::unsafe_world_cell::UnsafeWorldCell}; use crate::{component::Tick, world::unsafe_world_cell::UnsafeWorldCell};
use std::ops::Range; use std::ops::Range;
use super::{QueryItem, QueryState, ROQueryItem, ReadOnlyWorldQuery, WorldQuery}; use super::{QueryItem, QueryState, ReadOnlyWorldQuery, WorldQuery};
/// Dictates how a parallel query chunks up large tables/archetypes /// Dictates how a parallel query chunks up large tables/archetypes
/// during iteration. /// during iteration.
@ -90,26 +90,6 @@ pub struct QueryParIter<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> {
pub(crate) batching_strategy: BatchingStrategy, pub(crate) batching_strategy: BatchingStrategy,
} }
impl<'w, 's, Q: ReadOnlyWorldQuery, F: ReadOnlyWorldQuery> QueryParIter<'w, 's, Q, F> {
/// Runs `func` on each query result in parallel.
///
/// This can only be called for read-only queries, see [`Self::for_each_mut`] for
/// write-queries.
///
/// # Panics
/// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being
/// initialized and run from the ECS scheduler, this should never panic.
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[inline]
pub fn for_each<FN: Fn(ROQueryItem<'w, Q>) + Send + Sync + Clone>(&self, func: FN) {
// SAFETY: query is read only
unsafe {
self.for_each_unchecked(func);
}
}
}
impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> QueryParIter<'w, 's, Q, F> { impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> QueryParIter<'w, 's, Q, F> {
/// Changes the batching strategy used when iterating. /// Changes the batching strategy used when iterating.
/// ///
@ -123,61 +103,72 @@ impl<'w, 's, Q: WorldQuery, F: ReadOnlyWorldQuery> QueryParIter<'w, 's, Q, F> {
/// Runs `func` on each query result in parallel. /// Runs `func` on each query result in parallel.
/// ///
/// # Panics /// # Panics
/// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being /// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being
/// initialized and run from the ECS scheduler, this should never panic. /// initialized and run from the ECS scheduler, this should never panic.
/// ///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool /// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[inline] #[inline]
pub fn for_each_mut<FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(&mut self, func: FN) { pub fn for_each<FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(self, func: FN) {
// SAFETY: query has unique world access
unsafe {
self.for_each_unchecked(func);
}
}
/// Runs `func` on each query result in parallel.
///
/// # Panics
/// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being
/// initialized and run from the ECS scheduler, this should never panic.
///
/// # Safety
///
/// This does not check for mutable query correctness. To be safe, make sure mutable queries
/// have unique access to the components they query.
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[inline]
unsafe fn for_each_unchecked<FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(&self, func: FN) {
#[cfg(any(target = "wasm32", not(feature = "multi-threaded")))] #[cfg(any(target = "wasm32", not(feature = "multi-threaded")))]
{ {
self.state // SAFETY:
.for_each_unchecked_manual(self.world, func, self.last_run, self.this_run); // This method can only be called once per instance of QueryParIter,
} // which ensures that mutable queries cannot be executed multiple times at once.
#[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))] // Mutable instances of QueryParIter can only be created via an exclusive borrow of a
{ // Query or a World, which ensures that multiple aliasing QueryParIters cannot exist
let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num(); // at the same time.
if thread_count <= 1 { unsafe {
self.state.for_each_unchecked_manual( self.state.for_each_unchecked_manual(
self.world, self.world,
func, func,
self.last_run, self.last_run,
self.this_run, self.this_run,
); );
}
}
#[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))]
{
let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num();
if thread_count <= 1 {
// SAFETY: See the safety comment above.
unsafe {
self.state.for_each_unchecked_manual(
self.world,
func,
self.last_run,
self.this_run,
);
}
} else { } else {
// Need a batch size of at least 1. // Need a batch size of at least 1.
let batch_size = self.get_batch_size(thread_count).max(1); let batch_size = self.get_batch_size(thread_count).max(1);
self.state.par_for_each_unchecked_manual( // SAFETY: See the safety comment above.
self.world, unsafe {
batch_size, self.state.par_for_each_unchecked_manual(
func, self.world,
self.last_run, batch_size,
self.this_run, func,
); self.last_run,
self.this_run,
);
}
} }
} }
} }
/// Runs `func` on each query result in parallel.
///
/// # Panics
/// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being
/// initialized and run from the ECS scheduler, this should never panic.
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[inline]
#[deprecated = "use `.for_each(...)` instead."]
pub fn for_each_mut<FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(self, func: FN) {
self.for_each(func);
}
#[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))] #[cfg(all(not(target = "wasm32"), feature = "multi-threaded"))]
fn get_batch_size(&self, thread_count: usize) -> usize { fn get_batch_size(&self, thread_count: usize) -> usize {
if self.batching_strategy.batch_size_limits.is_empty() { if self.batching_strategy.batch_size_limits.is_empty() {

View file

@ -367,7 +367,7 @@ pub fn check_visibility(
let view_mask = maybe_view_mask.copied().unwrap_or_default(); let view_mask = maybe_view_mask.copied().unwrap_or_default();
visible_entities.entities.clear(); visible_entities.entities.clear();
visible_aabb_query.par_iter_mut().for_each_mut( visible_aabb_query.par_iter_mut().for_each(
|( |(
entity, entity,
mut computed_visibility, mut computed_visibility,
@ -412,7 +412,7 @@ pub fn check_visibility(
}, },
); );
visible_no_aabb_query.par_iter_mut().for_each_mut( visible_no_aabb_query.par_iter_mut().for_each(
|(entity, mut computed_visibility, maybe_entity_mask)| { |(entity, mut computed_visibility, maybe_entity_mask)| {
// skip computing visibility for entities that are configured to be hidden. is_visible_in_view has already been set to false // skip computing visibility for entities that are configured to be hidden. is_visible_in_view has already been set to false
// in visibility_propagate_system // in visibility_propagate_system

View file

@ -29,7 +29,7 @@ pub fn sync_simple_transforms(
query query
.p0() .p0()
.par_iter_mut() .par_iter_mut()
.for_each_mut(|(transform, mut global_transform)| { .for_each(|(transform, mut global_transform)| {
*global_transform = GlobalTransform::from(*transform); *global_transform = GlobalTransform::from(*transform);
}); });
// Update orphaned entities. // Update orphaned entities.
@ -59,7 +59,7 @@ pub fn propagate_transforms(
orphaned_entities.clear(); orphaned_entities.clear();
orphaned_entities.extend(orphaned.iter()); orphaned_entities.extend(orphaned.iter());
orphaned_entities.sort_unstable(); orphaned_entities.sort_unstable();
root_query.par_iter_mut().for_each_mut( root_query.par_iter_mut().for_each(
|(entity, children, transform, mut global_transform)| { |(entity, children, transform, mut global_transform)| {
let changed = transform.is_changed() || global_transform.is_added() || orphaned_entities.binary_search(&entity).is_ok(); let changed = transform.is_changed() || global_transform.is_added() || orphaned_entities.binary_search(&entity).is_ok();
if changed { if changed {

View file

@ -34,7 +34,7 @@ fn move_system(mut sprites: Query<(&mut Transform, &Velocity)>) {
// to use or not use ParallelIterator over a normal Iterator. // to use or not use ParallelIterator over a normal Iterator.
sprites sprites
.par_iter_mut() .par_iter_mut()
.for_each_mut(|(mut transform, velocity)| { .for_each(|(mut transform, velocity)| {
transform.translation += velocity.extend(0.0); transform.translation += velocity.extend(0.0);
}); });
} }
@ -54,7 +54,7 @@ fn bounce_system(windows: Query<&Window>, mut sprites: Query<(&Transform, &mut V
sprites sprites
.par_iter_mut() .par_iter_mut()
.batching_strategy(BatchingStrategy::fixed(32)) .batching_strategy(BatchingStrategy::fixed(32))
.for_each_mut(|(transform, mut v)| { .for_each(|(transform, mut v)| {
if !(left < transform.translation.x if !(left < transform.translation.x
&& transform.translation.x < right && transform.translation.x < right
&& bottom < transform.translation.y && bottom < transform.translation.y