Tweak variant_id_for_adt

This commit is contained in:
Nadrieril 2024-02-06 05:01:05 +01:00
parent 66cec4d11a
commit c2d21242aa

View file

@ -75,18 +75,15 @@ impl<'p> MatchCheckCtx<'p> {
} }
} }
fn variant_id_for_adt(&self, ctor: &Constructor<Self>, adt: hir_def::AdtId) -> VariantId { fn variant_id_for_adt(ctor: &Constructor<Self>, adt: hir_def::AdtId) -> Option<VariantId> {
match ctor { match ctor {
&Variant(id) => id.into(), &Variant(id) => Some(id.into()),
Struct | UnionField => { Struct | UnionField => match adt {
assert!(!matches!(adt, hir_def::AdtId::EnumId(_))); hir_def::AdtId::EnumId(_) => None,
match adt { hir_def::AdtId::StructId(id) => Some(id.into()),
hir_def::AdtId::EnumId(_) => unreachable!(), hir_def::AdtId::UnionId(id) => Some(id.into()),
hir_def::AdtId::StructId(id) => id.into(), },
hir_def::AdtId::UnionId(id) => id.into(), _ => panic!("bad constructor {ctor:?} for adt {adt:?}"),
}
}
_ => panic!("bad constructor {self:?} for adt {adt:?}"),
} }
} }
@ -200,7 +197,7 @@ impl<'p> MatchCheckCtx<'p> {
Wildcard Wildcard
} }
}; };
let variant = self.variant_id_for_adt(&ctor, adt.0); let variant = Self::variant_id_for_adt(&ctor, adt.0).unwrap();
let fields_len = variant.variant_data(self.db.upcast()).fields().len(); let fields_len = variant.variant_data(self.db.upcast()).fields().len();
// For each field in the variant, we store the relevant index into `self.fields` if any. // For each field in the variant, we store the relevant index into `self.fields` if any.
let mut field_id_to_id: Vec<Option<usize>> = vec![None; fields_len]; let mut field_id_to_id: Vec<Option<usize>> = vec![None; fields_len];
@ -266,7 +263,7 @@ impl<'p> MatchCheckCtx<'p> {
PatKind::Deref { subpattern: subpatterns.next().unwrap() } PatKind::Deref { subpattern: subpatterns.next().unwrap() }
} }
TyKind::Adt(adt, substs) => { TyKind::Adt(adt, substs) => {
let variant = self.variant_id_for_adt(pat.ctor(), adt.0); let variant = Self::variant_id_for_adt(pat.ctor(), adt.0).unwrap();
let subpatterns = self let subpatterns = self
.list_variant_nonhidden_fields(pat.ty(), variant) .list_variant_nonhidden_fields(pat.ty(), variant)
.zip(subpatterns) .zip(subpatterns)
@ -327,7 +324,7 @@ impl<'p> TypeCx for MatchCheckCtx<'p> {
// patterns. If we're here we can assume this is a box pattern. // patterns. If we're here we can assume this is a box pattern.
1 1
} else { } else {
let variant = self.variant_id_for_adt(ctor, adt); let variant = Self::variant_id_for_adt(ctor, adt).unwrap();
self.list_variant_nonhidden_fields(ty, variant).count() self.list_variant_nonhidden_fields(ty, variant).count()
} }
} }
@ -370,7 +367,7 @@ impl<'p> TypeCx for MatchCheckCtx<'p> {
let subst_ty = substs.at(Interner, 0).assert_ty_ref(Interner).clone(); let subst_ty = substs.at(Interner, 0).assert_ty_ref(Interner).clone();
alloc(self, once(subst_ty)) alloc(self, once(subst_ty))
} else { } else {
let variant = self.variant_id_for_adt(ctor, adt); let variant = Self::variant_id_for_adt(ctor, adt).unwrap();
let tys = self.list_variant_nonhidden_fields(ty, variant).map(|(_, ty)| ty); let tys = self.list_variant_nonhidden_fields(ty, variant).map(|(_, ty)| ty);
alloc(self, tys) alloc(self, tys)
} }