Use more strictly typed syntax nodes for analysis in extract_function assist

This commit is contained in:
Lukas Wirth 2021-07-29 17:17:45 +02:00
parent 2b461c50d7
commit b537cb186e
4 changed files with 200 additions and 89 deletions

1
Cargo.lock generated
View file

@ -600,6 +600,7 @@ dependencies = [
"expect-test",
"hir",
"ide_db",
"indexmap",
"itertools",
"profile",
"rustc-hash",

View file

@ -13,6 +13,7 @@ cov-mark = "2.0.0-pre.1"
rustc-hash = "1.1.0"
itertools = "0.10.0"
either = "1.6.1"
indexmap = "1.6.2"
stdx = { path = "../stdx", version = "0.0.0" }
syntax = { path = "../syntax", version = "0.0.0" }

View file

@ -1,13 +1,14 @@
use std::iter;
use std::{hash::BuildHasherDefault, iter};
use ast::make;
use either::Either;
use hir::{HirDisplay, Local};
use hir::{HirDisplay, Local, Semantics};
use ide_db::{
defs::{Definition, NameRefClass},
search::{FileReference, ReferenceAccess, SearchScope},
RootDatabase,
};
use itertools::Itertools;
use rustc_hash::FxHasher;
use stdx::format_to;
use syntax::{
ast::{
@ -25,6 +26,8 @@ use crate::{
AssistId,
};
type FxIndexSet<T> = indexmap::IndexSet<T, BuildHasherDefault<FxHasher>>;
// Assist: extract_function
//
// Extracts selected statements into new function.
@ -51,7 +54,8 @@ use crate::{
// }
// ```
pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
if ctx.frange.range.is_empty() {
let range = ctx.frange.range;
if range.is_empty() {
return None;
}
@ -65,11 +69,9 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
syntax::NodeOrToken::Node(n) => n,
syntax::NodeOrToken::Token(t) => t.parent()?,
};
let body = extraction_target(&node, range)?;
let body = extraction_target(&node, ctx.frange.range)?;
let vars_used_in_body = vars_used_in_body(ctx, &body);
let self_param = self_param_from_usages(ctx, &body, &vars_used_in_body);
let (locals_used, has_await, self_param) = analyze_body(&ctx.sema, &body);
let anchor = if self_param.is_some() { Anchor::Method } else { Anchor::Freestanding };
let insert_after = scope_for_fn_insertion(&body, anchor)?;
@ -95,7 +97,7 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
"Extract into function",
target_range,
move |builder| {
let params = extracted_function_params(ctx, &body, &vars_used_in_body);
let params = extracted_function_params(ctx, &body, locals_used.iter().copied());
let fun = Function {
name: "fun_name".to_string(),
@ -109,15 +111,10 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext) -> Option
let new_indent = IndentLevel::from_node(&insert_after);
let old_indent = fun.body.indent_level();
let body_contains_await = body_contains_await(&fun.body);
builder.replace(
target_range,
format_replacement(ctx, &fun, old_indent, body_contains_await),
);
builder.replace(target_range, format_replacement(ctx, &fun, old_indent, has_await));
let fn_def =
format_function(ctx, module, &fun, old_indent, new_indent, body_contains_await);
let fn_def = format_function(ctx, module, &fun, old_indent, new_indent, has_await);
let insert_offset = insert_after.text_range().end();
match ctx.config.snippet_cap {
Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def),
@ -500,15 +497,59 @@ impl FunctionBody {
}
}
fn descendants(&self) -> impl Iterator<Item = SyntaxNode> + '_ {
fn walk_expr(&self, cb: &mut dyn FnMut(ast::Expr)) {
match self {
FunctionBody::Expr(expr) => Either::Right(expr.syntax().descendants()),
FunctionBody::Span { parent, text_range } => Either::Left(
FunctionBody::Expr(expr) => expr.walk(cb),
FunctionBody::Span { parent, text_range } => {
parent
.syntax()
.descendants()
.filter(move |it| text_range.contains_range(it.text_range())),
),
.statements()
.filter(|stmt| text_range.contains_range(stmt.syntax().text_range()))
.filter_map(|stmt| match stmt {
ast::Stmt::ExprStmt(expr_stmt) => expr_stmt.expr(),
ast::Stmt::Item(_) => None,
ast::Stmt::LetStmt(stmt) => stmt.initializer(),
})
.for_each(|expr| expr.walk(cb));
if let Some(expr) = parent
.tail_expr()
.filter(|it| text_range.contains_range(it.syntax().text_range()))
{
expr.walk(cb);
}
}
}
}
fn walk_pat(&self, cb: &mut dyn FnMut(ast::Pat)) {
match self {
FunctionBody::Expr(expr) => expr.walk_patterns(cb),
FunctionBody::Span { parent, text_range } => {
parent
.statements()
.filter(|stmt| text_range.contains_range(stmt.syntax().text_range()))
.for_each(|stmt| match stmt {
ast::Stmt::ExprStmt(expr_stmt) => {
if let Some(expr) = expr_stmt.expr() {
expr.walk_patterns(cb)
}
}
ast::Stmt::Item(_) => (),
ast::Stmt::LetStmt(stmt) => {
if let Some(pat) = stmt.pat() {
pat.walk(cb);
}
if let Some(expr) = stmt.initializer() {
expr.walk_patterns(cb);
}
}
});
if let Some(expr) = parent
.tail_expr()
.filter(|it| text_range.contains_range(it.syntax().text_range()))
{
expr.walk_patterns(cb);
}
}
}
}
@ -622,58 +663,48 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option<Fu
node.ancestors().find_map(ast::Expr::cast).and_then(FunctionBody::from_expr)
}
/// list local variables that are referenced in `body`
fn vars_used_in_body(ctx: &AssistContext, body: &FunctionBody) -> Vec<Local> {
// FIXME: currently usages inside macros are not found
body.descendants()
.filter_map(ast::NameRef::cast)
.filter_map(|name_ref| NameRefClass::classify(&ctx.sema, &name_ref))
.map(|name_kind| match name_kind {
NameRefClass::Definition(def) => def,
NameRefClass::FieldShorthand { local_ref, field_ref: _ } => {
Definition::Local(local_ref)
}
})
.filter_map(|definition| match definition {
Definition::Local(local) => Some(local),
_ => None,
})
.unique()
.collect()
}
fn body_contains_await(body: &FunctionBody) -> bool {
body.descendants().any(|d| matches!(d.kind(), SyntaxKind::AWAIT_EXPR))
}
/// find `self` param, that was not defined inside `body`
///
/// It should skip `self` params from impls inside `body`
fn self_param_from_usages(
ctx: &AssistContext,
/// Analyzes a function body, returning the used local variables that are referenced in it as well as
/// whether it contains an await expression.
fn analyze_body(
sema: &Semantics<RootDatabase>,
body: &FunctionBody,
vars_used_in_body: &[Local],
) -> Option<(Local, ast::SelfParam)> {
let mut iter = vars_used_in_body
.iter()
.filter(|var| var.is_self(ctx.db()))
.map(|var| (var, var.source(ctx.db())))
.filter(|(_, src)| is_defined_before(ctx, body, src))
.filter_map(|(&node, src)| match src.value {
Either::Right(it) => Some((node, it)),
) -> (FxIndexSet<Local>, bool, Option<(Local, ast::SelfParam)>) {
// FIXME: currently usages inside macros are not found
let mut has_await = false;
let mut self_param = None;
let mut res = FxIndexSet::default();
body.walk_expr(&mut |expr| {
has_await |= matches!(expr, ast::Expr::AwaitExpr(_));
let name_ref = match expr {
ast::Expr::PathExpr(path_expr) => {
path_expr.path().and_then(|it| it.as_single_name_ref())
}
_ => return,
};
if let Some(name_ref) = name_ref {
if let Some(
NameRefClass::Definition(Definition::Local(local_ref))
| NameRefClass::FieldShorthand { local_ref, field_ref: _ },
) = NameRefClass::classify(sema, &name_ref)
{
res.insert(local_ref);
if local_ref.is_self(sema.db) {
match local_ref.source(sema.db).value {
Either::Right(it) => {
stdx::always!(
self_param.replace((local_ref, it)).is_none(),
"body references two different self params"
);
}
Either::Left(_) => {
stdx::never!(false, "Local::is_self returned true, but source is IdentPat");
None
stdx::never!("Local::is_self returned true, but source is IdentPat");
}
}
}
}
}
});
let self_param = iter.next();
stdx::always!(
iter.next().is_none(),
"body references two different self params, both defined outside"
);
self_param
(res, has_await, self_param)
}
/// find variables that should be extracted as params
@ -682,16 +713,15 @@ fn self_param_from_usages(
fn extracted_function_params(
ctx: &AssistContext,
body: &FunctionBody,
vars_used_in_body: &[Local],
locals: impl Iterator<Item = Local>,
) -> Vec<Param> {
vars_used_in_body
.iter()
.filter(|var| !var.is_self(ctx.db()))
.map(|node| (node, node.source(ctx.db())))
.filter(|(_, src)| is_defined_before(ctx, body, src))
.filter_map(|(&node, src)| {
locals
.filter(|local| !local.is_self(ctx.db()))
.map(|local| (local, local.source(ctx.db())))
.filter(|(_, src)| is_defined_outside_of_body(ctx, body, src))
.filter_map(|(local, src)| {
if src.value.is_left() {
Some(node)
Some(local)
} else {
stdx::never!(false, "Local::is_self returned false, but source is SelfParam");
None
@ -838,14 +868,18 @@ fn path_element_of_reference(
}
/// list local variables defined inside `body`
fn vars_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> Vec<Local> {
fn locals_defined_in_body(body: &FunctionBody, ctx: &AssistContext) -> FxIndexSet<Local> {
// FIXME: this doesn't work well with macros
// see https://github.com/rust-analyzer/rust-analyzer/pull/7535#discussion_r570048550
body.descendants()
.filter_map(ast::IdentPat::cast)
.filter_map(|let_stmt| ctx.sema.to_def(&let_stmt))
.unique()
.collect()
let mut res = FxIndexSet::default();
body.walk_pat(&mut |pat| {
if let ast::Pat::IdentPat(pat) = pat {
if let Some(local) = ctx.sema.to_def(&pat) {
res.insert(local);
}
}
});
res
}
/// list local variables defined inside `body` that should be returned from extracted function
@ -854,7 +888,7 @@ fn vars_defined_in_body_and_outlive(
body: &FunctionBody,
parent: &SyntaxNode,
) -> Vec<OutlivedLocal> {
let vars_defined_in_body = vars_defined_in_body(body, ctx);
let vars_defined_in_body = locals_defined_in_body(body, ctx);
vars_defined_in_body
.into_iter()
.filter_map(|var| var_outlives_body(ctx, body, var, parent))
@ -862,7 +896,7 @@ fn vars_defined_in_body_and_outlive(
}
/// checks if the relevant local was defined before(outside of) body
fn is_defined_before(
fn is_defined_outside_of_body(
ctx: &AssistContext,
body: &FunctionBody,
src: &hir::InFile<Either<ast::IdentPat, ast::SelfParam>>,

View file

@ -103,6 +103,81 @@ impl ast::Expr {
}
}
}
/// Preorder walk all the expression's child patterns.
pub fn walk_patterns(&self, cb: &mut dyn FnMut(ast::Pat)) {
let mut preorder = self.syntax().preorder();
while let Some(event) = preorder.next() {
let node = match event {
WalkEvent::Enter(node) => node,
WalkEvent::Leave(_) => continue,
};
match ast::Stmt::cast(node.clone()) {
Some(ast::Stmt::LetStmt(l)) => {
if let Some(pat) = l.pat() {
pat.walk(cb);
}
if let Some(expr) = l.initializer() {
expr.walk_patterns(cb);
}
preorder.skip_subtree();
}
// Don't skip subtree since we want to process the expression child next
Some(ast::Stmt::ExprStmt(_)) => (),
// skip inner items which might have their own patterns
Some(ast::Stmt::Item(_)) => preorder.skip_subtree(),
None => {
// skip const args, those are a different context
if ast::GenericArg::can_cast(node.kind()) {
preorder.skip_subtree();
} else if let Some(expr) = ast::Expr::cast(node.clone()) {
let is_different_context = match &expr {
ast::Expr::EffectExpr(effect) => {
matches!(
effect.effect(),
ast::Effect::Async(_)
| ast::Effect::Try(_)
| ast::Effect::Const(_)
)
}
ast::Expr::ClosureExpr(_) => true,
_ => false,
};
if is_different_context {
preorder.skip_subtree();
}
} else if let Some(pat) = ast::Pat::cast(node) {
preorder.skip_subtree();
pat.walk(cb);
}
}
}
}
}
}
impl ast::Pat {
/// Preorder walk all the pattern's sub patterns.
pub fn walk(&self, cb: &mut dyn FnMut(ast::Pat)) {
let mut preorder = self.syntax().preorder();
while let Some(event) = preorder.next() {
let node = match event {
WalkEvent::Enter(node) => node,
WalkEvent::Leave(_) => continue,
};
match ast::Pat::cast(node.clone()) {
Some(ast::Pat::ConstBlockPat(_)) => preorder.skip_subtree(),
Some(pat) => {
cb(pat);
}
// skip const args
None if ast::GenericArg::can_cast(node.kind()) => {
preorder.skip_subtree();
}
None => (),
}
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]