Implement type variables

This will really become necessary when we implement generics, but even now, it
allows us to reason 'backwards' to infer types of expressions that we didn't
understand for some reason.

We use ena, the union-find implementation extracted from rustc, to keep track of
type variables.
This commit is contained in:
Florian Diebold 2018-12-26 17:00:42 +01:00
parent f3f073804c
commit cfa1de72eb
8 changed files with 395 additions and 118 deletions

10
Cargo.lock generated
View file

@ -271,6 +271,14 @@ name = "either"
version = "1.5.0" version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "ena"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]] [[package]]
name = "error-chain" name = "error-chain"
version = "0.12.0" version = "0.12.0"
@ -737,6 +745,7 @@ name = "ra_hir"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"arrayvec 0.4.10 (registry+https://github.com/rust-lang/crates.io-index)", "arrayvec 0.4.10 (registry+https://github.com/rust-lang/crates.io-index)",
"ena 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
"flexi_logger 0.10.3 (registry+https://github.com/rust-lang/crates.io-index)", "flexi_logger 0.10.3 (registry+https://github.com/rust-lang/crates.io-index)",
"id-arena 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "id-arena 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)",
@ -1546,6 +1555,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum digest 0.7.6 (registry+https://github.com/rust-lang/crates.io-index)" = "03b072242a8cbaf9c145665af9d250c59af3b958f83ed6824e13533cf76d5b90" "checksum digest 0.7.6 (registry+https://github.com/rust-lang/crates.io-index)" = "03b072242a8cbaf9c145665af9d250c59af3b958f83ed6824e13533cf76d5b90"
"checksum drop_bomb 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "69b26e475fd29098530e709294e94e661974c851aed42512793f120fed4e199f" "checksum drop_bomb 0.1.4 (registry+https://github.com/rust-lang/crates.io-index)" = "69b26e475fd29098530e709294e94e661974c851aed42512793f120fed4e199f"
"checksum either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3be565ca5c557d7f59e7cfcf1844f9e3033650c929c6566f511e8005f205c1d0" "checksum either 1.5.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3be565ca5c557d7f59e7cfcf1844f9e3033650c929c6566f511e8005f205c1d0"
"checksum ena 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "f56c93cc076508c549d9bb747f79aa9b4eb098be7b8cad8830c3137ef52d1e00"
"checksum error-chain 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "07e791d3be96241c77c43846b665ef1384606da2cd2a48730abe606a12906e02" "checksum error-chain 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)" = "07e791d3be96241c77c43846b665ef1384606da2cd2a48730abe606a12906e02"
"checksum failure 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6dd377bcc1b1b7ce911967e3ec24fa19c3224394ec05b54aa7b083d498341ac7" "checksum failure 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "6dd377bcc1b1b7ce911967e3ec24fa19c3224394ec05b54aa7b083d498341ac7"
"checksum failure_derive 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "64c2d913fe8ed3b6c6518eedf4538255b989945c14c2a7d5cbff62a5e2120596" "checksum failure_derive 0.1.3 (registry+https://github.com/rust-lang/crates.io-index)" = "64c2d913fe8ed3b6c6518eedf4538255b989945c14c2a7d5cbff62a5e2120596"

View file

@ -12,6 +12,7 @@ salsa = "0.9.0"
rustc-hash = "1.0" rustc-hash = "1.0"
parking_lot = "0.7.0" parking_lot = "0.7.0"
id-arena = "2.0" id-arena = "2.0"
ena = "0.11"
ra_syntax = { path = "../ra_syntax" } ra_syntax = { path = "../ra_syntax" }
ra_editor = { path = "../ra_editor" } ra_editor = { path = "../ra_editor" }
ra_db = { path = "../ra_db" } ra_db = { path = "../ra_db" }

View file

@ -3,10 +3,11 @@ mod primitive;
mod tests; mod tests;
use std::sync::Arc; use std::sync::Arc;
use std::fmt; use std::{fmt, mem};
use log; use log;
use rustc_hash::{FxHashMap}; use rustc_hash::FxHashMap;
use ena::unify::{InPlaceUnificationTable, UnifyKey, UnifyValue, NoError};
use ra_db::{LocalSyntaxPtr, Cancelable}; use ra_db::{LocalSyntaxPtr, Cancelable};
use ra_syntax::{ use ra_syntax::{
@ -17,10 +18,89 @@ use ra_syntax::{
use crate::{ use crate::{
Def, DefId, FnScopes, Module, Function, Struct, Enum, Path, Name, AsName, Def, DefId, FnScopes, Module, Function, Struct, Enum, Path, Name, AsName,
db::HirDatabase, db::HirDatabase,
adt::VariantData,
type_ref::{TypeRef, Mutability}, type_ref::{TypeRef, Mutability},
}; };
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub struct TypeVarId(u32);
impl UnifyKey for TypeVarId {
type Value = TypeVarValue;
fn index(&self) -> u32 {
self.0
}
fn from_index(i: u32) -> Self {
TypeVarId(i)
}
fn tag() -> &'static str {
"TypeVarId"
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum TypeVarValue {
Known(Ty),
Unknown,
}
impl TypeVarValue {
pub fn known(&self) -> Option<&Ty> {
match self {
TypeVarValue::Known(ty) => Some(ty),
TypeVarValue::Unknown => None,
}
}
}
impl UnifyValue for TypeVarValue {
type Error = NoError;
fn unify_values(value1: &Self, value2: &Self) -> Result<Self, NoError> {
match (value1, value2) {
// We should never equate two type variables, both of which have
// known types. Instead, we recursively equate those types.
(TypeVarValue::Known(..), TypeVarValue::Known(..)) => {
panic!("equating two type variables, both of which have known types")
}
// If one side is known, prefer that one.
(TypeVarValue::Known(..), TypeVarValue::Unknown) => Ok(value1.clone()),
(TypeVarValue::Unknown, TypeVarValue::Known(..)) => Ok(value2.clone()),
(TypeVarValue::Unknown, TypeVarValue::Unknown) => Ok(TypeVarValue::Unknown),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum InferTy {
TypeVar(TypeVarId),
// later we'll have IntVar and FloatVar as well
}
/// When inferring an expression, we propagate downward whatever type hint we
/// are able in the form of an `Expectation`.
#[derive(Clone, PartialEq, Eq, Debug)]
struct Expectation {
ty: Ty,
// TODO: In some cases, we need to be aware whether the expectation is that
// the type match exactly what we passed, or whether it just needs to be
// coercible to the expected type. See Expectation::rvalue_hint in rustc.
}
impl Expectation {
fn has_type(ty: Ty) -> Self {
Expectation { ty }
}
fn none() -> Self {
Expectation { ty: Ty::Unknown }
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)] #[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum Ty { pub enum Ty {
/// The primitive boolean type. Written as `bool`. /// The primitive boolean type. Written as `bool`.
@ -75,23 +155,22 @@ pub enum Ty {
// A trait, defined with `dyn trait`. // A trait, defined with `dyn trait`.
// Dynamic(), // Dynamic(),
/// The anonymous type of a closure. Used to represent the type of // The anonymous type of a closure. Used to represent the type of
/// `|a| a`. // `|a| a`.
// Closure(DefId, ClosureSubsts<'tcx>), // Closure(DefId, ClosureSubsts<'tcx>),
/// The anonymous type of a generator. Used to represent the type of // The anonymous type of a generator. Used to represent the type of
/// `|a| yield a`. // `|a| yield a`.
// Generator(DefId, GeneratorSubsts<'tcx>, hir::GeneratorMovability), // Generator(DefId, GeneratorSubsts<'tcx>, hir::GeneratorMovability),
/// A type representin the types stored inside a generator. // A type representin the types stored inside a generator.
/// This should only appear in GeneratorInteriors. // This should only appear in GeneratorInteriors.
// GeneratorWitness(Binder<&'tcx List<Ty<'tcx>>>), // GeneratorWitness(Binder<&'tcx List<Ty<'tcx>>>),
/// The never type `!` /// The never type `!`
Never, Never,
/// A tuple type. For example, `(i32, bool)`. /// A tuple type. For example, `(i32, bool)`.
Tuple(Vec<Ty>), Tuple(Arc<[Ty]>),
// The projection of an associated type. For example, // The projection of an associated type. For example,
// `<T as Trait<..>>::N`.pub // `<T as Trait<..>>::N`.pub
@ -106,14 +185,14 @@ pub enum Ty {
// A type parameter; for example, `T` in `fn f<T>(x: T) {} // A type parameter; for example, `T` in `fn f<T>(x: T) {}
// Param(ParamTy), // Param(ParamTy),
/// A type variable used during type checking. Not to be confused with a
/// type parameter.
Infer(InferTy),
// A placeholder type - universally quantified higher-ranked type. /// A placeholder for a type which could not be computed; this is propagated
// Placeholder(ty::PlaceholderType), /// to avoid useless error messages. Doubles as a placeholder where type
/// variables are inserted before type checking, since we want to try to
// A type variable used during type checking. /// infer a better type here anyway.
// Infer(InferTy),
/// A placeholder for a type which could not be computed; this is
/// propagated to avoid useless error messages.
Unknown, Unknown,
} }
@ -137,8 +216,8 @@ impl Ty {
let inner_tys = inner let inner_tys = inner
.iter() .iter()
.map(|tr| Ty::from_hir(db, module, tr)) .map(|tr| Ty::from_hir(db, module, tr))
.collect::<Cancelable<_>>()?; .collect::<Cancelable<Vec<_>>>()?;
Ty::Tuple(inner_tys) Ty::Tuple(inner_tys.into())
} }
TypeRef::Path(path) => Ty::from_hir_path(db, module, path)?, TypeRef::Path(path) => Ty::from_hir_path(db, module, path)?,
TypeRef::RawPtr(inner, mutability) => { TypeRef::RawPtr(inner, mutability) => {
@ -154,7 +233,7 @@ impl Ty {
let inner_ty = Ty::from_hir(db, module, inner)?; let inner_ty = Ty::from_hir(db, module, inner)?;
Ty::Ref(Arc::new(inner_ty), *mutability) Ty::Ref(Arc::new(inner_ty), *mutability)
} }
TypeRef::Placeholder => Ty::Unknown, // TODO TypeRef::Placeholder => Ty::Unknown,
TypeRef::Fn(params) => { TypeRef::Fn(params) => {
let mut inner_tys = params let mut inner_tys = params
.iter() .iter()
@ -217,7 +296,41 @@ impl Ty {
} }
pub fn unit() -> Self { pub fn unit() -> Self {
Ty::Tuple(Vec::new()) Ty::Tuple(Arc::new([]))
}
fn walk_mut(&mut self, f: &mut impl FnMut(&mut Ty)) {
f(self);
match self {
Ty::Slice(t) => Arc::make_mut(t).walk_mut(f),
Ty::RawPtr(t, _) => Arc::make_mut(t).walk_mut(f),
Ty::Ref(t, _) => Arc::make_mut(t).walk_mut(f),
Ty::Tuple(ts) => {
// Without an Arc::make_mut_slice, we can't avoid the clone here:
let mut v: Vec<_> = ts.iter().cloned().collect();
for t in &mut v {
t.walk_mut(f);
}
*ts = v.into();
}
Ty::FnPtr(sig) => {
let sig_mut = Arc::make_mut(sig);
for input in &mut sig_mut.input {
input.walk_mut(f);
}
sig_mut.output.walk_mut(f);
}
Ty::Adt { .. } => {} // need to walk type parameters later
_ => {}
}
}
fn fold(mut self, f: &mut impl FnMut(Ty) -> Ty) -> Ty {
self.walk_mut(&mut |ty_mut| {
let ty = mem::replace(ty_mut, Ty::Unknown);
*ty_mut = f(ty);
});
self
} }
} }
@ -236,7 +349,7 @@ impl fmt::Display for Ty {
Ty::Never => write!(f, "!"), Ty::Never => write!(f, "!"),
Ty::Tuple(ts) => { Ty::Tuple(ts) => {
write!(f, "(")?; write!(f, "(")?;
for t in ts { for t in ts.iter() {
write!(f, "{},", t)?; write!(f, "{},", t)?;
} }
write!(f, ")") write!(f, ")")
@ -250,6 +363,7 @@ impl fmt::Display for Ty {
} }
Ty::Adt { name, .. } => write!(f, "{}", name), Ty::Adt { name, .. } => write!(f, "{}", name),
Ty::Unknown => write!(f, "[unknown]"), Ty::Unknown => write!(f, "[unknown]"),
Ty::Infer(..) => write!(f, "_"),
} }
} }
} }
@ -342,7 +456,7 @@ pub struct InferenceContext<'a, D: HirDatabase> {
db: &'a D, db: &'a D,
scopes: Arc<FnScopes>, scopes: Arc<FnScopes>,
module: Module, module: Module,
// TODO unification tables... var_unification_table: InPlaceUnificationTable<TypeVarId>,
type_of: FxHashMap<LocalSyntaxPtr, Ty>, type_of: FxHashMap<LocalSyntaxPtr, Ty>,
} }
@ -350,33 +464,116 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
fn new(db: &'a D, scopes: Arc<FnScopes>, module: Module) -> Self { fn new(db: &'a D, scopes: Arc<FnScopes>, module: Module) -> Self {
InferenceContext { InferenceContext {
type_of: FxHashMap::default(), type_of: FxHashMap::default(),
var_unification_table: InPlaceUnificationTable::new(),
db, db,
scopes, scopes,
module, module,
} }
} }
fn resolve_all(mut self) -> InferenceResult {
let mut types = mem::replace(&mut self.type_of, FxHashMap::default());
for ty in types.values_mut() {
let resolved = self.resolve_ty_completely(mem::replace(ty, Ty::Unknown));
*ty = resolved;
}
InferenceResult { type_of: types }
}
fn write_ty(&mut self, node: SyntaxNodeRef, ty: Ty) { fn write_ty(&mut self, node: SyntaxNodeRef, ty: Ty) {
self.type_of.insert(LocalSyntaxPtr::new(node), ty); self.type_of.insert(LocalSyntaxPtr::new(node), ty);
} }
fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> Option<Ty> { fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
if *ty1 == Ty::Unknown { match (ty1, ty2) {
return Some(ty2.clone()); (Ty::Unknown, ..) => true,
(.., Ty::Unknown) => true,
(Ty::Bool, _)
| (Ty::Str, _)
| (Ty::Never, _)
| (Ty::Char, _)
| (Ty::Int(..), Ty::Int(..))
| (Ty::Uint(..), Ty::Uint(..))
| (Ty::Float(..), Ty::Float(..)) => ty1 == ty2,
(
Ty::Adt {
def_id: def_id1, ..
},
Ty::Adt {
def_id: def_id2, ..
},
) if def_id1 == def_id2 => true,
(Ty::Slice(t1), Ty::Slice(t2)) => self.unify(t1, t2),
(Ty::RawPtr(t1, m1), Ty::RawPtr(t2, m2)) if m1 == m2 => self.unify(t1, t2),
(Ty::Ref(t1, m1), Ty::Ref(t2, m2)) if m1 == m2 => self.unify(t1, t2),
(Ty::FnPtr(sig1), Ty::FnPtr(sig2)) if sig1 == sig2 => true,
(Ty::Tuple(ts1), Ty::Tuple(ts2)) if ts1.len() == ts2.len() => ts1
.iter()
.zip(ts2.iter())
.all(|(t1, t2)| self.unify(t1, t2)),
(Ty::Infer(InferTy::TypeVar(tv1)), Ty::Infer(InferTy::TypeVar(tv2))) => {
self.var_unification_table.union(*tv1, *tv2);
true
}
(Ty::Infer(InferTy::TypeVar(tv)), other) | (other, Ty::Infer(InferTy::TypeVar(tv))) => {
self.var_unification_table
.union_value(*tv, TypeVarValue::Known(other.clone()));
true
}
_ => false,
} }
if *ty2 == Ty::Unknown {
return Some(ty1.clone());
}
if ty1 == ty2 {
return Some(ty1.clone());
}
// TODO implement actual unification
return None;
} }
fn unify_with_coercion(&mut self, ty1: &Ty, ty2: &Ty) -> Option<Ty> { fn new_type_var(&mut self) -> Ty {
// TODO implement coercion Ty::Infer(InferTy::TypeVar(
self.unify(ty1, ty2) self.var_unification_table.new_key(TypeVarValue::Unknown),
))
}
/// Replaces Ty::Unknown by a new type var, so we can maybe still infer it.
fn insert_type_vars_shallow(&mut self, ty: Ty) -> Ty {
match ty {
Ty::Unknown => self.new_type_var(),
_ => ty,
}
}
fn insert_type_vars(&mut self, ty: Ty) -> Ty {
ty.fold(&mut |ty| self.insert_type_vars_shallow(ty))
}
/// Resolves the type as far as currently possible, replacing type variables
/// by their known types. All types returned by the infer_* functions should
/// be resolved as far as possible, i.e. contain no type variables with
/// known type.
fn resolve_ty_as_possible(&mut self, ty: Ty) -> Ty {
ty.fold(&mut |ty| match ty {
Ty::Infer(InferTy::TypeVar(tv)) => {
if let Some(known_ty) = self.var_unification_table.probe_value(tv).known() {
// known_ty may contain other variables that are known by now
self.resolve_ty_as_possible(known_ty.clone())
} else {
Ty::Infer(InferTy::TypeVar(tv))
}
}
_ => ty,
})
}
/// Resolves the type completely; type variables without known type are
/// replaced by Ty::Unknown.
fn resolve_ty_completely(&mut self, ty: Ty) -> Ty {
ty.fold(&mut |ty| match ty {
Ty::Infer(InferTy::TypeVar(tv)) => {
if let Some(known_ty) = self.var_unification_table.probe_value(tv).known() {
// known_ty may contain other variables that are known by now
self.resolve_ty_completely(known_ty.clone())
} else {
Ty::Unknown
}
}
_ => ty,
})
} }
fn infer_path_expr(&mut self, expr: ast::PathExpr) -> Cancelable<Option<Ty>> { fn infer_path_expr(&mut self, expr: ast::PathExpr) -> Cancelable<Option<Ty>> {
@ -387,21 +584,19 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
let name = ctry!(ast_path.segment().and_then(|s| s.name_ref())); let name = ctry!(ast_path.segment().and_then(|s| s.name_ref()));
if let Some(scope_entry) = self.scopes.resolve_local_name(name) { if let Some(scope_entry) = self.scopes.resolve_local_name(name) {
let ty = ctry!(self.type_of.get(&scope_entry.ptr())); let ty = ctry!(self.type_of.get(&scope_entry.ptr()));
return Ok(Some(ty.clone())); let ty = self.resolve_ty_as_possible(ty.clone());
return Ok(Some(ty));
}; };
}; };
// resolve in module // resolve in module
let resolved = ctry!(self.module.resolve_path(self.db, &path)?.take_values()); let resolved = ctry!(self.module.resolve_path(self.db, &path)?.take_values());
let ty = self.db.type_for_def(resolved)?; let ty = self.db.type_for_def(resolved)?;
// TODO we will need to add type variables for type parameters etc. here let ty = self.insert_type_vars(ty);
Ok(Some(ty)) Ok(Some(ty))
} }
fn resolve_variant( fn resolve_variant(&self, path: Option<ast::Path>) -> Cancelable<(Ty, Option<DefId>)> {
&self,
path: Option<ast::Path>,
) -> Cancelable<(Ty, Option<Arc<VariantData>>)> {
let path = if let Some(path) = path.and_then(Path::from_ast) { let path = if let Some(path) = path.and_then(Path::from_ast) {
path path
} else { } else {
@ -414,102 +609,116 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
}; };
Ok(match def_id.resolve(self.db)? { Ok(match def_id.resolve(self.db)? {
Def::Struct(s) => { Def::Struct(s) => {
let struct_data = self.db.struct_data(def_id)?;
let ty = type_for_struct(self.db, s)?; let ty = type_for_struct(self.db, s)?;
(ty, Some(struct_data.variant_data().clone())) (ty, Some(def_id))
} }
_ => (Ty::Unknown, None), _ => (Ty::Unknown, None),
}) })
} }
fn infer_expr_opt(&mut self, expr: Option<ast::Expr>) -> Cancelable<Ty> { fn infer_expr_opt(
&mut self,
expr: Option<ast::Expr>,
expected: &Expectation,
) -> Cancelable<Ty> {
if let Some(e) = expr { if let Some(e) = expr {
self.infer_expr(e) self.infer_expr(e, expected)
} else { } else {
Ok(Ty::Unknown) Ok(Ty::Unknown)
} }
} }
fn infer_expr(&mut self, expr: ast::Expr) -> Cancelable<Ty> { fn infer_expr(&mut self, expr: ast::Expr, expected: &Expectation) -> Cancelable<Ty> {
let ty = match expr { let ty = match expr {
ast::Expr::IfExpr(e) => { ast::Expr::IfExpr(e) => {
if let Some(condition) = e.condition() { if let Some(condition) = e.condition() {
// TODO if no pat, this should be bool let expected = if condition.pat().is_none() {
self.infer_expr_opt(condition.expr())?; Expectation::has_type(Ty::Bool)
} else {
Expectation::none()
};
self.infer_expr_opt(condition.expr(), &expected)?;
// TODO write type for pat // TODO write type for pat
}; };
let if_ty = self.infer_block_opt(e.then_branch())?; let if_ty = self.infer_block_opt(e.then_branch(), expected)?;
let else_ty = self.infer_block_opt(e.else_branch())?; if let Some(else_branch) = e.else_branch() {
if let Some(ty) = self.unify(&if_ty, &else_ty) { self.infer_block(else_branch, expected)?;
ty
} else { } else {
// TODO report diagnostic // no else branch -> unit
Ty::Unknown self.unify(&expected.ty, &Ty::unit()); // actually coerce
} }
if_ty
} }
ast::Expr::BlockExpr(e) => self.infer_block_opt(e.block())?, ast::Expr::BlockExpr(e) => self.infer_block_opt(e.block(), expected)?,
ast::Expr::LoopExpr(e) => { ast::Expr::LoopExpr(e) => {
self.infer_block_opt(e.loop_body())?; self.infer_block_opt(e.loop_body(), &Expectation::has_type(Ty::unit()))?;
// TODO never, or the type of the break param // TODO never, or the type of the break param
Ty::Unknown Ty::Unknown
} }
ast::Expr::WhileExpr(e) => { ast::Expr::WhileExpr(e) => {
if let Some(condition) = e.condition() { if let Some(condition) = e.condition() {
// TODO if no pat, this should be bool let expected = if condition.pat().is_none() {
self.infer_expr_opt(condition.expr())?; Expectation::has_type(Ty::Bool)
} else {
Expectation::none()
};
self.infer_expr_opt(condition.expr(), &expected)?;
// TODO write type for pat // TODO write type for pat
}; };
self.infer_block_opt(e.loop_body())?; self.infer_block_opt(e.loop_body(), &Expectation::has_type(Ty::unit()))?;
// TODO always unit? // TODO always unit?
Ty::Unknown Ty::unit()
} }
ast::Expr::ForExpr(e) => { ast::Expr::ForExpr(e) => {
let _iterable_ty = self.infer_expr_opt(e.iterable()); let _iterable_ty = self.infer_expr_opt(e.iterable(), &Expectation::none());
if let Some(_pat) = e.pat() { if let Some(_pat) = e.pat() {
// TODO write type for pat // TODO write type for pat
} }
self.infer_block_opt(e.loop_body())?; self.infer_block_opt(e.loop_body(), &Expectation::has_type(Ty::unit()))?;
// TODO always unit? // TODO always unit?
Ty::Unknown Ty::unit()
} }
ast::Expr::LambdaExpr(e) => { ast::Expr::LambdaExpr(e) => {
let _body_ty = self.infer_expr_opt(e.body())?; let _body_ty = self.infer_expr_opt(e.body(), &Expectation::none())?;
Ty::Unknown Ty::Unknown
} }
ast::Expr::CallExpr(e) => { ast::Expr::CallExpr(e) => {
let callee_ty = self.infer_expr_opt(e.expr())?; let callee_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
if let Some(arg_list) = e.arg_list() { let (arg_tys, ret_ty) = match &callee_ty {
for arg in arg_list.args() { Ty::FnPtr(sig) => (&sig.input[..], sig.output.clone()),
// TODO unify / expect argument type
self.infer_expr(arg)?;
}
}
match callee_ty {
Ty::FnPtr(sig) => sig.output.clone(),
_ => { _ => {
// not callable // not callable
// TODO report an error? // TODO report an error?
Ty::Unknown (&[][..], Ty::Unknown)
}
};
if let Some(arg_list) = e.arg_list() {
for (i, arg) in arg_list.args().enumerate() {
self.infer_expr(
arg,
&Expectation::has_type(arg_tys.get(i).cloned().unwrap_or(Ty::Unknown)),
)?;
} }
} }
ret_ty
} }
ast::Expr::MethodCallExpr(e) => { ast::Expr::MethodCallExpr(e) => {
let _receiver_ty = self.infer_expr_opt(e.expr())?; let _receiver_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
if let Some(arg_list) = e.arg_list() { if let Some(arg_list) = e.arg_list() {
for arg in arg_list.args() { for arg in arg_list.args() {
// TODO unify / expect argument type // TODO unify / expect argument type
self.infer_expr(arg)?; self.infer_expr(arg, &Expectation::none())?;
} }
} }
Ty::Unknown Ty::Unknown
} }
ast::Expr::MatchExpr(e) => { ast::Expr::MatchExpr(e) => {
let _ty = self.infer_expr_opt(e.expr())?; let _ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
if let Some(match_arm_list) = e.match_arm_list() { if let Some(match_arm_list) = e.match_arm_list() {
for arm in match_arm_list.arms() { for arm in match_arm_list.arms() {
// TODO type the bindings in pat // TODO type the bindings in pat
// TODO type the guard // TODO type the guard
let _ty = self.infer_expr_opt(arm.expr())?; let _ty = self.infer_expr_opt(arm.expr(), &Expectation::none())?;
} }
// TODO unify all the match arm types // TODO unify all the match arm types
Ty::Unknown Ty::Unknown
@ -522,10 +731,10 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
ast::Expr::PathExpr(e) => self.infer_path_expr(e)?.unwrap_or(Ty::Unknown), ast::Expr::PathExpr(e) => self.infer_path_expr(e)?.unwrap_or(Ty::Unknown),
ast::Expr::ContinueExpr(_e) => Ty::Never, ast::Expr::ContinueExpr(_e) => Ty::Never,
ast::Expr::BreakExpr(_e) => Ty::Never, ast::Expr::BreakExpr(_e) => Ty::Never,
ast::Expr::ParenExpr(e) => self.infer_expr_opt(e.expr())?, ast::Expr::ParenExpr(e) => self.infer_expr_opt(e.expr(), expected)?,
ast::Expr::Label(_e) => Ty::Unknown, ast::Expr::Label(_e) => Ty::Unknown,
ast::Expr::ReturnExpr(e) => { ast::Expr::ReturnExpr(e) => {
self.infer_expr_opt(e.expr())?; self.infer_expr_opt(e.expr(), &Expectation::none())?;
Ty::Never Ty::Never
} }
ast::Expr::MatchArmList(_) | ast::Expr::MatchArm(_) | ast::Expr::MatchGuard(_) => { ast::Expr::MatchArmList(_) | ast::Expr::MatchArm(_) | ast::Expr::MatchGuard(_) => {
@ -533,11 +742,16 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
Ty::Unknown Ty::Unknown
} }
ast::Expr::StructLit(e) => { ast::Expr::StructLit(e) => {
let (ty, _variant_data) = self.resolve_variant(e.path())?; let (ty, def_id) = self.resolve_variant(e.path())?;
if let Some(nfl) = e.named_field_list() { if let Some(nfl) = e.named_field_list() {
for field in nfl.fields() { for field in nfl.fields() {
// TODO unify with / expect field type let field_ty = if let (Some(def_id), Some(nr)) = (def_id, field.name_ref())
self.infer_expr_opt(field.expr())?; {
self.db.type_for_field(def_id, nr.as_name())?
} else {
Ty::Unknown
};
self.infer_expr_opt(field.expr(), &Expectation::has_type(field_ty))?;
} }
} }
ty ty
@ -548,9 +762,9 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
} }
ast::Expr::IndexExpr(_e) => Ty::Unknown, ast::Expr::IndexExpr(_e) => Ty::Unknown,
ast::Expr::FieldExpr(e) => { ast::Expr::FieldExpr(e) => {
let receiver_ty = self.infer_expr_opt(e.expr())?; let receiver_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
if let Some(nr) = e.name_ref() { if let Some(nr) = e.name_ref() {
match receiver_ty { let ty = match receiver_ty {
Ty::Tuple(fields) => { Ty::Tuple(fields) => {
let i = nr.text().parse::<usize>().ok(); let i = nr.text().parse::<usize>().ok();
i.and_then(|i| fields.get(i).cloned()) i.and_then(|i| fields.get(i).cloned())
@ -558,29 +772,32 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
} }
Ty::Adt { def_id, .. } => self.db.type_for_field(def_id, nr.as_name())?, Ty::Adt { def_id, .. } => self.db.type_for_field(def_id, nr.as_name())?,
_ => Ty::Unknown, _ => Ty::Unknown,
} };
self.insert_type_vars(ty)
} else { } else {
Ty::Unknown Ty::Unknown
} }
} }
ast::Expr::TryExpr(e) => { ast::Expr::TryExpr(e) => {
let _inner_ty = self.infer_expr_opt(e.expr())?; let _inner_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
Ty::Unknown Ty::Unknown
} }
ast::Expr::CastExpr(e) => { ast::Expr::CastExpr(e) => {
let _inner_ty = self.infer_expr_opt(e.expr())?; let _inner_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
let cast_ty = Ty::from_ast_opt(self.db, &self.module, e.type_ref())?; let cast_ty = Ty::from_ast_opt(self.db, &self.module, e.type_ref())?;
let cast_ty = self.insert_type_vars(cast_ty);
// TODO do the coercion... // TODO do the coercion...
cast_ty cast_ty
} }
ast::Expr::RefExpr(e) => { ast::Expr::RefExpr(e) => {
let inner_ty = self.infer_expr_opt(e.expr())?; // TODO pass the expectation down
let inner_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
let m = Mutability::from_mutable(e.is_mut()); let m = Mutability::from_mutable(e.is_mut());
// TODO reference coercions etc. // TODO reference coercions etc.
Ty::Ref(Arc::new(inner_ty), m) Ty::Ref(Arc::new(inner_ty), m)
} }
ast::Expr::PrefixExpr(e) => { ast::Expr::PrefixExpr(e) => {
let inner_ty = self.infer_expr_opt(e.expr())?; let inner_ty = self.infer_expr_opt(e.expr(), &Expectation::none())?;
match e.op() { match e.op() {
Some(PrefixOp::Deref) => { Some(PrefixOp::Deref) => {
match inner_ty { match inner_ty {
@ -598,28 +815,34 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
ast::Expr::BinExpr(_e) => Ty::Unknown, ast::Expr::BinExpr(_e) => Ty::Unknown,
ast::Expr::Literal(_e) => Ty::Unknown, ast::Expr::Literal(_e) => Ty::Unknown,
}; };
// use a new type variable if we got Ty::Unknown here
let ty = self.insert_type_vars_shallow(ty);
self.unify(&ty, &expected.ty);
self.write_ty(expr.syntax(), ty.clone()); self.write_ty(expr.syntax(), ty.clone());
Ok(ty) Ok(ty)
} }
fn infer_block_opt(&mut self, node: Option<ast::Block>) -> Cancelable<Ty> { fn infer_block_opt(
&mut self,
node: Option<ast::Block>,
expected: &Expectation,
) -> Cancelable<Ty> {
if let Some(b) = node { if let Some(b) = node {
self.infer_block(b) self.infer_block(b, expected)
} else { } else {
Ok(Ty::Unknown) Ok(Ty::Unknown)
} }
} }
fn infer_block(&mut self, node: ast::Block) -> Cancelable<Ty> { fn infer_block(&mut self, node: ast::Block, expected: &Expectation) -> Cancelable<Ty> {
for stmt in node.statements() { for stmt in node.statements() {
match stmt { match stmt {
ast::Stmt::LetStmt(stmt) => { ast::Stmt::LetStmt(stmt) => {
let decl_ty = Ty::from_ast_opt(self.db, &self.module, stmt.type_ref())?; let decl_ty = Ty::from_ast_opt(self.db, &self.module, stmt.type_ref())?;
let decl_ty = self.insert_type_vars(decl_ty);
let ty = if let Some(expr) = stmt.initializer() { let ty = if let Some(expr) = stmt.initializer() {
// TODO pass expectation let expr_ty = self.infer_expr(expr, &Expectation::has_type(decl_ty))?;
let expr_ty = self.infer_expr(expr)?; expr_ty
self.unify_with_coercion(&expr_ty, &decl_ty)
.unwrap_or(decl_ty)
} else { } else {
decl_ty decl_ty
}; };
@ -629,12 +852,12 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
}; };
} }
ast::Stmt::ExprStmt(expr_stmt) => { ast::Stmt::ExprStmt(expr_stmt) => {
self.infer_expr_opt(expr_stmt.expr())?; self.infer_expr_opt(expr_stmt.expr(), &Expectation::none())?;
} }
} }
} }
let ty = if let Some(expr) = node.expr() { let ty = if let Some(expr) = node.expr() {
self.infer_expr(expr)? self.infer_expr(expr, expected)?
} else { } else {
Ty::unit() Ty::unit()
}; };
@ -660,25 +883,27 @@ pub fn infer(db: &impl HirDatabase, function: Function) -> Cancelable<InferenceR
}; };
if let Some(type_ref) = param.type_ref() { if let Some(type_ref) = param.type_ref() {
let ty = Ty::from_ast(db, &ctx.module, type_ref)?; let ty = Ty::from_ast(db, &ctx.module, type_ref)?;
let ty = ctx.insert_type_vars(ty);
ctx.type_of.insert(LocalSyntaxPtr::new(pat.syntax()), ty); ctx.type_of.insert(LocalSyntaxPtr::new(pat.syntax()), ty);
} else { } else {
// TODO self param // TODO self param
let type_var = ctx.new_type_var();
ctx.type_of ctx.type_of
.insert(LocalSyntaxPtr::new(pat.syntax()), Ty::Unknown); .insert(LocalSyntaxPtr::new(pat.syntax()), type_var);
}; };
} }
} }
// TODO get Ty for node.ret_type() and pass that to infer_block as expectation let ret_ty = if let Some(type_ref) = node.ret_type().and_then(|n| n.type_ref()) {
// (see Expectation in rustc_typeck) let ty = Ty::from_ast(db, &ctx.module, type_ref)?;
ctx.insert_type_vars(ty)
} else {
Ty::unit()
};
if let Some(block) = node.body() { if let Some(block) = node.body() {
ctx.infer_block(block)?; ctx.infer_block(block, &Expectation::has_type(ret_ty))?;
} }
// TODO 'resolve' the types: replace inference variables by their inferred results Ok(ctx.resolve_all())
Ok(InferenceResult {
type_of: ctx.type_of,
})
} }

View file

@ -113,6 +113,27 @@ fn test(a: &u32, b: &mut u32, c: *const u32, d: *mut u32) {
); );
} }
#[test]
fn infer_backwards() {
check_inference(
r#"
fn takes_u32(x: u32) {}
struct S { i32_field: i32 }
fn test() -> &mut &f64 {
let a = unknown_function();
takes_u32(a);
let b = unknown_function();
S { i32_field: b };
let c = unknown_function();
&mut &c
}
"#,
"0006_backwards.txt",
);
}
fn infer(content: &str) -> String { fn infer(content: &str) -> String {
let (db, _, file_id) = MockDatabase::with_single_file(content); let (db, _, file_id) = MockDatabase::with_single_file(content);
let source_file = db.source_file(file_id); let source_file = db.source_file(file_id);

View file

@ -1,5 +1,5 @@
[21; 22) 'a': [unknown] [21; 22) 'a': [unknown]
[52; 53) '1': [unknown] [52; 53) '1': usize
[11; 71) '{ ...= b; }': () [11; 71) '{ ...= b; }': ()
[63; 64) 'c': usize [63; 64) 'c': usize
[25; 31) '1isize': [unknown] [25; 31) '1isize': [unknown]

View file

@ -1,7 +1,7 @@
[15; 20) '{ 1 }': [unknown] [15; 20) '{ 1 }': u32
[17; 18) '1': [unknown] [17; 18) '1': u32
[50; 51) '1': [unknown] [50; 51) '1': u32
[48; 53) '{ 1 }': [unknown] [48; 53) '{ 1 }': u32
[82; 88) 'b::c()': u32 [82; 88) 'b::c()': u32
[67; 91) '{ ...c(); }': () [67; 91) '{ ...c(); }': ()
[73; 74) 'a': fn() -> u32 [73; 74) 'a': fn() -> u32

View file

@ -1,5 +1,5 @@
[86; 90) 'C(1)': [unknown] [86; 90) 'C(1)': [unknown]
[121; 122) 'B': [unknown] [121; 122) 'B': B
[86; 87) 'C': [unknown] [86; 87) 'C': [unknown]
[129; 130) '1': [unknown] [129; 130) '1': [unknown]
[107; 108) 'a': A [107; 108) 'a': A
@ -13,4 +13,4 @@
[96; 97) 'B': [unknown] [96; 97) 'B': [unknown]
[88; 89) '1': [unknown] [88; 89) '1': [unknown]
[82; 83) 'c': [unknown] [82; 83) 'c': [unknown]
[127; 131) 'C(1)': [unknown] [127; 131) 'C(1)': C

View file

@ -0,0 +1,20 @@
[22; 24) '{}': ()
[14; 15) 'x': u32
[142; 158) 'unknow...nction': [unknown]
[126; 127) 'a': u32
[198; 216) 'unknow...tion()': f64
[228; 229) 'c': f64
[198; 214) 'unknow...nction': [unknown]
[166; 184) 'S { i3...d: b }': S
[222; 229) '&mut &c': &mut &f64
[194; 195) 'c': f64
[92; 110) 'unknow...tion()': u32
[142; 160) 'unknow...tion()': i32
[92; 108) 'unknow...nction': [unknown]
[116; 128) 'takes_u32(a)': [unknown]
[78; 231) '{ ...t &c }': &mut &f64
[227; 229) '&c': &f64
[88; 89) 'a': u32
[181; 182) 'b': i32
[116; 125) 'takes_u32': fn(u32,) -> [unknown]
[138; 139) 'b': i32