diff --git a/crates/ra_hir_def/src/body/lower.rs b/crates/ra_hir_def/src/body/lower.rs index 853e17bae7..be5d17d85e 100644 --- a/crates/ra_hir_def/src/body/lower.rs +++ b/crates/ra_hir_def/src/body/lower.rs @@ -372,8 +372,9 @@ where arg_types.push(type_ref); } } + let ret_type = e.ret_type().and_then(|r| r.type_ref()).map(TypeRef::from_ast); let body = self.collect_expr_opt(e.body()); - self.alloc_expr(Expr::Lambda { args, arg_types, body }, syntax_ptr) + self.alloc_expr(Expr::Lambda { args, arg_types, ret_type, body }, syntax_ptr) } ast::Expr::BinExpr(e) => { let lhs = self.collect_expr_opt(e.lhs()); diff --git a/crates/ra_hir_def/src/expr.rs b/crates/ra_hir_def/src/expr.rs index 6fad80a8d4..a75ef9970d 100644 --- a/crates/ra_hir_def/src/expr.rs +++ b/crates/ra_hir_def/src/expr.rs @@ -143,6 +143,7 @@ pub enum Expr { Lambda { args: Vec, arg_types: Vec>, + ret_type: Option, body: ExprId, }, Tuple { diff --git a/crates/ra_hir_ty/src/infer.rs b/crates/ra_hir_ty/src/infer.rs index bbbc391c4f..9f2ed830ed 100644 --- a/crates/ra_hir_ty/src/infer.rs +++ b/crates/ra_hir_ty/src/infer.rs @@ -196,7 +196,12 @@ struct InferenceContext<'a, D: HirDatabase> { trait_env: Arc, obligations: Vec, result: InferenceResult, - /// The return type of the function being inferred. + /// The return type of the function being inferred, or the closure if we're + /// currently within one. + /// + /// We might consider using a nested inference context for checking + /// closures, but currently this is the only field that will change there, + /// so it doesn't make sense. return_ty: Ty, /// Impls of `CoerceUnsized` used in coercion. diff --git a/crates/ra_hir_ty/src/infer/expr.rs b/crates/ra_hir_ty/src/infer/expr.rs index 8be5679177..253332c30e 100644 --- a/crates/ra_hir_ty/src/infer/expr.rs +++ b/crates/ra_hir_ty/src/infer/expr.rs @@ -102,7 +102,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { self.infer_expr(*body, &Expectation::has_type(Ty::unit())); Ty::unit() } - Expr::Lambda { body, args, arg_types } => { + Expr::Lambda { body, args, ret_type, arg_types } => { assert_eq!(args.len(), arg_types.len()); let mut sig_tys = Vec::new(); @@ -118,7 +118,10 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { } // add return type - let ret_ty = self.table.new_type_var(); + let ret_ty = match ret_type { + Some(type_ref) => self.make_ty(type_ref), + None => self.table.new_type_var(), + }; sig_tys.push(ret_ty.clone()); let sig_ty = Ty::apply( TypeCtor::FnPtr { num_args: sig_tys.len() as u16 - 1 }, @@ -134,7 +137,12 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { // infer the body. self.coerce(&closure_ty, &expected.ty); - self.infer_expr(*body, &Expectation::has_type(ret_ty)); + let prev_ret_ty = std::mem::replace(&mut self.return_ty, ret_ty.clone()); + + self.infer_expr_coerce(*body, &Expectation::has_type(ret_ty)); + + self.return_ty = prev_ret_ty; + closure_ty } Expr::Call { callee, args } => { @@ -192,6 +200,9 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { Expr::Return { expr } => { if let Some(expr) = expr { self.infer_expr_coerce(*expr, &Expectation::has_type(self.return_ty.clone())); + } else { + let unit = Ty::unit(); + self.coerce(&unit, &self.return_ty.clone()); } Ty::simple(TypeCtor::Never) } diff --git a/crates/ra_hir_ty/src/tests/coercion.rs b/crates/ra_hir_ty/src/tests/coercion.rs index ac9e3872a5..33d6ca4034 100644 --- a/crates/ra_hir_ty/src/tests/coercion.rs +++ b/crates/ra_hir_ty/src/tests/coercion.rs @@ -440,3 +440,34 @@ fn test() { "### ); } + +#[test] +fn closure_return_coerce() { + assert_snapshot!( + infer_with_mismatches(r#" +fn foo() { + let x = || { + if true { + return &1u32; + } + &&1u32 + }; +} +"#, true), + @r###" + [10; 106) '{ ... }; }': () + [20; 21) 'x': || -> &u32 + [24; 103) '|| { ... }': || -> &u32 + [27; 103) '{ ... }': &u32 + [37; 82) 'if tru... }': () + [40; 44) 'true': bool + [45; 82) '{ ... }': ! + [59; 71) 'return &1u32': ! + [66; 71) '&1u32': &u32 + [67; 71) '1u32': u32 + [91; 97) '&&1u32': &&u32 + [92; 97) '&1u32': &u32 + [93; 97) '1u32': u32 + "### + ); +} diff --git a/crates/ra_hir_ty/src/tests/simple.rs b/crates/ra_hir_ty/src/tests/simple.rs index 18976c9aee..6fe647a5e3 100644 --- a/crates/ra_hir_ty/src/tests/simple.rs +++ b/crates/ra_hir_ty/src/tests/simple.rs @@ -1606,3 +1606,58 @@ fn main() { ); assert_eq!(t, "u32"); } + +#[test] +fn closure_return() { + assert_snapshot!( + infer(r#" +fn foo() -> u32 { + let x = || -> usize { return 1; }; +} +"#), + @r###" + [17; 59) '{ ...; }; }': () + [27; 28) 'x': || -> usize + [31; 56) '|| -> ...n 1; }': || -> usize + [43; 56) '{ return 1; }': ! + [45; 53) 'return 1': ! + [52; 53) '1': usize + "### + ); +} + +#[test] +fn closure_return_unit() { + assert_snapshot!( + infer(r#" +fn foo() -> u32 { + let x = || { return; }; +} +"#), + @r###" + [17; 48) '{ ...; }; }': () + [27; 28) 'x': || -> () + [31; 45) '|| { return; }': || -> () + [34; 45) '{ return; }': ! + [36; 42) 'return': ! + "### + ); +} + +#[test] +fn closure_return_inferred() { + assert_snapshot!( + infer(r#" +fn foo() -> u32 { + let x = || { "test" }; +} +"#), + @r###" + [17; 47) '{ ..." }; }': () + [27; 28) 'x': || -> &str + [31; 44) '|| { "test" }': || -> &str + [34; 44) '{ "test" }': &str + [36; 42) '"test"': &str + "### + ); +} diff --git a/crates/ra_syntax/src/ast/generated.rs b/crates/ra_syntax/src/ast/generated.rs index 9dd6bd3eac..8d65e2e08e 100644 --- a/crates/ra_syntax/src/ast/generated.rs +++ b/crates/ra_syntax/src/ast/generated.rs @@ -1426,6 +1426,9 @@ impl LambdaExpr { pub fn param_list(&self) -> Option { AstChildren::new(&self.syntax).next() } + pub fn ret_type(&self) -> Option { + AstChildren::new(&self.syntax).next() + } pub fn body(&self) -> Option { AstChildren::new(&self.syntax).next() } diff --git a/crates/ra_syntax/src/grammar.ron b/crates/ra_syntax/src/grammar.ron index 9ffa9095bc..a228fa9d6e 100644 --- a/crates/ra_syntax/src/grammar.ron +++ b/crates/ra_syntax/src/grammar.ron @@ -426,7 +426,7 @@ Grammar( "PathExpr": (options: ["Path"]), "LambdaExpr": ( options: [ - "ParamList", + "ParamList", "RetType", ["body", "Expr"], ] ),