//! Inference of closure parameter types based on the closure's expected type. use std::{cmp, collections::HashMap, convert::Infallible, mem}; use chalk_ir::{ cast::Cast, fold::{FallibleTypeFolder, TypeFoldable}, AliasEq, AliasTy, BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, WhereClause, }; use hir_def::{ data::adt::VariantData, hir::{Array, BinaryOp, BindingId, CaptureBy, Expr, ExprId, Pat, PatId, Statement, UnaryOp}, lang_item::LangItem, resolver::{resolver_for_expr, ResolveValueResult, ValueNs}, DefWithBodyId, FieldId, HasModule, VariantId, }; use hir_expand::name; use rustc_hash::FxHashMap; use smallvec::SmallVec; use stdx::never; use crate::{ db::HirDatabase, from_placeholder_idx, make_binders, mir::{BorrowKind, MirSpan, ProjectionElem}, static_lifetime, to_chalk_trait_id, traits::FnTrait, utils::{self, generics, Generics}, Adjust, Adjustment, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, FnPointer, FnSig, Interner, Substitution, Ty, TyExt, }; use super::{Expectation, InferenceContext}; impl InferenceContext<'_> { // This function handles both closures and generators. pub(super) fn deduce_closure_type_from_expectations( &mut self, closure_expr: ExprId, closure_ty: &Ty, sig_ty: &Ty, expectation: &Expectation, ) { let expected_ty = match expectation.to_option(&mut self.table) { Some(ty) => ty, None => return, }; // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here. let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty); // Generators are not Fn* so return early. if matches!(closure_ty.kind(Interner), TyKind::Generator(..)) { return; } // Deduction based on the expected `dyn Fn` is done separately. if let TyKind::Dyn(dyn_ty) = expected_ty.kind(Interner) { if let Some(sig) = self.deduce_sig_from_dyn_ty(dyn_ty) { let expected_sig_ty = TyKind::Function(sig).intern(Interner); self.unify(sig_ty, &expected_sig_ty); } } } fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option { // Search for a predicate like `<$self as FnX>::Output == Ret` let fn_traits: SmallVec<[ChalkTraitId; 3]> = utils::fn_traits(self.db.upcast(), self.owner.module(self.db.upcast()).krate()) .map(to_chalk_trait_id) .collect(); let self_ty = self.result.standard_types.unknown.clone(); let bounds = dyn_ty.bounds.clone().substitute(Interner, &[self_ty.cast(Interner)]); for bound in bounds.iter(Interner) { // NOTE(skip_binders): the extracted types are rebound by the returned `FnPointer` if let WhereClause::AliasEq(AliasEq { alias: AliasTy::Projection(projection), ty }) = bound.skip_binders() { let assoc_data = self.db.associated_ty_data(projection.associated_ty_id); if !fn_traits.contains(&assoc_data.trait_id) { return None; } // Skip `Self`, get the type argument. let arg = projection.substitution.as_slice(Interner).get(1)?; if let Some(subst) = arg.ty(Interner)?.as_tuple() { let generic_args = subst.as_slice(Interner); let mut sig_tys = Vec::with_capacity(generic_args.len() + 1); for arg in generic_args { sig_tys.push(arg.ty(Interner)?.clone()); } sig_tys.push(ty.clone()); cov_mark::hit!(dyn_fn_param_informs_call_site_closure_signature); return Some(FnPointer { num_binders: bound.len(Interner), sig: FnSig { abi: (), safety: chalk_ir::Safety::Safe, variadic: false }, substitution: FnSubst(Substitution::from_iter(Interner, sig_tys)), }); } } } None } } // The below functions handle capture and closure kind (Fn, FnMut, ..) #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(crate) struct HirPlace { pub(crate) local: BindingId, pub(crate) projections: Vec>, } impl HirPlace { fn ty(&self, ctx: &mut InferenceContext<'_>) -> Ty { let mut ty = ctx.table.resolve_completely(ctx.result[self.local].clone()); for p in &self.projections { ty = p.projected_ty( ty, ctx.db, |_, _, _| { unreachable!("Closure field only happens in MIR"); }, ctx.owner.module(ctx.db.upcast()).krate(), ); } ty.clone() } fn capture_kind_of_truncated_place( &self, mut current_capture: CaptureKind, len: usize, ) -> CaptureKind { match current_capture { CaptureKind::ByRef(BorrowKind::Mut { .. }) => { if self.projections[len..].iter().any(|x| *x == ProjectionElem::Deref) { current_capture = CaptureKind::ByRef(BorrowKind::Unique); } } _ => (), } current_capture } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum CaptureKind { ByRef(BorrowKind), ByValue, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct CapturedItem { pub(crate) place: HirPlace, pub(crate) kind: CaptureKind, pub(crate) span: MirSpan, pub(crate) ty: Binders, } impl CapturedItem { pub fn local(&self) -> BindingId { self.place.local } pub fn ty(&self, subst: &Substitution) -> Ty { self.ty.clone().substitute(Interner, utils::ClosureSubst(subst).parent_subst()) } pub fn kind(&self) -> CaptureKind { self.kind } pub fn display_place(&self, owner: DefWithBodyId, db: &dyn HirDatabase) -> String { let body = db.body(owner); let mut result = body[self.place.local].name.display(db.upcast()).to_string(); let mut field_need_paren = false; for proj in &self.place.projections { match proj { ProjectionElem::Deref => { result = format!("*{result}"); field_need_paren = true; } ProjectionElem::Field(f) => { if field_need_paren { result = format!("({result})"); } let variant_data = f.parent.variant_data(db.upcast()); let field = match &*variant_data { VariantData::Record(fields) => fields[f.local_id] .name .as_str() .unwrap_or("[missing field]") .to_string(), VariantData::Tuple(fields) => fields .iter() .position(|x| x.0 == f.local_id) .unwrap_or_default() .to_string(), VariantData::Unit => "[missing field]".to_string(), }; result = format!("{result}.{field}"); field_need_paren = false; } &ProjectionElem::TupleOrClosureField(field) => { if field_need_paren { result = format!("({result})"); } result = format!("{result}.{field}"); field_need_paren = false; } ProjectionElem::Index(_) | ProjectionElem::ConstantIndex { .. } | ProjectionElem::Subslice { .. } | ProjectionElem::OpaqueCast(_) => { never!("Not happen in closure capture"); continue; } } } result } } #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct CapturedItemWithoutTy { pub(crate) place: HirPlace, pub(crate) kind: CaptureKind, pub(crate) span: MirSpan, } impl CapturedItemWithoutTy { fn with_ty(self, ctx: &mut InferenceContext<'_>) -> CapturedItem { let ty = self.place.ty(ctx).clone(); let ty = match &self.kind { CaptureKind::ByValue => ty, CaptureKind::ByRef(bk) => { let m = match bk { BorrowKind::Mut { .. } => Mutability::Mut, _ => Mutability::Not, }; TyKind::Ref(m, static_lifetime(), ty).intern(Interner) } }; return CapturedItem { place: self.place, kind: self.kind, span: self.span, ty: replace_placeholder_with_binder(ctx.db, ctx.owner, ty), }; fn replace_placeholder_with_binder( db: &dyn HirDatabase, owner: DefWithBodyId, ty: Ty, ) -> Binders { struct Filler<'a> { db: &'a dyn HirDatabase, generics: Generics, } impl FallibleTypeFolder for Filler<'_> { type Error = (); fn as_dyn(&mut self) -> &mut dyn FallibleTypeFolder { self } fn interner(&self) -> Interner { Interner } fn try_fold_free_placeholder_const( &mut self, ty: chalk_ir::Ty, idx: chalk_ir::PlaceholderIndex, outer_binder: DebruijnIndex, ) -> Result, Self::Error> { let x = from_placeholder_idx(self.db, idx); let Some(idx) = self.generics.param_idx(x) else { return Err(()); }; Ok(BoundVar::new(outer_binder, idx).to_const(Interner, ty)) } fn try_fold_free_placeholder_ty( &mut self, idx: chalk_ir::PlaceholderIndex, outer_binder: DebruijnIndex, ) -> std::result::Result { let x = from_placeholder_idx(self.db, idx); let Some(idx) = self.generics.param_idx(x) else { return Err(()); }; Ok(BoundVar::new(outer_binder, idx).to_ty(Interner)) } } let Some(generic_def) = owner.as_generic_def_id() else { return Binders::empty(Interner, ty); }; let filler = &mut Filler { db, generics: generics(db.upcast(), generic_def) }; let result = ty.clone().try_fold_with(filler, DebruijnIndex::INNERMOST).unwrap_or(ty); make_binders(db, &filler.generics, result) } } } impl InferenceContext<'_> { fn place_of_expr(&mut self, tgt_expr: ExprId) -> Option { let r = self.place_of_expr_without_adjust(tgt_expr)?; let default = vec![]; let adjustments = self.result.expr_adjustments.get(&tgt_expr).unwrap_or(&default); apply_adjusts_to_place(r, adjustments) } fn place_of_expr_without_adjust(&mut self, tgt_expr: ExprId) -> Option { match &self.body[tgt_expr] { Expr::Path(p) => { let resolver = resolver_for_expr(self.db.upcast(), self.owner, tgt_expr); if let Some(r) = resolver.resolve_path_in_value_ns(self.db.upcast(), p) { if let ResolveValueResult::ValueNs(v) = r { if let ValueNs::LocalBinding(b) = v { return Some(HirPlace { local: b, projections: vec![] }); } } } } Expr::Field { expr, name } => { let mut place = self.place_of_expr(*expr)?; if let TyKind::Tuple(..) = self.expr_ty(*expr).kind(Interner) { let index = name.as_tuple_index()?; place.projections.push(ProjectionElem::TupleOrClosureField(index)) } else { let field = self.result.field_resolution(tgt_expr)?; place.projections.push(ProjectionElem::Field(field)); } return Some(place); } Expr::UnaryOp { expr, op: UnaryOp::Deref } => { if matches!( self.expr_ty_after_adjustments(*expr).kind(Interner), TyKind::Ref(..) | TyKind::Raw(..) ) { let mut place = self.place_of_expr(*expr)?; place.projections.push(ProjectionElem::Deref); return Some(place); } } _ => (), } None } fn push_capture(&mut self, capture: CapturedItemWithoutTy) { self.current_captures.push(capture); } fn ref_expr(&mut self, expr: ExprId) { if let Some(place) = self.place_of_expr(expr) { self.add_capture(place, CaptureKind::ByRef(BorrowKind::Shared), expr.into()); } self.walk_expr(expr); } fn add_capture(&mut self, place: HirPlace, kind: CaptureKind, span: MirSpan) { if self.is_upvar(&place) { self.push_capture(CapturedItemWithoutTy { place, kind, span }); } } fn mutate_expr(&mut self, expr: ExprId) { if let Some(place) = self.place_of_expr(expr) { self.add_capture( place, CaptureKind::ByRef(BorrowKind::Mut { allow_two_phase_borrow: false }), expr.into(), ); } self.walk_expr(expr); } fn consume_expr(&mut self, expr: ExprId) { if let Some(place) = self.place_of_expr(expr) { self.consume_place(place, expr.into()); } self.walk_expr(expr); } fn consume_place(&mut self, place: HirPlace, span: MirSpan) { if self.is_upvar(&place) { let ty = place.ty(self).clone(); let kind = if self.is_ty_copy(ty) { CaptureKind::ByRef(BorrowKind::Shared) } else { CaptureKind::ByValue }; self.push_capture(CapturedItemWithoutTy { place, kind, span }); } } fn walk_expr_with_adjust(&mut self, tgt_expr: ExprId, adjustment: &[Adjustment]) { if let Some((last, rest)) = adjustment.split_last() { match last.kind { Adjust::NeverToAny | Adjust::Deref(None) | Adjust::Pointer(_) => { self.walk_expr_with_adjust(tgt_expr, rest) } Adjust::Deref(Some(m)) => match m.0 { Some(m) => { self.ref_capture_with_adjusts(m, tgt_expr, rest); } None => unreachable!(), }, Adjust::Borrow(b) => { self.ref_capture_with_adjusts(b.mutability(), tgt_expr, rest); } } } else { self.walk_expr_without_adjust(tgt_expr); } } fn ref_capture_with_adjusts(&mut self, m: Mutability, tgt_expr: ExprId, rest: &[Adjustment]) { let capture_kind = match m { Mutability::Mut => { CaptureKind::ByRef(BorrowKind::Mut { allow_two_phase_borrow: false }) } Mutability::Not => CaptureKind::ByRef(BorrowKind::Shared), }; if let Some(place) = self.place_of_expr_without_adjust(tgt_expr) { if let Some(place) = apply_adjusts_to_place(place, rest) { self.add_capture(place, capture_kind, tgt_expr.into()); } } self.walk_expr_with_adjust(tgt_expr, rest); } fn walk_expr(&mut self, tgt_expr: ExprId) { if let Some(x) = self.result.expr_adjustments.get_mut(&tgt_expr) { // FIXME: this take is completely unneeded, and just is here to make borrow checker // happy. Remove it if you can. let x_taken = mem::take(x); self.walk_expr_with_adjust(tgt_expr, &x_taken); *self.result.expr_adjustments.get_mut(&tgt_expr).unwrap() = x_taken; } else { self.walk_expr_without_adjust(tgt_expr); } } fn walk_expr_without_adjust(&mut self, tgt_expr: ExprId) { match &self.body[tgt_expr] { Expr::If { condition, then_branch, else_branch } => { self.consume_expr(*condition); self.consume_expr(*then_branch); if let &Some(expr) = else_branch { self.consume_expr(expr); } } Expr::Async { statements, tail, .. } | Expr::Unsafe { statements, tail, .. } | Expr::Block { statements, tail, .. } => { for s in statements.iter() { match s { Statement::Let { pat, type_ref: _, initializer, else_branch } => { if let Some(else_branch) = else_branch { self.consume_expr(*else_branch); if let Some(initializer) = initializer { self.consume_expr(*initializer); } return; } if let Some(initializer) = initializer { self.walk_expr(*initializer); if let Some(place) = self.place_of_expr(*initializer) { self.consume_with_pat(place, *pat); } } } Statement::Expr { expr, has_semi: _ } => { self.consume_expr(*expr); } } } if let Some(tail) = tail { self.consume_expr(*tail); } } Expr::While { condition, body, label: _ } => { self.consume_expr(*condition); self.consume_expr(*body); } Expr::Call { callee, args, is_assignee_expr: _ } => { self.consume_expr(*callee); self.consume_exprs(args.iter().copied()); } Expr::MethodCall { receiver, args, .. } => { self.consume_expr(*receiver); self.consume_exprs(args.iter().copied()); } Expr::Match { expr, arms } => { for arm in arms.iter() { self.consume_expr(arm.expr); if let Some(guard) = arm.guard { self.consume_expr(guard); } } self.walk_expr(*expr); if let Some(discr_place) = self.place_of_expr(*expr) { if self.is_upvar(&discr_place) { let mut capture_mode = None; for arm in arms.iter() { self.walk_pat(&mut capture_mode, arm.pat); } if let Some(c) = capture_mode { self.push_capture(CapturedItemWithoutTy { place: discr_place, kind: c, span: (*expr).into(), }) } } } } Expr::Break { expr, label: _ } | Expr::Return { expr } | Expr::Yield { expr } | Expr::Yeet { expr } => { if let &Some(expr) = expr { self.consume_expr(expr); } } Expr::RecordLit { fields, spread, .. } => { if let &Some(expr) = spread { self.consume_expr(expr); } self.consume_exprs(fields.iter().map(|x| x.expr)); } Expr::Field { expr, name: _ } => self.select_from_expr(*expr), Expr::UnaryOp { expr, op: UnaryOp::Deref } => { if matches!( self.expr_ty_after_adjustments(*expr).kind(Interner), TyKind::Ref(..) | TyKind::Raw(..) ) { self.select_from_expr(*expr); } else if let Some((f, _)) = self.result.method_resolution(tgt_expr) { let mutability = 'b: { if let Some(deref_trait) = self.resolve_lang_item(LangItem::DerefMut).and_then(|x| x.as_trait()) { if let Some(deref_fn) = self.db.trait_data(deref_trait).method_by_name(&name![deref_mut]) { break 'b deref_fn == f; } } false }; if mutability { self.mutate_expr(*expr); } else { self.ref_expr(*expr); } } else { self.select_from_expr(*expr); } } Expr::UnaryOp { expr, op: _ } | Expr::Array(Array::Repeat { initializer: expr, repeat: _ }) | Expr::Await { expr } | Expr::Loop { body: expr, label: _ } | Expr::Let { pat: _, expr } | Expr::Box { expr } | Expr::Cast { expr, type_ref: _ } => { self.consume_expr(*expr); } Expr::Ref { expr, rawness: _, mutability } => match mutability { hir_def::type_ref::Mutability::Shared => self.ref_expr(*expr), hir_def::type_ref::Mutability::Mut => self.mutate_expr(*expr), }, Expr::BinaryOp { lhs, rhs, op } => { let Some(op) = op else { return; }; if matches!(op, BinaryOp::Assignment { .. }) { self.mutate_expr(*lhs); self.consume_expr(*rhs); return; } self.consume_expr(*lhs); self.consume_expr(*rhs); } Expr::Range { lhs, rhs, range_type: _ } => { if let &Some(expr) = lhs { self.consume_expr(expr); } if let &Some(expr) = rhs { self.consume_expr(expr); } } Expr::Index { base, index } => { self.select_from_expr(*base); self.consume_expr(*index); } Expr::Closure { .. } => { let ty = self.expr_ty(tgt_expr); let TyKind::Closure(id, _) = ty.kind(Interner) else { never!("closure type is always closure"); return; }; let (captures, _) = self.result.closure_info.get(id).expect( "We sort closures, so we should always have data for inner closures", ); let mut cc = mem::take(&mut self.current_captures); cc.extend(captures.iter().filter(|x| self.is_upvar(&x.place)).map(|x| { CapturedItemWithoutTy { place: x.place.clone(), kind: x.kind, span: x.span } })); self.current_captures = cc; } Expr::Array(Array::ElementList { elements: exprs, is_assignee_expr: _ }) | Expr::Tuple { exprs, is_assignee_expr: _ } => { self.consume_exprs(exprs.iter().copied()) } Expr::Missing | Expr::Continue { .. } | Expr::Path(_) | Expr::Literal(_) | Expr::Const(_) | Expr::Underscore => (), } } fn walk_pat(&mut self, result: &mut Option, pat: PatId) { let mut update_result = |ck: CaptureKind| match result { Some(r) => { *r = cmp::max(*r, ck); } None => *result = Some(ck), }; self.walk_pat_inner( pat, &mut update_result, BorrowKind::Mut { allow_two_phase_borrow: false }, ); } fn walk_pat_inner( &mut self, p: PatId, update_result: &mut impl FnMut(CaptureKind), mut for_mut: BorrowKind, ) { match &self.body[p] { Pat::Ref { .. } | Pat::Box { .. } | Pat::Missing | Pat::Wild | Pat::Tuple { .. } | Pat::Or(_) => (), Pat::TupleStruct { .. } | Pat::Record { .. } => { if let Some(variant) = self.result.variant_resolution_for_pat(p) { let adt = variant.adt_id(); let is_multivariant = match adt { hir_def::AdtId::EnumId(e) => self.db.enum_data(e).variants.len() != 1, _ => false, }; if is_multivariant { update_result(CaptureKind::ByRef(BorrowKind::Shared)); } } } Pat::Slice { .. } | Pat::ConstBlock(_) | Pat::Path(_) | Pat::Lit(_) | Pat::Range { .. } => { update_result(CaptureKind::ByRef(BorrowKind::Shared)); } Pat::Bind { id, .. } => match self.result.binding_modes[*id] { crate::BindingMode::Move => { if self.is_ty_copy(self.result.type_of_binding[*id].clone()) { update_result(CaptureKind::ByRef(BorrowKind::Shared)); } else { update_result(CaptureKind::ByValue); } } crate::BindingMode::Ref(r) => match r { Mutability::Mut => update_result(CaptureKind::ByRef(for_mut)), Mutability::Not => update_result(CaptureKind::ByRef(BorrowKind::Shared)), }, }, } if self.result.pat_adjustments.get(&p).map_or(false, |x| !x.is_empty()) { for_mut = BorrowKind::Unique; } self.body.walk_pats_shallow(p, |p| self.walk_pat_inner(p, update_result, for_mut)); } fn expr_ty(&self, expr: ExprId) -> Ty { self.result[expr].clone() } fn expr_ty_after_adjustments(&self, e: ExprId) -> Ty { let mut ty = None; if let Some(x) = self.result.expr_adjustments.get(&e) { if let Some(x) = x.last() { ty = Some(x.target.clone()); } } ty.unwrap_or_else(|| self.expr_ty(e)) } fn is_upvar(&self, place: &HirPlace) -> bool { let b = &self.body[place.local]; if let Some(c) = self.current_closure { let (_, root) = self.db.lookup_intern_closure(c.into()); return b.is_upvar(root); } false } fn is_ty_copy(&mut self, ty: Ty) -> bool { if let TyKind::Closure(id, _) = ty.kind(Interner) { // FIXME: We handle closure as a special case, since chalk consider every closure as copy. We // should probably let chalk know which closures are copy, but I don't know how doing it // without creating query cycles. return self.result.closure_info.get(id).map(|x| x.1 == FnTrait::Fn).unwrap_or(true); } self.table.resolve_completely(ty).is_copy(self.db, self.owner) } fn select_from_expr(&mut self, expr: ExprId) { self.walk_expr(expr); } fn adjust_for_move_closure(&mut self) { for capture in &mut self.current_captures { if let Some(first_deref) = capture.place.projections.iter().position(|proj| *proj == ProjectionElem::Deref) { capture.place.projections.truncate(first_deref); } capture.kind = CaptureKind::ByValue; } } fn minimize_captures(&mut self) { self.current_captures.sort_by_key(|x| x.place.projections.len()); let mut hash_map = HashMap::::new(); let result = mem::take(&mut self.current_captures); for item in result { let mut lookup_place = HirPlace { local: item.place.local, projections: vec![] }; let mut it = item.place.projections.iter(); let prev_index = loop { if let Some(k) = hash_map.get(&lookup_place) { break Some(*k); } match it.next() { Some(x) => lookup_place.projections.push(x.clone()), None => break None, } }; match prev_index { Some(p) => { let len = self.current_captures[p].place.projections.len(); let kind_after_truncate = item.place.capture_kind_of_truncated_place(item.kind, len); self.current_captures[p].kind = cmp::max(kind_after_truncate, self.current_captures[p].kind); } None => { hash_map.insert(item.place.clone(), self.current_captures.len()); self.current_captures.push(item); } } } } fn consume_with_pat(&mut self, mut place: HirPlace, pat: PatId) { let cnt = self.result.pat_adjustments.get(&pat).map(|x| x.len()).unwrap_or_default(); place.projections = place .projections .iter() .cloned() .chain((0..cnt).map(|_| ProjectionElem::Deref)) .collect::>() .into(); match &self.body[pat] { Pat::Missing | Pat::Wild => (), Pat::Tuple { args, ellipsis } => { let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len())); let field_count = match self.result[pat].kind(Interner) { TyKind::Tuple(_, s) => s.len(Interner), _ => return, }; let fields = 0..field_count; let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); for (arg, i) in it { let mut p = place.clone(); p.projections.push(ProjectionElem::TupleOrClosureField(i)); self.consume_with_pat(p, *arg); } } Pat::Or(pats) => { for pat in pats.iter() { self.consume_with_pat(place.clone(), *pat); } } Pat::Record { args, .. } => { let Some(variant) = self.result.variant_resolution_for_pat(pat) else { return; }; match variant { VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { self.consume_place(place, pat.into()) } VariantId::StructId(s) => { let vd = &*self.db.struct_data(s).variant_data; for field_pat in args.iter() { let arg = field_pat.pat; let Some(local_id) = vd.field(&field_pat.name) else { continue; }; let mut p = place.clone(); p.projections.push(ProjectionElem::Field(FieldId { parent: variant.into(), local_id, })); self.consume_with_pat(p, arg); } } } } Pat::Range { .. } | Pat::Slice { .. } | Pat::ConstBlock(_) | Pat::Path(_) | Pat::Lit(_) => self.consume_place(place, pat.into()), Pat::Bind { id, subpat: _ } => { let mode = self.result.binding_modes[*id]; let capture_kind = match mode { BindingMode::Move => { self.consume_place(place, pat.into()); return; } BindingMode::Ref(Mutability::Not) => BorrowKind::Shared, BindingMode::Ref(Mutability::Mut) => { BorrowKind::Mut { allow_two_phase_borrow: false } } }; self.add_capture(place, CaptureKind::ByRef(capture_kind), pat.into()); } Pat::TupleStruct { path: _, args, ellipsis } => { let Some(variant) = self.result.variant_resolution_for_pat(pat) else { return; }; match variant { VariantId::EnumVariantId(_) | VariantId::UnionId(_) => { self.consume_place(place, pat.into()) } VariantId::StructId(s) => { let vd = &*self.db.struct_data(s).variant_data; let (al, ar) = args.split_at(ellipsis.unwrap_or(args.len())); let fields = vd.fields().iter(); let it = al.iter().zip(fields.clone()).chain(ar.iter().rev().zip(fields.rev())); for (arg, (i, _)) in it { let mut p = place.clone(); p.projections.push(ProjectionElem::Field(FieldId { parent: variant.into(), local_id: i, })); self.consume_with_pat(p, *arg); } } } } Pat::Ref { pat, mutability: _ } => { place.projections.push(ProjectionElem::Deref); self.consume_with_pat(place, *pat) } Pat::Box { .. } => (), // not supported } } fn consume_exprs(&mut self, exprs: impl Iterator) { for expr in exprs { self.consume_expr(expr); } } fn closure_kind(&self) -> FnTrait { let mut r = FnTrait::Fn; for x in &self.current_captures { r = cmp::min( r, match &x.kind { CaptureKind::ByRef(BorrowKind::Unique | BorrowKind::Mut { .. }) => { FnTrait::FnMut } CaptureKind::ByRef(BorrowKind::Shallow | BorrowKind::Shared) => FnTrait::Fn, CaptureKind::ByValue => FnTrait::FnOnce, }, ) } r } fn analyze_closure(&mut self, closure: ClosureId) -> FnTrait { let (_, root) = self.db.lookup_intern_closure(closure.into()); self.current_closure = Some(closure); let Expr::Closure { body, capture_by, .. } = &self.body[root] else { unreachable!("Closure expression id is always closure"); }; self.consume_expr(*body); for item in &self.current_captures { if matches!(item.kind, CaptureKind::ByRef(BorrowKind::Mut { .. })) && !item.place.projections.contains(&ProjectionElem::Deref) { // FIXME: remove the `mutated_bindings_in_closure` completely and add proper fake reads in // MIR. I didn't do that due duplicate diagnostics. self.result.mutated_bindings_in_closure.insert(item.place.local); } } // closure_kind should be done before adjust_for_move_closure let closure_kind = self.closure_kind(); match capture_by { CaptureBy::Value => self.adjust_for_move_closure(), CaptureBy::Ref => (), } self.minimize_captures(); let result = mem::take(&mut self.current_captures); let captures = result.into_iter().map(|x| x.with_ty(self)).collect::>(); self.result.closure_info.insert(closure, (captures, closure_kind)); closure_kind } pub(crate) fn infer_closures(&mut self) { let deferred_closures = self.sort_closures(); for (closure, exprs) in deferred_closures.into_iter().rev() { self.current_captures = vec![]; let kind = self.analyze_closure(closure); for (derefed_callee, callee_ty, params, expr) in exprs { if let &Expr::Call { callee, .. } = &self.body[expr] { let mut adjustments = self.result.expr_adjustments.remove(&callee).unwrap_or_default(); self.write_fn_trait_method_resolution( kind, &derefed_callee, &mut adjustments, &callee_ty, ¶ms, expr, ); self.result.expr_adjustments.insert(callee, adjustments); } } } } /// We want to analyze some closures before others, to have a correct analysis: /// * We should analyze nested closures before the parent, since the parent should capture some of /// the things that its children captures. /// * If a closure calls another closure, we need to analyze the callee, to find out how we should /// capture it (e.g. by move for FnOnce) /// /// These dependencies are collected in the main inference. We do a topological sort in this function. It /// will consume the `deferred_closures` field and return its content in a sorted vector. fn sort_closures(&mut self) -> Vec<(ClosureId, Vec<(Ty, Ty, Vec, ExprId)>)> { let mut deferred_closures = mem::take(&mut self.deferred_closures); let mut dependents_count: FxHashMap = deferred_closures.keys().map(|x| (*x, 0)).collect(); for (_, deps) in &self.closure_dependencies { for dep in deps { *dependents_count.entry(*dep).or_default() += 1; } } let mut queue: Vec<_> = deferred_closures.keys().copied().filter(|x| dependents_count[x] == 0).collect(); let mut result = vec![]; while let Some(x) = queue.pop() { if let Some(d) = deferred_closures.remove(&x) { result.push((x, d)); } for dep in self.closure_dependencies.get(&x).into_iter().flat_map(|x| x.iter()) { let cnt = dependents_count.get_mut(dep).unwrap(); *cnt -= 1; if *cnt == 0 { queue.push(*dep); } } } result } } fn apply_adjusts_to_place(mut r: HirPlace, adjustments: &[Adjustment]) -> Option { for adj in adjustments { match &adj.kind { Adjust::Deref(None) => { r.projections.push(ProjectionElem::Deref); } _ => return None, } } Some(r) }