Auto merge of #15000 - lowr:fix/builtin-derive-bound-for-assoc, r=HKalbasi

fix: only generate trait bound for associated types in field types

Given the following definitions:

```rust
trait Trait {
    type A;
    type B;
    type C;
}

#[derive(Clone)]
struct S<T: Trait>
where
    T::A: Send,
{
    qualified: <T as Trait>::B,
    shorthand: T::C,
}
```

we currently expand the derive macro to:

```rust
impl<T> Clone for S<T>
where
    T: Trait + Clone,
    T::A: Clone,
    T::B: Clone,
    T::C: Clone,
{ /* ... */ }
```

This does not match how rustc expands it. Specifically, `Clone` bounds for `T::A` and `T::B` should not be generated.

The criteria for associated types to get bound seem to be 1) the associated type appears as part of field types AND 2) it's written in the shorthand form. I have no idea why rustc doesn't consider qualified associated types (there's even a comment that suggests they should be considered; see rust-lang/rust#50730), but it's important to follow rustc.
This commit is contained in:
bors 2023-06-07 13:00:24 +00:00
commit 085a3112ae
3 changed files with 139 additions and 71 deletions

View file

@ -114,6 +114,66 @@ impl <A: core::clone::Clone, B: core::clone::Clone, > core::clone::Clone for Com
); );
} }
#[test]
fn test_clone_expand_with_associated_types() {
check(
r#"
//- minicore: derive, clone
trait Trait {
type InWc;
type InFieldQualified;
type InFieldShorthand;
type InGenericArg;
}
trait Marker {}
struct Vec<T>(T);
#[derive(Clone)]
struct Foo<T: Trait>
where
<T as Trait>::InWc: Marker,
{
qualified: <T as Trait>::InFieldQualified,
shorthand: T::InFieldShorthand,
generic: Vec<T::InGenericArg>,
}
"#,
expect![[r#"
trait Trait {
type InWc;
type InFieldQualified;
type InFieldShorthand;
type InGenericArg;
}
trait Marker {}
struct Vec<T>(T);
#[derive(Clone)]
struct Foo<T: Trait>
where
<T as Trait>::InWc: Marker,
{
qualified: <T as Trait>::InFieldQualified,
shorthand: T::InFieldShorthand,
generic: Vec<T::InGenericArg>,
}
impl <T: core::clone::Clone, > core::clone::Clone for Foo<T, > where T: Trait, T::InFieldShorthand: core::clone::Clone, T::InGenericArg: core::clone::Clone, {
fn clone(&self ) -> Self {
match self {
Foo {
qualified: qualified, shorthand: shorthand, generic: generic,
}
=>Foo {
qualified: qualified.clone(), shorthand: shorthand.clone(), generic: generic.clone(),
}
,
}
}
}"#]],
);
}
#[test] #[test]
fn test_clone_expand_with_const_generics() { fn test_clone_expand_with_const_generics() {
check( check(

View file

@ -4,17 +4,16 @@ use ::tt::Ident;
use base_db::{CrateOrigin, LangCrateOrigin}; use base_db::{CrateOrigin, LangCrateOrigin};
use itertools::izip; use itertools::izip;
use mbe::TokenMap; use mbe::TokenMap;
use std::collections::HashSet; use rustc_hash::FxHashSet;
use stdx::never; use stdx::never;
use tracing::debug; use tracing::debug;
use crate::tt::{self, TokenId}; use crate::{
use syntax::{ name::{AsName, Name},
ast::{ tt::{self, TokenId},
self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, };
HasTypeBounds, PathType, use syntax::ast::{
}, self, AstNode, FieldList, HasAttrs, HasGenericParams, HasModuleItem, HasName, HasTypeBounds,
match_ast,
}; };
use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId}; use crate::{db::ExpandDatabase, name, quote, ExpandError, ExpandResult, MacroCallId};
@ -201,33 +200,46 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
debug!("no module item parsed"); debug!("no module item parsed");
ExpandError::Other("no item found".into()) ExpandError::Other("no item found".into())
})?; })?;
let node = item.syntax(); let adt = ast::Adt::cast(item.syntax().clone()).ok_or_else(|| {
let (name, params, shape) = match_ast! { debug!("expected adt, found: {:?}", item);
match node { ExpandError::Other("expected struct, enum or union".into())
ast::Struct(it) => (it.name(), it.generic_param_list(), AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?)), })?;
ast::Enum(it) => { let (name, generic_param_list, shape) = match &adt {
let default_variant = it.variant_list().into_iter().flat_map(|x| x.variants()).position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into()))); ast::Adt::Struct(it) => (
it.name(),
it.generic_param_list(),
AdtShape::Struct(VariantShape::from(it.field_list(), &token_map)?),
),
ast::Adt::Enum(it) => {
let default_variant = it
.variant_list()
.into_iter()
.flat_map(|x| x.variants())
.position(|x| x.attrs().any(|x| x.simple_name() == Some("default".into())));
( (
it.name(), it.name(),
it.generic_param_list(), it.generic_param_list(),
AdtShape::Enum { AdtShape::Enum {
default_variant, default_variant,
variants: it.variant_list() variants: it
.variant_list()
.into_iter() .into_iter()
.flat_map(|x| x.variants()) .flat_map(|x| x.variants())
.map(|x| Ok((name_to_token(&token_map,x.name())?, VariantShape::from(x.field_list(), &token_map)?))).collect::<Result<_, ExpandError>>()? .map(|x| {
} Ok((
name_to_token(&token_map, x.name())?,
VariantShape::from(x.field_list(), &token_map)?,
))
})
.collect::<Result<_, ExpandError>>()?,
},
) )
},
ast::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
_ => {
debug!("unexpected node is {:?}", node);
return Err(ExpandError::Other("expected struct, enum or union".into()))
},
} }
ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union),
}; };
let mut param_type_set: HashSet<String> = HashSet::new();
let param_types = params let mut param_type_set: FxHashSet<Name> = FxHashSet::default();
let param_types = generic_param_list
.into_iter() .into_iter()
.flat_map(|param_list| param_list.type_or_const_params()) .flat_map(|param_list| param_list.type_or_const_params())
.map(|param| { .map(|param| {
@ -235,7 +247,7 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
let this = param.name(); let this = param.name();
match this { match this {
Some(x) => { Some(x) => {
param_type_set.insert(x.to_string()); param_type_set.insert(x.as_name());
mbe::syntax_node_to_token_tree(x.syntax()).0 mbe::syntax_node_to_token_tree(x.syntax()).0
} }
None => tt::Subtree::empty(), None => tt::Subtree::empty(),
@ -259,37 +271,33 @@ fn parse_adt(tt: &tt::Subtree) -> Result<BasicAdtInfo, ExpandError> {
(name, ty, bounds) (name, ty, bounds)
}) })
.collect(); .collect();
let is_associated_type = |p: &PathType| {
if let Some(p) = p.path() { // For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field
if let Some(parent) = p.qualifier() { // types (of any variant for enums), we generate trait bound for it. It sounds reasonable to
if let Some(x) = parent.segment() { // also generate trait bound for qualified associated type `<T as Trait>::Assoc`, but rustc
if let Some(x) = x.path_type() { // does not do that for some unknown reason.
if let Some(x) = x.path() { //
if let Some(pname) = x.as_single_name_ref() { // See the analogous function in rustc [find_type_parameters()] and rust-lang/rust#50730.
if param_type_set.contains(&pname.to_string()) { // [find_type_parameters()]: https://github.com/rust-lang/rust/blob/1.70.0/compiler/rustc_builtin_macros/src/deriving/generic/mod.rs#L378
// <T as Trait>::Assoc
return true; // It's cumbersome to deal with the distinct structures of ADTs, so let's just get untyped
} // `SyntaxNode` that contains fields and look for descendant `ast::PathType`s. Of note is that
} // we should not inspect `ast::PathType`s in parameter bounds and where clauses.
} let field_list = match adt {
} ast::Adt::Enum(it) => it.variant_list().map(|list| list.syntax().clone()),
} ast::Adt::Struct(it) => it.field_list().map(|list| list.syntax().clone()),
if let Some(pname) = parent.as_single_name_ref() { ast::Adt::Union(it) => it.record_field_list().map(|list| list.syntax().clone()),
if param_type_set.contains(&pname.to_string()) {
// T::Assoc
return true;
}
}
}
}
false
}; };
let associated_types = node let associated_types = field_list
.descendants() .into_iter()
.filter_map(PathType::cast) .flat_map(|it| it.descendants())
.filter(is_associated_type) .filter_map(ast::PathType::cast)
.filter_map(|p| {
let name = p.path()?.qualifier()?.as_single_name_ref()?.as_name();
param_type_set.contains(&name).then_some(p)
})
.map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0) .map(|x| mbe::syntax_node_to_token_tree(x.syntax()).0)
.collect::<Vec<_>>(); .collect();
let name_token = name_to_token(&token_map, name)?; let name_token = name_to_token(&token_map, name)?;
Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types }) Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types })
} }
@ -334,18 +342,18 @@ fn name_to_token(token_map: &TokenMap, name: Option<ast::Name>) -> Result<tt::Id
/// } /// }
/// ``` /// ```
/// ///
/// where B1, ..., BN are the bounds given by `bounds_paths`.'. Z is a phantom type, and /// where B1, ..., BN are the bounds given by `bounds_paths`. Z is a phantom type, and
/// therefore does not get bound by the derived trait. /// therefore does not get bound by the derived trait.
fn expand_simple_derive( fn expand_simple_derive(
tt: &tt::Subtree, tt: &tt::Subtree,
trait_path: tt::Subtree, trait_path: tt::Subtree,
trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree, make_trait_body: impl FnOnce(&BasicAdtInfo) -> tt::Subtree,
) -> ExpandResult<tt::Subtree> { ) -> ExpandResult<tt::Subtree> {
let info = match parse_adt(tt) { let info = match parse_adt(tt) {
Ok(info) => info, Ok(info) => info,
Err(e) => return ExpandResult::new(tt::Subtree::empty(), e), Err(e) => return ExpandResult::new(tt::Subtree::empty(), e),
}; };
let trait_body = trait_body(&info); let trait_body = make_trait_body(&info);
let mut where_block = vec![]; let mut where_block = vec![];
let (params, args): (Vec<_>, Vec<_>) = info let (params, args): (Vec<_>, Vec<_>) = info
.param_types .param_types

View file

@ -4335,8 +4335,9 @@ fn derive_macro_bounds() {
#[derive(Clone)] #[derive(Clone)]
struct AssocGeneric<T: Tr>(T::Assoc); struct AssocGeneric<T: Tr>(T::Assoc);
#[derive(Clone)] // Currently rustc does not accept this.
struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc); // #[derive(Clone)]
// struct AssocGeneric2<T: Tr>(<T as Tr>::Assoc);
#[derive(Clone)] #[derive(Clone)]
struct AssocGeneric3<T: Tr>(Generic<T::Assoc>); struct AssocGeneric3<T: Tr>(Generic<T::Assoc>);
@ -4361,9 +4362,8 @@ fn derive_macro_bounds() {
let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy); let x: &AssocGeneric<Copy> = &AssocGeneric(NotCopy);
let x = x.clone(); let x = x.clone();
//^ &AssocGeneric<Copy> //^ &AssocGeneric<Copy>
let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy); // let x: &AssocGeneric2<Copy> = &AssocGeneric2(NotCopy);
let x = x.clone(); // let x = x.clone();
//^ &AssocGeneric2<Copy>
let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy)); let x: &AssocGeneric3<Copy> = &AssocGeneric3(Generic(NotCopy));
let x = x.clone(); let x = x.clone();
//^ &AssocGeneric3<Copy> //^ &AssocGeneric3<Copy>