diff --git a/crates/base_db/src/fixture.rs b/crates/base_db/src/fixture.rs index 017684fc83..58b55e48c1 100644 --- a/crates/base_db/src/fixture.rs +++ b/crates/base_db/src/fixture.rs @@ -270,6 +270,10 @@ fn test_proc_macros(proc_macros: &[String]) -> (Vec, String) { pub fn identity(_attr: TokenStream, item: TokenStream) -> TokenStream { item } +#[proc_macro_derive(derive_identity)] +pub fn derive_identity(item: TokenStream) -> TokenStream { + item +} #[proc_macro_attribute] pub fn input_replace(attr: TokenStream, _item: TokenStream) -> TokenStream { attr @@ -285,6 +289,11 @@ pub fn mirror(input: TokenStream) -> TokenStream { kind: crate::ProcMacroKind::Attr, expander: Arc::new(IdentityProcMacroExpander), }, + ProcMacro { + name: "derive_identity".into(), + kind: crate::ProcMacroKind::CustomDerive, + expander: Arc::new(IdentityProcMacroExpander), + }, ProcMacro { name: "input_replace".into(), kind: crate::ProcMacroKind::Attr, diff --git a/crates/hir_def/src/macro_expansion_tests.rs b/crates/hir_def/src/macro_expansion_tests.rs index 5f8bf44cc9..036d9147a0 100644 --- a/crates/hir_def/src/macro_expansion_tests.rs +++ b/crates/hir_def/src/macro_expansion_tests.rs @@ -12,6 +12,7 @@ mod mbe; mod builtin_fn_macro; mod builtin_derive_macro; +mod proc_macros; use std::{iter, ops::Range}; @@ -28,7 +29,7 @@ use syntax::{ use crate::{ db::DefDatabase, nameres::ModuleSource, resolver::HasResolver, src::HasSource, test_db::TestDB, - AsMacroCall, Lookup, + AdtId, AsMacroCall, Lookup, ModuleDefId, }; #[track_caller] @@ -124,6 +125,16 @@ fn check(ra_fixture: &str, mut expect: Expect) { expanded_text.replace_range(range, &expn_text) } + for decl_id in def_map[local_id].scope.declarations() { + if let ModuleDefId::AdtId(AdtId::StructId(struct_id)) = decl_id { + let src = struct_id.lookup(&db).source(&db); + if src.file_id.is_attr_macro(&db) || src.file_id.is_custom_derive(&db) { + let pp = pretty_print_macro_expansion(src.value.syntax().clone()); + format_to!(expanded_text, "\n{}", pp) + } + } + } + for impl_id in def_map[local_id].scope.impls() { let src = impl_id.lookup(&db).source(&db); if src.file_id.is_builtin_derive(&db).is_some() { diff --git a/crates/hir_def/src/macro_expansion_tests/proc_macros.rs b/crates/hir_def/src/macro_expansion_tests/proc_macros.rs new file mode 100644 index 0000000000..ef8dc3e3b0 --- /dev/null +++ b/crates/hir_def/src/macro_expansion_tests/proc_macros.rs @@ -0,0 +1,53 @@ +//! Tests for user-defined procedural macros. +//! +//! Note `//- proc_macros: identity` fixture metas in tests -- we don't use real +//! proc-macros here, as that would be slow. Instead, we use several hard-coded +//! in-memory macros. +use expect_test::expect; + +use crate::macro_expansion_tests::check; + +#[test] +fn attribute_macro_attr_censoring() { + cov_mark::check!(attribute_macro_attr_censoring); + check( + r#" +//- proc_macros: identity +#[attr1] #[proc_macros::identity] #[attr2] +struct S; +"#, + expect![[r##" +#[attr1] #[proc_macros::identity] #[attr2] +struct S; + +#[attr1] +#[attr2] struct S;"##]], + ); +} + +#[test] +fn derive_censoring() { + cov_mark::check!(derive_censoring); + check( + r#" +//- proc_macros: derive_identity +#[attr1] +#[derive(Foo)] +#[derive(proc_macros::derive_identity)] +#[derive(Bar)] +#[attr2] +struct S; +"#, + expect![[r##" +#[attr1] +#[derive(Foo)] +#[derive(proc_macros::derive_identity)] +#[derive(Bar)] +#[attr2] +struct S; + +#[attr1] +#[derive(Bar)] +#[attr2] struct S;"##]], + ); +} diff --git a/crates/hir_expand/src/db.rs b/crates/hir_expand/src/db.rs index 08ad9ffd94..c87fb6a17d 100644 --- a/crates/hir_expand/src/db.rs +++ b/crates/hir_expand/src/db.rs @@ -319,18 +319,24 @@ fn censor_for_macro_input(loc: &MacroCallLoc, node: &SyntaxNode) -> FxHashSet return None, - MacroCallKind::Derive { derive_attr_index, .. } => ast::Item::cast(node.clone())? - .attrs() - .take(derive_attr_index as usize + 1) - .filter(|attr| attr.simple_name().as_deref() == Some("derive")) - .map(|it| it.syntax().clone()) - .collect(), - MacroCallKind::Attr { invoc_attr_index, .. } => ast::Item::cast(node.clone())? - .attrs() - .nth(invoc_attr_index as usize) - .map(|attr| attr.syntax().clone()) - .into_iter() - .collect(), + MacroCallKind::Derive { derive_attr_index, .. } => { + cov_mark::hit!(derive_censoring); + ast::Item::cast(node.clone())? + .attrs() + .take(derive_attr_index as usize + 1) + .filter(|attr| attr.simple_name().as_deref() == Some("derive")) + .map(|it| it.syntax().clone()) + .collect() + } + MacroCallKind::Attr { invoc_attr_index, .. } => { + cov_mark::hit!(attribute_macro_attr_censoring); + ast::Item::cast(node.clone())? + .attrs() + .nth(invoc_attr_index as usize) + .map(|attr| attr.syntax().clone()) + .into_iter() + .collect() + } }; Some(censor) })() diff --git a/crates/hir_expand/src/lib.rs b/crates/hir_expand/src/lib.rs index de32d02415..1fc74cb3b0 100644 --- a/crates/hir_expand/src/lib.rs +++ b/crates/hir_expand/src/lib.rs @@ -175,6 +175,19 @@ impl HirFileId { } } + pub fn is_custom_derive(&self, db: &dyn db::AstDatabase) -> bool { + match self.0 { + HirFileIdRepr::FileId(_) => false, + HirFileIdRepr::MacroFile(macro_file) => { + let loc: MacroCallLoc = db.lookup_intern_macro(macro_file.macro_call_id); + match loc.def.kind { + MacroDefKind::ProcMacro(_, ProcMacroKind::CustomDerive, _) => true, + _ => false, + } + } + } + } + /// Return whether this file is an include macro pub fn is_include_macro(&self, db: &dyn db::AstDatabase) -> bool { match self.0 { diff --git a/crates/mbe/src/lib.rs b/crates/mbe/src/lib.rs index 61e032e0af..6cd084eaea 100644 --- a/crates/mbe/src/lib.rs +++ b/crates/mbe/src/lib.rs @@ -12,9 +12,6 @@ mod syntax_bridge; mod tt_iter; mod subtree_source; -#[cfg(test)] -mod tests; - #[cfg(test)] mod benchmark; mod token_map; diff --git a/crates/mbe/src/tests.rs b/crates/mbe/src/tests.rs deleted file mode 100644 index f4d7e60595..0000000000 --- a/crates/mbe/src/tests.rs +++ /dev/null @@ -1,55 +0,0 @@ -use std::iter; - -use syntax::{ast, AstNode}; - -use super::*; - -#[test] -fn test_node_to_tt_censor() { - use syntax::ast::{HasAttrs, HasModuleItem}; - - let source = r##" -#[attr0] -#[attr1] -#[attr2] -struct Struct { - field: () -} -"##; - let source_file = ast::SourceFile::parse(source).ok().unwrap(); - let item = source_file.items().next().unwrap(); - let attr = item.attrs().nth(1).unwrap(); - - let (tt, _) = syntax_node_to_token_tree_censored( - item.syntax(), - &iter::once(attr.syntax().clone()).collect(), - ); - expect_test::expect![[r##"# [attr0] # [attr2] struct Struct {field : ()}"##]] - .assert_eq(&tt.to_string()); - - let source = r##" -#[attr0] -#[derive(Derive0)] -#[attr1] -#[derive(Derive1)] -#[attr2] -#[derive(Derive2)] -#[attr3] -struct Struct { - field: () -} -"##; - let source_file = ast::SourceFile::parse(source).ok().unwrap(); - let item = source_file.items().next().unwrap(); - let derive_attr_index = 3; - let censor = item - .attrs() - .take(derive_attr_index as usize + 1) - .filter(|attr| attr.simple_name().as_deref() == Some("derive")) - .map(|it| it.syntax().clone()) - .collect(); - - let (tt, _) = syntax_node_to_token_tree_censored(item.syntax(), &censor); - expect_test::expect![[r##"# [attr0] # [attr1] # [attr2] # [derive (Derive2)] # [attr3] struct Struct {field : ()}"##]] - .assert_eq(&tt.to_string()); -}