refactor remove trailing return diagnostic

This commit is contained in:
davidsemakula 2024-01-31 15:04:41 +03:00
parent 98e6f43a2f
commit a250c2dde0
2 changed files with 24 additions and 32 deletions

View file

@ -11,7 +11,7 @@ use cfg::{CfgExpr, CfgOptions};
use either::Either;
use hir_def::{body::SyntheticSyntax, hir::ExprOrPatId, path::ModPath, AssocItemId, DefWithBodyId};
use hir_expand::{name::Name, HirFileId, InFile};
use syntax::{ast, AstNode, AstPtr, SyntaxError, SyntaxNodePtr, TextRange};
use syntax::{ast, AstPtr, SyntaxError, SyntaxNodePtr, TextRange};
use crate::{AssocItem, Field, Local, MacroKind, Trait, Type};
@ -346,8 +346,7 @@ pub struct TraitImplRedundantAssocItems {
#[derive(Debug)]
pub struct RemoveTrailingReturn {
pub file_id: HirFileId,
pub return_expr: AstPtr<ast::Expr>,
pub return_expr: InFile<AstPtr<ast::ReturnExpr>>,
}
#[derive(Debug)]
@ -460,11 +459,10 @@ impl AnyDiagnostic {
BodyValidationDiagnostic::RemoveTrailingReturn { return_expr } => {
if let Ok(source_ptr) = source_map.expr_syntax(return_expr) {
// Filters out desugared return expressions (e.g. desugared try operators).
if ast::ReturnExpr::can_cast(source_ptr.value.kind()) {
if let Some(ptr) = source_ptr.value.cast::<ast::ReturnExpr>() {
return Some(
RemoveTrailingReturn {
file_id: source_ptr.file_id,
return_expr: source_ptr.value,
return_expr: InFile::new(source_ptr.file_id, ptr),
}
.into(),
);

View file

@ -1,9 +1,9 @@
use hir::{db::ExpandDatabase, diagnostics::RemoveTrailingReturn, HirFileIdExt, InFile};
use ide_db::{assists::Assist, source_change::SourceChange};
use syntax::{ast, AstNode, SyntaxNodePtr};
use hir::{db::ExpandDatabase, diagnostics::RemoveTrailingReturn};
use ide_db::{assists::Assist, base_db::FileRange, source_change::SourceChange};
use syntax::{ast, AstNode};
use text_edit::TextEdit;
use crate::{fix, Diagnostic, DiagnosticCode, DiagnosticsContext};
use crate::{adjusted_display_range, fix, Diagnostic, DiagnosticCode, DiagnosticsContext};
// Diagnostic: remove-trailing-return
//
@ -13,12 +13,12 @@ pub(crate) fn remove_trailing_return(
ctx: &DiagnosticsContext<'_>,
d: &RemoveTrailingReturn,
) -> Diagnostic {
let display_range = ctx.sema.diagnostics_display_range(InFile {
file_id: d.file_id,
value: expr_stmt(ctx, d)
.as_ref()
.map(|stmt| SyntaxNodePtr::new(stmt.syntax()))
.unwrap_or_else(|| d.return_expr.into()),
let display_range = adjusted_display_range(ctx, d.return_expr, &|return_expr| {
return_expr
.syntax()
.parent()
.and_then(ast::ExprStmt::cast)
.map(|stmt| stmt.syntax().text_range())
});
Diagnostic::new(
DiagnosticCode::Clippy("needless_return"),
@ -29,15 +29,20 @@ pub(crate) fn remove_trailing_return(
}
fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<Vec<Assist>> {
let return_expr = return_expr(ctx, d)?;
let stmt = expr_stmt(ctx, d);
let root = ctx.sema.db.parse_or_expand(d.return_expr.file_id);
let return_expr = d.return_expr.value.to_node(&root);
let stmt = return_expr.syntax().parent().and_then(ast::ExprStmt::cast);
let FileRange { range, file_id } =
ctx.sema.original_range_opt(stmt.as_ref().map_or(return_expr.syntax(), AstNode::syntax))?;
if Some(file_id) != d.return_expr.file_id.file_id() {
return None;
}
let range = stmt.as_ref().map_or(return_expr.syntax(), AstNode::syntax).text_range();
let replacement =
return_expr.expr().map_or_else(String::new, |expr| format!("{}", expr.syntax().text()));
let edit = TextEdit::replace(range, replacement);
let source_change = SourceChange::from_text_edit(d.file_id.original_file(ctx.sema.db), edit);
let source_change = SourceChange::from_text_edit(file_id, edit);
Some(vec![fix(
"remove_trailing_return",
@ -47,17 +52,6 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<Vec<A
)])
}
fn return_expr(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<ast::ReturnExpr> {
let root = ctx.sema.db.parse_or_expand(d.file_id);
let expr = d.return_expr.to_node(&root);
ast::ReturnExpr::cast(expr.syntax().clone())
}
fn expr_stmt(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option<ast::ExprStmt> {
let return_expr = return_expr(ctx, d)?;
return_expr.syntax().parent().and_then(ast::ExprStmt::cast)
}
#[cfg(test)]
mod tests {
use crate::tests::{check_diagnostics, check_fix};