Add helper method to check whether ctrl+c was pressed, adopt it (#7482)

I've been working on streaming and pipeline interruption lately. It was
bothering me that checking ctrl+c (something we want to do often) always
requires a bunch of boilerplate like:
```rust
use std::sync::atomic::Ordering;

if let Some(ctrlc) = &engine_state.ctrlc {
     if ctrlc.load(Ordering::SeqCst) {
          ...
```
I added a helper method to cut that down to:

```rust
if nu_utils::ctrl_c::was_pressed(&engine_state.ctrlc) {
    ...
```
This commit is contained in:
Reilly Wood 2022-12-15 09:39:24 -08:00 committed by GitHub
parent 33aea56ccd
commit e215fbbd08
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 69 additions and 137 deletions

View file

@ -1,5 +1,3 @@
use std::sync::atomic::Ordering;
use nu_engine::{eval_block, CallExt}; use nu_engine::{eval_block, CallExt};
use nu_protocol::ast::Call; use nu_protocol::ast::Call;
use nu_protocol::engine::{Block, Command, EngineState, Stack}; use nu_protocol::engine::{Block, Command, EngineState, Stack};
@ -44,10 +42,8 @@ impl Command for Loop {
let block: Block = call.req(engine_state, stack, 0)?; let block: Block = call.req(engine_state, stack, 0)?;
loop { loop {
if let Some(ctrlc) = &engine_state.ctrlc { if nu_utils::ctrl_c::was_pressed(&engine_state.ctrlc) {
if ctrlc.load(Ordering::SeqCst) { break;
break;
}
} }
let block = engine_state.get_block(block.block_id); let block = engine_state.get_block(block.block_id);

View file

@ -1,5 +1,3 @@
use std::sync::atomic::Ordering;
use nu_engine::{eval_block, eval_expression, CallExt}; use nu_engine::{eval_block, eval_expression, CallExt};
use nu_protocol::ast::Call; use nu_protocol::ast::Call;
use nu_protocol::engine::{Block, Command, EngineState, Stack}; use nu_protocol::engine::{Block, Command, EngineState, Stack};
@ -48,10 +46,8 @@ impl Command for While {
let block: Block = call.req(engine_state, stack, 1)?; let block: Block = call.req(engine_state, stack, 1)?;
loop { loop {
if let Some(ctrlc) = &engine_state.ctrlc { if nu_utils::ctrl_c::was_pressed(&engine_state.ctrlc) {
if ctrlc.load(Ordering::SeqCst) { break;
break;
}
} }
let result = eval_expression(engine_state, stack, cond)?; let result = eval_expression(engine_state, stack, cond)?;

View file

@ -10,10 +10,7 @@ use std::{
fs::File, fs::File,
io::Read, io::Read,
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::{ sync::{atomic::AtomicBool, Arc},
atomic::{AtomicBool, Ordering},
Arc,
},
}; };
const SQLITE_MAGIC_BYTES: &[u8] = "SQLite format 3\0".as_bytes(); const SQLITE_MAGIC_BYTES: &[u8] = "SQLite format 3\0".as_bytes();
@ -399,14 +396,12 @@ fn prepared_statement_to_nu_list(
let mut row_values = vec![]; let mut row_values = vec![];
for row_result in row_results { for row_result in row_results {
if let Some(ctrlc) = &ctrlc { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if ctrlc.load(Ordering::SeqCst) { // return whatever we have so far, let the caller decide whether to use it
// return whatever we have so far, let the caller decide whether to use it return Ok(Value::List {
return Ok(Value::List { vals: row_values,
vals: row_values, span: call_span,
span: call_span, });
});
}
} }
if let Ok(row_value) = row_result { if let Ok(row_value) = row_result {

View file

@ -1,5 +1,4 @@
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::atomic::Ordering;
use std::sync::mpsc::{channel, RecvTimeoutError}; use std::sync::mpsc::{channel, RecvTimeoutError};
use std::time::Duration; use std::time::Duration;
@ -252,10 +251,8 @@ impl Command for Watch {
} }
Err(RecvTimeoutError::Timeout) => {} Err(RecvTimeoutError::Timeout) => {}
} }
if let Some(ctrlc) = ctrlc_ref { if nu_utils::ctrl_c::was_pressed(ctrlc_ref) {
if ctrlc.load(Ordering::SeqCst) { break;
break;
}
} }
} }

View file

@ -1,5 +1,3 @@
use std::sync::atomic::Ordering;
use nu_engine::{eval_block, CallExt}; use nu_engine::{eval_block, CallExt};
use nu_protocol::ast::Call; use nu_protocol::ast::Call;
@ -217,10 +215,8 @@ impl Command for Reduce {
)? )?
.into_value(span); .into_value(span);
if let Some(ctrlc) = &ctrlc { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if ctrlc.load(Ordering::SeqCst) { break;
break;
}
} }
} }

