diff --git a/crates/ra_assists/src/move_guard.rs b/crates/ra_assists/src/move_guard.rs index 127c9e0682..699221e335 100644 --- a/crates/ra_assists/src/move_guard.rs +++ b/crates/ra_assists/src/move_guard.rs @@ -65,9 +65,9 @@ pub(crate) fn move_arm_cond_to_match_guard(mut ctx: AssistCtx) "move condition to match guard", |edit| { edit.target(if_expr.syntax().text_range()); - let then_only_expr = then_block.statements().next().is_none(); + let then_only_expr = then_block.block().and_then(|it| it.statements().next()).is_none(); - match &then_block.expr() { + match &then_block.block().and_then(|it| it.expr()) { Some(then_expr) if then_only_expr => { edit.replace(if_expr.syntax().text_range(), then_expr.syntax().text()) } diff --git a/crates/ra_assists/src/replace_if_let_with_match.rs b/crates/ra_assists/src/replace_if_let_with_match.rs index c0bf6d2351..401835c579 100644 --- a/crates/ra_assists/src/replace_if_let_with_match.rs +++ b/crates/ra_assists/src/replace_if_let_with_match.rs @@ -1,3 +1,4 @@ +use format_buf::format; use hir::db::HirDatabase; use ra_fmt::extract_trivial_expression; use ra_syntax::{ast, AstNode}; @@ -25,16 +26,21 @@ pub(crate) fn replace_if_let_with_match(mut ctx: AssistCtx) -> ctx.build() } -fn build_match_expr(expr: ast::Expr, pat1: ast::Pat, arm1: ast::Block, arm2: ast::Block) -> String { +fn build_match_expr( + expr: ast::Expr, + pat1: ast::Pat, + arm1: ast::BlockExpr, + arm2: ast::BlockExpr, +) -> String { let mut buf = String::new(); - buf.push_str(&format!("match {} {{\n", expr.syntax().text())); - buf.push_str(&format!(" {} => {}\n", pat1.syntax().text(), format_arm(&arm1))); - buf.push_str(&format!(" _ => {}\n", format_arm(&arm2))); + format!(buf, "match {} {{\n", expr.syntax().text()); + format!(buf, " {} => {}\n", pat1.syntax().text(), format_arm(&arm1)); + format!(buf, " _ => {}\n", format_arm(&arm2)); buf.push_str("}"); buf } -fn format_arm(block: &ast::Block) -> String { +fn format_arm(block: &ast::BlockExpr) -> String { match extract_trivial_expression(block) { None => block.syntax().text().to_string(), Some(e) => format!("{},", e.syntax().text()), diff --git a/crates/ra_fmt/src/lib.rs b/crates/ra_fmt/src/lib.rs index b09478d7a3..e22ac9753f 100644 --- a/crates/ra_fmt/src/lib.rs +++ b/crates/ra_fmt/src/lib.rs @@ -34,7 +34,8 @@ fn prev_tokens(token: SyntaxToken) -> impl Iterator { successors(token.prev_token(), |token| token.prev_token()) } -pub fn extract_trivial_expression(block: &ast::Block) -> Option { +pub fn extract_trivial_expression(expr: &ast::BlockExpr) -> Option { + let block = expr.block()?; let expr = block.expr()?; if expr.syntax().text().contains_char('\n') { return None; diff --git a/crates/ra_hir/src/code_model/src.rs b/crates/ra_hir/src/code_model/src.rs index e5bae16ab5..7c9454c0b5 100644 --- a/crates/ra_hir/src/code_model/src.rs +++ b/crates/ra_hir/src/code_model/src.rs @@ -119,10 +119,10 @@ where expr_id: crate::expr::ExprId, ) -> Option> { let source_map = self.body_source_map(db); - let expr_syntax = source_map.expr_syntax(expr_id)?; + let expr_syntax = source_map.expr_syntax(expr_id)?.a()?; let source = self.source(db); - let node = expr_syntax.to_node(&source.ast.syntax()); - ast::Expr::cast(node).map(|ast| Source { file_id: source.file_id, ast }) + let ast = expr_syntax.to_node(&source.ast.syntax()); + Some(Source { file_id: source.file_id, ast }) } } diff --git a/crates/ra_hir/src/expr.rs b/crates/ra_hir/src/expr.rs index c7530849b2..5c95bed40c 100644 --- a/crates/ra_hir/src/expr.rs +++ b/crates/ra_hir/src/expr.rs @@ -9,7 +9,7 @@ use ra_syntax::{ self, ArgListOwner, ArrayExprKind, LiteralKind, LoopBodyOwner, NameOwner, TypeAscriptionOwner, }, - AstNode, AstPtr, SyntaxNodePtr, + AstNode, AstPtr, }; use test_utils::tested_by; @@ -56,13 +56,14 @@ pub struct Body { /// file, so that we don't recompute types whenever some whitespace is typed. #[derive(Default, Debug, Eq, PartialEq)] pub struct BodySourceMap { - expr_map: FxHashMap, - expr_map_back: ArenaMap, + expr_map: FxHashMap, + expr_map_back: ArenaMap, pat_map: FxHashMap, pat_map_back: ArenaMap, field_map: FxHashMap<(ExprId, usize), AstPtr>, } +type ExprPtr = Either, AstPtr>; type PatPtr = Either, AstPtr>; impl Body { @@ -128,16 +129,12 @@ impl Index for Body { } impl BodySourceMap { - pub(crate) fn expr_syntax(&self, expr: ExprId) -> Option { + pub(crate) fn expr_syntax(&self, expr: ExprId) -> Option { self.expr_map_back.get(expr).cloned() } - pub(crate) fn syntax_expr(&self, ptr: SyntaxNodePtr) -> Option { - self.expr_map.get(&ptr).cloned() - } - pub(crate) fn node_expr(&self, node: &ast::Expr) -> Option { - self.expr_map.get(&SyntaxNodePtr::new(node.syntax())).cloned() + self.expr_map.get(&Either::A(AstPtr::new(node))).cloned() } pub(crate) fn pat_syntax(&self, pat: PatId) -> Option { @@ -575,11 +572,12 @@ where current_file_id: file_id, } } - fn alloc_expr(&mut self, expr: Expr, syntax_ptr: SyntaxNodePtr) -> ExprId { + fn alloc_expr(&mut self, expr: Expr, ptr: AstPtr) -> ExprId { + let ptr = Either::A(ptr); let id = self.exprs.alloc(expr); if self.current_file_id == self.original_file_id { - self.source_map.expr_map.insert(syntax_ptr, id); - self.source_map.expr_map_back.insert(id, syntax_ptr); + self.source_map.expr_map.insert(ptr, id); + self.source_map.expr_map_back.insert(id, ptr); } id } @@ -601,7 +599,7 @@ where } fn collect_expr(&mut self, expr: ast::Expr) -> ExprId { - let syntax_ptr = SyntaxNodePtr::new(expr.syntax()); + let syntax_ptr = AstPtr::new(&expr); match expr { ast::Expr::IfExpr(e) => { let then_branch = self.collect_block_opt(e.then_branch()); @@ -640,10 +638,10 @@ where self.alloc_expr(Expr::If { condition, then_branch, else_branch }, syntax_ptr) } ast::Expr::TryBlockExpr(e) => { - let body = self.collect_block_opt(e.block()); + let body = self.collect_block_opt(e.body()); self.alloc_expr(Expr::TryBlock { body }, syntax_ptr) } - ast::Expr::BlockExpr(e) => self.collect_block_opt(e.block()), + ast::Expr::BlockExpr(e) => self.collect_block(e), ast::Expr::LoopExpr(e) => { let body = self.collect_block_opt(e.loop_body()); self.alloc_expr(Expr::Loop { body }, syntax_ptr) @@ -739,7 +737,7 @@ where ast::Expr::ParenExpr(e) => { let inner = self.collect_expr_opt(e.expr()); // make the paren expr point to the inner expression as well - self.source_map.expr_map.insert(syntax_ptr, inner); + self.source_map.expr_map.insert(Either::A(syntax_ptr), inner); inner } ast::Expr::ReturnExpr(e) => { @@ -763,12 +761,9 @@ where } else if let Some(nr) = field.name_ref() { // field shorthand let id = self.exprs.alloc(Expr::Path(Path::from_name_ref(&nr))); - self.source_map - .expr_map - .insert(SyntaxNodePtr::new(nr.syntax()), id); - self.source_map - .expr_map_back - .insert(id, SyntaxNodePtr::new(nr.syntax())); + let ptr = Either::B(AstPtr::new(&field)); + self.source_map.expr_map.insert(ptr, id); + self.source_map.expr_map_back.insert(id, ptr); id } else { self.exprs.alloc(Expr::Missing) @@ -942,7 +937,12 @@ where } } - fn collect_block(&mut self, block: ast::Block) -> ExprId { + fn collect_block(&mut self, expr: ast::BlockExpr) -> ExprId { + let syntax_node_ptr = AstPtr::new(&expr.clone().into()); + let block = match expr.block() { + Some(block) => block, + None => return self.alloc_expr(Expr::Missing, syntax_node_ptr), + }; let statements = block .statements() .map(|s| match s { @@ -956,11 +956,11 @@ where }) .collect(); let tail = block.expr().map(|e| self.collect_expr(e)); - self.alloc_expr(Expr::Block { statements, tail }, SyntaxNodePtr::new(block.syntax())) + self.alloc_expr(Expr::Block { statements, tail }, syntax_node_ptr) } - fn collect_block_opt(&mut self, block: Option) -> ExprId { - if let Some(block) = block { + fn collect_block_opt(&mut self, expr: Option) -> ExprId { + if let Some(block) = expr { self.collect_block(block) } else { self.exprs.alloc(Expr::Missing) diff --git a/crates/ra_hir/src/expr/scope.rs b/crates/ra_hir/src/expr/scope.rs index 79e1857f93..b6d7f3fc14 100644 --- a/crates/ra_hir/src/expr/scope.rs +++ b/crates/ra_hir/src/expr/scope.rs @@ -172,7 +172,7 @@ fn compute_expr_scopes(expr: ExprId, body: &Body, scopes: &mut ExprScopes, scope #[cfg(test)] mod tests { use ra_db::SourceDatabase; - use ra_syntax::{algo::find_node_at_offset, ast, AstNode, SyntaxNodePtr}; + use ra_syntax::{algo::find_node_at_offset, ast, AstNode}; use test_utils::{assert_eq_text, extract_offset}; use crate::{mock::MockDatabase, source_binder::SourceAnalyzer}; @@ -194,8 +194,7 @@ mod tests { let analyzer = SourceAnalyzer::new(&db, file_id, marker.syntax(), None); let scopes = analyzer.scopes(); - let expr_id = - analyzer.body_source_map().syntax_expr(SyntaxNodePtr::new(marker.syntax())).unwrap(); + let expr_id = analyzer.body_source_map().node_expr(&marker.into()).unwrap(); let scope = scopes.scope_for(expr_id); let actual = scopes diff --git a/crates/ra_hir/src/expr/validation.rs b/crates/ra_hir/src/expr/validation.rs index c8ae198696..6fdaf1fce3 100644 --- a/crates/ra_hir/src/expr/validation.rs +++ b/crates/ra_hir/src/expr/validation.rs @@ -1,7 +1,7 @@ -use rustc_hash::FxHashSet; use std::sync::Arc; -use ra_syntax::ast::{AstNode, RecordLit}; +use ra_syntax::ast::{self, AstNode}; +use rustc_hash::FxHashSet; use super::{Expr, ExprId, RecordLitField}; use crate::{ @@ -13,7 +13,6 @@ use crate::{ ty::{ApplicationTy, InferenceResult, Ty, TypeCtor}, Function, HasSource, HirDatabase, ModuleDef, Name, Path, PerNs, Resolution, }; -use ra_syntax::ast; pub(crate) struct ExprValidator<'a, 'b: 'a> { func: Function, @@ -84,8 +83,12 @@ impl<'a, 'b> ExprValidator<'a, 'b> { let source_file = parse.tree(); if let Some(field_list_node) = source_map .expr_syntax(id) + .and_then(|ptr| ptr.a()) .map(|ptr| ptr.to_node(source_file.syntax())) - .and_then(RecordLit::cast) + .and_then(|expr| match expr { + ast::Expr::RecordLit(it) => Some(it), + _ => None, + }) .and_then(|lit| lit.record_field_list()) { let field_list_ptr = AstPtr::new(&field_list_node); @@ -135,7 +138,7 @@ impl<'a, 'b> ExprValidator<'a, 'b> { let source_map = self.func.body_source_map(db); let file_id = self.func.source(db).file_id; - if let Some(expr) = source_map.expr_syntax(id).and_then(|n| n.cast::()) { + if let Some(expr) = source_map.expr_syntax(id).and_then(|n| n.a()) { self.sink.push(MissingOkInTailExpr { file: file_id, expr }); } } diff --git a/crates/ra_hir/src/source_binder.rs b/crates/ra_hir/src/source_binder.rs index 43aec201a7..e5f4d11a64 100644 --- a/crates/ra_hir/src/source_binder.rs +++ b/crates/ra_hir/src/source_binder.rs @@ -462,8 +462,8 @@ fn scope_for( node: &SyntaxNode, ) -> Option { node.ancestors() - .map(|it| SyntaxNodePtr::new(&it)) - .filter_map(|ptr| source_map.syntax_expr(ptr)) + .filter_map(ast::Expr::cast) + .filter_map(|it| source_map.node_expr(&it)) .find_map(|it| scopes.scope_for(it)) } @@ -475,7 +475,10 @@ fn scope_for_offset( scopes .scope_by_expr() .iter() - .filter_map(|(id, scope)| Some((source_map.expr_syntax(*id)?, scope))) + .filter_map(|(id, scope)| { + let ast_ptr = source_map.expr_syntax(*id)?.a()?; + Some((ast_ptr.syntax_node_ptr(), scope)) + }) // find containing scope .min_by_key(|(ptr, _scope)| { (!(ptr.range().start() <= offset && offset <= ptr.range().end()), ptr.range().len()) @@ -495,7 +498,10 @@ fn adjust( let child_scopes = scopes .scope_by_expr() .iter() - .filter_map(|(id, scope)| Some((source_map.expr_syntax(*id)?, scope))) + .filter_map(|(id, scope)| { + let ast_ptr = source_map.expr_syntax(*id)?.a()?; + Some((ast_ptr.syntax_node_ptr(), scope)) + }) .map(|(ptr, scope)| (ptr.range(), scope)) .filter(|(range, _)| range.start() <= offset && range.is_subrange(&r) && *range != r); diff --git a/crates/ra_hir/src/ty/tests.rs b/crates/ra_hir/src/ty/tests.rs index b034fd59e9..d344ab12e7 100644 --- a/crates/ra_hir/src/ty/tests.rs +++ b/crates/ra_hir/src/ty/tests.rs @@ -3582,7 +3582,7 @@ fn infer(content: &str) -> String { for (expr, ty) in inference_result.type_of_expr.iter() { let syntax_ptr = match body_source_map.expr_syntax(expr) { - Some(sp) => sp, + Some(sp) => sp.either(|it| it.syntax_node_ptr(), |it| it.syntax_node_ptr()), None => continue, }; types.push((syntax_ptr, ty)); diff --git a/crates/ra_ide_api/src/join_lines.rs b/crates/ra_ide_api/src/join_lines.rs index a2e4b6f3cf..a71e4ed7dc 100644 --- a/crates/ra_ide_api/src/join_lines.rs +++ b/crates/ra_ide_api/src/join_lines.rs @@ -123,7 +123,7 @@ fn has_comma_after(node: &SyntaxNode) -> bool { fn join_single_expr_block(edit: &mut TextEditBuilder, token: &SyntaxToken) -> Option<()> { let block = ast::Block::cast(token.parent())?; let block_expr = ast::BlockExpr::cast(block.syntax().parent()?)?; - let expr = extract_trivial_expression(&block)?; + let expr = extract_trivial_expression(&block_expr)?; let block_range = block_expr.syntax().text_range(); let mut buf = expr.syntax().text().to_string(); diff --git a/crates/ra_syntax/src/ast/expr_extensions.rs b/crates/ra_syntax/src/ast/expr_extensions.rs index d7ea4354df..1324965cfb 100644 --- a/crates/ra_syntax/src/ast/expr_extensions.rs +++ b/crates/ra_syntax/src/ast/expr_extensions.rs @@ -9,12 +9,12 @@ use crate::{ #[derive(Debug, Clone, PartialEq, Eq)] pub enum ElseBranch { - Block(ast::Block), + Block(ast::BlockExpr), IfExpr(ast::IfExpr), } impl ast::IfExpr { - pub fn then_branch(&self) -> Option { + pub fn then_branch(&self) -> Option { self.blocks().nth(0) } pub fn else_branch(&self) -> Option { @@ -28,7 +28,7 @@ impl ast::IfExpr { Some(res) } - fn blocks(&self) -> AstChildren { + fn blocks(&self) -> AstChildren { children(self) } } diff --git a/crates/ra_syntax/src/ast/generated.rs b/crates/ra_syntax/src/ast/generated.rs index fd85a32315..e2a92ae604 100644 --- a/crates/ra_syntax/src/ast/generated.rs +++ b/crates/ra_syntax/src/ast/generated.rs @@ -3135,7 +3135,7 @@ impl AstNode for TryBlockExpr { } } impl TryBlockExpr { - pub fn block(&self) -> Option { + pub fn body(&self) -> Option { AstChildren::new(&self.syntax).next() } } diff --git a/crates/ra_syntax/src/ast/traits.rs b/crates/ra_syntax/src/ast/traits.rs index 20c251fbad..c3e676d4c2 100644 --- a/crates/ra_syntax/src/ast/traits.rs +++ b/crates/ra_syntax/src/ast/traits.rs @@ -28,7 +28,7 @@ pub trait VisibilityOwner: AstNode { } pub trait LoopBodyOwner: AstNode { - fn loop_body(&self) -> Option { + fn loop_body(&self) -> Option { child_opt(self) } } diff --git a/crates/ra_syntax/src/grammar.ron b/crates/ra_syntax/src/grammar.ron index 37166182f9..c14ee0e856 100644 --- a/crates/ra_syntax/src/grammar.ron +++ b/crates/ra_syntax/src/grammar.ron @@ -426,7 +426,7 @@ Grammar( traits: ["LoopBodyOwner"], ), "TryBlockExpr": ( - options: ["Block"], + options: [["body", "BlockExpr"]], ), "ForExpr": ( traits: ["LoopBodyOwner"],