From ff2629d65147919a16c3446293943d3efa1d7894 Mon Sep 17 00:00:00 2001 From: morine0122 Date: Wed, 17 Apr 2024 20:39:41 +0900 Subject: [PATCH] Make generate function assist generate a function as a constructor if the name of function is new --- crates/hir/src/lib.rs | 8 + .../src/handlers/generate_function.rs | 288 ++++++++++++++++-- 2 files changed, 270 insertions(+), 26 deletions(-) diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 8556d35a43..80afab95c2 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -1489,6 +1489,14 @@ impl Adt { .map(|arena| arena.1.clone()) } + pub fn as_struct(&self) -> Option { + if let Self::Struct(v) = self { + Some(*v) + } else { + None + } + } + pub fn as_enum(&self) -> Option { if let Self::Enum(v) = self { Some(*v) diff --git a/crates/ide-assists/src/handlers/generate_function.rs b/crates/ide-assists/src/handlers/generate_function.rs index db94a21a6d..0fc122d623 100644 --- a/crates/ide-assists/src/handlers/generate_function.rs +++ b/crates/ide-assists/src/handlers/generate_function.rs @@ -1,6 +1,6 @@ use hir::{ - Adt, AsAssocItem, HasSource, HirDisplay, HirFileIdExt, Module, PathResolution, Semantics, Type, - TypeInfo, + Adt, AsAssocItem, HasSource, HirDisplay, HirFileIdExt, Module, PathResolution, Semantics, + StructKind, Type, TypeInfo, }; use ide_db::{ base_db::FileId, @@ -15,8 +15,8 @@ use itertools::Itertools; use stdx::to_lower_snake_case; use syntax::{ ast::{ - self, edit::IndentLevel, edit_in_place::Indent, make, AstNode, CallExpr, HasArgList, - HasGenericParams, HasModuleItem, HasTypeBounds, + self, edit::IndentLevel, edit_in_place::Indent, make, AstNode, BlockExpr, CallExpr, + HasArgList, HasGenericParams, HasModuleItem, HasTypeBounds, }, ted, SyntaxKind, SyntaxNode, TextRange, T, }; @@ -66,7 +66,7 @@ fn gen_fn(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { } let fn_name = &*name_ref.text(); - let TargetInfo { target_module, adt_name, target, file } = + let TargetInfo { target_module, adt_info, target, file } = fn_target_info(ctx, path, &call, fn_name)?; if let Some(m) = target_module { @@ -75,15 +75,16 @@ fn gen_fn(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { } } - let function_builder = FunctionBuilder::from_call(ctx, &call, fn_name, target_module, target)?; + let function_builder = + FunctionBuilder::from_call(ctx, &call, fn_name, target_module, target, &adt_info)?; let text_range = call.syntax().text_range(); let label = format!("Generate {} function", function_builder.fn_name); - add_func_to_accumulator(acc, ctx, text_range, function_builder, file, adt_name, label) + add_func_to_accumulator(acc, ctx, text_range, function_builder, file, adt_info, label) } struct TargetInfo { target_module: Option, - adt_name: Option, + adt_info: Option, target: GeneratedFunctionTarget, file: FileId, } @@ -91,11 +92,11 @@ struct TargetInfo { impl TargetInfo { fn new( target_module: Option, - adt_name: Option, + adt_info: Option, target: GeneratedFunctionTarget, file: FileId, ) -> Self { - Self { target_module, adt_name, target, file } + Self { target_module, adt_info, target, file } } } @@ -157,9 +158,9 @@ fn gen_method(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { target, )?; let text_range = call.syntax().text_range(); - let adt_name = if impl_.is_none() { Some(adt.name(ctx.sema.db)) } else { None }; + let adt_info = AdtInfo::new(adt, impl_.is_some()); let label = format!("Generate {} method", function_builder.fn_name); - add_func_to_accumulator(acc, ctx, text_range, function_builder, file, adt_name, label) + add_func_to_accumulator(acc, ctx, text_range, function_builder, file, Some(adt_info), label) } fn add_func_to_accumulator( @@ -168,7 +169,7 @@ fn add_func_to_accumulator( text_range: TextRange, function_builder: FunctionBuilder, file: FileId, - adt_name: Option, + adt_info: Option, label: String, ) -> Option<()> { acc.add(AssistId("generate_function", AssistKind::Generate), label, text_range, |edit| { @@ -177,8 +178,14 @@ fn add_func_to_accumulator( let target = function_builder.target.clone(); let func = function_builder.render(ctx.config.snippet_cap, edit); - if let Some(name) = adt_name { - let name = make::ty_path(make::ext::ident_path(&format!("{}", name.display(ctx.db())))); + if let Some(adt) = + adt_info + .and_then(|adt_info| if adt_info.impl_exists { None } else { Some(adt_info.adt) }) + { + let name = make::ty_path(make::ext::ident_path(&format!( + "{}", + adt.name(ctx.db()).display(ctx.db()) + ))); // FIXME: adt may have generic params. let impl_ = make::impl_(None, None, name, None, None).clone_for_update(); @@ -210,6 +217,7 @@ struct FunctionBuilder { generic_param_list: Option, where_clause: Option, params: ast::ParamList, + fn_body: BlockExpr, ret_type: Option, should_focus_return_type: bool, visibility: Visibility, @@ -225,6 +233,7 @@ impl FunctionBuilder { fn_name: &str, target_module: Option, target: GeneratedFunctionTarget, + adt_info: &Option, ) -> Option { let target_module = target_module.or_else(|| ctx.sema.scope(target.syntax()).map(|it| it.module()))?; @@ -243,9 +252,27 @@ impl FunctionBuilder { let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast); let is_async = await_expr.is_some(); - let expr_for_ret_ty = await_expr.map_or_else(|| call.clone().into(), |it| it.into()); - let (ret_type, should_focus_return_type) = - make_return_type(ctx, &expr_for_ret_ty, target_module, &mut necessary_generic_params); + let ret_type; + let should_focus_return_type; + let fn_body; + + // If generated function has the name "new" and is an associated function, we generate fn body + // as a constructor and assume a "Self" return type. + if let Some(body) = make_fn_body_as_new_function(ctx, &fn_name.text(), adt_info) { + ret_type = Some(make::ret_type(make::ty_path(make::ext::ident_path("Self")))); + should_focus_return_type = false; + fn_body = body; + } else { + let expr_for_ret_ty = await_expr.map_or_else(|| call.clone().into(), |it| it.into()); + (ret_type, should_focus_return_type) = make_return_type( + ctx, + &expr_for_ret_ty, + target_module, + &mut necessary_generic_params, + ); + let placeholder_expr = make::ext::expr_todo(); + fn_body = make::block_expr(vec![], Some(placeholder_expr)); + }; let (generic_param_list, where_clause) = fn_generic_params(ctx, necessary_generic_params, &target)?; @@ -256,6 +283,7 @@ impl FunctionBuilder { generic_param_list, where_clause, params, + fn_body, ret_type, should_focus_return_type, visibility, @@ -294,12 +322,16 @@ impl FunctionBuilder { let (generic_param_list, where_clause) = fn_generic_params(ctx, necessary_generic_params, &target)?; + let placeholder_expr = make::ext::expr_todo(); + let fn_body = make::block_expr(vec![], Some(placeholder_expr)); + Some(Self { target, fn_name, generic_param_list, where_clause, params, + fn_body, ret_type, should_focus_return_type, visibility, @@ -308,8 +340,6 @@ impl FunctionBuilder { } fn render(self, cap: Option, edit: &mut SourceChangeBuilder) -> ast::Fn { - let placeholder_expr = make::ext::expr_todo(); - let fn_body = make::block_expr(vec![], Some(placeholder_expr)); let visibility = match self.visibility { Visibility::None => None, Visibility::Crate => Some(make::visibility_pub_crate()), @@ -321,7 +351,7 @@ impl FunctionBuilder { self.generic_param_list, self.where_clause, self.params, - fn_body, + self.fn_body, self.ret_type, self.is_async, false, // FIXME : const and unsafe are not handled yet. @@ -391,6 +421,53 @@ fn make_return_type( (ret_type, should_focus_return_type) } +fn make_fn_body_as_new_function( + ctx: &AssistContext<'_>, + fn_name: &str, + adt_info: &Option, +) -> Option { + if fn_name != "new" { + return None; + }; + let adt_info = adt_info.as_ref()?; + + let path_self = make::ext::ident_path("Self"); + let placeholder_expr = make::ext::expr_todo(); + let tail_expr = if let Some(strukt) = adt_info.adt.as_struct() { + match strukt.kind(ctx.db()) { + StructKind::Record => { + let fields = strukt + .fields(ctx.db()) + .iter() + .map(|field| { + make::record_expr_field( + make::name_ref(&format!("{}", field.name(ctx.db()).display(ctx.db()))), + Some(placeholder_expr.clone()), + ) + }) + .collect::>(); + + make::record_expr(path_self, make::record_expr_field_list(fields)).into() + } + StructKind::Tuple => { + let args = strukt + .fields(ctx.db()) + .iter() + .map(|_| placeholder_expr.clone()) + .collect::>(); + + make::expr_call(make::expr_path(path_self), make::arg_list(args)) + } + StructKind::Unit => make::expr_path(path_self), + } + } else { + placeholder_expr + }; + + let fn_body = make::block_expr(vec![], Some(tail_expr)); + Some(fn_body) +} + fn get_fn_target_info( ctx: &AssistContext<'_>, target_module: Option, @@ -443,8 +520,8 @@ fn assoc_fn_target_info( } let (impl_, file) = get_adt_source(ctx, &adt, fn_name)?; let target = get_method_target(ctx, &impl_, &adt)?; - let adt_name = if impl_.is_none() { Some(adt.name(ctx.sema.db)) } else { None }; - Some(TargetInfo::new(target_module, adt_name, target, file)) + let adt_info = AdtInfo::new(adt, impl_.is_some()); + Some(TargetInfo::new(target_module, Some(adt_info), target, file)) } #[derive(Clone)] @@ -560,6 +637,17 @@ impl GeneratedFunctionTarget { } } +struct AdtInfo { + adt: hir::Adt, + impl_exists: bool, +} + +impl AdtInfo { + fn new(adt: Adt, impl_exists: bool) -> Self { + Self { adt, impl_exists } + } +} + /// Computes parameter list for the generated function. fn fn_args( ctx: &AssistContext<'_>, @@ -2758,18 +2846,18 @@ fn main() { r" enum Foo {} fn main() { - Foo::new$0(); + Foo::bar$0(); } ", r" enum Foo {} impl Foo { - fn new() ${0:-> _} { + fn bar() ${0:-> _} { todo!() } } fn main() { - Foo::new(); + Foo::bar(); } ", ) @@ -2849,4 +2937,152 @@ fn main() { ", ); } + + #[test] + fn new_function_assume_self_type() { + check_assist( + generate_function, + r" +pub struct Foo { + field_1: usize, + field_2: String, +} + +fn main() { + let foo = Foo::new$0(); +} + ", + r" +pub struct Foo { + field_1: usize, + field_2: String, +} +impl Foo { + fn new() -> Self { + ${0:Self { field_1: todo!(), field_2: todo!() }} + } +} + +fn main() { + let foo = Foo::new(); +} + ", + ) + } + + #[test] + fn new_function_assume_self_type_for_tuple_struct() { + check_assist( + generate_function, + r" +pub struct Foo (usize, String); + +fn main() { + let foo = Foo::new$0(); +} + ", + r" +pub struct Foo (usize, String); +impl Foo { + fn new() -> Self { + ${0:Self(todo!(), todo!())} + } +} + +fn main() { + let foo = Foo::new(); +} + ", + ) + } + + #[test] + fn new_function_assume_self_type_for_unit_struct() { + check_assist( + generate_function, + r" +pub struct Foo; + +fn main() { + let foo = Foo::new$0(); +} + ", + r" +pub struct Foo; +impl Foo { + fn new() -> Self { + ${0:Self} + } +} + +fn main() { + let foo = Foo::new(); +} + ", + ) + } + + #[test] + fn new_function_assume_self_type_for_enum() { + check_assist( + generate_function, + r" +pub enum Foo {} + +fn main() { + let foo = Foo::new$0(); +} + ", + r" +pub enum Foo {} +impl Foo { + fn new() -> Self { + ${0:todo!()} + } +} + +fn main() { + let foo = Foo::new(); +} + ", + ) + } + + #[test] + fn new_function_assume_self_type_with_args() { + check_assist( + generate_function, + r#" +pub struct Foo { + field_1: usize, + field_2: String, +} + +struct Baz; +fn baz() -> Baz { Baz } + +fn main() { + let foo = Foo::new$0(baz(), baz(), "foo", "bar"); +} + "#, + r#" +pub struct Foo { + field_1: usize, + field_2: String, +} +impl Foo { + fn new(baz_1: Baz, baz_2: Baz, arg_1: &str, arg_2: &str) -> Self { + ${0:Self { field_1: todo!(), field_2: todo!() }} + } +} + +struct Baz; +fn baz() -> Baz { Baz } + +fn main() { + let foo = Foo::new(baz(), baz(), "foo", "bar"); +} + "#, + ) + } }