diff --git a/crates/ra_assists/src/ast_editor.rs b/crates/ra_assists/src/ast_editor.rs index 0484786625..a710edce8e 100644 --- a/crates/ra_assists/src/ast_editor.rs +++ b/crates/ra_assists/src/ast_editor.rs @@ -297,6 +297,11 @@ impl AstBuilder { ast_node_from_file_text(text) } + pub fn from_name(name: ast::Name) -> ast::Path { + let name = name.syntax().to_string(); + Self::from_text(name.as_str()) + } + pub fn from_pieces(enum_name: ast::Name, var_name: ast::Name) -> ast::Path { Self::from_text(&format!("{}::{}", enum_name.syntax(), var_name.syntax())) } @@ -380,6 +385,31 @@ impl AstBuilder { } } +impl AstBuilder { + fn from_text(text: &str) -> ast::WherePred { + ast_node_from_file_text(&format!("fn f() where {} {{ }}", text)) + } + + pub fn from_pieces( + path: ast::Path, + bounds: impl Iterator, + ) -> ast::WherePred { + let bounds = bounds.map(|b| b.syntax().to_string()).collect::>().join(" + "); + Self::from_text(&format!("{}: {}", path.syntax(), bounds)) + } +} + +impl AstBuilder { + fn from_text(text: &str) -> ast::WhereClause { + ast_node_from_file_text(&format!("fn f() where {} {{ }}", text)) + } + + pub fn from_predicates(preds: impl Iterator) -> ast::WhereClause { + let preds = preds.map(|p| p.syntax().to_string()).collect::>().join(", "); + Self::from_text(preds.as_str()) + } +} + fn ast_node_from_file_text(text: &str) -> N { let parse = SourceFile::parse(text); let res = parse.tree().syntax().descendants().find_map(N::cast).unwrap(); diff --git a/crates/ra_assists/src/lib.rs b/crates/ra_assists/src/lib.rs index 03eec73ad8..10ccc345c9 100644 --- a/crates/ra_assists/src/lib.rs +++ b/crates/ra_assists/src/lib.rs @@ -102,6 +102,7 @@ mod remove_dbg; pub mod auto_import; mod add_missing_impl_members; mod move_guard; +mod move_bounds; fn all_assists() -> &'static [fn(AssistCtx) -> Option] { &[ @@ -123,6 +124,7 @@ fn all_assists() -> &'static [fn(AssistCtx) -> Option) -> Option { + let type_param_list = ctx.node_at_offset::()?; + + let mut type_params = type_param_list.type_params(); + if type_params.all(|p| p.type_bound_list().is_none()) { + return None; + } + + let parent = type_param_list.syntax().parent()?; + if parent.children_with_tokens().find(|it| it.kind() == WHERE_CLAUSE).is_some() { + return None; + } + + let anchor: SyntaxElement = match parent.kind() { + FN_DEF => ast::FnDef::cast(parent)?.body()?.syntax().clone().into(), + TRAIT_DEF => ast::TraitDef::cast(parent)?.item_list()?.syntax().clone().into(), + IMPL_BLOCK => ast::ImplBlock::cast(parent)?.item_list()?.syntax().clone().into(), + ENUM_DEF => ast::EnumDef::cast(parent)?.variant_list()?.syntax().clone().into(), + STRUCT_DEF => parent + .children_with_tokens() + .find(|it| it.kind() == RECORD_FIELD_DEF_LIST || it.kind() == SEMI)?, + _ => return None, + }; + + ctx.add_action( + AssistId("move_bounds_to_where_clause"), + "move_bounds_to_where_clause", + |edit| { + let type_params = type_param_list.type_params().collect::>(); + + for param in &type_params { + if let Some(bounds) = param.type_bound_list() { + let colon = param + .syntax() + .children_with_tokens() + .find(|it| it.kind() == COLON) + .unwrap(); + let start = colon.text_range().start(); + let end = bounds.syntax().text_range().end(); + edit.delete(TextRange::from_to(start, end)); + } + } + + let predicates = type_params.iter().filter_map(build_predicate); + let where_clause = AstBuilder::::from_predicates(predicates); + + let to_insert = match anchor.prev_sibling_or_token() { + Some(ref elem) if elem.kind() == WHITESPACE => { + format!("{} ", where_clause.syntax()) + } + _ => format!(" {}", where_clause.syntax()), + }; + edit.insert(anchor.text_range().start(), to_insert); + edit.target(type_param_list.syntax().text_range()); + }, + ); + + ctx.build() +} + +fn build_predicate(param: &ast::TypeParam) -> Option { + let path = AstBuilder::::from_name(param.name()?); + let predicate = + AstBuilder::::from_pieces(path, param.type_bound_list()?.bounds()); + Some(predicate) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::helpers::check_assist; + + #[test] + fn move_bounds_to_where_clause_fn() { + check_assist( + move_bounds_to_where_clause, + r#" + fn fooF: FnOnce(T) -> T>() {} + "#, + r#" + fn fooF>() where T: u32, F: FnOnce(T) -> T {} + "#, + ); + } + + #[test] + fn move_bounds_to_where_clause_impl() { + check_assist( + move_bounds_to_where_clause, + r#" + implT> A {} + "#, + r#" + implT> A where U: u32 {} + "#, + ); + } + + #[test] + fn move_bounds_to_where_clause_struct() { + check_assist( + move_bounds_to_where_clause, + r#" + struct A<<|>T: Iterator> {} + "#, + r#" + struct A<<|>T> where T: Iterator {} + "#, + ); + } + + #[test] + fn move_bounds_to_where_clause_tuple_struct() { + check_assist( + move_bounds_to_where_clause, + r#" + struct Pair<<|>T: u32>(T, T); + "#, + r#" + struct Pair<<|>T>(T, T) where T: u32; + "#, + ); + } +} diff --git a/docs/user/features.md b/docs/user/features.md index b6e6008c42..1034a51172 100644 --- a/docs/user/features.md +++ b/docs/user/features.md @@ -435,6 +435,16 @@ fn f() { } ``` +- Move type bounds to where clause + +```rust +// before: +fn foo T>() {} + +// after: +fn foo() where T: u32, F: FnOnce(T) -> T {} +``` + ### Magic Completions In addition to usual reference completion, rust-analyzer provides some ✨magic✨