Auto merge of #16503 - Veykril:salsa, r=Veykril

Move salsa fork in-tree

No one else is supposed to rely on it anyways, this makes it easier to edit.
This commit is contained in:
bors 2024-02-07 16:20:53 +00:00
commit 2d2ddd318b
69 changed files with 9083 additions and 38 deletions

146
Cargo.lock generated
View file

@ -72,8 +72,8 @@ dependencies = [
"cfg", "cfg",
"la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"profile", "profile",
"rust-analyzer-salsa",
"rustc-hash", "rustc-hash",
"salsa",
"semver", "semver",
"span", "span",
"stdx", "stdx",
@ -357,6 +357,15 @@ dependencies = [
"log", "log",
] ]
[[package]]
name = "env_logger"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580"
dependencies = [
"log",
]
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.0" version = "1.0.0"
@ -441,6 +450,17 @@ version = "0.4.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a"
[[package]]
name = "getrandom"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]] [[package]]
name = "gimli" name = "gimli"
version = "0.27.3" version = "0.27.3"
@ -918,6 +938,12 @@ dependencies = [
"text-size", "text-size",
] ]
[[package]]
name = "linked-hash-map"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]] [[package]]
name = "load-cargo" name = "load-cargo"
version = "0.0.0" version = "0.0.0"
@ -1261,6 +1287,12 @@ version = "0.2.9"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]] [[package]]
name = "proc-macro-api" name = "proc-macro-api"
version = "0.0.0" version = "0.0.0"
@ -1504,6 +1536,36 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.8.0" version = "1.8.0"
@ -1611,35 +1673,6 @@ dependencies = [
"xshell", "xshell",
] ]
[[package]]
name = "rust-analyzer-salsa"
version = "0.17.0-pre.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719825638c59fd26a55412a24561c7c5bcf54364c88b9a7a04ba08a6eafaba8d"
dependencies = [
"indexmap",
"lock_api",
"oorandom",
"parking_lot",
"rust-analyzer-salsa-macros",
"rustc-hash",
"smallvec",
"tracing",
"triomphe",
]
[[package]]
name = "rust-analyzer-salsa-macros"
version = "0.17.0-pre.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d96498e9684848c6676c399032ebc37c52da95ecbefa83d71ccc53b9f8a4a8e"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "rustc-demangle" name = "rustc-demangle"
version = "0.1.23" version = "0.1.23"
@ -1668,6 +1701,36 @@ version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
[[package]]
name = "salsa"
version = "0.0.0"
dependencies = [
"dissimilar",
"expect-test",
"indexmap",
"linked-hash-map",
"lock_api",
"oorandom",
"parking_lot",
"rand",
"rustc-hash",
"salsa-macros",
"smallvec",
"test-log",
"tracing",
"triomphe",
]
[[package]]
name = "salsa-macros"
version = "0.0.0"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "same-file" name = "same-file"
version = "1.0.6" version = "1.0.6"
@ -1792,7 +1855,7 @@ name = "span"
version = "0.0.0" version = "0.0.0"
dependencies = [ dependencies = [
"la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"rust-analyzer-salsa", "salsa",
"stdx", "stdx",
"syntax", "syntax",
"vfs", "vfs",
@ -1889,6 +1952,27 @@ dependencies = [
"tt", "tt",
] ]
[[package]]
name = "test-log"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6159ab4116165c99fc88cce31f99fa2c9dbe08d3691cb38da02fc3b45f357d2b"
dependencies = [
"env_logger",
"test-log-macros",
]
[[package]]
name = "test-log-macros"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ba277e77219e9eea169e8508942db1bf5d8a41ff2db9b20aab5a5aadc9fa25d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "test-utils" name = "test-utils"
version = "0.0.0" version = "0.0.0"

View file

@ -70,6 +70,7 @@ proc-macro-srv = { path = "./crates/proc-macro-srv", version = "0.0.0" }
proc-macro-srv-cli = { path = "./crates/proc-macro-srv-cli", version = "0.0.0" } proc-macro-srv-cli = { path = "./crates/proc-macro-srv-cli", version = "0.0.0" }
profile = { path = "./crates/profile", version = "0.0.0" } profile = { path = "./crates/profile", version = "0.0.0" }
project-model = { path = "./crates/project-model", version = "0.0.0" } project-model = { path = "./crates/project-model", version = "0.0.0" }
salsa = { path = "./crates/salsa", version = "0.0.0" }
span = { path = "./crates/span", version = "0.0.0" } span = { path = "./crates/span", version = "0.0.0" }
stdx = { path = "./crates/stdx", version = "0.0.0" } stdx = { path = "./crates/stdx", version = "0.0.0" }
syntax = { path = "./crates/syntax", version = "0.0.0" } syntax = { path = "./crates/syntax", version = "0.0.0" }
@ -113,7 +114,6 @@ itertools = "0.12.0"
libc = "0.2.150" libc = "0.2.150"
nohash-hasher = "0.2.0" nohash-hasher = "0.2.0"
rayon = "1.8.0" rayon = "1.8.0"
rust-analyzer-salsa = "0.17.0-pre.6"
rustc-hash = "1.1.0" rustc-hash = "1.1.0"
semver = "1.0.14" semver = "1.0.14"
serde = { version = "1.0.192", features = ["derive"] } serde = { version = "1.0.192", features = ["derive"] }

View file

@ -13,7 +13,7 @@ doctest = false
[dependencies] [dependencies]
la-arena.workspace = true la-arena.workspace = true
rust-analyzer-salsa.workspace = true salsa.workspace = true
rustc-hash.workspace = true rustc-hash.workspace = true
triomphe.workspace = true triomphe.workspace = true
semver.workspace = true semver.workspace = true

35
crates/salsa/Cargo.toml Normal file
View file

@ -0,0 +1,35 @@
[package]
name = "salsa"
version = "0.0.0"
authors = ["Salsa developers"]
edition = "2021"
license = "Apache-2.0 OR MIT"
repository = "https://github.com/salsa-rs/salsa"
description = "A generic framework for on-demand, incrementalized computation (experimental)"
rust-version.workspace = true
[lib]
name = "salsa"
[dependencies]
indexmap = "2.1.0"
lock_api = "0.4"
tracing = "0.1"
parking_lot = "0.12.1"
rustc-hash = "1.0"
smallvec = "1.0.0"
oorandom = "11"
triomphe = "0.1.11"
salsa-macros = { version = "0.0.0", path = "salsa-macros" }
[dev-dependencies]
linked-hash-map = "0.5.6"
rand = "0.8.5"
test-log = "0.2.14"
expect-test = "1.4.0"
dissimilar = "1.0.7"
[lints]
workspace = true

34
crates/salsa/FAQ.md Normal file
View file

@ -0,0 +1,34 @@
# Frequently asked questions
## Why is it called salsa?
I like salsa! Don't you?! Well, ok, there's a bit more to it. The
underlying algorithm for figuring out which bits of code need to be
re-executed after any given change is based on the algorithm used in
rustc. Michael Woerister and I first described the rustc algorithm in
terms of two colors, red and green, and hence we called it the
"red-green algorithm". This made me think of the New Mexico State
Question --- ["Red or green?"][nm] --- which refers to chile
(salsa). Although this version no longer uses colors (we borrowed
revision counters from Glimmer, instead), I still like the name.
[nm]: https://www.sos.state.nm.us/about-new-mexico/state-question/
## What is the relationship between salsa and an Entity-Component System (ECS)?
You may have noticed that Salsa "feels" a lot like an ECS in some
ways. That's true -- Salsa's queries are a bit like *components* (and
the keys to the queries are a bit like *entities*). But there is one
big difference: **ECS is -- at its heart -- a mutable system**. You
can get or set a component of some entity whenever you like. In
contrast, salsa's queries **define "derived values" via pure
computations**.
Partly as a consequence, ECS doesn't handle incremental updates for
you. When you update some component of some entity, you have to ensure
that other entities' components are updated appropriately.
Finally, ECS offers interesting metadata and "aspect-like" facilities,
such as iterating over all entities that share certain components.
Salsa has no analogue to that.

201
crates/salsa/LICENSE-APACHE Normal file
View file

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

23
crates/salsa/LICENSE-MIT Normal file
View file

@ -0,0 +1,23 @@
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

42
crates/salsa/README.md Normal file
View file

@ -0,0 +1,42 @@
# salsa
*A generic framework for on-demand, incrementalized computation.*
## Obligatory warning
This is a fork of https://github.com/salsa-rs/salsa/ adjusted to rust-analyzer's needs.
## Credits
This system is heavily inspired by [adapton](http://adapton.org/), [glimmer](https://github.com/glimmerjs/glimmer-vm), and rustc's query
system. So credit goes to Eduard-Mihai Burtescu, Matthew Hammer,
Yehuda Katz, and Michael Woerister.
## Key idea
The key idea of `salsa` is that you define your program as a set of
**queries**. Every query is used like function `K -> V` that maps from
some key of type `K` to a value of type `V`. Queries come in two basic
varieties:
- **Inputs**: the base inputs to your system. You can change these
whenever you like.
- **Functions**: pure functions (no side effects) that transform your
inputs into other values. The results of queries is memoized to
avoid recomputing them a lot. When you make changes to the inputs,
we'll figure out (fairly intelligently) when we can re-use these
memoized values and when we have to recompute them.
## Want to learn more?
To learn more about Salsa, try one of the following:
- read the [heavily commented `hello_world` example](https://github.com/salsa-rs/salsa/blob/master/examples/hello_world/main.rs);
- check out the [Salsa book](https://salsa-rs.github.io/salsa);
- watch one of our [videos](https://salsa-rs.github.io/salsa/videos.html).
## Getting in touch
The bulk of the discussion happens in the [issues](https://github.com/salsa-rs/salsa/issues)
and [pull requests](https://github.com/salsa-rs/salsa/pulls),
but we have a [zulip chat](https://salsa.zulipchat.com/) as well.

View file

@ -0,0 +1,23 @@
[package]
name = "salsa-macros"
version = "0.0.0"
authors = ["Salsa developers"]
edition = "2021"
license = "Apache-2.0 OR MIT"
repository = "https://github.com/salsa-rs/salsa"
description = "Procedural macros for the salsa crate"
rust-version.workspace = true
[lib]
proc-macro = true
name = "salsa_macros"
[dependencies]
heck = "0.4"
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "2.0", features = ["full", "extra-traits"] }
[lints]
workspace = true

View file

@ -0,0 +1 @@
../LICENSE-APACHE

View file

@ -0,0 +1 @@
../LICENSE-MIT

View file

@ -0,0 +1 @@
../README.md

View file

@ -0,0 +1,250 @@
//!
use heck::ToSnakeCase;
use proc_macro::TokenStream;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{Ident, ItemStruct, Path, Token};
type PunctuatedQueryGroups = Punctuated<QueryGroup, Token![,]>;
pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream {
let args = syn::parse_macro_input!(args as QueryGroupList);
let input = syn::parse_macro_input!(input as ItemStruct);
let query_groups = &args.query_groups;
let database_name = &input.ident;
let visibility = &input.vis;
let db_storage_field = quote! { storage };
let mut output = proc_macro2::TokenStream::new();
output.extend(quote! { #input });
let query_group_names_snake: Vec<_> = query_groups
.iter()
.map(|query_group| {
let group_name = query_group.name();
Ident::new(&group_name.to_string().to_snake_case(), group_name.span())
})
.collect();
let query_group_storage_names: Vec<_> = query_groups
.iter()
.map(|QueryGroup { group_path }| {
quote! {
<#group_path as salsa::plumbing::QueryGroup>::GroupStorage
}
})
.collect();
// For each query group `foo::MyGroup` create a link to its
// `foo::MyGroupGroupStorage`
let mut storage_fields = proc_macro2::TokenStream::new();
let mut storage_initializers = proc_macro2::TokenStream::new();
let mut has_group_impls = proc_macro2::TokenStream::new();
for (((query_group, group_name_snake), group_storage), group_index) in query_groups
.iter()
.zip(&query_group_names_snake)
.zip(&query_group_storage_names)
.zip(0_u16..)
{
let group_path = &query_group.group_path;
// rewrite the last identifier (`MyGroup`, above) to
// (e.g.) `MyGroupGroupStorage`.
storage_fields.extend(quote! {
#group_name_snake: #group_storage,
});
// rewrite the last identifier (`MyGroup`, above) to
// (e.g.) `MyGroupGroupStorage`.
storage_initializers.extend(quote! {
#group_name_snake: #group_storage::new(#group_index),
});
// ANCHOR:HasQueryGroup
has_group_impls.extend(quote! {
impl salsa::plumbing::HasQueryGroup<#group_path> for #database_name {
fn group_storage(&self) -> &#group_storage {
&self.#db_storage_field.query_store().#group_name_snake
}
fn group_storage_mut(&mut self) -> (&#group_storage, &mut salsa::Runtime) {
let (query_store_mut, runtime) = self.#db_storage_field.query_store_mut();
(&query_store_mut.#group_name_snake, runtime)
}
}
});
// ANCHOR_END:HasQueryGroup
}
// create group storage wrapper struct
output.extend(quote! {
#[doc(hidden)]
#visibility struct __SalsaDatabaseStorage {
#storage_fields
}
impl Default for __SalsaDatabaseStorage {
fn default() -> Self {
Self {
#storage_initializers
}
}
}
});
// Create a tuple (D1, D2, ...) where Di is the data for a given query group.
let mut database_data = vec![];
for QueryGroup { group_path } in query_groups {
database_data.push(quote! {
<#group_path as salsa::plumbing::QueryGroup>::GroupData
});
}
// ANCHOR:DatabaseStorageTypes
output.extend(quote! {
impl salsa::plumbing::DatabaseStorageTypes for #database_name {
type DatabaseStorage = __SalsaDatabaseStorage;
}
});
// ANCHOR_END:DatabaseStorageTypes
// ANCHOR:DatabaseOps
let mut fmt_ops = proc_macro2::TokenStream::new();
let mut maybe_changed_ops = proc_macro2::TokenStream::new();
let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new();
let mut for_each_ops = proc_macro2::TokenStream::new();
for ((QueryGroup { group_path }, group_storage), group_index) in
query_groups.iter().zip(&query_group_storage_names).zip(0_u16..)
{
fmt_ops.extend(quote! {
#group_index => {
let storage: &#group_storage =
<Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
storage.fmt_index(self, input, fmt)
}
});
maybe_changed_ops.extend(quote! {
#group_index => {
let storage: &#group_storage =
<Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
storage.maybe_changed_after(self, input, revision)
}
});
cycle_recovery_strategy_ops.extend(quote! {
#group_index => {
let storage: &#group_storage =
<Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
storage.cycle_recovery_strategy(self, input)
}
});
for_each_ops.extend(quote! {
let storage: &#group_storage =
<Self as salsa::plumbing::HasQueryGroup<#group_path>>::group_storage(self);
storage.for_each_query(runtime, &mut op);
});
}
output.extend(quote! {
impl salsa::plumbing::DatabaseOps for #database_name {
fn ops_database(&self) -> &dyn salsa::Database {
self
}
fn ops_salsa_runtime(&self) -> &salsa::Runtime {
self.#db_storage_field.salsa_runtime()
}
fn ops_salsa_runtime_mut(&mut self) -> &mut salsa::Runtime {
self.#db_storage_field.salsa_runtime_mut()
}
fn fmt_index(
&self,
input: salsa::DatabaseKeyIndex,
fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match input.group_index() {
#fmt_ops
i => panic!("salsa: invalid group index {}", i)
}
}
fn maybe_changed_after(
&self,
input: salsa::DatabaseKeyIndex,
revision: salsa::Revision
) -> bool {
match input.group_index() {
#maybe_changed_ops
i => panic!("salsa: invalid group index {}", i)
}
}
fn cycle_recovery_strategy(
&self,
input: salsa::DatabaseKeyIndex,
) -> salsa::plumbing::CycleRecoveryStrategy {
match input.group_index() {
#cycle_recovery_strategy_ops
i => panic!("salsa: invalid group index {}", i)
}
}
fn for_each_query(
&self,
mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps),
) {
let runtime = salsa::Database::salsa_runtime(self);
#for_each_ops
}
}
});
// ANCHOR_END:DatabaseOps
output.extend(has_group_impls);
output.into()
}
#[derive(Clone, Debug)]
struct QueryGroupList {
query_groups: PunctuatedQueryGroups,
}
impl Parse for QueryGroupList {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let query_groups: PunctuatedQueryGroups =
input.parse_terminated(QueryGroup::parse, Token![,])?;
Ok(QueryGroupList { query_groups })
}
}
#[derive(Clone, Debug)]
struct QueryGroup {
group_path: Path,
}
impl QueryGroup {
/// The name of the query group trait.
fn name(&self) -> Ident {
self.group_path.segments.last().unwrap().ident.clone()
}
}
impl Parse for QueryGroup {
/// ```ignore
/// impl HelloWorldDatabase;
/// ```
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let group_path: Path = input.parse()?;
Ok(QueryGroup { group_path })
}
}
struct Nothing;
impl Parse for Nothing {
fn parse(_input: ParseStream<'_>) -> syn::Result<Self> {
Ok(Nothing)
}
}

View file

@ -0,0 +1,146 @@
//! This crate provides salsa's macros and attributes.
#![recursion_limit = "256"]
#[macro_use]
extern crate quote;
use proc_macro::TokenStream;
mod database_storage;
mod parenthesized;
mod query_group;
/// The decorator that defines a salsa "query group" trait. This is a
/// trait that defines everything that a block of queries need to
/// execute, as well as defining the queries themselves that are
/// exported for others to use.
///
/// This macro declares the "prototype" for a group of queries. It will
/// expand into a trait and a set of structs, one per query.
///
/// For each query, you give the name of the accessor method to invoke
/// the query (e.g., `my_query`, below), as well as its parameter
/// types and the output type. You also give the name for a query type
/// (e.g., `MyQuery`, below) that represents the query, and optionally
/// other details, such as its storage.
///
/// # Examples
///
/// The simplest example is something like this:
///
/// ```ignore
/// #[salsa::query_group]
/// trait TypeckDatabase {
/// #[salsa::input] // see below for other legal attributes
/// fn my_query(&self, input: u32) -> u64;
///
/// /// Queries can have any number of inputs (including zero); if there
/// /// is not exactly one input, then the key type will be
/// /// a tuple of the input types, so in this case `(u32, f32)`.
/// fn other_query(&self, input1: u32, input2: f32) -> u64;
/// }
/// ```
///
/// Here is a list of legal `salsa::XXX` attributes:
///
/// - Storage attributes: control how the query data is stored and set. These
/// are described in detail in the section below.
/// - `#[salsa::input]`
/// - `#[salsa::memoized]`
/// - `#[salsa::dependencies]`
/// - Query execution:
/// - `#[salsa::invoke(path::to::my_fn)]` -- for a non-input, this
/// indicates the function to call when a query must be
/// recomputed. The default is to call a function in the same
/// module with the same name as the query.
/// - `#[query_type(MyQueryTypeName)]` specifies the name of the
/// dummy struct created for the query. Default is the name of the
/// query, in camel case, plus the word "Query" (e.g.,
/// `MyQueryQuery` and `OtherQueryQuery` in the examples above).
///
/// # Storage attributes
///
/// Here are the possible storage values for each query. The default
/// is `storage memoized`.
///
/// ## Input queries
///
/// Specifying `storage input` will give you an **input
/// query**. Unlike derived queries, whose value is given by a
/// function, input queries are explicitly set by doing
/// `db.query(QueryType).set(key, value)` (where `QueryType` is the
/// `type` specified for the query). Accessing a value that has not
/// yet been set will panic. Each time you invoke `set`, we assume the
/// value has changed, and so we will potentially re-execute derived
/// queries that read (transitively) from this input.
///
/// ## Derived queries
///
/// Derived queries are specified by a function.
///
/// - `#[salsa::memoized]` (the default) -- The result is memoized
/// between calls. If the inputs have changed, we will recompute
/// the value, but then compare against the old memoized value,
/// which can significantly reduce the amount of recomputation
/// required in new revisions. This does require that the value
/// implements `Eq`.
/// - `#[salsa::dependencies]` -- does not cache the value, so it will
/// be recomputed every time it is needed. We do track the inputs, however,
/// so if they have not changed, then things that rely on this query
/// may be known not to have changed.
///
/// ## Attribute combinations
///
/// Some attributes are mutually exclusive. For example, it is an error to add
/// multiple storage specifiers:
///
/// ```compile_fail
/// # use salsa_macros as salsa;
/// #[salsa::query_group]
/// trait CodegenDatabase {
/// #[salsa::input]
/// #[salsa::memoized]
/// fn my_query(&self, input: u32) -> u64;
/// }
/// ```
///
/// It is also an error to annotate a function to `invoke` on an `input` query:
///
/// ```compile_fail
/// # use salsa_macros as salsa;
/// #[salsa::query_group]
/// trait CodegenDatabase {
/// #[salsa::input]
/// #[salsa::invoke(typeck::my_query)]
/// fn my_query(&self, input: u32) -> u64;
/// }
/// ```
#[proc_macro_attribute]
pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
query_group::query_group(args, input)
}
/// This attribute is placed on your database struct. It takes a list of the
/// query groups that your database supports. The format looks like so:
///
/// ```rust,ignore
/// #[salsa::database(MyQueryGroup1, MyQueryGroup2)]
/// struct MyDatabase {
/// runtime: salsa::Runtime<MyDatabase>, // <-- your database will need this field, too
/// }
/// ```
///
/// Here, the struct `MyDatabase` would support the two query groups
/// `MyQueryGroup1` and `MyQueryGroup2`. In addition to the `database`
/// attribute, the struct needs to have a `runtime` field (of type
/// [`salsa::Runtime`]) and to implement the `salsa::Database` trait.
///
/// See [the `hello_world` example][hw] for more details.
///
/// [`salsa::Runtime`]: struct.Runtime.html
/// [hw]: https://github.com/salsa-rs/salsa/tree/master/examples/hello_world
#[proc_macro_attribute]
pub fn database(args: TokenStream, input: TokenStream) -> TokenStream {
database_storage::database(args, input)
}

View file

@ -0,0 +1,13 @@
//!
pub(crate) struct Parenthesized<T>(pub(crate) T);
impl<T> syn::parse::Parse for Parenthesized<T>
where
T: syn::parse::Parse,
{
fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
let content;
syn::parenthesized!(content in input);
content.parse::<T>().map(Parenthesized)
}
}

View file

