fix: untrack in the async block of a Resource (closes #2937) (#2941)

This commit is contained in:
Greg Johnston 2024-09-06 17:23:40 -04:00 committed by GitHub
parent 32bea69c28
commit 48c2148589
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 51 additions and 18 deletions

View file

@ -91,8 +91,7 @@ where
}
};
let data =
ArcAsyncDerived::new_with_initial_without_spawning(initial, fun);
let data = ArcAsyncDerived::new_with_manual_dependencies(initial, fun);
if is_ready {
source.with(|_| ());
source.add_subscriber(data.to_any_subscriber());

View file

@ -214,7 +214,7 @@ impl<T> DefinedAt for ArcAsyncDerived<T> {
// whether `fun` returns a `Future` that is `Send`. Doing it as a function would,
// as far as I can tell, require repeating most of the function body.
macro_rules! spawn_derived {
($spawner:expr, $initial:ident, $fun:ident, $should_spawn:literal, $force_spawn:literal) => {{
($spawner:expr, $initial:ident, $fun:ident, $should_spawn:literal, $force_spawn:literal, $should_track:literal) => {{
let (notifier, mut rx) = channel();
let is_ready = $initial.is_some() && !$force_spawn;
@ -239,10 +239,14 @@ macro_rules! spawn_derived {
loading: Arc::new(AtomicBool::new(!is_ready)),
};
let any_subscriber = this.to_any_subscriber();
let initial_fut = owner.with_cleanup(|| {
any_subscriber
.with_observer(|| ScopedFuture::new($fun()))
});
let initial_fut = if $should_track {
owner.with_cleanup(|| {
any_subscriber
.with_observer(|| ScopedFuture::new($fun()))
})
} else {
crate::untrack(|| ScopedFuture::new($fun()))
};
#[cfg(feature = "sandboxed-arenas")]
let initial_fut = Sandboxed::new(initial_fut);
let mut initial_fut = Box::pin(initial_fut);
@ -293,10 +297,14 @@ macro_rules! spawn_derived {
// generate new Future
let owner = inner.read().or_poisoned().owner.clone();
let fut = initial_fut.take().unwrap_or_else(|| {
let fut = owner.with_cleanup(|| {
any_subscriber
.with_observer(|| ScopedFuture::new($fun()))
});
let fut = if $should_track {
owner.with_cleanup(|| {
any_subscriber
.with_observer(|| ScopedFuture::new($fun()))
})
} else {
crate::untrack(|| ScopedFuture::new($fun()))
};
#[cfg(feature = "sandboxed-arenas")]
let fut = Sandboxed::new(fut);
Box::pin(fut)
@ -399,8 +407,14 @@ impl<T: 'static> ArcAsyncDerived<T> {
T: Send + Sync + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let (this, _) =
spawn_derived!(Executor::spawn, initial_value, fun, true, true);
let (this, _) = spawn_derived!(
Executor::spawn,
initial_value,
fun,
true,
true,
true
);
this
}
@ -411,7 +425,7 @@ impl<T: 'static> ArcAsyncDerived<T> {
/// where you do not want to run the run the `Future` unnecessarily.
#[doc(hidden)]
#[track_caller]
pub fn new_with_initial_without_spawning<Fut>(
pub fn new_with_manual_dependencies<Fut>(
initial_value: Option<T>,
fun: impl Fn() -> Fut + Send + Sync + 'static,
) -> Self
@ -419,8 +433,14 @@ impl<T: 'static> ArcAsyncDerived<T> {
T: Send + Sync + 'static,
Fut: Future<Output = T> + Send + 'static,
{
let (this, _) =
spawn_derived!(Executor::spawn, initial_value, fun, true, false);
let (this, _) = spawn_derived!(
Executor::spawn,
initial_value,
fun,
true,
false,
false
);
this
}
@ -454,6 +474,7 @@ impl<T: 'static> ArcAsyncDerived<T> {
initial_value,
fun,
true,
true,
true
);
this
@ -485,8 +506,14 @@ impl<T: 'static> ArcAsyncDerived<SendWrapper<T>> {
SendWrapper::new(value)
}
};
let (this, _) =
spawn_derived!(Executor::spawn_local, initial, fun, false, false);
let (this, _) = spawn_derived!(
Executor::spawn_local,
initial,
fun,
false,
false,
true
);
this
}
}

View file

@ -1,5 +1,6 @@
use super::{ArcAsyncDerived, AsyncDerived};
use crate::{
diagnostics::SpecialNonReactiveZone,
graph::{AnySource, ToAnySource},
owner::Storage,
signal::guards::{AsyncPlain, Mapped, ReadGuard},
@ -36,6 +37,8 @@ impl Future for AsyncDerivedReadyFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
#[cfg(debug_assertions)]
let _guard = SpecialNonReactiveZone::enter();
let waker = cx.waker();
self.source.track();
if self.loading.load(Ordering::Relaxed) {
@ -99,6 +102,8 @@ where
#[track_caller]
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
#[cfg(debug_assertions)]
let _guard = SpecialNonReactiveZone::enter();
let waker = cx.waker();
self.source.track();
let value = self.value.read_arc();
@ -163,6 +168,8 @@ where
type Output = AsyncDerivedGuard<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
#[cfg(debug_assertions)]
let _guard = SpecialNonReactiveZone::enter();
let waker = cx.waker();
self.source.track();
let value = self.value.read_arc();