fix: refine fallback case in generated PartialEq impl

This commit is contained in:
rami3l 2022-12-07 10:25:17 +08:00
parent fed74c8b71
commit 57fb18e3bd
2 changed files with 43 additions and 3 deletions

View file

@ -906,6 +906,33 @@ impl PartialEq for Foo {
) )
} }
#[test]
fn add_custom_impl_partial_eq_single_variant_tuple_enum() {
check_assist(
replace_derive_with_manual_impl,
r#"
//- minicore: eq, derive
#[derive(Partial$0Eq)]
enum Foo {
Bar(String),
}
"#,
r#"
enum Foo {
Bar(String),
}
impl PartialEq for Foo {
$0fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Bar(l0), Self::Bar(r0)) => l0 == r0,
}
}
}
"#,
)
}
#[test] #[test]
fn add_custom_impl_partial_eq_partial_tuple_enum() { fn add_custom_impl_partial_eq_partial_tuple_enum() {
check_assist( check_assist(
@ -959,7 +986,7 @@ impl PartialEq for Foo {
match (self, other) { match (self, other) {
(Self::Bar(l0), Self::Bar(r0)) => l0 == r0, (Self::Bar(l0), Self::Bar(r0)) => l0 == r0,
(Self::Baz(l0), Self::Baz(r0)) => l0 == r0, (Self::Baz(l0), Self::Baz(r0)) => l0 == r0,
_ => core::mem::discriminant(self) == core::mem::discriminant(other), _ => false,
} }
} }
} }

View file

@ -439,8 +439,10 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
let eq_check = let eq_check =
make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs); make::expr_bin_op(lhs, BinaryOp::CmpOp(CmpOp::Eq { negated: false }), rhs);
let mut n_cases = 0;
let mut arms = vec![]; let mut arms = vec![];
for variant in enum_.variant_list()?.variants() { for variant in enum_.variant_list()?.variants() {
n_cases += 1;
match variant.field_list() { match variant.field_list() {
// => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin, // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
Some(ast::FieldList::RecordFieldList(list)) => { Some(ast::FieldList::RecordFieldList(list)) => {
@ -514,8 +516,19 @@ fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
let expr = match arms.len() { let expr = match arms.len() {
0 => eq_check, 0 => eq_check,
_ => { arms_len => {
arms.push(make::match_arm(Some(make::wildcard_pat().into()), None, eq_check)); // Generate the fallback arm when this enum has >1 variants.
// The fallback arm will be `_ => false,` if we've already gone through every case where the variants of self and other match,
// and `_ => std::mem::discriminant(self) == std::mem::discriminant(other),` otherwise.
if n_cases > 1 {
let lhs = make::wildcard_pat().into();
let rhs = if arms_len == n_cases {
make::expr_literal("false").into()
} else {
eq_check
};
arms.push(make::match_arm(Some(lhs), None, rhs));
}
let match_target = make::expr_tuple(vec![lhs_name, rhs_name]); let match_target = make::expr_tuple(vec![lhs_name, rhs_name]);
let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1)); let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));