internal: simplify and refactor write_where_clause

This commit is contained in:
roife 2024-06-08 03:10:27 +08:00
parent ff83ec6835
commit af1291bbf3
3 changed files with 110 additions and 101 deletions

1
Cargo.lock generated
View file

@ -512,6 +512,7 @@ dependencies = [
"hir-def",
"hir-expand",
"hir-ty",
"intern",
"itertools",
"once_cell",
"rustc-hash",

View file

@ -27,6 +27,7 @@ cfg.workspace = true
hir-def.workspace = true
hir-expand.workspace = true
hir-ty.workspace = true
intern.workspace = true
stdx.workspace = true
syntax.workspace = true
tt.workspace = true

View file

@ -3,7 +3,8 @@ use either::Either;
use hir_def::{
data::adt::{StructKind, VariantData},
generics::{
TypeOrConstParamData, TypeParamProvenance, WherePredicate, WherePredicateTypeTarget,
GenericParams, TypeOrConstParamData, TypeParamProvenance, WherePredicate,
WherePredicateTypeTarget,
},
lang_item::LangItem,
type_ref::{TypeBound, TypeRef},
@ -16,6 +17,8 @@ use hir_ty::{
},
AliasEq, AliasTy, Interner, ProjectionTyExt, TraitRefExt, TyKind, WhereClause,
};
use intern::Interned;
use itertools::Itertools;
use crate::{
Adt, AsAssocItem, AssocItem, AssocItemContainer, Const, ConstParam, Enum, ExternCrateDecl,
@ -30,10 +33,14 @@ impl HirDisplay for Function {
let data = db.function_data(self.id);
let container = self.as_assoc_item(db).map(|it| it.container(db));
let mut module = self.module(db);
if let Some(AssocItemContainer::Impl(_)) = container {
match container {
Some(AssocItemContainer::Impl(_)) => {
// Block-local impls are "hoisted" to the nearest (non-block) module.
module = module.nearest_non_block_module(db);
}
_ => {}
}
let module_id = module.id;
write_visibility(module_id, self.visibility(db), f)?;
if data.has_default_kw() {
@ -555,103 +562,98 @@ fn write_where_clause(
) -> Result<bool, HirDisplayError> {
let params = f.db.generic_params(def);
// unnamed type targets are displayed inline with the argument itself, e.g. `f: impl Y`.
let is_unnamed_type_target = |target: &WherePredicateTypeTarget| match target {
WherePredicateTypeTarget::TypeRef(_) => false,
WherePredicateTypeTarget::TypeOrConstParam(id) => {
params.type_or_consts[*id].name().is_none()
}
let no_displayable_pred = |params: &Interned<GenericParams>| {
params.where_predicates.iter().all(|pred| {
matches!(
pred,
WherePredicate::TypeBound { target: WherePredicateTypeTarget::TypeOrConstParam(id), .. }
if params.type_or_consts[*id].name().is_none()
)
})
};
let has_displayable_predicate = params
.where_predicates
.iter()
.any(|pred| {
!matches!(pred, WherePredicate::TypeBound { target, .. } if is_unnamed_type_target(target))
});
if !has_displayable_predicate {
if no_displayable_pred(&params) {
return Ok(false);
}
let write_target = |target: &WherePredicateTypeTarget, f: &mut HirFormatter<'_>| match target {
WherePredicateTypeTarget::TypeRef(ty) => ty.hir_fmt(f),
WherePredicateTypeTarget::TypeOrConstParam(id) => {
match &params.type_or_consts[*id].name() {
Some(name) => write!(f, "{}", name.display(f.db.upcast())),
None => f.write_str("{unnamed}"),
}
}
};
f.write_str("\nwhere")?;
for (pred_idx, pred) in params.where_predicates.iter().enumerate() {
let prev_pred =
if pred_idx == 0 { None } else { Some(&params.where_predicates[pred_idx - 1]) };
let new_predicate = |f: &mut HirFormatter<'_>| {
f.write_str(if pred_idx == 0 { "\n " } else { ",\n " })
};
match pred {
WherePredicate::TypeBound { target, .. } if is_unnamed_type_target(target) => {}
WherePredicate::TypeBound { target, bound } => {
if matches!(prev_pred, Some(WherePredicate::TypeBound { target: target_, .. }) if target_ == target)
{
f.write_str(" + ")?;
} else {
new_predicate(f)?;
write_target(target, f)?;
f.write_str(": ")?;
}
bound.hir_fmt(f)?;
}
WherePredicate::Lifetime { target, bound } => {
if matches!(prev_pred, Some(WherePredicate::Lifetime { target: target_, .. }) if target_ == target)
{
write!(f, " + {}", bound.name.display(f.db.upcast()))?;
} else {
new_predicate(f)?;
write!(
f,
"{}: {}",
target.name.display(f.db.upcast()),
bound.name.display(f.db.upcast())
)?;
}
}
WherePredicate::ForLifetime { lifetimes, target, bound } => {
if matches!(
prev_pred,
Some(WherePredicate::ForLifetime { lifetimes: lifetimes_, target: target_, .. })
if lifetimes_ == lifetimes && target_ == target,
) {
f.write_str(" + ")?;
} else {
new_predicate(f)?;
f.write_str("for<")?;
for (idx, lifetime) in lifetimes.iter().enumerate() {
if idx != 0 {
f.write_str(", ")?;
}
write!(f, "{}", lifetime.display(f.db.upcast()))?;
}
f.write_str("> ")?;
write_target(target, f)?;
f.write_str(": ")?;
}
bound.hir_fmt(f)?;
}
}
}
// End of final predicate. There must be at least one predicate here.
f.write_char(',')?;
write_where_predicates(&params, f)?;
Ok(true)
}
fn write_where_predicates(
params: &Interned<GenericParams>,
f: &mut HirFormatter<'_>,
) -> Result<(), HirDisplayError> {
use WherePredicate::*;
// unnamed type targets are displayed inline with the argument itself, e.g. `f: impl Y`.
let is_unnamed_type_target =
|params: &Interned<GenericParams>, target: &WherePredicateTypeTarget| {
matches!(target,
WherePredicateTypeTarget::TypeOrConstParam(id) if params.type_or_consts[*id].name().is_none()
)
};
let write_target = |target: &WherePredicateTypeTarget, f: &mut HirFormatter<'_>| match target {
WherePredicateTypeTarget::TypeRef(ty) => ty.hir_fmt(f),
WherePredicateTypeTarget::TypeOrConstParam(id) => match params.type_or_consts[*id].name() {
Some(name) => write!(f, "{}", name.display(f.db.upcast())),
None => f.write_str("{unnamed}"),
},
};
let check_same_target = |pred1: &WherePredicate, pred2: &WherePredicate| match (pred1, pred2) {
(TypeBound { target: t1, .. }, TypeBound { target: t2, .. }) => t1 == t2,
(Lifetime { target: t1, .. }, Lifetime { target: t2, .. }) => t1 == t2,
(
ForLifetime { lifetimes: l1, target: t1, .. },
ForLifetime { lifetimes: l2, target: t2, .. },
) => l1 == l2 && t1 == t2,
_ => false,
};
let mut iter = params.where_predicates.iter().peekable();
while let Some(pred) = iter.next() {
if matches!(pred, TypeBound { target, .. } if is_unnamed_type_target(params, target)) {
continue;
}
f.write_str("\n ")?;
match pred {
TypeBound { target, bound } => {
write_target(target, f)?;
f.write_str(": ")?;
bound.hir_fmt(f)?;
}
Lifetime { target, bound } => {
let target = target.name.display(f.db.upcast());
let bound = bound.name.display(f.db.upcast());
write!(f, "{target}: {bound}")?;
}
ForLifetime { lifetimes, target, bound } => {
let lifetimes = lifetimes.iter().map(|it| it.display(f.db.upcast())).join(", ");
write!(f, "for<{lifetimes}> ")?;
write_target(target, f)?;
f.write_str(": ")?;
bound.hir_fmt(f)?;
}
}
while let Some(nxt) = iter.next_if(|nxt| check_same_target(pred, nxt)) {
f.write_str(" + ")?;
match nxt {
TypeBound { bound, .. } | ForLifetime { bound, .. } => bound.hir_fmt(f)?,
Lifetime { bound, .. } => write!(f, "{}", bound.name.display(f.db.upcast()))?,
}
}
f.write_str(",")?;
}
Ok(())
}
impl HirDisplay for Const {
fn hir_fmt(&self, f: &mut HirFormatter<'_>) -> Result<(), HirDisplayError> {
let db = f.db;
@ -689,17 +691,8 @@ impl HirDisplay for Static {
impl HirDisplay for Trait {
fn hir_fmt(&self, f: &mut HirFormatter<'_>) -> Result<(), HirDisplayError> {
write_visibility(self.module(f.db).id, self.visibility(f.db), f)?;
let data = f.db.trait_data(self.id);
if data.is_unsafe {
f.write_str("unsafe ")?;
}
if data.is_auto {
f.write_str("auto ")?;
}
write!(f, "trait {}", data.name.display(f.db.upcast()))?;
write_trait_header(self, f)?;
let def_id = GenericDefId::TraitId(self.id);
write_generic_params(def_id, f)?;
let has_where_clause = write_where_clause(def_id, f)?;
if let Some(limit) = f.entity_limit {
@ -735,6 +728,20 @@ impl HirDisplay for Trait {
}
}
fn write_trait_header(trait_: &Trait, f: &mut HirFormatter<'_>) -> Result<(), HirDisplayError> {
write_visibility(trait_.module(f.db).id, trait_.visibility(f.db), f)?;
let data = f.db.trait_data(trait_.id);
if data.is_unsafe {
f.write_str("unsafe ")?;
}
if data.is_auto {
f.write_str("auto ")?;
}
write!(f, "trait {}", data.name.display(f.db.upcast()))?;
write_generic_params(GenericDefId::TraitId(trait_.id), f)?;
Ok(())
}
impl HirDisplay for TraitAlias {
fn hir_fmt(&self, f: &mut HirFormatter<'_>) -> Result<(), HirDisplayError> {
write_visibility(self.module(f.db).id, self.visibility(f.db), f)?;