Remove task_pool parameter from par_for_each(_mut) (#4705)

# Objective
Fixes #3183. Requiring a `&TaskPool` parameter is sort of meaningless if the only correct one is to use the one provided by `Res<ComputeTaskPool>` all the time.

## Solution
Have `QueryState` save a clone of the `ComputeTaskPool` which is used for all `par_for_each` functions.

~~Adds a small overhead of the internal `Arc` clone as a part of the startup, but the ergonomics win should be well worth this hardly-noticable overhead.~~

Updated the docs to note that it will panic the task pool is not present as a resource.

# Future Work
If https://github.com/bevyengine/rfcs/pull/54 is approved, we can replace these resource lookups with a static function call instead to get the `ComputeTaskPool`.

---

## Changelog
Removed: The `task_pool` parameter of `Query(State)::par_for_each(_mut)`. These calls will use the `World`'s `ComputeTaskPool` resource instead.

## Migration Guide
The `task_pool` parameter for `Query(State)::par_for_each(_mut)` has been removed. Remove these parameters from all calls to these functions.

Before:
```rust
fn parallel_system(
   task_pool: Res<ComputeTaskPool>,
   query: Query<&MyComponent>,
) {
   query.par_for_each(&task_pool, 32, |comp| {
        ...
   });
}
```

After:

```rust
fn parallel_system(query: Query<&MyComponent>) {
   query.par_for_each(32, |comp| {
        ...
   });
}
```

If using `Query(State)` outside of a system run by the scheduler, you may need to manually configure and initialize a `ComputeTaskPool` as a resource in the `World`.
This commit is contained in:
James Liu 2022-05-30 16:59:38 +00:00
parent f59ea7e6e8
commit c5e89894f4
5 changed files with 151 additions and 127 deletions

View file

@ -1,5 +1,5 @@
use bevy_ecs::prelude::*;
use bevy_tasks::TaskPool;
use bevy_tasks::{ComputeTaskPool, TaskPool};
use glam::*;
#[derive(Component, Copy, Clone)]
@ -29,8 +29,8 @@ impl Benchmark {
)
}));
fn sys(task_pool: Res<TaskPool>, mut query: Query<(&mut Position, &mut Transform)>) {
query.par_for_each_mut(&task_pool, 128, |(mut pos, mut mat)| {
fn sys(mut query: Query<(&mut Position, &mut Transform)>) {
query.par_for_each_mut(128, |(mut pos, mut mat)| {
for _ in 0..100 {
mat.0 = mat.0.inverse();
}
@ -39,7 +39,7 @@ impl Benchmark {
});
}
world.insert_resource(TaskPool::default());
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let mut system = IntoSystem::into_system(sys);
system.initialize(&mut world);
system.update_archetype_component_access(&world);

View file

@ -59,7 +59,7 @@ mod tests {
query::{Added, ChangeTrackers, Changed, FilteredAccess, With, Without, WorldQuery},
world::{Mut, World},
};
use bevy_tasks::TaskPool;
use bevy_tasks::{ComputeTaskPool, TaskPool};
use std::{
any::TypeId,
sync::{
@ -376,7 +376,7 @@ mod tests {
#[test]
fn par_for_each_dense() {
let mut world = World::new();
let task_pool = TaskPool::default();
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let e1 = world.spawn().insert(A(1)).id();
let e2 = world.spawn().insert(A(2)).id();
let e3 = world.spawn().insert(A(3)).id();
@ -385,7 +385,7 @@ mod tests {
let results = Arc::new(Mutex::new(Vec::new()));
world
.query::<(Entity, &A)>()
.par_for_each(&world, &task_pool, 2, |(e, &A(i))| {
.par_for_each(&world, 2, |(e, &A(i))| {
results.lock().unwrap().push((e, i));
});
results.lock().unwrap().sort();
@ -398,8 +398,7 @@ mod tests {
#[test]
fn par_for_each_sparse() {
let mut world = World::new();
let task_pool = TaskPool::default();
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let e1 = world.spawn().insert(SparseStored(1)).id();
let e2 = world.spawn().insert(SparseStored(2)).id();
let e3 = world.spawn().insert(SparseStored(3)).id();
@ -408,7 +407,6 @@ mod tests {
let results = Arc::new(Mutex::new(Vec::new()));
world.query::<(Entity, &SparseStored)>().par_for_each(
&world,
&task_pool,
2,
|(e, &SparseStored(i))| results.lock().unwrap().push((e, i)),
);

View file

@ -10,17 +10,18 @@ use crate::{
storage::TableId,
world::{World, WorldId},
};
use bevy_tasks::TaskPool;
use bevy_tasks::{ComputeTaskPool, TaskPool};
#[cfg(feature = "trace")]
use bevy_utils::tracing::Instrument;
use fixedbitset::FixedBitSet;
use std::fmt;
use std::{fmt, ops::Deref};
use super::{QueryFetch, QueryItem, ROQueryFetch, ROQueryItem};
/// Provides scoped access to a [`World`] state according to a given [`WorldQuery`] and query filter.
pub struct QueryState<Q: WorldQuery, F: WorldQuery = ()> {
world_id: WorldId,
task_pool: Option<TaskPool>,
pub(crate) archetype_generation: ArchetypeGeneration,
pub(crate) matched_tables: FixedBitSet,
pub(crate) matched_archetypes: FixedBitSet,
@ -61,6 +62,9 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
let mut state = Self {
world_id: world.id(),
task_pool: world
.get_resource::<ComputeTaskPool>()
.map(|task_pool| task_pool.deref().clone()),
archetype_generation: ArchetypeGeneration::initial(),
matched_table_ids: Vec::new(),
matched_archetype_ids: Vec::new(),
@ -689,15 +693,18 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
);
}
/// Runs `func` on each query result in parallel using the given `task_pool`.
/// Runs `func` on each query result in parallel.
///
/// This can only be called for read-only queries, see [`Self::par_for_each_mut`] for
/// write-queries.
///
/// # Panics
/// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
#[inline]
pub fn par_for_each<'w, FN: Fn(ROQueryItem<'w, Q>) + Send + Sync + Clone>(
&mut self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
) {
@ -706,7 +713,6 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
self.update_archetypes(world);
self.par_for_each_unchecked_manual::<ROQueryFetch<Q>, FN>(
world,
task_pool,
batch_size,
func,
world.last_change_tick(),
@ -715,12 +721,15 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
}
}
/// Runs `func` on each query result in parallel using the given `task_pool`.
/// Runs `func` on each query result in parallel.
///
/// # Panics
/// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
#[inline]
pub fn par_for_each_mut<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(
&mut self,
world: &'w mut World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
) {
@ -729,7 +738,6 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
self.update_archetypes(world);
self.par_for_each_unchecked_manual::<QueryFetch<Q>, FN>(
world,
task_pool,
batch_size,
func,
world.last_change_tick(),
@ -738,10 +746,14 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
}
}
/// Runs `func` on each query result in parallel using the given `task_pool`.
/// Runs `func` on each query result in parallel.
///
/// This can only be called for read-only queries.
///
/// # Panics
/// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
///
/// # Safety
///
/// This does not check for mutable query correctness. To be safe, make sure mutable queries
@ -750,14 +762,12 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
pub unsafe fn par_for_each_unchecked<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(
&mut self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
) {
self.update_archetypes(world);
self.par_for_each_unchecked_manual::<QueryFetch<Q>, FN>(
world,
task_pool,
batch_size,
func,
world.last_change_tick(),
@ -833,6 +843,10 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
/// the current change tick are given. This is faster than the equivalent
/// iter() method, but cannot be chained like a normal [`Iterator`].
///
/// # Panics
/// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
///
/// # Safety
///
/// This does not check for mutable query correctness. To be safe, make sure mutable queries
@ -846,7 +860,6 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
>(
&self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
last_change_tick: u32,
@ -854,95 +867,106 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
) {
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
task_pool.scope(|scope| {
if QF::IS_DENSE && <QueryFetch<'static, F>>::IS_DENSE {
let tables = &world.storages().tables;
for table_id in &self.matched_table_ids {
let table = &tables[*table_id];
let mut offset = 0;
while offset < table.len() {
let func = func.clone();
let len = batch_size.min(table.len() - offset);
let task = async move {
let mut fetch =
QF::init(world, &self.fetch_state, last_change_tick, change_tick);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let table = &tables[*table_id];
fetch.set_table(&self.fetch_state, table);
filter.set_table(&self.filter_state, table);
for table_index in offset..offset + len {
if !filter.table_filter_fetch(table_index) {
continue;
self.task_pool
.as_ref()
.expect("Cannot iterate query in parallel. No ComputeTaskPool initialized.")
.scope(|scope| {
if QF::IS_DENSE && <QueryFetch<'static, F>>::IS_DENSE {
let tables = &world.storages().tables;
for table_id in &self.matched_table_ids {
let table = &tables[*table_id];
let mut offset = 0;
while offset < table.len() {
let func = func.clone();
let len = batch_size.min(table.len() - offset);
let task = async move {
let mut fetch = QF::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let table = &tables[*table_id];
fetch.set_table(&self.fetch_state, table);
filter.set_table(&self.filter_state, table);
for table_index in offset..offset + len {
if !filter.table_filter_fetch(table_index) {
continue;
}
let item = fetch.table_fetch(table_index);
func(item);
}
let item = fetch.table_fetch(table_index);
func(item);
}
};
#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
#[cfg(feature = "trace")]
let task = task.instrument(span);
scope.spawn(task);
offset += batch_size;
};
#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
#[cfg(feature = "trace")]
let task = task.instrument(span);
scope.spawn(task);
offset += batch_size;
}
}
} else {
let archetypes = &world.archetypes;
for archetype_id in &self.matched_archetype_ids {
let mut offset = 0;
let archetype = &archetypes[*archetype_id];
while offset < archetype.len() {
let func = func.clone();
let len = batch_size.min(archetype.len() - offset);
let task = async move {
let mut fetch = QF::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let archetype = &world.archetypes[*archetype_id];
fetch.set_archetype(&self.fetch_state, archetype, tables);
filter.set_archetype(&self.filter_state, archetype, tables);
for archetype_index in offset..offset + len {
if !filter.archetype_filter_fetch(archetype_index) {
continue;
}
func(fetch.archetype_fetch(archetype_index));
}
};
#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
#[cfg(feature = "trace")]
let task = task.instrument(span);
scope.spawn(task);
offset += batch_size;
}
}
}
} else {
let archetypes = &world.archetypes;
for archetype_id in &self.matched_archetype_ids {
let mut offset = 0;
let archetype = &archetypes[*archetype_id];
while offset < archetype.len() {
let func = func.clone();
let len = batch_size.min(archetype.len() - offset);
let task = async move {
let mut fetch =
QF::init(world, &self.fetch_state, last_change_tick, change_tick);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let archetype = &world.archetypes[*archetype_id];
fetch.set_archetype(&self.fetch_state, archetype, tables);
filter.set_archetype(&self.filter_state, archetype, tables);
for archetype_index in offset..offset + len {
if !filter.archetype_filter_fetch(archetype_index) {
continue;
}
func(fetch.archetype_fetch(archetype_index));
}
};
#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
#[cfg(feature = "trace")]
let task = task.instrument(span);
scope.spawn(task);
offset += batch_size;
}
}
}
});
});
}
/// Returns a single immutable query result when there is exactly one entity matching

View file

@ -7,7 +7,6 @@ use crate::{
},
world::{Mut, World},
};
use bevy_tasks::TaskPool;
use std::{any::TypeId, fmt::Debug};
/// Provides scoped access to components in a [`World`].
@ -493,7 +492,7 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> {
};
}
/// Runs `f` on each query result in parallel using the given [`TaskPool`].
/// Runs `f` on each query result in parallel using the [`World`]'s [`ComputeTaskPool`].
///
/// This can only be called for immutable data, see [`Self::par_for_each_mut`] for
/// mutable access.
@ -502,7 +501,7 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> {
///
/// The items in the query get sorted into batches.
/// Internally, this function spawns a group of futures that each take on a `batch_size` sized section of the items (or less if the division is not perfect).
/// Then, the tasks in the [`TaskPool`] work through these futures.
/// Then, the tasks in the [`ComputeTaskPool`] work through these futures.
///
/// You can use this value to tune between maximum multithreading ability (many small batches) and minimum parallelization overhead (few big batches).
/// Rule of thumb: If the function body is (mostly) computationally expensive but there are not many items, a small batch size (=more batches) may help to even out the load.
@ -510,13 +509,17 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> {
///
/// # Arguments
///
///* `task_pool` - The [`TaskPool`] to use
///* `batch_size` - The number of batches to spawn
///* `f` - The function to run on each item in the query
///
/// # Panics
/// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
///
/// [`ComputeTaskPool`]: bevy_tasks::prelude::ComputeTaskPool
#[inline]
pub fn par_for_each<'this>(
&'this self,
task_pool: &TaskPool,
batch_size: usize,
f: impl Fn(ROQueryItem<'this, Q>) + Send + Sync + Clone,
) {
@ -526,7 +529,6 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> {
self.state
.par_for_each_unchecked_manual::<ROQueryFetch<Q>, _>(
self.world,
task_pool,
batch_size,
f,
self.last_change_tick,
@ -535,12 +537,17 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> {
};
}
/// Runs `f` on each query result in parallel using the given [`TaskPool`].
/// Runs `f` on each query result in parallel using the [`World`]'s [`ComputeTaskPool`].
/// See [`Self::par_for_each`] for more details.
///
/// # Panics
/// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
///
/// [`ComputeTaskPool`]: bevy_tasks::prelude::ComputeTaskPool
#[inline]
pub fn par_for_each_mut<'a, FN: Fn(QueryItem<'a, Q>) + Send + Sync + Clone>(
&'a mut self,
task_pool: &TaskPool,
batch_size: usize,
f: FN,
) {
@ -550,7 +557,6 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> {
self.state
.par_for_each_unchecked_manual::<QueryFetch<Q>, FN>(
self.world,
task_pool,
batch_size,
f,
self.last_change_tick,

View file

@ -1,6 +1,6 @@
//! Illustrates parallel queries with `ParallelIterator`.
use bevy::{prelude::*, tasks::prelude::*};
use bevy::prelude::*;
use rand::random;
#[derive(Component, Deref)]
@ -23,26 +23,22 @@ fn spawn_system(mut commands: Commands, asset_server: Res<AssetServer>) {
}
// Move sprites according to their velocity
fn move_system(pool: Res<ComputeTaskPool>, mut sprites: Query<(&mut Transform, &Velocity)>) {
fn move_system(mut sprites: Query<(&mut Transform, &Velocity)>) {
// Compute the new location of each sprite in parallel on the
// ComputeTaskPool using batches of 32 sprites
//
// This example is only for demonstrative purposes. Using a
// This example is only for demonstrative purposes. Using a
// ParallelIterator for an inexpensive operation like addition on only 128
// elements will not typically be faster than just using a normal Iterator.
// See the ParallelIterator documentation for more information on when
// to use or not use ParallelIterator over a normal Iterator.
sprites.par_for_each_mut(&pool, 32, |(mut transform, velocity)| {
sprites.par_for_each_mut(32, |(mut transform, velocity)| {
transform.translation += velocity.extend(0.0);
});
}
// Bounce sprites outside the window
fn bounce_system(
pool: Res<ComputeTaskPool>,
windows: Res<Windows>,
mut sprites: Query<(&Transform, &mut Velocity)>,
) {
fn bounce_system(windows: Res<Windows>, mut sprites: Query<(&Transform, &mut Velocity)>) {
let window = windows.primary();
let width = window.width();
let height = window.height();
@ -53,7 +49,7 @@ fn bounce_system(
sprites
// Batch size of 32 is chosen to limit the overhead of
// ParallelIterator, since negating a vector is very inexpensive.
.par_for_each_mut(&pool, 32, |(transform, mut v)| {
.par_for_each_mut(32, |(transform, mut v)| {
if !(left < transform.translation.x
&& transform.translation.x < right
&& bottom < transform.translation.y