fix: properly trigger Suspense when Suspend is called again (#2993)

This commit is contained in:
Greg Johnston 2024-09-18 21:35:37 -04:00 committed by GitHub
parent c4b1176a6a
commit b0d8d4ee26
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 32 additions and 5 deletions

View file

@ -233,7 +233,8 @@ macro_rules! spawn_derived {
sources: SourceSet::new(), sources: SourceSet::new(),
subscribers: SubscriberSet::new(), subscribers: SubscriberSet::new(),
state: AsyncDerivedState::Clean, state: AsyncDerivedState::Clean,
version: 0 version: 0,
suspenses: Vec::new()
})); }));
let value = Arc::new(AsyncRwLock::new($initial)); let value = Arc::new(AsyncRwLock::new($initial));
let wakers = Arc::new(RwLock::new(Vec::new())); let wakers = Arc::new(RwLock::new(Vec::new()));
@ -345,14 +346,21 @@ macro_rules! spawn_derived {
// generate and assign new value // generate and assign new value
loading.store(true, Ordering::Relaxed); loading.store(true, Ordering::Relaxed);
let this_version = { let (this_version, suspense_ids) = {
let mut guard = inner.write().or_poisoned(); let mut guard = inner.write().or_poisoned();
guard.version += 1; guard.version += 1;
guard.version let version = guard.version;
let suspense_ids = mem::take(&mut guard.suspenses)
.into_iter()
.map(|sc| sc.task_id())
.collect::<Vec<_>>();
(version, suspense_ids)
}; };
let new_value = fut.await; let new_value = fut.await;
drop(suspense_ids);
let latest_version = inner.read().or_poisoned().version; let latest_version = inner.read().or_poisoned().version;
if latest_version == this_version { if latest_version == this_version {
@ -575,6 +583,11 @@ impl<T: 'static> ReadUntracked for ArcAsyncDerived<T> {
ready.await; ready.await;
drop(handle); drop(handle);
}); });
self.inner
.write()
.or_poisoned()
.suspenses
.push(suspense_context);
} }
} }
AsyncPlain::try_new(&self.value).map(ReadGuard::new) AsyncPlain::try_new(&self.value).map(ReadGuard::new)

View file

@ -1,8 +1,9 @@
use super::{ArcAsyncDerived, AsyncDerived}; use super::{inner::ArcAsyncDerivedInner, ArcAsyncDerived, AsyncDerived};
use crate::{ use crate::{
computed::suspense::SuspenseContext,
diagnostics::SpecialNonReactiveZone, diagnostics::SpecialNonReactiveZone,
graph::{AnySource, ToAnySource}, graph::{AnySource, ToAnySource},
owner::Storage, owner::{use_context, Storage},
signal::guards::{AsyncPlain, Mapped, ReadGuard}, signal::guards::{AsyncPlain, Mapped, ReadGuard},
traits::{DefinedAt, Track}, traits::{DefinedAt, Track},
unwrap_signal, unwrap_signal,
@ -63,6 +64,7 @@ where
value: Arc::clone(&self.value), value: Arc::clone(&self.value),
loading: Arc::clone(&self.loading), loading: Arc::clone(&self.loading),
wakers: Arc::clone(&self.wakers), wakers: Arc::clone(&self.wakers),
inner: Arc::clone(&self.inner),
} }
} }
} }
@ -92,6 +94,7 @@ pub struct AsyncDerivedFuture<T> {
value: Arc<async_lock::RwLock<Option<T>>>, value: Arc<async_lock::RwLock<Option<T>>>,
loading: Arc<AtomicBool>, loading: Arc<AtomicBool>,
wakers: Arc<RwLock<Vec<Waker>>>, wakers: Arc<RwLock<Vec<Waker>>>,
inner: Arc<RwLock<ArcAsyncDerivedInner>>,
} }
impl<T> Future for AsyncDerivedFuture<T> impl<T> Future for AsyncDerivedFuture<T>
@ -107,6 +110,15 @@ where
let waker = cx.waker(); let waker = cx.waker();
self.source.track(); self.source.track();
let value = self.value.read_arc(); let value = self.value.read_arc();
if let Some(suspense_context) = use_context::<SuspenseContext>() {
self.inner
.write()
.or_poisoned()
.suspenses
.push(suspense_context);
}
pin_mut!(value); pin_mut!(value);
match (self.loading.load(Ordering::Relaxed), value.poll(cx)) { match (self.loading.load(Ordering::Relaxed), value.poll(cx)) {
(true, _) => { (true, _) => {

View file

@ -1,5 +1,6 @@
use crate::{ use crate::{
channel::Sender, channel::Sender,
computed::suspense::SuspenseContext,
graph::{ graph::{
AnySource, AnySubscriber, ReactiveNode, Source, SourceSet, Subscriber, AnySource, AnySubscriber, ReactiveNode, Source, SourceSet, Subscriber,
SubscriberSet, SubscriberSet,
@ -20,6 +21,7 @@ pub(crate) struct ArcAsyncDerivedInner {
pub notifier: Sender, pub notifier: Sender,
pub state: AsyncDerivedState, pub state: AsyncDerivedState,
pub version: usize, pub version: usize,
pub suspenses: Vec<SuspenseContext>,
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]