This commit is contained in:
Lukas Wirth 2024-12-28 19:51:04 +01:00
parent 9419e199d8
commit 60e28c6bd9

View file

@ -6,8 +6,8 @@
use crate::db::HirDatabase; use crate::db::HirDatabase;
use crate::generics::{generics, Generics}; use crate::generics::{generics, Generics};
use crate::{ use crate::{
AliasTy, Const, ConstScalar, DynTyExt, FnPointer, GenericArg, GenericArgData, Interner, AliasTy, Const, ConstScalar, DynTyExt, GenericArg, GenericArgData, Interner, Lifetime,
Lifetime, LifetimeData, Ty, TyKind, LifetimeData, Ty, TyKind,
}; };
use base_db::ra_salsa::Cycle; use base_db::ra_salsa::Cycle;
use chalk_ir::Mutability; use chalk_ir::Mutability;
@ -15,6 +15,7 @@ use hir_def::data::adt::StructFlags;
use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId}; use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId};
use std::fmt; use std::fmt;
use std::ops::Not; use std::ops::Not;
use stdx::never;
use triomphe::Arc; use triomphe::Arc;
pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option<Arc<[Variance]>> { pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option<Arc<[Variance]>> {
@ -156,9 +157,19 @@ impl Variance {
(x, Variance::Bivariant) | (Variance::Bivariant, x) => x, (x, Variance::Bivariant) | (Variance::Bivariant, x) => x,
} }
} }
pub fn invariant(self) -> Self {
self.xform(Variance::Invariant)
}
pub fn covariant(self) -> Self {
self.xform(Variance::Covariant)
}
pub fn contravariant(self) -> Self {
self.xform(Variance::Contravariant)
}
} }
#[derive(Copy, Clone, Debug)]
struct InferredIndex(usize);
struct Context<'db> { struct Context<'db> {
db: &'db dyn HirDatabase, db: &'db dyn HirDatabase,
@ -193,12 +204,12 @@ impl Context<'_> {
} }
GenericDefId::FunctionId(f) => { GenericDefId::FunctionId(f) => {
let subst = self.generics.placeholder_subst(self.db); let subst = self.generics.placeholder_subst(self.db);
self.add_constraints_from_sig2( self.add_constraints_from_sig(
&self self.db
.db
.callable_item_signature(f.into()) .callable_item_signature(f.into())
.substitute(Interner, &subst) .substitute(Interner, &subst)
.params_and_return, .params_and_return
.iter(),
Variance::Covariant, Variance::Covariant,
); );
} }
@ -216,41 +227,15 @@ impl Context<'_> {
// Functions are permitted to have unused generic parameters: make those invariant. // Functions are permitted to have unused generic parameters: make those invariant.
if let GenericDefId::FunctionId(_) = self.generics.def() { if let GenericDefId::FunctionId(_) = self.generics.def() {
for variance in &mut variances { variances
if *variance == Variance::Bivariant { .iter_mut()
*variance = Variance::Invariant; .filter(|&&mut v| v == Variance::Bivariant)
} .for_each(|v| *v = Variance::Invariant);
}
} }
variances variances
} }
fn contravariant(&mut self, variance: Variance) -> Variance {
variance.xform(Variance::Contravariant)
}
fn invariant(&mut self, variance: Variance) -> Variance {
variance.xform(Variance::Invariant)
}
fn add_constraints_from_invariant_args(&mut self, args: &[GenericArg], variance: Variance) {
tracing::debug!(
"add_constraints_from_invariant_args(args={:?}, variance={:?})",
args,
variance
);
let variance_i = self.invariant(variance);
for k in args {
match k.data(Interner) {
GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, variance_i),
GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i),
GenericArgData::Const(val) => self.add_constraints_from_const(val, variance_i),
}
}
}
/// Adds constraints appropriate for an instance of `ty` appearing /// Adds constraints appropriate for an instance of `ty` appearing
/// in a context with the generics defined in `generics` and /// in a context with the generics defined in `generics` and
/// ambient variance `variance` /// ambient variance `variance`
@ -260,39 +245,31 @@ impl Context<'_> {
TyKind::Scalar(_) | TyKind::Never | TyKind::Str | TyKind::Foreign(..) => { TyKind::Scalar(_) | TyKind::Never | TyKind::Str | TyKind::Foreign(..) => {
// leaf type -- noop // leaf type -- noop
} }
TyKind::FnDef(..) | TyKind::Coroutine(..) | TyKind::Closure(..) => { TyKind::FnDef(..) | TyKind::Coroutine(..) | TyKind::Closure(..) => {
panic!("Unexpected unnameable type in variance computation: {ty:?}"); never!("Unexpected unnameable type in variance computation: {:?}", ty);
} }
TyKind::Ref(mutbl, lifetime, ty) => { TyKind::Ref(mutbl, lifetime, ty) => {
self.add_constraints_from_region(lifetime, variance); self.add_constraints_from_region(lifetime, variance);
self.add_constraints_from_mt(ty, *mutbl, variance); self.add_constraints_from_mt(ty, *mutbl, variance);
} }
TyKind::Array(typ, len) => { TyKind::Array(typ, len) => {
self.add_constraints_from_const(len, variance); self.add_constraints_from_const(len, variance);
self.add_constraints_from_ty(typ, variance); self.add_constraints_from_ty(typ, variance);
} }
TyKind::Slice(typ) => { TyKind::Slice(typ) => {
self.add_constraints_from_ty(typ, variance); self.add_constraints_from_ty(typ, variance);
} }
TyKind::Raw(mutbl, ty) => { TyKind::Raw(mutbl, ty) => {
self.add_constraints_from_mt(ty, *mutbl, variance); self.add_constraints_from_mt(ty, *mutbl, variance);
} }
TyKind::Tuple(_, subtys) => { TyKind::Tuple(_, subtys) => {
for subty in subtys.type_parameters(Interner) { for subty in subtys.type_parameters(Interner) {
self.add_constraints_from_ty(&subty, variance); self.add_constraints_from_ty(&subty, variance);
} }
} }
TyKind::Adt(def, args) => { TyKind::Adt(def, args) => {
self.add_constraints_from_args(def.0.into(), args.as_slice(Interner), variance); self.add_constraints_from_args(def.0.into(), args.as_slice(Interner), variance);
} }
TyKind::Alias(AliasTy::Opaque(opaque)) => { TyKind::Alias(AliasTy::Opaque(opaque)) => {
self.add_constraints_from_invariant_args( self.add_constraints_from_invariant_args(
opaque.substitution.as_slice(Interner), opaque.substitution.as_slice(Interner),
@ -313,7 +290,6 @@ impl Context<'_> {
TyKind::OpaqueType(_, subst) => { TyKind::OpaqueType(_, subst) => {
self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance); self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
} }
TyKind::Dyn(it) => { TyKind::Dyn(it) => {
// The type `dyn Trait<T> +'a` is covariant w/r/t `'a`: // The type `dyn Trait<T> +'a` is covariant w/r/t `'a`:
self.add_constraints_from_region(&it.lifetime, variance); self.add_constraints_from_region(&it.lifetime, variance);
@ -352,20 +328,33 @@ impl Context<'_> {
// Chalk has no params, so use placeholders for now? // Chalk has no params, so use placeholders for now?
TyKind::Placeholder(index) => { TyKind::Placeholder(index) => {
let idx = crate::from_placeholder_idx(self.db, *index); let idx = crate::from_placeholder_idx(self.db, *index);
let inferred = InferredIndex(self.generics.type_or_const_param_idx(idx).unwrap()); let index = self.generics.type_or_const_param_idx(idx).unwrap();
self.constrain(inferred, variance); self.constrain(index, variance);
} }
TyKind::Function(f) => { TyKind::Function(f) => {
self.add_constraints_from_sig(f, variance); self.add_constraints_from_sig(
f.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)),
variance,
);
} }
TyKind::Error => { TyKind::Error => {
// we encounter this when walking the trait references for object // we encounter this when walking the trait references for object
// types, where we use Error as the Self type // types, where we use Error as the Self type
} }
TyKind::CoroutineWitness(..) | TyKind::BoundVar(..) | TyKind::InferenceVar(..) => { TyKind::CoroutineWitness(..) | TyKind::BoundVar(..) | TyKind::InferenceVar(..) => {
panic!("unexpected type encountered in variance inference: {:?}", ty); never!("unexpected type encountered in variance inference: {:?}", ty)
}
}
}
fn add_constraints_from_invariant_args(&mut self, args: &[GenericArg], variance: Variance) {
let variance_i = variance.invariant();
for k in args {
match k.data(Interner) {
GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, variance_i),
GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i),
GenericArgData::Const(val) => self.add_constraints_from_const(val, variance_i),
} }
} }
} }
@ -378,13 +367,6 @@ impl Context<'_> {
args: &[GenericArg], args: &[GenericArg],
variance: Variance, variance: Variance,
) { ) {
tracing::debug!(
"add_constraints_from_args(def_id={:?}, args={:?}, variance={:?})",
def_id,
args,
variance
);
// We don't record `inferred_starts` entries for empty generics. // We don't record `inferred_starts` entries for empty generics.
if args.is_empty() { if args.is_empty() {
return; return;
@ -392,13 +374,12 @@ impl Context<'_> {
if def_id == self.generics.def() { if def_id == self.generics.def() {
// HACK: Workaround for the trivial cycle salsa case (see // HACK: Workaround for the trivial cycle salsa case (see
// recursive_one_bivariant_more_non_bivariant_params test) // recursive_one_bivariant_more_non_bivariant_params test)
let variance_i = variance.xform(Variance::Bivariant);
for k in args { for k in args {
match k.data(Interner) { match k.data(Interner) {
GenericArgData::Lifetime(lt) => { GenericArgData::Lifetime(lt) => {
self.add_constraints_from_region(lt, variance_i) self.add_constraints_from_region(lt, Variance::Bivariant)
} }
GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, Variance::Bivariant),
GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), GenericArgData::Const(val) => self.add_constraints_from_const(val, variance),
} }
} }
@ -408,12 +389,13 @@ impl Context<'_> {
}; };
for (i, k) in args.iter().enumerate() { for (i, k) in args.iter().enumerate() {
let variance_i = variance.xform(variances[i]);
match k.data(Interner) { match k.data(Interner) {
GenericArgData::Lifetime(lt) => { GenericArgData::Lifetime(lt) => {
self.add_constraints_from_region(lt, variance_i) self.add_constraints_from_region(lt, variance.xform(variances[i]))
}
GenericArgData::Ty(ty) => {
self.add_constraints_from_ty(ty, variance.xform(variances[i]))
} }
GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i),
GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), GenericArgData::Const(val) => self.add_constraints_from_const(val, variance),
} }
} }
@ -435,20 +417,17 @@ impl Context<'_> {
/// Adds constraints appropriate for a function with signature /// Adds constraints appropriate for a function with signature
/// `sig` appearing in a context with ambient variance `variance` /// `sig` appearing in a context with ambient variance `variance`
fn add_constraints_from_sig(&mut self, sig: &FnPointer, variance: Variance) { fn add_constraints_from_sig<'a>(
let contra = self.contravariant(variance); &mut self,
let mut tys = sig.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)); mut sig_tys: impl DoubleEndedIterator<Item = &'a Ty>,
self.add_constraints_from_ty(tys.next_back().unwrap(), variance); variance: Variance,
for input in tys { ) {
self.add_constraints_from_ty(input, contra); let contra = variance.contravariant();
} let Some(output) = sig_tys.next_back() else {
} return never!("function signature has no return type");
};
fn add_constraints_from_sig2(&mut self, sig: &[Ty], variance: Variance) { self.add_constraints_from_ty(output, variance);
let contra = self.contravariant(variance); for input in sig_tys {
let mut tys = sig.iter();
self.add_constraints_from_ty(tys.next_back().unwrap(), variance);
for input in tys {
self.add_constraints_from_ty(input, contra); self.add_constraints_from_ty(input, contra);
} }
} }
@ -462,27 +441,23 @@ impl Context<'_> {
variance variance
); );
match region.data(Interner) { match region.data(Interner) {
// FIXME: chalk has no params?
LifetimeData::Placeholder(index) => { LifetimeData::Placeholder(index) => {
let idx = crate::lt_from_placeholder_idx(self.db, *index); let idx = crate::lt_from_placeholder_idx(self.db, *index);
let inferred = InferredIndex(self.generics.lifetime_idx(idx).unwrap()); let inferred = self.generics.lifetime_idx(idx).unwrap();
self.constrain(inferred, variance); self.constrain(inferred, variance);
} }
LifetimeData::Static => {} LifetimeData::Static => {}
LifetimeData::BoundVar(..) => { LifetimeData::BoundVar(..) => {
// Either a higher-ranked region inside of a type or a // Either a higher-ranked region inside of a type or a
// late-bound function parameter. // late-bound function parameter.
// //
// We do not compute constraints for either of these. // We do not compute constraints for either of these.
} }
LifetimeData::Error => {} LifetimeData::Error => {}
LifetimeData::Phantom(..) | LifetimeData::InferenceVar(..) | LifetimeData::Erased => { LifetimeData::Phantom(..) | LifetimeData::InferenceVar(..) | LifetimeData::Erased => {
// We don't expect to see anything but 'static or bound // We don't expect to see anything but 'static or bound
// regions when visiting member types or method types. // regions when visiting member types or method types.
panic!( never!(
"unexpected region encountered in variance \ "unexpected region encountered in variance \
inference: {:?}", inference: {:?}",
region region
@ -494,26 +469,23 @@ impl Context<'_> {
/// Adds constraints appropriate for a mutability-type pair /// Adds constraints appropriate for a mutability-type pair
/// appearing in a context with ambient variance `variance` /// appearing in a context with ambient variance `variance`
fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: Variance) { fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: Variance) {
self.add_constraints_from_ty(
ty,
match mt { match mt {
Mutability::Mut => { Mutability::Mut => variance.invariant(),
let invar = self.invariant(variance); Mutability::Not => variance,
self.add_constraints_from_ty(ty, invar); },
);
} }
Mutability::Not => { fn constrain(&mut self, index: usize, variance: Variance) {
self.add_constraints_from_ty(ty, variance);
}
}
}
fn constrain(&mut self, inferred: InferredIndex, variance: Variance) {
tracing::debug!( tracing::debug!(
"constrain(index={:?}, variance={:?}, to={:?})", "constrain(index={:?}, variance={:?}, to={:?})",
inferred, index,
self.variances[inferred.0], self.variances[index],
variance variance
); );
self.variances[inferred.0] = self.variances[inferred.0].glb(variance); self.variances[index] = self.variances[index].glb(variance);
} }
} }
@ -967,6 +939,22 @@ fn bar<'min,'max>(v: SomeStruct<&'min ()>)
); );
} }
#[test]
fn invalid_arg_counts() {
check(
r#"
struct S<T>(T);
struct S2<T>(S<>);
struct S3<T>(S<T, T>);
"#,
expect![[r#"
S[T: covariant]
S2[T: bivariant]
S3[T: covariant]
"#]],
);
}
#[test] #[test]
fn recursive_one_bivariant_more_non_bivariant_params() { fn recursive_one_bivariant_more_non_bivariant_params() {
// FIXME: This is wrong, this should be `BivariantPartialIndirect[T: bivariant, U: covariant]` (likewise for Wrapper) // FIXME: This is wrong, this should be `BivariantPartialIndirect[T: bivariant, U: covariant]` (likewise for Wrapper)