Auto merge of #14955 - HKalbasi:mir-fix, r=HKalbasi

Remove unnecessary `StorageDead`

I hope this reduces MIR memory usage.
This commit is contained in:
bors 2023-06-03 13:55:40 +00:00
commit e5c56cd9a0
4 changed files with 238 additions and 301 deletions

View file

@ -9,10 +9,7 @@ use chalk_ir::{
}; };
use hir_def::{ use hir_def::{
data::adt::VariantData, data::adt::VariantData,
hir::{ hir::{Array, BinaryOp, BindingId, CaptureBy, Expr, ExprId, Pat, PatId, Statement, UnaryOp},
Array, BinaryOp, BindingAnnotation, BindingId, CaptureBy, Expr, ExprId, Pat, PatId,
Statement, UnaryOp,
},
lang_item::LangItem, lang_item::LangItem,
resolver::{resolver_for_expr, ResolveValueResult, ValueNs}, resolver::{resolver_for_expr, ResolveValueResult, ValueNs},
DefWithBodyId, FieldId, HasModule, VariantId, DefWithBodyId, FieldId, HasModule, VariantId,
@ -28,9 +25,9 @@ use crate::{
mir::{BorrowKind, MirSpan, ProjectionElem}, mir::{BorrowKind, MirSpan, ProjectionElem},
static_lifetime, to_chalk_trait_id, static_lifetime, to_chalk_trait_id,
traits::FnTrait, traits::FnTrait,
utils::{self, generics, pattern_matching_dereference_count, Generics}, utils::{self, generics, Generics},
Adjust, Adjustment, Binders, ChalkTraitId, ClosureId, ConstValue, DynTy, FnPointer, FnSig, Adjust, Adjustment, Binders, BindingMode, ChalkTraitId, ClosureId, ConstValue, DynTy,
Interner, Substitution, Ty, TyExt, FnPointer, FnSig, Interner, Substitution, Ty, TyExt,
}; };
use super::{Expectation, InferenceContext}; use super::{Expectation, InferenceContext};
@ -488,13 +485,7 @@ impl InferenceContext<'_> {
if let Some(initializer) = initializer { if let Some(initializer) = initializer {
self.walk_expr(*initializer); self.walk_expr(*initializer);
if let Some(place) = self.place_of_expr(*initializer) { if let Some(place) = self.place_of_expr(*initializer) {
let ty = self.expr_ty(*initializer); self.consume_with_pat(place, *pat);
self.consume_with_pat(
place,
ty,
BindingAnnotation::Unannotated,
*pat,
);
} }
} }
} }
@ -799,41 +790,37 @@ impl InferenceContext<'_> {
} }
} }
fn consume_with_pat( fn consume_with_pat(&mut self, mut place: HirPlace, pat: PatId) {
&mut self, let cnt = self.result.pat_adjustments.get(&pat).map(|x| x.len()).unwrap_or_default();
mut place: HirPlace, place.projections = place
mut ty: Ty, .projections
mut bm: BindingAnnotation, .iter()
pat: PatId, .cloned()
) { .chain((0..cnt).map(|_| ProjectionElem::Deref))
.collect::<Vec<_>>()
.into();
match &self.body[pat] { match &self.body[pat] {
Pat::Missing | Pat::Wild => (), Pat::Missing | Pat::Wild => (),
Pat::Tuple { args, ellipsis } => { Pat::Tuple { args, ellipsis } => {
pattern_matching_dereference(&mut ty, &mut bm, &mut place);
let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len())); let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len()));
let subst = match ty.kind(Interner) { let field_count = match self.result[pat].kind(Interner) {
TyKind::Tuple(_, s) => s, TyKind::Tuple(_, s) => s.len(Interner),
_ => return, _ => return,
}; };
let fields = subst.iter(Interner).map(|x| x.assert_ty_ref(Interner)).enumerate(); let fields = 0..field_count;
let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev()));
for (arg, (i, ty)) in it { for (arg, i) in it {
let mut p = place.clone(); let mut p = place.clone();
p.projections.push(ProjectionElem::TupleOrClosureField(i)); p.projections.push(ProjectionElem::TupleOrClosureField(i));
self.consume_with_pat(p, ty.clone(), bm, *arg); self.consume_with_pat(p, *arg);
} }
} }
Pat::Or(pats) => { Pat::Or(pats) => {
for pat in pats.iter() { for pat in pats.iter() {
self.consume_with_pat(place.clone(), ty.clone(), bm, *pat); self.consume_with_pat(place.clone(), *pat);
} }
} }
Pat::Record { args, .. } => { Pat::Record { args, .. } => {
pattern_matching_dereference(&mut ty, &mut bm, &mut place);
let subst = match ty.kind(Interner) {
TyKind::Adt(_, s) => s,
_ => return,
};
let Some(variant) = self.result.variant_resolution_for_pat(pat) else { let Some(variant) = self.result.variant_resolution_for_pat(pat) else {
return; return;
}; };
@ -843,7 +830,6 @@ impl InferenceContext<'_> {
} }
VariantId::StructId(s) => { VariantId::StructId(s) => {
let vd = &*self.db.struct_data(s).variant_data; let vd = &*self.db.struct_data(s).variant_data;
let field_types = self.db.field_types(variant);
for field_pat in args.iter() { for field_pat in args.iter() {
let arg = field_pat.pat; let arg = field_pat.pat;
let Some(local_id) = vd.field(&field_pat.name) else { let Some(local_id) = vd.field(&field_pat.name) else {
@ -854,12 +840,7 @@ impl InferenceContext<'_> {
parent: variant.into(), parent: variant.into(),
local_id, local_id,
})); }));
self.consume_with_pat( self.consume_with_pat(p, arg);
p,
field_types[local_id].clone().substitute(Interner, subst),
bm,
arg,
);
} }
} }
} }
@ -870,26 +851,20 @@ impl InferenceContext<'_> {
| Pat::Path(_) | Pat::Path(_)
| Pat::Lit(_) => self.consume_place(place, pat.into()), | Pat::Lit(_) => self.consume_place(place, pat.into()),
Pat::Bind { id, subpat: _ } => { Pat::Bind { id, subpat: _ } => {
let mode = self.body.bindings[*id].mode; let mode = self.result.binding_modes[*id];
if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) { let capture_kind = match mode {
bm = mode; BindingMode::Move => {
}
let capture_kind = match bm {
BindingAnnotation::Unannotated | BindingAnnotation::Mutable => {
self.consume_place(place, pat.into()); self.consume_place(place, pat.into());
return; return;
} }
BindingAnnotation::Ref => BorrowKind::Shared, BindingMode::Ref(Mutability::Not) => BorrowKind::Shared,
BindingAnnotation::RefMut => BorrowKind::Mut { allow_two_phase_borrow: false }, BindingMode::Ref(Mutability::Mut) => {
BorrowKind::Mut { allow_two_phase_borrow: false }
}
}; };
self.add_capture(place, CaptureKind::ByRef(capture_kind), pat.into()); self.add_capture(place, CaptureKind::ByRef(capture_kind), pat.into());
} }
Pat::TupleStruct { path: _, args, ellipsis } => { Pat::TupleStruct { path: _, args, ellipsis } => {
pattern_matching_dereference(&mut ty, &mut bm, &mut place);
let subst = match ty.kind(Interner) {
TyKind::Adt(_, s) => s,
_ => return,
};
let Some(variant) = self.result.variant_resolution_for_pat(pat) else { let Some(variant) = self.result.variant_resolution_for_pat(pat) else {
return; return;
}; };
@ -903,29 +878,20 @@ impl InferenceContext<'_> {
let fields = vd.fields().iter(); let fields = vd.fields().iter();
let it = let it =
al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev()));
let field_types = self.db.field_types(variant);
for (arg, (i, _)) in it { for (arg, (i, _)) in it {
let mut p = place.clone(); let mut p = place.clone();
p.projections.push(ProjectionElem::Field(FieldId { p.projections.push(ProjectionElem::Field(FieldId {
parent: variant.into(), parent: variant.into(),
local_id: i, local_id: i,
})); }));
self.consume_with_pat( self.consume_with_pat(p, *arg);
p,
field_types[i].clone().substitute(Interner, subst),
bm,
*arg,
);
} }
} }
} }
} }
Pat::Ref { pat, mutability: _ } => { Pat::Ref { pat, mutability: _ } => {
if let Some((inner, _, _)) = ty.as_reference() {
ty = inner.clone();
place.projections.push(ProjectionElem::Deref); place.projections.push(ProjectionElem::Deref);
self.consume_with_pat(place, ty, bm, *pat) self.consume_with_pat(place, *pat)
}
} }
Pat::Box { .. } => (), // not supported Pat::Box { .. } => (), // not supported
} }
@ -1054,12 +1020,3 @@ fn apply_adjusts_to_place(mut r: HirPlace, adjustments: &[Adjustment]) -> Option
} }
Some(r) Some(r)
} }
fn pattern_matching_dereference(
cond_ty: &mut Ty,
binding_mode: &mut BindingAnnotation,
cond_place: &mut HirPlace,
) {
let cnt = pattern_matching_dereference_count(cond_ty, binding_mode);
cond_place.projections.extend((0..cnt).map(|_| ProjectionElem::Deref));
}

