fix: Support generics in extract_function assist

This change attempts to resolve issue #7636: Extract into Function does not
create a generic function with constraints when extracting generic code.

In `FunctionBody::analyze_container`, we now traverse the `ancestors` in search
of `AnyHasGenericParams`, and attach any `GenericParamList`s and `WhereClause`s
we find to the `ContainerInfo`.

Later, in `format_function`, we collect all the `GenericParam`s and
`WherePred`s from the container, and filter them to keep only types matching
`TypeParam`s used within the newly extracted function body or param list. We
can then include the new `GenericParamList` and `WhereClause` in the new
function definition.

This change only impacts `TypeParam`s. `LifetimeParam`s and `ConstParam`s are
out of scope for this change.
This commit is contained in:
Dorian Scheidt 2022-04-30 12:29:55 -05:00
parent 794ecd58a3
commit 075ab03851
2 changed files with 436 additions and 7 deletions

View file

@ -3307,6 +3307,15 @@ impl Type {
let tys = hir_ty::replace_errors_with_variables(&(self.ty.clone(), to.ty.clone()));
hir_ty::could_coerce(db, self.env.clone(), &tys)
}
pub fn as_type_param(&self, db: &dyn HirDatabase) -> Option<TypeParam> {
match self.ty.kind(Interner) {
TyKind::Placeholder(p) => Some(TypeParam {
id: TypeParamId::from_unchecked(hir_ty::from_placeholder_idx(db, *p)),
}),
_ => None,
}
}
}
#[derive(Debug)]

View file

