Implement type variables

This will really become necessary when we implement generics, but even now, it
allows us to reason 'backwards' to infer types of expressions that we didn't
understand for some reason.

We use ena, the union-find implementation extracted from rustc, to keep track of
type variables.
This commit is contained in:
Florian Diebold 2018-12-26 17:00:42 +01:00
parent f3f073804c
commit cfa1de72eb
8 changed files with 395 additions and 118 deletions

10
Cargo.lock generated
View file

@ -271,6 +271,14 @@ name = "either"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "ena"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "error-chain"
version = "0.12.0"
@ -737,6 +745,7 @@ name = "ra_hir"
version = "0.1.0"
dependencies = [
"arrayvec 0.4.10 (registry+https://github.com/rust-lang/crates.io-index)",
"ena 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
"flexi_logger 0.10.3 (registry+https://github.com/rust-lang/crates.io-index)",
"id-arena 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
@ -1546,6 +1555,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum digest 0.7.6 (registry+https://github.com/rust-lang/crates.io-index)" = "03b072242a8cbaf9c145665af9d250c59af3b958f83ed6824e13533cf76d5b90"
"checksum drop_bomb 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "69b26e475fd29098530e709294e94e661974c851aed42512793f120fed4e199f"
"checksum either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3be565ca5c557d7f59e7cfcf1844f9e3033650c929c6566f511e8005f205c1d0"
"checksum ena 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f56c93cc076508c549d9bb747f79aa9b4eb098be7b8cad8830c3137ef52d1e00"
"checksum error-chain 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "07e791d3be96241c77c43846b665ef1384606da2cd2a48730abe606a12906e02"
"checksum failure 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6dd377bcc1b1b7ce911967e3ec24fa19c3224394ec05b54aa7b083d498341ac7"
"checksum failure_derive 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "64c2d913fe8ed3b6c6518eedf4538255b989945c14c2a7d5cbff62a5e2120596"

View file

@ -12,6 +12,7 @@ salsa = "0.9.0"
rustc-hash = "1.0"
parking_lot = "0.7.0"
id-arena = "2.0"
ena = "0.11"
ra_syntax = { path = "../ra_syntax" }
ra_editor = { path = "../ra_editor" }
ra_db = { path = "../ra_db" }

View file

@ -3,10 +3,11 @@ mod primitive;
mod tests;
use std::sync::Arc;
use std::fmt;
use std::{fmt, mem};
use log;
use rustc_hash::{FxHashMap};
use rustc_hash::FxHashMap;
use ena::unify::{InPlaceUnificationTable, UnifyKey, UnifyValue, NoError};
use ra_db::{LocalSyntaxPtr, Cancelable};
use ra_syntax::{
@ -17,10 +18,89 @@ use ra_syntax::{
use crate::{
Def, DefId, FnScopes, Module, Function, Struct, Enum, Path, Name, AsName,
db::HirDatabase,
adt::VariantData,
type_ref::{TypeRef, Mutability},
};
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub struct TypeVarId(u32);
impl UnifyKey for TypeVarId {
type Value = TypeVarValue;
fn index(&self) -> u32 {
self.0
}
fn from_index(i: u32) -> Self {
TypeVarId(i)
}
fn tag() -> &'static str {
"TypeVarId"
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum TypeVarValue {
Known(Ty),
Unknown,
}
impl TypeVarValue {
pub fn known(&self) -> Option<&Ty> {
match self {
TypeVarValue::Known(ty) => Some(ty),
TypeVarValue::Unknown => None,
}
}
}
impl UnifyValue for TypeVarValue {
type Error = NoError;
fn unify_values(value1: &Self, value2: &Self) -> Result<Self, NoError> {
match (value1, value2) {
// We should never equate two type variables, both of which have
// known types. Instead, we recursively equate those types.
(TypeVarValue::Known(..), TypeVarValue::Known(..)) => {
panic!("equating two type variables, both of which have known types")
}
// If one side is known, prefer that one.
(TypeVarValue::Known(..), TypeVarValue::Unknown) => Ok(value1.clone()),
(TypeVarValue::Unknown, TypeVarValue::Known(..)) => Ok(value2.clone()),
(TypeVarValue::Unknown, TypeVarValue::Unknown) => Ok(TypeVarValue::Unknown),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum InferTy {
TypeVar(TypeVarId),
// later we'll have IntVar and FloatVar as well
}
/// When inferring an expression, we propagate downward whatever type hint we
/// are able in the form of an `Expectation`.
#[derive(Clone, PartialEq, Eq, Debug)]
struct Expectation {
ty: Ty,
// TODO: In some cases, we need to be aware whether the expectation is that
// the type match exactly what we passed, or whether it just needs to be
// coercible to the expected type. See Expectation::rvalue_hint in rustc.
}
impl Expectation {
fn has_type(ty: Ty) -> Self {
Expectation { ty }
}
fn none() -> Self {
Expectation { ty: Ty::Unknown }
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum Ty {
/// The primitive boolean type. Written as `bool`.
@ -75,23 +155,22 @@ pub enum Ty {
// A trait, defined with `dyn trait`.
// Dynamic(),
/// The anonymous type of a closure. Used to represent the type of
/// `|a| a`.
// The anonymous type of a closure. Used to represent the type of
// `|a| a`.
// Closure(DefId, ClosureSubsts<'tcx>),
/// The anonymous type of a generator. Used to represent the type of
/// `|a| yield a`.
// The anonymous type of a generator. Used to represent the type of
// `|a| yield a`.
// Generator(DefId, GeneratorSubsts<'tcx>, hir::GeneratorMovability),
/// A type representin the types stored inside a generator.
/// This should only appear in GeneratorInteriors.
// A type representin the types stored inside a generator.
// This should only appear in GeneratorInteriors.
// GeneratorWitness(Binder<&'tcx List<Ty<'tcx>>>),
/// The never type `!`
Never,
/// A tuple type. For example, `(i32, bool)`.
Tuple(Vec<Ty>),
Tuple(Arc<[Ty]>),
// The projection of an associated type. For example,
// `<T as Trait<..>>::N`.pub
@ -106,14 +185,14 @@ pub enum Ty {
// A type parameter; for example, `T` in `fn f<T>(x: T) {}
// Param(ParamTy),
/// A type variable used during type checking. Not to be confused with a
/// type parameter.
Infer(InferTy),
// A placeholder type - universally quantified higher-ranked type.
// Placeholder(ty::PlaceholderType),
// A type variable used during type checking.
// Infer(InferTy),
/// A placeholder for a type which could not be computed; this is
/// propagated to avoid useless error messages.
/// A placeholder for a type which could not be computed; this is propagated
/// to avoid useless error messages. Doubles as a placeholder where type
/// variables are inserted before type checking, since we want to try to
/// infer a better type here anyway.
Unknown,
}
@ -137,8 +216,8 @@ impl Ty {
let inner_tys = inner
.iter()
.map(|tr| Ty::from_hir(db, module, tr))
.collect::<Cancelable<_>>()?;
Ty::Tuple(inner_tys)
.collect::<Cancelable<Vec<_>>>()?;
Ty::Tuple(inner_tys.into())
}
TypeRef::Path(path) => Ty::from_hir_path(db, module, path)?,
TypeRef::RawPtr(inner, mutability) => {
@ -154,7 +233,7 @@ impl Ty {
let inner_ty = Ty::from_hir(db, module, inner)?;
Ty::Ref(Arc::new(inner_ty), *mutability)
}
TypeRef::Placeholder => Ty::Unknown, // TODO
TypeRef::Placeholder => Ty::Unknown,
TypeRef::Fn(params) => {
let mut inner_tys = params
.iter()
@ -217,7 +296,41 @@ impl Ty {
}
pub fn unit() -> Self {
Ty::Tuple(Vec::new())
Ty::Tuple(Arc::new([]))
}
fn walk_mut(&mut self, f: &mut impl FnMut(&mut Ty)) {
f(self);
match self {
Ty::Slice(t) => Arc::make_mut(t).walk_mut(f),
Ty::RawPtr(t, _) => Arc::make_mut(t).walk_mut(f),
Ty::Ref(t, _) => Arc::make_mut(t).walk_mut(f),
Ty::Tuple(ts) => {
// Without an Arc::make_mut_slice, we can't avoid the clone here:
let mut v: Vec<_> = ts.iter().cloned().collect();
for t in &mut v {
t.walk_mut(f);
}
*ts = v.into();
}
Ty::FnPtr(sig) => {
let sig_mut = Arc::make_mut(sig);
for input in &mut sig_mut.input {
input.walk_mut(f);
}
sig_mut.output.walk_mut(f);
}
Ty::Adt { .. } => {} // need to walk type parameters later
_ => {}
}
}
fn fold(mut self, f: &mut impl FnMut(Ty) -> Ty) -> Ty {
self.walk_mut(&mut |ty_mut| {
let ty = mem::replace(ty_mut, Ty::Unknown);
*ty_mut = f(ty);
});
self
}
}
@ -236,7 +349,7 @@ impl fmt::Display for Ty {
Ty::Never => write!(f, "!"),
Ty::Tuple(ts) => {
write!(f, "(")?;
for t in ts {
for t in ts.iter() {
write!(f, "{},", t)?;
}
write!(f, ")")
@ -250,6 +363,7 @@ impl fmt::Display for Ty {
}
Ty::Adt { name, .. } => write!(f, "{}", name),
Ty::Unknown => write!(f, "[unknown]"),
Ty::Infer(..) => write!(f, "_"),
}
}
}
@ -342,7 +456,7 @@ pub struct InferenceContext<'a, D: HirDatabase> {
db: &'a D,
scopes: Arc<FnScopes>,
module: Module,
// TODO unification tables...
var_unification_table: InPlaceUnificationTable<TypeVarId>,
type_of: FxHashMap<LocalSyntaxPtr, Ty>,
}
@ -350,33 +464,116 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
fn new(db: &'a D, scopes: Arc<FnScopes>, module: Module) -> Self {
InferenceContext {
type_of: FxHashMap::default(),
var_unification_table: InPlaceUnificationTable::new(),
db,
scopes,
module,
}
}
fn resolve_all(mut self) -> InferenceResult {
let mut types = mem::replace(&mut self.type_of, FxHashMap::default());
for ty in types.values_mut() {
let resolved = self.resolve_ty_completely(mem::replace(ty, Ty::Unknown));
*ty = resolved;
}
InferenceResult { type_of: types }
}
fn write_ty(&mut self, node: SyntaxNodeRef, ty: Ty) {
self.type_of.insert(LocalSyntaxPtr::new(node), ty);
}
fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> Option<Ty> {
if *ty1 == Ty::Unknown {
return Some(ty2.clone());
fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
match (ty1, ty2) {
(Ty::Unknown, ..) => true,
(.., Ty::Unknown) => true,
(Ty::Bool, _)
| (Ty::Str, _)
| (Ty::Never, _)
| (Ty::Char, _)
| (Ty::Int(..), Ty::Int(..))
| (Ty::Uint(..), Ty::Uint(..))
| (Ty::Float(..), Ty::Float(..)) => ty1 == ty2,
(
Ty::Adt {
def_id: def_id1, ..
},
Ty::Adt {
def_id: def_id2, ..
},
) if def_id1 == def_id2 => true,
(Ty::Slice(t1), Ty::Slice(t2)) => self.unify(t1, t2),
(Ty::RawPtr(t1, m1), Ty::RawPtr(t2, m2)) if m1 == m2 => self.unify(t1, t2),
(Ty::Ref(t1, m1), Ty::Ref(t2, m2)) if m1 == m2 => self.unify(t1, t2),
(Ty::FnPtr(sig1), Ty::FnPtr(sig2)) if sig1 == sig2 => true,
(Ty::Tuple(ts1), Ty::Tuple(ts2)) if ts1.len() == ts2.len() => ts1
.iter()
.zip(ts2.iter())
.all(|(t1, t2)| self.unify(t1, t2)),
(Ty::Infer(InferTy::TypeVar(tv1)), Ty::Infer(InferTy::TypeVar(tv2))) => {
self.var_unification_table.union(*tv1, *tv2);
true
}
if *ty2 == Ty::Unknown {
return Some(ty1.clone());
(Ty::Infer(InferTy::TypeVar(tv)), other) | (other, Ty::Infer(InferTy::TypeVar(tv))) => {
self.var_unification_table
.union_value(*tv, TypeVarValue::Known(other.clone()));
true
}
if ty1 == ty2 {
return Some(ty1.clone());
_ => false,
}
// TODO implement actual unification
return None;
}
fn unify_with_coercion(&mut self, ty1: &Ty, ty2: &Ty) -> Option<Ty> {
// TODO implement coercion
self.unify(ty1, ty2)
fn new_type_var(&mut self) -> Ty {
Ty::Infer(InferTy::TypeVar(
self.var_unification_table.new_key(TypeVarValue::Unknown),
))
}
/// Replaces Ty::Unknown by a new type var, so we can maybe still infer it.
fn insert_type_vars_shallow(&mut self, ty: Ty) -> Ty {
match ty {
Ty::Unknown => self.new_type_var(),
_ => ty,
}
}
fn insert_type_vars(&mut self, ty: Ty) -> Ty {
ty.fold(&mut |ty| self.insert_type_vars_shallow(ty))
}
/// Resolves the type as far as currently possible, replacing type variables
/// by their known types. All types returned by the infer_* functions should
/// be resolved as far as possible, i.e. contain no type variables with
/// known type.
fn resolve_ty_as_possible(&mut self, ty: Ty) -> Ty {
ty.fold(&mut |ty| match ty {
Ty::Infer(InferTy::TypeVar(tv)) => {
if let Some(known_ty) = self.var_unification_table.probe_value(tv).known() {
// known_ty may contain other variables that are known by now
self.resolve_ty_as_possible(known_ty.clone())
} else {
Ty::Infer(InferTy::TypeVar(tv))
}
}
_ => ty,
})
}
/// Resolves the type completely; type variables without known type are
/// replaced by Ty::Unknown.
fn resolve_ty_completely(&mut self, ty: Ty) -> Ty {
ty.fold(&mut |ty| match ty {
Ty::Infer(InferTy::TypeVar(tv)) => {
if let Some(known_ty) = self.var_unification_table.probe_value(tv).known() {
// known_ty may contain other variables that are known by now
self.resolve_ty_completely(known_ty.clone())
} else {
Ty::Unknown
}
}
_ => ty,
})
}
fn infer_path_expr(&mut self, expr: ast::PathExpr) -> Cancelable<Option<Ty>> {
@ -387,21 +584,19 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
let name = ctry!(ast_path.segment().and_then(|s| s.name_ref()));
if let Some(scope_entry) = self.scopes.resolve_local_name(name) {
let ty = ctry!(self.type_of.get(&scope_entry.ptr()));
return Ok(Some(ty.clone()));
let ty = self.resolve_ty_as_possible(ty.clone());
return Ok(Some(ty));
};
};
// resolve in module
let resolved = ctry!(self.module.resolve_path(self.db, &path)?.take_values());
let ty = self.db.type_for_def(resolved)?;
// TODO we will need to add type variables for type parameters etc. here
let ty = self.insert_type_vars(ty);
Ok(Some(ty))
}
fn resolve_variant(
&self,
path: Option<ast::Path>,
) -> Cancelable<(Ty, Option<Arc<VariantData>>)> {
fn resolve_variant(&self, path: Option<ast::Path>) -> Cancelable<(Ty, Option<DefId>)> {
let path = if let Some(path) = path.and_then(Path::from_ast) {
path
} else {
@ -414,102 +609,116 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
};
Ok(match def_id.resolve(self.db)? {
Def::Struct(s) => {
let struct_data = self.db.struct_data(def_id)?;
let ty = type_for_struct(self.db, s)?;
(ty, Some(struct_data.variant_data().clone()))
(ty, Some(def_id))
}
_ => (Ty::Unknown, None),
})
}
fn infer_expr_opt(&mut self, expr: Option<ast::Expr>) -> Cancelable<Ty> {
fn infer_expr_opt(
&mut self,
expr: Option<ast::Expr>,
expected: &Expectation,
) -> Cancelable<Ty> {
if let Some(e) = expr {
self.infer_expr(e)
self.infer_expr(e, expected)
} else {
Ok(Ty::Unknown)
}
}
fn infer_expr(&mut self, expr: ast::Expr) -> Cancelable<Ty> {
fn infer_expr(&mut self, expr: ast::Expr, expected: &Expectation) -> Cancelable<Ty> {
let ty = match expr {
ast::Expr::IfExpr(e) => {
if let Some(condition) = e.condition() {
// TODO if no pat, this should be bool
self.infer_expr_opt(condition.expr())?;
let expected = if condition.pat().is_none() {
Expectation::has_type(Ty::Bool)
} else {
Expectation::none()
};
self.infer_expr_opt(condition.expr(), &expected)?;
// TODO write type for pat
};
let if_ty = self.infer_block_opt(e.then_branch())?;
let else_ty = self.infer_block_opt(e.else_branch())?;
if let Some(ty) = self.unify(&if_ty, &else_ty) {
ty
let if_ty = self.infer_block_opt(e.then_branch(), expected)?;
if let Some(else_branch) = e.else_branch() {
self.infer_block(else_branch, expected)?;
} else {
// TODO report diagnostic
Ty::Unknown
// no else branch -> unit
self.unify(&expected.ty, &Ty::unit()); // actually coerce
}
if_ty
}
ast::Expr::BlockExpr(e) => self.infer_block_opt(e.block())?,
ast::Expr::BlockExpr(e) => self.infer_block_opt(e.block(), expected)?,
ast::Expr::LoopExpr(e) => {
self.infer_block_opt(e.loop_body())?;
self.infer_block_opt(e.loop_body(), &Expectation::has_type(Ty::unit()))?;
// TODO never, or the type of the break param
Ty::Unknown
}
ast::Expr::WhileExpr(e) => {
if let Some(condition) = e.condition() {
// TODO if no pat, this should be bool
self.infer_expr_opt(condition.expr())?;
let expected = if condition.pat().is_none() {
Expectation::has_type(Ty::Bool)
} else {
Expectation::none()
};
self.infer_expr_opt(condition.expr(), &expected)?;
// TODO write type for pat
};
self.infer_block_opt(e.loop_body())?;
self.infer_block_opt(e.loop_body(), &Expectation::has_type(Ty::unit()))?;
// TODO always unit?
Ty::Unknown
Ty::unit()
}
ast::Expr::ForExpr(e) => {
let _iterable_ty = self.infer_expr_opt(e.iterable());
let _iterable_ty = self.infer_expr_opt(e.iterable(), &Expectation::none());
if let Some(_pat) = e.pat() {
// TODO write type for pat
}
self.infer_block_opt(e.loop_body())?;
self.infer_block_opt(e.loop_body(), &Expectation::has_type(Ty::unit()))?;
// TODO always unit?
Ty::Unknown
Ty::unit()
}
ast::Expr::LambdaExpr(e) => {
let _body_ty = self.infer_expr_opt(e.body())?;
let _body_ty = self.infer_expr_opt(e.body(), &Expectation::none())?;
Ty::Unknown
}
ast::Expr::CallExpr(e) => {
let callee_ty = self.infer_expr_opt(e.expr())?;
if let Some(arg_list) = e.arg_list() {
for arg in arg_list.args() {
// TODO unify / expect argument type
self.infer_expr(arg)?;
}
}
match callee_ty {
Ty::FnPtr(sig) => sig.output.clone(),
let callee_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
let (arg_tys, ret_ty) = match &callee_ty {
Ty::FnPtr(sig) => (&sig.input[..], sig.output.clone()),
_ => {
// not callable
// TODO report an error?
Ty::Unknown
(&[][..], Ty::Unknown)
}
};
if let Some(arg_list) = e.arg_list() {
for (i, arg) in arg_list.args().enumerate() {
self.infer_expr(
arg,
&Expectation::has_type(arg_tys.get(i).cloned().unwrap_or(Ty::Unknown)),
)?;
}
}
ret_ty
}
ast::Expr::MethodCallExpr(e) => {
let _receiver_ty = self.infer_expr_opt(e.expr())?;
let _receiver_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
if let Some(arg_list) = e.arg_list() {
for arg in arg_list.args() {
// TODO unify / expect argument type
self.infer_expr(arg)?;
self.infer_expr(arg, &Expectation::none())?;
}
}
Ty::Unknown
}
ast::Expr::MatchExpr(e) => {
let _ty = self.infer_expr_opt(e.expr())?;
let _ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
if let Some(match_arm_list) = e.match_arm_list() {
for arm in match_arm_list.arms() {
// TODO type the bindings in pat
// TODO type the guard
let _ty = self.infer_expr_opt(arm.expr())?;
let _ty = self.infer_expr_opt(arm.expr(), &Expectation::none())?;
}
// TODO unify all the match arm types
Ty::Unknown
@ -522,10 +731,10 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
ast::Expr::PathExpr(e) => self.infer_path_expr(e)?.unwrap_or(Ty::Unknown),
ast::Expr::ContinueExpr(_e) => Ty::Never,
ast::Expr::BreakExpr(_e) => Ty::Never,
ast::Expr::ParenExpr(e) => self.infer_expr_opt(e.expr())?,
ast::Expr::ParenExpr(e) => self.infer_expr_opt(e.expr(), expected)?,
ast::Expr::Label(_e) => Ty::Unknown,
ast::Expr::ReturnExpr(e) => {
self.infer_expr_opt(e.expr())?;
self.infer_expr_opt(e.expr(), &Expectation::none())?;
Ty::Never
}
ast::Expr::MatchArmList(_) | ast::Expr::MatchArm(_) | ast::Expr::MatchGuard(_) => {
@ -533,11 +742,16 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
Ty::Unknown
}
ast::Expr::StructLit(e) => {
let (ty, _variant_data) = self.resolve_variant(e.path())?;
let (ty, def_id) = self.resolve_variant(e.path())?;
if let Some(nfl) = e.named_field_list() {
for field in nfl.fields() {
// TODO unify with / expect field type
self.infer_expr_opt(field.expr())?;
let field_ty = if let (Some(def_id), Some(nr)) = (def_id, field.name_ref())
{
self.db.type_for_field(def_id, nr.as_name())?
} else {
Ty::Unknown
};
self.infer_expr_opt(field.expr(), &Expectation::has_type(field_ty))?;
}
}
ty
@ -548,9 +762,9 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
}
ast::Expr::IndexExpr(_e) => Ty::Unknown,
ast::Expr::FieldExpr(e) => {
let receiver_ty = self.infer_expr_opt(e.expr())?;
let receiver_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
if let Some(nr) = e.name_ref() {
match receiver_ty {
let ty = match receiver_ty {
Ty::Tuple(fields) => {
let i = nr.text().parse::<usize>().ok();
i.and_then(|i| fields.get(i).cloned())
@ -558,29 +772,32 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
}
Ty::Adt { def_id, .. } => self.db.type_for_field(def_id, nr.as_name())?,
_ => Ty::Unknown,
}
};
self.insert_type_vars(ty)
} else {
Ty::Unknown
}
}
ast::Expr::TryExpr(e) => {
let _inner_ty = self.infer_expr_opt(e.expr())?;
let _inner_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
Ty::Unknown
}
ast::Expr::CastExpr(e) => {
let _inner_ty = self.infer_expr_opt(e.expr())?;
let _inner_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
let cast_ty = Ty::from_ast_opt(self.db, &self.module, e.type_ref())?;
let cast_ty = self.insert_type_vars(cast_ty);
// TODO do the coercion...
cast_ty
}
ast::Expr::RefExpr(e) => {
let inner_ty = self.infer_expr_opt(e.expr())?;
// TODO pass the expectation down
let inner_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
let m = Mutability::from_mutable(e.is_mut());
// TODO reference coercions etc.
Ty::Ref(Arc::new(inner_ty), m)
}
ast::Expr::PrefixExpr(e) => {
let inner_ty = self.infer_expr_opt(e.expr())?;
let inner_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
match e.op() {
Some(PrefixOp::Deref) => {
match inner_ty {
@ -598,28 +815,34 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
ast::Expr::BinExpr(_e) => Ty::Unknown,
ast::Expr::Literal(_e) => Ty::Unknown,
};
// use a new type variable if we got Ty::Unknown here
let ty = self.insert_type_vars_shallow(ty);
self.unify(&ty, &expected.ty);
self.write_ty(expr.syntax(), ty.clone());
Ok(ty)
}
fn infer_block_opt(&mut self, node: Option<ast::Block>) -> Cancelable<Ty> {
fn infer_block_opt(
&mut self,
node: Option<ast::Block>,
expected: &Expectation,
) -> Cancelable<Ty> {
if let Some(b) = node {
self.infer_block(b)
self.infer_block(b, expected)
} else {
Ok(Ty::Unknown)
}
}
fn infer_block(&mut self, node: ast::Block) -> Cancelable<Ty> {
fn infer_block(&mut self, node: ast::Block, expected: &Expectation) -> Cancelable<Ty> {
for stmt in node.statements() {
match stmt {
ast::Stmt::LetStmt(stmt) => {
let decl_ty = Ty::from_ast_opt(self.db, &self.module, stmt.type_ref())?;
let decl_ty = self.insert_type_vars(decl_ty);
let ty = if let Some(expr) = stmt.initializer() {
// TODO pass expectation
let expr_ty = self.infer_expr(expr)?;
self.unify_with_coercion(&expr_ty, &decl_ty)
.unwrap_or(decl_ty)
let expr_ty = self.infer_expr(expr, &Expectation::has_type(decl_ty))?;
expr_ty
} else {
decl_ty
};
@ -629,12 +852,12 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
};
}
ast::Stmt::ExprStmt(expr_stmt) => {
self.infer_expr_opt(expr_stmt.expr())?;
self.infer_expr_opt(expr_stmt.expr(), &Expectation::none())?;
}
}
}
let ty = if let Some(expr) = node.expr() {
self.infer_expr(expr)?
self.infer_expr(expr, expected)?
} else {
Ty::unit()
};
@ -660,25 +883,27 @@ pub fn infer(db: &impl HirDatabase, function: Function) -> Cancelable<InferenceR
};
if let Some(type_ref) = param.type_ref() {
let ty = Ty::from_ast(db, &ctx.module, type_ref)?;
let ty = ctx.insert_type_vars(ty);
ctx.type_of.insert(LocalSyntaxPtr::new(pat.syntax()), ty);
} else {
// TODO self param
let type_var = ctx.new_type_var();
ctx.type_of
.insert(LocalSyntaxPtr::new(pat.syntax()), Ty::Unknown);
.insert(LocalSyntaxPtr::new(pat.syntax()), type_var);
};
}
}
// TODO get Ty for node.ret_type() and pass that to infer_block as expectation
// (see Expectation in rustc_typeck)
let ret_ty = if let Some(type_ref) = node.ret_type().and_then(|n| n.type_ref()) {
let ty = Ty::from_ast(db, &ctx.module, type_ref)?;
ctx.insert_type_vars(ty)
} else {
Ty::unit()
};
if let Some(block) = node.body() {
ctx.infer_block(block)?;
ctx.infer_block(block, &Expectation::has_type(ret_ty))?;
}
// TODO 'resolve' the types: replace inference variables by their inferred results
Ok(InferenceResult {
type_of: ctx.type_of,
})
Ok(ctx.resolve_all())
}

View file

@ -113,6 +113,27 @@ fn test(a: &u32, b: &mut u32, c: *const u32, d: *mut u32) {
);
}
#[test]
fn infer_backwards() {
check_inference(
r#"
fn takes_u32(x: u32) {}
struct S { i32_field: i32 }
fn test() -> &mut &f64 {
let a = unknown_function();
takes_u32(a);
let b = unknown_function();
S { i32_field: b };
let c = unknown_function();
&mut &c
}
"#,
"0006_backwards.txt",
);
}
fn infer(content: &str) -> String {
let (db, _, file_id) = MockDatabase::with_single_file(content);
let source_file = db.source_file(file_id);

View file

@ -1,5 +1,5 @@
[21; 22) 'a': [unknown]
[52; 53) '1': [unknown]
[52; 53) '1': usize
[11; 71) '{ ...= b; }': ()
[63; 64) 'c': usize
[25; 31) '1isize': [unknown]

View file

@ -1,7 +1,7 @@
[15; 20) '{ 1 }': [unknown]
[17; 18) '1': [unknown]
[50; 51) '1': [unknown]
[48; 53) '{ 1 }': [unknown]
[15; 20) '{ 1 }': u32
[17; 18) '1': u32
[50; 51) '1': u32
[48; 53) '{ 1 }': u32
[82; 88) 'b::c()': u32
[67; 91) '{ ...c(); }': ()
[73; 74) 'a': fn() -> u32

View file

@ -1,5 +1,5 @@
[86; 90) 'C(1)': [unknown]
[121; 122) 'B': [unknown]
[121; 122) 'B': B
[86; 87) 'C': [unknown]
[129; 130) '1': [unknown]
[107; 108) 'a': A
@ -13,4 +13,4 @@
[96; 97) 'B': [unknown]
[88; 89) '1': [unknown]
[82; 83) 'c': [unknown]
[127; 131) 'C(1)': [unknown]
[127; 131) 'C(1)': C

View file

@ -0,0 +1,20 @@
[22; 24) '{}': ()
[14; 15) 'x': u32
[142; 158) 'unknow...nction': [unknown]
[126; 127) 'a': u32
[198; 216) 'unknow...tion()': f64
[228; 229) 'c': f64
[198; 214) 'unknow...nction': [unknown]
[166; 184) 'S { i3...d: b }': S
[222; 229) '&mut &c': &mut &f64
[194; 195) 'c': f64
[92; 110) 'unknow...tion()': u32
[142; 160) 'unknow...tion()': i32
[92; 108) 'unknow...nction': [unknown]
[116; 128) 'takes_u32(a)': [unknown]
[78; 231) '{ ...t &c }': &mut &f64
[227; 229) '&c': &f64
[88; 89) 'a': u32
[181; 182) 'b': i32
[116; 125) 'takes_u32': fn(u32,) -> [unknown]
[138; 139) 'b': i32