From e36a31bd9534254df49714bd9fa3ca9dcb31f805 Mon Sep 17 00:00:00 2001 From: Shoyu Vanilla Date: Thu, 22 Feb 2024 00:44:48 +0900 Subject: [PATCH] Port closure kind deduction logic from rustc --- crates/hir-ty/src/infer/closure.rs | 96 +++++++++++++++++++++++++++--- crates/hir-ty/src/infer/unify.rs | 73 ++++++++++++++++++++++- crates/hir-ty/src/utils.rs | 46 ++++++++++++++ 3 files changed, 205 insertions(+), 10 deletions(-) diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index 000248c2ce..62740f354c 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -5,7 +5,7 @@ use std::{cmp, convert::Infallible, mem}; use chalk_ir::{ cast::Cast, fold::{FallibleTypeFolder, TypeFoldable}, - AliasEq, AliasTy, BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, WhereClause, + BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, }; use either::Either; use hir_def::{ @@ -22,13 +22,14 @@ use stdx::never; use crate::{ db::{HirDatabase, InternedClosure}, - from_placeholder_idx, make_binders, + from_chalk_trait_id, from_placeholder_idx, make_binders, mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem}, static_lifetime, to_chalk_trait_id, traits::FnTrait, - utils::{self, generics, Generics}, - Adjust, Adjustment, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, FnAbi, FnPointer, - FnSig, Interner, Substitution, Ty, TyExt, + utils::{self, elaborate_clause_supertraits, generics, Generics}, + Adjust, Adjustment, AliasEq, AliasTy, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, + DynTyExt, FnAbi, FnPointer, FnSig, Interner, OpaqueTy, ProjectionTyExt, Substitution, Ty, + TyExt, WhereClause, }; use super::{Expectation, InferenceContext}; @@ -47,6 +48,15 @@ impl InferenceContext<'_> { None => return, }; + if let TyKind::Closure(closure_id, _) = closure_ty.kind(Interner) { + if let Some(closure_kind) = self.deduce_closure_kind_from_expectations(&expected_ty) { + self.result + .closure_info + .entry(*closure_id) + .or_insert_with(|| (Vec::new(), closure_kind)); + } + } + // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here. let _ = self.coerce(Some(closure_expr), closure_ty, &expected_ty); @@ -65,6 +75,60 @@ impl InferenceContext<'_> { } } + // Closure kind deductions are mostly from `rustc_hir_typeck/src/closure.rs`. + // Might need to port closure sig deductions too. + fn deduce_closure_kind_from_expectations(&mut self, expected_ty: &Ty) -> Option { + match expected_ty.kind(Interner) { + TyKind::Alias(AliasTy::Opaque(OpaqueTy { .. })) | TyKind::OpaqueType(..) => { + let clauses = expected_ty + .impl_trait_bounds(self.db) + .into_iter() + .flatten() + .map(|b| b.into_value_and_skipped_binders().0); + self.deduce_closure_kind_from_predicate_clauses(clauses) + } + TyKind::Dyn(dyn_ty) => dyn_ty.principal().and_then(|trait_ref| { + self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_ref.trait_id)) + }), + TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => { + let clauses = self.clauses_for_self_ty(*ty); + self.deduce_closure_kind_from_predicate_clauses(clauses.into_iter()) + } + TyKind::Function(_) => Some(FnTrait::Fn), + _ => None, + } + } + + fn deduce_closure_kind_from_predicate_clauses( + &self, + clauses: impl DoubleEndedIterator, + ) -> Option { + let mut expected_kind = None; + + for clause in elaborate_clause_supertraits(self.db, clauses.rev()) { + let trait_id = match clause { + WhereClause::AliasEq(AliasEq { + alias: AliasTy::Projection(projection), .. + }) => Some(projection.trait_(self.db)), + WhereClause::Implemented(trait_ref) => { + Some(from_chalk_trait_id(trait_ref.trait_id)) + } + _ => None, + }; + if let Some(closure_kind) = + trait_id.and_then(|trait_id| self.fn_trait_kind_from_trait_id(trait_id)) + { + // `FnX`'s variants order is opposite from rustc, so use `cmp::max` instead of `cmp::min` + expected_kind = Some( + expected_kind + .map_or_else(|| closure_kind, |current| cmp::max(current, closure_kind)), + ); + } + } + + expected_kind + } + fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option { // Search for a predicate like `<$self as FnX>::Output == Ret` @@ -111,6 +175,18 @@ impl InferenceContext<'_> { None } + + fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option { + utils::fn_traits(self.db.upcast(), self.owner.module(self.db.upcast()).krate()) + .enumerate() + .find_map(|(i, t)| (t == trait_id).then_some(i)) + .map(|i| match i { + 0 => FnTrait::Fn, + 1 => FnTrait::FnMut, + 2 => FnTrait::FnOnce, + _ => unreachable!(), + }) + } } // The below functions handle capture and closure kind (Fn, FnMut, ..) @@ -962,8 +1038,14 @@ impl InferenceContext<'_> { } } self.restrict_precision_for_unsafe(); - // closure_kind should be done before adjust_for_move_closure - let closure_kind = self.closure_kind(); + // `closure_kind` should be done before adjust_for_move_closure + // If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does. + // rustc also does diagnostics here if the latter is not a subtype of the former. + let closure_kind = self + .result + .closure_info + .get(&closure) + .map_or_else(|| self.closure_kind(), |info| info.1); match capture_by { CaptureBy::Value => self.adjust_for_move_closure(), CaptureBy::Ref => (), diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs index 709760b64f..d24e938a54 100644 --- a/crates/hir-ty/src/infer/unify.rs +++ b/crates/hir-ty/src/infer/unify.rs @@ -10,15 +10,16 @@ use chalk_solve::infer::ParameterEnaVariableExt; use either::Either; use ena::unify::UnifyKey; use hir_expand::name; +use smallvec::SmallVec; use triomphe::Arc; use super::{InferOk, InferResult, InferenceContext, TypeError}; use crate::{ consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, static_lifetime, to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, Const, ConstValue, - DebruijnIndex, GenericArg, GenericArgData, Goal, Guidance, InEnvironment, InferenceVar, - Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, Substitution, - TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind, + DebruijnIndex, DomainGoal, GenericArg, GenericArgData, Goal, GoalData, Guidance, InEnvironment, + InferenceVar, Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, + Substitution, TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind, WhereClause, }; impl InferenceContext<'_> { @@ -31,6 +32,72 @@ impl InferenceContext<'_> { { self.table.canonicalize(t) } + + pub(super) fn clauses_for_self_ty( + &mut self, + self_ty: InferenceVar, + ) -> SmallVec<[WhereClause; 4]> { + self.table.resolve_obligations_as_possible(); + + let root = self.table.var_unification_table.inference_var_root(self_ty); + let pending_obligations = mem::take(&mut self.table.pending_obligations); + let obligations = pending_obligations + .iter() + .filter_map(|obligation| match obligation.value.value.goal.data(Interner) { + GoalData::DomainGoal(DomainGoal::Holds( + clause @ WhereClause::AliasEq(AliasEq { + alias: AliasTy::Projection(projection), + .. + }), + )) => { + let projection_self = projection.self_type_parameter(self.db); + let uncanonical = chalk_ir::Substitute::apply( + &obligation.free_vars, + projection_self, + Interner, + ); + if matches!( + self.resolve_ty_shallow(&uncanonical).kind(Interner), + TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root, + ) { + Some(chalk_ir::Substitute::apply( + &obligation.free_vars, + clause.clone(), + Interner, + )) + } else { + None + } + } + GoalData::DomainGoal(DomainGoal::Holds( + clause @ WhereClause::Implemented(trait_ref), + )) => { + let trait_ref_self = trait_ref.self_type_parameter(Interner); + let uncanonical = chalk_ir::Substitute::apply( + &obligation.free_vars, + trait_ref_self, + Interner, + ); + if matches!( + self.resolve_ty_shallow(&uncanonical).kind(Interner), + TyKind::InferenceVar(iv, TyVariableKind::General) if *iv == root, + ) { + Some(chalk_ir::Substitute::apply( + &obligation.free_vars, + clause.clone(), + Interner, + )) + } else { + None + } + } + _ => None, + }) + .collect(); + self.table.pending_obligations = pending_obligations; + + obligations + } } #[derive(Debug, Clone)] diff --git a/crates/hir-ty/src/utils.rs b/crates/hir-ty/src/utils.rs index c150314138..8bd57820d2 100644 --- a/crates/hir-ty/src/utils.rs +++ b/crates/hir-ty/src/utils.rs @@ -112,6 +112,52 @@ impl Iterator for SuperTraits<'_> { } } +pub(super) fn elaborate_clause_supertraits( + db: &dyn HirDatabase, + clauses: impl Iterator, +) -> ClauseElaborator<'_> { + let mut elaborator = ClauseElaborator { db, stack: Vec::new(), seen: FxHashSet::default() }; + elaborator.extend_deduped(clauses); + + elaborator +} + +pub(super) struct ClauseElaborator<'a> { + db: &'a dyn HirDatabase, + stack: Vec, + seen: FxHashSet, +} + +impl<'a> ClauseElaborator<'a> { + fn extend_deduped(&mut self, clauses: impl IntoIterator) { + self.stack.extend(clauses.into_iter().filter(|c| self.seen.insert(c.clone()))) + } + + fn elaborate_supertrait(&mut self, clause: &WhereClause) { + if let WhereClause::Implemented(trait_ref) = clause { + direct_super_trait_refs(self.db, trait_ref, |t| { + let clause = WhereClause::Implemented(t); + if self.seen.insert(clause.clone()) { + self.stack.push(clause); + } + }); + } + } +} + +impl Iterator for ClauseElaborator<'_> { + type Item = WhereClause; + + fn next(&mut self) -> Option { + if let Some(next) = self.stack.pop() { + self.elaborate_supertrait(&next); + Some(next) + } else { + None + } + } +} + fn direct_super_traits(db: &dyn DefDatabase, trait_: TraitId, cb: impl FnMut(TraitId)) { let resolver = trait_.resolver(db); let generic_params = db.generic_params(trait_.into());