Reject recursive calls in inline_call

This commit is contained in:
Lukas Wirth 2021-09-26 14:38:58 +02:00
parent 1ccb21a0ca
commit 1a50f904ef
2 changed files with 67 additions and 21 deletions

View file

@ -1,7 +1,10 @@
use ast::make; use ast::make;
use hir::{db::HirDatabase, HasSource, PathResolution, Semantics, TypeInfo}; use hir::{db::HirDatabase, HasSource, PathResolution, Semantics, TypeInfo};
use ide_db::{ use ide_db::{
base_db::FileId, defs::Definition, path_transform::PathTransform, search::FileReference, base_db::{FileId, FileRange},
defs::Definition,
path_transform::PathTransform,
search::{FileReference, SearchScope},
RootDatabase, RootDatabase,
}; };
use itertools::izip; use itertools::izip;
@ -54,11 +57,14 @@ use crate::{
// } // }
// ``` // ```
pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext) -> Option<()> {
let def_file = ctx.frange.file_id;
let name = ctx.find_node_at_offset::<ast::Name>()?; let name = ctx.find_node_at_offset::<ast::Name>()?;
let func_syn = name.syntax().parent().and_then(ast::Fn::cast)?; let ast_func = name.syntax().parent().and_then(ast::Fn::cast)?;
let func_body = func_syn.body()?; let func_body = ast_func.body()?;
let param_list = func_syn.param_list()?; let param_list = ast_func.param_list()?;
let function = ctx.sema.to_def(&func_syn)?;
let function = ctx.sema.to_def(&ast_func)?;
let params = get_fn_params(ctx.sema.db, function, &param_list)?; let params = get_fn_params(ctx.sema.db, function, &param_list)?;
let usages = Definition::ModuleDef(hir::ModuleDef::Function(function)).usages(&ctx.sema); let usages = Definition::ModuleDef(hir::ModuleDef::Function(function)).usages(&ctx.sema);
@ -66,19 +72,28 @@ pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext) -> Opt
return None; return None;
} }
let is_recursive_fn = usages
.clone()
.in_scope(SearchScope::file_range(FileRange {
file_id: def_file,
range: func_body.syntax().text_range(),
}))
.at_least_one();
if is_recursive_fn {
cov_mark::hit!(inline_into_callers_recursive);
return None;
}
acc.add( acc.add(
AssistId("inline_into_callers", AssistKind::RefactorInline), AssistId("inline_into_callers", AssistKind::RefactorInline),
"Inline into all callers", "Inline into all callers",
name.syntax().text_range(), name.syntax().text_range(),
|builder| { |builder| {
let def_file = ctx.frange.file_id;
let usages =
Definition::ModuleDef(hir::ModuleDef::Function(function)).usages(&ctx.sema);
let mut usages = usages.all(); let mut usages = usages.all();
let current_file_usage = usages.references.remove(&def_file); let current_file_usage = usages.references.remove(&def_file);
let mut can_remove = true; let mut remove_def = true;
let mut inline_refs = |file_id, refs: Vec<FileReference>| { let mut inline_refs_for_file = |file_id, refs: Vec<FileReference>| {
builder.edit_file(file_id); builder.edit_file(file_id);
let count = refs.len(); let count = refs.len();
let name_refs = refs.into_iter().filter_map(|file_ref| match file_ref.name { let name_refs = refs.into_iter().filter_map(|file_ref| match file_ref.name {
@ -124,18 +139,18 @@ pub(crate) fn inline_into_callers(acc: &mut Assists, ctx: &AssistContext) -> Opt
); );
}) })
.count(); .count();
can_remove &= replaced == count; remove_def &= replaced == count;
}; };
for (file_id, refs) in usages.into_iter() { for (file_id, refs) in usages.into_iter() {
inline_refs(file_id, refs); inline_refs_for_file(file_id, refs);
} }
if let Some(refs) = current_file_usage { if let Some(refs) = current_file_usage {
inline_refs(def_file, refs); inline_refs_for_file(def_file, refs);
} else { } else {
builder.edit_file(def_file); builder.edit_file(def_file);
} }
if can_remove { if remove_def {
builder.delete(func_syn.syntax().text_range()); builder.delete(ast_func.syntax().text_range());
} }
}, },
) )
@ -201,10 +216,15 @@ pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
) )
}; };
let hir::InFile { value: function_source, file_id } = function.source(ctx.db())?; let fn_source = function.source(ctx.db())?;
let fn_body = function_source.body()?; let fn_body = fn_source.value.body()?;
let param_list = function_source.param_list()?; let param_list = fn_source.value.param_list()?;
let FileRange { file_id, range } = fn_source.syntax().original_file_range(ctx.sema.db);
if file_id == ctx.frange.file_id && range.contains(ctx.frange.range.start()) {
cov_mark::hit!(inline_call_recursive);
return None;
}
let params = get_fn_params(ctx.sema.db, function, &param_list)?; let params = get_fn_params(ctx.sema.db, function, &param_list)?;
if call_info.arguments.len() != params.len() { if call_info.arguments.len() != params.len() {
@ -220,7 +240,6 @@ pub(crate) fn inline_call(acc: &mut Assists, ctx: &AssistContext) -> Option<()>
label, label,
syntax.text_range(), syntax.text_range(),
|builder| { |builder| {
let file_id = file_id.original_file(ctx.sema.db);
let replacement = inline(&ctx.sema, file_id, function, &fn_body, &params, &call_info); let replacement = inline(&ctx.sema, file_id, function, &fn_body, &params, &call_info);
builder.replace_ast( builder.replace_ast(
@ -967,6 +986,32 @@ fn foo() {
foo * 0 + foo foo * 0 + foo
}; };
} }
"#,
);
}
#[test]
fn inline_callers_recursive() {
cov_mark::check!(inline_into_callers_recursive);
check_assist_not_applicable(
inline_into_callers,
r#"
fn foo$0() {
foo();
}
"#,
);
}
#[test]
fn inline_call_recursive() {
cov_mark::check!(inline_call_recursive);
check_assist_not_applicable(
inline_call,
r#"
fn foo() {
foo$0();
}
"#, "#,
); );
} }

View file

@ -315,6 +315,7 @@ impl Definition {
} }
} }
#[derive(Clone)]
pub struct FindUsages<'a> { pub struct FindUsages<'a> {
def: Definition, def: Definition,
sema: &'a Semantics<'a, RootDatabase>, sema: &'a Semantics<'a, RootDatabase>,
@ -341,7 +342,7 @@ impl<'a> FindUsages<'a> {
self self
} }
pub fn at_least_one(self) -> bool { pub fn at_least_one(&self) -> bool {
let mut found = false; let mut found = false;
self.search(&mut |_, _| { self.search(&mut |_, _| {
found = true; found = true;
@ -359,7 +360,7 @@ impl<'a> FindUsages<'a> {
res res
} }
fn search(self, sink: &mut dyn FnMut(FileId, FileReference) -> bool) { fn search(&self, sink: &mut dyn FnMut(FileId, FileReference) -> bool) {
let _p = profile::span("FindUsages:search"); let _p = profile::span("FindUsages:search");
let sema = self.sema; let sema = self.sema;