From af034f3b5e1324ed50a59c01d34e5221b34e2f9f Mon Sep 17 00:00:00 2001 From: Nutomic Date: Mon, 27 May 2024 15:34:58 +0200 Subject: [PATCH] Unit tests and cleanup for outgoing federation code (#4733) * test setup * code cleanup * cleanup * move stats to own file * basic test working * cleanup * processes test * more test cases * fmt * add file * add assert * error handling * fmt * use instance id instead of domain for stats channel --- Cargo.lock | 2 + crates/db_schema/src/impls/instance.rs | 6 +- crates/db_schema/src/newtypes.rs | 6 +- crates/federate/Cargo.toml | 4 + crates/federate/src/lib.rs | 438 ++++++++++++++++--------- crates/federate/src/stats.rs | 97 ++++++ crates/federate/src/util.rs | 27 +- crates/federate/src/worker.rs | 90 +++-- scripts/test.sh | 5 +- src/lib.rs | 10 +- 10 files changed, 449 insertions(+), 236 deletions(-) create mode 100644 crates/federate/src/stats.rs diff --git a/Cargo.lock b/Cargo.lock index 948a18125..e7d1455d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2986,10 +2986,12 @@ dependencies = [ "lemmy_apub", "lemmy_db_schema", "lemmy_db_views_actor", + "lemmy_utils", "moka", "once_cell", "reqwest 0.11.27", "serde_json", + "serial_test", "tokio", "tokio-util", "tracing", diff --git a/crates/db_schema/src/impls/instance.rs b/crates/db_schema/src/impls/instance.rs index 8d2e4b75a..94bf909a3 100644 --- a/crates/db_schema/src/impls/instance.rs +++ b/crates/db_schema/src/impls/instance.rs @@ -94,11 +94,15 @@ impl Instance { .await } - #[cfg(test)] + /// Only for use in tests pub async fn delete_all(pool: &mut DbPool<'_>) -> Result { let conn = &mut get_conn(pool).await?; + diesel::delete(federation_queue_state::table) + .execute(conn) + .await?; diesel::delete(instance::table).execute(conn).await } + pub async fn allowlist(pool: &mut DbPool<'_>) -> Result, Error> { let conn = &mut get_conn(pool).await?; instance::table diff --git a/crates/db_schema/src/newtypes.rs b/crates/db_schema/src/newtypes.rs index e0c516037..10abfaec4 100644 --- a/crates/db_schema/src/newtypes.rs +++ b/crates/db_schema/src/newtypes.rs @@ -127,11 +127,13 @@ pub struct LanguageId(pub i32); /// The comment reply id. pub struct CommentReplyId(i32); -#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Serialize, Deserialize, Default)] +#[derive( + Debug, Copy, Clone, Hash, Eq, PartialEq, Serialize, Deserialize, Default, Ord, PartialOrd, +)] #[cfg_attr(feature = "full", derive(DieselNewType, TS))] #[cfg_attr(feature = "full", ts(export))] /// The instance id. -pub struct InstanceId(i32); +pub struct InstanceId(pub i32); #[derive( Debug, Copy, Clone, Hash, Eq, PartialEq, Serialize, Deserialize, Default, PartialOrd, Ord, diff --git a/crates/federate/Cargo.toml b/crates/federate/Cargo.toml index 00e2f5a60..2405d3af0 100644 --- a/crates/federate/Cargo.toml +++ b/crates/federate/Cargo.toml @@ -19,6 +19,7 @@ lemmy_api_common.workspace = true lemmy_apub.workspace = true lemmy_db_schema = { workspace = true, features = ["full"] } lemmy_db_views_actor.workspace = true +lemmy_utils.workspace = true activitypub_federation.workspace = true anyhow.workspace = true @@ -33,3 +34,6 @@ tokio = { workspace = true, features = ["full"] } tracing.workspace = true moka.workspace = true tokio-util = "0.7.11" + +[dev-dependencies] +serial_test = { workspace = true } diff --git a/crates/federate/src/lib.rs b/crates/federate/src/lib.rs index e6145dad9..d3876226f 100644 --- a/crates/federate/src/lib.rs +++ b/crates/federate/src/lib.rs @@ -1,20 +1,22 @@ use crate::{util::CancellableTask, worker::InstanceWorker}; use activitypub_federation::config::FederationConfig; -use chrono::{Local, Timelike}; -use lemmy_api_common::{context::LemmyContext, federate_retry_sleep_duration}; +use lemmy_api_common::context::LemmyContext; use lemmy_db_schema::{ newtypes::InstanceId, source::{federation_queue_state::FederationQueueState, instance::Instance}, - utils::{ActualDbPool, DbPool}, }; +use lemmy_utils::error::LemmyResult; +use stats::receive_print_stats; use std::{collections::HashMap, time::Duration}; use tokio::{ - sync::mpsc::{unbounded_channel, UnboundedReceiver}, + sync::mpsc::{unbounded_channel, UnboundedSender}, + task::JoinHandle, time::sleep, }; use tokio_util::sync::CancellationToken; use tracing::info; +mod stats; mod util; mod worker; @@ -32,175 +34,293 @@ pub struct Opts { pub process_index: i32, } -async fn start_stop_federation_workers( +pub struct SendManager { opts: Opts, - pool: ActualDbPool, - federation_config: FederationConfig, - cancel: CancellationToken, -) -> anyhow::Result<()> { - let mut workers = HashMap::::new(); + workers: HashMap, + context: FederationConfig, + stats_sender: UnboundedSender<(InstanceId, FederationQueueState)>, + exit_print: JoinHandle<()>, +} - let (stats_sender, stats_receiver) = unbounded_channel(); - let exit_print = tokio::spawn(receive_print_stats(pool.clone(), stats_receiver)); - let pool2 = &mut DbPool::Pool(&pool); - let process_index = opts.process_index - 1; - let local_domain = federation_config.settings().get_hostname_without_port()?; - info!( - "Starting federation workers for process count {} and index {}", - opts.process_count, process_index - ); - loop { - let mut total_count = 0; - let mut dead_count = 0; - let mut disallowed_count = 0; - for (instance, allowed, is_dead) in - Instance::read_federated_with_blocked_and_dead(pool2).await? - { - if instance.domain == local_domain { - continue; - } - if instance.id.inner() % opts.process_count != process_index { - continue; - } - total_count += 1; - if !allowed { - disallowed_count += 1; - } - if is_dead { - dead_count += 1; - } - let should_federate = allowed && !is_dead; - if should_federate { - if workers.contains_key(&instance.id) { - // worker already running +impl SendManager { + pub fn new(opts: Opts, context: FederationConfig) -> Self { + assert!(opts.process_count > 0); + assert!(opts.process_index > 0); + assert!(opts.process_index <= opts.process_count); + + let (stats_sender, stats_receiver) = unbounded_channel(); + Self { + opts, + workers: HashMap::new(), + stats_sender, + exit_print: tokio::spawn(receive_print_stats( + context.inner_pool().clone(), + stats_receiver, + )), + context, + } + } + + pub fn run(mut self) -> CancellableTask { + CancellableTask::spawn(WORKER_EXIT_TIMEOUT, move |cancel| async move { + self.do_loop(cancel).await?; + self.cancel().await?; + Ok(()) + }) + } + + async fn do_loop(&mut self, cancel: CancellationToken) -> LemmyResult<()> { + let process_index = self.opts.process_index - 1; + info!( + "Starting federation workers for process count {} and index {}", + self.opts.process_count, process_index + ); + let local_domain = self.context.settings().get_hostname_without_port()?; + let mut pool = self.context.pool(); + loop { + let mut total_count = 0; + let mut dead_count = 0; + let mut disallowed_count = 0; + for (instance, allowed, is_dead) in + Instance::read_federated_with_blocked_and_dead(&mut pool).await? + { + if instance.domain == local_domain { continue; } - // create new worker - let config = federation_config.clone(); - let stats_sender = stats_sender.clone(); - let pool = pool.clone(); - workers.insert( - instance.id, - CancellableTask::spawn(WORKER_EXIT_TIMEOUT, move |stop| { - let instance = instance.clone(); - let req_data = config.clone().to_request_data(); - let stats_sender = stats_sender.clone(); - let pool = pool.clone(); - async move { - InstanceWorker::init_and_loop( - instance, - req_data, - &mut DbPool::Pool(&pool), - stop, - stats_sender, - ) - .await + if instance.id.inner() % self.opts.process_count != process_index { + continue; + } + total_count += 1; + if !allowed { + disallowed_count += 1; + } + if is_dead { + dead_count += 1; + } + let should_federate = allowed && !is_dead; + if should_federate { + if self.workers.contains_key(&instance.id) { + // worker already running + continue; + } + // create new worker + let instance = instance.clone(); + let req_data = self.context.to_request_data(); + let stats_sender = self.stats_sender.clone(); + self.workers.insert( + instance.id, + CancellableTask::spawn(WORKER_EXIT_TIMEOUT, move |stop| async move { + InstanceWorker::init_and_loop(instance, req_data, stop, stats_sender).await?; + Ok(()) + }), + ); + } else if !should_federate { + if let Some(worker) = self.workers.remove(&instance.id) { + if let Err(e) = worker.cancel().await { + tracing::error!("error stopping worker: {e}"); } - }), - ); - } else if !should_federate { - if let Some(worker) = workers.remove(&instance.id) { - if let Err(e) = worker.cancel().await { - tracing::error!("error stopping worker: {e}"); } } } - } - let worker_count = workers.len(); - tracing::info!("Federating to {worker_count}/{total_count} instances ({dead_count} dead, {disallowed_count} disallowed)"); - tokio::select! { - () = sleep(INSTANCES_RECHECK_DELAY) => {}, - _ = cancel.cancelled() => { break; } - } - } - drop(stats_sender); - tracing::warn!( - "Waiting for {} workers ({:.2?} max)", - workers.len(), - WORKER_EXIT_TIMEOUT - ); - // the cancel futures need to be awaited concurrently for the shutdown processes to be triggered - // concurrently - futures::future::join_all(workers.into_values().map(util::CancellableTask::cancel)).await; - exit_print.await?; - Ok(()) -} - -/// starts and stops federation workers depending on which instances are on db -/// await the returned future to stop/cancel all workers gracefully -pub fn start_stop_federation_workers_cancellable( - opts: Opts, - pool: ActualDbPool, - config: FederationConfig, -) -> CancellableTask { - CancellableTask::spawn(WORKER_EXIT_TIMEOUT, move |stop| { - let opts = opts.clone(); - let pool = pool.clone(); - let config = config.clone(); - async move { start_stop_federation_workers(opts, pool, config, stop).await } - }) -} - -/// every 60s, print the state for every instance. exits if the receiver is done (all senders -/// dropped) -async fn receive_print_stats( - pool: ActualDbPool, - mut receiver: UnboundedReceiver<(String, FederationQueueState)>, -) { - let pool = &mut DbPool::Pool(&pool); - let mut printerval = tokio::time::interval(Duration::from_secs(60)); - printerval.tick().await; // skip first - let mut stats = HashMap::new(); - loop { - tokio::select! { - ele = receiver.recv() => { - let Some((domain, ele)) = ele else { - print_stats(pool, &stats).await; - return; - }; - stats.insert(domain, ele); - }, - _ = printerval.tick() => { - print_stats(pool, &stats).await; + let worker_count = self.workers.len(); + tracing::info!("Federating to {worker_count}/{total_count} instances ({dead_count} dead, {disallowed_count} disallowed)"); + tokio::select! { + () = sleep(INSTANCES_RECHECK_DELAY) => {}, + _ = cancel.cancelled() => { return Ok(()) } } } } + + pub async fn cancel(self) -> LemmyResult<()> { + drop(self.stats_sender); + tracing::warn!( + "Waiting for {} workers ({:.2?} max)", + self.workers.len(), + WORKER_EXIT_TIMEOUT + ); + // the cancel futures need to be awaited concurrently for the shutdown processes to be triggered + // concurrently + futures::future::join_all( + self + .workers + .into_values() + .map(util::CancellableTask::cancel), + ) + .await; + self.exit_print.await?; + Ok(()) + } } -async fn print_stats(pool: &mut DbPool<'_>, stats: &HashMap) { - let last_id = crate::util::get_latest_activity_id(pool).await; - let Ok(last_id) = last_id else { - tracing::error!("could not get last id"); - return; +#[cfg(test)] +#[allow(clippy::unwrap_used)] +#[allow(clippy::indexing_slicing)] +mod test { + + use super::*; + use activitypub_federation::config::Data; + use chrono::DateTime; + use lemmy_db_schema::source::{ + federation_allowlist::FederationAllowList, + federation_blocklist::FederationBlockList, + instance::InstanceForm, }; - // it's expected that the values are a bit out of date, everything < SAVE_STATE_EVERY should be - // considered up to date - tracing::info!( - "Federation state as of {}:", - Local::now() - .with_nanosecond(0) - .expect("0 is valid nanos") - .to_rfc3339() - ); - // todo: more stats (act/sec, avg http req duration) - let mut ok_count = 0; - let mut behind_count = 0; - for (domain, stat) in stats { - let behind = last_id.0 - stat.last_successful_id.map(|e| e.0).unwrap_or(0); - if stat.fail_count > 0 { - tracing::info!( - "{}: Warning. {} behind, {} consecutive fails, current retry delay {:.2?}", - domain, - behind, - stat.fail_count, - federate_retry_sleep_duration(stat.fail_count) - ); - } else if behind > 0 { - tracing::debug!("{}: Ok. {} activities behind", domain, behind); - behind_count += 1; - } else { - ok_count += 1; + use lemmy_utils::error::LemmyError; + use serial_test::serial; + use std::{ + collections::HashSet, + sync::{Arc, Mutex}, + }; + use tokio::{spawn, time::sleep}; + + struct TestData { + send_manager: SendManager, + context: Data, + instances: Vec, + } + impl TestData { + async fn init(process_count: i32, process_index: i32) -> LemmyResult { + let context = LemmyContext::init_test_context().await; + let opts = Opts { + process_count, + process_index, + }; + let federation_config = FederationConfig::builder() + .domain("local.com") + .app_data(context.clone()) + .build() + .await?; + + let pool = &mut context.pool(); + let instances = vec![ + Instance::read_or_create(pool, "alpha.com".to_string()).await?, + Instance::read_or_create(pool, "beta.com".to_string()).await?, + Instance::read_or_create(pool, "gamma.com".to_string()).await?, + ]; + + let send_manager = SendManager::new(opts, federation_config); + Ok(Self { + send_manager, + context, + instances, + }) + } + + async fn run(&mut self) -> LemmyResult<()> { + // start it and cancel after workers are running + let cancel = CancellationToken::new(); + let cancel_ = cancel.clone(); + spawn(async move { + sleep(Duration::from_millis(100)).await; + cancel_.cancel(); + }); + self.send_manager.do_loop(cancel.clone()).await?; + Ok(()) + } + + async fn cleanup(self) -> LemmyResult<()> { + self.send_manager.cancel().await?; + Instance::delete_all(&mut self.context.pool()).await?; + Ok(()) } } - tracing::info!("{ok_count} others up to date. {behind_count} instances behind."); + + /// Basic test with default params and only active/allowed instances + #[tokio::test] + #[serial] + async fn test_send_manager() -> LemmyResult<()> { + let mut data = TestData::init(1, 1).await?; + + data.run().await?; + assert_eq!(3, data.send_manager.workers.len()); + let workers: HashSet<_> = data.send_manager.workers.keys().cloned().collect(); + let instances: HashSet<_> = data.instances.iter().map(|i| i.id).collect(); + assert_eq!(instances, workers); + + data.cleanup().await?; + Ok(()) + } + + /// Running with multiple processes should start correct workers + #[tokio::test] + #[serial] + async fn test_send_manager_processes() -> LemmyResult<()> { + let active = Arc::new(Mutex::new(vec![])); + let execute = |count, index, active: Arc>>| async move { + let mut data = TestData::init(count, index).await?; + data.run().await?; + assert_eq!(1, data.send_manager.workers.len()); + for k in data.send_manager.workers.keys() { + active.lock().unwrap().push(*k); + } + data.cleanup().await?; + Ok::<(), LemmyError>(()) + }; + execute(3, 1, active.clone()).await?; + execute(3, 2, active.clone()).await?; + execute(3, 3, active.clone()).await?; + + // Should run exactly three workers + assert_eq!(3, active.lock().unwrap().len()); + + Ok(()) + } + + /// Use blocklist, should not send to blocked instances + #[tokio::test] + #[serial] + async fn test_send_manager_blocked() -> LemmyResult<()> { + let mut data = TestData::init(1, 1).await?; + + let domain = data.instances[0].domain.clone(); + FederationBlockList::replace(&mut data.context.pool(), Some(vec![domain])).await?; + data.run().await?; + let workers = &data.send_manager.workers; + assert_eq!(2, workers.len()); + assert!(workers.contains_key(&data.instances[1].id)); + assert!(workers.contains_key(&data.instances[2].id)); + + data.cleanup().await?; + Ok(()) + } + + /// Use allowlist, should only send to allowed instance + #[tokio::test] + #[serial] + async fn test_send_manager_allowed() -> LemmyResult<()> { + let mut data = TestData::init(1, 1).await?; + + let domain = data.instances[0].domain.clone(); + FederationAllowList::replace(&mut data.context.pool(), Some(vec![domain])).await?; + data.run().await?; + let workers = &data.send_manager.workers; + assert_eq!(1, workers.len()); + assert!(workers.contains_key(&data.instances[0].id)); + + data.cleanup().await?; + Ok(()) + } + + /// Mark instance as dead, there should be no worker created for it + #[tokio::test] + #[serial] + async fn test_send_manager_dead() -> LemmyResult<()> { + let mut data = TestData::init(1, 1).await?; + + let instance = &data.instances[0]; + let form = InstanceForm::builder() + .domain(instance.domain.clone()) + .updated(DateTime::from_timestamp(0, 0)) + .build(); + Instance::update(&mut data.context.pool(), instance.id, form).await?; + + data.run().await?; + let workers = &data.send_manager.workers; + assert_eq!(2, workers.len()); + assert!(workers.contains_key(&data.instances[1].id)); + assert!(workers.contains_key(&data.instances[2].id)); + + data.cleanup().await?; + Ok(()) + } } diff --git a/crates/federate/src/stats.rs b/crates/federate/src/stats.rs new file mode 100644 index 000000000..bb6510263 --- /dev/null +++ b/crates/federate/src/stats.rs @@ -0,0 +1,97 @@ +use crate::util::get_latest_activity_id; +use chrono::Local; +use diesel::result::Error::NotFound; +use lemmy_api_common::federate_retry_sleep_duration; +use lemmy_db_schema::{ + newtypes::InstanceId, + source::{federation_queue_state::FederationQueueState, instance::Instance}, + utils::{ActualDbPool, DbPool}, +}; +use lemmy_utils::{error::LemmyResult, CACHE_DURATION_FEDERATION}; +use moka::future::Cache; +use once_cell::sync::Lazy; +use std::{collections::HashMap, time::Duration}; +use tokio::{sync::mpsc::UnboundedReceiver, time::interval}; +use tracing::{debug, info, warn}; + +/// every 60s, print the state for every instance. exits if the receiver is done (all senders +/// dropped) +pub(crate) async fn receive_print_stats( + pool: ActualDbPool, + mut receiver: UnboundedReceiver<(InstanceId, FederationQueueState)>, +) { + let pool = &mut DbPool::Pool(&pool); + let mut printerval = interval(Duration::from_secs(60)); + let mut stats = HashMap::new(); + loop { + tokio::select! { + ele = receiver.recv() => { + match ele { + // update stats for instance + Some((instance_id, ele)) => {stats.insert(instance_id, ele);}, + // receiver closed, print stats and exit + None => { + print_stats(pool, &stats).await; + return; + } + } + }, + _ = printerval.tick() => { + print_stats(pool, &stats).await; + } + } + } +} + +async fn print_stats(pool: &mut DbPool<'_>, stats: &HashMap) { + let res = print_stats_with_error(pool, stats).await; + if let Err(e) = res { + warn!("Failed to print stats: {e}"); + } +} + +async fn print_stats_with_error( + pool: &mut DbPool<'_>, + stats: &HashMap, +) -> LemmyResult<()> { + static INSTANCE_CACHE: Lazy>> = Lazy::new(|| { + Cache::builder() + .max_capacity(1) + .time_to_live(CACHE_DURATION_FEDERATION) + .build() + }); + let instances = INSTANCE_CACHE + .try_get_with((), async { Instance::read_all(pool).await }) + .await?; + + let last_id = get_latest_activity_id(pool).await?; + + // it's expected that the values are a bit out of date, everything < SAVE_STATE_EVERY should be + // considered up to date + info!("Federation state as of {}:", Local::now().to_rfc3339()); + // todo: more stats (act/sec, avg http req duration) + let mut ok_count = 0; + let mut behind_count = 0; + for (instance_id, stat) in stats { + let domain = &instances + .iter() + .find(|i| &i.id == instance_id) + .ok_or(NotFound)? + .domain; + let behind = last_id.0 - stat.last_successful_id.map(|e| e.0).unwrap_or(0); + if stat.fail_count > 0 { + info!( + "{domain}: Warning. {behind} behind, {} consecutive fails, current retry delay {:.2?}", + stat.fail_count, + federate_retry_sleep_duration(stat.fail_count) + ); + } else if behind > 0 { + debug!("{}: Ok. {} activities behind", domain, behind); + behind_count += 1; + } else { + ok_count += 1; + } + } + info!("{ok_count} others up to date. {behind_count} instances behind."); + Ok(()) +} diff --git a/crates/federate/src/util.rs b/crates/federate/src/util.rs index a64d49f03..02a90dee9 100644 --- a/crates/federate/src/util.rs +++ b/crates/federate/src/util.rs @@ -17,6 +17,7 @@ use lemmy_db_schema::{ traits::ApubActor, utils::{get_conn, DbPool}, }; +use lemmy_utils::error::LemmyResult; use moka::future::Cache; use once_cell::sync::Lazy; use reqwest::Url; @@ -24,6 +25,7 @@ use serde_json::Value; use std::{fmt::Debug, future::Future, pin::Pin, sync::Arc, time::Duration}; use tokio::{task::JoinHandle, time::sleep}; use tokio_util::sync::CancellationToken; +use tracing::error; /// Decrease the delays of the federation queue. /// Should only be used for federation tests since it significantly increases CPU and DB load of the @@ -59,36 +61,29 @@ impl CancellableTask { /// spawn a task but with graceful shutdown pub fn spawn( timeout: Duration, - task: impl Fn(CancellationToken) -> F + Send + 'static, + task: impl FnOnce(CancellationToken) -> F + Send + 'static, ) -> CancellableTask where - F: Future + Send + 'static, + F: Future> + Send + 'static, + R: Send + 'static, { let stop = CancellationToken::new(); let stop2 = stop.clone(); - let task: JoinHandle<()> = tokio::spawn(async move { - loop { - let res = task(stop2.clone()).await; - if stop2.is_cancelled() { - return; - } else { - tracing::warn!("task exited, restarting: {res:?}"); - } - } - }); + let task: JoinHandle> = tokio::spawn(task(stop2)); let abort = task.abort_handle(); CancellableTask { f: Box::pin(async move { stop.cancel(); tokio::select! { r = task => { - r.context("could not join")?; - Ok(()) + if let Err(ref e) = r? { + error!("CancellableTask threw error: {e}"); + } + Ok(()) }, _ = sleep(timeout) => { abort.abort(); - tracing::warn!("Graceful shutdown timed out, aborting task"); - Err(anyhow!("task aborted due to timeout")) + Err(anyhow!("CancellableTask aborted due to shutdown timeout")) } } }), diff --git a/crates/federate/src/worker.rs b/crates/federate/src/worker.rs index c11c019d6..f13a02678 100644 --- a/crates/federate/src/worker.rs +++ b/crates/federate/src/worker.rs @@ -22,7 +22,7 @@ use lemmy_db_schema::{ instance::{Instance, InstanceForm}, site::Site, }, - utils::{naive_now, DbPool}, + utils::naive_now, }; use lemmy_db_views_actor::structs::CommunityFollowerView; use once_cell::sync::Lazy; @@ -75,7 +75,7 @@ pub(crate) struct InstanceWorker { followed_communities: HashMap>, stop: CancellationToken, context: Data, - stats_sender: UnboundedSender<(String, FederationQueueState)>, + stats_sender: UnboundedSender<(InstanceId, FederationQueueState)>, last_full_communities_fetch: DateTime, last_incremental_communities_fetch: DateTime, state: FederationQueueState, @@ -86,12 +86,11 @@ impl InstanceWorker { pub(crate) async fn init_and_loop( instance: Instance, context: Data, - pool: &mut DbPool<'_>, /* in theory there's a ref to the pool in context, but i couldn't get - * that to work wrt lifetimes */ stop: CancellationToken, - stats_sender: UnboundedSender<(String, FederationQueueState)>, + stats_sender: UnboundedSender<(InstanceId, FederationQueueState)>, ) -> Result<(), anyhow::Error> { - let state = FederationQueueState::load(pool, instance.id).await?; + let mut pool = context.pool(); + let state = FederationQueueState::load(&mut pool, instance.id).await?; let mut worker = InstanceWorker { instance, site_loaded: false, @@ -105,32 +104,29 @@ impl InstanceWorker { state, last_state_insert: Utc.timestamp_nanos(0), }; - worker.loop_until_stopped(pool).await + worker.loop_until_stopped().await } /// loop fetch new activities from db and send them to the inboxes of the given instances /// this worker only returns if (a) there is an internal error or (b) the cancellation token is /// cancelled (graceful exit) - pub(crate) async fn loop_until_stopped( - &mut self, - pool: &mut DbPool<'_>, - ) -> Result<(), anyhow::Error> { + pub(crate) async fn loop_until_stopped(&mut self) -> Result<(), anyhow::Error> { debug!("Starting federation worker for {}", self.instance.domain); let save_state_every = chrono::Duration::from_std(SAVE_STATE_EVERY_TIME).expect("not negative"); - self.update_communities(pool).await?; + self.update_communities().await?; self.initial_fail_sleep().await?; while !self.stop.is_cancelled() { - self.loop_batch(pool).await?; + self.loop_batch().await?; if self.stop.is_cancelled() { break; } if (Utc::now() - self.last_state_insert) > save_state_every { - self.save_and_send_state(pool).await?; + self.save_and_send_state().await?; } - self.update_communities(pool).await?; + self.update_communities().await?; } // final update of state in db - self.save_and_send_state(pool).await?; + self.save_and_send_state().await?; Ok(()) } @@ -155,8 +151,8 @@ impl InstanceWorker { Ok(()) } /// send out a batch of CHECK_SAVE_STATE_EVERY_IT activities - async fn loop_batch(&mut self, pool: &mut DbPool<'_>) -> Result<()> { - let latest_id = get_latest_activity_id(pool).await?; + async fn loop_batch(&mut self) -> Result<()> { + let latest_id = get_latest_activity_id(&mut self.context.pool()).await?; let mut id = if let Some(id) = self.state.last_successful_id { id } else { @@ -166,7 +162,7 @@ impl InstanceWorker { // skip all past activities: self.state.last_successful_id = Some(latest_id); // save here to ensure it's not read as 0 again later if no activities have happened - self.save_and_send_state(pool).await?; + self.save_and_send_state().await?; latest_id }; if id >= latest_id { @@ -184,7 +180,7 @@ impl InstanceWorker { { id = ActivityId(id.0 + 1); processed_activities += 1; - let Some(ele) = get_activity_cached(pool, id) + let Some(ele) = get_activity_cached(&mut self.context.pool(), id) .await .context("failed reading activity from db")? else { @@ -192,7 +188,7 @@ impl InstanceWorker { self.state.last_successful_id = Some(id); continue; }; - if let Err(e) = self.send_retry_loop(pool, &ele.0, &ele.1).await { + if let Err(e) = self.send_retry_loop(&ele.0, &ele.1).await { warn!( "sending {} errored internally, skipping activity: {:?}", ele.0.ap_id, e @@ -213,12 +209,11 @@ impl InstanceWorker { // and will return an error if an internal error occurred (send errors cause an infinite loop) async fn send_retry_loop( &mut self, - pool: &mut DbPool<'_>, activity: &SentActivity, object: &SharedInboxActivities, ) -> Result<()> { let inbox_urls = self - .get_inbox_urls(pool, activity) + .get_inbox_urls(activity) .await .context("failed figuring out inbox urls")?; if inbox_urls.is_empty() { @@ -230,7 +225,7 @@ impl InstanceWorker { let Some(actor_apub_id) = &activity.actor_apub_id else { return Ok(()); // activity was inserted before persistent queue was activated }; - let actor = get_actor_cached(pool, activity.actor_type, actor_apub_id) + let actor = get_actor_cached(&mut self.context.pool(), activity.actor_type, actor_apub_id) .await .context("failed getting actor instance (was it marked deleted / removed?)")?; @@ -249,7 +244,7 @@ impl InstanceWorker { "{}: retrying {:?} attempt {} with delay {retry_delay:.2?}. ({e})", self.instance.domain, activity.id, self.state.fail_count ); - self.save_and_send_state(pool).await?; + self.save_and_send_state().await?; tokio::select! { () = sleep(retry_delay) => {}, () = self.stop.cancelled() => { @@ -268,7 +263,7 @@ impl InstanceWorker { .domain(self.instance.domain.clone()) .updated(Some(naive_now())) .build(); - Instance::update(pool, self.instance.id, form).await?; + Instance::update(&mut self.context.pool(), self.instance.id, form).await?; } } Ok(()) @@ -278,16 +273,12 @@ impl InstanceWorker { /// most often this will return 0 values (if instance doesn't care about the activity) /// or 1 value (the shared inbox) /// > 1 values only happens for non-lemmy software - async fn get_inbox_urls( - &mut self, - pool: &mut DbPool<'_>, - activity: &SentActivity, - ) -> Result> { + async fn get_inbox_urls(&mut self, activity: &SentActivity) -> Result> { let mut inbox_urls: HashSet = HashSet::new(); if activity.send_all_instances { if !self.site_loaded { - self.site = Site::read_from_instance_id(pool, self.instance.id).await?; + self.site = Site::read_from_instance_id(&mut self.context.pool(), self.instance.id).await?; self.site_loaded = true; } if let Some(site) = &self.site { @@ -312,22 +303,18 @@ impl InstanceWorker { Ok(inbox_urls) } - async fn update_communities(&mut self, pool: &mut DbPool<'_>) -> Result<()> { + async fn update_communities(&mut self) -> Result<()> { if (Utc::now() - self.last_full_communities_fetch) > *FOLLOW_REMOVALS_RECHECK_DELAY { // process removals every hour (self.followed_communities, self.last_full_communities_fetch) = self - .get_communities(pool, self.instance.id, Utc.timestamp_nanos(0)) + .get_communities(self.instance.id, Utc.timestamp_nanos(0)) .await?; self.last_incremental_communities_fetch = self.last_full_communities_fetch; } if (Utc::now() - self.last_incremental_communities_fetch) > *FOLLOW_ADDITIONS_RECHECK_DELAY { // process additions every minute let (news, time) = self - .get_communities( - pool, - self.instance.id, - self.last_incremental_communities_fetch, - ) + .get_communities(self.instance.id, self.last_incremental_communities_fetch) .await?; self.followed_communities.extend(news); self.last_incremental_communities_fetch = time; @@ -339,7 +326,6 @@ impl InstanceWorker { /// them async fn get_communities( &mut self, - pool: &mut DbPool<'_>, instance_id: InstanceId, last_fetch: DateTime, ) -> Result<(HashMap>, DateTime)> { @@ -347,22 +333,26 @@ impl InstanceWorker { Utc::now() - chrono::TimeDelta::try_seconds(10).expect("TimeDelta out of bounds"); // update to time before fetch to ensure overlap. subtract 10s to ensure overlap even if // published date is not exact Ok(( - CommunityFollowerView::get_instance_followed_community_inboxes(pool, instance_id, last_fetch) - .await? - .into_iter() - .fold(HashMap::new(), |mut map, (c, u)| { - map.entry(c).or_default().insert(u.into()); - map - }), + CommunityFollowerView::get_instance_followed_community_inboxes( + &mut self.context.pool(), + instance_id, + last_fetch, + ) + .await? + .into_iter() + .fold(HashMap::new(), |mut map, (c, u)| { + map.entry(c).or_default().insert(u.into()); + map + }), new_last_fetch, )) } - async fn save_and_send_state(&mut self, pool: &mut DbPool<'_>) -> Result<()> { + async fn save_and_send_state(&mut self) -> Result<()> { self.last_state_insert = Utc::now(); - FederationQueueState::upsert(pool, &self.state).await?; + FederationQueueState::upsert(&mut self.context.pool(), &self.state).await?; self .stats_sender - .send((self.instance.domain.clone(), self.state.clone()))?; + .send((self.instance.id, self.state.clone()))?; Ok(()) } } diff --git a/scripts/test.sh b/scripts/test.sh index 3e0581fc7..9bb6acaa8 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -20,11 +20,10 @@ then cargo test -p $PACKAGE --all-features --no-fail-fast $TEST else cargo test --workspace --no-fail-fast + # Testing lemmy utils all features in particular (for ts-rs bindings) + cargo test -p lemmy_utils --all-features --no-fail-fast fi -# Testing lemmy utils all features in particular (for ts-rs bindings) -cargo test -p lemmy_utils --all-features --no-fail-fast - # Add this to do printlns: -- --nocapture pg_ctl stop --silent diff --git a/src/lib.rs b/src/lib.rs index 38e9addd7..c2b5e57c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,7 @@ use lemmy_apub::{ FEDERATION_HTTP_FETCH_LIMIT, }; use lemmy_db_schema::{source::secret::Secret, utils::build_db_pool}; -use lemmy_federate::{start_stop_federation_workers_cancellable, Opts}; +use lemmy_federate::{Opts, SendManager}; use lemmy_routes::{feeds, images, nodeinfo, webfinger}; use lemmy_utils::{ error::LemmyResult, @@ -210,14 +210,14 @@ pub async fn start_lemmy_server(args: CmdArgs) -> LemmyResult<()> { None }; let federate = (!args.disable_activity_sending).then(|| { - start_stop_federation_workers_cancellable( + let task = SendManager::new( Opts { process_index: args.federate_process_index, process_count: args.federate_process_count, }, - pool.clone(), - federation_config.clone(), - ) + federation_config, + ); + task.run() }); let mut interrupt = tokio::signal::unix::signal(SignalKind::interrupt())?; let mut terminate = tokio::signal::unix::signal(SignalKind::terminate())?;