diff --git a/crates/hir_expand/src/name.rs b/crates/hir_expand/src/name.rs index 4c15ba1545..b7b0f0b0e6 100644 --- a/crates/hir_expand/src/name.rs +++ b/crates/hir_expand/src/name.rs @@ -200,6 +200,7 @@ pub mod known { Range, Neg, Not, + None, Index, // Components of known path (function name) filter_map, diff --git a/crates/ide_assists/src/handlers/convert_bool_then.rs b/crates/ide_assists/src/handlers/convert_bool_then.rs new file mode 100644 index 0000000000..1840b2d756 --- /dev/null +++ b/crates/ide_assists/src/handlers/convert_bool_then.rs @@ -0,0 +1,352 @@ +use hir::{known, Semantics}; +use ide_db::{ + helpers::{for_each_tail_expr, FamousDefs}, + RootDatabase, +}; +use syntax::{ + ast::{self, make, ArgListOwner}, + ted, AstNode, SyntaxNode, +}; + +use crate::{ + utils::{invert_boolean_expression, unwrap_trivial_block}, + AssistContext, AssistId, AssistKind, Assists, +}; + +// Assist: convert_if_to_bool_then +// +// Converts an if expression into a corresponding `bool::then` call. +// +// ``` +// # //- minicore: option +// fn main() { +// if$0 cond { +// Some(val) +// } else { +// None +// } +// } +// ``` +// -> +// ``` +// fn main() { +// cond.then(|| val) +// } +// ``` +pub(crate) fn convert_if_to_bool_then(acc: &mut Assists, ctx: &AssistContext) -> Option<()> { + // todo, applies to match as well + let expr = ctx.find_node_at_offset::()?; + if !expr.if_token()?.text_range().contains_inclusive(ctx.offset()) { + return None; + } + + let cond = expr.condition().filter(|cond| !cond.is_pattern_cond())?; + let cond = cond.expr()?; + let then = expr.then_branch()?; + let else_ = match expr.else_branch()? { + ast::ElseBranch::Block(b) => b, + ast::ElseBranch::IfExpr(_) => { + cov_mark::hit!(convert_if_to_bool_then_chain); + return None; + } + }; + + let (none_variant, some_variant) = option_variants(&ctx.sema, expr.syntax())?; + + let (invert_cond, closure_body) = match ( + block_is_none_variant(&ctx.sema, &then, none_variant), + block_is_none_variant(&ctx.sema, &else_, none_variant), + ) { + (invert @ true, false) => (invert, ast::Expr::BlockExpr(else_)), + (invert @ false, true) => (invert, ast::Expr::BlockExpr(then)), + _ => return None, + }; + + if is_invalid_body(&ctx.sema, some_variant, &closure_body) { + cov_mark::hit!(convert_if_to_bool_then_pattern_invalid_body); + return None; + } + + let target = expr.syntax().text_range(); + acc.add( + AssistId("convert_if_to_bool_then", AssistKind::RefactorRewrite), + "Convert `if` expression to `bool::then` call", + target, + |builder| { + let closure_body = closure_body.clone_for_update(); + // Rewrite all `Some(e)` in tail position to `e` + for_each_tail_expr(&closure_body, &mut |e| { + let e = match e { + ast::Expr::BreakExpr(e) => e.expr(), + e @ ast::Expr::CallExpr(_) => Some(e.clone()), + _ => None, + }; + if let Some(ast::Expr::CallExpr(call)) = e { + if let Some(arg_list) = call.arg_list() { + if let Some(arg) = arg_list.args().next() { + ted::replace(call.syntax(), arg.syntax()); + } + } + } + }); + let closure_body = match closure_body { + ast::Expr::BlockExpr(block) => unwrap_trivial_block(block), + e => e, + }; + + let cond = if invert_cond { invert_boolean_expression(&ctx.sema, cond) } else { cond }; + let arg_list = make::arg_list(Some(make::expr_closure(None, closure_body))); + let mcall = make::expr_method_call(cond, make::name_ref("then"), arg_list); + builder.replace(target, mcall.to_string()); + }, + ) +} + +fn option_variants( + sema: &Semantics, + expr: &SyntaxNode, +) -> Option<(hir::Variant, hir::Variant)> { + let fam = FamousDefs(&sema, sema.scope(expr).krate()); + let option_variants = fam.core_option_Option()?.variants(sema.db); + match &*option_variants { + &[variant0, variant1] => Some(if variant0.name(sema.db) == known::None { + (variant0, variant1) + } else { + (variant1, variant0) + }), + _ => None, + } +} + +/// Traverses the expression checking if it contains `return` or `?` expressions or if any tail is not a `Some(expr)` expression. +/// If any of these conditions are met it is impossible to rewrite this as a `bool::then` call. +fn is_invalid_body( + sema: &Semantics, + some_variant: hir::Variant, + expr: &ast::Expr, +) -> bool { + let mut invalid = false; + expr.preorder(&mut |e| { + invalid |= + matches!(e, syntax::WalkEvent::Enter(ast::Expr::TryExpr(_) | ast::Expr::ReturnExpr(_))); + invalid + }); + if !invalid { + for_each_tail_expr(&expr, &mut |e| { + if invalid { + return; + } + let e = match e { + ast::Expr::BreakExpr(e) => e.expr(), + e @ ast::Expr::CallExpr(_) => Some(e.clone()), + _ => None, + }; + if let Some(ast::Expr::CallExpr(call)) = e { + if let Some(ast::Expr::PathExpr(p)) = call.expr() { + let res = p.path().and_then(|p| sema.resolve_path(&p)); + if let Some(hir::PathResolution::Def(hir::ModuleDef::Variant(v))) = res { + return invalid |= v != some_variant; + } + } + } + invalid = true + }); + } + invalid +} + +fn block_is_none_variant( + sema: &Semantics, + block: &ast::BlockExpr, + none_variant: hir::Variant, +) -> bool { + block.as_lone_tail().and_then(|e| match e { + ast::Expr::PathExpr(pat) => match sema.resolve_path(&pat.path()?)? { + hir::PathResolution::Def(hir::ModuleDef::Variant(v)) => Some(v), + _ => None, + }, + _ => None, + }) == Some(none_variant) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn convert_if_to_bool_then_simple() { + check_assist( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + Some(15) + } else { + None + } +} +", + r" +fn main() { + true.then(|| 15) +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_invert() { + check_assist( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + None + } else { + Some(15) + } +} +", + r" +fn main() { + false.then(|| 15) +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_none_none() { + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + None + } else { + None + } +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_some_some() { + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + Some(15) + } else { + Some(15) + } +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_mixed() { + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + if true { + Some(15) + } else { + None + } + } else { + None + } +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_chain() { + cov_mark::check!(convert_if_to_bool_then_chain); + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + Some(15) + } else if true { + None + } else { + None + } +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_pattern_cond() { + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 let true = true { + Some(15) + } else { + None + } +} +", + ); + } + + #[test] + fn convert_if_to_bool_then_pattern_invalid_body() { + cov_mark::check_count!(convert_if_to_bool_then_pattern_invalid_body, 2); + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn make_me_an_option() -> Option { None } +fn main() { + if$0 true { + if true { + make_me_an_option() + } else { + Some(15) + } + } else { + None + } +} +", + ); + check_assist_not_applicable( + convert_if_to_bool_then, + r" +//- minicore:option +fn main() { + if$0 true { + if true { + return; + } + Some(15) + } else { + None + } +} +", + ); + } +} diff --git a/crates/ide_assists/src/lib.rs b/crates/ide_assists/src/lib.rs index 14bf565e56..71a2008609 100644 --- a/crates/ide_assists/src/lib.rs +++ b/crates/ide_assists/src/lib.rs @@ -55,10 +55,11 @@ mod handlers { mod apply_demorgan; mod auto_import; mod change_visibility; - mod convert_integer_literal; + mod convert_bool_then; mod convert_comment_block; - mod convert_iter_for_each_to_for; + mod convert_integer_literal; mod convert_into_to_from; + mod convert_iter_for_each_to_for; mod convert_tuple_struct_to_named_struct; mod early_return; mod expand_glob_import; @@ -73,7 +74,6 @@ mod handlers { mod flip_trait_bound; mod generate_default_from_enum_variant; mod generate_default_from_new; - mod generate_is_empty_from_len; mod generate_deref; mod generate_derive; mod generate_enum_is_method; @@ -82,6 +82,7 @@ mod handlers { mod generate_function; mod generate_getter; mod generate_impl; + mod generate_is_empty_from_len; mod generate_new; mod generate_setter; mod infer_function_return_type; @@ -124,10 +125,11 @@ mod handlers { apply_demorgan::apply_demorgan, auto_import::auto_import, change_visibility::change_visibility, - convert_integer_literal::convert_integer_literal, + convert_bool_then::convert_if_to_bool_then, convert_comment_block::convert_comment_block, - convert_iter_for_each_to_for::convert_iter_for_each_to_for, + convert_integer_literal::convert_integer_literal, convert_into_to_from::convert_into_to_from, + convert_iter_for_each_to_for::convert_iter_for_each_to_for, convert_tuple_struct_to_named_struct::convert_tuple_struct_to_named_struct, early_return::convert_to_guarded_return, expand_glob_import::expand_glob_import, diff --git a/crates/ide_assists/src/tests/generated.rs b/crates/ide_assists/src/tests/generated.rs index ebf312aa3f..cb67b77168 100644 --- a/crates/ide_assists/src/tests/generated.rs +++ b/crates/ide_assists/src/tests/generated.rs @@ -191,6 +191,28 @@ pub(crate) fn frobnicate() {} ) } +#[test] +fn doctest_convert_if_to_bool_then() { + check_doc_test( + "convert_if_to_bool_then", + r#####" +//- minicore: option +fn main() { + if$0 cond { + Some(val) + } else { + None + } +} +"#####, + r#####" +fn main() { + cond.then(|| val) +} +"#####, + ) +} + #[test] fn doctest_convert_integer_literal() { check_doc_test( diff --git a/crates/syntax/src/ast/node_ext.rs b/crates/syntax/src/ast/node_ext.rs index 68dcac4b03..99ef5c264f 100644 --- a/crates/syntax/src/ast/node_ext.rs +++ b/crates/syntax/src/ast/node_ext.rs @@ -56,6 +56,10 @@ impl ast::BlockExpr { pub fn is_empty(&self) -> bool { self.statements().next().is_none() && self.tail_expr().is_none() } + + pub fn as_lone_tail(&self) -> Option { + self.statements().next().is_none().then(|| self.tail_expr()).flatten() + } } impl ast::Pat {