diff --git a/clippy_lints/src/derive.rs b/clippy_lints/src/derive.rs index fe99f4a8d..99347ebad 100644 --- a/clippy_lints/src/derive.rs +++ b/clippy_lints/src/derive.rs @@ -1,16 +1,17 @@ use clippy_utils::diagnostics::{span_lint_and_help, span_lint_and_note, span_lint_and_sugg, span_lint_and_then}; use clippy_utils::paths; -use clippy_utils::ty::{implements_trait, is_copy}; +use clippy_utils::ty::{implements_trait, implements_trait_with_env, is_copy}; use clippy_utils::{is_lint_allowed, match_def_path}; use if_chain::if_chain; use rustc_errors::Applicability; use rustc_hir::intravisit::{walk_expr, walk_fn, walk_item, FnKind, Visitor}; use rustc_hir::{ - BlockCheckMode, BodyId, Expr, ExprKind, FnDecl, HirId, Impl, Item, ItemKind, TraitRef, UnsafeSource, Unsafety, + self as hir, BlockCheckMode, BodyId, Expr, ExprKind, FnDecl, HirId, Impl, Item, ItemKind, UnsafeSource, Unsafety, }; use rustc_lint::{LateContext, LateLintPass}; use rustc_middle::hir::nested_filter; -use rustc_middle::ty::{self, Ty}; +use rustc_middle::ty::subst::GenericArg; +use rustc_middle::ty::{self, BoundConstness, ImplPolarity, ParamEnv, PredicateKind, TraitPredicate, TraitRef, Ty}; use rustc_session::{declare_lint_pass, declare_tool_lint}; use rustc_span::source_map::Span; use rustc_span::sym; @@ -224,7 +225,7 @@ impl<'tcx> LateLintPass<'tcx> for Derive { fn check_hash_peq<'tcx>( cx: &LateContext<'tcx>, span: Span, - trait_ref: &TraitRef<'_>, + trait_ref: &hir::TraitRef<'_>, ty: Ty<'tcx>, hash_is_automatically_derived: bool, ) { @@ -277,7 +278,7 @@ fn check_hash_peq<'tcx>( fn check_ord_partial_ord<'tcx>( cx: &LateContext<'tcx>, span: Span, - trait_ref: &TraitRef<'_>, + trait_ref: &hir::TraitRef<'_>, ty: Ty<'tcx>, ord_is_automatically_derived: bool, ) { @@ -328,7 +329,7 @@ fn check_ord_partial_ord<'tcx>( } /// Implementation of the `EXPL_IMPL_CLONE_ON_COPY` lint. -fn check_copy_clone<'tcx>(cx: &LateContext<'tcx>, item: &Item<'_>, trait_ref: &TraitRef<'_>, ty: Ty<'tcx>) { +fn check_copy_clone<'tcx>(cx: &LateContext<'tcx>, item: &Item<'_>, trait_ref: &hir::TraitRef<'_>, ty: Ty<'tcx>) { let clone_id = match cx.tcx.lang_items().clone_trait() { Some(id) if trait_ref.trait_def_id() == Some(id) => id, _ => return, @@ -378,7 +379,7 @@ fn check_copy_clone<'tcx>(cx: &LateContext<'tcx>, item: &Item<'_>, trait_ref: &T fn check_unsafe_derive_deserialize<'tcx>( cx: &LateContext<'tcx>, item: &Item<'_>, - trait_ref: &TraitRef<'_>, + trait_ref: &hir::TraitRef<'_>, ty: Ty<'tcx>, ) { fn has_unsafe<'tcx>(cx: &LateContext<'tcx>, item: &'tcx Item<'_>) -> bool { @@ -455,13 +456,41 @@ impl<'tcx> Visitor<'tcx> for UnsafeVisitor<'_, 'tcx> { } /// Implementation of the `DERIVE_PARTIAL_EQ_WITHOUT_EQ` lint. -fn check_partial_eq_without_eq<'tcx>(cx: &LateContext<'tcx>, span: Span, trait_ref: &TraitRef<'_>, ty: Ty<'tcx>) { +fn check_partial_eq_without_eq<'tcx>(cx: &LateContext<'tcx>, span: Span, trait_ref: &hir::TraitRef<'_>, ty: Ty<'tcx>) { if_chain! { if let ty::Adt(adt, substs) = ty.kind(); if let Some(eq_trait_def_id) = cx.tcx.get_diagnostic_item(sym::Eq); + if let Some(peq_trait_def_id) = cx.tcx.get_diagnostic_item(sym::PartialEq); if let Some(def_id) = trait_ref.trait_def_id(); if cx.tcx.is_diagnostic_item(sym::PartialEq, def_id); - if !implements_trait(cx, ty, eq_trait_def_id, substs); + // New `ParamEnv` replacing `T: PartialEq` with `T: Eq` + let param_env = ParamEnv::new( + cx.tcx.mk_predicates(cx.param_env.caller_bounds().iter().map(|p| { + let kind = p.kind(); + match kind.skip_binder() { + PredicateKind::Trait(p) + if p.trait_ref.def_id == peq_trait_def_id + && p.trait_ref.substs.get(0) == p.trait_ref.substs.get(1) + && matches!(p.trait_ref.self_ty().kind(), ty::Param(_)) + && p.constness == BoundConstness::NotConst + && p.polarity == ImplPolarity::Positive => + { + cx.tcx.mk_predicate(kind.rebind(PredicateKind::Trait(TraitPredicate { + trait_ref: TraitRef::new( + eq_trait_def_id, + cx.tcx.mk_substs([GenericArg::from(p.trait_ref.self_ty())].into_iter()), + ), + constness: BoundConstness::NotConst, + polarity: ImplPolarity::Positive, + }))) + }, + _ => p, + } + })), + cx.param_env.reveal(), + cx.param_env.constness(), + ); + if !implements_trait_with_env(cx.tcx, param_env, ty, eq_trait_def_id, substs); then { // If all of our fields implement `Eq`, we can implement `Eq` too for variant in adt.variants() { diff --git a/clippy_utils/src/ty.rs b/clippy_utils/src/ty.rs index 07d3d2807..203f33d35 100644 --- a/clippy_utils/src/ty.rs +++ b/clippy_utils/src/ty.rs @@ -13,7 +13,8 @@ use rustc_lint::LateContext; use rustc_middle::mir::interpret::{ConstValue, Scalar}; use rustc_middle::ty::subst::{GenericArg, GenericArgKind, Subst}; use rustc_middle::ty::{ - self, AdtDef, Binder, FnSig, IntTy, Predicate, PredicateKind, Ty, TyCtxt, TypeFoldable, UintTy, VariantDiscr, + self, AdtDef, Binder, FnSig, IntTy, ParamEnv, Predicate, PredicateKind, Ty, TyCtxt, TypeFoldable, UintTy, + VariantDiscr, }; use rustc_span::symbol::Ident; use rustc_span::{sym, Span, Symbol, DUMMY_SP}; @@ -151,18 +152,29 @@ pub fn implements_trait<'tcx>( ty: Ty<'tcx>, trait_id: DefId, ty_params: &[GenericArg<'tcx>], +) -> bool { + implements_trait_with_env(cx.tcx, cx.param_env, ty, trait_id, ty_params) +} + +/// Same as `implements_trait` but allows using a `ParamEnv` different from the lint context. +pub fn implements_trait_with_env<'tcx>( + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + ty: Ty<'tcx>, + trait_id: DefId, + ty_params: &[GenericArg<'tcx>], ) -> bool { // Clippy shouldn't have infer types assert!(!ty.needs_infer()); - let ty = cx.tcx.erase_regions(ty); + let ty = tcx.erase_regions(ty); if ty.has_escaping_bound_vars() { return false; } - let ty_params = cx.tcx.mk_substs(ty_params.iter()); - cx.tcx.infer_ctxt().enter(|infcx| { + let ty_params = tcx.mk_substs(ty_params.iter()); + tcx.infer_ctxt().enter(|infcx| { infcx - .type_implements_trait(trait_id, ty, ty_params, cx.param_env) + .type_implements_trait(trait_id, ty, ty_params, param_env) .must_apply_modulo_regions() }) } diff --git a/tests/ui/derive_partial_eq_without_eq.fixed b/tests/ui/derive_partial_eq_without_eq.fixed index 7d4d1b3b6..012780258 100644 --- a/tests/ui/derive_partial_eq_without_eq.fixed +++ b/tests/ui/derive_partial_eq_without_eq.fixed @@ -95,4 +95,10 @@ enum EnumNotEq { #[derive(Debug, PartialEq, Eq, Clone)] struct RustFixWithOtherDerives; +#[derive(PartialEq)] +struct Generic(T); + +#[derive(PartialEq, Eq)] +struct GenericPhantom(core::marker::PhantomData); + fn main() {} diff --git a/tests/ui/derive_partial_eq_without_eq.rs b/tests/ui/derive_partial_eq_without_eq.rs index ab4e1df1c..fc8285b0c 100644 --- a/tests/ui/derive_partial_eq_without_eq.rs +++ b/tests/ui/derive_partial_eq_without_eq.rs @@ -95,4 +95,10 @@ enum EnumNotEq { #[derive(Debug, PartialEq, Clone)] struct RustFixWithOtherDerives; +#[derive(PartialEq)] +struct Generic(T); + +#[derive(PartialEq, Eq)] +struct GenericPhantom(core::marker::PhantomData); + fn main() {}