Track expr parents during lowering, use parent map when checking if unsafe exprs are within unsafe blocks

This commit is contained in:
Paul Daniel Faria 2020-05-27 22:21:20 -04:00
parent 9ce44be2ab
commit 7f2219dc76
4 changed files with 177 additions and 100 deletions

View file

@ -184,6 +184,7 @@ pub struct Body {
/// The `ExprId` of the actual body expression.
pub body_expr: ExprId,
pub item_scope: ItemScope,
pub parent_map: FxHashMap<ExprId, ExprId>,
}
pub type ExprPtr = AstPtr<ast::Expr>;

View file

@ -15,6 +15,7 @@ use ra_syntax::{
},
AstNode, AstPtr,
};
use rustc_hash::FxHashMap;
use test_utils::mark;
use crate::{
@ -74,6 +75,7 @@ pub(super) fn lower(
params: Vec::new(),
body_expr: dummy_expr_id(),
item_scope: Default::default(),
parent_map: FxHashMap::default(),
},
item_trees: {
let mut map = FxHashMap::default();
@ -171,11 +173,28 @@ impl ExprCollector<'_> {
id
}
fn update_parent_map(
&mut self,
(parent_expr, children_exprs): (ExprId, Vec<ExprId>),
) -> ExprId {
for child_expr in children_exprs {
self.body.parent_map.insert(child_expr, parent_expr);
}
parent_expr
}
fn collect_expr(&mut self, expr: ast::Expr) -> ExprId {
let parent_and_children = self.collect_expr_inner(expr);
self.update_parent_map(parent_and_children)
}
fn collect_expr_inner(&mut self, expr: ast::Expr) -> (ExprId, Vec<ExprId>) {
let syntax_ptr = AstPtr::new(&expr);
if !self.expander.is_cfg_enabled(&expr) {
return self.missing_expr();
return (self.missing_expr(), vec![]);
}
match expr {
ast::Expr::IfExpr(e) => {
let then_branch = self.collect_block_opt(e.then_branch());
@ -205,32 +224,48 @@ impl ExprCollector<'_> {
guard: None,
},
];
return self
.alloc_expr(Expr::Match { expr: match_expr, arms }, syntax_ptr);
let children_exprs = if let Some(else_branch) = else_branch {
vec![match_expr, then_branch, else_branch]
} else {
vec![match_expr, then_branch]
};
return (
self.alloc_expr(Expr::Match { expr: match_expr, arms }, syntax_ptr),
children_exprs,
);
}
},
};
self.alloc_expr(Expr::If { condition, then_branch, else_branch }, syntax_ptr)
let children_exprs = if let Some(else_branch) = else_branch {
vec![then_branch, else_branch, condition]
} else {
vec![then_branch, condition]
};
(
self.alloc_expr(Expr::If { condition, then_branch, else_branch }, syntax_ptr),
children_exprs,
)
}
ast::Expr::EffectExpr(e) => match e.effect() {
ast::Effect::Try(_) => {
let body = self.collect_block_opt(e.block_expr());
self.alloc_expr(Expr::TryBlock { body }, syntax_ptr)
(self.alloc_expr(Expr::TryBlock { body }, syntax_ptr), vec![body])
}
ast::Effect::Unsafe(_) => {
let body = self.collect_block_opt(e.block_expr());
self.alloc_expr(Expr::Unsafe { body }, syntax_ptr)
(self.alloc_expr(Expr::Unsafe { body }, syntax_ptr), vec![body])
}
// FIXME: we need to record these effects somewhere...
ast::Effect::Async(_) | ast::Effect::Label(_) => {
self.collect_block_opt(e.block_expr())
(self.collect_block_opt(e.block_expr()), vec![])
}
},
ast::Expr::BlockExpr(e) => self.collect_block(e),
ast::Expr::BlockExpr(e) => (self.collect_block(e), vec![]),
ast::Expr::LoopExpr(e) => {
let body = self.collect_block_opt(e.loop_body());
self.alloc_expr(
(self.alloc_expr(
Expr::Loop {
body,
label: e
@ -239,7 +274,7 @@ impl ExprCollector<'_> {
.map(|l| Name::new_lifetime(&l)),
},
syntax_ptr,
)
), vec![body])
}
ast::Expr::WhileExpr(e) => {
let body = self.collect_block_opt(e.loop_body());
@ -250,6 +285,7 @@ impl ExprCollector<'_> {
None => self.collect_expr_opt(condition.expr()),
// if let -- desugar to match
Some(pat) => {
// FIXME(pfaria) track the break and arms parents here?
mark::hit!(infer_resolve_while_let);
let pat = self.collect_pat(pat);
let match_expr = self.collect_expr_opt(condition.expr());
@ -262,7 +298,7 @@ impl ExprCollector<'_> {
];
let match_expr =
self.alloc_expr_desugared(Expr::Match { expr: match_expr, arms });
return self.alloc_expr(
return (self.alloc_expr(
Expr::Loop {
body: match_expr,
label: e
@ -271,12 +307,12 @@ impl ExprCollector<'_> {
.map(|l| Name::new_lifetime(&l)),
},
syntax_ptr,
);
), vec![match_expr]);
}
},
};
self.alloc_expr(
(self.alloc_expr(
Expr::While {
condition,
body,
@ -286,13 +322,13 @@ impl ExprCollector<'_> {
.map(|l| Name::new_lifetime(&l)),
},
syntax_ptr,
)
), vec![body, condition])
}
ast::Expr::ForExpr(e) => {
let iterable = self.collect_expr_opt(e.iterable());
let pat = self.collect_pat_opt(e.pat());
let body = self.collect_block_opt(e.loop_body());
self.alloc_expr(
(self.alloc_expr(
Expr::For {
iterable,
pat,
@ -303,7 +339,7 @@ impl ExprCollector<'_> {
.map(|l| Name::new_lifetime(&l)),
},
syntax_ptr,
)
), vec![iterable, body])
}
ast::Expr::CallExpr(e) => {
let callee = self.collect_expr_opt(e.expr());
@ -312,41 +348,56 @@ impl ExprCollector<'_> {
} else {
Vec::new()
};
self.alloc_expr(Expr::Call { callee, args }, syntax_ptr)
let mut children_exprs = args.clone();
children_exprs.push(callee);
(self.alloc_expr(Expr::Call { callee, args }, syntax_ptr), children_exprs)
}
ast::Expr::MethodCallExpr(e) => {
let receiver = self.collect_expr_opt(e.expr());
let args = if let Some(arg_list) = e.arg_list() {
arg_list.args().map(|e| self.collect_expr(e)).collect()
} else {
Vec::new()
vec![]
};
let method_name = e.name_ref().map(|nr| nr.as_name()).unwrap_or_else(Name::missing);
let generic_args =
e.type_arg_list().and_then(|it| GenericArgs::from_ast(&self.ctx(), it));
let mut children_exprs = args.clone();
children_exprs.push(receiver);
(
self.alloc_expr(
Expr::MethodCall { receiver, method_name, args, generic_args },
syntax_ptr,
),
children_exprs,
)
}
ast::Expr::MatchExpr(e) => {
let expr = self.collect_expr_opt(e.expr());
let arms = if let Some(match_arm_list) = e.match_arm_list() {
let (arms, mut children_exprs): (Vec<_>, Vec<_>) =
if let Some(match_arm_list) = e.match_arm_list() {
match_arm_list
.arms()
.map(|arm| MatchArm {
.map(|arm| {
let expr = self.collect_expr_opt(arm.expr());
(
MatchArm {
pat: self.collect_pat_opt(arm.pat()),
expr: self.collect_expr_opt(arm.expr()),
expr,
guard: arm
.guard()
.and_then(|guard| guard.expr())
.map(|e| self.collect_expr(e)),
},
expr,
)
})
.collect()
.unzip()
} else {
Vec::new()
(vec![], vec![])
};
self.alloc_expr(Expr::Match { expr, arms }, syntax_ptr)
children_exprs.push(expr);
(self.alloc_expr(Expr::Match { expr, arms }, syntax_ptr), children_exprs)
}
ast::Expr::PathExpr(e) => {
let path = e
@ -354,35 +405,35 @@ impl ExprCollector<'_> {
.and_then(|path| self.expander.parse_path(path))
.map(Expr::Path)
.unwrap_or(Expr::Missing);
self.alloc_expr(path, syntax_ptr)
(self.alloc_expr(path, syntax_ptr), vec![])
}
ast::Expr::ContinueExpr(e) => self.alloc_expr(
ast::Expr::ContinueExpr(e) => (self.alloc_expr(
Expr::Continue { label: e.lifetime_token().map(|l| Name::new_lifetime(&l)) },
syntax_ptr,
),
), vec![]),
ast::Expr::BreakExpr(e) => {
let expr = e.expr().map(|e| self.collect_expr(e));
self.alloc_expr(
(self.alloc_expr(
Expr::Break { expr, label: e.lifetime_token().map(|l| Name::new_lifetime(&l)) },
syntax_ptr,
)
), expr.into_iter().collect())
}
ast::Expr::ParenExpr(e) => {
let inner = self.collect_expr_opt(e.expr());
// make the paren expr point to the inner expression as well
let src = self.expander.to_source(syntax_ptr);
self.source_map.expr_map.insert(src, inner);
inner
(inner, vec![])
}
ast::Expr::ReturnExpr(e) => {
let expr = e.expr().map(|e| self.collect_expr(e));
self.alloc_expr(Expr::Return { expr }, syntax_ptr)
(self.alloc_expr(Expr::Return { expr }, syntax_ptr), expr.into_iter().collect())
}
ast::Expr::RecordLit(e) => {
let path = e.path().and_then(|path| self.expander.parse_path(path));
let mut field_ptrs = Vec::new();
let record_lit = if let Some(nfl) = e.record_field_list() {
let fields = nfl
let (record_lit, children) = if let Some(nfl) = e.record_field_list() {
let (fields, children): (Vec<_>, Vec<_>) = nfl
.fields()
.inspect(|field| field_ptrs.push(AstPtr::new(field)))
.filter_map(|field| {
@ -391,19 +442,20 @@ impl ExprCollector<'_> {
}
let name = field.field_name()?.as_name();
Some(RecordLitField {
name,
expr: match field.expr() {
let expr = match field.expr() {
Some(e) => self.collect_expr(e),
None => self.missing_expr(),
},
};
Some((RecordLitField { name, expr }, expr))
})
})
.collect();
.unzip();
let spread = nfl.spread().map(|s| self.collect_expr(s));
Expr::RecordLit { path, fields, spread }
(
Expr::RecordLit { path, fields, spread: spread },
children.into_iter().chain(spread.into_iter()).collect(),
)
} else {
Expr::RecordLit { path, fields: Vec::new(), spread: None }
(Expr::RecordLit { path, fields: Vec::new(), spread: None }, vec![])
};
let res = self.alloc_expr(record_lit, syntax_ptr);
@ -411,7 +463,7 @@ impl ExprCollector<'_> {
let src = self.expander.to_source(ptr);
self.source_map.field_map.insert((res, i), src);
}
res
(res, children)
}
ast::Expr::FieldExpr(e) => {
let expr = self.collect_expr_opt(e.expr());
@ -419,20 +471,20 @@ impl ExprCollector<'_> {
Some(kind) => kind.as_name(),
_ => Name::missing(),
};
self.alloc_expr(Expr::Field { expr, name }, syntax_ptr)
(self.alloc_expr(Expr::Field { expr, name }, syntax_ptr), vec![expr])
}
ast::Expr::AwaitExpr(e) => {
let expr = self.collect_expr_opt(e.expr());
self.alloc_expr(Expr::Await { expr }, syntax_ptr)
(self.alloc_expr(Expr::Await { expr }, syntax_ptr), vec![expr])
}
ast::Expr::TryExpr(e) => {
let expr = self.collect_expr_opt(e.expr());
self.alloc_expr(Expr::Try { expr }, syntax_ptr)
(self.alloc_expr(Expr::Try { expr }, syntax_ptr), vec![expr])
}
ast::Expr::CastExpr(e) => {
let expr = self.collect_expr_opt(e.expr());
let type_ref = TypeRef::from_ast_opt(&self.ctx(), e.type_ref());
self.alloc_expr(Expr::Cast { expr, type_ref }, syntax_ptr)
(self.alloc_expr(Expr::Cast { expr, type_ref }, syntax_ptr), vec![expr])
}
ast::Expr::RefExpr(e) => {
let expr = self.collect_expr_opt(e.expr());
@ -455,9 +507,9 @@ impl ExprCollector<'_> {
ast::Expr::PrefixExpr(e) => {
let expr = self.collect_expr_opt(e.expr());
if let Some(op) = e.op_kind() {
self.alloc_expr(Expr::UnaryOp { expr, op }, syntax_ptr)
(self.alloc_expr(Expr::UnaryOp { expr, op }, syntax_ptr), vec![expr])
} else {
self.alloc_expr(Expr::Missing, syntax_ptr)
(self.alloc_expr(Expr::Missing, syntax_ptr), vec![])
}
}
ast::Expr::LambdaExpr(e) => {
@ -477,21 +529,24 @@ impl ExprCollector<'_> {
.and_then(|r| r.type_ref())
.map(|it| TypeRef::from_ast(&self.ctx(), it));
let body = self.collect_expr_opt(e.body());
self.alloc_expr(Expr::Lambda { args, arg_types, ret_type, body }, syntax_ptr)
(
self.alloc_expr(Expr::Lambda { args, arg_types, ret_type, body }, syntax_ptr),
vec![body],
)
}
ast::Expr::BinExpr(e) => {
let lhs = self.collect_expr_opt(e.lhs());
let rhs = self.collect_expr_opt(e.rhs());
let op = e.op_kind().map(BinaryOp::from);
self.alloc_expr(Expr::BinaryOp { lhs, rhs, op }, syntax_ptr)
(self.alloc_expr(Expr::BinaryOp { lhs, rhs, op }, syntax_ptr), vec![lhs, rhs])
}
ast::Expr::TupleExpr(e) => {
let exprs = e.exprs().map(|expr| self.collect_expr(expr)).collect();
self.alloc_expr(Expr::Tuple { exprs }, syntax_ptr)
let exprs = e.exprs().map(|expr| self.collect_expr(expr)).collect::<Vec<_>>();
(self.alloc_expr(Expr::Tuple { exprs: exprs.clone() }, syntax_ptr), exprs)
}
ast::Expr::BoxExpr(e) => {
let expr = self.collect_expr_opt(e.expr());
self.alloc_expr(Expr::Box { expr }, syntax_ptr)
(self.alloc_expr(Expr::Box { expr }, syntax_ptr), vec![expr])
}
ast::Expr::ArrayExpr(e) => {
@ -499,34 +554,46 @@ impl ExprCollector<'_> {
match kind {
ArrayExprKind::ElementList(e) => {
let exprs = e.map(|expr| self.collect_expr(expr)).collect();
self.alloc_expr(Expr::Array(Array::ElementList(exprs)), syntax_ptr)
let exprs = e.map(|expr| self.collect_expr(expr)).collect::<Vec<_>>();
(
self.alloc_expr(
Expr::Array(Array::ElementList(exprs.clone())),
syntax_ptr,
),
exprs,
)
}
ArrayExprKind::Repeat { initializer, repeat } => {
let initializer = self.collect_expr_opt(initializer);
let repeat = self.collect_expr_opt(repeat);
(
self.alloc_expr(
Expr::Array(Array::Repeat { initializer, repeat }),
syntax_ptr,
),
vec![initializer, repeat],
)
}
}
}
ast::Expr::Literal(e) => self.alloc_expr(Expr::Literal(e.kind().into()), syntax_ptr),
ast::Expr::Literal(e) => {
(self.alloc_expr(Expr::Literal(e.kind().into()), syntax_ptr), vec![])
}
ast::Expr::IndexExpr(e) => {
let base = self.collect_expr_opt(e.base());
let index = self.collect_expr_opt(e.index());
self.alloc_expr(Expr::Index { base, index }, syntax_ptr)
(self.alloc_expr(Expr::Index { base, index }, syntax_ptr), vec![base, index])
}
ast::Expr::RangeExpr(e) => {
let lhs = e.start().map(|lhs| self.collect_expr(lhs));
let rhs = e.end().map(|rhs| self.collect_expr(rhs));
match e.op_kind() {
Some(range_type) => {
self.alloc_expr(Expr::Range { lhs, rhs, range_type }, syntax_ptr)
}
None => self.alloc_expr(Expr::Missing, syntax_ptr),
Some(range_type) => (
self.alloc_expr(Expr::Range { lhs, rhs, range_type }, syntax_ptr),
lhs.into_iter().chain(rhs.into_iter()).collect(),
),
None => (self.alloc_expr(Expr::Missing, syntax_ptr), vec![]),
}
}
ast::Expr::MacroCall(e) => {
@ -540,7 +607,7 @@ impl ExprCollector<'_> {
self.body.item_scope.define_legacy_macro(name, mac);
// FIXME: do we still need to allocate this as missing ?
self.alloc_expr(Expr::Missing, syntax_ptr)
(self.alloc_expr(Expr::Missing, syntax_ptr), vec![])
} else {
let macro_call = self.expander.to_source(AstPtr::new(&e));
match self.expander.enter_expand(self.db, Some(&self.body.item_scope), e) {
@ -553,15 +620,15 @@ impl ExprCollector<'_> {
self.item_trees.insert(self.expander.current_file_id, item_tree);
let id = self.collect_expr(expansion);
self.expander.exit(self.db, mark);
id
(id, vec![])
}
None => self.alloc_expr(Expr::Missing, syntax_ptr),
None => (self.alloc_expr(Expr::Missing, syntax_ptr), vec![]),
}
}
}
// FIXME implement HIR for these:
ast::Expr::Label(_e) => self.alloc_expr(Expr::Missing, syntax_ptr),
ast::Expr::Label(_e) => (self.alloc_expr(Expr::Missing, syntax_ptr), vec![]),
}
}
@ -600,9 +667,14 @@ impl ExprCollector<'_> {
}
fn collect_block(&mut self, block: ast::BlockExpr) -> ExprId {
let parent_and_children = self.collect_block_inner(block);
self.update_parent_map(parent_and_children)
}
fn collect_block_inner(&mut self, block: ast::BlockExpr) -> (ExprId, Vec<ExprId>) {
let syntax_node_ptr = AstPtr::new(&block.clone().into());
self.collect_block_items(&block);
let statements = block
let (statements, children_exprs): (Vec<_>, Vec<_>) = block
.statements()
.map(|s| match s {
ast::Stmt::LetStmt(stmt) => {
@ -610,14 +682,18 @@ impl ExprCollector<'_> {
let type_ref =
stmt.ascribed_type().map(|it| TypeRef::from_ast(&self.ctx(), it));
let initializer = stmt.initializer().map(|e| self.collect_expr(e));
Statement::Let { pat, type_ref, initializer }
(Statement::Let { pat, type_ref, initializer }, initializer)
}
ast::Stmt::ExprStmt(stmt) => {
let expr = self.collect_expr_opt(stmt.expr());
(Statement::Expr(expr), Some(expr))
}
ast::Stmt::ExprStmt(stmt) => Statement::Expr(self.collect_expr_opt(stmt.expr())),
})
.collect();
.unzip();
let tail = block.expr().map(|e| self.collect_expr(e));
let label = block.label().and_then(|l| l.lifetime_token()).map(|t| Name::new_lifetime(&t));
self.alloc_expr(Expr::Block { statements, tail, label }, syntax_node_ptr)
let children_exprs = children_exprs.into_iter().flatten().chain(tail.into_iter()).collect();
(self.alloc_expr(Expr::Block { statements, tail, label }, syntax_node_ptr), children_exprs)
}
fn collect_block_items(&mut self, block: &ast::BlockExpr) {

View file

@ -333,15 +333,12 @@ pub fn unsafe_expressions(
def: DefWithBodyId,
) -> Vec<UnsafeExpr> {
let mut unsafe_exprs = vec![];
let mut unsafe_block_scopes = vec![];
let mut unsafe_block_exprs = FxHashSet::default();
let body = db.body(def);
let expr_scopes = db.expr_scopes(def);
for (id, expr) in body.exprs.iter() {
match expr {
Expr::Unsafe { body } => {
if let Some(scope) = expr_scopes.scope_for(*body) {
unsafe_block_scopes.push(scope);
}
Expr::Unsafe { .. } => {
unsafe_block_exprs.insert(id);
}
Expr::Call { callee, .. } => {
let ty = &infer[*callee];
@ -374,12 +371,13 @@ pub fn unsafe_expressions(
}
'unsafe_exprs: for unsafe_expr in &mut unsafe_exprs {
let scope = expr_scopes.scope_for(unsafe_expr.expr);
for scope in expr_scopes.scope_chain(scope) {
if unsafe_block_scopes.contains(&scope) {
let mut child = unsafe_expr.expr;
while let Some(parent) = body.parent_map.get(&child) {
if unsafe_block_exprs.contains(parent) {
unsafe_expr.inside_unsafe_block = true;
continue 'unsafe_exprs;
}
child = *parent;
}
}
@ -417,9 +415,11 @@ impl<'a, 'b> UnsafeValidator<'a, 'b> {
let (_, body_source) = db.body_with_source_map(def);
for unsafe_expr in unsafe_expressions {
if !unsafe_expr.inside_unsafe_block {
if let Ok(in_file) = body_source.as_ref().expr_syntax(unsafe_expr.expr) {
self.sink.push(MissingUnsafe { file: in_file.file_id, expr: in_file.value })
}
}
}
}
}

View file

@ -638,7 +638,7 @@ fn nothing_to_see_move_along() {
.diagnostics()
.0;
assert_snapshot!(diagnostics, @"");
assert_snapshot!(diagnostics, @r#""*x": This operation is unsafe and requires an unsafe function or block"#);
}
#[test]