From 9e47468c3ac40ab04719fc46aa33e9cfa2966841 Mon Sep 17 00:00:00 2001 From: Evan Almloff Date: Fri, 25 Nov 2022 13:38:01 -0600 Subject: [PATCH] more work on parallel passes --- packages/native-core/src/lib.rs | 1 + packages/native-core/src/passes.rs | 432 +++++++++++++++++++++++++++++ packages/native-core/src/tree.rs | 259 +---------------- 3 files changed, 435 insertions(+), 257 deletions(-) create mode 100644 packages/native-core/src/passes.rs diff --git a/packages/native-core/src/lib.rs b/packages/native-core/src/lib.rs index 895bb7dc9..56c4a6a73 100644 --- a/packages/native-core/src/lib.rs +++ b/packages/native-core/src/lib.rs @@ -3,6 +3,7 @@ use tree::NodeId; pub mod layout_attributes; pub mod node; pub mod node_ref; +pub mod passes; pub mod real_dom; pub mod state; pub mod tree; diff --git a/packages/native-core/src/passes.rs b/packages/native-core/src/passes.rs new file mode 100644 index 000000000..aa25a81f2 --- /dev/null +++ b/packages/native-core/src/passes.rs @@ -0,0 +1,432 @@ +use crossbeam_deque::{Injector, Stealer, Worker}; +use parking_lot::RwLock; +use rustc_hash::{FxHashMap, FxHashSet, FxHasher}; +use std::hash::BuildHasherDefault; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +use crate::tree::{NodeId, SharedView, TreeView}; + +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +pub struct PassId(u64); + +pub trait UpwardPass { + fn pass_id(&self) -> PassId; + fn dependancies(&self) -> &'static [PassId]; + fn dependants(&self) -> &'static [PassId]; + fn pass<'a>(&self, node: &mut T, children: &mut dyn Iterator) -> bool; +} + +pub trait DownwardPass { + fn pass_id(&self) -> PassId; + fn dependancies(&self) -> &'static [PassId]; + fn dependants(&self) -> &'static [PassId]; + fn pass(&self, node: &mut T, parent: Option<&mut T>) -> bool; +} + +pub trait NodePass { + fn pass_id(&self) -> PassId; + fn dependancies(&self) -> &'static [PassId]; + fn dependants(&self) -> &'static [PassId]; + fn pass(&self, node: &mut T) -> bool; +} + +pub enum AnyPass { + Upward(Box + Send + Sync>), + Downward(Box + Send + Sync>), + Node(Box + Send + Sync>), +} + +impl AnyPass { + fn pass_id(&self) -> PassId { + match self { + Self::Upward(pass) => pass.pass_id(), + Self::Downward(pass) => pass.pass_id(), + Self::Node(pass) => pass.pass_id(), + } + } + + fn dependancies(&self) -> &'static [PassId] { + match self { + Self::Upward(pass) => pass.dependancies(), + Self::Downward(pass) => pass.dependancies(), + Self::Node(pass) => pass.dependancies(), + } + } +} + +type FxDashMap = dashmap::DashMap>; + +#[derive(Default)] +struct DirtyNodeStates { + dirty: FxDashMap>, +} + +impl DirtyNodeStates { + fn new(starting_nodes: FxHashMap>) -> Self { + let this = Self::default(); + for (node, nodes) in starting_nodes { + for pass_id in nodes { + this.insert(pass_id, node); + } + } + this + } + + fn insert(&self, pass_id: PassId, node_id: NodeId) { + let pass_id = pass_id.0; + let index = pass_id / 64; + let bit = pass_id % 64; + let encoded = 1 << bit; + if let Some(dirty) = self.dirty.get(&node_id) { + if let Some(atomic) = dirty.get(index as usize) { + atomic.fetch_or(encoded, Ordering::Relaxed); + } else { + drop(dirty); + let mut write = self.dirty.get_mut(&node_id).unwrap(); + write.resize_with(index as usize + 1, || AtomicU64::new(0)); + write[index as usize].fetch_or(encoded, Ordering::Relaxed); + } + } else { + let mut v = Vec::with_capacity(index as usize + 1); + v.resize_with(index as usize + 1, || AtomicU64::new(0)); + v[index as usize].fetch_or(encoded, Ordering::Relaxed); + self.dirty.insert(node_id, v); + } + } + + fn all_dirty(&self, pass_id: PassId) -> impl Iterator + '_ { + let pass_id = pass_id.0; + let index = pass_id / 64; + let bit = pass_id % 64; + let encoded = 1 << bit; + self.dirty.iter().filter_map(move |entry| { + let node_id = entry.key(); + let dirty = entry.value(); + if let Some(atomic) = dirty.get(index as usize) { + if atomic.load(Ordering::Relaxed) & encoded != 0 { + Some(*node_id) + } else { + None + } + } else { + None + } + }) + } +} + +fn get_pass>( + passes: &mut Vec>, + resolved_passes: &mut FxHashSet, + dirty_nodes: &DirtyNodeStates, + shared_view: &mut SharedView, + global: &Injector, + current_pass: &RwLock>>, +) { + for i in 0..passes.len() { + if passes[i] + .dependancies() + .iter() + .all(|id| resolved_passes.contains(id)) + { + let pass = passes.remove(i); + let pass_id = pass.pass_id(); + resolved_passes.insert(pass_id); + match pass { + AnyPass::Upward(pass) => { + // Upward passes are more difficult. Right now we limit them to only one thread. + let worker = Worker::new_fifo(); + let mut queued_nodes = FxHashSet::default(); + for node in dirty_nodes.all_dirty(pass_id) { + queued_nodes.insert(node); + worker.push(node); + } + while let Some(id) = worker.pop() { + let (node, mut children) = shared_view.parent_child_mut(id).unwrap(); + if pass.pass(node, &mut children) { + drop(children); + if let Some(id) = shared_view.parent_id(id) { + for dependant in pass.dependants() { + dirty_nodes.insert(*dependant, id); + } + if !queued_nodes.contains(&id) { + queued_nodes.insert(id); + worker.push(id); + } + } + } + } + } + _ => { + for node in dirty_nodes.all_dirty(pass_id) { + global.push(node); + } + current_pass.write().replace(pass); + } + } + + break; + } + } +} + +pub fn resolve_passes( + tree: &mut impl TreeView, + starting_nodes: FxHashMap>, + mut passes: Vec>, +) { + let dirty_nodes: Arc = Arc::new(DirtyNodeStates::new(starting_nodes)); + let global = Injector::default(); + + let core_count = thread::available_parallelism() + .map(|c| c.get()) + .unwrap_or(1); + let workers: Vec> = (0..core_count).map(|_| Worker::new_fifo()).collect(); + let stealers: Vec<_> = workers.iter().map(|w| w.stealer()).collect(); + let mut shared_view = SharedView::new(tree); + let mut resolved_passes: FxHashSet = FxHashSet::default(); + let current_pass: Arc>>> = Arc::new(RwLock::new(None)); + + thread::scope(|s| { + get_pass( + &mut passes, + &mut resolved_passes, + &dirty_nodes, + &mut shared_view, + &global, + ¤t_pass, + ); + let global = &global; + let stealers = &stealers; + for (_, w) in (0..core_count).zip(workers.into_iter()) { + let mut shared_view = shared_view.clone(); + let current_pass = current_pass.clone(); + let dirty_nodes = dirty_nodes.clone(); + s.spawn(move || { + while let Some(current_pass) = &*current_pass.read() { + match current_pass { + AnyPass::Upward(_) => { + todo!("Upward passes are single threaded") + } + AnyPass::Node(pass) => { + // Node passes are the easiest to parallelize. We just run the pass on each node. + while let Some(id) = find_task(&w, global, stealers) { + let node = shared_view.get_mut(id).unwrap(); + if pass.pass(node) { + for dependant in pass.dependants() { + dirty_nodes.insert(*dependant, id); + } + } + } + } + AnyPass::Downward(pass) => { + // Downward passes are easy to parallelize. We try to keep trees localized to one thread, but allow work stealing to balance the load. + while let Some(id) = find_task(&w, global, stealers) { + let (node, parent) = shared_view.node_parent_mut(id).unwrap(); + if pass.pass(node, parent) { + for id in shared_view.children_ids(id).unwrap() { + for dependant in pass.dependants() { + dirty_nodes.insert(*dependant, *id); + } + w.push(*id); + } + } + } + } + } + } + }); + } + while !passes.is_empty() { + while !stealers.iter().all(|s| s.is_empty()) { + std::thread::sleep(Duration::from_millis(50)); + } + get_pass( + &mut passes, + &mut resolved_passes, + &dirty_nodes, + &mut shared_view, + global, + ¤t_pass, + ); + } + *current_pass.write() = None; + }); +} + +fn find_task(local: &Worker, global: &Injector, stealers: &[Stealer]) -> Option { + // Pop a task from the local queue, if not empty. + local.pop().or_else(|| { + // Otherwise, we need to look for a task elsewhere. + std::iter::repeat_with(|| { + // Try stealing a batch of tasks from the global queue. + global + .steal_batch_and_pop(local) + // Or try stealing a task from one of the other threads. + .or_else(|| stealers.iter().map(|s| s.steal()).collect()) + }) + // Loop while no task was stolen and any steal operation needs to be retried. + .find(|s| !s.is_retry()) + // Extract the stolen task, if there is one. + .and_then(|s| s.success()) + }) +} + +#[test] +fn node_pass() { + use crate::tree::{Tree, TreeLike}; + let mut tree = Tree::new(0); + let parent = tree.root(); + let child1 = tree.create_node(1); + tree.add_child(parent, child1); + let grandchild1 = tree.create_node(3); + tree.add_child(child1, grandchild1); + let child2 = tree.create_node(2); + tree.add_child(parent, child2); + let grandchild2 = tree.create_node(4); + tree.add_child(child2, grandchild2); + + struct AddPass; + + impl NodePass for AddPass { + fn pass_id(&self) -> PassId { + PassId(0) + } + + fn dependancies(&self) -> &'static [PassId] { + &[] + } + + fn dependants(&self) -> &'static [PassId] { + &[] + } + + fn pass(&self, node: &mut i32) -> bool { + *node += 1; + true + } + } + + let passes = vec![AnyPass::Node(Box::new(AddPass))]; + let mut dirty_nodes: FxHashMap> = FxHashMap::default(); + dirty_nodes.insert(tree.root(), [PassId(0)].into_iter().collect()); + resolve_passes(&mut tree, dirty_nodes, passes); + + assert_eq!(tree.get(tree.root()).unwrap(), &1); +} + +#[test] +fn down_pass() { + use crate::tree::{Tree, TreeLike}; + let mut tree = Tree::new(1); + let parent = tree.root(); + let child1 = tree.create_node(1); + tree.add_child(parent, child1); + let grandchild1 = tree.create_node(1); + tree.add_child(child1, grandchild1); + let child2 = tree.create_node(1); + tree.add_child(parent, child2); + let grandchild2 = tree.create_node(1); + tree.add_child(child2, grandchild2); + + struct AddPass; + + impl DownwardPass for AddPass { + fn pass_id(&self) -> PassId { + PassId(0) + } + + fn dependancies(&self) -> &'static [PassId] { + &[] + } + + fn dependants(&self) -> &'static [PassId] { + &[] + } + + fn pass(&self, node: &mut i32, parent: Option<&mut i32>) -> bool { + if let Some(parent) = parent { + *node += *parent; + } + true + } + } + + let passes = vec![AnyPass::Downward(Box::new(AddPass))]; + let mut dirty_nodes: FxHashMap> = FxHashMap::default(); + dirty_nodes.insert(tree.root(), [PassId(0)].into_iter().collect()); + resolve_passes(&mut tree, dirty_nodes, passes); + + assert_eq!(tree.get(tree.root()).unwrap(), &1); + assert_eq!(tree.get(child1).unwrap(), &2); + assert_eq!(tree.get(grandchild1).unwrap(), &3); + assert_eq!(tree.get(child2).unwrap(), &2); + assert_eq!(tree.get(grandchild2).unwrap(), &3); +} + +#[test] +fn up_pass() { + use crate::tree::{Tree, TreeLike}; + // Tree before: + // 0=\ + // 0=\ + // 1 + // 0=\ + // 1 + // Tree after: + // 2=\ + // 1=\ + // 1 + // 1=\ + // 1 + let mut tree = Tree::new(0); + let parent = tree.root(); + let child1 = tree.create_node(0); + tree.add_child(parent, child1); + let grandchild1 = tree.create_node(1); + tree.add_child(child1, grandchild1); + let child2 = tree.create_node(0); + tree.add_child(parent, child2); + let grandchild2 = tree.create_node(1); + tree.add_child(child2, grandchild2); + + struct AddPass; + + impl UpwardPass for AddPass { + fn pass_id(&self) -> PassId { + PassId(0) + } + + fn dependancies(&self) -> &'static [PassId] { + &[] + } + + fn dependants(&self) -> &'static [PassId] { + &[] + } + + fn pass<'a>( + &self, + node: &mut i32, + children: &mut dyn Iterator, + ) -> bool { + *node += children.map(|i| *i).sum::(); + true + } + } + + let passes = vec![AnyPass::Upward(Box::new(AddPass))]; + let mut dirty_nodes: FxHashMap> = FxHashMap::default(); + dirty_nodes.insert(grandchild1, [PassId(0)].into_iter().collect()); + dirty_nodes.insert(grandchild2, [PassId(0)].into_iter().collect()); + resolve_passes(&mut tree, dirty_nodes, passes); + + assert_eq!(tree.get(tree.root()).unwrap(), &2); + assert_eq!(tree.get(child1).unwrap(), &1); + assert_eq!(tree.get(grandchild1).unwrap(), &1); + assert_eq!(tree.get(child2).unwrap(), &1); + assert_eq!(tree.get(grandchild2).unwrap(), &1); +} diff --git a/packages/native-core/src/tree.rs b/packages/native-core/src/tree.rs index f5b4c09bf..4bec046be 100644 --- a/packages/native-core/src/tree.rs +++ b/packages/native-core/src/tree.rs @@ -1,20 +1,11 @@ use core::panic; -use std::hash::BuildHasherDefault; -use std::sync::atomic::{AtomicU64, Ordering}; - -use crossbeam_deque::{Injector, Stealer, Worker}; -use dashmap::DashSet; -use dioxus_core::ScopeId; -use dioxus_html::u; use parking_lot::lock_api::RawMutex as _; -use parking_lot::{Mutex, RawMutex, RwLock, RwLockWriteGuard}; -use rustc_hash::{FxHashMap, FxHashSet, FxHasher}; +use parking_lot::{RawMutex, RwLock}; use slab::Slab; use std::cell::UnsafeCell; use std::collections::VecDeque; use std::marker::PhantomData; use std::sync::Arc; -use std::thread; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug, PartialOrd, Ord)] pub struct NodeId(pub usize); @@ -109,7 +100,7 @@ pub trait TreeView: Sized { // Safety: No node has itself as a parent. (*mut_ptr) .get_mut(id) - .map(|node| (node, (*mut_ptr).parent_mut(id).map(|parent| parent))) + .map(|node| (node, (*mut_ptr).parent_mut(id))) } } @@ -808,249 +799,3 @@ fn traverse_depth_first() { node_count += 1; }); } - -#[test] -fn traverse_breadth_first() { - let mut tree = Tree::new(0); - let parent = tree.root(); - let child1 = tree.create_node(1); - tree.add_child(parent, child1); - let grandchild1 = tree.create_node(3); - tree.add_child(child1, grandchild1); - let child2 = tree.create_node(2); - tree.add_child(parent, child2); - let grandchild2 = tree.create_node(4); - tree.add_child(child2, grandchild2); - - let mut node_count = 0; - tree.traverse_breadth_first(move |node| { - assert_eq!(*node, node_count); - node_count += 1; - }); -} - -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub struct PassId(u64); - -pub trait UpwardPass { - fn pass_id(&self) -> PassId; - fn dependancies(&self) -> &'static [PassId]; - fn dependants(&self) -> &'static [PassId]; - fn upward_pass(&self, node: &mut T, parent: Option<&mut T>) -> bool; -} - -pub trait DownwardPass { - fn pass_id(&self) -> PassId; - fn dependancies(&self) -> &'static [PassId]; - fn dependants(&self) -> &'static [PassId]; - fn downward_pass<'a>( - &self, - node: &mut T, - children: &mut dyn Iterator, - ) -> bool; -} - -pub trait NodePass { - fn pass_id(&self) -> PassId; - fn dependancies(&self) -> &'static [PassId]; - fn dependants(&self) -> &'static [PassId]; - fn node_pass(&self, node: &mut T) -> bool; -} - -pub enum AnyPass { - Upward(Box + Send + Sync>), - Downward(Box + Send + Sync>), - Node(Box + Send + Sync>), -} - -impl AnyPass { - fn pass_id(&self) -> PassId { - match self { - Self::Upward(pass) => pass.pass_id(), - Self::Downward(pass) => pass.pass_id(), - Self::Node(pass) => pass.pass_id(), - } - } - - fn dependancies(&self) -> &'static [PassId] { - match self { - Self::Upward(pass) => pass.dependancies(), - Self::Downward(pass) => pass.dependancies(), - Self::Node(pass) => pass.dependancies(), - } - } - - fn dependants(&self) -> &'static [PassId] { - match self { - Self::Upward(pass) => pass.dependants(), - Self::Downward(pass) => pass.dependants(), - Self::Node(pass) => pass.dependants(), - } - } -} - -type FxDashSet = dashmap::DashSet>; -type FxDashMap = dashmap::DashMap>; - -#[derive(Default)] -struct DirtyNodeStates { - dirty: FxDashMap>, -} - -impl DirtyNodeStates { - fn new(starting_nodes: FxHashMap>) -> Self { - let mut this = Self::default(); - for (node, nodes) in starting_nodes { - for pass_id in nodes { - this.insert(pass_id, node); - } - } - this - } - - fn insert(&self, pass_id: PassId, node_id: NodeId) { - let pass_id = pass_id.0; - let index = pass_id / 64; - let bit = pass_id % 64; - let encoded = 1 << bit; - if let Some(dirty) = self.dirty.get(&node_id) { - if let Some(atomic) = dirty.get(index as usize) { - atomic.fetch_or(encoded, Ordering::Relaxed); - } else { - drop(dirty); - let mut write = self.dirty.get_mut(&node_id).unwrap(); - write.resize_with(index as usize + 1, || AtomicU64::new(0)); - write[index as usize].fetch_or(encoded, Ordering::Relaxed); - } - } else { - self.dirty.insert(node_id, vec![AtomicU64::new(encoded)]); - } - } - - fn all_dirty(&self, pass_id: PassId) -> impl Iterator + '_ { - let pass_id = pass_id.0; - let index = pass_id / 64; - let bit = pass_id % 64; - let encoded = 1 << bit; - self.dirty.iter().filter_map(move |entry| { - let node_id = entry.key(); - let dirty = entry.value(); - if let Some(atomic) = dirty.get(index as usize) { - if atomic.load(Ordering::Relaxed) & encoded != 0 { - Some(*node_id) - } else { - None - } - } else { - None - } - }) - } -} - -pub fn resolve_passes( - tree: &mut impl TreeView, - starting_nodes: FxHashMap>, - mut passes: Vec>, -) { - let dirty_nodes: Arc = Arc::new(DirtyNodeStates::new(starting_nodes)); - let global = Injector::default(); - - let core_count = thread::available_parallelism() - .map(|c| c.get()) - .unwrap_or(1); - let workers: Vec> = (0..core_count).map(|_| Worker::new_fifo()).collect(); - let stealers: Vec<_> = workers.iter().map(|w| w.stealer()).collect(); - let shared_view = SharedView::new(tree); - let mut resolved_passes: FxHashSet = FxHashSet::default(); - let current_pass: Arc>>> = Arc::new(RwLock::new(None)); - thread::scope(|s| { - let global = &global; - let stealers = &stealers; - for (_, w) in (0..core_count).zip(workers.into_iter()) { - let mut shared_view = shared_view.clone(); - let current_pass = current_pass.clone(); - let dirty_nodes = dirty_nodes.clone(); - s.spawn(move || { - while let Some(current_pass) = &*current_pass.read() { - match current_pass { - AnyPass::Upward(pass) => { - while let Some(id) = find_task(&w, global, stealers) { - let (node, parent) = shared_view.node_parent_mut(id).unwrap(); - if pass.upward_pass(node, parent) { - if let Some(id) = shared_view.parent_id(id) { - for dependant in pass.dependants() { - dirty_nodes.insert(*dependant, id); - } - w.push(id); - } - } - } - } - AnyPass::Downward(pass) => { - while let Some(id) = find_task(&w, global, stealers) { - let (node, mut children) = - shared_view.parent_child_mut(id).unwrap(); - if pass.downward_pass(node, &mut children) { - drop(children); - for id in shared_view.children_ids(id).unwrap() { - for dependant in pass.dependants() { - dirty_nodes.insert(*dependant, *id); - } - w.push(*id); - } - } - } - } - AnyPass::Node(pass) => { - while let Some(id) = find_task(&w, global, stealers) { - let node = shared_view.get_mut(id).unwrap(); - if pass.node_pass(node) { - for dependant in pass.dependants() { - dirty_nodes.insert(*dependant, id); - } - } - } - } - } - } - }); - } - while !passes.is_empty() { - for i in 0..passes.len() { - if passes[i] - .dependancies() - .iter() - .all(|id| resolved_passes.contains(id)) - { - let pass = passes.remove(i); - let pass_id = pass.pass_id(); - for node in dirty_nodes.all_dirty(pass_id) { - global.push(node); - } - resolved_passes.insert(pass_id); - break; - } - } - } - *current_pass.write() = None; - }); -} - -fn find_task(local: &Worker, global: &Injector, stealers: &[Stealer]) -> Option { - // Pop a task from the local queue, if not empty. - local.pop().or_else(|| { - // Otherwise, we need to look for a task elsewhere. - std::iter::repeat_with(|| { - // Try stealing a batch of tasks from the global queue. - global - .steal_batch_and_pop(local) - // Or try stealing a task from one of the other threads. - .or_else(|| stealers.iter().map(|s| s.steal()).collect()) - }) - // Loop while no task was stolen and any steal operation needs to be retried. - .find(|s| !s.is_retry()) - // Extract the stolen task, if there is one. - .and_then(|s| s.success()) - }) -}