From bafb9a25fbd4adb153a883c845bac66e2714cce7 Mon Sep 17 00:00:00 2001 From: Logic <38597904+LogicFan@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:00:01 -0500 Subject: [PATCH] Support `on_thread_spawn` and `on_thread_destroy` for `TaskPoolPlugin` (#13045) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Objective - Allow to configure `on_thread_spawn` and `on_thread_destroy` when using `TaskPoolPlugin` of bevy. ## Solution - In `TaskPoolThreadAssignmentPolicy`, two options `on_thread_spawn` and `on_thread_destroy` are added, which will be passed to two new methods motioned above when creating corresponding task pool using builder. - Due to lack of debug derive for these two options, manually implement the debug for `TaskPoolThreadAssignmentPolicy`. --- ## Changelog ### Added - `on_thread_spawn` option and `on_thread_destroy` option to the `TaskPoolPlugin`, allow user to customize them as needed. ## Migration Guide - `TaskPooolThreadAssignmentPolicy` now has two additional fields: `on_thread_spawn` and `on_thread_destroy`. Please consider defaulting them to `None`. --------- Co-authored-by: François Mockers Co-authored-by: François Mockers --- crates/bevy_core/src/task_pool_options.rs | 78 ++++++++++++++++--- .../src/single_threaded_task_pool.rs | 10 +++ 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/crates/bevy_core/src/task_pool_options.rs b/crates/bevy_core/src/task_pool_options.rs index 276902fb49..cdb0418a35 100644 --- a/crates/bevy_core/src/task_pool_options.rs +++ b/crates/bevy_core/src/task_pool_options.rs @@ -1,9 +1,12 @@ use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder}; use bevy_utils::tracing::trace; +use alloc::sync::Arc; +use core::fmt::Debug; + /// Defines a simple way to determine how many threads to use given the number of remaining cores /// and number of total cores -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct TaskPoolThreadAssignmentPolicy { /// Force using at least this many threads pub min_threads: usize, @@ -12,6 +15,22 @@ pub struct TaskPoolThreadAssignmentPolicy { /// Target using this percentage of total cores, clamped by `min_threads` and `max_threads`. It is /// permitted to use 1.0 to try to use all remaining threads pub percent: f32, + /// Callback that is invoked once for every created thread as it starts. + /// This configuration will be ignored under wasm platform. + pub on_thread_spawn: Option>, + /// Callback that is invoked once for every created thread as it terminates + /// This configuration will be ignored under wasm platform. + pub on_thread_destroy: Option>, +} + +impl Debug for TaskPoolThreadAssignmentPolicy { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("TaskPoolThreadAssignmentPolicy") + .field("min_threads", &self.min_threads) + .field("max_threads", &self.max_threads) + .field("percent", &self.percent) + .finish() + } } impl TaskPoolThreadAssignmentPolicy { @@ -61,6 +80,8 @@ impl Default for TaskPoolOptions { min_threads: 1, max_threads: 4, percent: 0.25, + on_thread_spawn: None, + on_thread_destroy: None, }, // Use 25% of cores for async compute, at least 1, no more than 4 @@ -68,6 +89,8 @@ impl Default for TaskPoolOptions { min_threads: 1, max_threads: 4, percent: 0.25, + on_thread_spawn: None, + on_thread_destroy: None, }, // Use all remaining cores for compute (at least 1) @@ -75,6 +98,8 @@ impl Default for TaskPoolOptions { min_threads: 1, max_threads: usize::MAX, percent: 1.0, // This 1.0 here means "whatever is left over" + on_thread_spawn: None, + on_thread_destroy: None, }, } } @@ -108,10 +133,21 @@ impl TaskPoolOptions { remaining_threads = remaining_threads.saturating_sub(io_threads); IoTaskPool::get_or_init(|| { - TaskPoolBuilder::default() + let mut builder = TaskPoolBuilder::default() .num_threads(io_threads) - .thread_name("IO Task Pool".to_string()) - .build() + .thread_name("IO Task Pool".to_string()); + + #[cfg(not(target_arch = "wasm32"))] + { + if let Some(f) = self.io.on_thread_spawn.clone() { + builder = builder.on_thread_spawn(move || f()); + } + if let Some(f) = self.io.on_thread_destroy.clone() { + builder = builder.on_thread_destroy(move || f()); + } + } + + builder.build() }); } @@ -125,10 +161,21 @@ impl TaskPoolOptions { remaining_threads = remaining_threads.saturating_sub(async_compute_threads); AsyncComputeTaskPool::get_or_init(|| { - TaskPoolBuilder::default() + let mut builder = TaskPoolBuilder::default() .num_threads(async_compute_threads) - .thread_name("Async Compute Task Pool".to_string()) - .build() + .thread_name("Async Compute Task Pool".to_string()); + + #[cfg(not(target_arch = "wasm32"))] + { + if let Some(f) = self.async_compute.on_thread_spawn.clone() { + builder = builder.on_thread_spawn(move || f()); + } + if let Some(f) = self.async_compute.on_thread_destroy.clone() { + builder = builder.on_thread_destroy(move || f()); + } + } + + builder.build() }); } @@ -142,10 +189,21 @@ impl TaskPoolOptions { trace!("Compute Threads: {}", compute_threads); ComputeTaskPool::get_or_init(|| { - TaskPoolBuilder::default() + let mut builder = TaskPoolBuilder::default() .num_threads(compute_threads) - .thread_name("Compute Task Pool".to_string()) - .build() + .thread_name("Compute Task Pool".to_string()); + + #[cfg(not(target_arch = "wasm32"))] + { + if let Some(f) = self.compute.on_thread_spawn.clone() { + builder = builder.on_thread_spawn(move || f()); + } + if let Some(f) = self.compute.on_thread_destroy.clone() { + builder = builder.on_thread_destroy(move || f()); + } + } + + builder.build() }); } } diff --git a/crates/bevy_tasks/src/single_threaded_task_pool.rs b/crates/bevy_tasks/src/single_threaded_task_pool.rs index d7f994026a..054d22260e 100644 --- a/crates/bevy_tasks/src/single_threaded_task_pool.rs +++ b/crates/bevy_tasks/src/single_threaded_task_pool.rs @@ -46,6 +46,16 @@ impl TaskPoolBuilder { self } + /// No op on the single threaded task pool + pub fn on_thread_spawn(self, _f: impl Fn() + Send + Sync + 'static) -> Self { + self + } + + /// No op on the single threaded task pool + pub fn on_thread_destroy(self, _f: impl Fn() + Send + Sync + 'static) -> Self { + self + } + /// Creates a new [`TaskPool`] pub fn build(self) -> TaskPool { TaskPool::new_internal()