From c06ef201b72b3cbe901820417106b7d65c6f01e1 Mon Sep 17 00:00:00 2001 From: Devyn Cairns Date: Mon, 15 Apr 2024 11:28:18 -0700 Subject: [PATCH] Local socket mode and foreground terminal control for plugins (#12448) # Description Adds support for running plugins using local socket communication instead of stdio. This will be an optional thing that not all plugins have to support. This frees up stdio for use to make plugins that use stdio to create terminal UIs, cc @amtoine, @fdncred. This uses the [`interprocess`](https://crates.io/crates/interprocess) crate (298 stars, MIT license, actively maintained), which seems to be the best option for cross-platform local socket support in Rust. On Windows, a local socket name is provided. On Unixes, it's a path. The socket name is kept to a relatively small size because some operating systems have pretty strict limits on the whole path (~100 chars), so on macOS for example we prefer `/tmp/nu.{pid}.{hash64}.sock` where the hash includes the plugin filename and timestamp to be unique enough. This also adds an API for moving plugins in and out of the foreground group, which is relevant for Unixes where direct terminal control depends on that. TODO: - [x] Generate local socket path according to OS conventions - [x] Add support for passing `--local-socket` to the plugin executable instead of `--stdio`, and communicating over that instead - [x] Test plugins that were broken, including [amtoine/nu_plugin_explore](https://github.com/amtoine/nu_plugin_explore) - [x] Automatically upgrade to using local sockets when supported, falling back if it doesn't work, transparently to the user without any visible error messages - Added protocol feature: `LocalSocket` - [x] Reset preferred mode to `None` on `register` - [x] Allow plugins to detect whether they're running on a local socket and can use stdio freely, so that TUI plugins can just produce an error message otherwise - Implemented via `EngineInterface::is_using_stdio()` - [x] Clean up foreground state when plugin command exits on the engine side too, not just whole plugin - [x] Make sure tests for failure cases work as intended - `nu_plugin_stress_internals` added # User-Facing Changes - TUI plugins work - Non-Rust plugins could optionally choose to use this - This might behave differently, so will need to test it carefully across different operating systems # Tests + Formatting - :green_circle: `toolkit fmt` - :green_circle: `toolkit clippy` - :green_circle: `toolkit test` - :green_circle: `toolkit test stdlib` # After Submitting - [ ] Document local socket option in plugin contrib docs - [ ] Document how to do a terminal UI plugin in plugin contrib docs - [ ] Document: `EnterForeground` engine call - [ ] Document: `LeaveForeground` engine call - [ ] Document: `LocalSocket` protocol feature --- Cargo.lock | 183 +++++++++++ Cargo.toml | 1 + crates/nu-parser/src/parse_keywords.rs | 7 +- .../src/fake_persistent_plugin.rs | 5 + .../src/spawn_fake_plugin.rs | 4 +- crates/nu-plugin/Cargo.toml | 10 + .../communication_mode/local_socket/mod.rs | 84 +++++ .../communication_mode/local_socket/tests.rs | 19 ++ .../src/plugin/communication_mode/mod.rs | 233 ++++++++++++++ crates/nu-plugin/src/plugin/context.rs | 25 +- crates/nu-plugin/src/plugin/interface.rs | 13 + .../nu-plugin/src/plugin/interface/engine.rs | 137 ++++++++- .../src/plugin/interface/engine/tests.rs | 40 ++- .../nu-plugin/src/plugin/interface/plugin.rs | 111 ++++++- .../src/plugin/interface/plugin/tests.rs | 24 +- .../src/plugin/interface/test_util.rs | 2 +- crates/nu-plugin/src/plugin/mod.rs | 290 ++++++++++-------- crates/nu-plugin/src/plugin/persistent.rs | 181 ++++++++--- crates/nu-plugin/src/plugin/process.rs | 90 ++++++ crates/nu-plugin/src/protocol/mod.rs | 13 +- .../nu-plugin/src/protocol/protocol_info.rs | 28 +- crates/nu-plugin/src/util/mod.rs | 2 + crates/nu-plugin/src/util/waitable.rs | 100 ++++++ crates/nu-protocol/src/plugin/registered.rs | 4 + crates/nu-system/src/foreground.rs | 142 ++++++++- crates/nu-system/src/lib.rs | 2 +- crates/nu-test-support/src/macros.rs | 40 ++- .../nu_plugin_python_example.py | 19 +- crates/nu_plugin_stress_internals/Cargo.toml | 19 ++ crates/nu_plugin_stress_internals/LICENSE | 21 ++ crates/nu_plugin_stress_internals/src/main.rs | 213 +++++++++++++ tests/plugins/mod.rs | 1 + tests/plugins/stress_internals.rs | 123 ++++++++ 33 files changed, 1949 insertions(+), 237 deletions(-) create mode 100644 crates/nu-plugin/src/plugin/communication_mode/local_socket/mod.rs create mode 100644 crates/nu-plugin/src/plugin/communication_mode/local_socket/tests.rs create mode 100644 crates/nu-plugin/src/plugin/communication_mode/mod.rs create mode 100644 crates/nu-plugin/src/plugin/process.rs create mode 100644 crates/nu-plugin/src/util/waitable.rs create mode 100644 crates/nu_plugin_stress_internals/Cargo.toml create mode 100644 crates/nu_plugin_stress_internals/LICENSE create mode 100644 crates/nu_plugin_stress_internals/src/main.rs create mode 100644 tests/plugins/stress_internals.rs diff --git a/Cargo.lock b/Cargo.lock index 18d9194675..e12545ccfc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -234,6 +234,30 @@ dependencies = [ "wait-timeout", ] +[[package]] +name = "async-channel" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28243a43d821d11341ab73c80bed182dc015c514b951616cf79bd4af39af0c3" +dependencies = [ + "concurrent-queue", + "event-listener 5.3.0", + "event-listener-strategy 0.5.1", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-lock" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d034b430882f8381900d3fe6f0aaa3ad94f2cb4ac519b429692a1bc2dda4ae7b" +dependencies = [ + "event-listener 4.0.3", + "event-listener-strategy 0.4.0", + "pin-project-lite", +] + [[package]] name = "async-stream" version = "0.3.5" @@ -256,6 +280,12 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "async-task" +version = "4.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbb36e985947064623dbd357f727af08ffd077f93d696782f3c56365fa2e2799" + [[package]] name = "async-trait" version = "0.1.79" @@ -282,6 +312,12 @@ version = "0.15.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.2.0" @@ -430,6 +466,22 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blocking" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a37913e8dc4ddcc604f0c6d3bf2887c995153af3611de9e23c352b44c1b9118" +dependencies = [ + "async-channel", + "async-lock", + "async-task", + "fastrand", + "futures-io", + "futures-lite", + "piper", + "tracing", +] + [[package]] name = "borsh" version = "1.4.0" @@ -826,6 +878,15 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "concurrent-queue" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "condtype" version = "1.3.0" @@ -1362,6 +1423,48 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" +[[package]] +name = "event-listener" +version = "4.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b215c49b2b248c855fb73579eb1f4f26c38ffdc12973e20e07b91d78d5646e" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d9944b8ca13534cdfb2800775f8dd4902ff3fc75a50101466decadfdf322a24" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" +dependencies = [ + "event-listener 4.0.3", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "332f51cb23d20b0de8458b86580878211da09bcd4503cb579c225b3d124cabb3" +dependencies = [ + "event-listener 5.3.0", + "pin-project-lite", +] + [[package]] name = "fallible-iterator" version = "0.3.0" @@ -1578,6 +1681,16 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +[[package]] +name = "futures-lite" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52527eb5074e35e9339c6b4e8d12600c7128b68fb25dcb9fa9dec18f7c25f3a5" +dependencies = [ + "futures-core", + "pin-project-lite", +] + [[package]] name = "futures-macro" version = "0.3.30" @@ -1997,6 +2110,32 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "interprocess" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81f2533f3be42fffe3b5e63b71aeca416c1c3bc33e4e27be018521e76b1f38fb" +dependencies = [ + "blocking", + "cfg-if", + "futures-core", + "futures-io", + "intmap", + "libc", + "once_cell", + "rustc_version", + "spinning", + "thiserror", + "to_method", + "winapi", +] + +[[package]] +name = "intmap" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae52f28f45ac2bc96edb7714de995cffc174a395fb0abf5bff453587c980d7b9" + [[package]] name = "inventory" version = "0.3.15" @@ -3077,10 +3216,13 @@ name = "nu-plugin" version = "0.92.3" dependencies = [ "bincode", + "interprocess", "log", "miette", + "nix", "nu-engine", "nu-protocol", + "nu-system", "nu-utils", "rmp-serde", "semver", @@ -3311,6 +3453,15 @@ dependencies = [ "sxd-xpath", ] +[[package]] +name = "nu_plugin_stress_internals" +version = "0.92.3" +dependencies = [ + "interprocess", + "serde", + "serde_json", +] + [[package]] name = "num" version = "0.4.1" @@ -3591,6 +3742,12 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + [[package]] name = "parking_lot" version = "0.12.1" @@ -3814,6 +3971,17 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "piper" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "668d31b1c4eba19242f2088b2bf3316b82ca31082a8335764db4e083db7485d4" +dependencies = [ + "atomic-waker", + "fastrand", + "futures-io", +] + [[package]] name = "pkg-config" version = "0.3.30" @@ -5309,6 +5477,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spinning" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d4f0e86297cad2658d92a707320d87bf4e6ae1050287f51d19b67ef3f153a7b" +dependencies = [ + "lock_api", +] + [[package]] name = "sqlparser" version = "0.39.0" @@ -5746,6 +5923,12 @@ dependencies = [ "regex", ] +[[package]] +name = "to_method" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c4ceeeca15c8384bbc3e011dbd8fccb7f068a440b752b7d9b32ceb0ca0e2e8" + [[package]] name = "tokio" version = "1.37.0" diff --git a/Cargo.toml b/Cargo.toml index 8b56585678..ee4062fcb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ members = [ "crates/nu_plugin_custom_values", "crates/nu_plugin_formats", "crates/nu_plugin_polars", + "crates/nu_plugin_stress_internals", "crates/nu-std", "crates/nu-table", "crates/nu-term-grid", diff --git a/crates/nu-parser/src/parse_keywords.rs b/crates/nu-parser/src/parse_keywords.rs index bac41108d7..b46a9f06f7 100644 --- a/crates/nu-parser/src/parse_keywords.rs +++ b/crates/nu-parser/src/parse_keywords.rs @@ -3687,7 +3687,7 @@ pub fn parse_register(working_set: &mut StateWorkingSet, lite_command: &LiteComm // Add it to the working set let plugin = working_set.find_or_create_plugin(&identity, || { - Arc::new(PersistentPlugin::new(identity.clone(), gc_config)) + Arc::new(PersistentPlugin::new(identity.clone(), gc_config.clone())) }); // Downcast the plugin to `PersistentPlugin` - we generally expect this to succeed. The @@ -3706,7 +3706,7 @@ pub fn parse_register(working_set: &mut StateWorkingSet, lite_command: &LiteComm // // The user would expect that `register` would always run the binary to get new // signatures, in case it was replaced with an updated binary - plugin.stop().map_err(|err| { + plugin.reset().map_err(|err| { ParseError::LabeledError( "Failed to restart plugin to get new signatures".into(), err.to_string(), @@ -3714,7 +3714,10 @@ pub fn parse_register(working_set: &mut StateWorkingSet, lite_command: &LiteComm ) })?; + plugin.set_gc_config(&gc_config); + let signatures = get_signature(plugin.clone(), get_envs).map_err(|err| { + log::warn!("Error getting signatures: {err:?}"); ParseError::LabeledError( "Error getting signatures".into(), err.to_string(), diff --git a/crates/nu-plugin-test-support/src/fake_persistent_plugin.rs b/crates/nu-plugin-test-support/src/fake_persistent_plugin.rs index 40bba951c2..b1215faa04 100644 --- a/crates/nu-plugin-test-support/src/fake_persistent_plugin.rs +++ b/crates/nu-plugin-test-support/src/fake_persistent_plugin.rs @@ -51,6 +51,11 @@ impl RegisteredPlugin for FakePersistentPlugin { Ok(()) } + fn reset(&self) -> Result<(), ShellError> { + // We can't stop + Ok(()) + } + fn as_any(self: Arc) -> Arc { self } diff --git a/crates/nu-plugin-test-support/src/spawn_fake_plugin.rs b/crates/nu-plugin-test-support/src/spawn_fake_plugin.rs index e29dea5e1c..0b8e34ae19 100644 --- a/crates/nu-plugin-test-support/src/spawn_fake_plugin.rs +++ b/crates/nu-plugin-test-support/src/spawn_fake_plugin.rs @@ -47,7 +47,9 @@ pub(crate) fn spawn_fake_plugin( let identity = PluginIdentity::new_fake(name); let reg_plugin = Arc::new(FakePersistentPlugin::new(identity.clone())); let source = Arc::new(PluginSource::new(reg_plugin.clone())); - let mut manager = PluginInterfaceManager::new(source, input_write); + + // The fake plugin has no process ID, and we also don't set the garbage collector + let mut manager = PluginInterfaceManager::new(source, None, input_write); // Set up the persistent plugin with the interface before continuing let interface = manager.get_interface(); diff --git a/crates/nu-plugin/Cargo.toml b/crates/nu-plugin/Cargo.toml index 0e2f755f31..1780ff002b 100644 --- a/crates/nu-plugin/Cargo.toml +++ b/crates/nu-plugin/Cargo.toml @@ -13,6 +13,7 @@ bench = false [dependencies] nu-engine = { path = "../nu-engine", version = "0.92.3" } nu-protocol = { path = "../nu-protocol", version = "0.92.3" } +nu-system = { path = "../nu-system", version = "0.92.3" } nu-utils = { path = "../nu-utils", version = "0.92.3" } bincode = "1.3" @@ -24,6 +25,15 @@ miette = { workspace = true } semver = "1.0" typetag = "0.2" thiserror = "1.0" +interprocess = { version = "1.2.1", optional = true } + +[features] +default = ["local-socket"] +local-socket = ["interprocess"] + +[target.'cfg(target_family = "unix")'.dependencies] +# For setting the process group ID (EnterForeground / LeaveForeground) +nix = { workspace = true, default-features = false, features = ["process"] } [target.'cfg(target_os = "windows")'.dependencies] windows = { workspace = true, features = [ diff --git a/crates/nu-plugin/src/plugin/communication_mode/local_socket/mod.rs b/crates/nu-plugin/src/plugin/communication_mode/local_socket/mod.rs new file mode 100644 index 0000000000..e550892fe1 --- /dev/null +++ b/crates/nu-plugin/src/plugin/communication_mode/local_socket/mod.rs @@ -0,0 +1,84 @@ +use std::ffi::OsString; + +#[cfg(test)] +pub(crate) mod tests; + +/// Generate a name to be used for a local socket specific to this `nu` process, described by the +/// given `unique_id`, which should be unique to the purpose of the socket. +/// +/// On Unix, this is a path, which should generally be 100 characters or less for compatibility. On +/// Windows, this is a name within the `\\.\pipe` namespace. +#[cfg(unix)] +pub fn make_local_socket_name(unique_id: &str) -> OsString { + // Prefer to put it in XDG_RUNTIME_DIR if set, since that's user-local + let mut base = if let Some(runtime_dir) = std::env::var_os("XDG_RUNTIME_DIR") { + std::path::PathBuf::from(runtime_dir) + } else { + // Use std::env::temp_dir() for portability, especially since on Android this is probably + // not `/tmp` + std::env::temp_dir() + }; + let socket_name = format!("nu.{}.{}.sock", std::process::id(), unique_id); + base.push(socket_name); + base.into() +} + +/// Generate a name to be used for a local socket specific to this `nu` process, described by the +/// given `unique_id`, which should be unique to the purpose of the socket. +/// +/// On Unix, this is a path, which should generally be 100 characters or less for compatibility. On +/// Windows, this is a name within the `\\.\pipe` namespace. +#[cfg(windows)] +pub fn make_local_socket_name(unique_id: &str) -> OsString { + format!("nu.{}.{}", std::process::id(), unique_id).into() +} + +/// Determine if the error is just due to the listener not being ready yet in asynchronous mode +#[cfg(not(windows))] +pub fn is_would_block_err(err: &std::io::Error) -> bool { + err.kind() == std::io::ErrorKind::WouldBlock +} + +/// Determine if the error is just due to the listener not being ready yet in asynchronous mode +#[cfg(windows)] +pub fn is_would_block_err(err: &std::io::Error) -> bool { + err.kind() == std::io::ErrorKind::WouldBlock + || err.raw_os_error().is_some_and(|e| { + // Windows returns this error when trying to accept a pipe in non-blocking mode + e as i64 == windows::Win32::Foundation::ERROR_PIPE_LISTENING.0 as i64 + }) +} + +/// Wraps the `interprocess` local socket stream for greater compatibility +#[derive(Debug)] +pub struct LocalSocketStream(pub interprocess::local_socket::LocalSocketStream); + +impl From for LocalSocketStream { + fn from(value: interprocess::local_socket::LocalSocketStream) -> Self { + LocalSocketStream(value) + } +} + +impl std::io::Read for LocalSocketStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.0.read(buf) + } +} + +impl std::io::Write for LocalSocketStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.0.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + // We don't actually flush the underlying socket on Windows. The flush operation on a + // Windows named pipe actually synchronizes with read on the other side, and won't finish + // until the other side is empty. This isn't how most of our other I/O methods work, so we + // just won't do it. The BufWriter above this will have still made a write call with the + // contents of the buffer, which should be good enough. + if cfg!(not(windows)) { + self.0.flush()?; + } + Ok(()) + } +} diff --git a/crates/nu-plugin/src/plugin/communication_mode/local_socket/tests.rs b/crates/nu-plugin/src/plugin/communication_mode/local_socket/tests.rs new file mode 100644 index 0000000000..a15d7a5294 --- /dev/null +++ b/crates/nu-plugin/src/plugin/communication_mode/local_socket/tests.rs @@ -0,0 +1,19 @@ +use super::make_local_socket_name; + +#[test] +fn local_socket_path_contains_pid() { + let name = make_local_socket_name("test-string") + .to_string_lossy() + .into_owned(); + println!("{}", name); + assert!(name.to_string().contains(&std::process::id().to_string())); +} + +#[test] +fn local_socket_path_contains_provided_name() { + let name = make_local_socket_name("test-string") + .to_string_lossy() + .into_owned(); + println!("{}", name); + assert!(name.to_string().contains("test-string")); +} diff --git a/crates/nu-plugin/src/plugin/communication_mode/mod.rs b/crates/nu-plugin/src/plugin/communication_mode/mod.rs new file mode 100644 index 0000000000..ca7d5e2b41 --- /dev/null +++ b/crates/nu-plugin/src/plugin/communication_mode/mod.rs @@ -0,0 +1,233 @@ +use std::ffi::OsStr; +use std::io::{Stdin, Stdout}; +use std::process::{Child, ChildStdin, ChildStdout, Command, Stdio}; + +use nu_protocol::ShellError; + +#[cfg(feature = "local-socket")] +use interprocess::local_socket::LocalSocketListener; + +#[cfg(feature = "local-socket")] +mod local_socket; + +#[cfg(feature = "local-socket")] +use local_socket::*; + +#[derive(Debug, Clone)] +pub(crate) enum CommunicationMode { + /// Communicate using `stdin` and `stdout`. + Stdio, + /// Communicate using an operating system-specific local socket. + #[cfg(feature = "local-socket")] + LocalSocket(std::ffi::OsString), +} + +impl CommunicationMode { + /// Generate a new local socket communication mode based on the given plugin exe path. + #[cfg(feature = "local-socket")] + pub fn local_socket(plugin_exe: &std::path::Path) -> CommunicationMode { + use std::hash::{Hash, Hasher}; + use std::time::SystemTime; + + // Generate the unique ID based on the plugin path and the current time. The actual + // algorithm here is not very important, we just want this to be relatively unique very + // briefly. Using the default hasher in the stdlib means zero extra dependencies. + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + + plugin_exe.hash(&mut hasher); + SystemTime::now().hash(&mut hasher); + + let unique_id = format!("{:016x}", hasher.finish()); + + CommunicationMode::LocalSocket(make_local_socket_name(&unique_id)) + } + + pub fn args(&self) -> Vec<&OsStr> { + match self { + CommunicationMode::Stdio => vec![OsStr::new("--stdio")], + #[cfg(feature = "local-socket")] + CommunicationMode::LocalSocket(path) => { + vec![OsStr::new("--local-socket"), path.as_os_str()] + } + } + } + + pub fn setup_command_io(&self, command: &mut Command) { + match self { + CommunicationMode::Stdio => { + // Both stdout and stdin are piped so we can receive information from the plugin + command.stdin(Stdio::piped()); + command.stdout(Stdio::piped()); + } + #[cfg(feature = "local-socket")] + CommunicationMode::LocalSocket(_) => { + // Stdio can be used by the plugin to talk to the terminal in local socket mode, + // which is the big benefit + command.stdin(Stdio::inherit()); + command.stdout(Stdio::inherit()); + } + } + } + + pub fn serve(&self) -> Result { + match self { + // Nothing to set up for stdio - we just take it from the child. + CommunicationMode::Stdio => Ok(PreparedServerCommunication::Stdio), + // For sockets: we need to create the server so that the child won't fail to connect. + #[cfg(feature = "local-socket")] + CommunicationMode::LocalSocket(name) => { + let listener = LocalSocketListener::bind(name.as_os_str()).map_err(|err| { + ShellError::IOError { + msg: format!("failed to open socket for plugin: {err}"), + } + })?; + Ok(PreparedServerCommunication::LocalSocket { + name: name.clone(), + listener, + }) + } + } + } + + pub fn connect_as_client(&self) -> Result { + match self { + CommunicationMode::Stdio => Ok(ClientCommunicationIo::Stdio( + std::io::stdin(), + std::io::stdout(), + )), + #[cfg(feature = "local-socket")] + CommunicationMode::LocalSocket(name) => { + // Connect to the specified socket. + let get_socket = || { + use interprocess::local_socket as ls; + ls::LocalSocketStream::connect(name.as_os_str()) + .map_err(|err| ShellError::IOError { + msg: format!("failed to connect to socket: {err}"), + }) + .map(LocalSocketStream::from) + }; + // Reverse order from the server: read in, write out + let read_in = get_socket()?; + let write_out = get_socket()?; + Ok(ClientCommunicationIo::LocalSocket { read_in, write_out }) + } + } + } +} + +pub(crate) enum PreparedServerCommunication { + Stdio, + #[cfg(feature = "local-socket")] + LocalSocket { + #[cfg_attr(windows, allow(dead_code))] // not used on Windows + name: std::ffi::OsString, + listener: LocalSocketListener, + }, +} + +impl PreparedServerCommunication { + pub fn connect(&self, child: &mut Child) -> Result { + match self { + PreparedServerCommunication::Stdio => { + let stdin = child + .stdin + .take() + .ok_or_else(|| ShellError::PluginFailedToLoad { + msg: "Plugin missing stdin writer".into(), + })?; + + let stdout = child + .stdout + .take() + .ok_or_else(|| ShellError::PluginFailedToLoad { + msg: "Plugin missing stdout writer".into(), + })?; + + Ok(ServerCommunicationIo::Stdio(stdin, stdout)) + } + #[cfg(feature = "local-socket")] + PreparedServerCommunication::LocalSocket { listener, .. } => { + use std::time::{Duration, Instant}; + + const RETRY_PERIOD: Duration = Duration::from_millis(1); + const TIMEOUT: Duration = Duration::from_secs(10); + + let start = Instant::now(); + + // Use a loop to try to get two clients from the listener: one for read (the plugin + // output) and one for write (the plugin input) + listener.set_nonblocking(true)?; + let mut get_socket = || { + let mut result = None; + while let Ok(None) = child.try_wait() { + match listener.accept() { + Ok(stream) => { + // Success! But make sure the stream is in blocking mode. + stream.set_nonblocking(false)?; + result = Some(stream); + break; + } + Err(err) => { + if !is_would_block_err(&err) { + // `WouldBlock` is ok, just means it's not ready yet, but some other + // kind of error should be reported + return Err(err.into()); + } + } + } + if Instant::now().saturating_duration_since(start) > TIMEOUT { + return Err(ShellError::PluginFailedToLoad { + msg: "Plugin timed out while waiting to connect to socket".into(), + }); + } else { + std::thread::sleep(RETRY_PERIOD); + } + } + if let Some(stream) = result { + Ok(LocalSocketStream(stream)) + } else { + // The process may have exited + Err(ShellError::PluginFailedToLoad { + msg: "Plugin exited without connecting".into(), + }) + } + }; + // Input stream always comes before output + let write_in = get_socket()?; + let read_out = get_socket()?; + Ok(ServerCommunicationIo::LocalSocket { read_out, write_in }) + } + } + } +} + +impl Drop for PreparedServerCommunication { + fn drop(&mut self) { + match self { + #[cfg(all(unix, feature = "local-socket"))] + PreparedServerCommunication::LocalSocket { name: path, .. } => { + // Just try to remove the socket file, it's ok if this fails + let _ = std::fs::remove_file(path); + } + _ => (), + } + } +} + +pub(crate) enum ServerCommunicationIo { + Stdio(ChildStdin, ChildStdout), + #[cfg(feature = "local-socket")] + LocalSocket { + read_out: LocalSocketStream, + write_in: LocalSocketStream, + }, +} + +pub(crate) enum ClientCommunicationIo { + Stdio(Stdin, Stdout), + #[cfg(feature = "local-socket")] + LocalSocket { + read_in: LocalSocketStream, + write_out: LocalSocketStream, + }, +} diff --git a/crates/nu-plugin/src/plugin/context.rs b/crates/nu-plugin/src/plugin/context.rs index 61fdfe662a..75315520e1 100644 --- a/crates/nu-plugin/src/plugin/context.rs +++ b/crates/nu-plugin/src/plugin/context.rs @@ -8,7 +8,10 @@ use nu_protocol::{ use std::{ borrow::Cow, collections::HashMap, - sync::{atomic::AtomicBool, Arc}, + sync::{ + atomic::{AtomicBool, AtomicU32}, + Arc, + }, }; /// Object safe trait for abstracting operations required of the plugin context. @@ -16,8 +19,12 @@ use std::{ /// This is not a public API. #[doc(hidden)] pub trait PluginExecutionContext: Send + Sync { + /// A span pointing to the command being executed + fn span(&self) -> Span; /// The interrupt signal, if present fn ctrlc(&self) -> Option<&Arc>; + /// The pipeline externals state, for tracking the foreground process group, if present + fn pipeline_externals_state(&self) -> Option<&Arc<(AtomicU32, AtomicU32)>>; /// Get engine configuration fn get_config(&self) -> Result; /// Get plugin configuration @@ -75,10 +82,18 @@ impl<'a> PluginExecutionCommandContext<'a> { } impl<'a> PluginExecutionContext for PluginExecutionCommandContext<'a> { + fn span(&self) -> Span { + self.call.head + } + fn ctrlc(&self) -> Option<&Arc> { self.engine_state.ctrlc.as_ref() } + fn pipeline_externals_state(&self) -> Option<&Arc<(AtomicU32, AtomicU32)>> { + Some(&self.engine_state.pipeline_externals_state) + } + fn get_config(&self) -> Result { Ok(nu_engine::get_config(&self.engine_state, &self.stack)) } @@ -237,10 +252,18 @@ pub(crate) struct PluginExecutionBogusContext; #[cfg(test)] impl PluginExecutionContext for PluginExecutionBogusContext { + fn span(&self) -> Span { + Span::test_data() + } + fn ctrlc(&self) -> Option<&Arc> { None } + fn pipeline_externals_state(&self) -> Option<&Arc<(AtomicU32, AtomicU32)>> { + None + } + fn get_config(&self) -> Result { Err(ShellError::NushellFailed { msg: "get_config not implemented on bogus".into(), diff --git a/crates/nu-plugin/src/plugin/interface.rs b/crates/nu-plugin/src/plugin/interface.rs index 45fbd75c12..19e70fb99f 100644 --- a/crates/nu-plugin/src/plugin/interface.rs +++ b/crates/nu-plugin/src/plugin/interface.rs @@ -80,6 +80,11 @@ pub trait PluginWrite: Send + Sync { /// Flush any internal buffers, if applicable. fn flush(&self) -> Result<(), ShellError>; + + /// True if this output is stdout, so that plugins can avoid using stdout for their own purpose + fn is_stdout(&self) -> bool { + false + } } impl PluginWrite for (std::io::Stdout, E) @@ -96,6 +101,10 @@ where msg: err.to_string(), }) } + + fn is_stdout(&self) -> bool { + true + } } impl PluginWrite for (Mutex, E) @@ -131,6 +140,10 @@ where fn flush(&self) -> Result<(), ShellError> { (**self).flush() } + + fn is_stdout(&self) -> bool { + (**self).is_stdout() + } } /// An interface manager handles I/O and state management for communication between a plugin and the diff --git a/crates/nu-plugin/src/plugin/interface/engine.rs b/crates/nu-plugin/src/plugin/interface/engine.rs index e98eed06fb..e394b561e8 100644 --- a/crates/nu-plugin/src/plugin/interface/engine.rs +++ b/crates/nu-plugin/src/plugin/interface/engine.rs @@ -4,10 +4,13 @@ use super::{ stream::{StreamManager, StreamManagerHandle}, Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite, Sequence, }; -use crate::protocol::{ - CallInfo, CustomValueOp, EngineCall, EngineCallId, EngineCallResponse, Ordering, PluginCall, - PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, PluginOption, PluginOutput, - ProtocolInfo, +use crate::{ + protocol::{ + CallInfo, CustomValueOp, EngineCall, EngineCallId, EngineCallResponse, Ordering, + PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, PluginOption, + PluginOutput, ProtocolInfo, + }, + util::Waitable, }; use nu_protocol::{ engine::Closure, Config, IntoInterruptiblePipelineData, LabeledError, ListStream, PipelineData, @@ -47,6 +50,8 @@ mod tests; /// Internal shared state between the manager and each interface. struct EngineInterfaceState { + /// Protocol version info, set after `Hello` received + protocol_info: Waitable>, /// Sequence for generating engine call ids engine_call_id_sequence: Sequence, /// Sequence for generating stream ids @@ -61,6 +66,7 @@ struct EngineInterfaceState { impl std::fmt::Debug for EngineInterfaceState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("EngineInterfaceState") + .field("protocol_info", &self.protocol_info) .field("engine_call_id_sequence", &self.engine_call_id_sequence) .field("stream_id_sequence", &self.stream_id_sequence) .field( @@ -91,8 +97,6 @@ pub struct EngineInterfaceManager { mpsc::Receiver<(EngineCallId, mpsc::Sender>)>, /// Manages stream messages and state stream_manager: StreamManager, - /// Protocol version info, set after `Hello` received - protocol_info: Option, } impl EngineInterfaceManager { @@ -102,6 +106,7 @@ impl EngineInterfaceManager { EngineInterfaceManager { state: Arc::new(EngineInterfaceState { + protocol_info: Waitable::new(), engine_call_id_sequence: Sequence::default(), stream_id_sequence: Sequence::default(), engine_call_subscription_sender: subscription_tx, @@ -112,7 +117,6 @@ impl EngineInterfaceManager { engine_call_subscriptions: BTreeMap::new(), engine_call_subscription_receiver: subscription_rx, stream_manager: StreamManager::new(), - protocol_info: None, } } @@ -228,12 +232,13 @@ impl InterfaceManager for EngineInterfaceManager { match input { PluginInput::Hello(info) => { + let info = Arc::new(info); + self.state.protocol_info.set(info.clone())?; + let local_info = ProtocolInfo::default(); if local_info.is_compatible_with(&info)? { - self.protocol_info = Some(info); Ok(()) } else { - self.protocol_info = None; Err(ShellError::PluginFailedToLoad { msg: format!( "Plugin is compiled for nushell version {}, \ @@ -243,7 +248,7 @@ impl InterfaceManager for EngineInterfaceManager { }) } } - _ if self.protocol_info.is_none() => { + _ if !self.state.protocol_info.is_set() => { // Must send protocol info first Err(ShellError::PluginFailedToLoad { msg: "Failed to receive initial Hello message. This engine might be too old" @@ -477,6 +482,15 @@ impl EngineInterface { }) } + /// Returns `true` if the plugin is communicating on stdio. When this is the case, stdin and + /// stdout should not be used by the plugin for other purposes. + /// + /// If the plugin can not be used without access to stdio, an error should be presented to the + /// user instead. + pub fn is_using_stdio(&self) -> bool { + self.state.writer.is_stdout() + } + /// Get the full shell configuration from the engine. As this is quite a large object, it is /// provided on request only. /// @@ -656,6 +670,43 @@ impl EngineInterface { } } + /// Returns a guard that will keep the plugin in the foreground as long as the guard is alive. + /// + /// Moving the plugin to the foreground is necessary for plugins that need to receive input and + /// signals directly from the terminal. + /// + /// The exact implementation is operating system-specific. On Unix, this ensures that the + /// plugin process becomes part of the process group controlling the terminal. + pub fn enter_foreground(&self) -> Result { + match self.engine_call(EngineCall::EnterForeground)? { + EngineCallResponse::Error(error) => Err(error), + EngineCallResponse::PipelineData(PipelineData::Value( + Value::Int { val: pgrp, .. }, + _, + )) => { + set_pgrp_from_enter_foreground(pgrp)?; + Ok(ForegroundGuard(Some(self.clone()))) + } + EngineCallResponse::PipelineData(PipelineData::Empty) => { + Ok(ForegroundGuard(Some(self.clone()))) + } + _ => Err(ShellError::PluginFailedToDecode { + msg: "Received unexpected response type for EngineCall::SetForeground".into(), + }), + } + } + + /// Internal: for exiting the foreground after `enter_foreground()`. Called from the guard. + fn leave_foreground(&self) -> Result<(), ShellError> { + match self.engine_call(EngineCall::LeaveForeground)? { + EngineCallResponse::Error(error) => Err(error), + EngineCallResponse::PipelineData(PipelineData::Empty) => Ok(()), + _ => Err(ShellError::PluginFailedToDecode { + msg: "Received unexpected response type for EngineCall::LeaveForeground".into(), + }), + } + } + /// Get the contents of a [`Span`] from the engine. /// /// This method returns `Vec` as it's possible for the matched span to not be a valid UTF-8 @@ -869,3 +920,69 @@ impl Interface for EngineInterface { } } } + +/// Keeps the plugin in the foreground as long as it is alive. +/// +/// Use [`.leave()`] to leave the foreground without ignoring the error. +pub struct ForegroundGuard(Option); + +impl ForegroundGuard { + // Should be called only once + fn leave_internal(&mut self) -> Result<(), ShellError> { + if let Some(interface) = self.0.take() { + // On Unix, we need to put ourselves back in our own process group + #[cfg(unix)] + { + use nix::unistd::{setpgid, Pid}; + // This should always succeed, frankly, but handle the error just in case + setpgid(Pid::from_raw(0), Pid::from_raw(0)).map_err(|err| ShellError::IOError { + msg: err.to_string(), + })?; + } + interface.leave_foreground()?; + } + Ok(()) + } + + /// Leave the foreground. In contrast to dropping the guard, this preserves the error (if any). + pub fn leave(mut self) -> Result<(), ShellError> { + let result = self.leave_internal(); + std::mem::forget(self); + result + } +} + +impl Drop for ForegroundGuard { + fn drop(&mut self) { + let _ = self.leave_internal(); + } +} + +#[cfg(unix)] +fn set_pgrp_from_enter_foreground(pgrp: i64) -> Result<(), ShellError> { + use nix::unistd::{setpgid, Pid}; + if let Ok(pgrp) = pgrp.try_into() { + setpgid(Pid::from_raw(0), Pid::from_raw(pgrp)).map_err(|err| ShellError::GenericError { + error: "Failed to set process group for foreground".into(), + msg: "".into(), + span: None, + help: Some(err.to_string()), + inner: vec![], + }) + } else { + Err(ShellError::NushellFailed { + msg: "Engine returned an invalid process group ID".into(), + }) + } +} + +#[cfg(not(unix))] +fn set_pgrp_from_enter_foreground(_pgrp: i64) -> Result<(), ShellError> { + Err(ShellError::NushellFailed { + msg: concat!( + "EnterForeground asked plugin to join process group, but not supported on ", + cfg!(target_os) + ) + .into(), + }) +} diff --git a/crates/nu-plugin/src/plugin/interface/engine/tests.rs b/crates/nu-plugin/src/plugin/interface/engine/tests.rs index 4693b1774b..572f6a39fe 100644 --- a/crates/nu-plugin/src/plugin/interface/engine/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/engine/tests.rs @@ -15,9 +15,21 @@ use nu_protocol::{ }; use std::{ collections::HashMap, - sync::mpsc::{self, TryRecvError}, + sync::{ + mpsc::{self, TryRecvError}, + Arc, + }, }; +#[test] +fn is_using_stdio_is_false_for_test() { + let test = TestCase::new(); + let manager = test.engine(); + let interface = manager.get_interface(); + + assert!(!interface.is_using_stdio()); +} + #[test] fn manager_consume_all_consumes_messages() -> Result<(), ShellError> { let mut test = TestCase::new(); @@ -247,8 +259,9 @@ fn manager_consume_sets_protocol_info_on_hello() -> Result<(), ShellError> { manager.consume(PluginInput::Hello(info.clone()))?; let set_info = manager + .state .protocol_info - .as_ref() + .try_get()? .expect("protocol info not set"); assert_eq!(info.version, set_info.version); Ok(()) @@ -275,7 +288,7 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), let mut manager = TestCase::new().engine(); // hello not set - assert!(manager.protocol_info.is_none()); + assert!(!manager.state.protocol_info.is_set()); let error = manager .consume(PluginInput::Drop(0)) @@ -285,10 +298,17 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), Ok(()) } +fn set_default_protocol_info(manager: &mut EngineInterfaceManager) -> Result<(), ShellError> { + manager + .state + .protocol_info + .set(Arc::new(ProtocolInfo::default())) +} + #[test] fn manager_consume_goodbye_closes_plugin_call_channel() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -307,7 +327,7 @@ fn manager_consume_goodbye_closes_plugin_call_channel() -> Result<(), ShellError #[test] fn manager_consume_call_signature_forwards_to_receiver_with_context() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -327,7 +347,7 @@ fn manager_consume_call_signature_forwards_to_receiver_with_context() -> Result< #[test] fn manager_consume_call_run_forwards_to_receiver_with_context() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -361,7 +381,7 @@ fn manager_consume_call_run_forwards_to_receiver_with_context() -> Result<(), Sh #[test] fn manager_consume_call_run_forwards_to_receiver_with_pipeline_data() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -403,7 +423,7 @@ fn manager_consume_call_run_forwards_to_receiver_with_pipeline_data() -> Result< #[test] fn manager_consume_call_run_deserializes_custom_values_in_args() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -469,7 +489,7 @@ fn manager_consume_call_run_deserializes_custom_values_in_args() -> Result<(), S fn manager_consume_call_custom_value_op_forwards_to_receiver_with_context() -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = manager .take_plugin_call_receiver() @@ -509,7 +529,7 @@ fn manager_consume_call_custom_value_op_forwards_to_receiver_with_context() -> R fn manager_consume_engine_call_response_forwards_to_subscriber_with_pipeline_data( ) -> Result<(), ShellError> { let mut manager = TestCase::new().engine(); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = fake_engine_call(&mut manager, 0); diff --git a/crates/nu-plugin/src/plugin/interface/plugin.rs b/crates/nu-plugin/src/plugin/interface/plugin.rs index 93b04c8ef9..5be513a2f1 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin.rs @@ -5,14 +5,14 @@ use super::{ Interface, InterfaceManager, PipelineDataWriter, PluginRead, PluginWrite, }; use crate::{ - plugin::{context::PluginExecutionContext, gc::PluginGc, PluginSource}, + plugin::{context::PluginExecutionContext, gc::PluginGc, process::PluginProcess, PluginSource}, protocol::{ CallInfo, CustomValueOp, EngineCall, EngineCallId, EngineCallResponse, Ordering, PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, PluginOption, PluginOutput, ProtocolInfo, StreamId, StreamMessage, }, sequence::Sequence, - util::with_custom_values_in, + util::{with_custom_values_in, Waitable}, }; use nu_protocol::{ ast::Operator, CustomValue, IntoInterruptiblePipelineData, IntoSpanned, ListStream, @@ -62,6 +62,10 @@ impl std::ops::Deref for Context { struct PluginInterfaceState { /// The source to be used for custom values coming from / going to the plugin source: Arc, + /// The plugin process being managed + process: Option, + /// Protocol version info, set after `Hello` received + protocol_info: Waitable>, /// Sequence for generating plugin call ids plugin_call_id_sequence: Sequence, /// Sequence for generating stream ids @@ -78,12 +82,14 @@ impl std::fmt::Debug for PluginInterfaceState { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PluginInterfaceState") .field("source", &self.source) + .field("protocol_info", &self.protocol_info) .field("plugin_call_id_sequence", &self.plugin_call_id_sequence) .field("stream_id_sequence", &self.stream_id_sequence) .field( "plugin_call_subscription_sender", &self.plugin_call_subscription_sender, ) + .field("error", &self.error) .finish_non_exhaustive() } } @@ -132,8 +138,6 @@ pub struct PluginInterfaceManager { state: Arc, /// Manages stream messages and state stream_manager: StreamManager, - /// Protocol version info, set after `Hello` received - protocol_info: Option, /// State related to plugin calls plugin_call_states: BTreeMap, /// Receiver for plugin call subscriptions @@ -149,6 +153,7 @@ pub struct PluginInterfaceManager { impl PluginInterfaceManager { pub fn new( source: Arc, + pid: Option, writer: impl PluginWrite + 'static, ) -> PluginInterfaceManager { let (subscription_tx, subscription_rx) = mpsc::channel(); @@ -156,6 +161,8 @@ impl PluginInterfaceManager { PluginInterfaceManager { state: Arc::new(PluginInterfaceState { source, + process: pid.map(PluginProcess::new), + protocol_info: Waitable::new(), plugin_call_id_sequence: Sequence::default(), stream_id_sequence: Sequence::default(), plugin_call_subscription_sender: subscription_tx, @@ -163,7 +170,6 @@ impl PluginInterfaceManager { writer: Box::new(writer), }), stream_manager: StreamManager::new(), - protocol_info: None, plugin_call_states: BTreeMap::new(), plugin_call_subscription_receiver: subscription_rx, plugin_call_input_streams: BTreeMap::new(), @@ -289,9 +295,10 @@ impl PluginInterfaceManager { })?; // Generate the state needed to handle engine calls - let current_call_state = CurrentCallState { + let mut current_call_state = CurrentCallState { context_tx: None, keep_plugin_custom_values_tx: Some(state.keep_plugin_custom_values.0.clone()), + entered_foreground: false, }; let handler = move || { @@ -308,7 +315,7 @@ impl PluginInterfaceManager { if let Err(err) = interface.handle_engine_call( engine_call_id, engine_call, - ¤t_call_state, + &mut current_call_state, context.as_deref_mut(), ) { log::warn!( @@ -453,12 +460,13 @@ impl InterfaceManager for PluginInterfaceManager { match input { PluginOutput::Hello(info) => { + let info = Arc::new(info); + self.state.protocol_info.set(info.clone())?; + let local_info = ProtocolInfo::default(); if local_info.is_compatible_with(&info)? { - self.protocol_info = Some(info); Ok(()) } else { - self.protocol_info = None; Err(ShellError::PluginFailedToLoad { msg: format!( "Plugin `{}` is compiled for nushell version {}, \ @@ -470,7 +478,7 @@ impl InterfaceManager for PluginInterfaceManager { }) } } - _ if self.protocol_info.is_none() => { + _ if !self.state.protocol_info.is_set() => { // Must send protocol info first Err(ShellError::PluginFailedToLoad { msg: format!( @@ -613,6 +621,16 @@ pub struct PluginInterface { } impl PluginInterface { + /// Get the process ID for the plugin, if known. + pub fn pid(&self) -> Option { + self.state.process.as_ref().map(|p| p.pid()) + } + + /// Get the protocol info for the plugin. Will block to receive `Hello` if not received yet. + pub fn protocol_info(&self) -> Result, ShellError> { + self.state.protocol_info.get() + } + /// Write the protocol info. This should be done after initialization pub fn hello(&self) -> Result<(), ShellError> { self.write(PluginInput::Hello(ProtocolInfo::default()))?; @@ -673,6 +691,7 @@ impl PluginInterface { let state = CurrentCallState { context_tx: Some(context_tx), keep_plugin_custom_values_tx: Some(keep_plugin_custom_values.0.clone()), + entered_foreground: false, }; // Prepare the call with the state. @@ -767,12 +786,22 @@ impl PluginInterface { &self, rx: mpsc::Receiver, mut context: Option<&mut (dyn PluginExecutionContext + '_)>, - state: CurrentCallState, + mut state: CurrentCallState, ) -> Result, ShellError> { // Handle message from receiver for msg in rx { match msg { ReceivedPluginCallMessage::Response(resp) => { + if state.entered_foreground { + // Make the plugin leave the foreground on return, even if it's a stream + if let Some(context) = context.as_deref_mut() { + if let Err(err) = + set_foreground(self.state.process.as_ref(), context, false) + { + log::warn!("Failed to leave foreground state on exit: {err:?}"); + } + } + } if resp.has_stream() { // If the response has a stream, we need to register the context if let Some(context) = context { @@ -790,7 +819,7 @@ impl PluginInterface { self.handle_engine_call( engine_call_id, engine_call, - &state, + &mut state, context.as_deref_mut(), )?; } @@ -807,11 +836,12 @@ impl PluginInterface { &self, engine_call_id: EngineCallId, engine_call: EngineCall, - state: &CurrentCallState, + state: &mut CurrentCallState, context: Option<&mut (dyn PluginExecutionContext + '_)>, ) -> Result<(), ShellError> { - let resp = - handle_engine_call(engine_call, context).unwrap_or_else(EngineCallResponse::Error); + let process = self.state.process.as_ref(); + let resp = handle_engine_call(engine_call, state, context, process) + .unwrap_or_else(EngineCallResponse::Error); // Handle stream let mut writer = None; let resp = resp @@ -1057,6 +1087,9 @@ pub struct CurrentCallState { /// Sender for a channel that retains plugin custom values that need to stay alive for the /// duration of a plugin call. keep_plugin_custom_values_tx: Option>, + /// The plugin call entered the foreground: this should be cleaned up automatically when the + /// plugin call returns. + entered_foreground: bool, } impl CurrentCallState { @@ -1144,7 +1177,9 @@ impl CurrentCallState { /// Handle an engine call. pub(crate) fn handle_engine_call( call: EngineCall, + state: &mut CurrentCallState, context: Option<&mut (dyn PluginExecutionContext + '_)>, + process: Option<&PluginProcess>, ) -> Result, ShellError> { let call_name = call.name(); @@ -1190,6 +1225,16 @@ pub(crate) fn handle_engine_call( help.item, help.span, ))) } + EngineCall::EnterForeground => { + let resp = set_foreground(process, context, true)?; + state.entered_foreground = true; + Ok(resp) + } + EngineCall::LeaveForeground => { + let resp = set_foreground(process, context, false)?; + state.entered_foreground = false; + Ok(resp) + } EngineCall::GetSpanContents(span) => { let contents = context.get_span_contents(span)?; Ok(EngineCallResponse::value(Value::binary( @@ -1208,3 +1253,39 @@ pub(crate) fn handle_engine_call( .map(EngineCallResponse::PipelineData), } } + +/// Implements enter/exit foreground +fn set_foreground( + process: Option<&PluginProcess>, + context: &mut dyn PluginExecutionContext, + enter: bool, +) -> Result, ShellError> { + if let Some(process) = process { + if let Some(pipeline_externals_state) = context.pipeline_externals_state() { + if enter { + let pgrp = process.enter_foreground(context.span(), pipeline_externals_state)?; + Ok(pgrp.map_or_else(EngineCallResponse::empty, |id| { + EngineCallResponse::value(Value::int(id as i64, context.span())) + })) + } else { + process.exit_foreground()?; + Ok(EngineCallResponse::empty()) + } + } else { + // This should always be present on a real context + Err(ShellError::NushellFailed { + msg: "missing required pipeline_externals_state from context \ + for entering foreground" + .into(), + }) + } + } else { + Err(ShellError::GenericError { + error: "Can't manage plugin process to enter foreground".into(), + msg: "the process ID for this plugin is unknown".into(), + span: Some(context.span()), + help: Some("the plugin may be running in a test".into()), + inner: vec![], + }) + } +} diff --git a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs index 290b04a3ce..77dcc6d4d2 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs @@ -279,8 +279,9 @@ fn manager_consume_sets_protocol_info_on_hello() -> Result<(), ShellError> { manager.consume(PluginOutput::Hello(info.clone()))?; let set_info = manager + .state .protocol_info - .as_ref() + .try_get()? .expect("protocol info not set"); assert_eq!(info.version, set_info.version); Ok(()) @@ -307,7 +308,7 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), let mut manager = TestCase::new().plugin("test"); // hello not set - assert!(manager.protocol_info.is_none()); + assert!(!manager.state.protocol_info.is_set()); let error = manager .consume(PluginOutput::Drop(0)) @@ -317,11 +318,18 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), Ok(()) } +fn set_default_protocol_info(manager: &mut PluginInterfaceManager) -> Result<(), ShellError> { + manager + .state + .protocol_info + .set(Arc::new(ProtocolInfo::default())) +} + #[test] fn manager_consume_call_response_forwards_to_subscriber_with_pipeline_data( ) -> Result<(), ShellError> { let mut manager = TestCase::new().plugin("test"); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = fake_plugin_call(&mut manager, 0); @@ -359,7 +367,7 @@ fn manager_consume_call_response_forwards_to_subscriber_with_pipeline_data( #[test] fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> { let mut manager = TestCase::new().plugin("test"); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; for n in [0, 1] { fake_plugin_call(&mut manager, n); @@ -428,7 +436,7 @@ fn manager_consume_call_response_registers_streams() -> Result<(), ShellError> { fn manager_consume_engine_call_forwards_to_subscriber_with_pipeline_data() -> Result<(), ShellError> { let mut manager = TestCase::new().plugin("test"); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let rx = fake_plugin_call(&mut manager, 37); @@ -480,7 +488,7 @@ fn manager_consume_engine_call_forwards_to_subscriber_with_pipeline_data() -> Re fn manager_handle_engine_call_after_response_received() -> Result<(), ShellError> { let test = TestCase::new(); let mut manager = test.plugin("test"); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; let (context_tx, context_rx) = mpsc::channel(); @@ -584,7 +592,7 @@ fn manager_send_plugin_call_response_removes_context_only_if_no_streams_to_read( #[test] fn manager_consume_stream_end_removes_context_only_if_last_stream() -> Result<(), ShellError> { let mut manager = TestCase::new().plugin("test"); - manager.protocol_info = Some(ProtocolInfo::default()); + set_default_protocol_info(&mut manager)?; for n in [1, 2] { manager.plugin_call_states.insert( @@ -1274,8 +1282,8 @@ fn prepare_custom_value_sends_to_keep_channel_if_drop_notify() -> Result<(), She let source = Arc::new(PluginSource::new_fake("test")); let (tx, rx) = mpsc::channel(); let state = CurrentCallState { - context_tx: None, keep_plugin_custom_values_tx: Some(tx), + ..Default::default() }; // Try with a custom val that has drop check set let mut drop_val = PluginCustomValue::serialize_from_custom_value(&DropCustomVal, span)? diff --git a/crates/nu-plugin/src/plugin/interface/test_util.rs b/crates/nu-plugin/src/plugin/interface/test_util.rs index b701ff7792..a2acec1e7e 100644 --- a/crates/nu-plugin/src/plugin/interface/test_util.rs +++ b/crates/nu-plugin/src/plugin/interface/test_util.rs @@ -128,7 +128,7 @@ impl TestCase { impl TestCase { /// Create a new [`PluginInterfaceManager`] that writes to this test case. pub(crate) fn plugin(&self, name: &str) -> PluginInterfaceManager { - PluginInterfaceManager::new(PluginSource::new_fake(name).into(), self.clone()) + PluginInterfaceManager::new(PluginSource::new_fake(name).into(), None, self.clone()) } } diff --git a/crates/nu-plugin/src/plugin/mod.rs b/crates/nu-plugin/src/plugin/mod.rs index d314f3c54a..6ab62eb079 100644 --- a/crates/nu-plugin/src/plugin/mod.rs +++ b/crates/nu-plugin/src/plugin/mod.rs @@ -8,12 +8,11 @@ use std::{ cmp::Ordering, collections::HashMap, env, - ffi::OsStr, - fmt::Write, - io::{BufReader, BufWriter, Read, Write as WriteTrait}, + ffi::OsString, + io::{BufReader, BufWriter}, ops::Deref, path::Path, - process::{Child, ChildStdout, Command as CommandSys, Stdio}, + process::{Child, Command as CommandSys}, sync::{ mpsc::{self, TrySendError}, Arc, Mutex, @@ -34,14 +33,23 @@ use std::os::unix::process::CommandExt; use std::os::windows::process::CommandExt; pub use self::interface::{PluginRead, PluginWrite}; -use self::{command::render_examples, gc::PluginGc}; +use self::{ + command::render_examples, + communication_mode::{ + ClientCommunicationIo, CommunicationMode, PreparedServerCommunication, + ServerCommunicationIo, + }, + gc::PluginGc, +}; mod command; +mod communication_mode; mod context; mod declaration; mod gc; mod interface; mod persistent; +mod process; mod source; pub use command::{create_plugin_signature, PluginCommand, SimplePluginCommand}; @@ -84,62 +92,51 @@ pub trait PluginEncoder: Encoder + Encoder { fn name(&self) -> &str; } -fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys { - log::trace!("Starting plugin: {path:?}, shell = {shell:?}"); +fn create_command(path: &Path, mut shell: Option<&Path>, mode: &CommunicationMode) -> CommandSys { + log::trace!("Starting plugin: {path:?}, shell = {shell:?}, mode = {mode:?}"); - // There is only one mode supported at the moment, but the idea is that future - // communication methods could be supported if desirable - let mut input_arg = Some("--stdio"); + let mut shell_args = vec![]; - let mut process = match (path.extension(), shell) { - (_, Some(shell)) => { - let mut process = std::process::Command::new(shell); - process.arg(path); - - process - } - (Some(extension), None) => { - let (shell, command_switch) = match extension.to_str() { - Some("cmd") | Some("bat") => (Some("cmd"), Some("/c")), - Some("sh") => (Some("sh"), Some("-c")), - Some("py") => (Some("python"), None), - _ => (None, None), - }; - - match (shell, command_switch) { - (Some(shell), Some(command_switch)) => { - let mut process = std::process::Command::new(shell); - process.arg(command_switch); - // If `command_switch` is set, we need to pass the path + arg as one argument - // e.g. sh -c "nu_plugin_inc --stdio" - let mut combined = path.as_os_str().to_owned(); - if let Some(arg) = input_arg.take() { - combined.push(OsStr::new(" ")); - combined.push(OsStr::new(arg)); - } - process.arg(combined); - - process + if shell.is_none() { + // We only have to do this for things that are not executable by Rust's Command API on + // Windows. They do handle bat/cmd files for us, helpfully. + // + // Also include anything that wouldn't be executable with a shebang, like JAR files. + shell = match path.extension().and_then(|e| e.to_str()) { + Some("sh") => { + if cfg!(unix) { + // We don't want to override what might be in the shebang if this is Unix, since + // some scripts will have a shebang specifying bash even if they're .sh + None + } else { + Some(Path::new("sh")) } - (Some(shell), None) => { - let mut process = std::process::Command::new(shell); - process.arg(path); - - process - } - _ => std::process::Command::new(path), } - } - (None, None) => std::process::Command::new(path), - }; - - // Pass input_arg, unless we consumed it already - if let Some(input_arg) = input_arg { - process.arg(input_arg); + Some("nu") => Some(Path::new("nu")), + Some("py") => Some(Path::new("python")), + Some("rb") => Some(Path::new("ruby")), + Some("jar") => { + shell_args.push("-jar"); + Some(Path::new("java")) + } + _ => None, + }; } - // Both stdout and stdin are piped so we can receive information from the plugin - process.stdout(Stdio::piped()).stdin(Stdio::piped()); + let mut process = if let Some(shell) = shell { + let mut process = std::process::Command::new(shell); + process.args(shell_args); + process.arg(path); + + process + } else { + std::process::Command::new(path) + }; + + process.args(mode.args()); + + // Setup I/O according to the communication mode + mode.setup_command_io(&mut process); // The plugin should be run in a new process group to prevent Ctrl-C from stopping it #[cfg(unix)] @@ -158,29 +155,53 @@ fn create_command(path: &Path, shell: Option<&Path>) -> CommandSys { fn make_plugin_interface( mut child: Child, + comm: PreparedServerCommunication, source: Arc, + pid: Option, gc: Option, ) -> Result { - let stdin = child - .stdin - .take() - .ok_or_else(|| ShellError::PluginFailedToLoad { - msg: "Plugin missing stdin writer".into(), - })?; + match comm.connect(&mut child)? { + ServerCommunicationIo::Stdio(stdin, stdout) => make_plugin_interface_with_streams( + stdout, + stdin, + move || { + let _ = child.wait(); + }, + source, + pid, + gc, + ), + #[cfg(feature = "local-socket")] + ServerCommunicationIo::LocalSocket { read_out, write_in } => { + make_plugin_interface_with_streams( + read_out, + write_in, + move || { + let _ = child.wait(); + }, + source, + pid, + gc, + ) + } + } +} - let mut stdout = child - .stdout - .take() - .ok_or_else(|| ShellError::PluginFailedToLoad { - msg: "Plugin missing stdout writer".into(), - })?; +fn make_plugin_interface_with_streams( + mut reader: impl std::io::Read + Send + 'static, + writer: impl std::io::Write + Send + 'static, + after_close: impl FnOnce() + Send + 'static, + source: Arc, + pid: Option, + gc: Option, +) -> Result { + let encoder = get_plugin_encoding(&mut reader)?; - let encoder = get_plugin_encoding(&mut stdout)?; + let reader = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, reader); + let writer = BufWriter::with_capacity(OUTPUT_BUFFER_SIZE, writer); - let reader = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, stdout); - let writer = BufWriter::with_capacity(OUTPUT_BUFFER_SIZE, stdin); - - let mut manager = PluginInterfaceManager::new(source.clone(), (Mutex::new(writer), encoder)); + let mut manager = + PluginInterfaceManager::new(source.clone(), pid, (Mutex::new(writer), encoder)); manager.set_garbage_collector(gc); let interface = manager.get_interface(); @@ -198,10 +219,10 @@ fn make_plugin_interface( if let Err(err) = manager.consume_all((reader, encoder)) { log::warn!("Error in PluginInterfaceManager: {err}"); } - // If the loop has ended, drop the manager so everyone disconnects and then wait for the - // child to exit + // If the loop has ended, drop the manager so everyone disconnects and then run + // after_close drop(manager); - let _ = child.wait(); + after_close(); }) .map_err(|err| ShellError::PluginFailedToLoad { msg: format!("Failed to spawn thread for plugin: {err}"), @@ -211,15 +232,10 @@ fn make_plugin_interface( } #[doc(hidden)] // Note: not for plugin authors / only used in nu-parser -pub fn get_signature( +pub fn get_signature( plugin: Arc, - envs: impl FnOnce() -> Result, -) -> Result, ShellError> -where - E: IntoIterator, - K: AsRef, - V: AsRef, -{ + envs: impl FnOnce() -> Result, ShellError>, +) -> Result, ShellError> { plugin.get(envs)?.get_signature() } @@ -412,9 +428,7 @@ pub trait Plugin: Sync { /// } /// ``` pub fn serve_plugin(plugin: &impl Plugin, encoder: impl PluginEncoder + 'static) { - let mut args = env::args().skip(1); - let number_of_args = args.len(); - let first_arg = args.next(); + let args: Vec = env::args_os().skip(1).collect(); // Determine the plugin name, for errors let exe = std::env::current_exe().ok(); @@ -430,18 +444,26 @@ pub fn serve_plugin(plugin: &impl Plugin, encoder: impl PluginEncoder + 'static) }) .unwrap_or_else(|| "(unknown)".into()); - if number_of_args == 0 - || first_arg - .as_ref() - .is_some_and(|arg| arg == "-h" || arg == "--help") - { + if args.is_empty() || args[0] == "-h" || args[0] == "--help" { print_help(plugin, encoder); std::process::exit(0) } - // Must pass --stdio for plugin execution. Any other arg is an error to give us options in the - // future. - if number_of_args > 1 || !first_arg.is_some_and(|arg| arg == "--stdio") { + // Implement different communication modes: + let mode = if args[0] == "--stdio" && args.len() == 1 { + // --stdio always supported. + CommunicationMode::Stdio + } else if args[0] == "--local-socket" && args.len() == 2 { + #[cfg(feature = "local-socket")] + { + CommunicationMode::LocalSocket((&args[1]).into()) + } + #[cfg(not(feature = "local-socket"))] + { + eprintln!("{plugin_name}: local socket mode is not supported"); + std::process::exit(1); + } + } else { eprintln!( "{}: This plugin must be run from within Nushell.", env::current_exe() @@ -453,34 +475,42 @@ pub fn serve_plugin(plugin: &impl Plugin, encoder: impl PluginEncoder + 'static) version of nushell you are using." ); std::process::exit(1) - } - - // tell nushell encoding. - // - // 1 byte - // encoding format: | content-length | content | - let mut stdout = std::io::stdout(); - { - let encoding = encoder.name(); - let length = encoding.len() as u8; - let mut encoding_content: Vec = encoding.as_bytes().to_vec(); - encoding_content.insert(0, length); - stdout - .write_all(&encoding_content) - .expect("Failed to tell nushell my encoding"); - stdout - .flush() - .expect("Failed to tell nushell my encoding when flushing stdout"); - } + }; let encoder_clone = encoder.clone(); - let result = serve_plugin_io( - plugin, - &plugin_name, - move || (std::io::stdin().lock(), encoder_clone), - move || (std::io::stdout(), encoder), - ); + let result = match mode.connect_as_client() { + Ok(ClientCommunicationIo::Stdio(stdin, mut stdout)) => { + tell_nushell_encoding(&mut stdout, &encoder).expect("failed to tell nushell encoding"); + serve_plugin_io( + plugin, + &plugin_name, + move || (stdin.lock(), encoder_clone), + move || (stdout, encoder), + ) + } + #[cfg(feature = "local-socket")] + Ok(ClientCommunicationIo::LocalSocket { + read_in, + mut write_out, + }) => { + tell_nushell_encoding(&mut write_out, &encoder) + .expect("failed to tell nushell encoding"); + + let read = BufReader::with_capacity(OUTPUT_BUFFER_SIZE, read_in); + let write = Mutex::new(BufWriter::with_capacity(OUTPUT_BUFFER_SIZE, write_out)); + serve_plugin_io( + plugin, + &plugin_name, + move || (read, encoder_clone), + move || (write, encoder), + ) + } + Err(err) => { + eprintln!("{plugin_name}: failed to connect: {err:?}"); + std::process::exit(1); + } + }; match result { Ok(()) => (), @@ -493,6 +523,22 @@ pub fn serve_plugin(plugin: &impl Plugin, encoder: impl PluginEncoder + 'static) } } +fn tell_nushell_encoding( + writer: &mut impl std::io::Write, + encoder: &impl PluginEncoder, +) -> Result<(), std::io::Error> { + // tell nushell encoding. + // + // 1 byte + // encoding format: | content-length | content | + let encoding = encoder.name(); + let length = encoding.len() as u8; + let mut encoding_content: Vec = encoding.as_bytes().to_vec(); + encoding_content.insert(0, length); + writer.write_all(&encoding_content)?; + writer.flush() +} + /// An error from [`serve_plugin_io()`] #[derive(Debug, Error)] pub enum ServePluginError { @@ -765,6 +811,8 @@ fn custom_value_op( } fn print_help(plugin: &impl Plugin, encoder: impl PluginEncoder) { + use std::fmt::Write; + println!("Nushell Plugin"); println!("Encoder: {}", encoder.name()); @@ -831,7 +879,9 @@ fn print_help(plugin: &impl Plugin, encoder: impl PluginEncoder) { println!("{help}") } -pub fn get_plugin_encoding(child_stdout: &mut ChildStdout) -> Result { +pub fn get_plugin_encoding( + child_stdout: &mut impl std::io::Read, +) -> Result { let mut length_buf = [0u8; 1]; child_stdout .read_exact(&mut length_buf) diff --git a/crates/nu-plugin/src/plugin/persistent.rs b/crates/nu-plugin/src/plugin/persistent.rs index a2f3cf0733..5c08add437 100644 --- a/crates/nu-plugin/src/plugin/persistent.rs +++ b/crates/nu-plugin/src/plugin/persistent.rs @@ -1,10 +1,13 @@ -use super::{create_command, gc::PluginGc, make_plugin_interface, PluginInterface, PluginSource}; +use super::{ + communication_mode::CommunicationMode, create_command, gc::PluginGc, make_plugin_interface, + PluginInterface, PluginSource, +}; use nu_protocol::{ engine::{EngineState, Stack}, PluginGcConfig, PluginIdentity, RegisteredPlugin, ShellError, }; use std::{ - ffi::OsStr, + collections::HashMap, sync::{Arc, Mutex}, }; @@ -28,14 +31,21 @@ pub struct PersistentPlugin { struct MutableState { /// Reference to the plugin if running running: Option, + /// Plugin's preferred communication mode (if known) + preferred_mode: Option, /// Garbage collector config gc_config: PluginGcConfig, } +#[derive(Debug, Clone, Copy)] +enum PreferredCommunicationMode { + Stdio, + #[cfg(feature = "local-socket")] + LocalSocket, +} + #[derive(Debug)] struct RunningPlugin { - /// Process ID of the running plugin - pid: u32, /// Interface (which can be cloned) to the running plugin interface: PluginInterface, /// Garbage collector for the plugin @@ -49,6 +59,7 @@ impl PersistentPlugin { identity, mutable: Mutex::new(MutableState { running: None, + preferred_mode: None, gc_config, }), } @@ -58,15 +69,10 @@ impl PersistentPlugin { /// /// Will call `envs` to get environment variables to spawn the plugin if the plugin needs to be /// spawned. - pub(crate) fn get( + pub(crate) fn get( self: Arc, - envs: impl FnOnce() -> Result, - ) -> Result - where - E: IntoIterator, - K: AsRef, - V: AsRef, - { + envs: impl FnOnce() -> Result, ShellError>, + ) -> Result { let mut mutable = self.mutable.lock().map_err(|_| ShellError::NushellFailed { msg: format!( "plugin `{}` mutex poisoned, probably panic during spawn", @@ -78,28 +84,70 @@ impl PersistentPlugin { // It exists, so just clone the interface Ok(running.interface.clone()) } else { - // Try to spawn, and then store the spawned plugin if we were successful. + // Try to spawn. On success, `mutable.running` should have been set to the new running + // plugin by `spawn()` so we just then need to clone the interface from there. // // We hold the lock the whole time to prevent others from trying to spawn and ending // up with duplicate plugins // // TODO: We should probably store the envs somewhere, in case we have to launch without // envs (e.g. from a custom value) - let new_running = self.clone().spawn(envs()?, &mutable.gc_config)?; - let interface = new_running.interface.clone(); - mutable.running = Some(new_running); - Ok(interface) + let envs = envs()?; + let result = self.clone().spawn(&envs, &mut mutable); + + // Check if we were using an alternate communication mode and may need to fall back to + // stdio. + if result.is_err() + && !matches!( + mutable.preferred_mode, + Some(PreferredCommunicationMode::Stdio) + ) + { + log::warn!("{}: Trying again with stdio communication because mode {:?} failed with {result:?}", + self.identity.name(), + mutable.preferred_mode); + // Reset to stdio and try again, but this time don't catch any error + mutable.preferred_mode = Some(PreferredCommunicationMode::Stdio); + self.clone().spawn(&envs, &mut mutable)?; + } + + Ok(mutable + .running + .as_ref() + .ok_or_else(|| ShellError::NushellFailed { + msg: "spawn() succeeded but didn't set interface".into(), + })? + .interface + .clone()) } } - /// Run the plugin command, then set up and return [`RunningPlugin`]. + /// Run the plugin command, then set up and set `mutable.running` to the new running plugin. fn spawn( self: Arc, - envs: impl IntoIterator, impl AsRef)>, - gc_config: &PluginGcConfig, - ) -> Result { + envs: &HashMap, + mutable: &mut MutableState, + ) -> Result<(), ShellError> { + // Make sure `running` is set to None to begin + if let Some(running) = mutable.running.take() { + // Stop the GC if there was a running plugin + running.gc.stop_tracking(); + } + let source_file = self.identity.filename(); - let mut plugin_cmd = create_command(source_file, self.identity.shell()); + + // Determine the mode to use based on the preferred mode + let mode = match mutable.preferred_mode { + // If not set, we try stdio first and then might retry if another mode is supported + Some(PreferredCommunicationMode::Stdio) | None => CommunicationMode::Stdio, + // Local socket only if enabled + #[cfg(feature = "local-socket")] + Some(PreferredCommunicationMode::LocalSocket) => { + CommunicationMode::local_socket(source_file) + } + }; + + let mut plugin_cmd = create_command(source_file, self.identity.shell(), &mode); // We need the current environment variables for `python` based plugins // Or we'll likely have a problem when a plugin is implemented in a virtual Python environment. @@ -107,6 +155,9 @@ impl PersistentPlugin { let program_name = plugin_cmd.get_program().to_os_string().into_string(); + // Before running the command, prepare communication + let comm = mode.serve()?; + // Run the plugin command let child = plugin_cmd.spawn().map_err(|err| { let error_msg = match err.kind() { @@ -126,13 +177,64 @@ impl PersistentPlugin { })?; // Start the plugin garbage collector - let gc = PluginGc::new(gc_config.clone(), &self)?; + let gc = PluginGc::new(mutable.gc_config.clone(), &self)?; let pid = child.id(); - let interface = - make_plugin_interface(child, Arc::new(PluginSource::new(self)), Some(gc.clone()))?; + let interface = make_plugin_interface( + child, + comm, + Arc::new(PluginSource::new(self.clone())), + Some(pid), + Some(gc.clone()), + )?; - Ok(RunningPlugin { pid, interface, gc }) + // If our current preferred mode is None, check to see if the plugin might support another + // mode. If so, retry spawn() with that mode + #[cfg(feature = "local-socket")] + if mutable.preferred_mode.is_none() + && interface + .protocol_info()? + .supports_feature(&crate::protocol::Feature::LocalSocket) + { + log::trace!( + "{}: Attempting to upgrade to local socket mode", + self.identity.name() + ); + // Stop the GC we just created from tracking so that we don't accidentally try to + // stop the new plugin + gc.stop_tracking(); + // Set the mode and try again + mutable.preferred_mode = Some(PreferredCommunicationMode::LocalSocket); + return self.spawn(envs, mutable); + } + + mutable.running = Some(RunningPlugin { interface, gc }); + Ok(()) + } + + fn stop_internal(&self, reset: bool) -> Result<(), ShellError> { + let mut mutable = self.mutable.lock().map_err(|_| ShellError::NushellFailed { + msg: format!( + "plugin `{}` mutable mutex poisoned, probably panic during spawn", + self.identity.name() + ), + })?; + + // If the plugin is running, stop its GC, so that the GC doesn't accidentally try to stop + // a future plugin + if let Some(ref running) = mutable.running { + running.gc.stop_tracking(); + } + + // We don't try to kill the process or anything, we just drop the RunningPlugin. It should + // exit soon after + mutable.running = None; + + // If this is a reset, we should also reset other learned attributes like preferred_mode + if reset { + mutable.preferred_mode = None; + } + Ok(()) } } @@ -155,27 +257,15 @@ impl RegisteredPlugin for PersistentPlugin { self.mutable .lock() .ok() - .and_then(|r| r.running.as_ref().map(|r| r.pid)) + .and_then(|r| r.running.as_ref().and_then(|r| r.interface.pid())) } fn stop(&self) -> Result<(), ShellError> { - let mut mutable = self.mutable.lock().map_err(|_| ShellError::NushellFailed { - msg: format!( - "plugin `{}` mutable mutex poisoned, probably panic during spawn", - self.identity.name() - ), - })?; + self.stop_internal(false) + } - // If the plugin is running, stop its GC, so that the GC doesn't accidentally try to stop - // a future plugin - if let Some(ref running) = mutable.running { - running.gc.stop_tracking(); - } - - // We don't try to kill the process or anything, we just drop the RunningPlugin. It should - // exit soon after - mutable.running = None; - Ok(()) + fn reset(&self) -> Result<(), ShellError> { + self.stop_internal(true) } fn set_gc_config(&self, gc_config: &PluginGcConfig) { @@ -214,11 +304,12 @@ pub trait GetPlugin: RegisteredPlugin { impl GetPlugin for PersistentPlugin { fn get_plugin( self: Arc, - context: Option<(&EngineState, &mut Stack)>, + mut context: Option<(&EngineState, &mut Stack)>, ) -> Result { self.get(|| { // Get envs from the context if provided. let envs = context + .as_mut() .map(|(engine_state, stack)| { // We need the current environment variables for `python` based plugins. Or // we'll likely have a problem when a plugin is implemented in a virtual Python @@ -228,7 +319,7 @@ impl GetPlugin for PersistentPlugin { }) .transpose()?; - Ok(envs.into_iter().flatten()) + Ok(envs.unwrap_or_default()) }) } } diff --git a/crates/nu-plugin/src/plugin/process.rs b/crates/nu-plugin/src/plugin/process.rs new file mode 100644 index 0000000000..a87e2ea22b --- /dev/null +++ b/crates/nu-plugin/src/plugin/process.rs @@ -0,0 +1,90 @@ +use std::sync::{atomic::AtomicU32, Arc, Mutex, MutexGuard}; + +use nu_protocol::{ShellError, Span}; +use nu_system::ForegroundGuard; + +/// Provides a utility interface for a plugin interface to manage the process the plugin is running +/// in. +#[derive(Debug)] +pub(crate) struct PluginProcess { + pid: u32, + mutable: Mutex, +} + +#[derive(Debug)] +struct MutablePart { + foreground_guard: Option, +} + +impl PluginProcess { + /// Manage a plugin process. + pub(crate) fn new(pid: u32) -> PluginProcess { + PluginProcess { + pid, + mutable: Mutex::new(MutablePart { + foreground_guard: None, + }), + } + } + + /// The process ID of the plugin. + pub(crate) fn pid(&self) -> u32 { + self.pid + } + + fn lock_mutable(&self) -> Result, ShellError> { + self.mutable.lock().map_err(|_| ShellError::NushellFailed { + msg: "the PluginProcess mutable lock has been poisoned".into(), + }) + } + + /// Move the plugin process to the foreground. See [`ForegroundGuard::new`]. + /// + /// This produces an error if the plugin process was already in the foreground. + /// + /// Returns `Some()` on Unix with the process group ID if the plugin process will need to join + /// another process group to be part of the foreground. + pub(crate) fn enter_foreground( + &self, + span: Span, + pipeline_state: &Arc<(AtomicU32, AtomicU32)>, + ) -> Result, ShellError> { + let pid = self.pid; + let mut mutable = self.lock_mutable()?; + if mutable.foreground_guard.is_none() { + let guard = ForegroundGuard::new(pid, pipeline_state).map_err(|err| { + ShellError::GenericError { + error: "Failed to enter foreground".into(), + msg: err.to_string(), + span: Some(span), + help: None, + inner: vec![], + } + })?; + let pgrp = guard.pgrp(); + mutable.foreground_guard = Some(guard); + Ok(pgrp) + } else { + Err(ShellError::GenericError { + error: "Can't enter foreground".into(), + msg: "this plugin is already running in the foreground".into(), + span: Some(span), + help: Some( + "you may be trying to run the command in parallel, or this may be a bug in \ + the plugin" + .into(), + ), + inner: vec![], + }) + } + } + + /// Move the plugin process out of the foreground. See [`ForegroundGuard::reset`]. + /// + /// This is a no-op if the plugin process was already in the background. + pub(crate) fn exit_foreground(&self) -> Result<(), ShellError> { + let mut mutable = self.lock_mutable()?; + drop(mutable.foreground_guard.take()); + Ok(()) + } +} diff --git a/crates/nu-plugin/src/protocol/mod.rs b/crates/nu-plugin/src/protocol/mod.rs index 6f71efc4c3..c716e37901 100644 --- a/crates/nu-plugin/src/protocol/mod.rs +++ b/crates/nu-plugin/src/protocol/mod.rs @@ -17,9 +17,8 @@ use std::collections::HashMap; pub use evaluated_call::EvaluatedCall; pub use plugin_custom_value::PluginCustomValue; -#[cfg(test)] -pub use protocol_info::Protocol; -pub use protocol_info::ProtocolInfo; +#[allow(unused_imports)] // may be unused by compile flags +pub use protocol_info::{Feature, Protocol, ProtocolInfo}; /// A sequential identifier for a stream pub type StreamId = usize; @@ -485,6 +484,10 @@ pub enum EngineCall { AddEnvVar(String, Value), /// Get help for the current command GetHelp, + /// Move the plugin into the foreground for terminal interaction + EnterForeground, + /// Move the plugin out of the foreground once terminal interaction has finished + LeaveForeground, /// Get the contents of a span. Response is a binary which may not parse to UTF-8 GetSpanContents(Span), /// Evaluate a closure with stream input/output @@ -515,6 +518,8 @@ impl EngineCall { EngineCall::GetCurrentDir => "GetCurrentDir", EngineCall::AddEnvVar(..) => "AddEnvVar", EngineCall::GetHelp => "GetHelp", + EngineCall::EnterForeground => "EnterForeground", + EngineCall::LeaveForeground => "LeaveForeground", EngineCall::GetSpanContents(_) => "GetSpanContents", EngineCall::EvalClosure { .. } => "EvalClosure", } @@ -534,6 +539,8 @@ impl EngineCall { EngineCall::GetCurrentDir => EngineCall::GetCurrentDir, EngineCall::AddEnvVar(name, value) => EngineCall::AddEnvVar(name, value), EngineCall::GetHelp => EngineCall::GetHelp, + EngineCall::EnterForeground => EngineCall::EnterForeground, + EngineCall::LeaveForeground => EngineCall::LeaveForeground, EngineCall::GetSpanContents(span) => EngineCall::GetSpanContents(span), EngineCall::EvalClosure { closure, diff --git a/crates/nu-plugin/src/protocol/protocol_info.rs b/crates/nu-plugin/src/protocol/protocol_info.rs index e7f40234b5..922feb64b6 100644 --- a/crates/nu-plugin/src/protocol/protocol_info.rs +++ b/crates/nu-plugin/src/protocol/protocol_info.rs @@ -22,12 +22,13 @@ impl Default for ProtocolInfo { ProtocolInfo { protocol: Protocol::NuPlugin, version: env!("CARGO_PKG_VERSION").into(), - features: vec![], + features: default_features(), } } } impl ProtocolInfo { + /// True if the version specified in `self` is compatible with the version specified in `other`. pub fn is_compatible_with(&self, other: &ProtocolInfo) -> Result { fn parse_failed(error: semver::Error) -> ShellError { ShellError::PluginFailedToLoad { @@ -52,6 +53,11 @@ impl ProtocolInfo { } .matches(&versions[1])) } + + /// True if the protocol info contains a feature compatible with the given feature. + pub fn supports_feature(&self, feature: &Feature) -> bool { + self.features.iter().any(|f| feature.is_compatible_with(f)) + } } /// Indicates the protocol in use. Only one protocol is supported. @@ -72,9 +78,29 @@ pub enum Protocol { #[derive(Serialize, Deserialize, Debug, Clone)] #[serde(tag = "name")] pub enum Feature { + /// The plugin supports running with a local socket passed via `--local-socket` instead of + /// stdio. + LocalSocket, + /// A feature that was not recognized on deserialization. Attempting to serialize this feature /// is an error. Matching against it may only be used if necessary to determine whether /// unsupported features are present. #[serde(other, skip_serializing)] Unknown, } + +impl Feature { + /// True if the feature is considered to be compatible with another feature. + pub fn is_compatible_with(&self, other: &Feature) -> bool { + matches!((self, other), (Feature::LocalSocket, Feature::LocalSocket)) + } +} + +/// Protocol features compiled into this version of `nu-plugin`. +pub fn default_features() -> Vec { + vec![ + // Only available if compiled with the `local-socket` feature flag (enabled by default). + #[cfg(feature = "local-socket")] + Feature::LocalSocket, + ] +} diff --git a/crates/nu-plugin/src/util/mod.rs b/crates/nu-plugin/src/util/mod.rs index 6a4fe5e5d4..ae861705b3 100644 --- a/crates/nu-plugin/src/util/mod.rs +++ b/crates/nu-plugin/src/util/mod.rs @@ -1,5 +1,7 @@ mod mutable_cow; +mod waitable; mod with_custom_values_in; pub(crate) use mutable_cow::*; +pub use waitable::Waitable; pub use with_custom_values_in::*; diff --git a/crates/nu-plugin/src/util/waitable.rs b/crates/nu-plugin/src/util/waitable.rs new file mode 100644 index 0000000000..9793c93b69 --- /dev/null +++ b/crates/nu-plugin/src/util/waitable.rs @@ -0,0 +1,100 @@ +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Condvar, Mutex, MutexGuard, PoisonError, +}; + +use nu_protocol::ShellError; + +/// A container that may be empty, and allows threads to block until it has a value. +#[derive(Debug)] +pub struct Waitable { + is_set: AtomicBool, + mutex: Mutex>, + condvar: Condvar, +} + +#[track_caller] +fn fail_if_poisoned<'a, T>( + result: Result, PoisonError>>, +) -> Result, ShellError> { + match result { + Ok(guard) => Ok(guard), + Err(_) => Err(ShellError::NushellFailedHelp { + msg: "Waitable mutex poisoned".into(), + help: std::panic::Location::caller().to_string(), + }), + } +} + +impl Waitable { + /// Create a new empty `Waitable`. + pub fn new() -> Waitable { + Waitable { + is_set: AtomicBool::new(false), + mutex: Mutex::new(None), + condvar: Condvar::new(), + } + } + + /// Wait for a value to be available and then clone it. + #[track_caller] + pub fn get(&self) -> Result { + let guard = fail_if_poisoned(self.mutex.lock())?; + if let Some(value) = (*guard).clone() { + Ok(value) + } else { + let guard = fail_if_poisoned(self.condvar.wait_while(guard, |g| g.is_none()))?; + Ok((*guard) + .clone() + .expect("checked already for Some but it was None")) + } + } + + /// Clone the value if one is available, but don't wait if not. + #[track_caller] + pub fn try_get(&self) -> Result, ShellError> { + let guard = fail_if_poisoned(self.mutex.lock())?; + Ok((*guard).clone()) + } + + /// Returns true if value is available. + #[track_caller] + pub fn is_set(&self) -> bool { + self.is_set.load(Ordering::SeqCst) + } + + /// Set the value and let waiting threads know. + #[track_caller] + pub fn set(&self, value: T) -> Result<(), ShellError> { + let mut guard = fail_if_poisoned(self.mutex.lock())?; + self.is_set.store(true, Ordering::SeqCst); + *guard = Some(value); + self.condvar.notify_all(); + Ok(()) + } +} + +impl Default for Waitable { + fn default() -> Self { + Self::new() + } +} + +#[test] +fn set_from_other_thread() -> Result<(), ShellError> { + use std::sync::Arc; + + let waitable = Arc::new(Waitable::new()); + let waitable_clone = waitable.clone(); + + assert!(!waitable.is_set()); + + std::thread::spawn(move || { + waitable_clone.set(42).expect("error on set"); + }); + + assert_eq!(42, waitable.get()?); + assert_eq!(Some(42), waitable.try_get()?); + assert!(waitable.is_set()); + Ok(()) +} diff --git a/crates/nu-protocol/src/plugin/registered.rs b/crates/nu-protocol/src/plugin/registered.rs index cb0c893728..46d65b41d1 100644 --- a/crates/nu-protocol/src/plugin/registered.rs +++ b/crates/nu-protocol/src/plugin/registered.rs @@ -19,6 +19,10 @@ pub trait RegisteredPlugin: Send + Sync { /// Stop the plugin. fn stop(&self) -> Result<(), ShellError>; + /// Stop the plugin and reset any state so that we don't make any assumptions about the plugin + /// next time it launches. This is used on `register`. + fn reset(&self) -> Result<(), ShellError>; + /// Cast the pointer to an [`Any`] so that its concrete type can be retrieved. /// /// This is necessary in order to allow `nu_plugin` to handle the implementation details of diff --git a/crates/nu-system/src/foreground.rs b/crates/nu-system/src/foreground.rs index f7d675461e..d54cab1f19 100644 --- a/crates/nu-system/src/foreground.rs +++ b/crates/nu-system/src/foreground.rs @@ -1,16 +1,11 @@ use std::{ io, process::{Child, Command}, + sync::{atomic::AtomicU32, Arc}, }; #[cfg(unix)] -use std::{ - io::IsTerminal, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, - }, -}; +use std::{io::IsTerminal, sync::atomic::Ordering}; #[cfg(unix)] pub use foreground_pgroup::stdin_fd; @@ -97,6 +92,139 @@ impl Drop for ForegroundChild { } } +/// Keeps a specific already existing process in the foreground as long as the [`ForegroundGuard`]. +/// If the process needs to be spawned in the foreground, use [`ForegroundChild`] instead. This is +/// used to temporarily bring plugin processes into the foreground. +/// +/// # OS-specific behavior +/// ## Unix +/// +/// If there is already a foreground external process running, spawned with [`ForegroundChild`], +/// this expects the process ID to remain in the process group created by the [`ForegroundChild`] +/// for the lifetime of the guard, and keeps the terminal controlling process group set to that. If +/// there is no foreground external process running, this sets the foreground process group to the +/// plugin's process ID. The process group that is expected can be retrieved with [`.pgrp()`] if +/// different from the plugin process ID. +/// +/// ## Other systems +/// +/// It does nothing special on non-unix systems. +#[derive(Debug)] +pub struct ForegroundGuard { + #[cfg(unix)] + pgrp: Option, + #[cfg(unix)] + pipeline_state: Arc<(AtomicU32, AtomicU32)>, +} + +impl ForegroundGuard { + /// Move the given process to the foreground. + #[cfg(unix)] + pub fn new( + pid: u32, + pipeline_state: &Arc<(AtomicU32, AtomicU32)>, + ) -> std::io::Result { + use nix::unistd::{self, Pid}; + + let pid_nix = Pid::from_raw(pid as i32); + let (pgrp, pcnt) = pipeline_state.as_ref(); + + // Might have to retry due to race conditions on the atomics + loop { + // Try to give control to the child, if there isn't currently a foreground group + if pgrp + .compare_exchange(0, pid, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + { + let _ = pcnt.fetch_add(1, Ordering::SeqCst); + + // We don't need the child to change process group. Make the guard now so that if there + // is an error, it will be cleaned up + let guard = ForegroundGuard { + pgrp: None, + pipeline_state: pipeline_state.clone(), + }; + + log::trace!("Giving control of the terminal to the plugin group, pid={pid}"); + + // Set the terminal controlling process group to the child process + unistd::tcsetpgrp(unsafe { stdin_fd() }, pid_nix)?; + + return Ok(guard); + } else if pcnt + .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |count| { + // Avoid a race condition: only increment if count is > 0 + if count > 0 { + Some(count + 1) + } else { + None + } + }) + .is_ok() + { + // We successfully added another count to the foreground process group, which means + // we only need to tell the child process to join this one + let pgrp = pgrp.load(Ordering::SeqCst); + log::trace!( + "Will ask the plugin pid={pid} to join pgrp={pgrp} for control of the \ + terminal" + ); + return Ok(ForegroundGuard { + pgrp: Some(pgrp), + pipeline_state: pipeline_state.clone(), + }); + } else { + // The state has changed, we'll have to retry + continue; + } + } + } + + /// Move the given process to the foreground. + #[cfg(not(unix))] + pub fn new( + pid: u32, + pipeline_state: &Arc<(AtomicU32, AtomicU32)>, + ) -> std::io::Result { + let _ = (pid, pipeline_state); + Ok(ForegroundGuard {}) + } + + /// If the child process is expected to join a different process group to be in the foreground, + /// this returns `Some(pgrp)`. This only ever returns `Some` on Unix. + pub fn pgrp(&self) -> Option { + #[cfg(unix)] + { + self.pgrp + } + #[cfg(not(unix))] + { + None + } + } + + /// This should only be called once by `Drop` + fn reset_internal(&mut self) { + #[cfg(unix)] + { + log::trace!("Leaving the foreground group"); + + let (pgrp, pcnt) = self.pipeline_state.as_ref(); + if pcnt.fetch_sub(1, Ordering::SeqCst) == 1 { + // Clean up if we are the last one around + pgrp.store(0, Ordering::SeqCst); + foreground_pgroup::reset() + } + } + } +} + +impl Drop for ForegroundGuard { + fn drop(&mut self) { + self.reset_internal(); + } +} + // It's a simpler version of fish shell's external process handling. #[cfg(unix)] mod foreground_pgroup { diff --git a/crates/nu-system/src/lib.rs b/crates/nu-system/src/lib.rs index bee037cb8b..6058ed4fcf 100644 --- a/crates/nu-system/src/lib.rs +++ b/crates/nu-system/src/lib.rs @@ -9,7 +9,7 @@ mod windows; #[cfg(unix)] pub use self::foreground::stdin_fd; -pub use self::foreground::ForegroundChild; +pub use self::foreground::{ForegroundChild, ForegroundGuard}; #[cfg(any(target_os = "android", target_os = "linux"))] pub use self::linux::*; #[cfg(target_os = "macos")] diff --git a/crates/nu-test-support/src/macros.rs b/crates/nu-test-support/src/macros.rs index 386127f0b3..0b689ea107 100644 --- a/crates/nu-test-support/src/macros.rs +++ b/crates/nu-test-support/src/macros.rs @@ -203,15 +203,38 @@ macro_rules! nu_with_std { #[macro_export] macro_rules! nu_with_plugins { (cwd: $cwd:expr, plugins: [$(($plugin_name:expr)),+$(,)?], $command:expr) => {{ - $crate::macros::nu_with_plugin_run_test($cwd, &[$($plugin_name),+], $command) + nu_with_plugins!( + cwd: $cwd, + envs: Vec::<(&str, &str)>::new(), + plugins: [$(($plugin_name)),+], + $command + ) }}; (cwd: $cwd:expr, plugin: ($plugin_name:expr), $command:expr) => {{ - $crate::macros::nu_with_plugin_run_test($cwd, &[$plugin_name], $command) + nu_with_plugins!( + cwd: $cwd, + envs: Vec::<(&str, &str)>::new(), + plugin: ($plugin_name), + $command + ) + }}; + + ( + cwd: $cwd:expr, + envs: $envs:expr, + plugins: [$(($plugin_name:expr)),+$(,)?], + $command:expr + ) => {{ + $crate::macros::nu_with_plugin_run_test($cwd, $envs, &[$($plugin_name),+], $command) + }}; + (cwd: $cwd:expr, envs: $envs:expr, plugin: ($plugin_name:expr), $command:expr) => {{ + $crate::macros::nu_with_plugin_run_test($cwd, $envs, &[$plugin_name], $command) }}; } use crate::{Outcome, NATIVE_PATH_ENV_VAR}; +use std::ffi::OsStr; use std::fmt::Write; use std::{ path::Path, @@ -285,7 +308,17 @@ pub fn nu_run_test(opts: NuOpts, commands: impl AsRef, with_std: bool) -> O Outcome::new(out, err.into_owned(), output.status) } -pub fn nu_with_plugin_run_test(cwd: impl AsRef, plugins: &[&str], command: &str) -> Outcome { +pub fn nu_with_plugin_run_test( + cwd: impl AsRef, + envs: E, + plugins: &[&str], + command: &str, +) -> Outcome +where + E: IntoIterator, + K: AsRef, + V: AsRef, +{ let test_bins = crate::fs::binaries(); let test_bins = nu_path::canonicalize_with(&test_bins, ".").unwrap_or_else(|e| { panic!( @@ -325,6 +358,7 @@ pub fn nu_with_plugin_run_test(cwd: impl AsRef, plugins: &[&str], command: executable_path = crate::fs::installed_nu_path(); } let process = match setup_command(&executable_path, &target_cwd) + .envs(envs) .arg("--commands") .arg(commands) .arg("--config") diff --git a/crates/nu_plugin_python/nu_plugin_python_example.py b/crates/nu_plugin_python/nu_plugin_python_example.py index c2601214a6..6aaf1a5000 100755 --- a/crates/nu_plugin_python/nu_plugin_python_example.py +++ b/crates/nu_plugin_python/nu_plugin_python_example.py @@ -133,6 +133,13 @@ def process_call(id, plugin_call): span = plugin_call["Run"]["call"]["head"] # Creates a Value of type List that will be encoded and sent to Nushell + f = lambda x, y: { + "Int": { + "val": x * y, + "span": span + } + } + value = { "Value": { "List": { @@ -140,15 +147,9 @@ def process_call(id, plugin_call): { "Record": { "val": { - "cols": ["one", "two", "three"], - "vals": [ - { - "Int": { - "val": x * y, - "span": span - } - } for y in [0, 1, 2] - ] + "one": f(x, 0), + "two": f(x, 1), + "three": f(x, 2), }, "span": span } diff --git a/crates/nu_plugin_stress_internals/Cargo.toml b/crates/nu_plugin_stress_internals/Cargo.toml new file mode 100644 index 0000000000..0062c627fb --- /dev/null +++ b/crates/nu_plugin_stress_internals/Cargo.toml @@ -0,0 +1,19 @@ +[package] +authors = ["The Nushell Project Developers"] +description = "A test plugin for Nushell to stress aspects of the internals" +repository = "https://github.com/nushell/nushell/tree/main/crates/nu_plugin_stress_internals" +edition = "2021" +license = "MIT" +name = "nu_plugin_stress_internals" +version = "0.92.3" + +[[bin]] +name = "nu_plugin_stress_internals" +bench = false + +[dependencies] +# Intentionally not using the nu-protocol / nu-plugin crates, to check behavior against our +# assumptions about the serialized format +serde = { workspace = true } +serde_json = { workspace = true } +interprocess = "1.2.1" diff --git a/crates/nu_plugin_stress_internals/LICENSE b/crates/nu_plugin_stress_internals/LICENSE new file mode 100644 index 0000000000..ae174e8595 --- /dev/null +++ b/crates/nu_plugin_stress_internals/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 - 2023 The Nushell Project Developers + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/nu_plugin_stress_internals/src/main.rs b/crates/nu_plugin_stress_internals/src/main.rs new file mode 100644 index 0000000000..3f95d5ef02 --- /dev/null +++ b/crates/nu_plugin_stress_internals/src/main.rs @@ -0,0 +1,213 @@ +use std::{ + error::Error, + io::{BufRead, BufReader, Write}, +}; + +use interprocess::local_socket::LocalSocketStream; +use serde::Deserialize; +use serde_json::{json, Value}; + +#[derive(Debug)] +struct Options { + refuse_local_socket: bool, + advertise_local_socket: bool, + exit_early: bool, + wrong_version: bool, + local_socket_path: Option, +} + +pub fn main() -> Result<(), Box> { + let args: Vec = std::env::args().collect(); + + eprintln!("stress_internals: args: {args:?}"); + + // Parse options from environment variables + fn has_env(var: &str) -> bool { + std::env::var(var).is_ok() + } + let mut opts = Options { + refuse_local_socket: has_env("STRESS_REFUSE_LOCAL_SOCKET"), + advertise_local_socket: has_env("STRESS_ADVERTISE_LOCAL_SOCKET"), + exit_early: has_env("STRESS_EXIT_EARLY"), + wrong_version: has_env("STRESS_WRONG_VERSION"), + local_socket_path: None, + }; + + #[allow(unused_mut)] + let mut should_flush = true; + + let (mut input, mut output): (Box, Box) = + match args.get(1).map(|s| s.as_str()) { + Some("--stdio") => ( + Box::new(std::io::stdin().lock()), + Box::new(std::io::stdout()), + ), + Some("--local-socket") => { + opts.local_socket_path = Some(args[2].clone()); + if opts.refuse_local_socket { + std::process::exit(1) + } else { + let in_socket = LocalSocketStream::connect(args[2].as_str())?; + let out_socket = LocalSocketStream::connect(args[2].as_str())?; + + #[cfg(windows)] + { + // Flushing on a socket on Windows is weird and waits for the other side + should_flush = false; + } + + (Box::new(BufReader::new(in_socket)), Box::new(out_socket)) + } + } + None => { + eprintln!("Run nu_plugin_stress_internals as a plugin from inside nushell"); + std::process::exit(1) + } + _ => { + eprintln!("Received args I don't understand: {args:?}"); + std::process::exit(1) + } + }; + + // Send encoding format + output.write_all(b"\x04json")?; + if should_flush { + output.flush()?; + } + + // Send `Hello` message + write( + &mut output, + should_flush, + &json!({ + "Hello": { + "protocol": "nu-plugin", + "version": if opts.wrong_version { + "0.0.0" + } else { + env!("CARGO_PKG_VERSION") + }, + "features": if opts.advertise_local_socket { + vec![json!({"name": "LocalSocket"})] + } else { + vec![] + }, + } + }), + )?; + + // Read `Hello` message + let mut de = serde_json::Deserializer::from_reader(&mut input); + let hello: Value = Value::deserialize(&mut de)?; + + assert!(hello.get("Hello").is_some()); + + if opts.exit_early { + // Exit without handling anything other than Hello + std::process::exit(0); + } + + // Parse incoming messages + loop { + match Value::deserialize(&mut de) { + Ok(message) => handle_message(&mut output, should_flush, &opts, &message)?, + Err(err) => { + if err.is_eof() { + break; + } else { + return Err(err.into()); + } + } + } + } + + Ok(()) +} + +fn handle_message( + output: &mut impl Write, + should_flush: bool, + opts: &Options, + message: &Value, +) -> Result<(), Box> { + if let Some(plugin_call) = message.get("Call") { + let (id, plugin_call) = (&plugin_call[0], &plugin_call[1]); + if plugin_call.as_str() == Some("Signature") { + write( + output, + should_flush, + &json!({ + "CallResponse": [ + id, + { + "Signature": signatures(), + } + ] + }), + ) + } else if let Some(call_info) = plugin_call.get("Run") { + if call_info["name"].as_str() == Some("stress_internals") { + // Just return debug of opts + let return_value = json!({ + "String": { + "val": format!("{opts:?}"), + "span": &call_info["call"]["head"], + } + }); + write( + output, + should_flush, + &json!({ + "CallResponse": [ + id, + { + "PipelineData": { + "Value": return_value + } + } + ] + }), + ) + } else { + Err(format!("unknown call name: {call_info}").into()) + } + } else { + Err(format!("unknown plugin call: {plugin_call}").into()) + } + } else if message.as_str() == Some("Goodbye") { + std::process::exit(0); + } else { + Err(format!("unknown message: {message}").into()) + } +} + +fn signatures() -> Vec { + vec![json!({ + "sig": { + "name": "stress_internals", + "usage": "Used to test behavior of plugin protocol", + "extra_usage": "", + "search_terms": [], + "required_positional": [], + "optional_positional": [], + "rest_positional": null, + "named": [], + "input_output_types": [], + "allow_variants_without_examples": false, + "is_filter": false, + "creates_scope": false, + "allows_unknown_args": false, + "category": "Experimental", + }, + "examples": [], + })] +} + +fn write(output: &mut impl Write, should_flush: bool, value: &Value) -> Result<(), Box> { + serde_json::to_writer(&mut *output, value)?; + output.write_all(b"\n")?; + if should_flush { + output.flush()?; + } + Ok(()) +} diff --git a/tests/plugins/mod.rs b/tests/plugins/mod.rs index 244af5cbef..f52006d3ad 100644 --- a/tests/plugins/mod.rs +++ b/tests/plugins/mod.rs @@ -5,3 +5,4 @@ mod env; mod formats; mod register; mod stream; +mod stress_internals; diff --git a/tests/plugins/stress_internals.rs b/tests/plugins/stress_internals.rs new file mode 100644 index 0000000000..1207c15252 --- /dev/null +++ b/tests/plugins/stress_internals.rs @@ -0,0 +1,123 @@ +use nu_test_support::nu_with_plugins; + +fn ensure_stress_env_vars_unset() { + for (key, _) in std::env::vars_os() { + if key.to_string_lossy().starts_with("STRESS_") { + panic!("Test is running in a dirty environment: {key:?} is set"); + } + } +} + +#[test] +fn test_stdio() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(result.status.success()); + assert!(result.out.contains("local_socket_path: None")); +} + +#[test] +fn test_local_socket() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_ADVERTISE_LOCAL_SOCKET", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(result.status.success()); + // Should be run once in stdio mode + assert!(result.err.contains("--stdio")); + // And then in local socket mode + assert!(result.err.contains("--local-socket")); + assert!(result.out.contains("local_socket_path: Some")); +} + +#[test] +fn test_failing_local_socket_fallback() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_ADVERTISE_LOCAL_SOCKET", "1"), + ("STRESS_REFUSE_LOCAL_SOCKET", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(result.status.success()); + + // Count the number of times we do stdio/local socket + let mut count_stdio = 0; + let mut count_local_socket = 0; + + for line in result.err.split('\n') { + if line.contains("--stdio") { + count_stdio += 1; + } + if line.contains("--local-socket") { + count_local_socket += 1; + } + } + + // Should be run once in local socket mode + assert_eq!(1, count_local_socket, "count of --local-socket"); + // Should be run twice in stdio mode, due to the fallback + assert_eq!(2, count_stdio, "count of --stdio"); + + // In the end it should not be running in local socket mode, but should succeed + assert!(result.out.contains("local_socket_path: None")); +} + +#[test] +fn test_exit_early_stdio() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_EXIT_EARLY", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(!result.status.success()); + assert!(result.err.contains("--stdio")); +} + +#[test] +fn test_exit_early_local_socket() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_ADVERTISE_LOCAL_SOCKET", "1"), + ("STRESS_EXIT_EARLY", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(!result.status.success()); + assert!(result.err.contains("--local-socket")); +} + +#[test] +fn test_wrong_version() { + ensure_stress_env_vars_unset(); + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_WRONG_VERSION", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + assert!(!result.status.success()); + assert!(result.err.contains("version")); + assert!(result.err.contains("0.0.0")); +}