work on implementing macro

This commit is contained in:
Evan Almloff 2022-12-03 18:28:32 -06:00
parent 0d780fee61
commit 53094a7bbd
2 changed files with 144 additions and 288 deletions

View file

@ -154,77 +154,21 @@ fn impl_derive_macro(ast: &syn::DeriveInput) -> TokenStream {
.map(|m| &m.mem.ident) .map(|m| &m.mem.ident)
.collect(); .collect();
let member_types = state_strct.state_members.iter().map(|m| &m.mem.ty); let member_types = state_strct.state_members.iter().map(|m| &m.mem.ty);
let resolve_members = state_strct let impl_members = state_strct
.state_members .state_members
.iter() .iter()
.map(|m| state_strct.resolve(m)); .map(|m| state_strct.impl_pass(m));
let child_types = state_strct.child_states.iter().map(|s| &s.ty); // let child_types = state_strct.child_states.iter().map(|s| &s.ty);
let child_members = state_strct.child_states.iter().map(|s| &s.ident); // let child_members = state_strct.child_states.iter().map(|s| &s.ident);
let gen = quote! { let gen = quote! {
#(#impl_members)*
impl State for #type_name { impl State for #type_name {
fn update<'a, T: dioxus_native_core::traversable::Traversable<Node = Self, Id = dioxus_native_core::RealNodeId>,T2: dioxus_native_core::traversable::Traversable<Node = dioxus_native_core::real_dom::NodeData, Id = dioxus_native_core::RealNodeId>>( const PASSES: &'static [AnyPass<Node<Self>>] = &[
dirty: &[(dioxus_native_core::RealNodeId, dioxus_native_core::node_ref::NodeMask)], AnyPass(#(&#member_types),*)
state_tree: &'a mut T, ];
rdom: &'a T2, const MASKS: &'static [NodeMask] = &[#(#member_types::NODE_MASK),*];
ctx: &anymap::AnyMap,
) -> rustc_hash::FxHashSet<dioxus_native_core::RealNodeId>{
#[derive(Clone, Copy)]
struct MembersDirty {
#(#members: bool, )*
}
impl MembersDirty {
fn new() -> Self {
Self {#(#members: false),*}
}
fn any(&self) -> bool {
#(self.#members || )* false
}
fn union(self, other: Self) -> Self {
Self {#(#members: self.#members || other.#members),*}
}
}
let mut dirty_elements = rustc_hash::FxHashSet::default();
// the states of any elements that are dirty
let mut states: rustc_hash::FxHashMap<dioxus_native_core::RealNodeId, MembersDirty> = rustc_hash::FxHashMap::default();
for (id, mask) in dirty {
let members_dirty = MembersDirty {
#(#members: #member_types::NODE_MASK.overlaps(mask),)*
};
if members_dirty.any(){
if let Some(state) = states.get_mut(id){
*state = state.union(members_dirty);
}
else{
states.insert(*id, members_dirty);
}
}
dirty_elements.insert(*id);
}
#(
#resolve_members;
)*
#(
dirty_elements.extend(
<#child_types as dioxus_native_core::state::State>::update(
dirty,
&mut state_tree.map(|n| &n.#child_members, |n| &mut n.#child_members),
rdom,
ctx,
)
);
)*
dirty_elements
}
} }
}; };
gen.into() gen.into()
@ -246,7 +190,11 @@ struct Struct {
impl Struct { impl Struct {
fn new(name: Ident, fields: &[&Field]) -> Self { fn new(name: Ident, fields: &[&Field]) -> Self {
let members = fields.iter().filter_map(|f| Member::parse(f)).collect(); let members = fields
.iter()
.enumerate()
.filter_map(|(i, f)| Member::parse(f, 1 << i))
.collect();
Self { name, members } Self { name, members }
} }
} }
@ -260,7 +208,7 @@ impl<'a> StateStruct<'a> {
/// Parse the state structure, and find a resolution order that will allow us to update the state for each node in after the state(s) it depends on have been resolved. /// Parse the state structure, and find a resolution order that will allow us to update the state for each node in after the state(s) it depends on have been resolved.
fn parse(fields: &[&'a Field], strct: &'a Struct) -> Result<Self> { fn parse(fields: &[&'a Field], strct: &'a Struct) -> Result<Self> {
let mut parse_err = Ok(()); let mut parse_err = Ok(());
let mut unordered_state_members: Vec<_> = strct let mut state_members: Vec<_> = strct
.members .members
.iter() .iter()
.zip(fields.iter()) .zip(fields.iter())
@ -273,41 +221,10 @@ impl<'a> StateStruct<'a> {
}) })
.collect(); .collect();
parse_err?; parse_err?;
for i in 0..state_members.len() {
let mut state_members = Vec::new(); let deps = state_members[i].dep_mems.clone();
// Keep adding members that have had all of there dependancies resolved until there are no more members left. for dep in deps {
while !unordered_state_members.is_empty() { state_members[dep as usize].dependant_mems.push(i as u64);
let mut resolved = false;
for i in 0..unordered_state_members.len() {
let mem = &mut unordered_state_members[i];
// if this member has all of its dependancies resolved other than itself, resolve it next.
if mem.dep_mems.iter().all(|(dep, resolved)| {
*resolved || (*dep == mem.mem && mem.dep_kind != DependencyKind::Node)
}) {
let mem = unordered_state_members.remove(i);
// mark any dependency that depends on this member as resolved
for member in unordered_state_members.iter_mut() {
for (dep, resolved) in &mut member.dep_mems {
*resolved |= *dep == mem.mem;
}
}
state_members.push(mem);
resolved = true;
break;
}
}
if !resolved {
return Err(Error::new(
strct.name.span(),
format!(
"{} has circular dependacy in {:?}",
strct.name,
unordered_state_members
.iter()
.map(|m| format!("{}", &m.mem.ident))
.collect::<Vec<_>>()
),
));
} }
} }
@ -351,175 +268,6 @@ impl<'a> StateStruct<'a> {
} }
dependants dependants
} }
// Mark the states that depend on the current state as dirty
fn update_dependants(&self, mem: &Member) -> impl ToTokens {
let dep = self.get_depenadants(mem);
let update_child_dependants = if dep.child.is_empty() {
quote!()
} else {
let insert = dep.child.iter().map(|d|{
if *d == mem {
quote! {
let seeking = dioxus_native_core::HeightOrdering::new(state_tree.height(parent_id).unwrap(), parent_id);
if let Err(idx) = resolution_order
.binary_search_by(|ordering| ordering.cmp(&seeking).reverse()){
resolution_order.insert(
idx,
seeking,
);
}
}
} else {
quote! {}
}
});
let update: Vec<_> = dep
.child
.iter()
.map(|d| {
let ident = &d.ident;
quote! {
dirty.#ident = true;
}
})
.collect();
quote! {
if let Some(parent_id) = state_tree.parent(id) {
#(#insert)*
if let Some(dirty) = states.get_mut(&parent_id) {
#(#update)*
}
else {
let mut dirty = MembersDirty::new();
#(#update)*
states.insert(parent_id, dirty);
}
}
}
};
let node_dependants: Vec<_> = dep.node.iter().map(|d| &d.ident).collect();
let update_node_dependants = quote! {#(members_dirty.#node_dependants = true;)*};
let update_parent_dependants = if dep.parent.is_empty() {
quote!()
} else {
let insert = dep.parent.iter().map(|d| {
if *d == mem {
quote! {
let seeking = dioxus_native_core::HeightOrdering::new(state_tree.height(*child_id).unwrap(), *child_id);
if let Err(idx) = resolution_order
.binary_search(&seeking){
resolution_order.insert(
idx,
seeking,
);
}
}
} else {
quote! {}
}
});
let update: Vec<_> = dep
.parent
.iter()
.map(|d| {
let ident = &d.ident;
quote! {
dirty.#ident = true;
}
})
.collect();
quote! {
for child_id in state_tree.children(id) {
#(#insert)*
if let Some(dirty) = states.get_mut(&child_id) {
#(#update)*
}
else {
let mut dirty = MembersDirty::new();
#(#update)*
states.insert(*child_id, dirty);
}
}
}
};
quote! {
#update_node_dependants
#update_child_dependants
#update_parent_dependants
}
}
// Generate code to resolve this state
fn resolve(&self, mem: &StateMember) -> impl ToTokens {
let reduce_member = mem.reduce_self();
let update_dependant = self.update_dependants(mem.mem);
let member = &mem.mem.ident;
match mem.dep_kind {
DependencyKind::Parent => {
quote! {
// resolve parent dependant state
let mut resolution_order = states.keys().copied().map(|id| dioxus_native_core::HeightOrdering::new(state_tree.height(id).unwrap(), id)).collect::<Vec<_>>();
resolution_order.sort();
let mut i = 0;
while i < resolution_order.len(){
let id = resolution_order[i].id;
let node = rdom.get(id).unwrap();
let members_dirty = states.get_mut(&id).unwrap();
let (current_state, parent) = state_tree.get_node_parent_mut(id);
let current_state = current_state.unwrap();
if members_dirty.#member && #reduce_member {
dirty_elements.insert(id);
#update_dependant
}
i += 1;
}
}
}
DependencyKind::Child => {
quote! {
// resolve child dependant state
let mut resolution_order = states.keys().copied().map(|id| dioxus_native_core::HeightOrdering::new(state_tree.height(id).unwrap(), id)).collect::<Vec<_>>();
resolution_order.sort_by(|height_ordering1, height_ordering2| {
height_ordering1.cmp(&height_ordering2).reverse()
});
let mut i = 0;
while i < resolution_order.len(){
let id = resolution_order[i].id;
let node = rdom.get(id).unwrap();
let members_dirty = states.get_mut(&id).unwrap();
let (current_state, children) = state_tree.get_node_children_mut(id);
let current_state = current_state.unwrap();
if members_dirty.#member && #reduce_member {
dirty_elements.insert(id);
#update_dependant
}
i += 1;
}
}
}
DependencyKind::Node => {
quote! {
// resolve node dependant state
let mut resolution_order = states.keys().copied().collect::<Vec<_>>();
let mut i = 0;
while i < resolution_order.len(){
let id = resolution_order[i];
let node = rdom.get(id).unwrap();
let members_dirty = states.get_mut(&id).unwrap();
let current_state = state_tree.get_mut(id).unwrap();
if members_dirty.#member && #reduce_member {
dirty_elements.insert(id);
#update_dependant
}
i += 1;
}
}
}
}
}
} }
fn try_parenthesized(input: ParseStream) -> Result<ParseBuffer> { fn try_parenthesized(input: ParseStream) -> Result<ParseBuffer> {
@ -562,13 +310,15 @@ impl Parse for Dependency {
/// The type of the member and the ident of the member /// The type of the member and the ident of the member
#[derive(PartialEq, Debug)] #[derive(PartialEq, Debug)]
struct Member { struct Member {
id: u64,
ty: Type, ty: Type,
ident: Ident, ident: Ident,
} }
impl Member { impl Member {
fn parse(field: &Field) -> Option<Self> { fn parse(field: &Field, id: u64) -> Option<Self> {
Some(Self { Some(Self {
id,
ty: field.ty.clone(), ty: field.ty.clone(),
ident: field.ident.as_ref()?.clone(), ident: field.ident.as_ref()?.clone(),
}) })
@ -581,7 +331,9 @@ struct StateMember<'a> {
// the kind of dependncies this state has // the kind of dependncies this state has
dep_kind: DependencyKind, dep_kind: DependencyKind,
// the depenancy and if it is satified // the depenancy and if it is satified
dep_mems: Vec<(&'a Member, bool)>, dep_mems: Vec<u64>,
// any members that depend on this member
dependant_mems: Vec<u64>,
// the context this state requires // the context this state requires
ctx_ty: Option<Type>, ctx_ty: Option<Type>,
} }
@ -610,7 +362,7 @@ impl<'a> StateMember<'a> {
.iter() .iter()
.filter_map(|name| { .filter_map(|name| {
if let Some(found) = parent.members.iter().find(|m| &m.ident == name) { if let Some(found) = parent.members.iter().find(|m| &m.ident == name) {
Some((found, false)) Some((found.id, false))
} else { } else {
err = Err(Error::new( err = Err(Error::new(
name.span(), name.span(),
@ -624,6 +376,7 @@ impl<'a> StateMember<'a> {
mem, mem,
dep_kind, dep_kind,
dep_mems, dep_mems,
dependant_mems: Vec::new(),
ctx_ty: dependency.ctx_ty, ctx_ty: dependency.ctx_ty,
}) })
} }
@ -638,8 +391,9 @@ impl<'a> StateMember<'a> {
} }
/// generate code to call the resolve function for the state. This does not handle checking if resolving the state is necessary, or marking the states that depend on this state as dirty. /// generate code to call the resolve function for the state. This does not handle checking if resolving the state is necessary, or marking the states that depend on this state as dirty.
fn reduce_self(&self) -> quote::__private::TokenStream { fn impl_pass(&self, parent_type: Type) -> quote::__private::TokenStream {
let ident = &self.mem.ident; let ident = &self.mem.ident;
let ty = &self.mem.ty;
let get_ctx = if let Some(ctx_ty) = &self.ctx_ty { let get_ctx = if let Some(ctx_ty) = &self.ctx_ty {
if ctx_ty == &parse_quote!(()) { if ctx_ty == &parse_quote!(()) {
quote! {&()} quote! {&()}
@ -652,23 +406,125 @@ impl<'a> StateMember<'a> {
}; };
let ty = &self.mem.ty; let ty = &self.mem.ty;
let node_view = quote!(dioxus_native_core::node_ref::NodeView::new(node, #ty::NODE_MASK)); let node_view = quote!(dioxus_native_core::node_ref::NodeView::new(unsafe{&*{&node.node_data as *const _}}, #ty::NODE_MASK));
let dep_idents = self.dep_mems.iter().map(|m| &m.0.ident); let dep_idents = self.dep_mems.iter().map(|m| &m.0.ident);
match self.dep_kind { let impl_specific = match self.dep_kind {
DependencyKind::Node => { DependencyKind::Node => {
quote!({ quote! {
current_state.#ident.reduce(#node_view, (#(&current_state.#dep_idents,)*), #get_ctx) impl NodePass<#parent_type> for #ty {
}) fn pass(&self, node: &mut T, ctx: &SendAnyMap) -> bool {
node.state.#ident.reduce(#node_view, (#(&node.state.#dep_idents,)*), #get_ctx)
}
}
}
} }
DependencyKind::Child => { DependencyKind::Child => {
quote!({ let update = if self.dep_mems.iter().any(|id| id == self.mem.id) {
current_state.#ident.reduce(#node_view, children.iter().map(|c| (#(&c.#dep_idents)*)), #get_ctx) quote! {
}) if update {
PassReturn{
progress: true,
mark_dirty: true,
}
} else {
PassReturn{
progress: false,
mark_dirty: false,
}
}
}
} else {
quote! {
if update {
PassReturn{
progress: false,
mark_dirty: true,
}
} else {
PassReturn{
progress: false,
mark_dirty: false,
}
}
}
};
quote!(
pub trait UpwardPass<T>: Pass {
fn pass<'a>(
&self,
node: &mut T,
children: &mut dyn Iterator<Item = &'a mut T>,
ctx: &SendAnyMap,
) -> PassReturn {
let update = node.state.#ident.reduce(#node_view, children.map(|c| (#(&c.state.#dep_idents,)*)), #get_ctx);
#update
}
}
)
} }
DependencyKind::Parent => { DependencyKind::Parent => {
quote!({ let update = if self.dep_mems.iter().any(|id| id == self.mem.id) {
current_state.#ident.reduce(#node_view, parent.as_ref().map(|p| (#(&p.#dep_idents)*)), #get_ctx) quote! {
}) if update {
PassReturn{
progress: true,
mark_dirty: true,
}
} else {
PassReturn{
progress: false,
mark_dirty: false,
}
}
}
} else {
quote! {
if update {
PassReturn{
progress: false,
mark_dirty: true,
}
} else {
PassReturn{
progress: false,
mark_dirty: false,
}
}
}
};
quote!(
impl DownwardPass<#parent_type> for #ty {
fn pass(&self, node: &mut T, parent: Option<&mut T>, ctx: &SendAnyMap) -> PassReturn{
let update = current_state.state.#ident.reduce(#node_view, parent.as_ref().map(|p| (#(&p.state.#dep_idents,)*)), #get_ctx);
#update
}
}
)
}
};
let pass_id = self.mem.id;
let depenancies = &self.dep_mems;
let dependants = &self.dependant_mems;
let mask = self
.dep_mems
.iter()
.map(|m| m.0.id)
.fold(self.mem.id, |a, b| a | b);
quote! {
#impl_specific
impl Pass for #ty {
fn pass_id(&self) -> PassId {
PassId(#pass_id)
}
fn dependancies(&self) -> &'static [PassId] {
&[#(PassId(#depenancies)),*]
}
fn dependants(&self) -> &'static [PassId] {
&[#(PassId(#dependants)),*]
}
fn mask(&self) -> MemberMask {
MemberMask(#mask)
}
} }
} }
} }

View file

@ -151,8 +151,8 @@ impl BitOr for MemberMask {
} }
pub struct PassReturn { pub struct PassReturn {
progress: bool, pub progress: bool,
mark_dirty: bool, pub mark_dirty: bool,
} }
pub trait Pass { pub trait Pass {