Implement better handling of divergence

Divergence here means that for some reason, the end of a block will not be
reached. We tried to model this just using the never type, but that doesn't work
fully (e.g. in `let x = { loop {}; "foo" };` x should still have type `&str`);
so this introduces a `diverges` flag that the type checker keeps track of, like
rustc does.
This commit is contained in:
Florian Diebold 2020-05-08 17:36:11 +02:00
parent d3eb9d8eaf
commit fe7bf993aa
7 changed files with 200 additions and 23 deletions

View file

@ -210,6 +210,7 @@ struct InferenceContext<'a> {
/// closures, but currently this is the only field that will change there,
/// so it doesn't make sense.
return_ty: Ty,
diverges: Diverges,
}
impl<'a> InferenceContext<'a> {
@ -224,6 +225,7 @@ impl<'a> InferenceContext<'a> {
owner,
body: db.body(owner),
resolver,
diverges: Diverges::Maybe,
}
}
@ -666,6 +668,44 @@ impl Expectation {
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum Diverges {
Maybe,
Always,
}
impl Diverges {
fn is_always(self) -> bool {
self == Diverges::Always
}
}
impl std::ops::BitAnd for Diverges {
type Output = Self;
fn bitand(self, other: Self) -> Self {
std::cmp::min(self, other)
}
}
impl std::ops::BitOr for Diverges {
type Output = Self;
fn bitor(self, other: Self) -> Self {
std::cmp::max(self, other)
}
}
impl std::ops::BitAndAssign for Diverges {
fn bitand_assign(&mut self, other: Self) {
*self = *self & other;
}
}
impl std::ops::BitOrAssign for Diverges {
fn bitor_assign(&mut self, other: Self) {
*self = *self | other;
}
}
mod diagnostics {
use hir_def::{expr::ExprId, FunctionId};
use hir_expand::diagnostics::DiagnosticSink;

View file

@ -1,7 +1,7 @@
//! Type inference for expressions.
use std::iter::{repeat, repeat_with};
use std::sync::Arc;
use std::{mem, sync::Arc};
use hir_def::{
builtin_type::Signedness,
@ -21,11 +21,15 @@ use crate::{
Ty, TypeCtor, Uncertain,
};
use super::{BindingMode, Expectation, InferenceContext, InferenceDiagnostic, TypeMismatch};
use super::{BindingMode, Expectation, InferenceContext, InferenceDiagnostic, TypeMismatch, Diverges};
impl<'a> InferenceContext<'a> {
pub(super) fn infer_expr(&mut self, tgt_expr: ExprId, expected: &Expectation) -> Ty {
let ty = self.infer_expr_inner(tgt_expr, expected);
if ty.is_never() {
// Any expression that produces a value of type `!` must have diverged
self.diverges = Diverges::Always;
}
let could_unify = self.unify(&ty, &expected.ty);
if !could_unify {
self.result.type_mismatches.insert(
@ -64,11 +68,18 @@ impl<'a> InferenceContext<'a> {
// if let is desugared to match, so this is always simple if
self.infer_expr(*condition, &Expectation::has_type(Ty::simple(TypeCtor::Bool)));
let condition_diverges = mem::replace(&mut self.diverges, Diverges::Maybe);
let mut both_arms_diverge = Diverges::Always;
let then_ty = self.infer_expr_inner(*then_branch, &expected);
both_arms_diverge &= mem::replace(&mut self.diverges, Diverges::Maybe);
let else_ty = match else_branch {
Some(else_branch) => self.infer_expr_inner(*else_branch, &expected),
None => Ty::unit(),
};
both_arms_diverge &= self.diverges;
self.diverges = condition_diverges | both_arms_diverge;
self.coerce_merge_branch(&then_ty, &else_ty)
}
@ -132,10 +143,12 @@ impl<'a> InferenceContext<'a> {
// infer the body.
self.coerce(&closure_ty, &expected.ty);
let prev_ret_ty = std::mem::replace(&mut self.return_ty, ret_ty.clone());
let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe);
let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone());
self.infer_expr_coerce(*body, &Expectation::has_type(ret_ty));
self.diverges = prev_diverges;
self.return_ty = prev_ret_ty;
closure_ty
@ -165,7 +178,11 @@ impl<'a> InferenceContext<'a> {
self.table.new_type_var()
};
let matchee_diverges = self.diverges;
let mut all_arms_diverge = Diverges::Always;
for arm in arms {
self.diverges = Diverges::Maybe;
let _pat_ty = self.infer_pat(arm.pat, &input_ty, BindingMode::default());
if let Some(guard_expr) = arm.guard {
self.infer_expr(
@ -175,9 +192,12 @@ impl<'a> InferenceContext<'a> {
}
let arm_ty = self.infer_expr_inner(arm.expr, &expected);
all_arms_diverge &= self.diverges;
result_ty = self.coerce_merge_branch(&result_ty, &arm_ty);
}
self.diverges = matchee_diverges | all_arms_diverge;
result_ty
}
Expr::Path(p) => {
@ -522,7 +542,6 @@ impl<'a> InferenceContext<'a> {
tail: Option<ExprId>,
expected: &Expectation,
) -> Ty {
let mut diverges = false;
for stmt in statements {
match stmt {
Statement::Let { pat, type_ref, initializer } => {
@ -544,9 +563,7 @@ impl<'a> InferenceContext<'a> {
self.infer_pat(*pat, &ty, BindingMode::default());
}
Statement::Expr(expr) => {
if let ty_app!(TypeCtor::Never) = self.infer_expr(*expr, &Expectation::none()) {
diverges = true;
}
self.infer_expr(*expr, &Expectation::none());
}
}
}
@ -554,14 +571,22 @@ impl<'a> InferenceContext<'a> {
let ty = if let Some(expr) = tail {
self.infer_expr_coerce(expr, expected)
} else {
self.coerce(&Ty::unit(), expected.coercion_target());
Ty::unit()
// Citing rustc: if there is no explicit tail expression,
// that is typically equivalent to a tail expression
// of `()` -- except if the block diverges. In that
// case, there is no value supplied from the tail
// expression (assuming there are no other breaks,
// this implies that the type of the block will be
// `!`).
if self.diverges.is_always() {
// we don't even make an attempt at coercion
self.table.new_maybe_never_type_var()
} else {
self.coerce(&Ty::unit(), expected.coercion_target());
Ty::unit()
}
};
if diverges {
Ty::simple(TypeCtor::Never)
} else {
ty
}
ty
}
fn infer_method_call(

View file

@ -730,6 +730,13 @@ impl Ty {
}
}
pub fn is_never(&self) -> bool {
match self {
Ty::Apply(ApplicationTy { ctor: TypeCtor::Never, .. }) => true,
_ => false,
}
}
/// If this is a `dyn Trait` type, this returns the `Trait` part.
pub fn dyn_trait_ref(&self) -> Option<&TraitRef> {
match self {

View file

@ -384,7 +384,7 @@ fn foo() -> u32 {
}
"#, true),
@r###"
17..40 '{ ...own; }': !
17..40 '{ ...own; }': u32
23..37 'return unknown': !
30..37 'unknown': u32
"###
@ -514,7 +514,7 @@ fn foo() {
27..103 '{ ... }': &u32
37..82 'if tru... }': ()
40..44 'true': bool
45..82 '{ ... }': !
45..82 '{ ... }': ()
59..71 'return &1u32': !
66..71 '&1u32': &u32
67..71 '1u32': u32

View file

@ -197,7 +197,7 @@ fn spam() {
!0..6 '1isize': isize
!0..6 '1isize': isize
!0..6 '1isize': isize
54..457 '{ ...!(); }': !
54..457 '{ ...!(); }': ()
88..109 'spam!(...am!())': {unknown}
115..134 'for _ ...!() {}': ()
119..120 '_': {unknown}

View file

@ -1,4 +1,6 @@
use super::type_at;
use insta::assert_snapshot;
use super::{infer_with_mismatches, type_at};
#[test]
fn infer_never1() {
@ -261,3 +263,106 @@ fn test(a: i32) {
);
assert_eq!(t, "f64");
}
#[test]
fn diverging_expression_1() {
let t = infer_with_mismatches(
r#"
//- /main.rs
fn test1() {
let x: u32 = return;
}
fn test2() {
let x: u32 = { return; };
}
fn test3() {
let x: u32 = loop {};
}
fn test4() {
let x: u32 = { loop {} };
}
fn test5() {
let x: u32 = { if true { loop {}; } else { loop {}; } };
}
fn test6() {
let x: u32 = { let y: u32 = { loop {}; }; };
}
"#,
true,
);
assert_snapshot!(t, @r###"
25..53 '{ ...urn; }': ()
35..36 'x': u32
44..50 'return': !
65..98 '{ ...; }; }': ()
75..76 'x': u32
84..95 '{ return; }': u32
86..92 'return': !
110..139 '{ ... {}; }': ()
120..121 'x': u32
129..136 'loop {}': !
134..136 '{}': ()
151..184 '{ ...} }; }': ()
161..162 'x': u32
170..181 '{ loop {} }': u32
172..179 'loop {}': !
177..179 '{}': ()
196..260 '{ ...} }; }': ()
206..207 'x': u32
215..257 '{ if t...}; } }': u32
217..255 'if tru... {}; }': u32
220..224 'true': bool
225..237 '{ loop {}; }': u32
227..234 'loop {}': !
232..234 '{}': ()
243..255 '{ loop {}; }': u32
245..252 'loop {}': !
250..252 '{}': ()
272..324 '{ ...; }; }': ()
282..283 'x': u32
291..321 '{ let ...; }; }': u32
297..298 'y': u32
306..318 '{ loop {}; }': u32
308..315 'loop {}': !
313..315 '{}': ()
"###);
}
#[test]
fn diverging_expression_2() {
let t = infer_with_mismatches(
r#"
//- /main.rs
fn test1() {
// should give type mismatch
let x: u32 = { loop {}; "foo" };
}
"#,
true,
);
assert_snapshot!(t, @r###"
25..98 '{ ..." }; }': ()
68..69 'x': u32
77..95 '{ loop...foo" }': &str
79..86 'loop {}': !
84..86 '{}': ()
88..93 '"foo"': &str
77..95: expected u32, got &str
88..93: expected u32, got &str
"###);
}
#[test]
fn diverging_expression_3_break() {
let t = infer_with_mismatches(
r#"
//- /main.rs
fn test1() {
// should give type mismatch
let x: u32 = { loop { break; } };
}
"#,
true,
);
assert_snapshot!(t, @r###""###);
}

View file

@ -179,7 +179,7 @@ fn test(a: u32, b: isize, c: !, d: &str) {
17..18 'b': isize
27..28 'c': !
33..34 'd': &str
42..121 '{ ...f32; }': !
42..121 '{ ...f32; }': ()
48..49 'a': u32
55..56 'b': isize
62..63 'c': !
@ -935,7 +935,7 @@ fn foo() {
29..33 'true': bool
34..51 '{ ... }': i32
44..45 '1': i32
57..80 '{ ... }': !
57..80 '{ ... }': i32
67..73 'return': !
90..93 '_x2': i32
96..149 'if tru... }': i32
@ -951,7 +951,7 @@ fn foo() {
186..190 'true': bool
194..195 '3': i32
205..206 '_': bool
210..241 '{ ... }': !
210..241 '{ ... }': i32
224..230 'return': !
257..260 '_x4': i32
263..320 'match ... }': i32
@ -1687,7 +1687,7 @@ fn foo() -> u32 {
17..59 '{ ...; }; }': ()
27..28 'x': || -> usize
31..56 '|| -> ...n 1; }': || -> usize
43..56 '{ return 1; }': !
43..56 '{ return 1; }': usize
45..53 'return 1': !
52..53 '1': usize
"###
@ -1706,7 +1706,7 @@ fn foo() -> u32 {
17..48 '{ ...; }; }': ()
27..28 'x': || -> ()
31..45 '|| { return; }': || -> ()
34..45 '{ return; }': !
34..45 '{ return; }': ()
36..42 'return': !
"###
);