diff --git a/crates/nu-parser/src/parser.rs b/crates/nu-parser/src/parser.rs index cf4b00f3ef..628178a5d9 100644 --- a/crates/nu-parser/src/parser.rs +++ b/crates/nu-parser/src/parser.rs @@ -3197,6 +3197,7 @@ pub fn parse_block( *expr = wrap_expr_with_collect(working_set, expr); } } + Statement::Pipeline(Pipeline { expressions: output, }) diff --git a/crates/nu-protocol/src/ast/expression.rs b/crates/nu-protocol/src/ast/expression.rs index 1d4eebeb4a..6b2d0084e3 100644 --- a/crates/nu-protocol/src/ast/expression.rs +++ b/crates/nu-protocol/src/ast/expression.rs @@ -217,13 +217,35 @@ impl Expression { right.replace_in_variable(working_set, new_var_id); } Expr::Block(block_id) => { + let block = working_set.get_block(*block_id); + + let new_expr = if let Some(Statement::Pipeline(pipeline)) = block.stmts.get(0) { + if let Some(expr) = pipeline.expressions.get(0) { + let mut new_expr = expr.clone(); + new_expr.replace_in_variable(working_set, new_var_id); + Some(new_expr) + } else { + None + } + } else { + None + }; + let block = working_set.get_block_mut(*block_id); - if let Some(Statement::Pipeline(pipeline)) = block.stmts.get_mut(0) { - if let Some(expr) = pipeline.expressions.get_mut(0) { - expr.clone().replace_in_variable(working_set, new_var_id) + if let Some(new_expr) = new_expr { + if let Some(Statement::Pipeline(pipeline)) = block.stmts.get_mut(0) { + if let Some(expr) = pipeline.expressions.get_mut(0) { + *expr = new_expr + } } } + + block.captures = block + .captures + .iter() + .map(|x| if *x != IN_VARIABLE_ID { *x } else { new_var_id }) + .collect(); } Expr::Bool(_) => {} Expr::Call(call) => { @@ -274,13 +296,35 @@ impl Expression { Expr::Signature(_) => {} Expr::String(_) => {} Expr::Subexpression(block_id) => { + let block = working_set.get_block(*block_id); + + let new_expr = if let Some(Statement::Pipeline(pipeline)) = block.stmts.get(0) { + if let Some(expr) = pipeline.expressions.get(0) { + let mut new_expr = expr.clone(); + new_expr.replace_in_variable(working_set, new_var_id); + Some(new_expr) + } else { + None + } + } else { + None + }; + let block = working_set.get_block_mut(*block_id); - if let Some(Statement::Pipeline(pipeline)) = block.stmts.get_mut(0) { - if let Some(expr) = pipeline.expressions.get_mut(0) { - expr.clone().replace_in_variable(working_set, new_var_id) + if let Some(new_expr) = new_expr { + if let Some(Statement::Pipeline(pipeline)) = block.stmts.get_mut(0) { + if let Some(expr) = pipeline.expressions.get_mut(0) { + *expr = new_expr + } } } + + block.captures = block + .captures + .iter() + .map(|x| if *x != IN_VARIABLE_ID { *x } else { new_var_id }) + .collect(); } Expr::Table(headers, cells) => { for header in headers { diff --git a/src/tests.rs b/src/tests.rs index 99350b9c0e..2895a78347 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -869,3 +869,23 @@ fn in_variable_1() -> TestResult { fn in_variable_2() -> TestResult { run_test(r#"3 | if $in > 2 { "yay!" } else { "boo" }"#, "yay!") } + +#[test] +fn in_variable_3() -> TestResult { + run_test(r#"3 | if $in > 4 { "yay!" } else { $in }"#, "3") +} + +#[test] +fn in_variable_4() -> TestResult { + run_test(r#"3 | do { $in }"#, "3") +} + +#[test] +fn in_variable_5() -> TestResult { + run_test(r#"3 | if $in > 2 { $in - 10 } else { $in * 10 }"#, "-7") +} + +#[test] +fn in_variable_6() -> TestResult { + run_test(r#"3 | if $in > 6 { $in - 10 } else { $in * 10 }"#, "30") +}