Migrate extract_function to mutable ast

This commit is contained in:
DropDemBits 2023-12-11 17:37:45 -05:00
parent 3924a0ef7c
commit 0e39257e5b
No known key found for this signature in database
GPG key ID: 7FE02A6C1EDFA075
2 changed files with 193 additions and 112 deletions

View file

@ -1,4 +1,4 @@
use std::iter;
use std::{iter, ops::RangeInclusive};
use ast::make;
use either::Either;
@ -12,27 +12,25 @@ use ide_db::{
helpers::mod_path_to_ast,
imports::insert_use::{insert_use, ImportScope},
search::{FileReference, ReferenceCategory, SearchScope},
source_change::SourceChangeBuilder,
syntax_helpers::node_ext::{
for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr,
},
FxIndexSet, RootDatabase,
};
use itertools::Itertools;
use stdx::format_to;
use syntax::{
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
AstNode, HasGenericParams,
self, edit::IndentLevel, edit_in_place::Indent, AstNode, AstToken, HasGenericParams,
HasName,
},
match_ast, ted, AstToken, SyntaxElement,
match_ast, ted, SyntaxElement,
SyntaxKind::{self, COMMENT},
SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, WalkEvent, T,
};
use crate::{
assist_context::{AssistContext, Assists, TreeMutator},
utils::generate_impl_text,
utils::generate_impl,
AssistId,
};
@ -134,17 +132,65 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
let new_indent = IndentLevel::from_node(&insert_after);
let old_indent = fun.body.indent_level();
builder.replace(target_range, make_call(ctx, &fun, old_indent));
let insert_after = builder.make_syntax_mut(insert_after);
let call_expr = make_call(ctx, &fun, old_indent);
// Map the element range to replace into the mutable version
let elements = match &fun.body {
FunctionBody::Expr(expr) => {
// expr itself becomes the replacement target
let expr = &builder.make_mut(expr.clone());
let node = SyntaxElement::Node(expr.syntax().clone());
node.clone()..=node
}
FunctionBody::Span { parent, elements, .. } => {
// Map the element range into the mutable versions
let parent = builder.make_mut(parent.clone());
let start = parent
.syntax()
.children_with_tokens()
.nth(elements.start().index())
.expect("should be able to find mutable start element");
let end = parent
.syntax()
.children_with_tokens()
.nth(elements.end().index())
.expect("should be able to find mutable end element");
start..=end
}
};
let has_impl_wrapper =
insert_after.ancestors().any(|a| a.kind() == SyntaxKind::IMPL && a != insert_after);
let fn_def = format_function(ctx, module, &fun, old_indent).clone_for_update();
if let Some(cap) = ctx.config.snippet_cap {
if let Some(name) = fn_def.name() {
builder.add_tabstop_before(cap, name);
}
}
let fn_def = match fun.self_param_adt(ctx) {
Some(adt) if anchor == Anchor::Method && !has_impl_wrapper => {
let fn_def = format_function(ctx, module, &fun, old_indent, new_indent + 1);
generate_impl_text(&adt, &fn_def).replace("{\n\n", "{")
fn_def.indent(1.into());
let impl_ = generate_impl(&adt);
impl_.indent(new_indent);
impl_.get_or_create_assoc_item_list().add_item(fn_def.into());
impl_.syntax().clone()
}
_ => {
fn_def.indent(new_indent.into());
fn_def.syntax().clone()
}
_ => format_function(ctx, module, &fun, old_indent, new_indent),
};
// There are external control flows
@ -177,12 +223,15 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
}
}
let insert_offset = insert_after.text_range().end();
// Replace the call site with the call to the new function
fixup_call_site(builder, &fun.body);
ted::replace_all(elements, vec![call_expr.into()]);
match ctx.config.snippet_cap {
Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def),
None => builder.insert(insert_offset, fn_def),
};
// Insert the newly extracted function (or impl)
ted::insert_all_raw(
ted::Position::after(insert_after),
vec![make::tokens::whitespace(&format!("\n\n{new_indent}")).into(), fn_def.into()],
);
},
)
}
@ -225,10 +274,10 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option<Fu
if let Some(stmt) = ast::Stmt::cast(node.clone()) {
return match stmt {
ast::Stmt::Item(_) => None,
ast::Stmt::ExprStmt(_) | ast::Stmt::LetStmt(_) => Some(FunctionBody::from_range(
ast::Stmt::ExprStmt(_) | ast::Stmt::LetStmt(_) => FunctionBody::from_range(
node.parent().and_then(ast::StmtList::cast)?,
node.text_range(),
)),
),
};
}
@ -241,7 +290,7 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option<Fu
}
// Extract the full statements.
return Some(FunctionBody::from_range(stmt_list, selection_range));
return FunctionBody::from_range(stmt_list, selection_range);
}
let expr = ast::Expr::cast(node.clone())?;
@ -371,7 +420,7 @@ impl RetType {
#[derive(Debug)]
enum FunctionBody {
Expr(ast::Expr),
Span { parent: ast::StmtList, text_range: TextRange },
Span { parent: ast::StmtList, elements: RangeInclusive<SyntaxElement>, text_range: TextRange },
}
#[derive(Debug)]
@ -569,26 +618,38 @@ impl FunctionBody {
}
}
fn from_range(parent: ast::StmtList, selected: TextRange) -> FunctionBody {
fn from_range(parent: ast::StmtList, selected: TextRange) -> Option<FunctionBody> {
let full_body = parent.syntax().children_with_tokens();
let mut text_range = full_body
// Get all of the elements intersecting with the selection
let mut stmts_in_selection = full_body
.filter(|it| ast::Stmt::can_cast(it.kind()) || it.kind() == COMMENT)
.map(|element| element.text_range())
.filter(|&range| selected.intersect(range).filter(|it| !it.is_empty()).is_some())
.reduce(|acc, stmt| acc.cover(stmt));
.filter(|it| selected.intersect(it.text_range()).filter(|it| !it.is_empty()).is_some());
if let Some(tail_range) = parent
.tail_expr()
.map(|it| it.syntax().text_range())
.filter(|&it| selected.intersect(it).is_some())
let first_element = stmts_in_selection.next();
// If the tail expr is part of the selection too, make that the last element
// Otherwise use the last stmt
let last_element = if let Some(tail_expr) =
parent.tail_expr().filter(|it| selected.intersect(it.syntax().text_range()).is_some())
{
text_range = Some(match text_range {
Some(text_range) => text_range.cover(tail_range),
None => tail_range,
});
}
Self::Span { parent, text_range: text_range.unwrap_or(selected) }
Some(tail_expr.syntax().clone().into())
} else {
stmts_in_selection.last()
};
let elements = match (first_element, last_element) {
(None, _) => {
cov_mark::hit!(extract_function_empty_selection_is_not_applicable);
return None;
}
(Some(first), None) => first.clone()..=first,
(Some(first), Some(last)) => first..=last,
};
let text_range = elements.start().text_range().cover(elements.end().text_range());
Some(Self::Span { parent, elements, text_range })
}
fn indent_level(&self) -> IndentLevel {
@ -601,7 +662,7 @@ impl FunctionBody {
fn tail_expr(&self) -> Option<ast::Expr> {
match &self {
FunctionBody::Expr(expr) => Some(expr.clone()),
FunctionBody::Span { parent, text_range } => {
FunctionBody::Span { parent, text_range, .. } => {
let tail_expr = parent.tail_expr()?;
text_range.contains_range(tail_expr.syntax().text_range()).then_some(tail_expr)
}
@ -611,7 +672,7 @@ impl FunctionBody {
fn walk_expr(&self, cb: &mut dyn FnMut(ast::Expr)) {
match self {
FunctionBody::Expr(expr) => walk_expr(expr, cb),
FunctionBody::Span { parent, text_range } => {
FunctionBody::Span { parent, text_range, .. } => {
parent
.statements()
.filter(|stmt| text_range.contains_range(stmt.syntax().text_range()))
@ -634,7 +695,7 @@ impl FunctionBody {
fn preorder_expr(&self, cb: &mut dyn FnMut(WalkEvent<ast::Expr>) -> bool) {
match self {
FunctionBody::Expr(expr) => preorder_expr(expr, cb),
FunctionBody::Span { parent, text_range } => {
FunctionBody::Span { parent, text_range, .. } => {
parent
.statements()
.filter(|stmt| text_range.contains_range(stmt.syntax().text_range()))
@ -657,7 +718,7 @@ impl FunctionBody {
fn walk_pat(&self, cb: &mut dyn FnMut(ast::Pat)) {
match self {
FunctionBody::Expr(expr) => walk_patterns_in_expr(expr, cb),
FunctionBody::Span { parent, text_range } => {
FunctionBody::Span { parent, text_range, .. } => {
parent
.statements()
.filter(|stmt| text_range.contains_range(stmt.syntax().text_range()))
@ -1151,7 +1212,7 @@ impl HasTokenAtOffset for FunctionBody {
fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset<SyntaxToken> {
match self {
FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset),
FunctionBody::Span { parent, text_range } => {
FunctionBody::Span { parent, text_range, .. } => {
match parent.syntax().token_at_offset(offset) {
TokenAtOffset::None => TokenAtOffset::None,
TokenAtOffset::Single(t) => {
@ -1316,7 +1377,19 @@ fn impl_type_name(impl_node: &ast::Impl) -> Option<String> {
Some(impl_node.self_ty()?.to_string())
}
fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> String {
/// Fixes up the call site before the target expressions are replaced with the call expression
fn fixup_call_site(builder: &mut SourceChangeBuilder, body: &FunctionBody) {
let parent_match_arm = body.parent().and_then(ast::MatchArm::cast);
if let Some(parent_match_arm) = parent_match_arm {
if parent_match_arm.comma_token().is_none() {
let parent_match_arm = builder.make_mut(parent_match_arm);
ted::append_child_raw(parent_match_arm.syntax(), make::token(T![,]));
}
}
}
fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> SyntaxNode {
let ret_ty = fun.return_type(ctx);
let args = make::arg_list(fun.params.iter().map(|param| param.to_arg(ctx)));
@ -1334,44 +1407,49 @@ fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> St
if fun.control_flow.is_async {
call_expr = make::expr_await(call_expr);
}
let expr = handler.make_call_expr(call_expr).indent(indent);
let mut_modifier = |var: &OutlivedLocal| if var.mut_usage_outside_body { "mut " } else { "" };
let expr = handler.make_call_expr(call_expr).clone_for_update();
expr.indent(indent);
let mut buf = String::new();
match fun.outliving_locals.as_slice() {
[] => {}
let outliving_bindings = match fun.outliving_locals.as_slice() {
[] => None,
[var] => {
let modifier = mut_modifier(var);
let name = var.local.name(ctx.db());
format_to!(buf, "let {modifier}{} = ", name.display(ctx.db()))
let name = make::name(&name.display(ctx.db()).to_string());
Some(ast::Pat::IdentPat(make::ident_pat(
false,
var.mut_usage_outside_body,
name.into(),
)))
}
vars => {
buf.push_str("let (");
let bindings = vars.iter().format_with(", ", |local, f| {
let modifier = mut_modifier(local);
let name = local.local.name(ctx.db());
f(&format_args!("{modifier}{}", name.display(ctx.db())))?;
Ok(())
let binding_pats = vars.iter().map(|var| {
let name = var.local.name(ctx.db());
let name = make::name(&name.display(ctx.db()).to_string());
make::ident_pat(false, var.mut_usage_outside_body, name.into()).into()
});
format_to!(buf, "{bindings}");
buf.push_str(") = ");
Some(ast::Pat::TuplePat(make::tuple_pat(binding_pats)))
}
}
};
format_to!(buf, "{expr}");
let parent_match_arm = fun.body.parent().and_then(ast::MatchArm::cast);
let insert_comma = parent_match_arm.as_ref().is_some_and(|it| it.comma_token().is_none());
if insert_comma {
buf.push(',');
} else if parent_match_arm.is_none()
if let Some(bindings) = outliving_bindings {
// with bindings that outlive it
make::let_stmt(bindings, None, Some(expr)).syntax().clone_for_update()
} else if parent_match_arm.as_ref().is_some() {
// as a tail expr for a match arm
expr.syntax().clone()
} else if parent_match_arm.as_ref().is_none()
&& fun.ret_ty.is_unit()
&& (!fun.outliving_locals.is_empty() || !expr.is_block_like())
{
buf.push(';');
// as an expr stmt
make::expr_stmt(expr).syntax().clone_for_update()
} else {
// as a tail expr, or a block
expr.syntax().clone()
}
buf
}
enum FlowHandler {
@ -1500,42 +1578,25 @@ fn format_function(
module: hir::Module,
fun: &Function,
old_indent: IndentLevel,
new_indent: IndentLevel,
) -> String {
let mut fn_def = String::new();
let fun_name = &fun.name;
) -> ast::Fn {
let fun_name = make::name(&fun.name.text());
let params = fun.make_param_list(ctx, module);
let ret_ty = fun.make_ret_ty(ctx, module);
let body = make_body(ctx, old_indent, new_indent, fun);
let const_kw = if fun.mods.is_const { "const " } else { "" };
let async_kw = if fun.control_flow.is_async { "async " } else { "" };
let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" };
let body = make_body(ctx, old_indent, fun);
let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun);
format_to!(fn_def, "\n\n{new_indent}{const_kw}{async_kw}{unsafe_kw}");
match ctx.config.snippet_cap {
Some(_) => format_to!(fn_def, "fn $0{fun_name}"),
None => format_to!(fn_def, "fn {fun_name}"),
}
if let Some(generic_params) = generic_params {
format_to!(fn_def, "{generic_params}");
}
format_to!(fn_def, "{params}");
if let Some(ret_ty) = ret_ty {
format_to!(fn_def, " {ret_ty}");
}
if let Some(where_clause) = where_clause {
format_to!(fn_def, " {where_clause}");
}
format_to!(fn_def, " {body}");
fn_def
make::fn_(
None,
fun_name,
generic_params,
where_clause,
params,
body,
ret_ty,
fun.control_flow.is_async,
fun.mods.is_const,
fun.control_flow.is_unsafe,
)
}
fn make_generic_params_and_where_clause(
@ -1716,12 +1777,7 @@ impl FunType {
}
}
fn make_body(
ctx: &AssistContext<'_>,
old_indent: IndentLevel,
new_indent: IndentLevel,
fun: &Function,
) -> ast::BlockExpr {
fn make_body(ctx: &AssistContext<'_>, old_indent: IndentLevel, fun: &Function) -> ast::BlockExpr {
let ret_ty = fun.return_type(ctx);
let handler = FlowHandler::from_ret_ty(fun, &ret_ty);
@ -1732,7 +1788,7 @@ fn make_body(
match expr {
ast::Expr::BlockExpr(block) => {
// If the extracted expression is itself a block, there is no need to wrap it inside another block.
let block = block.dedent(old_indent);
block.dedent(old_indent);
let elements = block.stmt_list().map_or_else(
|| Either::Left(iter::empty()),
|stmt_list| {
@ -1752,13 +1808,13 @@ fn make_body(
make::hacky_block_expr(elements, block.tail_expr())
}
_ => {
let expr = expr.dedent(old_indent).indent(IndentLevel(1));
expr.reindent_to(1.into());
make::block_expr(Vec::new(), Some(expr))
}
}
}
FunctionBody::Span { parent, text_range } => {
FunctionBody::Span { parent, text_range, .. } => {
let mut elements: Vec<_> = parent
.syntax()
.children_with_tokens()
@ -1801,8 +1857,8 @@ fn make_body(
.map(|node_or_token| match &node_or_token {
syntax::NodeOrToken::Node(node) => match ast::Stmt::cast(node.clone()) {
Some(stmt) => {
let indented = stmt.dedent(old_indent).indent(body_indent);
let ast_node = indented.syntax().clone_subtree();
stmt.reindent_to(body_indent);
let ast_node = stmt.syntax().clone_subtree();
syntax::NodeOrToken::Node(ast_node)
}
_ => node_or_token,
@ -1810,7 +1866,9 @@ fn make_body(
_ => node_or_token,
})
.collect::<Vec<SyntaxElement>>();
let tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent));
if let Some(tail_expr) = &mut tail_expr {
tail_expr.reindent_to(body_indent);
}
make::hacky_block_expr(elements, tail_expr)
}
@ -1853,7 +1911,7 @@ fn make_body(
}),
};
block.indent(new_indent)
block
}
fn map_tail_expr(block: ast::BlockExpr, f: impl FnOnce(ast::Expr) -> ast::Expr) -> ast::BlockExpr {
@ -2551,6 +2609,20 @@ fn $0fun_name(n: u32) -> u32 {
check_assist_not_applicable(extract_function, r"fn main() { 1 + /* $0comment$0 */ 1; }");
}
#[test]
fn empty_selection_is_not_applicable() {
cov_mark::check!(extract_function_empty_selection_is_not_applicable);
check_assist_not_applicable(
extract_function,
r#"
fn main() {
$0
$0
}"#,
);
}
#[test]
fn part_of_expr_stmt() {
check_assist(

View file

@ -687,12 +687,21 @@ pub fn test_some_range(a: int) -> bool {
delete: 59..60,
},
Indel {
insert: "\n\nfn $0fun_name() -> i32 {\n 5\n}",
insert: "\n\nfn fun_name() -> i32 {\n 5\n}",
delete: 110..110,
},
],
},
None,
Some(
SnippetEdit(
[
(
0,
124..124,
),
],
),
),
),
},
file_system_edits: [],