Move TaskPool into GlobalState

This commit is contained in:
Aleksey Kladov 2020-06-25 15:35:42 +02:00
parent 9be0094b5c
commit dd20c2ec5b
4 changed files with 98 additions and 84 deletions

View file

@ -20,8 +20,9 @@ use crate::{
diagnostics::{CheckFixes, DiagnosticCollection},
from_proto,
line_endings::LineEndings,
main_loop::ReqQueue,
main_loop::{ReqQueue, Task},
request_metrics::{LatestRequests, RequestMetrics},
thread_pool::TaskPool,
to_proto::url_from_abs_path,
Result,
};
@ -66,6 +67,7 @@ impl Default for Status {
/// incremental salsa database.
pub(crate) struct GlobalState {
pub(crate) config: Config,
pub(crate) task_pool: (TaskPool<Task>, Receiver<Task>),
pub(crate) analysis_host: AnalysisHost,
pub(crate) loader: Box<dyn vfs::loader::Handle>,
pub(crate) task_receiver: Receiver<vfs::loader::Message>,
@ -153,8 +155,15 @@ impl GlobalState {
let mut analysis_host = AnalysisHost::new(lru_capacity);
analysis_host.apply_change(change);
let task_pool = {
let (sender, receiver) = unbounded();
(TaskPool::new(sender), receiver)
};
let mut res = GlobalState {
config,
task_pool,
analysis_host,
loader,
task_receiver,

View file

@ -30,6 +30,7 @@ mod diagnostics;
mod line_endings;
mod request_metrics;
mod lsp_utils;
mod thread_pool;
pub mod lsp_ext;
pub mod config;

View file

@ -2,11 +2,10 @@
//! requests/replies and notifications back to the client.
use std::{
env, fmt, panic,
sync::Arc,
time::{Duration, Instant},
};
use crossbeam_channel::{never, select, unbounded, RecvError, Sender};
use crossbeam_channel::{never, select, RecvError, Sender};
use lsp_server::{Connection, ErrorCode, Notification, Request, RequestId, Response};
use lsp_types::{request::Request as _, NumberOrString};
use ra_db::VfsPath;
@ -14,7 +13,6 @@ use ra_ide::{Canceled, FileId};
use ra_prof::profile;
use ra_project_model::{PackageRoot, ProjectWorkspace};
use serde::{de::DeserializeOwned, Serialize};
use threadpool::ThreadPool;
use crate::{
config::{Config, FilesWatcher, LinkedProject},
@ -118,12 +116,8 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> {
GlobalState::new(workspaces, config.lru_capacity, config, req_queue)
};
let pool = ThreadPool::default();
let (task_sender, task_receiver) = unbounded::<Task>();
log::info!("server initialized, serving requests");
{
let task_sender = task_sender;
loop {
log::trace!("selecting");
let event = select! {
@ -131,7 +125,7 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> {
Ok(msg) => Event::Lsp(msg),
Err(RecvError) => return Err("client exited without shutdown".into()),
},
recv(task_receiver) -> task => Event::Task(task.unwrap()),
recv(&global_state.task_pool.1) -> task => Event::Task(task.unwrap()),
recv(global_state.task_receiver) -> task => match task {
Ok(task) => Event::Vfs(task),
Err(RecvError) => return Err("vfs died".into()),
@ -147,29 +141,19 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> {
};
}
assert!(!global_state.vfs.read().0.has_changes());
loop_turn(&pool, &task_sender, &connection, &mut global_state, event)?;
loop_turn(&connection, &mut global_state, event)?;
assert!(!global_state.vfs.read().0.has_changes());
}
}
global_state.analysis_host.request_cancellation();
log::info!("waiting for tasks to finish...");
task_receiver.into_iter().for_each(|task| on_task(task, &connection.sender, &mut global_state));
log::info!("...tasks have finished");
log::info!("joining threadpool...");
pool.join();
drop(pool);
log::info!("...threadpool has finished");
let vfs = Arc::try_unwrap(global_state.vfs).expect("all snapshots should be dead");
drop(vfs);
Ok(())
}
#[derive(Debug)]
enum Task {
pub(crate) enum Task {
Respond(Response),
Diagnostic(DiagnosticTask),
Diagnostics(Vec<DiagnosticTask>),
Unit,
}
enum Event {
@ -215,19 +199,13 @@ pub(crate) type ReqHandler = fn(&mut GlobalState, Response);
pub(crate) type ReqQueue = lsp_server::ReqQueue<(&'static str, Instant), ReqHandler>;
const DO_NOTHING: ReqHandler = |_, _| ();
fn loop_turn(
pool: &ThreadPool,
task_sender: &Sender<Task>,
connection: &Connection,
global_state: &mut GlobalState,
event: Event,
) -> Result<()> {
fn loop_turn(connection: &Connection, global_state: &mut GlobalState, event: Event) -> Result<()> {
let loop_start = Instant::now();
// NOTE: don't count blocking select! call as a loop-turn time
let _p = profile("main_loop_inner/loop-turn");
log::info!("loop turn = {:?}", event);
let queue_count = pool.queued_count();
let queue_count = global_state.task_pool.0.len();
if queue_count > 0 {
log::info!("queued count = {}", queue_count);
}
@ -269,12 +247,10 @@ fn loop_turn(
)
}
},
Event::Flycheck(task) => {
on_check_task(task, global_state, task_sender, &connection.sender)?
}
Event::Flycheck(task) => on_check_task(task, global_state, &connection.sender)?,
Event::Lsp(msg) => match msg {
lsp_server::Message::Request(req) => {
on_request(global_state, pool, task_sender, &connection.sender, loop_start, req)?
on_request(global_state, &connection.sender, loop_start, req)?
}
lsp_server::Message::Notification(not) => {
on_notification(&connection.sender, global_state, not)?;
@ -301,16 +277,14 @@ fn loop_turn(
.map(|path| global_state.vfs.read().0.file_id(&path).unwrap())
.collect::<Vec<_>>();
update_file_notifications_on_threadpool(
pool,
global_state.snapshot(),
task_sender.clone(),
subscriptions.clone(),
);
pool.execute({
update_file_notifications_on_threadpool(global_state, subscriptions.clone());
global_state.task_pool.0.spawn({
let subs = subscriptions;
let snap = global_state.snapshot();
move || snap.analysis.prime_caches(subs).unwrap_or_else(|_: Canceled| ())
move || {
snap.analysis.prime_caches(subs).unwrap_or_else(|_: Canceled| ());
Task::Unit
}
});
}
@ -345,26 +319,21 @@ fn on_task(task: Task, msg_sender: &Sender<lsp_server::Message>, global_state: &
msg_sender.send(response.into()).unwrap();
}
}
Task::Diagnostic(task) => on_diagnostic_task(task, msg_sender, global_state),
Task::Diagnostics(tasks) => {
tasks.into_iter().for_each(|task| on_diagnostic_task(task, msg_sender, global_state))
}
Task::Unit => (),
}
}
fn on_request(
global_state: &mut GlobalState,
pool: &ThreadPool,
task_sender: &Sender<Task>,
msg_sender: &Sender<lsp_server::Message>,
request_received: Instant,
req: Request,
) -> Result<()> {
let mut pool_dispatcher = PoolDispatcher {
req: Some(req),
pool,
global_state,
task_sender,
msg_sender,
request_received,
};
let mut pool_dispatcher =
PoolDispatcher { req: Some(req), global_state, msg_sender, request_received };
pool_dispatcher
.on_sync::<lsp_ext::CollectGarbage>(|s, ()| Ok(s.collect_garbage()))?
.on_sync::<lsp_ext::JoinLines>(|s, p| handlers::handle_join_lines(s.snapshot(), p))?
@ -552,12 +521,11 @@ fn on_notification(
fn on_check_task(
task: flycheck::Message,
global_state: &mut GlobalState,
task_sender: &Sender<Task>,
msg_sender: &Sender<lsp_server::Message>,
) -> Result<()> {
match task {
flycheck::Message::ClearDiagnostics => {
task_sender.send(Task::Diagnostic(DiagnosticTask::ClearCheck))?;
on_diagnostic_task(DiagnosticTask::ClearCheck, msg_sender, global_state)
}
flycheck::Message::AddDiagnostic { workspace_root, diagnostic } => {
@ -576,11 +544,15 @@ fn on_check_task(
}
};
task_sender.send(Task::Diagnostic(DiagnosticTask::AddCheck(
file_id,
diag.diagnostic,
diag.fixes.into_iter().map(|it| it.into()).collect(),
)))?;
on_diagnostic_task(
DiagnosticTask::AddCheck(
file_id,
diag.diagnostic,
diag.fixes.into_iter().map(|it| it.into()).collect(),
),
msg_sender,
global_state,
)
}
}
@ -674,10 +646,8 @@ fn report_progress(
struct PoolDispatcher<'a> {
req: Option<Request>,
pool: &'a ThreadPool,
global_state: &'a mut GlobalState,
msg_sender: &'a Sender<lsp_server::Message>,
task_sender: &'a Sender<Task>,
request_received: Instant,
}
@ -725,13 +695,11 @@ impl<'a> PoolDispatcher<'a> {
}
};
self.pool.execute({
self.global_state.task_pool.0.spawn({
let world = self.global_state.snapshot();
let sender = self.task_sender.clone();
move || {
let result = f(world, params);
let task = result_to_task::<R>(id, result);
sender.send(task).unwrap();
result_to_task::<R>(id, result)
}
});
@ -801,26 +769,27 @@ where
}
fn update_file_notifications_on_threadpool(
pool: &ThreadPool,
world: GlobalStateSnapshot,
task_sender: Sender<Task>,
global_state: &mut GlobalState,
subscriptions: Vec<FileId>,
) {
log::trace!("updating notifications for {:?}", subscriptions);
if world.config.publish_diagnostics {
pool.execute(move || {
for file_id in subscriptions {
match handlers::publish_diagnostics(&world, file_id) {
Err(e) => {
if !is_canceled(&*e) {
log::error!("failed to compute diagnostics: {:?}", e);
}
}
Ok(task) => {
task_sender.send(Task::Diagnostic(task)).unwrap();
}
}
}
if global_state.config.publish_diagnostics {
let snapshot = global_state.snapshot();
global_state.task_pool.0.spawn(move || {
let diagnostics = subscriptions
.into_iter()
.filter_map(|file_id| {
handlers::publish_diagnostics(&snapshot, file_id)
.map_err(|err| {
if !is_canceled(&*err) {
log::error!("failed to compute diagnostics: {:?}", err);
}
()
})
.ok()
})
.collect::<Vec<_>>();
Task::Diagnostics(diagnostics)
})
}
}

View file

@ -0,0 +1,35 @@
//! A thin wrapper around `ThreadPool` to make sure that we join all things
//! properly.
use crossbeam_channel::Sender;
pub(crate) struct TaskPool<T> {
sender: Sender<T>,
inner: threadpool::ThreadPool,
}
impl<T> TaskPool<T> {
pub(crate) fn new(sender: Sender<T>) -> TaskPool<T> {
TaskPool { sender, inner: threadpool::ThreadPool::default() }
}
pub(crate) fn spawn<F>(&mut self, task: F)
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
self.inner.execute({
let sender = self.sender.clone();
move || sender.send(task()).unwrap()
})
}
pub(crate) fn len(&self) -> usize {
self.inner.queued_count()
}
}
impl<T> Drop for TaskPool<T> {
fn drop(&mut self) {
self.inner.join()
}
}