diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs index 8994ab50e6..bfdda53740 100644 --- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs +++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs @@ -9,9 +9,10 @@ use syntax::{ ast::{ self, edit::{AstNodeEdit, IndentLevel}, - make, BlockExpr, Expr, ExprStmt, HasArgList, + syntax_factory::SyntaxFactory, + BlockExpr, Expr, ExprStmt, HasArgList, }, - ted, AstNode, AstPtr, TextSize, + AstNode, AstPtr, TextSize, }; use crate::{adjusted_display_range, fix, Assist, Diagnostic, DiagnosticCode, DiagnosticsContext}; @@ -223,8 +224,9 @@ fn remove_unnecessary_wrapper( let inner_arg = call_expr.arg_list()?.args().next()?; - let mut builder = SourceChangeBuilder::new(expr_ptr.file_id.original_file(ctx.sema.db)); - + let file_id = expr_ptr.file_id.original_file(db); + let mut builder = SourceChangeBuilder::new(file_id); + let mut editor; match inner_arg { // We're returning `()` Expr::TupleExpr(tup) if tup.fields().next().is_none() => { @@ -233,35 +235,33 @@ fn remove_unnecessary_wrapper( .parent() .and_then(Either::::cast)?; + editor = builder.make_editor(parent.syntax()); + let make = SyntaxFactory::new(); + match parent { Either::Left(ret_expr) => { - let old = builder.make_mut(ret_expr); - let new = make::expr_return(None).clone_for_update(); - - ted::replace(old.syntax(), new.syntax()); + editor.replace(ret_expr.syntax(), make.expr_return(None).syntax()); } Either::Right(stmt_list) => { - if stmt_list.statements().count() == 0 { - let block = stmt_list.syntax().parent().and_then(ast::BlockExpr::cast)?; - let old = builder.make_mut(block); - let new = make::expr_empty_block().clone_for_update(); - - ted::replace(old.syntax(), new.syntax()); + let new_block = if stmt_list.statements().next().is_none() { + make.expr_empty_block() } else { - let old = builder.make_syntax_mut(stmt_list.syntax().parent()?); - let new = make::block_expr(stmt_list.statements(), None).clone_for_update(); + make.block_expr(stmt_list.statements(), None) + }; - ted::replace(old, new.syntax()); - } + editor.replace(stmt_list.syntax().parent()?, new_block.syntax()); } } + + editor.add_mappings(make.finish_with_mappings()); } _ => { - let call_mut = builder.make_mut(call_expr.clone()); - ted::replace(call_mut.syntax(), inner_arg.clone_for_update().syntax()); + editor = builder.make_editor(call_expr.syntax()); + editor.replace(call_expr.syntax(), inner_arg.syntax()); } } + builder.add_file_edits(file_id, editor); let name = format!("Remove unnecessary {}() wrapper", variant.name(db).as_str()); acc.push(fix( "remove_unnecessary_wrapper", @@ -848,6 +848,29 @@ fn div(x: i32, y: i32) -> i32 { ); } + #[test] + fn unwrap_return_type_option_tail_unit() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) { + if y == 0 { + panic!(); + } + + Ok(())$0 +} +"#, + r#" +fn div(x: i32, y: i32) { + if y == 0 { + panic!(); + } +} +"#, + ); + } + #[test] fn unwrap_return_type_handles_generic_functions() { check_fix( diff --git a/crates/syntax/src/ast/syntax_factory/constructors.rs b/crates/syntax/src/ast/syntax_factory/constructors.rs index 44f67d83dc..f6ec18ef30 100644 --- a/crates/syntax/src/ast/syntax_factory/constructors.rs +++ b/crates/syntax/src/ast/syntax_factory/constructors.rs @@ -89,6 +89,10 @@ impl SyntaxFactory { ast } + pub fn expr_empty_block(&self) -> ast::BlockExpr { + ast::BlockExpr { syntax: make::expr_empty_block().syntax().clone_for_update() } + } + pub fn expr_bin(&self, lhs: ast::Expr, op: ast::BinaryOp, rhs: ast::Expr) -> ast::BinExpr { let ast::Expr::BinExpr(ast) = make::expr_bin_op(lhs.clone(), op, rhs.clone()).clone_for_update() @@ -135,6 +139,22 @@ impl SyntaxFactory { ast.into() } + pub fn expr_return(&self, expr: Option) -> ast::ReturnExpr { + let ast::Expr::ReturnExpr(ast) = make::expr_return(expr.clone()).clone_for_update() else { + unreachable!() + }; + + if let Some(mut mapping) = self.mappings() { + let mut builder = SyntaxMappingBuilder::new(ast.syntax().clone()); + if let Some(input) = expr { + builder.map_node(input.syntax().clone(), ast.expr().unwrap().syntax().clone()); + } + builder.finish(&mut mapping); + } + + ast + } + pub fn let_stmt( &self, pattern: ast::Pat,