diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 07fa85a61c..964be478fa 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -28,7 +28,7 @@ jobs: proc_macros: ${{ steps.filter.outputs.proc_macros }} steps: - uses: actions/checkout@v3 - - uses: dorny/paths-filter@4067d885736b84de7c414f582ac45897079b0a78 + - uses: dorny/paths-filter@1441771bbfdd59dcd748680ee64ebd8faab1a242 id: filter with: filters: | @@ -104,8 +104,8 @@ jobs: run: target/${{ matrix.target }}/debug/rust-analyzer analysis-stats --with-deps $(rustc --print sysroot)/lib/rustlib/src/rust/library/std - name: clippy - if: matrix.os == 'ubuntu-latest' - run: cargo clippy --all-targets + if: matrix.os == 'windows-latest' + run: cargo clippy --all-targets -- -D clippy::disallowed_macros -D clippy::dbg_macro -D clippy::todo -D clippy::print_stdout -D clippy::print_stderr # Weird targets to catch non-portable code rust-cross: diff --git a/Cargo.lock b/Cargo.lock index e9492ce020..dc2bf3a769 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -72,8 +72,8 @@ dependencies = [ "cfg", "la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "profile", - "rust-analyzer-salsa", "rustc-hash", + "salsa", "semver", "span", "stdx", @@ -167,7 +167,7 @@ checksum = "5676cea088c32290fe65c82895be9d06dd21e0fa49bb97ca840529e9417ab71a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn", "synstructure", ] @@ -313,17 +313,6 @@ dependencies = [ "parking_lot_core", ] -[[package]] -name = "derivative" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "derive_arbitrary" version = "1.3.2" @@ -332,7 +321,7 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn", ] [[package]] @@ -368,6 +357,15 @@ dependencies = [ "log", ] +[[package]] +name = "env_logger" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" +dependencies = [ + "log", +] + [[package]] name = "equivalent" version = "1.0.0" @@ -452,6 +450,17 @@ version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" +[[package]] +name = "getrandom" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "gimli" version = "0.27.3" @@ -929,6 +938,12 @@ dependencies = [ "text-size", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "load-cargo" version = "0.0.0" @@ -1272,6 +1287,12 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro-api" version = "0.0.0" @@ -1283,6 +1304,7 @@ dependencies = [ "object 0.32.0", "paths", "profile", + "rustc-hash", "serde", "serde_json", "snap", @@ -1435,17 +1457,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "ra-ap-rustc_index" -version = "0.33.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5313d7f243b63ef9e58d94355b11aa8499f1328055f1f58adf0a5ea7d2faca" -dependencies = [ - "arrayvec", - "ra-ap-rustc_index_macros 0.33.0", - "smallvec", -] - [[package]] name = "ra-ap-rustc_index" version = "0.35.0" @@ -1458,15 +1469,14 @@ dependencies = [ ] [[package]] -name = "ra-ap-rustc_index_macros" -version = "0.33.0" +name = "ra-ap-rustc_index" +version = "0.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a83108ebf3e73dde205b9c25706209bcd7736480820f90ded28eabaf8b469f25" +checksum = "f8a41dee58608b1fc93779ea365edaa70ac9927e3335ae914b675be0fa063cd7" dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.39", - "synstructure", + "arrayvec", + "ra-ap-rustc_index_macros 0.36.0", + "smallvec", ] [[package]] @@ -1477,7 +1487,19 @@ checksum = "054e25eac52f0506c1309ca4317c11ad4925d7b99eb897f71aa7c3cbafb46c2b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn", + "synstructure", +] + +[[package]] +name = "ra-ap-rustc_index_macros" +version = "0.36.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbfe98def54c4337a2f7d8233850bd5d5349972b185fe8a0db2b979164b30ed8" +dependencies = [ + "proc-macro2", + "quote", + "syn", "synstructure", ] @@ -1503,18 +1525,47 @@ dependencies = [ [[package]] name = "ra-ap-rustc_pattern_analysis" -version = "0.33.0" +version = "0.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c4085e0c771fd4b883930b599ef42966b855762bbe4052c17673b3253421a6d" +checksum = "b5529bffec7530b4a3425640bfdfd9b95d87c4c620f740266c0de6572561aab4" dependencies = [ - "derivative", - "ra-ap-rustc_index 0.33.0", + "ra-ap-rustc_index 0.36.0", "rustc-hash", "rustc_apfloat", "smallvec", "tracing", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "rayon" version = "1.8.0" @@ -1622,35 +1673,6 @@ dependencies = [ "xshell", ] -[[package]] -name = "rust-analyzer-salsa" -version = "0.17.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719825638c59fd26a55412a24561c7c5bcf54364c88b9a7a04ba08a6eafaba8d" -dependencies = [ - "indexmap", - "lock_api", - "oorandom", - "parking_lot", - "rust-analyzer-salsa-macros", - "rustc-hash", - "smallvec", - "tracing", - "triomphe", -] - -[[package]] -name = "rust-analyzer-salsa-macros" -version = "0.17.0-pre.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d96498e9684848c6676c399032ebc37c52da95ecbefa83d71ccc53b9f8a4a8e" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn 2.0.39", -] - [[package]] name = "rustc-demangle" version = "0.1.23" @@ -1679,6 +1701,36 @@ version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" +[[package]] +name = "salsa" +version = "0.0.0" +dependencies = [ + "dissimilar", + "expect-test", + "indexmap", + "linked-hash-map", + "lock_api", + "oorandom", + "parking_lot", + "rand", + "rustc-hash", + "salsa-macros", + "smallvec", + "test-log", + "tracing", + "triomphe", +] + +[[package]] +name = "salsa-macros" +version = "0.0.0" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "same-file" version = "1.0.6" @@ -1735,7 +1787,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn", ] [[package]] @@ -1758,7 +1810,7 @@ checksum = "bcec881020c684085e55a25f7fd888954d56609ef363479dc5a1305eb0d40cab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn", ] [[package]] @@ -1803,7 +1855,7 @@ name = "span" version = "0.0.0" dependencies = [ "la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", - "rust-analyzer-salsa", + "salsa", "stdx", "syntax", "vfs", @@ -1835,17 +1887,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - [[package]] name = "syn" version = "2.0.39" @@ -1865,7 +1906,7 @@ checksum = "285ba80e733fac80aa4270fbcdf83772a79b80aa35c97075320abfee4a915b06" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn", "unicode-xid", ] @@ -1911,6 +1952,27 @@ dependencies = [ "tt", ] +[[package]] +name = "test-log" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6159ab4116165c99fc88cce31f99fa2c9dbe08d3691cb38da02fc3b45f357d2b" +dependencies = [ + "env_logger", + "test-log-macros", +] + +[[package]] +name = "test-log-macros" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ba277e77219e9eea169e8508942db1bf5d8a41ff2db9b20aab5a5aadc9fa25d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "test-utils" version = "0.0.0" @@ -1954,7 +2016,7 @@ checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn", ] [[package]] @@ -2055,7 +2117,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.39", + "syn", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3fb5d9aa7a..2b81f7b11b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,7 @@ proc-macro-srv = { path = "./crates/proc-macro-srv", version = "0.0.0" } proc-macro-srv-cli = { path = "./crates/proc-macro-srv-cli", version = "0.0.0" } profile = { path = "./crates/profile", version = "0.0.0" } project-model = { path = "./crates/project-model", version = "0.0.0" } +salsa = { path = "./crates/salsa", version = "0.0.0" } span = { path = "./crates/span", version = "0.0.0" } stdx = { path = "./crates/stdx", version = "0.0.0" } syntax = { path = "./crates/syntax", version = "0.0.0" } @@ -83,7 +84,7 @@ ra-ap-rustc_lexer = { version = "0.35.0", default-features = false } ra-ap-rustc_parse_format = { version = "0.35.0", default-features = false } ra-ap-rustc_index = { version = "0.35.0", default-features = false } ra-ap-rustc_abi = { version = "0.35.0", default-features = false } -ra-ap-rustc_pattern_analysis = { version = "0.33.0", default-features = false } +ra-ap-rustc_pattern_analysis = { version = "0.36.0", default-features = false } # local crates that aren't published to crates.io. These should not have versions. sourcegen = { path = "./crates/sourcegen" } @@ -106,22 +107,21 @@ dissimilar = "1.0.7" either = "1.9.0" expect-test = "1.4.0" hashbrown = { version = "0.14", features = [ - "inline-more", + "inline-more", ], default-features = false } indexmap = "2.1.0" itertools = "0.12.0" libc = "0.2.150" nohash-hasher = "0.2.0" rayon = "1.8.0" -rust-analyzer-salsa = "0.17.0-pre.6" rustc-hash = "1.1.0" semver = "1.0.14" serde = { version = "1.0.192", features = ["derive"] } serde_json = "1.0.108" smallvec = { version = "1.10.0", features = [ - "const_new", - "union", - "const_generics", + "const_new", + "union", + "const_generics", ] } smol_str = "0.2.1" text-size = "1.1.1" @@ -164,23 +164,20 @@ len_without_is_empty = "allow" enum_variant_names = "allow" # Builder pattern disagrees new_ret_no_self = "allow" +# Has a bunch of false positives +useless_asref = "allow" ## Following lints should be tackled at some point -borrowed_box = "allow" -derived_hash_with_manual_eq = "allow" -forget_non_drop = "allow" -needless_doctest_main = "allow" -non_canonical_clone_impl = "allow" -non_canonical_partial_ord_impl = "allow" -self_named_constructors = "allow" too_many_arguments = "allow" type_complexity = "allow" wrong_self_convention = "allow" ## warn at following lints +# CI raises these to deny dbg_macro = "warn" todo = "warn" -unimplemented = "allow" +print_stdout = "warn" +print_stderr = "warn" + rc_buffer = "warn" -# FIXME enable this, we use this pattern a lot so its annoying work ... -# str_to_string = "warn" +str_to_string = "warn" diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 0000000000..8032c775ab --- /dev/null +++ b/clippy.toml @@ -0,0 +1,5 @@ +disallowed-types = [ + { path = "std::collections::HashMap", reason = "use FxHashMap" }, + { path = "std::collections::HashSet", reason = "use FxHashSet" }, + { path = "std::collections::hash_map::RandomState", reason = "use BuildHasherDefault"} +] diff --git a/crates/base-db/Cargo.toml b/crates/base-db/Cargo.toml index 485ba78846..801ba2d1f6 100644 --- a/crates/base-db/Cargo.toml +++ b/crates/base-db/Cargo.toml @@ -13,7 +13,7 @@ doctest = false [dependencies] la-arena.workspace = true -rust-analyzer-salsa.workspace = true +salsa.workspace = true rustc-hash.workspace = true triomphe.workspace = true semver.workspace = true diff --git a/crates/base-db/src/input.rs b/crates/base-db/src/input.rs index 51e6fdb951..9560826e37 100644 --- a/crates/base-db/src/input.rs +++ b/crates/base-db/src/input.rs @@ -782,7 +782,7 @@ impl FromStr for Edition { "2018" => Edition::Edition2018, "2021" => Edition::Edition2021, "2024" => Edition::Edition2024, - _ => return Err(ParseEditionError { invalid_input: s.to_string() }), + _ => return Err(ParseEditionError { invalid_input: s.to_owned() }), }; Ok(res) } diff --git a/crates/flycheck/src/lib.rs b/crates/flycheck/src/lib.rs index 22603842a1..c59aff2a8b 100644 --- a/crates/flycheck/src/lib.rs +++ b/crates/flycheck/src/lib.rs @@ -100,9 +100,14 @@ impl FlycheckHandle { FlycheckHandle { id, sender, _thread: thread } } - /// Schedule a re-start of the cargo check worker. - pub fn restart(&self) { - self.sender.send(StateChange::Restart).unwrap(); + /// Schedule a re-start of the cargo check worker to do a workspace wide check. + pub fn restart_workspace(&self) { + self.sender.send(StateChange::Restart(None)).unwrap(); + } + + /// Schedule a re-start of the cargo check worker to do a package wide check. + pub fn restart_for_package(&self, package: String) { + self.sender.send(StateChange::Restart(Some(package))).unwrap(); } /// Stop this cargo check worker. @@ -153,7 +158,7 @@ pub enum Progress { } enum StateChange { - Restart, + Restart(Option), Cancel, } @@ -213,7 +218,7 @@ impl FlycheckActor { tracing::debug!(flycheck_id = self.id, "flycheck cancelled"); self.cancel_check_process(); } - Event::RequestStateChange(StateChange::Restart) => { + Event::RequestStateChange(StateChange::Restart(package)) => { // Cancel the previously spawned process self.cancel_check_process(); while let Ok(restart) = inbox.recv_timeout(Duration::from_millis(50)) { @@ -223,7 +228,7 @@ impl FlycheckActor { } } - let command = self.check_command(); + let command = self.check_command(package.as_deref()); let formatted_command = format!("{:?}", command); tracing::debug!(?command, "will restart flycheck"); @@ -297,7 +302,7 @@ impl FlycheckActor { } } - fn check_command(&self) -> Command { + fn check_command(&self, package: Option<&str>) -> Command { let (mut cmd, args) = match &self.config { FlycheckConfig::CargoCommand { command, @@ -314,7 +319,11 @@ impl FlycheckActor { let mut cmd = Command::new(toolchain::cargo()); cmd.arg(command); cmd.current_dir(&self.root); - cmd.arg("--workspace"); + + match package { + Some(pkg) => cmd.arg("-p").arg(pkg), + None => cmd.arg("--workspace"), + }; cmd.arg(if *ansi_color_output { "--message-format=json-diagnostic-rendered-ansi" @@ -493,9 +502,7 @@ impl CargoActor { // Skip certain kinds of messages to only spend time on what's useful JsonMessage::Cargo(message) => match message { cargo_metadata::Message::CompilerArtifact(artifact) if !artifact.fresh => { - self.sender - .send(CargoMessage::CompilerArtifact(Box::new(artifact))) - .unwrap(); + self.sender.send(CargoMessage::CompilerArtifact(artifact)).unwrap(); } cargo_metadata::Message::CompilerMessage(msg) => { self.sender.send(CargoMessage::Diagnostic(msg.message)).unwrap(); @@ -539,8 +546,9 @@ impl CargoActor { } } +#[allow(clippy::large_enum_variant)] enum CargoMessage { - CompilerArtifact(Box), + CompilerArtifact(cargo_metadata::Artifact), Diagnostic(Diagnostic), } diff --git a/crates/hir-def/src/attr.rs b/crates/hir-def/src/attr.rs index bee6f0083b..c91a549726 100644 --- a/crates/hir-def/src/attr.rs +++ b/crates/hir-def/src/attr.rs @@ -24,11 +24,11 @@ use triomphe::Arc; use crate::{ db::DefDatabase, - item_tree::{AttrOwner, Fields, ItemTreeId, ItemTreeModItemNode}, + item_tree::{AttrOwner, Fields, ItemTreeNode}, lang_item::LangItem, nameres::{ModuleOrigin, ModuleSource}, src::{HasChildSource, HasSource}, - AdtId, AssocItemLoc, AttrDefId, GenericParamId, ItemLoc, LocalFieldId, Lookup, MacroId, + AdtId, AttrDefId, GenericParamId, HasModule, ItemTreeLoc, LocalFieldId, Lookup, MacroId, VariantId, }; @@ -317,7 +317,7 @@ fn parse_comma_sep(subtree: &tt::Subtree) -> Vec { } impl AttrsWithOwner { - pub fn attrs_with_owner(db: &dyn DefDatabase, owner: AttrDefId) -> Self { + pub fn new(db: &dyn DefDatabase, owner: AttrDefId) -> Self { Self { attrs: db.attrs(owner), owner } } @@ -356,12 +356,7 @@ impl AttrsWithOwner { AttrDefId::FieldId(it) => { return db.fields_attrs(it.parent)[it.local_id].clone(); } - // FIXME: DRY this up - AttrDefId::EnumVariantId(it) => { - let id = it.lookup(db).id; - let tree = id.item_tree(db); - tree.raw_attrs(id.value.into()).clone() - } + AttrDefId::EnumVariantId(it) => attrs_from_item_tree_loc(db, it), AttrDefId::AdtId(it) => match it { AdtId::StructId(it) => attrs_from_item_tree_loc(db, it), AdtId::EnumId(it) => attrs_from_item_tree_loc(db, it), @@ -370,15 +365,15 @@ impl AttrsWithOwner { AttrDefId::TraitId(it) => attrs_from_item_tree_loc(db, it), AttrDefId::TraitAliasId(it) => attrs_from_item_tree_loc(db, it), AttrDefId::MacroId(it) => match it { - MacroId::Macro2Id(it) => attrs_from_item_tree(db, it.lookup(db).id), - MacroId::MacroRulesId(it) => attrs_from_item_tree(db, it.lookup(db).id), - MacroId::ProcMacroId(it) => attrs_from_item_tree(db, it.lookup(db).id), + MacroId::Macro2Id(it) => attrs_from_item_tree_loc(db, it), + MacroId::MacroRulesId(it) => attrs_from_item_tree_loc(db, it), + MacroId::ProcMacroId(it) => attrs_from_item_tree_loc(db, it), }, AttrDefId::ImplId(it) => attrs_from_item_tree_loc(db, it), - AttrDefId::ConstId(it) => attrs_from_item_tree_assoc(db, it), - AttrDefId::StaticId(it) => attrs_from_item_tree_assoc(db, it), - AttrDefId::FunctionId(it) => attrs_from_item_tree_assoc(db, it), - AttrDefId::TypeAliasId(it) => attrs_from_item_tree_assoc(db, it), + AttrDefId::ConstId(it) => attrs_from_item_tree_loc(db, it), + AttrDefId::StaticId(it) => attrs_from_item_tree_loc(db, it), + AttrDefId::FunctionId(it) => attrs_from_item_tree_loc(db, it), + AttrDefId::TypeAliasId(it) => attrs_from_item_tree_loc(db, it), AttrDefId::GenericParamId(it) => match it { GenericParamId::ConstParamId(it) => { let src = it.parent().child_source(db); @@ -603,29 +598,14 @@ fn any_has_attrs<'db>( id.lookup(db).source(db).map(ast::AnyHasAttrs::new) } -fn attrs_from_item_tree( - db: &dyn DefDatabase, - id: ItemTreeId, +fn attrs_from_item_tree_loc<'db, N: ItemTreeNode>( + db: &(dyn DefDatabase + 'db), + lookup: impl Lookup = dyn DefDatabase + 'db, Data = impl ItemTreeLoc>, ) -> RawAttrs { + let id = lookup.lookup(db).item_tree_id(); let tree = id.item_tree(db); - let mod_item = N::id_to_mod_item(id.value); - tree.raw_attrs(mod_item.into()).clone() -} - -fn attrs_from_item_tree_loc<'db, N: ItemTreeModItemNode>( - db: &(dyn DefDatabase + 'db), - lookup: impl Lookup = dyn DefDatabase + 'db, Data = ItemLoc>, -) -> RawAttrs { - let id = lookup.lookup(db).id; - attrs_from_item_tree(db, id) -} - -fn attrs_from_item_tree_assoc<'db, N: ItemTreeModItemNode>( - db: &(dyn DefDatabase + 'db), - lookup: impl Lookup = dyn DefDatabase + 'db, Data = AssocItemLoc>, -) -> RawAttrs { - let id = lookup.lookup(db).id; - attrs_from_item_tree(db, id) + let attr_owner = N::attr_owner(id.value); + tree.raw_attrs(attr_owner).clone() } pub(crate) fn fields_attrs_source_map( diff --git a/crates/hir-def/src/attr/builtin.rs b/crates/hir-def/src/attr/builtin.rs index 48a596f7f5..b20ee9e5bf 100644 --- a/crates/hir-def/src/attr/builtin.rs +++ b/crates/hir-def/src/attr/builtin.rs @@ -283,9 +283,6 @@ pub const INERT_ATTRIBUTES: &[BuiltinAttribute] = &[ experimental!(optimize), ), - gated!( - ffi_returns_twice, Normal, template!(Word), WarnFollowing, experimental!(ffi_returns_twice) - ), gated!(ffi_pure, Normal, template!(Word), WarnFollowing, experimental!(ffi_pure)), gated!(ffi_const, Normal, template!(Word), WarnFollowing, experimental!(ffi_const)), gated!( diff --git a/crates/hir-def/src/body/lower.rs b/crates/hir-def/src/body/lower.rs index 492ea6d5c5..29ac666277 100644 --- a/crates/hir-def/src/body/lower.rs +++ b/crates/hir-def/src/body/lower.rs @@ -1980,10 +1980,7 @@ fn pat_literal_to_hir(lit: &ast::LiteralPat) -> Option<(Literal, ast::Literal)> let ast_lit = lit.literal()?; let mut hir_lit: Literal = ast_lit.kind().into(); if lit.minus_token().is_some() { - let Some(h) = hir_lit.negate() else { - return None; - }; - hir_lit = h; + hir_lit = hir_lit.negate()?; } Some((hir_lit, ast_lit)) } diff --git a/crates/hir-def/src/body/pretty.rs b/crates/hir-def/src/body/pretty.rs index b821b91b89..4afb408651 100644 --- a/crates/hir-def/src/body/pretty.rs +++ b/crates/hir-def/src/body/pretty.rs @@ -29,11 +29,11 @@ pub(super) fn print_body_hir(db: &dyn DefDatabase, body: &Body, owner: DefWithBo "const {} = ", match &it.name { Some(name) => name.display(db.upcast()).to_string(), - None => "_".to_string(), + None => "_".to_owned(), } ) }), - DefWithBodyId::InTypeConstId(_) => "In type const = ".to_string(), + DefWithBodyId::InTypeConstId(_) => "In type const = ".to_owned(), DefWithBodyId::VariantId(it) => { let loc = it.lookup(db); let enum_loc = loc.parent.lookup(db); @@ -123,7 +123,7 @@ impl Printer<'_> { wln!(self); f(self); self.indent_level -= 1; - self.buf = self.buf.trim_end_matches('\n').to_string(); + self.buf = self.buf.trim_end_matches('\n').to_owned(); } fn whitespace(&mut self) { diff --git a/crates/hir-def/src/child_by_source.rs b/crates/hir-def/src/child_by_source.rs index b3bb3355f1..ba7d06272a 100644 --- a/crates/hir-def/src/child_by_source.rs +++ b/crates/hir-def/src/child_by_source.rs @@ -6,15 +6,21 @@ use either::Either; use hir_expand::{attrs::collect_attrs, HirFileId}; +use syntax::ast; use crate::{ db::DefDatabase, - dyn_map::{keys, DynMap}, + dyn_map::{ + keys::{self, Key}, + DynMap, + }, item_scope::ItemScope, + item_tree::ItemTreeNode, nameres::DefMap, src::{HasChildSource, HasSource}, - AdtId, AssocItemId, DefWithBodyId, EnumId, ExternCrateId, FieldId, ImplId, Lookup, MacroId, - ModuleDefId, ModuleId, TraitId, UseId, VariantId, + AdtId, AssocItemId, DefWithBodyId, EnumId, FieldId, GenericDefId, ImplId, ItemTreeLoc, + LifetimeParamId, Lookup, MacroId, ModuleDefId, ModuleId, TraitId, TypeOrConstParamId, + VariantId, }; pub trait ChildBySource { @@ -55,29 +61,6 @@ impl ChildBySource for ImplId { } } -fn add_assoc_item(db: &dyn DefDatabase, res: &mut DynMap, file_id: HirFileId, item: AssocItemId) { - match item { - AssocItemId::FunctionId(func) => { - let loc = func.lookup(db); - if loc.id.file_id() == file_id { - res[keys::FUNCTION].insert(loc.source(db).value, func) - } - } - AssocItemId::ConstId(konst) => { - let loc = konst.lookup(db); - if loc.id.file_id() == file_id { - res[keys::CONST].insert(loc.source(db).value, konst) - } - } - AssocItemId::TypeAliasId(ty) => { - let loc = ty.lookup(db); - if loc.id.file_id() == file_id { - res[keys::TYPE_ALIAS].insert(loc.source(db).value, ty) - } - } - } -} - impl ChildBySource for ModuleId { fn child_by_source_to(&self, db: &dyn DefDatabase, res: &mut DynMap, file_id: HirFileId) { let def_map = self.def_map(db); @@ -89,15 +72,12 @@ impl ChildBySource for ModuleId { impl ChildBySource for ItemScope { fn child_by_source_to(&self, db: &dyn DefDatabase, res: &mut DynMap, file_id: HirFileId) { self.declarations().for_each(|item| add_module_def(db, res, file_id, item)); - self.impls().for_each(|imp| add_impl(db, res, file_id, imp)); - self.extern_crate_decls().for_each(|ext| add_extern_crate(db, res, file_id, ext)); - self.use_decls().for_each(|ext| add_use(db, res, file_id, ext)); - self.unnamed_consts(db).for_each(|konst| { - let loc = konst.lookup(db); - if loc.id.file_id() == file_id { - res[keys::CONST].insert(loc.source(db).value, konst); - } - }); + self.impls().for_each(|imp| insert_item_loc(db, res, file_id, imp, keys::IMPL)); + self.extern_crate_decls() + .for_each(|ext| insert_item_loc(db, res, file_id, ext, keys::EXTERN_CRATE)); + self.use_decls().for_each(|ext| insert_item_loc(db, res, file_id, ext, keys::USE)); + self.unnamed_consts(db) + .for_each(|konst| insert_item_loc(db, res, file_id, konst, keys::CONST)); self.attr_macro_invocs().filter(|(id, _)| id.file_id == file_id).for_each( |(ast_id, call_id)| { res[keys::ATTR_MACRO_CALL].insert(ast_id.to_node(db.upcast()), call_id); @@ -132,59 +112,38 @@ impl ChildBySource for ItemScope { file_id: HirFileId, item: ModuleDefId, ) { - macro_rules! insert { - ($map:ident[$key:path].$insert:ident($id:ident)) => {{ - let loc = $id.lookup(db); - if loc.id.file_id() == file_id { - $map[$key].$insert(loc.source(db).value, $id) - } - }}; - } match item { - ModuleDefId::FunctionId(id) => insert!(map[keys::FUNCTION].insert(id)), - ModuleDefId::ConstId(id) => insert!(map[keys::CONST].insert(id)), - ModuleDefId::StaticId(id) => insert!(map[keys::STATIC].insert(id)), - ModuleDefId::TypeAliasId(id) => insert!(map[keys::TYPE_ALIAS].insert(id)), - ModuleDefId::TraitId(id) => insert!(map[keys::TRAIT].insert(id)), - ModuleDefId::TraitAliasId(id) => insert!(map[keys::TRAIT_ALIAS].insert(id)), + ModuleDefId::FunctionId(id) => { + insert_item_loc(db, map, file_id, id, keys::FUNCTION) + } + ModuleDefId::ConstId(id) => insert_item_loc(db, map, file_id, id, keys::CONST), + ModuleDefId::TypeAliasId(id) => { + insert_item_loc(db, map, file_id, id, keys::TYPE_ALIAS) + } + ModuleDefId::StaticId(id) => insert_item_loc(db, map, file_id, id, keys::STATIC), + ModuleDefId::TraitId(id) => insert_item_loc(db, map, file_id, id, keys::TRAIT), + ModuleDefId::TraitAliasId(id) => { + insert_item_loc(db, map, file_id, id, keys::TRAIT_ALIAS) + } ModuleDefId::AdtId(adt) => match adt { - AdtId::StructId(id) => insert!(map[keys::STRUCT].insert(id)), - AdtId::UnionId(id) => insert!(map[keys::UNION].insert(id)), - AdtId::EnumId(id) => insert!(map[keys::ENUM].insert(id)), + AdtId::StructId(id) => insert_item_loc(db, map, file_id, id, keys::STRUCT), + AdtId::UnionId(id) => insert_item_loc(db, map, file_id, id, keys::UNION), + AdtId::EnumId(id) => insert_item_loc(db, map, file_id, id, keys::ENUM), }, ModuleDefId::MacroId(id) => match id { - MacroId::Macro2Id(id) => insert!(map[keys::MACRO2].insert(id)), - MacroId::MacroRulesId(id) => insert!(map[keys::MACRO_RULES].insert(id)), - MacroId::ProcMacroId(id) => insert!(map[keys::PROC_MACRO].insert(id)), + MacroId::Macro2Id(id) => insert_item_loc(db, map, file_id, id, keys::MACRO2), + MacroId::MacroRulesId(id) => { + insert_item_loc(db, map, file_id, id, keys::MACRO_RULES) + } + MacroId::ProcMacroId(id) => { + insert_item_loc(db, map, file_id, id, keys::PROC_MACRO) + } }, ModuleDefId::ModuleId(_) | ModuleDefId::EnumVariantId(_) | ModuleDefId::BuiltinType(_) => (), } } - fn add_impl(db: &dyn DefDatabase, map: &mut DynMap, file_id: HirFileId, imp: ImplId) { - let loc = imp.lookup(db); - if loc.id.file_id() == file_id { - map[keys::IMPL].insert(loc.source(db).value, imp) - } - } - fn add_extern_crate( - db: &dyn DefDatabase, - map: &mut DynMap, - file_id: HirFileId, - ext: ExternCrateId, - ) { - let loc = ext.lookup(db); - if loc.id.file_id() == file_id { - map[keys::EXTERN_CRATE].insert(loc.source(db).value, ext) - } - } - fn add_use(db: &dyn DefDatabase, map: &mut DynMap, file_id: HirFileId, ext: UseId) { - let loc = ext.lookup(db); - if loc.id.file_id() == file_id { - map[keys::USE].insert(loc.source(db).value, ext) - } - } } } @@ -237,3 +196,63 @@ impl ChildBySource for DefWithBodyId { } } } + +impl ChildBySource for GenericDefId { + fn child_by_source_to(&self, db: &dyn DefDatabase, res: &mut DynMap, file_id: HirFileId) { + let (gfile_id, generic_params_list) = self.file_id_and_params_of(db); + if gfile_id != file_id { + return; + } + + let generic_params = db.generic_params(*self); + let mut toc_idx_iter = generic_params.type_or_consts.iter().map(|(idx, _)| idx); + let lts_idx_iter = generic_params.lifetimes.iter().map(|(idx, _)| idx); + + // For traits the first type index is `Self`, skip it. + if let GenericDefId::TraitId(_) = *self { + toc_idx_iter.next().unwrap(); // advance_by(1); + } + + if let Some(generic_params_list) = generic_params_list { + for (local_id, ast_param) in + toc_idx_iter.zip(generic_params_list.type_or_const_params()) + { + let id = TypeOrConstParamId { parent: *self, local_id }; + match ast_param { + ast::TypeOrConstParam::Type(a) => res[keys::TYPE_PARAM].insert(a, id), + ast::TypeOrConstParam::Const(a) => res[keys::CONST_PARAM].insert(a, id), + } + } + for (local_id, ast_param) in lts_idx_iter.zip(generic_params_list.lifetime_params()) { + let id = LifetimeParamId { parent: *self, local_id }; + res[keys::LIFETIME_PARAM].insert(ast_param, id); + } + } + } +} + +fn insert_item_loc( + db: &dyn DefDatabase, + res: &mut DynMap, + file_id: HirFileId, + id: ID, + key: Key, +) where + ID: for<'db> Lookup = dyn DefDatabase + 'db, Data = Data> + 'static, + Data: ItemTreeLoc, + N: ItemTreeNode, + N::Source: 'static, +{ + let loc = id.lookup(db); + if loc.item_tree_id().file_id() == file_id { + res[key].insert(loc.source(db).value, id) + } +} + +fn add_assoc_item(db: &dyn DefDatabase, res: &mut DynMap, file_id: HirFileId, item: AssocItemId) { + match item { + AssocItemId::FunctionId(func) => insert_item_loc(db, res, file_id, func, keys::FUNCTION), + AssocItemId::ConstId(konst) => insert_item_loc(db, res, file_id, konst, keys::CONST), + AssocItemId::TypeAliasId(ty) => insert_item_loc(db, res, file_id, ty, keys::TYPE_ALIAS), + } +} diff --git a/crates/hir-def/src/data/adt.rs b/crates/hir-def/src/data/adt.rs index 5986b7df3d..540f643ae7 100644 --- a/crates/hir-def/src/data/adt.rs +++ b/crates/hir-def/src/data/adt.rs @@ -10,7 +10,7 @@ use hir_expand::{ HirFileId, InFile, }; use intern::Interned; -use la_arena::{Arena, ArenaMap}; +use la_arena::Arena; use rustc_abi::{Align, Integer, IntegerType, ReprFlags, ReprOptions}; use syntax::ast::{self, HasName, HasVisibility}; use triomphe::Arc; @@ -22,13 +22,11 @@ use crate::{ lang_item::LangItem, lower::LowerCtx, nameres::diagnostics::{DefDiagnostic, DefDiagnostics}, - src::HasChildSource, - src::HasSource, trace::Trace, tt::{Delimiter, DelimiterKind, Leaf, Subtree, TokenTree}, type_ref::TypeRef, visibility::RawVisibility, - EnumId, EnumVariantId, LocalFieldId, LocalModuleId, Lookup, StructId, UnionId, VariantId, + EnumId, EnumVariantId, LocalFieldId, LocalModuleId, Lookup, StructId, UnionId, }; /// Note that we use `StructData` for unions as well! @@ -387,46 +385,6 @@ impl VariantData { } } -impl HasChildSource for VariantId { - type Value = Either; - - fn child_source(&self, db: &dyn DefDatabase) -> InFile> { - let item_tree; - let (src, fields, container) = match *self { - VariantId::EnumVariantId(it) => { - let lookup = it.lookup(db); - item_tree = lookup.id.item_tree(db); - ( - lookup.source(db).map(|it| it.kind()), - &item_tree[lookup.id.value].fields, - lookup.parent.lookup(db).container, - ) - } - VariantId::StructId(it) => { - let lookup = it.lookup(db); - item_tree = lookup.id.item_tree(db); - ( - lookup.source(db).map(|it| it.kind()), - &item_tree[lookup.id.value].fields, - lookup.container, - ) - } - VariantId::UnionId(it) => { - let lookup = it.lookup(db); - item_tree = lookup.id.item_tree(db); - ( - lookup.source(db).map(|it| it.kind()), - &item_tree[lookup.id.value].fields, - lookup.container, - ) - } - }; - let mut trace = Trace::new_for_map(); - lower_struct(db, &mut trace, &src, container.krate, &item_tree, fields); - src.with_value(trace.into_map()) - } -} - #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum StructKind { Tuple, @@ -434,7 +392,7 @@ pub enum StructKind { Unit, } -fn lower_struct( +pub(crate) fn lower_struct( db: &dyn DefDatabase, trace: &mut Trace>, ast: &InFile, diff --git a/crates/hir-def/src/find_path.rs b/crates/hir-def/src/find_path.rs index 1a44c319de..2e137f67b4 100644 --- a/crates/hir-def/src/find_path.rs +++ b/crates/hir-def/src/find_path.rs @@ -171,6 +171,7 @@ fn find_path_inner(ctx: FindPathCtx<'_>, item: ItemInNs, from: ModuleId) -> Opti .map(|(item, _)| item) } +#[tracing::instrument(skip_all)] fn find_path_for_module( ctx: FindPathCtx<'_>, def_map: &DefMap, @@ -312,6 +313,7 @@ fn find_self_super(def_map: &DefMap, item: ModuleId, from: ModuleId) -> Option, def_map: &DefMap, diff --git a/crates/hir-def/src/generics.rs b/crates/hir-def/src/generics.rs index 349d327aaa..1d2c7c3a55 100644 --- a/crates/hir-def/src/generics.rs +++ b/crates/hir-def/src/generics.rs @@ -3,31 +3,27 @@ //! generic parameters. See also the `Generics` type and the `generics_of` query //! in rustc. -use base_db::FileId; use either::Either; use hir_expand::{ name::{AsName, Name}, - ExpandResult, HirFileId, InFile, + ExpandResult, }; use intern::Interned; -use la_arena::{Arena, ArenaMap, Idx}; +use la_arena::{Arena, Idx}; use once_cell::unsync::Lazy; use stdx::impl_from; use syntax::ast::{self, HasGenericParams, HasName, HasTypeBounds}; use triomphe::Arc; use crate::{ - child_by_source::ChildBySource, db::DefDatabase, - dyn_map::{keys, DynMap}, expander::Expander, - item_tree::ItemTree, + item_tree::{GenericsItemTreeNode, ItemTree}, lower::LowerCtx, nameres::{DefMap, MacroSubNs}, - src::{HasChildSource, HasSource}, type_ref::{ConstRef, LifetimeRef, TypeBound, TypeRef}, - AdtId, ConstParamId, GenericDefId, HasModule, LifetimeParamId, LocalLifetimeParamId, - LocalTypeOrConstParamId, Lookup, TypeOrConstParamId, TypeParamId, + AdtId, ConstParamId, GenericDefId, HasModule, ItemTreeLoc, LocalTypeOrConstParamId, Lookup, + TypeOrConstParamId, TypeParamId, }; /// Data about a generic type parameter (to a function, struct, impl, ...). @@ -264,7 +260,7 @@ impl GenericParamsCollector { self.add_where_predicate_from_bound( lower_ctx, bound, - lifetimes.as_ref(), + lifetimes.as_deref(), target.clone(), ); } @@ -275,14 +271,14 @@ impl GenericParamsCollector { &mut self, lower_ctx: &LowerCtx<'_>, bound: ast::TypeBound, - hrtb_lifetimes: Option<&Box<[Name]>>, + hrtb_lifetimes: Option<&[Name]>, target: Either, ) { let bound = TypeBound::from_ast(lower_ctx, bound); let predicate = match (target, bound) { (Either::Left(type_ref), bound) => match hrtb_lifetimes { Some(hrtb_lifetimes) => WherePredicate::ForLifetime { - lifetimes: hrtb_lifetimes.clone(), + lifetimes: hrtb_lifetimes.to_vec().into_boxed_slice(), target: WherePredicateTypeTarget::TypeRef(Interned::new(type_ref)), bound: Interned::new(bound), }, @@ -418,13 +414,18 @@ impl GenericParams { }) } }; - macro_rules! id_to_generics { - ($id:ident) => {{ - let id = $id.lookup(db).id; - let tree = id.item_tree(db); - let item = &tree[id.value]; - enabled_params(&item.generic_params, &tree) - }}; + fn id_to_generics( + db: &dyn DefDatabase, + id: impl for<'db> Lookup< + Database<'db> = dyn DefDatabase + 'db, + Data = impl ItemTreeLoc, + >, + enabled_params: impl Fn(&Interned, &ItemTree) -> Interned, + ) -> Interned { + let id = id.lookup(db).item_tree_id(); + let tree = id.item_tree(db); + let item = &tree[id.value]; + enabled_params(item.generic_params(), &tree) } match def { @@ -457,13 +458,13 @@ impl GenericParams { Interned::new(generic_params.finish()) } } - GenericDefId::AdtId(AdtId::StructId(id)) => id_to_generics!(id), - GenericDefId::AdtId(AdtId::EnumId(id)) => id_to_generics!(id), - GenericDefId::AdtId(AdtId::UnionId(id)) => id_to_generics!(id), - GenericDefId::TraitId(id) => id_to_generics!(id), - GenericDefId::TraitAliasId(id) => id_to_generics!(id), - GenericDefId::TypeAliasId(id) => id_to_generics!(id), - GenericDefId::ImplId(id) => id_to_generics!(id), + GenericDefId::AdtId(AdtId::StructId(id)) => id_to_generics(db, id, enabled_params), + GenericDefId::AdtId(AdtId::EnumId(id)) => id_to_generics(db, id, enabled_params), + GenericDefId::AdtId(AdtId::UnionId(id)) => id_to_generics(db, id, enabled_params), + GenericDefId::TraitId(id) => id_to_generics(db, id, enabled_params), + GenericDefId::TraitAliasId(id) => id_to_generics(db, id, enabled_params), + GenericDefId::TypeAliasId(id) => id_to_generics(db, id, enabled_params), + GenericDefId::ImplId(id) => id_to_generics(db, id, enabled_params), GenericDefId::EnumVariantId(_) | GenericDefId::ConstId(_) => { Interned::new(GenericParams { type_or_consts: Default::default(), @@ -507,141 +508,3 @@ impl GenericParams { }) } } - -fn file_id_and_params_of( - def: GenericDefId, - db: &dyn DefDatabase, -) -> (HirFileId, Option) { - match def { - GenericDefId::FunctionId(it) => { - let src = it.lookup(db).source(db); - (src.file_id, src.value.generic_param_list()) - } - GenericDefId::AdtId(AdtId::StructId(it)) => { - let src = it.lookup(db).source(db); - (src.file_id, src.value.generic_param_list()) - } - GenericDefId::AdtId(AdtId::UnionId(it)) => { - let src = it.lookup(db).source(db); - (src.file_id, src.value.generic_param_list()) - } - GenericDefId::AdtId(AdtId::EnumId(it)) => { - let src = it.lookup(db).source(db); - (src.file_id, src.value.generic_param_list()) - } - GenericDefId::TraitId(it) => { - let src = it.lookup(db).source(db); - (src.file_id, src.value.generic_param_list()) - } - GenericDefId::TraitAliasId(it) => { - let src = it.lookup(db).source(db); - (src.file_id, src.value.generic_param_list()) - } - GenericDefId::TypeAliasId(it) => { - let src = it.lookup(db).source(db); - (src.file_id, src.value.generic_param_list()) - } - GenericDefId::ImplId(it) => { - let src = it.lookup(db).source(db); - (src.file_id, src.value.generic_param_list()) - } - // We won't be using this ID anyway - GenericDefId::EnumVariantId(_) | GenericDefId::ConstId(_) => (FileId::BOGUS.into(), None), - } -} - -impl HasChildSource for GenericDefId { - type Value = Either; - fn child_source( - &self, - db: &dyn DefDatabase, - ) -> InFile> { - let generic_params = db.generic_params(*self); - let mut idx_iter = generic_params.type_or_consts.iter().map(|(idx, _)| idx); - - let (file_id, generic_params_list) = file_id_and_params_of(*self, db); - - let mut params = ArenaMap::default(); - - // For traits and trait aliases the first type index is `Self`, we need to add it before - // the other params. - match *self { - GenericDefId::TraitId(id) => { - let trait_ref = id.lookup(db).source(db).value; - let idx = idx_iter.next().unwrap(); - params.insert(idx, Either::Right(ast::TraitOrAlias::Trait(trait_ref))); - } - GenericDefId::TraitAliasId(id) => { - let alias = id.lookup(db).source(db).value; - let idx = idx_iter.next().unwrap(); - params.insert(idx, Either::Right(ast::TraitOrAlias::TraitAlias(alias))); - } - _ => {} - } - - if let Some(generic_params_list) = generic_params_list { - for (idx, ast_param) in idx_iter.zip(generic_params_list.type_or_const_params()) { - params.insert(idx, Either::Left(ast_param)); - } - } - - InFile::new(file_id, params) - } -} - -impl HasChildSource for GenericDefId { - type Value = ast::LifetimeParam; - fn child_source( - &self, - db: &dyn DefDatabase, - ) -> InFile> { - let generic_params = db.generic_params(*self); - let idx_iter = generic_params.lifetimes.iter().map(|(idx, _)| idx); - - let (file_id, generic_params_list) = file_id_and_params_of(*self, db); - - let mut params = ArenaMap::default(); - - if let Some(generic_params_list) = generic_params_list { - for (idx, ast_param) in idx_iter.zip(generic_params_list.lifetime_params()) { - params.insert(idx, ast_param); - } - } - - InFile::new(file_id, params) - } -} - -impl ChildBySource for GenericDefId { - fn child_by_source_to(&self, db: &dyn DefDatabase, res: &mut DynMap, file_id: HirFileId) { - let (gfile_id, generic_params_list) = file_id_and_params_of(*self, db); - if gfile_id != file_id { - return; - } - - let generic_params = db.generic_params(*self); - let mut toc_idx_iter = generic_params.type_or_consts.iter().map(|(idx, _)| idx); - let lts_idx_iter = generic_params.lifetimes.iter().map(|(idx, _)| idx); - - // For traits the first type index is `Self`, skip it. - if let GenericDefId::TraitId(_) = *self { - toc_idx_iter.next().unwrap(); // advance_by(1); - } - - if let Some(generic_params_list) = generic_params_list { - for (local_id, ast_param) in - toc_idx_iter.zip(generic_params_list.type_or_const_params()) - { - let id = TypeOrConstParamId { parent: *self, local_id }; - match ast_param { - ast::TypeOrConstParam::Type(a) => res[keys::TYPE_PARAM].insert(a, id), - ast::TypeOrConstParam::Const(a) => res[keys::CONST_PARAM].insert(a, id), - } - } - for (local_id, ast_param) in lts_idx_iter.zip(generic_params_list.lifetime_params()) { - let id = LifetimeParamId { parent: *self, local_id }; - res[keys::LIFETIME_PARAM].insert(ast_param, id); - } - } - } -} diff --git a/crates/hir-def/src/import_map.rs b/crates/hir-def/src/import_map.rs index c698510ca9..98982c7db8 100644 --- a/crates/hir-def/src/import_map.rs +++ b/crates/hir-def/src/import_map.rs @@ -859,7 +859,7 @@ mod tests { check_search( ra_fixture, "main", - Query::new("fmt".to_string()).fuzzy(), + Query::new("fmt".to_owned()).fuzzy(), expect![[r#" dep::fmt (t) dep::fmt::Display::FMT_CONST (a) @@ -888,9 +888,7 @@ mod tests { check_search( ra_fixture, "main", - Query::new("fmt".to_string()) - .fuzzy() - .assoc_search_mode(AssocSearchMode::AssocItemsOnly), + Query::new("fmt".to_owned()).fuzzy().assoc_search_mode(AssocSearchMode::AssocItemsOnly), expect![[r#" dep::fmt::Display::FMT_CONST (a) dep::fmt::Display::format_function (a) @@ -901,7 +899,7 @@ mod tests { check_search( ra_fixture, "main", - Query::new("fmt".to_string()).fuzzy().assoc_search_mode(AssocSearchMode::Exclude), + Query::new("fmt".to_owned()).fuzzy().assoc_search_mode(AssocSearchMode::Exclude), expect![[r#" dep::fmt (t) "#]], @@ -937,7 +935,7 @@ pub mod fmt { check_search( ra_fixture, "main", - Query::new("fmt".to_string()).fuzzy(), + Query::new("fmt".to_owned()).fuzzy(), expect![[r#" dep::Fmt (m) dep::Fmt (t) @@ -951,7 +949,7 @@ pub mod fmt { check_search( ra_fixture, "main", - Query::new("fmt".to_string()), + Query::new("fmt".to_owned()), expect![[r#" dep::Fmt (m) dep::Fmt (t) @@ -991,7 +989,7 @@ pub mod fmt { check_search( ra_fixture, "main", - Query::new("fmt".to_string()), + Query::new("fmt".to_owned()), expect![[r#" dep::Fmt (m) dep::Fmt (t) @@ -1015,7 +1013,7 @@ pub mod fmt { check_search( ra_fixture, "main", - Query::new("FMT".to_string()), + Query::new("FMT".to_owned()), expect![[r#" dep::FMT (t) dep::FMT (v) @@ -1027,7 +1025,7 @@ pub mod fmt { check_search( ra_fixture, "main", - Query::new("FMT".to_string()).case_sensitive(), + Query::new("FMT".to_owned()).case_sensitive(), expect![[r#" dep::FMT (t) dep::FMT (v) diff --git a/crates/hir-def/src/item_scope.rs b/crates/hir-def/src/item_scope.rs index 0b0c838bed..0e6826a75a 100644 --- a/crates/hir-def/src/item_scope.rs +++ b/crates/hir-def/src/item_scope.rs @@ -222,17 +222,15 @@ impl ItemScope { self.declarations.iter().copied() } - pub fn extern_crate_decls( - &self, - ) -> impl Iterator + ExactSizeIterator + '_ { + pub fn extern_crate_decls(&self) -> impl ExactSizeIterator + '_ { self.extern_crate_decls.iter().copied() } - pub fn use_decls(&self) -> impl Iterator + ExactSizeIterator + '_ { + pub fn use_decls(&self) -> impl ExactSizeIterator + '_ { self.use_decls.iter().copied() } - pub fn impls(&self) -> impl Iterator + ExactSizeIterator + '_ { + pub fn impls(&self) -> impl ExactSizeIterator + '_ { self.impls.iter().copied() } @@ -674,7 +672,7 @@ impl ItemScope { format_to!( buf, "{}:", - name.map_or("_".to_string(), |name| name.display(db).to_string()) + name.map_or("_".to_owned(), |name| name.display(db).to_string()) ); if let Some((.., i)) = def.types { diff --git a/crates/hir-def/src/item_tree.rs b/crates/hir-def/src/item_tree.rs index 336e0de7fd..be16a5e31a 100644 --- a/crates/hir-def/src/item_tree.rs +++ b/crates/hir-def/src/item_tree.rs @@ -337,20 +337,18 @@ from_attrs!( LifetimeParamData(Idx), ); -/// Trait implemented by all item nodes in the item tree. -pub trait ItemTreeModItemNode: Clone { - type Source: AstIdNode + Into; +/// Trait implemented by all nodes in the item tree. +pub trait ItemTreeNode: Clone { + type Source: AstIdNode; fn ast_id(&self) -> FileAstId; /// Looks up an instance of `Self` in an item tree. fn lookup(tree: &ItemTree, index: Idx) -> &Self; - - /// Downcasts a `ModItem` to a `FileItemTreeId` specific to this type. - fn id_from_mod_item(mod_item: ModItem) -> Option>; - - /// Upcasts a `FileItemTreeId` to a generic `ModItem`. - fn id_to_mod_item(id: FileItemTreeId) -> ModItem; + fn attr_owner(id: FileItemTreeId) -> AttrOwner; +} +pub trait GenericsItemTreeNode: ItemTreeNode { + fn generic_params(&self) -> &Interned; } pub struct FileItemTreeId(Idx); @@ -372,7 +370,7 @@ impl FileItemTreeId { impl Clone for FileItemTreeId { fn clone(&self) -> Self { - Self(self.0) + *self } } impl Copy for FileItemTreeId {} @@ -478,7 +476,7 @@ impl Hash for ItemTreeId { } macro_rules! mod_items { - ( $( $typ:ident in $fld:ident -> $ast:ty ),+ $(,)? ) => { + ( $( $typ:ident $(<$generic_params:ident>)? in $fld:ident -> $ast:ty ),+ $(,)? ) => { #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] pub enum ModItem { $( @@ -495,7 +493,7 @@ macro_rules! mod_items { )+ $( - impl ItemTreeModItemNode for $typ { + impl ItemTreeNode for $typ { type Source = $ast; fn ast_id(&self) -> FileAstId { @@ -506,15 +504,8 @@ macro_rules! mod_items { &tree.data().$fld[index] } - fn id_from_mod_item(mod_item: ModItem) -> Option> { - match mod_item { - ModItem::$typ(id) => Some(id), - _ => None, - } - } - - fn id_to_mod_item(id: FileItemTreeId) -> ModItem { - ModItem::$typ(id) + fn attr_owner(id: FileItemTreeId) -> AttrOwner { + AttrOwner::ModItem(ModItem::$typ(id)) } } @@ -525,6 +516,14 @@ macro_rules! mod_items { &self.data().$fld[index] } } + + $( + impl GenericsItemTreeNode for $typ { + fn generic_params(&self) -> &Interned { + &self.$generic_params + } + } + )? )+ }; } @@ -533,16 +532,16 @@ mod_items! { Use in uses -> ast::Use, ExternCrate in extern_crates -> ast::ExternCrate, ExternBlock in extern_blocks -> ast::ExternBlock, - Function in functions -> ast::Fn, - Struct in structs -> ast::Struct, - Union in unions -> ast::Union, - Enum in enums -> ast::Enum, + Function in functions -> ast::Fn, + Struct in structs -> ast::Struct, + Union in unions -> ast::Union, + Enum in enums -> ast::Enum, Const in consts -> ast::Const, Static in statics -> ast::Static, - Trait in traits -> ast::Trait, - TraitAlias in trait_aliases -> ast::TraitAlias, - Impl in impls -> ast::Impl, - TypeAlias in type_aliases -> ast::TypeAlias, + Trait in traits -> ast::Trait, + TraitAlias in trait_aliases -> ast::TraitAlias, + Impl in impls -> ast::Impl, + TypeAlias in type_aliases -> ast::TypeAlias, Mod in mods -> ast::Module, MacroCall in macro_calls -> ast::MacroCall, MacroRules in macro_rules -> ast::MacroRules, @@ -578,17 +577,26 @@ impl Index for ItemTree { } } -impl Index> for ItemTree { +impl Index> for ItemTree { type Output = N; fn index(&self, id: FileItemTreeId) -> &N { N::lookup(self, id.index()) } } -impl Index> for ItemTree { - type Output = Variant; - fn index(&self, id: FileItemTreeId) -> &Variant { - &self[id.index()] +impl ItemTreeNode for Variant { + type Source = ast::Variant; + + fn ast_id(&self) -> FileAstId { + self.ast_id + } + + fn lookup(tree: &ItemTree, index: Idx) -> &Self { + &tree.data().variants[index] + } + + fn attr_owner(id: FileItemTreeId) -> AttrOwner { + AttrOwner::Variant(id) } } @@ -1027,7 +1035,7 @@ impl AssocItem { } } -#[derive(Debug, Eq, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct Variant { pub name: Name, pub fields: Fields, diff --git a/crates/hir-def/src/item_tree/lower.rs b/crates/hir-def/src/item_tree/lower.rs index b500f56b6c..e0aa3ae612 100644 --- a/crates/hir-def/src/item_tree/lower.rs +++ b/crates/hir-def/src/item_tree/lower.rs @@ -13,7 +13,7 @@ use crate::{ use super::*; -fn id(index: Idx) -> FileItemTreeId { +fn id(index: Idx) -> FileItemTreeId { FileItemTreeId(index) } @@ -267,7 +267,7 @@ impl<'a> Ctx<'a> { if let Some(data) = self.lower_variant(&variant) { let idx = self.data().variants.alloc(data); self.add_attrs( - FileItemTreeId(idx).into(), + id(idx).into(), RawAttrs::new(self.db.upcast(), &variant, self.span_map()), ); } @@ -658,7 +658,7 @@ impl<'a> Ctx<'a> { fn lower_visibility(&mut self, item: &dyn ast::HasVisibility) -> RawVisibilityId { let vis = - RawVisibility::from_ast_with_span_map(self.db, item.visibility(), self.span_map()); + RawVisibility::from_opt_ast_with_span_map(self.db, item.visibility(), self.span_map()); self.data().vis.alloc(vis) } diff --git a/crates/hir-def/src/item_tree/pretty.rs b/crates/hir-def/src/item_tree/pretty.rs index 520034d213..0086b7180b 100644 --- a/crates/hir-def/src/item_tree/pretty.rs +++ b/crates/hir-def/src/item_tree/pretty.rs @@ -24,7 +24,7 @@ pub(super) fn print_item_tree(db: &dyn DefDatabase, tree: &ItemTree) -> String { p.print_mod_item(*item); } - let mut s = p.buf.trim_end_matches('\n').to_string(); + let mut s = p.buf.trim_end_matches('\n').to_owned(); s.push('\n'); s } @@ -58,7 +58,7 @@ impl Printer<'_> { wln!(self); f(self); self.indent_level -= 1; - self.buf = self.buf.trim_end_matches('\n').to_string(); + self.buf = self.buf.trim_end_matches('\n').to_owned(); } /// Ensures that a blank line is output before the next text. diff --git a/crates/hir-def/src/lib.rs b/crates/hir-def/src/lib.rs index 71bc521333..5670ebfa17 100644 --- a/crates/hir-def/src/lib.rs +++ b/crates/hir-def/src/lib.rs @@ -70,7 +70,11 @@ use std::{ panic::{RefUnwindSafe, UnwindSafe}, }; -use base_db::{impl_intern_key, salsa, CrateId, Edition}; +use base_db::{ + impl_intern_key, + salsa::{self, impl_intern_value_trivial}, + CrateId, Edition, +}; use hir_expand::{ ast_id_map::{AstIdNode, FileAstId}, builtin_attr_macro::BuiltinAttrExpander, @@ -87,7 +91,7 @@ use hir_expand::{ use item_tree::ExternBlock; use la_arena::Idx; use nameres::DefMap; -use span::Span; +use span::{FileId, Span}; use stdx::impl_from; use syntax::{ast, AstNode}; @@ -98,11 +102,268 @@ use crate::{ data::adt::VariantData, db::DefDatabase, item_tree::{ - Const, Enum, ExternCrate, Function, Impl, ItemTreeId, ItemTreeModItemNode, Macro2, - MacroRules, Static, Struct, Trait, TraitAlias, TypeAlias, Union, Use, Variant, + Const, Enum, ExternCrate, Function, Impl, ItemTreeId, ItemTreeNode, Macro2, MacroRules, + Static, Struct, Trait, TraitAlias, TypeAlias, Union, Use, Variant, }, }; +#[derive(Debug)] +pub struct ItemLoc { + pub container: ModuleId, + pub id: ItemTreeId, +} + +impl Clone for ItemLoc { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for ItemLoc {} + +impl PartialEq for ItemLoc { + fn eq(&self, other: &Self) -> bool { + self.container == other.container && self.id == other.id + } +} + +impl Eq for ItemLoc {} + +impl Hash for ItemLoc { + fn hash(&self, state: &mut H) { + self.container.hash(state); + self.id.hash(state); + } +} + +#[derive(Debug)] +pub struct AssocItemLoc { + pub container: ItemContainerId, + pub id: ItemTreeId, +} + +impl Clone for AssocItemLoc { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for AssocItemLoc {} + +impl PartialEq for AssocItemLoc { + fn eq(&self, other: &Self) -> bool { + self.container == other.container && self.id == other.id + } +} + +impl Eq for AssocItemLoc {} + +impl Hash for AssocItemLoc { + fn hash(&self, state: &mut H) { + self.container.hash(state); + self.id.hash(state); + } +} + +pub trait ItemTreeLoc { + type Container; + type Id; + fn item_tree_id(&self) -> ItemTreeId; + fn container(&self) -> Self::Container; +} + +macro_rules! impl_intern { + ($id:ident, $loc:ident, $intern:ident, $lookup:ident) => { + impl_intern_key!($id); + impl_intern_value_trivial!($loc); + impl_intern_lookup!(DefDatabase, $id, $loc, $intern, $lookup); + }; +} + +macro_rules! impl_loc { + ($loc:ident, $id:ident: $id_ty:ident, $container:ident: $container_type:ident) => { + impl ItemTreeLoc for $loc { + type Container = $container_type; + type Id = $id_ty; + fn item_tree_id(&self) -> ItemTreeId { + self.$id + } + fn container(&self) -> Self::Container { + self.$container + } + } + }; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct FunctionId(salsa::InternId); +type FunctionLoc = AssocItemLoc; +impl_intern!(FunctionId, FunctionLoc, intern_function, lookup_intern_function); +impl_loc!(FunctionLoc, id: Function, container: ItemContainerId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct StructId(salsa::InternId); +type StructLoc = ItemLoc; +impl_intern!(StructId, StructLoc, intern_struct, lookup_intern_struct); +impl_loc!(StructLoc, id: Struct, container: ModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct UnionId(salsa::InternId); +pub type UnionLoc = ItemLoc; +impl_intern!(UnionId, UnionLoc, intern_union, lookup_intern_union); +impl_loc!(UnionLoc, id: Union, container: ModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct EnumId(salsa::InternId); +pub type EnumLoc = ItemLoc; +impl_intern!(EnumId, EnumLoc, intern_enum, lookup_intern_enum); +impl_loc!(EnumLoc, id: Enum, container: ModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ConstId(salsa::InternId); +type ConstLoc = AssocItemLoc; +impl_intern!(ConstId, ConstLoc, intern_const, lookup_intern_const); +impl_loc!(ConstLoc, id: Const, container: ItemContainerId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct StaticId(salsa::InternId); +pub type StaticLoc = AssocItemLoc; +impl_intern!(StaticId, StaticLoc, intern_static, lookup_intern_static); +impl_loc!(StaticLoc, id: Static, container: ItemContainerId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TraitId(salsa::InternId); +pub type TraitLoc = ItemLoc; +impl_intern!(TraitId, TraitLoc, intern_trait, lookup_intern_trait); +impl_loc!(TraitLoc, id: Trait, container: ModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TraitAliasId(salsa::InternId); +pub type TraitAliasLoc = ItemLoc; +impl_intern!(TraitAliasId, TraitAliasLoc, intern_trait_alias, lookup_intern_trait_alias); +impl_loc!(TraitAliasLoc, id: TraitAlias, container: ModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct TypeAliasId(salsa::InternId); +type TypeAliasLoc = AssocItemLoc; +impl_intern!(TypeAliasId, TypeAliasLoc, intern_type_alias, lookup_intern_type_alias); +impl_loc!(TypeAliasLoc, id: TypeAlias, container: ItemContainerId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct ImplId(salsa::InternId); +type ImplLoc = ItemLoc; +impl_intern!(ImplId, ImplLoc, intern_impl, lookup_intern_impl); +impl_loc!(ImplLoc, id: Impl, container: ModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct UseId(salsa::InternId); +type UseLoc = ItemLoc; +impl_intern!(UseId, UseLoc, intern_use, lookup_intern_use); +impl_loc!(UseLoc, id: Use, container: ModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct ExternCrateId(salsa::InternId); +type ExternCrateLoc = ItemLoc; +impl_intern!(ExternCrateId, ExternCrateLoc, intern_extern_crate, lookup_intern_extern_crate); +impl_loc!(ExternCrateLoc, id: ExternCrate, container: ModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct ExternBlockId(salsa::InternId); +type ExternBlockLoc = ItemLoc; +impl_intern!(ExternBlockId, ExternBlockLoc, intern_extern_block, lookup_intern_extern_block); +impl_loc!(ExternBlockLoc, id: ExternBlock, container: ModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct EnumVariantId(salsa::InternId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct EnumVariantLoc { + pub id: ItemTreeId, + pub parent: EnumId, + pub index: u32, +} +impl_intern!(EnumVariantId, EnumVariantLoc, intern_enum_variant, lookup_intern_enum_variant); +impl_loc!(EnumVariantLoc, id: Variant, parent: EnumId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct Macro2Id(salsa::InternId); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Macro2Loc { + pub container: ModuleId, + pub id: ItemTreeId, + pub expander: MacroExpander, + pub allow_internal_unsafe: bool, + pub edition: Edition, +} +impl_intern!(Macro2Id, Macro2Loc, intern_macro2, lookup_intern_macro2); +impl_loc!(Macro2Loc, id: Macro2, container: ModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct MacroRulesId(salsa::InternId); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct MacroRulesLoc { + pub container: ModuleId, + pub id: ItemTreeId, + pub expander: MacroExpander, + pub flags: MacroRulesLocFlags, + pub edition: Edition, +} +impl_intern!(MacroRulesId, MacroRulesLoc, intern_macro_rules, lookup_intern_macro_rules); +impl_loc!(MacroRulesLoc, id: MacroRules, container: ModuleId); + +bitflags::bitflags! { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub struct MacroRulesLocFlags: u8 { + const ALLOW_INTERNAL_UNSAFE = 1 << 0; + const LOCAL_INNER = 1 << 1; + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum MacroExpander { + Declarative, + BuiltIn(BuiltinFnLikeExpander), + BuiltInAttr(BuiltinAttrExpander), + BuiltInDerive(BuiltinDeriveExpander), + BuiltInEager(EagerExpander), +} +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct ProcMacroId(salsa::InternId); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct ProcMacroLoc { + pub container: CrateRootModuleId, + pub id: ItemTreeId, + pub expander: CustomProcMacroExpander, + pub kind: ProcMacroKind, + pub edition: Edition, +} +impl_intern!(ProcMacroId, ProcMacroLoc, intern_proc_macro, lookup_intern_proc_macro); +impl_loc!(ProcMacroLoc, id: Function, container: CrateRootModuleId); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct BlockId(salsa::InternId); +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct BlockLoc { + ast_id: AstId, + /// The containing module. + module: ModuleId, +} +impl_intern!(BlockId, BlockLoc, intern_block, lookup_intern_block); + +/// Id of the anonymous const block expression and patterns. This is very similar to `ClosureId` and +/// shouldn't be a `DefWithBodyId` since its type inference is dependent on its parent. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +pub struct ConstBlockId(salsa::InternId); +impl_intern!(ConstBlockId, ConstBlockLoc, intern_anonymous_const, lookup_intern_anonymous_const); + +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct ConstBlockLoc { + /// The parent of the anonymous const block. + pub parent: DefWithBodyId, + /// The root expression of this const block in the parent body. + pub root: hir::ExprId, +} + /// A `ModuleId` that is always a crate's root module. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct CrateRootModuleId { @@ -124,23 +385,6 @@ impl PartialEq for CrateRootModuleId { other.block.is_none() && other.local_id == DefMap::ROOT && self.krate == other.krate } } -impl PartialEq for ModuleId { - fn eq(&self, other: &CrateRootModuleId) -> bool { - other == self - } -} - -impl From for ModuleId { - fn from(CrateRootModuleId { krate }: CrateRootModuleId) -> Self { - ModuleId { krate, block: None, local_id: DefMap::ROOT } - } -} - -impl From for ModuleDefId { - fn from(value: CrateRootModuleId) -> Self { - ModuleDefId::ModuleId(value.into()) - } -} impl From for CrateRootModuleId { fn from(krate: CrateId) -> Self { @@ -208,105 +452,27 @@ impl ModuleId { } } +impl PartialEq for ModuleId { + fn eq(&self, other: &CrateRootModuleId) -> bool { + other == self + } +} + +impl From for ModuleId { + fn from(CrateRootModuleId { krate }: CrateRootModuleId) -> Self { + ModuleId { krate, block: None, local_id: DefMap::ROOT } + } +} + +impl From for ModuleDefId { + fn from(value: CrateRootModuleId) -> Self { + ModuleDefId::ModuleId(value.into()) + } +} + /// An ID of a module, **local** to a `DefMap`. pub type LocalModuleId = Idx; -#[derive(Debug)] -pub struct ItemLoc { - pub container: ModuleId, - pub id: ItemTreeId, -} - -impl Clone for ItemLoc { - fn clone(&self) -> Self { - Self { container: self.container, id: self.id } - } -} - -impl Copy for ItemLoc {} - -impl PartialEq for ItemLoc { - fn eq(&self, other: &Self) -> bool { - self.container == other.container && self.id == other.id - } -} - -impl Eq for ItemLoc {} - -impl Hash for ItemLoc { - fn hash(&self, state: &mut H) { - self.container.hash(state); - self.id.hash(state); - } -} - -#[derive(Debug)] -pub struct AssocItemLoc { - pub container: ItemContainerId, - pub id: ItemTreeId, -} - -impl Clone for AssocItemLoc { - fn clone(&self) -> Self { - Self { container: self.container, id: self.id } - } -} - -impl Copy for AssocItemLoc {} - -impl PartialEq for AssocItemLoc { - fn eq(&self, other: &Self) -> bool { - self.container == other.container && self.id == other.id - } -} - -impl Eq for AssocItemLoc {} - -impl Hash for AssocItemLoc { - fn hash(&self, state: &mut H) { - self.container.hash(state); - self.id.hash(state); - } -} - -macro_rules! impl_intern { - ($id:ident, $loc:ident, $intern:ident, $lookup:ident) => { - impl_intern_key!($id); - impl_intern_lookup!(DefDatabase, $id, $loc, $intern, $lookup); - }; -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct FunctionId(salsa::InternId); -type FunctionLoc = AssocItemLoc; -impl_intern!(FunctionId, FunctionLoc, intern_function, lookup_intern_function); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct StructId(salsa::InternId); -type StructLoc = ItemLoc; -impl_intern!(StructId, StructLoc, intern_struct, lookup_intern_struct); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct UnionId(salsa::InternId); -pub type UnionLoc = ItemLoc; -impl_intern!(UnionId, UnionLoc, intern_union, lookup_intern_union); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct EnumId(salsa::InternId); -pub type EnumLoc = ItemLoc; -impl_intern!(EnumId, EnumLoc, intern_enum, lookup_intern_enum); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct EnumVariantId(salsa::InternId); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct EnumVariantLoc { - pub id: ItemTreeId, - pub parent: EnumId, - pub index: u32, -} -impl_intern!(EnumVariantId, EnumVariantLoc, intern_enum_variant, lookup_intern_enum_variant); - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FieldId { pub parent: VariantId, @@ -324,119 +490,12 @@ pub struct TupleFieldId { pub index: u32, } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct ConstId(salsa::InternId); -type ConstLoc = AssocItemLoc; -impl_intern!(ConstId, ConstLoc, intern_const, lookup_intern_const); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct StaticId(salsa::InternId); -pub type StaticLoc = AssocItemLoc; -impl_intern!(StaticId, StaticLoc, intern_static, lookup_intern_static); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct TraitId(salsa::InternId); -pub type TraitLoc = ItemLoc; -impl_intern!(TraitId, TraitLoc, intern_trait, lookup_intern_trait); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct TraitAliasId(salsa::InternId); -pub type TraitAliasLoc = ItemLoc; -impl_intern!(TraitAliasId, TraitAliasLoc, intern_trait_alias, lookup_intern_trait_alias); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct TypeAliasId(salsa::InternId); -type TypeAliasLoc = AssocItemLoc; -impl_intern!(TypeAliasId, TypeAliasLoc, intern_type_alias, lookup_intern_type_alias); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub struct ImplId(salsa::InternId); -type ImplLoc = ItemLoc; -impl_intern!(ImplId, ImplLoc, intern_impl, lookup_intern_impl); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub struct UseId(salsa::InternId); -type UseLoc = ItemLoc; -impl_intern!(UseId, UseLoc, intern_use, lookup_intern_use); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub struct ExternCrateId(salsa::InternId); -type ExternCrateLoc = ItemLoc; -impl_intern!(ExternCrateId, ExternCrateLoc, intern_extern_crate, lookup_intern_extern_crate); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub struct ExternBlockId(salsa::InternId); -type ExternBlockLoc = ItemLoc; -impl_intern!(ExternBlockId, ExternBlockLoc, intern_extern_block, lookup_intern_extern_block); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum MacroExpander { - Declarative, - BuiltIn(BuiltinFnLikeExpander), - BuiltInAttr(BuiltinAttrExpander), - BuiltInDerive(BuiltinDeriveExpander), - BuiltInEager(EagerExpander), -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub struct Macro2Id(salsa::InternId); -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Macro2Loc { - pub container: ModuleId, - pub id: ItemTreeId, - pub expander: MacroExpander, - pub allow_internal_unsafe: bool, - pub edition: Edition, -} -impl_intern!(Macro2Id, Macro2Loc, intern_macro2, lookup_intern_macro2); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub struct MacroRulesId(salsa::InternId); -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct MacroRulesLoc { - pub container: ModuleId, - pub id: ItemTreeId, - pub expander: MacroExpander, - pub flags: MacroRulesLocFlags, - pub edition: Edition, -} -impl_intern!(MacroRulesId, MacroRulesLoc, intern_macro_rules, lookup_intern_macro_rules); - -bitflags::bitflags! { - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] - pub struct MacroRulesLocFlags: u8 { - const ALLOW_INTERNAL_UNSAFE = 1 << 0; - const LOCAL_INNER = 1 << 1; - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub struct ProcMacroId(salsa::InternId); -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct ProcMacroLoc { - pub container: CrateRootModuleId, - pub id: ItemTreeId, - pub expander: CustomProcMacroExpander, - pub kind: ProcMacroKind, - pub edition: Edition, -} -impl_intern!(ProcMacroId, ProcMacroLoc, intern_proc_macro, lookup_intern_proc_macro); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] -pub struct BlockId(salsa::InternId); -#[derive(Debug, Hash, PartialEq, Eq, Clone)] -pub struct BlockLoc { - ast_id: AstId, - /// The containing module. - module: ModuleId, -} -impl_intern!(BlockId, BlockLoc, intern_block, lookup_intern_block); - #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct TypeOrConstParamId { pub parent: GenericDefId, pub local_id: LocalTypeOrConstParamId, } +impl_intern_value_trivial!(TypeOrConstParamId); /// A TypeOrConstParamId with an invariant that it actually belongs to a type #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -498,6 +557,7 @@ pub struct LifetimeParamId { pub local_id: LocalLifetimeParamId, } pub type LocalLifetimeParamId = Idx; +impl_intern_value_trivial!(LifetimeParamId); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum ItemContainerId { @@ -572,20 +632,6 @@ impl_from!( for ModuleDefId ); -/// Id of the anonymous const block expression and patterns. This is very similar to `ClosureId` and -/// shouldn't be a `DefWithBodyId` since its type inference is dependent on its parent. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub struct ConstBlockId(salsa::InternId); -impl_intern!(ConstBlockId, ConstBlockLoc, intern_anonymous_const, lookup_intern_anonymous_const); - -#[derive(Debug, Hash, PartialEq, Eq, Clone)] -pub struct ConstBlockLoc { - /// The parent of the anonymous const block. - pub parent: DefWithBodyId, - /// The root expression of this const block in the parent body. - pub root: hir::ExprId, -} - /// Something that holds types, required for the current const arg lowering implementation as they /// need to be able to query where they are defined. #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] @@ -721,6 +767,9 @@ impl Clone for Box { pub struct InTypeConstId(salsa::InternId); impl_intern!(InTypeConstId, InTypeConstLoc, intern_in_type_const, lookup_intern_in_type_const); +// We would like to set `derive(PartialEq)` +// but the compiler complains about that `.expected_ty` does not implement the `Copy` trait. +#[allow(clippy::derived_hash_with_manual_eq)] #[derive(Debug, Hash, Eq, Clone)] pub struct InTypeConstLoc { pub id: AstId, @@ -850,6 +899,39 @@ impl_from!( for GenericDefId ); +impl GenericDefId { + fn file_id_and_params_of( + self, + db: &dyn DefDatabase, + ) -> (HirFileId, Option) { + fn file_id_and_params_of_item_loc( + db: &dyn DefDatabase, + def: impl for<'db> Lookup = dyn DefDatabase + 'db, Data = Loc>, + ) -> (HirFileId, Option) + where + Loc: src::HasSource, + Loc::Value: ast::HasGenericParams, + { + let src = def.lookup(db).source(db); + (src.file_id, ast::HasGenericParams::generic_param_list(&src.value)) + } + + match self { + GenericDefId::FunctionId(it) => file_id_and_params_of_item_loc(db, it), + GenericDefId::TypeAliasId(it) => file_id_and_params_of_item_loc(db, it), + GenericDefId::ConstId(_) => (FileId::BOGUS.into(), None), + GenericDefId::AdtId(AdtId::StructId(it)) => file_id_and_params_of_item_loc(db, it), + GenericDefId::AdtId(AdtId::UnionId(it)) => file_id_and_params_of_item_loc(db, it), + GenericDefId::AdtId(AdtId::EnumId(it)) => file_id_and_params_of_item_loc(db, it), + GenericDefId::TraitId(it) => file_id_and_params_of_item_loc(db, it), + GenericDefId::TraitAliasId(it) => file_id_and_params_of_item_loc(db, it), + GenericDefId::ImplId(it) => file_id_and_params_of_item_loc(db, it), + // We won't be using this ID anyway + GenericDefId::EnumVariantId(_) => (FileId::BOGUS.into(), None), + } + } +} + impl From for GenericDefId { fn from(item: AssocItemId) -> Self { match item { @@ -983,44 +1065,92 @@ impl VariantId { } pub trait HasModule { + /// Returns the enclosing module this thing is defined within. fn module(&self, db: &dyn DefDatabase) -> ModuleId; -} - -impl HasModule for ItemContainerId { - fn module(&self, db: &dyn DefDatabase) -> ModuleId { - match *self { - ItemContainerId::ModuleId(it) => it, - ItemContainerId::ImplId(it) => it.lookup(db).container, - ItemContainerId::TraitId(it) => it.lookup(db).container, - ItemContainerId::ExternBlockId(it) => it.lookup(db).container, - } - } -} - -impl HasModule for AssocItemLoc { + /// Returns the crate this thing is defined within. #[inline] - fn module(&self, db: &dyn DefDatabase) -> ModuleId { - self.container.module(db) + #[doc(alias = "crate")] + fn krate(&self, db: &dyn DefDatabase) -> CrateId { + self.module(db).krate } } -impl HasModule for AdtId { - fn module(&self, db: &dyn DefDatabase) -> ModuleId { - match self { - AdtId::StructId(it) => it.lookup(db).container, - AdtId::UnionId(it) => it.lookup(db).container, - AdtId::EnumId(it) => it.lookup(db).container, - } - } -} +// In theory this impl should work out for us, but rustc thinks it collides with all the other +// manual impls that do not have a ModuleId container... +// impl HasModule for ItemId +// where +// N: ItemTreeNode, +// ItemId: for<'db> Lookup = dyn DefDatabase + 'db, Data = Data> + Copy, +// Data: ItemTreeLoc, +// { +// #[inline] +// fn module(&self, db: &dyn DefDatabase) -> ModuleId { +// self.lookup(db).container() +// } +// } -impl HasModule for EnumId { +impl HasModule for ItemId +where + N: ItemTreeNode, + ItemId: for<'db> Lookup = dyn DefDatabase + 'db, Data = ItemLoc> + Copy, +{ #[inline] fn module(&self, db: &dyn DefDatabase) -> ModuleId { self.lookup(db).container } } +// Technically this does not overlap with the above, but rustc currently forbids this, hence why we +// need to write the 3 impls manually instead +// impl HasModule for ItemId +// where +// N: ItemTreeModItemNode, +// ItemId: for<'db> Lookup = dyn DefDatabase + 'db, Data = AssocItemLoc> + Copy, +// { +// #[inline] +// fn module(&self, db: &dyn DefDatabase) -> ModuleId { +// self.lookup(db).container.module(db) +// } +// } + +// region: manual-assoc-has-module-impls +#[inline] +fn module_for_assoc_item_loc<'db>( + db: &(dyn 'db + DefDatabase), + id: impl Lookup = dyn DefDatabase + 'db, Data = AssocItemLoc>, +) -> ModuleId { + id.lookup(db).container.module(db) +} + +impl HasModule for FunctionId { + #[inline] + fn module(&self, db: &dyn DefDatabase) -> ModuleId { + module_for_assoc_item_loc(db, *self) + } +} + +impl HasModule for ConstId { + #[inline] + fn module(&self, db: &dyn DefDatabase) -> ModuleId { + module_for_assoc_item_loc(db, *self) + } +} + +impl HasModule for StaticId { + #[inline] + fn module(&self, db: &dyn DefDatabase) -> ModuleId { + module_for_assoc_item_loc(db, *self) + } +} + +impl HasModule for TypeAliasId { + #[inline] + fn module(&self, db: &dyn DefDatabase) -> ModuleId { + module_for_assoc_item_loc(db, *self) + } +} +// endregion: manual-assoc-has-module-impls + impl HasModule for EnumVariantId { #[inline] fn module(&self, db: &dyn DefDatabase) -> ModuleId { @@ -1028,46 +1158,81 @@ impl HasModule for EnumVariantId { } } -impl HasModule for ExternCrateId { +impl HasModule for MacroRulesId { #[inline] fn module(&self, db: &dyn DefDatabase) -> ModuleId { self.lookup(db).container } } +impl HasModule for Macro2Id { + #[inline] + fn module(&self, db: &dyn DefDatabase) -> ModuleId { + self.lookup(db).container + } +} + +impl HasModule for ProcMacroId { + #[inline] + fn module(&self, db: &dyn DefDatabase) -> ModuleId { + self.lookup(db).container.into() + } +} + +impl HasModule for ItemContainerId { + fn module(&self, db: &dyn DefDatabase) -> ModuleId { + match *self { + ItemContainerId::ModuleId(it) => it, + ItemContainerId::ImplId(it) => it.module(db), + ItemContainerId::TraitId(it) => it.module(db), + ItemContainerId::ExternBlockId(it) => it.module(db), + } + } +} + +impl HasModule for AdtId { + fn module(&self, db: &dyn DefDatabase) -> ModuleId { + match *self { + AdtId::StructId(it) => it.module(db), + AdtId::UnionId(it) => it.module(db), + AdtId::EnumId(it) => it.module(db), + } + } +} + impl HasModule for VariantId { fn module(&self, db: &dyn DefDatabase) -> ModuleId { - match self { - VariantId::EnumVariantId(it) => it.lookup(db).parent.module(db), - VariantId::StructId(it) => it.lookup(db).container, - VariantId::UnionId(it) => it.lookup(db).container, + match *self { + VariantId::EnumVariantId(it) => it.module(db), + VariantId::StructId(it) => it.module(db), + VariantId::UnionId(it) => it.module(db), } } } impl HasModule for MacroId { fn module(&self, db: &dyn DefDatabase) -> ModuleId { - match self { - MacroId::MacroRulesId(it) => it.lookup(db).container, - MacroId::Macro2Id(it) => it.lookup(db).container, - MacroId::ProcMacroId(it) => it.lookup(db).container.into(), + match *self { + MacroId::MacroRulesId(it) => it.module(db), + MacroId::Macro2Id(it) => it.module(db), + MacroId::ProcMacroId(it) => it.module(db), } } } impl HasModule for TypeOwnerId { fn module(&self, db: &dyn DefDatabase) -> ModuleId { - match self { - TypeOwnerId::FunctionId(it) => it.lookup(db).module(db), - TypeOwnerId::StaticId(it) => it.lookup(db).module(db), - TypeOwnerId::ConstId(it) => it.lookup(db).module(db), - TypeOwnerId::InTypeConstId(it) => it.lookup(db).owner.module(db), + match *self { + TypeOwnerId::FunctionId(it) => it.module(db), + TypeOwnerId::StaticId(it) => it.module(db), + TypeOwnerId::ConstId(it) => it.module(db), TypeOwnerId::AdtId(it) => it.module(db), - TypeOwnerId::TraitId(it) => it.lookup(db).container, - TypeOwnerId::TraitAliasId(it) => it.lookup(db).container, - TypeOwnerId::TypeAliasId(it) => it.lookup(db).module(db), - TypeOwnerId::ImplId(it) => it.lookup(db).container, - TypeOwnerId::EnumVariantId(it) => it.lookup(db).parent.module(db), + TypeOwnerId::TraitId(it) => it.module(db), + TypeOwnerId::TraitAliasId(it) => it.module(db), + TypeOwnerId::TypeAliasId(it) => it.module(db), + TypeOwnerId::ImplId(it) => it.module(db), + TypeOwnerId::EnumVariantId(it) => it.module(db), + TypeOwnerId::InTypeConstId(it) => it.lookup(db).owner.module(db), } } } @@ -1075,10 +1240,10 @@ impl HasModule for TypeOwnerId { impl HasModule for DefWithBodyId { fn module(&self, db: &dyn DefDatabase) -> ModuleId { match self { - DefWithBodyId::FunctionId(it) => it.lookup(db).module(db), - DefWithBodyId::StaticId(it) => it.lookup(db).module(db), - DefWithBodyId::ConstId(it) => it.lookup(db).module(db), - DefWithBodyId::VariantId(it) => it.lookup(db).parent.module(db), + DefWithBodyId::FunctionId(it) => it.module(db), + DefWithBodyId::StaticId(it) => it.module(db), + DefWithBodyId::ConstId(it) => it.module(db), + DefWithBodyId::VariantId(it) => it.module(db), DefWithBodyId::InTypeConstId(it) => it.lookup(db).owner.module(db), } } @@ -1087,29 +1252,43 @@ impl HasModule for DefWithBodyId { impl HasModule for GenericDefId { fn module(&self, db: &dyn DefDatabase) -> ModuleId { match self { - GenericDefId::FunctionId(it) => it.lookup(db).module(db), + GenericDefId::FunctionId(it) => it.module(db), GenericDefId::AdtId(it) => it.module(db), - GenericDefId::TraitId(it) => it.lookup(db).container, - GenericDefId::TraitAliasId(it) => it.lookup(db).container, - GenericDefId::TypeAliasId(it) => it.lookup(db).module(db), - GenericDefId::ImplId(it) => it.lookup(db).container, - GenericDefId::EnumVariantId(it) => it.lookup(db).parent.lookup(db).container, - GenericDefId::ConstId(it) => it.lookup(db).module(db), + GenericDefId::TraitId(it) => it.module(db), + GenericDefId::TraitAliasId(it) => it.module(db), + GenericDefId::TypeAliasId(it) => it.module(db), + GenericDefId::ImplId(it) => it.module(db), + GenericDefId::EnumVariantId(it) => it.module(db), + GenericDefId::ConstId(it) => it.module(db), } } } -impl HasModule for TypeAliasId { - #[inline] +impl HasModule for AttrDefId { fn module(&self, db: &dyn DefDatabase) -> ModuleId { - self.lookup(db).module(db) - } -} - -impl HasModule for TraitId { - #[inline] - fn module(&self, db: &dyn DefDatabase) -> ModuleId { - self.lookup(db).container + match self { + AttrDefId::ModuleId(it) => *it, + AttrDefId::FieldId(it) => it.parent.module(db), + AttrDefId::AdtId(it) => it.module(db), + AttrDefId::FunctionId(it) => it.module(db), + AttrDefId::EnumVariantId(it) => it.module(db), + AttrDefId::StaticId(it) => it.module(db), + AttrDefId::ConstId(it) => it.module(db), + AttrDefId::TraitId(it) => it.module(db), + AttrDefId::TraitAliasId(it) => it.module(db), + AttrDefId::TypeAliasId(it) => it.module(db), + AttrDefId::ImplId(it) => it.module(db), + AttrDefId::ExternBlockId(it) => it.module(db), + AttrDefId::GenericParamId(it) => match it { + GenericParamId::TypeParamId(it) => it.parent(), + GenericParamId::ConstParamId(it) => it.parent(), + GenericParamId::LifetimeParamId(it) => it.parent, + } + .module(db), + AttrDefId::MacroId(it) => it.module(db), + AttrDefId::ExternCrateId(it) => it.module(db), + AttrDefId::UseId(it) => it.module(db), + } } } @@ -1120,51 +1299,20 @@ impl ModuleDefId { pub fn module(&self, db: &dyn DefDatabase) -> Option { Some(match self { ModuleDefId::ModuleId(id) => *id, - ModuleDefId::FunctionId(id) => id.lookup(db).module(db), + ModuleDefId::FunctionId(id) => id.module(db), ModuleDefId::AdtId(id) => id.module(db), - ModuleDefId::EnumVariantId(id) => id.lookup(db).parent.module(db), - ModuleDefId::ConstId(id) => id.lookup(db).container.module(db), - ModuleDefId::StaticId(id) => id.lookup(db).module(db), - ModuleDefId::TraitId(id) => id.lookup(db).container, - ModuleDefId::TraitAliasId(id) => id.lookup(db).container, - ModuleDefId::TypeAliasId(id) => id.lookup(db).module(db), + ModuleDefId::EnumVariantId(id) => id.module(db), + ModuleDefId::ConstId(id) => id.module(db), + ModuleDefId::StaticId(id) => id.module(db), + ModuleDefId::TraitId(id) => id.module(db), + ModuleDefId::TraitAliasId(id) => id.module(db), + ModuleDefId::TypeAliasId(id) => id.module(db), ModuleDefId::MacroId(id) => id.module(db), ModuleDefId::BuiltinType(_) => return None, }) } } -impl AttrDefId { - pub fn krate(&self, db: &dyn DefDatabase) -> CrateId { - match self { - AttrDefId::ModuleId(it) => it.krate, - AttrDefId::FieldId(it) => it.parent.module(db).krate, - AttrDefId::AdtId(it) => it.module(db).krate, - AttrDefId::FunctionId(it) => it.lookup(db).module(db).krate, - AttrDefId::EnumVariantId(it) => it.lookup(db).parent.module(db).krate, - AttrDefId::StaticId(it) => it.lookup(db).module(db).krate, - AttrDefId::ConstId(it) => it.lookup(db).module(db).krate, - AttrDefId::TraitId(it) => it.lookup(db).container.krate, - AttrDefId::TraitAliasId(it) => it.lookup(db).container.krate, - AttrDefId::TypeAliasId(it) => it.lookup(db).module(db).krate, - AttrDefId::ImplId(it) => it.lookup(db).container.krate, - AttrDefId::ExternBlockId(it) => it.lookup(db).container.krate, - AttrDefId::GenericParamId(it) => { - match it { - GenericParamId::TypeParamId(it) => it.parent(), - GenericParamId::ConstParamId(it) => it.parent(), - GenericParamId::LifetimeParamId(it) => it.parent, - } - .module(db) - .krate - } - AttrDefId::MacroId(it) => it.module(db).krate, - AttrDefId::ExternCrateId(it) => it.lookup(db).container.krate, - AttrDefId::UseId(it) => it.lookup(db).container.krate, - } - } -} - /// A helper trait for converting to MacroCallId pub trait AsMacroCall { fn as_call_id( diff --git a/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs b/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs index 553c0b7953..86b4466153 100644 --- a/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs +++ b/crates/hir-def/src/macro_expansion_tests/builtin_derive_macro.rs @@ -157,7 +157,7 @@ where generic: Vec, } -impl $crate::clone::Clone for Foo where T: Trait, T::InFieldShorthand: $crate::clone::Clone, T::InGenericArg: $crate::clone::Clone, { +impl $crate::clone::Clone for Foo where ::InWc: Marker, T: Trait, T::InFieldShorthand: $crate::clone::Clone, T::InGenericArg: $crate::clone::Clone, { fn clone(&self ) -> Self { match self { Foo { diff --git a/crates/hir-def/src/macro_expansion_tests/mbe.rs b/crates/hir-def/src/macro_expansion_tests/mbe.rs index d0ae1f59f7..edc8247f16 100644 --- a/crates/hir-def/src/macro_expansion_tests/mbe.rs +++ b/crates/hir-def/src/macro_expansion_tests/mbe.rs @@ -35,9 +35,9 @@ macro_rules! f { }; } -struct#0:1@58..64#2# MyTraitMap2#0:2@31..42#0# {#0:1@72..73#2# - map#0:1@86..89#2#:#0:1@89..90#2# #0:1@89..90#2#::#0:1@91..92#2#std#0:1@93..96#2#::#0:1@96..97#2#collections#0:1@98..109#2#::#0:1@109..110#2#HashSet#0:1@111..118#2#<#0:1@118..119#2#(#0:1@119..120#2#)#0:1@120..121#2#>#0:1@121..122#2#,#0:1@122..123#2# -}#0:1@132..133#2# +struct#0:1@58..64#1# MyTraitMap2#0:2@31..42#0# {#0:1@72..73#1# + map#0:1@86..89#1#:#0:1@89..90#1# #0:1@89..90#1#::#0:1@91..92#1#std#0:1@93..96#1#::#0:1@96..97#1#collections#0:1@98..109#1#::#0:1@109..110#1#HashSet#0:1@111..118#1#<#0:1@118..119#1#(#0:1@119..120#1#)#0:1@120..121#1#>#0:1@121..122#1#,#0:1@122..123#1# +}#0:1@132..133#1# "#]], ); } @@ -171,7 +171,7 @@ fn main(foo: ()) { } fn main(foo: ()) { - /* error: unresolved macro unresolved */"helloworld!"#0:3@207..323#6#; + /* error: unresolved macro unresolved */"helloworld!"#0:3@207..323#2#; } } @@ -197,7 +197,7 @@ macro_rules! mk_struct { #[macro_use] mod foo; -struct#1:1@59..65#2# Foo#0:2@32..35#0#(#1:1@70..71#2#u32#0:2@41..44#0#)#1:1@74..75#2#;#1:1@75..76#2# +struct#1:1@59..65#1# Foo#0:2@32..35#0#(#1:1@70..71#1#u32#0:2@41..44#0#)#1:1@74..75#1#;#1:1@75..76#1# "#]], ); } diff --git a/crates/hir-def/src/macro_expansion_tests/mod.rs b/crates/hir-def/src/macro_expansion_tests/mod.rs index e315414e9b..fc5a6e80a4 100644 --- a/crates/hir-def/src/macro_expansion_tests/mod.rs +++ b/crates/hir-def/src/macro_expansion_tests/mod.rs @@ -224,7 +224,7 @@ fn reindent(indent: IndentLevel, pp: String) -> String { return pp; } let mut lines = pp.split_inclusive('\n'); - let mut res = lines.next().unwrap().to_string(); + let mut res = lines.next().unwrap().to_owned(); for line in lines { if line.trim().is_empty() { res.push_str(line) diff --git a/crates/hir-def/src/nameres/collector.rs b/crates/hir-def/src/nameres/collector.rs index fb6fd867a1..21cc28f1b3 100644 --- a/crates/hir-def/src/nameres/collector.rs +++ b/crates/hir-def/src/nameres/collector.rs @@ -33,8 +33,8 @@ use crate::{ db::DefDatabase, item_scope::{ImportId, ImportOrExternCrate, ImportType, PerNsGlobImports}, item_tree::{ - self, ExternCrate, Fields, FileItemTreeId, ImportKind, ItemTree, ItemTreeId, - ItemTreeModItemNode, Macro2, MacroCall, MacroRules, Mod, ModItem, ModKind, TreeId, + self, ExternCrate, Fields, FileItemTreeId, ImportKind, ItemTree, ItemTreeId, ItemTreeNode, + Macro2, MacroCall, MacroRules, Mod, ModItem, ModKind, TreeId, }, macro_call_as_call_id, macro_call_as_call_id_with_eager, nameres::{ @@ -2125,7 +2125,7 @@ impl ModCollector<'_, '_> { let is_export = export_attr.exists(); let local_inner = if is_export { - export_attr.tt_values().flat_map(|it| &it.token_trees).any(|it| match it { + export_attr.tt_values().flat_map(|it| it.token_trees.iter()).any(|it| match it { tt::TokenTree::Leaf(tt::Leaf::Ident(ident)) => { ident.text.contains("local_inner_macros") } diff --git a/crates/hir-def/src/resolver.rs b/crates/hir-def/src/resolver.rs index 7a9c4ea016..db47d743c5 100644 --- a/crates/hir-def/src/resolver.rs +++ b/crates/hir-def/src/resolver.rs @@ -27,9 +27,9 @@ use crate::{ visibility::{RawVisibility, Visibility}, AdtId, ConstId, ConstParamId, CrateRootModuleId, DefWithBodyId, EnumId, EnumVariantId, ExternBlockId, ExternCrateId, FunctionId, GenericDefId, GenericParamId, HasModule, ImplId, - ItemContainerId, LifetimeParamId, LocalModuleId, Lookup, Macro2Id, MacroId, MacroRulesId, - ModuleDefId, ModuleId, ProcMacroId, StaticId, StructId, TraitAliasId, TraitId, TypeAliasId, - TypeOrConstParamId, TypeOwnerId, TypeParamId, UseId, VariantId, + ItemContainerId, ItemTreeLoc, LifetimeParamId, LocalModuleId, Lookup, Macro2Id, MacroId, + MacroRulesId, ModuleDefId, ModuleId, ProcMacroId, StaticId, StructId, TraitAliasId, TraitId, + TypeAliasId, TypeOrConstParamId, TypeOwnerId, TypeParamId, UseId, VariantId, }; #[derive(Debug, Clone)] @@ -248,6 +248,7 @@ impl Resolver { RawVisibility::Public => Some(Visibility::Public), } } + pub fn resolve_path_in_value_ns( &self, db: &dyn DefDatabase, @@ -1014,13 +1015,13 @@ impl HasResolver for CrateRootModuleId { impl HasResolver for TraitId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db).push_generic_params_scope(db, self.into()) + lookup_resolver(db, self).push_generic_params_scope(db, self.into()) } } impl HasResolver for TraitAliasId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db).push_generic_params_scope(db, self.into()) + lookup_resolver(db, self).push_generic_params_scope(db, self.into()) } } @@ -1036,25 +1037,25 @@ impl + Copy> HasResolver for T { impl HasResolver for FunctionId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db).push_generic_params_scope(db, self.into()) + lookup_resolver(db, self).push_generic_params_scope(db, self.into()) } } impl HasResolver for ConstId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db) + lookup_resolver(db, self) } } impl HasResolver for StaticId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db) + lookup_resolver(db, self) } } impl HasResolver for TypeAliasId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db).push_generic_params_scope(db, self.into()) + lookup_resolver(db, self).push_generic_params_scope(db, self.into()) } } @@ -1071,19 +1072,19 @@ impl HasResolver for ImplId { impl HasResolver for ExternBlockId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { // Same as parent's - self.lookup(db).container.resolver(db) + lookup_resolver(db, self) } } impl HasResolver for ExternCrateId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db) + lookup_resolver(db, self) } } impl HasResolver for UseId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db) + lookup_resolver(db, self) } } @@ -1170,18 +1171,28 @@ impl HasResolver for MacroId { impl HasResolver for Macro2Id { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db) + lookup_resolver(db, self) } } impl HasResolver for ProcMacroId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db) + lookup_resolver(db, self) } } impl HasResolver for MacroRulesId { fn resolver(self, db: &dyn DefDatabase) -> Resolver { - self.lookup(db).container.resolver(db) + lookup_resolver(db, self) } } + +fn lookup_resolver<'db>( + db: &(dyn DefDatabase + 'db), + lookup: impl Lookup< + Database<'db> = dyn DefDatabase + 'db, + Data = impl ItemTreeLoc, + >, +) -> Resolver { + lookup.lookup(db).container().resolver(db) +} diff --git a/crates/hir-def/src/src.rs b/crates/hir-def/src/src.rs index 9bd8c8d221..4283f003f8 100644 --- a/crates/hir-def/src/src.rs +++ b/crates/hir-def/src/src.rs @@ -1,12 +1,14 @@ //! Utilities for mapping between hir IDs and the surface syntax. +use either::Either; use hir_expand::InFile; use la_arena::ArenaMap; use syntax::ast; use crate::{ - db::DefDatabase, item_tree::ItemTreeModItemNode, AssocItemLoc, EnumVariantLoc, ItemLoc, Lookup, - Macro2Loc, MacroRulesLoc, ProcMacroLoc, UseId, + data::adt::lower_struct, db::DefDatabase, item_tree::ItemTreeNode, trace::Trace, GenericDefId, + ItemTreeLoc, LocalFieldId, LocalLifetimeParamId, LocalTypeOrConstParamId, Lookup, UseId, + VariantId, }; pub trait HasSource { @@ -14,81 +16,22 @@ pub trait HasSource { fn source(&self, db: &dyn DefDatabase) -> InFile; } -impl HasSource for AssocItemLoc { - type Value = N::Source; - - fn source(&self, db: &dyn DefDatabase) -> InFile { - let tree = self.id.item_tree(db); - let ast_id_map = db.ast_id_map(self.id.file_id()); - let root = db.parse_or_expand(self.id.file_id()); - let node = &tree[self.id.value]; - - InFile::new(self.id.file_id(), ast_id_map.get(node.ast_id()).to_node(&root)) - } -} - -impl HasSource for ItemLoc { - type Value = N::Source; - - fn source(&self, db: &dyn DefDatabase) -> InFile { - let tree = self.id.item_tree(db); - let ast_id_map = db.ast_id_map(self.id.file_id()); - let root = db.parse_or_expand(self.id.file_id()); - let node = &tree[self.id.value]; - - InFile::new(self.id.file_id(), ast_id_map.get(node.ast_id()).to_node(&root)) - } -} - -impl HasSource for EnumVariantLoc { - type Value = ast::Variant; - - fn source(&self, db: &dyn DefDatabase) -> InFile { - let tree = self.id.item_tree(db); - let ast_id_map = db.ast_id_map(self.id.file_id()); - let root = db.parse_or_expand(self.id.file_id()); - let node = &tree[self.id.value]; - - InFile::new(self.id.file_id(), ast_id_map.get(node.ast_id).to_node(&root)) - } -} - -impl HasSource for Macro2Loc { - type Value = ast::MacroDef; +impl HasSource for T +where + T: ItemTreeLoc, + T::Id: ItemTreeNode, +{ + type Value = ::Source; fn source(&self, db: &dyn DefDatabase) -> InFile { - let tree = self.id.item_tree(db); - let ast_id_map = db.ast_id_map(self.id.file_id()); - let root = db.parse_or_expand(self.id.file_id()); - let node = &tree[self.id.value]; + let id = self.item_tree_id(); + let file_id = id.file_id(); + let tree = id.item_tree(db); + let ast_id_map = db.ast_id_map(file_id); + let root = db.parse_or_expand(file_id); + let node = &tree[id.value]; - InFile::new(self.id.file_id(), ast_id_map.get(node.ast_id()).to_node(&root)) - } -} - -impl HasSource for MacroRulesLoc { - type Value = ast::MacroRules; - - fn source(&self, db: &dyn DefDatabase) -> InFile { - let tree = self.id.item_tree(db); - let ast_id_map = db.ast_id_map(self.id.file_id()); - let root = db.parse_or_expand(self.id.file_id()); - let node = &tree[self.id.value]; - - InFile::new(self.id.file_id(), ast_id_map.get(node.ast_id()).to_node(&root)) - } -} - -impl HasSource for ProcMacroLoc { - type Value = ast::Fn; - - fn source(&self, db: &dyn DefDatabase) -> InFile { - let tree = self.id.item_tree(db); - let ast_id_map = db.ast_id_map(self.id.file_id()); - let root = db.parse_or_expand(self.id.file_id()); - let node = &tree[self.id.value]; - - InFile::new(self.id.file_id(), ast_id_map.get(node.ast_id()).to_node(&root)) + InFile::new(file_id, ast_id_map.get(node.ast_id()).to_node(&root)) } } @@ -111,3 +54,105 @@ impl HasChildSource> for UseId { ) } } + +impl HasChildSource for GenericDefId { + type Value = Either; + fn child_source( + &self, + db: &dyn DefDatabase, + ) -> InFile> { + let generic_params = db.generic_params(*self); + let mut idx_iter = generic_params.type_or_consts.iter().map(|(idx, _)| idx); + + let (file_id, generic_params_list) = self.file_id_and_params_of(db); + + let mut params = ArenaMap::default(); + + // For traits and trait aliases the first type index is `Self`, we need to add it before + // the other params. + match *self { + GenericDefId::TraitId(id) => { + let trait_ref = id.lookup(db).source(db).value; + let idx = idx_iter.next().unwrap(); + params.insert(idx, Either::Right(ast::TraitOrAlias::Trait(trait_ref))); + } + GenericDefId::TraitAliasId(id) => { + let alias = id.lookup(db).source(db).value; + let idx = idx_iter.next().unwrap(); + params.insert(idx, Either::Right(ast::TraitOrAlias::TraitAlias(alias))); + } + _ => {} + } + + if let Some(generic_params_list) = generic_params_list { + for (idx, ast_param) in idx_iter.zip(generic_params_list.type_or_const_params()) { + params.insert(idx, Either::Left(ast_param)); + } + } + + InFile::new(file_id, params) + } +} + +impl HasChildSource for GenericDefId { + type Value = ast::LifetimeParam; + fn child_source( + &self, + db: &dyn DefDatabase, + ) -> InFile> { + let generic_params = db.generic_params(*self); + let idx_iter = generic_params.lifetimes.iter().map(|(idx, _)| idx); + + let (file_id, generic_params_list) = self.file_id_and_params_of(db); + + let mut params = ArenaMap::default(); + + if let Some(generic_params_list) = generic_params_list { + for (idx, ast_param) in idx_iter.zip(generic_params_list.lifetime_params()) { + params.insert(idx, ast_param); + } + } + + InFile::new(file_id, params) + } +} + +impl HasChildSource for VariantId { + type Value = Either; + + fn child_source(&self, db: &dyn DefDatabase) -> InFile> { + let item_tree; + let (src, fields, container) = match *self { + VariantId::EnumVariantId(it) => { + let lookup = it.lookup(db); + item_tree = lookup.id.item_tree(db); + ( + lookup.source(db).map(|it| it.kind()), + &item_tree[lookup.id.value].fields, + lookup.parent.lookup(db).container, + ) + } + VariantId::StructId(it) => { + let lookup = it.lookup(db); + item_tree = lookup.id.item_tree(db); + ( + lookup.source(db).map(|it| it.kind()), + &item_tree[lookup.id.value].fields, + lookup.container, + ) + } + VariantId::UnionId(it) => { + let lookup = it.lookup(db); + item_tree = lookup.id.item_tree(db); + ( + lookup.source(db).map(|it| it.kind()), + &item_tree[lookup.id.value].fields, + lookup.container, + ) + } + }; + let mut trace = Trace::new_for_map(); + lower_struct(db, &mut trace, &src, container.krate, &item_tree, fields); + src.with_value(trace.into_map()) + } +} diff --git a/crates/hir-def/src/trace.rs b/crates/hir-def/src/trace.rs index 04d5b26619..da50ee8dc7 100644 --- a/crates/hir-def/src/trace.rs +++ b/crates/hir-def/src/trace.rs @@ -11,6 +11,8 @@ //! projections. use la_arena::{Arena, ArenaMap, Idx, RawIdx}; +// FIXME: This isn't really used anymore, at least not in a way where it does anything useful. +// Check if we should get rid of this or make proper use of it instead. pub(crate) struct Trace { arena: Option>, map: Option, V>>, diff --git a/crates/hir-def/src/visibility.rs b/crates/hir-def/src/visibility.rs index b9676179a5..0f3fac1cec 100644 --- a/crates/hir-def/src/visibility.rs +++ b/crates/hir-def/src/visibility.rs @@ -37,10 +37,14 @@ impl RawVisibility { db: &dyn DefDatabase, node: InFile>, ) -> RawVisibility { + let node = match node.transpose() { + None => return RawVisibility::private(), + Some(node) => node, + }; Self::from_ast_with_span_map(db, node.value, db.span_map(node.file_id).as_ref()) } - pub(crate) fn from_ast_with_span_map( + pub(crate) fn from_opt_ast_with_span_map( db: &dyn DefDatabase, node: Option, span_map: SpanMapRef<'_>, @@ -49,29 +53,28 @@ impl RawVisibility { None => return RawVisibility::private(), Some(node) => node, }; - match node.kind() { + Self::from_ast_with_span_map(db, node, span_map) + } + + fn from_ast_with_span_map( + db: &dyn DefDatabase, + node: ast::Visibility, + span_map: SpanMapRef<'_>, + ) -> RawVisibility { + let path = match node.kind() { ast::VisibilityKind::In(path) => { let path = ModPath::from_src(db.upcast(), path, span_map); - let path = match path { + match path { None => return RawVisibility::private(), Some(path) => path, - }; - RawVisibility::Module(path, VisibilityExplicitness::Explicit) + } } - ast::VisibilityKind::PubCrate => { - let path = ModPath::from_kind(PathKind::Crate); - RawVisibility::Module(path, VisibilityExplicitness::Explicit) - } - ast::VisibilityKind::PubSuper => { - let path = ModPath::from_kind(PathKind::Super(1)); - RawVisibility::Module(path, VisibilityExplicitness::Explicit) - } - ast::VisibilityKind::PubSelf => { - let path = ModPath::from_kind(PathKind::Super(0)); - RawVisibility::Module(path, VisibilityExplicitness::Explicit) - } - ast::VisibilityKind::Pub => RawVisibility::Public, - } + ast::VisibilityKind::PubCrate => ModPath::from_kind(PathKind::Crate), + ast::VisibilityKind::PubSuper => ModPath::from_kind(PathKind::Super(1)), + ast::VisibilityKind::PubSelf => ModPath::from_kind(PathKind::Super(0)), + ast::VisibilityKind::Pub => return RawVisibility::Public, + }; + RawVisibility::Module(path, VisibilityExplicitness::Explicit) } pub fn resolve( @@ -94,6 +97,11 @@ pub enum Visibility { } impl Visibility { + pub(crate) fn is_visible_from_other_crate(self) -> bool { + matches!(self, Visibility::Public) + } + + #[tracing::instrument(skip_all)] pub fn is_visible_from(self, db: &dyn DefDatabase, from_module: ModuleId) -> bool { let to_module = match self { Visibility::Module(m, _) => m, @@ -104,24 +112,33 @@ impl Visibility { return false; } let def_map = from_module.def_map(db); - self.is_visible_from_def_map(db, &def_map, from_module.local_id) - } - - pub(crate) fn is_visible_from_other_crate(self) -> bool { - matches!(self, Visibility::Public) + Self::is_visible_from_def_map_(db, &def_map, to_module, from_module.local_id) } pub(crate) fn is_visible_from_def_map( self, db: &dyn DefDatabase, def_map: &DefMap, - mut from_module: LocalModuleId, + from_module: LocalModuleId, ) -> bool { - let mut to_module = match self { + let to_module = match self { Visibility::Module(m, _) => m, Visibility::Public => return true, }; + // if they're not in the same crate, it can't be visible + if def_map.krate() != to_module.krate { + return false; + } + Self::is_visible_from_def_map_(db, def_map, to_module, from_module) + } + fn is_visible_from_def_map_( + db: &dyn DefDatabase, + def_map: &DefMap, + mut to_module: ModuleId, + mut from_module: LocalModuleId, + ) -> bool { + debug_assert_eq!(to_module.krate, def_map.krate()); // `to_module` might be the root module of a block expression. Those have the same // visibility as the containing module (even though no items are directly nameable from // there, getting this right is important for method resolution). @@ -129,20 +146,25 @@ impl Visibility { // Additional complication: `to_module` might be in `from_module`'s `DefMap`, which we're // currently computing, so we must not call the `def_map` query for it. - let mut arc; + let def_map_block = def_map.block_id(); loop { - let to_module_def_map = - if to_module.krate == def_map.krate() && to_module.block == def_map.block_id() { + match (to_module.block, def_map_block) { + // to_module is not a block, so there is no parent def map to use + (None, _) => (), + (Some(a), Some(b)) if a == b => { cov_mark::hit!(is_visible_from_same_block_def_map); - def_map - } else { - arc = to_module.def_map(db); - &arc - }; - match to_module_def_map.parent() { - Some(parent) => to_module = parent, - None => break, + if let Some(parent) = def_map.parent() { + to_module = parent; + } + } + _ => { + if let Some(parent) = to_module.def_map(db).parent() { + to_module = parent; + continue; + } + } } + break; } // from_module needs to be a descendant of to_module @@ -175,30 +197,25 @@ impl Visibility { /// visible in unrelated modules). pub(crate) fn max(self, other: Visibility, def_map: &DefMap) -> Option { match (self, other) { - (Visibility::Module(_, _) | Visibility::Public, Visibility::Public) - | (Visibility::Public, Visibility::Module(_, _)) => Some(Visibility::Public), - (Visibility::Module(mod_a, vis_a), Visibility::Module(mod_b, vis_b)) => { + (_, Visibility::Public) | (Visibility::Public, _) => Some(Visibility::Public), + (Visibility::Module(mod_a, expl_a), Visibility::Module(mod_b, expl_b)) => { if mod_a.krate != mod_b.krate { return None; } - let mut a_ancestors = iter::successors(Some(mod_a.local_id), |&m| { - let parent_id = def_map[m].parent?; - Some(parent_id) - }); - let mut b_ancestors = iter::successors(Some(mod_b.local_id), |&m| { - let parent_id = def_map[m].parent?; - Some(parent_id) - }); + let mut a_ancestors = + iter::successors(Some(mod_a.local_id), |&m| def_map[m].parent); + let mut b_ancestors = + iter::successors(Some(mod_b.local_id), |&m| def_map[m].parent); if a_ancestors.any(|m| m == mod_b.local_id) { // B is above A - return Some(Visibility::Module(mod_b, vis_b)); + return Some(Visibility::Module(mod_b, expl_b)); } if b_ancestors.any(|m| m == mod_a.local_id) { // A is above B - return Some(Visibility::Module(mod_a, vis_a)); + return Some(Visibility::Module(mod_a, expl_a)); } None @@ -207,7 +224,8 @@ impl Visibility { } } -/// Whether the item was imported through `pub(crate) use` or just `use`. +/// Whether the item was imported through an explicit `pub(crate) use` or just a `use` without +/// visibility. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum VisibilityExplicitness { Explicit, diff --git a/crates/hir-expand/src/ast_id_map.rs b/crates/hir-expand/src/ast_id_map.rs index 530f10a068..ab582741f5 100644 --- a/crates/hir-expand/src/ast_id_map.rs +++ b/crates/hir-expand/src/ast_id_map.rs @@ -155,7 +155,7 @@ impl PartialEq for AstIdMap { impl Eq for AstIdMap {} impl AstIdMap { - pub(crate) fn ast_id_map( + pub(crate) fn new( db: &dyn ExpandDatabase, file_id: span::HirFileId, ) -> triomphe::Arc { diff --git a/crates/hir-expand/src/attrs.rs b/crates/hir-expand/src/attrs.rs index c20c1639e1..1c92dea38e 100644 --- a/crates/hir-expand/src/attrs.rs +++ b/crates/hir-expand/src/attrs.rs @@ -123,7 +123,7 @@ impl RawAttrs { .filter_map(|(idx, attr)| Attr::from_tt(db, attr, index.with_cfg_attr(idx))); let cfg_options = &crate_graph[krate].cfg_options; - let cfg = Subtree { delimiter: subtree.delimiter, token_trees: cfg.to_vec() }; + let cfg = Subtree { delimiter: subtree.delimiter, token_trees: Box::from(cfg) }; let cfg = CfgExpr::parse(&cfg); if cfg_options.check(&cfg) == Some(false) { smallvec![] diff --git a/crates/hir-expand/src/builtin_attr_macro.rs b/crates/hir-expand/src/builtin_attr_macro.rs index dd2aa94ad0..903b0d4807 100644 --- a/crates/hir-expand/src/builtin_attr_macro.rs +++ b/crates/hir-expand/src/builtin_attr_macro.rs @@ -137,5 +137,8 @@ pub fn pseudo_derive_attr_expansion( token_trees.extend(tt.iter().cloned()); token_trees.push(mk_leaf(']')); } - ExpandResult::ok(tt::Subtree { delimiter: tt.delimiter, token_trees }) + ExpandResult::ok(tt::Subtree { + delimiter: tt.delimiter, + token_trees: token_trees.into_boxed_slice(), + }) } diff --git a/crates/hir-expand/src/builtin_derive_macro.rs b/crates/hir-expand/src/builtin_derive_macro.rs index 024fb8c1f6..2795487514 100644 --- a/crates/hir-expand/src/builtin_derive_macro.rs +++ b/crates/hir-expand/src/builtin_derive_macro.rs @@ -194,6 +194,7 @@ struct BasicAdtInfo { /// second field is `Some(ty)` if it's a const param of type `ty`, `None` if it's a type param. /// third fields is where bounds, if any param_types: Vec<(tt::Subtree, Option, Option)>, + where_clause: Vec, associated_types: Vec, } @@ -202,10 +203,11 @@ fn parse_adt( adt: &ast::Adt, call_site: Span, ) -> Result { - let (name, generic_param_list, shape) = match adt { + let (name, generic_param_list, where_clause, shape) = match adt { ast::Adt::Struct(it) => ( it.name(), it.generic_param_list(), + it.where_clause(), AdtShape::Struct(VariantShape::from(tm, it.field_list())?), ), ast::Adt::Enum(it) => { @@ -217,6 +219,7 @@ fn parse_adt( ( it.name(), it.generic_param_list(), + it.where_clause(), AdtShape::Enum { default_variant, variants: it @@ -233,7 +236,9 @@ fn parse_adt( }, ) } - ast::Adt::Union(it) => (it.name(), it.generic_param_list(), AdtShape::Union), + ast::Adt::Union(it) => { + (it.name(), it.generic_param_list(), it.where_clause(), AdtShape::Union) + } }; let mut param_type_set: FxHashSet = FxHashSet::default(); @@ -274,6 +279,14 @@ fn parse_adt( }) .collect(); + let where_clause = if let Some(w) = where_clause { + w.predicates() + .map(|it| mbe::syntax_node_to_token_tree(it.syntax(), tm, call_site)) + .collect() + } else { + vec![] + }; + // For a generic parameter `T`, when shorthand associated type `T::Assoc` appears in field // types (of any variant for enums), we generate trait bound for it. It sounds reasonable to // also generate trait bound for qualified associated type `::Assoc`, but rustc @@ -301,7 +314,7 @@ fn parse_adt( .map(|it| mbe::syntax_node_to_token_tree(it.syntax(), tm, call_site)) .collect(); let name_token = name_to_token(tm, name)?; - Ok(BasicAdtInfo { name: name_token, shape, param_types, associated_types }) + Ok(BasicAdtInfo { name: name_token, shape, param_types, where_clause, associated_types }) } fn name_to_token( @@ -366,7 +379,8 @@ fn expand_simple_derive( } }; let trait_body = make_trait_body(&info); - let mut where_block = vec![]; + let mut where_block: Vec<_> = + info.where_clause.into_iter().map(|w| quote! {invoc_span => #w , }).collect(); let (params, args): (Vec<_>, Vec<_>) = info .param_types .into_iter() diff --git a/crates/hir-expand/src/builtin_fn_macro.rs b/crates/hir-expand/src/builtin_fn_macro.rs index 29d389f656..6d3de0e55d 100644 --- a/crates/hir-expand/src/builtin_fn_macro.rs +++ b/crates/hir-expand/src/builtin_fn_macro.rs @@ -155,10 +155,10 @@ fn line_expand( // not incremental ExpandResult::ok(tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(span), - token_trees: vec![tt::TokenTree::Leaf(tt::Leaf::Literal(tt::Literal { + token_trees: Box::new([tt::TokenTree::Leaf(tt::Leaf::Literal(tt::Literal { text: "0u32".into(), span, - }))], + }))]), }) } @@ -208,11 +208,11 @@ fn assert_expand( [cond, panic_args @ ..] => { let comma = tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(call_site_span), - token_trees: vec![tt::TokenTree::Leaf(tt::Leaf::Punct(tt::Punct { + token_trees: Box::new([tt::TokenTree::Leaf(tt::Leaf::Punct(tt::Punct { char: ',', spacing: tt::Spacing::Alone, span: call_site_span, - }))], + }))]), }; let cond = cond.clone(); let panic_args = itertools::Itertools::intersperse(panic_args.iter().cloned(), comma); @@ -359,7 +359,10 @@ fn panic_expand( close: call_site_span, kind: tt::DelimiterKind::Parenthesis, }; - call.token_trees.push(tt::TokenTree::Subtree(subtree)); + + // FIXME(slow): quote! have a way to expand to builder to make this a vec! + call.push(tt::TokenTree::Subtree(subtree)); + ExpandResult::ok(call) } @@ -388,7 +391,10 @@ fn unreachable_expand( close: call_site_span, kind: tt::DelimiterKind::Parenthesis, }; - call.token_trees.push(tt::TokenTree::Subtree(subtree)); + + // FIXME(slow): quote! have a way to expand to builder to make this a vec! + call.push(tt::TokenTree::Subtree(subtree)); + ExpandResult::ok(call) } @@ -509,7 +515,7 @@ fn concat_bytes_expand( tt::TokenTree::Leaf(tt::Leaf::Literal(lit)) => { let token = ast::make::tokens::literal(&lit.to_string()); match token.kind() { - syntax::SyntaxKind::BYTE => bytes.push(token.text().to_string()), + syntax::SyntaxKind::BYTE => bytes.push(token.text().to_owned()), syntax::SyntaxKind::BYTE_STRING => { let components = unquote_byte_string(lit).unwrap_or_default(); components.into_iter().for_each(|it| bytes.push(it.to_string())); @@ -564,7 +570,7 @@ fn concat_bytes_expand_subtree( let lit = ast::make::tokens::literal(&lit.to_string()); match lit.kind() { syntax::SyntaxKind::BYTE | syntax::SyntaxKind::INT_NUMBER => { - bytes.push(lit.text().to_string()) + bytes.push(lit.text().to_owned()) } _ => { return Err(mbe::ExpandError::UnexpectedToken.into()); @@ -675,10 +681,10 @@ fn include_bytes_expand( // FIXME: actually read the file here if the user asked for macro expansion let res = tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(span), - token_trees: vec![tt::TokenTree::Leaf(tt::Leaf::Literal(tt::Literal { + token_trees: Box::new([tt::TokenTree::Leaf(tt::Leaf::Literal(tt::Literal { text: r#"b"""#.into(), span, - }))], + }))]), }; ExpandResult::ok(res) } @@ -743,7 +749,7 @@ fn env_expand( // We cannot use an empty string here, because for // `include!(concat!(env!("OUT_DIR"), "/foo.rs"))` will become // `include!("foo.rs"), which might go to infinite loop - "UNRESOLVED_ENV_VAR".to_string() + "UNRESOLVED_ENV_VAR".to_owned() }); let expanded = quote! {span => #s }; diff --git a/crates/hir-expand/src/db.rs b/crates/hir-expand/src/db.rs index f220284fae..6a288cf919 100644 --- a/crates/hir-expand/src/db.rs +++ b/crates/hir-expand/src/db.rs @@ -61,7 +61,7 @@ pub trait ExpandDatabase: SourceDatabase { #[salsa::input] fn proc_macros(&self) -> Arc; - #[salsa::invoke(AstIdMap::ast_id_map)] + #[salsa::invoke(AstIdMap::new)] fn ast_id_map(&self, file_id: HirFileId) -> Arc; /// Main public API -- parses a hir file, not caring whether it's a real @@ -524,7 +524,7 @@ fn macro_expand( return ExpandResult { value: CowArc::Owned(tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(loc.call_site), - token_trees: Vec::new(), + token_trees: Box::new([]), }), // FIXME: We should make sure to enforce an invariant that invalid macro // calls do not reach this call path! @@ -586,7 +586,7 @@ fn macro_expand( return value.map(|()| { CowArc::Owned(tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(loc.call_site), - token_trees: vec![], + token_trees: Box::new([]), }) }); } @@ -601,7 +601,7 @@ fn expand_proc_macro(db: &dyn ExpandDatabase, id: MacroCallId) -> ExpandResult ExpandResult bool { a.delimiter.kind == b.delimiter.kind && a.token_trees.len() == b.token_trees.len() - && a.token_trees.iter().zip(&b.token_trees).all(|(a, b)| check_tt_eq(a, b)) + && a.token_trees.iter().zip(b.token_trees.iter()).all(|(a, b)| check_tt_eq(a, b)) } fn check_tt_eq(a: &tt::TokenTree, b: &tt::TokenTree) -> bool { diff --git a/crates/hir-expand/src/hygiene.rs b/crates/hir-expand/src/hygiene.rs index 8ddaa3f303..65b834d7a8 100644 --- a/crates/hir-expand/src/hygiene.rs +++ b/crates/hir-expand/src/hygiene.rs @@ -7,9 +7,10 @@ use std::iter; +use base_db::salsa::{self, InternValue}; use span::{MacroCallId, Span, SyntaxContextId}; -use crate::db::ExpandDatabase; +use crate::db::{ExpandDatabase, InternSyntaxContextQuery}; #[derive(Copy, Clone, Hash, PartialEq, Eq)] pub struct SyntaxContextData { @@ -22,6 +23,14 @@ pub struct SyntaxContextData { pub opaque_and_semitransparent: SyntaxContextId, } +impl InternValue for SyntaxContextData { + type Key = (SyntaxContextId, Option, Transparency); + + fn into_key(&self) -> Self::Key { + (self.parent, self.outer_expn, self.outer_transparency) + } +} + impl std::fmt::Debug for SyntaxContextData { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SyntaxContextData") @@ -149,38 +158,36 @@ fn apply_mark_internal( transparency: Transparency, ) -> SyntaxContextId { let syntax_context_data = db.lookup_intern_syntax_context(ctxt); - let mut opaque = handle_self_ref(ctxt, syntax_context_data.opaque); - let mut opaque_and_semitransparent = - handle_self_ref(ctxt, syntax_context_data.opaque_and_semitransparent); + let mut opaque = syntax_context_data.opaque; + let mut opaque_and_semitransparent = syntax_context_data.opaque_and_semitransparent; if transparency >= Transparency::Opaque { let parent = opaque; - // Unlike rustc, with salsa we can't prefetch the to be allocated ID to create cycles with - // salsa when interning, so we use a sentinel value that effectively means the current - // syntax context. - let new_opaque = SyntaxContextId::SELF_REF; - opaque = db.intern_syntax_context(SyntaxContextData { - outer_expn: call_id, - outer_transparency: transparency, - parent, - opaque: new_opaque, - opaque_and_semitransparent: new_opaque, - }); + opaque = salsa::plumbing::get_query_table::(db).get_or_insert( + (parent, call_id, transparency), + |new_opaque| SyntaxContextData { + outer_expn: call_id, + outer_transparency: transparency, + parent, + opaque: new_opaque, + opaque_and_semitransparent: new_opaque, + }, + ); } if transparency >= Transparency::SemiTransparent { let parent = opaque_and_semitransparent; - // Unlike rustc, with salsa we can't prefetch the to be allocated ID to create cycles with - // salsa when interning, so we use a sentinel value that effectively means the current - // syntax context. - let new_opaque_and_semitransparent = SyntaxContextId::SELF_REF; - opaque_and_semitransparent = db.intern_syntax_context(SyntaxContextData { - outer_expn: call_id, - outer_transparency: transparency, - parent, - opaque, - opaque_and_semitransparent: new_opaque_and_semitransparent, - }); + opaque_and_semitransparent = + salsa::plumbing::get_query_table::(db).get_or_insert( + (parent, call_id, transparency), + |new_opaque_and_semitransparent| SyntaxContextData { + outer_expn: call_id, + outer_transparency: transparency, + parent, + opaque, + opaque_and_semitransparent: new_opaque_and_semitransparent, + }, + ); } let parent = ctxt; @@ -201,20 +208,12 @@ pub trait SyntaxContextExt { fn marks(self, db: &dyn ExpandDatabase) -> Vec<(Option, Transparency)>; } -#[inline(always)] -fn handle_self_ref(p: SyntaxContextId, n: SyntaxContextId) -> SyntaxContextId { - match n { - SyntaxContextId::SELF_REF => p, - _ => n, - } -} - impl SyntaxContextExt for SyntaxContextId { fn normalize_to_macro_rules(self, db: &dyn ExpandDatabase) -> Self { - handle_self_ref(self, db.lookup_intern_syntax_context(self).opaque_and_semitransparent) + db.lookup_intern_syntax_context(self).opaque_and_semitransparent } fn normalize_to_macros_2_0(self, db: &dyn ExpandDatabase) -> Self { - handle_self_ref(self, db.lookup_intern_syntax_context(self).opaque) + db.lookup_intern_syntax_context(self).opaque } fn parent_ctxt(self, db: &dyn ExpandDatabase) -> Self { db.lookup_intern_syntax_context(self).parent diff --git a/crates/hir-expand/src/lib.rs b/crates/hir-expand/src/lib.rs index bd25052490..fd028182fa 100644 --- a/crates/hir-expand/src/lib.rs +++ b/crates/hir-expand/src/lib.rs @@ -30,7 +30,7 @@ use triomphe::Arc; use std::{fmt, hash::Hash}; -use base_db::{CrateId, Edition, FileId}; +use base_db::{salsa::impl_intern_value_trivial, CrateId, Edition, FileId}; use either::Either; use span::{FileRange, HirFileIdRepr, Span, SyntaxContextId}; use syntax::{ @@ -66,6 +66,7 @@ pub mod tt { pub type Delimiter = ::tt::Delimiter; pub type DelimSpan = ::tt::DelimSpan; pub type Subtree = ::tt::Subtree; + pub type SubtreeBuilder = ::tt::SubtreeBuilder; pub type Leaf = ::tt::Leaf; pub type Literal = ::tt::Literal; pub type Punct = ::tt::Punct; @@ -175,6 +176,7 @@ pub struct MacroCallLoc { pub kind: MacroCallKind, pub call_site: Span, } +impl_intern_value_trivial!(MacroCallLoc); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct MacroDefId { @@ -760,7 +762,7 @@ impl ExpansionInfo { ( Arc::new(tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(loc.call_site), - token_trees: Vec::new(), + token_trees: Box::new([]), }), SyntaxFixupUndoInfo::NONE, ) diff --git a/crates/hir-expand/src/quote.rs b/crates/hir-expand/src/quote.rs index 824f3c3e8f..c1930c94f5 100644 --- a/crates/hir-expand/src/quote.rs +++ b/crates/hir-expand/src/quote.rs @@ -31,7 +31,7 @@ macro_rules! __quote { open: $span, close: $span, }, - token_trees: $crate::quote::IntoTt::to_tokens(children), + token_trees: $crate::quote::IntoTt::to_tokens(children).into_boxed_slice(), } } }; @@ -146,7 +146,7 @@ impl IntoTt for Vec { fn to_subtree(self, span: Span) -> crate::tt::Subtree { crate::tt::Subtree { delimiter: crate::tt::Delimiter::invisible_spanned(span), - token_trees: self, + token_trees: self.into_boxed_slice(), } } @@ -296,8 +296,9 @@ mod tests { // } let struct_name = mk_ident("Foo"); let fields = [mk_ident("name"), mk_ident("id")]; - let fields = - fields.iter().flat_map(|it| quote!(DUMMY =>#it: self.#it.clone(), ).token_trees); + let fields = fields + .iter() + .flat_map(|it| quote!(DUMMY =>#it: self.#it.clone(), ).token_trees.into_vec()); let list = crate::tt::Subtree { delimiter: crate::tt::Delimiter { diff --git a/crates/hir-ty/src/autoderef.rs b/crates/hir-ty/src/autoderef.rs index 991fd2f91d..8d819e41aa 100644 --- a/crates/hir-ty/src/autoderef.rs +++ b/crates/hir-ty/src/autoderef.rs @@ -88,6 +88,7 @@ impl<'a, 'db> Autoderef<'a, 'db> { impl Iterator for Autoderef<'_, '_> { type Item = (Ty, usize); + #[tracing::instrument(skip_all)] fn next(&mut self) -> Option { if self.at_start { self.at_start = false; diff --git a/crates/hir-ty/src/builder.rs b/crates/hir-ty/src/builder.rs index 24a7eb3ff0..c485c9b2e8 100644 --- a/crates/hir-ty/src/builder.rs +++ b/crates/hir-ty/src/builder.rs @@ -125,6 +125,7 @@ impl TyBuilder { this } + #[tracing::instrument(skip_all)] pub(crate) fn fill_with_inference_vars(self, table: &mut InferenceTable<'_>) -> Self { self.fill(|x| match x { ParamKind::Type => table.new_type_var().cast(Interner), @@ -208,6 +209,7 @@ impl TyBuilder<()> { ) } + #[tracing::instrument(skip_all)] pub fn subst_for_def( db: &dyn HirDatabase, def: impl Into, diff --git a/crates/hir-ty/src/chalk_db.rs b/crates/hir-ty/src/chalk_db.rs index 7e460f9f86..bd243518fc 100644 --- a/crates/hir-ty/src/chalk_db.rs +++ b/crates/hir-ty/src/chalk_db.rs @@ -17,7 +17,7 @@ use hir_def::{ use hir_expand::name::name; use crate::{ - db::HirDatabase, + db::{HirDatabase, InternedCoroutine}, display::HirDisplay, from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id, make_binders, make_single_type_binders, @@ -428,7 +428,7 @@ impl chalk_solve::RustIrDatabase for ChalkContext<'_> { &self, id: chalk_ir::CoroutineId, ) -> Arc> { - let (parent, expr) = self.db.lookup_intern_coroutine(id.into()); + let InternedCoroutine(parent, expr) = self.db.lookup_intern_coroutine(id.into()); // We fill substitution with unknown type, because we only need to know whether the generic // params are types or consts to build `Binders` and those being filled up are for @@ -473,7 +473,7 @@ impl chalk_solve::RustIrDatabase for ChalkContext<'_> { let inner_types = rust_ir::CoroutineWitnessExistential { types: wrap_empty_binders(vec![]) }; - let (parent, _) = self.db.lookup_intern_coroutine(id.into()); + let InternedCoroutine(parent, _) = self.db.lookup_intern_coroutine(id.into()); // See the comment in `coroutine_datum()` for unknown types. let subst = TyBuilder::subst_for_coroutine(self.db, parent).fill_with_unknown().build(); let it = subst diff --git a/crates/hir-ty/src/consteval/tests.rs b/crates/hir-ty/src/consteval/tests.rs index ac82208708..98384c4749 100644 --- a/crates/hir-ty/src/consteval/tests.rs +++ b/crates/hir-ty/src/consteval/tests.rs @@ -133,7 +133,7 @@ fn bit_op() { check_number(r#"const GOAL: i8 = 1 << 7"#, (1i8 << 7) as i128); check_number(r#"const GOAL: i8 = -1 << 2"#, (-1i8 << 2) as i128); check_fail(r#"const GOAL: i8 = 1 << 8"#, |e| { - e == ConstEvalError::MirEvalError(MirEvalError::Panic("Overflow in Shl".to_string())) + e == ConstEvalError::MirEvalError(MirEvalError::Panic("Overflow in Shl".to_owned())) }); check_number(r#"const GOAL: i32 = 100000000i32 << 11"#, (100000000i32 << 11) as i128); } @@ -2756,7 +2756,7 @@ fn memory_limit() { "#, |e| { e == ConstEvalError::MirEvalError(MirEvalError::Panic( - "Memory allocation of 30000000000 bytes failed".to_string(), + "Memory allocation of 30000000000 bytes failed".to_owned(), )) }, ); diff --git a/crates/hir-ty/src/db.rs b/crates/hir-ty/src/db.rs index 21679150b3..fbd366864a 100644 --- a/crates/hir-ty/src/db.rs +++ b/crates/hir-ty/src/db.rs @@ -3,7 +3,11 @@ use std::sync; -use base_db::{impl_intern_key, salsa, CrateId, Upcast}; +use base_db::{ + impl_intern_key, + salsa::{self, impl_intern_value_trivial}, + CrateId, Upcast, +}; use hir_def::{ db::DefDatabase, hir::ExprId, layout::TargetDataLayout, AdtId, BlockId, ConstParamId, DefWithBodyId, EnumVariantId, FunctionId, GeneralConstId, GenericDefId, ImplId, @@ -199,9 +203,9 @@ pub trait HirDatabase: DefDatabase + Upcast { #[salsa::interned] fn intern_impl_trait_id(&self, id: ImplTraitId) -> InternedOpaqueTyId; #[salsa::interned] - fn intern_closure(&self, id: (DefWithBodyId, ExprId)) -> InternedClosureId; + fn intern_closure(&self, id: InternedClosure) -> InternedClosureId; #[salsa::interned] - fn intern_coroutine(&self, id: (DefWithBodyId, ExprId)) -> InternedCoroutineId; + fn intern_coroutine(&self, id: InternedCoroutine) -> InternedCoroutineId; #[salsa::invoke(chalk_db::associated_ty_data_query)] fn associated_ty_data( @@ -337,10 +341,18 @@ impl_intern_key!(InternedOpaqueTyId); pub struct InternedClosureId(salsa::InternId); impl_intern_key!(InternedClosureId); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct InternedClosure(pub DefWithBodyId, pub ExprId); +impl_intern_value_trivial!(InternedClosure); + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct InternedCoroutineId(salsa::InternId); impl_intern_key!(InternedCoroutineId); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct InternedCoroutine(pub DefWithBodyId, pub ExprId); +impl_intern_value_trivial!(InternedCoroutine); + /// This exists just for Chalk, because Chalk just has a single `FnDefId` where /// we have different IDs for struct and enum variant constructors. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] diff --git a/crates/hir-ty/src/diagnostics/decl_check.rs b/crates/hir-ty/src/diagnostics/decl_check.rs index 78f2005e67..38eb3371e3 100644 --- a/crates/hir-ty/src/diagnostics/decl_check.rs +++ b/crates/hir-ty/src/diagnostics/decl_check.rs @@ -16,11 +16,9 @@ mod case_conv; use std::fmt; use hir_def::{ - data::adt::VariantData, - hir::{Pat, PatId}, - src::HasSource, - AdtId, AttrDefId, ConstId, EnumId, FunctionId, ItemContainerId, Lookup, ModuleDefId, ModuleId, - StaticId, StructId, + data::adt::VariantData, db::DefDatabase, hir::Pat, src::HasSource, AdtId, AttrDefId, ConstId, + EnumId, FunctionId, ItemContainerId, Lookup, ModuleDefId, ModuleId, StaticId, StructId, + TraitId, TypeAliasId, }; use hir_expand::{ name::{AsName, Name}, @@ -79,12 +77,14 @@ pub enum IdentType { Enum, Field, Function, + Module, Parameter, StaticVariable, Structure, + Trait, + TypeAlias, Variable, Variant, - Module, } impl fmt::Display for IdentType { @@ -94,12 +94,14 @@ impl fmt::Display for IdentType { IdentType::Enum => "Enum", IdentType::Field => "Field", IdentType::Function => "Function", + IdentType::Module => "Module", IdentType::Parameter => "Parameter", IdentType::StaticVariable => "Static variable", IdentType::Structure => "Structure", + IdentType::Trait => "Trait", + IdentType::TypeAlias => "Type alias", IdentType::Variable => "Variable", IdentType::Variant => "Variant", - IdentType::Module => "Module", }; repr.fmt(f) @@ -136,10 +138,12 @@ impl<'a> DeclValidator<'a> { pub(super) fn validate_item(&mut self, item: ModuleDefId) { match item { ModuleDefId::ModuleId(module_id) => self.validate_module(module_id), + ModuleDefId::TraitId(trait_id) => self.validate_trait(trait_id), ModuleDefId::FunctionId(func) => self.validate_func(func), ModuleDefId::AdtId(adt) => self.validate_adt(adt), ModuleDefId::ConstId(const_id) => self.validate_const(const_id), ModuleDefId::StaticId(static_id) => self.validate_static(static_id), + ModuleDefId::TypeAliasId(type_alias_id) => self.validate_type_alias(type_alias_id), _ => (), } } @@ -242,50 +246,46 @@ impl<'a> DeclValidator<'a> { // Check the module name. let Some(module_name) = module_id.name(self.db.upcast()) else { return }; - let module_name_replacement = + let Some(module_name_replacement) = module_name.as_str().and_then(to_lower_snake_case).map(|new_name| Replacement { current_name: module_name, suggested_text: new_name, expected_case: CaseType::LowerSnakeCase, - }); + }) + else { + return; + }; + let module_data = &module_id.def_map(self.db.upcast())[module_id.local_id]; + let Some(module_src) = module_data.declaration_source(self.db.upcast()) else { + return; + }; + self.create_incorrect_case_diagnostic_for_ast_node( + module_name_replacement, + module_src.file_id, + &module_src.value, + IdentType::Module, + ); + } - if let Some(module_name_replacement) = module_name_replacement { - let module_data = &module_id.def_map(self.db.upcast())[module_id.local_id]; - let module_src = module_data.declaration_source(self.db.upcast()); - - if let Some(module_src) = module_src { - let ast_ptr = match module_src.value.name() { - Some(name) => name, - None => { - never!( - "Replacement ({:?}) was generated for a module without a name: {:?}", - module_name_replacement, - module_src - ); - return; - } - }; - - let diagnostic = IncorrectCase { - file: module_src.file_id, - ident_type: IdentType::Module, - ident: AstPtr::new(&ast_ptr), - expected_case: module_name_replacement.expected_case, - ident_text: module_name_replacement - .current_name - .display(self.db.upcast()) - .to_string(), - suggested_text: module_name_replacement.suggested_text, - }; - - self.sink.push(diagnostic); - } + fn validate_trait(&mut self, trait_id: TraitId) { + // Check whether non-snake case identifiers are allowed for this trait. + if self.allowed(trait_id.into(), allow::NON_CAMEL_CASE_TYPES, false) { + return; } + + // Check the trait name. + let data = self.db.trait_data(trait_id); + self.create_incorrect_case_diagnostic_for_item_name( + trait_id, + &data.name, + CaseType::UpperCamelCase, + IdentType::Trait, + ); } fn validate_func(&mut self, func: FunctionId) { - let data = self.db.function_data(func); - if matches!(func.lookup(self.db.upcast()).container, ItemContainerId::ExternBlockId(_)) { + let container = func.lookup(self.db.upcast()).container; + if matches!(container, ItemContainerId::ExternBlockId(_)) { cov_mark::hit!(extern_func_incorrect_case_ignored); return; } @@ -296,270 +296,173 @@ impl<'a> DeclValidator<'a> { } // Check the function name. - let function_name = data.name.display(self.db.upcast()).to_string(); - let fn_name_replacement = to_lower_snake_case(&function_name).map(|new_name| Replacement { - current_name: data.name.clone(), - suggested_text: new_name, - expected_case: CaseType::LowerSnakeCase, - }); - - let body = self.db.body(func.into()); + // Skipped if function is an associated item of a trait implementation. + if !self.is_trait_impl_container(container) { + let data = self.db.function_data(func); + self.create_incorrect_case_diagnostic_for_item_name( + func, + &data.name, + CaseType::LowerSnakeCase, + IdentType::Function, + ); + } else { + cov_mark::hit!(trait_impl_assoc_func_name_incorrect_case_ignored); + } // Check the patterns inside the function body. - // This includes function parameters. - let pats_replacements = body + self.validate_func_body(func); + } + + /// Check incorrect names for patterns inside the function body. + /// This includes function parameters except for trait implementation associated functions. + fn validate_func_body(&mut self, func: FunctionId) { + let body = self.db.body(func.into()); + let mut pats_replacements = body .pats .iter() .filter_map(|(pat_id, pat)| match pat { - Pat::Bind { id, .. } => Some((pat_id, &body.bindings[*id].name)), + Pat::Bind { id, .. } => { + let bind_name = &body.bindings[*id].name; + let replacement = Replacement { + current_name: bind_name.clone(), + suggested_text: to_lower_snake_case(&bind_name.to_smol_str())?, + expected_case: CaseType::LowerSnakeCase, + }; + Some((pat_id, replacement)) + } _ => None, }) - .filter_map(|(id, bind_name)| { - Some(( - id, - Replacement { - current_name: bind_name.clone(), - suggested_text: to_lower_snake_case( - &bind_name.display(self.db.upcast()).to_string(), - )?, - expected_case: CaseType::LowerSnakeCase, - }, - )) - }) - .collect(); + .peekable(); - // If there is at least one element to spawn a warning on, go to the source map and generate a warning. - if let Some(fn_name_replacement) = fn_name_replacement { - self.create_incorrect_case_diagnostic_for_func(func, fn_name_replacement); - } - - self.create_incorrect_case_diagnostic_for_variables(func, pats_replacements); - } - - /// Given the information about incorrect names in the function declaration, looks up into the source code - /// for exact locations and adds diagnostics into the sink. - fn create_incorrect_case_diagnostic_for_func( - &mut self, - func: FunctionId, - fn_name_replacement: Replacement, - ) { - let fn_loc = func.lookup(self.db.upcast()); - let fn_src = fn_loc.source(self.db.upcast()); - - // Diagnostic for function name. - let ast_ptr = match fn_src.value.name() { - Some(name) => name, - None => { - never!( - "Replacement ({:?}) was generated for a function without a name: {:?}", - fn_name_replacement, - fn_src - ); - return; - } - }; - - let diagnostic = IncorrectCase { - file: fn_src.file_id, - ident_type: IdentType::Function, - ident: AstPtr::new(&ast_ptr), - expected_case: fn_name_replacement.expected_case, - ident_text: fn_name_replacement.current_name.display(self.db.upcast()).to_string(), - suggested_text: fn_name_replacement.suggested_text, - }; - - self.sink.push(diagnostic); - } - - /// Given the information about incorrect variable names, looks up into the source code - /// for exact locations and adds diagnostics into the sink. - fn create_incorrect_case_diagnostic_for_variables( - &mut self, - func: FunctionId, - pats_replacements: Vec<(PatId, Replacement)>, - ) { // XXX: only look at source_map if we do have missing fields - if pats_replacements.is_empty() { + if pats_replacements.peek().is_none() { return; } let (_, source_map) = self.db.body_with_source_map(func.into()); - for (id, replacement) in pats_replacements { - if let Ok(source_ptr) = source_map.pat_syntax(id) { - if let Some(ptr) = source_ptr.value.cast::() { - let root = source_ptr.file_syntax(self.db.upcast()); - let ident_pat = ptr.to_node(&root); - let parent = match ident_pat.syntax().parent() { - Some(parent) => parent, - None => continue, - }; - let name_ast = match ident_pat.name() { - Some(name_ast) => name_ast, - None => continue, - }; + let Ok(source_ptr) = source_map.pat_syntax(id) else { + continue; + }; + let Some(ptr) = source_ptr.value.cast::() else { + continue; + }; + let root = source_ptr.file_syntax(self.db.upcast()); + let ident_pat = ptr.to_node(&root); + let Some(parent) = ident_pat.syntax().parent() else { + continue; + }; - let is_param = ast::Param::can_cast(parent.kind()); - - // We have to check that it's either `let var = ...` or `var @ Variant(_)` statement, - // because e.g. match arms are patterns as well. - // In other words, we check that it's a named variable binding. - let is_binding = ast::LetStmt::can_cast(parent.kind()) - || (ast::MatchArm::can_cast(parent.kind()) - && ident_pat.at_token().is_some()); - if !(is_param || is_binding) { - // This pattern is not an actual variable declaration, e.g. `Some(val) => {..}` match arm. - continue; - } - - let ident_type = - if is_param { IdentType::Parameter } else { IdentType::Variable }; - - let diagnostic = IncorrectCase { - file: source_ptr.file_id, - ident_type, - ident: AstPtr::new(&name_ast), - expected_case: replacement.expected_case, - ident_text: replacement.current_name.display(self.db.upcast()).to_string(), - suggested_text: replacement.suggested_text, - }; - - self.sink.push(diagnostic); - } + let is_param = ast::Param::can_cast(parent.kind()); + // We have to check that it's either `let var = ...` or `var @ Variant(_)` statement, + // because e.g. match arms are patterns as well. + // In other words, we check that it's a named variable binding. + let is_binding = ast::LetStmt::can_cast(parent.kind()) + || (ast::MatchArm::can_cast(parent.kind()) && ident_pat.at_token().is_some()); + if !(is_param || is_binding) { + // This pattern is not an actual variable declaration, e.g. `Some(val) => {..}` match arm. + continue; } + + let ident_type = if is_param { IdentType::Parameter } else { IdentType::Variable }; + + self.create_incorrect_case_diagnostic_for_ast_node( + replacement, + source_ptr.file_id, + &ident_pat, + ident_type, + ); } } fn validate_struct(&mut self, struct_id: StructId) { - let data = self.db.struct_data(struct_id); - + // Check the structure name. let non_camel_case_allowed = self.allowed(struct_id.into(), allow::NON_CAMEL_CASE_TYPES, false); - let non_snake_case_allowed = self.allowed(struct_id.into(), allow::NON_SNAKE_CASE, false); - - // Check the structure name. - let struct_name = data.name.display(self.db.upcast()).to_string(); - let struct_name_replacement = if !non_camel_case_allowed { - to_camel_case(&struct_name).map(|new_name| Replacement { - current_name: data.name.clone(), - suggested_text: new_name, - expected_case: CaseType::UpperCamelCase, - }) - } else { - None - }; - - // Check the field names. - let mut struct_fields_replacements = Vec::new(); - - if !non_snake_case_allowed { - if let VariantData::Record(fields) = data.variant_data.as_ref() { - for (_, field) in fields.iter() { - let field_name = field.name.display(self.db.upcast()).to_string(); - if let Some(new_name) = to_lower_snake_case(&field_name) { - let replacement = Replacement { - current_name: field.name.clone(), - suggested_text: new_name, - expected_case: CaseType::LowerSnakeCase, - }; - struct_fields_replacements.push(replacement); - } - } - } + if !non_camel_case_allowed { + let data = self.db.struct_data(struct_id); + self.create_incorrect_case_diagnostic_for_item_name( + struct_id, + &data.name, + CaseType::UpperCamelCase, + IdentType::Structure, + ); } - // If there is at least one element to spawn a warning on, go to the source map and generate a warning. - self.create_incorrect_case_diagnostic_for_struct( - struct_id, - struct_name_replacement, - struct_fields_replacements, - ); + // Check the field names. + self.validate_struct_fields(struct_id); } - /// Given the information about incorrect names in the struct declaration, looks up into the source code - /// for exact locations and adds diagnostics into the sink. - fn create_incorrect_case_diagnostic_for_struct( - &mut self, - struct_id: StructId, - struct_name_replacement: Option, - struct_fields_replacements: Vec, - ) { + /// Check incorrect names for struct fields. + fn validate_struct_fields(&mut self, struct_id: StructId) { + if self.allowed(struct_id.into(), allow::NON_SNAKE_CASE, false) { + return; + } + + let data = self.db.struct_data(struct_id); + let VariantData::Record(fields) = data.variant_data.as_ref() else { + return; + }; + let mut struct_fields_replacements = fields + .iter() + .filter_map(|(_, field)| { + to_lower_snake_case(&field.name.to_smol_str()).map(|new_name| Replacement { + current_name: field.name.clone(), + suggested_text: new_name, + expected_case: CaseType::LowerSnakeCase, + }) + }) + .peekable(); + // XXX: Only look at sources if we do have incorrect names. - if struct_name_replacement.is_none() && struct_fields_replacements.is_empty() { + if struct_fields_replacements.peek().is_none() { return; } let struct_loc = struct_id.lookup(self.db.upcast()); let struct_src = struct_loc.source(self.db.upcast()); - if let Some(replacement) = struct_name_replacement { - let ast_ptr = match struct_src.value.name() { - Some(name) => name, - None => { + let Some(ast::FieldList::RecordFieldList(struct_fields_list)) = + struct_src.value.field_list() + else { + always!( + struct_fields_replacements.peek().is_none(), + "Replacements ({:?}) were generated for a structure fields \ + which had no fields list: {:?}", + struct_fields_replacements.collect::>(), + struct_src + ); + return; + }; + let mut struct_fields_iter = struct_fields_list.fields(); + for field_replacement in struct_fields_replacements { + // We assume that parameters in replacement are in the same order as in the + // actual params list, but just some of them (ones that named correctly) are skipped. + let field = loop { + if let Some(field) = struct_fields_iter.next() { + let Some(field_name) = field.name() else { + continue; + }; + if field_name.as_name() == field_replacement.current_name { + break field; + } + } else { never!( - "Replacement ({:?}) was generated for a structure without a name: {:?}", - replacement, + "Replacement ({:?}) was generated for a structure field \ + which was not found: {:?}", + field_replacement, struct_src ); return; } }; - let diagnostic = IncorrectCase { - file: struct_src.file_id, - ident_type: IdentType::Structure, - ident: AstPtr::new(&ast_ptr), - expected_case: replacement.expected_case, - ident_text: replacement.current_name.display(self.db.upcast()).to_string(), - suggested_text: replacement.suggested_text, - }; - - self.sink.push(diagnostic); - } - - let struct_fields_list = match struct_src.value.field_list() { - Some(ast::FieldList::RecordFieldList(fields)) => fields, - _ => { - always!( - struct_fields_replacements.is_empty(), - "Replacements ({:?}) were generated for a structure fields which had no fields list: {:?}", - struct_fields_replacements, - struct_src - ); - return; - } - }; - let mut struct_fields_iter = struct_fields_list.fields(); - for field_to_rename in struct_fields_replacements { - // We assume that parameters in replacement are in the same order as in the - // actual params list, but just some of them (ones that named correctly) are skipped. - let ast_ptr = loop { - match struct_fields_iter.next().and_then(|field| field.name()) { - Some(field_name) => { - if field_name.as_name() == field_to_rename.current_name { - break field_name; - } - } - None => { - never!( - "Replacement ({:?}) was generated for a structure field which was not found: {:?}", - field_to_rename, struct_src - ); - return; - } - } - }; - - let diagnostic = IncorrectCase { - file: struct_src.file_id, - ident_type: IdentType::Field, - ident: AstPtr::new(&ast_ptr), - expected_case: field_to_rename.expected_case, - ident_text: field_to_rename.current_name.display(self.db.upcast()).to_string(), - suggested_text: field_to_rename.suggested_text, - }; - - self.sink.push(diagnostic); + self.create_incorrect_case_diagnostic_for_ast_node( + field_replacement, + struct_src.file_id, + &field, + IdentType::Field, + ); } } @@ -572,163 +475,103 @@ impl<'a> DeclValidator<'a> { } // Check the enum name. - let enum_name = data.name.display(self.db.upcast()).to_string(); - let enum_name_replacement = to_camel_case(&enum_name).map(|new_name| Replacement { - current_name: data.name.clone(), - suggested_text: new_name, - expected_case: CaseType::UpperCamelCase, - }); + self.create_incorrect_case_diagnostic_for_item_name( + enum_id, + &data.name, + CaseType::UpperCamelCase, + IdentType::Enum, + ); - // Check the field names. - let enum_fields_replacements = data + // Check the variant names. + self.validate_enum_variants(enum_id) + } + + /// Check incorrect names for enum variants. + fn validate_enum_variants(&mut self, enum_id: EnumId) { + let data = self.db.enum_data(enum_id); + let mut enum_variants_replacements = data .variants .iter() .filter_map(|(_, name)| { - Some(Replacement { + to_camel_case(&name.to_smol_str()).map(|new_name| Replacement { current_name: name.clone(), - suggested_text: to_camel_case(&name.to_smol_str())?, + suggested_text: new_name, expected_case: CaseType::UpperCamelCase, }) }) - .collect(); + .peekable(); - // If there is at least one element to spawn a warning on, go to the source map and generate a warning. - self.create_incorrect_case_diagnostic_for_enum( - enum_id, - enum_name_replacement, - enum_fields_replacements, - ) - } - - /// Given the information about incorrect names in the struct declaration, looks up into the source code - /// for exact locations and adds diagnostics into the sink. - fn create_incorrect_case_diagnostic_for_enum( - &mut self, - enum_id: EnumId, - enum_name_replacement: Option, - enum_variants_replacements: Vec, - ) { // XXX: only look at sources if we do have incorrect names - if enum_name_replacement.is_none() && enum_variants_replacements.is_empty() { + if enum_variants_replacements.peek().is_none() { return; } let enum_loc = enum_id.lookup(self.db.upcast()); let enum_src = enum_loc.source(self.db.upcast()); - if let Some(replacement) = enum_name_replacement { - let ast_ptr = match enum_src.value.name() { - Some(name) => name, - None => { + let Some(enum_variants_list) = enum_src.value.variant_list() else { + always!( + enum_variants_replacements.peek().is_none(), + "Replacements ({:?}) were generated for enum variants \ + which had no fields list: {:?}", + enum_variants_replacements, + enum_src + ); + return; + }; + let mut enum_variants_iter = enum_variants_list.variants(); + for variant_replacement in enum_variants_replacements { + // We assume that parameters in replacement are in the same order as in the + // actual params list, but just some of them (ones that named correctly) are skipped. + let variant = loop { + if let Some(variant) = enum_variants_iter.next() { + let Some(variant_name) = variant.name() else { + continue; + }; + if variant_name.as_name() == variant_replacement.current_name { + break variant; + } + } else { never!( - "Replacement ({:?}) was generated for a enum without a name: {:?}", - replacement, + "Replacement ({:?}) was generated for an enum variant \ + which was not found: {:?}", + variant_replacement, enum_src ); return; } }; - let diagnostic = IncorrectCase { - file: enum_src.file_id, - ident_type: IdentType::Enum, - ident: AstPtr::new(&ast_ptr), - expected_case: replacement.expected_case, - ident_text: replacement.current_name.display(self.db.upcast()).to_string(), - suggested_text: replacement.suggested_text, - }; - - self.sink.push(diagnostic); - } - - let enum_variants_list = match enum_src.value.variant_list() { - Some(variants) => variants, - _ => { - always!( - enum_variants_replacements.is_empty(), - "Replacements ({:?}) were generated for a enum variants which had no fields list: {:?}", - enum_variants_replacements, - enum_src - ); - return; - } - }; - let mut enum_variants_iter = enum_variants_list.variants(); - for variant_to_rename in enum_variants_replacements { - // We assume that parameters in replacement are in the same order as in the - // actual params list, but just some of them (ones that named correctly) are skipped. - let ast_ptr = loop { - match enum_variants_iter.next().and_then(|v| v.name()) { - Some(variant_name) => { - if variant_name.as_name() == variant_to_rename.current_name { - break variant_name; - } - } - None => { - never!( - "Replacement ({:?}) was generated for a enum variant which was not found: {:?}", - variant_to_rename, enum_src - ); - return; - } - } - }; - - let diagnostic = IncorrectCase { - file: enum_src.file_id, - ident_type: IdentType::Variant, - ident: AstPtr::new(&ast_ptr), - expected_case: variant_to_rename.expected_case, - ident_text: variant_to_rename.current_name.display(self.db.upcast()).to_string(), - suggested_text: variant_to_rename.suggested_text, - }; - - self.sink.push(diagnostic); + self.create_incorrect_case_diagnostic_for_ast_node( + variant_replacement, + enum_src.file_id, + &variant, + IdentType::Variant, + ); } } fn validate_const(&mut self, const_id: ConstId) { - let data = self.db.const_data(const_id); + let container = const_id.lookup(self.db.upcast()).container; + if self.is_trait_impl_container(container) { + cov_mark::hit!(trait_impl_assoc_const_incorrect_case_ignored); + return; + } if self.allowed(const_id.into(), allow::NON_UPPER_CASE_GLOBAL, false) { return; } - let name = match &data.name { - Some(name) => name, - None => return, - }; - - let const_name = name.to_smol_str(); - let replacement = if let Some(new_name) = to_upper_snake_case(&const_name) { - Replacement { - current_name: name.clone(), - suggested_text: new_name, - expected_case: CaseType::UpperSnakeCase, - } - } else { - // Nothing to do here. + let data = self.db.const_data(const_id); + let Some(name) = &data.name else { return; }; - - let const_loc = const_id.lookup(self.db.upcast()); - let const_src = const_loc.source(self.db.upcast()); - - let ast_ptr = match const_src.value.name() { - Some(name) => name, - None => return, - }; - - let diagnostic = IncorrectCase { - file: const_src.file_id, - ident_type: IdentType::Constant, - ident: AstPtr::new(&ast_ptr), - expected_case: replacement.expected_case, - ident_text: replacement.current_name.display(self.db.upcast()).to_string(), - suggested_text: replacement.suggested_text, - }; - - self.sink.push(diagnostic); + self.create_incorrect_case_diagnostic_for_item_name( + const_id, + name, + CaseType::UpperSnakeCase, + IdentType::Constant, + ); } fn validate_static(&mut self, static_id: StaticId) { @@ -742,32 +585,91 @@ impl<'a> DeclValidator<'a> { return; } - let name = &data.name; + self.create_incorrect_case_diagnostic_for_item_name( + static_id, + &data.name, + CaseType::UpperSnakeCase, + IdentType::StaticVariable, + ); + } - let static_name = name.to_smol_str(); - let replacement = if let Some(new_name) = to_upper_snake_case(&static_name) { - Replacement { - current_name: name.clone(), - suggested_text: new_name, - expected_case: CaseType::UpperSnakeCase, - } - } else { - // Nothing to do here. + fn validate_type_alias(&mut self, type_alias_id: TypeAliasId) { + let container = type_alias_id.lookup(self.db.upcast()).container; + if self.is_trait_impl_container(container) { + cov_mark::hit!(trait_impl_assoc_type_incorrect_case_ignored); + return; + } + + // Check whether non-snake case identifiers are allowed for this type alias. + if self.allowed(type_alias_id.into(), allow::NON_CAMEL_CASE_TYPES, false) { + return; + } + + // Check the type alias name. + let data = self.db.type_alias_data(type_alias_id); + self.create_incorrect_case_diagnostic_for_item_name( + type_alias_id, + &data.name, + CaseType::UpperCamelCase, + IdentType::TypeAlias, + ); + } + + fn create_incorrect_case_diagnostic_for_item_name( + &mut self, + item_id: L, + name: &Name, + expected_case: CaseType, + ident_type: IdentType, + ) where + N: AstNode + HasName + fmt::Debug, + S: HasSource, + L: Lookup = dyn DefDatabase + 'a>, + { + let to_expected_case_type = match expected_case { + CaseType::LowerSnakeCase => to_lower_snake_case, + CaseType::UpperSnakeCase => to_upper_snake_case, + CaseType::UpperCamelCase => to_camel_case, + }; + let Some(replacement) = to_expected_case_type(&name.to_smol_str()).map(|new_name| { + Replacement { current_name: name.clone(), suggested_text: new_name, expected_case } + }) else { return; }; - let static_loc = static_id.lookup(self.db.upcast()); - let static_src = static_loc.source(self.db.upcast()); + let item_loc = item_id.lookup(self.db.upcast()); + let item_src = item_loc.source(self.db.upcast()); + self.create_incorrect_case_diagnostic_for_ast_node( + replacement, + item_src.file_id, + &item_src.value, + ident_type, + ); + } - let ast_ptr = match static_src.value.name() { - Some(name) => name, - None => return, + fn create_incorrect_case_diagnostic_for_ast_node( + &mut self, + replacement: Replacement, + file_id: HirFileId, + node: &T, + ident_type: IdentType, + ) where + T: AstNode + HasName + fmt::Debug, + { + let Some(name_ast) = node.name() else { + never!( + "Replacement ({:?}) was generated for a {:?} without a name: {:?}", + replacement, + ident_type, + node + ); + return; }; let diagnostic = IncorrectCase { - file: static_src.file_id, - ident_type: IdentType::StaticVariable, - ident: AstPtr::new(&ast_ptr), + file: file_id, + ident_type, + ident: AstPtr::new(&name_ast), expected_case: replacement.expected_case, ident_text: replacement.current_name.display(self.db.upcast()).to_string(), suggested_text: replacement.suggested_text, @@ -775,4 +677,13 @@ impl<'a> DeclValidator<'a> { self.sink.push(diagnostic); } + + fn is_trait_impl_container(&self, container_id: ItemContainerId) -> bool { + if let ItemContainerId::ImplId(impl_id) = container_id { + if self.db.impl_trait(impl_id).is_some() { + return true; + } + } + false + } } diff --git a/crates/hir-ty/src/diagnostics/expr.rs b/crates/hir-ty/src/diagnostics/expr.rs index eda8f2371c..7f8fb7f4b5 100644 --- a/crates/hir-ty/src/diagnostics/expr.rs +++ b/crates/hir-ty/src/diagnostics/expr.rs @@ -27,7 +27,7 @@ use crate::{ pub(crate) use hir_def::{ body::Body, - hir::{Expr, ExprId, MatchArm, Pat, PatId}, + hir::{Expr, ExprId, MatchArm, Pat, PatId, Statement}, LocalFieldId, VariantId, }; @@ -44,6 +44,12 @@ pub enum BodyValidationDiagnostic { match_expr: ExprId, uncovered_patterns: String, }, + RemoveTrailingReturn { + return_expr: ExprId, + }, + RemoveUnnecessaryElse { + if_expr: ExprId, + }, } impl BodyValidationDiagnostic { @@ -72,6 +78,10 @@ impl ExprValidator { let body = db.body(self.owner); let mut filter_map_next_checker = None; + if matches!(self.owner, DefWithBodyId::FunctionId(_)) { + self.check_for_trailing_return(body.body_expr, &body); + } + for (id, expr) in body.exprs.iter() { if let Some((variant, missed_fields, true)) = record_literal_missing_fields(db, &self.infer, id, expr) @@ -90,9 +100,16 @@ impl ExprValidator { Expr::Call { .. } | Expr::MethodCall { .. } => { self.validate_call(db, id, expr, &mut filter_map_next_checker); } + Expr::Closure { body: body_expr, .. } => { + self.check_for_trailing_return(*body_expr, &body); + } + Expr::If { .. } => { + self.check_for_unnecessary_else(id, expr, &body); + } _ => {} } } + for (id, pat) in body.pats.iter() { if let Some((variant, missed_fields, true)) = record_pattern_missing_fields(db, &self.infer, id, pat) @@ -153,14 +170,7 @@ impl ExprValidator { } let pattern_arena = Arena::new(); - let ty_arena = Arena::new(); - let cx = MatchCheckCtx::new( - self.owner.module(db.upcast()), - self.owner, - db, - &pattern_arena, - &ty_arena, - ); + let cx = MatchCheckCtx::new(self.owner.module(db.upcast()), self.owner, db, &pattern_arena); let mut m_arms = Vec::with_capacity(arms.len()); let mut has_lowering_errors = false; @@ -207,7 +217,7 @@ impl ExprValidator { } let report = match compute_match_usefulness( - rustc_pattern_analysis::MatchCtxt { tycx: &cx }, + &cx, m_arms.as_slice(), scrut_ty.clone(), ValidityConstraint::ValidOnly, @@ -244,6 +254,59 @@ impl ExprValidator { } pattern } + + fn check_for_trailing_return(&mut self, body_expr: ExprId, body: &Body) { + match &body.exprs[body_expr] { + Expr::Block { statements, tail, .. } => { + let last_stmt = tail.or_else(|| match statements.last()? { + Statement::Expr { expr, .. } => Some(*expr), + _ => None, + }); + if let Some(last_stmt) = last_stmt { + self.check_for_trailing_return(last_stmt, body); + } + } + Expr::If { then_branch, else_branch, .. } => { + self.check_for_trailing_return(*then_branch, body); + if let Some(else_branch) = else_branch { + self.check_for_trailing_return(*else_branch, body); + } + } + Expr::Match { arms, .. } => { + for arm in arms.iter() { + let MatchArm { expr, .. } = arm; + self.check_for_trailing_return(*expr, body); + } + } + Expr::Return { .. } => { + self.diagnostics.push(BodyValidationDiagnostic::RemoveTrailingReturn { + return_expr: body_expr, + }); + } + _ => (), + } + } + + fn check_for_unnecessary_else(&mut self, id: ExprId, expr: &Expr, body: &Body) { + if let Expr::If { condition: _, then_branch, else_branch } = expr { + if else_branch.is_none() { + return; + } + if let Expr::Block { statements, tail, .. } = &body.exprs[*then_branch] { + let last_then_expr = tail.or_else(|| match statements.last()? { + Statement::Expr { expr, .. } => Some(*expr), + _ => None, + }); + if let Some(last_then_expr) = last_then_expr { + let last_then_expr_ty = &self.infer[last_then_expr]; + if last_then_expr_ty.is_never() { + self.diagnostics + .push(BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr: id }) + } + } + } + } + } } struct FilterMapNextChecker { diff --git a/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs b/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs index 0b595042cd..712842372b 100644 --- a/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs +++ b/crates/hir-ty/src/diagnostics/match_check/pat_analysis.rs @@ -9,7 +9,7 @@ use rustc_pattern_analysis::{ index::IdxContainer, Captures, TypeCx, }; -use smallvec::SmallVec; +use smallvec::{smallvec, SmallVec}; use stdx::never; use typed_arena::Arena; @@ -41,8 +41,14 @@ pub(crate) struct MatchCheckCtx<'p> { body: DefWithBodyId, pub(crate) db: &'p dyn HirDatabase, pub(crate) pattern_arena: &'p Arena>, - ty_arena: &'p Arena, exhaustive_patterns: bool, + min_exhaustive_patterns: bool, +} + +#[derive(Clone)] +pub(crate) struct PatData<'p> { + /// Keep db around so that we can print variant names in `Debug`. + pub(crate) db: &'p dyn HirDatabase, } impl<'p> MatchCheckCtx<'p> { @@ -51,11 +57,12 @@ impl<'p> MatchCheckCtx<'p> { body: DefWithBodyId, db: &'p dyn HirDatabase, pattern_arena: &'p Arena>, - ty_arena: &'p Arena, ) -> Self { let def_map = db.crate_def_map(module.krate()); let exhaustive_patterns = def_map.is_unstable_feature_enabled("exhaustive_patterns"); - Self { module, body, db, pattern_arena, exhaustive_patterns, ty_arena } + let min_exhaustive_patterns = + def_map.is_unstable_feature_enabled("min_exhaustive_patterns"); + Self { module, body, db, pattern_arena, exhaustive_patterns, min_exhaustive_patterns } } fn is_uninhabited(&self, ty: &Ty) -> bool { @@ -75,18 +82,15 @@ impl<'p> MatchCheckCtx<'p> { } } - fn variant_id_for_adt(&self, ctor: &Constructor, adt: hir_def::AdtId) -> VariantId { + fn variant_id_for_adt(ctor: &Constructor, adt: hir_def::AdtId) -> Option { match ctor { - &Variant(id) => id.into(), - Struct | UnionField => { - assert!(!matches!(adt, hir_def::AdtId::EnumId(_))); - match adt { - hir_def::AdtId::EnumId(_) => unreachable!(), - hir_def::AdtId::StructId(id) => id.into(), - hir_def::AdtId::UnionId(id) => id.into(), - } - } - _ => panic!("bad constructor {self:?} for adt {adt:?}"), + &Variant(id) => Some(id.into()), + Struct | UnionField => match adt { + hir_def::AdtId::EnumId(_) => None, + hir_def::AdtId::StructId(id) => Some(id.into()), + hir_def::AdtId::UnionId(id) => Some(id.into()), + }, + _ => panic!("bad constructor {ctor:?} for adt {adt:?}"), } } @@ -200,7 +204,7 @@ impl<'p> MatchCheckCtx<'p> { Wildcard } }; - let variant = self.variant_id_for_adt(&ctor, adt.0); + let variant = Self::variant_id_for_adt(&ctor, adt.0).unwrap(); let fields_len = variant.variant_data(self.db.upcast()).fields().len(); // For each field in the variant, we store the relevant index into `self.fields` if any. let mut field_id_to_id: Vec> = vec![None; fields_len]; @@ -241,7 +245,8 @@ impl<'p> MatchCheckCtx<'p> { fields = self.pattern_arena.alloc_extend(subpats); } } - DeconstructedPat::new(ctor, fields, pat.ty.clone(), ()) + let data = PatData { db: self.db }; + DeconstructedPat::new(ctor, fields, pat.ty.clone(), data) } pub(crate) fn hoist_witness_pat(&self, pat: &WitnessPat<'p>) -> Pat { @@ -266,7 +271,7 @@ impl<'p> MatchCheckCtx<'p> { PatKind::Deref { subpattern: subpatterns.next().unwrap() } } TyKind::Adt(adt, substs) => { - let variant = self.variant_id_for_adt(pat.ctor(), adt.0); + let variant = Self::variant_id_for_adt(pat.ctor(), adt.0).unwrap(); let subpatterns = self .list_variant_nonhidden_fields(pat.ty(), variant) .zip(subpatterns) @@ -307,11 +312,14 @@ impl<'p> TypeCx for MatchCheckCtx<'p> { type VariantIdx = EnumVariantId; type StrLit = Void; type ArmData = (); - type PatData = (); + type PatData = PatData<'p>; fn is_exhaustive_patterns_feature_on(&self) -> bool { self.exhaustive_patterns } + fn is_min_exhaustive_patterns_feature_on(&self) -> bool { + self.min_exhaustive_patterns + } fn ctor_arity( &self, @@ -327,7 +335,7 @@ impl<'p> TypeCx for MatchCheckCtx<'p> { // patterns. If we're here we can assume this is a box pattern. 1 } else { - let variant = self.variant_id_for_adt(ctor, adt); + let variant = Self::variant_id_for_adt(ctor, adt).unwrap(); self.list_variant_nonhidden_fields(ty, variant).count() } } @@ -347,54 +355,51 @@ impl<'p> TypeCx for MatchCheckCtx<'p> { } } - fn ctor_sub_tys( - &self, - ctor: &rustc_pattern_analysis::constructor::Constructor, - ty: &Self::Ty, - ) -> &[Self::Ty] { - use std::iter::once; - fn alloc<'a>(cx: &'a MatchCheckCtx<'_>, iter: impl Iterator) -> &'a [Ty] { - cx.ty_arena.alloc_extend(iter) - } - match ctor { + fn ctor_sub_tys<'a>( + &'a self, + ctor: &'a rustc_pattern_analysis::constructor::Constructor, + ty: &'a Self::Ty, + ) -> impl ExactSizeIterator + Captures<'a> { + let single = |ty| smallvec![ty]; + let tys: SmallVec<[_; 2]> = match ctor { Struct | Variant(_) | UnionField => match ty.kind(Interner) { TyKind::Tuple(_, substs) => { let tys = substs.iter(Interner).map(|ty| ty.assert_ty_ref(Interner)); - alloc(self, tys.cloned()) + tys.cloned().collect() } - TyKind::Ref(.., rty) => alloc(self, once(rty.clone())), + TyKind::Ref(.., rty) => single(rty.clone()), &TyKind::Adt(AdtId(adt), ref substs) => { if is_box(self.db, adt) { // The only legal patterns of type `Box` (outside `std`) are `_` and box // patterns. If we're here we can assume this is a box pattern. let subst_ty = substs.at(Interner, 0).assert_ty_ref(Interner).clone(); - alloc(self, once(subst_ty)) + single(subst_ty) } else { - let variant = self.variant_id_for_adt(ctor, adt); - let tys = self.list_variant_nonhidden_fields(ty, variant).map(|(_, ty)| ty); - alloc(self, tys) + let variant = Self::variant_id_for_adt(ctor, adt).unwrap(); + self.list_variant_nonhidden_fields(ty, variant).map(|(_, ty)| ty).collect() } } ty_kind => { never!("Unexpected type for `{:?}` constructor: {:?}", ctor, ty_kind); - alloc(self, once(ty.clone())) + single(ty.clone()) } }, Ref => match ty.kind(Interner) { - TyKind::Ref(.., rty) => alloc(self, once(rty.clone())), + TyKind::Ref(.., rty) => single(rty.clone()), ty_kind => { never!("Unexpected type for `{:?}` constructor: {:?}", ctor, ty_kind); - alloc(self, once(ty.clone())) + single(ty.clone()) } }, Slice(_) => unreachable!("Found a `Slice` constructor in match checking"), Bool(..) | IntRange(..) | F32Range(..) | F64Range(..) | Str(..) | Opaque(..) - | NonExhaustive | Hidden | Missing | Wildcard => &[], + | NonExhaustive | Hidden | Missing | Wildcard => smallvec![], Or => { never!("called `Fields::wildcards` on an `Or` ctor"); - &[] + smallvec![] } - } + }; + tys.into_iter() } fn ctors_for_ty( @@ -456,11 +461,27 @@ impl<'p> TypeCx for MatchCheckCtx<'p> { }) } - fn debug_pat( - _f: &mut fmt::Formatter<'_>, - _pat: &rustc_pattern_analysis::pat::DeconstructedPat<'_, Self>, + fn write_variant_name( + f: &mut fmt::Formatter<'_>, + pat: &rustc_pattern_analysis::pat::DeconstructedPat<'_, Self>, ) -> fmt::Result { - // FIXME: implement this, as using `unimplemented!()` causes panics in `tracing`. + let variant = + pat.ty().as_adt().and_then(|(adt, _)| Self::variant_id_for_adt(pat.ctor(), adt)); + + let db = pat.data().unwrap().db; + if let Some(variant) = variant { + match variant { + VariantId::EnumVariantId(v) => { + write!(f, "{}", db.enum_variant_data(v).name.display(db.upcast()))?; + } + VariantId::StructId(s) => { + write!(f, "{}", db.struct_data(s).name.display(db.upcast()))? + } + VariantId::UnionId(u) => { + write!(f, "{}", db.union_data(u).name.display(db.upcast()))? + } + } + } Ok(()) } diff --git a/crates/hir-ty/src/display.rs b/crates/hir-ty/src/display.rs index 2327c8df1b..fe51ec3f82 100644 --- a/crates/hir-ty/src/display.rs +++ b/crates/hir-ty/src/display.rs @@ -32,7 +32,7 @@ use triomphe::Arc; use crate::{ consteval::try_const_usize, - db::HirDatabase, + db::{HirDatabase, InternedClosure}, from_assoc_type_id, from_foreign_def_id, from_placeholder_idx, layout::Layout, lt_from_placeholder_idx, @@ -814,9 +814,8 @@ impl HirDisplay for Ty { // Don't count Sized but count when it absent // (i.e. when explicit ?Sized bound is set). - let default_sized = SizedByDefault::Sized { - anchor: func.lookup(db.upcast()).module(db.upcast()).krate(), - }; + let default_sized = + SizedByDefault::Sized { anchor: func.krate(db.upcast()) }; let sized_bounds = bounds .skip_binders() .iter() @@ -1025,7 +1024,7 @@ impl HirDisplay for Ty { let data = (*datas).as_ref().map(|rpit| rpit.impl_traits[idx].bounds.clone()); let bounds = data.substitute(Interner, ¶meters); - let krate = func.lookup(db.upcast()).module(db.upcast()).krate(); + let krate = func.krate(db.upcast()); write_bounds_like_dyn_trait_with_prefix( f, "impl", @@ -1086,7 +1085,7 @@ impl HirDisplay for Ty { } let sig = ClosureSubst(substs).sig_ty().callable_sig(db); if let Some(sig) = sig { - let (def, _) = db.lookup_intern_closure((*id).into()); + let InternedClosure(def, _) = db.lookup_intern_closure((*id).into()); let infer = db.infer(def); let (_, kind) = infer.closure_info(id); match f.closure_style { @@ -1191,7 +1190,7 @@ impl HirDisplay for Ty { let data = (*datas).as_ref().map(|rpit| rpit.impl_traits[idx].bounds.clone()); let bounds = data.substitute(Interner, &opaque_ty.substitution); - let krate = func.lookup(db.upcast()).module(db.upcast()).krate(); + let krate = func.krate(db.upcast()); write_bounds_like_dyn_trait_with_prefix( f, "impl", diff --git a/crates/hir-ty/src/infer/closure.rs b/crates/hir-ty/src/infer/closure.rs index 572df8f713..c3746f7870 100644 --- a/crates/hir-ty/src/infer/closure.rs +++ b/crates/hir-ty/src/infer/closure.rs @@ -21,7 +21,7 @@ use smallvec::SmallVec; use stdx::never; use crate::{ - db::HirDatabase, + db::{HirDatabase, InternedClosure}, from_placeholder_idx, make_binders, mir::{BorrowKind, MirSpan, ProjectionElem}, static_lifetime, to_chalk_trait_id, @@ -194,17 +194,15 @@ impl CapturedItem { } let variant_data = f.parent.variant_data(db.upcast()); let field = match &*variant_data { - VariantData::Record(fields) => fields[f.local_id] - .name - .as_str() - .unwrap_or("[missing field]") - .to_string(), + VariantData::Record(fields) => { + fields[f.local_id].name.as_str().unwrap_or("[missing field]").to_owned() + } VariantData::Tuple(fields) => fields .iter() .position(|it| it.0 == f.local_id) .unwrap_or_default() .to_string(), - VariantData::Unit => "[missing field]".to_string(), + VariantData::Unit => "[missing field]".to_owned(), }; result = format!("{result}.{field}"); field_need_paren = false; @@ -718,7 +716,7 @@ impl InferenceContext<'_> { fn is_upvar(&self, place: &HirPlace) -> bool { if let Some(c) = self.current_closure { - let (_, root) = self.db.lookup_intern_closure(c.into()); + let InternedClosure(_, root) = self.db.lookup_intern_closure(c.into()); return self.body.is_binding_upvar(place.local, root); } false @@ -940,7 +938,7 @@ impl InferenceContext<'_> { } fn analyze_closure(&mut self, closure: ClosureId) -> FnTrait { - let (_, root) = self.db.lookup_intern_closure(closure.into()); + let InternedClosure(_, root) = self.db.lookup_intern_closure(closure.into()); self.current_closure = Some(closure); let Expr::Closure { body, capture_by, .. } = &self.body[root] else { unreachable!("Closure expression id is always closure"); diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 842f7bdafe..8b8e97b008 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -23,6 +23,7 @@ use syntax::ast::RangeOp; use crate::{ autoderef::{builtin_deref, deref_by_trait, Autoderef}, consteval, + db::{InternedClosure, InternedCoroutine}, infer::{ coerce::{CoerceMany, CoercionCause}, find_continuable, @@ -253,13 +254,17 @@ impl InferenceContext<'_> { .push(ret_ty.clone()) .build(); - let coroutine_id = self.db.intern_coroutine((self.owner, tgt_expr)).into(); + let coroutine_id = self + .db + .intern_coroutine(InternedCoroutine(self.owner, tgt_expr)) + .into(); let coroutine_ty = TyKind::Coroutine(coroutine_id, subst).intern(Interner); (None, coroutine_ty, Some((resume_ty, yield_ty))) } ClosureKind::Closure | ClosureKind::Async => { - let closure_id = self.db.intern_closure((self.owner, tgt_expr)).into(); + let closure_id = + self.db.intern_closure(InternedClosure(self.owner, tgt_expr)).into(); let closure_ty = TyKind::Closure( closure_id, TyBuilder::subst_for_closure(self.db, self.owner, sig_ty.clone()), diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs index 61c8233950..de23ca3499 100644 --- a/crates/hir-ty/src/infer/unify.rs +++ b/crates/hir-ty/src/infer/unify.rs @@ -469,12 +469,14 @@ impl<'a> InferenceTable<'a> { } } + #[tracing::instrument(skip_all)] pub(crate) fn rollback_to(&mut self, snapshot: InferenceTableSnapshot) { self.var_unification_table.rollback_to(snapshot.var_table_snapshot); self.type_variable_table = snapshot.type_variable_table_snapshot; self.pending_obligations = snapshot.pending_obligations; } + #[tracing::instrument(skip_all)] pub(crate) fn run_in_snapshot(&mut self, f: impl FnOnce(&mut InferenceTable<'_>) -> T) -> T { let snapshot = self.snapshot(); let result = f(self); diff --git a/crates/hir-ty/src/layout.rs b/crates/hir-ty/src/layout.rs index 310c4cc9ff..be1c8d9094 100644 --- a/crates/hir-ty/src/layout.rs +++ b/crates/hir-ty/src/layout.rs @@ -19,8 +19,12 @@ use stdx::never; use triomphe::Arc; use crate::{ - consteval::try_const_usize, db::HirDatabase, infer::normalize, layout::adt::struct_variant_idx, - utils::ClosureSubst, Interner, ProjectionTy, Substitution, TraitEnvironment, Ty, + consteval::try_const_usize, + db::{HirDatabase, InternedClosure}, + infer::normalize, + layout::adt::struct_variant_idx, + utils::ClosureSubst, + Interner, ProjectionTy, Substitution, TraitEnvironment, Ty, }; pub use self::{ @@ -391,7 +395,7 @@ pub fn layout_of_ty_query( } } TyKind::Closure(c, subst) => { - let (def, _) = db.lookup_intern_closure((*c).into()); + let InternedClosure(def, _) = db.lookup_intern_closure((*c).into()); let infer = db.infer(def); let (captures, _) = infer.closure_info(c); let fields = captures diff --git a/crates/hir-ty/src/layout/tests.rs b/crates/hir-ty/src/layout/tests.rs index 1f2ea753c1..ba3dfe8100 100644 --- a/crates/hir-ty/src/layout/tests.rs +++ b/crates/hir-ty/src/layout/tests.rs @@ -1,8 +1,7 @@ -use std::collections::HashMap; - use chalk_ir::{AdtId, TyKind}; use either::Either; use hir_def::db::DefDatabase; +use rustc_hash::FxHashMap; use test_fixture::WithFixture; use triomphe::Arc; @@ -16,7 +15,7 @@ use crate::{ mod closure; fn current_machine_data_layout() -> String { - project_model::target_data_layout::get(None, None, &HashMap::default()).unwrap() + project_model::target_data_layout::get(None, None, &FxHashMap::default()).unwrap() } fn eval_goal(ra_fixture: &str, minicore: &str) -> Result, LayoutError> { diff --git a/crates/hir-ty/src/lib.rs b/crates/hir-ty/src/lib.rs index 288c42405d..7013863334 100644 --- a/crates/hir-ty/src/lib.rs +++ b/crates/hir-ty/src/lib.rs @@ -51,6 +51,7 @@ use std::{ hash::{BuildHasherDefault, Hash}, }; +use base_db::salsa::impl_intern_value_trivial; use chalk_ir::{ fold::{Shift, TypeFoldable}, interner::HasInterner, @@ -228,7 +229,7 @@ impl MemoryMap { &self, mut f: impl FnMut(&[u8], usize) -> Result, ) -> Result, MirEvalError> { - let mut transform = |(addr, val): (&usize, &Box<[u8]>)| { + let mut transform = |(addr, val): (&usize, &[u8])| { let addr = *addr; let align = if addr == 0 { 64 } else { (addr - (addr & (addr - 1))).min(64) }; f(val, align).map(|it| (addr, it)) @@ -240,7 +241,9 @@ impl MemoryMap { map.insert(addr, val); map }), - MemoryMap::Complex(cm) => cm.memory.iter().map(transform).collect(), + MemoryMap::Complex(cm) => { + cm.memory.iter().map(|(addr, val)| transform((addr, val))).collect() + } } } @@ -584,6 +587,7 @@ pub enum ImplTraitId { ReturnTypeImplTrait(hir_def::FunctionId, RpitId), AsyncBlockTypeImplTrait(hir_def::DefWithBodyId, ExprId), } +impl_intern_value_trivial!(ImplTraitId); #[derive(Clone, PartialEq, Eq, Debug, Hash)] pub struct ReturnTypeImplTraits { diff --git a/crates/hir-ty/src/lower.rs b/crates/hir-ty/src/lower.rs index 386a03d93f..75ac3b0d66 100644 --- a/crates/hir-ty/src/lower.rs +++ b/crates/hir-ty/src/lower.rs @@ -10,7 +10,10 @@ use std::{ iter, }; -use base_db::{salsa::Cycle, CrateId}; +use base_db::{ + salsa::{impl_intern_value_trivial, Cycle}, + CrateId, +}; use chalk_ir::{ cast::Cast, fold::Shift, fold::TypeFoldable, interner::HasInterner, Mutability, Safety, }; @@ -1225,7 +1228,7 @@ impl<'a> TyLoweringContext<'a> { .collect(); if !ctx.unsized_types.borrow().contains(&self_ty) { - let krate = func.lookup(ctx.db.upcast()).module(ctx.db.upcast()).krate(); + let krate = func.krate(ctx.db.upcast()); let sized_trait = ctx .db .lang_item(krate, LangItem::Sized) @@ -1809,6 +1812,7 @@ pub enum CallableDefId { StructId(StructId), EnumVariantId(EnumVariantId), } +impl_intern_value_trivial!(CallableDefId); impl_from!(FunctionId, StructId, EnumVariantId for CallableDefId); impl From for ModuleDefId { fn from(def: CallableDefId) -> ModuleDefId { @@ -1824,11 +1828,10 @@ impl CallableDefId { pub fn krate(self, db: &dyn HirDatabase) -> CrateId { let db = db.upcast(); match self { - CallableDefId::FunctionId(f) => f.lookup(db).module(db), - CallableDefId::StructId(s) => s.lookup(db).container, - CallableDefId::EnumVariantId(e) => e.module(db), + CallableDefId::FunctionId(f) => f.krate(db), + CallableDefId::StructId(s) => s.krate(db), + CallableDefId::EnumVariantId(e) => e.krate(db), } - .krate() } } diff --git a/crates/hir-ty/src/method_resolution.rs b/crates/hir-ty/src/method_resolution.rs index 1c068bf684..a4baf572d9 100644 --- a/crates/hir-ty/src/method_resolution.rs +++ b/crates/hir-ty/src/method_resolution.rs @@ -931,6 +931,15 @@ pub fn iterate_method_candidates_dyn( mode: LookupMode, callback: &mut dyn FnMut(ReceiverAdjustments, AssocItemId, bool) -> ControlFlow<()>, ) -> ControlFlow<()> { + let _p = tracing::span!( + tracing::Level::INFO, + "iterate_method_candidates_dyn", + ?mode, + ?name, + traits_in_scope_len = traits_in_scope.len() + ) + .entered(); + match mode { LookupMode::MethodCall => { // For method calls, rust first does any number of autoderef, and @@ -984,6 +993,7 @@ pub fn iterate_method_candidates_dyn( } } +#[tracing::instrument(skip_all, fields(name = ?name))] fn iterate_method_candidates_with_autoref( receiver_ty: &Canonical, first_adjustment: ReceiverAdjustments, @@ -1041,6 +1051,7 @@ fn iterate_method_candidates_with_autoref( ) } +#[tracing::instrument(skip_all, fields(name = ?name))] fn iterate_method_candidates_by_receiver( receiver_ty: &Canonical, receiver_adjustments: ReceiverAdjustments, @@ -1088,6 +1099,7 @@ fn iterate_method_candidates_by_receiver( ControlFlow::Continue(()) } +#[tracing::instrument(skip_all, fields(name = ?name))] fn iterate_method_candidates_for_self_ty( self_ty: &Canonical, db: &dyn HirDatabase, @@ -1119,6 +1131,7 @@ fn iterate_method_candidates_for_self_ty( ) } +#[tracing::instrument(skip_all, fields(name = ?name, visible_from_module, receiver_ty))] fn iterate_trait_method_candidates( self_ty: &Ty, table: &mut InferenceTable<'_>, @@ -1175,6 +1188,7 @@ fn iterate_trait_method_candidates( ControlFlow::Continue(()) } +#[tracing::instrument(skip_all, fields(name = ?name, visible_from_module, receiver_ty))] fn iterate_inherent_methods( self_ty: &Ty, table: &mut InferenceTable<'_>, @@ -1267,6 +1281,7 @@ fn iterate_inherent_methods( } return ControlFlow::Continue(()); + #[tracing::instrument(skip_all, fields(name = ?name, visible_from_module, receiver_ty))] fn iterate_inherent_trait_methods( self_ty: &Ty, table: &mut InferenceTable<'_>, @@ -1293,6 +1308,7 @@ fn iterate_inherent_methods( ControlFlow::Continue(()) } + #[tracing::instrument(skip_all, fields(name = ?name, visible_from_module, receiver_ty))] fn impls_for_self_ty( impls: &InherentImpls, self_ty: &Ty, @@ -1356,6 +1372,7 @@ macro_rules! check_that { }; } +#[tracing::instrument(skip_all, fields(name))] fn is_valid_candidate( table: &mut InferenceTable<'_>, name: Option<&Name>, @@ -1403,6 +1420,7 @@ enum IsValidCandidate { NotVisible, } +#[tracing::instrument(skip_all, fields(name))] fn is_valid_fn_candidate( table: &mut InferenceTable<'_>, fn_id: FunctionId, @@ -1439,15 +1457,15 @@ fn is_valid_fn_candidate( _ => unreachable!(), }; - let fn_subst = TyBuilder::subst_for_def(db, fn_id, Some(impl_subst.clone())) - .fill_with_inference_vars(table) - .build(); - check_that!(table.unify(&expect_self_ty, self_ty)); if let Some(receiver_ty) = receiver_ty { check_that!(data.has_self_param()); + let fn_subst = TyBuilder::subst_for_def(db, fn_id, Some(impl_subst.clone())) + .fill_with_inference_vars(table) + .build(); + let sig = db.callable_item_signature(fn_id.into()); let expected_receiver = sig.map(|s| s.params()[0].clone()).substitute(Interner, &fn_subst); @@ -1540,6 +1558,7 @@ pub fn implements_trait_unique( /// This creates Substs for a trait with the given Self type and type variables /// for all other parameters, to query Chalk with it. +#[tracing::instrument(skip_all)] fn generic_implements_goal( db: &dyn HirDatabase, env: Arc, diff --git a/crates/hir-ty/src/mir/borrowck.rs b/crates/hir-ty/src/mir/borrowck.rs index ea4e60cad3..9089c11c5d 100644 --- a/crates/hir-ty/src/mir/borrowck.rs +++ b/crates/hir-ty/src/mir/borrowck.rs @@ -11,7 +11,10 @@ use stdx::never; use triomphe::Arc; use crate::{ - db::HirDatabase, mir::Operand, utils::ClosureSubst, ClosureId, Interner, Ty, TyExt, TypeFlags, + db::{HirDatabase, InternedClosure}, + mir::Operand, + utils::ClosureSubst, + ClosureId, Interner, Ty, TyExt, TypeFlags, }; use super::{ @@ -97,7 +100,7 @@ fn moved_out_of_ref(db: &dyn HirDatabase, body: &MirBody) -> Vec ty, db, |c, subst, f| { - let (def, _) = db.lookup_intern_closure(c.into()); + let InternedClosure(def, _) = db.lookup_intern_closure(c.into()); let infer = db.infer(def); let (captures, _) = infer.closure_info(&c); let parent_subst = ClosureSubst(subst).parent_subst(); @@ -215,7 +218,7 @@ fn place_case(db: &dyn HirDatabase, body: &MirBody, lvalue: &Place) -> Projectio ty, db, |c, subst, f| { - let (def, _) = db.lookup_intern_closure(c.into()); + let InternedClosure(def, _) = db.lookup_intern_closure(c.into()); let infer = db.infer(def); let (captures, _) = infer.closure_info(&c); let parent_subst = ClosureSubst(subst).parent_subst(); diff --git a/crates/hir-ty/src/mir/eval.rs b/crates/hir-ty/src/mir/eval.rs index 50c4d00660..2428678d72 100644 --- a/crates/hir-ty/src/mir/eval.rs +++ b/crates/hir-ty/src/mir/eval.rs @@ -25,7 +25,7 @@ use triomphe::Arc; use crate::{ consteval::{intern_const_scalar, try_const_usize, ConstEvalError}, - db::HirDatabase, + db::{HirDatabase, InternedClosure}, display::{ClosureStyle, HirDisplay}, infer::PointerCast, layout::{Layout, LayoutError, RustcEnumVariantIdx}, @@ -647,7 +647,7 @@ impl Evaluator<'_> { ty.clone(), self.db, |c, subst, f| { - let (def, _) = self.db.lookup_intern_closure(c.into()); + let InternedClosure(def, _) = self.db.lookup_intern_closure(c.into()); let infer = self.db.infer(def); let (captures, _) = infer.closure_info(&c); let parent_subst = ClosureSubst(subst).parent_subst(); @@ -1763,7 +1763,7 @@ impl Evaluator<'_> { } }; mem.get(pos..pos + size) - .ok_or_else(|| MirEvalError::UndefinedBehavior("out of bound memory read".to_string())) + .ok_or_else(|| MirEvalError::UndefinedBehavior("out of bound memory read".to_owned())) } fn write_memory_using_ref(&mut self, addr: Address, size: usize) -> Result<&mut [u8]> { @@ -1777,7 +1777,7 @@ impl Evaluator<'_> { } }; mem.get_mut(pos..pos + size) - .ok_or_else(|| MirEvalError::UndefinedBehavior("out of bound memory write".to_string())) + .ok_or_else(|| MirEvalError::UndefinedBehavior("out of bound memory write".to_owned())) } fn write_memory(&mut self, addr: Address, r: &[u8]) -> Result<()> { @@ -1800,7 +1800,7 @@ impl Evaluator<'_> { return Ok(()); } - let oob = || MirEvalError::UndefinedBehavior("out of bounds memory write".to_string()); + let oob = || MirEvalError::UndefinedBehavior("out of bounds memory write".to_owned()); match (addr, r.addr) { (Stack(dst), Stack(src)) => { @@ -2653,7 +2653,7 @@ pub fn render_const_using_debug_impl( ptr: ArenaMap::new(), body: db .mir_body(owner.into()) - .map_err(|_| MirEvalError::NotSupported("unreachable".to_string()))?, + .map_err(|_| MirEvalError::NotSupported("unreachable".to_owned()))?, drop_flags: DropFlags::default(), }; let data = evaluator.allocate_const_in_heap(locals, c)?; diff --git a/crates/hir-ty/src/mir/eval/shim.rs b/crates/hir-ty/src/mir/eval/shim.rs index b4fb99acae..d68803fe28 100644 --- a/crates/hir-ty/src/mir/eval/shim.rs +++ b/crates/hir-ty/src/mir/eval/shim.rs @@ -178,7 +178,7 @@ impl Evaluator<'_> { not_supported!("wrong arg count for clone"); }; let addr = Address::from_bytes(arg.get(self)?)?; - let (closure_owner, _) = self.db.lookup_intern_closure((*id).into()); + let InternedClosure(closure_owner, _) = self.db.lookup_intern_closure((*id).into()); let infer = self.db.infer(closure_owner); let (captures, _) = infer.closure_info(id); let layout = self.layout(&self_ty)?; @@ -304,7 +304,7 @@ impl Evaluator<'_> { use LangItem::*; let mut args = args.iter(); match it { - BeginPanic => Err(MirEvalError::Panic("".to_string())), + BeginPanic => Err(MirEvalError::Panic("".to_owned())), PanicFmt => { let message = (|| { let resolver = self diff --git a/crates/hir-ty/src/mir/lower.rs b/crates/hir-ty/src/mir/lower.rs index 28d26c6c8a..1572a6d497 100644 --- a/crates/hir-ty/src/mir/lower.rs +++ b/crates/hir-ty/src/mir/lower.rs @@ -25,7 +25,7 @@ use triomphe::Arc; use crate::{ consteval::ConstEvalError, - db::HirDatabase, + db::{HirDatabase, InternedClosure}, display::HirDisplay, infer::{CaptureKind, CapturedItem, TypeMismatch}, inhabitedness::is_ty_uninhabited_from, @@ -126,6 +126,10 @@ impl DropScopeToken { } } +impl Drop for DropScopeToken { + fn drop(&mut self) {} +} + // Uncomment this to make `DropScopeToken` a drop bomb. Unfortunately we can't do this in release, since // in cases that mir lowering fails, we don't handle (and don't need to handle) drop scopes so it will be // actually reached. `pop_drop_scope_assert_finished` will also detect this case, but doesn't show useful @@ -1630,7 +1634,7 @@ impl<'ctx> MirLowerCtx<'ctx> { self.set_goto(prev_block, begin, span); f(self, begin)?; let my = mem::replace(&mut self.current_loop_blocks, prev).ok_or( - MirLowerError::ImplementationError("current_loop_blocks is corrupt".to_string()), + MirLowerError::ImplementationError("current_loop_blocks is corrupt".to_owned()), )?; if let Some(prev) = prev_label { self.labeled_loop_blocks.insert(label.unwrap(), prev); @@ -1665,7 +1669,7 @@ impl<'ctx> MirLowerCtx<'ctx> { .current_loop_blocks .as_mut() .ok_or(MirLowerError::ImplementationError( - "Current loop access out of loop".to_string(), + "Current loop access out of loop".to_owned(), ))? .end { @@ -1675,7 +1679,7 @@ impl<'ctx> MirLowerCtx<'ctx> { self.current_loop_blocks .as_mut() .ok_or(MirLowerError::ImplementationError( - "Current loop access out of loop".to_string(), + "Current loop access out of loop".to_owned(), ))? .end = Some(s); s @@ -1973,7 +1977,7 @@ pub fn mir_body_for_closure_query( db: &dyn HirDatabase, closure: ClosureId, ) -> Result> { - let (owner, expr) = db.lookup_intern_closure(closure.into()); + let InternedClosure(owner, expr) = db.lookup_intern_closure(closure.into()); let body = db.body(owner); let infer = db.infer(owner); let Expr::Closure { args, body: root, .. } = &body[expr] else { diff --git a/crates/hir-ty/src/mir/lower/as_place.rs b/crates/hir-ty/src/mir/lower/as_place.rs index 8d15794402..afe33607d4 100644 --- a/crates/hir-ty/src/mir/lower/as_place.rs +++ b/crates/hir-ty/src/mir/lower/as_place.rs @@ -225,7 +225,7 @@ impl MirLowerCtx<'_> { { let Some(index_fn) = self.infer.method_resolution(expr_id) else { return Err(MirLowerError::UnresolvedMethod( - "[overloaded index]".to_string(), + "[overloaded index]".to_owned(), )); }; let Some((base_place, current)) = diff --git a/crates/hir-ty/src/mir/monomorphization.rs b/crates/hir-ty/src/mir/monomorphization.rs index 46dec257e8..d2e413f0a3 100644 --- a/crates/hir-ty/src/mir/monomorphization.rs +++ b/crates/hir-ty/src/mir/monomorphization.rs @@ -19,7 +19,7 @@ use triomphe::Arc; use crate::{ consteval::{intern_const_scalar, unknown_const}, - db::HirDatabase, + db::{HirDatabase, InternedClosure}, from_placeholder_idx, infer::normalize, utils::{generics, Generics}, @@ -315,7 +315,7 @@ pub fn monomorphized_mir_body_for_closure_query( subst: Substitution, trait_env: Arc, ) -> Result, MirLowerError> { - let (owner, _) = db.lookup_intern_closure(closure.into()); + let InternedClosure(owner, _) = db.lookup_intern_closure(closure.into()); let generics = owner.as_generic_def_id().map(|g_def| generics(db.upcast(), g_def)); let filler = &mut Filler { db, subst: &subst, trait_env, generics, owner }; let body = db.mir_body_for_closure(closure)?; diff --git a/crates/hir-ty/src/tests.rs b/crates/hir-ty/src/tests.rs index 9804910c87..5e159236f4 100644 --- a/crates/hir-ty/src/tests.rs +++ b/crates/hir-ty/src/tests.rs @@ -10,7 +10,7 @@ mod regression; mod simple; mod traits; -use std::{collections::HashMap, env}; +use std::env; use base_db::{FileRange, SourceDatabaseExt}; use expect_test::Expect; @@ -25,6 +25,7 @@ use hir_def::{ }; use hir_expand::{db::ExpandDatabase, InFile}; use once_cell::race::OnceBool; +use rustc_hash::FxHashMap; use stdx::format_to; use syntax::{ ast::{self, AstNode, HasName}, @@ -90,16 +91,16 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour let (db, files) = TestDB::with_many_files(ra_fixture); let mut had_annotations = false; - let mut mismatches = HashMap::new(); - let mut types = HashMap::new(); - let mut adjustments = HashMap::<_, Vec<_>>::new(); + let mut mismatches = FxHashMap::default(); + let mut types = FxHashMap::default(); + let mut adjustments = FxHashMap::<_, Vec<_>>::default(); for (file_id, annotations) in db.extract_annotations() { for (range, expected) in annotations { let file_range = FileRange { file_id, range }; if only_types { types.insert(file_range, expected); } else if expected.starts_with("type: ") { - types.insert(file_range, expected.trim_start_matches("type: ").to_string()); + types.insert(file_range, expected.trim_start_matches("type: ").to_owned()); } else if expected.starts_with("expected") { mismatches.insert(file_range, expected); } else if expected.starts_with("adjustments:") { @@ -109,7 +110,7 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour .trim_start_matches("adjustments:") .trim() .split(',') - .map(|it| it.trim().to_string()) + .map(|it| it.trim().to_owned()) .filter(|it| !it.is_empty()) .collect(), ); @@ -330,7 +331,7 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { }); for (node, ty) in &types { let (range, text) = if let Some(self_param) = ast::SelfParam::cast(node.value.clone()) { - (self_param.name().unwrap().syntax().text_range(), "self".to_string()) + (self_param.name().unwrap().syntax().text_range(), "self".to_owned()) } else { (node.value.text_range(), node.value.text().to_string().replace('\n', " ")) }; diff --git a/crates/hir-ty/src/tests/macros.rs b/crates/hir-ty/src/tests/macros.rs index b0a9361f1c..2f75338f99 100644 --- a/crates/hir-ty/src/tests/macros.rs +++ b/crates/hir-ty/src/tests/macros.rs @@ -1373,3 +1373,34 @@ pub fn attr_macro() {} "#, ); } + +#[test] +fn clone_with_type_bound() { + check_types( + r#" +//- minicore: derive, clone, builtin_impls +#[derive(Clone)] +struct Float; + +trait TensorKind: Clone { + /// The primitive type of the tensor. + type Primitive: Clone; +} + +impl TensorKind for Float { + type Primitive = f64; +} + +#[derive(Clone)] +struct Tensor where K: TensorKind +{ + primitive: K::Primitive, +} + +fn foo(t: Tensor) { + let x = t.clone(); + //^ Tensor +} +"#, + ); +} diff --git a/crates/hir-ty/src/traits.rs b/crates/hir-ty/src/traits.rs index 3a1a4e63ea..b2232b920a 100644 --- a/crates/hir-ty/src/traits.rs +++ b/crates/hir-ty/src/traits.rs @@ -104,8 +104,8 @@ pub(crate) fn trait_solve_query( GoalData::DomainGoal(DomainGoal::Holds(WhereClause::Implemented(it))) => { db.trait_data(it.hir_trait_id()).name.display(db.upcast()).to_string() } - GoalData::DomainGoal(DomainGoal::Holds(WhereClause::AliasEq(_))) => "alias_eq".to_string(), - _ => "??".to_string(), + GoalData::DomainGoal(DomainGoal::Holds(WhereClause::AliasEq(_))) => "alias_eq".to_owned(), + _ => "??".to_owned(), }; let _p = tracing::span!(tracing::Level::INFO, "trait_solve_query", ?detail).entered(); tracing::info!("trait_solve_query({:?})", goal.value.goal); @@ -187,7 +187,7 @@ struct LoggingRustIrDatabaseLoggingOnDrop<'a>(LoggingRustIrDatabase { fn drop(&mut self) { - eprintln!("chalk program:\n{}", self.0); + tracing::info!("chalk program:\n{}", self.0); } } diff --git a/crates/hir/src/attrs.rs b/crates/hir/src/attrs.rs index 5c369f42e6..7d637bac09 100644 --- a/crates/hir/src/attrs.rs +++ b/crates/hir/src/attrs.rs @@ -30,7 +30,7 @@ macro_rules! impl_has_attrs { impl HasAttrs for $def { fn attrs(self, db: &dyn HirDatabase) -> AttrsWithOwner { let def = AttrDefId::$def_id(self.into()); - AttrsWithOwner::attrs_with_owner(db.upcast(), def) + AttrsWithOwner::new(db.upcast(), def) } fn attr_id(self) -> AttrDefId { AttrDefId::$def_id(self.into()) diff --git a/crates/hir/src/diagnostics.rs b/crates/hir/src/diagnostics.rs index 2d8f1dbad5..08843a6c99 100644 --- a/crates/hir/src/diagnostics.rs +++ b/crates/hir/src/diagnostics.rs @@ -67,6 +67,8 @@ diagnostics![ NoSuchField, PrivateAssocItem, PrivateField, + RemoveTrailingReturn, + RemoveUnnecessaryElse, ReplaceFilterMapNextWithFindMap, TraitImplIncorrectSafety, TraitImplMissingAssocItems, @@ -342,6 +344,16 @@ pub struct TraitImplRedundantAssocItems { pub assoc_item: (Name, AssocItem), } +#[derive(Debug)] +pub struct RemoveTrailingReturn { + pub return_expr: InFile>, +} + +#[derive(Debug)] +pub struct RemoveUnnecessaryElse { + pub if_expr: InFile>, +} + impl AnyDiagnostic { pub(crate) fn body_validation_diagnostic( db: &dyn HirDatabase, @@ -444,6 +456,29 @@ impl AnyDiagnostic { Err(SyntheticSyntax) => (), } } + BodyValidationDiagnostic::RemoveTrailingReturn { return_expr } => { + if let Ok(source_ptr) = source_map.expr_syntax(return_expr) { + // Filters out desugared return expressions (e.g. desugared try operators). + if let Some(ptr) = source_ptr.value.cast::() { + return Some( + RemoveTrailingReturn { + return_expr: InFile::new(source_ptr.file_id, ptr), + } + .into(), + ); + } + } + } + BodyValidationDiagnostic::RemoveUnnecessaryElse { if_expr } => { + if let Ok(source_ptr) = source_map.expr_syntax(if_expr) { + if let Some(ptr) = source_ptr.value.cast::() { + return Some( + RemoveUnnecessaryElse { if_expr: InFile::new(source_ptr.file_id, ptr) } + .into(), + ); + } + } + } } None } @@ -546,9 +581,7 @@ impl AnyDiagnostic { source_map.pat_syntax(pat).expect("unexpected synthetic"); // cast from Either -> Either<_, Pat> - let Some(ptr) = AstPtr::try_from_raw(value.syntax_node_ptr()) else { - return None; - }; + let ptr = AstPtr::try_from_raw(value.syntax_node_ptr())?; InFile { file_id, value: ptr } } }; diff --git a/crates/hir/src/display.rs b/crates/hir/src/display.rs index 9b99b141fc..30f402a79f 100644 --- a/crates/hir/src/display.rs +++ b/crates/hir/src/display.rs @@ -158,7 +158,8 @@ impl HirDisplay for Adt { impl HirDisplay for Struct { fn hir_fmt(&self, f: &mut HirFormatter<'_>) -> Result<(), HirDisplayError> { - write_visibility(self.module(f.db).id, self.visibility(f.db), f)?; + let module_id = self.module(f.db).id; + write_visibility(module_id, self.visibility(f.db), f)?; f.write_str("struct ")?; write!(f, "{}", self.name(f.db).display(f.db.upcast()))?; let def_id = GenericDefId::AdtId(AdtId::StructId(self.id)); @@ -171,6 +172,7 @@ impl HirDisplay for Struct { while let Some((id, _)) = it.next() { let field = Field { parent: (*self).into(), id }; + write_visibility(module_id, field.visibility(f.db), f)?; field.ty(f.db).hir_fmt(f)?; if it.peek().is_some() { f.write_str(", ")?; diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 1e21045e98..32abbc80c6 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -44,7 +44,7 @@ use hir_def::{ data::adt::VariantData, generics::{LifetimeParamData, TypeOrConstParamData, TypeParamProvenance}, hir::{BindingAnnotation, BindingId, ExprOrPatId, LabelId, Pat}, - item_tree::ItemTreeModItemNode, + item_tree::ItemTreeNode, lang_item::LangItemTarget, layout::{self, ReprOptions, TargetDataLayout}, nameres::{self, diagnostics::DefDiagnostic}, @@ -62,6 +62,7 @@ use hir_expand::{attrs::collect_attrs, name::name, proc_macro::ProcMacroKind, Ma use hir_ty::{ all_super_traits, autoderef, check_orphan_rules, consteval::{try_const_usize, unknown_const_as_generic, ConstExt}, + db::InternedClosure, diagnostics::BodyValidationDiagnostic, known_const_to_ast, layout::{Layout as TyLayout, RustcEnumVariantIdx, RustcFieldIdx, TagEncoding}, @@ -563,6 +564,11 @@ impl Module { for diag in db.trait_data_with_diagnostics(t.id).1.iter() { emit_def_diagnostic(db, acc, diag); } + + for item in t.items(db) { + item.diagnostics(db, acc); + } + acc.extend(def.diagnostics(db)) } ModuleDef::Adt(adt) => { @@ -730,13 +736,7 @@ impl Module { } for &item in &db.impl_data(impl_def.id).items { - let def: DefWithBody = match AssocItem::from(item) { - AssocItem::Function(it) => it.into(), - AssocItem::Const(it) => it.into(), - AssocItem::TypeAlias(_) => continue, - }; - - def.diagnostics(db, acc); + AssocItem::from(item).diagnostics(db, acc); } } } @@ -1769,7 +1769,7 @@ pub struct Function { impl Function { pub fn module(self, db: &dyn HirDatabase) -> Module { - self.id.lookup(db.upcast()).module(db.upcast()).into() + self.id.module(db.upcast()).into() } pub fn name(self, db: &dyn HirDatabase) -> Name { @@ -1910,8 +1910,7 @@ impl Function { { return None; } - let loc = self.id.lookup(db.upcast()); - let def_map = db.crate_def_map(loc.krate(db).into()); + let def_map = db.crate_def_map(HasModule::krate(&self.id, db.upcast())); def_map.fn_as_proc_macro(self.id).map(|id| Macro { id: id.into() }) } @@ -1934,7 +1933,7 @@ impl Function { }; let (result, output) = interpret_mir(db, body, false, None); let mut text = match result { - Ok(_) => "pass".to_string(), + Ok(_) => "pass".to_owned(), Err(e) => { let mut r = String::new(); _ = e.pretty_print(&mut r, db, &span_formatter); @@ -2120,7 +2119,7 @@ pub struct Const { impl Const { pub fn module(self, db: &dyn HirDatabase) -> Module { - Module { id: self.id.lookup(db.upcast()).module(db.upcast()) } + Module { id: self.id.module(db.upcast()) } } pub fn name(self, db: &dyn HirDatabase) -> Option { @@ -2175,7 +2174,7 @@ pub struct Static { impl Static { pub fn module(self, db: &dyn HirDatabase) -> Module { - Module { id: self.id.lookup(db.upcast()).module(db.upcast()) } + Module { id: self.id.module(db.upcast()) } } pub fn name(self, db: &dyn HirDatabase) -> Name { @@ -2294,7 +2293,7 @@ impl TypeAlias { } pub fn module(self, db: &dyn HirDatabase) -> Module { - Module { id: self.id.lookup(db.upcast()).module(db.upcast()) } + Module { id: self.id.module(db.upcast()) } } pub fn type_ref(self, db: &dyn HirDatabase) -> Option { @@ -2517,11 +2516,13 @@ pub enum AssocItem { Const(Const), TypeAlias(TypeAlias), } + #[derive(Debug, Clone)] pub enum AssocItemContainer { Trait(Trait), Impl(Impl), } + pub trait AsAssocItem { fn as_assoc_item(self, db: &dyn HirDatabase) -> Option; } @@ -2531,16 +2532,19 @@ impl AsAssocItem for Function { as_assoc_item(db, AssocItem::Function, self.id) } } + impl AsAssocItem for Const { fn as_assoc_item(self, db: &dyn HirDatabase) -> Option { as_assoc_item(db, AssocItem::Const, self.id) } } + impl AsAssocItem for TypeAlias { fn as_assoc_item(self, db: &dyn HirDatabase) -> Option { as_assoc_item(db, AssocItem::TypeAlias, self.id) } } + impl AsAssocItem for ModuleDef { fn as_assoc_item(self, db: &dyn HirDatabase) -> Option { match self { @@ -2551,6 +2555,7 @@ impl AsAssocItem for ModuleDef { } } } + impl AsAssocItem for DefWithBody { fn as_assoc_item(self, db: &dyn HirDatabase) -> Option { match self { @@ -2561,16 +2566,15 @@ impl AsAssocItem for DefWithBody { } } -fn as_assoc_item<'db, ID, DEF, CTOR, AST>( +fn as_assoc_item<'db, ID, DEF, LOC>( db: &(dyn HirDatabase + 'db), - ctor: CTOR, + ctor: impl FnOnce(DEF) -> AssocItem, id: ID, ) -> Option where - ID: Lookup = dyn DefDatabase + 'db, Data = AssocItemLoc>, + ID: Lookup = dyn DefDatabase + 'db, Data = AssocItemLoc>, DEF: From, - CTOR: FnOnce(DEF) -> AssocItem, - AST: ItemTreeModItemNode, + LOC: ItemTreeNode, { match id.lookup(db.upcast()).container { ItemContainerId::TraitId(_) | ItemContainerId::ImplId(_) => Some(ctor(DEF::from(id))), @@ -2610,27 +2614,34 @@ impl AssocItem { } } - pub fn containing_trait(self, db: &dyn HirDatabase) -> Option { + pub fn container_trait(self, db: &dyn HirDatabase) -> Option { match self.container(db) { AssocItemContainer::Trait(t) => Some(t), _ => None, } } - pub fn containing_trait_impl(self, db: &dyn HirDatabase) -> Option { + pub fn implemented_trait(self, db: &dyn HirDatabase) -> Option { match self.container(db) { AssocItemContainer::Impl(i) => i.trait_(db), _ => None, } } - pub fn containing_trait_or_trait_impl(self, db: &dyn HirDatabase) -> Option { + pub fn container_or_implemented_trait(self, db: &dyn HirDatabase) -> Option { match self.container(db) { AssocItemContainer::Trait(t) => Some(t), AssocItemContainer::Impl(i) => i.trait_(db), } } + pub fn implementing_ty(self, db: &dyn HirDatabase) -> Option { + match self.container(db) { + AssocItemContainer::Impl(i) => Some(i.self_ty(db)), + _ => None, + } + } + pub fn as_function(self) -> Option { match self { Self::Function(v) => Some(v), @@ -2651,6 +2662,22 @@ impl AssocItem { _ => None, } } + + pub fn diagnostics(self, db: &dyn HirDatabase, acc: &mut Vec) { + match self { + AssocItem::Function(func) => { + DefWithBody::from(func).diagnostics(db, acc); + } + AssocItem::Const(const_) => { + DefWithBody::from(const_).diagnostics(db, acc); + } + AssocItem::TypeAlias(type_alias) => { + for diag in hir_ty::diagnostics::incorrect_case(db, type_alias.id.into()) { + acc.push(diag.into()); + } + } + } + } } impl HasVisibility for AssocItem { @@ -3306,7 +3333,7 @@ impl Impl { } pub fn items(self, db: &dyn HirDatabase) -> Vec { - db.impl_data(self.id).items.iter().map(|it| (*it).into()).collect() + db.impl_data(self.id).items.iter().map(|&it| it.into()).collect() } pub fn is_negative(self, db: &dyn HirDatabase) -> bool { @@ -3662,7 +3689,7 @@ impl Type { .and_then(|it| { let into_future_fn = it.as_function()?; let assoc_item = as_assoc_item(db, AssocItem::Function, into_future_fn)?; - let into_future_trait = assoc_item.containing_trait_or_trait_impl(db)?; + let into_future_trait = assoc_item.container_or_implemented_trait(db)?; Some(into_future_trait.id) }) .or_else(|| { @@ -4101,6 +4128,14 @@ impl Type { name: Option<&Name>, callback: &mut dyn FnMut(AssocItemId) -> ControlFlow<()>, ) { + let _p = tracing::span!( + tracing::Level::INFO, + "iterate_method_candidates_dyn", + with_local_impls = traits_in_scope.len(), + traits_in_scope = traits_in_scope.len(), + ?name, + ) + .entered(); // There should be no inference vars in types passed here let canonical = hir_ty::replace_errors_with_variables(&self.ty); @@ -4122,6 +4157,7 @@ impl Type { ); } + #[tracing::instrument(skip_all, fields(name = ?name))] pub fn iterate_path_candidates( &self, db: &dyn HirDatabase, @@ -4150,6 +4186,7 @@ impl Type { slot } + #[tracing::instrument(skip_all, fields(name = ?name))] fn iterate_path_candidates_dyn( &self, db: &dyn HirDatabase, @@ -4463,7 +4500,7 @@ impl Callable { } fn closure_source(db: &dyn HirDatabase, closure: ClosureId) -> Option { - let (owner, expr_id) = db.lookup_intern_closure(closure.into()); + let InternedClosure(owner, expr_id) = db.lookup_intern_closure(closure.into()); let (_, source_map) = db.body_with_source_map(owner); let ast = source_map.expr_syntax(expr_id).ok()?; let root = ast.file_syntax(db.upcast()); diff --git a/crates/ide-assists/src/handlers/apply_demorgan.rs b/crates/ide-assists/src/handlers/apply_demorgan.rs index 2d41243c20..55e0d7f3b2 100644 --- a/crates/ide-assists/src/handlers/apply_demorgan.rs +++ b/crates/ide-assists/src/handlers/apply_demorgan.rs @@ -96,7 +96,7 @@ pub(crate) fn apply_demorgan(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opti let dm_lhs = demorganed.lhs()?; acc.add_group( - &GroupLabel("Apply De Morgan's law".to_string()), + &GroupLabel("Apply De Morgan's law".to_owned()), AssistId("apply_demorgan", AssistKind::RefactorRewrite), "Apply De Morgan's law", op_range, @@ -187,7 +187,7 @@ pub(crate) fn apply_demorgan_iterator(acc: &mut Assists, ctx: &AssistContext<'_> let op_range = method_call.syntax().text_range(); let label = format!("Apply De Morgan's law to `Iterator::{}`", name.text().as_str()); acc.add_group( - &GroupLabel("Apply De Morgan's law".to_string()), + &GroupLabel("Apply De Morgan's law".to_owned()), AssistId("apply_demorgan_iterator", AssistKind::RefactorRewrite), label, op_range, diff --git a/crates/ide-assists/src/handlers/convert_bool_then.rs b/crates/ide-assists/src/handlers/convert_bool_then.rs index d231708c55..61b7b41217 100644 --- a/crates/ide-assists/src/handlers/convert_bool_then.rs +++ b/crates/ide-assists/src/handlers/convert_bool_then.rs @@ -163,9 +163,8 @@ pub(crate) fn convert_bool_then_to_if(acc: &mut Assists, ctx: &AssistContext<'_> return None; } let assoc = func.as_assoc_item(ctx.sema.db)?; - match assoc.container(ctx.sema.db) { - hir::AssocItemContainer::Impl(impl_) if impl_.self_ty(ctx.sema.db).is_bool() => {} - _ => return None, + if !assoc.implementing_ty(ctx.sema.db)?.is_bool() { + return None; } let target = mcall.syntax().text_range(); diff --git a/crates/ide-assists/src/handlers/convert_comment_block.rs b/crates/ide-assists/src/handlers/convert_comment_block.rs index 3f478ee7d3..fbc0b9f673 100644 --- a/crates/ide-assists/src/handlers/convert_comment_block.rs +++ b/crates/ide-assists/src/handlers/convert_comment_block.rs @@ -57,7 +57,7 @@ fn block_to_line(acc: &mut Assists, comment: ast::Comment) -> Option<()> { // Don't introduce trailing whitespace if line.is_empty() { - line_prefix.to_string() + line_prefix.to_owned() } else { format!("{line_prefix} {line}") } diff --git a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs index 6f30ffa622..e1966d476c 100644 --- a/crates/ide-assists/src/handlers/convert_to_guarded_return.rs +++ b/crates/ide-assists/src/handlers/convert_to_guarded_return.rs @@ -1,6 +1,9 @@ use std::iter::once; -use ide_db::syntax_helpers::node_ext::{is_pattern_cond, single_let}; +use ide_db::{ + syntax_helpers::node_ext::{is_pattern_cond, single_let}, + ty_filter::TryEnum, +}; use syntax::{ ast::{ self, @@ -41,13 +44,35 @@ use crate::{ // } // ``` pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { - let if_expr: ast::IfExpr = ctx.find_node_at_offset()?; + if let Some(let_stmt) = ctx.find_node_at_offset() { + let_stmt_to_guarded_return(let_stmt, acc, ctx) + } else if let Some(if_expr) = ctx.find_node_at_offset() { + if_expr_to_guarded_return(if_expr, acc, ctx) + } else { + None + } +} + +fn if_expr_to_guarded_return( + if_expr: ast::IfExpr, + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { if if_expr.else_branch().is_some() { return None; } let cond = if_expr.condition()?; + let if_token_range = if_expr.if_token()?.text_range(); + let if_cond_range = cond.syntax().text_range(); + + let cursor_in_range = + if_token_range.cover(if_cond_range).contains_range(ctx.selection_trimmed()); + if !cursor_in_range { + return None; + } + // Check if there is an IfLet that we can handle. let (if_let_pat, cond_expr) = if is_pattern_cond(cond.clone()) { let let_ = single_let(cond)?; @@ -148,6 +173,65 @@ pub(crate) fn convert_to_guarded_return(acc: &mut Assists, ctx: &AssistContext<' ) } +fn let_stmt_to_guarded_return( + let_stmt: ast::LetStmt, + acc: &mut Assists, + ctx: &AssistContext<'_>, +) -> Option<()> { + let pat = let_stmt.pat()?; + let expr = let_stmt.initializer()?; + + let let_token_range = let_stmt.let_token()?.text_range(); + let let_pattern_range = pat.syntax().text_range(); + let cursor_in_range = + let_token_range.cover(let_pattern_range).contains_range(ctx.selection_trimmed()); + + if !cursor_in_range { + return None; + } + + let try_enum = + ctx.sema.type_of_expr(&expr).and_then(|ty| TryEnum::from_ty(&ctx.sema, &ty.adjusted()))?; + + let happy_pattern = try_enum.happy_pattern(pat); + let target = let_stmt.syntax().text_range(); + + let early_expression: ast::Expr = { + let parent_block = + let_stmt.syntax().parent()?.ancestors().find_map(ast::BlockExpr::cast)?; + let parent_container = parent_block.syntax().parent()?; + + match parent_container.kind() { + WHILE_EXPR | LOOP_EXPR | FOR_EXPR => make::expr_continue(None), + FN => make::expr_return(None), + _ => return None, + } + }; + + acc.add( + AssistId("convert_to_guarded_return", AssistKind::RefactorRewrite), + "Convert to guarded return", + target, + |edit| { + let let_stmt = edit.make_mut(let_stmt); + let let_indent_level = IndentLevel::from_node(let_stmt.syntax()); + + let replacement = { + let let_else_stmt = make::let_else_stmt( + happy_pattern, + let_stmt.ty(), + expr, + ast::make::tail_only_block_expr(early_expression), + ); + let let_else_stmt = let_else_stmt.indent(let_indent_level); + let_else_stmt.syntax().clone_for_update() + }; + + ted::replace(let_stmt.syntax(), replacement) + }, + ) +} + #[cfg(test)] mod tests { use crate::tests::{check_assist, check_assist_not_applicable}; @@ -450,6 +534,62 @@ fn main() { ); } + #[test] + fn convert_let_stmt_inside_fn() { + check_assist( + convert_to_guarded_return, + r#" +//- minicore: option +fn foo() -> Option { + None +} + +fn main() { + let x$0 = foo(); +} +"#, + r#" +fn foo() -> Option { + None +} + +fn main() { + let Some(x) = foo() else { return }; +} +"#, + ); + } + + #[test] + fn convert_let_stmt_inside_loop() { + check_assist( + convert_to_guarded_return, + r#" +//- minicore: option +fn foo() -> Option { + None +} + +fn main() { + loop { + let x$0 = foo(); + } +} +"#, + r#" +fn foo() -> Option { + None +} + +fn main() { + loop { + let Some(x) = foo() else { continue }; + } +} +"#, + ); + } + #[test] fn convert_arbitrary_if_let_patterns() { check_assist( @@ -591,6 +731,37 @@ fn main() { } } } +"#, + ); + } + + #[test] + fn ignore_inside_if_stmt() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" +fn main() { + if false { + foo()$0; + } +} +"#, + ); + } + + #[test] + fn ignore_inside_let_initializer() { + check_assist_not_applicable( + convert_to_guarded_return, + r#" +//- minicore: option +fn foo() -> Option { + None +} + +fn main() { + let x = foo()$0; +} "#, ); } diff --git a/crates/ide-assists/src/handlers/desugar_doc_comment.rs b/crates/ide-assists/src/handlers/desugar_doc_comment.rs index c859e98524..d264928046 100644 --- a/crates/ide-assists/src/handlers/desugar_doc_comment.rs +++ b/crates/ide-assists/src/handlers/desugar_doc_comment.rs @@ -27,9 +27,7 @@ use crate::{ pub(crate) fn desugar_doc_comment(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let comment = ctx.find_token_at_offset::()?; // Only allow doc comments - let Some(placement) = comment.kind().doc else { - return None; - }; + let placement = comment.kind().doc?; // Only allow comments which are alone on their line if let Some(prev) = comment.syntax().prev_token() { diff --git a/crates/ide-assists/src/handlers/extract_function.rs b/crates/ide-assists/src/handlers/extract_function.rs index 1eb28626f7..d111005c2e 100644 --- a/crates/ide-assists/src/handlers/extract_function.rs +++ b/crates/ide-assists/src/handlers/extract_function.rs @@ -1,4 +1,4 @@ -use std::iter; +use std::{iter, ops::RangeInclusive}; use ast::make; use either::Either; @@ -12,27 +12,25 @@ use ide_db::{ helpers::mod_path_to_ast, imports::insert_use::{insert_use, ImportScope}, search::{FileReference, ReferenceCategory, SearchScope}, + source_change::SourceChangeBuilder, syntax_helpers::node_ext::{ for_each_tail_expr, preorder_expr, walk_expr, walk_pat, walk_patterns_in_expr, }, FxIndexSet, RootDatabase, }; -use itertools::Itertools; -use stdx::format_to; use syntax::{ ast::{ - self, - edit::{AstNodeEdit, IndentLevel}, - AstNode, HasGenericParams, + self, edit::IndentLevel, edit_in_place::Indent, AstNode, AstToken, HasGenericParams, + HasName, }, - match_ast, ted, AstToken, SyntaxElement, + match_ast, ted, SyntaxElement, SyntaxKind::{self, COMMENT}, SyntaxNode, SyntaxToken, TextRange, TextSize, TokenAtOffset, WalkEvent, T, }; use crate::{ assist_context::{AssistContext, Assists, TreeMutator}, - utils::generate_impl_text, + utils::generate_impl, AssistId, }; @@ -134,17 +132,65 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let new_indent = IndentLevel::from_node(&insert_after); let old_indent = fun.body.indent_level(); - builder.replace(target_range, make_call(ctx, &fun, old_indent)); + let insert_after = builder.make_syntax_mut(insert_after); + + let call_expr = make_call(ctx, &fun, old_indent); + + // Map the element range to replace into the mutable version + let elements = match &fun.body { + FunctionBody::Expr(expr) => { + // expr itself becomes the replacement target + let expr = &builder.make_mut(expr.clone()); + let node = SyntaxElement::Node(expr.syntax().clone()); + + node.clone()..=node + } + FunctionBody::Span { parent, elements, .. } => { + // Map the element range into the mutable versions + let parent = builder.make_mut(parent.clone()); + + let start = parent + .syntax() + .children_with_tokens() + .nth(elements.start().index()) + .expect("should be able to find mutable start element"); + + let end = parent + .syntax() + .children_with_tokens() + .nth(elements.end().index()) + .expect("should be able to find mutable end element"); + + start..=end + } + }; let has_impl_wrapper = insert_after.ancestors().any(|a| a.kind() == SyntaxKind::IMPL && a != insert_after); + let fn_def = format_function(ctx, module, &fun, old_indent).clone_for_update(); + + if let Some(cap) = ctx.config.snippet_cap { + if let Some(name) = fn_def.name() { + builder.add_tabstop_before(cap, name); + } + } + let fn_def = match fun.self_param_adt(ctx) { Some(adt) if anchor == Anchor::Method && !has_impl_wrapper => { - let fn_def = format_function(ctx, module, &fun, old_indent, new_indent + 1); - generate_impl_text(&adt, &fn_def).replace("{\n\n", "{") + fn_def.indent(1.into()); + + let impl_ = generate_impl(&adt); + impl_.indent(new_indent); + impl_.get_or_create_assoc_item_list().add_item(fn_def.into()); + + impl_.syntax().clone() + } + _ => { + fn_def.indent(new_indent); + + fn_def.syntax().clone() } - _ => format_function(ctx, module, &fun, old_indent, new_indent), }; // There are external control flows @@ -177,12 +223,15 @@ pub(crate) fn extract_function(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op } } - let insert_offset = insert_after.text_range().end(); + // Replace the call site with the call to the new function + fixup_call_site(builder, &fun.body); + ted::replace_all(elements, vec![call_expr.into()]); - match ctx.config.snippet_cap { - Some(cap) => builder.insert_snippet(cap, insert_offset, fn_def), - None => builder.insert(insert_offset, fn_def), - }; + // Insert the newly extracted function (or impl) + ted::insert_all_raw( + ted::Position::after(insert_after), + vec![make::tokens::whitespace(&format!("\n\n{new_indent}")).into(), fn_def.into()], + ); }, ) } @@ -195,7 +244,7 @@ fn make_function_name(semantics_scope: &hir::SemanticsScope<'_>) -> ast::NameRef let default_name = "fun_name"; - let mut name = default_name.to_string(); + let mut name = default_name.to_owned(); let mut counter = 0; while names_in_scope.contains(&name) { counter += 1; @@ -225,10 +274,10 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option None, - ast::Stmt::ExprStmt(_) | ast::Stmt::LetStmt(_) => Some(FunctionBody::from_range( + ast::Stmt::ExprStmt(_) | ast::Stmt::LetStmt(_) => FunctionBody::from_range( node.parent().and_then(ast::StmtList::cast)?, node.text_range(), - )), + ), }; } @@ -241,7 +290,7 @@ fn extraction_target(node: &SyntaxNode, selection_range: TextRange) -> Option, text_range: TextRange }, } #[derive(Debug)] @@ -569,26 +618,38 @@ impl FunctionBody { } } - fn from_range(parent: ast::StmtList, selected: TextRange) -> FunctionBody { + fn from_range(parent: ast::StmtList, selected: TextRange) -> Option { let full_body = parent.syntax().children_with_tokens(); - let mut text_range = full_body + // Get all of the elements intersecting with the selection + let mut stmts_in_selection = full_body .filter(|it| ast::Stmt::can_cast(it.kind()) || it.kind() == COMMENT) - .map(|element| element.text_range()) - .filter(|&range| selected.intersect(range).filter(|it| !it.is_empty()).is_some()) - .reduce(|acc, stmt| acc.cover(stmt)); + .filter(|it| selected.intersect(it.text_range()).filter(|it| !it.is_empty()).is_some()); - if let Some(tail_range) = parent - .tail_expr() - .map(|it| it.syntax().text_range()) - .filter(|&it| selected.intersect(it).is_some()) + let first_element = stmts_in_selection.next(); + + // If the tail expr is part of the selection too, make that the last element + // Otherwise use the last stmt + let last_element = if let Some(tail_expr) = + parent.tail_expr().filter(|it| selected.intersect(it.syntax().text_range()).is_some()) { - text_range = Some(match text_range { - Some(text_range) => text_range.cover(tail_range), - None => tail_range, - }); - } - Self::Span { parent, text_range: text_range.unwrap_or(selected) } + Some(tail_expr.syntax().clone().into()) + } else { + stmts_in_selection.last() + }; + + let elements = match (first_element, last_element) { + (None, _) => { + cov_mark::hit!(extract_function_empty_selection_is_not_applicable); + return None; + } + (Some(first), None) => first.clone()..=first, + (Some(first), Some(last)) => first..=last, + }; + + let text_range = elements.start().text_range().cover(elements.end().text_range()); + + Some(Self::Span { parent, elements, text_range }) } fn indent_level(&self) -> IndentLevel { @@ -601,7 +662,7 @@ impl FunctionBody { fn tail_expr(&self) -> Option { match &self { FunctionBody::Expr(expr) => Some(expr.clone()), - FunctionBody::Span { parent, text_range } => { + FunctionBody::Span { parent, text_range, .. } => { let tail_expr = parent.tail_expr()?; text_range.contains_range(tail_expr.syntax().text_range()).then_some(tail_expr) } @@ -611,7 +672,7 @@ impl FunctionBody { fn walk_expr(&self, cb: &mut dyn FnMut(ast::Expr)) { match self { FunctionBody::Expr(expr) => walk_expr(expr, cb), - FunctionBody::Span { parent, text_range } => { + FunctionBody::Span { parent, text_range, .. } => { parent .statements() .filter(|stmt| text_range.contains_range(stmt.syntax().text_range())) @@ -634,7 +695,7 @@ impl FunctionBody { fn preorder_expr(&self, cb: &mut dyn FnMut(WalkEvent) -> bool) { match self { FunctionBody::Expr(expr) => preorder_expr(expr, cb), - FunctionBody::Span { parent, text_range } => { + FunctionBody::Span { parent, text_range, .. } => { parent .statements() .filter(|stmt| text_range.contains_range(stmt.syntax().text_range())) @@ -657,7 +718,7 @@ impl FunctionBody { fn walk_pat(&self, cb: &mut dyn FnMut(ast::Pat)) { match self { FunctionBody::Expr(expr) => walk_patterns_in_expr(expr, cb), - FunctionBody::Span { parent, text_range } => { + FunctionBody::Span { parent, text_range, .. } => { parent .statements() .filter(|stmt| text_range.contains_range(stmt.syntax().text_range())) @@ -1151,7 +1212,7 @@ impl HasTokenAtOffset for FunctionBody { fn token_at_offset(&self, offset: TextSize) -> TokenAtOffset { match self { FunctionBody::Expr(expr) => expr.syntax().token_at_offset(offset), - FunctionBody::Span { parent, text_range } => { + FunctionBody::Span { parent, text_range, .. } => { match parent.syntax().token_at_offset(offset) { TokenAtOffset::None => TokenAtOffset::None, TokenAtOffset::Single(t) => { @@ -1316,7 +1377,19 @@ fn impl_type_name(impl_node: &ast::Impl) -> Option { Some(impl_node.self_ty()?.to_string()) } -fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> String { +/// Fixes up the call site before the target expressions are replaced with the call expression +fn fixup_call_site(builder: &mut SourceChangeBuilder, body: &FunctionBody) { + let parent_match_arm = body.parent().and_then(ast::MatchArm::cast); + + if let Some(parent_match_arm) = parent_match_arm { + if parent_match_arm.comma_token().is_none() { + let parent_match_arm = builder.make_mut(parent_match_arm); + ted::append_child_raw(parent_match_arm.syntax(), make::token(T![,])); + } + } +} + +fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> SyntaxNode { let ret_ty = fun.return_type(ctx); let args = make::arg_list(fun.params.iter().map(|param| param.to_arg(ctx))); @@ -1334,44 +1407,45 @@ fn make_call(ctx: &AssistContext<'_>, fun: &Function, indent: IndentLevel) -> St if fun.control_flow.is_async { call_expr = make::expr_await(call_expr); } - let expr = handler.make_call_expr(call_expr).indent(indent); - let mut_modifier = |var: &OutlivedLocal| if var.mut_usage_outside_body { "mut " } else { "" }; + let expr = handler.make_call_expr(call_expr).clone_for_update(); + expr.indent(indent); - let mut buf = String::new(); - match fun.outliving_locals.as_slice() { - [] => {} + let outliving_bindings = match fun.outliving_locals.as_slice() { + [] => None, [var] => { - let modifier = mut_modifier(var); let name = var.local.name(ctx.db()); - format_to!(buf, "let {modifier}{} = ", name.display(ctx.db())) + let name = make::name(&name.display(ctx.db()).to_string()); + Some(ast::Pat::IdentPat(make::ident_pat(false, var.mut_usage_outside_body, name))) } vars => { - buf.push_str("let ("); - let bindings = vars.iter().format_with(", ", |local, f| { - let modifier = mut_modifier(local); - let name = local.local.name(ctx.db()); - f(&format_args!("{modifier}{}", name.display(ctx.db())))?; - Ok(()) + let binding_pats = vars.iter().map(|var| { + let name = var.local.name(ctx.db()); + let name = make::name(&name.display(ctx.db()).to_string()); + make::ident_pat(false, var.mut_usage_outside_body, name).into() }); - format_to!(buf, "{bindings}"); - buf.push_str(") = "); + Some(ast::Pat::TuplePat(make::tuple_pat(binding_pats))) } - } + }; - format_to!(buf, "{expr}"); let parent_match_arm = fun.body.parent().and_then(ast::MatchArm::cast); - let insert_comma = parent_match_arm.as_ref().is_some_and(|it| it.comma_token().is_none()); - if insert_comma { - buf.push(','); - } else if parent_match_arm.is_none() + if let Some(bindings) = outliving_bindings { + // with bindings that outlive it + make::let_stmt(bindings, None, Some(expr)).syntax().clone_for_update() + } else if parent_match_arm.as_ref().is_some() { + // as a tail expr for a match arm + expr.syntax().clone() + } else if parent_match_arm.as_ref().is_none() && fun.ret_ty.is_unit() && (!fun.outliving_locals.is_empty() || !expr.is_block_like()) { - buf.push(';'); + // as an expr stmt + make::expr_stmt(expr).syntax().clone_for_update() + } else { + // as a tail expr, or a block + expr.syntax().clone() } - buf } enum FlowHandler { @@ -1500,42 +1574,25 @@ fn format_function( module: hir::Module, fun: &Function, old_indent: IndentLevel, - new_indent: IndentLevel, -) -> String { - let mut fn_def = String::new(); - - let fun_name = &fun.name; +) -> ast::Fn { + let fun_name = make::name(&fun.name.text()); let params = fun.make_param_list(ctx, module); let ret_ty = fun.make_ret_ty(ctx, module); - let body = make_body(ctx, old_indent, new_indent, fun); - let const_kw = if fun.mods.is_const { "const " } else { "" }; - let async_kw = if fun.control_flow.is_async { "async " } else { "" }; - let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" }; + let body = make_body(ctx, old_indent, fun); let (generic_params, where_clause) = make_generic_params_and_where_clause(ctx, fun); - format_to!(fn_def, "\n\n{new_indent}{const_kw}{async_kw}{unsafe_kw}"); - match ctx.config.snippet_cap { - Some(_) => format_to!(fn_def, "fn $0{fun_name}"), - None => format_to!(fn_def, "fn {fun_name}"), - } - - if let Some(generic_params) = generic_params { - format_to!(fn_def, "{generic_params}"); - } - - format_to!(fn_def, "{params}"); - - if let Some(ret_ty) = ret_ty { - format_to!(fn_def, " {ret_ty}"); - } - - if let Some(where_clause) = where_clause { - format_to!(fn_def, " {where_clause}"); - } - - format_to!(fn_def, " {body}"); - - fn_def + make::fn_( + None, + fun_name, + generic_params, + where_clause, + params, + body, + ret_ty, + fun.control_flow.is_async, + fun.mods.is_const, + fun.control_flow.is_unsafe, + ) } fn make_generic_params_and_where_clause( @@ -1716,12 +1773,7 @@ impl FunType { } } -fn make_body( - ctx: &AssistContext<'_>, - old_indent: IndentLevel, - new_indent: IndentLevel, - fun: &Function, -) -> ast::BlockExpr { +fn make_body(ctx: &AssistContext<'_>, old_indent: IndentLevel, fun: &Function) -> ast::BlockExpr { let ret_ty = fun.return_type(ctx); let handler = FlowHandler::from_ret_ty(fun, &ret_ty); @@ -1732,7 +1784,7 @@ fn make_body( match expr { ast::Expr::BlockExpr(block) => { // If the extracted expression is itself a block, there is no need to wrap it inside another block. - let block = block.dedent(old_indent); + block.dedent(old_indent); let elements = block.stmt_list().map_or_else( || Either::Left(iter::empty()), |stmt_list| { @@ -1752,13 +1804,13 @@ fn make_body( make::hacky_block_expr(elements, block.tail_expr()) } _ => { - let expr = expr.dedent(old_indent).indent(IndentLevel(1)); + expr.reindent_to(1.into()); make::block_expr(Vec::new(), Some(expr)) } } } - FunctionBody::Span { parent, text_range } => { + FunctionBody::Span { parent, text_range, .. } => { let mut elements: Vec<_> = parent .syntax() .children_with_tokens() @@ -1801,8 +1853,8 @@ fn make_body( .map(|node_or_token| match &node_or_token { syntax::NodeOrToken::Node(node) => match ast::Stmt::cast(node.clone()) { Some(stmt) => { - let indented = stmt.dedent(old_indent).indent(body_indent); - let ast_node = indented.syntax().clone_subtree(); + stmt.reindent_to(body_indent); + let ast_node = stmt.syntax().clone_subtree(); syntax::NodeOrToken::Node(ast_node) } _ => node_or_token, @@ -1810,13 +1862,15 @@ fn make_body( _ => node_or_token, }) .collect::>(); - let tail_expr = tail_expr.map(|expr| expr.dedent(old_indent).indent(body_indent)); + if let Some(tail_expr) = &mut tail_expr { + tail_expr.reindent_to(body_indent); + } make::hacky_block_expr(elements, tail_expr) } }; - let block = match &handler { + match &handler { FlowHandler::None => block, FlowHandler::Try { kind } => { let block = with_default_tail_expr(block, make::expr_unit()); @@ -1851,9 +1905,7 @@ fn make_body( let args = make::arg_list(iter::once(tail_expr)); make::expr_call(ok, args) }), - }; - - block.indent(new_indent) + } } fn map_tail_expr(block: ast::BlockExpr, f: impl FnOnce(ast::Expr) -> ast::Expr) -> ast::BlockExpr { @@ -1897,7 +1949,7 @@ fn with_tail_expr(block: ast::BlockExpr, tail_expr: ast::Expr) -> ast::BlockExpr } fn format_type(ty: &hir::Type, ctx: &AssistContext<'_>, module: hir::Module) -> String { - ty.display_source_code(ctx.db(), module.into(), true).ok().unwrap_or_else(|| "_".to_string()) + ty.display_source_code(ctx.db(), module.into(), true).ok().unwrap_or_else(|| "_".to_owned()) } fn make_ty(ty: &hir::Type, ctx: &AssistContext<'_>, module: hir::Module) -> ast::Type { @@ -2551,6 +2603,20 @@ fn $0fun_name(n: u32) -> u32 { check_assist_not_applicable(extract_function, r"fn main() { 1 + /* $0comment$0 */ 1; }"); } + #[test] + fn empty_selection_is_not_applicable() { + cov_mark::check!(extract_function_empty_selection_is_not_applicable); + check_assist_not_applicable( + extract_function, + r#" +fn main() { + $0 + + $0 +}"#, + ); + } + #[test] fn part_of_expr_stmt() { check_assist( diff --git a/crates/ide-assists/src/handlers/extract_module.rs b/crates/ide-assists/src/handlers/extract_module.rs index 30c3983dc4..af834c8a53 100644 --- a/crates/ide-assists/src/handlers/extract_module.rs +++ b/crates/ide-assists/src/handlers/extract_module.rs @@ -1,7 +1,4 @@ -use std::{ - collections::{HashMap, HashSet}, - iter, -}; +use std::iter; use hir::{HasSource, HirFileIdExt, ModuleSource}; use ide_db::{ @@ -9,6 +6,7 @@ use ide_db::{ base_db::FileId, defs::{Definition, NameClass, NameRefClass}, search::{FileReference, SearchScope}, + FxHashMap, FxHashSet, }; use itertools::Itertools; use smallvec::SmallVec; @@ -235,9 +233,9 @@ impl Module { fn get_usages_and_record_fields( &self, ctx: &AssistContext<'_>, - ) -> (HashMap>, Vec) { + ) -> (FxHashMap>, Vec) { let mut adt_fields = Vec::new(); - let mut refs: HashMap> = HashMap::new(); + let mut refs: FxHashMap> = FxHashMap::default(); //Here impl is not included as each item inside impl will be tied to the parent of //implementing block(a struct, enum, etc), if the parent is in selected module, it will @@ -320,7 +318,7 @@ impl Module { &self, ctx: &AssistContext<'_>, node_def: Definition, - refs_in_files: &mut HashMap>, + refs_in_files: &mut FxHashMap>, ) { for (file_id, references) in node_def.usages(&ctx.sema).all() { let source_file = ctx.sema.parse(file_id); @@ -400,7 +398,7 @@ impl Module { ctx: &AssistContext<'_>, ) -> Vec { let mut import_paths_to_be_removed: Vec = vec![]; - let mut node_set: HashSet = HashSet::new(); + let mut node_set: FxHashSet = FxHashSet::default(); for item in self.body_items.clone() { for x in item.syntax().descendants() { diff --git a/crates/ide-assists/src/handlers/extract_variable.rs b/crates/ide-assists/src/handlers/extract_variable.rs index 22d16cf6b3..36d3122651 100644 --- a/crates/ide-assists/src/handlers/extract_variable.rs +++ b/crates/ide-assists/src/handlers/extract_variable.rs @@ -115,7 +115,7 @@ pub(crate) fn extract_variable(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op let trailing_ws = if prev_ws.is_some_and(|it| it.text().starts_with('\n')) { format!("\n{indent_to}") } else { - " ".to_string() + " ".to_owned() }; ted::insert_all_raw( diff --git a/crates/ide-assists/src/handlers/flip_trait_bound.rs b/crates/ide-assists/src/handlers/flip_trait_bound.rs index 430cd5b080..70b5efcb64 100644 --- a/crates/ide-assists/src/handlers/flip_trait_bound.rs +++ b/crates/ide-assists/src/handlers/flip_trait_bound.rs @@ -58,6 +58,11 @@ mod tests { check_assist_not_applicable(flip_trait_bound, "struct S where T: $0A { }") } + #[test] + fn flip_trait_bound_works_for_dyn() { + check_assist(flip_trait_bound, "fn f<'a>(x: dyn Copy $0+ 'a)", "fn f<'a>(x: dyn 'a + Copy)") + } + #[test] fn flip_trait_bound_works_for_struct() { check_assist( diff --git a/crates/ide-assists/src/handlers/generate_delegate_methods.rs b/crates/ide-assists/src/handlers/generate_delegate_methods.rs index d59bd71d31..4f2df5633c 100644 --- a/crates/ide-assists/src/handlers/generate_delegate_methods.rs +++ b/crates/ide-assists/src/handlers/generate_delegate_methods.rs @@ -1,7 +1,5 @@ -use std::collections::HashSet; - use hir::{self, HasCrate, HasVisibility}; -use ide_db::path_transform::PathTransform; +use ide_db::{path_transform::PathTransform, FxHashSet}; use syntax::{ ast::{ self, edit_in_place::Indent, make, AstNode, HasGenericParams, HasName, HasVisibility as _, @@ -71,7 +69,7 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' let sema_field_ty = ctx.sema.resolve_type(&field_ty)?; let mut methods = vec![]; - let mut seen_names = HashSet::new(); + let mut seen_names = FxHashSet::default(); for ty in sema_field_ty.autoderef(ctx.db()) { let krate = ty.krate(ctx.db()); @@ -163,13 +161,13 @@ pub(crate) fn generate_delegate_methods(acc: &mut Assists, ctx: &AssistContext<' Some(impl_def) => edit.make_mut(impl_def), None => { let name = &strukt_name.to_string(); - let params = strukt.generic_param_list(); - let ty_params = params; + let ty_params = strukt.generic_param_list(); + let ty_args = ty_params.as_ref().map(|it| it.to_generic_args()); let where_clause = strukt.where_clause(); let impl_def = make::impl_( ty_params, - None, + ty_args, make::ty_path(make::ext::ident_path(name)), where_clause, None, diff --git a/crates/ide-assists/src/handlers/generate_delegate_trait.rs b/crates/ide-assists/src/handlers/generate_delegate_trait.rs index 898bd01291..7a60287f92 100644 --- a/crates/ide-assists/src/handlers/generate_delegate_trait.rs +++ b/crates/ide-assists/src/handlers/generate_delegate_trait.rs @@ -502,9 +502,7 @@ fn generate_args_for_impl( trait_params: &Option, old_trait_args: &FxHashSet, ) -> Option { - let Some(old_impl_args) = old_impl_gpl.map(|gpl| gpl.to_generic_args().generic_args()) else { - return None; - }; + let old_impl_args = old_impl_gpl.map(|gpl| gpl.to_generic_args().generic_args())?; // Create pairs of the args of `self_ty` and corresponding `field_ty` to // form the substitution list let mut arg_substs = FxHashMap::default(); @@ -958,7 +956,8 @@ where impl AnotherTrait for S where T: AnotherTrait, -{}"#, +{ +}"#, ); } @@ -1448,7 +1447,8 @@ where impl AnotherTrait for S where T: AnotherTrait, -{}"#, +{ +}"#, ); } diff --git a/crates/ide-assists/src/handlers/generate_documentation_template.rs b/crates/ide-assists/src/handlers/generate_documentation_template.rs index f298ce8916..38b24fd19c 100644 --- a/crates/ide-assists/src/handlers/generate_documentation_template.rs +++ b/crates/ide-assists/src/handlers/generate_documentation_template.rs @@ -364,7 +364,7 @@ fn is_in_trait_impl(ast_func: &ast::Fn, ctx: &AssistContext<'_>) -> bool { ctx.sema .to_def(ast_func) .and_then(|hir_func| hir_func.as_assoc_item(ctx.db())) - .and_then(|assoc_item| assoc_item.containing_trait_impl(ctx.db())) + .and_then(|assoc_item| assoc_item.implemented_trait(ctx.db())) .is_some() } @@ -373,7 +373,7 @@ fn is_in_trait_def(ast_func: &ast::Fn, ctx: &AssistContext<'_>) -> bool { ctx.sema .to_def(ast_func) .and_then(|hir_func| hir_func.as_assoc_item(ctx.db())) - .and_then(|assoc_item| assoc_item.containing_trait(ctx.db())) + .and_then(|assoc_item| assoc_item.container_trait(ctx.db())) .is_some() } @@ -416,9 +416,9 @@ fn arguments_from_params(param_list: &ast::ParamList) -> String { true => format!("&mut {name}"), false => name.to_string(), }, - None => "_".to_string(), + None => "_".to_owned(), }, - _ => "_".to_string(), + _ => "_".to_owned(), }); args_iter.format(", ").to_string() } diff --git a/crates/ide-assists/src/handlers/generate_enum_variant.rs b/crates/ide-assists/src/handlers/generate_enum_variant.rs index 681f8c1fcf..7faf2d5b13 100644 --- a/crates/ide-assists/src/handlers/generate_enum_variant.rs +++ b/crates/ide-assists/src/handlers/generate_enum_variant.rs @@ -162,7 +162,7 @@ fn make_record_field_list( fn name_from_field(field: &ast::RecordExprField) -> ast::Name { let text = match field.name_ref() { Some(it) => it.to_string(), - None => name_from_field_shorthand(field).unwrap_or("unknown".to_string()), + None => name_from_field_shorthand(field).unwrap_or("unknown".to_owned()), }; make::name(&text) } diff --git a/crates/ide-assists/src/handlers/generate_function.rs b/crates/ide-assists/src/handlers/generate_function.rs index 50528e1caa..fe2f8ed641 100644 --- a/crates/ide-assists/src/handlers/generate_function.rs +++ b/crates/ide-assists/src/handlers/generate_function.rs @@ -202,7 +202,7 @@ fn get_adt_source( let file = ctx.sema.parse(range.file_id); let adt_source = ctx.sema.find_node_at_offset_with_macros(file.syntax(), range.range.start())?; - find_struct_impl(ctx, &adt_source, &[fn_name.to_string()]).map(|impl_| (impl_, range.file_id)) + find_struct_impl(ctx, &adt_source, &[fn_name.to_owned()]).map(|impl_| (impl_, range.file_id)) } struct FunctionTemplate { @@ -908,7 +908,7 @@ fn filter_unnecessary_bounds( } } - let starting_nodes = necessary_params.iter().map(|param| param_map[param]); + let starting_nodes = necessary_params.iter().flat_map(|param| param_map.get(param).copied()); let reachable = graph.compute_reachable_nodes(starting_nodes); // Not pretty, but effective. If only there were `Vec::retain_index()`... @@ -1007,7 +1007,7 @@ fn fn_arg_name(sema: &Semantics<'_, RootDatabase>, arg_expr: &ast::Expr) -> Stri name } Some(name) => name, - None => "arg".to_string(), + None => "arg".to_owned(), } } @@ -1033,7 +1033,7 @@ fn fn_arg_type( if ty.is_reference() || ty.is_mutable_reference() { let famous_defs = &FamousDefs(&ctx.sema, ctx.sema.scope(fn_arg.syntax())?.krate()); convert_reference_type(ty.strip_references(), ctx.db(), famous_defs) - .map(|conversion| conversion.convert_type(ctx.db())) + .map(|conversion| conversion.convert_type(ctx.db()).to_string()) .or_else(|| ty.display_source_code(ctx.db(), target_module.into(), true).ok()) } else { ty.display_source_code(ctx.db(), target_module.into(), true).ok() diff --git a/crates/ide-assists/src/handlers/generate_getter_or_setter.rs b/crates/ide-assists/src/handlers/generate_getter_or_setter.rs index 79307fcec5..e90a032f1c 100644 --- a/crates/ide-assists/src/handlers/generate_getter_or_setter.rs +++ b/crates/ide-assists/src/handlers/generate_getter_or_setter.rs @@ -1,12 +1,12 @@ use ide_db::{famous_defs::FamousDefs, source_change::SourceChangeBuilder}; use stdx::{format_to, to_lower_snake_case}; use syntax::{ - ast::{self, AstNode, HasName, HasVisibility}, - TextRange, + ast::{self, edit_in_place::Indent, make, AstNode, HasName, HasVisibility}, + ted, TextRange, }; use crate::{ - utils::{convert_reference_type, find_impl_block_end, find_struct_impl, generate_impl_text}, + utils::{convert_reference_type, find_struct_impl, generate_impl}, AssistContext, AssistId, AssistKind, Assists, GroupLabel, }; @@ -75,7 +75,7 @@ pub(crate) fn generate_setter(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt // Generate a getter method. // // ``` -// # //- minicore: as_ref +// # //- minicore: as_ref, deref // # pub struct String; // # impl AsRef for String { // # fn as_ref(&self) -> &str { @@ -83,6 +83,13 @@ pub(crate) fn generate_setter(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt // # } // # } // # +// # impl core::ops::Deref for String { +// # type Target = str; +// # fn deref(&self) -> &Self::Target { +// # "" +// # } +// # } +// # // struct Person { // nam$0e: String, // } @@ -96,13 +103,20 @@ pub(crate) fn generate_setter(acc: &mut Assists, ctx: &AssistContext<'_>) -> Opt // # } // # } // # +// # impl core::ops::Deref for String { +// # type Target = str; +// # fn deref(&self) -> &Self::Target { +// # "" +// # } +// # } +// # // struct Person { // name: String, // } // // impl Person { // fn $0name(&self) -> &str { -// self.name.as_ref() +// &self.name // } // } // ``` @@ -200,14 +214,14 @@ fn generate_getter_from_info( ctx: &AssistContext<'_>, info: &AssistInfo, record_field_info: &RecordFieldInfo, -) -> String { - let mut buf = String::with_capacity(512); - - let vis = info.strukt.visibility().map_or(String::new(), |v| format!("{v} ")); +) -> ast::Fn { let (ty, body) = if matches!(info.assist_type, AssistType::MutGet) { ( - format!("&mut {}", record_field_info.field_ty), - format!("&mut self.{}", record_field_info.field_name), + make::ty_ref(record_field_info.field_ty.clone(), true), + make::expr_ref( + make::expr_field(make::ext::expr_self(), &record_field_info.field_name.text()), + true, + ), ) } else { (|| { @@ -226,41 +240,52 @@ fn generate_getter_from_info( })() .unwrap_or_else(|| { ( - format!("&{}", record_field_info.field_ty), - format!("&self.{}", record_field_info.field_name), + make::ty_ref(record_field_info.field_ty.clone(), false), + make::expr_ref( + make::expr_field(make::ext::expr_self(), &record_field_info.field_name.text()), + false, + ), ) }) }; - format_to!( - buf, - " {}fn {}(&{}self) -> {} {{ - {} - }}", - vis, - record_field_info.fn_name, - matches!(info.assist_type, AssistType::MutGet).then_some("mut ").unwrap_or_default(), - ty, - body, - ); + let self_param = if matches!(info.assist_type, AssistType::MutGet) { + make::mut_self_param() + } else { + make::self_param() + }; - buf + let strukt = &info.strukt; + let fn_name = make::name(&record_field_info.fn_name); + let params = make::param_list(Some(self_param), []); + let ret_type = Some(make::ret_type(ty)); + let body = make::block_expr([], Some(body)); + + make::fn_(strukt.visibility(), fn_name, None, None, params, body, ret_type, false, false, false) } -fn generate_setter_from_info(info: &AssistInfo, record_field_info: &RecordFieldInfo) -> String { - let mut buf = String::with_capacity(512); +fn generate_setter_from_info(info: &AssistInfo, record_field_info: &RecordFieldInfo) -> ast::Fn { let strukt = &info.strukt; - let fn_name = &record_field_info.fn_name; + let field_name = &record_field_info.fn_name; + let fn_name = make::name(&format!("set_{field_name}")); let field_ty = &record_field_info.field_ty; - let vis = strukt.visibility().map_or(String::new(), |v| format!("{v} ")); - format_to!( - buf, - " {vis}fn set_{fn_name}(&mut self, {fn_name}: {field_ty}) {{ - self.{fn_name} = {fn_name}; - }}" - ); - buf + // Make the param list + // `(&mut self, $field_name: $field_ty)` + let field_param = + make::param(make::ident_pat(false, false, make::name(field_name)).into(), field_ty.clone()); + let params = make::param_list(Some(make::mut_self_param()), [field_param]); + + // Make the assignment body + // `self.$field_name = $field_name` + let self_expr = make::ext::expr_self(); + let lhs = make::expr_field(self_expr, field_name); + let rhs = make::expr_path(make::ext::ident_path(field_name)); + let assign_stmt = make::expr_stmt(make::expr_assignment(lhs, rhs)); + let body = make::block_expr([assign_stmt.into()], None); + + // Make the setter fn + make::fn_(strukt.visibility(), fn_name, None, None, params, body, None, false, false, false) } fn extract_and_parse( @@ -353,74 +378,45 @@ fn build_source_change( ) { let record_fields_count = info_of_record_fields.len(); - let mut buf = String::with_capacity(512); + let impl_def = if let Some(impl_def) = &assist_info.impl_def { + // We have an existing impl to add to + builder.make_mut(impl_def.clone()) + } else { + // Generate a new impl to add the methods to + let impl_def = generate_impl(&ast::Adt::Struct(assist_info.strukt.clone())); - // Check if an impl exists - if let Some(impl_def) = &assist_info.impl_def { - // Check if impl is empty - if let Some(assoc_item_list) = impl_def.assoc_item_list() { - if assoc_item_list.assoc_items().next().is_some() { - // If not empty then only insert a new line - buf.push('\n'); - } - } - } + // Insert it after the adt + let strukt = builder.make_mut(assist_info.strukt.clone()); + + ted::insert_all_raw( + ted::Position::after(strukt.syntax()), + vec![make::tokens::blank_line().into(), impl_def.syntax().clone().into()], + ); + + impl_def + }; + + let assoc_item_list = impl_def.get_or_create_assoc_item_list(); for (i, record_field_info) in info_of_record_fields.iter().enumerate() { - // this buf inserts a newline at the end of a getter - // automatically, if one wants to add one more newline - // for separating it from other assoc items, that needs - // to be handled separately - let mut getter_buf = match assist_info.assist_type { + // Make the new getter or setter fn + let new_fn = match assist_info.assist_type { AssistType::Set => generate_setter_from_info(&assist_info, record_field_info), _ => generate_getter_from_info(ctx, &assist_info, record_field_info), - }; + } + .clone_for_update(); + new_fn.indent(1.into()); - // Insert `$0` only for last getter we generate - if i == record_fields_count - 1 && ctx.config.snippet_cap.is_some() { - getter_buf = getter_buf.replacen("fn ", "fn $0", 1); + // Insert a tabstop only for last method we generate + if i == record_fields_count - 1 { + if let Some(cap) = ctx.config.snippet_cap { + if let Some(name) = new_fn.name() { + builder.add_tabstop_before(cap, name); + } + } } - // For first element we do not merge with '\n', as - // that can be inserted by impl_def check defined - // above, for other cases which are: - // - // - impl exists but it empty, here we would ideally - // not want to keep newline between impl { - // and fn () { line - // - // - next if impl itself does not exist, in this - // case we ourselves generate a new impl and that - // again ends up with the same reasoning as above - // for not keeping newline - if i == 0 { - buf = buf + &getter_buf; - } else { - buf = buf + "\n" + &getter_buf; - } - - // We don't insert a new line at the end of - // last getter as it will end up in the end - // of an impl where we would not like to keep - // getter and end of impl ( i.e. `}` ) with an - // extra line for no reason - if i < record_fields_count - 1 { - buf += "\n"; - } - } - - let start_offset = assist_info - .impl_def - .as_ref() - .and_then(|impl_def| find_impl_block_end(impl_def.to_owned(), &mut buf)) - .unwrap_or_else(|| { - buf = generate_impl_text(&ast::Adt::Struct(assist_info.strukt.clone()), &buf); - assist_info.strukt.syntax().text_range().end() - }); - - match ctx.config.snippet_cap { - Some(cap) => builder.insert_snippet(cap, start_offset, buf), - None => builder.insert(start_offset, buf), + assoc_item_list.add_item(new_fn.clone().into()); } } diff --git a/crates/ide-assists/src/handlers/generate_impl.rs b/crates/ide-assists/src/handlers/generate_impl.rs index d52d778d34..821783c283 100644 --- a/crates/ide-assists/src/handlers/generate_impl.rs +++ b/crates/ide-assists/src/handlers/generate_impl.rs @@ -1,10 +1,10 @@ -use syntax::ast::{self, AstNode, HasName}; - -use crate::{ - utils::{generate_impl_text, generate_trait_impl_text_intransitive}, - AssistContext, AssistId, AssistKind, Assists, +use syntax::{ + ast::{self, make, AstNode, HasName}, + ted, }; +use crate::{utils, AssistContext, AssistId, AssistKind, Assists}; + // Assist: generate_impl // // Adds a new inherent impl for a type. @@ -20,9 +20,7 @@ use crate::{ // data: T, // } // -// impl Ctx { -// $0 -// } +// impl Ctx {$0} // ``` pub(crate) fn generate_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let nominal = ctx.find_node_at_offset::()?; @@ -38,17 +36,22 @@ pub(crate) fn generate_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Optio format!("Generate impl for `{name}`"), target, |edit| { - let start_offset = nominal.syntax().text_range().end(); - match ctx.config.snippet_cap { - Some(cap) => { - let snippet = generate_impl_text(&nominal, " $0"); - edit.insert_snippet(cap, start_offset, snippet); - } - None => { - let snippet = generate_impl_text(&nominal, ""); - edit.insert(start_offset, snippet); + // Generate the impl + let impl_ = utils::generate_impl(&nominal); + + // Add a tabstop after the left curly brace + if let Some(cap) = ctx.config.snippet_cap { + if let Some(l_curly) = impl_.assoc_item_list().and_then(|it| it.l_curly_token()) { + edit.add_tabstop_after_token(cap, l_curly); } } + + // Add the impl after the adt + let nominal = edit.make_mut(nominal); + ted::insert_all_raw( + ted::Position::after(nominal.syntax()), + vec![make::tokens::blank_line().into(), impl_.syntax().clone().into()], + ); }, ) } @@ -68,9 +71,7 @@ pub(crate) fn generate_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Optio // data: T, // } // -// impl $0 for Ctx { -// -// } +// impl ${0:_} for Ctx {} // ``` pub(crate) fn generate_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { let nominal = ctx.find_node_at_offset::()?; @@ -86,17 +87,22 @@ pub(crate) fn generate_trait_impl(acc: &mut Assists, ctx: &AssistContext<'_>) -> format!("Generate trait impl for `{name}`"), target, |edit| { - let start_offset = nominal.syntax().text_range().end(); - match ctx.config.snippet_cap { - Some(cap) => { - let snippet = generate_trait_impl_text_intransitive(&nominal, "$0", ""); - edit.insert_snippet(cap, start_offset, snippet); - } - None => { - let text = generate_trait_impl_text_intransitive(&nominal, "", ""); - edit.insert(start_offset, text); + // Generate the impl + let impl_ = utils::generate_trait_impl_intransitive(&nominal, make::ty_placeholder()); + + // Make the trait type a placeholder snippet + if let Some(cap) = ctx.config.snippet_cap { + if let Some(trait_) = impl_.trait_() { + edit.add_placeholder_snippet(cap, trait_); } } + + // Add the impl after the adt + let nominal = edit.make_mut(nominal); + ted::insert_all_raw( + ted::Position::after(nominal.syntax()), + vec![make::tokens::blank_line().into(), impl_.syntax().clone().into()], + ); }, ) } @@ -117,9 +123,7 @@ mod tests { r#" struct Foo {} - impl Foo { - $0 - } + impl Foo {$0} "#, ); } @@ -134,9 +138,7 @@ mod tests { r#" struct Foo {} - impl Foo { - $0 - } + impl Foo {$0} "#, ); } @@ -151,9 +153,7 @@ mod tests { r#" struct Foo<'a, T: Foo<'a>> {} - impl<'a, T: Foo<'a>> Foo<'a, T> { - $0 - } + impl<'a, T: Foo<'a>> Foo<'a, T> {$0} "#, ); } @@ -171,9 +171,7 @@ mod tests { struct Foo<'a, T: Foo<'a>> {} #[cfg(feature = "foo")] - impl<'a, T: Foo<'a>> Foo<'a, T> { - $0 - } + impl<'a, T: Foo<'a>> Foo<'a, T> {$0} "#, ); } @@ -188,9 +186,7 @@ mod tests { r#" struct Defaulted {} - impl Defaulted { - $0 - } + impl Defaulted {$0} "#, ); } @@ -205,9 +201,7 @@ mod tests { r#" struct Defaulted<'a, 'b: 'a, T: Debug + Clone + 'a + 'b = String, const S: usize> {} - impl<'a, 'b: 'a, T: Debug + Clone + 'a + 'b, const S: usize> Defaulted<'a, 'b, T, S> { - $0 - } + impl<'a, 'b: 'a, T: Debug + Clone + 'a + 'b, const S: usize> Defaulted<'a, 'b, T, S> {$0} "#, ); } @@ -222,9 +216,7 @@ mod tests { r#" struct Defaulted {} - impl Defaulted { - $0 - } + impl Defaulted {$0} "#, ); } @@ -254,8 +246,7 @@ mod tests { impl Struct where T: Trait, - { - $0 + {$0 } "#, ); @@ -285,9 +276,7 @@ mod tests { r#" struct Foo {} - impl $0 for Foo { - - } + impl ${0:_} for Foo {} "#, ); } @@ -302,9 +291,7 @@ mod tests { r#" struct Foo {} - impl $0 for Foo { - - } + impl ${0:_} for Foo {} "#, ); } @@ -319,9 +306,7 @@ mod tests { r#" struct Foo<'a, T: Foo<'a>> {} - impl<'a, T: Foo<'a>> $0 for Foo<'a, T> { - - } + impl<'a, T: Foo<'a>> ${0:_} for Foo<'a, T> {} "#, ); } @@ -339,9 +324,7 @@ mod tests { struct Foo<'a, T: Foo<'a>> {} #[cfg(feature = "foo")] - impl<'a, T: Foo<'a>> $0 for Foo<'a, T> { - - } + impl<'a, T: Foo<'a>> ${0:_} for Foo<'a, T> {} "#, ); } @@ -356,9 +339,7 @@ mod tests { r#" struct Defaulted {} - impl $0 for Defaulted { - - } + impl ${0:_} for Defaulted {} "#, ); } @@ -373,9 +354,7 @@ mod tests { r#" struct Defaulted<'a, 'b: 'a, T: Debug + Clone + 'a + 'b = String, const S: usize> {} - impl<'a, 'b: 'a, T: Debug + Clone + 'a + 'b, const S: usize> $0 for Defaulted<'a, 'b, T, S> { - - } + impl<'a, 'b: 'a, T: Debug + Clone + 'a + 'b, const S: usize> ${0:_} for Defaulted<'a, 'b, T, S> {} "#, ); } @@ -390,9 +369,7 @@ mod tests { r#" struct Defaulted {} - impl $0 for Defaulted { - - } + impl ${0:_} for Defaulted {} "#, ); } @@ -419,11 +396,10 @@ mod tests { inner: T, } - impl $0 for Struct + impl ${0:_} for Struct where T: Trait, { - } "#, ); diff --git a/crates/ide-assists/src/handlers/generate_is_empty_from_len.rs b/crates/ide-assists/src/handlers/generate_is_empty_from_len.rs index 6bfc69b0ad..4d369e705e 100644 --- a/crates/ide-assists/src/handlers/generate_is_empty_from_len.rs +++ b/crates/ide-assists/src/handlers/generate_is_empty_from_len.rs @@ -79,7 +79,7 @@ pub(crate) fn generate_is_empty_from_len(acc: &mut Assists, ctx: &AssistContext< pub fn is_empty(&self) -> bool { self.len() == 0 }"# - .to_string(); + .to_owned(); builder.insert(range.end(), code) }, ) diff --git a/crates/ide-assists/src/handlers/generate_new.rs b/crates/ide-assists/src/handlers/generate_new.rs index 7bfd599660..22c75cd5ee 100644 --- a/crates/ide-assists/src/handlers/generate_new.rs +++ b/crates/ide-assists/src/handlers/generate_new.rs @@ -1,12 +1,13 @@ use ide_db::{ imports::import_assets::item_for_path_search, use_trivial_constructor::use_trivial_constructor, }; -use itertools::Itertools; -use stdx::format_to; -use syntax::ast::{self, AstNode, HasName, HasVisibility, StructKind}; +use syntax::{ + ast::{self, edit_in_place::Indent, make, AstNode, HasName, HasVisibility, StructKind}, + ted, +}; use crate::{ - utils::{find_impl_block_start, find_struct_impl, generate_impl_text}, + utils::{find_struct_impl, generate_impl}, AssistContext, AssistId, AssistKind, Assists, }; @@ -26,7 +27,9 @@ use crate::{ // } // // impl Ctx { -// fn $0new(data: T) -> Self { Self { data } } +// fn $0new(data: T) -> Self { +// Self { data } +// } // } // ``` pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> { @@ -46,14 +49,6 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option let target = strukt.syntax().text_range(); acc.add(AssistId("generate_new", AssistKind::Generate), "Generate `new`", target, |builder| { - let mut buf = String::with_capacity(512); - - if impl_def.is_some() { - buf.push('\n'); - } - - let vis = strukt.visibility().map_or(String::new(), |v| format!("{v} ")); - let trivial_constructors = field_list .fields() .map(|f| { @@ -76,54 +71,79 @@ pub(crate) fn generate_new(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option &ty, )?; - Some(format!("{name}: {expr}")) + Some(make::record_expr_field(make::name_ref(&name.text()), Some(expr))) }) .collect::>(); - let params = field_list - .fields() - .enumerate() - .filter_map(|(i, f)| { - if trivial_constructors[i].is_none() { - let name = f.name()?; - let ty = f.ty()?; + let params = field_list.fields().enumerate().filter_map(|(i, f)| { + if trivial_constructors[i].is_none() { + let name = f.name()?; + let ty = f.ty()?; - Some(format!("{name}: {ty}")) - } else { - None - } - }) - .format(", "); + Some(make::param(make::ident_pat(false, false, name).into(), ty)) + } else { + None + } + }); + let params = make::param_list(None, params); - let fields = field_list - .fields() - .enumerate() - .filter_map(|(i, f)| { - let constructor = trivial_constructors[i].clone(); - if constructor.is_some() { - constructor - } else { - Some(f.name()?.to_string()) - } - }) - .format(", "); + let fields = field_list.fields().enumerate().filter_map(|(i, f)| { + let constructor = trivial_constructors[i].clone(); + if constructor.is_some() { + constructor + } else { + Some(make::record_expr_field(make::name_ref(&f.name()?.text()), None)) + } + }); + let fields = make::record_expr_field_list(fields); - format_to!(buf, " {vis}fn new({params}) -> Self {{ Self {{ {fields} }} }}"); + let record_expr = make::record_expr(make::ext::ident_path("Self"), fields); + let body = make::block_expr(None, Some(record_expr.into())); - let start_offset = impl_def - .and_then(|impl_def| find_impl_block_start(impl_def, &mut buf)) - .unwrap_or_else(|| { - buf = generate_impl_text(&ast::Adt::Struct(strukt.clone()), &buf); - strukt.syntax().text_range().end() - }); + let ret_type = make::ret_type(make::ty_path(make::ext::ident_path("Self"))); - match ctx.config.snippet_cap { - None => builder.insert(start_offset, buf), - Some(cap) => { - buf = buf.replace("fn new", "fn $0new"); - builder.insert_snippet(cap, start_offset, buf); + let fn_ = make::fn_( + strukt.visibility(), + make::name("new"), + None, + None, + params, + body, + Some(ret_type), + false, + false, + false, + ) + .clone_for_update(); + fn_.indent(1.into()); + + // Add a tabstop before the name + if let Some(cap) = ctx.config.snippet_cap { + if let Some(name) = fn_.name() { + builder.add_tabstop_before(cap, name); } } + + // Get the mutable version of the impl to modify + let impl_def = if let Some(impl_def) = impl_def { + builder.make_mut(impl_def) + } else { + // Generate a new impl to add the method to + let impl_def = generate_impl(&ast::Adt::Struct(strukt.clone())); + + // Insert it after the adt + let strukt = builder.make_mut(strukt.clone()); + + ted::insert_all_raw( + ted::Position::after(strukt.syntax()), + vec![make::tokens::blank_line().into(), impl_def.syntax().clone().into()], + ); + + impl_def + }; + + // Add the `new` method at the start of the impl + impl_def.get_or_create_assoc_item_list().add_item_at_start(fn_.into()); }) } @@ -148,7 +168,9 @@ struct Empty; struct Foo { empty: Empty } impl Foo { - fn $0new() -> Self { Self { empty: Empty } } + fn $0new() -> Self { + Self { empty: Empty } + } } "#, ); @@ -165,7 +187,9 @@ struct Empty; struct Foo { baz: String, empty: Empty } impl Foo { - fn $0new(baz: String) -> Self { Self { baz, empty: Empty } } + fn $0new(baz: String) -> Self { + Self { baz, empty: Empty } + } } "#, ); @@ -182,7 +206,9 @@ enum Empty { Bar } struct Foo { empty: Empty } impl Foo { - fn $0new() -> Self { Self { empty: Empty::Bar } } + fn $0new() -> Self { + Self { empty: Empty::Bar } + } } "#, ); @@ -201,7 +227,9 @@ struct Empty {} struct Foo { empty: Empty } impl Foo { - fn $0new(empty: Empty) -> Self { Self { empty } } + fn $0new(empty: Empty) -> Self { + Self { empty } + } } "#, ); @@ -218,7 +246,9 @@ enum Empty { Bar {} } struct Foo { empty: Empty } impl Foo { - fn $0new(empty: Empty) -> Self { Self { empty } } + fn $0new(empty: Empty) -> Self { + Self { empty } + } } "#, ); @@ -235,7 +265,9 @@ struct Foo {$0} struct Foo {} impl Foo { - fn $0new() -> Self { Self { } } + fn $0new() -> Self { + Self { } + } } "#, ); @@ -248,7 +280,9 @@ struct Foo {$0} struct Foo {} impl Foo { - fn $0new() -> Self { Self { } } + fn $0new() -> Self { + Self { } + } } "#, ); @@ -261,7 +295,9 @@ struct Foo<'a, T: Foo<'a>> {$0} struct Foo<'a, T: Foo<'a>> {} impl<'a, T: Foo<'a>> Foo<'a, T> { - fn $0new() -> Self { Self { } } + fn $0new() -> Self { + Self { } + } } "#, ); @@ -274,7 +310,9 @@ struct Foo { baz: String $0} struct Foo { baz: String } impl Foo { - fn $0new(baz: String) -> Self { Self { baz } } + fn $0new(baz: String) -> Self { + Self { baz } + } } "#, ); @@ -287,7 +325,9 @@ struct Foo { baz: String, qux: Vec $0} struct Foo { baz: String, qux: Vec } impl Foo { - fn $0new(baz: String, qux: Vec) -> Self { Self { baz, qux } } + fn $0new(baz: String, qux: Vec) -> Self { + Self { baz, qux } + } } "#, ); @@ -304,7 +344,9 @@ struct Foo { pub baz: String, pub qux: Vec $0} struct Foo { pub baz: String, pub qux: Vec } impl Foo { - fn $0new(baz: String, qux: Vec) -> Self { Self { baz, qux } } + fn $0new(baz: String, qux: Vec) -> Self { + Self { baz, qux } + } } "#, ); @@ -323,7 +365,9 @@ impl Foo {} struct Foo {} impl Foo { - fn $0new() -> Self { Self { } } + fn $0new() -> Self { + Self { } + } } "#, ); @@ -340,7 +384,9 @@ impl Foo { struct Foo {} impl Foo { - fn $0new() -> Self { Self { } } + fn $0new() -> Self { + Self { } + } fn qux(&self) {} } @@ -363,7 +409,9 @@ impl Foo { struct Foo {} impl Foo { - fn $0new() -> Self { Self { } } + fn $0new() -> Self { + Self { } + } fn qux(&self) {} fn baz() -> i32 { @@ -385,7 +433,9 @@ pub struct Foo {$0} pub struct Foo {} impl Foo { - pub fn $0new() -> Self { Self { } } + pub fn $0new() -> Self { + Self { } + } } "#, ); @@ -398,7 +448,9 @@ pub(crate) struct Foo {$0} pub(crate) struct Foo {} impl Foo { - pub(crate) fn $0new() -> Self { Self { } } + pub(crate) fn $0new() -> Self { + Self { } + } } "#, ); @@ -493,7 +545,9 @@ pub struct Source { } impl Source { - pub fn $0new(file_id: HirFileId, ast: T) -> Self { Self { file_id, ast } } + pub fn $0new(file_id: HirFileId, ast: T) -> Self { + Self { file_id, ast } + } pub fn map U, U>(self, f: F) -> Source { Source { file_id: self.file_id, ast: f(self.ast) } diff --git a/crates/ide-assists/src/handlers/generate_trait_from_impl.rs b/crates/ide-assists/src/handlers/generate_trait_from_impl.rs index 8881aa69f2..24094de22c 100644 --- a/crates/ide-assists/src/handlers/generate_trait_from_impl.rs +++ b/crates/ide-assists/src/handlers/generate_trait_from_impl.rs @@ -118,7 +118,7 @@ pub(crate) fn generate_trait_from_impl(acc: &mut Assists, ctx: &AssistContext<'_ let arg_list = if let Some(genpars) = impl_ast.generic_param_list() { genpars.to_generic_args().to_string() } else { - "".to_string() + "".to_owned() }; if let Some(snippet_cap) = ctx.config.snippet_cap { diff --git a/crates/ide-assists/src/handlers/inline_call.rs b/crates/ide-assists/src/handlers/inline_call.rs index 4ba33ada48..11b22b6520 100644 --- a/crates/ide-assists/src/handlers/inline_call.rs +++ b/crates/ide-assists/src/handlers/inline_call.rs @@ -415,7 +415,24 @@ fn inline( let expr: &ast::Expr = expr; let mut insert_let_stmt = || { - let ty = sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty.clone()); + let param_ty = match param_ty { + None => None, + Some(param_ty) => { + if sema.hir_file_for(param_ty.syntax()).is_macro() { + if let Some(param_ty) = + ast::Type::cast(insert_ws_into(param_ty.syntax().clone())) + { + Some(param_ty) + } else { + Some(param_ty.clone_for_update()) + } + } else { + Some(param_ty.clone_for_update()) + } + } + }; + let ty: Option = + sema.type_of_expr(expr).filter(TypeInfo::has_adjustment).and(param_ty); let is_self = param .name(sema.db) @@ -1732,6 +1749,49 @@ pub fn main() { this.0 += 1; }; } +"#, + ) + } + + #[test] + fn inline_call_with_reference_in_macros() { + check_assist( + inline_call, + r#" +fn _write_u64(s: &mut u64, x: u64) { + *s += x; +} +macro_rules! impl_write { + ($(($ty:ident, $meth:ident),)*) => {$( + fn _hash(inner_self_: &u64, state: &mut u64) { + $meth(state, *inner_self_) + } + )*} +} +impl_write! { (u64, _write_u64), } +fn _hash2(self_: &u64, state: &mut u64) { + $0_hash(&self_, state); +} +"#, + r#" +fn _write_u64(s: &mut u64, x: u64) { + *s += x; +} +macro_rules! impl_write { + ($(($ty:ident, $meth:ident),)*) => {$( + fn _hash(inner_self_: &u64, state: &mut u64) { + $meth(state, *inner_self_) + } + )*} +} +impl_write! { (u64, _write_u64), } +fn _hash2(self_: &u64, state: &mut u64) { + { + let inner_self_: &u64 = &self_; + let state: &mut u64 = state; + _write_u64(state, *inner_self_) + }; +} "#, ) } diff --git a/crates/ide-assists/src/handlers/inline_const_as_literal.rs b/crates/ide-assists/src/handlers/inline_const_as_literal.rs index 1843745376..111ea50fdc 100644 --- a/crates/ide-assists/src/handlers/inline_const_as_literal.rs +++ b/crates/ide-assists/src/handlers/inline_const_as_literal.rs @@ -60,7 +60,7 @@ pub(crate) fn inline_const_as_literal(acc: &mut Assists, ctx: &AssistContext<'_> let id = AssistId("inline_const_as_literal", AssistKind::RefactorInline); - let label = "Inline const as literal".to_string(); + let label = "Inline const as literal".to_owned(); let target = variable.syntax().text_range(); return acc.add(id, label, target, |edit| { diff --git a/crates/ide-assists/src/handlers/inline_macro.rs b/crates/ide-assists/src/handlers/inline_macro.rs index c1beb46c80..0c9e971dd2 100644 --- a/crates/ide-assists/src/handlers/inline_macro.rs +++ b/crates/ide-assists/src/handlers/inline_macro.rs @@ -41,7 +41,7 @@ pub(crate) fn inline_macro(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option acc.add( AssistId("inline_macro", AssistKind::RefactorInline), - "Inline macro".to_string(), + "Inline macro".to_owned(), text_range, |builder| builder.replace(text_range, expanded.to_string()), ) diff --git a/crates/ide-assists/src/handlers/inline_type_alias.rs b/crates/ide-assists/src/handlers/inline_type_alias.rs index 5982e9d61d..e2f3d9edcd 100644 --- a/crates/ide-assists/src/handlers/inline_type_alias.rs +++ b/crates/ide-assists/src/handlers/inline_type_alias.rs @@ -3,12 +3,12 @@ // - Remove unused aliases if there are no longer any users, see inline_call.rs. use hir::{HasSource, PathResolution}; +use ide_db::FxHashMap; use ide_db::{ defs::Definition, imports::insert_use::ast_to_remove_for_path_in_use_stmt, search::FileReference, }; use itertools::Itertools; -use std::collections::HashMap; use syntax::{ ast::{self, make, HasGenericParams, HasName}, ted, AstNode, NodeOrToken, SyntaxNode, @@ -189,14 +189,14 @@ fn inline(alias_def: &ast::TypeAlias, alias_instance: &ast::PathType) -> Option< Some(repl) } -struct LifetimeMap(HashMap); +struct LifetimeMap(FxHashMap); impl LifetimeMap { fn new( instance_args: &Option, alias_generics: &ast::GenericParamList, ) -> Option { - let mut inner = HashMap::new(); + let mut inner = FxHashMap::default(); let wildcard_lifetime = make::lifetime("'_"); let lifetimes = alias_generics @@ -231,14 +231,14 @@ impl LifetimeMap { } } -struct ConstAndTypeMap(HashMap); +struct ConstAndTypeMap(FxHashMap); impl ConstAndTypeMap { fn new( instance_args: &Option, alias_generics: &ast::GenericParamList, ) -> Option { - let mut inner = HashMap::new(); + let mut inner = FxHashMap::default(); let instance_generics = generic_args_to_const_and_type_generics(instance_args); let alias_generics = generic_param_list_to_const_and_type_generics(alias_generics); diff --git a/crates/ide-assists/src/handlers/into_to_qualified_from.rs b/crates/ide-assists/src/handlers/into_to_qualified_from.rs index f7da88b2c1..dee74afcbe 100644 --- a/crates/ide-assists/src/handlers/into_to_qualified_from.rs +++ b/crates/ide-assists/src/handlers/into_to_qualified_from.rs @@ -48,7 +48,7 @@ pub(crate) fn into_to_qualified_from(acc: &mut Assists, ctx: &AssistContext<'_>) let fnc = sema.resolve_method_call(&method_call)?; let scope = sema.scope(method_call.syntax())?; // Check if the method call refers to Into trait. - if fnc.as_assoc_item(db)?.containing_trait_impl(db)? + if fnc.as_assoc_item(db)?.implemented_trait(db)? == FamousDefs(sema, scope.krate()).core_convert_Into()? { let type_call = sema.type_of_expr(&method_call.clone().into())?; diff --git a/crates/ide-assists/src/handlers/introduce_named_lifetime.rs b/crates/ide-assists/src/handlers/introduce_named_lifetime.rs index c5aa9755bc..62909c586e 100644 --- a/crates/ide-assists/src/handlers/introduce_named_lifetime.rs +++ b/crates/ide-assists/src/handlers/introduce_named_lifetime.rs @@ -129,7 +129,7 @@ fn generate_unique_lifetime_param_name( type_params.lifetime_params().map(|p| p.syntax().text().to_string()).collect(); ('a'..='z').map(|it| format!("'{it}")).find(|it| !used_lifetime_params.contains(it)) } - None => Some("'a".to_string()), + None => Some("'a".to_owned()), } .map(|it| make::lifetime(&it)) } diff --git a/crates/ide-assists/src/handlers/merge_match_arms.rs b/crates/ide-assists/src/handlers/merge_match_arms.rs index aae9f20d4e..4608e9494b 100644 --- a/crates/ide-assists/src/handlers/merge_match_arms.rs +++ b/crates/ide-assists/src/handlers/merge_match_arms.rs @@ -1,5 +1,6 @@ use hir::Type; -use std::{collections::HashMap, iter::successors}; +use ide_db::FxHashMap; +use std::iter::successors; use syntax::{ algo::neighbor, ast::{self, AstNode, HasName}, @@ -95,7 +96,7 @@ fn contains_placeholder(a: &ast::MatchArm) -> bool { } fn are_same_types( - current_arm_types: &HashMap>, + current_arm_types: &FxHashMap>, arm: &ast::MatchArm, ctx: &AssistContext<'_>, ) -> bool { @@ -114,11 +115,11 @@ fn are_same_types( fn get_arm_types( context: &AssistContext<'_>, arm: &ast::MatchArm, -) -> HashMap> { - let mut mapping: HashMap> = HashMap::new(); +) -> FxHashMap> { + let mut mapping: FxHashMap> = FxHashMap::default(); fn recurse( - map: &mut HashMap>, + map: &mut FxHashMap>, ctx: &AssistContext<'_>, pat: &Option, ) { diff --git a/crates/ide-assists/src/handlers/move_module_to_file.rs b/crates/ide-assists/src/handlers/move_module_to_file.rs index 166b25c69e..048906d9d9 100644 --- a/crates/ide-assists/src/handlers/move_module_to_file.rs +++ b/crates/ide-assists/src/handlers/move_module_to_file.rs @@ -75,7 +75,7 @@ pub(crate) fn move_module_to_file(acc: &mut Assists, ctx: &AssistContext<'_>) -> let contents = { let items = module_items.dedent(IndentLevel(1)).to_string(); let mut items = - items.trim_start_matches('{').trim_end_matches('}').trim().to_string(); + items.trim_start_matches('{').trim_end_matches('}').trim().to_owned(); if !items.is_empty() { items.push('\n'); } diff --git a/crates/ide-assists/src/handlers/number_representation.rs b/crates/ide-assists/src/handlers/number_representation.rs index 7e3fef516b..a13799f9b1 100644 --- a/crates/ide-assists/src/handlers/number_representation.rs +++ b/crates/ide-assists/src/handlers/number_representation.rs @@ -33,7 +33,7 @@ pub(crate) fn reformat_number_literal(acc: &mut Assists, ctx: &AssistContext<'_> } let radix = literal.radix(); - let mut converted = prefix.to_string(); + let mut converted = prefix.to_owned(); converted.push_str(&add_group_separators(value, group_size(radix))); converted.push_str(suffix); diff --git a/crates/ide-assists/src/handlers/qualify_path.rs b/crates/ide-assists/src/handlers/qualify_path.rs index 0864871849..63a09ce2e9 100644 --- a/crates/ide-assists/src/handlers/qualify_path.rs +++ b/crates/ide-assists/src/handlers/qualify_path.rs @@ -203,7 +203,7 @@ fn item_as_trait(db: &RootDatabase, item: hir::ItemInNs) -> Option { match item_module_def { hir::ModuleDef::Trait(trait_) => Some(trait_), - _ => item_module_def.as_assoc_item(db)?.containing_trait(db), + _ => item_module_def.as_assoc_item(db)?.container_trait(db), } } diff --git a/crates/ide-assists/src/handlers/remove_unused_imports.rs b/crates/ide-assists/src/handlers/remove_unused_imports.rs index 35bf84c434..d67b259d2f 100644 --- a/crates/ide-assists/src/handlers/remove_unused_imports.rs +++ b/crates/ide-assists/src/handlers/remove_unused_imports.rs @@ -1,11 +1,11 @@ -use std::collections::{hash_map::Entry, HashMap}; +use std::collections::hash_map::Entry; use hir::{HirFileIdExt, InFile, InRealFile, Module, ModuleSource}; use ide_db::{ base_db::FileRange, defs::Definition, search::{FileReference, ReferenceCategory, SearchScope}, - RootDatabase, + FxHashMap, RootDatabase, }; use syntax::{ast, AstNode}; use text_edit::TextRange; @@ -44,7 +44,7 @@ pub(crate) fn remove_unused_imports(acc: &mut Assists, ctx: &AssistContext<'_>) let uses = uses_up.chain(uses_down).collect::>(); // Maps use nodes to the scope that we should search through to find - let mut search_scopes = HashMap::>::new(); + let mut search_scopes = FxHashMap::>::default(); // iterator over all unused use trees let mut unused = uses diff --git a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs index 788cc846c2..3420d906de 100644 --- a/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs +++ b/crates/ide-assists/src/handlers/replace_derive_with_manual_impl.rs @@ -2,15 +2,17 @@ use hir::{InFile, MacroFileIdExt, ModuleDef}; use ide_db::{helpers::mod_path_to_ast, imports::import_assets::NameToImport, items_locator}; use itertools::Itertools; use syntax::{ - ast::{self, AstNode, HasName}, + ast::{self, make, AstNode, HasName}, + ted, SyntaxKind::WHITESPACE, + T, }; use crate::{ assist_context::{AssistContext, Assists, SourceChangeBuilder}, utils::{ - add_trait_assoc_items_to_impl, filter_assoc_items, gen_trait_fn_body, - generate_trait_impl_text, render_snippet, Cursor, DefaultMethods, IgnoreAssocItems, + add_trait_assoc_items_to_impl, filter_assoc_items, gen_trait_fn_body, generate_trait_impl, + DefaultMethods, IgnoreAssocItems, }, AssistId, AssistKind, }; @@ -132,35 +134,59 @@ fn add_assist( label, target, |builder| { - let insert_pos = adt.syntax().text_range().end(); + let insert_after = ted::Position::after(builder.make_mut(adt.clone()).syntax()); + let impl_def_with_items = impl_def_from_trait(&ctx.sema, adt, &annotated_name, trait_, replace_trait_path); update_attribute(builder, old_derives, old_tree, old_trait_path, attr); - let trait_path = replace_trait_path.to_string(); + + let trait_path = make::ty_path(replace_trait_path.clone()); + match (ctx.config.snippet_cap, impl_def_with_items) { (None, _) => { - builder.insert(insert_pos, generate_trait_impl_text(adt, &trait_path, "")) + let impl_def = generate_trait_impl(adt, trait_path); + + ted::insert_all( + insert_after, + vec![make::tokens::blank_line().into(), impl_def.syntax().clone().into()], + ); + } + (Some(cap), None) => { + let impl_def = generate_trait_impl(adt, trait_path); + + if let Some(l_curly) = + impl_def.assoc_item_list().and_then(|it| it.l_curly_token()) + { + builder.add_tabstop_after_token(cap, l_curly); + } + + ted::insert_all( + insert_after, + vec![make::tokens::blank_line().into(), impl_def.syntax().clone().into()], + ); } - (Some(cap), None) => builder.insert_snippet( - cap, - insert_pos, - generate_trait_impl_text(adt, &trait_path, " $0"), - ), (Some(cap), Some((impl_def, first_assoc_item))) => { - let mut cursor = Cursor::Before(first_assoc_item.syntax()); - let placeholder; + let mut added_snippet = false; if let ast::AssocItem::Fn(ref func) = first_assoc_item { if let Some(m) = func.syntax().descendants().find_map(ast::MacroCall::cast) { if m.syntax().text() == "todo!()" { - placeholder = m; - cursor = Cursor::Replace(placeholder.syntax()); + // Make the `todo!()` a placeholder + builder.add_placeholder_snippet(cap, m); + added_snippet = true; } } } - let rendered = render_snippet(cap, impl_def.syntax(), cursor); - builder.insert_snippet(cap, insert_pos, format!("\n\n{rendered}")) + if !added_snippet { + // If we haven't already added a snippet, add a tabstop before the generated function + builder.add_tabstop_before(cap, first_assoc_item); + } + + ted::insert_all( + insert_after, + vec![make::tokens::blank_line().into(), impl_def.syntax().clone().into()], + ); } }; }, @@ -190,28 +216,7 @@ fn impl_def_from_trait( if trait_items.is_empty() { return None; } - let impl_def = { - use syntax::ast::Impl; - let text = generate_trait_impl_text(adt, trait_path.to_string().as_str(), ""); - // FIXME: `generate_trait_impl_text` currently generates two newlines - // at the front, but these leading newlines should really instead be - // inserted at the same time the impl is inserted - assert_eq!(&text[..2], "\n\n", "`generate_trait_impl_text` output changed"); - let parse = syntax::SourceFile::parse(&text[2..]); - let node = match parse.tree().syntax().descendants().find_map(Impl::cast) { - Some(it) => it, - None => { - panic!( - "Failed to make ast node `{}` from text {}", - std::any::type_name::(), - text - ) - } - }; - let node = node.clone_for_update(); - assert_eq!(node.syntax().text_range().start(), 0.into()); - node - }; + let impl_def = generate_trait_impl(adt, make::ty_path(trait_path.clone())); let first_assoc_item = add_trait_assoc_items_to_impl(sema, &trait_items, trait_, &impl_def, target_scope); @@ -238,20 +243,34 @@ fn update_attribute( let has_more_derives = !new_derives.is_empty(); if has_more_derives { - let new_derives = format!("({})", new_derives.iter().format(", ")); - builder.replace(old_tree.syntax().text_range(), new_derives); - } else { - let attr_range = attr.syntax().text_range(); - builder.delete(attr_range); + let old_tree = builder.make_mut(old_tree.clone()); - if let Some(line_break_range) = attr - .syntax() - .next_sibling_or_token() - .filter(|t| t.kind() == WHITESPACE) - .map(|t| t.text_range()) + // Make the paths into flat lists of tokens in a vec + let tt = new_derives.iter().map(|path| path.syntax().clone()).map(|node| { + node.descendants_with_tokens() + .filter_map(|element| element.into_token()) + .collect::>() + }); + // ...which are interspersed with ", " + let tt = Itertools::intersperse(tt, vec![make::token(T![,]), make::tokens::single_space()]); + // ...wrap them into the appropriate `NodeOrToken` variant + let tt = tt.flatten().map(syntax::NodeOrToken::Token); + // ...and make them into a flat list of tokens + let tt = tt.collect::>(); + + let new_tree = make::token_tree(T!['('], tt).clone_for_update(); + ted::replace(old_tree.syntax(), new_tree.syntax()); + } else { + // Remove the attr and any trailing whitespace + let attr = builder.make_mut(attr.clone()); + + if let Some(line_break) = + attr.syntax().next_sibling_or_token().filter(|t| t.kind() == WHITESPACE) { - builder.delete(line_break_range); + ted::remove(line_break) } + + ted::remove(attr.syntax()) } } @@ -1168,9 +1187,7 @@ struct Foo { bar: String, } -impl Debug for Foo { - $0 -} +impl Debug for Foo {$0} "#, ) } @@ -1191,9 +1208,7 @@ pub struct Foo { bar: String, } -impl Debug for Foo { - $0 -} +impl Debug for Foo {$0} "#, ) } @@ -1211,9 +1226,7 @@ struct Foo {} #[derive(Display, Serialize)] struct Foo {} -impl Debug for Foo { - $0 -} +impl Debug for Foo {$0} "#, ) } diff --git a/crates/ide-assists/src/handlers/unwrap_result_return_type.rs b/crates/ide-assists/src/handlers/unwrap_result_return_type.rs index 03e6dfebeb..8a9e669630 100644 --- a/crates/ide-assists/src/handlers/unwrap_result_return_type.rs +++ b/crates/ide-assists/src/handlers/unwrap_result_return_type.rs @@ -47,9 +47,7 @@ pub(crate) fn unwrap_result_return_type(acc: &mut Assists, ctx: &AssistContext<' return None; } - let Some(ok_type) = unwrap_result_type(type_ref) else { - return None; - }; + let ok_type = unwrap_result_type(type_ref)?; acc.add( AssistId("unwrap_result_return_type", AssistKind::RefactorRewrite), diff --git a/crates/ide-assists/src/tests.rs b/crates/ide-assists/src/tests.rs index 573d69b5c6..9b6f7d018e 100644 --- a/crates/ide-assists/src/tests.rs +++ b/crates/ide-assists/src/tests.rs @@ -474,7 +474,7 @@ pub fn test_some_range(a: int) -> bool { &db, &cfg, AssistResolveStrategy::Single(SingleResolve { - assist_id: "SOMETHING_MISMATCHING".to_string(), + assist_id: "SOMETHING_MISMATCHING".to_owned(), assist_kind: AssistKind::RefactorExtract, }), frange, @@ -520,7 +520,7 @@ pub fn test_some_range(a: int) -> bool { &db, &cfg, AssistResolveStrategy::Single(SingleResolve { - assist_id: "extract_variable".to_string(), + assist_id: "extract_variable".to_owned(), assist_kind: AssistKind::RefactorExtract, }), frange, @@ -687,12 +687,21 @@ pub fn test_some_range(a: int) -> bool { delete: 59..60, }, Indel { - insert: "\n\nfn $0fun_name() -> i32 {\n 5\n}", + insert: "\n\nfn fun_name() -> i32 {\n 5\n}", delete: 110..110, }, ], }, - None, + Some( + SnippetEdit( + [ + ( + 0, + 124..124, + ), + ], + ), + ), ), }, file_system_edits: [], diff --git a/crates/ide-assists/src/tests/generated.rs b/crates/ide-assists/src/tests/generated.rs index 8d7c49d52c..8ad735d0ae 100644 --- a/crates/ide-assists/src/tests/generated.rs +++ b/crates/ide-assists/src/tests/generated.rs @@ -1429,7 +1429,7 @@ fn doctest_generate_getter() { check_doc_test( "generate_getter", r#####" -//- minicore: as_ref +//- minicore: as_ref, deref pub struct String; impl AsRef for String { fn as_ref(&self) -> &str { @@ -1437,6 +1437,13 @@ impl AsRef for String { } } +impl core::ops::Deref for String { + type Target = str; + fn deref(&self) -> &Self::Target { + "" + } +} + struct Person { nam$0e: String, } @@ -1449,13 +1456,20 @@ impl AsRef for String { } } +impl core::ops::Deref for String { + type Target = str; + fn deref(&self) -> &Self::Target { + "" + } +} + struct Person { name: String, } impl Person { fn $0name(&self) -> &str { - self.name.as_ref() + &self.name } } "#####, @@ -1499,9 +1513,7 @@ struct Ctx { data: T, } -impl Ctx { - $0 -} +impl Ctx {$0} "#####, ) } @@ -1589,7 +1601,9 @@ struct Ctx { } impl Ctx { - fn $0new(data: T) -> Self { Self { data } } + fn $0new(data: T) -> Self { + Self { data } + } } "#####, ) @@ -1688,9 +1702,7 @@ struct Ctx { data: T, } -impl $0 for Ctx { - -} +impl ${0:_} for Ctx {} "#####, ) } diff --git a/crates/ide-assists/src/tests/sourcegen.rs b/crates/ide-assists/src/tests/sourcegen.rs index 088d93f9a6..847cb1af51 100644 --- a/crates/ide-assists/src/tests/sourcegen.rs +++ b/crates/ide-assists/src/tests/sourcegen.rs @@ -15,7 +15,7 @@ fn sourcegen_assists_docs() { let mut buf = " use super::check_doc_test; " - .to_string(); + .to_owned(); for assist in assists.iter() { for (idx, section) in assist.sections.iter().enumerate() { let test_id = @@ -101,7 +101,7 @@ impl Assist { let mut assist = Assist { id, location, sections: Vec::new() }; while lines.peek().is_some() { - let doc = take_until(lines.by_ref(), "```").trim().to_string(); + let doc = take_until(lines.by_ref(), "```").trim().to_owned(); assert!( (doc.chars().next().unwrap().is_ascii_uppercase() && doc.ends_with('.')) || !assist.sections.is_empty(), diff --git a/crates/ide-assists/src/utils.rs b/crates/ide-assists/src/utils.rs index eeb3d80d07..a4f1432675 100644 --- a/crates/ide-assists/src/utils.rs +++ b/crates/ide-assists/src/utils.rs @@ -1,12 +1,10 @@ //! Assorted functions shared by several assists. -use std::ops; - pub(crate) use gen_trait_fn_body::gen_trait_fn_body; use hir::{db::HirDatabase, HasAttrs as HirHasAttrs, HirDisplay, InFile, Semantics}; use ide_db::{ famous_defs::FamousDefs, path_transform::PathTransform, - syntax_helpers::insert_whitespace_into_node::insert_ws_into, RootDatabase, SnippetCap, + syntax_helpers::insert_whitespace_into_node::insert_ws_into, RootDatabase, }; use stdx::format_to; use syntax::{ @@ -217,43 +215,6 @@ pub fn add_trait_assoc_items_to_impl( first_item.unwrap() } -#[derive(Clone, Copy, Debug)] -pub(crate) enum Cursor<'a> { - Replace(&'a SyntaxNode), - Before(&'a SyntaxNode), -} - -impl<'a> Cursor<'a> { - fn node(self) -> &'a SyntaxNode { - match self { - Cursor::Replace(node) | Cursor::Before(node) => node, - } - } -} - -pub(crate) fn render_snippet(_cap: SnippetCap, node: &SyntaxNode, cursor: Cursor<'_>) -> String { - assert!(cursor.node().ancestors().any(|it| it == *node)); - let range = cursor.node().text_range() - node.text_range().start(); - let range: ops::Range = range.into(); - - let mut placeholder = cursor.node().to_string(); - escape(&mut placeholder); - let tab_stop = match cursor { - Cursor::Replace(placeholder) => format!("${{0:{placeholder}}}"), - Cursor::Before(placeholder) => format!("$0{placeholder}"), - }; - - let mut buf = node.to_string(); - buf.replace_range(range, &tab_stop); - return buf; - - fn escape(buf: &mut String) { - stdx::replace(buf, '{', r"\{"); - stdx::replace(buf, '}', r"\}"); - stdx::replace(buf, '$', r"\$"); - } -} - pub(crate) fn vis_offset(node: &SyntaxNode) -> TextSize { node.children_with_tokens() .find(|it| !matches!(it.kind(), WHITESPACE | COMMENT | ATTR)) @@ -445,15 +406,6 @@ fn has_any_fn(imp: &ast::Impl, names: &[String]) -> bool { false } -/// Find the start of the `impl` block for the given `ast::Impl`. -// -// FIXME: this partially overlaps with `find_struct_impl` -pub(crate) fn find_impl_block_start(impl_def: ast::Impl, buf: &mut String) -> Option { - buf.push('\n'); - let start = impl_def.assoc_item_list().and_then(|it| it.l_curly_token())?.text_range().end(); - Some(start) -} - /// Find the end of the `impl` block for the given `ast::Impl`. // // FIXME: this partially overlaps with `find_struct_impl` @@ -470,6 +422,7 @@ pub(crate) fn find_impl_block_end(impl_def: ast::Impl, buf: &mut String) -> Opti /// Generates the surrounding `impl Type { }` including type and lifetime /// parameters. +// FIXME: migrate remaining uses to `generate_impl` pub(crate) fn generate_impl_text(adt: &ast::Adt, code: &str) -> String { generate_impl_text_inner(adt, None, true, code) } @@ -478,6 +431,7 @@ pub(crate) fn generate_impl_text(adt: &ast::Adt, code: &str) -> String { /// and lifetime parameters, with `` appended to `impl`'s generic parameters' bounds. /// /// This is useful for traits like `PartialEq`, since `impl PartialEq for U` often requires `T: PartialEq`. +// FIXME: migrate remaining uses to `generate_trait_impl` pub(crate) fn generate_trait_impl_text(adt: &ast::Adt, trait_text: &str, code: &str) -> String { generate_impl_text_inner(adt, Some(trait_text), true, code) } @@ -486,6 +440,7 @@ pub(crate) fn generate_trait_impl_text(adt: &ast::Adt, trait_text: &str, code: & /// and lifetime parameters, with `impl`'s generic parameters' bounds kept as-is. /// /// This is useful for traits like `From`, since `impl From for U` doesn't require `T: From`. +// FIXME: migrate remaining uses to `generate_trait_impl_intransitive` pub(crate) fn generate_trait_impl_text_intransitive( adt: &ast::Adt, trait_text: &str, @@ -516,7 +471,7 @@ fn generate_impl_text_inner( // Add the current trait to `bounds` if the trait is transitive, // meaning `impl Trait for U` requires `T: Trait`. if trait_is_transitive { - bounds.push(make::type_bound(trait_)); + bounds.push(make::type_bound_text(trait_)); } }; // `{ty_param}: {bounds}` @@ -574,6 +529,101 @@ fn generate_impl_text_inner( buf } +/// Generates the corresponding `impl Type {}` including type and lifetime +/// parameters. +pub(crate) fn generate_impl(adt: &ast::Adt) -> ast::Impl { + generate_impl_inner(adt, None, true) +} + +/// Generates the corresponding `impl for Type {}` including type +/// and lifetime parameters, with `` appended to `impl`'s generic parameters' bounds. +/// +/// This is useful for traits like `PartialEq`, since `impl PartialEq for U` often requires `T: PartialEq`. +pub(crate) fn generate_trait_impl(adt: &ast::Adt, trait_: ast::Type) -> ast::Impl { + generate_impl_inner(adt, Some(trait_), true) +} + +/// Generates the corresponding `impl for Type {}` including type +/// and lifetime parameters, with `impl`'s generic parameters' bounds kept as-is. +/// +/// This is useful for traits like `From`, since `impl From for U` doesn't require `T: From`. +pub(crate) fn generate_trait_impl_intransitive(adt: &ast::Adt, trait_: ast::Type) -> ast::Impl { + generate_impl_inner(adt, Some(trait_), false) +} + +fn generate_impl_inner( + adt: &ast::Adt, + trait_: Option, + trait_is_transitive: bool, +) -> ast::Impl { + // Ensure lifetime params are before type & const params + let generic_params = adt.generic_param_list().map(|generic_params| { + let lifetime_params = + generic_params.lifetime_params().map(ast::GenericParam::LifetimeParam); + let ty_or_const_params = generic_params.type_or_const_params().map(|param| { + match param { + ast::TypeOrConstParam::Type(param) => { + let param = param.clone_for_update(); + // remove defaults since they can't be specified in impls + param.remove_default(); + let mut bounds = + param.type_bound_list().map_or_else(Vec::new, |it| it.bounds().collect()); + if let Some(trait_) = &trait_ { + // Add the current trait to `bounds` if the trait is transitive, + // meaning `impl Trait for U` requires `T: Trait`. + if trait_is_transitive { + bounds.push(make::type_bound(trait_.clone())); + } + }; + // `{ty_param}: {bounds}` + let param = + make::type_param(param.name().unwrap(), make::type_bound_list(bounds)); + ast::GenericParam::TypeParam(param) + } + ast::TypeOrConstParam::Const(param) => { + let param = param.clone_for_update(); + // remove defaults since they can't be specified in impls + param.remove_default(); + ast::GenericParam::ConstParam(param) + } + } + }); + + make::generic_param_list(itertools::chain(lifetime_params, ty_or_const_params)) + }); + let generic_args = + generic_params.as_ref().map(|params| params.to_generic_args().clone_for_update()); + let ty = make::ty_path(make::ext::ident_path(&adt.name().unwrap().text())); + + let impl_ = match trait_ { + Some(trait_) => make::impl_trait( + false, + None, + None, + generic_params, + generic_args, + false, + trait_, + ty, + None, + adt.where_clause(), + None, + ), + None => make::impl_(generic_params, generic_args, ty, adt.where_clause(), None), + } + .clone_for_update(); + + // Copy any cfg attrs from the original adt + let cfg_attrs = adt + .attrs() + .filter(|attr| attr.as_simple_call().map(|(name, _arg)| name == "cfg").unwrap_or(false)); + for attr in cfg_attrs { + impl_.add_attr(attr.clone_for_update()); + } + + impl_ +} + pub(crate) fn add_method_to_adt( builder: &mut SourceChangeBuilder, adt: &ast::Adt, @@ -600,6 +650,7 @@ pub(crate) fn add_method_to_adt( pub(crate) struct ReferenceConversion { conversion: ReferenceConversionType, ty: hir::Type, + impls_deref: bool, } #[derive(Debug)] @@ -619,10 +670,10 @@ enum ReferenceConversionType { } impl ReferenceConversion { - pub(crate) fn convert_type(&self, db: &dyn HirDatabase) -> String { - match self.conversion { + pub(crate) fn convert_type(&self, db: &dyn HirDatabase) -> ast::Type { + let ty = match self.conversion { ReferenceConversionType::Copy => self.ty.display(db).to_string(), - ReferenceConversionType::AsRefStr => "&str".to_string(), + ReferenceConversionType::AsRefStr => "&str".to_owned(), ReferenceConversionType::AsRefSlice => { let type_argument_name = self.ty.type_arguments().next().unwrap().display(db).to_string(); @@ -646,17 +697,27 @@ impl ReferenceConversion { type_arguments.next().unwrap().display(db).to_string(); format!("Result<&{first_type_argument_name}, &{second_type_argument_name}>") } - } + }; + + make::ty(&ty) } - pub(crate) fn getter(&self, field_name: String) -> String { + pub(crate) fn getter(&self, field_name: String) -> ast::Expr { + let expr = make::expr_field(make::ext::expr_self(), &field_name); + match self.conversion { - ReferenceConversionType::Copy => format!("self.{field_name}"), + ReferenceConversionType::Copy => expr, ReferenceConversionType::AsRefStr | ReferenceConversionType::AsRefSlice | ReferenceConversionType::Dereferenced | ReferenceConversionType::Option - | ReferenceConversionType::Result => format!("self.{field_name}.as_ref()"), + | ReferenceConversionType::Result => { + if self.impls_deref { + make::expr_ref(expr, false) + } else { + make::expr_method_call(expr, make::name_ref("as_ref"), make::arg_list([])) + } + } } } } @@ -675,54 +736,64 @@ pub(crate) fn convert_reference_type( .or_else(|| handle_dereferenced(&ty, db, famous_defs)) .or_else(|| handle_option_as_ref(&ty, db, famous_defs)) .or_else(|| handle_result_as_ref(&ty, db, famous_defs)) - .map(|conversion| ReferenceConversion { ty, conversion }) + .map(|(conversion, impls_deref)| ReferenceConversion { ty, conversion, impls_deref }) } -fn handle_copy(ty: &hir::Type, db: &dyn HirDatabase) -> Option { - ty.is_copy(db).then_some(ReferenceConversionType::Copy) +fn could_deref_to_target(ty: &hir::Type, target: &hir::Type, db: &dyn HirDatabase) -> bool { + let ty_ref = hir::Type::reference(ty, hir::Mutability::Shared); + let target_ref = hir::Type::reference(target, hir::Mutability::Shared); + ty_ref.could_coerce_to(db, &target_ref) +} + +fn handle_copy(ty: &hir::Type, db: &dyn HirDatabase) -> Option<(ReferenceConversionType, bool)> { + ty.is_copy(db).then_some((ReferenceConversionType::Copy, true)) } fn handle_as_ref_str( ty: &hir::Type, db: &dyn HirDatabase, famous_defs: &FamousDefs<'_, '_>, -) -> Option { +) -> Option<(ReferenceConversionType, bool)> { let str_type = hir::BuiltinType::str().ty(db); - ty.impls_trait(db, famous_defs.core_convert_AsRef()?, &[str_type]) - .then_some(ReferenceConversionType::AsRefStr) + ty.impls_trait(db, famous_defs.core_convert_AsRef()?, &[str_type.clone()]) + .then_some((ReferenceConversionType::AsRefStr, could_deref_to_target(ty, &str_type, db))) } fn handle_as_ref_slice( ty: &hir::Type, db: &dyn HirDatabase, famous_defs: &FamousDefs<'_, '_>, -) -> Option { +) -> Option<(ReferenceConversionType, bool)> { let type_argument = ty.type_arguments().next()?; let slice_type = hir::Type::new_slice(type_argument); - ty.impls_trait(db, famous_defs.core_convert_AsRef()?, &[slice_type]) - .then_some(ReferenceConversionType::AsRefSlice) + ty.impls_trait(db, famous_defs.core_convert_AsRef()?, &[slice_type.clone()]).then_some(( + ReferenceConversionType::AsRefSlice, + could_deref_to_target(ty, &slice_type, db), + )) } fn handle_dereferenced( ty: &hir::Type, db: &dyn HirDatabase, famous_defs: &FamousDefs<'_, '_>, -) -> Option { +) -> Option<(ReferenceConversionType, bool)> { let type_argument = ty.type_arguments().next()?; - ty.impls_trait(db, famous_defs.core_convert_AsRef()?, &[type_argument]) - .then_some(ReferenceConversionType::Dereferenced) + ty.impls_trait(db, famous_defs.core_convert_AsRef()?, &[type_argument.clone()]).then_some(( + ReferenceConversionType::Dereferenced, + could_deref_to_target(ty, &type_argument, db), + )) } fn handle_option_as_ref( ty: &hir::Type, db: &dyn HirDatabase, famous_defs: &FamousDefs<'_, '_>, -) -> Option { +) -> Option<(ReferenceConversionType, bool)> { if ty.as_adt() == famous_defs.core_option_Option()?.ty(db).as_adt() { - Some(ReferenceConversionType::Option) + Some((ReferenceConversionType::Option, false)) } else { None } @@ -732,9 +803,9 @@ fn handle_result_as_ref( ty: &hir::Type, db: &dyn HirDatabase, famous_defs: &FamousDefs<'_, '_>, -) -> Option { +) -> Option<(ReferenceConversionType, bool)> { if ty.as_adt() == famous_defs.core_result_Result()?.ty(db).as_adt() { - Some(ReferenceConversionType::Result) + Some((ReferenceConversionType::Result, false)) } else { None } diff --git a/crates/ide-assists/src/utils/suggest_name.rs b/crates/ide-assists/src/utils/suggest_name.rs index 78dee24a6d..1859825b3d 100644 --- a/crates/ide-assists/src/utils/suggest_name.rs +++ b/crates/ide-assists/src/utils/suggest_name.rs @@ -1,9 +1,7 @@ //! This module contains functions to suggest names for expressions, functions and other items -use std::collections::HashSet; - use hir::Semantics; -use ide_db::RootDatabase; +use ide_db::{FxHashSet, RootDatabase}; use itertools::Itertools; use stdx::to_lower_snake_case; use syntax::{ @@ -78,8 +76,8 @@ pub(crate) fn for_unique_generic_name( ast::GenericParam::TypeParam(t) => t.name().unwrap().to_string(), p => p.to_string(), }) - .collect::>(); - let mut name = name.to_string(); + .collect::>(); + let mut name = name.to_owned(); let base_len = name.len(); let mut count = 0; while param_names.contains(&name) { @@ -167,7 +165,7 @@ pub(crate) fn for_variable(expr: &ast::Expr, sema: &Semantics<'_, RootDatabase>) } } - "var_name".to_string() + "var_name".to_owned() } fn normalize(name: &str) -> Option { diff --git a/crates/ide-completion/src/completions/dot.rs b/crates/ide-completion/src/completions/dot.rs index 00135a6d20..24a1f9492e 100644 --- a/crates/ide-completion/src/completions/dot.rs +++ b/crates/ide-completion/src/completions/dot.rs @@ -4,7 +4,10 @@ use ide_db::FxHashSet; use syntax::SmolStr; use crate::{ - context::{CompletionContext, DotAccess, DotAccessKind, ExprCtx, PathCompletionCtx, Qualified}, + context::{ + CompletionContext, DotAccess, DotAccessExprCtx, DotAccessKind, PathCompletionCtx, + PathExprCtx, Qualified, + }, CompletionItem, CompletionItemKind, Completions, }; @@ -51,7 +54,7 @@ pub(crate) fn complete_undotted_self( acc: &mut Completions, ctx: &CompletionContext<'_>, path_ctx: &PathCompletionCtx, - expr_ctx: &ExprCtx, + expr_ctx: &PathExprCtx, ) { if !ctx.config.enable_self_on_the_fly { return; @@ -66,7 +69,7 @@ pub(crate) fn complete_undotted_self( return; } let self_param = match expr_ctx { - ExprCtx { self_param: Some(self_param), .. } => self_param, + PathExprCtx { self_param: Some(self_param), .. } => self_param, _ => return, }; @@ -82,6 +85,10 @@ pub(crate) fn complete_undotted_self( receiver: None, receiver_ty: None, kind: DotAccessKind::Field { receiver_is_ambiguous_float_literal: false }, + ctx: DotAccessExprCtx { + in_block_expr: expr_ctx.in_block_expr, + in_breakable: expr_ctx.in_breakable, + }, }, Some(hir::known::SELF_PARAM), field, @@ -99,6 +106,10 @@ pub(crate) fn complete_undotted_self( receiver: None, receiver_ty: None, kind: DotAccessKind::Method { has_parens: false }, + ctx: DotAccessExprCtx { + in_block_expr: expr_ctx.in_block_expr, + in_breakable: expr_ctx.in_breakable, + }, }, func, Some(hir::known::SELF_PARAM), diff --git a/crates/ide-completion/src/completions/expr.rs b/crates/ide-completion/src/completions/expr.rs index 1433216d61..77fd5dd98b 100644 --- a/crates/ide-completion/src/completions/expr.rs +++ b/crates/ide-completion/src/completions/expr.rs @@ -5,7 +5,7 @@ use syntax::ast; use crate::{ completions::record::add_default_update, - context::{ExprCtx, PathCompletionCtx, Qualified}, + context::{BreakableKind, PathCompletionCtx, PathExprCtx, Qualified}, CompletionContext, Completions, }; @@ -13,16 +13,16 @@ pub(crate) fn complete_expr_path( acc: &mut Completions, ctx: &CompletionContext<'_>, path_ctx @ PathCompletionCtx { qualified, .. }: &PathCompletionCtx, - expr_ctx: &ExprCtx, + expr_ctx: &PathExprCtx, ) { let _p = tracing::span!(tracing::Level::INFO, "complete_expr_path").entered(); if !ctx.qualifier_ctx.none() { return; } - let &ExprCtx { + let &PathExprCtx { in_block_expr, - in_loop_body, + in_breakable, after_if_expr, in_condition, incomplete_let, @@ -290,7 +290,7 @@ pub(crate) fn complete_expr_path( add_keyword("mut", "mut "); } - if in_loop_body { + if in_breakable != BreakableKind::None { if in_block_expr { add_keyword("continue", "continue;"); add_keyword("break", "break;"); diff --git a/crates/ide-completion/src/completions/extern_crate.rs b/crates/ide-completion/src/completions/extern_crate.rs index f9cde44667..b67d82c20d 100644 --- a/crates/ide-completion/src/completions/extern_crate.rs +++ b/crates/ide-completion/src/completions/extern_crate.rs @@ -46,7 +46,7 @@ mod other_mod {} let completion_list = completion_list_no_kw(case); - assert_eq!("md other_crate_a\n".to_string(), completion_list); + assert_eq!("md other_crate_a\n".to_owned(), completion_list); } #[test] @@ -66,6 +66,6 @@ mod other_mod {} let completion_list = completion_list_no_kw(case); - assert_eq!("md other_crate_a\n".to_string(), completion_list); + assert_eq!("md other_crate_a\n".to_owned(), completion_list); } } diff --git a/crates/ide-completion/src/completions/flyimport.rs b/crates/ide-completion/src/completions/flyimport.rs index 0e04ad35d3..b9f91d34b2 100644 --- a/crates/ide-completion/src/completions/flyimport.rs +++ b/crates/ide-completion/src/completions/flyimport.rs @@ -395,6 +395,14 @@ fn import_assets_for_path( potential_import_name: &str, qualifier: Option, ) -> Option { + let _p = tracing::span!( + tracing::Level::INFO, + "import_assets_for_path", + ?potential_import_name, + ?qualifier + ) + .entered(); + let fuzzy_name_length = potential_import_name.len(); let mut assets_for_path = ImportAssets::for_fuzzy_path( ctx.module, diff --git a/crates/ide-completion/src/completions/item_list.rs b/crates/ide-completion/src/completions/item_list.rs index addd9dac1a..0a6a8633a2 100644 --- a/crates/ide-completion/src/completions/item_list.rs +++ b/crates/ide-completion/src/completions/item_list.rs @@ -1,7 +1,7 @@ //! Completion of paths and keywords at item list position. use crate::{ - context::{ExprCtx, ItemListKind, PathCompletionCtx, Qualified}, + context::{ItemListKind, PathCompletionCtx, PathExprCtx, Qualified}, CompletionContext, Completions, }; @@ -11,7 +11,7 @@ pub(crate) fn complete_item_list_in_expr( acc: &mut Completions, ctx: &CompletionContext<'_>, path_ctx: &PathCompletionCtx, - expr_ctx: &ExprCtx, + expr_ctx: &PathExprCtx, ) { if !expr_ctx.in_block_expr { return; diff --git a/crates/ide-completion/src/completions/keyword.rs b/crates/ide-completion/src/completions/keyword.rs index b9ab2afca2..ed32a5db23 100644 --- a/crates/ide-completion/src/completions/keyword.rs +++ b/crates/ide-completion/src/completions/keyword.rs @@ -81,11 +81,13 @@ fn foo(a: A) { a.$0 } sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn let let sn letm let mut sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ); @@ -106,11 +108,13 @@ fn foo() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn let let sn letm let mut sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ); @@ -133,11 +137,13 @@ fn foo(a: A) { a.$0 } sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn let let sn letm let mut sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ); diff --git a/crates/ide-completion/src/completions/postfix.rs b/crates/ide-completion/src/completions/postfix.rs index af83d4104f..72c0885e92 100644 --- a/crates/ide-completion/src/completions/postfix.rs +++ b/crates/ide-completion/src/completions/postfix.rs @@ -2,6 +2,7 @@ mod format_like; +use hir::ItemInNs; use ide_db::{ documentation::{Documentation, HasDocs}, imports::insert_use::ImportScope, @@ -17,7 +18,7 @@ use text_edit::TextEdit; use crate::{ completions::postfix::format_like::add_format_like_completions, - context::{CompletionContext, DotAccess, DotAccessKind}, + context::{BreakableKind, CompletionContext, DotAccess, DotAccessKind}, item::{Builder, CompletionRelevancePostfixMatch}, CompletionItem, CompletionItemKind, CompletionRelevance, Completions, SnippetScope, }; @@ -44,6 +45,7 @@ pub(crate) fn complete_postfix( ), _ => return, }; + let expr_ctx = &dot_access.ctx; let receiver_text = get_receiver_text(dot_receiver, receiver_is_ambiguous_float_literal); @@ -59,16 +61,22 @@ pub(crate) fn complete_postfix( if let Some(drop_trait) = ctx.famous_defs().core_ops_Drop() { if receiver_ty.impls_trait(ctx.db, drop_trait, &[]) { - if let &[hir::AssocItem::Function(drop_fn)] = &*drop_trait.items(ctx.db) { - cov_mark::hit!(postfix_drop_completion); - // FIXME: check that `drop` is in scope, use fully qualified path if it isn't/if shadowed - let mut item = postfix_snippet( - "drop", - "fn drop(&mut self)", - &format!("drop($0{receiver_text})"), - ); - item.set_documentation(drop_fn.docs(ctx.db)); - item.add_to(acc, ctx.db); + if let Some(drop_fn) = ctx.famous_defs().core_mem_drop() { + if let Some(path) = ctx.module.find_use_path( + ctx.db, + ItemInNs::Values(drop_fn.into()), + ctx.config.prefer_no_std, + ctx.config.prefer_prelude, + ) { + cov_mark::hit!(postfix_drop_completion); + let mut item = postfix_snippet( + "drop", + "fn drop(&mut self)", + &format!("{path}($0{receiver_text})", path = path.display(ctx.db)), + ); + item.set_documentation(drop_fn.docs(ctx.db)); + item.add_to(acc, ctx.db); + } } } } @@ -140,6 +148,7 @@ pub(crate) fn complete_postfix( postfix_snippet("ref", "&expr", &format!("&{receiver_text}")).add_to(acc, ctx.db); postfix_snippet("refm", "&mut expr", &format!("&mut {receiver_text}")).add_to(acc, ctx.db); + postfix_snippet("deref", "*expr", &format!("*{receiver_text}")).add_to(acc, ctx.db); let mut unsafe_should_be_wrapped = true; if dot_receiver.syntax().kind() == BLOCK_EXPR { @@ -224,6 +233,28 @@ pub(crate) fn complete_postfix( add_format_like_completions(acc, ctx, &dot_receiver, cap, &literal_text); } } + + postfix_snippet( + "return", + "return expr", + &format!( + "return {receiver_text}{semi}", + semi = if expr_ctx.in_block_expr { ";" } else { "" } + ), + ) + .add_to(acc, ctx.db); + + if let BreakableKind::Block | BreakableKind::Loop = expr_ctx.in_breakable { + postfix_snippet( + "break", + "break expr", + &format!( + "break {receiver_text}{semi}", + semi = if expr_ctx.in_block_expr { ";" } else { "" } + ), + ) + .add_to(acc, ctx.db); + } } fn get_receiver_text(receiver: &ast::Expr, receiver_is_ambiguous_float_literal: bool) -> String { @@ -295,7 +326,7 @@ fn build_postfix_snippet_builder<'ctx>( delete_range: TextRange, ) -> impl Fn(&str, &str, &str) -> Builder + 'ctx { move |label, detail, snippet| { - let edit = TextEdit::replace(delete_range, snippet.to_string()); + let edit = TextEdit::replace(delete_range, snippet.to_owned()); let mut item = CompletionItem::new(CompletionItemKind::Snippet, ctx.source_range(), label); item.detail(detail).snippet_edit(cap, edit); @@ -368,6 +399,7 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn if if expr {} sn let let sn letm let mut @@ -375,6 +407,7 @@ fn main() { sn not !expr sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} sn while while expr {} "#]], @@ -399,11 +432,13 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn if if expr {} sn match match expr {} sn not !expr sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} sn while while expr {} "#]], @@ -424,11 +459,13 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn let let sn letm let mut sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ) @@ -448,6 +485,7 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn if if expr {} sn let let sn letm let mut @@ -455,6 +493,7 @@ fn main() { sn not !expr sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} sn while while expr {} "#]], diff --git a/crates/ide-completion/src/completions/record.rs b/crates/ide-completion/src/completions/record.rs index e53d1cc632..1dcf41f8dd 100644 --- a/crates/ide-completion/src/completions/record.rs +++ b/crates/ide-completion/src/completions/record.rs @@ -6,7 +6,7 @@ use syntax::{ }; use crate::{ - context::{DotAccess, DotAccessKind, PatternContext}, + context::{DotAccess, DotAccessExprCtx, DotAccessKind, PatternContext}, CompletionContext, CompletionItem, CompletionItemKind, CompletionRelevance, CompletionRelevancePostfixMatch, Completions, }; @@ -118,12 +118,17 @@ fn complete_fields( missing_fields: Vec<(hir::Field, hir::Type)>, ) { for (field, ty) in missing_fields { + // This should call something else, we shouldn't be synthesizing a DotAccess here acc.add_field( ctx, &DotAccess { receiver: None, receiver_ty: None, kind: DotAccessKind::Field { receiver_is_ambiguous_float_literal: false }, + ctx: DotAccessExprCtx { + in_block_expr: false, + in_breakable: crate::context::BreakableKind::None, + }, }, None, field, diff --git a/crates/ide-completion/src/completions/snippet.rs b/crates/ide-completion/src/completions/snippet.rs index a019192205..e831113350 100644 --- a/crates/ide-completion/src/completions/snippet.rs +++ b/crates/ide-completion/src/completions/snippet.rs @@ -3,7 +3,7 @@ use ide_db::{documentation::Documentation, imports::insert_use::ImportScope, SnippetCap}; use crate::{ - context::{ExprCtx, ItemListKind, PathCompletionCtx, Qualified}, + context::{ItemListKind, PathCompletionCtx, PathExprCtx, Qualified}, item::Builder, CompletionContext, CompletionItem, CompletionItemKind, Completions, SnippetScope, }; @@ -12,7 +12,7 @@ pub(crate) fn complete_expr_snippet( acc: &mut Completions, ctx: &CompletionContext<'_>, path_ctx: &PathCompletionCtx, - &ExprCtx { in_block_expr, .. }: &ExprCtx, + &PathExprCtx { in_block_expr, .. }: &PathExprCtx, ) { if !matches!(path_ctx.qualified, Qualified::No) { return; diff --git a/crates/ide-completion/src/context.rs b/crates/ide-completion/src/context.rs index 2c0370c58f..2a0004f60b 100644 --- a/crates/ide-completion/src/context.rs +++ b/crates/ide-completion/src/context.rs @@ -45,13 +45,13 @@ pub(crate) enum Visible { /// Existing qualifiers for the thing we are currently completing. #[derive(Debug, Default)] -pub(super) struct QualifierCtx { - pub(super) unsafe_tok: Option, - pub(super) vis_node: Option, +pub(crate) struct QualifierCtx { + pub(crate) unsafe_tok: Option, + pub(crate) vis_node: Option, } impl QualifierCtx { - pub(super) fn none(&self) -> bool { + pub(crate) fn none(&self) -> bool { self.unsafe_tok.is_none() && self.vis_node.is_none() } } @@ -60,27 +60,27 @@ impl QualifierCtx { #[derive(Debug)] pub(crate) struct PathCompletionCtx { /// If this is a call with () already there (or {} in case of record patterns) - pub(super) has_call_parens: bool, + pub(crate) has_call_parens: bool, /// If this has a macro call bang ! - pub(super) has_macro_bang: bool, + pub(crate) has_macro_bang: bool, /// The qualifier of the current path. - pub(super) qualified: Qualified, + pub(crate) qualified: Qualified, /// The parent of the path we are completing. - pub(super) parent: Option, + pub(crate) parent: Option, #[allow(dead_code)] /// The path of which we are completing the segment - pub(super) path: ast::Path, + pub(crate) path: ast::Path, /// The path of which we are completing the segment in the original file pub(crate) original_path: Option, - pub(super) kind: PathKind, + pub(crate) kind: PathKind, /// Whether the path segment has type args or not. - pub(super) has_type_args: bool, + pub(crate) has_type_args: bool, /// Whether the qualifier comes from a use tree parent or not pub(crate) use_tree_parent: bool, } impl PathCompletionCtx { - pub(super) fn is_trivial_path(&self) -> bool { + pub(crate) fn is_trivial_path(&self) -> bool { matches!( self, PathCompletionCtx { @@ -97,9 +97,9 @@ impl PathCompletionCtx { /// The kind of path we are completing right now. #[derive(Debug, PartialEq, Eq)] -pub(super) enum PathKind { +pub(crate) enum PathKind { Expr { - expr_ctx: ExprCtx, + expr_ctx: PathExprCtx, }, Type { location: TypeLocation, @@ -132,9 +132,9 @@ pub(crate) struct AttrCtx { } #[derive(Debug, PartialEq, Eq)] -pub(crate) struct ExprCtx { +pub(crate) struct PathExprCtx { pub(crate) in_block_expr: bool, - pub(crate) in_loop_body: bool, + pub(crate) in_breakable: BreakableKind, pub(crate) after_if_expr: bool, /// Whether this expression is the direct condition of an if or while expression pub(crate) in_condition: bool, @@ -221,7 +221,7 @@ pub(crate) enum TypeAscriptionTarget { /// The kind of item list a [`PathKind::Item`] belongs to. #[derive(Debug, PartialEq, Eq)] -pub(super) enum ItemListKind { +pub(crate) enum ItemListKind { SourceFile, Module, Impl, @@ -231,7 +231,7 @@ pub(super) enum ItemListKind { } #[derive(Debug)] -pub(super) enum Qualified { +pub(crate) enum Qualified { No, With { path: ast::Path, @@ -259,37 +259,37 @@ pub(super) enum Qualified { /// The state of the pattern we are completing. #[derive(Debug, Clone, PartialEq, Eq)] -pub(super) struct PatternContext { - pub(super) refutability: PatternRefutability, - pub(super) param_ctx: Option, - pub(super) has_type_ascription: bool, - pub(super) parent_pat: Option, - pub(super) ref_token: Option, - pub(super) mut_token: Option, +pub(crate) struct PatternContext { + pub(crate) refutability: PatternRefutability, + pub(crate) param_ctx: Option, + pub(crate) has_type_ascription: bool, + pub(crate) parent_pat: Option, + pub(crate) ref_token: Option, + pub(crate) mut_token: Option, /// The record pattern this name or ref is a field of - pub(super) record_pat: Option, - pub(super) impl_: Option, + pub(crate) record_pat: Option, + pub(crate) impl_: Option, /// List of missing variants in a match expr - pub(super) missing_variants: Vec, + pub(crate) missing_variants: Vec, } #[derive(Debug, Clone, PartialEq, Eq)] -pub(super) struct ParamContext { - pub(super) param_list: ast::ParamList, - pub(super) param: ast::Param, - pub(super) kind: ParamKind, +pub(crate) struct ParamContext { + pub(crate) param_list: ast::ParamList, + pub(crate) param: ast::Param, + pub(crate) kind: ParamKind, } /// The state of the lifetime we are completing. #[derive(Debug)] -pub(super) struct LifetimeContext { - pub(super) lifetime: Option, - pub(super) kind: LifetimeKind, +pub(crate) struct LifetimeContext { + pub(crate) lifetime: Option, + pub(crate) kind: LifetimeKind, } /// The kind of lifetime we are completing. #[derive(Debug)] -pub(super) enum LifetimeKind { +pub(crate) enum LifetimeKind { LifetimeParam { is_decl: bool, param: ast::LifetimeParam }, Lifetime, LabelRef, @@ -298,16 +298,16 @@ pub(super) enum LifetimeKind { /// The state of the name we are completing. #[derive(Debug)] -pub(super) struct NameContext { +pub(crate) struct NameContext { #[allow(dead_code)] - pub(super) name: Option, - pub(super) kind: NameKind, + pub(crate) name: Option, + pub(crate) kind: NameKind, } /// The kind of the name we are completing. #[derive(Debug)] #[allow(dead_code)] -pub(super) enum NameKind { +pub(crate) enum NameKind { Const, ConstParam, Enum, @@ -331,15 +331,15 @@ pub(super) enum NameKind { /// The state of the NameRef we are completing. #[derive(Debug)] -pub(super) struct NameRefContext { +pub(crate) struct NameRefContext { /// NameRef syntax in the original file - pub(super) nameref: Option, - pub(super) kind: NameRefKind, + pub(crate) nameref: Option, + pub(crate) kind: NameRefKind, } /// The kind of the NameRef we are completing. #[derive(Debug)] -pub(super) enum NameRefKind { +pub(crate) enum NameRefKind { Path(PathCompletionCtx), DotAccess(DotAccess), /// Position where we are only interested in keyword completions @@ -355,7 +355,7 @@ pub(super) enum NameRefKind { /// The identifier we are currently completing. #[derive(Debug)] -pub(super) enum CompletionAnalysis { +pub(crate) enum CompletionAnalysis { Name(NameContext), NameRef(NameRefContext), Lifetime(LifetimeContext), @@ -376,14 +376,15 @@ pub(super) enum CompletionAnalysis { /// Information about the field or method access we are completing. #[derive(Debug)] -pub(super) struct DotAccess { - pub(super) receiver: Option, - pub(super) receiver_ty: Option, - pub(super) kind: DotAccessKind, +pub(crate) struct DotAccess { + pub(crate) receiver: Option, + pub(crate) receiver_ty: Option, + pub(crate) kind: DotAccessKind, + pub(crate) ctx: DotAccessExprCtx, } #[derive(Debug)] -pub(super) enum DotAccessKind { +pub(crate) enum DotAccessKind { Field { /// True if the receiver is an integer and there is no ident in the original file after it yet /// like `0.$0` @@ -394,6 +395,21 @@ pub(super) enum DotAccessKind { }, } +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct DotAccessExprCtx { + pub(crate) in_block_expr: bool, + pub(crate) in_breakable: BreakableKind, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(crate) enum BreakableKind { + None, + Loop, + For, + While, + Block, +} + #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) enum ParamKind { Function(ast::Fn), @@ -404,39 +420,39 @@ pub(crate) enum ParamKind { /// exactly is the cursor, syntax-wise. #[derive(Debug)] pub(crate) struct CompletionContext<'a> { - pub(super) sema: Semantics<'a, RootDatabase>, - pub(super) scope: SemanticsScope<'a>, - pub(super) db: &'a RootDatabase, - pub(super) config: &'a CompletionConfig, - pub(super) position: FilePosition, + pub(crate) sema: Semantics<'a, RootDatabase>, + pub(crate) scope: SemanticsScope<'a>, + pub(crate) db: &'a RootDatabase, + pub(crate) config: &'a CompletionConfig, + pub(crate) position: FilePosition, /// The token before the cursor, in the original file. - pub(super) original_token: SyntaxToken, + pub(crate) original_token: SyntaxToken, /// The token before the cursor, in the macro-expanded file. - pub(super) token: SyntaxToken, + pub(crate) token: SyntaxToken, /// The crate of the current file. - pub(super) krate: hir::Crate, + pub(crate) krate: hir::Crate, /// The module of the `scope`. - pub(super) module: hir::Module, + pub(crate) module: hir::Module, /// Whether nightly toolchain is used. Cached since this is looked up a lot. is_nightly: bool, /// The expected name of what we are completing. /// This is usually the parameter name of the function argument we are completing. - pub(super) expected_name: Option, + pub(crate) expected_name: Option, /// The expected type of what we are completing. - pub(super) expected_type: Option, + pub(crate) expected_type: Option, - pub(super) qualifier_ctx: QualifierCtx, + pub(crate) qualifier_ctx: QualifierCtx, - pub(super) locals: FxHashMap, + pub(crate) locals: FxHashMap, /// The module depth of the current module of the cursor position. /// - crate-root /// - mod foo /// - mod bar /// Here depth will be 2 - pub(super) depth_from_crate_root: usize, + pub(crate) depth_from_crate_root: usize, } impl CompletionContext<'_> { @@ -634,7 +650,7 @@ impl CompletionContext<'_> { // CompletionContext construction impl<'a> CompletionContext<'a> { - pub(super) fn new( + pub(crate) fn new( db: &'a RootDatabase, position @ FilePosition { file_id, offset }: FilePosition, config: &'a CompletionConfig, @@ -649,7 +665,7 @@ impl<'a> CompletionContext<'a> { // actual completion. let file_with_fake_ident = { let parse = db.parse(file_id); - let edit = Indel::insert(offset, COMPLETION_MARKER.to_string()); + let edit = Indel::insert(offset, COMPLETION_MARKER.to_owned()); parse.reparse(&edit).tree() }; diff --git a/crates/ide-completion/src/context/analysis.rs b/crates/ide-completion/src/context/analysis.rs index c06b64df1c..92af688977 100644 --- a/crates/ide-completion/src/context/analysis.rs +++ b/crates/ide-completion/src/context/analysis.rs @@ -11,10 +11,11 @@ use syntax::{ }; use crate::context::{ - AttrCtx, CompletionAnalysis, DotAccess, DotAccessKind, ExprCtx, ItemListKind, LifetimeContext, - LifetimeKind, NameContext, NameKind, NameRefContext, NameRefKind, ParamContext, ParamKind, - PathCompletionCtx, PathKind, PatternContext, PatternRefutability, Qualified, QualifierCtx, - TypeAscriptionTarget, TypeLocation, COMPLETION_MARKER, + AttrCtx, BreakableKind, CompletionAnalysis, DotAccess, DotAccessExprCtx, DotAccessKind, + ItemListKind, LifetimeContext, LifetimeKind, NameContext, NameKind, NameRefContext, + NameRefKind, ParamContext, ParamKind, PathCompletionCtx, PathExprCtx, PathKind, PatternContext, + PatternRefutability, Qualified, QualifierCtx, TypeAscriptionTarget, TypeLocation, + COMPLETION_MARKER, }; struct ExpansionResult { @@ -623,7 +624,8 @@ fn classify_name_ref( let kind = NameRefKind::DotAccess(DotAccess { receiver_ty: receiver.as_ref().and_then(|it| sema.type_of_expr(it)), kind: DotAccessKind::Field { receiver_is_ambiguous_float_literal }, - receiver + receiver, + ctx: DotAccessExprCtx { in_block_expr: is_in_block(field.syntax()), in_breakable: is_in_breakable(field.syntax()) } }); return Some(make_res(kind)); }, @@ -636,7 +638,8 @@ fn classify_name_ref( let kind = NameRefKind::DotAccess(DotAccess { receiver_ty: receiver.as_ref().and_then(|it| sema.type_of_expr(it)), kind: DotAccessKind::Method { has_parens: method.arg_list().map_or(false, |it| it.l_paren_token().is_some()) }, - receiver + receiver, + ctx: DotAccessExprCtx { in_block_expr: is_in_block(method.syntax()), in_breakable: is_in_breakable(method.syntax()) } }); return Some(make_res(kind)); }, @@ -659,13 +662,6 @@ fn classify_name_ref( use_tree_parent: false, }; - let is_in_block = |it: &SyntaxNode| { - it.parent() - .map(|node| { - ast::ExprStmt::can_cast(node.kind()) || ast::StmtList::can_cast(node.kind()) - }) - .unwrap_or(false) - }; let func_update_record = |syn: &SyntaxNode| { if let Some(record_expr) = syn.ancestors().nth(2).and_then(ast::RecordExpr::cast) { find_node_in_file_compensated(sema, original_file, &record_expr) @@ -932,7 +928,7 @@ fn classify_name_ref( let make_path_kind_expr = |expr: ast::Expr| { let it = expr.syntax(); let in_block_expr = is_in_block(it); - let in_loop_body = is_in_loop_body(it); + let in_loop_body = is_in_breakable(it); let after_if_expr = after_if_expr(it.clone()); let ref_expr_parent = path.as_single_name_ref().and_then(|_| it.parent()).and_then(ast::RefExpr::cast); @@ -998,9 +994,9 @@ fn classify_name_ref( }; PathKind::Expr { - expr_ctx: ExprCtx { + expr_ctx: PathExprCtx { in_block_expr, - in_loop_body, + in_breakable: in_loop_body, after_if_expr, in_condition, ref_expr_parent, @@ -1202,7 +1198,7 @@ fn classify_name_ref( if path_ctx.is_trivial_path() { // fetch the full expression that may have qualifiers attached to it let top_node = match path_ctx.kind { - PathKind::Expr { expr_ctx: ExprCtx { in_block_expr: true, .. } } => { + PathKind::Expr { expr_ctx: PathExprCtx { in_block_expr: true, .. } } => { parent.ancestors().find(|it| ast::PathExpr::can_cast(it.kind())).and_then(|p| { let parent = p.parent()?; if ast::StmtList::can_cast(parent.kind()) { @@ -1467,21 +1463,30 @@ fn is_in_token_of_for_loop(path: &ast::Path) -> bool { .unwrap_or(false) } -fn is_in_loop_body(node: &SyntaxNode) -> bool { +fn is_in_breakable(node: &SyntaxNode) -> BreakableKind { node.ancestors() .take_while(|it| it.kind() != SyntaxKind::FN && it.kind() != SyntaxKind::CLOSURE_EXPR) .find_map(|it| { - let loop_body = match_ast! { + let (breakable, loop_body) = match_ast! { match it { - ast::ForExpr(it) => it.loop_body(), - ast::WhileExpr(it) => it.loop_body(), - ast::LoopExpr(it) => it.loop_body(), - _ => None, + ast::ForExpr(it) => (BreakableKind::For, it.loop_body()), + ast::WhileExpr(it) => (BreakableKind::While, it.loop_body()), + ast::LoopExpr(it) => (BreakableKind::Loop, it.loop_body()), + ast::BlockExpr(it) => return it.label().map(|_| BreakableKind::Block), + _ => return None, } }; - loop_body.filter(|it| it.syntax().text_range().contains_range(node.text_range())) + loop_body + .filter(|it| it.syntax().text_range().contains_range(node.text_range())) + .map(|_| breakable) }) - .is_some() + .unwrap_or(BreakableKind::None) +} + +fn is_in_block(node: &SyntaxNode) -> bool { + node.parent() + .map(|node| ast::ExprStmt::can_cast(node.kind()) || ast::StmtList::can_cast(node.kind())) + .unwrap_or(false) } fn previous_non_trivia_token(e: impl Into) -> Option { diff --git a/crates/ide-completion/src/item.rs b/crates/ide-completion/src/item.rs index bcf169f465..8552a20392 100644 --- a/crates/ide-completion/src/item.rs +++ b/crates/ide-completion/src/item.rs @@ -553,7 +553,7 @@ impl Builder { self.detail = detail.map(Into::into); if let Some(detail) = &self.detail { if never!(detail.contains('\n'), "multiline detail:\n{}", detail) { - self.detail = Some(detail.split('\n').next().unwrap().to_string()); + self.detail = Some(detail.split('\n').next().unwrap().to_owned()); } } self diff --git a/crates/ide-completion/src/lib.rs b/crates/ide-completion/src/lib.rs index 733523d369..912f2fba2b 100644 --- a/crates/ide-completion/src/lib.rs +++ b/crates/ide-completion/src/lib.rs @@ -64,6 +64,7 @@ pub use crate::{ // - `expr.ref` -> `&expr` // - `expr.refm` -> `&mut expr` // - `expr.let` -> `let $0 = expr;` +// - `expr.lete` -> `let $1 = expr else { $0 };` // - `expr.letm` -> `let mut $0 = expr;` // - `expr.not` -> `!expr` // - `expr.dbg` -> `dbg!(expr)` diff --git a/crates/ide-completion/src/render.rs b/crates/ide-completion/src/render.rs index 4d49d2f498..2ed080a834 100644 --- a/crates/ide-completion/src/render.rs +++ b/crates/ide-completion/src/render.rs @@ -112,7 +112,7 @@ impl<'a> RenderContext<'a> { }; is_assoc_deprecated || assoc - .containing_trait_or_trait_impl(db) + .container_or_implemented_trait(db) .map(|trait_| self.is_deprecated(trait_)) .unwrap_or(false) } @@ -167,14 +167,14 @@ pub(crate) fn render_field( if !expected_fn_type { if let Some(receiver) = &dot_access.receiver { if let Some(receiver) = ctx.completion.sema.original_ast_node(receiver.clone()) { - builder.insert(receiver.syntax().text_range().start(), "(".to_string()); - builder.insert(ctx.source_range().end(), ")".to_string()); + builder.insert(receiver.syntax().text_range().start(), "(".to_owned()); + builder.insert(ctx.source_range().end(), ")".to_owned()); let is_parens_needed = !matches!(dot_access.kind, DotAccessKind::Method { has_parens: true }); if is_parens_needed { - builder.insert(ctx.source_range().end(), "()".to_string()); + builder.insert(ctx.source_range().end(), "()".to_owned()); } } } @@ -2199,12 +2199,14 @@ fn main() { sn while [] sn ref [] sn refm [] + sn deref [] sn unsafe [] sn match [] sn box [] sn dbg [] sn dbgr [] sn call [] + sn return [] "#]], ); } @@ -2227,6 +2229,7 @@ fn main() { me f() [] sn ref [] sn refm [] + sn deref [] sn unsafe [] sn match [] sn box [] @@ -2235,6 +2238,7 @@ fn main() { sn call [] sn let [] sn letm [] + sn return [] "#]], ); } diff --git a/crates/ide-completion/src/render/const_.rs b/crates/ide-completion/src/render/const_.rs index a2bfac994f..0d24882156 100644 --- a/crates/ide-completion/src/render/const_.rs +++ b/crates/ide-completion/src/render/const_.rs @@ -23,7 +23,7 @@ fn render(ctx: RenderContext<'_>, const_: hir::Const) -> Option .set_relevance(ctx.completion_relevance()); if let Some(actm) = const_.as_assoc_item(db) { - if let Some(trt) = actm.containing_trait_or_trait_impl(db) { + if let Some(trt) = actm.container_or_implemented_trait(db) { item.trait_name(trt.name(db).to_smol_str()); } } diff --git a/crates/ide-completion/src/render/function.rs b/crates/ide-completion/src/render/function.rs index 4ae7ea861c..27186a2b7f 100644 --- a/crates/ide-completion/src/render/function.rs +++ b/crates/ide-completion/src/render/function.rs @@ -75,7 +75,7 @@ fn render( let ret_type = func.ret_type(db); let assoc_item = func.as_assoc_item(db); - let trait_ = assoc_item.and_then(|trait_| trait_.containing_trait_or_trait_impl(db)); + let trait_ = assoc_item.and_then(|trait_| trait_.container_or_implemented_trait(db)); let is_op_method = trait_.map_or(false, |trait_| completion.is_ops_trait(trait_)); let is_item_from_notable_trait = @@ -145,7 +145,7 @@ fn render( } None => { if let Some(actm) = assoc_item { - if let Some(trt) = actm.containing_trait_or_trait_impl(db) { + if let Some(trt) = actm.container_or_implemented_trait(db) { item.trait_name(trt.name(db).to_smol_str()); } } @@ -184,12 +184,12 @@ pub(super) fn add_call_parens<'b>( } None => { let name = match param.ty().as_adt() { - None => "_".to_string(), + None => "_".to_owned(), Some(adt) => adt .name(ctx.db) .as_text() .map(|s| to_lower_snake_case(s.as_str())) - .unwrap_or_else(|| "_".to_string()), + .unwrap_or_else(|| "_".to_owned()), }; f(&format_args!("${{{}:{name}}}", index + offset)) } diff --git a/crates/ide-completion/src/render/pattern.rs b/crates/ide-completion/src/render/pattern.rs index a5f851566c..c07966f7a7 100644 --- a/crates/ide-completion/src/render/pattern.rs +++ b/crates/ide-completion/src/render/pattern.rs @@ -140,7 +140,7 @@ fn render_pat( StructKind::Record => { render_record_as_pat(ctx.db(), ctx.snippet_cap(), fields, name, fields_omitted) } - StructKind::Unit => name.to_string(), + StructKind::Unit => name.to_owned(), }; let needs_ascription = matches!( diff --git a/crates/ide-completion/src/render/type_alias.rs b/crates/ide-completion/src/render/type_alias.rs index b192309e93..8f80793dd7 100644 --- a/crates/ide-completion/src/render/type_alias.rs +++ b/crates/ide-completion/src/render/type_alias.rs @@ -47,7 +47,7 @@ fn render( .set_relevance(ctx.completion_relevance()); if let Some(actm) = type_alias.as_assoc_item(db) { - if let Some(trt) = actm.containing_trait_or_trait_impl(db) { + if let Some(trt) = actm.container_or_implemented_trait(db) { item.trait_name(trt.name(db).to_smol_str()); } } diff --git a/crates/ide-completion/src/render/variant.rs b/crates/ide-completion/src/render/variant.rs index a9a01a3a30..28238de455 100644 --- a/crates/ide-completion/src/render/variant.rs +++ b/crates/ide-completion/src/render/variant.rs @@ -23,7 +23,7 @@ pub(crate) fn render_record_lit( path: &str, ) -> RenderedLiteral { if snippet_cap.is_none() { - return RenderedLiteral { literal: path.to_string(), detail: path.to_string() }; + return RenderedLiteral { literal: path.to_owned(), detail: path.to_owned() }; } let completions = fields.iter().enumerate().format_with(", ", |(idx, field), f| { if snippet_cap.is_some() { @@ -52,7 +52,7 @@ pub(crate) fn render_tuple_lit( path: &str, ) -> RenderedLiteral { if snippet_cap.is_none() { - return RenderedLiteral { literal: path.to_string(), detail: path.to_string() }; + return RenderedLiteral { literal: path.to_owned(), detail: path.to_owned() }; } let completions = fields.iter().enumerate().format_with(", ", |(idx, _), f| { if snippet_cap.is_some() { diff --git a/crates/ide-completion/src/tests/expression.rs b/crates/ide-completion/src/tests/expression.rs index 758c254a88..78907a2896 100644 --- a/crates/ide-completion/src/tests/expression.rs +++ b/crates/ide-completion/src/tests/expression.rs @@ -362,6 +362,27 @@ fn completes_in_loop_ctx() { sn ppd "#]], ); + check_empty( + r"fn my() { loop { foo.$0 } }", + expect![[r#" + sn box Box::new(expr) + sn break break expr + sn call function(expr) + sn dbg dbg!(expr) + sn dbgr dbg!(&expr) + sn deref *expr + sn if if expr {} + sn let let + sn letm let mut + sn match match expr {} + sn not !expr + sn ref &expr + sn refm &mut expr + sn return return expr + sn unsafe unsafe {} + sn while while expr {} + "#]], + ); } #[test] @@ -1115,9 +1136,11 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ); @@ -1139,9 +1162,11 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ); @@ -1167,9 +1192,11 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ); @@ -1191,9 +1218,11 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ); @@ -1215,9 +1244,11 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ); @@ -1238,11 +1269,13 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn if if expr {} sn match match expr {} sn not !expr sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} sn while while expr {} "#]], diff --git a/crates/ide-completion/src/tests/proc_macros.rs b/crates/ide-completion/src/tests/proc_macros.rs index 2d6234e310..613f33309f 100644 --- a/crates/ide-completion/src/tests/proc_macros.rs +++ b/crates/ide-completion/src/tests/proc_macros.rs @@ -29,11 +29,13 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn let let sn letm let mut sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ) @@ -60,11 +62,13 @@ fn main() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn let let sn letm let mut sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ) @@ -93,11 +97,13 @@ fn main() {} sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn let let sn letm let mut sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ) @@ -126,11 +132,13 @@ fn main() {} sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn let let sn letm let mut sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ) diff --git a/crates/ide-completion/src/tests/special.rs b/crates/ide-completion/src/tests/special.rs index f96fb71f28..a87d16c789 100644 --- a/crates/ide-completion/src/tests/special.rs +++ b/crates/ide-completion/src/tests/special.rs @@ -1157,11 +1157,13 @@ fn here_we_go() { sn call function(expr) sn dbg dbg!(expr) sn dbgr dbg!(&expr) + sn deref *expr sn let let sn letm let mut sn match match expr {} sn ref &expr sn refm &mut expr + sn return return expr sn unsafe unsafe {} "#]], ); diff --git a/crates/ide-db/src/defs.rs b/crates/ide-db/src/defs.rs index 81f2f87d96..d95d94ec72 100644 --- a/crates/ide-db/src/defs.rs +++ b/crates/ide-db/src/defs.rs @@ -206,7 +206,7 @@ impl Definition { // docs are missing, for assoc items of trait impls try to fall back to the docs of the // original item of the trait let assoc = self.as_assoc_item(db)?; - let trait_ = assoc.containing_trait_impl(db)?; + let trait_ = assoc.implemented_trait(db)?; let name = Some(assoc.name(db)?); let item = trait_.items(db).into_iter().find(|it| it.name(db) == name)?; item.docs(db) diff --git a/crates/ide-db/src/famous_defs.rs b/crates/ide-db/src/famous_defs.rs index 722517a767..4edfa37b32 100644 --- a/crates/ide-db/src/famous_defs.rs +++ b/crates/ide-db/src/famous_defs.rs @@ -1,7 +1,7 @@ //! See [`FamousDefs`]. use base_db::{CrateOrigin, LangCrateOrigin, SourceDatabase}; -use hir::{Crate, Enum, Macro, Module, ScopeDef, Semantics, Trait}; +use hir::{Crate, Enum, Function, Macro, Module, ScopeDef, Semantics, Trait}; use crate::RootDatabase; @@ -110,6 +110,10 @@ impl FamousDefs<'_, '_> { self.find_macro("core:macros:builtin:derive") } + pub fn core_mem_drop(&self) -> Option { + self.find_function("core:mem:drop") + } + pub fn builtin_crates(&self) -> impl Iterator { IntoIterator::into_iter([ self.std(), @@ -149,6 +153,13 @@ impl FamousDefs<'_, '_> { } } + fn find_function(&self, path: &str) -> Option { + match self.find_def(path)? { + hir::ScopeDef::ModuleDef(hir::ModuleDef::Function(it)) => Some(it), + _ => None, + } + } + fn find_lang_crate(&self, origin: LangCrateOrigin) -> Option { let krate = self.1; let db = self.0.db; diff --git a/crates/ide-db/src/generated/lints.rs b/crates/ide-db/src/generated/lints.rs index 677c8fd54c..2fc0793320 100644 --- a/crates/ide-db/src/generated/lints.rs +++ b/crates/ide-db/src/generated/lints.rs @@ -5044,17 +5044,6 @@ against are compatible with those of the `#[ffi_pure]`. [ARM C/C++ compiler]: http://infocenter.arm.com/help/index.jsp?topic=/com.arm.doc.dui0491c/Cacigdac.html [GCC]: https://gcc.gnu.org/onlinedocs/gcc/Common-Function-Attributes.html#index-pure-function-attribute [IBM ILE C/C++]: https://www.ibm.com/support/knowledgecenter/fr/ssw_ibm_i_71/rzarg/fn_attrib_pure.htm -"##, - }, - Lint { - label: "ffi_returns_twice", - description: r##"# `ffi_returns_twice` - -The tracking issue for this feature is: [#58314] - -[#58314]: https://github.com/rust-lang/rust/issues/58314 - ------------------------- "##, }, Lint { diff --git a/crates/ide-db/src/imports/import_assets.rs b/crates/ide-db/src/imports/import_assets.rs index cb3f01f345..a71d8e9002 100644 --- a/crates/ide-db/src/imports/import_assets.rs +++ b/crates/ide-db/src/imports/import_assets.rs @@ -529,7 +529,7 @@ fn trait_applicable_items( return None; } - let assoc_item_trait = assoc.containing_trait(db)?; + let assoc_item_trait = assoc.container_trait(db)?; if related_traits.contains(&assoc_item_trait) { return None; } @@ -550,8 +550,7 @@ fn trait_applicable_items( None, |assoc| { if required_assoc_items.contains(&assoc) { - let located_trait = - assoc.containing_trait(db).filter(|&it| scope_filter(it))?; + let located_trait = assoc.container_trait(db).filter(|&it| scope_filter(it))?; let trait_item = ItemInNs::from(ModuleDef::from(located_trait)); let import_path = trait_import_paths .entry(trait_item) @@ -576,8 +575,7 @@ fn trait_applicable_items( |function| { let assoc = function.as_assoc_item(db)?; if required_assoc_items.contains(&assoc) { - let located_trait = - assoc.containing_trait(db).filter(|&it| scope_filter(it))?; + let located_trait = assoc.container_trait(db).filter(|&it| scope_filter(it))?; let trait_item = ItemInNs::from(ModuleDef::from(located_trait)); let import_path = trait_import_paths .entry(trait_item) @@ -605,6 +603,7 @@ fn assoc_to_item(assoc: AssocItem) -> ItemInNs { } } +#[tracing::instrument(skip_all)] fn get_mod_path( db: &RootDatabase, item_to_search: ItemInNs, diff --git a/crates/ide-db/src/imports/insert_use.rs b/crates/ide-db/src/imports/insert_use.rs index f29f91eea8..bd5c464c55 100644 --- a/crates/ide-db/src/imports/insert_use.rs +++ b/crates/ide-db/src/imports/insert_use.rs @@ -214,7 +214,7 @@ fn insert_use_with_alias_option( }; } - let mut use_tree = make::use_tree(path.clone(), None, alias, false); + let mut use_tree = make::use_tree(path, None, alias, false); if mb == Some(MergeBehavior::One) && use_tree.path().is_some() { use_tree = use_tree.clone_for_update(); use_tree.wrap_in_tree_list(); diff --git a/crates/ide-db/src/syntax_helpers/format_string_exprs.rs b/crates/ide-db/src/syntax_helpers/format_string_exprs.rs index fc23081819..49594aee9f 100644 --- a/crates/ide-db/src/syntax_helpers/format_string_exprs.rs +++ b/crates/ide-db/src/syntax_helpers/format_string_exprs.rs @@ -203,7 +203,7 @@ mod tests { use expect_test::{expect, Expect}; fn check(input: &str, expect: &Expect) { - let (output, exprs) = parse_format_exprs(input).unwrap_or(("-".to_string(), vec![])); + let (output, exprs) = parse_format_exprs(input).unwrap_or(("-".to_owned(), vec![])); let outcome_repr = if !exprs.is_empty() { format!("{output}; {}", with_placeholders(exprs).join(", ")) } else { diff --git a/crates/ide-db/src/tests/sourcegen_lints.rs b/crates/ide-db/src/tests/sourcegen_lints.rs index c8cf87d3c2..86ed01c8e7 100644 --- a/crates/ide-db/src/tests/sourcegen_lints.rs +++ b/crates/ide-db/src/tests/sourcegen_lints.rs @@ -52,7 +52,7 @@ pub struct LintGroup { generate_lint_descriptor(sh, &mut contents); contents.push('\n'); - let cargo = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_string()); + let cargo = std::env::var("CARGO").unwrap_or_else(|_| "cargo".to_owned()); let unstable_book = project_root().join("./target/unstable-book-gen"); cmd!( sh, @@ -241,6 +241,7 @@ fn unescape(s: &str) -> String { s.replace(r#"\""#, "").replace(r#"\n"#, "\n").replace(r#"\r"#, "") } +#[allow(clippy::print_stderr)] fn generate_descriptor_clippy(buf: &mut String, path: &Path) { let file_content = std::fs::read_to_string(path).unwrap(); let mut clippy_lints: Vec = Vec::new(); @@ -282,7 +283,7 @@ fn generate_descriptor_clippy(buf: &mut String, path: &Path) { let line = &line[..up_to]; let clippy_lint = clippy_lints.last_mut().expect("clippy lint must already exist"); - clippy_lint.help = unescape(line).trim().to_string(); + clippy_lint.help = unescape(line).trim().to_owned(); } } clippy_lints.sort_by(|lint, lint2| lint.id.cmp(&lint2.id)); diff --git a/crates/ide-db/src/traits.rs b/crates/ide-db/src/traits.rs index bbdfd81d65..ee7c448bb8 100644 --- a/crates/ide-db/src/traits.rs +++ b/crates/ide-db/src/traits.rs @@ -75,7 +75,7 @@ pub fn get_missing_assoc_items( pub(crate) fn convert_to_def_in_trait(db: &dyn HirDatabase, def: Definition) -> Definition { (|| { let assoc = def.as_assoc_item(db)?; - let trait_ = assoc.containing_trait_impl(db)?; + let trait_ = assoc.implemented_trait(db)?; assoc_item_of_trait(db, assoc, trait_) })() .unwrap_or(def) diff --git a/crates/ide-diagnostics/src/handlers/inactive_code.rs b/crates/ide-diagnostics/src/handlers/inactive_code.rs index 9f754f9c6f..7db5ea04fb 100644 --- a/crates/ide-diagnostics/src/handlers/inactive_code.rs +++ b/crates/ide-diagnostics/src/handlers/inactive_code.rs @@ -16,7 +16,7 @@ pub(crate) fn inactive_code( } let inactive = DnfExpr::new(d.cfg.clone()).why_inactive(&d.opts); - let mut message = "code is inactive due to #[cfg] directives".to_string(); + let mut message = "code is inactive due to #[cfg] directives".to_owned(); if let Some(inactive) = inactive { let inactive_reasons = inactive.to_string(); diff --git a/crates/ide-diagnostics/src/handlers/incoherent_impl.rs b/crates/ide-diagnostics/src/handlers/incoherent_impl.rs index 3b4d400912..9f56e10414 100644 --- a/crates/ide-diagnostics/src/handlers/incoherent_impl.rs +++ b/crates/ide-diagnostics/src/handlers/incoherent_impl.rs @@ -9,7 +9,7 @@ pub(crate) fn incoherent_impl(ctx: &DiagnosticsContext<'_>, d: &hir::IncoherentI Diagnostic::new_with_syntax_node_ptr( ctx, DiagnosticCode::RustcHardError("E0210"), - "cannot define inherent `impl` for foreign type".to_string(), + "cannot define inherent `impl` for foreign type".to_owned(), InFile::new(d.file_id, d.impl_.into()), ) } diff --git a/crates/ide-diagnostics/src/handlers/incorrect_case.rs b/crates/ide-diagnostics/src/handlers/incorrect_case.rs index f5a6aa1197..dd64b93e45 100644 --- a/crates/ide-diagnostics/src/handlers/incorrect_case.rs +++ b/crates/ide-diagnostics/src/handlers/incorrect_case.rs @@ -52,7 +52,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::IncorrectCase) -> Option bool { false } @@ -403,7 +402,9 @@ impl T for U { trait __BitFlagsBad { const HiImAlsoBad: u8 = 2; + // ^^^^^^^^^^^ 💡 warn: Constant `HiImAlsoBad` should have UPPER_SNAKE_CASE name, e.g. `HI_IM_ALSO_BAD` fn Dirty(&self) -> bool { false } + // ^^^^^💡 warn: Function `Dirty` should have snake_case name, e.g. `dirty` } } } @@ -462,19 +463,59 @@ extern { } #[test] - fn bug_traits_arent_checked() { - // FIXME: Traits and functions in traits aren't currently checked by - // r-a, even though rustc will complain about them. + fn incorrect_trait_and_assoc_item_names() { check_diagnostics( r#" trait BAD_TRAIT { + // ^^^^^^^^^ 💡 warn: Trait `BAD_TRAIT` should have CamelCase name, e.g. `BadTrait` + const bad_const: u8; + // ^^^^^^^^^ 💡 warn: Constant `bad_const` should have UPPER_SNAKE_CASE name, e.g. `BAD_CONST` + type BAD_TYPE; + // ^^^^^^^^ 💡 warn: Type alias `BAD_TYPE` should have CamelCase name, e.g. `BadType` fn BAD_FUNCTION(); + // ^^^^^^^^^^^^ 💡 warn: Function `BAD_FUNCTION` should have snake_case name, e.g. `bad_function` fn BadFunction(); + // ^^^^^^^^^^^ 💡 warn: Function `BadFunction` should have snake_case name, e.g. `bad_function` } "#, ); } + #[test] + fn no_diagnostics_for_trait_impl_assoc_items_except_pats_in_body() { + cov_mark::check!(trait_impl_assoc_const_incorrect_case_ignored); + cov_mark::check!(trait_impl_assoc_type_incorrect_case_ignored); + cov_mark::check_count!(trait_impl_assoc_func_name_incorrect_case_ignored, 2); + check_diagnostics_with_disabled( + r#" +trait BAD_TRAIT { + // ^^^^^^^^^ 💡 warn: Trait `BAD_TRAIT` should have CamelCase name, e.g. `BadTrait` + const bad_const: u8; + // ^^^^^^^^^ 💡 warn: Constant `bad_const` should have UPPER_SNAKE_CASE name, e.g. `BAD_CONST` + type BAD_TYPE; + // ^^^^^^^^ 💡 warn: Type alias `BAD_TYPE` should have CamelCase name, e.g. `BadType` + fn BAD_FUNCTION(BAD_PARAM: u8); + // ^^^^^^^^^^^^ 💡 warn: Function `BAD_FUNCTION` should have snake_case name, e.g. `bad_function` + // ^^^^^^^^^ 💡 warn: Parameter `BAD_PARAM` should have snake_case name, e.g. `bad_param` + fn BadFunction(); + // ^^^^^^^^^^^ 💡 warn: Function `BadFunction` should have snake_case name, e.g. `bad_function` +} + +impl BAD_TRAIT for () { + const bad_const: u8 = 0; + type BAD_TYPE = (); + fn BAD_FUNCTION(BAD_PARAM: u8) { + // ^^^^^^^^^ 💡 warn: Parameter `BAD_PARAM` should have snake_case name, e.g. `bad_param` + let BAD_VAR = 0; + // ^^^^^^^ 💡 warn: Variable `BAD_VAR` should have snake_case name, e.g. `bad_var` + } + fn BadFunction() {} +} + "#, + std::iter::once("unused_variables".to_owned()), + ); + } + #[test] fn allow_attributes() { check_diagnostics( @@ -519,6 +560,14 @@ pub const some_const: u8 = 10; #[allow(non_upper_case_globals)] pub static SomeStatic: u8 = 10; + +#[allow(non_snake_case, non_camel_case_types, non_upper_case_globals)] +trait BAD_TRAIT { + const bad_const: u8; + type BAD_TYPE; + fn BAD_FUNCTION(BAD_PARAM: u8); + fn BadFunction(); +} "#, ); } @@ -578,6 +627,20 @@ pub const some_const: u8 = 10; #[deny(non_upper_case_globals)] pub static SomeStatic: u8 = 10; //^^^^^^^^^^ 💡 error: Static variable `SomeStatic` should have UPPER_SNAKE_CASE name, e.g. `SOME_STATIC` + +#[deny(non_snake_case, non_camel_case_types, non_upper_case_globals)] +trait BAD_TRAIT { + // ^^^^^^^^^ 💡 error: Trait `BAD_TRAIT` should have CamelCase name, e.g. `BadTrait` + const bad_const: u8; + // ^^^^^^^^^ 💡 error: Constant `bad_const` should have UPPER_SNAKE_CASE name, e.g. `BAD_CONST` + type BAD_TYPE; + // ^^^^^^^^ 💡 error: Type alias `BAD_TYPE` should have CamelCase name, e.g. `BadType` + fn BAD_FUNCTION(BAD_PARAM: u8); + // ^^^^^^^^^^^^ 💡 error: Function `BAD_FUNCTION` should have snake_case name, e.g. `bad_function` + // ^^^^^^^^^ 💡 error: Parameter `BAD_PARAM` should have snake_case name, e.g. `bad_param` + fn BadFunction(); + // ^^^^^^^^^^^ 💡 error: Function `BadFunction` should have snake_case name, e.g. `bad_function` +} "#, ); } diff --git a/crates/ide-diagnostics/src/handlers/json_is_not_rust.rs b/crates/ide-diagnostics/src/handlers/json_is_not_rust.rs index d330973aaa..241fddbb90 100644 --- a/crates/ide-diagnostics/src/handlers/json_is_not_rust.rs +++ b/crates/ide-diagnostics/src/handlers/json_is_not_rust.rs @@ -42,12 +42,12 @@ impl State { v.push("Deserialize"); } match v.as_slice() { - [] => "".to_string(), + [] => "".to_owned(), [x] => format!("#[derive({x})]\n"), [x, y] => format!("#[derive({x}, {y})]\n"), _ => { never!(); - "".to_string() + "".to_owned() } } } @@ -176,7 +176,7 @@ mod tests { #[test] fn diagnostic_for_simple_case() { let mut config = DiagnosticsConfig::test_sample(); - config.disabled.insert("syntax-error".to_string()); + config.disabled.insert("syntax-error".to_owned()); check_diagnostics_with_config( config, r#" diff --git a/crates/ide-diagnostics/src/handlers/macro_error.rs b/crates/ide-diagnostics/src/handlers/macro_error.rs index fc5c715981..e4cb53f3a2 100644 --- a/crates/ide-diagnostics/src/handlers/macro_error.rs +++ b/crates/ide-diagnostics/src/handlers/macro_error.rs @@ -99,7 +99,7 @@ pub macro panic { // FIXME: This is a false-positive, the file is actually linked in via // `include!` macro - config.disabled.insert("unlinked-file".to_string()); + config.disabled.insert("unlinked-file".to_owned()); check_diagnostics_with_config( config, @@ -268,8 +268,8 @@ fn f() { #[test] fn include_does_not_break_diagnostics() { let mut config = DiagnosticsConfig::test_sample(); - config.disabled.insert("inactive-code".to_string()); - config.disabled.insert("unlinked-file".to_string()); + config.disabled.insert("inactive-code".to_owned()); + config.disabled.insert("unlinked-file".to_owned()); check_diagnostics_with_config( config, r#" diff --git a/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs b/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs index 66ebf59350..41c762c85b 100644 --- a/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs +++ b/crates/ide-diagnostics/src/handlers/mismatched_arg_count.rs @@ -199,7 +199,7 @@ fn f() { // future, but we shouldn't emit an argument count diagnostic here check_diagnostics( r#" -trait Foo { fn method(&self, arg: usize) {} } +trait Foo { fn method(&self, _arg: usize) {} } fn f() { let x; diff --git a/crates/ide-diagnostics/src/handlers/missing_fields.rs b/crates/ide-diagnostics/src/handlers/missing_fields.rs index 37ac912f06..c70f39eb28 100644 --- a/crates/ide-diagnostics/src/handlers/missing_fields.rs +++ b/crates/ide-diagnostics/src/handlers/missing_fields.rs @@ -59,9 +59,15 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingFields) -> Option, d: &hir::MissingFields) -> Option, d: &hir::MissingFields) -> Option, d: &hir::MissingFields) -> Option { let missing_fields = ctx.sema.record_pattern_missing_fields(field_list_parent); @@ -160,11 +161,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::MissingFields) -> Option ast::Ty let ty_str = match ty.as_adt() { Some(adt) => adt.name(db).display(db.upcast()).to_string(), None => { - ty.display_source_code(db, module.into(), false).ok().unwrap_or_else(|| "_".to_string()) + ty.display_source_code(db, module.into(), false).ok().unwrap_or_else(|| "_".to_owned()) } }; diff --git a/crates/ide-diagnostics/src/handlers/missing_match_arms.rs b/crates/ide-diagnostics/src/handlers/missing_match_arms.rs index cb6d568442..17dc679e05 100644 --- a/crates/ide-diagnostics/src/handlers/missing_match_arms.rs +++ b/crates/ide-diagnostics/src/handlers/missing_match_arms.rs @@ -31,7 +31,7 @@ mod tests { #[test] fn empty_body() { let mut config = DiagnosticsConfig::test_sample(); - config.disabled.insert("syntax-error".to_string()); + config.disabled.insert("syntax-error".to_owned()); check_diagnostics_with_config( config, r#" diff --git a/crates/ide-diagnostics/src/handlers/mutability_errors.rs b/crates/ide-diagnostics/src/handlers/mutability_errors.rs index 773a075f8f..bdb55a9d98 100644 --- a/crates/ide-diagnostics/src/handlers/mutability_errors.rs +++ b/crates/ide-diagnostics/src/handlers/mutability_errors.rs @@ -19,7 +19,7 @@ pub(crate) fn need_mut(ctx: &DiagnosticsContext<'_>, d: &hir::NeedMut) -> Diagno for source in d.local.sources(ctx.sema.db) { let Some(ast) = source.name() else { continue }; // FIXME: macros - edit_builder.insert(ast.value.syntax().text_range().start(), "mut ".to_string()); + edit_builder.insert(ast.value.syntax().text_range().start(), "mut ".to_owned()); } let edit = edit_builder.finish(); Some(vec![fix( @@ -86,7 +86,7 @@ pub(super) fn token(parent: &SyntaxNode, kind: SyntaxKind) -> Option X { @@ -448,8 +448,9 @@ fn main(b: bool) { &mut x; } "#, + std::iter::once("remove-unnecessary-else".to_owned()), ); - check_diagnostics( + check_diagnostics_with_disabled( r#" fn main(b: bool) { if b { @@ -462,6 +463,7 @@ fn main(b: bool) { &mut x; } "#, + std::iter::once("remove-unnecessary-else".to_owned()), ); } diff --git a/crates/ide-diagnostics/src/handlers/private_field.rs b/crates/ide-diagnostics/src/handlers/private_field.rs index 3179a632e2..e91e64c81b 100644 --- a/crates/ide-diagnostics/src/handlers/private_field.rs +++ b/crates/ide-diagnostics/src/handlers/private_field.rs @@ -83,6 +83,32 @@ fn main() { }; strukt.field; } +"#, + ); + } + + #[test] + fn block_module_madness2() { + check_diagnostics( + r#" +fn main() { + use crate as ForceParentBlockDefMap; + let strukt = { + use crate as ForceParentBlockDefMap; + { + pub struct Struct { + field: (), + } + { + use crate as ForceParentBlockDefMap; + { + Struct { field: () } + } + } + } + }; + strukt.field; +} "#, ); } diff --git a/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs b/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs new file mode 100644 index 0000000000..a0d5d742d3 --- /dev/null +++ b/crates/ide-diagnostics/src/handlers/remove_trailing_return.rs @@ -0,0 +1,375 @@ +use hir::{db::ExpandDatabase, diagnostics::RemoveTrailingReturn}; +use ide_db::{assists::Assist, base_db::FileRange, source_change::SourceChange}; +use syntax::{ast, AstNode}; +use text_edit::TextEdit; + +use crate::{adjusted_display_range, fix, Diagnostic, DiagnosticCode, DiagnosticsContext}; + +// Diagnostic: remove-trailing-return +// +// This diagnostic is triggered when there is a redundant `return` at the end of a function +// or closure. +pub(crate) fn remove_trailing_return( + ctx: &DiagnosticsContext<'_>, + d: &RemoveTrailingReturn, +) -> Diagnostic { + let display_range = adjusted_display_range(ctx, d.return_expr, &|return_expr| { + return_expr + .syntax() + .parent() + .and_then(ast::ExprStmt::cast) + .map(|stmt| stmt.syntax().text_range()) + }); + Diagnostic::new( + DiagnosticCode::Clippy("needless_return"), + "replace return ; with ", + display_range, + ) + .with_fixes(fixes(ctx, d)) +} + +fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveTrailingReturn) -> Option> { + let root = ctx.sema.db.parse_or_expand(d.return_expr.file_id); + let return_expr = d.return_expr.value.to_node(&root); + let stmt = return_expr.syntax().parent().and_then(ast::ExprStmt::cast); + + let FileRange { range, file_id } = + ctx.sema.original_range_opt(stmt.as_ref().map_or(return_expr.syntax(), AstNode::syntax))?; + if Some(file_id) != d.return_expr.file_id.file_id() { + return None; + } + + let replacement = + return_expr.expr().map_or_else(String::new, |expr| format!("{}", expr.syntax().text())); + let edit = TextEdit::replace(range, replacement); + let source_change = SourceChange::from_text_edit(file_id, edit); + + Some(vec![fix( + "remove_trailing_return", + "Replace return ; with ", + source_change, + range, + )]) +} + +#[cfg(test)] +mod tests { + use crate::tests::{ + check_diagnostics, check_diagnostics_with_disabled, check_fix, check_fix_with_disabled, + }; + + #[test] + fn remove_trailing_return() { + check_diagnostics( + r#" +fn foo() -> u8 { + return 2; +} //^^^^^^^^^ 💡 weak: replace return ; with +"#, + ); + } + + #[test] + fn remove_trailing_return_inner_function() { + check_diagnostics( + r#" +fn foo() -> u8 { + fn bar() -> u8 { + return 2; + } //^^^^^^^^^ 💡 weak: replace return ; with + bar() +} +"#, + ); + } + + #[test] + fn remove_trailing_return_closure() { + check_diagnostics( + r#" +fn foo() -> u8 { + let bar = || return 2; + bar() //^^^^^^^^ 💡 weak: replace return ; with +} +"#, + ); + check_diagnostics( + r#" +fn foo() -> u8 { + let bar = || { + return 2; + };//^^^^^^^^^ 💡 weak: replace return ; with + bar() +} +"#, + ); + } + + #[test] + fn remove_trailing_return_unit() { + check_diagnostics( + r#" +fn foo() { + return +} //^^^^^^ 💡 weak: replace return ; with +"#, + ); + } + + #[test] + fn remove_trailing_return_no_semi() { + check_diagnostics( + r#" +fn foo() -> u8 { + return 2 +} //^^^^^^^^ 💡 weak: replace return ; with +"#, + ); + } + + #[test] + fn remove_trailing_return_in_if() { + check_diagnostics_with_disabled( + r#" +fn foo(x: usize) -> u8 { + if x > 0 { + return 1; + //^^^^^^^^^ 💡 weak: replace return ; with + } else { + return 0; + } //^^^^^^^^^ 💡 weak: replace return ; with +} +"#, + std::iter::once("remove-unnecessary-else".to_owned()), + ); + } + + #[test] + fn remove_trailing_return_in_match() { + check_diagnostics( + r#" +fn foo(x: Result) -> u8 { + match x { + Ok(_) => return 1, + //^^^^^^^^ 💡 weak: replace return ; with + Err(_) => return 0, + } //^^^^^^^^ 💡 weak: replace return ; with +} +"#, + ); + } + + #[test] + fn no_diagnostic_if_no_return_keyword() { + check_diagnostics( + r#" +fn foo() -> u8 { + 3 +} +"#, + ); + } + + #[test] + fn no_diagnostic_if_not_last_statement() { + check_diagnostics( + r#" +fn foo() -> u8 { + if true { return 2; } + 3 +} +"#, + ); + } + + #[test] + fn replace_with_expr() { + check_fix( + r#" +fn foo() -> u8 { + return$0 2; +} +"#, + r#" +fn foo() -> u8 { + 2 +} +"#, + ); + } + + #[test] + fn replace_with_unit() { + check_fix( + r#" +fn foo() { + return$0/*ensure tidy is happy*/ +} +"#, + r#" +fn foo() { + /*ensure tidy is happy*/ +} +"#, + ); + } + + #[test] + fn replace_with_expr_no_semi() { + check_fix( + r#" +fn foo() -> u8 { + return$0 2 +} +"#, + r#" +fn foo() -> u8 { + 2 +} +"#, + ); + } + + #[test] + fn replace_in_inner_function() { + check_fix( + r#" +fn foo() -> u8 { + fn bar() -> u8 { + return$0 2; + } + bar() +} +"#, + r#" +fn foo() -> u8 { + fn bar() -> u8 { + 2 + } + bar() +} +"#, + ); + } + + #[test] + fn replace_in_closure() { + check_fix( + r#" +fn foo() -> u8 { + let bar = || return$0 2; + bar() +} +"#, + r#" +fn foo() -> u8 { + let bar = || 2; + bar() +} +"#, + ); + check_fix( + r#" +fn foo() -> u8 { + let bar = || { + return$0 2; + }; + bar() +} +"#, + r#" +fn foo() -> u8 { + let bar = || { + 2 + }; + bar() +} +"#, + ); + } + + #[test] + fn replace_in_if() { + check_fix_with_disabled( + r#" +fn foo(x: usize) -> u8 { + if x > 0 { + return$0 1; + } else { + 0 + } +} +"#, + r#" +fn foo(x: usize) -> u8 { + if x > 0 { + 1 + } else { + 0 + } +} +"#, + std::iter::once("remove-unnecessary-else".to_owned()), + ); + check_fix( + r#" +fn foo(x: usize) -> u8 { + if x > 0 { + 1 + } else { + return$0 0; + } +} +"#, + r#" +fn foo(x: usize) -> u8 { + if x > 0 { + 1 + } else { + 0 + } +} +"#, + ); + } + + #[test] + fn replace_in_match() { + check_fix( + r#" +fn foo(x: Result) -> u8 { + match x { + Ok(_) => return$0 1, + Err(_) => 0, + } +} +"#, + r#" +fn foo(x: Result) -> u8 { + match x { + Ok(_) => 1, + Err(_) => 0, + } +} +"#, + ); + check_fix( + r#" +fn foo(x: Result) -> u8 { + match x { + Ok(_) => 1, + Err(_) => return$0 0, + } +} +"#, + r#" +fn foo(x: Result) -> u8 { + match x { + Ok(_) => 1, + Err(_) => 0, + } +} +"#, + ); + } +} diff --git a/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs new file mode 100644 index 0000000000..ae8241ec2c --- /dev/null +++ b/crates/ide-diagnostics/src/handlers/remove_unnecessary_else.rs @@ -0,0 +1,390 @@ +use hir::{db::ExpandDatabase, diagnostics::RemoveUnnecessaryElse, HirFileIdExt}; +use ide_db::{assists::Assist, source_change::SourceChange}; +use itertools::Itertools; +use syntax::{ + ast::{self, edit::IndentLevel}, + AstNode, SyntaxToken, TextRange, +}; +use text_edit::TextEdit; + +use crate::{ + adjusted_display_range, fix, Diagnostic, DiagnosticCode, DiagnosticsContext, Severity, +}; + +// Diagnostic: remove-unnecessary-else +// +// This diagnostic is triggered when there is an `else` block for an `if` expression whose +// then branch diverges (e.g. ends with a `return`, `continue`, `break` e.t.c). +pub(crate) fn remove_unnecessary_else( + ctx: &DiagnosticsContext<'_>, + d: &RemoveUnnecessaryElse, +) -> Diagnostic { + let display_range = adjusted_display_range(ctx, d.if_expr, &|if_expr| { + if_expr.else_token().as_ref().map(SyntaxToken::text_range) + }); + Diagnostic::new( + DiagnosticCode::Ra("remove-unnecessary-else", Severity::WeakWarning), + "remove unnecessary else block", + display_range, + ) + .with_fixes(fixes(ctx, d)) +} + +fn fixes(ctx: &DiagnosticsContext<'_>, d: &RemoveUnnecessaryElse) -> Option> { + let root = ctx.sema.db.parse_or_expand(d.if_expr.file_id); + let if_expr = d.if_expr.value.to_node(&root); + let if_expr = ctx.sema.original_ast_node(if_expr.clone())?; + + let mut indent = IndentLevel::from_node(if_expr.syntax()); + let has_parent_if_expr = if_expr.syntax().parent().and_then(ast::IfExpr::cast).is_some(); + if has_parent_if_expr { + indent = indent + 1; + } + let else_replacement = match if_expr.else_branch()? { + ast::ElseBranch::Block(ref block) => { + block.statements().map(|stmt| format!("\n{indent}{stmt}")).join("") + } + ast::ElseBranch::IfExpr(ref nested_if_expr) => { + format!("\n{indent}{nested_if_expr}") + } + }; + let (replacement, range) = if has_parent_if_expr { + let base_indent = IndentLevel::from_node(if_expr.syntax()); + let then_indent = base_indent + 1; + let then_child_indent = then_indent + 1; + + let condition = if_expr.condition()?; + let then_stmts = if_expr + .then_branch()? + .statements() + .map(|stmt| format!("\n{then_child_indent}{stmt}")) + .join(""); + let then_replacement = + format!("\n{then_indent}if {condition} {{{then_stmts}\n{then_indent}}}",); + let replacement = format!("{{{then_replacement}{else_replacement}\n{base_indent}}}"); + (replacement, if_expr.syntax().text_range()) + } else { + ( + else_replacement, + TextRange::new( + if_expr.then_branch()?.syntax().text_range().end(), + if_expr.syntax().text_range().end(), + ), + ) + }; + + let edit = TextEdit::replace(range, replacement); + let source_change = + SourceChange::from_text_edit(d.if_expr.file_id.original_file(ctx.sema.db), edit); + + Some(vec![fix( + "remove_unnecessary_else", + "Remove unnecessary else block", + source_change, + range, + )]) +} + +#[cfg(test)] +mod tests { + use crate::tests::{check_diagnostics, check_diagnostics_with_disabled, check_fix}; + + fn check_diagnostics_with_needless_return_disabled(ra_fixture: &str) { + check_diagnostics_with_disabled(ra_fixture, std::iter::once("needless_return".to_owned())); + } + + #[test] + fn remove_unnecessary_else_for_return() { + check_diagnostics_with_needless_return_disabled( + r#" +fn test() { + if foo { + return bar; + } else { + //^^^^ 💡 weak: remove unnecessary else block + do_something_else(); + } +} +"#, + ); + check_fix( + r#" +fn test() { + if foo { + return bar; + } else$0 { + do_something_else(); + } +} +"#, + r#" +fn test() { + if foo { + return bar; + } + do_something_else(); +} +"#, + ); + } + + #[test] + fn remove_unnecessary_else_for_return2() { + check_diagnostics_with_needless_return_disabled( + r#" +fn test() { + if foo { + return bar; + } else if qux { + //^^^^ 💡 weak: remove unnecessary else block + do_something_else(); + } else { + do_something_else2(); + } +} +"#, + ); + check_fix( + r#" +fn test() { + if foo { + return bar; + } else$0 if qux { + do_something_else(); + } else { + do_something_else2(); + } +} +"#, + r#" +fn test() { + if foo { + return bar; + } + if qux { + do_something_else(); + } else { + do_something_else2(); + } +} +"#, + ); + } + + #[test] + fn remove_unnecessary_else_for_return_in_child_if_expr() { + check_diagnostics_with_needless_return_disabled( + r#" +fn test() { + if foo { + do_something(); + } else if qux { + return bar; + } else { + //^^^^ 💡 weak: remove unnecessary else block + do_something_else(); + } +} +"#, + ); + check_fix( + r#" +fn test() { + if foo { + do_something(); + } else if qux { + return bar; + } else$0 { + do_something_else(); + } +} +"#, + r#" +fn test() { + if foo { + do_something(); + } else { + if qux { + return bar; + } + do_something_else(); + } +} +"#, + ); + } + + #[test] + fn remove_unnecessary_else_for_break() { + check_diagnostics( + r#" +fn test() { + loop { + if foo { + break; + } else { + //^^^^ 💡 weak: remove unnecessary else block + do_something_else(); + } + } +} +"#, + ); + check_fix( + r#" +fn test() { + loop { + if foo { + break; + } else$0 { + do_something_else(); + } + } +} +"#, + r#" +fn test() { + loop { + if foo { + break; + } + do_something_else(); + } +} +"#, + ); + } + + #[test] + fn remove_unnecessary_else_for_continue() { + check_diagnostics( + r#" +fn test() { + loop { + if foo { + continue; + } else { + //^^^^ 💡 weak: remove unnecessary else block + do_something_else(); + } + } +} +"#, + ); + check_fix( + r#" +fn test() { + loop { + if foo { + continue; + } else$0 { + do_something_else(); + } + } +} +"#, + r#" +fn test() { + loop { + if foo { + continue; + } + do_something_else(); + } +} +"#, + ); + } + + #[test] + fn remove_unnecessary_else_for_never() { + check_diagnostics( + r#" +fn test() { + if foo { + never(); + } else { + //^^^^ 💡 weak: remove unnecessary else block + do_something_else(); + } +} + +fn never() -> ! { + loop {} +} +"#, + ); + check_fix( + r#" +fn test() { + if foo { + never(); + } else$0 { + do_something_else(); + } +} + +fn never() -> ! { + loop {} +} +"#, + r#" +fn test() { + if foo { + never(); + } + do_something_else(); +} + +fn never() -> ! { + loop {} +} +"#, + ); + } + + #[test] + fn no_diagnostic_if_no_else_branch() { + check_diagnostics( + r#" +fn test() { + if foo { + return bar; + } + + do_something_else(); +} +"#, + ); + } + + #[test] + fn no_diagnostic_if_no_divergence() { + check_diagnostics( + r#" +fn test() { + if foo { + do_something(); + } else { + do_something_else(); + } +} +"#, + ); + } + + #[test] + fn no_diagnostic_if_no_divergence_in_else_branch() { + check_diagnostics_with_needless_return_disabled( + r#" +fn test() { + if foo { + do_something(); + } else { + return bar; + } +} +"#, + ); + } +} diff --git a/crates/ide-diagnostics/src/handlers/replace_filter_map_next_with_find_map.rs b/crates/ide-diagnostics/src/handlers/replace_filter_map_next_with_find_map.rs index 72896b891b..6d3dcf31ab 100644 --- a/crates/ide-diagnostics/src/handlers/replace_filter_map_next_with_find_map.rs +++ b/crates/ide-diagnostics/src/handlers/replace_filter_map_next_with_find_map.rs @@ -63,8 +63,8 @@ mod tests { #[track_caller] pub(crate) fn check_diagnostics(ra_fixture: &str) { let mut config = DiagnosticsConfig::test_sample(); - config.disabled.insert("inactive-code".to_string()); - config.disabled.insert("E0599".to_string()); + config.disabled.insert("inactive-code".to_owned()); + config.disabled.insert("E0599".to_owned()); check_diagnostics_with_config(config, ra_fixture) } diff --git a/crates/ide-diagnostics/src/handlers/trait_impl_orphan.rs b/crates/ide-diagnostics/src/handlers/trait_impl_orphan.rs index d36813381e..78a04e1542 100644 --- a/crates/ide-diagnostics/src/handlers/trait_impl_orphan.rs +++ b/crates/ide-diagnostics/src/handlers/trait_impl_orphan.rs @@ -13,7 +13,7 @@ pub(crate) fn trait_impl_orphan( ctx, DiagnosticCode::RustcHardError("E0117"), "only traits defined in the current crate can be implemented for arbitrary types" - .to_string(), + .to_owned(), InFile::new(d.file_id, d.impl_.into()), ) // Not yet checked for false positives diff --git a/crates/ide-diagnostics/src/handlers/trait_impl_redundant_assoc_item.rs b/crates/ide-diagnostics/src/handlers/trait_impl_redundant_assoc_item.rs index f58fcd1f7e..00710ef507 100644 --- a/crates/ide-diagnostics/src/handlers/trait_impl_redundant_assoc_item.rs +++ b/crates/ide-diagnostics/src/handlers/trait_impl_redundant_assoc_item.rs @@ -103,7 +103,7 @@ fn quickfix_for_redundant_assoc_item( Some(vec![Assist { id: AssistId("add assoc item def into trait def", AssistKind::QuickFix), - label: Label::new("Add assoc item def into trait def".to_string()), + label: Label::new("Add assoc item def into trait def".to_owned()), group: None, target: range, source_change: Some(source_change_builder.finish()), diff --git a/crates/ide-diagnostics/src/handlers/type_mismatch.rs b/crates/ide-diagnostics/src/handlers/type_mismatch.rs index 750189beec..e93eea8ce2 100644 --- a/crates/ide-diagnostics/src/handlers/type_mismatch.rs +++ b/crates/ide-diagnostics/src/handlers/type_mismatch.rs @@ -120,7 +120,7 @@ fn add_missing_ok_or_some( let mut builder = TextEdit::builder(); builder.insert(expr.syntax().text_range().start(), format!("{variant_name}(")); - builder.insert(expr.syntax().text_range().end(), ")".to_string()); + builder.insert(expr.syntax().text_range().end(), ")".to_owned()); let source_change = SourceChange::from_text_edit(expr_ptr.file_id.original_file(ctx.sema.db), builder.finish()); let name = format!("Wrap in {variant_name}"); @@ -174,7 +174,7 @@ fn str_ref_to_owned( let expr = expr_ptr.value.to_node(&root); let expr_range = expr.syntax().text_range(); - let to_owned = ".to_owned()".to_string(); + let to_owned = ".to_owned()".to_owned(); let edit = TextEdit::insert(expr.syntax().text_range().end(), to_owned); let source_change = @@ -186,7 +186,9 @@ fn str_ref_to_owned( #[cfg(test)] mod tests { - use crate::tests::{check_diagnostics, check_fix, check_no_fix}; + use crate::tests::{ + check_diagnostics, check_diagnostics_with_disabled, check_fix, check_no_fix, + }; #[test] fn missing_reference() { @@ -718,7 +720,7 @@ struct Bar { #[test] fn return_no_value() { - check_diagnostics( + check_diagnostics_with_disabled( r#" fn f() -> i32 { return; @@ -727,6 +729,7 @@ fn f() -> i32 { } fn g() { return; } "#, + std::iter::once("needless_return".to_owned()), ); } diff --git a/crates/ide-diagnostics/src/handlers/unimplemented_builtin_macro.rs b/crates/ide-diagnostics/src/handlers/unimplemented_builtin_macro.rs index 996b6eda59..06f176f86f 100644 --- a/crates/ide-diagnostics/src/handlers/unimplemented_builtin_macro.rs +++ b/crates/ide-diagnostics/src/handlers/unimplemented_builtin_macro.rs @@ -10,7 +10,7 @@ pub(crate) fn unimplemented_builtin_macro( Diagnostic::new_with_syntax_node_ptr( ctx, DiagnosticCode::Ra("unimplemented-builtin-macro", Severity::WeakWarning), - "unimplemented built-in macro".to_string(), + "unimplemented built-in macro".to_owned(), d.node, ) } diff --git a/crates/ide-diagnostics/src/handlers/unresolved_field.rs b/crates/ide-diagnostics/src/handlers/unresolved_field.rs index 0e7a5720d4..65abfd8a29 100644 --- a/crates/ide-diagnostics/src/handlers/unresolved_field.rs +++ b/crates/ide-diagnostics/src/handlers/unresolved_field.rs @@ -65,7 +65,7 @@ fn method_fix( let FileRange { range, file_id } = ctx.sema.original_range_opt(expr.syntax())?; Some(vec![Assist { id: AssistId("expected-field-found-method-call-fix", AssistKind::QuickFix), - label: Label::new("Use parentheses to call the method".to_string()), + label: Label::new("Use parentheses to call the method".to_owned()), group: None, target: range, source_change: Some(SourceChange::from_text_edit( diff --git a/crates/ide-diagnostics/src/handlers/unresolved_method.rs b/crates/ide-diagnostics/src/handlers/unresolved_method.rs index 9f8fee67f3..648d081898 100644 --- a/crates/ide-diagnostics/src/handlers/unresolved_method.rs +++ b/crates/ide-diagnostics/src/handlers/unresolved_method.rs @@ -101,7 +101,7 @@ fn field_fix( }; Some(Assist { id: AssistId("expected-method-found-field-fix", AssistKind::QuickFix), - label: Label::new("Use parentheses to call the value of the field".to_string()), + label: Label::new("Use parentheses to call the value of the field".to_owned()), group: None, target: range, source_change: Some(SourceChange::from_iter([ diff --git a/crates/ide-diagnostics/src/handlers/unresolved_module.rs b/crates/ide-diagnostics/src/handlers/unresolved_module.rs index 1604decf90..1155688324 100644 --- a/crates/ide-diagnostics/src/handlers/unresolved_module.rs +++ b/crates/ide-diagnostics/src/handlers/unresolved_module.rs @@ -16,7 +16,7 @@ pub(crate) fn unresolved_module( ctx, DiagnosticCode::RustcHardError("E0583"), match &*d.candidates { - [] => "unresolved module".to_string(), + [] => "unresolved module".to_owned(), [candidate] => format!("unresolved module, can't find module file: {candidate}"), [candidates @ .., last] => { format!( @@ -46,7 +46,7 @@ fn fixes(ctx: &DiagnosticsContext<'_>, d: &hir::UnresolvedModule) -> Option format!("proc macro `{name}` not expanded"), - None => "proc macro not expanded".to_string(), + None => "proc macro not expanded".to_owned(), }; let severity = if config_enabled { Severity::Error } else { Severity::WeakWarning }; let def_map = ctx.sema.db.crate_def_map(d.krate); let message = if config_enabled { - def_map.proc_macro_loading_error().unwrap_or("proc macro not found in the built dylib") + def_map.proc_macro_loading_error().unwrap_or("internal error") } else { match d.kind { hir::MacroKind::Attr if proc_macros_enabled => "attribute macro expansion is disabled", diff --git a/crates/ide-diagnostics/src/handlers/useless_braces.rs b/crates/ide-diagnostics/src/handlers/useless_braces.rs index 8dce2af23e..863a7ab783 100644 --- a/crates/ide-diagnostics/src/handlers/useless_braces.rs +++ b/crates/ide-diagnostics/src/handlers/useless_braces.rs @@ -40,7 +40,7 @@ pub(crate) fn useless_braces( acc.push( Diagnostic::new( DiagnosticCode::RustcLint("unused_braces"), - "Unnecessary braces in use statement".to_string(), + "Unnecessary braces in use statement".to_owned(), FileRange { file_id, range: use_range }, ) .with_main_node(InFile::new(file_id.into(), node.clone())) @@ -112,7 +112,7 @@ mod a { ); let mut config = DiagnosticsConfig::test_sample(); - config.disabled.insert("syntax-error".to_string()); + config.disabled.insert("syntax-error".to_owned()); check_diagnostics_with_config( config, r#" diff --git a/crates/ide-diagnostics/src/lib.rs b/crates/ide-diagnostics/src/lib.rs index 5ad7069e31..9d21bb4cd9 100644 --- a/crates/ide-diagnostics/src/lib.rs +++ b/crates/ide-diagnostics/src/lib.rs @@ -43,6 +43,8 @@ mod handlers { pub(crate) mod no_such_field; pub(crate) mod private_assoc_item; pub(crate) mod private_field; + pub(crate) mod remove_trailing_return; + pub(crate) mod remove_unnecessary_else; pub(crate) mod replace_filter_map_next_with_find_map; pub(crate) mod trait_impl_incorrect_safety; pub(crate) mod trait_impl_missing_assoc_item; @@ -73,8 +75,6 @@ mod handlers { #[cfg(test)] mod tests; -use std::collections::HashMap; - use hir::{diagnostics::AnyDiagnostic, InFile, Semantics}; use ide_db::{ assists::{Assist, AssistId, AssistKind, AssistResolveStrategy}, @@ -384,6 +384,8 @@ pub fn diagnostics( AnyDiagnostic::UnusedVariable(d) => handlers::unused_variables::unused_variables(&ctx, &d), AnyDiagnostic::BreakOutsideOfLoop(d) => handlers::break_outside_of_loop::break_outside_of_loop(&ctx, &d), AnyDiagnostic::MismatchedTupleStructPatArgCount(d) => handlers::mismatched_arg_count::mismatched_tuple_struct_pat_arg_count(&ctx, &d), + AnyDiagnostic::RemoveTrailingReturn(d) => handlers::remove_trailing_return::remove_trailing_return(&ctx, &d), + AnyDiagnostic::RemoveUnnecessaryElse(d) => handlers::remove_unnecessary_else::remove_unnecessary_else(&ctx, &d), }; res.push(d) } @@ -413,18 +415,18 @@ pub fn diagnostics( // `__RA_EVERY_LINT` is a fake lint group to allow every lint in proc macros -static RUSTC_LINT_GROUPS_DICT: Lazy>> = +static RUSTC_LINT_GROUPS_DICT: Lazy>> = Lazy::new(|| build_group_dict(DEFAULT_LINT_GROUPS, &["warnings", "__RA_EVERY_LINT"], "")); -static CLIPPY_LINT_GROUPS_DICT: Lazy>> = +static CLIPPY_LINT_GROUPS_DICT: Lazy>> = Lazy::new(|| build_group_dict(CLIPPY_LINT_GROUPS, &["__RA_EVERY_LINT"], "clippy::")); fn build_group_dict( lint_group: &'static [LintGroup], all_groups: &'static [&'static str], prefix: &'static str, -) -> HashMap<&'static str, Vec<&'static str>> { - let mut r: HashMap<&str, Vec<&str>> = HashMap::new(); +) -> FxHashMap<&'static str, Vec<&'static str>> { + let mut r: FxHashMap<&str, Vec<&str>> = FxHashMap::default(); for g in lint_group { for child in g.children { r.entry(child.strip_prefix(prefix).unwrap()) @@ -561,7 +563,7 @@ fn unresolved_fix(id: &'static str, label: &str, target: TextRange) -> Assist { assert!(!id.contains(' ')); Assist { id: AssistId(id, AssistKind::QuickFix), - label: Label::new(label.to_string()), + label: Label::new(label.to_owned()), group: None, target, source_change: None, diff --git a/crates/ide-diagnostics/src/tests.rs b/crates/ide-diagnostics/src/tests.rs index f394a491b5..b62bb5affd 100644 --- a/crates/ide-diagnostics/src/tests.rs +++ b/crates/ide-diagnostics/src/tests.rs @@ -1,3 +1,4 @@ +#![allow(clippy::print_stderr)] #[cfg(not(feature = "in-rust-tree"))] mod sourcegen; @@ -33,13 +34,35 @@ pub(crate) fn check_fixes(ra_fixture_before: &str, ra_fixtures_after: Vec<&str>) #[track_caller] fn check_nth_fix(nth: usize, ra_fixture_before: &str, ra_fixture_after: &str) { + let mut config = DiagnosticsConfig::test_sample(); + config.expr_fill_default = ExprFillDefaultMode::Default; + check_nth_fix_with_config(config, nth, ra_fixture_before, ra_fixture_after) +} + +#[track_caller] +pub(crate) fn check_fix_with_disabled( + ra_fixture_before: &str, + ra_fixture_after: &str, + disabled: impl Iterator, +) { + let mut config = DiagnosticsConfig::test_sample(); + config.expr_fill_default = ExprFillDefaultMode::Default; + config.disabled.extend(disabled); + check_nth_fix_with_config(config, 0, ra_fixture_before, ra_fixture_after) +} + +#[track_caller] +fn check_nth_fix_with_config( + config: DiagnosticsConfig, + nth: usize, + ra_fixture_before: &str, + ra_fixture_after: &str, +) { let after = trim_indent(ra_fixture_after); let (db, file_position) = RootDatabase::with_position(ra_fixture_before); - let mut conf = DiagnosticsConfig::test_sample(); - conf.expr_fill_default = ExprFillDefaultMode::Default; let diagnostic = - super::diagnostics(&db, &conf, &AssistResolveStrategy::All, file_position.file_id) + super::diagnostics(&db, &config, &AssistResolveStrategy::All, file_position.file_id) .pop() .expect("no diagnostics"); let fix = &diagnostic @@ -85,7 +108,17 @@ pub(crate) fn check_no_fix(ra_fixture: &str) { #[track_caller] pub(crate) fn check_diagnostics(ra_fixture: &str) { let mut config = DiagnosticsConfig::test_sample(); - config.disabled.insert("inactive-code".to_string()); + config.disabled.insert("inactive-code".to_owned()); + check_diagnostics_with_config(config, ra_fixture) +} + +#[track_caller] +pub(crate) fn check_diagnostics_with_disabled( + ra_fixture: &str, + disabled: impl Iterator, +) { + let mut config = DiagnosticsConfig::test_sample(); + config.disabled.extend(disabled); check_diagnostics_with_config(config, ra_fixture) } @@ -174,7 +207,8 @@ fn minicore_smoke_test() { let source = minicore.source_code(); let mut config = DiagnosticsConfig::test_sample(); // This should be ignored since we conditionally remove code which creates single item use with braces - config.disabled.insert("unused_braces".to_string()); + config.disabled.insert("unused_braces".to_owned()); + config.disabled.insert("unused_variables".to_owned()); check_diagnostics_with_config(config, &source); } diff --git a/crates/ide-ssr/src/matching.rs b/crates/ide-ssr/src/matching.rs index 060897a685..fb98e95684 100644 --- a/crates/ide-ssr/src/matching.rs +++ b/crates/ide-ssr/src/matching.rs @@ -456,7 +456,7 @@ impl<'db, 'sema> Matcher<'db, 'sema> { SyntaxElement::Token(t) => Some(t.clone()), SyntaxElement::Node(n) => n.first_token(), }) - .map(|p| p.text().to_string()); + .map(|p| p.text().to_owned()); let first_matched_token = child.clone(); let mut last_matched_token = child; // Read code tokens util we reach one equal to the next token from our pattern @@ -706,7 +706,7 @@ where // we are trying to match that bit of code. This saves us having to pass a boolean into all the bits // of code that can make the decision to not match. thread_local! { - pub static RECORDING_MATCH_FAIL_REASONS: Cell = Cell::new(false); + pub static RECORDING_MATCH_FAIL_REASONS: Cell = const { Cell::new(false) }; } fn recording_match_fail_reasons() -> bool { @@ -795,7 +795,7 @@ mod tests { let edits = match_finder.edits(); assert_eq!(edits.len(), 1); let edit = &edits[&position.file_id]; - let mut after = input.to_string(); + let mut after = input.to_owned(); edit.apply(&mut after); assert_eq!(after, "fn foo() {} fn bar() {} fn main() { bar(1+2); }"); } diff --git a/crates/ide-ssr/src/parsing.rs b/crates/ide-ssr/src/parsing.rs index d78d009681..2f91271c46 100644 --- a/crates/ide-ssr/src/parsing.rs +++ b/crates/ide-ssr/src/parsing.rs @@ -152,7 +152,7 @@ impl FromStr for SsrRule { .next() .ok_or_else(|| SsrError("Cannot find delimiter `==>>`".into()))? .trim() - .to_string(); + .to_owned(); if it.next().is_some() { return Err(SsrError("More than one delimiter found".into())); } diff --git a/crates/ide-ssr/src/tests.rs b/crates/ide-ssr/src/tests.rs index 7c7d146cb4..e608b0a7c4 100644 --- a/crates/ide-ssr/src/tests.rs +++ b/crates/ide-ssr/src/tests.rs @@ -113,6 +113,7 @@ fn assert_ssr_transforms(rules: &[&str], input: &str, expected: Expect) { expected.assert_eq(&actual); } +#[allow(clippy::print_stdout)] fn print_match_debug_info(match_finder: &MatchFinder<'_>, file_id: FileId, snippet: &str) { let debug_info = match_finder.debug_where_text_equal(file_id, snippet); println!( diff --git a/crates/ide/src/doc_links.rs b/crates/ide/src/doc_links.rs index f221985719..dbe6a5507c 100644 --- a/crates/ide/src/doc_links.rs +++ b/crates/ide/src/doc_links.rs @@ -58,7 +58,7 @@ pub(crate) fn rewrite_links(db: &RootDatabase, markdown: &str, definition: Defin // and valid URLs so we choose to be too eager to try to resolve what might be // a URL. if target.contains("://") { - (Some(LinkType::Inline), target.to_string(), title.to_string()) + (Some(LinkType::Inline), target.to_owned(), title.to_owned()) } else { // Two possibilities: // * path-based links: `../../module/struct.MyStruct.html` @@ -66,9 +66,9 @@ pub(crate) fn rewrite_links(db: &RootDatabase, markdown: &str, definition: Defin if let Some((target, title)) = rewrite_intra_doc_link(db, definition, target, title) { (None, target, title) } else if let Some(target) = rewrite_url_link(db, definition, target) { - (Some(LinkType::Inline), target, title.to_string()) + (Some(LinkType::Inline), target, title.to_owned()) } else { - (None, target.to_string(), title.to_string()) + (None, target.to_owned(), title.to_owned()) } } }); @@ -186,7 +186,7 @@ pub(crate) fn extract_definitions_from_docs( let (link, ns) = parse_intra_doc_link(&target); Some(( TextRange::new(range.start.try_into().ok()?, range.end.try_into().ok()?), - link.to_string(), + link.to_owned(), ns, )) } @@ -388,7 +388,7 @@ fn rewrite_intra_doc_link( url = url.join(&file).ok()?; url.set_fragment(anchor); - Some((url.into(), strip_prefixes_suffixes(title).to_string())) + Some((url.into(), strip_prefixes_suffixes(title).to_owned())) } /// Try to resolve path to local documentation via path-based links (i.e. `../gateway/struct.Shard.html`). @@ -668,7 +668,7 @@ fn get_assoc_item_fragment(db: &dyn HirDatabase, assoc_item: hir::AssocItem) -> Some(match assoc_item { AssocItem::Function(function) => { let is_trait_method = - function.as_assoc_item(db).and_then(|assoc| assoc.containing_trait(db)).is_some(); + function.as_assoc_item(db).and_then(|assoc| assoc.container_trait(db)).is_some(); // This distinction may get more complicated when specialization is available. // Rustdoc makes this decision based on whether a method 'has defaultness'. // Currently this is only the case for provided trait methods. diff --git a/crates/ide/src/doc_links/tests.rs b/crates/ide/src/doc_links/tests.rs index 3bb0fc6064..60e8d29a71 100644 --- a/crates/ide/src/doc_links/tests.rs +++ b/crates/ide/src/doc_links/tests.rs @@ -29,9 +29,6 @@ fn check_external_docs( let web_url = links.web_url; let local_url = links.local_url; - println!("web_url: {:?}", web_url); - println!("local_url: {:?}", local_url); - match (expect_web_url, web_url) { (Some(expect), Some(url)) => expect.assert_eq(&url), (None, None) => (), diff --git a/crates/ide/src/file_structure.rs b/crates/ide/src/file_structure.rs index b278924721..0e790e1420 100644 --- a/crates/ide/src/file_structure.rs +++ b/crates/ide/src/file_structure.rs @@ -193,7 +193,7 @@ fn structure_token(token: SyntaxToken) -> Option { if let Some(region_name) = text.strip_prefix("// region:").map(str::trim) { return Some(StructureNode { parent: None, - label: region_name.to_string(), + label: region_name.to_owned(), navigation_range: comment.syntax().text_range(), node_range: comment.syntax().text_range(), kind: StructureNodeKind::Region, diff --git a/crates/ide/src/goto_declaration.rs b/crates/ide/src/goto_declaration.rs index fae1007435..fab62e95d1 100644 --- a/crates/ide/src/goto_declaration.rs +++ b/crates/ide/src/goto_declaration.rs @@ -61,7 +61,7 @@ pub(crate) fn goto_declaration( _ => None, }?; - let trait_ = assoc.containing_trait_impl(db)?; + let trait_ = assoc.implemented_trait(db)?; let name = Some(assoc.name(db)?); let item = trait_.items(db).into_iter().find(|it| it.name(db) == name)?; item.try_to_nav(db) diff --git a/crates/ide/src/goto_definition.rs b/crates/ide/src/goto_definition.rs index 4fed1f9158..88255d222e 100644 --- a/crates/ide/src/goto_definition.rs +++ b/crates/ide/src/goto_definition.rs @@ -182,11 +182,7 @@ fn try_filter_trait_item_definition( match assoc { AssocItem::Function(..) => None, AssocItem::Const(..) | AssocItem::TypeAlias(..) => { - let imp = match assoc.container(db) { - hir::AssocItemContainer::Impl(imp) => imp, - _ => return None, - }; - let trait_ = imp.trait_(db)?; + let trait_ = assoc.implemented_trait(db)?; let name = def.name(db)?; let discri_value = discriminant(&assoc); trait_ diff --git a/crates/ide/src/goto_implementation.rs b/crates/ide/src/goto_implementation.rs index c1a4a7b1fc..8a12cbaccc 100644 --- a/crates/ide/src/goto_implementation.rs +++ b/crates/ide/src/goto_implementation.rs @@ -60,13 +60,13 @@ pub(crate) fn goto_implementation( Definition::Function(f) => { let assoc = f.as_assoc_item(sema.db)?; let name = assoc.name(sema.db)?; - let trait_ = assoc.containing_trait_or_trait_impl(sema.db)?; + let trait_ = assoc.container_or_implemented_trait(sema.db)?; impls_for_trait_item(&sema, trait_, name) } Definition::Const(c) => { let assoc = c.as_assoc_item(sema.db)?; let name = assoc.name(sema.db)?; - let trait_ = assoc.containing_trait_or_trait_impl(sema.db)?; + let trait_ = assoc.container_or_implemented_trait(sema.db)?; impls_for_trait_item(&sema, trait_, name) } _ => return None, diff --git a/crates/ide/src/highlight_related.rs b/crates/ide/src/highlight_related.rs index 979ca4575d..dd285e9b32 100644 --- a/crates/ide/src/highlight_related.rs +++ b/crates/ide/src/highlight_related.rs @@ -521,7 +521,7 @@ mod tests { ReferenceCategory::Import => "import", ReferenceCategory::Test => "test", } - .to_string() + .to_owned() }), ) }) diff --git a/crates/ide/src/hover/render.rs b/crates/ide/src/hover/render.rs index 45386df2b2..eff055c959 100644 --- a/crates/ide/src/hover/render.rs +++ b/crates/ide/src/hover/render.rs @@ -621,7 +621,7 @@ fn closure_ty( }) .join("\n"); if captures_rendered.trim().is_empty() { - captures_rendered = "This closure captures nothing".to_string(); + captures_rendered = "This closure captures nothing".to_owned(); } let mut targets: Vec = Vec::new(); let mut push_new_def = |item: hir::ModuleDef| { @@ -823,7 +823,7 @@ fn keyword_hints( } } _ => KeywordHint { - description: token.text().to_string(), + description: token.text().to_owned(), keyword_mod, actions: Vec::new(), }, @@ -835,9 +835,9 @@ fn keyword_hints( Some(_) => format!("prim_{}", token.text()), None => format!("{}_keyword", token.text()), }; - KeywordHint::new(token.text().to_string(), module) + KeywordHint::new(token.text().to_owned(), module) } - T![Self] => KeywordHint::new(token.text().to_string(), "self_upper_keyword".into()), - _ => KeywordHint::new(token.text().to_string(), format!("{}_keyword", token.text())), + T![Self] => KeywordHint::new(token.text().to_owned(), "self_upper_keyword".into()), + _ => KeywordHint::new(token.text().to_owned(), format!("{}_keyword", token.text())), } } diff --git a/crates/ide/src/hover/tests.rs b/crates/ide/src/hover/tests.rs index 9f4427090e..30bfe6ee9d 100644 --- a/crates/ide/src/hover/tests.rs +++ b/crates/ide/src/hover/tests.rs @@ -406,8 +406,8 @@ fn main() { file_id: FileId( 1, ), - full_range: 631..866, - focus_range: 692..698, + full_range: 632..867, + focus_range: 693..699, name: "FnOnce", kind: Trait, container_name: "function", @@ -702,7 +702,7 @@ fn hover_shows_struct_field_info() { // Hovering over the field when instantiating check( r#" -struct Foo { field_a: u32 } +struct Foo { pub field_a: u32 } fn main() { let foo = Foo { field_a$0: 0, }; @@ -717,7 +717,7 @@ fn main() { ```rust // size = 4, align = 4, offset = 0 - field_a: u32 + pub field_a: u32 ``` "#]], ); @@ -725,7 +725,7 @@ fn main() { // Hovering over the field in the definition check( r#" -struct Foo { field_a$0: u32 } +struct Foo { pub field_a$0: u32 } fn main() { let foo = Foo { field_a: 0 }; @@ -740,7 +740,74 @@ fn main() { ```rust // size = 4, align = 4, offset = 0 - field_a: u32 + pub field_a: u32 + ``` + "#]], + ); +} + +#[test] +fn hover_shows_tuple_struct_field_info() { + check( + r#" +struct Foo(pub u32) + +fn main() { + let foo = Foo { 0$0: 0, }; +} +"#, + expect![[r#" + *0* + + ```rust + test::Foo + ``` + + ```rust + // size = 4, align = 4, offset = 0 + pub 0: u32 + ``` + "#]], + ); + check( + r#" +struct Foo(pub u32) + +fn foo(foo: Foo) { + foo.0$0; +} +"#, + expect![[r#" + *0* + + ```rust + test::Foo + ``` + + ```rust + // size = 4, align = 4, offset = 0 + pub 0: u32 + ``` + "#]], + ); +} + +#[test] +fn hover_tuple_struct() { + check( + r#" +struct Foo$0(pub u32) +"#, + expect![[r#" + *Foo* + + ```rust + test + ``` + + ```rust + // size = 4, align = 4 + struct Foo(pub u32); ``` "#]], ); @@ -7196,8 +7263,8 @@ impl Iterator for S { file_id: FileId( 1, ), - full_range: 6156..6364, - focus_range: 6221..6227, + full_range: 6157..6365, + focus_range: 6222..6228, name: "Future", kind: Trait, container_name: "future", @@ -7210,8 +7277,8 @@ impl Iterator for S { file_id: FileId( 1, ), - full_range: 6994..7460, - focus_range: 7038..7046, + full_range: 6995..7461, + focus_range: 7039..7047, name: "Iterator", kind: Trait, container_name: "iterator", diff --git a/crates/ide/src/inlay_hints/implicit_drop.rs b/crates/ide/src/inlay_hints/implicit_drop.rs index 3104b85768..8d9ad5bda1 100644 --- a/crates/ide/src/inlay_hints/implicit_drop.rs +++ b/crates/ide/src/inlay_hints/implicit_drop.rs @@ -1,10 +1,8 @@ //! Implementation of "implicit drop" inlay hints: -//! ```no_run -//! fn main() { -//! let x = vec![2]; -//! if some_condition() { -//! /* drop(x) */return; -//! } +//! ```ignore +//! let x = vec![2]; +//! if some_condition() { +//! /* drop(x) */return; //! } //! ``` use hir::{ diff --git a/crates/ide/src/inlay_hints/range_exclusive.rs b/crates/ide/src/inlay_hints/range_exclusive.rs index 50ab15c504..c4b0c199fc 100644 --- a/crates/ide/src/inlay_hints/range_exclusive.rs +++ b/crates/ide/src/inlay_hints/range_exclusive.rs @@ -1,5 +1,5 @@ //! Implementation of "range exclusive" inlay hints: -//! ```no_run +//! ```ignore //! for i in 0../* < */10 {} //! if let ../* < */100 = 50 {} //! ``` diff --git a/crates/ide/src/interpret_function.rs b/crates/ide/src/interpret_function.rs index adbd191888..df444a3f4d 100644 --- a/crates/ide/src/interpret_function.rs +++ b/crates/ide/src/interpret_function.rs @@ -15,8 +15,8 @@ use syntax::{algo::ancestors_at_offset, ast, AstNode, TextRange}; // |=== pub(crate) fn interpret_function(db: &RootDatabase, position: FilePosition) -> String { let start_time = Instant::now(); - let mut result = find_and_interpret(db, position) - .unwrap_or_else(|| "Not inside a function body".to_string()); + let mut result = + find_and_interpret(db, position).unwrap_or_else(|| "Not inside a function body".to_owned()); let duration = Instant::now() - start_time; writeln!(result).unwrap(); writeln!(result, "----------------------").unwrap(); diff --git a/crates/ide/src/join_lines.rs b/crates/ide/src/join_lines.rs index 1cfde23624..fef0ec35ba 100644 --- a/crates/ide/src/join_lines.rs +++ b/crates/ide/src/join_lines.rs @@ -115,7 +115,7 @@ fn remove_newline( let range = TextRange::at(offset, ((n_spaces_after_line_break + 1) as u32).into()); let replace_with = if no_space { "" } else { " " }; - edit.replace(range, replace_with.to_string()); + edit.replace(range, replace_with.to_owned()); return; } @@ -140,7 +140,7 @@ fn remove_newline( }; edit.replace( TextRange::new(prev.text_range().start(), token.text_range().end()), - space.to_string(), + space.to_owned(), ); return; } @@ -154,7 +154,7 @@ fn remove_newline( Some(_) => cov_mark::hit!(join_two_ifs_with_existing_else), None => { cov_mark::hit!(join_two_ifs); - edit.replace(token.text_range(), " else ".to_string()); + edit.replace(token.text_range(), " else ".to_owned()); return; } } @@ -203,7 +203,7 @@ fn remove_newline( } // Remove newline but add a computed amount of whitespace characters - edit.replace(token.text_range(), compute_ws(prev.kind(), next.kind()).to_string()); + edit.replace(token.text_range(), compute_ws(prev.kind(), next.kind()).to_owned()); } fn join_single_expr_block(edit: &mut TextEditBuilder, token: &SyntaxToken) -> Option<()> { diff --git a/crates/ide/src/lib.rs b/crates/ide/src/lib.rs index e9f42d4785..effdbf2c1f 100644 --- a/crates/ide/src/lib.rs +++ b/crates/ide/src/lib.rs @@ -238,7 +238,7 @@ impl Analysis { let mut host = AnalysisHost::default(); let file_id = FileId::from_raw(0); let mut file_set = FileSet::default(); - file_set.insert(file_id, VfsPath::new_virtual_path("/main.rs".to_string())); + file_set.insert(file_id, VfsPath::new_virtual_path("/main.rs".to_owned())); let source_root = SourceRoot::new_local(file_set); let mut change = Change::new(); diff --git a/crates/ide/src/moniker.rs b/crates/ide/src/moniker.rs index c49d75b2f8..80d265ae37 100644 --- a/crates/ide/src/moniker.rs +++ b/crates/ide/src/moniker.rs @@ -383,18 +383,18 @@ pub(crate) fn def_to_moniker( let (name, repo, version) = match krate.origin(db) { CrateOrigin::Library { repo, name } => (name, repo, krate.version(db)), CrateOrigin::Local { repo, name } => ( - name.unwrap_or(krate.display_name(db)?.canonical_name().to_string()), + name.unwrap_or(krate.display_name(db)?.canonical_name().to_owned()), repo, krate.version(db), ), CrateOrigin::Rustc { name } => ( name.clone(), - Some("https://github.com/rust-lang/rust/".to_string()), + Some("https://github.com/rust-lang/rust/".to_owned()), Some(format!("https://github.com/rust-lang/rust/compiler/{name}",)), ), CrateOrigin::Lang(lang) => ( - krate.display_name(db)?.canonical_name().to_string(), - Some("https://github.com/rust-lang/rust/".to_string()), + krate.display_name(db)?.canonical_name().to_owned(), + Some("https://github.com/rust-lang/rust/".to_owned()), Some(match lang { LangCrateOrigin::Other => { "https://github.com/rust-lang/rust/library/".into() diff --git a/crates/ide/src/navigation_target.rs b/crates/ide/src/navigation_target.rs index bfd91feeb3..674ce6d52b 100644 --- a/crates/ide/src/navigation_target.rs +++ b/crates/ide/src/navigation_target.rs @@ -860,7 +860,7 @@ fn foo() { enum FooInner { } } "#, ); - let navs = analysis.symbol_search(Query::new("FooInner".to_string()), !0).unwrap(); + let navs = analysis.symbol_search(Query::new("FooInner".to_owned()), !0).unwrap(); expect![[r#" [ NavigationTarget { @@ -898,7 +898,7 @@ struct Foo; "#, ); - let navs = analysis.symbol_search(Query::new("foo".to_string()), !0).unwrap(); + let navs = analysis.symbol_search(Query::new("foo".to_owned()), !0).unwrap(); assert_eq!(navs.len(), 2) } } diff --git a/crates/ide/src/prime_caches.rs b/crates/ide/src/prime_caches.rs index a95d1771ce..5c14f496a0 100644 --- a/crates/ide/src/prime_caches.rs +++ b/crates/ide/src/prime_caches.rs @@ -105,7 +105,7 @@ pub(crate) fn parallel_prime_caches( work_sender .send(( crate_id, - graph[crate_id].display_name.as_deref().unwrap_or_default().to_string(), + graph[crate_id].display_name.as_deref().unwrap_or_default().to_owned(), )) .ok(); } diff --git a/crates/ide/src/references.rs b/crates/ide/src/references.rs index bdda25a111..dcdc6118a3 100644 --- a/crates/ide/src/references.rs +++ b/crates/ide/src/references.rs @@ -9,8 +9,6 @@ //! at the index that the match starts at and its tree parent is //! resolved to the search element definition, we get a reference. -use std::collections::HashMap; - use hir::{DescendPreference, PathResolution, Semantics}; use ide_db::{ base_db::FileId, @@ -79,7 +77,7 @@ pub(crate) fn find_all_refs( .collect(), ) }) - .collect::, _>>(); + .collect::>>(); let declaration = match def { Definition::Module(module) => { Some(NavigationTarget::from_module_to_decl(sema.db, module)) diff --git a/crates/ide/src/runnables.rs b/crates/ide/src/runnables.rs index 3008722cdb..ae107a9604 100644 --- a/crates/ide/src/runnables.rs +++ b/crates/ide/src/runnables.rs @@ -72,7 +72,7 @@ impl Runnable { RunnableKind::Bench { test_id } => format!("bench {test_id}"), RunnableKind::DocTest { test_id, .. } => format!("doctest {test_id}"), RunnableKind::Bin => { - target.map_or_else(|| "run binary".to_string(), |t| format!("run {t}")) + target.map_or_else(|| "run binary".to_owned(), |t| format!("run {t}")) } } } @@ -442,8 +442,7 @@ fn module_def_doctest(db: &RootDatabase, def: Definition) -> Option { .for_each(|name| format_to!(path, "{}::", name.display(db))); // This probably belongs to canonical_path? if let Some(assoc_item) = def.as_assoc_item(db) { - if let hir::AssocItemContainer::Impl(imp) = assoc_item.container(db) { - let ty = imp.self_ty(db); + if let Some(ty) = assoc_item.implementing_ty(db) { if let Some(adt) = ty.as_adt() { let name = adt.name(db); let mut ty_args = ty.generic_parameters(db).peekable(); diff --git a/crates/ide/src/ssr.rs b/crates/ide/src/ssr.rs index f0d18fdefa..b49fe391bf 100644 --- a/crates/ide/src/ssr.rs +++ b/crates/ide/src/ssr.rs @@ -41,7 +41,7 @@ pub(crate) fn ssr_assists( for (label, source_change) in assists.into_iter() { let assist = Assist { id, - label: Label::new(label.to_string()), + label: Label::new(label.to_owned()), group: Some(GroupLabel("Apply SSR".into())), target: comment_range, source_change, diff --git a/crates/ide/src/static_index.rs b/crates/ide/src/static_index.rs index 5b7094e6bc..dee5afbf8d 100644 --- a/crates/ide/src/static_index.rs +++ b/crates/ide/src/static_index.rs @@ -1,14 +1,12 @@ //! This module provides `StaticIndex` which is used for powering //! read-only code browsers and emitting LSIF -use std::collections::HashMap; - use hir::{db::HirDatabase, Crate, HirFileIdExt, Module}; -use ide_db::helpers::get_definition; use ide_db::{ base_db::{FileId, FileRange, SourceDatabaseExt}, defs::Definition, - FxHashSet, RootDatabase, + helpers::get_definition, + FxHashMap, FxHashSet, RootDatabase, }; use syntax::{AstNode, SyntaxKind::*, TextRange, T}; @@ -31,7 +29,7 @@ pub struct StaticIndex<'a> { pub tokens: TokenStore, analysis: &'a Analysis, db: &'a RootDatabase, - def_map: HashMap, + def_map: FxHashMap, } #[derive(Debug)] @@ -232,14 +230,13 @@ impl StaticIndex<'_> { #[cfg(test)] mod tests { use crate::{fixture, StaticIndex}; - use ide_db::base_db::FileRange; - use std::collections::HashSet; + use ide_db::{base_db::FileRange, FxHashSet}; use syntax::TextSize; fn check_all_ranges(ra_fixture: &str) { let (analysis, ranges) = fixture::annotations_without_marker(ra_fixture); let s = StaticIndex::compute(&analysis); - let mut range_set: HashSet<_> = ranges.iter().map(|it| it.0).collect(); + let mut range_set: FxHashSet<_> = ranges.iter().map(|it| it.0).collect(); for f in s.files { for (range, _) in f.tokens { let it = FileRange { file_id: f.file_id, range }; @@ -258,7 +255,7 @@ mod tests { fn check_definitions(ra_fixture: &str) { let (analysis, ranges) = fixture::annotations_without_marker(ra_fixture); let s = StaticIndex::compute(&analysis); - let mut range_set: HashSet<_> = ranges.iter().map(|it| it.0).collect(); + let mut range_set: FxHashSet<_> = ranges.iter().map(|it| it.0).collect(); for (_, t) in s.tokens.iter() { if let Some(t) = t.definition { if t.range.start() == TextSize::from(0) { diff --git a/crates/ide/src/status.rs b/crates/ide/src/status.rs index b2b305c1d3..3321a0513b 100644 --- a/crates/ide/src/status.rs +++ b/crates/ide/src/status.rs @@ -105,7 +105,7 @@ pub(crate) fn status(db: &RootDatabase, file_id: Option) -> String { } } - buf.trim().to_string() + buf.trim().to_owned() } fn collect_query<'q, Q>(table: QueryTable<'q, Q>) -> ::Collector diff --git a/crates/ide/src/syntax_highlighting/highlight.rs b/crates/ide/src/syntax_highlighting/highlight.rs index d686652bb3..e7c1b4497e 100644 --- a/crates/ide/src/syntax_highlighting/highlight.rs +++ b/crates/ide/src/syntax_highlighting/highlight.rs @@ -559,7 +559,7 @@ fn highlight_method_call( } if func .as_assoc_item(sema.db) - .and_then(|it| it.containing_trait_or_trait_impl(sema.db)) + .and_then(|it| it.container_or_implemented_trait(sema.db)) .is_some() { h |= HlMod::Trait; diff --git a/crates/ide/src/syntax_tree.rs b/crates/ide/src/syntax_tree.rs index 2108b53861..1065d5899a 100644 --- a/crates/ide/src/syntax_tree.rs +++ b/crates/ide/src/syntax_tree.rs @@ -55,7 +55,7 @@ fn syntax_tree_for_string(token: &SyntaxToken, text_range: TextRange) -> Option< fn syntax_tree_for_token(node: &SyntaxToken, text_range: TextRange) -> Option { // Range of the full node let node_range = node.text_range(); - let text = node.text().to_string(); + let text = node.text().to_owned(); // We start at some point inside the node // Either we have selected the whole string diff --git a/crates/ide/src/typing.rs b/crates/ide/src/typing.rs index b8856882ed..e87fc89fea 100644 --- a/crates/ide/src/typing.rs +++ b/crates/ide/src/typing.rs @@ -149,10 +149,7 @@ fn on_opening_bracket_typed( let tree: ast::UseTree = find_node_at_offset(file.syntax(), offset)?; - Some(TextEdit::insert( - tree.syntax().text_range().end() + TextSize::of("{"), - "}".to_string(), - )) + Some(TextEdit::insert(tree.syntax().text_range().end() + TextSize::of("{"), "}".to_owned())) } fn bracket_expr( @@ -235,7 +232,7 @@ fn on_eq_typed(file: &SourceFile, offset: TextSize) -> Option { return None; } let offset = expr.syntax().text_range().end(); - Some(TextEdit::insert(offset, ";".to_string())) + Some(TextEdit::insert(offset, ";".to_owned())) } /// `a =$0 b;` removes the semicolon if an expression is valid in this context. @@ -275,7 +272,7 @@ fn on_eq_typed(file: &SourceFile, offset: TextSize) -> Option { return None; } let offset = let_stmt.syntax().text_range().end(); - Some(TextEdit::insert(offset, ";".to_string())) + Some(TextEdit::insert(offset, ";".to_owned())) } } @@ -353,7 +350,7 @@ fn on_left_angle_typed(file: &SourceFile, offset: TextSize) -> Option".to_string()), + edit: TextEdit::replace(range, "<$0>".to_owned()), is_snippet: true, }); } @@ -363,7 +360,7 @@ fn on_left_angle_typed(file: &SourceFile, offset: TextSize) -> Option".to_string()), + edit: TextEdit::replace(range, "<$0>".to_owned()), is_snippet: true, }) } else { @@ -383,7 +380,7 @@ fn on_right_angle_typed(file: &SourceFile, offset: TextSize) -> Option } find_node_at_offset::(file.syntax(), offset)?; - Some(TextEdit::insert(after_arrow, " ".to_string())) + Some(TextEdit::insert(after_arrow, " ".to_owned())) } #[cfg(test)] diff --git a/crates/ide/src/view_hir.rs b/crates/ide/src/view_hir.rs index 9abe54cd39..51cf45bd22 100644 --- a/crates/ide/src/view_hir.rs +++ b/crates/ide/src/view_hir.rs @@ -12,7 +12,7 @@ use syntax::{algo::ancestors_at_offset, ast, AstNode}; // |=== // image::https://user-images.githubusercontent.com/48062697/113065588-068bdb80-91b1-11eb-9a78-0b4ef1e972fb.gif[] pub(crate) fn view_hir(db: &RootDatabase, position: FilePosition) -> String { - body_hir(db, position).unwrap_or_else(|| "Not inside a function body".to_string()) + body_hir(db, position).unwrap_or_else(|| "Not inside a function body".to_owned()) } fn body_hir(db: &RootDatabase, position: FilePosition) -> Option { diff --git a/crates/ide/src/view_mir.rs b/crates/ide/src/view_mir.rs index 08d810c134..5fb4703989 100644 --- a/crates/ide/src/view_mir.rs +++ b/crates/ide/src/view_mir.rs @@ -11,7 +11,7 @@ use syntax::{algo::ancestors_at_offset, ast, AstNode}; // | VS Code | **rust-analyzer: View Mir** // |=== pub(crate) fn view_mir(db: &RootDatabase, position: FilePosition) -> String { - body_mir(db, position).unwrap_or_else(|| "Not inside a function body".to_string()) + body_mir(db, position).unwrap_or_else(|| "Not inside a function body".to_owned()) } fn body_mir(db: &RootDatabase, position: FilePosition) -> Option { diff --git a/crates/limit/src/lib.rs b/crates/limit/src/lib.rs index 7f4b00df0b..27471db6a3 100644 --- a/crates/limit/src/lib.rs +++ b/crates/limit/src/lib.rs @@ -55,13 +55,12 @@ impl Limit { if other <= old_max || old_max == 0 { break; } - if self - .max - .compare_exchange_weak(old_max, other, Ordering::Relaxed, Ordering::Relaxed) - .is_ok() - { - eprintln!("new max: {other}"); - } + _ = self.max.compare_exchange_weak( + old_max, + other, + Ordering::Relaxed, + Ordering::Relaxed, + ); } Ok(()) diff --git a/crates/load-cargo/src/lib.rs b/crates/load-cargo/src/lib.rs index 3878e20a2a..c6dc071c39 100644 --- a/crates/load-cargo/src/lib.rs +++ b/crates/load-cargo/src/lib.rs @@ -279,7 +279,7 @@ pub fn load_proc_macro( let dylib = MacroDylib::new(path.to_path_buf()); let vec = server.load_dylib(dylib).map_err(|e| format!("{e}"))?; if vec.is_empty() { - return Err("proc macro library returned no proc macros".to_string()); + return Err("proc macro library returned no proc macros".to_owned()); } Ok(vec .into_iter() @@ -382,7 +382,7 @@ impl ProcMacroExpander for Expander { call_site: Span, mixed_site: Span, ) -> Result, ProcMacroExpansionError> { - let env = env.iter().map(|(k, v)| (k.to_string(), v.to_string())).collect(); + let env = env.iter().map(|(k, v)| (k.to_owned(), v.to_owned())).collect(); match self.0.expand(subtree, attrs, env, def_site, call_site, mixed_site) { Ok(Ok(subtree)) => Ok(subtree), Ok(Err(err)) => Err(ProcMacroExpansionError::Panic(err.0)), diff --git a/crates/mbe/src/benchmark.rs b/crates/mbe/src/benchmark.rs index 6c3917b37f..d946ecc1ca 100644 --- a/crates/mbe/src/benchmark.rs +++ b/crates/mbe/src/benchmark.rs @@ -100,17 +100,19 @@ fn invocation_fixtures( // So we just skip any error cases and try again let mut try_cnt = 0; loop { - let mut subtree = tt::Subtree { + let mut token_trees = Vec::new(); + for op in rule.lhs.iter() { + collect_from_op(op, &mut token_trees, &mut seed); + } + + let subtree = tt::Subtree { delimiter: tt::Delimiter { open: DUMMY, close: DUMMY, kind: tt::DelimiterKind::Invisible, }, - token_trees: vec![], + token_trees: token_trees.into_boxed_slice(), }; - for op in rule.lhs.iter() { - collect_from_op(op, &mut subtree, &mut seed); - } if it.expand(&subtree, |_| (), true, DUMMY).err.is_none() { res.push((name.clone(), subtree)); break; @@ -127,45 +129,45 @@ fn invocation_fixtures( fn collect_from_op( op: &Op, - parent: &mut tt::Subtree, + token_trees: &mut Vec>, seed: &mut usize, ) { return match op { Op::Var { kind, .. } => match kind.as_ref() { - Some(MetaVarKind::Ident) => parent.token_trees.push(make_ident("foo")), - Some(MetaVarKind::Ty) => parent.token_trees.push(make_ident("Foo")), - Some(MetaVarKind::Tt) => parent.token_trees.push(make_ident("foo")), - Some(MetaVarKind::Vis) => parent.token_trees.push(make_ident("pub")), - Some(MetaVarKind::Pat) => parent.token_trees.push(make_ident("foo")), - Some(MetaVarKind::Path) => parent.token_trees.push(make_ident("foo")), - Some(MetaVarKind::Literal) => parent.token_trees.push(make_literal("1")), - Some(MetaVarKind::Expr) => parent.token_trees.push(make_ident("foo")), + Some(MetaVarKind::Ident) => token_trees.push(make_ident("foo")), + Some(MetaVarKind::Ty) => token_trees.push(make_ident("Foo")), + Some(MetaVarKind::Tt) => token_trees.push(make_ident("foo")), + Some(MetaVarKind::Vis) => token_trees.push(make_ident("pub")), + Some(MetaVarKind::Pat) => token_trees.push(make_ident("foo")), + Some(MetaVarKind::Path) => token_trees.push(make_ident("foo")), + Some(MetaVarKind::Literal) => token_trees.push(make_literal("1")), + Some(MetaVarKind::Expr) => token_trees.push(make_ident("foo")), Some(MetaVarKind::Lifetime) => { - parent.token_trees.push(make_punct('\'')); - parent.token_trees.push(make_ident("a")); + token_trees.push(make_punct('\'')); + token_trees.push(make_ident("a")); } Some(MetaVarKind::Block) => { - parent.token_trees.push(make_subtree(tt::DelimiterKind::Brace, None)) + token_trees.push(make_subtree(tt::DelimiterKind::Brace, None)) } Some(MetaVarKind::Item) => { - parent.token_trees.push(make_ident("fn")); - parent.token_trees.push(make_ident("foo")); - parent.token_trees.push(make_subtree(tt::DelimiterKind::Parenthesis, None)); - parent.token_trees.push(make_subtree(tt::DelimiterKind::Brace, None)); + token_trees.push(make_ident("fn")); + token_trees.push(make_ident("foo")); + token_trees.push(make_subtree(tt::DelimiterKind::Parenthesis, None)); + token_trees.push(make_subtree(tt::DelimiterKind::Brace, None)); } Some(MetaVarKind::Meta) => { - parent.token_trees.push(make_ident("foo")); - parent.token_trees.push(make_subtree(tt::DelimiterKind::Parenthesis, None)); + token_trees.push(make_ident("foo")); + token_trees.push(make_subtree(tt::DelimiterKind::Parenthesis, None)); } None => (), Some(kind) => panic!("Unhandled kind {kind:?}"), }, - Op::Literal(it) => parent.token_trees.push(tt::Leaf::from(it.clone()).into()), - Op::Ident(it) => parent.token_trees.push(tt::Leaf::from(it.clone()).into()), + Op::Literal(it) => token_trees.push(tt::Leaf::from(it.clone()).into()), + Op::Ident(it) => token_trees.push(tt::Leaf::from(it.clone()).into()), Op::Punct(puncts) => { for punct in puncts { - parent.token_trees.push(tt::Leaf::from(*punct).into()); + token_trees.push(tt::Leaf::from(*punct).into()); } } Op::Repeat { tokens, kind, separator } => { @@ -177,20 +179,20 @@ fn invocation_fixtures( }; for i in 0..cnt { for it in tokens.iter() { - collect_from_op(it, parent, seed); + collect_from_op(it, token_trees, seed); } if i + 1 != cnt { if let Some(sep) = separator { match sep { Separator::Literal(it) => { - parent.token_trees.push(tt::Leaf::Literal(it.clone()).into()) + token_trees.push(tt::Leaf::Literal(it.clone()).into()) } Separator::Ident(it) => { - parent.token_trees.push(tt::Leaf::Ident(it.clone()).into()) + token_trees.push(tt::Leaf::Ident(it.clone()).into()) } Separator::Puncts(puncts) => { for it in puncts { - parent.token_trees.push(tt::Leaf::Punct(*it).into()) + token_trees.push(tt::Leaf::Punct(*it).into()) } } }; @@ -199,11 +201,15 @@ fn invocation_fixtures( } } Op::Subtree { tokens, delimiter } => { - let mut subtree = tt::Subtree { delimiter: *delimiter, token_trees: Vec::new() }; + let mut subtree = Vec::new(); tokens.iter().for_each(|it| { collect_from_op(it, &mut subtree, seed); }); - parent.token_trees.push(subtree.into()); + + let subtree = + tt::Subtree { delimiter: *delimiter, token_trees: subtree.into_boxed_slice() }; + + token_trees.push(subtree.into()); } Op::Ignore { .. } | Op::Index { .. } | Op::Count { .. } | Op::Length { .. } => {} }; @@ -230,7 +236,7 @@ fn invocation_fixtures( ) -> tt::TokenTree { tt::Subtree { delimiter: tt::Delimiter { open: DUMMY, close: DUMMY, kind }, - token_trees: token_trees.unwrap_or_default(), + token_trees: token_trees.map(Vec::into_boxed_slice).unwrap_or_default(), } .into() } diff --git a/crates/mbe/src/expander.rs b/crates/mbe/src/expander.rs index 60483809dc..9366048fd9 100644 --- a/crates/mbe/src/expander.rs +++ b/crates/mbe/src/expander.rs @@ -58,7 +58,7 @@ pub(crate) fn expand_rules( ExpandResult::new( tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(call_site), - token_trees: vec![], + token_trees: Box::new([]), }, ExpandError::NoMatchingRule, ) diff --git a/crates/mbe/src/expander/matcher.rs b/crates/mbe/src/expander/matcher.rs index 40b4c7cdd6..eea92cfba4 100644 --- a/crates/mbe/src/expander/matcher.rs +++ b/crates/mbe/src/expander/matcher.rs @@ -798,7 +798,7 @@ fn match_meta_var( tt.map(|tt| match tt { tt::TokenTree::Leaf(leaf) => tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(*leaf.span()), - token_trees: vec![leaf.into()], + token_trees: Box::new([leaf.into()]), }, tt::TokenTree::Subtree(mut s) => { if s.delimiter.kind == tt::DelimiterKind::Invisible { @@ -832,7 +832,7 @@ fn match_meta_var( None => lit.into(), Some(neg) => tt::TokenTree::Subtree(tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(*literal.span()), - token_trees: vec![neg, lit.into()], + token_trees: Box::new([neg, lit.into()]), }), } }) @@ -989,10 +989,10 @@ impl TtIter<'_, S> { close: ident.span, kind: tt::DelimiterKind::Invisible, }, - token_trees: vec![ + token_trees: Box::new([ tt::Leaf::Punct(*punct).into(), tt::Leaf::Ident(ident.clone()).into(), - ], + ]), } .into()) } diff --git a/crates/mbe/src/expander/transcriber.rs b/crates/mbe/src/expander/transcriber.rs index 800bc994ad..9291f799cc 100644 --- a/crates/mbe/src/expander/transcriber.rs +++ b/crates/mbe/src/expander/transcriber.rs @@ -83,7 +83,7 @@ impl Bindings { close: span, kind: tt::DelimiterKind::Brace, }, - token_trees: vec![], + token_trees: Box::new([]), })), // FIXME: Meta and Item should get proper defaults MetaVarKind::Meta | MetaVarKind::Item | MetaVarKind::Tt | MetaVarKind::Vis => { @@ -349,11 +349,11 @@ fn expand_var( // We just treat it a normal tokens let tt = tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(id), - token_trees: vec![ + token_trees: Box::new([ tt::Leaf::from(tt::Punct { char: '$', spacing: tt::Spacing::Alone, span: id }) .into(), tt::Leaf::from(tt::Ident { text: v.clone(), span: id }).into(), - ], + ]), } .into(); ExpandResult::ok(Fragment::Tokens(tt)) @@ -406,7 +406,7 @@ fn expand_repeat( value: Fragment::Tokens( tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(ctx.call_site), - token_trees: vec![], + token_trees: Box::new([]), } .into(), ), @@ -455,7 +455,7 @@ fn expand_repeat( // e.g {Delimiter:None> ['>'] /Delimiter:None>} let tt = tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(ctx.call_site), - token_trees: buf, + token_trees: buf.into_boxed_slice(), } .into(); @@ -486,7 +486,7 @@ fn push_fragment( fn push_subtree(buf: &mut Vec>, tt: tt::Subtree) { match tt.delimiter.kind { - tt::DelimiterKind::Invisible => buf.extend(tt.token_trees), + tt::DelimiterKind::Invisible => buf.extend(Vec::from(tt.token_trees)), _ => buf.push(tt.into()), } } @@ -504,7 +504,7 @@ fn fix_up_and_push_path_tt( // Note that we only need to fix up the top-level `TokenTree`s because the // context of the paths in the descendant `Subtree`s won't be changed by the // mbe transcription. - for tt in subtree.token_trees { + for tt in Vec::from(subtree.token_trees) { if prev_was_ident { // Pedantically, `(T) -> U` in `FnOnce(T) -> U` is treated as a generic // argument list and thus needs `::` between it and `FnOnce`. However in diff --git a/crates/mbe/src/syntax_bridge.rs b/crates/mbe/src/syntax_bridge.rs index d6c3bd1892..bfc5d197f6 100644 --- a/crates/mbe/src/syntax_bridge.rs +++ b/crates/mbe/src/syntax_bridge.rs @@ -130,7 +130,7 @@ where tt::Subtree { delimiter: tt::Delimiter { kind: tt::DelimiterKind::Invisible, .. }, token_trees, - } => TokenBuffer::from_tokens(token_trees.as_slice()), + } => TokenBuffer::from_tokens(token_trees), _ => TokenBuffer::from_subtree(tt), }; let parser_input = to_parser_input(&buffer); @@ -146,7 +146,7 @@ where } parser::Step::Enter { kind } => tree_sink.start_node(kind), parser::Step::Exit => tree_sink.finish_node(), - parser::Step::Error { msg } => tree_sink.error(msg.to_string()), + parser::Step::Error { msg } => tree_sink.error(msg.to_owned()), } } tree_sink.finish() @@ -227,14 +227,14 @@ where C: TokenConverter, S: Span, { - let entry = tt::Subtree { + let entry = tt::SubtreeBuilder { delimiter: tt::Delimiter::invisible_spanned(conv.call_site()), token_trees: vec![], }; let mut stack = NonEmptyVec::new(entry); while let Some((token, abs_range)) = conv.bump() { - let tt::Subtree { delimiter, token_trees } = stack.last_mut(); + let tt::SubtreeBuilder { delimiter, token_trees } = stack.last_mut(); let tt = match token.as_leaf() { Some(leaf) => tt::TokenTree::Leaf(leaf.clone()), @@ -260,7 +260,7 @@ where if matches!(expected, Some(expected) if expected == kind) { if let Some(mut subtree) = stack.pop() { subtree.delimiter.close = conv.span_for(abs_range); - stack.last_mut().token_trees.push(subtree.into()); + stack.last_mut().token_trees.push(subtree.build().into()); } continue; } @@ -275,7 +275,7 @@ where // Start a new subtree if let Some(kind) = delim { let open = conv.span_for(abs_range); - stack.push(tt::Subtree { + stack.push(tt::SubtreeBuilder { delimiter: tt::Delimiter { open, // will be overwritten on subtree close above @@ -361,7 +361,7 @@ where parent.token_trees.extend(entry.token_trees); } - let subtree = stack.into_last(); + let subtree = stack.into_last().build(); if let [tt::TokenTree::Subtree(first)] = &*subtree.token_trees { first.clone() } else { @@ -454,7 +454,7 @@ fn convert_doc_comment( }; // Make `doc="\" Comments\"" - let meta_tkns = vec![mk_ident("doc"), mk_punct('='), mk_doc_literal(&comment)]; + let meta_tkns = Box::new([mk_ident("doc"), mk_punct('='), mk_doc_literal(&comment)]); // Make `#![]` let mut token_trees = Vec::with_capacity(3); diff --git a/crates/mbe/src/syntax_bridge/tests.rs b/crates/mbe/src/syntax_bridge/tests.rs index e5569138db..11d1a72879 100644 --- a/crates/mbe/src/syntax_bridge/tests.rs +++ b/crates/mbe/src/syntax_bridge/tests.rs @@ -1,5 +1,4 @@ -use std::collections::HashMap; - +use rustc_hash::FxHashMap; use syntax::{ast, AstNode}; use test_utils::extract_annotations; use tt::{ @@ -12,7 +11,7 @@ use crate::{syntax_node_to_token_tree, DummyTestSpanData, DummyTestSpanMap, DUMM fn check_punct_spacing(fixture: &str) { let source_file = ast::SourceFile::parse(fixture).ok().unwrap(); let subtree = syntax_node_to_token_tree(source_file.syntax(), DummyTestSpanMap, DUMMY); - let mut annotations: HashMap<_, _> = extract_annotations(fixture) + let mut annotations: FxHashMap<_, _> = extract_annotations(fixture) .into_iter() .map(|(range, annotation)| { let spacing = match annotation.as_str() { diff --git a/crates/mbe/src/tt_iter.rs b/crates/mbe/src/tt_iter.rs index 71513ef439..f9913cb6f9 100644 --- a/crates/mbe/src/tt_iter.rs +++ b/crates/mbe/src/tt_iter.rs @@ -180,7 +180,7 @@ impl<'a, S: Span> TtIter<'a, S> { [] | [_] => res.pop(), [first, ..] => Some(tt::TokenTree::Subtree(tt::Subtree { delimiter: tt::Delimiter::invisible_spanned(first.first_span()), - token_trees: res, + token_trees: res.into_boxed_slice(), })), }; ExpandResult { value: res, err } diff --git a/crates/parser/src/grammar.rs b/crates/parser/src/grammar.rs index 53fda3ae4f..34715628f1 100644 --- a/crates/parser/src/grammar.rs +++ b/crates/parser/src/grammar.rs @@ -393,11 +393,26 @@ fn delimited( bra: SyntaxKind, ket: SyntaxKind, delim: SyntaxKind, + unexpected_delim_message: impl Fn() -> String, first_set: TokenSet, mut parser: impl FnMut(&mut Parser<'_>) -> bool, ) { p.bump(bra); while !p.at(ket) && !p.at(EOF) { + if p.at(delim) { + // Recover if an argument is missing and only got a delimiter, + // e.g. `(a, , b)`. + + // Wrap the erroneous delimiter in an error node so that fixup logic gets rid of it. + // FIXME: Ideally this should be handled in fixup in a structured way, but our list + // nodes currently have no concept of a missing node between two delimiters. + // So doing it this way is easier. + let m = p.start(); + p.error(unexpected_delim_message()); + p.bump(delim); + m.complete(p, ERROR); + continue; + } if !parser(p) { break; } diff --git a/crates/parser/src/grammar/expressions.rs b/crates/parser/src/grammar/expressions.rs index c862611114..f40c515fa0 100644 --- a/crates/parser/src/grammar/expressions.rs +++ b/crates/parser/src/grammar/expressions.rs @@ -530,6 +530,15 @@ fn method_call_expr( generic_args::opt_generic_arg_list(p, true); if p.at(T!['(']) { arg_list(p); + } else { + // emit an error when argument list is missing + + // test_err method_call_missing_argument_list + // fn func() { + // foo.bar::<> + // foo.bar::; + // } + p.error("expected argument list"); } m.complete(p, METHOD_CALL_EXPR) } @@ -602,6 +611,7 @@ fn cast_expr(p: &mut Parser<'_>, lhs: CompletedMarker) -> CompletedMarker { // foo(bar::); // foo(bar:); // foo(bar+); +// foo(a, , b); // } fn arg_list(p: &mut Parser<'_>) { assert!(p.at(T!['('])); @@ -615,8 +625,9 @@ fn arg_list(p: &mut Parser<'_>) { T!['('], T![')'], T![,], + || "expected expression".into(), EXPR_FIRST.union(ATTRIBUTE_FIRST), - |p: &mut Parser<'_>| expr(p).is_some(), + |p| expr(p).is_some(), ); m.complete(p, ARG_LIST); } diff --git a/crates/parser/src/grammar/generic_args.rs b/crates/parser/src/grammar/generic_args.rs index 211af98e6e..249be2a333 100644 --- a/crates/parser/src/grammar/generic_args.rs +++ b/crates/parser/src/grammar/generic_args.rs @@ -1,5 +1,7 @@ use super::*; +// test_err generic_arg_list_recover +// type T = T<0, ,T>; pub(super) fn opt_generic_arg_list(p: &mut Parser<'_>, colon_colon_required: bool) { let m; if p.at(T![::]) && p.nth(2) == T![<] { @@ -11,7 +13,15 @@ pub(super) fn opt_generic_arg_list(p: &mut Parser<'_>, colon_colon_required: boo return; } - delimited(p, T![<], T![>], T![,], GENERIC_ARG_FIRST, generic_arg); + delimited( + p, + T![<], + T![>], + T![,], + || "expected generic argument".into(), + GENERIC_ARG_FIRST, + generic_arg, + ); m.complete(p, GENERIC_ARG_LIST); } diff --git a/crates/parser/src/grammar/generic_params.rs b/crates/parser/src/grammar/generic_params.rs index 29d9b05d3f..3c577aa3cb 100644 --- a/crates/parser/src/grammar/generic_params.rs +++ b/crates/parser/src/grammar/generic_params.rs @@ -10,16 +10,27 @@ pub(super) fn opt_generic_param_list(p: &mut Parser<'_>) { // test generic_param_list // fn f() {} + +// test_err generic_param_list_recover +// fn f() {} fn generic_param_list(p: &mut Parser<'_>) { assert!(p.at(T![<])); let m = p.start(); - delimited(p, T![<], T![>], T![,], GENERIC_PARAM_FIRST.union(ATTRIBUTE_FIRST), |p| { - // test generic_param_attribute - // fn foo<#[lt_attr] 'a, #[t_attr] T>() {} - let m = p.start(); - attributes::outer_attrs(p); - generic_param(p, m) - }); + delimited( + p, + T![<], + T![>], + T![,], + || "expected generic parameter".into(), + GENERIC_PARAM_FIRST.union(ATTRIBUTE_FIRST), + |p| { + // test generic_param_attribute + // fn foo<#[lt_attr] 'a, #[t_attr] T>() {} + let m = p.start(); + attributes::outer_attrs(p); + generic_param(p, m) + }, + ); m.complete(p, GENERIC_PARAM_LIST); } diff --git a/crates/parser/src/grammar/items/adt.rs b/crates/parser/src/grammar/items/adt.rs index 17f41b8e13..21078175c0 100644 --- a/crates/parser/src/grammar/items/adt.rs +++ b/crates/parser/src/grammar/items/adt.rs @@ -146,28 +146,39 @@ pub(crate) fn record_field_list(p: &mut Parser<'_>) { const TUPLE_FIELD_FIRST: TokenSet = types::TYPE_FIRST.union(ATTRIBUTE_FIRST).union(VISIBILITY_FIRST); +// test_err tuple_field_list_recovery +// struct S(struct S; +// struct S(A,,B); fn tuple_field_list(p: &mut Parser<'_>) { assert!(p.at(T!['('])); let m = p.start(); - delimited(p, T!['('], T![')'], T![,], TUPLE_FIELD_FIRST, |p| { - let m = p.start(); - // test tuple_field_attrs - // struct S (#[attr] f32); - attributes::outer_attrs(p); - let has_vis = opt_visibility(p, true); - if !p.at_ts(types::TYPE_FIRST) { - p.error("expected a type"); - if has_vis { - m.complete(p, ERROR); - } else { - m.abandon(p); + delimited( + p, + T!['('], + T![')'], + T![,], + || "expected tuple field".into(), + TUPLE_FIELD_FIRST, + |p| { + let m = p.start(); + // test tuple_field_attrs + // struct S (#[attr] f32); + attributes::outer_attrs(p); + let has_vis = opt_visibility(p, true); + if !p.at_ts(types::TYPE_FIRST) { + p.error("expected a type"); + if has_vis { + m.complete(p, ERROR); + } else { + m.abandon(p); + } + return false; } - return false; - } - types::type_(p); - m.complete(p, TUPLE_FIELD); - true - }); + types::type_(p); + m.complete(p, TUPLE_FIELD); + true + }, + ); m.complete(p, TUPLE_FIELD_LIST); } diff --git a/crates/parser/src/grammar/items/use_item.rs b/crates/parser/src/grammar/items/use_item.rs index f689c06b31..675a1fd465 100644 --- a/crates/parser/src/grammar/items/use_item.rs +++ b/crates/parser/src/grammar/items/use_item.rs @@ -93,9 +93,16 @@ pub(crate) fn use_tree_list(p: &mut Parser<'_>) { // use b; // struct T; // fn test() {} - delimited(p, T!['{'], T!['}'], T![,], USE_TREE_LIST_FIRST_SET, |p: &mut Parser<'_>| { - use_tree(p, false) || p.at_ts(USE_TREE_LIST_RECOVERY_SET) - }); + // use {a ,, b}; + delimited( + p, + T!['{'], + T!['}'], + T![,], + || "expected use tree".into(), + USE_TREE_LIST_FIRST_SET, + |p: &mut Parser<'_>| use_tree(p, false) || p.at_ts(USE_TREE_LIST_RECOVERY_SET), + ); m.complete(p, USE_TREE_LIST); } diff --git a/crates/parser/src/lexed_str.rs b/crates/parser/src/lexed_str.rs index bf1feb9a7e..2da9184693 100644 --- a/crates/parser/src/lexed_str.rs +++ b/crates/parser/src/lexed_str.rs @@ -149,7 +149,7 @@ impl<'a> Converter<'a> { if let Some(err) = err { let token = self.res.len() as u32; - let msg = err.to_string(); + let msg = err.to_owned(); self.res.error.push(LexError { msg, token }); } } diff --git a/crates/parser/src/tests/sourcegen_inline_tests.rs b/crates/parser/src/tests/sourcegen_inline_tests.rs index bd9e188e4d..5a71bfd82b 100644 --- a/crates/parser/src/tests/sourcegen_inline_tests.rs +++ b/crates/parser/src/tests/sourcegen_inline_tests.rs @@ -1,5 +1,6 @@ //! This module greps parser's code for specially formatted comments and turns //! them into tests. +#![allow(clippy::disallowed_types, clippy::print_stdout)] use std::{ collections::HashMap, @@ -59,9 +60,9 @@ fn collect_tests(s: &str) -> Vec { for comment_block in sourcegen::CommentBlock::extract_untagged(s) { let first_line = &comment_block.contents[0]; let (name, ok) = if let Some(name) = first_line.strip_prefix("test ") { - (name.to_string(), true) + (name.to_owned(), true) } else if let Some(name) = first_line.strip_prefix("test_err ") { - (name.to_string(), false) + (name.to_owned(), false) } else { continue; }; diff --git a/crates/parser/test_data/parser/inline/err/0015_arg_list_recovery.rast b/crates/parser/test_data/parser/inline/err/0015_arg_list_recovery.rast index 5d0fe859c2..cd5aa680c6 100644 --- a/crates/parser/test_data/parser/inline/err/0015_arg_list_recovery.rast +++ b/crates/parser/test_data/parser/inline/err/0015_arg_list_recovery.rast @@ -68,6 +68,33 @@ SOURCE_FILE PLUS "+" R_PAREN ")" SEMICOLON ";" + WHITESPACE "\n " + EXPR_STMT + CALL_EXPR + PATH_EXPR + PATH + PATH_SEGMENT + NAME_REF + IDENT "foo" + ARG_LIST + L_PAREN "(" + PATH_EXPR + PATH + PATH_SEGMENT + NAME_REF + IDENT "a" + COMMA "," + WHITESPACE " " + ERROR + COMMA "," + WHITESPACE " " + PATH_EXPR + PATH + PATH_SEGMENT + NAME_REF + IDENT "b" + R_PAREN ")" + SEMICOLON ";" WHITESPACE "\n" R_CURLY "}" WHITESPACE "\n" @@ -75,3 +102,4 @@ error 25: expected identifier error 39: expected COMMA error 39: expected expression error 55: expected expression +error 69: expected expression diff --git a/crates/parser/test_data/parser/inline/err/0015_arg_list_recovery.rs b/crates/parser/test_data/parser/inline/err/0015_arg_list_recovery.rs index 0e7ac9cc30..175a31f8b5 100644 --- a/crates/parser/test_data/parser/inline/err/0015_arg_list_recovery.rs +++ b/crates/parser/test_data/parser/inline/err/0015_arg_list_recovery.rs @@ -2,4 +2,5 @@ fn main() { foo(bar::); foo(bar:); foo(bar+); + foo(a, , b); } diff --git a/crates/parser/test_data/parser/inline/err/0026_use_tree_list_err_recovery.rast b/crates/parser/test_data/parser/inline/err/0026_use_tree_list_err_recovery.rast index cb90b093ba..b576d872e1 100644 --- a/crates/parser/test_data/parser/inline/err/0026_use_tree_list_err_recovery.rast +++ b/crates/parser/test_data/parser/inline/err/0026_use_tree_list_err_recovery.rast @@ -43,4 +43,29 @@ SOURCE_FILE L_CURLY "{" R_CURLY "}" WHITESPACE "\n" + USE + USE_KW "use" + WHITESPACE " " + USE_TREE + USE_TREE_LIST + L_CURLY "{" + USE_TREE + PATH + PATH_SEGMENT + NAME_REF + IDENT "a" + WHITESPACE " " + COMMA "," + ERROR + COMMA "," + WHITESPACE " " + USE_TREE + PATH + PATH_SEGMENT + NAME_REF + IDENT "b" + R_CURLY "}" + SEMICOLON ";" + WHITESPACE "\n" error 6: expected R_CURLY +error 46: expected use tree diff --git a/crates/parser/test_data/parser/inline/err/0026_use_tree_list_err_recovery.rs b/crates/parser/test_data/parser/inline/err/0026_use_tree_list_err_recovery.rs index f16959c25f..9885e6ab27 100644 --- a/crates/parser/test_data/parser/inline/err/0026_use_tree_list_err_recovery.rs +++ b/crates/parser/test_data/parser/inline/err/0026_use_tree_list_err_recovery.rs @@ -2,3 +2,4 @@ use {a; use b; struct T; fn test() {} +use {a ,, b}; diff --git a/crates/parser/test_data/parser/inline/err/0028_method_call_missing_argument_list.rast b/crates/parser/test_data/parser/inline/err/0028_method_call_missing_argument_list.rast new file mode 100644 index 0000000000..02544b5e53 --- /dev/null +++ b/crates/parser/test_data/parser/inline/err/0028_method_call_missing_argument_list.rast @@ -0,0 +1,56 @@ +SOURCE_FILE + FN + FN_KW "fn" + WHITESPACE " " + NAME + IDENT "func" + PARAM_LIST + L_PAREN "(" + R_PAREN ")" + WHITESPACE " " + BLOCK_EXPR + STMT_LIST + L_CURLY "{" + WHITESPACE "\n " + EXPR_STMT + METHOD_CALL_EXPR + PATH_EXPR + PATH + PATH_SEGMENT + NAME_REF + IDENT "foo" + DOT "." + NAME_REF + IDENT "bar" + GENERIC_ARG_LIST + COLON2 "::" + L_ANGLE "<" + R_ANGLE ">" + WHITESPACE "\n " + EXPR_STMT + METHOD_CALL_EXPR + PATH_EXPR + PATH + PATH_SEGMENT + NAME_REF + IDENT "foo" + DOT "." + NAME_REF + IDENT "bar" + GENERIC_ARG_LIST + COLON2 "::" + L_ANGLE "<" + TYPE_ARG + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "i32" + R_ANGLE ">" + SEMICOLON ";" + WHITESPACE "\n" + R_CURLY "}" + WHITESPACE "\n" +error 27: expected argument list +error 27: expected SEMICOLON +error 46: expected argument list diff --git a/crates/parser/test_data/parser/inline/err/0028_method_call_missing_argument_list.rs b/crates/parser/test_data/parser/inline/err/0028_method_call_missing_argument_list.rs new file mode 100644 index 0000000000..7c8baecaa9 --- /dev/null +++ b/crates/parser/test_data/parser/inline/err/0028_method_call_missing_argument_list.rs @@ -0,0 +1,4 @@ +fn func() { + foo.bar::<> + foo.bar::; +} diff --git a/crates/parser/test_data/parser/inline/err/0029_tuple_field_list_recovery.rast b/crates/parser/test_data/parser/inline/err/0029_tuple_field_list_recovery.rast new file mode 100644 index 0000000000..6b0bfa007e --- /dev/null +++ b/crates/parser/test_data/parser/inline/err/0029_tuple_field_list_recovery.rast @@ -0,0 +1,44 @@ +SOURCE_FILE + STRUCT + STRUCT_KW "struct" + WHITESPACE " " + NAME + IDENT "S" + TUPLE_FIELD_LIST + L_PAREN "(" + STRUCT + STRUCT_KW "struct" + WHITESPACE " " + NAME + IDENT "S" + SEMICOLON ";" + WHITESPACE "\n" + STRUCT + STRUCT_KW "struct" + WHITESPACE " " + NAME + IDENT "S" + TUPLE_FIELD_LIST + L_PAREN "(" + TUPLE_FIELD + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "A" + COMMA "," + ERROR + COMMA "," + TUPLE_FIELD + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "B" + R_PAREN ")" + SEMICOLON ";" + WHITESPACE "\n" +error 9: expected a type +error 9: expected R_PAREN +error 9: expected SEMICOLON +error 30: expected tuple field diff --git a/crates/parser/test_data/parser/inline/err/0029_tuple_field_list_recovery.rs b/crates/parser/test_data/parser/inline/err/0029_tuple_field_list_recovery.rs new file mode 100644 index 0000000000..ecb4d8bda1 --- /dev/null +++ b/crates/parser/test_data/parser/inline/err/0029_tuple_field_list_recovery.rs @@ -0,0 +1,2 @@ +struct S(struct S; +struct S(A,,B); diff --git a/crates/parser/test_data/parser/inline/err/0030_generic_arg_list_recover.rast b/crates/parser/test_data/parser/inline/err/0030_generic_arg_list_recover.rast new file mode 100644 index 0000000000..4cf5a3386b --- /dev/null +++ b/crates/parser/test_data/parser/inline/err/0030_generic_arg_list_recover.rast @@ -0,0 +1,33 @@ +SOURCE_FILE + TYPE_ALIAS + TYPE_KW "type" + WHITESPACE " " + NAME + IDENT "T" + WHITESPACE " " + EQ "=" + WHITESPACE " " + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "T" + GENERIC_ARG_LIST + L_ANGLE "<" + CONST_ARG + LITERAL + INT_NUMBER "0" + COMMA "," + WHITESPACE " " + ERROR + COMMA "," + TYPE_ARG + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "T" + R_ANGLE ">" + SEMICOLON ";" + WHITESPACE "\n" +error 14: expected generic argument diff --git a/crates/parser/test_data/parser/inline/err/0030_generic_arg_list_recover.rs b/crates/parser/test_data/parser/inline/err/0030_generic_arg_list_recover.rs new file mode 100644 index 0000000000..7d849aa1be --- /dev/null +++ b/crates/parser/test_data/parser/inline/err/0030_generic_arg_list_recover.rs @@ -0,0 +1 @@ +type T = T<0, ,T>; diff --git a/crates/parser/test_data/parser/inline/err/0031_generic_param_list_recover.rast b/crates/parser/test_data/parser/inline/err/0031_generic_param_list_recover.rast new file mode 100644 index 0000000000..0a1ed01fbe --- /dev/null +++ b/crates/parser/test_data/parser/inline/err/0031_generic_param_list_recover.rast @@ -0,0 +1,45 @@ +SOURCE_FILE + FN + FN_KW "fn" + WHITESPACE " " + NAME + IDENT "f" + GENERIC_PARAM_LIST + L_ANGLE "<" + TYPE_PARAM + NAME + IDENT "T" + COLON ":" + WHITESPACE " " + TYPE_BOUND_LIST + TYPE_BOUND + PATH_TYPE + PATH + PATH_SEGMENT + NAME_REF + IDENT "Clone" + COMMA "," + ERROR + COMMA "," + WHITESPACE " " + TYPE_PARAM + NAME + IDENT "U" + COLON ":" + TYPE_BOUND_LIST + COMMA "," + WHITESPACE " " + TYPE_PARAM + NAME + IDENT "V" + R_ANGLE ">" + PARAM_LIST + L_PAREN "(" + R_PAREN ")" + WHITESPACE " " + BLOCK_EXPR + STMT_LIST + L_CURLY "{" + R_CURLY "}" + WHITESPACE "\n" +error 14: expected generic parameter diff --git a/crates/parser/test_data/parser/inline/err/0031_generic_param_list_recover.rs b/crates/parser/test_data/parser/inline/err/0031_generic_param_list_recover.rs new file mode 100644 index 0000000000..2b5149bb0d --- /dev/null +++ b/crates/parser/test_data/parser/inline/err/0031_generic_param_list_recover.rs @@ -0,0 +1 @@ +fn f() {} diff --git a/crates/proc-macro-api/Cargo.toml b/crates/proc-macro-api/Cargo.toml index 49a0979f4f..cf01b94c0a 100644 --- a/crates/proc-macro-api/Cargo.toml +++ b/crates/proc-macro-api/Cargo.toml @@ -23,6 +23,7 @@ serde.workspace = true serde_json = { workspace = true, features = ["unbounded_depth"] } tracing.workspace = true triomphe.workspace = true +rustc-hash.workspace = true memmap2 = "0.5.4" snap = "1.1.0" indexmap = "2.1.0" @@ -40,4 +41,4 @@ base-db.workspace = true la-arena.workspace = true [lints] -workspace = true \ No newline at end of file +workspace = true diff --git a/crates/proc-macro-api/src/lib.rs b/crates/proc-macro-api/src/lib.rs index 379d184dd6..1dadfc40ac 100644 --- a/crates/proc-macro-api/src/lib.rs +++ b/crates/proc-macro-api/src/lib.rs @@ -197,7 +197,7 @@ impl ProcMacro { &deserialize_span_data_index_map(&resp.span_data_table), ) })), - _ => Err(ServerError { message: "unexpected response".to_string(), io: None }), + _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }), } } } diff --git a/crates/proc-macro-api/src/msg.rs b/crates/proc-macro-api/src/msg.rs index e28fe387b8..aa5aff455f 100644 --- a/crates/proc-macro-api/src/msg.rs +++ b/crates/proc-macro-api/src/msg.rs @@ -187,7 +187,67 @@ mod tests { file_id: FileId::from_raw(0), ast_id: ErasedFileAstId::from_raw(RawIdx::from(0)), }; - let mut subtree = Subtree { + + let token_trees = Box::new([ + TokenTree::Leaf( + Ident { + text: "struct".into(), + span: Span { + range: TextRange::at(TextSize::new(0), TextSize::of("struct")), + anchor, + ctx: SyntaxContextId::ROOT, + }, + } + .into(), + ), + TokenTree::Leaf( + Ident { + text: "Foo".into(), + span: Span { + range: TextRange::at(TextSize::new(5), TextSize::of("Foo")), + anchor, + ctx: SyntaxContextId::ROOT, + }, + } + .into(), + ), + TokenTree::Leaf(Leaf::Literal(Literal { + text: "Foo".into(), + + span: Span { + range: TextRange::at(TextSize::new(8), TextSize::of("Foo")), + anchor, + ctx: SyntaxContextId::ROOT, + }, + })), + TokenTree::Leaf(Leaf::Punct(Punct { + char: '@', + span: Span { + range: TextRange::at(TextSize::new(11), TextSize::of('@')), + anchor, + ctx: SyntaxContextId::ROOT, + }, + spacing: Spacing::Joint, + })), + TokenTree::Subtree(Subtree { + delimiter: Delimiter { + open: Span { + range: TextRange::at(TextSize::new(12), TextSize::of('{')), + anchor, + ctx: SyntaxContextId::ROOT, + }, + close: Span { + range: TextRange::at(TextSize::new(13), TextSize::of('}')), + anchor, + ctx: SyntaxContextId::ROOT, + }, + kind: DelimiterKind::Brace, + }, + token_trees: Box::new([]), + }), + ]); + + Subtree { delimiter: Delimiter { open: Span { range: TextRange::empty(TextSize::new(0)), @@ -201,65 +261,8 @@ mod tests { }, kind: DelimiterKind::Invisible, }, - token_trees: Vec::new(), - }; - subtree.token_trees.push(TokenTree::Leaf( - Ident { - text: "struct".into(), - span: Span { - range: TextRange::at(TextSize::new(0), TextSize::of("struct")), - anchor, - ctx: SyntaxContextId::ROOT, - }, - } - .into(), - )); - subtree.token_trees.push(TokenTree::Leaf( - Ident { - text: "Foo".into(), - span: Span { - range: TextRange::at(TextSize::new(5), TextSize::of("Foo")), - anchor, - ctx: SyntaxContextId::ROOT, - }, - } - .into(), - )); - subtree.token_trees.push(TokenTree::Leaf(Leaf::Literal(Literal { - text: "Foo".into(), - - span: Span { - range: TextRange::at(TextSize::new(8), TextSize::of("Foo")), - anchor, - ctx: SyntaxContextId::ROOT, - }, - }))); - subtree.token_trees.push(TokenTree::Leaf(Leaf::Punct(Punct { - char: '@', - span: Span { - range: TextRange::at(TextSize::new(11), TextSize::of('@')), - anchor, - ctx: SyntaxContextId::ROOT, - }, - spacing: Spacing::Joint, - }))); - subtree.token_trees.push(TokenTree::Subtree(Subtree { - delimiter: Delimiter { - open: Span { - range: TextRange::at(TextSize::new(12), TextSize::of('{')), - anchor, - ctx: SyntaxContextId::ROOT, - }, - close: Span { - range: TextRange::at(TextSize::new(13), TextSize::of('}')), - anchor, - ctx: SyntaxContextId::ROOT, - }, - kind: DelimiterKind::Brace, - }, - token_trees: vec![], - })); - subtree + token_trees, + } } #[test] diff --git a/crates/proc-macro-api/src/msg/flat.rs b/crates/proc-macro-api/src/msg/flat.rs index 8dfaba5262..caf9e237fd 100644 --- a/crates/proc-macro-api/src/msg/flat.rs +++ b/crates/proc-macro-api/src/msg/flat.rs @@ -35,10 +35,11 @@ //! as we don't have bincode in Cargo.toml yet, lets stick with serde_json for //! the time being. -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use indexmap::IndexSet; use la_arena::RawIdx; +use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use span::{ErasedFileAstId, FileId, Span, SpanAnchor, SyntaxContextId}; use text_size::TextRange; @@ -129,7 +130,7 @@ impl FlatTree { span_data_table: &mut SpanDataIndexMap, ) -> FlatTree { let mut w = Writer { - string_table: HashMap::new(), + string_table: FxHashMap::default(), work: VecDeque::new(), span_data_table, @@ -158,7 +159,7 @@ impl FlatTree { pub fn new_raw(subtree: &tt::Subtree, version: u32) -> FlatTree { let mut w = Writer { - string_table: HashMap::new(), + string_table: FxHashMap::default(), work: VecDeque::new(), span_data_table: &mut (), @@ -340,7 +341,7 @@ impl InternableSpan for Span { struct Writer<'a, 'span, S: InternableSpan> { work: VecDeque<(usize, &'a tt::Subtree)>, - string_table: HashMap<&'a str, u32>, + string_table: FxHashMap<&'a str, u32>, span_data_table: &'span mut S::Table, subtree: Vec, @@ -370,7 +371,7 @@ impl<'a, 'span, S: InternableSpan> Writer<'a, 'span, S> { self.subtree[idx].tt = [first_tt as u32, (first_tt + n_tt) as u32]; - for child in &subtree.token_trees { + for child in subtree.token_trees.iter() { let idx_tag = match child { tt::TokenTree::Subtree(it) => { let idx = self.enqueue(it); @@ -418,7 +419,7 @@ impl<'a, 'span, S: InternableSpan> Writer<'a, 'span, S> { let table = &mut self.text; *self.string_table.entry(text).or_insert_with(|| { let idx = table.len(); - table.push(text.to_string()); + table.push(text.to_owned()); idx as u32 }) } diff --git a/crates/proc-macro-api/src/process.rs b/crates/proc-macro-api/src/process.rs index 5ce601bce6..96f97bf5e2 100644 --- a/crates/proc-macro-api/src/process.rs +++ b/crates/proc-macro-api/src/process.rs @@ -78,7 +78,7 @@ impl ProcMacroProcessSrv { match response { Response::ApiVersionCheck(version) => Ok(version), - _ => Err(ServerError { message: "unexpected response".to_string(), io: None }), + _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }), } } @@ -90,7 +90,7 @@ impl ProcMacroProcessSrv { match response { Response::SetConfig(crate::msg::ServerConfig { span_mode }) => Ok(span_mode), - _ => Err(ServerError { message: "unexpected response".to_string(), io: None }), + _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }), } } @@ -104,7 +104,7 @@ impl ProcMacroProcessSrv { match response { Response::ListMacros(it) => Ok(it), - _ => Err(ServerError { message: "unexpected response".to_string(), io: None }), + _ => Err(ServerError { message: "unexpected response".to_owned(), io: None }), } } diff --git a/crates/proc-macro-api/src/version.rs b/crates/proc-macro-api/src/version.rs index 5f81c0a96d..f768de3e31 100644 --- a/crates/proc-macro-api/src/version.rs +++ b/crates/proc-macro-api/src/version.rs @@ -38,7 +38,7 @@ pub fn read_dylib_info(dylib_path: &AbsPath) -> io::Result { let version_part = items.next().ok_or_else(|| err!("no version string"))?; let mut version_parts = version_part.split('-'); let version = version_parts.next().ok_or_else(|| err!("no version"))?; - let channel = version_parts.next().unwrap_or_default().to_string(); + let channel = version_parts.next().unwrap_or_default().to_owned(); let commit = match items.next() { Some(commit) => { diff --git a/crates/proc-macro-srv-cli/src/main.rs b/crates/proc-macro-srv-cli/src/main.rs index a36200cdb4..df0ae3171f 100644 --- a/crates/proc-macro-srv-cli/src/main.rs +++ b/crates/proc-macro-srv-cli/src/main.rs @@ -1,6 +1,8 @@ //! A standalone binary for `proc-macro-srv`. //! Driver for proc macro server #![cfg_attr(feature = "in-rust-tree", feature(rustc_private))] +#![allow(clippy::print_stderr)] + #[cfg(feature = "in-rust-tree")] extern crate rustc_driver as _; diff --git a/crates/proc-macro-srv/proc-macro-test/imp/src/lib.rs b/crates/proc-macro-srv/proc-macro-test/imp/src/lib.rs index d9018b1b87..5f8530d08c 100644 --- a/crates/proc-macro-srv/proc-macro-test/imp/src/lib.rs +++ b/crates/proc-macro-srv/proc-macro-test/imp/src/lib.rs @@ -2,6 +2,7 @@ #![warn(rust_2018_idioms, unused_lifetimes)] #![feature(proc_macro_span, proc_macro_def_site)] +#![allow(clippy::all)] use proc_macro::{Group, Ident, Literal, Punct, Span, TokenStream, TokenTree}; diff --git a/crates/proc-macro-srv/src/server/rust_analyzer_span.rs b/crates/proc-macro-srv/src/server/rust_analyzer_span.rs index b864a5e4fd..c7c7bea994 100644 --- a/crates/proc-macro-srv/src/server/rust_analyzer_span.rs +++ b/crates/proc-macro-srv/src/server/rust_analyzer_span.rs @@ -104,7 +104,7 @@ impl server::TokenStream for RaSpanServer { delimiter: delim_to_internal(group.delimiter, group.span), token_trees: match group.stream { Some(stream) => stream.into_iter().collect(), - None => Vec::new(), + None => Box::new([]), }, }; let tree = tt::TokenTree::from(group); @@ -221,7 +221,7 @@ impl server::TokenStream for RaSpanServer { stream: if subtree.token_trees.is_empty() { None } else { - Some(subtree.token_trees.into_iter().collect()) + Some(subtree.token_trees.into_vec().into_iter().collect()) }, span: bridge::DelimSpan::from_single(subtree.delimiter.open), }), diff --git a/crates/proc-macro-srv/src/server/token_id.rs b/crates/proc-macro-srv/src/server/token_id.rs index c83e09af0d..edbdc67b48 100644 --- a/crates/proc-macro-srv/src/server/token_id.rs +++ b/crates/proc-macro-srv/src/server/token_id.rs @@ -94,7 +94,7 @@ impl server::TokenStream for TokenIdServer { delimiter: delim_to_internal(group.delimiter, group.span), token_trees: match group.stream { Some(stream) => stream.into_iter().collect(), - None => Vec::new(), + None => Box::new([]), }, }; let tree = TokenTree::from(group); @@ -206,7 +206,7 @@ impl server::TokenStream for TokenIdServer { stream: if subtree.token_trees.is_empty() { None } else { - Some(TokenStream { token_trees: subtree.token_trees }) + Some(TokenStream { token_trees: subtree.token_trees.into_vec() }) }, span: bridge::DelimSpan::from_single(subtree.delimiter.open), }), @@ -338,7 +338,7 @@ mod tests { close: tt::TokenId(0), kind: tt::DelimiterKind::Brace, }, - token_trees: vec![], + token_trees: Box::new([]), }), ], }; @@ -354,10 +354,10 @@ mod tests { close: tt::TokenId(0), kind: tt::DelimiterKind::Parenthesis, }, - token_trees: vec![tt::TokenTree::Leaf(tt::Leaf::Ident(tt::Ident { + token_trees: Box::new([tt::TokenTree::Leaf(tt::Leaf::Ident(tt::Ident { text: "a".into(), span: tt::TokenId(0), - }))], + }))]), }); let t1 = TokenStream::from_str("(a)", tt::TokenId(0)).unwrap(); diff --git a/crates/proc-macro-srv/src/server/token_stream.rs b/crates/proc-macro-srv/src/server/token_stream.rs index 8f669a3049..5edaa720fc 100644 --- a/crates/proc-macro-srv/src/server/token_stream.rs +++ b/crates/proc-macro-srv/src/server/token_stream.rs @@ -15,14 +15,14 @@ impl Default for TokenStream { impl TokenStream { pub(crate) fn new() -> Self { - TokenStream { token_trees: vec![] } + TokenStream::default() } pub(crate) fn with_subtree(subtree: tt::Subtree) -> Self { if subtree.delimiter.kind != tt::DelimiterKind::Invisible { TokenStream { token_trees: vec![TokenTree::Subtree(subtree)] } } else { - TokenStream { token_trees: subtree.token_trees } + TokenStream { token_trees: subtree.token_trees.into_vec() } } } @@ -36,7 +36,7 @@ impl TokenStream { close: call_site, kind: tt::DelimiterKind::Invisible, }, - token_trees: self.token_trees, + token_trees: self.token_trees.into_boxed_slice(), } } @@ -83,7 +83,7 @@ impl Extend> for TokenStream { tt::TokenTree::Subtree(subtree) if subtree.delimiter.kind == tt::DelimiterKind::Invisible => { - self.token_trees.extend(subtree.token_trees); + self.token_trees.extend(subtree.token_trees.into_vec().into_iter()); } _ => { self.token_trees.push(tkn); diff --git a/crates/profile/src/lib.rs b/crates/profile/src/lib.rs index 38c5b3fc9c..3639981560 100644 --- a/crates/profile/src/lib.rs +++ b/crates/profile/src/lib.rs @@ -21,7 +21,7 @@ pub use countme; /// almost zero. pub use countme::Count; -thread_local!(static IN_SCOPE: RefCell = RefCell::new(false)); +thread_local!(static IN_SCOPE: RefCell = const { RefCell::new(false) }); /// Allows to check if the current code is within some dynamic scope, can be /// useful during debugging to figure out why a function is called. @@ -88,6 +88,7 @@ pub fn cpu_span() -> CpuSpan { } #[cfg(not(feature = "cpu_profiler"))] + #[allow(clippy::print_stderr)] { eprintln!( r#"cpu profiling is disabled, uncomment `default = [ "cpu_profiler" ]` in Cargo.toml to enable."# diff --git a/crates/profile/src/stop_watch.rs b/crates/profile/src/stop_watch.rs index 814a025740..990b59cad4 100644 --- a/crates/profile/src/stop_watch.rs +++ b/crates/profile/src/stop_watch.rs @@ -1,4 +1,7 @@ //! Like `std::time::Instant`, but also measures memory & CPU cycles. + +#![allow(clippy::print_stderr)] + use std::{ fmt, time::{Duration, Instant}, diff --git a/crates/project-model/src/build_scripts.rs b/crates/project-model/src/build_scripts.rs index c1670c2004..a2c9856a3f 100644 --- a/crates/project-model/src/build_scripts.rs +++ b/crates/project-model/src/build_scripts.rs @@ -322,7 +322,7 @@ impl WorkspaceBuildScripts { let mut deserializer = serde_json::Deserializer::from_str(line); deserializer.disable_recursion_limit(); let message = Message::deserialize(&mut deserializer) - .unwrap_or_else(|_| Message::TextLine(line.to_string())); + .unwrap_or_else(|_| Message::TextLine(line.to_owned())); match message { Message::BuildScriptExecuted(mut message) => { @@ -356,7 +356,7 @@ impl WorkspaceBuildScripts { if let Some(out_dir) = out_dir.as_os_str().to_str().map(|s| s.to_owned()) { - data.envs.push(("OUT_DIR".to_string(), out_dir)); + data.envs.push(("OUT_DIR".to_owned(), out_dir)); } data.out_dir = Some(out_dir); data.cfgs = cfgs; @@ -396,7 +396,7 @@ impl WorkspaceBuildScripts { let errors = if !output.status.success() { let errors = errors.into_inner(); - Some(if errors.is_empty() { "cargo check failed".to_string() } else { errors }) + Some(if errors.is_empty() { "cargo check failed".to_owned() } else { errors }) } else { None }; @@ -490,7 +490,7 @@ impl WorkspaceBuildScripts { // FIXME: Find a better way to know if it is a dylib. fn is_dylib(path: &Utf8Path) -> bool { - match path.extension().map(|e| e.to_string().to_lowercase()) { + match path.extension().map(|e| e.to_owned().to_lowercase()) { None => false, Some(ext) => matches!(ext.as_str(), "dll" | "dylib" | "so"), } diff --git a/crates/project-model/src/cargo_workspace.rs b/crates/project-model/src/cargo_workspace.rs index 361f8721a4..a99ee6e664 100644 --- a/crates/project-model/src/cargo_workspace.rs +++ b/crates/project-model/src/cargo_workspace.rs @@ -285,7 +285,7 @@ impl CargoWorkspace { // FIXME: Fetching metadata is a slow process, as it might require // calling crates.io. We should be reporting progress here, but it's // unclear whether cargo itself supports it. - progress("metadata".to_string()); + progress("metadata".to_owned()); (|| -> Result { let mut command = meta.cargo_command(); @@ -399,7 +399,7 @@ impl CargoWorkspace { CargoWorkspace { packages, targets, workspace_root, target_directory } } - pub fn packages(&self) -> impl Iterator + ExactSizeIterator + '_ { + pub fn packages(&self) -> impl ExactSizeIterator + '_ { self.packages.iter().map(|(id, _pkg)| id) } @@ -502,7 +502,7 @@ fn rustc_discover_host_triple( let field = "host: "; let target = stdout.lines().find_map(|l| l.strip_prefix(field)); if let Some(target) = target { - Some(target.to_string()) + Some(target.to_owned()) } else { // If we fail to resolve the host platform, it's not the end of the world. tracing::info!("rustc -vV did not report host platform, got:\n{}", stdout); @@ -536,7 +536,7 @@ fn parse_output_cargo_config_build_target(stdout: String) -> Vec { let trimmed = stdout.trim_start_matches("build.target = ").trim_matches('"'); if !trimmed.starts_with('[') { - return [trimmed.to_string()].to_vec(); + return [trimmed.to_owned()].to_vec(); } let res = serde_json::from_str(trimmed); diff --git a/crates/project-model/src/cfg_flag.rs b/crates/project-model/src/cfg_flag.rs index e366d441c1..af682904b1 100644 --- a/crates/project-model/src/cfg_flag.rs +++ b/crates/project-model/src/cfg_flag.rs @@ -19,7 +19,7 @@ impl FromStr for CfgFlag { if !(value.starts_with('"') && value.ends_with('"')) { return Err(format!("Invalid cfg ({s:?}), value should be in quotes")); } - let key = key.to_string(); + let key = key.to_owned(); let value = value[1..value.len() - 1].to_string(); CfgFlag::KeyValue { key, value } } diff --git a/crates/project-model/src/lib.rs b/crates/project-model/src/lib.rs index 5114c9c016..5b91f5d805 100644 --- a/crates/project-model/src/lib.rs +++ b/crates/project-model/src/lib.rs @@ -167,7 +167,7 @@ fn utf8_stdout(mut cmd: Command) -> anyhow::Result { } } let stdout = String::from_utf8(output.stdout)?; - Ok(stdout.trim().to_string()) + Ok(stdout.trim().to_owned()) } #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] diff --git a/crates/project-model/src/rustc_cfg.rs b/crates/project-model/src/rustc_cfg.rs index cf12d5b71d..0aee002fbb 100644 --- a/crates/project-model/src/rustc_cfg.rs +++ b/crates/project-model/src/rustc_cfg.rs @@ -33,7 +33,7 @@ pub(crate) fn get( res.push(CfgFlag::Atom("target_thread_local".into())); for ty in ["8", "16", "32", "64", "cas", "ptr"] { for key in ["target_has_atomic", "target_has_atomic_load_store"] { - res.push(CfgFlag::KeyValue { key: key.to_string(), value: ty.into() }); + res.push(CfgFlag::KeyValue { key: key.to_owned(), value: ty.into() }); } } diff --git a/crates/project-model/src/sysroot.rs b/crates/project-model/src/sysroot.rs index c24c0196dd..9e19a52583 100644 --- a/crates/project-model/src/sysroot.rs +++ b/crates/project-model/src/sysroot.rs @@ -57,7 +57,7 @@ impl Stitched { self.by_name("proc_macro") } - pub(crate) fn crates(&self) -> impl Iterator + ExactSizeIterator + '_ { + pub(crate) fn crates(&self) -> impl ExactSizeIterator + '_ { self.crates.iter().map(|(id, _data)| id) } diff --git a/crates/project-model/src/tests.rs b/crates/project-model/src/tests.rs index 7c078f72f5..74042e925e 100644 --- a/crates/project-model/src/tests.rs +++ b/crates/project-model/src/tests.rs @@ -129,7 +129,7 @@ fn get_fake_sysroot() -> Sysroot { } fn rooted_project_json(data: ProjectJsonData) -> ProjectJson { - let mut root = "$ROOT$".to_string(); + let mut root = "$ROOT$".to_owned(); replace_root(&mut root, true); let path = Path::new(&root); let base = AbsPath::assert(path); diff --git a/crates/project-model/src/workspace.rs b/crates/project-model/src/workspace.rs index 8c5ea0619a..cda5ad2f11 100644 --- a/crates/project-model/src/workspace.rs +++ b/crates/project-model/src/workspace.rs @@ -241,7 +241,7 @@ impl ProjectWorkspace { .map_err(|p| Some(format!("rustc source path is not absolute: {p}"))), Some(RustLibSource::Discover) => { sysroot.as_ref().ok().and_then(Sysroot::discover_rustc_src).ok_or_else( - || Some("Failed to discover rustc source for sysroot.".to_string()), + || Some("Failed to discover rustc source for sysroot.".to_owned()), ) } None => Err(None), @@ -840,7 +840,7 @@ fn project_json_to_crate_graph( if let Some(name) = display_name.clone() { CrateOrigin::Local { repo: repository.clone(), - name: Some(name.canonical_name().to_string()), + name: Some(name.canonical_name().to_owned()), } } else { CrateOrigin::Local { repo: None, name: None } @@ -1117,7 +1117,7 @@ fn detached_files_to_crate_graph( let display_name = detached_file .file_stem() .and_then(|os_str| os_str.to_str()) - .map(|file_stem| CrateDisplayName::from_canonical_name(file_stem.to_string())); + .map(|file_stem| CrateDisplayName::from_canonical_name(file_stem.to_owned())); let detached_file_crate = crate_graph.add_crate_root( file_id, Edition::CURRENT, @@ -1129,7 +1129,7 @@ fn detached_files_to_crate_graph( false, CrateOrigin::Local { repo: None, - name: display_name.map(|n| n.canonical_name().to_string()), + name: display_name.map(|n| n.canonical_name().to_owned()), }, target_layout.clone(), None, @@ -1323,7 +1323,7 @@ fn add_target_crate_root( } } - let display_name = CrateDisplayName::from_canonical_name(cargo_name.to_string()); + let display_name = CrateDisplayName::from_canonical_name(cargo_name.to_owned()); let crate_id = crate_graph.add_crate_root( file_id, edition, @@ -1455,7 +1455,7 @@ fn sysroot_to_crate_graph( (SysrootPublicDeps { deps: pub_deps }, libproc_macro) } SysrootMode::Stitched(stitched) => { - let cfg_options = create_cfg_options(rustc_cfg.clone()); + let cfg_options = create_cfg_options(rustc_cfg); let sysroot_crates: FxHashMap = stitched .crates() .filter_map(|krate| { diff --git a/crates/rust-analyzer/src/bin/main.rs b/crates/rust-analyzer/src/bin/main.rs index 66b680571a..269dd3cfff 100644 --- a/crates/rust-analyzer/src/bin/main.rs +++ b/crates/rust-analyzer/src/bin/main.rs @@ -3,7 +3,9 @@ //! Based on cli flags, either spawns an LSP server, or runs a batch analysis #![warn(rust_2018_idioms, unused_lifetimes)] +#![allow(clippy::print_stdout, clippy::print_stderr)] #![cfg_attr(feature = "in-rust-tree", feature(rustc_private))] + #[cfg(feature = "in-rust-tree")] extern crate rustc_driver as _; @@ -132,7 +134,7 @@ fn setup_logging(log_file_flag: Option) -> anyhow::Result<()> { writer, // Deliberately enable all `error` logs if the user has not set RA_LOG, as there is usually // useful information in there for debugging. - filter: env::var("RA_LOG").ok().unwrap_or_else(|| "error".to_string()), + filter: env::var("RA_LOG").ok().unwrap_or_else(|| "error".to_owned()), chalk_filter: env::var("CHALK_DEBUG").ok(), profile_filter: env::var("RA_PROFILE").ok(), } @@ -222,7 +224,7 @@ fn run_server() -> anyhow::Result<()> { MessageType, ShowMessageParams, }; let not = lsp_server::Notification::new( - ShowMessage::METHOD.to_string(), + ShowMessage::METHOD.to_owned(), ShowMessageParams { typ: MessageType::WARNING, message: e.to_string() }, ); connection.sender.send(lsp_server::Message::Notification(not)).unwrap(); diff --git a/crates/rust-analyzer/src/caps.rs b/crates/rust-analyzer/src/caps.rs index 94eab97e8f..a1469c22ab 100644 --- a/crates/rust-analyzer/src/caps.rs +++ b/crates/rust-analyzer/src/caps.rs @@ -44,17 +44,17 @@ pub fn server_capabilities(config: &Config) -> ServerCapabilities { completion_provider: Some(CompletionOptions { resolve_provider: completions_resolve_provider(config.caps()), trigger_characters: Some(vec![ - ":".to_string(), - ".".to_string(), - "'".to_string(), - "(".to_string(), + ":".to_owned(), + ".".to_owned(), + "'".to_owned(), + "(".to_owned(), ]), all_commit_characters: None, completion_item: completion_item(config), work_done_progress_options: WorkDoneProgressOptions { work_done_progress: None }, }), signature_help_provider: Some(SignatureHelpOptions { - trigger_characters: Some(vec!["(".to_string(), ",".to_string(), "<".to_string()]), + trigger_characters: Some(vec!["(".to_owned(), ",".to_owned(), "<".to_owned()]), retrigger_characters: None, work_done_progress_options: WorkDoneProgressOptions { work_done_progress: None }, }), @@ -74,7 +74,7 @@ pub fn server_capabilities(config: &Config) -> ServerCapabilities { _ => Some(OneOf::Left(false)), }, document_on_type_formatting_provider: Some(DocumentOnTypeFormattingOptions { - first_trigger_character: "=".to_string(), + first_trigger_character: "=".to_owned(), more_trigger_character: Some(more_trigger_character(config)), }), selection_range_provider: Some(SelectionRangeProviderCapability::Simple(true)), @@ -222,9 +222,9 @@ fn code_action_capabilities(client_caps: &ClientCapabilities) -> CodeActionProvi } fn more_trigger_character(config: &Config) -> Vec { - let mut res = vec![".".to_string(), ">".to_string(), "{".to_string(), "(".to_string()]; + let mut res = vec![".".to_owned(), ">".to_owned(), "{".to_owned(), "(".to_owned()]; if config.snippet_cap() { - res.push("<".to_string()); + res.push("<".to_owned()); } res } diff --git a/crates/rust-analyzer/src/cli.rs b/crates/rust-analyzer/src/cli.rs index 00670f2cb4..0bd6677b66 100644 --- a/crates/rust-analyzer/src/cli.rs +++ b/crates/rust-analyzer/src/cli.rs @@ -1,5 +1,7 @@ //! Various batch processing tasks, intended primarily for debugging. +#![allow(clippy::print_stdout, clippy::print_stderr)] + mod analysis_stats; mod diagnostics; pub mod flags; diff --git a/crates/rust-analyzer/src/cli/analysis_stats.rs b/crates/rust-analyzer/src/cli/analysis_stats.rs index 31bdd2a0e8..2741b45222 100644 --- a/crates/rust-analyzer/src/cli/analysis_stats.rs +++ b/crates/rust-analyzer/src/cli/analysis_stats.rs @@ -397,7 +397,7 @@ impl flags::AnalysisStats { module .krate() .display_name(db) - .map(|it| it.canonical_name().to_string()) + .map(|it| it.canonical_name().to_owned()) .into_iter() .chain( module @@ -688,7 +688,7 @@ impl flags::AnalysisStats { module .krate() .display_name(db) - .map(|it| it.canonical_name().to_string()) + .map(|it| it.canonical_name().to_owned()) .into_iter() .chain( module @@ -833,7 +833,7 @@ impl flags::AnalysisStats { fn location_csv_expr(db: &RootDatabase, vfs: &Vfs, sm: &BodySourceMap, expr_id: ExprId) -> String { let src = match sm.expr_syntax(expr_id) { Ok(s) => s, - Err(SyntheticSyntax) => return "synthetic,,".to_string(), + Err(SyntheticSyntax) => return "synthetic,,".to_owned(), }; let root = db.parse_or_expand(src.file_id); let node = src.map(|e| e.to_node(&root).syntax().clone()); @@ -849,7 +849,7 @@ fn location_csv_expr(db: &RootDatabase, vfs: &Vfs, sm: &BodySourceMap, expr_id: fn location_csv_pat(db: &RootDatabase, vfs: &Vfs, sm: &BodySourceMap, pat_id: PatId) -> String { let src = match sm.pat_syntax(pat_id) { Ok(s) => s, - Err(SyntheticSyntax) => return "synthetic,,".to_string(), + Err(SyntheticSyntax) => return "synthetic,,".to_owned(), }; let root = db.parse_or_expand(src.file_id); let node = src.map(|e| e.to_node(&root).syntax().clone()); diff --git a/crates/rust-analyzer/src/cli/diagnostics.rs b/crates/rust-analyzer/src/cli/diagnostics.rs index 6d2e97be20..605670f6a8 100644 --- a/crates/rust-analyzer/src/cli/diagnostics.rs +++ b/crates/rust-analyzer/src/cli/diagnostics.rs @@ -45,7 +45,7 @@ impl flags::Diagnostics { let file_id = module.definition_source_file_id(db).original_file(db); if !visited_files.contains(&file_id) { let crate_name = - module.krate().display_name(db).as_deref().unwrap_or("unknown").to_string(); + module.krate().display_name(db).as_deref().unwrap_or("unknown").to_owned(); println!("processing crate: {crate_name}, module: {}", _vfs.file_path(file_id)); for diagnostic in analysis .diagnostics( diff --git a/crates/rust-analyzer/src/cli/lsif.rs b/crates/rust-analyzer/src/cli/lsif.rs index 64f965e22a..1424a77577 100644 --- a/crates/rust-analyzer/src/cli/lsif.rs +++ b/crates/rust-analyzer/src/cli/lsif.rs @@ -1,6 +1,5 @@ //! LSIF (language server index format) generator -use std::collections::HashMap; use std::env; use std::time::Instant; @@ -16,6 +15,7 @@ use ide_db::{ use load_cargo::{load_workspace, LoadCargoConfig, ProcMacroServerChoice}; use lsp_types::{self, lsif}; use project_model::{CargoConfig, ProjectManifest, ProjectWorkspace, RustLibSource}; +use rustc_hash::FxHashMap; use vfs::{AbsPathBuf, Vfs}; use crate::{ @@ -35,10 +35,10 @@ impl Clone for Snap> { struct LsifManager<'a> { count: i32, - token_map: HashMap, - range_map: HashMap, - file_map: HashMap, - package_map: HashMap, + token_map: FxHashMap, + range_map: FxHashMap, + file_map: FxHashMap, + package_map: FxHashMap, analysis: &'a Analysis, db: &'a RootDatabase, vfs: &'a Vfs, @@ -57,10 +57,10 @@ impl LsifManager<'_> { fn new<'a>(analysis: &'a Analysis, db: &'a RootDatabase, vfs: &'a Vfs) -> LsifManager<'a> { LsifManager { count: 0, - token_map: HashMap::default(), - range_map: HashMap::default(), - file_map: HashMap::default(), - package_map: HashMap::default(), + token_map: FxHashMap::default(), + range_map: FxHashMap::default(), + file_map: FxHashMap::default(), + package_map: FxHashMap::default(), analysis, db, vfs, @@ -104,12 +104,12 @@ impl LsifManager<'_> { let result_set_id = self.add_vertex(lsif::Vertex::PackageInformation(lsif::PackageInformation { name: pi.name, - manager: "cargo".to_string(), + manager: "cargo".to_owned(), uri: None, content: None, repository: pi.repo.map(|url| lsif::Repository { url, - r#type: "git".to_string(), + r#type: "git".to_owned(), commit_id: None, }), version: pi.version, @@ -148,7 +148,7 @@ impl LsifManager<'_> { let path = self.vfs.file_path(id); let path = path.as_path().unwrap(); let doc_id = self.add_vertex(lsif::Vertex::Document(lsif::Document { - language_id: "rust".to_string(), + language_id: "rust".to_owned(), uri: lsp_types::Url::from_file_path(path).unwrap(), })); self.file_map.insert(id, doc_id); @@ -175,7 +175,7 @@ impl LsifManager<'_> { if let Some(moniker) = token.moniker { let package_id = self.get_package_id(moniker.package_information); let moniker_id = self.add_vertex(lsif::Vertex::Moniker(lsp_types::Moniker { - scheme: "rust-analyzer".to_string(), + scheme: "rust-analyzer".to_owned(), identifier: moniker.identifier.to_string(), unique: lsp_types::UniquenessLevel::Scheme, kind: Some(match moniker.kind { @@ -215,7 +215,7 @@ impl LsifManager<'_> { out_v: result_set_id.into(), })); let mut edges = token.references.iter().fold( - HashMap::<_, Vec>::new(), + FxHashMap::<_, Vec>::default(), |mut edges, it| { let entry = edges.entry((it.range.file_id, it.is_definition)).or_default(); entry.push((*self.range_map.get(&it.range).unwrap()).into()); @@ -313,7 +313,7 @@ impl flags::Lsif { project_root: lsp_types::Url::from_file_path(path).unwrap(), position_encoding: lsif::Encoding::Utf16, tool_info: Some(lsp_types::lsif::ToolInfo { - name: "rust-analyzer".to_string(), + name: "rust-analyzer".to_owned(), args: vec![], version: Some(version().to_string()), }), diff --git a/crates/rust-analyzer/src/cli/progress_report.rs b/crates/rust-analyzer/src/cli/progress_report.rs index 8166aa23b4..b233730099 100644 --- a/crates/rust-analyzer/src/cli/progress_report.rs +++ b/crates/rust-analyzer/src/cli/progress_report.rs @@ -92,7 +92,7 @@ impl<'a> ProgressReport<'a> { let _ = io::stdout().write(output.as_bytes()); let _ = io::stdout().flush(); - self.text = text.to_string(); + self.text = text.to_owned(); } fn set_value(&mut self, value: f32) { diff --git a/crates/rust-analyzer/src/cli/run_tests.rs b/crates/rust-analyzer/src/cli/run_tests.rs index d07dcdec25..6b43e09542 100644 --- a/crates/rust-analyzer/src/cli/run_tests.rs +++ b/crates/rust-analyzer/src/cli/run_tests.rs @@ -34,7 +34,7 @@ impl flags::RunTests { .filter(|x| x.is_test(db)); let span_formatter = |file_id, text_range: TextRange| { let line_col = match db.line_index(file_id).try_line_col(text_range.start()) { - None => " (unknown line col)".to_string(), + None => " (unknown line col)".to_owned(), Some(x) => format!("#{}:{}", x.line + 1, x.col), }; let path = &db diff --git a/crates/rust-analyzer/src/cli/rustc_tests.rs b/crates/rust-analyzer/src/cli/rustc_tests.rs index be7e434aca..64ea246a45 100644 --- a/crates/rust-analyzer/src/cli/rustc_tests.rs +++ b/crates/rust-analyzer/src/cli/rustc_tests.rs @@ -1,8 +1,6 @@ //! Run all tests in a project, similar to `cargo test`, but using the mir interpreter. -use std::{ - cell::RefCell, collections::HashMap, fs::read_to_string, panic::AssertUnwindSafe, path::PathBuf, -}; +use std::{cell::RefCell, fs::read_to_string, panic::AssertUnwindSafe, path::PathBuf}; use hir::{Change, Crate}; use ide::{AnalysisHost, DiagnosticCode, DiagnosticsConfig}; @@ -10,6 +8,7 @@ use profile::StopWatch; use project_model::{CargoConfig, ProjectWorkspace, RustLibSource, Sysroot}; use load_cargo::{load_workspace, LoadCargoConfig, ProcMacroServerChoice}; +use rustc_hash::FxHashMap; use triomphe::Arc; use vfs::{AbsPathBuf, FileId}; use walkdir::WalkDir; @@ -27,7 +26,7 @@ struct Tester { fn string_to_diagnostic_code_leaky(code: &str) -> DiagnosticCode { thread_local! { - static LEAK_STORE: RefCell> = RefCell::new(HashMap::new()); + static LEAK_STORE: RefCell> = RefCell::new(FxHashMap::default()); } LEAK_STORE.with_borrow_mut(|s| match s.get(code) { Some(c) => *c, @@ -39,9 +38,9 @@ fn string_to_diagnostic_code_leaky(code: &str) -> DiagnosticCode { }) } -fn detect_errors_from_rustc_stderr_file(p: PathBuf) -> HashMap { +fn detect_errors_from_rustc_stderr_file(p: PathBuf) -> FxHashMap { let text = read_to_string(p).unwrap(); - let mut result = HashMap::new(); + let mut result = FxHashMap::default(); { let mut text = &*text; while let Some(p) = text.find("error[E") { @@ -106,7 +105,7 @@ impl Tester { let expected = if stderr_path.exists() { detect_errors_from_rustc_stderr_file(stderr_path) } else { - HashMap::new() + FxHashMap::default() }; let text = read_to_string(&p).unwrap(); let mut change = Change::new(); @@ -125,7 +124,7 @@ impl Tester { self.host.apply_change(change); let diagnostic_config = DiagnosticsConfig::test_sample(); - let mut actual = HashMap::new(); + let mut actual = FxHashMap::default(); let panicked = match std::panic::catch_unwind(|| { self.host .analysis() diff --git a/crates/rust-analyzer/src/cli/scip.rs b/crates/rust-analyzer/src/cli/scip.rs index 81622a4617..f4aec28834 100644 --- a/crates/rust-analyzer/src/cli/scip.rs +++ b/crates/rust-analyzer/src/cli/scip.rs @@ -1,10 +1,6 @@ //! SCIP generator -use std::{ - collections::{HashMap, HashSet}, - path::PathBuf, - time::Instant, -}; +use std::{path::PathBuf, time::Instant}; use ide::{ LineCol, MonikerDescriptorKind, MonikerResult, StaticIndex, StaticIndexedFile, @@ -12,6 +8,7 @@ use ide::{ }; use ide_db::LineIndexDatabase; use load_cargo::{load_workspace_at, LoadCargoConfig, ProcMacroServerChoice}; +use rustc_hash::{FxHashMap, FxHashSet}; use scip::types as scip_types; use crate::{ @@ -76,9 +73,10 @@ impl flags::Scip { }; let mut documents = Vec::new(); - let mut symbols_emitted: HashSet = HashSet::default(); - let mut tokens_to_symbol: HashMap = HashMap::new(); - let mut tokens_to_enclosing_symbol: HashMap> = HashMap::new(); + let mut symbols_emitted: FxHashSet = FxHashSet::default(); + let mut tokens_to_symbol: FxHashMap = FxHashMap::default(); + let mut tokens_to_enclosing_symbol: FxHashMap> = + FxHashMap::default(); for StaticIndexedFile { file_id, tokens, .. } in si.files { let mut local_count = 0; @@ -148,7 +146,7 @@ impl flags::Scip { let signature_documentation = token.signature.clone().map(|text| scip_types::Document { relative_path: relative_path.clone(), - language: "rust".to_string(), + language: "rust".to_owned(), text, position_encoding, ..Default::default() @@ -188,7 +186,7 @@ impl flags::Scip { scip_types::PositionEncoding::UTF8CodeUnitOffsetFromLineStart.into(); documents.push(scip_types::Document { relative_path, - language: "rust".to_string(), + language: "rust".to_owned(), occurrences, symbols, text: String::new(), @@ -218,7 +216,7 @@ fn get_relative_filepath( rootpath: &vfs::AbsPathBuf, file_id: ide::FileId, ) -> Option { - Some(vfs.file_path(file_id).as_path()?.strip_prefix(rootpath)?.as_ref().to_str()?.to_string()) + Some(vfs.file_path(file_id).as_path()?.strip_prefix(rootpath)?.as_ref().to_str()?.to_owned()) } // SCIP Ranges have a (very large) optimization that ranges if they are on the same line @@ -241,8 +239,8 @@ fn new_descriptor_str( suffix: scip_types::descriptor::Suffix, ) -> scip_types::Descriptor { scip_types::Descriptor { - name: name.to_string(), - disambiguator: "".to_string(), + name: name.to_owned(), + disambiguator: "".to_owned(), suffix: suffix.into(), special_fields: Default::default(), } @@ -313,9 +311,9 @@ fn moniker_to_symbol(moniker: &MonikerResult) -> scip_types::Symbol { scip_types::Symbol { scheme: "rust-analyzer".into(), package: Some(scip_types::Package { - manager: "cargo".to_string(), + manager: "cargo".to_owned(), name: package_name, - version: version.unwrap_or_else(|| ".".to_string()), + version: version.unwrap_or_else(|| ".".to_owned()), special_fields: Default::default(), }) .into(), diff --git a/crates/rust-analyzer/src/config.rs b/crates/rust-analyzer/src/config.rs index 815f6ea12e..7bdd9ec866 100644 --- a/crates/rust-analyzer/src/config.rs +++ b/crates/rust-analyzer/src/config.rs @@ -223,6 +223,9 @@ config_data! { /// /// Aliased as `"checkOnSave.targets"`. check_targets | checkOnSave_targets | checkOnSave_target: Option = "null", + /// Whether `--workspace` should be passed to `cargo check`. + /// If false, `-p ` will be passed instead. + check_workspace: bool = "true", /// Toggles the additional completions that automatically add imports when completed. /// Note that your client must specify the `additionalTextEdits` LSP client capability to truly have this feature enabled. @@ -478,6 +481,9 @@ config_data! { /// Whether to show `can't find Cargo.toml` error message. notifications_cargoTomlNotFound: bool = "true", + /// Whether to send an UnindexedProject notification to the client. + notifications_unindexedProject: bool = "false", + /// How many worker threads in the main loop. The default `null` means to pick automatically. numThreads: Option = "null", @@ -745,6 +751,7 @@ pub enum FilesWatcher { #[derive(Debug, Clone)] pub struct NotificationsConfig { pub cargo_toml_not_found: bool, + pub unindexed_project: bool, } #[derive(Debug, Clone)] @@ -897,7 +904,7 @@ impl Config { use serde::de::Error; if self.data.check_command.is_empty() { error_sink.push(( - "/check/command".to_string(), + "/check/command".to_owned(), serde_json::Error::custom("expected a non-empty string"), )); } @@ -1220,7 +1227,10 @@ impl Config { } pub fn notifications(&self) -> NotificationsConfig { - NotificationsConfig { cargo_toml_not_found: self.data.notifications_cargoTomlNotFound } + NotificationsConfig { + cargo_toml_not_found: self.data.notifications_cargoTomlNotFound, + unindexed_project: self.data.notifications_unindexedProject, + } } pub fn cargo_autoreload(&self) -> bool { @@ -1323,6 +1333,10 @@ impl Config { } } + pub fn flycheck_workspace(&self) -> bool { + self.data.check_workspace + } + pub fn flycheck(&self) -> FlycheckConfig { match &self.data.check_overrideCommand { Some(args) if !args.is_empty() => { @@ -2612,7 +2626,7 @@ mod tests { .replace('\n', "\n ") .trim_start_matches('\n') .trim_end() - .to_string(); + .to_owned(); schema.push_str(",\n"); // Transform the asciidoc form link to markdown style. diff --git a/crates/rust-analyzer/src/config/patch_old_style.rs b/crates/rust-analyzer/src/config/patch_old_style.rs index 73d2ed3298..92c0c0d048 100644 --- a/crates/rust-analyzer/src/config/patch_old_style.rs +++ b/crates/rust-analyzer/src/config/patch_old_style.rs @@ -19,7 +19,7 @@ pub(super) fn patch_json_for_outdated_configs(json: &mut Value) { Some(it) => { let mut last = it; for segment in [$(stringify!($dst)),+].into_iter().rev() { - last = Value::Object(serde_json::Map::from_iter(std::iter::once((segment.to_string(), last)))); + last = Value::Object(serde_json::Map::from_iter(std::iter::once((segment.to_owned(), last)))); } merge(json, last); diff --git a/crates/rust-analyzer/src/diagnostics.rs b/crates/rust-analyzer/src/diagnostics.rs index c91b22999d..a0a53f545c 100644 --- a/crates/rust-analyzer/src/diagnostics.rs +++ b/crates/rust-analyzer/src/diagnostics.rs @@ -135,11 +135,11 @@ pub(crate) fn fetch_native_diagnostics( |line_index: &crate::line_index::LineIndex, d: ide::Diagnostic| lsp_types::Diagnostic { range: lsp::to_proto::range(line_index, d.range.range), severity: Some(lsp::to_proto::diagnostic_severity(d.severity)), - code: Some(lsp_types::NumberOrString::String(d.code.as_str().to_string())), + code: Some(lsp_types::NumberOrString::String(d.code.as_str().to_owned())), code_description: Some(lsp_types::CodeDescription { href: lsp_types::Url::parse(&d.code.url()).unwrap(), }), - source: Some("rust-analyzer".to_string()), + source: Some("rust-analyzer".to_owned()), message: d.message, related_information: None, tags: d.unused.then(|| vec![lsp_types::DiagnosticTag::UNNECESSARY]), diff --git a/crates/rust-analyzer/src/diagnostics/to_proto.rs b/crates/rust-analyzer/src/diagnostics/to_proto.rs index f8bc66ff8e..e900f2601d 100644 --- a/crates/rust-analyzer/src/diagnostics/to_proto.rs +++ b/crates/rust-analyzer/src/diagnostics/to_proto.rs @@ -1,9 +1,9 @@ //! This module provides the functionality needed to convert diagnostics from //! `cargo check` json format to the LSP diagnostic format. -use std::collections::HashMap; use flycheck::{Applicability, DiagnosticLevel, DiagnosticSpan}; use itertools::Itertools; +use rustc_hash::FxHashMap; use stdx::format_to; use vfs::{AbsPath, AbsPathBuf}; @@ -186,7 +186,7 @@ fn map_rust_child_diagnostic( return MappedRustChildDiagnostic::MessageLine(rd.message.clone()); } - let mut edit_map: HashMap> = HashMap::new(); + let mut edit_map: FxHashMap> = FxHashMap::default(); let mut suggested_replacements = Vec::new(); let mut is_preferred = true; for &span in &spans { @@ -403,16 +403,16 @@ pub(crate) fn map_rust_diagnostic_to_lsp( related_info_macro_calls.push(lsp_types::DiagnosticRelatedInformation { location: secondary_location.clone(), message: if is_in_macro_call { - "Error originated from macro call here".to_string() + "Error originated from macro call here".to_owned() } else { - "Actual error occurred here".to_string() + "Actual error occurred here".to_owned() }, }); // For the additional in-macro diagnostic we add the inverse message pointing to the error location in code. let information_for_additional_diagnostic = vec![lsp_types::DiagnosticRelatedInformation { location: primary_location.clone(), - message: "Exact error occurred here".to_string(), + message: "Exact error occurred here".to_owned(), }]; let diagnostic = lsp_types::Diagnostic { @@ -467,7 +467,7 @@ pub(crate) fn map_rust_diagnostic_to_lsp( // `related_information`, which just produces hard-to-read links, at least in VS Code. let back_ref = lsp_types::DiagnosticRelatedInformation { location: primary_location, - message: "original diagnostic".to_string(), + message: "original diagnostic".to_owned(), }; for sub in &subdiagnostics { diagnostics.push(MappedRustDiagnostic { @@ -685,7 +685,7 @@ mod tests { fn rustc_unused_variable_as_info() { check_with_config( DiagnosticsMapConfig { - warnings_as_info: vec!["unused_variables".to_string()], + warnings_as_info: vec!["unused_variables".to_owned()], ..DiagnosticsMapConfig::default() }, r##"{ @@ -769,7 +769,7 @@ mod tests { fn rustc_unused_variable_as_hint() { check_with_config( DiagnosticsMapConfig { - warnings_as_hint: vec!["unused_variables".to_string()], + warnings_as_hint: vec!["unused_variables".to_owned()], ..DiagnosticsMapConfig::default() }, r##"{ diff --git a/crates/rust-analyzer/src/dispatch.rs b/crates/rust-analyzer/src/dispatch.rs index fa856a796a..7adaef4ff6 100644 --- a/crates/rust-analyzer/src/dispatch.rs +++ b/crates/rust-analyzer/src/dispatch.rs @@ -123,7 +123,7 @@ impl RequestDispatcher<'_> { Err(_) => Task::Response(lsp_server::Response::new_err( req.id, lsp_server::ErrorCode::ContentModified as i32, - "content modified".to_string(), + "content modified".to_owned(), )), } } @@ -179,7 +179,7 @@ impl RequestDispatcher<'_> { let response = lsp_server::Response::new_err( req.id, lsp_server::ErrorCode::MethodNotFound as i32, - "unknown request".to_string(), + "unknown request".to_owned(), ); self.global_state.respond(response); } @@ -269,7 +269,7 @@ where .map(String::as_str) .or_else(|| panic.downcast_ref::<&str>().copied()); - let mut message = "request handler panicked".to_string(); + let mut message = "request handler panicked".to_owned(); if let Some(panic_message) = panic_message { message.push_str(": "); message.push_str(panic_message) diff --git a/crates/rust-analyzer/src/global_state.rs b/crates/rust-analyzer/src/global_state.rs index 2f226d0115..da4422a60a 100644 --- a/crates/rust-analyzer/src/global_state.rs +++ b/crates/rust-analyzer/src/global_state.rs @@ -33,7 +33,7 @@ use crate::{ mem_docs::MemDocs, op_queue::OpQueue, reload, - task_pool::TaskPool, + task_pool::{TaskPool, TaskQueue}, }; // Enforces drop order @@ -126,6 +126,17 @@ pub(crate) struct GlobalState { OpQueue<(), (Arc>, Vec>)>, pub(crate) fetch_proc_macros_queue: OpQueue, bool>, pub(crate) prime_caches_queue: OpQueue, + + /// A deferred task queue. + /// + /// This queue is used for doing database-dependent work inside of sync + /// handlers, as accessing the database may block latency-sensitive + /// interactions and should be moved away from the main thread. + /// + /// For certain features, such as [`lsp_ext::UnindexedProjectParams`], + /// this queue should run only *after* [`GlobalState::process_changes`] has + /// been called. + pub(crate) deferred_task_queue: TaskQueue, } /// An immutable snapshot of the world's state at a point in time. @@ -165,6 +176,11 @@ impl GlobalState { Handle { handle, receiver } }; + let task_queue = { + let (sender, receiver) = unbounded(); + TaskQueue { sender, receiver } + }; + let mut analysis_host = AnalysisHost::new(config.lru_parse_query_capacity()); if let Some(capacities) = config.lru_query_capacities() { analysis_host.update_lru_capacities(capacities); @@ -208,6 +224,8 @@ impl GlobalState { fetch_proc_macros_queue: OpQueue::default(), prime_caches_queue: OpQueue::default(), + + deferred_task_queue: task_queue, }; // Apply any required database inputs from the config. this.update_configuration(config); @@ -370,7 +388,7 @@ impl GlobalState { params: R::Params, handler: ReqHandler, ) { - let request = self.req_queue.outgoing.register(R::METHOD.to_string(), params, handler); + let request = self.req_queue.outgoing.register(R::METHOD.to_owned(), params, handler); self.send(request.into()); } @@ -387,7 +405,7 @@ impl GlobalState { &self, params: N::Params, ) { - let not = lsp_server::Notification::new(N::METHOD.to_string(), params); + let not = lsp_server::Notification::new(N::METHOD.to_owned(), params); self.send(not.into()); } diff --git a/crates/rust-analyzer/src/handlers/notification.rs b/crates/rust-analyzer/src/handlers/notification.rs index 1f24e95010..d3c2073f09 100644 --- a/crates/rust-analyzer/src/handlers/notification.rs +++ b/crates/rust-analyzer/src/handlers/notification.rs @@ -70,7 +70,15 @@ pub(crate) fn handle_did_open_text_document( if already_exists { tracing::error!("duplicate DidOpenTextDocument: {}", path); } + state.vfs.write().0.set_file_contents(path, Some(params.text_document.text.into_bytes())); + if state.config.notifications().unindexed_project { + tracing::debug!("queuing task"); + let _ = state + .deferred_task_queue + .sender + .send(crate::main_loop::QueuedTask::CheckIfIndexed(params.text_document.uri)); + } } Ok(()) } @@ -160,7 +168,7 @@ pub(crate) fn handle_did_save_text_document( } else if state.config.check_on_save() { // No specific flycheck was triggered, so let's trigger all of them. for flycheck in state.flycheck.iter() { - flycheck.restart(); + flycheck.restart_workspace(); } } Ok(()) @@ -176,7 +184,7 @@ pub(crate) fn handle_did_change_configuration( lsp_types::ConfigurationParams { items: vec![lsp_types::ConfigurationItem { scope_uri: None, - section: Some("rust-analyzer".to_string()), + section: Some("rust-analyzer".to_owned()), }], }, |this, resp| { @@ -228,7 +236,7 @@ pub(crate) fn handle_did_change_workspace_folders( if !config.has_linked_projects() && config.detached_files().is_empty() { config.rediscover_workspaces(); - state.fetch_workspaces_queue.request_op("client workspaces changed".to_string(), false) + state.fetch_workspaces_queue.request_op("client workspaces changed".to_owned(), false) } Ok(()) @@ -281,27 +289,40 @@ fn run_flycheck(state: &mut GlobalState, vfs_path: VfsPath) -> bool { let crate_root_paths: Vec<_> = crate_root_paths.iter().map(Deref::deref).collect(); // Find all workspaces that have at least one target containing the saved file - let workspace_ids = world.workspaces.iter().enumerate().filter(|(_, ws)| match ws { - project_model::ProjectWorkspace::Cargo { cargo, .. } => { - cargo.packages().any(|pkg| { - cargo[pkg] - .targets - .iter() - .any(|&it| crate_root_paths.contains(&cargo[it].root.as_path())) - }) - } - project_model::ProjectWorkspace::Json { project, .. } => { - project.crates().any(|(c, _)| crate_ids.iter().any(|&crate_id| crate_id == c)) - } - project_model::ProjectWorkspace::DetachedFiles { .. } => false, + let workspace_ids = world.workspaces.iter().enumerate().filter_map(|(idx, ws)| { + let package = match ws { + project_model::ProjectWorkspace::Cargo { cargo, .. } => { + cargo.packages().find_map(|pkg| { + let has_target_with_root = cargo[pkg] + .targets + .iter() + .any(|&it| crate_root_paths.contains(&cargo[it].root.as_path())); + has_target_with_root.then(|| cargo[pkg].name.clone()) + }) + } + project_model::ProjectWorkspace::Json { project, .. } => { + if !project + .crates() + .any(|(c, _)| crate_ids.iter().any(|&crate_id| crate_id == c)) + { + return None; + } + None + } + project_model::ProjectWorkspace::DetachedFiles { .. } => return None, + }; + Some((idx, package)) }); // Find and trigger corresponding flychecks for flycheck in world.flycheck.iter() { - for (id, _) in workspace_ids.clone() { + for (id, package) in workspace_ids.clone() { if id == flycheck.id() { updated = true; - flycheck.restart(); + match package.filter(|_| !world.config.flycheck_workspace()) { + Some(package) => flycheck.restart_for_package(package), + None => flycheck.restart_workspace(), + } continue; } } @@ -309,7 +330,7 @@ fn run_flycheck(state: &mut GlobalState, vfs_path: VfsPath) -> bool { // No specific flycheck was triggered, so let's trigger all of them. if !updated { for flycheck in world.flycheck.iter() { - flycheck.restart(); + flycheck.restart_workspace(); } } Ok(()) @@ -351,7 +372,7 @@ pub(crate) fn handle_run_flycheck( } // No specific flycheck was triggered, so let's trigger all of them. for flycheck in state.flycheck.iter() { - flycheck.restart(); + flycheck.restart_workspace(); } Ok(()) } diff --git a/crates/rust-analyzer/src/handlers/request.rs b/crates/rust-analyzer/src/handlers/request.rs index 2be2ba5c44..2a3633a48e 100644 --- a/crates/rust-analyzer/src/handlers/request.rs +++ b/crates/rust-analyzer/src/handlers/request.rs @@ -54,7 +54,7 @@ pub(crate) fn handle_workspace_reload(state: &mut GlobalState, _: ()) -> anyhow: state.proc_macro_clients = Arc::from_iter([]); state.proc_macro_changed = false; - state.fetch_workspaces_queue.request_op("reload workspace request".to_string(), false); + state.fetch_workspaces_queue.request_op("reload workspace request".to_owned(), false); Ok(()) } @@ -62,7 +62,7 @@ pub(crate) fn handle_proc_macros_rebuild(state: &mut GlobalState, _: ()) -> anyh state.proc_macro_clients = Arc::from_iter([]); state.proc_macro_changed = false; - state.fetch_build_data_queue.request_op("rebuild proc macros request".to_string(), ()); + state.fetch_build_data_queue.request_op("rebuild proc macros request".to_owned(), ()); Ok(()) } @@ -562,7 +562,7 @@ pub(crate) fn handle_will_rename_files( (Some(p1), Some(p2)) if p1 == p2 => { if from_path.is_dir() { // add '/' to end of url -- from `file://path/to/folder` to `file://path/to/folder/` - let mut old_folder_name = from_path.file_stem()?.to_str()?.to_string(); + let mut old_folder_name = from_path.file_stem()?.to_str()?.to_owned(); old_folder_name.push('/'); let from_with_trailing_slash = from.join(&old_folder_name).ok()?; @@ -570,7 +570,7 @@ pub(crate) fn handle_will_rename_files( let new_file_name = to_path.file_name()?.to_str()?; Some(( snap.url_to_file_id(&imitate_from_url).ok()?, - new_file_name.to_string(), + new_file_name.to_owned(), )) } else { let old_name = from_path.file_stem()?.to_str()?; @@ -578,7 +578,7 @@ pub(crate) fn handle_will_rename_files( match (old_name, new_name) { ("mod", _) => None, (_, "mod") => None, - _ => Some((snap.url_to_file_id(&from).ok()?, new_name.to_string())), + _ => Some((snap.url_to_file_id(&from).ok()?, new_name.to_owned())), } } } @@ -799,13 +799,13 @@ pub(crate) fn handle_runnables( None => { if !snap.config.linked_or_discovered_projects().is_empty() { res.push(lsp_ext::Runnable { - label: "cargo check --workspace".to_string(), + label: "cargo check --workspace".to_owned(), location: None, kind: lsp_ext::RunnableKind::Cargo, args: lsp_ext::CargoRunnable { workspace_root: None, override_cargo: config.override_cargo, - cargo_args: vec!["check".to_string(), "--workspace".to_string()], + cargo_args: vec!["check".to_owned(), "--workspace".to_owned()], cargo_extra_args: config.cargo_extra_args, executable_args: Vec::new(), expect_test: None, @@ -879,7 +879,7 @@ pub(crate) fn handle_completion_resolve( if !all_edits_are_disjoint(&original_completion, &[]) { return Err(invalid_params_error( - "Received a completion with overlapping edits, this is not LSP-compliant".to_string(), + "Received a completion with overlapping edits, this is not LSP-compliant".to_owned(), ) .into()); } @@ -1191,7 +1191,7 @@ pub(crate) fn handle_code_action_resolve( let _p = tracing::span!(tracing::Level::INFO, "handle_code_action_resolve").entered(); let params = match code_action.data.take() { Some(it) => it, - None => return Err(invalid_params_error("code action without data".to_string()).into()), + None => return Err(invalid_params_error("code action without data".to_owned()).into()), }; let file_id = from_proto::file_id(&snap, ¶ms.code_action_params.text_document.uri)?; @@ -1270,7 +1270,7 @@ fn parse_action_id(action_id: &str) -> anyhow::Result<(usize, SingleResolve), St }; Ok((index, SingleResolve { assist_id: assist_id_string.to_string(), assist_kind })) } - _ => Err("Action id contains incorrect number of segments".to_string()), + _ => Err("Action id contains incorrect number of segments".to_owned()), } } diff --git a/crates/rust-analyzer/src/lsp/ext.rs b/crates/rust-analyzer/src/lsp/ext.rs index 35c8fad374..aa40728ce6 100644 --- a/crates/rust-analyzer/src/lsp/ext.rs +++ b/crates/rust-analyzer/src/lsp/ext.rs @@ -1,6 +1,8 @@ //! rust-analyzer extensions to the LSP. -use std::{collections::HashMap, path::PathBuf}; +#![allow(clippy::disallowed_types)] + +use std::path::PathBuf; use ide_db::line_index::WideEncoding; use lsp_types::request::Request; @@ -9,6 +11,7 @@ use lsp_types::{ PartialResultParams, Position, Range, TextDocumentIdentifier, WorkDoneProgressParams, }; use lsp_types::{PositionEncodingKind, Url}; +use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use crate::line_index::PositionEncoding; @@ -448,12 +451,16 @@ pub struct CodeActionData { #[serde(rename_all = "camelCase")] pub struct SnippetWorkspaceEdit { #[serde(skip_serializing_if = "Option::is_none")] - pub changes: Option>>, + pub changes: Option>>, #[serde(skip_serializing_if = "Option::is_none")] pub document_changes: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub change_annotations: - Option>, + pub change_annotations: Option< + std::collections::HashMap< + lsp_types::ChangeAnnotationIdentifier, + lsp_types::ChangeAnnotation, + >, + >, } #[derive(Debug, Eq, PartialEq, Clone, Deserialize, Serialize)] @@ -696,3 +703,16 @@ pub struct CompletionImport { pub struct ClientCommandOptions { pub commands: Vec, } + +pub enum UnindexedProject {} + +impl Notification for UnindexedProject { + type Params = UnindexedProjectParams; + const METHOD: &'static str = "rust-analyzer/unindexedProject"; +} + +#[derive(Deserialize, Serialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct UnindexedProjectParams { + pub text_documents: Vec, +} diff --git a/crates/rust-analyzer/src/lsp/from_proto.rs b/crates/rust-analyzer/src/lsp/from_proto.rs index 9923be382b..f42985a916 100644 --- a/crates/rust-analyzer/src/lsp/from_proto.rs +++ b/crates/rust-analyzer/src/lsp/from_proto.rs @@ -108,7 +108,7 @@ pub(crate) fn annotation( code_lens: lsp_types::CodeLens, ) -> anyhow::Result> { let data = - code_lens.data.ok_or_else(|| invalid_params_error("code lens without data".to_string()))?; + code_lens.data.ok_or_else(|| invalid_params_error("code lens without data".to_owned()))?; let resolve = from_json::("CodeLensResolveData", &data)?; match resolve.kind { diff --git a/crates/rust-analyzer/src/lsp/to_proto.rs b/crates/rust-analyzer/src/lsp/to_proto.rs index d363ac69fd..64f19f0b32 100644 --- a/crates/rust-analyzer/src/lsp/to_proto.rs +++ b/crates/rust-analyzer/src/lsp/to_proto.rs @@ -245,7 +245,7 @@ fn completion_item( ) { let insert_replace_support = config.insert_replace_support().then_some(tdpp.position); let ref_match = item.ref_match(); - let lookup = item.lookup().to_string(); + let lookup = item.lookup().to_owned(); let mut additional_text_edits = Vec::new(); @@ -367,7 +367,7 @@ pub(crate) fn signature_help( let params = call_info .parameter_labels() .map(|label| lsp_types::ParameterInformation { - label: lsp_types::ParameterLabel::Simple(label.to_string()), + label: lsp_types::ParameterLabel::Simple(label.to_owned()), documentation: None, }) .collect::>(); @@ -1498,7 +1498,7 @@ pub(crate) mod command { pub(crate) fn run_single(runnable: &lsp_ext::Runnable, title: &str) -> lsp_types::Command { lsp_types::Command { - title: title.to_string(), + title: title.to_owned(), command: "rust-analyzer.runSingle".into(), arguments: Some(vec![to_value(runnable).unwrap()]), } @@ -1608,7 +1608,7 @@ fn main() { } }"#; - let (analysis, file_id) = Analysis::from_single_file(text.to_string()); + let (analysis, file_id) = Analysis::from_single_file(text.to_owned()); let folds = analysis.folding_ranges(file_id).unwrap(); assert_eq!(folds.len(), 4); diff --git a/crates/rust-analyzer/src/lsp/utils.rs b/crates/rust-analyzer/src/lsp/utils.rs index fa5ea5b57d..10335cb145 100644 --- a/crates/rust-analyzer/src/lsp/utils.rs +++ b/crates/rust-analyzer/src/lsp/utils.rs @@ -333,21 +333,20 @@ mod tests { #[test] fn empty_completion_disjoint_tests() { - let empty_completion = - CompletionItem::new_simple("label".to_string(), "detail".to_string()); + let empty_completion = CompletionItem::new_simple("label".to_owned(), "detail".to_owned()); let disjoint_edit_1 = lsp_types::TextEdit::new( Range::new(Position::new(2, 2), Position::new(3, 3)), - "new_text".to_string(), + "new_text".to_owned(), ); let disjoint_edit_2 = lsp_types::TextEdit::new( Range::new(Position::new(3, 3), Position::new(4, 4)), - "new_text".to_string(), + "new_text".to_owned(), ); let joint_edit = lsp_types::TextEdit::new( Range::new(Position::new(1, 1), Position::new(5, 5)), - "new_text".to_string(), + "new_text".to_owned(), ); assert!( @@ -375,19 +374,19 @@ mod tests { fn completion_with_joint_edits_disjoint_tests() { let disjoint_edit = lsp_types::TextEdit::new( Range::new(Position::new(1, 1), Position::new(2, 2)), - "new_text".to_string(), + "new_text".to_owned(), ); let disjoint_edit_2 = lsp_types::TextEdit::new( Range::new(Position::new(2, 2), Position::new(3, 3)), - "new_text".to_string(), + "new_text".to_owned(), ); let joint_edit = lsp_types::TextEdit::new( Range::new(Position::new(1, 1), Position::new(5, 5)), - "new_text".to_string(), + "new_text".to_owned(), ); let mut completion_with_joint_edits = - CompletionItem::new_simple("label".to_string(), "detail".to_string()); + CompletionItem::new_simple("label".to_owned(), "detail".to_owned()); completion_with_joint_edits.additional_text_edits = Some(vec![disjoint_edit.clone(), joint_edit.clone()]); assert!( @@ -405,7 +404,7 @@ mod tests { completion_with_joint_edits.text_edit = Some(CompletionTextEdit::InsertAndReplace(InsertReplaceEdit { - new_text: "new_text".to_string(), + new_text: "new_text".to_owned(), insert: disjoint_edit.range, replace: disjoint_edit_2.range, })); @@ -420,19 +419,19 @@ mod tests { fn completion_with_disjoint_edits_disjoint_tests() { let disjoint_edit = lsp_types::TextEdit::new( Range::new(Position::new(1, 1), Position::new(2, 2)), - "new_text".to_string(), + "new_text".to_owned(), ); let disjoint_edit_2 = lsp_types::TextEdit::new( Range::new(Position::new(2, 2), Position::new(3, 3)), - "new_text".to_string(), + "new_text".to_owned(), ); let joint_edit = lsp_types::TextEdit::new( Range::new(Position::new(1, 1), Position::new(5, 5)), - "new_text".to_string(), + "new_text".to_owned(), ); let mut completion_with_disjoint_edits = - CompletionItem::new_simple("label".to_string(), "detail".to_string()); + CompletionItem::new_simple("label".to_owned(), "detail".to_owned()); completion_with_disjoint_edits.text_edit = Some(CompletionTextEdit::Edit(disjoint_edit)); let completion_with_disjoint_edits = completion_with_disjoint_edits; diff --git a/crates/rust-analyzer/src/main_loop.rs b/crates/rust-analyzer/src/main_loop.rs index f3ead6d04f..88660db7e9 100644 --- a/crates/rust-analyzer/src/main_loop.rs +++ b/crates/rust-analyzer/src/main_loop.rs @@ -1,5 +1,6 @@ //! The main loop of `rust-analyzer` responsible for dispatching LSP //! requests/replies and notifications back to the client. +use crate::lsp::ext; use std::{ fmt, time::{Duration, Instant}, @@ -56,6 +57,7 @@ pub fn main_loop(config: Config, connection: Connection) -> anyhow::Result<()> { enum Event { Lsp(lsp_server::Message), Task(Task), + QueuedTask(QueuedTask), Vfs(vfs::loader::Message), Flycheck(flycheck::Message), } @@ -67,13 +69,20 @@ impl fmt::Display for Event { Event::Task(_) => write!(f, "Event::Task"), Event::Vfs(_) => write!(f, "Event::Vfs"), Event::Flycheck(_) => write!(f, "Event::Flycheck"), + Event::QueuedTask(_) => write!(f, "Event::QueuedTask"), } } } +#[derive(Debug)] +pub(crate) enum QueuedTask { + CheckIfIndexed(lsp_types::Url), +} + #[derive(Debug)] pub(crate) enum Task { Response(lsp_server::Response), + ClientNotification(ext::UnindexedProjectParams), Retry(lsp_server::Request), Diagnostics(Vec<(FileId, Vec)>), PrimeCaches(PrimeCachesProgress), @@ -115,6 +124,7 @@ impl fmt::Debug for Event { match self { Event::Lsp(it) => fmt::Debug::fmt(it, f), Event::Task(it) => fmt::Debug::fmt(it, f), + Event::QueuedTask(it) => fmt::Debug::fmt(it, f), Event::Vfs(it) => fmt::Debug::fmt(it, f), Event::Flycheck(it) => fmt::Debug::fmt(it, f), } @@ -129,7 +139,7 @@ impl GlobalState { self.register_did_save_capability(); } - self.fetch_workspaces_queue.request_op("startup".to_string(), false); + self.fetch_workspaces_queue.request_op("startup".to_owned(), false); if let Some((cause, force_crate_graph_reload)) = self.fetch_workspaces_queue.should_start_op() { @@ -175,8 +185,8 @@ impl GlobalState { }; let registration = lsp_types::Registration { - id: "textDocument/didSave".to_string(), - method: "textDocument/didSave".to_string(), + id: "textDocument/didSave".to_owned(), + method: "textDocument/didSave".to_owned(), register_options: Some(serde_json::to_value(save_registration_options).unwrap()), }; self.send_request::( @@ -193,6 +203,9 @@ impl GlobalState { recv(self.task_pool.receiver) -> task => Some(Event::Task(task.unwrap())), + recv(self.deferred_task_queue.receiver) -> task => + Some(Event::QueuedTask(task.unwrap())), + recv(self.fmt_pool.receiver) -> task => Some(Event::Task(task.unwrap())), @@ -211,7 +224,7 @@ impl GlobalState { .entered(); let event_dbg_msg = format!("{event:?}"); - tracing::debug!("{:?} handle_event({})", loop_start, event_dbg_msg); + tracing::debug!(?loop_start, ?event, "handle_event"); if tracing::enabled!(tracing::Level::INFO) { let task_queue_len = self.task_pool.handle.len(); if task_queue_len > 0 { @@ -226,6 +239,16 @@ impl GlobalState { lsp_server::Message::Notification(not) => self.on_notification(not)?, lsp_server::Message::Response(resp) => self.complete_request(resp), }, + Event::QueuedTask(task) => { + let _p = + tracing::span!(tracing::Level::INFO, "GlobalState::handle_event/queued_task") + .entered(); + self.handle_queued_task(task); + // Coalesce multiple task events into one loop turn + while let Ok(task) = self.deferred_task_queue.receiver.try_recv() { + self.handle_queued_task(task); + } + } Event::Task(task) => { let _p = tracing::span!(tracing::Level::INFO, "GlobalState::handle_event/task") .entered(); @@ -273,7 +296,7 @@ impl GlobalState { self.prime_caches_queue.op_completed(()); if cancelled { self.prime_caches_queue - .request_op("restart after cancellation".to_string(), ()); + .request_op("restart after cancellation".to_owned(), ()); } } }; @@ -314,10 +337,10 @@ impl GlobalState { if became_quiescent { if self.config.check_on_save() { // Project has loaded properly, kick off initial flycheck - self.flycheck.iter().for_each(FlycheckHandle::restart); + self.flycheck.iter().for_each(FlycheckHandle::restart_workspace); } if self.config.prefill_caches() { - self.prime_caches_queue.request_op("became quiescent".to_string(), ()); + self.prime_caches_queue.request_op("became quiescent".to_owned(), ()); } } @@ -367,7 +390,7 @@ impl GlobalState { // See https://github.com/rust-lang/rust-analyzer/issues/13130 let patch_empty = |message: &mut String| { if message.is_empty() { - *message = " ".to_string(); + *message = " ".to_owned(); } }; @@ -498,6 +521,9 @@ impl GlobalState { fn handle_task(&mut self, prime_caches_progress: &mut Vec, task: Task) { match task { Task::Response(response) => self.respond(response), + Task::ClientNotification(params) => { + self.send_notification::(params) + } // Only retry requests that haven't been cancelled. Otherwise we do unnecessary work. Task::Retry(req) if !self.is_completed(&req) => self.on_request(req), Task::Retry(_) => (), @@ -531,12 +557,12 @@ impl GlobalState { } let old = Arc::clone(&self.workspaces); - self.switch_workspaces("fetched workspace".to_string()); + self.switch_workspaces("fetched workspace".to_owned()); let workspaces_updated = !Arc::ptr_eq(&old, &self.workspaces); if self.config.run_build_scripts() && workspaces_updated { self.fetch_build_data_queue - .request_op("workspace updated".to_string(), ()); + .request_op("workspace updated".to_owned(), ()); } (Progress::End, None) @@ -555,7 +581,7 @@ impl GlobalState { tracing::error!("FetchBuildDataError:\n{e}"); } - self.switch_workspaces("fetched build data".to_string()); + self.switch_workspaces("fetched build data".to_owned()); self.send_hint_refresh_query = true; (Some(Progress::End), None) @@ -638,6 +664,31 @@ impl GlobalState { } } + fn handle_queued_task(&mut self, task: QueuedTask) { + match task { + QueuedTask::CheckIfIndexed(uri) => { + let snap = self.snapshot(); + + self.task_pool.handle.spawn_with_sender(ThreadIntent::Worker, move |sender| { + let _p = tracing::span!(tracing::Level::INFO, "GlobalState::check_if_indexed") + .entered(); + tracing::debug!(?uri, "handling uri"); + let id = from_proto::file_id(&snap, &uri).expect("unable to get FileId"); + if let Ok(crates) = &snap.analysis.crates_for(id) { + if crates.is_empty() { + let params = ext::UnindexedProjectParams { + text_documents: vec![lsp_types::TextDocumentIdentifier { uri }], + }; + sender.send(Task::ClientNotification(params)).unwrap(); + } else { + tracing::debug!(?uri, "is indexed"); + } + } + }); + } + } + } + fn handle_flycheck_msg(&mut self, message: flycheck::Message) { match message { flycheck::Message::AddDiagnostic { id, workspace_root, diagnostic } => { diff --git a/crates/rust-analyzer/src/reload.rs b/crates/rust-analyzer/src/reload.rs index 65c00cc08d..7bd2877b00 100644 --- a/crates/rust-analyzer/src/reload.rs +++ b/crates/rust-analyzer/src/reload.rs @@ -83,7 +83,7 @@ impl GlobalState { } if self.config.linked_or_discovered_projects() != old_config.linked_or_discovered_projects() { - self.fetch_workspaces_queue.request_op("linked projects changed".to_string(), false) + self.fetch_workspaces_queue.request_op("linked projects changed".to_owned(), false) } else if self.config.flycheck() != old_config.flycheck() { self.reload_flycheck(); } @@ -440,8 +440,8 @@ impl GlobalState { .collect(), }; let registration = lsp_types::Registration { - id: "workspace/didChangeWatchedFiles".to_string(), - method: "workspace/didChangeWatchedFiles".to_string(), + id: "workspace/didChangeWatchedFiles".to_owned(), + method: "workspace/didChangeWatchedFiles".to_owned(), register_options: Some(serde_json::to_value(registration_options).unwrap()), }; self.send_request::( @@ -503,7 +503,7 @@ impl GlobalState { let mut crate_graph_file_dependencies = FxHashSet::default(); let mut load = |path: &AbsPath| { - let _p = tracing::span!(tracing::Level::INFO, "switch_workspaces::load").entered(); + let _p = tracing::span!(tracing::Level::DEBUG, "switch_workspaces::load").entered(); let vfs_path = vfs::VfsPath::from(path.to_path_buf()); crate_graph_file_dependencies.insert(vfs_path.clone()); match vfs.file_id(&vfs_path) { @@ -528,10 +528,16 @@ impl GlobalState { (crate_graph, proc_macros, crate_graph_file_dependencies) }; + let mut change = Change::new(); if self.config.expand_proc_macros() { + change.set_proc_macros( + crate_graph + .iter() + .map(|id| (id, Err("Proc-macros have not been built yet".to_owned()))) + .collect(), + ); self.fetch_proc_macros_queue.request_op(cause, proc_macro_paths); } - let mut change = Change::new(); change.set_crate_graph(crate_graph); self.analysis_host.apply_change(change); self.crate_graph_file_dependencies = crate_graph_file_dependencies; diff --git a/crates/rust-analyzer/src/task_pool.rs b/crates/rust-analyzer/src/task_pool.rs index a5a10e8691..f7de5fb2ff 100644 --- a/crates/rust-analyzer/src/task_pool.rs +++ b/crates/rust-analyzer/src/task_pool.rs @@ -4,6 +4,8 @@ use crossbeam_channel::Sender; use stdx::thread::{Pool, ThreadIntent}; +use crate::main_loop::QueuedTask; + pub(crate) struct TaskPool { sender: Sender, pool: Pool, @@ -40,3 +42,12 @@ impl TaskPool { self.pool.len() } } + +/// `TaskQueue`, like its name suggests, queues tasks. +/// +/// This should only be used used if a task must run after [`GlobalState::process_changes`] +/// has been called. +pub(crate) struct TaskQueue { + pub(crate) sender: crossbeam_channel::Sender, + pub(crate) receiver: crossbeam_channel::Receiver, +} diff --git a/crates/rust-analyzer/src/tracing/hprof.rs b/crates/rust-analyzer/src/tracing/hprof.rs index c99b551df8..9064987329 100644 --- a/crates/rust-analyzer/src/tracing/hprof.rs +++ b/crates/rust-analyzer/src/tracing/hprof.rs @@ -179,6 +179,7 @@ impl Node { self.go(0, filter) } + #[allow(clippy::print_stderr)] fn go(&self, level: usize, filter: &WriteFilter) { if self.duration > filter.longer_than && level < filter.depth { let duration = ms(self.duration); diff --git a/crates/rust-analyzer/tests/slow-tests/main.rs b/crates/rust-analyzer/tests/slow-tests/main.rs index 19890110d5..79ae0c30cf 100644 --- a/crates/rust-analyzer/tests/slow-tests/main.rs +++ b/crates/rust-analyzer/tests/slow-tests/main.rs @@ -9,6 +9,7 @@ //! be sure without a real client anyway. #![warn(rust_2018_idioms, unused_lifetimes)] +#![allow(clippy::disallowed_types)] #[cfg(not(feature = "in-rust-tree"))] mod sourcegen; @@ -29,7 +30,7 @@ use lsp_types::{ PartialResultParams, Position, Range, RenameFilesParams, TextDocumentItem, TextDocumentPositionParams, WorkDoneProgressParams, }; -use rust_analyzer::lsp::ext::{OnEnter, Runnables, RunnablesParams}; +use rust_analyzer::lsp::ext::{OnEnter, Runnables, RunnablesParams, UnindexedProject}; use serde_json::json; use stdx::format_to_acc; use test_utils::skip_slow_tests; @@ -586,6 +587,66 @@ fn main() {{}} ); } +#[test] +fn test_opening_a_file_outside_of_indexed_workspace() { + if skip_slow_tests() { + return; + } + + let tmp_dir = TestDir::new(); + let path = tmp_dir.path(); + + let project = json!({ + "roots": [path], + "crates": [ { + "root_module": path.join("src/crate_one/lib.rs"), + "deps": [], + "edition": "2015", + "cfg": [ "cfg_atom_1", "feature=\"cfg_1\""], + } ] + }); + + let code = format!( + r#" +//- /rust-project.json +{project} + +//- /src/crate_one/lib.rs +mod bar; + +fn main() {{}} +"#, + ); + + let server = Project::with_fixture(&code) + .tmp_dir(tmp_dir) + .with_config(serde_json::json!({ + "notifications": { + "unindexedProject": true + }, + })) + .server() + .wait_until_workspace_is_loaded(); + + let uri = server.doc_id("src/crate_two/lib.rs").uri; + server.notification::(DidOpenTextDocumentParams { + text_document: TextDocumentItem { + uri: uri.clone(), + language_id: "rust".to_owned(), + version: 0, + text: "/// Docs\nfn foo() {}".to_owned(), + }, + }); + let expected = json!({ + "textDocuments": [ + { + "uri": uri + } + ] + }); + server.expect_notification::(expected); +} + #[test] fn diagnostics_dont_block_typing() { if skip_slow_tests() { @@ -621,9 +682,9 @@ fn main() {{}} server.notification::(DidOpenTextDocumentParams { text_document: TextDocumentItem { uri: server.doc_id(&format!("src/m{i}.rs")).uri, - language_id: "rust".to_string(), + language_id: "rust".to_owned(), version: 0, - text: "/// Docs\nfn foo() {}".to_string(), + text: "/// Docs\nfn foo() {}".to_owned(), }, }); } @@ -1017,15 +1078,15 @@ use crate::old_folder::nested::foo as bar; server.request::( RenameFilesParams { files: vec![FileRename { - old_uri: base_path.join("src/old_file.rs").to_str().unwrap().to_string(), - new_uri: base_path.join("src/new_file.rs").to_str().unwrap().to_string(), + old_uri: base_path.join("src/old_file.rs").to_str().unwrap().to_owned(), + new_uri: base_path.join("src/new_file.rs").to_str().unwrap().to_owned(), }], }, json!({ "documentChanges": [ { "textDocument": { - "uri": format!("file://{}", tmp_dir_path.join("src").join("lib.rs").to_str().unwrap().to_string().replace("C:\\", "/c:/").replace('\\', "/")), + "uri": format!("file://{}", tmp_dir_path.join("src").join("lib.rs").to_str().unwrap().to_owned().replace("C:\\", "/c:/").replace('\\', "/")), "version": null }, "edits": [ @@ -1052,8 +1113,8 @@ use crate::old_folder::nested::foo as bar; server.request::( RenameFilesParams { files: vec![FileRename { - old_uri: base_path.join("src/from_mod/mod.rs").to_str().unwrap().to_string(), - new_uri: base_path.join("src/from_mod/foo.rs").to_str().unwrap().to_string(), + old_uri: base_path.join("src/from_mod/mod.rs").to_str().unwrap().to_owned(), + new_uri: base_path.join("src/from_mod/foo.rs").to_str().unwrap().to_owned(), }], }, json!(null), @@ -1063,8 +1124,8 @@ use crate::old_folder::nested::foo as bar; server.request::( RenameFilesParams { files: vec![FileRename { - old_uri: base_path.join("src/to_mod/foo.rs").to_str().unwrap().to_string(), - new_uri: base_path.join("src/to_mod/mod.rs").to_str().unwrap().to_string(), + old_uri: base_path.join("src/to_mod/foo.rs").to_str().unwrap().to_owned(), + new_uri: base_path.join("src/to_mod/mod.rs").to_str().unwrap().to_owned(), }], }, json!(null), @@ -1074,15 +1135,15 @@ use crate::old_folder::nested::foo as bar; server.request::( RenameFilesParams { files: vec![FileRename { - old_uri: base_path.join("src/old_folder").to_str().unwrap().to_string(), - new_uri: base_path.join("src/new_folder").to_str().unwrap().to_string(), + old_uri: base_path.join("src/old_folder").to_str().unwrap().to_owned(), + new_uri: base_path.join("src/new_folder").to_str().unwrap().to_owned(), }], }, json!({ "documentChanges": [ { "textDocument": { - "uri": format!("file://{}", tmp_dir_path.join("src").join("lib.rs").to_str().unwrap().to_string().replace("C:\\", "/c:/").replace('\\', "/")), + "uri": format!("file://{}", tmp_dir_path.join("src").join("lib.rs").to_str().unwrap().to_owned().replace("C:\\", "/c:/").replace('\\', "/")), "version": null }, "edits": [ @@ -1103,7 +1164,7 @@ use crate::old_folder::nested::foo as bar; }, { "textDocument": { - "uri": format!("file://{}", tmp_dir_path.join("src").join("old_folder").join("nested.rs").to_str().unwrap().to_string().replace("C:\\", "/c:/").replace('\\', "/")), + "uri": format!("file://{}", tmp_dir_path.join("src").join("old_folder").join("nested.rs").to_str().unwrap().to_owned().replace("C:\\", "/c:/").replace('\\', "/")), "version": null }, "edits": [ diff --git a/crates/rust-analyzer/tests/slow-tests/support.rs b/crates/rust-analyzer/tests/slow-tests/support.rs index d699374f9c..d02cb45b8e 100644 --- a/crates/rust-analyzer/tests/slow-tests/support.rs +++ b/crates/rust-analyzer/tests/slow-tests/support.rs @@ -9,7 +9,7 @@ use std::{ use crossbeam_channel::{after, select, Receiver}; use lsp_server::{Connection, Message, Notification, Request}; use lsp_types::{notification::Exit, request::Shutdown, TextDocumentIdentifier, Url}; -use rust_analyzer::{config::Config, lsp, main_loop, tracing}; +use rust_analyzer::{config::Config, lsp, main_loop}; use serde::Serialize; use serde_json::{json, to_string_pretty, Value}; use test_utils::FixtureWithProjectMeta; @@ -91,11 +91,11 @@ impl Project<'_> { static INIT: Once = Once::new(); INIT.call_once(|| { - let _ = tracing::Config { + let _ = rust_analyzer::tracing::Config { writer: TestWriter::default(), // Deliberately enable all `error` logs if the user has not set RA_LOG, as there is usually // useful information in there for debugging. - filter: std::env::var("RA_LOG").ok().unwrap_or_else(|| "error".to_string()), + filter: std::env::var("RA_LOG").ok().unwrap_or_else(|| "error".to_owned()), chalk_filter: std::env::var("CHALK_DEBUG").ok(), profile_filter: std::env::var("RA_PROFILE").ok(), }; @@ -193,7 +193,7 @@ impl Server { let (connection, client) = Connection::memory(); let _thread = stdx::thread::Builder::new(stdx::thread::ThreadIntent::Worker) - .name("test server".to_string()) + .name("test server".to_owned()) .spawn(move || main_loop(config, connection).unwrap()) .expect("failed to spawn a thread"); @@ -210,10 +210,44 @@ impl Server { N: lsp_types::notification::Notification, N::Params: Serialize, { - let r = Notification::new(N::METHOD.to_string(), params); + let r = Notification::new(N::METHOD.to_owned(), params); self.send_notification(r) } + pub(crate) fn expect_notification(&self, expected: Value) + where + N: lsp_types::notification::Notification, + N::Params: Serialize, + { + while let Some(Message::Notification(actual)) = + recv_timeout(&self.client.receiver).unwrap_or_else(|_| panic!("timed out")) + { + if actual.method == N::METHOD { + let actual = actual + .clone() + .extract::(N::METHOD) + .expect("was not able to extract notification"); + + tracing::debug!(?actual, "got notification"); + if let Some((expected_part, actual_part)) = find_mismatch(&expected, &actual) { + panic!( + "JSON mismatch\nExpected:\n{}\nWas:\n{}\nExpected part:\n{}\nActual part:\n{}\n", + to_string_pretty(&expected).unwrap(), + to_string_pretty(&actual).unwrap(), + to_string_pretty(expected_part).unwrap(), + to_string_pretty(actual_part).unwrap(), + ); + } else { + tracing::debug!("sucessfully matched notification"); + return; + } + } else { + continue; + } + } + panic!("never got expected notification"); + } + #[track_caller] pub(crate) fn request(&self, params: R::Params, expected_resp: Value) where @@ -240,7 +274,7 @@ impl Server { let id = self.req_id.get(); self.req_id.set(id.wrapping_add(1)); - let r = Request::new(id.into(), R::METHOD.to_string(), params); + let r = Request::new(id.into(), R::METHOD.to_owned(), params); self.send_request_(r) } fn send_request_(&self, r: Request) -> Value { diff --git a/crates/rust-analyzer/tests/slow-tests/tidy.rs b/crates/rust-analyzer/tests/slow-tests/tidy.rs index d3146ab767..3e38fc3ebc 100644 --- a/crates/rust-analyzer/tests/slow-tests/tidy.rs +++ b/crates/rust-analyzer/tests/slow-tests/tidy.rs @@ -1,3 +1,4 @@ +#![allow(clippy::disallowed_types, clippy::print_stderr)] use std::{ collections::HashSet, path::{Path, PathBuf}, @@ -78,8 +79,6 @@ fn files_are_tidy() { match extension { "rs" => { let text = sh.read_file(&path).unwrap(); - check_todo(&path, &text); - check_dbg(&path, &text); check_test_attrs(&path, &text); check_trailing_ws(&path, &text); tidy_docs.visit(&path, &text); @@ -205,74 +204,6 @@ Zlib OR Apache-2.0 OR MIT assert_eq!(licenses, expected); } -fn check_todo(path: &Path, text: &str) { - let need_todo = &[ - // This file itself obviously needs to use todo (<- like this!). - "tests/tidy.rs", - // Some of our assists generate `todo!()`. - "handlers/add_turbo_fish.rs", - "handlers/generate_function.rs", - "handlers/add_missing_match_arms.rs", - "handlers/replace_derive_with_manual_impl.rs", - // To support generating `todo!()` in assists, we have `expr_todo()` in - // `ast::make`. - "ast/make.rs", - // The documentation in string literals may contain anything for its own purposes - "ide-db/src/generated/lints.rs", - "ide-assists/src/utils/gen_trait_fn_body.rs", - "ide-assists/src/tests/generated.rs", - // The tests for missing fields - "ide-diagnostics/src/handlers/missing_fields.rs", - ]; - if need_todo.iter().any(|p| path.ends_with(p)) { - return; - } - if text.contains("TODO") || text.contains("TOOD") || text.contains("todo!") { - // Generated by an assist - if text.contains("${0:todo!()}") { - return; - } - - panic!( - "\nTODO markers or todo! macros should not be committed to the master branch,\n\ - use FIXME instead\n\ - {}\n", - path.display(), - ) - } -} - -fn check_dbg(path: &Path, text: &str) { - let need_dbg = &[ - // This file itself obviously needs to use dbg. - "slow-tests/tidy.rs", - // Assists to remove `dbg!()` - "handlers/remove_dbg.rs", - // We have .dbg postfix - "ide-completion/src/completions/postfix.rs", - "ide-completion/src/completions/keyword.rs", - "ide-completion/src/tests/expression.rs", - "ide-completion/src/tests/proc_macros.rs", - // The documentation in string literals may contain anything for its own purposes - "ide-completion/src/lib.rs", - "ide-db/src/generated/lints.rs", - // test for doc test for remove_dbg - "src/tests/generated.rs", - // `expect!` string can contain `dbg!` (due to .dbg postfix) - "ide-completion/src/tests/special.rs", - ]; - if need_dbg.iter().any(|p| path.ends_with(p)) { - return; - } - if text.contains("dbg!") { - panic!( - "\ndbg! macros should not be committed to the master branch,\n\ - {}\n", - path.display(), - ) - } -} - fn check_test_attrs(path: &Path, text: &str) { let ignore_rule = "https://github.com/rust-lang/rust-analyzer/blob/master/docs/dev/style.md#ignore"; @@ -442,7 +373,7 @@ fn find_marks(set: &mut HashSet, text: &str, mark: &str) { text = stripped_text.trim_start(); if let Some(idx2) = text.find(|c: char| !(c.is_alphanumeric() || c == '_')) { let mark_text = &text[..idx2]; - set.insert(mark_text.to_string()); + set.insert(mark_text.to_owned()); text = &text[idx2..]; } } diff --git a/crates/salsa/Cargo.toml b/crates/salsa/Cargo.toml new file mode 100644 index 0000000000..4ccbc3de84 --- /dev/null +++ b/crates/salsa/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "salsa" +version = "0.0.0" +authors = ["Salsa developers"] +edition = "2021" +license = "Apache-2.0 OR MIT" +repository = "https://github.com/salsa-rs/salsa" +description = "A generic framework for on-demand, incrementalized computation (experimental)" + +rust-version.workspace = true + +[lib] +name = "salsa" + +[dependencies] +indexmap = "2.1.0" +lock_api = "0.4" +tracing = "0.1" +parking_lot = "0.12.1" +rustc-hash = "1.0" +smallvec = "1.0.0" +oorandom = "11" +triomphe = "0.1.11" + +salsa-macros = { version = "0.0.0", path = "salsa-macros" } + +[dev-dependencies] +linked-hash-map = "0.5.6" +rand = "0.8.5" +test-log = "0.2.14" +expect-test = "1.4.0" +dissimilar = "1.0.7" + +[lints] +workspace = true diff --git a/crates/salsa/FAQ.md b/crates/salsa/FAQ.md new file mode 100644 index 0000000000..9c9f6f92da --- /dev/null +++ b/crates/salsa/FAQ.md @@ -0,0 +1,34 @@ +# Frequently asked questions + +## Why is it called salsa? + +I like salsa! Don't you?! Well, ok, there's a bit more to it. The +underlying algorithm for figuring out which bits of code need to be +re-executed after any given change is based on the algorithm used in +rustc. Michael Woerister and I first described the rustc algorithm in +terms of two colors, red and green, and hence we called it the +"red-green algorithm". This made me think of the New Mexico State +Question --- ["Red or green?"][nm] --- which refers to chile +(salsa). Although this version no longer uses colors (we borrowed +revision counters from Glimmer, instead), I still like the name. + +[nm]: https://www.sos.state.nm.us/about-new-mexico/state-question/ + +## What is the relationship between salsa and an Entity-Component System (ECS)? + +You may have noticed that Salsa "feels" a lot like an ECS in some +ways. That's true -- Salsa's queries are a bit like *components* (and +the keys to the queries are a bit like *entities*). But there is one +big difference: **ECS is -- at its heart -- a mutable system**. You +can get or set a component of some entity whenever you like. In +contrast, salsa's queries **define "derived values" via pure +computations**. + +Partly as a consequence, ECS doesn't handle incremental updates for +you. When you update some component of some entity, you have to ensure +that other entities' components are updated appropriately. + +Finally, ECS offers interesting metadata and "aspect-like" facilities, +such as iterating over all entities that share certain components. +Salsa has no analogue to that. + diff --git a/crates/salsa/LICENSE-APACHE b/crates/salsa/LICENSE-APACHE new file mode 100644 index 0000000000..16fe87b06e --- /dev/null +++ b/crates/salsa/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/crates/salsa/LICENSE-MIT b/crates/salsa/LICENSE-MIT new file mode 100644 index 0000000000..31aa79387f --- /dev/null +++ b/crates/salsa/LICENSE-MIT @@ -0,0 +1,23 @@ +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/crates/salsa/README.md b/crates/salsa/README.md new file mode 100644 index 0000000000..4a8d9f8c73 --- /dev/null +++ b/crates/salsa/README.md @@ -0,0 +1,42 @@ +# salsa + +*A generic framework for on-demand, incrementalized computation.* + +## Obligatory warning + +This is a fork of https://github.com/salsa-rs/salsa/ adjusted to rust-analyzer's needs. + +## Credits + +This system is heavily inspired by [adapton](http://adapton.org/), [glimmer](https://github.com/glimmerjs/glimmer-vm), and rustc's query +system. So credit goes to Eduard-Mihai Burtescu, Matthew Hammer, +Yehuda Katz, and Michael Woerister. + +## Key idea + +The key idea of `salsa` is that you define your program as a set of +**queries**. Every query is used like function `K -> V` that maps from +some key of type `K` to a value of type `V`. Queries come in two basic +varieties: + +- **Inputs**: the base inputs to your system. You can change these + whenever you like. +- **Functions**: pure functions (no side effects) that transform your + inputs into other values. The results of queries is memoized to + avoid recomputing them a lot. When you make changes to the inputs, + we'll figure out (fairly intelligently) when we can re-use these + memoized values and when we have to recompute them. + +## Want to learn more? + +To learn more about Salsa, try one of the following: + +- read the [heavily commented `hello_world` example](https://github.com/salsa-rs/salsa/blob/master/examples/hello_world/main.rs); +- check out the [Salsa book](https://salsa-rs.github.io/salsa); +- watch one of our [videos](https://salsa-rs.github.io/salsa/videos.html). + +## Getting in touch + +The bulk of the discussion happens in the [issues](https://github.com/salsa-rs/salsa/issues) +and [pull requests](https://github.com/salsa-rs/salsa/pulls), +but we have a [zulip chat](https://salsa.zulipchat.com/) as well. diff --git a/crates/salsa/salsa-macros/Cargo.toml b/crates/salsa/salsa-macros/Cargo.toml new file mode 100644 index 0000000000..791d2f6e9f --- /dev/null +++ b/crates/salsa/salsa-macros/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "salsa-macros" +version = "0.0.0" +authors = ["Salsa developers"] +edition = "2021" +license = "Apache-2.0 OR MIT" +repository = "https://github.com/salsa-rs/salsa" +description = "Procedural macros for the salsa crate" + +rust-version.workspace = true + +[lib] +proc-macro = true +name = "salsa_macros" + +[dependencies] +heck = "0.4" +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "2.0", features = ["full", "extra-traits"] } + +[lints] +workspace = true diff --git a/crates/salsa/salsa-macros/LICENSE-APACHE b/crates/salsa/salsa-macros/LICENSE-APACHE new file mode 100644 index 0000000000..0bf2cad648 --- /dev/null +++ b/crates/salsa/salsa-macros/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE diff --git a/crates/salsa/salsa-macros/LICENSE-MIT b/crates/salsa/salsa-macros/LICENSE-MIT new file mode 100644 index 0000000000..d99cce5f72 --- /dev/null +++ b/crates/salsa/salsa-macros/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT diff --git a/crates/salsa/salsa-macros/README.md b/crates/salsa/salsa-macros/README.md new file mode 100644 index 0000000000..94389aee61 --- /dev/null +++ b/crates/salsa/salsa-macros/README.md @@ -0,0 +1 @@ +../README.md diff --git a/crates/salsa/salsa-macros/src/database_storage.rs b/crates/salsa/salsa-macros/src/database_storage.rs new file mode 100644 index 0000000000..0ec75bb043 --- /dev/null +++ b/crates/salsa/salsa-macros/src/database_storage.rs @@ -0,0 +1,250 @@ +//! +use heck::ToSnakeCase; +use proc_macro::TokenStream; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::{Ident, ItemStruct, Path, Token}; + +type PunctuatedQueryGroups = Punctuated; + +pub(crate) fn database(args: TokenStream, input: TokenStream) -> TokenStream { + let args = syn::parse_macro_input!(args as QueryGroupList); + let input = syn::parse_macro_input!(input as ItemStruct); + + let query_groups = &args.query_groups; + let database_name = &input.ident; + let visibility = &input.vis; + let db_storage_field = quote! { storage }; + + let mut output = proc_macro2::TokenStream::new(); + output.extend(quote! { #input }); + + let query_group_names_snake: Vec<_> = query_groups + .iter() + .map(|query_group| { + let group_name = query_group.name(); + Ident::new(&group_name.to_string().to_snake_case(), group_name.span()) + }) + .collect(); + + let query_group_storage_names: Vec<_> = query_groups + .iter() + .map(|QueryGroup { group_path }| { + quote! { + <#group_path as salsa::plumbing::QueryGroup>::GroupStorage + } + }) + .collect(); + + // For each query group `foo::MyGroup` create a link to its + // `foo::MyGroupGroupStorage` + let mut storage_fields = proc_macro2::TokenStream::new(); + let mut storage_initializers = proc_macro2::TokenStream::new(); + let mut has_group_impls = proc_macro2::TokenStream::new(); + for (((query_group, group_name_snake), group_storage), group_index) in query_groups + .iter() + .zip(&query_group_names_snake) + .zip(&query_group_storage_names) + .zip(0_u16..) + { + let group_path = &query_group.group_path; + + // rewrite the last identifier (`MyGroup`, above) to + // (e.g.) `MyGroupGroupStorage`. + storage_fields.extend(quote! { + #group_name_snake: #group_storage, + }); + + // rewrite the last identifier (`MyGroup`, above) to + // (e.g.) `MyGroupGroupStorage`. + storage_initializers.extend(quote! { + #group_name_snake: #group_storage::new(#group_index), + }); + + // ANCHOR:HasQueryGroup + has_group_impls.extend(quote! { + impl salsa::plumbing::HasQueryGroup<#group_path> for #database_name { + fn group_storage(&self) -> &#group_storage { + &self.#db_storage_field.query_store().#group_name_snake + } + + fn group_storage_mut(&mut self) -> (&#group_storage, &mut salsa::Runtime) { + let (query_store_mut, runtime) = self.#db_storage_field.query_store_mut(); + (&query_store_mut.#group_name_snake, runtime) + } + } + }); + // ANCHOR_END:HasQueryGroup + } + + // create group storage wrapper struct + output.extend(quote! { + #[doc(hidden)] + #visibility struct __SalsaDatabaseStorage { + #storage_fields + } + + impl Default for __SalsaDatabaseStorage { + fn default() -> Self { + Self { + #storage_initializers + } + } + } + }); + + // Create a tuple (D1, D2, ...) where Di is the data for a given query group. + let mut database_data = vec![]; + for QueryGroup { group_path } in query_groups { + database_data.push(quote! { + <#group_path as salsa::plumbing::QueryGroup>::GroupData + }); + } + + // ANCHOR:DatabaseStorageTypes + output.extend(quote! { + impl salsa::plumbing::DatabaseStorageTypes for #database_name { + type DatabaseStorage = __SalsaDatabaseStorage; + } + }); + // ANCHOR_END:DatabaseStorageTypes + + // ANCHOR:DatabaseOps + let mut fmt_ops = proc_macro2::TokenStream::new(); + let mut maybe_changed_ops = proc_macro2::TokenStream::new(); + let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new(); + let mut for_each_ops = proc_macro2::TokenStream::new(); + for ((QueryGroup { group_path }, group_storage), group_index) in + query_groups.iter().zip(&query_group_storage_names).zip(0_u16..) + { + fmt_ops.extend(quote! { + #group_index => { + let storage: &#group_storage = + >::group_storage(self); + storage.fmt_index(self, input, fmt) + } + }); + maybe_changed_ops.extend(quote! { + #group_index => { + let storage: &#group_storage = + >::group_storage(self); + storage.maybe_changed_after(self, input, revision) + } + }); + cycle_recovery_strategy_ops.extend(quote! { + #group_index => { + let storage: &#group_storage = + >::group_storage(self); + storage.cycle_recovery_strategy(self, input) + } + }); + for_each_ops.extend(quote! { + let storage: &#group_storage = + >::group_storage(self); + storage.for_each_query(runtime, &mut op); + }); + } + output.extend(quote! { + impl salsa::plumbing::DatabaseOps for #database_name { + fn ops_database(&self) -> &dyn salsa::Database { + self + } + + fn ops_salsa_runtime(&self) -> &salsa::Runtime { + self.#db_storage_field.salsa_runtime() + } + + fn ops_salsa_runtime_mut(&mut self) -> &mut salsa::Runtime { + self.#db_storage_field.salsa_runtime_mut() + } + + fn fmt_index( + &self, + input: salsa::DatabaseKeyIndex, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + match input.group_index() { + #fmt_ops + i => panic!("salsa: invalid group index {}", i) + } + } + + fn maybe_changed_after( + &self, + input: salsa::DatabaseKeyIndex, + revision: salsa::Revision + ) -> bool { + match input.group_index() { + #maybe_changed_ops + i => panic!("salsa: invalid group index {}", i) + } + } + + fn cycle_recovery_strategy( + &self, + input: salsa::DatabaseKeyIndex, + ) -> salsa::plumbing::CycleRecoveryStrategy { + match input.group_index() { + #cycle_recovery_strategy_ops + i => panic!("salsa: invalid group index {}", i) + } + } + + fn for_each_query( + &self, + mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps), + ) { + let runtime = salsa::Database::salsa_runtime(self); + #for_each_ops + } + } + }); + // ANCHOR_END:DatabaseOps + + output.extend(has_group_impls); + + output.into() +} + +#[derive(Clone, Debug)] +struct QueryGroupList { + query_groups: PunctuatedQueryGroups, +} + +impl Parse for QueryGroupList { + fn parse(input: ParseStream<'_>) -> syn::Result { + let query_groups: PunctuatedQueryGroups = + input.parse_terminated(QueryGroup::parse, Token![,])?; + Ok(QueryGroupList { query_groups }) + } +} + +#[derive(Clone, Debug)] +struct QueryGroup { + group_path: Path, +} + +impl QueryGroup { + /// The name of the query group trait. + fn name(&self) -> Ident { + self.group_path.segments.last().unwrap().ident.clone() + } +} + +impl Parse for QueryGroup { + /// ```ignore + /// impl HelloWorldDatabase; + /// ``` + fn parse(input: ParseStream<'_>) -> syn::Result { + let group_path: Path = input.parse()?; + Ok(QueryGroup { group_path }) + } +} + +struct Nothing; + +impl Parse for Nothing { + fn parse(_input: ParseStream<'_>) -> syn::Result { + Ok(Nothing) + } +} diff --git a/crates/salsa/salsa-macros/src/lib.rs b/crates/salsa/salsa-macros/src/lib.rs new file mode 100644 index 0000000000..8af48b1e3f --- /dev/null +++ b/crates/salsa/salsa-macros/src/lib.rs @@ -0,0 +1,146 @@ +//! This crate provides salsa's macros and attributes. + +#![recursion_limit = "256"] + +#[macro_use] +extern crate quote; + +use proc_macro::TokenStream; + +mod database_storage; +mod parenthesized; +mod query_group; + +/// The decorator that defines a salsa "query group" trait. This is a +/// trait that defines everything that a block of queries need to +/// execute, as well as defining the queries themselves that are +/// exported for others to use. +/// +/// This macro declares the "prototype" for a group of queries. It will +/// expand into a trait and a set of structs, one per query. +/// +/// For each query, you give the name of the accessor method to invoke +/// the query (e.g., `my_query`, below), as well as its parameter +/// types and the output type. You also give the name for a query type +/// (e.g., `MyQuery`, below) that represents the query, and optionally +/// other details, such as its storage. +/// +/// # Examples +/// +/// The simplest example is something like this: +/// +/// ```ignore +/// #[salsa::query_group] +/// trait TypeckDatabase { +/// #[salsa::input] // see below for other legal attributes +/// fn my_query(&self, input: u32) -> u64; +/// +/// /// Queries can have any number of inputs (including zero); if there +/// /// is not exactly one input, then the key type will be +/// /// a tuple of the input types, so in this case `(u32, f32)`. +/// fn other_query(&self, input1: u32, input2: f32) -> u64; +/// } +/// ``` +/// +/// Here is a list of legal `salsa::XXX` attributes: +/// +/// - Storage attributes: control how the query data is stored and set. These +/// are described in detail in the section below. +/// - `#[salsa::input]` +/// - `#[salsa::memoized]` +/// - `#[salsa::dependencies]` +/// - Query execution: +/// - `#[salsa::invoke(path::to::my_fn)]` -- for a non-input, this +/// indicates the function to call when a query must be +/// recomputed. The default is to call a function in the same +/// module with the same name as the query. +/// - `#[query_type(MyQueryTypeName)]` specifies the name of the +/// dummy struct created for the query. Default is the name of the +/// query, in camel case, plus the word "Query" (e.g., +/// `MyQueryQuery` and `OtherQueryQuery` in the examples above). +/// +/// # Storage attributes +/// +/// Here are the possible storage values for each query. The default +/// is `storage memoized`. +/// +/// ## Input queries +/// +/// Specifying `storage input` will give you an **input +/// query**. Unlike derived queries, whose value is given by a +/// function, input queries are explicitly set by doing +/// `db.query(QueryType).set(key, value)` (where `QueryType` is the +/// `type` specified for the query). Accessing a value that has not +/// yet been set will panic. Each time you invoke `set`, we assume the +/// value has changed, and so we will potentially re-execute derived +/// queries that read (transitively) from this input. +/// +/// ## Derived queries +/// +/// Derived queries are specified by a function. +/// +/// - `#[salsa::memoized]` (the default) -- The result is memoized +/// between calls. If the inputs have changed, we will recompute +/// the value, but then compare against the old memoized value, +/// which can significantly reduce the amount of recomputation +/// required in new revisions. This does require that the value +/// implements `Eq`. +/// - `#[salsa::dependencies]` -- does not cache the value, so it will +/// be recomputed every time it is needed. We do track the inputs, however, +/// so if they have not changed, then things that rely on this query +/// may be known not to have changed. +/// +/// ## Attribute combinations +/// +/// Some attributes are mutually exclusive. For example, it is an error to add +/// multiple storage specifiers: +/// +/// ```compile_fail +/// # use salsa_macros as salsa; +/// #[salsa::query_group] +/// trait CodegenDatabase { +/// #[salsa::input] +/// #[salsa::memoized] +/// fn my_query(&self, input: u32) -> u64; +/// } +/// ``` +/// +/// It is also an error to annotate a function to `invoke` on an `input` query: +/// +/// ```compile_fail +/// # use salsa_macros as salsa; +/// #[salsa::query_group] +/// trait CodegenDatabase { +/// #[salsa::input] +/// #[salsa::invoke(typeck::my_query)] +/// fn my_query(&self, input: u32) -> u64; +/// } +/// ``` +#[proc_macro_attribute] +pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream { + query_group::query_group(args, input) +} + +/// This attribute is placed on your database struct. It takes a list of the +/// query groups that your database supports. The format looks like so: +/// +/// ```rust,ignore +/// #[salsa::database(MyQueryGroup1, MyQueryGroup2)] +/// struct MyDatabase { +/// runtime: salsa::Runtime, // <-- your database will need this field, too +/// } +/// ``` +/// +/// Here, the struct `MyDatabase` would support the two query groups +/// `MyQueryGroup1` and `MyQueryGroup2`. In addition to the `database` +/// attribute, the struct needs to have a `runtime` field (of type +/// [`salsa::Runtime`]) and to implement the `salsa::Database` trait. +/// +/// See [the `hello_world` example][hw] for more details. +/// +/// [`salsa::Runtime`]: struct.Runtime.html +/// [hw]: https://github.com/salsa-rs/salsa/tree/master/examples/hello_world +#[proc_macro_attribute] +pub fn database(args: TokenStream, input: TokenStream) -> TokenStream { + database_storage::database(args, input) +} diff --git a/crates/salsa/salsa-macros/src/parenthesized.rs b/crates/salsa/salsa-macros/src/parenthesized.rs new file mode 100644 index 0000000000..9df41e03c1 --- /dev/null +++ b/crates/salsa/salsa-macros/src/parenthesized.rs @@ -0,0 +1,13 @@ +//! +pub(crate) struct Parenthesized(pub(crate) T); + +impl syn::parse::Parse for Parenthesized +where + T: syn::parse::Parse, +{ + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let content; + syn::parenthesized!(content in input); + content.parse::().map(Parenthesized) + } +} diff --git a/crates/salsa/salsa-macros/src/query_group.rs b/crates/salsa/salsa-macros/src/query_group.rs new file mode 100644 index 0000000000..e535d7ed04 --- /dev/null +++ b/crates/salsa/salsa-macros/src/query_group.rs @@ -0,0 +1,737 @@ +//! +use std::{convert::TryFrom, iter::FromIterator}; + +use crate::parenthesized::Parenthesized; +use heck::ToUpperCamelCase; +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::ToTokens; +use syn::{ + parse_macro_input, parse_quote, spanned::Spanned, Attribute, Error, FnArg, Ident, ItemTrait, + ReturnType, TraitItem, Type, +}; + +/// Implementation for `[salsa::query_group]` decorator. +pub(crate) fn query_group(args: TokenStream, input: TokenStream) -> TokenStream { + let group_struct = parse_macro_input!(args as Ident); + let input: ItemTrait = parse_macro_input!(input as ItemTrait); + // println!("args: {:#?}", args); + // println!("input: {:#?}", input); + + let input_span = input.span(); + let (trait_attrs, salsa_attrs) = filter_attrs(input.attrs); + if !salsa_attrs.is_empty() { + return Error::new(input_span, format!("unsupported attributes: {:?}", salsa_attrs)) + .to_compile_error() + .into(); + } + + let trait_vis = input.vis; + let trait_name = input.ident; + let _generics = input.generics.clone(); + let dyn_db = quote! { dyn #trait_name }; + + // Decompose the trait into the corresponding queries. + let mut queries = vec![]; + for item in input.items { + if let TraitItem::Fn(method) = item { + let query_name = method.sig.ident.to_string(); + + let mut storage = QueryStorage::Memoized; + let mut cycle = None; + let mut invoke = None; + + let mut query_type = + format_ident!("{}Query", query_name.to_string().to_upper_camel_case()); + let mut num_storages = 0; + + // Extract attributes. + let (attrs, salsa_attrs) = filter_attrs(method.attrs); + for SalsaAttr { name, tts, span } in salsa_attrs { + match name.as_str() { + "memoized" => { + storage = QueryStorage::Memoized; + num_storages += 1; + } + "dependencies" => { + storage = QueryStorage::Dependencies; + num_storages += 1; + } + "input" => { + storage = QueryStorage::Input; + num_storages += 1; + } + "interned" => { + storage = QueryStorage::Interned; + num_storages += 1; + } + "cycle" => { + cycle = Some(parse_macro_input!(tts as Parenthesized).0); + } + "invoke" => { + invoke = Some(parse_macro_input!(tts as Parenthesized).0); + } + "query_type" => { + query_type = parse_macro_input!(tts as Parenthesized).0; + } + "transparent" => { + storage = QueryStorage::Transparent; + num_storages += 1; + } + _ => { + return Error::new(span, format!("unknown salsa attribute `{}`", name)) + .to_compile_error() + .into(); + } + } + } + + let sig_span = method.sig.span(); + // Check attribute combinations. + if num_storages > 1 { + return Error::new(sig_span, "multiple storage attributes specified") + .to_compile_error() + .into(); + } + match &invoke { + Some(invoke) if storage == QueryStorage::Input => { + return Error::new( + invoke.span(), + "#[salsa::invoke] cannot be set on #[salsa::input] queries", + ) + .to_compile_error() + .into(); + } + _ => {} + } + + // Extract keys. + let mut iter = method.sig.inputs.iter(); + let self_receiver = match iter.next() { + Some(FnArg::Receiver(sr)) if sr.mutability.is_none() => sr, + _ => { + return Error::new( + sig_span, + format!("first argument of query `{}` must be `&self`", query_name), + ) + .to_compile_error() + .into(); + } + }; + let mut keys: Vec<(Ident, Type)> = vec![]; + for (idx, arg) in iter.enumerate() { + match arg { + FnArg::Typed(syn::PatType { pat, ty, .. }) => keys.push(( + match pat.as_ref() { + syn::Pat::Ident(ident_pat) => ident_pat.ident.clone(), + _ => format_ident!("key{}", idx), + }, + Type::clone(ty), + )), + arg => { + return Error::new( + arg.span(), + format!("unsupported argument `{:?}` of `{}`", arg, query_name,), + ) + .to_compile_error() + .into(); + } + } + } + + // Extract value. + let value = match method.sig.output { + ReturnType::Type(_, ref ty) => ty.as_ref().clone(), + ref ret => { + return Error::new( + ret.span(), + format!("unsupported return type `{:?}` of `{}`", ret, query_name), + ) + .to_compile_error() + .into(); + } + }; + + // For `#[salsa::interned]` keys, we create a "lookup key" automatically. + // + // For a query like: + // + // fn foo(&self, x: Key1, y: Key2) -> u32 + // + // we would create + // + // fn lookup_foo(&self, x: u32) -> (Key1, Key2) + let lookup_query = if let QueryStorage::Interned = storage { + let lookup_query_type = + format_ident!("{}LookupQuery", query_name.to_string().to_upper_camel_case()); + let lookup_fn_name = format_ident!("lookup_{}", query_name); + let keys = keys.iter().map(|(_, ty)| ty); + let lookup_value: Type = parse_quote!((#(#keys),*)); + let lookup_keys = vec![(parse_quote! { key }, value.clone())]; + Some(Query { + query_type: lookup_query_type, + query_name: format!("{}", lookup_fn_name), + fn_name: lookup_fn_name, + receiver: self_receiver.clone(), + attrs: vec![], // FIXME -- some automatically generated docs on this method? + storage: QueryStorage::InternedLookup { intern_query_type: query_type.clone() }, + keys: lookup_keys, + value: lookup_value, + invoke: None, + cycle: cycle.clone(), + }) + } else { + None + }; + + queries.push(Query { + query_type, + query_name, + fn_name: method.sig.ident, + receiver: self_receiver.clone(), + attrs, + storage, + keys, + value, + invoke, + cycle, + }); + + queries.extend(lookup_query); + } + } + + let group_storage = format_ident!("{}GroupStorage__", trait_name, span = Span::call_site()); + + let mut query_fn_declarations = proc_macro2::TokenStream::new(); + let mut query_fn_definitions = proc_macro2::TokenStream::new(); + let mut storage_fields = proc_macro2::TokenStream::new(); + let mut queries_with_storage = vec![]; + for query in &queries { + #[allow(clippy::map_identity)] + // clippy is incorrect here, this is not the identity function due to match ergonomics + let (key_names, keys): (Vec<_>, Vec<_>) = query.keys.iter().map(|(a, b)| (a, b)).unzip(); + let value = &query.value; + let fn_name = &query.fn_name; + let qt = &query.query_type; + let attrs = &query.attrs; + let self_receiver = &query.receiver; + + query_fn_declarations.extend(quote! { + #(#attrs)* + fn #fn_name(#self_receiver, #(#key_names: #keys),*) -> #value; + }); + + // Special case: transparent queries don't create actual storage, + // just inline the definition + if let QueryStorage::Transparent = query.storage { + let invoke = query.invoke_tt(); + query_fn_definitions.extend(quote! { + fn #fn_name(&self, #(#key_names: #keys),*) -> #value { + #invoke(self, #(#key_names),*) + } + }); + continue; + } + + queries_with_storage.push(fn_name); + + query_fn_definitions.extend(quote! { + fn #fn_name(&self, #(#key_names: #keys),*) -> #value { + // Create a shim to force the code to be monomorphized in the + // query crate. Our experiments revealed that this makes a big + // difference in total compilation time in rust-analyzer, though + // it's not totally obvious why that should be. + fn __shim(db: &(dyn #trait_name + '_), #(#key_names: #keys),*) -> #value { + salsa::plumbing::get_query_table::<#qt>(db).get((#(#key_names),*)) + } + __shim(self, #(#key_names),*) + + } + }); + + // For input queries, we need `set_foo` etc + if let QueryStorage::Input = query.storage { + let set_fn_name = format_ident!("set_{}", fn_name); + let set_with_durability_fn_name = format_ident!("set_{}_with_durability", fn_name); + + let set_fn_docs = format!( + " + Set the value of the `{fn_name}` input. + + See `{fn_name}` for details. + + *Note:* Setting values will trigger cancellation + of any ongoing queries; this method blocks until + those queries have been cancelled. + ", + fn_name = fn_name + ); + + let set_constant_fn_docs = format!( + " + Set the value of the `{fn_name}` input with a + specific durability instead of the default of + `Durability::LOW`. You can use `Durability::MAX` + to promise that its value will never change again. + + See `{fn_name}` for details. + + *Note:* Setting values will trigger cancellation + of any ongoing queries; this method blocks until + those queries have been cancelled. + ", + fn_name = fn_name + ); + + query_fn_declarations.extend(quote! { + # [doc = #set_fn_docs] + fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value); + + + # [doc = #set_constant_fn_docs] + fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability); + }); + + query_fn_definitions.extend(quote! { + fn #set_fn_name(&mut self, #(#key_names: #keys,)* value__: #value) { + fn __shim(db: &mut dyn #trait_name, #(#key_names: #keys,)* value__: #value) { + salsa::plumbing::get_query_table_mut::<#qt>(db).set((#(#key_names),*), value__) + } + __shim(self, #(#key_names,)* value__) + } + + fn #set_with_durability_fn_name(&mut self, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability) { + fn __shim(db: &mut dyn #trait_name, #(#key_names: #keys,)* value__: #value, durability__: salsa::Durability) { + salsa::plumbing::get_query_table_mut::<#qt>(db).set_with_durability((#(#key_names),*), value__, durability__) + } + __shim(self, #(#key_names,)* value__ ,durability__) + } + }); + } + + // A field for the storage struct + storage_fields.extend(quote! { + #fn_name: std::sync::Arc<<#qt as salsa::Query>::Storage>, + }); + } + + // Emit the trait itself. + let mut output = { + let bounds = &input.supertraits; + quote! { + #(#trait_attrs)* + #trait_vis trait #trait_name : + salsa::Database + + salsa::plumbing::HasQueryGroup<#group_struct> + + #bounds + { + #query_fn_declarations + } + } + }; + + // Emit the query group struct and impl of `QueryGroup`. + output.extend(quote! { + /// Representative struct for the query group. + #trait_vis struct #group_struct { } + + impl salsa::plumbing::QueryGroup for #group_struct + { + type DynDb = #dyn_db; + type GroupStorage = #group_storage; + } + }); + + // Emit an impl of the trait + output.extend({ + let bounds = input.supertraits; + quote! { + impl #trait_name for DB + where + DB: #bounds, + DB: salsa::Database, + DB: salsa::plumbing::HasQueryGroup<#group_struct>, + { + #query_fn_definitions + } + } + }); + + let non_transparent_queries = + || queries.iter().filter(|q| !matches!(q.storage, QueryStorage::Transparent)); + + // Emit the query types. + for (query, query_index) in non_transparent_queries().zip(0_u16..) { + let fn_name = &query.fn_name; + let qt = &query.query_type; + + let storage = match &query.storage { + QueryStorage::Memoized => quote!(salsa::plumbing::MemoizedStorage), + QueryStorage::Dependencies => { + quote!(salsa::plumbing::DependencyStorage) + } + QueryStorage::Input if query.keys.is_empty() => { + quote!(salsa::plumbing::UnitInputStorage) + } + QueryStorage::Input => quote!(salsa::plumbing::InputStorage), + QueryStorage::Interned => quote!(salsa::plumbing::InternedStorage), + QueryStorage::InternedLookup { intern_query_type } => { + quote!(salsa::plumbing::LookupInternedStorage) + } + QueryStorage::Transparent => panic!("should have been filtered"), + }; + let keys = query.keys.iter().map(|(_, ty)| ty); + let value = &query.value; + let query_name = &query.query_name; + + // Emit the query struct and implement the Query trait on it. + output.extend(quote! { + #[derive(Default, Debug)] + #trait_vis struct #qt; + }); + + output.extend(quote! { + impl #qt { + /// Get access to extra methods pertaining to this query. + /// You can also use it to invoke this query. + #trait_vis fn in_db(self, db: &#dyn_db) -> salsa::QueryTable<'_, Self> + { + salsa::plumbing::get_query_table::<#qt>(db) + } + } + }); + + output.extend(quote! { + impl #qt { + /// Like `in_db`, but gives access to methods for setting the + /// value of an input. Not applicable to derived queries. + /// + /// # Threads, cancellation, and blocking + /// + /// Mutating the value of a query cannot be done while there are + /// still other queries executing. If you are using your database + /// within a single thread, this is not a problem: you only have + /// `&self` access to the database, but this method requires `&mut + /// self`. + /// + /// However, if you have used `snapshot` to create other threads, + /// then attempts to `set` will **block the current thread** until + /// those snapshots are dropped (usually when those threads + /// complete). This also implies that if you create a snapshot but + /// do not send it to another thread, then invoking `set` will + /// deadlock. + /// + /// Before blocking, the thread that is attempting to `set` will + /// also set a cancellation flag. This will cause any query + /// invocations in other threads to unwind with a `Cancelled` + /// sentinel value and eventually let the `set` succeed once all + /// threads have unwound past the salsa invocation. + /// + /// If your query implementations are performing expensive + /// operations without invoking another query, you can also use + /// the `Runtime::unwind_if_cancelled` method to check for an + /// ongoing cancellation and bring those operations to a close, + /// thus allowing the `set` to succeed. Otherwise, long-running + /// computations may lead to "starvation", meaning that the + /// thread attempting to `set` has to wait a long, long time. =) + #trait_vis fn in_db_mut(self, db: &mut #dyn_db) -> salsa::QueryTableMut<'_, Self> + { + salsa::plumbing::get_query_table_mut::<#qt>(db) + } + } + + impl<'d> salsa::QueryDb<'d> for #qt + { + type DynDb = #dyn_db + 'd; + type Group = #group_struct; + type GroupStorage = #group_storage; + } + + // ANCHOR:Query_impl + impl salsa::Query for #qt + { + type Key = (#(#keys),*); + type Value = #value; + type Storage = #storage; + + const QUERY_INDEX: u16 = #query_index; + + const QUERY_NAME: &'static str = #query_name; + + fn query_storage<'a>( + group_storage: &'a >::GroupStorage, + ) -> &'a std::sync::Arc { + &group_storage.#fn_name + } + + fn query_storage_mut<'a>( + group_storage: &'a >::GroupStorage, + ) -> &'a std::sync::Arc { + &group_storage.#fn_name + } + } + // ANCHOR_END:Query_impl + }); + + // Implement the QueryFunction trait for queries which need it. + if query.storage.needs_query_function() { + let span = query.fn_name.span(); + + let key_names: Vec<_> = query.keys.iter().map(|(pat, _)| pat).collect(); + let key_pattern = if query.keys.len() == 1 { + quote! { #(#key_names),* } + } else { + quote! { (#(#key_names),*) } + }; + let invoke = query.invoke_tt(); + + let recover = if let Some(cycle_recovery_fn) = &query.cycle { + quote! { + const CYCLE_STRATEGY: salsa::plumbing::CycleRecoveryStrategy = + salsa::plumbing::CycleRecoveryStrategy::Fallback; + fn cycle_fallback(db: &>::DynDb, cycle: &salsa::Cycle, #key_pattern: &::Key) + -> ::Value { + #cycle_recovery_fn( + db, + cycle, + #(#key_names),* + ) + } + } + } else { + quote! { + const CYCLE_STRATEGY: salsa::plumbing::CycleRecoveryStrategy = + salsa::plumbing::CycleRecoveryStrategy::Panic; + } + }; + + output.extend(quote_spanned! {span=> + // ANCHOR:QueryFunction_impl + impl salsa::plumbing::QueryFunction for #qt + { + fn execute(db: &>::DynDb, #key_pattern: ::Key) + -> ::Value { + #invoke(db, #(#key_names),*) + } + + #recover + } + // ANCHOR_END:QueryFunction_impl + }); + } + } + + let mut fmt_ops = proc_macro2::TokenStream::new(); + for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) { + fmt_ops.extend(quote! { + #query_index => { + salsa::plumbing::QueryStorageOps::fmt_index( + &*self.#fn_name, db, input, fmt, + ) + } + }); + } + + let mut maybe_changed_ops = proc_macro2::TokenStream::new(); + for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) { + maybe_changed_ops.extend(quote! { + #query_index => { + salsa::plumbing::QueryStorageOps::maybe_changed_after( + &*self.#fn_name, db, input, revision + ) + } + }); + } + + let mut cycle_recovery_strategy_ops = proc_macro2::TokenStream::new(); + for (Query { fn_name, .. }, query_index) in non_transparent_queries().zip(0_u16..) { + cycle_recovery_strategy_ops.extend(quote! { + #query_index => { + salsa::plumbing::QueryStorageOps::cycle_recovery_strategy( + &*self.#fn_name + ) + } + }); + } + + let mut for_each_ops = proc_macro2::TokenStream::new(); + for Query { fn_name, .. } in non_transparent_queries() { + for_each_ops.extend(quote! { + op(&*self.#fn_name); + }); + } + + // Emit query group storage struct + output.extend(quote! { + #trait_vis struct #group_storage { + #storage_fields + } + + // ANCHOR:group_storage_new + impl #group_storage { + #trait_vis fn new(group_index: u16) -> Self { + #group_storage { + #( + #queries_with_storage: + std::sync::Arc::new(salsa::plumbing::QueryStorageOps::new(group_index)), + )* + } + } + } + // ANCHOR_END:group_storage_new + + // ANCHOR:group_storage_methods + impl #group_storage { + #trait_vis fn fmt_index( + &self, + db: &(#dyn_db + '_), + input: salsa::DatabaseKeyIndex, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + match input.query_index() { + #fmt_ops + i => panic!("salsa: impossible query index {}", i), + } + } + + #trait_vis fn maybe_changed_after( + &self, + db: &(#dyn_db + '_), + input: salsa::DatabaseKeyIndex, + revision: salsa::Revision, + ) -> bool { + match input.query_index() { + #maybe_changed_ops + i => panic!("salsa: impossible query index {}", i), + } + } + + #trait_vis fn cycle_recovery_strategy( + &self, + db: &(#dyn_db + '_), + input: salsa::DatabaseKeyIndex, + ) -> salsa::plumbing::CycleRecoveryStrategy { + match input.query_index() { + #cycle_recovery_strategy_ops + i => panic!("salsa: impossible query index {}", i), + } + } + + #trait_vis fn for_each_query( + &self, + _runtime: &salsa::Runtime, + mut op: &mut dyn FnMut(&dyn salsa::plumbing::QueryStorageMassOps), + ) { + #for_each_ops + } + } + // ANCHOR_END:group_storage_methods + }); + output.into() +} + +struct SalsaAttr { + name: String, + tts: TokenStream, + span: Span, +} + +impl std::fmt::Debug for SalsaAttr { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(fmt, "{:?}", self.name) + } +} + +impl TryFrom for SalsaAttr { + type Error = syn::Attribute; + + fn try_from(attr: syn::Attribute) -> Result { + if is_not_salsa_attr_path(attr.path()) { + return Err(attr); + } + + let span = attr.span(); + let name = attr.path().segments[1].ident.to_string(); + let tts = match attr.meta { + syn::Meta::Path(path) => path.into_token_stream(), + syn::Meta::List(ref list) => { + let tts = list + .into_token_stream() + .into_iter() + .skip(attr.path().to_token_stream().into_iter().count()); + proc_macro2::TokenStream::from_iter(tts) + } + syn::Meta::NameValue(nv) => nv.into_token_stream(), + } + .into(); + + Ok(SalsaAttr { name, tts, span }) + } +} + +fn is_not_salsa_attr_path(path: &syn::Path) -> bool { + path.segments.first().map(|s| s.ident != "salsa").unwrap_or(true) || path.segments.len() != 2 +} + +fn filter_attrs(attrs: Vec) -> (Vec, Vec) { + let mut other = vec![]; + let mut salsa = vec![]; + // Leave non-salsa attributes untouched. These are + // attributes that don't start with `salsa::` or don't have + // exactly two segments in their path. + // Keep the salsa attributes around. + for attr in attrs { + match SalsaAttr::try_from(attr) { + Ok(it) => salsa.push(it), + Err(it) => other.push(it), + } + } + (other, salsa) +} + +#[derive(Debug)] +struct Query { + fn_name: Ident, + receiver: syn::Receiver, + query_name: String, + attrs: Vec, + query_type: Ident, + storage: QueryStorage, + keys: Vec<(Ident, syn::Type)>, + value: syn::Type, + invoke: Option, + cycle: Option, +} + +impl Query { + fn invoke_tt(&self) -> proc_macro2::TokenStream { + match &self.invoke { + Some(i) => i.into_token_stream(), + None => self.fn_name.clone().into_token_stream(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum QueryStorage { + Memoized, + Dependencies, + Input, + Interned, + InternedLookup { intern_query_type: Ident }, + Transparent, +} + +impl QueryStorage { + /// Do we need a `QueryFunction` impl for this type of query? + fn needs_query_function(&self) -> bool { + match self { + QueryStorage::Input + | QueryStorage::Interned + | QueryStorage::InternedLookup { .. } + | QueryStorage::Transparent => false, + QueryStorage::Memoized | QueryStorage::Dependencies => true, + } + } +} diff --git a/crates/salsa/src/debug.rs b/crates/salsa/src/debug.rs new file mode 100644 index 0000000000..0925ddb3d8 --- /dev/null +++ b/crates/salsa/src/debug.rs @@ -0,0 +1,66 @@ +//! Debugging APIs: these are meant for use when unit-testing or +//! debugging your application but aren't ordinarily needed. + +use crate::durability::Durability; +use crate::plumbing::QueryStorageOps; +use crate::Query; +use crate::QueryTable; +use std::iter::FromIterator; + +/// Additional methods on queries that can be used to "peek into" +/// their current state. These methods are meant for debugging and +/// observing the effects of garbage collection etc. +pub trait DebugQueryTable { + /// Key of this query. + type Key; + + /// Value of this query. + type Value; + + /// Returns a lower bound on the durability for the given key. + /// This is typically the minimum durability of all values that + /// the query accessed, but we may return a lower durability in + /// some cases. + fn durability(&self, key: Self::Key) -> Durability; + + /// Get the (current) set of the entries in the query table. + fn entries(&self) -> C + where + C: FromIterator>; +} + +/// An entry from a query table, for debugging and inspecting the table state. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[non_exhaustive] +pub struct TableEntry { + /// key of the query + pub key: K, + /// value of the query, if it is stored + pub value: Option, +} + +impl TableEntry { + pub(crate) fn new(key: K, value: Option) -> TableEntry { + TableEntry { key, value } + } +} + +impl DebugQueryTable for QueryTable<'_, Q> +where + Q: Query, + Q::Storage: QueryStorageOps, +{ + type Key = Q::Key; + type Value = Q::Value; + + fn durability(&self, key: Q::Key) -> Durability { + self.storage.durability(self.db, &key) + } + + fn entries(&self) -> C + where + C: FromIterator>, + { + self.storage.entries(self.db) + } +} diff --git a/crates/salsa/src/derived.rs b/crates/salsa/src/derived.rs new file mode 100644 index 0000000000..c381e66e08 --- /dev/null +++ b/crates/salsa/src/derived.rs @@ -0,0 +1,233 @@ +//! +use crate::debug::TableEntry; +use crate::durability::Durability; +use crate::hash::FxIndexMap; +use crate::lru::Lru; +use crate::plumbing::DerivedQueryStorageOps; +use crate::plumbing::LruQueryStorageOps; +use crate::plumbing::QueryFunction; +use crate::plumbing::QueryStorageMassOps; +use crate::plumbing::QueryStorageOps; +use crate::runtime::StampedValue; +use crate::Runtime; +use crate::{Database, DatabaseKeyIndex, QueryDb, Revision}; +use parking_lot::RwLock; +use std::borrow::Borrow; +use std::convert::TryFrom; +use std::hash::Hash; +use std::marker::PhantomData; +use triomphe::Arc; + +mod slot; +use slot::Slot; + +/// Memoized queries store the result plus a list of the other queries +/// that they invoked. This means we can avoid recomputing them when +/// none of those inputs have changed. +pub type MemoizedStorage = DerivedStorage; + +/// "Dependency" queries just track their dependencies and not the +/// actual value (which they produce on demand). This lessens the +/// storage requirements. +pub type DependencyStorage = DerivedStorage; + +/// Handles storage where the value is 'derived' by executing a +/// function (in contrast to "inputs"). +pub struct DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + group_index: u16, + lru_list: Lru>, + slot_map: RwLock>>>, + policy: PhantomData, +} + +impl std::panic::RefUnwindSafe for DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, + Q::Key: std::panic::RefUnwindSafe, + Q::Value: std::panic::RefUnwindSafe, +{ +} + +pub trait MemoizationPolicy: Send + Sync +where + Q: QueryFunction, +{ + fn should_memoize_value(key: &Q::Key) -> bool; + + fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool; +} + +pub enum AlwaysMemoizeValue {} +impl MemoizationPolicy for AlwaysMemoizeValue +where + Q: QueryFunction, + Q::Value: Eq, +{ + fn should_memoize_value(_key: &Q::Key) -> bool { + true + } + + fn memoized_value_eq(old_value: &Q::Value, new_value: &Q::Value) -> bool { + old_value == new_value + } +} + +pub enum NeverMemoizeValue {} +impl MemoizationPolicy for NeverMemoizeValue +where + Q: QueryFunction, +{ + fn should_memoize_value(_key: &Q::Key) -> bool { + false + } + + fn memoized_value_eq(_old_value: &Q::Value, _new_value: &Q::Value) -> bool { + panic!("cannot reach since we never memoize") + } +} + +impl DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn slot(&self, key: &Q::Key) -> Arc> { + if let Some(v) = self.slot_map.read().get(key) { + return v.clone(); + } + + let mut write = self.slot_map.write(); + let entry = write.entry(key.clone()); + let key_index = u32::try_from(entry.index()).unwrap(); + let database_key_index = DatabaseKeyIndex { + group_index: self.group_index, + query_index: Q::QUERY_INDEX, + key_index, + }; + entry.or_insert_with(|| Arc::new(Slot::new(key.clone(), database_key_index))).clone() + } +} + +impl QueryStorageOps for DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = Q::CYCLE_STRATEGY; + + fn new(group_index: u16) -> Self { + DerivedStorage { + group_index, + slot_map: RwLock::new(FxIndexMap::default()), + lru_list: Default::default(), + policy: PhantomData, + } + } + + fn fmt_index( + &self, + _db: &>::DynDb, + index: DatabaseKeyIndex, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + assert_eq!(index.group_index, self.group_index); + assert_eq!(index.query_index, Q::QUERY_INDEX); + let slot_map = self.slot_map.read(); + let key = slot_map.get_index(index.key_index as usize).unwrap().0; + write!(fmt, "{}({:?})", Q::QUERY_NAME, key) + } + + fn maybe_changed_after( + &self, + db: &>::DynDb, + input: DatabaseKeyIndex, + revision: Revision, + ) -> bool { + assert_eq!(input.group_index, self.group_index); + assert_eq!(input.query_index, Q::QUERY_INDEX); + debug_assert!(revision < db.salsa_runtime().current_revision()); + let slot = self.slot_map.read().get_index(input.key_index as usize).unwrap().1.clone(); + slot.maybe_changed_after(db, revision) + } + + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { + db.unwind_if_cancelled(); + + let slot = self.slot(key); + let StampedValue { value, durability, changed_at } = slot.read(db); + + if let Some(evicted) = self.lru_list.record_use(&slot) { + evicted.evict(); + } + + db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted( + slot.database_key_index(), + durability, + changed_at, + ); + + value + } + + fn durability(&self, db: &>::DynDb, key: &Q::Key) -> Durability { + self.slot(key).durability(db) + } + + fn entries(&self, _db: &>::DynDb) -> C + where + C: std::iter::FromIterator>, + { + let slot_map = self.slot_map.read(); + slot_map.values().filter_map(|slot| slot.as_table_entry()).collect() + } +} + +impl QueryStorageMassOps for DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn purge(&self) { + self.lru_list.purge(); + *self.slot_map.write() = Default::default(); + } +} + +impl LruQueryStorageOps for DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn set_lru_capacity(&self, new_capacity: usize) { + self.lru_list.set_lru_capacity(new_capacity); + } +} + +impl DerivedQueryStorageOps for DerivedStorage +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn invalidate(&self, runtime: &mut Runtime, key: &S) + where + S: Eq + Hash, + Q::Key: Borrow, + { + runtime.with_incremented_revision(|new_revision| { + let map_read = self.slot_map.read(); + + if let Some(slot) = map_read.get(key) { + if let Some(durability) = slot.invalidate(new_revision) { + return Some(durability); + } + } + + None + }) + } +} diff --git a/crates/salsa/src/derived/slot.rs b/crates/salsa/src/derived/slot.rs new file mode 100644 index 0000000000..4fad791a26 --- /dev/null +++ b/crates/salsa/src/derived/slot.rs @@ -0,0 +1,833 @@ +//! +use crate::debug::TableEntry; +use crate::derived::MemoizationPolicy; +use crate::durability::Durability; +use crate::lru::LruIndex; +use crate::lru::LruNode; +use crate::plumbing::{DatabaseOps, QueryFunction}; +use crate::revision::Revision; +use crate::runtime::local_state::ActiveQueryGuard; +use crate::runtime::local_state::QueryInputs; +use crate::runtime::local_state::QueryRevisions; +use crate::runtime::Runtime; +use crate::runtime::RuntimeId; +use crate::runtime::StampedValue; +use crate::runtime::WaitResult; +use crate::Cycle; +use crate::{Database, DatabaseKeyIndex, Event, EventKind, QueryDb}; +use parking_lot::{RawRwLock, RwLock}; +use std::marker::PhantomData; +use std::ops::Deref; +use std::sync::atomic::{AtomicBool, Ordering}; +use tracing::{debug, info}; + +pub(super) struct Slot +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + key: Q::Key, + database_key_index: DatabaseKeyIndex, + state: RwLock>, + policy: PhantomData, + lru_index: LruIndex, +} + +/// Defines the "current state" of query's memoized results. +enum QueryState +where + Q: QueryFunction, +{ + NotComputed, + + /// The runtime with the given id is currently computing the + /// result of this query. + InProgress { + id: RuntimeId, + + /// Set to true if any other queries are blocked, + /// waiting for this query to complete. + anyone_waiting: AtomicBool, + }, + + /// We have computed the query already, and here is the result. + Memoized(Memo), +} + +struct Memo { + /// The result of the query, if we decide to memoize it. + value: Option, + + /// Last revision when this memo was verified; this begins + /// as the current revision. + pub(crate) verified_at: Revision, + + /// Revision information + revisions: QueryRevisions, +} + +/// Return value of `probe` helper. +enum ProbeState { + /// Another thread was active but has completed. + /// Try again! + Retry, + + /// No entry for this key at all. + NotComputed(G), + + /// There is an entry, but its contents have not been + /// verified in this revision. + Stale(G), + + /// There is an entry, and it has been verified + /// in this revision, but it has no cached + /// value. The `Revision` is the revision where the + /// value last changed (if we were to recompute it). + NoValue(G, Revision), + + /// There is an entry which has been verified, + /// and it has the following value-- or, we blocked + /// on another thread, and that resulted in a cycle. + UpToDate(V), +} + +/// Return value of `maybe_changed_after_probe` helper. +enum MaybeChangedSinceProbeState { + /// Another thread was active but has completed. + /// Try again! + Retry, + + /// Value may have changed in the given revision. + ChangedAt(Revision), + + /// There is a stale cache entry that has not been + /// verified in this revision, so we can't say. + Stale(G), +} + +impl Slot +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + pub(super) fn new(key: Q::Key, database_key_index: DatabaseKeyIndex) -> Self { + Self { + key, + database_key_index, + state: RwLock::new(QueryState::NotComputed), + lru_index: LruIndex::default(), + policy: PhantomData, + } + } + + pub(super) fn database_key_index(&self) -> DatabaseKeyIndex { + self.database_key_index + } + + pub(super) fn read(&self, db: &>::DynDb) -> StampedValue { + let runtime = db.salsa_runtime(); + + // NB: We don't need to worry about people modifying the + // revision out from under our feet. Either `db` is a frozen + // database, in which case there is a lock, or the mutator + // thread is the current thread, and it will be prevented from + // doing any `set` invocations while the query function runs. + let revision_now = runtime.current_revision(); + + info!("{:?}: invoked at {:?}", self, revision_now,); + + // First, do a check with a read-lock. + loop { + match self.probe(db, self.state.read(), runtime, revision_now) { + ProbeState::UpToDate(v) => return v, + ProbeState::Stale(..) | ProbeState::NoValue(..) | ProbeState::NotComputed(..) => { + break + } + ProbeState::Retry => continue, + } + } + + self.read_upgrade(db, revision_now) + } + + /// Second phase of a read operation: acquires an upgradable-read + /// and -- if needed -- validates whether inputs have changed, + /// recomputes value, etc. This is invoked after our initial probe + /// shows a potentially out of date value. + fn read_upgrade( + &self, + db: &>::DynDb, + revision_now: Revision, + ) -> StampedValue { + let runtime = db.salsa_runtime(); + + debug!("{:?}: read_upgrade(revision_now={:?})", self, revision_now,); + + // Check with an upgradable read to see if there is a value + // already. (This permits other readers but prevents anyone + // else from running `read_upgrade` at the same time.) + let mut old_memo = loop { + match self.probe(db, self.state.upgradable_read(), runtime, revision_now) { + ProbeState::UpToDate(v) => return v, + ProbeState::Stale(state) + | ProbeState::NotComputed(state) + | ProbeState::NoValue(state, _) => { + type RwLockUpgradableReadGuard<'a, T> = + lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; + + let mut state = RwLockUpgradableReadGuard::upgrade(state); + match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) { + QueryState::Memoized(old_memo) => break Some(old_memo), + QueryState::InProgress { .. } => unreachable!(), + QueryState::NotComputed => break None, + } + } + ProbeState::Retry => continue, + } + }; + + let panic_guard = PanicGuard::new(self.database_key_index, self, runtime); + let active_query = runtime.push_query(self.database_key_index); + + // If we have an old-value, it *may* now be stale, since there + // has been a new revision since the last time we checked. So, + // first things first, let's walk over each of our previous + // inputs and check whether they are out of date. + if let Some(memo) = &mut old_memo { + if let Some(value) = memo.verify_value(db.ops_database(), revision_now, &active_query) { + info!("{:?}: validated old memoized value", self,); + + db.salsa_event(Event { + runtime_id: runtime.id(), + kind: EventKind::DidValidateMemoizedValue { + database_key: self.database_key_index, + }, + }); + + panic_guard.proceed(old_memo); + + return value; + } + } + + self.execute(db, runtime, revision_now, active_query, panic_guard, old_memo) + } + + fn execute( + &self, + db: &>::DynDb, + runtime: &Runtime, + revision_now: Revision, + active_query: ActiveQueryGuard<'_>, + panic_guard: PanicGuard<'_, Q, MP>, + old_memo: Option>, + ) -> StampedValue { + tracing::info!("{:?}: executing query", self.database_key_index.debug(db)); + + db.salsa_event(Event { + runtime_id: db.salsa_runtime().id(), + kind: EventKind::WillExecute { database_key: self.database_key_index }, + }); + + // Query was not previously executed, or value is potentially + // stale, or value is absent. Let's execute! + let value = match Cycle::catch(|| Q::execute(db, self.key.clone())) { + Ok(v) => v, + Err(cycle) => { + tracing::debug!( + "{:?}: caught cycle {:?}, have strategy {:?}", + self.database_key_index.debug(db), + cycle, + Q::CYCLE_STRATEGY, + ); + match Q::CYCLE_STRATEGY { + crate::plumbing::CycleRecoveryStrategy::Panic => { + panic_guard.proceed(None); + cycle.throw() + } + crate::plumbing::CycleRecoveryStrategy::Fallback => { + if let Some(c) = active_query.take_cycle() { + assert!(c.is(&cycle)); + Q::cycle_fallback(db, &cycle, &self.key) + } else { + // we are not a participant in this cycle + debug_assert!(!cycle + .participant_keys() + .any(|k| k == self.database_key_index)); + cycle.throw() + } + } + } + } + }; + + let mut revisions = active_query.pop(); + + // We assume that query is side-effect free -- that is, does + // not mutate the "inputs" to the query system. Sanity check + // that assumption here, at least to the best of our ability. + assert_eq!( + runtime.current_revision(), + revision_now, + "revision altered during query execution", + ); + + // If the new value is equal to the old one, then it didn't + // really change, even if some of its inputs have. So we can + // "backdate" its `changed_at` revision to be the same as the + // old value. + if let Some(old_memo) = &old_memo { + if let Some(old_value) = &old_memo.value { + // Careful: if the value became less durable than it + // used to be, that is a "breaking change" that our + // consumers must be aware of. Becoming *more* durable + // is not. See the test `constant_to_non_constant`. + if revisions.durability >= old_memo.revisions.durability + && MP::memoized_value_eq(old_value, &value) + { + debug!( + "read_upgrade({:?}): value is equal, back-dating to {:?}", + self, old_memo.revisions.changed_at, + ); + + assert!(old_memo.revisions.changed_at <= revisions.changed_at); + revisions.changed_at = old_memo.revisions.changed_at; + } + } + } + + let new_value = StampedValue { + value, + durability: revisions.durability, + changed_at: revisions.changed_at, + }; + + let memo_value = + if self.should_memoize_value(&self.key) { Some(new_value.value.clone()) } else { None }; + + debug!("read_upgrade({:?}): result.revisions = {:#?}", self, revisions,); + + panic_guard.proceed(Some(Memo { value: memo_value, verified_at: revision_now, revisions })); + + new_value + } + + /// Helper for `read` that does a shallow check (not recursive) if we have an up-to-date value. + /// + /// Invoked with the guard `state` corresponding to the `QueryState` of some `Slot` (the guard + /// can be either read or write). Returns a suitable `ProbeState`: + /// + /// - `ProbeState::UpToDate(r)` if the table has an up-to-date value (or we blocked on another + /// thread that produced such a value). + /// - `ProbeState::StaleOrAbsent(g)` if either (a) there is no memo for this key, (b) the memo + /// has no value; or (c) the memo has not been verified at the current revision. + /// + /// Note that in case `ProbeState::UpToDate`, the lock will have been released. + fn probe( + &self, + db: &>::DynDb, + state: StateGuard, + runtime: &Runtime, + revision_now: Revision, + ) -> ProbeState, StateGuard> + where + StateGuard: Deref>, + { + match &*state { + QueryState::NotComputed => ProbeState::NotComputed(state), + + QueryState::InProgress { id, anyone_waiting } => { + let other_id = *id; + + // NB: `Ordering::Relaxed` is sufficient here, + // as there are no loads that are "gated" on this + // value. Everything that is written is also protected + // by a lock that must be acquired. The role of this + // boolean is to decide *whether* to acquire the lock, + // not to gate future atomic reads. + anyone_waiting.store(true, Ordering::Relaxed); + + self.block_on_or_unwind(db, runtime, other_id, state); + + // Other thread completely normally, so our value may be available now. + ProbeState::Retry + } + + QueryState::Memoized(memo) => { + debug!( + "{:?}: found memoized value, verified_at={:?}, changed_at={:?}", + self, memo.verified_at, memo.revisions.changed_at, + ); + + if memo.verified_at < revision_now { + return ProbeState::Stale(state); + } + + if let Some(value) = &memo.value { + let value = StampedValue { + durability: memo.revisions.durability, + changed_at: memo.revisions.changed_at, + value: value.clone(), + }; + + info!("{:?}: returning memoized value changed at {:?}", self, value.changed_at); + + ProbeState::UpToDate(value) + } else { + let changed_at = memo.revisions.changed_at; + ProbeState::NoValue(state, changed_at) + } + } + } + } + + pub(super) fn durability(&self, db: &>::DynDb) -> Durability { + match &*self.state.read() { + QueryState::NotComputed => Durability::LOW, + QueryState::InProgress { .. } => panic!("query in progress"), + QueryState::Memoized(memo) => { + if memo.check_durability(db.salsa_runtime()) { + memo.revisions.durability + } else { + Durability::LOW + } + } + } + } + + pub(super) fn as_table_entry(&self) -> Option> { + match &*self.state.read() { + QueryState::NotComputed => None, + QueryState::InProgress { .. } => Some(TableEntry::new(self.key.clone(), None)), + QueryState::Memoized(memo) => { + Some(TableEntry::new(self.key.clone(), memo.value.clone())) + } + } + } + + pub(super) fn evict(&self) { + let mut state = self.state.write(); + if let QueryState::Memoized(memo) = &mut *state { + // Evicting a value with an untracked input could + // lead to inconsistencies. Note that we can't check + // `has_untracked_input` when we add the value to the cache, + // because inputs can become untracked in the next revision. + if memo.has_untracked_input() { + return; + } + memo.value = None; + } + } + + pub(super) fn invalidate(&self, new_revision: Revision) -> Option { + tracing::debug!("Slot::invalidate(new_revision = {:?})", new_revision); + match &mut *self.state.write() { + QueryState::Memoized(memo) => { + memo.revisions.inputs = QueryInputs::Untracked; + memo.revisions.changed_at = new_revision; + Some(memo.revisions.durability) + } + QueryState::NotComputed => None, + QueryState::InProgress { .. } => unreachable!(), + } + } + + pub(super) fn maybe_changed_after( + &self, + db: &>::DynDb, + revision: Revision, + ) -> bool { + let runtime = db.salsa_runtime(); + let revision_now = runtime.current_revision(); + + db.unwind_if_cancelled(); + + debug!( + "maybe_changed_after({:?}) called with revision={:?}, revision_now={:?}", + self, revision, revision_now, + ); + + // Do an initial probe with just the read-lock. + // + // If we find that a cache entry for the value is present + // but hasn't been verified in this revision, we'll have to + // do more. + loop { + match self.maybe_changed_after_probe(db, self.state.read(), runtime, revision_now) { + MaybeChangedSinceProbeState::Retry => continue, + MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, + MaybeChangedSinceProbeState::Stale(state) => { + drop(state); + return self.maybe_changed_after_upgrade(db, revision); + } + } + } + } + + fn maybe_changed_after_probe( + &self, + db: &>::DynDb, + state: StateGuard, + runtime: &Runtime, + revision_now: Revision, + ) -> MaybeChangedSinceProbeState + where + StateGuard: Deref>, + { + match self.probe(db, state, runtime, revision_now) { + ProbeState::Retry => MaybeChangedSinceProbeState::Retry, + + ProbeState::Stale(state) => MaybeChangedSinceProbeState::Stale(state), + + // If we know when value last changed, we can return right away. + // Note that we don't need the actual value to be available. + ProbeState::NoValue(_, changed_at) + | ProbeState::UpToDate(StampedValue { value: _, durability: _, changed_at }) => { + MaybeChangedSinceProbeState::ChangedAt(changed_at) + } + + // If we have nothing cached, then value may have changed. + ProbeState::NotComputed(_) => MaybeChangedSinceProbeState::ChangedAt(revision_now), + } + } + + fn maybe_changed_after_upgrade( + &self, + db: &>::DynDb, + revision: Revision, + ) -> bool { + let runtime = db.salsa_runtime(); + let revision_now = runtime.current_revision(); + + // Get an upgradable read lock, which permits other reads but no writers. + // Probe again. If the value is stale (needs to be verified), then upgrade + // to a write lock and swap it with InProgress while we work. + let mut old_memo = match self.maybe_changed_after_probe( + db, + self.state.upgradable_read(), + runtime, + revision_now, + ) { + MaybeChangedSinceProbeState::ChangedAt(changed_at) => return changed_at > revision, + + // If another thread was active, then the cache line is going to be + // either verified or cleared out. Just recurse to figure out which. + // Note that we don't need an upgradable read. + MaybeChangedSinceProbeState::Retry => return self.maybe_changed_after(db, revision), + + MaybeChangedSinceProbeState::Stale(state) => { + type RwLockUpgradableReadGuard<'a, T> = + lock_api::RwLockUpgradableReadGuard<'a, RawRwLock, T>; + + let mut state = RwLockUpgradableReadGuard::upgrade(state); + match std::mem::replace(&mut *state, QueryState::in_progress(runtime.id())) { + QueryState::Memoized(old_memo) => old_memo, + QueryState::NotComputed | QueryState::InProgress { .. } => unreachable!(), + } + } + }; + + let panic_guard = PanicGuard::new(self.database_key_index, self, runtime); + let active_query = runtime.push_query(self.database_key_index); + + if old_memo.verify_revisions(db.ops_database(), revision_now, &active_query) { + let maybe_changed = old_memo.revisions.changed_at > revision; + panic_guard.proceed(Some(old_memo)); + maybe_changed + } else if old_memo.value.is_some() { + // We found that this memoized value may have changed + // but we have an old value. We can re-run the code and + // actually *check* if it has changed. + let StampedValue { changed_at, .. } = + self.execute(db, runtime, revision_now, active_query, panic_guard, Some(old_memo)); + changed_at > revision + } else { + // We found that inputs to this memoized value may have chanced + // but we don't have an old value to compare against or re-use. + // No choice but to drop the memo and say that its value may have changed. + panic_guard.proceed(None); + true + } + } + + /// Helper: see [`Runtime::try_block_on_or_unwind`]. + fn block_on_or_unwind( + &self, + db: &>::DynDb, + runtime: &Runtime, + other_id: RuntimeId, + mutex_guard: MutexGuard, + ) { + runtime.block_on_or_unwind( + db.ops_database(), + self.database_key_index, + other_id, + mutex_guard, + ) + } + + fn should_memoize_value(&self, key: &Q::Key) -> bool { + MP::should_memoize_value(key) + } +} + +impl QueryState +where + Q: QueryFunction, +{ + fn in_progress(id: RuntimeId) -> Self { + QueryState::InProgress { id, anyone_waiting: Default::default() } + } +} + +struct PanicGuard<'me, Q, MP> +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + database_key_index: DatabaseKeyIndex, + slot: &'me Slot, + runtime: &'me Runtime, +} + +impl<'me, Q, MP> PanicGuard<'me, Q, MP> +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn new( + database_key_index: DatabaseKeyIndex, + slot: &'me Slot, + runtime: &'me Runtime, + ) -> Self { + Self { database_key_index, slot, runtime } + } + + /// Indicates that we have concluded normally (without panicking). + /// If `opt_memo` is some, then this memo is installed as the new + /// memoized value. If `opt_memo` is `None`, then the slot is cleared + /// and has no value. + fn proceed(mut self, opt_memo: Option>) { + self.overwrite_placeholder(WaitResult::Completed, opt_memo); + std::mem::forget(self) + } + + /// Overwrites the `InProgress` placeholder for `key` that we + /// inserted; if others were blocked, waiting for us to finish, + /// then notify them. + fn overwrite_placeholder(&mut self, wait_result: WaitResult, opt_memo: Option>) { + let mut write = self.slot.state.write(); + + let old_value = match opt_memo { + // Replace the `InProgress` marker that we installed with the new + // memo, thus releasing our unique access to this key. + Some(memo) => std::mem::replace(&mut *write, QueryState::Memoized(memo)), + + // We had installed an `InProgress` marker, but we panicked before + // it could be removed. At this point, we therefore "own" unique + // access to our slot, so we can just remove the key. + None => std::mem::replace(&mut *write, QueryState::NotComputed), + }; + + match old_value { + QueryState::InProgress { id, anyone_waiting } => { + assert_eq!(id, self.runtime.id()); + + // NB: As noted on the `store`, `Ordering::Relaxed` is + // sufficient here. This boolean signals us on whether to + // acquire a mutex; the mutex will guarantee that all writes + // we are interested in are visible. + if anyone_waiting.load(Ordering::Relaxed) { + self.runtime.unblock_queries_blocked_on(self.database_key_index, wait_result); + } + } + _ => panic!( + "\ +Unexpected panic during query evaluation, aborting the process. + +Please report this bug to https://github.com/salsa-rs/salsa/issues." + ), + } + } +} + +impl<'me, Q, MP> Drop for PanicGuard<'me, Q, MP> +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn drop(&mut self) { + if std::thread::panicking() { + // We panicked before we could proceed and need to remove `key`. + self.overwrite_placeholder(WaitResult::Panicked, None) + } else { + // If no panic occurred, then panic guard ought to be + // "forgotten" and so this Drop code should never run. + panic!(".forget() was not called") + } + } +} + +impl Memo +where + V: Clone, +{ + /// Determines whether the value stored in this memo (if any) is still + /// valid in the current revision. If so, returns a stamped value. + /// + /// If needed, this will walk each dependency and + /// recursively invoke `maybe_changed_after`, which may in turn + /// re-execute the dependency. This can cause cycles to occur, + /// so the current query must be pushed onto the + /// stack to permit cycle detection and recovery: therefore, + /// takes the `active_query` argument as evidence. + fn verify_value( + &mut self, + db: &dyn Database, + revision_now: Revision, + active_query: &ActiveQueryGuard<'_>, + ) -> Option> { + // If we don't have a memoized value, nothing to validate. + if self.value.is_none() { + return None; + } + if self.verify_revisions(db, revision_now, active_query) { + Some(StampedValue { + durability: self.revisions.durability, + changed_at: self.revisions.changed_at, + value: self.value.as_ref().unwrap().clone(), + }) + } else { + None + } + } + + /// Determines whether the value represented by this memo is still + /// valid in the current revision; note that the value itself is + /// not needed for this check. If needed, this will walk each + /// dependency and recursively invoke `maybe_changed_after`, which + /// may in turn re-execute the dependency. This can cause cycles to occur, + /// so the current query must be pushed onto the + /// stack to permit cycle detection and recovery: therefore, + /// takes the `active_query` argument as evidence. + fn verify_revisions( + &mut self, + db: &dyn Database, + revision_now: Revision, + _active_query: &ActiveQueryGuard<'_>, + ) -> bool { + assert!(self.verified_at != revision_now); + let verified_at = self.verified_at; + + debug!( + "verify_revisions: verified_at={:?}, revision_now={:?}, inputs={:#?}", + verified_at, revision_now, self.revisions.inputs + ); + + if self.check_durability(db.salsa_runtime()) { + return self.mark_value_as_verified(revision_now); + } + + match &self.revisions.inputs { + // We can't validate values that had untracked inputs; just have to + // re-execute. + QueryInputs::Untracked => { + return false; + } + + QueryInputs::NoInputs => {} + + // Check whether any of our inputs changed since the + // **last point where we were verified** (not since we + // last changed). This is important: if we have + // memoized values, then an input may have changed in + // revision R2, but we found that *our* value was the + // same regardless, so our change date is still + // R1. But our *verification* date will be R2, and we + // are only interested in finding out whether the + // input changed *again*. + QueryInputs::Tracked { inputs } => { + let changed_input = + inputs.iter().find(|&&input| db.maybe_changed_after(input, verified_at)); + if let Some(input) = changed_input { + debug!("validate_memoized_value: `{:?}` may have changed", input); + + return false; + } + } + }; + + self.mark_value_as_verified(revision_now) + } + + /// True if this memo is known not to have changed based on its durability. + fn check_durability(&self, runtime: &Runtime) -> bool { + let last_changed = runtime.last_changed_revision(self.revisions.durability); + debug!( + "check_durability(last_changed={:?} <= verified_at={:?}) = {:?}", + last_changed, + self.verified_at, + last_changed <= self.verified_at, + ); + last_changed <= self.verified_at + } + + fn mark_value_as_verified(&mut self, revision_now: Revision) -> bool { + self.verified_at = revision_now; + true + } + + fn has_untracked_input(&self) -> bool { + matches!(self.revisions.inputs, QueryInputs::Untracked) + } +} + +impl std::fmt::Debug for Slot +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(fmt, "{:?}({:?})", Q::default(), self.key) + } +} + +impl LruNode for Slot +where + Q: QueryFunction, + MP: MemoizationPolicy, +{ + fn lru_index(&self) -> &LruIndex { + &self.lru_index + } +} + +/// Check that `Slot: Send + Sync` as long as +/// `DB::DatabaseData: Send + Sync`, which in turn implies that +/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`. +#[allow(dead_code)] +fn check_send_sync() +where + Q: QueryFunction, + MP: MemoizationPolicy, + Q::Key: Send + Sync, + Q::Value: Send + Sync, +{ + fn is_send_sync() {} + is_send_sync::>(); +} + +/// Check that `Slot: 'static` as long as +/// `DB::DatabaseData: 'static`, which in turn implies that +/// `Q::Key: 'static`, `Q::Value: 'static`. +#[allow(dead_code)] +fn check_static() +where + Q: QueryFunction + 'static, + MP: MemoizationPolicy + 'static, + Q::Key: 'static, + Q::Value: 'static, +{ + fn is_static() {} + is_static::>(); +} diff --git a/crates/salsa/src/doctest.rs b/crates/salsa/src/doctest.rs new file mode 100644 index 0000000000..29a8066356 --- /dev/null +++ b/crates/salsa/src/doctest.rs @@ -0,0 +1,115 @@ +//! +#![allow(dead_code)] + +/// Test that a database with a key/value that is not `Send` will, +/// indeed, not be `Send`. +/// +/// ```compile_fail,E0277 +/// use std::rc::Rc; +/// +/// #[salsa::query_group(NoSendSyncStorage)] +/// trait NoSendSyncDatabase: salsa::Database { +/// fn no_send_sync_value(&self, key: bool) -> Rc; +/// fn no_send_sync_key(&self, key: Rc) -> bool; +/// } +/// +/// fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Rc { +/// Rc::new(key) +/// } +/// +/// fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Rc) -> bool { +/// *key +/// } +/// +/// #[salsa::database(NoSendSyncStorage)] +/// #[derive(Default)] +/// struct DatabaseImpl { +/// storage: salsa::Storage, +/// } +/// +/// impl salsa::Database for DatabaseImpl { +/// } +/// +/// fn is_send(_: T) { } +/// +/// fn assert_send() { +/// is_send(DatabaseImpl::default()); +/// } +/// ``` +fn test_key_not_send_db_not_send() {} + +/// Test that a database with a key/value that is not `Sync` will not +/// be `Send`. +/// +/// ```compile_fail,E0277 +/// use std::rc::Rc; +/// use std::cell::Cell; +/// +/// #[salsa::query_group(NoSendSyncStorage)] +/// trait NoSendSyncDatabase: salsa::Database { +/// fn no_send_sync_value(&self, key: bool) -> Cell; +/// fn no_send_sync_key(&self, key: Cell) -> bool; +/// } +/// +/// fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Cell { +/// Cell::new(key) +/// } +/// +/// fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Cell) -> bool { +/// *key +/// } +/// +/// #[salsa::database(NoSendSyncStorage)] +/// #[derive(Default)] +/// struct DatabaseImpl { +/// runtime: salsa::Storage, +/// } +/// +/// impl salsa::Database for DatabaseImpl { +/// } +/// +/// fn is_send(_: T) { } +/// +/// fn assert_send() { +/// is_send(DatabaseImpl::default()); +/// } +/// ``` +fn test_key_not_sync_db_not_send() {} + +/// Test that a database with a key/value that is not `Sync` will +/// not be `Sync`. +/// +/// ```compile_fail,E0277 +/// use std::cell::Cell; +/// use std::rc::Rc; +/// +/// #[salsa::query_group(NoSendSyncStorage)] +/// trait NoSendSyncDatabase: salsa::Database { +/// fn no_send_sync_value(&self, key: bool) -> Cell; +/// fn no_send_sync_key(&self, key: Cell) -> bool; +/// } +/// +/// fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Cell { +/// Cell::new(key) +/// } +/// +/// fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Cell) -> bool { +/// *key +/// } +/// +/// #[salsa::database(NoSendSyncStorage)] +/// #[derive(Default)] +/// struct DatabaseImpl { +/// runtime: salsa::Storage, +/// } +/// +/// impl salsa::Database for DatabaseImpl { +/// } +/// +/// fn is_sync(_: T) { } +/// +/// fn assert_send() { +/// is_sync(DatabaseImpl::default()); +/// } +/// ``` +fn test_key_not_sync_db_not_sync() {} diff --git a/crates/salsa/src/durability.rs b/crates/salsa/src/durability.rs new file mode 100644 index 0000000000..0c82f6345a --- /dev/null +++ b/crates/salsa/src/durability.rs @@ -0,0 +1,50 @@ +//! +/// Describes how likely a value is to change -- how "durable" it is. +/// By default, inputs have `Durability::LOW` and interned values have +/// `Durability::HIGH`. But inputs can be explicitly set with other +/// durabilities. +/// +/// We use durabilities to optimize the work of "revalidating" a query +/// after some input has changed. Ordinarily, in a new revision, +/// queries have to trace all their inputs back to the base inputs to +/// determine if any of those inputs have changed. But if we know that +/// the only changes were to inputs of low durability (the common +/// case), and we know that the query only used inputs of medium +/// durability or higher, then we can skip that enumeration. +/// +/// Typically, one assigns low durabilites to inputs that the user is +/// frequently editing. Medium or high durabilities are used for +/// configuration, the source from library crates, or other things +/// that are unlikely to be edited. +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct Durability(u8); + +impl Durability { + /// Low durability: things that change frequently. + /// + /// Example: part of the crate being edited + pub const LOW: Durability = Durability(0); + + /// Medium durability: things that change sometimes, but rarely. + /// + /// Example: a Cargo.toml file + pub const MEDIUM: Durability = Durability(1); + + /// High durability: things that are not expected to change under + /// common usage. + /// + /// Example: the standard library or something from crates.io + pub const HIGH: Durability = Durability(2); + + /// The maximum possible durability; equivalent to HIGH but + /// "conceptually" distinct (i.e., if we add more durability + /// levels, this could change). + pub(crate) const MAX: Durability = Self::HIGH; + + /// Number of durability levels. + pub(crate) const LEN: usize = 3; + + pub(crate) fn index(self) -> usize { + self.0 as usize + } +} diff --git a/crates/salsa/src/hash.rs b/crates/salsa/src/hash.rs new file mode 100644 index 0000000000..47a2dd1ce0 --- /dev/null +++ b/crates/salsa/src/hash.rs @@ -0,0 +1,4 @@ +//! +pub(crate) type FxHasher = std::hash::BuildHasherDefault; +pub(crate) type FxIndexSet = indexmap::IndexSet; +pub(crate) type FxIndexMap = indexmap::IndexMap; diff --git a/crates/salsa/src/input.rs b/crates/salsa/src/input.rs new file mode 100644 index 0000000000..4e8fca6149 --- /dev/null +++ b/crates/salsa/src/input.rs @@ -0,0 +1,384 @@ +//! +use crate::debug::TableEntry; +use crate::durability::Durability; +use crate::hash::FxIndexMap; +use crate::plumbing::CycleRecoveryStrategy; +use crate::plumbing::InputQueryStorageOps; +use crate::plumbing::QueryStorageMassOps; +use crate::plumbing::QueryStorageOps; +use crate::revision::Revision; +use crate::runtime::StampedValue; +use crate::Database; +use crate::Query; +use crate::Runtime; +use crate::{DatabaseKeyIndex, QueryDb}; +use indexmap::map::Entry; +use parking_lot::RwLock; +use std::convert::TryFrom; +use std::iter; +use tracing::debug; + +/// Input queries store the result plus a list of the other queries +/// that they invoked. This means we can avoid recomputing them when +/// none of those inputs have changed. +pub struct InputStorage +where + Q: Query, +{ + group_index: u16, + slots: RwLock>>, +} + +struct Slot { + database_key_index: DatabaseKeyIndex, + stamped_value: RwLock>, +} + +impl std::panic::RefUnwindSafe for InputStorage +where + Q: Query, + Q::Key: std::panic::RefUnwindSafe, + Q::Value: std::panic::RefUnwindSafe, +{ +} + +impl QueryStorageOps for InputStorage +where + Q: Query, +{ + const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = CycleRecoveryStrategy::Panic; + + fn new(group_index: u16) -> Self { + InputStorage { group_index, slots: Default::default() } + } + + fn fmt_index( + &self, + _db: &>::DynDb, + index: DatabaseKeyIndex, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + assert_eq!(index.group_index, self.group_index); + assert_eq!(index.query_index, Q::QUERY_INDEX); + let slot_map = self.slots.read(); + let key = slot_map.get_index(index.key_index as usize).unwrap().0; + write!(fmt, "{}({:?})", Q::QUERY_NAME, key) + } + + fn maybe_changed_after( + &self, + db: &>::DynDb, + input: DatabaseKeyIndex, + revision: Revision, + ) -> bool { + assert_eq!(input.group_index, self.group_index); + assert_eq!(input.query_index, Q::QUERY_INDEX); + debug_assert!(revision < db.salsa_runtime().current_revision()); + let slots = &self.slots.read(); + let slot = slots.get_index(input.key_index as usize).unwrap().1; + + debug!("maybe_changed_after(slot={:?}, revision={:?})", Q::default(), revision,); + + let changed_at = slot.stamped_value.read().changed_at; + + debug!("maybe_changed_after: changed_at = {:?}", changed_at); + + changed_at > revision + } + + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { + db.unwind_if_cancelled(); + + let slots = &self.slots.read(); + let slot = slots + .get(key) + .unwrap_or_else(|| panic!("no value set for {:?}({:?})", Q::default(), key)); + + let StampedValue { value, durability, changed_at } = slot.stamped_value.read().clone(); + + db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted( + slot.database_key_index, + durability, + changed_at, + ); + + value + } + + fn durability(&self, _db: &>::DynDb, key: &Q::Key) -> Durability { + match self.slots.read().get(key) { + Some(slot) => slot.stamped_value.read().durability, + None => panic!("no value set for {:?}({:?})", Q::default(), key), + } + } + + fn entries(&self, _db: &>::DynDb) -> C + where + C: std::iter::FromIterator>, + { + let slots = self.slots.read(); + slots + .iter() + .map(|(key, slot)| { + TableEntry::new(key.clone(), Some(slot.stamped_value.read().value.clone())) + }) + .collect() + } +} + +impl QueryStorageMassOps for InputStorage +where + Q: Query, +{ + fn purge(&self) { + *self.slots.write() = Default::default(); + } +} + +impl InputQueryStorageOps for InputStorage +where + Q: Query, +{ + fn set(&self, runtime: &mut Runtime, key: &Q::Key, value: Q::Value, durability: Durability) { + tracing::debug!("{:?}({:?}) = {:?} ({:?})", Q::default(), key, value, durability); + + // The value is changing, so we need a new revision (*). We also + // need to update the 'last changed' revision by invoking + // `guard.mark_durability_as_changed`. + // + // CAREFUL: This will block until the global revision lock can + // be acquired. If there are still queries executing, they may + // need to read from this input. Therefore, we wait to acquire + // the lock on `map` until we also hold the global query write + // lock. + // + // (*) Technically, since you can't presently access an input + // for a non-existent key, and you can't enumerate the set of + // keys, we only need a new revision if the key used to + // exist. But we may add such methods in the future and this + // case doesn't generally seem worth optimizing for. + runtime.with_incremented_revision(|next_revision| { + let mut slots = self.slots.write(); + + // Do this *after* we acquire the lock, so that we are not + // racing with somebody else to modify this same cell. + // (Otherwise, someone else might write a *newer* revision + // into the same cell while we block on the lock.) + let stamped_value = StampedValue { value, durability, changed_at: next_revision }; + + match slots.entry(key.clone()) { + Entry::Occupied(entry) => { + let mut slot_stamped_value = entry.get().stamped_value.write(); + let old_durability = slot_stamped_value.durability; + *slot_stamped_value = stamped_value; + Some(old_durability) + } + + Entry::Vacant(entry) => { + let key_index = u32::try_from(entry.index()).unwrap(); + let database_key_index = DatabaseKeyIndex { + group_index: self.group_index, + query_index: Q::QUERY_INDEX, + key_index, + }; + entry.insert(Slot { + database_key_index, + stamped_value: RwLock::new(stamped_value), + }); + None + } + } + }); + } +} + +/// Same as `InputStorage`, but optimized for queries that take no inputs. +pub struct UnitInputStorage +where + Q: Query, +{ + group_index: u16, + slot: UnitSlot, +} + +struct UnitSlot { + database_key_index: DatabaseKeyIndex, + stamped_value: RwLock>>, +} + +impl std::panic::RefUnwindSafe for UnitInputStorage +where + Q: Query, + Q::Key: std::panic::RefUnwindSafe, + Q::Value: std::panic::RefUnwindSafe, +{ +} + +impl QueryStorageOps for UnitInputStorage +where + Q: Query, +{ + const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = CycleRecoveryStrategy::Panic; + + fn new(group_index: u16) -> Self { + let database_key_index = + DatabaseKeyIndex { group_index, query_index: Q::QUERY_INDEX, key_index: 0 }; + UnitInputStorage { + group_index, + slot: UnitSlot { database_key_index, stamped_value: RwLock::new(None) }, + } + } + + fn fmt_index( + &self, + _db: &>::DynDb, + index: DatabaseKeyIndex, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + assert_eq!(index.group_index, self.group_index); + assert_eq!(index.query_index, Q::QUERY_INDEX); + write!(fmt, "{}", Q::QUERY_NAME) + } + + fn maybe_changed_after( + &self, + db: &>::DynDb, + input: DatabaseKeyIndex, + revision: Revision, + ) -> bool { + assert_eq!(input.group_index, self.group_index); + assert_eq!(input.query_index, Q::QUERY_INDEX); + debug_assert!(revision < db.salsa_runtime().current_revision()); + + debug!("maybe_changed_after(slot={:?}, revision={:?})", Q::default(), revision,); + + let changed_at = self.slot.stamped_value.read().as_ref().unwrap().changed_at; + + debug!("maybe_changed_after: changed_at = {:?}", changed_at); + + changed_at > revision + } + + fn fetch(&self, db: &>::DynDb, &(): &Q::Key) -> Q::Value { + db.unwind_if_cancelled(); + + let StampedValue { value, durability, changed_at } = self + .slot + .stamped_value + .read() + .clone() + .unwrap_or_else(|| panic!("no value set for {:?}", Q::default())); + + db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted( + self.slot.database_key_index, + durability, + changed_at, + ); + + value + } + + fn durability(&self, _db: &>::DynDb, &(): &Q::Key) -> Durability { + match &*self.slot.stamped_value.read() { + Some(stamped_value) => stamped_value.durability, + None => panic!("no value set for {:?}", Q::default(),), + } + } + + fn entries(&self, _db: &>::DynDb) -> C + where + C: std::iter::FromIterator>, + { + iter::once(TableEntry::new( + (), + self.slot.stamped_value.read().as_ref().map(|it| it.value.clone()), + )) + .collect() + } +} + +impl QueryStorageMassOps for UnitInputStorage +where + Q: Query, +{ + fn purge(&self) { + *self.slot.stamped_value.write() = Default::default(); + } +} + +impl InputQueryStorageOps for UnitInputStorage +where + Q: Query, +{ + fn set(&self, runtime: &mut Runtime, (): &Q::Key, value: Q::Value, durability: Durability) { + tracing::debug!("{:?} = {:?} ({:?})", Q::default(), value, durability); + + // The value is changing, so we need a new revision (*). We also + // need to update the 'last changed' revision by invoking + // `guard.mark_durability_as_changed`. + // + // CAREFUL: This will block until the global revision lock can + // be acquired. If there are still queries executing, they may + // need to read from this input. Therefore, we wait to acquire + // the lock on `map` until we also hold the global query write + // lock. + // + // (*) Technically, since you can't presently access an input + // for a non-existent key, and you can't enumerate the set of + // keys, we only need a new revision if the key used to + // exist. But we may add such methods in the future and this + // case doesn't generally seem worth optimizing for. + runtime.with_incremented_revision(|next_revision| { + let mut stamped_value_slot = self.slot.stamped_value.write(); + + // Do this *after* we acquire the lock, so that we are not + // racing with somebody else to modify this same cell. + // (Otherwise, someone else might write a *newer* revision + // into the same cell while we block on the lock.) + let stamped_value = StampedValue { value, durability, changed_at: next_revision }; + + match &mut *stamped_value_slot { + Some(slot_stamped_value) => { + let old_durability = slot_stamped_value.durability; + *slot_stamped_value = stamped_value; + Some(old_durability) + } + + stamped_value_slot @ None => { + *stamped_value_slot = Some(stamped_value); + None + } + } + }); + } +} + +/// Check that `Slot: Send + Sync` as long as +/// `DB::DatabaseData: Send + Sync`, which in turn implies that +/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`. +#[allow(dead_code)] +fn check_send_sync() +where + Q: Query, + Q::Key: Send + Sync, + Q::Value: Send + Sync, +{ + fn is_send_sync() {} + is_send_sync::>(); + is_send_sync::>(); +} + +/// Check that `Slot: 'static` as long as +/// `DB::DatabaseData: 'static`, which in turn implies that +/// `Q::Key: 'static`, `Q::Value: 'static`. +#[allow(dead_code)] +fn check_static() +where + Q: Query + 'static, + Q::Key: 'static, + Q::Value: 'static, +{ + fn is_static() {} + is_static::>(); + is_static::>(); +} diff --git a/crates/salsa/src/intern_id.rs b/crates/salsa/src/intern_id.rs new file mode 100644 index 0000000000..a7bbc088f9 --- /dev/null +++ b/crates/salsa/src/intern_id.rs @@ -0,0 +1,131 @@ +//! +use std::fmt; +use std::num::NonZeroU32; + +/// The "raw-id" is used for interned keys in salsa -- it is basically +/// a newtype'd u32. Typically, it is wrapped in a type of your own +/// devising. For more information about interned keys, see [the +/// interned key RFC][rfc]. +/// +/// # Creating a `InternId` +// +/// InternId values can be constructed using the `From` impls, +/// which are implemented for `u32` and `usize`: +/// +/// ``` +/// # use salsa::InternId; +/// let intern_id1 = InternId::from(22_u32); +/// let intern_id2 = InternId::from(22_usize); +/// assert_eq!(intern_id1, intern_id2); +/// ``` +/// +/// # Converting to a u32 or usize +/// +/// Normally, there should be no need to access the underlying integer +/// in a `InternId`. But if you do need to do so, you can convert to a +/// `usize` using the `as_u32` or `as_usize` methods or the `From` impls. +/// +/// ``` +/// # use salsa::InternId; +/// let intern_id = InternId::from(22_u32); +/// let value = u32::from(intern_id); +/// assert_eq!(value, 22); +/// ``` +/// +/// ## Illegal values +/// +/// Be warned, however, that `InternId` values cannot be created from +/// *arbitrary* values -- in particular large values greater than +/// `InternId::MAX` will panic. Those large values are reserved so that +/// the Rust compiler can use them as sentinel values, which means +/// that (for example) `Option` is represented in a single +/// word. +/// +/// ```should_panic +/// # use salsa::InternId; +/// InternId::from(InternId::MAX); +/// ``` +/// +/// [rfc]: https://github.com/salsa-rs/salsa-rfcs/pull/2 +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct InternId { + value: NonZeroU32, +} + +impl InternId { + /// The maximum allowed `InternId`. This value can grow between + /// releases without affecting semver. + pub const MAX: u32 = 0xFFFF_FF00; + + /// Creates a new InternId. + /// + /// # Safety + /// + /// `value` must be less than `MAX` + pub const unsafe fn new_unchecked(value: u32) -> Self { + debug_assert!(value < InternId::MAX); + InternId { value: NonZeroU32::new_unchecked(value + 1) } + } + + /// Convert this raw-id into a u32 value. + /// + /// ``` + /// # use salsa::InternId; + /// let intern_id = InternId::from(22_u32); + /// let value = intern_id.as_usize(); + /// assert_eq!(value, 22); + /// ``` + pub fn as_u32(self) -> u32 { + self.value.get() - 1 + } + + /// Convert this raw-id into a usize value. + /// + /// ``` + /// # use salsa::InternId; + /// let intern_id = InternId::from(22_u32); + /// let value = intern_id.as_usize(); + /// assert_eq!(value, 22); + /// ``` + pub fn as_usize(self) -> usize { + self.as_u32() as usize + } +} + +impl From for u32 { + fn from(raw: InternId) -> u32 { + raw.as_u32() + } +} + +impl From for usize { + fn from(raw: InternId) -> usize { + raw.as_usize() + } +} + +impl From for InternId { + fn from(id: u32) -> InternId { + assert!(id < InternId::MAX); + unsafe { InternId::new_unchecked(id) } + } +} + +impl From for InternId { + fn from(id: usize) -> InternId { + assert!(id < (InternId::MAX as usize)); + unsafe { InternId::new_unchecked(id as u32) } + } +} + +impl fmt::Debug for InternId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.as_usize().fmt(f) + } +} + +impl fmt::Display for InternId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.as_usize().fmt(f) + } +} diff --git a/crates/salsa/src/interned.rs b/crates/salsa/src/interned.rs new file mode 100644 index 0000000000..731839e959 --- /dev/null +++ b/crates/salsa/src/interned.rs @@ -0,0 +1,508 @@ +//! +use crate::debug::TableEntry; +use crate::durability::Durability; +use crate::intern_id::InternId; +use crate::plumbing::CycleRecoveryStrategy; +use crate::plumbing::HasQueryGroup; +use crate::plumbing::QueryStorageMassOps; +use crate::plumbing::QueryStorageOps; +use crate::revision::Revision; +use crate::Query; +use crate::QueryTable; +use crate::{Database, DatabaseKeyIndex, QueryDb}; +use parking_lot::RwLock; +use rustc_hash::FxHashMap; +use std::collections::hash_map::Entry; +use std::convert::From; +use std::fmt::Debug; +use std::hash::Hash; +use triomphe::Arc; + +const INTERN_DURABILITY: Durability = Durability::HIGH; + +/// Handles storage where the value is 'derived' by executing a +/// function (in contrast to "inputs"). +pub struct InternedStorage +where + Q: Query, + Q::Key: InternValue, + Q::Value: InternKey, +{ + group_index: u16, + tables: RwLock, Q::Key>>, +} + +/// Storage for the looking up interned things. +pub struct LookupInternedStorage +where + Q: Query, + Q::Key: InternKey, + Q::Value: InternValue, +{ + phantom: std::marker::PhantomData<(Q::Key, IQ)>, +} + +struct InternTables { + /// Map from the key to the corresponding intern-index. + map: FxHashMap, + + /// For each valid intern-index, stores the interned value. + values: Vec>>, +} + +/// Trait implemented for the "key" that results from a +/// `#[salsa::intern]` query. This is basically meant to be a +/// "newtype"'d `u32`. +pub trait InternKey { + /// Create an instance of the intern-key from a `u32` value. + fn from_intern_id(v: InternId) -> Self; + + /// Extract the `u32` with which the intern-key was created. + fn as_intern_id(&self) -> InternId; +} + +impl InternKey for InternId { + fn from_intern_id(v: InternId) -> InternId { + v + } + + fn as_intern_id(&self) -> InternId { + *self + } +} + +/// Trait implemented for the "value" that is being interned. +pub trait InternValue { + /// They key used to intern this value by. + type Key: Eq + Hash + Debug + Clone; + /// Maps the value to a key that will be used to intern it. + fn into_key(&self) -> Self::Key; + /// Calls the given function with the key that was used to intern this value. + /// + /// This is mainly used to prevent frequent cloning of the key when doing a lookup. + #[inline] + fn with_key T, T>(&self, f: F) -> T { + f(&self.into_key()) + } +} + +impl + InternValue for (A, B) +{ + type Key = Self; + #[inline] + fn into_key(&self) -> Self::Key { + self.clone() + } + #[inline] + fn with_key T, T>(&self, f: F) -> T { + f(self) + } +} + +/// Implement [`InternValue`] trivially, that is without actually mapping at all. +#[macro_export] +macro_rules! impl_intern_value_trivial { + ($($ty:ty),*) => { + $( + impl $crate::InternValue for $ty { + type Key = $ty; + #[inline] + fn into_key(&self) -> Self::Key { + self.clone() + } + #[inline] + fn with_key T, T>(&self, f: F) -> T { + f(self) + } + } + )* + }; +} +impl_intern_value_trivial!(String); +#[derive(Debug)] +struct Slot { + /// DatabaseKeyIndex for this slot. + database_key_index: DatabaseKeyIndex, + + /// Value that was interned. + value: V, + + /// When was this intern'd? + /// + /// (This informs the "changed-at" result) + interned_at: Revision, +} + +impl std::panic::RefUnwindSafe for InternedStorage +where + Q: Query, + Q::Key: InternValue, + Q::Key: std::panic::RefUnwindSafe, + Q::Value: InternKey, + Q::Value: std::panic::RefUnwindSafe, +{ +} + +impl InternTables { + /// Returns the slot for the given key. + fn slot_for_key(&self, key: &K) -> Option<(Arc>, InternId)> { + let &index = self.map.get(key)?; + Some((self.slot_for_index(index), index)) + } + + /// Returns the slot at the given index. + fn slot_for_index(&self, index: InternId) -> Arc> { + let slot = &self.values[index.as_usize()]; + slot.clone() + } +} + +impl Default for InternTables +where + K: Eq + Hash, +{ + fn default() -> Self { + Self { map: Default::default(), values: Default::default() } + } +} + +type MappedKey = <::Key as InternValue>::Key; + +impl InternedStorage +where + Q: Query, + Q::Key: InternValue, + Q::Value: InternKey, +{ + /// Creates a new slot. + fn intern_index( + &self, + db: &>::DynDb, + mapped_key: MappedKey, + insert: impl FnOnce(Q::Value) -> Q::Key, + ) -> (Arc>, InternId) { + let revision_now = db.salsa_runtime().current_revision(); + + let mut tables = self.tables.write(); + let tables = &mut *tables; + let entry = match tables.map.entry(mapped_key) { + Entry::Vacant(entry) => entry, + Entry::Occupied(entry) => { + // Somebody inserted this key while we were waiting + // for the write lock. In this case, we don't need to + // update the `accessed_at` field because they should + // have already done so! + let index = *entry.get(); + let slot = &tables.values[index.as_usize()]; + return (slot.clone(), index); + } + }; + + let create_slot = |index: InternId| { + let database_key_index = DatabaseKeyIndex { + group_index: self.group_index, + query_index: Q::QUERY_INDEX, + key_index: index.as_u32(), + }; + Arc::new(Slot { + database_key_index, + value: insert(Q::Value::from_intern_id(index)), + interned_at: revision_now, + }) + }; + + let index = InternId::from(tables.values.len()); + let slot = create_slot(index); + tables.values.push(slot.clone()); + entry.insert(index); + + (slot, index) + } + + fn intern_check(&self, key: &MappedKey) -> Option<(Arc>, InternId)> { + self.tables.read().slot_for_key(key) + } + + /// Given an index, lookup and clone its value, updating the + /// `accessed_at` time if necessary. + fn lookup_value(&self, index: InternId) -> Arc> { + self.tables.read().slot_for_index(index) + } + + fn fetch_or_insert( + &self, + db: &>::DynDb, + key: MappedKey, + insert: impl FnOnce(Q::Value) -> Q::Key, + ) -> Q::Value { + db.unwind_if_cancelled(); + let (slot, index) = match self.intern_check(&key) { + Some(i) => i, + None => self.intern_index(db, key, insert), + }; + let changed_at = slot.interned_at; + db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted( + slot.database_key_index, + INTERN_DURABILITY, + changed_at, + ); + ::from_intern_id(index) + } +} + +impl QueryStorageOps for InternedStorage +where + Q: Query, + Q::Key: InternValue, + Q::Value: InternKey, +{ + const CYCLE_STRATEGY: crate::plumbing::CycleRecoveryStrategy = CycleRecoveryStrategy::Panic; + + fn new(group_index: u16) -> Self { + InternedStorage { group_index, tables: RwLock::new(InternTables::default()) } + } + + fn fmt_index( + &self, + _db: &>::DynDb, + index: DatabaseKeyIndex, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + assert_eq!(index.group_index, self.group_index); + assert_eq!(index.query_index, Q::QUERY_INDEX); + let intern_id = InternId::from(index.key_index); + let slot = self.lookup_value(intern_id); + write!(fmt, "{}({:?})", Q::QUERY_NAME, slot.value) + } + + fn maybe_changed_after( + &self, + db: &>::DynDb, + input: DatabaseKeyIndex, + revision: Revision, + ) -> bool { + assert_eq!(input.group_index, self.group_index); + assert_eq!(input.query_index, Q::QUERY_INDEX); + debug_assert!(revision < db.salsa_runtime().current_revision()); + let intern_id = InternId::from(input.key_index); + let slot = self.lookup_value(intern_id); + slot.maybe_changed_after(revision) + } + + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { + db.unwind_if_cancelled(); + + let (slot, index) = match key.with_key(|key| self.intern_check(key)) { + Some(i) => i, + None => self.intern_index(db, key.into_key(), |_| key.clone()), + }; + let changed_at = slot.interned_at; + db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted( + slot.database_key_index, + INTERN_DURABILITY, + changed_at, + ); + ::from_intern_id(index) + } + + fn durability(&self, _db: &>::DynDb, _key: &Q::Key) -> Durability { + INTERN_DURABILITY + } + + fn entries(&self, _db: &>::DynDb) -> C + where + C: std::iter::FromIterator>, + { + let tables = self.tables.read(); + tables + .map + .values() + .map(|index| { + TableEntry::new( + tables.values[index.as_usize()].value.clone(), + Some(::from_intern_id(*index)), + ) + }) + .collect() + } +} + +impl QueryStorageMassOps for InternedStorage +where + Q: Query, + Q::Key: InternValue, + Q::Value: InternKey, +{ + fn purge(&self) { + *self.tables.write() = Default::default(); + } +} + +// Workaround for +// ``` +// IQ: for<'d> QueryDb< +// 'd, +// DynDb = >::DynDb, +// Group = >::Group, +// GroupStorage = >::GroupStorage, +// >, +// ``` +// not working to make rustc know DynDb, Group and GroupStorage being the same in `Q` and `IQ` +#[doc(hidden)] +pub trait EqualDynDb<'d, IQ>: QueryDb<'d> +where + IQ: QueryDb<'d>, +{ + fn convert_db(d: &Self::DynDb) -> &IQ::DynDb; + fn convert_group_storage(d: &Self::GroupStorage) -> &IQ::GroupStorage; +} + +impl<'d, IQ, Q> EqualDynDb<'d, IQ> for Q +where + Q: QueryDb<'d, DynDb = IQ::DynDb, Group = IQ::Group, GroupStorage = IQ::GroupStorage>, + Q::DynDb: HasQueryGroup, + IQ: QueryDb<'d>, +{ + fn convert_db(d: &Self::DynDb) -> &IQ::DynDb { + d + } + fn convert_group_storage(d: &Self::GroupStorage) -> &IQ::GroupStorage { + d + } +} + +impl QueryStorageOps for LookupInternedStorage +where + Q: Query, + Q::Key: InternKey, + Q::Value: InternValue, + IQ: Query>, + for<'d> Q: EqualDynDb<'d, IQ>, +{ + const CYCLE_STRATEGY: CycleRecoveryStrategy = CycleRecoveryStrategy::Panic; + + fn new(_group_index: u16) -> Self { + LookupInternedStorage { phantom: std::marker::PhantomData } + } + + fn fmt_index( + &self, + db: &>::DynDb, + index: DatabaseKeyIndex, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + let group_storage = + <>::DynDb as HasQueryGroup>::group_storage(db); + let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage)); + interned_storage.fmt_index(Q::convert_db(db), index, fmt) + } + + fn maybe_changed_after( + &self, + db: &>::DynDb, + input: DatabaseKeyIndex, + revision: Revision, + ) -> bool { + let group_storage = + <>::DynDb as HasQueryGroup>::group_storage(db); + let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage)); + interned_storage.maybe_changed_after(Q::convert_db(db), input, revision) + } + + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value { + let index = key.as_intern_id(); + let group_storage = + <>::DynDb as HasQueryGroup>::group_storage(db); + let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage)); + let slot = interned_storage.lookup_value(index); + let value = slot.value.clone(); + let interned_at = slot.interned_at; + db.salsa_runtime().report_query_read_and_unwind_if_cycle_resulted( + slot.database_key_index, + INTERN_DURABILITY, + interned_at, + ); + value + } + + fn durability(&self, _db: &>::DynDb, _key: &Q::Key) -> Durability { + INTERN_DURABILITY + } + + fn entries(&self, db: &>::DynDb) -> C + where + C: std::iter::FromIterator>, + { + let group_storage = + <>::DynDb as HasQueryGroup>::group_storage(db); + let interned_storage = IQ::query_storage(Q::convert_group_storage(group_storage)); + let tables = interned_storage.tables.read(); + tables + .map + .values() + .map(|index| { + TableEntry::new( + ::from_intern_id(*index), + Some(tables.values[index.as_usize()].value.clone()), + ) + }) + .collect() + } +} + +impl QueryStorageMassOps for LookupInternedStorage +where + Q: Query, + Q::Key: InternKey, + Q::Value: InternValue, + IQ: Query, +{ + fn purge(&self) {} +} + +impl Slot { + fn maybe_changed_after(&self, revision: Revision) -> bool { + self.interned_at > revision + } +} + +/// Check that `Slot: Send + Sync` as long as +/// `DB::DatabaseData: Send + Sync`, which in turn implies that +/// `Q::Key: Send + Sync`, `Q::Value: Send + Sync`. +#[allow(dead_code)] +fn check_send_sync() +where + K: Send + Sync, +{ + fn is_send_sync() {} + is_send_sync::>(); +} + +/// Check that `Slot: 'static` as long as +/// `DB::DatabaseData: 'static`, which in turn implies that +/// `Q::Key: 'static`, `Q::Value: 'static`. +#[allow(dead_code)] +fn check_static() +where + K: 'static, +{ + fn is_static() {} + is_static::>(); +} + +impl<'me, Q> QueryTable<'me, Q> +where + Q: Query>, + Q::Key: InternValue, + Q::Value: InternKey, +{ + /// Fetches the intern id for the given key or inserts it if it does not exist. + pub fn get_or_insert( + &self, + key: MappedKey, + insert: impl FnOnce(Q::Value) -> Q::Key, + ) -> Q::Value { + self.storage.fetch_or_insert(self.db, key, insert) + } +} diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs new file mode 100644 index 0000000000..2d58beafb2 --- /dev/null +++ b/crates/salsa/src/lib.rs @@ -0,0 +1,742 @@ +//! +#![allow(clippy::type_complexity)] +#![allow(clippy::question_mark)] +#![warn(rust_2018_idioms)] +#![warn(missing_docs)] + +//! The salsa crate is a crate for incremental recomputation. It +//! permits you to define a "database" of queries with both inputs and +//! values derived from those inputs; as you set the inputs, you can +//! re-execute the derived queries and it will try to re-use results +//! from previous invocations as appropriate. + +mod derived; +mod doctest; +mod durability; +mod hash; +mod input; +mod intern_id; +mod interned; +mod lru; +mod revision; +mod runtime; +mod storage; + +pub mod debug; +/// Items in this module are public for implementation reasons, +/// and are exempt from the SemVer guarantees. +#[doc(hidden)] +pub mod plumbing; + +use crate::plumbing::CycleRecoveryStrategy; +use crate::plumbing::DerivedQueryStorageOps; +use crate::plumbing::InputQueryStorageOps; +use crate::plumbing::LruQueryStorageOps; +use crate::plumbing::QueryStorageMassOps; +use crate::plumbing::QueryStorageOps; +pub use crate::revision::Revision; +use std::fmt::{self, Debug}; +use std::hash::Hash; +use std::panic::AssertUnwindSafe; +use std::panic::{self, UnwindSafe}; + +pub use crate::durability::Durability; +pub use crate::intern_id::InternId; +pub use crate::interned::{InternKey, InternValue}; +pub use crate::runtime::Runtime; +pub use crate::runtime::RuntimeId; +pub use crate::storage::Storage; + +/// The base trait which your "query context" must implement. Gives +/// access to the salsa runtime, which you must embed into your query +/// context (along with whatever other state you may require). +pub trait Database: plumbing::DatabaseOps { + /// This function is invoked at key points in the salsa + /// runtime. It permits the database to be customized and to + /// inject logging or other custom behavior. + fn salsa_event(&self, event_fn: Event) { + #![allow(unused_variables)] + } + + /// Starts unwinding the stack if the current revision is cancelled. + /// + /// This method can be called by query implementations that perform + /// potentially expensive computations, in order to speed up propagation of + /// cancellation. + /// + /// Cancellation will automatically be triggered by salsa on any query + /// invocation. + /// + /// This method should not be overridden by `Database` implementors. A + /// `salsa_event` is emitted when this method is called, so that should be + /// used instead. + #[inline] + fn unwind_if_cancelled(&self) { + let runtime = self.salsa_runtime(); + self.salsa_event(Event { + runtime_id: runtime.id(), + kind: EventKind::WillCheckCancellation, + }); + + let current_revision = runtime.current_revision(); + let pending_revision = runtime.pending_revision(); + tracing::debug!( + "unwind_if_cancelled: current_revision={:?}, pending_revision={:?}", + current_revision, + pending_revision + ); + if pending_revision > current_revision { + runtime.unwind_cancelled(); + } + } + + /// Gives access to the underlying salsa runtime. + /// + /// This method should not be overridden by `Database` implementors. + fn salsa_runtime(&self) -> &Runtime { + self.ops_salsa_runtime() + } + + /// Gives access to the underlying salsa runtime. + /// + /// This method should not be overridden by `Database` implementors. + fn salsa_runtime_mut(&mut self) -> &mut Runtime { + self.ops_salsa_runtime_mut() + } +} + +/// The `Event` struct identifies various notable things that can +/// occur during salsa execution. Instances of this struct are given +/// to `salsa_event`. +pub struct Event { + /// The id of the snapshot that triggered the event. Usually + /// 1-to-1 with a thread, as well. + pub runtime_id: RuntimeId, + + /// What sort of event was it. + pub kind: EventKind, +} + +impl Event { + /// Returns a type that gives a user-readable debug output. + /// Use like `println!("{:?}", index.debug(db))`. + pub fn debug<'me, D: ?Sized>(&'me self, db: &'me D) -> impl std::fmt::Debug + 'me + where + D: plumbing::DatabaseOps, + { + EventDebug { event: self, db } + } +} + +impl fmt::Debug for Event { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Event") + .field("runtime_id", &self.runtime_id) + .field("kind", &self.kind) + .finish() + } +} + +struct EventDebug<'me, D: ?Sized> +where + D: plumbing::DatabaseOps, +{ + event: &'me Event, + db: &'me D, +} + +impl<'me, D: ?Sized> fmt::Debug for EventDebug<'me, D> +where + D: plumbing::DatabaseOps, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Event") + .field("runtime_id", &self.event.runtime_id) + .field("kind", &self.event.kind.debug(self.db)) + .finish() + } +} + +/// An enum identifying the various kinds of events that can occur. +pub enum EventKind { + /// Occurs when we found that all inputs to a memoized value are + /// up-to-date and hence the value can be re-used without + /// executing the closure. + /// + /// Executes before the "re-used" value is returned. + DidValidateMemoizedValue { + /// The database-key for the affected value. Implements `Debug`. + database_key: DatabaseKeyIndex, + }, + + /// Indicates that another thread (with id `other_runtime_id`) is processing the + /// given query (`database_key`), so we will block until they + /// finish. + /// + /// Executes after we have registered with the other thread but + /// before they have answered us. + /// + /// (NB: you can find the `id` of the current thread via the + /// `salsa_runtime`) + WillBlockOn { + /// The id of the runtime we will block on. + other_runtime_id: RuntimeId, + + /// The database-key for the affected value. Implements `Debug`. + database_key: DatabaseKeyIndex, + }, + + /// Indicates that the function for this query will be executed. + /// This is either because it has never executed before or because + /// its inputs may be out of date. + WillExecute { + /// The database-key for the affected value. Implements `Debug`. + database_key: DatabaseKeyIndex, + }, + + /// Indicates that `unwind_if_cancelled` was called and salsa will check if + /// the current revision has been cancelled. + WillCheckCancellation, +} + +impl EventKind { + /// Returns a type that gives a user-readable debug output. + /// Use like `println!("{:?}", index.debug(db))`. + pub fn debug<'me, D: ?Sized>(&'me self, db: &'me D) -> impl std::fmt::Debug + 'me + where + D: plumbing::DatabaseOps, + { + EventKindDebug { kind: self, db } + } +} + +impl fmt::Debug for EventKind { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + EventKind::DidValidateMemoizedValue { database_key } => fmt + .debug_struct("DidValidateMemoizedValue") + .field("database_key", database_key) + .finish(), + EventKind::WillBlockOn { other_runtime_id, database_key } => fmt + .debug_struct("WillBlockOn") + .field("other_runtime_id", other_runtime_id) + .field("database_key", database_key) + .finish(), + EventKind::WillExecute { database_key } => { + fmt.debug_struct("WillExecute").field("database_key", database_key).finish() + } + EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(), + } + } +} + +struct EventKindDebug<'me, D: ?Sized> +where + D: plumbing::DatabaseOps, +{ + kind: &'me EventKind, + db: &'me D, +} + +impl<'me, D: ?Sized> fmt::Debug for EventKindDebug<'me, D> +where + D: plumbing::DatabaseOps, +{ + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.kind { + EventKind::DidValidateMemoizedValue { database_key } => fmt + .debug_struct("DidValidateMemoizedValue") + .field("database_key", &database_key.debug(self.db)) + .finish(), + EventKind::WillBlockOn { other_runtime_id, database_key } => fmt + .debug_struct("WillBlockOn") + .field("other_runtime_id", &other_runtime_id) + .field("database_key", &database_key.debug(self.db)) + .finish(), + EventKind::WillExecute { database_key } => fmt + .debug_struct("WillExecute") + .field("database_key", &database_key.debug(self.db)) + .finish(), + EventKind::WillCheckCancellation => fmt.debug_struct("WillCheckCancellation").finish(), + } + } +} + +/// Indicates a database that also supports parallel query +/// evaluation. All of Salsa's base query support is capable of +/// parallel execution, but for it to work, your query key/value types +/// must also be `Send`, as must any additional data in your database. +pub trait ParallelDatabase: Database + Send { + /// Creates a second handle to the database that holds the + /// database fixed at a particular revision. So long as this + /// "frozen" handle exists, any attempt to [`set`] an input will + /// block. + /// + /// [`set`]: struct.QueryTable.html#method.set + /// + /// This is the method you are meant to use most of the time in a + /// parallel setting where modifications may arise asynchronously + /// (e.g., a language server). In this context, it is common to + /// wish to "fork off" a snapshot of the database performing some + /// series of queries in parallel and arranging the results. Using + /// this method for that purpose ensures that those queries will + /// see a consistent view of the database (it is also advisable + /// for those queries to use the [`Runtime::unwind_if_cancelled`] + /// method to check for cancellation). + /// + /// # Panics + /// + /// It is not permitted to create a snapshot from inside of a + /// query. Attepting to do so will panic. + /// + /// # Deadlock warning + /// + /// The intended pattern for snapshots is that, once created, they + /// are sent to another thread and used from there. As such, the + /// `snapshot` acquires a "read lock" on the database -- + /// therefore, so long as the `snapshot` is not dropped, any + /// attempt to `set` a value in the database will block. If the + /// `snapshot` is owned by the same thread that is attempting to + /// `set`, this will cause a problem. + /// + /// # How to implement this + /// + /// Typically, this method will create a second copy of your + /// database type (`MyDatabaseType`, in the example below), + /// cloning over each of the fields from `self` into this new + /// copy. For the field that stores the salsa runtime, you should + /// use [the `Runtime::snapshot` method][rfm] to create a snapshot of the + /// runtime. Finally, package up the result using `Snapshot::new`, + /// which is a simple wrapper type that only gives `&self` access + /// to the database within (thus preventing the use of methods + /// that may mutate the inputs): + /// + /// [rfm]: struct.Runtime.html#method.snapshot + /// + /// ```rust,ignore + /// impl ParallelDatabase for MyDatabaseType { + /// fn snapshot(&self) -> Snapshot { + /// Snapshot::new( + /// MyDatabaseType { + /// runtime: self.runtime.snapshot(self), + /// other_field: self.other_field.clone(), + /// } + /// ) + /// } + /// } + /// ``` + fn snapshot(&self) -> Snapshot; +} + +/// Simple wrapper struct that takes ownership of a database `DB` and +/// only gives `&self` access to it. See [the `snapshot` method][fm] +/// for more details. +/// +/// [fm]: trait.ParallelDatabase.html#method.snapshot +#[derive(Debug)] +pub struct Snapshot +where + DB: ParallelDatabase, +{ + db: DB, +} + +impl Snapshot +where + DB: ParallelDatabase, +{ + /// Creates a `Snapshot` that wraps the given database handle + /// `db`. From this point forward, only shared references to `db` + /// will be possible. + pub fn new(db: DB) -> Self { + Snapshot { db } + } +} + +impl std::ops::Deref for Snapshot +where + DB: ParallelDatabase, +{ + type Target = DB; + + fn deref(&self) -> &DB { + &self.db + } +} + +/// An integer that uniquely identifies a particular query instance within the +/// database. Used to track dependencies between queries. Fully ordered and +/// equatable but those orderings are arbitrary, and meant to be used only for +/// inserting into maps and the like. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct DatabaseKeyIndex { + group_index: u16, + query_index: u16, + key_index: u32, +} + +impl DatabaseKeyIndex { + /// Returns the index of the query group containing this key. + #[inline] + pub fn group_index(self) -> u16 { + self.group_index + } + + /// Returns the index of the query within its query group. + #[inline] + pub fn query_index(self) -> u16 { + self.query_index + } + + /// Returns the index of this particular query key within the query. + #[inline] + pub fn key_index(self) -> u32 { + self.key_index + } + + /// Returns a type that gives a user-readable debug output. + /// Use like `println!("{:?}", index.debug(db))`. + pub fn debug(self, db: &D) -> impl std::fmt::Debug + '_ + where + D: plumbing::DatabaseOps, + { + DatabaseKeyIndexDebug { index: self, db } + } +} + +/// Helper type for `DatabaseKeyIndex::debug` +struct DatabaseKeyIndexDebug<'me, D: ?Sized> +where + D: plumbing::DatabaseOps, +{ + index: DatabaseKeyIndex, + db: &'me D, +} + +impl std::fmt::Debug for DatabaseKeyIndexDebug<'_, D> +where + D: plumbing::DatabaseOps, +{ + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.db.fmt_index(self.index, fmt) + } +} + +/// Trait implements by all of the "special types" associated with +/// each of your queries. +/// +/// Base trait of `Query` that has a lifetime parameter to allow the `DynDb` to be non-'static. +pub trait QueryDb<'d>: Sized { + /// Dyn version of the associated trait for this query group. + type DynDb: ?Sized + Database + HasQueryGroup + 'd; + + /// Associate query group struct. + type Group: plumbing::QueryGroup; + + /// Generated struct that contains storage for all queries in a group. + type GroupStorage; +} + +/// Trait implements by all of the "special types" associated with +/// each of your queries. +pub trait Query: Debug + Default + Sized + for<'d> QueryDb<'d> { + /// Type that you you give as a parameter -- for queries with zero + /// or more than one input, this will be a tuple. + type Key: Clone + Debug + Hash + Eq; + + /// What value does the query return? + type Value: Clone + Debug; + + /// Internal struct storing the values for the query. + // type Storage: plumbing::QueryStorageOps; + type Storage; + + /// A unique index identifying this query within the group. + const QUERY_INDEX: u16; + + /// Name of the query method (e.g., `foo`) + const QUERY_NAME: &'static str; + + /// Extact storage for this query from the storage for its group. + fn query_storage<'a>( + group_storage: &'a >::GroupStorage, + ) -> &'a std::sync::Arc; + + /// Extact storage for this query from the storage for its group. + fn query_storage_mut<'a>( + group_storage: &'a >::GroupStorage, + ) -> &'a std::sync::Arc; +} + +/// Return value from [the `query` method] on `Database`. +/// Gives access to various less common operations on queries. +/// +/// [the `query` method]: trait.Database.html#method.query +pub struct QueryTable<'me, Q> +where + Q: Query, +{ + db: &'me >::DynDb, + storage: &'me Q::Storage, +} + +impl<'me, Q> QueryTable<'me, Q> +where + Q: Query, + Q::Storage: QueryStorageOps, +{ + /// Constructs a new `QueryTable`. + pub fn new(db: &'me >::DynDb, storage: &'me Q::Storage) -> Self { + Self { db, storage } + } + + /// Execute the query on a given input. Usually it's easier to + /// invoke the trait method directly. Note that for variadic + /// queries (those with no inputs, or those with more than one + /// input) the key will be a tuple. + pub fn get(&self, key: Q::Key) -> Q::Value { + self.storage.fetch(self.db, &key) + } + + /// Completely clears the storage for this query. + /// + /// This method breaks internal invariants of salsa, so any further queries + /// might return nonsense results. It is useful only in very specific + /// circumstances -- for example, when one wants to observe which values + /// dropped together with the table + pub fn purge(&self) + where + Q::Storage: plumbing::QueryStorageMassOps, + { + self.storage.purge(); + } +} + +/// Return value from [the `query_mut` method] on `Database`. +/// Gives access to the `set` method, notably, that is used to +/// set the value of an input query. +/// +/// [the `query_mut` method]: trait.Database.html#method.query_mut +pub struct QueryTableMut<'me, Q> +where + Q: Query + 'me, +{ + runtime: &'me mut Runtime, + storage: &'me Q::Storage, +} + +impl<'me, Q> QueryTableMut<'me, Q> +where + Q: Query, +{ + /// Constructs a new `QueryTableMut`. + pub fn new(runtime: &'me mut Runtime, storage: &'me Q::Storage) -> Self { + Self { runtime, storage } + } + + /// Assign a value to an "input query". Must be used outside of + /// an active query computation. + /// + /// If you are using `snapshot`, see the notes on blocking + /// and cancellation on [the `query_mut` method]. + /// + /// [the `query_mut` method]: trait.Database.html#method.query_mut + pub fn set(&mut self, key: Q::Key, value: Q::Value) + where + Q::Storage: plumbing::InputQueryStorageOps, + { + self.set_with_durability(key, value, Durability::LOW); + } + + /// Assign a value to an "input query", with the additional + /// promise that this value will **never change**. Must be used + /// outside of an active query computation. + /// + /// If you are using `snapshot`, see the notes on blocking + /// and cancellation on [the `query_mut` method]. + /// + /// [the `query_mut` method]: trait.Database.html#method.query_mut + pub fn set_with_durability(&mut self, key: Q::Key, value: Q::Value, durability: Durability) + where + Q::Storage: plumbing::InputQueryStorageOps, + { + self.storage.set(self.runtime, &key, value, durability); + } + + /// Sets the size of LRU cache of values for this query table. + /// + /// That is, at most `cap` values will be preset in the table at the same + /// time. This helps with keeping maximum memory usage under control, at the + /// cost of potential extra recalculations of evicted values. + /// + /// If `cap` is zero, all values are preserved, this is the default. + pub fn set_lru_capacity(&self, cap: usize) + where + Q::Storage: plumbing::LruQueryStorageOps, + { + self.storage.set_lru_capacity(cap); + } + + /// Marks the computed value as outdated. + /// + /// This causes salsa to re-execute the query function on the next access to + /// the query, even if all dependencies are up to date. + /// + /// This is most commonly used as part of the [on-demand input + /// pattern](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html). + pub fn invalidate(&mut self, key: &Q::Key) + where + Q::Storage: plumbing::DerivedQueryStorageOps, + { + self.storage.invalidate(self.runtime, key) + } +} + +/// A panic payload indicating that execution of a salsa query was cancelled. +/// +/// This can occur for a few reasons: +/// * +/// * +/// * +#[derive(Debug)] +#[non_exhaustive] +pub enum Cancelled { + /// The query was operating on revision R, but there is a pending write to move to revision R+1. + #[non_exhaustive] + PendingWrite, + + /// The query was blocked on another thread, and that thread panicked. + #[non_exhaustive] + PropagatedPanic, +} + +impl Cancelled { + fn throw(self) -> ! { + // We use resume and not panic here to avoid running the panic + // hook (that is, to avoid collecting and printing backtrace). + std::panic::resume_unwind(Box::new(self)); + } + + /// Runs `f`, and catches any salsa cancellation. + pub fn catch(f: F) -> Result + where + F: FnOnce() -> T + UnwindSafe, + { + match panic::catch_unwind(f) { + Ok(t) => Ok(t), + Err(payload) => match payload.downcast() { + Ok(cancelled) => Err(*cancelled), + Err(payload) => panic::resume_unwind(payload), + }, + } + } +} + +impl std::fmt::Display for Cancelled { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let why = match self { + Cancelled::PendingWrite => "pending write", + Cancelled::PropagatedPanic => "propagated panic", + }; + f.write_str("cancelled because of ")?; + f.write_str(why) + } +} + +impl std::error::Error for Cancelled {} + +/// Captures the participants of a cycle that occurred when executing a query. +/// +/// This type is meant to be used to help give meaningful error messages to the +/// user or to help salsa developers figure out why their program is resulting +/// in a computation cycle. +/// +/// It is used in a few ways: +/// +/// * During [cycle recovery](https://https://salsa-rs.github.io/salsa/cycles/fallback.html), +/// where it is given to the fallback function. +/// * As the panic value when an unexpected cycle (i.e., a cycle where one or more participants +/// lacks cycle recovery information) occurs. +/// +/// You can read more about cycle handling in +/// the [salsa book](https://https://salsa-rs.github.io/salsa/cycles.html). +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Cycle { + participants: plumbing::CycleParticipants, +} + +impl Cycle { + pub(crate) fn new(participants: plumbing::CycleParticipants) -> Self { + Self { participants } + } + + /// True if two `Cycle` values represent the same cycle. + pub(crate) fn is(&self, cycle: &Cycle) -> bool { + triomphe::Arc::ptr_eq(&self.participants, &cycle.participants) + } + + pub(crate) fn throw(self) -> ! { + tracing::debug!("throwing cycle {:?}", self); + std::panic::resume_unwind(Box::new(self)) + } + + pub(crate) fn catch(execute: impl FnOnce() -> T) -> Result { + match std::panic::catch_unwind(AssertUnwindSafe(execute)) { + Ok(v) => Ok(v), + Err(err) => match err.downcast::() { + Ok(cycle) => Err(*cycle), + Err(other) => std::panic::resume_unwind(other), + }, + } + } + + /// Iterate over the [`DatabaseKeyIndex`] for each query participating + /// in the cycle. The start point of this iteration within the cycle + /// is arbitrary but deterministic, but the ordering is otherwise determined + /// by the execution. + pub fn participant_keys(&self) -> impl Iterator + '_ { + self.participants.iter().copied() + } + + /// Returns a vector with the debug information for + /// all the participants in the cycle. + pub fn all_participants(&self, db: &DB) -> Vec { + self.participant_keys().map(|d| format!("{:?}", d.debug(db))).collect() + } + + /// Returns a vector with the debug information for + /// those participants in the cycle that lacked recovery + /// information. + pub fn unexpected_participants(&self, db: &DB) -> Vec { + self.participant_keys() + .filter(|&d| db.cycle_recovery_strategy(d) == CycleRecoveryStrategy::Panic) + .map(|d| format!("{:?}", d.debug(db))) + .collect() + } + + /// Returns a "debug" view onto this strict that can be used to print out information. + pub fn debug<'me, DB: ?Sized + Database>(&'me self, db: &'me DB) -> impl std::fmt::Debug + 'me { + struct UnexpectedCycleDebug<'me> { + c: &'me Cycle, + db: &'me dyn Database, + } + + impl<'me> std::fmt::Debug for UnexpectedCycleDebug<'me> { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt.debug_struct("UnexpectedCycle") + .field("all_participants", &self.c.all_participants(self.db)) + .field("unexpected_participants", &self.c.unexpected_participants(self.db)) + .finish() + } + } + + UnexpectedCycleDebug { c: self, db: db.ops_database() } + } +} + +// Re-export the procedural macros. +#[allow(unused_imports)] +#[macro_use] +extern crate salsa_macros; +use plumbing::HasQueryGroup; +pub use salsa_macros::*; diff --git a/crates/salsa/src/lru.rs b/crates/salsa/src/lru.rs new file mode 100644 index 0000000000..c6b9778f20 --- /dev/null +++ b/crates/salsa/src/lru.rs @@ -0,0 +1,325 @@ +//! +use oorandom::Rand64; +use parking_lot::Mutex; +use std::fmt::Debug; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use triomphe::Arc; + +/// A simple and approximate concurrent lru list. +/// +/// We assume but do not verify that each node is only used with one +/// list. If this is not the case, it is not *unsafe*, but panics and +/// weird results will ensue. +/// +/// Each "node" in the list is of type `Node` and must implement +/// `LruNode`, which is a trait that gives access to a field that +/// stores the index in the list. This index gives us a rough idea of +/// how recently the node has been used. +#[derive(Debug)] +pub(crate) struct Lru +where + Node: LruNode, +{ + green_zone: AtomicUsize, + data: Mutex>, +} + +#[derive(Debug)] +struct LruData { + end_red_zone: usize, + end_yellow_zone: usize, + end_green_zone: usize, + rng: Rand64, + entries: Vec>, +} + +pub(crate) trait LruNode: Sized + Debug { + fn lru_index(&self) -> &LruIndex; +} + +#[derive(Debug)] +pub(crate) struct LruIndex { + /// Index in the approprate LRU list, or std::usize::MAX if not a + /// member. + index: AtomicUsize, +} + +impl Default for Lru +where + Node: LruNode, +{ + fn default() -> Self { + Lru::new() + } +} + +// We always use a fixed seed for our randomness so that we have +// predictable results. +const LRU_SEED: &str = "Hello, Rustaceans"; + +impl Lru +where + Node: LruNode, +{ + /// Creates a new LRU list where LRU caching is disabled. + pub(crate) fn new() -> Self { + Self::with_seed(LRU_SEED) + } + + #[cfg_attr(not(test), allow(dead_code))] + fn with_seed(seed: &str) -> Self { + Lru { green_zone: AtomicUsize::new(0), data: Mutex::new(LruData::with_seed(seed)) } + } + + /// Adjust the total number of nodes permitted to have a value at + /// once. If `len` is zero, this disables LRU caching completely. + pub(crate) fn set_lru_capacity(&self, len: usize) { + let mut data = self.data.lock(); + + // We require each zone to have at least 1 slot. Therefore, + // the length cannot be just 1 or 2. + if len == 0 { + self.green_zone.store(0, Ordering::Release); + data.resize(0, 0, 0); + } else { + let len = std::cmp::max(len, 3); + + // Top 10% is the green zone. This must be at least length 1. + let green_zone = std::cmp::max(len / 10, 1); + + // Next 20% is the yellow zone. + let yellow_zone = std::cmp::max(len / 5, 1); + + // Remaining 70% is the red zone. + let red_zone = len - yellow_zone - green_zone; + + // We need quick access to the green zone. + self.green_zone.store(green_zone, Ordering::Release); + + // Resize existing array. + data.resize(green_zone, yellow_zone, red_zone); + } + } + + /// Records that `node` was used. This may displace an old node (if the LRU limits are + pub(crate) fn record_use(&self, node: &Arc) -> Option> { + tracing::debug!("record_use(node={:?})", node); + + // Load green zone length and check if the LRU cache is even enabled. + let green_zone = self.green_zone.load(Ordering::Acquire); + tracing::debug!("record_use: green_zone={}", green_zone); + if green_zone == 0 { + return None; + } + + // Find current index of list (if any) and the current length + // of our green zone. + let index = node.lru_index().load(); + tracing::debug!("record_use: index={}", index); + + // Already a member of the list, and in the green zone -- nothing to do! + if index < green_zone { + return None; + } + + self.data.lock().record_use(node) + } + + pub(crate) fn purge(&self) { + self.green_zone.store(0, Ordering::SeqCst); + *self.data.lock() = LruData::with_seed(LRU_SEED); + } +} + +impl LruData +where + Node: LruNode, +{ + fn with_seed(seed_str: &str) -> Self { + Self::with_rng(rng_with_seed(seed_str)) + } + + fn with_rng(rng: Rand64) -> Self { + LruData { end_yellow_zone: 0, end_green_zone: 0, end_red_zone: 0, entries: Vec::new(), rng } + } + + fn green_zone(&self) -> std::ops::Range { + 0..self.end_green_zone + } + + fn yellow_zone(&self) -> std::ops::Range { + self.end_green_zone..self.end_yellow_zone + } + + fn red_zone(&self) -> std::ops::Range { + self.end_yellow_zone..self.end_red_zone + } + + fn resize(&mut self, len_green_zone: usize, len_yellow_zone: usize, len_red_zone: usize) { + self.end_green_zone = len_green_zone; + self.end_yellow_zone = self.end_green_zone + len_yellow_zone; + self.end_red_zone = self.end_yellow_zone + len_red_zone; + let entries = std::mem::replace(&mut self.entries, Vec::with_capacity(self.end_red_zone)); + + tracing::debug!("green_zone = {:?}", self.green_zone()); + tracing::debug!("yellow_zone = {:?}", self.yellow_zone()); + tracing::debug!("red_zone = {:?}", self.red_zone()); + + // We expect to resize when the LRU cache is basically empty. + // So just forget all the old LRU indices to start. + for entry in entries { + entry.lru_index().clear(); + } + } + + /// Records that a node was used. If it is already a member of the + /// LRU list, it is promoted to the green zone (unless it's + /// already there). Otherwise, it is added to the list first and + /// *then* promoted to the green zone. Adding a new node to the + /// list may displace an old member of the red zone, in which case + /// that is returned. + fn record_use(&mut self, node: &Arc) -> Option> { + tracing::debug!("record_use(node={:?})", node); + + // NB: When this is invoked, we have typically already loaded + // the LRU index (to check if it is in green zone). But that + // check was done outside the lock and -- for all we know -- + // the index may have changed since. So we always reload. + let index = node.lru_index().load(); + + if index < self.end_green_zone { + None + } else if index < self.end_yellow_zone { + self.promote_yellow_to_green(node, index); + None + } else if index < self.end_red_zone { + self.promote_red_to_green(node, index); + None + } else { + self.insert_new(node) + } + } + + /// Inserts a node that is not yet a member of the LRU list. If + /// the list is at capacity, this can displace an existing member. + fn insert_new(&mut self, node: &Arc) -> Option> { + debug_assert!(!node.lru_index().is_in_lru()); + + // Easy case: we still have capacity. Push it, and then promote + // it up to the appropriate zone. + let len = self.entries.len(); + if len < self.end_red_zone { + self.entries.push(node.clone()); + node.lru_index().store(len); + tracing::debug!("inserted node {:?} at {}", node, len); + return self.record_use(node); + } + + // Harder case: no capacity. Create some by evicting somebody from red + // zone and then promoting. + let victim_index = self.pick_index(self.red_zone()); + let victim_node = std::mem::replace(&mut self.entries[victim_index], node.clone()); + tracing::debug!("evicting red node {:?} from {}", victim_node, victim_index); + victim_node.lru_index().clear(); + self.promote_red_to_green(node, victim_index); + Some(victim_node) + } + + /// Promotes the node `node`, stored at `red_index` (in the red + /// zone), into a green index, demoting yellow/green nodes at + /// random. + /// + /// NB: It is not required that `node.lru_index()` is up-to-date + /// when entering this method. + fn promote_red_to_green(&mut self, node: &Arc, red_index: usize) { + debug_assert!(self.red_zone().contains(&red_index)); + + // Pick a yellow at random and switch places with it. + // + // Subtle: we do not update `node.lru_index` *yet* -- we're + // going to invoke `self.promote_yellow` next, and it will get + // updated then. + let yellow_index = self.pick_index(self.yellow_zone()); + tracing::debug!( + "demoting yellow node {:?} from {} to red at {}", + self.entries[yellow_index], + yellow_index, + red_index, + ); + self.entries.swap(yellow_index, red_index); + self.entries[red_index].lru_index().store(red_index); + + // Now move ourselves up into the green zone. + self.promote_yellow_to_green(node, yellow_index); + } + + /// Promotes the node `node`, stored at `yellow_index` (in the + /// yellow zone), into a green index, demoting a green node at + /// random to replace it. + /// + /// NB: It is not required that `node.lru_index()` is up-to-date + /// when entering this method. + fn promote_yellow_to_green(&mut self, node: &Arc, yellow_index: usize) { + debug_assert!(self.yellow_zone().contains(&yellow_index)); + + // Pick a yellow at random and switch places with it. + let green_index = self.pick_index(self.green_zone()); + tracing::debug!( + "demoting green node {:?} from {} to yellow at {}", + self.entries[green_index], + green_index, + yellow_index + ); + self.entries.swap(green_index, yellow_index); + self.entries[yellow_index].lru_index().store(yellow_index); + node.lru_index().store(green_index); + + tracing::debug!("promoted {:?} to green index {}", node, green_index); + } + + fn pick_index(&mut self, zone: std::ops::Range) -> usize { + let end_index = std::cmp::min(zone.end, self.entries.len()); + self.rng.rand_range(zone.start as u64..end_index as u64) as usize + } +} + +impl Default for LruIndex { + fn default() -> Self { + Self { index: AtomicUsize::new(std::usize::MAX) } + } +} + +impl LruIndex { + fn load(&self) -> usize { + self.index.load(Ordering::Acquire) // see note on ordering below + } + + fn store(&self, value: usize) { + self.index.store(value, Ordering::Release) // see note on ordering below + } + + fn clear(&self) { + self.store(std::usize::MAX); + } + + fn is_in_lru(&self) -> bool { + self.load() != std::usize::MAX + } +} + +fn rng_with_seed(seed_str: &str) -> Rand64 { + let mut seed: [u8; 16] = [0; 16]; + for (i, &b) in seed_str.as_bytes().iter().take(16).enumerate() { + seed[i] = b; + } + Rand64::new(u128::from_le_bytes(seed)) +} + +// A note on ordering: +// +// I chose to use AcqRel for the ordering but I don't think it's +// strictly needed. All writes occur under a lock, so they should be +// ordered w/r/t one another. As for the reads, they can occur +// outside the lock, but they don't themselves enable dependent reads +// -- if the reads are out of bounds, we would acquire a lock. diff --git a/crates/salsa/src/plumbing.rs b/crates/salsa/src/plumbing.rs new file mode 100644 index 0000000000..71332e39ca --- /dev/null +++ b/crates/salsa/src/plumbing.rs @@ -0,0 +1,238 @@ +//! +#![allow(missing_docs)] + +use crate::debug::TableEntry; +use crate::durability::Durability; +use crate::Cycle; +use crate::Database; +use crate::Query; +use crate::QueryTable; +use crate::QueryTableMut; +use std::borrow::Borrow; +use std::fmt::Debug; +use std::hash::Hash; +use triomphe::Arc; + +pub use crate::derived::DependencyStorage; +pub use crate::derived::MemoizedStorage; +pub use crate::input::{InputStorage, UnitInputStorage}; +pub use crate::interned::InternedStorage; +pub use crate::interned::LookupInternedStorage; +pub use crate::{revision::Revision, DatabaseKeyIndex, QueryDb, Runtime}; + +/// Defines various associated types. An impl of this +/// should be generated for your query-context type automatically by +/// the `database_storage` macro, so you shouldn't need to mess +/// with this trait directly. +pub trait DatabaseStorageTypes: Database { + /// Defines the "storage type", where all the query data is kept. + /// This type is defined by the `database_storage` macro. + type DatabaseStorage: Default; +} + +/// Internal operations that the runtime uses to operate on the database. +pub trait DatabaseOps { + /// Upcast this type to a `dyn Database`. + fn ops_database(&self) -> &dyn Database; + + /// Gives access to the underlying salsa runtime. + fn ops_salsa_runtime(&self) -> &Runtime; + + /// Gives access to the underlying salsa runtime. + fn ops_salsa_runtime_mut(&mut self) -> &mut Runtime; + + /// Formats a database key index in a human readable fashion. + fn fmt_index( + &self, + index: DatabaseKeyIndex, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result; + + /// True if the computed value for `input` may have changed since `revision`. + fn maybe_changed_after(&self, input: DatabaseKeyIndex, revision: Revision) -> bool; + + /// Find the `CycleRecoveryStrategy` for a given input. + fn cycle_recovery_strategy(&self, input: DatabaseKeyIndex) -> CycleRecoveryStrategy; + + /// Executes the callback for each kind of query. + fn for_each_query(&self, op: &mut dyn FnMut(&dyn QueryStorageMassOps)); +} + +/// Internal operations performed on the query storage as a whole +/// (note that these ops do not need to know the identity of the +/// query, unlike `QueryStorageOps`). +pub trait QueryStorageMassOps { + fn purge(&self); +} + +pub trait DatabaseKey: Clone + Debug + Eq + Hash {} + +pub trait QueryFunction: Query { + /// See `CycleRecoveryStrategy` + const CYCLE_STRATEGY: CycleRecoveryStrategy; + + fn execute(db: &>::DynDb, key: Self::Key) -> Self::Value; + + fn cycle_fallback( + db: &>::DynDb, + cycle: &Cycle, + key: &Self::Key, + ) -> Self::Value { + let _ = (db, cycle, key); + panic!("query `{:?}` doesn't support cycle fallback", Self::default()) + } +} + +/// Cycle recovery strategy: Is this query capable of recovering from +/// a cycle that results from executing the function? If so, how? +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum CycleRecoveryStrategy { + /// Cannot recover from cycles: panic. + /// + /// This is the default. It is also what happens if a cycle + /// occurs and the queries involved have different recovery + /// strategies. + /// + /// In the case of a failure due to a cycle, the panic + /// value will be XXX (FIXME). + Panic, + + /// Recovers from cycles by storing a sentinel value. + /// + /// This value is computed by the `QueryFunction::cycle_fallback` + /// function. + Fallback, +} + +/// Create a query table, which has access to the storage for the query +/// and offers methods like `get`. +pub fn get_query_table<'me, Q>(db: &'me >::DynDb) -> QueryTable<'me, Q> +where + Q: Query + 'me, + Q::Storage: QueryStorageOps, +{ + let group_storage: &Q::GroupStorage = HasQueryGroup::group_storage(db); + let query_storage: &Q::Storage = Q::query_storage(group_storage); + QueryTable::new(db, query_storage) +} + +/// Create a mutable query table, which has access to the storage +/// for the query and offers methods like `set`. +pub fn get_query_table_mut<'me, Q>(db: &'me mut >::DynDb) -> QueryTableMut<'me, Q> +where + Q: Query, +{ + let (group_storage, runtime) = HasQueryGroup::group_storage_mut(db); + let query_storage = Q::query_storage_mut(group_storage); + QueryTableMut::new(runtime, &**query_storage) +} + +pub trait QueryGroup: Sized { + type GroupStorage; + + /// Dyn version of the associated database trait. + type DynDb: ?Sized + Database + HasQueryGroup; +} + +/// Trait implemented by a database for each group that it supports. +/// `S` and `K` are the types for *group storage* and *group key*, respectively. +pub trait HasQueryGroup: Database +where + G: QueryGroup, +{ + /// Access the group storage struct from the database. + fn group_storage(&self) -> &G::GroupStorage; + + /// Access the group storage struct from the database. + /// Also returns a ref to the `Runtime`, since otherwise + /// the database is borrowed and one cannot get access to it. + fn group_storage_mut(&mut self) -> (&G::GroupStorage, &mut Runtime); +} + +// ANCHOR:QueryStorageOps +pub trait QueryStorageOps +where + Self: QueryStorageMassOps, + Q: Query, +{ + // ANCHOR_END:QueryStorageOps + + /// See CycleRecoveryStrategy + const CYCLE_STRATEGY: CycleRecoveryStrategy; + + fn new(group_index: u16) -> Self; + + /// Format a database key index in a suitable way. + fn fmt_index( + &self, + db: &>::DynDb, + index: DatabaseKeyIndex, + fmt: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result; + + // ANCHOR:maybe_changed_after + /// True if the value of `input`, which must be from this query, may have + /// changed after the given revision ended. + /// + /// This function should only be invoked with a revision less than the current + /// revision. + fn maybe_changed_after( + &self, + db: &>::DynDb, + input: DatabaseKeyIndex, + revision: Revision, + ) -> bool; + // ANCHOR_END:maybe_changed_after + + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { + Self::CYCLE_STRATEGY + } + + // ANCHOR:fetch + /// Execute the query, returning the result (often, the result + /// will be memoized). This is the "main method" for + /// queries. + /// + /// Returns `Err` in the event of a cycle, meaning that computing + /// the value for this `key` is recursively attempting to fetch + /// itself. + fn fetch(&self, db: &>::DynDb, key: &Q::Key) -> Q::Value; + // ANCHOR_END:fetch + + /// Returns the durability associated with a given key. + fn durability(&self, db: &>::DynDb, key: &Q::Key) -> Durability; + + /// Get the (current) set of the entries in the query storage + fn entries(&self, db: &>::DynDb) -> C + where + C: std::iter::FromIterator>; +} + +/// An optional trait that is implemented for "user mutable" storage: +/// that is, storage whose value is not derived from other storage but +/// is set independently. +pub trait InputQueryStorageOps +where + Q: Query, +{ + fn set(&self, runtime: &mut Runtime, key: &Q::Key, new_value: Q::Value, durability: Durability); +} + +/// An optional trait that is implemented for "user mutable" storage: +/// that is, storage whose value is not derived from other storage but +/// is set independently. +pub trait LruQueryStorageOps { + fn set_lru_capacity(&self, new_capacity: usize); +} + +pub trait DerivedQueryStorageOps +where + Q: Query, +{ + fn invalidate(&self, runtime: &mut Runtime, key: &S) + where + S: Eq + Hash, + Q::Key: Borrow; +} + +pub type CycleParticipants = Arc>; diff --git a/crates/salsa/src/revision.rs b/crates/salsa/src/revision.rs new file mode 100644 index 0000000000..d97aaf9deb --- /dev/null +++ b/crates/salsa/src/revision.rs @@ -0,0 +1,67 @@ +//! +use std::num::NonZeroU32; +use std::sync::atomic::{AtomicU32, Ordering}; + +/// Value of the initial revision, as a u32. We don't use 0 +/// because we want to use a `NonZeroU32`. +const START: u32 = 1; + +/// A unique identifier for the current version of the database; each +/// time an input is changed, the revision number is incremented. +/// `Revision` is used internally to track which values may need to be +/// recomputed, but is not something you should have to interact with +/// directly as a user of salsa. +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct Revision { + generation: NonZeroU32, +} + +impl Revision { + pub(crate) fn start() -> Self { + Self::from(START) + } + + pub(crate) fn from(g: u32) -> Self { + Self { generation: NonZeroU32::new(g).unwrap() } + } + + pub(crate) fn next(self) -> Revision { + Self::from(self.generation.get() + 1) + } + + fn as_u32(self) -> u32 { + self.generation.get() + } +} + +impl std::fmt::Debug for Revision { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(fmt, "R{}", self.generation) + } +} + +#[derive(Debug)] +pub(crate) struct AtomicRevision { + data: AtomicU32, +} + +impl AtomicRevision { + pub(crate) fn start() -> Self { + Self { data: AtomicU32::new(START) } + } + + pub(crate) fn load(&self) -> Revision { + Revision::from(self.data.load(Ordering::SeqCst)) + } + + pub(crate) fn store(&self, r: Revision) { + self.data.store(r.as_u32(), Ordering::SeqCst); + } + + /// Increment by 1, returning previous value. + pub(crate) fn fetch_then_increment(&self) -> Revision { + let v = self.data.fetch_add(1, Ordering::SeqCst); + assert!(v != u32::max_value(), "revision overflow"); + Revision::from(v) + } +} diff --git a/crates/salsa/src/runtime.rs b/crates/salsa/src/runtime.rs new file mode 100644 index 0000000000..40b8856991 --- /dev/null +++ b/crates/salsa/src/runtime.rs @@ -0,0 +1,667 @@ +//! +use crate::durability::Durability; +use crate::hash::FxIndexSet; +use crate::plumbing::CycleRecoveryStrategy; +use crate::revision::{AtomicRevision, Revision}; +use crate::{Cancelled, Cycle, Database, DatabaseKeyIndex, Event, EventKind}; +use parking_lot::lock_api::{RawRwLock, RawRwLockRecursive}; +use parking_lot::{Mutex, RwLock}; +use std::hash::Hash; +use std::panic::panic_any; +use std::sync::atomic::{AtomicUsize, Ordering}; +use tracing::debug; +use triomphe::Arc; + +mod dependency_graph; +use dependency_graph::DependencyGraph; + +pub(crate) mod local_state; +use local_state::LocalState; + +use self::local_state::{ActiveQueryGuard, QueryInputs, QueryRevisions}; + +/// The salsa runtime stores the storage for all queries as well as +/// tracking the query stack and dependencies between cycles. +/// +/// Each new runtime you create (e.g., via `Runtime::new` or +/// `Runtime::default`) will have an independent set of query storage +/// associated with it. Normally, therefore, you only do this once, at +/// the start of your application. +pub struct Runtime { + /// Our unique runtime id. + id: RuntimeId, + + /// If this is a "forked" runtime, then the `revision_guard` will + /// be `Some`; this guard holds a read-lock on the global query + /// lock. + revision_guard: Option, + + /// Local state that is specific to this runtime (thread). + local_state: LocalState, + + /// Shared state that is accessible via all runtimes. + shared_state: Arc, +} + +#[derive(Clone, Debug)] +pub(crate) enum WaitResult { + Completed, + Panicked, + Cycle(Cycle), +} + +impl Default for Runtime { + fn default() -> Self { + Runtime { + id: RuntimeId { counter: 0 }, + revision_guard: None, + shared_state: Default::default(), + local_state: Default::default(), + } + } +} + +impl std::fmt::Debug for Runtime { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt.debug_struct("Runtime") + .field("id", &self.id()) + .field("forked", &self.revision_guard.is_some()) + .field("shared_state", &self.shared_state) + .finish() + } +} + +impl Runtime { + /// Create a new runtime; equivalent to `Self::default`. This is + /// used when creating a new database. + pub fn new() -> Self { + Self::default() + } + + /// See [`crate::storage::Storage::snapshot`]. + pub(crate) fn snapshot(&self) -> Self { + if self.local_state.query_in_progress() { + panic!("it is not legal to `snapshot` during a query (see salsa-rs/salsa#80)"); + } + + let revision_guard = RevisionGuard::new(&self.shared_state); + + let id = RuntimeId { counter: self.shared_state.next_id.fetch_add(1, Ordering::SeqCst) }; + + Runtime { + id, + revision_guard: Some(revision_guard), + shared_state: self.shared_state.clone(), + local_state: Default::default(), + } + } + + /// A "synthetic write" causes the system to act *as though* some + /// input of durability `durability` has changed. This is mostly + /// useful for profiling scenarios. + /// + /// **WARNING:** Just like an ordinary write, this method triggers + /// cancellation. If you invoke it while a snapshot exists, it + /// will block until that snapshot is dropped -- if that snapshot + /// is owned by the current thread, this could trigger deadlock. + pub fn synthetic_write(&mut self, durability: Durability) { + self.with_incremented_revision(|_next_revision| Some(durability)); + } + + /// The unique identifier attached to this `SalsaRuntime`. Each + /// snapshotted runtime has a distinct identifier. + #[inline] + pub fn id(&self) -> RuntimeId { + self.id + } + + /// Returns the database-key for the query that this thread is + /// actively executing (if any). + pub fn active_query(&self) -> Option { + self.local_state.active_query() + } + + /// Read current value of the revision counter. + #[inline] + pub(crate) fn current_revision(&self) -> Revision { + self.shared_state.revisions[0].load() + } + + /// The revision in which values with durability `d` may have last + /// changed. For D0, this is just the current revision. But for + /// higher levels of durability, this value may lag behind the + /// current revision. If we encounter a value of durability Di, + /// then, we can check this function to get a "bound" on when the + /// value may have changed, which allows us to skip walking its + /// dependencies. + #[inline] + pub(crate) fn last_changed_revision(&self, d: Durability) -> Revision { + self.shared_state.revisions[d.index()].load() + } + + /// Read current value of the revision counter. + #[inline] + pub(crate) fn pending_revision(&self) -> Revision { + self.shared_state.pending_revision.load() + } + + #[cold] + pub(crate) fn unwind_cancelled(&self) { + self.report_untracked_read(); + Cancelled::PendingWrite.throw(); + } + + /// Acquires the **global query write lock** (ensuring that no queries are + /// executing) and then increments the current revision counter; invokes + /// `op` with the global query write lock still held. + /// + /// While we wait to acquire the global query write lock, this method will + /// also increment `pending_revision_increments`, thus signalling to queries + /// that their results are "cancelled" and they should abort as expeditiously + /// as possible. + /// + /// The `op` closure should actually perform the writes needed. It is given + /// the new revision as an argument, and its return value indicates whether + /// any pre-existing value was modified: + /// + /// - returning `None` means that no pre-existing value was modified (this + /// could occur e.g. when setting some key on an input that was never set + /// before) + /// - returning `Some(d)` indicates that a pre-existing value was modified + /// and it had the durability `d`. This will update the records for when + /// values with each durability were modified. + /// + /// Note that, given our writer model, we can assume that only one thread is + /// attempting to increment the global revision at a time. + pub(crate) fn with_incremented_revision(&mut self, op: F) + where + F: FnOnce(Revision) -> Option, + { + tracing::debug!("increment_revision()"); + + if !self.permits_increment() { + panic!("increment_revision invoked during a query computation"); + } + + // Set the `pending_revision` field so that people + // know current revision is cancelled. + let current_revision = self.shared_state.pending_revision.fetch_then_increment(); + + // To modify the revision, we need the lock. + let shared_state = self.shared_state.clone(); + let _lock = shared_state.query_lock.write(); + + let old_revision = self.shared_state.revisions[0].fetch_then_increment(); + assert_eq!(current_revision, old_revision); + + let new_revision = current_revision.next(); + + debug!("increment_revision: incremented to {:?}", new_revision); + + if let Some(d) = op(new_revision) { + for rev in &self.shared_state.revisions[1..=d.index()] { + rev.store(new_revision); + } + } + } + + pub(crate) fn permits_increment(&self) -> bool { + self.revision_guard.is_none() && !self.local_state.query_in_progress() + } + + #[inline] + pub(crate) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> { + self.local_state.push_query(database_key_index) + } + + /// Reports that the currently active query read the result from + /// another query. + /// + /// Also checks whether the "cycle participant" flag is set on + /// the current stack frame -- if so, panics with `CycleParticipant` + /// value, which should be caught by the code executing the query. + /// + /// # Parameters + /// + /// - `database_key`: the query whose result was read + /// - `changed_revision`: the last revision in which the result of that + /// query had changed + pub(crate) fn report_query_read_and_unwind_if_cycle_resulted( + &self, + input: DatabaseKeyIndex, + durability: Durability, + changed_at: Revision, + ) { + self.local_state + .report_query_read_and_unwind_if_cycle_resulted(input, durability, changed_at); + } + + /// Reports that the query depends on some state unknown to salsa. + /// + /// Queries which report untracked reads will be re-executed in the next + /// revision. + pub fn report_untracked_read(&self) { + self.local_state.report_untracked_read(self.current_revision()); + } + + /// Acts as though the current query had read an input with the given durability; this will force the current query's durability to be at most `durability`. + /// + /// This is mostly useful to control the durability level for [on-demand inputs](https://salsa-rs.github.io/salsa/common_patterns/on_demand_inputs.html). + pub fn report_synthetic_read(&self, durability: Durability) { + let changed_at = self.last_changed_revision(durability); + self.local_state.report_synthetic_read(durability, changed_at); + } + + /// Handles a cycle in the dependency graph that was detected when the + /// current thread tried to block on `database_key_index` which is being + /// executed by `to_id`. If this function returns, then `to_id` no longer + /// depends on the current thread, and so we should continue executing + /// as normal. Otherwise, the function will throw a `Cycle` which is expected + /// to be caught by some frame on our stack. This occurs either if there is + /// a frame on our stack with cycle recovery (possibly the top one!) or if there + /// is no cycle recovery at all. + fn unblock_cycle_and_maybe_throw( + &self, + db: &dyn Database, + dg: &mut DependencyGraph, + database_key_index: DatabaseKeyIndex, + to_id: RuntimeId, + ) { + debug!("unblock_cycle_and_maybe_throw(database_key={:?})", database_key_index); + + let mut from_stack = self.local_state.take_query_stack(); + let from_id = self.id(); + + // Make a "dummy stack frame". As we iterate through the cycle, we will collect the + // inputs from each participant. Then, if we are participating in cycle recovery, we + // will propagate those results to all participants. + let mut cycle_query = ActiveQuery::new(database_key_index); + + // Identify the cycle participants: + let cycle = { + let mut v = vec![]; + dg.for_each_cycle_participant( + from_id, + &mut from_stack, + database_key_index, + to_id, + |aqs| { + aqs.iter_mut().for_each(|aq| { + cycle_query.add_from(aq); + v.push(aq.database_key_index); + }); + }, + ); + + // We want to give the participants in a deterministic order + // (at least for this execution, not necessarily across executions), + // no matter where it started on the stack. Find the minimum + // key and rotate it to the front. + let min = v.iter().min().unwrap(); + let index = v.iter().position(|p| p == min).unwrap(); + v.rotate_left(index); + + // No need to store extra memory. + v.shrink_to_fit(); + + Cycle::new(Arc::new(v)) + }; + debug!("cycle {:?}, cycle_query {:#?}", cycle.debug(db), cycle_query,); + + // We can remove the cycle participants from the list of dependencies; + // they are a strongly connected component (SCC) and we only care about + // dependencies to things outside the SCC that control whether it will + // form again. + cycle_query.remove_cycle_participants(&cycle); + + // Mark each cycle participant that has recovery set, along with + // any frames that come after them on the same thread. Those frames + // are going to be unwound so that fallback can occur. + dg.for_each_cycle_participant(from_id, &mut from_stack, database_key_index, to_id, |aqs| { + aqs.iter_mut() + .skip_while(|aq| match db.cycle_recovery_strategy(aq.database_key_index) { + CycleRecoveryStrategy::Panic => true, + CycleRecoveryStrategy::Fallback => false, + }) + .for_each(|aq| { + debug!("marking {:?} for fallback", aq.database_key_index.debug(db)); + aq.take_inputs_from(&cycle_query); + assert!(aq.cycle.is_none()); + aq.cycle = Some(cycle.clone()); + }); + }); + + // Unblock every thread that has cycle recovery with a `WaitResult::Cycle`. + // They will throw the cycle, which will be caught by the frame that has + // cycle recovery so that it can execute that recovery. + let (me_recovered, others_recovered) = + dg.maybe_unblock_runtimes_in_cycle(from_id, &from_stack, database_key_index, to_id); + + self.local_state.restore_query_stack(from_stack); + + if me_recovered { + // If the current thread has recovery, we want to throw + // so that it can begin. + cycle.throw() + } else if others_recovered { + // If other threads have recovery but we didn't: return and we will block on them. + } else { + // if nobody has recover, then we panic + panic_any(cycle); + } + } + + /// Block until `other_id` completes executing `database_key`; + /// panic or unwind in the case of a cycle. + /// + /// `query_mutex_guard` is the guard for the current query's state; + /// it will be dropped after we have successfully registered the + /// dependency. + /// + /// # Propagating panics + /// + /// If the thread `other_id` panics, then our thread is considered + /// cancelled, so this function will panic with a `Cancelled` value. + /// + /// # Cycle handling + /// + /// If the thread `other_id` already depends on the current thread, + /// and hence there is a cycle in the query graph, then this function + /// will unwind instead of returning normally. The method of unwinding + /// depends on the [`Self::mutual_cycle_recovery_strategy`] + /// of the cycle participants: + /// + /// * [`CycleRecoveryStrategy::Panic`]: panic with the [`Cycle`] as the value. + /// * [`CycleRecoveryStrategy::Fallback`]: initiate unwinding with [`CycleParticipant::unwind`]. + pub(crate) fn block_on_or_unwind( + &self, + db: &dyn Database, + database_key: DatabaseKeyIndex, + other_id: RuntimeId, + query_mutex_guard: QueryMutexGuard, + ) { + let mut dg = self.shared_state.dependency_graph.lock(); + + if dg.depends_on(other_id, self.id()) { + self.unblock_cycle_and_maybe_throw(db, &mut dg, database_key, other_id); + + // If the above fn returns, then (via cycle recovery) it has unblocked the + // cycle, so we can continue. + assert!(!dg.depends_on(other_id, self.id())); + } + + db.salsa_event(Event { + runtime_id: self.id(), + kind: EventKind::WillBlockOn { other_runtime_id: other_id, database_key }, + }); + + let stack = self.local_state.take_query_stack(); + + let (stack, result) = DependencyGraph::block_on( + dg, + self.id(), + database_key, + other_id, + stack, + query_mutex_guard, + ); + + self.local_state.restore_query_stack(stack); + + match result { + WaitResult::Completed => (), + + // If the other thread panicked, then we consider this thread + // cancelled. The assumption is that the panic will be detected + // by the other thread and responded to appropriately. + WaitResult::Panicked => Cancelled::PropagatedPanic.throw(), + + WaitResult::Cycle(c) => c.throw(), + } + } + + /// Invoked when this runtime completed computing `database_key` with + /// the given result `wait_result` (`wait_result` should be `None` if + /// computing `database_key` panicked and could not complete). + /// This function unblocks any dependent queries and allows them + /// to continue executing. + pub(crate) fn unblock_queries_blocked_on( + &self, + database_key: DatabaseKeyIndex, + wait_result: WaitResult, + ) { + self.shared_state + .dependency_graph + .lock() + .unblock_runtimes_blocked_on(database_key, wait_result); + } +} + +/// State that will be common to all threads (when we support multiple threads) +struct SharedState { + /// Stores the next id to use for a snapshotted runtime (starts at 1). + next_id: AtomicUsize, + + /// Whenever derived queries are executing, they acquire this lock + /// in read mode. Mutating inputs (and thus creating a new + /// revision) requires a write lock (thus guaranteeing that no + /// derived queries are in progress). Note that this is not needed + /// to prevent **race conditions** -- the revision counter itself + /// is stored in an `AtomicUsize` so it can be cheaply read + /// without acquiring the lock. Rather, the `query_lock` is used + /// to ensure a higher-level consistency property. + query_lock: RwLock<()>, + + /// This is typically equal to `revision` -- set to `revision+1` + /// when a new revision is pending (which implies that the current + /// revision is cancelled). + pending_revision: AtomicRevision, + + /// Stores the "last change" revision for values of each duration. + /// This vector is always of length at least 1 (for Durability 0) + /// but its total length depends on the number of durations. The + /// element at index 0 is special as it represents the "current + /// revision". In general, we have the invariant that revisions + /// in here are *declining* -- that is, `revisions[i] >= + /// revisions[i + 1]`, for all `i`. This is because when you + /// modify a value with durability D, that implies that values + /// with durability less than D may have changed too. + revisions: Vec, + + /// The dependency graph tracks which runtimes are blocked on one + /// another, waiting for queries to terminate. + dependency_graph: Mutex, +} + +impl SharedState { + fn with_durabilities(durabilities: usize) -> Self { + SharedState { + next_id: AtomicUsize::new(1), + query_lock: Default::default(), + revisions: (0..durabilities).map(|_| AtomicRevision::start()).collect(), + pending_revision: AtomicRevision::start(), + dependency_graph: Default::default(), + } + } +} + +impl std::panic::RefUnwindSafe for SharedState {} + +impl Default for SharedState { + fn default() -> Self { + Self::with_durabilities(Durability::LEN) + } +} + +impl std::fmt::Debug for SharedState { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let query_lock = if self.query_lock.try_write().is_some() { + "" + } else if self.query_lock.try_read().is_some() { + "" + } else { + "" + }; + fmt.debug_struct("SharedState") + .field("query_lock", &query_lock) + .field("revisions", &self.revisions) + .field("pending_revision", &self.pending_revision) + .finish() + } +} + +#[derive(Debug)] +struct ActiveQuery { + /// What query is executing + database_key_index: DatabaseKeyIndex, + + /// Minimum durability of inputs observed so far. + durability: Durability, + + /// Maximum revision of all inputs observed. If we observe an + /// untracked read, this will be set to the most recent revision. + changed_at: Revision, + + /// Set of subqueries that were accessed thus far, or `None` if + /// there was an untracked the read. + dependencies: Option>, + + /// Stores the entire cycle, if one is found and this query is part of it. + cycle: Option, +} + +impl ActiveQuery { + fn new(database_key_index: DatabaseKeyIndex) -> Self { + ActiveQuery { + database_key_index, + durability: Durability::MAX, + changed_at: Revision::start(), + dependencies: Some(FxIndexSet::default()), + cycle: None, + } + } + + fn add_read(&mut self, input: DatabaseKeyIndex, durability: Durability, revision: Revision) { + if let Some(set) = &mut self.dependencies { + set.insert(input); + } + + self.durability = self.durability.min(durability); + self.changed_at = self.changed_at.max(revision); + } + + fn add_untracked_read(&mut self, changed_at: Revision) { + self.dependencies = None; + self.durability = Durability::LOW; + self.changed_at = changed_at; + } + + fn add_synthetic_read(&mut self, durability: Durability, revision: Revision) { + self.dependencies = None; + self.durability = self.durability.min(durability); + self.changed_at = self.changed_at.max(revision); + } + + pub(crate) fn revisions(&self) -> QueryRevisions { + let inputs = match &self.dependencies { + None => QueryInputs::Untracked, + + Some(dependencies) => { + if dependencies.is_empty() { + QueryInputs::NoInputs + } else { + QueryInputs::Tracked { inputs: dependencies.iter().copied().collect() } + } + } + }; + + QueryRevisions { changed_at: self.changed_at, inputs, durability: self.durability } + } + + /// Adds any dependencies from `other` into `self`. + /// Used during cycle recovery, see [`Runtime::create_cycle_error`]. + fn add_from(&mut self, other: &ActiveQuery) { + self.changed_at = self.changed_at.max(other.changed_at); + self.durability = self.durability.min(other.durability); + if let Some(other_dependencies) = &other.dependencies { + if let Some(my_dependencies) = &mut self.dependencies { + my_dependencies.extend(other_dependencies.iter().copied()); + } + } else { + self.dependencies = None; + } + } + + /// Removes the participants in `cycle` from my dependencies. + /// Used during cycle recovery, see [`Runtime::create_cycle_error`]. + fn remove_cycle_participants(&mut self, cycle: &Cycle) { + if let Some(my_dependencies) = &mut self.dependencies { + for p in cycle.participant_keys() { + my_dependencies.remove(&p); + } + } + } + + /// Copy the changed-at, durability, and dependencies from `cycle_query`. + /// Used during cycle recovery, see [`Runtime::create_cycle_error`]. + pub(crate) fn take_inputs_from(&mut self, cycle_query: &ActiveQuery) { + self.changed_at = cycle_query.changed_at; + self.durability = cycle_query.durability; + self.dependencies = cycle_query.dependencies.clone(); + } +} + +/// A unique identifier for a particular runtime. Each time you create +/// a snapshot, a fresh `RuntimeId` is generated. Once a snapshot is +/// complete, its `RuntimeId` may potentially be re-used. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct RuntimeId { + counter: usize, +} + +#[derive(Clone, Debug)] +pub(crate) struct StampedValue { + pub(crate) value: V, + pub(crate) durability: Durability, + pub(crate) changed_at: Revision, +} + +struct RevisionGuard { + shared_state: Arc, +} + +impl RevisionGuard { + fn new(shared_state: &Arc) -> Self { + // Subtle: we use a "recursive" lock here so that it is not an + // error to acquire a read-lock when one is already held (this + // happens when a query uses `snapshot` to spawn off parallel + // workers, for example). + // + // This has the side-effect that we are responsible to ensure + // that people contending for the write lock do not starve, + // but this is what we achieve via the cancellation mechanism. + // + // (In particular, since we only ever have one "mutating + // handle" to the database, the only contention for the global + // query lock occurs when there are "futures" evaluating + // queries in parallel, and those futures hold a read-lock + // already, so the starvation problem is more about them bring + // themselves to a close, versus preventing other people from + // *starting* work). + unsafe { + shared_state.query_lock.raw().lock_shared_recursive(); + } + + Self { shared_state: shared_state.clone() } + } +} + +impl Drop for RevisionGuard { + fn drop(&mut self) { + // Release our read-lock without using RAII. As documented in + // `Snapshot::new` above, this requires the unsafe keyword. + unsafe { + self.shared_state.query_lock.raw().unlock_shared(); + } + } +} diff --git a/crates/salsa/src/runtime/dependency_graph.rs b/crates/salsa/src/runtime/dependency_graph.rs new file mode 100644 index 0000000000..e41eb280de --- /dev/null +++ b/crates/salsa/src/runtime/dependency_graph.rs @@ -0,0 +1,251 @@ +//! +use triomphe::Arc; + +use crate::{DatabaseKeyIndex, RuntimeId}; +use parking_lot::{Condvar, MutexGuard}; +use rustc_hash::FxHashMap; +use smallvec::SmallVec; + +use super::{ActiveQuery, WaitResult}; + +type QueryStack = Vec; + +#[derive(Debug, Default)] +pub(super) struct DependencyGraph { + /// A `(K -> V)` pair in this map indicates that the the runtime + /// `K` is blocked on some query executing in the runtime `V`. + /// This encodes a graph that must be acyclic (or else deadlock + /// will result). + edges: FxHashMap, + + /// Encodes the `RuntimeId` that are blocked waiting for the result + /// of a given query. + query_dependents: FxHashMap>, + + /// When a key K completes which had dependent queries Qs blocked on it, + /// it stores its `WaitResult` here. As they wake up, each query Q in Qs will + /// come here to fetch their results. + wait_results: FxHashMap, +} + +#[derive(Debug)] +struct Edge { + blocked_on_id: RuntimeId, + blocked_on_key: DatabaseKeyIndex, + stack: QueryStack, + + /// Signalled whenever a query with dependents completes. + /// Allows those dependents to check if they are ready to unblock. + condvar: Arc, +} + +impl DependencyGraph { + /// True if `from_id` depends on `to_id`. + /// + /// (i.e., there is a path from `from_id` to `to_id` in the graph.) + pub(super) fn depends_on(&mut self, from_id: RuntimeId, to_id: RuntimeId) -> bool { + let mut p = from_id; + while let Some(q) = self.edges.get(&p).map(|edge| edge.blocked_on_id) { + if q == to_id { + return true; + } + + p = q; + } + p == to_id + } + + /// Invokes `closure` with a `&mut ActiveQuery` for each query that participates in the cycle. + /// The cycle runs as follows: + /// + /// 1. The runtime `from_id`, which has the stack `from_stack`, would like to invoke `database_key`... + /// 2. ...but `database_key` is already being executed by `to_id`... + /// 3. ...and `to_id` is transitively dependent on something which is present on `from_stack`. + pub(super) fn for_each_cycle_participant( + &mut self, + from_id: RuntimeId, + from_stack: &mut QueryStack, + database_key: DatabaseKeyIndex, + to_id: RuntimeId, + mut closure: impl FnMut(&mut [ActiveQuery]), + ) { + debug_assert!(self.depends_on(to_id, from_id)); + + // To understand this algorithm, consider this [drawing](https://is.gd/TGLI9v): + // + // database_key = QB2 + // from_id = A + // to_id = B + // from_stack = [QA1, QA2, QA3] + // + // self.edges[B] = { C, QC2, [QB1..QB3] } + // self.edges[C] = { A, QA2, [QC1..QC3] } + // + // The cyclic + // edge we have + // failed to add. + // : + // A : B C + // : + // QA1 v QB1 QC1 + // ┌► QA2 ┌──► QB2 ┌─► QC2 + // │ QA3 ───┘ QB3 ──┘ QC3 ───┐ + // │ │ + // └───────────────────────────────┘ + // + // Final output: [QB2, QB3, QC2, QC3, QA2, QA3] + + let mut id = to_id; + let mut key = database_key; + while id != from_id { + // Looking at the diagram above, the idea is to + // take the edge from `to_id` starting at `key` + // (inclusive) and down to the end. We can then + // load up the next thread (i.e., we start at B/QB2, + // and then load up the dependency on C/QC2). + let edge = self.edges.get_mut(&id).unwrap(); + let prefix = edge.stack.iter_mut().take_while(|p| p.database_key_index != key).count(); + closure(&mut edge.stack[prefix..]); + id = edge.blocked_on_id; + key = edge.blocked_on_key; + } + + // Finally, we copy in the results from `from_stack`. + let prefix = from_stack.iter_mut().take_while(|p| p.database_key_index != key).count(); + closure(&mut from_stack[prefix..]); + } + + /// Unblock each blocked runtime (excluding the current one) if some + /// query executing in that runtime is participating in cycle fallback. + /// + /// Returns a boolean (Current, Others) where: + /// * Current is true if the current runtime has cycle participants + /// with fallback; + /// * Others is true if other runtimes were unblocked. + pub(super) fn maybe_unblock_runtimes_in_cycle( + &mut self, + from_id: RuntimeId, + from_stack: &QueryStack, + database_key: DatabaseKeyIndex, + to_id: RuntimeId, + ) -> (bool, bool) { + // See diagram in `for_each_cycle_participant`. + let mut id = to_id; + let mut key = database_key; + let mut others_unblocked = false; + while id != from_id { + let edge = self.edges.get(&id).unwrap(); + let prefix = edge.stack.iter().take_while(|p| p.database_key_index != key).count(); + let next_id = edge.blocked_on_id; + let next_key = edge.blocked_on_key; + + if let Some(cycle) = edge.stack[prefix..].iter().rev().find_map(|aq| aq.cycle.clone()) { + // Remove `id` from the list of runtimes blocked on `next_key`: + self.query_dependents.get_mut(&next_key).unwrap().retain(|r| *r != id); + + // Unblock runtime so that it can resume execution once lock is released: + self.unblock_runtime(id, WaitResult::Cycle(cycle)); + + others_unblocked = true; + } + + id = next_id; + key = next_key; + } + + let prefix = from_stack.iter().take_while(|p| p.database_key_index != key).count(); + let this_unblocked = from_stack[prefix..].iter().any(|aq| aq.cycle.is_some()); + + (this_unblocked, others_unblocked) + } + + /// Modifies the graph so that `from_id` is blocked + /// on `database_key`, which is being computed by + /// `to_id`. + /// + /// For this to be reasonable, the lock on the + /// results table for `database_key` must be held. + /// This ensures that computing `database_key` doesn't + /// complete before `block_on` executes. + /// + /// Preconditions: + /// * No path from `to_id` to `from_id` + /// (i.e., `me.depends_on(to_id, from_id)` is false) + /// * `held_mutex` is a read lock (or stronger) on `database_key` + pub(super) fn block_on( + mut me: MutexGuard<'_, Self>, + from_id: RuntimeId, + database_key: DatabaseKeyIndex, + to_id: RuntimeId, + from_stack: QueryStack, + query_mutex_guard: QueryMutexGuard, + ) -> (QueryStack, WaitResult) { + let condvar = me.add_edge(from_id, database_key, to_id, from_stack); + + // Release the mutex that prevents `database_key` + // from completing, now that the edge has been added. + drop(query_mutex_guard); + + loop { + if let Some(stack_and_result) = me.wait_results.remove(&from_id) { + debug_assert!(!me.edges.contains_key(&from_id)); + return stack_and_result; + } + condvar.wait(&mut me); + } + } + + /// Helper for `block_on`: performs actual graph modification + /// to add a dependency edge from `from_id` to `to_id`, which is + /// computing `database_key`. + fn add_edge( + &mut self, + from_id: RuntimeId, + database_key: DatabaseKeyIndex, + to_id: RuntimeId, + from_stack: QueryStack, + ) -> Arc { + assert_ne!(from_id, to_id); + debug_assert!(!self.edges.contains_key(&from_id)); + debug_assert!(!self.depends_on(to_id, from_id)); + + let condvar = Arc::new(Condvar::new()); + self.edges.insert( + from_id, + Edge { + blocked_on_id: to_id, + blocked_on_key: database_key, + stack: from_stack, + condvar: condvar.clone(), + }, + ); + self.query_dependents.entry(database_key).or_default().push(from_id); + condvar + } + + /// Invoked when runtime `to_id` completes executing + /// `database_key`. + pub(super) fn unblock_runtimes_blocked_on( + &mut self, + database_key: DatabaseKeyIndex, + wait_result: WaitResult, + ) { + let dependents = self.query_dependents.remove(&database_key).unwrap_or_default(); + + for from_id in dependents { + self.unblock_runtime(from_id, wait_result.clone()); + } + } + + /// Unblock the runtime with the given id with the given wait-result. + /// This will cause it resume execution (though it will have to grab + /// the lock on this data structure first, to recover the wait result). + fn unblock_runtime(&mut self, id: RuntimeId, wait_result: WaitResult) { + let edge = self.edges.remove(&id).expect("not blocked"); + self.wait_results.insert(id, (edge.stack, wait_result)); + + // Now that we have inserted the `wait_results`, + // notify the thread. + edge.condvar.notify_one(); + } +} diff --git a/crates/salsa/src/runtime/local_state.rs b/crates/salsa/src/runtime/local_state.rs new file mode 100644 index 0000000000..91b95dffe7 --- /dev/null +++ b/crates/salsa/src/runtime/local_state.rs @@ -0,0 +1,214 @@ +//! +use tracing::debug; + +use crate::durability::Durability; +use crate::runtime::ActiveQuery; +use crate::runtime::Revision; +use crate::Cycle; +use crate::DatabaseKeyIndex; +use std::cell::RefCell; +use triomphe::Arc; + +/// State that is specific to a single execution thread. +/// +/// Internally, this type uses ref-cells. +/// +/// **Note also that all mutations to the database handle (and hence +/// to the local-state) must be undone during unwinding.** +pub(super) struct LocalState { + /// Vector of active queries. + /// + /// This is normally `Some`, but it is set to `None` + /// while the query is blocked waiting for a result. + /// + /// Unwinding note: pushes onto this vector must be popped -- even + /// during unwinding. + query_stack: RefCell>>, +} + +/// Summarizes "all the inputs that a query used" +#[derive(Debug, Clone)] +pub(crate) struct QueryRevisions { + /// The most revision in which some input changed. + pub(crate) changed_at: Revision, + + /// Minimum durability of the inputs to this query. + pub(crate) durability: Durability, + + /// The inputs that went into our query, if we are tracking them. + pub(crate) inputs: QueryInputs, +} + +/// Every input. +#[derive(Debug, Clone)] +pub(crate) enum QueryInputs { + /// Non-empty set of inputs, fully known + Tracked { inputs: Arc<[DatabaseKeyIndex]> }, + + /// Empty set of inputs, fully known. + NoInputs, + + /// Unknown quantity of inputs + Untracked, +} + +impl Default for LocalState { + fn default() -> Self { + LocalState { query_stack: RefCell::new(Some(Vec::new())) } + } +} + +impl LocalState { + #[inline] + pub(super) fn push_query(&self, database_key_index: DatabaseKeyIndex) -> ActiveQueryGuard<'_> { + let mut query_stack = self.query_stack.borrow_mut(); + let query_stack = query_stack.as_mut().expect("local stack taken"); + query_stack.push(ActiveQuery::new(database_key_index)); + ActiveQueryGuard { local_state: self, database_key_index, push_len: query_stack.len() } + } + + fn with_query_stack(&self, c: impl FnOnce(&mut Vec) -> R) -> R { + c(self.query_stack.borrow_mut().as_mut().expect("query stack taken")) + } + + pub(super) fn query_in_progress(&self) -> bool { + self.with_query_stack(|stack| !stack.is_empty()) + } + + pub(super) fn active_query(&self) -> Option { + self.with_query_stack(|stack| { + stack.last().map(|active_query| active_query.database_key_index) + }) + } + + pub(super) fn report_query_read_and_unwind_if_cycle_resulted( + &self, + input: DatabaseKeyIndex, + durability: Durability, + changed_at: Revision, + ) { + debug!( + "report_query_read_and_unwind_if_cycle_resulted(input={:?}, durability={:?}, changed_at={:?})", + input, durability, changed_at + ); + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_read(input, durability, changed_at); + + // We are a cycle participant: + // + // C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0 + // ^ ^ + // : | + // This edge -----+ | + // | + // | + // N0 + // + // In this case, the value we have just read from `Ci+1` + // is actually the cycle fallback value and not especially + // interesting. We unwind now with `CycleParticipant` to avoid + // executing the rest of our query function. This unwinding + // will be caught and our own fallback value will be used. + // + // Note that `Ci+1` may` have *other* callers who are not + // participants in the cycle (e.g., N0 in the graph above). + // They will not have the `cycle` marker set in their + // stack frames, so they will just read the fallback value + // from `Ci+1` and continue on their merry way. + if let Some(cycle) = &top_query.cycle { + cycle.clone().throw() + } + } + }) + } + + pub(super) fn report_untracked_read(&self, current_revision: Revision) { + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_untracked_read(current_revision); + } + }) + } + + /// Update the top query on the stack to act as though it read a value + /// of durability `durability` which changed in `revision`. + pub(super) fn report_synthetic_read(&self, durability: Durability, revision: Revision) { + self.with_query_stack(|stack| { + if let Some(top_query) = stack.last_mut() { + top_query.add_synthetic_read(durability, revision); + } + }) + } + + /// Takes the query stack and returns it. This is used when + /// the current thread is blocking. The stack must be restored + /// with [`Self::restore_query_stack`] when the thread unblocks. + pub(super) fn take_query_stack(&self) -> Vec { + assert!(self.query_stack.borrow().is_some(), "query stack already taken"); + self.query_stack.take().unwrap() + } + + /// Restores a query stack taken with [`Self::take_query_stack`] once + /// the thread unblocks. + pub(super) fn restore_query_stack(&self, stack: Vec) { + assert!(self.query_stack.borrow().is_none(), "query stack not taken"); + self.query_stack.replace(Some(stack)); + } +} + +impl std::panic::RefUnwindSafe for LocalState {} + +/// When a query is pushed onto the `active_query` stack, this guard +/// is returned to represent its slot. The guard can be used to pop +/// the query from the stack -- in the case of unwinding, the guard's +/// destructor will also remove the query. +pub(crate) struct ActiveQueryGuard<'me> { + local_state: &'me LocalState, + push_len: usize, + database_key_index: DatabaseKeyIndex, +} + +impl ActiveQueryGuard<'_> { + fn pop_helper(&self) -> ActiveQuery { + self.local_state.with_query_stack(|stack| { + // Sanity check: pushes and pops should be balanced. + assert_eq!(stack.len(), self.push_len); + debug_assert_eq!(stack.last().unwrap().database_key_index, self.database_key_index); + stack.pop().unwrap() + }) + } + + /// Invoked when the query has successfully completed execution. + pub(super) fn complete(self) -> ActiveQuery { + let query = self.pop_helper(); + std::mem::forget(self); + query + } + + /// Pops an active query from the stack. Returns the [`QueryRevisions`] + /// which summarizes the other queries that were accessed during this + /// query's execution. + #[inline] + pub(crate) fn pop(self) -> QueryRevisions { + // Extract accumulated inputs. + let popped_query = self.complete(); + + // If this frame were a cycle participant, it would have unwound. + assert!(popped_query.cycle.is_none()); + + popped_query.revisions() + } + + /// If the active query is registered as a cycle participant, remove and + /// return that cycle. + pub(crate) fn take_cycle(&self) -> Option { + self.local_state.with_query_stack(|stack| stack.last_mut()?.cycle.take()) + } +} + +impl Drop for ActiveQueryGuard<'_> { + fn drop(&mut self) { + self.pop_helper(); + } +} diff --git a/crates/salsa/src/storage.rs b/crates/salsa/src/storage.rs new file mode 100644 index 0000000000..c0e6416f4a --- /dev/null +++ b/crates/salsa/src/storage.rs @@ -0,0 +1,54 @@ +//! +use crate::{plumbing::DatabaseStorageTypes, Runtime}; +use triomphe::Arc; + +/// Stores the cached results and dependency information for all the queries +/// defined on your salsa database. Also embeds a [`Runtime`] which is used to +/// manage query execution. Every database must include a `storage: +/// Storage` field. +pub struct Storage { + query_store: Arc, + runtime: Runtime, +} + +impl Default for Storage { + fn default() -> Self { + Self { query_store: Default::default(), runtime: Default::default() } + } +} + +impl Storage { + /// Gives access to the underlying salsa runtime. + pub fn salsa_runtime(&self) -> &Runtime { + &self.runtime + } + + /// Gives access to the underlying salsa runtime. + pub fn salsa_runtime_mut(&mut self) -> &mut Runtime { + &mut self.runtime + } + + /// Access the query storage tables. Not meant to be used directly by end + /// users. + pub fn query_store(&self) -> &DB::DatabaseStorage { + &self.query_store + } + + /// Access the query storage tables. Not meant to be used directly by end + /// users. + pub fn query_store_mut(&mut self) -> (&DB::DatabaseStorage, &mut Runtime) { + (&self.query_store, &mut self.runtime) + } + + /// Returns a "snapshotted" storage, suitable for use in a forked database. + /// This snapshot hold a read-lock on the global state, which means that any + /// attempt to `set` an input will block until the forked runtime is + /// dropped. See `ParallelDatabase::snapshot` for more information. + /// + /// **Warning.** This second handle is intended to be used from a separate + /// thread. Using two database handles from the **same thread** can lead to + /// deadlock. + pub fn snapshot(&self) -> Self { + Storage { query_store: self.query_store.clone(), runtime: self.runtime.snapshot() } + } +} diff --git a/crates/salsa/tests/cycles.rs b/crates/salsa/tests/cycles.rs new file mode 100644 index 0000000000..00ca533244 --- /dev/null +++ b/crates/salsa/tests/cycles.rs @@ -0,0 +1,493 @@ +use std::panic::UnwindSafe; + +use expect_test::expect; +use salsa::{Durability, ParallelDatabase, Snapshot}; +use test_log::test; + +// Axes: +// +// Threading +// * Intra-thread +// * Cross-thread -- part of cycle is on one thread, part on another +// +// Recovery strategies: +// * Panic +// * Fallback +// * Mixed -- multiple strategies within cycle participants +// +// Across revisions: +// * N/A -- only one revision +// * Present in new revision, not old +// * Present in old revision, not new +// * Present in both revisions +// +// Dependencies +// * Tracked +// * Untracked -- cycle participant(s) contain untracked reads +// +// Layers +// * Direct -- cycle participant is directly invoked from test +// * Indirect -- invoked a query that invokes the cycle +// +// +// | Thread | Recovery | Old, New | Dep style | Layers | Test Name | +// | ------ | -------- | -------- | --------- | ------ | --------- | +// | Intra | Panic | N/A | Tracked | direct | cycle_memoized | +// | Intra | Panic | N/A | Untracked | direct | cycle_volatile | +// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | +// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | +// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | +// | Intra | Fallback | New | Tracked | direct | cycle_appears | +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears | +// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | +// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | +// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | +// | Cross | Fallback | N/A | Tracked | both | parallel/cycles.rs: recover_parallel_cycle | +// | Cross | Panic | N/A | Tracked | both | parallel/cycles.rs: panic_parallel_cycle | + +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +struct Error { + cycle: Vec, +} + +#[salsa::database(GroupStruct)] +#[derive(Default)] +struct DatabaseImpl { + storage: salsa::Storage, +} + +impl salsa::Database for DatabaseImpl {} + +impl ParallelDatabase for DatabaseImpl { + fn snapshot(&self) -> Snapshot { + Snapshot::new(DatabaseImpl { storage: self.storage.snapshot() }) + } +} + +/// The queries A, B, and C in `Database` can be configured +/// to invoke one another in arbitrary ways using this +/// enum. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum CycleQuery { + None, + A, + B, + C, + AthenC, +} + +#[salsa::query_group(GroupStruct)] +trait Database: salsa::Database { + // `a` and `b` depend on each other and form a cycle + fn memoized_a(&self) -> (); + fn memoized_b(&self) -> (); + fn volatile_a(&self) -> (); + fn volatile_b(&self) -> (); + + #[salsa::input] + fn a_invokes(&self) -> CycleQuery; + + #[salsa::input] + fn b_invokes(&self) -> CycleQuery; + + #[salsa::input] + fn c_invokes(&self) -> CycleQuery; + + #[salsa::cycle(recover_a)] + fn cycle_a(&self) -> Result<(), Error>; + + #[salsa::cycle(recover_b)] + fn cycle_b(&self) -> Result<(), Error>; + + fn cycle_c(&self) -> Result<(), Error>; +} + +fn recover_a(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> { + Err(Error { cycle: cycle.all_participants(db) }) +} + +fn recover_b(db: &dyn Database, cycle: &salsa::Cycle) -> Result<(), Error> { + Err(Error { cycle: cycle.all_participants(db) }) +} + +fn memoized_a(db: &dyn Database) { + db.memoized_b() +} + +fn memoized_b(db: &dyn Database) { + db.memoized_a() +} + +fn volatile_a(db: &dyn Database) { + db.salsa_runtime().report_untracked_read(); + db.volatile_b() +} + +fn volatile_b(db: &dyn Database) { + db.salsa_runtime().report_untracked_read(); + db.volatile_a() +} + +impl CycleQuery { + fn invoke(self, db: &dyn Database) -> Result<(), Error> { + match self { + CycleQuery::A => db.cycle_a(), + CycleQuery::B => db.cycle_b(), + CycleQuery::C => db.cycle_c(), + CycleQuery::AthenC => { + let _ = db.cycle_a(); + db.cycle_c() + } + CycleQuery::None => Ok(()), + } + } +} + +fn cycle_a(db: &dyn Database) -> Result<(), Error> { + db.a_invokes().invoke(db) +} + +fn cycle_b(db: &dyn Database) -> Result<(), Error> { + db.b_invokes().invoke(db) +} + +fn cycle_c(db: &dyn Database) -> Result<(), Error> { + db.c_invokes().invoke(db) +} + +#[track_caller] +fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { + let v = std::panic::catch_unwind(f); + if let Err(d) = &v { + if let Some(cycle) = d.downcast_ref::() { + return cycle.clone(); + } + } + panic!("unexpected value: {:?}", v) +} + +#[test] +fn cycle_memoized() { + let db = DatabaseImpl::default(); + let cycle = extract_cycle(|| db.memoized_a()); + expect![[r#" + [ + "memoized_a(())", + "memoized_b(())", + ] + "#]] + .assert_debug_eq(&cycle.unexpected_participants(&db)); +} + +#[test] +fn cycle_volatile() { + let db = DatabaseImpl::default(); + let cycle = extract_cycle(|| db.volatile_a()); + expect![[r#" + [ + "volatile_a(())", + "volatile_b(())", + ] + "#]] + .assert_debug_eq(&cycle.unexpected_participants(&db)); +} + +#[test] +fn cycle_cycle() { + let mut query = DatabaseImpl::default(); + + // A --> B + // ^ | + // +-----+ + + query.set_a_invokes(CycleQuery::B); + query.set_b_invokes(CycleQuery::A); + + assert!(query.cycle_a().is_err()); +} + +#[test] +fn inner_cycle() { + let mut query = DatabaseImpl::default(); + + // A --> B <-- C + // ^ | + // +-----+ + + query.set_a_invokes(CycleQuery::B); + query.set_b_invokes(CycleQuery::A); + query.set_c_invokes(CycleQuery::B); + + let err = query.cycle_c(); + assert!(err.is_err()); + let cycle = err.unwrap_err().cycle; + expect![[r#" + [ + "cycle_a(())", + "cycle_b(())", + ] + "#]] + .assert_debug_eq(&cycle); +} + +#[test] +fn cycle_revalidate() { + let mut db = DatabaseImpl::default(); + + // A --> B + // ^ | + // +-----+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::A); + + assert!(db.cycle_a().is_err()); + db.set_b_invokes(CycleQuery::A); // same value as default + assert!(db.cycle_a().is_err()); +} + +#[test] +fn cycle_revalidate_unchanged_twice() { + let mut db = DatabaseImpl::default(); + + // A --> B + // ^ | + // +-----+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::A); + + assert!(db.cycle_a().is_err()); + db.set_c_invokes(CycleQuery::A); // force new revisi5on + + // on this run + expect![[r#" + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ) + "#]] + .assert_debug_eq(&db.cycle_a()); +} + +#[test] +fn cycle_appears() { + let mut db = DatabaseImpl::default(); + + // A --> B + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::None); + assert!(db.cycle_a().is_ok()); + + // A --> B + // ^ | + // +-----+ + db.set_b_invokes(CycleQuery::A); + tracing::debug!("Set Cycle Leaf"); + assert!(db.cycle_a().is_err()); +} + +#[test] +fn cycle_disappears() { + let mut db = DatabaseImpl::default(); + + // A --> B + // ^ | + // +-----+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::A); + assert!(db.cycle_a().is_err()); + + // A --> B + db.set_b_invokes(CycleQuery::None); + assert!(db.cycle_a().is_ok()); +} + +/// A variant on `cycle_disappears` in which the values of +/// `a_invokes` and `b_invokes` are set with durability values. +/// If we are not careful, this could cause us to overlook +/// the fact that the cycle will no longer occur. +#[test] +fn cycle_disappears_durability() { + let mut db = DatabaseImpl::default(); + db.set_a_invokes_with_durability(CycleQuery::B, Durability::LOW); + db.set_b_invokes_with_durability(CycleQuery::A, Durability::HIGH); + + let res = db.cycle_a(); + assert!(res.is_err()); + + // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However, + // because `b` participates in the same cycle as `a`, its final durability + // should be `LOW`. + // + // Check that setting a `LOW` input causes us to re-execute `b` query, and + // observe that the cycle goes away. + db.set_a_invokes_with_durability(CycleQuery::None, Durability::LOW); + + let res = db.cycle_b(); + assert!(res.is_ok()); +} + +#[test] +fn cycle_mixed_1() { + let mut db = DatabaseImpl::default(); + // A --> B <-- C + // | ^ + // +-----+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::C); + db.set_c_invokes(CycleQuery::B); + + let u = db.cycle_c(); + expect![[r#" + Err( + Error { + cycle: [ + "cycle_b(())", + "cycle_c(())", + ], + }, + ) + "#]] + .assert_debug_eq(&u); +} + +#[test] +fn cycle_mixed_2() { + let mut db = DatabaseImpl::default(); + + // Configuration: + // + // A --> B --> C + // ^ | + // +-----------+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::C); + db.set_c_invokes(CycleQuery::A); + + let u = db.cycle_a(); + expect![[r#" + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + "cycle_c(())", + ], + }, + ) + "#]] + .assert_debug_eq(&u); +} + +#[test] +fn cycle_deterministic_order() { + // No matter whether we start from A or B, we get the same set of participants: + let db = || { + let mut db = DatabaseImpl::default(); + // A --> B + // ^ | + // +-----+ + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::A); + db + }; + let a = db().cycle_a(); + let b = db().cycle_b(); + expect![[r#" + ( + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ), + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ), + ) + "#]] + .assert_debug_eq(&(a, b)); +} + +#[test] +fn cycle_multiple() { + // No matter whether we start from A or B, we get the same set of participants: + let mut db = DatabaseImpl::default(); + + // Configuration: + // + // A --> B <-- C + // ^ | ^ + // +-----+ | + // | | + // +-----+ + // + // Here, conceptually, B encounters a cycle with A and then + // recovers. + db.set_a_invokes(CycleQuery::B); + db.set_b_invokes(CycleQuery::AthenC); + db.set_c_invokes(CycleQuery::B); + + let c = db.cycle_c(); + let b = db.cycle_b(); + let a = db.cycle_a(); + expect![[r#" + ( + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ), + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ), + Err( + Error { + cycle: [ + "cycle_a(())", + "cycle_b(())", + ], + }, + ), + ) + "#]] + .assert_debug_eq(&(a, b, c)); +} + +#[test] +fn cycle_recovery_set_but_not_participating() { + let mut db = DatabaseImpl::default(); + + // A --> C -+ + // ^ | + // +--+ + db.set_a_invokes(CycleQuery::C); + db.set_c_invokes(CycleQuery::C); + + // Here we expect C to panic and A not to recover: + let r = extract_cycle(|| drop(db.cycle_a())); + expect![[r#" + [ + "cycle_c(())", + ] + "#]] + .assert_debug_eq(&r.all_participants(&db)); +} diff --git a/crates/salsa/tests/dyn_trait.rs b/crates/salsa/tests/dyn_trait.rs new file mode 100644 index 0000000000..09ebc5c4ce --- /dev/null +++ b/crates/salsa/tests/dyn_trait.rs @@ -0,0 +1,28 @@ +//! Test that you can implement a query using a `dyn Trait` setup. + +#[salsa::database(DynTraitStorage)] +#[derive(Default)] +struct DynTraitDatabase { + storage: salsa::Storage, +} + +impl salsa::Database for DynTraitDatabase {} + +#[salsa::query_group(DynTraitStorage)] +trait DynTrait { + #[salsa::input] + fn input(&self, x: u32) -> u32; + + fn output(&self, x: u32) -> u32; +} + +fn output(db: &dyn DynTrait, x: u32) -> u32 { + db.input(x) * 2 +} + +#[test] +fn dyn_trait() { + let mut query = DynTraitDatabase::default(); + query.set_input(22, 23); + assert_eq!(query.output(22), 46); +} diff --git a/crates/salsa/tests/incremental/constants.rs b/crates/salsa/tests/incremental/constants.rs new file mode 100644 index 0000000000..ea0eb81978 --- /dev/null +++ b/crates/salsa/tests/incremental/constants.rs @@ -0,0 +1,145 @@ +use crate::implementation::{TestContext, TestContextImpl}; +use salsa::debug::DebugQueryTable; +use salsa::Durability; + +#[salsa::query_group(Constants)] +pub(crate) trait ConstantsDatabase: TestContext { + #[salsa::input] + fn input(&self, key: char) -> usize; + + fn add(&self, key1: char, key2: char) -> usize; + + fn add3(&self, key1: char, key2: char, key3: char) -> usize; +} + +fn add(db: &dyn ConstantsDatabase, key1: char, key2: char) -> usize { + db.log().add(format!("add({}, {})", key1, key2)); + db.input(key1) + db.input(key2) +} + +fn add3(db: &dyn ConstantsDatabase, key1: char, key2: char, key3: char) -> usize { + db.log().add(format!("add3({}, {}, {})", key1, key2, key3)); + db.add(key1, key2) + db.input(key3) +} + +// Test we can assign a constant and things will be correctly +// recomputed afterwards. +#[test] +fn invalidate_constant() { + let db = &mut TestContextImpl::default(); + db.set_input_with_durability('a', 44, Durability::HIGH); + db.set_input_with_durability('b', 22, Durability::HIGH); + assert_eq!(db.add('a', 'b'), 66); + + db.set_input_with_durability('a', 66, Durability::HIGH); + assert_eq!(db.add('a', 'b'), 88); +} + +#[test] +fn invalidate_constant_1() { + let db = &mut TestContextImpl::default(); + + // Not constant: + db.set_input('a', 44); + assert_eq!(db.add('a', 'a'), 88); + + // Becomes constant: + db.set_input_with_durability('a', 44, Durability::HIGH); + assert_eq!(db.add('a', 'a'), 88); + + // Invalidates: + db.set_input_with_durability('a', 33, Durability::HIGH); + assert_eq!(db.add('a', 'a'), 66); +} + +// Test cases where we assign same value to 'a' after declaring it a +// constant. +#[test] +fn set_after_constant_same_value() { + let db = &mut TestContextImpl::default(); + db.set_input_with_durability('a', 44, Durability::HIGH); + db.set_input_with_durability('a', 44, Durability::HIGH); + db.set_input('a', 44); +} + +#[test] +fn not_constant() { + let mut db = TestContextImpl::default(); + + db.set_input('a', 22); + db.set_input('b', 44); + assert_eq!(db.add('a', 'b'), 66); + assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b'))); +} + +#[test] +fn durability() { + let mut db = TestContextImpl::default(); + + db.set_input_with_durability('a', 22, Durability::HIGH); + db.set_input_with_durability('b', 44, Durability::HIGH); + assert_eq!(db.add('a', 'b'), 66); + assert_eq!(Durability::HIGH, AddQuery.in_db(&db).durability(('a', 'b'))); +} + +#[test] +fn mixed_constant() { + let mut db = TestContextImpl::default(); + + db.set_input_with_durability('a', 22, Durability::HIGH); + db.set_input('b', 44); + assert_eq!(db.add('a', 'b'), 66); + assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b'))); +} + +#[test] +fn becomes_constant_with_change() { + let mut db = TestContextImpl::default(); + + db.set_input('a', 22); + db.set_input('b', 44); + assert_eq!(db.add('a', 'b'), 66); + assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b'))); + + db.set_input_with_durability('a', 23, Durability::HIGH); + assert_eq!(db.add('a', 'b'), 67); + assert_eq!(Durability::LOW, AddQuery.in_db(&db).durability(('a', 'b'))); + + db.set_input_with_durability('b', 45, Durability::HIGH); + assert_eq!(db.add('a', 'b'), 68); + assert_eq!(Durability::HIGH, AddQuery.in_db(&db).durability(('a', 'b'))); + + db.set_input_with_durability('b', 45, Durability::MEDIUM); + assert_eq!(db.add('a', 'b'), 68); + assert_eq!(Durability::MEDIUM, AddQuery.in_db(&db).durability(('a', 'b'))); +} + +// Test a subtle case in which an input changes from constant to +// non-constant, but its value doesn't change. If we're not careful, +// this can cause us to incorrectly consider derived values as still +// being constant. +#[test] +fn constant_to_non_constant() { + let mut db = TestContextImpl::default(); + + db.set_input_with_durability('a', 11, Durability::HIGH); + db.set_input_with_durability('b', 22, Durability::HIGH); + db.set_input_with_durability('c', 33, Durability::HIGH); + + // Here, `add3` invokes `add`, which yields 33. Both calls are + // constant. + assert_eq!(db.add3('a', 'b', 'c'), 66); + + db.set_input('a', 11); + + // Here, `add3` invokes `add`, which *still* yields 33, but which + // is no longer constant. Since value didn't change, we might + // preserve `add3` unchanged, not noticing that it is no longer + // constant. + assert_eq!(db.add3('a', 'b', 'c'), 66); + + // In that case, we would not get the correct result here, when + // 'a' changes *again*. + db.set_input('a', 22); + assert_eq!(db.add3('a', 'b', 'c'), 77); +} diff --git a/crates/salsa/tests/incremental/counter.rs b/crates/salsa/tests/incremental/counter.rs new file mode 100644 index 0000000000..c04857e24c --- /dev/null +++ b/crates/salsa/tests/incremental/counter.rs @@ -0,0 +1,14 @@ +use std::cell::Cell; + +#[derive(Default)] +pub(crate) struct Counter { + value: Cell, +} + +impl Counter { + pub(crate) fn increment(&self) -> usize { + let v = self.value.get(); + self.value.set(v + 1); + v + } +} diff --git a/crates/salsa/tests/incremental/implementation.rs b/crates/salsa/tests/incremental/implementation.rs new file mode 100644 index 0000000000..19752bba00 --- /dev/null +++ b/crates/salsa/tests/incremental/implementation.rs @@ -0,0 +1,59 @@ +use crate::constants; +use crate::counter::Counter; +use crate::log::Log; +use crate::memoized_dep_inputs; +use crate::memoized_inputs; +use crate::memoized_volatile; + +pub(crate) trait TestContext: salsa::Database { + fn clock(&self) -> &Counter; + fn log(&self) -> &Log; +} + +#[salsa::database( + constants::Constants, + memoized_dep_inputs::MemoizedDepInputs, + memoized_inputs::MemoizedInputs, + memoized_volatile::MemoizedVolatile +)] +#[derive(Default)] +pub(crate) struct TestContextImpl { + storage: salsa::Storage, + clock: Counter, + log: Log, +} + +impl TestContextImpl { + #[track_caller] + pub(crate) fn assert_log(&self, expected_log: &[&str]) { + let expected_text = &format!("{:#?}", expected_log); + let actual_text = &format!("{:#?}", self.log().take()); + + if expected_text == actual_text { + return; + } + + #[allow(clippy::print_stdout)] + for diff in dissimilar::diff(expected_text, actual_text) { + match diff { + dissimilar::Chunk::Delete(l) => println!("-{}", l), + dissimilar::Chunk::Equal(l) => println!(" {}", l), + dissimilar::Chunk::Insert(r) => println!("+{}", r), + } + } + + panic!("incorrect log results"); + } +} + +impl TestContext for TestContextImpl { + fn clock(&self) -> &Counter { + &self.clock + } + + fn log(&self) -> &Log { + &self.log + } +} + +impl salsa::Database for TestContextImpl {} diff --git a/crates/salsa/tests/incremental/log.rs b/crates/salsa/tests/incremental/log.rs new file mode 100644 index 0000000000..1ee57fe667 --- /dev/null +++ b/crates/salsa/tests/incremental/log.rs @@ -0,0 +1,16 @@ +use std::cell::RefCell; + +#[derive(Default)] +pub(crate) struct Log { + data: RefCell>, +} + +impl Log { + pub(crate) fn add(&self, text: impl Into) { + self.data.borrow_mut().push(text.into()); + } + + pub(crate) fn take(&self) -> Vec { + self.data.take() + } +} diff --git a/crates/salsa/tests/incremental/main.rs b/crates/salsa/tests/incremental/main.rs new file mode 100644 index 0000000000..bcd13c75f7 --- /dev/null +++ b/crates/salsa/tests/incremental/main.rs @@ -0,0 +1,9 @@ +mod constants; +mod counter; +mod implementation; +mod log; +mod memoized_dep_inputs; +mod memoized_inputs; +mod memoized_volatile; + +fn main() {} diff --git a/crates/salsa/tests/incremental/memoized_dep_inputs.rs b/crates/salsa/tests/incremental/memoized_dep_inputs.rs new file mode 100644 index 0000000000..4ea33e0c1a --- /dev/null +++ b/crates/salsa/tests/incremental/memoized_dep_inputs.rs @@ -0,0 +1,60 @@ +use crate::implementation::{TestContext, TestContextImpl}; + +#[salsa::query_group(MemoizedDepInputs)] +pub(crate) trait MemoizedDepInputsContext: TestContext { + fn dep_memoized2(&self) -> usize; + fn dep_memoized1(&self) -> usize; + #[salsa::dependencies] + fn dep_derived1(&self) -> usize; + #[salsa::input] + fn dep_input1(&self) -> usize; + #[salsa::input] + fn dep_input2(&self) -> usize; +} + +fn dep_memoized2(db: &dyn MemoizedDepInputsContext) -> usize { + db.log().add("Memoized2 invoked"); + db.dep_memoized1() +} + +fn dep_memoized1(db: &dyn MemoizedDepInputsContext) -> usize { + db.log().add("Memoized1 invoked"); + db.dep_derived1() * 2 +} + +fn dep_derived1(db: &dyn MemoizedDepInputsContext) -> usize { + db.log().add("Derived1 invoked"); + db.dep_input1() / 2 +} + +#[test] +fn revalidate() { + let db = &mut TestContextImpl::default(); + + db.set_dep_input1(0); + + // Initial run starts from Memoized2: + let v = db.dep_memoized2(); + assert_eq!(v, 0); + db.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Derived1 invoked"]); + + // After that, we first try to validate Memoized1 but wind up + // running Memoized2. Note that we don't try to validate + // Derived1, so it is invoked by Memoized1. + db.set_dep_input1(44); + let v = db.dep_memoized2(); + assert_eq!(v, 44); + db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]); + + // Here validation of Memoized1 succeeds so Memoized2 never runs. + db.set_dep_input1(45); + let v = db.dep_memoized2(); + assert_eq!(v, 44); + db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]); + + // Here, a change to input2 doesn't affect us, so nothing runs. + db.set_dep_input2(45); + let v = db.dep_memoized2(); + assert_eq!(v, 44); + db.assert_log(&[]); +} diff --git a/crates/salsa/tests/incremental/memoized_inputs.rs b/crates/salsa/tests/incremental/memoized_inputs.rs new file mode 100644 index 0000000000..53d2ace887 --- /dev/null +++ b/crates/salsa/tests/incremental/memoized_inputs.rs @@ -0,0 +1,76 @@ +use crate::implementation::{TestContext, TestContextImpl}; + +#[salsa::query_group(MemoizedInputs)] +pub(crate) trait MemoizedInputsContext: TestContext { + fn max(&self) -> usize; + #[salsa::input] + fn input1(&self) -> usize; + #[salsa::input] + fn input2(&self) -> usize; +} + +fn max(db: &dyn MemoizedInputsContext) -> usize { + db.log().add("Max invoked"); + std::cmp::max(db.input1(), db.input2()) +} + +#[test] +fn revalidate() { + let db = &mut TestContextImpl::default(); + + db.set_input1(0); + db.set_input2(0); + + let v = db.max(); + assert_eq!(v, 0); + db.assert_log(&["Max invoked"]); + + let v = db.max(); + assert_eq!(v, 0); + db.assert_log(&[]); + + db.set_input1(44); + db.assert_log(&[]); + + let v = db.max(); + assert_eq!(v, 44); + db.assert_log(&["Max invoked"]); + + let v = db.max(); + assert_eq!(v, 44); + db.assert_log(&[]); + + db.set_input1(44); + db.assert_log(&[]); + db.set_input2(66); + db.assert_log(&[]); + db.set_input1(64); + db.assert_log(&[]); + + let v = db.max(); + assert_eq!(v, 66); + db.assert_log(&["Max invoked"]); + + let v = db.max(); + assert_eq!(v, 66); + db.assert_log(&[]); +} + +/// Test that invoking `set` on an input with the same value still +/// triggers a new revision. +#[test] +fn set_after_no_change() { + let db = &mut TestContextImpl::default(); + + db.set_input2(0); + + db.set_input1(44); + let v = db.max(); + assert_eq!(v, 44); + db.assert_log(&["Max invoked"]); + + db.set_input1(44); + let v = db.max(); + assert_eq!(v, 44); + db.assert_log(&["Max invoked"]); +} diff --git a/crates/salsa/tests/incremental/memoized_volatile.rs b/crates/salsa/tests/incremental/memoized_volatile.rs new file mode 100644 index 0000000000..6dc5030063 --- /dev/null +++ b/crates/salsa/tests/incremental/memoized_volatile.rs @@ -0,0 +1,77 @@ +use crate::implementation::{TestContext, TestContextImpl}; +use salsa::{Database, Durability}; + +#[salsa::query_group(MemoizedVolatile)] +pub(crate) trait MemoizedVolatileContext: TestContext { + // Queries for testing a "volatile" value wrapped by + // memoization. + fn memoized2(&self) -> usize; + fn memoized1(&self) -> usize; + fn volatile(&self) -> usize; +} + +fn memoized2(db: &dyn MemoizedVolatileContext) -> usize { + db.log().add("Memoized2 invoked"); + db.memoized1() +} + +fn memoized1(db: &dyn MemoizedVolatileContext) -> usize { + db.log().add("Memoized1 invoked"); + let v = db.volatile(); + v / 2 +} + +fn volatile(db: &dyn MemoizedVolatileContext) -> usize { + db.log().add("Volatile invoked"); + db.salsa_runtime().report_untracked_read(); + db.clock().increment() +} + +#[test] +fn volatile_x2() { + let query = TestContextImpl::default(); + + // Invoking volatile twice doesn't execute twice, because volatile + // queries are memoized by default. + query.volatile(); + query.volatile(); + query.assert_log(&["Volatile invoked"]); +} + +/// Test that: +/// +/// - On the first run of R0, we recompute everything. +/// - On the second run of R1, we recompute nothing. +/// - On the first run of R1, we recompute Memoized1 but not Memoized2 (since Memoized1 result +/// did not change). +/// - On the second run of R1, we recompute nothing. +/// - On the first run of R2, we recompute everything (since Memoized1 result *did* change). +#[test] +fn revalidate() { + let mut query = TestContextImpl::default(); + + query.memoized2(); + query.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Volatile invoked"]); + + query.memoized2(); + query.assert_log(&[]); + + // Second generation: volatile will change (to 1) but memoized1 + // will not (still 0, as 1/2 = 0) + query.salsa_runtime_mut().synthetic_write(Durability::LOW); + query.memoized2(); + query.assert_log(&["Volatile invoked", "Memoized1 invoked"]); + query.memoized2(); + query.assert_log(&[]); + + // Third generation: volatile will change (to 2) and memoized1 + // will too (to 1). Therefore, after validating that Memoized1 + // changed, we now invoke Memoized2. + query.salsa_runtime_mut().synthetic_write(Durability::LOW); + + query.memoized2(); + query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]); + + query.memoized2(); + query.assert_log(&[]); +} diff --git a/crates/salsa/tests/interned.rs b/crates/salsa/tests/interned.rs new file mode 100644 index 0000000000..d097e41cfd --- /dev/null +++ b/crates/salsa/tests/interned.rs @@ -0,0 +1,90 @@ +//! Test that you can implement a query using a `dyn Trait` setup. + +use salsa::InternId; + +#[salsa::database(InternStorage)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database {} + +impl salsa::ParallelDatabase for Database { + fn snapshot(&self) -> salsa::Snapshot { + salsa::Snapshot::new(Database { storage: self.storage.snapshot() }) + } +} + +#[salsa::query_group(InternStorage)] +trait Intern { + #[salsa::interned] + fn intern1(&self, x: String) -> InternId; + + #[salsa::interned] + fn intern2(&self, x: String, y: String) -> InternId; + + #[salsa::interned] + fn intern_key(&self, x: String) -> InternKey; +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct InternKey(InternId); + +impl salsa::InternKey for InternKey { + fn from_intern_id(v: InternId) -> Self { + InternKey(v) + } + + fn as_intern_id(&self) -> InternId { + self.0 + } +} + +#[test] +fn test_intern1() { + let db = Database::default(); + let foo0 = db.intern1("foo".to_owned()); + let bar0 = db.intern1("bar".to_owned()); + let foo1 = db.intern1("foo".to_owned()); + let bar1 = db.intern1("bar".to_owned()); + + assert_eq!(foo0, foo1); + assert_eq!(bar0, bar1); + assert_ne!(foo0, bar0); + + assert_eq!("foo".to_owned(), db.lookup_intern1(foo0)); + assert_eq!("bar".to_owned(), db.lookup_intern1(bar0)); +} + +#[test] +fn test_intern2() { + let db = Database::default(); + let foo0 = db.intern2("x".to_owned(), "foo".to_owned()); + let bar0 = db.intern2("x".to_owned(), "bar".to_owned()); + let foo1 = db.intern2("x".to_owned(), "foo".to_owned()); + let bar1 = db.intern2("x".to_owned(), "bar".to_owned()); + + assert_eq!(foo0, foo1); + assert_eq!(bar0, bar1); + assert_ne!(foo0, bar0); + + assert_eq!(("x".to_owned(), "foo".to_owned()), db.lookup_intern2(foo0)); + assert_eq!(("x".to_owned(), "bar".to_owned()), db.lookup_intern2(bar0)); +} + +#[test] +fn test_intern_key() { + let db = Database::default(); + let foo0 = db.intern_key("foo".to_owned()); + let bar0 = db.intern_key("bar".to_owned()); + let foo1 = db.intern_key("foo".to_owned()); + let bar1 = db.intern_key("bar".to_owned()); + + assert_eq!(foo0, foo1); + assert_eq!(bar0, bar1); + assert_ne!(foo0, bar0); + + assert_eq!("foo".to_owned(), db.lookup_intern_key(foo0)); + assert_eq!("bar".to_owned(), db.lookup_intern_key(bar0)); +} diff --git a/crates/salsa/tests/lru.rs b/crates/salsa/tests/lru.rs new file mode 100644 index 0000000000..3da8519b08 --- /dev/null +++ b/crates/salsa/tests/lru.rs @@ -0,0 +1,102 @@ +//! Test setting LRU actually limits the number of things in the database; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +#[derive(Debug, PartialEq, Eq)] +struct HotPotato(u32); + +static N_POTATOES: AtomicUsize = AtomicUsize::new(0); + +impl HotPotato { + fn new(id: u32) -> HotPotato { + N_POTATOES.fetch_add(1, Ordering::SeqCst); + HotPotato(id) + } +} + +impl Drop for HotPotato { + fn drop(&mut self) { + N_POTATOES.fetch_sub(1, Ordering::SeqCst); + } +} + +#[salsa::query_group(QueryGroupStorage)] +trait QueryGroup: salsa::Database { + fn get(&self, x: u32) -> Arc; + fn get_volatile(&self, x: u32) -> usize; +} + +fn get(_db: &dyn QueryGroup, x: u32) -> Arc { + Arc::new(HotPotato::new(x)) +} + +fn get_volatile(db: &dyn QueryGroup, _x: u32) -> usize { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + db.salsa_runtime().report_untracked_read(); + COUNTER.fetch_add(1, Ordering::SeqCst) +} + +#[salsa::database(QueryGroupStorage)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database {} + +#[test] +fn lru_works() { + let mut db = Database::default(); + GetQuery.in_db_mut(&mut db).set_lru_capacity(32); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0); + + for i in 0..128u32 { + let p = db.get(i); + assert_eq!(p.0, i) + } + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32); + + for i in 0..128u32 { + let p = db.get(i); + assert_eq!(p.0, i) + } + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32); + + GetQuery.in_db_mut(&mut db).set_lru_capacity(32); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32); + + GetQuery.in_db_mut(&mut db).set_lru_capacity(64); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 32); + for i in 0..128u32 { + let p = db.get(i); + assert_eq!(p.0, i) + } + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64); + + // Special case: setting capacity to zero disables LRU + GetQuery.in_db_mut(&mut db).set_lru_capacity(0); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 64); + for i in 0..128u32 { + let p = db.get(i); + assert_eq!(p.0, i) + } + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 128); + + drop(db); + assert_eq!(N_POTATOES.load(Ordering::SeqCst), 0); +} + +#[test] +fn lru_doesnt_break_volatile_queries() { + let mut db = Database::default(); + GetVolatileQuery.in_db_mut(&mut db).set_lru_capacity(32); + // Here, we check that we execute each volatile query at most once, despite + // LRU. That does mean that we have more values in DB than the LRU capacity, + // but it's much better than inconsistent results from volatile queries! + for i in (0..3).flat_map(|_| 0..128usize) { + let x = db.get_volatile(i as u32); + assert_eq!(x, i) + } +} diff --git a/crates/salsa/tests/macros.rs b/crates/salsa/tests/macros.rs new file mode 100644 index 0000000000..3d818e53c8 --- /dev/null +++ b/crates/salsa/tests/macros.rs @@ -0,0 +1,11 @@ +#[salsa::query_group(MyStruct)] +trait MyDatabase: salsa::Database { + #[salsa::invoke(another_module::another_name)] + fn my_query(&self, key: ()) -> (); +} + +mod another_module { + pub(crate) fn another_name(_: &dyn crate::MyDatabase, (): ()) {} +} + +fn main() {} diff --git a/crates/salsa/tests/no_send_sync.rs b/crates/salsa/tests/no_send_sync.rs new file mode 100644 index 0000000000..2a25c437c3 --- /dev/null +++ b/crates/salsa/tests/no_send_sync.rs @@ -0,0 +1,31 @@ +use std::rc::Rc; + +#[salsa::query_group(NoSendSyncStorage)] +trait NoSendSyncDatabase: salsa::Database { + fn no_send_sync_value(&self, key: bool) -> Rc; + fn no_send_sync_key(&self, key: Rc) -> bool; +} + +fn no_send_sync_value(_db: &dyn NoSendSyncDatabase, key: bool) -> Rc { + Rc::new(key) +} + +fn no_send_sync_key(_db: &dyn NoSendSyncDatabase, key: Rc) -> bool { + *key +} + +#[salsa::database(NoSendSyncStorage)] +#[derive(Default)] +struct DatabaseImpl { + storage: salsa::Storage, +} + +impl salsa::Database for DatabaseImpl {} + +#[test] +fn no_send_sync() { + let db = DatabaseImpl::default(); + + assert_eq!(db.no_send_sync_value(true), Rc::new(true)); + assert!(!db.no_send_sync_key(Rc::new(false))); +} diff --git a/crates/salsa/tests/on_demand_inputs.rs b/crates/salsa/tests/on_demand_inputs.rs new file mode 100644 index 0000000000..5d0e486644 --- /dev/null +++ b/crates/salsa/tests/on_demand_inputs.rs @@ -0,0 +1,147 @@ +//! Test that "on-demand" input pattern works. +//! +//! On-demand inputs are inputs computed lazily on the fly. They are simulated +//! via a b query with zero inputs, which uses `add_synthetic_read` to +//! tweak durability and `invalidate` to clear the input. + +#![allow(clippy::disallowed_types, clippy::type_complexity)] + +use std::{cell::RefCell, collections::HashMap, rc::Rc}; + +use salsa::{Database as _, Durability, EventKind}; + +#[salsa::query_group(QueryGroupStorage)] +trait QueryGroup: salsa::Database + AsRef> { + fn a(&self, x: u32) -> u32; + fn b(&self, x: u32) -> u32; + fn c(&self, x: u32) -> u32; +} + +fn a(db: &dyn QueryGroup, x: u32) -> u32 { + let durability = if x % 2 == 0 { Durability::LOW } else { Durability::HIGH }; + db.salsa_runtime().report_synthetic_read(durability); + let external_state: &HashMap = db.as_ref(); + external_state[&x] +} + +fn b(db: &dyn QueryGroup, x: u32) -> u32 { + db.a(x) +} + +fn c(db: &dyn QueryGroup, x: u32) -> u32 { + db.b(x) +} + +#[salsa::database(QueryGroupStorage)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, + external_state: HashMap, + on_event: Option>, +} + +impl salsa::Database for Database { + fn salsa_event(&self, event: salsa::Event) { + if let Some(cb) = &self.on_event { + cb(self, event) + } + } +} + +impl AsRef> for Database { + fn as_ref(&self) -> &HashMap { + &self.external_state + } +} + +#[test] +fn on_demand_input_works() { + let mut db = Database::default(); + + db.external_state.insert(1, 10); + assert_eq!(db.b(1), 10); + assert_eq!(db.a(1), 10); + + // We changed external state, but haven't signaled about this yet, + // so we expect to see the old answer + db.external_state.insert(1, 92); + assert_eq!(db.b(1), 10); + assert_eq!(db.a(1), 10); + + AQuery.in_db_mut(&mut db).invalidate(&1); + assert_eq!(db.b(1), 92); + assert_eq!(db.a(1), 92); + + // Downstream queries should also be rerun if we call `a` first. + db.external_state.insert(1, 50); + AQuery.in_db_mut(&mut db).invalidate(&1); + assert_eq!(db.a(1), 50); + assert_eq!(db.b(1), 50); +} + +#[test] +fn on_demand_input_durability() { + let mut db = Database::default(); + + let events = Rc::new(RefCell::new(vec![])); + db.on_event = Some(Box::new({ + let events = events.clone(); + move |db, event| { + if let EventKind::WillCheckCancellation = event.kind { + // these events are not interesting + } else { + events.borrow_mut().push(format!("{:?}", event.debug(db))) + } + } + })); + + events.replace(vec![]); + db.external_state.insert(1, 10); + db.external_state.insert(2, 20); + assert_eq!(db.b(1), 10); + assert_eq!(db.b(2), 20); + expect_test::expect![[r#" + RefCell { + value: [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: b(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }", + ], + } + "#]].assert_debug_eq(&events); + + db.salsa_runtime_mut().synthetic_write(Durability::LOW); + events.replace(vec![]); + assert_eq!(db.c(1), 10); + assert_eq!(db.c(2), 20); + // Re-execute `a(2)` because that has low durability, but not `a(1)` + expect_test::expect![[r#" + RefCell { + value: [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: c(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: b(2) } }", + ], + } + "#]].assert_debug_eq(&events); + + db.salsa_runtime_mut().synthetic_write(Durability::HIGH); + events.replace(vec![]); + assert_eq!(db.c(1), 10); + assert_eq!(db.c(2), 20); + // Re-execute both `a(1)` and `a(2)`, but we don't re-execute any `b` queries as the + // result didn't actually change. + expect_test::expect![[r#" + RefCell { + value: [ + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(1) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: WillExecute { database_key: a(2) } }", + "Event { runtime_id: RuntimeId { counter: 0 }, kind: DidValidateMemoizedValue { database_key: c(2) } }", + ], + } + "#]].assert_debug_eq(&events); +} diff --git a/crates/salsa/tests/panic_safely.rs b/crates/salsa/tests/panic_safely.rs new file mode 100644 index 0000000000..c11ae9c214 --- /dev/null +++ b/crates/salsa/tests/panic_safely.rs @@ -0,0 +1,93 @@ +use salsa::{Database, ParallelDatabase, Snapshot}; +use std::panic::{self, AssertUnwindSafe}; +use std::sync::atomic::{AtomicU32, Ordering::SeqCst}; + +#[salsa::query_group(PanicSafelyStruct)] +trait PanicSafelyDatabase: salsa::Database { + #[salsa::input] + fn one(&self) -> usize; + + fn panic_safely(&self) -> (); + + fn outer(&self) -> (); +} + +fn panic_safely(db: &dyn PanicSafelyDatabase) { + assert_eq!(db.one(), 1); +} + +static OUTER_CALLS: AtomicU32 = AtomicU32::new(0); + +fn outer(db: &dyn PanicSafelyDatabase) { + OUTER_CALLS.fetch_add(1, SeqCst); + db.panic_safely(); +} + +#[salsa::database(PanicSafelyStruct)] +#[derive(Default)] +struct DatabaseStruct { + storage: salsa::Storage, +} + +impl salsa::Database for DatabaseStruct {} + +impl salsa::ParallelDatabase for DatabaseStruct { + fn snapshot(&self) -> Snapshot { + Snapshot::new(DatabaseStruct { storage: self.storage.snapshot() }) + } +} + +#[test] +fn should_panic_safely() { + let mut db = DatabaseStruct::default(); + db.set_one(0); + + // Invoke `db.panic_safely() without having set `db.one`. `db.one` will + // return 0 and we should catch the panic. + let result = panic::catch_unwind(AssertUnwindSafe({ + let db = db.snapshot(); + move || db.panic_safely() + })); + assert!(result.is_err()); + + // Set `db.one` to 1 and assert ok + db.set_one(1); + let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely())); + assert!(result.is_ok()); + + // Check, that memoized outer is not invalidated by a panic + { + assert_eq!(OUTER_CALLS.load(SeqCst), 0); + db.outer(); + assert_eq!(OUTER_CALLS.load(SeqCst), 1); + + db.set_one(0); + let result = panic::catch_unwind(AssertUnwindSafe(|| db.outer())); + assert!(result.is_err()); + assert_eq!(OUTER_CALLS.load(SeqCst), 1); + + db.set_one(1); + db.outer(); + assert_eq!(OUTER_CALLS.load(SeqCst), 2); + } +} + +#[test] +fn storages_are_unwind_safe() { + fn check_unwind_safe() {} + check_unwind_safe::<&DatabaseStruct>(); +} + +#[test] +fn panics_clear_query_stack() { + let db = DatabaseStruct::default(); + + // Invoke `db.panic_if_not_one() without having set `db.input`. `db.input` + // will default to 0 and we should catch the panic. + let result = panic::catch_unwind(AssertUnwindSafe(|| db.panic_safely())); + assert!(result.is_err()); + + // The database has been poisoned and any attempt to increment the + // revision should panic. + assert_eq!(db.salsa_runtime().active_query(), None); +} diff --git a/crates/salsa/tests/parallel/cancellation.rs b/crates/salsa/tests/parallel/cancellation.rs new file mode 100644 index 0000000000..9a92e5cc1f --- /dev/null +++ b/crates/salsa/tests/parallel/cancellation.rs @@ -0,0 +1,132 @@ +use crate::setup::{CancellationFlag, Knobs, ParDatabase, ParDatabaseImpl, WithValue}; +use salsa::{Cancelled, ParallelDatabase}; + +macro_rules! assert_cancelled { + ($thread:expr) => { + match $thread.join() { + Ok(value) => panic!("expected cancellation, got {:?}", value), + Err(payload) => match payload.downcast::() { + Ok(_) => {} + Err(payload) => ::std::panic::resume_unwind(payload), + }, + } + }; +} + +/// Add test where a call to `sum` is cancelled by a simultaneous +/// write. Check that we recompute the result in next revision, even +/// though none of the inputs have changed. +#[test] +fn in_par_get_set_cancellation_immediate() { + let mut db = ParDatabaseImpl::default(); + + db.set_input('a', 100); + db.set_input('b', 10); + db.set_input('c', 1); + db.set_input('d', 0); + + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || { + // This will not return until it sees cancellation is + // signaled. + db.knobs().sum_signal_on_entry.with_value(1, || { + db.knobs() + .sum_wait_for_cancellation + .with_value(CancellationFlag::Panic, || db.sum("abc")) + }) + } + }); + + // Wait until we have entered `sum` in the other thread. + db.wait_for(1); + + // Try to set the input. This will signal cancellation. + db.set_input('d', 1000); + + // This should re-compute the value (even though no input has changed). + let thread2 = std::thread::spawn({ + let db = db.snapshot(); + move || db.sum("abc") + }); + + assert_eq!(db.sum("d"), 1000); + assert_cancelled!(thread1); + assert_eq!(thread2.join().unwrap(), 111); +} + +/// Here, we check that `sum`'s cancellation is propagated +/// to `sum2` properly. +#[test] +fn in_par_get_set_cancellation_transitive() { + let mut db = ParDatabaseImpl::default(); + + db.set_input('a', 100); + db.set_input('b', 10); + db.set_input('c', 1); + db.set_input('d', 0); + + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || { + // This will not return until it sees cancellation is + // signaled. + db.knobs().sum_signal_on_entry.with_value(1, || { + db.knobs() + .sum_wait_for_cancellation + .with_value(CancellationFlag::Panic, || db.sum2("abc")) + }) + } + }); + + // Wait until we have entered `sum` in the other thread. + db.wait_for(1); + + // Try to set the input. This will signal cancellation. + db.set_input('d', 1000); + + // This should re-compute the value (even though no input has changed). + let thread2 = std::thread::spawn({ + let db = db.snapshot(); + move || db.sum2("abc") + }); + + assert_eq!(db.sum2("d"), 1000); + assert_cancelled!(thread1); + assert_eq!(thread2.join().unwrap(), 111); +} + +/// https://github.com/salsa-rs/salsa/issues/66 +#[test] +fn no_back_dating_in_cancellation() { + let mut db = ParDatabaseImpl::default(); + + db.set_input('a', 1); + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || { + // Here we compute a long-chain of queries, + // but the last one gets cancelled. + db.knobs().sum_signal_on_entry.with_value(1, || { + db.knobs() + .sum_wait_for_cancellation + .with_value(CancellationFlag::Panic, || db.sum3("a")) + }) + } + }); + + db.wait_for(1); + + // Set unrelated input to bump revision + db.set_input('b', 2); + + // Here we should recompuet the whole chain again, clearing the cancellation + // state. If we get `usize::max()` here, it is a bug! + assert_eq!(db.sum3("a"), 1); + + assert_cancelled!(thread1); + + db.set_input('a', 3); + db.set_input('a', 4); + assert_eq!(db.sum3("ab"), 6); +} diff --git a/crates/salsa/tests/parallel/frozen.rs b/crates/salsa/tests/parallel/frozen.rs new file mode 100644 index 0000000000..5359a8820e --- /dev/null +++ b/crates/salsa/tests/parallel/frozen.rs @@ -0,0 +1,57 @@ +use crate::setup::{ParDatabase, ParDatabaseImpl}; +use crate::signal::Signal; +use salsa::{Database, ParallelDatabase}; +use std::{ + panic::{catch_unwind, AssertUnwindSafe}, + sync::Arc, +}; + +/// Add test where a call to `sum` is cancelled by a simultaneous +/// write. Check that we recompute the result in next revision, even +/// though none of the inputs have changed. +#[test] +fn in_par_get_set_cancellation() { + let mut db = ParDatabaseImpl::default(); + + db.set_input('a', 1); + + let signal = Arc::new(Signal::default()); + + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + let signal = signal.clone(); + move || { + // Check that cancellation flag is not yet set, because + // `set` cannot have been called yet. + catch_unwind(AssertUnwindSafe(|| db.unwind_if_cancelled())).unwrap(); + + // Signal other thread to proceed. + signal.signal(1); + + // Wait for other thread to signal cancellation + catch_unwind(AssertUnwindSafe(|| loop { + db.unwind_if_cancelled(); + std::thread::yield_now(); + })) + .unwrap_err(); + } + }); + + let thread2 = std::thread::spawn({ + move || { + // Wait until thread 1 has asserted that they are not cancelled + // before we invoke `set.` + signal.wait_for(1); + + // This will block until thread1 drops the revision lock. + db.set_input('a', 2); + + db.input('a') + } + }); + + thread1.join().unwrap(); + + let c = thread2.join().unwrap(); + assert_eq!(c, 2); +} diff --git a/crates/salsa/tests/parallel/independent.rs b/crates/salsa/tests/parallel/independent.rs new file mode 100644 index 0000000000..bd6ba3bf93 --- /dev/null +++ b/crates/salsa/tests/parallel/independent.rs @@ -0,0 +1,29 @@ +use crate::setup::{ParDatabase, ParDatabaseImpl}; +use salsa::ParallelDatabase; + +/// Test two `sum` queries (on distinct keys) executing in different +/// threads. Really just a test that `snapshot` etc compiles. +#[test] +fn in_par_two_independent_queries() { + let mut db = ParDatabaseImpl::default(); + + db.set_input('a', 100); + db.set_input('b', 10); + db.set_input('c', 1); + db.set_input('d', 200); + db.set_input('e', 20); + db.set_input('f', 2); + + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || db.sum("abc") + }); + + let thread2 = std::thread::spawn({ + let db = db.snapshot(); + move || db.sum("def") + }); + + assert_eq!(thread1.join().unwrap(), 111); + assert_eq!(thread2.join().unwrap(), 222); +} diff --git a/crates/salsa/tests/parallel/main.rs b/crates/salsa/tests/parallel/main.rs new file mode 100644 index 0000000000..31c0da1837 --- /dev/null +++ b/crates/salsa/tests/parallel/main.rs @@ -0,0 +1,13 @@ +mod setup; + +mod cancellation; +mod frozen; +mod independent; +mod parallel_cycle_all_recover; +mod parallel_cycle_mid_recover; +mod parallel_cycle_none_recover; +mod parallel_cycle_one_recovers; +mod race; +mod signal; +mod stress; +mod true_parallel; diff --git a/crates/salsa/tests/parallel/parallel_cycle_all_recover.rs b/crates/salsa/tests/parallel/parallel_cycle_all_recover.rs new file mode 100644 index 0000000000..cee51b4db7 --- /dev/null +++ b/crates/salsa/tests/parallel/parallel_cycle_all_recover.rs @@ -0,0 +1,110 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::{Knobs, ParDatabaseImpl}; +use salsa::ParallelDatabase; +use test_log::test; + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | signal stage 2 +// (unblocked) wait for stage 3 (blocks) +// a2 | +// b1 (blocks -> stage 3) | +// | (unblocked) +// | b2 +// | a1 (cycle detected, recovers) +// | b2 completes, recovers +// | b1 completes, recovers +// a2 sees cycle, recovers +// a1 completes, recovers + +#[test] +fn parallel_cycle_all_recover() { + let db = ParDatabaseImpl::default(); + db.knobs().signal_on_will_block.set(3); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || db.a1(1) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || db.b1(1) + }); + + assert_eq!(thread_a.join().unwrap(), 11); + assert_eq!(thread_b.join().unwrap(), 21); +} + +#[salsa::query_group(ParallelCycleAllRecover)] +pub(crate) trait TestDatabase: Knobs { + #[salsa::cycle(recover_a1)] + fn a1(&self, key: i32) -> i32; + + #[salsa::cycle(recover_a2)] + fn a2(&self, key: i32) -> i32; + + #[salsa::cycle(recover_b1)] + fn b1(&self, key: i32) -> i32; + + #[salsa::cycle(recover_b2)] + fn b2(&self, key: i32) -> i32; +} + +fn recover_a1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + tracing::debug!("recover_a1"); + key * 10 + 1 +} + +fn recover_a2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + tracing::debug!("recover_a2"); + key * 10 + 2 +} + +fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + tracing::debug!("recover_b1"); + key * 20 + 1 +} + +fn recover_b2(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + tracing::debug!("recover_b2"); + key * 20 + 2 +} + +fn a1(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + db.a2(key) +} + +fn a2(db: &dyn TestDatabase, key: i32) -> i32 { + db.b1(key) +} + +fn b1(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + + db.b2(key) +} + +fn b2(db: &dyn TestDatabase, key: i32) -> i32 { + db.a1(key) +} diff --git a/crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs b/crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs new file mode 100644 index 0000000000..f78c05c559 --- /dev/null +++ b/crates/salsa/tests/parallel/parallel_cycle_mid_recover.rs @@ -0,0 +1,110 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::{Knobs, ParDatabaseImpl}; +use salsa::ParallelDatabase; +use test_log::test; + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | | +// | b2 +// | b3 +// | a1 (blocks -> stage 2) +// (unblocked) | +// a2 (cycle detected) | +// b3 recovers +// b2 resumes +// b1 panics because bug + +#[test] +fn parallel_cycle_mid_recovers() { + let db = ParDatabaseImpl::default(); + db.knobs().signal_on_will_block.set(2); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || db.a1(1) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || db.b1(1) + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} + +#[salsa::query_group(ParallelCycleMidRecovers)] +pub(crate) trait TestDatabase: Knobs { + fn a1(&self, key: i32) -> i32; + + fn a2(&self, key: i32) -> i32; + + #[salsa::cycle(recover_b1)] + fn b1(&self, key: i32) -> i32; + + fn b2(&self, key: i32) -> i32; + + #[salsa::cycle(recover_b3)] + fn b3(&self, key: i32) -> i32; +} + +fn recover_b1(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + tracing::debug!("recover_b1"); + key * 20 + 2 +} + +fn recover_b3(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + tracing::debug!("recover_b1"); + key * 200 + 2 +} + +fn a1(db: &dyn TestDatabase, key: i32) -> i32 { + // tell thread b we have started + db.signal(1); + + // wait for thread b to block on a1 + db.wait_for(2); + + db.a2(key) +} + +fn a2(db: &dyn TestDatabase, key: i32) -> i32 { + // create the cycle + db.b1(key) +} + +fn b1(db: &dyn TestDatabase, key: i32) -> i32 { + // wait for thread a to have started + db.wait_for(1); + + db.b2(key); + + 0 +} + +fn b2(db: &dyn TestDatabase, key: i32) -> i32 { + // will encounter a cycle but recover + db.b3(key); + db.b1(key); // hasn't recovered yet + 0 +} + +fn b3(db: &dyn TestDatabase, key: i32) -> i32 { + // will block on thread a, signaling stage 2 + db.a1(key) +} diff --git a/crates/salsa/tests/parallel/parallel_cycle_none_recover.rs b/crates/salsa/tests/parallel/parallel_cycle_none_recover.rs new file mode 100644 index 0000000000..35fe379118 --- /dev/null +++ b/crates/salsa/tests/parallel/parallel_cycle_none_recover.rs @@ -0,0 +1,69 @@ +//! Test a cycle where no queries recover that occurs across threads. +//! See the `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::{Knobs, ParDatabaseImpl}; +use expect_test::expect; +use salsa::ParallelDatabase; +use test_log::test; + +#[test] +fn parallel_cycle_none_recover() { + let db = ParDatabaseImpl::default(); + db.knobs().signal_on_will_block.set(3); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || db.a(-1) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || db.b(-1) + }); + + // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). + // Right now, it panics with a string. + let err_b = thread_b.join().unwrap_err(); + if let Some(c) = err_b.downcast_ref::() { + expect![[r#" + [ + "a(-1)", + "b(-1)", + ] + "#]] + .assert_debug_eq(&c.unexpected_participants(&db)); + } else { + panic!("b failed in an unexpected way: {:?}", err_b); + } + + // We expect A to propagate a panic, which causes us to use the sentinel + // type `Canceled`. + assert!(thread_a.join().unwrap_err().downcast_ref::().is_some()); +} + +#[salsa::query_group(ParallelCycleNoneRecover)] +pub(crate) trait TestDatabase: Knobs { + fn a(&self, key: i32) -> i32; + fn b(&self, key: i32) -> i32; +} + +fn a(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + db.b(key) +} + +fn b(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + + // Now try to execute A + db.a(key) +} diff --git a/crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs b/crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs new file mode 100644 index 0000000000..7d3944714a --- /dev/null +++ b/crates/salsa/tests/parallel/parallel_cycle_one_recovers.rs @@ -0,0 +1,95 @@ +//! Test for cycle recover spread across two threads. +//! See `../cycles.rs` for a complete listing of cycle tests, +//! both intra and cross thread. + +use crate::setup::{Knobs, ParDatabaseImpl}; +use salsa::ParallelDatabase; +use test_log::test; + +// Recover cycle test: +// +// The pattern is as follows. +// +// Thread A Thread B +// -------- -------- +// a1 b1 +// | wait for stage 1 (blocks) +// signal stage 1 | +// wait for stage 2 (blocks) (unblocked) +// | signal stage 2 +// (unblocked) wait for stage 3 (blocks) +// a2 | +// b1 (blocks -> stage 3) | +// | (unblocked) +// | b2 +// | a1 (cycle detected) +// a2 recovery fn executes | +// a1 completes normally | +// b2 completes, recovers +// b1 completes, recovers + +#[test] +fn parallel_cycle_one_recovers() { + let db = ParDatabaseImpl::default(); + db.knobs().signal_on_will_block.set(3); + + let thread_a = std::thread::spawn({ + let db = db.snapshot(); + move || db.a1(1) + }); + + let thread_b = std::thread::spawn({ + let db = db.snapshot(); + move || db.b1(1) + }); + + // We expect that the recovery function yields + // `1 * 20 + 2`, which is returned (and forwarded) + // to b1, and from there to a2 and a1. + assert_eq!(thread_a.join().unwrap(), 22); + assert_eq!(thread_b.join().unwrap(), 22); +} + +#[salsa::query_group(ParallelCycleOneRecovers)] +pub(crate) trait TestDatabase: Knobs { + fn a1(&self, key: i32) -> i32; + + #[salsa::cycle(recover)] + fn a2(&self, key: i32) -> i32; + + fn b1(&self, key: i32) -> i32; + + fn b2(&self, key: i32) -> i32; +} + +fn recover(_db: &dyn TestDatabase, _cycle: &salsa::Cycle, key: &i32) -> i32 { + tracing::debug!("recover"); + key * 20 + 2 +} + +fn a1(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.signal(1); + db.wait_for(2); + + db.a2(key) +} + +fn a2(db: &dyn TestDatabase, key: i32) -> i32 { + db.b1(key) +} + +fn b1(db: &dyn TestDatabase, key: i32) -> i32 { + // Wait to create the cycle until both threads have entered + db.wait_for(1); + db.signal(2); + + // Wait for thread A to block on this thread + db.wait_for(3); + + db.b2(key) +} + +fn b2(db: &dyn TestDatabase, key: i32) -> i32 { + db.a1(key) +} diff --git a/crates/salsa/tests/parallel/race.rs b/crates/salsa/tests/parallel/race.rs new file mode 100644 index 0000000000..e875de998f --- /dev/null +++ b/crates/salsa/tests/parallel/race.rs @@ -0,0 +1,37 @@ +use std::panic::AssertUnwindSafe; + +use crate::setup::{ParDatabase, ParDatabaseImpl}; +use salsa::{Cancelled, ParallelDatabase}; + +/// Test where a read and a set are racing with one another. +/// Should be atomic. +#[test] +fn in_par_get_set_race() { + let mut db = ParDatabaseImpl::default(); + + db.set_input('a', 100); + db.set_input('b', 10); + db.set_input('c', 1); + + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || Cancelled::catch(AssertUnwindSafe(|| db.sum("abc"))) + }); + + let thread2 = std::thread::spawn(move || { + db.set_input('a', 1000); + db.sum("a") + }); + + // If the 1st thread runs first, you get 111, otherwise you get + // 1011; if they run concurrently and the 1st thread observes the + // cancellation, it'll unwind. + let result1 = thread1.join().unwrap(); + if let Ok(value1) = result1 { + assert!(value1 == 111 || value1 == 1011, "illegal result {}", value1); + } + + // thread2 can not observe a cancellation because it performs a + // database write before running any other queries. + assert_eq!(thread2.join().unwrap(), 1000); +} diff --git a/crates/salsa/tests/parallel/setup.rs b/crates/salsa/tests/parallel/setup.rs new file mode 100644 index 0000000000..0a35902b43 --- /dev/null +++ b/crates/salsa/tests/parallel/setup.rs @@ -0,0 +1,197 @@ +use crate::signal::Signal; +use salsa::Database; +use salsa::ParallelDatabase; +use salsa::Snapshot; +use std::sync::Arc; +use std::{ + cell::Cell, + panic::{catch_unwind, resume_unwind, AssertUnwindSafe}, +}; + +#[salsa::query_group(Par)] +pub(crate) trait ParDatabase: Knobs { + #[salsa::input] + fn input(&self, key: char) -> usize; + + fn sum(&self, key: &'static str) -> usize; + + /// Invokes `sum` + fn sum2(&self, key: &'static str) -> usize; + + /// Invokes `sum` but doesn't really care about the result. + fn sum2_drop_sum(&self, key: &'static str) -> usize; + + /// Invokes `sum2` + fn sum3(&self, key: &'static str) -> usize; + + /// Invokes `sum2_drop_sum` + fn sum3_drop_sum(&self, key: &'static str) -> usize; +} + +/// Various "knobs" and utilities used by tests to force +/// a certain behavior. +pub(crate) trait Knobs { + fn knobs(&self) -> &KnobsStruct; + + fn signal(&self, stage: usize); + + fn wait_for(&self, stage: usize); +} + +pub(crate) trait WithValue { + fn with_value(&self, value: T, closure: impl FnOnce() -> R) -> R; +} + +impl WithValue for Cell { + fn with_value(&self, value: T, closure: impl FnOnce() -> R) -> R { + let old_value = self.replace(value); + + let result = catch_unwind(AssertUnwindSafe(closure)); + + self.set(old_value); + + match result { + Ok(r) => r, + Err(payload) => resume_unwind(payload), + } + } +} + +#[derive(Default, Clone, Copy, PartialEq, Eq)] +pub(crate) enum CancellationFlag { + #[default] + Down, + Panic, +} + +/// Various "knobs" that can be used to customize how the queries +/// behave on one specific thread. Note that this state is +/// intentionally thread-local (apart from `signal`). +#[derive(Clone, Default)] +pub(crate) struct KnobsStruct { + /// A kind of flexible barrier used to coordinate execution across + /// threads to ensure we reach various weird states. + pub(crate) signal: Arc, + + /// When this database is about to block, send a signal. + pub(crate) signal_on_will_block: Cell, + + /// Invocations of `sum` will signal this stage on entry. + pub(crate) sum_signal_on_entry: Cell, + + /// Invocations of `sum` will wait for this stage on entry. + pub(crate) sum_wait_for_on_entry: Cell, + + /// If true, invocations of `sum` will panic before they exit. + pub(crate) sum_should_panic: Cell, + + /// If true, invocations of `sum` will wait for cancellation before + /// they exit. + pub(crate) sum_wait_for_cancellation: Cell, + + /// Invocations of `sum` will wait for this stage prior to exiting. + pub(crate) sum_wait_for_on_exit: Cell, + + /// Invocations of `sum` will signal this stage prior to exiting. + pub(crate) sum_signal_on_exit: Cell, + + /// Invocations of `sum3_drop_sum` will panic unconditionally + pub(crate) sum3_drop_sum_should_panic: Cell, +} + +fn sum(db: &dyn ParDatabase, key: &'static str) -> usize { + let mut sum = 0; + + db.signal(db.knobs().sum_signal_on_entry.get()); + + db.wait_for(db.knobs().sum_wait_for_on_entry.get()); + + if db.knobs().sum_should_panic.get() { + panic!("query set to panic before exit") + } + + for ch in key.chars() { + sum += db.input(ch); + } + + match db.knobs().sum_wait_for_cancellation.get() { + CancellationFlag::Down => (), + CancellationFlag::Panic => { + tracing::debug!("waiting for cancellation"); + loop { + db.unwind_if_cancelled(); + std::thread::yield_now(); + } + } + } + + db.wait_for(db.knobs().sum_wait_for_on_exit.get()); + + db.signal(db.knobs().sum_signal_on_exit.get()); + + sum +} + +fn sum2(db: &dyn ParDatabase, key: &'static str) -> usize { + db.sum(key) +} + +fn sum2_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize { + let _ = db.sum(key); + 22 +} + +fn sum3(db: &dyn ParDatabase, key: &'static str) -> usize { + db.sum2(key) +} + +fn sum3_drop_sum(db: &dyn ParDatabase, key: &'static str) -> usize { + if db.knobs().sum3_drop_sum_should_panic.get() { + panic!("sum3_drop_sum executed") + } + db.sum2_drop_sum(key) +} + +#[salsa::database( + Par, + crate::parallel_cycle_all_recover::ParallelCycleAllRecover, + crate::parallel_cycle_none_recover::ParallelCycleNoneRecover, + crate::parallel_cycle_mid_recover::ParallelCycleMidRecovers, + crate::parallel_cycle_one_recovers::ParallelCycleOneRecovers +)] +#[derive(Default)] +pub(crate) struct ParDatabaseImpl { + storage: salsa::Storage, + knobs: KnobsStruct, +} + +impl Database for ParDatabaseImpl { + fn salsa_event(&self, event: salsa::Event) { + if let salsa::EventKind::WillBlockOn { .. } = event.kind { + self.signal(self.knobs().signal_on_will_block.get()); + } + } +} + +impl ParallelDatabase for ParDatabaseImpl { + fn snapshot(&self) -> Snapshot { + Snapshot::new(ParDatabaseImpl { + storage: self.storage.snapshot(), + knobs: self.knobs.clone(), + }) + } +} + +impl Knobs for ParDatabaseImpl { + fn knobs(&self) -> &KnobsStruct { + &self.knobs + } + + fn signal(&self, stage: usize) { + self.knobs.signal.signal(stage); + } + + fn wait_for(&self, stage: usize) { + self.knobs.signal.wait_for(stage); + } +} diff --git a/crates/salsa/tests/parallel/signal.rs b/crates/salsa/tests/parallel/signal.rs new file mode 100644 index 0000000000..0af7b66e48 --- /dev/null +++ b/crates/salsa/tests/parallel/signal.rs @@ -0,0 +1,40 @@ +use parking_lot::{Condvar, Mutex}; + +#[derive(Default)] +pub(crate) struct Signal { + value: Mutex, + cond_var: Condvar, +} + +impl Signal { + pub(crate) fn signal(&self, stage: usize) { + tracing::debug!("signal({})", stage); + + // This check avoids acquiring the lock for things that will + // clearly be a no-op. Not *necessary* but helps to ensure we + // are more likely to encounter weird race conditions; + // otherwise calls to `sum` will tend to be unnecessarily + // synchronous. + if stage > 0 { + let mut v = self.value.lock(); + if stage > *v { + *v = stage; + self.cond_var.notify_all(); + } + } + } + + /// Waits until the given condition is true; the fn is invoked + /// with the current stage. + pub(crate) fn wait_for(&self, stage: usize) { + tracing::debug!("wait_for({})", stage); + + // As above, avoid lock if clearly a no-op. + if stage > 0 { + let mut v = self.value.lock(); + while *v < stage { + self.cond_var.wait(&mut v); + } + } + } +} diff --git a/crates/salsa/tests/parallel/stress.rs b/crates/salsa/tests/parallel/stress.rs new file mode 100644 index 0000000000..2fa317b2b9 --- /dev/null +++ b/crates/salsa/tests/parallel/stress.rs @@ -0,0 +1,168 @@ +use rand::seq::SliceRandom; +use rand::Rng; + +use salsa::ParallelDatabase; +use salsa::Snapshot; +use salsa::{Cancelled, Database}; + +// Number of operations a reader performs +const N_MUTATOR_OPS: usize = 100; +const N_READER_OPS: usize = 100; + +#[salsa::query_group(Stress)] +trait StressDatabase: salsa::Database { + #[salsa::input] + fn a(&self, key: usize) -> usize; + + fn b(&self, key: usize) -> usize; + + fn c(&self, key: usize) -> usize; +} + +fn b(db: &dyn StressDatabase, key: usize) -> usize { + db.unwind_if_cancelled(); + db.a(key) +} + +fn c(db: &dyn StressDatabase, key: usize) -> usize { + db.b(key) +} + +#[salsa::database(Stress)] +#[derive(Default)] +struct StressDatabaseImpl { + storage: salsa::Storage, +} + +impl salsa::Database for StressDatabaseImpl {} + +impl salsa::ParallelDatabase for StressDatabaseImpl { + fn snapshot(&self) -> Snapshot { + Snapshot::new(StressDatabaseImpl { storage: self.storage.snapshot() }) + } +} + +#[derive(Clone, Copy, Debug)] +enum Query { + A, + B, + C, +} + +enum MutatorOp { + WriteOp(WriteOp), + LaunchReader { ops: Vec, check_cancellation: bool }, +} + +#[derive(Debug)] +enum WriteOp { + SetA(usize, usize), +} + +#[derive(Debug)] +enum ReadOp { + Get(Query, usize), +} + +impl rand::distributions::Distribution for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> Query { + *[Query::A, Query::B, Query::C].choose(rng).unwrap() + } +} + +impl rand::distributions::Distribution for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> MutatorOp { + if rng.gen_bool(0.5) { + MutatorOp::WriteOp(rng.gen()) + } else { + MutatorOp::LaunchReader { + ops: (0..N_READER_OPS).map(|_| rng.gen()).collect(), + check_cancellation: rng.gen(), + } + } + } +} + +impl rand::distributions::Distribution for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> WriteOp { + let key = rng.gen::() % 10; + let value = rng.gen::() % 10; + WriteOp::SetA(key, value) + } +} + +impl rand::distributions::Distribution for rand::distributions::Standard { + fn sample(&self, rng: &mut R) -> ReadOp { + let query = rng.gen::(); + let key = rng.gen::() % 10; + ReadOp::Get(query, key) + } +} + +fn db_reader_thread(db: &StressDatabaseImpl, ops: Vec, check_cancellation: bool) { + for op in ops { + if check_cancellation { + db.unwind_if_cancelled(); + } + op.execute(db); + } +} + +impl WriteOp { + fn execute(self, db: &mut StressDatabaseImpl) { + match self { + WriteOp::SetA(key, value) => { + db.set_a(key, value); + } + } + } +} + +impl ReadOp { + fn execute(self, db: &StressDatabaseImpl) { + match self { + ReadOp::Get(query, key) => match query { + Query::A => { + db.a(key); + } + Query::B => { + let _ = db.b(key); + } + Query::C => { + let _ = db.c(key); + } + }, + } + } +} + +#[test] +fn stress_test() { + let mut db = StressDatabaseImpl::default(); + for i in 0..10 { + db.set_a(i, i); + } + + let mut rng = rand::thread_rng(); + + // generate the ops that the mutator thread will perform + let write_ops: Vec = (0..N_MUTATOR_OPS).map(|_| rng.gen()).collect(); + + // execute the "main thread", which sometimes snapshots off other threads + let mut all_threads = vec![]; + for op in write_ops { + match op { + MutatorOp::WriteOp(w) => w.execute(&mut db), + MutatorOp::LaunchReader { ops, check_cancellation } => { + all_threads.push(std::thread::spawn({ + let db = db.snapshot(); + move || Cancelled::catch(|| db_reader_thread(&db, ops, check_cancellation)) + })) + } + } + } + + for thread in all_threads { + thread.join().unwrap().ok(); + } +} diff --git a/crates/salsa/tests/parallel/true_parallel.rs b/crates/salsa/tests/parallel/true_parallel.rs new file mode 100644 index 0000000000..d0e58efd1a --- /dev/null +++ b/crates/salsa/tests/parallel/true_parallel.rs @@ -0,0 +1,125 @@ +use crate::setup::{Knobs, ParDatabase, ParDatabaseImpl, WithValue}; +use salsa::ParallelDatabase; +use std::panic::{self, AssertUnwindSafe}; + +/// Test where two threads are executing sum. We show that they can +/// both be executing sum in parallel by having thread1 wait for +/// thread2 to send a signal before it leaves (similarly, thread2 +/// waits for thread1 to send a signal before it enters). +#[test] +fn true_parallel_different_keys() { + let mut db = ParDatabaseImpl::default(); + + db.set_input('a', 100); + db.set_input('b', 10); + db.set_input('c', 1); + + // Thread 1 will signal stage 1 when it enters and wait for stage 2. + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || { + let v = db + .knobs() + .sum_signal_on_entry + .with_value(1, || db.knobs().sum_wait_for_on_exit.with_value(2, || db.sum("a"))); + v + } + }); + + // Thread 2 will wait_for stage 1 when it enters and signal stage 2 + // when it leaves. + let thread2 = std::thread::spawn({ + let db = db.snapshot(); + move || { + let v = db + .knobs() + .sum_wait_for_on_entry + .with_value(1, || db.knobs().sum_signal_on_exit.with_value(2, || db.sum("b"))); + v + } + }); + + assert_eq!(thread1.join().unwrap(), 100); + assert_eq!(thread2.join().unwrap(), 10); +} + +/// Add a test that tries to trigger a conflict, where we fetch +/// `sum("abc")` from two threads simultaneously, and of them +/// therefore has to block. +#[test] +fn true_parallel_same_keys() { + let mut db = ParDatabaseImpl::default(); + + db.set_input('a', 100); + db.set_input('b', 10); + db.set_input('c', 1); + + // Thread 1 will wait_for a barrier in the start of `sum` + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || { + let v = db + .knobs() + .sum_signal_on_entry + .with_value(1, || db.knobs().sum_wait_for_on_entry.with_value(2, || db.sum("abc"))); + v + } + }); + + // Thread 2 will wait until Thread 1 has entered sum and then -- + // once it has set itself to block -- signal Thread 1 to + // continue. This way, we test out the mechanism of one thread + // blocking on another. + let thread2 = std::thread::spawn({ + let db = db.snapshot(); + move || { + db.knobs().signal.wait_for(1); + db.knobs().signal_on_will_block.set(2); + db.sum("abc") + } + }); + + assert_eq!(thread1.join().unwrap(), 111); + assert_eq!(thread2.join().unwrap(), 111); +} + +/// Add a test that tries to trigger a conflict, where we fetch `sum("a")` +/// from two threads simultaneously. After `thread2` begins blocking, +/// we force `thread1` to panic and should see that propagate to `thread2`. +#[test] +fn true_parallel_propagate_panic() { + let mut db = ParDatabaseImpl::default(); + + db.set_input('a', 1); + + // `thread1` will wait_for a barrier in the start of `sum`. Once it can + // continue, it will panic. + let thread1 = std::thread::spawn({ + let db = db.snapshot(); + move || { + let v = db.knobs().sum_signal_on_entry.with_value(1, || { + db.knobs() + .sum_wait_for_on_entry + .with_value(2, || db.knobs().sum_should_panic.with_value(true, || db.sum("a"))) + }); + v + } + }); + + // `thread2` will wait until `thread1` has entered sum and then -- once it + // has set itself to block -- signal `thread1` to continue. + let thread2 = std::thread::spawn({ + let db = db.snapshot(); + move || { + db.knobs().signal.wait_for(1); + db.knobs().signal_on_will_block.set(2); + db.sum("a") + } + }); + + let result1 = panic::catch_unwind(AssertUnwindSafe(|| thread1.join().unwrap())); + let result2 = panic::catch_unwind(AssertUnwindSafe(|| thread2.join().unwrap())); + + assert!(result1.is_err()); + assert!(result2.is_err()); +} diff --git a/crates/salsa/tests/storage_varieties/implementation.rs b/crates/salsa/tests/storage_varieties/implementation.rs new file mode 100644 index 0000000000..2843660f15 --- /dev/null +++ b/crates/salsa/tests/storage_varieties/implementation.rs @@ -0,0 +1,19 @@ +use crate::queries; +use std::cell::Cell; + +#[salsa::database(queries::GroupStruct)] +#[derive(Default)] +pub(crate) struct DatabaseImpl { + storage: salsa::Storage, + counter: Cell, +} + +impl queries::Counter for DatabaseImpl { + fn increment(&self) -> usize { + let v = self.counter.get(); + self.counter.set(v + 1); + v + } +} + +impl salsa::Database for DatabaseImpl {} diff --git a/crates/salsa/tests/storage_varieties/main.rs b/crates/salsa/tests/storage_varieties/main.rs new file mode 100644 index 0000000000..e92c61740e --- /dev/null +++ b/crates/salsa/tests/storage_varieties/main.rs @@ -0,0 +1,5 @@ +mod implementation; +mod queries; +mod tests; + +fn main() {} diff --git a/crates/salsa/tests/storage_varieties/queries.rs b/crates/salsa/tests/storage_varieties/queries.rs new file mode 100644 index 0000000000..0847fadefb --- /dev/null +++ b/crates/salsa/tests/storage_varieties/queries.rs @@ -0,0 +1,22 @@ +pub(crate) trait Counter: salsa::Database { + fn increment(&self) -> usize; +} + +#[salsa::query_group(GroupStruct)] +pub(crate) trait Database: Counter { + fn memoized(&self) -> usize; + fn volatile(&self) -> usize; +} + +/// Because this query is memoized, we only increment the counter +/// the first time it is invoked. +fn memoized(db: &dyn Database) -> usize { + db.volatile() +} + +/// Because this query is volatile, each time it is invoked, +/// we will increment the counter. +fn volatile(db: &dyn Database) -> usize { + db.salsa_runtime().report_untracked_read(); + db.increment() +} diff --git a/crates/salsa/tests/storage_varieties/tests.rs b/crates/salsa/tests/storage_varieties/tests.rs new file mode 100644 index 0000000000..f75c7c142f --- /dev/null +++ b/crates/salsa/tests/storage_varieties/tests.rs @@ -0,0 +1,49 @@ +#![cfg(test)] + +use crate::implementation::DatabaseImpl; +use crate::queries::Database; +use salsa::Database as _Database; +use salsa::Durability; + +#[test] +fn memoized_twice() { + let db = DatabaseImpl::default(); + let v1 = db.memoized(); + let v2 = db.memoized(); + assert_eq!(v1, v2); +} + +#[test] +fn volatile_twice() { + let mut db = DatabaseImpl::default(); + let v1 = db.volatile(); + let v2 = db.volatile(); // volatiles are cached, so 2nd read returns the same + assert_eq!(v1, v2); + + db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches + + let v3 = db.volatile(); // will re-increment the counter + let v4 = db.volatile(); // second call will be cached + assert_eq!(v1 + 1, v3); + assert_eq!(v3, v4); +} + +#[test] +fn intermingled() { + let mut db = DatabaseImpl::default(); + let v1 = db.volatile(); + let v2 = db.memoized(); + let v3 = db.volatile(); // cached + let v4 = db.memoized(); // cached + + assert_eq!(v1, v2); + assert_eq!(v1, v3); + assert_eq!(v2, v4); + + db.salsa_runtime_mut().synthetic_write(Durability::LOW); // clears volatile caches + + let v5 = db.memoized(); // re-executes volatile, caches new result + let v6 = db.memoized(); // re-use cached result + assert_eq!(v4 + 1, v5); + assert_eq!(v5, v6); +} diff --git a/crates/salsa/tests/transparent.rs b/crates/salsa/tests/transparent.rs new file mode 100644 index 0000000000..2e6dd4267b --- /dev/null +++ b/crates/salsa/tests/transparent.rs @@ -0,0 +1,39 @@ +//! Test that transparent (uncached) queries work + +#[salsa::query_group(QueryGroupStorage)] +trait QueryGroup { + #[salsa::input] + fn input(&self, x: u32) -> u32; + #[salsa::transparent] + fn wrap(&self, x: u32) -> u32; + fn get(&self, x: u32) -> u32; +} + +fn wrap(db: &dyn QueryGroup, x: u32) -> u32 { + db.input(x) +} + +fn get(db: &dyn QueryGroup, x: u32) -> u32 { + db.wrap(x) +} + +#[salsa::database(QueryGroupStorage)] +#[derive(Default)] +struct Database { + storage: salsa::Storage, +} + +impl salsa::Database for Database {} + +#[test] +fn transparent_queries_work() { + let mut db = Database::default(); + + db.set_input(1, 10); + assert_eq!(db.get(1), 10); + assert_eq!(db.get(1), 10); + + db.set_input(1, 92); + assert_eq!(db.get(1), 92); + assert_eq!(db.get(1), 92); +} diff --git a/crates/salsa/tests/variadic.rs b/crates/salsa/tests/variadic.rs new file mode 100644 index 0000000000..cb857844eb --- /dev/null +++ b/crates/salsa/tests/variadic.rs @@ -0,0 +1,51 @@ +#[salsa::query_group(HelloWorld)] +trait HelloWorldDatabase: salsa::Database { + #[salsa::input] + fn input(&self, a: u32, b: u32) -> u32; + + fn none(&self) -> u32; + + fn one(&self, k: u32) -> u32; + + fn two(&self, a: u32, b: u32) -> u32; + + fn trailing(&self, a: u32, b: u32) -> u32; +} + +fn none(_db: &dyn HelloWorldDatabase) -> u32 { + 22 +} + +fn one(_db: &dyn HelloWorldDatabase, k: u32) -> u32 { + k * 2 +} + +fn two(_db: &dyn HelloWorldDatabase, a: u32, b: u32) -> u32 { + a * b +} + +fn trailing(_db: &dyn HelloWorldDatabase, a: u32, b: u32) -> u32 { + a - b +} + +#[salsa::database(HelloWorld)] +#[derive(Default)] +struct DatabaseStruct { + storage: salsa::Storage, +} + +impl salsa::Database for DatabaseStruct {} + +#[test] +fn execute() { + let mut db = DatabaseStruct::default(); + + // test what happens with inputs: + db.set_input(1, 2, 3); + assert_eq!(db.input(1, 2), 3); + + assert_eq!(db.none(), 22); + assert_eq!(db.one(11), 22); + assert_eq!(db.two(11, 2), 22); + assert_eq!(db.trailing(24, 2), 22); +} diff --git a/crates/sourcegen/src/lib.rs b/crates/sourcegen/src/lib.rs index 18fa77fd97..295b716b4e 100644 --- a/crates/sourcegen/src/lib.rs +++ b/crates/sourcegen/src/lib.rs @@ -69,7 +69,7 @@ impl CommentBlock { panic!("Use plain (non-doc) comments with tags like {tag}:\n {first}"); } - block.id = id.trim().to_string(); + block.id = id.trim().to_owned(); true }); blocks @@ -93,7 +93,7 @@ impl CommentBlock { if let Some(' ') = contents.chars().next() { contents = &contents[1..]; } - block.contents.push(contents.to_string()); + block.contents.push(contents.to_owned()); } None => { if !block.contents.is_empty() { @@ -167,6 +167,7 @@ pub fn add_preamble(generator: &'static str, mut text: String) -> String { /// Checks that the `file` has the specified `contents`. If that is not the /// case, updates the file and then fails the test. +#[allow(clippy::print_stderr)] pub fn ensure_file_contents(file: &Path, contents: &str) { if let Ok(old_contents) = fs::read_to_string(file) { if normalize_newlines(&old_contents) == normalize_newlines(contents) { diff --git a/crates/span/Cargo.toml b/crates/span/Cargo.toml index a4abba29bb..7093f3a691 100644 --- a/crates/span/Cargo.toml +++ b/crates/span/Cargo.toml @@ -11,7 +11,7 @@ authors.workspace = true [dependencies] la-arena.workspace = true -rust-analyzer-salsa.workspace = true +salsa.workspace = true # local deps diff --git a/crates/span/src/lib.rs b/crates/span/src/lib.rs index 6796dc4188..7763d75cc9 100644 --- a/crates/span/src/lib.rs +++ b/crates/span/src/lib.rs @@ -68,26 +68,9 @@ impl fmt::Display for Span { } } -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct SyntaxContextId(InternId); -impl fmt::Debug for SyntaxContextId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - if *self == Self::SELF_REF { - f.debug_tuple("SyntaxContextId") - .field(&{ - #[derive(Debug)] - #[allow(non_camel_case_types)] - struct SELF_REF; - SELF_REF - }) - .finish() - } else { - f.debug_tuple("SyntaxContextId").field(&self.0).finish() - } - } -} - impl salsa::InternKey for SyntaxContextId { fn from_intern_id(v: salsa::InternId) -> Self { SyntaxContextId(v) @@ -106,10 +89,6 @@ impl fmt::Display for SyntaxContextId { // inherent trait impls please tyvm impl SyntaxContextId { pub const ROOT: Self = SyntaxContextId(unsafe { InternId::new_unchecked(0) }); - // veykril(HACK): FIXME salsa doesn't allow us fetching the id of the current input to be allocated so - // we need a special value that behaves as the current context. - pub const SELF_REF: Self = - SyntaxContextId(unsafe { InternId::new_unchecked(InternId::MAX - 1) }); pub fn is_root(self) -> bool { self == Self::ROOT diff --git a/crates/stdx/Cargo.toml b/crates/stdx/Cargo.toml index 2e3f9113b0..6cca116335 100644 --- a/crates/stdx/Cargo.toml +++ b/crates/stdx/Cargo.toml @@ -29,4 +29,4 @@ winapi = { version = "0.3.9", features = ["winerror"] } # default = [ "backtrace" ] [lints] -workspace = true \ No newline at end of file +workspace = true diff --git a/crates/stdx/src/anymap.rs b/crates/stdx/src/anymap.rs index 9990f8b086..899cd8ac6b 100644 --- a/crates/stdx/src/anymap.rs +++ b/crates/stdx/src/anymap.rs @@ -54,12 +54,13 @@ use core::any::{Any, TypeId}; use core::hash::BuildHasherDefault; use core::marker::PhantomData; -use ::std::collections::hash_map::{self, HashMap}; +use ::std::collections::hash_map; /// Raw access to the underlying `HashMap`. /// /// This alias is provided for convenience because of the ugly third generic parameter. -pub type RawMap = HashMap, BuildHasherDefault>; +#[allow(clippy::disallowed_types)] // Uses a custom hasher +pub type RawMap = hash_map::HashMap, BuildHasherDefault>; /// A collection containing zero or one values for any given type and allowing convenient, /// type-safe access to those values. diff --git a/crates/stdx/src/lib.rs b/crates/stdx/src/lib.rs index 07b7827228..9a9ebae74e 100644 --- a/crates/stdx/src/lib.rs +++ b/crates/stdx/src/lib.rs @@ -23,12 +23,14 @@ pub fn is_ci() -> bool { } #[must_use] +#[allow(clippy::print_stderr)] pub fn timeit(label: &'static str) -> impl Drop { let start = Instant::now(); defer(move || eprintln!("{}: {:.2?}", label, start.elapsed())) } /// Prints backtrace to stderr, useful for debugging. +#[allow(clippy::print_stderr)] pub fn print_backtrace() { #[cfg(feature = "backtrace")] eprintln!("{:?}", backtrace::Backtrace::new()); diff --git a/crates/stdx/src/panic_context.rs b/crates/stdx/src/panic_context.rs index c3e8813b0e..cf3d85b4da 100644 --- a/crates/stdx/src/panic_context.rs +++ b/crates/stdx/src/panic_context.rs @@ -18,6 +18,7 @@ pub struct PanicContext { } impl PanicContext { + #[allow(clippy::print_stderr)] fn init() { let default_hook = panic::take_hook(); let hook = move |panic_info: &panic::PanicInfo<'_>| { @@ -43,7 +44,7 @@ impl Drop for PanicContext { fn with_ctx(f: impl FnOnce(&mut Vec)) { thread_local! { - static CTX: RefCell> = RefCell::new(Vec::new()); + static CTX: RefCell> = const { RefCell::new(Vec::new()) }; } CTX.with(|ctx| f(&mut ctx.borrow_mut())); } diff --git a/crates/stdx/src/rand.rs b/crates/stdx/src/rand.rs index 64aa57eae0..115a073dab 100644 --- a/crates/stdx/src/rand.rs +++ b/crates/stdx/src/rand.rs @@ -14,8 +14,7 @@ pub fn shuffle(slice: &mut [T], mut rand_index: impl FnMut(usize) -> usize) { } pub fn seed() -> u64 { - use std::collections::hash_map::RandomState; use std::hash::{BuildHasher, Hasher}; - - RandomState::new().build_hasher().finish() + #[allow(clippy::disallowed_types)] + std::collections::hash_map::RandomState::new().build_hasher().finish() } diff --git a/crates/syntax/src/ast/edit_in_place.rs b/crates/syntax/src/ast/edit_in_place.rs index c9944b75b0..bc9c54d0b7 100644 --- a/crates/syntax/src/ast/edit_in_place.rs +++ b/crates/syntax/src/ast/edit_in_place.rs @@ -627,6 +627,8 @@ impl ast::Impl { } impl ast::AssocItemList { + /// Adds a new associated item after all of the existing associated items. + /// /// Attention! This function does align the first line of `item` with respect to `self`, /// but it does _not_ change indentation of other lines (if any). pub fn add_item(&self, item: ast::AssocItem) { @@ -650,6 +652,46 @@ impl ast::AssocItemList { ]; ted::insert_all(position, elements); } + + /// Adds a new associated item at the start of the associated item list. + /// + /// Attention! This function does align the first line of `item` with respect to `self`, + /// but it does _not_ change indentation of other lines (if any). + pub fn add_item_at_start(&self, item: ast::AssocItem) { + match self.assoc_items().next() { + Some(first_item) => { + let indent = IndentLevel::from_node(first_item.syntax()); + let before = Position::before(first_item.syntax()); + + ted::insert_all( + before, + vec![ + item.syntax().clone().into(), + make::tokens::whitespace(&format!("\n\n{indent}")).into(), + ], + ) + } + None => { + let (indent, position, whitespace) = match self.l_curly_token() { + Some(l_curly) => { + normalize_ws_between_braces(self.syntax()); + (IndentLevel::from_token(&l_curly) + 1, Position::after(&l_curly), "\n") + } + None => (IndentLevel::single(), Position::first_child_of(self.syntax()), ""), + }; + + let mut elements = vec![]; + + // Avoid pushing an empty whitespace token + if !indent.is_zero() || !whitespace.is_empty() { + elements.push(make::tokens::whitespace(&format!("{whitespace}{indent}")).into()) + } + elements.push(item.syntax().clone().into()); + + ted::insert_all(position, elements) + } + }; + } } impl ast::Fn { diff --git a/crates/syntax/src/ast/make.rs b/crates/syntax/src/ast/make.rs index d5eda8f15e..120d801c8d 100644 --- a/crates/syntax/src/ast/make.rs +++ b/crates/syntax/src/ast/make.rs @@ -68,6 +68,9 @@ pub mod ext { pub fn expr_ty_new(ty: &ast::Type) -> ast::Expr { expr_from_text(&format!("{ty}::new()")) } + pub fn expr_self() -> ast::Expr { + expr_from_text("self") + } pub fn zero_number() -> ast::Expr { expr_from_text("0") @@ -236,24 +239,21 @@ fn merge_where_clause( pub fn impl_( generic_params: Option, - generic_args: Option, + generic_args: Option, path_type: ast::Type, where_clause: Option, body: Option>>, ) -> ast::Impl { - let (gen_params, tr_gen_args) = match (generic_params, generic_args) { - (None, None) => (String::new(), String::new()), - (None, Some(args)) => (String::new(), args.to_generic_args().to_string()), - (Some(params), None) => (params.to_string(), params.to_generic_args().to_string()), - (Some(params), Some(args)) => match merge_gen_params(Some(params.clone()), Some(args)) { - Some(merged) => (params.to_string(), merged.to_generic_args().to_string()), - None => (params.to_string(), String::new()), - }, - }; + let gen_args = generic_args.map_or_else(String::new, |it| it.to_string()); + + let gen_params = generic_params.map_or_else(String::new, |it| it.to_string()); + + let body_newline = + if where_clause.is_some() && body.is_none() { "\n".to_owned() } else { String::new() }; let where_clause = match where_clause { - Some(pr) => pr.to_string(), - None => " ".to_string(), + Some(pr) => format!("\n{pr}\n"), + None => " ".to_owned(), }; let body = match body { @@ -261,7 +261,9 @@ pub fn impl_( None => String::new(), }; - ast_from_text(&format!("impl{gen_params} {path_type}{tr_gen_args}{where_clause}{{{}}}", body)) + ast_from_text(&format!( + "impl{gen_params} {path_type}{gen_args}{where_clause}{{{body_newline}{body}}}" + )) } pub fn impl_trait( @@ -282,22 +284,27 @@ pub fn impl_trait( let trait_gen_args = trait_gen_args.map(|args| args.to_string()).unwrap_or_default(); let type_gen_args = type_gen_args.map(|args| args.to_string()).unwrap_or_default(); - let gen_params = match merge_gen_params(trait_gen_params, type_gen_params) { - Some(pars) => pars.to_string(), - None => String::new(), - }; + let gen_params = merge_gen_params(trait_gen_params, type_gen_params) + .map_or_else(String::new, |it| it.to_string()); let is_negative = if is_negative { "! " } else { "" }; + let body_newline = + if (ty_where_clause.is_some() || trait_where_clause.is_some()) && body.is_none() { + "\n".to_owned() + } else { + String::new() + }; + let where_clause = merge_where_clause(ty_where_clause, trait_where_clause) - .map_or_else(|| " ".to_string(), |wc| format!("\n{}\n", wc)); + .map_or_else(|| " ".to_owned(), |wc| format!("\n{}\n", wc)); let body = match body { Some(bd) => bd.iter().map(|elem| elem.to_string()).join(""), None => String::new(), }; - ast_from_text(&format!("{is_unsafe}impl{gen_params} {is_negative}{path_type}{trait_gen_args} for {ty}{type_gen_args}{where_clause}{{{}}}" , body)) + ast_from_text(&format!("{is_unsafe}impl{gen_params} {is_negative}{path_type}{trait_gen_args} for {ty}{type_gen_args}{where_clause}{{{body_newline}{body}}}")) } pub fn impl_trait_type(bounds: ast::TypeBoundList) -> ast::ImplTraitType { @@ -371,7 +378,7 @@ pub fn use_tree( alias: Option, add_star: bool, ) -> ast::UseTree { - let mut buf = "use ".to_string(); + let mut buf = "use ".to_owned(); buf += &path.syntax().to_string(); if let Some(use_tree_list) = use_tree_list { format_to!(buf, "::{use_tree_list}"); @@ -437,7 +444,7 @@ pub fn block_expr( stmts: impl IntoIterator, tail_expr: Option, ) -> ast::BlockExpr { - let mut buf = "{\n".to_string(); + let mut buf = "{\n".to_owned(); for stmt in stmts.into_iter() { format_to!(buf, " {stmt}\n"); } @@ -452,7 +459,7 @@ pub fn async_move_block_expr( stmts: impl IntoIterator, tail_expr: Option, ) -> ast::BlockExpr { - let mut buf = "async move {\n".to_string(); + let mut buf = "async move {\n".to_owned(); for stmt in stmts.into_iter() { format_to!(buf, " {stmt}\n"); } @@ -475,7 +482,7 @@ pub fn hacky_block_expr( elements: impl IntoIterator, tail_expr: Option, ) -> ast::BlockExpr { - let mut buf = "{\n".to_string(); + let mut buf = "{\n".to_owned(); for node_or_token in elements.into_iter() { match node_or_token { rowan::NodeOrToken::Node(n) => format_to!(buf, " {n}\n"), @@ -903,7 +910,12 @@ pub fn trait_( ast_from_text(&text) } -pub fn type_bound(bound: &str) -> ast::TypeBound { +// FIXME: remove when no one depends on `generate_impl_text_inner` +pub fn type_bound_text(bound: &str) -> ast::TypeBound { + ast_from_text(&format!("fn f() {{ }}")) +} + +pub fn type_bound(bound: ast::Type) -> ast::TypeBound { ast_from_text(&format!("fn f() {{ }}")) } diff --git a/crates/syntax/src/fuzz.rs b/crates/syntax/src/fuzz.rs index 239a89f9b2..2873867179 100644 --- a/crates/syntax/src/fuzz.rs +++ b/crates/syntax/src/fuzz.rs @@ -34,7 +34,7 @@ impl CheckReparse { let mut lines = data.lines(); let delete_start = usize::from_str(lines.next()?).ok()? + PREFIX.len(); let delete_len = usize::from_str(lines.next()?).ok()?; - let insert = lines.next()?.to_string(); + let insert = lines.next()?.to_owned(); let text = lines.collect::>().join("\n"); let text = format!("{PREFIX}{text}{SUFFIX}"); text.get(delete_start..delete_start.checked_add(delete_len)?)?; // make sure delete is a valid range @@ -46,6 +46,7 @@ impl CheckReparse { Some(CheckReparse { text, edit, edited_text }) } + #[allow(clippy::print_stderr)] pub fn run(&self) { let parse = SourceFile::parse(&self.text); let new_parse = parse.reparse(&self.edit); diff --git a/crates/syntax/src/lib.rs b/crates/syntax/src/lib.rs index 62a0261d7a..960889b742 100644 --- a/crates/syntax/src/lib.rs +++ b/crates/syntax/src/lib.rs @@ -432,7 +432,7 @@ fn api_walkthrough() { WalkEvent::Enter(node) => { let text = match &node { NodeOrToken::Node(it) => it.text().to_string(), - NodeOrToken::Token(it) => it.text().to_string(), + NodeOrToken::Token(it) => it.text().to_owned(), }; format_to!(buf, "{:indent$}{:?} {:?}\n", " ", text, node.kind(), indent = indent); indent += 2; diff --git a/crates/syntax/src/parsing.rs b/crates/syntax/src/parsing.rs index 047e670c9f..1250b5274c 100644 --- a/crates/syntax/src/parsing.rs +++ b/crates/syntax/src/parsing.rs @@ -28,7 +28,7 @@ pub(crate) fn build_tree( parser::StrStep::Enter { kind } => builder.start_node(kind), parser::StrStep::Exit => builder.finish_node(), parser::StrStep::Error { msg, pos } => { - builder.error(msg.to_string(), pos.try_into().unwrap()) + builder.error(msg.to_owned(), pos.try_into().unwrap()) } }); diff --git a/crates/syntax/src/parsing/reparsing.rs b/crates/syntax/src/parsing/reparsing.rs index 0ddc641711..14715b5725 100644 --- a/crates/syntax/src/parsing/reparsing.rs +++ b/crates/syntax/src/parsing/reparsing.rs @@ -105,7 +105,7 @@ fn get_text_after_edit(element: SyntaxElement, edit: &Indel) -> String { let edit = Indel::replace(edit.delete - element.text_range().start(), edit.insert.clone()); let mut text = match element { - NodeOrToken::Token(token) => token.text().to_string(), + NodeOrToken::Token(token) => token.text().to_owned(), NodeOrToken::Node(node) => node.text().to_string(), }; edit.apply(&mut text); diff --git a/crates/syntax/src/ptr.rs b/crates/syntax/src/ptr.rs index b716d36706..fb8aee9c3b 100644 --- a/crates/syntax/src/ptr.rs +++ b/crates/syntax/src/ptr.rs @@ -36,7 +36,7 @@ impl std::fmt::Debug for AstPtr { impl Copy for AstPtr {} impl Clone for AstPtr { fn clone(&self) -> AstPtr { - AstPtr { raw: self.raw, _ty: PhantomData } + *self } } diff --git a/crates/syntax/src/tests/sourcegen_ast.rs b/crates/syntax/src/tests/sourcegen_ast.rs index c2e921e4b6..2fd7a47349 100644 --- a/crates/syntax/src/tests/sourcegen_ast.rs +++ b/crates/syntax/src/tests/sourcegen_ast.rs @@ -3,14 +3,12 @@ //! Specifically, it generates the `SyntaxKind` enum and a number of newtype //! wrappers around `SyntaxNode` which implement `syntax::AstNode`. -use std::{ - collections::{BTreeSet, HashSet}, - fmt::Write, -}; +use std::{collections::BTreeSet, fmt::Write}; use itertools::Itertools; use proc_macro2::{Punct, Spacing}; use quote::{format_ident, quote}; +use rustc_hash::FxHashSet; use ungrammar::{Grammar, Rule}; use crate::tests::ast_src::{ @@ -278,7 +276,7 @@ fn generate_nodes(kinds: KindsSrc<'_>, grammar: &AstSrc) -> String { } }); - let defined_nodes: HashSet<_> = node_names.collect(); + let defined_nodes: FxHashSet<_> = node_names.collect(); for node in kinds .nodes @@ -575,7 +573,7 @@ fn lower(grammar: &Grammar) -> AstSrc { tokens: "Whitespace Comment String ByteString CString IntNumber FloatNumber Char Byte Ident" .split_ascii_whitespace() - .map(|it| it.to_string()) + .map(|it| it.to_owned()) .collect::>(), ..Default::default() }; @@ -818,7 +816,7 @@ fn extract_struct_trait(node: &mut AstNodeSrc, trait_name: &str, methods: &[&str } } if to_remove.len() == methods.len() { - node.traits.push(trait_name.to_string()); + node.traits.push(trait_name.to_owned()); node.remove_field(to_remove); } } diff --git a/crates/test-fixture/src/lib.rs b/crates/test-fixture/src/lib.rs index b5ff7a1bf5..28e757e81b 100644 --- a/crates/test-fixture/src/lib.rs +++ b/crates/test-fixture/src/lib.rs @@ -260,7 +260,7 @@ impl ChangeFixture { file_id = FileId::from_raw(file_id.index() + 1); let mut fs = FileSet::default(); - fs.insert(core_file, VfsPath::new_virtual_path("/sysroot/core/lib.rs".to_string())); + fs.insert(core_file, VfsPath::new_virtual_path("/sysroot/core/lib.rs".to_owned())); roots.push(SourceRoot::new_library(fs)); source_change.change_file(core_file, Some(mini_core.source_code().into())); @@ -270,7 +270,7 @@ impl ChangeFixture { let core_crate = crate_graph.add_crate_root( core_file, Edition::Edition2021, - Some(CrateDisplayName::from_canonical_name("core".to_string())), + Some(CrateDisplayName::from_canonical_name("core".to_owned())), None, Default::default(), Default::default(), @@ -304,7 +304,7 @@ impl ChangeFixture { let mut fs = FileSet::default(); fs.insert( proc_lib_file, - VfsPath::new_virtual_path("/sysroot/proc_macros/lib.rs".to_string()), + VfsPath::new_virtual_path("/sysroot/proc_macros/lib.rs".to_owned()), ); roots.push(SourceRoot::new_library(fs)); @@ -315,7 +315,7 @@ impl ChangeFixture { let proc_macros_crate = crate_graph.add_crate_root( proc_lib_file, Edition::Edition2021, - Some(CrateDisplayName::from_canonical_name("proc_macros".to_string())), + Some(CrateDisplayName::from_canonical_name("proc_macros".to_owned())), None, Default::default(), Default::default(), @@ -598,7 +598,7 @@ impl ProcMacroExpander for MirrorProcMacroExpander { }; token_trees.push(tt); } - Subtree { delimiter: input.delimiter, token_trees } + Subtree { delimiter: input.delimiter, token_trees: token_trees.into_boxed_slice() } } Ok(traverse(input)) } diff --git a/crates/test-utils/src/bench_fixture.rs b/crates/test-utils/src/bench_fixture.rs index 9296fd2e68..d83f957686 100644 --- a/crates/test-utils/src/bench_fixture.rs +++ b/crates/test-utils/src/bench_fixture.rs @@ -12,7 +12,7 @@ pub fn big_struct() -> String { } pub fn big_struct_n(n: u32) -> String { - let mut buf = "pub struct RegisterBlock {".to_string(); + let mut buf = "pub struct RegisterBlock {".to_owned(); for i in 0..n { format_to!(buf, " /// Doc comment for {}.\n", i); format_to!(buf, " pub s{}: S{},\n", i, i); diff --git a/crates/test-utils/src/fixture.rs b/crates/test-utils/src/fixture.rs index 3f8b5a0896..595281336d 100644 --- a/crates/test-utils/src/fixture.rs +++ b/crates/test-utils/src/fixture.rs @@ -178,13 +178,13 @@ impl FixtureWithProjectMeta { if let Some(meta) = fixture.strip_prefix("//- toolchain:") { let (meta, remain) = meta.split_once('\n').unwrap(); - toolchain = Some(meta.trim().to_string()); + toolchain = Some(meta.trim().to_owned()); fixture = remain; } if let Some(meta) = fixture.strip_prefix("//- proc_macros:") { let (meta, remain) = meta.split_once('\n').unwrap(); - proc_macro_names = meta.split(',').map(|it| it.trim().to_string()).collect(); + proc_macro_names = meta.split(',').map(|it| it.trim().to_owned()).collect(); fixture = remain; } @@ -234,7 +234,7 @@ impl FixtureWithProjectMeta { let meta = meta["//-".len()..].trim(); let mut components = meta.split_ascii_whitespace(); - let path = components.next().expect("fixture meta must start with a path").to_string(); + let path = components.next().expect("fixture meta must start with a path").to_owned(); assert!(path.starts_with('/'), "fixture path does not start with `/`: {path:?}"); let mut krate = None; @@ -246,7 +246,7 @@ impl FixtureWithProjectMeta { let mut introduce_new_source_root = None; let mut library = false; let mut target_data_layout = Some( - "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128".to_string(), + "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128".to_owned(), ); for component in components { if component == "library" { @@ -257,22 +257,22 @@ impl FixtureWithProjectMeta { let (key, value) = component.split_once(':').unwrap_or_else(|| panic!("invalid meta line: {meta:?}")); match key { - "crate" => krate = Some(value.to_string()), - "deps" => deps = value.split(',').map(|it| it.to_string()).collect(), + "crate" => krate = Some(value.to_owned()), + "deps" => deps = value.split(',').map(|it| it.to_owned()).collect(), "extern-prelude" => { if value.is_empty() { extern_prelude = Some(Vec::new()); } else { extern_prelude = - Some(value.split(',').map(|it| it.to_string()).collect::>()); + Some(value.split(',').map(|it| it.to_owned()).collect::>()); } } - "edition" => edition = Some(value.to_string()), + "edition" => edition = Some(value.to_owned()), "cfg" => { for entry in value.split(',') { match entry.split_once('=') { - Some((k, v)) => cfgs.push((k.to_string(), Some(v.to_string()))), - None => cfgs.push((entry.to_string(), None)), + Some((k, v)) => cfgs.push((k.to_owned(), Some(v.to_owned()))), + None => cfgs.push((entry.to_owned(), None)), } } } @@ -283,8 +283,8 @@ impl FixtureWithProjectMeta { } } } - "new_source_root" => introduce_new_source_root = Some(value.to_string()), - "target_data_layout" => target_data_layout = Some(value.to_string()), + "new_source_root" => introduce_new_source_root = Some(value.to_owned()), + "target_data_layout" => target_data_layout = Some(value.to_owned()), _ => panic!("bad component: {component:?}"), } } @@ -381,7 +381,7 @@ impl MiniCore { let (flag, deps) = line.split_once(':').unwrap(); let flag = flag.trim(); - self.valid_flags.push(flag.to_string()); + self.valid_flags.push(flag.to_owned()); implications.extend( iter::repeat(flag) .zip(deps.split(", ").map(str::trim).filter(|dep| !dep.is_empty())), @@ -401,7 +401,7 @@ impl MiniCore { let mut changed = false; for &(u, v) in &implications { if self.has_flag(u) && !self.has_flag(v) { - self.activated_flags.push(v.to_string()); + self.activated_flags.push(v.to_owned()); changed = true; } } @@ -486,9 +486,9 @@ fn parse_fixture_gets_full_meta() { mod m; "#, ); - assert_eq!(toolchain, Some("nightly".to_string())); - assert_eq!(proc_macro_names, vec!["identity".to_string()]); - assert_eq!(mini_core.unwrap().activated_flags, vec!["coerce_unsized".to_string()]); + assert_eq!(toolchain, Some("nightly".to_owned())); + assert_eq!(proc_macro_names, vec!["identity".to_owned()]); + assert_eq!(mini_core.unwrap().activated_flags, vec!["coerce_unsized".to_owned()]); assert_eq!(1, parsed.len()); let meta = &parsed[0]; diff --git a/crates/test-utils/src/lib.rs b/crates/test-utils/src/lib.rs index e48b273130..b750107803 100644 --- a/crates/test-utils/src/lib.rs +++ b/crates/test-utils/src/lib.rs @@ -7,6 +7,7 @@ //! * marks (see the eponymous module). #![warn(rust_2018_idioms, unused_lifetimes)] +#![allow(clippy::print_stderr)] mod assert_linear; pub mod bench_fixture; @@ -163,7 +164,7 @@ pub fn extract_tags(mut text: &str, tag: &str) -> (Vec<(TextRange, Option').unwrap(); let attr = text[open.len()..close_open].trim(); - let attr = if attr.is_empty() { None } else { Some(attr.to_string()) }; + let attr = if attr.is_empty() { None } else { Some(attr.to_owned()) }; text = &text[close_open + '>'.len_utf8()..]; let from = TextSize::of(&res); stack.push((from, attr)); @@ -325,7 +326,7 @@ fn extract_line_annotations(mut line: &str) -> Vec { content = &content["file".len()..]; } - let content = content.trim_start().to_string(); + let content = content.trim_start().to_owned(); let annotation = if continuation { LineAnnotation::Continuation { offset: range.end(), content } @@ -424,7 +425,7 @@ pub fn format_diff(chunks: Vec>) -> String { /// /// A benchmark test looks like this: /// -/// ``` +/// ```ignore /// #[test] /// fn benchmark_foo() { /// if skip_slow_tests() { return; } diff --git a/crates/test-utils/src/minicore.rs b/crates/test-utils/src/minicore.rs index 9c25d88cb8..23a3a7e0af 100644 --- a/crates/test-utils/src/minicore.rs +++ b/crates/test-utils/src/minicore.rs @@ -328,7 +328,6 @@ pub mod convert { } pub mod mem { - // region:drop // region:manually_drop #[lang = "manually_drop"] #[repr(transparent)] @@ -353,6 +352,7 @@ pub mod mem { // endregion:manually_drop + // region:drop pub fn drop(_x: T) {} pub const fn replace(dest: &mut T, src: T) -> T { unsafe { diff --git a/crates/text-edit/src/lib.rs b/crates/text-edit/src/lib.rs index fb52a50f0b..e2ff373c1b 100644 --- a/crates/text-edit/src/lib.rs +++ b/crates/text-edit/src/lib.rs @@ -231,11 +231,11 @@ mod tests { #[test] fn test_apply() { - let mut text = "_11h1_2222_xx3333_4444_6666".to_string(); + let mut text = "_11h1_2222_xx3333_4444_6666".to_owned(); let mut builder = TextEditBuilder::default(); - builder.replace(range(3, 4), "1".to_string()); + builder.replace(range(3, 4), "1".to_owned()); builder.delete(range(11, 13)); - builder.insert(22.into(), "_5555".to_string()); + builder.insert(22.into(), "_5555".to_owned()); let text_edit = builder.finish(); text_edit.apply(&mut text); diff --git a/crates/tt/src/buffer.rs b/crates/tt/src/buffer.rs index cade2e9f67..cd41af03c6 100644 --- a/crates/tt/src/buffer.rs +++ b/crates/tt/src/buffer.rs @@ -106,7 +106,7 @@ impl<'t, Span> TokenBuffer<'t, Span> { for (child_idx, (subtree, tt)) in children { let idx = TokenBuffer::new_inner( - subtree.token_trees.as_slice(), + &*subtree.token_trees, buffers, Some(EntryPtr(EntryId(res), child_idx + 1)), ); diff --git a/crates/tt/src/lib.rs b/crates/tt/src/lib.rs index b3b0eeda75..9004bff53a 100644 --- a/crates/tt/src/lib.rs +++ b/crates/tt/src/lib.rs @@ -23,10 +23,10 @@ pub enum TokenTree { } impl_from!(Leaf, Subtree for TokenTree); impl TokenTree { - pub const fn empty(span: S) -> Self { + pub fn empty(span: S) -> Self { Self::Subtree(Subtree { delimiter: Delimiter::invisible_spanned(span), - token_trees: vec![], + token_trees: Box::new([]), }) } @@ -34,7 +34,7 @@ impl TokenTree { match self { TokenTree::Leaf(_) => Subtree { delimiter: Delimiter::invisible_delim_spanned(span), - token_trees: vec![self], + token_trees: Box::new([self]), }, TokenTree::Subtree(s) => s, } @@ -69,25 +69,35 @@ impl_from!(Literal, Punct, Ident for Leaf); #[derive(Clone, PartialEq, Eq, Hash)] pub struct Subtree { pub delimiter: Delimiter, - pub token_trees: Vec>, + pub token_trees: Box<[TokenTree]>, } impl Subtree { - pub const fn empty(span: DelimSpan) -> Self { - Subtree { delimiter: Delimiter::invisible_delim_spanned(span), token_trees: vec![] } + pub fn empty(span: DelimSpan) -> Self { + Subtree { delimiter: Delimiter::invisible_delim_spanned(span), token_trees: Box::new([]) } } - pub fn visit_ids(&mut self, f: &mut impl FnMut(S) -> S) { - self.delimiter.open = f(self.delimiter.open); - self.delimiter.close = f(self.delimiter.close); - self.token_trees.iter_mut().for_each(|tt| match tt { - crate::TokenTree::Leaf(leaf) => match leaf { - crate::Leaf::Literal(it) => it.span = f(it.span), - crate::Leaf::Punct(it) => it.span = f(it.span), - crate::Leaf::Ident(it) => it.span = f(it.span), - }, - crate::TokenTree::Subtree(s) => s.visit_ids(f), - }) + /// This is slow, and should be avoided, as it will always reallocate! + pub fn push(&mut self, subtree: TokenTree) { + let mut mutable_trees = std::mem::take(&mut self.token_trees).into_vec(); + + // Reserve exactly space for one element, to avoid `into_boxed_slice` having to reallocate again. + mutable_trees.reserve_exact(1); + mutable_trees.push(subtree); + + self.token_trees = mutable_trees.into_boxed_slice(); + } +} + +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct SubtreeBuilder { + pub delimiter: Delimiter, + pub token_trees: Vec>, +} + +impl SubtreeBuilder { + pub fn build(self) -> Subtree { + Subtree { delimiter: self.delimiter, token_trees: self.token_trees.into_boxed_slice() } } } @@ -241,7 +251,7 @@ impl fmt::Display for Subtree { }; f.write_str(l)?; let mut needs_space = false; - for tt in &self.token_trees { + for tt in self.token_trees.iter() { if needs_space { f.write_str(" ")?; } @@ -316,7 +326,7 @@ impl Subtree { let mut res = String::new(); res.push_str(delim.0); let mut last = None; - for child in &self.token_trees { + for child in self.token_trees.iter() { let s = match child { TokenTree::Leaf(it) => { let s = match it { @@ -326,11 +336,11 @@ impl Subtree { }; match (it, last) { (Leaf::Ident(_), Some(&TokenTree::Leaf(Leaf::Ident(_)))) => { - " ".to_string() + &s + " ".to_owned() + &s } (Leaf::Punct(_), Some(TokenTree::Leaf(Leaf::Punct(punct)))) => { if punct.spacing == Spacing::Alone { - " ".to_string() + &s + " ".to_owned() + &s } else { s } diff --git a/crates/vfs/src/loader.rs b/crates/vfs/src/loader.rs index e49849d230..c3d1ff7271 100644 --- a/crates/vfs/src/loader.rs +++ b/crates/vfs/src/loader.rs @@ -201,7 +201,7 @@ impl Directories { /// ``` fn dirs(base: AbsPathBuf, exclude: &[&str]) -> Directories { let exclude = exclude.iter().map(|it| base.join(it)).collect::>(); - Directories { extensions: vec!["rs".to_string()], include: vec![base], exclude } + Directories { extensions: vec!["rs".to_owned()], include: vec![base], exclude } } impl fmt::Debug for Message { diff --git a/crates/vfs/src/vfs_path/tests.rs b/crates/vfs/src/vfs_path/tests.rs index 510e021e89..2d89362ee0 100644 --- a/crates/vfs/src/vfs_path/tests.rs +++ b/crates/vfs/src/vfs_path/tests.rs @@ -2,29 +2,29 @@ use super::*; #[test] fn virtual_path_extensions() { - assert_eq!(VirtualPath("/".to_string()).name_and_extension(), None); + assert_eq!(VirtualPath("/".to_owned()).name_and_extension(), None); assert_eq!( - VirtualPath("/directory".to_string()).name_and_extension(), + VirtualPath("/directory".to_owned()).name_and_extension(), Some(("directory", None)) ); assert_eq!( - VirtualPath("/directory/".to_string()).name_and_extension(), + VirtualPath("/directory/".to_owned()).name_and_extension(), Some(("directory", None)) ); assert_eq!( - VirtualPath("/directory/file".to_string()).name_and_extension(), + VirtualPath("/directory/file".to_owned()).name_and_extension(), Some(("file", None)) ); assert_eq!( - VirtualPath("/directory/.file".to_string()).name_and_extension(), + VirtualPath("/directory/.file".to_owned()).name_and_extension(), Some((".file", None)) ); assert_eq!( - VirtualPath("/directory/.file.rs".to_string()).name_and_extension(), + VirtualPath("/directory/.file.rs".to_owned()).name_and_extension(), Some((".file", Some("rs"))) ); assert_eq!( - VirtualPath("/directory/file.rs".to_string()).name_and_extension(), + VirtualPath("/directory/file.rs".to_owned()).name_and_extension(), Some(("file", Some("rs"))) ); } diff --git a/docs/dev/lsp-extensions.md b/docs/dev/lsp-extensions.md index bc558c2024..f3100ee194 100644 --- a/docs/dev/lsp-extensions.md +++ b/docs/dev/lsp-extensions.md @@ -1,5 +1,5 @@