handle match scrutinee in closure captures

This commit is contained in:
hkalbasi 2023-05-19 12:00:19 +03:30
parent e8ae2d3976
commit c5ea2d7adc
8 changed files with 137 additions and 10 deletions

View file

@ -221,15 +221,15 @@ impl Body {
pub fn walk_bindings_in_pat(&self, pat_id: PatId, mut f: impl FnMut(BindingId)) {
self.walk_pats(pat_id, &mut |pat| {
if let Pat::Bind { id, .. } = pat {
if let Pat::Bind { id, .. } = &self[pat] {
f(*id);
}
});
}
pub fn walk_pats(&self, pat_id: PatId, f: &mut impl FnMut(&Pat)) {
pub fn walk_pats(&self, pat_id: PatId, f: &mut impl FnMut(PatId)) {
let pat = &self[pat_id];
f(pat);
f(pat_id);
match pat {
Pat::Range { .. }
| Pat::Lit(..)

View file

@ -147,7 +147,7 @@ impl<'a> PatCtxt<'a> {
}
hir_def::hir::Pat::Bind { id, subpat, .. } => {
let bm = self.infer.pat_binding_modes[&pat];
let bm = self.infer.binding_modes[id];
ty = &self.infer[id];
let name = &self.body.bindings[id].name;
match (bm, ty.kind(Interner)) {

View file

@ -389,7 +389,7 @@ pub struct InferenceResult {
standard_types: InternedStandardTypes,
/// Stores the types which were implicitly dereferenced in pattern binding modes.
pub pat_adjustments: FxHashMap<PatId, Vec<Ty>>,
pub pat_binding_modes: FxHashMap<PatId, BindingMode>,
pub binding_modes: ArenaMap<BindingId, BindingMode>,
pub expr_adjustments: FxHashMap<ExprId, Vec<Adjustment>>,
pub(crate) closure_info: FxHashMap<ClosureId, (Vec<CapturedItem>, FnTrait)>,
// FIXME: remove this field

View file

@ -504,9 +504,27 @@ impl InferenceContext<'_> {
self.consume_exprs(args.iter().copied());
}
Expr::Match { expr, arms } => {
self.consume_expr(*expr);
for arm in arms.iter() {
self.consume_expr(arm.expr);
if let Some(guard) = arm.guard {
self.consume_expr(guard);
}
}
self.walk_expr(*expr);
if let Some(discr_place) = self.place_of_expr(*expr) {
if self.is_upvar(&discr_place) {
let mut capture_mode = None;
for arm in arms.iter() {
self.walk_pat(&mut capture_mode, arm.pat);
}
if let Some(c) = capture_mode {
self.push_capture(CapturedItemWithoutTy {
place: discr_place,
kind: c,
span: (*expr).into(),
})
}
}
}
}
Expr::Break { expr, label: _ }
@ -618,6 +636,57 @@ impl InferenceContext<'_> {
}
}
fn walk_pat(&mut self, result: &mut Option<CaptureKind>, pat: PatId) {
let mut update_result = |ck: CaptureKind| match result {
Some(r) => {
*r = cmp::max(*r, ck);
}
None => *result = Some(ck),
};
self.body.walk_pats(pat, &mut |p| match &self.body[p] {
Pat::Ref { .. }
| Pat::Box { .. }
| Pat::Missing
| Pat::Wild
| Pat::Tuple { .. }
| Pat::Or(_) => (),
Pat::TupleStruct { .. } | Pat::Record { .. } => {
if let Some(variant) = self.result.variant_resolution_for_pat(p) {
let adt = variant.adt_id();
let is_multivariant = match adt {
hir_def::AdtId::EnumId(e) => self.db.enum_data(e).variants.len() != 1,
_ => false,
};
if is_multivariant {
update_result(CaptureKind::ByRef(BorrowKind::Shared));
}
}
}
Pat::Slice { .. }
| Pat::ConstBlock(_)
| Pat::Path(_)
| Pat::Lit(_)
| Pat::Range { .. } => {
update_result(CaptureKind::ByRef(BorrowKind::Shared));
}
Pat::Bind { id, .. } => match self.result.binding_modes[*id] {
crate::BindingMode::Move => {
if self.is_ty_copy(self.result.type_of_binding[*id].clone()) {
update_result(CaptureKind::ByRef(BorrowKind::Shared));
} else {
update_result(CaptureKind::ByValue);
}
}
crate::BindingMode::Ref(r) => match r {
Mutability::Mut => update_result(CaptureKind::ByRef(BorrowKind::Mut {
allow_two_phase_borrow: false,
})),
Mutability::Not => update_result(CaptureKind::ByRef(BorrowKind::Shared)),
},
},
});
}
fn expr_ty(&self, expr: ExprId) -> Ty {
self.result[expr].clone()
}

View file

@ -340,7 +340,7 @@ impl<'a> InferenceContext<'a> {
} else {
BindingMode::convert(mode)
};
self.result.pat_binding_modes.insert(pat, mode);
self.result.binding_modes.insert(binding, mode);
let inner_ty = match subpat {
Some(subpat) => self.infer_pat(subpat, &expected, default_bm),
@ -439,7 +439,7 @@ fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool {
pub(super) fn contains_explicit_ref_binding(body: &Body, pat_id: PatId) -> bool {
let mut res = false;
body.walk_pats(pat_id, &mut |pat| {
res |= matches!(pat, Pat::Bind { id, .. } if body.bindings[*id].mode == BindingAnnotation::Ref);
res |= matches!(body[pat], Pat::Bind { id, .. } if body.bindings[id].mode == BindingAnnotation::Ref);
});
res
}

View file

@ -182,6 +182,43 @@ fn capture_specific_fields() {
}
}
#[test]
fn match_pattern() {
size_and_align_expr! {
struct X(i64, i32, (u8, i128));
let y: X = X(2, 5, (7, 3));
move |x: i64| {
match y {
_ => x,
}
}
}
size_and_align_expr! {
minicore: copy;
stmts: [
struct X(i64, i32, (u8, i128));
let y: X = X(2, 5, (7, 3));
]
|x: i64| {
match y {
X(_a, _b, _c) => x,
}
}
}
size_and_align_expr! {
minicore: copy;
stmts: [
struct X(i64, i32, (u8, i128));
let y: X = X(2, 5, (7, 3));
]
|x: i64| {
match y {
_y => x,
}
}
}
}
#[test]
fn ellipsis_pattern() {
size_and_align_expr! {

View file

@ -236,9 +236,9 @@ impl SourceAnalyzer {
_db: &dyn HirDatabase,
pat: &ast::IdentPat,
) -> Option<BindingMode> {
let pat_id = self.pat_id(&pat.clone().into())?;
let binding_id = self.binding_id_of_pat(pat)?;
let infer = self.infer.as_ref()?;
infer.pat_binding_modes.get(&pat_id).map(|bm| match bm {
infer.binding_modes.get(binding_id).map(|bm| match bm {
hir_ty::BindingMode::Move => BindingMode::Move,
hir_ty::BindingMode::Ref(hir_ty::Mutability::Mut) => BindingMode::Ref(Mutability::Mut),
hir_ty::BindingMode::Ref(hir_ty::Mutability::Not) => {

View file

@ -151,4 +151,25 @@ fn f(x: &mut X<'_>) {
"#,
);
}
#[test]
fn no_false_positive_match_and_closure_capture() {
check_diagnostics(
r#"
//- minicore: copy, fn
enum X {
Foo(u16),
Bar,
}
fn main() {
let x = &X::Bar;
let c = || match *x {
X::Foo(t) => t,
_ => 5,
};
}
"#,
);
}
}