mirror of
https://github.com/rust-lang/rust-analyzer
synced 2025-01-15 22:54:00 +00:00
fix: Support generics in extract_function assist
This change attempts to resolve issue #7636: Extract into Function does not create a generic function with constraints when extracting generic code. In `FunctionBody::analyze_container`, we now traverse the `ancestors` in search of `AnyHasGenericParams`, and attach any `GenericParamList`s and `WhereClause`s we find to the `ContainerInfo`. Later, in `format_function`, we collect all the `GenericParam`s and `WherePred`s from the container, and filter them to keep only types matching `TypeParam`s used within the newly extracted function body or param list. We can then include the new `GenericParamList` and `WhereClause` in the new function definition. This change only impacts `TypeParam`s. `LifetimeParam`s and `ConstParam`s are out of scope for this change.
This commit is contained in:
parent
794ecd58a3
commit
075ab03851
2 changed files with 436 additions and 7 deletions
|
@ -3307,6 +3307,15 @@ impl Type {
|
|||
let tys = hir_ty::replace_errors_with_variables(&(self.ty.clone(), to.ty.clone()));
|
||||
hir_ty::could_coerce(db, self.env.clone(), &tys)
|
||||
}
|
||||
|
||||
pub fn as_type_param(&self, db: &dyn HirDatabase) -> Option<TypeParam> {
|
||||
match self.ty.kind(Interner) {
|
||||
TyKind::Placeholder(p) => Some(TypeParam {
|
||||
id: TypeParamId::from_unchecked(hir_ty::from_placeholder_idx(db, *p)),
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
|
@ -2,7 +2,9 @@ use std::iter;
|
|||
|
||||
use ast::make;
|
||||
use either::Either;
|
||||
use hir::{HasSource, HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo};
|
||||
use hir::{
|
||||
HasSource, HirDisplay, InFile, Local, ModuleDef, PathResolution, Semantics, TypeInfo, TypeParam,
|
||||
};
|
||||
use ide_db::{
|
||||
defs::{Definition, NameRefClass},
|
||||
famous_defs::FamousDefs,
|
||||
|
@ -18,7 +20,7 @@ use syntax::{
|
|||
ast::{
|
||||
self,
|
||||
edit::{AstNodeEdit, IndentLevel},
|
||||
AstNode,
|
||||
AstNode, HasGenericParams,
|
||||
},
|
||||
match_ast, ted, SyntaxElement,
|
||||
SyntaxKind::{self, COMMENT},
|
||||
|
@ -294,6 +296,8 @@ struct ContainerInfo {
|
|||
parent_loop: Option<SyntaxNode>,
|
||||
/// The function's return type, const's type etc.
|
||||
ret_type: Option<hir::Type>,
|
||||
generic_param_lists: Vec<ast::GenericParamList>,
|
||||
where_clauses: Vec<ast::WhereClause>,
|
||||
}
|
||||
|
||||
/// Control flow that is exported from extracted function
|
||||
|
@ -517,6 +521,24 @@ impl FunctionBody {
|
|||
}
|
||||
}
|
||||
|
||||
fn descendants(&self) -> impl Iterator<Item = SyntaxNode> {
|
||||
match self {
|
||||
FunctionBody::Expr(expr) => expr.syntax().descendants(),
|
||||
FunctionBody::Span { parent, .. } => parent.syntax().descendants(),
|
||||
}
|
||||
}
|
||||
|
||||
fn descendant_paths(&self) -> impl Iterator<Item = ast::Path> {
|
||||
self.descendants().filter_map(|node| {
|
||||
match_ast! {
|
||||
match node {
|
||||
ast::Path(it) => Some(it),
|
||||
_ => None
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn from_expr(expr: ast::Expr) -> Option<Self> {
|
||||
match expr {
|
||||
ast::Expr::BreakExpr(it) => it.expr().map(Self::Expr),
|
||||
|
@ -731,6 +753,7 @@ impl FunctionBody {
|
|||
parent_loop.get_or_insert(loop_.syntax().clone());
|
||||
}
|
||||
};
|
||||
|
||||
let (is_const, expr, ty) = loop {
|
||||
let anc = ancestors.next()?;
|
||||
break match_ast! {
|
||||
|
@ -798,7 +821,19 @@ impl FunctionBody {
|
|||
container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
|
||||
container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
|
||||
});
|
||||
Some(ContainerInfo { is_in_tail, is_const, parent_loop, ret_type: ty })
|
||||
|
||||
let parent = self.parent()?;
|
||||
let generic_param_lists = parent_generic_param_lists(&parent);
|
||||
let where_clauses = parent_where_clauses(&parent);
|
||||
|
||||
Some(ContainerInfo {
|
||||
is_in_tail,
|
||||
is_const,
|
||||
parent_loop,
|
||||
ret_type: ty,
|
||||
generic_param_lists,
|
||||
where_clauses,
|
||||
})
|
||||
}
|
||||
|
||||
fn return_ty(&self, ctx: &AssistContext) -> Option<RetType> {
|
||||
|
@ -955,6 +990,26 @@ impl FunctionBody {
|
|||
}
|
||||
}
|
||||
|
||||
fn parent_where_clauses(parent: &SyntaxNode) -> Vec<ast::WhereClause> {
|
||||
let mut where_clause: Vec<ast::WhereClause> = parent
|
||||
.ancestors()
|
||||
.filter_map(ast::AnyHasGenericParams::cast)
|
||||
.filter_map(|it| it.where_clause())
|
||||
.collect();
|
||||
where_clause.reverse();
|
||||
where_clause
|
||||
}
|
||||
|
||||
fn parent_generic_param_lists(parent: &SyntaxNode) -> Vec<ast::GenericParamList> {
|
||||
let mut generic_param_list: Vec<ast::GenericParamList> = parent
|
||||
.ancestors()
|
||||
.filter_map(ast::AnyHasGenericParams::cast)
|
||||
.filter_map(|it| it.generic_param_list())
|
||||
.collect();
|
||||
generic_param_list.reverse();
|
||||
generic_param_list
|
||||
}
|
||||
|
||||
/// checks if relevant var is used with `&mut` access inside body
|
||||
fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &FunctionBody) -> bool {
|
||||
usages
|
||||
|
@ -1362,37 +1417,154 @@ fn format_function(
|
|||
let const_kw = if fun.mods.is_const { "const " } else { "" };
|
||||
let async_kw = if fun.control_flow.is_async { "async " } else { "" };
|
||||
let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" };
|
||||
let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun);
|
||||
match ctx.config.snippet_cap {
|
||||
Some(_) => format_to!(
|
||||
fn_def,
|
||||
"\n\n{}{}{}{}fn $0{}{}",
|
||||
"\n\n{}{}{}{}fn $0{}",
|
||||
new_indent,
|
||||
const_kw,
|
||||
async_kw,
|
||||
unsafe_kw,
|
||||
fun.name,
|
||||
params
|
||||
),
|
||||
None => format_to!(
|
||||
fn_def,
|
||||
"\n\n{}{}{}{}fn {}{}",
|
||||
"\n\n{}{}{}{}fn {}",
|
||||
new_indent,
|
||||
const_kw,
|
||||
async_kw,
|
||||
unsafe_kw,
|
||||
fun.name,
|
||||
params
|
||||
),
|
||||
}
|
||||
|
||||
if let Some(generic_params) = generic_params {
|
||||
format_to!(fn_def, "{}", generic_params);
|
||||
}
|
||||
|
||||
format_to!(fn_def, "{}", params);
|
||||
|
||||
if let Some(ret_ty) = ret_ty {
|
||||
format_to!(fn_def, " {}", ret_ty);
|
||||
}
|
||||
|
||||
if let Some(where_clause) = where_clause {
|
||||
format_to!(fn_def, " {}", where_clause);
|
||||
}
|
||||
|
||||
format_to!(fn_def, " {}", body);
|
||||
|
||||
fn_def
|
||||
}
|
||||
|
||||
fn make_generic_params_and_where_clause(
|
||||
ctx: &AssistContext,
|
||||
fun: &Function,
|
||||
) -> (Option<ast::GenericParamList>, Option<ast::WhereClause>) {
|
||||
let used_type_params = fun.type_params(ctx);
|
||||
|
||||
let generic_param_list = make_generic_param_list(ctx, fun, &used_type_params);
|
||||
let where_clause = make_where_clause(ctx, fun, &used_type_params);
|
||||
|
||||
(generic_param_list, where_clause)
|
||||
}
|
||||
|
||||
fn make_generic_param_list(
|
||||
ctx: &AssistContext,
|
||||
fun: &Function,
|
||||
used_type_params: &[TypeParam],
|
||||
) -> Option<ast::GenericParamList> {
|
||||
let mut generic_params = fun
|
||||
.mods
|
||||
.generic_param_lists
|
||||
.iter()
|
||||
.flat_map(|parent_params| {
|
||||
parent_params
|
||||
.generic_params()
|
||||
.filter(|param| param_is_required(ctx, param, used_type_params))
|
||||
})
|
||||
.peekable();
|
||||
|
||||
if generic_params.peek().is_some() {
|
||||
Some(make::generic_param_list(generic_params))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn param_is_required(
|
||||
ctx: &AssistContext,
|
||||
param: &ast::GenericParam,
|
||||
used_type_params: &[TypeParam],
|
||||
) -> bool {
|
||||
match param {
|
||||
ast::GenericParam::ConstParam(_) | ast::GenericParam::LifetimeParam(_) => false,
|
||||
ast::GenericParam::TypeParam(type_param) => match &ctx.sema.to_def(type_param) {
|
||||
Some(def) => used_type_params.contains(def),
|
||||
_ => false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn make_where_clause(
|
||||
ctx: &AssistContext,
|
||||
fun: &Function,
|
||||
used_type_params: &[TypeParam],
|
||||
) -> Option<ast::WhereClause> {
|
||||
let mut predicates = fun
|
||||
.mods
|
||||
.where_clauses
|
||||
.iter()
|
||||
.flat_map(|parent_where_clause| {
|
||||
parent_where_clause
|
||||
.predicates()
|
||||
.filter(|pred| pred_is_required(ctx, pred, used_type_params))
|
||||
})
|
||||
.peekable();
|
||||
|
||||
if predicates.peek().is_some() {
|
||||
Some(make::where_clause(predicates))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn pred_is_required(
|
||||
ctx: &AssistContext,
|
||||
pred: &ast::WherePred,
|
||||
used_type_params: &[TypeParam],
|
||||
) -> bool {
|
||||
match resolved_type_param(ctx, pred) {
|
||||
Some(it) => used_type_params.contains(&it),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn resolved_type_param(ctx: &AssistContext, pred: &ast::WherePred) -> Option<TypeParam> {
|
||||
let path = match pred.ty()? {
|
||||
ast::Type::PathType(path_type) => path_type.path(),
|
||||
_ => None,
|
||||
}?;
|
||||
|
||||
match ctx.sema.resolve_path(&path)? {
|
||||
PathResolution::TypeParam(type_param) => Some(type_param),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl Function {
|
||||
/// Collect all the `TypeParam`s used in the `body` and `params`.
|
||||
fn type_params(&self, ctx: &AssistContext) -> Vec<TypeParam> {
|
||||
let type_params_in_descendant_paths =
|
||||
self.body.descendant_paths().filter_map(|it| match ctx.sema.resolve_path(&it) {
|
||||
Some(PathResolution::TypeParam(type_param)) => Some(type_param),
|
||||
_ => None,
|
||||
});
|
||||
let type_params_in_params = self.params.iter().filter_map(|p| p.ty.as_type_param(ctx.db()));
|
||||
type_params_in_descendant_paths.chain(type_params_in_params).collect()
|
||||
}
|
||||
|
||||
fn make_param_list(&self, ctx: &AssistContext, module: hir::Module) -> ast::ParamList {
|
||||
let self_param = self.self_param.clone();
|
||||
let params = self.params.iter().map(|param| param.to_param(ctx, module));
|
||||
|
@ -4872,6 +5044,254 @@ fn parent(factor: i32) {
|
|||
fn $0fun_name(v: &[i32; 3], factor: i32) {
|
||||
v.iter().map(|it| it * factor);
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserve_generics() {
|
||||
check_assist(
|
||||
extract_function,
|
||||
r#"
|
||||
fn func<T: Debug>(i: T) {
|
||||
$0foo(i);$0
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
fn func<T: Debug>(i: T) {
|
||||
fun_name(i);
|
||||
}
|
||||
|
||||
fn $0fun_name<T: Debug>(i: T) {
|
||||
foo(i);
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserve_generics_from_body() {
|
||||
check_assist(
|
||||
extract_function,
|
||||
r#"
|
||||
fn func<T: Default>() -> T {
|
||||
$0T::default()$0
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
fn func<T: Default>() -> T {
|
||||
fun_name()
|
||||
}
|
||||
|
||||
fn $0fun_name<T: Default>() -> T {
|
||||
T::default()
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_unused_generics() {
|
||||
check_assist(
|
||||
extract_function,
|
||||
r#"
|
||||
fn func<T: Debug, U: Copy>(i: T, u: U) {
|
||||
bar(u);
|
||||
$0foo(i);$0
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
fn func<T: Debug, U: Copy>(i: T, u: U) {
|
||||
bar(u);
|
||||
fun_name(i);
|
||||
}
|
||||
|
||||
fn $0fun_name<T: Debug>(i: T) {
|
||||
foo(i);
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_generic_param_list() {
|
||||
check_assist(
|
||||
extract_function,
|
||||
r#"
|
||||
fn func<T: Debug>(t: T, i: u32) {
|
||||
bar(t);
|
||||
$0foo(i);$0
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
fn func<T: Debug>(t: T, i: u32) {
|
||||
bar(t);
|
||||
fun_name(i);
|
||||
}
|
||||
|
||||
fn $0fun_name(i: u32) {
|
||||
foo(i);
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preserve_where_clause() {
|
||||
check_assist(
|
||||
extract_function,
|
||||
r#"
|
||||
fn func<T>(i: T) where T: Debug {
|
||||
$0foo(i);$0
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
fn func<T>(i: T) where T: Debug {
|
||||
fun_name(i);
|
||||
}
|
||||
|
||||
fn $0fun_name<T>(i: T) where T: Debug {
|
||||
foo(i);
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_unused_where_clause() {
|
||||
check_assist(
|
||||
extract_function,
|
||||
r#"
|
||||
fn func<T, U>(i: T, u: U) where T: Debug, U: Copy {
|
||||
bar(u);
|
||||
$0foo(i);$0
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
fn func<T, U>(i: T, u: U) where T: Debug, U: Copy {
|
||||
bar(u);
|
||||
fun_name(i);
|
||||
}
|
||||
|
||||
fn $0fun_name<T>(i: T) where T: Debug {
|
||||
foo(i);
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nested_generics() {
|
||||
check_assist(
|
||||
extract_function,
|
||||
r#"
|
||||
struct Struct<T: Into<i32>>(T);
|
||||
impl <T: Into<i32> + Copy> Struct<T> {
|
||||
fn func<V: Into<i32>>(&self, v: V) -> i32 {
|
||||
let t = self.0;
|
||||
$0t.into() + v.into()$0
|
||||
}
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
struct Struct<T: Into<i32>>(T);
|
||||
impl <T: Into<i32> + Copy> Struct<T> {
|
||||
fn func<V: Into<i32>>(&self, v: V) -> i32 {
|
||||
let t = self.0;
|
||||
fun_name(t, v)
|
||||
}
|
||||
}
|
||||
|
||||
fn $0fun_name<T: Into<i32> + Copy, V: Into<i32>>(t: T, v: V) -> i32 {
|
||||
t.into() + v.into()
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filters_unused_nested_generics() {
|
||||
check_assist(
|
||||
extract_function,
|
||||
r#"
|
||||
struct Struct<T: Into<i32>, U: Debug>(T, U);
|
||||
impl <T: Into<i32> + Copy, U: Debug> Struct<T, U> {
|
||||
fn func<V: Into<i32>>(&self, v: V) -> i32 {
|
||||
let t = self.0;
|
||||
$0t.into() + v.into()$0
|
||||
}
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
struct Struct<T: Into<i32>, U: Debug>(T, U);
|
||||
impl <T: Into<i32> + Copy, U: Debug> Struct<T, U> {
|
||||
fn func<V: Into<i32>>(&self, v: V) -> i32 {
|
||||
let t = self.0;
|
||||
fun_name(t, v)
|
||||
}
|
||||
}
|
||||
|
||||
fn $0fun_name<T: Into<i32> + Copy, V: Into<i32>>(t: T, v: V) -> i32 {
|
||||
t.into() + v.into()
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nested_where_clauses() {
|
||||
check_assist(
|
||||
extract_function,
|
||||
r#"
|
||||
struct Struct<T>(T) where T: Into<i32>;
|
||||
impl <T> Struct<T> where T: Into<i32> + Copy {
|
||||
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
|
||||
let t = self.0;
|
||||
$0t.into() + v.into()$0
|
||||
}
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
struct Struct<T>(T) where T: Into<i32>;
|
||||
impl <T> Struct<T> where T: Into<i32> + Copy {
|
||||
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
|
||||
let t = self.0;
|
||||
fun_name(t, v)
|
||||
}
|
||||
}
|
||||
|
||||
fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
|
||||
t.into() + v.into()
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filters_unused_nested_where_clauses() {
|
||||
check_assist(
|
||||
extract_function,
|
||||
r#"
|
||||
struct Struct<T, U>(T, U) where T: Into<i32>, U: Debug;
|
||||
impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
|
||||
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
|
||||
let t = self.0;
|
||||
$0t.into() + v.into()$0
|
||||
}
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
struct Struct<T, U>(T, U) where T: Into<i32>, U: Debug;
|
||||
impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
|
||||
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
|
||||
let t = self.0;
|
||||
fun_name(t, v)
|
||||
}
|
||||
}
|
||||
|
||||
fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
|
||||
t.into() + v.into()
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue