style: clean up bool_to_enum assist

This commit is contained in:
Ryan Mehri 2023-10-27 07:59:15 -07:00 committed by Lukas Wirth
parent 2034556f81
commit b5e0edf427

View file

@ -16,7 +16,7 @@ use syntax::{
edit_in_place::{AttrsOwnerEdit, Indent}, edit_in_place::{AttrsOwnerEdit, Indent},
make, HasName, make, HasName,
}, },
ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, T,
}; };
use text_edit::TextRange; use text_edit::TextRange;
@ -73,7 +73,7 @@ pub(crate) fn bool_to_enum(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
let usages = definition.usages(&ctx.sema).all(); let usages = definition.usages(&ctx.sema).all();
add_enum_def(edit, ctx, &usages, target_node, &target_module); add_enum_def(edit, ctx, &usages, target_node, &target_module);
replace_usages(edit, ctx, &usages, definition, &target_module); replace_usages(edit, ctx, usages, definition, &target_module);
}, },
) )
} }
@ -192,21 +192,19 @@ fn bool_expr_to_enum_expr(expr: ast::Expr) -> ast::Expr {
fn replace_usages( fn replace_usages(
edit: &mut SourceChangeBuilder, edit: &mut SourceChangeBuilder,
ctx: &AssistContext<'_>, ctx: &AssistContext<'_>,
usages: &UsageSearchResult, usages: UsageSearchResult,
target_definition: Definition, target_definition: Definition,
target_module: &hir::Module, target_module: &hir::Module,
) { ) {
for (file_id, references) in usages.iter() { for (file_id, references) in usages {
edit.edit_file(*file_id); edit.edit_file(file_id);
let refs_with_imports = let refs_with_imports = augment_references_with_imports(ctx, references, target_module);
augment_references_with_imports(edit, ctx, references, target_module);
refs_with_imports.into_iter().rev().for_each( refs_with_imports.into_iter().rev().for_each(
|FileReferenceWithImport { range, old_name, new_name, import_data }| { |FileReferenceWithImport { range, name, import_data }| {
// replace the usages in patterns and expressions // replace the usages in patterns and expressions
if let Some(ident_pat) = old_name.syntax().ancestors().find_map(ast::IdentPat::cast) if let Some(ident_pat) = name.syntax().ancestors().find_map(ast::IdentPat::cast) {
{
cov_mark::hit!(replaces_record_pat_shorthand); cov_mark::hit!(replaces_record_pat_shorthand);
let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local); let definition = ctx.sema.to_def(&ident_pat).map(Definition::Local);
@ -214,36 +212,35 @@ fn replace_usages(
replace_usages( replace_usages(
edit, edit,
ctx, ctx,
&def.usages(&ctx.sema).all(), def.usages(&ctx.sema).all(),
target_definition, target_definition,
target_module, target_module,
) )
} }
} else if let Some(initializer) = find_assignment_usage(&new_name) { } else if let Some(initializer) = find_assignment_usage(&name) {
cov_mark::hit!(replaces_assignment); cov_mark::hit!(replaces_assignment);
replace_bool_expr(edit, initializer); replace_bool_expr(edit, initializer);
} else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&new_name) { } else if let Some((prefix_expr, inner_expr)) = find_negated_usage(&name) {
cov_mark::hit!(replaces_negation); cov_mark::hit!(replaces_negation);
edit.replace( edit.replace(
prefix_expr.syntax().text_range(), prefix_expr.syntax().text_range(),
format!("{} == Bool::False", inner_expr), format!("{} == Bool::False", inner_expr),
); );
} else if let Some((record_field, initializer)) = old_name } else if let Some((record_field, initializer)) = name
.as_name_ref() .as_name_ref()
.and_then(ast::RecordExprField::for_field_name) .and_then(ast::RecordExprField::for_field_name)
.and_then(|record_field| ctx.sema.resolve_record_field(&record_field)) .and_then(|record_field| ctx.sema.resolve_record_field(&record_field))
.and_then(|(got_field, _, _)| { .and_then(|(got_field, _, _)| {
find_record_expr_usage(&new_name, got_field, target_definition) find_record_expr_usage(&name, got_field, target_definition)
}) })
{ {
cov_mark::hit!(replaces_record_expr); cov_mark::hit!(replaces_record_expr);
let record_field = edit.make_mut(record_field);
let enum_expr = bool_expr_to_enum_expr(initializer); let enum_expr = bool_expr_to_enum_expr(initializer);
record_field.replace_expr(enum_expr); replace_record_field_expr(edit, record_field, enum_expr);
} else if let Some(pat) = find_record_pat_field_usage(&old_name) { } else if let Some(pat) = find_record_pat_field_usage(&name) {
match pat { match pat {
ast::Pat::IdentPat(ident_pat) => { ast::Pat::IdentPat(ident_pat) => {
cov_mark::hit!(replaces_record_pat); cov_mark::hit!(replaces_record_pat);
@ -253,7 +250,7 @@ fn replace_usages(
replace_usages( replace_usages(
edit, edit,
ctx, ctx,
&def.usages(&ctx.sema).all(), def.usages(&ctx.sema).all(),
target_definition, target_definition,
target_module, target_module,
) )
@ -270,79 +267,94 @@ fn replace_usages(
} }
_ => (), _ => (),
} }
} else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&new_name) } else if let Some((ty_annotation, initializer)) = find_assoc_const_usage(&name) {
{
edit.replace(ty_annotation.syntax().text_range(), "Bool"); edit.replace(ty_annotation.syntax().text_range(), "Bool");
replace_bool_expr(edit, initializer); replace_bool_expr(edit, initializer);
} else if let Some(receiver) = find_method_call_expr_usage(&new_name) { } else if let Some(receiver) = find_method_call_expr_usage(&name) {
edit.replace( edit.replace(
receiver.syntax().text_range(), receiver.syntax().text_range(),
format!("({} == Bool::True)", receiver), format!("({} == Bool::True)", receiver),
); );
} else if new_name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() { } else if name.syntax().ancestors().find_map(ast::UseTree::cast).is_none() {
// for any other usage in an expression, replace it with a check that it is the true variant // for any other usage in an expression, replace it with a check that it is the true variant
if let Some((record_field, expr)) = new_name if let Some((record_field, expr)) =
.as_name_ref() name.as_name_ref().and_then(ast::RecordExprField::for_field_name).and_then(
.and_then(ast::RecordExprField::for_field_name) |record_field| record_field.expr().map(|expr| (record_field, expr)),
.and_then(|record_field| { )
record_field.expr().map(|expr| (record_field, expr))
})
{ {
record_field.replace_expr( replace_record_field_expr(
edit,
record_field,
make::expr_bin_op( make::expr_bin_op(
expr, expr,
ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }), ast::BinaryOp::CmpOp(ast::CmpOp::Eq { negated: false }),
make::expr_path(make::path_from_text("Bool::True")), make::expr_path(make::path_from_text("Bool::True")),
) ),
.clone_for_update(),
); );
} else { } else {
edit.replace(range, format!("{} == Bool::True", new_name.text())); edit.replace(range, format!("{} == Bool::True", name.text()));
} }
} }
// add imports across modules where needed // add imports across modules where needed
if let Some((import_scope, path)) = import_data { if let Some((import_scope, path)) = import_data {
insert_use(&import_scope, path, &ctx.config.insert_use); let scope = match import_scope.clone() {
ImportScope::File(it) => ImportScope::File(edit.make_mut(it)),
ImportScope::Module(it) => ImportScope::Module(edit.make_mut(it)),
ImportScope::Block(it) => ImportScope::Block(edit.make_mut(it)),
};
insert_use(&scope, path, &ctx.config.insert_use);
} }
}, },
) )
} }
} }
/// Replaces the record expression, handling field shorthands.
fn replace_record_field_expr(
edit: &mut SourceChangeBuilder,
record_field: ast::RecordExprField,
initializer: ast::Expr,
) {
if let Some(ast::Expr::PathExpr(path_expr)) = record_field.expr() {
// replace field shorthand
edit.insert(
path_expr.syntax().text_range().end(),
format!(": {}", initializer.syntax().text()),
)
} else if let Some(expr) = record_field.expr() {
// just replace expr
edit.replace_ast(expr, initializer);
}
}
struct FileReferenceWithImport { struct FileReferenceWithImport {
range: TextRange, range: TextRange,
old_name: ast::NameLike, name: ast::NameLike,
new_name: ast::NameLike,
import_data: Option<(ImportScope, ast::Path)>, import_data: Option<(ImportScope, ast::Path)>,
} }
fn augment_references_with_imports( fn augment_references_with_imports(
edit: &mut SourceChangeBuilder,
ctx: &AssistContext<'_>, ctx: &AssistContext<'_>,
references: &[FileReference], references: Vec<FileReference>,
target_module: &hir::Module, target_module: &hir::Module,
) -> Vec<FileReferenceWithImport> { ) -> Vec<FileReferenceWithImport> {
let mut visited_modules = FxHashSet::default(); let mut visited_modules = FxHashSet::default();
references references
.iter() .into_iter()
.filter_map(|FileReference { range, name, .. }| { .filter_map(|FileReference { range, name, .. }| {
let name = name.clone().into_name_like()?; let name = name.clone().into_name_like()?;
ctx.sema.scope(name.syntax()).map(|scope| (*range, name, scope.module())) ctx.sema.scope(name.syntax()).map(|scope| (range, name, scope.module()))
}) })
.map(|(range, name, ref_module)| { .map(|(range, name, ref_module)| {
let old_name = name.clone();
let new_name = edit.make_mut(name.clone());
// if the referenced module is not the same as the target one and has not been seen before, add an import // if the referenced module is not the same as the target one and has not been seen before, add an import
let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module let import_data = if ref_module.nearest_non_block_module(ctx.db()) != *target_module
&& !visited_modules.contains(&ref_module) && !visited_modules.contains(&ref_module)
{ {
visited_modules.insert(ref_module); visited_modules.insert(ref_module);
let import_scope = let import_scope = ImportScope::find_insert_use_container(name.syntax(), &ctx.sema);
ImportScope::find_insert_use_container(new_name.syntax(), &ctx.sema);
let path = ref_module let path = ref_module
.find_use_path_prefixed( .find_use_path_prefixed(
ctx.sema.db, ctx.sema.db,
@ -360,7 +372,7 @@ fn augment_references_with_imports(
None None
}; };
FileReferenceWithImport { range, old_name, new_name, import_data } FileReferenceWithImport { range, name, import_data }
}) })
.collect() .collect()
} }
@ -465,12 +477,9 @@ fn add_enum_def(
let indent = IndentLevel::from_node(&insert_before); let indent = IndentLevel::from_node(&insert_before);
enum_def.reindent_to(indent); enum_def.reindent_to(indent);
ted::insert_all( edit.insert(
ted::Position::before(&edit.make_syntax_mut(insert_before)), insert_before.text_range().start(),
vec![ format!("{}\n\n{indent}", enum_def.syntax().text()),
enum_def.syntax().clone().into(),
make::tokens::whitespace(&format!("\n\n{indent}")).into(),
],
); );
} }