From 6d4989b3c754eebbc2b372545db14243b4fbbc44 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 19 Jul 2024 18:37:57 +0200 Subject: [PATCH] Make LRU opt-in --- crates/base-db/src/lib.rs | 2 + crates/hir-expand/src/db.rs | 2 +- crates/hir-ty/src/db.rs | 1 + crates/ide-db/src/lib.rs | 129 --- crates/salsa/salsa-macros/src/query_group.rs | 20 +- crates/salsa/src/derived.rs | 95 +-- crates/salsa/src/derived/slot.rs | 166 +--- crates/salsa/src/derived_lru.rs | 233 +++++ crates/salsa/src/derived_lru/slot.rs | 845 +++++++++++++++++++ crates/salsa/src/lib.rs | 1 + crates/salsa/src/plumbing.rs | 3 +- crates/salsa/tests/lru.rs | 2 + 12 files changed, 1158 insertions(+), 341 deletions(-) create mode 100644 crates/salsa/src/derived_lru.rs create mode 100644 crates/salsa/src/derived_lru/slot.rs diff --git a/crates/base-db/src/lib.rs b/crates/base-db/src/lib.rs index f1bcf93bc0..f319f98537 100644 --- a/crates/base-db/src/lib.rs +++ b/crates/base-db/src/lib.rs @@ -59,6 +59,7 @@ pub trait FileLoader { #[salsa::query_group(SourceDatabaseStorage)] pub trait SourceDatabase: FileLoader + std::fmt::Debug { /// Parses the file into the syntax tree. + #[salsa::lru] fn parse(&self, file_id: EditionedFileId) -> Parse; /// Returns the set of errors obtained from parsing the file including validation errors. @@ -105,6 +106,7 @@ pub trait SourceDatabaseExt: SourceDatabase { #[salsa::input] fn compressed_file_text(&self, file_id: FileId) -> Arc<[u8]>; + #[salsa::lru] fn file_text(&self, file_id: FileId) -> Arc; /// Path to a file, relative to the root of its source root. diff --git a/crates/hir-expand/src/db.rs b/crates/hir-expand/src/db.rs index 8be2dcf0de..5a499d2efd 100644 --- a/crates/hir-expand/src/db.rs +++ b/crates/hir-expand/src/db.rs @@ -65,7 +65,7 @@ pub trait ExpandDatabase: SourceDatabase { #[salsa::transparent] fn parse_or_expand_with_err(&self, file_id: HirFileId) -> ExpandResult>; /// Implementation for the macro case. - // This query is LRU cached + #[salsa::lru] fn parse_macro_expansion( &self, macro_file: MacroFileId, diff --git a/crates/hir-ty/src/db.rs b/crates/hir-ty/src/db.rs index 734aad4945..9a1f2158bf 100644 --- a/crates/hir-ty/src/db.rs +++ b/crates/hir-ty/src/db.rs @@ -61,6 +61,7 @@ pub trait HirDatabase: DefDatabase + Upcast { ) -> Result, MirLowerError>; #[salsa::invoke(crate::mir::borrowck_query)] + #[salsa::lru] fn borrowck(&self, def: DefWithBodyId) -> Result, MirLowerError>; #[salsa::invoke(crate::consteval::const_eval_query)] diff --git a/crates/ide-db/src/lib.rs b/crates/ide-db/src/lib.rs index f259698af0..322c4e9e5c 100644 --- a/crates/ide-db/src/lib.rs +++ b/crates/ide-db/src/lib.rs @@ -192,135 +192,6 @@ impl RootDatabase { .copied() .unwrap_or(base_db::DEFAULT_BORROWCK_LRU_CAP), ); - - macro_rules! update_lru_capacity_per_query { - ($( $module:ident :: $query:ident )*) => {$( - if let Some(&cap) = lru_capacities.get(stringify!($query)) { - $module::$query.in_db_mut(self).set_lru_capacity(cap); - } - )*} - } - update_lru_capacity_per_query![ - // SourceDatabase - // base_db::ParseQuery - // base_db::CrateGraphQuery - // base_db::ProcMacrosQuery - - // SourceDatabaseExt - base_db::FileTextQuery - // base_db::FileSourceRootQuery - // base_db::SourceRootQuery - base_db::SourceRootCratesQuery - - // ExpandDatabase - hir_db::AstIdMapQuery - // hir_db::ParseMacroExpansionQuery - // hir_db::InternMacroCallQuery - hir_db::MacroArgQuery - hir_db::DeclMacroExpanderQuery - // hir_db::MacroExpandQuery - hir_db::ExpandProcMacroQuery - hir_db::ParseMacroExpansionErrorQuery - - // DefDatabase - hir_db::FileItemTreeQuery - hir_db::BlockDefMapQuery - hir_db::StructDataWithDiagnosticsQuery - hir_db::UnionDataWithDiagnosticsQuery - hir_db::EnumDataQuery - hir_db::EnumVariantDataWithDiagnosticsQuery - hir_db::ImplDataWithDiagnosticsQuery - hir_db::TraitDataWithDiagnosticsQuery - hir_db::TraitAliasDataQuery - hir_db::TypeAliasDataQuery - hir_db::FunctionDataQuery - hir_db::ConstDataQuery - hir_db::StaticDataQuery - hir_db::Macro2DataQuery - hir_db::MacroRulesDataQuery - hir_db::ProcMacroDataQuery - hir_db::BodyWithSourceMapQuery - hir_db::BodyQuery - hir_db::ExprScopesQuery - hir_db::GenericParamsQuery - hir_db::FieldsAttrsQuery - hir_db::FieldsAttrsSourceMapQuery - hir_db::AttrsQuery - hir_db::CrateLangItemsQuery - hir_db::LangItemQuery - hir_db::ImportMapQuery - hir_db::FieldVisibilitiesQuery - hir_db::FunctionVisibilityQuery - hir_db::ConstVisibilityQuery - hir_db::CrateSupportsNoStdQuery - - // HirDatabase - hir_db::MirBodyQuery - hir_db::BorrowckQuery - hir_db::TyQuery - hir_db::ValueTyQuery - hir_db::ImplSelfTyQuery - hir_db::ConstParamTyQuery - hir_db::ConstEvalQuery - hir_db::ConstEvalDiscriminantQuery - hir_db::ImplTraitQuery - hir_db::FieldTypesQuery - hir_db::LayoutOfAdtQuery - hir_db::TargetDataLayoutQuery - hir_db::CallableItemSignatureQuery - hir_db::ReturnTypeImplTraitsQuery - hir_db::GenericPredicatesForParamQuery - hir_db::GenericPredicatesQuery - hir_db::TraitEnvironmentQuery - hir_db::GenericDefaultsQuery - hir_db::InherentImplsInCrateQuery - hir_db::InherentImplsInBlockQuery - hir_db::IncoherentInherentImplCratesQuery - hir_db::TraitImplsInCrateQuery - hir_db::TraitImplsInBlockQuery - hir_db::TraitImplsInDepsQuery - // hir_db::InternCallableDefQuery - // hir_db::InternLifetimeParamIdQuery - // hir_db::InternImplTraitIdQuery - // hir_db::InternTypeOrConstParamIdQuery - // hir_db::InternClosureQuery - // hir_db::InternCoroutineQuery - hir_db::AssociatedTyDataQuery - hir_db::TraitDatumQuery - hir_db::AdtDatumQuery - hir_db::ImplDatumQuery - hir_db::FnDefDatumQuery - hir_db::FnDefVarianceQuery - hir_db::AdtVarianceQuery - hir_db::AssociatedTyValueQuery - hir_db::ProgramClausesForChalkEnvQuery - - // SymbolsDatabase - symbol_index::ModuleSymbolsQuery - symbol_index::LibrarySymbolsQuery - // symbol_index::LocalRootsQuery - // symbol_index::LibraryRootsQuery - - // LineIndexDatabase - crate::LineIndexQuery - - // InternDatabase - // hir_db::InternFunctionQuery - // hir_db::InternStructQuery - // hir_db::InternUnionQuery - // hir_db::InternEnumQuery - // hir_db::InternConstQuery - // hir_db::InternStaticQuery - // hir_db::InternTraitQuery - // hir_db::InternTraitAliasQuery - // hir_db::InternTypeAliasQuery - // hir_db::InternImplQuery - // hir_db::InternExternBlockQuery - // hir_db::InternBlockQuery - // hir_db::InternMacro2Query - // hir_db::InternProcMacroQuery - // hir_db::InternMacroRulesQuery - ]; } } diff --git a/crates/salsa/salsa-macros/src/query_group.rs b/crates/salsa/salsa-macros/src/query_group.rs index 4e70741239..eeaf008a15 100644 --- a/crates/salsa/salsa-macros/src/query_group.rs +++ b/crates/salsa/salsa-macros/src/query_group.rs @@ -53,7 +53,11 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream num_storages += 1; } "dependencies" => { - storage = QueryStorage::Dependencies; + storage = QueryStorage::LruDependencies; + num_storages += 1; + } + "lru" => { + storage = QueryStorage::LruMemoized; num_storages += 1; } "input" => { @@ -235,7 +239,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream queries_with_storage.push(fn_name); - let tracing = if let QueryStorage::Memoized = query.storage { + let tracing = if let QueryStorage::Memoized | QueryStorage::LruMemoized = query.storage { let s = format!("{trait_name}::{fn_name}"); Some(quote! { let _p = tracing::debug_span!(#s, #(#key_names = tracing::field::debug(&#key_names)),*).entered(); @@ -376,8 +380,9 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream let storage = match &query.storage { QueryStorage::Memoized => quote!(salsa::plumbing::MemoizedStorage), - QueryStorage::Dependencies => { - quote!(salsa::plumbing::DependencyStorage) + QueryStorage::LruMemoized => quote!(salsa::plumbing::LruMemoizedStorage), + QueryStorage::LruDependencies => { + quote!(salsa::plumbing::LruDependencyStorage) } QueryStorage::Input if query.keys.is_empty() => { quote!(salsa::plumbing::UnitInputStorage) @@ -724,7 +729,8 @@ impl Query { #[derive(Debug, Clone, PartialEq, Eq)] enum QueryStorage { Memoized, - Dependencies, + LruDependencies, + LruMemoized, Input, Interned, InternedLookup { intern_query_type: Ident }, @@ -739,7 +745,9 @@ impl QueryStorage { | QueryStorage::Interned | QueryStorage::InternedLookup { .. } | QueryStorage::Transparent => false, - QueryStorage::Memoized | QueryStorage::Dependencies => true, + QueryStorage::Memoized | QueryStorage::LruMemoized | QueryStorage::LruDependencies => { + true + } } } } diff --git a/crates/salsa/src/derived.rs b/crates/salsa/src/derived.rs index bdb448e241..435baef98c 100644 --- a/crates/salsa/src/derived.rs +++ b/crates/salsa/src/derived.rs @@ -1,9 +1,7 @@ use crate::debug::TableEntry; use crate::durability::Durability; use crate::hash::FxIndexMap; -use crate::lru::Lru; use crate::plumbing::DerivedQueryStorageOps; -use crate::plumbing::LruQueryStorageOps; use crate::plumbing::QueryFunction; use crate::plumbing::QueryStorageMassOps; use crate::plumbing::QueryStorageOps; @@ -13,7 +11,6 @@ use crate::{Database, DatabaseKeyIndex, QueryDb, Revision}; use parking_lot::RwLock; use std::borrow::Borrow; use std::hash::Hash; -use std::marker::PhantomData; use triomphe::Arc; mod slot; @@ -22,79 +19,32 @@ use slot::Slot; /// Memoized queries store the result plus a list of the other queries /// that they invoked. This means we can avoid recomputing them when /// none of those inputs have changed. -pub type MemoizedStorage = DerivedStorage; - -/// "Dependency" queries just track their dependencies and not the -/// actual value (which they produce on demand). This lessens the -/// storage requirements. -pub type DependencyStorage = DerivedStorage; +pub type MemoizedStorage = DerivedStorage; /// Handles storage where the value is 'derived' by executing a /// function (in contrast to "inputs"). -pub struct DerivedStorage +pub struct DerivedStorage where Q: QueryFunction, - MP: MemoizationPolicy, { group_index: u16, - lru_list: Lru>, - slot_map: RwLock>>>, - policy: PhantomData, + slot_map: RwLock>>>, } -impl std::panic::RefUnwindSafe for DerivedStorage +impl std::panic::RefUnwindSafe for DerivedStorage where Q: QueryFunction, - MP: MemoizationPolicy, + Q::Key: std::panic::RefUnwindSafe, Q::Value: std::panic::RefUnwindSafe, { } -pub trait MemoizationPolicy: Send + Sync +impl DerivedStorage where Q: QueryFunction, { - fn should_memoize_value(key: &Q::Key) -> bool; - - fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool; -} - -pub enum AlwaysMemoizeValue {} -impl MemoizationPolicy for AlwaysMemoizeValue -where - Q: QueryFunction, - Q::Value: Eq, -{ - fn should_memoize_value(_key: &Q::Key) -> bool { - true - } - - fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool { - old_value == new_value - } -} - -pub enum NeverMemoizeValue {} -impl MemoizationPolicy for NeverMemoizeValue -where - Q: QueryFunction, -{ - fn should_memoize_value(_key: &Q::Key) -> bool { - false - } - - fn memoized_value_eq(_old_value: &Q::Value, _new_value: &Q::Value) -> bool { - panic!("cannot reach since we never memoize") - } -} - -impl DerivedStorage -where - Q: QueryFunction, - MP: MemoizationPolicy, -{ - fn slot(&self, key: &Q::Key) -> Arc> { + fn slot(&self, key: &Q::Key) -> Arc> { if let Some(v) = self.slot_map.read().get(key) { return v.clone(); } @@ -111,20 +61,14 @@ where } } -impl QueryStorageOps for DerivedStorage +impl QueryStorageOps for DerivedStorage where Q: QueryFunction, - MP: MemoizationPolicy, { const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = Q::CYCLE_STRATEGY; fn new(group_index: u16) -> Self { - DerivedStorage { - group_index, - slot_map: RwLock::new(FxIndexMap::default()), - lru_list: Default::default(), - policy: PhantomData, - } + DerivedStorage { group_index, slot_map: RwLock::new(FxIndexMap::default()) } } fn fmt_index( @@ -161,10 +105,6 @@ where let slot = self.slot(key); let StampedValue { value, durability, changed_at } = slot.read(db, key); - if let Some(evicted) = self.lru_list.record_use(&slot) { - evicted.evict(); - } - db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted( slot.database_key_index(), durability, @@ -187,31 +127,18 @@ where } } -impl QueryStorageMassOps for DerivedStorage +impl QueryStorageMassOps for DerivedStorage where Q: QueryFunction, - MP: MemoizationPolicy, { fn purge(&self) { - self.lru_list.purge(); *self.slot_map.write() = Default::default(); } } -impl LruQueryStorageOps for DerivedStorage +impl DerivedQueryStorageOps for DerivedStorage where Q: QueryFunction, - MP: MemoizationPolicy, -{ - fn set_lru_capacity(&self, new_capacity: u16) { - self.lru_list.set_lru_capacity(new_capacity); - } -} - -impl DerivedQueryStorageOps for DerivedStorage -where - Q: QueryFunction, - MP: MemoizationPolicy, { fn invalidate(&self, runtime: &mut Runtime, key: &S) where diff --git a/crates/salsa/src/derived/slot.rs b/crates/salsa/src/derived/slot.rs index 5c21fd8d97..b5c3d9f4f2 100644 --- a/crates/salsa/src/derived/slot.rs +++ b/crates/salsa/src/derived/slot.rs @@ -1,8 +1,5 @@ use crate::debug::TableEntry; -use crate::derived::MemoizationPolicy; use crate::durability::Durability; -use crate::lru::LruIndex; -use crate::lru::LruNode; use crate::plumbing::{DatabaseOps, QueryFunction}; use crate::revision::Revision; use crate::runtime::local_state::ActiveQueryGuard; @@ -14,21 +11,18 @@ use crate::runtime::WaitResult; use crate::Cycle; use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb}; use parking_lot::{RawRwLock, RwLock}; -use std::marker::PhantomData; use std::ops::Deref; use std::sync::atomic::{AtomicBool, Ordering}; use tracing::{debug, info}; -pub(super) struct Slot +pub(super) struct Slot where Q: QueryFunction, - MP: MemoizationPolicy, { key_index: u32, + // FIXME: Yeet this group_index: u16, state: RwLock>, - lru_index: LruIndex, - policy: PhantomData, } /// Defines the "current state" of query's memoized results. @@ -54,7 +48,7 @@ where struct Memo { /// The result of the query, if we decide to memoize it. - value: Option, + value: V, /// Last revision when this memo was verified; this begins /// as the current revision. @@ -77,12 +71,6 @@ enum ProbeState { /// verified in this revision. Stale(G), - /// There is an entry, and it has been verified - /// in this revision, but it has no cached - /// value. The `Revision` is the revision where the - /// value last changed (if we were to recompute it). - NoValue(G, Revision), - /// There is an entry which has been verified, /// and it has the following value-- or, we blocked /// on another thread, and that resulted in a cycle. @@ -103,18 +91,15 @@ enum MaybeChangedSinceProbeState { Stale(G), } -impl Slot +impl Slot where Q: QueryFunction, - MP: MemoizationPolicy, { pub(super) fn new(database_key_index: DatabaseKeyIndex) -> Self { Self { key_index: database_key_index.key_index, group_index: database_key_index.group_index, state: RwLock::new(QueryState::NotComputed), - lru_index: LruIndex::default(), - policy: PhantomData, } } @@ -146,9 +131,7 @@ where loop { match self.probe(db, self.state.read(), runtime, revision_now) { ProbeState::UpToDate(v) => return v, - ProbeState::Stale(..) | ProbeState::NoValue(..) | ProbeState::NotComputed(..) => { - break - } + ProbeState::Stale(..) | ProbeState::NotComputed(..) => break, ProbeState::Retry => continue, } } @@ -176,9 +159,7 @@ where let mut old_memo = loop { match self.probe(db, self.state.upgradable_read(), runtime, revision_now) { ProbeState::UpToDate(v) => return v, - ProbeState::Stale(state) - | ProbeState::NotComputed(state) - | ProbeState::NoValue(state, _) => { + ProbeState::Stale(state) | ProbeState::NotComputed(state) => { type RwLockUpgradableReadGuard<'a, T> = lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; @@ -226,7 +207,7 @@ where runtime: &Runtime, revision_now: Revision, active_query: ActiveQueryGuard<'_>, - panic_guard: PanicGuard<'_, Q, MP>, + panic_guard: PanicGuard<'_, Q>, old_memo: Option>, key: &Q::Key, ) -> StampedValue { @@ -285,22 +266,18 @@ where // "backdate" its `changed_at` revision to be the same as the // old value. if let Some(old_memo) = &old_memo { - if let Some(old_value) = &old_memo.value { - // Careful: if the value became less durable than it - // used to be, that is a "breaking change" that our - // consumers must be aware of. Becoming *more* durable - // is not. See the test `constant_to_non_constant`. - if revisions.durability >= old_memo.revisions.durability - && MP::memoized_value_eq(old_value, &value) - { - debug!( - "read_upgrade({:?}): value is equal, back-dating to {:?}", - self, old_memo.revisions.changed_at, - ); + // Careful: if the value became less durable than it + // used to be, that is a "breaking change" that our + // consumers must be aware of. Becoming *more* durable + // is not. See the test `constant_to_non_constant`. + if revisions.durability >= old_memo.revisions.durability { + debug!( + "read_upgrade({:?}): value is equal, back-dating to {:?}", + self, old_memo.revisions.changed_at, + ); - assert!(old_memo.revisions.changed_at <= revisions.changed_at); - revisions.changed_at = old_memo.revisions.changed_at; - } + assert!(old_memo.revisions.changed_at <= revisions.changed_at); + revisions.changed_at = old_memo.revisions.changed_at; } } @@ -310,8 +287,7 @@ where changed_at: revisions.changed_at, }; - let memo_value = - if self.should_memoize_value(key) { Some(new_value.value.clone()) } else { None }; + let memo_value = new_value.value.clone(); debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,); @@ -371,20 +347,16 @@ where return ProbeState::Stale(state); } - if let Some(value) = &memo.value { - let value = StampedValue { - durability: memo.revisions.durability, - changed_at: memo.revisions.changed_at, - value: value.clone(), - }; + let value = &memo.value; + let value = StampedValue { + durability: memo.revisions.durability, + changed_at: memo.revisions.changed_at, + value: value.clone(), + }; - info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at); + info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at); - ProbeState::UpToDate(value) - } else { - let changed_at = memo.revisions.changed_at; - ProbeState::NoValue(state, changed_at) - } + ProbeState::UpToDate(value) } } } @@ -407,21 +379,9 @@ where match &*self.state.read() { QueryState::NotComputed => None, QueryState::InProgress { .. } => Some(TableEntry::new(key.clone(), None)), - QueryState::Memoized(memo) => Some(TableEntry::new(key.clone(), memo.value.clone())), - } - } - - pub(super) fn evict(&self) { - let mut state = self.state.write(); - if let QueryState::Memoized(memo) = &mut *state { - // Evicting a value with an untracked input could - // lead to inconsistencies. Note that we can't check - // `has_untracked_input` when we add the value to the cache, - // because inputs can become untracked in the next revision. - if memo.has_untracked_input() { - return; + QueryState::Memoized(memo) => { + Some(TableEntry::new(key.clone(), Some(memo.value.clone()))) } - memo.value = None; } } @@ -489,8 +449,7 @@ where // If we know when value last changed, we can return right away. // Note that we don't need the actual value to be available. - ProbeState::NoValue(_, changed_at) - | ProbeState::UpToDate(StampedValue { value: _, durability: _, changed_at }) => { + ProbeState::UpToDate(StampedValue { value: _, durability: _, changed_at }) => { MaybeChangedSinceProbeState::ChangedAt(changed_at) } @@ -545,7 +504,7 @@ where let maybe_changed = old_memo.revisions.changed_at > revision; panic_guard.proceed(Some(old_memo)); maybe_changed - } else if old_memo.value.is_some() { + } else { // We found that this memoized value may have changed // but we have an old value. We can re-run the code and // actually *check* if it has changed. @@ -559,12 +518,6 @@ where key, ); changed_at > revision - } else { - // We found that inputs to this memoized value may have chanced - // but we don't have an old value to compare against or re-use. - // No choice but to drop the memo and say that its value may have changed. - panic_guard.proceed(None); - true } } @@ -583,10 +536,6 @@ where mutex_guard, ) } - - fn should_memoize_value(&self, key: &Q::Key) -> bool { - MP::should_memoize_value(key) - } } impl QueryState @@ -598,21 +547,19 @@ where } } -struct PanicGuard<'me, Q, MP> +struct PanicGuard<'me, Q> where Q: QueryFunction, - MP: MemoizationPolicy, { - slot: &'me Slot, + slot: &'me Slot, runtime: &'me Runtime, } -impl<'me, Q, MP> PanicGuard<'me, Q, MP> +impl<'me, Q> PanicGuard<'me, Q> where Q: QueryFunction, - MP: MemoizationPolicy, { - fn new(slot: &'me Slot, runtime: &'me Runtime) -> Self { + fn new(slot: &'me Slot, runtime: &'me Runtime) -> Self { Self { slot, runtime } } @@ -666,10 +613,9 @@ Please report this bug to https://github.com/salsa-rs/salsa/issues." } } -impl<'me, Q, MP> Drop for PanicGuard<'me, Q, MP> +impl<'me, Q> Drop for PanicGuard<'me, Q> where Q: QueryFunction, - MP: MemoizationPolicy, { fn drop(&mut self) { if std::thread::panicking() { @@ -702,15 +648,11 @@ where revision_now: Revision, active_query: &ActiveQueryGuard<'_>, ) -> Option> { - // If we don't have a memoized value, nothing to validate. - if self.value.is_none() { - return None; - } if self.verify_revisions(db, revision_now, active_query) { - self.value.clone().map(|value| StampedValue { + Some(StampedValue { durability: self.revisions.durability, changed_at: self.revisions.changed_at, - value, + value: self.value.clone(), }) } else { None @@ -788,58 +730,42 @@ where self.verified_at = revision_now; true } - - fn has_untracked_input(&self) -> bool { - self.revisions.untracked - } } -impl std::fmt::Debug for Slot +impl std::fmt::Debug for Slot where Q: QueryFunction, - MP: MemoizationPolicy, { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(fmt, "{:?}", Q::default()) } } -impl LruNode for Slot -where - Q: QueryFunction, - MP: MemoizationPolicy, -{ - fn lru_index(&self) -> &LruIndex { - &self.lru_index - } -} - -/// Check that `Slot: Send + Sync` as long as +/// Check that `Slot: Send + Sync` as long as /// `DB::DatabaseData: Send + Sync`, which in turn implies that /// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`. #[allow(dead_code)] -fn check_send_sync() +fn check_send_sync() where Q: QueryFunction, - MP: MemoizationPolicy, + Q::Key: Send + Sync, Q::Value: Send + Sync, { fn is_send_sync() {} - is_send_sync::>(); + is_send_sync::>(); } -/// Check that `Slot: 'static` as long as +/// Check that `Slot: 'static` as long as /// `DB::DatabaseData: 'static`, which in turn implies that /// `Q::Key: 'static`, `Q::Value: 'static`. #[allow(dead_code)] -fn check_static() +fn check_static() where Q: QueryFunction + 'static, - MP: MemoizationPolicy + 'static, Q::Key: 'static, Q::Value: 'static, { fn is_static() {} - is_static::>(); + is_static::>(); } diff --git a/crates/salsa/src/derived_lru.rs b/crates/salsa/src/derived_lru.rs new file mode 100644 index 0000000000..bdb448e241 --- /dev/null +++ b/crates/salsa/src/derived_lru.rs @@ -0,0 +1,233 @@ +use crate::debug::TableEntry; +use crate::durability::Durability; +use crate::hash::FxIndexMap; +use crate::lru::Lru; +use crate::plumbing::DerivedQueryStorageOps; +use crate::plumbing::LruQueryStorageOps; +use crate::plumbing::QueryFunction; +use crate::plumbing::QueryStorageMassOps; +use crate::plumbing::QueryStorageOps; +use crate::runtime::StampedValue; +use crate::Runtime; +use crate::{Database, DatabaseKeyIndex, QueryDb, Revision}; +use parking_lot::RwLock; +use std::borrow::Borrow; +use std::hash::Hash; +use std::marker::PhantomData; +use triomphe::Arc; + +mod slot; +use slot::Slot; + +/// Memoized queries store the result plus a list of the other queries +/// that they invoked. This means we can avoid recomputing them when +/// none of those inputs have changed. +pub type MemoizedStorage = DerivedStorage; + +/// "Dependency" queries just track their dependencies and not the +/// actual value (which they produce on demand). This lessens the +/// storage requirements. +pub type DependencyStorage = DerivedStorage; + +/// Handles storage where the value is 'derived' by executing a +/// function (in contrast to "inputs"). +pub struct DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + group_index: u16, + lru_list: Lru>, + slot_map: RwLock>>>, + policy: PhantomData, +} + +impl std::panic::RefUnwindSafe for DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, + Q::Key: std::panic::RefUnwindSafe, + Q::Value: std::panic::RefUnwindSafe, +{ +} + +pub trait MemoizationPolicy: Send + Sync +where + Q: QueryFunction, +{ + fn should_memoize_value(key: &Q::Key) -> bool; + + fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool; +} + +pub enum AlwaysMemoizeValue {} +impl MemoizationPolicy for AlwaysMemoizeValue +where + Q: QueryFunction, + Q::Value: Eq, +{ + fn should_memoize_value(_key: &Q::Key) -> bool { + true + } + + fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool { + old_value == new_value + } +} + +pub enum NeverMemoizeValue {} +impl MemoizationPolicy for NeverMemoizeValue +where + Q: QueryFunction, +{ + fn should_memoize_value(_key: &Q::Key) -> bool { + false + } + + fn memoized_value_eq(_old_value: &Q::Value, _new_value: &Q::Value) -> bool { + panic!("cannot reach since we never memoize") + } +} + +impl DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn slot(&self, key: &Q::Key) -> Arc> { + if let Some(v) = self.slot_map.read().get(key) { + return v.clone(); + } + + let mut write = self.slot_map.write(); + let entry = write.entry(key.clone()); + let key_index = entry.index() as u32; + let database_key_index = DatabaseKeyIndex { + group_index: self.group_index, + query_index: Q::QUERY_INDEX, + key_index, + }; + entry.or_insert_with(|| Arc::new(Slot::new(database_key_index))).clone() + } +} + +impl QueryStorageOps for DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = Q::CYCLE_STRATEGY; + + fn new(group_index: u16) -> Self { + DerivedStorage { + group_index, + slot_map: RwLock::new(FxIndexMap::default()), + lru_list: Default::default(), + policy: PhantomData, + } + } + + fn fmt_index( + &self, + _db: &>::DynDb, + index: u32, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + let slot_map = self.slot_map.read(); + let key = slot_map.get_index(index as usize).unwrap().0; + write!(fmt, "{}::{}({:?})", std::any::type_name::(), Q::QUERY_NAME, key) + } + + fn maybe_changed_after( + &self, + db: &>::DynDb, + index: u32, + revision: Revision, + ) -> bool { + debug_assert!(revision < db.salsa_runtime().current_revision()); + let (key, slot) = { + let read = self.slot_map.read(); + let Some((key, slot)) = read.get_index(index as usize) else { + return false; + }; + (key.clone(), slot.clone()) + }; + slot.maybe_changed_after(db, revision, &key) + } + + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { + db.unwind_if_cancelled(); + + let slot = self.slot(key); + let StampedValue { value, durability, changed_at } = slot.read(db, key); + + if let Some(evicted) = self.lru_list.record_use(&slot) { + evicted.evict(); + } + + db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted( + slot.database_key_index(), + durability, + changed_at, + ); + + value + } + + fn durability(&self, db: &>::DynDb, key: &Q::Key) -> Durability { + self.slot(key).durability(db) + } + + fn entries(&self, _db: &>::DynDb) -> C + where + C: std::iter::FromIterator>, + { + let slot_map = self.slot_map.read(); + slot_map.iter().filter_map(|(key, slot)| slot.as_table_entry(key)).collect() + } +} + +impl QueryStorageMassOps for DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn purge(&self) { + self.lru_list.purge(); + *self.slot_map.write() = Default::default(); + } +} + +impl LruQueryStorageOps for DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn set_lru_capacity(&self, new_capacity: u16) { + self.lru_list.set_lru_capacity(new_capacity); + } +} + +impl DerivedQueryStorageOps for DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn invalidate(&self, runtime: &mut Runtime, key: &S) + where + S: Eq + Hash, + Q::Key: Borrow, + { + runtime.with_incremented_revision(|new_revision| { + let map_read = self.slot_map.read(); + + if let Some(slot) = map_read.get(key) { + if let Some(durability) = slot.invalidate(new_revision) { + return Some(durability); + } + } + + None + }) + } +} diff --git a/crates/salsa/src/derived_lru/slot.rs b/crates/salsa/src/derived_lru/slot.rs new file mode 100644 index 0000000000..d0e4b5422b --- /dev/null +++ b/crates/salsa/src/derived_lru/slot.rs @@ -0,0 +1,845 @@ +use crate::debug::TableEntry; +use crate::derived_lru::MemoizationPolicy; +use crate::durability::Durability; +use crate::lru::LruIndex; +use crate::lru::LruNode; +use crate::plumbing::{DatabaseOps, QueryFunction}; +use crate::revision::Revision; +use crate::runtime::local_state::ActiveQueryGuard; +use crate::runtime::local_state::QueryRevisions; +use crate::runtime::Runtime; +use crate::runtime::RuntimeId; +use crate::runtime::StampedValue; +use crate::runtime::WaitResult; +use crate::Cycle; +use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb}; +use parking_lot::{RawRwLock, RwLock}; +use std::marker::PhantomData; +use std::ops::Deref; +use std::sync::atomic::{AtomicBool, Ordering}; +use tracing::{debug, info}; + +pub(super) struct Slot +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + key_index: u32, + group_index: u16, + state: RwLock>, + lru_index: LruIndex, + policy: PhantomData, +} + +/// Defines the "current state" of query's memoized results. +enum QueryState +where + Q: QueryFunction, +{ + NotComputed, + + /// The runtime with the given id is currently computing the + /// result of this query. + InProgress { + id: RuntimeId, + + /// Set to true if any other queries are blocked, + /// waiting for this query to complete. + anyone_waiting: AtomicBool, + }, + + /// We have computed the query already, and here is the result. + Memoized(Memo), +} + +struct Memo { + /// The result of the query, if we decide to memoize it. + value: Option, + + /// Last revision when this memo was verified; this begins + /// as the current revision. + pub(crate) verified_at: Revision, + + /// Revision information + revisions: QueryRevisions, +} + +/// Return value of `probe` helper. +enum ProbeState { + /// Another thread was active but has completed. + /// Try again! + Retry, + + /// No entry for this key at all. + NotComputed(G), + + /// There is an entry, but its contents have not been + /// verified in this revision. + Stale(G), + + /// There is an entry, and it has been verified + /// in this revision, but it has no cached + /// value. The `Revision` is the revision where the + /// value last changed (if we were to recompute it). + NoValue(G, Revision), + + /// There is an entry which has been verified, + /// and it has the following value-- or, we blocked + /// on another thread, and that resulted in a cycle. + UpToDate(V), +} + +/// Return value of `maybe_changed_after_probe` helper. +enum MaybeChangedSinceProbeState { + /// Another thread was active but has completed. + /// Try again! + Retry, + + /// Value may have changed in the given revision. + ChangedAt(Revision), + + /// There is a stale cache entry that has not been + /// verified in this revision, so we can't say. + Stale(G), +} + +impl Slot +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + pub(super) fn new(database_key_index: DatabaseKeyIndex) -> Self { + Self { + key_index: database_key_index.key_index, + group_index: database_key_index.group_index, + state: RwLock::new(QueryState::NotComputed), + lru_index: LruIndex::default(), + policy: PhantomData, + } + } + + pub(super) fn database_key_index(&self) -> DatabaseKeyIndex { + DatabaseKeyIndex { + group_index: self.group_index, + query_index: Q::QUERY_INDEX, + key_index: self.key_index, + } + } + + pub(super) fn read( + &self, + db: &>::DynDb, + key: &Q::Key, + ) -> StampedValue { + let runtime = db.salsa_runtime(); + + // NB: We don't need to worry about people modifying the + // revision out from under our feet. Either `db` is a frozen + // database, in which case there is a lock, or the mutator + // thread is the current thread, and it will be prevented from + // doing any `set` invocations while the query function runs. + let revision_now = runtime.current_revision(); + + info!("{:?}: invoked at {:?}", self, revision_now,); + + // First, do a check with a read-lock. + loop { + match self.probe(db, self.state.read(), runtime, revision_now) { + ProbeState::UpToDate(v) => return v, + ProbeState::Stale(..) | ProbeState::NoValue(..) | ProbeState::NotComputed(..) => { + break + } + ProbeState::Retry => continue, + } + } + + self.read_upgrade(db, key, revision_now) + } + + /// Second phase of a read operation: acquires an upgradable-read + /// and -- if needed -- validates whether inputs have changed, + /// recomputes value, etc. This is invoked after our initial probe + /// shows a potentially out of date value. + fn read_upgrade( + &self, + db: &>::DynDb, + key: &Q::Key, + revision_now: Revision, + ) -> StampedValue { + let runtime = db.salsa_runtime(); + + debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,); + + // Check with an upgradable read to see if there is a value + // already. (This permits other readers but prevents anyone + // else from running `read_upgrade` at the same time.) + let mut old_memo = loop { + match self.probe(db, self.state.upgradable_read(), runtime, revision_now) { + ProbeState::UpToDate(v) => return v, + ProbeState::Stale(state) + | ProbeState::NotComputed(state) + | ProbeState::NoValue(state, _) => { + type RwLockUpgradableReadGuard<'a, T> = + lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; + + let mut state = RwLockUpgradableReadGuard::upgrade(state); + match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) { + QueryState::Memoized(old_memo) => break Some(old_memo), + QueryState::InProgress { .. } => unreachable!(), + QueryState::NotComputed => break None, + } + } + ProbeState::Retry => continue, + } + }; + + let panic_guard = PanicGuard::new(self, runtime); + let active_query = runtime.push_query(self.database_key_index()); + + // If we have an old-value, it *may* now be stale, since there + // has been a new revision since the last time we checked. So, + // first things first, let's walk over each of our previous + // inputs and check whether they are out of date. + if let Some(memo) = &mut old_memo { + if let Some(value) = memo.verify_value(db.ops_database(), revision_now, &active_query) { + info!("{:?}: validated old memoized value", self,); + + db.salsa_event(Event { + runtime_id: runtime.id(), + kind: EventKind::DidValidateMemoizedValue { + database_key: self.database_key_index(), + }, + }); + + panic_guard.proceed(old_memo); + + return value; + } + } + + self.execute(db, runtime, revision_now, active_query, panic_guard, old_memo, key) + } + + fn execute( + &self, + db: &>::DynDb, + runtime: &Runtime, + revision_now: Revision, + active_query: ActiveQueryGuard<'_>, + panic_guard: PanicGuard<'_, Q, MP>, + old_memo: Option>, + key: &Q::Key, + ) -> StampedValue { + tracing::info!("{:?}: executing query", self.database_key_index().debug(db)); + + db.salsa_event(Event { + runtime_id: db.salsa_runtime().id(), + kind: EventKind::WillExecute { database_key: self.database_key_index() }, + }); + + // Query was not previously executed, or value is potentially + // stale, or value is absent. Let's execute! + let value = match Cycle::catch(|| Q::execute(db, key.clone())) { + Ok(v) => v, + Err(cycle) => { + tracing::debug!( + "{:?}: caught cycle {:?}, have strategy {:?}", + self.database_key_index().debug(db), + cycle, + Q::CYCLE_STRATEGY, + ); + match Q::CYCLE_STRATEGY { + crate::plumbing::CycleRecoveryStrategy::Panic => { + panic_guard.proceed(None); + cycle.throw() + } + crate::plumbing::CycleRecoveryStrategy::Fallback => { + if let Some(c) = active_query.take_cycle() { + assert!(c.is(&cycle)); + Q::cycle_fallback(db, &cycle, key) + } else { + // we are not a participant in this cycle + debug_assert!(!cycle + .participant_keys() + .any(|k| k == self.database_key_index())); + cycle.throw() + } + } + } + } + }; + + let mut revisions = active_query.pop(); + + // We assume that query is side-effect free -- that is, does + // not mutate the "inputs" to the query system. Sanity check + // that assumption here, at least to the best of our ability. + assert_eq!( + runtime.current_revision(), + revision_now, + "revision altered during query execution", + ); + + // If the new value is equal to the old one, then it didn't + // really change, even if some of its inputs have. So we can + // "backdate" its `changed_at` revision to be the same as the + // old value. + if let Some(old_memo) = &old_memo { + if let Some(old_value) = &old_memo.value { + // Careful: if the value became less durable than it + // used to be, that is a "breaking change" that our + // consumers must be aware of. Becoming *more* durable + // is not. See the test `constant_to_non_constant`. + if revisions.durability >= old_memo.revisions.durability + && MP::memoized_value_eq(old_value, &value) + { + debug!( + "read_upgrade({:?}): value is equal, back-dating to {:?}", + self, old_memo.revisions.changed_at, + ); + + assert!(old_memo.revisions.changed_at <= revisions.changed_at); + revisions.changed_at = old_memo.revisions.changed_at; + } + } + } + + let new_value = StampedValue { + value, + durability: revisions.durability, + changed_at: revisions.changed_at, + }; + + let memo_value = + if self.should_memoize_value(key) { Some(new_value.value.clone()) } else { None }; + + debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,); + + panic_guard.proceed(Some(Memo { value: memo_value, verified_at: revision_now, revisions })); + + new_value + } + + /// Helper for `read` that does a shallow check (not recursive) if we have an up-to-date value. + /// + /// Invoked with the guard `state` corresponding to the `QueryState` of some `Slot` (the guard + /// can be either read or write). Returns a suitable `ProbeState`: + /// + /// - `ProbeState::UpToDate(r)` if the table has an up-to-date value (or we blocked on another + /// thread that produced such a value). + /// - `ProbeState::StaleOrAbsent(g)` if either (a) there is no memo for this key, (b) the memo + /// has no value; or (c) the memo has not been verified at the current revision. + /// + /// Note that in case `ProbeState::UpToDate`, the lock will have been released. + fn probe( + &self, + db: &>::DynDb, + state: StateGuard, + runtime: &Runtime, + revision_now: Revision, + ) -> ProbeState, StateGuard> + where + StateGuard: Deref>, + { + match &*state { + QueryState::NotComputed => ProbeState::NotComputed(state), + + QueryState::InProgress { id, anyone_waiting } => { + let other_id = *id; + + // NB: `Ordering::Relaxed` is sufficient here, + // as there are no loads that are "gated" on this + // value. Everything that is written is also protected + // by a lock that must be acquired. The role of this + // boolean is to decide *whether* to acquire the lock, + // not to gate future atomic reads. + anyone_waiting.store(true, Ordering::Relaxed); + + self.block_on_or_unwind(db, runtime, other_id, state); + + // Other thread completely normally, so our value may be available now. + ProbeState::Retry + } + + QueryState::Memoized(memo) => { + debug!( + "{:?}: found memoized value, verified_at={:?}, changed_at={:?}", + self, memo.verified_at, memo.revisions.changed_at, + ); + + if memo.verified_at < revision_now { + return ProbeState::Stale(state); + } + + if let Some(value) = &memo.value { + let value = StampedValue { + durability: memo.revisions.durability, + changed_at: memo.revisions.changed_at, + value: value.clone(), + }; + + info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at); + + ProbeState::UpToDate(value) + } else { + let changed_at = memo.revisions.changed_at; + ProbeState::NoValue(state, changed_at) + } + } + } + } + + pub(super) fn durability(&self, db: &>::DynDb) -> Durability { + match &*self.state.read() { + QueryState::NotComputed => Durability::LOW, + QueryState::InProgress { .. } => panic!("query in progress"), + QueryState::Memoized(memo) => { + if memo.check_durability(db.salsa_runtime()) { + memo.revisions.durability + } else { + Durability::LOW + } + } + } + } + + pub(super) fn as_table_entry(&self, key: &Q::Key) -> Option> { + match &*self.state.read() { + QueryState::NotComputed => None, + QueryState::InProgress { .. } => Some(TableEntry::new(key.clone(), None)), + QueryState::Memoized(memo) => Some(TableEntry::new(key.clone(), memo.value.clone())), + } + } + + pub(super) fn evict(&self) { + let mut state = self.state.write(); + if let QueryState::Memoized(memo) = &mut *state { + // Evicting a value with an untracked input could + // lead to inconsistencies. Note that we can't check + // `has_untracked_input` when we add the value to the cache, + // because inputs can become untracked in the next revision. + if memo.has_untracked_input() { + return; + } + memo.value = None; + } + } + + pub(super) fn invalidate(&self, new_revision: Revision) -> Option { + tracing::debug!("Slot::invalidate(new_revision = {:?})", new_revision); + match &mut *self.state.write() { + QueryState::Memoized(memo) => { + memo.revisions.untracked = true; + memo.revisions.inputs = None; + memo.revisions.changed_at = new_revision; + Some(memo.revisions.durability) + } + QueryState::NotComputed => None, + QueryState::InProgress { .. } => unreachable!(), + } + } + + pub(super) fn maybe_changed_after( + &self, + db: &>::DynDb, + revision: Revision, + key: &Q::Key, + ) -> bool { + let runtime = db.salsa_runtime(); + let revision_now = runtime.current_revision(); + + db.unwind_if_cancelled(); + + debug!( + "maybe_changed_after({:?}) called with revision={:?}, revision_now={:?}", + self, revision, revision_now, + ); + + // Do an initial probe with just the read-lock. + // + // If we find that a cache entry for the value is present + // but hasn't been verified in this revision, we'll have to + // do more. + loop { + match self.maybe_changed_after_probe(db, self.state.read(), runtime, revision_now) { + MaybeChangedSinceProbeState::Retry => continue, + MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, + MaybeChangedSinceProbeState::Stale(state) => { + drop(state); + return self.maybe_changed_after_upgrade(db, revision, key); + } + } + } + } + + fn maybe_changed_after_probe( + &self, + db: &>::DynDb, + state: StateGuard, + runtime: &Runtime, + revision_now: Revision, + ) -> MaybeChangedSinceProbeState + where + StateGuard: Deref>, + { + match self.probe(db, state, runtime, revision_now) { + ProbeState::Retry => MaybeChangedSinceProbeState::Retry, + + ProbeState::Stale(state) => MaybeChangedSinceProbeState::Stale(state), + + // If we know when value last changed, we can return right away. + // Note that we don't need the actual value to be available. + ProbeState::NoValue(_, changed_at) + | ProbeState::UpToDate(StampedValue { value: _, durability: _, changed_at }) => { + MaybeChangedSinceProbeState::ChangedAt(changed_at) + } + + // If we have nothing cached, then value may have changed. + ProbeState::NotComputed(_) => MaybeChangedSinceProbeState::ChangedAt(revision_now), + } + } + + fn maybe_changed_after_upgrade( + &self, + db: &>::DynDb, + revision: Revision, + key: &Q::Key, + ) -> bool { + let runtime = db.salsa_runtime(); + let revision_now = runtime.current_revision(); + + // Get an upgradable read lock, which permits other reads but no writers. + // Probe again. If the value is stale (needs to be verified), then upgrade + // to a write lock and swap it with InProgress while we work. + let mut old_memo = match self.maybe_changed_after_probe( + db, + self.state.upgradable_read(), + runtime, + revision_now, + ) { + MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, + + // If another thread was active, then the cache line is going to be + // either verified or cleared out. Just recurse to figure out which. + // Note that we don't need an upgradable read. + MaybeChangedSinceProbeState::Retry => { + return self.maybe_changed_after(db, revision, key) + } + + MaybeChangedSinceProbeState::Stale(state) => { + type RwLockUpgradableReadGuard<'a, T> = + lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; + + let mut state = RwLockUpgradableReadGuard::upgrade(state); + match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) { + QueryState::Memoized(old_memo) => old_memo, + QueryState::NotComputed | QueryState::InProgress { .. } => unreachable!(), + } + } + }; + + let panic_guard = PanicGuard::new(self, runtime); + let active_query = runtime.push_query(self.database_key_index()); + + if old_memo.verify_revisions(db.ops_database(), revision_now, &active_query) { + let maybe_changed = old_memo.revisions.changed_at > revision; + panic_guard.proceed(Some(old_memo)); + maybe_changed + } else if old_memo.value.is_some() { + // We found that this memoized value may have changed + // but we have an old value. We can re-run the code and + // actually *check* if it has changed. + let StampedValue { changed_at, .. } = self.execute( + db, + runtime, + revision_now, + active_query, + panic_guard, + Some(old_memo), + key, + ); + changed_at > revision + } else { + // We found that inputs to this memoized value may have chanced + // but we don't have an old value to compare against or re-use. + // No choice but to drop the memo and say that its value may have changed. + panic_guard.proceed(None); + true + } + } + + /// Helper: see [`Runtime::try_block_on_or_unwind`]. + fn block_on_or_unwind( + &self, + db: &>::DynDb, + runtime: &Runtime, + other_id: RuntimeId, + mutex_guard: MutexGuard, + ) { + runtime.block_on_or_unwind( + db.ops_database(), + self.database_key_index(), + other_id, + mutex_guard, + ) + } + + fn should_memoize_value(&self, key: &Q::Key) -> bool { + MP::should_memoize_value(key) + } +} + +impl QueryState +where + Q: QueryFunction, +{ + fn in_progress(id: RuntimeId) -> Self { + QueryState::InProgress { id, anyone_waiting: Default::default() } + } +} + +struct PanicGuard<'me, Q, MP> +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + slot: &'me Slot, + runtime: &'me Runtime, +} + +impl<'me, Q, MP> PanicGuard<'me, Q, MP> +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn new(slot: &'me Slot, runtime: &'me Runtime) -> Self { + Self { slot, runtime } + } + + /// Indicates that we have concluded normally (without panicking). + /// If `opt_memo` is some, then this memo is installed as the new + /// memoized value. If `opt_memo` is `None`, then the slot is cleared + /// and has no value. + fn proceed(mut self, opt_memo: Option>) { + self.overwrite_placeholder(WaitResult::Completed, opt_memo); + std::mem::forget(self) + } + + /// Overwrites the `InProgress` placeholder for `key` that we + /// inserted; if others were blocked, waiting for us to finish, + /// then notify them. + fn overwrite_placeholder(&mut self, wait_result: WaitResult, opt_memo: Option>) { + let old_value = { + let mut write = self.slot.state.write(); + match opt_memo { + // Replace the `InProgress` marker that we installed with the new + // memo, thus releasing our unique access to this key. + Some(memo) => std::mem::replace(&mut *write, QueryState::Memoized(memo)), + + // We had installed an `InProgress` marker, but we panicked before + // it could be removed. At this point, we therefore "own" unique + // access to our slot, so we can just remove the key. + None => std::mem::replace(&mut *write, QueryState::NotComputed), + } + }; + + match old_value { + QueryState::InProgress { id, anyone_waiting } => { + assert_eq!(id, self.runtime.id()); + + // NB: As noted on the `store`, `Ordering::Relaxed` is + // sufficient here. This boolean signals us on whether to + // acquire a mutex; the mutex will guarantee that all writes + // we are interested in are visible. + if anyone_waiting.load(Ordering::Relaxed) { + self.runtime + .unblock_queries_blocked_on(self.slot.database_key_index(), wait_result); + } + } + _ => panic!( + "\ +Unexpected panic during query evaluation, aborting the process. + +Please report this bug to https://github.com/salsa-rs/salsa/issues." + ), + } + } +} + +impl<'me, Q, MP> Drop for PanicGuard<'me, Q, MP> +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn drop(&mut self) { + if std::thread::panicking() { + // We panicked before we could proceed and need to remove `key`. + self.overwrite_placeholder(WaitResult::Panicked, None) + } else { + // If no panic occurred, then panic guard ought to be + // "forgotten" and so this Drop code should never run. + panic!(".forget() was not called") + } + } +} + +impl Memo +where + V: Clone, +{ + /// Determines whether the value stored in this memo (if any) is still + /// valid in the current revision. If so, returns a stamped value. + /// + /// If needed, this will walk each dependency and + /// recursively invoke `maybe_changed_after`, which may in turn + /// re-execute the dependency. This can cause cycles to occur, + /// so the current query must be pushed onto the + /// stack to permit cycle detection and recovery: therefore, + /// takes the `active_query` argument as evidence. + fn verify_value( + &mut self, + db: &dyn Database, + revision_now: Revision, + active_query: &ActiveQueryGuard<'_>, + ) -> Option> { + // If we don't have a memoized value, nothing to validate. + if self.value.is_none() { + return None; + } + if self.verify_revisions(db, revision_now, active_query) { + self.value.clone().map(|value| StampedValue { + durability: self.revisions.durability, + changed_at: self.revisions.changed_at, + value, + }) + } else { + None + } + } + + /// Determines whether the value represented by this memo is still + /// valid in the current revision; note that the value itself is + /// not needed for this check. If needed, this will walk each + /// dependency and recursively invoke `maybe_changed_after`, which + /// may in turn re-execute the dependency. This can cause cycles to occur, + /// so the current query must be pushed onto the + /// stack to permit cycle detection and recovery: therefore, + /// takes the `active_query` argument as evidence. + fn verify_revisions( + &mut self, + db: &dyn Database, + revision_now: Revision, + _active_query: &ActiveQueryGuard<'_>, + ) -> bool { + assert!(self.verified_at != revision_now); + let verified_at = self.verified_at; + + debug!( + "verify_revisions: verified_at={:?}, revision_now={:?}, inputs={:#?}", + verified_at, revision_now, self.revisions.inputs + ); + + if self.check_durability(db.salsa_runtime()) { + return self.mark_value_as_verified(revision_now); + } + + match &self.revisions.inputs { + // We can't validate values that had untracked inputs; just have to + // re-execute. + None if self.revisions.untracked => return false, + None => {} + + // Check whether any of our inputs changed since the + // **last point where we were verified** (not since we + // last changed). This is important: if we have + // memoized values, then an input may have changed in + // revision R2, but we found that *our* value was the + // same regardless, so our change date is still + // R1. But our *verification* date will be R2, and we + // are only interested in finding out whether the + // input changed *again*. + Some(inputs) => { + let changed_input = + inputs.slice.iter().find(|&&input| db.maybe_changed_after(input, verified_at)); + if let Some(input) = changed_input { + debug!("validate_memoized_value: `{:?}` may have changed", input); + + return false; + } + } + }; + + self.mark_value_as_verified(revision_now) + } + + /// True if this memo is known not to have changed based on its durability. + fn check_durability(&self, runtime: &Runtime) -> bool { + let last_changed = runtime.last_changed_revision(self.revisions.durability); + debug!( + "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}", + last_changed, + self.verified_at, + last_changed <= self.verified_at, + ); + last_changed <= self.verified_at + } + + fn mark_value_as_verified(&mut self, revision_now: Revision) -> bool { + self.verified_at = revision_now; + true + } + + fn has_untracked_input(&self) -> bool { + self.revisions.untracked + } +} + +impl std::fmt::Debug for Slot +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(fmt, "{:?}", Q::default()) + } +} + +impl LruNode for Slot +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn lru_index(&self) -> &LruIndex { + &self.lru_index + } +} + +/// Check that `Slot: Send + Sync` as long as +/// `DB::DatabaseData: Send + Sync`, which in turn implies that +/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`. +#[allow(dead_code)] +fn check_send_sync() +where + Q: QueryFunction, + MP: MemoizationPolicy, + Q::Key: Send + Sync, + Q::Value: Send + Sync, +{ + fn is_send_sync() {} + is_send_sync::>(); +} + +/// Check that `Slot: 'static` as long as +/// `DB::DatabaseData: 'static`, which in turn implies that +/// `Q::Key: 'static`, `Q::Value: 'static`. +#[allow(dead_code)] +fn check_static() +where + Q: QueryFunction + 'static, + MP: MemoizationPolicy + 'static, + Q::Key: 'static, + Q::Value: 'static, +{ + fn is_static() {} + is_static::>(); +} diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs index eac52653f2..48d6dc2e38 100644 --- a/crates/salsa/src/lib.rs +++ b/crates/salsa/src/lib.rs @@ -10,6 +10,7 @@ //! from previous invocations as appropriate. mod derived; +mod derived_lru; mod durability; mod hash; mod input; diff --git a/crates/salsa/src/plumbing.rs b/crates/salsa/src/plumbing.rs index cae3a50132..e96b9daa97 100644 --- a/crates/salsa/src/plumbing.rs +++ b/crates/salsa/src/plumbing.rs @@ -12,8 +12,9 @@ use std::fmt::Debug; use std::hash::Hash; use triomphe::Arc; -pub use crate::derived::DependencyStorage; pub use crate::derived::MemoizedStorage; +pub use crate::derived_lru::DependencyStorage as LruDependencyStorage; +pub use crate::derived_lru::MemoizedStorage as LruMemoizedStorage; pub use crate::input::{InputStorage, UnitInputStorage}; pub use crate::interned::InternedStorage; pub use crate::interned::LookupInternedStorage; diff --git a/crates/salsa/tests/lru.rs b/crates/salsa/tests/lru.rs index 3da8519b08..ef98a2c32b 100644 --- a/crates/salsa/tests/lru.rs +++ b/crates/salsa/tests/lru.rs @@ -24,7 +24,9 @@ impl Drop for HotPotato { #[salsa::query_group(QueryGroupStorage)] trait QueryGroup: salsa::Database { + #[salsa::lru] fn get(&self, x: u32) -> Arc; + #[salsa::lru] fn get_volatile(&self, x: u32) -> usize; }