mirror of
https://github.com/nushell/nushell
synced 2025-01-13 13:49:21 +00:00
Fix (and test) for a deadlock that can happen while waiting for protocol info (#12633)
# Description The local socket PR introduced a `Waitable` type, which could either hold a value or be waited on until a value is available. Unlike a channel, it would always return that value once set. However, one issue with this design was that there was no way to detect whether a value would ever be written. This splits the writer into a different type `WaitableMut`, so that when it is dropped, waiting threads can fail (because they'll never get a value). # Tests + Formatting A test has been added to `stress_internals` to make sure this fails in the right way. - 🟢 `toolkit fmt` - 🟢 `toolkit clippy` - 🟢 `toolkit test` - 🟢 `toolkit test stdlib`
This commit is contained in:
parent
0f645b3bb6
commit
c52884b3c8
8 changed files with 184 additions and 57 deletions
|
@ -10,7 +10,7 @@ use crate::{
|
|||
PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, PluginOption,
|
||||
PluginOutput, ProtocolInfo,
|
||||
},
|
||||
util::Waitable,
|
||||
util::{Waitable, WaitableMut},
|
||||
};
|
||||
use nu_protocol::{
|
||||
engine::Closure, Config, IntoInterruptiblePipelineData, LabeledError, ListStream, PipelineData,
|
||||
|
@ -85,6 +85,8 @@ impl std::fmt::Debug for EngineInterfaceState {
|
|||
pub struct EngineInterfaceManager {
|
||||
/// Shared state
|
||||
state: Arc<EngineInterfaceState>,
|
||||
/// The writer for protocol info
|
||||
protocol_info_mut: WaitableMut<Arc<ProtocolInfo>>,
|
||||
/// Channel to send received PluginCalls to. This is removed after `Goodbye` is received.
|
||||
plugin_call_sender: Option<mpsc::Sender<ReceivedPluginCall>>,
|
||||
/// Receiver for PluginCalls. This is usually taken after initialization
|
||||
|
@ -103,15 +105,17 @@ impl EngineInterfaceManager {
|
|||
pub(crate) fn new(writer: impl PluginWrite<PluginOutput> + 'static) -> EngineInterfaceManager {
|
||||
let (plug_tx, plug_rx) = mpsc::channel();
|
||||
let (subscription_tx, subscription_rx) = mpsc::channel();
|
||||
let protocol_info_mut = WaitableMut::new();
|
||||
|
||||
EngineInterfaceManager {
|
||||
state: Arc::new(EngineInterfaceState {
|
||||
protocol_info: Waitable::new(),
|
||||
protocol_info: protocol_info_mut.reader(),
|
||||
engine_call_id_sequence: Sequence::default(),
|
||||
stream_id_sequence: Sequence::default(),
|
||||
engine_call_subscription_sender: subscription_tx,
|
||||
writer: Box::new(writer),
|
||||
}),
|
||||
protocol_info_mut,
|
||||
plugin_call_sender: Some(plug_tx),
|
||||
plugin_call_receiver: Some(plug_rx),
|
||||
engine_call_subscriptions: BTreeMap::new(),
|
||||
|
@ -233,7 +237,7 @@ impl InterfaceManager for EngineInterfaceManager {
|
|||
match input {
|
||||
PluginInput::Hello(info) => {
|
||||
let info = Arc::new(info);
|
||||
self.state.protocol_info.set(info.clone())?;
|
||||
self.protocol_info_mut.set(info.clone())?;
|
||||
|
||||
let local_info = ProtocolInfo::default();
|
||||
if local_info.is_compatible_with(&info)? {
|
||||
|
|
|
@ -300,8 +300,7 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(),
|
|||
|
||||
fn set_default_protocol_info(manager: &mut EngineInterfaceManager) -> Result<(), ShellError> {
|
||||
manager
|
||||
.state
|
||||
.protocol_info
|
||||
.protocol_info_mut
|
||||
.set(Arc::new(ProtocolInfo::default()))
|
||||
}
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ use crate::{
|
|||
PluginOutput, ProtocolInfo, StreamId, StreamMessage,
|
||||
},
|
||||
sequence::Sequence,
|
||||
util::{with_custom_values_in, Waitable},
|
||||
util::{with_custom_values_in, Waitable, WaitableMut},
|
||||
};
|
||||
use nu_protocol::{
|
||||
ast::Operator, CustomValue, IntoInterruptiblePipelineData, IntoSpanned, ListStream,
|
||||
|
@ -138,6 +138,8 @@ impl Drop for PluginCallState {
|
|||
pub struct PluginInterfaceManager {
|
||||
/// Shared state
|
||||
state: Arc<PluginInterfaceState>,
|
||||
/// The writer for protocol info
|
||||
protocol_info_mut: WaitableMut<Arc<ProtocolInfo>>,
|
||||
/// Manages stream messages and state
|
||||
stream_manager: StreamManager,
|
||||
/// State related to plugin calls
|
||||
|
@ -159,18 +161,20 @@ impl PluginInterfaceManager {
|
|||
writer: impl PluginWrite<PluginInput> + 'static,
|
||||
) -> PluginInterfaceManager {
|
||||
let (subscription_tx, subscription_rx) = mpsc::channel();
|
||||
let protocol_info_mut = WaitableMut::new();
|
||||
|
||||
PluginInterfaceManager {
|
||||
state: Arc::new(PluginInterfaceState {
|
||||
source,
|
||||
process: pid.map(PluginProcess::new),
|
||||
protocol_info: Waitable::new(),
|
||||
protocol_info: protocol_info_mut.reader(),
|
||||
plugin_call_id_sequence: Sequence::default(),
|
||||
stream_id_sequence: Sequence::default(),
|
||||
plugin_call_subscription_sender: subscription_tx,
|
||||
error: OnceLock::new(),
|
||||
writer: Box::new(writer),
|
||||
}),
|
||||
protocol_info_mut,
|
||||
stream_manager: StreamManager::new(),
|
||||
plugin_call_states: BTreeMap::new(),
|
||||
plugin_call_subscription_receiver: subscription_rx,
|
||||
|
@ -464,7 +468,7 @@ impl InterfaceManager for PluginInterfaceManager {
|
|||
match input {
|
||||
PluginOutput::Hello(info) => {
|
||||
let info = Arc::new(info);
|
||||
self.state.protocol_info.set(info.clone())?;
|
||||
self.protocol_info_mut.set(info.clone())?;
|
||||
|
||||
let local_info = ProtocolInfo::default();
|
||||
if local_info.is_compatible_with(&info)? {
|
||||
|
@ -631,7 +635,14 @@ impl PluginInterface {
|
|||
|
||||
/// Get the protocol info for the plugin. Will block to receive `Hello` if not received yet.
|
||||
pub fn protocol_info(&self) -> Result<Arc<ProtocolInfo>, ShellError> {
|
||||
self.state.protocol_info.get()
|
||||
self.state.protocol_info.get().and_then(|info| {
|
||||
info.ok_or_else(|| ShellError::PluginFailedToLoad {
|
||||
msg: format!(
|
||||
"Failed to get protocol info (`Hello` message) from the `{}` plugin",
|
||||
self.state.source.identity.name()
|
||||
),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Write the protocol info. This should be done after initialization
|
||||
|
|
|
@ -321,8 +321,7 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(),
|
|||
|
||||
fn set_default_protocol_info(manager: &mut PluginInterfaceManager) -> Result<(), ShellError> {
|
||||
manager
|
||||
.state
|
||||
.protocol_info
|
||||
.protocol_info_mut
|
||||
.set(Arc::new(ProtocolInfo::default()))
|
||||
}
|
||||
|
||||
|
|
|
@ -3,5 +3,5 @@ mod waitable;
|
|||
mod with_custom_values_in;
|
||||
|
||||
pub(crate) use mutable_cow::*;
|
||||
pub use waitable::Waitable;
|
||||
pub use waitable::*;
|
||||
pub use with_custom_values_in::*;
|
||||
|
|
|
@ -1,18 +1,36 @@
|
|||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Condvar, Mutex, MutexGuard, PoisonError,
|
||||
Arc, Condvar, Mutex, MutexGuard, PoisonError,
|
||||
};
|
||||
|
||||
use nu_protocol::ShellError;
|
||||
|
||||
/// A container that may be empty, and allows threads to block until it has a value.
|
||||
#[derive(Debug)]
|
||||
/// A shared container that may be empty, and allows threads to block until it has a value.
|
||||
///
|
||||
/// This side is read-only - use [`WaitableMut`] on threads that might write a value.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Waitable<T: Clone + Send> {
|
||||
shared: Arc<WaitableShared<T>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WaitableMut<T: Clone + Send> {
|
||||
shared: Arc<WaitableShared<T>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct WaitableShared<T: Clone + Send> {
|
||||
is_set: AtomicBool,
|
||||
mutex: Mutex<Option<T>>,
|
||||
mutex: Mutex<SyncState<T>>,
|
||||
condvar: Condvar,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SyncState<T: Clone + Send> {
|
||||
writers: usize,
|
||||
value: Option<T>,
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn fail_if_poisoned<'a, T>(
|
||||
result: Result<MutexGuard<'a, T>, PoisonError<MutexGuard<'a, T>>>,
|
||||
|
@ -26,75 +44,138 @@ fn fail_if_poisoned<'a, T>(
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> Waitable<T> {
|
||||
/// Create a new empty `Waitable`.
|
||||
pub fn new() -> Waitable<T> {
|
||||
Waitable {
|
||||
impl<T: Clone + Send> WaitableMut<T> {
|
||||
/// Create a new empty `WaitableMut`. Call [`.reader()`] to get [`Waitable`].
|
||||
pub fn new() -> WaitableMut<T> {
|
||||
WaitableMut {
|
||||
shared: Arc::new(WaitableShared {
|
||||
is_set: AtomicBool::new(false),
|
||||
mutex: Mutex::new(None),
|
||||
mutex: Mutex::new(SyncState {
|
||||
writers: 1,
|
||||
value: None,
|
||||
}),
|
||||
condvar: Condvar::new(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Wait for a value to be available and then clone it.
|
||||
pub fn reader(&self) -> Waitable<T> {
|
||||
Waitable {
|
||||
shared: self.shared.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the value and let waiting threads know.
|
||||
#[track_caller]
|
||||
pub fn get(&self) -> Result<T, ShellError> {
|
||||
let guard = fail_if_poisoned(self.mutex.lock())?;
|
||||
if let Some(value) = (*guard).clone() {
|
||||
Ok(value)
|
||||
pub fn set(&self, value: T) -> Result<(), ShellError> {
|
||||
let mut sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
|
||||
self.shared.is_set.store(true, Ordering::SeqCst);
|
||||
sync_state.value = Some(value);
|
||||
self.shared.condvar.notify_all();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> Default for WaitableMut<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> Clone for WaitableMut<T> {
|
||||
fn clone(&self) -> Self {
|
||||
let shared = self.shared.clone();
|
||||
shared
|
||||
.mutex
|
||||
.lock()
|
||||
.expect("failed to lock mutex to increment writers")
|
||||
.writers += 1;
|
||||
WaitableMut { shared }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> Drop for WaitableMut<T> {
|
||||
fn drop(&mut self) {
|
||||
// Decrement writers...
|
||||
if let Ok(mut sync_state) = self.shared.mutex.lock() {
|
||||
sync_state.writers = sync_state
|
||||
.writers
|
||||
.checked_sub(1)
|
||||
.expect("would decrement writers below zero");
|
||||
}
|
||||
// and notify waiting threads so they have a chance to see it.
|
||||
self.shared.condvar.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> Waitable<T> {
|
||||
/// Wait for a value to be available and then clone it.
|
||||
///
|
||||
/// Returns `Ok(None)` if there are no writers left that could possibly place a value.
|
||||
#[track_caller]
|
||||
pub fn get(&self) -> Result<Option<T>, ShellError> {
|
||||
let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
|
||||
if let Some(value) = sync_state.value.clone() {
|
||||
Ok(Some(value))
|
||||
} else if sync_state.writers == 0 {
|
||||
// There can't possibly be a value written, so no point in waiting.
|
||||
Ok(None)
|
||||
} else {
|
||||
let guard = fail_if_poisoned(self.condvar.wait_while(guard, |g| g.is_none()))?;
|
||||
Ok((*guard)
|
||||
.clone()
|
||||
.expect("checked already for Some but it was None"))
|
||||
let sync_state = fail_if_poisoned(
|
||||
self.shared
|
||||
.condvar
|
||||
.wait_while(sync_state, |g| g.writers > 0 && g.value.is_none()),
|
||||
)?;
|
||||
Ok(sync_state.value.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Clone the value if one is available, but don't wait if not.
|
||||
#[track_caller]
|
||||
pub fn try_get(&self) -> Result<Option<T>, ShellError> {
|
||||
let guard = fail_if_poisoned(self.mutex.lock())?;
|
||||
Ok((*guard).clone())
|
||||
let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
|
||||
Ok(sync_state.value.clone())
|
||||
}
|
||||
|
||||
/// Returns true if value is available.
|
||||
#[track_caller]
|
||||
pub fn is_set(&self) -> bool {
|
||||
self.is_set.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Set the value and let waiting threads know.
|
||||
#[track_caller]
|
||||
pub fn set(&self, value: T) -> Result<(), ShellError> {
|
||||
let mut guard = fail_if_poisoned(self.mutex.lock())?;
|
||||
self.is_set.store(true, Ordering::SeqCst);
|
||||
*guard = Some(value);
|
||||
self.condvar.notify_all();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + Send> Default for Waitable<T> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
self.shared.is_set.load(Ordering::SeqCst)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_from_other_thread() -> Result<(), ShellError> {
|
||||
use std::sync::Arc;
|
||||
|
||||
let waitable = Arc::new(Waitable::new());
|
||||
let waitable_clone = waitable.clone();
|
||||
let waitable_mut = WaitableMut::new();
|
||||
let waitable = waitable_mut.reader();
|
||||
|
||||
assert!(!waitable.is_set());
|
||||
|
||||
std::thread::spawn(move || {
|
||||
waitable_clone.set(42).expect("error on set");
|
||||
waitable_mut.set(42).expect("error on set");
|
||||
});
|
||||
|
||||
assert_eq!(42, waitable.get()?);
|
||||
assert_eq!(Some(42), waitable.get()?);
|
||||
assert_eq!(Some(42), waitable.try_get()?);
|
||||
assert!(waitable.is_set());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dont_deadlock_if_waiting_without_writer() {
|
||||
use std::time::Duration;
|
||||
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
let writer = WaitableMut::<()>::new();
|
||||
let waitable = writer.reader();
|
||||
// Ensure there are no writers
|
||||
drop(writer);
|
||||
std::thread::spawn(move || {
|
||||
let _ = tx.send(waitable.get());
|
||||
});
|
||||
let result = rx
|
||||
.recv_timeout(Duration::from_secs(10))
|
||||
.expect("timed out")
|
||||
.expect("error");
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ use serde_json::{json, Value};
|
|||
struct Options {
|
||||
refuse_local_socket: bool,
|
||||
advertise_local_socket: bool,
|
||||
exit_before_hello: bool,
|
||||
exit_early: bool,
|
||||
wrong_version: bool,
|
||||
local_socket_path: Option<String>,
|
||||
|
@ -28,6 +29,7 @@ pub fn main() -> Result<(), Box<dyn Error>> {
|
|||
let mut opts = Options {
|
||||
refuse_local_socket: has_env("STRESS_REFUSE_LOCAL_SOCKET"),
|
||||
advertise_local_socket: has_env("STRESS_ADVERTISE_LOCAL_SOCKET"),
|
||||
exit_before_hello: has_env("STRESS_EXIT_BEFORE_HELLO"),
|
||||
exit_early: has_env("STRESS_EXIT_EARLY"),
|
||||
wrong_version: has_env("STRESS_WRONG_VERSION"),
|
||||
local_socket_path: None,
|
||||
|
@ -75,6 +77,11 @@ pub fn main() -> Result<(), Box<dyn Error>> {
|
|||
output.flush()?;
|
||||
}
|
||||
|
||||
// Test exiting without `Hello`
|
||||
if opts.exit_before_hello {
|
||||
std::process::exit(1)
|
||||
}
|
||||
|
||||
// Send `Hello` message
|
||||
write(
|
||||
&mut output,
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::{sync::mpsc, time::Duration};
|
||||
|
||||
use nu_test_support::nu_with_plugins;
|
||||
|
||||
fn ensure_stress_env_vars_unset() {
|
||||
|
@ -75,6 +77,30 @@ fn test_failing_local_socket_fallback() {
|
|||
assert!(result.out.contains("local_socket_path: None"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exit_before_hello_stdio() {
|
||||
ensure_stress_env_vars_unset();
|
||||
// This can deadlock if not handled properly, so we try several times and timeout
|
||||
for _ in 0..5 {
|
||||
let (tx, rx) = mpsc::channel();
|
||||
std::thread::spawn(move || {
|
||||
let result = nu_with_plugins!(
|
||||
cwd: ".",
|
||||
envs: vec![
|
||||
("STRESS_EXIT_BEFORE_HELLO", "1"),
|
||||
],
|
||||
plugin: ("nu_plugin_stress_internals"),
|
||||
"stress_internals"
|
||||
);
|
||||
let _ = tx.send(result);
|
||||
});
|
||||
let result = rx
|
||||
.recv_timeout(Duration::from_secs(15))
|
||||
.expect("timed out. probably a deadlock");
|
||||
assert!(!result.status.success());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exit_early_stdio() {
|
||||
ensure_stress_env_vars_unset();
|
||||
|
|
Loading…
Reference in a new issue