From 2924fd22138080ebf15b2aa05d31458d9fe3907d Mon Sep 17 00:00:00 2001 From: Luna Razzaghipour Date: Thu, 25 May 2023 17:04:51 +1000 Subject: [PATCH] Implement custom QoS-aware thread pool This code replaces the thread pool implementation we were using previously (from the `threadpool` crate). By making the thread pool aware of QoS, each job spawned on the thread pool can have a different QoS class. This commit also replaces every QoS class used previously with Default as a temporary measure so that each usage can be chosen deliberately. --- Cargo.lock | 11 +-- crates/flycheck/src/lib.rs | 4 +- crates/ide/src/prime_caches.rs | 2 +- crates/rust-analyzer/Cargo.toml | 1 - crates/rust-analyzer/src/bin/main.rs | 2 +- crates/rust-analyzer/src/dispatch.rs | 55 ++++++++--- .../src/handlers/notification.rs | 2 +- crates/rust-analyzer/src/main_loop.rs | 34 ++++--- crates/rust-analyzer/src/reload.rs | 6 +- crates/rust-analyzer/src/task_pool.rs | 56 +++-------- crates/stdx/Cargo.toml | 1 + crates/stdx/src/thread.rs | 6 ++ crates/stdx/src/thread/pool.rs | 95 +++++++++++++++++++ crates/vfs-notify/src/lib.rs | 2 +- 14 files changed, 184 insertions(+), 93 deletions(-) create mode 100644 crates/stdx/src/thread/pool.rs diff --git a/Cargo.lock b/Cargo.lock index f9c5417ffb..322a67383b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1518,7 +1518,6 @@ dependencies = [ "syntax", "test-utils", "thiserror", - "threadpool", "tikv-jemallocator", "toolchain", "tracing", @@ -1712,6 +1711,7 @@ version = "0.0.0" dependencies = [ "always-assert", "backtrace", + "crossbeam-channel", "jod-thread", "libc", "miow", @@ -1823,15 +1823,6 @@ dependencies = [ "once_cell", ] -[[package]] -name = "threadpool" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa" -dependencies = [ - "num_cpus", -] - [[package]] name = "tikv-jemalloc-ctl" version = "0.5.0" diff --git a/crates/flycheck/src/lib.rs b/crates/flycheck/src/lib.rs index e40257c58f..190205a2cd 100644 --- a/crates/flycheck/src/lib.rs +++ b/crates/flycheck/src/lib.rs @@ -90,7 +90,7 @@ impl FlycheckHandle { ) -> FlycheckHandle { let actor = FlycheckActor::new(id, sender, config, workspace_root); let (sender, receiver) = unbounded::(); - let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Utility) + let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Default) .name("Flycheck".to_owned()) .spawn(move || actor.run(receiver)) .expect("failed to spawn thread"); @@ -409,7 +409,7 @@ impl CargoHandle { let (sender, receiver) = unbounded(); let actor = CargoActor::new(sender, stdout, stderr); - let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Utility) + let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Default) .name("CargoHandle".to_owned()) .spawn(move || actor.run()) .expect("failed to spawn thread"); diff --git a/crates/ide/src/prime_caches.rs b/crates/ide/src/prime_caches.rs index f049a225f0..8c8a93bcb8 100644 --- a/crates/ide/src/prime_caches.rs +++ b/crates/ide/src/prime_caches.rs @@ -81,7 +81,7 @@ pub(crate) fn parallel_prime_caches( let worker = prime_caches_worker.clone(); let db = db.snapshot(); - stdx::thread::Builder::new(stdx::thread::QoSClass::Utility) + stdx::thread::Builder::new(stdx::thread::QoSClass::Default) .allow_leak(true) .spawn(move || Cancelled::catch(|| worker(db))) .expect("failed to spawn thread"); diff --git a/crates/rust-analyzer/Cargo.toml b/crates/rust-analyzer/Cargo.toml index 3f795340b2..97bd920920 100644 --- a/crates/rust-analyzer/Cargo.toml +++ b/crates/rust-analyzer/Cargo.toml @@ -31,7 +31,6 @@ oorandom = "11.1.3" rustc-hash = "1.1.0" serde_json = { workspace = true, features = ["preserve_order"] } serde.workspace = true -threadpool = "1.8.1" rayon = "1.6.1" num_cpus = "1.15.0" mimalloc = { version = "0.1.30", default-features = false, optional = true } diff --git a/crates/rust-analyzer/src/bin/main.rs b/crates/rust-analyzer/src/bin/main.rs index 3224aeae56..eba1933311 100644 --- a/crates/rust-analyzer/src/bin/main.rs +++ b/crates/rust-analyzer/src/bin/main.rs @@ -85,7 +85,7 @@ fn try_main(flags: flags::RustAnalyzer) -> Result<()> { // will make actions like hitting enter in the editor slow. // rust-analyzer does not block the editor’s render loop, // so we don’t use User Interactive. - with_extra_thread("LspServer", stdx::thread::QoSClass::UserInitiated, run_server)?; + with_extra_thread("LspServer", stdx::thread::QoSClass::Default, run_server)?; } flags::RustAnalyzerCmd::Parse(cmd) => cmd.run()?, flags::RustAnalyzerCmd::Symbols(cmd) => cmd.run()?, diff --git a/crates/rust-analyzer/src/dispatch.rs b/crates/rust-analyzer/src/dispatch.rs index 313bb2ec8d..c4731340ba 100644 --- a/crates/rust-analyzer/src/dispatch.rs +++ b/crates/rust-analyzer/src/dispatch.rs @@ -4,6 +4,7 @@ use std::{fmt, panic, thread}; use ide::Cancelled; use lsp_server::ExtractError; use serde::{de::DeserializeOwned, Serialize}; +use stdx::thread::QoSClass; use crate::{ global_state::{GlobalState, GlobalStateSnapshot}, @@ -102,7 +103,7 @@ impl<'a> RequestDispatcher<'a> { None => return self, }; - self.global_state.task_pool.handle.spawn({ + self.global_state.task_pool.handle.spawn(QoSClass::Default, { let world = self.global_state.snapshot(); move || { let result = panic::catch_unwind(move || { @@ -128,6 +129,44 @@ impl<'a> RequestDispatcher<'a> { &mut self, f: fn(GlobalStateSnapshot, R::Params) -> Result, ) -> &mut Self + where + R: lsp_types::request::Request + 'static, + R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug, + R::Result: Serialize, + { + self.on_with_qos::(QoSClass::Default, f) + } + + /// Dispatches a latency-sensitive request onto the thread pool. + pub(crate) fn on_latency_sensitive( + &mut self, + f: fn(GlobalStateSnapshot, R::Params) -> Result, + ) -> &mut Self + where + R: lsp_types::request::Request + 'static, + R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug, + R::Result: Serialize, + { + self.on_with_qos::(QoSClass::Default, f) + } + + pub(crate) fn finish(&mut self) { + if let Some(req) = self.req.take() { + tracing::error!("unknown request: {:?}", req); + let response = lsp_server::Response::new_err( + req.id, + lsp_server::ErrorCode::MethodNotFound as i32, + "unknown request".to_string(), + ); + self.global_state.respond(response); + } + } + + fn on_with_qos( + &mut self, + qos_class: QoSClass, + f: fn(GlobalStateSnapshot, R::Params) -> Result, + ) -> &mut Self where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug, @@ -138,7 +177,7 @@ impl<'a> RequestDispatcher<'a> { None => return self, }; - self.global_state.task_pool.handle.spawn({ + self.global_state.task_pool.handle.spawn(qos_class, { let world = self.global_state.snapshot(); move || { let result = panic::catch_unwind(move || { @@ -155,18 +194,6 @@ impl<'a> RequestDispatcher<'a> { self } - pub(crate) fn finish(&mut self) { - if let Some(req) = self.req.take() { - tracing::error!("unknown request: {:?}", req); - let response = lsp_server::Response::new_err( - req.id, - lsp_server::ErrorCode::MethodNotFound as i32, - "unknown request".to_string(), - ); - self.global_state.respond(response); - } - } - fn parse(&mut self) -> Option<(lsp_server::Request, R::Params, String)> where R: lsp_types::request::Request, diff --git a/crates/rust-analyzer/src/handlers/notification.rs b/crates/rust-analyzer/src/handlers/notification.rs index 7074ef018a..2d871748c3 100644 --- a/crates/rust-analyzer/src/handlers/notification.rs +++ b/crates/rust-analyzer/src/handlers/notification.rs @@ -291,7 +291,7 @@ fn run_flycheck(state: &mut GlobalState, vfs_path: VfsPath) -> bool { } Ok(()) }; - state.task_pool.handle.spawn_with_sender(move |_| { + state.task_pool.handle.spawn_with_sender(stdx::thread::QoSClass::Default, move |_| { if let Err(e) = std::panic::catch_unwind(task) { tracing::error!("flycheck task panicked: {e:?}") } diff --git a/crates/rust-analyzer/src/main_loop.rs b/crates/rust-analyzer/src/main_loop.rs index a28edde2f4..ae9f6ff7ee 100644 --- a/crates/rust-analyzer/src/main_loop.rs +++ b/crates/rust-analyzer/src/main_loop.rs @@ -397,7 +397,7 @@ impl GlobalState { tracing::debug!(%cause, "will prime caches"); let num_worker_threads = self.config.prime_caches_num_threads(); - self.task_pool.handle.spawn_with_sender({ + self.task_pool.handle.spawn_with_sender(stdx::thread::QoSClass::Default, { let analysis = self.snapshot().analysis; move |sender| { sender.send(Task::PrimeCaches(PrimeCachesProgress::Begin)).unwrap(); @@ -678,7 +678,24 @@ impl GlobalState { .on_sync::(handlers::handle_selection_range) .on_sync::(handlers::handle_matching_brace) .on_sync::(handlers::handle_on_type_formatting) - // All other request handlers: + // We can’t run latency-sensitive request handlers which do semantic + // analysis on the main thread because that would block other + // requests. Instead, we run these request handlers on higher QoS + // threads in the threadpool. + .on_latency_sensitive::(handlers::handle_completion) + .on_latency_sensitive::( + handlers::handle_completion_resolve, + ) + .on_latency_sensitive::( + handlers::handle_semantic_tokens_full, + ) + .on_latency_sensitive::( + handlers::handle_semantic_tokens_full_delta, + ) + .on_latency_sensitive::( + handlers::handle_semantic_tokens_range, + ) + // All other request handlers .on::(handlers::fetch_dependency_list) .on::(handlers::handle_analyzer_status) .on::(handlers::handle_syntax_tree) @@ -706,8 +723,6 @@ impl GlobalState { .on::(handlers::handle_goto_type_definition) .on_no_retry::(handlers::handle_inlay_hints) .on::(handlers::handle_inlay_hints_resolve) - .on::(handlers::handle_completion) - .on::(handlers::handle_completion_resolve) .on::(handlers::handle_code_lens) .on::(handlers::handle_code_lens_resolve) .on::(handlers::handle_folding_range) @@ -725,15 +740,6 @@ impl GlobalState { .on::( handlers::handle_call_hierarchy_outgoing, ) - .on::( - handlers::handle_semantic_tokens_full, - ) - .on::( - handlers::handle_semantic_tokens_full_delta, - ) - .on::( - handlers::handle_semantic_tokens_range, - ) .on::(handlers::handle_will_rename_files) .on::(handlers::handle_ssr) .finish(); @@ -781,7 +787,7 @@ impl GlobalState { tracing::trace!("updating notifications for {:?}", subscriptions); let snapshot = self.snapshot(); - self.task_pool.handle.spawn(move || { + self.task_pool.handle.spawn(stdx::thread::QoSClass::Default, move || { let _p = profile::span("publish_diagnostics"); let diagnostics = subscriptions .into_iter() diff --git a/crates/rust-analyzer/src/reload.rs b/crates/rust-analyzer/src/reload.rs index 4e29485573..7070950638 100644 --- a/crates/rust-analyzer/src/reload.rs +++ b/crates/rust-analyzer/src/reload.rs @@ -185,7 +185,7 @@ impl GlobalState { pub(crate) fn fetch_workspaces(&mut self, cause: Cause) { tracing::info!(%cause, "will fetch workspaces"); - self.task_pool.handle.spawn_with_sender({ + self.task_pool.handle.spawn_with_sender(stdx::thread::QoSClass::Default, { let linked_projects = self.config.linked_projects(); let detached_files = self.config.detached_files().to_vec(); let cargo_config = self.config.cargo(); @@ -260,7 +260,7 @@ impl GlobalState { tracing::info!(%cause, "will fetch build data"); let workspaces = Arc::clone(&self.workspaces); let config = self.config.cargo(); - self.task_pool.handle.spawn_with_sender(move |sender| { + self.task_pool.handle.spawn_with_sender(stdx::thread::QoSClass::Default, move |sender| { sender.send(Task::FetchBuildData(BuildDataProgress::Begin)).unwrap(); let progress = { @@ -280,7 +280,7 @@ impl GlobalState { let dummy_replacements = self.config.dummy_replacements().clone(); let proc_macro_clients = self.proc_macro_clients.clone(); - self.task_pool.handle.spawn_with_sender(move |sender| { + self.task_pool.handle.spawn_with_sender(stdx::thread::QoSClass::Default, move |sender| { sender.send(Task::LoadProcMacros(ProcMacroProgress::Begin)).unwrap(); let dummy_replacements = &dummy_replacements; diff --git a/crates/rust-analyzer/src/task_pool.rs b/crates/rust-analyzer/src/task_pool.rs index 0c5a4f3055..f055de40d0 100644 --- a/crates/rust-analyzer/src/task_pool.rs +++ b/crates/rust-analyzer/src/task_pool.rs @@ -1,76 +1,42 @@ -//! A thin wrapper around `ThreadPool` to make sure that we join all things -//! properly. -use std::sync::{Arc, Barrier}; +//! A thin wrapper around [`stdx::thread::Pool`] which threads a sender through spawned jobs. +//! It is used in [`crate::global_state::GlobalState`] throughout the main loop. use crossbeam_channel::Sender; +use stdx::thread::{Pool, QoSClass}; pub(crate) struct TaskPool { sender: Sender, - inner: threadpool::ThreadPool, + pool: Pool, } impl TaskPool { pub(crate) fn new_with_threads(sender: Sender, threads: usize) -> TaskPool { - const STACK_SIZE: usize = 8 * 1024 * 1024; - - let inner = threadpool::Builder::new() - .thread_name("Worker".into()) - .thread_stack_size(STACK_SIZE) - .num_threads(threads) - .build(); - - // Set QoS of all threads in threadpool. - let barrier = Arc::new(Barrier::new(threads + 1)); - for _ in 0..threads { - let barrier = barrier.clone(); - inner.execute(move || { - stdx::thread::set_current_thread_qos_class(stdx::thread::QoSClass::Utility); - barrier.wait(); - }); - } - barrier.wait(); - - TaskPool { sender, inner } + TaskPool { sender, pool: Pool::new(threads) } } - pub(crate) fn spawn(&mut self, task: F) + pub(crate) fn spawn(&mut self, qos_class: QoSClass, task: F) where F: FnOnce() -> T + Send + 'static, T: Send + 'static, { - self.inner.execute({ + self.pool.spawn(qos_class, { let sender = self.sender.clone(); - move || { - if stdx::thread::IS_QOS_AVAILABLE { - debug_assert_eq!( - stdx::thread::get_current_thread_qos_class(), - Some(stdx::thread::QoSClass::Utility) - ); - } - - sender.send(task()).unwrap() - } + move || sender.send(task()).unwrap() }) } - pub(crate) fn spawn_with_sender(&mut self, task: F) + pub(crate) fn spawn_with_sender(&mut self, qos_class: QoSClass, task: F) where F: FnOnce(Sender) + Send + 'static, T: Send + 'static, { - self.inner.execute({ + self.pool.spawn(qos_class, { let sender = self.sender.clone(); move || task(sender) }) } pub(crate) fn len(&self) -> usize { - self.inner.queued_count() - } -} - -impl Drop for TaskPool { - fn drop(&mut self) { - self.inner.join() + self.pool.len() } } diff --git a/crates/stdx/Cargo.toml b/crates/stdx/Cargo.toml index 986e3fcdcf..a67f36ae90 100644 --- a/crates/stdx/Cargo.toml +++ b/crates/stdx/Cargo.toml @@ -16,6 +16,7 @@ libc = "0.2.135" backtrace = { version = "0.3.65", optional = true } always-assert = { version = "0.1.2", features = ["log"] } jod-thread = "0.1.2" +crossbeam-channel = "0.5.5" # Think twice before adding anything here [target.'cfg(windows)'.dependencies] diff --git a/crates/stdx/src/thread.rs b/crates/stdx/src/thread.rs index 5042f00143..8630961f85 100644 --- a/crates/stdx/src/thread.rs +++ b/crates/stdx/src/thread.rs @@ -13,6 +13,9 @@ use std::fmt; +mod pool; +pub use pool::Pool; + pub fn spawn(qos_class: QoSClass, f: F) -> JoinHandle where F: FnOnce() -> T, @@ -152,6 +155,8 @@ pub enum QoSClass { /// performance, responsiveness and efficiency. Utility, + Default, + /// TLDR: tasks that block using your app /// /// Contract: @@ -229,6 +234,7 @@ mod imp { let c = match class { QoSClass::UserInteractive => libc::qos_class_t::QOS_CLASS_USER_INTERACTIVE, QoSClass::UserInitiated => libc::qos_class_t::QOS_CLASS_USER_INITIATED, + QoSClass::Default => libc::qos_class_t::QOS_CLASS_DEFAULT, QoSClass::Utility => libc::qos_class_t::QOS_CLASS_UTILITY, QoSClass::Background => libc::qos_class_t::QOS_CLASS_BACKGROUND, }; diff --git a/crates/stdx/src/thread/pool.rs b/crates/stdx/src/thread/pool.rs new file mode 100644 index 0000000000..b4ab9cb292 --- /dev/null +++ b/crates/stdx/src/thread/pool.rs @@ -0,0 +1,95 @@ +//! [`Pool`] implements a basic custom thread pool +//! inspired by the [`threadpool` crate](http://docs.rs/threadpool). +//! It allows the spawning of tasks under different QoS classes. +//! rust-analyzer uses this to prioritize work based on latency requirements. +//! +//! The thread pool is implemented entirely using +//! the threading utilities in [`crate::thread`]. + +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +use crossbeam_channel::{Receiver, Sender}; + +use super::{ + get_current_thread_qos_class, set_current_thread_qos_class, Builder, JoinHandle, QoSClass, + IS_QOS_AVAILABLE, +}; + +pub struct Pool { + // `_handles` is never read: the field is present + // only for its `Drop` impl. + + // The worker threads exit once the channel closes; + // make sure to keep `job_sender` above `handles` + // so that the channel is actually closed + // before we join the worker threads! + job_sender: Sender, + _handles: Vec, + extant_tasks: Arc, +} + +struct Job { + requested_qos_class: QoSClass, + f: Box, +} + +impl Pool { + pub fn new(threads: usize) -> Pool { + const STACK_SIZE: usize = 8 * 1024 * 1024; + const INITIAL_QOS_CLASS: QoSClass = QoSClass::Utility; + + let (job_sender, job_receiver) = crossbeam_channel::unbounded(); + let extant_tasks = Arc::new(AtomicUsize::new(0)); + + let mut handles = Vec::with_capacity(threads); + for _ in 0..threads { + let handle = Builder::new(INITIAL_QOS_CLASS) + .stack_size(STACK_SIZE) + .name("Worker".into()) + .spawn({ + let extant_tasks = Arc::clone(&extant_tasks); + let job_receiver: Receiver = job_receiver.clone(); + move || { + let mut current_qos_class = INITIAL_QOS_CLASS; + for job in job_receiver { + if job.requested_qos_class != current_qos_class { + set_current_thread_qos_class(job.requested_qos_class); + current_qos_class = job.requested_qos_class; + } + extant_tasks.fetch_add(1, Ordering::SeqCst); + (job.f)(); + extant_tasks.fetch_sub(1, Ordering::SeqCst); + } + } + }) + .expect("failed to spawn thread"); + + handles.push(handle); + } + + Pool { _handles: handles, extant_tasks, job_sender } + } + + pub fn spawn(&self, qos_class: QoSClass, f: F) + where + F: FnOnce() + Send + 'static, + { + let f = Box::new(move || { + if IS_QOS_AVAILABLE { + debug_assert_eq!(get_current_thread_qos_class(), Some(qos_class)); + } + + f() + }); + + let job = Job { requested_qos_class: qos_class, f }; + self.job_sender.send(job).unwrap(); + } + + pub fn len(&self) -> usize { + self.extant_tasks.load(Ordering::SeqCst) + } +} diff --git a/crates/vfs-notify/src/lib.rs b/crates/vfs-notify/src/lib.rs index 26f7a9fc42..90a7d7d6c0 100644 --- a/crates/vfs-notify/src/lib.rs +++ b/crates/vfs-notify/src/lib.rs @@ -34,7 +34,7 @@ impl loader::Handle for NotifyHandle { fn spawn(sender: loader::Sender) -> NotifyHandle { let actor = NotifyActor::new(sender); let (sender, receiver) = unbounded::(); - let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Utility) + let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Default) .name("VfsLoader".to_owned()) .spawn(move || actor.run(receiver)) .expect("failed to spawn thread");