fix memo and add tests

This commit is contained in:
Evan Almloff 2023-08-07 14:31:03 -07:00
parent 7f2049b647
commit 190c65b15d
5 changed files with 198 additions and 27 deletions

View file

@ -20,7 +20,7 @@ fn app(cx: Scope) -> Element {
button { onclick: move |_| count += 1, "Up high!" } button { onclick: move |_| count += 1, "Up high!" }
button { onclick: move |_| count -= 1, "Down low!" } button { onclick: move |_| count -= 1, "Down low!" }
if count() > 5 { if count.value() > 5 {
rsx!{ h2 { "High five!" } } rsx!{ h2 { "High five!" } }
} }
}) })

View file

@ -1,5 +1,5 @@
use crate::rt::CopyValue; use crate::rt::CopyValue;
use crate::Signal; use crate::{Signal, Write};
use std::cell::{Ref, RefMut}; use std::cell::{Ref, RefMut};
@ -104,10 +104,6 @@ macro_rules! impls {
pub fn get(&self, index: usize) -> Option<Ref<'_, T>> { pub fn get(&self, index: usize) -> Option<Ref<'_, T>> {
Ref::filter_map(self.read(), |v| v.get(index)).ok() Ref::filter_map(self.read(), |v| v.get(index)).ok()
} }
pub fn get_mut(&self, index: usize) -> Option<RefMut<'_, T>> {
RefMut::filter_map(self.write(), |v| v.get_mut(index)).ok()
}
} }
impl<T: 'static> $ty<Option<T>> { impl<T: 'static> $ty<Option<T>> {
@ -130,10 +126,6 @@ macro_rules! impls {
Ref::filter_map(self.read(), |v| v.as_ref()).ok() Ref::filter_map(self.read(), |v| v.as_ref()).ok()
} }
pub fn as_mut(&self) -> Option<RefMut<'_, T>> {
RefMut::filter_map(self.write(), |v| v.as_mut()).ok()
}
pub fn get_or_insert(&self, default: T) -> Ref<'_, T> { pub fn get_or_insert(&self, default: T) -> Ref<'_, T> {
self.get_or_insert_with(|| default) self.get_or_insert_with(|| default)
} }
@ -183,6 +175,18 @@ impl<T: Clone + 'static> IntoIterator for CopyValue<Vec<T>> {
} }
} }
impl<T: 'static> CopyValue<Vec<T>> {
pub fn get_mut(&self, index: usize) -> Option<RefMut<'_, T>> {
RefMut::filter_map(self.write(), |v| v.get_mut(index)).ok()
}
}
impl<T: 'static> CopyValue<Option<T>> {
pub fn as_mut(&self) -> Option<RefMut<'_, T>> {
RefMut::filter_map(self.write(), |v| v.as_mut()).ok()
}
}
pub struct CopySignalIterator<T: 'static> { pub struct CopySignalIterator<T: 'static> {
index: usize, index: usize,
value: Signal<Vec<T>>, value: Signal<Vec<T>>,
@ -210,3 +214,15 @@ impl<T: Clone + 'static> IntoIterator for Signal<Vec<T>> {
} }
} }
} }
impl<T: 'static> Signal<Vec<T>> {
pub fn get_mut(&self, index: usize) -> Option<Write<'_, T, Vec<T>>> {
Write::filter_map(self.write(), |v| v.get_mut(index))
}
}
impl<T: 'static> Signal<Option<T>> {
pub fn as_mut(&self) -> Option<Write<'_, T, Option<T>>> {
Write::filter_map(self.write(), |v| v.as_mut())
}
}

View file

