Auto merge of #16643 - Veykril:salsa-opt, r=Veykril

internal: Optimize salsa memory usage

Reduces memory on self by ~20mb for me, there is a few more mb to save here if we made LRU caching opt-in, as currently every entry in a memoized query will store an `AtomicUsize` for the LRU.
This commit is contained in:
bors 2024-02-23 15:18:58 +00:00
commit cbc579e51f
18 changed files with 173 additions and 162 deletions

1
Cargo.lock generated
View file

@ -1709,6 +1709,7 @@ dependencies = [
"dissimilar", "dissimilar",
"expect-test", "expect-test",
"indexmap", "indexmap",
"itertools",
"linked-hash-map", "linked-hash-map",
"lock_api", "lock_api",
"oorandom", "oorandom",

View file

@ -17,7 +17,7 @@ impl RootDatabase {
pub fn request_cancellation(&mut self) { pub fn request_cancellation(&mut self) {
let _p = let _p =
tracing::span!(tracing::Level::INFO, "RootDatabase::request_cancellation").entered(); tracing::span!(tracing::Level::INFO, "RootDatabase::request_cancellation").entered();
self.salsa_runtime_mut().synthetic_write(Durability::LOW); self.synthetic_write(Durability::LOW);
} }
pub fn apply_change(&mut self, change: Change) { pub fn apply_change(&mut self, change: Change) {

View file

@ -21,6 +21,7 @@ rustc-hash = "1.0"
smallvec = "1.0.0" smallvec = "1.0.0"
oorandom = "11" oorandom = "11"
triomphe = "0.1.11" triomphe = "0.1.11"
itertools.workspace = true
salsa-macros = { version = "0.0.0", path = "salsa-macros" } salsa-macros = { version = "0.0.0", path = "salsa-macros" }

View file

@ -154,8 +154,8 @@ pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream {
self.#db_storage_field.salsa_runtime() self.#db_storage_field.salsa_runtime()
} }
fn ops_salsa_runtime_mut(&mut self) -> &mut salsa::Runtime { fn synthetic_write(&mut self, durability: salsa::Durability) {
self.#db_storage_field.salsa_runtime_mut() self.#db_storage_field.salsa_runtime_mut().synthetic_write(durability)
} }
fn fmt_index( fn fmt_index(

View file

@ -526,7 +526,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
fmt_ops.extend(quote! { fmt_ops.extend(quote! {
#query_index => { #query_index => {
salsa::plumbing::QueryStorageOps::fmt_index( salsa::plumbing::QueryStorageOps::fmt_index(
&*self.#fn_name, db, input, fmt, &*self.#fn_name, db, input.key_index(), fmt,
) )
} }
}); });
@ -537,7 +537,7 @@ pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream
maybe_changed_ops.extend(quote! { maybe_changed_ops.extend(quote! {
#query_index => { #query_index => {
salsa::plumbing::QueryStorageOps::maybe_changed_after( salsa::plumbing::QueryStorageOps::maybe_changed_after(
&*self.#fn_name, db, input, revision &*self.#fn_name, db, input.key_index(), revision
) )
} }
}); });

View file

@ -102,13 +102,13 @@ where
let mut write = self.slot_map.write(); let mut write = self.slot_map.write();
let entry = write.entry(key.clone()); let entry = write.entry(key.clone());
let key_index = u32::try_from(entry.index()).unwrap(); let key_index = entry.index() as u32;
let database_key_index = DatabaseKeyIndex { let database_key_index = DatabaseKeyIndex {
group_index: self.group_index, group_index: self.group_index,
query_index: Q::QUERY_INDEX, query_index: Q::QUERY_INDEX,
key_index, key_index,
}; };
entry.or_insert_with(|| Arc::new(Slot::new(key.clone(), database_key_index))).clone() entry.or_insert_with(|| Arc::new(Slot::new(database_key_index))).clone()
} }
} }
@ -131,34 +131,33 @@ where
fn fmt_index( fn fmt_index(
&self, &self,
_db: &<Q as QueryDb<'_>>::DynDb, _db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex, index: u32,
fmt: &mut std::fmt::Formatter<'_>, fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result { ) -> std::fmt::Result {
assert_eq!(index.group_index, self.group_index);
assert_eq!(index.query_index, Q::QUERY_INDEX);
let slot_map = self.slot_map.read(); let slot_map = self.slot_map.read();
let key = slot_map.get_index(index.key_index as usize).unwrap().0; let key = slot_map.get_index(index as usize).unwrap().0;
write!(fmt, "{}({:?})", Q::QUERY_NAME, key) write!(fmt, "{}({:?})", Q::QUERY_NAME, key)
} }
fn maybe_changed_after( fn maybe_changed_after(
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex, index: u32,
revision: Revision, revision: Revision,
) -> bool { ) -> bool {
assert_eq!(input.group_index, self.group_index);
assert_eq!(input.query_index, Q::QUERY_INDEX);
debug_assert!(revision < db.salsa_runtime().current_revision()); debug_assert!(revision < db.salsa_runtime().current_revision());
let slot = self.slot_map.read().get_index(input.key_index as usize).unwrap().1.clone(); let read = &self.slot_map.read();
slot.maybe_changed_after(db, revision) let Some((key, slot)) = read.get_index(index as usize) else {
return false;
};
slot.maybe_changed_after(db, revision, key)
} }
fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value { fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
db.unwind_if_cancelled(); db.unwind_if_cancelled();
let slot = self.slot(key); let slot = self.slot(key);
let StampedValue { value, durability, changed_at } = slot.read(db); let StampedValue { value, durability, changed_at } = slot.read(db, key);
if let Some(evicted) = self.lru_list.record_use(&slot) { if let Some(evicted) = self.lru_list.record_use(&slot) {
evicted.evict(); evicted.evict();
@ -182,7 +181,7 @@ where
C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>, C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
{ {
let slot_map = self.slot_map.read(); let slot_map = self.slot_map.read();
slot_map.values().filter_map(|slot| slot.as_table_entry()).collect() slot_map.iter().filter_map(|(key, slot)| slot.as_table_entry(key)).collect()
} }
} }

View file

@ -26,8 +26,8 @@ where
Q: QueryFunction, Q: QueryFunction,
MP: MemoizationPolicy<Q>, MP: MemoizationPolicy<Q>,
{ {
key: Q::Key, key_index: u32,
database_key_index: DatabaseKeyIndex, group_index: u16,
state: RwLock<QueryState<Q>>, state: RwLock<QueryState<Q>>,
policy: PhantomData<MP>, policy: PhantomData<MP>,
lru_index: LruIndex, lru_index: LruIndex,
@ -110,10 +110,10 @@ where
Q: QueryFunction, Q: QueryFunction,
MP: MemoizationPolicy<Q>, MP: MemoizationPolicy<Q>,
{ {
pub(super) fn new(key: Q::Key, database_key_index: DatabaseKeyIndex) -> Self { pub(super) fn new(database_key_index: DatabaseKeyIndex) -> Self {
Self { Self {
key, key_index: database_key_index.key_index,
database_key_index, group_index: database_key_index.group_index,
state: RwLock::new(QueryState::NotComputed), state: RwLock::new(QueryState::NotComputed),
lru_index: LruIndex::default(), lru_index: LruIndex::default(),
policy: PhantomData, policy: PhantomData,
@ -121,10 +121,18 @@ where
} }
pub(super) fn database_key_index(&self) -> DatabaseKeyIndex { pub(super) fn database_key_index(&self) -> DatabaseKeyIndex {
self.database_key_index DatabaseKeyIndex {
group_index: self.group_index,
query_index: Q::QUERY_INDEX,
key_index: self.key_index,
}
} }
pub(super) fn read(&self, db: &<Q as QueryDb<'_>>::DynDb) -> StampedValue<Q::Value> { pub(super) fn read(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
key: &Q::Key,
) -> StampedValue<Q::Value> {
let runtime = db.salsa_runtime(); let runtime = db.salsa_runtime();
// NB: We don't need to worry about people modifying the // NB: We don't need to worry about people modifying the
@ -147,7 +155,7 @@ where
} }
} }
self.read_upgrade(db, revision_now) self.read_upgrade(db, key, revision_now)
} }
/// Second phase of a read operation: acquires an upgradable-read /// Second phase of a read operation: acquires an upgradable-read
@ -157,6 +165,7 @@ where
fn read_upgrade( fn read_upgrade(
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
key: &Q::Key,
revision_now: Revision, revision_now: Revision,
) -> StampedValue<Q::Value> { ) -> StampedValue<Q::Value> {
let runtime = db.salsa_runtime(); let runtime = db.salsa_runtime();
@ -186,8 +195,8 @@ where
} }
}; };
let panic_guard = PanicGuard::new(self.database_key_index, self, runtime); let panic_guard = PanicGuard::new(self, runtime);
let active_query = runtime.push_query(self.database_key_index); let active_query = runtime.push_query(self.database_key_index());
// If we have an old-value, it *may* now be stale, since there // If we have an old-value, it *may* now be stale, since there
// has been a new revision since the last time we checked. So, // has been a new revision since the last time we checked. So,
@ -200,7 +209,7 @@ where
db.salsa_event(Event { db.salsa_event(Event {
runtime_id: runtime.id(), runtime_id: runtime.id(),
kind: EventKind::DidValidateMemoizedValue { kind: EventKind::DidValidateMemoizedValue {
database_key: self.database_key_index, database_key: self.database_key_index(),
}, },
}); });
@ -210,7 +219,7 @@ where
} }
} }
self.execute(db, runtime, revision_now, active_query, panic_guard, old_memo) self.execute(db, runtime, revision_now, active_query, panic_guard, old_memo, key)
} }
fn execute( fn execute(
@ -221,22 +230,23 @@ where
active_query: ActiveQueryGuard<'_>, active_query: ActiveQueryGuard<'_>,
panic_guard: PanicGuard<'_, Q, MP>, panic_guard: PanicGuard<'_, Q, MP>,
old_memo: Option<Memo<Q::Value>>, old_memo: Option<Memo<Q::Value>>,
key: &Q::Key,
) -> StampedValue<Q::Value> { ) -> StampedValue<Q::Value> {
tracing::info!("{:?}: executing query", self.database_key_index.debug(db)); tracing::info!("{:?}: executing query", self.database_key_index().debug(db));
db.salsa_event(Event { db.salsa_event(Event {
runtime_id: db.salsa_runtime().id(), runtime_id: db.salsa_runtime().id(),
kind: EventKind::WillExecute { database_key: self.database_key_index }, kind: EventKind::WillExecute { database_key: self.database_key_index() },
}); });
// Query was not previously executed, or value is potentially // Query was not previously executed, or value is potentially
// stale, or value is absent. Let's execute! // stale, or value is absent. Let's execute!
let value = match Cycle::catch(|| Q::execute(db, self.key.clone())) { let value = match Cycle::catch(|| Q::execute(db, key.clone())) {
Ok(v) => v, Ok(v) => v,
Err(cycle) => { Err(cycle) => {
tracing::debug!( tracing::debug!(
"{:?}: caught cycle {:?}, have strategy {:?}", "{:?}: caught cycle {:?}, have strategy {:?}",
self.database_key_index.debug(db), self.database_key_index().debug(db),
cycle, cycle,
Q::CYCLE_STRATEGY, Q::CYCLE_STRATEGY,
); );
@ -248,12 +258,12 @@ where
crate::plumbing::CycleRecoveryStrategy::Fallback => { crate::plumbing::CycleRecoveryStrategy::Fallback => {
if let Some(c) = active_query.take_cycle() { if let Some(c) = active_query.take_cycle() {
assert!(c.is(&cycle)); assert!(c.is(&cycle));
Q::cycle_fallback(db, &cycle, &self.key) Q::cycle_fallback(db, &cycle, key)
} else { } else {
// we are not a participant in this cycle // we are not a participant in this cycle
debug_assert!(!cycle debug_assert!(!cycle
.participant_keys() .participant_keys()
.any(|k| k == self.database_key_index)); .any(|k| k == self.database_key_index()));
cycle.throw() cycle.throw()
} }
} }
@ -303,7 +313,7 @@ where
}; };
let memo_value = let memo_value =
if self.should_memoize_value(&self.key) { Some(new_value.value.clone()) } else { None }; if self.should_memoize_value(key) { Some(new_value.value.clone()) } else { None };
debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,); debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,);
@ -395,13 +405,11 @@ where
} }
} }
pub(super) fn as_table_entry(&self) -> Option<TableEntry<Q::Key, Q::Value>> { pub(super) fn as_table_entry(&self, key: &Q::Key) -> Option<TableEntry<Q::Key, Q::Value>> {
match &*self.state.read() { match &*self.state.read() {
QueryState::NotComputed => None, QueryState::NotComputed => None,
QueryState::InProgress { .. } => Some(TableEntry::new(self.key.clone(), None)), QueryState::InProgress { .. } => Some(TableEntry::new(key.clone(), None)),
QueryState::Memoized(memo) => { QueryState::Memoized(memo) => Some(TableEntry::new(key.clone(), memo.value.clone())),
Some(TableEntry::new(self.key.clone(), memo.value.clone()))
}
} }
} }
@ -436,6 +444,7 @@ where
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
revision: Revision, revision: Revision,
key: &Q::Key,
) -> bool { ) -> bool {
let runtime = db.salsa_runtime(); let runtime = db.salsa_runtime();
let revision_now = runtime.current_revision(); let revision_now = runtime.current_revision();
@ -458,7 +467,7 @@ where
MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision,
MaybeChangedSinceProbeState::Stale(state) => { MaybeChangedSinceProbeState::Stale(state) => {
drop(state); drop(state);
return self.maybe_changed_after_upgrade(db, revision); return self.maybe_changed_after_upgrade(db, revision, key);
} }
} }
} }
@ -495,6 +504,7 @@ where
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
revision: Revision, revision: Revision,
key: &Q::Key,
) -> bool { ) -> bool {
let runtime = db.salsa_runtime(); let runtime = db.salsa_runtime();
let revision_now = runtime.current_revision(); let revision_now = runtime.current_revision();
@ -513,7 +523,9 @@ where
// If another thread was active, then the cache line is going to be // If another thread was active, then the cache line is going to be
// either verified or cleared out. Just recurse to figure out which. // either verified or cleared out. Just recurse to figure out which.
// Note that we don't need an upgradable read. // Note that we don't need an upgradable read.
MaybeChangedSinceProbeState::Retry => return self.maybe_changed_after(db, revision), MaybeChangedSinceProbeState::Retry => {
return self.maybe_changed_after(db, revision, key)
}
MaybeChangedSinceProbeState::Stale(state) => { MaybeChangedSinceProbeState::Stale(state) => {
type RwLockUpgradableReadGuard<'a, T> = type RwLockUpgradableReadGuard<'a, T> =
@ -527,8 +539,8 @@ where
} }
}; };
let panic_guard = PanicGuard::new(self.database_key_index, self, runtime); let panic_guard = PanicGuard::new(self, runtime);
let active_query = runtime.push_query(self.database_key_index); let active_query = runtime.push_query(self.database_key_index());
if old_memo.verify_revisions(db.ops_database(), revision_now, &active_query) { if old_memo.verify_revisions(db.ops_database(), revision_now, &active_query) {
let maybe_changed = old_memo.revisions.changed_at > revision; let maybe_changed = old_memo.revisions.changed_at > revision;
@ -538,8 +550,15 @@ where
// We found that this memoized value may have changed // We found that this memoized value may have changed
// but we have an old value. We can re-run the code and // but we have an old value. We can re-run the code and
// actually *check* if it has changed. // actually *check* if it has changed.
let StampedValue { changed_at, .. } = let StampedValue { changed_at, .. } = self.execute(
self.execute(db, runtime, revision_now, active_query, panic_guard, Some(old_memo)); db,
runtime,
revision_now,
active_query,
panic_guard,
Some(old_memo),
key,
);
changed_at > revision changed_at > revision
} else { } else {
// We found that inputs to this memoized value may have chanced // We found that inputs to this memoized value may have chanced
@ -560,7 +579,7 @@ where
) { ) {
runtime.block_on_or_unwind( runtime.block_on_or_unwind(
db.ops_database(), db.ops_database(),
self.database_key_index, self.database_key_index(),
other_id, other_id,
mutex_guard, mutex_guard,
) )
@ -585,7 +604,6 @@ where
Q: QueryFunction, Q: QueryFunction,
MP: MemoizationPolicy<Q>, MP: MemoizationPolicy<Q>,
{ {
database_key_index: DatabaseKeyIndex,
slot: &'me Slot<Q, MP>, slot: &'me Slot<Q, MP>,
runtime: &'me Runtime, runtime: &'me Runtime,
} }
@ -595,12 +613,8 @@ where
Q: QueryFunction, Q: QueryFunction,
MP: MemoizationPolicy<Q>, MP: MemoizationPolicy<Q>,
{ {
fn new( fn new(slot: &'me Slot<Q, MP>, runtime: &'me Runtime) -> Self {
database_key_index: DatabaseKeyIndex, Self { slot, runtime }
slot: &'me Slot<Q, MP>,
runtime: &'me Runtime,
) -> Self {
Self { database_key_index, slot, runtime }
} }
/// Indicates that we have concluded normally (without panicking). /// Indicates that we have concluded normally (without panicking).
@ -616,17 +630,18 @@ where
/// inserted; if others were blocked, waiting for us to finish, /// inserted; if others were blocked, waiting for us to finish,
/// then notify them. /// then notify them.
fn overwrite_placeholder(&mut self, wait_result: WaitResult, opt_memo: Option<Memo<Q::Value>>) { fn overwrite_placeholder(&mut self, wait_result: WaitResult, opt_memo: Option<Memo<Q::Value>>) {
let mut write = self.slot.state.write(); let old_value = {
let mut write = self.slot.state.write();
match opt_memo {
// Replace the `InProgress` marker that we installed with the new
// memo, thus releasing our unique access to this key.
Some(memo) => std::mem::replace(&mut *write, QueryState::Memoized(memo)),
let old_value = match opt_memo { // We had installed an `InProgress` marker, but we panicked before
// Replace the `InProgress` marker that we installed with the new // it could be removed. At this point, we therefore "own" unique
// memo, thus releasing our unique access to this key. // access to our slot, so we can just remove the key.
Some(memo) => std::mem::replace(&mut *write, QueryState::Memoized(memo)), None => std::mem::replace(&mut *write, QueryState::NotComputed),
}
// We had installed an `InProgress` marker, but we panicked before
// it could be removed. At this point, we therefore "own" unique
// access to our slot, so we can just remove the key.
None => std::mem::replace(&mut *write, QueryState::NotComputed),
}; };
match old_value { match old_value {
@ -638,7 +653,8 @@ where
// acquire a mutex; the mutex will guarantee that all writes // acquire a mutex; the mutex will guarantee that all writes
// we are interested in are visible. // we are interested in are visible.
if anyone_waiting.load(Ordering::Relaxed) { if anyone_waiting.load(Ordering::Relaxed) {
self.runtime.unblock_queries_blocked_on(self.database_key_index, wait_result); self.runtime
.unblock_queries_blocked_on(self.slot.database_key_index(), wait_result);
} }
} }
_ => panic!( _ => panic!(
@ -692,10 +708,10 @@ where
return None; return None;
} }
if self.verify_revisions(db, revision_now, active_query) { if self.verify_revisions(db, revision_now, active_query) {
Some(StampedValue { self.value.clone().map(|value| StampedValue {
durability: self.revisions.durability, durability: self.revisions.durability,
changed_at: self.revisions.changed_at, changed_at: self.revisions.changed_at,
value: self.value.as_ref().unwrap().clone(), value,
}) })
} else { } else {
None None
@ -748,7 +764,7 @@ where
// input changed *again*. // input changed *again*.
QueryInputs::Tracked { inputs } => { QueryInputs::Tracked { inputs } => {
let changed_input = let changed_input =
inputs.iter().find(|&&input| db.maybe_changed_after(input, verified_at)); inputs.slice.iter().find(|&&input| db.maybe_changed_after(input, verified_at));
if let Some(input) = changed_input { if let Some(input) = changed_input {
debug!("validate_memoized_value: `{:?}` may have changed", input); debug!("validate_memoized_value: `{:?}` may have changed", input);
@ -788,7 +804,7 @@ where
MP: MemoizationPolicy<Q>, MP: MemoizationPolicy<Q>,
{ {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "{:?}({:?})", Q::default(), self.key) write!(fmt, "{:?}", Q::default())
} }
} }

View file

@ -42,9 +42,9 @@ impl Durability {
pub(crate) const MAX: Durability = Self::HIGH; pub(crate) const MAX: Durability = Self::HIGH;
/// Number of durability levels. /// Number of durability levels.
pub(crate) const LEN: usize = 3; pub(crate) const LEN: usize = Self::MAX.index() + 1;
pub(crate) fn index(self) -> usize { pub(crate) const fn index(self) -> usize {
self.0 as usize self.0 as usize
} }
} }

View file

@ -29,7 +29,7 @@ where
} }
struct Slot<V> { struct Slot<V> {
database_key_index: DatabaseKeyIndex, key_index: u32,
stamped_value: RwLock<StampedValue<V>>, stamped_value: RwLock<StampedValue<V>>,
} }
@ -54,27 +54,25 @@ where
fn fmt_index( fn fmt_index(
&self, &self,
_db: &<Q as QueryDb<'_>>::DynDb, _db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex, index: u32,
fmt: &mut std::fmt::Formatter<'_>, fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result { ) -> std::fmt::Result {
assert_eq!(index.group_index, self.group_index);
assert_eq!(index.query_index, Q::QUERY_INDEX);
let slot_map = self.slots.read(); let slot_map = self.slots.read();
let key = slot_map.get_index(index.key_index as usize).unwrap().0; let key = slot_map.get_index(index as usize).unwrap().0;
write!(fmt, "{}({:?})", Q::QUERY_NAME, key) write!(fmt, "{}({:?})", Q::QUERY_NAME, key)
} }
fn maybe_changed_after( fn maybe_changed_after(
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex, index: u32,
revision: Revision, revision: Revision,
) -> bool { ) -> bool {
assert_eq!(input.group_index, self.group_index);
assert_eq!(input.query_index, Q::QUERY_INDEX);
debug_assert!(revision < db.salsa_runtime().current_revision()); debug_assert!(revision < db.salsa_runtime().current_revision());
let slots = &self.slots.read(); let slots = &self.slots.read();
let slot = slots.get_index(input.key_index as usize).unwrap().1; let Some((_, slot)) = slots.get_index(index as usize) else {
return true;
};
debug!("maybe_changed_after(slot={:?}, revision={:?})", Q::default(), revision,); debug!("maybe_changed_after(slot={:?}, revision={:?})", Q::default(), revision,);
@ -96,7 +94,11 @@ where
let StampedValue { value, durability, changed_at } = slot.stamped_value.read().clone(); let StampedValue { value, durability, changed_at } = slot.stamped_value.read().clone();
db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted( db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted(
slot.database_key_index, DatabaseKeyIndex {
group_index: self.group_index,
query_index: Q::QUERY_INDEX,
key_index: slot.key_index,
},
durability, durability,
changed_at, changed_at,
); );
@ -174,16 +176,8 @@ where
} }
Entry::Vacant(entry) => { Entry::Vacant(entry) => {
let key_index = u32::try_from(entry.index()).unwrap(); let key_index = entry.index() as u32;
let database_key_index = DatabaseKeyIndex { entry.insert(Slot { key_index, stamped_value: RwLock::new(stamped_value) });
group_index: self.group_index,
query_index: Q::QUERY_INDEX,
key_index,
};
entry.insert(Slot {
database_key_index,
stamped_value: RwLock::new(stamped_value),
});
None None
} }
} }
@ -196,7 +190,6 @@ pub struct UnitInputStorage<Q>
where where
Q: Query<Key = ()>, Q: Query<Key = ()>,
{ {
group_index: u16,
slot: UnitSlot<Q::Value>, slot: UnitSlot<Q::Value>,
} }
@ -222,36 +215,32 @@ where
fn new(group_index: u16) -> Self { fn new(group_index: u16) -> Self {
let database_key_index = let database_key_index =
DatabaseKeyIndex { group_index, query_index: Q::QUERY_INDEX, key_index: 0 }; DatabaseKeyIndex { group_index, query_index: Q::QUERY_INDEX, key_index: 0 };
UnitInputStorage { UnitInputStorage { slot: UnitSlot { database_key_index, stamped_value: RwLock::new(None) } }
group_index,
slot: UnitSlot { database_key_index, stamped_value: RwLock::new(None) },
}
} }
fn fmt_index( fn fmt_index(
&self, &self,
_db: &<Q as QueryDb<'_>>::DynDb, _db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex, _index: u32,
fmt: &mut std::fmt::Formatter<'_>, fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result { ) -> std::fmt::Result {
assert_eq!(index.group_index, self.group_index);
assert_eq!(index.query_index, Q::QUERY_INDEX);
write!(fmt, "{}", Q::QUERY_NAME) write!(fmt, "{}", Q::QUERY_NAME)
} }
fn maybe_changed_after( fn maybe_changed_after(
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex, _index: u32,
revision: Revision, revision: Revision,
) -> bool { ) -> bool {
assert_eq!(input.group_index, self.group_index);
assert_eq!(input.query_index, Q::QUERY_INDEX);
debug_assert!(revision < db.salsa_runtime().current_revision()); debug_assert!(revision < db.salsa_runtime().current_revision());
debug!("maybe_changed_after(slot={:?}, revision={:?})", Q::default(), revision,); debug!("maybe_changed_after(slot={:?}, revision={:?})", Q::default(), revision,);
let changed_at = self.slot.stamped_value.read().as_ref().unwrap().changed_at; let Some(value) = &*self.slot.stamped_value.read() else {
return true;
};
let changed_at = value.changed_at;
debug!("maybe_changed_after: changed_at = {:?}", changed_at); debug!("maybe_changed_after: changed_at = {:?}", changed_at);

View file

@ -265,12 +265,10 @@ where
fn fmt_index( fn fmt_index(
&self, &self,
_db: &<Q as QueryDb<'_>>::DynDb, _db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex, index: u32,
fmt: &mut std::fmt::Formatter<'_>, fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result { ) -> std::fmt::Result {
assert_eq!(index.group_index, self.group_index); let intern_id = InternId::from(index);
assert_eq!(index.query_index, Q::QUERY_INDEX);
let intern_id = InternId::from(index.key_index);
let slot = self.lookup_value(intern_id); let slot = self.lookup_value(intern_id);
write!(fmt, "{}({:?})", Q::QUERY_NAME, slot.value) write!(fmt, "{}({:?})", Q::QUERY_NAME, slot.value)
} }
@ -278,13 +276,11 @@ where
fn maybe_changed_after( fn maybe_changed_after(
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex, input: u32,
revision: Revision, revision: Revision,
) -> bool { ) -> bool {
assert_eq!(input.group_index, self.group_index);
assert_eq!(input.query_index, Q::QUERY_INDEX);
debug_assert!(revision < db.salsa_runtime().current_revision()); debug_assert!(revision < db.salsa_runtime().current_revision());
let intern_id = InternId::from(input.key_index); let intern_id = InternId::from(input);
let slot = self.lookup_value(intern_id); let slot = self.lookup_value(intern_id);
slot.maybe_changed_after(revision) slot.maybe_changed_after(revision)
} }
@ -388,7 +384,7 @@ where
fn fmt_index( fn fmt_index(
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex, index: u32,
fmt: &mut std::fmt::Formatter<'_>, fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result { ) -> std::fmt::Result {
let group_storage = let group_storage =
@ -400,7 +396,7 @@ where
fn maybe_changed_after( fn maybe_changed_after(
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex, input: u32,
revision: Revision, revision: Revision,
) -> bool { ) -> bool {
let group_storage = let group_storage =

View file

@ -54,7 +54,7 @@ pub trait Database: plumbing::DatabaseOps {
/// runtime. It permits the database to be customized and to /// runtime. It permits the database to be customized and to
/// inject logging or other custom behavior. /// inject logging or other custom behavior.
fn salsa_event(&self, event_fn: Event) { fn salsa_event(&self, event_fn: Event) {
#![allow(unused_variables)] _ = event_fn;
} }
/// Starts unwinding the stack if the current revision is cancelled. /// Starts unwinding the stack if the current revision is cancelled.
@ -96,11 +96,16 @@ pub trait Database: plumbing::DatabaseOps {
self.ops_salsa_runtime() self.ops_salsa_runtime()
} }
/// Gives access to the underlying salsa runtime. /// A "synthetic write" causes the system to act *as though* some
/// input of durability `durability` has changed. This is mostly
/// useful for profiling scenarios.
/// ///
/// This method should not be overridden by `Database` implementors. /// **WARNING:** Just like an ordinary write, this method triggers
fn salsa_runtime_mut(&mut self) -> &mut Runtime { /// cancellation. If you invoke it while a snapshot exists, it
self.ops_salsa_runtime_mut() /// will block until that snapshot is dropped -- if that snapshot
/// is owned by the current thread, this could trigger deadlock.
fn synthetic_write(&mut self, durability: Durability) {
plumbing::DatabaseOps::synthetic_write(self, durability)
} }
} }

View file

@ -38,8 +38,15 @@ pub trait DatabaseOps {
/// Gives access to the underlying salsa runtime. /// Gives access to the underlying salsa runtime.
fn ops_salsa_runtime(&self) -> &Runtime; fn ops_salsa_runtime(&self) -> &Runtime;
/// Gives access to the underlying salsa runtime. /// A "synthetic write" causes the system to act *as though* some
fn ops_salsa_runtime_mut(&mut self) -> &mut Runtime; /// input of durability `durability` has changed. This is mostly
/// useful for profiling scenarios.
///
/// **WARNING:** Just like an ordinary write, this method triggers
/// cancellation. If you invoke it while a snapshot exists, it
/// will block until that snapshot is dropped -- if that snapshot
/// is owned by the current thread, this could trigger deadlock.
fn synthetic_write(&mut self, durability: Durability);
/// Formats a database key index in a human readable fashion. /// Formats a database key index in a human readable fashion.
fn fmt_index( fn fmt_index(
@ -166,7 +173,7 @@ where
fn fmt_index( fn fmt_index(
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex, index: u32,
fmt: &mut std::fmt::Formatter<'_>, fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result; ) -> std::fmt::Result;
@ -179,7 +186,7 @@ where
fn maybe_changed_after( fn maybe_changed_after(
&self, &self,
db: &<Q as QueryDb<'_>>::DynDb, db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex, index: u32,
revision: Revision, revision: Revision,
) -> bool; ) -> bool;
// ANCHOR_END:maybe_changed_after // ANCHOR_END:maybe_changed_after

View file

@ -46,7 +46,7 @@ pub(crate) struct AtomicRevision {
} }
impl AtomicRevision { impl AtomicRevision {
pub(crate) fn start() -> Self { pub(crate) const fn start() -> Self {
Self { data: AtomicU32::new(START) } Self { data: AtomicU32::new(START) }
} }

View file

@ -4,13 +4,14 @@ use crate::hash::FxIndexSet;
use crate::plumbing::CycleRecoveryStrategy; use crate::plumbing::CycleRecoveryStrategy;
use crate::revision::{AtomicRevision, Revision}; use crate::revision::{AtomicRevision, Revision};
use crate::{Cancelled, Cycle, Database, DatabaseKeyIndex, Event, EventKind}; use crate::{Cancelled, Cycle, Database, DatabaseKeyIndex, Event, EventKind};
use itertools::Itertools;
use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive}; use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use std::hash::Hash; use std::hash::Hash;
use std::panic::panic_any; use std::panic::panic_any;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use tracing::debug; use tracing::debug;
use triomphe::Arc; use triomphe::{Arc, ThinArc};
mod dependency_graph; mod dependency_graph;
use dependency_graph::DependencyGraph; use dependency_graph::DependencyGraph;
@ -297,8 +298,7 @@ impl Runtime {
// (at least for this execution, not necessarily across executions), // (at least for this execution, not necessarily across executions),
// no matter where it started on the stack. Find the minimum // no matter where it started on the stack. Find the minimum
// key and rotate it to the front. // key and rotate it to the front.
let min = v.iter().min().unwrap(); let index = v.iter().position_min().unwrap_or_default();
let index = v.iter().position(|p| p == min).unwrap();
v.rotate_left(index); v.rotate_left(index);
// No need to store extra memory. // No need to store extra memory.
@ -440,7 +440,7 @@ impl Runtime {
/// State that will be common to all threads (when we support multiple threads) /// State that will be common to all threads (when we support multiple threads)
struct SharedState { struct SharedState {
/// Stores the next id to use for a snapshotted runtime (starts at 1). /// Stores the next id to use for a snapshotted runtime (starts at 1).
next_id: AtomicUsize, next_id: AtomicU32,
/// Whenever derived queries are executing, they acquire this lock /// Whenever derived queries are executing, they acquire this lock
/// in read mode. Mutating inputs (and thus creating a new /// in read mode. Mutating inputs (and thus creating a new
@ -457,50 +457,46 @@ struct SharedState {
/// revision is cancelled). /// revision is cancelled).
pending_revision: AtomicRevision, pending_revision: AtomicRevision,
/// Stores the "last change" revision for values of each duration. /// Stores the "last change" revision for values of each Durability.
/// This vector is always of length at least 1 (for Durability 0) /// This vector is always of length at least 1 (for Durability 0)
/// but its total length depends on the number of durations. The /// but its total length depends on the number of Durabilities. The
/// element at index 0 is special as it represents the "current /// element at index 0 is special as it represents the "current
/// revision". In general, we have the invariant that revisions /// revision". In general, we have the invariant that revisions
/// in here are *declining* -- that is, `revisions[i] >= /// in here are *declining* -- that is, `revisions[i] >=
/// revisions[i + 1]`, for all `i`. This is because when you /// revisions[i + 1]`, for all `i`. This is because when you
/// modify a value with durability D, that implies that values /// modify a value with durability D, that implies that values
/// with durability less than D may have changed too. /// with durability less than D may have changed too.
revisions: Vec<AtomicRevision>, revisions: [AtomicRevision; Durability::LEN],
/// The dependency graph tracks which runtimes are blocked on one /// The dependency graph tracks which runtimes are blocked on one
/// another, waiting for queries to terminate. /// another, waiting for queries to terminate.
dependency_graph: Mutex<DependencyGraph>, dependency_graph: Mutex<DependencyGraph>,
} }
impl SharedState { impl std::panic::RefUnwindSafe for SharedState {}
fn with_durabilities(durabilities: usize) -> Self {
impl Default for SharedState {
fn default() -> Self {
#[allow(clippy::declare_interior_mutable_const)]
const START: AtomicRevision = AtomicRevision::start();
SharedState { SharedState {
next_id: AtomicUsize::new(1), next_id: AtomicU32::new(1),
query_lock: Default::default(), query_lock: Default::default(),
revisions: (0..durabilities).map(|_| AtomicRevision::start()).collect(), revisions: [START; Durability::LEN],
pending_revision: AtomicRevision::start(), pending_revision: START,
dependency_graph: Default::default(), dependency_graph: Default::default(),
} }
} }
} }
impl std::panic::RefUnwindSafe for SharedState {}
impl Default for SharedState {
fn default() -> Self {
Self::with_durabilities(Durability::LEN)
}
}
impl std::fmt::Debug for SharedState { impl std::fmt::Debug for SharedState {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let query_lock = if self.query_lock.try_write().is_some() { let query_lock = if self.query_lock.is_locked_exclusive() {
"<unlocked>" "<wlocked>"
} else if self.query_lock.try_read().is_some() { } else if self.query_lock.is_locked() {
"<rlocked>" "<rlocked>"
} else { } else {
"<wlocked>" "<unlocked>"
}; };
fmt.debug_struct("SharedState") fmt.debug_struct("SharedState")
.field("query_lock", &query_lock) .field("query_lock", &query_lock)
@ -570,7 +566,9 @@ impl ActiveQuery {
if dependencies.is_empty() { if dependencies.is_empty() {
QueryInputs::NoInputs QueryInputs::NoInputs
} else { } else {
QueryInputs::Tracked { inputs: dependencies.iter().copied().collect() } QueryInputs::Tracked {
inputs: ThinArc::from_header_and_iter((), dependencies.iter().copied()),
}
} }
} }
}; };
@ -616,7 +614,7 @@ impl ActiveQuery {
/// complete, its `RuntimeId` may potentially be re-used. /// complete, its `RuntimeId` may potentially be re-used.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct RuntimeId { pub struct RuntimeId {
counter: usize, counter: u32,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]

View file

@ -1,5 +1,6 @@
//! //!
use tracing::debug; use tracing::debug;
use triomphe::ThinArc;
use crate::durability::Durability; use crate::durability::Durability;
use crate::runtime::ActiveQuery; use crate::runtime::ActiveQuery;
@ -7,7 +8,6 @@ use crate::runtime::Revision;
use crate::Cycle; use crate::Cycle;
use crate::DatabaseKeyIndex; use crate::DatabaseKeyIndex;
use std::cell::RefCell; use std::cell::RefCell;
use triomphe::Arc;
/// State that is specific to a single execution thread. /// State that is specific to a single execution thread.
/// ///
@ -43,7 +43,7 @@ pub(crate) struct QueryRevisions {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) enum QueryInputs { pub(crate) enum QueryInputs {
/// Non-empty set of inputs, fully known /// Non-empty set of inputs, fully known
Tracked { inputs: Arc<[DatabaseKeyIndex]> }, Tracked { inputs: ThinArc<(), DatabaseKeyIndex> },
/// Empty set of inputs, fully known. /// Empty set of inputs, fully known.
NoInputs, NoInputs,
@ -145,8 +145,7 @@ impl LocalState {
/// the current thread is blocking. The stack must be restored /// the current thread is blocking. The stack must be restored
/// with [`Self::restore_query_stack`] when the thread unblocks. /// with [`Self::restore_query_stack`] when the thread unblocks.
pub(super) fn take_query_stack(&self) -> Vec<ActiveQuery> { pub(super) fn take_query_stack(&self) -> Vec<ActiveQuery> {
assert!(self.query_stack.borrow().is_some(), "query stack already taken"); self.query_stack.take().expect("query stack already taken")
self.query_stack.take().unwrap()
} }
/// Restores a query stack taken with [`Self::take_query_stack`] once /// Restores a query stack taken with [`Self::take_query_stack`] once

View file

@ -58,7 +58,7 @@ fn revalidate() {
// Second generation: volatile will change (to 1) but memoized1 // Second generation: volatile will change (to 1) but memoized1
// will not (still 0, as 1/2 = 0) // will not (still 0, as 1/2 = 0)
query.salsa_runtime_mut().synthetic_write(Durability::LOW); query.synthetic_write(Durability::LOW);
query.memoized2(); query.memoized2();
query.assert_log(&["Volatile invoked", "Memoized1 invoked"]); query.assert_log(&["Volatile invoked", "Memoized1 invoked"]);
query.memoized2(); query.memoized2();
@ -67,7 +67,7 @@ fn revalidate() {
// Third generation: volatile will change (to 2) and memoized1 // Third generation: volatile will change (to 2) and memoized1
// will too (to 1). Therefore, after validating that Memoized1 // will too (to 1). Therefore, after validating that Memoized1
// changed, we now invoke Memoized2. // changed, we now invoke Memoized2.
query.salsa_runtime_mut().synthetic_write(Durability::LOW); query.synthetic_write(Durability::LOW);
query.memoized2(); query.memoized2();
query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]); query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]);

View file

@ -111,7 +111,7 @@ fn on_demand_input_durability() {
} }
"#]].assert_debug_eq(&events); "#]].assert_debug_eq(&events);
db.salsa_runtime_mut().synthetic_write(Durability::LOW); db.synthetic_write(Durability::LOW);
events.replace(vec![]); events.replace(vec![]);
assert_eq!(db.c(1), 10); assert_eq!(db.c(1), 10);
assert_eq!(db.c(2), 20); assert_eq!(db.c(2), 20);
@ -128,7 +128,7 @@ fn on_demand_input_durability() {
} }
"#]].assert_debug_eq(&events); "#]].assert_debug_eq(&events);
db.salsa_runtime_mut().synthetic_write(Durability::HIGH); db.synthetic_write(Durability::HIGH);
events.replace(vec![]); events.replace(vec![]);
assert_eq!(db.c(1), 10); assert_eq!(db.c(1), 10);
assert_eq!(db.c(2), 20); assert_eq!(db.c(2), 20);

View file

@ -20,7 +20,7 @@ fn volatile_twice() {
let v2 = db.volatile(); // volatiles are cached, so 2nd read returns the same let v2 = db.volatile(); // volatiles are cached, so 2nd read returns the same
assert_eq!(v1, v2); assert_eq!(v1, v2);
db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches db.synthetic_write(Durability::LOW); // clears volatile caches
let v3 = db.volatile(); // will re-increment the counter let v3 = db.volatile(); // will re-increment the counter
let v4 = db.volatile(); // second call will be cached let v4 = db.volatile(); // second call will be cached
@ -40,7 +40,7 @@ fn intermingled() {
assert_eq!(v1, v3); assert_eq!(v1, v3);
assert_eq!(v2, v4); assert_eq!(v2, v4);
db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches db.synthetic_write(Durability::LOW); // clears volatile caches
let v5 = db.memoized(); // re-executes volatile, caches new result let v5 = db.memoized(); // re-executes volatile, caches new result
let v6 = db.memoized(); // re-use cached result let v6 = db.memoized(); // re-use cached result