From 3ce5c66ca1dbe99f107b4a84f1f8bf37db831740 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Sun, 27 Jun 2021 01:11:57 +0200 Subject: [PATCH] Deduplicate ast expression walking logic --- crates/ide/src/highlight_related.rs | 162 +-------- .../handlers/wrap_return_type_in_result.rs | 333 +++--------------- crates/ide_db/src/helpers.rs | 112 +++++- crates/syntax/src/ast/node_ext.rs | 52 ++- 4 files changed, 218 insertions(+), 441 deletions(-) diff --git a/crates/ide/src/highlight_related.rs b/crates/ide/src/highlight_related.rs index 42ebfce373..2d27fb45e3 100644 --- a/crates/ide/src/highlight_related.rs +++ b/crates/ide/src/highlight_related.rs @@ -2,13 +2,13 @@ use hir::Semantics; use ide_db::{ base_db::FilePosition, defs::Definition, - helpers::pick_best_token, + helpers::{for_each_break_expr, for_each_tail_expr, pick_best_token}, search::{FileReference, ReferenceAccess, SearchScope}, RootDatabase, }; use syntax::{ ast::{self, LoopBodyOwner}, - match_ast, AstNode, SyntaxNode, SyntaxToken, TextRange, WalkEvent, T, + match_ast, AstNode, SyntaxNode, SyntaxToken, TextRange, T, }; use crate::{display::TryToNav, references, NavigationTarget}; @@ -95,7 +95,7 @@ fn highlight_exit_points( ) -> Option> { let mut highlights = Vec::new(); let body = body?; - walk(&body, &mut |expr| match expr { + body.walk(&mut |expr| match expr { ast::Expr::ReturnExpr(expr) => { if let Some(token) = expr.return_token() { highlights.push(HighlightedRange { access: None, range: token.text_range() }); @@ -120,7 +120,7 @@ fn highlight_exit_points( }; if let Some(tail) = tail { - for_each_inner_tail(&tail, &mut |tail| { + for_each_tail_expr(&tail, &mut |tail| { let range = match tail { ast::Expr::BreakExpr(b) => b .break_token() @@ -161,7 +161,7 @@ fn highlight_break_points(token: SyntaxToken) -> Option> { label.as_ref().map(|it| it.syntax().text_range()), ); highlights.extend(range.map(|range| HighlightedRange { access: None, range })); - for_each_break(label, body, &mut |break_| { + for_each_break_expr(label, body, &mut |break_| { let range = cover_range( break_.break_token().map(|it| it.text_range()), break_.lifetime().map(|it| it.syntax().text_range()), @@ -216,7 +216,7 @@ fn highlight_yield_points(token: SyntaxToken) -> Option> { let mut highlights = Vec::new(); highlights.push(HighlightedRange { access: None, range: async_token?.text_range() }); if let Some(body) = body { - walk(&body, &mut |expr| { + body.walk(&mut |expr| { if let ast::Expr::AwaitExpr(expr) = expr { if let Some(token) = expr.await_token() { highlights @@ -240,156 +240,6 @@ fn highlight_yield_points(token: SyntaxToken) -> Option> { None } -/// Preorder walk all the expression's child expressions -fn walk(expr: &ast::Expr, cb: &mut dyn FnMut(ast::Expr)) { - let mut preorder = expr.syntax().preorder(); - while let Some(event) = preorder.next() { - let node = match event { - WalkEvent::Enter(node) => node, - WalkEvent::Leave(_) => continue, - }; - match ast::Stmt::cast(node.clone()) { - // recursively walk the initializer, skipping potential const pat expressions - // lets statements aren't usually nested too deeply so this is fine to recurse on - Some(ast::Stmt::LetStmt(l)) => { - if let Some(expr) = l.initializer() { - walk(&expr, cb); - } - preorder.skip_subtree(); - } - // Don't skip subtree since we want to process the expression child next - Some(ast::Stmt::ExprStmt(_)) => (), - // skip inner items which might have their own expressions - Some(ast::Stmt::Item(_)) => preorder.skip_subtree(), - None => { - if let Some(expr) = ast::Expr::cast(node) { - let is_different_context = match &expr { - ast::Expr::EffectExpr(effect) => { - matches!( - effect.effect(), - ast::Effect::Async(_) | ast::Effect::Try(_) | ast::Effect::Const(_) - ) - } - ast::Expr::ClosureExpr(__) => true, - _ => false, - }; - cb(expr); - if is_different_context { - preorder.skip_subtree(); - } - } else { - preorder.skip_subtree(); - } - } - } - } -} - -// FIXME: doesn't account for labeled breaks in labeled blocks -fn for_each_inner_tail(expr: &ast::Expr, cb: &mut dyn FnMut(&ast::Expr)) { - match expr { - ast::Expr::BlockExpr(b) => { - if let Some(e) = b.tail_expr() { - for_each_inner_tail(&e, cb); - } - } - ast::Expr::EffectExpr(e) => match e.effect() { - ast::Effect::Label(label) => { - for_each_break(Some(label), e.block_expr(), &mut |b| cb(&ast::Expr::BreakExpr(b))); - if let Some(b) = e.block_expr() { - for_each_inner_tail(&ast::Expr::BlockExpr(b), cb); - } - } - ast::Effect::Unsafe(_) => { - if let Some(e) = e.block_expr().and_then(|b| b.tail_expr()) { - for_each_inner_tail(&e, cb); - } - } - ast::Effect::Async(_) | ast::Effect::Try(_) | ast::Effect::Const(_) => cb(expr), - }, - ast::Expr::IfExpr(if_) => { - if_.blocks().for_each(|block| for_each_inner_tail(&ast::Expr::BlockExpr(block), cb)) - } - ast::Expr::LoopExpr(l) => { - for_each_break(l.label(), l.loop_body(), &mut |b| cb(&ast::Expr::BreakExpr(b))) - } - ast::Expr::MatchExpr(m) => { - if let Some(arms) = m.match_arm_list() { - arms.arms().filter_map(|arm| arm.expr()).for_each(|e| for_each_inner_tail(&e, cb)); - } - } - ast::Expr::ArrayExpr(_) - | ast::Expr::AwaitExpr(_) - | ast::Expr::BinExpr(_) - | ast::Expr::BoxExpr(_) - | ast::Expr::BreakExpr(_) - | ast::Expr::CallExpr(_) - | ast::Expr::CastExpr(_) - | ast::Expr::ClosureExpr(_) - | ast::Expr::ContinueExpr(_) - | ast::Expr::FieldExpr(_) - | ast::Expr::ForExpr(_) - | ast::Expr::IndexExpr(_) - | ast::Expr::Literal(_) - | ast::Expr::MacroCall(_) - | ast::Expr::MacroStmts(_) - | ast::Expr::MethodCallExpr(_) - | ast::Expr::ParenExpr(_) - | ast::Expr::PathExpr(_) - | ast::Expr::PrefixExpr(_) - | ast::Expr::RangeExpr(_) - | ast::Expr::RecordExpr(_) - | ast::Expr::RefExpr(_) - | ast::Expr::ReturnExpr(_) - | ast::Expr::TryExpr(_) - | ast::Expr::TupleExpr(_) - | ast::Expr::WhileExpr(_) - | ast::Expr::YieldExpr(_) => cb(expr), - } -} - -fn for_each_break( - label: Option, - body: Option, - cb: &mut dyn FnMut(ast::BreakExpr), -) { - let label = label.and_then(|lbl| lbl.lifetime()); - let mut depth = 0; - if let Some(b) = body { - let preorder = &mut b.syntax().preorder(); - let ev_as_expr = |ev| match ev { - WalkEvent::Enter(it) => Some(WalkEvent::Enter(ast::Expr::cast(it)?)), - WalkEvent::Leave(it) => Some(WalkEvent::Leave(ast::Expr::cast(it)?)), - }; - let eq_label = |lt: Option| { - lt.zip(label.as_ref()).map_or(false, |(lt, lbl)| lt.text() == lbl.text()) - }; - while let Some(node) = preorder.find_map(ev_as_expr) { - match node { - WalkEvent::Enter(expr) => match expr { - ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_) => { - depth += 1 - } - ast::Expr::EffectExpr(e) if e.label().is_some() => depth += 1, - ast::Expr::BreakExpr(b) - if (depth == 0 && b.lifetime().is_none()) || eq_label(b.lifetime()) => - { - cb(b); - } - _ => (), - }, - WalkEvent::Leave(expr) => match expr { - ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_) => { - depth -= 1 - } - ast::Expr::EffectExpr(e) if e.label().is_some() => depth -= 1, - _ => (), - }, - } - } - } -} - fn cover_range(r0: Option, r1: Option) -> Option { match (r0, r1) { (Some(r0), Some(r1)) => Some(r0.cover(r1)), diff --git a/crates/ide_assists/src/handlers/wrap_return_type_in_result.rs b/crates/ide_assists/src/handlers/wrap_return_type_in_result.rs index 140e27356f..65d0640a38 100644 --- a/crates/ide_assists/src/handlers/wrap_return_type_in_result.rs +++ b/crates/ide_assists/src/handlers/wrap_return_type_in_result.rs @@ -1,8 +1,9 @@ use std::iter; +use ide_db::helpers::for_each_tail_expr; use syntax::{ - ast::{self, make, BlockExpr, Expr, LoopBodyOwner}, - match_ast, AstNode, SyntaxNode, + ast::{self, make, Expr}, + match_ast, AstNode, }; use crate::{AssistContext, AssistId, AssistKind, Assists}; @@ -21,7 +22,7 @@ use crate::{AssistContext, AssistId, AssistKind, Assists}; pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { let ret_type = ctx.find_node_at_offset::()?; let parent = ret_type.syntax().parent()?; - let block_expr = match_ast! { + let body = match_ast! { match parent { ast::Fn(func) => func.body()?, ast::ClosureExpr(closure) => match closure.body()? { @@ -32,6 +33,7 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) _ => return None, } }; + let body = ast::Expr::BlockExpr(body); let type_ref = &ret_type.ty()?; let ret_type_str = type_ref.syntax().text().to_string(); @@ -48,11 +50,18 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) "Wrap return type in Result", type_ref.syntax().text_range(), |builder| { - let mut tail_return_expr_collector = TailReturnCollector::new(); - tail_return_expr_collector.collect_jump_exprs(&block_expr, false); - tail_return_expr_collector.collect_tail_exprs(&block_expr); + let mut exprs_to_wrap = Vec::new(); + let tail_cb = &mut |e: &_| tail_cb_impl(&mut exprs_to_wrap, e); + body.walk(&mut |expr| { + if let Expr::ReturnExpr(ret_expr) = expr { + if let Some(ret_expr_arg) = &ret_expr.expr() { + for_each_tail_expr(ret_expr_arg, tail_cb); + } + } + }); + for_each_tail_expr(&body, tail_cb); - for ret_expr_arg in tail_return_expr_collector.exprs_to_wrap { + for ret_expr_arg in exprs_to_wrap { let ok_wrapped = make::expr_call( make::expr_path(make::ext::ident_path("Ok")), make::arg_list(iter::once(ret_expr_arg.clone())), @@ -72,199 +81,14 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext) ) } -struct TailReturnCollector { - exprs_to_wrap: Vec, -} - -impl TailReturnCollector { - fn new() -> Self { - Self { exprs_to_wrap: vec![] } - } - /// Collect all`return` expression - fn collect_jump_exprs(&mut self, block_expr: &BlockExpr, collect_break: bool) { - let statements = block_expr.statements(); - for stmt in statements { - let expr = match &stmt { - ast::Stmt::ExprStmt(stmt) => stmt.expr(), - ast::Stmt::LetStmt(stmt) => stmt.initializer(), - ast::Stmt::Item(_) => continue, - }; - if let Some(expr) = &expr { - self.handle_exprs(expr, collect_break); +fn tail_cb_impl(acc: &mut Vec, e: &ast::Expr) { + match e { + Expr::BreakExpr(break_expr) => { + if let Some(break_expr_arg) = break_expr.expr() { + for_each_tail_expr(&break_expr_arg, &mut |e| tail_cb_impl(acc, e)) } } - - // Browse tail expressions for each block - if let Some(expr) = block_expr.tail_expr() { - if let Some(last_exprs) = get_tail_expr_from_block(&expr) { - for last_expr in last_exprs { - let last_expr = match last_expr { - NodeType::Node(expr) => expr, - NodeType::Leaf(expr) => expr.syntax().clone(), - }; - - if let Some(last_expr) = Expr::cast(last_expr.clone()) { - self.handle_exprs(&last_expr, collect_break); - } else if let Some(expr_stmt) = ast::Stmt::cast(last_expr) { - let expr_stmt = match &expr_stmt { - ast::Stmt::ExprStmt(stmt) => stmt.expr(), - ast::Stmt::LetStmt(stmt) => stmt.initializer(), - ast::Stmt::Item(_) => None, - }; - if let Some(expr) = &expr_stmt { - self.handle_exprs(expr, collect_break); - } - } - } - } - } - } - - fn handle_exprs(&mut self, expr: &Expr, collect_break: bool) { - match expr { - Expr::BlockExpr(block_expr) => { - self.collect_jump_exprs(block_expr, collect_break); - } - Expr::ReturnExpr(ret_expr) => { - if let Some(ret_expr_arg) = &ret_expr.expr() { - self.exprs_to_wrap.push(ret_expr_arg.clone()); - } - } - Expr::BreakExpr(break_expr) if collect_break => { - if let Some(break_expr_arg) = &break_expr.expr() { - self.exprs_to_wrap.push(break_expr_arg.clone()); - } - } - Expr::IfExpr(if_expr) => { - for block in if_expr.blocks() { - self.collect_jump_exprs(&block, collect_break); - } - } - Expr::LoopExpr(loop_expr) => { - if let Some(block_expr) = loop_expr.loop_body() { - self.collect_jump_exprs(&block_expr, collect_break); - } - } - Expr::ForExpr(for_expr) => { - if let Some(block_expr) = for_expr.loop_body() { - self.collect_jump_exprs(&block_expr, collect_break); - } - } - Expr::WhileExpr(while_expr) => { - if let Some(block_expr) = while_expr.loop_body() { - self.collect_jump_exprs(&block_expr, collect_break); - } - } - Expr::MatchExpr(match_expr) => { - if let Some(arm_list) = match_expr.match_arm_list() { - arm_list.arms().filter_map(|match_arm| match_arm.expr()).for_each(|expr| { - self.handle_exprs(&expr, collect_break); - }); - } - } - _ => {} - } - } - - fn collect_tail_exprs(&mut self, block: &BlockExpr) { - if let Some(expr) = block.tail_expr() { - self.handle_exprs(&expr, true); - self.fetch_tail_exprs(&expr); - } - } - - fn fetch_tail_exprs(&mut self, expr: &Expr) { - if let Some(exprs) = get_tail_expr_from_block(expr) { - for node_type in &exprs { - match node_type { - NodeType::Leaf(expr) => { - self.exprs_to_wrap.push(expr.clone()); - } - NodeType::Node(expr) => { - if let Some(last_expr) = Expr::cast(expr.clone()) { - self.fetch_tail_exprs(&last_expr); - } - } - } - } - } - } -} - -#[derive(Debug)] -enum NodeType { - Leaf(ast::Expr), - Node(SyntaxNode), -} - -/// Get a tail expression inside a block -fn get_tail_expr_from_block(expr: &Expr) -> Option> { - match expr { - Expr::IfExpr(if_expr) => { - let mut nodes = vec![]; - for block in if_expr.blocks() { - if let Some(block_expr) = block.tail_expr() { - if let Some(tail_exprs) = get_tail_expr_from_block(&block_expr) { - nodes.extend(tail_exprs); - } - } else if let Some(last_expr) = block.syntax().last_child() { - nodes.push(NodeType::Node(last_expr)); - } else { - nodes.push(NodeType::Node(block.syntax().clone())); - } - } - Some(nodes) - } - Expr::LoopExpr(loop_expr) => { - loop_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)]) - } - Expr::ForExpr(for_expr) => { - for_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)]) - } - Expr::WhileExpr(while_expr) => { - while_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)]) - } - Expr::BlockExpr(block_expr) => { - block_expr.tail_expr().map(|lc| vec![NodeType::Node(lc.syntax().clone())]) - } - Expr::MatchExpr(match_expr) => { - let arm_list = match_expr.match_arm_list()?; - let arms: Vec = arm_list - .arms() - .filter_map(|match_arm| match_arm.expr()) - .map(|expr| match expr { - Expr::ReturnExpr(ret_expr) => NodeType::Node(ret_expr.syntax().clone()), - Expr::BreakExpr(break_expr) => NodeType::Node(break_expr.syntax().clone()), - _ => match expr.syntax().last_child() { - Some(last_expr) => NodeType::Node(last_expr), - None => NodeType::Node(expr.syntax().clone()), - }, - }) - .collect(); - - Some(arms) - } - Expr::BreakExpr(expr) => expr.expr().map(|e| vec![NodeType::Leaf(e)]), - Expr::ReturnExpr(ret_expr) => Some(vec![NodeType::Node(ret_expr.syntax().clone())]), - - Expr::CallExpr(_) - | Expr::Literal(_) - | Expr::TupleExpr(_) - | Expr::ArrayExpr(_) - | Expr::ParenExpr(_) - | Expr::PathExpr(_) - | Expr::RecordExpr(_) - | Expr::IndexExpr(_) - | Expr::MethodCallExpr(_) - | Expr::AwaitExpr(_) - | Expr::CastExpr(_) - | Expr::RefExpr(_) - | Expr::PrefixExpr(_) - | Expr::RangeExpr(_) - | Expr::BinExpr(_) - | Expr::MacroCall(_) - | Expr::BoxExpr(_) => Some(vec![NodeType::Leaf(expr.clone())]), - _ => None, + e => acc.push(e.clone()), } } @@ -293,6 +117,35 @@ fn foo() -> Result { ); } + #[test] + fn wrap_return_type_break_split_tail() { + check_assist( + wrap_return_type_in_result, + r#" +fn foo() -> i3$02 { + loop { + break if true { + 1 + } else { + 0 + }; + } +} +"#, + r#" +fn foo() -> Result { + loop { + break if true { + Ok(1) + } else { + Ok(0) + }; + } +} +"#, + ); + } + #[test] fn wrap_return_type_in_result_simple_closure() { check_assist( @@ -940,90 +793,6 @@ fn foo() -> Result { "#, ); - check_assist( - wrap_return_type_in_result, - r#" -fn foo() -> i32$0 { - let test = "test"; - if test == "test" { - return 24i32; - } - let mut i = 0; - loop { - loop { - if i == 1 { - break 55; - } - i += 1; - } - } -} -"#, - r#" -fn foo() -> Result { - let test = "test"; - if test == "test" { - return Ok(24i32); - } - let mut i = 0; - loop { - loop { - if i == 1 { - break Ok(55); - } - i += 1; - } - } -} -"#, - ); - - check_assist( - wrap_return_type_in_result, - r#" -fn foo() -> i3$02 { - let test = "test"; - let other = 5; - if test == "test" { - let res = match other { - 5 => 43, - _ => return 56, - }; - } - let mut i = 0; - loop { - loop { - if i == 1 { - break 55; - } - i += 1; - } - } -} -"#, - r#" -fn foo() -> Result { - let test = "test"; - let other = 5; - if test == "test" { - let res = match other { - 5 => 43, - _ => return Ok(56), - }; - } - let mut i = 0; - loop { - loop { - if i == 1 { - break Ok(55); - } - i += 1; - } - } -} -"#, - ); - check_assist( wrap_return_type_in_result, r#" diff --git a/crates/ide_db/src/helpers.rs b/crates/ide_db/src/helpers.rs index bc21977e36..632fd36590 100644 --- a/crates/ide_db/src/helpers.rs +++ b/crates/ide_db/src/helpers.rs @@ -11,8 +11,8 @@ use base_db::FileId; use either::Either; use hir::{Crate, Enum, ItemInNs, MacroDef, Module, ModuleDef, Name, ScopeDef, Semantics, Trait}; use syntax::{ - ast::{self, make}, - SyntaxKind, SyntaxToken, TokenAtOffset, + ast::{self, make, LoopBodyOwner}, + AstNode, SyntaxKind, SyntaxToken, TokenAtOffset, WalkEvent, }; use crate::RootDatabase; @@ -204,3 +204,111 @@ impl SnippetCap { } } } + +/// Calls `cb` on each expression inside `expr` that is at "tail position". +pub fn for_each_tail_expr(expr: &ast::Expr, cb: &mut dyn FnMut(&ast::Expr)) { + match expr { + ast::Expr::BlockExpr(b) => { + if let Some(e) = b.tail_expr() { + for_each_tail_expr(&e, cb); + } + } + ast::Expr::EffectExpr(e) => match e.effect() { + ast::Effect::Label(label) => { + for_each_break_expr(Some(label), e.block_expr(), &mut |b| { + cb(&ast::Expr::BreakExpr(b)) + }); + if let Some(b) = e.block_expr() { + for_each_tail_expr(&ast::Expr::BlockExpr(b), cb); + } + } + ast::Effect::Unsafe(_) => { + if let Some(e) = e.block_expr().and_then(|b| b.tail_expr()) { + for_each_tail_expr(&e, cb); + } + } + ast::Effect::Async(_) | ast::Effect::Try(_) | ast::Effect::Const(_) => cb(expr), + }, + ast::Expr::IfExpr(if_) => { + if_.blocks().for_each(|block| for_each_tail_expr(&ast::Expr::BlockExpr(block), cb)) + } + ast::Expr::LoopExpr(l) => { + for_each_break_expr(l.label(), l.loop_body(), &mut |b| cb(&ast::Expr::BreakExpr(b))) + } + ast::Expr::MatchExpr(m) => { + if let Some(arms) = m.match_arm_list() { + arms.arms().filter_map(|arm| arm.expr()).for_each(|e| for_each_tail_expr(&e, cb)); + } + } + ast::Expr::ArrayExpr(_) + | ast::Expr::AwaitExpr(_) + | ast::Expr::BinExpr(_) + | ast::Expr::BoxExpr(_) + | ast::Expr::BreakExpr(_) + | ast::Expr::CallExpr(_) + | ast::Expr::CastExpr(_) + | ast::Expr::ClosureExpr(_) + | ast::Expr::ContinueExpr(_) + | ast::Expr::FieldExpr(_) + | ast::Expr::ForExpr(_) + | ast::Expr::IndexExpr(_) + | ast::Expr::Literal(_) + | ast::Expr::MacroCall(_) + | ast::Expr::MacroStmts(_) + | ast::Expr::MethodCallExpr(_) + | ast::Expr::ParenExpr(_) + | ast::Expr::PathExpr(_) + | ast::Expr::PrefixExpr(_) + | ast::Expr::RangeExpr(_) + | ast::Expr::RecordExpr(_) + | ast::Expr::RefExpr(_) + | ast::Expr::ReturnExpr(_) + | ast::Expr::TryExpr(_) + | ast::Expr::TupleExpr(_) + | ast::Expr::WhileExpr(_) + | ast::Expr::YieldExpr(_) => cb(expr), + } +} + +/// Calls `cb` on each break expr inside of `body` that is applicable for the given label. +pub fn for_each_break_expr( + label: Option, + body: Option, + cb: &mut dyn FnMut(ast::BreakExpr), +) { + let label = label.and_then(|lbl| lbl.lifetime()); + let mut depth = 0; + if let Some(b) = body { + let preorder = &mut b.syntax().preorder(); + let ev_as_expr = |ev| match ev { + WalkEvent::Enter(it) => Some(WalkEvent::Enter(ast::Expr::cast(it)?)), + WalkEvent::Leave(it) => Some(WalkEvent::Leave(ast::Expr::cast(it)?)), + }; + let eq_label = |lt: Option| { + lt.zip(label.as_ref()).map_or(false, |(lt, lbl)| lt.text() == lbl.text()) + }; + while let Some(node) = preorder.find_map(ev_as_expr) { + match node { + WalkEvent::Enter(expr) => match expr { + ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_) => { + depth += 1 + } + ast::Expr::EffectExpr(e) if e.label().is_some() => depth += 1, + ast::Expr::BreakExpr(b) + if (depth == 0 && b.lifetime().is_none()) || eq_label(b.lifetime()) => + { + cb(b); + } + _ => (), + }, + WalkEvent::Leave(expr) => match expr { + ast::Expr::LoopExpr(_) | ast::Expr::WhileExpr(_) | ast::Expr::ForExpr(_) => { + depth -= 1 + } + ast::Expr::EffectExpr(e) if e.label().is_some() => depth -= 1, + _ => (), + }, + } + } + } +} diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs index e33e5bb037..826efdfe87 100644 --- a/crates/syntax/src/ast/node_ext.rs +++ b/crates/syntax/src/ast/node_ext.rs @@ -5,7 +5,7 @@ use std::{borrow::Cow, fmt, iter::successors}; use itertools::Itertools; use parser::SyntaxKind; -use rowan::{GreenNodeData, GreenTokenData}; +use rowan::{GreenNodeData, GreenTokenData, WalkEvent}; use crate::{ ast::{self, support, AstChildren, AstNode, AstToken, AttrsOwner, NameOwner, SyntaxNode}, @@ -51,6 +51,56 @@ impl ast::BlockExpr { } } +impl ast::Expr { + /// Preorder walk all the expression's child expressions. + pub fn walk(&self, cb: &mut dyn FnMut(ast::Expr)) { + let mut preorder = self.syntax().preorder(); + while let Some(event) = preorder.next() { + let node = match event { + WalkEvent::Enter(node) => node, + WalkEvent::Leave(_) => continue, + }; + match ast::Stmt::cast(node.clone()) { + // recursively walk the initializer, skipping potential const pat expressions + // let statements aren't usually nested too deeply so this is fine to recurse on + Some(ast::Stmt::LetStmt(l)) => { + if let Some(expr) = l.initializer() { + expr.walk(cb); + } + preorder.skip_subtree(); + } + // Don't skip subtree since we want to process the expression child next + Some(ast::Stmt::ExprStmt(_)) => (), + // skip inner items which might have their own expressions + Some(ast::Stmt::Item(_)) => preorder.skip_subtree(), + None => { + // skip const args, those expressions are a different context + if ast::GenericArg::can_cast(node.kind()) { + preorder.skip_subtree(); + } else if let Some(expr) = ast::Expr::cast(node) { + let is_different_context = match &expr { + ast::Expr::EffectExpr(effect) => { + matches!( + effect.effect(), + ast::Effect::Async(_) + | ast::Effect::Try(_) + | ast::Effect::Const(_) + ) + } + ast::Expr::ClosureExpr(__) => true, + _ => false, + }; + cb(expr); + if is_different_context { + preorder.skip_subtree(); + } + } + } + } + } + } +} + #[derive(Debug, PartialEq, Eq, Clone)] pub enum Macro { MacroRules(ast::MacroRules),