all tests work

This commit is contained in:
Jeroen Vannevel 2022-01-11 21:39:50 +00:00
parent aaec467cfd
commit 18fb5412b2
No known key found for this signature in database
GPG key ID: 78EF5F52F38C49BD

View file

@ -1,8 +1,8 @@
use itertools::Itertools;
use std::iter::successors;
use hir::TypeInfo;
use std::{iter::successors, collections::HashMap};
use syntax::{
algo::neighbor,
ast::{self, AstNode},
ast::{self, AstNode, Pat, MatchArm, HasName},
Direction,
};
@ -52,6 +52,7 @@ pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option
return false;
}
println!("Checking types");
return are_same_types(&current_arm_types, arm, ctx);
}
_ => false,
@ -90,35 +91,70 @@ pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option
)
}
fn contains_placeholder(a: &ast::MatchArm) -> bool {
fn contains_placeholder(a: &MatchArm) -> bool {
matches!(a.pat(), Some(ast::Pat::WildcardPat(..)))
}
fn are_same_types(
current_arm_types: &Vec<Option<hir::TypeInfo>>,
current_arm_types: &HashMap<String, Option<TypeInfo>>,
arm: &ast::MatchArm,
ctx: &AssistContext,
) -> bool {
let arm_types = get_arm_types(&ctx, &arm);
for i in 0..arm_types.len() {
let other_arm_type = &arm_types[i];
let current_arm_type = &current_arm_types[i];
if let (Some(other_arm_type), Some(current_arm_type)) = (other_arm_type, current_arm_type) {
return &other_arm_type.original == &current_arm_type.original;
for other_arm_type_entry in arm_types {
let current_arm_type = current_arm_types.get_key_value(&other_arm_type_entry.0);
if current_arm_type.is_none() {
println!("No corresponding type found for {:?}", {other_arm_type_entry});
return false;
}
let unwrapped_current_arm_type = current_arm_type.unwrap().1;
if let (Some(other_arm_type), Some(current_arm_type)) = (other_arm_type_entry.1, unwrapped_current_arm_type) {
if other_arm_type.original != current_arm_type.original {
println!("Type {:?} is different from {:?}", &other_arm_type.original, &current_arm_type.original);
return false;
}
}
}
return true;
}
fn get_arm_types(ctx: &AssistContext, arm: &ast::MatchArm) -> Vec<Option<hir::TypeInfo>> {
match arm.pat() {
Some(ast::Pat::TupleStructPat(tp)) => {
tp.fields().into_iter().map(|field| ctx.sema.type_of_pat(&field)).collect_vec()
fn get_arm_types(context: &AssistContext, arm: &MatchArm) -> HashMap<String, Option<TypeInfo>> {
let mut mapping: HashMap<String, Option<TypeInfo>> = HashMap::new();
fn recurse(pat: &Option<Pat>, map: &mut HashMap<String, Option<TypeInfo>>, ctx: &AssistContext) {
if let Some(local_pat) = pat {
println!("{:?}", pat);
match pat {
Some(ast::Pat::TupleStructPat(tuple)) => {
for field in tuple.fields() {
recurse(&Some(field), map, ctx);
}
_ => Vec::new(),
},
Some(ast::Pat::RecordPat(record)) => {
if let Some(field_list) = record.record_pat_field_list() {
for field in field_list.fields() {
recurse(&field.pat(), map, ctx);
}
}
},
Some(ast::Pat::IdentPat(ident_pat)) => {
if let Some(name) = ident_pat.name() {
println!("Found name: {:?}", name.text().to_string());
let pat_type = ctx.sema.type_of_pat(local_pat);
map.insert(name.text().to_string(), pat_type);
}
},
_ => (),
}
}
}
recurse(&arm.pat(), &mut mapping, &context);
return mapping;
}
#[cfg(test)]
mod tests {
@ -430,6 +466,7 @@ fn func() {
check_assist(
merge_match_arms,
r#"
fn func() {
let x = 'c';
match x {
@ -437,14 +474,17 @@ let x = 'c';
'c'..='z' => "",
_ => "other",
};
}
"#,
r#"
fn func() {
let x = 'c';
match x {
'a'..='j' | 'c'..='z' => "",
_ => "other",
};
}
"#,
);
}
@ -675,5 +715,3 @@ fn main(msg: Message) {
)
}
}