View file

@ -1,4 +1,3 @@
use crate::input_handler::ctrl_c_was_pressed;
use nu_protocol::ast::Call; use nu_protocol::ast::Call;
use nu_protocol::engine::{Command, EngineState, Stack}; use nu_protocol::engine::{Command, EngineState, Stack};
use nu_protocol::{ use nu_protocol::{
@ -231,7 +230,7 @@ pub fn uniq(
let mut uniq_values = input let mut uniq_values = input
.into_iter() .into_iter()
.map_while(|item| { .map_while(|item| {
if ctrl_c_was_pressed(&ctrlc) { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
return None; return None;
} }
Some(item_mapper(ItemMapperState { Some(item_mapper(ItemMapperState {

View file

@ -1,6 +1,6 @@
use nu_protocol::ast::CellPath; use nu_protocol::ast::CellPath;
use nu_protocol::{PipelineData, ShellError, Span, Value}; use nu_protocol::{PipelineData, ShellError, Span, Value};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
pub trait CmdArgument { pub trait CmdArgument {
@ -71,12 +71,3 @@ where
} }
} }
} }
// Helper method to avoid boilerplate every time we check ctrl+c
pub fn ctrl_c_was_pressed(ctrlc: &Option<Arc<AtomicBool>>) -> bool {
if let Some(ctrlc) = ctrlc {
ctrlc.load(Ordering::SeqCst)
} else {
false
}
}

View file

@ -2,7 +2,7 @@ use filesize::file_real_size_fast;
use nu_glob::Pattern; use nu_glob::Pattern;
use nu_protocol::{ShellError, Span, Value}; use nu_protocol::{ShellError, Span, Value};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::AtomicBool;
use std::sync::Arc; use std::sync::Arc;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -106,13 +106,8 @@ impl DirInfo {
match std::fs::read_dir(&s.path) { match std::fs::read_dir(&s.path) {
Ok(d) => { Ok(d) => {
for f in d { for f in d {
match ctrl_c { if nu_utils::ctrl_c::was_pressed(&ctrl_c) {
Some(ref cc) => { break;
if cc.load(Ordering::SeqCst) {
break;
}
}
None => continue,
} }
match f { match f {

View file

@ -6,7 +6,6 @@ use nu_protocol::{
Type, Value, Type, Value,
}; };
use std::{ use std::{
sync::atomic::Ordering,
thread, thread,
time::{Duration, Instant}, time::{Duration, Instant},
}; };
@ -62,10 +61,8 @@ impl Command for Sleep {
break; break;
} }
if let Some(ctrlc) = ctrlc_ref { if nu_utils::ctrl_c::was_pressed(ctrlc_ref) {
if ctrlc.load(Ordering::SeqCst) { break;
break;
}
} }
} }

View file

@ -13,7 +13,7 @@ use std::collections::HashMap;
use std::io::{BufRead, BufReader, Read, Write}; use std::io::{BufRead, BufReader, Read, Write};
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::{Command as CommandSys, Stdio}; use std::process::{Command as CommandSys, Stdio};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::AtomicBool;
use std::sync::mpsc::{self, SyncSender}; use std::sync::mpsc::{self, SyncSender};
use std::sync::Arc; use std::sync::Arc;
@ -745,10 +745,8 @@ fn read_and_redirect_message<R>(
let length = bytes.len(); let length = bytes.len();
buf_read.consume(length); buf_read.consume(length);
if let Some(ctrlc) = &ctrlc { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if ctrlc.load(Ordering::SeqCst) { break;
break;
}
} }
match sender.send(bytes) { match sender.send(bytes) {

View file

@ -12,12 +12,7 @@ use nu_table::{string_width, Alignment, Table as NuTable, TableConfig, TableThem
use nu_utils::get_ls_colors; use nu_utils::get_ls_colors;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use std::{ use std::{cmp::max, collections::HashMap, path::PathBuf, sync::atomic::AtomicBool};
cmp::max,
collections::HashMap,
path::PathBuf,
sync::atomic::{AtomicBool, Ordering},
};
use terminal_size::{Height, Width}; use terminal_size::{Height, Width};
use url::Url; use url::Url;
@ -304,13 +299,8 @@ fn handle_table_command(
let result = strip_output_color(result, config); let result = strip_output_color(result, config);
let ctrl_c_was_triggered = || match &ctrlc {
Some(ctrlc) => ctrlc.load(Ordering::SeqCst),
None => false,
};
let result = result.unwrap_or_else(|| { let result = result.unwrap_or_else(|| {
if ctrl_c_was_triggered() { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
"".into() "".into()
} else { } else {
// assume this failed because the table was too wide // assume this failed because the table was too wide
@ -397,11 +387,8 @@ fn build_general_table2(
) -> Result<Option<String>, ShellError> { ) -> Result<Option<String>, ShellError> {
let mut data = Vec::with_capacity(vals.len()); let mut data = Vec::with_capacity(vals.len());
for (column, value) in cols.into_iter().zip(vals.into_iter()) { for (column, value) in cols.into_iter().zip(vals.into_iter()) {
// handle CTRLC event if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if let Some(ctrlc) = &ctrlc { return Ok(None);
if ctrlc.load(Ordering::SeqCst) {
return Ok(None);
}
} }
let row = vec![ let row = vec![
@ -460,11 +447,8 @@ fn build_expanded_table(
let mut data = Vec::with_capacity(cols.len()); let mut data = Vec::with_capacity(cols.len());
for (key, value) in cols.into_iter().zip(vals) { for (key, value) in cols.into_iter().zip(vals) {
// handle CTRLC event if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if let Some(ctrlc) = &ctrlc { return Ok(None);
if ctrlc.load(Ordering::SeqCst) {
return Ok(None);
}
} }
let is_limited = matches!(expand_limit, Some(0)); let is_limited = matches!(expand_limit, Some(0));
@ -810,10 +794,8 @@ fn convert_to_table(
}; };
for (row_num, item) in input.enumerate() { for (row_num, item) in input.enumerate() {
if let Some(ctrlc) = &ctrlc { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if ctrlc.load(Ordering::SeqCst) { return Ok(None);
return Ok(None);
}
} }
if let Value::Error { error } = item { if let Value::Error { error } = item {
@ -925,10 +907,8 @@ fn convert_to_table2<'a>(
} }
for (row, item) in input.clone().into_iter().enumerate() { for (row, item) in input.clone().into_iter().enumerate() {
if let Some(ctrlc) = &ctrlc { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if ctrlc.load(Ordering::SeqCst) { return Ok(None);
return Ok(None);
}
} }
if let Value::Error { error } = item { if let Value::Error { error } = item {
@ -960,10 +940,8 @@ fn convert_to_table2<'a>(
if !with_header { if !with_header {
for (row, item) in input.into_iter().enumerate() { for (row, item) in input.into_iter().enumerate() {
if let Some(ctrlc) = &ctrlc { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if ctrlc.load(Ordering::SeqCst) { return Ok(None);
return Ok(None);
}
} }
if let Value::Error { error } = item { if let Value::Error { error } = item {
@ -1019,10 +997,8 @@ fn convert_to_table2<'a>(
data[0].push(NuTable::create_cell(&header, header_style(color_hm))); data[0].push(NuTable::create_cell(&header, header_style(color_hm)));
for (row, item) in input.clone().into_iter().enumerate() { for (row, item) in input.clone().into_iter().enumerate() {
if let Some(ctrlc) = &ctrlc { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if ctrlc.load(Ordering::SeqCst) { return Ok(None);
return Ok(None);
}
} }
if let Value::Error { error } = item { if let Value::Error { error } = item {
@ -1059,10 +1035,8 @@ fn convert_to_table2<'a>(
column_width = string_width(&header); column_width = string_width(&header);
for (row, item) in input.clone().into_iter().enumerate() { for (row, item) in input.clone().into_iter().enumerate() {
if let Some(ctrlc) = &ctrlc { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if ctrlc.load(Ordering::SeqCst) { return Ok(None);
return Ok(None);
}
} }
let value = create_table2_entry_basic(item, &header, head, config, color_hm); let value = create_table2_entry_basic(item, &header, head, config, color_hm);
@ -1086,10 +1060,8 @@ fn convert_to_table2<'a>(
column_width = string_width(&header); column_width = string_width(&header);
for (row, item) in input.clone().into_iter().enumerate() { for (row, item) in input.clone().into_iter().enumerate() {
if let Some(ctrlc) = &ctrlc { if nu_utils::ctrl_c::was_pressed(&ctrlc) {
if ctrlc.load(Ordering::SeqCst) { return Ok(None);
return Ok(None);
}
} }
let value = create_table2_entry_basic(item, &header, head, config, color_hm); let value = create_table2_entry_basic(item, &header, head, config, color_hm);
@ -1593,10 +1565,8 @@ impl Iterator for PagingTableCreator {
break; break;
} }
if let Some(ctrlc) = &self.ctrlc { if nu_utils::ctrl_c::was_pressed(&self.ctrlc) {
if ctrlc.load(Ordering::SeqCst) { break;
break;
}
} }
} }

View file

@ -31,10 +31,8 @@ pub fn eval_call(
call: &Call, call: &Call,
input: PipelineData, input: PipelineData,
) -> Result<PipelineData, ShellError> { ) -> Result<PipelineData, ShellError> {
if let Some(ctrlc) = &engine_state.ctrlc { if nu_utils::ctrl_c::was_pressed(&engine_state.ctrlc) {
if ctrlc.load(core::sync::atomic::Ordering::SeqCst) { return Ok(Value::Nothing { span: call.head }.into_pipeline_data());
return Ok(Value::Nothing { span: call.head }.into_pipeline_data());
}
} }
let decl = engine_state.get_decl(call.decl_id); let decl = engine_state.get_decl(call.decl_id);

View file

@ -203,10 +203,8 @@ impl Iterator for RangeIterator {
return None; return None;
} }
if let Some(ctrlc) = &self.ctrlc { if nu_utils::ctrl_c::was_pressed(&self.ctrlc) {
if ctrlc.load(core::sync::atomic::Ordering::SeqCst) { return None;
return None;
}
} }
let ordering = if matches!(self.end, Value::Nothing { .. }) { let ordering = if matches!(self.end, Value::Nothing { .. }) {

View file

@ -1,10 +1,7 @@
use crate::*; use crate::*;
use std::{ use std::{
fmt::Debug, fmt::Debug,
sync::{ sync::{atomic::AtomicBool, Arc},
atomic::{AtomicBool, Ordering},
Arc,
},
}; };
pub struct RawStream { pub struct RawStream {
@ -77,10 +74,8 @@ impl Iterator for RawStream {
type Item = Result<Value, ShellError>; type Item = Result<Value, ShellError>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if let Some(ctrlc) = &self.ctrlc { if nu_utils::ctrl_c::was_pressed(&self.ctrlc) {
if ctrlc.load(Ordering::SeqCst) { return None;
return None;
}
} }
// If we know we're already binary, just output that // If we know we're already binary, just output that
@ -223,12 +218,8 @@ impl Iterator for ListStream {
type Item = Value; type Item = Value;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if let Some(ctrlc) = &self.ctrlc { if nu_utils::ctrl_c::was_pressed(&self.ctrlc) {
if ctrlc.load(Ordering::SeqCst) { None
None
} else {
self.stream.next()
}
} else { } else {
self.stream.next() self.stream.next()
} }

View file

@ -0,0 +1,13 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
/// Returns true if Nu has received a SIGINT signal / ctrl+c event
pub fn was_pressed(ctrlc: &Option<Arc<AtomicBool>>) -> bool {
if let Some(ctrlc) = ctrlc {
ctrlc.load(Ordering::SeqCst)
} else {
false
}
}

View file

@ -1,7 +1,9 @@
pub mod ctrl_c;
mod deansi; mod deansi;
pub mod locale; pub mod locale;
pub mod utils; pub mod utils;
pub use ctrl_c::was_pressed;
pub use locale::get_system_locale; pub use locale::get_system_locale;
pub use utils::{ pub use utils::{
enable_vt_processing, get_default_config, get_default_env, get_ls_colors, enable_vt_processing, get_default_config, get_default_env, get_ls_colors,