fix: ensure Resource always tracks its source, and does not double-run (#2948)

This commit is contained in:
Greg Johnston 2024-09-07 18:57:31 -04:00 committed by GitHub
parent 57bd343f4a
commit 827cc0bdfa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 147 additions and 32 deletions

View file

@ -91,9 +91,11 @@ where
}
};
let data = ArcAsyncDerived::new_with_manual_dependencies(initial, fun);
let data = ArcAsyncDerived::new_with_manual_dependencies(
initial, fun, &source,
);
if is_ready {
source.with(|_| ());
source.with_untracked(|_| ());
source.add_subscriber(data.to_any_subscriber());
}

View file

@ -1,5 +1,6 @@
use super::{
inner::ArcAsyncDerivedInner, AsyncDerivedReadyFuture, ScopedFuture,
inner::{ArcAsyncDerivedInner, AsyncDerivedState},
AsyncDerivedReadyFuture, ScopedFuture,
};
#[cfg(feature = "sandboxed-arenas")]
use crate::owner::Sandboxed;
@ -12,8 +13,13 @@ use crate::{
SubscriberSet, ToAnySource, ToAnySubscriber, WithObserver,
},
owner::{use_context, Owner},
signal::guards::{AsyncPlain, ReadGuard, WriteGuard},
traits::{DefinedAt, ReadUntracked, Trigger, UntrackableGuard, Writeable},
signal::{
guards::{AsyncPlain, ReadGuard, WriteGuard},
ArcTrigger,
},
traits::{
DefinedAt, ReadUntracked, Track, Trigger, UntrackableGuard, Writeable,
},
transition::AsyncTransition,
};
use any_spawner::Executor;
@ -214,7 +220,7 @@ impl<T> DefinedAt for ArcAsyncDerived<T> {
// whether `fun` returns a `Future` that is `Send`. Doing it as a function would,
// as far as I can tell, require repeating most of the function body.
macro_rules! spawn_derived {
($spawner:expr, $initial:ident, $fun:ident, $should_spawn:literal, $force_spawn:literal, $should_track:literal) => {{
($spawner:expr, $initial:ident, $fun:ident, $should_spawn:literal, $force_spawn:literal, $should_track:literal, $source:expr) => {{
let (notifier, mut rx) = channel();
let is_ready = $initial.is_some() && !$force_spawn;
@ -225,7 +231,7 @@ macro_rules! spawn_derived {
notifier,
sources: SourceSet::new(),
subscribers: SubscriberSet::new(),
dirty: false
state: AsyncDerivedState::Clean
}));
let value = Arc::new(AsyncRwLock::new($initial));
let wakers = Arc::new(RwLock::new(Vec::new()));
@ -245,7 +251,10 @@ macro_rules! spawn_derived {
.with_observer(|| ScopedFuture::new($fun()))
})
} else {
crate::untrack(|| ScopedFuture::new($fun()))
owner.with_cleanup(|| {
any_subscriber
.with_observer_untracked(|| ScopedFuture::new($fun()))
})
};
#[cfg(feature = "sandboxed-arenas")]
let initial_fut = Sandboxed::new(initial_fut);
@ -263,7 +272,7 @@ macro_rules! spawn_derived {
Some(orig_value) => {
let mut guard = this.inner.write().or_poisoned();
guard.dirty = false;
guard.state = AsyncDerivedState::Clean;
*value.blocking_write() = Some(orig_value);
this.loading.store(false, Ordering::Relaxed);
(true, None)
@ -283,6 +292,10 @@ macro_rules! spawn_derived {
any_subscriber.mark_dirty();
}
if let Some(source) = $source {
any_subscriber.with_observer(|| source.track());
}
if $should_spawn {
$spawner({
let value = Arc::downgrade(&this.value);
@ -291,7 +304,14 @@ macro_rules! spawn_derived {
let loading = Arc::downgrade(&this.loading);
let fut = async move {
while rx.next().await.is_some() {
if any_subscriber.with_observer(|| any_subscriber.update_if_necessary()) || first_run.is_some() {
let update_if_necessary = if $should_track {
any_subscriber
.with_observer(|| any_subscriber.update_if_necessary())
} else {
any_subscriber
.with_observer_untracked(|| any_subscriber.update_if_necessary())
};
if update_if_necessary || first_run.is_some() {
match (value.upgrade(), inner.upgrade(), wakers.upgrade(), loading.upgrade()) {
(Some(value), Some(inner), Some(wakers), Some(loading)) => {
// generate new Future
@ -303,7 +323,10 @@ macro_rules! spawn_derived {
.with_observer(|| ScopedFuture::new($fun()))
})
} else {
crate::untrack(|| ScopedFuture::new($fun()))
owner.with_cleanup(|| {
any_subscriber
.with_observer_untracked(|| ScopedFuture::new($fun()))
})
};
#[cfg(feature = "sandboxed-arenas")]
let fut = Sandboxed::new(fut);
@ -360,7 +383,7 @@ impl<T: 'static> ArcAsyncDerived<T> {
) {
loading.store(false, Ordering::Relaxed);
inner.write().or_poisoned().dirty = true;
inner.write().or_poisoned().state = AsyncDerivedState::Notifying;
if let Some(ready_tx) = ready_tx {
// if it's an Err, that just means the Receiver was dropped
@ -379,6 +402,8 @@ impl<T: 'static> ArcAsyncDerived<T> {
for waker in mem::take(&mut *wakers.write().or_poisoned()) {
waker.wake();
}
inner.write().or_poisoned().state = AsyncDerivedState::Clean;
}
}
@ -413,7 +438,8 @@ impl<T: 'static> ArcAsyncDerived<T> {
fun,
true,
true,
true
true,
None::<ArcTrigger>
);
this
}
@ -425,13 +451,15 @@ impl<T: 'static> ArcAsyncDerived<T> {
/// where you do not want to run the run the `Future` unnecessarily.
#[doc(hidden)]
#[track_caller]
pub fn new_with_manual_dependencies<Fut>(
pub fn new_with_manual_dependencies<Fut, S>(
initial_value: Option<T>,
fun: impl Fn() -> Fut + Send + Sync + 'static,
source: &S,
) -> Self
where
T: Send + Sync + 'static,
Fut: Future<Output = T> + Send + 'static,
S: Track,
{
let (this, _) = spawn_derived!(
Executor::spawn,
@ -439,7 +467,8 @@ impl<T: 'static> ArcAsyncDerived<T> {
fun,
true,
false,
false
false,
Some(source)
);
this
}
@ -475,7 +504,8 @@ impl<T: 'static> ArcAsyncDerived<T> {
fun,
true,
true,
true
true,
None::<ArcTrigger>
);
this
}
@ -512,7 +542,8 @@ impl<T: 'static> ArcAsyncDerived<SendWrapper<T>> {
fun,
false,
false,
true
true,
None::<ArcTrigger>
);
this
}

View file

@ -18,18 +18,30 @@ pub(crate) struct ArcAsyncDerivedInner {
pub subscribers: SubscriberSet,
// when a source changes, notifying this will cause the async work to rerun
pub notifier: Sender,
pub dirty: bool,
pub state: AsyncDerivedState,
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum AsyncDerivedState {
Clean,
Dirty,
Notifying,
}
impl ReactiveNode for RwLock<ArcAsyncDerivedInner> {
fn mark_dirty(&self) {
let mut lock = self.write().or_poisoned();
lock.dirty = true;
lock.notifier.notify();
if lock.state != AsyncDerivedState::Notifying {
lock.state = AsyncDerivedState::Dirty;
lock.notifier.notify();
}
}
fn mark_check(&self) {
self.write().or_poisoned().notifier.notify();
let mut lock = self.write().or_poisoned();
if lock.state != AsyncDerivedState::Notifying {
lock.notifier.notify();
}
}
fn mark_subscribers_check(&self) {
@ -41,11 +53,14 @@ impl ReactiveNode for RwLock<ArcAsyncDerivedInner> {
fn update_if_necessary(&self) -> bool {
let mut guard = self.write().or_poisoned();
let (is_dirty, sources) =
(guard.dirty, (!guard.dirty).then(|| guard.sources.clone()));
let (is_dirty, sources) = (
guard.state == AsyncDerivedState::Dirty,
(guard.state != AsyncDerivedState::Notifying)
.then(|| guard.sources.clone()),
);
if is_dirty {
guard.dirty = false;
guard.state = AsyncDerivedState::Clean;
return true;
}
drop(guard);

View file

@ -1,9 +1,17 @@
use super::{node::ReactiveNode, AnySource};
#[cfg(debug_assertions)]
use crate::diagnostics::SpecialNonReactiveZone;
use core::{fmt::Debug, hash::Hash};
use std::{cell::RefCell, mem, sync::Weak};
thread_local! {
static OBSERVER: RefCell<Option<AnySubscriber>> = const { RefCell::new(None) };
static OBSERVER: RefCell<Option<ObserverState>> = const { RefCell::new(None) };
}
#[derive(Debug)]
struct ObserverState {
subscriber: AnySubscriber,
untracked: bool,
}
/// The current reactive observer.
@ -25,24 +33,67 @@ impl Drop for SetObserverOnDrop {
impl Observer {
/// Returns the current observer, if any.
pub fn get() -> Option<AnySubscriber> {
OBSERVER.with_borrow(Clone::clone)
OBSERVER.with_borrow(|obs| {
obs.as_ref().and_then(|obs| {
if obs.untracked {
None
} else {
Some(obs.subscriber.clone())
}
})
})
}
pub(crate) fn is(observer: &AnySubscriber) -> bool {
OBSERVER.with_borrow(|o| o.as_ref() == Some(observer))
OBSERVER.with_borrow(|o| {
o.as_ref().map(|o| &o.subscriber) == Some(observer)
})
}
fn take() -> SetObserverOnDrop {
SetObserverOnDrop(OBSERVER.with_borrow_mut(Option::take))
SetObserverOnDrop(
OBSERVER.with_borrow_mut(Option::take).map(|o| o.subscriber),
)
}
fn set(observer: Option<AnySubscriber>) {
OBSERVER.with_borrow_mut(|o| *o = observer);
OBSERVER.with_borrow_mut(|o| {
*o = observer.map(|subscriber| ObserverState {
subscriber,
untracked: false,
})
});
}
fn replace(observer: Option<AnySubscriber>) -> SetObserverOnDrop {
SetObserverOnDrop(
OBSERVER.with(|o| mem::replace(&mut *o.borrow_mut(), observer)),
OBSERVER
.with(|o| {
mem::replace(
&mut *o.borrow_mut(),
observer.map(|subscriber| ObserverState {
subscriber,
untracked: false,
}),
)
})
.map(|o| o.subscriber),
)
}
fn replace_untracked(observer: Option<AnySubscriber>) -> SetObserverOnDrop {
SetObserverOnDrop(
OBSERVER
.with(|o| {
mem::replace(
&mut *o.borrow_mut(),
observer.map(|subscriber| ObserverState {
subscriber,
untracked: true,
}),
)
})
.map(|o| o.subscriber),
)
}
}
@ -155,22 +206,38 @@ impl ReactiveNode for AnySubscriber {
pub trait WithObserver {
/// Runs the given function with this subscriber as the thread-local [`Observer`].
fn with_observer<T>(&self, fun: impl FnOnce() -> T) -> T;
/// Runs the given function with this subscriber as the thread-local [`Observer`],
/// but without tracking dependencies.
fn with_observer_untracked<T>(&self, fun: impl FnOnce() -> T) -> T;
}
impl WithObserver for AnySubscriber {
/// Runs the given function with this subscriber as the thread-local [`Observer`].
fn with_observer<T>(&self, fun: impl FnOnce() -> T) -> T {
let _prev = Observer::replace(Some(self.clone()));
fun()
}
fn with_observer_untracked<T>(&self, fun: impl FnOnce() -> T) -> T {
#[cfg(debug_assertions)]
let _guard = SpecialNonReactiveZone::enter();
let _prev = Observer::replace_untracked(Some(self.clone()));
fun()
}
}
impl WithObserver for Option<AnySubscriber> {
/// Runs the given function with this subscriber as the thread-local [`Observer`].
fn with_observer<T>(&self, fun: impl FnOnce() -> T) -> T {
let _prev = Observer::replace(self.clone());
fun()
}
fn with_observer_untracked<T>(&self, fun: impl FnOnce() -> T) -> T {
#[cfg(debug_assertions)]
let _guard = SpecialNonReactiveZone::enter();
let _prev = Observer::replace_untracked(self.clone());
fun()
}
}
impl Debug for AnySubscriber {