Implement parameter variance inference

This commit is contained in:
Lukas Wirth 2024-12-28 15:08:26 +01:00
parent 17b3662755
commit eee2761140
8 changed files with 1271 additions and 37 deletions

View file

@ -950,11 +950,18 @@ pub(crate) fn fn_def_datum_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Ar
pub(crate) fn fn_def_variance_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Variances {
let callable_def: CallableDefId = from_chalk(db, fn_def_id);
let generic_params =
generics(db.upcast(), GenericDefId::from_callable(db.upcast(), callable_def));
Variances::from_iter(
Interner,
std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()),
db.variances_of(GenericDefId::from_callable(db.upcast(), callable_def))
.as_deref()
.unwrap_or_default()
.iter()
.map(|v| match v {
crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant,
crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant,
crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant,
crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant,
}),
)
}
@ -962,10 +969,14 @@ pub(crate) fn adt_variance_query(
db: &dyn HirDatabase,
chalk_ir::AdtId(adt_id): AdtId,
) -> Variances {
let generic_params = generics(db.upcast(), adt_id.into());
Variances::from_iter(
Interner,
std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()),
db.variances_of(adt_id.into()).as_deref().unwrap_or_default().iter().map(|v| match v {
crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant,
crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant,
crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant,
crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant,
}),
)
}

View file

