Improve par_iter and Parallel (#12904)

# Objective

- bevy usually use `Parallel::scope` to collect items from `par_iter`,
but `scope` will be called with every satifified items. it will cause a
lot of unnecessary lookup.

## Solution

- similar to Rayon ,we introduce `for_each_init` for `par_iter` which
only be invoked when spawn a task for a group of items.

---

## Changelog

- added  `for_each_init`

## Performance
`check_visibility `  in  `many_foxes ` 

![image](https://github.com/bevyengine/bevy/assets/45868716/030c41cf-0d2f-4a36-a071-35097d93e494)
 
~40% performance gain in `check_visibility`.

---------

Co-authored-by: James Liu <contact@jamessliu.com>
This commit is contained in:
re0312 2024-04-23 20:05:34 +08:00 committed by GitHub
parent b1ab036329
commit 0f27500e46
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 152 additions and 138 deletions

View file

@ -41,54 +41,6 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIter<'w, 's, D, F> {
}
}
/// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment
/// from a table.
///
/// # Safety
/// - all `rows` must be in `[0, table.entity_count)`.
/// - `table` must match D and F
/// - Both `D::IS_DENSE` and `F::IS_DENSE` must be true.
#[inline]
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
pub(super) unsafe fn for_each_in_table_range<Func>(
&mut self,
func: &mut Func,
table: &'w Table,
rows: Range<usize>,
) where
Func: FnMut(D::Item<'w>),
{
// SAFETY: Caller assures that D::IS_DENSE and F::IS_DENSE are true, that table matches D and F
// and all indices in rows are in range.
unsafe {
self.fold_over_table_range((), &mut |_, item| func(item), table, rows);
}
}
/// Executes the equivalent of [`Iterator::for_each`] over a contiguous segment
/// from an archetype.
///
/// # Safety
/// - all `indices` must be in `[0, archetype.len())`.
/// - `archetype` must match D and F
/// - Either `D::IS_DENSE` or `F::IS_DENSE` must be false.
#[inline]
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
pub(super) unsafe fn for_each_in_archetype_range<Func>(
&mut self,
func: &mut Func,
archetype: &'w Archetype,
rows: Range<usize>,
) where
Func: FnMut(D::Item<'w>),
{
// SAFETY: Caller assures that either D::IS_DENSE or F::IS_DENSE are false, that archetype matches D and F
// and all indices in rows are in range.
unsafe {
self.fold_over_archetype_range((), &mut |_, item| func(item), archetype, rows);
}
}
/// Executes the equivalent of [`Iterator::fold`] over a contiguous segment
/// from an table.
///
@ -752,7 +704,7 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryIterationCursor<'w, 's, D, F> {
}
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::par_for_each_unchecked_manual
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::par_fold_init_unchecked_manual
/// # Safety
/// `tables` and `archetypes` must belong to the same world that the [`QueryIterationCursor`]
/// was initialized for.

View file

@ -35,8 +35,52 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[inline]
pub fn for_each<FN: Fn(QueryItem<'w, D>) + Send + Sync + Clone>(self, func: FN) {
self.for_each_init(|| {}, |_, item| func(item));
}
/// Runs `func` on each query result in parallel on a value returned by `init`.
///
/// `init` may be called multiple times per thread, and the values returned may be discarded between tasks on any given thread.
/// Callers should avoid using this function as if it were a a parallel version
/// of [`Iterator::fold`].
///
/// # Example
///
/// ```
/// use bevy_utils::Parallel;
/// use crate::{bevy_ecs::prelude::Component, bevy_ecs::system::Query};
/// #[derive(Component)]
/// struct T;
/// fn system(query: Query<&T>){
/// let mut queue: Parallel<usize> = Parallel::default();
/// // queue.borrow_local_mut() will get or create a thread_local queue for each task/thread;
/// query.par_iter().for_each_init(|| queue.borrow_local_mut(),|local_queue,item| {
/// **local_queue += 1;
/// });
///
/// // collect value from every thread
/// let entity_count: usize = queue.iter_mut().map(|v| *v).sum();
/// }
/// ```
///
/// # Panics
/// If the [`ComputeTaskPool`] is not initialized. If using this from a query that is being
/// initialized and run from the ECS scheduler, this should never panic.
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[inline]
pub fn for_each_init<FN, INIT, T>(self, init: INIT, func: FN)
where
FN: Fn(&mut T, QueryItem<'w, D>) + Send + Sync + Clone,
INIT: Fn() -> T + Sync + Send + Clone,
{
let func = |mut init, item| {
func(&mut init, item);
init
};
#[cfg(any(target_arch = "wasm32", not(feature = "multi-threaded")))]
{
let init = init();
// SAFETY:
// This method can only be called once per instance of QueryParIter,
// which ensures that mutable queries cannot be executed multiple times at once.
@ -46,25 +90,27 @@ impl<'w, 's, D: QueryData, F: QueryFilter> QueryParIter<'w, 's, D, F> {
unsafe {
self.state
.iter_unchecked_manual(self.world, self.last_run, self.this_run)
.for_each(func);
.fold(init, func);
}
}
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
{
let thread_count = bevy_tasks::ComputeTaskPool::get().thread_num();
if thread_count <= 1 {
let init = init();
// SAFETY: See the safety comment above.
unsafe {
self.state
.iter_unchecked_manual(self.world, self.last_run, self.this_run)
.for_each(func);
.fold(init, func);
}
} else {
// Need a batch size of at least 1.
let batch_size = self.get_batch_size(thread_count).max(1);
// SAFETY: See the safety comment above.
unsafe {
self.state.par_for_each_unchecked_manual(
self.state.par_fold_init_unchecked_manual(
init,
self.world,
batch_size,
func,

View file

@ -1394,19 +1394,20 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
///
/// [`ComputeTaskPool`]: bevy_tasks::ComputeTaskPool
#[cfg(all(not(target_arch = "wasm32"), feature = "multi-threaded"))]
pub(crate) unsafe fn par_for_each_unchecked_manual<
'w,
FN: Fn(D::Item<'w>) + Send + Sync + Clone,
>(
pub(crate) unsafe fn par_fold_init_unchecked_manual<'w, T, FN, INIT>(
&self,
init_accum: INIT,
world: UnsafeWorldCell<'w>,
batch_size: usize,
func: FN,
last_run: Tick,
this_run: Tick,
) {
) where
FN: Fn(T, D::Item<'w>) -> T + Send + Sync + Clone,
INIT: Fn() -> T + Sync + Send + Clone,
{
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
// QueryIter, QueryIterationCursor, QueryManyIter, QueryCombinationIter,QueryState::par_fold_init_unchecked_manual
use arrayvec::ArrayVec;
bevy_tasks::ComputeTaskPool::get().scope(|scope| {
@ -1423,19 +1424,27 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
}
let queue = std::mem::take(queue);
let mut func = func.clone();
let init_accum = init_accum.clone();
scope.spawn(async move {
#[cfg(feature = "trace")]
let _span = self.par_iter_span.enter();
let mut iter = self.iter_unchecked_manual(world, last_run, this_run);
let mut accum = init_accum();
for storage_id in queue {
if D::IS_DENSE && F::IS_DENSE {
let id = storage_id.table_id;
let table = &world.storages().tables.get(id).debug_checked_unwrap();
iter.for_each_in_table_range(&mut func, table, 0..table.entity_count());
accum = iter.fold_over_table_range(
accum,
&mut func,
table,
0..table.entity_count(),
);
} else {
let id = storage_id.archetype_id;
let archetype = world.archetypes().get(id).debug_checked_unwrap();
iter.for_each_in_archetype_range(
accum = iter.fold_over_archetype_range(
accum,
&mut func,
archetype,
0..archetype.len(),
@ -1449,21 +1458,23 @@ impl<D: QueryData, F: QueryFilter> QueryState<D, F> {
let submit_single = |count, storage_id: StorageId| {
for offset in (0..count).step_by(batch_size) {
let mut func = func.clone();
let init_accum = init_accum.clone();
let len = batch_size.min(count - offset);
let batch = offset..offset + len;
scope.spawn(async move {
#[cfg(feature = "trace")]
let _span = self.par_iter_span.enter();
let accum = init_accum();
if D::IS_DENSE && F::IS_DENSE {
let id = storage_id.table_id;
let table = world.storages().tables.get(id).debug_checked_unwrap();
self.iter_unchecked_manual(world, last_run, this_run)
.for_each_in_table_range(&mut func, table, batch);
.fold_over_table_range(accum, &mut func, table, batch);
} else {
let id = storage_id.archetype_id;
let archetype = world.archetypes().get(id).debug_checked_unwrap();
self.iter_unchecked_manual(world, last_run, this_run)
.for_each_in_archetype_range(&mut func, archetype, batch);
.fold_over_archetype_range(accum, &mut func, archetype, batch);
}
});
}

View file

@ -595,8 +595,10 @@ pub fn extract_meshes_for_cpu_building(
)>,
>,
) {
meshes_query.par_iter().for_each(
|(
meshes_query.par_iter().for_each_init(
|| render_mesh_instance_queues.borrow_local_mut(),
|queue,
(
entity,
view_visibility,
transform,
@ -621,23 +623,19 @@ pub fn extract_meshes_for_cpu_building(
no_automatic_batching,
);
render_mesh_instance_queues.scope(|queue| {
let transform = transform.affine();
queue.push((
entity,
RenderMeshInstanceCpu {
transforms: MeshTransforms {
transform: (&transform).into(),
previous_transform: (&previous_transform
.map(|t| t.0)
.unwrap_or(transform))
.into(),
flags: mesh_flags.bits(),
},
shared,
let transform = transform.affine();
queue.push((
entity,
RenderMeshInstanceCpu {
transforms: MeshTransforms {
transform: (&transform).into(),
previous_transform: (&previous_transform.map(|t| t.0).unwrap_or(transform))
.into(),
flags: mesh_flags.bits(),
},
));
});
shared,
},
));
},
);
@ -683,8 +681,10 @@ pub fn extract_meshes_for_gpu_building(
)>,
>,
) {
meshes_query.par_iter().for_each(
|(
meshes_query.par_iter().for_each_init(
|| render_mesh_instance_queues.borrow_local_mut(),
|queue,
(
entity,
view_visibility,
transform,
@ -713,17 +713,15 @@ pub fn extract_meshes_for_gpu_building(
let lightmap_uv_rect =
lightmap::pack_lightmap_uv_rect(lightmap.map(|lightmap| lightmap.uv_rect));
render_mesh_instance_queues.scope(|queue| {
queue.push((
entity,
RenderMeshInstanceGpuBuilder {
shared,
transform: (&transform.affine()).into(),
lightmap_uv_rect,
mesh_flags,
},
));
});
queue.push((
entity,
RenderMeshInstanceGpuBuilder {
shared,
transform: (&transform.affine()).into(),
lightmap_uv_rect,
mesh_flags,
},
));
},
);

View file

@ -453,52 +453,53 @@ pub fn check_visibility<QF>(
let view_mask = maybe_view_mask.copied().unwrap_or_default();
visible_aabb_query.par_iter_mut().for_each(|query_item| {
let (
entity,
inherited_visibility,
mut view_visibility,
maybe_entity_mask,
maybe_model_aabb,
transform,
no_frustum_culling,
) = query_item;
visible_aabb_query.par_iter_mut().for_each_init(
|| thread_queues.borrow_local_mut(),
|queue, query_item| {
let (
entity,
inherited_visibility,
mut view_visibility,
maybe_entity_mask,
maybe_model_aabb,
transform,
no_frustum_culling,
) = query_item;
// Skip computing visibility for entities that are configured to be hidden.
// ViewVisibility has already been reset in `reset_view_visibility`.
if !inherited_visibility.get() {
return;
}
// Skip computing visibility for entities that are configured to be hidden.
// ViewVisibility has already been reset in `reset_view_visibility`.
if !inherited_visibility.get() {
return;
}
let entity_mask = maybe_entity_mask.copied().unwrap_or_default();
if !view_mask.intersects(&entity_mask) {
return;
}
let entity_mask = maybe_entity_mask.copied().unwrap_or_default();
if !view_mask.intersects(&entity_mask) {
return;
}
// If we have an aabb, do frustum culling
if !no_frustum_culling {
if let Some(model_aabb) = maybe_model_aabb {
let model = transform.affine();
let model_sphere = Sphere {
center: model.transform_point3a(model_aabb.center),
radius: transform.radius_vec3a(model_aabb.half_extents),
};
// Do quick sphere-based frustum culling
if !frustum.intersects_sphere(&model_sphere, false) {
return;
}
// Do aabb-based frustum culling
if !frustum.intersects_obb(model_aabb, &model, true, false) {
return;
// If we have an aabb, do frustum culling
if !no_frustum_culling {
if let Some(model_aabb) = maybe_model_aabb {
let model = transform.affine();
let model_sphere = Sphere {
center: model.transform_point3a(model_aabb.center),
radius: transform.radius_vec3a(model_aabb.half_extents),
};
// Do quick sphere-based frustum culling
if !frustum.intersects_sphere(&model_sphere, false) {
return;
}
// Do aabb-based frustum culling
if !frustum.intersects_obb(model_aabb, &model, true, false) {
return;
}
}
}
}
view_visibility.set();
thread_queues.scope(|queue| {
view_visibility.set();
queue.push(entity);
});
});
},
);
visible_entities.clear::<QF>();
thread_queues.drain_into(visible_entities.get_mut::<QF>());

View file

@ -1,4 +1,4 @@
use core::cell::Cell;
use std::{cell::RefCell, ops::DerefMut};
use thread_local::ThreadLocal;
/// A cohesive set of thread-local values of a given type.
@ -6,9 +6,10 @@ use thread_local::ThreadLocal;
/// Mutable references can be fetched if `T: Default` via [`Parallel::scope`].
#[derive(Default)]
pub struct Parallel<T: Send> {
locals: ThreadLocal<Cell<T>>,
locals: ThreadLocal<RefCell<T>>,
}
/// A scope guard of a `Parallel`, when this struct is dropped ,the value will writeback to its `Parallel`
impl<T: Send> Parallel<T> {
/// Gets a mutable iterator over all of the per-thread queues.
pub fn iter_mut(&mut self) -> impl Iterator<Item = &'_ mut T> {
@ -26,12 +27,17 @@ impl<T: Default + Send> Parallel<T> {
///
/// If there is no thread-local value, it will be initialized to its default.
pub fn scope<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
let cell = self.locals.get_or_default();
let mut value = cell.take();
let ret = f(&mut value);
cell.set(value);
let mut cell = self.locals.get_or_default().borrow_mut();
let ret = f(cell.deref_mut());
ret
}
/// Mutably borrows the thread-local value.
///
/// If there is no thread-local value, it will be initialized to it's default.
pub fn borrow_local_mut(&self) -> impl DerefMut<Target = T> + '_ {
self.locals.get_or_default().borrow_mut()
}
}
impl<T, I> Parallel<I>