diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs index c41654fc2e..84563fd892 100644 --- a/crates/hir/src/semantics.rs +++ b/crates/hir/src/semantics.rs @@ -884,6 +884,7 @@ to_def_impls![ (crate::Local, ast::IdentPat, bind_pat_to_def), (crate::Local, ast::SelfParam, self_param_to_def), (crate::Label, ast::Label, label_to_def), + (crate::Adt, ast::Adt, adt_to_def), ]; fn find_root(node: &SyntaxNode) -> SyntaxNode { diff --git a/crates/hir/src/semantics/source_to_def.rs b/crates/hir/src/semantics/source_to_def.rs index 6beb664ffb..93b78a1a16 100644 --- a/crates/hir/src/semantics/source_to_def.rs +++ b/crates/hir/src/semantics/source_to_def.rs @@ -91,9 +91,9 @@ use hir_def::{ dyn_map::DynMap, expr::{LabelId, PatId}, keys::{self, Key}, - ConstId, ConstParamId, DefWithBodyId, EnumId, EnumVariantId, FieldId, FunctionId, GenericDefId, - ImplId, LifetimeParamId, ModuleId, StaticId, StructId, TraitId, TypeAliasId, TypeParamId, - UnionId, VariantId, + AdtId, ConstId, ConstParamId, DefWithBodyId, EnumId, EnumVariantId, FieldId, FunctionId, + GenericDefId, ImplId, LifetimeParamId, ModuleId, StaticId, StructId, TraitId, TypeAliasId, + TypeParamId, UnionId, VariantId, }; use hir_expand::{name::AsName, AstId, MacroCallId, MacroDefId, MacroDefKind}; use rustc_hash::FxHashMap; @@ -201,6 +201,18 @@ impl SourceToDefCtx<'_, '_> { ) -> Option { self.to_def(src, keys::VARIANT) } + pub(super) fn adt_to_def( + &mut self, + InFile { file_id, value }: InFile, + ) -> Option { + match value { + ast::Adt::Enum(it) => self.enum_to_def(InFile::new(file_id, it)).map(AdtId::EnumId), + ast::Adt::Struct(it) => { + self.struct_to_def(InFile::new(file_id, it)).map(AdtId::StructId) + } + ast::Adt::Union(it) => self.union_to_def(InFile::new(file_id, it)).map(AdtId::UnionId), + } + } pub(super) fn bind_pat_to_def( &mut self, src: InFile, diff --git a/crates/ide_assists/src/handlers/generate_function.rs b/crates/ide_assists/src/handlers/generate_function.rs index 754e995e57..a0e682c4ce 100644 --- a/crates/ide_assists/src/handlers/generate_function.rs +++ b/crates/ide_assists/src/handlers/generate_function.rs @@ -1,4 +1,4 @@ -use hir::{HasSource, HirDisplay, InFile, Module, TypeInfo}; +use hir::{HasSource, HirDisplay, Module, TypeInfo}; use ide_db::{base_db::FileId, helpers::SnippetCap}; use rustc_hash::{FxHashMap, FxHashSet}; use stdx::to_lower_snake_case; @@ -106,31 +106,28 @@ fn gen_fn(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { fn gen_method(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { let call: ast::MethodCallExpr = ctx.find_node_at_offset()?; - let fn_name: ast::NameRef = ast::NameRef::cast( - call.syntax().children().find(|child| child.kind() == SyntaxKind::NAME_REF)?, - )?; - let ty = ctx.sema.type_of_expr(&call.receiver()?)?.original().strip_references().as_adt()?; + let fn_name = call.name_ref()?; + let adt = ctx.sema.type_of_expr(&call.receiver()?)?.original().strip_references().as_adt()?; - let current_module = - ctx.sema.scope(ctx.find_node_at_offset::()?.syntax()).module()?; - let target_module = ty.module(ctx.sema.db); + let current_module = ctx.sema.scope(call.syntax()).module()?; + let target_module = adt.module(ctx.sema.db); if current_module.krate() != target_module.krate() { return None; } - let (impl_, file) = match ty { - hir::Adt::Struct(strukt) => get_impl(strukt.source(ctx.sema.db)?.syntax(), &fn_name, ctx), - hir::Adt::Enum(en) => get_impl(en.source(ctx.sema.db)?.syntax(), &fn_name, ctx), - hir::Adt::Union(union) => get_impl(union.source(ctx.sema.db)?.syntax(), &fn_name, ctx), - }?; + let range = adt.source(ctx.sema.db)?.syntax().original_file_range(ctx.sema.db); + let file = ctx.sema.parse(range.file_id); + let adt_source = + ctx.sema.find_node_at_offset_with_macros(file.syntax(), range.range.start())?; + let impl_ = find_struct_impl(ctx, &adt_source, fn_name.text().as_str())?; let function_builder = FunctionBuilder::from_method_call( ctx, &call, &fn_name, &impl_, - file, + range.file_id, target_module, current_module, )?; @@ -145,7 +142,7 @@ fn gen_method(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { builder.edit_file(function_template.file); let mut new_fn = function_template.to_string(ctx.config.snippet_cap); if impl_.is_none() { - new_fn = format!("\nimpl {} {{\n{}\n}}", ty.name(ctx.sema.db), new_fn,); + new_fn = format!("\nimpl {} {{\n{}\n}}", adt.name(ctx.sema.db), new_fn,); } match ctx.config.snippet_cap { Some(cap) => builder.insert_snippet(cap, function_template.insert_offset, new_fn), @@ -155,18 +152,6 @@ fn gen_method(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { ) } -fn get_impl( - adt: InFile<&SyntaxNode>, - fn_name: &ast::NameRef, - ctx: &AssistContext, -) -> Option<(Option, FileId)> { - let file = adt.file_id.original_file(ctx.sema.db); - let adt = adt.value; - let adt = ast::Adt::cast(adt.clone())?; - let r = find_struct_impl(ctx, &adt, fn_name.text().as_str())?; - Some((r, file)) -} - struct FunctionTemplate { insert_offset: TextSize, leading_ws: String, diff --git a/crates/ide_assists/src/utils.rs b/crates/ide_assists/src/utils.rs index e59fc6ba4f..edbc6dd609 100644 --- a/crates/ide_assists/src/utils.rs +++ b/crates/ide_assists/src/utils.rs @@ -5,7 +5,7 @@ mod gen_trait_fn_body; use std::ops; -use hir::{Adt, HasSource}; +use hir::HasSource; use ide_db::{helpers::SnippetCap, path_transform::PathTransform, RootDatabase}; use itertools::Itertools; use stdx::format_to; @@ -290,19 +290,13 @@ pub(crate) fn does_pat_match_variant(pat: &ast::Pat, var: &ast::Pat) -> bool { // FIXME: this partially overlaps with `find_impl_block_*` pub(crate) fn find_struct_impl( ctx: &AssistContext, - strukt: &ast::Adt, + adt: &ast::Adt, name: &str, ) -> Option> { let db = ctx.db(); - let module = strukt.syntax().ancestors().find(|node| { - ast::Module::can_cast(node.kind()) || ast::SourceFile::can_cast(node.kind()) - })?; + let module = adt.syntax().parent()?; - let struct_def = match strukt { - ast::Adt::Enum(e) => Adt::Enum(ctx.sema.to_def(e)?), - ast::Adt::Struct(s) => Adt::Struct(ctx.sema.to_def(s)?), - ast::Adt::Union(u) => Adt::Union(ctx.sema.to_def(u)?), - }; + let struct_def = ctx.sema.to_def(adt)?; let block = module.descendants().filter_map(ast::Impl::cast).find_map(|impl_blk| { let blk = ctx.sema.to_def(&impl_blk)?;