[unnecessary_fold]: suggest turbofish if necessary

This commit is contained in:
y21 2023-06-12 02:43:23 +02:00
parent 60258b061d
commit c9daec2585
4 changed files with 217 additions and 11 deletions

View file

@ -7,10 +7,74 @@ use rustc_errors::Applicability;
use rustc_hir as hir;
use rustc_hir::PatKind;
use rustc_lint::LateContext;
use rustc_middle::ty;
use rustc_span::{source_map::Span, sym};
use super::UNNECESSARY_FOLD;
/// No turbofish needed in any case.
fn no_turbofish(_: &LateContext<'_>, _: &hir::Expr<'_>) -> bool {
false
}
/// Turbofish (`::<T>`) may be needed, but can be omitted if we are certain
/// that the type can be inferred from usage.
fn turbofish_if_not_inferred(cx: &LateContext<'_>, expr: &hir::Expr<'_>) -> bool {
let parent = cx.tcx.hir().get_parent(expr.hir_id);
// some common cases where turbofish isn't needed:
// - assigned to a local variable with a type annotation
if let hir::Node::Local(local) = parent
&& local.ty.is_some()
{
return false;
}
// - part of a function call argument, can be inferred from the function signature (provided that
// the parameter is not a generic type parameter)
if let hir::Node::Expr(parent_expr) = parent
&& let hir::ExprKind::Call(recv, args) = parent_expr.kind
&& let hir::ExprKind::Path(ref qpath) = recv.kind
&& let Some(fn_def_id) = cx.qpath_res(qpath, recv.hir_id).opt_def_id()
&& let fn_sig = cx.tcx.fn_sig(fn_def_id).skip_binder().skip_binder()
&& let Some(arg_pos) = args.iter().position(|arg| arg.hir_id == expr.hir_id)
&& let Some(ty) = fn_sig.inputs().get(arg_pos)
&& !matches!(ty.kind(), ty::Param(_))
{
return false;
}
// if it's neither of those, stay on the safe side and suggest turbofish,
// even if it could work!
true
}
#[derive(Copy, Clone)]
struct Replacement {
method_name: &'static str,
has_args: bool,
requires_turbofish: fn(&LateContext<'_>, &hir::Expr<'_>) -> bool,
}
impl Replacement {
/// `any(f)`, `all(f)`
pub fn non_generic(method_name: &'static str) -> Self {
Self {
method_name,
has_args: true,
requires_turbofish: no_turbofish,
}
}
/// `sum::<T>()`, `product::<T>()`
pub fn generic(method_name: &'static str) -> Self {
Self {
method_name,
has_args: false,
requires_turbofish: turbofish_if_not_inferred,
}
}
}
pub(super) fn check(
cx: &LateContext<'_>,
expr: &hir::Expr<'_>,
@ -24,8 +88,7 @@ pub(super) fn check(
acc: &hir::Expr<'_>,
fold_span: Span,
op: hir::BinOpKind,
replacement_method_name: &str,
replacement_has_args: bool,
replacement: Replacement,
) {
if_chain! {
// Extract the body of the closure passed to fold
@ -43,18 +106,27 @@ pub(super) fn check(
if let PatKind::Binding(_, second_arg_id, second_arg_ident, _) = strip_pat_refs(param_b.pat).kind;
if path_to_local_id(left_expr, first_arg_id);
if replacement_has_args || path_to_local_id(right_expr, second_arg_id);
if replacement.has_args || path_to_local_id(right_expr, second_arg_id);
then {
let mut applicability = Applicability::MachineApplicable;
let sugg = if replacement_has_args {
let turbofish = if (replacement.requires_turbofish)(cx, expr) {
format!("::<{}>", cx.typeck_results().expr_ty_adjusted(right_expr).peel_refs())
} else {
String::new()
};
let sugg = if replacement.has_args {
format!(
"{replacement_method_name}(|{second_arg_ident}| {r})",
"{method}{turbofish}(|{second_arg_ident}| {r})",
method = replacement.method_name,
r = snippet_with_applicability(cx, right_expr.span, "EXPR", &mut applicability),
)
} else {
format!(
"{replacement_method_name}()",
"{method}{turbofish}()",
method = replacement.method_name,
)
};
@ -80,11 +152,43 @@ pub(super) fn check(
// Check if the first argument to .fold is a suitable literal
if let hir::ExprKind::Lit(lit) = init.kind {
match lit.node {
ast::LitKind::Bool(false) => check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Or, "any", true),
ast::LitKind::Bool(true) => check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::And, "all", true),
ast::LitKind::Int(0, _) => check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Add, "sum", false),
ast::LitKind::Bool(false) => {
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::Or,
Replacement::non_generic("any"),
);
},
ast::LitKind::Bool(true) => {
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::And,
Replacement::non_generic("all"),
);
},
ast::LitKind::Int(0, _) => check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::Add,
Replacement::generic("sum"),
),
ast::LitKind::Int(1, _) => {
check_fold_with_op(cx, expr, acc, fold_span, hir::BinOpKind::Mul, "product", false);
check_fold_with_op(
cx,
expr,
acc,
fold_span,
hir::BinOpKind::Mul,
Replacement::generic("product"),
);
},
_ => (),
}

View file

@ -49,4 +49,28 @@ fn unnecessary_fold_over_multiple_lines() {
.any(|x| x > 2);
}
fn issue10000() {
use std::collections::HashMap;
use std::hash::BuildHasher;
fn anything<T>(_: T) {}
fn num(_: i32) {}
fn smoketest_map<S: BuildHasher>(mut map: HashMap<i32, i32, S>) {
map.insert(0, 0);
assert_eq!(map.values().sum::<i32>(), 0);
// more cases:
let _ = map.values().sum::<i32>();
let _ = map.values().product::<i32>();
let _: i32 = map.values().sum();
let _: i32 = map.values().product();
anything(map.values().sum::<i32>());
anything(map.values().product::<i32>());
num(map.values().sum());
num(map.values().product());
}
smoketest_map(HashMap::new());
}
fn main() {}

View file

@ -49,4 +49,28 @@ fn unnecessary_fold_over_multiple_lines() {
.fold(false, |acc, x| acc || x > 2);
}
fn issue10000() {
use std::collections::HashMap;
use std::hash::BuildHasher;
fn anything<T>(_: T) {}
fn num(_: i32) {}
fn smoketest_map<S: BuildHasher>(mut map: HashMap<i32, i32, S>) {
map.insert(0, 0);
assert_eq!(map.values().fold(0, |x, y| x + y), 0);
// more cases:
let _ = map.values().fold(0, |x, y| x + y);
let _ = map.values().fold(1, |x, y| x * y);
let _: i32 = map.values().fold(0, |x, y| x + y);
let _: i32 = map.values().fold(1, |x, y| x * y);
anything(map.values().fold(0, |x, y| x + y));
anything(map.values().fold(1, |x, y| x * y));
num(map.values().fold(0, |x, y| x + y));
num(map.values().fold(1, |x, y| x * y));
}
smoketest_map(HashMap::new());
}
fn main() {}

View file

@ -36,5 +36,59 @@ error: this `.fold` can be written more succinctly using another method
LL | .fold(false, |acc, x| acc || x > 2);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ help: try: `any(|x| x > 2)`
error: aborting due to 6 previous errors
error: this `.fold` can be written more succinctly using another method
--> $DIR/unnecessary_fold.rs:60:33
|
LL | assert_eq!(map.values().fold(0, |x, y| x + y), 0);
| ^^^^^^^^^^^^^^^^^^^^^ help: try: `sum::<i32>()`
error: this `.fold` can be written more succinctly using another method
--> $DIR/unnecessary_fold.rs:63:30
|
LL | let _ = map.values().fold(0, |x, y| x + y);
| ^^^^^^^^^^^^^^^^^^^^^ help: try: `sum::<i32>()`
error: this `.fold` can be written more succinctly using another method
--> $DIR/unnecessary_fold.rs:64:30
|
LL | let _ = map.values().fold(1, |x, y| x * y);
| ^^^^^^^^^^^^^^^^^^^^^ help: try: `product::<i32>()`
error: this `.fold` can be written more succinctly using another method
--> $DIR/unnecessary_fold.rs:65:35
|
LL | let _: i32 = map.values().fold(0, |x, y| x + y);
| ^^^^^^^^^^^^^^^^^^^^^ help: try: `sum()`
error: this `.fold` can be written more succinctly using another method
--> $DIR/unnecessary_fold.rs:66:35
|
LL | let _: i32 = map.values().fold(1, |x, y| x * y);
| ^^^^^^^^^^^^^^^^^^^^^ help: try: `product()`
error: this `.fold` can be written more succinctly using another method
--> $DIR/unnecessary_fold.rs:67:31
|
LL | anything(map.values().fold(0, |x, y| x + y));
| ^^^^^^^^^^^^^^^^^^^^^ help: try: `sum::<i32>()`
error: this `.fold` can be written more succinctly using another method
--> $DIR/unnecessary_fold.rs:68:31
|
LL | anything(map.values().fold(1, |x, y| x * y));
| ^^^^^^^^^^^^^^^^^^^^^ help: try: `product::<i32>()`
error: this `.fold` can be written more succinctly using another method
--> $DIR/unnecessary_fold.rs:69:26
|
LL | num(map.values().fold(0, |x, y| x + y));
| ^^^^^^^^^^^^^^^^^^^^^ help: try: `sum()`
error: this `.fold` can be written more succinctly using another method
--> $DIR/unnecessary_fold.rs:70:26
|
LL | num(map.values().fold(1, |x, y| x * y));
| ^^^^^^^^^^^^^^^^^^^^^ help: try: `product()`
error: aborting due to 15 previous errors