@ -0,0 +1,734 @@
//!
use std::{convert::TryFrom, iter::FromIterator};
use crate::parenthesized::Parenthesized;
use heck::ToUpperCamelCase;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::ToTokens;
use syn::{
parse_macro_input, parse_quote, spanned::Spanned, Attribute, Error, FnArg, Ident, ItemTrait,
ReturnType, TraitItem, Type,
};
/// Implementation for `[salsa::query_group]` decorator.
pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream {
let group_struct = parse_macro_input!(args as Ident);
let input: ItemTrait = parse_macro_input!(input as ItemTrait);
// println!("args: {:#?}", args);
// println!("input: {:#?}", input);
let input_span = input.span();
let (trait_attrs, salsa_attrs) = filter_attrs(input.attrs);
if !salsa_attrs.is_empty() {
return Error::new(input_span, format!("unsupported attributes: {:?}", salsa_attrs))
.to_compile_error()
.into();
}
let trait_vis = input.vis;
let trait_name = input.ident;
let _generics = input.generics.clone();
let dyn_db = quote! { dyn #trait_name };
// Decompose the trait into the corresponding queries.
let mut queries = vec![];
for item in input.items {
if let TraitItem::Fn(method) = item {
let query_name = method.sig.ident.to_string();
let mut storage = QueryStorage::Memoized;
let mut cycle = None;
let mut invoke = None;
let mut query_type =
format_ident!("{}Query", query_name.to_string().to_upper_camel_case());
let mut num_storages = 0;
// Extract attributes.
let (attrs, salsa_attrs) = filter_attrs(method.attrs);
for SalsaAttr { name, tts, span } in salsa_attrs {
match name.as_str() {
"memoized" => {
storage = QueryStorage::Memoized;
num_storages += 1;
}
"dependencies" => {
storage = QueryStorage::Dependencies;
num_storages += 1;
}
"input" => {
storage = QueryStorage::Input;
num_storages += 1;
}
"interned" => {
storage = QueryStorage::Interned;
num_storages += 1;
}
"cycle" => {
cycle = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
}
"invoke" => {
invoke = Some(parse_macro_input!(tts as Parenthesized<syn::Path>).0);
}
"query_type" => {
query_type = parse_macro_input!(tts as Parenthesized<Ident>).0;
}
"transparent" => {
storage = QueryStorage::Transparent;
num_storages += 1;
}
_ => {
return Error::new(span, format!("unknown salsa attribute `{}`", name))
.to_compile_error()
.into();
}
}
}
let sig_span = method.sig.span();
// Check attribute combinations.
if num_storages > 1 {
return Error::new(sig_span, "multiple storage attributes specified")
.to_compile_error()
.into();
}
match &invoke {
Some(invoke) if storage == QueryStorage::Input => {
return Error::new(
invoke.span(),
"#[salsa::invoke] cannot be set on #[salsa::input] queries",
)
.to_compile_error()
.into();
}
_ => {}
}
// Extract keys.
let mut iter = method.sig.inputs.iter();
let self_receiver = match iter.next() {
Some(FnArg::Receiver(sr)) if sr.mutability.is_none() => sr,
_ => {
return Error::new(
sig_span,
format!("first argument of query `{}` must be `&self`", query_name),
)
.to_compile_error()
.into();
}
};
let mut keys: Vec<(Ident, Type)> = vec![];
for (idx, arg) in iter.enumerate() {
match arg {
FnArg::Typed(syn::PatType { pat, ty, .. }) => keys.push((
match pat.as_ref() {
syn::Pat::Ident(ident_pat) => ident_pat.ident.clone(),
_ => format_ident!("key{}", idx),
},
Type::clone(ty),
)),
arg => {
return Error::new(
arg.span(),
format!("unsupported argument `{:?}` of `{}`", arg, query_name,),
)
.to_compile_error()
.into();
}
}
}
// Extract value.
let value = match method.sig.output {
ReturnType::Type(_, ref ty) => ty.as_ref().clone(),
ref ret => {
return Error::new(
ret.span(),
format!("unsupported return type `{:?}` of `{}`", ret, query_name),
)
.to_compile_error()
.into();
}
};
// For `#[salsa::interned]` keys, we create a "lookup key" automatically.
//
// For a query like:
//
// fn foo(&self, x: Key1, y: Key2) -> u32
//
// we would create
//
// fn lookup_foo(&self, x: u32) -> (Key1, Key2)
let lookup_query = if let QueryStorage::Interned = storage {
let lookup_query_type =
format_ident!("{}LookupQuery", query_name.to_string().to_upper_camel_case());
let lookup_fn_name = format_ident!("lookup_{}", query_name);
let keys = keys.iter().map(|(_, ty)| ty);
let lookup_value: Type = parse_quote!((#(#keys),*));
let lookup_keys = vec![(parse_quote! { key }, value.clone())];
Some(Query {
query_type: lookup_query_type,
query_name: format!("{}", lookup_fn_name),
fn_name: lookup_fn_name,
receiver: self_receiver.clone(),
attrs: vec![], // FIXME -- some automatically generated docs on this method?
storage: QueryStorage::InternedLookup { intern_query_type: query_type.clone() },
keys: lookup_keys,
value: lookup_value,
invoke: None,
cycle: cycle.clone(),
})
} else {
None
};
queries.push(Query {
query_type,
query_name,
fn_name: method.sig.ident,
receiver: self_receiver.clone(),
attrs,
storage,
keys,
value,
invoke,
cycle,
});
queries.extend(lookup_query);
}
}
let group_storage = format_ident!("{}GroupStorage__", trait_name, span = Span::call_site());
let mut query_fn_declarations = proc_macro2::TokenStream::new();
let mut query_fn_definitions = proc_macro2::TokenStream::new();
let mut storage_fields = proc_macro2::TokenStream::new();
let mut queries_with_storage = vec![];
for query in &queries {
#[allow(clippy::map_identity)]
// clippy is incorrect here, this is not the identity function due to match ergonomics
let (key_names, keys): (Vec<_>, Vec<_>) = query.keys.iter().map(|(a, b)| (a, b)).unzip();
let value = &query.value;
let fn_name = &query.fn_name;
let qt = &query.query_type;
let attrs = &query.attrs;
let self_receiver = &query.receiver;
query_fn_declarations.extend(quote! {
#(#attrs)*
fn #fn_name(#self_receiver, #(#key_names: #keys),*) -> #value;
});
// Special case: transparent queries don't create actual storage,
// just inline the definition
if let QueryStorage::Transparent = query.storage {
let invoke = query.invoke_tt();
query_fn_definitions.extend(quote! {
fn #fn_name(&self, #(#key_names: #keys),*) -> #value {
#invoke(self, #(#key_names),*)
}
});
continue;
}
queries_with_storage.push(fn_name);
query_fn_definitions.extend(quote! {
fn #fn_name(&self, #(#key_names: #keys),*) -> #value {
// Create a shim to force the code to be monomorphized in the
// query crate. Our experiments revealed that this makes a big
// difference in total compilation time in rust-analyzer, though
// it's not totally obvious why that should be.
fn __shim(db: &(dyn #trait_name + '_), #(#key_names: #keys),*) -> #value {
salsa::plumbing::get_query_table::<#qt>(db).get((#(#key_names),*))
}
__shim(self, #(#key_names),*)
}
});
// For input queries, we need `set_foo` etc
if let QueryStorage::Input = query.storage {
let set_fn_name = format_ident!("set_{}", fn_name);
let set_with_durability_fn_name = format_ident!("set_{}_with_durability", fn_name);
let set_fn_docs = format!(
"
Set the value of the `{fn_name}` input.
See `{fn_name}` for details.
*Note:* Setting values will trigger cancellation
of any ongoing queries; this method blocks until
those queries have been cancelled.
",
fn_name = fn_name
);
let set_constant_fn_docs = format!(
"
Set the value of the `{fn_name}` input with a
specific durability instead of the default of
`Durability::LOW`. You can use `Durability::MAX`
to promise that its value will never change again.
See `{fn_name}` for details.
*Note:* Setting values will trigger cancellation
of any ongoing queries; this method blocks until
those queries have been cancelled.
",
fn_name = fn_name
);
query_fn_declarations.extend(quote! {
# [doc = #set_fn_docs]
fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value);
# [doc = #set_constant_fn_docs]
fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability);
});
query_fn_definitions.extend(quote! {
fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value) {
fn __shim(db: &mut dyn #trait_name, #(#key_names: #keys,)* value__: #value) {
salsa::plumbing::get_query_table_mut::<#qt>(db).set((#(#key_names),*), value__)
}
__shim(self, #(#key_names,)* value__)
}
fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability) {
fn __shim(db: &mut dyn #trait_name, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability) {
salsa::plumbing::get_query_table_mut::<#qt>(db).set_with_durability((#(#key_names),*), value__, durability__)
}
__shim(self, #(#key_names,)* value__ ,durability__)
}
});
}
// A field for the storage struct
storage_fields.extend(quote! {
#fn_name: std::sync::Arc<<#qt as salsa::Query>::Storage>,
});
}
// Emit the trait itself.
let mut output = {
let bounds = &input.supertraits;
quote! {
#(#trait_attrs)*
#trait_vis trait #trait_name :
salsa::Database +
salsa::plumbing::HasQueryGroup<#group_struct> +
#bounds
{
#query_fn_declarations
}
}
};
// Emit the query group struct and impl of `QueryGroup`.
output.extend(quote! {
/// Representative struct for the query group.
#trait_vis struct #group_struct { }
impl salsa::plumbing::QueryGroup for #group_struct
{
type DynDb = #dyn_db;
type GroupStorage = #group_storage;
}
});
// Emit an impl of the trait
output.extend({
let bounds = input.supertraits;
quote! {
impl<DB> #trait_name for DB
where
DB: #bounds,
DB: salsa::Database,
DB: salsa::plumbing::HasQueryGroup<#group_struct>,
{
#query_fn_definitions
}
}
});
let non_transparent_queries =
|| queries.iter().filter(|q| !matches!(q.storage, QueryStorage::Transparent));
// Emit the query types.
for (query, query_index) in non_transparent_queries().zip(0_u16..) {
let fn_name = &query.fn_name;
let qt = &query.query_type;
let storage = match &query.storage {
QueryStorage::Memoized => quote!(salsa::plumbing::MemoizedStorage<Self>),
QueryStorage::Dependencies => {
quote!(salsa::plumbing::DependencyStorage<Self>)
}
QueryStorage::Input => quote!(salsa::plumbing::InputStorage<Self>),
QueryStorage::Interned => quote!(salsa::plumbing::InternedStorage<Self>),
QueryStorage::InternedLookup { intern_query_type } => {
quote!(salsa::plumbing::LookupInternedStorage<Self, #intern_query_type>)
}
QueryStorage::Transparent => panic!("should have been filtered"),
};
let keys = query.keys.iter().map(|(_, ty)| ty);
let value = &query.value;
let query_name = &query.query_name;
// Emit the query struct and implement the Query trait on it.
output.extend(quote! {
#[derive(Default, Debug)]
#trait_vis struct #qt;
});
output.extend(quote! {
impl #qt {
/// Get access to extra methods pertaining to this query.
/// You can also use it to invoke this query.
#trait_vis fn in_db(self, db: &#dyn_db) -> salsa::QueryTable<'_, Self>
{
salsa::plumbing::get_query_table::<#qt>(db)
}
}
});
output.extend(quote! {
impl #qt {
/// Like `in_db`, but gives access to methods for setting the
/// value of an input. Not applicable to derived queries.
///
/// # Threads, cancellation, and blocking
///
/// Mutating the value of a query cannot be done while there are
/// still other queries executing. If you are using your database
/// within a single thread, this is not a problem: you only have
/// `&self` access to the database, but this method requires `&mut
/// self`.
///
/// However, if you have used `snapshot` to create other threads,
/// then attempts to `set` will **block the current thread** until
/// those snapshots are dropped (usually when those threads
/// complete). This also implies that if you create a snapshot but
/// do not send it to another thread, then invoking `set` will
/// deadlock.
///
/// Before blocking, the thread that is attempting to `set` will
/// also set a cancellation flag. This will cause any query
/// invocations in other threads to unwind with a `Cancelled`
/// sentinel value and eventually let the `set` succeed once all
/// threads have unwound past the salsa invocation.
///
/// If your query implementations are performing expensive
/// operations without invoking another query, you can also use
/// the `Runtime::unwind_if_cancelled` method to check for an
/// ongoing cancellation and bring those operations to a close,
/// thus allowing the `set` to succeed. Otherwise, long-running
/// computations may lead to "starvation", meaning that the
/// thread attempting to `set` has to wait a long, long time. =)
#trait_vis fn in_db_mut(self, db: &mut #dyn_db) -> salsa::QueryTableMut<'_, Self>
{
salsa::plumbing::get_query_table_mut::<#qt>(db)
}
}
impl<'d> salsa::QueryDb<'d> for #qt
{
type DynDb = #dyn_db + 'd;
type Group = #group_struct;
type GroupStorage = #group_storage;
}
// ANCHOR:Query_impl
impl salsa::Query for #qt
{
type Key = (#(#keys),*);
type Value = #value;
type Storage = #storage;
const QUERY_INDEX: u16 = #query_index;
const QUERY_NAME: &'static str = #query_name;
fn query_storage<'a>(
group_storage: &'a <Self as salsa::QueryDb<'_>>::GroupStorage,
) -> &'a std::sync::Arc<Self::Storage> {
&group_storage.#fn_name
}
fn query_storage_mut<'a>(
group_storage: &'a <Self as salsa::QueryDb<'_>>::GroupStorage,
) -> &'a std::sync::Arc<Self::Storage> {
&group_storage.#fn_name
}
}
// ANCHOR_END:Query_impl
});
// Implement the QueryFunction trait for queries which need it.
if query.storage.needs_query_function() {
let span = query.fn_name.span();
let key_names: Vec<_> = query.keys.iter().map(|(pat, _)| pat).collect();
let key_pattern = if query.keys.len() == 1 {
quote! { #(#key_names),* }
} else {
quote! { (#(#key_names),*) }
};
let invoke = query.invoke_tt();
let recover = if let Some(cycle_recovery_fn) = &query.cycle {
quote! {
const CYCLE_STRATEGY: salsa::plumbing::CycleRecoveryStrategy =
salsa::plumbing::CycleRecoveryStrategy::Fallback;
fn cycle_fallback(db: &<Self as salsa::QueryDb<'_>>::DynDb, cycle: &salsa::Cycle, #key_pattern: &<Self as salsa::Query>::Key)
-> <Self as salsa::Query>::Value {
#cycle_recovery_fn(
db,
cycle,
#(#key_names),*
)
}
}
} else {
quote! {
const CYCLE_STRATEGY: salsa::plumbing::CycleRecoveryStrategy =
salsa::plumbing::CycleRecoveryStrategy::Panic;
}
};
output.extend(quote_spanned! {span=>
// ANCHOR:QueryFunction_impl
impl salsa::plumbing::QueryFunction for #qt
{
fn execute(db: &<Self as salsa::QueryDb<'_>>::DynDb, #key_pattern: <Self as salsa::Query>::Key)
-> <Self as salsa::Query>::Value {
#invoke(db, #(#key_names),*)
}
#recover
}
// ANCHOR_END:QueryFunction_impl
});
}
}
let mut fmt_ops = proc_macro2::TokenStream::new();
for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) {
fmt_ops.extend(quote! {
#query_index => {
salsa::plumbing::QueryStorageOps::fmt_index(
&*self.#fn_name, db, input, fmt,
)
}
});
}
let mut maybe_changed_ops = proc_macro2::TokenStream::new();
for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) {
maybe_changed_ops.extend(quote! {
#query_index => {
salsa::plumbing::QueryStorageOps::maybe_changed_after(
&*self.#fn_name, db, input, revision
)
}
});
}
let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new();
for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) {
cycle_recovery_strategy_ops.extend(quote! {
#query_index => {
salsa::plumbing::QueryStorageOps::cycle_recovery_strategy(
&*self.#fn_name
)
}
});
}
let mut for_each_ops = proc_macro2::TokenStream::new();
for Query { fn_name, .. } in non_transparent_queries() {
for_each_ops.extend(quote! {
op(&*self.#fn_name);
});
}
// Emit query group storage struct
output.extend(quote! {
#trait_vis struct #group_storage {
#storage_fields
}
// ANCHOR:group_storage_new
impl #group_storage {
#trait_vis fn new(group_index: u16) -> Self {
#group_storage {
#(
#queries_with_storage:
std::sync::Arc::new(salsa::plumbing::QueryStorageOps::new(group_index)),
)*
}
}
}
// ANCHOR_END:group_storage_new
// ANCHOR:group_storage_methods
impl #group_storage {
#trait_vis fn fmt_index(
&self,
db: &(#dyn_db + '_),
input: salsa::DatabaseKeyIndex,
fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match input.query_index() {
#fmt_ops
i => panic!("salsa: impossible query index {}", i),
}
}
#trait_vis fn maybe_changed_after(
&self,
db: &(#dyn_db + '_),
input: salsa::DatabaseKeyIndex,
revision: salsa::Revision,
) -> bool {
match input.query_index() {
#maybe_changed_ops
i => panic!("salsa: impossible query index {}", i),
}
}
#trait_vis fn cycle_recovery_strategy(
&self,
db: &(#dyn_db + '_),
input: salsa::DatabaseKeyIndex,
) -> salsa::plumbing::CycleRecoveryStrategy {
match input.query_index() {
#cycle_recovery_strategy_ops
i => panic!("salsa: impossible query index {}", i),
}
}
#trait_vis fn for_each_query(
&self,
_runtime: &salsa::Runtime,
mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps),
) {
#for_each_ops
}
}
// ANCHOR_END:group_storage_methods
});
output.into()
}
struct SalsaAttr {
name: String,
tts: TokenStream,
span: Span,
}
impl std::fmt::Debug for SalsaAttr {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "{:?}", self.name)
}
}
impl TryFrom<syn::Attribute> for SalsaAttr {
type Error = syn::Attribute;
fn try_from(attr: syn::Attribute) -> Result<SalsaAttr, syn::Attribute> {
if is_not_salsa_attr_path(attr.path()) {
return Err(attr);
}
let span = attr.span();
let name = attr.path().segments[1].ident.to_string();
let tts = match attr.meta {
syn::Meta::Path(path) => path.into_token_stream(),
syn::Meta::List(ref list) => {
let tts = list
.into_token_stream()
.into_iter()
.skip(attr.path().to_token_stream().into_iter().count());
proc_macro2::TokenStream::from_iter(tts)
}
syn::Meta::NameValue(nv) => nv.into_token_stream(),
}
.into();
Ok(SalsaAttr { name, tts, span })
}
}
fn is_not_salsa_attr_path(path: &syn::Path) -> bool {
path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2
}
fn filter_attrs(attrs: Vec<Attribute>) -> (Vec<Attribute>, Vec<SalsaAttr>) {
let mut other = vec![];
let mut salsa = vec![];
// Leave non-salsa attributes untouched. These are
// attributes that don't start with `salsa::` or don't have
// exactly two segments in their path.
// Keep the salsa attributes around.
for attr in attrs {
match SalsaAttr::try_from(attr) {
Ok(it) => salsa.push(it),
Err(it) => other.push(it),
}
}
(other, salsa)
}
#[derive(Debug)]
struct Query {
fn_name: Ident,
receiver: syn::Receiver,
query_name: String,
attrs: Vec<syn::Attribute>,
query_type: Ident,
storage: QueryStorage,
keys: Vec<(Ident, syn::Type)>,
value: syn::Type,
invoke: Option<syn::Path>,
cycle: Option<syn::Path>,
}
impl Query {
fn invoke_tt(&self) -> proc_macro2::TokenStream {
match &self.invoke {
Some(i) => i.into_token_stream(),
None => self.fn_name.clone().into_token_stream(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum QueryStorage {
Memoized,
Dependencies,
Input,
Interned,
InternedLookup { intern_query_type: Ident },
Transparent,
}
impl QueryStorage {
/// Do we need a `QueryFunction` impl for this type of query?
fn needs_query_function(&self) -> bool {
match self {
QueryStorage::Input
| QueryStorage::Interned
| QueryStorage::InternedLookup { .. }
| QueryStorage::Transparent => false,
QueryStorage::Memoized | QueryStorage::Dependencies => true,
}
}
}

66
crates/salsa/src/debug.rs Normal file
View file

@ -0,0 +1,66 @@
//! Debugging APIs: these are meant for use when unit-testing or
//! debugging your application but aren't ordinarily needed.
use crate::durability::Durability;
use crate::plumbing::QueryStorageOps;
use crate::Query;
use crate::QueryTable;
use std::iter::FromIterator;
/// Additional methods on queries that can be used to "peek into"
/// their current state. These methods are meant for debugging and
/// observing the effects of garbage collection etc.
pub trait DebugQueryTable {
/// Key of this query.
type Key;
/// Value of this query.
type Value;
/// Returns a lower bound on the durability for the given key.
/// This is typically the minimum durability of all values that
/// the query accessed, but we may return a lower durability in
/// some cases.
fn durability(&self, key: Self::Key) -> Durability;
/// Get the (current) set of the entries in the query table.
fn entries<C>(&self) -> C
where
C: FromIterator<TableEntry<Self::Key, Self::Value>>;
}
/// An entry from a query table, for debugging and inspecting the table state.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[non_exhaustive]
pub struct TableEntry<K, V> {
/// key of the query
pub key: K,
/// value of the query, if it is stored
pub value: Option<V>,
}
impl<K, V> TableEntry<K, V> {
pub(crate) fn new(key: K, value: Option<V>) -> TableEntry<K, V> {
TableEntry { key, value }
}
}
impl<Q> DebugQueryTable for QueryTable<'_, Q>
where
Q: Query,
Q::Storage: QueryStorageOps<Q>,
{
type Key = Q::Key;
type Value = Q::Value;
fn durability(&self, key: Q::Key) -> Durability {
self.storage.durability(self.db, &key)
}
fn entries<C>(&self) -> C
where
C: FromIterator<TableEntry<Self::Key, Self::Value>>,
{
self.storage.entries(self.db)
}
}

233
crates/salsa/src/derived.rs Normal file
View file

@ -0,0 +1,233 @@
//!
use crate::debug::TableEntry;
use crate::durability::Durability;
use crate::hash::FxIndexMap;
use crate::lru::Lru;
use crate::plumbing::DerivedQueryStorageOps;
use crate::plumbing::LruQueryStorageOps;
use crate::plumbing::QueryFunction;
use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps;
use crate::runtime::StampedValue;
use crate::Runtime;
use crate::{Database, DatabaseKeyIndex, QueryDb, Revision};
use parking_lot::RwLock;
use std::borrow::Borrow;
use std::convert::TryFrom;
use std::hash::Hash;
use std::marker::PhantomData;
use triomphe::Arc;
mod slot;
use slot::Slot;
/// Memoized queries store the result plus a list of the other queries
/// that they invoked. This means we can avoid recomputing them when
/// none of those inputs have changed.
pub type MemoizedStorage<Q> = DerivedStorage<Q, AlwaysMemoizeValue>;
/// "Dependency" queries just track their dependencies and not the
/// actual value (which they produce on demand). This lessens the
/// storage requirements.
pub type DependencyStorage<Q> = DerivedStorage<Q, NeverMemoizeValue>;
/// Handles storage where the value is 'derived' by executing a
/// function (in contrast to "inputs").
pub struct DerivedStorage<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
group_index: u16,
lru_list: Lru<Slot<Q, MP>>,
slot_map: RwLock<FxIndexMap<Q::Key, Arc<Slot<Q, MP>>>>,
policy: PhantomData<MP>,
}
impl<Q, MP> std::panic::RefUnwindSafe for DerivedStorage<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
Q::Key: std::panic::RefUnwindSafe,
Q::Value: std::panic::RefUnwindSafe,
{
}
pub trait MemoizationPolicy<Q>: Send + Sync
where
Q: QueryFunction,
{
fn should_memoize_value(key: &Q::Key) -> bool;
fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool;
}
pub enum AlwaysMemoizeValue {}
impl<Q> MemoizationPolicy<Q> for AlwaysMemoizeValue
where
Q: QueryFunction,
Q::Value: Eq,
{
fn should_memoize_value(_key: &Q::Key) -> bool {
true
}
fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool {
old_value == new_value
}
}
pub enum NeverMemoizeValue {}
impl<Q> MemoizationPolicy<Q> for NeverMemoizeValue
where
Q: QueryFunction,
{
fn should_memoize_value(_key: &Q::Key) -> bool {
false
}
fn memoized_value_eq(_old_value: &Q::Value, _new_value: &Q::Value) -> bool {
panic!("cannot reach since we never memoize")
}
}
impl<Q, MP> DerivedStorage<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn slot(&self, key: &Q::Key) -> Arc<Slot<Q, MP>> {
if let Some(v) = self.slot_map.read().get(key) {
return v.clone();
}
let mut write = self.slot_map.write();
let entry = write.entry(key.clone());
let key_index = u32::try_from(entry.index()).unwrap();
let database_key_index = DatabaseKeyIndex {
group_index: self.group_index,
query_index: Q::QUERY_INDEX,
key_index,
};
entry.or_insert_with(|| Arc::new(Slot::new(key.clone(), database_key_index))).clone()
}
}
impl<Q, MP> QueryStorageOps<Q> for DerivedStorage<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = Q::CYCLE_STRATEGY;
fn new(group_index: u16) -> Self {
DerivedStorage {
group_index,
slot_map: RwLock::new(FxIndexMap::default()),
lru_list: Default::default(),
policy: PhantomData,
}
}
fn fmt_index(
&self,
_db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex,
fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
assert_eq!(index.group_index, self.group_index);
assert_eq!(index.query_index, Q::QUERY_INDEX);
let slot_map = self.slot_map.read();
let key = slot_map.get_index(index.key_index as usize).unwrap().0;
write!(fmt, "{}({:?})", Q::QUERY_NAME, key)
}
fn maybe_changed_after(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex,
revision: Revision,
) -> bool {
assert_eq!(input.group_index, self.group_index);
assert_eq!(input.query_index, Q::QUERY_INDEX);
debug_assert!(revision < db.salsa_runtime().current_revision());
let slot = self.slot_map.read().get_index(input.key_index as usize).unwrap().1.clone();
slot.maybe_changed_after(db, revision)
}
fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
db.unwind_if_cancelled();
let slot = self.slot(key);
let StampedValue { value, durability, changed_at } = slot.read(db);
if let Some(evicted) = self.lru_list.record_use(&slot) {
evicted.evict();
}
db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted(
slot.database_key_index(),
durability,
changed_at,
);
value
}
fn durability(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Durability {
self.slot(key).durability(db)
}
fn entries<C>(&self, _db: &<Q as QueryDb<'_>>::DynDb) -> C
where
C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
{
let slot_map = self.slot_map.read();
slot_map.values().filter_map(|slot| slot.as_table_entry()).collect()
}
}
impl<Q, MP> QueryStorageMassOps for DerivedStorage<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn purge(&self) {
self.lru_list.purge();
*self.slot_map.write() = Default::default();
}
}
impl<Q, MP> LruQueryStorageOps for DerivedStorage<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn set_lru_capacity(&self, new_capacity: usize) {
self.lru_list.set_lru_capacity(new_capacity);
}
}
impl<Q, MP> DerivedQueryStorageOps<Q> for DerivedStorage<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn invalidate<S>(&self, runtime: &mut Runtime, key: &S)
where
S: Eq + Hash,
Q::Key: Borrow<S>,
{
runtime.with_incremented_revision(|new_revision| {
let map_read = self.slot_map.read();
if let Some(slot) = map_read.get(key) {
if let Some(durability) = slot.invalidate(new_revision) {
return Some(durability);
}
}
None
})
}
}

View file

