diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index f4e58d88ed..18de04b16d 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -32,7 +32,7 @@ pub mod symbols; mod display; -use std::{collections::HashMap, iter, ops::ControlFlow, sync::Arc}; +use std::{iter, ops::ControlFlow, sync::Arc}; use arrayvec::ArrayVec; use base_db::{CrateDisplayName, CrateId, CrateOrigin, Edition, FileId, ProcMacroKind}; @@ -55,9 +55,7 @@ use hir_def::{ use hir_expand::{name::name, MacroCallKind}; use hir_ty::{ autoderef, - consteval::{ - eval_const, unknown_const_as_generic, ComputedExpr, ConstEvalCtx, ConstEvalError, ConstExt, - }, + consteval::{unknown_const_as_generic, ComputedExpr, ConstEvalError, ConstExt}, diagnostics::BodyValidationDiagnostic, method_resolution::{self, TyFingerprint}, primitive::UintTy, @@ -1602,20 +1600,7 @@ impl Const { } pub fn eval(self, db: &dyn HirDatabase) -> Result { - let body = db.body(self.id.into()); - let root = &body.exprs[body.body_expr]; - let infer = db.infer_query(self.id.into()); - let infer = infer.as_ref(); - let result = eval_const( - root, - &mut ConstEvalCtx { - exprs: &body.exprs, - pats: &body.pats, - local_data: HashMap::default(), - infer: &mut |x| infer[x].clone(), - }, - ); - result + db.const_eval(self.id) } } diff --git a/crates/hir_ty/src/consteval.rs b/crates/hir_ty/src/consteval.rs index 1ba291528f..009ea008fc 100644 --- a/crates/hir_ty/src/consteval.rs +++ b/crates/hir_ty/src/consteval.rs @@ -8,22 +8,19 @@ use std::{ use chalk_ir::{BoundVar, DebruijnIndex, GenericArgData, IntTy, Scalar}; use hir_def::{ - expr::{ArithOp, BinaryOp, Expr, Literal, Pat}, + expr::{ArithOp, BinaryOp, Expr, ExprId, Literal, Pat, PatId}, path::ModPath, - resolver::{Resolver, ValueNs}, + resolver::{resolver_for_expr, ResolveValueResult, Resolver, ValueNs}, type_ref::ConstScalar, + ConstId, DefWithBodyId, }; -use hir_expand::name::Name; use la_arena::{Arena, Idx}; use stdx::never; use crate::{ - db::HirDatabase, - infer::{Expectation, InferenceContext}, - lower::ParamLoweringMode, - to_placeholder_idx, - utils::Generics, - Const, ConstData, ConstValue, GenericArg, Interner, Ty, TyKind, + db::HirDatabase, infer::InferenceContext, lower::ParamLoweringMode, to_placeholder_idx, + utils::Generics, Const, ConstData, ConstValue, GenericArg, InferenceResult, Interner, Ty, + TyKind, }; /// Extension trait for [`Const`] @@ -55,21 +52,30 @@ impl ConstExt for Const { } pub struct ConstEvalCtx<'a> { + pub db: &'a dyn HirDatabase, + pub owner: DefWithBodyId, pub exprs: &'a Arena, pub pats: &'a Arena, - pub local_data: HashMap, - pub infer: &'a mut dyn FnMut(Idx) -> Ty, + pub local_data: HashMap, + infer: &'a InferenceResult, } -#[derive(Debug, Clone)] +impl ConstEvalCtx<'_> { + fn expr_ty(&mut self, expr: ExprId) -> Ty { + self.infer[expr].clone() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] pub enum ConstEvalError { NotSupported(&'static str), - TypeError, + SemanticError(&'static str), + Loop, IncompleteExpr, Panic(String), } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum ComputedExpr { Literal(Literal), Tuple(Box<[ComputedExpr]>), @@ -80,14 +86,14 @@ impl Display for ComputedExpr { match self { ComputedExpr::Literal(l) => match l { Literal::Int(x, _) => { - if *x >= 16 { + if *x >= 10 { write!(f, "{} ({:#X})", x, x) } else { x.fmt(f) } } Literal::Uint(x, _) => { - if *x >= 16 { + if *x >= 10 { write!(f, "{} ({:#X})", x, x) } else { x.fmt(f) @@ -143,12 +149,17 @@ fn is_valid(scalar: &Scalar, value: i128) -> bool { } } -pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result { +pub fn eval_const( + expr_id: ExprId, + ctx: &mut ConstEvalCtx<'_>, +) -> Result { + let expr = &ctx.exprs[expr_id]; match expr { + Expr::Missing => Err(ConstEvalError::IncompleteExpr), Expr::Literal(l) => Ok(ComputedExpr::Literal(l.clone())), &Expr::UnaryOp { expr, op } => { - let ty = &(ctx.infer)(expr); - let ev = eval_const(&ctx.exprs[expr], ctx)?; + let ty = &ctx.expr_ty(expr); + let ev = eval_const(expr, ctx)?; match op { hir_def::expr::UnaryOp::Deref => Err(ConstEvalError::NotSupported("deref")), hir_def::expr::UnaryOp::Not => { @@ -203,9 +214,9 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result { - let ty = &(ctx.infer)(lhs); - let lhs = eval_const(&ctx.exprs[lhs], ctx)?; - let rhs = eval_const(&ctx.exprs[rhs], ctx)?; + let ty = &ctx.expr_ty(lhs); + let lhs = eval_const(lhs, ctx)?; + let rhs = eval_const(rhs, ctx)?; let op = op.ok_or(ConstEvalError::IncompleteExpr)?; let v1 = match lhs { ComputedExpr::Literal(Literal::Int(v, _)) => v, @@ -249,31 +260,31 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result Err(ConstEvalError::TypeError), + BinaryOp::LogicOp(_) => Err(ConstEvalError::SemanticError("logic op on numbers")), _ => Err(ConstEvalError::NotSupported("bin op on this operators")), } } Expr::Block { statements, tail, .. } => { - let mut prev_values = HashMap::>::default(); + let mut prev_values = HashMap::>::default(); for statement in &**statements { match *statement { - hir_def::expr::Statement::Let { pat, initializer, .. } => { - let pat = &ctx.pats[pat]; - let name = match pat { - Pat::Bind { name, subpat, .. } if subpat.is_none() => name.clone(), + hir_def::expr::Statement::Let { pat: pat_id, initializer, .. } => { + let pat = &ctx.pats[pat_id]; + match pat { + Pat::Bind { subpat, .. } if subpat.is_none() => (), _ => { return Err(ConstEvalError::NotSupported("complex patterns in let")) } }; let value = match initializer { - Some(x) => eval_const(&ctx.exprs[x], ctx)?, + Some(x) => eval_const(x, ctx)?, None => continue, }; - if !prev_values.contains_key(&name) { - let prev = ctx.local_data.insert(name.clone(), value); - prev_values.insert(name, prev); + if !prev_values.contains_key(&pat_id) { + let prev = ctx.local_data.insert(pat_id, value); + prev_values.insert(pat_id, prev); } else { - ctx.local_data.insert(name, value); + ctx.local_data.insert(pat_id, value); } } hir_def::expr::Statement::Expr { .. } => { @@ -282,7 +293,7 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result eval_const(&ctx.exprs[x], ctx), + &Some(x) => eval_const(x, ctx), None => Ok(ComputedExpr::Tuple(Box::new([]))), }; // clean up local data, so caller will receive the exact map that passed to us @@ -295,19 +306,48 @@ pub fn eval_const(expr: &Expr, ctx: &mut ConstEvalCtx<'_>) -> Result { - let name = p.mod_path().as_ident().ok_or(ConstEvalError::NotSupported("big paths"))?; - let r = ctx - .local_data - .get(name) - .ok_or(ConstEvalError::NotSupported("Non local name resolution"))?; - Ok(r.clone()) + let resolver = resolver_for_expr(ctx.db.upcast(), ctx.owner, expr_id); + let pr = resolver + .resolve_path_in_value_ns(ctx.db.upcast(), p.mod_path()) + .ok_or(ConstEvalError::SemanticError("unresolved path"))?; + let pr = match pr { + ResolveValueResult::ValueNs(v) => v, + ResolveValueResult::Partial(..) => { + return match ctx + .infer + .assoc_resolutions_for_expr(expr_id) + .ok_or(ConstEvalError::SemanticError("unresolved assoc item"))? + { + hir_def::AssocItemId::FunctionId(_) => { + Err(ConstEvalError::NotSupported("assoc function")) + } + hir_def::AssocItemId::ConstId(c) => ctx.db.const_eval(c), + hir_def::AssocItemId::TypeAliasId(_) => { + Err(ConstEvalError::NotSupported("assoc type alias")) + } + } + } + }; + match pr { + ValueNs::LocalBinding(pat_id) => { + let r = ctx + .local_data + .get(&pat_id) + .ok_or(ConstEvalError::NotSupported("Unexpected missing local"))?; + Ok(r.clone()) + } + ValueNs::ConstId(id) => ctx.db.const_eval(id), + ValueNs::GenericParam(_) => { + Err(ConstEvalError::NotSupported("const generic without substitution")) + } + _ => Err(ConstEvalError::NotSupported("path that are not const or local")), + } } _ => Err(ConstEvalError::NotSupported("This kind of expression")), } } pub fn eval_usize(expr: Idx, mut ctx: ConstEvalCtx<'_>) -> Option { - let expr = &ctx.exprs[expr]; if let Ok(ce) = eval_const(expr, &mut ctx) { match ce { ComputedExpr::Literal(Literal::Int(x, _)) => return x.try_into().ok(), @@ -380,10 +420,39 @@ pub fn usize_const(value: Option) -> Const { .intern(Interner) } -pub(crate) fn eval_to_const( +pub(crate) fn const_eval_recover( + _: &dyn HirDatabase, + _: &[String], + _: &ConstId, +) -> Result { + Err(ConstEvalError::Loop) +} + +pub(crate) fn const_eval_query( + db: &dyn HirDatabase, + const_id: ConstId, +) -> Result { + let def = const_id.into(); + let body = db.body(def); + let infer = &db.infer(def); + let result = eval_const( + body.body_expr, + &mut ConstEvalCtx { + db, + owner: const_id.into(), + exprs: &body.exprs, + pats: &body.pats, + local_data: HashMap::default(), + infer, + }, + ); + result +} + +pub(crate) fn eval_to_const<'a>( expr: Idx, mode: ParamLoweringMode, - ctx: &mut InferenceContext, + ctx: &mut InferenceContext<'a>, args: impl FnOnce() -> Generics, debruijn: DebruijnIndex, ) -> Const { @@ -396,10 +465,15 @@ pub(crate) fn eval_to_const( } let body = ctx.body.clone(); let ctx = ConstEvalCtx { + db: ctx.db, + owner: ctx.owner, exprs: &body.exprs, pats: &body.pats, local_data: HashMap::default(), - infer: &mut |x| ctx.infer_expr(x, &Expectation::None), + infer: &ctx.result, }; usize_const(eval_usize(expr, ctx)) } + +#[cfg(test)] +mod tests; diff --git a/crates/hir_ty/src/consteval/tests.rs b/crates/hir_ty/src/consteval/tests.rs new file mode 100644 index 0000000000..4a052851af --- /dev/null +++ b/crates/hir_ty/src/consteval/tests.rs @@ -0,0 +1,148 @@ +use base_db::fixture::WithFixture; +use hir_def::{db::DefDatabase, expr::Literal}; + +use crate::{consteval::ComputedExpr, db::HirDatabase, test_db::TestDB}; + +use super::ConstEvalError; + +fn check_fail(ra_fixture: &str, error: ConstEvalError) { + assert_eq!(eval_goal(ra_fixture), Err(error)); +} + +fn check_number(ra_fixture: &str, answer: i128) { + let r = eval_goal(ra_fixture).unwrap(); + match r { + ComputedExpr::Literal(Literal::Int(r, _)) => assert_eq!(r, answer), + ComputedExpr::Literal(Literal::Uint(r, _)) => assert_eq!(r, answer as u128), + x => panic!("Expected number but found {:?}", x), + } +} + +fn eval_goal(ra_fixture: &str) -> Result { + let (db, file_id) = TestDB::with_single_file(ra_fixture); + let module_id = db.module_for_file(file_id); + let def_map = module_id.def_map(&db); + let scope = &def_map[module_id.local_id].scope; + let const_id = scope + .declarations() + .into_iter() + .find_map(|x| match x { + hir_def::ModuleDefId::ConstId(x) => { + if db.const_data(x).name.as_ref()?.to_string() == "GOAL" { + Some(x) + } else { + None + } + } + _ => None, + }) + .unwrap(); + db.const_eval(const_id) +} + +#[test] +fn add() { + check_number(r#"const GOAL: usize = 2 + 2;"#, 4); +} + +#[test] +fn bit_op() { + check_number(r#"const GOAL: u8 = !0 & !(!0 >> 1)"#, 128); + check_number(r#"const GOAL: i8 = !0 & !(!0 >> 1)"#, 0); + // FIXME: rustc evaluate this to -128 + check_fail( + r#"const GOAL: i8 = 1 << 7"#, + ConstEvalError::Panic("attempt to run invalid arithmetic operation".to_string()), + ); + check_fail( + r#"const GOAL: i8 = 1 << 8"#, + ConstEvalError::Panic("attempt to run invalid arithmetic operation".to_string()), + ); +} + +#[test] +fn locals() { + check_number( + r#" + const GOAL: usize = { + let a = 3 + 2; + let b = a * a; + b + }; + "#, + 25, + ); +} + +#[test] +fn consts() { + check_number( + r#" + const F1: i32 = 1; + const F3: i32 = 3 * F2; + const F2: i32 = 2 * F1; + const GOAL: i32 = F3; + "#, + 6, + ); +} + +#[test] +fn const_loop() { + check_fail( + r#" + const F1: i32 = 1 * F3; + const F3: i32 = 3 * F2; + const F2: i32 = 2 * F1; + const GOAL: i32 = F3; + "#, + ConstEvalError::Loop, + ); +} + +#[test] +fn const_impl_assoc() { + check_number( + r#" + struct U5; + impl U5 { + const VAL: usize = 5; + } + const GOAL: usize = U5::VAL; + "#, + 5, + ); +} + +#[test] +fn const_generic_subst() { + // FIXME: this should evaluate to 5 + check_fail( + r#" + struct Adder; + impl Adder { + const VAL: usize = N + M; + } + const GOAL: usize = Adder::<2, 3>::VAL; + "#, + ConstEvalError::NotSupported("const generic without substitution"), + ); +} + +#[test] +fn const_trait_assoc() { + // FIXME: this should evaluate to 0 + check_fail( + r#" + struct U0; + trait ToConst { + const VAL: usize; + } + impl ToConst for U0 { + const VAL: usize = 0; + } + const GOAL: usize = U0::VAL; + "#, + ConstEvalError::IncompleteExpr, + ); +} diff --git a/crates/hir_ty/src/db.rs b/crates/hir_ty/src/db.rs index 599fd16dd0..467dcfa33e 100644 --- a/crates/hir_ty/src/db.rs +++ b/crates/hir_ty/src/db.rs @@ -5,13 +5,14 @@ use std::sync::Arc; use base_db::{impl_intern_key, salsa, CrateId, Upcast}; use hir_def::{ - db::DefDatabase, expr::ExprId, BlockId, ConstParamId, DefWithBodyId, FunctionId, GenericDefId, - ImplId, LifetimeParamId, LocalFieldId, TypeOrConstParamId, VariantId, + db::DefDatabase, expr::ExprId, BlockId, ConstId, ConstParamId, DefWithBodyId, FunctionId, + GenericDefId, ImplId, LifetimeParamId, LocalFieldId, TypeOrConstParamId, VariantId, }; use la_arena::ArenaMap; use crate::{ chalk_db, + consteval::{ComputedExpr, ConstEvalError}, method_resolution::{InherentImpls, TraitImpls}, Binders, CallableDefId, FnDefId, GenericArg, ImplTraitId, InferenceResult, Interner, PolyFnSig, QuantifiedWhereClause, ReturnTypeImplTraits, TraitRef, Ty, TyDefId, ValueTyDefId, @@ -41,6 +42,10 @@ pub trait HirDatabase: DefDatabase + Upcast { #[salsa::invoke(crate::lower::const_param_ty_query)] fn const_param_ty(&self, def: ConstParamId) -> Ty; + #[salsa::invoke(crate::consteval::const_eval_query)] + #[salsa::cycle(crate::consteval::const_eval_recover)] + fn const_eval(&self, def: ConstId) -> Result; + #[salsa::invoke(crate::lower::impl_trait_query)] fn impl_trait(&self, def: ImplId) -> Option>; diff --git a/crates/hir_ty/src/infer.rs b/crates/hir_ty/src/infer.rs index e78d3f267f..9a6795a1c8 100644 --- a/crates/hir_ty/src/infer.rs +++ b/crates/hir_ty/src/infer.rs @@ -358,12 +358,12 @@ impl Index for InferenceResult { #[derive(Clone, Debug)] pub(crate) struct InferenceContext<'a> { pub(crate) db: &'a dyn HirDatabase, - owner: DefWithBodyId, + pub(crate) owner: DefWithBodyId, pub(crate) body: Arc, pub(crate) resolver: Resolver, table: unify::InferenceTable<'a>, trait_env: Arc, - result: InferenceResult, + pub(crate) result: InferenceResult, /// The return type of the function being inferred, the closure or async block if we're /// currently within one. /// diff --git a/crates/hir_ty/src/tests/simple.rs b/crates/hir_ty/src/tests/simple.rs index 31045c193c..675f9038f0 100644 --- a/crates/hir_ty/src/tests/simple.rs +++ b/crates/hir_ty/src/tests/simple.rs @@ -1749,6 +1749,18 @@ fn main() { ); } +#[test] +fn const_eval_array_repeat_expr() { + check_types( + r#" +fn main() { + const X: usize = 6 - 1; + let t = [(); X + 2]; + //^ [(); 7] +}"#, + ); +} + #[test] fn shadowing_primitive_with_inner_items() { check_types( diff --git a/crates/ide/src/hover/tests.rs b/crates/ide/src/hover/tests.rs index 96cd83b924..67dc9884ed 100644 --- a/crates/ide/src/hover/tests.rs +++ b/crates/ide/src/hover/tests.rs @@ -3401,10 +3401,11 @@ impl Foo {} #[test] fn hover_const_eval() { + // show hex for <10 check( r#" /// This is a doc -const FOO$0: usize = !0 & !(!0 >> 1); +const FOO$0: usize = 1 << 3; "#, expect![[r#" *FOO* @@ -3414,7 +3415,7 @@ const FOO$0: usize = !0 & !(!0 >> 1); ``` ```rust - const FOO: usize = 9223372036854775808 (0x8000000000000000) + const FOO: usize = 8 ``` --- @@ -3422,14 +3423,11 @@ const FOO$0: usize = !0 & !(!0 >> 1); This is a doc "#]], ); + // show hex for >10 check( r#" /// This is a doc -const FOO$0: usize = { - let a = 3 + 2; - let b = a * a; - b -}; +const FOO$0: usize = (1 << 3) + (1 << 2); "#, expect![[r#" *FOO* @@ -3439,53 +3437,7 @@ const FOO$0: usize = { ``` ```rust - const FOO: usize = 25 (0x19) - ``` - - --- - - This is a doc - "#]], - ); - check( - r#" -/// This is a doc -const FOO$0: usize = 1 << 10; -"#, - expect![[r#" - *FOO* - - ```rust - test - ``` - - ```rust - const FOO: usize = 1024 (0x400) - ``` - - --- - - This is a doc - "#]], - ); - check( - r#" -/// This is a doc -const FOO$0: usize = { - let b = 4; - let a = { let b = 2; let a = b; a } + { let a = 1; a + b }; - a -}; -"#, - expect![[r#" - *FOO* - - ```rust - test - ``` - - ```rust - const FOO: usize = 7 + const FOO: usize = 12 (0xC) ``` --- @@ -3493,6 +3445,7 @@ const FOO$0: usize = { This is a doc "#]], ); + // show original body when const eval fails check( r#" /// This is a doc @@ -3514,6 +3467,7 @@ const FOO$0: usize = 2 - 3; This is a doc "#]], ); + // don't show hex for negatives check( r#" /// This is a doc @@ -3538,27 +3492,6 @@ const FOO$0: i32 = 2 - 3; check( r#" /// This is a doc -const FOO$0: usize = 1 << 100; -"#, - expect![[r#" - *FOO* - - ```rust - test - ``` - - ```rust - const FOO: usize = 1 << 100 - ``` - - --- - - This is a doc - "#]], - ); - check( - r#" -/// This is a doc const FOO$0: &str = "bar"; "#, expect![[r#"