From db33bc2e6178443ae1df9ef4aa2c9bb45b8fc33e Mon Sep 17 00:00:00 2001 From: Greg Johnston Date: Wed, 22 May 2024 21:45:02 -0400 Subject: [PATCH] feat: owning memo --- reactive_graph/src/computed/arc_memo.rs | 24 ++++++++++++++-- reactive_graph/src/computed/inner.rs | 22 ++++----------- reactive_graph/src/computed/memo.rs | 37 +++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 19 deletions(-) diff --git a/reactive_graph/src/computed/arc_memo.rs b/reactive_graph/src/computed/arc_memo.rs index 4143f9cff..e24ec299c 100644 --- a/reactive_graph/src/computed/arc_memo.rs +++ b/reactive_graph/src/computed/arc_memo.rs @@ -31,7 +31,7 @@ impl ArcMemo { where T: PartialEq, { - Self::new_with_compare(fun, |lhs, rhs| lhs.as_ref() == rhs.as_ref()) + Self::new_with_compare(fun, |lhs, rhs| lhs.as_ref() != rhs.as_ref()) } #[track_caller] @@ -41,7 +41,25 @@ impl ArcMemo { )] pub fn new_with_compare( fun: impl Fn(Option<&T>) -> T + Send + Sync + 'static, - is_same: fn(Option<&T>, Option<&T>) -> bool, + changed: fn(Option<&T>, Option<&T>) -> bool, + ) -> Self + where + T: PartialEq, + { + Self::new_owning(move |prev: Option| { + let new_value = fun(prev.as_ref()); + let changed = changed(prev.as_ref(), Some(&new_value)); + (new_value, changed) + }) + } + + #[track_caller] + #[cfg_attr( + feature = "tracing", + tracing::instrument(level = "trace", skip_all,) + )] + pub fn new_owning( + fun: impl Fn(Option) -> (T, bool) + Send + Sync + 'static, ) -> Self where T: PartialEq, @@ -52,7 +70,7 @@ impl ArcMemo { Weak::clone(weak) as Weak, ); - RwLock::new(MemoInner::new(Arc::new(fun), is_same, subscriber)) + RwLock::new(MemoInner::new(Arc::new(fun), subscriber)) }); Self { #[cfg(debug_assertions)] diff --git a/reactive_graph/src/computed/inner.rs b/reactive_graph/src/computed/inner.rs index d20352a64..0c7857e78 100644 --- a/reactive_graph/src/computed/inner.rs +++ b/reactive_graph/src/computed/inner.rs @@ -14,8 +14,7 @@ use std::{ pub struct MemoInner { pub(crate) value: Option, #[allow(clippy::type_complexity)] - pub(crate) fun: Arc) -> T + Send + Sync>, - pub(crate) compare_with: fn(Option<&T>, Option<&T>) -> bool, + pub(crate) fun: Arc) -> (T, bool) + Send + Sync>, pub(crate) owner: Owner, pub(crate) state: ReactiveNodeState, pub(crate) sources: SourceSet, @@ -32,14 +31,12 @@ impl Debug for MemoInner { impl MemoInner { #[allow(clippy::type_complexity)] pub fn new( - fun: Arc) -> T + Send + Sync>, - compare_with: fn(Option<&T>, Option<&T>) -> bool, + fun: Arc) -> (T, bool) + Send + Sync>, any_subscriber: AnySubscriber, ) -> Self { Self { value: None, fun, - compare_with, owner: Owner::new(), state: ReactiveNodeState::Dirty, sources: Default::default(), @@ -89,24 +86,17 @@ impl ReactiveNode for RwLock> { }; if needs_update { - let (fun, value, compare_with, owner) = { + let (fun, value, owner) = { let mut lock = self.write().or_poisoned(); - ( - lock.fun.clone(), - lock.value.take(), - lock.compare_with, - lock.owner.clone(), - ) + (lock.fun.clone(), lock.value.take(), lock.owner.clone()) }; let any_subscriber = { self.read().or_poisoned().any_subscriber.clone() }; any_subscriber.clear_sources(&any_subscriber); - let new_value = owner.with_cleanup(|| { - any_subscriber.with_observer(|| fun(value.as_ref())) - }); + let (new_value, changed) = owner + .with_cleanup(|| any_subscriber.with_observer(|| fun(value))); - let changed = !compare_with(Some(&new_value), value.as_ref()); let mut lock = self.write().or_poisoned(); lock.value = Some(new_value); lock.state = ReactiveNodeState::Clean; diff --git a/reactive_graph/src/computed/memo.rs b/reactive_graph/src/computed/memo.rs index c55362845..f9fa7c687 100644 --- a/reactive_graph/src/computed/memo.rs +++ b/reactive_graph/src/computed/memo.rs @@ -46,6 +46,43 @@ impl Memo { inner: StoredValue::new(ArcMemo::new(fun)), } } + + #[track_caller] + #[cfg_attr( + feature = "tracing", + tracing::instrument(level = "trace", skip_all,) + )] + pub fn new_with_compare( + fun: impl Fn(Option<&T>) -> T + Send + Sync + 'static, + changed: fn(Option<&T>, Option<&T>) -> bool, + ) -> Self + where + T: PartialEq, + { + Self { + #[cfg(debug_assertions)] + defined_at: Location::caller(), + inner: StoredValue::new(ArcMemo::new_with_compare(fun, changed)), + } + } + + #[track_caller] + #[cfg_attr( + feature = "tracing", + tracing::instrument(level = "trace", skip_all,) + )] + pub fn new_owning( + fun: impl Fn(Option) -> (T, bool) + Send + Sync + 'static, + ) -> Self + where + T: PartialEq, + { + Self { + #[cfg(debug_assertions)] + defined_at: Location::caller(), + inner: StoredValue::new(ArcMemo::new_owning(fun)), + } + } } impl Copy for Memo {}