Handle closure return types

Fixes #2547.
This commit is contained in:
Florian Diebold 2019-12-20 16:41:32 +01:00
parent cfc50ff160
commit 2a8c9100bf
8 changed files with 113 additions and 6 deletions

View file

@ -372,8 +372,9 @@ where
arg_types.push(type_ref); arg_types.push(type_ref);
} }
} }
let ret_type = e.ret_type().and_then(|r| r.type_ref()).map(TypeRef::from_ast);
let body = self.collect_expr_opt(e.body()); let body = self.collect_expr_opt(e.body());
self.alloc_expr(Expr::Lambda { args, arg_types, body }, syntax_ptr) self.alloc_expr(Expr::Lambda { args, arg_types, ret_type, body }, syntax_ptr)
} }
ast::Expr::BinExpr(e) => { ast::Expr::BinExpr(e) => {
let lhs = self.collect_expr_opt(e.lhs()); let lhs = self.collect_expr_opt(e.lhs());

View file

@ -143,6 +143,7 @@ pub enum Expr {
Lambda { Lambda {
args: Vec<PatId>, args: Vec<PatId>,
arg_types: Vec<Option<TypeRef>>, arg_types: Vec<Option<TypeRef>>,
ret_type: Option<TypeRef>,
body: ExprId, body: ExprId,
}, },
Tuple { Tuple {

View file

@ -196,7 +196,12 @@ struct InferenceContext<'a, D: HirDatabase> {
trait_env: Arc<TraitEnvironment>, trait_env: Arc<TraitEnvironment>,
obligations: Vec<Obligation>, obligations: Vec<Obligation>,
result: InferenceResult, result: InferenceResult,
/// The return type of the function being inferred. /// The return type of the function being inferred, or the closure if we're
/// currently within one.
///
/// We might consider using a nested inference context for checking
/// closures, but currently this is the only field that will change there,
/// so it doesn't make sense.
return_ty: Ty, return_ty: Ty,
/// Impls of `CoerceUnsized` used in coercion. /// Impls of `CoerceUnsized` used in coercion.

View file

@ -102,7 +102,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
self.infer_expr(*body, &Expectation::has_type(Ty::unit())); self.infer_expr(*body, &Expectation::has_type(Ty::unit()));
Ty::unit() Ty::unit()
} }
Expr::Lambda { body, args, arg_types } => { Expr::Lambda { body, args, ret_type, arg_types } => {
assert_eq!(args.len(), arg_types.len()); assert_eq!(args.len(), arg_types.len());
let mut sig_tys = Vec::new(); let mut sig_tys = Vec::new();
@ -118,7 +118,10 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
} }
// add return type // add return type
let ret_ty = self.table.new_type_var(); let ret_ty = match ret_type {
Some(type_ref) => self.make_ty(type_ref),
None => self.table.new_type_var(),
};
sig_tys.push(ret_ty.clone()); sig_tys.push(ret_ty.clone());
let sig_ty = Ty::apply( let sig_ty = Ty::apply(
TypeCtor::FnPtr { num_args: sig_tys.len() as u16 - 1 }, TypeCtor::FnPtr { num_args: sig_tys.len() as u16 - 1 },
@ -134,7 +137,12 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
// infer the body. // infer the body.
self.coerce(&closure_ty, &expected.ty); self.coerce(&closure_ty, &expected.ty);
self.infer_expr(*body, &Expectation::has_type(ret_ty)); let prev_ret_ty = std::mem::replace(&mut self.return_ty, ret_ty.clone());
self.infer_expr_coerce(*body, &Expectation::has_type(ret_ty));
self.return_ty = prev_ret_ty;
closure_ty closure_ty
} }
Expr::Call { callee, args } => { Expr::Call { callee, args } => {
@ -192,6 +200,9 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
Expr::Return { expr } => { Expr::Return { expr } => {
if let Some(expr) = expr { if let Some(expr) = expr {
self.infer_expr_coerce(*expr, &Expectation::has_type(self.return_ty.clone())); self.infer_expr_coerce(*expr, &Expectation::has_type(self.return_ty.clone()));
} else {
let unit = Ty::unit();
self.coerce(&unit, &self.return_ty.clone());
} }
Ty::simple(TypeCtor::Never) Ty::simple(TypeCtor::Never)
} }

View file

@ -440,3 +440,34 @@ fn test() {
"### "###
); );
} }
#[test]
fn closure_return_coerce() {
assert_snapshot!(
infer_with_mismatches(r#"
fn foo() {
let x = || {
if true {
return &1u32;
}
&&1u32
};
}
"#, true),
@r###"
[10; 106) '{ ... }; }': ()
[20; 21) 'x': || -> &u32
[24; 103) '|| { ... }': || -> &u32
[27; 103) '{ ... }': &u32
[37; 82) 'if tru... }': ()
[40; 44) 'true': bool
[45; 82) '{ ... }': !
[59; 71) 'return &1u32': !
[66; 71) '&1u32': &u32
[67; 71) '1u32': u32
[91; 97) '&&1u32': &&u32
[92; 97) '&1u32': &u32
[93; 97) '1u32': u32
"###
);
}

View file

@ -1606,3 +1606,58 @@ fn main() {
); );
assert_eq!(t, "u32"); assert_eq!(t, "u32");
} }
#[test]
fn closure_return() {
assert_snapshot!(
infer(r#"
fn foo() -> u32 {
let x = || -> usize { return 1; };
}
"#),
@r###"
[17; 59) '{ ...; }; }': ()
[27; 28) 'x': || -> usize
[31; 56) '|| -> ...n 1; }': || -> usize
[43; 56) '{ return 1; }': !
[45; 53) 'return 1': !
[52; 53) '1': usize
"###
);
}
#[test]
fn closure_return_unit() {
assert_snapshot!(
infer(r#"
fn foo() -> u32 {
let x = || { return; };
}
"#),
@r###"
[17; 48) '{ ...; }; }': ()
[27; 28) 'x': || -> ()
[31; 45) '|| { return; }': || -> ()
[34; 45) '{ return; }': !
[36; 42) 'return': !
"###
);
}
#[test]
fn closure_return_inferred() {
assert_snapshot!(
infer(r#"
fn foo() -> u32 {
let x = || { "test" };
}
"#),
@r###"
[17; 47) '{ ..." }; }': ()
[27; 28) 'x': || -> &str
[31; 44) '|| { "test" }': || -> &str
[34; 44) '{ "test" }': &str
[36; 42) '"test"': &str
"###
);
}

View file

@ -1426,6 +1426,9 @@ impl LambdaExpr {
pub fn param_list(&self) -> Option<ParamList> { pub fn param_list(&self) -> Option<ParamList> {
AstChildren::new(&self.syntax).next() AstChildren::new(&self.syntax).next()
} }
pub fn ret_type(&self) -> Option<RetType> {
AstChildren::new(&self.syntax).next()
}
pub fn body(&self) -> Option<Expr> { pub fn body(&self) -> Option<Expr> {
AstChildren::new(&self.syntax).next() AstChildren::new(&self.syntax).next()
} }

View file

@ -426,7 +426,7 @@ Grammar(
"PathExpr": (options: ["Path"]), "PathExpr": (options: ["Path"]),
"LambdaExpr": ( "LambdaExpr": (
options: [ options: [
"ParamList", "ParamList", "RetType",
["body", "Expr"], ["body", "Expr"],
] ]
), ),