From 2e13aed3bc235d47d92f9ce3b8fd4fa3c5f87939 Mon Sep 17 00:00:00 2001 From: Ryan Mehri Date: Sat, 9 Sep 2023 11:40:29 -0700 Subject: [PATCH] feat: support cross module imports --- .../ide-assists/src/handlers/bool_to_enum.rs | 226 +++++++++++++++++- 1 file changed, 214 insertions(+), 12 deletions(-) diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 9752264844..f59b052813 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -1,9 +1,13 @@ +use hir::ModuleDef; use ide_db::{ assists::{AssistId, AssistKind}, defs::Definition, - search::{FileReference, SearchScope, UsageSearchResult}, + helpers::mod_path_to_ast, + imports::insert_use::{insert_use, ImportScope}, + search::{FileReference, UsageSearchResult}, source_change::SourceChangeBuilder, }; +use itertools::Itertools; use syntax::{ ast::{ self, @@ -48,6 +52,7 @@ use crate::assist_context::{AssistContext, Assists}; pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let BoolNodeData { target_node, name, ty_annotation, initializer, definition } = find_bool_node(ctx)?; + let target_module = ctx.sema.scope(&target_node)?.module(); let target = name.syntax().text_range(); acc.add( @@ -64,13 +69,10 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option replace_bool_expr(edit, initializer); } - let usages = definition - .usages(&ctx.sema) - .in_scope(&SearchScope::single_file(ctx.file_id())) - .all(); - replace_usages(edit, &usages); + let usages = definition.usages(&ctx.sema).all(); - add_enum_def(edit, ctx, &usages, target_node); + add_enum_def(edit, ctx, &usages, target_node, &target_module); + replace_usages(edit, ctx, &usages, &target_module); }, ) } @@ -186,8 +188,45 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr { } /// Replaces all usages of the target identifier, both when read and written to. -fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) { - for (_, references) in usages.iter() { +fn replace_usages( + edit: &mut SourceChangeBuilder, + ctx: &AssistContext<'_>, + usages: &UsageSearchResult, + target_module: &hir::Module, +) { + for (file_id, references) in usages.iter() { + edit.edit_file(*file_id); + + // add imports across modules where needed + references + .iter() + .filter_map(|FileReference { name, .. }| { + ctx.sema.scope(name.syntax()).map(|scope| (name, scope.module())) + }) + .unique_by(|name_and_module| name_and_module.1) + .filter(|(_, module)| module != target_module) + .filter_map(|(name, module)| { + let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema); + let mod_path = module.find_use_path_prefixed( + ctx.sema.db, + ModuleDef::Module(*target_module), + ctx.config.insert_use.prefix_kind, + ctx.config.prefer_no_std, + ); + import_scope.zip(mod_path) + }) + .for_each(|(import_scope, mod_path)| { + let import_scope = match import_scope { + ImportScope::File(it) => ImportScope::File(edit.make_mut(it)), + ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)), + ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)), + }; + let path = + make::path_concat(mod_path_to_ast(&mod_path), make::path_from_text("Bool")); + insert_use(&import_scope, path, &ctx.config.insert_use); + }); + + // replace the usages in expressions references .into_iter() .filter_map(|FileReference { range, name, .. }| match name { @@ -213,7 +252,7 @@ fn replace_usages(edit: &mut SourceChangeBuilder, usages: &UsageSearchResult) { let record_field = edit.make_mut(record_field); let enum_expr = bool_expr_to_enum_expr(initializer); record_field.replace_expr(enum_expr); - } else if name_ref.syntax().ancestors().find_map(ast::Expr::cast).is_some() { + } else if name_ref.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { // for any other usage in an expression, replace it with a check that it is the true variant edit.replace(range, format!("{} == Bool::True", name_ref.text())); } @@ -255,8 +294,15 @@ fn add_enum_def( ctx: &AssistContext<'_>, usages: &UsageSearchResult, target_node: SyntaxNode, + target_module: &hir::Module, ) { - let make_enum_pub = usages.iter().any(|(file_id, _)| file_id != &ctx.file_id()); + let make_enum_pub = usages + .iter() + .flat_map(|(_, refs)| refs) + .filter_map(|FileReference { name, .. }| { + ctx.sema.scope(name.syntax()).map(|scope| scope.module()) + }) + .any(|module| &module != target_module); let enum_def = make_bool_enum(make_enum_pub); let indent = IndentLevel::from_node(&target_node); @@ -649,7 +695,7 @@ fn main() { "#, r#" #[derive(PartialEq, Eq)] -enum $0Bool { True, False } +enum Bool { True, False } struct Foo { bar: Bool, @@ -713,6 +759,162 @@ fn main() { ) } + #[test] + fn const_in_module() { + check_assist( + bool_to_enum, + r#" +fn main() { + if foo::FOO { + println!("foo"); + } +} + +mod foo { + pub const $0FOO: bool = true; +} +"#, + r#" +use foo::Bool; + +fn main() { + if foo::FOO == Bool::True { + println!("foo"); + } +} + +mod foo { + #[derive(PartialEq, Eq)] + pub enum Bool { True, False } + + pub const FOO: Bool = Bool::True; +} +"#, + ) + } + + #[test] + fn const_in_module_with_import() { + check_assist( + bool_to_enum, + r#" +fn main() { + use foo::FOO; + + if FOO { + println!("foo"); + } +} + +mod foo { + pub const $0FOO: bool = true; +} +"#, + r#" +use crate::foo::Bool; + +fn main() { + use foo::FOO; + + if FOO == Bool::True { + println!("foo"); + } +} + +mod foo { + #[derive(PartialEq, Eq)] + pub enum Bool { True, False } + + pub const FOO: Bool = Bool::True; +} +"#, + ) + } + + #[test] + fn const_cross_file() { + check_assist( + bool_to_enum, + r#" +//- /main.rs +mod foo; + +fn main() { + if foo::FOO { + println!("foo"); + } +} + +//- /foo.rs +pub const $0FOO: bool = true; +"#, + r#" +//- /main.rs +use foo::Bool; + +mod foo; + +fn main() { + if foo::FOO == Bool::True { + println!("foo"); + } +} + +//- /foo.rs +#[derive(PartialEq, Eq)] +pub enum Bool { True, False } + +pub const FOO: Bool = Bool::True; +"#, + ) + } + + #[test] + fn const_cross_file_and_module() { + check_assist( + bool_to_enum, + r#" +//- /main.rs +mod foo; + +fn main() { + use foo::bar; + + if bar::BAR { + println!("foo"); + } +} + +//- /foo.rs +pub mod bar { + pub const $0BAR: bool = false; +} +"#, + r#" +//- /main.rs +use crate::foo::bar::Bool; + +mod foo; + +fn main() { + use foo::bar; + + if bar::BAR == Bool::True { + println!("foo"); + } +} + +//- /foo.rs +pub mod bar { + #[derive(PartialEq, Eq)] + pub enum Bool { True, False } + + pub const BAR: Bool = Bool::False; +} +"#, + ) + } + #[test] fn const_non_bool() { cov_mark::check!(not_applicable_non_bool_const);