11772: Support constants in const eval r=HKalbasi a=HKalbasi

This PR enables evaluating things like this:
```rust
const X: usize = 2;
const Y: usize = 3 + X; // = 5
```
My target was nalgebra's `U5`, `U22`, ... which are defined as `type U5 = Const<{ SomeType5::SOME_ASSOC_CONST }>` but I didn't find out how to find the `ConstId` of the implementation of the trait, not the trait itself (possibly related to #4558 ? We can find associated type alias, so maybe this is doable already) So it doesn't help for nalgebra currently, but it is useful anyway.


Co-authored-by: hkalbasi <hamidrezakalbasi@protonmail.com>
This commit is contained in:
bors[bot] 2022-03-24 09:42:09 +00:00 committed by GitHub
commit f3d1a53fa6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 298 additions and 141 deletions

View file

@ -32,7 +32,7 @@ pub mod symbols;
mod display;
use std::{collections::HashMap, iter, ops::ControlFlow, sync::Arc};
use std::{iter, ops::ControlFlow, sync::Arc};
use arrayvec::ArrayVec;
use base_db::{CrateDisplayName, CrateId, CrateOrigin, Edition, FileId, ProcMacroKind};
@ -55,9 +55,7 @@ use hir_def::{
use hir_expand::{name::name, MacroCallKind};
use hir_ty::{
autoderef,
consteval::{
eval_const, unknown_const_as_generic, ComputedExpr, ConstEvalCtx, ConstEvalError, ConstExt,
},
consteval::{unknown_const_as_generic, ComputedExpr, ConstEvalError, ConstExt},
diagnostics::BodyValidationDiagnostic,
method_resolution::{self, TyFingerprint},
primitive::UintTy,
@ -1602,20 +1600,7 @@ impl Const {
}
pub fn eval(self, db: &dyn HirDatabase) -> Result<ComputedExpr, ConstEvalError> {
let body = db.body(self.id.into());
let root = &body.exprs[body.body_expr];
let infer = db.infer_query(self.id.into());
let infer = infer.as_ref();
let result = eval_const(
root,
&mut ConstEvalCtx {
exprs: &body.exprs,
pats: &body.pats,
local_data: HashMap::default(),
infer: &mut |x| infer[x].clone(),
},
);
result
db.const_eval(self.id)
}
}

View file

@ -8,22 +8,19 @@ use std::{
use chalk_ir::{BoundVar, DebruijnIndex, GenericArgData, IntTy, Scalar};
use hir_def::{
expr::{ArithOp, BinaryOp, Expr, Literal, Pat},
expr::{ArithOp, BinaryOp, Expr, ExprId, Literal, Pat, PatId},
path::ModPath,
resolver::{Resolver, ValueNs},
resolver::{resolver_for_expr, ResolveValueResult, Resolver, ValueNs},
type_ref::ConstScalar,
ConstId, DefWithBodyId,
};
use hir_expand::name::Name;
use la_arena::{Arena, Idx};
use stdx::never;
use crate::{
db::HirDatabase,
infer::{Expectation, InferenceContext},
lower::ParamLoweringMode,
to_placeholder_idx,
utils::Generics,
Const, ConstData, ConstValue, GenericArg, Interner, Ty, TyKind,
db::HirDatabase, infer::InferenceContext, lower::ParamLoweringMode, to_placeholder_idx,
utils::Generics, Const, ConstData, ConstValue, GenericArg, InferenceResult, Interner, Ty,
TyKind,
};
/// Extension trait for [`Const`]
@ -55,21 +52,30 @@ impl ConstExt for Const {
}
pub struct ConstEvalCtx<'a> {
pub db: &'a dyn HirDatabase,
pub owner: DefWithBodyId,
pub exprs: &'a Arena<Expr>,
pub pats: &'a Arena<Pat>,
pub local_data: HashMap<Name, ComputedExpr>,
pub infer: &'a mut dyn FnMut(Idx<Expr>) -> Ty,
pub local_data: HashMap<PatId, ComputedExpr>,
infer: &'a InferenceResult,
}
#[derive(Debug, Clone)]
impl ConstEvalCtx<'_> {
fn expr_ty(&mut self, expr: ExprId) -> Ty {
self.infer[expr].clone()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConstEvalError {
NotSupported(&'static str),
TypeError,
SemanticError(&'static str),
Loop,
IncompleteExpr,
Panic(String),
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ComputedExpr {
Literal(Literal),
Tuple(Box<[ComputedExpr]>),
@ -80,14 +86,14 @@ impl Display for ComputedExpr {
match self {
ComputedExpr::Literal(l) => match l {
Literal::Int(x, _) => {
if *x >= 16 {
if *x >= 10 {
write!(f, "{} ({:#X})", x, x)
} else {
x.fmt(f)
}
}
Literal::Uint(x, _) => {
if *x >= 16 {
if *x >= 10 {
write!(f, "{} ({:#X})", x, x)
} else {
x.fmt(f)
@ -143,12 +149,17 @@ fn is_valid(scalar: &Scalar, value: i128) -> bool {
}
}
pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result<ComputedExpr, ConstEvalError> {
pub fn eval_const(
expr_id: ExprId,
ctx: &mut ConstEvalCtx<'_>,
) -> Result<ComputedExpr, ConstEvalError> {
let expr = &ctx.exprs[expr_id];
match expr {
Expr::Missing => Err(ConstEvalError::IncompleteExpr),
Expr::Literal(l) => Ok(ComputedExpr::Literal(l.clone())),
&Expr::UnaryOp { expr, op } => {
let ty = &(ctx.infer)(expr);
let ev = eval_const(&ctx.exprs[expr], ctx)?;
let ty = &ctx.expr_ty(expr);
let ev = eval_const(expr, ctx)?;
match op {
hir_def::expr::UnaryOp::Deref => Err(ConstEvalError::NotSupported("deref")),
hir_def::expr::UnaryOp::Not => {
@ -203,9 +214,9 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result<ComputedExp
}
}
&Expr::BinaryOp { lhs, rhs, op } => {
let ty = &(ctx.infer)(lhs);
let lhs = eval_const(&ctx.exprs[lhs], ctx)?;
let rhs = eval_const(&ctx.exprs[rhs], ctx)?;
let ty = &ctx.expr_ty(lhs);
let lhs = eval_const(lhs, ctx)?;
let rhs = eval_const(rhs, ctx)?;
let op = op.ok_or(ConstEvalError::IncompleteExpr)?;
let v1 = match lhs {
ComputedExpr::Literal(Literal::Int(v, _)) => v,
@ -249,31 +260,31 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result<ComputedExp
}
Ok(ComputedExpr::Literal(Literal::Int(r, None)))
}
BinaryOp::LogicOp(_) => Err(ConstEvalError::TypeError),
BinaryOp::LogicOp(_) => Err(ConstEvalError::SemanticError("logic op on numbers")),
_ => Err(ConstEvalError::NotSupported("bin op on this operators")),
}
}
Expr::Block { statements, tail, .. } => {
let mut prev_values = HashMap::<Name, Option<ComputedExpr>>::default();
let mut prev_values = HashMap::<PatId, Option<ComputedExpr>>::default();
for statement in &**statements {
match *statement {
hir_def::expr::Statement::Let { pat, initializer, .. } => {
let pat = &ctx.pats[pat];
let name = match pat {
Pat::Bind { name, subpat, .. } if subpat.is_none() => name.clone(),
hir_def::expr::Statement::Let { pat: pat_id, initializer, .. } => {
let pat = &ctx.pats[pat_id];
match pat {
Pat::Bind { subpat, .. } if subpat.is_none() => (),
_ => {
return Err(ConstEvalError::NotSupported("complex patterns in let"))
}
};
let value = match initializer {
Some(x) => eval_const(&ctx.exprs[x], ctx)?,
Some(x) => eval_const(x, ctx)?,
None => continue,
};
if !prev_values.contains_key(&name) {
let prev = ctx.local_data.insert(name.clone(), value);
prev_values.insert(name, prev);
if !prev_values.contains_key(&pat_id) {
let prev = ctx.local_data.insert(pat_id, value);
prev_values.insert(pat_id, prev);
} else {
ctx.local_data.insert(name, value);
ctx.local_data.insert(pat_id, value);
}
}
hir_def::expr::Statement::Expr { .. } => {
@ -282,7 +293,7 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result<ComputedExp
}
}
let r = match tail {
&Some(x) => eval_const(&ctx.exprs[x], ctx),
&Some(x) => eval_const(x, ctx),
None => Ok(ComputedExpr::Tuple(Box::new([]))),
};
// clean up local data, so caller will receive the exact map that passed to us
@ -295,19 +306,48 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result<ComputedExp
r
}
Expr::Path(p) => {
let name = p.mod_path().as_ident().ok_or(ConstEvalError::NotSupported("big paths"))?;
let resolver = resolver_for_expr(ctx.db.upcast(), ctx.owner, expr_id);
let pr = resolver
.resolve_path_in_value_ns(ctx.db.upcast(), p.mod_path())
.ok_or(ConstEvalError::SemanticError("unresolved path"))?;
let pr = match pr {
ResolveValueResult::ValueNs(v) => v,
ResolveValueResult::Partial(..) => {
return match ctx
.infer
.assoc_resolutions_for_expr(expr_id)
.ok_or(ConstEvalError::SemanticError("unresolved assoc item"))?
{
hir_def::AssocItemId::FunctionId(_) => {
Err(ConstEvalError::NotSupported("assoc function"))
}
hir_def::AssocItemId::ConstId(c) => ctx.db.const_eval(c),
hir_def::AssocItemId::TypeAliasId(_) => {
Err(ConstEvalError::NotSupported("assoc type alias"))
}
}
}
};
match pr {
ValueNs::LocalBinding(pat_id) => {
let r = ctx
.local_data
.get(name)
.ok_or(ConstEvalError::NotSupported("Non local name resolution"))?;
.get(&pat_id)
.ok_or(ConstEvalError::NotSupported("Unexpected missing local"))?;
Ok(r.clone())
}
ValueNs::ConstId(id) => ctx.db.const_eval(id),
ValueNs::GenericParam(_) => {
Err(ConstEvalError::NotSupported("const generic without substitution"))
}
_ => Err(ConstEvalError::NotSupported("path that are not const or local")),
}
}
_ => Err(ConstEvalError::NotSupported("This kind of expression")),
}
}
pub fn eval_usize(expr: Idx<Expr>, mut ctx: ConstEvalCtx<'_>) -> Option<u64> {
let expr = &ctx.exprs[expr];
if let Ok(ce) = eval_const(expr, &mut ctx) {
match ce {
ComputedExpr::Literal(Literal::Int(x, _)) => return x.try_into().ok(),
@ -380,10 +420,39 @@ pub fn usize_const(value: Option<u64>) -> Const {
.intern(Interner)
}
pub(crate) fn eval_to_const(
pub(crate) fn const_eval_recover(
_: &dyn HirDatabase,
_: &[String],
_: &ConstId,
) -> Result<ComputedExpr, ConstEvalError> {
Err(ConstEvalError::Loop)
}
pub(crate) fn const_eval_query(
db: &dyn HirDatabase,
const_id: ConstId,
) -> Result<ComputedExpr, ConstEvalError> {
let def = const_id.into();
let body = db.body(def);
let infer = &db.infer(def);
let result = eval_const(
body.body_expr,
&mut ConstEvalCtx {
db,
owner: const_id.into(),
exprs: &body.exprs,
pats: &body.pats,
local_data: HashMap::default(),
infer,
},
);
result
}
pub(crate) fn eval_to_const<'a>(
expr: Idx<Expr>,
mode: ParamLoweringMode,
ctx: &mut InferenceContext,
ctx: &mut InferenceContext<'a>,
args: impl FnOnce() -> Generics,
debruijn: DebruijnIndex,
) -> Const {
@ -396,10 +465,15 @@ pub(crate) fn eval_to_const(
}
let body = ctx.body.clone();
let ctx = ConstEvalCtx {
db: ctx.db,
owner: ctx.owner,
exprs: &body.exprs,
pats: &body.pats,
local_data: HashMap::default(),
infer: &mut |x| ctx.infer_expr(x, &Expectation::None),
infer: &ctx.result,
};
usize_const(eval_usize(expr, ctx))
}
#[cfg(test)]
mod tests;

View file

@ -0,0 +1,148 @@
use base_db::fixture::WithFixture;
use hir_def::{db::DefDatabase, expr::Literal};
use crate::{consteval::ComputedExpr, db::HirDatabase, test_db::TestDB};
use super::ConstEvalError;
fn check_fail(ra_fixture: &str, error: ConstEvalError) {
assert_eq!(eval_goal(ra_fixture), Err(error));
}
fn check_number(ra_fixture: &str, answer: i128) {
let r = eval_goal(ra_fixture).unwrap();
match r {
ComputedExpr::Literal(Literal::Int(r, _)) => assert_eq!(r, answer),
ComputedExpr::Literal(Literal::Uint(r, _)) => assert_eq!(r, answer as u128),
x => panic!("Expected number but found {:?}", x),
}
}
fn eval_goal(ra_fixture: &str) -> Result<ComputedExpr, ConstEvalError> {
let (db, file_id) = TestDB::with_single_file(ra_fixture);
let module_id = db.module_for_file(file_id);
let def_map = module_id.def_map(&db);
let scope = &def_map[module_id.local_id].scope;
let const_id = scope
.declarations()
.into_iter()
.find_map(|x| match x {
hir_def::ModuleDefId::ConstId(x) => {
if db.const_data(x).name.as_ref()?.to_string() == "GOAL" {
Some(x)
} else {
None
}
}
_ => None,
})
.unwrap();
db.const_eval(const_id)
}
#[test]
fn add() {
check_number(r#"const GOAL: usize = 2 + 2;"#, 4);
}
#[test]
fn bit_op() {
check_number(r#"const GOAL: u8 = !0 & !(!0 >> 1)"#, 128);
check_number(r#"const GOAL: i8 = !0 & !(!0 >> 1)"#, 0);
// FIXME: rustc evaluate this to -128
check_fail(
r#"const GOAL: i8 = 1 << 7"#,
ConstEvalError::Panic("attempt to run invalid arithmetic operation".to_string()),
);
check_fail(
r#"const GOAL: i8 = 1 << 8"#,
ConstEvalError::Panic("attempt to run invalid arithmetic operation".to_string()),
);
}
#[test]
fn locals() {
check_number(
r#"
const GOAL: usize = {
let a = 3 + 2;
let b = a * a;
b
};
"#,
25,
);
}
#[test]
fn consts() {
check_number(
r#"
const F1: i32 = 1;
const F3: i32 = 3 * F2;
const F2: i32 = 2 * F1;
const GOAL: i32 = F3;
"#,
6,
);
}
#[test]
fn const_loop() {
check_fail(
r#"
const F1: i32 = 1 * F3;
const F3: i32 = 3 * F2;
const F2: i32 = 2 * F1;
const GOAL: i32 = F3;
"#,
ConstEvalError::Loop,
);
}
#[test]
fn const_impl_assoc() {
check_number(
r#"
struct U5;
impl U5 {
const VAL: usize = 5;
}
const GOAL: usize = U5::VAL;
"#,
5,
);
}
#[test]
fn const_generic_subst() {
// FIXME: this should evaluate to 5
check_fail(
r#"
struct Adder<const N: usize, const M: usize>;
impl<const N: usize, const M: usize> Adder<N, M> {
const VAL: usize = N + M;
}
const GOAL: usize = Adder::<2, 3>::VAL;
"#,
ConstEvalError::NotSupported("const generic without substitution"),
);
}
#[test]
fn const_trait_assoc() {
// FIXME: this should evaluate to 0
check_fail(
r#"
struct U0;
trait ToConst {
const VAL: usize;
}
impl ToConst for U0 {
const VAL: usize = 0;
}
const GOAL: usize = U0::VAL;
"#,
ConstEvalError::IncompleteExpr,
);
}

View file

@ -5,13 +5,14 @@ use std::sync::Arc;
use base_db::{impl_intern_key, salsa, CrateId, Upcast};
use hir_def::{
db::DefDatabase, expr::ExprId, BlockId, ConstParamId, DefWithBodyId, FunctionId, GenericDefId,
ImplId, LifetimeParamId, LocalFieldId, TypeOrConstParamId, VariantId,
db::DefDatabase, expr::ExprId, BlockId, ConstId, ConstParamId, DefWithBodyId, FunctionId,
GenericDefId, ImplId, LifetimeParamId, LocalFieldId, TypeOrConstParamId, VariantId,
};
use la_arena::ArenaMap;
use crate::{
chalk_db,
consteval::{ComputedExpr, ConstEvalError},
method_resolution::{InherentImpls, TraitImpls},
Binders, CallableDefId, FnDefId, GenericArg, ImplTraitId, InferenceResult, Interner, PolyFnSig,
QuantifiedWhereClause, ReturnTypeImplTraits, TraitRef, Ty, TyDefId, ValueTyDefId,
@ -41,6 +42,10 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
#[salsa::invoke(crate::lower::const_param_ty_query)]
fn const_param_ty(&self, def: ConstParamId) -> Ty;
#[salsa::invoke(crate::consteval::const_eval_query)]
#[salsa::cycle(crate::consteval::const_eval_recover)]
fn const_eval(&self, def: ConstId) -> Result<ComputedExpr, ConstEvalError>;
#[salsa::invoke(crate::lower::impl_trait_query)]
fn impl_trait(&self, def: ImplId) -> Option<Binders<TraitRef>>;

View file

@ -358,12 +358,12 @@ impl Index<PatId> for InferenceResult {
#[derive(Clone, Debug)]
pub(crate) struct InferenceContext<'a> {
pub(crate) db: &'a dyn HirDatabase,
owner: DefWithBodyId,
pub(crate) owner: DefWithBodyId,
pub(crate) body: Arc<Body>,
pub(crate) resolver: Resolver,
table: unify::InferenceTable<'a>,
trait_env: Arc<TraitEnvironment>,
result: InferenceResult,
pub(crate) result: InferenceResult,
/// The return type of the function being inferred, the closure or async block if we're
/// currently within one.
///

View file

@ -1749,6 +1749,18 @@ fn main() {
);
}
#[test]
fn const_eval_array_repeat_expr() {
check_types(
r#"
fn main() {
const X: usize = 6 - 1;
let t = [(); X + 2];
//^ [(); 7]
}"#,
);
}
#[test]
fn shadowing_primitive_with_inner_items() {
check_types(

View file

@ -3401,10 +3401,11 @@ impl<const LEN: usize> Foo<LEN$0> {}
#[test]
fn hover_const_eval() {
// show hex for <10
check(
r#"
/// This is a doc
const FOO$0: usize = !0 & !(!0 >> 1);
const FOO$0: usize = 1 << 3;
"#,
expect![[r#"
*FOO*
@ -3414,7 +3415,7 @@ const FOO$0: usize = !0 & !(!0 >> 1);
```
```rust
const FOO: usize = 9223372036854775808 (0x8000000000000000)
const FOO: usize = 8
```
---
@ -3422,14 +3423,11 @@ const FOO$0: usize = !0 & !(!0 >> 1);
This is a doc
"#]],
);
// show hex for >10
check(
r#"
/// This is a doc
const FOO$0: usize = {
let a = 3 + 2;
let b = a * a;
b
};
const FOO$0: usize = (1 << 3) + (1 << 2);
"#,
expect![[r#"
*FOO*
@ -3439,53 +3437,7 @@ const FOO$0: usize = {
```
```rust
const FOO: usize = 25 (0x19)
```
---
This is a doc
"#]],
);
check(
r#"
/// This is a doc
const FOO$0: usize = 1 << 10;
"#,
expect![[r#"
*FOO*
```rust
test
```
```rust
const FOO: usize = 1024 (0x400)
```
---
This is a doc
"#]],
);
check(
r#"
/// This is a doc
const FOO$0: usize = {
let b = 4;
let a = { let b = 2; let a = b; a } + { let a = 1; a + b };
a
};
"#,
expect![[r#"
*FOO*
```rust
test
```
```rust
const FOO: usize = 7
const FOO: usize = 12 (0xC)
```
---
@ -3493,6 +3445,7 @@ const FOO$0: usize = {
This is a doc
"#]],
);
// show original body when const eval fails
check(
r#"
/// This is a doc
@ -3514,6 +3467,7 @@ const FOO$0: usize = 2 - 3;
This is a doc
"#]],
);
// don't show hex for negatives
check(
r#"
/// This is a doc
@ -3538,27 +3492,6 @@ const FOO$0: i32 = 2 - 3;
check(
r#"
/// This is a doc
const FOO$0: usize = 1 << 100;
"#,
expect![[r#"
*FOO*
```rust
test
```
```rust
const FOO: usize = 1 << 100
```
---
This is a doc
"#]],
);
check(
r#"
/// This is a doc
const FOO$0: &str = "bar";
"#,
expect![[r#"