diff --git a/crates/rust-analyzer/src/global_state.rs b/crates/rust-analyzer/src/global_state.rs index 1527c9947d..d1897bf505 100644 --- a/crates/rust-analyzer/src/global_state.rs +++ b/crates/rust-analyzer/src/global_state.rs @@ -20,7 +20,7 @@ use stdx::format_to; use crate::{ config::{Config, FilesWatcher}, diagnostics::{CheckFixes, DiagnosticCollection}, - main_loop::pending_requests::{CompletedRequest, LatestRequests}, + main_loop::req_queue::{CompletedInRequest, LatestRequests}, to_proto::url_from_abs_path, vfs_glob::{Glob, RustPackageFilterBuilder}, LspError, Result, @@ -236,7 +236,7 @@ impl GlobalState { self.analysis_host.collect_garbage() } - pub fn complete_request(&mut self, request: CompletedRequest) { + pub(crate) fn complete_request(&mut self, request: CompletedInRequest) { self.latest_requests.write().record(request) } } diff --git a/crates/rust-analyzer/src/main_loop.rs b/crates/rust-analyzer/src/main_loop.rs index f0aaaa21ea..fd40b2443f 100644 --- a/crates/rust-analyzer/src/main_loop.rs +++ b/crates/rust-analyzer/src/main_loop.rs @@ -3,7 +3,7 @@ mod handlers; mod subscriptions; -pub(crate) mod pending_requests; +pub(crate) mod req_queue; use std::{ borrow::Cow, @@ -28,7 +28,6 @@ use ra_ide::{Canceled, FileId, LineIndex}; use ra_prof::profile; use ra_project_model::{PackageRoot, ProjectWorkspace}; use ra_vfs::VfsTask; -use rustc_hash::FxHashSet; use serde::{de::DeserializeOwned, Serialize}; use threadpool::ThreadPool; @@ -38,12 +37,10 @@ use crate::{ from_proto, global_state::{file_id_to_url, GlobalState, GlobalStateSnapshot}, lsp_ext, - main_loop::{ - pending_requests::{PendingRequest, PendingRequests}, - subscriptions::Subscriptions, - }, + main_loop::subscriptions::Subscriptions, Result, }; +use req_queue::ReqQueue; #[derive(Debug)] pub struct LspError { @@ -153,10 +150,10 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> { register_options: Some(serde_json::to_value(registration_options).unwrap()), }; let params = lsp_types::RegistrationParams { registrations: vec![registration] }; - let request = request_new::( - loop_state.next_request_id(), - params, - ); + let request = loop_state + .req_queue + .outgoing + .register::(params, |_, _| ()); connection.sender.send(request.into()).unwrap(); } @@ -199,7 +196,7 @@ pub fn main_loop(config: Config, connection: Connection) -> Result<()> { global_state.analysis_host.request_cancellation(); log::info!("waiting for tasks to finish..."); task_receiver.into_iter().for_each(|task| { - on_task(task, &connection.sender, &mut loop_state.pending_requests, &mut global_state) + on_task(task, &connection.sender, &mut loop_state.req_queue.incoming, &mut global_state) }); log::info!("...tasks have finished"); log::info!("joining threadpool..."); @@ -264,27 +261,14 @@ impl fmt::Debug for Event { } } -#[derive(Debug, Default)] +#[derive(Default)] struct LoopState { - next_request_id: u64, - pending_responses: FxHashSet, - pending_requests: PendingRequests, + req_queue: ReqQueue, subscriptions: Subscriptions, workspace_loaded: bool, roots_progress_reported: Option, roots_scanned: usize, roots_total: usize, - configuration_request_id: Option, -} - -impl LoopState { - fn next_request_id(&mut self) -> RequestId { - self.next_request_id += 1; - let res: RequestId = self.next_request_id.into(); - let inserted = self.pending_responses.insert(res.clone()); - assert!(inserted); - res - } } fn loop_turn( @@ -307,7 +291,7 @@ fn loop_turn( match event { Event::Task(task) => { - on_task(task, &connection.sender, &mut loop_state.pending_requests, global_state); + on_task(task, &connection.sender, &mut loop_state.req_queue.incoming, global_state); global_state.maybe_collect_garbage(); } Event::Vfs(task) => { @@ -317,7 +301,7 @@ fn loop_turn( Event::Msg(msg) => match msg { Message::Request(req) => on_request( global_state, - &mut loop_state.pending_requests, + &mut loop_state.req_queue.incoming, pool, task_sender, &connection.sender, @@ -328,32 +312,8 @@ fn loop_turn( on_notification(&connection.sender, global_state, loop_state, not)?; } Message::Response(resp) => { - let removed = loop_state.pending_responses.remove(&resp.id); - if !removed { - log::error!("unexpected response: {:?}", resp) - } - - if Some(&resp.id) == loop_state.configuration_request_id.as_ref() { - loop_state.configuration_request_id = None; - log::debug!("config update response: '{:?}", resp); - let Response { error, result, .. } = resp; - - match (error, result) { - (Some(err), _) => { - log::error!("failed to fetch the server settings: {:?}", err) - } - (None, Some(configs)) => { - if let Some(new_config) = configs.get(0) { - let mut config = global_state.config.clone(); - config.update(&new_config); - global_state.update_configuration(config); - } - } - (None, None) => { - log::error!("received empty server settings response from the client") - } - } - } + let handler = loop_state.req_queue.outgoing.complete(resp.id.clone()); + handler(global_state, resp) } }, }; @@ -407,12 +367,12 @@ fn loop_turn( fn on_task( task: Task, msg_sender: &Sender, - pending_requests: &mut PendingRequests, + incoming_requests: &mut req_queue::Incoming, state: &mut GlobalState, ) { match task { Task::Respond(response) => { - if let Some(completed) = pending_requests.finish(&response.id) { + if let Some(completed) = incoming_requests.complete(response.id.clone()) { log::info!("handled req#{} in {:?}", completed.id, completed.duration); state.complete_request(completed); msg_sender.send(response.into()).unwrap(); @@ -427,7 +387,7 @@ fn on_task( fn on_request( global_state: &mut GlobalState, - pending_requests: &mut PendingRequests, + incoming_requests: &mut req_queue::Incoming, pool: &ThreadPool, task_sender: &Sender, msg_sender: &Sender, @@ -440,7 +400,7 @@ fn on_request( global_state, task_sender, msg_sender, - pending_requests, + incoming_requests, request_received, }; pool_dispatcher @@ -504,12 +464,7 @@ fn on_notification( NumberOrString::Number(id) => id.into(), NumberOrString::String(id) => id.into(), }; - if loop_state.pending_requests.cancel(&id) { - let response = Response::new_err( - id, - ErrorCode::RequestCanceled as i32, - "canceled by client".to_string(), - ); + if let Some(response) = loop_state.req_queue.incoming.cancel(id) { msg_sender.send(response.into()).unwrap() } return Ok(()); @@ -572,18 +527,38 @@ fn on_notification( Ok(_) => { // As stated in https://github.com/microsoft/language-server-protocol/issues/676, // this notification's parameters should be ignored and the actual config queried separately. - let request_id = loop_state.next_request_id(); - let request = request_new::( - request_id.clone(), - lsp_types::ConfigurationParams { - items: vec![lsp_types::ConfigurationItem { - scope_uri: None, - section: Some("rust-analyzer".to_string()), - }], - }, - ); + let request = loop_state + .req_queue + .outgoing + .register::( + lsp_types::ConfigurationParams { + items: vec![lsp_types::ConfigurationItem { + scope_uri: None, + section: Some("rust-analyzer".to_string()), + }], + }, + |global_state, resp| { + log::debug!("config update response: '{:?}", resp); + let Response { error, result, .. } = resp; + + match (error, result) { + (Some(err), _) => { + log::error!("failed to fetch the server settings: {:?}", err) + } + (None, Some(configs)) => { + if let Some(new_config) = configs.get(0) { + let mut config = global_state.config.clone(); + config.update(&new_config); + global_state.update_configuration(config); + } + } + (None, None) => log::error!( + "received empty server settings response from the client" + ), + } + }, + ); msg_sender.send(request.into())?; - loop_state.configuration_request_id = Some(request_id); return Ok(()); } @@ -752,13 +727,16 @@ fn send_startup_progress(sender: &Sender, loop_state: &mut LoopState) { match (prev, loop_state.workspace_loaded) { (None, false) => { - let work_done_progress_create = request_new::( - loop_state.next_request_id(), - WorkDoneProgressCreateParams { - token: lsp_types::ProgressToken::String("rustAnalyzer/startup".into()), - }, - ); - sender.send(work_done_progress_create.into()).unwrap(); + let request = loop_state + .req_queue + .outgoing + .register::( + WorkDoneProgressCreateParams { + token: lsp_types::ProgressToken::String("rustAnalyzer/startup".into()), + }, + |_, _| (), + ); + sender.send(request.into()).unwrap(); send_startup_progress_notif( sender, WorkDoneProgress::Begin(WorkDoneProgressBegin { @@ -800,7 +778,7 @@ struct PoolDispatcher<'a> { req: Option, pool: &'a ThreadPool, global_state: &'a mut GlobalState, - pending_requests: &'a mut PendingRequests, + incoming_requests: &'a mut req_queue::Incoming, msg_sender: &'a Sender, task_sender: &'a Sender, request_received: Instant, @@ -829,7 +807,7 @@ impl<'a> PoolDispatcher<'a> { result_to_task::(id, result) }) .map_err(|_| format!("sync task {:?} panicked", R::METHOD))?; - on_task(task, self.msg_sender, self.pending_requests, self.global_state); + on_task(task, self.msg_sender, self.incoming_requests, self.global_state); Ok(self) } @@ -876,7 +854,7 @@ impl<'a> PoolDispatcher<'a> { return None; } }; - self.pending_requests.start(PendingRequest { + self.incoming_requests.register(req_queue::PendingInRequest { id: id.clone(), method: R::METHOD.to_string(), received: self.request_received, @@ -993,14 +971,6 @@ where Notification::new(N::METHOD.to_string(), params) } -fn request_new(id: RequestId, params: R::Params) -> Request -where - R: lsp_types::request::Request, - R::Params: Serialize, -{ - Request::new(id, R::METHOD.to_string(), params) -} - #[cfg(test)] mod tests { use std::borrow::Cow; diff --git a/crates/rust-analyzer/src/main_loop/pending_requests.rs b/crates/rust-analyzer/src/main_loop/pending_requests.rs deleted file mode 100644 index 73b33e4194..0000000000 --- a/crates/rust-analyzer/src/main_loop/pending_requests.rs +++ /dev/null @@ -1,75 +0,0 @@ -//! Data structures that keep track of inflight requests. - -use std::time::{Duration, Instant}; - -use lsp_server::RequestId; -use rustc_hash::FxHashMap; - -#[derive(Debug)] -pub struct CompletedRequest { - pub id: RequestId, - pub method: String, - pub duration: Duration, -} - -#[derive(Debug)] -pub(crate) struct PendingRequest { - pub(crate) id: RequestId, - pub(crate) method: String, - pub(crate) received: Instant, -} - -impl From for CompletedRequest { - fn from(pending: PendingRequest) -> CompletedRequest { - CompletedRequest { - id: pending.id, - method: pending.method, - duration: pending.received.elapsed(), - } - } -} - -#[derive(Debug, Default)] -pub(crate) struct PendingRequests { - map: FxHashMap, -} - -impl PendingRequests { - pub(crate) fn start(&mut self, request: PendingRequest) { - let id = request.id.clone(); - let prev = self.map.insert(id.clone(), request); - assert!(prev.is_none(), "duplicate request with id {}", id); - } - pub(crate) fn cancel(&mut self, id: &RequestId) -> bool { - self.map.remove(id).is_some() - } - pub(crate) fn finish(&mut self, id: &RequestId) -> Option { - self.map.remove(id).map(CompletedRequest::from) - } -} - -const N_COMPLETED_REQUESTS: usize = 10; - -#[derive(Debug, Default)] -pub struct LatestRequests { - // hand-rolling VecDeque here to print things in a nicer way - buf: [Option; N_COMPLETED_REQUESTS], - idx: usize, -} - -impl LatestRequests { - pub(crate) fn record(&mut self, request: CompletedRequest) { - // special case: don't track status request itself - if request.method == "rust-analyzer/analyzerStatus" { - return; - } - let idx = self.idx; - self.buf[idx] = Some(request); - self.idx = (idx + 1) % N_COMPLETED_REQUESTS; - } - - pub(crate) fn iter(&self) -> impl Iterator { - let idx = self.idx; - self.buf.iter().enumerate().filter_map(move |(i, req)| Some((i == idx, req.as_ref()?))) - } -} diff --git a/crates/rust-analyzer/src/main_loop/req_queue.rs b/crates/rust-analyzer/src/main_loop/req_queue.rs new file mode 100644 index 0000000000..5cf6d916b7 --- /dev/null +++ b/crates/rust-analyzer/src/main_loop/req_queue.rs @@ -0,0 +1,123 @@ +//! Manages the set of in-flight requests in both directions. +use std::time::{Duration, Instant}; + +use lsp_server::RequestId; +use rustc_hash::FxHashMap; +use serde::Serialize; + +#[derive(Debug)] +pub(crate) struct ReqQueue { + pub(crate) incoming: Incoming, + pub(crate) outgoing: Outgoing, +} + +impl Default for ReqQueue { + fn default() -> Self { + ReqQueue { incoming: Incoming::default(), outgoing: Outgoing::default() } + } +} + +#[derive(Debug)] +pub(crate) struct Outgoing { + next: u64, + pending: FxHashMap, +} + +impl Default for Outgoing { + fn default() -> Self { + Outgoing { next: 0, pending: FxHashMap::default() } + } +} + +impl Outgoing { + pub(crate) fn register(&mut self, params: R::Params, handler: H) -> lsp_server::Request + where + R: lsp_types::request::Request, + R::Params: Serialize, + { + let id = RequestId::from(self.next); + self.next += 1; + self.pending.insert(id.clone(), handler); + lsp_server::Request::new(id, R::METHOD.to_string(), params) + } + pub(crate) fn complete(&mut self, id: RequestId) -> H { + self.pending.remove(&id).unwrap() + } +} + +#[derive(Debug)] +pub(crate) struct CompletedInRequest { + pub(crate) id: RequestId, + pub(crate) method: String, + pub(crate) duration: Duration, +} + +#[derive(Debug)] +pub(crate) struct PendingInRequest { + pub(crate) id: RequestId, + pub(crate) method: String, + pub(crate) received: Instant, +} + +impl From for CompletedInRequest { + fn from(pending: PendingInRequest) -> CompletedInRequest { + CompletedInRequest { + id: pending.id, + method: pending.method, + duration: pending.received.elapsed(), + } + } +} + +#[derive(Debug, Default)] +pub(crate) struct Incoming { + pending: FxHashMap, +} + +impl Incoming { + pub(crate) fn register(&mut self, request: PendingInRequest) { + let id = request.id.clone(); + let prev = self.pending.insert(id.clone(), request); + assert!(prev.is_none(), "duplicate request with id {}", id); + } + pub(crate) fn cancel(&mut self, id: RequestId) -> Option { + if self.pending.remove(&id).is_some() { + Some(lsp_server::Response::new_err( + id, + lsp_server::ErrorCode::RequestCanceled as i32, + "canceled by client".to_string(), + )) + } else { + None + } + } + pub(crate) fn complete(&mut self, id: RequestId) -> Option { + self.pending.remove(&id).map(CompletedInRequest::from) + } +} + +const N_COMPLETED_REQUESTS: usize = 10; + +#[derive(Debug, Default)] +pub struct LatestRequests { + // hand-rolling VecDeque here to print things in a nicer way + buf: [Option; N_COMPLETED_REQUESTS], + idx: usize, +} + +impl LatestRequests { + pub(crate) fn record(&mut self, request: CompletedInRequest) { + // special case: don't track status request itself + if request.method == "rust-analyzer/analyzerStatus" { + return; + } + let idx = self.idx; + self.buf[idx] = Some(request); + self.idx = (idx + 1) % N_COMPLETED_REQUESTS; + } + + pub(crate) fn iter(&self) -> impl Iterator { + let idx = self.idx; + self.buf.iter().enumerate().filter_map(move |(i, req)| Some((i == idx, req.as_ref()?))) + } +}