diff --git a/crates/nu-cmd-lang/src/core_commands/if_.rs b/crates/nu-cmd-lang/src/core_commands/if_.rs index 43c97bbe7b..407afafbab 100644 --- a/crates/nu-cmd-lang/src/core_commands/if_.rs +++ b/crates/nu-cmd-lang/src/core_commands/if_.rs @@ -1,6 +1,7 @@ use nu_engine::{eval_block, eval_expression, eval_expression_with_input, CallExt}; use nu_protocol::ast::Call; -use nu_protocol::engine::{Block, Command, EngineState, Stack}; +use nu_protocol::engine::{Block, Command, EngineState, Stack, StateWorkingSet}; +use nu_protocol::eval_const::{eval_const_subexpression, eval_constant, eval_constant_with_input}; use nu_protocol::{ Category, Example, PipelineData, ShellError, Signature, SyntaxShape, Type, Value, }; @@ -40,6 +41,60 @@ impl Command for If { .category(Category::Core) } + fn is_const(&self) -> bool { + true + } + + fn run_const( + &self, + working_set: &StateWorkingSet, + call: &Call, + input: PipelineData, + ) -> Result { + let cond = call.positional_nth(0).expect("checked through parser"); + let then_block: Block = call.req_const(working_set, 1)?; + let else_case = call.positional_nth(2); + + let result = eval_constant(working_set, cond)?; + match &result { + Value::Bool { val, .. } => { + if *val { + let block = working_set.get_block(then_block.block_id); + eval_const_subexpression( + working_set, + block, + input, + block.span.unwrap_or(call.head), + ) + } else if let Some(else_case) = else_case { + if let Some(else_expr) = else_case.as_keyword() { + if let Some(block_id) = else_expr.as_block() { + let block = working_set.get_block(block_id); + eval_const_subexpression( + working_set, + block, + input, + block.span.unwrap_or(call.head), + ) + } else { + eval_constant_with_input(working_set, else_expr, input) + } + } else { + eval_constant_with_input(working_set, else_case, input) + } + } else { + Ok(PipelineData::empty()) + } + } + x => Err(ShellError::CantConvert { + to_type: "bool".into(), + from_type: x.get_type().to_string(), + span: result.span(), + help: None, + }), + } + } + fn run( &self, engine_state: &EngineState, diff --git a/crates/nu-protocol/src/eval_const.rs b/crates/nu-protocol/src/eval_const.rs index 10a70e003a..4ca2129e4e 100644 --- a/crates/nu-protocol/src/eval_const.rs +++ b/crates/nu-protocol/src/eval_const.rs @@ -207,16 +207,16 @@ fn eval_const_call( decl.run_const(working_set, call, input) } -fn eval_const_subexpression( +pub fn eval_const_subexpression( working_set: &StateWorkingSet, - expr: &Expression, block: &Block, mut input: PipelineData, + span: Span, ) -> Result { for pipeline in block.pipelines.iter() { for element in pipeline.elements.iter() { let PipelineElement::Expression(_, expr) = element else { - return Err(ShellError::NotAConstant(expr.span)); + return Err(ShellError::NotAConstant(span)); }; input = eval_constant_with_input(working_set, expr, input)? @@ -226,7 +226,7 @@ fn eval_const_subexpression( Ok(input) } -fn eval_constant_with_input( +pub fn eval_constant_with_input( working_set: &StateWorkingSet, expr: &Expression, input: PipelineData, @@ -235,7 +235,7 @@ fn eval_constant_with_input( Expr::Call(call) => eval_const_call(working_set, call, input), Expr::Subexpression(block_id) => { let block = working_set.get_block(*block_id); - eval_const_subexpression(working_set, expr, block, input) + eval_const_subexpression(working_set, block, input, expr.span) } _ => eval_constant(working_set, expr).map(|v| PipelineData::Value(v, None)), } @@ -341,7 +341,7 @@ pub fn eval_constant( Expr::Subexpression(block_id) => { let block = working_set.get_block(*block_id); Ok( - eval_const_subexpression(working_set, expr, block, PipelineData::empty())? + eval_const_subexpression(working_set, block, PipelineData::empty(), expr.span)? .into_value(expr.span), ) } @@ -471,6 +471,7 @@ pub fn eval_constant( Operator::Assignment(_) => Err(ShellError::NotAConstant(expr.span)), } } + Expr::Block(block_id) => Ok(Value::block(*block_id, expr.span)), _ => Err(ShellError::NotAConstant(expr.span)), } } diff --git a/tests/const_/mod.rs b/tests/const_/mod.rs index 7e033d6412..cc35d5f6f4 100644 --- a/tests/const_/mod.rs +++ b/tests/const_/mod.rs @@ -342,3 +342,16 @@ fn version_const() { let actual = nu!("const x = (version); $x"); assert!(actual.err.is_empty()); } + +#[test] +fn if_const() { + let actual = nu!("const x = (if 2 < 3 { 'yes!' }); $x"); + assert_eq!(actual.out, "yes!"); + + let actual = nu!("const x = (if 5 < 3 { 'yes!' } else { 'no!' }); $x"); + assert_eq!(actual.out, "no!"); + + let actual = + nu!("const x = (if 5 < 3 { 'yes!' } else if 4 < 5 { 'no!' } else { 'okay!' }); $x"); + assert_eq!(actual.out, "no!"); +}