From 7c17dd331bb5708152ef2c2dd91c67287f9d821d Mon Sep 17 00:00:00 2001 From: Jason Newcomb Date: Sat, 15 Jun 2024 02:13:32 -0400 Subject: [PATCH] `needless_for_each`: Check HIR tree first. --- clippy_lints/src/needless_for_each.rs | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/clippy_lints/src/needless_for_each.rs b/clippy_lints/src/needless_for_each.rs index 143acc2b1..6390e51f9 100644 --- a/clippy_lints/src/needless_for_each.rs +++ b/clippy_lints/src/needless_for_each.rs @@ -3,7 +3,7 @@ use rustc_hir::intravisit::{walk_expr, Visitor}; use rustc_hir::{Block, BlockCheckMode, Closure, Expr, ExprKind, Stmt, StmtKind}; use rustc_lint::{LateContext, LateLintPass}; use rustc_session::declare_lint_pass; -use rustc_span::{sym, Span, Symbol}; +use rustc_span::{sym, Span}; use clippy_utils::diagnostics::span_lint_and_then; use clippy_utils::is_trait_method; @@ -55,16 +55,8 @@ declare_lint_pass!(NeedlessForEach => [NEEDLESS_FOR_EACH]); impl<'tcx> LateLintPass<'tcx> for NeedlessForEach { fn check_stmt(&mut self, cx: &LateContext<'tcx>, stmt: &'tcx Stmt<'_>) { - let (StmtKind::Expr(expr) | StmtKind::Semi(expr)) = stmt.kind else { - return; - }; - - if let ExprKind::MethodCall(method_name, for_each_recv, [for_each_arg], _) = expr.kind - // Check the method name is `for_each`. - && method_name.ident.name == Symbol::intern("for_each") - // Check `for_each` is an associated function of `Iterator`. - && is_trait_method(cx, expr, sym::Iterator) - // Checks the receiver of `for_each` is also a method call. + if let StmtKind::Expr(expr) | StmtKind::Semi(expr) = stmt.kind + && let ExprKind::MethodCall(method_name, for_each_recv, [for_each_arg], _) = expr.kind && let ExprKind::MethodCall(_, iter_recv, [], _) = for_each_recv.kind // Skip the lint if the call chain is too long. e.g. `v.field.iter().for_each()` or // `v.foo().iter().for_each()` must be skipped. @@ -72,6 +64,8 @@ impl<'tcx> LateLintPass<'tcx> for NeedlessForEach { iter_recv.kind, ExprKind::Array(..) | ExprKind::Call(..) | ExprKind::Path(..) ) + && method_name.ident.name.as_str() == "for_each" + && is_trait_method(cx, expr, sym::Iterator) // Checks the type of the `iter` method receiver is NOT a user defined type. && has_iter_method(cx, cx.typeck_results().expr_ty(iter_recv)).is_some() // Skip the lint if the body is not block because this is simpler than `for` loop.