Support on_thread_spawn and on_thread_destroy for TaskPoolPlugin (#13045)

# Objective

- Allow to configure `on_thread_spawn` and `on_thread_destroy` when
using `TaskPoolPlugin` of bevy.

## Solution

- In `TaskPoolThreadAssignmentPolicy`, two options `on_thread_spawn` and
`on_thread_destroy` are added, which will be passed to two new methods
motioned above when creating corresponding task pool using builder.
- Due to lack of debug derive for these two options, manually implement
the debug for `TaskPoolThreadAssignmentPolicy`.

---

## Changelog

### Added
- `on_thread_spawn` option and `on_thread_destroy` option to the
`TaskPoolPlugin`, allow user to customize them as needed.

## Migration Guide

- `TaskPooolThreadAssignmentPolicy` now has two additional fields:
`on_thread_spawn` and `on_thread_destroy`. Please consider defaulting
them to `None`.

---------

Co-authored-by: François Mockers <mockersf@gmail.com>
Co-authored-by: François Mockers <francois.mockers@vleue.com>
This commit is contained in:
Logic 2024-11-11 15:00:01 -05:00 committed by GitHub
parent ef23f465ce
commit bafb9a25fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 78 additions and 10 deletions

View file

@ -1,9 +1,12 @@
use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder}; use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder};
use bevy_utils::tracing::trace; use bevy_utils::tracing::trace;
use alloc::sync::Arc;
use core::fmt::Debug;
/// Defines a simple way to determine how many threads to use given the number of remaining cores /// Defines a simple way to determine how many threads to use given the number of remaining cores
/// and number of total cores /// and number of total cores
#[derive(Clone, Debug)] #[derive(Clone)]
pub struct TaskPoolThreadAssignmentPolicy { pub struct TaskPoolThreadAssignmentPolicy {
/// Force using at least this many threads /// Force using at least this many threads
pub min_threads: usize, pub min_threads: usize,
@ -12,6 +15,22 @@ pub struct TaskPoolThreadAssignmentPolicy {
/// Target using this percentage of total cores, clamped by `min_threads` and `max_threads`. It is /// Target using this percentage of total cores, clamped by `min_threads` and `max_threads`. It is
/// permitted to use 1.0 to try to use all remaining threads /// permitted to use 1.0 to try to use all remaining threads
pub percent: f32, pub percent: f32,
/// Callback that is invoked once for every created thread as it starts.
/// This configuration will be ignored under wasm platform.
pub on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
/// Callback that is invoked once for every created thread as it terminates
/// This configuration will be ignored under wasm platform.
pub on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
}
impl Debug for TaskPoolThreadAssignmentPolicy {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("TaskPoolThreadAssignmentPolicy")
.field("min_threads", &self.min_threads)
.field("max_threads", &self.max_threads)
.field("percent", &self.percent)
.finish()
}
} }
impl TaskPoolThreadAssignmentPolicy { impl TaskPoolThreadAssignmentPolicy {
@ -61,6 +80,8 @@ impl Default for TaskPoolOptions {
min_threads: 1, min_threads: 1,
max_threads: 4, max_threads: 4,
percent: 0.25, percent: 0.25,
on_thread_spawn: None,
on_thread_destroy: None,
}, },
// Use 25% of cores for async compute, at least 1, no more than 4 // Use 25% of cores for async compute, at least 1, no more than 4
@ -68,6 +89,8 @@ impl Default for TaskPoolOptions {
min_threads: 1, min_threads: 1,
max_threads: 4, max_threads: 4,
percent: 0.25, percent: 0.25,
on_thread_spawn: None,
on_thread_destroy: None,
}, },
// Use all remaining cores for compute (at least 1) // Use all remaining cores for compute (at least 1)
@ -75,6 +98,8 @@ impl Default for TaskPoolOptions {
min_threads: 1, min_threads: 1,
max_threads: usize::MAX, max_threads: usize::MAX,
percent: 1.0, // This 1.0 here means "whatever is left over" percent: 1.0, // This 1.0 here means "whatever is left over"
on_thread_spawn: None,
on_thread_destroy: None,
}, },
} }
} }
@ -108,10 +133,21 @@ impl TaskPoolOptions {
remaining_threads = remaining_threads.saturating_sub(io_threads); remaining_threads = remaining_threads.saturating_sub(io_threads);
IoTaskPool::get_or_init(|| { IoTaskPool::get_or_init(|| {
TaskPoolBuilder::default() let mut builder = TaskPoolBuilder::default()
.num_threads(io_threads) .num_threads(io_threads)
.thread_name("IO Task Pool".to_string()) .thread_name("IO Task Pool".to_string());
.build()
#[cfg(not(target_arch = "wasm32"))]
{
if let Some(f) = self.io.on_thread_spawn.clone() {
builder = builder.on_thread_spawn(move || f());
}
if let Some(f) = self.io.on_thread_destroy.clone() {
builder = builder.on_thread_destroy(move || f());
}
}
builder.build()
}); });
} }
@ -125,10 +161,21 @@ impl TaskPoolOptions {
remaining_threads = remaining_threads.saturating_sub(async_compute_threads); remaining_threads = remaining_threads.saturating_sub(async_compute_threads);
AsyncComputeTaskPool::get_or_init(|| { AsyncComputeTaskPool::get_or_init(|| {
TaskPoolBuilder::default() let mut builder = TaskPoolBuilder::default()
.num_threads(async_compute_threads) .num_threads(async_compute_threads)
.thread_name("Async Compute Task Pool".to_string()) .thread_name("Async Compute Task Pool".to_string());
.build()
#[cfg(not(target_arch = "wasm32"))]
{
if let Some(f) = self.async_compute.on_thread_spawn.clone() {
builder = builder.on_thread_spawn(move || f());
}
if let Some(f) = self.async_compute.on_thread_destroy.clone() {
builder = builder.on_thread_destroy(move || f());
}
}
builder.build()
}); });
} }
@ -142,10 +189,21 @@ impl TaskPoolOptions {
trace!("Compute Threads: {}", compute_threads); trace!("Compute Threads: {}", compute_threads);
ComputeTaskPool::get_or_init(|| { ComputeTaskPool::get_or_init(|| {
TaskPoolBuilder::default() let mut builder = TaskPoolBuilder::default()
.num_threads(compute_threads) .num_threads(compute_threads)
.thread_name("Compute Task Pool".to_string()) .thread_name("Compute Task Pool".to_string());
.build()
#[cfg(not(target_arch = "wasm32"))]
{
if let Some(f) = self.compute.on_thread_spawn.clone() {
builder = builder.on_thread_spawn(move || f());
}
if let Some(f) = self.compute.on_thread_destroy.clone() {
builder = builder.on_thread_destroy(move || f());
}
}
builder.build()
}); });
} }
} }

View file

@ -46,6 +46,16 @@ impl TaskPoolBuilder {
self self
} }
/// No op on the single threaded task pool
pub fn on_thread_spawn(self, _f: impl Fn() + Send + Sync + 'static) -> Self {
self
}
/// No op on the single threaded task pool
pub fn on_thread_destroy(self, _f: impl Fn() + Send + Sync + 'static) -> Self {
self
}
/// Creates a new [`TaskPool`] /// Creates a new [`TaskPool`]
pub fn build(self) -> TaskPool { pub fn build(self) -> TaskPool {
TaskPool::new_internal() TaskPool::new_internal()