diff --git a/crates/rust-analyzer/src/global_state.rs b/crates/rust-analyzer/src/global_state.rs index 446207e9e5..de6b956866 100644 --- a/crates/rust-analyzer/src/global_state.rs +++ b/crates/rust-analyzer/src/global_state.rs @@ -20,8 +20,9 @@ use crate::{ diagnostics::{CheckFixes, DiagnosticCollection}, from_proto, line_endings::LineEndings, - main_loop::ReqQueue, + main_loop::{ReqQueue, Task}, request_metrics::{LatestRequests, RequestMetrics}, + thread_pool::TaskPool, to_proto::url_from_abs_path, Result, }; @@ -66,6 +67,7 @@ impl Default for Status { /// incremental salsa database. pub(crate) struct GlobalState { pub(crate) config: Config, + pub(crate) task_pool: (TaskPool, Receiver), pub(crate) analysis_host: AnalysisHost, pub(crate) loader: Box, pub(crate) task_receiver: Receiver, @@ -153,8 +155,15 @@ impl GlobalState { let mut analysis_host = AnalysisHost::new(lru_capacity); analysis_host.apply_change(change); + + let task_pool = { + let (sender, receiver) = unbounded(); + (TaskPool::new(sender), receiver) + }; + let mut res = GlobalState { config, + task_pool, analysis_host, loader, task_receiver, diff --git a/crates/rust-analyzer/src/lib.rs b/crates/rust-analyzer/src/lib.rs index 7942866726..ca788dd3cf 100644 --- a/crates/rust-analyzer/src/lib.rs +++ b/crates/rust-analyzer/src/lib.rs @@ -30,6 +30,7 @@ mod diagnostics; mod line_endings; mod request_metrics; mod lsp_utils; +mod thread_pool; pub mod lsp_ext; pub mod config; diff --git a/crates/rust-analyzer/src/main_loop.rs b/crates/rust-analyzer/src/main_loop.rs index a7a7d2eb7e..1a9c5ee2ce 100644 --- a/crates/rust-analyzer/src/main_loop.rs +++ b/crates/rust-analyzer/src/main_loop.rs @@ -2,11 +2,10 @@ //! requests/replies and notifications back to the client. use std::{ env, fmt, panic, - sync::Arc, time::{Duration, Instant}, }; -use crossbeam_channel::{never, select, unbounded, RecvError, Sender}; +use crossbeam_channel::{never, select, RecvError, Sender}; use lsp_server::{Connection, ErrorCode, Notification, Request, RequestId, Response}; use lsp_types::{request::Request as _, NumberOrString}; use ra_db::VfsPath; @@ -14,7 +13,6 @@ use ra_ide::{Canceled, FileId}; use ra_prof::profile; use ra_project_model::{PackageRoot, ProjectWorkspace}; use serde::{de::DeserializeOwned, Serialize}; -use threadpool::ThreadPool; use crate::{ config::{Config, FilesWatcher, LinkedProject}, @@ -118,12 +116,8 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> { GlobalState::new(workspaces, config.lru_capacity, config, req_queue) }; - let pool = ThreadPool::default(); - let (task_sender, task_receiver) = unbounded::(); - log::info!("server initialized, serving requests"); { - let task_sender = task_sender; loop { log::trace!("selecting"); let event = select! { @@ -131,7 +125,7 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> { Ok(msg) => Event::Lsp(msg), Err(RecvError) => return Err("client exited without shutdown".into()), }, - recv(task_receiver) -> task => Event::Task(task.unwrap()), + recv(&global_state.task_pool.1) -> task => Event::Task(task.unwrap()), recv(global_state.task_receiver) -> task => match task { Ok(task) => Event::Vfs(task), Err(RecvError) => return Err("vfs died".into()), @@ -147,29 +141,19 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> { }; } assert!(!global_state.vfs.read().0.has_changes()); - loop_turn(&pool, &task_sender, &connection, &mut global_state, event)?; + loop_turn(&connection, &mut global_state, event)?; assert!(!global_state.vfs.read().0.has_changes()); } } global_state.analysis_host.request_cancellation(); - log::info!("waiting for tasks to finish..."); - task_receiver.into_iter().for_each(|task| on_task(task, &connection.sender, &mut global_state)); - log::info!("...tasks have finished"); - log::info!("joining threadpool..."); - pool.join(); - drop(pool); - log::info!("...threadpool has finished"); - - let vfs = Arc::try_unwrap(global_state.vfs).expect("all snapshots should be dead"); - drop(vfs); - Ok(()) } #[derive(Debug)] -enum Task { +pub(crate) enum Task { Respond(Response), - Diagnostic(DiagnosticTask), + Diagnostics(Vec), + Unit, } enum Event { @@ -215,19 +199,13 @@ pub(crate) type ReqHandler = fn(&mut GlobalState, Response); pub(crate) type ReqQueue = lsp_server::ReqQueue<(&'static str, Instant), ReqHandler>; const DO_NOTHING: ReqHandler = |_, _| (); -fn loop_turn( - pool: &ThreadPool, - task_sender: &Sender, - connection: &Connection, - global_state: &mut GlobalState, - event: Event, -) -> Result<()> { +fn loop_turn(connection: &Connection, global_state: &mut GlobalState, event: Event) -> Result<()> { let loop_start = Instant::now(); // NOTE: don't count blocking select! call as a loop-turn time let _p = profile("main_loop_inner/loop-turn"); log::info!("loop turn = {:?}", event); - let queue_count = pool.queued_count(); + let queue_count = global_state.task_pool.0.len(); if queue_count > 0 { log::info!("queued count = {}", queue_count); } @@ -269,12 +247,10 @@ fn loop_turn( ) } }, - Event::Flycheck(task) => { - on_check_task(task, global_state, task_sender, &connection.sender)? - } + Event::Flycheck(task) => on_check_task(task, global_state, &connection.sender)?, Event::Lsp(msg) => match msg { lsp_server::Message::Request(req) => { - on_request(global_state, pool, task_sender, &connection.sender, loop_start, req)? + on_request(global_state, &connection.sender, loop_start, req)? } lsp_server::Message::Notification(not) => { on_notification(&connection.sender, global_state, not)?; @@ -301,16 +277,14 @@ fn loop_turn( .map(|path| global_state.vfs.read().0.file_id(&path).unwrap()) .collect::>(); - update_file_notifications_on_threadpool( - pool, - global_state.snapshot(), - task_sender.clone(), - subscriptions.clone(), - ); - pool.execute({ + update_file_notifications_on_threadpool(global_state, subscriptions.clone()); + global_state.task_pool.0.spawn({ let subs = subscriptions; let snap = global_state.snapshot(); - move || snap.analysis.prime_caches(subs).unwrap_or_else(|_: Canceled| ()) + move || { + snap.analysis.prime_caches(subs).unwrap_or_else(|_: Canceled| ()); + Task::Unit + } }); } @@ -345,26 +319,21 @@ fn on_task(task: Task, msg_sender: &Sender, global_state: & msg_sender.send(response.into()).unwrap(); } } - Task::Diagnostic(task) => on_diagnostic_task(task, msg_sender, global_state), + Task::Diagnostics(tasks) => { + tasks.into_iter().for_each(|task| on_diagnostic_task(task, msg_sender, global_state)) + } + Task::Unit => (), } } fn on_request( global_state: &mut GlobalState, - pool: &ThreadPool, - task_sender: &Sender, msg_sender: &Sender, request_received: Instant, req: Request, ) -> Result<()> { - let mut pool_dispatcher = PoolDispatcher { - req: Some(req), - pool, - global_state, - task_sender, - msg_sender, - request_received, - }; + let mut pool_dispatcher = + PoolDispatcher { req: Some(req), global_state, msg_sender, request_received }; pool_dispatcher .on_sync::(|s, ()| Ok(s.collect_garbage()))? .on_sync::(|s, p| handlers::handle_join_lines(s.snapshot(), p))? @@ -552,12 +521,11 @@ fn on_notification( fn on_check_task( task: flycheck::Message, global_state: &mut GlobalState, - task_sender: &Sender, msg_sender: &Sender, ) -> Result<()> { match task { flycheck::Message::ClearDiagnostics => { - task_sender.send(Task::Diagnostic(DiagnosticTask::ClearCheck))?; + on_diagnostic_task(DiagnosticTask::ClearCheck, msg_sender, global_state) } flycheck::Message::AddDiagnostic { workspace_root, diagnostic } => { @@ -576,11 +544,15 @@ fn on_check_task( } }; - task_sender.send(Task::Diagnostic(DiagnosticTask::AddCheck( - file_id, - diag.diagnostic, - diag.fixes.into_iter().map(|it| it.into()).collect(), - )))?; + on_diagnostic_task( + DiagnosticTask::AddCheck( + file_id, + diag.diagnostic, + diag.fixes.into_iter().map(|it| it.into()).collect(), + ), + msg_sender, + global_state, + ) } } @@ -674,10 +646,8 @@ fn report_progress( struct PoolDispatcher<'a> { req: Option, - pool: &'a ThreadPool, global_state: &'a mut GlobalState, msg_sender: &'a Sender, - task_sender: &'a Sender, request_received: Instant, } @@ -725,13 +695,11 @@ impl<'a> PoolDispatcher<'a> { } }; - self.pool.execute({ + self.global_state.task_pool.0.spawn({ let world = self.global_state.snapshot(); - let sender = self.task_sender.clone(); move || { let result = f(world, params); - let task = result_to_task::(id, result); - sender.send(task).unwrap(); + result_to_task::(id, result) } }); @@ -801,26 +769,27 @@ where } fn update_file_notifications_on_threadpool( - pool: &ThreadPool, - world: GlobalStateSnapshot, - task_sender: Sender, + global_state: &mut GlobalState, subscriptions: Vec, ) { log::trace!("updating notifications for {:?}", subscriptions); - if world.config.publish_diagnostics { - pool.execute(move || { - for file_id in subscriptions { - match handlers::publish_diagnostics(&world, file_id) { - Err(e) => { - if !is_canceled(&*e) { - log::error!("failed to compute diagnostics: {:?}", e); - } - } - Ok(task) => { - task_sender.send(Task::Diagnostic(task)).unwrap(); - } - } - } + if global_state.config.publish_diagnostics { + let snapshot = global_state.snapshot(); + global_state.task_pool.0.spawn(move || { + let diagnostics = subscriptions + .into_iter() + .filter_map(|file_id| { + handlers::publish_diagnostics(&snapshot, file_id) + .map_err(|err| { + if !is_canceled(&*err) { + log::error!("failed to compute diagnostics: {:?}", err); + } + () + }) + .ok() + }) + .collect::>(); + Task::Diagnostics(diagnostics) }) } } diff --git a/crates/rust-analyzer/src/thread_pool.rs b/crates/rust-analyzer/src/thread_pool.rs new file mode 100644 index 0000000000..4fa5029253 --- /dev/null +++ b/crates/rust-analyzer/src/thread_pool.rs @@ -0,0 +1,35 @@ +//! A thin wrapper around `ThreadPool` to make sure that we join all things +//! properly. +use crossbeam_channel::Sender; + +pub(crate) struct TaskPool { + sender: Sender, + inner: threadpool::ThreadPool, +} + +impl TaskPool { + pub(crate) fn new(sender: Sender) -> TaskPool { + TaskPool { sender, inner: threadpool::ThreadPool::default() } + } + + pub(crate) fn spawn(&mut self, task: F) + where + F: FnOnce() -> T + Send + 'static, + T: Send + 'static, + { + self.inner.execute({ + let sender = self.sender.clone(); + move || sender.send(task()).unwrap() + }) + } + + pub(crate) fn len(&self) -> usize { + self.inner.queued_count() + } +} + +impl Drop for TaskPool { + fn drop(&mut self) { + self.inner.join() + } +}