diff --git a/crates/assists/src/handlers/replace_impl_trait_with_generic.rs b/crates/assists/src/handlers/replace_impl_trait_with_generic.rs new file mode 100644 index 0000000000..8af2d16dd8 --- /dev/null +++ b/crates/assists/src/handlers/replace_impl_trait_with_generic.rs @@ -0,0 +1,62 @@ +use syntax::ast::{self, edit::AstNodeEdit, make, AstNode, GenericParamsOwner}; + +use crate::{AssistContext, AssistId, AssistKind, Assists}; + +// Assist: replace_impl_trait_with_generic +// +// Replaces `impl Trait` function argument with the named generic. +pub(crate) fn replace_impl_trait_with_generic( + acc: &mut Assists, + ctx: &AssistContext, +) -> Option<()> { + let type_impl_trait = ctx.find_node_at_offset::()?; + let type_param = type_impl_trait.syntax().parent().and_then(ast::Param::cast)?; + let type_fn = type_param.syntax().ancestors().nth(2).and_then(ast::Fn::cast)?; + + let generic_param_list = + type_fn.generic_param_list().unwrap_or_else(|| make::generic_param_list(None)); + + let impl_trait_ty = type_impl_trait + .syntax() + .descendants() + .last() + .and_then(ast::NameRef::cast)? + .text() + .to_string(); + + let target = type_fn.syntax().text_range(); + acc.add( + AssistId("replace_impl_trait_with_generic", AssistKind::RefactorRewrite), + "Replace impl trait with generic", + target, + |edit| { + let generic_letter = impl_trait_ty[..1].to_string(); + edit.replace_ast::(type_impl_trait.into(), make::ty(&generic_letter)); + + let new_params = generic_param_list + .append_param(make::generic_param(generic_letter, Some(impl_trait_ty))); + let new_type_fn = type_fn.replace_descendant(generic_param_list, new_params); + edit.replace_ast(type_fn.clone(), new_type_fn); + }, + ) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::tests::check_assist; + + #[test] + fn replace_with_generic_params() { + check_assist( + replace_impl_trait_with_generic, + r#" + fn foo(bar: <|>impl Bar) {} + "#, + r#" + fn foo(bar: B) {} + "#, + ); + } +} diff --git a/crates/assists/src/lib.rs b/crates/assists/src/lib.rs index 2e0d191a60..cbac53e711 100644 --- a/crates/assists/src/lib.rs +++ b/crates/assists/src/lib.rs @@ -155,6 +155,7 @@ mod handlers { mod remove_unused_param; mod reorder_fields; mod replace_if_let_with_match; + mod replace_impl_trait_with_generic; mod replace_let_with_if_let; mod replace_qualified_name_with_use; mod replace_unwrap_with_match; @@ -202,6 +203,7 @@ mod handlers { remove_unused_param::remove_unused_param, reorder_fields::reorder_fields, replace_if_let_with_match::replace_if_let_with_match, + replace_impl_trait_with_generic::replace_impl_trait_with_generic, replace_let_with_if_let::replace_let_with_if_let, replace_qualified_name_with_use::replace_qualified_name_with_use, replace_unwrap_with_match::replace_unwrap_with_match, diff --git a/crates/syntax/src/ast/edit.rs b/crates/syntax/src/ast/edit.rs index 8234753332..1ccb4de6af 100644 --- a/crates/syntax/src/ast/edit.rs +++ b/crates/syntax/src/ast/edit.rs @@ -459,6 +459,72 @@ impl ast::MatchArmList { } } +impl ast::GenericParamList { + #[must_use] + pub fn append_params(&self, params: impl IntoIterator) -> Self { + let mut res = self.clone(); + params.into_iter().for_each(|it| res = res.append_param(it)); + res + } + + #[must_use] + pub fn append_param(&self, item: ast::GenericParam) -> Self { + let is_multiline = self.syntax().text().contains_char('\n'); + let ws; + let space = if is_multiline { + ws = tokens::WsBuilder::new(&format!( + "\n{} ", + leading_indent(self.syntax()).unwrap_or_default() + )); + ws.ws() + } else { + tokens::single_space() + }; + + let mut to_insert: ArrayVec<[SyntaxElement; 4]> = ArrayVec::new(); + to_insert.push(space.into()); + to_insert.push(item.syntax().clone().into()); + to_insert.push(make::token(T![,]).into()); + + macro_rules! after_l_angle { + () => {{ + let anchor = match self.l_angle_token() { + Some(it) => it.into(), + None => return self.clone(), + }; + InsertPosition::After(anchor) + }}; + } + + macro_rules! after_field { + ($anchor:expr) => { + if let Some(comma) = $anchor + .syntax() + .siblings_with_tokens(Direction::Next) + .find(|it| it.kind() == T![,]) + { + InsertPosition::After(comma) + } else { + to_insert.insert(0, make::token(T![,]).into()); + InsertPosition::After($anchor.syntax().clone().into()) + } + }; + }; + + if !is_multiline { + // don't insert comma before angle + to_insert.pop(); + } + + let position = match self.generic_params().last() { + Some(it) => after_field!(it), + None => after_l_angle!(), + }; + + self.insert_children(position, to_insert) + } +} + #[must_use] pub fn remove_attrs_and_docs(node: &N) -> N { N::cast(remove_attrs_and_docs_inner(node.syntax().clone())).unwrap() diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index c2c938ad11..7329e3039c 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -294,6 +294,21 @@ pub fn param_list(pats: impl IntoIterator) -> ast::ParamList ast_from_text(&format!("fn f({}) {{ }}", args)) } +pub fn generic_param(name: String, ty: Option) -> ast::GenericParam { + let bound = match ty { + Some(it) => format!(": {}", it), + None => String::new(), + }; + ast_from_text(&format!("fn f<{}{}>() {{ }}", name, bound)) +} + +pub fn generic_param_list( + pats: impl IntoIterator, +) -> ast::GenericParamList { + let args = pats.into_iter().join(", "); + ast_from_text(&format!("fn f<{}>() {{ }}", args)) +} + pub fn visibility_pub_crate() -> ast::Visibility { ast_from_text("pub(crate) struct S") }