Implement custom QoS-aware thread pool

This code replaces the thread pool implementation we were using
previously (from the `threadpool` crate). By making the thread pool
aware of QoS, each job spawned on the thread pool can have a different
QoS class.

This commit also replaces every QoS class used previously with Default
as a temporary measure so that each usage can be chosen deliberately.
This commit is contained in:
Luna Razzaghipour 2023-05-25 17:04:51 +10:00
parent f6e3a87bf9
commit 2924fd2213
No known key found for this signature in database
14 changed files with 184 additions and 93 deletions

11
Cargo.lock generated
View file

@ -1518,7 +1518,6 @@ dependencies = [
"syntax",
"test-utils",
"thiserror",
"threadpool",
"tikv-jemallocator",
"toolchain",
"tracing",
@ -1712,6 +1711,7 @@ version = "0.0.0"
dependencies = [
"always-assert",
"backtrace",
"crossbeam-channel",
"jod-thread",
"libc",
"miow",
@ -1823,15 +1823,6 @@ dependencies = [
"once_cell",
]
[[package]]
name = "threadpool"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d050e60b33d41c19108b32cea32164033a9013fe3b46cbd4457559bfbf77afaa"
dependencies = [
"num_cpus",
]
[[package]]
name = "tikv-jemalloc-ctl"
version = "0.5.0"

View file

@ -90,7 +90,7 @@ impl FlycheckHandle {
) -> FlycheckHandle {
let actor = FlycheckActor::new(id, sender, config, workspace_root);
let (sender, receiver) = unbounded::<StateChange>();
let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Utility)
let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Default)
.name("Flycheck".to_owned())
.spawn(move || actor.run(receiver))
.expect("failed to spawn thread");
@ -409,7 +409,7 @@ impl CargoHandle {
let (sender, receiver) = unbounded();
let actor = CargoActor::new(sender, stdout, stderr);
let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Utility)
let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Default)
.name("CargoHandle".to_owned())
.spawn(move || actor.run())
.expect("failed to spawn thread");

View file

