From a0638b786c8bf9536d9d4804de5e19aca3105eb4 Mon Sep 17 00:00:00 2001 From: Greg Johnston Date: Sun, 18 Aug 2024 17:31:06 -0400 Subject: [PATCH] feat: allow mutating AsyncDerived and therefore Resources (closes #2743) --- .../async_derived/arc_async_derived.rs | 105 ++++++++++++++---- .../computed/async_derived/async_derived.rs | 41 ++++++- 2 files changed, 121 insertions(+), 25 deletions(-) diff --git a/reactive_graph/src/computed/async_derived/arc_async_derived.rs b/reactive_graph/src/computed/async_derived/arc_async_derived.rs index 978687354..bbac44ce4 100644 --- a/reactive_graph/src/computed/async_derived/arc_async_derived.rs +++ b/reactive_graph/src/computed/async_derived/arc_async_derived.rs @@ -12,8 +12,8 @@ use crate::{ SubscriberSet, ToAnySource, ToAnySubscriber, WithObserver, }, owner::{use_context, Owner}, - signal::guards::{AsyncPlain, ReadGuard}, - traits::{DefinedAt, ReadUntracked}, + signal::guards::{AsyncPlain, ReadGuard, WriteGuard}, + traits::{DefinedAt, ReadUntracked, Trigger, UntrackableGuard, Writeable}, transition::AsyncTransition, }; use any_spawner::Executor; @@ -24,6 +24,7 @@ use or_poisoned::OrPoisoned; use std::{ future::Future, mem, + ops::DerefMut, panic::Location, sync::{ atomic::{AtomicBool, Ordering}, @@ -112,6 +113,10 @@ pub(crate) trait BlockingLock { fn blocking_read_arc(self: &Arc) -> async_lock::RwLockReadGuardArc; + fn blocking_write_arc( + self: &Arc, + ) -> async_lock::RwLockWriteGuardArc; + fn blocking_read(&self) -> async_lock::RwLockReadGuard<'_, T>; fn blocking_write(&self) -> async_lock::RwLockWriteGuard<'_, T>; @@ -131,6 +136,19 @@ impl BlockingLock for AsyncRwLock { } } + fn blocking_write_arc( + self: &Arc, + ) -> async_lock::RwLockWriteGuardArc { + #[cfg(not(target_family = "wasm"))] + { + self.write_arc_blocking() + } + #[cfg(target_family = "wasm")] + { + self.write_arc().now_or_never().unwrap() + } + } + fn blocking_read(&self) -> async_lock::RwLockReadGuard<'_, T> { #[cfg(not(target_family = "wasm"))] { @@ -293,25 +311,7 @@ macro_rules! spawn_derived { // generate and assign new value loading.store(true, Ordering::Relaxed); let new_value = fut.await; - loading.store(false, Ordering::Relaxed); - *value.write().await = Some(new_value); - inner.write().or_poisoned().dirty = true; - - // if it's an Err, that just means the Receiver was dropped - // we don't particularly care about that: the point is to notify if - // it still exists, but we don't need to know if Suspense is no - // longer listening - _ = ready_tx.send(()); - - // notify reactive subscribers that we're not loading any more - for sub in (&inner.read().or_poisoned().subscribers).into_iter() { - sub.mark_dirty(); - } - - // notify async .awaiters - for waker in mem::take(&mut *wakers.write().or_poisoned()) { - waker.wake(); - } + Self::set_inner_value(new_value, value, wakers, inner, loading, Some(ready_tx)).await; } _ => break, } @@ -330,6 +330,49 @@ macro_rules! spawn_derived { }}; } +impl ArcAsyncDerived { + async fn set_inner_value( + new_value: T, + value: Arc>>, + wakers: Arc>>, + inner: Arc>, + loading: Arc, + ready_tx: Option>, + ) { + *value.write().await = Some(new_value); + Self::notify_subs(&wakers, &inner, &loading, ready_tx); + } + + fn notify_subs( + wakers: &Arc>>, + inner: &Arc>, + loading: &Arc, + ready_tx: Option>, + ) { + loading.store(false, Ordering::Relaxed); + + inner.write().or_poisoned().dirty = true; + + if let Some(ready_tx) = ready_tx { + // if it's an Err, that just means the Receiver was dropped + // we don't particularly care about that: the point is to notify if + // it still exists, but we don't need to know if Suspense is no + // longer listening + _ = ready_tx.send(()); + } + + // notify reactive subscribers that we're not loading any more + for sub in (&inner.read().or_poisoned().subscribers).into_iter() { + sub.mark_dirty(); + } + + // notify async .awaiters + for waker in mem::take(&mut *wakers.write().or_poisoned()) { + waker.wake(); + } + } +} + impl ArcAsyncDerived { /// Creates a new async derived computation. /// @@ -456,6 +499,26 @@ impl ReadUntracked for ArcAsyncDerived { } } +impl Trigger for ArcAsyncDerived { + fn trigger(&self) { + Self::notify_subs(&self.wakers, &self.inner, &self.loading, None); + } +} + +impl Writeable for ArcAsyncDerived { + type Value = Option; + + fn try_write(&self) -> Option> { + Some(WriteGuard::new(self.clone(), self.value.blocking_write())) + } + + fn try_write_untracked( + &self, + ) -> Option> { + Some(self.value.blocking_write()) + } +} + impl ToAnySource for ArcAsyncDerived { fn to_any_source(&self) -> AnySource { AnySource( diff --git a/reactive_graph/src/computed/async_derived/async_derived.rs b/reactive_graph/src/computed/async_derived/async_derived.rs index 5ebec8af3..0b46bd87d 100644 --- a/reactive_graph/src/computed/async_derived/async_derived.rs +++ b/reactive_graph/src/computed/async_derived/async_derived.rs @@ -1,16 +1,18 @@ -use super::{ArcAsyncDerived, AsyncDerivedReadyFuture}; +use super::{ArcAsyncDerived, AsyncDerivedReadyFuture, BlockingLock}; use crate::{ graph::{ AnySource, AnySubscriber, ReactiveNode, Source, Subscriber, ToAnySource, ToAnySubscriber, }, owner::{FromLocal, LocalStorage, Storage, StoredValue, SyncStorage}, - signal::guards::{AsyncPlain, ReadGuard}, - traits::{DefinedAt, Dispose, ReadUntracked}, + signal::guards::{AsyncPlain, ReadGuard, WriteGuard}, + traits::{ + DefinedAt, Dispose, ReadUntracked, Trigger, UntrackableGuard, Writeable, + }, unwrap_signal, }; use core::fmt::Debug; -use std::{future::Future, panic::Location}; +use std::{future::Future, ops::DerefMut, panic::Location}; /// A reactive value that is derived by running an asynchronous computation in response to changes /// in its sources. @@ -286,6 +288,37 @@ where } } +impl Trigger for AsyncDerived +where + T: 'static, + S: Storage>, +{ + fn trigger(&self) { + self.inner.try_with_value(|inner| inner.trigger()); + } +} + +impl Writeable for AsyncDerived +where + T: 'static, + S: Storage>, +{ + type Value = Option; + + fn try_write(&self) -> Option> { + let guard = self + .inner + .try_with_value(|n| n.value.blocking_write_arc())?; + Some(WriteGuard::new(*self, guard)) + } + + fn try_write_untracked( + &self, + ) -> Option> { + self.inner.try_with_value(|n| n.value.blocking_write_arc()) + } +} + impl ToAnySource for AsyncDerived where T: 'static,