@ -0,0 +1,833 @@
//!
use crate::debug::TableEntry;
use crate::derived::MemoizationPolicy;
use crate::durability::Durability;
use crate::lru::LruIndex;
use crate::lru::LruNode;
use crate::plumbing::{DatabaseOps, QueryFunction};
use crate::revision::Revision;
use crate::runtime::local_state::ActiveQueryGuard;
use crate::runtime::local_state::QueryInputs;
use crate::runtime::local_state::QueryRevisions;
use crate::runtime::Runtime;
use crate::runtime::RuntimeId;
use crate::runtime::StampedValue;
use crate::runtime::WaitResult;
use crate::Cycle;
use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb};
use parking_lot::{RawRwLock, RwLock};
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::atomic::{AtomicBool, Ordering};
use tracing::{debug, info};
pub(super) struct Slot<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
key: Q::Key,
database_key_index: DatabaseKeyIndex,
state: RwLock<QueryState<Q>>,
policy: PhantomData<MP>,
lru_index: LruIndex,
}
/// Defines the "current state" of query's memoized results.
enum QueryState<Q>
where
Q: QueryFunction,
{
NotComputed,
/// The runtime with the given id is currently computing the
/// result of this query.
InProgress {
id: RuntimeId,
/// Set to true if any other queries are blocked,
/// waiting for this query to complete.
anyone_waiting: AtomicBool,
},
/// We have computed the query already, and here is the result.
Memoized(Memo<Q::Value>),
}
struct Memo<V> {
/// The result of the query, if we decide to memoize it.
value: Option<V>,
/// Last revision when this memo was verified; this begins
/// as the current revision.
pub(crate) verified_at: Revision,
/// Revision information
revisions: QueryRevisions,
}
/// Return value of `probe` helper.
enum ProbeState<V, G> {
/// Another thread was active but has completed.
/// Try again!
Retry,
/// No entry for this key at all.
NotComputed(G),
/// There is an entry, but its contents have not been
/// verified in this revision.
Stale(G),
/// There is an entry, and it has been verified
/// in this revision, but it has no cached
/// value. The `Revision` is the revision where the
/// value last changed (if we were to recompute it).
NoValue(G, Revision),
/// There is an entry which has been verified,
/// and it has the following value-- or, we blocked
/// on another thread, and that resulted in a cycle.
UpToDate(V),
}
/// Return value of `maybe_changed_after_probe` helper.
enum MaybeChangedSinceProbeState<G> {
/// Another thread was active but has completed.
/// Try again!
Retry,
/// Value may have changed in the given revision.
ChangedAt(Revision),
/// There is a stale cache entry that has not been
/// verified in this revision, so we can't say.
Stale(G),
}
impl<Q, MP> Slot<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
pub(super) fn new(key: Q::Key, database_key_index: DatabaseKeyIndex) -> Self {
Self {
key,
database_key_index,
state: RwLock::new(QueryState::NotComputed),
lru_index: LruIndex::default(),
policy: PhantomData,
}
}
pub(super) fn database_key_index(&self) -> DatabaseKeyIndex {
self.database_key_index
}
pub(super) fn read(&self, db: &<Q as QueryDb<'_>>::DynDb) -> StampedValue<Q::Value> {
let runtime = db.salsa_runtime();
// NB: We don't need to worry about people modifying the
// revision out from under our feet. Either `db` is a frozen
// database, in which case there is a lock, or the mutator
// thread is the current thread, and it will be prevented from
// doing any `set` invocations while the query function runs.
let revision_now = runtime.current_revision();
info!("{:?}: invoked at {:?}", self, revision_now,);
// First, do a check with a read-lock.
loop {
match self.probe(db, self.state.read(), runtime, revision_now) {
ProbeState::UpToDate(v) => return v,
ProbeState::Stale(..) | ProbeState::NoValue(..) | ProbeState::NotComputed(..) => {
break
}
ProbeState::Retry => continue,
}
}
self.read_upgrade(db, revision_now)
}
/// Second phase of a read operation: acquires an upgradable-read
/// and -- if needed -- validates whether inputs have changed,
/// recomputes value, etc. This is invoked after our initial probe
/// shows a potentially out of date value.
fn read_upgrade(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
revision_now: Revision,
) -> StampedValue<Q::Value> {
let runtime = db.salsa_runtime();
debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,);
// Check with an upgradable read to see if there is a value
// already. (This permits other readers but prevents anyone
// else from running `read_upgrade` at the same time.)
let mut old_memo = loop {
match self.probe(db, self.state.upgradable_read(), runtime, revision_now) {
ProbeState::UpToDate(v) => return v,
ProbeState::Stale(state)
| ProbeState::NotComputed(state)
| ProbeState::NoValue(state, _) => {
type RwLockUpgradableReadGuard<'a, T> =
lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>;
let mut state = RwLockUpgradableReadGuard::upgrade(state);
match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) {
QueryState::Memoized(old_memo) => break Some(old_memo),
QueryState::InProgress { .. } => unreachable!(),
QueryState::NotComputed => break None,
}
}
ProbeState::Retry => continue,
}
};
let panic_guard = PanicGuard::new(self.database_key_index, self, runtime);
let active_query = runtime.push_query(self.database_key_index);
// If we have an old-value, it *may* now be stale, since there
// has been a new revision since the last time we checked. So,
// first things first, let's walk over each of our previous
// inputs and check whether they are out of date.
if let Some(memo) = &mut old_memo {
if let Some(value) = memo.verify_value(db.ops_database(), revision_now, &active_query) {
info!("{:?}: validated old memoized value", self,);
db.salsa_event(Event {
runtime_id: runtime.id(),
kind: EventKind::DidValidateMemoizedValue {
database_key: self.database_key_index,
},
});
panic_guard.proceed(old_memo);
return value;
}
}
self.execute(db, runtime, revision_now, active_query, panic_guard, old_memo)
}
fn execute(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
runtime: &Runtime,
revision_now: Revision,
active_query: ActiveQueryGuard<'_>,
panic_guard: PanicGuard<'_, Q, MP>,
old_memo: Option<Memo<Q::Value>>,
) -> StampedValue<Q::Value> {
tracing::info!("{:?}: executing query", self.database_key_index.debug(db));
db.salsa_event(Event {
runtime_id: db.salsa_runtime().id(),
kind: EventKind::WillExecute { database_key: self.database_key_index },
});
// Query was not previously executed, or value is potentially
// stale, or value is absent. Let's execute!
let value = match Cycle::catch(|| Q::execute(db, self.key.clone())) {
Ok(v) => v,
Err(cycle) => {
tracing::debug!(
"{:?}: caught cycle {:?}, have strategy {:?}",
self.database_key_index.debug(db),
cycle,
Q::CYCLE_STRATEGY,
);
match Q::CYCLE_STRATEGY {
crate::plumbing::CycleRecoveryStrategy::Panic => {
panic_guard.proceed(None);
cycle.throw()
}
crate::plumbing::CycleRecoveryStrategy::Fallback => {
if let Some(c) = active_query.take_cycle() {
assert!(c.is(&cycle));
Q::cycle_fallback(db, &cycle, &self.key)
} else {
// we are not a participant in this cycle
debug_assert!(!cycle
.participant_keys()
.any(|k| k == self.database_key_index));
cycle.throw()
}
}
}
}
};
let mut revisions = active_query.pop();
// We assume that query is side-effect free -- that is, does
// not mutate the "inputs" to the query system. Sanity check
// that assumption here, at least to the best of our ability.
assert_eq!(
runtime.current_revision(),
revision_now,
"revision altered during query execution",
);
// If the new value is equal to the old one, then it didn't
// really change, even if some of its inputs have. So we can
// "backdate" its `changed_at` revision to be the same as the
// old value.
if let Some(old_memo) = &old_memo {
if let Some(old_value) = &old_memo.value {
// Careful: if the value became less durable than it
// used to be, that is a "breaking change" that our
// consumers must be aware of. Becoming *more* durable
// is not. See the test `constant_to_non_constant`.
if revisions.durability >= old_memo.revisions.durability
&& MP::memoized_value_eq(old_value, &value)
{
debug!(
"read_upgrade({:?}): value is equal, back-dating to {:?}",
self, old_memo.revisions.changed_at,
);
assert!(old_memo.revisions.changed_at <= revisions.changed_at);
revisions.changed_at = old_memo.revisions.changed_at;
}
}
}
let new_value = StampedValue {
value,
durability: revisions.durability,
changed_at: revisions.changed_at,
};
let memo_value =
if self.should_memoize_value(&self.key) { Some(new_value.value.clone()) } else { None };
debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,);
panic_guard.proceed(Some(Memo { value: memo_value, verified_at: revision_now, revisions }));
new_value
}
/// Helper for `read` that does a shallow check (not recursive) if we have an up-to-date value.
///
/// Invoked with the guard `state` corresponding to the `QueryState` of some `Slot` (the guard
/// can be either read or write). Returns a suitable `ProbeState`:
///
/// - `ProbeState::UpToDate(r)` if the table has an up-to-date value (or we blocked on another
/// thread that produced such a value).
/// - `ProbeState::StaleOrAbsent(g)` if either (a) there is no memo for this key, (b) the memo
/// has no value; or (c) the memo has not been verified at the current revision.
///
/// Note that in case `ProbeState::UpToDate`, the lock will have been released.
fn probe<StateGuard>(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
state: StateGuard,
runtime: &Runtime,
revision_now: Revision,
) -> ProbeState<StampedValue<Q::Value>, StateGuard>
where
StateGuard: Deref<Target = QueryState<Q>>,
{
match &*state {
QueryState::NotComputed => ProbeState::NotComputed(state),
QueryState::InProgress { id, anyone_waiting } => {
let other_id = *id;
// NB: `Ordering::Relaxed` is sufficient here,
// as there are no loads that are "gated" on this
// value. Everything that is written is also protected
// by a lock that must be acquired. The role of this
// boolean is to decide *whether* to acquire the lock,
// not to gate future atomic reads.
anyone_waiting.store(true, Ordering::Relaxed);
self.block_on_or_unwind(db, runtime, other_id, state);
// Other thread completely normally, so our value may be available now.
ProbeState::Retry
}
QueryState::Memoized(memo) => {
debug!(
"{:?}: found memoized value, verified_at={:?}, changed_at={:?}",
self, memo.verified_at, memo.revisions.changed_at,
);
if memo.verified_at < revision_now {
return ProbeState::Stale(state);
}
if let Some(value) = &memo.value {
let value = StampedValue {
durability: memo.revisions.durability,
changed_at: memo.revisions.changed_at,
value: value.clone(),
};
info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at);
ProbeState::UpToDate(value)
} else {
let changed_at = memo.revisions.changed_at;
ProbeState::NoValue(state, changed_at)
}
}
}
}
pub(super) fn durability(&self, db: &<Q as QueryDb<'_>>::DynDb) -> Durability {
match &*self.state.read() {
QueryState::NotComputed => Durability::LOW,
QueryState::InProgress { .. } => panic!("query in progress"),
QueryState::Memoized(memo) => {
if memo.check_durability(db.salsa_runtime()) {
memo.revisions.durability
} else {
Durability::LOW
}
}
}
}
pub(super) fn as_table_entry(&self) -> Option<TableEntry<Q::Key, Q::Value>> {
match &*self.state.read() {
QueryState::NotComputed => None,
QueryState::InProgress { .. } => Some(TableEntry::new(self.key.clone(), None)),
QueryState::Memoized(memo) => {
Some(TableEntry::new(self.key.clone(), memo.value.clone()))
}
}
}
pub(super) fn evict(&self) {
let mut state = self.state.write();
if let QueryState::Memoized(memo) = &mut *state {
// Evicting a value with an untracked input could
// lead to inconsistencies. Note that we can't check
// `has_untracked_input` when we add the value to the cache,
// because inputs can become untracked in the next revision.
if memo.has_untracked_input() {
return;
}
memo.value = None;
}
}
pub(super) fn invalidate(&self, new_revision: Revision) -> Option<Durability> {
tracing::debug!("Slot::invalidate(new_revision = {:?})", new_revision);
match &mut *self.state.write() {
QueryState::Memoized(memo) => {
memo.revisions.inputs = QueryInputs::Untracked;
memo.revisions.changed_at = new_revision;
Some(memo.revisions.durability)
}
QueryState::NotComputed => None,
QueryState::InProgress { .. } => unreachable!(),
}
}
pub(super) fn maybe_changed_after(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
revision: Revision,
) -> bool {
let runtime = db.salsa_runtime();
let revision_now = runtime.current_revision();
db.unwind_if_cancelled();
debug!(
"maybe_changed_after({:?}) called with revision={:?}, revision_now={:?}",
self, revision, revision_now,
);
// Do an initial probe with just the read-lock.
//
// If we find that a cache entry for the value is present
// but hasn't been verified in this revision, we'll have to
// do more.
loop {
match self.maybe_changed_after_probe(db, self.state.read(), runtime, revision_now) {
MaybeChangedSinceProbeState::Retry => continue,
MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision,
MaybeChangedSinceProbeState::Stale(state) => {
drop(state);
return self.maybe_changed_after_upgrade(db, revision);
}
}
}
}
fn maybe_changed_after_probe<StateGuard>(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
state: StateGuard,
runtime: &Runtime,
revision_now: Revision,
) -> MaybeChangedSinceProbeState<StateGuard>
where
StateGuard: Deref<Target = QueryState<Q>>,
{
match self.probe(db, state, runtime, revision_now) {
ProbeState::Retry => MaybeChangedSinceProbeState::Retry,
ProbeState::Stale(state) => MaybeChangedSinceProbeState::Stale(state),
// If we know when value last changed, we can return right away.
// Note that we don't need the actual value to be available.
ProbeState::NoValue(_, changed_at)
| ProbeState::UpToDate(StampedValue { value: _, durability: _, changed_at }) => {
MaybeChangedSinceProbeState::ChangedAt(changed_at)
}
// If we have nothing cached, then value may have changed.
ProbeState::NotComputed(_) => MaybeChangedSinceProbeState::ChangedAt(revision_now),
}
}
fn maybe_changed_after_upgrade(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
revision: Revision,
) -> bool {
let runtime = db.salsa_runtime();
let revision_now = runtime.current_revision();
// Get an upgradable read lock, which permits other reads but no writers.
// Probe again. If the value is stale (needs to be verified), then upgrade
// to a write lock and swap it with InProgress while we work.
let mut old_memo = match self.maybe_changed_after_probe(
db,
self.state.upgradable_read(),
runtime,
revision_now,
) {
MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision,
// If another thread was active, then the cache line is going to be
// either verified or cleared out. Just recurse to figure out which.
// Note that we don't need an upgradable read.
MaybeChangedSinceProbeState::Retry => return self.maybe_changed_after(db, revision),
MaybeChangedSinceProbeState::Stale(state) => {
type RwLockUpgradableReadGuard<'a, T> =
lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>;
let mut state = RwLockUpgradableReadGuard::upgrade(state);
match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) {
QueryState::Memoized(old_memo) => old_memo,
QueryState::NotComputed | QueryState::InProgress { .. } => unreachable!(),
}
}
};
let panic_guard = PanicGuard::new(self.database_key_index, self, runtime);
let active_query = runtime.push_query(self.database_key_index);
if old_memo.verify_revisions(db.ops_database(), revision_now, &active_query) {
let maybe_changed = old_memo.revisions.changed_at > revision;
panic_guard.proceed(Some(old_memo));
maybe_changed
} else if old_memo.value.is_some() {
// We found that this memoized value may have changed
// but we have an old value. We can re-run the code and
// actually *check* if it has changed.
let StampedValue { changed_at, .. } =
self.execute(db, runtime, revision_now, active_query, panic_guard, Some(old_memo));
changed_at > revision
} else {
// We found that inputs to this memoized value may have chanced
// but we don't have an old value to compare against or re-use.
// No choice but to drop the memo and say that its value may have changed.
panic_guard.proceed(None);
true
}
}
/// Helper: see [`Runtime::try_block_on_or_unwind`].
fn block_on_or_unwind<MutexGuard>(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
runtime: &Runtime,
other_id: RuntimeId,
mutex_guard: MutexGuard,
) {
runtime.block_on_or_unwind(
db.ops_database(),
self.database_key_index,
other_id,
mutex_guard,
)
}
fn should_memoize_value(&self, key: &Q::Key) -> bool {
MP::should_memoize_value(key)
}
}
impl<Q> QueryState<Q>
where
Q: QueryFunction,
{
fn in_progress(id: RuntimeId) -> Self {
QueryState::InProgress { id, anyone_waiting: Default::default() }
}
}
struct PanicGuard<'me, Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
database_key_index: DatabaseKeyIndex,
slot: &'me Slot<Q, MP>,
runtime: &'me Runtime,
}
impl<'me, Q, MP> PanicGuard<'me, Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn new(
database_key_index: DatabaseKeyIndex,
slot: &'me Slot<Q, MP>,
runtime: &'me Runtime,
) -> Self {
Self { database_key_index, slot, runtime }
}
/// Indicates that we have concluded normally (without panicking).
/// If `opt_memo` is some, then this memo is installed as the new
/// memoized value. If `opt_memo` is `None`, then the slot is cleared
/// and has no value.
fn proceed(mut self, opt_memo: Option<Memo<Q::Value>>) {
self.overwrite_placeholder(WaitResult::Completed, opt_memo);
std::mem::forget(self)
}
/// Overwrites the `InProgress` placeholder for `key` that we
/// inserted; if others were blocked, waiting for us to finish,
/// then notify them.
fn overwrite_placeholder(&mut self, wait_result: WaitResult, opt_memo: Option<Memo<Q::Value>>) {
let mut write = self.slot.state.write();
let old_value = match opt_memo {
// Replace the `InProgress` marker that we installed with the new
// memo, thus releasing our unique access to this key.
Some(memo) => std::mem::replace(&mut *write, QueryState::Memoized(memo)),
// We had installed an `InProgress` marker, but we panicked before
// it could be removed. At this point, we therefore "own" unique
// access to our slot, so we can just remove the key.
None => std::mem::replace(&mut *write, QueryState::NotComputed),
};
match old_value {
QueryState::InProgress { id, anyone_waiting } => {
assert_eq!(id, self.runtime.id());
// NB: As noted on the `store`, `Ordering::Relaxed` is
// sufficient here. This boolean signals us on whether to
// acquire a mutex; the mutex will guarantee that all writes
// we are interested in are visible.
if anyone_waiting.load(Ordering::Relaxed) {
self.runtime.unblock_queries_blocked_on(self.database_key_index, wait_result);
}
}
_ => panic!(
"\
Unexpected panic during query evaluation, aborting the process.
Please report this bug to https://github.com/salsa-rs/salsa/issues."
),
}
}
}
impl<'me, Q, MP> Drop for PanicGuard<'me, Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn drop(&mut self) {
if std::thread::panicking() {
// We panicked before we could proceed and need to remove `key`.
self.overwrite_placeholder(WaitResult::Panicked, None)
} else {
// If no panic occurred, then panic guard ought to be
// "forgotten" and so this Drop code should never run.
panic!(".forget() was not called")
}
}
}
impl<V> Memo<V>
where
V: Clone,
{
/// Determines whether the value stored in this memo (if any) is still
/// valid in the current revision. If so, returns a stamped value.
///
/// If needed, this will walk each dependency and
/// recursively invoke `maybe_changed_after`, which may in turn
/// re-execute the dependency. This can cause cycles to occur,
/// so the current query must be pushed onto the
/// stack to permit cycle detection and recovery: therefore,
/// takes the `active_query` argument as evidence.
fn verify_value(
&mut self,
db: &dyn Database,
revision_now: Revision,
active_query: &ActiveQueryGuard<'_>,
) -> Option<StampedValue<V>> {
// If we don't have a memoized value, nothing to validate.
if self.value.is_none() {
return None;
}
if self.verify_revisions(db, revision_now, active_query) {
Some(StampedValue {
durability: self.revisions.durability,
changed_at: self.revisions.changed_at,
value: self.value.as_ref().unwrap().clone(),
})
} else {
None
}
}
/// Determines whether the value represented by this memo is still
/// valid in the current revision; note that the value itself is
/// not needed for this check. If needed, this will walk each
/// dependency and recursively invoke `maybe_changed_after`, which
/// may in turn re-execute the dependency. This can cause cycles to occur,
/// so the current query must be pushed onto the
/// stack to permit cycle detection and recovery: therefore,
/// takes the `active_query` argument as evidence.
fn verify_revisions(
&mut self,
db: &dyn Database,
revision_now: Revision,
_active_query: &ActiveQueryGuard<'_>,
) -> bool {
assert!(self.verified_at != revision_now);
let verified_at = self.verified_at;
debug!(
"verify_revisions: verified_at={:?}, revision_now={:?}, inputs={:#?}",
verified_at, revision_now, self.revisions.inputs
);
if self.check_durability(db.salsa_runtime()) {
return self.mark_value_as_verified(revision_now);
}
match &self.revisions.inputs {
// We can't validate values that had untracked inputs; just have to
// re-execute.
QueryInputs::Untracked => {
return false;
}
QueryInputs::NoInputs => {}
// Check whether any of our inputs changed since the
// **last point where we were verified** (not since we
// last changed). This is important: if we have
// memoized values, then an input may have changed in
// revision R2, but we found that *our* value was the
// same regardless, so our change date is still
// R1. But our *verification* date will be R2, and we
// are only interested in finding out whether the
// input changed *again*.
QueryInputs::Tracked { inputs } => {
let changed_input =
inputs.iter().find(|&&input| db.maybe_changed_after(input, verified_at));
if let Some(input) = changed_input {
debug!("validate_memoized_value: `{:?}` may have changed", input);
return false;
}
}
};
self.mark_value_as_verified(revision_now)
}
/// True if this memo is known not to have changed based on its durability.
fn check_durability(&self, runtime: &Runtime) -> bool {
let last_changed = runtime.last_changed_revision(self.revisions.durability);
debug!(
"check_durability(last_changed={:?} <= verified_at={:?}) = {:?}",
last_changed,
self.verified_at,
last_changed <= self.verified_at,
);
last_changed <= self.verified_at
}
fn mark_value_as_verified(&mut self, revision_now: Revision) -> bool {
self.verified_at = revision_now;
true
}
fn has_untracked_input(&self) -> bool {
matches!(self.revisions.inputs, QueryInputs::Untracked)
}
}
impl<Q, MP> std::fmt::Debug for Slot<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "{:?}({:?})", Q::default(), self.key)
}
}
impl<Q, MP> LruNode for Slot<Q, MP>
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
{
fn lru_index(&self) -> &LruIndex {
&self.lru_index
}
}
/// Check that `Slot<Q, MP>: Send + Sync` as long as
/// `DB::DatabaseData: Send + Sync`, which in turn implies that
/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`.
#[allow(dead_code)]
fn check_send_sync<Q, MP>()
where
Q: QueryFunction,
MP: MemoizationPolicy<Q>,
Q::Key: Send + Sync,
Q::Value: Send + Sync,
{
fn is_send_sync<T: Send + Sync>() {}
is_send_sync::<Slot<Q, MP>>();
}
/// Check that `Slot<Q, MP>: 'static` as long as
/// `DB::DatabaseData: 'static`, which in turn implies that
/// `Q::Key: 'static`, `Q::Value: 'static`.
#[allow(dead_code)]
fn check_static<Q, MP>()
where
Q: QueryFunction + 'static,
MP: MemoizationPolicy<Q> + 'static,
Q::Key: 'static,
Q::Value: 'static,
{
fn is_static<T: 'static>() {}
is_static::<Slot<Q, MP>>();
}

115
crates/salsa/src/doctest.rs Normal file
View file

@ -0,0 +1,115 @@
//!
#![allow(dead_code)]
/// Test that a database with a key/value that is not `Send` will,
/// indeed, not be `Send`.
///
/// ```compile_fail,E0277
/// use std::rc::Rc;
///
/// #[salsa::query_group(NoSendSyncStorage)]
/// trait NoSendSyncDatabase: salsa::Database {
/// fn no_send_sync_value(&self, key: bool) -> Rc<bool>;
/// fn no_send_sync_key(&self, key: Rc<bool>) -> bool;
/// }
///
/// fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Rc<bool> {
/// Rc::new(key)
/// }
///
/// fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Rc<bool>) -> bool {
/// *key
/// }
///
/// #[salsa::database(NoSendSyncStorage)]
/// #[derive(Default)]
/// struct DatabaseImpl {
/// storage: salsa::Storage<Self>,
/// }
///
/// impl salsa::Database for DatabaseImpl {
/// }
///
/// fn is_send<T: Send>(_: T) { }
///
/// fn assert_send() {
/// is_send(DatabaseImpl::default());
/// }
/// ```
fn test_key_not_send_db_not_send() {}
/// Test that a database with a key/value that is not `Sync` will not
/// be `Send`.
///
/// ```compile_fail,E0277
/// use std::rc::Rc;
/// use std::cell::Cell;
///
/// #[salsa::query_group(NoSendSyncStorage)]
/// trait NoSendSyncDatabase: salsa::Database {
/// fn no_send_sync_value(&self, key: bool) -> Cell<bool>;
/// fn no_send_sync_key(&self, key: Cell<bool>) -> bool;
/// }
///
/// fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Cell<bool> {
/// Cell::new(key)
/// }
///
/// fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Cell<bool>) -> bool {
/// *key
/// }
///
/// #[salsa::database(NoSendSyncStorage)]
/// #[derive(Default)]
/// struct DatabaseImpl {
/// runtime: salsa::Storage<Self>,
/// }
///
/// impl salsa::Database for DatabaseImpl {
/// }
///
/// fn is_send<T: Send>(_: T) { }
///
/// fn assert_send() {
/// is_send(DatabaseImpl::default());
/// }
/// ```
fn test_key_not_sync_db_not_send() {}
/// Test that a database with a key/value that is not `Sync` will
/// not be `Sync`.
///
/// ```compile_fail,E0277
/// use std::cell::Cell;
/// use std::rc::Rc;
///
/// #[salsa::query_group(NoSendSyncStorage)]
/// trait NoSendSyncDatabase: salsa::Database {
/// fn no_send_sync_value(&self, key: bool) -> Cell<bool>;
/// fn no_send_sync_key(&self, key: Cell<bool>) -> bool;
/// }
///
/// fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Cell<bool> {
/// Cell::new(key)
/// }
///
/// fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Cell<bool>) -> bool {
/// *key
/// }
///
/// #[salsa::database(NoSendSyncStorage)]
/// #[derive(Default)]
/// struct DatabaseImpl {
/// runtime: salsa::Storage<Self>,
/// }
///
/// impl salsa::Database for DatabaseImpl {
/// }
///
/// fn is_sync<T: Sync>(_: T) { }
///
/// fn assert_send() {
/// is_sync(DatabaseImpl::default());
/// }
/// ```
fn test_key_not_sync_db_not_sync() {}

View file

@ -0,0 +1,50 @@
//!
/// Describes how likely a value is to change -- how "durable" it is.
/// By default, inputs have `Durability::LOW` and interned values have
/// `Durability::HIGH`. But inputs can be explicitly set with other
/// durabilities.
///
/// We use durabilities to optimize the work of "revalidating" a query
/// after some input has changed. Ordinarily, in a new revision,
/// queries have to trace all their inputs back to the base inputs to
/// determine if any of those inputs have changed. But if we know that
/// the only changes were to inputs of low durability (the common
/// case), and we know that the query only used inputs of medium
/// durability or higher, then we can skip that enumeration.
///
/// Typically, one assigns low durabilites to inputs that the user is
/// frequently editing. Medium or high durabilities are used for
/// configuration, the source from library crates, or other things
/// that are unlikely to be edited.
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct Durability(u8);
impl Durability {
/// Low durability: things that change frequently.
///
/// Example: part of the crate being edited
pub const LOW: Durability = Durability(0);
/// Medium durability: things that change sometimes, but rarely.
///
/// Example: a Cargo.toml file
pub const MEDIUM: Durability = Durability(1);
/// High durability: things that are not expected to change under
/// common usage.
///
/// Example: the standard library or something from crates.io
pub const HIGH: Durability = Durability(2);
/// The maximum possible durability; equivalent to HIGH but
/// "conceptually" distinct (i.e., if we add more durability
/// levels, this could change).
pub(crate) const MAX: Durability = Self::HIGH;
/// Number of durability levels.
pub(crate) const LEN: usize = 3;
pub(crate) fn index(self) -> usize {
self.0 as usize
}
}

4
crates/salsa/src/hash.rs Normal file
View file

@ -0,0 +1,4 @@
//!
pub(crate) type FxHasher = std::hash::BuildHasherDefault<rustc_hash::FxHasher>;
pub(crate) type FxIndexSet<K> = indexmap::IndexSet<K, FxHasher>;
pub(crate) type FxIndexMap<K, V> = indexmap::IndexMap<K, V, FxHasher>;

240
crates/salsa/src/input.rs Normal file
View file

@ -0,0 +1,240 @@
//!
use crate::debug::TableEntry;
use crate::durability::Durability;
use crate::hash::FxIndexMap;
use crate::plumbing::CycleRecoveryStrategy;
use crate::plumbing::InputQueryStorageOps;
use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps;
use crate::revision::Revision;
use crate::runtime::StampedValue;
use crate::Database;
use crate::Query;
use crate::Runtime;
use crate::{DatabaseKeyIndex, QueryDb};
use indexmap::map::Entry;
use parking_lot::RwLock;
use std::convert::TryFrom;
use tracing::debug;
/// Input queries store the result plus a list of the other queries
/// that they invoked. This means we can avoid recomputing them when
/// none of those inputs have changed.
pub struct InputStorage<Q>
where
Q: Query,
{
group_index: u16,
slots: RwLock<FxIndexMap<Q::Key, Slot<Q>>>,
}
struct Slot<Q>
where
Q: Query,
{
database_key_index: DatabaseKeyIndex,
stamped_value: RwLock<StampedValue<Q::Value>>,
}
impl<Q> std::panic::RefUnwindSafe for InputStorage<Q>
where
Q: Query,
Q::Key: std::panic::RefUnwindSafe,
Q::Value: std::panic::RefUnwindSafe,
{
}
impl<Q> QueryStorageOps<Q> for InputStorage<Q>
where
Q: Query,
{
const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = CycleRecoveryStrategy::Panic;
fn new(group_index: u16) -> Self {
InputStorage { group_index, slots: Default::default() }
}
fn fmt_index(
&self,
_db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex,
fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
assert_eq!(index.group_index, self.group_index);
assert_eq!(index.query_index, Q::QUERY_INDEX);
let slot_map = self.slots.read();
let key = slot_map.get_index(index.key_index as usize).unwrap().0;
write!(fmt, "{}({:?})", Q::QUERY_NAME, key)
}
fn maybe_changed_after(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex,
revision: Revision,
) -> bool {
assert_eq!(input.group_index, self.group_index);
assert_eq!(input.query_index, Q::QUERY_INDEX);
debug_assert!(revision < db.salsa_runtime().current_revision());
let slots = &self.slots.read();
let slot = slots.get_index(input.key_index as usize).unwrap().1;
slot.maybe_changed_after(db, revision)
}
fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
db.unwind_if_cancelled();
let slots = &self.slots.read();
let slot = slots
.get(key)
.unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key));
let StampedValue { value, durability, changed_at } = slot.stamped_value.read().clone();
db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted(
slot.database_key_index,
durability,
changed_at,
);
value
}
fn durability(&self, _db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Durability {
match self.slots.read().get(key) {
Some(slot) => slot.stamped_value.read().durability,
None => panic!("no value set for {:?}({:?})", Q::default(), key),
}
}
fn entries<C>(&self, _db: &<Q as QueryDb<'_>>::DynDb) -> C
where
C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
{
let slots = self.slots.read();
slots
.iter()
.map(|(key, slot)| {
TableEntry::new(key.clone(), Some(slot.stamped_value.read().value.clone()))
})
.collect()
}
}
impl<Q> Slot<Q>
where
Q: Query,
{
fn maybe_changed_after(&self, _db: &<Q as QueryDb<'_>>::DynDb, revision: Revision) -> bool {
debug!("maybe_changed_after(slot={:?}, revision={:?})", self, revision,);
let changed_at = self.stamped_value.read().changed_at;
debug!("maybe_changed_after: changed_at = {:?}", changed_at);
changed_at > revision
}
}
impl<Q> QueryStorageMassOps for InputStorage<Q>
where
Q: Query,
{
fn purge(&self) {
*self.slots.write() = Default::default();
}
}
impl<Q> InputQueryStorageOps<Q> for InputStorage<Q>
where
Q: Query,
{
fn set(&self, runtime: &mut Runtime, key: &Q::Key, value: Q::Value, durability: Durability) {
tracing::debug!("{:?}({:?}) = {:?} ({:?})", Q::default(), key, value, durability);
// The value is changing, so we need a new revision (*). We also
// need to update the 'last changed' revision by invoking
// `guard.mark_durability_as_changed`.
//
// CAREFUL: This will block until the global revision lock can
// be acquired. If there are still queries executing, they may
// need to read from this input. Therefore, we wait to acquire
// the lock on `map` until we also hold the global query write
// lock.
//
// (*) Technically, since you can't presently access an input
// for a non-existent key, and you can't enumerate the set of
// keys, we only need a new revision if the key used to
// exist. But we may add such methods in the future and this
// case doesn't generally seem worth optimizing for.
runtime.with_incremented_revision(|next_revision| {
let mut slots = self.slots.write();
// Do this *after* we acquire the lock, so that we are not
// racing with somebody else to modify this same cell.
// (Otherwise, someone else might write a *newer* revision
// into the same cell while we block on the lock.)
let stamped_value = StampedValue { value, durability, changed_at: next_revision };
match slots.entry(key.clone()) {
Entry::Occupied(entry) => {
let mut slot_stamped_value = entry.get().stamped_value.write();
let old_durability = slot_stamped_value.durability;
*slot_stamped_value = stamped_value;
Some(old_durability)
}
Entry::Vacant(entry) => {
let key_index = u32::try_from(entry.index()).unwrap();
let database_key_index = DatabaseKeyIndex {
group_index: self.group_index,
query_index: Q::QUERY_INDEX,
key_index,
};
entry.insert(Slot {
database_key_index,
stamped_value: RwLock::new(stamped_value),
});
None
}
}
});
}
}
/// Check that `Slot<Q, MP>: Send + Sync` as long as
/// `DB::DatabaseData: Send + Sync`, which in turn implies that
/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`.
#[allow(dead_code)]
fn check_send_sync<Q>()
where
Q: Query,
Q::Key: Send + Sync,
Q::Value: Send + Sync,
{
fn is_send_sync<T: Send + Sync>() {}
is_send_sync::<Slot<Q>>();
}
/// Check that `Slot<Q, MP>: 'static` as long as
/// `DB::DatabaseData: 'static`, which in turn implies that
/// `Q::Key: 'static`, `Q::Value: 'static`.
#[allow(dead_code)]
fn check_static<Q>()
where
Q: Query + 'static,
Q::Key: 'static,
Q::Value: 'static,
{
fn is_static<T: 'static>() {}
is_static::<Slot<Q>>();
}
impl<Q> std::fmt::Debug for Slot<Q>
where
Q: Query,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "{:?}", Q::default())
}
}

View file

@ -0,0 +1,131 @@
//!
use std::fmt;
use std::num::NonZeroU32;
/// The "raw-id" is used for interned keys in salsa -- it is basically
/// a newtype'd u32. Typically, it is wrapped in a type of your own
/// devising. For more information about interned keys, see [the
/// interned key RFC][rfc].
///
/// # Creating a `InternId`
//
/// InternId values can be constructed using the `From` impls,
/// which are implemented for `u32` and `usize`:
///
/// ```
/// # use salsa::InternId;
/// let intern_id1 = InternId::from(22_u32);
/// let intern_id2 = InternId::from(22_usize);
/// assert_eq!(intern_id1, intern_id2);
/// ```
///
/// # Converting to a u32 or usize
///
/// Normally, there should be no need to access the underlying integer
/// in a `InternId`. But if you do need to do so, you can convert to a
/// `usize` using the `as_u32` or `as_usize` methods or the `From` impls.
///
/// ```
/// # use salsa::InternId;
/// let intern_id = InternId::from(22_u32);
/// let value = u32::from(intern_id);
/// assert_eq!(value, 22);
/// ```
///
/// ## Illegal values
///
/// Be warned, however, that `InternId` values cannot be created from
/// *arbitrary* values -- in particular large values greater than
/// `InternId::MAX` will panic. Those large values are reserved so that
/// the Rust compiler can use them as sentinel values, which means
/// that (for example) `Option<InternId>` is represented in a single
/// word.
///
/// ```should_panic
/// # use salsa::InternId;
/// InternId::from(InternId::MAX);
/// ```
///
/// [rfc]: https://github.com/salsa-rs/salsa-rfcs/pull/2
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct InternId {
value: NonZeroU32,
}
impl InternId {
/// The maximum allowed `InternId`. This value can grow between
/// releases without affecting semver.
pub const MAX: u32 = 0xFFFF_FF00;
/// Creates a new InternId.
///
/// # Safety
///
/// `value` must be less than `MAX`
pub const unsafe fn new_unchecked(value: u32) -> Self {
debug_assert!(value < InternId::MAX);
InternId { value: NonZeroU32::new_unchecked(value + 1) }
}
/// Convert this raw-id into a u32 value.
///
/// ```
/// # use salsa::InternId;
/// let intern_id = InternId::from(22_u32);
/// let value = intern_id.as_usize();
/// assert_eq!(value, 22);
/// ```
pub fn as_u32(self) -> u32 {
self.value.get() - 1
}
/// Convert this raw-id into a usize value.
///
/// ```
/// # use salsa::InternId;
/// let intern_id = InternId::from(22_u32);
/// let value = intern_id.as_usize();
/// assert_eq!(value, 22);
/// ```
pub fn as_usize(self) -> usize {
self.as_u32() as usize
}
}
impl From<InternId> for u32 {
fn from(raw: InternId) -> u32 {
raw.as_u32()
}
}
impl From<InternId> for usize {
fn from(raw: InternId) -> usize {
raw.as_usize()
}
}
impl From<u32> for InternId {
fn from(id: u32) -> InternId {
assert!(id < InternId::MAX);
unsafe { InternId::new_unchecked(id) }
}
}
impl From<usize> for InternId {
fn from(id: usize) -> InternId {
assert!(id < (InternId::MAX as usize));
unsafe { InternId::new_unchecked(id as u32) }
}
}
impl fmt::Debug for InternId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.as_usize().fmt(f)
}
}
impl fmt::Display for InternId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.as_usize().fmt(f)
}
}

View file

