diff --git a/crates/ra_assists/src/assist_ctx.rs b/crates/ra_assists/src/assist_ctx.rs index 62182cf03f..c3e6532999 100644 --- a/crates/ra_assists/src/assist_ctx.rs +++ b/crates/ra_assists/src/assist_ctx.rs @@ -11,6 +11,7 @@ use ra_syntax::{ use ra_text_edit::TextEditBuilder; use crate::{AssistAction, AssistId, AssistLabel, GroupLabel, ResolvedAssist}; +use algo::SyntaxRewriter; #[derive(Clone, Debug)] pub(crate) struct Assist(pub(crate) Vec); @@ -234,6 +235,11 @@ impl ActionBuilder { pub(crate) fn replace_ast(&mut self, old: N, new: N) { algo::diff(old.syntax(), new.syntax()).into_text_edit(&mut self.edit) } + pub(crate) fn rewrite(&mut self, rewriter: SyntaxRewriter) { + let node = rewriter.rewrite_root().unwrap(); + let new = rewriter.rewrite(&node); + algo::diff(&node, &new).into_text_edit(&mut self.edit) + } fn build(self) -> AssistAction { AssistAction { diff --git a/crates/ra_assists/src/ast_transform.rs b/crates/ra_assists/src/ast_transform.rs index 45558c4487..52b4c82db6 100644 --- a/crates/ra_assists/src/ast_transform.rs +++ b/crates/ra_assists/src/ast_transform.rs @@ -3,7 +3,10 @@ use rustc_hash::FxHashMap; use hir::{PathResolution, SemanticsScope}; use ra_ide_db::RootDatabase; -use ra_syntax::ast::{self, AstNode}; +use ra_syntax::{ + algo::SyntaxRewriter, + ast::{self, AstNode}, +}; pub trait AstTransform<'a> { fn get_substitution(&self, node: &ra_syntax::SyntaxNode) -> Option; @@ -153,15 +156,14 @@ impl<'a> QualifyPaths<'a> { } pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: N) -> N { - let syntax = node.syntax(); - let result = ra_syntax::algo::replace_descendants(syntax, |element| match element { + SyntaxRewriter::from_fn(|element| match element { ra_syntax::SyntaxElement::Node(n) => { let replacement = transformer.get_substitution(&n)?; Some(replacement.into()) } _ => None, - }); - N::cast(result).unwrap() + }) + .rewrite_ast(&node) } impl<'a> AstTransform<'a> for QualifyPaths<'a> { diff --git a/crates/ra_assists/src/handlers/merge_imports.rs b/crates/ra_assists/src/handlers/merge_imports.rs index 89bc975bd0..9c57d1e309 100644 --- a/crates/ra_assists/src/handlers/merge_imports.rs +++ b/crates/ra_assists/src/handlers/merge_imports.rs @@ -1,9 +1,9 @@ use std::iter::successors; use ra_syntax::{ - algo::neighbor, + algo::{neighbor, SyntaxRewriter}, ast::{self, edit::AstNodeEdit, make}, - AstNode, AstToken, Direction, InsertPosition, SyntaxElement, TextRange, T, + AstNode, Direction, InsertPosition, SyntaxElement, T, }; use crate::{Assist, AssistCtx, AssistId}; @@ -22,9 +22,10 @@ use crate::{Assist, AssistCtx, AssistId}; // ``` pub(crate) fn merge_imports(ctx: AssistCtx) -> Option { let tree: ast::UseTree = ctx.find_node_at_offset()?; - let (new_tree, to_delete) = if let Some(use_item) = - tree.syntax().parent().and_then(ast::UseItem::cast) - { + let mut rewriter = SyntaxRewriter::default(); + let mut offset = ctx.frange.range.start(); + + if let Some(use_item) = tree.syntax().parent().and_then(ast::UseItem::cast) { let (merged, to_delete) = next_prev() .filter_map(|dir| neighbor(&use_item, dir)) .filter_map(|it| Some((it.clone(), it.use_tree()?))) @@ -32,42 +33,28 @@ pub(crate) fn merge_imports(ctx: AssistCtx) -> Option { Some((try_merge_trees(&tree, &use_tree)?, use_item.clone())) })?; - let mut range = to_delete.syntax().text_range(); - let next_ws = to_delete - .syntax() - .next_sibling_or_token() - .and_then(|it| it.into_token()) - .and_then(ast::Whitespace::cast); - if let Some(ws) = next_ws { - range = range.extend_to(&ws.syntax().text_range()) + rewriter.replace_ast(&tree, &merged); + rewriter += to_delete.remove(); + + if to_delete.syntax().text_range().end() < offset { + offset -= to_delete.syntax().text_range().len(); } - (merged, range) } else { let (merged, to_delete) = next_prev() .filter_map(|dir| neighbor(&tree, dir)) .find_map(|use_tree| Some((try_merge_trees(&tree, &use_tree)?, use_tree.clone())))?; - let mut range = to_delete.syntax().text_range(); - if let Some((dir, nb)) = next_prev().find_map(|dir| Some((dir, neighbor(&to_delete, dir)?))) - { - let nb_range = nb.syntax().text_range(); - if dir == Direction::Prev { - range = TextRange::from_to(nb_range.end(), range.end()); - } else { - range = TextRange::from_to(range.start(), nb_range.start()); - } + rewriter.replace_ast(&tree, &merged); + rewriter += to_delete.remove(); + + if to_delete.syntax().text_range().end() < offset { + offset -= to_delete.syntax().text_range().len(); } - (merged, range) }; - let mut offset = ctx.frange.range.start(); ctx.add_assist(AssistId("merge_imports"), "Merge imports", |edit| { - edit.replace_ast(tree, new_tree); - edit.delete(to_delete); - - if to_delete.end() <= offset { - offset -= to_delete.len(); - } + edit.rewrite(rewriter); + // FIXME: we only need because our diff is imprecise edit.set_cursor(offset); }) } @@ -156,7 +143,7 @@ use std::fmt::Debug; use std::fmt<|>::Display; ", r" -use std::fmt<|>::{Display, Debug}; +use std::fmt:<|>:{Display, Debug}; ", ); } @@ -178,7 +165,57 @@ use std::{fmt<|>::{Debug, Display}}; use std::{fmt::Debug, fmt<|>::Display}; ", r" -use std::{fmt<|>::{Display, Debug}}; +use std::{fmt::<|>{Display, Debug}}; +", + ); + } + + #[test] + fn removes_just_enough_whitespace() { + check_assist( + merge_imports, + r" +use foo<|>::bar; +use foo::baz; + +/// Doc comment +", + r" +use foo<|>::{bar, baz}; + +/// Doc comment +", + ); + } + + #[test] + fn works_with_trailing_comma() { + check_assist( + merge_imports, + r" +use { + foo<|>::bar, + foo::baz, +}; +", + r" +use { + foo<|>::{bar, baz}, +}; +", + ); + check_assist( + merge_imports, + r" +use { + foo::baz, + foo<|>::bar, +}; +", + r" +use { + foo::{bar<|>, baz}, +}; ", ); } diff --git a/crates/ra_hir_expand/src/eager.rs b/crates/ra_hir_expand/src/eager.rs index 4cbce4df5b..687d402946 100644 --- a/crates/ra_hir_expand/src/eager.rs +++ b/crates/ra_hir_expand/src/eager.rs @@ -26,8 +26,8 @@ use crate::{ }; use ra_parser::FragmentKind; -use ra_syntax::{algo::replace_descendants, SyntaxElement, SyntaxNode}; -use std::{collections::HashMap, sync::Arc}; +use ra_syntax::{algo::SyntaxRewriter, SyntaxNode}; +use std::sync::Arc; pub fn expand_eager_macro( db: &dyn AstDatabase, @@ -95,10 +95,10 @@ fn eager_macro_recur( curr: InFile, macro_resolver: &dyn Fn(ast::Path) -> Option, ) -> Option { - let mut original = curr.value.clone(); + let original = curr.value.clone(); let children = curr.value.descendants().filter_map(ast::MacroCall::cast); - let mut replaces: HashMap = HashMap::default(); + let mut rewriter = SyntaxRewriter::default(); // Collect replacement for child in children { @@ -119,12 +119,9 @@ fn eager_macro_recur( } }; - replaces.insert(child.syntax().clone().into(), insert.into()); + rewriter.replace(child.syntax(), &insert); } - if !replaces.is_empty() { - original = replace_descendants(&original, |n| replaces.get(n).cloned()); - } - - Some(original) + let res = rewriter.rewrite(&original); + Some(res) } diff --git a/crates/ra_ide/src/expand_macro.rs b/crates/ra_ide/src/expand_macro.rs index e58526f31f..f536ba3e78 100644 --- a/crates/ra_ide/src/expand_macro.rs +++ b/crates/ra_ide/src/expand_macro.rs @@ -3,10 +3,9 @@ use hir::Semantics; use ra_ide_db::RootDatabase; use ra_syntax::{ - algo::{find_node_at_offset, replace_descendants}, - ast, AstNode, NodeOrToken, SyntaxElement, SyntaxKind, SyntaxNode, WalkEvent, T, + algo::{find_node_at_offset, SyntaxRewriter}, + ast, AstNode, NodeOrToken, SyntaxKind, SyntaxNode, WalkEvent, T, }; -use rustc_hash::FxHashMap; use crate::FilePosition; @@ -37,7 +36,7 @@ fn expand_macro_recur( let mut expanded = sema.expand(macro_call)?; let children = expanded.descendants().filter_map(ast::MacroCall::cast); - let mut replaces: FxHashMap = FxHashMap::default(); + let mut rewriter = SyntaxRewriter::default(); for child in children.into_iter() { if let Some(new_node) = expand_macro_recur(sema, &child) { @@ -47,12 +46,13 @@ fn expand_macro_recur( if expanded == *child.syntax() { expanded = new_node; } else { - replaces.insert(child.syntax().clone().into(), new_node.into()); + rewriter.replace(child.syntax(), &new_node) } } } - Some(replace_descendants(&expanded, |n| replaces.get(n).cloned())) + let res = rewriter.rewrite(&expanded); + Some(res) } // FIXME: It would also be cool to share logic here and in the mbe tests, diff --git a/crates/ra_syntax/src/algo.rs b/crates/ra_syntax/src/algo.rs index ffdbdc7671..4d463a3ef0 100644 --- a/crates/ra_syntax/src/algo.rs +++ b/crates/ra_syntax/src/algo.rs @@ -1,6 +1,9 @@ //! FIXME: write short doc here -use std::ops::RangeInclusive; +use std::{ + fmt, + ops::{self, RangeInclusive}, +}; use itertools::Itertools; use ra_text_edit::TextEditBuilder; @@ -222,42 +225,119 @@ fn _replace_children( with_children(parent, new_children) } +#[derive(Default)] +pub struct SyntaxRewriter<'a> { + f: Option Option + 'a>>, + //FIXME: add debug_assertions that all elements are in fact from the same file. + replacements: FxHashMap, +} + +impl fmt::Debug for SyntaxRewriter<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SyntaxRewriter").field("replacements", &self.replacements).finish() + } +} + +impl<'a> SyntaxRewriter<'a> { + pub fn from_fn(f: impl Fn(&SyntaxElement) -> Option + 'a) -> SyntaxRewriter<'a> { + SyntaxRewriter { f: Some(Box::new(f)), replacements: FxHashMap::default() } + } + pub fn delete>(&mut self, what: &T) { + let what = what.clone().into(); + let replacement = Replacement::Delete; + self.replacements.insert(what, replacement); + } + pub fn replace>(&mut self, what: &T, with: &T) { + let what = what.clone().into(); + let replacement = Replacement::Single(with.clone().into()); + self.replacements.insert(what, replacement); + } + pub fn replace_ast(&mut self, what: &T, with: &T) { + self.replace(what.syntax(), with.syntax()) + } + + pub fn rewrite(&self, node: &SyntaxNode) -> SyntaxNode { + if self.f.is_none() && self.replacements.is_empty() { + return node.clone(); + } + self.rewrite_children(node) + } + + pub fn rewrite_ast(self, node: &N) -> N { + N::cast(self.rewrite(node.syntax())).unwrap() + } + + pub fn rewrite_root(&self) -> Option { + assert!(self.f.is_none()); + self.replacements + .keys() + .map(|element| match element { + SyntaxElement::Node(it) => it.clone(), + SyntaxElement::Token(it) => it.parent(), + }) + .fold1(|a, b| least_common_ancestor(&a, &b).unwrap()) + } + + fn replacement(&self, element: &SyntaxElement) -> Option { + if let Some(f) = &self.f { + assert!(self.replacements.is_empty()); + return f(element).map(Replacement::Single); + } + self.replacements.get(element).cloned() + } + + fn rewrite_children(&self, node: &SyntaxNode) -> SyntaxNode { + // FIXME: this could be made much faster. + let new_children = + node.children_with_tokens().flat_map(|it| self.rewrite_self(&it)).collect::>(); + with_children(node, new_children) + } + + fn rewrite_self( + &self, + element: &SyntaxElement, + ) -> Option> { + if let Some(replacement) = self.replacement(&element) { + return match replacement { + Replacement::Single(NodeOrToken::Node(it)) => { + Some(NodeOrToken::Node(it.green().clone())) + } + Replacement::Single(NodeOrToken::Token(it)) => { + Some(NodeOrToken::Token(it.green().clone())) + } + Replacement::Delete => None, + }; + } + let res = match element { + NodeOrToken::Token(it) => NodeOrToken::Token(it.green().clone()), + NodeOrToken::Node(it) => NodeOrToken::Node(self.rewrite_children(it).green().clone()), + }; + Some(res) + } +} + +impl<'a> ops::AddAssign for SyntaxRewriter<'_> { + fn add_assign(&mut self, rhs: SyntaxRewriter) { + assert!(rhs.f.is_none()); + self.replacements.extend(rhs.replacements) + } +} + +#[derive(Clone, Debug)] +enum Replacement { + Delete, + Single(SyntaxElement), +} + /// Replaces descendants in the node, according to the mapping. /// /// This is a type-unsafe low-level editing API, if you need to use it, prefer /// to create a type-safe abstraction on top of it instead. -pub fn replace_descendants( +pub fn _replace_descendants( parent: &SyntaxNode, map: impl Fn(&SyntaxElement) -> Option, ) -> SyntaxNode { - _replace_descendants(parent, &map) -} - -fn _replace_descendants( - parent: &SyntaxNode, - map: &dyn Fn(&SyntaxElement) -> Option, -) -> SyntaxNode { - // FIXME: this could be made much faster. - let new_children = parent.children_with_tokens().map(|it| go(map, it)).collect::>(); - return with_children(parent, new_children); - - fn go( - map: &dyn Fn(&SyntaxElement) -> Option, - element: SyntaxElement, - ) -> NodeOrToken { - if let Some(replacement) = map(&element) { - return match replacement { - NodeOrToken::Node(it) => NodeOrToken::Node(it.green().clone()), - NodeOrToken::Token(it) => NodeOrToken::Token(it.green().clone()), - }; - } - match element { - NodeOrToken::Token(it) => NodeOrToken::Token(it.green().clone()), - NodeOrToken::Node(it) => { - NodeOrToken::Node(_replace_descendants(&it, map).green().clone()) - } - } - } + SyntaxRewriter::from_fn(map).rewrite(parent) } fn with_children( diff --git a/crates/ra_syntax/src/ast/edit.rs b/crates/ra_syntax/src/ast/edit.rs index 324327162c..f74c9f9c62 100644 --- a/crates/ra_syntax/src/ast/edit.rs +++ b/crates/ra_syntax/src/ast/edit.rs @@ -4,7 +4,6 @@ use std::{iter, ops::RangeInclusive}; use arrayvec::ArrayVec; -use rustc_hash::FxHashMap; use crate::{ algo, @@ -17,6 +16,7 @@ use crate::{ SyntaxKind::{ATTR, COMMENT, WHITESPACE}, SyntaxNode, SyntaxToken, T, }; +use algo::{neighbor, SyntaxRewriter}; impl ast::BinExpr { #[must_use] @@ -255,6 +255,28 @@ impl ast::UseItem { } self.clone() } + + pub fn remove(&self) -> SyntaxRewriter<'static> { + let mut res = SyntaxRewriter::default(); + res.delete(self.syntax()); + let next_ws = self + .syntax() + .next_sibling_or_token() + .and_then(|it| it.into_token()) + .and_then(ast::Whitespace::cast); + if let Some(next_ws) = next_ws { + let ws_text = next_ws.syntax().text(); + if ws_text.starts_with('\n') { + let rest = &ws_text[1..]; + if rest.is_empty() { + res.delete(next_ws.syntax()) + } else { + res.replace(next_ws.syntax(), &make::tokens::whitespace(rest)); + } + } + } + res + } } impl ast::UseTree { @@ -293,6 +315,22 @@ impl ast::UseTree { Some(res) } } + + pub fn remove(&self) -> SyntaxRewriter<'static> { + let mut res = SyntaxRewriter::default(); + res.delete(self.syntax()); + for &dir in [Direction::Next, Direction::Prev].iter() { + if let Some(nb) = neighbor(self, dir) { + self.syntax() + .siblings_with_tokens(dir) + .skip(1) + .take_while(|it| it.as_node() != Some(nb.syntax())) + .for_each(|el| res.delete(&el)); + return res; + } + } + res + } } #[must_use] @@ -343,28 +381,24 @@ impl IndentLevel { } fn _increase_indent(self, node: SyntaxNode) -> SyntaxNode { - let replacements: FxHashMap = node - .descendants_with_tokens() + let mut rewriter = SyntaxRewriter::default(); + node.descendants_with_tokens() .filter_map(|el| el.into_token()) .filter_map(ast::Whitespace::cast) .filter(|ws| { let text = ws.syntax().text(); text.contains('\n') }) - .map(|ws| { - ( - ws.syntax().clone().into(), - make::tokens::whitespace(&format!( - "{}{:width$}", - ws.syntax().text(), - "", - width = self.0 as usize * 4 - )) - .into(), - ) - }) - .collect(); - algo::replace_descendants(&node, |n| replacements.get(n).cloned()) + .for_each(|ws| { + let new_ws = make::tokens::whitespace(&format!( + "{}{:width$}", + ws.syntax().text(), + "", + width = self.0 as usize * 4 + )); + rewriter.replace(ws.syntax(), &new_ws) + }); + rewriter.rewrite(&node) } pub fn decrease_indent(self, node: N) -> N { @@ -372,27 +406,21 @@ impl IndentLevel { } fn _decrease_indent(self, node: SyntaxNode) -> SyntaxNode { - let replacements: FxHashMap = node - .descendants_with_tokens() + let mut rewriter = SyntaxRewriter::default(); + node.descendants_with_tokens() .filter_map(|el| el.into_token()) .filter_map(ast::Whitespace::cast) .filter(|ws| { let text = ws.syntax().text(); text.contains('\n') }) - .map(|ws| { - ( - ws.syntax().clone().into(), - make::tokens::whitespace( - &ws.syntax() - .text() - .replace(&format!("\n{:1$}", "", self.0 as usize * 4), "\n"), - ) - .into(), - ) - }) - .collect(); - algo::replace_descendants(&node, |n| replacements.get(n).cloned()) + .for_each(|ws| { + let new_ws = make::tokens::whitespace( + &ws.syntax().text().replace(&format!("\n{:1$}", "", self.0 as usize * 4), "\n"), + ); + rewriter.replace(ws.syntax(), &new_ws) + }); + rewriter.rewrite(&node) } } @@ -442,12 +470,11 @@ pub trait AstNodeEdit: AstNode + Sized { &self, replacement_map: impl IntoIterator, ) -> Self { - let map = replacement_map - .into_iter() - .map(|(from, to)| (from.syntax().clone().into(), to.syntax().clone().into())) - .collect::>(); - let new_syntax = algo::replace_descendants(self.syntax(), |n| map.get(n).cloned()); - Self::cast(new_syntax).unwrap() + let mut rewriter = SyntaxRewriter::default(); + for (from, to) in replacement_map { + rewriter.replace(from.syntax(), to.syntax()) + } + rewriter.rewrite_ast(self) } }