Handle not let if expressions in replace_if_let_with_match

This commit is contained in:
Lukas Wirth 2021-07-02 00:58:56 +02:00
parent 20be999304
commit 8967856d78
2 changed files with 83 additions and 24 deletions

View file

@ -1,5 +1,6 @@
use std::iter::{self, successors};
use either::Either;
use ide_db::{ty_filter::TryEnum, RootDatabase};
use syntax::{
ast::{
@ -53,17 +54,30 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
});
let scrutinee_to_be_expr = if_expr.condition()?.expr()?;
let mut pat_bodies = Vec::new();
let mut pat_seen = false;
let mut cond_bodies = Vec::new();
for if_expr in if_exprs {
let cond = if_expr.condition()?;
let expr = cond.expr()?;
if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() {
// Only if all condition expressions are equal we can merge them into a match
return None;
}
let pat = cond.pat()?;
let cond = match cond.pat() {
Some(pat) => {
if scrutinee_to_be_expr.syntax().text() != expr.syntax().text() {
// Only if all condition expressions are equal we can merge them into a match
return None;
} else {
pat_seen = true;
Either::Left(pat)
}
}
None => Either::Right(expr),
};
let body = if_expr.then_branch()?;
pat_bodies.push((pat, body));
cond_bodies.push((cond, body));
}
if !pat_seen {
// Don't offer turning an if (chain) without patterns into a match
return None;
}
let target = if_expr.syntax().text_range();
@ -76,8 +90,8 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
let else_arm = {
match else_block {
Some(else_block) => {
let pattern = match &*pat_bodies {
[(pat, _)] => ctx
let pattern = match &*cond_bodies {
[(Either::Left(pat), _)] => ctx
.sema
.type_of_pat(&pat)
.and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty))
@ -99,23 +113,34 @@ pub(crate) fn replace_if_let_with_match(acc: &mut Assists, ctx: &AssistContext)
),
}
};
let arms = pat_bodies
let arms = cond_bodies
.into_iter()
.map(|(pat, body)| {
let body = body.reset_indent().indent(IndentLevel(1));
make::match_arm(vec![pat], unwrap_trivial_block(body))
match pat {
Either::Left(pat) => {
make::match_arm(iter::once(pat), unwrap_trivial_block(body))
}
Either::Right(expr) => make::match_arm_with_guard(
iter::once(make::wildcard_pat().into()),
expr,
unwrap_trivial_block(body),
),
}
})
.chain(iter::once(else_arm));
let match_expr = make::expr_match(scrutinee_to_be_expr, make::match_arm_list(arms));
match_expr.indent(IndentLevel::from_node(if_expr.syntax()))
};
let expr =
if if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind())) {
make::block_expr(None, Some(match_expr)).into()
} else {
match_expr
};
let has_preceding_if_expr =
if_expr.syntax().parent().map_or(false, |it| ast::IfExpr::can_cast(it.kind()));
let expr = if has_preceding_if_expr {
// make sure we replace the `else if let ...` with a block so we don't end up with `else expr`
make::block_expr(None, Some(match_expr)).into()
} else {
match_expr
};
edit.replace_ast::<ast::Expr>(if_expr.into(), expr);
},
)
@ -210,7 +235,19 @@ fn is_pat_wildcard_or_sad(sema: &hir::Semantics<RootDatabase>, pat: &ast::Pat) -
mod tests {
use super::*;
use crate::tests::{check_assist, check_assist_target};
use crate::tests::{check_assist, check_assist_not_applicable, check_assist_target};
#[test]
fn test_if_let_with_match_unapplicable_for_simple_ifs() {
check_assist_not_applicable(
replace_if_let_with_match,
r#"
fn main() {
if $0true {} else if false {} else {}
}
"#,
)
}
#[test]
fn test_if_let_with_match_no_else() {
@ -223,7 +260,8 @@ impl VariantData {
self.foo();
}
}
} "#,
}
"#,
r#"
impl VariantData {
pub fn foo(&self) {
@ -234,7 +272,8 @@ impl VariantData {
_ => (),
}
}
} "#,
}
"#,
)
}
@ -249,19 +288,23 @@ impl VariantData {
true
} else if let VariantData::Tuple(..) = *self {
false
} else if cond() {
true
} else {
bar(
123
)
}
}
} "#,
}
"#,
r#"
impl VariantData {
pub fn is_struct(&self) -> bool {
match *self {
VariantData::Struct(..) => true,
VariantData::Tuple(..) => false,
_ if cond() => true,
_ => {
bar(
123
@ -269,7 +312,8 @@ impl VariantData {
}
}
}
} "#,
}
"#,
)
}
@ -288,7 +332,8 @@ impl VariantData {
false
}
}
} "#,
}
"#,
r#"
impl VariantData {
pub fn is_struct(&self) -> bool {
@ -301,7 +346,8 @@ impl VariantData {
}
}
}
} "#,
}
"#,
)
}

View file

@ -430,6 +430,19 @@ pub fn match_arm(pats: impl IntoIterator<Item = ast::Pat>, expr: ast::Expr) -> a
}
}
pub fn match_arm_with_guard(
pats: impl IntoIterator<Item = ast::Pat>,
guard: ast::Expr,
expr: ast::Expr,
) -> ast::MatchArm {
let pats_str = pats.into_iter().join(" | ");
return from_text(&format!("{} if {} => {}", pats_str, guard, expr));
fn from_text(text: &str) -> ast::MatchArm {
ast_from_text(&format!("fn f() {{ match () {{{}}} }}", text))
}
}
pub fn match_arm_list(arms: impl IntoIterator<Item = ast::MatchArm>) -> ast::MatchArmList {
let arms_str = arms
.into_iter()