diff --git a/crates/ide-assists/src/handlers/bool_to_enum.rs b/crates/ide-assists/src/handlers/bool_to_enum.rs index f59b052813..784a0d3559 100644 --- a/crates/ide-assists/src/handlers/bool_to_enum.rs +++ b/crates/ide-assists/src/handlers/bool_to_enum.rs @@ -263,6 +263,11 @@ fn replace_usages( fn find_assignment_usage(name_ref: &ast::NameRef) -> Option { let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?; + if !bin_expr.lhs()?.syntax().descendants().contains(name_ref.syntax()) { + cov_mark::hit!(dont_assign_incorrect_ref); + return None; + } + if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() { bin_expr.rhs() } else { @@ -273,6 +278,11 @@ fn find_assignment_usage(name_ref: &ast::NameRef) -> Option { fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> { let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?; + if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) { + cov_mark::hit!(dont_overwrite_expression_inside_negation); + return None; + } + if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() { let inner_expr = prefix_expr.expr()?; Some((prefix_expr, inner_expr)) @@ -285,7 +295,12 @@ fn find_record_expr_usage(name_ref: &ast::NameRef) -> Option<(ast::RecordExprFie let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?; let initializer = record_field.expr()?; - Some((record_field, initializer)) + if record_field.field_name()?.syntax().descendants().contains(name_ref.syntax()) { + Some((record_field, initializer)) + } else { + cov_mark::hit!(dont_overwrite_wrong_record_field); + None + } } /// Adds the definition of the new enum before the target node. @@ -561,6 +576,37 @@ fn main() { ) } + #[test] + fn local_variable_nested_in_negation() { + cov_mark::check!(dont_overwrite_expression_inside_negation); + check_assist( + bool_to_enum, + r#" +fn main() { + if !"foo".chars().any(|c| { + let $0foo = true; + foo + }) { + println!("foo"); + } +} +"#, + r#" +fn main() { + if !"foo".chars().any(|c| { + #[derive(PartialEq, Eq)] + enum Bool { True, False } + + let foo = Bool::True; + foo == Bool::True + }) { + println!("foo"); + } +} +"#, + ) + } + #[test] fn local_variable_non_bool() { cov_mark::check!(not_applicable_non_bool_local); @@ -638,6 +684,42 @@ fn main() { ) } + #[test] + fn field_negated() { + check_assist( + bool_to_enum, + r#" +struct Foo { + $0bar: bool, +} + +fn main() { + let foo = Foo { bar: false }; + + if !foo.bar { + println!("foo"); + } +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +struct Foo { + bar: Bool, +} + +fn main() { + let foo = Foo { bar: Bool::False }; + + if foo.bar == Bool::False { + println!("foo"); + } +} +"#, + ) + } + #[test] fn field_in_mod_properly_indented() { check_assist( @@ -714,6 +796,88 @@ fn main() { ) } + #[test] + fn field_assigned_to_another() { + cov_mark::check!(dont_assign_incorrect_ref); + check_assist( + bool_to_enum, + r#" +struct Foo { + $0foo: bool, +} + +struct Bar { + bar: bool, +} + +fn main() { + let foo = Foo { foo: true }; + let mut bar = Bar { bar: true }; + + bar.bar = foo.foo; +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +struct Foo { + foo: Bool, +} + +struct Bar { + bar: bool, +} + +fn main() { + let foo = Foo { foo: Bool::True }; + let mut bar = Bar { bar: true }; + + bar.bar = foo.foo == Bool::True; +} +"#, + ) + } + + #[test] + fn field_initialized_with_other() { + cov_mark::check!(dont_overwrite_wrong_record_field); + check_assist( + bool_to_enum, + r#" +struct Foo { + $0foo: bool, +} + +struct Bar { + bar: bool, +} + +fn main() { + let foo = Foo { foo: true }; + let bar = Bar { bar: foo.foo }; +} +"#, + r#" +#[derive(PartialEq, Eq)] +enum Bool { True, False } + +struct Foo { + foo: Bool, +} + +struct Bar { + bar: bool, +} + +fn main() { + let foo = Foo { foo: Bool::True }; + let bar = Bar { bar: foo.foo == Bool::True }; +} +"#, + ) + } + #[test] fn field_non_bool() { cov_mark::check!(not_applicable_non_bool_field);