@ -0,0 +1,409 @@
//!
use crate::debug::TableEntry;
use crate::durability::Durability;
use crate::intern_id::InternId;
use crate::plumbing::CycleRecoveryStrategy;
use crate::plumbing::HasQueryGroup;
use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps;
use crate::revision::Revision;
use crate::Query;
use crate::{Database, DatabaseKeyIndex, QueryDb};
use parking_lot::RwLock;
use rustc_hash::FxHashMap;
use std::collections::hash_map::Entry;
use std::convert::From;
use std::fmt::Debug;
use std::hash::Hash;
use triomphe::Arc;
const INTERN_DURABILITY: Durability = Durability::HIGH;
/// Handles storage where the value is 'derived' by executing a
/// function (in contrast to "inputs").
pub struct InternedStorage<Q>
where
Q: Query,
Q::Value: InternKey,
{
group_index: u16,
tables: RwLock<InternTables<Q::Key>>,
}
/// Storage for the looking up interned things.
pub struct LookupInternedStorage<Q, IQ>
where
Q: Query,
Q::Key: InternKey,
Q::Value: Eq + Hash,
{
phantom: std::marker::PhantomData<(Q::Key, IQ)>,
}
struct InternTables<K> {
/// Map from the key to the corresponding intern-index.
map: FxHashMap<K, InternId>,
/// For each valid intern-index, stores the interned value.
values: Vec<Arc<Slot<K>>>,
}
/// Trait implemented for the "key" that results from a
/// `#[salsa::intern]` query. This is basically meant to be a
/// "newtype"'d `u32`.
pub trait InternKey {
/// Create an instance of the intern-key from a `u32` value.
fn from_intern_id(v: InternId) -> Self;
/// Extract the `u32` with which the intern-key was created.
fn as_intern_id(&self) -> InternId;
}
impl InternKey for InternId {
fn from_intern_id(v: InternId) -> InternId {
v
}
fn as_intern_id(&self) -> InternId {
*self
}
}
#[derive(Debug)]
struct Slot<K> {
/// DatabaseKeyIndex for this slot.
database_key_index: DatabaseKeyIndex,
/// Value that was interned.
value: K,
/// When was this intern'd?
///
/// (This informs the "changed-at" result)
interned_at: Revision,
}
impl<Q> std::panic::RefUnwindSafe for InternedStorage<Q>
where
Q: Query,
Q::Key: std::panic::RefUnwindSafe,
Q::Value: InternKey,
Q::Value: std::panic::RefUnwindSafe,
{
}
impl<K: Debug + Hash + Eq> InternTables<K> {
/// Returns the slot for the given key.
fn slot_for_key(&self, key: &K) -> Option<(Arc<Slot<K>>, InternId)> {
let &index = self.map.get(key)?;
Some((self.slot_for_index(index), index))
}
/// Returns the slot at the given index.
fn slot_for_index(&self, index: InternId) -> Arc<Slot<K>> {
let slot = &self.values[index.as_usize()];
slot.clone()
}
}
impl<K> Default for InternTables<K>
where
K: Eq + Hash,
{
fn default() -> Self {
Self { map: Default::default(), values: Default::default() }
}
}
impl<Q> InternedStorage<Q>
where
Q: Query,
Q::Key: Eq + Hash + Clone,
Q::Value: InternKey,
{
/// If `key` has already been interned, returns its slot. Otherwise, creates a new slot.
fn intern_index(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
key: &Q::Key,
) -> (Arc<Slot<Q::Key>>, InternId) {
if let Some(i) = self.intern_check(key) {
return i;
}
let owned_key1 = key.to_owned();
let owned_key2 = owned_key1.clone();
let revision_now = db.salsa_runtime().current_revision();
let mut tables = self.tables.write();
let tables = &mut *tables;
let entry = match tables.map.entry(owned_key1) {
Entry::Vacant(entry) => entry,
Entry::Occupied(entry) => {
// Somebody inserted this key while we were waiting
// for the write lock. In this case, we don't need to
// update the `accessed_at` field because they should
// have already done so!
let index = *entry.get();
let slot = &tables.values[index.as_usize()];
debug_assert_eq!(owned_key2, slot.value);
return (slot.clone(), index);
}
};
let create_slot = |index: InternId| {
let database_key_index = DatabaseKeyIndex {
group_index: self.group_index,
query_index: Q::QUERY_INDEX,
key_index: index.as_u32(),
};
Arc::new(Slot { database_key_index, value: owned_key2, interned_at: revision_now })
};
let (slot, index);
index = InternId::from(tables.values.len());
slot = create_slot(index);
tables.values.push(slot.clone());
entry.insert(index);
(slot, index)
}
fn intern_check(&self, key: &Q::Key) -> Option<(Arc<Slot<Q::Key>>, InternId)> {
self.tables.read().slot_for_key(key)
}
/// Given an index, lookup and clone its value, updating the
/// `accessed_at` time if necessary.
fn lookup_value(&self, index: InternId) -> Arc<Slot<Q::Key>> {
self.tables.read().slot_for_index(index)
}
}
impl<Q> QueryStorageOps<Q> for InternedStorage<Q>
where
Q: Query,
Q::Value: InternKey,
{
const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = CycleRecoveryStrategy::Panic;
fn new(group_index: u16) -> Self {
InternedStorage { group_index, tables: RwLock::new(InternTables::default()) }
}
fn fmt_index(
&self,
_db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex,
fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
assert_eq!(index.group_index, self.group_index);
assert_eq!(index.query_index, Q::QUERY_INDEX);
let intern_id = InternId::from(index.key_index);
let slot = self.lookup_value(intern_id);
write!(fmt, "{}({:?})", Q::QUERY_NAME, slot.value)
}
fn maybe_changed_after(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex,
revision: Revision,
) -> bool {
assert_eq!(input.group_index, self.group_index);
assert_eq!(input.query_index, Q::QUERY_INDEX);
debug_assert!(revision < db.salsa_runtime().current_revision());
let intern_id = InternId::from(input.key_index);
let slot = self.lookup_value(intern_id);
slot.maybe_changed_after(revision)
}
fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
db.unwind_if_cancelled();
let (slot, index) = self.intern_index(db, key);
let changed_at = slot.interned_at;
db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted(
slot.database_key_index,
INTERN_DURABILITY,
changed_at,
);
<Q::Value>::from_intern_id(index)
}
fn durability(&self, _db: &<Q as QueryDb<'_>>::DynDb, _key: &Q::Key) -> Durability {
INTERN_DURABILITY
}
fn entries<C>(&self, _db: &<Q as QueryDb<'_>>::DynDb) -> C
where
C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
{
let tables = self.tables.read();
tables
.map
.iter()
.map(|(key, index)| {
TableEntry::new(key.clone(), Some(<Q::Value>::from_intern_id(*index)))
})
.collect()
}
}
impl<Q> QueryStorageMassOps for InternedStorage<Q>
where
Q: Query,
Q::Value: InternKey,
{
fn purge(&self) {
*self.tables.write() = Default::default();
}
}
// Workaround for
// ```
// IQ: for<'d> QueryDb<
// 'd,
// DynDb = <Q as QueryDb<'d>>::DynDb,
// Group = <Q as QueryDb<'d>>::Group,
// GroupStorage = <Q as QueryDb<'d>>::GroupStorage,
// >,
// ```
// not working to make rustc know DynDb, Group and GroupStorage being the same in `Q` and `IQ`
#[doc(hidden)]
pub trait EqualDynDb<'d, IQ>: QueryDb<'d>
where
IQ: QueryDb<'d>,
{
fn convert_db(d: &Self::DynDb) -> &IQ::DynDb;
fn convert_group_storage(d: &Self::GroupStorage) -> &IQ::GroupStorage;
}
impl<'d, IQ, Q> EqualDynDb<'d, IQ> for Q
where
Q: QueryDb<'d, DynDb = IQ::DynDb, Group = IQ::Group, GroupStorage = IQ::GroupStorage>,
Q::DynDb: HasQueryGroup<Q::Group>,
IQ: QueryDb<'d>,
{
fn convert_db(d: &Self::DynDb) -> &IQ::DynDb {
d
}
fn convert_group_storage(d: &Self::GroupStorage) -> &IQ::GroupStorage {
d
}
}
impl<Q, IQ> QueryStorageOps<Q> for LookupInternedStorage<Q, IQ>
where
Q: Query,
Q::Key: InternKey,
Q::Value: Eq + Hash,
IQ: Query<Key = Q::Value, Value = Q::Key, Storage = InternedStorage<IQ>>,
for<'d> Q: EqualDynDb<'d, IQ>,
{
const CYCLE_STRATEGY: CycleRecoveryStrategy = CycleRecoveryStrategy::Panic;
fn new(_group_index: u16) -> Self {
LookupInternedStorage { phantom: std::marker::PhantomData }
}
fn fmt_index(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex,
fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
let group_storage =
<<Q as QueryDb<'_>>::DynDb as HasQueryGroup<Q::Group>>::group_storage(db);
let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage));
interned_storage.fmt_index(Q::convert_db(db), index, fmt)
}
fn maybe_changed_after(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex,
revision: Revision,
) -> bool {
let group_storage =
<<Q as QueryDb<'_>>::DynDb as HasQueryGroup<Q::Group>>::group_storage(db);
let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage));
interned_storage.maybe_changed_after(Q::convert_db(db), input, revision)
}
fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value {
let index = key.as_intern_id();
let group_storage =
<<Q as QueryDb<'_>>::DynDb as HasQueryGroup<Q::Group>>::group_storage(db);
let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage));
let slot = interned_storage.lookup_value(index);
let value = slot.value.clone();
let interned_at = slot.interned_at;
db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted(
slot.database_key_index,
INTERN_DURABILITY,
interned_at,
);
value
}
fn durability(&self, _db: &<Q as QueryDb<'_>>::DynDb, _key: &Q::Key) -> Durability {
INTERN_DURABILITY
}
fn entries<C>(&self, db: &<Q as QueryDb<'_>>::DynDb) -> C
where
C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>,
{
let group_storage =
<<Q as QueryDb<'_>>::DynDb as HasQueryGroup<Q::Group>>::group_storage(db);
let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage));
let tables = interned_storage.tables.read();
tables
.map
.iter()
.map(|(key, index)| {
TableEntry::new(<Q::Key>::from_intern_id(*index), Some(key.clone()))
})
.collect()
}
}
impl<Q, IQ> QueryStorageMassOps for LookupInternedStorage<Q, IQ>
where
Q: Query,
Q::Key: InternKey,
Q::Value: Eq + Hash,
IQ: Query<Key = Q::Value, Value = Q::Key>,
{
fn purge(&self) {}
}
impl<K> Slot<K> {
fn maybe_changed_after(&self, revision: Revision) -> bool {
self.interned_at > revision
}
}
/// Check that `Slot<Q, MP>: Send + Sync` as long as
/// `DB::DatabaseData: Send + Sync`, which in turn implies that
/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`.
#[allow(dead_code)]
fn check_send_sync<K>()
where
K: Send + Sync,
{
fn is_send_sync<T: Send + Sync>() {}
is_send_sync::<Slot<K>>();
}
/// Check that `Slot<Q, MP>: 'static` as long as
/// `DB::DatabaseData: 'static`, which in turn implies that
/// `Q::Key: 'static`, `Q::Value: 'static`.
#[allow(dead_code)]
fn check_static<K>()
where
K: 'static,
{
fn is_static<T: 'static>() {}
is_static::<Slot<K>>();
}

742
crates/salsa/src/lib.rs Normal file
View file

@ -0,0 +1,742 @@
//!
#![allow(clippy::type_complexity)]
#![allow(clippy::question_mark)]
#![warn(rust_2018_idioms)]
#![warn(missing_docs)]
//! The salsa crate is a crate for incremental recomputation. It
//! permits you to define a "database" of queries with both inputs and
//! values derived from those inputs; as you set the inputs, you can
//! re-execute the derived queries and it will try to re-use results
//! from previous invocations as appropriate.
mod derived;
mod doctest;
mod durability;
mod hash;
mod input;
mod intern_id;
mod interned;
mod lru;
mod revision;
mod runtime;
mod storage;
pub mod debug;
/// Items in this module are public for implementation reasons,
/// and are exempt from the SemVer guarantees.
#[doc(hidden)]
pub mod plumbing;
use crate::plumbing::CycleRecoveryStrategy;
use crate::plumbing::DerivedQueryStorageOps;
use crate::plumbing::InputQueryStorageOps;
use crate::plumbing::LruQueryStorageOps;
use crate::plumbing::QueryStorageMassOps;
use crate::plumbing::QueryStorageOps;
pub use crate::revision::Revision;
use std::fmt::{self, Debug};
use std::hash::Hash;
use std::panic::AssertUnwindSafe;
use std::panic::{self, UnwindSafe};
pub use crate::durability::Durability;
pub use crate::intern_id::InternId;
pub use crate::interned::InternKey;
pub use crate::runtime::Runtime;
pub use crate::runtime::RuntimeId;
pub use crate::storage::Storage;
/// The base trait which your "query context" must implement. Gives
/// access to the salsa runtime, which you must embed into your query
/// context (along with whatever other state you may require).
pub trait Database: plumbing::DatabaseOps {
/// This function is invoked at key points in the salsa
/// runtime. It permits the database to be customized and to
/// inject logging or other custom behavior.
fn salsa_event(&self, event_fn: Event) {
#![allow(unused_variables)]
}
/// Starts unwinding the stack if the current revision is cancelled.
///
/// This method can be called by query implementations that perform
/// potentially expensive computations, in order to speed up propagation of
/// cancellation.
///
/// Cancellation will automatically be triggered by salsa on any query
/// invocation.
///
/// This method should not be overridden by `Database` implementors. A
/// `salsa_event` is emitted when this method is called, so that should be
/// used instead.
#[inline]
fn unwind_if_cancelled(&self) {
let runtime = self.salsa_runtime();
self.salsa_event(Event {
runtime_id: runtime.id(),
kind: EventKind::WillCheckCancellation,
});
let current_revision = runtime.current_revision();
let pending_revision = runtime.pending_revision();
tracing::debug!(
"unwind_if_cancelled: current_revision={:?}, pending_revision={:?}",
current_revision,
pending_revision
);
if pending_revision > current_revision {
runtime.unwind_cancelled();
}
}
/// Gives access to the underlying salsa runtime.
///
/// This method should not be overridden by `Database` implementors.
fn salsa_runtime(&self) -> &Runtime {
self.ops_salsa_runtime()
}
/// Gives access to the underlying salsa runtime.
///
/// This method should not be overridden by `Database` implementors.
fn salsa_runtime_mut(&mut self) -> &mut Runtime {
self.ops_salsa_runtime_mut()
}
}
/// The `Event` struct identifies various notable things that can
/// occur during salsa execution. Instances of this struct are given
/// to `salsa_event`.
pub struct Event {
/// The id of the snapshot that triggered the event. Usually
/// 1-to-1 with a thread, as well.
pub runtime_id: RuntimeId,
/// What sort of event was it.
pub kind: EventKind,
}
impl Event {
/// Returns a type that gives a user-readable debug output.
/// Use like `println!("{:?}", index.debug(db))`.
pub fn debug<'me, D: ?Sized>(&'me self, db: &'me D) -> impl std::fmt::Debug + 'me
where
D: plumbing::DatabaseOps,
{
EventDebug { event: self, db }
}
}
impl fmt::Debug for Event {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Event")
.field("runtime_id", &self.runtime_id)
.field("kind", &self.kind)
.finish()
}
}
struct EventDebug<'me, D: ?Sized>
where
D: plumbing::DatabaseOps,
{
event: &'me Event,
db: &'me D,
}
impl<'me, D: ?Sized> fmt::Debug for EventDebug<'me, D>
where
D: plumbing::DatabaseOps,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Event")
.field("runtime_id", &self.event.runtime_id)
.field("kind", &self.event.kind.debug(self.db))
.finish()
}
}
/// An enum identifying the various kinds of events that can occur.
pub enum EventKind {
/// Occurs when we found that all inputs to a memoized value are
/// up-to-date and hence the value can be re-used without
/// executing the closure.
///
/// Executes before the "re-used" value is returned.
DidValidateMemoizedValue {
/// The database-key for the affected value. Implements `Debug`.
database_key: DatabaseKeyIndex,
},
/// Indicates that another thread (with id `other_runtime_id`) is processing the
/// given query (`database_key`), so we will block until they
/// finish.
///
/// Executes after we have registered with the other thread but
/// before they have answered us.
///
/// (NB: you can find the `id` of the current thread via the
/// `salsa_runtime`)
WillBlockOn {
/// The id of the runtime we will block on.
other_runtime_id: RuntimeId,
/// The database-key for the affected value. Implements `Debug`.
database_key: DatabaseKeyIndex,
},
/// Indicates that the function for this query will be executed.
/// This is either because it has never executed before or because
/// its inputs may be out of date.
WillExecute {
/// The database-key for the affected value. Implements `Debug`.
database_key: DatabaseKeyIndex,
},
/// Indicates that `unwind_if_cancelled` was called and salsa will check if
/// the current revision has been cancelled.
WillCheckCancellation,
}
impl EventKind {
/// Returns a type that gives a user-readable debug output.
/// Use like `println!("{:?}", index.debug(db))`.
pub fn debug<'me, D: ?Sized>(&'me self, db: &'me D) -> impl std::fmt::Debug + 'me
where
D: plumbing::DatabaseOps,
{
EventKindDebug { kind: self, db }
}
}
impl fmt::Debug for EventKind {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
EventKind::DidValidateMemoizedValue { database_key } => fmt
.debug_struct("DidValidateMemoizedValue")
.field("database_key", database_key)
.finish(),
EventKind::WillBlockOn { other_runtime_id, database_key } => fmt
.debug_struct("WillBlockOn")
.field("other_runtime_id", other_runtime_id)
.field("database_key", database_key)
.finish(),
EventKind::WillExecute { database_key } => {
fmt.debug_struct("WillExecute").field("database_key", database_key).finish()
}
EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(),
}
}
}
struct EventKindDebug<'me, D: ?Sized>
where
D: plumbing::DatabaseOps,
{
kind: &'me EventKind,
db: &'me D,
}
impl<'me, D: ?Sized> fmt::Debug for EventKindDebug<'me, D>
where
D: plumbing::DatabaseOps,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.kind {
EventKind::DidValidateMemoizedValue { database_key } => fmt
.debug_struct("DidValidateMemoizedValue")
.field("database_key", &database_key.debug(self.db))
.finish(),
EventKind::WillBlockOn { other_runtime_id, database_key } => fmt
.debug_struct("WillBlockOn")
.field("other_runtime_id", &other_runtime_id)
.field("database_key", &database_key.debug(self.db))
.finish(),
EventKind::WillExecute { database_key } => fmt
.debug_struct("WillExecute")
.field("database_key", &database_key.debug(self.db))
.finish(),
EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(),
}
}
}
/// Indicates a database that also supports parallel query
/// evaluation. All of Salsa's base query support is capable of
/// parallel execution, but for it to work, your query key/value types
/// must also be `Send`, as must any additional data in your database.
pub trait ParallelDatabase: Database + Send {
/// Creates a second handle to the database that holds the
/// database fixed at a particular revision. So long as this
/// "frozen" handle exists, any attempt to [`set`] an input will
/// block.
///
/// [`set`]: struct.QueryTable.html#method.set
///
/// This is the method you are meant to use most of the time in a
/// parallel setting where modifications may arise asynchronously
/// (e.g., a language server). In this context, it is common to
/// wish to "fork off" a snapshot of the database performing some
/// series of queries in parallel and arranging the results. Using
/// this method for that purpose ensures that those queries will
/// see a consistent view of the database (it is also advisable
/// for those queries to use the [`Runtime::unwind_if_cancelled`]
/// method to check for cancellation).
///
/// # Panics
///
/// It is not permitted to create a snapshot from inside of a
/// query. Attepting to do so will panic.
///
/// # Deadlock warning
///
/// The intended pattern for snapshots is that, once created, they
/// are sent to another thread and used from there. As such, the
/// `snapshot` acquires a "read lock" on the database --
/// therefore, so long as the `snapshot` is not dropped, any
/// attempt to `set` a value in the database will block. If the
/// `snapshot` is owned by the same thread that is attempting to
/// `set`, this will cause a problem.
///
/// # How to implement this
///
/// Typically, this method will create a second copy of your
/// database type (`MyDatabaseType`, in the example below),
/// cloning over each of the fields from `self` into this new
/// copy. For the field that stores the salsa runtime, you should
/// use [the `Runtime::snapshot` method][rfm] to create a snapshot of the
/// runtime. Finally, package up the result using `Snapshot::new`,
/// which is a simple wrapper type that only gives `&self` access
/// to the database within (thus preventing the use of methods
/// that may mutate the inputs):
///
/// [rfm]: struct.Runtime.html#method.snapshot
///
/// ```rust,ignore
/// impl ParallelDatabase for MyDatabaseType {
/// fn snapshot(&self) -> Snapshot<Self> {
/// Snapshot::new(
/// MyDatabaseType {
/// runtime: self.runtime.snapshot(self),
/// other_field: self.other_field.clone(),
/// }
/// )
/// }
/// }
/// ```
fn snapshot(&self) -> Snapshot<Self>;
}
/// Simple wrapper struct that takes ownership of a database `DB` and
/// only gives `&self` access to it. See [the `snapshot` method][fm]
/// for more details.
///
/// [fm]: trait.ParallelDatabase.html#method.snapshot
#[derive(Debug)]
pub struct Snapshot<DB: ?Sized>
where
DB: ParallelDatabase,
{
db: DB,
}
impl<DB> Snapshot<DB>
where
DB: ParallelDatabase,
{
/// Creates a `Snapshot` that wraps the given database handle
/// `db`. From this point forward, only shared references to `db`
/// will be possible.
pub fn new(db: DB) -> Self {
Snapshot { db }
}
}
impl<DB> std::ops::Deref for Snapshot<DB>
where
DB: ParallelDatabase,
{
type Target = DB;
fn deref(&self) -> &DB {
&self.db
}
}
/// An integer that uniquely identifies a particular query instance within the
/// database. Used to track dependencies between queries. Fully ordered and
/// equatable but those orderings are arbitrary, and meant to be used only for
/// inserting into maps and the like.
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct DatabaseKeyIndex {
group_index: u16,
query_index: u16,
key_index: u32,
}
impl DatabaseKeyIndex {
/// Returns the index of the query group containing this key.
#[inline]
pub fn group_index(self) -> u16 {
self.group_index
}
/// Returns the index of the query within its query group.
#[inline]
pub fn query_index(self) -> u16 {
self.query_index
}
/// Returns the index of this particular query key within the query.
#[inline]
pub fn key_index(self) -> u32 {
self.key_index
}
/// Returns a type that gives a user-readable debug output.
/// Use like `println!("{:?}", index.debug(db))`.
pub fn debug<D: ?Sized>(self, db: &D) -> impl std::fmt::Debug + '_
where
D: plumbing::DatabaseOps,
{
DatabaseKeyIndexDebug { index: self, db }
}
}
/// Helper type for `DatabaseKeyIndex::debug`
struct DatabaseKeyIndexDebug<'me, D: ?Sized>
where
D: plumbing::DatabaseOps,
{
index: DatabaseKeyIndex,
db: &'me D,
}
impl<D: ?Sized> std::fmt::Debug for DatabaseKeyIndexDebug<'_, D>
where
D: plumbing::DatabaseOps,
{
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.db.fmt_index(self.index, fmt)
}
}
/// Trait implements by all of the "special types" associated with
/// each of your queries.
///
/// Base trait of `Query` that has a lifetime parameter to allow the `DynDb` to be non-'static.
pub trait QueryDb<'d>: Sized {
/// Dyn version of the associated trait for this query group.
type DynDb: ?Sized + Database + HasQueryGroup<Self::Group> + 'd;
/// Associate query group struct.
type Group: plumbing::QueryGroup<GroupStorage = Self::GroupStorage>;
/// Generated struct that contains storage for all queries in a group.
type GroupStorage;
}
/// Trait implements by all of the "special types" associated with
/// each of your queries.
pub trait Query: Debug + Default + Sized + for<'d> QueryDb<'d> {
/// Type that you you give as a parameter -- for queries with zero
/// or more than one input, this will be a tuple.
type Key: Clone + Debug + Hash + Eq;
/// What value does the query return?
type Value: Clone + Debug;
/// Internal struct storing the values for the query.
// type Storage: plumbing::QueryStorageOps<Self>;
type Storage;
/// A unique index identifying this query within the group.
const QUERY_INDEX: u16;
/// Name of the query method (e.g., `foo`)
const QUERY_NAME: &'static str;
/// Extact storage for this query from the storage for its group.
fn query_storage<'a>(
group_storage: &'a <Self as QueryDb<'_>>::GroupStorage,
) -> &'a std::sync::Arc<Self::Storage>;
/// Extact storage for this query from the storage for its group.
fn query_storage_mut<'a>(
group_storage: &'a <Self as QueryDb<'_>>::GroupStorage,
) -> &'a std::sync::Arc<Self::Storage>;
}
/// Return value from [the `query` method] on `Database`.
/// Gives access to various less common operations on queries.
///
/// [the `query` method]: trait.Database.html#method.query
pub struct QueryTable<'me, Q>
where
Q: Query,
{
db: &'me <Q as QueryDb<'me>>::DynDb,
storage: &'me Q::Storage,
}
impl<'me, Q> QueryTable<'me, Q>
where
Q: Query,
Q::Storage: QueryStorageOps<Q>,
{
/// Constructs a new `QueryTable`.
pub fn new(db: &'me <Q as QueryDb<'me>>::DynDb, storage: &'me Q::Storage) -> Self {
Self { db, storage }
}
/// Execute the query on a given input. Usually it's easier to
/// invoke the trait method directly. Note that for variadic
/// queries (those with no inputs, or those with more than one
/// input) the key will be a tuple.
pub fn get(&self, key: Q::Key) -> Q::Value {
self.storage.fetch(self.db, &key)
}
/// Completely clears the storage for this query.
///
/// This method breaks internal invariants of salsa, so any further queries
/// might return nonsense results. It is useful only in very specific
/// circumstances -- for example, when one wants to observe which values
/// dropped together with the table
pub fn purge(&self)
where
Q::Storage: plumbing::QueryStorageMassOps,
{
self.storage.purge();
}
}
/// Return value from [the `query_mut` method] on `Database`.
/// Gives access to the `set` method, notably, that is used to
/// set the value of an input query.
///
/// [the `query_mut` method]: trait.Database.html#method.query_mut
pub struct QueryTableMut<'me, Q>
where
Q: Query + 'me,
{
runtime: &'me mut Runtime,
storage: &'me Q::Storage,
}
impl<'me, Q> QueryTableMut<'me, Q>
where
Q: Query,
{
/// Constructs a new `QueryTableMut`.
pub fn new(runtime: &'me mut Runtime, storage: &'me Q::Storage) -> Self {
Self { runtime, storage }
}
/// Assign a value to an "input query". Must be used outside of
/// an active query computation.
///
/// If you are using `snapshot`, see the notes on blocking
/// and cancellation on [the `query_mut` method].
///
/// [the `query_mut` method]: trait.Database.html#method.query_mut
pub fn set(&mut self, key: Q::Key, value: Q::Value)
where
Q::Storage: plumbing::InputQueryStorageOps<Q>,
{
self.set_with_durability(key, value, Durability::LOW);
}
/// Assign a value to an "input query", with the additional
/// promise that this value will **never change**. Must be used
/// outside of an active query computation.
///
/// If you are using `snapshot`, see the notes on blocking
/// and cancellation on [the `query_mut` method].
///
/// [the `query_mut` method]: trait.Database.html#method.query_mut
pub fn set_with_durability(&mut self, key: Q::Key, value: Q::Value, durability: Durability)
where
Q::Storage: plumbing::InputQueryStorageOps<Q>,
{
self.storage.set(self.runtime, &key, value, durability);
}
/// Sets the size of LRU cache of values for this query table.
///
/// That is, at most `cap` values will be preset in the table at the same
/// time. This helps with keeping maximum memory usage under control, at the
/// cost of potential extra recalculations of evicted values.
///
/// If `cap` is zero, all values are preserved, this is the default.
pub fn set_lru_capacity(&self, cap: usize)
where
Q::Storage: plumbing::LruQueryStorageOps,
{
self.storage.set_lru_capacity(cap);
}
/// Marks the computed value as outdated.
///
/// This causes salsa to re-execute the query function on the next access to
/// the query, even if all dependencies are up to date.
///
/// This is most commonly used as part of the [on-demand input
/// pattern](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html).
pub fn invalidate(&mut self, key: &Q::Key)
where
Q::Storage: plumbing::DerivedQueryStorageOps<Q>,
{
self.storage.invalidate(self.runtime, key)
}
}
/// A panic payload indicating that execution of a salsa query was cancelled.
///
/// This can occur for a few reasons:
/// *
/// *
/// *
#[derive(Debug)]
#[non_exhaustive]
pub enum Cancelled {
/// The query was operating on revision R, but there is a pending write to move to revision R+1.
#[non_exhaustive]
PendingWrite,
/// The query was blocked on another thread, and that thread panicked.
#[non_exhaustive]
PropagatedPanic,
}
impl Cancelled {
fn throw(self) -> ! {
// We use resume and not panic here to avoid running the panic
// hook (that is, to avoid collecting and printing backtrace).
std::panic::resume_unwind(Box::new(self));
}
/// Runs `f`, and catches any salsa cancellation.
pub fn catch<F, T>(f: F) -> Result<T, Cancelled>
where
F: FnOnce() -> T + UnwindSafe,
{
match panic::catch_unwind(f) {
Ok(t) => Ok(t),
Err(payload) => match payload.downcast() {
Ok(cancelled) => Err(*cancelled),
Err(payload) => panic::resume_unwind(payload),
},
}
}
}
impl std::fmt::Display for Cancelled {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let why = match self {
Cancelled::PendingWrite => "pending write",
Cancelled::PropagatedPanic => "propagated panic",
};
f.write_str("cancelled because of ")?;
f.write_str(why)
}
}
impl std::error::Error for Cancelled {}
/// Captures the participants of a cycle that occurred when executing a query.
///
/// This type is meant to be used to help give meaningful error messages to the
/// user or to help salsa developers figure out why their program is resulting
/// in a computation cycle.
///
/// It is used in a few ways:
///
/// * During [cycle recovery](https://https://salsa-rs.github.io/salsa/cycles/fallback.html),
/// where it is given to the fallback function.
/// * As the panic value when an unexpected cycle (i.e., a cycle where one or more participants
/// lacks cycle recovery information) occurs.
///
/// You can read more about cycle handling in
/// the [salsa book](https://https://salsa-rs.github.io/salsa/cycles.html).
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Cycle {
participants: plumbing::CycleParticipants,
}
impl Cycle {
pub(crate) fn new(participants: plumbing::CycleParticipants) -> Self {
Self { participants }
}
/// True if two `Cycle` values represent the same cycle.
pub(crate) fn is(&self, cycle: &Cycle) -> bool {
triomphe::Arc::ptr_eq(&self.participants, &cycle.participants)
}
pub(crate) fn throw(self) -> ! {
tracing::debug!("throwing cycle {:?}", self);
std::panic::resume_unwind(Box::new(self))
}
pub(crate) fn catch<T>(execute: impl FnOnce() -> T) -> Result<T, Cycle> {
match std::panic::catch_unwind(AssertUnwindSafe(execute)) {
Ok(v) => Ok(v),
Err(err) => match err.downcast::<Cycle>() {
Ok(cycle) => Err(*cycle),
Err(other) => std::panic::resume_unwind(other),
},
}
}
/// Iterate over the [`DatabaseKeyIndex`] for each query participating
/// in the cycle. The start point of this iteration within the cycle
/// is arbitrary but deterministic, but the ordering is otherwise determined
/// by the execution.
pub fn participant_keys(&self) -> impl Iterator<Item = DatabaseKeyIndex> + '_ {
self.participants.iter().copied()
}
/// Returns a vector with the debug information for
/// all the participants in the cycle.
pub fn all_participants<DB: ?Sized + Database>(&self, db: &DB) -> Vec<String> {
self.participant_keys().map(|d| format!("{:?}", d.debug(db))).collect()
}
/// Returns a vector with the debug information for
/// those participants in the cycle that lacked recovery
/// information.
pub fn unexpected_participants<DB: ?Sized + Database>(&self, db: &DB) -> Vec<String> {
self.participant_keys()
.filter(|&d| db.cycle_recovery_strategy(d) == CycleRecoveryStrategy::Panic)
.map(|d| format!("{:?}", d.debug(db)))
.collect()
}
/// Returns a "debug" view onto this strict that can be used to print out information.
pub fn debug<'me, DB: ?Sized + Database>(&'me self, db: &'me DB) -> impl std::fmt::Debug + 'me {
struct UnexpectedCycleDebug<'me> {
c: &'me Cycle,
db: &'me dyn Database,
}
impl<'me> std::fmt::Debug for UnexpectedCycleDebug<'me> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.debug_struct("UnexpectedCycle")
.field("all_participants", &self.c.all_participants(self.db))
.field("unexpected_participants", &self.c.unexpected_participants(self.db))
.finish()
}
}
UnexpectedCycleDebug { c: self, db: db.ops_database() }
}
}
// Re-export the procedural macros.
#[allow(unused_imports)]
#[macro_use]
extern crate salsa_macros;
use plumbing::HasQueryGroup;
pub use salsa_macros::*;

325
crates/salsa/src/lru.rs Normal file
View file

