diff --git a/crates/ra_assists/src/handlers/extract_struct_from_enum_variant.rs b/crates/ra_assists/src/handlers/extract_struct_from_enum_variant.rs index 3592838029..d5397bf213 100644 --- a/crates/ra_assists/src/handlers/extract_struct_from_enum_variant.rs +++ b/crates/ra_assists/src/handlers/extract_struct_from_enum_variant.rs @@ -39,15 +39,16 @@ pub(crate) fn extract_struct_from_enum(acc: &mut Assists, ctx: &AssistContext) - _ => return None, }; let variant_name = variant.name()?.to_string(); - let enum_ast = variant.parent_enum(); - let enum_name = enum_ast.name()?.to_string(); - let visibility = enum_ast.visibility(); let variant_hir = ctx.sema.to_def(&variant)?; - if existing_struct_def(ctx.db, &variant_name, &variant_hir) { return None; } - + let enum_ast = variant.parent_enum(); + let enum_name = enum_ast.name()?.to_string(); + let visibility = enum_ast.visibility(); + let current_module_def = + ImportsLocator::new(ctx.db).find_imports(&enum_name).first()?.left()?; + let current_module = current_module_def.module(ctx.db)?; let target = variant.syntax().text_range(); return acc.add_in_multiple_files( AssistId("extract_struct_from_enum_variant"), @@ -56,10 +57,9 @@ pub(crate) fn extract_struct_from_enum(acc: &mut Assists, ctx: &AssistContext) - |edit| { let definition = Definition::ModuleDef(ModuleDef::EnumVariant(variant_hir)); let res = definition.find_usages(&ctx.db, None); - let module_def = mod_def_for_target_module(ctx, &enum_name); let start_offset = variant.parent_enum().syntax().text_range().start(); let mut visited_modules_set: FxHashSet = FxHashSet::default(); - visited_modules_set.insert(module_def.module(ctx.db).unwrap()); + visited_modules_set.insert(current_module); for reference in res { let source_file = ctx.sema.parse(reference.file_range.file_id); update_reference( @@ -67,7 +67,7 @@ pub(crate) fn extract_struct_from_enum(acc: &mut Assists, ctx: &AssistContext) - edit, reference, &source_file, - &module_def, + ¤t_module_def, &mut visited_modules_set, ); } @@ -95,10 +95,6 @@ fn existing_struct_def(db: &RootDatabase, variant_name: &str, variant: &EnumVari .any(|(name, _)| name.to_string() == variant_name.to_string()) } -fn mod_def_for_target_module(ctx: &AssistContext, enum_name: &str) -> ModuleDef { - ImportsLocator::new(ctx.db).find_imports(enum_name).first().unwrap().left().unwrap() -} - fn insert_import( ctx: &AssistContext, builder: &mut AssistBuilder, @@ -186,23 +182,16 @@ fn update_reference( let call = path_expr.syntax().parent().and_then(ast::CallExpr::cast)?; let list = call.arg_list()?; let segment = path_expr.path()?.segment()?; + let segment_name = segment.name_ref()?; + let module = ctx.sema.scope(&path_expr.syntax()).module()?; let list_range = list.syntax().text_range(); let inside_list_range = TextRange::new( list_range.start().checked_add(TextSize::from(1))?, list_range.end().checked_sub(TextSize::from(1))?, ); edit.perform(reference.file_range.file_id, |builder| { - let module = ctx.sema.scope(&path_expr.syntax()).module().unwrap(); if !visited_modules_set.contains(&module) { - if insert_import( - ctx, - builder, - &path_expr, - &module, - module_def, - segment.name_ref().unwrap(), - ) - .is_some() + if insert_import(ctx, builder, &path_expr, &module, module_def, segment_name).is_some() { visited_modules_set.insert(module); }