diff --git a/crates/ra_assists/src/handlers/change_return_type_to_result.rs b/crates/ra_assists/src/handlers/change_return_type_to_result.rs new file mode 100644 index 0000000000..1e8d986cdb --- /dev/null +++ b/crates/ra_assists/src/handlers/change_return_type_to_result.rs @@ -0,0 +1,971 @@ +use ra_syntax::{ + ast, AstNode, + SyntaxKind::{COMMENT, WHITESPACE}, + SyntaxNode, TextSize, +}; + +use crate::{Assist, AssistCtx, AssistId}; +use ast::{BlockExpr, Expr, LoopBodyOwner}; + +// Assist: change_return_type_to_result +// +// Change the function's return type to Result. +// +// ``` +// fn foo() -> i32<|> { 42i32 } +// ``` +// -> +// ``` +// fn foo() -> Result { Ok(42i32) } +// ``` +pub(crate) fn change_return_type_to_result(ctx: AssistCtx) -> Option { + let fn_def = ctx.find_node_at_offset::(); + let fn_def = &mut fn_def?; + let ret_type = &fn_def.ret_type()?.type_ref()?; + if ret_type.syntax().text().to_string().starts_with("Result<") { + return None; + } + + let block_expr = &fn_def.body()?; + let cursor_in_ret_type = + fn_def.ret_type()?.syntax().text_range().contains_range(ctx.frange.range); + if !cursor_in_ret_type { + return None; + } + + ctx.add_assist( + AssistId("change_return_type_to_result"), + "Change return type to Result", + ret_type.syntax().text_range(), + |edit| { + let mut tail_return_expr_collector = TailReturnCollector::new(); + tail_return_expr_collector.collect_jump_exprs(block_expr, false); + tail_return_expr_collector.collect_tail_exprs(block_expr); + + for ret_expr_arg in tail_return_expr_collector.exprs_to_wrap { + edit.replace_node_and_indent(&ret_expr_arg, format!("Ok({})", ret_expr_arg)); + } + edit.replace_node_and_indent(ret_type.syntax(), format!("Result<{}, >", ret_type)); + + if let Some(node_start) = result_insertion_offset(&ret_type) { + edit.set_cursor(node_start + TextSize::of(&format!("Result<{}, ", ret_type))); + } + }, + ) +} + +struct TailReturnCollector { + exprs_to_wrap: Vec, +} + +impl TailReturnCollector { + fn new() -> Self { + Self { exprs_to_wrap: vec![] } + } + /// Collect all`return` expression + fn collect_jump_exprs(&mut self, block_expr: &BlockExpr, collect_break: bool) { + let statements = block_expr.statements(); + for stmt in statements { + let expr = match &stmt { + ast::Stmt::ExprStmt(stmt) => stmt.expr(), + ast::Stmt::LetStmt(stmt) => stmt.initializer(), + }; + if let Some(expr) = &expr { + self.handle_exprs(expr, collect_break); + } + } + + // Browse tail expressions for each block + if let Some(expr) = block_expr.expr() { + if let Some(last_exprs) = get_tail_expr_from_block(&expr) { + for last_expr in last_exprs { + let last_expr = match last_expr { + NodeType::Node(expr) | NodeType::Leaf(expr) => expr, + }; + + if let Some(last_expr) = Expr::cast(last_expr.clone()) { + self.handle_exprs(&last_expr, collect_break); + } else if let Some(expr_stmt) = ast::Stmt::cast(last_expr) { + let expr_stmt = match &expr_stmt { + ast::Stmt::ExprStmt(stmt) => stmt.expr(), + ast::Stmt::LetStmt(stmt) => stmt.initializer(), + }; + if let Some(expr) = &expr_stmt { + self.handle_exprs(expr, collect_break); + } + } + } + } + } + } + + fn handle_exprs(&mut self, expr: &Expr, collect_break: bool) { + match expr { + Expr::BlockExpr(block_expr) => { + self.collect_jump_exprs(&block_expr, collect_break); + } + Expr::ReturnExpr(ret_expr) => { + if let Some(ret_expr_arg) = &ret_expr.expr() { + self.exprs_to_wrap.push(ret_expr_arg.syntax().clone()); + } + } + Expr::BreakExpr(break_expr) if collect_break => { + if let Some(break_expr_arg) = &break_expr.expr() { + self.exprs_to_wrap.push(break_expr_arg.syntax().clone()); + } + } + Expr::IfExpr(if_expr) => { + for block in if_expr.blocks() { + self.collect_jump_exprs(&block, collect_break); + } + } + Expr::LoopExpr(loop_expr) => { + if let Some(block_expr) = loop_expr.loop_body() { + self.collect_jump_exprs(&block_expr, collect_break); + } + } + Expr::ForExpr(for_expr) => { + if let Some(block_expr) = for_expr.loop_body() { + self.collect_jump_exprs(&block_expr, collect_break); + } + } + Expr::WhileExpr(while_expr) => { + if let Some(block_expr) = while_expr.loop_body() { + self.collect_jump_exprs(&block_expr, collect_break); + } + } + Expr::MatchExpr(match_expr) => { + if let Some(arm_list) = match_expr.match_arm_list() { + arm_list.arms().filter_map(|match_arm| match_arm.expr()).for_each(|expr| { + self.handle_exprs(&expr, collect_break); + }); + } + } + _ => {} + } + } + + fn collect_tail_exprs(&mut self, block: &BlockExpr) { + if let Some(expr) = block.expr() { + self.handle_exprs(&expr, true); + self.fetch_tail_exprs(&expr); + } + } + + fn fetch_tail_exprs(&mut self, expr: &Expr) { + if let Some(exprs) = get_tail_expr_from_block(expr) { + for node_type in &exprs { + match node_type { + NodeType::Leaf(expr) => { + self.exprs_to_wrap.push(expr.clone()); + } + NodeType::Node(expr) => match &Expr::cast(expr.clone()) { + Some(last_expr) => { + self.fetch_tail_exprs(last_expr); + } + None => { + self.exprs_to_wrap.push(expr.clone()); + } + }, + } + } + } + } +} + +#[derive(Debug)] +enum NodeType { + Leaf(SyntaxNode), + Node(SyntaxNode), +} + +/// Get a tail expression inside a block +fn get_tail_expr_from_block(expr: &Expr) -> Option> { + match expr { + Expr::IfExpr(if_expr) => { + let mut nodes = vec![]; + for block in if_expr.blocks() { + if let Some(block_expr) = block.expr() { + if let Some(tail_exprs) = get_tail_expr_from_block(&block_expr) { + nodes.extend(tail_exprs); + } + } else if let Some(last_expr) = block.syntax().last_child() { + nodes.push(NodeType::Node(last_expr)); + } else { + nodes.push(NodeType::Node(block.syntax().clone())); + } + } + Some(nodes) + } + Expr::LoopExpr(loop_expr) => { + loop_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)]) + } + Expr::ForExpr(for_expr) => { + for_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)]) + } + Expr::WhileExpr(while_expr) => { + while_expr.syntax().last_child().map(|lc| vec![NodeType::Node(lc)]) + } + Expr::BlockExpr(block_expr) => { + block_expr.expr().map(|lc| vec![NodeType::Node(lc.syntax().clone())]) + } + Expr::MatchExpr(match_expr) => { + let arm_list = match_expr.match_arm_list()?; + let arms: Vec = arm_list + .arms() + .filter_map(|match_arm| match_arm.expr()) + .map(|expr| match expr { + Expr::ReturnExpr(ret_expr) => NodeType::Node(ret_expr.syntax().clone()), + Expr::BreakExpr(break_expr) => NodeType::Node(break_expr.syntax().clone()), + _ => match expr.syntax().last_child() { + Some(last_expr) => NodeType::Node(last_expr), + None => NodeType::Node(expr.syntax().clone()), + }, + }) + .collect(); + + Some(arms) + } + Expr::BreakExpr(expr) => expr.expr().map(|e| vec![NodeType::Leaf(e.syntax().clone())]), + Expr::ReturnExpr(ret_expr) => Some(vec![NodeType::Node(ret_expr.syntax().clone())]), + Expr::CallExpr(call_expr) => Some(vec![NodeType::Leaf(call_expr.syntax().clone())]), + Expr::Literal(lit_expr) => Some(vec![NodeType::Leaf(lit_expr.syntax().clone())]), + Expr::TupleExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::ArrayExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::ParenExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::PathExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::Label(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::RecordLit(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::IndexExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::MethodCallExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::AwaitExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::CastExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::RefExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::PrefixExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::RangeExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::BinExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::MacroCall(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + Expr::BoxExpr(expr) => Some(vec![NodeType::Leaf(expr.syntax().clone())]), + _ => None, + } +} + +fn result_insertion_offset(ret_type: &ast::TypeRef) -> Option { + let non_ws_child = ret_type + .syntax() + .children_with_tokens() + .find(|it| it.kind() != COMMENT && it.kind() != WHITESPACE)?; + Some(non_ws_child.text_range().start()) +} + +#[cfg(test)] +mod tests { + + use crate::tests::{check_assist, check_assist_not_applicable}; + + use super::*; + + #[test] + fn change_return_type_to_result_simple() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i3<|>2 { + let test = "test"; + return 42i32; + }"#, + r#"fn foo() -> Result> { + let test = "test"; + return Ok(42i32); + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_return_type() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let test = "test"; + return 42i32; + }"#, + r#"fn foo() -> Result> { + let test = "test"; + return Ok(42i32); + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_return_type_bad_cursor() { + check_assist_not_applicable( + change_return_type_to_result, + r#"fn foo() -> i32 { + let test = "test";<|> + return 42i32; + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_cursor() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> <|>i32 { + let test = "test"; + return 42i32; + }"#, + r#"fn foo() -> Result> { + let test = "test"; + return Ok(42i32); + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail() { + check_assist( + change_return_type_to_result, + r#"fn foo() -><|> i32 { + let test = "test"; + 42i32 + }"#, + r#"fn foo() -> Result> { + let test = "test"; + Ok(42i32) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail_only() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + 42i32 + }"#, + r#"fn foo() -> Result> { + Ok(42i32) + }"#, + ); + } + #[test] + fn change_return_type_to_result_simple_with_tail_block_like() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + if true { + 42i32 + } else { + 24i32 + } + }"#, + r#"fn foo() -> Result> { + if true { + Ok(42i32) + } else { + Ok(24i32) + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_nested_if() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + if true { + if false { + 1 + } else { + 2 + } + } else { + 24i32 + } + }"#, + r#"fn foo() -> Result> { + if true { + if false { + Ok(1) + } else { + Ok(2) + } + } else { + Ok(24i32) + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_await() { + check_assist( + change_return_type_to_result, + r#"async fn foo() -> i<|>32 { + if true { + if false { + 1.await + } else { + 2.await + } + } else { + 24i32.await + } + }"#, + r#"async fn foo() -> Result> { + if true { + if false { + Ok(1.await) + } else { + Ok(2.await) + } + } else { + Ok(24i32.await) + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_array() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> [i32;<|> 3] { + [1, 2, 3] + }"#, + r#"fn foo() -> Result<[i32; 3], <|>> { + Ok([1, 2, 3]) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_cast() { + check_assist( + change_return_type_to_result, + r#"fn foo() -<|>> i32 { + if true { + if false { + 1 as i32 + } else { + 2 as i32 + } + } else { + 24 as i32 + } + }"#, + r#"fn foo() -> Result> { + if true { + if false { + Ok(1 as i32) + } else { + Ok(2 as i32) + } + } else { + Ok(24 as i32) + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail_block_like_match() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = 5; + match my_var { + 5 => 42i32, + _ => 24i32, + } + }"#, + r#"fn foo() -> Result> { + let my_var = 5; + match my_var { + 5 => Ok(42i32), + _ => Ok(24i32), + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_loop_with_tail() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = 5; + loop { + println!("test"); + 5 + } + + my_var + }"#, + r#"fn foo() -> Result> { + let my_var = 5; + loop { + println!("test"); + 5 + } + + Ok(my_var) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_loop_in_let_stmt() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = let x = loop { + break 1; + }; + + my_var + }"#, + r#"fn foo() -> Result> { + let my_var = let x = loop { + break 1; + }; + + Ok(my_var) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail_block_like_match_return_expr() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = 5; + let res = match my_var { + 5 => 42i32, + _ => return 24i32, + }; + + res + }"#, + r#"fn foo() -> Result> { + let my_var = 5; + let res = match my_var { + 5 => 42i32, + _ => return Ok(24i32), + }; + + Ok(res) + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = 5; + let res = if my_var == 5 { + 42i32 + } else { + return 24i32; + }; + + res + }"#, + r#"fn foo() -> Result> { + let my_var = 5; + let res = if my_var == 5 { + 42i32 + } else { + return Ok(24i32); + }; + + Ok(res) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail_block_like_match_deeper() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let my_var = 5; + match my_var { + 5 => { + if true { + 42i32 + } else { + 25i32 + } + }, + _ => { + let test = "test"; + if test == "test" { + return bar(); + } + 53i32 + }, + } + }"#, + r#"fn foo() -> Result> { + let my_var = 5; + match my_var { + 5 => { + if true { + Ok(42i32) + } else { + Ok(25i32) + } + }, + _ => { + let test = "test"; + if test == "test" { + return Ok(bar()); + } + Ok(53i32) + }, + } + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_tail_block_like_early_return() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i<|>32 { + let test = "test"; + if test == "test" { + return 24i32; + } + 53i32 + }"#, + r#"fn foo() -> Result> { + let test = "test"; + if test == "test" { + return Ok(24i32); + } + Ok(53i32) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_closure() { + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -><|> u32 { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return 99; + } else { + return 0; + } + } + + the_field + }"#, + r#"fn foo(the_field: u32) -> Result> { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return Ok(99); + } else { + return Ok(0); + } + } + + Ok(the_field) + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -> u32<|> { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return 99; + } else { + return 0; + } + } + let t = None; + + t.unwrap_or_else(|| the_field) + }"#, + r#"fn foo(the_field: u32) -> Result> { + let true_closure = || { + return true; + }; + if the_field < 5 { + let mut i = 0; + + + if true_closure() { + return Ok(99); + } else { + return Ok(0); + } + } + let t = None; + + Ok(t.unwrap_or_else(|| the_field)) + }"#, + ); + } + + #[test] + fn change_return_type_to_result_simple_with_weird_forms() { + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let test = "test"; + if test == "test" { + return 24i32; + } + let mut i = 0; + loop { + if i == 1 { + break 55; + } + i += 1; + } + }"#, + r#"fn foo() -> Result> { + let test = "test"; + if test == "test" { + return Ok(24i32); + } + let mut i = 0; + loop { + if i == 1 { + break Ok(55); + } + i += 1; + } + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo() -> i32<|> { + let test = "test"; + if test == "test" { + return 24i32; + } + let mut i = 0; + loop { + loop { + if i == 1 { + break 55; + } + i += 1; + } + } + }"#, + r#"fn foo() -> Result> { + let test = "test"; + if test == "test" { + return Ok(24i32); + } + let mut i = 0; + loop { + loop { + if i == 1 { + break Ok(55); + } + i += 1; + } + } + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo() -> i3<|>2 { + let test = "test"; + let other = 5; + if test == "test" { + let res = match other { + 5 => 43, + _ => return 56, + }; + } + let mut i = 0; + loop { + loop { + if i == 1 { + break 55; + } + i += 1; + } + } + }"#, + r#"fn foo() -> Result> { + let test = "test"; + let other = 5; + if test == "test" { + let res = match other { + 5 => 43, + _ => return Ok(56), + }; + } + let mut i = 0; + loop { + loop { + if i == 1 { + break Ok(55); + } + i += 1; + } + } + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -> u32<|> { + if the_field < 5 { + let mut i = 0; + loop { + if i > 5 { + return 55u32; + } + i += 3; + } + + match i { + 5 => return 99, + _ => return 0, + }; + } + + the_field + }"#, + r#"fn foo(the_field: u32) -> Result> { + if the_field < 5 { + let mut i = 0; + loop { + if i > 5 { + return Ok(55u32); + } + i += 3; + } + + match i { + 5 => return Ok(99), + _ => return Ok(0), + }; + } + + Ok(the_field) + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -> u3<|>2 { + if the_field < 5 { + let mut i = 0; + + match i { + 5 => return 99, + _ => return 0, + } + } + + the_field + }"#, + r#"fn foo(the_field: u32) -> Result> { + if the_field < 5 { + let mut i = 0; + + match i { + 5 => return Ok(99), + _ => return Ok(0), + } + } + + Ok(the_field) + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -> u32<|> { + if the_field < 5 { + let mut i = 0; + + if i == 5 { + return 99 + } else { + return 0 + } + } + + the_field + }"#, + r#"fn foo(the_field: u32) -> Result> { + if the_field < 5 { + let mut i = 0; + + if i == 5 { + return Ok(99) + } else { + return Ok(0) + } + } + + Ok(the_field) + }"#, + ); + + check_assist( + change_return_type_to_result, + r#"fn foo(the_field: u32) -> <|>u32 { + if the_field < 5 { + let mut i = 0; + + if i == 5 { + return 99; + } else { + return 0; + } + } + + the_field + }"#, + r#"fn foo(the_field: u32) -> Result> { + if the_field < 5 { + let mut i = 0; + + if i == 5 { + return Ok(99); + } else { + return Ok(0); + } + } + + Ok(the_field) + }"#, + ); + } +} diff --git a/crates/ra_assists/src/lib.rs b/crates/ra_assists/src/lib.rs index 13ea45ec7c..0473fd8c20 100644 --- a/crates/ra_assists/src/lib.rs +++ b/crates/ra_assists/src/lib.rs @@ -129,6 +129,7 @@ mod handlers { mod replace_qualified_name_with_use; mod replace_unwrap_with_match; mod split_import; + mod change_return_type_to_result; mod add_from_impl_for_enum; mod reorder_fields; mod unwrap_block; @@ -145,6 +146,7 @@ mod handlers { add_new::add_new, apply_demorgan::apply_demorgan, auto_import::auto_import, + change_return_type_to_result::change_return_type_to_result, change_visibility::change_visibility, early_return::convert_to_guarded_return, fill_match_arms::fill_match_arms, diff --git a/crates/ra_assists/src/tests/generated.rs b/crates/ra_assists/src/tests/generated.rs index 7d35fa2846..972dbd251d 100644 --- a/crates/ra_assists/src/tests/generated.rs +++ b/crates/ra_assists/src/tests/generated.rs @@ -249,6 +249,19 @@ pub mod std { pub mod collections { pub struct HashMap { } } } ) } +#[test] +fn doctest_change_return_type_to_result() { + check_doc_test( + "change_return_type_to_result", + r#####" +fn foo() -> i32<|> { 42i32 } +"#####, + r#####" +fn foo() -> Result { Ok(42i32) } +"#####, + ) +} + #[test] fn doctest_change_visibility() { check_doc_test( diff --git a/docs/user/assists.md b/docs/user/assists.md index ee515949e9..692fd4f52b 100644 --- a/docs/user/assists.md +++ b/docs/user/assists.md @@ -241,6 +241,18 @@ fn main() { } ``` +## `change_return_type_to_result` + +Change the function's return type to Result. + +```rust +// BEFORE +fn foo() -> i32┃ { 42i32 } + +// AFTER +fn foo() -> Result { Ok(42i32) } +``` + ## `change_visibility` Adds or changes existing visibility specifier.