This commit is contained in:
Caio 2023-06-17 17:03:31 -03:00
parent 8c8ff5f31d
commit 0e1caa765e
9 changed files with 114 additions and 72 deletions

View file

@ -51,7 +51,7 @@ impl<'tcx> LateLintPass<'tcx> for UnportableVariant {
.const_eval_poly(def_id.to_def_id())
.ok()
.map(|val| rustc_middle::mir::ConstantKind::from_value(val, ty));
if let Some(Constant::Int(val)) = constant.and_then(|c| miri_to_const(cx.tcx, c)) {
if let Some(Constant::Int(val)) = constant.and_then(|c| miri_to_const(cx, c)) {
if let ty::Adt(adt, _) = ty.kind() {
if adt.is_enum() {
ty = adt.repr().discr_type().to_ty(cx.tcx);

View file

@ -215,7 +215,7 @@ fn check_ln1p(cx: &LateContext<'_>, expr: &Expr<'_>, receiver: &Expr<'_>) {
// ranges [-16777215, 16777216) for type f32 as whole number floats outside
// this range are lossy and ambiguous.
#[expect(clippy::cast_possible_truncation)]
fn get_integer_from_float_constant(value: &Constant) -> Option<i32> {
fn get_integer_from_float_constant(value: &Constant<'_>) -> Option<i32> {
match value {
F32(num) if num.fract() == 0.0 => {
if (-16_777_215.0..16_777_216.0).contains(num) {

View file

@ -41,7 +41,7 @@ fn all_ranges<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>], ty: Ty<'tcx>)
cx.tcx.valtree_to_const_val((ty, min_val_const.to_valtree())),
ty,
);
miri_to_const(cx.tcx, min_constant)?
miri_to_const(cx, min_constant)?
},
};
let rhs_const = match rhs {
@ -52,7 +52,7 @@ fn all_ranges<'tcx>(cx: &LateContext<'tcx>, arms: &'tcx [Arm<'_>], ty: Ty<'tcx>)
cx.tcx.valtree_to_const_val((ty, max_val_const.to_valtree())),
ty,
);
miri_to_const(cx.tcx, max_constant)?
miri_to_const(cx, max_constant)?
},
};
let lhs_val = lhs_const.int_value(cx, ty)?;

View file

@ -66,7 +66,7 @@ enum MinMax {
Max,
}
fn min_max<'a>(cx: &LateContext<'_>, expr: &'a Expr<'a>) -> Option<(MinMax, Constant, &'a Expr<'a>)> {
fn min_max<'a, 'tcx>(cx: &LateContext<'tcx>, expr: &'a Expr<'a>) -> Option<(MinMax, Constant<'tcx>, &'a Expr<'a>)> {
match expr.kind {
ExprKind::Call(path, args) => {
if let ExprKind::Path(ref qpath) = path.kind {
@ -99,12 +99,12 @@ fn min_max<'a>(cx: &LateContext<'_>, expr: &'a Expr<'a>) -> Option<(MinMax, Cons
}
}
fn fetch_const<'a>(
cx: &LateContext<'_>,
fn fetch_const<'a, 'tcx>(
cx: &LateContext<'tcx>,
receiver: Option<&'a Expr<'a>>,
args: &'a [Expr<'a>],
m: MinMax,
) -> Option<(MinMax, Constant, &'a Expr<'a>)> {
) -> Option<(MinMax, Constant<'tcx>, &'a Expr<'a>)> {
let mut args = receiver.into_iter().chain(args);
let first_arg = args.next()?;
let second_arg = args.next()?;

View file

@ -85,7 +85,7 @@ fn get_lint_and_message(is_local: bool, is_comparing_arrays: bool) -> (&'static
}
}
fn is_allowed(val: &Constant) -> bool {
fn is_allowed(val: &Constant<'_>) -> bool {
match val {
&Constant::F32(f) => f == 0.0 || f.is_infinite(),
&Constant::F64(f) => f == 0.0 || f.is_infinite(),

View file

@ -296,8 +296,8 @@ fn check_possible_range_contains(
}
}
struct RangeBounds<'a> {
val: Constant,
struct RangeBounds<'a, 'tcx> {
val: Constant<'tcx>,
expr: &'a Expr<'a>,
id: HirId,
name_span: Span,
@ -309,7 +309,7 @@ struct RangeBounds<'a> {
// Takes a binary expression such as x <= 2 as input
// Breaks apart into various pieces, such as the value of the number,
// hir id of the variable, and direction/inclusiveness of the operator
fn check_range_bounds<'a>(cx: &'a LateContext<'_>, ex: &'a Expr<'_>) -> Option<RangeBounds<'a>> {
fn check_range_bounds<'a, 'tcx>(cx: &'a LateContext<'tcx>, ex: &'a Expr<'_>) -> Option<RangeBounds<'a, 'tcx>> {
if let ExprKind::Binary(ref op, l, r) = ex.kind {
let (inclusive, ordering) = match op.node {
BinOpKind::Gt => (false, Ordering::Greater),

View file

@ -14,7 +14,7 @@ use rustc_middle::mir::interpret::Scalar;
use rustc_middle::ty::{self, EarlyBinder, FloatTy, ScalarInt, Ty, TyCtxt};
use rustc_middle::ty::{List, SubstsRef};
use rustc_middle::{bug, span_bug};
use rustc_span::symbol::Symbol;
use rustc_span::symbol::{Ident, Symbol};
use rustc_span::SyntaxContext;
use std::cmp::Ordering::{self, Equal};
use std::hash::{Hash, Hasher};
@ -22,7 +22,8 @@ use std::iter;
/// A `LitKind`-like enum to fold constant `Expr`s into.
#[derive(Debug, Clone)]
pub enum Constant {
pub enum Constant<'tcx> {
Adt(rustc_middle::mir::ConstantKind<'tcx>),
/// A `String` (e.g., "abc").
Str(String),
/// A binary string (e.g., `b"abc"`).
@ -38,20 +39,20 @@ pub enum Constant {
/// `true` or `false`.
Bool(bool),
/// An array of constants.
Vec(Vec<Constant>),
Vec(Vec<Constant<'tcx>>),
/// Also an array, but with only one constant, repeated N times.
Repeat(Box<Constant>, u64),
Repeat(Box<Constant<'tcx>>, u64),
/// A tuple of constants.
Tuple(Vec<Constant>),
Tuple(Vec<Constant<'tcx>>),
/// A raw pointer.
RawPtr(u128),
/// A reference
Ref(Box<Constant>),
Ref(Box<Constant<'tcx>>),
/// A literal with syntax error.
Err,
}
impl PartialEq for Constant {
impl<'tcx> PartialEq for Constant<'tcx> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Str(ls), Self::Str(rs)) => ls == rs,
@ -80,13 +81,16 @@ impl PartialEq for Constant {
}
}
impl Hash for Constant {
impl<'tcx> Hash for Constant<'tcx> {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
std::mem::discriminant(self).hash(state);
match *self {
Self::Adt(ref elem) => {
elem.hash(state);
},
Self::Str(ref s) => {
s.hash(state);
},
@ -126,7 +130,7 @@ impl Hash for Constant {
}
}
impl Constant {
impl<'tcx> Constant<'tcx> {
pub fn partial_cmp(tcx: TyCtxt<'_>, cmp_type: Ty<'_>, left: &Self, right: &Self) -> Option<Ordering> {
match (left, right) {
(Self::Str(ls), Self::Str(rs)) => Some(ls.cmp(rs)),
@ -209,7 +213,7 @@ impl Constant {
}
/// Parses a `LitKind` to a `Constant`.
pub fn lit_to_mir_constant(lit: &LitKind, ty: Option<Ty<'_>>) -> Constant {
pub fn lit_to_mir_constant<'tcx>(lit: &LitKind, ty: Option<Ty<'tcx>>) -> Constant<'tcx> {
match *lit {
LitKind::Str(ref is, _) => Constant::Str(is.to_string()),
LitKind::Byte(b) => Constant::Int(u128::from(b)),
@ -248,7 +252,7 @@ pub fn constant<'tcx>(
lcx: &LateContext<'tcx>,
typeck_results: &ty::TypeckResults<'tcx>,
e: &Expr<'_>,
) -> Option<Constant> {
) -> Option<Constant<'tcx>> {
ConstEvalLateContext::new(lcx, typeck_results).expr(e)
}
@ -257,7 +261,7 @@ pub fn constant_with_source<'tcx>(
lcx: &LateContext<'tcx>,
typeck_results: &ty::TypeckResults<'tcx>,
e: &Expr<'_>,
) -> Option<(Constant, ConstantSource)> {
) -> Option<(Constant<'tcx>, ConstantSource)> {
let mut ctxt = ConstEvalLateContext::new(lcx, typeck_results);
let res = ctxt.expr(e);
res.map(|x| (x, ctxt.source))
@ -268,7 +272,7 @@ pub fn constant_simple<'tcx>(
lcx: &LateContext<'tcx>,
typeck_results: &ty::TypeckResults<'tcx>,
e: &Expr<'_>,
) -> Option<Constant> {
) -> Option<Constant<'tcx>> {
constant_with_source(lcx, typeck_results, e).and_then(|(c, s)| s.is_local().then_some(c))
}
@ -338,7 +342,7 @@ impl<'a, 'tcx> ConstEvalLateContext<'a, 'tcx> {
}
/// Simple constant folding: Insert an expression, get a constant or none.
pub fn expr(&mut self, e: &Expr<'_>) -> Option<Constant> {
pub fn expr(&mut self, e: &Expr<'_>) -> Option<Constant<'tcx>> {
match e.kind {
ExprKind::Path(ref qpath) => self.fetch_path(qpath, e.hir_id, self.typeck_results.expr_ty(e)),
ExprKind::Block(block, _) => self.block(block),
@ -392,13 +396,25 @@ impl<'a, 'tcx> ConstEvalLateContext<'a, 'tcx> {
},
ExprKind::Index(arr, index) => self.index(arr, index),
ExprKind::AddrOf(_, _, inner) => self.expr(inner).map(|r| Constant::Ref(Box::new(r))),
// TODO: add other expressions.
ExprKind::Field(ref local_expr, ref field) => {
let result = self.expr(local_expr);
if let Some(Constant::Adt(constant)) = &self.expr(local_expr)
&& let ty::Adt(adt_def, _) = constant.ty().kind()
&& adt_def.is_struct()
&& let Some(desired_field) = field_of_struct(*adt_def, self.lcx, *constant, field)
{
miri_to_const(self.lcx, desired_field)
}
else {
result
}
},
_ => None,
}
}
#[expect(clippy::cast_possible_wrap)]
fn constant_not(&self, o: &Constant, ty: Ty<'_>) -> Option<Constant> {
fn constant_not(&self, o: &Constant<'tcx>, ty: Ty<'_>) -> Option<Constant<'tcx>> {
use self::Constant::{Bool, Int};
match *o {
Bool(b) => Some(Bool(!b)),
@ -414,7 +430,7 @@ impl<'a, 'tcx> ConstEvalLateContext<'a, 'tcx> {
}
}
fn constant_negate(&self, o: &Constant, ty: Ty<'_>) -> Option<Constant> {
fn constant_negate(&self, o: &Constant<'tcx>, ty: Ty<'_>) -> Option<Constant<'tcx>> {
use self::Constant::{Int, F32, F64};
match *o {
Int(value) => {
@ -433,28 +449,25 @@ impl<'a, 'tcx> ConstEvalLateContext<'a, 'tcx> {
/// Create `Some(Vec![..])` of all constants, unless there is any
/// non-constant part.
fn multi(&mut self, vec: &[Expr<'_>]) -> Option<Vec<Constant>> {
fn multi(&mut self, vec: &[Expr<'_>]) -> Option<Vec<Constant<'tcx>>> {
vec.iter().map(|elem| self.expr(elem)).collect::<Option<_>>()
}
/// Lookup a possibly constant expression from an `ExprKind::Path`.
fn fetch_path(&mut self, qpath: &QPath<'_>, id: HirId, ty: Ty<'tcx>) -> Option<Constant> {
fn fetch_path(&mut self, qpath: &QPath<'_>, id: HirId, ty: Ty<'tcx>) -> Option<Constant<'tcx>> {
let res = self.typeck_results.qpath_res(qpath, id);
match res {
Res::Def(DefKind::Const | DefKind::AssocConst, def_id) => {
// Check if this constant is based on `cfg!(..)`,
// which is NOT constant for our purposes.
if let Some(node) = self.lcx.tcx.hir().get_if_local(def_id) &&
let Node::Item(&Item {
kind: ItemKind::Const(_, body_id),
..
}) = node &&
let Node::Expr(&Expr {
kind: ExprKind::Lit(_),
span,
..
}) = self.lcx.tcx.hir().get(body_id.hir_id) &&
is_direct_expn_of(span, "cfg").is_some() {
if let Some(node) = self.lcx.tcx.hir().get_if_local(def_id)
&& let Node::Item(Item { kind: ItemKind::Const(_, body_id), .. }) = node
&& let Node::Expr(Expr { kind: ExprKind::Lit(_), span, .. }) = self.lcx
.tcx
.hir()
.get(body_id.hir_id)
&& is_direct_expn_of(*span, "cfg").is_some()
{
return None;
}
@ -464,23 +477,21 @@ impl<'a, 'tcx> ConstEvalLateContext<'a, 'tcx> {
} else {
EarlyBinder::bind(substs).subst(self.lcx.tcx, self.substs)
};
let result = self
.lcx
.tcx
.const_eval_resolve(self.param_env, mir::UnevaluatedConst::new(def_id, substs), None)
.ok()
.map(|val| rustc_middle::mir::ConstantKind::from_value(val, ty))?;
let result = miri_to_const(self.lcx.tcx, result)?;
let result = miri_to_const(self.lcx, result)?;
self.source = ConstantSource::Constant;
Some(result)
},
// FIXME: cover all usable cases.
_ => None,
}
}
fn index(&mut self, lhs: &'_ Expr<'_>, index: &'_ Expr<'_>) -> Option<Constant> {
fn index(&mut self, lhs: &'_ Expr<'_>, index: &'_ Expr<'_>) -> Option<Constant<'tcx>> {
let lhs = self.expr(lhs);
let index = self.expr(index);
@ -506,7 +517,7 @@ impl<'a, 'tcx> ConstEvalLateContext<'a, 'tcx> {
}
/// A block can only yield a constant if it only has one constant expression.
fn block(&mut self, block: &Block<'_>) -> Option<Constant> {
fn block(&mut self, block: &Block<'_>) -> Option<Constant<'tcx>> {
if block.stmts.is_empty()
&& let Some(expr) = block.expr
{
@ -539,7 +550,7 @@ impl<'a, 'tcx> ConstEvalLateContext<'a, 'tcx> {
}
}
fn ifthenelse(&mut self, cond: &Expr<'_>, then: &Expr<'_>, otherwise: Option<&Expr<'_>>) -> Option<Constant> {
fn ifthenelse(&mut self, cond: &Expr<'_>, then: &Expr<'_>, otherwise: Option<&Expr<'_>>) -> Option<Constant<'tcx>> {
if let Some(Constant::Bool(b)) = self.expr(cond) {
if b {
self.expr(then)
@ -551,7 +562,7 @@ impl<'a, 'tcx> ConstEvalLateContext<'a, 'tcx> {
}
}
fn binop(&mut self, op: BinOp, left: &Expr<'_>, right: &Expr<'_>) -> Option<Constant> {
fn binop(&mut self, op: BinOp, left: &Expr<'_>, right: &Expr<'_>) -> Option<Constant<'tcx>> {
let l = self.expr(left)?;
let r = self.expr(right);
match (l, r) {
@ -644,23 +655,21 @@ impl<'a, 'tcx> ConstEvalLateContext<'a, 'tcx> {
}
}
pub fn miri_to_const<'tcx>(tcx: TyCtxt<'tcx>, result: mir::ConstantKind<'tcx>) -> Option<Constant> {
pub fn miri_to_const<'tcx>(lcx: &LateContext<'tcx>, result: mir::ConstantKind<'tcx>) -> Option<Constant<'tcx>> {
use rustc_middle::mir::interpret::ConstValue;
match result {
mir::ConstantKind::Val(ConstValue::Scalar(Scalar::Int(int)), _) => {
match result.ty().kind() {
ty::Bool => Some(Constant::Bool(int == ScalarInt::TRUE)),
ty::Uint(_) | ty::Int(_) => Some(Constant::Int(int.assert_bits(int.size()))),
ty::Float(FloatTy::F32) => Some(Constant::F32(f32::from_bits(
int.try_into().expect("invalid f32 bit representation"),
))),
ty::Float(FloatTy::F64) => Some(Constant::F64(f64::from_bits(
int.try_into().expect("invalid f64 bit representation"),
))),
ty::RawPtr(_) => Some(Constant::RawPtr(int.assert_bits(int.size()))),
// FIXME: implement other conversions.
_ => None,
}
mir::ConstantKind::Val(ConstValue::Scalar(Scalar::Int(int)), _) => match result.ty().kind() {
ty::Adt(adt_def, _) if adt_def.is_struct() => Some(Constant::Adt(result)),
ty::Bool => Some(Constant::Bool(int == ScalarInt::TRUE)),
ty::Uint(_) | ty::Int(_) => Some(Constant::Int(int.assert_bits(int.size()))),
ty::Float(FloatTy::F32) => Some(Constant::F32(f32::from_bits(
int.try_into().expect("invalid f32 bit representation"),
))),
ty::Float(FloatTy::F64) => Some(Constant::F64(f64::from_bits(
int.try_into().expect("invalid f64 bit representation"),
))),
ty::RawPtr(_) => Some(Constant::RawPtr(int.assert_bits(int.size()))),
_ => None,
},
mir::ConstantKind::Val(ConstValue::Slice { data, start, end }, _) => match result.ty().kind() {
ty::Ref(_, tam, _) => match tam.kind() {
@ -676,35 +685,53 @@ pub fn miri_to_const<'tcx>(tcx: TyCtxt<'tcx>, result: mir::ConstantKind<'tcx>) -
_ => None,
},
mir::ConstantKind::Val(ConstValue::ByRef { alloc, offset: _ }, _) => match result.ty().kind() {
ty::Adt(adt_def, _) if adt_def.is_struct() => Some(Constant::Adt(result)),
ty::Array(sub_type, len) => match sub_type.kind() {
ty::Float(FloatTy::F32) => match len.kind().try_to_target_usize(tcx) {
ty::Float(FloatTy::F32) => match len.kind().try_to_target_usize(lcx.tcx) {
Some(len) => alloc
.inner()
.inspect_with_uninit_and_ptr_outside_interpreter(0..(4 * usize::try_from(len).unwrap()))
.to_owned()
.array_chunks::<4>()
.map(|&chunk| Some(Constant::F32(f32::from_le_bytes(chunk))))
.collect::<Option<Vec<Constant>>>()
.collect::<Option<Vec<Constant<'tcx>>>>()
.map(Constant::Vec),
_ => None,
},
ty::Float(FloatTy::F64) => match len.kind().try_to_target_usize(tcx) {
ty::Float(FloatTy::F64) => match len.kind().try_to_target_usize(lcx.tcx) {
Some(len) => alloc
.inner()
.inspect_with_uninit_and_ptr_outside_interpreter(0..(8 * usize::try_from(len).unwrap()))
.to_owned()
.array_chunks::<8>()
.map(|&chunk| Some(Constant::F64(f64::from_le_bytes(chunk))))
.collect::<Option<Vec<Constant>>>()
.collect::<Option<Vec<Constant<'tcx>>>>()
.map(Constant::Vec),
_ => None,
},
// FIXME: implement other array type conversions.
_ => None,
},
_ => None,
},
// FIXME: implement other conversions.
_ => None,
}
}
fn field_of_struct<'tcx>(
adt_def: ty::AdtDef<'tcx>,
lcx: &LateContext<'tcx>,
result: mir::ConstantKind<'tcx>,
field: &Ident,
) -> Option<mir::ConstantKind<'tcx>> {
let dc = lcx.tcx.destructure_mir_constant(lcx.param_env, result);
if let Some(dc_variant) = dc.variant
&& let Some(variant) = adt_def.variants().get(dc_variant)
&& let Some(field_idx) = variant.fields.iter().position(|el| el.name == field.name)
&& let Some(dc_field) = dc.fields.get(field_idx)
{
Some(*dc_field)
}
else {
None
}
}

View file

@ -1499,7 +1499,7 @@ pub fn is_range_full(cx: &LateContext<'_>, expr: &Expr<'_>, container_path: Opti
&& let Some(min_val) = bnd_ty.numeric_min_val(cx.tcx)
&& let const_val = cx.tcx.valtree_to_const_val((bnd_ty, min_val.to_valtree()))
&& let min_const_kind = ConstantKind::from_value(const_val, bnd_ty)
&& let Some(min_const) = miri_to_const(cx.tcx, min_const_kind)
&& let Some(min_const) = miri_to_const(cx, min_const_kind)
&& let Some(start_const) = constant(cx, cx.typeck_results(), start)
{
start_const == min_const
@ -1515,7 +1515,7 @@ pub fn is_range_full(cx: &LateContext<'_>, expr: &Expr<'_>, container_path: Opti
&& let Some(max_val) = bnd_ty.numeric_max_val(cx.tcx)
&& let const_val = cx.tcx.valtree_to_const_val((bnd_ty, max_val.to_valtree()))
&& let max_const_kind = ConstantKind::from_value(const_val, bnd_ty)
&& let Some(max_const) = miri_to_const(cx.tcx, max_const_kind)
&& let Some(max_const) = miri_to_const(cx, max_const_kind)
&& let Some(end_const) = constant(cx, cx.typeck_results(), end)
{
end_const == max_const

View file

@ -466,4 +466,19 @@ pub fn issue_10767() {
&3.5_f32 + &1.3_f32;
}
pub fn issue_10792() {
struct One {
a: u32,
}
struct Two {
b: u32,
c: u64,
}
const ONE: One = One { a: 1 };
const TWO: Two = Two { b: 2, c: 3 };
let _ = 10 / ONE.a;
let _ = 10 / TWO.b;
let _ = 10 / TWO.c;
}
fn main() {}