From 075ab03851d7b4304ddf563d75c4ee3d713f2583 Mon Sep 17 00:00:00 2001 From: Dorian Scheidt Date: Sat, 30 Apr 2022 12:29:55 -0500 Subject: [PATCH 1/2] fix: Support generics in extract_function assist This change attempts to resolve issue #7636: Extract into Function does not create a generic function with constraints when extracting generic code. In `FunctionBody::analyze_container`, we now traverse the `ancestors` in search of `AnyHasGenericParams`, and attach any `GenericParamList`s and `WhereClause`s we find to the `ContainerInfo`. Later, in `format_function`, we collect all the `GenericParam`s and `WherePred`s from the container, and filter them to keep only types matching `TypeParam`s used within the newly extracted function body or param list. We can then include the new `GenericParamList` and `WhereClause` in the new function definition. This change only impacts `TypeParam`s. `LifetimeParam`s and `ConstParam`s are out of scope for this change. --- crates/hir/src/lib.rs | 9 + .../src/handlers/extract_function.rs | 434 +++++++++++++++++- 2 files changed, 436 insertions(+), 7 deletions(-) diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 96424d087e..86124b68b5 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -3307,6 +3307,15 @@ impl Type { let tys = hir_ty::replace_errors_with_variables(&(self.ty.clone(), to.ty.clone())); hir_ty::could_coerce(db, self.env.clone(), &tys) } + + pub fn as_type_param(&self, db: &dyn HirDatabase) -> Option { + match self.ty.kind(Interner) { + TyKind::Placeholder(p) => Some(TypeParam { + id: TypeParamId::from_unchecked(hir_ty::from_placeholder_idx(db, *p)), + }), + _ => None, + } + } } #[derive(Debug)] diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index 9233c198df..aa1c3a548c 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -2,7 +2,9 @@ use std::iter; use ast::make; use either::Either; -use hir::{HasSource, HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo}; +use hir::{ + HasSource, HirDisplay, InFile, Local, ModuleDef, PathResolution, Semantics, TypeInfo, TypeParam, +}; use ide_db::{ defs::{Definition, NameRefClass}, famous_defs::FamousDefs, @@ -18,7 +20,7 @@ use syntax::{ ast::{ self, edit::{AstNodeEdit, IndentLevel}, - AstNode, + AstNode, HasGenericParams, }, match_ast, ted, SyntaxElement, SyntaxKind::{self, COMMENT}, @@ -294,6 +296,8 @@ struct ContainerInfo { parent_loop: Option, /// The function's return type, const's type etc. ret_type: Option, + generic_param_lists: Vec, + where_clauses: Vec, } /// Control flow that is exported from extracted function @@ -517,6 +521,24 @@ impl FunctionBody { } } + fn descendants(&self) -> impl Iterator { + match self { + FunctionBody::Expr(expr) => expr.syntax().descendants(), + FunctionBody::Span { parent, .. } => parent.syntax().descendants(), + } + } + + fn descendant_paths(&self) -> impl Iterator { + self.descendants().filter_map(|node| { + match_ast! { + match node { + ast::Path(it) => Some(it), + _ => None + } + } + }) + } + fn from_expr(expr: ast::Expr) -> Option { match expr { ast::Expr::BreakExpr(it) => it.expr().map(Self::Expr), @@ -731,6 +753,7 @@ impl FunctionBody { parent_loop.get_or_insert(loop_.syntax().clone()); } }; + let (is_const, expr, ty) = loop { let anc = ancestors.next()?; break match_ast! { @@ -798,7 +821,19 @@ impl FunctionBody { container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| { container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range()) }); - Some(ContainerInfo { is_in_tail, is_const, parent_loop, ret_type: ty }) + + let parent = self.parent()?; + let generic_param_lists = parent_generic_param_lists(&parent); + let where_clauses = parent_where_clauses(&parent); + + Some(ContainerInfo { + is_in_tail, + is_const, + parent_loop, + ret_type: ty, + generic_param_lists, + where_clauses, + }) } fn return_ty(&self, ctx: &AssistContext) -> Option { @@ -955,6 +990,26 @@ impl FunctionBody { } } +fn parent_where_clauses(parent: &SyntaxNode) -> Vec { + let mut where_clause: Vec = parent + .ancestors() + .filter_map(ast::AnyHasGenericParams::cast) + .filter_map(|it| it.where_clause()) + .collect(); + where_clause.reverse(); + where_clause +} + +fn parent_generic_param_lists(parent: &SyntaxNode) -> Vec { + let mut generic_param_list: Vec = parent + .ancestors() + .filter_map(ast::AnyHasGenericParams::cast) + .filter_map(|it| it.generic_param_list()) + .collect(); + generic_param_list.reverse(); + generic_param_list +} + /// checks if relevant var is used with `&mut` access inside body fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &FunctionBody) -> bool { usages @@ -1362,37 +1417,154 @@ fn format_function( let const_kw = if fun.mods.is_const { "const " } else { "" }; let async_kw = if fun.control_flow.is_async { "async " } else { "" }; let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" }; + let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun); match ctx.config.snippet_cap { Some(_) => format_to!( fn_def, - "\n\n{}{}{}{}fn $0{}{}", + "\n\n{}{}{}{}fn $0{}", new_indent, const_kw, async_kw, unsafe_kw, fun.name, - params ), None => format_to!( fn_def, - "\n\n{}{}{}{}fn {}{}", + "\n\n{}{}{}{}fn {}", new_indent, const_kw, async_kw, unsafe_kw, fun.name, - params ), } + + if let Some(generic_params) = generic_params { + format_to!(fn_def, "{}", generic_params); + } + + format_to!(fn_def, "{}", params); + if let Some(ret_ty) = ret_ty { format_to!(fn_def, " {}", ret_ty); } + + if let Some(where_clause) = where_clause { + format_to!(fn_def, " {}", where_clause); + } + format_to!(fn_def, " {}", body); fn_def } +fn make_generic_params_and_where_clause( + ctx: &AssistContext, + fun: &Function, +) -> (Option, Option) { + let used_type_params = fun.type_params(ctx); + + let generic_param_list = make_generic_param_list(ctx, fun, &used_type_params); + let where_clause = make_where_clause(ctx, fun, &used_type_params); + + (generic_param_list, where_clause) +} + +fn make_generic_param_list( + ctx: &AssistContext, + fun: &Function, + used_type_params: &[TypeParam], +) -> Option { + let mut generic_params = fun + .mods + .generic_param_lists + .iter() + .flat_map(|parent_params| { + parent_params + .generic_params() + .filter(|param| param_is_required(ctx, param, used_type_params)) + }) + .peekable(); + + if generic_params.peek().is_some() { + Some(make::generic_param_list(generic_params)) + } else { + None + } +} + +fn param_is_required( + ctx: &AssistContext, + param: &ast::GenericParam, + used_type_params: &[TypeParam], +) -> bool { + match param { + ast::GenericParam::ConstParam(_) | ast::GenericParam::LifetimeParam(_) => false, + ast::GenericParam::TypeParam(type_param) => match &ctx.sema.to_def(type_param) { + Some(def) => used_type_params.contains(def), + _ => false, + }, + } +} + +fn make_where_clause( + ctx: &AssistContext, + fun: &Function, + used_type_params: &[TypeParam], +) -> Option { + let mut predicates = fun + .mods + .where_clauses + .iter() + .flat_map(|parent_where_clause| { + parent_where_clause + .predicates() + .filter(|pred| pred_is_required(ctx, pred, used_type_params)) + }) + .peekable(); + + if predicates.peek().is_some() { + Some(make::where_clause(predicates)) + } else { + None + } +} + +fn pred_is_required( + ctx: &AssistContext, + pred: &ast::WherePred, + used_type_params: &[TypeParam], +) -> bool { + match resolved_type_param(ctx, pred) { + Some(it) => used_type_params.contains(&it), + None => false, + } +} + +fn resolved_type_param(ctx: &AssistContext, pred: &ast::WherePred) -> Option { + let path = match pred.ty()? { + ast::Type::PathType(path_type) => path_type.path(), + _ => None, + }?; + + match ctx.sema.resolve_path(&path)? { + PathResolution::TypeParam(type_param) => Some(type_param), + _ => None, + } +} + impl Function { + /// Collect all the `TypeParam`s used in the `body` and `params`. + fn type_params(&self, ctx: &AssistContext) -> Vec { + let type_params_in_descendant_paths = + self.body.descendant_paths().filter_map(|it| match ctx.sema.resolve_path(&it) { + Some(PathResolution::TypeParam(type_param)) => Some(type_param), + _ => None, + }); + let type_params_in_params = self.params.iter().filter_map(|p| p.ty.as_type_param(ctx.db())); + type_params_in_descendant_paths.chain(type_params_in_params).collect() + } + fn make_param_list(&self, ctx: &AssistContext, module: hir::Module) -> ast::ParamList { let self_param = self.self_param.clone(); let params = self.params.iter().map(|param| param.to_param(ctx, module)); @@ -4872,6 +5044,254 @@ fn parent(factor: i32) { fn $0fun_name(v: &[i32; 3], factor: i32) { v.iter().map(|it| it * factor); } +"#, + ); + } + + #[test] + fn preserve_generics() { + check_assist( + extract_function, + r#" +fn func(i: T) { + $0foo(i);$0 +} +"#, + r#" +fn func(i: T) { + fun_name(i); +} + +fn $0fun_name(i: T) { + foo(i); +} +"#, + ); + } + + #[test] + fn preserve_generics_from_body() { + check_assist( + extract_function, + r#" +fn func() -> T { + $0T::default()$0 +} +"#, + r#" +fn func() -> T { + fun_name() +} + +fn $0fun_name() -> T { + T::default() +} +"#, + ); + } + + #[test] + fn filter_unused_generics() { + check_assist( + extract_function, + r#" +fn func(i: T, u: U) { + bar(u); + $0foo(i);$0 +} +"#, + r#" +fn func(i: T, u: U) { + bar(u); + fun_name(i); +} + +fn $0fun_name(i: T) { + foo(i); +} +"#, + ); + } + + #[test] + fn empty_generic_param_list() { + check_assist( + extract_function, + r#" +fn func(t: T, i: u32) { + bar(t); + $0foo(i);$0 +} +"#, + r#" +fn func(t: T, i: u32) { + bar(t); + fun_name(i); +} + +fn $0fun_name(i: u32) { + foo(i); +} +"#, + ); + } + + #[test] + fn preserve_where_clause() { + check_assist( + extract_function, + r#" +fn func(i: T) where T: Debug { + $0foo(i);$0 +} +"#, + r#" +fn func(i: T) where T: Debug { + fun_name(i); +} + +fn $0fun_name(i: T) where T: Debug { + foo(i); +} +"#, + ); + } + + #[test] + fn filter_unused_where_clause() { + check_assist( + extract_function, + r#" +fn func(i: T, u: U) where T: Debug, U: Copy { + bar(u); + $0foo(i);$0 +} +"#, + r#" +fn func(i: T, u: U) where T: Debug, U: Copy { + bar(u); + fun_name(i); +} + +fn $0fun_name(i: T) where T: Debug { + foo(i); +} +"#, + ); + } + + #[test] + fn nested_generics() { + check_assist( + extract_function, + r#" +struct Struct>(T); +impl + Copy> Struct { + fn func>(&self, v: V) -> i32 { + let t = self.0; + $0t.into() + v.into()$0 + } +} +"#, + r#" +struct Struct>(T); +impl + Copy> Struct { + fn func>(&self, v: V) -> i32 { + let t = self.0; + fun_name(t, v) + } +} + +fn $0fun_name + Copy, V: Into>(t: T, v: V) -> i32 { + t.into() + v.into() +} +"#, + ); + } + + #[test] + fn filters_unused_nested_generics() { + check_assist( + extract_function, + r#" +struct Struct, U: Debug>(T, U); +impl + Copy, U: Debug> Struct { + fn func>(&self, v: V) -> i32 { + let t = self.0; + $0t.into() + v.into()$0 + } +} +"#, + r#" +struct Struct, U: Debug>(T, U); +impl + Copy, U: Debug> Struct { + fn func>(&self, v: V) -> i32 { + let t = self.0; + fun_name(t, v) + } +} + +fn $0fun_name + Copy, V: Into>(t: T, v: V) -> i32 { + t.into() + v.into() +} +"#, + ); + } + + #[test] + fn nested_where_clauses() { + check_assist( + extract_function, + r#" +struct Struct(T) where T: Into; +impl Struct where T: Into + Copy { + fn func(&self, v: V) -> i32 where V: Into { + let t = self.0; + $0t.into() + v.into()$0 + } +} +"#, + r#" +struct Struct(T) where T: Into; +impl Struct where T: Into + Copy { + fn func(&self, v: V) -> i32 where V: Into { + let t = self.0; + fun_name(t, v) + } +} + +fn $0fun_name(t: T, v: V) -> i32 where T: Into + Copy, V: Into { + t.into() + v.into() +} +"#, + ); + } + + #[test] + fn filters_unused_nested_where_clauses() { + check_assist( + extract_function, + r#" +struct Struct(T, U) where T: Into, U: Debug; +impl Struct where T: Into + Copy, U: Debug { + fn func(&self, v: V) -> i32 where V: Into { + let t = self.0; + $0t.into() + v.into()$0 + } +} +"#, + r#" +struct Struct(T, U) where T: Into, U: Debug; +impl Struct where T: Into + Copy, U: Debug { + fn func(&self, v: V) -> i32 where V: Into { + let t = self.0; + fun_name(t, v) + } +} + +fn $0fun_name(t: T, v: V) -> i32 where T: Into + Copy, V: Into { + t.into() + v.into() +} "#, ); } From 796641b5d85d06f2884e28971a358421976aefaa Mon Sep 17 00:00:00 2001 From: Dorian Scheidt Date: Wed, 13 Jul 2022 10:20:55 -0500 Subject: [PATCH 2/2] Make search for applicable generics more precise --- .../src/handlers/extract_function.rs | 67 ++++++++++++++----- 1 file changed, 49 insertions(+), 18 deletions(-) diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index aa1c3a548c..94b638d4c6 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -823,8 +823,9 @@ impl FunctionBody { }); let parent = self.parent()?; - let generic_param_lists = parent_generic_param_lists(&parent); - let where_clauses = parent_where_clauses(&parent); + let parents = generic_parents(&parent); + let generic_param_lists = parents.iter().filter_map(|it| it.generic_param_list()).collect(); + let where_clauses = parents.iter().filter_map(|it| it.where_clause()).collect(); Some(ContainerInfo { is_in_tail, @@ -990,24 +991,54 @@ impl FunctionBody { } } -fn parent_where_clauses(parent: &SyntaxNode) -> Vec { - let mut where_clause: Vec = parent - .ancestors() - .filter_map(ast::AnyHasGenericParams::cast) - .filter_map(|it| it.where_clause()) - .collect(); - where_clause.reverse(); - where_clause +enum GenericParent { + Fn(ast::Fn), + Impl(ast::Impl), + Trait(ast::Trait), } -fn parent_generic_param_lists(parent: &SyntaxNode) -> Vec { - let mut generic_param_list: Vec = parent - .ancestors() - .filter_map(ast::AnyHasGenericParams::cast) - .filter_map(|it| it.generic_param_list()) - .collect(); - generic_param_list.reverse(); - generic_param_list +impl GenericParent { + fn generic_param_list(&self) -> Option { + match self { + GenericParent::Fn(fn_) => fn_.generic_param_list(), + GenericParent::Impl(impl_) => impl_.generic_param_list(), + GenericParent::Trait(trait_) => trait_.generic_param_list(), + } + } + + fn where_clause(&self) -> Option { + match self { + GenericParent::Fn(fn_) => fn_.where_clause(), + GenericParent::Impl(impl_) => impl_.where_clause(), + GenericParent::Trait(trait_) => trait_.where_clause(), + } + } +} + +/// Search `parent`'s ancestors for items with potentially applicable generic parameters +fn generic_parents(parent: &SyntaxNode) -> Vec { + let mut list = Vec::new(); + if let Some(parent_item) = parent.ancestors().find_map(ast::Item::cast) { + match parent_item { + ast::Item::Fn(ref fn_) => { + if let Some(parent_parent) = parent_item + .syntax() + .parent() + .and_then(|it| it.parent()) + .and_then(ast::Item::cast) + { + match parent_parent { + ast::Item::Impl(impl_) => list.push(GenericParent::Impl(impl_)), + ast::Item::Trait(trait_) => list.push(GenericParent::Trait(trait_)), + _ => (), + } + } + list.push(GenericParent::Fn(fn_.clone())); + } + _ => (), + } + } + list } /// checks if relevant var is used with `&mut` access inside body