From d881208d1bacb999d198f7a7a0d3ea6d5b24cdba Mon Sep 17 00:00:00 2001 From: Giga Bowser <45986823+Giga-Bowser@users.noreply.github.com> Date: Fri, 1 Nov 2024 10:40:19 -0400 Subject: [PATCH 1/3] Add diagnostic fix to remove unnecessary wrapper in type mismatch I also reorganized the tests in a more logical order, and removed the redundant `test_` prefix from their names. --- .../src/handlers/type_mismatch.rs | 449 +++++++++++++++--- 1 file changed, 381 insertions(+), 68 deletions(-) diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs index 93fe9374a3..8994ab50e6 100644 --- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs +++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs @@ -1,14 +1,17 @@ use either::Either; -use hir::{db::ExpandDatabase, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type}; -use ide_db::text_edit::TextEdit; -use ide_db::{famous_defs::FamousDefs, source_change::SourceChange}; +use hir::{db::ExpandDatabase, CallableKind, ClosureStyle, HirDisplay, HirFileIdExt, InFile, Type}; +use ide_db::{ + famous_defs::FamousDefs, + source_change::{SourceChange, SourceChangeBuilder}, + text_edit::TextEdit, +}; use syntax::{ ast::{ self, edit::{AstNodeEdit, IndentLevel}, - BlockExpr, Expr, ExprStmt, + make, BlockExpr, Expr, ExprStmt, HasArgList, }, - AstNode, AstPtr, TextSize, + ted, AstNode, AstPtr, TextSize, }; use crate::{adjusted_display_range, fix, Assist, Diagnostic, DiagnosticCode, DiagnosticsContext}; @@ -63,6 +66,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch) -> Option, + d: &hir::TypeMismatch, + expr_ptr: &InFile>, + acc: &mut Vec, +) -> Option<()> { + let db = ctx.sema.db; + let root = db.parse_or_expand(expr_ptr.file_id); + let expr = expr_ptr.value.to_node(&root); + let expr = ctx.sema.original_ast_node(expr.clone())?; + + let Expr::CallExpr(call_expr) = expr else { + return None; + }; + + let callable = ctx.sema.resolve_expr_as_callable(&call_expr.expr()?)?; + let CallableKind::TupleEnumVariant(variant) = callable.kind() else { + return None; + }; + + let actual_enum = d.actual.as_adt()?.as_enum()?; + let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(call_expr.syntax())?.krate()); + let core_option = famous_defs.core_option_Option(); + let core_result = famous_defs.core_result_Result(); + if Some(actual_enum) != core_option && Some(actual_enum) != core_result { + return None; + } + + let inner_type = variant.fields(db).first()?.ty_with_args(db, d.actual.type_arguments()); + if !d.expected.could_unify_with(db, &inner_type) { + return None; + } + + let inner_arg = call_expr.arg_list()?.args().next()?; + + let mut builder = SourceChangeBuilder::new(expr_ptr.file_id.original_file(ctx.sema.db)); + + match inner_arg { + // We're returning `()` + Expr::TupleExpr(tup) if tup.fields().next().is_none() => { + let parent = call_expr + .syntax() + .parent() + .and_then(Either::::cast)?; + + match parent { + Either::Left(ret_expr) => { + let old = builder.make_mut(ret_expr); + let new = make::expr_return(None).clone_for_update(); + + ted::replace(old.syntax(), new.syntax()); + } + Either::Right(stmt_list) => { + if stmt_list.statements().count() == 0 { + let block = stmt_list.syntax().parent().and_then(ast::BlockExpr::cast)?; + let old = builder.make_mut(block); + let new = make::expr_empty_block().clone_for_update(); + + ted::replace(old.syntax(), new.syntax()); + } else { + let old = builder.make_syntax_mut(stmt_list.syntax().parent()?); + let new = make::block_expr(stmt_list.statements(), None).clone_for_update(); + + ted::replace(old, new.syntax()); + } + } + } + } + _ => { + let call_mut = builder.make_mut(call_expr.clone()); + ted::replace(call_mut.syntax(), inner_arg.clone_for_update().syntax()); + } + } + + let name = format!("Remove unnecessary {}() wrapper", variant.name(db).as_str()); + acc.push(fix( + "remove_unnecessary_wrapper", + &name, + builder.finish(), + call_expr.syntax().text_range(), + )); + Some(()) +} + fn remove_semicolon( ctx: &DiagnosticsContext<'_>, d: &hir::TypeMismatch, @@ -243,7 +331,7 @@ fn str_ref_to_owned( #[cfg(test)] mod tests { use crate::tests::{ - check_diagnostics, check_diagnostics_with_disabled, check_fix, check_no_fix, + check_diagnostics, check_diagnostics_with_disabled, check_fix, check_has_fix, check_no_fix, }; #[test] @@ -260,7 +348,7 @@ fn test(_arg: &i32) {} } #[test] - fn test_add_reference_to_int() { + fn add_reference_to_int() { check_fix( r#" fn main() { @@ -278,7 +366,7 @@ fn test(_arg: &i32) {} } #[test] - fn test_add_mutable_reference_to_int() { + fn add_mutable_reference_to_int() { check_fix( r#" fn main() { @@ -296,7 +384,7 @@ fn test(_arg: &mut i32) {} } #[test] - fn test_add_reference_to_array() { + fn add_reference_to_array() { check_fix( r#" //- minicore: coerce_unsized @@ -315,7 +403,7 @@ fn test(_arg: &[i32]) {} } #[test] - fn test_add_reference_with_autoderef() { + fn add_reference_with_autoderef() { check_fix( r#" //- minicore: coerce_unsized, deref @@ -348,7 +436,7 @@ fn test(_arg: &Bar) {} } #[test] - fn test_add_reference_to_method_call() { + fn add_reference_to_method_call() { check_fix( r#" fn main() { @@ -372,7 +460,7 @@ impl Test { } #[test] - fn test_add_reference_to_let_stmt() { + fn add_reference_to_let_stmt() { check_fix( r#" fn main() { @@ -388,7 +476,7 @@ fn main() { } #[test] - fn test_add_reference_to_macro_call() { + fn add_reference_to_macro_call() { check_fix( r#" macro_rules! thousand { @@ -416,7 +504,7 @@ fn main() { } #[test] - fn test_add_mutable_reference_to_let_stmt() { + fn add_mutable_reference_to_let_stmt() { check_fix( r#" fn main() { @@ -431,29 +519,6 @@ fn main() { ); } - #[test] - fn test_wrap_return_type_option() { - check_fix( - r#" -//- minicore: option, result -fn div(x: i32, y: i32) -> Option { - if y == 0 { - return None; - } - x / y$0 -} -"#, - r#" -fn div(x: i32, y: i32) -> Option { - if y == 0 { - return None; - } - Some(x / y) -} -"#, - ); - } - #[test] fn const_generic_type_mismatch() { check_diagnostics( @@ -487,7 +552,53 @@ fn div(x: i32, y: i32) -> Option { } #[test] - fn test_wrap_return_type_option_tails() { + fn wrap_return_type() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> Result { + if y == 0 { + return Err(()); + } + x / y$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> Result { + if y == 0 { + return Err(()); + } + Ok(x / y) +} +"#, + ); + } + + #[test] + fn wrap_return_type_option() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> Option { + if y == 0 { + return None; + } + x / y$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> Option { + if y == 0 { + return None; + } + Some(x / y) +} +"#, + ); + } + + #[test] + fn wrap_return_type_option_tails() { check_fix( r#" //- minicore: option, result @@ -516,30 +627,7 @@ fn div(x: i32, y: i32) -> Option { } #[test] - fn test_wrap_return_type() { - check_fix( - r#" -//- minicore: option, result -fn div(x: i32, y: i32) -> Result { - if y == 0 { - return Err(()); - } - x / y$0 -} -"#, - r#" -fn div(x: i32, y: i32) -> Result { - if y == 0 { - return Err(()); - } - Ok(x / y) -} -"#, - ); - } - - #[test] - fn test_wrap_return_type_handles_generic_functions() { + fn wrap_return_type_handles_generic_functions() { check_fix( r#" //- minicore: option, result @@ -562,7 +650,7 @@ fn div(x: T) -> Result { } #[test] - fn test_wrap_return_type_handles_type_aliases() { + fn wrap_return_type_handles_type_aliases() { check_fix( r#" //- minicore: option, result @@ -589,7 +677,7 @@ fn div(x: i32, y: i32) -> MyResult { } #[test] - fn test_wrapped_unit_as_block_tail_expr() { + fn wrapped_unit_as_block_tail_expr() { check_fix( r#" //- minicore: result @@ -619,7 +707,7 @@ fn foo() -> Result<(), ()> { } #[test] - fn test_wrapped_unit_as_return_expr() { + fn wrapped_unit_as_return_expr() { check_fix( r#" //- minicore: result @@ -642,7 +730,7 @@ fn foo(b: bool) -> Result<(), String> { } #[test] - fn test_in_const_and_static() { + fn wrap_in_const_and_static() { check_fix( r#" //- minicore: option, result @@ -664,7 +752,7 @@ const _: Option<()> = {Some(())}; } #[test] - fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() { + fn wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() { check_no_fix( r#" //- minicore: option, result @@ -674,7 +762,7 @@ fn foo() -> Result<(), i32> { 0$0 } } #[test] - fn test_wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() { + fn wrap_return_type_not_applicable_when_return_type_is_not_result_or_option() { check_no_fix( r#" //- minicore: option, result @@ -685,6 +773,231 @@ fn foo() -> SomeOtherEnum { 0$0 } ); } + #[test] + fn unwrap_return_type() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + Ok(x / y)$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + x / y +} +"#, + ); + } + + #[test] + fn unwrap_return_type_option() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + Some(x / y)$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + panic!(); + } + x / y +} +"#, + ); + } + + #[test] + fn unwrap_return_type_option_tails() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + 42 + } else if true { + Some(100)$0 + } else { + 0 + } +} +"#, + r#" +fn div(x: i32, y: i32) -> i32 { + if y == 0 { + 42 + } else if true { + 100 + } else { + 0 + } +} +"#, + ); + } + + #[test] + fn unwrap_return_type_handles_generic_functions() { + check_fix( + r#" +//- minicore: option, result +fn div(x: T) -> T { + if x == 0 { + panic!(); + } + $0Ok(x) +} +"#, + r#" +fn div(x: T) -> T { + if x == 0 { + panic!(); + } + x +} +"#, + ); + } + + #[test] + fn unwrap_return_type_handles_type_aliases() { + check_fix( + r#" +//- minicore: option, result +type MyResult = T; + +fn div(x: i32, y: i32) -> MyResult { + if y == 0 { + panic!(); + } + Ok(x $0/ y) +} +"#, + r#" +type MyResult = T; + +fn div(x: i32, y: i32) -> MyResult { + if y == 0 { + panic!(); + } + x / y +} +"#, + ); + } + + #[test] + fn unwrap_tail_expr() { + check_fix( + r#" +//- minicore: result +fn foo() -> () { + println!("Hello, world!"); + Ok(())$0 +} + "#, + r#" +fn foo() -> () { + println!("Hello, world!"); +} + "#, + ); + } + + #[test] + fn unwrap_to_empty_block() { + check_fix( + r#" +//- minicore: result +fn foo() -> () { + Ok(())$0 +} + "#, + r#" +fn foo() -> () {} + "#, + ); + } + + #[test] + fn unwrap_to_return_expr() { + check_has_fix( + r#" +//- minicore: result +fn foo(b: bool) -> () { + if b { + return $0Ok(()); + } + + panic!("oh dear"); +}"#, + r#" +fn foo(b: bool) -> () { + if b { + return; + } + + panic!("oh dear"); +}"#, + ); + } + + #[test] + fn unwrap_in_const_and_static() { + check_fix( + r#" +//- minicore: option, result +static A: () = {Some(($0))}; + "#, + r#" +static A: () = {}; + "#, + ); + check_fix( + r#" +//- minicore: option, result +const _: () = {Some(($0))}; + "#, + r#" +const _: () = {}; + "#, + ); + } + + #[test] + fn unwrap_return_type_not_applicable_when_inner_type_does_not_match_return_type() { + check_no_fix( + r#" +//- minicore: result +fn foo() -> i32 { $0Ok(()) } +"#, + ); + } + + #[test] + fn unwrap_return_type_not_applicable_when_wrapper_type_is_not_result_or_option() { + check_no_fix( + r#" +//- minicore: option, result +enum SomeOtherEnum { Ok(i32), Err(String) } + +fn foo() -> i32 { SomeOtherEnum::Ok($042) } +"#, + ); + } + #[test] fn remove_semicolon() { check_fix(r#"fn f() -> i32 { 92$0; }"#, r#"fn f() -> i32 { 92 }"#); From 59cd717602359254ab8a6f720d9dcaf3b2416d1d Mon Sep 17 00:00:00 2001 From: Giga Bowser <45986823+Giga-Bowser@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:57:58 -0500 Subject: [PATCH 2/3] fix: Handle the final statement in `SyntaxFactory::block_expr` properly This caused a bug that was rather tricky to hunt down! --- .../src/ast/syntax_factory/constructors.rs | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/crates/syntax/src/ast/syntax_factory/constructors.rs b/crates/syntax/src/ast/syntax_factory/constructors.rs index 54f17bd721..44f67d83dc 100644 --- a/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -58,22 +58,31 @@ impl SyntaxFactory { tail_expr: Option, ) -> ast::BlockExpr { let stmts = stmts.into_iter().collect_vec(); - let input = stmts.iter().map(|it| it.syntax().clone()).collect_vec(); + let mut input = stmts.iter().map(|it| it.syntax().clone()).collect_vec(); let ast = make::block_expr(stmts, tail_expr.clone()).clone_for_update(); - if let Some((mut mapping, stmt_list)) = self.mappings().zip(ast.stmt_list()) { + if let Some(mut mapping) = self.mappings() { + let stmt_list = ast.stmt_list().unwrap(); let mut builder = SyntaxMappingBuilder::new(stmt_list.syntax().clone()); + if let Some(input) = tail_expr { + builder.map_node( + input.syntax().clone(), + stmt_list.tail_expr().unwrap().syntax().clone(), + ); + } else if let Some(ast_tail) = stmt_list.tail_expr() { + // The parser interpreted the last statement (probably a statement with a block) as an Expr + let last_stmt = input.pop().unwrap(); + + builder.map_node(last_stmt, ast_tail.syntax().clone()); + } + builder.map_children( input.into_iter(), stmt_list.statements().map(|it| it.syntax().clone()), ); - if let Some((input, output)) = tail_expr.zip(stmt_list.tail_expr()) { - builder.map_node(input.syntax().clone(), output.syntax().clone()); - } - builder.finish(&mut mapping); } From 68b85ce66f049a38e5cece797e1a527838eb466b Mon Sep 17 00:00:00 2001 From: Giga Bowser <45986823+Giga-Bowser@users.noreply.github.com> Date: Thu, 14 Nov 2024 21:15:50 -0500 Subject: [PATCH 3/3] minor: Migrate `remove_unnecessary_wrapper` to `SyntaxEditor` --- .../src/handlers/type_mismatch.rs | 63 +++++++++++++------ .../src/ast/syntax_factory/constructors.rs | 20 ++++++ 2 files changed, 63 insertions(+), 20 deletions(-) diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs index 8994ab50e6..bfdda53740 100644 --- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs +++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs @@ -9,9 +9,10 @@ use syntax::{ ast::{ self, edit::{AstNodeEdit, IndentLevel}, - make, BlockExpr, Expr, ExprStmt, HasArgList, + syntax_factory::SyntaxFactory, + BlockExpr, Expr, ExprStmt, HasArgList, }, - ted, AstNode, AstPtr, TextSize, + AstNode, AstPtr, TextSize, }; use crate::{adjusted_display_range, fix, Assist, Diagnostic, DiagnosticCode, DiagnosticsContext}; @@ -223,8 +224,9 @@ fn remove_unnecessary_wrapper( let inner_arg = call_expr.arg_list()?.args().next()?; - let mut builder = SourceChangeBuilder::new(expr_ptr.file_id.original_file(ctx.sema.db)); - + let file_id = expr_ptr.file_id.original_file(db); + let mut builder = SourceChangeBuilder::new(file_id); + let mut editor; match inner_arg { // We're returning `()` Expr::TupleExpr(tup) if tup.fields().next().is_none() => { @@ -233,35 +235,33 @@ fn remove_unnecessary_wrapper( .parent() .and_then(Either::::cast)?; + editor = builder.make_editor(parent.syntax()); + let make = SyntaxFactory::new(); + match parent { Either::Left(ret_expr) => { - let old = builder.make_mut(ret_expr); - let new = make::expr_return(None).clone_for_update(); - - ted::replace(old.syntax(), new.syntax()); + editor.replace(ret_expr.syntax(), make.expr_return(None).syntax()); } Either::Right(stmt_list) => { - if stmt_list.statements().count() == 0 { - let block = stmt_list.syntax().parent().and_then(ast::BlockExpr::cast)?; - let old = builder.make_mut(block); - let new = make::expr_empty_block().clone_for_update(); - - ted::replace(old.syntax(), new.syntax()); + let new_block = if stmt_list.statements().next().is_none() { + make.expr_empty_block() } else { - let old = builder.make_syntax_mut(stmt_list.syntax().parent()?); - let new = make::block_expr(stmt_list.statements(), None).clone_for_update(); + make.block_expr(stmt_list.statements(), None) + }; - ted::replace(old, new.syntax()); - } + editor.replace(stmt_list.syntax().parent()?, new_block.syntax()); } } + + editor.add_mappings(make.finish_with_mappings()); } _ => { - let call_mut = builder.make_mut(call_expr.clone()); - ted::replace(call_mut.syntax(), inner_arg.clone_for_update().syntax()); + editor = builder.make_editor(call_expr.syntax()); + editor.replace(call_expr.syntax(), inner_arg.syntax()); } } + builder.add_file_edits(file_id, editor); let name = format!("Remove unnecessary {}() wrapper", variant.name(db).as_str()); acc.push(fix( "remove_unnecessary_wrapper", @@ -848,6 +848,29 @@ fn div(x: i32, y: i32) -> i32 { ); } + #[test] + fn unwrap_return_type_option_tail_unit() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) { + if y == 0 { + panic!(); + } + + Ok(())$0 +} +"#, + r#" +fn div(x: i32, y: i32) { + if y == 0 { + panic!(); + } +} +"#, + ); + } + #[test] fn unwrap_return_type_handles_generic_functions() { check_fix( diff --git a/crates/syntax/src/ast/syntax_factory/constructors.rs b/crates/syntax/src/ast/syntax_factory/constructors.rs index 44f67d83dc..f6ec18ef30 100644 --- a/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -89,6 +89,10 @@ impl SyntaxFactory { ast } + pub fn expr_empty_block(&self) -> ast::BlockExpr { + ast::BlockExpr { syntax: make::expr_empty_block().syntax().clone_for_update() } + } + pub fn expr_bin(&self, lhs: ast::Expr, op: ast::BinaryOp, rhs: ast::Expr) -> ast::BinExpr { let ast::Expr::BinExpr(ast) = make::expr_bin_op(lhs.clone(), op, rhs.clone()).clone_for_update() @@ -135,6 +139,22 @@ impl SyntaxFactory { ast.into() } + pub fn expr_return(&self, expr: Option) -> ast::ReturnExpr { + let ast::Expr::ReturnExpr(ast) = make::expr_return(expr.clone()).clone_for_update() else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + if let Some(input) = expr { + builder.map_node(input.syntax().clone(), ast.expr().unwrap().syntax().clone()); + } + builder.finish(&mut mapping); + } + + ast + } + pub fn let_stmt( &self, pattern: ast::Pat,