fix: add checks for overwriting incorrect ancestor

This commit is contained in:
Ryan Mehri 2023-09-09 23:54:25 -07:00
parent 2e13aed3bc
commit 7ba2e130b9

View file

@ -263,6 +263,11 @@ fn replace_usages(
fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> { fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> {
let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?; let bin_expr = name_ref.syntax().ancestors().find_map(ast::BinExpr::cast)?;
if !bin_expr.lhs()?.syntax().descendants().contains(name_ref.syntax()) {
cov_mark::hit!(dont_assign_incorrect_ref);
return None;
}
if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() { if let Some(ast::BinaryOp::Assignment { op: None }) = bin_expr.op_kind() {
bin_expr.rhs() bin_expr.rhs()
} else { } else {
@ -273,6 +278,11 @@ fn find_assignment_usage(name_ref: &ast::NameRef) -> Option<ast::Expr> {
fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> { fn find_negated_usage(name_ref: &ast::NameRef) -> Option<(ast::PrefixExpr, ast::Expr)> {
let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?; let prefix_expr = name_ref.syntax().ancestors().find_map(ast::PrefixExpr::cast)?;
if !matches!(prefix_expr.expr()?, ast::Expr::PathExpr(_) | ast::Expr::FieldExpr(_)) {
cov_mark::hit!(dont_overwrite_expression_inside_negation);
return None;
}
if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() { if let Some(ast::UnaryOp::Not) = prefix_expr.op_kind() {
let inner_expr = prefix_expr.expr()?; let inner_expr = prefix_expr.expr()?;
Some((prefix_expr, inner_expr)) Some((prefix_expr, inner_expr))
@ -285,7 +295,12 @@ fn find_record_expr_usage(name_ref: &ast::NameRef) -> Option<(ast::RecordExprFie
let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?; let record_field = name_ref.syntax().ancestors().find_map(ast::RecordExprField::cast)?;
let initializer = record_field.expr()?; let initializer = record_field.expr()?;
if record_field.field_name()?.syntax().descendants().contains(name_ref.syntax()) {
Some((record_field, initializer)) Some((record_field, initializer))
} else {
cov_mark::hit!(dont_overwrite_wrong_record_field);
None
}
} }
/// Adds the definition of the new enum before the target node. /// Adds the definition of the new enum before the target node.
@ -561,6 +576,37 @@ fn main() {
) )
} }
#[test]
fn local_variable_nested_in_negation() {
cov_mark::check!(dont_overwrite_expression_inside_negation);
check_assist(
bool_to_enum,
r#"
fn main() {
if !"foo".chars().any(|c| {
let $0foo = true;
foo
}) {
println!("foo");
}
}
"#,
r#"
fn main() {
if !"foo".chars().any(|c| {
#[derive(PartialEq, Eq)]
enum Bool { True, False }
let foo = Bool::True;
foo == Bool::True
}) {
println!("foo");
}
}
"#,
)
}
#[test] #[test]
fn local_variable_non_bool() { fn local_variable_non_bool() {
cov_mark::check!(not_applicable_non_bool_local); cov_mark::check!(not_applicable_non_bool_local);
@ -638,6 +684,42 @@ fn main() {
) )
} }
#[test]
fn field_negated() {
check_assist(
bool_to_enum,
r#"
struct Foo {
$0bar: bool,
}
fn main() {
let foo = Foo { bar: false };
if !foo.bar {
println!("foo");
}
}
"#,
r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
struct Foo {
bar: Bool,
}
fn main() {
let foo = Foo { bar: Bool::False };
if foo.bar == Bool::False {
println!("foo");
}
}
"#,
)
}
#[test] #[test]
fn field_in_mod_properly_indented() { fn field_in_mod_properly_indented() {
check_assist( check_assist(
@ -714,6 +796,88 @@ fn main() {
) )
} }
#[test]
fn field_assigned_to_another() {
cov_mark::check!(dont_assign_incorrect_ref);
check_assist(
bool_to_enum,
r#"
struct Foo {
$0foo: bool,
}
struct Bar {
bar: bool,
}
fn main() {
let foo = Foo { foo: true };
let mut bar = Bar { bar: true };
bar.bar = foo.foo;
}
"#,
r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
struct Foo {
foo: Bool,
}
struct Bar {
bar: bool,
}
fn main() {
let foo = Foo { foo: Bool::True };
let mut bar = Bar { bar: true };
bar.bar = foo.foo == Bool::True;
}
"#,
)
}
#[test]
fn field_initialized_with_other() {
cov_mark::check!(dont_overwrite_wrong_record_field);
check_assist(
bool_to_enum,
r#"
struct Foo {
$0foo: bool,
}
struct Bar {
bar: bool,
}
fn main() {
let foo = Foo { foo: true };
let bar = Bar { bar: foo.foo };
}
"#,
r#"
#[derive(PartialEq, Eq)]
enum Bool { True, False }
struct Foo {
foo: Bool,
}
struct Bar {
bar: bool,
}
fn main() {
let foo = Foo { foo: Bool::True };
let bar = Bar { bar: foo.foo == Bool::True };
}
"#,
)
}
#[test] #[test]
fn field_non_bool() { fn field_non_bool() {
cov_mark::check!(not_applicable_non_bool_field); cov_mark::check!(not_applicable_non_bool_field);