View file

@ -478,9 +478,7 @@ impl<'ctx> MirLowerCtx<'ctx> {
current, current,
None, None,
cond_place, cond_place,
self.expr_ty_after_adjustments(*expr),
*pat, *pat,
BindingAnnotation::Unannotated,
)?; )?;
self.write_bytes_to_place( self.write_bytes_to_place(
then_target, then_target,
@ -598,16 +596,13 @@ impl<'ctx> MirLowerCtx<'ctx> {
else { else {
return Ok(None); return Ok(None);
}; };
let cond_ty = self.expr_ty_after_adjustments(*expr);
let mut end = None; let mut end = None;
for MatchArm { pat, guard, expr } in arms.iter() { for MatchArm { pat, guard, expr } in arms.iter() {
let (then, mut otherwise) = self.pattern_match( let (then, mut otherwise) = self.pattern_match(
current, current,
None, None,
cond_place.clone(), cond_place.clone(),
cond_ty.clone(),
*pat, *pat,
BindingAnnotation::Unannotated,
)?; )?;
let then = if let &Some(guard) = guard { let then = if let &Some(guard) = guard {
let next = self.new_basic_block(); let next = self.new_basic_block();
@ -1477,9 +1472,6 @@ impl<'ctx> MirLowerCtx<'ctx> {
span: MirSpan, span: MirSpan,
) -> Result<()> { ) -> Result<()> {
self.drop_scopes.last_mut().unwrap().locals.push(l); self.drop_scopes.last_mut().unwrap().locals.push(l);
// FIXME: this storage dead is not neccessary, but since drop scope handling is broken, we need
// it to avoid falso positives in mutability errors
self.push_statement(current, StatementKind::StorageDead(l).with_span(span));
self.push_statement(current, StatementKind::StorageLive(l).with_span(span)); self.push_statement(current, StatementKind::StorageLive(l).with_span(span));
Ok(()) Ok(())
} }
@ -1508,14 +1500,8 @@ impl<'ctx> MirLowerCtx<'ctx> {
return Ok(None); return Ok(None);
}; };
current = c; current = c;
(current, else_block) = self.pattern_match( (current, else_block) =
current, self.pattern_match(current, None, init_place, *pat)?;
None,
init_place,
self.expr_ty_after_adjustments(*expr_id),
*pat,
BindingAnnotation::Unannotated,
)?;
match (else_block, else_branch) { match (else_block, else_branch) {
(None, _) => (), (None, _) => (),
(Some(else_block), None) => { (Some(else_block), None) => {
@ -1595,14 +1581,7 @@ impl<'ctx> MirLowerCtx<'ctx> {
continue; continue;
} }
} }
let r = self.pattern_match( let r = self.pattern_match(current, None, local.into(), param)?;
current,
None,
local.into(),
self.result.locals[local].ty.clone(),
param,
BindingAnnotation::Unannotated,
)?;
if let Some(b) = r.1 { if let Some(b) = r.1 {
self.set_terminator(b, TerminatorKind::Unreachable, param.into()); self.set_terminator(b, TerminatorKind::Unreachable, param.into());
} }

View file

@ -2,7 +2,7 @@
use hir_def::{hir::LiteralOrConst, resolver::HasResolver, AssocItemId}; use hir_def::{hir::LiteralOrConst, resolver::HasResolver, AssocItemId};
use crate::utils::pattern_matching_dereference_count; use crate::BindingMode;
use super::*; use super::*;
@ -18,6 +18,26 @@ pub(super) enum AdtPatternShape<'a> {
Unit, Unit,
} }
/// We need to do pattern matching in two phases: One to check if the pattern matches, and one to fill the bindings
/// of patterns. This is necessary to prevent double moves and similar problems. For example:
/// ```ignore
/// struct X;
/// match (X, 3) {
/// (b, 2) | (b, 3) => {},
/// _ => {}
/// }
/// ```
/// If we do everything in one pass, we will move `X` to the first `b`, then we see that the second field of tuple
/// doesn't match and we should move the `X` to the second `b` (which here is the same thing, but doesn't need to be) and
/// it might even doesn't match the second pattern and we may want to not move `X` at all.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MatchingMode {
/// Check that if this pattern matches
Check,
/// Assume that this pattern matches, fill bindings
Bind,
}
impl MirLowerCtx<'_> { impl MirLowerCtx<'_> {
/// It gets a `current` unterminated block, appends some statements and possibly a terminator to it to check if /// It gets a `current` unterminated block, appends some statements and possibly a terminator to it to check if
/// the pattern matches and write bindings, and returns two unterminated blocks, one for the matched path (which /// the pattern matches and write bindings, and returns two unterminated blocks, one for the matched path (which
@ -29,20 +49,50 @@ impl MirLowerCtx<'_> {
/// wouldn't be `None` as well. Note that this function will add jumps to the beginning of the `current_else` block, /// wouldn't be `None` as well. Note that this function will add jumps to the beginning of the `current_else` block,
/// so it should be an empty block. /// so it should be an empty block.
pub(super) fn pattern_match( pub(super) fn pattern_match(
&mut self,
current: BasicBlockId,
current_else: Option<BasicBlockId>,
cond_place: Place,
pattern: PatId,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let (current, current_else) = self.pattern_match_inner(
current,
current_else,
cond_place.clone(),
pattern,
MatchingMode::Check,
)?;
let (current, current_else) = self.pattern_match_inner(
current,
current_else,
cond_place,
pattern,
MatchingMode::Bind,
)?;
Ok((current, current_else))
}
fn pattern_match_inner(
&mut self, &mut self,
mut current: BasicBlockId, mut current: BasicBlockId,
mut current_else: Option<BasicBlockId>, mut current_else: Option<BasicBlockId>,
mut cond_place: Place, mut cond_place: Place,
mut cond_ty: Ty,
pattern: PatId, pattern: PatId,
mut binding_mode: BindingAnnotation, mode: MatchingMode,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> { ) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let cnt = self.infer.pat_adjustments.get(&pattern).map(|x| x.len()).unwrap_or_default();
cond_place.projection = cond_place
.projection
.iter()
.cloned()
.chain((0..cnt).map(|_| ProjectionElem::Deref))
.collect::<Vec<_>>()
.into();
Ok(match &self.body.pats[pattern] { Ok(match &self.body.pats[pattern] {
Pat::Missing => return Err(MirLowerError::IncompletePattern), Pat::Missing => return Err(MirLowerError::IncompletePattern),
Pat::Wild => (current, current_else), Pat::Wild => (current, current_else),
Pat::Tuple { args, ellipsis } => { Pat::Tuple { args, ellipsis } => {
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place); let subst = match self.infer[pattern].kind(Interner) {
let subst = match cond_ty.kind(Interner) {
TyKind::Tuple(_, s) => s, TyKind::Tuple(_, s) => s,
_ => { _ => {
return Err(MirLowerError::TypeError( return Err(MirLowerError::TypeError(
@ -55,25 +105,31 @@ impl MirLowerCtx<'_> {
current_else, current_else,
args, args,
*ellipsis, *ellipsis,
subst.iter(Interner).enumerate().map(|(i, x)| { (0..subst.len(Interner)).map(|i| PlaceElem::TupleOrClosureField(i)),
(PlaceElem::TupleOrClosureField(i), x.assert_ty_ref(Interner).clone()) &(&mut cond_place),
}), mode,
&cond_place,
binding_mode,
)? )?
} }
Pat::Or(pats) => { Pat::Or(pats) => {
let then_target = self.new_basic_block(); let then_target = self.new_basic_block();
let mut finished = false; let mut finished = false;
for pat in &**pats { for pat in &**pats {
let (next, next_else) = self.pattern_match( let (mut next, next_else) = self.pattern_match_inner(
current, current,
None, None,
cond_place.clone(), (&mut cond_place).clone(),
cond_ty.clone(),
*pat, *pat,
binding_mode, MatchingMode::Check,
)?; )?;
if mode == MatchingMode::Bind {
(next, _) = self.pattern_match_inner(
next,
None,
(&mut cond_place).clone(),
*pat,
MatchingMode::Bind,
)?;
}
self.set_goto(next, then_target, pattern.into()); self.set_goto(next, then_target, pattern.into());
match next_else { match next_else {
Some(t) => { Some(t) => {
@ -86,9 +142,13 @@ impl MirLowerCtx<'_> {
} }
} }
if !finished { if !finished {
if mode == MatchingMode::Bind {
self.set_terminator(current, TerminatorKind::Unreachable, pattern.into());
} else {
let ce = *current_else.get_or_insert_with(|| self.new_basic_block()); let ce = *current_else.get_or_insert_with(|| self.new_basic_block());
self.set_goto(current, ce, pattern.into()); self.set_goto(current, ce, pattern.into());
} }
}
(then_target, current_else) (then_target, current_else)
} }
Pat::Record { args, .. } => { Pat::Record { args, .. } => {
@ -96,19 +156,19 @@ impl MirLowerCtx<'_> {
not_supported!("unresolved variant for record"); not_supported!("unresolved variant for record");
}; };
self.pattern_matching_variant( self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place, cond_place,
variant, variant,
current, current,
pattern.into(), pattern.into(),
current_else, current_else,
AdtPatternShape::Record { args: &*args }, AdtPatternShape::Record { args: &*args },
mode,
)? )?
} }
Pat::Range { start, end } => { Pat::Range { start, end } => {
let mut add_check = |l: &LiteralOrConst, binop| -> Result<()> { let mut add_check = |l: &LiteralOrConst, binop| -> Result<()> {
let lv = self.lower_literal_or_const_to_operand(cond_ty.clone(), l)?; let lv =
self.lower_literal_or_const_to_operand(self.infer[pattern].clone(), l)?;
let else_target = *current_else.get_or_insert_with(|| self.new_basic_block()); let else_target = *current_else.get_or_insert_with(|| self.new_basic_block());
let next = self.new_basic_block(); let next = self.new_basic_block();
let discr: Place = let discr: Place =
@ -116,7 +176,11 @@ impl MirLowerCtx<'_> {
self.push_assignment( self.push_assignment(
current, current,
discr.clone(), discr.clone(),
Rvalue::CheckedBinaryOp(binop, lv, Operand::Copy(cond_place.clone())), Rvalue::CheckedBinaryOp(
binop,
lv,
Operand::Copy((&mut cond_place).clone()),
),
pattern.into(), pattern.into(),
); );
let discr = Operand::Copy(discr); let discr = Operand::Copy(discr);
@ -131,24 +195,25 @@ impl MirLowerCtx<'_> {
current = next; current = next;
Ok(()) Ok(())
}; };
if mode == MatchingMode::Check {
if let Some(start) = start { if let Some(start) = start {
add_check(start, BinOp::Le)?; add_check(start, BinOp::Le)?;
} }
if let Some(end) = end { if let Some(end) = end {
add_check(end, BinOp::Ge)?; add_check(end, BinOp::Ge)?;
} }
}
(current, current_else) (current, current_else)
} }
Pat::Slice { prefix, slice, suffix } => { Pat::Slice { prefix, slice, suffix } => {
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place); if let TyKind::Slice(_) = self.infer[pattern].kind(Interner) {
if let TyKind::Slice(_) = cond_ty.kind(Interner) {
let pattern_len = prefix.len() + suffix.len(); let pattern_len = prefix.len() + suffix.len();
let place_len: Place = let place_len: Place =
self.temp(TyBuilder::usize(), current, pattern.into())?.into(); self.temp(TyBuilder::usize(), current, pattern.into())?.into();
self.push_assignment( self.push_assignment(
current, current,
place_len.clone(), place_len.clone(),
Rvalue::Len(cond_place.clone()), Rvalue::Len((&mut cond_place).clone()),
pattern.into(), pattern.into(),
); );
let else_target = *current_else.get_or_insert_with(|| self.new_basic_block()); let else_target = *current_else.get_or_insert_with(|| self.new_basic_block());
@ -193,29 +258,22 @@ impl MirLowerCtx<'_> {
current = next; current = next;
} }
for (i, &pat) in prefix.iter().enumerate() { for (i, &pat) in prefix.iter().enumerate() {
let next_place = cond_place.project(ProjectionElem::ConstantIndex { let next_place = (&mut cond_place).project(ProjectionElem::ConstantIndex {
offset: i as u64, offset: i as u64,
from_end: false, from_end: false,
}); });
let cond_ty = self.infer[pat].clone(); (current, current_else) =
(current, current_else) = self.pattern_match( self.pattern_match_inner(current, current_else, next_place, pat, mode)?;
current,
current_else,
next_place,
cond_ty,
pat,
binding_mode,
)?;
} }
if let Some(slice) = slice { if let Some(slice) = slice {
if mode == MatchingMode::Bind {
if let Pat::Bind { id, subpat: _ } = self.body[*slice] { if let Pat::Bind { id, subpat: _ } = self.body[*slice] {
let next_place = cond_place.project(ProjectionElem::Subslice { let next_place = (&mut cond_place).project(ProjectionElem::Subslice {
from: prefix.len() as u64, from: prefix.len() as u64,
to: suffix.len() as u64, to: suffix.len() as u64,
}); });
(current, current_else) = self.pattern_match_binding( (current, current_else) = self.pattern_match_binding(
id, id,
&mut binding_mode,
next_place, next_place,
(*slice).into(), (*slice).into(),
current, current,
@ -223,33 +281,26 @@ impl MirLowerCtx<'_> {
)?; )?;
} }
} }
}
for (i, &pat) in suffix.iter().enumerate() { for (i, &pat) in suffix.iter().enumerate() {
let next_place = cond_place.project(ProjectionElem::ConstantIndex { let next_place = (&mut cond_place).project(ProjectionElem::ConstantIndex {
offset: i as u64, offset: i as u64,
from_end: true, from_end: true,
}); });
let cond_ty = self.infer[pat].clone(); (current, current_else) =
(current, current_else) = self.pattern_match( self.pattern_match_inner(current, current_else, next_place, pat, mode)?;
current,
current_else,
next_place,
cond_ty,
pat,
binding_mode,
)?;
} }
(current, current_else) (current, current_else)
} }
Pat::Path(p) => match self.infer.variant_resolution_for_pat(pattern) { Pat::Path(p) => match self.infer.variant_resolution_for_pat(pattern) {
Some(variant) => self.pattern_matching_variant( Some(variant) => self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place, cond_place,
variant, variant,
current, current,
pattern.into(), pattern.into(),
current_else, current_else,
AdtPatternShape::Unit, AdtPatternShape::Unit,
mode,
)?, )?,
None => { None => {
let unresolved_name = || MirLowerError::unresolved_path(self.db, p); let unresolved_name = || MirLowerError::unresolved_path(self.db, p);
@ -270,9 +321,17 @@ impl MirLowerCtx<'_> {
} }
not_supported!("path in pattern position that is not const or variant") not_supported!("path in pattern position that is not const or variant")
}; };
let tmp: Place = self.temp(cond_ty.clone(), current, pattern.into())?.into(); let tmp: Place =
self.temp(self.infer[pattern].clone(), current, pattern.into())?.into();
let span = pattern.into(); let span = pattern.into();
self.lower_const(c.into(), current, tmp.clone(), subst, span, cond_ty.clone())?; self.lower_const(
c.into(),
current,
tmp.clone(),
subst,
span,
self.infer[pattern].clone(),
)?;
let tmp2: Place = self.temp(TyBuilder::bool(), current, pattern.into())?.into(); let tmp2: Place = self.temp(TyBuilder::bool(), current, pattern.into())?.into();
self.push_assignment( self.push_assignment(
current, current,
@ -299,61 +358,58 @@ impl MirLowerCtx<'_> {
}, },
Pat::Lit(l) => match &self.body.exprs[*l] { Pat::Lit(l) => match &self.body.exprs[*l] {
Expr::Literal(l) => { Expr::Literal(l) => {
let c = self.lower_literal_to_operand(cond_ty, l)?; let c = self.lower_literal_to_operand(self.infer[pattern].clone(), l)?;
if mode == MatchingMode::Check {
self.pattern_match_const(current_else, current, c, cond_place, pattern)? self.pattern_match_const(current_else, current, c, cond_place, pattern)?
} else {
(current, current_else)
}
} }
_ => not_supported!("expression path literal"), _ => not_supported!("expression path literal"),
}, },
Pat::Bind { id, subpat } => { Pat::Bind { id, subpat } => {
if let Some(subpat) = subpat { if let Some(subpat) = subpat {
(current, current_else) = self.pattern_match( (current, current_else) = self.pattern_match_inner(
current, current,
current_else, current_else,
cond_place.clone(), (&mut cond_place).clone(),
cond_ty,
*subpat, *subpat,
binding_mode, mode,
)? )?
} }
if mode == MatchingMode::Bind {
self.pattern_match_binding( self.pattern_match_binding(
*id, *id,
&mut binding_mode,
cond_place, cond_place,
pattern.into(), pattern.into(),
current, current,
current_else, current_else,
)? )?
} else {
(current, current_else)
}
} }
Pat::TupleStruct { path: _, args, ellipsis } => { Pat::TupleStruct { path: _, args, ellipsis } => {
let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else { let Some(variant) = self.infer.variant_resolution_for_pat(pattern) else {
not_supported!("unresolved variant"); not_supported!("unresolved variant");
}; };
self.pattern_matching_variant( self.pattern_matching_variant(
cond_ty,
binding_mode,
cond_place, cond_place,
variant, variant,
current, current,
pattern.into(), pattern.into(),
current_else, current_else,
AdtPatternShape::Tuple { args, ellipsis: *ellipsis }, AdtPatternShape::Tuple { args, ellipsis: *ellipsis },
mode,
)? )?
} }
Pat::Ref { pat, mutability: _ } => { Pat::Ref { pat, mutability: _ } => self.pattern_match_inner(
if let Some((ty, _, _)) = cond_ty.as_reference() {
cond_ty = ty.clone();
self.pattern_match(
current, current,
current_else, current_else,
cond_place.project(ProjectionElem::Deref), cond_place.project(ProjectionElem::Deref),
cond_ty,
*pat, *pat,
binding_mode, mode,
)? )?,
} else {
return Err(MirLowerError::TypeError("& pattern for non reference"));
}
}
Pat::Box { .. } => not_supported!("box pattern"), Pat::Box { .. } => not_supported!("box pattern"),
Pat::ConstBlock(_) => not_supported!("const block pattern"), Pat::ConstBlock(_) => not_supported!("const block pattern"),
}) })
@ -362,27 +418,21 @@ impl MirLowerCtx<'_> {
fn pattern_match_binding( fn pattern_match_binding(
&mut self, &mut self,
id: BindingId, id: BindingId,
binding_mode: &mut BindingAnnotation,
cond_place: Place, cond_place: Place,
span: MirSpan, span: MirSpan,
current: BasicBlockId, current: BasicBlockId,
current_else: Option<BasicBlockId>, current_else: Option<BasicBlockId>,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> { ) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let target_place = self.binding_local(id)?; let target_place = self.binding_local(id)?;
let mode = self.body.bindings[id].mode; let mode = self.infer.binding_modes[id];
if matches!(mode, BindingAnnotation::Ref | BindingAnnotation::RefMut) {
*binding_mode = mode;
}
self.push_storage_live(id, current)?; self.push_storage_live(id, current)?;
self.push_assignment( self.push_assignment(
current, current,
target_place.into(), target_place.into(),
match *binding_mode { match mode {
BindingAnnotation::Unannotated | BindingAnnotation::Mutable => { BindingMode::Move => Operand::Copy(cond_place).into(),
Operand::Copy(cond_place).into() BindingMode::Ref(Mutability::Not) => Rvalue::Ref(BorrowKind::Shared, cond_place),
} BindingMode::Ref(Mutability::Mut) => {
BindingAnnotation::Ref => Rvalue::Ref(BorrowKind::Shared, cond_place),
BindingAnnotation::RefMut => {
Rvalue::Ref(BorrowKind::Mut { allow_two_phase_borrow: false }, cond_place) Rvalue::Ref(BorrowKind::Mut { allow_two_phase_borrow: false }, cond_place)
} }
}, },
@ -420,24 +470,19 @@ impl MirLowerCtx<'_> {
Ok((then_target, Some(else_target))) Ok((then_target, Some(else_target)))
} }
pub(super) fn pattern_matching_variant( fn pattern_matching_variant(
&mut self, &mut self,
mut cond_ty: Ty, cond_place: Place,
mut binding_mode: BindingAnnotation,
mut cond_place: Place,
variant: VariantId, variant: VariantId,
current: BasicBlockId, mut current: BasicBlockId,
span: MirSpan, span: MirSpan,
current_else: Option<BasicBlockId>, mut current_else: Option<BasicBlockId>,
shape: AdtPatternShape<'_>, shape: AdtPatternShape<'_>,
mode: MatchingMode,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> { ) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
pattern_matching_dereference(&mut cond_ty, &mut binding_mode, &mut cond_place);
let subst = match cond_ty.kind(Interner) {
TyKind::Adt(_, s) => s,
_ => return Err(MirLowerError::TypeError("non adt type matched with tuple struct")),
};
Ok(match variant { Ok(match variant {
VariantId::EnumVariantId(v) => { VariantId::EnumVariantId(v) => {
if mode == MatchingMode::Check {
let e = self.const_eval_discriminant(v)? as u128; let e = self.const_eval_discriminant(v)? as u128;
let tmp = self.discr_temp_place(current); let tmp = self.discr_temp_place(current);
self.push_assignment( self.push_assignment(
@ -447,25 +492,26 @@ impl MirLowerCtx<'_> {
span, span,
); );
let next = self.new_basic_block(); let next = self.new_basic_block();
let else_target = current_else.unwrap_or_else(|| self.new_basic_block()); let else_target = current_else.get_or_insert_with(|| self.new_basic_block());
self.set_terminator( self.set_terminator(
current, current,
TerminatorKind::SwitchInt { TerminatorKind::SwitchInt {
discr: Operand::Copy(tmp), discr: Operand::Copy(tmp),
targets: SwitchTargets::static_if(e, next, else_target), targets: SwitchTargets::static_if(e, next, *else_target),
}, },
span, span,
); );
current = next;
}
let enum_data = self.db.enum_data(v.parent); let enum_data = self.db.enum_data(v.parent);
self.pattern_matching_variant_fields( self.pattern_matching_variant_fields(
shape, shape,
&enum_data.variants[v.local_id].variant_data, &enum_data.variants[v.local_id].variant_data,
variant, variant,
subst, current,
next, current_else,
Some(else_target),
&cond_place, &cond_place,
binding_mode, mode,
)? )?
} }
VariantId::StructId(s) => { VariantId::StructId(s) => {
@ -474,11 +520,10 @@ impl MirLowerCtx<'_> {
shape, shape,
&struct_data.variant_data, &struct_data.variant_data,
variant, variant,
subst,
current, current,
current_else, current_else,
&cond_place, &cond_place,
binding_mode, mode,
)? )?
} }
VariantId::UnionId(_) => { VariantId::UnionId(_) => {
@ -492,13 +537,11 @@ impl MirLowerCtx<'_> {
shape: AdtPatternShape<'_>, shape: AdtPatternShape<'_>,
variant_data: &VariantData, variant_data: &VariantData,
v: VariantId, v: VariantId,
subst: &Substitution,
current: BasicBlockId, current: BasicBlockId,
current_else: Option<BasicBlockId>, current_else: Option<BasicBlockId>,
cond_place: &Place, cond_place: &Place,
binding_mode: BindingAnnotation, mode: MatchingMode,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> { ) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let fields_type = self.db.field_types(v);
Ok(match shape { Ok(match shape {
AdtPatternShape::Record { args } => { AdtPatternShape::Record { args } => {
let it = args let it = args
@ -509,25 +552,16 @@ impl MirLowerCtx<'_> {
Ok(( Ok((
PlaceElem::Field(FieldId { parent: v.into(), local_id: field_id }), PlaceElem::Field(FieldId { parent: v.into(), local_id: field_id }),
x.pat, x.pat,
fields_type[field_id].clone().substitute(Interner, subst),
)) ))
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
self.pattern_match_adt( self.pattern_match_adt(current, current_else, it.into_iter(), cond_place, mode)?
current,
current_else,
it.into_iter(),
cond_place,
binding_mode,
)?
} }
AdtPatternShape::Tuple { args, ellipsis } => { AdtPatternShape::Tuple { args, ellipsis } => {
let fields = variant_data.fields().iter().map(|(x, _)| { let fields = variant_data
( .fields()
PlaceElem::Field(FieldId { parent: v.into(), local_id: x }), .iter()
fields_type[x].clone().substitute(Interner, subst), .map(|(x, _)| PlaceElem::Field(FieldId { parent: v.into(), local_id: x }));
)
});
self.pattern_match_tuple_like( self.pattern_match_tuple_like(
current, current,
current_else, current_else,
@ -535,7 +569,7 @@ impl MirLowerCtx<'_> {
ellipsis, ellipsis,
fields, fields,
cond_place, cond_place,
binding_mode, mode,
)? )?
} }
AdtPatternShape::Unit => (current, current_else), AdtPatternShape::Unit => (current, current_else),
@ -546,14 +580,14 @@ impl MirLowerCtx<'_> {
&mut self, &mut self,
mut current: BasicBlockId, mut current: BasicBlockId,
mut current_else: Option<BasicBlockId>, mut current_else: Option<BasicBlockId>,
args: impl Iterator<Item = (PlaceElem, PatId, Ty)>, args: impl Iterator<Item = (PlaceElem, PatId)>,
cond_place: &Place, cond_place: &Place,
binding_mode: BindingAnnotation, mode: MatchingMode,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> { ) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
for (proj, arg, ty) in args { for (proj, arg) in args {
let cond_place = cond_place.project(proj); let cond_place = cond_place.project(proj);
(current, current_else) = (current, current_else) =
self.pattern_match(current, current_else, cond_place, ty, arg, binding_mode)?; self.pattern_match_inner(current, current_else, cond_place, arg, mode)?;
} }
Ok((current, current_else)) Ok((current, current_else))
} }
@ -564,31 +598,16 @@ impl MirLowerCtx<'_> {
current_else: Option<BasicBlockId>, current_else: Option<BasicBlockId>,
args: &[PatId], args: &[PatId],
ellipsis: Option<usize>, ellipsis: Option<usize>,
fields: impl DoubleEndedIterator<Item = (PlaceElem, Ty)> + Clone, fields: impl DoubleEndedIterator<Item = PlaceElem> + Clone,
cond_place: &Place, cond_place: &Place,
binding_mode: BindingAnnotation, mode: MatchingMode,
) -> Result<(BasicBlockId, Option<BasicBlockId>)> { ) -> Result<(BasicBlockId, Option<BasicBlockId>)> {
let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len())); let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len()));
let it = al let it = al
.iter() .iter()
.zip(fields.clone()) .zip(fields.clone())
.chain(ar.iter().rev().zip(fields.rev())) .chain(ar.iter().rev().zip(fields.rev()))
.map(|(x, y)| (y.0, *x, y.1)); .map(|(x, y)| (y, *x));
self.pattern_match_adt(current, current_else, it, cond_place, binding_mode) self.pattern_match_adt(current, current_else, it, cond_place, mode)
} }
} }
fn pattern_matching_dereference(
cond_ty: &mut Ty,
binding_mode: &mut BindingAnnotation,
cond_place: &mut Place,
) {
let cnt = pattern_matching_dereference_count(cond_ty, binding_mode);
cond_place.projection = cond_place
.projection
.iter()
.cloned()
.chain((0..cnt).map(|_| ProjectionElem::Deref))
.collect::<Vec<_>>()
.into();
}

View file

@ -7,7 +7,7 @@ use base_db::CrateId;
use chalk_ir::{ use chalk_ir::{
cast::Cast, cast::Cast,
fold::{FallibleTypeFolder, Shift}, fold::{FallibleTypeFolder, Shift},
BoundVar, DebruijnIndex, Mutability, BoundVar, DebruijnIndex,
}; };
use either::Either; use either::Either;
use hir_def::{ use hir_def::{
@ -16,7 +16,6 @@ use hir_def::{
GenericParams, TypeOrConstParamData, TypeParamProvenance, WherePredicate, GenericParams, TypeOrConstParamData, TypeParamProvenance, WherePredicate,
WherePredicateTypeTarget, WherePredicateTypeTarget,
}, },
hir::BindingAnnotation,
lang_item::LangItem, lang_item::LangItem,
resolver::{HasResolver, TypeNs}, resolver::{HasResolver, TypeNs},
type_ref::{TraitBoundModifier, TypeRef}, type_ref::{TraitBoundModifier, TypeRef},
@ -35,7 +34,7 @@ use crate::{
layout::{Layout, TagEncoding}, layout::{Layout, TagEncoding},
mir::pad16, mir::pad16,
ChalkTraitId, Const, ConstScalar, GenericArg, Interner, Substitution, TraitRef, TraitRefExt, ChalkTraitId, Const, ConstScalar, GenericArg, Interner, Substitution, TraitRef, TraitRefExt,
Ty, TyExt, WhereClause, Ty, WhereClause,
}; };
pub(crate) fn fn_traits( pub(crate) fn fn_traits(
@ -395,23 +394,6 @@ pub fn is_fn_unsafe_to_call(db: &dyn HirDatabase, func: FunctionId) -> bool {
} }
} }
pub(crate) fn pattern_matching_dereference_count(
cond_ty: &mut Ty,
binding_mode: &mut BindingAnnotation,
) -> usize {
let mut r = 0;
while let Some((ty, _, mu)) = cond_ty.as_reference() {
if mu == Mutability::Mut && *binding_mode != BindingAnnotation::Ref {
*binding_mode = BindingAnnotation::RefMut;
} else {
*binding_mode = BindingAnnotation::Ref;
}
*cond_ty = ty.clone();
r += 1;
}
r
}
pub(crate) struct UnevaluatedConstEvaluatorFolder<'a> { pub(crate) struct UnevaluatedConstEvaluatorFolder<'a> {
pub(crate) db: &'a dyn HirDatabase, pub(crate) db: &'a dyn HirDatabase,
} }