Correctly resolve variables and labels from before macro definition in macro expansion

E.g.:
```rust
let v;
macro_rules! m { () => { v }; }
```

This was an existing bug, but it was less severe because unless the variable was shadowed it would be correctly resolved. With hygiene however, without this fix the variable is never resolved.
This commit is contained in:
Chayim Refael Friedman 2024-10-22 20:58:25 +03:00
parent 8adcbdcc49
commit 4ac3dc1a2f
12 changed files with 287 additions and 50 deletions

View file

@ -35,7 +35,7 @@ use crate::{
/// A wrapper around [`span::SyntaxContextId`] that is intended only for comparisons.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct HygieneId(span::SyntaxContextId);
pub struct HygieneId(pub(crate) span::SyntaxContextId);
impl HygieneId {
pub const ROOT: Self = Self(span::SyntaxContextId::ROOT);
@ -44,7 +44,7 @@ impl HygieneId {
Self(ctx)
}
fn is_root(self) -> bool {
pub(crate) fn is_root(self) -> bool {
self.0.is_root()
}
}
@ -420,7 +420,7 @@ impl Body {
self.walk_exprs_in_pat(*pat, &mut f);
}
Statement::Expr { expr: expression, .. } => f(*expression),
Statement::Item => (),
Statement::Item(_) => (),
}
}
if let &Some(expr) = tail {

View file

@ -10,7 +10,7 @@ use either::Either;
use hir_expand::{
name::{AsName, Name},
span_map::{ExpansionSpanMap, SpanMap},
InFile,
InFile, MacroDefId,
};
use intern::{sym, Interned, Symbol};
use rustc_hash::FxHashMap;
@ -39,8 +39,8 @@ use crate::{
FormatPlaceholder, FormatSign, FormatTrait,
},
Array, Binding, BindingAnnotation, BindingId, BindingProblems, CaptureBy, ClosureKind,
Expr, ExprId, Label, LabelId, Literal, LiteralOrConst, MatchArm, Movability, OffsetOf, Pat,
PatId, RecordFieldPat, RecordLitField, Statement,
Expr, ExprId, Item, Label, LabelId, Literal, LiteralOrConst, MatchArm, Movability,
OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, Statement,
},
item_scope::BuiltinShadowMode,
lang_item::LangItem,
@ -48,7 +48,7 @@ use crate::{
nameres::{DefMap, MacroSubNs},
path::{GenericArgs, Path},
type_ref::{Mutability, Rawness, TypeRef},
AdtId, BlockId, BlockLoc, ConstBlockLoc, DefWithBodyId, ModuleDefId, UnresolvedMacro,
AdtId, BlockId, BlockLoc, ConstBlockLoc, DefWithBodyId, MacroId, ModuleDefId, UnresolvedMacro,
};
type FxIndexSet<K> = indexmap::IndexSet<K, std::hash::BuildHasherDefault<rustc_hash::FxHasher>>;
@ -88,6 +88,7 @@ pub(super) fn lower(
current_binding_owner: None,
awaitable_context: None,
current_span_map: span_map,
current_block_legacy_macro_defs_count: FxHashMap::default(),
}
.collect(params, body, is_async_fn)
}
@ -104,6 +105,10 @@ struct ExprCollector<'a> {
is_lowering_coroutine: bool,
/// Legacy (`macro_rules!`) macros can have multiple definitions and shadow each other,
/// and we need to find the current definition. So we track the number of definitions we saw.
current_block_legacy_macro_defs_count: FxHashMap<Name, usize>,
current_span_map: Option<Arc<ExpansionSpanMap>>,
current_try_block_label: Option<LabelId>,
@ -124,31 +129,27 @@ struct ExprCollector<'a> {
#[derive(Clone, Debug)]
struct LabelRib {
kind: RibKind,
// Once we handle macro hygiene this will need to be a map
label: Option<(Name, LabelId, HygieneId)>,
}
impl LabelRib {
fn new(kind: RibKind) -> Self {
LabelRib { kind, label: None }
}
fn new_normal(label: (Name, LabelId, HygieneId)) -> Self {
LabelRib { kind: RibKind::Normal, label: Some(label) }
LabelRib { kind }
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
enum RibKind {
Normal,
Normal(Name, LabelId, HygieneId),
Closure,
Constant,
MacroDef(Box<MacroDefId>),
}
impl RibKind {
/// This rib forbids referring to labels defined in upwards ribs.
fn is_label_barrier(self) -> bool {
fn is_label_barrier(&self) -> bool {
match self {
RibKind::Normal => false,
RibKind::Normal(..) | RibKind::MacroDef(_) => false,
RibKind::Closure | RibKind::Constant => true,
}
}
@ -1350,8 +1351,44 @@ impl ExprCollector<'_> {
statements.push(Statement::Expr { expr, has_semi });
}
}
ast::Stmt::Item(_item) => statements.push(Statement::Item),
ast::Stmt::Item(ast::Item::MacroDef(macro_)) => {
let Some(name) = macro_.name() else {
statements.push(Statement::Item(Item::Other));
return;
};
let name = name.as_name();
let macro_id = self.def_map.modules[DefMap::ROOT].scope.get(&name).take_macros();
self.collect_macro_def(statements, macro_id);
}
ast::Stmt::Item(ast::Item::MacroRules(macro_)) => {
let Some(name) = macro_.name() else {
statements.push(Statement::Item(Item::Other));
return;
};
let name = name.as_name();
let macro_defs_count =
self.current_block_legacy_macro_defs_count.entry(name.clone()).or_insert(0);
let macro_id = self.def_map.modules[DefMap::ROOT]
.scope
.get_legacy_macro(&name)
.and_then(|it| it.get(*macro_defs_count))
.copied();
*macro_defs_count += 1;
self.collect_macro_def(statements, macro_id);
}
ast::Stmt::Item(_item) => statements.push(Statement::Item(Item::Other)),
}
}
fn collect_macro_def(&mut self, statements: &mut Vec<Statement>, macro_id: Option<MacroId>) {
let Some(macro_id) = macro_id else {
never!("def map should have macro definition, but it doesn't");
statements.push(Statement::Item(Item::Other));
return;
};
let macro_id = self.db.macro_def(macro_id);
statements.push(Statement::Item(Item::MacroDef(Box::new(macro_id))));
self.label_ribs.push(LabelRib::new(RibKind::MacroDef(Box::new(macro_id))));
}
fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId {
@ -1399,6 +1436,7 @@ impl ExprCollector<'_> {
};
let prev_def_map = mem::replace(&mut self.def_map, def_map);
let prev_local_module = mem::replace(&mut self.expander.module, module);
let prev_legacy_macros_count = mem::take(&mut self.current_block_legacy_macro_defs_count);
let mut statements = Vec::new();
block.statements().for_each(|s| self.collect_stmt(&mut statements, s));
@ -1421,6 +1459,7 @@ impl ExprCollector<'_> {
self.def_map = prev_def_map;
self.expander.module = prev_local_module;
self.current_block_legacy_macro_defs_count = prev_legacy_macros_count;
expr_id
}
@ -1780,12 +1819,25 @@ impl ExprCollector<'_> {
lifetime: Option<ast::Lifetime>,
) -> Result<Option<LabelId>, BodyDiagnostic> {
let Some(lifetime) = lifetime else { return Ok(None) };
let hygiene = self.hygiene_id_for(lifetime.syntax().text_range().start());
let (mut hygiene_id, mut hygiene_info) = match &self.current_span_map {
None => (HygieneId::ROOT, None),
Some(span_map) => {
let span = span_map.span_at(lifetime.syntax().text_range().start());
let ctx = self.db.lookup_intern_syntax_context(span.ctx);
let hygiene_id = HygieneId::new(ctx.opaque_and_semitransparent);
let hygiene_info = ctx.outer_expn.map(|expansion| {
let expansion = self.db.lookup_intern_macro_call(expansion);
(ctx.parent, expansion.def)
});
(hygiene_id, hygiene_info)
}
};
let name = Name::new_lifetime(&lifetime);
for (rib_idx, rib) in self.label_ribs.iter().enumerate().rev() {
if let Some((label_name, id, label_hygiene)) = &rib.label {
if *label_name == name && *label_hygiene == hygiene {
match &rib.kind {
RibKind::Normal(label_name, id, label_hygiene) => {
if *label_name == name && *label_hygiene == hygiene_id {
return if self.is_label_valid_from_rib(rib_idx) {
Ok(Some(*id))
} else {
@ -1796,6 +1848,23 @@ impl ExprCollector<'_> {
};
}
}
RibKind::MacroDef(macro_id) => {
if let Some((parent_ctx, label_macro_id)) = hygiene_info {
if label_macro_id == **macro_id {
// A macro is allowed to refer to labels from before its declaration.
// Therefore, if we got to the rib of its declaration, give up its hygiene
// and use its parent expansion.
let parent_ctx = self.db.lookup_intern_syntax_context(parent_ctx);
hygiene_id = HygieneId::new(parent_ctx.opaque_and_semitransparent);
hygiene_info = parent_ctx.outer_expn.map(|expansion| {
let expansion = self.db.lookup_intern_macro_call(expansion);
(parent_ctx.parent, expansion.def)
});
}
}
}
_ => {}
}
}
Err(BodyDiagnostic::UndeclaredLabel {
@ -1808,10 +1877,17 @@ impl ExprCollector<'_> {
!self.label_ribs[rib_index + 1..].iter().any(|rib| rib.kind.is_label_barrier())
}
fn pop_label_rib(&mut self) {
// We need to pop all macro defs, plus one rib.
while let Some(LabelRib { kind: RibKind::MacroDef(_) }) = self.label_ribs.pop() {
// Do nothing.
}
}
fn with_label_rib<T>(&mut self, kind: RibKind, f: impl FnOnce(&mut Self) -> T) -> T {
self.label_ribs.push(LabelRib::new(kind));
let res = f(self);
self.label_ribs.pop();
self.pop_label_rib();
res
}
@ -1821,9 +1897,13 @@ impl ExprCollector<'_> {
hygiene: HygieneId,
f: impl FnOnce(&mut Self) -> T,
) -> T {
self.label_ribs.push(LabelRib::new_normal((self.body[label].name.clone(), label, hygiene)));
self.label_ribs.push(LabelRib::new(RibKind::Normal(
self.body[label].name.clone(),
label,
hygiene,
)));
let res = f(self);
self.label_ribs.pop();
self.pop_label_rib();
res
}

View file

@ -753,7 +753,7 @@ impl Printer<'_> {
}
wln!(self);
}
Statement::Item => (),
Statement::Item(_) => (),
}
}

