diff --git a/crates/libanalysis/Cargo.toml b/crates/libanalysis/Cargo.toml index 4c92951b14..88f29d7c86 100644 --- a/crates/libanalysis/Cargo.toml +++ b/crates/libanalysis/Cargo.toml @@ -14,6 +14,7 @@ fst = "0.3.1" im = "12.0.0" libsyntax2 = { path = "../libsyntax2" } libeditor = { path = "../libeditor" } +salsa = { path = "../salsa" } [dev-dependencies] test_utils = { path = "../test_utils" } diff --git a/crates/libanalysis/src/db.rs b/crates/libanalysis/src/db.rs deleted file mode 100644 index d30e75fe2b..0000000000 --- a/crates/libanalysis/src/db.rs +++ /dev/null @@ -1,306 +0,0 @@ -use std::{ - hash::{Hash, Hasher}, - sync::Arc, - cell::RefCell, - fmt::Debug, - any::Any, -}; -use parking_lot::Mutex; -use libsyntax2::{File}; -use im; -use { - FileId, - imp::{FileResolverImp}, -}; - -#[derive(Debug)] -pub(crate) struct DbHost { - db: Arc, -} - -impl DbHost { - pub(crate) fn new() -> DbHost { - let db = Db { - file_resolver: FileResolverImp::default(), - files: im::HashMap::new(), - cache: Mutex::new(Cache::new()) - }; - DbHost { db: Arc::new(db) } - } - pub(crate) fn change_file(&mut self, file_id: FileId, text: Option) { - let db = self.db_mut(); - match text { - None => { - db.files.remove(&file_id); - } - Some(text) => { - db.files.insert(file_id, Arc::new(text)); - } - } - } - pub(crate) fn set_file_resolver(&mut self, file_resolver: FileResolverImp) { - let db = self.db_mut(); - db.file_resolver = file_resolver - } - pub(crate) fn query_ctx(&self) -> QueryCtx { - QueryCtx { - db: Arc::clone(&self.db), - stack: RefCell::new(Vec::new()), - trace: RefCell::new(Vec::new()), - } - } - fn db_mut(&mut self) -> &mut Db { - // NB: this "forks" the database - let db = Arc::make_mut(&mut self.db); - db.cache.get_mut().gen += 1; - db - } -} - -type QueryInvocationId = (u32, u64); -type Gen = u64; -type OutputHash = u64; - -fn id(params: &Q::Params) -> QueryInvocationId { - use std::collections::hash_map::DefaultHasher; - let mut hasher = DefaultHasher::new(); - params.hash(&mut hasher); - (Q::ID, hasher.finish()) -} -fn output_hash(output: &Q::Output) -> OutputHash { - use std::collections::hash_map::DefaultHasher; - let mut hasher = DefaultHasher::new(); - output.hash(&mut hasher); - hasher.finish() -} - -#[derive(Debug)] -pub(crate) struct Db { - file_resolver: FileResolverImp, - files: im::HashMap>, - cache: Mutex, -} - -impl Clone for Db { - fn clone(&self) -> Db { - Db { - file_resolver: self.file_resolver.clone(), - files: self.files.clone(), - cache: Mutex::new(Cache::new()), - } - } -} - -type QueryDeps = Vec<(QueryInvocationId, Arc, OutputHash)>; - -#[derive(Default, Debug)] -pub(crate) struct Cache { - gen: Gen, - green: im::HashMap, - deps: im::HashMap, - results: im::HashMap>, -} - - -#[allow(type_alias_bounds)] -pub(crate) type QueryCache = im::HashMap< - ::Params, - ::Output ->; - -impl Cache { - fn new() -> Cache { - Default::default() - } - - fn get_result(&self, id: QueryInvocationId) -> Q::Output - where - Q::Output: Clone - { - let res = &self.results[&id]; - let res = res.downcast_ref::().unwrap(); - res.clone() - } -} - -pub(crate) struct QueryCtx { - db: Arc, - stack: RefCell>, - pub(crate) trace: RefCell>, -} - -#[derive(Clone, Copy, Debug)] -pub(crate) struct TraceEvent { - pub(crate) query_id: u32, - pub(crate) kind: TraceEventKind -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub(crate) enum TraceEventKind { - Start, Evaluating, Finish -} - -impl QueryCtx { - pub(crate) fn get(&self, params: &Q::Params) -> Q::Output { - let me = id::(params); - self.trace(TraceEvent { query_id: Q::ID, kind: TraceEventKind::Start }); - let res = Q::get(self, params); - self.trace(TraceEvent { query_id: Q::ID, kind: TraceEventKind::Finish }); - { - let mut stack = self.stack.borrow_mut(); - if let Some((_, ref mut deps)) = stack.last_mut() { - let params = Arc::new(params.clone()); - deps.push((me, params, output_hash::(&res))); - } - } - - res - } - fn trace(&self, event: TraceEvent) { - self.trace.borrow_mut().push(event) - } -} - -pub(crate) trait Query { - const ID: u32; - type Params: Hash + Eq + Debug + Clone + Any + 'static; - type Output: Hash + Debug + Any + 'static; -} - -pub(crate) trait Get: Query { - fn get(ctx: &QueryCtx, params: &Self::Params) -> Self::Output; -} - -impl Get for Q -where - Q::Params: Clone, - Q::Output: Clone, -{ - fn get(ctx: &QueryCtx, params: &Self::Params) -> Self::Output { - if let Some(res) = try_reuse::(ctx, params) { - return res; - } - - let me = id::(params); - ctx.trace(TraceEvent { query_id: Q::ID, kind: TraceEventKind::Evaluating }); - ctx.stack.borrow_mut().push((me, Vec::new())); - let res = Self::eval(ctx, params); - let (also_me, deps) = ctx.stack.borrow_mut().pop().unwrap(); - assert_eq!(also_me, me); - let mut cache = ctx.db.cache.lock(); - cache.deps.insert(me, deps); - let gen = cache.gen; - let output_hash = output_hash::(&res); - let id = id::(params); - cache.green.insert(id, (gen, output_hash)); - cache.results.insert(me, Arc::new(res.clone())); - res - } -} - -fn try_reuse(ctx: &QueryCtx, params: &Q::Params) -> Option -where - Q::Params: Clone, - Q::Output: Clone, -{ - let id = id::(params); - let mut cache = ctx.db.cache.lock(); - let curr_gen = cache.gen; - let old_hash = match *cache.green.get(&id)? { - (gen, _) if gen == curr_gen => { - return Some(cache.get_result::(id)); - } - (_, hash) => hash, - }; - let deps_are_fresh = cache.deps[&id] - .iter() - .all(|&(dep_id, _, dep_hash)| { - match cache.green.get(&dep_id) { - //TODO: store the value of parameters, and re-execute the query - Some((gen, hash)) if gen == &curr_gen && hash == &dep_hash => true, - _ => false, - } - }); - if !deps_are_fresh { - return None; - } - cache.green.insert(id, (curr_gen, old_hash)); - Some(cache.get_result::(id)) -} - -pub(crate) trait Eval: Query -where - Self::Params: Clone, - Self::Output: Clone, -{ - fn eval(ctx: &QueryCtx, params: &Self::Params) -> Self::Output; -} - -#[derive(Debug)] -pub(crate) struct DbFiles { - db: Arc, -} - -impl Hash for DbFiles { - fn hash(&self, hasher: &mut H) { - self.db.cache.lock().gen.hash(hasher) - } -} - -impl DbFiles { - pub(crate) fn iter<'a>(&'a self) -> impl Iterator + 'a { - self.db.files.keys().cloned() - } - pub(crate) fn file_resolver(&self) -> FileResolverImp { - self.db.file_resolver.clone() - } -} - -pub(crate) enum Files {} -impl Query for Files { - const ID: u32 = 1; - type Params = (); - type Output = DbFiles; -} -impl Get for Files { - fn get(ctx: &QueryCtx, params: &()) -> DbFiles { - let res = DbFiles { db: Arc::clone(&ctx.db) }; - let id = id::(params); - let hash = output_hash::(&res); - let mut cache = ctx.db.cache.lock(); - let gen = cache.gen; - cache.green.insert(id, (gen, hash)); - res - } -} - -enum FileText {} -impl Query for FileText { - const ID: u32 = 10; - type Params = FileId; - type Output = Arc; -} -impl Get for FileText { - fn get(ctx: &QueryCtx, file_id: &FileId) -> Arc { - let res = ctx.db.files[file_id].clone(); - let id = id::(file_id); - let hash = output_hash::(&res); - let mut cache = ctx.db.cache.lock(); - let gen = cache.gen; - cache.green.insert(id, (gen, hash)); - res - } -} - -pub(crate) enum FileSyntax {} -impl Query for FileSyntax { - const ID: u32 = 20; - type Params = FileId; - type Output = File; -} -impl Eval for FileSyntax { - fn eval(ctx: &QueryCtx, file_id: &FileId) -> File { - let text = ctx.get::(file_id); - File::parse(&text) - } -} diff --git a/crates/libanalysis/src/db/mod.rs b/crates/libanalysis/src/db/mod.rs new file mode 100644 index 0000000000..f68aab61c0 --- /dev/null +++ b/crates/libanalysis/src/db/mod.rs @@ -0,0 +1,196 @@ +mod queries; + +use std::{ + hash::{Hash}, + sync::Arc, + fmt::Debug, + any::Any, + iter, +}; +use im; +use salsa; +use { + FileId, + imp::{FileResolverImp}, +}; + + +#[derive(Clone, Default)] +pub(crate) struct State { + pub(crate) resolver: FileResolverImp, + pub(crate) file_map: im::HashMap>, +} + +type Data = Arc; + +pub(crate) struct QueryCtx<'a> { + inner: &'a salsa::QueryCtx +} + +pub(crate) struct Db { + inner: salsa::Db +} + +struct GroundQuery { + id: u16, + f: fn(&State, &T) -> R, + h: fn(&R) -> u64, +} + +pub(crate) struct Query { + pub(crate) id: u16, + pub(crate) f: fn(QueryCtx, &T) -> R, +} + +impl Db { + pub(crate) fn new(state: State) -> Db { + Db { inner: salsa::Db::new(query_config(), state) } + } + pub(crate) fn state(&self) -> &State { + self.inner.ground_data() + } + pub(crate) fn with_state( + &self, + new_state: State, + updated_files: &[FileId], + file_set_changed: bool, + ) -> Db { + let mut inv = salsa::Invalidations::new(); + if file_set_changed { + inv.invalidate( + salsa::QueryTypeId(queries::FILE_SET.id), + iter::once(salsa::InputFingerprint(hash(&()))), + ); + } else { + inv.invalidate( + salsa::QueryTypeId(queries::FILE_SET.id), + iter::empty(), + ); + } + inv.invalidate( + salsa::QueryTypeId(queries::FILE_TEXT.id), + updated_files.iter().map(hash).map(salsa::InputFingerprint), + ); + Db { inner: self.inner.with_ground_data(new_state, inv) } + } + pub(crate) fn get(&self, q: Query, params: T) -> (Arc, Vec) + where + T: Hash + Send + Sync + 'static, + R: Send + Sync + 'static, + { + let query_id = salsa::QueryId( + salsa::QueryTypeId(q.id), + salsa::InputFingerprint(hash(¶ms)), + ); + let params = Arc::new(params); + let (res, events) = self.inner.get(query_id, params); + let res = res.downcast().unwrap(); + let events = events.into_iter().map(|it| it.0).collect(); + (res, events) + } + +} + +impl<'a> QueryCtx<'a> { + fn get_g(&self, q: GroundQuery, params: T) -> Arc + where + T: Hash + Send + Sync + 'static, + R: Send + Sync + 'static, + { + let query_id = salsa::QueryId( + salsa::QueryTypeId(q.id), + salsa::InputFingerprint(hash(¶ms)), + ); + let res = self.inner.get(query_id, Arc::new(params)); + res.downcast().unwrap() + } + pub(crate) fn get(&self, q: Query, params: T) -> Arc + where + T: Hash + Send + Sync + 'static, + R: Send + Sync + 'static, + { + let query_id = salsa::QueryId( + salsa::QueryTypeId(q.id), + salsa::InputFingerprint(hash(¶ms)), + ); + let res = self.inner.get(query_id, Arc::new(params)); + res.downcast().unwrap() + } +} + +fn query_config() -> salsa::QueryConfig { + let mut res = salsa::QueryConfig::new(); + let queries: Vec = vec![ + queries::FILE_TEXT.into(), + queries::FILE_SET.into(), + ]; + for q in queries { + res = res.with_ground_query(q.query_type, q.f) + } + let queries: Vec = vec![ + queries::FILE_SYNTAX.into(), + ::module_map_db::MODULE_DESCR.into(), + ::module_map_db::RESOLVE_SUBMODULE.into(), + ::module_map_db::PARENT_MODULE.into(), + ]; + for q in queries { + res = res.with_query(q.query_type, q.f); + } + res +} + +struct SalsaGroundQuery { + query_type: salsa::QueryTypeId, + f: Box (Data, salsa::OutputFingerprint) + Send + Sync + 'static>, +} + +impl From> for SalsaGroundQuery +where + T: Send + Sync + 'static, + R: Send + Sync + 'static, +{ + fn from(q: GroundQuery) -> SalsaGroundQuery + { + SalsaGroundQuery { + query_type: salsa::QueryTypeId(q.id), + f: Box::new(move |state, data| { + let data: &T = data.downcast_ref().unwrap(); + let res = (q.f)(state, data); + let h = (q.h)(&res); + (Arc::new(res), salsa::OutputFingerprint(h)) + }) + } + } +} + +struct SalsaQuery { + query_type: salsa::QueryTypeId, + f: Box, &Data) -> (Data, salsa::OutputFingerprint) + Send + Sync + 'static>, +} + +impl From> for SalsaQuery +where + T: Hash + Send + Sync + 'static, + R: Hash + Send + Sync + 'static, +{ + fn from(q: Query) -> SalsaQuery + { + SalsaQuery { + query_type: salsa::QueryTypeId(q.id), + f: Box::new(move |ctx, data| { + let ctx = QueryCtx { inner: ctx }; + let data: &T = data.downcast_ref().unwrap(); + let res = (q.f)(ctx, data); + let h = hash(&res); + (Arc::new(res), salsa::OutputFingerprint(h)) + }) + } + } +} + +fn hash(x: &T) -> u64 { + use std::hash::Hasher; + let mut hasher = ::std::collections::hash_map::DefaultHasher::new(); + ::std::hash::Hash::hash(x, &mut hasher); + hasher.finish() +} diff --git a/crates/libanalysis/src/db/queries.rs b/crates/libanalysis/src/db/queries.rs new file mode 100644 index 0000000000..2d4aac6e93 --- /dev/null +++ b/crates/libanalysis/src/db/queries.rs @@ -0,0 +1,43 @@ +use std::sync::Arc; +use libsyntax2::{File}; +use { + FileId, FileResolverImp, + db::{Query, GroundQuery, QueryCtx, hash}, +}; + + +impl<'a> QueryCtx<'a> { + pub(crate) fn file_set(&self) -> Arc<(Vec, FileResolverImp)> { + self.get_g(FILE_SET, ()) + } + pub(crate) fn file_text(&self, file_id: FileId) -> Arc { + Arc::clone(&*self.get_g(FILE_TEXT, file_id)) + } + pub(crate) fn file_syntax(&self, file_id: FileId) -> File { + (&*self.get(FILE_SYNTAX, file_id)).clone() + } +} + +pub(super) const FILE_TEXT: GroundQuery> = GroundQuery { + id: 10, + f: |state, id| state.file_map[&id].clone(), + h: hash, +}; + +pub(super) const FILE_SET: GroundQuery<(), (Vec, FileResolverImp)> = GroundQuery { + id: 11, + f: |state, &()| { + let files = state.file_map.keys().cloned().collect(); + let resolver = state.resolver.clone(); + (files, resolver) + }, + h: |(files, _)| hash(files), +}; + +pub(super) const FILE_SYNTAX: Query = Query { + id: 20, + f: |ctx, file_id: &FileId| { + let text = ctx.file_text(*file_id); + File::parse(&*text) + } +}; diff --git a/crates/libanalysis/src/lib.rs b/crates/libanalysis/src/lib.rs index 3e77006c5a..bca5d4c1d6 100644 --- a/crates/libanalysis/src/lib.rs +++ b/crates/libanalysis/src/lib.rs @@ -10,6 +10,7 @@ extern crate relative_path; #[macro_use] extern crate crossbeam_channel; extern crate im; +extern crate salsa; mod symbol_index; mod module_map; diff --git a/crates/libanalysis/src/module_map_db.rs b/crates/libanalysis/src/module_map_db.rs index 4d4bd91046..ff69cc0d20 100644 --- a/crates/libanalysis/src/module_map_db.rs +++ b/crates/libanalysis/src/module_map_db.rs @@ -2,66 +2,55 @@ use std::sync::Arc; use { FileId, db::{ - Query, Eval, QueryCtx, FileSyntax, Files, - Cache, QueryCache, + Query, QueryCtx }, module_map::resolve_submodule, }; -pub(crate) enum ModuleDescr {} -impl Query for ModuleDescr { - const ID: u32 = 30; - type Params = FileId; - type Output = Arc; -} - -enum ResolveSubmodule {} -impl Query for ResolveSubmodule { - const ID: u32 = 31; - type Params = (FileId, descr::Submodule); - type Output = Arc>; -} - -enum ParentModule {} -impl Query for ParentModule { - const ID: u32 = 40; - type Params = FileId; - type Output = Arc>; -} - -impl Eval for ModuleDescr { - fn eval(ctx: &QueryCtx, file_id: &FileId) -> Arc { - let file = ctx.get::(file_id); - Arc::new(descr::ModuleDescr::new(file.ast())) +impl<'a> QueryCtx<'a> { + fn module_descr(&self, file_id: FileId) -> Arc { + self.get(MODULE_DESCR, file_id) + } + fn resolve_submodule(&self, file_id: FileId, submod: descr::Submodule) -> Arc> { + self.get(RESOLVE_SUBMODULE, (file_id, submod)) } } -impl Eval for ResolveSubmodule { - fn eval(ctx: &QueryCtx, &(file_id, ref submodule): &(FileId, descr::Submodule)) -> Arc> { - let files = ctx.get::(&()); - let res = resolve_submodule(file_id, &submodule.name, &files.file_resolver()).0; - Arc::new(res) +pub(crate) const MODULE_DESCR: Query = Query { + id: 30, + f: |ctx, &file_id| { + let file = ctx.file_syntax(file_id); + descr::ModuleDescr::new(file.ast()) } -} +}; -impl Eval for ParentModule { - fn eval(ctx: &QueryCtx, file_id: &FileId) -> Arc> { - let files = ctx.get::(&()); - let res = files.iter() - .map(|parent_id| (parent_id, ctx.get::(&parent_id))) +pub(crate) const RESOLVE_SUBMODULE: Query<(FileId, descr::Submodule), Vec> = Query { + id: 31, + f: |ctx, params| { + let files = ctx.file_set(); + resolve_submodule(params.0, ¶ms.1.name, &files.1).0 + } +}; + +pub(crate) const PARENT_MODULE: Query> = Query { + id: 40, + f: |ctx, file_id| { + let files = ctx.file_set(); + let res = files.0.iter() + .map(|&parent_id| (parent_id, ctx.module_descr(parent_id))) .filter(|(parent_id, descr)| { descr.submodules.iter() .any(|subm| { - ctx.get::(&(*parent_id, subm.clone())) + ctx.resolve_submodule(*parent_id, subm.clone()) .iter() .any(|it| it == file_id) }) }) .map(|(id, _)| id) .collect(); - Arc::new(res) + res } -} +}; mod descr { use libsyntax2::{ @@ -102,7 +91,7 @@ mod tests { use im; use relative_path::{RelativePath, RelativePathBuf}; use { - db::{Query, DbHost, TraceEventKind}, + db::{Query, Db, State}, imp::FileResolverImp, FileId, FileResolver, }; @@ -126,7 +115,7 @@ mod tests { struct Fixture { next_file_id: u32, fm: im::HashMap, - db: DbHost, + db: Db, } impl Fixture { @@ -134,7 +123,7 @@ mod tests { Fixture { next_file_id: 1, fm: im::HashMap::new(), - db: DbHost::new(), + db: Db::new(State::default()), } } fn add_file(&mut self, path: &str, text: &str) -> FileId { @@ -142,36 +131,39 @@ mod tests { let file_id = FileId(self.next_file_id); self.next_file_id += 1; self.fm.insert(file_id, RelativePathBuf::from(&path[1..])); - self.db.change_file(file_id, Some(text.to_string())); - self.db.set_file_resolver(FileResolverImp::new( + let mut new_state = self.db.state().clone(); + new_state.file_map.insert(file_id, text.to_string().into_boxed_str().into()); + new_state.resolver = FileResolverImp::new( Arc::new(FileMap(self.fm.clone())) - )); - + ); + self.db = self.db.with_state(new_state, &[file_id], true); file_id } fn remove_file(&mut self, file_id: FileId) { self.fm.remove(&file_id); - self.db.change_file(file_id, None); - self.db.set_file_resolver(FileResolverImp::new( + let mut new_state = self.db.state().clone(); + new_state.file_map.remove(&file_id); + new_state.resolver = FileResolverImp::new( Arc::new(FileMap(self.fm.clone())) - )) + ); + self.db = self.db.with_state(new_state, &[file_id], true); } fn change_file(&mut self, file_id: FileId, new_text: &str) { - self.db.change_file(file_id, Some(new_text.to_string())); + let mut new_state = self.db.state().clone(); + new_state.file_map.insert(file_id, new_text.to_string().into_boxed_str().into()); + self.db = self.db.with_state(new_state, &[file_id], false); } fn check_parent_modules( &self, file_id: FileId, expected: &[FileId], - queries: &[(u32, u64)] + queries: &[(u16, u64)] ) { - let ctx = self.db.query_ctx(); - let actual = ctx.get::(&file_id); + let (actual, events) = self.db.get(PARENT_MODULE, file_id); assert_eq!(actual.as_slice(), expected); let mut counts = HashMap::new(); - ctx.trace.borrow().iter() - .filter(|event| event.kind == TraceEventKind::Evaluating) - .for_each(|event| *counts.entry(event.query_id).or_insert(0) += 1); + events.into_iter() + .for_each(|event| *counts.entry(event).or_insert(0) += 1); for &(query_id, expected_count) in queries.iter() { let actual_count = *counts.get(&query_id).unwrap_or(&0); assert_eq!( @@ -189,25 +181,25 @@ mod tests { fn test_parent_module() { let mut f = Fixture::new(); let foo = f.add_file("/foo.rs", ""); - // f.check_parent_modules(foo, &[], &[(ModuleDescr::ID, 1)]); + f.check_parent_modules(foo, &[], &[(MODULE_DESCR.id, 1)]); let lib = f.add_file("/lib.rs", "mod foo;"); - f.check_parent_modules(foo, &[lib], &[(ModuleDescr::ID, 2)]); - f.check_parent_modules(foo, &[lib], &[(ModuleDescr::ID, 0)]); + f.check_parent_modules(foo, &[lib], &[(MODULE_DESCR.id, 1)]); + f.check_parent_modules(foo, &[lib], &[(MODULE_DESCR.id, 0)]); f.change_file(lib, ""); - f.check_parent_modules(foo, &[], &[(ModuleDescr::ID, 2)]); + f.check_parent_modules(foo, &[], &[(MODULE_DESCR.id, 1)]); - // f.change_file(lib, "mod foo;"); - // f.check_parent_modules(foo, &[lib], &[(ModuleDescr::ID, 2)]); + f.change_file(lib, "mod foo;"); + f.check_parent_modules(foo, &[lib], &[(MODULE_DESCR.id, 1)]); - // f.change_file(lib, "mod bar;"); - // f.check_parent_modules(foo, &[], &[(ModuleDescr::ID, 2)]); + f.change_file(lib, "mod bar;"); + f.check_parent_modules(foo, &[], &[(MODULE_DESCR.id, 1)]); - // f.change_file(lib, "mod foo;"); - // f.check_parent_modules(foo, &[lib], &[(ModuleDescr::ID, 2)]); + f.change_file(lib, "mod foo;"); + f.check_parent_modules(foo, &[lib], &[(MODULE_DESCR.id, 1)]); - // f.remove_file(lib); - // f.check_parent_modules(foo, &[], &[(ModuleDescr::ID, 1)]); + f.remove_file(lib); + f.check_parent_modules(foo, &[], &[(MODULE_DESCR.id, 0)]); } } diff --git a/crates/salsa/src/lib.rs b/crates/salsa/src/lib.rs index 5de3c77747..75815e8bd2 100644 --- a/crates/salsa/src/lib.rs +++ b/crates/salsa/src/lib.rs @@ -8,8 +8,8 @@ use std::{ }; use parking_lot::Mutex; -type GroundQueryFn = fn(&T, &D) -> (D, OutputFingerprint); -type QueryFn = fn(&QueryCtx, &D) -> (D, OutputFingerprint); +type GroundQueryFn = Box (D, OutputFingerprint) + Send + Sync + 'static>; +type QueryFn = Box, &D) -> (D, OutputFingerprint) + Send + Sync + 'static>; #[derive(Debug)] pub struct Db { @@ -119,7 +119,7 @@ where res } - pub fn get_inner( + fn get_inner( &self, query_id: QueryId, params: D, @@ -176,9 +176,9 @@ where self.executed.borrow_mut().push(query_id.0); self.stack.borrow_mut().push(Vec::new()); - let (res, output_fingerprint) = if let Some(f) = self.ground_query_fn_by_type(query_id.0) { + let (res, output_fingerprint) = if let Some(f) = self.query_config.ground_fn.get(&query_id.0) { f(&self.db.ground_data, ¶ms) - } else if let Some(f) = self.query_fn_by_type(query_id.0) { + } else if let Some(f) = self.query_config.query_fn.get(&query_id.0) { f(self, ¶ms) } else { panic!("unknown query type: {:?}", query_id.0); @@ -190,12 +190,6 @@ where self.db.record(query_id, params, res.clone(), output_fingerprint, deps); (res, output_fingerprint) } - fn ground_query_fn_by_type(&self, query_type: QueryTypeId) -> Option> { - self.query_config.ground_fn.get(&query_type).map(|&it| it) - } - fn query_fn_by_type(&self, query_type: QueryTypeId) -> Option> { - self.query_config.query_fn.get(&query_type).map(|&it| it) - } fn record_dep( &self, query_id: QueryId, @@ -239,7 +233,9 @@ where query_config: Arc::new(query_config), } } - + pub fn ground_data(&self) -> &T { + &self.db.ground_data + } pub fn with_ground_data( &self, ground_data: T, diff --git a/crates/salsa/tests/integration.rs b/crates/salsa/tests/integration.rs index 3cec330e64..aed9219bee 100644 --- a/crates/salsa/tests/integration.rs +++ b/crates/salsa/tests/integration.rs @@ -79,19 +79,19 @@ where fn mk_queries() -> salsa::QueryConfig { salsa::QueryConfig::::new() - .with_ground_query(GET_TEXT, |state, id| { + .with_ground_query(GET_TEXT, Box::new(|state, id| { mk_ground_query::(state, id, |state, id| state[id].clone()) - }) - .with_ground_query(GET_FILES, |state, id| { + })) + .with_ground_query(GET_FILES, Box::new(|state, id| { mk_ground_query::<(), Vec>(state, id, |state, &()| state.keys().cloned().collect()) - }) - .with_query(FILE_NEWLINES, |query_ctx, id| { + })) + .with_query(FILE_NEWLINES, Box::new(|query_ctx, id| { mk_query(query_ctx, id, |query_ctx, &id| { let text = query_ctx.get_text(id); text.lines().count() }) - }) - .with_query(TOTAL_NEWLINES, |query_ctx, id| { + })) + .with_query(TOTAL_NEWLINES, Box::new(|query_ctx, id| { mk_query(query_ctx, id, |query_ctx, &()| { let mut total = 0; for &id in query_ctx.get_files().iter() { @@ -99,7 +99,7 @@ fn mk_queries() -> salsa::QueryConfig { } total }) - }) + })) } #[test]