diff --git a/crates/bevy_tasks/src/lib.rs b/crates/bevy_tasks/src/lib.rs index 802f6c267b..ae9b0a3dbd 100644 --- a/crates/bevy_tasks/src/lib.rs +++ b/crates/bevy_tasks/src/lib.rs @@ -22,6 +22,11 @@ mod usages; pub use usages::tick_global_task_pools_on_main_thread; pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool}; +#[cfg(not(target_arch = "wasm32"))] +mod thread_executor; +#[cfg(not(target_arch = "wasm32"))] +pub use thread_executor::{ThreadExecutor, ThreadExecutorTicker}; + mod iter; pub use iter::ParallelIterator; diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index 958e35aafd..7bca59e7db 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -10,7 +10,7 @@ use async_task::FallibleTask; use concurrent_queue::ConcurrentQueue; use futures_lite::{future, pin, FutureExt}; -use crate::Task; +use crate::{thread_executor::ThreadExecutor, Task}; struct CallOnDrop(Option>); @@ -108,6 +108,7 @@ pub struct TaskPool { impl TaskPool { thread_local! { static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = async_executor::LocalExecutor::new(); + static THREAD_EXECUTOR: ThreadExecutor<'static> = ThreadExecutor::new(); } /// Create a `TaskPool` with the default configuration. @@ -271,59 +272,61 @@ impl TaskPool { F: for<'scope> FnOnce(&'scope Scope<'scope, 'env, T>), T: Send + 'static, { - // SAFETY: This safety comment applies to all references transmuted to 'env. - // Any futures spawned with these references need to return before this function completes. - // This is guaranteed because we drive all the futures spawned onto the Scope - // to completion in this function. However, rust has no way of knowing this so we - // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety. - let executor: &async_executor::Executor = &self.executor; - let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) }; - let task_scope_executor = &async_executor::Executor::default(); - let task_scope_executor: &'env async_executor::Executor = - unsafe { mem::transmute(task_scope_executor) }; - let spawned: ConcurrentQueue> = ConcurrentQueue::unbounded(); - let spawned_ref: &'env ConcurrentQueue> = - unsafe { mem::transmute(&spawned) }; + Self::THREAD_EXECUTOR.with(|thread_executor| { + // SAFETY: This safety comment applies to all references transmuted to 'env. + // Any futures spawned with these references need to return before this function completes. + // This is guaranteed because we drive all the futures spawned onto the Scope + // to completion in this function. However, rust has no way of knowing this so we + // transmute the lifetimes to 'env here to appease the compiler as it is unable to validate safety. + let executor: &async_executor::Executor = &self.executor; + let executor: &'env async_executor::Executor = unsafe { mem::transmute(executor) }; + let thread_executor: &'env ThreadExecutor<'env> = + unsafe { mem::transmute(thread_executor) }; + let spawned: ConcurrentQueue> = ConcurrentQueue::unbounded(); + let spawned_ref: &'env ConcurrentQueue> = + unsafe { mem::transmute(&spawned) }; - let scope = Scope { - executor, - task_scope_executor, - spawned: spawned_ref, - scope: PhantomData, - env: PhantomData, - }; - - let scope_ref: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) }; - - f(scope_ref); - - if spawned.is_empty() { - Vec::new() - } else { - let get_results = async { - let mut results = Vec::with_capacity(spawned_ref.len()); - while let Ok(task) = spawned_ref.pop() { - results.push(task.await.unwrap()); - } - - results + let scope = Scope { + executor, + thread_executor, + spawned: spawned_ref, + scope: PhantomData, + env: PhantomData, }; - // Pin the futures on the stack. - pin!(get_results); + let scope_ref: &'env Scope<'_, 'env, T> = unsafe { mem::transmute(&scope) }; - loop { - if let Some(result) = future::block_on(future::poll_once(&mut get_results)) { - break result; + f(scope_ref); + + if spawned.is_empty() { + Vec::new() + } else { + let get_results = async { + let mut results = Vec::with_capacity(spawned_ref.len()); + while let Ok(task) = spawned_ref.pop() { + results.push(task.await.unwrap()); + } + + results }; - std::panic::catch_unwind(|| { - executor.try_tick(); - task_scope_executor.try_tick(); - }) - .ok(); + // Pin the futures on the stack. + pin!(get_results); + + let thread_ticker = thread_executor.ticker().unwrap(); + loop { + if let Some(result) = future::block_on(future::poll_once(&mut get_results)) { + break result; + }; + + std::panic::catch_unwind(|| { + executor.try_tick(); + thread_ticker.try_tick(); + }) + .ok(); + } } - } + }) } /// Spawns a static future onto the thread pool. The returned Task is a future. It can also be @@ -395,7 +398,7 @@ impl Drop for TaskPool { #[derive(Debug)] pub struct Scope<'scope, 'env: 'scope, T> { executor: &'scope async_executor::Executor<'scope>, - task_scope_executor: &'scope async_executor::Executor<'scope>, + thread_executor: &'scope ThreadExecutor<'scope>, spawned: &'scope ConcurrentQueue>, // make `Scope` invariant over 'scope and 'env scope: PhantomData<&'scope mut &'scope ()>, @@ -425,7 +428,7 @@ impl<'scope, 'env, T: Send + 'scope> Scope<'scope, 'env, T> { /// /// For more information, see [`TaskPool::scope`]. pub fn spawn_on_scope + 'scope + Send>(&self, f: Fut) { - let task = self.task_scope_executor.spawn(f).fallible(); + let task = self.thread_executor.spawn(f).fallible(); // ConcurrentQueue only errors when closed or full, but we never // close and use an unbounded queue, so it is safe to unwrap self.spawned.push(task).unwrap(); diff --git a/crates/bevy_tasks/src/thread_executor.rs b/crates/bevy_tasks/src/thread_executor.rs new file mode 100644 index 0000000000..0ba66571db --- /dev/null +++ b/crates/bevy_tasks/src/thread_executor.rs @@ -0,0 +1,128 @@ +use std::{ + marker::PhantomData, + thread::{self, ThreadId}, +}; + +use async_executor::{Executor, Task}; +use futures_lite::Future; + +/// An executor that can only be ticked on the thread it was instantiated on. But +/// can spawn `Send` tasks from other threads. +/// +/// # Example +/// ```rust +/// # use std::sync::{Arc, atomic::{AtomicI32, Ordering}}; +/// use bevy_tasks::ThreadExecutor; +/// +/// let thread_executor = ThreadExecutor::new(); +/// let count = Arc::new(AtomicI32::new(0)); +/// +/// // create some owned values that can be moved into another thread +/// let count_clone = count.clone(); +/// +/// std::thread::scope(|scope| { +/// scope.spawn(|| { +/// // we cannot get the ticker from another thread +/// let not_thread_ticker = thread_executor.ticker(); +/// assert!(not_thread_ticker.is_none()); +/// +/// // but we can spawn tasks from another thread +/// thread_executor.spawn(async move { +/// count_clone.fetch_add(1, Ordering::Relaxed); +/// }).detach(); +/// }); +/// }); +/// +/// // the tasks do not make progress unless the executor is manually ticked +/// assert_eq!(count.load(Ordering::Relaxed), 0); +/// +/// // tick the ticker until task finishes +/// let thread_ticker = thread_executor.ticker().unwrap(); +/// thread_ticker.try_tick(); +/// assert_eq!(count.load(Ordering::Relaxed), 1); +/// ``` +#[derive(Debug)] +pub struct ThreadExecutor<'task> { + executor: Executor<'task>, + thread_id: ThreadId, +} + +impl<'task> Default for ThreadExecutor<'task> { + fn default() -> Self { + Self { + executor: Executor::new(), + thread_id: thread::current().id(), + } + } +} + +impl<'task> ThreadExecutor<'task> { + /// create a new [`ThreadExecutor`] + pub fn new() -> Self { + Self::default() + } + + /// Spawn a task on the thread executor + pub fn spawn( + &self, + future: impl Future + Send + 'task, + ) -> Task { + self.executor.spawn(future) + } + + /// Gets the [`ThreadExecutorTicker`] for this executor. + /// Use this to tick the executor. + /// It only returns the ticker if it's on the thread the executor was created on + /// and returns `None` otherwise. + pub fn ticker<'ticker>(&'ticker self) -> Option> { + if thread::current().id() == self.thread_id { + return Some(ThreadExecutorTicker { + executor: &self.executor, + _marker: PhantomData::default(), + }); + } + None + } +} + +/// Used to tick the [`ThreadExecutor`]. The executor does not +/// make progress unless it is manually ticked on the thread it was +/// created on. +#[derive(Debug)] +pub struct ThreadExecutorTicker<'task, 'ticker> { + executor: &'ticker Executor<'task>, + // make type not send or sync + _marker: PhantomData<*const ()>, +} +impl<'task, 'ticker> ThreadExecutorTicker<'task, 'ticker> { + /// Tick the thread executor. + pub async fn tick(&self) { + self.executor.tick().await; + } + + /// Synchronously try to tick a task on the executor. + /// Returns false if if does not find a task to tick. + pub fn try_tick(&self) -> bool { + self.executor.try_tick() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn test_ticker() { + let executor = Arc::new(ThreadExecutor::new()); + let ticker = executor.ticker(); + assert!(ticker.is_some()); + + std::thread::scope(|s| { + s.spawn(|| { + let ticker = executor.ticker(); + assert!(ticker.is_none()); + }); + }); + } +}