2020-08-29 19:35:41 +00:00
|
|
|
use std::{
|
|
|
|
future::Future,
|
|
|
|
mem,
|
|
|
|
pin::Pin,
|
2020-09-09 20:12:50 +00:00
|
|
|
sync::Arc,
|
2020-08-29 19:35:41 +00:00
|
|
|
thread::{self, JoinHandle},
|
|
|
|
};
|
|
|
|
|
2020-09-21 20:13:40 +00:00
|
|
|
use futures_lite::{future, pin};
|
2020-09-10 19:54:24 +00:00
|
|
|
|
2020-09-22 03:23:09 +00:00
|
|
|
use crate::Task;
|
|
|
|
|
2020-08-29 19:35:41 +00:00
|
|
|
/// Used to create a TaskPool
|
|
|
|
#[derive(Debug, Default, Clone)]
|
|
|
|
pub struct TaskPoolBuilder {
|
|
|
|
/// If set, we'll set up the thread pool to use at most n threads. Otherwise use
|
|
|
|
/// the logical core count of the system
|
|
|
|
num_threads: Option<usize>,
|
|
|
|
/// If set, we'll use the given stack size rather than the system default
|
|
|
|
stack_size: Option<usize>,
|
|
|
|
/// Allows customizing the name of the threads - helpful for debugging. If set, threads will
|
|
|
|
/// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
|
|
|
|
thread_name: Option<String>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl TaskPoolBuilder {
|
|
|
|
/// Creates a new TaskPoolBuilder instance
|
|
|
|
pub fn new() -> Self {
|
|
|
|
Self::default()
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Override the number of threads created for the pool. If unset, we default to the number
|
|
|
|
/// of logical cores of the system
|
|
|
|
pub fn num_threads(mut self, num_threads: usize) -> Self {
|
|
|
|
self.num_threads = Some(num_threads);
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Override the stack size of the threads created for the pool
|
|
|
|
pub fn stack_size(mut self, stack_size: usize) -> Self {
|
|
|
|
self.stack_size = Some(stack_size);
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Override the name of the threads created for the pool. If set, threads will
|
|
|
|
/// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
|
|
|
|
pub fn thread_name(mut self, thread_name: String) -> Self {
|
|
|
|
self.thread_name = Some(thread_name);
|
|
|
|
self
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Creates a new ThreadPoolBuilder based on the current options.
|
|
|
|
pub fn build(self) -> TaskPool {
|
|
|
|
TaskPool::new_internal(
|
|
|
|
self.num_threads,
|
|
|
|
self.stack_size,
|
|
|
|
self.thread_name.as_deref(),
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-10-08 18:43:01 +00:00
|
|
|
#[derive(Debug)]
|
2020-08-29 19:35:41 +00:00
|
|
|
struct TaskPoolInner {
|
2020-09-09 20:12:50 +00:00
|
|
|
threads: Vec<JoinHandle<()>>,
|
|
|
|
shutdown_tx: async_channel::Sender<()>,
|
2020-08-29 19:35:41 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
impl Drop for TaskPoolInner {
|
|
|
|
fn drop(&mut self) {
|
2020-09-09 20:12:50 +00:00
|
|
|
self.shutdown_tx.close();
|
2020-08-29 19:35:41 +00:00
|
|
|
|
2020-12-22 20:21:21 +00:00
|
|
|
let panicking = thread::panicking();
|
2020-09-09 20:12:50 +00:00
|
|
|
for join_handle in self.threads.drain(..) {
|
2020-12-22 20:21:21 +00:00
|
|
|
let res = join_handle.join();
|
|
|
|
if !panicking {
|
|
|
|
res.expect("Task thread panicked while executing.");
|
|
|
|
}
|
2020-08-29 19:35:41 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// A thread pool for executing tasks. Tasks are futures that are being automatically driven by
|
|
|
|
/// the pool on threads owned by the pool.
|
2020-10-08 18:43:01 +00:00
|
|
|
#[derive(Debug, Clone)]
|
2020-08-29 19:35:41 +00:00
|
|
|
pub struct TaskPool {
|
|
|
|
/// The executor for the pool
|
|
|
|
///
|
|
|
|
/// This has to be separate from TaskPoolInner because we have to create an Arc<Executor> to
|
|
|
|
/// pass into the worker threads, and we must create the worker threads before we can create the
|
|
|
|
/// Vec<Task<T>> contained within TaskPoolInner
|
2020-09-20 18:27:24 +00:00
|
|
|
executor: Arc<async_executor::Executor<'static>>,
|
2020-08-29 19:35:41 +00:00
|
|
|
|
|
|
|
/// Inner state of the pool
|
|
|
|
inner: Arc<TaskPoolInner>,
|
|
|
|
}
|
|
|
|
|
|
|
|
impl TaskPool {
|
2021-01-18 21:48:28 +00:00
|
|
|
thread_local! {
|
|
|
|
static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = async_executor::LocalExecutor::new();
|
|
|
|
}
|
|
|
|
|
2020-08-29 19:35:41 +00:00
|
|
|
/// Create a `TaskPool` with the default configuration.
|
|
|
|
pub fn new() -> Self {
|
|
|
|
TaskPoolBuilder::new().build()
|
|
|
|
}
|
|
|
|
|
|
|
|
fn new_internal(
|
|
|
|
num_threads: Option<usize>,
|
|
|
|
stack_size: Option<usize>,
|
|
|
|
thread_name: Option<&str>,
|
|
|
|
) -> Self {
|
2020-09-09 20:12:50 +00:00
|
|
|
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
|
|
|
|
|
|
|
|
let executor = Arc::new(async_executor::Executor::new());
|
2020-08-29 19:35:41 +00:00
|
|
|
|
|
|
|
let num_threads = num_threads.unwrap_or_else(num_cpus::get);
|
|
|
|
|
|
|
|
let threads = (0..num_threads)
|
|
|
|
.map(|i| {
|
|
|
|
let ex = Arc::clone(&executor);
|
2020-09-09 20:12:50 +00:00
|
|
|
let shutdown_rx = shutdown_rx.clone();
|
2020-08-29 19:35:41 +00:00
|
|
|
|
|
|
|
let thread_name = if let Some(thread_name) = thread_name {
|
|
|
|
format!("{} ({})", thread_name, i)
|
|
|
|
} else {
|
|
|
|
format!("TaskPool ({})", i)
|
|
|
|
};
|
|
|
|
|
|
|
|
let mut thread_builder = thread::Builder::new().name(thread_name);
|
|
|
|
|
|
|
|
if let Some(stack_size) = stack_size {
|
|
|
|
thread_builder = thread_builder.stack_size(stack_size);
|
|
|
|
}
|
|
|
|
|
2020-09-09 20:12:50 +00:00
|
|
|
thread_builder
|
2020-08-29 19:35:41 +00:00
|
|
|
.spawn(move || {
|
2020-09-09 20:12:50 +00:00
|
|
|
let shutdown_future = ex.run(shutdown_rx.recv());
|
|
|
|
// Use unwrap_err because we expect a Closed error
|
2020-09-10 19:54:24 +00:00
|
|
|
future::block_on(shutdown_future).unwrap_err();
|
2020-08-29 19:35:41 +00:00
|
|
|
})
|
2020-12-02 19:31:16 +00:00
|
|
|
.expect("Failed to spawn thread.")
|
2020-08-29 19:35:41 +00:00
|
|
|
})
|
|
|
|
.collect();
|
|
|
|
|
|
|
|
Self {
|
|
|
|
executor,
|
|
|
|
inner: Arc::new(TaskPoolInner {
|
|
|
|
threads,
|
2020-09-09 20:12:50 +00:00
|
|
|
shutdown_tx,
|
2020-08-29 19:35:41 +00:00
|
|
|
}),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Return the number of threads owned by the task pool
|
|
|
|
pub fn thread_num(&self) -> usize {
|
|
|
|
self.inner.threads.len()
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Allows spawning non-`static futures on the thread pool. The function takes a callback,
|
|
|
|
/// passing a scope object into it. The scope object provided to the callback can be used
|
|
|
|
/// to spawn tasks. This function will await the completion of all tasks before returning.
|
|
|
|
///
|
|
|
|
/// This is similar to `rayon::scope` and `crossbeam::scope`
|
|
|
|
pub fn scope<'scope, F, T>(&self, f: F) -> Vec<T>
|
|
|
|
where
|
|
|
|
F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send,
|
|
|
|
T: Send + 'static,
|
|
|
|
{
|
2021-01-18 21:48:28 +00:00
|
|
|
TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
|
|
|
|
// SAFETY: This function blocks until all futures complete, so this future must return
|
|
|
|
// before this function returns. However, rust has no way of knowing
|
|
|
|
// this so we must convert to 'static here to appease the compiler as it is unable to
|
|
|
|
// validate safety.
|
|
|
|
let executor: &async_executor::Executor = &*self.executor;
|
|
|
|
let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) };
|
|
|
|
let local_executor: &'scope async_executor::LocalExecutor =
|
|
|
|
unsafe { mem::transmute(local_executor) };
|
|
|
|
let mut scope = Scope {
|
|
|
|
executor,
|
|
|
|
local_executor,
|
|
|
|
spawned: Vec::new(),
|
2020-11-27 20:14:44 +00:00
|
|
|
};
|
2020-08-29 19:35:41 +00:00
|
|
|
|
2021-01-18 21:48:28 +00:00
|
|
|
f(&mut scope);
|
|
|
|
|
|
|
|
if scope.spawned.is_empty() {
|
|
|
|
Vec::default()
|
|
|
|
} else if scope.spawned.len() == 1 {
|
|
|
|
vec![future::block_on(&mut scope.spawned[0])]
|
|
|
|
} else {
|
|
|
|
let fut = async move {
|
|
|
|
let mut results = Vec::with_capacity(scope.spawned.len());
|
|
|
|
for task in scope.spawned {
|
|
|
|
results.push(task.await);
|
|
|
|
}
|
2020-11-27 20:14:44 +00:00
|
|
|
|
2021-01-18 21:48:28 +00:00
|
|
|
results
|
|
|
|
};
|
2020-08-29 19:35:41 +00:00
|
|
|
|
2021-01-18 21:48:28 +00:00
|
|
|
// Pin the futures on the stack.
|
|
|
|
pin!(fut);
|
|
|
|
|
|
|
|
// SAFETY: This function blocks until all futures complete, so we do not read/write the
|
|
|
|
// data from futures outside of the 'scope lifetime. However, rust has no way of knowing
|
|
|
|
// this so we must convert to 'static here to appease the compiler as it is unable to
|
|
|
|
// validate safety.
|
|
|
|
let fut: Pin<&mut (dyn Future<Output = Vec<T>>)> = fut;
|
|
|
|
let fut: Pin<&'static mut (dyn Future<Output = Vec<T>> + 'static)> =
|
|
|
|
unsafe { mem::transmute(fut) };
|
|
|
|
|
|
|
|
// The thread that calls scope() will participate in driving tasks in the pool forward
|
|
|
|
// until the tasks that are spawned by this scope() call complete. (If the caller of scope()
|
|
|
|
// happens to be a thread in this thread pool, and we only have one thread in the pool, then
|
|
|
|
// simply calling future::block_on(spawned) would deadlock.)
|
|
|
|
let mut spawned = local_executor.spawn(fut);
|
|
|
|
loop {
|
|
|
|
if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
|
|
|
|
break result;
|
|
|
|
};
|
|
|
|
|
|
|
|
self.executor.try_tick();
|
|
|
|
local_executor.try_tick();
|
|
|
|
}
|
2020-11-26 02:05:55 +00:00
|
|
|
}
|
2021-01-18 21:48:28 +00:00
|
|
|
})
|
2020-08-29 19:35:41 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Spawns a static future onto the thread pool. The returned Task is a future. It can also be
|
|
|
|
/// cancelled and "detached" allowing it to continue running without having to be polled by the
|
|
|
|
/// end-user.
|
2020-09-22 03:23:09 +00:00
|
|
|
pub fn spawn<T>(&self, future: impl Future<Output = T> + Send + 'static) -> Task<T>
|
2020-08-29 19:35:41 +00:00
|
|
|
where
|
|
|
|
T: Send + 'static,
|
|
|
|
{
|
2020-09-22 03:23:09 +00:00
|
|
|
Task::new(self.executor.spawn(future))
|
2020-08-29 19:35:41 +00:00
|
|
|
}
|
2021-01-18 21:48:28 +00:00
|
|
|
|
|
|
|
pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
|
|
|
|
where
|
|
|
|
T: 'static,
|
|
|
|
{
|
|
|
|
Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
|
|
|
|
}
|
2020-08-29 19:35:41 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
impl Default for TaskPool {
|
|
|
|
fn default() -> Self {
|
|
|
|
Self::new()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-10-08 18:43:01 +00:00
|
|
|
#[derive(Debug)]
|
2020-08-29 19:35:41 +00:00
|
|
|
pub struct Scope<'scope, T> {
|
2020-09-20 18:27:24 +00:00
|
|
|
executor: &'scope async_executor::Executor<'scope>,
|
2021-01-18 21:48:28 +00:00
|
|
|
local_executor: &'scope async_executor::LocalExecutor<'scope>,
|
2020-09-09 20:12:50 +00:00
|
|
|
spawned: Vec<async_executor::Task<T>>,
|
2020-08-29 19:35:41 +00:00
|
|
|
}
|
|
|
|
|
2020-09-20 18:27:24 +00:00
|
|
|
impl<'scope, T: Send + 'scope> Scope<'scope, T> {
|
2020-08-29 19:35:41 +00:00
|
|
|
pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&mut self, f: Fut) {
|
2020-09-20 18:27:24 +00:00
|
|
|
let task = self.executor.spawn(f);
|
2020-08-29 19:35:41 +00:00
|
|
|
self.spawned.push(task);
|
|
|
|
}
|
2021-01-18 21:48:28 +00:00
|
|
|
|
|
|
|
pub fn spawn_local<Fut: Future<Output = T> + 'scope>(&mut self, f: Fut) {
|
|
|
|
let task = self.local_executor.spawn(f);
|
|
|
|
self.spawned.push(task);
|
|
|
|
}
|
2020-08-29 19:35:41 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
mod tests {
|
|
|
|
use super::*;
|
2021-01-18 21:48:28 +00:00
|
|
|
use std::sync::{
|
|
|
|
atomic::{AtomicBool, AtomicI32, Ordering},
|
|
|
|
Barrier,
|
|
|
|
};
|
2020-08-29 19:35:41 +00:00
|
|
|
|
|
|
|
#[test]
|
|
|
|
pub fn test_spawn() {
|
|
|
|
let pool = TaskPool::new();
|
|
|
|
|
|
|
|
let foo = Box::new(42);
|
|
|
|
let foo = &*foo;
|
|
|
|
|
2020-09-06 04:46:23 +00:00
|
|
|
let count = Arc::new(AtomicI32::new(0));
|
|
|
|
|
2020-08-29 19:35:41 +00:00
|
|
|
let outputs = pool.scope(|scope| {
|
2020-09-06 04:46:23 +00:00
|
|
|
for _ in 0..100 {
|
|
|
|
let count_clone = count.clone();
|
2020-08-29 19:35:41 +00:00
|
|
|
scope.spawn(async move {
|
|
|
|
if *foo != 42 {
|
|
|
|
panic!("not 42!?!?")
|
|
|
|
} else {
|
2020-09-06 04:46:23 +00:00
|
|
|
count_clone.fetch_add(1, Ordering::Relaxed);
|
2020-08-29 19:35:41 +00:00
|
|
|
*foo
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
2020-09-06 04:46:23 +00:00
|
|
|
for output in &outputs {
|
|
|
|
assert_eq!(*output, 42);
|
2020-08-29 19:35:41 +00:00
|
|
|
}
|
2020-09-06 04:46:23 +00:00
|
|
|
|
|
|
|
assert_eq!(outputs.len(), 100);
|
|
|
|
assert_eq!(count.load(Ordering::Relaxed), 100);
|
2020-08-29 19:35:41 +00:00
|
|
|
}
|
2021-01-18 21:48:28 +00:00
|
|
|
|
|
|
|
#[test]
|
|
|
|
pub fn test_mixed_spawn_local_and_spawn() {
|
|
|
|
let pool = TaskPool::new();
|
|
|
|
|
|
|
|
let foo = Box::new(42);
|
|
|
|
let foo = &*foo;
|
|
|
|
|
|
|
|
let local_count = Arc::new(AtomicI32::new(0));
|
|
|
|
let non_local_count = Arc::new(AtomicI32::new(0));
|
|
|
|
|
|
|
|
let outputs = pool.scope(|scope| {
|
|
|
|
for i in 0..100 {
|
|
|
|
if i % 2 == 0 {
|
|
|
|
let count_clone = non_local_count.clone();
|
|
|
|
scope.spawn(async move {
|
|
|
|
if *foo != 42 {
|
|
|
|
panic!("not 42!?!?")
|
|
|
|
} else {
|
|
|
|
count_clone.fetch_add(1, Ordering::Relaxed);
|
|
|
|
*foo
|
|
|
|
}
|
|
|
|
});
|
|
|
|
} else {
|
|
|
|
let count_clone = local_count.clone();
|
|
|
|
scope.spawn_local(async move {
|
|
|
|
if *foo != 42 {
|
|
|
|
panic!("not 42!?!?")
|
|
|
|
} else {
|
|
|
|
count_clone.fetch_add(1, Ordering::Relaxed);
|
|
|
|
*foo
|
|
|
|
}
|
|
|
|
});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
|
|
|
for output in &outputs {
|
|
|
|
assert_eq!(*output, 42);
|
|
|
|
}
|
|
|
|
|
|
|
|
assert_eq!(outputs.len(), 100);
|
|
|
|
assert_eq!(local_count.load(Ordering::Relaxed), 50);
|
|
|
|
assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
|
|
|
|
}
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
pub fn test_thread_locality() {
|
|
|
|
let pool = Arc::new(TaskPool::new());
|
|
|
|
let count = Arc::new(AtomicI32::new(0));
|
|
|
|
let barrier = Arc::new(Barrier::new(101));
|
|
|
|
let thread_check_failed = Arc::new(AtomicBool::new(false));
|
|
|
|
|
|
|
|
for _ in 0..100 {
|
|
|
|
let inner_barrier = barrier.clone();
|
|
|
|
let count_clone = count.clone();
|
|
|
|
let inner_pool = pool.clone();
|
|
|
|
let inner_thread_check_failed = thread_check_failed.clone();
|
|
|
|
std::thread::spawn(move || {
|
|
|
|
inner_pool.scope(|scope| {
|
|
|
|
let inner_count_clone = count_clone.clone();
|
|
|
|
scope.spawn(async move {
|
|
|
|
inner_count_clone.fetch_add(1, Ordering::Release);
|
|
|
|
});
|
|
|
|
let spawner = std::thread::current().id();
|
|
|
|
let inner_count_clone = count_clone.clone();
|
|
|
|
scope.spawn_local(async move {
|
|
|
|
inner_count_clone.fetch_add(1, Ordering::Release);
|
|
|
|
if std::thread::current().id() != spawner {
|
|
|
|
// NOTE: This check is using an atomic rather than simply panicing the thread to avoid deadlocking the barrier on failure
|
|
|
|
inner_thread_check_failed.store(true, Ordering::Release);
|
|
|
|
}
|
|
|
|
});
|
|
|
|
});
|
|
|
|
inner_barrier.wait();
|
|
|
|
});
|
|
|
|
}
|
|
|
|
barrier.wait();
|
|
|
|
assert!(!thread_check_failed.load(Ordering::Acquire));
|
|
|
|
assert_eq!(count.load(Ordering::Acquire), 200);
|
|
|
|
}
|
2020-08-29 19:35:41 +00:00
|
|
|
}
|