fix hir for new block syntax

This commit is contained in:
Aleksey Kladov 2019-09-02 21:23:19 +03:00
parent dcf8e89503
commit 5e3f291195
14 changed files with 72 additions and 57 deletions

View file

@ -65,9 +65,9 @@ pub(crate) fn move_arm_cond_to_match_guard(mut ctx: AssistCtx<impl HirDatabase>)
"move condition to match guard", "move condition to match guard",
|edit| { |edit| {
edit.target(if_expr.syntax().text_range()); edit.target(if_expr.syntax().text_range());
let then_only_expr = then_block.statements().next().is_none(); let then_only_expr = then_block.block().and_then(|it| it.statements().next()).is_none();
match &then_block.expr() { match &then_block.block().and_then(|it| it.expr()) {
Some(then_expr) if then_only_expr => { Some(then_expr) if then_only_expr => {
edit.replace(if_expr.syntax().text_range(), then_expr.syntax().text()) edit.replace(if_expr.syntax().text_range(), then_expr.syntax().text())
} }

View file

@ -1,3 +1,4 @@
use format_buf::format;
use hir::db::HirDatabase; use hir::db::HirDatabase;
use ra_fmt::extract_trivial_expression; use ra_fmt::extract_trivial_expression;
use ra_syntax::{ast, AstNode}; use ra_syntax::{ast, AstNode};
@ -25,16 +26,21 @@ pub(crate) fn replace_if_let_with_match(mut ctx: AssistCtx<impl HirDatabase>) ->
ctx.build() ctx.build()
} }
fn build_match_expr(expr: ast::Expr, pat1: ast::Pat, arm1: ast::Block, arm2: ast::Block) -> String { fn build_match_expr(
expr: ast::Expr,
pat1: ast::Pat,
arm1: ast::BlockExpr,
arm2: ast::BlockExpr,
) -> String {
let mut buf = String::new(); let mut buf = String::new();
buf.push_str(&format!("match {} {{\n", expr.syntax().text())); format!(buf, "match {} {{\n", expr.syntax().text());
buf.push_str(&format!(" {} => {}\n", pat1.syntax().text(), format_arm(&arm1))); format!(buf, " {} => {}\n", pat1.syntax().text(), format_arm(&arm1));
buf.push_str(&format!(" _ => {}\n", format_arm(&arm2))); format!(buf, " _ => {}\n", format_arm(&arm2));
buf.push_str("}"); buf.push_str("}");
buf buf
} }
fn format_arm(block: &ast::Block) -> String { fn format_arm(block: &ast::BlockExpr) -> String {
match extract_trivial_expression(block) { match extract_trivial_expression(block) {
None => block.syntax().text().to_string(), None => block.syntax().text().to_string(),
Some(e) => format!("{},", e.syntax().text()), Some(e) => format!("{},", e.syntax().text()),

View file

@ -34,7 +34,8 @@ fn prev_tokens(token: SyntaxToken) -> impl Iterator<Item = SyntaxToken> {
successors(token.prev_token(), |token| token.prev_token()) successors(token.prev_token(), |token| token.prev_token())
} }
pub fn extract_trivial_expression(block: &ast::Block) -> Option<ast::Expr> { pub fn extract_trivial_expression(expr: &ast::BlockExpr) -> Option<ast::Expr> {
let block = expr.block()?;
let expr = block.expr()?; let expr = block.expr()?;
if expr.syntax().text().contains_char('\n') { if expr.syntax().text().contains_char('\n') {
return None; return None;

View file

@ -119,10 +119,10 @@ where
expr_id: crate::expr::ExprId, expr_id: crate::expr::ExprId,
) -> Option<Source<ast::Expr>> { ) -> Option<Source<ast::Expr>> {
let source_map = self.body_source_map(db); let source_map = self.body_source_map(db);
let expr_syntax = source_map.expr_syntax(expr_id)?; let expr_syntax = source_map.expr_syntax(expr_id)?.a()?;
let source = self.source(db); let source = self.source(db);
let node = expr_syntax.to_node(&source.ast.syntax()); let ast = expr_syntax.to_node(&source.ast.syntax());
ast::Expr::cast(node).map(|ast| Source { file_id: source.file_id, ast }) Some(Source { file_id: source.file_id, ast })
} }
} }

View file

@ -9,7 +9,7 @@ use ra_syntax::{
self, ArgListOwner, ArrayExprKind, LiteralKind, LoopBodyOwner, NameOwner, self, ArgListOwner, ArrayExprKind, LiteralKind, LoopBodyOwner, NameOwner,
TypeAscriptionOwner, TypeAscriptionOwner,
}, },
AstNode, AstPtr, SyntaxNodePtr, AstNode, AstPtr,
}; };
use test_utils::tested_by; use test_utils::tested_by;
@ -56,13 +56,14 @@ pub struct Body {
/// file, so that we don't recompute types whenever some whitespace is typed. /// file, so that we don't recompute types whenever some whitespace is typed.
#[derive(Default, Debug, Eq, PartialEq)] #[derive(Default, Debug, Eq, PartialEq)]
pub struct BodySourceMap { pub struct BodySourceMap {
expr_map: FxHashMap<SyntaxNodePtr, ExprId>, expr_map: FxHashMap<ExprPtr, ExprId>,
expr_map_back: ArenaMap<ExprId, SyntaxNodePtr>, expr_map_back: ArenaMap<ExprId, ExprPtr>,
pat_map: FxHashMap<PatPtr, PatId>, pat_map: FxHashMap<PatPtr, PatId>,
pat_map_back: ArenaMap<PatId, PatPtr>, pat_map_back: ArenaMap<PatId, PatPtr>,
field_map: FxHashMap<(ExprId, usize), AstPtr<ast::RecordField>>, field_map: FxHashMap<(ExprId, usize), AstPtr<ast::RecordField>>,
} }
type ExprPtr = Either<AstPtr<ast::Expr>, AstPtr<ast::RecordField>>;
type PatPtr = Either<AstPtr<ast::Pat>, AstPtr<ast::SelfParam>>; type PatPtr = Either<AstPtr<ast::Pat>, AstPtr<ast::SelfParam>>;
impl Body { impl Body {
@ -128,16 +129,12 @@ impl Index<PatId> for Body {
} }
impl BodySourceMap { impl BodySourceMap {
pub(crate) fn expr_syntax(&self, expr: ExprId) -> Option<SyntaxNodePtr> { pub(crate) fn expr_syntax(&self, expr: ExprId) -> Option<ExprPtr> {
self.expr_map_back.get(expr).cloned() self.expr_map_back.get(expr).cloned()
} }
pub(crate) fn syntax_expr(&self, ptr: SyntaxNodePtr) -> Option<ExprId> {
self.expr_map.get(&ptr).cloned()
}
pub(crate) fn node_expr(&self, node: &ast::Expr) -> Option<ExprId> { pub(crate) fn node_expr(&self, node: &ast::Expr) -> Option<ExprId> {
self.expr_map.get(&SyntaxNodePtr::new(node.syntax())).cloned() self.expr_map.get(&Either::A(AstPtr::new(node))).cloned()
} }
pub(crate) fn pat_syntax(&self, pat: PatId) -> Option<PatPtr> { pub(crate) fn pat_syntax(&self, pat: PatId) -> Option<PatPtr> {
@ -575,11 +572,12 @@ where
current_file_id: file_id, current_file_id: file_id,
} }
} }
fn alloc_expr(&mut self, expr: Expr, syntax_ptr: SyntaxNodePtr) -> ExprId { fn alloc_expr(&mut self, expr: Expr, ptr: AstPtr<ast::Expr>) -> ExprId {
let ptr = Either::A(ptr);
let id = self.exprs.alloc(expr); let id = self.exprs.alloc(expr);
if self.current_file_id == self.original_file_id { if self.current_file_id == self.original_file_id {
self.source_map.expr_map.insert(syntax_ptr, id); self.source_map.expr_map.insert(ptr, id);
self.source_map.expr_map_back.insert(id, syntax_ptr); self.source_map.expr_map_back.insert(id, ptr);
} }
id id
} }
@ -601,7 +599,7 @@ where
} }
fn collect_expr(&mut self, expr: ast::Expr) -> ExprId { fn collect_expr(&mut self, expr: ast::Expr) -> ExprId {
let syntax_ptr = SyntaxNodePtr::new(expr.syntax()); let syntax_ptr = AstPtr::new(&expr);
match expr { match expr {
ast::Expr::IfExpr(e) => { ast::Expr::IfExpr(e) => {
let then_branch = self.collect_block_opt(e.then_branch()); let then_branch = self.collect_block_opt(e.then_branch());
@ -640,10 +638,10 @@ where
self.alloc_expr(Expr::If { condition, then_branch, else_branch }, syntax_ptr) self.alloc_expr(Expr::If { condition, then_branch, else_branch }, syntax_ptr)
} }
ast::Expr::TryBlockExpr(e) => { ast::Expr::TryBlockExpr(e) => {
let body = self.collect_block_opt(e.block()); let body = self.collect_block_opt(e.body());
self.alloc_expr(Expr::TryBlock { body }, syntax_ptr) self.alloc_expr(Expr::TryBlock { body }, syntax_ptr)
} }
ast::Expr::BlockExpr(e) => self.collect_block_opt(e.block()), ast::Expr::BlockExpr(e) => self.collect_block(e),
ast::Expr::LoopExpr(e) => { ast::Expr::LoopExpr(e) => {
let body = self.collect_block_opt(e.loop_body()); let body = self.collect_block_opt(e.loop_body());
self.alloc_expr(Expr::Loop { body }, syntax_ptr) self.alloc_expr(Expr::Loop { body }, syntax_ptr)
@ -739,7 +737,7 @@ where
ast::Expr::ParenExpr(e) => { ast::Expr::ParenExpr(e) => {
let inner = self.collect_expr_opt(e.expr()); let inner = self.collect_expr_opt(e.expr());
// make the paren expr point to the inner expression as well // make the paren expr point to the inner expression as well
self.source_map.expr_map.insert(syntax_ptr, inner); self.source_map.expr_map.insert(Either::A(syntax_ptr), inner);
inner inner
} }
ast::Expr::ReturnExpr(e) => { ast::Expr::ReturnExpr(e) => {
@ -763,12 +761,9 @@ where
} else if let Some(nr) = field.name_ref() { } else if let Some(nr) = field.name_ref() {
// field shorthand // field shorthand
let id = self.exprs.alloc(Expr::Path(Path::from_name_ref(&nr))); let id = self.exprs.alloc(Expr::Path(Path::from_name_ref(&nr)));
self.source_map let ptr = Either::B(AstPtr::new(&field));
.expr_map self.source_map.expr_map.insert(ptr, id);
.insert(SyntaxNodePtr::new(nr.syntax()), id); self.source_map.expr_map_back.insert(id, ptr);
self.source_map
.expr_map_back
.insert(id, SyntaxNodePtr::new(nr.syntax()));
id id
} else { } else {
self.exprs.alloc(Expr::Missing) self.exprs.alloc(Expr::Missing)
@ -942,7 +937,12 @@ where
} }
} }
fn collect_block(&mut self, block: ast::Block) -> ExprId { fn collect_block(&mut self, expr: ast::BlockExpr) -> ExprId {
let syntax_node_ptr = AstPtr::new(&expr.clone().into());
let block = match expr.block() {
Some(block) => block,
None => return self.alloc_expr(Expr::Missing, syntax_node_ptr),
};
let statements = block let statements = block
.statements() .statements()
.map(|s| match s { .map(|s| match s {
@ -956,11 +956,11 @@ where
}) })
.collect(); .collect();
let tail = block.expr().map(|e| self.collect_expr(e)); let tail = block.expr().map(|e| self.collect_expr(e));
self.alloc_expr(Expr::Block { statements, tail }, SyntaxNodePtr::new(block.syntax())) self.alloc_expr(Expr::Block { statements, tail }, syntax_node_ptr)
} }
fn collect_block_opt(&mut self, block: Option<ast::Block>) -> ExprId { fn collect_block_opt(&mut self, expr: Option<ast::BlockExpr>) -> ExprId {
if let Some(block) = block { if let Some(block) = expr {
self.collect_block(block) self.collect_block(block)
} else { } else {
self.exprs.alloc(Expr::Missing) self.exprs.alloc(Expr::Missing)

View file

@ -172,7 +172,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use ra_db::SourceDatabase; use ra_db::SourceDatabase;
use ra_syntax::{algo::find_node_at_offset, ast, AstNode, SyntaxNodePtr}; use ra_syntax::{algo::find_node_at_offset, ast, AstNode};
use test_utils::{assert_eq_text, extract_offset}; use test_utils::{assert_eq_text, extract_offset};
use crate::{mock::MockDatabase, source_binder::SourceAnalyzer}; use crate::{mock::MockDatabase, source_binder::SourceAnalyzer};
@ -194,8 +194,7 @@ mod tests {
let analyzer = SourceAnalyzer::new(&db, file_id, marker.syntax(), None); let analyzer = SourceAnalyzer::new(&db, file_id, marker.syntax(), None);
let scopes = analyzer.scopes(); let scopes = analyzer.scopes();
let expr_id = let expr_id = analyzer.body_source_map().node_expr(&marker.into()).unwrap();
analyzer.body_source_map().syntax_expr(SyntaxNodePtr::new(marker.syntax())).unwrap();
let scope = scopes.scope_for(expr_id); let scope = scopes.scope_for(expr_id);
let actual = scopes let actual = scopes

View file

@ -1,7 +1,7 @@
use rustc_hash::FxHashSet;
use std::sync::Arc; use std::sync::Arc;
use ra_syntax::ast::{AstNode, RecordLit}; use ra_syntax::ast::{self, AstNode};
use rustc_hash::FxHashSet;
use super::{Expr, ExprId, RecordLitField}; use super::{Expr, ExprId, RecordLitField};
use crate::{ use crate::{
@ -13,7 +13,6 @@ use crate::{
ty::{ApplicationTy, InferenceResult, Ty, TypeCtor}, ty::{ApplicationTy, InferenceResult, Ty, TypeCtor},
Function, HasSource, HirDatabase, ModuleDef, Name, Path, PerNs, Resolution, Function, HasSource, HirDatabase, ModuleDef, Name, Path, PerNs, Resolution,
}; };
use ra_syntax::ast;
pub(crate) struct ExprValidator<'a, 'b: 'a> { pub(crate) struct ExprValidator<'a, 'b: 'a> {
func: Function, func: Function,
@ -84,8 +83,12 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
let source_file = parse.tree(); let source_file = parse.tree();
if let Some(field_list_node) = source_map if let Some(field_list_node) = source_map
.expr_syntax(id) .expr_syntax(id)
.and_then(|ptr| ptr.a())
.map(|ptr| ptr.to_node(source_file.syntax())) .map(|ptr| ptr.to_node(source_file.syntax()))
.and_then(RecordLit::cast) .and_then(|expr| match expr {
ast::Expr::RecordLit(it) => Some(it),
_ => None,
})
.and_then(|lit| lit.record_field_list()) .and_then(|lit| lit.record_field_list())
{ {
let field_list_ptr = AstPtr::new(&field_list_node); let field_list_ptr = AstPtr::new(&field_list_node);
@ -135,7 +138,7 @@ impl<'a, 'b> ExprValidator<'a, 'b> {
let source_map = self.func.body_source_map(db); let source_map = self.func.body_source_map(db);
let file_id = self.func.source(db).file_id; let file_id = self.func.source(db).file_id;
if let Some(expr) = source_map.expr_syntax(id).and_then(|n| n.cast::<ast::Expr>()) { if let Some(expr) = source_map.expr_syntax(id).and_then(|n| n.a()) {
self.sink.push(MissingOkInTailExpr { file: file_id, expr }); self.sink.push(MissingOkInTailExpr { file: file_id, expr });
} }
} }

View file

@ -462,8 +462,8 @@ fn scope_for(
node: &SyntaxNode, node: &SyntaxNode,
) -> Option<ScopeId> { ) -> Option<ScopeId> {
node.ancestors() node.ancestors()
.map(|it| SyntaxNodePtr::new(&it)) .filter_map(ast::Expr::cast)
.filter_map(|ptr| source_map.syntax_expr(ptr)) .filter_map(|it| source_map.node_expr(&it))
.find_map(|it| scopes.scope_for(it)) .find_map(|it| scopes.scope_for(it))
} }
@ -475,7 +475,10 @@ fn scope_for_offset(
scopes scopes
.scope_by_expr() .scope_by_expr()
.iter() .iter()
.filter_map(|(id, scope)| Some((source_map.expr_syntax(*id)?, scope))) .filter_map(|(id, scope)| {
let ast_ptr = source_map.expr_syntax(*id)?.a()?;
Some((ast_ptr.syntax_node_ptr(), scope))
})
// find containing scope // find containing scope
.min_by_key(|(ptr, _scope)| { .min_by_key(|(ptr, _scope)| {
(!(ptr.range().start() <= offset && offset <= ptr.range().end()), ptr.range().len()) (!(ptr.range().start() <= offset && offset <= ptr.range().end()), ptr.range().len())
@ -495,7 +498,10 @@ fn adjust(
let child_scopes = scopes let child_scopes = scopes
.scope_by_expr() .scope_by_expr()
.iter() .iter()
.filter_map(|(id, scope)| Some((source_map.expr_syntax(*id)?, scope))) .filter_map(|(id, scope)| {
let ast_ptr = source_map.expr_syntax(*id)?.a()?;
Some((ast_ptr.syntax_node_ptr(), scope))
})
.map(|(ptr, scope)| (ptr.range(), scope)) .map(|(ptr, scope)| (ptr.range(), scope))
.filter(|(range, _)| range.start() <= offset && range.is_subrange(&r) && *range != r); .filter(|(range, _)| range.start() <= offset && range.is_subrange(&r) && *range != r);

View file

@ -3582,7 +3582,7 @@ fn infer(content: &str) -> String {
for (expr, ty) in inference_result.type_of_expr.iter() { for (expr, ty) in inference_result.type_of_expr.iter() {
let syntax_ptr = match body_source_map.expr_syntax(expr) { let syntax_ptr = match body_source_map.expr_syntax(expr) {
Some(sp) => sp, Some(sp) => sp.either(|it| it.syntax_node_ptr(), |it| it.syntax_node_ptr()),
None => continue, None => continue,
}; };
types.push((syntax_ptr, ty)); types.push((syntax_ptr, ty));

View file

@ -123,7 +123,7 @@ fn has_comma_after(node: &SyntaxNode) -> bool {
fn join_single_expr_block(edit: &mut TextEditBuilder, token: &SyntaxToken) -> Option<()> { fn join_single_expr_block(edit: &mut TextEditBuilder, token: &SyntaxToken) -> Option<()> {
let block = ast::Block::cast(token.parent())?; let block = ast::Block::cast(token.parent())?;
let block_expr = ast::BlockExpr::cast(block.syntax().parent()?)?; let block_expr = ast::BlockExpr::cast(block.syntax().parent()?)?;
let expr = extract_trivial_expression(&block)?; let expr = extract_trivial_expression(&block_expr)?;
let block_range = block_expr.syntax().text_range(); let block_range = block_expr.syntax().text_range();
let mut buf = expr.syntax().text().to_string(); let mut buf = expr.syntax().text().to_string();

View file

@ -9,12 +9,12 @@ use crate::{
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum ElseBranch { pub enum ElseBranch {
Block(ast::Block), Block(ast::BlockExpr),
IfExpr(ast::IfExpr), IfExpr(ast::IfExpr),
} }
impl ast::IfExpr { impl ast::IfExpr {
pub fn then_branch(&self) -> Option<ast::Block> { pub fn then_branch(&self) -> Option<ast::BlockExpr> {
self.blocks().nth(0) self.blocks().nth(0)
} }
pub fn else_branch(&self) -> Option<ElseBranch> { pub fn else_branch(&self) -> Option<ElseBranch> {
@ -28,7 +28,7 @@ impl ast::IfExpr {
Some(res) Some(res)
} }
fn blocks(&self) -> AstChildren<ast::Block> { fn blocks(&self) -> AstChildren<ast::BlockExpr> {
children(self) children(self)
} }
} }

View file

@ -3135,7 +3135,7 @@ impl AstNode for TryBlockExpr {
} }
} }
impl TryBlockExpr { impl TryBlockExpr {
pub fn block(&self) -> Option<Block> { pub fn body(&self) -> Option<BlockExpr> {
AstChildren::new(&self.syntax).next() AstChildren::new(&self.syntax).next()
} }
} }

View file

@ -28,7 +28,7 @@ pub trait VisibilityOwner: AstNode {
} }
pub trait LoopBodyOwner: AstNode { pub trait LoopBodyOwner: AstNode {
fn loop_body(&self) -> Option<ast::Block> { fn loop_body(&self) -> Option<ast::BlockExpr> {
child_opt(self) child_opt(self)
} }
} }

View file

@ -426,7 +426,7 @@ Grammar(
traits: ["LoopBodyOwner"], traits: ["LoopBodyOwner"],
), ),
"TryBlockExpr": ( "TryBlockExpr": (
options: ["Block"], options: [["body", "BlockExpr"]],
), ),
"ForExpr": ( "ForExpr": (
traits: ["LoopBodyOwner"], traits: ["LoopBodyOwner"],