From bb6781a3b10d4bae65bdf8758a6ce1a4f284abbb Mon Sep 17 00:00:00 2001 From: JT Date: Fri, 10 Sep 2021 09:47:20 +1200 Subject: [PATCH] Add row conditions --- crates/nu-command/src/default_context.rs | 6 +- crates/nu-command/src/if_.rs | 2 +- crates/nu-command/src/lib.rs | 1 + crates/nu-command/src/where_.rs | 92 ++++++++++++++++++++++++ crates/nu-engine/src/eval.rs | 1 + crates/nu-parser/src/flatten.rs | 1 + crates/nu-parser/src/parser.rs | 89 ++++++++++++++++++----- crates/nu-protocol/src/ast/expr.rs | 1 + crates/nu-protocol/src/value/mod.rs | 4 ++ src/main.rs | 4 +- src/tests.rs | 16 +++++ 11 files changed, 195 insertions(+), 22 deletions(-) create mode 100644 crates/nu-command/src/where_.rs diff --git a/crates/nu-command/src/default_context.rs b/crates/nu-command/src/default_context.rs index fbaaaf537b..c3f41be6f4 100644 --- a/crates/nu-command/src/default_context.rs +++ b/crates/nu-command/src/default_context.rs @@ -5,7 +5,9 @@ use nu_protocol::{ Signature, SyntaxShape, }; -use crate::{Alias, Benchmark, BuildString, Def, Do, Each, For, If, Length, Let, LetEnv}; +use crate::{ + where_::Where, Alias, Benchmark, BuildString, Def, Do, Each, For, If, Length, Let, LetEnv, +}; pub fn create_default_context() -> Rc> { let engine_state = Rc::new(RefCell::new(EngineState::new())); @@ -33,6 +35,8 @@ pub fn create_default_context() -> Rc> { working_set.add_decl(Box::new(Each)); + working_set.add_decl(Box::new(Where)); + working_set.add_decl(Box::new(Do)); working_set.add_decl(Box::new(Benchmark)); diff --git a/crates/nu-command/src/if_.rs b/crates/nu-command/src/if_.rs index bb785c4c2c..2b8df48671 100644 --- a/crates/nu-command/src/if_.rs +++ b/crates/nu-command/src/if_.rs @@ -11,7 +11,7 @@ impl Command for If { } fn usage(&self) -> &str { - "Create a variable and give it a value." + "Conditionally run a block." } fn signature(&self) -> nu_protocol::Signature { diff --git a/crates/nu-command/src/lib.rs b/crates/nu-command/src/lib.rs index 3e7d99e7d7..e9c29e17bd 100644 --- a/crates/nu-command/src/lib.rs +++ b/crates/nu-command/src/lib.rs @@ -10,6 +10,7 @@ mod if_; mod length; mod let_; mod let_env; +mod where_; pub use alias::Alias; pub use benchmark::Benchmark; diff --git a/crates/nu-command/src/where_.rs b/crates/nu-command/src/where_.rs new file mode 100644 index 0000000000..b876277bb0 --- /dev/null +++ b/crates/nu-command/src/where_.rs @@ -0,0 +1,92 @@ +use nu_engine::eval_expression; +use nu_protocol::ast::{Call, Expr, Expression}; +use nu_protocol::engine::{Command, EvaluationContext}; +use nu_protocol::{IntoValueStream, ShellError, Signature, SyntaxShape, Value}; + +pub struct Where; + +impl Command for Where { + fn name(&self) -> &str { + "where" + } + + fn usage(&self) -> &str { + "Filter values based on a condition." + } + + fn signature(&self) -> nu_protocol::Signature { + Signature::build("where").required("cond", SyntaxShape::RowCondition, "condition") + } + + fn run( + &self, + context: &EvaluationContext, + call: &Call, + input: Value, + ) -> Result { + let cond = call.positional[0].clone(); + + let context = context.enter_scope(); + + let (var_id, cond) = match cond { + Expression { + expr: Expr::RowCondition(var_id, expr), + .. + } => (var_id, expr), + _ => return Err(ShellError::InternalError("Expected row condition".into())), + }; + + match input { + Value::Stream { stream, span } => { + let output_stream = stream + .filter(move |value| { + context.add_var(var_id, value.clone()); + + let result = eval_expression(&context, &cond); + + match result { + Ok(result) => result.is_true(), + _ => false, + } + }) + .into_value_stream(); + + Ok(Value::Stream { + stream: output_stream, + span, + }) + } + Value::List { vals, span } => { + let output_stream = vals + .into_iter() + .filter(move |value| { + context.add_var(var_id, value.clone()); + + let result = eval_expression(&context, &cond); + + match result { + Ok(result) => result.is_true(), + _ => false, + } + }) + .into_value_stream(); + + Ok(Value::Stream { + stream: output_stream, + span, + }) + } + x => { + context.add_var(var_id, x.clone()); + + let result = eval_expression(&context, &cond)?; + + if result.is_true() { + Ok(x) + } else { + Ok(Value::Nothing { span: call.head }) + } + } + } + } +} diff --git a/crates/nu-engine/src/eval.rs b/crates/nu-engine/src/eval.rs index d04c393484..b3da15c87a 100644 --- a/crates/nu-engine/src/eval.rs +++ b/crates/nu-engine/src/eval.rs @@ -135,6 +135,7 @@ pub fn eval_expression( value.follow_cell_path(&column_path.tail) } + Expr::RowCondition(_, expr) => eval_expression(context, expr), Expr::Call(call) => eval_call(context, call, Value::nothing()), Expr::ExternalCall(_, _) => Err(ShellError::ExternalNotSupported(expr.span)), Expr::Operator(_) => Ok(Value::Nothing { span: expr.span }), diff --git a/crates/nu-parser/src/flatten.rs b/crates/nu-parser/src/flatten.rs index 7439da6181..1f92794546 100644 --- a/crates/nu-parser/src/flatten.rs +++ b/crates/nu-parser/src/flatten.rs @@ -114,6 +114,7 @@ pub fn flatten_expression( Expr::String(_) => { vec![(expr.span, FlatShape::String)] } + Expr::RowCondition(_, expr) => flatten_expression(working_set, expr), Expr::Subexpression(block_id) => { flatten_block(working_set, working_set.get_block(*block_id)) } diff --git a/crates/nu-parser/src/parser.rs b/crates/nu-parser/src/parser.rs index 57506b7fd2..0bea2fe27f 100644 --- a/crates/nu-parser/src/parser.rs +++ b/crates/nu-parser/src/parser.rs @@ -851,7 +851,7 @@ pub(crate) fn parse_dollar_expr( } else if let (expr, None) = parse_range(working_set, span) { (expr, None) } else { - parse_full_column_path(working_set, span) + parse_full_column_path(working_set, None, span) } } @@ -922,7 +922,7 @@ pub fn parse_string_interpolation( end: b + 1, }; - let (expr, err) = parse_full_column_path(working_set, span); + let (expr, err) = parse_full_column_path(working_set, None, span); error = error.or(err); output.push(expr); } @@ -957,7 +957,7 @@ pub fn parse_string_interpolation( end, }; - let (expr, err) = parse_full_column_path(working_set, span); + let (expr, err) = parse_full_column_path(working_set, None, span); error = error.or(err); output.push(expr); } @@ -1047,6 +1047,7 @@ pub fn parse_variable_expr( pub fn parse_full_column_path( working_set: &mut StateWorkingSet, + implicit_head: Option, span: Span, ) -> (Expression, Option) { // FIXME: assume for now a paren expr, but needs more @@ -1057,10 +1058,10 @@ pub fn parse_full_column_path( let (tokens, err) = lex(source, span.start, &[b'\n'], &[b'.']); error = error.or(err); - let mut tokens = tokens.into_iter(); - if let Some(head) = tokens.next() { + let mut tokens = tokens.into_iter().peekable(); + if let Some(head) = tokens.peek() { let bytes = working_set.get_span_contents(head.span); - let head = if bytes.starts_with(b"(") { + let (head, mut expect_dot) = if bytes.starts_with(b"(") { let mut start = head.span.start; let mut end = head.span.end; @@ -1085,27 +1086,42 @@ pub fn parse_full_column_path( let source = working_set.get_span_contents(span); - let (tokens, err) = lex(source, span.start, &[b'\n'], &[]); + let (output, err) = lex(source, span.start, &[b'\n'], &[]); error = error.or(err); - let (output, err) = lite_parse(&tokens); + let (output, err) = lite_parse(&output); error = error.or(err); let (output, err) = parse_block(working_set, &output, true); error = error.or(err); let block_id = working_set.add_block(output); + tokens.next(); - Expression { - expr: Expr::Subexpression(block_id), - span, - ty: Type::Unknown, // FIXME - } + ( + Expression { + expr: Expr::Subexpression(block_id), + span, + ty: Type::Unknown, // FIXME + }, + true, + ) } else if bytes.starts_with(b"$") { let (out, err) = parse_variable_expr(working_set, head.span); error = error.or(err); - out + tokens.next(); + + (out, true) + } else if let Some(var_id) = implicit_head { + ( + Expression { + expr: Expr::Var(var_id), + span: Span::unknown(), + ty: Type::Unknown, + }, + false, + ) } else { return ( garbage(span), @@ -1119,7 +1135,6 @@ pub fn parse_full_column_path( let mut tail = vec![]; - let mut expect_dot = true; for path_element in tokens { let bytes = working_set.get_span_contents(path_element.span); @@ -1293,11 +1308,40 @@ pub fn parse_var_with_opt_type( ) } } + +pub fn expand_to_cell_path( + working_set: &mut StateWorkingSet, + expression: &mut Expression, + var_id: VarId, +) { + if let Expression { + expr: Expr::String(_), + span, + .. + } = expression + { + // Re-parse the string as if it were a cell-path + let (new_expression, _err) = parse_full_column_path(working_set, Some(var_id), *span); + + *expression = new_expression; + } +} + pub fn parse_row_condition( working_set: &mut StateWorkingSet, spans: &[Span], ) -> (Expression, Option) { - parse_math_expression(working_set, spans) + let var_id = working_set.add_variable(b"$it".to_vec(), Type::Unknown); + let (expression, err) = parse_math_expression(working_set, spans, Some(var_id)); + let span = span(spans); + ( + Expression { + ty: Type::Bool, + span, + expr: Expr::RowCondition(var_id, Box::new(expression)), + }, + err, + ) } pub fn parse_signature( @@ -1995,7 +2039,7 @@ pub fn parse_value( if let (expr, None) = parse_range(working_set, span) { return (expr, None); } else { - return parse_full_column_path(working_set, span); + return parse_full_column_path(working_set, None, span); } } else if bytes.starts_with(b"{") { if matches!(shape, SyntaxShape::Block) || matches!(shape, SyntaxShape::Any) { @@ -2142,6 +2186,7 @@ pub fn parse_operator( pub fn parse_math_expression( working_set: &mut StateWorkingSet, spans: &[Span], + lhs_row_var_id: Option, ) -> (Expression, Option) { // As the expr_stack grows, we increase the required precedence to grow larger // If, at any time, the operator we're looking at is the same or lower precedence @@ -2200,6 +2245,10 @@ pub fn parse_math_expression( .pop() .expect("internal error: expression stack empty"); + if let Some(row_var_id) = lhs_row_var_id { + expand_to_cell_path(working_set, &mut lhs, row_var_id); + } + let (result_ty, err) = math_result_type(working_set, &mut lhs, &mut op, &mut rhs); error = error.or(err); @@ -2230,6 +2279,10 @@ pub fn parse_math_expression( .pop() .expect("internal error: expression stack empty"); + if let Some(row_var_id) = lhs_row_var_id { + expand_to_cell_path(working_set, &mut lhs, row_var_id); + } + let (result_ty, err) = math_result_type(working_set, &mut lhs, &mut op, &mut rhs); error = error.or(err); @@ -2256,7 +2309,7 @@ pub fn parse_expression( match bytes[0] { b'0' | b'1' | b'2' | b'3' | b'4' | b'5' | b'6' | b'7' | b'8' | b'9' | b'(' | b'{' - | b'[' | b'$' | b'"' | b'\'' | b'-' => parse_math_expression(working_set, spans), + | b'[' | b'$' | b'"' | b'\'' | b'-' => parse_math_expression(working_set, spans, None), _ => parse_call(working_set, spans, true), } } diff --git a/crates/nu-protocol/src/ast/expr.rs b/crates/nu-protocol/src/ast/expr.rs index ac29cd24ba..5a873631e8 100644 --- a/crates/nu-protocol/src/ast/expr.rs +++ b/crates/nu-protocol/src/ast/expr.rs @@ -15,6 +15,7 @@ pub enum Expr { Call(Box), ExternalCall(Vec, Vec>), Operator(Operator), + RowCondition(VarId, Box), BinaryOp(Box, Box, Box), //lhs, op, rhs Subexpression(BlockId), Block(BlockId), diff --git a/crates/nu-protocol/src/value/mod.rs b/crates/nu-protocol/src/value/mod.rs index d02f1ce2d4..23c1b89a90 100644 --- a/crates/nu-protocol/src/value/mod.rs +++ b/crates/nu-protocol/src/value/mod.rs @@ -277,6 +277,10 @@ impl Value { Ok(current) } + + pub fn is_true(&self) -> bool { + matches!(self, Value::Bool { val: true, .. }) + } } impl PartialEq for Value { diff --git a/src/main.rs b/src/main.rs index 09b78404c2..a8015245c9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::{arch::x86_64::_CMP_EQ_OQ, cell::RefCell, rc::Rc}; +use std::{cell::RefCell, rc::Rc}; use nu_cli::{report_parsing_error, report_shell_error, NuHighlighter}; use nu_command::create_default_context; @@ -164,7 +164,7 @@ impl Completer for EQCompleter { let mut working_set = StateWorkingSet::new(&*engine_state); let offset = working_set.next_span_start(); let pos = offset + pos; - let (output, err) = parse(&mut working_set, Some("completer"), line.as_bytes(), false); + let (output, _err) = parse(&mut working_set, Some("completer"), line.as_bytes(), false); let flattened = flatten_block(&working_set, &output); diff --git a/src/tests.rs b/src/tests.rs index d5cb6d3f04..df619f2b91 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -292,3 +292,19 @@ fn row_iteration() -> TestResult { fn record_iteration() -> TestResult { run_test("([[name, level]; [aa, 100], [bb, 200]] | each { $it | each { |x| if $x.column == \"level\" { $x.value + 100 } else { $x.value } } }).level", "[200, 300]") } + +#[test] +fn row_condition1() -> TestResult { + run_test( + "([[name, size]; [a, 1], [b, 2], [c, 3]] | where size < 3).name", + "[a, b]", + ) +} + +#[test] +fn row_condition2() -> TestResult { + run_test( + "[[name, size]; [a, 1], [b, 2], [c, 3]] | where $it.size > 2 | length", + "1", + ) +}