@ -81,7 +81,7 @@ pub(crate) fn parallel_prime_caches(
let worker = prime_caches_worker.clone();
let db = db.snapshot();
stdx::thread::Builder::new(stdx::thread::QoSClass::Utility)
stdx::thread::Builder::new(stdx::thread::QoSClass::Default)
.allow_leak(true)
.spawn(move || Cancelled::catch(|| worker(db)))
.expect("failed to spawn thread");

View file

@ -31,7 +31,6 @@ oorandom = "11.1.3"
rustc-hash = "1.1.0"
serde_json = { workspace = true, features = ["preserve_order"] }
serde.workspace = true
threadpool = "1.8.1"
rayon = "1.6.1"
num_cpus = "1.15.0"
mimalloc = { version = "0.1.30", default-features = false, optional = true }

View file

@ -85,7 +85,7 @@ fn try_main(flags: flags::RustAnalyzer) -> Result<()> {
// will make actions like hitting enter in the editor slow.
// rust-analyzer does not block the editors render loop,
// so we dont use User Interactive.
with_extra_thread("LspServer", stdx::thread::QoSClass::UserInitiated, run_server)?;
with_extra_thread("LspServer", stdx::thread::QoSClass::Default, run_server)?;
}
flags::RustAnalyzerCmd::Parse(cmd) => cmd.run()?,
flags::RustAnalyzerCmd::Symbols(cmd) => cmd.run()?,

View file

@ -4,6 +4,7 @@ use std::{fmt, panic, thread};
use ide::Cancelled;
use lsp_server::ExtractError;
use serde::{de::DeserializeOwned, Serialize};
use stdx::thread::QoSClass;
use crate::{
global_state::{GlobalState, GlobalStateSnapshot},
@ -102,7 +103,7 @@ impl<'a> RequestDispatcher<'a> {
None => return self,
};
self.global_state.task_pool.handle.spawn({
self.global_state.task_pool.handle.spawn(QoSClass::Default, {
let world = self.global_state.snapshot();
move || {
let result = panic::catch_unwind(move || {
@ -128,6 +129,44 @@ impl<'a> RequestDispatcher<'a> {
&mut self,
f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
) -> &mut Self
where
R: lsp_types::request::Request + 'static,
R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
R::Result: Serialize,
{
self.on_with_qos::<R>(QoSClass::Default, f)
}
/// Dispatches a latency-sensitive request onto the thread pool.
pub(crate) fn on_latency_sensitive<R>(
&mut self,
f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
) -> &mut Self
where
R: lsp_types::request::Request + 'static,
R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
R::Result: Serialize,
{
self.on_with_qos::<R>(QoSClass::Default, f)
}
pub(crate) fn finish(&mut self) {
if let Some(req) = self.req.take() {
tracing::error!("unknown request: {:?}", req);
let response = lsp_server::Response::new_err(
req.id,
lsp_server::ErrorCode::MethodNotFound as i32,
"unknown request".to_string(),
);
self.global_state.respond(response);
}
}
fn on_with_qos<R>(
&mut self,
qos_class: QoSClass,
f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
) -> &mut Self
where
R: lsp_types::request::Request + 'static,
R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug,
@ -138,7 +177,7 @@ impl<'a> RequestDispatcher<'a> {
None => return self,
};
self.global_state.task_pool.handle.spawn({
self.global_state.task_pool.handle.spawn(qos_class, {
let world = self.global_state.snapshot();
move || {
let result = panic::catch_unwind(move || {
@ -155,18 +194,6 @@ impl<'a> RequestDispatcher<'a> {
self
}
pub(crate) fn finish(&mut self) {
if let Some(req) = self.req.take() {
tracing::error!("unknown request: {:?}", req);
let response = lsp_server::Response::new_err(
req.id,
lsp_server::ErrorCode::MethodNotFound as i32,
"unknown request".to_string(),
);
self.global_state.respond(response);
}
}
fn parse<R>(&mut self) -> Option<(lsp_server::Request, R::Params, String)>
where
R: lsp_types::request::Request,

View file

@ -291,7 +291,7 @@ fn run_flycheck(state: &mut GlobalState, vfs_path: VfsPath) -> bool {
}
Ok(())
};
state.task_pool.handle.spawn_with_sender(move |_| {
state.task_pool.handle.spawn_with_sender(stdx::thread::QoSClass::Default, move |_| {
if let Err(e) = std::panic::catch_unwind(task) {
tracing::error!("flycheck task panicked: {e:?}")
}

View file

@ -397,7 +397,7 @@ impl GlobalState {
tracing::debug!(%cause, "will prime caches");
let num_worker_threads = self.config.prime_caches_num_threads();
self.task_pool.handle.spawn_with_sender({
self.task_pool.handle.spawn_with_sender(stdx::thread::QoSClass::Default, {
let analysis = self.snapshot().analysis;
move |sender| {
sender.send(Task::PrimeCaches(PrimeCachesProgress::Begin)).unwrap();
@ -678,7 +678,24 @@ impl GlobalState {
.on_sync::<lsp_types::request::SelectionRangeRequest>(handlers::handle_selection_range)
.on_sync::<lsp_ext::MatchingBrace>(handlers::handle_matching_brace)
.on_sync::<lsp_ext::OnTypeFormatting>(handlers::handle_on_type_formatting)
// All other request handlers:
// We cant run latency-sensitive request handlers which do semantic
// analysis on the main thread because that would block other
// requests. Instead, we run these request handlers on higher QoS
// threads in the threadpool.
.on_latency_sensitive::<lsp_types::request::Completion>(handlers::handle_completion)
.on_latency_sensitive::<lsp_types::request::ResolveCompletionItem>(
handlers::handle_completion_resolve,
)
.on_latency_sensitive::<lsp_types::request::SemanticTokensFullRequest>(
handlers::handle_semantic_tokens_full,
)
.on_latency_sensitive::<lsp_types::request::SemanticTokensFullDeltaRequest>(
handlers::handle_semantic_tokens_full_delta,
)
.on_latency_sensitive::<lsp_types::request::SemanticTokensRangeRequest>(
handlers::handle_semantic_tokens_range,
)
// All other request handlers
.on::<lsp_ext::FetchDependencyList>(handlers::fetch_dependency_list)
.on::<lsp_ext::AnalyzerStatus>(handlers::handle_analyzer_status)
.on::<lsp_ext::SyntaxTree>(handlers::handle_syntax_tree)
@ -706,8 +723,6 @@ impl GlobalState {
.on::<lsp_types::request::GotoTypeDefinition>(handlers::handle_goto_type_definition)
.on_no_retry::<lsp_types::request::InlayHintRequest>(handlers::handle_inlay_hints)
.on::<lsp_types::request::InlayHintResolveRequest>(handlers::handle_inlay_hints_resolve)
.on::<lsp_types::request::Completion>(handlers::handle_completion)
.on::<lsp_types::request::ResolveCompletionItem>(handlers::handle_completion_resolve)
.on::<lsp_types::request::CodeLensRequest>(handlers::handle_code_lens)
.on::<lsp_types::request::CodeLensResolve>(handlers::handle_code_lens_resolve)
.on::<lsp_types::request::FoldingRangeRequest>(handlers::handle_folding_range)
@ -725,15 +740,6 @@ impl GlobalState {
.on::<lsp_types::request::CallHierarchyOutgoingCalls>(
handlers::handle_call_hierarchy_outgoing,
)
.on::<lsp_types::request::SemanticTokensFullRequest>(
handlers::handle_semantic_tokens_full,
)
.on::<lsp_types::request::SemanticTokensFullDeltaRequest>(
handlers::handle_semantic_tokens_full_delta,
)
.on::<lsp_types::request::SemanticTokensRangeRequest>(
handlers::handle_semantic_tokens_range,
)
.on::<lsp_types::request::WillRenameFiles>(handlers::handle_will_rename_files)
.on::<lsp_ext::Ssr>(handlers::handle_ssr)
.finish();
@ -781,7 +787,7 @@ impl GlobalState {
tracing::trace!("updating notifications for {:?}", subscriptions);
let snapshot = self.snapshot();
self.task_pool.handle.spawn(move || {
self.task_pool.handle.spawn(stdx::thread::QoSClass::Default, move || {
let _p = profile::span("publish_diagnostics");
let diagnostics = subscriptions
.into_iter()

View file

@ -185,7 +185,7 @@ impl GlobalState {
pub(crate) fn fetch_workspaces(&mut self, cause: Cause) {
tracing::info!(%cause, "will fetch workspaces");
self.task_pool.handle.spawn_with_sender({
self.task_pool.handle.spawn_with_sender(stdx::thread::QoSClass::Default, {
let linked_projects = self.config.linked_projects();
let detached_files = self.config.detached_files().to_vec();
let cargo_config = self.config.cargo();
@ -260,7 +260,7 @@ impl GlobalState {
tracing::info!(%cause, "will fetch build data");
let workspaces = Arc::clone(&self.workspaces);
let config = self.config.cargo();
self.task_pool.handle.spawn_with_sender(move |sender| {
self.task_pool.handle.spawn_with_sender(stdx::thread::QoSClass::Default, move |sender| {
sender.send(Task::FetchBuildData(BuildDataProgress::Begin)).unwrap();
let progress = {
@ -280,7 +280,7 @@ impl GlobalState {
let dummy_replacements = self.config.dummy_replacements().clone();
let proc_macro_clients = self.proc_macro_clients.clone();
self.task_pool.handle.spawn_with_sender(move |sender| {
self.task_pool.handle.spawn_with_sender(stdx::thread::QoSClass::Default, move |sender| {
sender.send(Task::LoadProcMacros(ProcMacroProgress::Begin)).unwrap();
let dummy_replacements = &dummy_replacements;

View file

@ -1,76 +1,42 @@
//! A thin wrapper around `ThreadPool` to make sure that we join all things
//! properly.
use std::sync::{Arc, Barrier};
//! A thin wrapper around [`stdx::thread::Pool`] which threads a sender through spawned jobs.
//! It is used in [`crate::global_state::GlobalState`] throughout the main loop.
use crossbeam_channel::Sender;
use stdx::thread::{Pool, QoSClass};
pub(crate) struct TaskPool<T> {
sender: Sender<T>,
inner: threadpool::ThreadPool,
pool: Pool,
}
impl<T> TaskPool<T> {
pub(crate) fn new_with_threads(sender: Sender<T>, threads: usize) -> TaskPool<T> {
const STACK_SIZE: usize = 8 * 1024 * 1024;
let inner = threadpool::Builder::new()
.thread_name("Worker".into())
.thread_stack_size(STACK_SIZE)
.num_threads(threads)
.build();
// Set QoS of all threads in threadpool.
let barrier = Arc::new(Barrier::new(threads + 1));
for _ in 0..threads {
let barrier = barrier.clone();
inner.execute(move || {
stdx::thread::set_current_thread_qos_class(stdx::thread::QoSClass::Utility);
barrier.wait();
});
}
barrier.wait();
TaskPool { sender, inner }
TaskPool { sender, pool: Pool::new(threads) }
}
pub(crate) fn spawn<F>(&mut self, task: F)
pub(crate) fn spawn<F>(&mut self, qos_class: QoSClass, task: F)
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
self.inner.execute({
self.pool.spawn(qos_class, {
let sender = self.sender.clone();
move || {
if stdx::thread::IS_QOS_AVAILABLE {
debug_assert_eq!(
stdx::thread::get_current_thread_qos_class(),
Some(stdx::thread::QoSClass::Utility)
);
}
sender.send(task()).unwrap()
}
move || sender.send(task()).unwrap()
})
}
pub(crate) fn spawn_with_sender<F>(&mut self, task: F)
pub(crate) fn spawn_with_sender<F>(&mut self, qos_class: QoSClass, task: F)
where
F: FnOnce(Sender<T>) + Send + 'static,
T: Send + 'static,
{
self.inner.execute({
self.pool.spawn(qos_class, {
let sender = self.sender.clone();
move || task(sender)
})
}
pub(crate) fn len(&self) -> usize {
self.inner.queued_count()
}
}
impl<T> Drop for TaskPool<T> {
fn drop(&mut self) {
self.inner.join()
self.pool.len()
}
}

View file

@ -16,6 +16,7 @@ libc = "0.2.135"
backtrace = { version = "0.3.65", optional = true }
always-assert = { version = "0.1.2", features = ["log"] }
jod-thread = "0.1.2"
crossbeam-channel = "0.5.5"
# Think twice before adding anything here
[target.'cfg(windows)'.dependencies]

View file

@ -13,6 +13,9 @@
use std::fmt;
mod pool;
pub use pool::Pool;
pub fn spawn<F, T>(qos_class: QoSClass, f: F) -> JoinHandle<T>
where
F: FnOnce() -> T,
@ -152,6 +155,8 @@ pub enum QoSClass {
/// performance, responsiveness and efficiency.
Utility,
Default,
/// TLDR: tasks that block using your app
///
/// Contract:
@ -229,6 +234,7 @@ mod imp {
let c = match class {
QoSClass::UserInteractive => libc::qos_class_t::QOS_CLASS_USER_INTERACTIVE,
QoSClass::UserInitiated => libc::qos_class_t::QOS_CLASS_USER_INITIATED,
QoSClass::Default => libc::qos_class_t::QOS_CLASS_DEFAULT,
QoSClass::Utility => libc::qos_class_t::QOS_CLASS_UTILITY,
QoSClass::Background => libc::qos_class_t::QOS_CLASS_BACKGROUND,
};

View file

@ -0,0 +1,95 @@
//! [`Pool`] implements a basic custom thread pool
//! inspired by the [`threadpool` crate](http://docs.rs/threadpool).
//! It allows the spawning of tasks under different QoS classes.
//! rust-analyzer uses this to prioritize work based on latency requirements.
//!
//! The thread pool is implemented entirely using
//! the threading utilities in [`crate::thread`].
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use crossbeam_channel::{Receiver, Sender};
use super::{
get_current_thread_qos_class, set_current_thread_qos_class, Builder, JoinHandle, QoSClass,
IS_QOS_AVAILABLE,
};
pub struct Pool {
// `_handles` is never read: the field is present
// only for its `Drop` impl.
// The worker threads exit once the channel closes;
// make sure to keep `job_sender` above `handles`
// so that the channel is actually closed
// before we join the worker threads!
job_sender: Sender<Job>,
_handles: Vec<JoinHandle>,
extant_tasks: Arc<AtomicUsize>,
}
struct Job {
requested_qos_class: QoSClass,
f: Box<dyn FnOnce() + Send + 'static>,
}
impl Pool {
pub fn new(threads: usize) -> Pool {
const STACK_SIZE: usize = 8 * 1024 * 1024;
const INITIAL_QOS_CLASS: QoSClass = QoSClass::Utility;
let (job_sender, job_receiver) = crossbeam_channel::unbounded();
let extant_tasks = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::with_capacity(threads);
for _ in 0..threads {
let handle = Builder::new(INITIAL_QOS_CLASS)
.stack_size(STACK_SIZE)
.name("Worker".into())
.spawn({
let extant_tasks = Arc::clone(&extant_tasks);
let job_receiver: Receiver<Job> = job_receiver.clone();
move || {
let mut current_qos_class = INITIAL_QOS_CLASS;
for job in job_receiver {
if job.requested_qos_class != current_qos_class {
set_current_thread_qos_class(job.requested_qos_class);
current_qos_class = job.requested_qos_class;
}
extant_tasks.fetch_add(1, Ordering::SeqCst);
(job.f)();
extant_tasks.fetch_sub(1, Ordering::SeqCst);
}
}
})
.expect("failed to spawn thread");
handles.push(handle);
}
Pool { _handles: handles, extant_tasks, job_sender }
}
pub fn spawn<F>(&self, qos_class: QoSClass, f: F)
where
F: FnOnce() + Send + 'static,
{
let f = Box::new(move || {
if IS_QOS_AVAILABLE {
debug_assert_eq!(get_current_thread_qos_class(), Some(qos_class));
}
f()
});
let job = Job { requested_qos_class: qos_class, f };
self.job_sender.send(job).unwrap();
}
pub fn len(&self) -> usize {
self.extant_tasks.load(Ordering::SeqCst)
}
}

View file

@ -34,7 +34,7 @@ impl loader::Handle for NotifyHandle {
fn spawn(sender: loader::Sender) -> NotifyHandle {
let actor = NotifyActor::new(sender);
let (sender, receiver) = unbounded::<Message>();
let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Utility)
let thread = stdx::thread::Builder::new(stdx::thread::QoSClass::Default)
.name("VfsLoader".to_owned())
.spawn(move || actor.run(receiver))
.expect("failed to spawn thread");