diff --git a/crates/ra_hir/src/expr/validation.rs b/crates/ra_hir/src/expr/validation.rs index ca7db61bc4..339a7b8486 100644 --- a/crates/ra_hir/src/expr/validation.rs +++ b/crates/ra_hir/src/expr/validation.rs @@ -6,10 +6,13 @@ use ra_syntax::ast::{AstNode, RecordLit}; use super::{Expr, ExprId, RecordLitField}; use crate::{ adt::AdtDef, + code_model::Enum, diagnostics::{DiagnosticSink, MissingFields, MissingOkInTailExpr}, expr::AstPtr, + name, + path::{PathKind, PathSegment}, ty::{InferenceResult, Ty, TypeCtor}, - Function, HasSource, HirDatabase, Name, Path, + Function, HasSource, HirDatabase, ModuleDef, Name, Path, PerNs, Resolution }; use ra_syntax::ast; @@ -106,18 +109,45 @@ impl<'a, 'b> ExprValidator<'a, 'b> { Some(m) => m, None => return, }; + + let std_result_path = Path { + kind: PathKind::Abs, + segments: vec![ + PathSegment { name: name::STD, args_and_bindings: None }, + PathSegment { name: name::RESULT_MOD, args_and_bindings: None }, + PathSegment { name: name::RESULT_TYPE, args_and_bindings: None }, + ] + }; + + let resolver = self.func.resolver(db); + let std_result_enum = match resolver.resolve_path_segments(db, &std_result_path).into_fully_resolved() { + PerNs { types: Some(Resolution::Def(ModuleDef::Enum(e))), .. } => e, + _ => return, + }; + + let std_result_type = std_result_enum.ty(db); + + fn enum_from_type(ty: &Ty) -> Option { + match ty { + Ty::Apply(t) => { + match t.ctor { + TypeCtor::Adt(AdtDef::Enum(e)) => Some(e), + _ => None, + } + } + _ => None + } + } + + if enum_from_type(&mismatch.expected) != enum_from_type(&std_result_type) { + return; + } + let ret = match &mismatch.expected { Ty::Apply(t) => t, _ => return, }; - let ret_enum = match ret.ctor { - TypeCtor::Adt(AdtDef::Enum(e)) => e, - _ => return, - }; - let enum_name = ret_enum.name(db); - if enum_name.is_none() || enum_name.unwrap().to_string() != "Result" { - return; - } + let params = &ret.parameters; if params.len() == 2 && ¶ms[0] == &mismatch.actual { let source_map = self.func.body_source_map(db); diff --git a/crates/ra_hir/src/name.rs b/crates/ra_hir/src/name.rs index 6d14eea8ec..9c4822d917 100644 --- a/crates/ra_hir/src/name.rs +++ b/crates/ra_hir/src/name.rs @@ -120,6 +120,8 @@ pub(crate) const TRY: Name = Name::new(SmolStr::new_inline_from_ascii(3, b"Try") pub(crate) const OK: Name = Name::new(SmolStr::new_inline_from_ascii(2, b"Ok")); pub(crate) const FUTURE_MOD: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"future")); pub(crate) const FUTURE_TYPE: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"Future")); +pub(crate) const RESULT_MOD: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"result")); +pub(crate) const RESULT_TYPE: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"Result")); pub(crate) const OUTPUT: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"Output")); fn resolve_name(text: &SmolStr) -> SmolStr { diff --git a/crates/ra_ide_api/src/diagnostics.rs b/crates/ra_ide_api/src/diagnostics.rs index 5e25991c62..57454719c2 100644 --- a/crates/ra_ide_api/src/diagnostics.rs +++ b/crates/ra_ide_api/src/diagnostics.rs @@ -281,6 +281,43 @@ fn div(x: i32, y: i32) -> Result { check_apply_diagnostic_fix_for_target_file("/main.rs", before, after); } + #[test] + fn test_wrap_return_type_handles_generic_functions() { + let before = r#" + //- /main.rs + use std::{default::Default, result::Result::{self, Ok, Err}}; + + fn div(x: i32) -> Result { + if x == 0 { + return Err(7); + } + T::default() + } + + //- /std/lib.rs + pub mod result { + pub enum Result { Ok(T), Err(E) } + } + pub mod default { + pub trait Default { + fn default() -> Self; + } + } + "#; +// The formatting here is a bit odd due to how the parse_fixture function works in test_utils - +// it strips empty lines and leading whitespace. The important part of this test is that the final +// `x / y` expr is now wrapped in `Ok(..)` + let after = r#"use std::{default::Default, result::Result::{self, Ok, Err}}; +fn div(x: i32) -> Result { + if x == 0 { + return Err(7); + } + Ok(T::default()) +} +"#; + check_apply_diagnostic_fix_for_target_file("/main.rs", before, after); + } + #[test] fn test_wrap_return_type_handles_type_aliases() { let before = r#"