mirror of
https://github.com/rust-lang/rust-analyzer
synced 2025-01-26 03:45:04 +00:00
all tests work
This commit is contained in:
parent
aaec467cfd
commit
18fb5412b2
1 changed files with 57 additions and 19 deletions
|
@ -1,8 +1,8 @@
|
|||
use itertools::Itertools;
|
||||
use std::iter::successors;
|
||||
use hir::TypeInfo;
|
||||
use std::{iter::successors, collections::HashMap};
|
||||
use syntax::{
|
||||
algo::neighbor,
|
||||
ast::{self, AstNode},
|
||||
ast::{self, AstNode, Pat, MatchArm, HasName},
|
||||
Direction,
|
||||
};
|
||||
|
||||
|
@ -52,6 +52,7 @@ pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option
|
|||
return false;
|
||||
}
|
||||
|
||||
println!("Checking types");
|
||||
return are_same_types(¤t_arm_types, arm, ctx);
|
||||
}
|
||||
_ => false,
|
||||
|
@ -90,35 +91,70 @@ pub(crate) fn merge_match_arms(acc: &mut Assists, ctx: &AssistContext) -> Option
|
|||
)
|
||||
}
|
||||
|
||||
fn contains_placeholder(a: &ast::MatchArm) -> bool {
|
||||
fn contains_placeholder(a: &MatchArm) -> bool {
|
||||
matches!(a.pat(), Some(ast::Pat::WildcardPat(..)))
|
||||
}
|
||||
|
||||
fn are_same_types(
|
||||
current_arm_types: &Vec<Option<hir::TypeInfo>>,
|
||||
current_arm_types: &HashMap<String, Option<TypeInfo>>,
|
||||
arm: &ast::MatchArm,
|
||||
ctx: &AssistContext,
|
||||
) -> bool {
|
||||
let arm_types = get_arm_types(&ctx, &arm);
|
||||
for i in 0..arm_types.len() {
|
||||
let other_arm_type = &arm_types[i];
|
||||
let current_arm_type = ¤t_arm_types[i];
|
||||
if let (Some(other_arm_type), Some(current_arm_type)) = (other_arm_type, current_arm_type) {
|
||||
return &other_arm_type.original == ¤t_arm_type.original;
|
||||
for other_arm_type_entry in arm_types {
|
||||
let current_arm_type = current_arm_types.get_key_value(&other_arm_type_entry.0);
|
||||
if current_arm_type.is_none() {
|
||||
println!("No corresponding type found for {:?}", {other_arm_type_entry});
|
||||
return false;
|
||||
}
|
||||
|
||||
let unwrapped_current_arm_type = current_arm_type.unwrap().1;
|
||||
|
||||
if let (Some(other_arm_type), Some(current_arm_type)) = (other_arm_type_entry.1, unwrapped_current_arm_type) {
|
||||
if other_arm_type.original != current_arm_type.original {
|
||||
println!("Type {:?} is different from {:?}", &other_arm_type.original, ¤t_arm_type.original);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
fn get_arm_types(ctx: &AssistContext, arm: &ast::MatchArm) -> Vec<Option<hir::TypeInfo>> {
|
||||
match arm.pat() {
|
||||
Some(ast::Pat::TupleStructPat(tp)) => {
|
||||
tp.fields().into_iter().map(|field| ctx.sema.type_of_pat(&field)).collect_vec()
|
||||
fn get_arm_types(context: &AssistContext, arm: &MatchArm) -> HashMap<String, Option<TypeInfo>> {
|
||||
let mut mapping: HashMap<String, Option<TypeInfo>> = HashMap::new();
|
||||
|
||||
fn recurse(pat: &Option<Pat>, map: &mut HashMap<String, Option<TypeInfo>>, ctx: &AssistContext) {
|
||||
if let Some(local_pat) = pat {
|
||||
println!("{:?}", pat);
|
||||
match pat {
|
||||
Some(ast::Pat::TupleStructPat(tuple)) => {
|
||||
for field in tuple.fields() {
|
||||
recurse(&Some(field), map, ctx);
|
||||
}
|
||||
_ => Vec::new(),
|
||||
},
|
||||
Some(ast::Pat::RecordPat(record)) => {
|
||||
if let Some(field_list) = record.record_pat_field_list() {
|
||||
for field in field_list.fields() {
|
||||
recurse(&field.pat(), map, ctx);
|
||||
}
|
||||
}
|
||||
},
|
||||
Some(ast::Pat::IdentPat(ident_pat)) => {
|
||||
if let Some(name) = ident_pat.name() {
|
||||
println!("Found name: {:?}", name.text().to_string());
|
||||
let pat_type = ctx.sema.type_of_pat(local_pat);
|
||||
map.insert(name.text().to_string(), pat_type);
|
||||
}
|
||||
},
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
recurse(&arm.pat(), &mut mapping, &context);
|
||||
return mapping;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
@ -430,6 +466,7 @@ fn func() {
|
|||
check_assist(
|
||||
merge_match_arms,
|
||||
r#"
|
||||
fn func() {
|
||||
let x = 'c';
|
||||
|
||||
match x {
|
||||
|
@ -437,14 +474,17 @@ let x = 'c';
|
|||
'c'..='z' => "",
|
||||
_ => "other",
|
||||
};
|
||||
}
|
||||
"#,
|
||||
r#"
|
||||
fn func() {
|
||||
let x = 'c';
|
||||
|
||||
match x {
|
||||
'a'..='j' | 'c'..='z' => "",
|
||||
_ => "other",
|
||||
};
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
@ -675,5 +715,3 @@ fn main(msg: Message) {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue