Support the new CoercePointee derive

This commit is contained in:
Chayim Refael Friedman 2025-01-02 19:05:07 +02:00
parent 061d257e90
commit ce323627c5
8 changed files with 767 additions and 63 deletions

View file

@ -2,7 +2,7 @@
use expect_test::expect; use expect_test::expect;
use crate::macro_expansion_tests::check; use crate::macro_expansion_tests::{check, check_errors};
#[test] #[test]
fn test_copy_expand_simple() { fn test_copy_expand_simple() {
@ -654,3 +654,95 @@ impl < > $crate::default::Default for Bar< > where {
}"##]], }"##]],
); );
} }
#[test]
fn coerce_pointee_expansion() {
check(
r#"
//- minicore: coerce_pointee
use core::marker::CoercePointee;
pub trait Trait<T: ?Sized> {}
#[derive(CoercePointee)]
#[repr(transparent)]
pub struct Foo<'a, T: ?Sized + Trait<U>, #[pointee] U: ?Sized, const N: u32>(T)
where
U: Trait<U> + ToString;"#,
expect![[r#"
use core::marker::CoercePointee;
pub trait Trait<T: ?Sized> {}
#[derive(CoercePointee)]
#[repr(transparent)]
pub struct Foo<'a, T: ?Sized + Trait<U>, #[pointee] U: ?Sized, const N: u32>(T)
where
U: Trait<U> + ToString;
impl <T, U, const N: u32, __S> $crate::ops::DispatchFromDyn<Foo<'a, T, __S, N>> for Foo<T, U, N, > where U: Trait<U> +ToString, T: Trait<__S>, __S: ?Sized, __S: Trait<__S> +ToString, U: ::core::marker::Unsize<__S>, T:?Sized+Trait<U>, U:?Sized, {}
impl <T, U, const N: u32, __S> $crate::ops::CoerceUnsized<Foo<'a, T, __S, N>> for Foo<T, U, N, > where U: Trait<U> +ToString, T: Trait<__S>, __S: ?Sized, __S: Trait<__S> +ToString, U: ::core::marker::Unsize<__S>, T:?Sized+Trait<U>, U:?Sized, {}"#]],
);
}
#[test]
fn coerce_pointee_errors() {
check_errors(
r#"
//- minicore: coerce_pointee
use core::marker::CoercePointee;
#[derive(CoercePointee)]
enum Enum {}
#[derive(CoercePointee)]
struct Struct1;
#[derive(CoercePointee)]
struct Struct2();
#[derive(CoercePointee)]
struct Struct3 {}
#[derive(CoercePointee)]
struct Struct4<T: ?Sized>(T);
#[derive(CoercePointee)]
#[repr(transparent)]
struct Struct5(i32);
#[derive(CoercePointee)]
#[repr(transparent)]
struct Struct6<#[pointee] T: ?Sized, #[pointee] U: ?Sized>(T, U);
#[derive(CoercePointee)]
#[repr(transparent)]
struct Struct7<T: ?Sized, U: ?Sized>(T, U);
#[derive(CoercePointee)]
#[repr(transparent)]
struct Struct8<#[pointee] T, U: ?Sized>(T);
#[derive(CoercePointee)]
#[repr(transparent)]
struct Struct9<T>(T);
#[derive(CoercePointee)]
#[repr(transparent)]
struct Struct9<#[pointee] T, U>(T) where T: ?Sized;
"#,
expect![[r#"
35..72: `CoercePointee` can only be derived on `struct`s
74..114: `CoercePointee` can only be derived on `struct`s with at least one field
116..158: `CoercePointee` can only be derived on `struct`s with at least one field
160..202: `CoercePointee` can only be derived on `struct`s with at least one field
204..258: `CoercePointee` can only be derived on `struct`s with `#[repr(transparent)]`
260..326: `CoercePointee` can only be derived on `struct`s that are generic over at least one type
328..439: only one type parameter can be marked as `#[pointee]` when deriving `CoercePointee` traits
441..530: exactly one generic type parameter must be marked as `#[pointee]` to derive `CoercePointee` traits
532..621: `derive(CoercePointee)` requires `T` to be marked `?Sized`
623..690: `derive(CoercePointee)` requires `T` to be marked `?Sized`"#]],
);
}

View file

@ -16,14 +16,16 @@ mod proc_macros;
use std::{iter, ops::Range, sync}; use std::{iter, ops::Range, sync};
use base_db::SourceDatabase;
use expect_test::Expect; use expect_test::Expect;
use hir_expand::{ use hir_expand::{
db::ExpandDatabase, db::ExpandDatabase,
proc_macro::{ProcMacro, ProcMacroExpander, ProcMacroExpansionError, ProcMacroKind}, proc_macro::{ProcMacro, ProcMacroExpander, ProcMacroExpansionError, ProcMacroKind},
span_map::SpanMapRef, span_map::SpanMapRef,
InFile, MacroFileId, MacroFileIdExt, InFile, MacroCallKind, MacroFileId, MacroFileIdExt,
}; };
use intern::Symbol; use intern::Symbol;
use itertools::Itertools;
use span::{Edition, Span}; use span::{Edition, Span};
use stdx::{format_to, format_to_acc}; use stdx::{format_to, format_to_acc};
use syntax::{ use syntax::{
@ -44,6 +46,36 @@ use crate::{
AdtId, AsMacroCall, Lookup, ModuleDefId, AdtId, AsMacroCall, Lookup, ModuleDefId,
}; };
#[track_caller]
fn check_errors(ra_fixture: &str, expect: Expect) {
let db = TestDB::with_files(ra_fixture);
let krate = db.fetch_test_crate();
let def_map = db.crate_def_map(krate);
let errors = def_map
.modules()
.flat_map(|module| module.1.scope.all_macro_calls())
.filter_map(|macro_call| {
let errors = db.parse_macro_expansion_error(macro_call)?;
let errors = errors.err.as_ref()?.render_to_string(&db);
let macro_loc = db.lookup_intern_macro_call(macro_call);
let ast_id = match macro_loc.kind {
MacroCallKind::FnLike { ast_id, .. } => ast_id.map(|it| it.erase()),
MacroCallKind::Derive { ast_id, .. } => ast_id.map(|it| it.erase()),
MacroCallKind::Attr { ast_id, .. } => ast_id.map(|it| it.erase()),
};
let ast = db
.parse(ast_id.file_id.file_id().expect("macros inside macros are not supported"))
.syntax_node();
let ast_id_map = db.ast_id_map(ast_id.file_id);
let node = ast_id_map.get_erased(ast_id.value).to_node(&ast);
Some((node.text_range(), errors))
})
.sorted_unstable_by_key(|(range, _)| range.start())
.format_with("\n", |(range, err), format| format(&format_args!("{range:?}: {err}")))
.to_string();
expect.assert_eq(&errors);
}
#[track_caller] #[track_caller]
fn check(ra_fixture: &str, mut expect: Expect) { fn check(ra_fixture: &str, mut expect: Expect) {
let extra_proc_macros = vec![( let extra_proc_macros = vec![(
@ -245,7 +277,9 @@ fn pretty_print_macro_expansion(
let mut res = String::new(); let mut res = String::new();
let mut prev_kind = EOF; let mut prev_kind = EOF;
let mut indent_level = 0; let mut indent_level = 0;
for token in iter::successors(expn.first_token(), |t| t.next_token()) { for token in iter::successors(expn.first_token(), |t| t.next_token())
.take_while(|token| token.text_range().start() < expn.text_range().end())
{
let curr_kind = token.kind(); let curr_kind = token.kind();
let space = match (prev_kind, curr_kind) { let space = match (prev_kind, curr_kind) {
_ if prev_kind.is_trivia() || curr_kind.is_trivia() => "", _ if prev_kind.is_trivia() || curr_kind.is_trivia() => "",

View file

@ -1,9 +1,10 @@
//! Builtin derives. //! Builtin derives.
use intern::sym; use intern::sym;
use itertools::izip; use itertools::{izip, Itertools};
use parser::SyntaxKind;
use rustc_hash::FxHashSet; use rustc_hash::FxHashSet;
use span::{MacroCallId, Span}; use span::{MacroCallId, Span, SyntaxContextId};
use stdx::never; use stdx::never;
use syntax_bridge::DocCommentDesugarMode; use syntax_bridge::DocCommentDesugarMode;
use tracing::debug; use tracing::debug;
@ -16,8 +17,12 @@ use crate::{
span_map::ExpansionSpanMap, span_map::ExpansionSpanMap,
tt, ExpandError, ExpandResult, tt, ExpandError, ExpandResult,
}; };
use syntax::ast::{ use syntax::{
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds, ast::{
self, edit_in_place::GenericParamsOwnerEdit, make, AstNode, FieldList, HasAttrs,
HasGenericArgs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
},
ted,
}; };
macro_rules! register_builtin { macro_rules! register_builtin {
@ -67,13 +72,15 @@ register_builtin! {
Ord => ord_expand, Ord => ord_expand,
PartialOrd => partial_ord_expand, PartialOrd => partial_ord_expand,
Eq => eq_expand, Eq => eq_expand,
PartialEq => partial_eq_expand PartialEq => partial_eq_expand,
CoercePointee => coerce_pointee_expand
} }
pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander> { pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander> {
BuiltinDeriveExpander::find_by_name(ident) BuiltinDeriveExpander::find_by_name(ident)
} }
#[derive(Clone)]
enum VariantShape { enum VariantShape {
Struct(Vec<tt::Ident>), Struct(Vec<tt::Ident>),
Tuple(usize), Tuple(usize),
@ -147,6 +154,7 @@ impl VariantShape {
} }
} }
#[derive(Clone)]
enum AdtShape { enum AdtShape {
Struct(VariantShape), Struct(VariantShape),
Enum { variants: Vec<(tt::Ident, VariantShape)>, default_variant: Option<usize> }, Enum { variants: Vec<(tt::Ident, VariantShape)>, default_variant: Option<usize> },
@ -197,30 +205,38 @@ impl AdtShape {
} }
} }
#[derive(Clone)]
struct BasicAdtInfo { struct BasicAdtInfo {
name: tt::Ident, name: tt::Ident,
shape: AdtShape, shape: AdtShape,
/// first field is the name, and /// first field is the name, and
/// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param. /// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
/// third fields is where bounds, if any /// third fields is where bounds, if any
param_types: Vec<(tt::TopSubtree, Option<tt::TopSubtree>, Option<tt::TopSubtree>)>, param_types: Vec<AdtParam>,
where_clause: Vec<tt::TopSubtree>, where_clause: Vec<tt::TopSubtree>,
associated_types: Vec<tt::TopSubtree>, associated_types: Vec<tt::TopSubtree>,
} }
#[derive(Clone)]
struct AdtParam {
name: tt::TopSubtree,
/// `None` if this is a type parameter.
const_ty: Option<tt::TopSubtree>,
bounds: Option<tt::TopSubtree>,
}
// FIXME: This whole thing needs a refactor. Each derive requires its special values, and the result is a mess.
fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result<BasicAdtInfo, ExpandError> { fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result<BasicAdtInfo, ExpandError> {
let (parsed, tm) = &syntax_bridge::token_tree_to_syntax_node( let (adt, tm) = to_adt_syntax(tt, call_site)?;
tt, parse_adt_from_syntax(&adt, &tm, call_site)
syntax_bridge::TopEntryPoint::MacroItems, }
parser::Edition::CURRENT_FIXME,
); fn parse_adt_from_syntax(
let macro_items = ast::MacroItems::cast(parsed.syntax_node()) adt: &ast::Adt,
.ok_or_else(|| ExpandError::other(call_site, "invalid item definition"))?; tm: &span::SpanMap<SyntaxContextId>,
let item = call_site: Span,
macro_items.items().next().ok_or_else(|| ExpandError::other(call_site, "no item found"))?; ) -> Result<BasicAdtInfo, ExpandError> {
let adt = &ast::Adt::cast(item.syntax().clone()) let (name, generic_param_list, where_clause, shape) = match &adt {
.ok_or_else(|| ExpandError::other(call_site, "expected struct, enum or union"))?;
let (name, generic_param_list, where_clause, shape) = match adt {
ast::Adt::Struct(it) => ( ast::Adt::Struct(it) => (
it.name(), it.name(),
it.generic_param_list(), it.generic_param_list(),
@ -291,7 +307,7 @@ fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result<BasicAdtInfo, Expan
}), }),
ast::TypeOrConstParam::Const(_) => None, ast::TypeOrConstParam::Const(_) => None,
}; };
let ty = if let ast::TypeOrConstParam::Const(param) = param { let const_ty = if let ast::TypeOrConstParam::Const(param) = param {
let ty = param let ty = param
.ty() .ty()
.map(|ty| { .map(|ty| {
@ -309,7 +325,7 @@ fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result<BasicAdtInfo, Expan
} else { } else {
None None
}; };
(name, ty, bounds) AdtParam { name, const_ty, bounds }
}) })
.collect(); .collect();
@ -365,6 +381,24 @@ fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result<BasicAdtInfo, Expan
Ok(BasicAdtInfo { name: name_token, shape, param_types, where_clause, associated_types }) Ok(BasicAdtInfo { name: name_token, shape, param_types, where_clause, associated_types })
} }
fn to_adt_syntax(
tt: &tt::TopSubtree,
call_site: Span,
) -> Result<(ast::Adt, span::SpanMap<SyntaxContextId>), ExpandError> {
let (parsed, tm) = syntax_bridge::token_tree_to_syntax_node(
tt,
syntax_bridge::TopEntryPoint::MacroItems,
parser::Edition::CURRENT_FIXME,
);
let macro_items = ast::MacroItems::cast(parsed.syntax_node())
.ok_or_else(|| ExpandError::other(call_site, "invalid item definition"))?;
let item =
macro_items.items().next().ok_or_else(|| ExpandError::other(call_site, "no item found"))?;
let adt = ast::Adt::cast(item.syntax().clone())
.ok_or_else(|| ExpandError::other(call_site, "expected struct, enum or union"))?;
Ok((adt, tm))
}
fn name_to_token( fn name_to_token(
call_site: Span, call_site: Span,
token_map: &ExpansionSpanMap, token_map: &ExpansionSpanMap,
@ -426,38 +460,64 @@ fn expand_simple_derive(
) )
} }
}; };
ExpandResult::ok(expand_simple_derive_with_parsed(
invoc_span,
info,
trait_path,
make_trait_body,
true,
tt::TopSubtree::empty(tt::DelimSpan::from_single(invoc_span)),
))
}
fn expand_simple_derive_with_parsed(
invoc_span: Span,
info: BasicAdtInfo,
trait_path: tt::TopSubtree,
make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::TopSubtree,
constrain_to_trait: bool,
extra_impl_params: tt::TopSubtree,
) -> tt::TopSubtree {
let trait_body = make_trait_body(&info); let trait_body = make_trait_body(&info);
let mut where_block: Vec<_> = let mut where_block: Vec<_> =
info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect(); info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect();
let (params, args): (Vec<_>, Vec<_>) = info let (params, args): (Vec<_>, Vec<_>) = info
.param_types .param_types
.into_iter() .into_iter()
.map(|(ident, param_ty, bound)| { .map(|param| {
let ident_ = ident.clone(); let ident = param.name;
if let Some(b) = bound { if let Some(b) = param.bounds {
let ident = ident.clone(); let ident2 = ident.clone();
where_block.push(quote! {invoc_span => #ident : #b , }); where_block.push(quote! {invoc_span => #ident2 : #b , });
} }
if let Some(ty) = param_ty { if let Some(ty) = param.const_ty {
(quote! {invoc_span => const #ident : #ty , }, quote! {invoc_span => #ident_ , }) let ident2 = ident.clone();
(quote! {invoc_span => const #ident : #ty , }, quote! {invoc_span => #ident2 , })
} else { } else {
let bound = trait_path.clone(); let bound = trait_path.clone();
(quote! {invoc_span => #ident : #bound , }, quote! {invoc_span => #ident_ , }) let ident2 = ident.clone();
let param = if constrain_to_trait {
quote! {invoc_span => #ident : #bound , }
} else {
quote! {invoc_span => #ident , }
};
(param, quote! {invoc_span => #ident2 , })
} }
}) })
.unzip(); .unzip();
if constrain_to_trait {
where_block.extend(info.associated_types.iter().map(|it| { where_block.extend(info.associated_types.iter().map(|it| {
let it = it.clone(); let it = it.clone();
let bound = trait_path.clone(); let bound = trait_path.clone();
quote! {invoc_span => #it : #bound , } quote! {invoc_span => #it : #bound , }
})); }));
}
let name = info.name; let name = info.name;
let expanded = quote! {invoc_span => quote! {invoc_span =>
impl < ##params > #trait_path for #name < ##args > where ##where_block { #trait_body } impl < ##params #extra_impl_params > #trait_path for #name < ##args > where ##where_block { #trait_body }
}; }
ExpandResult::ok(expanded)
} }
fn copy_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> { fn copy_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
@ -871,3 +931,493 @@ fn partial_ord_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSu
} }
}) })
} }
fn coerce_pointee_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
let (adt, _span_map) = match to_adt_syntax(tt, span) {
Ok(it) => it,
Err(err) => {
return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err);
}
};
let adt = adt.clone_for_update();
let ast::Adt::Struct(strukt) = &adt else {
return ExpandResult::new(
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
ExpandError::other(span, "`CoercePointee` can only be derived on `struct`s"),
);
};
let has_at_least_one_field = strukt
.field_list()
.map(|it| match it {
ast::FieldList::RecordFieldList(it) => it.fields().next().is_some(),
ast::FieldList::TupleFieldList(it) => it.fields().next().is_some(),
})
.unwrap_or(false);
if !has_at_least_one_field {
return ExpandResult::new(
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
ExpandError::other(
span,
"`CoercePointee` can only be derived on `struct`s with at least one field",
),
);
}
let is_repr_transparent = strukt.attrs().any(|attr| {
attr.as_simple_call().is_some_and(|(name, tt)| {
name == "repr"
&& tt.syntax().children_with_tokens().any(|it| {
it.into_token().is_some_and(|it| {
it.kind() == SyntaxKind::IDENT && it.text() == "transparent"
})
})
})
});
if !is_repr_transparent {
return ExpandResult::new(
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
ExpandError::other(
span,
"`CoercePointee` can only be derived on `struct`s with `#[repr(transparent)]`",
),
);
}
let type_params = strukt
.generic_param_list()
.into_iter()
.flat_map(|generics| {
generics.generic_params().filter_map(|param| match param {
ast::GenericParam::TypeParam(param) => Some(param),
_ => None,
})
})
.collect_vec();
if type_params.is_empty() {
return ExpandResult::new(
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
ExpandError::other(
span,
"`CoercePointee` can only be derived on `struct`s that are generic over at least one type",
),
);
}
let (pointee_param, pointee_param_idx) = if type_params.len() == 1 {
// Regardless of the only type param being designed as `#[pointee]` or not, we can just use it as such.
(type_params[0].clone(), 0)
} else {
let mut pointees = type_params.iter().cloned().enumerate().filter(|(_, param)| {
param.attrs().any(|attr| {
let is_pointee = attr.as_simple_atom().is_some_and(|name| name == "pointee");
if is_pointee {
// Remove the `#[pointee]` attribute so it won't be present in the generated
// impls (where we cannot resolve it).
ted::remove(attr.syntax());
}
is_pointee
})
});
match (pointees.next(), pointees.next()) {
(Some((pointee_idx, pointee)), None) => (pointee, pointee_idx),
(None, _) => {
return ExpandResult::new(
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
ExpandError::other(
span,
"exactly one generic type parameter must be marked \
as `#[pointee]` to derive `CoercePointee` traits",
),
)
}
(Some(_), Some(_)) => {
return ExpandResult::new(
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
ExpandError::other(
span,
"only one type parameter can be marked as `#[pointee]` \
when deriving `CoercePointee` traits",
),
)
}
}
};
let (Some(struct_name), Some(pointee_param_name)) = (strukt.name(), pointee_param.name())
else {
return ExpandResult::new(
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
ExpandError::other(span, "invalid item"),
);
};
{
let mut pointee_has_maybe_sized_bound = false;
if let Some(bounds) = pointee_param.type_bound_list() {
pointee_has_maybe_sized_bound |= bounds.bounds().any(is_maybe_sized_bound);
}
if let Some(where_clause) = strukt.where_clause() {
pointee_has_maybe_sized_bound |= where_clause.predicates().any(|pred| {
let Some(ast::Type::PathType(ty)) = pred.ty() else { return false };
let is_not_pointee = ty.path().is_none_or(|path| {
let is_pointee = path
.as_single_name_ref()
.is_some_and(|name| name.text() == pointee_param_name.text());
!is_pointee
});
if is_not_pointee {
return false;
}
pred.type_bound_list()
.is_some_and(|bounds| bounds.bounds().any(is_maybe_sized_bound))
})
}
if !pointee_has_maybe_sized_bound {
return ExpandResult::new(
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
ExpandError::other(
span,
format!("`derive(CoercePointee)` requires `{pointee_param_name}` to be marked `?Sized`"),
),
);
}
}
const ADDED_PARAM: &str = "__S";
let where_clause = strukt.get_or_create_where_clause();
{
let mut new_predicates = Vec::new();
// # Rewrite generic parameter bounds
// For each bound `U: ..` in `struct<U: ..>`, make a new bound with `__S` in place of `#[pointee]`
// Example:
// ```
// struct<
// U: Trait<T>,
// #[pointee] T: Trait<T> + ?Sized,
// V: Trait<T>> ...
// ```
// ... generates this `impl` generic parameters
// ```
// impl<
// U: Trait<T>,
// T: Trait<T> + ?Sized,
// V: Trait<T>
// >
// where
// U: Trait<__S>,
// __S: Trait<__S> + ?Sized,
// V: Trait<__S> ...
// ```
for param in &type_params {
let Some(param_name) = param.name() else { continue };
if let Some(bounds) = param.type_bound_list() {
// If the target type is the pointee, duplicate the bound as whole.
// Otherwise, duplicate only bounds that mention the pointee.
let is_pointee = param_name.text() == pointee_param_name.text();
let new_bounds = bounds
.bounds()
.map(|bound| bound.clone_subtree().clone_for_update())
.filter(|bound| {
bound.ty().is_some_and(|ty| {
substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
|| is_pointee
})
});
let new_bounds_target = if is_pointee {
make::name_ref(ADDED_PARAM)
} else {
make::name_ref(&param_name.text())
};
new_predicates.push(
make::where_pred(
make::ty_path(make::path_from_segments(
[make::path_segment(new_bounds_target)],
false,
)),
new_bounds,
)
.clone_for_update(),
);
}
}
// # Rewrite `where` clauses
//
// Move on to `where` clauses.
// Example:
// ```
// struct MyPointer<#[pointee] T, ..>
// where
// U: Trait<V> + Trait<T>,
// Companion<T>: Trait<T>,
// T: Trait<T> + ?Sized,
// { .. }
// ```
// ... will have a impl prelude like so
// ```
// impl<..> ..
// where
// U: Trait<V> + Trait<T>,
// U: Trait<__S>,
// Companion<T>: Trait<T>,
// Companion<__S>: Trait<__S>,
// T: Trait<T> + ?Sized,
// __S: Trait<__S> + ?Sized,
// ```
//
// We should also write a few new `where` bounds from `#[pointee] T` to `__S`
// as well as any bound that indirectly involves the `#[pointee] T` type.
for predicate in where_clause.predicates() {
let predicate = predicate.clone_subtree().clone_for_update();
let Some(pred_target) = predicate.ty() else { continue };
// If the target type references the pointee, duplicate the bound as whole.
// Otherwise, duplicate only bounds that mention the pointee.
if substitute_type_in_bound(
pred_target.clone(),
&pointee_param_name.text(),
ADDED_PARAM,
) {
if let Some(bounds) = predicate.type_bound_list() {
for bound in bounds.bounds() {
if let Some(ty) = bound.ty() {
substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM);
}
}
}
new_predicates.push(predicate);
} else if let Some(bounds) = predicate.type_bound_list() {
let new_bounds = bounds
.bounds()
.map(|bound| bound.clone_subtree().clone_for_update())
.filter(|bound| {
bound.ty().is_some_and(|ty| {
substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
})
});
new_predicates.push(make::where_pred(pred_target, new_bounds).clone_for_update());
}
}
for new_predicate in new_predicates {
where_clause.add_predicate(new_predicate);
}
}
{
// # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location
//
// Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
where_clause.add_predicate(
make::where_pred(
make::ty_path(make::path_from_segments(
[make::path_segment(make::name_ref(&pointee_param_name.text()))],
false,
)),
[make::type_bound(make::ty_path(make::path_from_segments(
[
make::path_segment(make::name_ref("core")),
make::path_segment(make::name_ref("marker")),
make::generic_ty_path_segment(
make::name_ref("Unsize"),
[make::type_arg(make::ty_path(make::path_from_segments(
[make::path_segment(make::name_ref(ADDED_PARAM))],
false,
)))
.into()],
),
],
true,
)))],
)
.clone_for_update(),
);
}
let self_for_traits = {
// Replace the `#[pointee]` with `__S`.
let mut type_param_idx = 0;
let self_params_for_traits = strukt
.generic_param_list()
.into_iter()
.flat_map(|params| params.generic_params())
.filter_map(|param| {
Some(match param {
ast::GenericParam::ConstParam(param) => {
ast::GenericArg::ConstArg(make::expr_const_value(&param.name()?.text()))
}
ast::GenericParam::LifetimeParam(param) => {
make::lifetime_arg(param.lifetime()?).into()
}
ast::GenericParam::TypeParam(param) => {
let name = if pointee_param_idx == type_param_idx {
make::name_ref(ADDED_PARAM)
} else {
make::name_ref(&param.name()?.text())
};
type_param_idx += 1;
make::type_arg(make::ty_path(make::path_from_segments(
[make::path_segment(name)],
false,
)))
.into()
}
})
});
let self_for_traits = make::path_from_segments(
[make::generic_ty_path_segment(
make::name_ref(&struct_name.text()),
self_params_for_traits,
)],
false,
)
.clone_for_update();
self_for_traits
};
let mut span_map = span::SpanMap::empty();
// One span for them all.
span_map.push(adt.syntax().text_range().end(), span);
let self_for_traits = syntax_bridge::syntax_node_to_token_tree(
self_for_traits.syntax(),
&span_map,
span,
DocCommentDesugarMode::ProcMacro,
);
let info = match parse_adt_from_syntax(&adt, &span_map, span) {
Ok(it) => it,
Err(err) => {
return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err)
}
};
let self_for_traits2 = self_for_traits.clone();
let krate = dollar_crate(span);
let krate2 = krate.clone();
let dispatch_from_dyn = expand_simple_derive_with_parsed(
span,
info.clone(),
quote! {span => #krate2::ops::DispatchFromDyn<#self_for_traits2> },
|_adt| quote! {span => },
false,
quote! {span => __S },
);
let coerce_unsized = expand_simple_derive_with_parsed(
span,
info,
quote! {span => #krate::ops::CoerceUnsized<#self_for_traits> },
|_adt| quote! {span => },
false,
quote! {span => __S },
);
return ExpandResult::ok(quote! {span => #dispatch_from_dyn #coerce_unsized });
fn is_maybe_sized_bound(bound: ast::TypeBound) -> bool {
if bound.question_mark_token().is_none() {
return false;
}
let Some(ast::Type::PathType(ty)) = bound.ty() else {
return false;
};
let Some(path) = ty.path() else {
return false;
};
return segments_eq(&path, &["Sized"])
|| segments_eq(&path, &["core", "marker", "Sized"])
|| segments_eq(&path, &["std", "marker", "Sized"]);
fn segments_eq(path: &ast::Path, expected: &[&str]) -> bool {
path.segments().zip_longest(expected.iter().copied()).all(|value| {
value.both().is_some_and(|(segment, expected)| {
segment.name_ref().is_some_and(|name| name.text() == expected)
})
})
}
}
/// Returns true if any substitution was performed.
fn substitute_type_in_bound(ty: ast::Type, param_name: &str, replacement: &str) -> bool {
return match ty {
ast::Type::ArrayType(ty) => {
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
}
ast::Type::DynTraitType(ty) => go_bounds(ty.type_bound_list(), param_name, replacement),
ast::Type::FnPtrType(ty) => any_long(
ty.param_list()
.into_iter()
.flat_map(|params| params.params().filter_map(|param| param.ty()))
.chain(ty.ret_type().and_then(|it| it.ty())),
|ty| substitute_type_in_bound(ty, param_name, replacement),
),
ast::Type::ForType(ty) => {
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
}
ast::Type::ImplTraitType(ty) => {
go_bounds(ty.type_bound_list(), param_name, replacement)
}
ast::Type::ParenType(ty) => {
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
}
ast::Type::PathType(ty) => ty.path().is_some_and(|path| {
if path.as_single_name_ref().is_some_and(|name| name.text() == param_name) {
ted::replace(
path.syntax(),
make::path_from_segments(
[make::path_segment(make::name_ref(replacement))],
false,
)
.clone_for_update()
.syntax(),
);
return true;
}
any_long(
path.segments()
.filter_map(|segment| segment.generic_arg_list())
.flat_map(|it| it.generic_args())
.filter_map(|generic_arg| match generic_arg {
ast::GenericArg::TypeArg(ty) => ty.ty(),
_ => None,
}),
|ty| substitute_type_in_bound(ty, param_name, replacement),
)
}),
ast::Type::PtrType(ty) => {
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
}
ast::Type::RefType(ty) => {
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
}
ast::Type::SliceType(ty) => {
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
}
ast::Type::TupleType(ty) => {
any_long(ty.fields(), |ty| substitute_type_in_bound(ty, param_name, replacement))
}
ast::Type::InferType(_) | ast::Type::MacroType(_) | ast::Type::NeverType(_) => false,
};
fn go_bounds(
bounds: Option<ast::TypeBoundList>,
param_name: &str,
replacement: &str,
) -> bool {
bounds.is_some_and(|bounds| {
any_long(bounds.bounds(), |bound| {
bound
.ty()
.is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
})
})
}
/// Like [`Iterator::any()`], but not short-circuiting.
fn any_long<I: Iterator, F: FnMut(I::Item) -> bool>(iter: I, mut f: F) -> bool {
let mut result = false;
iter.for_each(|item| result |= f(item));
result
}
}
}

View file

@ -78,7 +78,7 @@ pub(crate) fn move_bounds_to_where_clause(
fn build_predicate(param: ast::TypeParam) -> Option<ast::WherePred> { fn build_predicate(param: ast::TypeParam) -> Option<ast::WherePred> {
let path = make::ext::ident_path(&param.name()?.syntax().to_string()); let path = make::ext::ident_path(&param.name()?.syntax().to_string());
let predicate = make::where_pred(path, param.type_bound_list()?.bounds()); let predicate = make::where_pred(make::ty_path(path), param.type_bound_list()?.bounds());
Some(predicate.clone_for_update()) Some(predicate.clone_for_update())
} }

View file

@ -361,6 +361,7 @@ define_symbols! {
partial_ord, partial_ord,
PartialEq, PartialEq,
PartialOrd, PartialOrd,
CoercePointee,
path, path,
Pending, Pending,
phantom_data, phantom_data,

View file

@ -336,6 +336,24 @@ pub fn path_segment(name_ref: ast::NameRef) -> ast::PathSegment {
ast_from_text(&format!("type __ = {name_ref};")) ast_from_text(&format!("type __ = {name_ref};"))
} }
/// Type and expressions/patterns path differ in whether they require `::` before generic arguments.
/// Type paths allow them but they are often omitted, while expression/pattern paths require them.
pub fn generic_ty_path_segment(
name_ref: ast::NameRef,
generic_args: impl IntoIterator<Item = ast::GenericArg>,
) -> ast::PathSegment {
let mut generic_args = generic_args.into_iter();
let first_generic_arg = generic_args.next();
quote! {
PathSegment {
#name_ref
GenericArgList {
[<] #first_generic_arg #([,] " " #generic_args)* [>]
}
}
}
}
pub fn path_segment_ty(type_ref: ast::Type, trait_ref: Option<ast::PathType>) -> ast::PathSegment { pub fn path_segment_ty(type_ref: ast::Type, trait_ref: Option<ast::PathType>) -> ast::PathSegment {
let text = match trait_ref { let text = match trait_ref {
Some(trait_ref) => format!("fn f(x: <{type_ref} as {trait_ref}>) {{}}"), Some(trait_ref) => format!("fn f(x: <{type_ref} as {trait_ref}>) {{}}"),
@ -814,7 +832,7 @@ pub fn match_arm_list(arms: impl IntoIterator<Item = ast::MatchArm>) -> ast::Mat
} }
pub fn where_pred( pub fn where_pred(
path: ast::Path, path: ast::Type,
bounds: impl IntoIterator<Item = ast::TypeBound>, bounds: impl IntoIterator<Item = ast::TypeBound>,
) -> ast::WherePred { ) -> ast::WherePred {
let bounds = bounds.into_iter().join(" + "); let bounds = bounds.into_iter().join(" + ");

View file

@ -17,6 +17,7 @@
//! builtin_impls: //! builtin_impls:
//! cell: copy, drop //! cell: copy, drop
//! clone: sized //! clone: sized
//! coerce_pointee: derive, sized, unsize, coerce_unsized, dispatch_from_dyn
//! coerce_unsized: unsize //! coerce_unsized: unsize
//! concat: //! concat:
//! copy: clone //! copy: clone
@ -157,6 +158,14 @@ pub mod marker {
type Discriminant; type Discriminant;
} }
// endregion:discriminant // endregion:discriminant
// region:coerce_pointee
#[rustc_builtin_macro(CoercePointee, attributes(pointee))]
#[allow_internal_unstable(dispatch_from_dyn, coerce_unsized, unsize)]
pub macro CoercePointee($item:item) {
/* compiler built-in */
}
// endregion:coerce_pointee
} }
// region:default // region:default