Fix macro expansion for statements w/o semicolon

This commit is contained in:
Edwin Cheng 2021-03-16 13:44:50 +08:00
parent c0a2b4e826
commit 8e07b23b84
10 changed files with 99 additions and 61 deletions

4
Cargo.lock generated
View file

@ -1811,9 +1811,9 @@ checksum = "56dee185309b50d1f11bfedef0fe6d036842e3fb77413abef29f8f8d1c5d4c1c"
[[package]] [[package]]
name = "ungrammar" name = "ungrammar"
version = "1.11.0" version = "1.12.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84c629795d377049f2a1dc5f42cf505dc5ba8b28a5df0a03f4183a24480e4a6a" checksum = "df6586a7c530704efe803d49a0b4132dcbdb4063163df39110548e6b5f2373ba"
[[package]] [[package]]
name = "unicase" name = "unicase"

View file

@ -519,7 +519,7 @@ impl ExprCollector<'_> {
} }
ast::Expr::MacroCall(e) => { ast::Expr::MacroCall(e) => {
let mut ids = vec![]; let mut ids = vec![];
self.collect_macro_call(e, syntax_ptr.clone(), |this, expansion| { self.collect_macro_call(e, syntax_ptr.clone(), true, |this, expansion| {
ids.push(match expansion { ids.push(match expansion {
Some(it) => this.collect_expr(it), Some(it) => this.collect_expr(it),
None => this.alloc_expr(Expr::Missing, syntax_ptr.clone()), None => this.alloc_expr(Expr::Missing, syntax_ptr.clone()),
@ -527,6 +527,17 @@ impl ExprCollector<'_> {
}); });
ids[0] ids[0]
} }
ast::Expr::MacroStmts(e) => {
// FIXME: these statements should be held by some hir containter
for stmt in e.statements() {
self.collect_stmt(stmt);
}
if let Some(expr) = e.expr() {
self.collect_expr(expr)
} else {
self.alloc_expr(Expr::Missing, syntax_ptr)
}
}
} }
} }
@ -534,6 +545,7 @@ impl ExprCollector<'_> {
&mut self, &mut self,
e: ast::MacroCall, e: ast::MacroCall,
syntax_ptr: AstPtr<ast::Expr>, syntax_ptr: AstPtr<ast::Expr>,
is_error_recoverable: bool,
mut collector: F, mut collector: F,
) { ) {
// File containing the macro call. Expansion errors will be attached here. // File containing the macro call. Expansion errors will be attached here.
@ -567,7 +579,7 @@ impl ExprCollector<'_> {
Some((mark, expansion)) => { Some((mark, expansion)) => {
// FIXME: Statements are too complicated to recover from error for now. // FIXME: Statements are too complicated to recover from error for now.
// It is because we don't have any hygiene for local variable expansion right now. // It is because we don't have any hygiene for local variable expansion right now.
if T::can_cast(syntax::SyntaxKind::MACRO_STMTS) && res.err.is_some() { if !is_error_recoverable && res.err.is_some() {
self.expander.exit(self.db, mark); self.expander.exit(self.db, mark);
collector(self, None); collector(self, None);
} else { } else {
@ -591,8 +603,7 @@ impl ExprCollector<'_> {
} }
fn collect_stmt(&mut self, s: ast::Stmt) -> Option<Vec<Statement>> { fn collect_stmt(&mut self, s: ast::Stmt) -> Option<Vec<Statement>> {
let stmt = let stmt = match s {
match s {
ast::Stmt::LetStmt(stmt) => { ast::Stmt::LetStmt(stmt) => {
self.check_cfg(&stmt)?; self.check_cfg(&stmt)?;
@ -609,7 +620,7 @@ impl ExprCollector<'_> {
let syntax_ptr = AstPtr::new(&stmt.expr().unwrap()); let syntax_ptr = AstPtr::new(&stmt.expr().unwrap());
let mut stmts = vec![]; let mut stmts = vec![];
self.collect_macro_call(m, syntax_ptr.clone(), |this, expansion| { self.collect_macro_call(m, syntax_ptr.clone(), false, |this, expansion| {
match expansion { match expansion {
Some(expansion) => { Some(expansion) => {
let statements: ast::MacroStmts = expansion; let statements: ast::MacroStmts = expansion;

View file

@ -110,6 +110,11 @@ impl ItemTree {
// still need to collect inner items. // still need to collect inner items.
ctx.lower_inner_items(e.syntax()) ctx.lower_inner_items(e.syntax())
}, },
ast::ExprStmt(stmt) => {
// Macros can expand to stmt. We return an empty item tree in this case, but
// still need to collect inner items.
ctx.lower_inner_items(stmt.syntax())
},
_ => { _ => {
panic!("cannot create item tree from {:?} {}", syntax, syntax); panic!("cannot create item tree from {:?} {}", syntax, syntax);
}, },

View file

@ -401,13 +401,14 @@ fn to_fragment_kind(db: &dyn AstDatabase, id: MacroCallId) -> FragmentKind {
match parent.kind() { match parent.kind() {
MACRO_ITEMS | SOURCE_FILE => FragmentKind::Items, MACRO_ITEMS | SOURCE_FILE => FragmentKind::Items,
MACRO_STMTS => FragmentKind::Statement,
ITEM_LIST => FragmentKind::Items, ITEM_LIST => FragmentKind::Items,
LET_STMT => { LET_STMT => {
// FIXME: Handle Pattern // FIXME: Handle Pattern
FragmentKind::Expr FragmentKind::Expr
} }
EXPR_STMT => FragmentKind::Statements, EXPR_STMT => FragmentKind::Statements,
BLOCK_EXPR => FragmentKind::Expr, BLOCK_EXPR => FragmentKind::Statements,
ARG_LIST => FragmentKind::Expr, ARG_LIST => FragmentKind::Expr,
TRY_EXPR => FragmentKind::Expr, TRY_EXPR => FragmentKind::Expr,
TUPLE_EXPR => FragmentKind::Expr, TUPLE_EXPR => FragmentKind::Expr,

View file

@ -215,6 +215,22 @@ fn expr_macro_expanded_in_various_places() {
); );
} }
#[test]
fn expr_macro_expanded_in_stmts() {
check_infer(
r#"
macro_rules! id { ($($es:tt)*) => { $($es)* } }
fn foo() {
id! { let a = (); }
}
"#,
expect![[r#"
!0..8 'leta=();': ()
57..84 '{ ...); } }': ()
"#]],
);
}
#[test] #[test]
fn infer_type_value_macro_having_same_name() { fn infer_type_value_macro_having_same_name() {
check_infer( check_infer(

View file

@ -662,7 +662,6 @@ fn test_tt_to_stmts() {
LITERAL@12..13 LITERAL@12..13
INT_NUMBER@12..13 "1" INT_NUMBER@12..13 "1"
SEMICOLON@13..14 ";" SEMICOLON@13..14 ";"
EXPR_STMT@14..15
PATH_EXPR@14..15 PATH_EXPR@14..15
PATH@14..15 PATH@14..15
PATH_SEGMENT@14..15 PATH_SEGMENT@14..15

View file

@ -63,11 +63,11 @@ pub(crate) mod fragments {
} }
pub(crate) fn stmt(p: &mut Parser) { pub(crate) fn stmt(p: &mut Parser) {
expressions::stmt(p, expressions::StmtWithSemi::No) expressions::stmt(p, expressions::StmtWithSemi::No, true)
} }
pub(crate) fn stmt_optional_semi(p: &mut Parser) { pub(crate) fn stmt_optional_semi(p: &mut Parser) {
expressions::stmt(p, expressions::StmtWithSemi::Optional) expressions::stmt(p, expressions::StmtWithSemi::Optional, false)
} }
pub(crate) fn opt_visibility(p: &mut Parser) { pub(crate) fn opt_visibility(p: &mut Parser) {
@ -133,7 +133,7 @@ pub(crate) mod fragments {
continue; continue;
} }
expressions::stmt(p, expressions::StmtWithSemi::Optional); expressions::stmt(p, expressions::StmtWithSemi::Optional, true);
} }
m.complete(p, MACRO_STMTS); m.complete(p, MACRO_STMTS);

View file

@ -54,7 +54,7 @@ fn is_expr_stmt_attr_allowed(kind: SyntaxKind) -> bool {
!forbid !forbid
} }
pub(super) fn stmt(p: &mut Parser, with_semi: StmtWithSemi) { pub(super) fn stmt(p: &mut Parser, with_semi: StmtWithSemi, prefer_expr: bool) {
let m = p.start(); let m = p.start();
// test attr_on_expr_stmt // test attr_on_expr_stmt
// fn foo() { // fn foo() {
@ -90,7 +90,7 @@ pub(super) fn stmt(p: &mut Parser, with_semi: StmtWithSemi) {
p.error(format!("attributes are not allowed on {:?}", kind)); p.error(format!("attributes are not allowed on {:?}", kind));
} }
if p.at(T!['}']) { if p.at(T!['}']) || (prefer_expr && p.at(EOF)) {
// test attr_on_last_expr_in_block // test attr_on_last_expr_in_block
// fn foo() { // fn foo() {
// { #[A] bar!()? } // { #[A] bar!()? }
@ -198,7 +198,7 @@ pub(super) fn expr_block_contents(p: &mut Parser) {
continue; continue;
} }
stmt(p, StmtWithSemi::Yes) stmt(p, StmtWithSemi::Yes, false)
} }
} }

View file

@ -1336,6 +1336,7 @@ pub enum Expr {
Literal(Literal), Literal(Literal),
LoopExpr(LoopExpr), LoopExpr(LoopExpr),
MacroCall(MacroCall), MacroCall(MacroCall),
MacroStmts(MacroStmts),
MatchExpr(MatchExpr), MatchExpr(MatchExpr),
MethodCallExpr(MethodCallExpr), MethodCallExpr(MethodCallExpr),
ParenExpr(ParenExpr), ParenExpr(ParenExpr),
@ -3034,6 +3035,9 @@ impl From<LoopExpr> for Expr {
impl From<MacroCall> for Expr { impl From<MacroCall> for Expr {
fn from(node: MacroCall) -> Expr { Expr::MacroCall(node) } fn from(node: MacroCall) -> Expr { Expr::MacroCall(node) }
} }
impl From<MacroStmts> for Expr {
fn from(node: MacroStmts) -> Expr { Expr::MacroStmts(node) }
}
impl From<MatchExpr> for Expr { impl From<MatchExpr> for Expr {
fn from(node: MatchExpr) -> Expr { Expr::MatchExpr(node) } fn from(node: MatchExpr) -> Expr { Expr::MatchExpr(node) }
} }
@ -3078,8 +3082,8 @@ impl AstNode for Expr {
match kind { match kind {
ARRAY_EXPR | AWAIT_EXPR | BIN_EXPR | BLOCK_EXPR | BOX_EXPR | BREAK_EXPR | CALL_EXPR ARRAY_EXPR | AWAIT_EXPR | BIN_EXPR | BLOCK_EXPR | BOX_EXPR | BREAK_EXPR | CALL_EXPR
| CAST_EXPR | CLOSURE_EXPR | CONTINUE_EXPR | EFFECT_EXPR | FIELD_EXPR | FOR_EXPR | CAST_EXPR | CLOSURE_EXPR | CONTINUE_EXPR | EFFECT_EXPR | FIELD_EXPR | FOR_EXPR
| IF_EXPR | INDEX_EXPR | LITERAL | LOOP_EXPR | MACRO_CALL | MATCH_EXPR | IF_EXPR | INDEX_EXPR | LITERAL | LOOP_EXPR | MACRO_CALL | MACRO_STMTS
| METHOD_CALL_EXPR | PAREN_EXPR | PATH_EXPR | PREFIX_EXPR | RANGE_EXPR | MATCH_EXPR | METHOD_CALL_EXPR | PAREN_EXPR | PATH_EXPR | PREFIX_EXPR | RANGE_EXPR
| RECORD_EXPR | REF_EXPR | RETURN_EXPR | TRY_EXPR | TUPLE_EXPR | WHILE_EXPR | RECORD_EXPR | REF_EXPR | RETURN_EXPR | TRY_EXPR | TUPLE_EXPR | WHILE_EXPR
| YIELD_EXPR => true, | YIELD_EXPR => true,
_ => false, _ => false,
@ -3105,6 +3109,7 @@ impl AstNode for Expr {
LITERAL => Expr::Literal(Literal { syntax }), LITERAL => Expr::Literal(Literal { syntax }),
LOOP_EXPR => Expr::LoopExpr(LoopExpr { syntax }), LOOP_EXPR => Expr::LoopExpr(LoopExpr { syntax }),
MACRO_CALL => Expr::MacroCall(MacroCall { syntax }), MACRO_CALL => Expr::MacroCall(MacroCall { syntax }),
MACRO_STMTS => Expr::MacroStmts(MacroStmts { syntax }),
MATCH_EXPR => Expr::MatchExpr(MatchExpr { syntax }), MATCH_EXPR => Expr::MatchExpr(MatchExpr { syntax }),
METHOD_CALL_EXPR => Expr::MethodCallExpr(MethodCallExpr { syntax }), METHOD_CALL_EXPR => Expr::MethodCallExpr(MethodCallExpr { syntax }),
PAREN_EXPR => Expr::ParenExpr(ParenExpr { syntax }), PAREN_EXPR => Expr::ParenExpr(ParenExpr { syntax }),
@ -3142,6 +3147,7 @@ impl AstNode for Expr {
Expr::Literal(it) => &it.syntax, Expr::Literal(it) => &it.syntax,
Expr::LoopExpr(it) => &it.syntax, Expr::LoopExpr(it) => &it.syntax,
Expr::MacroCall(it) => &it.syntax, Expr::MacroCall(it) => &it.syntax,
Expr::MacroStmts(it) => &it.syntax,
Expr::MatchExpr(it) => &it.syntax, Expr::MatchExpr(it) => &it.syntax,
Expr::MethodCallExpr(it) => &it.syntax, Expr::MethodCallExpr(it) => &it.syntax,
Expr::ParenExpr(it) => &it.syntax, Expr::ParenExpr(it) => &it.syntax,

View file

@ -11,7 +11,7 @@ anyhow = "1.0.26"
flate2 = "1.0" flate2 = "1.0"
proc-macro2 = "1.0.8" proc-macro2 = "1.0.8"
quote = "1.0.2" quote = "1.0.2"
ungrammar = "=1.11" ungrammar = "=1.12"
walkdir = "2.3.1" walkdir = "2.3.1"
write-json = "0.1.0" write-json = "0.1.0"
xshell = "0.1" xshell = "0.1"