@ -1,5 +1,6 @@
use std::{ use std::{
cell::{Ref, RefCell, RefMut}, cell::{Ref, RefCell, RefMut},
ops::{Deref, DerefMut},
rc::Rc, rc::Rc,
sync::Arc, sync::Arc,
}; };
@ -52,7 +53,7 @@ struct SignalData<T> {
subscribers: Rc<RefCell<Vec<ScopeId>>>, subscribers: Rc<RefCell<Vec<ScopeId>>>,
effect_subscribers: Rc<RefCell<Vec<Effect>>>, effect_subscribers: Rc<RefCell<Vec<Effect>>>,
update_any: Arc<dyn Fn(ScopeId)>, update_any: Arc<dyn Fn(ScopeId)>,
value: T, pub(crate) value: T,
} }
pub struct Signal<T: 'static> { pub struct Signal<T: 'static> {
@ -77,7 +78,12 @@ impl<T: 'static> Signal<T> {
pub fn read(&self) -> Ref<T> { pub fn read(&self) -> Ref<T> {
let inner = self.inner.read(); let inner = self.inner.read();
if let Some(current_scope_id) = current_scope_id() { if let Some(effect) = Effect::current() {
let mut effect_subscribers = inner.effect_subscribers.borrow_mut();
if !effect_subscribers.contains(&effect) {
effect_subscribers.push(effect);
}
} else if let Some(current_scope_id) = current_scope_id() {
log::trace!( log::trace!(
"{:?} subscribed to {:?}", "{:?} subscribed to {:?}",
self.inner.value, self.inner.value,
@ -91,16 +97,19 @@ impl<T: 'static> Signal<T> {
inner.subscribers.borrow_mut().push(unsubscriber.scope); inner.subscribers.borrow_mut().push(unsubscriber.scope);
} }
} }
if let Some(effect) = Effect::current() {
let mut effect_subscribers = inner.effect_subscribers.borrow_mut();
if !effect_subscribers.contains(&effect) {
effect_subscribers.push(effect);
}
}
Ref::map(inner, |v| &v.value) Ref::map(inner, |v| &v.value)
} }
pub fn write(&self) -> RefMut<T> { pub fn write(&self) -> Write<'_, T> {
let inner = self.inner.write();
let borrow = RefMut::map(inner, |v| &mut v.value);
Write {
write: borrow,
signal: SignalSubscriberDrop { signal: *self },
}
}
fn update_subscribers(&self) {
{ {
let inner = self.inner.read(); let inner = self.inner.read();
for &scope_id in &*inner.subscribers.borrow() { for &scope_id in &*inner.subscribers.borrow() {
@ -113,8 +122,11 @@ impl<T: 'static> Signal<T> {
} }
} }
let subscribers = let subscribers = {
{ std::mem::take(&mut *self.inner.read().effect_subscribers.borrow_mut()) }; let self_read = self.inner.read();
let mut effects = self_read.effect_subscribers.borrow_mut();
std::mem::take(&mut *effects)
};
for effect in subscribers { for effect in subscribers {
log::trace!( log::trace!(
"Write on {:?} triggered effect {:?}", "Write on {:?} triggered effect {:?}",
@ -123,9 +135,6 @@ impl<T: 'static> Signal<T> {
); );
effect.try_run(); effect.try_run();
} }
let inner = self.inner.write();
RefMut::map(inner, |v| &mut v.value)
} }
pub fn set(&self, value: T) { pub fn set(&self, value: T) {
@ -154,3 +163,54 @@ impl<T: 'static> PartialEq for Signal<T> {
self.inner == other.inner self.inner == other.inner
} }
} }
struct SignalSubscriberDrop<T: 'static> {
signal: Signal<T>,
}
impl<T: 'static> Drop for SignalSubscriberDrop<T> {
fn drop(&mut self) {
self.signal.update_subscribers();
}
}
pub struct Write<'a, T: 'static, I: 'static = T> {
write: RefMut<'a, T>,
signal: SignalSubscriberDrop<I>,
}
impl<'a, T: 'static, I: 'static> Write<'a, T, I> {
pub fn map<O>(myself: Self, f: impl FnOnce(&mut T) -> &mut O) -> Write<'a, O, I> {
let Self { write, signal } = myself;
Write {
write: RefMut::map(write, f),
signal,
}
}
pub fn filter_map<O>(
myself: Self,
f: impl FnOnce(&mut T) -> Option<&mut O>,
) -> Option<Write<'a, O, I>> {
let Self { write, signal } = myself;
let write = RefMut::filter_map(write, f).ok();
write.map(|write| Write {
write,
signal: signal,
})
}
}
impl<'a, T: 'static> Deref for Write<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&*self.write
}
}
impl<T> DerefMut for Write<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.write
}
}

View file

@ -30,8 +30,8 @@ pub fn memo<R: PartialEq>(mut f: impl FnMut() -> R + 'static) -> Signal<R> {
effect.callback.value.set(Box::new(move || { effect.callback.value.set(Box::new(move || {
let value = f(); let value = f();
let changed = { let changed = {
let state = state.read(); let old = state.inner.read();
value != *state value != old.value
}; };
if changed { if changed {
state.set(value) state.set(value)

View file

@ -2,13 +2,14 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
use dioxus::html::p;
use dioxus::prelude::*; use dioxus::prelude::*;
use dioxus_core::ElementId; use dioxus_core::ElementId;
use dioxus_signals::*; use dioxus_signals::*;
#[test] #[test]
fn memos_rerun() { fn memos_rerun() {
simple_logger::SimpleLogger::new().init().unwrap(); let _ = simple_logger::SimpleLogger::new().init();
#[derive(Default)] #[derive(Default)]
struct RunCounter { struct RunCounter {
@ -48,3 +49,97 @@ fn memos_rerun() {
assert_eq!(current_counter.component, 1); assert_eq!(current_counter.component, 1);
assert_eq!(current_counter.effect, 2); assert_eq!(current_counter.effect, 2);
} }
#[test]
fn memos_prevents_component_rerun() {
let _ = simple_logger::SimpleLogger::new().init();
#[derive(Default)]
struct RunCounter {
component: usize,
effect: usize,
}
let counter = Rc::new(RefCell::new(RunCounter::default()));
let mut dom = VirtualDom::new_with_props(
|cx| {
let mut signal = use_signal(cx, || 0);
if cx.generation() == 1 {
*signal.write() = 0;
}
if cx.generation() == 2 {
println!("Writing to signal");
*signal.write() = 1;
}
render! {
Child {
signal: signal,
counter: cx.props.clone(),
}
}
},
counter.clone(),
);
#[derive(Default, Props)]
struct ChildProps {
signal: Signal<usize>,
counter: Rc<RefCell<RunCounter>>,
}
impl PartialEq for ChildProps {
fn eq(&self, other: &Self) -> bool {
self.signal == other.signal
}
}
fn Child(cx: Scope<ChildProps>) -> Element {
let counter = &cx.props.counter;
let signal = cx.props.signal;
counter.borrow_mut().component += 1;
let memo = cx.use_hook(move || {
to_owned![counter];
memo(move || {
counter.borrow_mut().effect += 1;
println!("Signal: {:?}", signal);
signal.value()
})
});
match cx.generation() {
0 => {
assert_eq!(memo.value(), 0);
}
1 => {
assert_eq!(memo.value(), 1);
}
_ => panic!("Unexpected generation"),
}
render! {
div {}
}
}
let _ = dom.rebuild().santize();
dom.mark_dirty(ScopeId(0));
dom.render_immediate();
{
let current_counter = counter.borrow();
assert_eq!(current_counter.component, 1);
assert_eq!(current_counter.effect, 2);
}
dom.mark_dirty(ScopeId(0));
dom.render_immediate();
dom.render_immediate();
{
let current_counter = counter.borrow();
assert_eq!(current_counter.component, 2);
assert_eq!(current_counter.effect, 3);
}
}