Add diagnostic fix to remove unnecessary wrapper in type mismatch

I also reorganized the tests in a more logical order, and removed the redundant `test_` prefix from their names.
This commit is contained in:
Giga Bowser 2024-11-01 10:40:19 -04:00
parent 99a6ecd41e
commit d881208d1b

View file

@ -1,14 +1,17 @@
use either::Either; use either::Either;
use hir::{db::ExpandDatabase, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type}; use hir::{db::ExpandDatabase, CallableKind, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type};
use ide_db::text_edit::TextEdit; use ide_db::{
use ide_db::{famous_defs::FamousDefs, source_change::SourceChange}; famous_defs::FamousDefs,
source_change::{SourceChange, SourceChangeBuilder},
text_edit::TextEdit,
};
use syntax::{ use syntax::{
ast::{ ast::{
self, self,
edit::{AstNodeEdit, IndentLevel}, edit::{AstNodeEdit, IndentLevel},
BlockExpr, Expr, ExprStmt, make, BlockExpr, Expr, ExprStmt, HasArgList,
}, },
AstNode, AstPtr, TextSize, ted, AstNode, AstPtr, TextSize,
}; };
use crate::{adjusted_display_range, fix, Assist, Diagnostic, DiagnosticCode, DiagnosticsContext}; use crate::{adjusted_display_range, fix, Assist, Diagnostic, DiagnosticCode, DiagnosticsContext};
@ -63,6 +66,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option<Vec<Assi
let expr_ptr = &InFile { file_id: d.expr_or_pat.file_id, value: expr_ptr }; let expr_ptr = &InFile { file_id: d.expr_or_pat.file_id, value: expr_ptr };
add_reference(ctx, d, expr_ptr, &mut fixes); add_reference(ctx, d, expr_ptr, &mut fixes);
add_missing_ok_or_some(ctx, d, expr_ptr, &mut fixes); add_missing_ok_or_some(ctx, d, expr_ptr, &mut fixes);
remove_unnecessary_wrapper(ctx, d, expr_ptr, &mut fixes);
remove_semicolon(ctx, d, expr_ptr, &mut fixes); remove_semicolon(ctx, d, expr_ptr, &mut fixes);
str_ref_to_owned(ctx, d, expr_ptr, &mut fixes); str_ref_to_owned(ctx, d, expr_ptr, &mut fixes);
} }
@ -184,6 +188,90 @@ fn add_missing_ok_or_some(
Some(()) Some(())
} }
fn remove_unnecessary_wrapper(
ctx: &DiagnosticsContext<'_>,
d: &hir::TypeMismatch,
expr_ptr: &InFile<AstPtr<ast::Expr>>,
acc: &mut Vec<Assist>,
) -> Option<()> {
let db = ctx.sema.db;
let root = db.parse_or_expand(expr_ptr.file_id);
let expr = expr_ptr.value.to_node(&root);
let expr = ctx.sema.original_ast_node(expr.clone())?;
let Expr::CallExpr(call_expr) = expr else {
return None;
};
let callable = ctx.sema.resolve_expr_as_callable(&call_expr.expr()?)?;
let CallableKind::TupleEnumVariant(variant) = callable.kind() else {
return None;
};
let actual_enum = d.actual.as_adt()?.as_enum()?;
let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(call_expr.syntax())?.krate());
let core_option = famous_defs.core_option_Option();
let core_result = famous_defs.core_result_Result();
if Some(actual_enum) != core_option && Some(actual_enum) != core_result {
return None;
}
let inner_type = variant.fields(db).first()?.ty_with_args(db, d.actual.type_arguments());
if !d.expected.could_unify_with(db, &inner_type) {
return None;
}
let inner_arg = call_expr.arg_list()?.args().next()?;
let mut builder = SourceChangeBuilder::new(expr_ptr.file_id.original_file(ctx.sema.db));
match inner_arg {
// We're returning `()`
Expr::TupleExpr(tup) if tup.fields().next().is_none() => {
let parent = call_expr
.syntax()
.parent()
.and_then(Either::<ast::ReturnExpr, ast::StmtList>::cast)?;
match parent {
Either::Left(ret_expr) => {
let old = builder.make_mut(ret_expr);
let new = make::expr_return(None).clone_for_update();
ted::replace(old.syntax(), new.syntax());
}
Either::Right(stmt_list) => {
if stmt_list.statements().count() == 0 {
let block = stmt_list.syntax().parent().and_then(ast::BlockExpr::cast)?;
let old = builder.make_mut(block);
let new = make::expr_empty_block().clone_for_update();
ted::replace(old.syntax(), new.syntax());
} else {
let old = builder.make_syntax_mut(stmt_list.syntax().parent()?);
let new = make::block_expr(stmt_list.statements(), None).clone_for_update();
ted::replace(old, new.syntax());
}
}
}
}
_ => {
let call_mut = builder.make_mut(call_expr.clone());
ted::replace(call_mut.syntax(), inner_arg.clone_for_update().syntax());
}
}
let name = format!("Remove unnecessary {}() wrapper", variant.name(db).as_str());
acc.push(fix(
"remove_unnecessary_wrapper",
&name,
builder.finish(),
call_expr.syntax().text_range(),
));
Some(())
}
fn remove_semicolon( fn remove_semicolon(
ctx: &DiagnosticsContext<'_>, ctx: &DiagnosticsContext<'_>,
d: &hir::TypeMismatch, d: &hir::TypeMismatch,
@ -243,7 +331,7 @@ fn str_ref_to_owned(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::tests::{ use crate::tests::{
check_diagnostics, check_diagnostics_with_disabled, check_fix, check_no_fix, check_diagnostics, check_diagnostics_with_disabled, check_fix, check_has_fix, check_no_fix,
}; };
#[test] #[test]
@ -260,7 +348,7 @@ fn test(_arg: &i32) {}
} }
#[test] #[test]
fn test_add_reference_to_int() { fn add_reference_to_int() {
check_fix( check_fix(
r#" r#"
fn main() { fn main() {
@ -278,7 +366,7 @@ fn test(_arg: &i32) {}
} }
#[test] #[test]
fn test_add_mutable_reference_to_int() { fn add_mutable_reference_to_int() {
check_fix( check_fix(
r#" r#"
fn main() { fn main() {
@ -296,7 +384,7 @@ fn test(_arg: &mut i32) {}
} }
#[test] #[test]
fn test_add_reference_to_array() { fn add_reference_to_array() {
check_fix( check_fix(
r#" r#"
//- minicore: coerce_unsized //- minicore: coerce_unsized
@ -315,7 +403,7 @@ fn test(_arg: &[i32]) {}
} }
#[test] #[test]
fn test_add_reference_with_autoderef() { fn add_reference_with_autoderef() {
check_fix( check_fix(
r#" r#"
//- minicore: coerce_unsized, deref //- minicore: coerce_unsized, deref
@ -348,7 +436,7 @@ fn test(_arg: &Bar) {}
} }
#[test] #[test]
fn test_add_reference_to_method_call() { fn add_reference_to_method_call() {
check_fix( check_fix(
r#" r#"
fn main() { fn main() {
@ -372,7 +460,7 @@ impl Test {
} }
#[test] #[test]
fn test_add_reference_to_let_stmt() { fn add_reference_to_let_stmt() {
check_fix( check_fix(
r#" r#"
fn main() { fn main() {
@ -388,7 +476,7 @@ fn main() {
} }
#[test] #[test]
fn test_add_reference_to_macro_call() { fn add_reference_to_macro_call() {
check_fix( check_fix(
r#" r#"
macro_rules! thousand { macro_rules! thousand {
@ -416,7 +504,7 @@ fn main() {
} }
#[test] #[test]
fn test_add_mutable_reference_to_let_stmt() { fn add_mutable_reference_to_let_stmt() {
check_fix( check_fix(
r#" r#"
fn main() { fn main() {
@ -431,29 +519,6 @@ fn main() {
); );
} }
#[test]
fn test_wrap_return_type_option() {
check_fix(
r#"
//- minicore: option, result
fn div(x: i32, y: i32) -> Option<i32> {
if y == 0 {
return None;
}
x / y$0
}
"#,
r#"
fn div(x: i32, y: i32) -> Option<i32> {
if y == 0 {
return None;
}
Some(x / y)
}
"#,
);
}
#[test] #[test]
fn const_generic_type_mismatch() { fn const_generic_type_mismatch() {
check_diagnostics( check_diagnostics(
@ -487,7 +552,53 @@ fn div(x: i32, y: i32) -> Option<i32> {
} }
#[test] #[test]
fn test_wrap_return_type_option_tails() { fn wrap_return_type() {
check_fix(
r#"
//- minicore: option, result
fn div(x: i32, y: i32) -> Result<i32, ()> {
if y == 0 {
return Err(());
}
x / y$0
}
"#,
r#"
fn div(x: i32, y: i32) -> Result<i32, ()> {
if y == 0 {
return Err(());
}
Ok(x / y)
}
"#,
);
}
#[test]
fn wrap_return_type_option() {
check_fix(
r#"
//- minicore: option, result
fn div(x: i32, y: i32) -> Option<i32> {
if y == 0 {
return None;
}
x / y$0
}
"#,
r#"
fn div(x: i32, y: i32) -> Option<i32> {
if y == 0 {
return None;
}
Some(x / y)
}
"#,
);
}
#[test]
fn wrap_return_type_option_tails() {
check_fix( check_fix(
r#" r#"
//- minicore: option, result //- minicore: option, result
@ -516,30 +627,7 @@ fn div(x: i32, y: i32) -> Option<i32> {
} }
#[test] #[test]
fn test_wrap_return_type() { fn wrap_return_type_handles_generic_functions() {
check_fix(
r#"
//- minicore: option, result
fn div(x: i32, y: i32) -> Result<i32, ()> {
if y == 0 {
return Err(());
}
x / y$0
}
"#,
r#"
fn div(x: i32, y: i32) -> Result<i32, ()> {
if y == 0 {
return Err(());
}
Ok(x / y)
}
"#,
);
}
#[test]
fn test_wrap_return_type_handles_generic_functions() {
check_fix( check_fix(
r#" r#"
//- minicore: option, result //- minicore: option, result
@ -562,7 +650,7 @@ fn div<T>(x: T) -> Result<T, i32> {
} }
#[test] #[test]
fn test_wrap_return_type_handles_type_aliases() { fn wrap_return_type_handles_type_aliases() {
check_fix( check_fix(
r#" r#"
//- minicore: option, result //- minicore: option, result
@ -589,7 +677,7 @@ fn div(x: i32, y: i32) -> MyResult<i32> {
} }
#[test] #[test]
fn test_wrapped_unit_as_block_tail_expr() { fn wrapped_unit_as_block_tail_expr() {
check_fix( check_fix(
r#" r#"
//- minicore: result //- minicore: result
@ -619,7 +707,7 @@ fn foo() -> Result<(), ()> {
} }
#[test] #[test]
fn test_wrapped_unit_as_return_expr() { fn wrapped_unit_as_return_expr() {
check_fix( check_fix(
r#" r#"
//- minicore: result //- minicore: result
@ -642,7 +730,7 @@ fn foo(b: bool) -> Result<(), String> {
} }
#[test] #[test]
fn test_in_const_and_static() { fn wrap_in_const_and_static() {
check_fix( check_fix(
r#" r#"
//- minicore: option, result //- minicore: option, result
@ -664,7 +752,7 @@ const _: Option<()> = {Some(())};
} }
#[test] #[test]
fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() { fn wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() {
check_no_fix( check_no_fix(
r#" r#"
//- minicore: option, result //- minicore: option, result
@ -674,7 +762,7 @@ fn foo() -> Result<(), i32> { 0$0 }
} }
#[test] #[test]
fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() { fn wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() {
check_no_fix( check_no_fix(
r#" r#"
//- minicore: option, result //- minicore: option, result
@ -685,6 +773,231 @@ fn foo() -> SomeOtherEnum { 0$0 }
); );
} }
#[test]
fn unwrap_return_type() {
check_fix(
r#"
//- minicore: option, result
fn div(x: i32, y: i32) -> i32 {
if y == 0 {
panic!();
}
Ok(x / y)$0
}
"#,
r#"
fn div(x: i32, y: i32) -> i32 {
if y == 0 {
panic!();
}
x / y
}
"#,
);
}
#[test]
fn unwrap_return_type_option() {
check_fix(
r#"
//- minicore: option, result
fn div(x: i32, y: i32) -> i32 {
if y == 0 {
panic!();
}
Some(x / y)$0
}
"#,
r#"
fn div(x: i32, y: i32) -> i32 {
if y == 0 {
panic!();
}
x / y
}
"#,
);
}
#[test]
fn unwrap_return_type_option_tails() {
check_fix(
r#"
//- minicore: option, result
fn div(x: i32, y: i32) -> i32 {
if y == 0 {
42
} else if true {
Some(100)$0
} else {
0
}
}
"#,
r#"
fn div(x: i32, y: i32) -> i32 {
if y == 0 {
42
} else if true {
100
} else {
0
}
}
"#,
);
}
#[test]
fn unwrap_return_type_handles_generic_functions() {
check_fix(
r#"
//- minicore: option, result
fn div<T>(x: T) -> T {
if x == 0 {
panic!();
}
$0Ok(x)
}
"#,
r#"
fn div<T>(x: T) -> T {
if x == 0 {
panic!();
}
x
}
"#,
);
}
#[test]
fn unwrap_return_type_handles_type_aliases() {
check_fix(
r#"
//- minicore: option, result
type MyResult<T> = T;
fn div(x: i32, y: i32) -> MyResult<i32> {
if y == 0 {
panic!();
}
Ok(x $0/ y)
}
"#,
r#"
type MyResult<T> = T;
fn div(x: i32, y: i32) -> MyResult<i32> {
if y == 0 {
panic!();
}
x / y
}
"#,
);
}
#[test]
fn unwrap_tail_expr() {
check_fix(
r#"
//- minicore: result
fn foo() -> () {
println!("Hello, world!");
Ok(())$0
}
"#,
r#"
fn foo() -> () {
println!("Hello, world!");
}
"#,
);
}
#[test]
fn unwrap_to_empty_block() {
check_fix(
r#"
//- minicore: result
fn foo() -> () {
Ok(())$0
}
"#,
r#"
fn foo() -> () {}
"#,
);
}
#[test]
fn unwrap_to_return_expr() {
check_has_fix(
r#"
//- minicore: result
fn foo(b: bool) -> () {
if b {
return $0Ok(());
}
panic!("oh dear");
}"#,
r#"
fn foo(b: bool) -> () {
if b {
return;
}
panic!("oh dear");
}"#,
);
}
#[test]
fn unwrap_in_const_and_static() {
check_fix(
r#"
//- minicore: option, result
static A: () = {Some(($0))};
"#,
r#"
static A: () = {};
"#,
);
check_fix(
r#"
//- minicore: option, result
const _: () = {Some(($0))};
"#,
r#"
const _: () = {};
"#,
);
}
#[test]
fn unwrap_return_type_not_applicable_when_inner_type_does_not_match_return_type() {
check_no_fix(
r#"
//- minicore: result
fn foo() -> i32 { $0Ok(()) }
"#,
);
}
#[test]
fn unwrap_return_type_not_applicable_when_wrapper_type_is_not_result_or_option() {
check_no_fix(
r#"
//- minicore: option, result
enum SomeOtherEnum { Ok(i32), Err(String) }
fn foo() -> i32 { SomeOtherEnum::Ok($042) }
"#,
);
}
#[test] #[test]
fn remove_semicolon() { fn remove_semicolon() {
check_fix(r#"fn f() -> i32 { 92$0; }"#, r#"fn f() -> i32 { 92 }"#); check_fix(r#"fn f() -> i32 { 92$0; }"#, r#"fn f() -> i32 { 92 }"#);