diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs index 93fe9374a3..8994ab50e6 100644 --- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs +++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs @@ -1,14 +1,17 @@ use either::Either; -use hir::{db::ExpandDatabase, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type}; -use ide_db::text_edit::TextEdit; -use ide_db::{famous_defs::FamousDefs, source_change::SourceChange}; +use hir::{db::ExpandDatabase, CallableKind, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type}; +use ide_db::{ + famous_defs::FamousDefs, + source_change::{SourceChange, SourceChangeBuilder}, + text_edit::TextEdit, +}; use syntax::{ ast::{ self, edit::{AstNodeEdit, IndentLevel}, - BlockExpr, Expr, ExprStmt, + make, BlockExpr, Expr, ExprStmt, HasArgList, }, - AstNode, AstPtr, TextSize, + ted, AstNode, AstPtr, TextSize, }; use crate::{adjusted_display_range, fix, Assist, Diagnostic, DiagnosticCode, DiagnosticsContext}; @@ -63,6 +66,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option, + d: &hir::TypeMismatch, + expr_ptr: &InFile>, + acc: &mut Vec, +) -> Option<()> { + let db = ctx.sema.db; + let root = db.parse_or_expand(expr_ptr.file_id); + let expr = expr_ptr.value.to_node(&root); + let expr = ctx.sema.original_ast_node(expr.clone())?; + + let Expr::CallExpr(call_expr) = expr else { + return None; + }; + + let callable = ctx.sema.resolve_expr_as_callable(&call_expr.expr()?)?; + let CallableKind::TupleEnumVariant(variant) = callable.kind() else { + return None; + }; + + let actual_enum = d.actual.as_adt()?.as_enum()?; + let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(call_expr.syntax())?.krate()); + let core_option = famous_defs.core_option_Option(); + let core_result = famous_defs.core_result_Result(); + if Some(actual_enum) != core_option && Some(actual_enum) != core_result { + return None; + } + + let inner_type = variant.fields(db).first()?.ty_with_args(db, d.actual.type_arguments()); + if !d.expected.could_unify_with(db, &inner_type) { + return None; + } + + let inner_arg = call_expr.arg_list()?.args().next()?; + + let mut builder = SourceChangeBuilder::new(expr_ptr.file_id.original_file(ctx.sema.db)); + + match inner_arg { + // We're returning `()` + Expr::TupleExpr(tup) if tup.fields().next().is_none() => { + let parent = call_expr + .syntax() + .parent() + .and_then(Either::::cast)?; + + match parent { + Either::Left(ret_expr) => { + let old = builder.make_mut(ret_expr); + let new = make::expr_return(None).clone_for_update(); + + ted::replace(old.syntax(), new.syntax()); + } + Either::Right(stmt_list) => { + if stmt_list.statements().count() == 0 { + let block = stmt_list.syntax().parent().and_then(ast::BlockExpr::cast)?; + let old = builder.make_mut(block); + let new = make::expr_empty_block().clone_for_update(); + + ted::replace(old.syntax(), new.syntax()); + } else { + let old = builder.make_syntax_mut(stmt_list.syntax().parent()?); + let new = make::block_expr(stmt_list.statements(), None).clone_for_update(); + + ted::replace(old, new.syntax()); + } + } + } + } + _ => { + let call_mut = builder.make_mut(call_expr.clone()); + ted::replace(call_mut.syntax(), inner_arg.clone_for_update().syntax()); + } + } + + let name = format!("Remove unnecessary {}() wrapper", variant.name(db).as_str()); + acc.push(fix( + "remove_unnecessary_wrapper", + &name, + builder.finish(), + call_expr.syntax().text_range(), + )); + Some(()) +} + fn remove_semicolon( ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch, @@ -243,7 +331,7 @@ fn str_ref_to_owned( #[cfg(test)] mod tests { use crate::tests::{ - check_diagnostics, check_diagnostics_with_disabled, check_fix, check_no_fix, + check_diagnostics, check_diagnostics_with_disabled, check_fix, check_has_fix, check_no_fix, }; #[test] @@ -260,7 +348,7 @@ fn test(_arg: &i32) {} } #[test] - fn test_add_reference_to_int() { + fn add_reference_to_int() { check_fix( r#" fn main() { @@ -278,7 +366,7 @@ fn test(_arg: &i32) {} } #[test] - fn test_add_mutable_reference_to_int() { + fn add_mutable_reference_to_int() { check_fix( r#" fn main() { @@ -296,7 +384,7 @@ fn test(_arg: &mut i32) {} } #[test] - fn test_add_reference_to_array() { + fn add_reference_to_array() { check_fix( r#" //- minicore: coerce_unsized @@ -315,7 +403,7 @@ fn test(_arg: &[i32]) {} } #[test] - fn test_add_reference_with_autoderef() { + fn add_reference_with_autoderef() { check_fix( r#" //- minicore: coerce_unsized, deref @@ -348,7 +436,7 @@ fn test(_arg: &Bar) {} } #[test] - fn test_add_reference_to_method_call() { + fn add_reference_to_method_call() { check_fix( r#" fn main() { @@ -372,7 +460,7 @@ impl Test { } #[test] - fn test_add_reference_to_let_stmt() { + fn add_reference_to_let_stmt() { check_fix( r#" fn main() { @@ -388,7 +476,7 @@ fn main() { } #[test] - fn test_add_reference_to_macro_call() { + fn add_reference_to_macro_call() { check_fix( r#" macro_rules! thousand { @@ -416,7 +504,7 @@ fn main() { } #[test] - fn test_add_mutable_reference_to_let_stmt() { + fn add_mutable_reference_to_let_stmt() { check_fix( r#" fn main() { @@ -431,29 +519,6 @@ fn main() { ); } - #[test] - fn test_wrap_return_type_option() { - check_fix( - r#" -//- minicore: option, result -fn div(x: i32, y: i32) -> Option { - if y == 0 { - return None; - } - x / y$0 -} -"#, - r#" -fn div(x: i32, y: i32) -> Option { - if y == 0 { - return None; - } - Some(x / y) -} -"#, - ); - } - #[test] fn const_generic_type_mismatch() { check_diagnostics( @@ -487,7 +552,53 @@ fn div(x: i32, y: i32) -> Option { } #[test] - fn test_wrap_return_type_option_tails() { + fn wrap_return_type() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> Result { + if y == 0 { + return Err(()); + } + x / y$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> Result { + if y == 0 { + return Err(()); + } + Ok(x / y) +} +"#, + ); + } + + #[test] + fn wrap_return_type_option() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> Option { + if y == 0 { + return None; + } + x / y$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> Option { + if y == 0 { + return None; + } + Some(x / y) +} +"#, + ); + } + + #[test] + fn wrap_return_type_option_tails() { check_fix( r#" //- minicore: option, result @@ -516,30 +627,7 @@ fn div(x: i32, y: i32) -> Option { } #[test] - fn test_wrap_return_type() { - check_fix( - r#" -//- minicore: option, result -fn div(x: i32, y: i32) -> Result { - if y == 0 { - return Err(()); - } - x / y$0 -} -"#, - r#" -fn div(x: i32, y: i32) -> Result { - if y == 0 { - return Err(()); - } - Ok(x / y) -} -"#, - ); - } - - #[test] - fn test_wrap_return_type_handles_generic_functions() { + fn wrap_return_type_handles_generic_functions() { check_fix( r#" //- minicore: option, result @@ -562,7 +650,7 @@ fn div(x: T) -> Result { } #[test] - fn test_wrap_return_type_handles_type_aliases() { + fn wrap_return_type_handles_type_aliases() { check_fix( r#" //- minicore: option, result @@ -589,7 +677,7 @@ fn div(x: i32, y: i32) -> MyResult { } #[test] - fn test_wrapped_unit_as_block_tail_expr() { + fn wrapped_unit_as_block_tail_expr() { check_fix( r#" //- minicore: result @@ -619,7 +707,7 @@ fn foo() -> Result<(), ()> { } #[test] - fn test_wrapped_unit_as_return_expr() { + fn wrapped_unit_as_return_expr() { check_fix( r#" //- minicore: result @@ -642,7 +730,7 @@ fn foo(b: bool) -> Result<(), String> { } #[test] - fn test_in_const_and_static() { + fn wrap_in_const_and_static() { check_fix( r#" //- minicore: option, result @@ -664,7 +752,7 @@ const _: Option<()> = {Some(())}; } #[test] - fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() { + fn wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() { check_no_fix( r#" //- minicore: option, result @@ -674,7 +762,7 @@ fn foo() -> Result<(), i32> { 0$0 } } #[test] - fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() { + fn wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() { check_no_fix( r#" //- minicore: option, result @@ -685,6 +773,231 @@ fn foo() -> SomeOtherEnum { 0$0 } ); } + #[test] + fn unwrap_return_type() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + Ok(x / y)$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + x / y +} +"#, + ); + } + + #[test] + fn unwrap_return_type_option() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + Some(x / y)$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + x / y +} +"#, + ); + } + + #[test] + fn unwrap_return_type_option_tails() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + 42 + } else if true { + Some(100)$0 + } else { + 0 + } +} +"#, + r#" +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + 42 + } else if true { + 100 + } else { + 0 + } +} +"#, + ); + } + + #[test] + fn unwrap_return_type_handles_generic_functions() { + check_fix( + r#" +//- minicore: option, result +fn div(x: T) -> T { + if x == 0 { + panic!(); + } + $0Ok(x) +} +"#, + r#" +fn div(x: T) -> T { + if x == 0 { + panic!(); + } + x +} +"#, + ); + } + + #[test] + fn unwrap_return_type_handles_type_aliases() { + check_fix( + r#" +//- minicore: option, result +type MyResult = T; + +fn div(x: i32, y: i32) -> MyResult { + if y == 0 { + panic!(); + } + Ok(x $0/ y) +} +"#, + r#" +type MyResult = T; + +fn div(x: i32, y: i32) -> MyResult { + if y == 0 { + panic!(); + } + x / y +} +"#, + ); + } + + #[test] + fn unwrap_tail_expr() { + check_fix( + r#" +//- minicore: result +fn foo() -> () { + println!("Hello, world!"); + Ok(())$0 +} + "#, + r#" +fn foo() -> () { + println!("Hello, world!"); +} + "#, + ); + } + + #[test] + fn unwrap_to_empty_block() { + check_fix( + r#" +//- minicore: result +fn foo() -> () { + Ok(())$0 +} + "#, + r#" +fn foo() -> () {} + "#, + ); + } + + #[test] + fn unwrap_to_return_expr() { + check_has_fix( + r#" +//- minicore: result +fn foo(b: bool) -> () { + if b { + return $0Ok(()); + } + + panic!("oh dear"); +}"#, + r#" +fn foo(b: bool) -> () { + if b { + return; + } + + panic!("oh dear"); +}"#, + ); + } + + #[test] + fn unwrap_in_const_and_static() { + check_fix( + r#" +//- minicore: option, result +static A: () = {Some(($0))}; + "#, + r#" +static A: () = {}; + "#, + ); + check_fix( + r#" +//- minicore: option, result +const _: () = {Some(($0))}; + "#, + r#" +const _: () = {}; + "#, + ); + } + + #[test] + fn unwrap_return_type_not_applicable_when_inner_type_does_not_match_return_type() { + check_no_fix( + r#" +//- minicore: result +fn foo() -> i32 { $0Ok(()) } +"#, + ); + } + + #[test] + fn unwrap_return_type_not_applicable_when_wrapper_type_is_not_result_or_option() { + check_no_fix( + r#" +//- minicore: option, result +enum SomeOtherEnum { Ok(i32), Err(String) } + +fn foo() -> i32 { SomeOtherEnum::Ok($042) } +"#, + ); + } + #[test] fn remove_semicolon() { check_fix(r#"fn f() -> i32 { 92$0; }"#, r#"fn f() -> i32 { 92 }"#);