mirror of
https://github.com/rust-lang/rust-analyzer
synced 2025-01-11 20:58:54 +00:00
Support the new CoercePointee
derive
This commit is contained in:
parent
061d257e90
commit
ce323627c5
8 changed files with 767 additions and 63 deletions
|
@ -2,7 +2,7 @@
|
|||
|
||||
use expect_test::expect;
|
||||
|
||||
use crate::macro_expansion_tests::check;
|
||||
use crate::macro_expansion_tests::{check, check_errors};
|
||||
|
||||
#[test]
|
||||
fn test_copy_expand_simple() {
|
||||
|
@ -654,3 +654,95 @@ impl < > $crate::default::Default for Bar< > where {
|
|||
}"##]],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coerce_pointee_expansion() {
|
||||
check(
|
||||
r#"
|
||||
//- minicore: coerce_pointee
|
||||
|
||||
use core::marker::CoercePointee;
|
||||
|
||||
pub trait Trait<T: ?Sized> {}
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
#[repr(transparent)]
|
||||
pub struct Foo<'a, T: ?Sized + Trait<U>, #[pointee] U: ?Sized, const N: u32>(T)
|
||||
where
|
||||
U: Trait<U> + ToString;"#,
|
||||
expect![[r#"
|
||||
|
||||
use core::marker::CoercePointee;
|
||||
|
||||
pub trait Trait<T: ?Sized> {}
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
#[repr(transparent)]
|
||||
pub struct Foo<'a, T: ?Sized + Trait<U>, #[pointee] U: ?Sized, const N: u32>(T)
|
||||
where
|
||||
U: Trait<U> + ToString;
|
||||
impl <T, U, const N: u32, __S> $crate::ops::DispatchFromDyn<Foo<'a, T, __S, N>> for Foo<T, U, N, > where U: Trait<U> +ToString, T: Trait<__S>, __S: ?Sized, __S: Trait<__S> +ToString, U: ::core::marker::Unsize<__S>, T:?Sized+Trait<U>, U:?Sized, {}
|
||||
impl <T, U, const N: u32, __S> $crate::ops::CoerceUnsized<Foo<'a, T, __S, N>> for Foo<T, U, N, > where U: Trait<U> +ToString, T: Trait<__S>, __S: ?Sized, __S: Trait<__S> +ToString, U: ::core::marker::Unsize<__S>, T:?Sized+Trait<U>, U:?Sized, {}"#]],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn coerce_pointee_errors() {
|
||||
check_errors(
|
||||
r#"
|
||||
//- minicore: coerce_pointee
|
||||
|
||||
use core::marker::CoercePointee;
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
enum Enum {}
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
struct Struct1;
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
struct Struct2();
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
struct Struct3 {}
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
struct Struct4<T: ?Sized>(T);
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
#[repr(transparent)]
|
||||
struct Struct5(i32);
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
#[repr(transparent)]
|
||||
struct Struct6<#[pointee] T: ?Sized, #[pointee] U: ?Sized>(T, U);
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
#[repr(transparent)]
|
||||
struct Struct7<T: ?Sized, U: ?Sized>(T, U);
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
#[repr(transparent)]
|
||||
struct Struct8<#[pointee] T, U: ?Sized>(T);
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
#[repr(transparent)]
|
||||
struct Struct9<T>(T);
|
||||
|
||||
#[derive(CoercePointee)]
|
||||
#[repr(transparent)]
|
||||
struct Struct9<#[pointee] T, U>(T) where T: ?Sized;
|
||||
"#,
|
||||
expect![[r#"
|
||||
35..72: `CoercePointee` can only be derived on `struct`s
|
||||
74..114: `CoercePointee` can only be derived on `struct`s with at least one field
|
||||
116..158: `CoercePointee` can only be derived on `struct`s with at least one field
|
||||
160..202: `CoercePointee` can only be derived on `struct`s with at least one field
|
||||
204..258: `CoercePointee` can only be derived on `struct`s with `#[repr(transparent)]`
|
||||
260..326: `CoercePointee` can only be derived on `struct`s that are generic over at least one type
|
||||
328..439: only one type parameter can be marked as `#[pointee]` when deriving `CoercePointee` traits
|
||||
441..530: exactly one generic type parameter must be marked as `#[pointee]` to derive `CoercePointee` traits
|
||||
532..621: `derive(CoercePointee)` requires `T` to be marked `?Sized`
|
||||
623..690: `derive(CoercePointee)` requires `T` to be marked `?Sized`"#]],
|
||||
);
|
||||
}
|
||||
|
|
|
@ -16,14 +16,16 @@ mod proc_macros;
|
|||
|
||||
use std::{iter, ops::Range, sync};
|
||||
|
||||
use base_db::SourceDatabase;
|
||||
use expect_test::Expect;
|
||||
use hir_expand::{
|
||||
db::ExpandDatabase,
|
||||
proc_macro::{ProcMacro, ProcMacroExpander, ProcMacroExpansionError, ProcMacroKind},
|
||||
span_map::SpanMapRef,
|
||||
InFile, MacroFileId, MacroFileIdExt,
|
||||
InFile, MacroCallKind, MacroFileId, MacroFileIdExt,
|
||||
};
|
||||
use intern::Symbol;
|
||||
use itertools::Itertools;
|
||||
use span::{Edition, Span};
|
||||
use stdx::{format_to, format_to_acc};
|
||||
use syntax::{
|
||||
|
@ -44,6 +46,36 @@ use crate::{
|
|||
AdtId, AsMacroCall, Lookup, ModuleDefId,
|
||||
};
|
||||
|
||||
#[track_caller]
|
||||
fn check_errors(ra_fixture: &str, expect: Expect) {
|
||||
let db = TestDB::with_files(ra_fixture);
|
||||
let krate = db.fetch_test_crate();
|
||||
let def_map = db.crate_def_map(krate);
|
||||
let errors = def_map
|
||||
.modules()
|
||||
.flat_map(|module| module.1.scope.all_macro_calls())
|
||||
.filter_map(|macro_call| {
|
||||
let errors = db.parse_macro_expansion_error(macro_call)?;
|
||||
let errors = errors.err.as_ref()?.render_to_string(&db);
|
||||
let macro_loc = db.lookup_intern_macro_call(macro_call);
|
||||
let ast_id = match macro_loc.kind {
|
||||
MacroCallKind::FnLike { ast_id, .. } => ast_id.map(|it| it.erase()),
|
||||
MacroCallKind::Derive { ast_id, .. } => ast_id.map(|it| it.erase()),
|
||||
MacroCallKind::Attr { ast_id, .. } => ast_id.map(|it| it.erase()),
|
||||
};
|
||||
let ast = db
|
||||
.parse(ast_id.file_id.file_id().expect("macros inside macros are not supported"))
|
||||
.syntax_node();
|
||||
let ast_id_map = db.ast_id_map(ast_id.file_id);
|
||||
let node = ast_id_map.get_erased(ast_id.value).to_node(&ast);
|
||||
Some((node.text_range(), errors))
|
||||
})
|
||||
.sorted_unstable_by_key(|(range, _)| range.start())
|
||||
.format_with("\n", |(range, err), format| format(&format_args!("{range:?}: {err}")))
|
||||
.to_string();
|
||||
expect.assert_eq(&errors);
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn check(ra_fixture: &str, mut expect: Expect) {
|
||||
let extra_proc_macros = vec![(
|
||||
|
@ -245,7 +277,9 @@ fn pretty_print_macro_expansion(
|
|||
let mut res = String::new();
|
||||
let mut prev_kind = EOF;
|
||||
let mut indent_level = 0;
|
||||
for token in iter::successors(expn.first_token(), |t| t.next_token()) {
|
||||
for token in iter::successors(expn.first_token(), |t| t.next_token())
|
||||
.take_while(|token| token.text_range().start() < expn.text_range().end())
|
||||
{
|
||||
let curr_kind = token.kind();
|
||||
let space = match (prev_kind, curr_kind) {
|
||||
_ if prev_kind.is_trivia() || curr_kind.is_trivia() => "",
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
//! Builtin derives.
|
||||
|
||||
use intern::sym;
|
||||
use itertools::izip;
|
||||
use itertools::{izip, Itertools};
|
||||
use parser::SyntaxKind;
|
||||
use rustc_hash::FxHashSet;
|
||||
use span::{MacroCallId, Span};
|
||||
use span::{MacroCallId, Span, SyntaxContextId};
|
||||
use stdx::never;
|
||||
use syntax_bridge::DocCommentDesugarMode;
|
||||
use tracing::debug;
|
||||
|
@ -16,8 +17,12 @@ use crate::{
|
|||
span_map::ExpansionSpanMap,
|
||||
tt, ExpandError, ExpandResult,
|
||||
};
|
||||
use syntax::ast::{
|
||||
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
|
||||
use syntax::{
|
||||
ast::{
|
||||
self, edit_in_place::GenericParamsOwnerEdit, make, AstNode, FieldList, HasAttrs,
|
||||
HasGenericArgs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
|
||||
},
|
||||
ted,
|
||||
};
|
||||
|
||||
macro_rules! register_builtin {
|
||||
|
@ -67,13 +72,15 @@ register_builtin! {
|
|||
Ord => ord_expand,
|
||||
PartialOrd => partial_ord_expand,
|
||||
Eq => eq_expand,
|
||||
PartialEq => partial_eq_expand
|
||||
PartialEq => partial_eq_expand,
|
||||
CoercePointee => coerce_pointee_expand
|
||||
}
|
||||
|
||||
pub fn find_builtin_derive(ident: &name::Name) -> Option<BuiltinDeriveExpander> {
|
||||
BuiltinDeriveExpander::find_by_name(ident)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum VariantShape {
|
||||
Struct(Vec<tt::Ident>),
|
||||
Tuple(usize),
|
||||
|
@ -147,6 +154,7 @@ impl VariantShape {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum AdtShape {
|
||||
Struct(VariantShape),
|
||||
Enum { variants: Vec<(tt::Ident, VariantShape)>, default_variant: Option<usize> },
|
||||
|
@ -197,30 +205,38 @@ impl AdtShape {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct BasicAdtInfo {
|
||||
name: tt::Ident,
|
||||
shape: AdtShape,
|
||||
/// first field is the name, and
|
||||
/// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param.
|
||||
/// third fields is where bounds, if any
|
||||
param_types: Vec<(tt::TopSubtree, Option<tt::TopSubtree>, Option<tt::TopSubtree>)>,
|
||||
param_types: Vec<AdtParam>,
|
||||
where_clause: Vec<tt::TopSubtree>,
|
||||
associated_types: Vec<tt::TopSubtree>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AdtParam {
|
||||
name: tt::TopSubtree,
|
||||
/// `None` if this is a type parameter.
|
||||
const_ty: Option<tt::TopSubtree>,
|
||||
bounds: Option<tt::TopSubtree>,
|
||||
}
|
||||
|
||||
// FIXME: This whole thing needs a refactor. Each derive requires its special values, and the result is a mess.
|
||||
fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result<BasicAdtInfo, ExpandError> {
|
||||
let (parsed, tm) = &syntax_bridge::token_tree_to_syntax_node(
|
||||
tt,
|
||||
syntax_bridge::TopEntryPoint::MacroItems,
|
||||
parser::Edition::CURRENT_FIXME,
|
||||
);
|
||||
let macro_items = ast::MacroItems::cast(parsed.syntax_node())
|
||||
.ok_or_else(|| ExpandError::other(call_site, "invalid item definition"))?;
|
||||
let item =
|
||||
macro_items.items().next().ok_or_else(|| ExpandError::other(call_site, "no item found"))?;
|
||||
let adt = &ast::Adt::cast(item.syntax().clone())
|
||||
.ok_or_else(|| ExpandError::other(call_site, "expected struct, enum or union"))?;
|
||||
let (name, generic_param_list, where_clause, shape) = match adt {
|
||||
let (adt, tm) = to_adt_syntax(tt, call_site)?;
|
||||
parse_adt_from_syntax(&adt, &tm, call_site)
|
||||
}
|
||||
|
||||
fn parse_adt_from_syntax(
|
||||
adt: &ast::Adt,
|
||||
tm: &span::SpanMap<SyntaxContextId>,
|
||||
call_site: Span,
|
||||
) -> Result<BasicAdtInfo, ExpandError> {
|
||||
let (name, generic_param_list, where_clause, shape) = match &adt {
|
||||
ast::Adt::Struct(it) => (
|
||||
it.name(),
|
||||
it.generic_param_list(),
|
||||
|
@ -291,7 +307,7 @@ fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result<BasicAdtInfo, Expan
|
|||
}),
|
||||
ast::TypeOrConstParam::Const(_) => None,
|
||||
};
|
||||
let ty = if let ast::TypeOrConstParam::Const(param) = param {
|
||||
let const_ty = if let ast::TypeOrConstParam::Const(param) = param {
|
||||
let ty = param
|
||||
.ty()
|
||||
.map(|ty| {
|
||||
|
@ -309,7 +325,7 @@ fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result<BasicAdtInfo, Expan
|
|||
} else {
|
||||
None
|
||||
};
|
||||
(name, ty, bounds)
|
||||
AdtParam { name, const_ty, bounds }
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
@ -365,6 +381,24 @@ fn parse_adt(tt: &tt::TopSubtree, call_site: Span) -> Result<BasicAdtInfo, Expan
|
|||
Ok(BasicAdtInfo { name: name_token, shape, param_types, where_clause, associated_types })
|
||||
}
|
||||
|
||||
fn to_adt_syntax(
|
||||
tt: &tt::TopSubtree,
|
||||
call_site: Span,
|
||||
) -> Result<(ast::Adt, span::SpanMap<SyntaxContextId>), ExpandError> {
|
||||
let (parsed, tm) = syntax_bridge::token_tree_to_syntax_node(
|
||||
tt,
|
||||
syntax_bridge::TopEntryPoint::MacroItems,
|
||||
parser::Edition::CURRENT_FIXME,
|
||||
);
|
||||
let macro_items = ast::MacroItems::cast(parsed.syntax_node())
|
||||
.ok_or_else(|| ExpandError::other(call_site, "invalid item definition"))?;
|
||||
let item =
|
||||
macro_items.items().next().ok_or_else(|| ExpandError::other(call_site, "no item found"))?;
|
||||
let adt = ast::Adt::cast(item.syntax().clone())
|
||||
.ok_or_else(|| ExpandError::other(call_site, "expected struct, enum or union"))?;
|
||||
Ok((adt, tm))
|
||||
}
|
||||
|
||||
fn name_to_token(
|
||||
call_site: Span,
|
||||
token_map: &ExpansionSpanMap,
|
||||
|
@ -426,38 +460,64 @@ fn expand_simple_derive(
|
|||
)
|
||||
}
|
||||
};
|
||||
ExpandResult::ok(expand_simple_derive_with_parsed(
|
||||
invoc_span,
|
||||
info,
|
||||
trait_path,
|
||||
make_trait_body,
|
||||
true,
|
||||
tt::TopSubtree::empty(tt::DelimSpan::from_single(invoc_span)),
|
||||
))
|
||||
}
|
||||
|
||||
fn expand_simple_derive_with_parsed(
|
||||
invoc_span: Span,
|
||||
info: BasicAdtInfo,
|
||||
trait_path: tt::TopSubtree,
|
||||
make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::TopSubtree,
|
||||
constrain_to_trait: bool,
|
||||
extra_impl_params: tt::TopSubtree,
|
||||
) -> tt::TopSubtree {
|
||||
let trait_body = make_trait_body(&info);
|
||||
let mut where_block: Vec<_> =
|
||||
info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect();
|
||||
let (params, args): (Vec<_>, Vec<_>) = info
|
||||
.param_types
|
||||
.into_iter()
|
||||
.map(|(ident, param_ty, bound)| {
|
||||
let ident_ = ident.clone();
|
||||
if let Some(b) = bound {
|
||||
let ident = ident.clone();
|
||||
where_block.push(quote! {invoc_span => #ident : #b , });
|
||||
.map(|param| {
|
||||
let ident = param.name;
|
||||
if let Some(b) = param.bounds {
|
||||
let ident2 = ident.clone();
|
||||
where_block.push(quote! {invoc_span => #ident2 : #b , });
|
||||
}
|
||||
if let Some(ty) = param_ty {
|
||||
(quote! {invoc_span => const #ident : #ty , }, quote! {invoc_span => #ident_ , })
|
||||
if let Some(ty) = param.const_ty {
|
||||
let ident2 = ident.clone();
|
||||
(quote! {invoc_span => const #ident : #ty , }, quote! {invoc_span => #ident2 , })
|
||||
} else {
|
||||
let bound = trait_path.clone();
|
||||
(quote! {invoc_span => #ident : #bound , }, quote! {invoc_span => #ident_ , })
|
||||
let ident2 = ident.clone();
|
||||
let param = if constrain_to_trait {
|
||||
quote! {invoc_span => #ident : #bound , }
|
||||
} else {
|
||||
quote! {invoc_span => #ident , }
|
||||
};
|
||||
(param, quote! {invoc_span => #ident2 , })
|
||||
}
|
||||
})
|
||||
.unzip();
|
||||
|
||||
if constrain_to_trait {
|
||||
where_block.extend(info.associated_types.iter().map(|it| {
|
||||
let it = it.clone();
|
||||
let bound = trait_path.clone();
|
||||
quote! {invoc_span => #it : #bound , }
|
||||
}));
|
||||
}
|
||||
|
||||
let name = info.name;
|
||||
let expanded = quote! {invoc_span =>
|
||||
impl < ##params > #trait_path for #name < ##args > where ##where_block { #trait_body }
|
||||
};
|
||||
ExpandResult::ok(expanded)
|
||||
quote! {invoc_span =>
|
||||
impl < ##params #extra_impl_params > #trait_path for #name < ##args > where ##where_block { #trait_body }
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
|
||||
|
@ -871,3 +931,493 @@ fn partial_ord_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSu
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn coerce_pointee_expand(span: Span, tt: &tt::TopSubtree) -> ExpandResult<tt::TopSubtree> {
|
||||
let (adt, _span_map) = match to_adt_syntax(tt, span) {
|
||||
Ok(it) => it,
|
||||
Err(err) => {
|
||||
return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err);
|
||||
}
|
||||
};
|
||||
let adt = adt.clone_for_update();
|
||||
let ast::Adt::Struct(strukt) = &adt else {
|
||||
return ExpandResult::new(
|
||||
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
|
||||
ExpandError::other(span, "`CoercePointee` can only be derived on `struct`s"),
|
||||
);
|
||||
};
|
||||
let has_at_least_one_field = strukt
|
||||
.field_list()
|
||||
.map(|it| match it {
|
||||
ast::FieldList::RecordFieldList(it) => it.fields().next().is_some(),
|
||||
ast::FieldList::TupleFieldList(it) => it.fields().next().is_some(),
|
||||
})
|
||||
.unwrap_or(false);
|
||||
if !has_at_least_one_field {
|
||||
return ExpandResult::new(
|
||||
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
|
||||
ExpandError::other(
|
||||
span,
|
||||
"`CoercePointee` can only be derived on `struct`s with at least one field",
|
||||
),
|
||||
);
|
||||
}
|
||||
let is_repr_transparent = strukt.attrs().any(|attr| {
|
||||
attr.as_simple_call().is_some_and(|(name, tt)| {
|
||||
name == "repr"
|
||||
&& tt.syntax().children_with_tokens().any(|it| {
|
||||
it.into_token().is_some_and(|it| {
|
||||
it.kind() == SyntaxKind::IDENT && it.text() == "transparent"
|
||||
})
|
||||
})
|
||||
})
|
||||
});
|
||||
if !is_repr_transparent {
|
||||
return ExpandResult::new(
|
||||
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
|
||||
ExpandError::other(
|
||||
span,
|
||||
"`CoercePointee` can only be derived on `struct`s with `#[repr(transparent)]`",
|
||||
),
|
||||
);
|
||||
}
|
||||
let type_params = strukt
|
||||
.generic_param_list()
|
||||
.into_iter()
|
||||
.flat_map(|generics| {
|
||||
generics.generic_params().filter_map(|param| match param {
|
||||
ast::GenericParam::TypeParam(param) => Some(param),
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.collect_vec();
|
||||
if type_params.is_empty() {
|
||||
return ExpandResult::new(
|
||||
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
|
||||
ExpandError::other(
|
||||
span,
|
||||
"`CoercePointee` can only be derived on `struct`s that are generic over at least one type",
|
||||
),
|
||||
);
|
||||
}
|
||||
let (pointee_param, pointee_param_idx) = if type_params.len() == 1 {
|
||||
// Regardless of the only type param being designed as `#[pointee]` or not, we can just use it as such.
|
||||
(type_params[0].clone(), 0)
|
||||
} else {
|
||||
let mut pointees = type_params.iter().cloned().enumerate().filter(|(_, param)| {
|
||||
param.attrs().any(|attr| {
|
||||
let is_pointee = attr.as_simple_atom().is_some_and(|name| name == "pointee");
|
||||
if is_pointee {
|
||||
// Remove the `#[pointee]` attribute so it won't be present in the generated
|
||||
// impls (where we cannot resolve it).
|
||||
ted::remove(attr.syntax());
|
||||
}
|
||||
is_pointee
|
||||
})
|
||||
});
|
||||
match (pointees.next(), pointees.next()) {
|
||||
(Some((pointee_idx, pointee)), None) => (pointee, pointee_idx),
|
||||
(None, _) => {
|
||||
return ExpandResult::new(
|
||||
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
|
||||
ExpandError::other(
|
||||
span,
|
||||
"exactly one generic type parameter must be marked \
|
||||
as `#[pointee]` to derive `CoercePointee` traits",
|
||||
),
|
||||
)
|
||||
}
|
||||
(Some(_), Some(_)) => {
|
||||
return ExpandResult::new(
|
||||
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
|
||||
ExpandError::other(
|
||||
span,
|
||||
"only one type parameter can be marked as `#[pointee]` \
|
||||
when deriving `CoercePointee` traits",
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
let (Some(struct_name), Some(pointee_param_name)) = (strukt.name(), pointee_param.name())
|
||||
else {
|
||||
return ExpandResult::new(
|
||||
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
|
||||
ExpandError::other(span, "invalid item"),
|
||||
);
|
||||
};
|
||||
|
||||
{
|
||||
let mut pointee_has_maybe_sized_bound = false;
|
||||
if let Some(bounds) = pointee_param.type_bound_list() {
|
||||
pointee_has_maybe_sized_bound |= bounds.bounds().any(is_maybe_sized_bound);
|
||||
}
|
||||
if let Some(where_clause) = strukt.where_clause() {
|
||||
pointee_has_maybe_sized_bound |= where_clause.predicates().any(|pred| {
|
||||
let Some(ast::Type::PathType(ty)) = pred.ty() else { return false };
|
||||
let is_not_pointee = ty.path().is_none_or(|path| {
|
||||
let is_pointee = path
|
||||
.as_single_name_ref()
|
||||
.is_some_and(|name| name.text() == pointee_param_name.text());
|
||||
!is_pointee
|
||||
});
|
||||
if is_not_pointee {
|
||||
return false;
|
||||
}
|
||||
pred.type_bound_list()
|
||||
.is_some_and(|bounds| bounds.bounds().any(is_maybe_sized_bound))
|
||||
})
|
||||
}
|
||||
if !pointee_has_maybe_sized_bound {
|
||||
return ExpandResult::new(
|
||||
tt::TopSubtree::empty(tt::DelimSpan::from_single(span)),
|
||||
ExpandError::other(
|
||||
span,
|
||||
format!("`derive(CoercePointee)` requires `{pointee_param_name}` to be marked `?Sized`"),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const ADDED_PARAM: &str = "__S";
|
||||
|
||||
let where_clause = strukt.get_or_create_where_clause();
|
||||
|
||||
{
|
||||
let mut new_predicates = Vec::new();
|
||||
|
||||
// # Rewrite generic parameter bounds
|
||||
// For each bound `U: ..` in `struct<U: ..>`, make a new bound with `__S` in place of `#[pointee]`
|
||||
// Example:
|
||||
// ```
|
||||
// struct<
|
||||
// U: Trait<T>,
|
||||
// #[pointee] T: Trait<T> + ?Sized,
|
||||
// V: Trait<T>> ...
|
||||
// ```
|
||||
// ... generates this `impl` generic parameters
|
||||
// ```
|
||||
// impl<
|
||||
// U: Trait<T>,
|
||||
// T: Trait<T> + ?Sized,
|
||||
// V: Trait<T>
|
||||
// >
|
||||
// where
|
||||
// U: Trait<__S>,
|
||||
// __S: Trait<__S> + ?Sized,
|
||||
// V: Trait<__S> ...
|
||||
// ```
|
||||
for param in &type_params {
|
||||
let Some(param_name) = param.name() else { continue };
|
||||
if let Some(bounds) = param.type_bound_list() {
|
||||
// If the target type is the pointee, duplicate the bound as whole.
|
||||
// Otherwise, duplicate only bounds that mention the pointee.
|
||||
let is_pointee = param_name.text() == pointee_param_name.text();
|
||||
let new_bounds = bounds
|
||||
.bounds()
|
||||
.map(|bound| bound.clone_subtree().clone_for_update())
|
||||
.filter(|bound| {
|
||||
bound.ty().is_some_and(|ty| {
|
||||
substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
|
||||
|| is_pointee
|
||||
})
|
||||
});
|
||||
let new_bounds_target = if is_pointee {
|
||||
make::name_ref(ADDED_PARAM)
|
||||
} else {
|
||||
make::name_ref(¶m_name.text())
|
||||
};
|
||||
new_predicates.push(
|
||||
make::where_pred(
|
||||
make::ty_path(make::path_from_segments(
|
||||
[make::path_segment(new_bounds_target)],
|
||||
false,
|
||||
)),
|
||||
new_bounds,
|
||||
)
|
||||
.clone_for_update(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// # Rewrite `where` clauses
|
||||
//
|
||||
// Move on to `where` clauses.
|
||||
// Example:
|
||||
// ```
|
||||
// struct MyPointer<#[pointee] T, ..>
|
||||
// where
|
||||
// U: Trait<V> + Trait<T>,
|
||||
// Companion<T>: Trait<T>,
|
||||
// T: Trait<T> + ?Sized,
|
||||
// { .. }
|
||||
// ```
|
||||
// ... will have a impl prelude like so
|
||||
// ```
|
||||
// impl<..> ..
|
||||
// where
|
||||
// U: Trait<V> + Trait<T>,
|
||||
// U: Trait<__S>,
|
||||
// Companion<T>: Trait<T>,
|
||||
// Companion<__S>: Trait<__S>,
|
||||
// T: Trait<T> + ?Sized,
|
||||
// __S: Trait<__S> + ?Sized,
|
||||
// ```
|
||||
//
|
||||
// We should also write a few new `where` bounds from `#[pointee] T` to `__S`
|
||||
// as well as any bound that indirectly involves the `#[pointee] T` type.
|
||||
for predicate in where_clause.predicates() {
|
||||
let predicate = predicate.clone_subtree().clone_for_update();
|
||||
let Some(pred_target) = predicate.ty() else { continue };
|
||||
|
||||
// If the target type references the pointee, duplicate the bound as whole.
|
||||
// Otherwise, duplicate only bounds that mention the pointee.
|
||||
if substitute_type_in_bound(
|
||||
pred_target.clone(),
|
||||
&pointee_param_name.text(),
|
||||
ADDED_PARAM,
|
||||
) {
|
||||
if let Some(bounds) = predicate.type_bound_list() {
|
||||
for bound in bounds.bounds() {
|
||||
if let Some(ty) = bound.ty() {
|
||||
substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
new_predicates.push(predicate);
|
||||
} else if let Some(bounds) = predicate.type_bound_list() {
|
||||
let new_bounds = bounds
|
||||
.bounds()
|
||||
.map(|bound| bound.clone_subtree().clone_for_update())
|
||||
.filter(|bound| {
|
||||
bound.ty().is_some_and(|ty| {
|
||||
substitute_type_in_bound(ty, &pointee_param_name.text(), ADDED_PARAM)
|
||||
})
|
||||
});
|
||||
new_predicates.push(make::where_pred(pred_target, new_bounds).clone_for_update());
|
||||
}
|
||||
}
|
||||
|
||||
for new_predicate in new_predicates {
|
||||
where_clause.add_predicate(new_predicate);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// # Add `Unsize<__S>` bound to `#[pointee]` at the generic parameter location
|
||||
//
|
||||
// Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
|
||||
where_clause.add_predicate(
|
||||
make::where_pred(
|
||||
make::ty_path(make::path_from_segments(
|
||||
[make::path_segment(make::name_ref(&pointee_param_name.text()))],
|
||||
false,
|
||||
)),
|
||||
[make::type_bound(make::ty_path(make::path_from_segments(
|
||||
[
|
||||
make::path_segment(make::name_ref("core")),
|
||||
make::path_segment(make::name_ref("marker")),
|
||||
make::generic_ty_path_segment(
|
||||
make::name_ref("Unsize"),
|
||||
[make::type_arg(make::ty_path(make::path_from_segments(
|
||||
[make::path_segment(make::name_ref(ADDED_PARAM))],
|
||||
false,
|
||||
)))
|
||||
.into()],
|
||||
),
|
||||
],
|
||||
true,
|
||||
)))],
|
||||
)
|
||||
.clone_for_update(),
|
||||
);
|
||||
}
|
||||
|
||||
let self_for_traits = {
|
||||
// Replace the `#[pointee]` with `__S`.
|
||||
let mut type_param_idx = 0;
|
||||
let self_params_for_traits = strukt
|
||||
.generic_param_list()
|
||||
.into_iter()
|
||||
.flat_map(|params| params.generic_params())
|
||||
.filter_map(|param| {
|
||||
Some(match param {
|
||||
ast::GenericParam::ConstParam(param) => {
|
||||
ast::GenericArg::ConstArg(make::expr_const_value(¶m.name()?.text()))
|
||||
}
|
||||
ast::GenericParam::LifetimeParam(param) => {
|
||||
make::lifetime_arg(param.lifetime()?).into()
|
||||
}
|
||||
ast::GenericParam::TypeParam(param) => {
|
||||
let name = if pointee_param_idx == type_param_idx {
|
||||
make::name_ref(ADDED_PARAM)
|
||||
} else {
|
||||
make::name_ref(¶m.name()?.text())
|
||||
};
|
||||
type_param_idx += 1;
|
||||
make::type_arg(make::ty_path(make::path_from_segments(
|
||||
[make::path_segment(name)],
|
||||
false,
|
||||
)))
|
||||
.into()
|
||||
}
|
||||
})
|
||||
});
|
||||
let self_for_traits = make::path_from_segments(
|
||||
[make::generic_ty_path_segment(
|
||||
make::name_ref(&struct_name.text()),
|
||||
self_params_for_traits,
|
||||
)],
|
||||
false,
|
||||
)
|
||||
.clone_for_update();
|
||||
self_for_traits
|
||||
};
|
||||
|
||||
let mut span_map = span::SpanMap::empty();
|
||||
// One span for them all.
|
||||
span_map.push(adt.syntax().text_range().end(), span);
|
||||
|
||||
let self_for_traits = syntax_bridge::syntax_node_to_token_tree(
|
||||
self_for_traits.syntax(),
|
||||
&span_map,
|
||||
span,
|
||||
DocCommentDesugarMode::ProcMacro,
|
||||
);
|
||||
let info = match parse_adt_from_syntax(&adt, &span_map, span) {
|
||||
Ok(it) => it,
|
||||
Err(err) => {
|
||||
return ExpandResult::new(tt::TopSubtree::empty(tt::DelimSpan::from_single(span)), err)
|
||||
}
|
||||
};
|
||||
|
||||
let self_for_traits2 = self_for_traits.clone();
|
||||
let krate = dollar_crate(span);
|
||||
let krate2 = krate.clone();
|
||||
let dispatch_from_dyn = expand_simple_derive_with_parsed(
|
||||
span,
|
||||
info.clone(),
|
||||
quote! {span => #krate2::ops::DispatchFromDyn<#self_for_traits2> },
|
||||
|_adt| quote! {span => },
|
||||
false,
|
||||
quote! {span => __S },
|
||||
);
|
||||
let coerce_unsized = expand_simple_derive_with_parsed(
|
||||
span,
|
||||
info,
|
||||
quote! {span => #krate::ops::CoerceUnsized<#self_for_traits> },
|
||||
|_adt| quote! {span => },
|
||||
false,
|
||||
quote! {span => __S },
|
||||
);
|
||||
return ExpandResult::ok(quote! {span => #dispatch_from_dyn #coerce_unsized });
|
||||
|
||||
fn is_maybe_sized_bound(bound: ast::TypeBound) -> bool {
|
||||
if bound.question_mark_token().is_none() {
|
||||
return false;
|
||||
}
|
||||
let Some(ast::Type::PathType(ty)) = bound.ty() else {
|
||||
return false;
|
||||
};
|
||||
let Some(path) = ty.path() else {
|
||||
return false;
|
||||
};
|
||||
return segments_eq(&path, &["Sized"])
|
||||
|| segments_eq(&path, &["core", "marker", "Sized"])
|
||||
|| segments_eq(&path, &["std", "marker", "Sized"]);
|
||||
|
||||
fn segments_eq(path: &ast::Path, expected: &[&str]) -> bool {
|
||||
path.segments().zip_longest(expected.iter().copied()).all(|value| {
|
||||
value.both().is_some_and(|(segment, expected)| {
|
||||
segment.name_ref().is_some_and(|name| name.text() == expected)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if any substitution was performed.
|
||||
fn substitute_type_in_bound(ty: ast::Type, param_name: &str, replacement: &str) -> bool {
|
||||
return match ty {
|
||||
ast::Type::ArrayType(ty) => {
|
||||
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
|
||||
}
|
||||
ast::Type::DynTraitType(ty) => go_bounds(ty.type_bound_list(), param_name, replacement),
|
||||
ast::Type::FnPtrType(ty) => any_long(
|
||||
ty.param_list()
|
||||
.into_iter()
|
||||
.flat_map(|params| params.params().filter_map(|param| param.ty()))
|
||||
.chain(ty.ret_type().and_then(|it| it.ty())),
|
||||
|ty| substitute_type_in_bound(ty, param_name, replacement),
|
||||
),
|
||||
ast::Type::ForType(ty) => {
|
||||
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
|
||||
}
|
||||
ast::Type::ImplTraitType(ty) => {
|
||||
go_bounds(ty.type_bound_list(), param_name, replacement)
|
||||
}
|
||||
ast::Type::ParenType(ty) => {
|
||||
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
|
||||
}
|
||||
ast::Type::PathType(ty) => ty.path().is_some_and(|path| {
|
||||
if path.as_single_name_ref().is_some_and(|name| name.text() == param_name) {
|
||||
ted::replace(
|
||||
path.syntax(),
|
||||
make::path_from_segments(
|
||||
[make::path_segment(make::name_ref(replacement))],
|
||||
false,
|
||||
)
|
||||
.clone_for_update()
|
||||
.syntax(),
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
any_long(
|
||||
path.segments()
|
||||
.filter_map(|segment| segment.generic_arg_list())
|
||||
.flat_map(|it| it.generic_args())
|
||||
.filter_map(|generic_arg| match generic_arg {
|
||||
ast::GenericArg::TypeArg(ty) => ty.ty(),
|
||||
_ => None,
|
||||
}),
|
||||
|ty| substitute_type_in_bound(ty, param_name, replacement),
|
||||
)
|
||||
}),
|
||||
ast::Type::PtrType(ty) => {
|
||||
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
|
||||
}
|
||||
ast::Type::RefType(ty) => {
|
||||
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
|
||||
}
|
||||
ast::Type::SliceType(ty) => {
|
||||
ty.ty().is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
|
||||
}
|
||||
ast::Type::TupleType(ty) => {
|
||||
any_long(ty.fields(), |ty| substitute_type_in_bound(ty, param_name, replacement))
|
||||
}
|
||||
ast::Type::InferType(_) | ast::Type::MacroType(_) | ast::Type::NeverType(_) => false,
|
||||
};
|
||||
|
||||
fn go_bounds(
|
||||
bounds: Option<ast::TypeBoundList>,
|
||||
param_name: &str,
|
||||
replacement: &str,
|
||||
) -> bool {
|
||||
bounds.is_some_and(|bounds| {
|
||||
any_long(bounds.bounds(), |bound| {
|
||||
bound
|
||||
.ty()
|
||||
.is_some_and(|ty| substitute_type_in_bound(ty, param_name, replacement))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Like [`Iterator::any()`], but not short-circuiting.
|
||||
fn any_long<I: Iterator, F: FnMut(I::Item) -> bool>(iter: I, mut f: F) -> bool {
|
||||
let mut result = false;
|
||||
iter.for_each(|item| result |= f(item));
|
||||
result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ pub(crate) fn move_bounds_to_where_clause(
|
|||
|
||||
fn build_predicate(param: ast::TypeParam) -> Option<ast::WherePred> {
|
||||
let path = make::ext::ident_path(¶m.name()?.syntax().to_string());
|
||||
let predicate = make::where_pred(path, param.type_bound_list()?.bounds());
|
||||
let predicate = make::where_pred(make::ty_path(path), param.type_bound_list()?.bounds());
|
||||
Some(predicate.clone_for_update())
|
||||
}
|
||||
|
||||
|
|
|
@ -361,6 +361,7 @@ define_symbols! {
|
|||
partial_ord,
|
||||
PartialEq,
|
||||
PartialOrd,
|
||||
CoercePointee,
|
||||
path,
|
||||
Pending,
|
||||
phantom_data,
|
||||
|
|
|
@ -336,6 +336,24 @@ pub fn path_segment(name_ref: ast::NameRef) -> ast::PathSegment {
|
|||
ast_from_text(&format!("type __ = {name_ref};"))
|
||||
}
|
||||
|
||||
/// Type and expressions/patterns path differ in whether they require `::` before generic arguments.
|
||||
/// Type paths allow them but they are often omitted, while expression/pattern paths require them.
|
||||
pub fn generic_ty_path_segment(
|
||||
name_ref: ast::NameRef,
|
||||
generic_args: impl IntoIterator<Item = ast::GenericArg>,
|
||||
) -> ast::PathSegment {
|
||||
let mut generic_args = generic_args.into_iter();
|
||||
let first_generic_arg = generic_args.next();
|
||||
quote! {
|
||||
PathSegment {
|
||||
#name_ref
|
||||
GenericArgList {
|
||||
[<] #first_generic_arg #([,] " " #generic_args)* [>]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn path_segment_ty(type_ref: ast::Type, trait_ref: Option<ast::PathType>) -> ast::PathSegment {
|
||||
let text = match trait_ref {
|
||||
Some(trait_ref) => format!("fn f(x: <{type_ref} as {trait_ref}>) {{}}"),
|
||||
|
@ -814,7 +832,7 @@ pub fn match_arm_list(arms: impl IntoIterator<Item = ast::MatchArm>) -> ast::Mat
|
|||
}
|
||||
|
||||
pub fn where_pred(
|
||||
path: ast::Path,
|
||||
path: ast::Type,
|
||||
bounds: impl IntoIterator<Item = ast::TypeBound>,
|
||||
) -> ast::WherePred {
|
||||
let bounds = bounds.into_iter().join(" + ");
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
//! builtin_impls:
|
||||
//! cell: copy, drop
|
||||
//! clone: sized
|
||||
//! coerce_pointee: derive, sized, unsize, coerce_unsized, dispatch_from_dyn
|
||||
//! coerce_unsized: unsize
|
||||
//! concat:
|
||||
//! copy: clone
|
||||
|
@ -157,6 +158,14 @@ pub mod marker {
|
|||
type Discriminant;
|
||||
}
|
||||
// endregion:discriminant
|
||||
|
||||
// region:coerce_pointee
|
||||
#[rustc_builtin_macro(CoercePointee, attributes(pointee))]
|
||||
#[allow_internal_unstable(dispatch_from_dyn, coerce_unsized, unsize)]
|
||||
pub macro CoercePointee($item:item) {
|
||||
/* compiler built-in */
|
||||
}
|
||||
// endregion:coerce_pointee
|
||||
}
|
||||
|
||||
// region:default
|
||||
|
|
Loading…
Reference in a new issue