@ -2,7 +2,9 @@ use std::iter;
use ast::make;
use either::Either;
use hir::{HasSource, HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo};
use hir::{
HasSource, HirDisplay, InFile, Local, ModuleDef, PathResolution, Semantics, TypeInfo, TypeParam,
};
use ide_db::{
defs::{Definition, NameRefClass},
famous_defs::FamousDefs,
@ -18,7 +20,7 @@ use syntax::{
ast::{
self,
edit::{AstNodeEdit, IndentLevel},
AstNode,
AstNode, HasGenericParams,
},
match_ast, ted, SyntaxElement,
SyntaxKind::{self, COMMENT},
@ -294,6 +296,8 @@ struct ContainerInfo {
parent_loop: Option<SyntaxNode>,
/// The function's return type, const's type etc.
ret_type: Option<hir::Type>,
generic_param_lists: Vec<ast::GenericParamList>,
where_clauses: Vec<ast::WhereClause>,
}
/// Control flow that is exported from extracted function
@ -517,6 +521,24 @@ impl FunctionBody {
}
}
fn descendants(&self) -> impl Iterator<Item = SyntaxNode> {
match self {
FunctionBody::Expr(expr) => expr.syntax().descendants(),
FunctionBody::Span { parent, .. } => parent.syntax().descendants(),
}
}
fn descendant_paths(&self) -> impl Iterator<Item = ast::Path> {
self.descendants().filter_map(|node| {
match_ast! {
match node {
ast::Path(it) => Some(it),
_ => None
}
}
})
}
fn from_expr(expr: ast::Expr) -> Option<Self> {
match expr {
ast::Expr::BreakExpr(it) => it.expr().map(Self::Expr),
@ -731,6 +753,7 @@ impl FunctionBody {
parent_loop.get_or_insert(loop_.syntax().clone());
}
};
let (is_const, expr, ty) = loop {
let anc = ancestors.next()?;
break match_ast! {
@ -798,7 +821,19 @@ impl FunctionBody {
container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
});
Some(ContainerInfo { is_in_tail, is_const, parent_loop, ret_type: ty })
let parent = self.parent()?;
let generic_param_lists = parent_generic_param_lists(&parent);
let where_clauses = parent_where_clauses(&parent);
Some(ContainerInfo {
is_in_tail,
is_const,
parent_loop,
ret_type: ty,
generic_param_lists,
where_clauses,
})
}
fn return_ty(&self, ctx: &AssistContext) -> Option<RetType> {
@ -955,6 +990,26 @@ impl FunctionBody {
}
}
fn parent_where_clauses(parent: &SyntaxNode) -> Vec<ast::WhereClause> {
let mut where_clause: Vec<ast::WhereClause> = parent
.ancestors()
.filter_map(ast::AnyHasGenericParams::cast)
.filter_map(|it| it.where_clause())
.collect();
where_clause.reverse();
where_clause
}
fn parent_generic_param_lists(parent: &SyntaxNode) -> Vec<ast::GenericParamList> {
let mut generic_param_list: Vec<ast::GenericParamList> = parent
.ancestors()
.filter_map(ast::AnyHasGenericParams::cast)
.filter_map(|it| it.generic_param_list())
.collect();
generic_param_list.reverse();
generic_param_list
}
/// checks if relevant var is used with `&mut` access inside body
fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &FunctionBody) -> bool {
usages
@ -1362,37 +1417,154 @@ fn format_function(
let const_kw = if fun.mods.is_const { "const " } else { "" };
let async_kw = if fun.control_flow.is_async { "async " } else { "" };
let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" };
let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun);
match ctx.config.snippet_cap {
Some(_) => format_to!(
fn_def,
"\n\n{}{}{}{}fn $0{}{}",
"\n\n{}{}{}{}fn $0{}",
new_indent,
const_kw,
async_kw,
unsafe_kw,
fun.name,
params
),
None => format_to!(
fn_def,
"\n\n{}{}{}{}fn {}{}",
"\n\n{}{}{}{}fn {}",
new_indent,
const_kw,
async_kw,
unsafe_kw,
fun.name,
params
),
}
if let Some(generic_params) = generic_params {
format_to!(fn_def, "{}", generic_params);
}
format_to!(fn_def, "{}", params);
if let Some(ret_ty) = ret_ty {
format_to!(fn_def, " {}", ret_ty);
}
if let Some(where_clause) = where_clause {
format_to!(fn_def, " {}", where_clause);
}
format_to!(fn_def, " {}", body);
fn_def
}
fn make_generic_params_and_where_clause(
ctx: &AssistContext,
fun: &Function,
) -> (Option<ast::GenericParamList>, Option<ast::WhereClause>) {
let used_type_params = fun.type_params(ctx);
let generic_param_list = make_generic_param_list(ctx, fun, &used_type_params);
let where_clause = make_where_clause(ctx, fun, &used_type_params);
(generic_param_list, where_clause)
}
fn make_generic_param_list(
ctx: &AssistContext,
fun: &Function,
used_type_params: &[TypeParam],
) -> Option<ast::GenericParamList> {
let mut generic_params = fun
.mods
.generic_param_lists
.iter()
.flat_map(|parent_params| {
parent_params
.generic_params()
.filter(|param| param_is_required(ctx, param, used_type_params))
})
.peekable();
if generic_params.peek().is_some() {
Some(make::generic_param_list(generic_params))
} else {
None
}
}
fn param_is_required(
ctx: &AssistContext,
param: &ast::GenericParam,
used_type_params: &[TypeParam],
) -> bool {
match param {
ast::GenericParam::ConstParam(_) | ast::GenericParam::LifetimeParam(_) => false,
ast::GenericParam::TypeParam(type_param) => match &ctx.sema.to_def(type_param) {
Some(def) => used_type_params.contains(def),
_ => false,
},
}
}
fn make_where_clause(
ctx: &AssistContext,
fun: &Function,
used_type_params: &[TypeParam],
) -> Option<ast::WhereClause> {
let mut predicates = fun
.mods
.where_clauses
.iter()
.flat_map(|parent_where_clause| {
parent_where_clause
.predicates()
.filter(|pred| pred_is_required(ctx, pred, used_type_params))
})
.peekable();
if predicates.peek().is_some() {
Some(make::where_clause(predicates))
} else {
None
}
}
fn pred_is_required(
ctx: &AssistContext,
pred: &ast::WherePred,
used_type_params: &[TypeParam],
) -> bool {
match resolved_type_param(ctx, pred) {
Some(it) => used_type_params.contains(&it),
None => false,
}
}
fn resolved_type_param(ctx: &AssistContext, pred: &ast::WherePred) -> Option<TypeParam> {
let path = match pred.ty()? {
ast::Type::PathType(path_type) => path_type.path(),
_ => None,
}?;
match ctx.sema.resolve_path(&path)? {
PathResolution::TypeParam(type_param) => Some(type_param),
_ => None,
}
}
impl Function {
/// Collect all the `TypeParam`s used in the `body` and `params`.
fn type_params(&self, ctx: &AssistContext) -> Vec<TypeParam> {
let type_params_in_descendant_paths =
self.body.descendant_paths().filter_map(|it| match ctx.sema.resolve_path(&it) {
Some(PathResolution::TypeParam(type_param)) => Some(type_param),
_ => None,
});
let type_params_in_params = self.params.iter().filter_map(|p| p.ty.as_type_param(ctx.db()));
type_params_in_descendant_paths.chain(type_params_in_params).collect()
}
fn make_param_list(&self, ctx: &AssistContext, module: hir::Module) -> ast::ParamList {
let self_param = self.self_param.clone();
let params = self.params.iter().map(|param| param.to_param(ctx, module));
@ -4872,6 +5044,254 @@ fn parent(factor: i32) {
fn $0fun_name(v: &[i32; 3], factor: i32) {
v.iter().map(|it| it * factor);
}
"#,
);
}
#[test]
fn preserve_generics() {
check_assist(
extract_function,
r#"
fn func<T: Debug>(i: T) {
$0foo(i);$0
}
"#,
r#"
fn func<T: Debug>(i: T) {
fun_name(i);
}
fn $0fun_name<T: Debug>(i: T) {
foo(i);
}
"#,
);
}
#[test]
fn preserve_generics_from_body() {
check_assist(
extract_function,
r#"
fn func<T: Default>() -> T {
$0T::default()$0
}
"#,
r#"
fn func<T: Default>() -> T {
fun_name()
}
fn $0fun_name<T: Default>() -> T {
T::default()
}
"#,
);
}
#[test]
fn filter_unused_generics() {
check_assist(
extract_function,
r#"
fn func<T: Debug, U: Copy>(i: T, u: U) {
bar(u);
$0foo(i);$0
}
"#,
r#"
fn func<T: Debug, U: Copy>(i: T, u: U) {
bar(u);
fun_name(i);
}
fn $0fun_name<T: Debug>(i: T) {
foo(i);
}
"#,
);
}
#[test]
fn empty_generic_param_list() {
check_assist(
extract_function,
r#"
fn func<T: Debug>(t: T, i: u32) {
bar(t);
$0foo(i);$0
}
"#,
r#"
fn func<T: Debug>(t: T, i: u32) {
bar(t);
fun_name(i);
}
fn $0fun_name(i: u32) {
foo(i);
}
"#,
);
}
#[test]
fn preserve_where_clause() {
check_assist(
extract_function,
r#"
fn func<T>(i: T) where T: Debug {
$0foo(i);$0
}
"#,
r#"
fn func<T>(i: T) where T: Debug {
fun_name(i);
}
fn $0fun_name<T>(i: T) where T: Debug {
foo(i);
}
"#,
);
}
#[test]
fn filter_unused_where_clause() {
check_assist(
extract_function,
r#"
fn func<T, U>(i: T, u: U) where T: Debug, U: Copy {
bar(u);
$0foo(i);$0
}
"#,
r#"
fn func<T, U>(i: T, u: U) where T: Debug, U: Copy {
bar(u);
fun_name(i);
}
fn $0fun_name<T>(i: T) where T: Debug {
foo(i);
}
"#,
);
}
#[test]
fn nested_generics() {
check_assist(
extract_function,
r#"
struct Struct<T: Into<i32>>(T);
impl <T: Into<i32> + Copy> Struct<T> {
fn func<V: Into<i32>>(&self, v: V) -> i32 {
let t = self.0;
$0t.into() + v.into()$0
}
}
"#,
r#"
struct Struct<T: Into<i32>>(T);
impl <T: Into<i32> + Copy> Struct<T> {
fn func<V: Into<i32>>(&self, v: V) -> i32 {
let t = self.0;
fun_name(t, v)
}
}
fn $0fun_name<T: Into<i32> + Copy, V: Into<i32>>(t: T, v: V) -> i32 {
t.into() + v.into()
}
"#,
);
}
#[test]
fn filters_unused_nested_generics() {
check_assist(
extract_function,
r#"
struct Struct<T: Into<i32>, U: Debug>(T, U);
impl <T: Into<i32> + Copy, U: Debug> Struct<T, U> {
fn func<V: Into<i32>>(&self, v: V) -> i32 {
let t = self.0;
$0t.into() + v.into()$0
}
}
"#,
r#"
struct Struct<T: Into<i32>, U: Debug>(T, U);
impl <T: Into<i32> + Copy, U: Debug> Struct<T, U> {
fn func<V: Into<i32>>(&self, v: V) -> i32 {
let t = self.0;
fun_name(t, v)
}
}
fn $0fun_name<T: Into<i32> + Copy, V: Into<i32>>(t: T, v: V) -> i32 {
t.into() + v.into()
}
"#,
);
}
#[test]
fn nested_where_clauses() {
check_assist(
extract_function,
r#"
struct Struct<T>(T) where T: Into<i32>;
impl <T> Struct<T> where T: Into<i32> + Copy {
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
let t = self.0;
$0t.into() + v.into()$0
}
}
"#,
r#"
struct Struct<T>(T) where T: Into<i32>;
impl <T> Struct<T> where T: Into<i32> + Copy {
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
let t = self.0;
fun_name(t, v)
}
}
fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
t.into() + v.into()
}
"#,
);
}
#[test]
fn filters_unused_nested_where_clauses() {
check_assist(
extract_function,
r#"
struct Struct<T, U>(T, U) where T: Into<i32>, U: Debug;
impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
let t = self.0;
$0t.into() + v.into()$0
}
}
"#,
r#"
struct Struct<T, U>(T, U) where T: Into<i32>, U: Debug;
impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
let t = self.0;
fun_name(t, v)
}
}
fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
t.into() + v.into()
}
"#,
);
}