diff --git a/crates/hir/src/diagnostics.rs b/crates/hir/src/diagnostics.rs index b4c505898e..3016c92a0c 100644 --- a/crates/hir/src/diagnostics.rs +++ b/crates/hir/src/diagnostics.rs @@ -9,6 +9,8 @@ use hir_def::path::ModPath; use hir_expand::{name::Name, HirFileId, InFile}; use syntax::{ast, AstPtr, SyntaxNodePtr, TextRange}; +use crate::Type; + macro_rules! diagnostics { ($($diag:ident,)*) => { pub enum AnyDiagnostic {$( @@ -142,6 +144,7 @@ pub struct MissingOkOrSomeInTailExpr { pub expr: InFile>, // `Some` or `Ok` depending on whether the return type is Result or Option pub required: String, + pub expected: Type, } #[derive(Debug)] diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 5960b3cc1c..4428d0644f 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -1216,7 +1216,14 @@ impl Function { } BodyValidationDiagnostic::MissingOkOrSomeInTailExpr { expr, required } => { match source_map.expr_syntax(expr) { - Ok(expr) => acc.push(MissingOkOrSomeInTailExpr { expr, required }.into()), + Ok(expr) => acc.push( + MissingOkOrSomeInTailExpr { + expr, + required, + expected: self.ret_type(db), + } + .into(), + ), Err(SyntheticSyntax) => (), } } diff --git a/crates/hir_ty/src/diagnostics/match_check.rs b/crates/hir_ty/src/diagnostics/match_check.rs index 7838bbe5c1..cdcb3ed5e7 100644 --- a/crates/hir_ty/src/diagnostics/match_check.rs +++ b/crates/hir_ty/src/diagnostics/match_check.rs @@ -7,6 +7,7 @@ mod deconstruct_pat; mod pat_util; + pub(crate) mod usefulness; use hir_def::{body::Body, EnumVariantId, LocalFieldId, VariantId}; diff --git a/crates/ide/src/highlight_related.rs b/crates/ide/src/highlight_related.rs index c7f2dd95f5..fd154ede44 100644 --- a/crates/ide/src/highlight_related.rs +++ b/crates/ide/src/highlight_related.rs @@ -533,6 +533,9 @@ fn foo() ->$0 u32 { 5 // ^ } + } else if false { + 0 + // ^ } else { match 5 { 6 => 100, diff --git a/crates/ide_completion/src/context.rs b/crates/ide_completion/src/context.rs index 36611cfc2b..06f16f9338 100644 --- a/crates/ide_completion/src/context.rs +++ b/crates/ide_completion/src/context.rs @@ -624,11 +624,7 @@ impl<'a> CompletionContext<'a> { fn classify_name(&mut self, name: ast::Name) { if let Some(bind_pat) = name.syntax().parent().and_then(ast::IdentPat::cast) { self.is_pat_or_const = Some(PatternRefutability::Refutable); - // if any of these is here our bind pat can't be a const pat anymore - let complex_ident_pat = bind_pat.at_token().is_some() - || bind_pat.ref_token().is_some() - || bind_pat.mut_token().is_some(); - if complex_ident_pat { + if !bind_pat.is_simple_ident() { self.is_pat_or_const = None; } else { let irrefutable_pat = bind_pat.syntax().ancestors().find_map(|node| { diff --git a/crates/ide_db/src/helpers.rs b/crates/ide_db/src/helpers.rs index f9da39f833..8ff569a741 100644 --- a/crates/ide_db/src/helpers.rs +++ b/crates/ide_db/src/helpers.rs @@ -271,7 +271,20 @@ pub fn for_each_tail_expr(expr: &ast::Expr, cb: &mut dyn FnMut(&ast::Expr)) { ast::Effect::Async(_) | ast::Effect::Try(_) | ast::Effect::Const(_) => cb(expr), }, ast::Expr::IfExpr(if_) => { - if_.blocks().for_each(|block| for_each_tail_expr(&ast::Expr::BlockExpr(block), cb)) + let mut if_ = if_.clone(); + loop { + if let Some(block) = if_.then_branch() { + for_each_tail_expr(&ast::Expr::BlockExpr(block), cb); + } + match if_.else_branch() { + Some(ast::ElseBranch::IfExpr(it)) => if_ = it, + Some(ast::ElseBranch::Block(block)) => { + for_each_tail_expr(&ast::Expr::BlockExpr(block), cb); + break; + } + None => break, + } + } } ast::Expr::LoopExpr(l) => { for_each_break_expr(l.label(), l.loop_body(), &mut |b| cb(&ast::Expr::BreakExpr(b))) diff --git a/crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs b/crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs index c0edcd7d39..469ab21d3c 100644 --- a/crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs +++ b/crates/ide_diagnostics/src/handlers/missing_ok_or_some_in_tail_expr.rs @@ -1,5 +1,5 @@ use hir::db::AstDatabase; -use ide_db::{assists::Assist, source_change::SourceChange}; +use ide_db::{assists::Assist, helpers::for_each_tail_expr, source_change::SourceChange}; use syntax::AstNode; use text_edit::TextEdit; @@ -33,10 +33,15 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingOkOrSomeInTailExpr) -> Op let root = ctx.sema.db.parse_or_expand(d.expr.file_id)?; let tail_expr = d.expr.value.to_node(&root); let tail_expr_range = tail_expr.syntax().text_range(); - let replacement = format!("{}({})", d.required, tail_expr.syntax()); - let edit = TextEdit::replace(tail_expr_range, replacement); + let mut builder = TextEdit::builder(); + for_each_tail_expr(&tail_expr, &mut |expr| { + if ctx.sema.type_of_expr(expr).as_ref() != Some(&d.expected) { + builder.insert(expr.syntax().text_range().start(), format!("{}(", d.required)); + builder.insert(expr.syntax().text_range().end(), ")".to_string()); + } + }); let source_change = - SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), edit); + SourceChange::from_text_edit(d.expr.file_id.original_file(ctx.sema.db), builder.finish()); let name = if d.required == "Ok" { "Wrap with Ok" } else { "Wrap with Some" }; Some(vec![fix("wrap_tail_expr", name, source_change, tail_expr_range)]) } @@ -68,6 +73,35 @@ fn div(x: i32, y: i32) -> Option { ); } + #[test] + fn test_wrap_return_type_option_tails() { + check_fix( + r#" +//- minicore: option, result +fn div(x: i32, y: i32) -> Option { + if y == 0 { + 0 + } else if true { + 100 + } else { + None + }$0 +} +"#, + r#" +fn div(x: i32, y: i32) -> Option { + if y == 0 { + Some(0) + } else if true { + Some(100) + } else { + None + } +} +"#, + ); + } + #[test] fn test_wrap_return_type() { check_fix( diff --git a/crates/syntax/src/ast/expr_ext.rs b/crates/syntax/src/ast/expr_ext.rs index 5307516ecb..aad5b08e9f 100644 --- a/crates/syntax/src/ast/expr_ext.rs +++ b/crates/syntax/src/ast/expr_ext.rs @@ -164,6 +164,7 @@ impl ast::IfExpr { pub fn then_branch(&self) -> Option { self.blocks().next() } + pub fn else_branch(&self) -> Option { let res = match self.blocks().nth(1) { Some(block) => ElseBranch::Block(block),