Support record pattern MIR lowering

This commit is contained in:
hkalbasi 2023-03-14 17:02:38 +03:30
parent 513e340bd3
commit 051dae2221
5 changed files with 201 additions and 63 deletions

View file

@ -555,6 +555,38 @@ fn structs() {
"#, "#,
17, 17,
); );
check_number(
r#"
struct Point {
x: i32,
y: i32,
}
const GOAL: i32 = {
let p = Point { x: 5, y: 2 };
let p2 = Point { x: 3, ..p };
p.x * 1000 + p.y * 100 + p2.x * 10 + p2.y
};
"#,
5232,
);
check_number(
r#"
struct Point {
x: i32,
y: i32,
}
const GOAL: i32 = {
let p = Point { x: 5, y: 2 };
let Point { x, y } = p;
let Point { x: x2, .. } = p;
let Point { y: y2, .. } = p;
x * 1000 + y * 100 + x2 * 10 + y2
};
"#,
5252,
);
} }
#[test] #[test]
@ -599,13 +631,14 @@ fn tuples() {
); );
check_number( check_number(
r#" r#"
struct TupleLike(i32, u8, i64, u16); struct TupleLike(i32, i64, u8, u16);
const GOAL: u8 = { const GOAL: i64 = {
let a = TupleLike(10, 20, 3, 15); let a = TupleLike(10, 20, 3, 15);
a.1 let TupleLike(b, .., c) = a;
a.1 * 100 + b as i64 + c as i64
}; };
"#, "#,
20, 2025,
); );
check_number( check_number(
r#" r#"

View file

@ -711,7 +711,8 @@ pub fn is_dyn_method(
}; };
let self_ty = trait_ref.self_type_parameter(Interner); let self_ty = trait_ref.self_type_parameter(Interner);
if let TyKind::Dyn(d) = self_ty.kind(Interner) { if let TyKind::Dyn(d) = self_ty.kind(Interner) {
let is_my_trait_in_bounds = d.bounds.skip_binders().as_slice(Interner).iter().any(|x| match x.skip_binders() { let is_my_trait_in_bounds =
d.bounds.skip_binders().as_slice(Interner).iter().any(|x| match x.skip_binders() {
// rustc doesn't accept `impl Foo<2> for dyn Foo<5>`, so if the trait id is equal, no matter // rustc doesn't accept `impl Foo<2> for dyn Foo<5>`, so if the trait id is equal, no matter
// what the generics are, we are sure that the method is come from the vtable. // what the generics are, we are sure that the method is come from the vtable.
WhereClause::Implemented(tr) => tr.trait_id == trait_ref.trait_id, WhereClause::Implemented(tr) => tr.trait_id == trait_ref.trait_id,

View file

@ -25,8 +25,8 @@ use crate::{
mapping::from_chalk, mapping::from_chalk,
method_resolution::{is_dyn_method, lookup_impl_method}, method_resolution::{is_dyn_method, lookup_impl_method},
traits::FnTrait, traits::FnTrait,
CallableDefId, Const, ConstScalar, FnDefId, Interner, MemoryMap, Substitution, CallableDefId, Const, ConstScalar, FnDefId, GenericArgData, Interner, MemoryMap, Substitution,
TraitEnvironment, Ty, TyBuilder, TyExt, GenericArgData, TraitEnvironment, Ty, TyBuilder, TyExt,
}; };
use super::{ use super::{
@ -1315,10 +1315,13 @@ impl Evaluator<'_> {
args_for_target[0] = args_for_target[0][0..self.ptr_size()].to_vec(); args_for_target[0] = args_for_target[0][0..self.ptr_size()].to_vec();
let generics_for_target = Substitution::from_iter( let generics_for_target = Substitution::from_iter(
Interner, Interner,
generic_args generic_args.iter(Interner).enumerate().map(|(i, x)| {
.iter(Interner) if i == self_ty_idx {
.enumerate() &ty
.map(|(i, x)| if i == self_ty_idx { &ty } else { x }) } else {
x
}
}),
); );
return self.exec_fn_with_args( return self.exec_fn_with_args(
def, def,

View file

@ -4,16 +4,17 @@ use std::{iter, mem, sync::Arc};
use chalk_ir::{BoundVar, ConstData, DebruijnIndex, TyKind}; use chalk_ir::{BoundVar, ConstData, DebruijnIndex, TyKind};
use hir_def::{ use hir_def::{
adt::VariantData,
body::Body, body::Body,
expr::{ expr::{
Array, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, Pat, PatId, Array, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, Pat, PatId,
RecordLitField, RecordFieldPat, RecordLitField,
}, },
lang_item::{LangItem, LangItemTarget}, lang_item::{LangItem, LangItemTarget},
layout::LayoutError, layout::LayoutError,
path::Path, path::Path,
resolver::{resolver_for_expr, ResolveValueResult, ValueNs}, resolver::{resolver_for_expr, ResolveValueResult, ValueNs},
DefWithBodyId, EnumVariantId, HasModule, ItemContainerId, TraitId, DefWithBodyId, EnumVariantId, HasModule, ItemContainerId, LocalFieldId, TraitId,
}; };
use hir_expand::name::Name; use hir_expand::name::Name;
use la_arena::ArenaMap; use la_arena::ArenaMap;
@ -106,6 +107,12 @@ impl MirLowerError {
type Result<T> = std::result::Result<T, MirLowerError>; type Result<T> = std::result::Result<T, MirLowerError>;
enum AdtPatternShape<'a> {
Tuple { args: &'a [PatId], ellipsis: Option<usize> },
Record { args: &'a [RecordFieldPat] },
Unit,
}
impl MirLowerCtx<'_> { impl MirLowerCtx<'_> {
fn temp(&mut self, ty: Ty) -> Result<LocalId> { fn temp(&mut self, ty: Ty) -> Result<LocalId> {
if matches!(ty.kind(Interner), TyKind::Slice(_) | TyKind::Dyn(_)) { if matches!(ty.kind(Interner), TyKind::Slice(_) | TyKind::Dyn(_)) {
@ -444,7 +451,8 @@ impl MirLowerCtx<'_> {
current, current,
pat.into(), pat.into(),
Some(end), Some(end),
&[pat], &None)?; AdtPatternShape::Tuple { args: &[pat], ellipsis: None },
)?;
if let Some((_, block)) = this.lower_expr_as_place(current, body, true)? { if let Some((_, block)) = this.lower_expr_as_place(current, body, true)? {
this.set_goto(block, begin); this.set_goto(block, begin);
} }
@ -573,7 +581,17 @@ impl MirLowerCtx<'_> {
Ok(None) Ok(None)
} }
Expr::Yield { .. } => not_supported!("yield"), Expr::Yield { .. } => not_supported!("yield"),
Expr::RecordLit { fields, path, .. } => { Expr::RecordLit { fields, path, spread, ellipsis: _, is_assignee_expr: _ } => {
let spread_place = match spread {
&Some(x) => {
let Some((p, c)) = self.lower_expr_as_place(current, x, true)? else {
return Ok(None);
};
current = c;
Some(p)
},
None => None,
};
let variant_id = self let variant_id = self
.infer .infer
.variant_resolution_for_expr(expr_id) .variant_resolution_for_expr(expr_id)
@ -603,9 +621,24 @@ impl MirLowerCtx<'_> {
place, place,
Rvalue::Aggregate( Rvalue::Aggregate(
AggregateKind::Adt(variant_id, subst), AggregateKind::Adt(variant_id, subst),
operands.into_iter().map(|x| x).collect::<Option<_>>().ok_or( match spread_place {
Some(sp) => operands.into_iter().enumerate().map(|(i, x)| {
match x {
Some(x) => x,
None => {
let mut p = sp.clone();
p.projection.push(ProjectionElem::Field(FieldId {
parent: variant_id,
local_id: LocalFieldId::from_raw(RawIdx::from(i as u32)),
}));
Operand::Copy(p)
},
}
}).collect(),
None => operands.into_iter().map(|x| x).collect::<Option<_>>().ok_or(
MirLowerError::TypeError("missing field in record literal"), MirLowerError::TypeError("missing field in record literal"),
)?, )?,
},
), ),
expr_id.into(), expr_id.into(),
); );
@ -1021,14 +1054,11 @@ impl MirLowerCtx<'_> {
self.pattern_match_tuple_like( self.pattern_match_tuple_like(
current, current,
current_else, current_else,
args.iter().enumerate().map(|(i, x)| { args,
(
PlaceElem::TupleField(i),
*x,
subst.at(Interner, i).assert_ty_ref(Interner).clone(),
)
}),
*ellipsis, *ellipsis,
subst.iter(Interner).enumerate().map(|(i, x)| {
(PlaceElem::TupleField(i), x.assert_ty_ref(Interner).clone())
}),
&cond_place, &cond_place,
binding_mode, binding_mode,
)? )?
@ -1062,7 +1092,21 @@ impl MirLowerCtx<'_> {
} }
(then_target, current_else) (then_target, current_else)
} }
Pat::Record { .. } => not_supported!("record pattern"), Pat::Record { args, .. } => {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant");
};
self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place,
variant,
current,
pattern.into(),
current_else,
AdtPatternShape::Record { args: &*args },
)?
}
Pat::Range { .. } => not_supported!("range pattern"), Pat::Range { .. } => not_supported!("range pattern"),
Pat::Slice { .. } => not_supported!("slice pattern"), Pat::Slice { .. } => not_supported!("slice pattern"),
Pat::Path(_) => { Pat::Path(_) => {
@ -1077,8 +1121,7 @@ impl MirLowerCtx<'_> {
current, current,
pattern.into(), pattern.into(),
current_else, current_else,
&[], AdtPatternShape::Unit,
&None,
)? )?
} }
Pat::Lit(l) => { Pat::Lit(l) => {
@ -1160,8 +1203,7 @@ impl MirLowerCtx<'_> {
current, current,
pattern.into(), pattern.into(),
current_else, current_else,
args, AdtPatternShape::Tuple { args, ellipsis: *ellipsis },
ellipsis,
)? )?
} }
Pat::Ref { .. } => not_supported!("& pattern"), Pat::Ref { .. } => not_supported!("& pattern"),
@ -1179,15 +1221,13 @@ impl MirLowerCtx<'_> {
current: BasicBlockId, current: BasicBlockId,
span: MirSpan, span: MirSpan,
current_else: Option<BasicBlockId>, current_else: Option<BasicBlockId>,
args: &[PatId], shape: AdtPatternShape<'_>,
ellipsis: &Option<usize>,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> { ) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place); pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
let subst = match cond_ty.kind(Interner) { let subst = match cond_ty.kind(Interner) {
TyKind::Adt(_, s) => s, TyKind::Adt(_, s) => s,
_ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")), _ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")),
}; };
let fields_type = self.db.field_types(variant);
Ok(match variant { Ok(match variant {
VariantId::EnumVariantId(v) => { VariantId::EnumVariantId(v) => {
let e = self.db.const_eval_discriminant(v)? as u128; let e = self.db.const_eval_discriminant(v)? as u128;
@ -1208,35 +1248,26 @@ impl MirLowerCtx<'_> {
}, },
); );
let enum_data = self.db.enum_data(v.parent); let enum_data = self.db.enum_data(v.parent);
let fields = self.pattern_matching_variant_fields(
enum_data.variants[v.local_id].variant_data.fields().iter().map(|(x, _)| { shape,
( &enum_data.variants[v.local_id].variant_data,
PlaceElem::Field(FieldId { parent: v.into(), local_id: x }), variant,
fields_type[x].clone().substitute(Interner, subst), subst,
)
});
self.pattern_match_tuple_like(
next, next,
Some(else_target), Some(else_target),
args.iter().zip(fields).map(|(x, y)| (y.0, *x, y.1)),
*ellipsis,
&cond_place, &cond_place,
binding_mode, binding_mode,
)? )?
} }
VariantId::StructId(s) => { VariantId::StructId(s) => {
let struct_data = self.db.struct_data(s); let struct_data = self.db.struct_data(s);
let fields = struct_data.variant_data.fields().iter().map(|(x, _)| { self.pattern_matching_variant_fields(
( shape,
PlaceElem::Field(FieldId { parent: s.into(), local_id: x }), &struct_data.variant_data,
fields_type[x].clone().substitute(Interner, subst), variant,
) subst,
});
self.pattern_match_tuple_like(
current, current,
current_else, current_else,
args.iter().zip(fields).map(|(x, y)| (y.0, *x, y.1)),
*ellipsis,
&cond_place, &cond_place,
binding_mode, binding_mode,
)? )?
@ -1247,18 +1278,69 @@ impl MirLowerCtx<'_> {
}) })
} }
fn pattern_match_tuple_like( fn pattern_matching_variant_fields(
&mut self,
shape: AdtPatternShape<'_>,
variant_data: &VariantData,
v: VariantId,
subst: &Substitution,
current: BasicBlockId,
current_else: Option<BasicBlockId>,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let fields_type = self.db.field_types(v);
Ok(match shape {
AdtPatternShape::Record { args } => {
let it = args
.iter()
.map(|x| {
let field_id =
variant_data.field(&x.name).ok_or(MirLowerError::UnresolvedField)?;
Ok((
PlaceElem::Field(FieldId { parent: v.into(), local_id: field_id }),
x.pat,
fields_type[field_id].clone().substitute(Interner, subst),
))
})
.collect::<Result<Vec<_>>>()?;
self.pattern_match_adt(
current,
current_else,
it.into_iter(),
cond_place,
binding_mode,
)?
}
AdtPatternShape::Tuple { args, ellipsis } => {
let fields = variant_data.fields().iter().map(|(x, _)| {
(
PlaceElem::Field(FieldId { parent: v.into(), local_id: x }),
fields_type[x].clone().substitute(Interner, subst),
)
});
self.pattern_match_tuple_like(
current,
current_else,
args,
ellipsis,
fields,
cond_place,
binding_mode,
)?
}
AdtPatternShape::Unit => (current, current_else),
})
}
fn pattern_match_adt(
&mut self, &mut self,
mut current: BasicBlockId, mut current: BasicBlockId,
mut current_else: Option<BasicBlockId>, mut current_else: Option<BasicBlockId>,
args: impl Iterator<Item = (PlaceElem, PatId, Ty)>, args: impl Iterator<Item = (PlaceElem, PatId, Ty)>,
ellipsis: Option<usize>,
cond_place: &Place, cond_place: &Place,
binding_mode: BindingAnnotation, binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> { ) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
if ellipsis.is_some() {
not_supported!("tuple like pattern with ellipsis");
}
for (proj, arg, ty) in args { for (proj, arg, ty) in args {
let mut cond_place = cond_place.clone(); let mut cond_place = cond_place.clone();
cond_place.projection.push(proj); cond_place.projection.push(proj);
@ -1268,6 +1350,25 @@ impl MirLowerCtx<'_> {
Ok((current, current_else)) Ok((current, current_else))
} }
fn pattern_match_tuple_like(
&mut self,
current: BasicBlockId,
current_else: Option<BasicBlockId>,
args: &[PatId],
ellipsis: Option<usize>,
fields: impl DoubleEndedIterator<Item = (PlaceElem, Ty)> + Clone,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len()));
let it = al
.iter()
.zip(fields.clone())
.chain(ar.iter().rev().zip(fields.rev()))
.map(|(x, y)| (y.0, *x, y.1));
self.pattern_match_adt(current, current_else, it, cond_place, binding_mode)
}
fn discr_temp_place(&mut self) -> Place { fn discr_temp_place(&mut self) -> Place {
match &self.discr_temp { match &self.discr_temp {
Some(x) => x.clone(), Some(x) => x.clone(),

View file

@ -295,7 +295,7 @@ impl<T> Arena<T> {
/// ``` /// ```
pub fn iter( pub fn iter(
&self, &self,
) -> impl Iterator<Item = (Idx<T>, &T)> + ExactSizeIterator + DoubleEndedIterator { ) -> impl Iterator<Item = (Idx<T>, &T)> + ExactSizeIterator + DoubleEndedIterator + Clone {
self.data.iter().enumerate().map(|(idx, value)| (Idx::from_raw(RawIdx(idx as u32)), value)) self.data.iter().enumerate().map(|(idx, value)| (Idx::from_raw(RawIdx(idx as u32)), value))
} }