diff --git a/TODO.md b/TODO.md index f28633e825..3731cefeb3 100644 --- a/TODO.md +++ b/TODO.md @@ -34,7 +34,7 @@ - [x] improved history and config paths - [x] ctrl-c support - [x] operator overflow -- [ ] Support for `$in` +- [x] Support for `$in` - [ ] shells - [ ] plugins - [ ] dataframes diff --git a/crates/nu-command/src/default_context.rs b/crates/nu-command/src/default_context.rs index 193b8bcb0d..900457641d 100644 --- a/crates/nu-command/src/default_context.rs +++ b/crates/nu-command/src/default_context.rs @@ -27,6 +27,7 @@ pub fn create_default_context() -> EngineState { BuildString, CamelCase, Cd, + Collect, Cp, Date, DateFormat, diff --git a/crates/nu-command/src/filters/collect.rs b/crates/nu-command/src/filters/collect.rs new file mode 100644 index 0000000000..07074c0601 --- /dev/null +++ b/crates/nu-command/src/filters/collect.rs @@ -0,0 +1,75 @@ +use nu_engine::eval_block; +use nu_protocol::ast::Call; +use nu_protocol::engine::{Command, EngineState, Stack}; +use nu_protocol::{Example, PipelineData, Signature, SyntaxShape, Value}; + +#[derive(Clone)] +pub struct Collect; + +impl Command for Collect { + fn name(&self) -> &str { + "collect" + } + + fn signature(&self) -> Signature { + Signature::build("collect").required( + "block", + SyntaxShape::Block(Some(vec![SyntaxShape::Any])), + "the block to run once the stream is collected", + ) + } + + fn usage(&self) -> &str { + "Collect the stream and pass it to a block." + } + + fn run( + &self, + engine_state: &EngineState, + stack: &mut Stack, + call: &Call, + input: PipelineData, + ) -> Result { + let block_id = call.positional[0] + .as_block() + .expect("internal error: expected block"); + + let block = engine_state.get_block(block_id).clone(); + let mut stack = stack.collect_captures(&block.captures); + + let input: Value = input.into_value(call.head); + + if let Some(var) = block.signature.get_positional(0) { + if let Some(var_id) = &var.var_id { + stack.add_var(*var_id, input); + } + } + + eval_block( + engine_state, + &mut stack, + &block, + PipelineData::new(call.head), + ) + } + + fn examples(&self) -> Vec { + vec![Example { + description: "Use the second value in the stream", + example: "echo 1 2 3 | collect { |x| echo $x.1 }", + result: Some(Value::test_int(2)), + }] + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_examples() { + use crate::test_examples; + + test_examples(Collect {}) + } +} diff --git a/crates/nu-command/src/filters/mod.rs b/crates/nu-command/src/filters/mod.rs index 4b30539b5f..ea9495a1fc 100644 --- a/crates/nu-command/src/filters/mod.rs +++ b/crates/nu-command/src/filters/mod.rs @@ -1,3 +1,4 @@ +mod collect; mod each; mod first; mod get; @@ -13,6 +14,7 @@ mod where_; mod wrap; mod zip; +pub use collect::Collect; pub use each::Each; pub use first::First; pub use get::Get; diff --git a/crates/nu-parser/src/parser.rs b/crates/nu-parser/src/parser.rs index 51976f553d..cf4b00f3ef 100644 --- a/crates/nu-parser/src/parser.rs +++ b/crates/nu-parser/src/parser.rs @@ -1197,6 +1197,16 @@ pub fn parse_variable_expr( }, None, ); + } else if contents == b"$in" { + return ( + Expression { + expr: Expr::Var(nu_protocol::IN_VARIABLE_ID), + span, + ty: Type::Unknown, + custom_completion: None, + }, + None, + ); } let (id, err) = parse_variable(working_set, span); @@ -3168,7 +3178,7 @@ pub fn parse_block( .iter() .map(|pipeline| { if pipeline.commands.len() > 1 { - let output = pipeline + let mut output = pipeline .commands .iter() .map(|command| { @@ -3182,6 +3192,11 @@ pub fn parse_block( }) .collect::>(); + for expr in output.iter_mut().skip(1) { + if expr.has_in_variable(working_set) { + *expr = wrap_expr_with_collect(working_set, expr); + } + } Statement::Pipeline(Pipeline { expressions: output, }) @@ -3375,6 +3390,62 @@ pub fn find_captures_in_expr( output } +fn wrap_expr_with_collect(working_set: &mut StateWorkingSet, expr: &Expression) -> Expression { + let span = expr.span; + + if let Some(decl_id) = working_set.find_decl(b"collect") { + let mut output = vec![]; + + let var_id = working_set.next_var_id(); + let mut signature = Signature::new(""); + signature.required_positional.push(PositionalArg { + var_id: Some(var_id), + name: "$it".into(), + desc: String::new(), + shape: SyntaxShape::Any, + }); + + let mut expr = expr.clone(); + expr.replace_in_variable(working_set, var_id); + + let mut block = Block { + stmts: vec![Statement::Pipeline(Pipeline { + expressions: vec![expr], + })], + signature: Box::new(signature), + ..Default::default() + }; + + let mut seen = vec![]; + let captures = find_captures_in_block(working_set, &block, &mut seen); + + block.captures = captures; + + let block_id = working_set.add_block(block); + + output.push(Expression { + expr: Expr::Block(block_id), + span, + ty: Type::Unknown, + custom_completion: None, + }); + + Expression { + expr: Expr::Call(Box::new(Call { + head: Span::unknown(), + named: vec![], + positional: output, + decl_id, + })), + span, + ty: Type::String, + custom_completion: None, + } + } else { + Expression::garbage(span) + } +} + // Parses a vector of u8 to create an AST Block. If a file name is given, then // the name is stored in the working set. When parsing a source without a file // name, the source of bytes is stored as "source" diff --git a/crates/nu-protocol/src/ast/expression.rs b/crates/nu-protocol/src/ast/expression.rs index 7278cdb374..1d4eebeb4a 100644 --- a/crates/nu-protocol/src/ast/expression.rs +++ b/crates/nu-protocol/src/ast/expression.rs @@ -1,5 +1,5 @@ -use super::{Expr, Operator}; -use crate::{BlockId, Signature, Span, Type, VarId}; +use super::{Expr, Operator, Statement}; +use crate::{engine::StateWorkingSet, BlockId, Signature, Span, Type, VarId, IN_VARIABLE_ID}; #[derive(Debug, Clone)] pub struct Expression { @@ -88,4 +88,219 @@ impl Expression { _ => None, } } + + pub fn has_in_variable(&self, working_set: &StateWorkingSet) -> bool { + match &self.expr { + Expr::BinaryOp(left, _, right) => { + left.has_in_variable(working_set) || right.has_in_variable(working_set) + } + Expr::Block(block_id) => { + let block = working_set.get_block(*block_id); + + if let Some(Statement::Pipeline(pipeline)) = block.stmts.get(0) { + match pipeline.expressions.get(0) { + Some(expr) => expr.has_in_variable(working_set), + None => false, + } + } else { + false + } + } + Expr::Bool(_) => false, + Expr::Call(call) => { + for positional in &call.positional { + if positional.has_in_variable(working_set) { + return true; + } + } + for named in &call.named { + if let Some(expr) = &named.1 { + if expr.has_in_variable(working_set) { + return true; + } + } + } + false + } + Expr::CellPath(_) => false, + Expr::ExternalCall(_, _, args) => { + for arg in args { + if arg.has_in_variable(working_set) { + return true; + } + } + false + } + Expr::Filepath(_) => false, + Expr::Float(_) => false, + Expr::FullCellPath(full_cell_path) => { + if full_cell_path.head.has_in_variable(working_set) { + return true; + } + false + } + Expr::Garbage => false, + Expr::GlobPattern(_) => false, + Expr::Int(_) => false, + Expr::Keyword(_, _, expr) => expr.has_in_variable(working_set), + Expr::List(list) => { + for l in list { + if l.has_in_variable(working_set) { + return true; + } + } + false + } + Expr::Operator(_) => false, + Expr::Range(left, middle, right, ..) => { + if let Some(left) = &left { + if left.has_in_variable(working_set) { + return true; + } + } + if let Some(middle) = &middle { + if middle.has_in_variable(working_set) { + return true; + } + } + if let Some(right) = &right { + if right.has_in_variable(working_set) { + return true; + } + } + false + } + Expr::RowCondition(_, expr) => expr.has_in_variable(working_set), + Expr::Signature(_) => false, + Expr::String(_) => false, + Expr::Subexpression(block_id) => { + let block = working_set.get_block(*block_id); + + if let Some(Statement::Pipeline(pipeline)) = block.stmts.get(0) { + if let Some(expr) = pipeline.expressions.get(0) { + expr.has_in_variable(working_set) + } else { + false + } + } else { + false + } + } + Expr::Table(headers, cells) => { + for header in headers { + if header.has_in_variable(working_set) { + return true; + } + } + + for row in cells { + for cell in row.iter() { + if cell.has_in_variable(working_set) { + return true; + } + } + } + + false + } + + Expr::ValueWithUnit(expr, _) => expr.has_in_variable(working_set), + Expr::Var(var_id) => *var_id == IN_VARIABLE_ID, + Expr::VarDecl(_) => false, + } + } + + pub fn replace_in_variable(&mut self, working_set: &mut StateWorkingSet, new_var_id: VarId) { + match &mut self.expr { + Expr::BinaryOp(left, _, right) => { + left.replace_in_variable(working_set, new_var_id); + right.replace_in_variable(working_set, new_var_id); + } + Expr::Block(block_id) => { + let block = working_set.get_block_mut(*block_id); + + if let Some(Statement::Pipeline(pipeline)) = block.stmts.get_mut(0) { + if let Some(expr) = pipeline.expressions.get_mut(0) { + expr.clone().replace_in_variable(working_set, new_var_id) + } + } + } + Expr::Bool(_) => {} + Expr::Call(call) => { + for positional in &mut call.positional { + positional.replace_in_variable(working_set, new_var_id); + } + for named in &mut call.named { + if let Some(expr) = &mut named.1 { + expr.replace_in_variable(working_set, new_var_id) + } + } + } + Expr::CellPath(_) => {} + Expr::ExternalCall(_, _, args) => { + for arg in args { + arg.replace_in_variable(working_set, new_var_id) + } + } + Expr::Filepath(_) => {} + Expr::Float(_) => {} + Expr::FullCellPath(full_cell_path) => { + full_cell_path + .head + .replace_in_variable(working_set, new_var_id); + } + Expr::Garbage => {} + Expr::GlobPattern(_) => {} + Expr::Int(_) => {} + Expr::Keyword(_, _, expr) => expr.replace_in_variable(working_set, new_var_id), + Expr::List(list) => { + for l in list { + l.replace_in_variable(working_set, new_var_id) + } + } + Expr::Operator(_) => {} + Expr::Range(left, middle, right, ..) => { + if let Some(left) = left { + left.replace_in_variable(working_set, new_var_id) + } + if let Some(middle) = middle { + middle.replace_in_variable(working_set, new_var_id) + } + if let Some(right) = right { + right.replace_in_variable(working_set, new_var_id) + } + } + Expr::RowCondition(_, expr) => expr.replace_in_variable(working_set, new_var_id), + Expr::Signature(_) => {} + Expr::String(_) => {} + Expr::Subexpression(block_id) => { + let block = working_set.get_block_mut(*block_id); + + if let Some(Statement::Pipeline(pipeline)) = block.stmts.get_mut(0) { + if let Some(expr) = pipeline.expressions.get_mut(0) { + expr.clone().replace_in_variable(working_set, new_var_id) + } + } + } + Expr::Table(headers, cells) => { + for header in headers { + header.replace_in_variable(working_set, new_var_id) + } + + for row in cells { + for cell in row.iter_mut() { + cell.replace_in_variable(working_set, new_var_id) + } + } + } + + Expr::ValueWithUnit(expr, _) => expr.replace_in_variable(working_set, new_var_id), + Expr::Var(x) => { + if *x == IN_VARIABLE_ID { + *x = new_var_id + } + } + Expr::VarDecl(_) => {} + } + } } diff --git a/crates/nu-protocol/src/engine/engine_state.rs b/crates/nu-protocol/src/engine/engine_state.rs index 7464354ea2..cbe3548448 100644 --- a/crates/nu-protocol/src/engine/engine_state.rs +++ b/crates/nu-protocol/src/engine/engine_state.rs @@ -135,13 +135,14 @@ pub struct EngineState { pub const NU_VARIABLE_ID: usize = 0; pub const SCOPE_VARIABLE_ID: usize = 1; +pub const IN_VARIABLE_ID: usize = 2; impl EngineState { pub fn new() -> Self { Self { files: im::vector![], file_contents: im::vector![], - vars: im::vector![Type::Unknown, Type::Unknown], + vars: im::vector![Type::Unknown, Type::Unknown, Type::Unknown], decls: im::vector![], blocks: im::vector![], scope: im::vector![ScopeFrame::new()], @@ -857,6 +858,18 @@ impl<'a> StateWorkingSet<'a> { } } + pub fn get_block_mut(&mut self, block_id: BlockId) -> &mut Block { + let num_permanent_blocks = self.permanent_state.num_blocks(); + if block_id < num_permanent_blocks { + panic!("Attempt to mutate a block that is in the permanent (immutable) state") + } else { + self.delta + .blocks + .get_mut(block_id - num_permanent_blocks) + .expect("internal error: missing block") + } + } + pub fn render(self) -> StateDelta { self.delta } diff --git a/crates/nu-protocol/src/lib.rs b/crates/nu-protocol/src/lib.rs index 6a7bf45c11..bb1e775ecc 100644 --- a/crates/nu-protocol/src/lib.rs +++ b/crates/nu-protocol/src/lib.rs @@ -11,7 +11,7 @@ mod ty; mod value; pub use value::Value; -pub use engine::{NU_VARIABLE_ID, SCOPE_VARIABLE_ID}; +pub use engine::{IN_VARIABLE_ID, NU_VARIABLE_ID, SCOPE_VARIABLE_ID}; pub use example::*; pub use id::*; pub use pipeline_data::*; diff --git a/src/tests.rs b/src/tests.rs index 1c58272a46..99350b9c0e 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -859,3 +859,13 @@ fn where_on_ranges() -> TestResult { fn index_on_list() -> TestResult { run_test(r#"[1, 2, 3].1"#, "2") } + +#[test] +fn in_variable_1() -> TestResult { + run_test(r#"[3] | if $in.0 > 4 { "yay!" } else { "boo" }"#, "boo") +} + +#[test] +fn in_variable_2() -> TestResult { + run_test(r#"3 | if $in > 2 { "yay!" } else { "boo" }"#, "yay!") +}