diff --git a/Cargo.lock b/Cargo.lock index 701e36d74a..11283034ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -216,7 +216,7 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5080df6b0f0ecb76cab30808f00d937ba725cebe266a3da8cd89dff92f2a9916" dependencies = [ - "nix", + "nix 0.26.2", "winapi", ] @@ -289,6 +289,16 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "ctrlc" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82e95fbd621905b854affdc67943b043a0fbb6ed7385fd5a25650d19a8a6cfdf" +dependencies = [ + "nix 0.27.1", + "windows-sys 0.48.0", +] + [[package]] name = "dashmap" version = "5.4.0" @@ -961,6 +971,7 @@ name = "lsp-server" version = "0.7.4" dependencies = [ "crossbeam-channel", + "ctrlc", "log", "lsp-types", "serde", @@ -1100,6 +1111,17 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "nix" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" +dependencies = [ + "bitflags 2.4.1", + "cfg-if", + "libc", +] + [[package]] name = "nohash-hasher" version = "0.2.0" @@ -1701,18 +1723,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.192" +version = "1.0.193" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", diff --git a/lib/lsp-server/Cargo.toml b/lib/lsp-server/Cargo.toml index 8d00813b0d..be1573913f 100644 --- a/lib/lsp-server/Cargo.toml +++ b/lib/lsp-server/Cargo.toml @@ -14,3 +14,4 @@ crossbeam-channel = "0.5.6" [dev-dependencies] lsp-types = "=0.94" +ctrlc = "3.4.1" diff --git a/lib/lsp-server/src/lib.rs b/lib/lsp-server/src/lib.rs index affab60a22..b190c0af73 100644 --- a/lib/lsp-server/src/lib.rs +++ b/lib/lsp-server/src/lib.rs @@ -17,7 +17,7 @@ use std::{ net::{TcpListener, TcpStream, ToSocketAddrs}, }; -use crossbeam_channel::{Receiver, Sender}; +use crossbeam_channel::{Receiver, RecvTimeoutError, Sender}; pub use crate::{ error::{ExtractError, ProtocolError}, @@ -113,11 +113,62 @@ impl Connection { /// } /// ``` pub fn initialize_start(&self) -> Result<(RequestId, serde_json::Value), ProtocolError> { - loop { - break match self.receiver.recv() { - Ok(Message::Request(req)) if req.is_initialize() => Ok((req.id, req.params)), + self.initialize_start_while(|| true) + } + + /// Starts the initialization process by waiting for an initialize as described in + /// [`Self::initialize_start`] as long as `running` returns + /// `true` while the return value can be changed through a sig handler such as `CTRL + C`. + /// + /// # Example + /// + /// ```rust + /// use std::sync::atomic::{AtomicBool, Ordering}; + /// use std::sync::Arc; + /// # use std::error::Error; + /// # use lsp_types::{ClientCapabilities, InitializeParams, ServerCapabilities}; + /// # use lsp_server::{Connection, Message, Request, RequestId, Response}; + /// # fn main() -> Result<(), Box> { + /// let running = Arc::new(AtomicBool::new(true)); + /// # running.store(true, Ordering::SeqCst); + /// let r = running.clone(); + /// + /// ctrlc::set_handler(move || { + /// r.store(false, Ordering::SeqCst); + /// }).expect("Error setting Ctrl-C handler"); + /// + /// let (connection, io_threads) = Connection::stdio(); + /// + /// let res = connection.initialize_start_while(|| running.load(Ordering::SeqCst)); + /// # assert!(res.is_err()); + /// + /// # Ok(()) + /// # } + /// ``` + pub fn initialize_start_while( + &self, + running: C, + ) -> Result<(RequestId, serde_json::Value), ProtocolError> + where + C: Fn() -> bool, + { + while running() { + let msg = match self.receiver.recv_timeout(std::time::Duration::from_secs(1)) { + Ok(msg) => msg, + Err(RecvTimeoutError::Timeout) => { + continue; + } + Err(e) => { + return Err(ProtocolError(format!( + "expected initialize request, got error: {e}" + ))) + } + }; + + match msg { + Message::Request(req) if req.is_initialize() => return Ok((req.id, req.params)), // Respond to non-initialize requests with ServerNotInitialized - Ok(Message::Request(req)) => { + Message::Request(req) => { let resp = Response::new_err( req.id.clone(), ErrorCode::ServerNotInitialized as i32, @@ -126,15 +177,18 @@ impl Connection { self.sender.send(resp.into()).unwrap(); continue; } - Ok(Message::Notification(n)) if !n.is_exit() => { + Message::Notification(n) if !n.is_exit() => { continue; } - Ok(msg) => Err(ProtocolError(format!("expected initialize request, got {msg:?}"))), - Err(e) => { - Err(ProtocolError(format!("expected initialize request, got error: {e}"))) + msg => { + return Err(ProtocolError(format!("expected initialize request, got {msg:?}"))); } }; } + + return Err(ProtocolError(String::from( + "Initialization has been aborted during initialization", + ))); } /// Finishes the initialization process by sending an `InitializeResult` to the client @@ -156,6 +210,51 @@ impl Connection { } } + /// Finishes the initialization process as described in [`Self::initialize_finish`] as + /// long as `running` returns `true` while the return value can be changed through a sig + /// handler such as `CTRL + C`. + pub fn initialize_finish_while( + &self, + initialize_id: RequestId, + initialize_result: serde_json::Value, + running: C, + ) -> Result<(), ProtocolError> + where + C: Fn() -> bool, + { + let resp = Response::new_ok(initialize_id, initialize_result); + self.sender.send(resp.into()).unwrap(); + + while running() { + let msg = match self.receiver.recv_timeout(std::time::Duration::from_secs(1)) { + Ok(msg) => msg, + Err(RecvTimeoutError::Timeout) => { + continue; + } + Err(e) => { + return Err(ProtocolError(format!( + "expected initialized notification, got error: {e}", + ))); + } + }; + + match msg { + Message::Notification(n) if n.is_initialized() => { + return Ok(()); + } + msg => { + return Err(ProtocolError(format!( + r#"expected initialized notification, got: {msg:?}"# + ))); + } + } + } + + return Err(ProtocolError(String::from( + "Initialization has been aborted during initialization", + ))); + } + /// Initialize the connection. Sends the server capabilities /// to the client and returns the serialized client capabilities /// on success. If more fine-grained initialization is required use @@ -198,6 +297,58 @@ impl Connection { Ok(params) } + /// Initialize the connection as described in [`Self::initialize`] as long as `running` returns + /// `true` while the return value can be changed through a sig handler such as `CTRL + C`. + /// + /// # Example + /// + /// ```rust + /// use std::sync::atomic::{AtomicBool, Ordering}; + /// use std::sync::Arc; + /// # use std::error::Error; + /// # use lsp_types::ServerCapabilities; + /// # use lsp_server::{Connection, Message, Request, RequestId, Response}; + /// + /// # fn main() -> Result<(), Box> { + /// let running = Arc::new(AtomicBool::new(true)); + /// # running.store(true, Ordering::SeqCst); + /// let r = running.clone(); + /// + /// ctrlc::set_handler(move || { + /// r.store(false, Ordering::SeqCst); + /// }).expect("Error setting Ctrl-C handler"); + /// + /// let (connection, io_threads) = Connection::stdio(); + /// + /// let server_capabilities = serde_json::to_value(&ServerCapabilities::default()).unwrap(); + /// let initialization_params = connection.initialize_while( + /// server_capabilities, + /// || running.load(Ordering::SeqCst) + /// ); + /// + /// # assert!(initialization_params.is_err()); + /// # Ok(()) + /// # } + /// ``` + pub fn initialize_while( + &self, + server_capabilities: serde_json::Value, + running: C, + ) -> Result + where + C: Fn() -> bool, + { + let (id, params) = self.initialize_start_while(&running)?; + + let initialize_data = serde_json::json!({ + "capabilities": server_capabilities, + }); + + self.initialize_finish_while(id, initialize_data, running)?; + + Ok(params) + } + /// If `req` is `Shutdown`, respond to it and return `true`, otherwise return `false` pub fn handle_shutdown(&self, req: &Request) -> Result { if !req.is_shutdown() {