diff --git a/crates/hir_def/src/macro_expansion_tests.rs b/crates/hir_def/src/macro_expansion_tests.rs index 80747f6c63..16df7ce4cf 100644 --- a/crates/hir_def/src/macro_expansion_tests.rs +++ b/crates/hir_def/src/macro_expansion_tests.rs @@ -345,6 +345,7 @@ impl base_db::ProcMacroExpander for IdentityWhenValidProcMacroExpander { if parse.errors().is_empty() { Ok(subtree.clone()) } else { + eprintln!("parse errors: {:?}", parse.errors()); use tt::{Delimiter, DelimiterKind, Ident, Leaf, Literal, Punct, TokenTree}; let mut subtree = Subtree::default(); subtree.token_trees.push(TokenTree::Leaf( diff --git a/crates/hir_def/src/macro_expansion_tests/proc_macros.rs b/crates/hir_def/src/macro_expansion_tests/proc_macros.rs index dead99a40c..e0c5367cf3 100644 --- a/crates/hir_def/src/macro_expansion_tests/proc_macros.rs +++ b/crates/hir_def/src/macro_expansion_tests/proc_macros.rs @@ -86,6 +86,10 @@ fn foo() { bar.; blub } expect![[r##" #[proc_macros::identity_when_valid] fn foo() { bar.; blub } -"##]], + +fn foo() { + bar.; + blub +}"##]], ); } diff --git a/crates/hir_expand/src/db.rs b/crates/hir_expand/src/db.rs index d9bfdd6fd4..6576701817 100644 --- a/crates/hir_expand/src/db.rs +++ b/crates/hir_expand/src/db.rs @@ -5,8 +5,8 @@ use std::sync::Arc; use base_db::{salsa, SourceDatabase}; use either::Either; use limit::Limit; -use mbe::{syntax_node_to_token_tree, ExpandError, ExpandResult}; -use rustc_hash::FxHashSet; +use mbe::{syntax_node_to_token_tree, ExpandError, ExpandResult, SyntheticToken}; +use rustc_hash::{FxHashMap, FxHashSet}; use syntax::{ algo::diff, ast::{self, HasAttrs, HasDocComments}, @@ -14,7 +14,7 @@ use syntax::{ }; use crate::{ - ast_id_map::AstIdMap, hygiene::HygieneFrame, BuiltinAttrExpander, BuiltinDeriveExpander, + ast_id_map::AstIdMap, fixup, hygiene::HygieneFrame, BuiltinAttrExpander, BuiltinDeriveExpander, BuiltinFnLikeExpander, ExpandTo, HirFileId, HirFileIdRepr, MacroCallId, MacroCallKind, MacroCallLoc, MacroDefId, MacroDefKind, MacroFile, ProcMacroExpander, }; @@ -146,8 +146,10 @@ pub fn expand_speculative( // Build the subtree and token mapping for the speculative args let censor = censor_for_macro_input(&loc, &speculative_args); + let mut fixups = fixup::fixup_syntax(&speculative_args); + fixups.replace.extend(censor.into_iter().map(|node| (node, Vec::new()))); let (mut tt, spec_args_tmap) = - mbe::syntax_node_to_token_tree_censored(&speculative_args, &censor); + mbe::syntax_node_to_token_tree_censored(&speculative_args, fixups.replace, fixups.append); let (attr_arg, token_id) = match loc.kind { MacroCallKind::Attr { invoc_attr_index, .. } => { @@ -294,8 +296,17 @@ fn macro_arg(db: &dyn AstDatabase, id: MacroCallId) -> Option (tt::Subtree, TokenMap) { - syntax_node_to_token_tree_censored(node, &Default::default()) + syntax_node_to_token_tree_censored(node, Default::default(), Default::default()) } +// TODO rename /// Convert the syntax node to a `TokenTree` (what macro will consume) /// with the censored range excluded. pub fn syntax_node_to_token_tree_censored( node: &SyntaxNode, - censor: &FxHashSet, + replace: FxHashMap>, + append: FxHashMap>, ) -> (tt::Subtree, TokenMap) { let global_offset = node.text_range().start(); - let mut c = Convertor::new(node, global_offset, censor); + let mut c = Convertor::new(node, global_offset, replace, append); let subtree = convert_tokens(&mut c); c.id_alloc.map.shrink_to_fit(); (subtree, c.id_alloc.map) } +pub type SyntheticToken = (SyntaxKind, SmolStr); + // The following items are what `rustc` macro can be parsed into : // link: https://github.com/rust-lang/rust/blob/9ebf47851a357faa4cd97f4b1dc7835f6376e639/src/libsyntax/ext/expand.rs#L141 // * Expr(P) -> token_tree_to_expr @@ -465,86 +469,124 @@ impl<'a> TokenConvertor for RawConvertor<'a> { } } -struct Convertor<'c> { +struct Convertor { id_alloc: TokenIdAlloc, current: Option, + current_synthetic: Vec, preorder: PreorderWithTokens, - censor: &'c FxHashSet, + replace: FxHashMap>, + append: FxHashMap>, range: TextRange, punct_offset: Option<(SyntaxToken, TextSize)>, } -impl<'c> Convertor<'c> { +impl Convertor { fn new( node: &SyntaxNode, global_offset: TextSize, - censor: &'c FxHashSet, - ) -> Convertor<'c> { + replace: FxHashMap>, + append: FxHashMap>, + ) -> Convertor { let range = node.text_range(); let mut preorder = node.preorder_with_tokens(); - let first = Self::next_token(&mut preorder, censor); + let (first, synthetic) = Self::next_token(&mut preorder, &replace, &append); Convertor { id_alloc: { TokenIdAlloc { map: TokenMap::default(), global_offset, next_id: 0 } }, current: first, + current_synthetic: synthetic, preorder, range, - censor, + replace, + append, punct_offset: None, } } fn next_token( preorder: &mut PreorderWithTokens, - censor: &FxHashSet, - ) -> Option { + replace: &FxHashMap>, + append: &FxHashMap>, + ) -> (Option, Vec) { while let Some(ev) = preorder.next() { let ele = match ev { WalkEvent::Enter(ele) => ele, + WalkEvent::Leave(SyntaxElement::Node(node)) => { + if let Some(v) = append.get(&node) { + eprintln!("after {:?}, appending {:?}", node, v); + if !v.is_empty() { + let mut reversed = v.clone(); + reversed.reverse(); + return (None, reversed); + } + } + continue; + } _ => continue, }; match ele { - SyntaxElement::Token(t) => return Some(t), - SyntaxElement::Node(node) if censor.contains(&node) => preorder.skip_subtree(), - SyntaxElement::Node(_) => (), + SyntaxElement::Token(t) => return (Some(t), Vec::new()), + SyntaxElement::Node(node) => { + if let Some(v) = replace.get(&node) { + preorder.skip_subtree(); + eprintln!("replacing {:?} by {:?}", node, v); + if !v.is_empty() { + let mut reversed = v.clone(); + reversed.reverse(); + return (None, reversed); + } + } + } } } - None + (None, Vec::new()) } } #[derive(Debug)] enum SynToken { Ordinary(SyntaxToken), + // FIXME is this supposed to be `Punct`? Punch(SyntaxToken, TextSize), + Synthetic(SyntheticToken), } impl SynToken { - fn token(&self) -> &SyntaxToken { + fn token(&self) -> Option<&SyntaxToken> { match self { - SynToken::Ordinary(it) | SynToken::Punch(it, _) => it, + SynToken::Ordinary(it) | SynToken::Punch(it, _) => Some(it), + SynToken::Synthetic(_) => None, } } } -impl<'a> SrcToken> for SynToken { - fn kind(&self, _ctx: &Convertor<'a>) -> SyntaxKind { - self.token().kind() +impl SrcToken for SynToken { + fn kind(&self, _ctx: &Convertor) -> SyntaxKind { + match self { + SynToken::Ordinary(token) => token.kind(), + SynToken::Punch(token, _) => token.kind(), + SynToken::Synthetic((kind, _)) => *kind, + } } - fn to_char(&self, _ctx: &Convertor<'a>) -> Option { + fn to_char(&self, _ctx: &Convertor) -> Option { match self { SynToken::Ordinary(_) => None, SynToken::Punch(it, i) => it.text().chars().nth((*i).into()), + SynToken::Synthetic(_) => None, } } - fn to_text(&self, _ctx: &Convertor<'a>) -> SmolStr { - self.token().text().into() + fn to_text(&self, _ctx: &Convertor) -> SmolStr { + match self { + SynToken::Ordinary(token) => token.text().into(), + SynToken::Punch(token, _) => token.text().into(), + SynToken::Synthetic((_, text)) => text.clone(), + } } } -impl TokenConvertor for Convertor<'_> { +impl TokenConvertor for Convertor { type Token = SynToken; fn convert_doc_comment(&self, token: &Self::Token) -> Option> { - convert_doc_comment(token.token()) + convert_doc_comment(token.token()?) } fn bump(&mut self) -> Option<(Self::Token, TextRange)> { @@ -558,11 +600,25 @@ impl TokenConvertor for Convertor<'_> { } } + if let Some(synth_token) = self.current_synthetic.pop() { + if self.current_synthetic.is_empty() { + let (new_current, new_synth) = + Self::next_token(&mut self.preorder, &self.replace, &self.append); + self.current = new_current; + self.current_synthetic = new_synth; + } + // TODO fix range? + return Some((SynToken::Synthetic(synth_token), self.range)); + } + let curr = self.current.clone()?; if !&self.range.contains_range(curr.text_range()) { return None; } - self.current = Self::next_token(&mut self.preorder, self.censor); + let (new_current, new_synth) = + Self::next_token(&mut self.preorder, &self.replace, &self.append); + self.current = new_current; + self.current_synthetic = new_synth; let token = if curr.kind().is_punct() { self.punct_offset = Some((curr.clone(), 0.into())); let range = curr.text_range(); @@ -585,6 +641,11 @@ impl TokenConvertor for Convertor<'_> { } } + if let Some(synth_token) = self.current_synthetic.last() { + // TODO fix range? + return Some(SynToken::Synthetic(synth_token.clone())); + } + let curr = self.current.clone()?; if !self.range.contains_range(curr.text_range()) { return None;