feat: support cross module imports

This commit is contained in:
Ryan Mehri 2023-09-09 11:40:29 -07:00
parent 136a9dbe36
commit 2e13aed3bc

View file

@ -1,9 +1,13 @@
use hir::ModuleDef;
use ide_db::{ use ide_db::{
assists::{AssistId, AssistKind}, assists::{AssistId, AssistKind},
defs::Definition, defs::Definition,
search::{FileReference, SearchScope, UsageSearchResult}, helpers::mod_path_to_ast,
imports::insert_use::{insert_use, ImportScope},
search::{FileReference, UsageSearchResult},
source_change::SourceChangeBuilder, source_change::SourceChangeBuilder,
}; };
use itertools::Itertools;
use syntax::{ use syntax::{
ast::{ ast::{
self, self,
@ -48,6 +52,7 @@ use crate::assist_context::{AssistContext, Assists};
pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let BoolNodeData { target_node, name, ty_annotation, initializer, definition } = let BoolNodeData { target_node, name, ty_annotation, initializer, definition } =
find_bool_node(ctx)?; find_bool_node(ctx)?;
let target_module = ctx.sema.scope(&target_node)?.module();
let target = name.syntax().text_range(); let target = name.syntax().text_range();
acc.add( acc.add(
@ -64,13 +69,10 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
replace_bool_expr(edit, initializer); replace_bool_expr(edit, initializer);
} }
let usages = definition let usages = definition.usages(&ctx.sema).all();
.usages(&ctx.sema)
.in_scope(&SearchScope::single_file(ctx.file_id()))
.all();
replace_usages(edit, &usages);
add_enum_def(edit, ctx, &usages, target_node); add_enum_def(edit, ctx, &usages, target_node, &target_module);
replace_usages(edit, ctx, &usages, &target_module);
}, },
) )
} }
@ -186,8 +188,45 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
} }
/// Replaces all usages of the target identifier, both when read and written to. /// Replaces all usages of the target identifier, both when read and written to.
fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) { fn replace_usages(
for (_, references) in usages.iter() { edit: &mut SourceChangeBuilder,
ctx: &AssistContext<'_>,
usages: &UsageSearchResult,
target_module: &hir::Module,
) {
for (file_id, references) in usages.iter() {
edit.edit_file(*file_id);
// add imports across modules where needed
references
.iter()
.filter_map(|FileReference { name, .. }| {
ctx.sema.scope(name.syntax()).map(|scope| (name, scope.module()))
})
.unique_by(|name_and_module| name_and_module.1)
.filter(|(_, module)| module != target_module)
.filter_map(|(name, module)| {
let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema);
let mod_path = module.find_use_path_prefixed(
ctx.sema.db,
ModuleDef::Module(*target_module),
ctx.config.insert_use.prefix_kind,
ctx.config.prefer_no_std,
);
import_scope.zip(mod_path)
})
.for_each(|(import_scope, mod_path)| {
let import_scope = match import_scope {
ImportScope::File(it) => ImportScope::File(edit.make_mut(it)),
ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)),
ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)),
};
let path =
make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool"));
insert_use(&import_scope, path, &ctx.config.insert_use);
});
// replace the usages in expressions
references references
.into_iter() .into_iter()
.filter_map(|FileReference { range, name, .. }| match name { .filter_map(|FileReference { range, name, .. }| match name {
@ -213,7 +252,7 @@ fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) {
let record_field = edit.make_mut(record_field); let record_field = edit.make_mut(record_field);
let enum_expr = bool_expr_to_enum_expr(initializer); let enum_expr = bool_expr_to_enum_expr(initializer);
record_field.replace_expr(enum_expr); record_field.replace_expr(enum_expr);
} else if name_ref.syntax().ancestors().find_map(ast::Expr::cast).is_some() { } else if name_ref.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
// for any other usage in an expression, replace it with a check that it is the true variant // for any other usage in an expression, replace it with a check that it is the true variant
edit.replace(range, format!("{} == Bool::True", name_ref.text())); edit.replace(range, format!("{} == Bool::True", name_ref.text()));
} }
@ -255,8 +294,15 @@ fn add_enum_def(
ctx: &AssistContext<'_>, ctx: &AssistContext<'_>,
usages: &UsageSearchResult, usages: &UsageSearchResult,
target_node: SyntaxNode, target_node: SyntaxNode,
target_module: &hir::Module,
) { ) {
let make_enum_pub = usages.iter().any(|(file_id, _)| file_id != &ctx.file_id()); let make_enum_pub = usages
.iter()
.flat_map(|(_, refs)| refs)
.filter_map(|FileReference { name, .. }| {
ctx.sema.scope(name.syntax()).map(|scope| scope.module())
})
.any(|module| &module != target_module);
let enum_def = make_bool_enum(make_enum_pub); let enum_def = make_bool_enum(make_enum_pub);
let indent = IndentLevel::from_node(&target_node); let indent = IndentLevel::from_node(&target_node);
@ -649,7 +695,7 @@ fn main() {
"#, "#,
r#" r#"
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
enum $0Bool { True, False } enum Bool { True, False }
struct Foo { struct Foo {
bar: Bool, bar: Bool,
@ -713,6 +759,162 @@ fn main() {
) )
} }
#[test]
fn const_in_module() {
check_assist(
bool_to_enum,
r#"
fn main() {
if foo::FOO {
println!("foo");
}
}
mod foo {
pub const $0FOO: bool = true;
}
"#,
r#"
use foo::Bool;
fn main() {
if foo::FOO == Bool::True {
println!("foo");
}
}
mod foo {
#[derive(PartialEq, Eq)]
pub enum Bool { True, False }
pub const FOO: Bool = Bool::True;
}
"#,
)
}
#[test]
fn const_in_module_with_import() {
check_assist(
bool_to_enum,
r#"
fn main() {
use foo::FOO;
if FOO {
println!("foo");
}
}
mod foo {
pub const $0FOO: bool = true;
}
"#,
r#"
use crate::foo::Bool;
fn main() {
use foo::FOO;
if FOO == Bool::True {
println!("foo");
}
}
mod foo {
#[derive(PartialEq, Eq)]
pub enum Bool { True, False }
pub const FOO: Bool = Bool::True;
}
"#,
)
}
#[test]
fn const_cross_file() {
check_assist(
bool_to_enum,
r#"
//- /main.rs
mod foo;
fn main() {
if foo::FOO {
println!("foo");
}
}
//- /foo.rs
pub const $0FOO: bool = true;
"#,
r#"
//- /main.rs
use foo::Bool;
mod foo;
fn main() {
if foo::FOO == Bool::True {
println!("foo");
}
}
//- /foo.rs
#[derive(PartialEq, Eq)]
pub enum Bool { True, False }
pub const FOO: Bool = Bool::True;
"#,
)
}
#[test]
fn const_cross_file_and_module() {
check_assist(
bool_to_enum,
r#"
//- /main.rs
mod foo;
fn main() {
use foo::bar;
if bar::BAR {
println!("foo");
}
}
//- /foo.rs
pub mod bar {
pub const $0BAR: bool = false;
}
"#,
r#"
//- /main.rs
use crate::foo::bar::Bool;
mod foo;
fn main() {
use foo::bar;
if bar::BAR == Bool::True {
println!("foo");
}
}
//- /foo.rs
pub mod bar {
#[derive(PartialEq, Eq)]
pub enum Bool { True, False }
pub const BAR: Bool = Bool::False;
}
"#,
)
}
#[test] #[test]
fn const_non_bool() { fn const_non_bool() {
cov_mark::check!(not_applicable_non_bool_const); cov_mark::check!(not_applicable_non_bool_const);