@ -271,6 +271,10 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
#[ra_salsa::invoke(chalk_db::adt_variance_query)]
fn adt_variance(&self, adt_id: chalk_db::AdtId) -> chalk_db::Variances;
#[ra_salsa::invoke(crate::variance::variances_of)]
#[ra_salsa::cycle(crate::variance::variances_of_cycle)]
fn variances_of(&self, def: GenericDefId) -> Option<Arc<[crate::variance::Variance]>>;
#[ra_salsa::invoke(chalk_db::associated_ty_value_query)]
fn associated_ty_value(
&self,

View file

@ -132,6 +132,14 @@ impl Generics {
self.params.len()
}
pub(crate) fn len_self_lifetimes(&self) -> usize {
self.params.len_lifetimes()
}
pub(crate) fn has_trait_self(&self) -> bool {
self.params.trait_self_param().is_some()
}
/// (parent total, self param, type params, const params, impl trait list, lifetimes)
pub(crate) fn provenance_split(&self) -> (usize, bool, usize, usize, usize, usize) {
let mut self_param = false;

View file

@ -50,6 +50,7 @@ pub mod traits;
mod test_db;
#[cfg(test)]
mod tests;
mod variance;
use std::hash::Hash;

View file

@ -127,7 +127,15 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour
None => continue,
};
let def_map = module.def_map(&db);
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
visit_module(&db, &def_map, module.local_id, &mut |it| {
defs.push(match it {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
})
});
}
defs.sort_by_key(|def| match def {
DefWithBodyId::FunctionId(it) => {
@ -375,7 +383,15 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
let def_map = module.def_map(&db);
let mut defs: Vec<DefWithBodyId> = Vec::new();
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
visit_module(&db, &def_map, module.local_id, &mut |it| {
defs.push(match it {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
})
});
defs.sort_by_key(|def| match def {
DefWithBodyId::FunctionId(it) => {
let loc = it.lookup(&db);
@ -405,11 +421,11 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
buf
}
fn visit_module(
pub(crate) fn visit_module(
db: &TestDB,
crate_def_map: &DefMap,
module_id: LocalModuleId,
cb: &mut dyn FnMut(DefWithBodyId),
cb: &mut dyn FnMut(ModuleDefId),
) {
visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb);
for impl_id in crate_def_map[module_id].scope.impls() {
@ -417,18 +433,18 @@ fn visit_module(
for &item in impl_data.items.iter() {
match item {
AssocItemId::FunctionId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
cb(it.into());
visit_body(db, &body, cb);
}
AssocItemId::ConstId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
cb(it.into());
visit_body(db, &body, cb);
}
AssocItemId::TypeAliasId(_) => (),
AssocItemId::TypeAliasId(it) => {
cb(it.into());
}
}
}
}
@ -437,33 +453,27 @@ fn visit_module(
db: &TestDB,
crate_def_map: &DefMap,
scope: &ItemScope,
cb: &mut dyn FnMut(DefWithBodyId),
cb: &mut dyn FnMut(ModuleDefId),
) {
for decl in scope.declarations() {
cb(decl);
match decl {
ModuleDefId::FunctionId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
visit_body(db, &body, cb);
}
ModuleDefId::ConstId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
visit_body(db, &body, cb);
}
ModuleDefId::StaticId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
visit_body(db, &body, cb);
}
ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => {
db.enum_data(it).variants.iter().for_each(|&(it, _)| {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
cb(it.into());
visit_body(db, &body, cb);
});
}
@ -473,7 +483,7 @@ fn visit_module(
match item {
AssocItemId::FunctionId(it) => cb(it.into()),
AssocItemId::ConstId(it) => cb(it.into()),
AssocItemId::TypeAliasId(_) => (),
AssocItemId::TypeAliasId(it) => cb(it.into()),
}
}
}
@ -483,7 +493,7 @@ fn visit_module(
}
}
fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) {
fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(ModuleDefId)) {
for (_, def_map) in body.blocks(db) {
for (mod_id, _) in def_map.modules() {
visit_module(db, &def_map, mod_id, cb);
@ -553,7 +563,13 @@ fn salsa_bug() {
let module = db.module_for_file(pos.file_id);
let crate_def_map = module.def_map(&db);
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
db.infer(def);
db.infer(match def {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
});
});
let new_text = "
@ -586,6 +602,12 @@ fn salsa_bug() {
let module = db.module_for_file(pos.file_id);
let crate_def_map = module.def_map(&db);
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
db.infer(def);
db.infer(match def {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
});
});
}

View file

@ -24,6 +24,13 @@ fn check_closure_captures(ra_fixture: &str, expect: Expect) {
let mut captures_info = Vec::new();
for def in defs {
let def = match def {
hir_def::ModuleDefId::FunctionId(it) => it.into(),
hir_def::ModuleDefId::EnumVariantId(it) => it.into(),
hir_def::ModuleDefId::ConstId(it) => it.into(),
hir_def::ModuleDefId::StaticId(it) => it.into(),
_ => continue,
};
let infer = db.infer(def);
let db = &db;
captures_info.extend(infer.closure_info.iter().flat_map(|(closure_id, (captures, _))| {

View file

@ -1,4 +1,5 @@
use base_db::SourceDatabaseFileInputExt as _;
use hir_def::ModuleDefId;
use test_fixture::WithFixture;
use crate::{db::HirDatabase, test_db::TestDB};
@ -19,7 +20,9 @@ fn foo() -> i32 {
let module = db.module_for_file(pos.file_id.file_id());
let crate_def_map = module.def_map(&db);
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
db.infer(def);
if let ModuleDefId::FunctionId(it) = def {
db.infer(it.into());
}
});
});
assert!(format!("{events:?}").contains("infer"))
@ -39,7 +42,9 @@ fn foo() -> i32 {
let module = db.module_for_file(pos.file_id.file_id());
let crate_def_map = module.def_map(&db);
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
db.infer(def);
if let ModuleDefId::FunctionId(it) = def {
db.infer(it.into());
}
});
});
assert!(!format!("{events:?}").contains("infer"), "{events:#?}")
@ -66,7 +71,9 @@ fn baz() -> i32 {
let module = db.module_for_file(pos.file_id.file_id());
let crate_def_map = module.def_map(&db);
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
db.infer(def);
if let ModuleDefId::FunctionId(it) = def {
db.infer(it.into());
}
});
});
assert!(format!("{events:?}").contains("infer"))
@ -91,7 +98,9 @@ fn baz() -> i32 {
let module = db.module_for_file(pos.file_id.file_id());
let crate_def_map = module.def_map(&db);
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
db.infer(def);
if let ModuleDefId::FunctionId(it) = def {
db.infer(it.into());
}
});
});
assert!(format!("{events:?}").matches("infer").count() == 1, "{events:#?}")

File diff suppressed because it is too large Load diff