From 52f1ce17aa3aa97920d47c3df7fbf400b3a672b1 Mon Sep 17 00:00:00 2001 From: oxalica Date: Sun, 18 Jun 2023 16:59:11 +0800 Subject: [PATCH] Correctly handle inlining of async fn --- .../ide-assists/src/handlers/inline_call.rs | 141 ++++++++++++++++-- crates/syntax/src/ast/make.rs | 15 ++ 2 files changed, 146 insertions(+), 10 deletions(-) diff --git a/crates/ide-assists/src/handlers/inline_call.rs b/crates/ide-assists/src/handlers/inline_call.rs index 28d815e81b..642af853fa 100644 --- a/crates/ide-assists/src/handlers/inline_call.rs +++ b/crates/ide-assists/src/handlers/inline_call.rs @@ -15,7 +15,7 @@ use ide_db::{ }; use itertools::{izip, Itertools}; use syntax::{ - ast::{self, edit_in_place::Indent, HasArgList, PathExpr}, + ast::{self, edit::IndentLevel, edit_in_place::Indent, HasArgList, PathExpr}, ted, AstNode, NodeOrToken, SyntaxKind, }; @@ -306,7 +306,7 @@ fn inline( params: &[(ast::Pat, Option, hir::Param)], CallInfo { node, arguments, generic_arg_list }: &CallInfo, ) -> ast::Expr { - let body = if sema.hir_file_for(fn_body.syntax()).is_macro() { + let mut body = if sema.hir_file_for(fn_body.syntax()).is_macro() { cov_mark::hit!(inline_call_defined_in_macro); if let Some(body) = ast::BlockExpr::cast(insert_ws_into(fn_body.syntax().clone())) { body @@ -391,19 +391,19 @@ fn inline( } } + let mut let_stmts = Vec::new(); + // Inline parameter expressions or generate `let` statements depending on whether inlining works or not. - for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arguments).rev() { + for ((pat, param_ty, _), usages, expr) in izip!(params, param_use_nodes, arguments) { // izip confuses RA due to our lack of hygiene info currently losing us type info causing incorrect errors let usages: &[ast::PathExpr] = &usages; let expr: &ast::Expr = expr; - let insert_let_stmt = || { + let mut insert_let_stmt = || { let ty = sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty.clone()); - if let Some(stmt_list) = body.stmt_list() { - stmt_list.push_front( - make::let_stmt(pat.clone(), ty, Some(expr.clone())).clone_for_update().into(), - ) - } + let_stmts.push( + make::let_stmt(pat.clone(), ty, Some(expr.clone())).clone_for_update().into(), + ); }; // check if there is a local var in the function that conflicts with parameter @@ -457,6 +457,24 @@ fn inline( } } + let is_async_fn = function.is_async(sema.db); + if is_async_fn { + cov_mark::hit!(inline_call_async_fn); + body = make::async_move_block_expr(body.statements(), body.tail_expr()).clone_for_update(); + + // Arguments should be evaluated outside the async block, and then moved into it. + if !let_stmts.is_empty() { + cov_mark::hit!(inline_call_async_fn_with_let_stmts); + body.indent(IndentLevel(1)); + body = make::block_expr(let_stmts, Some(body.into())).clone_for_update(); + } + } else if let Some(stmt_list) = body.stmt_list() { + ted::insert_all( + ted::Position::after(stmt_list.l_curly_token().unwrap()), + let_stmts.into_iter().map(|stmt| stmt.syntax().clone().into()).collect(), + ); + } + let original_indentation = match node { ast::CallableExpr::Call(it) => it.indent_level(), ast::CallableExpr::MethodCall(it) => it.indent_level(), @@ -464,7 +482,7 @@ fn inline( body.reindent_to(original_indentation); match body.tail_expr() { - Some(expr) if body.statements().next().is_none() => expr, + Some(expr) if !is_async_fn && body.statements().next().is_none() => expr, _ => match node .syntax() .parent() @@ -1351,6 +1369,109 @@ fn main() { bar * b * a * 6 }; } +"#, + ); + } + + #[test] + fn async_fn_single_expression() { + cov_mark::check!(inline_call_async_fn); + check_assist( + inline_call, + r#" +async fn bar(x: u32) -> u32 { x + 1 } +async fn foo(arg: u32) -> u32 { + bar(arg).await * 2 +} +fn spawn(_: T) {} +fn main() { + spawn(foo$0(42)); +} +"#, + r#" +async fn bar(x: u32) -> u32 { x + 1 } +async fn foo(arg: u32) -> u32 { + bar(arg).await * 2 +} +fn spawn(_: T) {} +fn main() { + spawn(async move { + bar(42).await * 2 + }); +} +"#, + ); + } + + #[test] + fn async_fn_multiple_statements() { + cov_mark::check!(inline_call_async_fn); + check_assist( + inline_call, + r#" +async fn bar(x: u32) -> u32 { x + 1 } +async fn foo(arg: u32) -> u32 { + bar(arg).await; + 42 +} +fn spawn(_: T) {} +fn main() { + spawn(foo$0(42)); +} +"#, + r#" +async fn bar(x: u32) -> u32 { x + 1 } +async fn foo(arg: u32) -> u32 { + bar(arg).await; + 42 +} +fn spawn(_: T) {} +fn main() { + spawn(async move { + bar(42).await; + 42 + }); +} +"#, + ); + } + + #[test] + fn async_fn_with_let_statements() { + cov_mark::check!(inline_call_async_fn); + cov_mark::check!(inline_call_async_fn_with_let_stmts); + check_assist( + inline_call, + r#" +async fn bar(x: u32) -> u32 { x + 1 } +async fn foo(x: u32, y: u32, z: &u32) -> u32 { + bar(x).await; + y + y + *z +} +fn spawn(_: T) {} +fn main() { + let var = 42; + spawn(foo$0(var, var + 1, &var)); +} +"#, + r#" +async fn bar(x: u32) -> u32 { x + 1 } +async fn foo(x: u32, y: u32, z: &u32) -> u32 { + bar(x).await; + y + y + *z +} +fn spawn(_: T) {} +fn main() { + let var = 42; + spawn({ + let y = var + 1; + let z: &u32 = &var; + async move { + bar(var).await; + y + y + *z + } + }); +} "#, ); } diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index 3c2b7e56b0..e435766000 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -447,6 +447,21 @@ pub fn block_expr( ast_from_text(&format!("fn f() {buf}")) } +pub fn async_move_block_expr( + stmts: impl IntoIterator, + tail_expr: Option, +) -> ast::BlockExpr { + let mut buf = "async move {\n".to_string(); + for stmt in stmts.into_iter() { + format_to!(buf, " {stmt}\n"); + } + if let Some(tail_expr) = tail_expr { + format_to!(buf, " {tail_expr}\n"); + } + buf += "}"; + ast_from_text(&format!("const _: () = {buf};")) +} + pub fn tail_only_block_expr(tail_expr: ast::Expr) -> ast::BlockExpr { ast_from_text(&format!("fn f() {{ {tail_expr} }}")) }