diff --git a/Cargo.lock b/Cargo.lock index 62b2ac86f3..917eed6d12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -866,11 +866,10 @@ dependencies = [ [[package]] name = "lsp-server" version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f70570c1c29cf6654029b8fe201a5507c153f0d85be6f234d471d756bc36775a" dependencies = [ "crossbeam-channel", "log", + "lsp-types", "serde", "serde_json", ] diff --git a/crates/hir-def/Cargo.toml b/crates/hir-def/Cargo.toml index d369c3ed28..8222212936 100644 --- a/crates/hir-def/Cargo.toml +++ b/crates/hir-def/Cargo.toml @@ -26,7 +26,7 @@ itertools = "0.10.3" indexmap = "1.8.0" smallvec = "1.8.0" arrayvec = "0.7.2" -la-arena = { version = "0.3.0", path = "../../lib/arena" } +la-arena = { version = "0.3.0", path = "../../lib/la-arena" } stdx = { path = "../stdx", version = "0.0.0" } base-db = { path = "../base-db", version = "0.0.0" } diff --git a/crates/hir-expand/Cargo.toml b/crates/hir-expand/Cargo.toml index f4770a6f9f..2a7e26fa2e 100644 --- a/crates/hir-expand/Cargo.toml +++ b/crates/hir-expand/Cargo.toml @@ -14,7 +14,7 @@ cov-mark = "2.0.0-pre.1" tracing = "0.1.32" either = "1.6.1" rustc-hash = "1.1.0" -la-arena = { version = "0.3.0", path = "../../lib/arena" } +la-arena = { version = "0.3.0", path = "../../lib/la-arena" } itertools = "0.10.3" hashbrown = { version = "0.12.0", features = [ "inline-more", diff --git a/crates/hir-ty/Cargo.toml b/crates/hir-ty/Cargo.toml index 06cdb1e4e0..10362f390b 100644 --- a/crates/hir-ty/Cargo.toml +++ b/crates/hir-ty/Cargo.toml @@ -21,7 +21,7 @@ scoped-tls = "1.0.0" chalk-solve = { version = "0.82.0", default-features = false } chalk-ir = "0.82.0" chalk-recursive = { version = "0.82.0", default-features = false } -la-arena = { version = "0.3.0", path = "../../lib/arena" } +la-arena = { version = "0.3.0", path = "../../lib/la-arena" } once_cell = "1.10.0" typed-arena = "2.0.1" diff --git a/crates/profile/Cargo.toml b/crates/profile/Cargo.toml index 8324f48de9..f37b362ebd 100644 --- a/crates/profile/Cargo.toml +++ b/crates/profile/Cargo.toml @@ -13,7 +13,7 @@ doctest = false once_cell = "1.10.0" cfg-if = "1.0.0" libc = "0.2.121" -la-arena = { version = "0.3.0", path = "../../lib/arena" } +la-arena = { version = "0.3.0", path = "../../lib/la-arena" } countme = { version = "3.0.1", features = ["enable"] } jemalloc-ctl = { version = "0.4.2", package = "tikv-jemalloc-ctl", optional = true } diff --git a/crates/project-model/Cargo.toml b/crates/project-model/Cargo.toml index 8fafdb1850..ecf3bcc8ab 100644 --- a/crates/project-model/Cargo.toml +++ b/crates/project-model/Cargo.toml @@ -18,7 +18,7 @@ serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0.79" anyhow = "1.0.56" expect-test = "1.2.2" -la-arena = { version = "0.3.0", path = "../../lib/arena" } +la-arena = { version = "0.3.0", path = "../../lib/la-arena" } cfg = { path = "../cfg", version = "0.0.0" } base-db = { path = "../base-db", version = "0.0.0" } diff --git a/crates/rust-analyzer/Cargo.toml b/crates/rust-analyzer/Cargo.toml index 767f59e61c..133459bf1d 100644 --- a/crates/rust-analyzer/Cargo.toml +++ b/crates/rust-analyzer/Cargo.toml @@ -33,7 +33,7 @@ threadpool = "1.8.1" rayon = "1.5.1" num_cpus = "1.13.1" mimalloc = { version = "0.1.28", default-features = false, optional = true } -lsp-server = "0.6.0" +lsp-server = { version = "0.6.0", path = "../../lib/lsp-server" } tracing = "0.1.32" tracing-subscriber = { version = "0.3.9", default-features = false, features = [ "env-filter", diff --git a/lib/arena/Cargo.toml b/lib/la-arena/Cargo.toml similarity index 77% rename from lib/arena/Cargo.toml rename to lib/la-arena/Cargo.toml index 2d3243a29b..ec5ba8ba00 100644 --- a/lib/arena/Cargo.toml +++ b/lib/la-arena/Cargo.toml @@ -3,7 +3,7 @@ name = "la-arena" version = "0.3.0" description = "Simple index-based arena without deletion." license = "MIT OR Apache-2.0" -repository = "https://github.com/rust-lang/rust-analyzer" +repository = "https://github.com/rust-lang/rust-analyzer/tree/master/lib/la-arena" documentation = "https://docs.rs/la-arena" categories = ["data-structures", "memory-management", "rust-patterns"] edition = "2021" diff --git a/lib/arena/src/lib.rs b/lib/la-arena/src/lib.rs similarity index 100% rename from lib/arena/src/lib.rs rename to lib/la-arena/src/lib.rs diff --git a/lib/arena/src/map.rs b/lib/la-arena/src/map.rs similarity index 100% rename from lib/arena/src/map.rs rename to lib/la-arena/src/map.rs diff --git a/lib/lsp-server/Cargo.toml b/lib/lsp-server/Cargo.toml new file mode 100644 index 0000000000..fd92cbe195 --- /dev/null +++ b/lib/lsp-server/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "lsp-server" +version = "0.6.0" +description = "Generic LSP server scaffold." +license = "MIT OR Apache-2.0" +repository = "https://github.com/rust-analyzer/rust-analyzer/tree/master/lib/lsp-server" +edition = "2021" + +[dependencies] +log = "0.4.3" +serde_json = "1.0.34" +serde = { version = "1.0.83", features = ["derive"] } +crossbeam-channel = "0.5.4" + +[dev-dependencies] +lsp-types = "0.93.0" diff --git a/lib/lsp-server/examples/goto_def.rs b/lib/lsp-server/examples/goto_def.rs new file mode 100644 index 0000000000..ca7ad0b536 --- /dev/null +++ b/lib/lsp-server/examples/goto_def.rs @@ -0,0 +1,121 @@ +//! A minimal example LSP server that can only respond to the `gotoDefinition` request. To use +//! this example, execute it and then send an `initialize` request. +//! +//! ```no_run +//! Content-Length: 85 +//! +//! {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {"capabilities": {}}} +//! ``` +//! +//! This will respond with a server response. Then send it a `initialized` notification which will +//! have no response. +//! +//! ```no_run +//! Content-Length: 59 +//! +//! {"jsonrpc": "2.0", "method": "initialized", "params": {}} +//! ``` +//! +//! Once these two are sent, then we enter the main loop of the server. The only request this +//! example can handle is `gotoDefinition`: +//! +//! ```no_run +//! Content-Length: 159 +//! +//! {"jsonrpc": "2.0", "method": "textDocument/definition", "id": 2, "params": {"textDocument": {"uri": "file://temp"}, "position": {"line": 1, "character": 1}}} +//! ``` +//! +//! To finish up without errors, send a shutdown request: +//! +//! ```no_run +//! Content-Length: 67 +//! +//! {"jsonrpc": "2.0", "method": "shutdown", "id": 3, "params": null} +//! ``` +//! +//! The server will exit the main loop and finally we send a `shutdown` notification to stop +//! the server. +//! +//! ``` +//! Content-Length: 54 +//! +//! {"jsonrpc": "2.0", "method": "exit", "params": null} +//! ``` +use std::error::Error; + +use lsp_types::OneOf; +use lsp_types::{ + request::GotoDefinition, GotoDefinitionResponse, InitializeParams, ServerCapabilities, +}; + +use lsp_server::{Connection, ExtractError, Message, Request, RequestId, Response}; + +fn main() -> Result<(), Box> { + // Note that we must have our logging only write out to stderr. + eprintln!("starting generic LSP server"); + + // Create the transport. Includes the stdio (stdin and stdout) versions but this could + // also be implemented to use sockets or HTTP. + let (connection, io_threads) = Connection::stdio(); + + // Run the server and wait for the two threads to end (typically by trigger LSP Exit event). + let server_capabilities = serde_json::to_value(&ServerCapabilities { + definition_provider: Some(OneOf::Left(true)), + ..Default::default() + }) + .unwrap(); + let initialization_params = connection.initialize(server_capabilities)?; + main_loop(connection, initialization_params)?; + io_threads.join()?; + + // Shut down gracefully. + eprintln!("shutting down server"); + Ok(()) +} + +fn main_loop( + connection: Connection, + params: serde_json::Value, +) -> Result<(), Box> { + let _params: InitializeParams = serde_json::from_value(params).unwrap(); + eprintln!("starting example main loop"); + for msg in &connection.receiver { + eprintln!("got msg: {:?}", msg); + match msg { + Message::Request(req) => { + if connection.handle_shutdown(&req)? { + return Ok(()); + } + eprintln!("got request: {:?}", req); + match cast::(req) { + Ok((id, params)) => { + eprintln!("got gotoDefinition request #{}: {:?}", id, params); + let result = Some(GotoDefinitionResponse::Array(Vec::new())); + let result = serde_json::to_value(&result).unwrap(); + let resp = Response { id, result: Some(result), error: None }; + connection.sender.send(Message::Response(resp))?; + continue; + } + Err(err @ ExtractError::JsonError { .. }) => panic!("{:?}", err), + Err(ExtractError::MethodMismatch(req)) => req, + }; + // ... + } + Message::Response(resp) => { + eprintln!("got response: {:?}", resp); + } + Message::Notification(not) => { + eprintln!("got notification: {:?}", not); + } + } + } + Ok(()) +} + +fn cast(req: Request) -> Result<(RequestId, R::Params), ExtractError> +where + R: lsp_types::request::Request, + R::Params: serde::de::DeserializeOwned, +{ + req.extract(R::METHOD) +} diff --git a/lib/lsp-server/src/error.rs b/lib/lsp-server/src/error.rs new file mode 100644 index 0000000000..4c934d9ecc --- /dev/null +++ b/lib/lsp-server/src/error.rs @@ -0,0 +1,50 @@ +use std::fmt; + +use crate::{Notification, Request}; + +#[derive(Debug, Clone)] +pub struct ProtocolError(pub(crate) String); + +impl std::error::Error for ProtocolError {} + +impl fmt::Display for ProtocolError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +#[derive(Debug)] +pub enum ExtractError { + /// The extracted message was of a different method than expected. + MethodMismatch(T), + /// Failed to deserialize the message. + JsonError { method: String, error: serde_json::Error }, +} + +impl std::error::Error for ExtractError {} +impl fmt::Display for ExtractError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ExtractError::MethodMismatch(req) => { + write!(f, "Method mismatch for request '{}'", req.method) + } + ExtractError::JsonError { method, error } => { + write!(f, "Invalid request\nMethod: {method}\n error: {error}",) + } + } + } +} + +impl std::error::Error for ExtractError {} +impl fmt::Display for ExtractError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ExtractError::MethodMismatch(req) => { + write!(f, "Method mismatch for notification '{}'", req.method) + } + ExtractError::JsonError { method, error } => { + write!(f, "Invalid notification\nMethod: {method}\n error: {error}") + } + } + } +} diff --git a/lib/lsp-server/src/lib.rs b/lib/lsp-server/src/lib.rs new file mode 100644 index 0000000000..1aaf327da0 --- /dev/null +++ b/lib/lsp-server/src/lib.rs @@ -0,0 +1,229 @@ +//! A language server scaffold, exposing a synchronous crossbeam-channel based API. +//! This crate handles protocol handshaking and parsing messages, while you +//! control the message dispatch loop yourself. +//! +//! Run with `RUST_LOG=lsp_server=debug` to see all the messages. +mod msg; +mod stdio; +mod error; +mod socket; +mod req_queue; + +use std::{ + io, + net::{TcpListener, TcpStream, ToSocketAddrs}, +}; + +use crossbeam_channel::{Receiver, Sender}; + +pub use crate::{ + error::{ExtractError, ProtocolError}, + msg::{ErrorCode, Message, Notification, Request, RequestId, Response, ResponseError}, + req_queue::{Incoming, Outgoing, ReqQueue}, + stdio::IoThreads, +}; + +/// Connection is just a pair of channels of LSP messages. +pub struct Connection { + pub sender: Sender, + pub receiver: Receiver, +} + +impl Connection { + /// Create connection over standard in/standard out. + /// + /// Use this to create a real language server. + pub fn stdio() -> (Connection, IoThreads) { + let (sender, receiver, io_threads) = stdio::stdio_transport(); + (Connection { sender, receiver }, io_threads) + } + + /// Open a connection over tcp. + /// This call blocks until a connection is established. + /// + /// Use this to create a real language server. + pub fn connect(addr: A) -> io::Result<(Connection, IoThreads)> { + let stream = TcpStream::connect(addr)?; + let (sender, receiver, io_threads) = socket::socket_transport(stream); + Ok((Connection { sender, receiver }, io_threads)) + } + + /// Listen for a connection over tcp. + /// This call blocks until a connection is established. + /// + /// Use this to create a real language server. + pub fn listen(addr: A) -> io::Result<(Connection, IoThreads)> { + let listener = TcpListener::bind(addr)?; + let (stream, _) = listener.accept()?; + let (sender, receiver, io_threads) = socket::socket_transport(stream); + Ok((Connection { sender, receiver }, io_threads)) + } + + /// Creates a pair of connected connections. + /// + /// Use this for testing. + pub fn memory() -> (Connection, Connection) { + let (s1, r1) = crossbeam_channel::unbounded(); + let (s2, r2) = crossbeam_channel::unbounded(); + (Connection { sender: s1, receiver: r2 }, Connection { sender: s2, receiver: r1 }) + } + + /// Starts the initialization process by waiting for an initialize + /// request from the client. Use this for more advanced customization than + /// `initialize` can provide. + /// + /// Returns the request id and serialized `InitializeParams` from the client. + /// + /// # Example + /// + /// ```no_run + /// use std::error::Error; + /// use lsp_types::{ClientCapabilities, InitializeParams, ServerCapabilities}; + /// + /// use lsp_server::{Connection, Message, Request, RequestId, Response}; + /// + /// fn main() -> Result<(), Box> { + /// // Create the transport. Includes the stdio (stdin and stdout) versions but this could + /// // also be implemented to use sockets or HTTP. + /// let (connection, io_threads) = Connection::stdio(); + /// + /// // Run the server + /// let (id, params) = connection.initialize_start()?; + /// + /// let init_params: InitializeParams = serde_json::from_value(params).unwrap(); + /// let client_capabilities: ClientCapabilities = init_params.capabilities; + /// let server_capabilities = ServerCapabilities::default(); + /// + /// let initialize_data = serde_json::json!({ + /// "capabilities": server_capabilities, + /// "serverInfo": { + /// "name": "lsp-server-test", + /// "version": "0.1" + /// } + /// }); + /// + /// connection.initialize_finish(id, initialize_data)?; + /// + /// // ... Run main loop ... + /// + /// Ok(()) + /// } + /// ``` + pub fn initialize_start(&self) -> Result<(RequestId, serde_json::Value), ProtocolError> { + loop { + match self.receiver.recv() { + Ok(Message::Request(req)) if req.is_initialize() => { + return Ok((req.id, req.params)) + } + // Respond to non-initialize requests with ServerNotInitialized + Ok(Message::Request(req)) => { + let resp = Response::new_err( + req.id.clone(), + ErrorCode::ServerNotInitialized as i32, + format!("expected initialize request, got {:?}", req), + ); + self.sender.send(resp.into()).unwrap(); + } + Ok(msg) => { + return Err(ProtocolError(format!( + "expected initialize request, got {:?}", + msg + ))) + } + Err(e) => { + return Err(ProtocolError(format!( + "expected initialize request, got error: {}", + e + ))) + } + }; + } + } + + /// Finishes the initialization process by sending an `InitializeResult` to the client + pub fn initialize_finish( + &self, + initialize_id: RequestId, + initialize_result: serde_json::Value, + ) -> Result<(), ProtocolError> { + let resp = Response::new_ok(initialize_id, initialize_result); + self.sender.send(resp.into()).unwrap(); + match &self.receiver.recv() { + Ok(Message::Notification(n)) if n.is_initialized() => (), + Ok(msg) => { + return Err(ProtocolError(format!( + "expected Message::Notification, got: {:?}", + msg, + ))) + } + Err(e) => { + return Err(ProtocolError(format!( + "expected initialized notification, got error: {}", + e, + ))) + } + } + Ok(()) + } + + /// Initialize the connection. Sends the server capabilities + /// to the client and returns the serialized client capabilities + /// on success. If more fine-grained initialization is required use + /// `initialize_start`/`initialize_finish`. + /// + /// # Example + /// + /// ```no_run + /// use std::error::Error; + /// use lsp_types::ServerCapabilities; + /// + /// use lsp_server::{Connection, Message, Request, RequestId, Response}; + /// + /// fn main() -> Result<(), Box> { + /// // Create the transport. Includes the stdio (stdin and stdout) versions but this could + /// // also be implemented to use sockets or HTTP. + /// let (connection, io_threads) = Connection::stdio(); + /// + /// // Run the server + /// let server_capabilities = serde_json::to_value(&ServerCapabilities::default()).unwrap(); + /// let initialization_params = connection.initialize(server_capabilities)?; + /// + /// // ... Run main loop ... + /// + /// Ok(()) + /// } + /// ``` + pub fn initialize( + &self, + server_capabilities: serde_json::Value, + ) -> Result { + let (id, params) = self.initialize_start()?; + + let initialize_data = serde_json::json!({ + "capabilities": server_capabilities, + }); + + self.initialize_finish(id, initialize_data)?; + + Ok(params) + } + + /// If `req` is `Shutdown`, respond to it and return `true`, otherwise return `false` + pub fn handle_shutdown(&self, req: &Request) -> Result { + if !req.is_shutdown() { + return Ok(false); + } + let resp = Response::new_ok(req.id.clone(), ()); + let _ = self.sender.send(resp.into()); + match &self.receiver.recv_timeout(std::time::Duration::from_secs(30)) { + Ok(Message::Notification(n)) if n.is_exit() => (), + Ok(msg) => { + return Err(ProtocolError(format!("unexpected message during shutdown: {:?}", msg))) + } + Err(e) => { + return Err(ProtocolError(format!("unexpected error during shutdown: {}", e))) + } + } + Ok(true) + } +} diff --git a/lib/lsp-server/src/msg.rs b/lib/lsp-server/src/msg.rs new file mode 100644 index 0000000000..97e5bd35ce --- /dev/null +++ b/lib/lsp-server/src/msg.rs @@ -0,0 +1,343 @@ +use std::{ + fmt, + io::{self, BufRead, Write}, +}; + +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use crate::error::ExtractError; + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum Message { + Request(Request), + Response(Response), + Notification(Notification), +} + +impl From for Message { + fn from(request: Request) -> Message { + Message::Request(request) + } +} + +impl From for Message { + fn from(response: Response) -> Message { + Message::Response(response) + } +} + +impl From for Message { + fn from(notification: Notification) -> Message { + Message::Notification(notification) + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[serde(transparent)] +pub struct RequestId(IdRepr); + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[serde(untagged)] +enum IdRepr { + I32(i32), + String(String), +} + +impl From for RequestId { + fn from(id: i32) -> RequestId { + RequestId(IdRepr::I32(id)) + } +} + +impl From for RequestId { + fn from(id: String) -> RequestId { + RequestId(IdRepr::String(id)) + } +} + +impl fmt::Display for RequestId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + IdRepr::I32(it) => fmt::Display::fmt(it, f), + // Use debug here, to make it clear that `92` and `"92"` are + // different, and to reduce WTF factor if the sever uses `" "` as an + // ID. + IdRepr::String(it) => fmt::Debug::fmt(it, f), + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Request { + pub id: RequestId, + pub method: String, + #[serde(default = "serde_json::Value::default")] + #[serde(skip_serializing_if = "serde_json::Value::is_null")] + pub params: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Response { + // JSON RPC allows this to be null if it was impossible + // to decode the request's id. Ignore this special case + // and just die horribly. + pub id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseError { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +#[derive(Clone, Copy, Debug)] +#[allow(unused)] +pub enum ErrorCode { + // Defined by JSON RPC: + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + ServerErrorStart = -32099, + ServerErrorEnd = -32000, + + /// Error code indicating that a server received a notification or + /// request before the server has received the `initialize` request. + ServerNotInitialized = -32002, + UnknownErrorCode = -32001, + + // Defined by the protocol: + /// The client has canceled a request and a server has detected + /// the cancel. + RequestCanceled = -32800, + + /// The server detected that the content of a document got + /// modified outside normal conditions. A server should + /// NOT send this error code if it detects a content change + /// in it unprocessed messages. The result even computed + /// on an older state might still be useful for the client. + /// + /// If a client decides that a result is not of any use anymore + /// the client should cancel the request. + ContentModified = -32801, + + /// The server cancelled the request. This error code should + /// only be used for requests that explicitly support being + /// server cancellable. + /// + /// @since 3.17.0 + ServerCancelled = -32802, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Notification { + pub method: String, + #[serde(default = "serde_json::Value::default")] + #[serde(skip_serializing_if = "serde_json::Value::is_null")] + pub params: serde_json::Value, +} + +impl Message { + pub fn read(r: &mut impl BufRead) -> io::Result> { + Message::_read(r) + } + fn _read(r: &mut dyn BufRead) -> io::Result> { + let text = match read_msg_text(r)? { + None => return Ok(None), + Some(text) => text, + }; + let msg = serde_json::from_str(&text)?; + Ok(Some(msg)) + } + pub fn write(self, w: &mut impl Write) -> io::Result<()> { + self._write(w) + } + fn _write(self, w: &mut dyn Write) -> io::Result<()> { + #[derive(Serialize)] + struct JsonRpc { + jsonrpc: &'static str, + #[serde(flatten)] + msg: Message, + } + let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0", msg: self })?; + write_msg_text(w, &text) + } +} + +impl Response { + pub fn new_ok(id: RequestId, result: R) -> Response { + Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None } + } + pub fn new_err(id: RequestId, code: i32, message: String) -> Response { + let error = ResponseError { code, message, data: None }; + Response { id, result: None, error: Some(error) } + } +} + +impl Request { + pub fn new(id: RequestId, method: String, params: P) -> Request { + Request { id, method, params: serde_json::to_value(params).unwrap() } + } + pub fn extract( + self, + method: &str, + ) -> Result<(RequestId, P), ExtractError> { + if self.method != method { + return Err(ExtractError::MethodMismatch(self)); + } + match serde_json::from_value(self.params) { + Ok(params) => Ok((self.id, params)), + Err(error) => Err(ExtractError::JsonError { method: self.method, error }), + } + } + + pub(crate) fn is_shutdown(&self) -> bool { + self.method == "shutdown" + } + pub(crate) fn is_initialize(&self) -> bool { + self.method == "initialize" + } +} + +impl Notification { + pub fn new(method: String, params: impl Serialize) -> Notification { + Notification { method, params: serde_json::to_value(params).unwrap() } + } + pub fn extract( + self, + method: &str, + ) -> Result> { + if self.method != method { + return Err(ExtractError::MethodMismatch(self)); + } + match serde_json::from_value(self.params) { + Ok(params) => Ok(params), + Err(error) => Err(ExtractError::JsonError { method: self.method, error }), + } + } + pub(crate) fn is_exit(&self) -> bool { + self.method == "exit" + } + pub(crate) fn is_initialized(&self) -> bool { + self.method == "initialized" + } +} + +fn read_msg_text(inp: &mut dyn BufRead) -> io::Result> { + fn invalid_data(error: impl Into>) -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, error) + } + macro_rules! invalid_data { + ($($tt:tt)*) => (invalid_data(format!($($tt)*))) + } + + let mut size = None; + let mut buf = String::new(); + loop { + buf.clear(); + if inp.read_line(&mut buf)? == 0 { + return Ok(None); + } + if !buf.ends_with("\r\n") { + return Err(invalid_data!("malformed header: {:?}", buf)); + } + let buf = &buf[..buf.len() - 2]; + if buf.is_empty() { + break; + } + let mut parts = buf.splitn(2, ": "); + let header_name = parts.next().unwrap(); + let header_value = + parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?; + if header_name == "Content-Length" { + size = Some(header_value.parse::().map_err(invalid_data)?); + } + } + let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?; + let mut buf = buf.into_bytes(); + buf.resize(size, 0); + inp.read_exact(&mut buf)?; + let buf = String::from_utf8(buf).map_err(invalid_data)?; + log::debug!("< {}", buf); + Ok(Some(buf)) +} + +fn write_msg_text(out: &mut dyn Write, msg: &str) -> io::Result<()> { + log::debug!("> {}", msg); + write!(out, "Content-Length: {}\r\n\r\n", msg.len())?; + out.write_all(msg.as_bytes())?; + out.flush()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::{Message, Notification, Request, RequestId}; + + #[test] + fn shutdown_with_explicit_null() { + let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\", \"params\": null }"; + let msg: Message = serde_json::from_str(text).unwrap(); + + assert!( + matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown") + ); + } + + #[test] + fn shutdown_with_no_params() { + let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\"}"; + let msg: Message = serde_json::from_str(text).unwrap(); + + assert!( + matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown") + ); + } + + #[test] + fn notification_with_explicit_null() { + let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\", \"params\": null }"; + let msg: Message = serde_json::from_str(text).unwrap(); + + assert!(matches!(msg, Message::Notification(not) if not.method == "exit")); + } + + #[test] + fn notification_with_no_params() { + let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\"}"; + let msg: Message = serde_json::from_str(text).unwrap(); + + assert!(matches!(msg, Message::Notification(not) if not.method == "exit")); + } + + #[test] + fn serialize_request_with_null_params() { + let msg = Message::Request(Request { + id: RequestId::from(3), + method: "shutdown".into(), + params: serde_json::Value::Null, + }); + let serialized = serde_json::to_string(&msg).unwrap(); + + assert_eq!("{\"id\":3,\"method\":\"shutdown\"}", serialized); + } + + #[test] + fn serialize_notification_with_null_params() { + let msg = Message::Notification(Notification { + method: "exit".into(), + params: serde_json::Value::Null, + }); + let serialized = serde_json::to_string(&msg).unwrap(); + + assert_eq!("{\"method\":\"exit\"}", serialized); + } +} diff --git a/lib/lsp-server/src/req_queue.rs b/lib/lsp-server/src/req_queue.rs new file mode 100644 index 0000000000..1f3d447153 --- /dev/null +++ b/lib/lsp-server/src/req_queue.rs @@ -0,0 +1,62 @@ +use std::collections::HashMap; + +use serde::Serialize; + +use crate::{ErrorCode, Request, RequestId, Response, ResponseError}; + +/// Manages the set of pending requests, both incoming and outgoing. +#[derive(Debug)] +pub struct ReqQueue { + pub incoming: Incoming, + pub outgoing: Outgoing, +} + +impl Default for ReqQueue { + fn default() -> ReqQueue { + ReqQueue { + incoming: Incoming { pending: HashMap::default() }, + outgoing: Outgoing { next_id: 0, pending: HashMap::default() }, + } + } +} + +#[derive(Debug)] +pub struct Incoming { + pending: HashMap, +} + +#[derive(Debug)] +pub struct Outgoing { + next_id: i32, + pending: HashMap, +} + +impl Incoming { + pub fn register(&mut self, id: RequestId, data: I) { + self.pending.insert(id, data); + } + pub fn cancel(&mut self, id: RequestId) -> Option { + let _data = self.complete(id.clone())?; + let error = ResponseError { + code: ErrorCode::RequestCanceled as i32, + message: "canceled by client".to_string(), + data: None, + }; + Some(Response { id, result: None, error: Some(error) }) + } + pub fn complete(&mut self, id: RequestId) -> Option { + self.pending.remove(&id) + } +} + +impl Outgoing { + pub fn register(&mut self, method: String, params: P, data: O) -> Request { + let id = RequestId::from(self.next_id); + self.pending.insert(id.clone(), data); + self.next_id += 1; + Request::new(id, method, params) + } + pub fn complete(&mut self, id: RequestId) -> Option { + self.pending.remove(&id) + } +} diff --git a/lib/lsp-server/src/socket.rs b/lib/lsp-server/src/socket.rs new file mode 100644 index 0000000000..4a59c4c0fa --- /dev/null +++ b/lib/lsp-server/src/socket.rs @@ -0,0 +1,46 @@ +use std::{ + io::{self, BufReader}, + net::TcpStream, + thread, +}; + +use crossbeam_channel::{bounded, Receiver, Sender}; + +use crate::{ + stdio::{make_io_threads, IoThreads}, + Message, +}; + +pub(crate) fn socket_transport( + stream: TcpStream, +) -> (Sender, Receiver, IoThreads) { + let (reader_receiver, reader) = make_reader(stream.try_clone().unwrap()); + let (writer_sender, writer) = make_write(stream.try_clone().unwrap()); + let io_threads = make_io_threads(reader, writer); + (writer_sender, reader_receiver, io_threads) +} + +fn make_reader(stream: TcpStream) -> (Receiver, thread::JoinHandle>) { + let (reader_sender, reader_receiver) = bounded::(0); + let reader = thread::spawn(move || { + let mut buf_read = BufReader::new(stream); + while let Some(msg) = Message::read(&mut buf_read).unwrap() { + let is_exit = matches!(&msg, Message::Notification(n) if n.is_exit()); + reader_sender.send(msg).unwrap(); + if is_exit { + break; + } + } + Ok(()) + }); + (reader_receiver, reader) +} + +fn make_write(mut stream: TcpStream) -> (Sender, thread::JoinHandle>) { + let (writer_sender, writer_receiver) = bounded::(0); + let writer = thread::spawn(move || { + writer_receiver.into_iter().try_for_each(|it| it.write(&mut stream)).unwrap(); + Ok(()) + }); + (writer_sender, writer) +} diff --git a/lib/lsp-server/src/stdio.rs b/lib/lsp-server/src/stdio.rs new file mode 100644 index 0000000000..cdee6432df --- /dev/null +++ b/lib/lsp-server/src/stdio.rs @@ -0,0 +1,71 @@ +use std::{ + io::{self, stdin, stdout}, + thread, +}; + +use crossbeam_channel::{bounded, Receiver, Sender}; + +use crate::Message; + +/// Creates an LSP connection via stdio. +pub(crate) fn stdio_transport() -> (Sender, Receiver, IoThreads) { + let (writer_sender, writer_receiver) = bounded::(0); + let writer = thread::spawn(move || { + let stdout = stdout(); + let mut stdout = stdout.lock(); + writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout))?; + Ok(()) + }); + let (reader_sender, reader_receiver) = bounded::(0); + let reader = thread::spawn(move || { + let stdin = stdin(); + let mut stdin = stdin.lock(); + while let Some(msg) = Message::read(&mut stdin)? { + let is_exit = match &msg { + Message::Notification(n) => n.is_exit(), + _ => false, + }; + + reader_sender.send(msg).unwrap(); + + if is_exit { + break; + } + } + Ok(()) + }); + let threads = IoThreads { reader, writer }; + (writer_sender, reader_receiver, threads) +} + +// Creates an IoThreads +pub(crate) fn make_io_threads( + reader: thread::JoinHandle>, + writer: thread::JoinHandle>, +) -> IoThreads { + IoThreads { reader, writer } +} + +pub struct IoThreads { + reader: thread::JoinHandle>, + writer: thread::JoinHandle>, +} + +impl IoThreads { + pub fn join(self) -> io::Result<()> { + match self.reader.join() { + Ok(r) => r?, + Err(err) => { + println!("reader panicked!"); + std::panic::panic_any(err) + } + } + match self.writer.join() { + Ok(r) => r, + Err(err) => { + println!("writer panicked!"); + std::panic::panic_any(err); + } + } + } +}