Enforce call stack depth limit for all calls (#11729)

# Description
Previously, only direcly-recursive calls were checked for recursion
depth. But most recursive calls in nushell are mutually recursive since
expressions like `for`, `where`, `try` and `do` all execute a separte
block.

```nushell
def f [] {
    do { f }
}
```
Calling `f` would crash nushell with a stack overflow.

I think the only general way to prevent such a stack overflow is to
enforce a maximum call stack depth instead of only disallowing directly
recursive calls.

This commit also moves that logic into `eval_call()` instead of
`eval_block()` because the recursion limit is tracked in the `Stack`,
but not all blocks are evaluated in a new stack. Incrementing the
recursion depth of the caller's stack would permanently increment that
for all future calls.

Fixes #11667

# User-Facing Changes
Any function call can now fail with `recursion_limit_reached` instead of
just directly recursive calls. Mutually-recursive calls no longer crash
nushell.

# After Submitting
<!-- If your PR had any user-facing changes, update [the
documentation](https://github.com/nushell/nushell.github.io) after the
PR is merged, if necessary. This will help us keep the docs up to date.
-->
This commit is contained in:
TrMen 2024-02-07 23:42:24 +01:00 committed by GitHub
parent 366348dea0
commit 4b91ed57dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 40 additions and 63 deletions

View file

@ -81,3 +81,15 @@ fn catch_block_can_use_error_object() {
let output = nu!("try {1 / 0} catch {|err| print ($err | get msg)}"); let output = nu!("try {1 / 0} catch {|err| print ($err | get msg)}");
assert_eq!(output.out, "Division by zero.") assert_eq!(output.out, "Division by zero.")
} }
// This test is disabled on Windows because they cause a stack overflow in CI (but not locally!).
// For reasons we don't understand, the Windows CI runners are prone to stack overflow.
// TODO: investigate so we can enable on Windows
#[cfg(not(target_os = "windows"))]
#[test]
fn can_catch_infinite_recursion() {
let actual = nu!(r#"
def bang [] { try { bang } catch { "Caught infinite recursion" } }; bang
"#);
assert_eq!(actual.out, "Caught infinite recursion");
}

View file

@ -42,6 +42,21 @@ pub fn eval_call(
let mut callee_stack = caller_stack.gather_captures(engine_state, &block.captures); let mut callee_stack = caller_stack.gather_captures(engine_state, &block.captures);
// Rust does not check recursion limits outside of const evaluation.
// But nu programs run in the same process as the shell.
// To prevent a stack overflow in user code from crashing the shell,
// we limit the recursion depth of function calls.
// Picked 50 arbitrarily, should work on all architectures.
const MAXIMUM_CALL_STACK_DEPTH: u64 = 50;
callee_stack.recursion_count += 1;
if callee_stack.recursion_count > MAXIMUM_CALL_STACK_DEPTH {
callee_stack.recursion_count = 0;
return Err(ShellError::RecursionLimitReached {
recursion_limit: MAXIMUM_CALL_STACK_DEPTH,
span: block.span,
});
}
for (param_idx, (param, required)) in decl for (param_idx, (param, required)) in decl
.signature() .signature()
.required_positional .required_positional
@ -635,22 +650,6 @@ pub fn eval_block(
redirect_stdout: bool, redirect_stdout: bool,
redirect_stderr: bool, redirect_stderr: bool,
) -> Result<PipelineData, ShellError> { ) -> Result<PipelineData, ShellError> {
// if Block contains recursion, make sure we don't recurse too deeply (to avoid stack overflow)
if let Some(recursive) = block.recursive {
// picked 50 arbitrarily, should work on all architectures
const RECURSION_LIMIT: u64 = 50;
if recursive {
if stack.recursion_count >= RECURSION_LIMIT {
stack.recursion_count = 0;
return Err(ShellError::RecursionLimitReached {
recursion_limit: RECURSION_LIMIT,
span: block.span,
});
}
stack.recursion_count += 1;
}
}
let num_pipelines = block.len(); let num_pipelines = block.len();
for (pipeline_idx, pipeline) in block.pipelines.iter().enumerate() { for (pipeline_idx, pipeline) in block.pipelines.iter().enumerate() {

View file

@ -598,8 +598,6 @@ pub fn parse_def(
*declaration = signature.clone().into_block_command(block_id); *declaration = signature.clone().into_block_command(block_id);
let block = working_set.get_block_mut(block_id); let block = working_set.get_block_mut(block_id);
let calls_itself = block_calls_itself(block, decl_id);
block.recursive = Some(calls_itself);
block.signature = signature; block.signature = signature;
block.redirect_env = has_env; block.redirect_env = has_env;
@ -758,10 +756,7 @@ pub fn parse_extern(
} else { } else {
*declaration = signature.clone().into_block_command(block_id); *declaration = signature.clone().into_block_command(block_id);
let block = working_set.get_block_mut(block_id); working_set.get_block_mut(block_id).signature = signature;
let calls_itself = block_calls_itself(block, decl_id);
block.recursive = Some(calls_itself);
block.signature = signature;
} }
} else { } else {
let decl = KnownExternal { let decl = KnownExternal {
@ -799,43 +794,6 @@ pub fn parse_extern(
}]) }])
} }
fn block_calls_itself(block: &Block, decl_id: usize) -> bool {
block.pipelines.iter().any(|pipeline| {
pipeline
.elements
.iter()
.any(|pipe_element| match pipe_element {
PipelineElement::Expression(
_,
Expression {
expr: Expr::Call(call_expr),
..
},
) => {
if call_expr.decl_id == decl_id {
return true;
}
call_expr.arguments.iter().any(|arg| match arg {
Argument::Positional(Expression { expr, .. }) => match expr {
Expr::Keyword(.., expr) => {
let expr = expr.as_ref();
let Expression { expr, .. } = expr;
match expr {
Expr::Call(call_expr2) => call_expr2.decl_id == decl_id,
_ => false,
}
}
Expr::Call(call_expr2) => call_expr2.decl_id == decl_id,
_ => false,
},
_ => false,
})
}
_ => false,
})
})
}
pub fn parse_alias( pub fn parse_alias(
working_set: &mut StateWorkingSet, working_set: &mut StateWorkingSet,
lite_command: &LiteCommand, lite_command: &LiteCommand,

View file

@ -10,7 +10,6 @@ pub struct Block {
pub captures: Vec<VarId>, pub captures: Vec<VarId>,
pub redirect_env: bool, pub redirect_env: bool,
pub span: Option<Span>, // None option encodes no span to avoid using test_span() pub span: Option<Span>, // None option encodes no span to avoid using test_span()
pub recursive: Option<bool>, // does the block call itself?
} }
impl Block { impl Block {
@ -51,7 +50,6 @@ impl Block {
captures: vec![], captures: vec![],
redirect_env: false, redirect_env: false,
span: None, span: None,
recursive: None,
} }
} }
@ -62,7 +60,6 @@ impl Block {
captures: vec![], captures: vec![],
redirect_env: false, redirect_env: false,
span: None, span: None,
recursive: None,
} }
} }
@ -97,7 +94,6 @@ where
captures: vec![], captures: vec![],
redirect_env: false, redirect_env: false,
span: None, span: None,
recursive: None,
} }
} }
} }

View file

@ -214,6 +214,18 @@ fn infinite_recursion_does_not_panic() {
assert!(actual.err.contains("Recursion limit (50) reached")); assert!(actual.err.contains("Recursion limit (50) reached"));
} }
// This test is disabled on Windows because they cause a stack overflow in CI (but not locally!).
// For reasons we don't understand, the Windows CI runners are prone to stack overflow.
// TODO: investigate so we can enable on Windows
#[cfg(not(target_os = "windows"))]
#[test]
fn infinite_mutual_recursion_does_not_panic() {
let actual = nu!(r#"
def bang [] { def boom [] { bang }; boom }; bang
"#);
assert!(actual.err.contains("Recursion limit (50) reached"));
}
#[test] #[test]
fn type_check_for_during_eval() -> TestResult { fn type_check_for_during_eval() -> TestResult {
fail_test( fail_test(