diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index 85b0b87d0c..3303a2dd3c 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -16,7 +16,7 @@ use syntax::{ edit_in_place::{AttrsOwnerEdit, Indent}, make, HasName, }, - ted, AstNode, NodeOrToken, SyntaxNode, T, + match_ast, ted, AstNode, NodeOrToken, SyntaxNode, T, }; use text_edit::TextRange; @@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists}; // ``` // -> // ``` -// fn main() { -// #[derive(PartialEq, Eq)] -// enum Bool { True, False } +// #[derive(PartialEq, Eq)] +// enum Bool { True, False } // +// fn main() { // let bool = Bool::True; // // if bool == Bool::True { @@ -270,6 +270,10 @@ fn replace_usages( } _ => (), } + } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name) + { + edit.replace(ty_annotation.syntax().text_range(), "Bool"); + replace_bool_expr(edit, initializer); } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { // for any other usage in an expression, replace it with a check that it is the true variant if let Some((record_field, expr)) = new_name @@ -413,6 +417,15 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option { } } +fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> { + let const_ = name.syntax().parent().and_then(ast::Const::cast)?; + if const_.syntax().parent().and_then(ast::AssocItemList::cast).is_none() { + return None; + } + + Some((const_.ty()?, const_.body()?)) +} + /// Adds the definition of the new enum before the target node. fn add_enum_def( edit: &mut SourceChangeBuilder, @@ -430,11 +443,12 @@ fn add_enum_def( .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); let enum_def = make_bool_enum(make_enum_pub); - let indent = IndentLevel::from_node(&target_node); + let insert_before = node_to_insert_before(target_node); + let indent = IndentLevel::from_node(&insert_before); enum_def.reindent_to(indent); ted::insert_all( - ted::Position::before(&edit.make_syntax_mut(target_node)), + ted::Position::before(&edit.make_syntax_mut(insert_before)), vec![ enum_def.syntax().clone().into(), make::tokens::whitespace(&format!("\n\n{indent}")).into(), @@ -442,6 +456,35 @@ fn add_enum_def( ); } +/// Finds where to put the new enum definition, at the nearest module or at top-level. +fn node_to_insert_before(mut target_node: SyntaxNode) -> SyntaxNode { + let mut ancestors = target_node.ancestors(); + + while let Some(ancestor) = ancestors.next() { + match_ast! { + match ancestor { + ast::Item(item) => { + if item + .syntax() + .parent() + .and_then(|item_list| item_list.parent()) + .and_then(ast::Module::cast) + .is_some() + { + return ancestor; + } + }, + ast::SourceFile(_) => break, + _ => (), + } + } + + target_node = ancestor; + } + + target_node +} + fn make_bool_enum(make_pub: bool) -> ast::Enum { let enum_def = make::enum_( if make_pub { Some(make::visibility_pub()) } else { None }, @@ -491,10 +534,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; if foo == Bool::True { @@ -520,10 +563,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; if foo == Bool::False { @@ -545,10 +588,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo: Bool = Bool::False; } "#, @@ -565,10 +608,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = if 1 == 2 { Bool::True } else { Bool::False }; } "#, @@ -590,10 +633,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::False; let bar = true; @@ -619,10 +662,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; if *&foo == Bool::True { @@ -645,10 +688,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo: Bool; foo = Bool::True; } @@ -671,10 +714,10 @@ fn main() { } "#, r#" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let foo = Bool::True; let bar = foo == Bool::False; @@ -702,11 +745,11 @@ fn main() { } "#, r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + fn main() { if !"foo".chars().any(|c| { - #[derive(PartialEq, Eq)] - enum Bool { True, False } - let foo = Bool::True; foo == Bool::True }) { @@ -1445,6 +1488,90 @@ pub mod bar { ) } + #[test] + fn const_in_impl_cross_file() { + check_assist( + bool_to_enum, + r#" +//- /main.rs +mod foo; + +struct Foo; + +impl Foo { + pub const $0BOOL: bool = true; +} + +//- /foo.rs +use crate::Foo; + +fn foo() -> bool { + Foo::BOOL +} +"#, + r#" +//- /main.rs +mod foo; + +struct Foo; + +#[derive(PartialEq, Eq)] +pub enum Bool { True, False } + +impl Foo { + pub const BOOL: Bool = Bool::True; +} + +//- /foo.rs +use crate::{Foo, Bool}; + +fn foo() -> bool { + Foo::BOOL == Bool::True +} +"#, + ) + } + + #[test] + fn const_in_trait() { + check_assist( + bool_to_enum, + r#" +trait Foo { + const $0BOOL: bool; +} + +impl Foo for usize { + const BOOL: bool = true; +} + +fn main() { + if ::BOOL { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +trait Foo { + const BOOL: Bool; +} + +impl Foo for usize { + const BOOL: Bool = Bool::True; +} + +fn main() { + if ::BOOL == Bool::True { + println!("foo"); + } +} +"#, + ) + } + #[test] fn const_non_bool() { cov_mark::check!(not_applicable_non_bool_const); diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs index 63a08a0e56..5a815d5c6a 100644 --- a/crates/ide-assists/src/tests/generated.rs +++ b/crates/ide-assists/src/tests/generated.rs @@ -294,10 +294,10 @@ fn main() { } "#####, r#####" -fn main() { - #[derive(PartialEq, Eq)] - enum Bool { True, False } +#[derive(PartialEq, Eq)] +enum Bool { True, False } +fn main() { let bool = Bool::True; if bool == Bool::True {