Auto merge of #14291 - HKalbasi:master, r=HKalbasi

fix multiple definition binding in match to let-else

fix #14290
This commit is contained in:
bors 2023-03-09 11:52:22 +00:00
commit 8e404f4928

View file

@ -1,6 +1,6 @@
use ide_db::defs::{Definition, NameRefClass}; use ide_db::defs::{Definition, NameRefClass};
use syntax::{ use syntax::{
ast::{self, HasName}, ast::{self, HasName, Name},
ted, AstNode, SyntaxNode, ted, AstNode, SyntaxNode,
}; };
@ -48,7 +48,7 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'
other => format!("{{ {other} }}"), other => format!("{{ {other} }}"),
}; };
let extracting_arm_pat = extracting_arm.pat()?; let extracting_arm_pat = extracting_arm.pat()?;
let extracted_variable = find_extracted_variable(ctx, &extracting_arm)?; let extracted_variable_positions = find_extracted_variable(ctx, &extracting_arm)?;
acc.add( acc.add(
AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite), AssistId("convert_match_to_let_else", AssistKind::RefactorRewrite),
@ -56,7 +56,7 @@ pub(crate) fn convert_match_to_let_else(acc: &mut Assists, ctx: &AssistContext<'
let_stmt.syntax().text_range(), let_stmt.syntax().text_range(),
|builder| { |builder| {
let extracting_arm_pat = let extracting_arm_pat =
rename_variable(&extracting_arm_pat, extracted_variable, binding); rename_variable(&extracting_arm_pat, &extracted_variable_positions, binding);
builder.replace( builder.replace(
let_stmt.syntax().text_range(), let_stmt.syntax().text_range(),
format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"), format!("let {extracting_arm_pat} = {initializer_expr} else {diverging_arm_expr};"),
@ -95,14 +95,15 @@ fn find_arms(
} }
// Given an extracting arm, find the extracted variable. // Given an extracting arm, find the extracted variable.
fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<ast::Name> { fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Option<Vec<Name>> {
match arm.expr()? { match arm.expr()? {
ast::Expr::PathExpr(path) => { ast::Expr::PathExpr(path) => {
let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?; let name_ref = path.syntax().descendants().find_map(ast::NameRef::cast)?;
match NameRefClass::classify(&ctx.sema, &name_ref)? { match NameRefClass::classify(&ctx.sema, &name_ref)? {
NameRefClass::Definition(Definition::Local(local)) => { NameRefClass::Definition(Definition::Local(local)) => {
let source = local.primary_source(ctx.db()).into_ident_pat()?; let source =
Some(source.name()?) local.sources(ctx.db()).into_iter().map(|x| x.into_ident_pat()?.name());
source.collect()
} }
_ => None, _ => None,
} }
@ -115,27 +116,34 @@ fn find_extracted_variable(ctx: &AssistContext<'_>, arm: &ast::MatchArm) -> Opti
} }
// Rename `extracted` with `binding` in `pat`. // Rename `extracted` with `binding` in `pat`.
fn rename_variable(pat: &ast::Pat, extracted: ast::Name, binding: ast::Pat) -> SyntaxNode { fn rename_variable(pat: &ast::Pat, extracted: &[Name], binding: ast::Pat) -> SyntaxNode {
let syntax = pat.syntax().clone_for_update(); let syntax = pat.syntax().clone_for_update();
let extracted_syntax = syntax.covering_element(extracted.syntax().text_range()); let extracted = extracted
.iter()
.map(|e| syntax.covering_element(e.syntax().text_range()))
.collect::<Vec<_>>();
for extracted_syntax in extracted {
// If `extracted` variable is a record field, we should rename it to `binding`,
// otherwise we just need to replace `extracted` with `binding`.
// If `extracted` variable is a record field, we should rename it to `binding`, if let Some(record_pat_field) =
// otherwise we just need to replace `extracted` with `binding`. extracted_syntax.ancestors().find_map(ast::RecordPatField::cast)
{
if let Some(record_pat_field) = extracted_syntax.ancestors().find_map(ast::RecordPatField::cast) if let Some(name_ref) = record_pat_field.field_name() {
{ ted::replace(
if let Some(name_ref) = record_pat_field.field_name() { record_pat_field.syntax(),
ted::replace( ast::make::record_pat_field(
record_pat_field.syntax(), ast::make::name_ref(&name_ref.text()),
ast::make::record_pat_field(ast::make::name_ref(&name_ref.text()), binding) binding.clone(),
)
.syntax() .syntax()
.clone_for_update(), .clone_for_update(),
); );
}
} else {
ted::replace(extracted_syntax, binding.clone().syntax().clone_for_update());
} }
} else {
ted::replace(extracted_syntax, binding.syntax().clone_for_update());
} }
syntax syntax
} }
@ -162,6 +170,39 @@ fn foo(opt: Option<()>) {
); );
} }
#[test]
fn or_pattern_multiple_binding() {
check_assist(
convert_match_to_let_else,
r#"
//- minicore: option
enum Foo {
A(u32),
B(u32),
C(String),
}
fn foo(opt: Option<Foo>) -> Result<u32, ()> {
let va$0lue = match opt {
Some(Foo::A(it) | Foo::B(it)) => it,
_ => return Err(()),
};
}
"#,
r#"
enum Foo {
A(u32),
B(u32),
C(String),
}
fn foo(opt: Option<Foo>) -> Result<u32, ()> {
let Some(Foo::A(value) | Foo::B(value)) = opt else { return Err(()) };
}
"#,
);
}
#[test] #[test]
fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() { fn should_not_be_applicable_if_extracting_arm_is_not_an_identity_expr() {
cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2); cov_mark::check_count!(extracting_arm_is_not_an_identity_expr, 2);