diff --git a/crates/nu-plugin/src/plugin/interface/engine.rs b/crates/nu-plugin/src/plugin/interface/engine.rs index e394b561e8..d4513e9033 100644 --- a/crates/nu-plugin/src/plugin/interface/engine.rs +++ b/crates/nu-plugin/src/plugin/interface/engine.rs @@ -10,7 +10,7 @@ use crate::{ PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, PluginOption, PluginOutput, ProtocolInfo, }, - util::Waitable, + util::{Waitable, WaitableMut}, }; use nu_protocol::{ engine::Closure, Config, IntoInterruptiblePipelineData, LabeledError, ListStream, PipelineData, @@ -85,6 +85,8 @@ impl std::fmt::Debug for EngineInterfaceState { pub struct EngineInterfaceManager { /// Shared state state: Arc, + /// The writer for protocol info + protocol_info_mut: WaitableMut>, /// Channel to send received PluginCalls to. This is removed after `Goodbye` is received. plugin_call_sender: Option>, /// Receiver for PluginCalls. This is usually taken after initialization @@ -103,15 +105,17 @@ impl EngineInterfaceManager { pub(crate) fn new(writer: impl PluginWrite + 'static) -> EngineInterfaceManager { let (plug_tx, plug_rx) = mpsc::channel(); let (subscription_tx, subscription_rx) = mpsc::channel(); + let protocol_info_mut = WaitableMut::new(); EngineInterfaceManager { state: Arc::new(EngineInterfaceState { - protocol_info: Waitable::new(), + protocol_info: protocol_info_mut.reader(), engine_call_id_sequence: Sequence::default(), stream_id_sequence: Sequence::default(), engine_call_subscription_sender: subscription_tx, writer: Box::new(writer), }), + protocol_info_mut, plugin_call_sender: Some(plug_tx), plugin_call_receiver: Some(plug_rx), engine_call_subscriptions: BTreeMap::new(), @@ -233,7 +237,7 @@ impl InterfaceManager for EngineInterfaceManager { match input { PluginInput::Hello(info) => { let info = Arc::new(info); - self.state.protocol_info.set(info.clone())?; + self.protocol_info_mut.set(info.clone())?; let local_info = ProtocolInfo::default(); if local_info.is_compatible_with(&info)? { diff --git a/crates/nu-plugin/src/plugin/interface/engine/tests.rs b/crates/nu-plugin/src/plugin/interface/engine/tests.rs index 572f6a39fe..03387f1527 100644 --- a/crates/nu-plugin/src/plugin/interface/engine/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/engine/tests.rs @@ -300,8 +300,7 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), fn set_default_protocol_info(manager: &mut EngineInterfaceManager) -> Result<(), ShellError> { manager - .state - .protocol_info + .protocol_info_mut .set(Arc::new(ProtocolInfo::default())) } diff --git a/crates/nu-plugin/src/plugin/interface/plugin.rs b/crates/nu-plugin/src/plugin/interface/plugin.rs index 67efbf1cdb..a57d79b284 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin.rs @@ -12,7 +12,7 @@ use crate::{ PluginOutput, ProtocolInfo, StreamId, StreamMessage, }, sequence::Sequence, - util::{with_custom_values_in, Waitable}, + util::{with_custom_values_in, Waitable, WaitableMut}, }; use nu_protocol::{ ast::Operator, CustomValue, IntoInterruptiblePipelineData, IntoSpanned, ListStream, @@ -138,6 +138,8 @@ impl Drop for PluginCallState { pub struct PluginInterfaceManager { /// Shared state state: Arc, + /// The writer for protocol info + protocol_info_mut: WaitableMut>, /// Manages stream messages and state stream_manager: StreamManager, /// State related to plugin calls @@ -159,18 +161,20 @@ impl PluginInterfaceManager { writer: impl PluginWrite + 'static, ) -> PluginInterfaceManager { let (subscription_tx, subscription_rx) = mpsc::channel(); + let protocol_info_mut = WaitableMut::new(); PluginInterfaceManager { state: Arc::new(PluginInterfaceState { source, process: pid.map(PluginProcess::new), - protocol_info: Waitable::new(), + protocol_info: protocol_info_mut.reader(), plugin_call_id_sequence: Sequence::default(), stream_id_sequence: Sequence::default(), plugin_call_subscription_sender: subscription_tx, error: OnceLock::new(), writer: Box::new(writer), }), + protocol_info_mut, stream_manager: StreamManager::new(), plugin_call_states: BTreeMap::new(), plugin_call_subscription_receiver: subscription_rx, @@ -464,7 +468,7 @@ impl InterfaceManager for PluginInterfaceManager { match input { PluginOutput::Hello(info) => { let info = Arc::new(info); - self.state.protocol_info.set(info.clone())?; + self.protocol_info_mut.set(info.clone())?; let local_info = ProtocolInfo::default(); if local_info.is_compatible_with(&info)? { @@ -631,7 +635,14 @@ impl PluginInterface { /// Get the protocol info for the plugin. Will block to receive `Hello` if not received yet. pub fn protocol_info(&self) -> Result, ShellError> { - self.state.protocol_info.get() + self.state.protocol_info.get().and_then(|info| { + info.ok_or_else(|| ShellError::PluginFailedToLoad { + msg: format!( + "Failed to get protocol info (`Hello` message) from the `{}` plugin", + self.state.source.identity.name() + ), + }) + }) } /// Write the protocol info. This should be done after initialization diff --git a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs index d40e0887e6..beda84041d 100644 --- a/crates/nu-plugin/src/plugin/interface/plugin/tests.rs +++ b/crates/nu-plugin/src/plugin/interface/plugin/tests.rs @@ -321,8 +321,7 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(), fn set_default_protocol_info(manager: &mut PluginInterfaceManager) -> Result<(), ShellError> { manager - .state - .protocol_info + .protocol_info_mut .set(Arc::new(ProtocolInfo::default())) } diff --git a/crates/nu-plugin/src/util/mod.rs b/crates/nu-plugin/src/util/mod.rs index ae861705b3..5d226cdfbd 100644 --- a/crates/nu-plugin/src/util/mod.rs +++ b/crates/nu-plugin/src/util/mod.rs @@ -3,5 +3,5 @@ mod waitable; mod with_custom_values_in; pub(crate) use mutable_cow::*; -pub use waitable::Waitable; +pub use waitable::*; pub use with_custom_values_in::*; diff --git a/crates/nu-plugin/src/util/waitable.rs b/crates/nu-plugin/src/util/waitable.rs index 9793c93b69..aaefa6f1b5 100644 --- a/crates/nu-plugin/src/util/waitable.rs +++ b/crates/nu-plugin/src/util/waitable.rs @@ -1,18 +1,36 @@ use std::sync::{ atomic::{AtomicBool, Ordering}, - Condvar, Mutex, MutexGuard, PoisonError, + Arc, Condvar, Mutex, MutexGuard, PoisonError, }; use nu_protocol::ShellError; -/// A container that may be empty, and allows threads to block until it has a value. -#[derive(Debug)] +/// A shared container that may be empty, and allows threads to block until it has a value. +/// +/// This side is read-only - use [`WaitableMut`] on threads that might write a value. +#[derive(Debug, Clone)] pub struct Waitable { + shared: Arc>, +} + +#[derive(Debug)] +pub struct WaitableMut { + shared: Arc>, +} + +#[derive(Debug)] +struct WaitableShared { is_set: AtomicBool, - mutex: Mutex>, + mutex: Mutex>, condvar: Condvar, } +#[derive(Debug)] +struct SyncState { + writers: usize, + value: Option, +} + #[track_caller] fn fail_if_poisoned<'a, T>( result: Result, PoisonError>>, @@ -26,75 +44,138 @@ fn fail_if_poisoned<'a, T>( } } -impl Waitable { - /// Create a new empty `Waitable`. - pub fn new() -> Waitable { - Waitable { - is_set: AtomicBool::new(false), - mutex: Mutex::new(None), - condvar: Condvar::new(), +impl WaitableMut { + /// Create a new empty `WaitableMut`. Call [`.reader()`] to get [`Waitable`]. + pub fn new() -> WaitableMut { + WaitableMut { + shared: Arc::new(WaitableShared { + is_set: AtomicBool::new(false), + mutex: Mutex::new(SyncState { + writers: 1, + value: None, + }), + condvar: Condvar::new(), + }), } } - /// Wait for a value to be available and then clone it. + pub fn reader(&self) -> Waitable { + Waitable { + shared: self.shared.clone(), + } + } + + /// Set the value and let waiting threads know. #[track_caller] - pub fn get(&self) -> Result { - let guard = fail_if_poisoned(self.mutex.lock())?; - if let Some(value) = (*guard).clone() { - Ok(value) + pub fn set(&self, value: T) -> Result<(), ShellError> { + let mut sync_state = fail_if_poisoned(self.shared.mutex.lock())?; + self.shared.is_set.store(true, Ordering::SeqCst); + sync_state.value = Some(value); + self.shared.condvar.notify_all(); + Ok(()) + } +} + +impl Default for WaitableMut { + fn default() -> Self { + Self::new() + } +} + +impl Clone for WaitableMut { + fn clone(&self) -> Self { + let shared = self.shared.clone(); + shared + .mutex + .lock() + .expect("failed to lock mutex to increment writers") + .writers += 1; + WaitableMut { shared } + } +} + +impl Drop for WaitableMut { + fn drop(&mut self) { + // Decrement writers... + if let Ok(mut sync_state) = self.shared.mutex.lock() { + sync_state.writers = sync_state + .writers + .checked_sub(1) + .expect("would decrement writers below zero"); + } + // and notify waiting threads so they have a chance to see it. + self.shared.condvar.notify_all(); + } +} + +impl Waitable { + /// Wait for a value to be available and then clone it. + /// + /// Returns `Ok(None)` if there are no writers left that could possibly place a value. + #[track_caller] + pub fn get(&self) -> Result, ShellError> { + let sync_state = fail_if_poisoned(self.shared.mutex.lock())?; + if let Some(value) = sync_state.value.clone() { + Ok(Some(value)) + } else if sync_state.writers == 0 { + // There can't possibly be a value written, so no point in waiting. + Ok(None) } else { - let guard = fail_if_poisoned(self.condvar.wait_while(guard, |g| g.is_none()))?; - Ok((*guard) - .clone() - .expect("checked already for Some but it was None")) + let sync_state = fail_if_poisoned( + self.shared + .condvar + .wait_while(sync_state, |g| g.writers > 0 && g.value.is_none()), + )?; + Ok(sync_state.value.clone()) } } /// Clone the value if one is available, but don't wait if not. #[track_caller] pub fn try_get(&self) -> Result, ShellError> { - let guard = fail_if_poisoned(self.mutex.lock())?; - Ok((*guard).clone()) + let sync_state = fail_if_poisoned(self.shared.mutex.lock())?; + Ok(sync_state.value.clone()) } /// Returns true if value is available. #[track_caller] pub fn is_set(&self) -> bool { - self.is_set.load(Ordering::SeqCst) - } - - /// Set the value and let waiting threads know. - #[track_caller] - pub fn set(&self, value: T) -> Result<(), ShellError> { - let mut guard = fail_if_poisoned(self.mutex.lock())?; - self.is_set.store(true, Ordering::SeqCst); - *guard = Some(value); - self.condvar.notify_all(); - Ok(()) - } -} - -impl Default for Waitable { - fn default() -> Self { - Self::new() + self.shared.is_set.load(Ordering::SeqCst) } } #[test] fn set_from_other_thread() -> Result<(), ShellError> { - use std::sync::Arc; - - let waitable = Arc::new(Waitable::new()); - let waitable_clone = waitable.clone(); + let waitable_mut = WaitableMut::new(); + let waitable = waitable_mut.reader(); assert!(!waitable.is_set()); std::thread::spawn(move || { - waitable_clone.set(42).expect("error on set"); + waitable_mut.set(42).expect("error on set"); }); - assert_eq!(42, waitable.get()?); + assert_eq!(Some(42), waitable.get()?); assert_eq!(Some(42), waitable.try_get()?); assert!(waitable.is_set()); Ok(()) } + +#[test] +fn dont_deadlock_if_waiting_without_writer() { + use std::time::Duration; + + let (tx, rx) = std::sync::mpsc::channel(); + let writer = WaitableMut::<()>::new(); + let waitable = writer.reader(); + // Ensure there are no writers + drop(writer); + std::thread::spawn(move || { + let _ = tx.send(waitable.get()); + }); + let result = rx + .recv_timeout(Duration::from_secs(10)) + .expect("timed out") + .expect("error"); + assert!(result.is_none()); +} diff --git a/crates/nu_plugin_stress_internals/src/main.rs b/crates/nu_plugin_stress_internals/src/main.rs index b05defe278..78a94db5db 100644 --- a/crates/nu_plugin_stress_internals/src/main.rs +++ b/crates/nu_plugin_stress_internals/src/main.rs @@ -11,6 +11,7 @@ use serde_json::{json, Value}; struct Options { refuse_local_socket: bool, advertise_local_socket: bool, + exit_before_hello: bool, exit_early: bool, wrong_version: bool, local_socket_path: Option, @@ -28,6 +29,7 @@ pub fn main() -> Result<(), Box> { let mut opts = Options { refuse_local_socket: has_env("STRESS_REFUSE_LOCAL_SOCKET"), advertise_local_socket: has_env("STRESS_ADVERTISE_LOCAL_SOCKET"), + exit_before_hello: has_env("STRESS_EXIT_BEFORE_HELLO"), exit_early: has_env("STRESS_EXIT_EARLY"), wrong_version: has_env("STRESS_WRONG_VERSION"), local_socket_path: None, @@ -75,6 +77,11 @@ pub fn main() -> Result<(), Box> { output.flush()?; } + // Test exiting without `Hello` + if opts.exit_before_hello { + std::process::exit(1) + } + // Send `Hello` message write( &mut output, diff --git a/tests/plugins/stress_internals.rs b/tests/plugins/stress_internals.rs index 1207c15252..0b8f94fde2 100644 --- a/tests/plugins/stress_internals.rs +++ b/tests/plugins/stress_internals.rs @@ -1,3 +1,5 @@ +use std::{sync::mpsc, time::Duration}; + use nu_test_support::nu_with_plugins; fn ensure_stress_env_vars_unset() { @@ -75,6 +77,30 @@ fn test_failing_local_socket_fallback() { assert!(result.out.contains("local_socket_path: None")); } +#[test] +fn test_exit_before_hello_stdio() { + ensure_stress_env_vars_unset(); + // This can deadlock if not handled properly, so we try several times and timeout + for _ in 0..5 { + let (tx, rx) = mpsc::channel(); + std::thread::spawn(move || { + let result = nu_with_plugins!( + cwd: ".", + envs: vec![ + ("STRESS_EXIT_BEFORE_HELLO", "1"), + ], + plugin: ("nu_plugin_stress_internals"), + "stress_internals" + ); + let _ = tx.send(result); + }); + let result = rx + .recv_timeout(Duration::from_secs(15)) + .expect("timed out. probably a deadlock"); + assert!(!result.status.success()); + } +} + #[test] fn test_exit_early_stdio() { ensure_stress_env_vars_unset();