internal: replace AstTransformer with mutable syntax trees

This commit is contained in:
Aleksey Kladov 2021-05-18 14:42:41 +03:00
parent 3cfe2d0a5d
commit 47d7434dde
6 changed files with 123 additions and 208 deletions

View file

@ -1,31 +1,12 @@
//! `AstTransformer`s are functions that replace nodes in an AST and can be easily combined.
use hir::{HirDisplay, PathResolution, SemanticsScope};
use hir::{HirDisplay, SemanticsScope};
use ide_db::helpers::mod_path_to_ast;
use rustc_hash::FxHashMap;
use syntax::{
ast::{self, AstNode},
ted, SyntaxNode,
ted,
};
pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: &N) {
let mut skip_to = None;
for event in node.syntax().preorder() {
match event {
syntax::WalkEvent::Enter(node) if skip_to.is_none() => {
skip_to = transformer.get_substitution(&node, transformer).zip(Some(node));
}
syntax::WalkEvent::Enter(_) => (),
syntax::WalkEvent::Leave(node) => match &skip_to {
Some((replacement, skip_target)) if *skip_target == node => {
ted::replace(node, replacement.clone_for_update());
skip_to.take();
}
_ => (),
},
}
}
}
/// `AstTransform` helps with applying bulk transformations to syntax nodes.
///
/// This is mostly useful for IDE code generation. If you paste some existing
@ -35,8 +16,8 @@ pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: &N) {
///
/// ```
/// mod x {
/// pub struct A;
/// pub trait T<U> { fn foo(&self, _: U) -> A; }
/// pub struct A<V>;
/// pub trait T<U> { fn foo(&self, _: U) -> A<U>; }
/// }
///
/// mod y {
@ -45,7 +26,7 @@ pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: &N) {
/// impl T<()> for () {
/// // If we invoke **Add Missing Members** here, we want to copy-paste `foo`.
/// // But we want a slightly-modified version of it:
/// fn foo(&self, _: ()) -> x::A {}
/// fn foo(&self, _: ()) -> x::A<()> {}
/// }
/// }
/// ```
@ -54,49 +35,27 @@ pub fn apply<'a, N: AstNode>(transformer: &dyn AstTransform<'a>, node: &N) {
/// `SyntaxNode`. Note that the API here is a bit too high-order and high-brow.
/// We'd want to somehow express this concept simpler, but so far nobody got to
/// simplifying this!
pub trait AstTransform<'a> {
fn get_substitution(
&self,
node: &SyntaxNode,
recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode>;
fn or<T: AstTransform<'a> + 'a>(self, other: T) -> Box<dyn AstTransform<'a> + 'a>
where
Self: Sized + 'a,
{
Box::new(Or(Box::new(self), Box::new(other)))
}
pub(crate) struct AstTransform<'a> {
pub(crate) subst: (hir::Trait, ast::Impl),
pub(crate) target_scope: &'a SemanticsScope<'a>,
pub(crate) source_scope: &'a SemanticsScope<'a>,
}
struct Or<'a>(Box<dyn AstTransform<'a> + 'a>, Box<dyn AstTransform<'a> + 'a>);
impl<'a> AstTransform<'a> for Or<'a> {
fn get_substitution(
&self,
node: &SyntaxNode,
recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
self.0.get_substitution(node, recur).or_else(|| self.1.get_substitution(node, recur))
impl<'a> AstTransform<'a> {
pub(crate) fn apply(&self, item: ast::AssocItem) {
if let Some(ctx) = self.build_ctx() {
ctx.apply(item)
}
}
fn build_ctx(&self) -> Option<Ctx<'a>> {
let db = self.source_scope.db;
let target_module = self.target_scope.module()?;
let source_module = self.source_scope.module()?;
pub struct SubstituteTypeParams<'a> {
source_scope: &'a SemanticsScope<'a>,
substs: FxHashMap<hir::TypeParam, ast::Type>,
}
impl<'a> SubstituteTypeParams<'a> {
pub fn for_trait_impl(
source_scope: &'a SemanticsScope<'a>,
// FIXME: there's implicit invariant that `trait_` and `source_scope` match...
trait_: hir::Trait,
impl_def: ast::Impl,
) -> SubstituteTypeParams<'a> {
let substs = get_syntactic_substs(impl_def).unwrap_or_default();
let generic_def: hir::GenericDef = trait_.into();
let substs = get_syntactic_substs(self.subst.1.clone()).unwrap_or_default();
let generic_def: hir::GenericDef = self.subst.0.into();
let substs_by_param: FxHashMap<_, _> = generic_def
.type_params(source_scope.db)
.type_params(db)
.into_iter()
// this is a trait impl, so we need to skip the first type parameter -- this is a bit hacky
.skip(1)
@ -110,19 +69,76 @@ impl<'a> SubstituteTypeParams<'a> {
.filter_map(|(k, v)| match v {
Some(v) => Some((k, v)),
None => {
let default = k.default(source_scope.db)?;
let default = k.default(db)?;
Some((
k,
ast::make::ty(
&default
.display_source_code(source_scope.db, source_scope.module()?.into())
.ok()?,
),
ast::make::ty(&default.display_source_code(db, source_module.into()).ok()?),
))
}
})
.collect();
return SubstituteTypeParams { source_scope, substs: substs_by_param };
let res = Ctx { substs: substs_by_param, target_module, source_scope: self.source_scope };
Some(res)
}
}
struct Ctx<'a> {
substs: FxHashMap<hir::TypeParam, ast::Type>,
target_module: hir::Module,
source_scope: &'a SemanticsScope<'a>,
}
impl<'a> Ctx<'a> {
fn apply(&self, item: ast::AssocItem) {
for event in item.syntax().preorder() {
let node = match event {
syntax::WalkEvent::Enter(_) => continue,
syntax::WalkEvent::Leave(it) => it,
};
if let Some(path) = ast::Path::cast(node.clone()) {
self.transform_path(path);
}
}
}
fn transform_path(&self, path: ast::Path) -> Option<()> {
if path.qualifier().is_some() {
return None;
}
if path.segment().and_then(|s| s.param_list()).is_some() {
// don't try to qualify `Fn(Foo) -> Bar` paths, they are in prelude anyway
return None;
}
let resolution = self.source_scope.speculative_resolve(&path)?;
match resolution {
hir::PathResolution::TypeParam(tp) => {
if let Some(subst) = self.substs.get(&tp) {
ted::replace(path.syntax(), subst.clone_subtree().clone_for_update().syntax())
}
}
hir::PathResolution::Def(def) => {
let found_path =
self.target_module.find_use_path(self.source_scope.db.upcast(), def)?;
let res = mod_path_to_ast(&found_path).clone_for_update();
if let Some(args) = path.segment().and_then(|it| it.generic_arg_list()) {
if let Some(segment) = res.segment() {
let old = segment.get_or_create_generic_arg_list();
ted::replace(old.syntax(), args.clone_subtree().syntax().clone_for_update())
}
}
ted::replace(path.syntax(), res.syntax())
}
hir::PathResolution::Local(_)
| hir::PathResolution::ConstParam(_)
| hir::PathResolution::SelfType(_)
| hir::PathResolution::Macro(_)
| hir::PathResolution::AssocItem(_) => (),
}
Some(())
}
}
// FIXME: It would probably be nicer if we could get this via HIR (i.e. get the
// trait ref, and then go from the types in the substs back to the syntax).
@ -146,73 +162,3 @@ impl<'a> SubstituteTypeParams<'a> {
Some(result)
}
}
}
impl<'a> AstTransform<'a> for SubstituteTypeParams<'a> {
fn get_substitution(
&self,
node: &SyntaxNode,
_recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
let type_ref = ast::Type::cast(node.clone())?;
let path = match &type_ref {
ast::Type::PathType(path_type) => path_type.path()?,
_ => return None,
};
let resolution = self.source_scope.speculative_resolve(&path)?;
match resolution {
hir::PathResolution::TypeParam(tp) => Some(self.substs.get(&tp)?.syntax().clone()),
_ => None,
}
}
}
pub struct QualifyPaths<'a> {
target_scope: &'a SemanticsScope<'a>,
source_scope: &'a SemanticsScope<'a>,
}
impl<'a> QualifyPaths<'a> {
pub fn new(target_scope: &'a SemanticsScope<'a>, source_scope: &'a SemanticsScope<'a>) -> Self {
Self { target_scope, source_scope }
}
}
impl<'a> AstTransform<'a> for QualifyPaths<'a> {
fn get_substitution(
&self,
node: &SyntaxNode,
recur: &dyn AstTransform<'a>,
) -> Option<SyntaxNode> {
// FIXME handle value ns?
let from = self.target_scope.module()?;
let p = ast::Path::cast(node.clone())?;
if p.segment().and_then(|s| s.param_list()).is_some() {
// don't try to qualify `Fn(Foo) -> Bar` paths, they are in prelude anyway
return None;
}
let resolution = self.source_scope.speculative_resolve(&p)?;
match resolution {
PathResolution::Def(def) => {
let found_path = from.find_use_path(self.source_scope.db.upcast(), def)?;
let mut path = mod_path_to_ast(&found_path);
let type_args = p.segment().and_then(|s| s.generic_arg_list());
if let Some(type_args) = type_args {
apply(recur, &type_args);
let last_segment = path.segment().unwrap();
path = path.with_segment(last_segment.with_generic_args(type_args))
}
Some(path.syntax().clone())
}
PathResolution::Local(_)
| PathResolution::TypeParam(_)
| PathResolution::SelfType(_)
| PathResolution::ConstParam(_) => None,
PathResolution::Macro(_) => None,
PathResolution::AssocItem(_) => None,
}
}
}

