add more code for parrellel passes

This commit is contained in:
Evan Almloff 2022-11-25 11:44:46 -06:00
parent 4096861c59
commit c7eeeef68e
2 changed files with 213 additions and 26 deletions

View file

@ -22,6 +22,7 @@ anymap = "0.12.1"
slab = "0.4"
parking_lot = "0.12.1"
crossbeam-deque = "0.8.2"
dashmap = "5.4.0"
[dev-dependencies]
rand = "0.8.5"

View file

@ -1,7 +1,14 @@
use core::panic;
use std::hash::BuildHasherDefault;
use std::sync::atomic::{AtomicU64, Ordering};
use crossbeam_deque::{Injector, Stealer, Worker};
use dashmap::DashSet;
use dioxus_core::ScopeId;
use dioxus_html::u;
use parking_lot::lock_api::RawMutex as _;
use parking_lot::{RawMutex, RwLock};
use parking_lot::{Mutex, RawMutex, RwLock, RwLockWriteGuard};
use rustc_hash::{FxHashMap, FxHashSet, FxHasher};
use slab::Slab;
use std::cell::UnsafeCell;
use std::collections::VecDeque;
@ -822,39 +829,218 @@ fn traverse_breadth_first() {
});
}
enum PassDirection {
Up,
Down,
Node,
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
struct PassId(u64);
trait UpwardPass<T> {
fn upward_pass(&mut self, node: &mut T, parent: Option<&mut T>) -> bool;
fn pass_id(&self) -> PassId;
fn dependancies(&self) -> &'static [PassId];
fn dependants(&self) -> &'static [PassId];
fn upward_pass(&self, node: &mut T, parent: Option<&mut T>) -> bool;
}
fn resolve_pass(&mut self, tree: &mut impl TreeView<T>, starting_nodes: &[NodeId]) {
let global = Injector::default();
for node in starting_nodes {
global.push(*node);
trait DownwardPass<T> {
fn pass_id(&self) -> PassId;
fn dependancies(&self) -> &'static [PassId];
fn dependants(&self) -> &'static [PassId];
fn downward_pass<'a>(
&self,
node: &mut T,
children: &mut dyn Iterator<Item = &'a mut T>,
) -> bool;
}
trait NodePass<T> {
fn pass_id(&self) -> PassId;
fn dependancies(&self) -> &'static [PassId];
fn dependants(&self) -> &'static [PassId];
fn node_pass(&self, node: &mut T) -> bool;
}
enum AnyPass<T> {
Upward(Box<dyn UpwardPass<T> + Send + Sync>),
Downward(Box<dyn DownwardPass<T> + Send + Sync>),
Node(Box<dyn NodePass<T> + Send + Sync>),
}
impl<T> AnyPass<T> {
fn pass_id(&self) -> PassId {
match self {
Self::Upward(pass) => pass.pass_id(),
Self::Downward(pass) => pass.pass_id(),
Self::Node(pass) => pass.pass_id(),
}
}
let core_count = thread::available_parallelism()
.map(|c| c.get())
.unwrap_or(1);
let workers: Vec<_> = (0..core_count).map(|_| Worker::new_fifo()).collect();
let stealers: Vec<_> = workers.iter().map(|w| w.stealer()).collect();
let shared_view = SharedView::new(tree);
thread::scope(|s| {
let global = &global;
let stealers = &stealers;
for (_, w) in (0..core_count).zip(workers.into_iter()) {
let mut shared_view = shared_view.clone();
s.spawn(move || {
while let Some(id) = find_task(&w, &global, &stealers) {
let (node, parent) = shared_view.node_parent_mut(id).unwrap();
if self.upward_pass(node, parent) {
if let Some(id) = shared_view.parent_id(id) {
w.push(id);
fn dependancies(&self) -> &'static [PassId] {
match self {
Self::Upward(pass) => pass.dependancies(),
Self::Downward(pass) => pass.dependancies(),
Self::Node(pass) => pass.dependancies(),
}
}
fn dependants(&self) -> &'static [PassId] {
match self {
Self::Upward(pass) => pass.dependants(),
Self::Downward(pass) => pass.dependants(),
Self::Node(pass) => pass.dependants(),
}
}
}
type FxDashSet<T> = dashmap::DashSet<T, BuildHasherDefault<FxHasher>>;
type FxDashMap<K, V> = dashmap::DashMap<K, V, BuildHasherDefault<FxHasher>>;
#[derive(Default)]
struct DirtyNodeStates {
dirty: FxDashMap<NodeId, Vec<AtomicU64>>,
}
impl DirtyNodeStates {
fn new(starting_nodes: FxHashMap<NodeId, FxDashSet<PassId>>) -> Self {
let mut this = Self::default();
for (node, nodes) in starting_nodes {
for pass_id in nodes {
this.insert(pass_id, node);
}
}
this
}
fn insert(&self, pass_id: PassId, node_id: NodeId) {
let pass_id = pass_id.0;
let index = pass_id / 64;
let bit = pass_id % 64;
let encoded = 1 << bit;
if let Some(dirty) = self.dirty.get(&node_id) {
if let Some(atomic) = dirty.get(index as usize) {
atomic.fetch_or(encoded, Ordering::Relaxed);
} else {
drop(dirty);
let mut write = self.dirty.get_mut(&node_id).unwrap();
write.resize_with(index as usize + 1, || AtomicU64::new(0));
write[index as usize].fetch_or(encoded, Ordering::Relaxed);
}
} else {
self.dirty.insert(node_id, vec![AtomicU64::new(encoded)]);
}
}
fn all_dirty(&self, pass_id: PassId) -> impl Iterator<Item = NodeId> + '_ {
let pass_id = pass_id.0;
let index = pass_id / 64;
let bit = pass_id % 64;
let encoded = 1 << bit;
self.dirty.iter().filter_map(move |entry| {
let node_id = entry.key();
let dirty = entry.value();
if let Some(atomic) = dirty.get(index as usize) {
if atomic.load(Ordering::Relaxed) & encoded != 0 {
Some(*node_id)
} else {
None
}
} else {
None
}
})
}
}
fn resolve_passes<T>(
tree: &mut impl TreeView<T>,
starting_nodes: FxHashMap<NodeId, FxDashSet<PassId>>,
mut passes: Vec<AnyPass<T>>,
) {
let dirty_nodes: Arc<DirtyNodeStates> = Arc::new(DirtyNodeStates::new(starting_nodes));
let global = Injector::default();
let core_count = thread::available_parallelism()
.map(|c| c.get())
.unwrap_or(1);
let workers: Vec<Worker<NodeId>> = (0..core_count).map(|_| Worker::new_fifo()).collect();
let stealers: Vec<_> = workers.iter().map(|w| w.stealer()).collect();
let shared_view = SharedView::new(tree);
let mut resolved_passes: FxHashSet<PassId> = FxHashSet::default();
let current_pass: Arc<RwLock<Option<AnyPass<T>>>> = Arc::new(RwLock::new(None));
thread::scope(|s| {
let global = &global;
let stealers = &stealers;
for (_, w) in (0..core_count).zip(workers.into_iter()) {
let mut shared_view = shared_view.clone();
let current_pass = current_pass.clone();
let dirty_nodes = dirty_nodes.clone();
s.spawn(move || {
while let Some(current_pass) = &*current_pass.read() {
match current_pass {
AnyPass::Upward(pass) => {
while let Some(id) = find_task(&w, global, stealers) {
let (node, parent) = shared_view.node_parent_mut(id).unwrap();
if pass.upward_pass(node, parent) {
if let Some(id) = shared_view.parent_id(id) {
for dependant in pass.dependants() {
dirty_nodes.insert(*dependant, id);
}
w.push(id);
}
}
}
}
AnyPass::Downward(pass) => {
while let Some(id) = find_task(&w, global, stealers) {
let (node, mut children) =
shared_view.parent_child_mut(id).unwrap();
if pass.downward_pass(node, &mut children) {
drop(children);
for id in shared_view.children_ids(id).unwrap() {
for dependant in pass.dependants() {
dirty_nodes.insert(*dependant, *id);
}
w.push(*id);
}
}
}
}
AnyPass::Node(pass) => {
while let Some(id) = find_task(&w, global, stealers) {
let node = shared_view.get_mut(id).unwrap();
if pass.node_pass(node) {
for dependant in pass.dependants() {
dirty_nodes.insert(*dependant, id);
}
}
}
}
}
});
}
});
}
while !passes.is_empty() {
for i in 0..passes.len() {
if passes[i]
.dependancies()
.iter()
.all(|id| resolved_passes.contains(id))
{
let pass = passes.remove(i);
let pass_id = pass.pass_id();
for node in dirty_nodes.all_dirty(pass_id) {
global.push(node);
}
resolved_passes.insert(pass_id);
break;
}
}
});
}
}
*current_pass.write() = None;
});
}
fn find_task<T>(local: &Worker<T>, global: &Injector<T>, stealers: &[Stealer<T>]) -> Option<T> {