mirror of
https://github.com/LemmyNet/activitypub-federation-rust
synced 2024-11-10 06:04:19 +00:00
Remove actix-rt
and replace with tokio tasks (#42)
* Remove `actix-rt` and replace with tokio tasks * Include activity queue test * Use older `Arc` method * Refactor to not re-process PEM data on each request * Add retry queue and spawn tokio tasks directly * Fix doc error * Remove semaphore and use join set for backpressure * Fix debug issue with multiple mailboxes
This commit is contained in:
parent
6ac6e2d90e
commit
c356265cf4
20 changed files with 601 additions and 153 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,3 +1,5 @@
|
|||
/target
|
||||
/.idea
|
||||
/Cargo.lock
|
||||
perf.data*
|
||||
flamegraph.svg
|
25
Cargo.toml
25
Cargo.toml
|
@ -23,26 +23,37 @@ openssl = "0.10.54"
|
|||
once_cell = "1.18.0"
|
||||
http = "0.2.9"
|
||||
sha2 = "0.10.6"
|
||||
background-jobs = "0.13.0"
|
||||
thiserror = "1.0.40"
|
||||
derive_builder = "0.12.0"
|
||||
itertools = "0.10.5"
|
||||
dyn-clone = "1.0.11"
|
||||
enum_delegate = "0.2.0"
|
||||
httpdate = "1.0.2"
|
||||
http-signature-normalization-reqwest = { version = "0.8.0", default-features = false, features = ["sha-2", "middleware"] }
|
||||
http-signature-normalization-reqwest = { version = "0.8.0", default-features = false, features = [
|
||||
"sha-2",
|
||||
"middleware",
|
||||
] }
|
||||
http-signature-normalization = "0.7.0"
|
||||
bytes = "1.4.0"
|
||||
futures-core = { version = "0.3.28", default-features = false }
|
||||
pin-project-lite = "0.2.9"
|
||||
activitystreams-kinds = "0.3.0"
|
||||
regex = { version = "1.8.4", default-features = false, features = ["std"] }
|
||||
tokio = { version = "1.21.2", features = [
|
||||
"sync",
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"time",
|
||||
] }
|
||||
|
||||
# Actix-web
|
||||
actix-web = { version = "4.3.1", default-features = false, optional = true }
|
||||
|
||||
# Axum
|
||||
axum = { version = "0.6.18", features = ["json", "headers"], default-features = false, optional = true }
|
||||
axum = { version = "0.6.18", features = [
|
||||
"json",
|
||||
"headers",
|
||||
], default-features = false, optional = true }
|
||||
tower = { version = "0.4.13", optional = true }
|
||||
hyper = { version = "0.14", optional = true }
|
||||
displaydoc = "0.2.4"
|
||||
|
@ -56,9 +67,13 @@ axum = ["dep:axum", "dep:tower", "dep:hyper"]
|
|||
rand = "0.8.5"
|
||||
env_logger = "0.10.0"
|
||||
tower-http = { version = "0.4.0", features = ["map-request-body", "util"] }
|
||||
axum = { version = "0.6.18", features = ["http1", "tokio", "query"], default-features = false }
|
||||
axum = { version = "0.6.18", features = [
|
||||
"http1",
|
||||
"tokio",
|
||||
"query",
|
||||
], default-features = false }
|
||||
axum-macros = "0.3.7"
|
||||
actix-rt = "2.8.0"
|
||||
tokio = { version = "1.21.2", features = ["full"] }
|
||||
|
||||
[profile.dev]
|
||||
strip = "symbols"
|
||||
|
|
|
@ -5,12 +5,13 @@ Next we need to do some configuration. Most importantly we need to specify the d
|
|||
```
|
||||
# use activitypub_federation::config::FederationConfig;
|
||||
# let db_connection = ();
|
||||
# let _ = actix_rt::System::new();
|
||||
# tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
let config = FederationConfig::builder()
|
||||
.domain("example.com")
|
||||
.app_data(db_connection)
|
||||
.build()?;
|
||||
.build().await?;
|
||||
# Ok::<(), anyhow::Error>(())
|
||||
# }).unwrap()
|
||||
```
|
||||
|
||||
`debug` is necessary to test federation with http and localhost URLs, but it should never be used in production. The `worker_count` value can be adjusted depending on the instance size. A lower value saves resources on a small instance, while a higher value is necessary on larger instances to keep up with send jobs. `url_verifier` can be used to implement a domain blacklist.
|
|
@ -22,12 +22,12 @@ The next step is to allow other servers to fetch our actors and objects. For thi
|
|||
# use http::HeaderMap;
|
||||
# async fn generate_user_html(_: String, _: Data<DbConnection>) -> axum::response::Response { todo!() }
|
||||
|
||||
#[actix_rt::main]
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Error> {
|
||||
let data = FederationConfig::builder()
|
||||
.domain("example.com")
|
||||
.app_data(DbConnection)
|
||||
.build()?;
|
||||
.build().await?;
|
||||
|
||||
let app = axum::Router::new()
|
||||
.route("/user/:name", get(http_get_user))
|
||||
|
|
|
@ -7,18 +7,17 @@ After setting up our structs, implementing traits and initializing configuration
|
|||
# use activitypub_federation::traits::tests::DbUser;
|
||||
# use activitypub_federation::config::FederationConfig;
|
||||
# let db_connection = activitypub_federation::traits::tests::DbConnection;
|
||||
# let _ = actix_rt::System::new();
|
||||
# actix_rt::Runtime::new().unwrap().block_on(async {
|
||||
# tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
let config = FederationConfig::builder()
|
||||
.domain("example.com")
|
||||
.app_data(db_connection)
|
||||
.build()?;
|
||||
.build().await?;
|
||||
let user_id = ObjectId::<DbUser>::parse("https://mastodon.social/@LemmyDev")?;
|
||||
let data = config.to_request_data();
|
||||
let user = user_id.dereference(&data).await;
|
||||
assert!(user.is_ok());
|
||||
# Ok::<(), anyhow::Error>(())
|
||||
}).unwrap()
|
||||
# }).unwrap()
|
||||
```
|
||||
|
||||
`dereference` retrieves the object JSON at the given URL, and uses serde to convert it to `Person`. It then calls your method `Object::from_json` which inserts it in the database and returns a `DbUser` struct. `request_data` contains the federation config as well as a counter of outgoing HTTP requests. If this counter exceeds the configured maximum, further requests are aborted in order to avoid recursive fetching which could allow for a denial of service attack.
|
||||
|
@ -32,9 +31,8 @@ We can similarly dereference a user over webfinger with the following method. It
|
|||
# use activitypub_federation::fetch::webfinger::webfinger_resolve_actor;
|
||||
# use activitypub_federation::traits::tests::DbUser;
|
||||
# let db_connection = DbConnection;
|
||||
# let _ = actix_rt::System::new();
|
||||
# actix_rt::Runtime::new().unwrap().block_on(async {
|
||||
# let config = FederationConfig::builder().domain("example.com").app_data(db_connection).build()?;
|
||||
# tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
# let config = FederationConfig::builder().domain("example.com").app_data(db_connection).build().await?;
|
||||
# let data = config.to_request_data();
|
||||
let user: DbUser = webfinger_resolve_actor("nutomic@lemmy.ml", &data).await?;
|
||||
# Ok::<(), anyhow::Error>(())
|
||||
|
|
|
@ -9,13 +9,12 @@ To send an activity we need to initialize our previously defined struct, and pic
|
|||
# use activitypub_federation::traits::Actor;
|
||||
# use activitypub_federation::fetch::object_id::ObjectId;
|
||||
# use activitypub_federation::traits::tests::{DB_USER, DbConnection, Follow};
|
||||
# let _ = actix_rt::System::new();
|
||||
# actix_rt::Runtime::new().unwrap().block_on(async {
|
||||
# tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
# let db_connection = DbConnection;
|
||||
# let config = FederationConfig::builder()
|
||||
# .domain("example.com")
|
||||
# .app_data(db_connection)
|
||||
# .build()?;
|
||||
# .build().await?;
|
||||
# let data = config.to_request_data();
|
||||
# let sender = DB_USER.clone();
|
||||
# let recipient = DB_USER.clone();
|
||||
|
|
|
@ -61,9 +61,9 @@ impl Object for SearchableDbObjects {
|
|||
}
|
||||
}
|
||||
|
||||
#[actix_rt::main]
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), anyhow::Error> {
|
||||
# let config = FederationConfig::builder().domain("example.com").app_data(DbConnection).build().unwrap();
|
||||
# let config = FederationConfig::builder().domain("example.com").app_data(DbConnection).build().await.unwrap();
|
||||
# let data = config.to_request_data();
|
||||
let query = "https://example.com/id/413";
|
||||
let query_result = ObjectId::<SearchableDbObjects>::parse(query)?
|
||||
|
|
|
@ -28,7 +28,7 @@ const DOMAIN: &str = "example.com";
|
|||
const LOCAL_USER_NAME: &str = "alison";
|
||||
const BIND_ADDRESS: &str = "localhost:8003";
|
||||
|
||||
#[actix_rt::main]
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Error> {
|
||||
env_logger::builder()
|
||||
.filter_level(LevelFilter::Warn)
|
||||
|
@ -47,7 +47,8 @@ async fn main() -> Result<(), Error> {
|
|||
let config = FederationConfig::builder()
|
||||
.domain(DOMAIN)
|
||||
.app_data(database)
|
||||
.build()?;
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
info!("Listen with HTTP server on {BIND_ADDRESS}");
|
||||
let config = config.clone();
|
||||
|
|
|
@ -30,7 +30,7 @@ pub fn listen(config: &FederationConfig<DatabaseHandle>) -> Result<(), Error> {
|
|||
})
|
||||
.bind(hostname)?
|
||||
.run();
|
||||
actix_rt::spawn(server);
|
||||
tokio::spawn(server);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ pub fn listen(config: &FederationConfig<DatabaseHandle>) -> Result<(), Error> {
|
|||
.expect("Failed to lookup domain name");
|
||||
let server = axum::Server::bind(&addr).serve(app.into_make_service());
|
||||
|
||||
actix_rt::spawn(server);
|
||||
tokio::spawn(server);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ use std::{
|
|||
};
|
||||
use url::Url;
|
||||
|
||||
pub fn new_instance(
|
||||
pub async fn new_instance(
|
||||
hostname: &str,
|
||||
name: String,
|
||||
) -> Result<FederationConfig<DatabaseHandle>, Error> {
|
||||
|
@ -29,7 +29,8 @@ pub fn new_instance(
|
|||
.signed_fetch_actor(&system_user)
|
||||
.app_data(database)
|
||||
.debug(true)
|
||||
.build()?;
|
||||
.build()
|
||||
.await?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ mod instance;
|
|||
mod objects;
|
||||
mod utils;
|
||||
|
||||
#[actix_rt::main]
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Error> {
|
||||
env_logger::builder()
|
||||
.filter_level(LevelFilter::Warn)
|
||||
|
@ -32,8 +32,8 @@ async fn main() -> Result<(), Error> {
|
|||
.map(|arg| Webserver::from_str(&arg).unwrap())
|
||||
.unwrap_or(Webserver::Axum);
|
||||
|
||||
let alpha = new_instance("localhost:8001", "alpha".to_string())?;
|
||||
let beta = new_instance("localhost:8002", "beta".to_string())?;
|
||||
let alpha = new_instance("localhost:8001", "alpha".to_string()).await?;
|
||||
let beta = new_instance("localhost:8002", "beta".to_string()).await?;
|
||||
listen(&alpha, &webserver)?;
|
||||
listen(&beta, &webserver)?;
|
||||
info!("Local instances started");
|
||||
|
|
|
@ -11,25 +11,28 @@ use crate::{
|
|||
FEDERATION_CONTENT_TYPE,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use background_jobs::{
|
||||
memory_storage::{ActixTimer, Storage},
|
||||
ActixJob,
|
||||
Backoff,
|
||||
Manager,
|
||||
MaxRetries,
|
||||
WorkerConfig,
|
||||
};
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures_core::Future;
|
||||
use http::{header::HeaderName, HeaderMap, HeaderValue};
|
||||
use httpdate::fmt_http_date;
|
||||
use itertools::Itertools;
|
||||
use openssl::pkey::{PKey, Private};
|
||||
use reqwest::Request;
|
||||
use reqwest_middleware::ClientWithMiddleware;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::Serialize;
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
fmt::{Debug, Display},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
use tokio::{
|
||||
sync::mpsc::{unbounded_channel, UnboundedSender},
|
||||
task::{JoinHandle, JoinSet},
|
||||
};
|
||||
use tracing::{debug, info, warn};
|
||||
use url::Url;
|
||||
|
||||
|
@ -56,10 +59,19 @@ where
|
|||
let config = &data.config;
|
||||
let actor_id = activity.actor();
|
||||
let activity_id = activity.id();
|
||||
let activity_serialized = serde_json::to_string_pretty(&activity)?;
|
||||
let private_key = actor
|
||||
let activity_serialized: Bytes = serde_json::to_vec(&activity)?.into();
|
||||
let private_key_pem = actor
|
||||
.private_key_pem()
|
||||
.expect("Actor for sending activity has private key");
|
||||
.ok_or_else(|| anyhow!("Actor {actor_id} does not contain a private key for signing"))?;
|
||||
|
||||
// This is a mostly expensive blocking call, we don't want to tie up other tasks while this is happening
|
||||
let private_key = tokio::task::spawn_blocking(move || {
|
||||
PKey::private_key_from_pem(private_key_pem.as_bytes())
|
||||
.map_err(|err| anyhow!("Could not create private key from PEM data:{err}"))
|
||||
})
|
||||
.await
|
||||
.map_err(|err| anyhow!("Error joining:{err}"))??;
|
||||
|
||||
let inboxes: Vec<Url> = inboxes
|
||||
.into_iter()
|
||||
.unique()
|
||||
|
@ -84,25 +96,32 @@ where
|
|||
private_key: private_key.clone(),
|
||||
http_signature_compat: config.http_signature_compat,
|
||||
};
|
||||
|
||||
// Don't use the activity queue if this is in debug mode, send and wait directly
|
||||
if config.debug {
|
||||
let res = do_send(message, &config.client, config.request_timeout).await;
|
||||
// Don't fail on error, as we intentionally do some invalid actions in tests, to verify that
|
||||
// they are rejected on the receiving side. These errors shouldn't bubble up to make the API
|
||||
// call fail. This matches the behaviour in production.
|
||||
if let Err(e) = res {
|
||||
warn!("{}", e);
|
||||
if let Err(err) = sign_and_send(
|
||||
&message,
|
||||
&config.client,
|
||||
config.request_timeout,
|
||||
Default::default(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("{err}");
|
||||
}
|
||||
} else {
|
||||
activity_queue.queue(message).await?;
|
||||
let stats = activity_queue.get_stats().await?;
|
||||
let stats = activity_queue.get_stats();
|
||||
let running = stats.running.load(Ordering::Relaxed);
|
||||
let stats_fmt = format!(
|
||||
"Activity queue stats: pending: {}, running: {}, dead (this hour): {}, complete (this hour): {}",
|
||||
stats.pending,
|
||||
stats.running,
|
||||
stats.dead.this_hour(),
|
||||
stats.complete.this_hour()
|
||||
"Activity queue stats: pending: {}, running: {}, retries: {}, dead: {}, complete: {}",
|
||||
stats.pending.load(Ordering::Relaxed),
|
||||
running,
|
||||
stats.retries.load(Ordering::Relaxed),
|
||||
stats.dead_last_hour.load(Ordering::Relaxed),
|
||||
stats.completed_last_hour.load(Ordering::Relaxed),
|
||||
);
|
||||
if stats.running as u64 == config.worker_count {
|
||||
if running == config.worker_count {
|
||||
warn!("Reached max number of send activity workers ({}). Consider increasing worker count to avoid federation delays", config.worker_count);
|
||||
warn!(stats_fmt);
|
||||
} else {
|
||||
|
@ -114,56 +133,66 @@ where
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct SendActivityTask {
|
||||
actor_id: Url,
|
||||
activity_id: Url,
|
||||
activity: String,
|
||||
activity: Bytes,
|
||||
inbox: Url,
|
||||
private_key: String,
|
||||
private_key: PKey<Private>,
|
||||
http_signature_compat: bool,
|
||||
}
|
||||
|
||||
impl ActixJob for SendActivityTask {
|
||||
type State = QueueState;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<(), anyhow::Error>>>>;
|
||||
const NAME: &'static str = "SendActivityTask";
|
||||
|
||||
const MAX_RETRIES: MaxRetries = MaxRetries::Count(3);
|
||||
/// This gives the following retry intervals:
|
||||
/// - 60s (one minute, for service restart)
|
||||
/// - 60min (one hour, for instance maintenance)
|
||||
/// - 60h (2.5 days, for major incident with rebuild from backup)
|
||||
const BACKOFF: Backoff = Backoff::Exponential(60);
|
||||
|
||||
fn run(self, state: Self::State) -> Self::Future {
|
||||
Box::pin(async move { do_send(self, &state.client, state.timeout).await })
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_send(
|
||||
task: SendActivityTask,
|
||||
async fn sign_and_send(
|
||||
task: &SendActivityTask,
|
||||
client: &ClientWithMiddleware,
|
||||
timeout: Duration,
|
||||
retry_strategy: RetryStrategy,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
debug!("Sending {} to {}", task.activity_id, task.inbox);
|
||||
debug!(
|
||||
"Sending {} to {}, contents:\n {}",
|
||||
task.activity_id,
|
||||
task.inbox,
|
||||
serde_json::from_slice::<serde_json::Value>(&task.activity)?
|
||||
);
|
||||
let request_builder = client
|
||||
.post(task.inbox.to_string())
|
||||
.timeout(timeout)
|
||||
.headers(generate_request_headers(&task.inbox));
|
||||
let request = sign_request(
|
||||
request_builder,
|
||||
task.actor_id,
|
||||
task.activity,
|
||||
task.private_key,
|
||||
&task.actor_id,
|
||||
task.activity.clone(),
|
||||
task.private_key.clone(),
|
||||
task.http_signature_compat,
|
||||
)
|
||||
.await?;
|
||||
|
||||
retry(
|
||||
|| {
|
||||
send(
|
||||
task,
|
||||
client,
|
||||
request
|
||||
.try_clone()
|
||||
.expect("The body of the request is not cloneable"),
|
||||
)
|
||||
},
|
||||
retry_strategy,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send(
|
||||
task: &SendActivityTask,
|
||||
client: &ClientWithMiddleware,
|
||||
request: Request,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let response = client.execute(request).await;
|
||||
|
||||
match response {
|
||||
Ok(o) if o.status().is_success() => {
|
||||
info!(
|
||||
debug!(
|
||||
"Activity {} delivered successfully to {}",
|
||||
task.activity_id, task.inbox
|
||||
);
|
||||
|
@ -171,7 +200,7 @@ async fn do_send(
|
|||
}
|
||||
Ok(o) if o.status().is_client_error() => {
|
||||
let text = o.text_limited().await.map_err(Error::other)?;
|
||||
info!(
|
||||
debug!(
|
||||
"Activity {} was rejected by {}, aborting: {}",
|
||||
task.activity_id, task.inbox, text,
|
||||
);
|
||||
|
@ -189,7 +218,7 @@ async fn do_send(
|
|||
))
|
||||
}
|
||||
Err(e) => {
|
||||
info!(
|
||||
debug!(
|
||||
"Unable to connect to {}, aborting task {}: {}",
|
||||
task.inbox, task.activity_id, e
|
||||
);
|
||||
|
@ -220,27 +249,401 @@ pub(crate) fn generate_request_headers(inbox_url: &Url) -> HeaderMap {
|
|||
headers
|
||||
}
|
||||
|
||||
pub(crate) fn create_activity_queue(
|
||||
client: ClientWithMiddleware,
|
||||
worker_count: u64,
|
||||
request_timeout: Duration,
|
||||
debug: bool,
|
||||
) -> Manager {
|
||||
// queue is not used in debug mod, so dont create any workers to avoid log spam
|
||||
let worker_count = if debug { 0 } else { worker_count };
|
||||
|
||||
// Configure and start our workers
|
||||
WorkerConfig::new_managed(Storage::new(ActixTimer), move |_| QueueState {
|
||||
client: client.clone(),
|
||||
timeout: request_timeout,
|
||||
})
|
||||
.register::<SendActivityTask>()
|
||||
.set_worker_count("default", worker_count)
|
||||
.start()
|
||||
/// A simple activity queue which spawns tokio workers to send out requests
|
||||
/// When creating a queue, it will spawn a task per worker thread
|
||||
/// Uses an unbounded mpsc queue for communication (i.e, all messages are in memory)
|
||||
pub(crate) struct ActivityQueue {
|
||||
// Stats shared between the queue and workers
|
||||
stats: Arc<Stats>,
|
||||
sender: UnboundedSender<SendActivityTask>,
|
||||
sender_task: JoinHandle<()>,
|
||||
retry_sender_task: JoinHandle<()>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct QueueState {
|
||||
/// Simple stat counter to show where we're up to with sending messages
|
||||
/// This is a lock-free way to share things between tasks
|
||||
/// When reading these values it's possible (but extremely unlikely) to get stale data if a worker task is in the middle of transitioning
|
||||
#[derive(Default)]
|
||||
struct Stats {
|
||||
pending: AtomicUsize,
|
||||
running: AtomicUsize,
|
||||
retries: AtomicUsize,
|
||||
dead_last_hour: AtomicUsize,
|
||||
completed_last_hour: AtomicUsize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
struct RetryStrategy {
|
||||
/// Amount of time in seconds to back off
|
||||
backoff: usize,
|
||||
/// Amount of times to retry
|
||||
retries: usize,
|
||||
/// If this particular request has already been retried, you can add an offset here to increment the count to start
|
||||
offset: usize,
|
||||
/// Number of seconds to sleep before trying
|
||||
initial_sleep: usize,
|
||||
}
|
||||
|
||||
/// A tokio spawned worker which is responsible for submitting requests to federated servers
|
||||
/// This will retry up to one time with the same signature, and if it fails, will move it to the retry queue.
|
||||
/// We need to retry activity sending in case the target instances is temporarily unreachable.
|
||||
/// In this case, the task is stored and resent when the instance is hopefully back up. This
|
||||
/// list shows the retry intervals, and which events of the target instance can be covered:
|
||||
/// - 60s (one minute, service restart) -- happens in the worker w/ same signature
|
||||
/// - 60min (one hour, instance maintenance) --- happens in the retry worker
|
||||
/// - 60h (2.5 days, major incident with rebuild from backup) --- happens in the retry worker
|
||||
async fn worker(
|
||||
client: ClientWithMiddleware,
|
||||
timeout: Duration,
|
||||
message: SendActivityTask,
|
||||
retry_queue: UnboundedSender<SendActivityTask>,
|
||||
stats: Arc<Stats>,
|
||||
strategy: RetryStrategy,
|
||||
) {
|
||||
stats.pending.fetch_sub(1, Ordering::Relaxed);
|
||||
stats.running.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
let outcome = sign_and_send(&message, &client, timeout, strategy).await;
|
||||
|
||||
// "Running" has finished, check the outcome
|
||||
stats.running.fetch_sub(1, Ordering::Relaxed);
|
||||
|
||||
match outcome {
|
||||
Ok(_) => {
|
||||
stats.completed_last_hour.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
Err(_err) => {
|
||||
stats.retries.fetch_add(1, Ordering::Relaxed);
|
||||
warn!(
|
||||
"Sending activity {} to {} to the retry queue to be tried again later",
|
||||
message.activity_id, message.inbox
|
||||
);
|
||||
// Send to the retry queue. Ignoring whether it succeeds or not
|
||||
retry_queue.send(message).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn retry_worker(
|
||||
client: ClientWithMiddleware,
|
||||
timeout: Duration,
|
||||
message: SendActivityTask,
|
||||
stats: Arc<Stats>,
|
||||
strategy: RetryStrategy,
|
||||
) {
|
||||
// Because the times are pretty extravagant between retries, we have to re-sign each time
|
||||
let outcome = retry(
|
||||
|| {
|
||||
sign_and_send(
|
||||
&message,
|
||||
&client,
|
||||
timeout,
|
||||
RetryStrategy {
|
||||
backoff: 0,
|
||||
retries: 0,
|
||||
offset: 0,
|
||||
initial_sleep: 0,
|
||||
},
|
||||
)
|
||||
},
|
||||
strategy,
|
||||
)
|
||||
.await;
|
||||
|
||||
stats.retries.fetch_sub(1, Ordering::Relaxed);
|
||||
|
||||
match outcome {
|
||||
Ok(_) => {
|
||||
stats.completed_last_hour.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
Err(_err) => {
|
||||
stats.dead_last_hour.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ActivityQueue {
|
||||
fn new(
|
||||
client: ClientWithMiddleware,
|
||||
worker_count: usize,
|
||||
retry_count: usize,
|
||||
timeout: Duration,
|
||||
backoff: usize, // This should be 60 seconds by default or 1 second in tests
|
||||
) -> Self {
|
||||
let stats: Arc<Stats> = Default::default();
|
||||
|
||||
// This task clears the dead/completed stats every hour
|
||||
let hour_stats = stats.clone();
|
||||
tokio::spawn(async move {
|
||||
let duration = Duration::from_secs(3600);
|
||||
loop {
|
||||
tokio::time::sleep(duration).await;
|
||||
hour_stats.completed_last_hour.store(0, Ordering::Relaxed);
|
||||
hour_stats.dead_last_hour.store(0, Ordering::Relaxed);
|
||||
}
|
||||
});
|
||||
|
||||
let (retry_sender, mut retry_receiver) = unbounded_channel();
|
||||
let retry_stats = stats.clone();
|
||||
let retry_client = client.clone();
|
||||
|
||||
// The "fast path" retry
|
||||
// The backoff should be < 5 mins for this to work otherwise signatures may expire
|
||||
// This strategy is the one that is used with the *same* signature
|
||||
let strategy = RetryStrategy {
|
||||
backoff,
|
||||
retries: 1,
|
||||
offset: 0,
|
||||
initial_sleep: 0,
|
||||
};
|
||||
|
||||
// The "retry path" strategy
|
||||
// After the fast path fails, a task will sleep up to backoff ^ 2 and then retry again
|
||||
let retry_strategy = RetryStrategy {
|
||||
backoff,
|
||||
retries: 3,
|
||||
offset: 2,
|
||||
initial_sleep: backoff.pow(2), // wait 60 mins before even trying
|
||||
};
|
||||
|
||||
let retry_sender_task = tokio::spawn(async move {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
while let Some(message) = retry_receiver.recv().await {
|
||||
// If we're over the limit of retries, wait for them to finish before spawning
|
||||
while join_set.len() >= retry_count {
|
||||
join_set.join_next().await;
|
||||
}
|
||||
|
||||
join_set.spawn(retry_worker(
|
||||
retry_client.clone(),
|
||||
timeout,
|
||||
message,
|
||||
retry_stats.clone(),
|
||||
retry_strategy,
|
||||
));
|
||||
}
|
||||
|
||||
while !join_set.is_empty() {
|
||||
join_set.join_next().await;
|
||||
}
|
||||
});
|
||||
|
||||
let (sender, mut receiver) = unbounded_channel();
|
||||
|
||||
let sender_stats = stats.clone();
|
||||
|
||||
let sender_task = tokio::spawn(async move {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
while let Some(message) = receiver.recv().await {
|
||||
// If we're over the limit of workers, wait for them to finish before spawning
|
||||
while join_set.len() >= worker_count {
|
||||
join_set.join_next().await;
|
||||
}
|
||||
|
||||
join_set.spawn(worker(
|
||||
client.clone(),
|
||||
timeout,
|
||||
message,
|
||||
retry_sender.clone(),
|
||||
sender_stats.clone(),
|
||||
strategy,
|
||||
));
|
||||
}
|
||||
|
||||
drop(retry_sender);
|
||||
|
||||
while !join_set.is_empty() {
|
||||
join_set.join_next().await;
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
stats,
|
||||
sender,
|
||||
sender_task,
|
||||
retry_sender_task,
|
||||
}
|
||||
}
|
||||
|
||||
async fn queue(&self, message: SendActivityTask) -> Result<(), anyhow::Error> {
|
||||
self.stats.pending.fetch_add(1, Ordering::Relaxed);
|
||||
self.sender.send(message)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_stats(&self) -> &Stats {
|
||||
&self.stats
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
// Drops all the senders and shuts down the workers
|
||||
async fn shutdown(self, wait_for_retries: bool) -> Result<Arc<Stats>, anyhow::Error> {
|
||||
drop(self.sender);
|
||||
|
||||
self.sender_task.await?;
|
||||
|
||||
if wait_for_retries {
|
||||
self.retry_sender_task.await?;
|
||||
}
|
||||
|
||||
Ok(self.stats)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an activity queue using tokio spawned tasks
|
||||
/// Note: requires a tokio runtime
|
||||
pub(crate) fn create_activity_queue(
|
||||
client: ClientWithMiddleware,
|
||||
worker_count: usize,
|
||||
retry_count: usize,
|
||||
request_timeout: Duration,
|
||||
) -> ActivityQueue {
|
||||
assert!(
|
||||
worker_count > 0,
|
||||
"worker count needs to be greater than zero"
|
||||
);
|
||||
|
||||
ActivityQueue::new(client, worker_count, retry_count, request_timeout, 60)
|
||||
}
|
||||
|
||||
/// Retries a future action factory function up to `amount` times with an exponential backoff timer between tries
|
||||
async fn retry<T, E: Display + Debug, F: Future<Output = Result<T, E>>, A: FnMut() -> F>(
|
||||
mut action: A,
|
||||
strategy: RetryStrategy,
|
||||
) -> Result<T, E> {
|
||||
let mut count = strategy.offset;
|
||||
|
||||
// Do an initial sleep if it's called for
|
||||
if strategy.initial_sleep > 0 {
|
||||
let sleep_dur = Duration::from_secs(strategy.initial_sleep as u64);
|
||||
tokio::time::sleep(sleep_dur).await;
|
||||
}
|
||||
|
||||
loop {
|
||||
match action().await {
|
||||
Ok(val) => return Ok(val),
|
||||
Err(err) => {
|
||||
if count < strategy.retries {
|
||||
count += 1;
|
||||
|
||||
let sleep_amt = strategy.backoff.pow(count as u32) as u64;
|
||||
let sleep_dur = Duration::from_secs(sleep_amt);
|
||||
warn!("{err:?}. Sleeping for {sleep_dur:?} and trying again");
|
||||
tokio::time::sleep(sleep_dur).await;
|
||||
continue;
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum::extract::State;
|
||||
use bytes::Bytes;
|
||||
use http::StatusCode;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::http_signatures::generate_actor_keypair;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[allow(unused)]
|
||||
// This will periodically send back internal errors to test the retry
|
||||
async fn dodgy_handler(
|
||||
State(state): State<Arc<AtomicUsize>>,
|
||||
headers: HeaderMap,
|
||||
body: Bytes,
|
||||
) -> Result<(), StatusCode> {
|
||||
debug!("Headers:{:?}", headers);
|
||||
debug!("Body len:{}", body.len());
|
||||
|
||||
if state.fetch_add(1, Ordering::Relaxed) % 20 == 0 {
|
||||
return Err(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn test_server() {
|
||||
use axum::{routing::post, Router};
|
||||
|
||||
// We should break every now and then ;)
|
||||
let state = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", post(dodgy_handler))
|
||||
.with_state(state);
|
||||
|
||||
axum::Server::bind(&"0.0.0.0:8001".parse().unwrap())
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// Queues 100 messages and then asserts that the worker runs them
|
||||
async fn test_activity_queue_workers() {
|
||||
let num_workers = 64;
|
||||
let num_messages: usize = 100;
|
||||
|
||||
tokio::spawn(test_server());
|
||||
|
||||
/*
|
||||
// uncomment for debug logs & stats
|
||||
use tracing::log::LevelFilter;
|
||||
|
||||
env_logger::builder()
|
||||
.filter_level(LevelFilter::Warn)
|
||||
.filter_module("activitypub_federation", LevelFilter::Info)
|
||||
.format_timestamp(None)
|
||||
.init();
|
||||
|
||||
*/
|
||||
|
||||
let activity_queue = ActivityQueue::new(
|
||||
reqwest::Client::default().into(),
|
||||
num_workers,
|
||||
num_workers,
|
||||
Duration::from_secs(10),
|
||||
1,
|
||||
);
|
||||
|
||||
let keypair = generate_actor_keypair().unwrap();
|
||||
|
||||
let message = SendActivityTask {
|
||||
actor_id: "http://localhost:8001".parse().unwrap(),
|
||||
activity_id: "http://localhost:8001/activity".parse().unwrap(),
|
||||
activity: "{}".into(),
|
||||
inbox: "http://localhost:8001".parse().unwrap(),
|
||||
private_key: keypair.private_key().unwrap(),
|
||||
http_signature_compat: true,
|
||||
};
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
for _ in 0..num_messages {
|
||||
activity_queue.queue(message.clone()).await.unwrap();
|
||||
}
|
||||
|
||||
info!("Queue Sent: {:?}", start.elapsed());
|
||||
|
||||
let stats = activity_queue.shutdown(true).await.unwrap();
|
||||
|
||||
info!(
|
||||
"Queue Finished. Num msgs: {}, Time {:?}, msg/s: {:0.0}",
|
||||
num_messages,
|
||||
start.elapsed(),
|
||||
num_messages as f64 / start.elapsed().as_secs_f64()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
stats.completed_last_hour.load(Ordering::Relaxed),
|
||||
num_messages
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,19 +65,19 @@ mod test {
|
|||
use reqwest_middleware::ClientWithMiddleware;
|
||||
use url::Url;
|
||||
|
||||
#[actix_rt::test]
|
||||
#[tokio::test]
|
||||
async fn test_receive_activity() {
|
||||
let (body, incoming_request, config) = setup_receive_test().await;
|
||||
receive_activity::<Follow, DbUser, DbConnection>(
|
||||
incoming_request.to_http_request(),
|
||||
body.into(),
|
||||
body,
|
||||
&config.to_request_data(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
#[tokio::test]
|
||||
async fn test_receive_activity_invalid_body_signature() {
|
||||
let (_, incoming_request, config) = setup_receive_test().await;
|
||||
let err = receive_activity::<Follow, DbUser, DbConnection>(
|
||||
|
@ -93,13 +93,13 @@ mod test {
|
|||
assert_eq!(e, &Error::ActivityBodyDigestInvalid)
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
#[tokio::test]
|
||||
async fn test_receive_activity_invalid_path() {
|
||||
let (body, incoming_request, config) = setup_receive_test().await;
|
||||
let incoming_request = incoming_request.uri("/wrong");
|
||||
let err = receive_activity::<Follow, DbUser, DbConnection>(
|
||||
incoming_request.to_http_request(),
|
||||
body.into(),
|
||||
body,
|
||||
&config.to_request_data(),
|
||||
)
|
||||
.await
|
||||
|
@ -110,7 +110,7 @@ mod test {
|
|||
assert_eq!(e, &Error::ActivitySignatureInvalid)
|
||||
}
|
||||
|
||||
async fn setup_receive_test() -> (String, TestRequest, FederationConfig<DbConnection>) {
|
||||
async fn setup_receive_test() -> (Bytes, TestRequest, FederationConfig<DbConnection>) {
|
||||
let inbox = "https://example.com/inbox";
|
||||
let headers = generate_request_headers(&Url::parse(inbox).unwrap());
|
||||
let request_builder = ClientWithMiddleware::from(Client::default())
|
||||
|
@ -122,12 +122,12 @@ mod test {
|
|||
kind: Default::default(),
|
||||
id: "http://localhost:123/1".try_into().unwrap(),
|
||||
};
|
||||
let body = serde_json::to_string(&activity).unwrap();
|
||||
let body: Bytes = serde_json::to_vec(&activity).unwrap().into();
|
||||
let outgoing_request = sign_request(
|
||||
request_builder,
|
||||
activity.actor.into_inner(),
|
||||
body.to_string(),
|
||||
DB_USER_KEYPAIR.private_key.clone(),
|
||||
&activity.actor.into_inner(),
|
||||
body.clone(),
|
||||
DB_USER_KEYPAIR.private_key().unwrap(),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
|
@ -142,6 +142,7 @@ mod test {
|
|||
.app_data(DbConnection)
|
||||
.debug(true)
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
(body, incoming_request, config)
|
||||
}
|
||||
|
|
|
@ -4,26 +4,27 @@
|
|||
//!
|
||||
//! ```
|
||||
//! # use activitypub_federation::config::FederationConfig;
|
||||
//! # let _ = actix_rt::System::new();
|
||||
//! # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
//! let settings = FederationConfig::builder()
|
||||
//! .domain("example.com")
|
||||
//! .app_data(())
|
||||
//! .http_fetch_limit(50)
|
||||
//! .worker_count(16)
|
||||
//! .build()?;
|
||||
//! .build().await?;
|
||||
//! # Ok::<(), anyhow::Error>(())
|
||||
//! # }).unwrap()
|
||||
//! ```
|
||||
|
||||
use crate::{
|
||||
activity_queue::create_activity_queue,
|
||||
activity_queue::{create_activity_queue, ActivityQueue},
|
||||
error::Error,
|
||||
protocol::verification::verify_domains_match,
|
||||
traits::{ActivityHandler, Actor},
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use background_jobs::Manager;
|
||||
use derive_builder::Builder;
|
||||
use dyn_clone::{clone_trait_object, DynClone};
|
||||
use openssl::pkey::{PKey, Private};
|
||||
use reqwest_middleware::ClientWithMiddleware;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::{
|
||||
|
@ -54,9 +55,14 @@ pub struct FederationConfig<T: Clone> {
|
|||
/// HTTP client used for all outgoing requests. Middleware can be used to add functionality
|
||||
/// like log tracing or retry of failed requests.
|
||||
pub(crate) client: ClientWithMiddleware,
|
||||
/// Number of worker threads for sending outgoing activities
|
||||
#[builder(default = "64")]
|
||||
pub(crate) worker_count: u64,
|
||||
/// Number of tasks that can be in-flight concurrently.
|
||||
/// Tasks are retried once after a minute, then put into the retry queue
|
||||
#[builder(default = "1024")]
|
||||
pub(crate) worker_count: usize,
|
||||
/// Number of concurrent tasks that are being retried in-flight concurrently.
|
||||
/// Tasks are retried after an hour, then again in 60 hours.
|
||||
#[builder(default = "128")]
|
||||
pub(crate) retry_count: usize,
|
||||
/// Run library in debug mode. This allows usage of http and localhost urls. It also sends
|
||||
/// outgoing activities synchronously, not in background thread. This helps to make tests
|
||||
/// more consistent. Do not use for production.
|
||||
|
@ -79,11 +85,11 @@ pub struct FederationConfig<T: Clone> {
|
|||
/// This can be used to implement secure mode federation.
|
||||
/// <https://docs.joinmastodon.org/spec/activitypub/#secure-mode>
|
||||
#[builder(default = "None", setter(custom))]
|
||||
pub(crate) signed_fetch_actor: Option<Arc<(Url, String)>>,
|
||||
pub(crate) signed_fetch_actor: Option<Arc<(Url, PKey<Private>)>>,
|
||||
/// Queue for sending outgoing activities. Only optional to make builder work, its always
|
||||
/// present once constructed.
|
||||
#[builder(setter(skip))]
|
||||
pub(crate) activity_queue: Option<Arc<Manager>>,
|
||||
pub(crate) activity_queue: Option<Arc<ActivityQueue>>,
|
||||
}
|
||||
|
||||
impl<T: Clone> FederationConfig<T> {
|
||||
|
@ -180,7 +186,10 @@ impl<T: Clone> FederationConfigBuilder<T> {
|
|||
let private_key_pem = actor
|
||||
.private_key_pem()
|
||||
.expect("actor does not have a private key to sign with");
|
||||
self.signed_fetch_actor = Some(Some(Arc::new((actor.id(), private_key_pem))));
|
||||
|
||||
let private_key = PKey::private_key_from_pem(private_key_pem.as_bytes())
|
||||
.expect("Could not decode PEM data");
|
||||
self.signed_fetch_actor = Some(Some(Arc::new((actor.id(), private_key))));
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -188,13 +197,14 @@ impl<T: Clone> FederationConfigBuilder<T> {
|
|||
///
|
||||
/// Values which are not explicitly specified use the defaults. Also initializes the
|
||||
/// queue for outgoing activities, which is stored internally in the config struct.
|
||||
pub fn build(&mut self) -> Result<FederationConfig<T>, FederationConfigBuilderError> {
|
||||
/// Requires a tokio runtime for the background queue.
|
||||
pub async fn build(&mut self) -> Result<FederationConfig<T>, FederationConfigBuilderError> {
|
||||
let mut config = self.partial_build()?;
|
||||
let queue = create_activity_queue(
|
||||
config.client.clone(),
|
||||
config.worker_count,
|
||||
config.retry_count,
|
||||
config.request_timeout,
|
||||
config.debug,
|
||||
);
|
||||
config.activity_queue = Some(Arc::new(queue));
|
||||
Ok(config)
|
||||
|
|
|
@ -30,7 +30,7 @@ where
|
|||
|
||||
/// Fetches collection over HTTP
|
||||
///
|
||||
/// Unlike [ObjectId::fetch](crate::fetch::object_id::ObjectId::fetch) this method doesn't do
|
||||
/// Unlike [ObjectId::dereference](crate::fetch::object_id::ObjectId::dereference) this method doesn't do
|
||||
/// any caching.
|
||||
pub async fn dereference(
|
||||
&self,
|
||||
|
|
|
@ -9,6 +9,7 @@ use crate::{
|
|||
reqwest_shim::ResponseExt,
|
||||
FEDERATION_CONTENT_TYPE,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use http::StatusCode;
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
@ -56,8 +57,8 @@ pub async fn fetch_object_http<T: Clone, Kind: DeserializeOwned>(
|
|||
let res = if let Some((actor_id, private_key_pem)) = config.signed_fetch_actor.as_deref() {
|
||||
let req = sign_request(
|
||||
req,
|
||||
actor_id.clone(),
|
||||
String::new(),
|
||||
actor_id,
|
||||
Bytes::new(),
|
||||
private_key_pem.clone(),
|
||||
data.config.http_signature_compat,
|
||||
)
|
||||
|
|
|
@ -36,13 +36,12 @@ where
|
|||
/// # use activitypub_federation::config::FederationConfig;
|
||||
/// # use activitypub_federation::error::Error::NotFound;
|
||||
/// # use activitypub_federation::traits::tests::{DbConnection, DbUser};
|
||||
/// # let _ = actix_rt::System::new();
|
||||
/// # actix_rt::Runtime::new().unwrap().block_on(async {
|
||||
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
|
||||
/// # let db_connection = DbConnection;
|
||||
/// let config = FederationConfig::builder()
|
||||
/// .domain("example.com")
|
||||
/// .app_data(db_connection)
|
||||
/// .build()?;
|
||||
/// .build().await?;
|
||||
/// let request_data = config.to_request_data();
|
||||
/// let object_id = ObjectId::<DbUser>::parse("https://lemmy.ml/u/nutomic")?;
|
||||
/// // Attempt to fetch object from local database or fall back to remote server
|
||||
|
|
|
@ -199,12 +199,13 @@ mod tests {
|
|||
traits::tests::{DbConnection, DbUser},
|
||||
};
|
||||
|
||||
#[actix_rt::test]
|
||||
#[tokio::test]
|
||||
async fn test_webfinger() {
|
||||
let config = FederationConfig::builder()
|
||||
.domain("example.com")
|
||||
.app_data(DbConnection)
|
||||
.build()
|
||||
.await
|
||||
.unwrap();
|
||||
let data = config.to_request_data();
|
||||
let res =
|
||||
|
|
|
@ -13,12 +13,13 @@ use crate::{
|
|||
traits::{Actor, Object},
|
||||
};
|
||||
use base64::{engine::general_purpose::STANDARD as Base64, Engine};
|
||||
use bytes::Bytes;
|
||||
use http::{header::HeaderName, uri::PathAndQuery, HeaderValue, Method, Uri};
|
||||
use http_signature_normalization_reqwest::prelude::{Config, SignExt};
|
||||
use once_cell::sync::Lazy;
|
||||
use openssl::{
|
||||
hash::MessageDigest,
|
||||
pkey::PKey,
|
||||
pkey::{PKey, Private},
|
||||
rsa::Rsa,
|
||||
sign::{Signer, Verifier},
|
||||
};
|
||||
|
@ -26,7 +27,7 @@ use reqwest::Request;
|
|||
use reqwest_middleware::RequestBuilder;
|
||||
use serde::Deserialize;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::{collections::BTreeMap, fmt::Debug, io::ErrorKind};
|
||||
use std::{collections::BTreeMap, fmt::Debug, io::ErrorKind, time::Duration};
|
||||
use tracing::debug;
|
||||
use url::Url;
|
||||
|
||||
|
@ -39,6 +40,14 @@ pub struct Keypair {
|
|||
pub public_key: String,
|
||||
}
|
||||
|
||||
impl Keypair {
|
||||
/// Helper method to turn this into an openssl private key
|
||||
#[cfg(test)]
|
||||
pub(crate) fn private_key(&self) -> Result<PKey<Private>, anyhow::Error> {
|
||||
Ok(PKey::private_key_from_pem(self.private_key.as_bytes())?)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a random asymmetric keypair for ActivityPub HTTP signatures.
|
||||
pub fn generate_actor_keypair() -> Result<Keypair, std::io::Error> {
|
||||
let rsa = Rsa::generate(2048)?;
|
||||
|
@ -58,19 +67,27 @@ pub fn generate_actor_keypair() -> Result<Keypair, std::io::Error> {
|
|||
})
|
||||
}
|
||||
|
||||
/// Sets the amount of time that a signed request is valid. Currenlty 5 minutes
|
||||
/// Mastodon & friends have ~1 hour expiry from creation if it's not set in the header
|
||||
pub(crate) const EXPIRES_AFTER: Duration = Duration::from_secs(300);
|
||||
|
||||
/// Creates an HTTP post request to `inbox_url`, with the given `client` and `headers`, and
|
||||
/// `activity` as request body. The request is signed with `private_key` and then sent.
|
||||
pub(crate) async fn sign_request(
|
||||
request_builder: RequestBuilder,
|
||||
actor_id: Url,
|
||||
activity: String,
|
||||
private_key: String,
|
||||
actor_id: &Url,
|
||||
activity: Bytes,
|
||||
private_key: PKey<Private>,
|
||||
http_signature_compat: bool,
|
||||
) -> Result<Request, anyhow::Error> {
|
||||
static CONFIG: Lazy<Config> = Lazy::new(Config::new);
|
||||
static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| Config::new().mastodon_compat());
|
||||
static CONFIG: Lazy<Config> = Lazy::new(|| Config::new().set_expiration(EXPIRES_AFTER));
|
||||
static CONFIG_COMPAT: Lazy<Config> = Lazy::new(|| {
|
||||
Config::new()
|
||||
.mastodon_compat()
|
||||
.set_expiration(EXPIRES_AFTER)
|
||||
});
|
||||
|
||||
let key_id = main_key_id(&actor_id);
|
||||
let key_id = main_key_id(actor_id);
|
||||
let sig_conf = match http_signature_compat {
|
||||
false => CONFIG.clone(),
|
||||
true => CONFIG_COMPAT.clone(),
|
||||
|
@ -82,7 +99,6 @@ pub(crate) async fn sign_request(
|
|||
Sha256::new(),
|
||||
activity,
|
||||
move |signing_string| {
|
||||
let private_key = PKey::private_key_from_pem(private_key.as_bytes())?;
|
||||
let mut signer = Signer::new(MessageDigest::sha256(), &private_key)?;
|
||||
signer.update(signing_string.as_bytes())?;
|
||||
|
||||
|
@ -259,7 +275,7 @@ pub mod test {
|
|||
static INBOX_URL: Lazy<Url> =
|
||||
Lazy::new(|| Url::parse("https://example.com/u/alice/inbox").unwrap());
|
||||
|
||||
#[actix_rt::test]
|
||||
#[tokio::test]
|
||||
async fn test_sign() {
|
||||
let mut headers = generate_request_headers(&INBOX_URL);
|
||||
// use hardcoded date in order to test against hardcoded signature
|
||||
|
@ -273,9 +289,9 @@ pub mod test {
|
|||
.headers(headers);
|
||||
let request = sign_request(
|
||||
request_builder,
|
||||
ACTOR_ID.clone(),
|
||||
"my activity".to_string(),
|
||||
test_keypair().private_key,
|
||||
&ACTOR_ID,
|
||||
"my activity".into(),
|
||||
PKey::private_key_from_pem(test_keypair().private_key.as_bytes()).unwrap(),
|
||||
// set this to prevent created/expires headers to be generated and inserted
|
||||
// automatically from current time
|
||||
true,
|
||||
|
@ -301,7 +317,7 @@ pub mod test {
|
|||
assert_eq!(signature, expected_signature);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
#[tokio::test]
|
||||
async fn test_verify() {
|
||||
let headers = generate_request_headers(&INBOX_URL);
|
||||
let request_builder = ClientWithMiddleware::from(Client::new())
|
||||
|
@ -309,9 +325,9 @@ pub mod test {
|
|||
.headers(headers);
|
||||
let request = sign_request(
|
||||
request_builder,
|
||||
ACTOR_ID.clone(),
|
||||
"my activity".to_string(),
|
||||
test_keypair().private_key,
|
||||
&ACTOR_ID,
|
||||
"my activity".to_string().into(),
|
||||
PKey::private_key_from_pem(test_keypair().private_key.as_bytes()).unwrap(),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
|
|
Loading…
Reference in a new issue