@ -0,0 +1,325 @@
//!
use oorandom::Rand64;
use parking_lot::Mutex;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use triomphe::Arc;
/// A simple and approximate concurrent lru list.
///
/// We assume but do not verify that each node is only used with one
/// list. If this is not the case, it is not *unsafe*, but panics and
/// weird results will ensue.
///
/// Each "node" in the list is of type `Node` and must implement
/// `LruNode`, which is a trait that gives access to a field that
/// stores the index in the list. This index gives us a rough idea of
/// how recently the node has been used.
#[derive(Debug)]
pub(crate) struct Lru<Node>
where
Node: LruNode,
{
green_zone: AtomicUsize,
data: Mutex<LruData<Node>>,
}
#[derive(Debug)]
struct LruData<Node> {
end_red_zone: usize,
end_yellow_zone: usize,
end_green_zone: usize,
rng: Rand64,
entries: Vec<Arc<Node>>,
}
pub(crate) trait LruNode: Sized + Debug {
fn lru_index(&self) -> &LruIndex;
}
#[derive(Debug)]
pub(crate) struct LruIndex {
/// Index in the approprate LRU list, or std::usize::MAX if not a
/// member.
index: AtomicUsize,
}
impl<Node> Default for Lru<Node>
where
Node: LruNode,
{
fn default() -> Self {
Lru::new()
}
}
// We always use a fixed seed for our randomness so that we have
// predictable results.
const LRU_SEED: &str = "Hello, Rustaceans";
impl<Node> Lru<Node>
where
Node: LruNode,
{
/// Creates a new LRU list where LRU caching is disabled.
pub(crate) fn new() -> Self {
Self::with_seed(LRU_SEED)
}
#[cfg_attr(not(test), allow(dead_code))]
fn with_seed(seed: &str) -> Self {
Lru { green_zone: AtomicUsize::new(0), data: Mutex::new(LruData::with_seed(seed)) }
}
/// Adjust the total number of nodes permitted to have a value at
/// once. If `len` is zero, this disables LRU caching completely.
pub(crate) fn set_lru_capacity(&self, len: usize) {
let mut data = self.data.lock();
// We require each zone to have at least 1 slot. Therefore,
// the length cannot be just 1 or 2.
if len == 0 {
self.green_zone.store(0, Ordering::Release);
data.resize(0, 0, 0);
} else {
let len = std::cmp::max(len, 3);
// Top 10% is the green zone. This must be at least length 1.
let green_zone = std::cmp::max(len / 10, 1);
// Next 20% is the yellow zone.
let yellow_zone = std::cmp::max(len / 5, 1);
// Remaining 70% is the red zone.
let red_zone = len - yellow_zone - green_zone;
// We need quick access to the green zone.
self.green_zone.store(green_zone, Ordering::Release);
// Resize existing array.
data.resize(green_zone, yellow_zone, red_zone);
}
}
/// Records that `node` was used. This may displace an old node (if the LRU limits are
pub(crate) fn record_use(&self, node: &Arc<Node>) -> Option<Arc<Node>> {
tracing::debug!("record_use(node={:?})", node);
// Load green zone length and check if the LRU cache is even enabled.
let green_zone = self.green_zone.load(Ordering::Acquire);
tracing::debug!("record_use: green_zone={}", green_zone);
if green_zone == 0 {
return None;
}
// Find current index of list (if any) and the current length
// of our green zone.
let index = node.lru_index().load();
tracing::debug!("record_use: index={}", index);
// Already a member of the list, and in the green zone -- nothing to do!
if index < green_zone {
return None;
}
self.data.lock().record_use(node)
}
pub(crate) fn purge(&self) {
self.green_zone.store(0, Ordering::SeqCst);
*self.data.lock() = LruData::with_seed(LRU_SEED);
}
}
impl<Node> LruData<Node>
where
Node: LruNode,
{
fn with_seed(seed_str: &str) -> Self {
Self::with_rng(rng_with_seed(seed_str))
}
fn with_rng(rng: Rand64) -> Self {
LruData { end_yellow_zone: 0, end_green_zone: 0, end_red_zone: 0, entries: Vec::new(), rng }
}
fn green_zone(&self) -> std::ops::Range<usize> {
0..self.end_green_zone
}
fn yellow_zone(&self) -> std::ops::Range<usize> {
self.end_green_zone..self.end_yellow_zone
}
fn red_zone(&self) -> std::ops::Range<usize> {
self.end_yellow_zone..self.end_red_zone
}
fn resize(&mut self, len_green_zone: usize, len_yellow_zone: usize, len_red_zone: usize) {
self.end_green_zone = len_green_zone;
self.end_yellow_zone = self.end_green_zone + len_yellow_zone;
self.end_red_zone = self.end_yellow_zone + len_red_zone;
let entries = std::mem::replace(&mut self.entries, Vec::with_capacity(self.end_red_zone));
tracing::debug!("green_zone = {:?}", self.green_zone());
tracing::debug!("yellow_zone = {:?}", self.yellow_zone());
tracing::debug!("red_zone = {:?}", self.red_zone());
// We expect to resize when the LRU cache is basically empty.
// So just forget all the old LRU indices to start.
for entry in entries {
entry.lru_index().clear();
}
}
/// Records that a node was used. If it is already a member of the
/// LRU list, it is promoted to the green zone (unless it's
/// already there). Otherwise, it is added to the list first and
/// *then* promoted to the green zone. Adding a new node to the
/// list may displace an old member of the red zone, in which case
/// that is returned.
fn record_use(&mut self, node: &Arc<Node>) -> Option<Arc<Node>> {
tracing::debug!("record_use(node={:?})", node);
// NB: When this is invoked, we have typically already loaded
// the LRU index (to check if it is in green zone). But that
// check was done outside the lock and -- for all we know --
// the index may have changed since. So we always reload.
let index = node.lru_index().load();
if index < self.end_green_zone {
None
} else if index < self.end_yellow_zone {
self.promote_yellow_to_green(node, index);
None
} else if index < self.end_red_zone {
self.promote_red_to_green(node, index);
None
} else {
self.insert_new(node)
}
}
/// Inserts a node that is not yet a member of the LRU list. If
/// the list is at capacity, this can displace an existing member.
fn insert_new(&mut self, node: &Arc<Node>) -> Option<Arc<Node>> {
debug_assert!(!node.lru_index().is_in_lru());
// Easy case: we still have capacity. Push it, and then promote
// it up to the appropriate zone.
let len = self.entries.len();
if len < self.end_red_zone {
self.entries.push(node.clone());
node.lru_index().store(len);
tracing::debug!("inserted node {:?} at {}", node, len);
return self.record_use(node);
}
// Harder case: no capacity. Create some by evicting somebody from red
// zone and then promoting.
let victim_index = self.pick_index(self.red_zone());
let victim_node = std::mem::replace(&mut self.entries[victim_index], node.clone());
tracing::debug!("evicting red node {:?} from {}", victim_node, victim_index);
victim_node.lru_index().clear();
self.promote_red_to_green(node, victim_index);
Some(victim_node)
}
/// Promotes the node `node`, stored at `red_index` (in the red
/// zone), into a green index, demoting yellow/green nodes at
/// random.
///
/// NB: It is not required that `node.lru_index()` is up-to-date
/// when entering this method.
fn promote_red_to_green(&mut self, node: &Arc<Node>, red_index: usize) {
debug_assert!(self.red_zone().contains(&red_index));
// Pick a yellow at random and switch places with it.
//
// Subtle: we do not update `node.lru_index` *yet* -- we're
// going to invoke `self.promote_yellow` next, and it will get
// updated then.
let yellow_index = self.pick_index(self.yellow_zone());
tracing::debug!(
"demoting yellow node {:?} from {} to red at {}",
self.entries[yellow_index],
yellow_index,
red_index,
);
self.entries.swap(yellow_index, red_index);
self.entries[red_index].lru_index().store(red_index);
// Now move ourselves up into the green zone.
self.promote_yellow_to_green(node, yellow_index);
}
/// Promotes the node `node`, stored at `yellow_index` (in the
/// yellow zone), into a green index, demoting a green node at
/// random to replace it.
///
/// NB: It is not required that `node.lru_index()` is up-to-date
/// when entering this method.
fn promote_yellow_to_green(&mut self, node: &Arc<Node>, yellow_index: usize) {
debug_assert!(self.yellow_zone().contains(&yellow_index));
// Pick a yellow at random and switch places with it.
let green_index = self.pick_index(self.green_zone());
tracing::debug!(
"demoting green node {:?} from {} to yellow at {}",
self.entries[green_index],
green_index,
yellow_index
);
self.entries.swap(green_index, yellow_index);
self.entries[yellow_index].lru_index().store(yellow_index);
node.lru_index().store(green_index);
tracing::debug!("promoted {:?} to green index {}", node, green_index);
}
fn pick_index(&mut self, zone: std::ops::Range<usize>) -> usize {
let end_index = std::cmp::min(zone.end, self.entries.len());
self.rng.rand_range(zone.start as u64..end_index as u64) as usize
}
}
impl Default for LruIndex {
fn default() -> Self {
Self { index: AtomicUsize::new(std::usize::MAX) }
}
}
impl LruIndex {
fn load(&self) -> usize {
self.index.load(Ordering::Acquire) // see note on ordering below
}
fn store(&self, value: usize) {
self.index.store(value, Ordering::Release) // see note on ordering below
}
fn clear(&self) {
self.store(std::usize::MAX);
}
fn is_in_lru(&self) -> bool {
self.load() != std::usize::MAX
}
}
fn rng_with_seed(seed_str: &str) -> Rand64 {
let mut seed: [u8; 16] = [0; 16];
for (i, &b) in seed_str.as_bytes().iter().take(16).enumerate() {
seed[i] = b;
}
Rand64::new(u128::from_le_bytes(seed))
}
// A note on ordering:
//
// I chose to use AcqRel for the ordering but I don't think it's
// strictly needed. All writes occur under a lock, so they should be
// ordered w/r/t one another. As for the reads, they can occur
// outside the lock, but they don't themselves enable dependent reads
// -- if the reads are out of bounds, we would acquire a lock.

View file

@ -0,0 +1,238 @@
//!
#![allow(missing_docs)]
use crate::debug::TableEntry;
use crate::durability::Durability;
use crate::Cycle;
use crate::Database;
use crate::Query;
use crate::QueryTable;
use crate::QueryTableMut;
use std::borrow::Borrow;
use std::fmt::Debug;
use std::hash::Hash;
use triomphe::Arc;
pub use crate::derived::DependencyStorage;
pub use crate::derived::MemoizedStorage;
pub use crate::input::InputStorage;
pub use crate::interned::InternedStorage;
pub use crate::interned::LookupInternedStorage;
pub use crate::{revision::Revision, DatabaseKeyIndex, QueryDb, Runtime};
/// Defines various associated types. An impl of this
/// should be generated for your query-context type automatically by
/// the `database_storage` macro, so you shouldn't need to mess
/// with this trait directly.
pub trait DatabaseStorageTypes: Database {
/// Defines the "storage type", where all the query data is kept.
/// This type is defined by the `database_storage` macro.
type DatabaseStorage: Default;
}
/// Internal operations that the runtime uses to operate on the database.
pub trait DatabaseOps {
/// Upcast this type to a `dyn Database`.
fn ops_database(&self) -> &dyn Database;
/// Gives access to the underlying salsa runtime.
fn ops_salsa_runtime(&self) -> &Runtime;
/// Gives access to the underlying salsa runtime.
fn ops_salsa_runtime_mut(&mut self) -> &mut Runtime;
/// Formats a database key index in a human readable fashion.
fn fmt_index(
&self,
index: DatabaseKeyIndex,
fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result;
/// True if the computed value for `input` may have changed since `revision`.
fn maybe_changed_after(&self, input: DatabaseKeyIndex, revision: Revision) -> bool;
/// Find the `CycleRecoveryStrategy` for a given input.
fn cycle_recovery_strategy(&self, input: DatabaseKeyIndex) -> CycleRecoveryStrategy;
/// Executes the callback for each kind of query.
fn for_each_query(&self, op: &mut dyn FnMut(&dyn QueryStorageMassOps));
}
/// Internal operations performed on the query storage as a whole
/// (note that these ops do not need to know the identity of the
/// query, unlike `QueryStorageOps`).
pub trait QueryStorageMassOps {
fn purge(&self);
}
pub trait DatabaseKey: Clone + Debug + Eq + Hash {}
pub trait QueryFunction: Query {
/// See `CycleRecoveryStrategy`
const CYCLE_STRATEGY: CycleRecoveryStrategy;
fn execute(db: &<Self as QueryDb<'_>>::DynDb, key: Self::Key) -> Self::Value;
fn cycle_fallback(
db: &<Self as QueryDb<'_>>::DynDb,
cycle: &Cycle,
key: &Self::Key,
) -> Self::Value {
let _ = (db, cycle, key);
panic!("query `{:?}` doesn't support cycle fallback", Self::default())
}
}
/// Cycle recovery strategy: Is this query capable of recovering from
/// a cycle that results from executing the function? If so, how?
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum CycleRecoveryStrategy {
/// Cannot recover from cycles: panic.
///
/// This is the default. It is also what happens if a cycle
/// occurs and the queries involved have different recovery
/// strategies.
///
/// In the case of a failure due to a cycle, the panic
/// value will be XXX (FIXME).
Panic,
/// Recovers from cycles by storing a sentinel value.
///
/// This value is computed by the `QueryFunction::cycle_fallback`
/// function.
Fallback,
}
/// Create a query table, which has access to the storage for the query
/// and offers methods like `get`.
pub fn get_query_table<'me, Q>(db: &'me <Q as QueryDb<'me>>::DynDb) -> QueryTable<'me, Q>
where
Q: Query + 'me,
Q::Storage: QueryStorageOps<Q>,
{
let group_storage: &Q::GroupStorage = HasQueryGroup::group_storage(db);
let query_storage: &Q::Storage = Q::query_storage(group_storage);
QueryTable::new(db, query_storage)
}
/// Create a mutable query table, which has access to the storage
/// for the query and offers methods like `set`.
pub fn get_query_table_mut<'me, Q>(db: &'me mut <Q as QueryDb<'me>>::DynDb) -> QueryTableMut<'me, Q>
where
Q: Query,
{
let (group_storage, runtime) = HasQueryGroup::group_storage_mut(db);
let query_storage = Q::query_storage_mut(group_storage);
QueryTableMut::new(runtime, &**query_storage)
}
pub trait QueryGroup: Sized {
type GroupStorage;
/// Dyn version of the associated database trait.
type DynDb: ?Sized + Database + HasQueryGroup<Self>;
}
/// Trait implemented by a database for each group that it supports.
/// `S` and `K` are the types for *group storage* and *group key*, respectively.
pub trait HasQueryGroup<G>: Database
where
G: QueryGroup,
{
/// Access the group storage struct from the database.
fn group_storage(&self) -> &G::GroupStorage;
/// Access the group storage struct from the database.
/// Also returns a ref to the `Runtime`, since otherwise
/// the database is borrowed and one cannot get access to it.
fn group_storage_mut(&mut self) -> (&G::GroupStorage, &mut Runtime);
}
// ANCHOR:QueryStorageOps
pub trait QueryStorageOps<Q>
where
Self: QueryStorageMassOps,
Q: Query,
{
// ANCHOR_END:QueryStorageOps
/// See CycleRecoveryStrategy
const CYCLE_STRATEGY: CycleRecoveryStrategy;
fn new(group_index: u16) -> Self;
/// Format a database key index in a suitable way.
fn fmt_index(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
index: DatabaseKeyIndex,
fmt: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result;
// ANCHOR:maybe_changed_after
/// True if the value of `input`, which must be from this query, may have
/// changed after the given revision ended.
///
/// This function should only be invoked with a revision less than the current
/// revision.
fn maybe_changed_after(
&self,
db: &<Q as QueryDb<'_>>::DynDb,
input: DatabaseKeyIndex,
revision: Revision,
) -> bool;
// ANCHOR_END:maybe_changed_after
fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy {
Self::CYCLE_STRATEGY
}
// ANCHOR:fetch
/// Execute the query, returning the result (often, the result
/// will be memoized). This is the "main method" for
/// queries.
///
/// Returns `Err` in the event of a cycle, meaning that computing
/// the value for this `key` is recursively attempting to fetch
/// itself.
fn fetch(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Q::Value;
// ANCHOR_END:fetch
/// Returns the durability associated with a given key.
fn durability(&self, db: &<Q as QueryDb<'_>>::DynDb, key: &Q::Key) -> Durability;
/// Get the (current) set of the entries in the query storage
fn entries<C>(&self, db: &<Q as QueryDb<'_>>::DynDb) -> C
where
C: std::iter::FromIterator<TableEntry<Q::Key, Q::Value>>;
}
/// An optional trait that is implemented for "user mutable" storage:
/// that is, storage whose value is not derived from other storage but
/// is set independently.
pub trait InputQueryStorageOps<Q>
where
Q: Query,
{
fn set(&self, runtime: &mut Runtime, key: &Q::Key, new_value: Q::Value, durability: Durability);
}
/// An optional trait that is implemented for "user mutable" storage:
/// that is, storage whose value is not derived from other storage but
/// is set independently.
pub trait LruQueryStorageOps {
fn set_lru_capacity(&self, new_capacity: usize);
}
pub trait DerivedQueryStorageOps<Q>
where
Q: Query,
{
fn invalidate<S>(&self, runtime: &mut Runtime, key: &S)
where
S: Eq + Hash,
Q::Key: Borrow<S>;
}
pub type CycleParticipants = Arc<Vec<DatabaseKeyIndex>>;

View file

@ -0,0 +1,67 @@
//!
use std::num::NonZeroU32;
use std::sync::atomic::{AtomicU32, Ordering};
/// Value of the initial revision, as a u32. We don't use 0
/// because we want to use a `NonZeroU32`.
const START: u32 = 1;
/// A unique identifier for the current version of the database; each
/// time an input is changed, the revision number is incremented.
/// `Revision` is used internally to track which values may need to be
/// recomputed, but is not something you should have to interact with
/// directly as a user of salsa.
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Revision {
generation: NonZeroU32,
}
impl Revision {
pub(crate) fn start() -> Self {
Self::from(START)
}
pub(crate) fn from(g: u32) -> Self {
Self { generation: NonZeroU32::new(g).unwrap() }
}
pub(crate) fn next(self) -> Revision {
Self::from(self.generation.get() + 1)
}
fn as_u32(self) -> u32 {
self.generation.get()
}
}
impl std::fmt::Debug for Revision {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(fmt, "R{}", self.generation)
}
}
#[derive(Debug)]
pub(crate) struct AtomicRevision {
data: AtomicU32,
}
impl AtomicRevision {
pub(crate) fn start() -> Self {
Self { data: AtomicU32::new(START) }
}
pub(crate) fn load(&self) -> Revision {
Revision::from(self.data.load(Ordering::SeqCst))
}
pub(crate) fn store(&self, r: Revision) {
self.data.store(r.as_u32(), Ordering::SeqCst);
}
/// Increment by 1, returning previous value.
pub(crate) fn fetch_then_increment(&self) -> Revision {
let v = self.data.fetch_add(1, Ordering::SeqCst);
assert!(v != u32::max_value(), "revision overflow");
Revision::from(v)
}
}

667
crates/salsa/src/runtime.rs Normal file
View file

@ -0,0 +1,667 @@
//!
use crate::durability::Durability;
use crate::hash::FxIndexSet;
use crate::plumbing::CycleRecoveryStrategy;
use crate::revision::{AtomicRevision, Revision};
use crate::{Cancelled, Cycle, Database, DatabaseKeyIndex, Event, EventKind};
use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive};
use parking_lot::{Mutex, RwLock};
use std::hash::Hash;
use std::panic::panic_any;
use std::sync::atomic::{AtomicUsize, Ordering};
use tracing::debug;
use triomphe::Arc;
mod dependency_graph;
use dependency_graph::DependencyGraph;
pub(crate) mod local_state;
use local_state::LocalState;
use self::local_state::{ActiveQueryGuard, QueryInputs, QueryRevisions};
/// The salsa runtime stores the storage for all queries as well as
/// tracking the query stack and dependencies between cycles.
///
/// Each new runtime you create (e.g., via `Runtime::new` or
/// `Runtime::default`) will have an independent set of query storage
/// associated with it. Normally, therefore, you only do this once, at
/// the start of your application.
pub struct Runtime {
/// Our unique runtime id.
id: RuntimeId,
/// If this is a "forked" runtime, then the `revision_guard` will
/// be `Some`; this guard holds a read-lock on the global query
/// lock.
revision_guard: Option<RevisionGuard>,
/// Local state that is specific to this runtime (thread).
local_state: LocalState,
/// Shared state that is accessible via all runtimes.
shared_state: Arc<SharedState>,
}
#[derive(Clone, Debug)]
pub(crate) enum WaitResult {
Completed,
Panicked,
Cycle(Cycle),
}
impl Default for Runtime {
fn default() -> Self {
Runtime {
id: RuntimeId { counter: 0 },
revision_guard: None,
shared_state: Default::default(),
local_state: Default::default(),
}
}
}
impl std::fmt::Debug for Runtime {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.debug_struct("Runtime")
.field("id", &self.id())
.field("forked", &self.revision_guard.is_some())
.field("shared_state", &self.shared_state)
.finish()
}
}
impl Runtime {
/// Create a new runtime; equivalent to `Self::default`. This is
/// used when creating a new database.
pub fn new() -> Self {
Self::default()
}
/// See [`crate::storage::Storage::snapshot`].
pub(crate) fn snapshot(&self) -> Self {
if self.local_state.query_in_progress() {
panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)");
}
let revision_guard = RevisionGuard::new(&self.shared_state);
let id = RuntimeId { counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst) };
Runtime {
id,
revision_guard: Some(revision_guard),
shared_state: self.shared_state.clone(),
local_state: Default::default(),
}
}
/// A "synthetic write" causes the system to act *as though* some
/// input of durability `durability` has changed. This is mostly
/// useful for profiling scenarios.
///
/// **WARNING:** Just like an ordinary write, this method triggers
/// cancellation. If you invoke it while a snapshot exists, it
/// will block until that snapshot is dropped -- if that snapshot
/// is owned by the current thread, this could trigger deadlock.
pub fn synthetic_write(&mut self, durability: Durability) {
self.with_incremented_revision(|_next_revision| Some(durability));
}
/// The unique identifier attached to this `SalsaRuntime`. Each
/// snapshotted runtime has a distinct identifier.
#[inline]
pub fn id(&self) -> RuntimeId {
self.id
}
/// Returns the database-key for the query that this thread is
/// actively executing (if any).
pub fn active_query(&self) -> Option<DatabaseKeyIndex> {
self.local_state.active_query()
}
/// Read current value of the revision counter.
#[inline]
pub(crate) fn current_revision(&self) -> Revision {
self.shared_state.revisions[0].load()
}
/// The revision in which values with durability `d` may have last
/// changed. For D0, this is just the current revision. But for
/// higher levels of durability, this value may lag behind the
/// current revision. If we encounter a value of durability Di,
/// then, we can check this function to get a "bound" on when the
/// value may have changed, which allows us to skip walking its
/// dependencies.
#[inline]
pub(crate) fn last_changed_revision(&self, d: Durability) -> Revision {
self.shared_state.revisions[d.index()].load()
}
/// Read current value of the revision counter.
#[inline]
pub(crate) fn pending_revision(&self) -> Revision {
self.shared_state.pending_revision.load()
}
#[cold]
pub(crate) fn unwind_cancelled(&self) {
self.report_untracked_read();
Cancelled::PendingWrite.throw();
}
/// Acquires the **global query write lock** (ensuring that no queries are
/// executing) and then increments the current revision counter; invokes
/// `op` with the global query write lock still held.
///
/// While we wait to acquire the global query write lock, this method will
/// also increment `pending_revision_increments`, thus signalling to queries
/// that their results are "cancelled" and they should abort as expeditiously
/// as possible.
///
/// The `op` closure should actually perform the writes needed. It is given
/// the new revision as an argument, and its return value indicates whether
/// any pre-existing value was modified:
///
/// - returning `None` means that no pre-existing value was modified (this
/// could occur e.g. when setting some key on an input that was never set
/// before)
/// - returning `Some(d)` indicates that a pre-existing value was modified
/// and it had the durability `d`. This will update the records for when
/// values with each durability were modified.
///
/// Note that, given our writer model, we can assume that only one thread is
/// attempting to increment the global revision at a time.
pub(crate) fn with_incremented_revision<F>(&mut self, op: F)
where
F: FnOnce(Revision) -> Option<Durability>,
{
tracing::debug!("increment_revision()");
if !self.permits_increment() {
panic!("increment_revision invoked during a query computation");
}
// Set the `pending_revision` field so that people
// know current revision is cancelled.
let current_revision = self.shared_state.pending_revision.fetch_then_increment();
// To modify the revision, we need the lock.
let shared_state = self.shared_state.clone();
let _lock = shared_state.query_lock.write();
let old_revision = self.shared_state.revisions[0].fetch_then_increment();
assert_eq!(current_revision, old_revision);
let new_revision = current_revision.next();
debug!("increment_revision: incremented to {:?}", new_revision);
if let Some(d) = op(new_revision) {
for rev in &self.shared_state.revisions[1..=d.index()] {
rev.store(new_revision);
}
}
}
pub(crate) fn permits_increment(&self) -> bool {
self.revision_guard.is_none() && !self.local_state.query_in_progress()
}
#[inline]
pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> {
self.local_state.push_query(database_key_index)
}
/// Reports that the currently active query read the result from
/// another query.
///
/// Also checks whether the "cycle participant" flag is set on
/// the current stack frame -- if so, panics with `CycleParticipant`
/// value, which should be caught by the code executing the query.
///
/// # Parameters
///
/// - `database_key`: the query whose result was read
/// - `changed_revision`: the last revision in which the result of that
/// query had changed
pub(crate) fn report_query_read_and_unwind_if_cycle_resulted(
&self,
input: DatabaseKeyIndex,
durability: Durability,
changed_at: Revision,
) {
self.local_state
.report_query_read_and_unwind_if_cycle_resulted(input, durability, changed_at);
}
/// Reports that the query depends on some state unknown to salsa.
///
/// Queries which report untracked reads will be re-executed in the next
/// revision.
pub fn report_untracked_read(&self) {
self.local_state.report_untracked_read(self.current_revision());
}
/// Acts as though the current query had read an input with the given durability; this will force the current query's durability to be at most `durability`.
///
/// This is mostly useful to control the durability level for [on-demand inputs](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html).
pub fn report_synthetic_read(&self, durability: Durability) {
let changed_at = self.last_changed_revision(durability);
self.local_state.report_synthetic_read(durability, changed_at);
}
/// Handles a cycle in the dependency graph that was detected when the
/// current thread tried to block on `database_key_index` which is being
/// executed by `to_id`. If this function returns, then `to_id` no longer
/// depends on the current thread, and so we should continue executing
/// as normal. Otherwise, the function will throw a `Cycle` which is expected
/// to be caught by some frame on our stack. This occurs either if there is
/// a frame on our stack with cycle recovery (possibly the top one!) or if there
/// is no cycle recovery at all.
fn unblock_cycle_and_maybe_throw(
&self,
db: &dyn Database,
dg: &mut DependencyGraph,
database_key_index: DatabaseKeyIndex,
to_id: RuntimeId,
) {
debug!("unblock_cycle_and_maybe_throw(database_key={:?})", database_key_index);
let mut from_stack = self.local_state.take_query_stack();
let from_id = self.id();
// Make a "dummy stack frame". As we iterate through the cycle, we will collect the
// inputs from each participant. Then, if we are participating in cycle recovery, we
// will propagate those results to all participants.
let mut cycle_query = ActiveQuery::new(database_key_index);
// Identify the cycle participants:
let cycle = {
let mut v = vec![];
dg.for_each_cycle_participant(
from_id,
&mut from_stack,
database_key_index,
to_id,
|aqs| {
aqs.iter_mut().for_each(|aq| {
cycle_query.add_from(aq);
v.push(aq.database_key_index);
});
},
);
// We want to give the participants in a deterministic order
// (at least for this execution, not necessarily across executions),
// no matter where it started on the stack. Find the minimum
// key and rotate it to the front.
let min = v.iter().min().unwrap();
let index = v.iter().position(|p| p == min).unwrap();
v.rotate_left(index);
// No need to store extra memory.
v.shrink_to_fit();
Cycle::new(Arc::new(v))
};
debug!("cycle {:?}, cycle_query {:#?}", cycle.debug(db), cycle_query,);
// We can remove the cycle participants from the list of dependencies;
// they are a strongly connected component (SCC) and we only care about
// dependencies to things outside the SCC that control whether it will
// form again.
cycle_query.remove_cycle_participants(&cycle);
// Mark each cycle participant that has recovery set, along with
// any frames that come after them on the same thread. Those frames
// are going to be unwound so that fallback can occur.
dg.for_each_cycle_participant(from_id, &mut from_stack, database_key_index, to_id, |aqs| {
aqs.iter_mut()
.skip_while(|aq| match db.cycle_recovery_strategy(aq.database_key_index) {
CycleRecoveryStrategy::Panic => true,
CycleRecoveryStrategy::Fallback => false,
})
.for_each(|aq| {
debug!("marking {:?} for fallback", aq.database_key_index.debug(db));
aq.take_inputs_from(&cycle_query);
assert!(aq.cycle.is_none());
aq.cycle = Some(cycle.clone());
});
});
// Unblock every thread that has cycle recovery with a `WaitResult::Cycle`.
// They will throw the cycle, which will be caught by the frame that has
// cycle recovery so that it can execute that recovery.
let (me_recovered, others_recovered) =
dg.maybe_unblock_runtimes_in_cycle(from_id, &from_stack, database_key_index, to_id);
self.local_state.restore_query_stack(from_stack);
if me_recovered {
// If the current thread has recovery, we want to throw
// so that it can begin.
cycle.throw()
} else if others_recovered {
// If other threads have recovery but we didn't: return and we will block on them.
} else {
// if nobody has recover, then we panic
panic_any(cycle);
}
}
/// Block until `other_id` completes executing `database_key`;
/// panic or unwind in the case of a cycle.
///
/// `query_mutex_guard` is the guard for the current query's state;
/// it will be dropped after we have successfully registered the
/// dependency.
///
/// # Propagating panics
///
/// If the thread `other_id` panics, then our thread is considered
/// cancelled, so this function will panic with a `Cancelled` value.
///
/// # Cycle handling
///
/// If the thread `other_id` already depends on the current thread,
/// and hence there is a cycle in the query graph, then this function
/// will unwind instead of returning normally. The method of unwinding
/// depends on the [`Self::mutual_cycle_recovery_strategy`]
/// of the cycle participants:
///
/// * [`CycleRecoveryStrategy::Panic`]: panic with the [`Cycle`] as the value.
/// * [`CycleRecoveryStrategy::Fallback`]: initiate unwinding with [`CycleParticipant::unwind`].
pub(crate) fn block_on_or_unwind<QueryMutexGuard>(
&self,
db: &dyn Database,
database_key: DatabaseKeyIndex,
other_id: RuntimeId,
query_mutex_guard: QueryMutexGuard,
) {
let mut dg = self.shared_state.dependency_graph.lock();
if dg.depends_on(other_id, self.id()) {
self.unblock_cycle_and_maybe_throw(db, &mut dg, database_key, other_id);
// If the above fn returns, then (via cycle recovery) it has unblocked the
// cycle, so we can continue.
assert!(!dg.depends_on(other_id, self.id()));
}
db.salsa_event(Event {
runtime_id: self.id(),
kind: EventKind::WillBlockOn { other_runtime_id: other_id, database_key },
});
let stack = self.local_state.take_query_stack();
let (stack, result) = DependencyGraph::block_on(
dg,
self.id(),
database_key,
other_id,
stack,
query_mutex_guard,
);
self.local_state.restore_query_stack(stack);
match result {
WaitResult::Completed => (),
// If the other thread panicked, then we consider this thread
// cancelled. The assumption is that the panic will be detected
// by the other thread and responded to appropriately.
WaitResult::Panicked => Cancelled::PropagatedPanic.throw(),
WaitResult::Cycle(c) => c.throw(),
}
}
/// Invoked when this runtime completed computing `database_key` with
/// the given result `wait_result` (`wait_result` should be `None` if
/// computing `database_key` panicked and could not complete).
/// This function unblocks any dependent queries and allows them
/// to continue executing.
pub(crate) fn unblock_queries_blocked_on(
&self,
database_key: DatabaseKeyIndex,
wait_result: WaitResult,
) {
self.shared_state
.dependency_graph
.lock()
.unblock_runtimes_blocked_on(database_key, wait_result);
}
}
/// State that will be common to all threads (when we support multiple threads)
struct SharedState {
/// Stores the next id to use for a snapshotted runtime (starts at 1).
next_id: AtomicUsize,
/// Whenever derived queries are executing, they acquire this lock
/// in read mode. Mutating inputs (and thus creating a new
/// revision) requires a write lock (thus guaranteeing that no
/// derived queries are in progress). Note that this is not needed
/// to prevent **race conditions** -- the revision counter itself
/// is stored in an `AtomicUsize` so it can be cheaply read
/// without acquiring the lock. Rather, the `query_lock` is used
/// to ensure a higher-level consistency property.
query_lock: RwLock<()>,
/// This is typically equal to `revision` -- set to `revision+1`
/// when a new revision is pending (which implies that the current
/// revision is cancelled).
pending_revision: AtomicRevision,
/// Stores the "last change" revision for values of each duration.
/// This vector is always of length at least 1 (for Durability 0)
/// but its total length depends on the number of durations. The
/// element at index 0 is special as it represents the "current
/// revision". In general, we have the invariant that revisions
/// in here are *declining* -- that is, `revisions[i] >=
/// revisions[i + 1]`, for all `i`. This is because when you
/// modify a value with durability D, that implies that values
/// with durability less than D may have changed too.
revisions: Vec<AtomicRevision>,
/// The dependency graph tracks which runtimes are blocked on one
/// another, waiting for queries to terminate.
dependency_graph: Mutex<DependencyGraph>,
}
impl SharedState {
fn with_durabilities(durabilities: usize) -> Self {
SharedState {
next_id: AtomicUsize::new(1),
query_lock: Default::default(),
revisions: (0..durabilities).map(|_| AtomicRevision::start()).collect(),
pending_revision: AtomicRevision::start(),
dependency_graph: Default::default(),
}
}
}
impl std::panic::RefUnwindSafe for SharedState {}
impl Default for SharedState {
fn default() -> Self {
Self::with_durabilities(Durability::LEN)
}
}
impl std::fmt::Debug for SharedState {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let query_lock = if self.query_lock.try_write().is_some() {
"<unlocked>"
} else if self.query_lock.try_read().is_some() {
"<rlocked>"
} else {
"<wlocked>"
};
fmt.debug_struct("SharedState")
.field("query_lock", &query_lock)
.field("revisions", &self.revisions)
.field("pending_revision", &self.pending_revision)
.finish()
}
}
#[derive(Debug)]
struct ActiveQuery {
/// What query is executing
database_key_index: DatabaseKeyIndex,
/// Minimum durability of inputs observed so far.
durability: Durability,
/// Maximum revision of all inputs observed. If we observe an
/// untracked read, this will be set to the most recent revision.
changed_at: Revision,
/// Set of subqueries that were accessed thus far, or `None` if
/// there was an untracked the read.
dependencies: Option<FxIndexSet<DatabaseKeyIndex>>,
/// Stores the entire cycle, if one is found and this query is part of it.
cycle: Option<Cycle>,
}
impl ActiveQuery {
fn new(database_key_index: DatabaseKeyIndex) -> Self {
ActiveQuery {
database_key_index,
durability: Durability::MAX,
changed_at: Revision::start(),
dependencies: Some(FxIndexSet::default()),
cycle: None,
}
}
fn add_read(&mut self, input: DatabaseKeyIndex, durability: Durability, revision: Revision) {
if let Some(set) = &mut self.dependencies {
set.insert(input);
}
self.durability = self.durability.min(durability);
self.changed_at = self.changed_at.max(revision);
}
fn add_untracked_read(&mut self, changed_at: Revision) {
self.dependencies = None;
self.durability = Durability::LOW;
self.changed_at = changed_at;
}
fn add_synthetic_read(&mut self, durability: Durability, revision: Revision) {
self.dependencies = None;
self.durability = self.durability.min(durability);
self.changed_at = self.changed_at.max(revision);
}
pub(crate) fn revisions(&self) -> QueryRevisions {
let inputs = match &self.dependencies {
None => QueryInputs::Untracked,
Some(dependencies) => {
if dependencies.is_empty() {
QueryInputs::NoInputs
} else {
QueryInputs::Tracked { inputs: dependencies.iter().copied().collect() }
}
}
};
QueryRevisions { changed_at: self.changed_at, inputs, durability: self.durability }
}
/// Adds any dependencies from `other` into `self`.
/// Used during cycle recovery, see [`Runtime::create_cycle_error`].
fn add_from(&mut self, other: &ActiveQuery) {
self.changed_at = self.changed_at.max(other.changed_at);
self.durability = self.durability.min(other.durability);
if let Some(other_dependencies) = &other.dependencies {
if let Some(my_dependencies) = &mut self.dependencies {
my_dependencies.extend(other_dependencies.iter().copied());
}
} else {
self.dependencies = None;
}
}
/// Removes the participants in `cycle` from my dependencies.
/// Used during cycle recovery, see [`Runtime::create_cycle_error`].
fn remove_cycle_participants(&mut self, cycle: &Cycle) {
if let Some(my_dependencies) = &mut self.dependencies {
for p in cycle.participant_keys() {
my_dependencies.remove(&p);
}
}
}
/// Copy the changed-at, durability, and dependencies from `cycle_query`.
/// Used during cycle recovery, see [`Runtime::create_cycle_error`].
pub(crate) fn take_inputs_from(&mut self, cycle_query: &ActiveQuery) {
self.changed_at = cycle_query.changed_at;
self.durability = cycle_query.durability;
self.dependencies = cycle_query.dependencies.clone();
}
}
/// A unique identifier for a particular runtime. Each time you create
/// a snapshot, a fresh `RuntimeId` is generated. Once a snapshot is
/// complete, its `RuntimeId` may potentially be re-used.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct RuntimeId {
counter: usize,
}
#[derive(Clone, Debug)]
pub(crate) struct StampedValue<V> {
pub(crate) value: V,
pub(crate) durability: Durability,
pub(crate) changed_at: Revision,
}
struct RevisionGuard {
shared_state: Arc<SharedState>,
}
impl RevisionGuard {
fn new(shared_state: &Arc<SharedState>) -> Self {
// Subtle: we use a "recursive" lock here so that it is not an
// error to acquire a read-lock when one is already held (this
// happens when a query uses `snapshot` to spawn off parallel
// workers, for example).
//
// This has the side-effect that we are responsible to ensure
// that people contending for the write lock do not starve,
// but this is what we achieve via the cancellation mechanism.
//
// (In particular, since we only ever have one "mutating
// handle" to the database, the only contention for the global
// query lock occurs when there are "futures" evaluating
// queries in parallel, and those futures hold a read-lock
// already, so the starvation problem is more about them bring
// themselves to a close, versus preventing other people from
// *starting* work).
unsafe {
shared_state.query_lock.raw().lock_shared_recursive();
}
Self { shared_state: shared_state.clone() }
}
}
impl Drop for RevisionGuard {
fn drop(&mut self) {
// Release our read-lock without using RAII. As documented in
// `Snapshot::new` above, this requires the unsafe keyword.
unsafe {
self.shared_state.query_lock.raw().unlock_shared();
}
}
}

