Auto merge of #14667 - unexge:nested-types-in-unwrap-result-type, r=HKalbasi

Handle nested types in `unwrap_result_return_type` assist

Fixes https://github.com/rust-lang/rust-analyzer/issues/14496
This commit is contained in:
bors 2023-04-27 06:38:36 +00:00
commit 237ffa3997

View file

@ -5,7 +5,7 @@ use ide_db::{
use itertools::Itertools;
use syntax::{
ast::{self, Expr},
match_ast, AstNode, TextRange, TextSize,
match_ast, AstNode, NodeOrToken, SyntaxKind, TextRange,
};
use crate::{AssistContext, AssistId, AssistKind, Assists};
@ -38,14 +38,15 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
};
let type_ref = &ret_type.ty()?;
let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
let Some(hir::Adt::Enum(ret_enum)) = ctx.sema.resolve_type(type_ref)?.as_adt() else { return None; };
let result_enum =
FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?;
if !matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) {
if ret_enum != result_enum {
return None;
}
let Some(ok_type) = unwrap_result_type(type_ref) else { return None; };
acc.add(
AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite),
"Unwrap Result return type",
@ -64,26 +65,19 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<'
});
for_each_tail_expr(&body, tail_cb);
let mut is_unit_type = false;
if let Some((_, inner_type)) = type_ref.to_string().split_once('<') {
let inner_type = match inner_type.split_once(',') {
Some((success_inner_type, _)) => success_inner_type,
None => inner_type,
};
let new_ret_type = inner_type.strip_suffix('>').unwrap_or(inner_type);
if new_ret_type == "()" {
is_unit_type = true;
let text_range = TextRange::new(
ret_type.syntax().text_range().start(),
ret_type.syntax().text_range().end() + TextSize::from(1u32),
);
builder.delete(text_range)
} else {
builder.replace(
type_ref.syntax().text_range(),
inner_type.strip_suffix('>').unwrap_or(inner_type),
)
let is_unit_type = is_unit_type(&ok_type);
if is_unit_type {
let mut text_range = ret_type.syntax().text_range();
if let Some(NodeOrToken::Token(token)) = ret_type.syntax().next_sibling_or_token() {
if token.kind() == SyntaxKind::WHITESPACE {
text_range = TextRange::new(text_range.start(), token.text_range().end());
}
}
builder.delete(text_range);
} else {
builder.replace(type_ref.syntax().text_range(), ok_type.syntax().text());
}
for ret_expr_arg in exprs_to_unwrap {
@ -134,6 +128,22 @@ fn tail_cb_impl(acc: &mut Vec<ast::Expr>, e: &ast::Expr) {
}
}
// Tries to extract `T` from `Result<T, E>`.
fn unwrap_result_type(ty: &ast::Type) -> Option<ast::Type> {
let ast::Type::PathType(path_ty) = ty else { return None; };
let path = path_ty.path()?;
let segment = path.first_segment()?;
let generic_arg_list = segment.generic_arg_list()?;
let generic_args: Vec<_> = generic_arg_list.generic_args().collect();
let ast::GenericArg::TypeArg(ok_type) = generic_args.first()? else { return None; };
ok_type.ty()
}
fn is_unit_type(ty: &ast::Type) -> bool {
let ast::Type::TupleType(tuple) = ty else { return false };
tuple.fields().next().is_none()
}
#[cfg(test)]
mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
@ -173,6 +183,21 @@ fn foo() -> Result<(), Box<dyn Error$0>> {
r#"
fn foo() {
}
"#,
);
// Unformatted return type
check_assist(
unwrap_result_return_type,
r#"
//- minicore: result
fn foo() -> Result<(), Box<dyn Error$0>>{
Ok(())
}
"#,
r#"
fn foo() {
}
"#,
);
}
@ -1014,6 +1039,54 @@ fn foo(the_field: u32) -> u32 {
}
the_field
}
"#,
);
}
#[test]
fn unwrap_result_return_type_nested_type() {
check_assist(
unwrap_result_return_type,
r#"
//- minicore: result, option
fn foo() -> Result<Option<i32$0>, ()> {
Ok(Some(42))
}
"#,
r#"
fn foo() -> Option<i32> {
Some(42)
}
"#,
);
check_assist(
unwrap_result_return_type,
r#"
//- minicore: result, option
fn foo() -> Result<Option<Result<i32$0, ()>>, ()> {
Ok(None)
}
"#,
r#"
fn foo() -> Option<Result<i32, ()>> {
None
}
"#,
);
check_assist(
unwrap_result_return_type,
r#"
//- minicore: result, option, iterators
fn foo() -> Result<impl Iterator<Item = i32>$0, ()> {
Ok(Some(42).into_iter())
}
"#,
r#"
fn foo() -> impl Iterator<Item = i32> {
Some(42).into_iter()
}
"#,
);
}