From 4317927231961c7e8e02c020864586392df862bf Mon Sep 17 00:00:00 2001 From: Chayim Refael Friedman Date: Wed, 30 Oct 2024 21:24:19 +0200 Subject: [PATCH] Avoid interior mutability in `TyLoweringContext` This requires some serious code juggling. --- crates/hir-ty/src/chalk_db.rs | 18 +- crates/hir-ty/src/infer.rs | 2 +- crates/hir-ty/src/infer/path.rs | 2 +- crates/hir-ty/src/lower.rs | 405 +++++++++++++++----------------- 4 files changed, 202 insertions(+), 225 deletions(-) diff --git a/crates/hir-ty/src/chalk_db.rs b/crates/hir-ty/src/chalk_db.rs index 4bc78afacc..05cd7bd37b 100644 --- a/crates/hir-ty/src/chalk_db.rs +++ b/crates/hir-ty/src/chalk_db.rs @@ -615,7 +615,7 @@ pub(crate) fn associated_ty_data_query( let type_alias_data = db.type_alias_data(type_alias); let generic_params = generics(db.upcast(), type_alias.into()); let resolver = hir_def::resolver::HasResolver::resolver(type_alias, db.upcast()); - let ctx = + let mut ctx = crate::TyLoweringContext::new(db, &resolver, &type_alias_data.types_map, type_alias.into()) .with_type_param_mode(crate::lower::ParamLoweringMode::Variable); @@ -627,14 +627,16 @@ pub(crate) fn associated_ty_data_query( .build(); let self_ty = TyKind::Alias(AliasTy::Projection(pro_ty)).intern(Interner); - let mut bounds: Vec<_> = type_alias_data - .bounds - .iter() - .flat_map(|bound| ctx.lower_type_bound(bound, self_ty.clone(), false)) - .filter_map(|pred| generic_predicate_to_inline_bound(db, &pred, &self_ty)) - .collect(); + let mut bounds = Vec::new(); + for bound in &type_alias_data.bounds { + ctx.lower_type_bound(bound, self_ty.clone(), false).for_each(|pred| { + if let Some(pred) = generic_predicate_to_inline_bound(db, &pred, &self_ty) { + bounds.push(pred); + } + }); + } - if !ctx.unsized_types.borrow().contains(&self_ty) { + if !ctx.unsized_types.contains(&self_ty) { let sized_trait = db .lang_item(resolver.krate(), LangItem::Sized) .and_then(|lang_item| lang_item.as_trait().map(to_chalk_trait_id)); diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 3685ed5696..01e0b635b2 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -1420,7 +1420,7 @@ impl<'a> InferenceContext<'a> { Some(path) => path, None => return (self.err_ty(), None), }; - let ctx = crate::lower::TyLoweringContext::new( + let mut ctx = crate::lower::TyLoweringContext::new( self.db, &self.resolver, &self.body.types, diff --git a/crates/hir-ty/src/infer/path.rs b/crates/hir-ty/src/infer/path.rs index 442daa9f9e..7550d197a3 100644 --- a/crates/hir-ty/src/infer/path.rs +++ b/crates/hir-ty/src/infer/path.rs @@ -151,7 +151,7 @@ impl InferenceContext<'_> { let last = path.segments().last()?; // Don't use `self.make_ty()` here as we need `orig_ns`. - let ctx = crate::lower::TyLoweringContext::new( + let mut ctx = crate::lower::TyLoweringContext::new( self.db, &self.resolver, &self.body.types, diff --git a/crates/hir-ty/src/lower.rs b/crates/hir-ty/src/lower.rs index ee9fd02cdf..b868ea95f8 100644 --- a/crates/hir-ty/src/lower.rs +++ b/crates/hir-ty/src/lower.rs @@ -6,8 +6,8 @@ //! //! This usually involves resolving names, collecting generic arguments etc. use std::{ - cell::{Cell, OnceCell, RefCell, RefMut}, - iter, + cell::OnceCell, + iter, mem, ops::{self, Not as _}, }; @@ -72,47 +72,32 @@ use crate::{ TraitRefExt, Ty, TyBuilder, TyKind, WhereClause, }; -#[derive(Debug)] -enum ImplTraitLoweringState { +#[derive(Debug, Default)] +struct ImplTraitLoweringState { /// When turning `impl Trait` into opaque types, we have to collect the /// bounds at the same time to get the IDs correct (without becoming too - /// complicated). I don't like using interior mutability (as for the - /// counter), but I've tried and failed to make the lifetimes work for - /// passing around a `&mut TyLoweringContext`. The core problem is that - /// we're grouping the mutable data (the counter and this field) together - /// with the immutable context (the references to the DB and resolver). - /// Splitting this up would be a possible fix. - Opaque(RefCell>), - Param(Cell), - Variable(Cell), - Disallowed, + /// complicated). + mode: ImplTraitLoweringMode, + // This is structured as a struct with fields and not as an enum because it helps with the borrow checker. + opaque_type_data: Arena, + param_and_variable_counter: u16, } impl ImplTraitLoweringState { - fn new(impl_trait_mode: ImplTraitLoweringMode) -> ImplTraitLoweringState { - match impl_trait_mode { - ImplTraitLoweringMode::Opaque => Self::Opaque(RefCell::new(Arena::new())), - ImplTraitLoweringMode::Param => Self::Param(Cell::new(0)), - ImplTraitLoweringMode::Variable => Self::Variable(Cell::new(0)), - ImplTraitLoweringMode::Disallowed => Self::Disallowed, + fn new(mode: ImplTraitLoweringMode) -> ImplTraitLoweringState { + Self { mode, opaque_type_data: Arena::new(), param_and_variable_counter: 0 } + } + fn param(counter: u16) -> Self { + Self { + mode: ImplTraitLoweringMode::Param, + opaque_type_data: Arena::new(), + param_and_variable_counter: counter, } } - - fn take(&self) -> Self { - match self { - Self::Opaque(x) => Self::Opaque(RefCell::new(x.take())), - Self::Param(x) => Self::Param(Cell::new(x.get())), - Self::Variable(x) => Self::Variable(Cell::new(x.get())), - Self::Disallowed => Self::Disallowed, - } - } - - fn swap(&self, impl_trait_mode: &Self) { - match (self, impl_trait_mode) { - (Self::Opaque(x), Self::Opaque(y)) => x.swap(y), - (Self::Param(x), Self::Param(y)) => x.swap(y), - (Self::Variable(x), Self::Variable(y)) => x.swap(y), - (Self::Disallowed, Self::Disallowed) => (), - _ => panic!("mismatched lowering mode"), + fn variable(counter: u16) -> Self { + Self { + mode: ImplTraitLoweringMode::Variable, + opaque_type_data: Arena::new(), + param_and_variable_counter: counter, } } } @@ -137,9 +122,9 @@ pub struct TyLoweringContext<'a> { /// possible currently, so this should be fine for now. pub type_param_mode: ParamLoweringMode, impl_trait_mode: ImplTraitLoweringState, - expander: RefCell>, + expander: Option, /// Tracks types with explicit `?Sized` bounds. - pub(crate) unsized_types: RefCell>, + pub(crate) unsized_types: FxHashSet, } impl<'a> TyLoweringContext<'a> { @@ -159,7 +144,7 @@ impl<'a> TyLoweringContext<'a> { types_source_map: Option<&'a TypesSourceMap>, owner: Option, ) -> Self { - let impl_trait_mode = ImplTraitLoweringState::Disallowed; + let impl_trait_mode = ImplTraitLoweringState::new(ImplTraitLoweringMode::Disallowed); let type_param_mode = ParamLoweringMode::Placeholder; let in_binders = DebruijnIndex::INNERMOST; Self { @@ -172,38 +157,26 @@ impl<'a> TyLoweringContext<'a> { in_binders, impl_trait_mode, type_param_mode, - expander: RefCell::new(None), - unsized_types: RefCell::default(), + expander: None, + unsized_types: FxHashSet::default(), } } pub fn with_debruijn( - &self, + &mut self, debruijn: DebruijnIndex, - f: impl FnOnce(&TyLoweringContext<'_>) -> T, + f: impl FnOnce(&mut TyLoweringContext<'_>) -> T, ) -> T { - let impl_trait_mode = self.impl_trait_mode.take(); - let expander = self.expander.take(); - let unsized_types = self.unsized_types.take(); - let new_ctx = Self { - in_binders: debruijn, - impl_trait_mode, - expander: RefCell::new(expander), - unsized_types: RefCell::new(unsized_types), - generics: self.generics.clone(), - ..*self - }; - let result = f(&new_ctx); - self.impl_trait_mode.swap(&new_ctx.impl_trait_mode); - self.expander.replace(new_ctx.expander.into_inner()); - self.unsized_types.replace(new_ctx.unsized_types.into_inner()); + let old_debruijn = mem::replace(&mut self.in_binders, debruijn); + let result = f(self); + self.in_binders = old_debruijn; result } pub fn with_shifted_in( - &self, + &mut self, debruijn: DebruijnIndex, - f: impl FnOnce(&TyLoweringContext<'_>) -> T, + f: impl FnOnce(&mut TyLoweringContext<'_>) -> T, ) -> T { self.with_debruijn(self.in_binders.shifted_in_from(debruijn), f) } @@ -227,7 +200,7 @@ impl<'a> TyLoweringContext<'a> { } } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)] pub enum ImplTraitLoweringMode { /// `impl Trait` gets lowered into an opaque type that doesn't unify with /// anything except itself. This is used in places where values flow 'out', @@ -244,6 +217,7 @@ pub enum ImplTraitLoweringMode { /// currently checking. Variable, /// `impl Trait` is disallowed and will be an error. + #[default] Disallowed, } @@ -254,12 +228,13 @@ pub enum ParamLoweringMode { } impl<'a> TyLoweringContext<'a> { - pub fn lower_ty(&self, type_ref: TypeRefId) -> Ty { + pub fn lower_ty(&mut self, type_ref: TypeRefId) -> Ty { self.lower_ty_ext(type_ref).0 } - pub fn lower_const(&self, const_ref: &ConstRef, const_type: Ty) -> Const { + pub fn lower_const(&mut self, const_ref: &ConstRef, const_type: Ty) -> Const { let Some(owner) = self.owner else { return unknown_const(const_type) }; + let debruijn = self.in_binders; const_or_path_to_chalk( self.db, self.resolver, @@ -268,7 +243,7 @@ impl<'a> TyLoweringContext<'a> { const_ref, self.type_param_mode, || self.generics(), - self.in_binders, + debruijn, ) } @@ -278,7 +253,7 @@ impl<'a> TyLoweringContext<'a> { .as_ref() } - pub fn lower_ty_ext(&self, type_ref_id: TypeRefId) -> (Ty, Option) { + pub fn lower_ty_ext(&mut self, type_ref_id: TypeRefId) -> (Ty, Option) { let mut res = None; let type_ref = &self.types_map[type_ref_id]; let ty = match type_ref { @@ -337,8 +312,8 @@ impl<'a> TyLoweringContext<'a> { } TypeRef::DynTrait(bounds) => self.lower_dyn_trait(bounds), TypeRef::ImplTrait(bounds) => { - match &self.impl_trait_mode { - ImplTraitLoweringState::Opaque(opaque_type_data) => { + match self.impl_trait_mode.mode { + ImplTraitLoweringMode::Opaque => { let origin = match self.resolver.generic_def() { Some(GenericDefId::FunctionId(it)) => Either::Left(it), Some(GenericDefId::TypeAliasId(it)) => Either::Right(it), @@ -350,7 +325,7 @@ impl<'a> TyLoweringContext<'a> { // this dance is to make sure the data is in the right // place even if we encounter more opaque types while // lowering the bounds - let idx = opaque_type_data.borrow_mut().alloc(ImplTrait { + let idx = self.impl_trait_mode.opaque_type_data.alloc(ImplTrait { bounds: crate::make_single_type_binders(Vec::default()), }); // We don't want to lower the bounds inside the binders @@ -366,7 +341,7 @@ impl<'a> TyLoweringContext<'a> { .with_debruijn(DebruijnIndex::INNERMOST, |ctx| { ctx.lower_impl_trait(bounds, self.resolver.krate()) }); - opaque_type_data.borrow_mut()[idx] = actual_opaque_type_data; + self.impl_trait_mode.opaque_type_data[idx] = actual_opaque_type_data; let impl_trait_id = origin.either( |f| ImplTraitId::ReturnTypeImplTrait(f, idx), @@ -378,11 +353,13 @@ impl<'a> TyLoweringContext<'a> { let parameters = generics.bound_vars_subst(self.db, self.in_binders); TyKind::OpaqueType(opaque_ty_id, parameters).intern(Interner) } - ImplTraitLoweringState::Param(counter) => { - let idx = counter.get(); + ImplTraitLoweringMode::Param => { + let idx = self.impl_trait_mode.param_and_variable_counter; // Count the number of `impl Trait` things that appear within our bounds. // Since those have been emitted as implicit type args already. - counter.set(idx + self.count_impl_traits(type_ref_id) as u16); + self.impl_trait_mode.param_and_variable_counter = + idx + self.count_impl_traits(type_ref_id) as u16; + let db = self.db; let kind = self .generics() .expect("param impl trait lowering must be in a generic def") @@ -398,15 +375,17 @@ impl<'a> TyLoweringContext<'a> { }) .nth(idx as usize) .map_or(TyKind::Error, |id| { - TyKind::Placeholder(to_placeholder_idx(self.db, id.into())) + TyKind::Placeholder(to_placeholder_idx(db, id.into())) }); kind.intern(Interner) } - ImplTraitLoweringState::Variable(counter) => { - let idx = counter.get(); + ImplTraitLoweringMode::Variable => { + let idx = self.impl_trait_mode.param_and_variable_counter; // Count the number of `impl Trait` things that appear within our bounds. // Since t hose have been emitted as implicit type args already. - counter.set(idx + self.count_impl_traits(type_ref_id) as u16); + self.impl_trait_mode.param_and_variable_counter = + idx + self.count_impl_traits(type_ref_id) as u16; + let debruijn = self.in_binders; let kind = self .generics() .expect("variable impl trait lowering must be in a generic def") @@ -423,33 +402,31 @@ impl<'a> TyLoweringContext<'a> { }) .nth(idx as usize) .map_or(TyKind::Error, |id| { - TyKind::BoundVar(BoundVar { debruijn: self.in_binders, index: id }) + TyKind::BoundVar(BoundVar { debruijn, index: id }) }); kind.intern(Interner) } - ImplTraitLoweringState::Disallowed => { + ImplTraitLoweringMode::Disallowed => { // FIXME: report error TyKind::Error.intern(Interner) } } } TypeRef::Macro(macro_call) => { - let (mut expander, recursion_start) = { - match RefMut::filter_map(self.expander.borrow_mut(), Option::as_mut) { + let (expander, recursion_start) = { + match &mut self.expander { // There already is an expander here, this means we are already recursing - Ok(expander) => (expander, false), + Some(expander) => (expander, false), // No expander was created yet, so we are at the start of the expansion recursion // and therefore have to create an expander. - Err(expander) => ( - RefMut::map(expander, |it| { - it.insert(Expander::new( - self.db.upcast(), - macro_call.file_id, - self.resolver.module(), - )) - }), - true, - ), + None => { + let expander = self.expander.insert(Expander::new( + self.db.upcast(), + macro_call.file_id, + self.resolver.module(), + )); + (expander, true) + } } }; let ty = { @@ -473,11 +450,8 @@ impl<'a> TyLoweringContext<'a> { // FIXME: Report syntax errors in expansion here let type_ref = TypeRef::from_ast(&mut ctx, expanded.tree()); - drop(expander); - - // FIXME: That may be better served by mutating `self` then restoring, but this requires - // making it `&mut self`. - let inner_ctx = TyLoweringContext { + // Can't mutate `self`, must create a new instance, because of the lifetimes. + let mut inner_ctx = TyLoweringContext { db: self.db, resolver: self.resolver, generics: self.generics.clone(), @@ -486,30 +460,27 @@ impl<'a> TyLoweringContext<'a> { in_binders: self.in_binders, owner: self.owner, type_param_mode: self.type_param_mode, - impl_trait_mode: self.impl_trait_mode.take(), - expander: RefCell::new(self.expander.take()), - unsized_types: RefCell::new(self.unsized_types.take()), + impl_trait_mode: mem::take(&mut self.impl_trait_mode), + expander: self.expander.take(), + unsized_types: mem::take(&mut self.unsized_types), }; let ty = inner_ctx.lower_ty(type_ref); - self.impl_trait_mode.swap(&inner_ctx.impl_trait_mode); - *self.expander.borrow_mut() = inner_ctx.expander.into_inner(); - *self.unsized_types.borrow_mut() = inner_ctx.unsized_types.into_inner(); + self.impl_trait_mode = inner_ctx.impl_trait_mode; + self.expander = inner_ctx.expander; + self.unsized_types = inner_ctx.unsized_types; - self.expander.borrow_mut().as_mut().unwrap().exit(mark); + self.expander.as_mut().unwrap().exit(mark); Some(ty) } - _ => { - drop(expander); - None - } + _ => None, } }; // drop the expander, resetting it to pre-recursion state if recursion_start { - *self.expander.borrow_mut() = None; + self.expander = None; } ty.unwrap_or_else(|| TyKind::Error.intern(Interner)) } @@ -544,7 +515,7 @@ impl<'a> TyLoweringContext<'a> { } pub(crate) fn lower_ty_relative_path( - &self, + &mut self, ty: Ty, // We need the original resolution to lower `Self::AssocTy` correctly res: Option, @@ -565,7 +536,7 @@ impl<'a> TyLoweringContext<'a> { } pub(crate) fn lower_partly_resolved_path( - &self, + &mut self, resolution: TypeNs, resolved_segment: PathSegment<'_>, remaining_segments: PathSegments<'_>, @@ -706,7 +677,7 @@ impl<'a> TyLoweringContext<'a> { self.lower_ty_relative_path(ty, Some(resolution), remaining_segments) } - pub(crate) fn lower_path(&self, path: &Path) -> (Ty, Option) { + pub(crate) fn lower_path(&mut self, path: &Path) -> (Ty, Option) { // Resolve the path (in type namespace) if let Some(type_ref) = path.type_anchor() { let (ty, res) = self.lower_ty_ext(type_ref); @@ -736,7 +707,7 @@ impl<'a> TyLoweringContext<'a> { self.lower_partly_resolved_path(resolution, resolved_segment, remaining_segments, false) } - fn select_associated_type(&self, res: Option, segment: PathSegment<'_>) -> Ty { + fn select_associated_type(&mut self, res: Option, segment: PathSegment<'_>) -> Ty { let Some((generics, res)) = self.generics().zip(res) else { return TyKind::Error.intern(Interner); }; @@ -746,6 +717,8 @@ impl<'a> TyLoweringContext<'a> { res, Some(segment.name.clone()), move |name, t, associated_ty| { + let generics = self.generics().unwrap(); + if name != segment.name { return None; } @@ -797,7 +770,7 @@ impl<'a> TyLoweringContext<'a> { } fn lower_path_inner( - &self, + &mut self, segment: PathSegment<'_>, typeable: TyDefId, infer_args: bool, @@ -814,7 +787,7 @@ impl<'a> TyLoweringContext<'a> { /// Collect generic arguments from a path into a `Substs`. See also /// `create_substs_for_ast_path` and `def_to_ty` in rustc. pub(super) fn substs_from_path( - &self, + &mut self, path: &Path, // Note that we don't call `db.value_type(resolved)` here, // `ValueTyDefId` is just a convenient way to pass generics and @@ -855,7 +828,7 @@ impl<'a> TyLoweringContext<'a> { } pub(super) fn substs_from_path_segment( - &self, + &mut self, segment: PathSegment<'_>, def: Option, infer_args: bool, @@ -870,7 +843,7 @@ impl<'a> TyLoweringContext<'a> { } fn substs_from_args_and_bindings( - &self, + &mut self, args_and_bindings: Option<&GenericArgs>, def: Option, infer_args: bool, @@ -959,11 +932,11 @@ impl<'a> TyLoweringContext<'a> { self.db, id, arg, - &mut (), + self, self.types_map, - |_, type_ref| self.lower_ty(type_ref), - |_, const_ref, ty| self.lower_const(const_ref, ty), - |_, lifetime_ref| self.lower_lifetime(lifetime_ref), + |this, type_ref| this.lower_ty(type_ref), + |this, const_ref, ty| this.lower_const(const_ref, ty), + |this, lifetime_ref| this.lower_lifetime(lifetime_ref), ); substs.push(arg); } @@ -1016,7 +989,7 @@ impl<'a> TyLoweringContext<'a> { } pub(crate) fn lower_trait_ref_from_resolved_path( - &self, + &mut self, resolved: TraitId, segment: PathSegment<'_>, explicit_self_ty: Ty, @@ -1025,7 +998,7 @@ impl<'a> TyLoweringContext<'a> { TraitRef { trait_id: to_chalk_trait_id(resolved), substitution: substs } } - fn lower_trait_ref_from_path(&self, path: &Path, explicit_self_ty: Ty) -> Option { + fn lower_trait_ref_from_path(&mut self, path: &Path, explicit_self_ty: Ty) -> Option { let resolved = match self.resolver.resolve_path_in_type_ns_fully(self.db.upcast(), path)? { // FIXME(trait_alias): We need to handle trait alias here. TypeNs::TraitId(tr) => tr, @@ -1035,12 +1008,16 @@ impl<'a> TyLoweringContext<'a> { Some(self.lower_trait_ref_from_resolved_path(resolved, segment, explicit_self_ty)) } - fn lower_trait_ref(&self, trait_ref: &HirTraitRef, explicit_self_ty: Ty) -> Option { + fn lower_trait_ref( + &mut self, + trait_ref: &HirTraitRef, + explicit_self_ty: Ty, + ) -> Option { self.lower_trait_ref_from_path(&trait_ref.path, explicit_self_ty) } fn trait_ref_substs_from_path( - &self, + &mut self, segment: PathSegment<'_>, resolved: TraitId, explicit_self_ty: Ty, @@ -1049,11 +1026,11 @@ impl<'a> TyLoweringContext<'a> { } pub(crate) fn lower_where_predicate<'b>( - &'b self, + &'b mut self, where_predicate: &'b WherePredicate, &def: &GenericDefId, ignore_bindings: bool, - ) -> impl Iterator + 'b { + ) -> impl Iterator + use<'a, 'b> { match where_predicate { WherePredicate::ForLifetime { target, bound, .. } | WherePredicate::TypeBound { target, bound } => { @@ -1087,12 +1064,12 @@ impl<'a> TyLoweringContext<'a> { .into_iter() } - pub(crate) fn lower_type_bound( - &'a self, - bound: &'a TypeBound, + pub(crate) fn lower_type_bound<'b>( + &'b mut self, + bound: &'b TypeBound, self_ty: Ty, ignore_bindings: bool, - ) -> impl Iterator + 'a { + ) -> impl Iterator + use<'b, 'a> { let mut trait_ref = None; let clause = match bound { TypeBound::Path(path, TraitBoundModifier::None) => { @@ -1111,7 +1088,7 @@ impl<'a> TyLoweringContext<'a> { .lower_trait_ref_from_path(path, self_ty.clone()) .map(|trait_ref| trait_ref.hir_trait_id()); if trait_id == sized_trait { - self.unsized_types.borrow_mut().insert(self_ty); + self.unsized_types.insert(self_ty); } None } @@ -1131,17 +1108,18 @@ impl<'a> TyLoweringContext<'a> { }; clause.into_iter().chain( trait_ref - .into_iter() .filter(move |_| !ignore_bindings) - .flat_map(move |tr| self.assoc_type_bindings_from_type_bound(bound, tr)), + .map(move |tr| self.assoc_type_bindings_from_type_bound(bound, tr)) + .into_iter() + .flatten(), ) } - fn assoc_type_bindings_from_type_bound( - &'a self, - bound: &'a TypeBound, + fn assoc_type_bindings_from_type_bound<'b>( + &'b mut self, + bound: &'b TypeBound, trait_ref: TraitRef, - ) -> impl Iterator + 'a { + ) -> impl Iterator + use<'b, 'a> { let last_segment = match bound { TypeBound::Path(path, TraitBoundModifier::None) | TypeBound::ForLifetime(_, path) => { path.segments().last() @@ -1192,22 +1170,16 @@ impl<'a> TyLoweringContext<'a> { binding.type_ref.as_ref().map_or(0, |_| 1) + binding.bounds.len(), ); if let Some(type_ref) = binding.type_ref { - match (&self.types_map[type_ref], &self.impl_trait_mode) { - (TypeRef::ImplTrait(_), ImplTraitLoweringState::Disallowed) => (), - ( - _, - ImplTraitLoweringState::Disallowed | ImplTraitLoweringState::Opaque(_), - ) => { + match (&self.types_map[type_ref], self.impl_trait_mode.mode) { + (TypeRef::ImplTrait(_), ImplTraitLoweringMode::Disallowed) => (), + (_, ImplTraitLoweringMode::Disallowed | ImplTraitLoweringMode::Opaque) => { let ty = self.lower_ty(type_ref); let alias_eq = AliasEq { alias: AliasTy::Projection(projection_ty.clone()), ty }; predicates .push(crate::wrap_empty_binders(WhereClause::AliasEq(alias_eq))); } - ( - _, - ImplTraitLoweringState::Param(_) | ImplTraitLoweringState::Variable(_), - ) => { + (_, ImplTraitLoweringMode::Param | ImplTraitLoweringMode::Variable) => { // Find the generic index for the target of our `bound` let target_param_idx = self .resolver @@ -1244,14 +1216,14 @@ impl<'a> TyLoweringContext<'a> { self.owner, ) .with_type_param_mode(self.type_param_mode); - match &self.impl_trait_mode { - ImplTraitLoweringState::Param(_) => { + match self.impl_trait_mode.mode { + ImplTraitLoweringMode::Param => { ext.impl_trait_mode = - ImplTraitLoweringState::Param(Cell::new(counter)); + ImplTraitLoweringState::param(counter); } - ImplTraitLoweringState::Variable(_) => { + ImplTraitLoweringMode::Variable => { ext.impl_trait_mode = - ImplTraitLoweringState::Variable(Cell::new(counter)); + ImplTraitLoweringState::variable(counter); } _ => unreachable!(), } @@ -1278,7 +1250,7 @@ impl<'a> TyLoweringContext<'a> { }) } - fn lower_dyn_trait(&self, bounds: &[TypeBound]) -> Ty { + fn lower_dyn_trait(&mut self, bounds: &[TypeBound]) -> Ty { let self_ty = TyKind::BoundVar(BoundVar::new(DebruijnIndex::INNERMOST, 0)).intern(Interner); // INVARIANT: The principal trait bound, if present, must come first. Others may be in any // order but should be in the same order for the same set but possibly different order of @@ -1287,22 +1259,26 @@ impl<'a> TyLoweringContext<'a> { // These invariants are utilized by `TyExt::dyn_trait()` and chalk. let mut lifetime = None; let bounds = self.with_shifted_in(DebruijnIndex::ONE, |ctx| { - let mut bounds: Vec<_> = bounds - .iter() - .flat_map(|b| ctx.lower_type_bound(b, self_ty.clone(), false)) - .filter(|b| match b.skip_binders() { - WhereClause::Implemented(_) | WhereClause::AliasEq(_) => true, - WhereClause::LifetimeOutlives(_) => false, - WhereClause::TypeOutlives(t) => { - lifetime = Some(t.lifetime.clone()); - false + let mut lowered_bounds = Vec::new(); + for b in bounds { + ctx.lower_type_bound(b, self_ty.clone(), false).for_each(|b| { + let filter = match b.skip_binders() { + WhereClause::Implemented(_) | WhereClause::AliasEq(_) => true, + WhereClause::LifetimeOutlives(_) => false, + WhereClause::TypeOutlives(t) => { + lifetime = Some(t.lifetime.clone()); + false + } + }; + if filter { + lowered_bounds.push(b); } - }) - .collect(); + }); + } let mut multiple_regular_traits = false; let mut multiple_same_projection = false; - bounds.sort_unstable_by(|lhs, rhs| { + lowered_bounds.sort_unstable_by(|lhs, rhs| { use std::cmp::Ordering; match (lhs.skip_binders(), rhs.skip_binders()) { (WhereClause::Implemented(lhs), WhereClause::Implemented(rhs)) => { @@ -1344,13 +1320,13 @@ impl<'a> TyLoweringContext<'a> { return None; } - bounds.first().and_then(|b| b.trait_id())?; + lowered_bounds.first().and_then(|b| b.trait_id())?; // As multiple occurrences of the same auto traits *are* permitted, we deduplicate the // bounds. We shouldn't have repeated elements besides auto traits at this point. - bounds.dedup(); + lowered_bounds.dedup(); - Some(QuantifiedWhereClauses::from_iter(Interner, bounds)) + Some(QuantifiedWhereClauses::from_iter(Interner, lowered_bounds)) }); if let Some(bounds) = bounds { @@ -1376,16 +1352,16 @@ impl<'a> TyLoweringContext<'a> { } } - fn lower_impl_trait(&self, bounds: &[TypeBound], krate: CrateId) -> ImplTrait { + fn lower_impl_trait(&mut self, bounds: &[TypeBound], krate: CrateId) -> ImplTrait { cov_mark::hit!(lower_rpit); let self_ty = TyKind::BoundVar(BoundVar::new(DebruijnIndex::INNERMOST, 0)).intern(Interner); let predicates = self.with_shifted_in(DebruijnIndex::ONE, |ctx| { - let mut predicates: Vec<_> = bounds - .iter() - .flat_map(|b| ctx.lower_type_bound(b, self_ty.clone(), false)) - .collect(); + let mut predicates = Vec::new(); + for b in bounds { + predicates.extend(ctx.lower_type_bound(b, self_ty.clone(), false)); + } - if !ctx.unsized_types.borrow().contains(&self_ty) { + if !ctx.unsized_types.contains(&self_ty) { let sized_trait = ctx .db .lang_item(krate, LangItem::Sized) @@ -1562,7 +1538,7 @@ pub(crate) fn field_types_query( }; let generics = generics(db.upcast(), def); let mut res = ArenaMap::default(); - let ctx = TyLoweringContext::new(db, &resolver, var_data.types_map(), def.into()) + let mut ctx = TyLoweringContext::new(db, &resolver, var_data.types_map(), def.into()) .with_type_param_mode(ParamLoweringMode::Variable); for (field_id, field_data) in var_data.fields().iter() { res.insert(field_id, make_binders(db, &generics, ctx.lower_ty(field_data.type_ref))); @@ -1596,7 +1572,7 @@ pub(crate) fn generic_predicates_for_param_query( let generics = generics(db.upcast(), def); // we have to filter out all other predicates *first*, before attempting to lower them - let predicate = |pred: &_, def: &_, ctx: &TyLoweringContext<'_>| match pred { + let predicate = |pred: &_, def: &_, ctx: &mut TyLoweringContext<'_>| match pred { WherePredicate::ForLifetime { target, bound, .. } | WherePredicate::TypeBound { target, bound, .. } => { let invalid_target = match target { @@ -1642,16 +1618,19 @@ pub(crate) fn generic_predicates_for_param_query( let mut predicates = Vec::new(); for (params, def) in resolver.all_generic_params() { ctx.types_map = ¶ms.types_map; - predicates.extend( - params.where_predicates().filter(|pred| predicate(pred, def, &ctx)).flat_map(|pred| { - ctx.lower_where_predicate(pred, def, true).map(|p| make_binders(db, &generics, p)) - }), - ); + for pred in params.where_predicates() { + if predicate(pred, def, &mut ctx) { + predicates.extend( + ctx.lower_where_predicate(pred, def, true) + .map(|p| make_binders(db, &generics, p)), + ); + } + } } let subst = generics.bound_vars_subst(db, DebruijnIndex::INNERMOST); if !subst.is_empty(Interner) { - let explicitly_unsized_tys = ctx.unsized_types.into_inner(); + let explicitly_unsized_tys = ctx.unsized_types; if let Some(implicitly_sized_predicates) = implicitly_sized_clauses( db, param_id.parent, @@ -1731,7 +1710,7 @@ pub(crate) fn trait_environment_query( let subst = generics(db.upcast(), def).placeholder_subst(db); if !subst.is_empty(Interner) { - let explicitly_unsized_tys = ctx.unsized_types.into_inner(); + let explicitly_unsized_tys = ctx.unsized_types; if let Some(implicitly_sized_clauses) = implicitly_sized_clauses(db, def, &explicitly_unsized_tys, &subst, &resolver) { @@ -1801,16 +1780,19 @@ where let mut predicates = Vec::new(); for (params, def) in resolver.all_generic_params() { ctx.types_map = ¶ms.types_map; - predicates.extend(params.where_predicates().filter(|pred| filter(pred, def)).flat_map( - |pred| { - ctx.lower_where_predicate(pred, def, false).map(|p| make_binders(db, &generics, p)) - }, - )); + for pred in params.where_predicates() { + if filter(pred, def) { + predicates.extend( + ctx.lower_where_predicate(pred, def, false) + .map(|p| make_binders(db, &generics, p)), + ); + } + } } if generics.len() > 0 { let subst = generics.bound_vars_subst(db, DebruijnIndex::INNERMOST); - let explicitly_unsized_tys = ctx.unsized_types.into_inner(); + let explicitly_unsized_tys = ctx.unsized_types; if let Some(implicitly_sized_predicates) = implicitly_sized_clauses(db, def, &explicitly_unsized_tys, &subst, &resolver) { @@ -1906,7 +1888,8 @@ pub(crate) fn generic_defaults_query(db: &dyn HirDatabase, def: GenericDefId) -> let mut val = p.default.as_ref().map_or_else( || unknown_const_as_generic(db.const_param_ty(id)), |c| { - let c = ctx.lower_const(c, ctx.lower_ty(p.ty)); + let param_ty = ctx.lower_ty(p.ty); + let c = ctx.lower_const(c, param_ty); c.cast(Interner) }, ); @@ -1946,11 +1929,11 @@ pub(crate) fn generic_defaults_recover( fn fn_sig_for_fn(db: &dyn HirDatabase, def: FunctionId) -> PolyFnSig { let data = db.function_data(def); let resolver = def.resolver(db.upcast()); - let ctx_params = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()) + let mut ctx_params = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()) .with_impl_trait_mode(ImplTraitLoweringMode::Variable) .with_type_param_mode(ParamLoweringMode::Variable); let params = data.params.iter().map(|&tr| ctx_params.lower_ty(tr)); - let ctx_ret = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()) + let mut ctx_ret = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()) .with_impl_trait_mode(ImplTraitLoweringMode::Opaque) .with_type_param_mode(ParamLoweringMode::Variable); let ret = ctx_ret.lower_ty(data.ret_type); @@ -1982,7 +1965,7 @@ fn type_for_const(db: &dyn HirDatabase, def: ConstId) -> Binders { let data = db.const_data(def); let generics = generics(db.upcast(), def.into()); let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()) + let mut ctx = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()) .with_type_param_mode(ParamLoweringMode::Variable); make_binders(db, &generics, ctx.lower_ty(data.type_ref)) @@ -1992,7 +1975,7 @@ fn type_for_const(db: &dyn HirDatabase, def: ConstId) -> Binders { fn type_for_static(db: &dyn HirDatabase, def: StaticId) -> Binders { let data = db.static_data(def); let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()); + let mut ctx = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()); Binders::empty(Interner, ctx.lower_ty(data.type_ref)) } @@ -2001,7 +1984,7 @@ fn fn_sig_for_struct_constructor(db: &dyn HirDatabase, def: StructId) -> PolyFnS let struct_data = db.struct_data(def); let fields = struct_data.variant_data.fields(); let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new( + let mut ctx = TyLoweringContext::new( db, &resolver, struct_data.variant_data.types_map(), @@ -2038,7 +2021,7 @@ fn fn_sig_for_enum_variant_constructor(db: &dyn HirDatabase, def: EnumVariantId) let var_data = db.enum_variant_data(def); let fields = var_data.variant_data.fields(); let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new( + let mut ctx = TyLoweringContext::new( db, &resolver, var_data.variant_data.types_map(), @@ -2087,7 +2070,7 @@ fn type_for_type_alias(db: &dyn HirDatabase, t: TypeAliasId) -> Binders { let generics = generics(db.upcast(), t.into()); let resolver = t.resolver(db.upcast()); let type_alias_data = db.type_alias_data(t); - let ctx = TyLoweringContext::new(db, &resolver, &type_alias_data.types_map, t.into()) + let mut ctx = TyLoweringContext::new(db, &resolver, &type_alias_data.types_map, t.into()) .with_impl_trait_mode(ImplTraitLoweringMode::Opaque) .with_type_param_mode(ParamLoweringMode::Variable); let inner = if type_alias_data.is_extern { @@ -2169,7 +2152,7 @@ pub(crate) fn impl_self_ty_query(db: &dyn HirDatabase, impl_id: ImplId) -> Binde let impl_data = db.impl_data(impl_id); let resolver = impl_id.resolver(db.upcast()); let generics = generics(db.upcast(), impl_id.into()); - let ctx = TyLoweringContext::new(db, &resolver, &impl_data.types_map, impl_id.into()) + let mut ctx = TyLoweringContext::new(db, &resolver, &impl_data.types_map, impl_id.into()) .with_type_param_mode(ParamLoweringMode::Variable); make_binders(db, &generics, ctx.lower_ty(impl_data.self_ty)) } @@ -2179,7 +2162,8 @@ pub(crate) fn const_param_ty_query(db: &dyn HirDatabase, def: ConstParamId) -> T let parent_data = db.generic_params(def.parent()); let data = &parent_data[def.local_id()]; let resolver = def.parent().resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, &parent_data.types_map, def.parent().into()); + let mut ctx = + TyLoweringContext::new(db, &resolver, &parent_data.types_map, def.parent().into()); match data { TypeOrConstParamData::TypeParamData(_) => { never!(); @@ -2201,7 +2185,7 @@ pub(crate) fn impl_self_ty_recover( pub(crate) fn impl_trait_query(db: &dyn HirDatabase, impl_id: ImplId) -> Option> { let impl_data = db.impl_data(impl_id); let resolver = impl_id.resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, &impl_data.types_map, impl_id.into()) + let mut ctx = TyLoweringContext::new(db, &resolver, &impl_data.types_map, impl_id.into()) .with_type_param_mode(ParamLoweringMode::Variable); let (self_ty, binders) = db.impl_self_ty(impl_id).into_value_and_skipped_binders(); let target_trait = impl_data.target_trait.as_ref()?; @@ -2215,17 +2199,13 @@ pub(crate) fn return_type_impl_traits( // FIXME unify with fn_sig_for_fn instead of doing lowering twice, maybe let data = db.function_data(def); let resolver = def.resolver(db.upcast()); - let ctx_ret = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()) + let mut ctx_ret = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()) .with_impl_trait_mode(ImplTraitLoweringMode::Opaque) .with_type_param_mode(ParamLoweringMode::Variable); let _ret = ctx_ret.lower_ty(data.ret_type); let generics = generics(db.upcast(), def.into()); - let return_type_impl_traits = ImplTraits { - impl_traits: match ctx_ret.impl_trait_mode { - ImplTraitLoweringState::Opaque(x) => x.into_inner(), - _ => unreachable!(), - }, - }; + let return_type_impl_traits = + ImplTraits { impl_traits: ctx_ret.impl_trait_mode.opaque_type_data }; if return_type_impl_traits.impl_traits.is_empty() { None } else { @@ -2239,18 +2219,13 @@ pub(crate) fn type_alias_impl_traits( ) -> Option>> { let data = db.type_alias_data(def); let resolver = def.resolver(db.upcast()); - let ctx = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()) + let mut ctx = TyLoweringContext::new(db, &resolver, &data.types_map, def.into()) .with_impl_trait_mode(ImplTraitLoweringMode::Opaque) .with_type_param_mode(ParamLoweringMode::Variable); if let Some(type_ref) = data.type_ref { let _ty = ctx.lower_ty(type_ref); } - let type_alias_impl_traits = ImplTraits { - impl_traits: match ctx.impl_trait_mode { - ImplTraitLoweringState::Opaque(x) => x.into_inner(), - _ => unreachable!(), - }, - }; + let type_alias_impl_traits = ImplTraits { impl_traits: ctx.impl_trait_mode.opaque_type_data }; if type_alias_impl_traits.impl_traits.is_empty() { None } else {