View file

@ -0,0 +1,251 @@
//!
use triomphe::Arc;
use crate::{DatabaseKeyIndex, RuntimeId};
use parking_lot::{Condvar, MutexGuard};
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use super::{ActiveQuery, WaitResult};
type QueryStack = Vec<ActiveQuery>;
#[derive(Debug, Default)]
pub(super) struct DependencyGraph {
/// A `(K -> V)` pair in this map indicates that the the runtime
/// `K` is blocked on some query executing in the runtime `V`.
/// This encodes a graph that must be acyclic (or else deadlock
/// will result).
edges: FxHashMap<RuntimeId, Edge>,
/// Encodes the `RuntimeId` that are blocked waiting for the result
/// of a given query.
query_dependents: FxHashMap<DatabaseKeyIndex, SmallVec<[RuntimeId; 4]>>,
/// When a key K completes which had dependent queries Qs blocked on it,
/// it stores its `WaitResult` here. As they wake up, each query Q in Qs will
/// come here to fetch their results.
wait_results: FxHashMap<RuntimeId, (QueryStack, WaitResult)>,
}
#[derive(Debug)]
struct Edge {
blocked_on_id: RuntimeId,
blocked_on_key: DatabaseKeyIndex,
stack: QueryStack,
/// Signalled whenever a query with dependents completes.
/// Allows those dependents to check if they are ready to unblock.
condvar: Arc<parking_lot::Condvar>,
}
impl DependencyGraph {
/// True if `from_id` depends on `to_id`.
///
/// (i.e., there is a path from `from_id` to `to_id` in the graph.)
pub(super) fn depends_on(&mut self, from_id: RuntimeId, to_id: RuntimeId) -> bool {
let mut p = from_id;
while let Some(q) = self.edges.get(&p).map(|edge| edge.blocked_on_id) {
if q == to_id {
return true;
}
p = q;
}
p == to_id
}
/// Invokes `closure` with a `&mut ActiveQuery` for each query that participates in the cycle.
/// The cycle runs as follows:
///
/// 1. The runtime `from_id`, which has the stack `from_stack`, would like to invoke `database_key`...
/// 2. ...but `database_key` is already being executed by `to_id`...
/// 3. ...and `to_id` is transitively dependent on something which is present on `from_stack`.
pub(super) fn for_each_cycle_participant(
&mut self,
from_id: RuntimeId,
from_stack: &mut QueryStack,
database_key: DatabaseKeyIndex,
to_id: RuntimeId,
mut closure: impl FnMut(&mut [ActiveQuery]),
) {
debug_assert!(self.depends_on(to_id, from_id));
// To understand this algorithm, consider this [drawing](https://is.gd/TGLI9v):
//
// database_key = QB2
// from_id = A
// to_id = B
// from_stack = [QA1, QA2, QA3]
//
// self.edges[B] = { C, QC2, [QB1..QB3] }
// self.edges[C] = { A, QA2, [QC1..QC3] }
//
// The cyclic
// edge we have
// failed to add.
// :
// A : B C
// :
// QA1 v QB1 QC1
// ┌► QA2 ┌──► QB2 ┌─► QC2
// │ QA3 ───┘ QB3 ──┘ QC3 ───┐
// │ │
// └───────────────────────────────┘
//
// Final output: [QB2, QB3, QC2, QC3, QA2, QA3]
let mut id = to_id;
let mut key = database_key;
while id != from_id {
// Looking at the diagram above, the idea is to
// take the edge from `to_id` starting at `key`
// (inclusive) and down to the end. We can then
// load up the next thread (i.e., we start at B/QB2,
// and then load up the dependency on C/QC2).
let edge = self.edges.get_mut(&id).unwrap();
let prefix = edge.stack.iter_mut().take_while(|p| p.database_key_index != key).count();
closure(&mut edge.stack[prefix..]);
id = edge.blocked_on_id;
key = edge.blocked_on_key;
}
// Finally, we copy in the results from `from_stack`.
let prefix = from_stack.iter_mut().take_while(|p| p.database_key_index != key).count();
closure(&mut from_stack[prefix..]);
}
/// Unblock each blocked runtime (excluding the current one) if some
/// query executing in that runtime is participating in cycle fallback.
///
/// Returns a boolean (Current, Others) where:
/// * Current is true if the current runtime has cycle participants
/// with fallback;
/// * Others is true if other runtimes were unblocked.
pub(super) fn maybe_unblock_runtimes_in_cycle(
&mut self,
from_id: RuntimeId,
from_stack: &QueryStack,
database_key: DatabaseKeyIndex,
to_id: RuntimeId,
) -> (bool, bool) {
// See diagram in `for_each_cycle_participant`.
let mut id = to_id;
let mut key = database_key;
let mut others_unblocked = false;
while id != from_id {
let edge = self.edges.get(&id).unwrap();
let prefix = edge.stack.iter().take_while(|p| p.database_key_index != key).count();
let next_id = edge.blocked_on_id;
let next_key = edge.blocked_on_key;
if let Some(cycle) = edge.stack[prefix..].iter().rev().find_map(|aq| aq.cycle.clone()) {
// Remove `id` from the list of runtimes blocked on `next_key`:
self.query_dependents.get_mut(&next_key).unwrap().retain(|r| *r != id);
// Unblock runtime so that it can resume execution once lock is released:
self.unblock_runtime(id, WaitResult::Cycle(cycle));
others_unblocked = true;
}
id = next_id;
key = next_key;
}
let prefix = from_stack.iter().take_while(|p| p.database_key_index != key).count();
let this_unblocked = from_stack[prefix..].iter().any(|aq| aq.cycle.is_some());
(this_unblocked, others_unblocked)
}
/// Modifies the graph so that `from_id` is blocked
/// on `database_key`, which is being computed by
/// `to_id`.
///
/// For this to be reasonable, the lock on the
/// results table for `database_key` must be held.
/// This ensures that computing `database_key` doesn't
/// complete before `block_on` executes.
///
/// Preconditions:
/// * No path from `to_id` to `from_id`
/// (i.e., `me.depends_on(to_id, from_id)` is false)
/// * `held_mutex` is a read lock (or stronger) on `database_key`
pub(super) fn block_on<QueryMutexGuard>(
mut me: MutexGuard<'_, Self>,
from_id: RuntimeId,
database_key: DatabaseKeyIndex,
to_id: RuntimeId,
from_stack: QueryStack,
query_mutex_guard: QueryMutexGuard,
) -> (QueryStack, WaitResult) {
let condvar = me.add_edge(from_id, database_key, to_id, from_stack);
// Release the mutex that prevents `database_key`
// from completing, now that the edge has been added.
drop(query_mutex_guard);
loop {
if let Some(stack_and_result) = me.wait_results.remove(&from_id) {
debug_assert!(!me.edges.contains_key(&from_id));
return stack_and_result;
}
condvar.wait(&mut me);
}
}
/// Helper for `block_on`: performs actual graph modification
/// to add a dependency edge from `from_id` to `to_id`, which is
/// computing `database_key`.
fn add_edge(
&mut self,
from_id: RuntimeId,
database_key: DatabaseKeyIndex,
to_id: RuntimeId,
from_stack: QueryStack,
) -> Arc<parking_lot::Condvar> {
assert_ne!(from_id, to_id);
debug_assert!(!self.edges.contains_key(&from_id));
debug_assert!(!self.depends_on(to_id, from_id));
let condvar = Arc::new(Condvar::new());
self.edges.insert(
from_id,
Edge {
blocked_on_id: to_id,
blocked_on_key: database_key,
stack: from_stack,
condvar: condvar.clone(),
},
);
self.query_dependents.entry(database_key).or_default().push(from_id);
condvar
}
/// Invoked when runtime `to_id` completes executing
/// `database_key`.
pub(super) fn unblock_runtimes_blocked_on(
&mut self,
database_key: DatabaseKeyIndex,
wait_result: WaitResult,
) {
let dependents = self.query_dependents.remove(&database_key).unwrap_or_default();
for from_id in dependents {
self.unblock_runtime(from_id, wait_result.clone());
}
}
/// Unblock the runtime with the given id with the given wait-result.
/// This will cause it resume execution (though it will have to grab
/// the lock on this data structure first, to recover the wait result).
fn unblock_runtime(&mut self, id: RuntimeId, wait_result: WaitResult) {
let edge = self.edges.remove(&id).expect("not blocked");
self.wait_results.insert(id, (edge.stack, wait_result));
// Now that we have inserted the `wait_results`,
// notify the thread.
edge.condvar.notify_one();
}
}

View file

@ -0,0 +1,214 @@
//!
use tracing::debug;
use crate::durability::Durability;
use crate::runtime::ActiveQuery;
use crate::runtime::Revision;
use crate::Cycle;
use crate::DatabaseKeyIndex;
use std::cell::RefCell;
use triomphe::Arc;
/// State that is specific to a single execution thread.
///
/// Internally, this type uses ref-cells.
///
/// **Note also that all mutations to the database handle (and hence
/// to the local-state) must be undone during unwinding.**
pub(super) struct LocalState {
/// Vector of active queries.
///
/// This is normally `Some`, but it is set to `None`
/// while the query is blocked waiting for a result.
///
/// Unwinding note: pushes onto this vector must be popped -- even
/// during unwinding.
query_stack: RefCell<Option<Vec<ActiveQuery>>>,
}
/// Summarizes "all the inputs that a query used"
#[derive(Debug, Clone)]
pub(crate) struct QueryRevisions {
/// The most revision in which some input changed.
pub(crate) changed_at: Revision,
/// Minimum durability of the inputs to this query.
pub(crate) durability: Durability,
/// The inputs that went into our query, if we are tracking them.
pub(crate) inputs: QueryInputs,
}
/// Every input.
#[derive(Debug, Clone)]
pub(crate) enum QueryInputs {
/// Non-empty set of inputs, fully known
Tracked { inputs: Arc<[DatabaseKeyIndex]> },
/// Empty set of inputs, fully known.
NoInputs,
/// Unknown quantity of inputs
Untracked,
}
impl Default for LocalState {
fn default() -> Self {
LocalState { query_stack: RefCell::new(Some(Vec::new())) }
}
}
impl LocalState {
#[inline]
pub(super) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> {
let mut query_stack = self.query_stack.borrow_mut();
let query_stack = query_stack.as_mut().expect("local stack taken");
query_stack.push(ActiveQuery::new(database_key_index));
ActiveQueryGuard { local_state: self, database_key_index, push_len: query_stack.len() }
}
fn with_query_stack<R>(&self, c: impl FnOnce(&mut Vec<ActiveQuery>) -> R) -> R {
c(self.query_stack.borrow_mut().as_mut().expect("query stack taken"))
}
pub(super) fn query_in_progress(&self) -> bool {
self.with_query_stack(|stack| !stack.is_empty())
}
pub(super) fn active_query(&self) -> Option<DatabaseKeyIndex> {
self.with_query_stack(|stack| {
stack.last().map(|active_query| active_query.database_key_index)
})
}
pub(super) fn report_query_read_and_unwind_if_cycle_resulted(
&self,
input: DatabaseKeyIndex,
durability: Durability,
changed_at: Revision,
) {
debug!(
"report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})",
input, durability, changed_at
);
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_read(input, durability, changed_at);
// We are a cycle participant:
//
// C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0
// ^ ^
// : |
// This edge -----+ |
// |
// |
// N0
//
// In this case, the value we have just read from `Ci+1`
// is actually the cycle fallback value and not especially
// interesting. We unwind now with `CycleParticipant` to avoid
// executing the rest of our query function. This unwinding
// will be caught and our own fallback value will be used.
//
// Note that `Ci+1` may` have *other* callers who are not
// participants in the cycle (e.g., N0 in the graph above).
// They will not have the `cycle` marker set in their
// stack frames, so they will just read the fallback value
// from `Ci+1` and continue on their merry way.
if let Some(cycle) = &top_query.cycle {
cycle.clone().throw()
}
}
})
}
pub(super) fn report_untracked_read(&self, current_revision: Revision) {
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_untracked_read(current_revision);
}
})
}
/// Update the top query on the stack to act as though it read a value
/// of durability `durability` which changed in `revision`.
pub(super) fn report_synthetic_read(&self, durability: Durability, revision: Revision) {
self.with_query_stack(|stack| {
if let Some(top_query) = stack.last_mut() {
top_query.add_synthetic_read(durability, revision);
}
})
}
/// Takes the query stack and returns it. This is used when
/// the current thread is blocking. The stack must be restored
/// with [`Self::restore_query_stack`] when the thread unblocks.
pub(super) fn take_query_stack(&self) -> Vec<ActiveQuery> {
assert!(self.query_stack.borrow().is_some(), "query stack already taken");
self.query_stack.take().unwrap()
}
/// Restores a query stack taken with [`Self::take_query_stack`] once
/// the thread unblocks.
pub(super) fn restore_query_stack(&self, stack: Vec<ActiveQuery>) {
assert!(self.query_stack.borrow().is_none(), "query stack not taken");
self.query_stack.replace(Some(stack));
}
}
impl std::panic::RefUnwindSafe for LocalState {}
/// When a query is pushed onto the `active_query` stack, this guard
/// is returned to represent its slot. The guard can be used to pop
/// the query from the stack -- in the case of unwinding, the guard's
/// destructor will also remove the query.
pub(crate) struct ActiveQueryGuard<'me> {
local_state: &'me LocalState,
push_len: usize,
database_key_index: DatabaseKeyIndex,
}
impl ActiveQueryGuard<'_> {
fn pop_helper(&self) -> ActiveQuery {
self.local_state.with_query_stack(|stack| {
// Sanity check: pushes and pops should be balanced.
assert_eq!(stack.len(), self.push_len);
debug_assert_eq!(stack.last().unwrap().database_key_index, self.database_key_index);
stack.pop().unwrap()
})
}
/// Invoked when the query has successfully completed execution.
pub(super) fn complete(self) -> ActiveQuery {
let query = self.pop_helper();
std::mem::forget(self);
query
}
/// Pops an active query from the stack. Returns the [`QueryRevisions`]
/// which summarizes the other queries that were accessed during this
/// query's execution.
#[inline]
pub(crate) fn pop(self) -> QueryRevisions {
// Extract accumulated inputs.
let popped_query = self.complete();
// If this frame were a cycle participant, it would have unwound.
assert!(popped_query.cycle.is_none());
popped_query.revisions()
}
/// If the active query is registered as a cycle participant, remove and
/// return that cycle.
pub(crate) fn take_cycle(&self) -> Option<Cycle> {
self.local_state.with_query_stack(|stack| stack.last_mut()?.cycle.take())
}
}
impl Drop for ActiveQueryGuard<'_> {
fn drop(&mut self) {
self.pop_helper();
}
}

View file

@ -0,0 +1,54 @@
//!
use crate::{plumbing::DatabaseStorageTypes, Runtime};
use triomphe::Arc;
/// Stores the cached results and dependency information for all the queries
/// defined on your salsa database. Also embeds a [`Runtime`] which is used to
/// manage query execution. Every database must include a `storage:
/// Storage<Self>` field.
pub struct Storage<DB: DatabaseStorageTypes> {
query_store: Arc<DB::DatabaseStorage>,
runtime: Runtime,
}
impl<DB: DatabaseStorageTypes> Default for Storage<DB> {
fn default() -> Self {
Self { query_store: Default::default(), runtime: Default::default() }
}
}
impl<DB: DatabaseStorageTypes> Storage<DB> {
/// Gives access to the underlying salsa runtime.
pub fn salsa_runtime(&self) -> &Runtime {
&self.runtime
}
/// Gives access to the underlying salsa runtime.
pub fn salsa_runtime_mut(&mut self) -> &mut Runtime {
&mut self.runtime
}
/// Access the query storage tables. Not meant to be used directly by end
/// users.
pub fn query_store(&self) -> &DB::DatabaseStorage {
&self.query_store
}
/// Access the query storage tables. Not meant to be used directly by end
/// users.
pub fn query_store_mut(&mut self) -> (&DB::DatabaseStorage, &mut Runtime) {
(&self.query_store, &mut self.runtime)
}
/// Returns a "snapshotted" storage, suitable for use in a forked database.
/// This snapshot hold a read-lock on the global state, which means that any
/// attempt to `set` an input will block until the forked runtime is
/// dropped. See `ParallelDatabase::snapshot` for more information.
///
/// **Warning.** This second handle is intended to be used from a separate
/// thread. Using two database handles from the **same thread** can lead to
/// deadlock.
pub fn snapshot(&self) -> Self {
Storage { query_store: self.query_store.clone(), runtime: self.runtime.snapshot() }
}
}

View file