View file

@ -1,12 +1,12 @@
//! Name resolution for expressions.
use hir_expand::name::Name;
use hir_expand::{name::Name, MacroDefId};
use la_arena::{Arena, ArenaMap, Idx, IdxRange, RawIdx};
use triomphe::Arc;
use crate::{
body::{Body, HygieneId},
db::DefDatabase,
hir::{Binding, BindingId, Expr, ExprId, LabelId, Pat, PatId, Statement},
hir::{Binding, BindingId, Expr, ExprId, Item, LabelId, Pat, PatId, Statement},
BlockId, ConstBlockId, DefWithBodyId,
};
@ -45,6 +45,8 @@ pub struct ScopeData {
parent: Option<ScopeId>,
block: Option<BlockId>,
label: Option<(LabelId, Name)>,
// FIXME: We can compress this with an enum for this and `label`/`block` if memory usage matters.
macro_def: Option<Box<MacroDefId>>,
entries: IdxRange<ScopeEntry>,
}
@ -67,6 +69,12 @@ impl ExprScopes {
self.scopes[scope].block
}
/// If `scope` refers to a macro def scope, returns the corresponding `MacroId`.
#[allow(clippy::borrowed_box)] // If we return `&MacroDefId` we need to move it, this way we just clone the `Box`.
pub fn macro_def(&self, scope: ScopeId) -> Option<&Box<MacroDefId>> {
self.scopes[scope].macro_def.as_ref()
}
/// If `scope` refers to a labeled expression scope, returns the corresponding `Label`.
pub fn label(&self, scope: ScopeId) -> Option<(LabelId, Name)> {
self.scopes[scope].label.clone()
@ -119,6 +127,7 @@ impl ExprScopes {
parent: None,
block: None,
label: None,
macro_def: None,
entries: empty_entries(self.scope_entries.len()),
})
}
@ -128,6 +137,7 @@ impl ExprScopes {
parent: Some(parent),
block: None,
label: None,
macro_def: None,
entries: empty_entries(self.scope_entries.len()),
})
}
@ -137,6 +147,7 @@ impl ExprScopes {
parent: Some(parent),
block: None,
label,
macro_def: None,
entries: empty_entries(self.scope_entries.len()),
})
}
@ -151,6 +162,17 @@ impl ExprScopes {
parent: Some(parent),
block,
label,
macro_def: None,
entries: empty_entries(self.scope_entries.len()),
})
}
fn new_macro_def_scope(&mut self, parent: ScopeId, macro_id: Box<MacroDefId>) -> ScopeId {
self.scopes.alloc(ScopeData {
parent: Some(parent),
block: None,
label: None,
macro_def: Some(macro_id),
entries: empty_entries(self.scope_entries.len()),
})
}
@ -217,7 +239,10 @@ fn compute_block_scopes(
Statement::Expr { expr, .. } => {
compute_expr_scopes(*expr, body, scopes, scope, resolve_const_block);
}
Statement::Item => (),
Statement::Item(Item::MacroDef(macro_id)) => {
*scope = scopes.new_macro_def_scope(*scope, macro_id.clone());
}
Statement::Item(Item::Other) => (),
}
}
if let Some(expr) = tail {

View file

@ -17,7 +17,7 @@ pub mod type_ref;
use std::fmt;
use hir_expand::name::Name;
use hir_expand::{name::Name, MacroDefId};
use intern::{Interned, Symbol};
use la_arena::{Idx, RawIdx};
use rustc_apfloat::ieee::{Half as f16, Quad as f128};
@ -492,9 +492,13 @@ pub enum Statement {
expr: ExprId,
has_semi: bool,
},
// At the moment, we only use this to figure out if a return expression
// is really the last statement of a block. See #16566
Item,
Item(Item),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Item {
MacroDef(Box<MacroDefId>),
Other,
}
/// Explicit binding annotations given in the HIR for a binding. Note

View file

@ -83,6 +83,8 @@ enum Scope {
AdtScope(AdtId),
/// Local bindings
ExprScope(ExprScope),
/// Macro definition inside bodies that affects all paths after it in the same block.
MacroDefScope(Box<MacroDefId>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
@ -191,7 +193,7 @@ impl Resolver {
for scope in self.scopes() {
match scope {
Scope::ExprScope(_) => continue,
Scope::ExprScope(_) | Scope::MacroDefScope(_) => continue,
Scope::GenericParams { params, def } => {
if let Some(id) = params.find_type_by_name(first_name, *def) {
return Some((TypeNs::GenericParam(id), remaining_idx(), None));
@ -260,7 +262,7 @@ impl Resolver {
&self,
db: &dyn DefDatabase,
path: &Path,
hygiene: HygieneId,
mut hygiene_id: HygieneId,
) -> Option<ResolveValueResult> {
let path = match path {
Path::Normal { mod_path, .. } => mod_path,
@ -304,12 +306,21 @@ impl Resolver {
}
if n_segments <= 1 {
let mut hygiene_info = if !hygiene_id.is_root() {
let ctx = db.lookup_intern_syntax_context(hygiene_id.0);
ctx.outer_expn.map(|expansion| {
let expansion = db.lookup_intern_macro_call(expansion);
(ctx.parent, expansion.def)
})
} else {
None
};
for scope in self.scopes() {
match scope {
Scope::ExprScope(scope) => {
let entry =
scope.expr_scopes.entries(scope.scope_id).iter().find(|entry| {
entry.name() == first_name && entry.hygiene() == hygiene
entry.name() == first_name && entry.hygiene() == hygiene_id
});
if let Some(e) = entry {
@ -319,6 +330,21 @@ impl Resolver {
));
}
}
Scope::MacroDefScope(macro_id) => {
if let Some((parent_ctx, label_macro_id)) = hygiene_info {
if label_macro_id == **macro_id {
// A macro is allowed to refer to variables from before its declaration.
// Therefore, if we got to the rib of its declaration, give up its hygiene
// and use its parent expansion.
let parent_ctx = db.lookup_intern_syntax_context(parent_ctx);
hygiene_id = HygieneId::new(parent_ctx.opaque_and_semitransparent);
hygiene_info = parent_ctx.outer_expn.map(|expansion| {
let expansion = db.lookup_intern_macro_call(expansion);
(parent_ctx.parent, expansion.def)
});
}
}
}
Scope::GenericParams { params, def } => {
if let Some(id) = params.find_const_by_name(first_name, *def) {
let val = ValueNs::GenericParam(id);
@ -345,7 +371,7 @@ impl Resolver {
} else {
for scope in self.scopes() {
match scope {
Scope::ExprScope(_) => continue,
Scope::ExprScope(_) | Scope::MacroDefScope(_) => continue,
Scope::GenericParams { params, def } => {
if let Some(id) = params.find_type_by_name(first_name, *def) {
let ty = TypeNs::GenericParam(id);
@ -626,7 +652,7 @@ impl Resolver {
pub fn type_owner(&self) -> Option<TypeOwnerId> {
self.scopes().find_map(|scope| match scope {
Scope::BlockScope(_) => None,
Scope::BlockScope(_) | Scope::MacroDefScope(_) => None,
&Scope::GenericParams { def, .. } => Some(def.into()),
&Scope::ImplDefScope(id) => Some(id.into()),
&Scope::AdtScope(adt) => Some(adt.into()),
@ -657,6 +683,9 @@ impl Resolver {
expr_scopes: &Arc<ExprScopes>,
scope_id: ScopeId,
) {
if let Some(macro_id) = expr_scopes.macro_def(scope_id) {
resolver.scopes.push(Scope::MacroDefScope(macro_id.clone()));
}
resolver.scopes.push(Scope::ExprScope(ExprScope {
owner,
expr_scopes: expr_scopes.clone(),
@ -674,7 +703,7 @@ impl Resolver {
}
let start = self.scopes.len();
let innermost_scope = self.scopes().next();
let innermost_scope = self.scopes().find(|scope| !matches!(scope, Scope::MacroDefScope(_)));
match innermost_scope {
Some(&Scope::ExprScope(ExprScope { scope_id, ref expr_scopes, owner })) => {
let expr_scopes = expr_scopes.clone();
@ -798,6 +827,7 @@ impl Scope {
acc.add_local(e.name(), e.binding());
});
}
Scope::MacroDefScope(_) => {}
}
}
}
@ -837,6 +867,9 @@ fn resolver_for_scope_(
// already traverses all parents, so this is O(n²). I think we could only store the
// innermost module scope instead?
}
if let Some(macro_id) = scopes.macro_def(scope) {
r = r.push_scope(Scope::MacroDefScope(macro_id.clone()));
}
r = r.push_expr_scope(owner, Arc::clone(&scopes), scope);
}

View file

@ -747,7 +747,7 @@ impl InferenceContext<'_> {
Statement::Expr { expr, has_semi: _ } => {
self.consume_expr(*expr);
}
Statement::Item => (),
Statement::Item(_) => (),
}
}
if let Some(tail) = tail {

View file

@ -1656,7 +1656,7 @@ impl InferenceContext<'_> {
);
}
}
Statement::Item => (),
Statement::Item(_) => (),
}
}

View file

@ -89,7 +89,7 @@ impl InferenceContext<'_> {
Statement::Expr { expr, has_semi: _ } => {
self.infer_mut_expr(*expr, Mutability::Not);
}
Statement::Item => (),
Statement::Item(_) => (),
}
}
if let Some(tail) = tail {

View file

@ -1783,7 +1783,7 @@ impl<'ctx> MirLowerCtx<'ctx> {
self.push_fake_read(c, p, expr.into());
current = scope2.pop_and_drop(self, c, expr.into());
}
hir_def::hir::Statement::Item => (),
hir_def::hir::Statement::Item(_) => (),
}
}
if let Some(tail) = tail {

View file

@ -3737,3 +3737,68 @@ fn foo() {
"#,
);
}
#[test]
fn macro_expansion_can_refer_variables_defined_before_macro_definition() {
check_types(
r#"
fn foo() {
let v: i32 = 0;
macro_rules! m {
() => { v };
}
let v: bool = true;
m!();
// ^^^^ i32
}
"#,
);
}
#[test]
fn macro_rules_shadowing_works_with_hygiene() {
check_types(
r#"
fn foo() {
let v: bool;
macro_rules! m { () => { v } }
m!();
// ^^^^ bool
let v: char;
macro_rules! m { () => { v } }
m!();
// ^^^^ char
{
let v: u8;
macro_rules! m { () => { v } }
m!();
// ^^^^ u8
let v: i8;
macro_rules! m { () => { v } }
m!();
// ^^^^ i8
let v: i16;
macro_rules! m { () => { v } }
m!();
// ^^^^ i16
{
let v: u32;
macro_rules! m { () => { v } }
m!();
// ^^^^ u32
let v: u64;
macro_rules! m { () => { v } }
m!();
// ^^^^ u64
}
}
}
"#,
);
}

View file

@ -104,6 +104,36 @@ async fn foo() {
async fn foo() {
|| None?;
}
"#,
);
}
#[test]
fn macro_expansion_can_refer_label_defined_before_macro_definition() {
check_diagnostics(
r#"
fn foo() {
'bar: loop {
macro_rules! m {
() => { break 'bar };
}
m!();
}
}
"#,
);
check_diagnostics(
r#"
fn foo() {
'bar: loop {
macro_rules! m {
() => { break 'bar };
}
'bar: loop {
m!();
}
}
}
"#,
);
}