From 57fb18e3bdc44e0ec390ec48b501f85bb78218e8 Mon Sep 17 00:00:00 2001 From: rami3l Date: Wed, 7 Dec 2022 10:25:17 +0800 Subject: [PATCH] fix: refine fallback case in generated `PartialEq` impl --- .../replace_derive_with_manual_impl.rs | 29 ++++++++++++++++++- .../src/utils/gen_trait_fn_body.rs | 17 +++++++++-- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs index a1ca286121..2854701c08 100644 --- a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs +++ b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs @@ -906,6 +906,33 @@ impl PartialEq for Foo { ) } + #[test] + fn add_custom_impl_partial_eq_single_variant_tuple_enum() { + check_assist( + replace_derive_with_manual_impl, + r#" +//- minicore: eq, derive +#[derive(Partial$0Eq)] +enum Foo { + Bar(String), +} +"#, + r#" +enum Foo { + Bar(String), +} + +impl PartialEq for Foo { + $0fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::Bar(l0), Self::Bar(r0)) => l0 == r0, + } + } +} +"#, + ) + } + #[test] fn add_custom_impl_partial_eq_partial_tuple_enum() { check_assist( @@ -959,7 +986,7 @@ impl PartialEq for Foo { match (self, other) { (Self::Bar(l0), Self::Bar(r0)) => l0 == r0, (Self::Baz(l0), Self::Baz(r0)) => l0 == r0, - _ => core::mem::discriminant(self) == core::mem::discriminant(other), + _ => false, } } } diff --git a/crates/ide-assists/src/utils/gen_trait_fn_body.rs b/crates/ide-assists/src/utils/gen_trait_fn_body.rs index 287001af84..f32e5ce97d 100644 --- a/crates/ide-assists/src/utils/gen_trait_fn_body.rs +++ b/crates/ide-assists/src/utils/gen_trait_fn_body.rs @@ -439,8 +439,10 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> { let eq_check = make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs); + let mut n_cases = 0; let mut arms = vec![]; for variant in enum_.variant_list()?.variants() { + n_cases += 1; match variant.field_list() { // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin, Some(ast::FieldList::RecordFieldList(list)) => { @@ -514,8 +516,19 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> { let expr = match arms.len() { 0 => eq_check, - _ => { - arms.push(make::match_arm(Some(make::wildcard_pat().into()), None, eq_check)); + arms_len => { + // Generate the fallback arm when this enum has >1 variants. + // The fallback arm will be `_ => false,` if we've already gone through every case where the variants of self and other match, + // and `_ => std::mem::discriminant(self) == std::mem::discriminant(other),` otherwise. + if n_cases > 1 { + let lhs = make::wildcard_pat().into(); + let rhs = if arms_len == n_cases { + make::expr_literal("false").into() + } else { + eq_check + }; + arms.push(make::match_arm(Some(lhs), None, rhs)); + } let match_target = make::expr_tuple(vec![lhs_name, rhs_name]); let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));