diff --git a/Cargo.lock b/Cargo.lock index 18d9194675..e12545ccfc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -234,6 +234,30 @@ dependencies = [ "wait-timeout", ] +[[package]] +name = "async-channel" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28243a43d821d11341ab73c80bed182dc015c514b951616cf79bd4af39af0c3" +dependencies = [ + "concurrent-queue", + "event-listener 5.3.0", + "event-listener-strategy 0.5.1", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-lock" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d034b430882f8381900d3fe6f0aaa3ad94f2cb4ac519b429692a1bc2dda4ae7b" +dependencies = [ + "event-listener 4.0.3", + "event-listener-strategy 0.4.0", + "pin-project-lite", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -256,6 +280,12 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "async-task" +version = "4.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbb36e985947064623dbd357f727af08ffd077f93d696782f3c56365fa2e2799" + [[package]] name = "async-trait" version = "0.1.79" @@ -282,6 +312,12 @@ version = "0.15.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.2.0" @@ -430,6 +466,22 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blocking" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a37913e8dc4ddcc604f0c6d3bf2887c995153af3611de9e23c352b44c1b9118" +dependencies = [ + "async-channel", + "async-lock", + "async-task", + "fastrand", + "futures-io", + "futures-lite", + "piper", + "tracing", +] + [[package]] name = "borsh" version = "1.4.0" @@ -826,6 +878,15 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "concurrent-queue" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "condtype" version = "1.3.0" @@ -1362,6 +1423,48 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" +[[package]] +name = "event-listener" +version = "4.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b215c49b2b248c855fb73579eb1f4f26c38ffdc12973e20e07b91d78d5646e" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d9944b8ca13534cdfb2800775f8dd4902ff3fc75a50101466decadfdf322a24" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" +dependencies = [ + "event-listener 4.0.3", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "332f51cb23d20b0de8458b86580878211da09bcd4503cb579c225b3d124cabb3" +dependencies = [ + "event-listener 5.3.0", + "pin-project-lite", +] + [[package]] name = "fallible-iterator" version = "0.3.0" @@ -1578,6 +1681,16 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +[[package]] +name = "futures-lite" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52527eb5074e35e9339c6b4e8d12600c7128b68fb25dcb9fa9dec18f7c25f3a5" +dependencies = [ + "futures-core", + "pin-project-lite", +] + [[package]] name = "futures-macro" version = "0.3.30" @@ -1997,6 +2110,32 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "interprocess" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81f2533f3be42fffe3b5e63b71aeca416c1c3bc33e4e27be018521e76b1f38fb" +dependencies = [ + "blocking", + "cfg-if", + "futures-core", + "futures-io", + "intmap", + "libc", + "once_cell", + "rustc_version", + "spinning", + "thiserror", + "to_method", + "winapi", +] + +[[package]] +name = "intmap" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae52f28f45ac2bc96edb7714de995cffc174a395fb0abf5bff453587c980d7b9" + [[package]] name = "inventory" version = "0.3.15" @@ -3077,10 +3216,13 @@ name = "nu-plugin" version = "0.92.3" dependencies = [ "bincode", + "interprocess", "log", "miette", + "nix", "nu-engine", "nu-protocol", + "nu-system", "nu-utils", "rmp-serde", "semver", @@ -3311,6 +3453,15 @@ dependencies = [ "sxd-xpath", ] +[[package]] +name = "nu_plugin_stress_internals" +version = "0.92.3" +dependencies = [ + "interprocess", + "serde", + "serde_json", +] + [[package]] name = "num" version = "0.4.1" @@ -3591,6 +3742,12 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + [[package]] name = "parking_lot" version = "0.12.1" @@ -3814,6 +3971,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "piper" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "668d31b1c4eba19242f2088b2bf3316b82ca31082a8335764db4e083db7485d4" +dependencies = [ + "atomic-waker", + "fastrand", + "futures-io", +] + [[package]] name = "pkg-config" version = "0.3.30" @@ -5309,6 +5477,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spinning" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d4f0e86297cad2658d92a707320d87bf4e6ae1050287f51d19b67ef3f153a7b" +dependencies = [ + "lock_api", +] + [[package]] name = "sqlparser" version = "0.39.0" @@ -5746,6 +5923,12 @@ dependencies = [ "regex", ] +[[package]] +name = "to_method" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c4ceeeca15c8384bbc3e011dbd8fccb7f068a440b752b7d9b32ceb0ca0e2e8" + [[package]] name = "tokio" version = "1.37.0" diff --git a/Cargo.toml b/Cargo.toml index 8b56585678..ee4062fcb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ members = [ "crates/nu_plugin_custom_values", "crates/nu_plugin_formats", "crates/nu_plugin_polars", + "crates/nu_plugin_stress_internals", "crates/nu-std", "crates/nu-table", "crates/nu-term-grid", diff --git a/crates/nu-parser/src/parse_keywords.rs b/crates/nu-parser/src/parse_keywords.rs index bac41108d7..b46a9f06f7 100644 --- a/crates/nu-parser/src/parse_keywords.rs +++ b/crates/nu-parser/src/parse_keywords.rs @@ -3687,7 +3687,7 @@ pub fn parse_register(working_set: &mut StateWorkingSet, lite_command: &LiteComm // Add it to the working set let plugin = working_set.find_or_create_plugin(&identity, || { - Arc::new(PersistentPlugin::new(identity.clone(), gc_config)) + Arc::new(PersistentPlugin::new(identity.clone(), gc_config.clone())) }); // Downcast the plugin to `PersistentPlugin` - we generally expect this to succeed. The @@ -3706,7 +3706,7 @@ pub fn parse_register(working_set: &mut StateWorkingSet, lite_command: &LiteComm // // The user would expect that `register` would always run the binary to get new // signatures, in case it was replaced with an updated binary - plugin.stop().map_err(|err| { + plugin.reset().map_err(|err| { ParseError::LabeledError( "Failed to restart plugin to get new signatures".into(), err.to_string(), @@ -3714,7 +3714,10 @@ pub fn parse_register(working_set: &mut StateWorkingSet, lite_command: &LiteComm ) })?; + plugin.set_gc_config(&gc_config); + let signatures = get_signature(plugin.clone(), get_envs).map_err(|err| { + log::warn!("Error getting signatures: {err:?}"); ParseError::LabeledError( "Error getting signatures".into(), err.to_string(), diff --git a/crates/nu-plugin-test-support/src/fake_persistent_plugin.rs b/crates/nu-plugin-test-support/src/fake_persistent_plugin.rs index 40bba951c2..b1215faa04 100644 --- a/crates/nu-plugin-test-support/src/fake_persistent_plugin.rs +++ b/crates/nu-plugin-test-support/src/fake_persistent_plugin.rs @@ -51,6 +51,11 @@ impl RegisteredPlugin for FakePersistentPlugin { Ok(()) } + fn reset(&self) -> Result<(), ShellError> { + // We can't stop + Ok(()) + } + fn as_any(self: Arc) -> Arc { self } diff --git a/crates/nu-plugin-test-support/src/spawn_fake_plugin.rs b/crates/nu-plugin-test-support/src/spawn_fake_plugin.rs index e29dea5e1c..0b8e34ae19 100644 --- a/crates/nu-plugin-test-support/src/spawn_fake_plugin.rs +++ b/crates/nu-plugin-test-support/src/spawn_fake_plugin.rs @@ -47,7 +47,9 @@ pub(crate) fn spawn_fake_plugin( let identity = PluginIdentity::new_fake(name); let reg_plugin = Arc::new(FakePersistentPlugin::new(identity.clone())); let source = Arc::new(PluginSource::new(reg_plugin.clone())); - let mut manager = PluginInterfaceManager::new(source, input_write); + + // The fake plugin has no process ID, and we also don't set the garbage collector + let mut manager = PluginInterfaceManager::new(source, None, input_write); // Set up the persistent plugin with the interface before continuing let interface = manager.get_interface(); diff --git a/crates/nu-plugin/Cargo.toml b/crates/nu-plugin/Cargo.toml index 0e2f755f31..1780ff002b 100644 --- a/crates/nu-plugin/Cargo.toml +++ b/crates/nu-plugin/Cargo.toml @@ -13,6 +13,7 @@ bench = false [dependencies] nu-engine = { path = "../nu-engine", version = "0.92.3" } nu-protocol = { path = "../nu-protocol", version = "0.92.3" } +nu-system = { path = "../nu-system", version = "0.92.3" } nu-utils = { path = "../nu-utils", version = "0.92.3" } bincode = "1.3" @@ -24,6 +25,15 @@ miette = { workspace = true } semver = "1.0" typetag = "0.2" thiserror = "1.0" +interprocess = { version = "1.2.1", optional = true } + +[features] +default = ["local-socket"] +local-socket = ["interprocess"] + +[target.'cfg(target_family = "unix")'.dependencies] +# For setting the process group ID (EnterForeground / LeaveForeground) +nix = { workspace = true, default-features = false, features = ["process"] } [target.'cfg(target_os = "windows")'.dependencies] windows = { workspace = true, features = [ diff --git a/crates/nu-plugin/src/plugin/communication_mode/local_socket/mod.rs b/crates/nu-plugin/src/plugin/communication_mode/local_socket/mod.rs new file mode 100644 index 0000000000..e550892fe1 --- /dev/null +++ b/crates/nu-plugin/src/plugin/communication_mode/local_socket/mod.rs @@ -0,0 +1,84 @@ +use std::ffi::OsString; + +#[cfg(test)] +pub(crate) mod tests; + +/// Generate a name to be used for a local socket specific to this `nu` process, described by the +/// given `unique_id`, which should be unique to the purpose of the socket. +/// +/// On Unix, this is a path, which should generally be 100 characters or less for compatibility. On +/// Windows, this is a name within the `\\.\pipe` namespace. +#[cfg(unix)] +pub fn make_local_socket_name(unique_id: &str) -> OsString { + // Prefer to put it in XDG_RUNTIME_DIR if set, since that's user-local + let mut base = if let Some(runtime_dir) = std::env::var_os("XDG_RUNTIME_DIR") { + std::path::PathBuf::from(runtime_dir) + } else { + // Use std::env::temp_dir() for portability, especially since on Android this is probably + // not `/tmp` + std::env::temp_dir() + }; + let socket_name = format!("nu.{}.{}.sock", std::process::id(), unique_id); + base.push(socket_name); + base.into() +} + +/// Generate a name to be used for a local socket specific to this `nu` process, described by the +/// given `unique_id`, which should be unique to the purpose of the socket. +/// +/// On Unix, this is a path, which should generally be 100 characters or less for compatibility. On +/// Windows, this is a name within the `\\.\pipe` namespace. +#[cfg(windows)] +pub fn make_local_socket_name(unique_id: &str) -> OsString { + format!("nu.{}.{}", std::process::id(), unique_id).into() +} + +/// Determine if the error is just due to the listener not being ready yet in asynchronous mode +#[cfg(not(windows))] +pub fn is_would_block_err(err: &std::io::Error) -> bool { + err.kind() == std::io::ErrorKind::WouldBlock +} + +/// Determine if the error is just due to the listener not being ready yet in asynchronous mode +#[cfg(windows)] +pub fn is_would_block_err(err: &std::io::Error) -> bool { + err.kind() == std::io::ErrorKind::WouldBlock + || err.raw_os_error().is_some_and(|e| { + // Windows returns this error when trying to accept a pipe in non-blocking mode + e as i64 == windows::Win32::Foundation::ERROR_PIPE_LISTENING.0 as i64 + }) +} + +/// Wraps the `interprocess` local socket stream for greater compatibility +#[derive(Debug)] +pub struct LocalSocketStream(pub interprocess::local_socket::LocalSocketStream); + +impl From for LocalSocketStream { + fn from(value: interprocess::local_socket::LocalSocketStream) -> Self { + LocalSocketStream(value) + } +} + +impl std::io::Read for LocalSocketStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.0.read(buf) + } +} + +impl std::io::Write for LocalSocketStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + // We don't actually flush the underlying socket on Windows. The flush operation on a + // Windows named pipe actually synchronizes with read on the other side, and won't finish + // until the other side is empty. This isn't how most of our other I/O methods work, so we + // just won't do it. The BufWriter above this will have still made a write call with the + // contents of the buffer, which should be good enough. + if cfg!(not(windows)) { + self.0.flush()?; + } + Ok(()) + } +} diff --git a/crates/nu-plugin/src/plugin/communication_mode/local_socket/tests.rs b/crates/nu-plugin/src/plugin/communication_mode/local_socket/tests.rs new file mode 100644 index 0000000000..a15d7a5294 --- /dev/null +++ b/crates/nu-plugin/src/plugin/communication_mode/local_socket/tests.rs @@ -0,0 +1,19 @@ +use super::make_local_socket_name; + +#[test] +fn local_socket_path_contains_pid() { + let name = make_local_socket_name("test-string") + .to_string_lossy() + .into_owned(); + println!("{}", name); + assert!(name.to_string().contains(&std::process::id().to_string())); +} + +#[test] +fn local_socket_path_contains_provided_name() { + let name = make_local_socket_name("test-string") + .to_string_lossy() + .into_owned(); + println!("{}", name); + assert!(name.to_string().contains("test-string")); +} diff --git a/crates/nu-plugin/src/plugin/communication_mode/mod.rs b/crates/nu-plugin/src/plugin/communication_mode/mod.rs new file mode 100644 index 0000000000..ca7d5e2b41 --- /dev/null +++ b/crates/nu-plugin/src/plugin/communication_mode/mod.rs @@ -0,0 +1,233 @@ +use std::ffi::OsStr; +use std::io::{Stdin, Stdout}; +use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio}; + +use nu_protocol::ShellError; + +#[cfg(feature = "local-socket")] +use interprocess::local_socket::LocalSocketListener; + +#[cfg(feature = "local-socket")] +mod local_socket; + +#[cfg(feature = "local-socket")] +use local_socket::*; + +#[derive(Debug, Clone)] +pub(crate) enum CommunicationMode { + /// Communicate using `stdin` and `stdout`. + Stdio, + /// Communicate using an operating system-specific local socket. + #[cfg(feature = "local-socket")] + LocalSocket(std::ffi::OsString), +} + +impl CommunicationMode { + /// Generate a new local socket communication mode based on the given plugin exe path. + #[cfg(feature = "local-socket")] + pub fn local_socket(plugin_exe: &std::path::Path) -> CommunicationMode { + use std::hash::{Hash, Hasher}; + use std::time::SystemTime; + + // Generate the unique ID based on the plugin path and the current time. The actual + // algorithm here is not very important, we just want this to be relatively unique very + // briefly. Using the default hasher in the stdlib means zero extra dependencies. + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + + plugin_exe.hash(&mut hasher); + SystemTime::now().hash(&mut hasher); + + let unique_id = format!("{:016x}", hasher.finish()); + + CommunicationMode::LocalSocket(make_local_socket_name(&unique_id)) + } + + pub fn args(&self) -> Vec<&OsStr> { + match self { + CommunicationMode::Stdio => vec![OsStr::new("--stdio")], + #[cfg(feature = "local-socket")] + CommunicationMode::LocalSocket(path) => { + vec![OsStr::new("--local-socket"), path.as_os_str()] + } + } + } + + pub fn setup_command_io(&self, command: &mut Command) { + match self { + CommunicationMode::Stdio => { + // Both stdout and stdin are piped so we can receive information from the plugin + command.stdin(Stdio::piped()); + command.stdout(Stdio::piped()); + } + #[cfg(feature = "local-socket")] + CommunicationMode::LocalSocket(_) => { + // Stdio can be used by the plugin to talk to the terminal in local socket mode, + // which is the big benefit + command.stdin(Stdio::inherit()); + command.stdout(Stdio::inherit()); + } + } + } + + pub fn serve(&self) -> Result { + match self { + // Nothing to set up for stdio - we just take it from the child. + CommunicationMode::Stdio => Ok(PreparedServerCommunication::Stdio), + // For sockets: we need to create the server so that the child won't fail to connect. + #[cfg(feature = "local-socket")] + CommunicationMode::LocalSocket(name) => { + let listener = LocalSocketListener::bind(name.as_os_str()).map_err(|err| { + ShellError::IOError { + msg: format!("failed to open socket for plugin: {err}"), + } + })?; + Ok(PreparedServerCommunication::LocalSocket { + name: name.clone(), + listener, + }) + } + } + } + + pub fn connect_as_client(&self) -> Result { + match self { + CommunicationMode::Stdio => Ok(ClientCommunicationIo::Stdio( + std::io::stdin(), + std::io::stdout(), + )), + #[cfg(feature = "local-socket")] + CommunicationMode::LocalSocket(name) => { + // Connect to the specified socket. + let get_socket = || { + use interprocess::local_socket as ls; + ls::LocalSocketStream::connect(name.as_os_str()) + .map_err(|err| ShellError::IOError { + msg: format!("failed to connect to socket: {err}"), + }) + .map(LocalSocketStream::from) + }; + // Reverse order from the server: read in, write out + let read_in = get_socket()?; + let write_out = get_socket()?; + Ok(ClientCommunicationIo::LocalSocket { read_in, write_out }) + } + } + } +} + +pub(crate) enum PreparedServerCommunication { + Stdio, + #[cfg(feature = "local-socket")] + LocalSocket { + #[cfg_attr(windows, allow(dead_code))] // not used on Windows + name: std::ffi::OsString, + listener: LocalSocketListener, + }, +} + +impl PreparedServerCommunication { + pub fn connect(&self, child: &mut Child) -> Result { + match self { + PreparedServerCommunication::Stdio => { + let stdin = child + .stdin + .take() + .ok_or_else(|| ShellError::PluginFailedToLoad { + msg: "Plugin missing stdin writer".into(), + })?; + + let stdout = child + .stdout + .take() + .ok_or_else(|| ShellError::PluginFailedToLoad { + msg: "Plugin missing stdout writer".into(), + })?; + + Ok(ServerCommunicationIo::Stdio(stdin, stdout)) + } + #[cfg(feature = "local-socket")] + PreparedServerCommunication::LocalSocket { listener, .. } => { + use std::time::{Duration, Instant}; + + const RETRY_PERIOD: Duration = Duration::from_millis(1); + const TIMEOUT: Duration = Duration::from_secs(10); + + let start = Instant::now(); + + // Use a loop to try to get two clients from the listener: one for read (the plugin + // output) and one for write (the plugin input) + listener.set_nonblocking(true)?; + let mut get_socket = || { + let mut result = None; + while let Ok(None) = child.try_wait() { + match listener.accept() { + Ok(stream) => { + // Success! But make sure the stream is in blocking mode. + stream.set_nonblocking(false)?; + result = Some(stream); + break; + } + Err(err) => { + if !is_would_block_err(&err) { + // `WouldBlock` is ok, just means it's not ready yet, but some other + // kind of error should be reported + return Err(err.into()); + } + } + } + if Instant::now().saturating_duration_since(start) > TIMEOUT { + return Err(ShellError::PluginFailedToLoad { + msg: "Plugin timed out while waiting to connect to socket".into(), + }); + } else { + std::thread::sleep(RETRY_PERIOD); + } + } + if let Some(stream) = result { + Ok(LocalSocketStream(stream)) + } else { + // The process may have exited + Err(ShellError::PluginFailedToLoad { + msg: "Plugin exited without connecting".into(), + }) + } + }; + // Input stream always comes before output + let write_in = get_socket()?; + let read_out = get_socket()?; + Ok(ServerCommunicationIo::LocalSocket { read_out, write_in }) + } + } + } +} + +impl Drop for PreparedServerCommunication { + fn drop(&mut self) { + match self { + #[cfg(all(unix, feature = "local-socket"))] + PreparedServerCommunication::LocalSocket { name: path, .. } => { + // Just try to remove the socket file, it's ok if this fails + let _ = std::fs::remove_file(path); + } + _ => (), + } + } +} + +pub(crate) enum ServerCommunicationIo { + Stdio(ChildStdin, ChildStdout), + #[cfg(feature = "local-socket")] + LocalSocket { + read_out: LocalSocketStream, + write_in: LocalSocketStream, + }, +} + +pub(crate) enum ClientCommunicationIo { + Stdio(Stdin, Stdout), + #[cfg(feature = "local-socket")] + LocalSocket { + read_in: LocalSocketStream, + write_out: LocalSocketStream, + }, +} diff --git a/crates/nu-plugin/src/plugin/context.rs b/crates/nu-plugin/src/plugin/context.rs index 61fdfe662a..75315520e1 100644 --- a/crates/nu-plugin/src/plugin/context.rs +++ b/crates/nu-plugin/src/plugin/context.rs @@ -8,7 +8,10 @@ use nu_protocol::{ use std::{ borrow::Cow, collections::HashMap, - sync::{atomic::AtomicBool, Arc}, + sync::{ + atomic::{AtomicBool, AtomicU32}, + Arc, + }, }; /// Object safe trait for abstracting operations required of the plugin context. @@ -16,8 +19,12 @@ use std::{ /// This is not a public API. #[doc(hidden)] pub trait PluginExecutionContext: Send + Sync { + /// A span pointing to the command being executed + fn span(&self) -> Span; /// The interrupt signal, if present fn ctrlc(&self) -> Option<&Arc>; + /// The pipeline externals state, for tracking the foreground process group, if present + fn pipeline_externals_state(&self) -> Option<&Arc<(AtomicU32, AtomicU32)>>; /// Get engine configuration fn get_config(&self) -> Result; /// Get plugin configuration @@ -75,10 +82,18 @@ impl<'a> PluginExecutionCommandContext<'a> { } impl<'a> PluginExecutionContext for PluginExecutionCommandContext<'a> { + fn span(&self) -> Span { + self.call.head + } + fn ctrlc(&self) -> Option<&Arc> { self.engine_state.ctrlc.as_ref() } + fn pipeline_externals_state(&self) -> Option<&Arc<(AtomicU32, AtomicU32)>> { + Some(&self.engine_state.pipeline_externals_state) + } + fn get_config(&self) -> Result { Ok(nu_engine::get_config(&self.engine_state, &self.stack)) } @@ -237,10 +252,18 @@ pub(crate) struct PluginExecutionBogusContext; #[cfg(test)] impl PluginExecutionContext for PluginExecutionBogusContext { + fn span(&self) -> Span { + Span::test_data() + } + fn ctrlc(&self) -> Option<&Arc> { None } + fn pipeline_externals_state(&self) -> Option<&Arc<(AtomicU32, AtomicU32)>> { + None + } + fn get_config(&self) -> Result { Err(ShellError::NushellFailed { msg: "get_config not implemented on bogus".into(), diff --git a/crates/nu-plugin/src/plugin/interface.rs b/crates/nu-plugin/src/plugin/interface.rs index 45fbd75c12..19e70fb99f 100644 --- a/crates/nu-plugin/src/plugin/interface.rs +++ b/crates/nu-plugin/src/plugin/interface.rs @@ -80,6 +80,11 @@ pub trait PluginWrite: Send + Sync { /// Flush any internal buffers, if applicable. fn flush(&self) -> Result<(), ShellError>; + + /// True if this output is stdout, so that plugins can avoid using stdout for their own purpose + fn is_stdout(&self) -> bool { + false + } } impl PluginWrite for (std::io::Stdout, E) @@ -96,6 +101,10 @@ where msg: err.to_string(), }) } + + fn is_stdout(&self) -> bool { + true + } } impl PluginWrite for (Mutex, E) @@ -131,6 +140,10 @@ where fn flush(&self) -> Result<(), ShellError> { (**self).flush() } + + fn is_stdout(&self) -> bool { + (**self).is_stdout() + } } /// An interface manager handles I/O and state management for communication between a plugin and the diff --git a/crates/nu-plugin/src/plugin/interface/engine.rs b/crates/nu-plugin/src/plugin/interface/engine.rs index e98eed06fb..e394b561e8 100644 --- a/crates/nu-plugin/src/plugin/interface/engine.rs +++ b/crates/nu-plugin/src/plugin/interface/engine.rs @@ -4,10 +4,13 @@ use super::{ stream::{StreamManager, StreamManagerHandle}, Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite, Sequence, }; -use crate::protocol::{ - CallInfo, CustomValueOp, EngineCall, EngineCallId, EngineCallResponse, Ordering, PluginCall, - PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, PluginOption, PluginOutput, - ProtocolInfo, +use crate::{ + protocol::{ + CallInfo, CustomValueOp, EngineCall, EngineCallId, EngineCallResponse, Ordering, + PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, PluginOption, + PluginOutput, ProtocolInfo, + }, + util::Waitable, }; use nu_protocol::{ engine::Closure, Config, IntoInterruptiblePipelineData, LabeledError, ListStream, PipelineData, @@ -47,6 +50,8 @@ mod tests; /// Internal shared state between the manager and each interface. struct EngineInterfaceState { + /// Protocol version info, set after `Hello` received + protocol_info: Waitable>, /// Sequence for generating engine call ids engine_call_id_sequence: Sequence, /// Sequence for generating stream ids @@ -61,6 +66,7 @@ struct EngineInterfaceState { impl std::fmt::Debug for EngineInterfaceState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EngineInterfaceState") + .field("protocol_info", &self.protocol_info) .field("engine_call_id_sequence", &self.engine_call_id_sequence) .field("stream_id_sequence", &self.stream_id_sequence) .field( @@ -91,8 +97,6 @@ pub struct EngineInterfaceManager { mpsc::Receiver<(EngineCallId, mpsc::Sender>)>, /// Manages stream messages and state stream_manager: StreamManager, - /// Protocol version info, set after `Hello` received - protocol_info: Option, } impl EngineInterfaceManager { @@ -102,6 +106,7 @@ impl EngineInterfaceManager { EngineInterfaceManager { state: Arc::new(EngineInterfaceState { + protocol_info: Waitable::new(), engine_call_id_sequence: Sequence::default(), stream_id_sequence: Sequence::default(), engine_call_subscription_sender: subscription_tx, @@ -112,7 +117,6 @@ impl EngineInterfaceManager { engine_call_subscriptions: BTreeMap::new(), engine_call_subscription_receiver: subscription_rx, stream_manager: StreamManager::new(), - protocol_info: None, } } @@ -228,12 +232,13 @@ impl InterfaceManager for EngineInterfaceManager { match input { PluginInput::Hello(info) => { + let info = Arc::new(info); + self.state.protocol_info.set(info.clone())?; + let local_info = ProtocolInfo::default(); if local_info.is_compatible_with(&info)? { - self.protocol_info = Some(info); Ok(()) } else { - self.protocol_info = None; Err(ShellError::PluginFailedToLoad { msg: format!( "Plugin is compiled for nushell version {}, \ @@ -243,7 +248,7 @@ impl InterfaceManager for EngineInterfaceManager { }) } } - _ if self.protocol_info.is_none() => { + _ if !self.state.protocol_info.is_set() => { // Must send protocol info first Err(ShellError::PluginFailedToLoad { msg: "Failed to receive initial Hello message. This engine might be too old" @@ -477,6 +482,15 @@ impl EngineInterface { }) } + /// Returns `true` if the plugin is communicating on stdio. When this is the case, stdin and + /// stdout should not be used by the plugin for other purposes. + /// + /// If the plugin can not be used without access to stdio, an error should be presented to the + /// user instead. + pub fn is_using_stdio(&self) -> bool { + self.state.writer.is_stdout() + } + /// Get the full shell configuration from the engine. As this is quite a large object, it is /// provided on request only. /// @@ -656,6 +670,43 @@ impl EngineInterface { } } + /// Returns a guard that will keep the plugin in the foreground as long as the guard is alive. + /// + /// Moving the plugin to the foreground is necessary for plugins that need to receive input and + /// signals directly from the terminal. + /// + /// The exact implementation is operating system-specific. On Unix, this ensures that the + /// plugin process becomes part of the process group controlling the terminal. + pub fn enter_foreground(&self) -> Result { + match self.engine_call(EngineCall::EnterForeground)? { + EngineCallResponse::Error(error) => Err(error), + EngineCallResponse::PipelineData(PipelineData::Value( + Value::Int { val: pgrp, .. }, + _, + )) => { + set_pgrp_from_enter_foreground(pgrp)?; + Ok(ForegroundGuard(Some(self.clone()))) + } + EngineCallResponse::PipelineData(PipelineData::Empty) => { + Ok(ForegroundGuard(Some(self.clone()))) + } + _ => Err(ShellError::PluginFailedToDecode { + msg: "Received unexpected response type for EngineCall::SetForeground".into(), + }), + } + } + + /// Internal: for exiting the foreground after `enter_foreground()`. Called from the guard. + fn leave_foreground(&self) -> Result<(), ShellError> { + match self.engine_call(EngineCall::LeaveForeground)? { + EngineCallResponse::Error(error) => Err(error), + EngineCallResponse::PipelineData(PipelineData::Empty) => Ok(()), + _ => Err(ShellError::PluginFailedToDecode { + msg: "Received unexpected response type for EngineCall::LeaveForeground".into(), + }), + } + } + /// Get the contents of a [`Span`] from the engine. /// /// This method returns `Vec` as it's possible for the matched span to not be a valid UTF-8 @@ -869,3 +920,69 @@ impl Interface for EngineInterface { } } } + +/// Keeps the plugin in the foreground as long as it is alive. +/// +/// Use [`.leave()`] to leave the foreground without ignoring the error. +pub struct ForegroundGuard(Option); + +impl ForegroundGuard { + // Should be called only once + fn leave_internal(&mut self) -> Result<(), ShellError> { + if let Some(interface) = self.0.take() { + // On Unix, we need to put ourselves back in our own process group + #[cfg(unix)] + { + use nix::unistd::{setpgid, Pid}; + // This should always succeed, frankly, but handle the error just in case + setpgid(Pid::from_raw(0), Pid::from_raw(0)).map_err(|err| ShellError::IOError { + msg: err.to_string(), + })?; + } + interface.leave_foreground()?; + } + Ok(()) + } + + /// Leave the foreground. In contrast to dropping the guard, this preserves the error (if any). + pub fn leave(mut self) -> Result<(), ShellError> { + let result = self.leave_internal(); + std::mem::forget(self); + result + } +} + +impl Drop for ForegroundGuard { + fn drop(&mut self) { + let _ = self.leave_internal(); + } +} + +#[cfg(unix)] +fn set_pgrp_from_enter_foreground(pgrp: i64) -> Result<(), ShellError> { + use nix::unistd::{setpgid, Pid}; + if let Ok(pgrp) = pgrp.try_into() { + setpgid(Pid::from_raw(0), Pid::from_raw(pgrp)).map_err(|err| ShellError::GenericError { + error: "Failed to set process group for foreground".into(), + msg: "".into(), + span: None, + help: Some(err.to_string()), + inner: vec![], + }) + } else { + Err(ShellError::NushellFailed { + msg: "Engine returned an invalid process group ID".into(), + }) + } +} + +#[cfg(not(unix))] +fn set_pgrp_from_enter_foreground(_pgrp: i64) -> Result<(), ShellError> { + Err(ShellError::NushellFailed { + msg: concat!( + "EnterForeground asked plugin to join process group, but not supported on ", + cfg!(target_os) + ) + .into(), + }) +} diff --git a/crates/nu-plugin/src/plugin/interface/engine/tests.rs b/crates/nu-plugin/src/plugin/interface/engine/tests.rs index 4693b1774b..572f6a39fe 100644 --- a/crates/nu-plugin/src/plugin/interface/engine/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/engine/tests.rs @@ -15,9 +15,21 @@ use nu_protocol::{ }; use std::{ collections::HashMap, - sync::mpsc::{self, TryRecvError}, + sync::{ + mpsc::{self, TryRecvError}, + Arc, + }, }; +#[test] +fn is_using_stdio_is_false_for_test() { + let test = TestCase::new(); + let manager = test.engine(); + let interface = manager.get_interface(); + + assert!(!interface.is_using_stdio()); +} + #[test] fn manager_consume_all_consumes_messages() -> Result<(), ShellError> { let mut test = TestCase::new(); @@ -247,8 +259,9 @@ fn manager_consume_sets_protocol_info_on_hello() -> Result<(), ShellError> { manager.consume(PluginInput::Hello(info.clone()))?; let set_info = manager + .state .protocol_info - .as_ref() + .try_get()? .expect("protocol info not set"); assert_eq!(info.version, set_info.version); Ok(()) @@ -275,7 +288,7 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), let mut manager = TestCase::new().engine(); // hello not set - assert!(manager.protocol_info.is_none()); + assert!(!manager.state.protocol_info.is_set()); let error = manager .consume(PluginInput::Drop(0)) @@ -285,10 +298,17 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), Ok(()) } +fn set_default_protocol_info(manager: &mut EngineInterfaceManager) -> Result<(), ShellError> { + manager + .state + .protocol_info + .set(Arc::new(ProtocolInfo::default())) +} + #[test] fn manager_consume_goodbye_closes_plugin_call_channel() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -307,7 +327,7 @@ fn manager_consume_goodbye_closes_plugin_call_channel() -> Result<(), ShellError #[test] fn manager_consume_call_signature_forwards_to_receiver_with_context() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -327,7 +347,7 @@ fn manager_consume_call_signature_forwards_to_receiver_with_context() -> Result< #[test] fn manager_consume_call_run_forwards_to_receiver_with_context() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -361,7 +381,7 @@ fn manager_consume_call_run_forwards_to_receiver_with_context() -> Result<(), Sh #[test] fn manager_consume_call_run_forwards_to_receiver_with_pipeline_data() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -403,7 +423,7 @@ fn manager_consume_call_run_forwards_to_receiver_with_pipeline_data() -> Result< #[test] fn manager_consume_call_run_deserializes_custom_values_in_args() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -469,7 +489,7 @@ fn manager_consume_call_run_deserializes_custom_values_in_args() -> Result<(), S fn manager_consume_call_custom_value_op_forwards_to_receiver_with_context() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -509,7 +529,7 @@ fn manager_consume_call_custom_value_op_forwards_to_receiver_with_context() -> R fn manager_consume_engine_call_response_forwards_to_subscriber_with_pipeline_data( ) -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = fake_engine_call(&mut manager, 0); diff --git a/crates/nu-plugin/src/plugin/interface/plugin.rs b/crates/nu-plugin/src/plugin/interface/plugin.rs index 93b04c8ef9..5be513a2f1 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin.rs @@ -5,14 +5,14 @@ use super::{ Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite, }; use crate::{ - plugin::{context::PluginExecutionContext, gc::PluginGc, PluginSource}, + plugin::{context::PluginExecutionContext, gc::PluginGc, process::PluginProcess, PluginSource}, protocol::{ CallInfo, CustomValueOp, EngineCall, EngineCallId, EngineCallResponse, Ordering, PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, PluginOption, PluginOutput, ProtocolInfo, StreamId, StreamMessage, }, sequence::Sequence, - util::with_custom_values_in, + util::{with_custom_values_in, Waitable}, }; use nu_protocol::{ ast::Operator, CustomValue, IntoInterruptiblePipelineData, IntoSpanned, ListStream, @@ -62,6 +62,10 @@ impl std::ops::Deref for Context { struct PluginInterfaceState { /// The source to be used for custom values coming from / going to the plugin source: Arc, + /// The plugin process being managed + process: Option, + /// Protocol version info, set after `Hello` received + protocol_info: Waitable>, /// Sequence for generating plugin call ids plugin_call_id_sequence: Sequence, /// Sequence for generating stream ids @@ -78,12 +82,14 @@ impl std::fmt::Debug for PluginInterfaceState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PluginInterfaceState") .field("source", &self.source) + .field("protocol_info", &self.protocol_info) .field("plugin_call_id_sequence", &self.plugin_call_id_sequence) .field("stream_id_sequence", &self.stream_id_sequence) .field( "plugin_call_subscription_sender", &self.plugin_call_subscription_sender, ) + .field("error", &self.error) .finish_non_exhaustive() } } @@ -132,8 +138,6 @@ pub struct PluginInterfaceManager { state: Arc, /// Manages stream messages and state stream_manager: StreamManager, - /// Protocol version info, set after `Hello` received - protocol_info: Option, /// State related to plugin calls plugin_call_states: BTreeMap, /// Receiver for plugin call subscriptions @@ -149,6 +153,7 @@ pub struct PluginInterfaceManager { impl PluginInterfaceManager { pub fn new( source: Arc, + pid: Option, writer: impl PluginWrite + 'static, ) -> PluginInterfaceManager { let (subscription_tx, subscription_rx) = mpsc::channel(); @@ -156,6 +161,8 @@ impl PluginInterfaceManager { PluginInterfaceManager { state: Arc::new(PluginInterfaceState { source, + process: pid.map(PluginProcess::new), + protocol_info: Waitable::new(), plugin_call_id_sequence: Sequence::default(), stream_id_sequence: Sequence::default(), plugin_call_subscription_sender: subscription_tx, @@ -163,7 +170,6 @@ impl PluginInterfaceManager { writer: Box::new(writer), }), stream_manager: StreamManager::new(), - protocol_info: None, plugin_call_states: BTreeMap::new(), plugin_call_subscription_receiver: subscription_rx, plugin_call_input_streams: BTreeMap::new(), @@ -289,9 +295,10 @@ impl PluginInterfaceManager { })?; // Generate the state needed to handle engine calls - let current_call_state = CurrentCallState { + let mut current_call_state = CurrentCallState { context_tx: None, keep_plugin_custom_values_tx: Some(state.keep_plugin_custom_values.0.clone()), + entered_foreground: false, }; let handler = move || { @@ -308,7 +315,7 @@ impl PluginInterfaceManager { if let Err(err) = interface.handle_engine_call( engine_call_id, engine_call, - ¤t_call_state, + &mut current_call_state, context.as_deref_mut(), ) { log::warn!( @@ -453,12 +460,13 @@ impl InterfaceManager for PluginInterfaceManager { match input { PluginOutput::Hello(info) => { + let info = Arc::new(info); + self.state.protocol_info.set(info.clone())?; + let local_info = ProtocolInfo::default(); if local_info.is_compatible_with(&info)? { - self.protocol_info = Some(info); Ok(()) } else { - self.protocol_info = None; Err(ShellError::PluginFailedToLoad { msg: format!( "Plugin `{}` is compiled for nushell version {}, \ @@ -470,7 +478,7 @@ impl InterfaceManager for PluginInterfaceManager { }) } } - _ if self.protocol_info.is_none() => { + _ if !self.state.protocol_info.is_set() => { // Must send protocol info first Err(ShellError::PluginFailedToLoad { msg: format!( @@ -613,6 +621,16 @@ pub struct PluginInterface { } impl PluginInterface { + /// Get the process ID for the plugin, if known. + pub fn pid(&self) -> Option { + self.state.process.as_ref().map(|p| p.pid()) + } + + /// Get the protocol info for the plugin. Will block to receive `Hello` if not received yet. + pub fn protocol_info(&self) -> Result, ShellError> { + self.state.protocol_info.get() + } + /// Write the protocol info. This should be done after initialization pub fn hello(&self) -> Result<(), ShellError> { self.write(PluginInput::Hello(ProtocolInfo::default()))?; @@ -673,6 +691,7 @@ impl PluginInterface { let state = CurrentCallState { context_tx: Some(context_tx), keep_plugin_custom_values_tx: Some(keep_plugin_custom_values.0.clone()), + entered_foreground: false, }; // Prepare the call with the state. @@ -767,12 +786,22 @@ impl PluginInterface { &self, rx: mpsc::Receiver, mut context: Option<&mut (dyn PluginExecutionContext + '_)>, - state: CurrentCallState, + mut state: CurrentCallState, ) -> Result, ShellError> { // Handle message from receiver for msg in rx { match msg { ReceivedPluginCallMessage::Response(resp) => { + if state.entered_foreground { + // Make the plugin leave the foreground on return, even if it's a stream + if let Some(context) = context.as_deref_mut() { + if let Err(err) = + set_foreground(self.state.process.as_ref(), context, false) + { + log::warn!("Failed to leave foreground state on exit: {err:?}"); + } + } + } if resp.has_stream() { // If the response has a stream, we need to register the context if let Some(context) = context { @@ -790,7 +819,7 @@ impl PluginInterface { self.handle_engine_call( engine_call_id, engine_call, - &state, + &mut state, context.as_deref_mut(), )?; } @@ -807,11 +836,12 @@ impl PluginInterface { &self, engine_call_id: EngineCallId, engine_call: EngineCall, - state: &CurrentCallState, + state: &mut CurrentCallState, context: Option<&mut (dyn PluginExecutionContext + '_)>, ) -> Result<(), ShellError> { - let resp = - handle_engine_call(engine_call, context).unwrap_or_else(EngineCallResponse::Error); + let process = self.state.process.as_ref(); + let resp = handle_engine_call(engine_call, state, context, process) + .unwrap_or_else(EngineCallResponse::Error); // Handle stream let mut writer = None; let resp = resp @@ -1057,6 +1087,9 @@ pub struct CurrentCallState { /// Sender for a channel that retains plugin custom values that need to stay alive for the /// duration of a plugin call. keep_plugin_custom_values_tx: Option>, + /// The plugin call entered the foreground: this should be cleaned up automatically when the + /// plugin call returns. + entered_foreground: bool, } impl CurrentCallState { @@ -1144,7 +1177,9 @@ impl CurrentCallState { /// Handle an engine call. pub(crate) fn handle_engine_call( call: EngineCall, + state: &mut CurrentCallState, context: Option<&mut (dyn PluginExecutionContext + '_)>, + process: Option<&PluginProcess>, ) -> Result, ShellError> { let call_name = call.name(); @@ -1190,6 +1225,16 @@ pub(crate) fn handle_engine_call( help.item, help.span, ))) } + EngineCall::EnterForeground => { + let resp = set_foreground(process, context, true)?; + state.entered_foreground = true; + Ok(resp) + } + EngineCall::LeaveForeground => { + let resp = set_foreground(process, context, false)?; + state.entered_foreground = false; + Ok(resp) + } EngineCall::GetSpanContents(span) => { let contents = context.get_span_contents(span)?; Ok(EngineCallResponse::value(Value::binary( @@ -1208,3 +1253,39 @@ pub(crate) fn handle_engine_call( .map(EngineCallResponse::PipelineData), } } + +/// Implements enter/exit foreground +fn set_foreground( + process: Option<&PluginProcess>, + context: &mut dyn PluginExecutionContext, + enter: bool, +) -> Result, ShellError> { + if let Some(process) = process { + if let Some(pipeline_externals_state) = context.pipeline_externals_state() { + if enter { + let pgrp = process.enter_foreground(context.span(), pipeline_externals_state)?; + Ok(pgrp.map_or_else(EngineCallResponse::empty, |id| { + EngineCallResponse::value(Value::int(id as i64, context.span())) + })) + } else { + process.exit_foreground()?; + Ok(EngineCallResponse::empty()) + } + } else { + // This should always be present on a real context + Err(ShellError::NushellFailed { + msg: "missing required pipeline_externals_state from context \ + for entering foreground" + .into(), + }) + } + } else { + Err(ShellError::GenericError { + error: "Can't manage plugin process to enter foreground".into(), + msg: "the process ID for this plugin is unknown".into(), + span: Some(context.span()), + help: Some("the plugin may be running in a test".into()), + inner: vec![], + }) + } +} diff --git a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs index 290b04a3ce..77dcc6d4d2 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs @@ -279,8 +279,9 @@ fn manager_consume_sets_protocol_info_on_hello() -> Result<(), ShellError> { manager.consume(PluginOutput::Hello(info.clone()))?; let set_info = manager + .state .protocol_info - .as_ref() + .try_get()? .expect("protocol info not set"); assert_eq!(info.version, set_info.version); Ok(()) @@ -307,7 +308,7 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), let mut manager = TestCase::new().plugin("test"); // hello not set - assert!(manager.protocol_info.is_none()); + assert!(!manager.state.protocol_info.is_set()); let error = manager .consume(PluginOutput::Drop(0)) @@ -317,11 +318,18 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), Ok(()) } +fn set_default_protocol_info(manager: &mut PluginInterfaceManager) -> Result<(), ShellError> { + manager + .state + .protocol_info + .set(Arc::new(ProtocolInfo::default())) +} + #[test] fn manager_consume_call_response_forwards_to_subscriber_with_pipeline_data( ) -> Result<(), ShellError> { let mut manager = TestCase::new().plugin("test"); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = fake_plugin_call(&mut manager, 0); @@ -359,7 +367,7 @@ fn manager_consume_call_response_forwards_to_subscriber_with_pipeline_data( #[test] fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> { let mut manager = TestCase::new().plugin("test"); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; for n in [0, 1] { fake_plugin_call(&mut manager, n); @@ -428,7 +436,7 @@ fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> { fn manager_consume_engine_call_forwards_to_subscriber_with_pipeline_data() -> Result<(), ShellError> { let mut manager = TestCase::new().plugin("test"); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = fake_plugin_call(&mut manager, 37); @@ -480,7 +488,7 @@ fn manager_consume_engine_call_forwards_to_subscriber_with_pipeline_data() -> Re fn manager_handle_engine_call_after_response_received() -> Result<(), ShellError> { let test = TestCase::new(); let mut manager = test.plugin("test"); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let (context_tx, context_rx) = mpsc::channel(); @@ -584,7 +592,7 @@ fn manager_send_plugin_call_response_removes_context_only_if_no_streams_to_read( #[test] fn manager_consume_stream_end_removes_context_only_if_last_stream() -> Result<(), ShellError> { let mut manager = TestCase::new().plugin("test"); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; for n in [1, 2] { manager.plugin_call_states.insert( @@ -1274,8 +1282,8 @@ fn prepare_custom_value_sends_to_keep_channel_if_drop_notify() -> Result<(), She let source = Arc::new(PluginSource::new_fake("test")); let (tx, rx) = mpsc::channel(); let state = CurrentCallState { - context_tx: None, keep_plugin_custom_values_tx: Some(tx), + ..Default::default() }; // Try with a custom val that has drop check set let mut drop_val = PluginCustomValue::serialize_from_custom_value(&DropCustomVal, span)? diff --git a/crates/nu-plugin/src/plugin/interface/test_util.rs b/crates/nu-plugin/src/plugin/interface/test_util.rs index b701ff7792..a2acec1e7e 100644 --- a/crates/nu-plugin/src/plugin/interface/test_util.rs +++ b/crates/nu-plugin/src/plugin/interface/test_util.rs @@ -128,7 +128,7 @@ impl TestCase { impl TestCase { /// Create a new [`PluginInterfaceManager`] that writes to this test case. pub(crate) fn plugin(&self, name: &str) -> PluginInterfaceManager { - PluginInterfaceManager::new(PluginSource::new_fake(name).into(), self.clone()) + PluginInterfaceManager::new(PluginSource::new_fake(name).into(), None, self.clone()) } } diff --git a/crates/nu-plugin/src/plugin/mod.rs b/crates/nu-plugin/src/plugin/mod.rs index d314f3c54a..6ab62eb079 100644 --- a/crates/nu-plugin/src/plugin/mod.rs +++ b/crates/nu-plugin/src/plugin/mod.rs @@ -8,12 +8,11 @@ use std::{ cmp::Ordering, collections::HashMap, env, - ffi::OsStr, - fmt::Write, - io::{BufReader, BufWriter, Read, Write as WriteTrait}, + ffi::OsString, + io::{BufReader, BufWriter}, ops::Deref, path::Path, - process::{Child, ChildStdout, Command as CommandSys, Stdio}, + process::{Child, Command as CommandSys}, sync::{ mpsc::{self, TrySendError}, Arc, Mutex, @@ -34,14 +33,23 @@ use std::os::unix::process::CommandExt; use std::os::windows::process::CommandExt; pub use self::interface::{PluginRead, PluginWrite}; -use self::{command::render_examples, gc::PluginGc}; +use self::{ + command::render_examples, + communication_mode::{ + ClientCommunicationIo, CommunicationMode, PreparedServerCommunication, + ServerCommunicationIo, + }, + gc::PluginGc, +}; mod command; +mod communication_mode; mod context; mod declaration; mod gc; mod interface; mod persistent; +mod process; mod source; pub use command::{create_plugin_signature, PluginCommand, SimplePluginCommand}; @@ -84,62 +92,51 @@ pub trait PluginEncoder: Encoder + Encoder { fn name(&self) -> &str; } -fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys { - log::trace!("Starting plugin: {path:?}, shell = {shell:?}"); +fn create_command(path: &Path, mut shell: Option<&Path>, mode: &CommunicationMode) -> CommandSys { + log::trace!("Starting plugin: {path:?}, shell = {shell:?}, mode = {mode:?}"); - // There is only one mode supported at the moment, but the idea is that future - // communication methods could be supported if desirable - let mut input_arg = Some("--stdio"); + let mut shell_args = vec![]; - let mut process = match (path.extension(), shell) { - (_, Some(shell)) => { - let mut process = std::process::Command::new(shell); - process.arg(path); - - process - } - (Some(extension), None) => { - let (shell, command_switch) = match extension.to_str() { - Some("cmd") | Some("bat") => (Some("cmd"), Some("/c")), - Some("sh") => (Some("sh"), Some("-c")), - Some("py") => (Some("python"), None), - _ => (None, None), - }; - - match (shell, command_switch) { - (Some(shell), Some(command_switch)) => { - let mut process = std::process::Command::new(shell); - process.arg(command_switch); - // If `command_switch` is set, we need to pass the path + arg as one argument - // e.g. sh -c "nu_plugin_inc --stdio" - let mut combined = path.as_os_str().to_owned(); - if let Some(arg) = input_arg.take() { - combined.push(OsStr::new(" ")); - combined.push(OsStr::new(arg)); - } - process.arg(combined); - - process + if shell.is_none() { + // We only have to do this for things that are not executable by Rust's Command API on + // Windows. They do handle bat/cmd files for us, helpfully. + // + // Also include anything that wouldn't be executable with a shebang, like JAR files. + shell = match path.extension().and_then(|e| e.to_str()) { + Some("sh") => { + if cfg!(unix) { + // We don't want to override what might be in the shebang if this is Unix, since + // some scripts will have a shebang specifying bash even if they're .sh + None + } else { + Some(Path::new("sh")) } - (Some(shell), None) => { - let mut process = std::process::Command::new(shell); - process.arg(path); - - process - } - _ => std::process::Command::new(path), } - } - (None, None) => std::process::Command::new(path), - }; - - // Pass input_arg, unless we consumed it already - if let Some(input_arg) = input_arg { - process.arg(input_arg); + Some("nu") => Some(Path::new("nu")), + Some("py") => Some(Path::new("python")), + Some("rb") => Some(Path::new("ruby")), + Some("jar") => { + shell_args.push("-jar"); + Some(Path::new("java")) + } + _ => None, + }; } - // Both stdout and stdin are piped so we can receive information from the plugin - process.stdout(Stdio::piped()).stdin(Stdio::piped()); + let mut process = if let Some(shell) = shell { + let mut process = std::process::Command::new(shell); + process.args(shell_args); + process.arg(path); + + process + } else { + std::process::Command::new(path) + }; + + process.args(mode.args()); + + // Setup I/O according to the communication mode + mode.setup_command_io(&mut process); // The plugin should be run in a new process group to prevent Ctrl-C from stopping it #[cfg(unix)] @@ -158,29 +155,53 @@ fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys { fn make_plugin_interface( mut child: Child, + comm: PreparedServerCommunication, source: Arc, + pid: Option, gc: Option, ) -> Result { - let stdin = child - .stdin - .take() - .ok_or_else(|| ShellError::PluginFailedToLoad { - msg: "Plugin missing stdin writer".into(), - })?; + match comm.connect(&mut child)? { + ServerCommunicationIo::Stdio(stdin, stdout) => make_plugin_interface_with_streams( + stdout, + stdin, + move || { + let _ = child.wait(); + }, + source, + pid, + gc, + ), + #[cfg(feature = "local-socket")] + ServerCommunicationIo::LocalSocket { read_out, write_in } => { + make_plugin_interface_with_streams( + read_out, + write_in, + move || { + let _ = child.wait(); + }, + source, + pid, + gc, + ) + } + } +} - let mut stdout = child - .stdout - .take() - .ok_or_else(|| ShellError::PluginFailedToLoad { - msg: "Plugin missing stdout writer".into(), - })?; +fn make_plugin_interface_with_streams( + mut reader: impl std::io::Read + Send + 'static, + writer: impl std::io::Write + Send + 'static, + after_close: impl FnOnce() + Send + 'static, + source: Arc, + pid: Option, + gc: Option, +) -> Result { + let encoder = get_plugin_encoding(&mut reader)?; - let encoder = get_plugin_encoding(&mut stdout)?; + let reader = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, reader); + let writer = BufWriter::with_capacity(OUTPUT_BUFFER_SIZE, writer); - let reader = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, stdout); - let writer = BufWriter::with_capacity(OUTPUT_BUFFER_SIZE, stdin); - - let mut manager = PluginInterfaceManager::new(source.clone(), (Mutex::new(writer), encoder)); + let mut manager = + PluginInterfaceManager::new(source.clone(), pid, (Mutex::new(writer), encoder)); manager.set_garbage_collector(gc); let interface = manager.get_interface(); @@ -198,10 +219,10 @@ fn make_plugin_interface( if let Err(err) = manager.consume_all((reader, encoder)) { log::warn!("Error in PluginInterfaceManager: {err}"); } - // If the loop has ended, drop the manager so everyone disconnects and then wait for the - // child to exit + // If the loop has ended, drop the manager so everyone disconnects and then run + // after_close drop(manager); - let _ = child.wait(); + after_close(); }) .map_err(|err| ShellError::PluginFailedToLoad { msg: format!("Failed to spawn thread for plugin: {err}"), @@ -211,15 +232,10 @@ fn make_plugin_interface( } #[doc(hidden)] // Note: not for plugin authors / only used in nu-parser -pub fn get_signature( +pub fn get_signature( plugin: Arc, - envs: impl FnOnce() -> Result, -) -> Result, ShellError> -where - E: IntoIterator, - K: AsRef, - V: AsRef, -{ + envs: impl FnOnce() -> Result, ShellError>, +) -> Result, ShellError> { plugin.get(envs)?.get_signature() } @@ -412,9 +428,7 @@ pub trait Plugin: Sync { /// } /// ``` pub fn serve_plugin(plugin: &impl Plugin, encoder: impl PluginEncoder + 'static) { - let mut args = env::args().skip(1); - let number_of_args = args.len(); - let first_arg = args.next(); + let args: Vec = env::args_os().skip(1).collect(); // Determine the plugin name, for errors let exe = std::env::current_exe().ok(); @@ -430,18 +444,26 @@ pub fn serve_plugin(plugin: &impl Plugin, encoder: impl PluginEncoder + 'static) }) .unwrap_or_else(|| "(unknown)".into()); - if number_of_args == 0 - || first_arg - .as_ref() - .is_some_and(|arg| arg == "-h" || arg == "--help") - { + if args.is_empty() || args[0] == "-h" || args[0] == "--help" { print_help(plugin, encoder); std::process::exit(0) } - // Must pass --stdio for plugin execution. Any other arg is an error to give us options in the - // future. - if number_of_args > 1 || !first_arg.is_some_and(|arg| arg == "--stdio") { + // Implement different communication modes: + let mode = if args[0] == "--stdio" && args.len() == 1 { + // --stdio always supported. + CommunicationMode::Stdio + } else if args[0] == "--local-socket" && args.len() == 2 { + #[cfg(feature = "local-socket")] + { + CommunicationMode::LocalSocket((&args[1]).into()) + } + #[cfg(not(feature = "local-socket"))] + { + eprintln!("{plugin_name}: local socket mode is not supported"); + std::process::exit(1); + } + } else { eprintln!( "{}: This plugin must be run from within Nushell.", env::current_exe() @@ -453,34 +475,42 @@ pub fn serve_plugin(plugin: &impl Plugin, encoder: impl PluginEncoder + 'static) version of nushell you are using." ); std::process::exit(1) - } - - // tell nushell encoding. - // - // 1 byte - // encoding format: | content-length | content | - let mut stdout = std::io::stdout(); - { - let encoding = encoder.name(); - let length = encoding.len() as u8; - let mut encoding_content: Vec = encoding.as_bytes().to_vec(); - encoding_content.insert(0, length); - stdout - .write_all(&encoding_content) - .expect("Failed to tell nushell my encoding"); - stdout - .flush() - .expect("Failed to tell nushell my encoding when flushing stdout"); - } + }; let encoder_clone = encoder.clone(); - let result = serve_plugin_io( - plugin, - &plugin_name, - move || (std::io::stdin().lock(), encoder_clone), - move || (std::io::stdout(), encoder), - ); + let result = match mode.connect_as_client() { + Ok(ClientCommunicationIo::Stdio(stdin, mut stdout)) => { + tell_nushell_encoding(&mut stdout, &encoder).expect("failed to tell nushell encoding"); + serve_plugin_io( + plugin, + &plugin_name, + move || (stdin.lock(), encoder_clone), + move || (stdout, encoder), + ) + } + #[cfg(feature = "local-socket")] + Ok(ClientCommunicationIo::LocalSocket { + read_in, + mut write_out, + }) => { + tell_nushell_encoding(&mut write_out, &encoder) + .expect("failed to tell nushell encoding"); + + let read = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, read_in); + let write = Mutex::new(BufWriter::with_capacity(OUTPUT_BUFFER_SIZE, write_out)); + serve_plugin_io( + plugin, + &plugin_name, + move || (read, encoder_clone), + move || (write, encoder), + ) + } + Err(err) => { + eprintln!("{plugin_name}: failed to connect: {err:?}"); + std::process::exit(1); + } + }; match result { Ok(()) => (), @@ -493,6 +523,22 @@ pub fn serve_plugin(plugin: &impl Plugin, encoder: impl PluginEncoder + 'static) } } +fn tell_nushell_encoding( + writer: &mut impl std::io::Write, + encoder: &impl PluginEncoder, +) -> Result<(), std::io::Error> { + // tell nushell encoding. + // + // 1 byte + // encoding format: | content-length | content | + let encoding = encoder.name(); + let length = encoding.len() as u8; + let mut encoding_content: Vec = encoding.as_bytes().to_vec(); + encoding_content.insert(0, length); + writer.write_all(&encoding_content)?; + writer.flush() +} + /// An error from [`serve_plugin_io()`] #[derive(Debug, Error)] pub enum ServePluginError { @@ -765,6 +811,8 @@ fn custom_value_op( } fn print_help(plugin: &impl Plugin, encoder: impl PluginEncoder) { + use std::fmt::Write; + println!("Nushell Plugin"); println!("Encoder: {}", encoder.name()); @@ -831,7 +879,9 @@ fn print_help(plugin: &impl Plugin, encoder: impl PluginEncoder) { println!("{help}") } -pub fn get_plugin_encoding(child_stdout: &mut ChildStdout) -> Result { +pub fn get_plugin_encoding( + child_stdout: &mut impl std::io::Read, +) -> Result { let mut length_buf = [0u8; 1]; child_stdout .read_exact(&mut length_buf) diff --git a/crates/nu-plugin/src/plugin/persistent.rs b/crates/nu-plugin/src/plugin/persistent.rs index a2f3cf0733..5c08add437 100644 --- a/crates/nu-plugin/src/plugin/persistent.rs +++ b/crates/nu-plugin/src/plugin/persistent.rs @@ -1,10 +1,13 @@ -use super::{create_command, gc::PluginGc, make_plugin_interface, PluginInterface, PluginSource}; +use super::{ + communication_mode::CommunicationMode, create_command, gc::PluginGc, make_plugin_interface, + PluginInterface, PluginSource, +}; use nu_protocol::{ engine::{EngineState, Stack}, PluginGcConfig, PluginIdentity, RegisteredPlugin, ShellError, }; use std::{ - ffi::OsStr, + collections::HashMap, sync::{Arc, Mutex}, }; @@ -28,14 +31,21 @@ pub struct PersistentPlugin { struct MutableState { /// Reference to the plugin if running running: Option, + /// Plugin's preferred communication mode (if known) + preferred_mode: Option, /// Garbage collector config gc_config: PluginGcConfig, } +#[derive(Debug, Clone, Copy)] +enum PreferredCommunicationMode { + Stdio, + #[cfg(feature = "local-socket")] + LocalSocket, +} + #[derive(Debug)] struct RunningPlugin { - /// Process ID of the running plugin - pid: u32, /// Interface (which can be cloned) to the running plugin interface: PluginInterface, /// Garbage collector for the plugin @@ -49,6 +59,7 @@ impl PersistentPlugin { identity, mutable: Mutex::new(MutableState { running: None, + preferred_mode: None, gc_config, }), } @@ -58,15 +69,10 @@ impl PersistentPlugin { /// /// Will call `envs` to get environment variables to spawn the plugin if the plugin needs to be /// spawned. - pub(crate) fn get( + pub(crate) fn get( self: Arc, - envs: impl FnOnce() -> Result, - ) -> Result - where - E: IntoIterator, - K: AsRef, - V: AsRef, - { + envs: impl FnOnce() -> Result, ShellError>, + ) -> Result { let mut mutable = self.mutable.lock().map_err(|_| ShellError::NushellFailed { msg: format!( "plugin `{}` mutex poisoned, probably panic during spawn", @@ -78,28 +84,70 @@ impl PersistentPlugin { // It exists, so just clone the interface Ok(running.interface.clone()) } else { - // Try to spawn, and then store the spawned plugin if we were successful. + // Try to spawn. On success, `mutable.running` should have been set to the new running + // plugin by `spawn()` so we just then need to clone the interface from there. // // We hold the lock the whole time to prevent others from trying to spawn and ending // up with duplicate plugins // // TODO: We should probably store the envs somewhere, in case we have to launch without // envs (e.g. from a custom value) - let new_running = self.clone().spawn(envs()?, &mutable.gc_config)?; - let interface = new_running.interface.clone(); - mutable.running = Some(new_running); - Ok(interface) + let envs = envs()?; + let result = self.clone().spawn(&envs, &mut mutable); + + // Check if we were using an alternate communication mode and may need to fall back to + // stdio. + if result.is_err() + && !matches!( + mutable.preferred_mode, + Some(PreferredCommunicationMode::Stdio) + ) + { + log::warn!("{}: Trying again with stdio communication because mode {:?} failed with {result:?}", + self.identity.name(), + mutable.preferred_mode); + // Reset to stdio and try again, but this time don't catch any error + mutable.preferred_mode = Some(PreferredCommunicationMode::Stdio); + self.clone().spawn(&envs, &mut mutable)?; + } + + Ok(mutable + .running + .as_ref() + .ok_or_else(|| ShellError::NushellFailed { + msg: "spawn() succeeded but didn't set interface".into(), + })? + .interface + .clone()) } } - /// Run the plugin command, then set up and return [`RunningPlugin`]. + /// Run the plugin command, then set up and set `mutable.running` to the new running plugin. fn spawn( self: Arc, - envs: impl IntoIterator, impl AsRef)>, - gc_config: &PluginGcConfig, - ) -> Result { + envs: &HashMap, + mutable: &mut MutableState, + ) -> Result<(), ShellError> { + // Make sure `running` is set to None to begin + if let Some(running) = mutable.running.take() { + // Stop the GC if there was a running plugin + running.gc.stop_tracking(); + } + let source_file = self.identity.filename(); - let mut plugin_cmd = create_command(source_file, self.identity.shell()); + + // Determine the mode to use based on the preferred mode + let mode = match mutable.preferred_mode { + // If not set, we try stdio first and then might retry if another mode is supported + Some(PreferredCommunicationMode::Stdio) | None => CommunicationMode::Stdio, + // Local socket only if enabled + #[cfg(feature = "local-socket")] + Some(PreferredCommunicationMode::LocalSocket) => { + CommunicationMode::local_socket(source_file) + } + }; + + let mut plugin_cmd = create_command(source_file, self.identity.shell(), &mode); // We need the current environment variables for `python` based plugins // Or we'll likely have a problem when a plugin is implemented in a virtual Python environment. @@ -107,6 +155,9 @@ impl PersistentPlugin { let program_name = plugin_cmd.get_program().to_os_string().into_string(); + // Before running the command, prepare communication + let comm = mode.serve()?; + // Run the plugin command let child = plugin_cmd.spawn().map_err(|err| { let error_msg = match err.kind() { @@ -126,13 +177,64 @@ impl PersistentPlugin { })?; // Start the plugin garbage collector - let gc = PluginGc::new(gc_config.clone(), &self)?; + let gc = PluginGc::new(mutable.gc_config.clone(), &self)?; let pid = child.id(); - let interface = - make_plugin_interface(child, Arc::new(PluginSource::new(self)), Some(gc.clone()))?; + let interface = make_plugin_interface( + child, + comm, + Arc::new(PluginSource::new(self.clone())), + Some(pid), + Some(gc.clone()), + )?; - Ok(RunningPlugin { pid, interface, gc }) + // If our current preferred mode is None, check to see if the plugin might support another + // mode. If so, retry spawn() with that mode + #[cfg(feature = "local-socket")] + if mutable.preferred_mode.is_none() + && interface + .protocol_info()? + .supports_feature(&crate::protocol::Feature::LocalSocket) + { + log::trace!( + "{}: Attempting to upgrade to local socket mode", + self.identity.name() + ); + // Stop the GC we just created from tracking so that we don't accidentally try to + // stop the new plugin + gc.stop_tracking(); + // Set the mode and try again + mutable.preferred_mode = Some(PreferredCommunicationMode::LocalSocket); + return self.spawn(envs, mutable); + } + + mutable.running = Some(RunningPlugin { interface, gc }); + Ok(()) + } + + fn stop_internal(&self, reset: bool) -> Result<(), ShellError> { + let mut mutable = self.mutable.lock().map_err(|_| ShellError::NushellFailed { + msg: format!( + "plugin `{}` mutable mutex poisoned, probably panic during spawn", + self.identity.name() + ), + })?; + + // If the plugin is running, stop its GC, so that the GC doesn't accidentally try to stop + // a future plugin + if let Some(ref running) = mutable.running { + running.gc.stop_tracking(); + } + + // We don't try to kill the process or anything, we just drop the RunningPlugin. It should + // exit soon after + mutable.running = None; + + // If this is a reset, we should also reset other learned attributes like preferred_mode + if reset { + mutable.preferred_mode = None; + } + Ok(()) } } @@ -155,27 +257,15 @@ impl RegisteredPlugin for PersistentPlugin { self.mutable .lock() .ok() - .and_then(|r| r.running.as_ref().map(|r| r.pid)) + .and_then(|r| r.running.as_ref().and_then(|r| r.interface.pid())) } fn stop(&self) -> Result<(), ShellError> { - let mut mutable = self.mutable.lock().map_err(|_| ShellError::NushellFailed { - msg: format!( - "plugin `{}` mutable mutex poisoned, probably panic during spawn", - self.identity.name() - ), - })?; + self.stop_internal(false) + } - // If the plugin is running, stop its GC, so that the GC doesn't accidentally try to stop - // a future plugin - if let Some(ref running) = mutable.running { - running.gc.stop_tracking(); - } - - // We don't try to kill the process or anything, we just drop the RunningPlugin. It should - // exit soon after - mutable.running = None; - Ok(()) + fn reset(&self) -> Result<(), ShellError> { + self.stop_internal(true) } fn set_gc_config(&self, gc_config: &PluginGcConfig) { @@ -214,11 +304,12 @@ pub trait GetPlugin: RegisteredPlugin { impl GetPlugin for PersistentPlugin { fn get_plugin( self: Arc, - context: Option<(&EngineState, &mut Stack)>, + mut context: Option<(&EngineState, &mut Stack)>, ) -> Result { self.get(|| { // Get envs from the context if provided. let envs = context + .as_mut() .map(|(engine_state, stack)| { // We need the current environment variables for `python` based plugins. Or // we'll likely have a problem when a plugin is implemented in a virtual Python @@ -228,7 +319,7 @@ impl GetPlugin for PersistentPlugin { }) .transpose()?; - Ok(envs.into_iter().flatten()) + Ok(envs.unwrap_or_default()) }) } } diff --git a/crates/nu-plugin/src/plugin/process.rs b/crates/nu-plugin/src/plugin/process.rs new file mode 100644 index 0000000000..a87e2ea22b --- /dev/null +++ b/crates/nu-plugin/src/plugin/process.rs @@ -0,0 +1,90 @@ +use std::sync::{atomic::AtomicU32, Arc, Mutex, MutexGuard}; + +use nu_protocol::{ShellError, Span}; +use nu_system::ForegroundGuard; + +/// Provides a utility interface for a plugin interface to manage the process the plugin is running +/// in. +#[derive(Debug)] +pub(crate) struct PluginProcess { + pid: u32, + mutable: Mutex, +} + +#[derive(Debug)] +struct MutablePart { + foreground_guard: Option, +} + +impl PluginProcess { + /// Manage a plugin process. + pub(crate) fn new(pid: u32) -> PluginProcess { + PluginProcess { + pid, + mutable: Mutex::new(MutablePart { + foreground_guard: None, + }), + } + } + + /// The process ID of the plugin. + pub(crate) fn pid(&self) -> u32 { + self.pid + } + + fn lock_mutable(&self) -> Result, ShellError> { + self.mutable.lock().map_err(|_| ShellError::NushellFailed { + msg: "the PluginProcess mutable lock has been poisoned".into(), + }) + } + + /// Move the plugin process to the foreground. See [`ForegroundGuard::new`]. + /// + /// This produces an error if the plugin process was already in the foreground. + /// + /// Returns `Some()` on Unix with the process group ID if the plugin process will need to join + /// another process group to be part of the foreground. + pub(crate) fn enter_foreground( + &self, + span: Span, + pipeline_state: &Arc<(AtomicU32, AtomicU32)>, + ) -> Result, ShellError> { + let pid = self.pid; + let mut mutable = self.lock_mutable()?; + if mutable.foreground_guard.is_none() { + let guard = ForegroundGuard::new(pid, pipeline_state).map_err(|err| { + ShellError::GenericError { + error: "Failed to enter foreground".into(), + msg: err.to_string(), + span: Some(span), + help: None, + inner: vec![], + } + })?; + let pgrp = guard.pgrp(); + mutable.foreground_guard = Some(guard); + Ok(pgrp) + } else { + Err(ShellError::GenericError { + error: "Can't enter foreground".into(), + msg: "this plugin is already running in the foreground".into(), + span: Some(span), + help: Some( + "you may be trying to run the command in parallel, or this may be a bug in \ + the plugin" + .into(), + ), + inner: vec![], + }) + } + } + + /// Move the plugin process out of the foreground. See [`ForegroundGuard::reset`]. + /// + /// This is a no-op if the plugin process was already in the background. + pub(crate) fn exit_foreground(&self) -> Result<(), ShellError> { + let mut mutable = self.lock_mutable()?; + drop(mutable.foreground_guard.take()); + Ok(()) + } +} diff --git a/crates/nu-plugin/src/protocol/mod.rs b/crates/nu-plugin/src/protocol/mod.rs index 6f71efc4c3..c716e37901 100644 --- a/crates/nu-plugin/src/protocol/mod.rs +++ b/crates/nu-plugin/src/protocol/mod.rs @@ -17,9 +17,8 @@ use std::collections::HashMap; pub use evaluated_call::EvaluatedCall; pub use plugin_custom_value::PluginCustomValue; -#[cfg(test)] -pub use protocol_info::Protocol; -pub use protocol_info::ProtocolInfo; +#[allow(unused_imports)] // may be unused by compile flags +pub use protocol_info::{Feature, Protocol, ProtocolInfo}; /// A sequential identifier for a stream pub type StreamId = usize; @@ -485,6 +484,10 @@ pub enum EngineCall { AddEnvVar(String, Value), /// Get help for the current command GetHelp, + /// Move the plugin into the foreground for terminal interaction + EnterForeground, + /// Move the plugin out of the foreground once terminal interaction has finished + LeaveForeground, /// Get the contents of a span. Response is a binary which may not parse to UTF-8 GetSpanContents(Span), /// Evaluate a closure with stream input/output @@ -515,6 +518,8 @@ impl EngineCall { EngineCall::GetCurrentDir => "GetCurrentDir", EngineCall::AddEnvVar(..) => "AddEnvVar", EngineCall::GetHelp => "GetHelp", + EngineCall::EnterForeground => "EnterForeground", + EngineCall::LeaveForeground => "LeaveForeground", EngineCall::GetSpanContents(_) => "GetSpanContents", EngineCall::EvalClosure { .. } => "EvalClosure", } @@ -534,6 +539,8 @@ impl EngineCall { EngineCall::GetCurrentDir => EngineCall::GetCurrentDir, EngineCall::AddEnvVar(name, value) => EngineCall::AddEnvVar(name, value), EngineCall::GetHelp => EngineCall::GetHelp, + EngineCall::EnterForeground => EngineCall::EnterForeground, + EngineCall::LeaveForeground => EngineCall::LeaveForeground, EngineCall::GetSpanContents(span) => EngineCall::GetSpanContents(span), EngineCall::EvalClosure { closure, diff --git a/crates/nu-plugin/src/protocol/protocol_info.rs b/crates/nu-plugin/src/protocol/protocol_info.rs index e7f40234b5..922feb64b6 100644 --- a/crates/nu-plugin/src/protocol/protocol_info.rs +++ b/crates/nu-plugin/src/protocol/protocol_info.rs @@ -22,12 +22,13 @@ impl Default for ProtocolInfo { ProtocolInfo { protocol: Protocol::NuPlugin, version: env!("CARGO_PKG_VERSION").into(), - features: vec![], + features: default_features(), } } } impl ProtocolInfo { + /// True if the version specified in `self` is compatible with the version specified in `other`. pub fn is_compatible_with(&self, other: &ProtocolInfo) -> Result { fn parse_failed(error: semver::Error) -> ShellError { ShellError::PluginFailedToLoad { @@ -52,6 +53,11 @@ impl ProtocolInfo { } .matches(&versions[1])) } + + /// True if the protocol info contains a feature compatible with the given feature. + pub fn supports_feature(&self, feature: &Feature) -> bool { + self.features.iter().any(|f| feature.is_compatible_with(f)) + } } /// Indicates the protocol in use. Only one protocol is supported. @@ -72,9 +78,29 @@ pub enum Protocol { #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "name")] pub enum Feature { + /// The plugin supports running with a local socket passed via `--local-socket` instead of + /// stdio. + LocalSocket, + /// A feature that was not recognized on deserialization. Attempting to serialize this feature /// is an error. Matching against it may only be used if necessary to determine whether /// unsupported features are present. #[serde(other, skip_serializing)] Unknown, } + +impl Feature { + /// True if the feature is considered to be compatible with another feature. + pub fn is_compatible_with(&self, other: &Feature) -> bool { + matches!((self, other), (Feature::LocalSocket, Feature::LocalSocket)) + } +} + +/// Protocol features compiled into this version of `nu-plugin`. +pub fn default_features() -> Vec { + vec![ + // Only available if compiled with the `local-socket` feature flag (enabled by default). + #[cfg(feature = "local-socket")] + Feature::LocalSocket, + ] +} diff --git a/crates/nu-plugin/src/util/mod.rs b/crates/nu-plugin/src/util/mod.rs index 6a4fe5e5d4..ae861705b3 100644 --- a/crates/nu-plugin/src/util/mod.rs +++ b/crates/nu-plugin/src/util/mod.rs @@ -1,5 +1,7 @@ mod mutable_cow; +mod waitable; mod with_custom_values_in; pub(crate) use mutable_cow::*; +pub use waitable::Waitable; pub use with_custom_values_in::*; diff --git a/crates/nu-plugin/src/util/waitable.rs b/crates/nu-plugin/src/util/waitable.rs new file mode 100644 index 0000000000..9793c93b69 --- /dev/null +++ b/crates/nu-plugin/src/util/waitable.rs @@ -0,0 +1,100 @@ +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Condvar, Mutex, MutexGuard, PoisonError, +}; + +use nu_protocol::ShellError; + +/// A container that may be empty, and allows threads to block until it has a value. +#[derive(Debug)] +pub struct Waitable { + is_set: AtomicBool, + mutex: Mutex>, + condvar: Condvar, +} + +#[track_caller] +fn fail_if_poisoned<'a, T>( + result: Result, PoisonError>>, +) -> Result, ShellError> { + match result { + Ok(guard) => Ok(guard), + Err(_) => Err(ShellError::NushellFailedHelp { + msg: "Waitable mutex poisoned".into(), + help: std::panic::Location::caller().to_string(), + }), + } +} + +impl Waitable { + /// Create a new empty `Waitable`. + pub fn new() -> Waitable { + Waitable { + is_set: AtomicBool::new(false), + mutex: Mutex::new(None), + condvar: Condvar::new(), + } + } + + /// Wait for a value to be available and then clone it. + #[track_caller] + pub fn get(&self) -> Result { + let guard = fail_if_poisoned(self.mutex.lock())?; + if let Some(value) = (*guard).clone() { + Ok(value) + } else { + let guard = fail_if_poisoned(self.condvar.wait_while(guard, |g| g.is_none()))?; + Ok((*guard) + .clone() + .expect("checked already for Some but it was None")) + } + } + + /// Clone the value if one is available, but don't wait if not. + #[track_caller] + pub fn try_get(&self) -> Result, ShellError> { + let guard = fail_if_poisoned(self.mutex.lock())?; + Ok((*guard).clone()) + } + + /// Returns true if value is available. + #[track_caller] + pub fn is_set(&self) -> bool { + self.is_set.load(Ordering::SeqCst) + } + + /// Set the value and let waiting threads know. + #[track_caller] + pub fn set(&self, value: T) -> Result<(), ShellError> { + let mut guard = fail_if_poisoned(self.mutex.lock())?; + self.is_set.store(true, Ordering::SeqCst); + *guard = Some(value); + self.condvar.notify_all(); + Ok(()) + } +} + +impl Default for Waitable { + fn default() -> Self { + Self::new() + } +} + +#[test] +fn set_from_other_thread() -> Result<(), ShellError> { + use std::sync::Arc; + + let waitable = Arc::new(Waitable::new()); + let waitable_clone = waitable.clone(); + + assert!(!waitable.is_set()); + + std::thread::spawn(move || { + waitable_clone.set(42).expect("error on set"); + }); + + assert_eq!(42, waitable.get()?); + assert_eq!(Some(42), waitable.try_get()?); + assert!(waitable.is_set()); + Ok(()) +} diff --git a/crates/nu-protocol/src/plugin/registered.rs b/crates/nu-protocol/src/plugin/registered.rs index cb0c893728..46d65b41d1 100644 --- a/crates/nu-protocol/src/plugin/registered.rs +++ b/crates/nu-protocol/src/plugin/registered.rs @@ -19,6 +19,10 @@ pub trait RegisteredPlugin: Send + Sync { /// Stop the plugin. fn stop(&self) -> Result<(), ShellError>; + /// Stop the plugin and reset any state so that we don't make any assumptions about the plugin + /// next time it launches. This is used on `register`. + fn reset(&self) -> Result<(), ShellError>; + /// Cast the pointer to an [`Any`] so that its concrete type can be retrieved. /// /// This is necessary in order to allow `nu_plugin` to handle the implementation details of diff --git a/crates/nu-system/src/foreground.rs b/crates/nu-system/src/foreground.rs index f7d675461e..d54cab1f19 100644 --- a/crates/nu-system/src/foreground.rs +++ b/crates/nu-system/src/foreground.rs @@ -1,16 +1,11 @@ use std::{ io, process::{Child, Command}, + sync::{atomic::AtomicU32, Arc}, }; #[cfg(unix)] -use std::{ - io::IsTerminal, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, - }, -}; +use std::{io::IsTerminal, sync::atomic::Ordering}; #[cfg(unix)] pub use foreground_pgroup::stdin_fd; @@ -97,6 +92,139 @@ impl Drop for ForegroundChild { } } +/// Keeps a specific already existing process in the foreground as long as the [`ForegroundGuard`]. +/// If the process needs to be spawned in the foreground, use [`ForegroundChild`] instead. This is +/// used to temporarily bring plugin processes into the foreground. +/// +/// # OS-specific behavior +/// ## Unix +/// +/// If there is already a foreground external process running, spawned with [`ForegroundChild`], +/// this expects the process ID to remain in the process group created by the [`ForegroundChild`] +/// for the lifetime of the guard, and keeps the terminal controlling process group set to that. If +/// there is no foreground external process running, this sets the foreground process group to the +/// plugin's process ID. The process group that is expected can be retrieved with [`.pgrp()`] if +/// different from the plugin process ID. +/// +/// ## Other systems +/// +/// It does nothing special on non-unix systems. +#[derive(Debug)] +pub struct ForegroundGuard { + #[cfg(unix)] + pgrp: Option, + #[cfg(unix)] + pipeline_state: Arc<(AtomicU32, AtomicU32)>, +} + +impl ForegroundGuard { + /// Move the given process to the foreground. + #[cfg(unix)] + pub fn new( + pid: u32, + pipeline_state: &Arc<(AtomicU32, AtomicU32)>, + ) -> std::io::Result { + use nix::unistd::{self, Pid}; + + let pid_nix = Pid::from_raw(pid as i32); + let (pgrp, pcnt) = pipeline_state.as_ref(); + + // Might have to retry due to race conditions on the atomics + loop { + // Try to give control to the child, if there isn't currently a foreground group + if pgrp + .compare_exchange(0, pid, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + { + let _ = pcnt.fetch_add(1, Ordering::SeqCst); + + // We don't need the child to change process group. Make the guard now so that if there + // is an error, it will be cleaned up + let guard = ForegroundGuard { + pgrp: None, + pipeline_state: pipeline_state.clone(), + }; + + log::trace!("Giving control of the terminal to the plugin group, pid={pid}"); + + // Set the terminal controlling process group to the child process + unistd::tcsetpgrp(unsafe { stdin_fd() }, pid_nix)?; + + return Ok(guard); + } else if pcnt + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |count| { + // Avoid a race condition: only increment if count is > 0 + if count > 0 { + Some(count + 1) + } else { + None + } + }) + .is_ok() + { + // We successfully added another count to the foreground process group, which means + // we only need to tell the child process to join this one + let pgrp = pgrp.load(Ordering::SeqCst); + log::trace!( + "Will ask the plugin pid={pid} to join pgrp={pgrp} for control of the \ + terminal" + ); + return Ok(ForegroundGuard { + pgrp: Some(pgrp), + pipeline_state: pipeline_state.clone(), + }); + } else { + // The state has changed, we'll have to retry + continue; + } + } + } + + /// Move the given process to the foreground. + #[cfg(not(unix))] + pub fn new( + pid: u32, + pipeline_state: &Arc<(AtomicU32, AtomicU32)>, + ) -> std::io::Result { + let _ = (pid, pipeline_state); + Ok(ForegroundGuard {}) + } + + /// If the child process is expected to join a different process group to be in the foreground, + /// this returns `Some(pgrp)`. This only ever returns `Some` on Unix. + pub fn pgrp(&self) -> Option { + #[cfg(unix)] + { + self.pgrp + } + #[cfg(not(unix))] + { + None + } + } + + /// This should only be called once by `Drop` + fn reset_internal(&mut self) { + #[cfg(unix)] + { + log::trace!("Leaving the foreground group"); + + let (pgrp, pcnt) = self.pipeline_state.as_ref(); + if pcnt.fetch_sub(1, Ordering::SeqCst) == 1 { + // Clean up if we are the last one around + pgrp.store(0, Ordering::SeqCst); + foreground_pgroup::reset() + } + } + } +} + +impl Drop for ForegroundGuard { + fn drop(&mut self) { + self.reset_internal(); + } +} + // It's a simpler version of fish shell's external process handling. #[cfg(unix)] mod foreground_pgroup { diff --git a/crates/nu-system/src/lib.rs b/crates/nu-system/src/lib.rs index bee037cb8b..6058ed4fcf 100644 --- a/crates/nu-system/src/lib.rs +++ b/crates/nu-system/src/lib.rs @@ -9,7 +9,7 @@ mod windows; #[cfg(unix)] pub use self::foreground::stdin_fd; -pub use self::foreground::ForegroundChild; +pub use self::foreground::{ForegroundChild, ForegroundGuard}; #[cfg(any(target_os = "android", target_os = "linux"))] pub use self::linux::*; #[cfg(target_os = "macos")] diff --git a/crates/nu-test-support/src/macros.rs b/crates/nu-test-support/src/macros.rs index 386127f0b3..0b689ea107 100644 --- a/crates/nu-test-support/src/macros.rs +++ b/crates/nu-test-support/src/macros.rs @@ -203,15 +203,38 @@ macro_rules! nu_with_std { #[macro_export] macro_rules! nu_with_plugins { (cwd: $cwd:expr, plugins: [$(($plugin_name:expr)),+$(,)?], $command:expr) => {{ - $crate::macros::nu_with_plugin_run_test($cwd, &[$($plugin_name),+], $command) + nu_with_plugins!( + cwd: $cwd, + envs: Vec::<(&str, &str)>::new(), + plugins: [$(($plugin_name)),+], + $command + ) }}; (cwd: $cwd:expr, plugin: ($plugin_name:expr), $command:expr) => {{ - $crate::macros::nu_with_plugin_run_test($cwd, &[$plugin_name], $command) + nu_with_plugins!( + cwd: $cwd, + envs: Vec::<(&str, &str)>::new(), + plugin: ($plugin_name), + $command + ) + }}; + + ( + cwd: $cwd:expr, + envs: $envs:expr, + plugins: [$(($plugin_name:expr)),+$(,)?], + $command:expr + ) => {{ + $crate::macros::nu_with_plugin_run_test($cwd, $envs, &[$($plugin_name),+], $command) + }}; + (cwd: $cwd:expr, envs: $envs:expr, plugin: ($plugin_name:expr), $command:expr) => {{ + $crate::macros::nu_with_plugin_run_test($cwd, $envs, &[$plugin_name], $command) }}; } use crate::{Outcome, NATIVE_PATH_ENV_VAR}; +use std::ffi::OsStr; use std::fmt::Write; use std::{ path::Path, @@ -285,7 +308,17 @@ pub fn nu_run_test(opts: NuOpts, commands: impl AsRef, with_std: bool) -> O Outcome::new(out, err.into_owned(), output.status) } -pub fn nu_with_plugin_run_test(cwd: impl AsRef, plugins: &[&str], command: &str) -> Outcome { +pub fn nu_with_plugin_run_test( + cwd: impl AsRef, + envs: E, + plugins: &[&str], + command: &str, +) -> Outcome +where + E: IntoIterator, + K: AsRef, + V: AsRef, +{ let test_bins = crate::fs::binaries(); let test_bins = nu_path::canonicalize_with(&test_bins, ".").unwrap_or_else(|e| { panic!( @@ -325,6 +358,7 @@ pub fn nu_with_plugin_run_test(cwd: impl AsRef, plugins: &[&str], command: executable_path = crate::fs::installed_nu_path(); } let process = match setup_command(&executable_path, &target_cwd) + .envs(envs) .arg("--commands") .arg(commands) .arg("--config") diff --git a/crates/nu_plugin_python/nu_plugin_python_example.py b/crates/nu_plugin_python/nu_plugin_python_example.py index c2601214a6..6aaf1a5000 100755 --- a/crates/nu_plugin_python/nu_plugin_python_example.py +++ b/crates/nu_plugin_python/nu_plugin_python_example.py @@ -133,6 +133,13 @@ def process_call(id, plugin_call): span = plugin_call["Run"]["call"]["head"] # Creates a Value of type List that will be encoded and sent to Nushell + f = lambda x, y: { + "Int": { + "val": x * y, + "span": span + } + } + value = { "Value": { "List": { @@ -140,15 +147,9 @@ def process_call(id, plugin_call): { "Record": { "val": { - "cols": ["one", "two", "three"], - "vals": [ - { - "Int": { - "val": x * y, - "span": span - } - } for y in [0, 1, 2] - ] + "one": f(x, 0), + "two": f(x, 1), + "three": f(x, 2), }, "span": span } diff --git a/crates/nu_plugin_stress_internals/Cargo.toml b/crates/nu_plugin_stress_internals/Cargo.toml new file mode 100644 index 0000000000..0062c627fb --- /dev/null +++ b/crates/nu_plugin_stress_internals/Cargo.toml @@ -0,0 +1,19 @@ +[package] +authors = ["The Nushell Project Developers"] +description = "A test plugin for Nushell to stress aspects of the internals" +repository = "https://github.com/nushell/nushell/tree/main/crates/nu_plugin_stress_internals" +edition = "2021" +license = "MIT" +name = "nu_plugin_stress_internals" +version = "0.92.3" + +[[bin]] +name = "nu_plugin_stress_internals" +bench = false + +[dependencies] +# Intentionally not using the nu-protocol / nu-plugin crates, to check behavior against our +# assumptions about the serialized format +serde = { workspace = true } +serde_json = { workspace = true } +interprocess = "1.2.1" diff --git a/crates/nu_plugin_stress_internals/LICENSE b/crates/nu_plugin_stress_internals/LICENSE new file mode 100644 index 0000000000..ae174e8595 --- /dev/null +++ b/crates/nu_plugin_stress_internals/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 - 2023 The Nushell Project Developers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/nu_plugin_stress_internals/src/main.rs b/crates/nu_plugin_stress_internals/src/main.rs new file mode 100644 index 0000000000..3f95d5ef02 --- /dev/null +++ b/crates/nu_plugin_stress_internals/src/main.rs @@ -0,0 +1,213 @@ +use std::{ + error::Error, + io::{BufRead, BufReader, Write}, +}; + +use interprocess::local_socket::LocalSocketStream; +use serde::Deserialize; +use serde_json::{json, Value}; + +#[derive(Debug)] +struct Options { + refuse_local_socket: bool, + advertise_local_socket: bool, + exit_early: bool, + wrong_version: bool, + local_socket_path: Option, +} + +pub fn main() -> Result<(), Box> { + let args: Vec = std::env::args().collect(); + + eprintln!("stress_internals: args: {args:?}"); + + // Parse options from environment variables + fn has_env(var: &str) -> bool { + std::env::var(var).is_ok() + } + let mut opts = Options { + refuse_local_socket: has_env("STRESS_REFUSE_LOCAL_SOCKET"), + advertise_local_socket: has_env("STRESS_ADVERTISE_LOCAL_SOCKET"), + exit_early: has_env("STRESS_EXIT_EARLY"), + wrong_version: has_env("STRESS_WRONG_VERSION"), + local_socket_path: None, + }; + + #[allow(unused_mut)] + let mut should_flush = true; + + let (mut input, mut output): (Box, Box) = + match args.get(1).map(|s| s.as_str()) { + Some("--stdio") => ( + Box::new(std::io::stdin().lock()), + Box::new(std::io::stdout()), + ), + Some("--local-socket") => { + opts.local_socket_path = Some(args[2].clone()); + if opts.refuse_local_socket { + std::process::exit(1) + } else { + let in_socket = LocalSocketStream::connect(args[2].as_str())?; + let out_socket = LocalSocketStream::connect(args[2].as_str())?; + + #[cfg(windows)] + { + // Flushing on a socket on Windows is weird and waits for the other side + should_flush = false; + } + + (Box::new(BufReader::new(in_socket)), Box::new(out_socket)) + } + } + None => { + eprintln!("Run nu_plugin_stress_internals as a plugin from inside nushell"); + std::process::exit(1) + } + _ => { + eprintln!("Received args I don't understand: {args:?}"); + std::process::exit(1) + } + }; + + // Send encoding format + output.write_all(b"\x04json")?; + if should_flush { + output.flush()?; + } + + // Send `Hello` message + write( + &mut output, + should_flush, + &json!({ + "Hello": { + "protocol": "nu-plugin", + "version": if opts.wrong_version { + "0.0.0" + } else { + env!("CARGO_PKG_VERSION") + }, + "features": if opts.advertise_local_socket { + vec![json!({"name": "LocalSocket"})] + } else { + vec![] + }, + } + }), + )?; + + // Read `Hello` message + let mut de = serde_json::Deserializer::from_reader(&mut input); + let hello: Value = Value::deserialize(&mut de)?; + + assert!(hello.get("Hello").is_some()); + + if opts.exit_early { + // Exit without handling anything other than Hello + std::process::exit(0); + } + + // Parse incoming messages + loop { + match Value::deserialize(&mut de) { + Ok(message) => handle_message(&mut output, should_flush, &opts, &message)?, + Err(err) => { + if err.is_eof() { + break; + } else { + return Err(err.into()); + } + } + } + } + + Ok(()) +} + +fn handle_message( + output: &mut impl Write, + should_flush: bool, + opts: &Options, + message: &Value, +) -> Result<(), Box> { + if let Some(plugin_call) = message.get("Call") { + let (id, plugin_call) = (&plugin_call[0], &plugin_call[1]); + if plugin_call.as_str() == Some("Signature") { + write( + output, + should_flush, + &json!({ + "CallResponse": [ + id, + { + "Signature": signatures(), + } + ] + }), + ) + } else if let Some(call_info) = plugin_call.get("Run") { + if call_info["name"].as_str() == Some("stress_internals") { + // Just return debug of opts + let return_value = json!({ + "String": { + "val": format!("{opts:?}"), + "span": &call_info["call"]["head"], + } + }); + write( + output, + should_flush, + &json!({ + "CallResponse": [ + id, + { + "PipelineData": { + "Value": return_value + } + } + ] + }), + ) + } else { + Err(format!("unknown call name: {call_info}").into()) + } + } else { + Err(format!("unknown plugin call: {plugin_call}").into()) + } + } else if message.as_str() == Some("Goodbye") { + std::process::exit(0); + } else { + Err(format!("unknown message: {message}").into()) + } +} + +fn signatures() -> Vec { + vec![json!({ + "sig": { + "name": "stress_internals", + "usage": "Used to test behavior of plugin protocol", + "extra_usage": "", + "search_terms": [], + "required_positional": [], + "optional_positional": [], + "rest_positional": null, + "named": [], + "input_output_types": [], + "allow_variants_without_examples": false, + "is_filter": false, + "creates_scope": false, + "allows_unknown_args": false, + "category": "Experimental", + }, + "examples": [], + })] +} + +fn write(output: &mut impl Write, should_flush: bool, value: &Value) -> Result<(), Box> { + serde_json::to_writer(&mut *output, value)?; + output.write_all(b"\n")?; + if should_flush { + output.flush()?; + } + Ok(()) +} diff --git a/tests/plugins/mod.rs b/tests/plugins/mod.rs index 244af5cbef..f52006d3ad 100644 --- a/tests/plugins/mod.rs +++ b/tests/plugins/mod.rs @@ -5,3 +5,4 @@ mod env; mod formats; mod register; mod stream; +mod stress_internals; diff --git a/tests/plugins/stress_internals.rs b/tests/plugins/stress_internals.rs new file mode 100644 index 0000000000..1207c15252 --- /dev/null +++ b/tests/plugins/stress_internals.rs @@ -0,0 +1,123 @@ +use nu_test_support::nu_with_plugins; + +fn ensure_stress_env_vars_unset() { + for (key, _) in std::env::vars_os() { + if key.to_string_lossy().starts_with("STRESS_") { + panic!("Test is running in a dirty environment: {key:?} is set"); + } + } +} + +#[test] +fn test_stdio() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(result.status.success()); + assert!(result.out.contains("local_socket_path: None")); +} + +#[test] +fn test_local_socket() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_ADVERTISE_LOCAL_SOCKET", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(result.status.success()); + // Should be run once in stdio mode + assert!(result.err.contains("--stdio")); + // And then in local socket mode + assert!(result.err.contains("--local-socket")); + assert!(result.out.contains("local_socket_path: Some")); +} + +#[test] +fn test_failing_local_socket_fallback() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_ADVERTISE_LOCAL_SOCKET", "1"), + ("STRESS_REFUSE_LOCAL_SOCKET", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(result.status.success()); + + // Count the number of times we do stdio/local socket + let mut count_stdio = 0; + let mut count_local_socket = 0; + + for line in result.err.split('\n') { + if line.contains("--stdio") { + count_stdio += 1; + } + if line.contains("--local-socket") { + count_local_socket += 1; + } + } + + // Should be run once in local socket mode + assert_eq!(1, count_local_socket, "count of --local-socket"); + // Should be run twice in stdio mode, due to the fallback + assert_eq!(2, count_stdio, "count of --stdio"); + + // In the end it should not be running in local socket mode, but should succeed + assert!(result.out.contains("local_socket_path: None")); +} + +#[test] +fn test_exit_early_stdio() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_EXIT_EARLY", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(!result.status.success()); + assert!(result.err.contains("--stdio")); +} + +#[test] +fn test_exit_early_local_socket() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_ADVERTISE_LOCAL_SOCKET", "1"), + ("STRESS_EXIT_EARLY", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(!result.status.success()); + assert!(result.err.contains("--local-socket")); +} + +#[test] +fn test_wrong_version() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_WRONG_VERSION", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(!result.status.success()); + assert!(result.err.contains("version")); + assert!(result.err.contains("0.0.0")); +}