diff --git a/crates/hir-ty/src/variance.rs b/crates/hir-ty/src/variance.rs index 0cce1aec2b..64286121b6 100644 --- a/crates/hir-ty/src/variance.rs +++ b/crates/hir-ty/src/variance.rs @@ -6,8 +6,8 @@ use crate::db::HirDatabase; use crate::generics::{generics, Generics}; use crate::{ - AliasTy, Const, ConstScalar, DynTyExt, FnPointer, GenericArg, GenericArgData, Interner, - Lifetime, LifetimeData, Ty, TyKind, + AliasTy, Const, ConstScalar, DynTyExt, GenericArg, GenericArgData, Interner, Lifetime, + LifetimeData, Ty, TyKind, }; use base_db::ra_salsa::Cycle; use chalk_ir::Mutability; @@ -15,6 +15,7 @@ use hir_def::data::adt::StructFlags; use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId}; use std::fmt; use std::ops::Not; +use stdx::never; use triomphe::Arc; pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option> { @@ -156,9 +157,19 @@ impl Variance { (x, Variance::Bivariant) | (Variance::Bivariant, x) => x, } } + + pub fn invariant(self) -> Self { + self.xform(Variance::Invariant) + } + + pub fn covariant(self) -> Self { + self.xform(Variance::Covariant) + } + + pub fn contravariant(self) -> Self { + self.xform(Variance::Contravariant) + } } -#[derive(Copy, Clone, Debug)] -struct InferredIndex(usize); struct Context<'db> { db: &'db dyn HirDatabase, @@ -193,12 +204,12 @@ impl Context<'_> { } GenericDefId::FunctionId(f) => { let subst = self.generics.placeholder_subst(self.db); - self.add_constraints_from_sig2( - &self - .db + self.add_constraints_from_sig( + self.db .callable_item_signature(f.into()) .substitute(Interner, &subst) - .params_and_return, + .params_and_return + .iter(), Variance::Covariant, ); } @@ -216,41 +227,15 @@ impl Context<'_> { // Functions are permitted to have unused generic parameters: make those invariant. if let GenericDefId::FunctionId(_) = self.generics.def() { - for variance in &mut variances { - if *variance == Variance::Bivariant { - *variance = Variance::Invariant; - } - } + variances + .iter_mut() + .filter(|&&mut v| v == Variance::Bivariant) + .for_each(|v| *v = Variance::Invariant); } variances } - fn contravariant(&mut self, variance: Variance) -> Variance { - variance.xform(Variance::Contravariant) - } - - fn invariant(&mut self, variance: Variance) -> Variance { - variance.xform(Variance::Invariant) - } - - fn add_constraints_from_invariant_args(&mut self, args: &[GenericArg], variance: Variance) { - tracing::debug!( - "add_constraints_from_invariant_args(args={:?}, variance={:?})", - args, - variance - ); - let variance_i = self.invariant(variance); - - for k in args { - match k.data(Interner) { - GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, variance_i), - GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), - GenericArgData::Const(val) => self.add_constraints_from_const(val, variance_i), - } - } - } - /// Adds constraints appropriate for an instance of `ty` appearing /// in a context with the generics defined in `generics` and /// ambient variance `variance` @@ -260,39 +245,31 @@ impl Context<'_> { TyKind::Scalar(_) | TyKind::Never | TyKind::Str | TyKind::Foreign(..) => { // leaf type -- noop } - TyKind::FnDef(..) | TyKind::Coroutine(..) | TyKind::Closure(..) => { - panic!("Unexpected unnameable type in variance computation: {ty:?}"); + never!("Unexpected unnameable type in variance computation: {:?}", ty); } - TyKind::Ref(mutbl, lifetime, ty) => { self.add_constraints_from_region(lifetime, variance); self.add_constraints_from_mt(ty, *mutbl, variance); } - TyKind::Array(typ, len) => { self.add_constraints_from_const(len, variance); self.add_constraints_from_ty(typ, variance); } - TyKind::Slice(typ) => { self.add_constraints_from_ty(typ, variance); } - TyKind::Raw(mutbl, ty) => { self.add_constraints_from_mt(ty, *mutbl, variance); } - TyKind::Tuple(_, subtys) => { for subty in subtys.type_parameters(Interner) { self.add_constraints_from_ty(&subty, variance); } } - TyKind::Adt(def, args) => { self.add_constraints_from_args(def.0.into(), args.as_slice(Interner), variance); } - TyKind::Alias(AliasTy::Opaque(opaque)) => { self.add_constraints_from_invariant_args( opaque.substitution.as_slice(Interner), @@ -313,7 +290,6 @@ impl Context<'_> { TyKind::OpaqueType(_, subst) => { self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance); } - TyKind::Dyn(it) => { // The type `dyn Trait +'a` is covariant w/r/t `'a`: self.add_constraints_from_region(&it.lifetime, variance); @@ -352,20 +328,33 @@ impl Context<'_> { // Chalk has no params, so use placeholders for now? TyKind::Placeholder(index) => { let idx = crate::from_placeholder_idx(self.db, *index); - let inferred = InferredIndex(self.generics.type_or_const_param_idx(idx).unwrap()); - self.constrain(inferred, variance); + let index = self.generics.type_or_const_param_idx(idx).unwrap(); + self.constrain(index, variance); } TyKind::Function(f) => { - self.add_constraints_from_sig(f, variance); + self.add_constraints_from_sig( + f.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)), + variance, + ); } - TyKind::Error => { // we encounter this when walking the trait references for object // types, where we use Error as the Self type } - TyKind::CoroutineWitness(..) | TyKind::BoundVar(..) | TyKind::InferenceVar(..) => { - panic!("unexpected type encountered in variance inference: {:?}", ty); + never!("unexpected type encountered in variance inference: {:?}", ty) + } + } + } + + fn add_constraints_from_invariant_args(&mut self, args: &[GenericArg], variance: Variance) { + let variance_i = variance.invariant(); + + for k in args { + match k.data(Interner) { + GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, variance_i), + GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), + GenericArgData::Const(val) => self.add_constraints_from_const(val, variance_i), } } } @@ -378,13 +367,6 @@ impl Context<'_> { args: &[GenericArg], variance: Variance, ) { - tracing::debug!( - "add_constraints_from_args(def_id={:?}, args={:?}, variance={:?})", - def_id, - args, - variance - ); - // We don't record `inferred_starts` entries for empty generics. if args.is_empty() { return; @@ -392,13 +374,12 @@ impl Context<'_> { if def_id == self.generics.def() { // HACK: Workaround for the trivial cycle salsa case (see // recursive_one_bivariant_more_non_bivariant_params test) - let variance_i = variance.xform(Variance::Bivariant); for k in args { match k.data(Interner) { GenericArgData::Lifetime(lt) => { - self.add_constraints_from_region(lt, variance_i) + self.add_constraints_from_region(lt, Variance::Bivariant) } - GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), + GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, Variance::Bivariant), GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), } } @@ -408,12 +389,13 @@ impl Context<'_> { }; for (i, k) in args.iter().enumerate() { - let variance_i = variance.xform(variances[i]); match k.data(Interner) { GenericArgData::Lifetime(lt) => { - self.add_constraints_from_region(lt, variance_i) + self.add_constraints_from_region(lt, variance.xform(variances[i])) + } + GenericArgData::Ty(ty) => { + self.add_constraints_from_ty(ty, variance.xform(variances[i])) } - GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i), GenericArgData::Const(val) => self.add_constraints_from_const(val, variance), } } @@ -435,20 +417,17 @@ impl Context<'_> { /// Adds constraints appropriate for a function with signature /// `sig` appearing in a context with ambient variance `variance` - fn add_constraints_from_sig(&mut self, sig: &FnPointer, variance: Variance) { - let contra = self.contravariant(variance); - let mut tys = sig.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)); - self.add_constraints_from_ty(tys.next_back().unwrap(), variance); - for input in tys { - self.add_constraints_from_ty(input, contra); - } - } - - fn add_constraints_from_sig2(&mut self, sig: &[Ty], variance: Variance) { - let contra = self.contravariant(variance); - let mut tys = sig.iter(); - self.add_constraints_from_ty(tys.next_back().unwrap(), variance); - for input in tys { + fn add_constraints_from_sig<'a>( + &mut self, + mut sig_tys: impl DoubleEndedIterator, + variance: Variance, + ) { + let contra = variance.contravariant(); + let Some(output) = sig_tys.next_back() else { + return never!("function signature has no return type"); + }; + self.add_constraints_from_ty(output, variance); + for input in sig_tys { self.add_constraints_from_ty(input, contra); } } @@ -462,27 +441,23 @@ impl Context<'_> { variance ); match region.data(Interner) { - // FIXME: chalk has no params? LifetimeData::Placeholder(index) => { let idx = crate::lt_from_placeholder_idx(self.db, *index); - let inferred = InferredIndex(self.generics.lifetime_idx(idx).unwrap()); + let inferred = self.generics.lifetime_idx(idx).unwrap(); self.constrain(inferred, variance); } LifetimeData::Static => {} - LifetimeData::BoundVar(..) => { // Either a higher-ranked region inside of a type or a // late-bound function parameter. // // We do not compute constraints for either of these. } - LifetimeData::Error => {} - LifetimeData::Phantom(..) | LifetimeData::InferenceVar(..) | LifetimeData::Erased => { // We don't expect to see anything but 'static or bound // regions when visiting member types or method types. - panic!( + never!( "unexpected region encountered in variance \ inference: {:?}", region @@ -494,26 +469,23 @@ impl Context<'_> { /// Adds constraints appropriate for a mutability-type pair /// appearing in a context with ambient variance `variance` fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: Variance) { - match mt { - Mutability::Mut => { - let invar = self.invariant(variance); - self.add_constraints_from_ty(ty, invar); - } - - Mutability::Not => { - self.add_constraints_from_ty(ty, variance); - } - } + self.add_constraints_from_ty( + ty, + match mt { + Mutability::Mut => variance.invariant(), + Mutability::Not => variance, + }, + ); } - fn constrain(&mut self, inferred: InferredIndex, variance: Variance) { + fn constrain(&mut self, index: usize, variance: Variance) { tracing::debug!( "constrain(index={:?}, variance={:?}, to={:?})", - inferred, - self.variances[inferred.0], + index, + self.variances[index], variance ); - self.variances[inferred.0] = self.variances[inferred.0].glb(variance); + self.variances[index] = self.variances[index].glb(variance); } } @@ -967,6 +939,22 @@ fn bar<'min,'max>(v: SomeStruct<&'min ()>) ); } + #[test] + fn invalid_arg_counts() { + check( + r#" +struct S(T); +struct S2(S<>); +struct S3(S); +"#, + expect![[r#" + S[T: covariant] + S2[T: bivariant] + S3[T: covariant] + "#]], + ); + } + #[test] fn recursive_one_bivariant_more_non_bivariant_params() { // FIXME: This is wrong, this should be `BivariantPartialIndirect[T: bivariant, U: covariant]` (likewise for Wrapper)