diff --git a/crates/nu-cmd-lang/src/core_commands/match_.rs b/crates/nu-cmd-lang/src/core_commands/match_.rs index b270ec7983..8f5032d15c 100644 --- a/crates/nu-cmd-lang/src/core_commands/match_.rs +++ b/crates/nu-cmd-lang/src/core_commands/match_.rs @@ -1,4 +1,4 @@ -use nu_engine::{eval_block, eval_expression_with_input, CallExt}; +use nu_engine::{eval_block, eval_expression, eval_expression_with_input, CallExt}; use nu_protocol::ast::{Call, Expr, Expression}; use nu_protocol::engine::{Command, EngineState, Matcher, Stack}; use nu_protocol::{ @@ -52,26 +52,38 @@ impl Command for Match { stack.add_var(match_variable.0, match_variable.1); } - if let Some(block_id) = match_.1.as_block() { - let block = engine_state.get_block(block_id); - return eval_block( - engine_state, - stack, - block, - input, - call.redirect_stdout, - call.redirect_stderr, - ); + let guard_matches = if let Some(guard) = &match_.0.guard { + let Value::Bool { val, .. } = eval_expression(engine_state, stack, guard)? else { + return Err(ShellError::MatchGuardNotBool { span: guard.span}); + }; + + val } else { - return eval_expression_with_input( - engine_state, - stack, - &match_.1, - input, - call.redirect_stdout, - call.redirect_stderr, - ) - .map(|x| x.0); + true + }; + + if guard_matches { + return if let Some(block_id) = match_.1.as_block() { + let block = engine_state.get_block(block_id); + eval_block( + engine_state, + stack, + block, + input, + call.redirect_stdout, + call.redirect_stderr, + ) + } else { + eval_expression_with_input( + engine_state, + stack, + &match_.1, + input, + call.redirect_stdout, + call.redirect_stderr, + ) + .map(|x| x.0) + }; } } } @@ -107,6 +119,16 @@ impl Command for Match { example: "{a: {b: 3}} | match $in {{a: { $b }} => ($b + 10) }", result: Some(Value::test_int(13)), }, + Example { + description: "Match with a guard", + example: " + match [1 2 3] { + [$x, ..$y] if $x == 1 => { 'good list' }, + _ => { 'not a very good list' } + } + ", + result: Some(Value::test_string("good list")), + }, ] } } diff --git a/crates/nu-command/tests/commands/match_.rs b/crates/nu-command/tests/commands/match_.rs index 0540d3d622..7cb6c3f1a8 100644 --- a/crates/nu-command/tests/commands/match_.rs +++ b/crates/nu-command/tests/commands/match_.rs @@ -197,3 +197,54 @@ fn match_doesnt_overwrite_variable() { // As we do not auto-print loops anymore assert_eq!(actual.out, "100"); } + +#[test] +fn match_with_guard() { + let actual = nu!( + cwd: ".", + "match [1 2 3] { [$x, ..] if $x mod 2 == 0 => { $x }, $x => { 2 } }" + ); + + assert_eq!(actual.out, "2"); +} + +#[test] +fn match_with_guard_block_as_guard() { + // this should work? + let actual = nu!( + cwd: ".", + "match 4 { $x if { $x + 20 > 25 } => { 'good num' }, _ => { 'terrible num' } }" + ); + + assert!(actual.err.contains("Match guard not bool")); +} + +#[test] +fn match_with_guard_parens_expr_as_guard() { + let actual = nu!( + cwd: ".", + "match 4 { $x if ($x + 20 > 25) => { 'good num' }, _ => { 'terrible num' } }" + ); + + assert_eq!(actual.out, "terrible num"); +} + +#[test] +fn match_with_guard_not_bool() { + let actual = nu!( + cwd: ".", + "match 4 { $x if $x + 1 => { 'err!()' }, _ => { 'unreachable!()' } }" + ); + + assert!(actual.err.contains("Match guard not bool")); +} + +#[test] +fn match_with_guard_no_expr_after_if() { + let actual = nu!( + cwd: ".", + "match 4 { $x if => { 'err!()' }, _ => { 'unreachable!()' } }" + ); + + assert!(actual.err.contains("Match guard without an expression")); +} diff --git a/crates/nu-parser/src/parse_patterns.rs b/crates/nu-parser/src/parse_patterns.rs index d9e801872c..661bd43d8b 100644 --- a/crates/nu-parser/src/parse_patterns.rs +++ b/crates/nu-parser/src/parse_patterns.rs @@ -13,6 +13,7 @@ use crate::{ pub fn garbage(span: Span) -> MatchPattern { MatchPattern { pattern: Pattern::Garbage, + guard: None, span, } } @@ -45,6 +46,7 @@ pub fn parse_pattern(working_set: &mut StateWorkingSet, span: Span) -> MatchPatt } else if bytes == b"_" { MatchPattern { pattern: Pattern::IgnoreValue, + guard: None, span, } } else { @@ -53,6 +55,7 @@ pub fn parse_pattern(working_set: &mut StateWorkingSet, span: Span) -> MatchPatt MatchPattern { pattern: Pattern::Value(value), + guard: None, span, } } @@ -78,6 +81,7 @@ pub fn parse_variable_pattern(working_set: &mut StateWorkingSet, span: Span) -> if let Some(var_id) = parse_variable_pattern_helper(working_set, span) { MatchPattern { pattern: Pattern::Variable(var_id), + guard: None, span, } } else { @@ -126,6 +130,7 @@ pub fn parse_list_pattern(working_set: &mut StateWorkingSet, span: Span) -> Matc if contents == b".." { args.push(MatchPattern { pattern: Pattern::IgnoreRest, + guard: None, span: command.parts[spans_idx], }); break; @@ -139,6 +144,7 @@ pub fn parse_list_pattern(working_set: &mut StateWorkingSet, span: Span) -> Matc ) { args.push(MatchPattern { pattern: Pattern::Rest(var_id), + guard: None, span: command.parts[spans_idx], }); break; @@ -163,6 +169,7 @@ pub fn parse_list_pattern(working_set: &mut StateWorkingSet, span: Span) -> Matc MatchPattern { pattern: Pattern::List(args), + guard: None, span, } } @@ -232,6 +239,7 @@ pub fn parse_record_pattern(working_set: &mut StateWorkingSet, span: Span) -> Ma MatchPattern { pattern: Pattern::Record(output), + guard: None, span, } } diff --git a/crates/nu-parser/src/parser.rs b/crates/nu-parser/src/parser.rs index ef738b5d00..2cc48d3122 100644 --- a/crates/nu-parser/src/parser.rs +++ b/crates/nu-parser/src/parser.rs @@ -4270,8 +4270,9 @@ pub fn parse_match_block_expression(working_set: &mut StateWorkingSet, span: Spa break; } - // Multiple patterns connected by '|' let mut connector = working_set.get_span_contents(output[position].span); + + // Multiple patterns connected by '|' if connector == b"|" && position < output.len() { let mut or_pattern = vec![pattern]; @@ -4322,10 +4323,56 @@ pub fn parse_match_block_expression(working_set: &mut StateWorkingSet, span: Spa pattern = MatchPattern { pattern: Pattern::Or(or_pattern), + guard: None, span: Span::new(start, end), } - } + // A match guard + } else if connector == b"if" { + let if_end = { + let end = output[position].span.end; + Span::new(end, end) + }; + position += 1; + + let mk_err = || ParseError::LabeledErrorWithHelp { + error: "Match guard without an expression".into(), + label: "expected an expression".into(), + help: "The `if` keyword must be followed with an expression".into(), + span: if_end, + }; + + if output.get(position).is_none() { + working_set.error(mk_err()); + return garbage(span); + }; + + let (tokens, found) = if let Some((pos, _)) = output[position..] + .iter() + .find_position(|t| working_set.get_span_contents(t.span) == b"=>") + { + if position + pos == position { + working_set.error(mk_err()); + return garbage(span); + } + + (&output[position..position + pos], true) + } else { + (&output[position..], false) + }; + + let mut start = 0; + let guard = parse_multispan_value( + working_set, + &tokens.iter().map(|tok| tok.span).collect_vec(), + &mut start, + &SyntaxShape::MathExpression, + ); + + pattern.guard = Some(guard); + position += if found { start + 1 } else { start }; + connector = working_set.get_span_contents(output[position].span); + } // Then the `=>` arrow if connector != b"=>" { working_set.error(ParseError::Mismatch( diff --git a/crates/nu-protocol/src/ast/match_pattern.rs b/crates/nu-protocol/src/ast/match_pattern.rs index c04bf9e9e3..b8f87c3f63 100644 --- a/crates/nu-protocol/src/ast/match_pattern.rs +++ b/crates/nu-protocol/src/ast/match_pattern.rs @@ -1,12 +1,11 @@ -use serde::{Deserialize, Serialize}; - -use crate::{Span, VarId}; - use super::Expression; +use crate::{Span, VarId}; +use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct MatchPattern { pub pattern: Pattern, + pub guard: Option, pub span: Span, } diff --git a/crates/nu-protocol/src/shell_error.rs b/crates/nu-protocol/src/shell_error.rs index 210fe80788..60278041ce 100644 --- a/crates/nu-protocol/src/shell_error.rs +++ b/crates/nu-protocol/src/shell_error.rs @@ -1073,6 +1073,18 @@ pub enum ShellError { #[label("This operation was interrupted")] span: Option, }, + + /// An attempt to use, as a match guard, an expression that + /// does not resolve into a boolean + #[error("Match guard not bool")] + #[diagnostic( + code(nu::shell::match_guard_not_bool), + help("Match guards should evaluate to a boolean") + )] + MatchGuardNotBool { + #[label("not a boolean expression")] + span: Span, + }, } impl From for ShellError {