Limit recursion to avoid stack overflow (#7657)

Add recursion limit to `def` and `block`.
Summary of this PR , it will detect if `def` call itself or not .
Then execute by using `stack` which I think best choice to use with this
design and core as it is available in all crates and mutable and
calculate the recursion limit on calling `def`.
Set 50 as recursion limit on `Config`.
Add some tests too .

Fixes #5899

Co-authored-by: Reilly Wood <reilly.wood@icloud.com>
This commit is contained in:
Amirhossein Akhlaghpour 2023-01-04 21:38:50 -05:00 committed by GitHub
parent 9bc4e6794d
commit 00469de93e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 85 additions and 0 deletions

View file

@ -876,6 +876,21 @@ pub fn eval_block(
redirect_stdout: bool, redirect_stdout: bool,
redirect_stderr: bool, redirect_stderr: bool,
) -> Result<PipelineData, ShellError> { ) -> Result<PipelineData, ShellError> {
// if Block contains recursion, make sure we don't recurse too deeply (to avoid stack overflow)
if let Some(recursive) = block.recursive {
// picked 50 arbitrarily, should work on all architectures
const RECURSION_LIMIT: u64 = 50;
if recursive {
if *stack.recursion_count >= RECURSION_LIMIT {
stack.recursion_count = Box::new(0);
return Err(ShellError::RecursionLimitReached {
recursion_limit: RECURSION_LIMIT,
span: block.span,
});
}
*stack.recursion_count += 1;
}
}
let num_pipelines = block.len(); let num_pipelines = block.len();
for (pipeline_idx, pipeline) in block.pipelines.iter().enumerate() { for (pipeline_idx, pipeline) in block.pipelines.iter().enumerate() {
let mut i = 0; let mut i = 0;

View file

@ -351,6 +351,41 @@ pub fn parse_def(
*declaration = signature.clone().into_block_command(block_id); *declaration = signature.clone().into_block_command(block_id);
let mut block = working_set.get_block_mut(block_id); let mut block = working_set.get_block_mut(block_id);
let calls_itself = block.pipelines.iter().any(|pipeline| {
pipeline
.elements
.iter()
.any(|pipe_element| match pipe_element {
PipelineElement::Expression(
_,
Expression {
expr: Expr::Call(call_expr),
..
},
) => {
if call_expr.decl_id == decl_id {
return true;
}
call_expr.arguments.iter().any(|arg| match arg {
Argument::Positional(Expression { expr, .. }) => match expr {
Expr::Keyword(.., expr) => {
let expr = expr.as_ref();
let Expression { expr, .. } = expr;
match expr {
Expr::Call(call_expr2) => call_expr2.decl_id == decl_id,
_ => false,
}
}
Expr::Call(call_expr2) => call_expr2.decl_id == decl_id,
_ => false,
},
_ => false,
})
}
_ => false,
})
});
block.recursive = Some(calls_itself);
block.signature = signature; block.signature = signature;
block.redirect_env = def_call == b"def-env"; block.redirect_env = def_call == b"def-env";
} else { } else {

View file

@ -11,6 +11,7 @@ pub struct Block {
pub captures: Vec<VarId>, pub captures: Vec<VarId>,
pub redirect_env: bool, pub redirect_env: bool,
pub span: Option<Span>, // None option encodes no span to avoid using test_span() pub span: Option<Span>, // None option encodes no span to avoid using test_span()
pub recursive: Option<bool>, // does the block call itself?
} }
impl Block { impl Block {
@ -51,6 +52,7 @@ impl Block {
captures: vec![], captures: vec![],
redirect_env: false, redirect_env: false,
span: None, span: None,
recursive: None,
} }
} }
} }
@ -66,6 +68,7 @@ where
captures: vec![], captures: vec![],
redirect_env: false, redirect_env: false,
span: None, span: None,
recursive: None,
} }
} }
} }

View file

@ -34,6 +34,7 @@ pub struct Stack {
pub env_hidden: HashMap<String, HashSet<String>>, pub env_hidden: HashMap<String, HashSet<String>>,
/// List of active overlays /// List of active overlays
pub active_overlays: Vec<String>, pub active_overlays: Vec<String>,
pub recursion_count: Box<u64>,
} }
impl Stack { impl Stack {
@ -43,6 +44,7 @@ impl Stack {
env_vars: vec![], env_vars: vec![],
env_hidden: HashMap::new(), env_hidden: HashMap::new(),
active_overlays: vec![DEFAULT_OVERLAY_NAME.to_string()], active_overlays: vec![DEFAULT_OVERLAY_NAME.to_string()],
recursion_count: Box::new(0),
} }
} }
@ -123,6 +125,7 @@ impl Stack {
env_vars, env_vars,
env_hidden: HashMap::new(), env_hidden: HashMap::new(),
active_overlays: self.active_overlays.clone(), active_overlays: self.active_overlays.clone(),
recursion_count: self.recursion_count.to_owned(),
} }
} }
@ -147,6 +150,7 @@ impl Stack {
env_vars, env_vars,
env_hidden: HashMap::new(), env_hidden: HashMap::new(),
active_overlays: self.active_overlays.clone(), active_overlays: self.active_overlays.clone(),
recursion_count: self.recursion_count.to_owned(),
} }
} }

View file

@ -903,6 +903,19 @@ Either make sure {0} is a string, or add a 'to_string' entry for it in ENV_CONVE
/// Return event, which may become an error if used outside of a function /// Return event, which may become an error if used outside of a function
#[error("Return used outside of function")] #[error("Return used outside of function")]
Return(#[label = "used outside of function"] Span, Box<Value>), Return(#[label = "used outside of function"] Span, Box<Value>),
/// The code being executed called itself too many times.
///
/// ## Resolution
///
/// Adjust your Nu code to
#[error("Recursion limit ({recursion_limit}) reached")]
#[diagnostic(code(nu::shell::recursion_limit_reached), url(docsrs))]
RecursionLimitReached {
recursion_limit: u64,
#[label("This called itself too many times")]
span: Option<Span>,
},
} }
impl From<std::io::Error> for ShellError { impl From<std::io::Error> for ShellError {

View file

@ -145,3 +145,18 @@ fn override_table_eval_file() {
let actual = nu!(cwd: ".", r#"def table [] { "hi" }; table"#); let actual = nu!(cwd: ".", r#"def table [] { "hi" }; table"#);
assert_eq!(actual.out, "hi"); assert_eq!(actual.out, "hi");
} }
// This test is disabled on Windows because they cause a stack overflow in CI (but not locally!).
// For reasons we don't understand, the Windows CI runners are prone to stack overflow.
// TODO: investigate so we can enable on Windows
#[cfg(not(target_os = "windows"))]
#[test]
fn infinite_recursion_does_not_panic() {
let actual = nu!(
cwd: ".",
r#"
def bang [] { bang }; bang
"#
);
assert!(actual.err.contains("Recursion limit (50) reached"));
}