Move TaskPool into GlobalState

This commit is contained in:
Aleksey Kladov 2020-06-25 15:35:42 +02:00
parent 9be0094b5c
commit dd20c2ec5b
4 changed files with 98 additions and 84 deletions

View file

@ -20,8 +20,9 @@ use crate::{
diagnostics::{CheckFixes, DiagnosticCollection}, diagnostics::{CheckFixes, DiagnosticCollection},
from_proto, from_proto,
line_endings::LineEndings, line_endings::LineEndings,
main_loop::ReqQueue, main_loop::{ReqQueue, Task},
request_metrics::{LatestRequests, RequestMetrics}, request_metrics::{LatestRequests, RequestMetrics},
thread_pool::TaskPool,
to_proto::url_from_abs_path, to_proto::url_from_abs_path,
Result, Result,
}; };
@ -66,6 +67,7 @@ impl Default for Status {
/// incremental salsa database. /// incremental salsa database.
pub(crate) struct GlobalState { pub(crate) struct GlobalState {
pub(crate) config: Config, pub(crate) config: Config,
pub(crate) task_pool: (TaskPool<Task>, Receiver<Task>),
pub(crate) analysis_host: AnalysisHost, pub(crate) analysis_host: AnalysisHost,
pub(crate) loader: Box<dyn vfs::loader::Handle>, pub(crate) loader: Box<dyn vfs::loader::Handle>,
pub(crate) task_receiver: Receiver<vfs::loader::Message>, pub(crate) task_receiver: Receiver<vfs::loader::Message>,
@ -153,8 +155,15 @@ impl GlobalState {
let mut analysis_host = AnalysisHost::new(lru_capacity); let mut analysis_host = AnalysisHost::new(lru_capacity);
analysis_host.apply_change(change); analysis_host.apply_change(change);
let task_pool = {
let (sender, receiver) = unbounded();
(TaskPool::new(sender), receiver)
};
let mut res = GlobalState { let mut res = GlobalState {
config, config,
task_pool,
analysis_host, analysis_host,
loader, loader,
task_receiver, task_receiver,

View file

@ -30,6 +30,7 @@ mod diagnostics;
mod line_endings; mod line_endings;
mod request_metrics; mod request_metrics;
mod lsp_utils; mod lsp_utils;
mod thread_pool;
pub mod lsp_ext; pub mod lsp_ext;
pub mod config; pub mod config;

View file

@ -2,11 +2,10 @@
//! requests/replies and notifications back to the client. //! requests/replies and notifications back to the client.
use std::{ use std::{
env, fmt, panic, env, fmt, panic,
sync::Arc,
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use crossbeam_channel::{never, select, unbounded, RecvError, Sender}; use crossbeam_channel::{never, select, RecvError, Sender};
use lsp_server::{Connection, ErrorCode, Notification, Request, RequestId, Response}; use lsp_server::{Connection, ErrorCode, Notification, Request, RequestId, Response};
use lsp_types::{request::Request as _, NumberOrString}; use lsp_types::{request::Request as _, NumberOrString};
use ra_db::VfsPath; use ra_db::VfsPath;
@ -14,7 +13,6 @@ use ra_ide::{Canceled, FileId};
use ra_prof::profile; use ra_prof::profile;
use ra_project_model::{PackageRoot, ProjectWorkspace}; use ra_project_model::{PackageRoot, ProjectWorkspace};
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use threadpool::ThreadPool;
use crate::{ use crate::{
config::{Config, FilesWatcher, LinkedProject}, config::{Config, FilesWatcher, LinkedProject},
@ -118,12 +116,8 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> {
GlobalState::new(workspaces, config.lru_capacity, config, req_queue) GlobalState::new(workspaces, config.lru_capacity, config, req_queue)
}; };
let pool = ThreadPool::default();
let (task_sender, task_receiver) = unbounded::<Task>();
log::info!("server initialized, serving requests"); log::info!("server initialized, serving requests");
{ {
let task_sender = task_sender;
loop { loop {
log::trace!("selecting"); log::trace!("selecting");
let event = select! { let event = select! {
@ -131,7 +125,7 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> {
Ok(msg) => Event::Lsp(msg), Ok(msg) => Event::Lsp(msg),
Err(RecvError) => return Err("client exited without shutdown".into()), Err(RecvError) => return Err("client exited without shutdown".into()),
}, },
recv(task_receiver) -> task => Event::Task(task.unwrap()), recv(&global_state.task_pool.1) -> task => Event::Task(task.unwrap()),
recv(global_state.task_receiver) -> task => match task { recv(global_state.task_receiver) -> task => match task {
Ok(task) => Event::Vfs(task), Ok(task) => Event::Vfs(task),
Err(RecvError) => return Err("vfs died".into()), Err(RecvError) => return Err("vfs died".into()),
@ -147,29 +141,19 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> {
}; };
} }
assert!(!global_state.vfs.read().0.has_changes()); assert!(!global_state.vfs.read().0.has_changes());
loop_turn(&pool, &task_sender, &connection, &mut global_state, event)?; loop_turn(&connection, &mut global_state, event)?;
assert!(!global_state.vfs.read().0.has_changes()); assert!(!global_state.vfs.read().0.has_changes());
} }
} }
global_state.analysis_host.request_cancellation(); global_state.analysis_host.request_cancellation();
log::info!("waiting for tasks to finish...");
task_receiver.into_iter().for_each(|task| on_task(task, &connection.sender, &mut global_state));
log::info!("...tasks have finished");
log::info!("joining threadpool...");
pool.join();
drop(pool);
log::info!("...threadpool has finished");
let vfs = Arc::try_unwrap(global_state.vfs).expect("all snapshots should be dead");
drop(vfs);
Ok(()) Ok(())
} }
#[derive(Debug)] #[derive(Debug)]
enum Task { pub(crate) enum Task {
Respond(Response), Respond(Response),
Diagnostic(DiagnosticTask), Diagnostics(Vec<DiagnosticTask>),
Unit,
} }
enum Event { enum Event {
@ -215,19 +199,13 @@ pub(crate) type ReqHandler = fn(&mut GlobalState, Response);
pub(crate) type ReqQueue = lsp_server::ReqQueue<(&'static str, Instant), ReqHandler>; pub(crate) type ReqQueue = lsp_server::ReqQueue<(&'static str, Instant), ReqHandler>;
const DO_NOTHING: ReqHandler = |_, _| (); const DO_NOTHING: ReqHandler = |_, _| ();
fn loop_turn( fn loop_turn(connection: &Connection, global_state: &mut GlobalState, event: Event) -> Result<()> {
pool: &ThreadPool,
task_sender: &Sender<Task>,
connection: &Connection,
global_state: &mut GlobalState,
event: Event,
) -> Result<()> {
let loop_start = Instant::now(); let loop_start = Instant::now();
// NOTE: don't count blocking select! call as a loop-turn time // NOTE: don't count blocking select! call as a loop-turn time
let _p = profile("main_loop_inner/loop-turn"); let _p = profile("main_loop_inner/loop-turn");
log::info!("loop turn = {:?}", event); log::info!("loop turn = {:?}", event);
let queue_count = pool.queued_count(); let queue_count = global_state.task_pool.0.len();
if queue_count > 0 { if queue_count > 0 {
log::info!("queued count = {}", queue_count); log::info!("queued count = {}", queue_count);
} }
@ -269,12 +247,10 @@ fn loop_turn(
) )
} }
}, },
Event::Flycheck(task) => { Event::Flycheck(task) => on_check_task(task, global_state, &connection.sender)?,
on_check_task(task, global_state, task_sender, &connection.sender)?
}
Event::Lsp(msg) => match msg { Event::Lsp(msg) => match msg {
lsp_server::Message::Request(req) => { lsp_server::Message::Request(req) => {
on_request(global_state, pool, task_sender, &connection.sender, loop_start, req)? on_request(global_state, &connection.sender, loop_start, req)?
} }
lsp_server::Message::Notification(not) => { lsp_server::Message::Notification(not) => {
on_notification(&connection.sender, global_state, not)?; on_notification(&connection.sender, global_state, not)?;
@ -301,16 +277,14 @@ fn loop_turn(
.map(|path| global_state.vfs.read().0.file_id(&path).unwrap()) .map(|path| global_state.vfs.read().0.file_id(&path).unwrap())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
update_file_notifications_on_threadpool( update_file_notifications_on_threadpool(global_state, subscriptions.clone());
pool, global_state.task_pool.0.spawn({
global_state.snapshot(),
task_sender.clone(),
subscriptions.clone(),
);
pool.execute({
let subs = subscriptions; let subs = subscriptions;
let snap = global_state.snapshot(); let snap = global_state.snapshot();
move || snap.analysis.prime_caches(subs).unwrap_or_else(|_: Canceled| ()) move || {
snap.analysis.prime_caches(subs).unwrap_or_else(|_: Canceled| ());
Task::Unit
}
}); });
} }
@ -345,26 +319,21 @@ fn on_task(task: Task, msg_sender: &Sender<lsp_server::Message>, global_state: &
msg_sender.send(response.into()).unwrap(); msg_sender.send(response.into()).unwrap();
} }
} }
Task::Diagnostic(task) => on_diagnostic_task(task, msg_sender, global_state), Task::Diagnostics(tasks) => {
tasks.into_iter().for_each(|task| on_diagnostic_task(task, msg_sender, global_state))
}
Task::Unit => (),
} }
} }
fn on_request( fn on_request(
global_state: &mut GlobalState, global_state: &mut GlobalState,
pool: &ThreadPool,
task_sender: &Sender<Task>,
msg_sender: &Sender<lsp_server::Message>, msg_sender: &Sender<lsp_server::Message>,
request_received: Instant, request_received: Instant,
req: Request, req: Request,
) -> Result<()> { ) -> Result<()> {
let mut pool_dispatcher = PoolDispatcher { let mut pool_dispatcher =
req: Some(req), PoolDispatcher { req: Some(req), global_state, msg_sender, request_received };
pool,
global_state,
task_sender,
msg_sender,
request_received,
};
pool_dispatcher pool_dispatcher
.on_sync::<lsp_ext::CollectGarbage>(|s, ()| Ok(s.collect_garbage()))? .on_sync::<lsp_ext::CollectGarbage>(|s, ()| Ok(s.collect_garbage()))?
.on_sync::<lsp_ext::JoinLines>(|s, p| handlers::handle_join_lines(s.snapshot(), p))? .on_sync::<lsp_ext::JoinLines>(|s, p| handlers::handle_join_lines(s.snapshot(), p))?
@ -552,12 +521,11 @@ fn on_notification(
fn on_check_task( fn on_check_task(
task: flycheck::Message, task: flycheck::Message,
global_state: &mut GlobalState, global_state: &mut GlobalState,
task_sender: &Sender<Task>,
msg_sender: &Sender<lsp_server::Message>, msg_sender: &Sender<lsp_server::Message>,
) -> Result<()> { ) -> Result<()> {
match task { match task {
flycheck::Message::ClearDiagnostics => { flycheck::Message::ClearDiagnostics => {
task_sender.send(Task::Diagnostic(DiagnosticTask::ClearCheck))?; on_diagnostic_task(DiagnosticTask::ClearCheck, msg_sender, global_state)
} }
flycheck::Message::AddDiagnostic { workspace_root, diagnostic } => { flycheck::Message::AddDiagnostic { workspace_root, diagnostic } => {
@ -576,11 +544,15 @@ fn on_check_task(
} }
}; };
task_sender.send(Task::Diagnostic(DiagnosticTask::AddCheck( on_diagnostic_task(
DiagnosticTask::AddCheck(
file_id, file_id,
diag.diagnostic, diag.diagnostic,
diag.fixes.into_iter().map(|it| it.into()).collect(), diag.fixes.into_iter().map(|it| it.into()).collect(),
)))?; ),
msg_sender,
global_state,
)
} }
} }
@ -674,10 +646,8 @@ fn report_progress(
struct PoolDispatcher<'a> { struct PoolDispatcher<'a> {
req: Option<Request>, req: Option<Request>,
pool: &'a ThreadPool,
global_state: &'a mut GlobalState, global_state: &'a mut GlobalState,
msg_sender: &'a Sender<lsp_server::Message>, msg_sender: &'a Sender<lsp_server::Message>,
task_sender: &'a Sender<Task>,
request_received: Instant, request_received: Instant,
} }
@ -725,13 +695,11 @@ impl<'a> PoolDispatcher<'a> {
} }
}; };
self.pool.execute({ self.global_state.task_pool.0.spawn({
let world = self.global_state.snapshot(); let world = self.global_state.snapshot();
let sender = self.task_sender.clone();
move || { move || {
let result = f(world, params); let result = f(world, params);
let task = result_to_task::<R>(id, result); result_to_task::<R>(id, result)
sender.send(task).unwrap();
} }
}); });
@ -801,26 +769,27 @@ where
} }
fn update_file_notifications_on_threadpool( fn update_file_notifications_on_threadpool(
pool: &ThreadPool, global_state: &mut GlobalState,
world: GlobalStateSnapshot,
task_sender: Sender<Task>,
subscriptions: Vec<FileId>, subscriptions: Vec<FileId>,
) { ) {
log::trace!("updating notifications for {:?}", subscriptions); log::trace!("updating notifications for {:?}", subscriptions);
if world.config.publish_diagnostics { if global_state.config.publish_diagnostics {
pool.execute(move || { let snapshot = global_state.snapshot();
for file_id in subscriptions { global_state.task_pool.0.spawn(move || {
match handlers::publish_diagnostics(&world, file_id) { let diagnostics = subscriptions
Err(e) => { .into_iter()
if !is_canceled(&*e) { .filter_map(|file_id| {
log::error!("failed to compute diagnostics: {:?}", e); handlers::publish_diagnostics(&snapshot, file_id)
} .map_err(|err| {
} if !is_canceled(&*err) {
Ok(task) => { log::error!("failed to compute diagnostics: {:?}", err);
task_sender.send(Task::Diagnostic(task)).unwrap();
}
}
} }
()
})
.ok()
})
.collect::<Vec<_>>();
Task::Diagnostics(diagnostics)
}) })
} }
} }

View file

@ -0,0 +1,35 @@
//! A thin wrapper around `ThreadPool` to make sure that we join all things
//! properly.
use crossbeam_channel::Sender;
pub(crate) struct TaskPool<T> {
sender: Sender<T>,
inner: threadpool::ThreadPool,
}
impl<T> TaskPool<T> {
pub(crate) fn new(sender: Sender<T>) -> TaskPool<T> {
TaskPool { sender, inner: threadpool::ThreadPool::default() }
}
pub(crate) fn spawn<F>(&mut self, task: F)
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
self.inner.execute({
let sender = self.sender.clone();
move || sender.send(task()).unwrap()
})
}
pub(crate) fn len(&self) -> usize {
self.inner.queued_count()
}
}
impl<T> Drop for TaskPool<T> {
fn drop(&mut self) {
self.inner.join()
}
}