diff --git a/crates/hir_ty/src/infer/coerce.rs b/crates/hir_ty/src/infer/coerce.rs index 27f59c8bb8..4d80b4a082 100644 --- a/crates/hir_ty/src/infer/coerce.rs +++ b/crates/hir_ty/src/infer/coerce.rs @@ -45,6 +45,10 @@ impl<'a> InferenceContext<'a> { /// - if we were concerned with lifetime subtyping, we'd need to look for a /// least upper bound. pub(super) fn coerce_merge_branch(&mut self, ty1: &Ty, ty2: &Ty) -> Ty { + let ty1 = self.resolve_ty_shallow(ty1); + let ty1 = ty1.as_ref(); + let ty2 = self.resolve_ty_shallow(ty2); + let ty2 = ty2.as_ref(); // Special case: two function types. Try to coerce both to // pointers to have a chance at getting a match. See // https://github.com/rust-lang/rust/blob/7b805396bf46dce972692a6846ce2ad8481c5f85/src/librustc_typeck/check/coercion.rs#L877-L916 @@ -71,12 +75,17 @@ impl<'a> InferenceContext<'a> { } } - if self.coerce(ty1, ty2) { - ty2.clone() - } else if self.coerce(ty2, ty1) { + // It might not seem like it, but order is important here: ty1 is our + // "previous" type, ty2 is the "new" one being added. If the previous + // type is a type variable and the new one is `!`, trying it the other + // way around first would mean we make the type variable `!`, instead of + // just marking it as possibly diverging. + if self.coerce(ty2, ty1) { ty1.clone() + } else if self.coerce(ty1, ty2) { + ty2.clone() } else { - // FIXME record a type mismatch + // TODO record a type mismatch cov_mark::hit!(coerce_merge_fail_fallback); ty1.clone() } diff --git a/crates/hir_ty/src/tests/coercion.rs b/crates/hir_ty/src/tests/coercion.rs index 67295b663b..bb568ea372 100644 --- a/crates/hir_ty/src/tests/coercion.rs +++ b/crates/hir_ty/src/tests/coercion.rs @@ -873,3 +873,42 @@ fn foo(c: i32) { "#, ) } + +#[test] +fn infer_match_diverging_branch_1() { + check_types( + r#" +enum Result { Ok(T), Err } +fn parse() -> T { loop {} } + +fn test() -> i32 { + let a = match parse() { + Ok(val) => val, + Err => return 0, + }; + a + //^ i32 +} + "#, + ) +} + +#[test] +fn infer_match_diverging_branch_2() { + // same as 1 except for order of branches + check_types( + r#" +enum Result { Ok(T), Err } +fn parse() -> T { loop {} } + +fn test() -> i32 { + let a = match parse() { + Err => return 0, + Ok(val) => val, + }; + a + //^ i32 +} + "#, + ) +}