Improve pattern matching MIR lowering

This commit is contained in:
hkalbasi 2023-03-14 23:01:46 +03:30
parent 051dae2221
commit 9564773d5e
9 changed files with 590 additions and 395 deletions

View file

@ -1030,9 +1030,16 @@ impl ExprCollector<'_> {
.collect(),
}
}
ast::Pat::LiteralPat(lit) => {
ast::Pat::LiteralPat(lit) => 'b: {
if let Some(ast_lit) = lit.literal() {
let expr = Expr::Literal(ast_lit.kind().into());
let mut hir_lit: Literal = ast_lit.kind().into();
if lit.minus_token().is_some() {
let Some(h) = hir_lit.negate() else {
break 'b Pat::Missing;
};
hir_lit = h;
}
let expr = Expr::Literal(hir_lit);
let expr_ptr = AstPtr::new(&ast::Expr::Literal(ast_lit));
let expr_id = self.alloc_expr(expr, expr_ptr);
Pat::Lit(expr_id)
@ -1144,11 +1151,11 @@ impl From<ast::LiteralKind> for Literal {
FloatTypeWrapper::new(lit.float_value().unwrap_or(Default::default())),
builtin,
)
} else if let builtin @ Some(_) = lit.suffix().and_then(BuiltinInt::from_suffix) {
Literal::Int(lit.value().unwrap_or(0) as i128, builtin)
} else {
let builtin = lit.suffix().and_then(BuiltinUint::from_suffix);
} else if let builtin @ Some(_) = lit.suffix().and_then(BuiltinUint::from_suffix) {
Literal::Uint(lit.value().unwrap_or(0), builtin)
} else {
let builtin = lit.suffix().and_then(BuiltinInt::from_suffix);
Literal::Int(lit.value().unwrap_or(0) as i128, builtin)
}
}
LiteralKind::FloatNumber(lit) => {

View file

@ -92,6 +92,16 @@ pub enum Literal {
Float(FloatTypeWrapper, Option<BuiltinFloat>),
}
impl Literal {
pub fn negate(self) -> Option<Self> {
if let Literal::Int(i, k) = self {
Some(Literal::Int(-i, k))
} else {
None
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Expr {
/// This is produced if the syntax tree does not have a required expression piece.

View file

@ -685,6 +685,36 @@ fn path_pattern_matching() {
);
}
#[test]
fn pattern_matching_literal() {
check_number(
r#"
const fn f(x: i32) -> i32 {
match x {
-1 => 1,
1 => 10,
_ => 100,
}
}
const GOAL: i32 = f(-1) + f(1) + f(0) + f(-5);
"#,
211
);
check_number(
r#"
const fn f(x: &str) -> u8 {
match x {
"foo" => 1,
"bar" => 10,
_ => 100,
}
}
const GOAL: u8 = f("foo") + f("bar");
"#,
11
);
}
#[test]
fn pattern_matching_ergonomics() {
check_number(
@ -698,6 +728,16 @@ fn pattern_matching_ergonomics() {
"#,
5,
);
check_number(
r#"
const GOAL: u8 = {
let a = &(2, 3);
let &(x, y) = a;
x + y
};
"#,
5,
);
}
#[test]
@ -781,6 +821,33 @@ fn function_param_patterns() {
);
}
#[test]
fn match_guards() {
check_number(
r#"
//- minicore: option, eq
impl<T: PartialEq> PartialEq for Option<T> {
fn eq(&self, other: &Rhs) -> bool {
match (self, other) {
(Some(x), Some(y)) => x == y,
(None, None) => true,
_ => false,
}
}
}
fn f(x: Option<i32>) -> i32 {
match x {
y if y == Some(42) => 42000,
Some(y) => y,
None => 10
}
}
const GOAL: i32 = f(Some(42)) + f(Some(2)) + f(None);
"#,
42012,
);
}
#[test]
fn options() {
check_number(
@ -983,6 +1050,51 @@ fn function_pointer() {
);
}
#[test]
fn enum_variant_as_function() {
check_number(
r#"
//- minicore: option
const GOAL: u8 = {
let f = Some;
f(3).unwrap_or(2)
};
"#,
3,
);
check_number(
r#"
//- minicore: option
const GOAL: u8 = {
let f: fn(u8) -> Option<u8> = Some;
f(3).unwrap_or(2)
};
"#,
3,
);
check_number(
r#"
//- minicore: coerce_unsized, index, slice
enum Foo {
Add2(u8),
Mult3(u8),
}
use Foo::*;
const fn f(x: Foo) -> u8 {
match x {
Add2(x) => x + 2,
Mult3(x) => x * 3,
}
}
const GOAL: u8 = {
let x = [Add2, Mult3];
f(x[0](1)) + f(x[1](5))
};
"#,
18,
);
}
#[test]
fn function_traits() {
check_number(

View file

@ -423,6 +423,7 @@ impl Evaluator<'_> {
args: impl Iterator<Item = Vec<u8>>,
subst: Substitution,
) -> Result<Vec<u8>> {
dbg!(body.dbg(self.db));
if let Some(x) = self.stack_depth_limit.checked_sub(1) {
self.stack_depth_limit = x;
} else {
@ -581,7 +582,14 @@ impl Evaluator<'_> {
let mut ty = self.operand_ty(lhs, locals)?;
while let TyKind::Ref(_, _, z) = ty.kind(Interner) {
ty = z.clone();
let size = self.size_of_sized(&ty, locals, "operand of binary op")?;
let size = if ty.kind(Interner) == &TyKind::Str {
let ns = from_bytes!(usize, &lc[self.ptr_size()..self.ptr_size() * 2]);
lc = &lc[..self.ptr_size()];
rc = &rc[..self.ptr_size()];
ns
} else {
self.size_of_sized(&ty, locals, "operand of binary op")?
};
lc = self.read_memory(Address::from_bytes(lc)?, size)?;
rc = self.read_memory(Address::from_bytes(rc)?, size)?;
}

View file

@ -4,7 +4,7 @@ use std::{iter, mem, sync::Arc};
use chalk_ir::{BoundVar, ConstData, DebruijnIndex, TyKind};
use hir_def::{
adt::VariantData,
adt::{VariantData, StructKind},
body::Body,
expr::{
Array, BindingAnnotation, BindingId, ExprId, LabelId, Literal, MatchArm, Pat, PatId,
@ -28,6 +28,9 @@ use crate::{
use super::*;
mod as_place;
mod pattern_matching;
use pattern_matching::AdtPatternShape;
#[derive(Debug, Clone, Copy)]
struct LoopBlocks {
@ -107,12 +110,6 @@ impl MirLowerError {
type Result<T> = std::result::Result<T, MirLowerError>;
enum AdtPatternShape<'a> {
Tuple { args: &'a [PatId], ellipsis: Option<usize> },
Record { args: &'a [RecordFieldPat] },
Unit,
}
impl MirLowerCtx<'_> {
fn temp(&mut self, ty: Ty) -> Result<LocalId> {
if matches!(ty.kind(Interner), TyKind::Slice(_) | TyKind::Dyn(_)) {
@ -275,15 +272,19 @@ impl MirLowerCtx<'_> {
Ok(Some(current))
}
ValueNs::EnumVariantId(variant_id) => {
let ty = self.infer.type_of_expr[expr_id].clone();
let current = self.lower_enum_variant(
variant_id,
current,
place,
ty,
vec![],
expr_id.into(),
)?;
let variant_data = &self.db.enum_data(variant_id.parent).variants[variant_id.local_id];
if variant_data.variant_data.kind() == StructKind::Unit {
let ty = self.infer.type_of_expr[expr_id].clone();
current = self.lower_enum_variant(
variant_id,
current,
place,
ty,
vec![],
expr_id.into(),
)?;
}
// Otherwise its a tuple like enum, treated like a zero sized function, so no action is needed
Ok(Some(current))
}
ValueNs::GenericParam(p) => {
@ -517,10 +518,7 @@ impl MirLowerCtx<'_> {
let cond_ty = self.expr_ty_after_adjustments(*expr);
let mut end = None;
for MatchArm { pat, guard, expr } in arms.iter() {
if guard.is_some() {
not_supported!("pattern matching with guard");
}
let (then, otherwise) = self.pattern_match(
let (then, mut otherwise) = self.pattern_match(
current,
None,
cond_place.clone(),
@ -528,6 +526,16 @@ impl MirLowerCtx<'_> {
*pat,
BindingAnnotation::Unannotated,
)?;
let then = if let &Some(guard) = guard {
let next = self.new_basic_block();
let o = otherwise.get_or_insert_with(|| self.new_basic_block());
if let Some((discr, c)) = self.lower_expr_to_some_operand(guard, then)? {
self.set_terminator(c, Terminator::SwitchInt { discr, targets: SwitchTargets::static_if(1, next, *o) });
}
next
} else {
then
};
if let Some(block) = self.lower_expr_to_place(*expr, place.clone(), then)? {
let r = end.get_or_insert_with(|| self.new_basic_block());
self.set_goto(block, *r);
@ -922,7 +930,7 @@ impl MirLowerCtx<'_> {
) -> Result<BasicBlockId> {
let subst = match ty.kind(Interner) {
TyKind::Adt(_, subst) => subst.clone(),
_ => not_supported!("Non ADT enum"),
_ => implementation_error!("Non ADT enum"),
};
self.push_assignment(
prev_block,
@ -1020,355 +1028,6 @@ impl MirLowerCtx<'_> {
self.push_statement(block, StatementKind::Assign(place, rvalue).with_span(span));
}
/// It gets a `current` unterminated block, appends some statements and possibly a terminator to it to check if
/// the pattern matches and write bindings, and returns two unterminated blocks, one for the matched path (which
/// can be the `current` block) and one for the mismatched path. If the input pattern is irrefutable, the
/// mismatched path block is `None`.
///
/// By default, it will create a new block for mismatched path. If you already have one, you can provide it with
/// `current_else` argument to save an unneccessary jump. If `current_else` isn't `None`, the result mismatched path
/// wouldn't be `None` as well. Note that this function will add jumps to the beginning of the `current_else` block,
/// so it should be an empty block.
fn pattern_match(
&mut self,
mut current: BasicBlockId,
mut current_else: Option<BasicBlockId>,
mut cond_place: Place,
mut cond_ty: Ty,
pattern: PatId,
mut binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
Ok(match &self.body.pats[pattern] {
Pat::Missing => return Err(MirLowerError::IncompleteExpr),
Pat::Wild => (current, current_else),
Pat::Tuple { args, ellipsis } => {
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
let subst = match cond_ty.kind(Interner) {
TyKind::Tuple(_, s) => s,
_ => {
return Err(MirLowerError::TypeError(
"non tuple type matched with tuple pattern",
))
}
};
self.pattern_match_tuple_like(
current,
current_else,
args,
*ellipsis,
subst.iter(Interner).enumerate().map(|(i, x)| {
(PlaceElem::TupleField(i), x.assert_ty_ref(Interner).clone())
}),
&cond_place,
binding_mode,
)?
}
Pat::Or(pats) => {
let then_target = self.new_basic_block();
let mut finished = false;
for pat in &**pats {
let (next, next_else) = self.pattern_match(
current,
None,
cond_place.clone(),
cond_ty.clone(),
*pat,
binding_mode,
)?;
self.set_goto(next, then_target);
match next_else {
Some(t) => {
current = t;
}
None => {
finished = true;
break;
}
}
}
if !finished {
let ce = *current_else.get_or_insert_with(|| self.new_basic_block());
self.set_goto(current, ce);
}
(then_target, current_else)
}
Pat::Record { args, .. } => {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant");
};
self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place,
variant,
current,
pattern.into(),
current_else,
AdtPatternShape::Record { args: &*args },
)?
}
Pat::Range { .. } => not_supported!("range pattern"),
Pat::Slice { .. } => not_supported!("slice pattern"),
Pat::Path(_) => {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant");
};
self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place,
variant,
current,
pattern.into(),
current_else,
AdtPatternShape::Unit,
)?
}
Pat::Lit(l) => {
let then_target = self.new_basic_block();
let else_target = current_else.unwrap_or_else(|| self.new_basic_block());
match &self.body.exprs[*l] {
Expr::Literal(l) => match l {
hir_def::expr::Literal::Int(x, _) => {
self.set_terminator(
current,
Terminator::SwitchInt {
discr: Operand::Copy(cond_place),
targets: SwitchTargets::static_if(
*x as u128,
then_target,
else_target,
),
},
);
}
hir_def::expr::Literal::Uint(x, _) => {
self.set_terminator(
current,
Terminator::SwitchInt {
discr: Operand::Copy(cond_place),
targets: SwitchTargets::static_if(*x, then_target, else_target),
},
);
}
_ => not_supported!("non int path literal"),
},
_ => not_supported!("expression path literal"),
}
(then_target, Some(else_target))
}
Pat::Bind { id, subpat } => {
let target_place = self.result.binding_locals[*id];
let mode = self.body.bindings[*id].mode;
if let Some(subpat) = subpat {
(current, current_else) = self.pattern_match(
current,
current_else,
cond_place.clone(),
cond_ty,
*subpat,
binding_mode,
)?
}
if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) {
binding_mode = mode;
}
self.push_storage_live(*id, current);
self.push_assignment(
current,
target_place.into(),
match binding_mode {
BindingAnnotation::Unannotated | BindingAnnotation::Mutable => {
Operand::Copy(cond_place).into()
}
BindingAnnotation::Ref => Rvalue::Ref(BorrowKind::Shared, cond_place),
BindingAnnotation::RefMut => Rvalue::Ref(
BorrowKind::Mut { allow_two_phase_borrow: false },
cond_place,
),
},
pattern.into(),
);
(current, current_else)
}
Pat::TupleStruct { path: _, args, ellipsis } => {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant");
};
self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place,
variant,
current,
pattern.into(),
current_else,
AdtPatternShape::Tuple { args, ellipsis: *ellipsis },
)?
}
Pat::Ref { .. } => not_supported!("& pattern"),
Pat::Box { .. } => not_supported!("box pattern"),
Pat::ConstBlock(_) => not_supported!("const block pattern"),
})
}
fn pattern_matching_variant(
&mut self,
mut cond_ty: Ty,
mut binding_mode: BindingAnnotation,
mut cond_place: Place,
variant: VariantId,
current: BasicBlockId,
span: MirSpan,
current_else: Option<BasicBlockId>,
shape: AdtPatternShape<'_>,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
let subst = match cond_ty.kind(Interner) {
TyKind::Adt(_, s) => s,
_ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")),
};
Ok(match variant {
VariantId::EnumVariantId(v) => {
let e = self.db.const_eval_discriminant(v)? as u128;
let next = self.new_basic_block();
let tmp = self.discr_temp_place();
self.push_assignment(
current,
tmp.clone(),
Rvalue::Discriminant(cond_place.clone()),
span,
);
let else_target = current_else.unwrap_or_else(|| self.new_basic_block());
self.set_terminator(
current,
Terminator::SwitchInt {
discr: Operand::Copy(tmp),
targets: SwitchTargets::static_if(e, next, else_target),
},
);
let enum_data = self.db.enum_data(v.parent);
self.pattern_matching_variant_fields(
shape,
&enum_data.variants[v.local_id].variant_data,
variant,
subst,
next,
Some(else_target),
&cond_place,
binding_mode,
)?
}
VariantId::StructId(s) => {
let struct_data = self.db.struct_data(s);
self.pattern_matching_variant_fields(
shape,
&struct_data.variant_data,
variant,
subst,
current,
current_else,
&cond_place,
binding_mode,
)?
}
VariantId::UnionId(_) => {
return Err(MirLowerError::TypeError("pattern matching on union"))
}
})
}
fn pattern_matching_variant_fields(
&mut self,
shape: AdtPatternShape<'_>,
variant_data: &VariantData,
v: VariantId,
subst: &Substitution,
current: BasicBlockId,
current_else: Option<BasicBlockId>,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let fields_type = self.db.field_types(v);
Ok(match shape {
AdtPatternShape::Record { args } => {
let it = args
.iter()
.map(|x| {
let field_id =
variant_data.field(&x.name).ok_or(MirLowerError::UnresolvedField)?;
Ok((
PlaceElem::Field(FieldId { parent: v.into(), local_id: field_id }),
x.pat,
fields_type[field_id].clone().substitute(Interner, subst),
))
})
.collect::<Result<Vec<_>>>()?;
self.pattern_match_adt(
current,
current_else,
it.into_iter(),
cond_place,
binding_mode,
)?
}
AdtPatternShape::Tuple { args, ellipsis } => {
let fields = variant_data.fields().iter().map(|(x, _)| {
(
PlaceElem::Field(FieldId { parent: v.into(), local_id: x }),
fields_type[x].clone().substitute(Interner, subst),
)
});
self.pattern_match_tuple_like(
current,
current_else,
args,
ellipsis,
fields,
cond_place,
binding_mode,
)?
}
AdtPatternShape::Unit => (current, current_else),
})
}
fn pattern_match_adt(
&mut self,
mut current: BasicBlockId,
mut current_else: Option<BasicBlockId>,
args: impl Iterator<Item = (PlaceElem, PatId, Ty)>,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
for (proj, arg, ty) in args {
let mut cond_place = cond_place.clone();
cond_place.projection.push(proj);
(current, current_else) =
self.pattern_match(current, current_else, cond_place, ty, arg, binding_mode)?;
}
Ok((current, current_else))
}
fn pattern_match_tuple_like(
&mut self,
current: BasicBlockId,
current_else: Option<BasicBlockId>,
args: &[PatId],
ellipsis: Option<usize>,
fields: impl DoubleEndedIterator<Item = (PlaceElem, Ty)> + Clone,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len()));
let it = al
.iter()
.zip(fields.clone())
.chain(ar.iter().rev().zip(fields.rev()))
.map(|(x, y)| (y.0, *x, y.1));
self.pattern_match_adt(current, current_else, it, cond_place, binding_mode)
}
fn discr_temp_place(&mut self) -> Place {
match &self.discr_temp {
Some(x) => x.clone(),
@ -1546,22 +1205,6 @@ impl MirLowerCtx<'_> {
}
}
fn pattern_matching_dereference(
cond_ty: &mut Ty,
binding_mode: &mut BindingAnnotation,
cond_place: &mut Place,
) {
while let Some((ty, _, mu)) = cond_ty.as_reference() {
if mu == Mutability::Mut && *binding_mode != BindingAnnotation::Ref {
*binding_mode = BindingAnnotation::RefMut;
} else {
*binding_mode = BindingAnnotation::Ref;
}
*cond_ty = ty.clone();
cond_place.projection.push(ProjectionElem::Deref);
}
}
fn cast_kind(source_ty: &Ty, target_ty: &Ty) -> Result<CastKind> {
Ok(match (source_ty.kind(Interner), target_ty.kind(Interner)) {
(TyKind::Scalar(s), TyKind::Scalar(t)) => match (s, t) {

View file

@ -0,0 +1,399 @@
//! MIR lowering for patterns
use super::*;
macro_rules! not_supported {
($x: expr) => {
return Err(MirLowerError::NotSupported(format!($x)))
};
}
pub(super) enum AdtPatternShape<'a> {
Tuple { args: &'a [PatId], ellipsis: Option<usize> },
Record { args: &'a [RecordFieldPat] },
Unit,
}
impl MirLowerCtx<'_> {
/// It gets a `current` unterminated block, appends some statements and possibly a terminator to it to check if
/// the pattern matches and write bindings, and returns two unterminated blocks, one for the matched path (which
/// can be the `current` block) and one for the mismatched path. If the input pattern is irrefutable, the
/// mismatched path block is `None`.
///
/// By default, it will create a new block for mismatched path. If you already have one, you can provide it with
/// `current_else` argument to save an unneccessary jump. If `current_else` isn't `None`, the result mismatched path
/// wouldn't be `None` as well. Note that this function will add jumps to the beginning of the `current_else` block,
/// so it should be an empty block.
pub(super) fn pattern_match(
&mut self,
mut current: BasicBlockId,
mut current_else: Option<BasicBlockId>,
mut cond_place: Place,
mut cond_ty: Ty,
pattern: PatId,
mut binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
Ok(match &self.body.pats[pattern] {
Pat::Missing => return Err(MirLowerError::IncompleteExpr),
Pat::Wild => (current, current_else),
Pat::Tuple { args, ellipsis } => {
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
let subst = match cond_ty.kind(Interner) {
TyKind::Tuple(_, s) => s,
_ => {
return Err(MirLowerError::TypeError(
"non tuple type matched with tuple pattern",
))
}
};
self.pattern_match_tuple_like(
current,
current_else,
args,
*ellipsis,
subst.iter(Interner).enumerate().map(|(i, x)| {
(PlaceElem::TupleField(i), x.assert_ty_ref(Interner).clone())
}),
&cond_place,
binding_mode,
)?
}
Pat::Or(pats) => {
let then_target = self.new_basic_block();
let mut finished = false;
for pat in &**pats {
let (next, next_else) = self.pattern_match(
current,
None,
cond_place.clone(),
cond_ty.clone(),
*pat,
binding_mode,
)?;
self.set_goto(next, then_target);
match next_else {
Some(t) => {
current = t;
}
None => {
finished = true;
break;
}
}
}
if !finished {
let ce = *current_else.get_or_insert_with(|| self.new_basic_block());
self.set_goto(current, ce);
}
(then_target, current_else)
}
Pat::Record { args, .. } => {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant");
};
self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place,
variant,
current,
pattern.into(),
current_else,
AdtPatternShape::Record { args: &*args },
)?
}
Pat::Range { .. } => not_supported!("range pattern"),
Pat::Slice { .. } => not_supported!("slice pattern"),
Pat::Path(_) => {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant");
};
self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place,
variant,
current,
pattern.into(),
current_else,
AdtPatternShape::Unit,
)?
}
Pat::Lit(l) => match &self.body.exprs[*l] {
Expr::Literal(l) => {
let c = self.lower_literal_to_operand(cond_ty, l)?;
self.pattern_match_const(current_else, current, c, cond_place, pattern)?
}
_ => not_supported!("expression path literal"),
},
Pat::Bind { id, subpat } => {
let target_place = self.result.binding_locals[*id];
let mode = self.body.bindings[*id].mode;
if let Some(subpat) = subpat {
(current, current_else) = self.pattern_match(
current,
current_else,
cond_place.clone(),
cond_ty,
*subpat,
binding_mode,
)?
}
if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) {
binding_mode = mode;
}
self.push_storage_live(*id, current);
self.push_assignment(
current,
target_place.into(),
match binding_mode {
BindingAnnotation::Unannotated | BindingAnnotation::Mutable => {
Operand::Copy(cond_place).into()
}
BindingAnnotation::Ref => Rvalue::Ref(BorrowKind::Shared, cond_place),
BindingAnnotation::RefMut => Rvalue::Ref(
BorrowKind::Mut { allow_two_phase_borrow: false },
cond_place,
),
},
pattern.into(),
);
(current, current_else)
}
Pat::TupleStruct { path: _, args, ellipsis } => {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant");
};
self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place,
variant,
current,
pattern.into(),
current_else,
AdtPatternShape::Tuple { args, ellipsis: *ellipsis },
)?
}
Pat::Ref { pat, mutability: _ } => {
if let Some((ty, _, _)) = cond_ty.as_reference() {
cond_ty = ty.clone();
cond_place.projection.push(ProjectionElem::Deref);
self.pattern_match(
current,
current_else,
cond_place,
cond_ty,
*pat,
binding_mode,
)?
} else {
return Err(MirLowerError::TypeError("& pattern for non reference"));
}
}
Pat::Box { .. } => not_supported!("box pattern"),
Pat::ConstBlock(_) => not_supported!("const block pattern"),
})
}
fn pattern_match_const(
&mut self,
current_else: Option<BasicBlockId>,
current: BasicBlockId,
c: Operand,
cond_place: Place,
pattern: Idx<Pat>,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let then_target = self.new_basic_block();
let else_target = current_else.unwrap_or_else(|| self.new_basic_block());
let discr: Place = self.temp(TyBuilder::bool())?.into();
self.push_assignment(
current,
discr.clone(),
Rvalue::CheckedBinaryOp(BinOp::Eq, c, Operand::Copy(cond_place)),
pattern.into(),
);
let discr = Operand::Copy(discr);
self.set_terminator(
current,
Terminator::SwitchInt {
discr,
targets: SwitchTargets::static_if(1, then_target, else_target),
},
);
Ok((then_target, Some(else_target)))
}
pub(super) fn pattern_matching_variant(
&mut self,
mut cond_ty: Ty,
mut binding_mode: BindingAnnotation,
mut cond_place: Place,
variant: VariantId,
current: BasicBlockId,
span: MirSpan,
current_else: Option<BasicBlockId>,
shape: AdtPatternShape<'_>,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
let subst = match cond_ty.kind(Interner) {
TyKind::Adt(_, s) => s,
_ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")),
};
Ok(match variant {
VariantId::EnumVariantId(v) => {
let e = self.db.const_eval_discriminant(v)? as u128;
let next = self.new_basic_block();
let tmp = self.discr_temp_place();
self.push_assignment(
current,
tmp.clone(),
Rvalue::Discriminant(cond_place.clone()),
span,
);
let else_target = current_else.unwrap_or_else(|| self.new_basic_block());
self.set_terminator(
current,
Terminator::SwitchInt {
discr: Operand::Copy(tmp),
targets: SwitchTargets::static_if(e, next, else_target),
},
);
let enum_data = self.db.enum_data(v.parent);
self.pattern_matching_variant_fields(
shape,
&enum_data.variants[v.local_id].variant_data,
variant,
subst,
next,
Some(else_target),
&cond_place,
binding_mode,
)?
}
VariantId::StructId(s) => {
let struct_data = self.db.struct_data(s);
self.pattern_matching_variant_fields(
shape,
&struct_data.variant_data,
variant,
subst,
current,
current_else,
&cond_place,
binding_mode,
)?
}
VariantId::UnionId(_) => {
return Err(MirLowerError::TypeError("pattern matching on union"))
}
})
}
fn pattern_matching_variant_fields(
&mut self,
shape: AdtPatternShape<'_>,
variant_data: &VariantData,
v: VariantId,
subst: &Substitution,
current: BasicBlockId,
current_else: Option<BasicBlockId>,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let fields_type = self.db.field_types(v);
Ok(match shape {
AdtPatternShape::Record { args } => {
let it = args
.iter()
.map(|x| {
let field_id =
variant_data.field(&x.name).ok_or(MirLowerError::UnresolvedField)?;
Ok((
PlaceElem::Field(FieldId { parent: v.into(), local_id: field_id }),
x.pat,
fields_type[field_id].clone().substitute(Interner, subst),
))
})
.collect::<Result<Vec<_>>>()?;
self.pattern_match_adt(
current,
current_else,
it.into_iter(),
cond_place,
binding_mode,
)?
}
AdtPatternShape::Tuple { args, ellipsis } => {
let fields = variant_data.fields().iter().map(|(x, _)| {
(
PlaceElem::Field(FieldId { parent: v.into(), local_id: x }),
fields_type[x].clone().substitute(Interner, subst),
)
});
self.pattern_match_tuple_like(
current,
current_else,
args,
ellipsis,
fields,
cond_place,
binding_mode,
)?
}
AdtPatternShape::Unit => (current, current_else),
})
}
fn pattern_match_adt(
&mut self,
mut current: BasicBlockId,
mut current_else: Option<BasicBlockId>,
args: impl Iterator<Item = (PlaceElem, PatId, Ty)>,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
for (proj, arg, ty) in args {
let mut cond_place = cond_place.clone();
cond_place.projection.push(proj);
(current, current_else) =
self.pattern_match(current, current_else, cond_place, ty, arg, binding_mode)?;
}
Ok((current, current_else))
}
fn pattern_match_tuple_like(
&mut self,
current: BasicBlockId,
current_else: Option<BasicBlockId>,
args: &[PatId],
ellipsis: Option<usize>,
fields: impl DoubleEndedIterator<Item = (PlaceElem, Ty)> + Clone,
cond_place: &Place,
binding_mode: BindingAnnotation,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len()));
let it = al
.iter()
.zip(fields.clone())
.chain(ar.iter().rev().zip(fields.rev()))
.map(|(x, y)| (y.0, *x, y.1));
self.pattern_match_adt(current, current_else, it, cond_place, binding_mode)
}
}
fn pattern_matching_dereference(
cond_ty: &mut Ty,
binding_mode: &mut BindingAnnotation,
cond_place: &mut Place,
) {
while let Some((ty, _, mu)) = cond_ty.as_reference() {
if mu == Mutability::Mut && *binding_mode != BindingAnnotation::Ref {
*binding_mode = BindingAnnotation::RefMut;
} else {
*binding_mode = BindingAnnotation::Ref;
}
*cond_ty = ty.clone();
cond_place.projection.push(ProjectionElem::Deref);
}
}

View file

@ -1,6 +1,6 @@
//! A pretty-printer for MIR.
use std::fmt::{Display, Write};
use std::fmt::{Display, Write, Debug};
use hir_def::{body::Body, expr::BindingId};
use hir_expand::name::Name;
@ -23,6 +23,18 @@ impl MirBody {
ctx.for_body();
ctx.result
}
// String with lines is rendered poorly in `dbg!` macros, which I use very much, so this
// function exists to solve that.
pub fn dbg(&self, db: &dyn HirDatabase) -> impl Debug {
struct StringDbg(String);
impl Debug for StringDbg {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
StringDbg(self.pretty_print(db))
}
}
struct MirPrettyCtx<'a> {

View file

@ -1376,6 +1376,7 @@ pub struct LiteralPat {
}
impl LiteralPat {
pub fn literal(&self) -> Option<Literal> { support::child(&self.syntax) }
pub fn minus_token(&self) -> Option<SyntaxToken> { support::token(&self.syntax, T![-]) }
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]

View file

@ -597,7 +597,10 @@ pub mod option {
loop {}
}
pub fn unwrap_or(self, default: T) -> T {
loop {}
match self {
Some(val) => val,
None => default,
}
}
// region:fn
pub fn and_then<U, F>(self, f: F) -> Option<U>