Infer return type for async function in generate_function

This commit is contained in:
Ryo Yoshida 2023-06-08 23:47:21 +09:00
parent 1c25885bd2
commit 32768fe310
No known key found for this signature in database
GPG key ID: E25698A930586171

View file

@ -291,12 +291,9 @@ impl FunctionBuilder {
let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast); let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
let is_async = await_expr.is_some(); let is_async = await_expr.is_some();
let (ret_type, should_focus_return_type) = make_return_type( let expr_for_ret_ty = await_expr.map_or_else(|| call.clone().into(), |it| it.into());
ctx, let (ret_type, should_focus_return_type) =
&ast::Expr::CallExpr(call.clone()), make_return_type(ctx, &expr_for_ret_ty, target_module, &mut necessary_generic_params);
target_module,
&mut necessary_generic_params,
);
let (generic_param_list, where_clause) = let (generic_param_list, where_clause) =
fn_generic_params(ctx, necessary_generic_params, &target)?; fn_generic_params(ctx, necessary_generic_params, &target)?;
@ -338,12 +335,9 @@ impl FunctionBuilder {
let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast); let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast);
let is_async = await_expr.is_some(); let is_async = await_expr.is_some();
let (ret_type, should_focus_return_type) = make_return_type( let expr_for_ret_ty = await_expr.map_or_else(|| call.clone().into(), |it| it.into());
ctx, let (ret_type, should_focus_return_type) =
&ast::Expr::MethodCallExpr(call.clone()), make_return_type(ctx, &expr_for_ret_ty, target_module, &mut necessary_generic_params);
target_module,
&mut necessary_generic_params,
);
let (generic_param_list, where_clause) = let (generic_param_list, where_clause) =
fn_generic_params(ctx, necessary_generic_params, &target)?; fn_generic_params(ctx, necessary_generic_params, &target)?;
@ -429,12 +423,12 @@ impl FunctionBuilder {
/// user can change the `todo!` function body. /// user can change the `todo!` function body.
fn make_return_type( fn make_return_type(
ctx: &AssistContext<'_>, ctx: &AssistContext<'_>,
call: &ast::Expr, expr: &ast::Expr,
target_module: Module, target_module: Module,
necessary_generic_params: &mut FxHashSet<hir::GenericParam>, necessary_generic_params: &mut FxHashSet<hir::GenericParam>,
) -> (Option<ast::RetType>, bool) { ) -> (Option<ast::RetType>, bool) {
let (ret_ty, should_focus_return_type) = { let (ret_ty, should_focus_return_type) = {
match ctx.sema.type_of_expr(call).map(TypeInfo::original) { match ctx.sema.type_of_expr(expr).map(TypeInfo::original) {
Some(ty) if ty.is_unknown() => (Some(make::ty_placeholder()), true), Some(ty) if ty.is_unknown() => (Some(make::ty_placeholder()), true),
None => (Some(make::ty_placeholder()), true), None => (Some(make::ty_placeholder()), true),
Some(ty) if ty.is_unit() => (None, false), Some(ty) if ty.is_unit() => (None, false),
@ -2268,13 +2262,13 @@ impl Foo {
check_assist( check_assist(
generate_function, generate_function,
r" r"
fn foo() { async fn foo() {
$0bar(42).await(); $0bar(42).await;
} }
", ",
r" r"
fn foo() { async fn foo() {
bar(42).await(); bar(42).await;
} }
async fn bar(arg: i32) ${0:-> _} { async fn bar(arg: i32) ${0:-> _} {
@ -2284,6 +2278,28 @@ async fn bar(arg: i32) ${0:-> _} {
) )
} }
#[test]
fn return_type_for_async_fn() {
check_assist(
generate_function,
r"
//- minicore: result
async fn foo() {
if Err(()) = $0bar(42).await {}
}
",
r"
async fn foo() {
if Err(()) = bar(42).await {}
}
async fn bar(arg: i32) -> Result<_, ()> {
${0:todo!()}
}
",
);
}
#[test] #[test]
fn create_method() { fn create_method() {
check_assist( check_assist(
@ -2401,6 +2417,31 @@ fn foo() {S.bar();}
) )
} }
#[test]
fn create_async_method() {
check_assist(
generate_function,
r"
//- minicore: result
struct S;
async fn foo() {
if let Err(()) = S.$0bar(42).await {}
}
",
r"
struct S;
impl S {
async fn bar(&self, arg: i32) -> Result<_, ()> {
${0:todo!()}
}
}
async fn foo() {
if let Err(()) = S.bar(42).await {}
}
",
)
}
#[test] #[test]
fn create_static_method() { fn create_static_method() {
check_assist( check_assist(
@ -2421,6 +2462,31 @@ fn foo() {S::bar();}
) )
} }
#[test]
fn create_async_static_method() {
check_assist(
generate_function,
r"
//- minicore: result
struct S;
async fn foo() {
if let Err(()) = S::$0bar(42).await {}
}
",
r"
struct S;
impl S {
async fn bar(arg: i32) -> Result<_, ()> {
${0:todo!()}
}
}
async fn foo() {
if let Err(()) = S::bar(42).await {}
}
",
)
}
#[test] #[test]
fn create_generic_static_method() { fn create_generic_static_method() {
check_assist( check_assist(