diff --git a/crates/hir_def/src/item_scope.rs b/crates/hir_def/src/item_scope.rs index 51586142e2..9f543b79b5 100644 --- a/crates/hir_def/src/item_scope.rs +++ b/crates/hir_def/src/item_scope.rs @@ -8,6 +8,7 @@ use hir_expand::{name::Name, AstId, MacroCallId, MacroDefKind}; use once_cell::sync::Lazy; use profile::Count; use rustc_hash::{FxHashMap, FxHashSet}; +use smallvec::SmallVec; use stdx::format_to; use syntax::ast; @@ -61,7 +62,7 @@ pub struct ItemScope { // be all resolved to the last one defined if shadowing happens. legacy_macros: FxHashMap, attr_macros: FxHashMap, MacroCallId>, - derive_macros: FxHashMap, (AttrId, MacroCallId)>, + derive_macros: FxHashMap, SmallVec<[(AttrId, MacroCallId); 1]>>, } pub(crate) static BUILTIN_SCOPE: Lazy> = Lazy::new(|| { @@ -189,13 +190,13 @@ impl ItemScope { call: MacroCallId, attr_id: AttrId, ) { - self.derive_macros.insert(item, (attr_id, call)); + self.derive_macros.entry(item).or_default().push((attr_id, call)); } pub(crate) fn derive_macro_invocs( &self, ) -> impl Iterator, (AttrId, MacroCallId))> + '_ { - self.derive_macros.iter().map(|(k, v)| (*k, *v)) + self.derive_macros.iter().flat_map(|(k, v)| v.iter().map(move |v| (*k, *v))) } pub(crate) fn unnamed_trait_vis(&self, tr: TraitId) -> Option { diff --git a/crates/ide/src/expand_macro.rs b/crates/ide/src/expand_macro.rs index 363b77967c..889b3748ad 100644 --- a/crates/ide/src/expand_macro.rs +++ b/crates/ide/src/expand_macro.rs @@ -3,7 +3,7 @@ use std::iter; use hir::Semantics; use ide_db::{helpers::pick_best_token, RootDatabase}; use itertools::Itertools; -use syntax::{ast, ted, AstNode, NodeOrToken, SyntaxKind, SyntaxKind::*, SyntaxNode, WalkEvent, T}; +use syntax::{ast, ted, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, WalkEvent, T}; use crate::FilePosition; @@ -31,21 +31,22 @@ pub(crate) fn expand_macro(db: &RootDatabase, position: FilePosition) -> Option< SyntaxKind::IDENT => 1, _ => 0, })?; + let descended = sema.descend_into_macros(tok.clone()); + if let Some(attr) = descended.ancestors().find_map(ast::Attr::cast) { + if let Some((path, tt)) = attr.as_simple_call() { + if path == "derive" { + let mut tt = tt.syntax().children_with_tokens().skip(1).join(""); + tt.pop(); + return sema + .expand_derive_macro(&attr) + .map(insert_whitespaces) + .map(|expansion| ExpandedMacro { name: tt, expansion }); + } + } + } let mut expanded = None; let mut name = None; for node in tok.ancestors() { - if let Some(attr) = ast::Attr::cast(node.clone()) { - if let Some((path, tt)) = attr.as_simple_call() { - if path == "derive" { - let mut tt = tt.syntax().children_with_tokens().skip(1).join(""); - tt.pop(); - name = Some(tt); - expanded = sema.expand_derive_macro(&attr); - break; - } - } - } - if let Some(item) = ast::Item::cast(node.clone()) { if let Some(def) = sema.resolve_attr_macro_call(&item) { name = def.name(db).map(|name| name.to_string()); @@ -53,7 +54,6 @@ pub(crate) fn expand_macro(db: &RootDatabase, position: FilePosition) -> Option< break; } } - if let Some(mac) = ast::MacroCall::cast(node) { name = Some(mac.path()?.segment()?.name_ref()?.to_string()); expanded = expand_macro_recur(&sema, &mac); @@ -107,24 +107,26 @@ fn expand( // FIXME: It would also be cool to share logic here and in the mbe tests, // which are pretty unreadable at the moment. fn insert_whitespaces(syn: SyntaxNode) -> String { + use SyntaxKind::*; let mut res = String::new(); - let mut token_iter = syn - .preorder_with_tokens() - .filter_map(|event| { - if let WalkEvent::Enter(NodeOrToken::Token(token)) = event { - Some(token) - } else { - None - } - }) - .peekable(); let mut indent = 0; let mut last: Option = None; - while let Some(token) = token_iter.next() { - let mut is_next = |f: fn(SyntaxKind) -> bool, default| -> bool { - token_iter.peek().map(|it| f(it.kind())).unwrap_or(default) + for event in syn.preorder_with_tokens() { + let token = match event { + WalkEvent::Enter(NodeOrToken::Token(token)) => token, + WalkEvent::Leave(NodeOrToken::Node(node)) + if matches!(node.kind(), ATTR | MATCH_ARM | STRUCT | ENUM | UNION | FN | IMPL) => + { + res.push('\n'); + res.extend(iter::repeat(" ").take(2 * indent)); + continue; + } + _ => continue, + }; + let is_next = |f: fn(SyntaxKind) -> bool, default| -> bool { + token.next_token().map(|it| f(it.kind())).unwrap_or(default) }; let is_last = |f: fn(SyntaxKind) -> bool, default| -> bool { last.map(f).unwrap_or(default) }; @@ -182,6 +184,7 @@ mod tests { use crate::fixture; + #[track_caller] fn check(ra_fixture: &str, expect: Expect) { let (analysis, pos) = fixture::position(ra_fixture); let expansion = analysis.expand_macro(pos).unwrap().unwrap(); @@ -207,6 +210,7 @@ f$0oo!(); expect![[r#" foo fn b(){} + "#]], ); } @@ -226,11 +230,12 @@ macro_rules! foo { f$0oo!(); "#, expect![[r#" - foo - fn some_thing() -> u32 { - let a = 0; - a+10 - }"#]], + foo + fn some_thing() -> u32 { + let a = 0; + a+10 + } + "#]], ); } @@ -343,7 +348,6 @@ fn main() { fn macro_expand_derive() { check( r#" - #[rustc_builtin_macro] pub macro Clone {} @@ -353,6 +357,28 @@ struct Foo {} expect![[r#" Clone impl< >crate::clone::Clone for Foo< >{} + + "#]], + ); + } + + #[test] + fn macro_expand_derive2() { + check( + r#" +#[rustc_builtin_macro] +pub macro Clone {} +#[rustc_builtin_macro] +pub macro Copy {} + +#[derive(Cop$0y)] +#[derive(Clone)] +struct Foo {} +"#, + expect![[r#" + Copy + impl< >crate::marker::Copy for Foo< >{} + "#]], ); }