@ -0,0 +1,493 @@
use std::panic::UnwindSafe;
use expect_test::expect;
use salsa::{Durability, ParallelDatabase, Snapshot};
use test_log::test;
// Axes:
//
// Threading
// * Intra-thread
// * Cross-thread -- part of cycle is on one thread, part on another
//
// Recovery strategies:
// * Panic
// * Fallback
// * Mixed -- multiple strategies within cycle participants
//
// Across revisions:
// * N/A -- only one revision
// * Present in new revision, not old
// * Present in old revision, not new
// * Present in both revisions
//
// Dependencies
// * Tracked
// * Untracked -- cycle participant(s) contain untracked reads
//
// Layers
// * Direct -- cycle participant is directly invoked from test
// * Indirect -- invoked a query that invokes the cycle
//
//
// | Thread | Recovery | Old, New | Dep style | Layers | Test Name |
// | ------ | -------- | -------- | --------- | ------ | --------- |
// | Intra | Panic | N/A | Tracked | direct | cycle_memoized |
// | Intra | Panic | N/A | Untracked | direct | cycle_volatile |
// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle |
// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle |
// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate |
// | Intra | Fallback | New | Tracked | direct | cycle_appears |
// | Intra | Fallback | Old | Tracked | direct | cycle_disappears |
// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability |
// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 |
// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 |
// | Cross | Fallback | N/A | Tracked | both | parallel/cycles.rs: recover_parallel_cycle |
// | Cross | Panic | N/A | Tracked | both | parallel/cycles.rs: panic_parallel_cycle |
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
struct Error {
cycle: Vec<String>,
}
#[salsa::database(GroupStruct)]
#[derive(Default)]
struct DatabaseImpl {
storage: salsa::Storage<Self>,
}
impl salsa::Database for DatabaseImpl {}
impl ParallelDatabase for DatabaseImpl {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(DatabaseImpl { storage: self.storage.snapshot() })
}
}
/// The queries A, B, and C in `Database` can be configured
/// to invoke one another in arbitrary ways using this
/// enum.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum CycleQuery {
None,
A,
B,
C,
AthenC,
}
#[salsa::query_group(GroupStruct)]
trait Database: salsa::Database {
// `a` and `b` depend on each other and form a cycle
fn memoized_a(&self) -> ();
fn memoized_b(&self) -> ();
fn volatile_a(&self) -> ();
fn volatile_b(&self) -> ();
#[salsa::input]
fn a_invokes(&self) -> CycleQuery;
#[salsa::input]
fn b_invokes(&self) -> CycleQuery;
#[salsa::input]
fn c_invokes(&self) -> CycleQuery;
#[salsa::cycle(recover_a)]
fn cycle_a(&self) -> Result<(), Error>;
#[salsa::cycle(recover_b)]
fn cycle_b(&self) -> Result<(), Error>;
fn cycle_c(&self) -> Result<(), Error>;
}
fn recover_a(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> {
Err(Error { cycle: cycle.all_participants(db) })
}
fn recover_b(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> {
Err(Error { cycle: cycle.all_participants(db) })
}
fn memoized_a(db: &dyn Database) {
db.memoized_b()
}
fn memoized_b(db: &dyn Database) {
db.memoized_a()
}
fn volatile_a(db: &dyn Database) {
db.salsa_runtime().report_untracked_read();
db.volatile_b()
}
fn volatile_b(db: &dyn Database) {
db.salsa_runtime().report_untracked_read();
db.volatile_a()
}
impl CycleQuery {
fn invoke(self, db: &dyn Database) -> Result<(), Error> {
match self {
CycleQuery::A => db.cycle_a(),
CycleQuery::B => db.cycle_b(),
CycleQuery::C => db.cycle_c(),
CycleQuery::AthenC => {
let _ = db.cycle_a();
db.cycle_c()
}
CycleQuery::None => Ok(()),
}
}
}
fn cycle_a(db: &dyn Database) -> Result<(), Error> {
db.a_invokes().invoke(db)
}
fn cycle_b(db: &dyn Database) -> Result<(), Error> {
db.b_invokes().invoke(db)
}
fn cycle_c(db: &dyn Database) -> Result<(), Error> {
db.c_invokes().invoke(db)
}
#[track_caller]
fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle {
let v = std::panic::catch_unwind(f);
if let Err(d) = &v {
if let Some(cycle) = d.downcast_ref::<salsa::Cycle>() {
return cycle.clone();
}
}
panic!("unexpected value: {:?}", v)
}
#[test]
fn cycle_memoized() {
let db = DatabaseImpl::default();
let cycle = extract_cycle(|| db.memoized_a());
expect![[r#"
[
"memoized_a(())",
"memoized_b(())",
]
"#]]
.assert_debug_eq(&cycle.unexpected_participants(&db));
}
#[test]
fn cycle_volatile() {
let db = DatabaseImpl::default();
let cycle = extract_cycle(|| db.volatile_a());
expect![[r#"
[
"volatile_a(())",
"volatile_b(())",
]
"#]]
.assert_debug_eq(&cycle.unexpected_participants(&db));
}
#[test]
fn cycle_cycle() {
let mut query = DatabaseImpl::default();
// A --> B
// ^ |
// +-----+
query.set_a_invokes(CycleQuery::B);
query.set_b_invokes(CycleQuery::A);
assert!(query.cycle_a().is_err());
}
#[test]
fn inner_cycle() {
let mut query = DatabaseImpl::default();
// A --> B <-- C
// ^ |
// +-----+
query.set_a_invokes(CycleQuery::B);
query.set_b_invokes(CycleQuery::A);
query.set_c_invokes(CycleQuery::B);
let err = query.cycle_c();
assert!(err.is_err());
let cycle = err.unwrap_err().cycle;
expect![[r#"
[
"cycle_a(())",
"cycle_b(())",
]
"#]]
.assert_debug_eq(&cycle);
}
#[test]
fn cycle_revalidate() {
let mut db = DatabaseImpl::default();
// A --> B
// ^ |
// +-----+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::A);
assert!(db.cycle_a().is_err());
db.set_b_invokes(CycleQuery::A); // same value as default
assert!(db.cycle_a().is_err());
}
#[test]
fn cycle_revalidate_unchanged_twice() {
let mut db = DatabaseImpl::default();
// A --> B
// ^ |
// +-----+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::A);
assert!(db.cycle_a().is_err());
db.set_c_invokes(CycleQuery::A); // force new revisi5on
// on this run
expect![[r#"
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
)
"#]]
.assert_debug_eq(&db.cycle_a());
}
#[test]
fn cycle_appears() {
let mut db = DatabaseImpl::default();
// A --> B
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::None);
assert!(db.cycle_a().is_ok());
// A --> B
// ^ |
// +-----+
db.set_b_invokes(CycleQuery::A);
tracing::debug!("Set Cycle Leaf");
assert!(db.cycle_a().is_err());
}
#[test]
fn cycle_disappears() {
let mut db = DatabaseImpl::default();
// A --> B
// ^ |
// +-----+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::A);
assert!(db.cycle_a().is_err());
// A --> B
db.set_b_invokes(CycleQuery::None);
assert!(db.cycle_a().is_ok());
}
/// A variant on `cycle_disappears` in which the values of
/// `a_invokes` and `b_invokes` are set with durability values.
/// If we are not careful, this could cause us to overlook
/// the fact that the cycle will no longer occur.
#[test]
fn cycle_disappears_durability() {
let mut db = DatabaseImpl::default();
db.set_a_invokes_with_durability(CycleQuery::B, Durability::LOW);
db.set_b_invokes_with_durability(CycleQuery::A, Durability::HIGH);
let res = db.cycle_a();
assert!(res.is_err());
// At this point, `a` read `LOW` input, and `b` read `HIGH` input. However,
// because `b` participates in the same cycle as `a`, its final durability
// should be `LOW`.
//
// Check that setting a `LOW` input causes us to re-execute `b` query, and
// observe that the cycle goes away.
db.set_a_invokes_with_durability(CycleQuery::None, Durability::LOW);
let res = db.cycle_b();
assert!(res.is_ok());
}
#[test]
fn cycle_mixed_1() {
let mut db = DatabaseImpl::default();
// A --> B <-- C
// | ^
// +-----+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::C);
db.set_c_invokes(CycleQuery::B);
let u = db.cycle_c();
expect![[r#"
Err(
Error {
cycle: [
"cycle_b(())",
"cycle_c(())",
],
},
)
"#]]
.assert_debug_eq(&u);
}
#[test]
fn cycle_mixed_2() {
let mut db = DatabaseImpl::default();
// Configuration:
//
// A --> B --> C
// ^ |
// +-----------+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::C);
db.set_c_invokes(CycleQuery::A);
let u = db.cycle_a();
expect![[r#"
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
"cycle_c(())",
],
},
)
"#]]
.assert_debug_eq(&u);
}
#[test]
fn cycle_deterministic_order() {
// No matter whether we start from A or B, we get the same set of participants:
let db = || {
let mut db = DatabaseImpl::default();
// A --> B
// ^ |
// +-----+
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::A);
db
};
let a = db().cycle_a();
let b = db().cycle_b();
expect![[r#"
(
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
),
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
),
)
"#]]
.assert_debug_eq(&(a, b));
}
#[test]
fn cycle_multiple() {
// No matter whether we start from A or B, we get the same set of participants:
let mut db = DatabaseImpl::default();
// Configuration:
//
// A --> B <-- C
// ^ | ^
// +-----+ |
// | |
// +-----+
//
// Here, conceptually, B encounters a cycle with A and then
// recovers.
db.set_a_invokes(CycleQuery::B);
db.set_b_invokes(CycleQuery::AthenC);
db.set_c_invokes(CycleQuery::B);
let c = db.cycle_c();
let b = db.cycle_b();
let a = db.cycle_a();
expect![[r#"
(
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
),
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
),
Err(
Error {
cycle: [
"cycle_a(())",
"cycle_b(())",
],
},
),
)
"#]]
.assert_debug_eq(&(a, b, c));
}
#[test]
fn cycle_recovery_set_but_not_participating() {
let mut db = DatabaseImpl::default();
// A --> C -+
// ^ |
// +--+
db.set_a_invokes(CycleQuery::C);
db.set_c_invokes(CycleQuery::C);
// Here we expect C to panic and A not to recover:
let r = extract_cycle(|| drop(db.cycle_a()));
expect![[r#"
[
"cycle_c(())",
]
"#]]
.assert_debug_eq(&r.all_participants(&db));
}

View file

@ -0,0 +1,28 @@
//! Test that you can implement a query using a `dyn Trait` setup.
#[salsa::database(DynTraitStorage)]
#[derive(Default)]
struct DynTraitDatabase {
storage: salsa::Storage<Self>,
}
impl salsa::Database for DynTraitDatabase {}
#[salsa::query_group(DynTraitStorage)]
trait DynTrait {
#[salsa::input]
fn input(&self, x: u32) -> u32;
fn output(&self, x: u32) -> u32;
}
fn output(db: &dyn DynTrait, x: u32) -> u32 {
db.input(x) * 2
}
#[test]
fn dyn_trait() {
let mut query = DynTraitDatabase::default();
query.set_input(22, 23);
assert_eq!(query.output(22), 46);
}

View file

@ -0,0 +1,145 @@
use crate::implementation::{TestContext, TestContextImpl};
use salsa::debug::DebugQueryTable;
use salsa::Durability;
#[salsa::query_group(Constants)]
pub(crate) trait ConstantsDatabase: TestContext {
#[salsa::input]
fn input(&self, key: char) -> usize;
fn add(&self, key1: char, key2: char) -> usize;
fn add3(&self, key1: char, key2: char, key3: char) -> usize;
}
fn add(db: &dyn ConstantsDatabase, key1: char, key2: char) -> usize {
db.log().add(format!("add({}, {})", key1, key2));
db.input(key1) + db.input(key2)
}
fn add3(db: &dyn ConstantsDatabase, key1: char, key2: char, key3: char) -> usize {
db.log().add(format!("add3({}, {}, {})", key1, key2, key3));
db.add(key1, key2) + db.input(key3)
}
// Test we can assign a constant and things will be correctly
// recomputed afterwards.
#[test]
fn invalidate_constant() {
let db = &mut TestContextImpl::default();
db.set_input_with_durability('a', 44, Durability::HIGH);
db.set_input_with_durability('b', 22, Durability::HIGH);
assert_eq!(db.add('a', 'b'), 66);
db.set_input_with_durability('a', 66, Durability::HIGH);
assert_eq!(db.add('a', 'b'), 88);
}
#[test]
fn invalidate_constant_1() {
let db = &mut TestContextImpl::default();
// Not constant:
db.set_input('a', 44);
assert_eq!(db.add('a', 'a'), 88);
// Becomes constant:
db.set_input_with_durability('a', 44, Durability::HIGH);
assert_eq!(db.add('a', 'a'), 88);
// Invalidates:
db.set_input_with_durability('a', 33, Durability::HIGH);
assert_eq!(db.add('a', 'a'), 66);
}
// Test cases where we assign same value to 'a' after declaring it a
// constant.
#[test]
fn set_after_constant_same_value() {
let db = &mut TestContextImpl::default();
db.set_input_with_durability('a', 44, Durability::HIGH);
db.set_input_with_durability('a', 44, Durability::HIGH);
db.set_input('a', 44);
}
#[test]
fn not_constant() {
let mut db = TestContextImpl::default();
db.set_input('a', 22);
db.set_input('b', 44);
assert_eq!(db.add('a', 'b'), 66);
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
}
#[test]
fn durability() {
let mut db = TestContextImpl::default();
db.set_input_with_durability('a', 22, Durability::HIGH);
db.set_input_with_durability('b', 44, Durability::HIGH);
assert_eq!(db.add('a', 'b'), 66);
assert_eq!(Durability::HIGH, AddQuery.in_db(&db).durability(('a', 'b')));
}
#[test]
fn mixed_constant() {
let mut db = TestContextImpl::default();
db.set_input_with_durability('a', 22, Durability::HIGH);
db.set_input('b', 44);
assert_eq!(db.add('a', 'b'), 66);
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
}
#[test]
fn becomes_constant_with_change() {
let mut db = TestContextImpl::default();
db.set_input('a', 22);
db.set_input('b', 44);
assert_eq!(db.add('a', 'b'), 66);
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
db.set_input_with_durability('a', 23, Durability::HIGH);
assert_eq!(db.add('a', 'b'), 67);
assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b')));
db.set_input_with_durability('b', 45, Durability::HIGH);
assert_eq!(db.add('a', 'b'), 68);
assert_eq!(Durability::HIGH, AddQuery.in_db(&db).durability(('a', 'b')));
db.set_input_with_durability('b', 45, Durability::MEDIUM);
assert_eq!(db.add('a', 'b'), 68);
assert_eq!(Durability::MEDIUM, AddQuery.in_db(&db).durability(('a', 'b')));
}
// Test a subtle case in which an input changes from constant to
// non-constant, but its value doesn't change. If we're not careful,
// this can cause us to incorrectly consider derived values as still
// being constant.
#[test]
fn constant_to_non_constant() {
let mut db = TestContextImpl::default();
db.set_input_with_durability('a', 11, Durability::HIGH);
db.set_input_with_durability('b', 22, Durability::HIGH);
db.set_input_with_durability('c', 33, Durability::HIGH);
// Here, `add3` invokes `add`, which yields 33. Both calls are
// constant.
assert_eq!(db.add3('a', 'b', 'c'), 66);
db.set_input('a', 11);
// Here, `add3` invokes `add`, which *still* yields 33, but which
// is no longer constant. Since value didn't change, we might
// preserve `add3` unchanged, not noticing that it is no longer
// constant.
assert_eq!(db.add3('a', 'b', 'c'), 66);
// In that case, we would not get the correct result here, when
// 'a' changes *again*.
db.set_input('a', 22);
assert_eq!(db.add3('a', 'b', 'c'), 77);
}

View file

@ -0,0 +1,14 @@
use std::cell::Cell;
#[derive(Default)]
pub(crate) struct Counter {
value: Cell<usize>,
}
impl Counter {
pub(crate) fn increment(&self) -> usize {
let v = self.value.get();
self.value.set(v + 1);
v
}
}

View file

@ -0,0 +1,59 @@
use crate::constants;
use crate::counter::Counter;
use crate::log::Log;
use crate::memoized_dep_inputs;
use crate::memoized_inputs;
use crate::memoized_volatile;
pub(crate) trait TestContext: salsa::Database {
fn clock(&self) -> &Counter;
fn log(&self) -> &Log;
}
#[salsa::database(
constants::Constants,
memoized_dep_inputs::MemoizedDepInputs,
memoized_inputs::MemoizedInputs,
memoized_volatile::MemoizedVolatile
)]
#[derive(Default)]
pub(crate) struct TestContextImpl {
storage: salsa::Storage<TestContextImpl>,
clock: Counter,
log: Log,
}
impl TestContextImpl {
#[track_caller]
pub(crate) fn assert_log(&self, expected_log: &[&str]) {
let expected_text = &format!("{:#?}", expected_log);
let actual_text = &format!("{:#?}", self.log().take());
if expected_text == actual_text {
return;
}
#[allow(clippy::print_stdout)]
for diff in dissimilar::diff(expected_text, actual_text) {
match diff {
dissimilar::Chunk::Delete(l) => println!("-{}", l),
dissimilar::Chunk::Equal(l) => println!(" {}", l),
dissimilar::Chunk::Insert(r) => println!("+{}", r),
}
}
panic!("incorrect log results");
}
}
impl TestContext for TestContextImpl {
fn clock(&self) -> &Counter {
&self.clock
}
fn log(&self) -> &Log {
&self.log
}
}
impl salsa::Database for TestContextImpl {}

View file

@ -0,0 +1,16 @@
use std::cell::RefCell;
#[derive(Default)]
pub(crate) struct Log {
data: RefCell<Vec<String>>,
}
impl Log {
pub(crate) fn add(&self, text: impl Into<String>) {
self.data.borrow_mut().push(text.into());
}
pub(crate) fn take(&self) -> Vec<String> {
self.data.take()
}
}

View file

@ -0,0 +1,9 @@
mod constants;
mod counter;
mod implementation;
mod log;
mod memoized_dep_inputs;
mod memoized_inputs;
mod memoized_volatile;
fn main() {}

View file

@ -0,0 +1,60 @@
use crate::implementation::{TestContext, TestContextImpl};
#[salsa::query_group(MemoizedDepInputs)]
pub(crate) trait MemoizedDepInputsContext: TestContext {
fn dep_memoized2(&self) -> usize;
fn dep_memoized1(&self) -> usize;
#[salsa::dependencies]
fn dep_derived1(&self) -> usize;
#[salsa::input]
fn dep_input1(&self) -> usize;
#[salsa::input]
fn dep_input2(&self) -> usize;
}
fn dep_memoized2(db: &dyn MemoizedDepInputsContext) -> usize {
db.log().add("Memoized2 invoked");
db.dep_memoized1()
}
fn dep_memoized1(db: &dyn MemoizedDepInputsContext) -> usize {
db.log().add("Memoized1 invoked");
db.dep_derived1() * 2
}
fn dep_derived1(db: &dyn MemoizedDepInputsContext) -> usize {
db.log().add("Derived1 invoked");
db.dep_input1() / 2
}
#[test]
fn revalidate() {
let db = &mut TestContextImpl::default();
db.set_dep_input1(0);
// Initial run starts from Memoized2:
let v = db.dep_memoized2();
assert_eq!(v, 0);
db.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Derived1 invoked"]);
// After that, we first try to validate Memoized1 but wind up
// running Memoized2. Note that we don't try to validate
// Derived1, so it is invoked by Memoized1.
db.set_dep_input1(44);
let v = db.dep_memoized2();
assert_eq!(v, 44);
db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]);
// Here validation of Memoized1 succeeds so Memoized2 never runs.
db.set_dep_input1(45);
let v = db.dep_memoized2();
assert_eq!(v, 44);
db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]);
// Here, a change to input2 doesn't affect us, so nothing runs.
db.set_dep_input2(45);
let v = db.dep_memoized2();
assert_eq!(v, 44);
db.assert_log(&[]);
}

View file

@ -0,0 +1,76 @@
use crate::implementation::{TestContext, TestContextImpl};
#[salsa::query_group(MemoizedInputs)]
pub(crate) trait MemoizedInputsContext: TestContext {
fn max(&self) -> usize;
#[salsa::input]
fn input1(&self) -> usize;
#[salsa::input]
fn input2(&self) -> usize;
}
fn max(db: &dyn MemoizedInputsContext) -> usize {
db.log().add("Max invoked");
std::cmp::max(db.input1(), db.input2())
}
#[test]
fn revalidate() {
let db = &mut TestContextImpl::default();
db.set_input1(0);
db.set_input2(0);
let v = db.max();
assert_eq!(v, 0);
db.assert_log(&["Max invoked"]);
let v = db.max();
assert_eq!(v, 0);
db.assert_log(&[]);
db.set_input1(44);
db.assert_log(&[]);
let v = db.max();
assert_eq!(v, 44);
db.assert_log(&["Max invoked"]);
let v = db.max();
assert_eq!(v, 44);
db.assert_log(&[]);
db.set_input1(44);
db.assert_log(&[]);
db.set_input2(66);
db.assert_log(&[]);
db.set_input1(64);
db.assert_log(&[]);
let v = db.max();
assert_eq!(v, 66);
db.assert_log(&["Max invoked"]);
let v = db.max();
assert_eq!(v, 66);
db.assert_log(&[]);
}
/// Test that invoking `set` on an input with the same value still
/// triggers a new revision.
#[test]
fn set_after_no_change() {
let db = &mut TestContextImpl::default();
db.set_input2(0);
db.set_input1(44);
let v = db.max();
assert_eq!(v, 44);
db.assert_log(&["Max invoked"]);
db.set_input1(44);
let v = db.max();
assert_eq!(v, 44);
db.assert_log(&["Max invoked"]);
}

View file

@ -0,0 +1,77 @@
use crate::implementation::{TestContext, TestContextImpl};
use salsa::{Database, Durability};
#[salsa::query_group(MemoizedVolatile)]
pub(crate) trait MemoizedVolatileContext: TestContext {
// Queries for testing a "volatile" value wrapped by
// memoization.
fn memoized2(&self) -> usize;
fn memoized1(&self) -> usize;
fn volatile(&self) -> usize;
}
fn memoized2(db: &dyn MemoizedVolatileContext) -> usize {
db.log().add("Memoized2 invoked");
db.memoized1()
}
fn memoized1(db: &dyn MemoizedVolatileContext) -> usize {
db.log().add("Memoized1 invoked");
let v = db.volatile();
v / 2
}
fn volatile(db: &dyn MemoizedVolatileContext) -> usize {
db.log().add("Volatile invoked");
db.salsa_runtime().report_untracked_read();
db.clock().increment()
}
#[test]
fn volatile_x2() {
let query = TestContextImpl::default();
// Invoking volatile twice doesn't execute twice, because volatile
// queries are memoized by default.
query.volatile();
query.volatile();
query.assert_log(&["Volatile invoked"]);
}
/// Test that:
///
/// - On the first run of R0, we recompute everything.
/// - On the second run of R1, we recompute nothing.
/// - On the first run of R1, we recompute Memoized1 but not Memoized2 (since Memoized1 result
/// did not change).
/// - On the second run of R1, we recompute nothing.
/// - On the first run of R2, we recompute everything (since Memoized1 result *did* change).
#[test]
fn revalidate() {
let mut query = TestContextImpl::default();
query.memoized2();
query.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Volatile invoked"]);
query.memoized2();
query.assert_log(&[]);
// Second generation: volatile will change (to 1) but memoized1
// will not (still 0, as 1/2 = 0)
query.salsa_runtime_mut().synthetic_write(Durability::LOW);
query.memoized2();
query.assert_log(&["Volatile invoked", "Memoized1 invoked"]);
query.memoized2();
query.assert_log(&[]);
// Third generation: volatile will change (to 2) and memoized1
// will too (to 1). Therefore, after validating that Memoized1
// changed, we now invoke Memoized2.
query.salsa_runtime_mut().synthetic_write(Durability::LOW);
query.memoized2();
query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]);
query.memoized2();
query.assert_log(&[]);
}

View file

@ -0,0 +1,90 @@
//! Test that you can implement a query using a `dyn Trait` setup.
use salsa::InternId;
#[salsa::database(InternStorage)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
impl salsa::Database for Database {}
impl salsa::ParallelDatabase for Database {
fn snapshot(&self) -> salsa::Snapshot<Self> {
salsa::Snapshot::new(Database { storage: self.storage.snapshot() })
}
}
#[salsa::query_group(InternStorage)]
trait Intern {
#[salsa::interned]
fn intern1(&self, x: String) -> InternId;
#[salsa::interned]
fn intern2(&self, x: String, y: String) -> InternId;
#[salsa::interned]
fn intern_key(&self, x: String) -> InternKey;
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct InternKey(InternId);
impl salsa::InternKey for InternKey {
fn from_intern_id(v: InternId) -> Self {
InternKey(v)
}
fn as_intern_id(&self) -> InternId {
self.0
}
}
#[test]
fn test_intern1() {
let db = Database::default();
let foo0 = db.intern1("foo".to_string());
let bar0 = db.intern1("bar".to_string());
let foo1 = db.intern1("foo".to_string());
let bar1 = db.intern1("bar".to_string());
assert_eq!(foo0, foo1);
assert_eq!(bar0, bar1);
assert_ne!(foo0, bar0);
assert_eq!("foo".to_string(), db.lookup_intern1(foo0));
assert_eq!("bar".to_string(), db.lookup_intern1(bar0));
}
#[test]
fn test_intern2() {
let db = Database::default();
let foo0 = db.intern2("x".to_string(), "foo".to_string());
let bar0 = db.intern2("x".to_string(), "bar".to_string());
let foo1 = db.intern2("x".to_string(), "foo".to_string());
let bar1 = db.intern2("x".to_string(), "bar".to_string());
assert_eq!(foo0, foo1);
assert_eq!(bar0, bar1);
assert_ne!(foo0, bar0);
assert_eq!(("x".to_string(), "foo".to_string()), db.lookup_intern2(foo0));
assert_eq!(("x".to_string(), "bar".to_string()), db.lookup_intern2(bar0));
}
#[test]
fn test_intern_key() {
let db = Database::default();
let foo0 = db.intern_key("foo".to_string());
let bar0 = db.intern_key("bar".to_string());
let foo1 = db.intern_key("foo".to_string());
let bar1 = db.intern_key("bar".to_string());
assert_eq!(foo0, foo1);
assert_eq!(bar0, bar1);
assert_ne!(foo0, bar0);
assert_eq!("foo".to_string(), db.lookup_intern_key(foo0));
assert_eq!("bar".to_string(), db.lookup_intern_key(bar0));
}

102
crates/salsa/tests/lru.rs Normal file
View file

@ -0,0 +1,102 @@
//! Test setting LRU actually limits the number of things in the database;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
#[derive(Debug, PartialEq, Eq)]
struct HotPotato(u32);
static N_POTATOES: AtomicUsize = AtomicUsize::new(0);
impl HotPotato {
fn new(id: u32) -> HotPotato {
N_POTATOES.fetch_add(1, Ordering::SeqCst);
HotPotato(id)
}
}
impl Drop for HotPotato {
fn drop(&mut self) {
N_POTATOES.fetch_sub(1, Ordering::SeqCst);
}
}
#[salsa::query_group(QueryGroupStorage)]
trait QueryGroup: salsa::Database {
fn get(&self, x: u32) -> Arc<HotPotato>;
fn get_volatile(&self, x: u32) -> usize;
}
fn get(_db: &dyn QueryGroup, x: u32) -> Arc<HotPotato> {
Arc::new(HotPotato::new(x))
}
fn get_volatile(db: &dyn QueryGroup, _x: u32) -> usize {
static COUNTER: AtomicUsize = AtomicUsize::new(0);
db.salsa_runtime().report_untracked_read();
COUNTER.fetch_add(1, Ordering::SeqCst)
}
#[salsa::database(QueryGroupStorage)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
impl salsa::Database for Database {}
#[test]
fn lru_works() {
let mut db = Database::default();
GetQuery.in_db_mut(&mut db).set_lru_capacity(32);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
for i in 0..128u32 {
let p = db.get(i);
assert_eq!(p.0, i)
}
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
for i in 0..128u32 {
let p = db.get(i);
assert_eq!(p.0, i)
}
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
GetQuery.in_db_mut(&mut db).set_lru_capacity(32);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
GetQuery.in_db_mut(&mut db).set_lru_capacity(64);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32);
for i in 0..128u32 {
let p = db.get(i);
assert_eq!(p.0, i)
}
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
// Special case: setting capacity to zero disables LRU
GetQuery.in_db_mut(&mut db).set_lru_capacity(0);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64);
for i in 0..128u32 {
let p = db.get(i);
assert_eq!(p.0, i)
}
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 128);
drop(db);
assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0);
}
#[test]
fn lru_doesnt_break_volatile_queries() {
let mut db = Database::default();
GetVolatileQuery.in_db_mut(&mut db).set_lru_capacity(32);
// Here, we check that we execute each volatile query at most once, despite
// LRU. That does mean that we have more values in DB than the LRU capacity,
// but it's much better than inconsistent results from volatile queries!
for i in (0..3).flat_map(|_| 0..128usize) {
let x = db.get_volatile(i as u32);
assert_eq!(x, i)
}
}

View file

@ -0,0 +1,11 @@
#[salsa::query_group(MyStruct)]
trait MyDatabase: salsa::Database {
#[salsa::invoke(another_module::another_name)]
fn my_query(&self, key: ()) -> ();
}
mod another_module {
pub(crate) fn another_name(_: &dyn crate::MyDatabase, (): ()) {}
}
fn main() {}

View file

@ -0,0 +1,31 @@
use std::rc::Rc;
#[salsa::query_group(NoSendSyncStorage)]
trait NoSendSyncDatabase: salsa::Database {
fn no_send_sync_value(&self, key: bool) -> Rc<bool>;
fn no_send_sync_key(&self, key: Rc<bool>) -> bool;
}
fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Rc<bool> {
Rc::new(key)
}
fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Rc<bool>) -> bool {
*key
}
#[salsa::database(NoSendSyncStorage)]
#[derive(Default)]
struct DatabaseImpl {
storage: salsa::Storage<Self>,
}
impl salsa::Database for DatabaseImpl {}
#[test]
fn no_send_sync() {
let db = DatabaseImpl::default();
assert_eq!(db.no_send_sync_value(true), Rc::new(true));
assert!(!db.no_send_sync_key(Rc::new(false)));
}

View file

@ -0,0 +1,147 @@
//! Test that "on-demand" input pattern works.
//!
//! On-demand inputs are inputs computed lazily on the fly. They are simulated
//! via a b query with zero inputs, which uses `add_synthetic_read` to
//! tweak durability and `invalidate` to clear the input.
#![allow(clippy::disallowed_types, clippy::type_complexity)]
use std::{cell::RefCell, collections::HashMap, rc::Rc};
use salsa::{Database as _, Durability, EventKind};
#[salsa::query_group(QueryGroupStorage)]
trait QueryGroup: salsa::Database + AsRef<HashMap<u32, u32>> {
fn a(&self, x: u32) -> u32;
fn b(&self, x: u32) -> u32;
fn c(&self, x: u32) -> u32;
}
fn a(db: &dyn QueryGroup, x: u32) -> u32 {
let durability = if x % 2 == 0 { Durability::LOW } else { Durability::HIGH };
db.salsa_runtime().report_synthetic_read(durability);
let external_state: &HashMap<u32, u32> = db.as_ref();
external_state[&x]
}
fn b(db: &dyn QueryGroup, x: u32) -> u32 {
db.a(x)
}
fn c(db: &dyn QueryGroup, x: u32) -> u32 {
db.b(x)
}
#[salsa::database(QueryGroupStorage)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
external_state: HashMap<u32, u32>,
on_event: Option<Box<dyn Fn(&Database, salsa::Event)>>,
}
impl salsa::Database for Database {
fn salsa_event(&self, event: salsa::Event) {
if let Some(cb) = &self.on_event {
cb(self, event)
}
}
}
impl AsRef<HashMap<u32, u32>> for Database {
fn as_ref(&self) -> &HashMap<u32, u32> {
&self.external_state
}
}
#[test]
fn on_demand_input_works() {
let mut db = Database::default();
db.external_state.insert(1, 10);
assert_eq!(db.b(1), 10);
assert_eq!(db.a(1), 10);
// We changed external state, but haven't signaled about this yet,
// so we expect to see the old answer
db.external_state.insert(1, 92);
assert_eq!(db.b(1), 10);
assert_eq!(db.a(1), 10);
AQuery.in_db_mut(&mut db).invalidate(&1);
assert_eq!(db.b(1), 92);
assert_eq!(db.a(1), 92);
// Downstream queries should also be rerun if we call `a` first.
db.external_state.insert(1, 50);
AQuery.in_db_mut(&mut db).invalidate(&1);
assert_eq!(db.a(1), 50);
assert_eq!(db.b(1), 50);
}
#[test]
fn on_demand_input_durability() {
let mut db = Database::default();
let events = Rc::new(RefCell::new(vec![]));
db.on_event = Some(Box::new({
let events = events.clone();
move |db, event| {
if let EventKind::WillCheckCancellation = event.kind {
// these events are not interesting
} else {
events.borrow_mut().push(format!("{:?}", event.debug(db)))
}
}
}));
events.replace(vec![]);
db.external_state.insert(1, 10);
db.external_state.insert(2, 20);
assert_eq!(db.b(1), 10);
assert_eq!(db.b(2), 20);
expect_test::expect![[r#"
RefCell {
value: [
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(2) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
],
}
"#]].assert_debug_eq(&events);
db.salsa_runtime_mut().synthetic_write(Durability::LOW);
events.replace(vec![]);
assert_eq!(db.c(1), 10);
assert_eq!(db.c(2), 20);
// Re-execute `a(2)` because that has low durability, but not `a(1)`
expect_test::expect![[r#"
RefCell {
value: [
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(2) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(2) } }",
],
}
"#]].assert_debug_eq(&events);
db.salsa_runtime_mut().synthetic_write(Durability::HIGH);
events.replace(vec![]);
assert_eq!(db.c(1), 10);
assert_eq!(db.c(2), 20);
// Re-execute both `a(1)` and `a(2)`, but we don't re-execute any `b` queries as the
// result didn't actually change.
expect_test::expect![[r#"
RefCell {
value: [
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(1) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }",
"Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(2) } }",
],
}
"#]].assert_debug_eq(&events);
}

View file

@ -0,0 +1,93 @@
use salsa::{Database, ParallelDatabase, Snapshot};
use std::panic::{self, AssertUnwindSafe};
use std::sync::atomic::{AtomicU32, Ordering::SeqCst};
#[salsa::query_group(PanicSafelyStruct)]
trait PanicSafelyDatabase: salsa::Database {
#[salsa::input]
fn one(&self) -> usize;
fn panic_safely(&self) -> ();
fn outer(&self) -> ();
}
fn panic_safely(db: &dyn PanicSafelyDatabase) {
assert_eq!(db.one(), 1);
}
static OUTER_CALLS: AtomicU32 = AtomicU32::new(0);
fn outer(db: &dyn PanicSafelyDatabase) {
OUTER_CALLS.fetch_add(1, SeqCst);
db.panic_safely();
}
#[salsa::database(PanicSafelyStruct)]
#[derive(Default)]
struct DatabaseStruct {
storage: salsa::Storage<Self>,
}
impl salsa::Database for DatabaseStruct {}
impl salsa::ParallelDatabase for DatabaseStruct {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(DatabaseStruct { storage: self.storage.snapshot() })
}
}
#[test]
fn should_panic_safely() {
let mut db = DatabaseStruct::default();
db.set_one(0);
// Invoke `db.panic_safely() without having set `db.one`. `db.one` will
// return 0 and we should catch the panic.
let result = panic::catch_unwind(AssertUnwindSafe({
let db = db.snapshot();
move || db.panic_safely()
}));
assert!(result.is_err());
// Set `db.one` to 1 and assert ok
db.set_one(1);
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
assert!(result.is_ok());
// Check, that memoized outer is not invalidated by a panic
{
assert_eq!(OUTER_CALLS.load(SeqCst), 0);
db.outer();
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
db.set_one(0);
let result = panic::catch_unwind(AssertUnwindSafe(|| db.outer()));
assert!(result.is_err());
assert_eq!(OUTER_CALLS.load(SeqCst), 1);
db.set_one(1);
db.outer();
assert_eq!(OUTER_CALLS.load(SeqCst), 2);
}
}
#[test]
fn storages_are_unwind_safe() {
fn check_unwind_safe<T: std::panic::UnwindSafe>() {}
check_unwind_safe::<&DatabaseStruct>();
}
#[test]
fn panics_clear_query_stack() {
let db = DatabaseStruct::default();
// Invoke `db.panic_if_not_one() without having set `db.input`. `db.input`
// will default to 0 and we should catch the panic.
let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely()));
assert!(result.is_err());
// The database has been poisoned and any attempt to increment the
// revision should panic.
assert_eq!(db.salsa_runtime().active_query(), None);
}

View file