View file

@ -24,7 +24,7 @@ use syntax::{
use crate::{
assist_context::{AssistBuilder, AssistContext},
ast_transform::{self, AstTransform, QualifyPaths, SubstituteTypeParams},
ast_transform::AstTransform,
};
pub(crate) fn unwrap_trivial_block(block: ast::BlockExpr) -> ast::Expr {
@ -132,14 +132,18 @@ pub fn add_trait_assoc_items_to_impl(
target_scope: hir::SemanticsScope,
) -> (ast::Impl, ast::AssocItem) {
let source_scope = sema.scope_for_def(trait_);
let ast_transform = QualifyPaths::new(&target_scope, &source_scope)
.or(SubstituteTypeParams::for_trait_impl(&source_scope, trait_, impl_.clone()));
let items = items
.into_iter()
.map(|it| it.clone_for_update())
.inspect(|it| ast_transform::apply(&*ast_transform, it))
.map(|it| edit::remove_attrs_and_docs(&it).clone_subtree().clone_for_update());
let transform = AstTransform {
subst: (trait_, impl_.clone()),
source_scope: &source_scope,
target_scope: &target_scope,
};
let items = items.into_iter().map(|assoc_item| {
let assoc_item = assoc_item.clone_for_update();
transform.apply(assoc_item.clone());
edit::remove_attrs_and_docs(&assoc_item).clone_subtree().clone_for_update()
});
let res = impl_.clone_for_update();

View file

@ -6,14 +6,12 @@ use std::{
ops::{self, RangeInclusive},
};
use arrayvec::ArrayVec;
use crate::{
algo,
ast::{self, make, AstNode},
ted, AstToken, InsertPosition, NodeOrToken, SyntaxElement, SyntaxKind,
ted, AstToken, NodeOrToken, SyntaxElement, SyntaxKind,
SyntaxKind::{ATTR, COMMENT, WHITESPACE},
SyntaxNode, SyntaxToken, T,
SyntaxNode, SyntaxToken,
};
impl ast::BinExpr {
@ -25,46 +23,6 @@ impl ast::BinExpr {
}
}
impl ast::Path {
#[must_use]
pub fn with_segment(&self, segment: ast::PathSegment) -> ast::Path {
if let Some(old) = self.segment() {
return self.replace_children(
single_node(old.syntax().clone()),
iter::once(segment.syntax().clone().into()),
);
}
self.clone()
}
}
impl ast::PathSegment {
#[must_use]
pub fn with_generic_args(&self, type_args: ast::GenericArgList) -> ast::PathSegment {
self._with_generic_args(type_args, false)
}
#[must_use]
pub fn with_turbo_fish(&self, type_args: ast::GenericArgList) -> ast::PathSegment {
self._with_generic_args(type_args, true)
}
fn _with_generic_args(&self, type_args: ast::GenericArgList, turbo: bool) -> ast::PathSegment {
if let Some(old) = self.generic_arg_list() {
return self.replace_children(
single_node(old.syntax().clone()),
iter::once(type_args.syntax().clone().into()),
);
}
let mut to_insert: ArrayVec<SyntaxElement, 2> = ArrayVec::new();
if turbo {
to_insert.push(make::token(T![::]).into());
}
to_insert.push(type_args.syntax().clone().into());
self.insert_children(InsertPosition::Last, to_insert)
}
}
impl ast::UseTree {
/// Splits off the given prefix, making it the path component of the use tree, appending the rest of the path to all UseTreeList items.
#[must_use]
@ -233,16 +191,6 @@ fn prev_tokens(token: SyntaxToken) -> impl Iterator<Item = SyntaxToken> {
}
pub trait AstNodeEdit: AstNode + Clone + Sized {
#[must_use]
fn insert_children(
&self,
position: InsertPosition<SyntaxElement>,
to_insert: impl IntoIterator<Item = SyntaxElement>,
) -> Self {
let new_syntax = algo::insert_children(self.syntax(), position, to_insert);
Self::cast(new_syntax).unwrap()
}
#[must_use]
fn replace_children(
&self,

View file

@ -239,6 +239,16 @@ impl ast::TypeBoundList {
}
}
impl ast::PathSegment {
pub fn get_or_create_generic_arg_list(&self) -> ast::GenericArgList {
if self.generic_arg_list().is_none() {
let arg_list = make::generic_arg_list().clone_for_update();
ted::append_child(self.syntax(), arg_list.syntax())
}
self.generic_arg_list().unwrap()
}
}
impl ast::UseTree {
pub fn remove(&self) {
for &dir in [Direction::Next, Direction::Prev].iter() {

View file

@ -106,6 +106,10 @@ pub fn impl_trait(trait_: ast::Path, ty: ast::Path) -> ast::Impl {
ast_from_text(&format!("impl {} for {} {{}}", trait_, ty))
}
pub(crate) fn generic_arg_list() -> ast::GenericArgList {
ast_from_text("const S: T<> = ();")
}
pub fn path_segment(name_ref: ast::NameRef) -> ast::PathSegment {
ast_from_text(&format!("use {};", name_ref))
}

View file

@ -184,6 +184,9 @@ fn ws_between(left: &SyntaxElement, right: &SyntaxElement) -> Option<SyntaxToken
if left.kind() == T![&] && right.kind() == SyntaxKind::LIFETIME {
return None;
}
if right.kind() == SyntaxKind::GENERIC_ARG_LIST {
return None;
}
if right.kind() == SyntaxKind::USE {
let mut indent = IndentLevel::from_element(left);