fix: make bool_to_enum assist create enum at top-level

This commit is contained in:
Ryan Mehri 2023-09-25 21:01:54 -07:00
parent d3cc3bc00e
commit bce4be9478
2 changed files with 163 additions and 36 deletions

View file

@ -16,7 +16,7 @@ use syntax::{
edit_in_place::{AttrsOwnerEdit, Indent}, edit_in_place::{AttrsOwnerEdit, Indent},
make, HasName, make, HasName,
}, },
ted, AstNode, NodeOrToken, SyntaxNode, T, match_ast, ted, AstNode, NodeOrToken, SyntaxNode, T,
}; };
use text_edit::TextRange; use text_edit::TextRange;
@ -40,10 +40,10 @@ use crate::assist_context::{AssistContext, Assists};
// ``` // ```
// -> // ->
// ``` // ```
// fn main() { // #[derive(PartialEq, Eq)]
// #[derive(PartialEq, Eq)] // enum Bool { True, False }
// enum Bool { True, False }
// //
// fn main() {
// let bool = Bool::True; // let bool = Bool::True;
// //
// if bool == Bool::True { // if bool == Bool::True {
@ -270,6 +270,10 @@ fn replace_usages(
} }
_ => (), _ => (),
} }
} else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name)
{
edit.replace(ty_annotation.syntax().text_range(), "Bool");
replace_bool_expr(edit, initializer);
} else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { } else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
// for any other usage in an expression, replace it with a check that it is the true variant // for any other usage in an expression, replace it with a check that it is the true variant
if let Some((record_field, expr)) = new_name if let Some((record_field, expr)) = new_name
@ -413,6 +417,15 @@ fn find_record_pat_field_usage(name: &ast::NameLike) -> Option<ast::Pat> {
} }
} }
fn find_assoc_const_usage(name: &ast::NameLike) -> Option<(ast::Type, ast::Expr)> {
let const_ = name.syntax().parent().and_then(ast::Const::cast)?;
if const_.syntax().parent().and_then(ast::AssocItemList::cast).is_none() {
return None;
}
Some((const_.ty()?, const_.body()?))
}
/// Adds the definition of the new enum before the target node. /// Adds the definition of the new enum before the target node.
fn add_enum_def( fn add_enum_def(
edit: &mut SourceChangeBuilder, edit: &mut SourceChangeBuilder,
@ -430,11 +443,12 @@ fn add_enum_def(
.any(|module| module.nearest_non_block_module(ctx.db()) != *target_module); .any(|module| module.nearest_non_block_module(ctx.db()) != *target_module);
let enum_def = make_bool_enum(make_enum_pub); let enum_def = make_bool_enum(make_enum_pub);
let indent = IndentLevel::from_node(&target_node); let insert_before = node_to_insert_before(target_node);
let indent = IndentLevel::from_node(&insert_before);
enum_def.reindent_to(indent); enum_def.reindent_to(indent);
ted::insert_all( ted::insert_all(
ted::Position::before(&edit.make_syntax_mut(target_node)), ted::Position::before(&edit.make_syntax_mut(insert_before)),
vec![ vec![
enum_def.syntax().clone().into(), enum_def.syntax().clone().into(),
make::tokens::whitespace(&format!("\n\n{indent}")).into(), make::tokens::whitespace(&format!("\n\n{indent}")).into(),
@ -442,6 +456,35 @@ fn add_enum_def(
); );
} }
/// Finds where to put the new enum definition, at the nearest module or at top-level.
fn node_to_insert_before(mut target_node: SyntaxNode) -> SyntaxNode {
let mut ancestors = target_node.ancestors();
while let Some(ancestor) = ancestors.next() {
match_ast! {
match ancestor {
ast::Item(item) => {
if item
.syntax()
.parent()
.and_then(|item_list| item_list.parent())
.and_then(ast::Module::cast)
.is_some()
{
return ancestor;
}
},
ast::SourceFile(_) => break,
_ => (),
}
}
target_node = ancestor;
}
target_node
}
fn make_bool_enum(make_pub: bool) -> ast::Enum { fn make_bool_enum(make_pub: bool) -> ast::Enum {
let enum_def = make::enum_( let enum_def = make::enum_(
if make_pub { Some(make::visibility_pub()) } else { None }, if make_pub { Some(make::visibility_pub()) } else { None },
@ -491,10 +534,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = Bool::True; let foo = Bool::True;
if foo == Bool::True { if foo == Bool::True {
@ -520,10 +563,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = Bool::True; let foo = Bool::True;
if foo == Bool::False { if foo == Bool::False {
@ -545,10 +588,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo: Bool = Bool::False; let foo: Bool = Bool::False;
} }
"#, "#,
@ -565,10 +608,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = if 1 == 2 { Bool::True } else { Bool::False }; let foo = if 1 == 2 { Bool::True } else { Bool::False };
} }
"#, "#,
@ -590,10 +633,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = Bool::False; let foo = Bool::False;
let bar = true; let bar = true;
@ -619,10 +662,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = Bool::True; let foo = Bool::True;
if *&foo == Bool::True { if *&foo == Bool::True {
@ -645,10 +688,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo: Bool; let foo: Bool;
foo = Bool::True; foo = Bool::True;
} }
@ -671,10 +714,10 @@ fn main() {
} }
"#, "#,
r#" r#"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let foo = Bool::True; let foo = Bool::True;
let bar = foo == Bool::False; let bar = foo == Bool::False;
@ -702,11 +745,11 @@ fn main() {
} }
"#, "#,
r#" r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
fn main() { fn main() {
if !"foo".chars().any(|c| { if !"foo".chars().any(|c| {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
let foo = Bool::True; let foo = Bool::True;
foo == Bool::True foo == Bool::True
}) { }) {
@ -1445,6 +1488,90 @@ pub mod bar {
) )
} }
#[test]
fn const_in_impl_cross_file() {
check_assist(
bool_to_enum,
r#"
//- /main.rs
mod foo;
struct Foo;
impl Foo {
pub const $0BOOL: bool = true;
}
//- /foo.rs
use crate::Foo;
fn foo() -> bool {
Foo::BOOL
}
"#,
r#"
//- /main.rs
mod foo;
struct Foo;
#[derive(PartialEq, Eq)]
pub enum Bool { True, False }
impl Foo {
pub const BOOL: Bool = Bool::True;
}
//- /foo.rs
use crate::{Foo, Bool};
fn foo() -> bool {
Foo::BOOL == Bool::True
}
"#,
)
}
#[test]
fn const_in_trait() {
check_assist(
bool_to_enum,
r#"
trait Foo {
const $0BOOL: bool;
}
impl Foo for usize {
const BOOL: bool = true;
}
fn main() {
if <usize as Foo>::BOOL {
println!("foo");
}
}
"#,
r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
trait Foo {
const BOOL: Bool;
}
impl Foo for usize {
const BOOL: Bool = Bool::True;
}
fn main() {
if <usize as Foo>::BOOL == Bool::True {
println!("foo");
}
}
"#,
)
}
#[test] #[test]
fn const_non_bool() { fn const_non_bool() {
cov_mark::check!(not_applicable_non_bool_const); cov_mark::check!(not_applicable_non_bool_const);

View file

@ -294,10 +294,10 @@ fn main() {
} }
"#####, "#####,
r#####" r#####"
fn main() { #[derive(PartialEq, Eq)]
#[derive(PartialEq, Eq)] enum Bool { True, False }
enum Bool { True, False }
fn main() {
let bool = Bool::True; let bool = Bool::True;
if bool == Bool::True { if bool == Bool::True {