diff --git a/crates/ide_assists/src/handlers/merge_match_arms.rs b/crates/ide_assists/src/handlers/merge_match_arms.rs index 49543861c2..7f9b5f4597 100644 --- a/crates/ide_assists/src/handlers/merge_match_arms.rs +++ b/crates/ide_assists/src/handlers/merge_match_arms.rs @@ -1,8 +1,8 @@ -use itertools::Itertools; -use std::iter::successors; +use hir::TypeInfo; +use std::{iter::successors, collections::HashMap}; use syntax::{ algo::neighbor, - ast::{self, AstNode}, + ast::{self, AstNode, Pat, MatchArm, HasName}, Direction, }; @@ -52,6 +52,7 @@ pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option return false; } + println!("Checking types"); return are_same_types(¤t_arm_types, arm, ctx); } _ => false, @@ -90,34 +91,69 @@ pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option ) } -fn contains_placeholder(a: &ast::MatchArm) -> bool { +fn contains_placeholder(a: &MatchArm) -> bool { matches!(a.pat(), Some(ast::Pat::WildcardPat(..))) } fn are_same_types( - current_arm_types: &Vec>, + current_arm_types: &HashMap>, arm: &ast::MatchArm, ctx: &AssistContext, ) -> bool { let arm_types = get_arm_types(&ctx, &arm); - for i in 0..arm_types.len() { - let other_arm_type = &arm_types[i]; - let current_arm_type = ¤t_arm_types[i]; - if let (Some(other_arm_type), Some(current_arm_type)) = (other_arm_type, current_arm_type) { - return &other_arm_type.original == ¤t_arm_type.original; + for other_arm_type_entry in arm_types { + let current_arm_type = current_arm_types.get_key_value(&other_arm_type_entry.0); + if current_arm_type.is_none() { + println!("No corresponding type found for {:?}", {other_arm_type_entry}); + return false; + } + + let unwrapped_current_arm_type = current_arm_type.unwrap().1; + + if let (Some(other_arm_type), Some(current_arm_type)) = (other_arm_type_entry.1, unwrapped_current_arm_type) { + if other_arm_type.original != current_arm_type.original { + println!("Type {:?} is different from {:?}", &other_arm_type.original, ¤t_arm_type.original); + return false; + } } } return true; } -fn get_arm_types(ctx: &AssistContext, arm: &ast::MatchArm) -> Vec> { - match arm.pat() { - Some(ast::Pat::TupleStructPat(tp)) => { - tp.fields().into_iter().map(|field| ctx.sema.type_of_pat(&field)).collect_vec() +fn get_arm_types(context: &AssistContext, arm: &MatchArm) -> HashMap> { + let mut mapping: HashMap> = HashMap::new(); + + fn recurse(pat: &Option, map: &mut HashMap>, ctx: &AssistContext) { + if let Some(local_pat) = pat { + println!("{:?}", pat); + match pat { + Some(ast::Pat::TupleStructPat(tuple)) => { + for field in tuple.fields() { + recurse(&Some(field), map, ctx); + } + }, + Some(ast::Pat::RecordPat(record)) => { + if let Some(field_list) = record.record_pat_field_list() { + for field in field_list.fields() { + recurse(&field.pat(), map, ctx); + } + } + }, + Some(ast::Pat::IdentPat(ident_pat)) => { + if let Some(name) = ident_pat.name() { + println!("Found name: {:?}", name.text().to_string()); + let pat_type = ctx.sema.type_of_pat(local_pat); + map.insert(name.text().to_string(), pat_type); + } + }, + _ => (), + } } - _ => Vec::new(), } + + recurse(&arm.pat(), &mut mapping, &context); + return mapping; } #[cfg(test)] @@ -430,21 +466,25 @@ fn func() { check_assist( merge_match_arms, r#" -let x = 'c'; +fn func() { + let x = 'c'; match x { 'a'..='j' => $0"", 'c'..='z' => "", _ => "other", }; +} "#, r#" -let x = 'c'; +fn func() { + let x = 'c'; match x { 'a'..='j' | 'c'..='z' => "", _ => "other", }; +} "#, ); } @@ -675,5 +715,3 @@ fn main(msg: Message) { ) } } - -