diff --git a/clippy_lints/src/eta_reduction.rs b/clippy_lints/src/eta_reduction.rs index 8d066f305..667eb8eb2 100644 --- a/clippy_lints/src/eta_reduction.rs +++ b/clippy_lints/src/eta_reduction.rs @@ -1,15 +1,16 @@ use clippy_utils::diagnostics::{span_lint_and_sugg, span_lint_and_then}; -use clippy_utils::higher; use clippy_utils::higher::VecArgs; use clippy_utils::source::snippet_opt; use clippy_utils::ty::{implements_trait, type_is_unsafe_function}; +use clippy_utils::usage::UsedAfterExprVisitor; +use clippy_utils::{get_enclosing_loop_or_closure, higher}; use clippy_utils::{is_adjusted, iter_input_pats}; use if_chain::if_chain; use rustc_errors::Applicability; use rustc_hir::{def_id, Expr, ExprKind, Param, PatKind, QPath}; use rustc_lint::{LateContext, LateLintPass, LintContext}; use rustc_middle::lint::in_external_macro; -use rustc_middle::ty::{self, Ty}; +use rustc_middle::ty::{self, ClosureKind, Ty}; use rustc_session::{declare_lint_pass, declare_tool_lint}; declare_clippy_lint! { @@ -86,7 +87,7 @@ impl<'tcx> LateLintPass<'tcx> for EtaReduction { } } -fn check_closure(cx: &LateContext<'_>, expr: &Expr<'_>) { +fn check_closure<'tcx>(cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) { if let ExprKind::Closure(_, decl, eid, _, _) = expr.kind { let body = cx.tcx.hir().body(eid); let ex = &body.value; @@ -131,7 +132,18 @@ fn check_closure(cx: &LateContext<'_>, expr: &Expr<'_>) { then { span_lint_and_then(cx, REDUNDANT_CLOSURE, expr.span, "redundant closure", |diag| { - if let Some(snippet) = snippet_opt(cx, caller.span) { + if let Some(mut snippet) = snippet_opt(cx, caller.span) { + if_chain! { + if let ty::Closure(_, substs) = fn_ty.kind(); + if let ClosureKind::FnMut = substs.as_closure().kind(); + if UsedAfterExprVisitor::is_found(cx, caller) + || get_enclosing_loop_or_closure(cx.tcx, expr).is_some(); + + then { + // Mutable closure is used after current expr; we cannot consume it. + snippet = format!("&mut {}", snippet); + } + } diag.span_suggestion( expr.span, "replace the closure with the function itself", diff --git a/clippy_utils/src/usage.rs b/clippy_utils/src/usage.rs index 2c55021ac..182d8cb11 100644 --- a/clippy_utils/src/usage.rs +++ b/clippy_utils/src/usage.rs @@ -199,3 +199,50 @@ pub fn contains_return_break_continue_macro(expression: &Expr<'_>) -> bool { recursive_visitor.visit_expr(expression); recursive_visitor.seen_return_break_continue } + +pub struct UsedAfterExprVisitor<'a, 'tcx> { + cx: &'a LateContext<'tcx>, + expr: &'tcx Expr<'tcx>, + definition: HirId, + past_expr: bool, + used_after_expr: bool, +} +impl<'a, 'tcx> UsedAfterExprVisitor<'a, 'tcx> { + pub fn is_found(cx: &'a LateContext<'tcx>, expr: &'tcx Expr<'_>) -> bool { + utils::path_to_local(expr).map_or(false, |definition| { + let mut visitor = UsedAfterExprVisitor { + cx, + expr, + definition, + past_expr: false, + used_after_expr: false, + }; + utils::get_enclosing_block(cx, definition).map_or(false, |block| { + visitor.visit_block(block); + visitor.used_after_expr + }) + }) + } +} + +impl<'a, 'tcx> intravisit::Visitor<'tcx> for UsedAfterExprVisitor<'a, 'tcx> { + type Map = Map<'tcx>; + + fn nested_visit_map(&mut self) -> NestedVisitorMap { + NestedVisitorMap::OnlyBodies(self.cx.tcx.hir()) + } + + fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) { + if self.used_after_expr { + return; + } + + if expr.hir_id == self.expr.hir_id { + self.past_expr = true; + } else if self.past_expr && utils::path_to_local_id(expr, self.definition) { + self.used_after_expr = true; + } else { + intravisit::walk_expr(self, expr); + } + } +} diff --git a/tests/ui/eta.fixed b/tests/ui/eta.fixed index 9e752311c..91b837f9a 100644 --- a/tests/ui/eta.fixed +++ b/tests/ui/eta.fixed @@ -220,3 +220,19 @@ impl std::ops::Deref for Bar { fn test_deref_with_trait_method() { let _ = [Bar].iter().map(|s| s.to_string()).collect::>(); } + +fn mutable_closure_used_again(x: Vec, y: Vec, z: Vec) { + let mut res = Vec::new(); + let mut add_to_res = |n| res.push(n); + x.into_iter().for_each(&mut add_to_res); + y.into_iter().for_each(&mut add_to_res); + z.into_iter().for_each(add_to_res); +} + +fn mutable_closure_in_loop() { + let mut value = 0; + let mut closure = |n| value += n; + for _ in 0..5 { + Some(1).map(&mut closure); + } +} diff --git a/tests/ui/eta.rs b/tests/ui/eta.rs index 44be4628c..1b5370028 100644 --- a/tests/ui/eta.rs +++ b/tests/ui/eta.rs @@ -220,3 +220,19 @@ impl std::ops::Deref for Bar { fn test_deref_with_trait_method() { let _ = [Bar].iter().map(|s| s.to_string()).collect::>(); } + +fn mutable_closure_used_again(x: Vec, y: Vec, z: Vec) { + let mut res = Vec::new(); + let mut add_to_res = |n| res.push(n); + x.into_iter().for_each(|x| add_to_res(x)); + y.into_iter().for_each(|x| add_to_res(x)); + z.into_iter().for_each(|x| add_to_res(x)); +} + +fn mutable_closure_in_loop() { + let mut value = 0; + let mut closure = |n| value += n; + for _ in 0..5 { + Some(1).map(|n| closure(n)); + } +} diff --git a/tests/ui/eta.stderr b/tests/ui/eta.stderr index 8795d3b42..28da89413 100644 --- a/tests/ui/eta.stderr +++ b/tests/ui/eta.stderr @@ -82,5 +82,29 @@ error: redundant closure LL | let a = Some(1u8).map(|a| closure(a)); | ^^^^^^^^^^^^^^ help: replace the closure with the function itself: `closure` -error: aborting due to 13 previous errors +error: redundant closure + --> $DIR/eta.rs:227:28 + | +LL | x.into_iter().for_each(|x| add_to_res(x)); + | ^^^^^^^^^^^^^^^^^ help: replace the closure with the function itself: `&mut add_to_res` + +error: redundant closure + --> $DIR/eta.rs:228:28 + | +LL | y.into_iter().for_each(|x| add_to_res(x)); + | ^^^^^^^^^^^^^^^^^ help: replace the closure with the function itself: `&mut add_to_res` + +error: redundant closure + --> $DIR/eta.rs:229:28 + | +LL | z.into_iter().for_each(|x| add_to_res(x)); + | ^^^^^^^^^^^^^^^^^ help: replace the closure with the function itself: `add_to_res` + +error: redundant closure + --> $DIR/eta.rs:236:21 + | +LL | Some(1).map(|n| closure(n)); + | ^^^^^^^^^^^^^^ help: replace the closure with the function itself: `&mut closure` + +error: aborting due to 17 previous errors