@ -0,0 +1,132 @@
use crate::setup::{CancellationFlag, Knobs, ParDatabase, ParDatabaseImpl, WithValue};
use salsa::{Cancelled, ParallelDatabase};
macro_rules! assert_cancelled {
($thread:expr) => {
match $thread.join() {
Ok(value) => panic!("expected cancellation, got {:?}", value),
Err(payload) => match payload.downcast::<Cancelled>() {
Ok(_) => {}
Err(payload) => ::std::panic::resume_unwind(payload),
},
}
};
}
/// Add test where a call to `sum` is cancelled by a simultaneous
/// write. Check that we recompute the result in next revision, even
/// though none of the inputs have changed.
#[test]
fn in_par_get_set_cancellation_immediate() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// This will not return until it sees cancellation is
// signaled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum("abc"))
})
}
});
// Wait until we have entered `sum` in the other thread.
db.wait_for(1);
// Try to set the input. This will signal cancellation.
db.set_input('d', 1000);
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("abc")
});
assert_eq!(db.sum("d"), 1000);
assert_cancelled!(thread1);
assert_eq!(thread2.join().unwrap(), 111);
}
/// Here, we check that `sum`'s cancellation is propagated
/// to `sum2` properly.
#[test]
fn in_par_get_set_cancellation_transitive() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 0);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// This will not return until it sees cancellation is
// signaled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum2("abc"))
})
}
});
// Wait until we have entered `sum` in the other thread.
db.wait_for(1);
// Try to set the input. This will signal cancellation.
db.set_input('d', 1000);
// This should re-compute the value (even though no input has changed).
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum2("abc")
});
assert_eq!(db.sum2("d"), 1000);
assert_cancelled!(thread1);
assert_eq!(thread2.join().unwrap(), 111);
}
/// https://github.com/salsa-rs/salsa/issues/66
#[test]
fn no_back_dating_in_cancellation() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
// Here we compute a long-chain of queries,
// but the last one gets cancelled.
db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_cancellation
.with_value(CancellationFlag::Panic, || db.sum3("a"))
})
}
});
db.wait_for(1);
// Set unrelated input to bump revision
db.set_input('b', 2);
// Here we should recompuet the whole chain again, clearing the cancellation
// state. If we get `usize::max()` here, it is a bug!
assert_eq!(db.sum3("a"), 1);
assert_cancelled!(thread1);
db.set_input('a', 3);
db.set_input('a', 4);
assert_eq!(db.sum3("ab"), 6);
}

View file

@ -0,0 +1,57 @@
use crate::setup::{ParDatabase, ParDatabaseImpl};
use crate::signal::Signal;
use salsa::{Database, ParallelDatabase};
use std::{
panic::{catch_unwind, AssertUnwindSafe},
sync::Arc,
};
/// Add test where a call to `sum` is cancelled by a simultaneous
/// write. Check that we recompute the result in next revision, even
/// though none of the inputs have changed.
#[test]
fn in_par_get_set_cancellation() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
let signal = Arc::new(Signal::default());
let thread1 = std::thread::spawn({
let db = db.snapshot();
let signal = signal.clone();
move || {
// Check that cancellation flag is not yet set, because
// `set` cannot have been called yet.
catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap();
// Signal other thread to proceed.
signal.signal(1);
// Wait for other thread to signal cancellation
catch_unwind(AssertUnwindSafe(|| loop {
db.unwind_if_cancelled();
std::thread::yield_now();
}))
.unwrap_err();
}
});
let thread2 = std::thread::spawn({
move || {
// Wait until thread 1 has asserted that they are not cancelled
// before we invoke `set.`
signal.wait_for(1);
// This will block until thread1 drops the revision lock.
db.set_input('a', 2);
db.input('a')
}
});
thread1.join().unwrap();
let c = thread2.join().unwrap();
assert_eq!(c, 2);
}

View file

@ -0,0 +1,29 @@
use crate::setup::{ParDatabase, ParDatabaseImpl};
use salsa::ParallelDatabase;
/// Test two `sum` queries (on distinct keys) executing in different
/// threads. Really just a test that `snapshot` etc compiles.
#[test]
fn in_par_two_independent_queries() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
db.set_input('d', 200);
db.set_input('e', 20);
db.set_input('f', 2);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("abc")
});
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || db.sum("def")
});
assert_eq!(thread1.join().unwrap(), 111);
assert_eq!(thread2.join().unwrap(), 222);
}

View file

@ -0,0 +1,13 @@
mod setup;
mod cancellation;
mod frozen;
mod independent;
mod parallel_cycle_all_recover;
mod parallel_cycle_mid_recover;
mod parallel_cycle_none_recover;
mod parallel_cycle_one_recovers;
mod race;
mod signal;
mod stress;
mod true_parallel;

View file

@ -0,0 +1,110 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use salsa::ParallelDatabase;
use test_log::test;
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | signal stage 2
// (unblocked) wait for stage 3 (blocks)
// a2 |
// b1 (blocks -> stage 3) |
// | (unblocked)
// | b2
// | a1 (cycle detected, recovers)
// | b2 completes, recovers
// | b1 completes, recovers
// a2 sees cycle, recovers
// a1 completes, recovers
#[test]
fn parallel_cycle_all_recover() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(3);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a1(1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b1(1)
});
assert_eq!(thread_a.join().unwrap(), 11);
assert_eq!(thread_b.join().unwrap(), 21);
}
#[salsa::query_group(ParallelCycleAllRecover)]
pub(crate) trait TestDatabase: Knobs {
#[salsa::cycle(recover_a1)]
fn a1(&self, key: i32) -> i32;
#[salsa::cycle(recover_a2)]
fn a2(&self, key: i32) -> i32;
#[salsa::cycle(recover_b1)]
fn b1(&self, key: i32) -> i32;
#[salsa::cycle(recover_b2)]
fn b2(&self, key: i32) -> i32;
}
fn recover_a1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_a1");
key * 10 + 1
}
fn recover_a2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_a2");
key * 10 + 2
}
fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b1");
key * 20 + 1
}
fn recover_b2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b2");
key * 20 + 2
}
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
db.a2(key)
}
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
db.b1(key)
}
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
db.b2(key)
}
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
db.a1(key)
}

View file

@ -0,0 +1,110 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use salsa::ParallelDatabase;
use test_log::test;
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | |
// | b2
// | b3
// | a1 (blocks -> stage 2)
// (unblocked) |
// a2 (cycle detected) |
// b3 recovers
// b2 resumes
// b1 panics because bug
#[test]
fn parallel_cycle_mid_recovers() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(2);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a1(1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b1(1)
});
// We expect that the recovery function yields
// `1 * 20 + 2`, which is returned (and forwarded)
// to b1, and from there to a2 and a1.
assert_eq!(thread_a.join().unwrap(), 22);
assert_eq!(thread_b.join().unwrap(), 22);
}
#[salsa::query_group(ParallelCycleMidRecovers)]
pub(crate) trait TestDatabase: Knobs {
fn a1(&self, key: i32) -> i32;
fn a2(&self, key: i32) -> i32;
#[salsa::cycle(recover_b1)]
fn b1(&self, key: i32) -> i32;
fn b2(&self, key: i32) -> i32;
#[salsa::cycle(recover_b3)]
fn b3(&self, key: i32) -> i32;
}
fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b1");
key * 20 + 2
}
fn recover_b3(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover_b1");
key * 200 + 2
}
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
// tell thread b we have started
db.signal(1);
// wait for thread b to block on a1
db.wait_for(2);
db.a2(key)
}
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
// create the cycle
db.b1(key)
}
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
// wait for thread a to have started
db.wait_for(1);
db.b2(key);
0
}
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
// will encounter a cycle but recover
db.b3(key);
db.b1(key); // hasn't recovered yet
0
}
fn b3(db: &dyn TestDatabase, key: i32) -> i32 {
// will block on thread a, signaling stage 2
db.a1(key)
}

View file

@ -0,0 +1,69 @@
//! Test a cycle where no queries recover that occurs across threads.
//! See the `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use expect_test::expect;
use salsa::ParallelDatabase;
use test_log::test;
#[test]
fn parallel_cycle_none_recover() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(3);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a(-1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b(-1)
});
// We expect B to panic because it detects a cycle (it is the one that calls A, ultimately).
// Right now, it panics with a string.
let err_b = thread_b.join().unwrap_err();
if let Some(c) = err_b.downcast_ref::<salsa::Cycle>() {
expect![[r#"
[
"a(-1)",
"b(-1)",
]
"#]]
.assert_debug_eq(&c.unexpected_participants(&db));
} else {
panic!("b failed in an unexpected way: {:?}", err_b);
}
// We expect A to propagate a panic, which causes us to use the sentinel
// type `Canceled`.
assert!(thread_a.join().unwrap_err().downcast_ref::<salsa::Cycle>().is_some());
}
#[salsa::query_group(ParallelCycleNoneRecover)]
pub(crate) trait TestDatabase: Knobs {
fn a(&self, key: i32) -> i32;
fn b(&self, key: i32) -> i32;
}
fn a(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
db.b(key)
}
fn b(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
// Now try to execute A
db.a(key)
}

View file

@ -0,0 +1,95 @@
//! Test for cycle recover spread across two threads.
//! See `../cycles.rs` for a complete listing of cycle tests,
//! both intra and cross thread.
use crate::setup::{Knobs, ParDatabaseImpl};
use salsa::ParallelDatabase;
use test_log::test;
// Recover cycle test:
//
// The pattern is as follows.
//
// Thread A Thread B
// -------- --------
// a1 b1
// | wait for stage 1 (blocks)
// signal stage 1 |
// wait for stage 2 (blocks) (unblocked)
// | signal stage 2
// (unblocked) wait for stage 3 (blocks)
// a2 |
// b1 (blocks -> stage 3) |
// | (unblocked)
// | b2
// | a1 (cycle detected)
// a2 recovery fn executes |
// a1 completes normally |
// b2 completes, recovers
// b1 completes, recovers
#[test]
fn parallel_cycle_one_recovers() {
let db = ParDatabaseImpl::default();
db.knobs().signal_on_will_block.set(3);
let thread_a = std::thread::spawn({
let db = db.snapshot();
move || db.a1(1)
});
let thread_b = std::thread::spawn({
let db = db.snapshot();
move || db.b1(1)
});
// We expect that the recovery function yields
// `1 * 20 + 2`, which is returned (and forwarded)
// to b1, and from there to a2 and a1.
assert_eq!(thread_a.join().unwrap(), 22);
assert_eq!(thread_b.join().unwrap(), 22);
}
#[salsa::query_group(ParallelCycleOneRecovers)]
pub(crate) trait TestDatabase: Knobs {
fn a1(&self, key: i32) -> i32;
#[salsa::cycle(recover)]
fn a2(&self, key: i32) -> i32;
fn b1(&self, key: i32) -> i32;
fn b2(&self, key: i32) -> i32;
}
fn recover(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 {
tracing::debug!("recover");
key * 20 + 2
}
fn a1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.signal(1);
db.wait_for(2);
db.a2(key)
}
fn a2(db: &dyn TestDatabase, key: i32) -> i32 {
db.b1(key)
}
fn b1(db: &dyn TestDatabase, key: i32) -> i32 {
// Wait to create the cycle until both threads have entered
db.wait_for(1);
db.signal(2);
// Wait for thread A to block on this thread
db.wait_for(3);
db.b2(key)
}
fn b2(db: &dyn TestDatabase, key: i32) -> i32 {
db.a1(key)
}

View file

@ -0,0 +1,37 @@
use std::panic::AssertUnwindSafe;
use crate::setup::{ParDatabase, ParDatabaseImpl};
use salsa::{Cancelled, ParallelDatabase};
/// Test where a read and a set are racing with one another.
/// Should be atomic.
#[test]
fn in_par_get_set_race() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || Cancelled::catch(AssertUnwindSafe(|| db.sum("abc")))
});
let thread2 = std::thread::spawn(move || {
db.set_input('a', 1000);
db.sum("a")
});
// If the 1st thread runs first, you get 111, otherwise you get
// 1011; if they run concurrently and the 1st thread observes the
// cancellation, it'll unwind.
let result1 = thread1.join().unwrap();
if let Ok(value1) = result1 {
assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1);
}
// thread2 can not observe a cancellation because it performs a
// database write before running any other queries.
assert_eq!(thread2.join().unwrap(), 1000);
}

View file

@ -0,0 +1,197 @@
use crate::signal::Signal;
use salsa::Database;
use salsa::ParallelDatabase;
use salsa::Snapshot;
use std::sync::Arc;
use std::{
cell::Cell,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
};
#[salsa::query_group(Par)]
pub(crate) trait ParDatabase: Knobs {
#[salsa::input]
fn input(&self, key: char) -> usize;
fn sum(&self, key: &'static str) -> usize;
/// Invokes `sum`
fn sum2(&self, key: &'static str) -> usize;
/// Invokes `sum` but doesn't really care about the result.
fn sum2_drop_sum(&self, key: &'static str) -> usize;
/// Invokes `sum2`
fn sum3(&self, key: &'static str) -> usize;
/// Invokes `sum2_drop_sum`
fn sum3_drop_sum(&self, key: &'static str) -> usize;
}
/// Various "knobs" and utilities used by tests to force
/// a certain behavior.
pub(crate) trait Knobs {
fn knobs(&self) -> &KnobsStruct;
fn signal(&self, stage: usize);
fn wait_for(&self, stage: usize);
}
pub(crate) trait WithValue<T> {
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R;
}
impl<T> WithValue<T> for Cell<T> {
fn with_value<R>(&self, value: T, closure: impl FnOnce() -> R) -> R {
let old_value = self.replace(value);
let result = catch_unwind(AssertUnwindSafe(closure));
self.set(old_value);
match result {
Ok(r) => r,
Err(payload) => resume_unwind(payload),
}
}
}
#[derive(Default, Clone, Copy, PartialEq, Eq)]
pub(crate) enum CancellationFlag {
#[default]
Down,
Panic,
}
/// Various "knobs" that can be used to customize how the queries
/// behave on one specific thread. Note that this state is
/// intentionally thread-local (apart from `signal`).
#[derive(Clone, Default)]
pub(crate) struct KnobsStruct {
/// A kind of flexible barrier used to coordinate execution across
/// threads to ensure we reach various weird states.
pub(crate) signal: Arc<Signal>,
/// When this database is about to block, send a signal.
pub(crate) signal_on_will_block: Cell<usize>,
/// Invocations of `sum` will signal this stage on entry.
pub(crate) sum_signal_on_entry: Cell<usize>,
/// Invocations of `sum` will wait for this stage on entry.
pub(crate) sum_wait_for_on_entry: Cell<usize>,
/// If true, invocations of `sum` will panic before they exit.
pub(crate) sum_should_panic: Cell<bool>,
/// If true, invocations of `sum` will wait for cancellation before
/// they exit.
pub(crate) sum_wait_for_cancellation: Cell<CancellationFlag>,
/// Invocations of `sum` will wait for this stage prior to exiting.
pub(crate) sum_wait_for_on_exit: Cell<usize>,
/// Invocations of `sum` will signal this stage prior to exiting.
pub(crate) sum_signal_on_exit: Cell<usize>,
/// Invocations of `sum3_drop_sum` will panic unconditionally
pub(crate) sum3_drop_sum_should_panic: Cell<bool>,
}
fn sum(db: &dyn ParDatabase, key: &'static str) -> usize {
let mut sum = 0;
db.signal(db.knobs().sum_signal_on_entry.get());
db.wait_for(db.knobs().sum_wait_for_on_entry.get());
if db.knobs().sum_should_panic.get() {
panic!("query set to panic before exit")
}
for ch in key.chars() {
sum += db.input(ch);
}
match db.knobs().sum_wait_for_cancellation.get() {
CancellationFlag::Down => (),
CancellationFlag::Panic => {
tracing::debug!("waiting for cancellation");
loop {
db.unwind_if_cancelled();
std::thread::yield_now();
}
}
}
db.wait_for(db.knobs().sum_wait_for_on_exit.get());
db.signal(db.knobs().sum_signal_on_exit.get());
sum
}
fn sum2(db: &dyn ParDatabase, key: &'static str) -> usize {
db.sum(key)
}
fn sum2_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
let _ = db.sum(key);
22
}
fn sum3(db: &dyn ParDatabase, key: &'static str) -> usize {
db.sum2(key)
}
fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize {
if db.knobs().sum3_drop_sum_should_panic.get() {
panic!("sum3_drop_sum executed")
}
db.sum2_drop_sum(key)
}
#[salsa::database(
Par,
crate::parallel_cycle_all_recover::ParallelCycleAllRecover,
crate::parallel_cycle_none_recover::ParallelCycleNoneRecover,
crate::parallel_cycle_mid_recover::ParallelCycleMidRecovers,
crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers
)]
#[derive(Default)]
pub(crate) struct ParDatabaseImpl {
storage: salsa::Storage<Self>,
knobs: KnobsStruct,
}
impl Database for ParDatabaseImpl {
fn salsa_event(&self, event: salsa::Event) {
if let salsa::EventKind::WillBlockOn { .. } = event.kind {
self.signal(self.knobs().signal_on_will_block.get());
}
}
}
impl ParallelDatabase for ParDatabaseImpl {
fn snapshot(&self) -> Snapshot<Self> {
Snapshot::new(ParDatabaseImpl {
storage: self.storage.snapshot(),
knobs: self.knobs.clone(),
})
}
}
impl Knobs for ParDatabaseImpl {
fn knobs(&self) -> &KnobsStruct {
&self.knobs
}
fn signal(&self, stage: usize) {
self.knobs.signal.signal(stage);
}
fn wait_for(&self, stage: usize) {
self.knobs.signal.wait_for(stage);
}
}

View file

@ -0,0 +1,40 @@
use parking_lot::{Condvar, Mutex};
#[derive(Default)]
pub(crate) struct Signal {
value: Mutex<usize>,
cond_var: Condvar,
}
impl Signal {
pub(crate) fn signal(&self, stage: usize) {
tracing::debug!("signal({})", stage);
// This check avoids acquiring the lock for things that will
// clearly be a no-op. Not *necessary* but helps to ensure we
// are more likely to encounter weird race conditions;
// otherwise calls to `sum` will tend to be unnecessarily
// synchronous.
if stage > 0 {
let mut v = self.value.lock();
if stage > *v {
*v = stage;
self.cond_var.notify_all();
}
}
}
/// Waits until the given condition is true; the fn is invoked
/// with the current stage.
pub(crate) fn wait_for(&self, stage: usize) {
tracing::debug!("wait_for({})", stage);
// As above, avoid lock if clearly a no-op.
if stage > 0 {
let mut v = self.value.lock();
while *v < stage {
self.cond_var.wait(&mut v);
}
}
}
}

View file

@ -0,0 +1,168 @@
use rand::seq::SliceRandom;
use rand::Rng;
use salsa::ParallelDatabase;
use salsa::Snapshot;
use salsa::{Cancelled, Database};
// Number of operations a reader performs
const N_MUTATOR_OPS: usize = 100;
const N_READER_OPS: usize = 100;
#[salsa::query_group(Stress)]
trait StressDatabase: salsa::Database {
#[salsa::input]
fn a(&self, key: usize) -> usize;
fn b(&self, key: usize) -> usize;
fn c(&self, key: usize) -> usize;
}
fn b(db: &dyn StressDatabase, key: usize) -> usize {
db.unwind_if_cancelled();
db.a(key)
}
fn c(db: &dyn StressDatabase, key: usize) -> usize {
db.b(key)
}
#[salsa::database(Stress)]
#[derive(Default)]
struct StressDatabaseImpl {
storage: salsa::Storage<Self>,
}
impl salsa::Database for StressDatabaseImpl {}
impl salsa::ParallelDatabase for StressDatabaseImpl {
fn snapshot(&self) -> Snapshot<StressDatabaseImpl> {
Snapshot::new(StressDatabaseImpl { storage: self.storage.snapshot() })
}
}
#[derive(Clone, Copy, Debug)]
enum Query {
A,
B,
C,
}
enum MutatorOp {
WriteOp(WriteOp),
LaunchReader { ops: Vec<ReadOp>, check_cancellation: bool },
}
#[derive(Debug)]
enum WriteOp {
SetA(usize, usize),
}
#[derive(Debug)]
enum ReadOp {
Get(Query, usize),
}
impl rand::distributions::Distribution<Query> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Query {
*[Query::A, Query::B, Query::C].choose(rng).unwrap()
}
}
impl rand::distributions::Distribution<MutatorOp> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> MutatorOp {
if rng.gen_bool(0.5) {
MutatorOp::WriteOp(rng.gen())
} else {
MutatorOp::LaunchReader {
ops: (0..N_READER_OPS).map(|_| rng.gen()).collect(),
check_cancellation: rng.gen(),
}
}
}
}
impl rand::distributions::Distribution<WriteOp> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> WriteOp {
let key = rng.gen::<usize>() % 10;
let value = rng.gen::<usize>() % 10;
WriteOp::SetA(key, value)
}
}
impl rand::distributions::Distribution<ReadOp> for rand::distributions::Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ReadOp {
let query = rng.gen::<Query>();
let key = rng.gen::<usize>() % 10;
ReadOp::Get(query, key)
}
}
fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec<ReadOp>, check_cancellation: bool) {
for op in ops {
if check_cancellation {
db.unwind_if_cancelled();
}
op.execute(db);
}
}
impl WriteOp {
fn execute(self, db: &mut StressDatabaseImpl) {
match self {
WriteOp::SetA(key, value) => {
db.set_a(key, value);
}
}
}
}
impl ReadOp {
fn execute(self, db: &StressDatabaseImpl) {
match self {
ReadOp::Get(query, key) => match query {
Query::A => {
db.a(key);
}
Query::B => {
let _ = db.b(key);
}
Query::C => {
let _ = db.c(key);
}
},
}
}
}
#[test]
fn stress_test() {
let mut db = StressDatabaseImpl::default();
for i in 0..10 {
db.set_a(i, i);
}
let mut rng = rand::thread_rng();
// generate the ops that the mutator thread will perform
let write_ops: Vec<MutatorOp> = (0..N_MUTATOR_OPS).map(|_| rng.gen()).collect();
// execute the "main thread", which sometimes snapshots off other threads
let mut all_threads = vec![];
for op in write_ops {
match op {
MutatorOp::WriteOp(w) => w.execute(&mut db),
MutatorOp::LaunchReader { ops, check_cancellation } => {
all_threads.push(std::thread::spawn({
let db = db.snapshot();
move || Cancelled::catch(|| db_reader_thread(&db, ops, check_cancellation))
}))
}
}
}
for thread in all_threads {
thread.join().unwrap().ok();
}
}

View file

@ -0,0 +1,125 @@
use crate::setup::{Knobs, ParDatabase, ParDatabaseImpl, WithValue};
use salsa::ParallelDatabase;
use std::panic::{self, AssertUnwindSafe};
/// Test where two threads are executing sum. We show that they can
/// both be executing sum in parallel by having thread1 wait for
/// thread2 to send a signal before it leaves (similarly, thread2
/// waits for thread1 to send a signal before it enters).
#[test]
fn true_parallel_different_keys() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
// Thread 1 will signal stage 1 when it enters and wait for stage 2.
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db
.knobs()
.sum_signal_on_entry
.with_value(1, || db.knobs().sum_wait_for_on_exit.with_value(2, || db.sum("a")));
v
}
});
// Thread 2 will wait_for stage 1 when it enters and signal stage 2
// when it leaves.
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db
.knobs()
.sum_wait_for_on_entry
.with_value(1, || db.knobs().sum_signal_on_exit.with_value(2, || db.sum("b")));
v
}
});
assert_eq!(thread1.join().unwrap(), 100);
assert_eq!(thread2.join().unwrap(), 10);
}
/// Add a test that tries to trigger a conflict, where we fetch
/// `sum("abc")` from two threads simultaneously, and of them
/// therefore has to block.
#[test]
fn true_parallel_same_keys() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 100);
db.set_input('b', 10);
db.set_input('c', 1);
// Thread 1 will wait_for a barrier in the start of `sum`
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db
.knobs()
.sum_signal_on_entry
.with_value(1, || db.knobs().sum_wait_for_on_entry.with_value(2, || db.sum("abc")));
v
}
});
// Thread 2 will wait until Thread 1 has entered sum and then --
// once it has set itself to block -- signal Thread 1 to
// continue. This way, we test out the mechanism of one thread
// blocking on another.
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || {
db.knobs().signal.wait_for(1);
db.knobs().signal_on_will_block.set(2);
db.sum("abc")
}
});
assert_eq!(thread1.join().unwrap(), 111);
assert_eq!(thread2.join().unwrap(), 111);
}
/// Add a test that tries to trigger a conflict, where we fetch `sum("a")`
/// from two threads simultaneously. After `thread2` begins blocking,
/// we force `thread1` to panic and should see that propagate to `thread2`.
#[test]
fn true_parallel_propagate_panic() {
let mut db = ParDatabaseImpl::default();
db.set_input('a', 1);
// `thread1` will wait_for a barrier in the start of `sum`. Once it can
// continue, it will panic.
let thread1 = std::thread::spawn({
let db = db.snapshot();
move || {
let v = db.knobs().sum_signal_on_entry.with_value(1, || {
db.knobs()
.sum_wait_for_on_entry
.with_value(2, || db.knobs().sum_should_panic.with_value(true, || db.sum("a")))
});
v
}
});
// `thread2` will wait until `thread1` has entered sum and then -- once it
// has set itself to block -- signal `thread1` to continue.
let thread2 = std::thread::spawn({
let db = db.snapshot();
move || {
db.knobs().signal.wait_for(1);
db.knobs().signal_on_will_block.set(2);
db.sum("a")
}
});
let result1 = panic::catch_unwind(AssertUnwindSafe(|| thread1.join().unwrap()));
let result2 = panic::catch_unwind(AssertUnwindSafe(|| thread2.join().unwrap()));
assert!(result1.is_err());
assert!(result2.is_err());
}

View file

@ -0,0 +1,19 @@
use crate::queries;
use std::cell::Cell;
#[salsa::database(queries::GroupStruct)]
#[derive(Default)]
pub(crate) struct DatabaseImpl {
storage: salsa::Storage<Self>,
counter: Cell<usize>,
}
impl queries::Counter for DatabaseImpl {
fn increment(&self) -> usize {
let v = self.counter.get();
self.counter.set(v + 1);
v
}
}
impl salsa::Database for DatabaseImpl {}

View file

@ -0,0 +1,5 @@
mod implementation;
mod queries;
mod tests;
fn main() {}

View file

@ -0,0 +1,22 @@
pub(crate) trait Counter: salsa::Database {
fn increment(&self) -> usize;
}
#[salsa::query_group(GroupStruct)]
pub(crate) trait Database: Counter {
fn memoized(&self) -> usize;
fn volatile(&self) -> usize;
}
/// Because this query is memoized, we only increment the counter
/// the first time it is invoked.
fn memoized(db: &dyn Database) -> usize {
db.volatile()
}
/// Because this query is volatile, each time it is invoked,
/// we will increment the counter.
fn volatile(db: &dyn Database) -> usize {
db.salsa_runtime().report_untracked_read();
db.increment()
}

View file

@ -0,0 +1,49 @@
#![cfg(test)]
use crate::implementation::DatabaseImpl;
use crate::queries::Database;
use salsa::Database as _Database;
use salsa::Durability;
#[test]
fn memoized_twice() {
let db = DatabaseImpl::default();
let v1 = db.memoized();
let v2 = db.memoized();
assert_eq!(v1, v2);
}
#[test]
fn volatile_twice() {
let mut db = DatabaseImpl::default();
let v1 = db.volatile();
let v2 = db.volatile(); // volatiles are cached, so 2nd read returns the same
assert_eq!(v1, v2);
db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches
let v3 = db.volatile(); // will re-increment the counter
let v4 = db.volatile(); // second call will be cached
assert_eq!(v1 + 1, v3);
assert_eq!(v3, v4);
}
#[test]
fn intermingled() {
let mut db = DatabaseImpl::default();
let v1 = db.volatile();
let v2 = db.memoized();
let v3 = db.volatile(); // cached
let v4 = db.memoized(); // cached
assert_eq!(v1, v2);
assert_eq!(v1, v3);
assert_eq!(v2, v4);
db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches
let v5 = db.memoized(); // re-executes volatile, caches new result
let v6 = db.memoized(); // re-use cached result
assert_eq!(v4 + 1, v5);
assert_eq!(v5, v6);
}

View file

@ -0,0 +1,39 @@
//! Test that transparent (uncached) queries work
#[salsa::query_group(QueryGroupStorage)]
trait QueryGroup {
#[salsa::input]
fn input(&self, x: u32) -> u32;
#[salsa::transparent]
fn wrap(&self, x: u32) -> u32;
fn get(&self, x: u32) -> u32;
}
fn wrap(db: &dyn QueryGroup, x: u32) -> u32 {
db.input(x)
}
fn get(db: &dyn QueryGroup, x: u32) -> u32 {
db.wrap(x)
}
#[salsa::database(QueryGroupStorage)]
#[derive(Default)]
struct Database {
storage: salsa::Storage<Self>,
}
impl salsa::Database for Database {}
#[test]
fn transparent_queries_work() {
let mut db = Database::default();
db.set_input(1, 10);
assert_eq!(db.get(1), 10);
assert_eq!(db.get(1), 10);
db.set_input(1, 92);
assert_eq!(db.get(1), 92);
assert_eq!(db.get(1), 92);
}

View file

@ -0,0 +1,51 @@
#[salsa::query_group(HelloWorld)]
trait HelloWorldDatabase: salsa::Database {
#[salsa::input]
fn input(&self, a: u32, b: u32) -> u32;
fn none(&self) -> u32;
fn one(&self, k: u32) -> u32;
fn two(&self, a: u32, b: u32) -> u32;
fn trailing(&self, a: u32, b: u32) -> u32;
}
fn none(_db: &dyn HelloWorldDatabase) -> u32 {
22
}
fn one(_db: &dyn HelloWorldDatabase, k: u32) -> u32 {
k * 2
}
fn two(_db: &dyn HelloWorldDatabase, a: u32, b: u32) -> u32 {
a * b
}
fn trailing(_db: &dyn HelloWorldDatabase, a: u32, b: u32) -> u32 {
a - b
}
#[salsa::database(HelloWorld)]
#[derive(Default)]
struct DatabaseStruct {
storage: salsa::Storage<Self>,
}
impl salsa::Database for DatabaseStruct {}
#[test]
fn execute() {
let mut db = DatabaseStruct::default();
// test what happens with inputs:
db.set_input(1, 2, 3);
assert_eq!(db.input(1, 2), 3);
assert_eq!(db.none(), 22);
assert_eq!(db.one(11), 22);
assert_eq!(db.two(11, 2), 22);
assert_eq!(db.trailing(24, 2), 22);
}

View file

@ -11,7 +11,7 @@ authors.workspace = true
[dependencies] [dependencies]
la-arena.workspace = true la-arena.workspace = true
rust-analyzer-salsa.workspace = true salsa.workspace = true
# local deps # local deps