mirror of
https://github.com/dani-garcia/vaultwarden
synced 2024-11-25 21:30:18 +00:00
Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking
This commit is contained in:
parent
08f0de7b46
commit
2d5f172e77
30 changed files with 1314 additions and 1028 deletions
874
Cargo.lock
generated
874
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
27
Cargo.toml
27
Cargo.toml
|
@ -3,7 +3,7 @@ name = "vaultwarden"
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
authors = ["Daniel García <dani-garcia@users.noreply.github.com>"]
|
authors = ["Daniel García <dani-garcia@users.noreply.github.com>"]
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
rust-version = "1.60"
|
rust-version = "1.56"
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
repository = "https://github.com/dani-garcia/vaultwarden"
|
repository = "https://github.com/dani-garcia/vaultwarden"
|
||||||
|
@ -13,6 +13,7 @@ publish = false
|
||||||
build = "build.rs"
|
build = "build.rs"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
# default = ["sqlite"]
|
||||||
# Empty to keep compatibility, prefer to set USE_SYSLOG=true
|
# Empty to keep compatibility, prefer to set USE_SYSLOG=true
|
||||||
enable_syslog = []
|
enable_syslog = []
|
||||||
mysql = ["diesel/mysql", "diesel_migrations/mysql"]
|
mysql = ["diesel/mysql", "diesel_migrations/mysql"]
|
||||||
|
@ -29,22 +30,22 @@ unstable = []
|
||||||
syslog = "4.0.1"
|
syslog = "4.0.1"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# Web framework for nightly with a focus on ease-of-use, expressibility, and speed.
|
# Web framework
|
||||||
rocket = { version = "=0.5.0-dev", features = ["tls"], default-features = false }
|
rocket = { version = "0.5.0-rc.1", features = ["tls", "json"], default-features = false }
|
||||||
rocket_contrib = "=0.5.0-dev"
|
|
||||||
|
|
||||||
# HTTP client
|
# Async futures
|
||||||
reqwest = { version = "0.11.9", features = ["blocking", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
|
futures = "0.3.19"
|
||||||
|
tokio = { version = "1.16.1", features = ["rt-multi-thread", "fs", "io-util", "parking_lot"] }
|
||||||
|
|
||||||
|
# HTTP client
|
||||||
|
reqwest = { version = "0.11.9", features = ["stream", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
|
||||||
|
bytes = "1.1.0"
|
||||||
|
|
||||||
# Used for custom short lived cookie jar
|
# Used for custom short lived cookie jar
|
||||||
cookie = "0.15.1"
|
cookie = "0.15.1"
|
||||||
cookie_store = "0.15.1"
|
cookie_store = "0.15.1"
|
||||||
bytes = "1.1.0"
|
|
||||||
url = "2.2.2"
|
url = "2.2.2"
|
||||||
|
|
||||||
# multipart/form-data support
|
|
||||||
multipart = { version = "0.18.0", features = ["server"], default-features = false }
|
|
||||||
|
|
||||||
# WebSockets library
|
# WebSockets library
|
||||||
ws = { version = "0.11.1", package = "parity-ws" }
|
ws = { version = "0.11.1", package = "parity-ws" }
|
||||||
|
|
||||||
|
@ -141,10 +142,10 @@ backtrace = "0.3.64"
|
||||||
paste = "1.0.6"
|
paste = "1.0.6"
|
||||||
governor = "0.4.1"
|
governor = "0.4.1"
|
||||||
|
|
||||||
|
ctrlc = { version = "3.2.1", features = ["termination"] }
|
||||||
|
|
||||||
[patch.crates-io]
|
[patch.crates-io]
|
||||||
# Use newest ring
|
rocket = { git = 'https://github.com/SergioBenitez/Rocket', rev = '8cae077ba1d54b92cdef3e171a730b819d5eeb8e' }
|
||||||
rocket = { git = 'https://github.com/SergioBenitez/Rocket', rev = '263e39b5b429de1913ce7e3036575a7b4d88b6d7' }
|
|
||||||
rocket_contrib = { git = 'https://github.com/SergioBenitez/Rocket', rev = '263e39b5b429de1913ce7e3036575a7b4d88b6d7' }
|
|
||||||
|
|
||||||
# The maintainer of the `job_scheduler` crate doesn't seem to have responded
|
# The maintainer of the `job_scheduler` crate doesn't seem to have responded
|
||||||
# to any issues or PRs for almost a year (as of April 2021). This hopefully
|
# to any issues or PRs for almost a year (as of April 2021). This hopefully
|
||||||
|
|
|
@ -1,2 +0,0 @@
|
||||||
[global.limits]
|
|
||||||
json = 10485760 # 10 MiB
|
|
|
@ -1 +1 @@
|
||||||
nightly-2022-01-23
|
stable
|
||||||
|
|
|
@ -3,13 +3,14 @@ use serde::de::DeserializeOwned;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::{
|
use rocket::{
|
||||||
http::{Cookie, Cookies, SameSite, Status},
|
form::Form,
|
||||||
request::{self, FlashMessage, Form, FromRequest, Outcome, Request},
|
http::{Cookie, CookieJar, SameSite, Status},
|
||||||
response::{content::Html, Flash, Redirect},
|
request::{self, FlashMessage, FromRequest, Outcome, Request},
|
||||||
|
response::{content::RawHtml as Html, Flash, Redirect},
|
||||||
Route,
|
Route,
|
||||||
};
|
};
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{ApiResult, EmptyResult, JsonResult, NumberOrString},
|
api::{ApiResult, EmptyResult, JsonResult, NumberOrString},
|
||||||
|
@ -85,10 +86,11 @@ fn admin_path() -> String {
|
||||||
|
|
||||||
struct Referer(Option<String>);
|
struct Referer(Option<String>);
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Referer {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for Referer {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||||
Outcome::Success(Referer(request.headers().get_one("Referer").map(str::to_string)))
|
Outcome::Success(Referer(request.headers().get_one("Referer").map(str::to_string)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -96,10 +98,11 @@ impl<'a, 'r> FromRequest<'a, 'r> for Referer {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct IpHeader(Option<String>);
|
struct IpHeader(Option<String>);
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for IpHeader {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for IpHeader {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
if req.headers().get_one(&CONFIG.ip_header()).is_some() {
|
if req.headers().get_one(&CONFIG.ip_header()).is_some() {
|
||||||
Outcome::Success(IpHeader(Some(CONFIG.ip_header())))
|
Outcome::Success(IpHeader(Some(CONFIG.ip_header())))
|
||||||
} else if req.headers().get_one("X-Client-IP").is_some() {
|
} else if req.headers().get_one("X-Client-IP").is_some() {
|
||||||
|
@ -138,7 +141,7 @@ fn admin_url(referer: Referer) -> String {
|
||||||
#[get("/", rank = 2)]
|
#[get("/", rank = 2)]
|
||||||
fn admin_login(flash: Option<FlashMessage>) -> ApiResult<Html<String>> {
|
fn admin_login(flash: Option<FlashMessage>) -> ApiResult<Html<String>> {
|
||||||
// If there is an error, show it
|
// If there is an error, show it
|
||||||
let msg = flash.map(|msg| format!("{}: {}", msg.name(), msg.msg()));
|
let msg = flash.map(|msg| format!("{}: {}", msg.kind(), msg.message()));
|
||||||
let json = json!({
|
let json = json!({
|
||||||
"page_content": "admin/login",
|
"page_content": "admin/login",
|
||||||
"version": VERSION,
|
"version": VERSION,
|
||||||
|
@ -159,7 +162,7 @@ struct LoginForm {
|
||||||
#[post("/", data = "<data>")]
|
#[post("/", data = "<data>")]
|
||||||
fn post_admin_login(
|
fn post_admin_login(
|
||||||
data: Form<LoginForm>,
|
data: Form<LoginForm>,
|
||||||
mut cookies: Cookies,
|
cookies: &CookieJar,
|
||||||
ip: ClientIp,
|
ip: ClientIp,
|
||||||
referer: Referer,
|
referer: Referer,
|
||||||
) -> Result<Redirect, Flash<Redirect>> {
|
) -> Result<Redirect, Flash<Redirect>> {
|
||||||
|
@ -180,7 +183,7 @@ fn post_admin_login(
|
||||||
|
|
||||||
let cookie = Cookie::build(COOKIE_NAME, jwt)
|
let cookie = Cookie::build(COOKIE_NAME, jwt)
|
||||||
.path(admin_path())
|
.path(admin_path())
|
||||||
.max_age(time::Duration::minutes(20))
|
.max_age(rocket::time::Duration::minutes(20))
|
||||||
.same_site(SameSite::Strict)
|
.same_site(SameSite::Strict)
|
||||||
.http_only(true)
|
.http_only(true)
|
||||||
.finish();
|
.finish();
|
||||||
|
@ -297,7 +300,7 @@ fn test_smtp(data: Json<InviteData>, _token: AdminToken) -> EmptyResult {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/logout")]
|
#[get("/logout")]
|
||||||
fn logout(mut cookies: Cookies, referer: Referer) -> Redirect {
|
fn logout(cookies: &CookieJar, referer: Referer) -> Redirect {
|
||||||
cookies.remove(Cookie::named(COOKIE_NAME));
|
cookies.remove(Cookie::named(COOKIE_NAME));
|
||||||
Redirect::to(admin_url(referer))
|
Redirect::to(admin_url(referer))
|
||||||
}
|
}
|
||||||
|
@ -462,23 +465,23 @@ struct GitCommit {
|
||||||
sha: String,
|
sha: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
|
async fn get_github_api<T: DeserializeOwned>(url: &str) -> Result<T, Error> {
|
||||||
let github_api = get_reqwest_client();
|
let github_api = get_reqwest_client();
|
||||||
|
|
||||||
Ok(github_api.get(url).send()?.error_for_status()?.json::<T>()?)
|
Ok(github_api.get(url).send().await?.error_for_status()?.json::<T>().await?)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn has_http_access() -> bool {
|
async fn has_http_access() -> bool {
|
||||||
let http_access = get_reqwest_client();
|
let http_access = get_reqwest_client();
|
||||||
|
|
||||||
match http_access.head("https://github.com/dani-garcia/vaultwarden").send() {
|
match http_access.head("https://github.com/dani-garcia/vaultwarden").send().await {
|
||||||
Ok(r) => r.status().is_success(),
|
Ok(r) => r.status().is_success(),
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/diagnostics")]
|
#[get("/diagnostics")]
|
||||||
fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
|
async fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResult<Html<String>> {
|
||||||
use crate::util::read_file_string;
|
use crate::util::read_file_string;
|
||||||
use chrono::prelude::*;
|
use chrono::prelude::*;
|
||||||
use std::net::ToSocketAddrs;
|
use std::net::ToSocketAddrs;
|
||||||
|
@ -497,7 +500,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
||||||
|
|
||||||
// Execute some environment checks
|
// Execute some environment checks
|
||||||
let running_within_docker = is_running_in_docker();
|
let running_within_docker = is_running_in_docker();
|
||||||
let has_http_access = has_http_access();
|
let has_http_access = has_http_access().await;
|
||||||
let uses_proxy = env::var_os("HTTP_PROXY").is_some()
|
let uses_proxy = env::var_os("HTTP_PROXY").is_some()
|
||||||
|| env::var_os("http_proxy").is_some()
|
|| env::var_os("http_proxy").is_some()
|
||||||
|| env::var_os("HTTPS_PROXY").is_some()
|
|| env::var_os("HTTPS_PROXY").is_some()
|
||||||
|
@ -513,11 +516,14 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
||||||
// TODO: Maybe we need to cache this using a LazyStatic or something. Github only allows 60 requests per hour, and we use 3 here already.
|
// TODO: Maybe we need to cache this using a LazyStatic or something. Github only allows 60 requests per hour, and we use 3 here already.
|
||||||
let (latest_release, latest_commit, latest_web_build) = if has_http_access {
|
let (latest_release, latest_commit, latest_web_build) = if has_http_access {
|
||||||
(
|
(
|
||||||
match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest") {
|
match get_github_api::<GitRelease>("https://api.github.com/repos/dani-garcia/vaultwarden/releases/latest")
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(r) => r.tag_name,
|
Ok(r) => r.tag_name,
|
||||||
_ => "-".to_string(),
|
_ => "-".to_string(),
|
||||||
},
|
},
|
||||||
match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main") {
|
match get_github_api::<GitCommit>("https://api.github.com/repos/dani-garcia/vaultwarden/commits/main").await
|
||||||
|
{
|
||||||
Ok(mut c) => {
|
Ok(mut c) => {
|
||||||
c.sha.truncate(8);
|
c.sha.truncate(8);
|
||||||
c.sha
|
c.sha
|
||||||
|
@ -531,7 +537,9 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
||||||
} else {
|
} else {
|
||||||
match get_github_api::<GitRelease>(
|
match get_github_api::<GitRelease>(
|
||||||
"https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest",
|
"https://api.github.com/repos/dani-garcia/bw_web_builds/releases/latest",
|
||||||
) {
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
|
Ok(r) => r.tag_name.trim_start_matches('v').to_string(),
|
||||||
_ => "-".to_string(),
|
_ => "-".to_string(),
|
||||||
}
|
}
|
||||||
|
@ -562,7 +570,7 @@ fn diagnostics(_token: AdminToken, ip_header: IpHeader, conn: DbConn) -> ApiResu
|
||||||
"ip_header_config": &CONFIG.ip_header(),
|
"ip_header_config": &CONFIG.ip_header(),
|
||||||
"uses_proxy": uses_proxy,
|
"uses_proxy": uses_proxy,
|
||||||
"db_type": *DB_TYPE,
|
"db_type": *DB_TYPE,
|
||||||
"db_version": get_sql_server_version(&conn),
|
"db_version": get_sql_server_version(&conn).await,
|
||||||
"admin_url": format!("{}/diagnostics", admin_url(Referer(None))),
|
"admin_url": format!("{}/diagnostics", admin_url(Referer(None))),
|
||||||
"overrides": &CONFIG.get_overrides().join(", "),
|
"overrides": &CONFIG.get_overrides().join(", "),
|
||||||
"server_time_local": Local::now().format("%Y-%m-%d %H:%M:%S %Z").to_string(),
|
"server_time_local": Local::now().format("%Y-%m-%d %H:%M:%S %Z").to_string(),
|
||||||
|
@ -591,9 +599,9 @@ fn delete_config(_token: AdminToken) -> EmptyResult {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/config/backup_db")]
|
#[post("/config/backup_db")]
|
||||||
fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
|
async fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
|
||||||
if *CAN_BACKUP {
|
if *CAN_BACKUP {
|
||||||
backup_database(&conn)
|
backup_database(&conn).await
|
||||||
} else {
|
} else {
|
||||||
err!("Can't back up current DB (Only SQLite supports this feature)");
|
err!("Can't back up current DB (Only SQLite supports this feature)");
|
||||||
}
|
}
|
||||||
|
@ -601,21 +609,22 @@ fn backup_db(_token: AdminToken, conn: DbConn) -> EmptyResult {
|
||||||
|
|
||||||
pub struct AdminToken {}
|
pub struct AdminToken {}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for AdminToken {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for AdminToken {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
|
||||||
if CONFIG.disable_admin_token() {
|
if CONFIG.disable_admin_token() {
|
||||||
Outcome::Success(AdminToken {})
|
Outcome::Success(AdminToken {})
|
||||||
} else {
|
} else {
|
||||||
let mut cookies = request.cookies();
|
let cookies = request.cookies();
|
||||||
|
|
||||||
let access_token = match cookies.get(COOKIE_NAME) {
|
let access_token = match cookies.get(COOKIE_NAME) {
|
||||||
Some(cookie) => cookie.value(),
|
Some(cookie) => cookie.value(),
|
||||||
None => return Outcome::Forward(()), // If there is no cookie, redirect to login
|
None => return Outcome::Forward(()), // If there is no cookie, redirect to login
|
||||||
};
|
};
|
||||||
|
|
||||||
let ip = match request.guard::<ClientIp>() {
|
let ip = match ClientIp::from_request(request).await {
|
||||||
Outcome::Success(ip) => ip.ip,
|
Outcome::Success(ip) => ip.ip,
|
||||||
_ => err_handler!("Error getting Client IP"),
|
_ => err_handler!("Error getting Client IP"),
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use rocket_contrib::json::Json;
|
use rocket::serde::json::Json;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::path::{Path, PathBuf};
|
|
||||||
|
|
||||||
use chrono::{NaiveDateTime, Utc};
|
use chrono::{NaiveDateTime, Utc};
|
||||||
use rocket::{http::ContentType, request::Form, Data, Route};
|
use rocket::fs::TempFile;
|
||||||
use rocket_contrib::json::Json;
|
use rocket::serde::json::Json;
|
||||||
|
use rocket::{
|
||||||
|
form::{Form, FromForm},
|
||||||
|
Route,
|
||||||
|
};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use multipart::server::{save::SavedData, Multipart, SaveResult};
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{self, EmptyResult, JsonResult, JsonUpcase, Notify, PasswordData, UpdateType},
|
api::{self, EmptyResult, JsonResult, JsonUpcase, Notify, PasswordData, UpdateType},
|
||||||
auth::Headers,
|
auth::Headers,
|
||||||
|
@ -79,9 +80,9 @@ pub fn routes() -> Vec<Route> {
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn purge_trashed_ciphers(pool: DbPool) {
|
pub async fn purge_trashed_ciphers(pool: DbPool) {
|
||||||
debug!("Purging trashed ciphers");
|
debug!("Purging trashed ciphers");
|
||||||
if let Ok(conn) = pool.get() {
|
if let Ok(conn) = pool.get().await {
|
||||||
Cipher::purge_trash(&conn);
|
Cipher::purge_trash(&conn);
|
||||||
} else {
|
} else {
|
||||||
error!("Failed to get DB connection while purging trashed ciphers")
|
error!("Failed to get DB connection while purging trashed ciphers")
|
||||||
|
@ -90,12 +91,12 @@ pub fn purge_trashed_ciphers(pool: DbPool) {
|
||||||
|
|
||||||
#[derive(FromForm, Default)]
|
#[derive(FromForm, Default)]
|
||||||
struct SyncData {
|
struct SyncData {
|
||||||
#[form(field = "excludeDomains")]
|
#[field(name = "excludeDomains")]
|
||||||
exclude_domains: bool, // Default: 'false'
|
exclude_domains: bool, // Default: 'false'
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/sync?<data..>")]
|
#[get("/sync?<data..>")]
|
||||||
fn sync(data: Form<SyncData>, headers: Headers, conn: DbConn) -> Json<Value> {
|
fn sync(data: SyncData, headers: Headers, conn: DbConn) -> Json<Value> {
|
||||||
let user_json = headers.user.to_json(&conn);
|
let user_json = headers.user.to_json(&conn);
|
||||||
|
|
||||||
let folders = Folder::find_by_user(&headers.user.uuid, &conn);
|
let folders = Folder::find_by_user(&headers.user.uuid, &conn);
|
||||||
|
@ -828,6 +829,12 @@ fn post_attachment_v2(
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(FromForm)]
|
||||||
|
struct UploadData<'f> {
|
||||||
|
key: Option<String>,
|
||||||
|
data: TempFile<'f>,
|
||||||
|
}
|
||||||
|
|
||||||
/// Saves the data content of an attachment to a file. This is common code
|
/// Saves the data content of an attachment to a file. This is common code
|
||||||
/// shared between the v2 and legacy attachment APIs.
|
/// shared between the v2 and legacy attachment APIs.
|
||||||
///
|
///
|
||||||
|
@ -836,22 +843,21 @@ fn post_attachment_v2(
|
||||||
///
|
///
|
||||||
/// When used with the v2 API, post_attachment_v2() has already created the
|
/// When used with the v2 API, post_attachment_v2() has already created the
|
||||||
/// database record, which is passed in as `attachment`.
|
/// database record, which is passed in as `attachment`.
|
||||||
fn save_attachment(
|
async fn save_attachment(
|
||||||
mut attachment: Option<Attachment>,
|
mut attachment: Option<Attachment>,
|
||||||
cipher_uuid: String,
|
cipher_uuid: String,
|
||||||
data: Data,
|
data: Form<UploadData<'_>>,
|
||||||
content_type: &ContentType,
|
|
||||||
headers: &Headers,
|
headers: &Headers,
|
||||||
conn: &DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify<'_>,
|
||||||
) -> Result<Cipher, crate::error::Error> {
|
) -> Result<(Cipher, DbConn), crate::error::Error> {
|
||||||
let cipher = match Cipher::find_by_uuid(&cipher_uuid, conn) {
|
let cipher = match Cipher::find_by_uuid(&cipher_uuid, &conn) {
|
||||||
Some(cipher) => cipher,
|
Some(cipher) => cipher,
|
||||||
None => err_discard!("Cipher doesn't exist", data),
|
None => err!("Cipher doesn't exist"),
|
||||||
};
|
};
|
||||||
|
|
||||||
if !cipher.is_write_accessible_to_user(&headers.user.uuid, conn) {
|
if !cipher.is_write_accessible_to_user(&headers.user.uuid, &conn) {
|
||||||
err_discard!("Cipher is not write accessible", data)
|
err!("Cipher is not write accessible")
|
||||||
}
|
}
|
||||||
|
|
||||||
// In the v2 API, the attachment record has already been created,
|
// In the v2 API, the attachment record has already been created,
|
||||||
|
@ -863,11 +869,11 @@ fn save_attachment(
|
||||||
|
|
||||||
let size_limit = if let Some(ref user_uuid) = cipher.user_uuid {
|
let size_limit = if let Some(ref user_uuid) = cipher.user_uuid {
|
||||||
match CONFIG.user_attachment_limit() {
|
match CONFIG.user_attachment_limit() {
|
||||||
Some(0) => err_discard!("Attachments are disabled", data),
|
Some(0) => err!("Attachments are disabled"),
|
||||||
Some(limit_kb) => {
|
Some(limit_kb) => {
|
||||||
let left = (limit_kb * 1024) - Attachment::size_by_user(user_uuid, conn) + size_adjust;
|
let left = (limit_kb * 1024) - Attachment::size_by_user(user_uuid, &conn) + size_adjust;
|
||||||
if left <= 0 {
|
if left <= 0 {
|
||||||
err_discard!("Attachment storage limit reached! Delete some attachments to free up space", data)
|
err!("Attachment storage limit reached! Delete some attachments to free up space")
|
||||||
}
|
}
|
||||||
Some(left as u64)
|
Some(left as u64)
|
||||||
}
|
}
|
||||||
|
@ -875,130 +881,78 @@ fn save_attachment(
|
||||||
}
|
}
|
||||||
} else if let Some(ref org_uuid) = cipher.organization_uuid {
|
} else if let Some(ref org_uuid) = cipher.organization_uuid {
|
||||||
match CONFIG.org_attachment_limit() {
|
match CONFIG.org_attachment_limit() {
|
||||||
Some(0) => err_discard!("Attachments are disabled", data),
|
Some(0) => err!("Attachments are disabled"),
|
||||||
Some(limit_kb) => {
|
Some(limit_kb) => {
|
||||||
let left = (limit_kb * 1024) - Attachment::size_by_org(org_uuid, conn) + size_adjust;
|
let left = (limit_kb * 1024) - Attachment::size_by_org(org_uuid, &conn) + size_adjust;
|
||||||
if left <= 0 {
|
if left <= 0 {
|
||||||
err_discard!("Attachment storage limit reached! Delete some attachments to free up space", data)
|
err!("Attachment storage limit reached! Delete some attachments to free up space")
|
||||||
}
|
}
|
||||||
Some(left as u64)
|
Some(left as u64)
|
||||||
}
|
}
|
||||||
None => None,
|
None => None,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err_discard!("Cipher is neither owned by a user nor an organization", data);
|
err!("Cipher is neither owned by a user nor an organization");
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut params = content_type.params();
|
let mut data = data.into_inner();
|
||||||
let boundary_pair = params.next().expect("No boundary provided");
|
|
||||||
let boundary = boundary_pair.1;
|
|
||||||
|
|
||||||
let base_path = Path::new(&CONFIG.attachments_folder()).join(&cipher_uuid);
|
if let Some(size_limit) = size_limit {
|
||||||
let mut path = PathBuf::new();
|
if data.data.len() > size_limit {
|
||||||
|
err!("Attachment storage limit exceeded with this file");
|
||||||
let mut attachment_key = None;
|
}
|
||||||
let mut error = None;
|
|
||||||
|
|
||||||
Multipart::with_body(data.open(), boundary)
|
|
||||||
.foreach_entry(|mut field| {
|
|
||||||
match &*field.headers.name {
|
|
||||||
"key" => {
|
|
||||||
use std::io::Read;
|
|
||||||
let mut key_buffer = String::new();
|
|
||||||
if field.data.read_to_string(&mut key_buffer).is_ok() {
|
|
||||||
attachment_key = Some(key_buffer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"data" => {
|
|
||||||
// In the legacy API, this is the encrypted filename
|
|
||||||
// provided by the client, stored to the database as-is.
|
|
||||||
// In the v2 API, this value doesn't matter, as it was
|
|
||||||
// already provided and stored via an earlier API call.
|
|
||||||
let encrypted_filename = field.headers.filename;
|
|
||||||
|
|
||||||
// This random ID is used as the name of the file on disk.
|
|
||||||
// In the legacy API, we need to generate this value here.
|
|
||||||
// In the v2 API, we use the value from post_attachment_v2().
|
|
||||||
let file_id = match &attachment {
|
|
||||||
Some(attachment) => attachment.id.clone(), // v2 API
|
|
||||||
None => crypto::generate_attachment_id(), // Legacy API
|
|
||||||
};
|
|
||||||
path = base_path.join(&file_id);
|
|
||||||
|
|
||||||
let size =
|
|
||||||
match field.data.save().memory_threshold(0).size_limit(size_limit).with_path(path.clone()) {
|
|
||||||
SaveResult::Full(SavedData::File(_, size)) => size as i32,
|
|
||||||
SaveResult::Full(other) => {
|
|
||||||
error = Some(format!("Attachment is not a file: {:?}", other));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
SaveResult::Partial(_, reason) => {
|
|
||||||
error = Some(format!("Attachment storage limit exceeded with this file: {:?}", reason));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
SaveResult::Error(e) => {
|
|
||||||
error = Some(format!("Error: {:?}", e));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(attachment) = &mut attachment {
|
|
||||||
// v2 API
|
|
||||||
|
|
||||||
// Check the actual size against the size initially provided by
|
|
||||||
// the client. Upstream allows +/- 1 MiB deviation from this
|
|
||||||
// size, but it's not clear when or why this is needed.
|
|
||||||
const LEEWAY: i32 = 1024 * 1024; // 1 MiB
|
|
||||||
let min_size = attachment.file_size - LEEWAY;
|
|
||||||
let max_size = attachment.file_size + LEEWAY;
|
|
||||||
|
|
||||||
if min_size <= size && size <= max_size {
|
|
||||||
if size != attachment.file_size {
|
|
||||||
// Update the attachment with the actual file size.
|
|
||||||
attachment.file_size = size;
|
|
||||||
attachment.save(conn).expect("Error updating attachment");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
attachment.delete(conn).ok();
|
|
||||||
|
|
||||||
let err_msg = "Attachment size mismatch".to_string();
|
|
||||||
error!("{} (expected within [{}, {}], got {})", err_msg, min_size, max_size, size);
|
|
||||||
error = Some(err_msg);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Legacy API
|
|
||||||
|
|
||||||
if encrypted_filename.is_none() {
|
|
||||||
error = Some("No filename provided".to_string());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if attachment_key.is_none() {
|
|
||||||
error = Some("No attachment key provided".to_string());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
let attachment = Attachment::new(
|
|
||||||
file_id,
|
|
||||||
cipher_uuid.clone(),
|
|
||||||
encrypted_filename.unwrap(),
|
|
||||||
size,
|
|
||||||
attachment_key.clone(),
|
|
||||||
);
|
|
||||||
attachment.save(conn).expect("Error saving attachment");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => error!("Invalid multipart name"),
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.expect("Error processing multipart data");
|
|
||||||
|
|
||||||
if let Some(ref e) = error {
|
|
||||||
std::fs::remove_file(path).ok();
|
|
||||||
err!(e);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(conn));
|
let file_id = match &attachment {
|
||||||
|
Some(attachment) => attachment.id.clone(), // v2 API
|
||||||
|
None => crypto::generate_attachment_id(), // Legacy API
|
||||||
|
};
|
||||||
|
|
||||||
Ok(cipher)
|
let folder_path = tokio::fs::canonicalize(&CONFIG.attachments_folder()).await?.join(&cipher_uuid);
|
||||||
|
let file_path = folder_path.join(&file_id);
|
||||||
|
tokio::fs::create_dir_all(&folder_path).await?;
|
||||||
|
|
||||||
|
let size = data.data.len() as i32;
|
||||||
|
if let Some(attachment) = &mut attachment {
|
||||||
|
// v2 API
|
||||||
|
|
||||||
|
// Check the actual size against the size initially provided by
|
||||||
|
// the client. Upstream allows +/- 1 MiB deviation from this
|
||||||
|
// size, but it's not clear when or why this is needed.
|
||||||
|
const LEEWAY: i32 = 1024 * 1024; // 1 MiB
|
||||||
|
let min_size = attachment.file_size - LEEWAY;
|
||||||
|
let max_size = attachment.file_size + LEEWAY;
|
||||||
|
|
||||||
|
if min_size <= size && size <= max_size {
|
||||||
|
if size != attachment.file_size {
|
||||||
|
// Update the attachment with the actual file size.
|
||||||
|
attachment.file_size = size;
|
||||||
|
attachment.save(&conn).expect("Error updating attachment");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
attachment.delete(&conn).ok();
|
||||||
|
|
||||||
|
err!(format!("Attachment size mismatch (expected within [{}, {}], got {})", min_size, max_size, size));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Legacy API
|
||||||
|
let encrypted_filename = data.data.raw_name().map(|s| s.dangerous_unsafe_unsanitized_raw().to_string());
|
||||||
|
|
||||||
|
if encrypted_filename.is_none() {
|
||||||
|
err!("No filename provided")
|
||||||
|
}
|
||||||
|
if data.key.is_none() {
|
||||||
|
err!("No attachment key provided")
|
||||||
|
}
|
||||||
|
let attachment = Attachment::new(file_id, cipher_uuid.clone(), encrypted_filename.unwrap(), size, data.key);
|
||||||
|
attachment.save(&conn).expect("Error saving attachment");
|
||||||
|
}
|
||||||
|
|
||||||
|
data.data.persist_to(file_path).await?;
|
||||||
|
|
||||||
|
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(&conn));
|
||||||
|
|
||||||
|
Ok((cipher, conn))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// v2 API for uploading the actual data content of an attachment.
|
/// v2 API for uploading the actual data content of an attachment.
|
||||||
|
@ -1006,14 +960,13 @@ fn save_attachment(
|
||||||
/// /ciphers/<uuid>/attachment/v2 route, which would otherwise conflict
|
/// /ciphers/<uuid>/attachment/v2 route, which would otherwise conflict
|
||||||
/// with this one.
|
/// with this one.
|
||||||
#[post("/ciphers/<uuid>/attachment/<attachment_id>", format = "multipart/form-data", data = "<data>", rank = 1)]
|
#[post("/ciphers/<uuid>/attachment/<attachment_id>", format = "multipart/form-data", data = "<data>", rank = 1)]
|
||||||
fn post_attachment_v2_data(
|
async fn post_attachment_v2_data(
|
||||||
uuid: String,
|
uuid: String,
|
||||||
attachment_id: String,
|
attachment_id: String,
|
||||||
data: Data,
|
data: Form<UploadData<'_>>,
|
||||||
content_type: &ContentType,
|
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify<'_>,
|
||||||
) -> EmptyResult {
|
) -> EmptyResult {
|
||||||
let attachment = match Attachment::find_by_id(&attachment_id, &conn) {
|
let attachment = match Attachment::find_by_id(&attachment_id, &conn) {
|
||||||
Some(attachment) if uuid == attachment.cipher_uuid => Some(attachment),
|
Some(attachment) if uuid == attachment.cipher_uuid => Some(attachment),
|
||||||
|
@ -1021,54 +974,51 @@ fn post_attachment_v2_data(
|
||||||
None => err!("Attachment doesn't exist"),
|
None => err!("Attachment doesn't exist"),
|
||||||
};
|
};
|
||||||
|
|
||||||
save_attachment(attachment, uuid, data, content_type, &headers, &conn, nt)?;
|
save_attachment(attachment, uuid, data, &headers, conn, nt).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Legacy API for creating an attachment associated with a cipher.
|
/// Legacy API for creating an attachment associated with a cipher.
|
||||||
#[post("/ciphers/<uuid>/attachment", format = "multipart/form-data", data = "<data>")]
|
#[post("/ciphers/<uuid>/attachment", format = "multipart/form-data", data = "<data>")]
|
||||||
fn post_attachment(
|
async fn post_attachment(
|
||||||
uuid: String,
|
uuid: String,
|
||||||
data: Data,
|
data: Form<UploadData<'_>>,
|
||||||
content_type: &ContentType,
|
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify<'_>,
|
||||||
) -> JsonResult {
|
) -> JsonResult {
|
||||||
// Setting this as None signifies to save_attachment() that it should create
|
// Setting this as None signifies to save_attachment() that it should create
|
||||||
// the attachment database record as well as saving the data to disk.
|
// the attachment database record as well as saving the data to disk.
|
||||||
let attachment = None;
|
let attachment = None;
|
||||||
|
|
||||||
let cipher = save_attachment(attachment, uuid, data, content_type, &headers, &conn, nt)?;
|
let (cipher, conn) = save_attachment(attachment, uuid, data, &headers, conn, nt).await?;
|
||||||
|
|
||||||
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, &conn)))
|
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, &conn)))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/ciphers/<uuid>/attachment-admin", format = "multipart/form-data", data = "<data>")]
|
#[post("/ciphers/<uuid>/attachment-admin", format = "multipart/form-data", data = "<data>")]
|
||||||
fn post_attachment_admin(
|
async fn post_attachment_admin(
|
||||||
uuid: String,
|
uuid: String,
|
||||||
data: Data,
|
data: Form<UploadData<'_>>,
|
||||||
content_type: &ContentType,
|
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify<'_>,
|
||||||
) -> JsonResult {
|
) -> JsonResult {
|
||||||
post_attachment(uuid, data, content_type, headers, conn, nt)
|
post_attachment(uuid, data, headers, conn, nt).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/ciphers/<uuid>/attachment/<attachment_id>/share", format = "multipart/form-data", data = "<data>")]
|
#[post("/ciphers/<uuid>/attachment/<attachment_id>/share", format = "multipart/form-data", data = "<data>")]
|
||||||
fn post_attachment_share(
|
async fn post_attachment_share(
|
||||||
uuid: String,
|
uuid: String,
|
||||||
attachment_id: String,
|
attachment_id: String,
|
||||||
data: Data,
|
data: Form<UploadData<'_>>,
|
||||||
content_type: &ContentType,
|
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify<'_>,
|
||||||
) -> JsonResult {
|
) -> JsonResult {
|
||||||
_delete_cipher_attachment_by_id(&uuid, &attachment_id, &headers, &conn, &nt)?;
|
_delete_cipher_attachment_by_id(&uuid, &attachment_id, &headers, &conn, &nt)?;
|
||||||
post_attachment(uuid, data, content_type, headers, conn, nt)
|
post_attachment(uuid, data, headers, conn, nt).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/ciphers/<uuid>/attachment/<attachment_id>/delete-admin")]
|
#[post("/ciphers/<uuid>/attachment/<attachment_id>/delete-admin")]
|
||||||
|
@ -1248,13 +1198,13 @@ fn move_cipher_selected_put(
|
||||||
|
|
||||||
#[derive(FromForm)]
|
#[derive(FromForm)]
|
||||||
struct OrganizationId {
|
struct OrganizationId {
|
||||||
#[form(field = "organizationId")]
|
#[field(name = "organizationId")]
|
||||||
org_id: String,
|
org_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/ciphers/purge?<organization..>", data = "<data>")]
|
#[post("/ciphers/purge?<organization..>", data = "<data>")]
|
||||||
fn delete_all(
|
fn delete_all(
|
||||||
organization: Option<Form<OrganizationId>>,
|
organization: Option<OrganizationId>,
|
||||||
data: JsonUpcase<PasswordData>,
|
data: JsonUpcase<PasswordData>,
|
||||||
headers: Headers,
|
headers: Headers,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use chrono::{Duration, Utc};
|
use chrono::{Duration, Utc};
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::borrow::Borrow;
|
use std::borrow::Borrow;
|
||||||
|
|
||||||
|
@ -709,13 +709,13 @@ fn check_emergency_access_allowed() -> EmptyResult {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn emergency_request_timeout_job(pool: DbPool) {
|
pub async fn emergency_request_timeout_job(pool: DbPool) {
|
||||||
debug!("Start emergency_request_timeout_job");
|
debug!("Start emergency_request_timeout_job");
|
||||||
if !CONFIG.emergency_access_allowed() {
|
if !CONFIG.emergency_access_allowed() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(conn) = pool.get() {
|
if let Ok(conn) = pool.get().await {
|
||||||
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
|
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
|
||||||
|
|
||||||
if emergency_access_list.is_empty() {
|
if emergency_access_list.is_empty() {
|
||||||
|
@ -756,13 +756,13 @@ pub fn emergency_request_timeout_job(pool: DbPool) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn emergency_notification_reminder_job(pool: DbPool) {
|
pub async fn emergency_notification_reminder_job(pool: DbPool) {
|
||||||
debug!("Start emergency_notification_reminder_job");
|
debug!("Start emergency_notification_reminder_job");
|
||||||
if !CONFIG.emergency_access_allowed() {
|
if !CONFIG.emergency_access_allowed() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(conn) = pool.get() {
|
if let Ok(conn) = pool.get().await {
|
||||||
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
|
let emergency_access_list = EmergencyAccess::find_all_recoveries(&conn);
|
||||||
|
|
||||||
if emergency_access_list.is_empty() {
|
if emergency_access_list.is_empty() {
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use rocket_contrib::json::Json;
|
use rocket::serde::json::Json;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
|
|
@ -31,8 +31,8 @@ pub fn routes() -> Vec<Route> {
|
||||||
//
|
//
|
||||||
// Move this somewhere else
|
// Move this somewhere else
|
||||||
//
|
//
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -144,7 +144,7 @@ fn put_eq_domains(data: JsonUpcase<EquivDomainData>, headers: Headers, conn: DbC
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/hibp/breach?<username>")]
|
#[get("/hibp/breach?<username>")]
|
||||||
fn hibp_breach(username: String) -> JsonResult {
|
async fn hibp_breach(username: String) -> JsonResult {
|
||||||
let url = format!(
|
let url = format!(
|
||||||
"https://haveibeenpwned.com/api/v3/breachedaccount/{}?truncateResponse=false&includeUnverified=false",
|
"https://haveibeenpwned.com/api/v3/breachedaccount/{}?truncateResponse=false&includeUnverified=false",
|
||||||
username
|
username
|
||||||
|
@ -153,14 +153,14 @@ fn hibp_breach(username: String) -> JsonResult {
|
||||||
if let Some(api_key) = crate::CONFIG.hibp_api_key() {
|
if let Some(api_key) = crate::CONFIG.hibp_api_key() {
|
||||||
let hibp_client = get_reqwest_client();
|
let hibp_client = get_reqwest_client();
|
||||||
|
|
||||||
let res = hibp_client.get(&url).header("hibp-api-key", api_key).send()?;
|
let res = hibp_client.get(&url).header("hibp-api-key", api_key).send().await?;
|
||||||
|
|
||||||
// If we get a 404, return a 404, it means no breached accounts
|
// If we get a 404, return a 404, it means no breached accounts
|
||||||
if res.status() == 404 {
|
if res.status() == 404 {
|
||||||
return Err(Error::empty().with_code(404));
|
return Err(Error::empty().with_code(404));
|
||||||
}
|
}
|
||||||
|
|
||||||
let value: Value = res.error_for_status()?.json()?;
|
let value: Value = res.error_for_status()?.json().await?;
|
||||||
Ok(Json(value))
|
Ok(Json(value))
|
||||||
} else {
|
} else {
|
||||||
Ok(Json(json!([{
|
Ok(Json(json!([{
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use num_traits::FromPrimitive;
|
use num_traits::FromPrimitive;
|
||||||
use rocket::{request::Form, Route};
|
use rocket::serde::json::Json;
|
||||||
use rocket_contrib::json::Json;
|
use rocket::Route;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -469,12 +469,12 @@ fn put_collection_users(
|
||||||
|
|
||||||
#[derive(FromForm)]
|
#[derive(FromForm)]
|
||||||
struct OrgIdData {
|
struct OrgIdData {
|
||||||
#[form(field = "organizationId")]
|
#[field(name = "organizationId")]
|
||||||
organization_id: String,
|
organization_id: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/ciphers/organization-details?<data..>")]
|
#[get("/ciphers/organization-details?<data..>")]
|
||||||
fn get_org_details(data: Form<OrgIdData>, headers: Headers, conn: DbConn) -> Json<Value> {
|
fn get_org_details(data: OrgIdData, headers: Headers, conn: DbConn) -> Json<Value> {
|
||||||
let ciphers = Cipher::find_by_org(&data.organization_id, &conn);
|
let ciphers = Cipher::find_by_org(&data.organization_id, &conn);
|
||||||
let ciphers_json: Vec<Value> =
|
let ciphers_json: Vec<Value> =
|
||||||
ciphers.iter().map(|c| c.to_json(&headers.host, &headers.user.uuid, &conn)).collect();
|
ciphers.iter().map(|c| c.to_json(&headers.host, &headers.user.uuid, &conn)).collect();
|
||||||
|
@ -1097,14 +1097,14 @@ struct RelationsData {
|
||||||
|
|
||||||
#[post("/ciphers/import-organization?<query..>", data = "<data>")]
|
#[post("/ciphers/import-organization?<query..>", data = "<data>")]
|
||||||
fn post_org_import(
|
fn post_org_import(
|
||||||
query: Form<OrgIdData>,
|
query: OrgIdData,
|
||||||
data: JsonUpcase<ImportData>,
|
data: JsonUpcase<ImportData>,
|
||||||
headers: AdminHeaders,
|
headers: AdminHeaders,
|
||||||
conn: DbConn,
|
conn: DbConn,
|
||||||
nt: Notify,
|
nt: Notify,
|
||||||
) -> EmptyResult {
|
) -> EmptyResult {
|
||||||
let data: ImportData = data.into_inner().data;
|
let data: ImportData = data.into_inner().data;
|
||||||
let org_id = query.into_inner().organization_id;
|
let org_id = query.organization_id;
|
||||||
|
|
||||||
// Read and create the collections
|
// Read and create the collections
|
||||||
let collections: Vec<_> = data
|
let collections: Vec<_> = data
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
use std::{io::Read, path::Path};
|
use std::path::Path;
|
||||||
|
|
||||||
use chrono::{DateTime, Duration, Utc};
|
use chrono::{DateTime, Duration, Utc};
|
||||||
use multipart::server::{save::SavedData, Multipart, SaveResult};
|
use rocket::form::Form;
|
||||||
use rocket::{http::ContentType, response::NamedFile, Data};
|
use rocket::fs::NamedFile;
|
||||||
use rocket_contrib::json::Json;
|
use rocket::fs::TempFile;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -31,9 +32,9 @@ pub fn routes() -> Vec<rocket::Route> {
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn purge_sends(pool: DbPool) {
|
pub async fn purge_sends(pool: DbPool) {
|
||||||
debug!("Purging sends");
|
debug!("Purging sends");
|
||||||
if let Ok(conn) = pool.get() {
|
if let Ok(conn) = pool.get().await {
|
||||||
Send::purge(&conn);
|
Send::purge(&conn);
|
||||||
} else {
|
} else {
|
||||||
error!("Failed to get DB connection while purging sends")
|
error!("Failed to get DB connection while purging sends")
|
||||||
|
@ -177,25 +178,23 @@ fn post_send(data: JsonUpcase<SendData>, headers: Headers, conn: DbConn, nt: Not
|
||||||
Ok(Json(send.to_json()))
|
Ok(Json(send.to_json()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(FromForm)]
|
||||||
|
struct UploadData<'f> {
|
||||||
|
model: Json<crate::util::UpCase<SendData>>,
|
||||||
|
data: TempFile<'f>,
|
||||||
|
}
|
||||||
|
|
||||||
#[post("/sends/file", format = "multipart/form-data", data = "<data>")]
|
#[post("/sends/file", format = "multipart/form-data", data = "<data>")]
|
||||||
fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn: DbConn, nt: Notify) -> JsonResult {
|
async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, conn: DbConn, nt: Notify<'_>) -> JsonResult {
|
||||||
enforce_disable_send_policy(&headers, &conn)?;
|
enforce_disable_send_policy(&headers, &conn)?;
|
||||||
|
|
||||||
let boundary = content_type.params().next().expect("No boundary provided").1;
|
let UploadData {
|
||||||
|
model,
|
||||||
|
mut data,
|
||||||
|
} = data.into_inner();
|
||||||
|
let model = model.into_inner().data;
|
||||||
|
|
||||||
let mut mpart = Multipart::with_body(data.open(), boundary);
|
enforce_disable_hide_email_policy(&model, &headers, &conn)?;
|
||||||
|
|
||||||
// First entry is the SendData JSON
|
|
||||||
let mut model_entry = match mpart.read_entry()? {
|
|
||||||
Some(e) if &*e.headers.name == "model" => e,
|
|
||||||
Some(_) => err!("Invalid entry name"),
|
|
||||||
None => err!("No model entry present"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut buf = String::new();
|
|
||||||
model_entry.data.read_to_string(&mut buf)?;
|
|
||||||
let data = serde_json::from_str::<crate::util::UpCase<SendData>>(&buf)?;
|
|
||||||
enforce_disable_hide_email_policy(&data.data, &headers, &conn)?;
|
|
||||||
|
|
||||||
// Get the file length and add an extra 5% to avoid issues
|
// Get the file length and add an extra 5% to avoid issues
|
||||||
const SIZE_525_MB: u64 = 550_502_400;
|
const SIZE_525_MB: u64 = 550_502_400;
|
||||||
|
@ -212,45 +211,27 @@ fn post_send_file(data: Data, content_type: &ContentType, headers: Headers, conn
|
||||||
None => SIZE_525_MB,
|
None => SIZE_525_MB,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create the Send
|
let mut send = create_send(model, headers.user.uuid)?;
|
||||||
let mut send = create_send(data.data, headers.user.uuid)?;
|
|
||||||
let file_id = crate::crypto::generate_send_id();
|
|
||||||
|
|
||||||
if send.atype != SendType::File as i32 {
|
if send.atype != SendType::File as i32 {
|
||||||
err!("Send content is not a file");
|
err!("Send content is not a file");
|
||||||
}
|
}
|
||||||
|
|
||||||
let file_path = Path::new(&CONFIG.sends_folder()).join(&send.uuid).join(&file_id);
|
let size = data.len();
|
||||||
|
if size > size_limit {
|
||||||
|
err!("Attachment storage limit exceeded with this file");
|
||||||
|
}
|
||||||
|
|
||||||
// Read the data entry and save the file
|
let file_id = crate::crypto::generate_send_id();
|
||||||
let mut data_entry = match mpart.read_entry()? {
|
let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(&send.uuid);
|
||||||
Some(e) if &*e.headers.name == "data" => e,
|
let file_path = folder_path.join(&file_id);
|
||||||
Some(_) => err!("Invalid entry name"),
|
tokio::fs::create_dir_all(&folder_path).await?;
|
||||||
None => err!("No model entry present"),
|
data.persist_to(&file_path).await?;
|
||||||
};
|
|
||||||
|
|
||||||
let size = match data_entry.data.save().memory_threshold(0).size_limit(size_limit).with_path(&file_path) {
|
|
||||||
SaveResult::Full(SavedData::File(_, size)) => size as i32,
|
|
||||||
SaveResult::Full(other) => {
|
|
||||||
std::fs::remove_file(&file_path).ok();
|
|
||||||
err!(format!("Attachment is not a file: {:?}", other));
|
|
||||||
}
|
|
||||||
SaveResult::Partial(_, reason) => {
|
|
||||||
std::fs::remove_file(&file_path).ok();
|
|
||||||
err!(format!("Attachment storage limit exceeded with this file: {:?}", reason));
|
|
||||||
}
|
|
||||||
SaveResult::Error(e) => {
|
|
||||||
std::fs::remove_file(&file_path).ok();
|
|
||||||
err!(format!("Error: {:?}", e));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Set ID and sizes
|
|
||||||
let mut data_value: Value = serde_json::from_str(&send.data)?;
|
let mut data_value: Value = serde_json::from_str(&send.data)?;
|
||||||
if let Some(o) = data_value.as_object_mut() {
|
if let Some(o) = data_value.as_object_mut() {
|
||||||
o.insert(String::from("Id"), Value::String(file_id));
|
o.insert(String::from("Id"), Value::String(file_id));
|
||||||
o.insert(String::from("Size"), Value::Number(size.into()));
|
o.insert(String::from("Size"), Value::Number(size.into()));
|
||||||
o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size)));
|
o.insert(String::from("SizeName"), Value::String(crate::util::get_display_size(size as i32)));
|
||||||
}
|
}
|
||||||
send.data = serde_json::to_string(&data_value)?;
|
send.data = serde_json::to_string(&data_value)?;
|
||||||
|
|
||||||
|
@ -367,10 +348,10 @@ fn post_access_file(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/sends/<send_id>/<file_id>?<t>")]
|
#[get("/sends/<send_id>/<file_id>?<t>")]
|
||||||
fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
|
async fn download_send(send_id: SafeString, file_id: SafeString, t: String) -> Option<NamedFile> {
|
||||||
if let Ok(claims) = crate::auth::decode_send(&t) {
|
if let Ok(claims) = crate::auth::decode_send(&t) {
|
||||||
if claims.sub == format!("{}/{}", send_id, file_id) {
|
if claims.sub == format!("{}/{}", send_id, file_id) {
|
||||||
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).ok();
|
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use data_encoding::BASE32;
|
use data_encoding::BASE32;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{
|
api::{
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use data_encoding::BASE64;
|
use data_encoding::BASE64;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, JsonUpcase, PasswordData},
|
api::{core::two_factor::_generate_recover_code, ApiResult, EmptyResult, JsonResult, JsonUpcase, PasswordData},
|
||||||
|
@ -152,7 +152,7 @@ fn check_duo_fields_custom(data: &EnableDuoData) -> bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/two-factor/duo", data = "<data>")]
|
#[post("/two-factor/duo", data = "<data>")]
|
||||||
fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
async fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
||||||
let data: EnableDuoData = data.into_inner().data;
|
let data: EnableDuoData = data.into_inner().data;
|
||||||
let mut user = headers.user;
|
let mut user = headers.user;
|
||||||
|
|
||||||
|
@ -163,7 +163,7 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
|
||||||
let (data, data_str) = if check_duo_fields_custom(&data) {
|
let (data, data_str) = if check_duo_fields_custom(&data) {
|
||||||
let data_req: DuoData = data.into();
|
let data_req: DuoData = data.into();
|
||||||
let data_str = serde_json::to_string(&data_req)?;
|
let data_str = serde_json::to_string(&data_req)?;
|
||||||
duo_api_request("GET", "/auth/v2/check", "", &data_req).map_res("Failed to validate Duo credentials")?;
|
duo_api_request("GET", "/auth/v2/check", "", &data_req).await.map_res("Failed to validate Duo credentials")?;
|
||||||
(data_req.obscure(), data_str)
|
(data_req.obscure(), data_str)
|
||||||
} else {
|
} else {
|
||||||
(DuoData::secret(), String::new())
|
(DuoData::secret(), String::new())
|
||||||
|
@ -185,11 +185,11 @@ fn activate_duo(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[put("/two-factor/duo", data = "<data>")]
|
#[put("/two-factor/duo", data = "<data>")]
|
||||||
fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
async fn activate_duo_put(data: JsonUpcase<EnableDuoData>, headers: Headers, conn: DbConn) -> JsonResult {
|
||||||
activate_duo(data, headers, conn)
|
activate_duo(data, headers, conn).await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
|
async fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> EmptyResult {
|
||||||
use reqwest::{header, Method};
|
use reqwest::{header, Method};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
@ -209,7 +209,8 @@ fn duo_api_request(method: &str, path: &str, params: &str, data: &DuoData) -> Em
|
||||||
.basic_auth(username, Some(password))
|
.basic_auth(username, Some(password))
|
||||||
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
|
.header(header::USER_AGENT, "vaultwarden:Duo/1.0 (Rust)")
|
||||||
.header(header::DATE, date)
|
.header(header::DATE, date)
|
||||||
.send()?
|
.send()
|
||||||
|
.await?
|
||||||
.error_for_status()?;
|
.error_for_status()?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use chrono::{Duration, NaiveDateTime, Utc};
|
use chrono::{Duration, NaiveDateTime, Utc};
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
api::{core::two_factor::_generate_recover_code, EmptyResult, JsonResult, JsonUpcase, PasswordData},
|
api::{core::two_factor::_generate_recover_code, EmptyResult, JsonResult, JsonUpcase, PasswordData},
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use chrono::{Duration, Utc};
|
use chrono::{Duration, Utc};
|
||||||
use data_encoding::BASE32;
|
use data_encoding::BASE32;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -158,14 +158,14 @@ fn disable_twofactor_put(data: JsonUpcase<DisableTwoFactorData>, headers: Header
|
||||||
disable_twofactor(data, headers, conn)
|
disable_twofactor(data, headers, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_incomplete_2fa_notifications(pool: DbPool) {
|
pub async fn send_incomplete_2fa_notifications(pool: DbPool) {
|
||||||
debug!("Sending notifications for incomplete 2FA logins");
|
debug!("Sending notifications for incomplete 2FA logins");
|
||||||
|
|
||||||
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
|
if CONFIG.incomplete_2fa_time_limit() <= 0 || !CONFIG.mail_enabled() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let conn = match pool.get() {
|
let conn = match pool.get().await {
|
||||||
Ok(conn) => conn,
|
Ok(conn) => conn,
|
||||||
_ => {
|
_ => {
|
||||||
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
|
error!("Failed to get DB connection in send_incomplete_2fa_notifications()");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use u2f::{
|
use u2f::{
|
||||||
messages::{RegisterResponse, SignResponse, U2fSignRequest},
|
messages::{RegisterResponse, SignResponse, U2fSignRequest},
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
use webauthn_rs::{base64_data::Base64UrlSafeData, proto::*, AuthenticationState, RegistrationState, Webauthn};
|
use webauthn_rs::{base64_data::Base64UrlSafeData, proto::*, AuthenticationState, RegistrationState, Webauthn};
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use yubico::{config::Config, verify};
|
use yubico::{config::Config, verify};
|
||||||
|
|
||||||
|
|
125
src/api/icons.rs
125
src/api/icons.rs
|
@ -1,19 +1,19 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
fs::{create_dir_all, remove_file, symlink_metadata, File},
|
|
||||||
io::prelude::*,
|
|
||||||
net::{IpAddr, ToSocketAddrs},
|
net::{IpAddr, ToSocketAddrs},
|
||||||
sync::{Arc, RwLock},
|
sync::{Arc, RwLock},
|
||||||
time::{Duration, SystemTime},
|
time::{Duration, SystemTime},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use bytes::{Buf, Bytes, BytesMut};
|
||||||
|
use futures::{stream::StreamExt, TryFutureExt};
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use reqwest::{blocking::Client, blocking::Response, header};
|
use reqwest::{header, Client, Response};
|
||||||
use rocket::{
|
use rocket::{http::ContentType, response::Redirect, Route};
|
||||||
http::ContentType,
|
use tokio::{
|
||||||
response::{Content, Redirect},
|
fs::{create_dir_all, remove_file, symlink_metadata, File},
|
||||||
Route,
|
io::{AsyncReadExt, AsyncWriteExt},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -104,27 +104,23 @@ fn icon_google(domain: String) -> Option<Redirect> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/<domain>/icon.png")]
|
#[get("/<domain>/icon.png")]
|
||||||
fn icon_internal(domain: String) -> Cached<Content<Vec<u8>>> {
|
async fn icon_internal(domain: String) -> Cached<(ContentType, Vec<u8>)> {
|
||||||
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
|
const FALLBACK_ICON: &[u8] = include_bytes!("../static/images/fallback-icon.png");
|
||||||
|
|
||||||
if !is_valid_domain(&domain) {
|
if !is_valid_domain(&domain) {
|
||||||
warn!("Invalid domain: {}", domain);
|
warn!("Invalid domain: {}", domain);
|
||||||
return Cached::ttl(
|
return Cached::ttl(
|
||||||
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
|
(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
|
||||||
CONFIG.icon_cache_negttl(),
|
CONFIG.icon_cache_negttl(),
|
||||||
true,
|
true,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
match get_icon(&domain) {
|
match get_icon(&domain).await {
|
||||||
Some((icon, icon_type)) => {
|
Some((icon, icon_type)) => {
|
||||||
Cached::ttl(Content(ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
|
Cached::ttl((ContentType::new("image", icon_type), icon), CONFIG.icon_cache_ttl(), true)
|
||||||
}
|
}
|
||||||
_ => Cached::ttl(
|
_ => Cached::ttl((ContentType::new("image", "png"), FALLBACK_ICON.to_vec()), CONFIG.icon_cache_negttl(), true),
|
||||||
Content(ContentType::new("image", "png"), FALLBACK_ICON.to_vec()),
|
|
||||||
CONFIG.icon_cache_negttl(),
|
|
||||||
true,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -317,15 +313,15 @@ fn is_domain_blacklisted(domain: &str) -> bool {
|
||||||
is_blacklisted
|
is_blacklisted
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
async fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
||||||
let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain);
|
let path = format!("{}/{}.png", CONFIG.icon_cache_folder(), domain);
|
||||||
|
|
||||||
// Check for expiration of negatively cached copy
|
// Check for expiration of negatively cached copy
|
||||||
if icon_is_negcached(&path) {
|
if icon_is_negcached(&path).await {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(icon) = get_cached_icon(&path) {
|
if let Some(icon) = get_cached_icon(&path).await {
|
||||||
let icon_type = match get_icon_type(&icon) {
|
let icon_type = match get_icon_type(&icon) {
|
||||||
Some(x) => x,
|
Some(x) => x,
|
||||||
_ => "x-icon",
|
_ => "x-icon",
|
||||||
|
@ -338,31 +334,31 @@ fn get_icon(domain: &str) -> Option<(Vec<u8>, String)> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the icon, or None in case of error
|
// Get the icon, or None in case of error
|
||||||
match download_icon(domain) {
|
match download_icon(domain).await {
|
||||||
Ok((icon, icon_type)) => {
|
Ok((icon, icon_type)) => {
|
||||||
save_icon(&path, &icon);
|
save_icon(&path, &icon).await;
|
||||||
Some((icon, icon_type.unwrap_or("x-icon").to_string()))
|
Some((icon.to_vec(), icon_type.unwrap_or("x-icon").to_string()))
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Unable to download icon: {:?}", e);
|
warn!("Unable to download icon: {:?}", e);
|
||||||
let miss_indicator = path + ".miss";
|
let miss_indicator = path + ".miss";
|
||||||
save_icon(&miss_indicator, &[]);
|
save_icon(&miss_indicator, &[]).await;
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
|
async fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
|
||||||
// Check for expiration of successfully cached copy
|
// Check for expiration of successfully cached copy
|
||||||
if icon_is_expired(path) {
|
if icon_is_expired(path).await {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to read the cached icon, and return it if it exists
|
// Try to read the cached icon, and return it if it exists
|
||||||
if let Ok(mut f) = File::open(path) {
|
if let Ok(mut f) = File::open(path).await {
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
|
|
||||||
if f.read_to_end(&mut buffer).is_ok() {
|
if f.read_to_end(&mut buffer).await.is_ok() {
|
||||||
return Some(buffer);
|
return Some(buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -370,22 +366,22 @@ fn get_cached_icon(path: &str) -> Option<Vec<u8>> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
|
async fn file_is_expired(path: &str, ttl: u64) -> Result<bool, Error> {
|
||||||
let meta = symlink_metadata(path)?;
|
let meta = symlink_metadata(path).await?;
|
||||||
let modified = meta.modified()?;
|
let modified = meta.modified()?;
|
||||||
let age = SystemTime::now().duration_since(modified)?;
|
let age = SystemTime::now().duration_since(modified)?;
|
||||||
|
|
||||||
Ok(ttl > 0 && ttl <= age.as_secs())
|
Ok(ttl > 0 && ttl <= age.as_secs())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon_is_negcached(path: &str) -> bool {
|
async fn icon_is_negcached(path: &str) -> bool {
|
||||||
let miss_indicator = path.to_owned() + ".miss";
|
let miss_indicator = path.to_owned() + ".miss";
|
||||||
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl());
|
let expired = file_is_expired(&miss_indicator, CONFIG.icon_cache_negttl()).await;
|
||||||
|
|
||||||
match expired {
|
match expired {
|
||||||
// No longer negatively cached, drop the marker
|
// No longer negatively cached, drop the marker
|
||||||
Ok(true) => {
|
Ok(true) => {
|
||||||
if let Err(e) = remove_file(&miss_indicator) {
|
if let Err(e) = remove_file(&miss_indicator).await {
|
||||||
error!("Could not remove negative cache indicator for icon {:?}: {:?}", path, e);
|
error!("Could not remove negative cache indicator for icon {:?}: {:?}", path, e);
|
||||||
}
|
}
|
||||||
false
|
false
|
||||||
|
@ -397,8 +393,8 @@ fn icon_is_negcached(path: &str) -> bool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn icon_is_expired(path: &str) -> bool {
|
async fn icon_is_expired(path: &str) -> bool {
|
||||||
let expired = file_is_expired(path, CONFIG.icon_cache_ttl());
|
let expired = file_is_expired(path, CONFIG.icon_cache_ttl()).await;
|
||||||
expired.unwrap_or(true)
|
expired.unwrap_or(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -521,13 +517,13 @@ struct IconUrlResult {
|
||||||
/// let icon_result = get_icon_url("github.com")?;
|
/// let icon_result = get_icon_url("github.com")?;
|
||||||
/// let icon_result = get_icon_url("vaultwarden.discourse.group")?;
|
/// let icon_result = get_icon_url("vaultwarden.discourse.group")?;
|
||||||
/// ```
|
/// ```
|
||||||
fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
async fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||||
// Default URL with secure and insecure schemes
|
// Default URL with secure and insecure schemes
|
||||||
let ssldomain = format!("https://{}", domain);
|
let ssldomain = format!("https://{}", domain);
|
||||||
let httpdomain = format!("http://{}", domain);
|
let httpdomain = format!("http://{}", domain);
|
||||||
|
|
||||||
// First check the domain as given during the request for both HTTPS and HTTP.
|
// First check the domain as given during the request for both HTTPS and HTTP.
|
||||||
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)) {
|
let resp = match get_page(&ssldomain).or_else(|_| get_page(&httpdomain)).await {
|
||||||
Ok(c) => Ok(c),
|
Ok(c) => Ok(c),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let mut sub_resp = Err(e);
|
let mut sub_resp = Err(e);
|
||||||
|
@ -546,7 +542,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||||
let httpbase = format!("http://{}", base_domain);
|
let httpbase = format!("http://{}", base_domain);
|
||||||
debug!("[get_icon_url]: Trying without subdomains '{}'", base_domain);
|
debug!("[get_icon_url]: Trying without subdomains '{}'", base_domain);
|
||||||
|
|
||||||
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase));
|
sub_resp = get_page(&sslbase).or_else(|_| get_page(&httpbase)).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
|
// When the domain is not an IP, and has less then 2 dots, try to add www. infront of it.
|
||||||
|
@ -557,7 +553,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||||
let httpwww = format!("http://{}", www_domain);
|
let httpwww = format!("http://{}", www_domain);
|
||||||
debug!("[get_icon_url]: Trying with www. prefix '{}'", www_domain);
|
debug!("[get_icon_url]: Trying with www. prefix '{}'", www_domain);
|
||||||
|
|
||||||
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww));
|
sub_resp = get_page(&sslwww).or_else(|_| get_page(&httpwww)).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -581,7 +577,7 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||||
iconlist.push(Icon::new(35, String::from(url.join("/favicon.ico").unwrap())));
|
iconlist.push(Icon::new(35, String::from(url.join("/favicon.ico").unwrap())));
|
||||||
|
|
||||||
// 384KB should be more than enough for the HTML, though as we only really need the HTML header.
|
// 384KB should be more than enough for the HTML, though as we only really need the HTML header.
|
||||||
let mut limited_reader = content.take(384 * 1024);
|
let mut limited_reader = stream_to_bytes_limit(content, 384 * 1024).await?.reader();
|
||||||
|
|
||||||
use html5ever::tendril::TendrilSink;
|
use html5ever::tendril::TendrilSink;
|
||||||
let dom = html5ever::parse_document(markup5ever_rcdom::RcDom::default(), Default::default())
|
let dom = html5ever::parse_document(markup5ever_rcdom::RcDom::default(), Default::default())
|
||||||
|
@ -607,11 +603,11 @@ fn get_icon_url(domain: &str) -> Result<IconUrlResult, Error> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_page(url: &str) -> Result<Response, Error> {
|
async fn get_page(url: &str) -> Result<Response, Error> {
|
||||||
get_page_with_referer(url, "")
|
get_page_with_referer(url, "").await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
async fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
||||||
if is_domain_blacklisted(url::Url::parse(url).unwrap().host_str().unwrap_or_default()) {
|
if is_domain_blacklisted(url::Url::parse(url).unwrap().host_str().unwrap_or_default()) {
|
||||||
warn!("Favicon '{}' resolves to a blacklisted domain or IP!", url);
|
warn!("Favicon '{}' resolves to a blacklisted domain or IP!", url);
|
||||||
}
|
}
|
||||||
|
@ -621,7 +617,7 @@ fn get_page_with_referer(url: &str, referer: &str) -> Result<Response, Error> {
|
||||||
client = client.header("Referer", referer)
|
client = client.header("Referer", referer)
|
||||||
}
|
}
|
||||||
|
|
||||||
match client.send() {
|
match client.send().await {
|
||||||
Ok(c) => c.error_for_status().map_err(Into::into),
|
Ok(c) => c.error_for_status().map_err(Into::into),
|
||||||
Err(e) => err_silent!(format!("{}", e)),
|
Err(e) => err_silent!(format!("{}", e)),
|
||||||
}
|
}
|
||||||
|
@ -706,14 +702,14 @@ fn parse_sizes(sizes: Option<&str>) -> (u16, u16) {
|
||||||
(width, height)
|
(width, height)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
async fn download_icon(domain: &str) -> Result<(Bytes, Option<&str>), Error> {
|
||||||
if is_domain_blacklisted(domain) {
|
if is_domain_blacklisted(domain) {
|
||||||
err_silent!("Domain is blacklisted", domain)
|
err_silent!("Domain is blacklisted", domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
let icon_result = get_icon_url(domain)?;
|
let icon_result = get_icon_url(domain).await?;
|
||||||
|
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Bytes::new();
|
||||||
let mut icon_type: Option<&str> = None;
|
let mut icon_type: Option<&str> = None;
|
||||||
|
|
||||||
use data_url::DataUrl;
|
use data_url::DataUrl;
|
||||||
|
@ -722,8 +718,12 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
||||||
if icon.href.starts_with("data:image") {
|
if icon.href.starts_with("data:image") {
|
||||||
let datauri = DataUrl::process(&icon.href).unwrap();
|
let datauri = DataUrl::process(&icon.href).unwrap();
|
||||||
// Check if we are able to decode the data uri
|
// Check if we are able to decode the data uri
|
||||||
match datauri.decode_to_vec() {
|
let mut body = BytesMut::new();
|
||||||
Ok((body, _fragment)) => {
|
match datauri.decode::<_, ()>(|bytes| {
|
||||||
|
body.extend_from_slice(bytes);
|
||||||
|
Ok(())
|
||||||
|
}) {
|
||||||
|
Ok(_) => {
|
||||||
// Also check if the size is atleast 67 bytes, which seems to be the smallest png i could create
|
// Also check if the size is atleast 67 bytes, which seems to be the smallest png i could create
|
||||||
if body.len() >= 67 {
|
if body.len() >= 67 {
|
||||||
// Check if the icon type is allowed, else try an icon from the list.
|
// Check if the icon type is allowed, else try an icon from the list.
|
||||||
|
@ -733,17 +733,17 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
info!("Extracted icon from data:image uri for {}", domain);
|
info!("Extracted icon from data:image uri for {}", domain);
|
||||||
buffer = body;
|
buffer = body.freeze();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => debug!("Extracted icon from data:image uri is invalid"),
|
_ => debug!("Extracted icon from data:image uri is invalid"),
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
match get_page_with_referer(&icon.href, &icon_result.referer) {
|
match get_page_with_referer(&icon.href, &icon_result.referer).await {
|
||||||
Ok(mut res) => {
|
Ok(res) => {
|
||||||
res.copy_to(&mut buffer)?;
|
buffer = stream_to_bytes_limit(res, 512 * 1024).await?; // 512 KB for each icon max
|
||||||
// Check if the icon type is allowed, else try an icon from the list.
|
// Check if the icon type is allowed, else try an icon from the list.
|
||||||
icon_type = get_icon_type(&buffer);
|
icon_type = get_icon_type(&buffer);
|
||||||
if icon_type.is_none() {
|
if icon_type.is_none() {
|
||||||
buffer.clear();
|
buffer.clear();
|
||||||
|
@ -765,13 +765,13 @@ fn download_icon(domain: &str) -> Result<(Vec<u8>, Option<&str>), Error> {
|
||||||
Ok((buffer, icon_type))
|
Ok((buffer, icon_type))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn save_icon(path: &str, icon: &[u8]) {
|
async fn save_icon(path: &str, icon: &[u8]) {
|
||||||
match File::create(path) {
|
match File::create(path).await {
|
||||||
Ok(mut f) => {
|
Ok(mut f) => {
|
||||||
f.write_all(icon).expect("Error writing icon file");
|
f.write_all(icon).await.expect("Error writing icon file");
|
||||||
}
|
}
|
||||||
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
|
Err(ref e) if e.kind() == std::io::ErrorKind::NotFound => {
|
||||||
create_dir_all(&CONFIG.icon_cache_folder()).expect("Error creating icon cache folder");
|
create_dir_all(&CONFIG.icon_cache_folder()).await.expect("Error creating icon cache folder");
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Unable to save icon: {:?}", e);
|
warn!("Unable to save icon: {:?}", e);
|
||||||
|
@ -820,8 +820,6 @@ impl reqwest::cookie::CookieStore for Jar {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cookies(&self, url: &url::Url) -> Option<header::HeaderValue> {
|
fn cookies(&self, url: &url::Url) -> Option<header::HeaderValue> {
|
||||||
use bytes::Bytes;
|
|
||||||
|
|
||||||
let cookie_store = self.0.read().unwrap();
|
let cookie_store = self.0.read().unwrap();
|
||||||
let s = cookie_store
|
let s = cookie_store
|
||||||
.get_request_values(url)
|
.get_request_values(url)
|
||||||
|
@ -836,3 +834,12 @@ impl reqwest::cookie::CookieStore for Jar {
|
||||||
header::HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
|
header::HeaderValue::from_maybe_shared(Bytes::from(s)).ok()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn stream_to_bytes_limit(res: Response, max_size: usize) -> Result<Bytes, reqwest::Error> {
|
||||||
|
let mut stream = res.bytes_stream().take(max_size);
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
buf.extend(chunk?);
|
||||||
|
}
|
||||||
|
Ok(buf.freeze())
|
||||||
|
}
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use num_traits::FromPrimitive;
|
use num_traits::FromPrimitive;
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::{
|
use rocket::{
|
||||||
request::{Form, FormItems, FromForm},
|
form::{Form, FromForm},
|
||||||
Route,
|
Route,
|
||||||
};
|
};
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -455,66 +455,57 @@ fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &DbConn) -> Api
|
||||||
|
|
||||||
// https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
|
// https://github.com/bitwarden/jslib/blob/master/common/src/models/request/tokenRequest.ts
|
||||||
// https://github.com/bitwarden/mobile/blob/master/src/Core/Models/Request/TokenRequest.cs
|
// https://github.com/bitwarden/mobile/blob/master/src/Core/Models/Request/TokenRequest.cs
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default, FromForm)]
|
||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
struct ConnectData {
|
struct ConnectData {
|
||||||
// refresh_token, password, client_credentials (API key)
|
#[field(name = uncased("grant_type"))]
|
||||||
grant_type: String,
|
#[field(name = uncased("granttype"))]
|
||||||
|
grant_type: String, // refresh_token, password, client_credentials (API key)
|
||||||
|
|
||||||
// Needed for grant_type="refresh_token"
|
// Needed for grant_type="refresh_token"
|
||||||
|
#[field(name = uncased("refresh_token"))]
|
||||||
|
#[field(name = uncased("refreshtoken"))]
|
||||||
refresh_token: Option<String>,
|
refresh_token: Option<String>,
|
||||||
|
|
||||||
// Needed for grant_type = "password" | "client_credentials"
|
// Needed for grant_type = "password" | "client_credentials"
|
||||||
client_id: Option<String>, // web, cli, desktop, browser, mobile
|
#[field(name = uncased("client_id"))]
|
||||||
client_secret: Option<String>, // API key login (cli only)
|
#[field(name = uncased("clientid"))]
|
||||||
|
client_id: Option<String>, // web, cli, desktop, browser, mobile
|
||||||
|
#[field(name = uncased("client_secret"))]
|
||||||
|
#[field(name = uncased("clientsecret"))]
|
||||||
|
client_secret: Option<String>,
|
||||||
|
#[field(name = uncased("password"))]
|
||||||
password: Option<String>,
|
password: Option<String>,
|
||||||
|
#[field(name = uncased("scope"))]
|
||||||
scope: Option<String>,
|
scope: Option<String>,
|
||||||
|
#[field(name = uncased("username"))]
|
||||||
username: Option<String>,
|
username: Option<String>,
|
||||||
|
|
||||||
|
#[field(name = uncased("device_identifier"))]
|
||||||
|
#[field(name = uncased("deviceidentifier"))]
|
||||||
device_identifier: Option<String>,
|
device_identifier: Option<String>,
|
||||||
|
#[field(name = uncased("device_name"))]
|
||||||
|
#[field(name = uncased("devicename"))]
|
||||||
device_name: Option<String>,
|
device_name: Option<String>,
|
||||||
|
#[field(name = uncased("device_type"))]
|
||||||
|
#[field(name = uncased("devicetype"))]
|
||||||
device_type: Option<String>,
|
device_type: Option<String>,
|
||||||
|
#[field(name = uncased("device_push_token"))]
|
||||||
|
#[field(name = uncased("devicepushtoken"))]
|
||||||
device_push_token: Option<String>, // Unused; mobile device push not yet supported.
|
device_push_token: Option<String>, // Unused; mobile device push not yet supported.
|
||||||
|
|
||||||
// Needed for two-factor auth
|
// Needed for two-factor auth
|
||||||
|
#[field(name = uncased("two_factor_provider"))]
|
||||||
|
#[field(name = uncased("twofactorprovider"))]
|
||||||
two_factor_provider: Option<i32>,
|
two_factor_provider: Option<i32>,
|
||||||
|
#[field(name = uncased("two_factor_token"))]
|
||||||
|
#[field(name = uncased("twofactortoken"))]
|
||||||
two_factor_token: Option<String>,
|
two_factor_token: Option<String>,
|
||||||
|
#[field(name = uncased("two_factor_remember"))]
|
||||||
|
#[field(name = uncased("twofactorremember"))]
|
||||||
two_factor_remember: Option<i32>,
|
two_factor_remember: Option<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'f> FromForm<'f> for ConnectData {
|
|
||||||
type Error = String;
|
|
||||||
|
|
||||||
fn from_form(items: &mut FormItems<'f>, _strict: bool) -> Result<Self, Self::Error> {
|
|
||||||
let mut form = Self::default();
|
|
||||||
for item in items {
|
|
||||||
let (key, value) = item.key_value_decoded();
|
|
||||||
let mut normalized_key = key.to_lowercase();
|
|
||||||
normalized_key.retain(|c| c != '_'); // Remove '_'
|
|
||||||
|
|
||||||
match normalized_key.as_ref() {
|
|
||||||
"granttype" => form.grant_type = value,
|
|
||||||
"refreshtoken" => form.refresh_token = Some(value),
|
|
||||||
"clientid" => form.client_id = Some(value),
|
|
||||||
"clientsecret" => form.client_secret = Some(value),
|
|
||||||
"password" => form.password = Some(value),
|
|
||||||
"scope" => form.scope = Some(value),
|
|
||||||
"username" => form.username = Some(value),
|
|
||||||
"deviceidentifier" => form.device_identifier = Some(value),
|
|
||||||
"devicename" => form.device_name = Some(value),
|
|
||||||
"devicetype" => form.device_type = Some(value),
|
|
||||||
"devicepushtoken" => form.device_push_token = Some(value),
|
|
||||||
"twofactorprovider" => form.two_factor_provider = value.parse().ok(),
|
|
||||||
"twofactortoken" => form.two_factor_token = Some(value),
|
|
||||||
"twofactorremember" => form.two_factor_remember = value.parse().ok(),
|
|
||||||
key => warn!("Detected unexpected parameter during login: {}", key),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(form)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult {
|
fn _check_is_some<T>(value: &Option<T>, msg: &str) -> EmptyResult {
|
||||||
if value.is_none() {
|
if value.is_none() {
|
||||||
err!(msg)
|
err!(msg)
|
||||||
|
|
|
@ -5,7 +5,7 @@ mod identity;
|
||||||
mod notifications;
|
mod notifications;
|
||||||
mod web;
|
mod web;
|
||||||
|
|
||||||
use rocket_contrib::json::Json;
|
use rocket::serde::json::Json;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
pub use crate::api::{
|
pub use crate::api::{
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
|
||||||
|
use rocket::serde::json::Json;
|
||||||
use rocket::Route;
|
use rocket::Route;
|
||||||
use rocket_contrib::json::Json;
|
|
||||||
use serde_json::Value as JsonValue;
|
use serde_json::Value as JsonValue;
|
||||||
|
|
||||||
use crate::{api::EmptyResult, auth::Headers, Error, CONFIG};
|
use crate::{api::EmptyResult, auth::Headers, Error, CONFIG};
|
||||||
|
@ -417,7 +417,7 @@ pub enum UpdateType {
|
||||||
}
|
}
|
||||||
|
|
||||||
use rocket::State;
|
use rocket::State;
|
||||||
pub type Notify<'a> = State<'a, WebSocketUsers>;
|
pub type Notify<'a> = &'a State<WebSocketUsers>;
|
||||||
|
|
||||||
pub fn start_notification_server() -> WebSocketUsers {
|
pub fn start_notification_server() -> WebSocketUsers {
|
||||||
let factory = WsFactory::init();
|
let factory = WsFactory::init();
|
||||||
|
@ -430,12 +430,11 @@ pub fn start_notification_server() -> WebSocketUsers {
|
||||||
settings.queue_size = 2;
|
settings.queue_size = 2;
|
||||||
settings.panic_on_internal = false;
|
settings.panic_on_internal = false;
|
||||||
|
|
||||||
ws::Builder::new()
|
let ws = ws::Builder::new().with_settings(settings).build(factory).unwrap();
|
||||||
.with_settings(settings)
|
CONFIG.set_ws_shutdown_handle(ws.broadcaster());
|
||||||
.build(factory)
|
ws.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port())).unwrap();
|
||||||
.unwrap()
|
|
||||||
.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port()))
|
warn!("WS Server stopped!");
|
||||||
.unwrap();
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
use rocket::{http::ContentType, response::content::Content, response::NamedFile, Route};
|
use rocket::serde::json::Json;
|
||||||
use rocket_contrib::json::Json;
|
use rocket::{fs::NamedFile, http::ContentType, Route};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -21,16 +21,16 @@ pub fn routes() -> Vec<Route> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/")]
|
#[get("/")]
|
||||||
fn web_index() -> Cached<Option<NamedFile>> {
|
async fn web_index() -> Cached<Option<NamedFile>> {
|
||||||
Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).ok(), false)
|
Cached::short(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join("index.html")).await.ok(), false)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/app-id.json")]
|
#[get("/app-id.json")]
|
||||||
fn app_id() -> Cached<Content<Json<Value>>> {
|
fn app_id() -> Cached<(ContentType, Json<Value>)> {
|
||||||
let content_type = ContentType::new("application", "fido.trusted-apps+json");
|
let content_type = ContentType::new("application", "fido.trusted-apps+json");
|
||||||
|
|
||||||
Cached::long(
|
Cached::long(
|
||||||
Content(
|
(
|
||||||
content_type,
|
content_type,
|
||||||
Json(json!({
|
Json(json!({
|
||||||
"trustedFacets": [
|
"trustedFacets": [
|
||||||
|
@ -58,13 +58,13 @@ fn app_id() -> Cached<Content<Json<Value>>> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/<p..>", rank = 10)] // Only match this if the other routes don't match
|
#[get("/<p..>", rank = 10)] // Only match this if the other routes don't match
|
||||||
fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
|
async fn web_files(p: PathBuf) -> Cached<Option<NamedFile>> {
|
||||||
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).ok(), true)
|
Cached::long(NamedFile::open(Path::new(&CONFIG.web_vault_folder()).join(p)).await.ok(), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/attachments/<uuid>/<file_id>")]
|
#[get("/attachments/<uuid>/<file_id>")]
|
||||||
fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
|
async fn attachments(uuid: SafeString, file_id: SafeString) -> Option<NamedFile> {
|
||||||
NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).ok()
|
NamedFile::open(Path::new(&CONFIG.attachments_folder()).join(uuid).join(file_id)).await.ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
// We use DbConn here to let the alive healthcheck also verify the database connection.
|
// We use DbConn here to let the alive healthcheck also verify the database connection.
|
||||||
|
@ -78,25 +78,20 @@ fn alive(_conn: DbConn) -> Json<String> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/vw_static/<filename>")]
|
#[get("/vw_static/<filename>")]
|
||||||
fn static_files(filename: String) -> Result<Content<&'static [u8]>, Error> {
|
fn static_files(filename: String) -> Result<(ContentType, &'static [u8]), Error> {
|
||||||
match filename.as_ref() {
|
match filename.as_ref() {
|
||||||
"mail-github.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
|
"mail-github.png" => Ok((ContentType::PNG, include_bytes!("../static/images/mail-github.png"))),
|
||||||
"logo-gray.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
|
"logo-gray.png" => Ok((ContentType::PNG, include_bytes!("../static/images/logo-gray.png"))),
|
||||||
"error-x.svg" => Ok(Content(ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
|
"error-x.svg" => Ok((ContentType::SVG, include_bytes!("../static/images/error-x.svg"))),
|
||||||
"hibp.png" => Ok(Content(ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
|
"hibp.png" => Ok((ContentType::PNG, include_bytes!("../static/images/hibp.png"))),
|
||||||
"vaultwarden-icon.png" => {
|
"vaultwarden-icon.png" => Ok((ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png"))),
|
||||||
Ok(Content(ContentType::PNG, include_bytes!("../static/images/vaultwarden-icon.png")))
|
"bootstrap.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
|
||||||
}
|
"bootstrap-native.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js"))),
|
||||||
|
"identicon.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
|
||||||
"bootstrap.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/bootstrap.css"))),
|
"datatables.js" => Ok((ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
|
||||||
"bootstrap-native.js" => {
|
"datatables.css" => Ok((ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
|
||||||
Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/bootstrap-native.js")))
|
|
||||||
}
|
|
||||||
"identicon.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/identicon.js"))),
|
|
||||||
"datatables.js" => Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/datatables.js"))),
|
|
||||||
"datatables.css" => Ok(Content(ContentType::CSS, include_bytes!("../static/scripts/datatables.css"))),
|
|
||||||
"jquery-3.6.0.slim.js" => {
|
"jquery-3.6.0.slim.js" => {
|
||||||
Ok(Content(ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
|
Ok((ContentType::JavaScript, include_bytes!("../static/scripts/jquery-3.6.0.slim.js")))
|
||||||
}
|
}
|
||||||
_ => err!(format!("Static file not found: {}", filename)),
|
_ => err!(format!("Static file not found: {}", filename)),
|
||||||
}
|
}
|
||||||
|
|
262
src/auth.rs
262
src/auth.rs
|
@ -257,7 +257,10 @@ pub fn generate_send_claims(send_id: &str, file_id: &str) -> BasicJwtClaims {
|
||||||
//
|
//
|
||||||
// Bearer token authentication
|
// Bearer token authentication
|
||||||
//
|
//
|
||||||
use rocket::request::{FromRequest, Outcome, Request};
|
use rocket::{
|
||||||
|
outcome::try_outcome,
|
||||||
|
request::{FromRequest, Outcome, Request},
|
||||||
|
};
|
||||||
|
|
||||||
use crate::db::{
|
use crate::db::{
|
||||||
models::{CollectionUser, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException},
|
models::{CollectionUser, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException},
|
||||||
|
@ -268,10 +271,11 @@ pub struct Host {
|
||||||
pub host: String,
|
pub host: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Host {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for Host {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
let headers = request.headers();
|
let headers = request.headers();
|
||||||
|
|
||||||
// Get host
|
// Get host
|
||||||
|
@ -314,17 +318,14 @@ pub struct Headers {
|
||||||
pub user: User,
|
pub user: User,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for Headers {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for Headers {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
let headers = request.headers();
|
let headers = request.headers();
|
||||||
|
|
||||||
let host = match Host::from_request(request) {
|
let host = try_outcome!(Host::from_request(request).await).host;
|
||||||
Outcome::Forward(_) => return Outcome::Forward(()),
|
|
||||||
Outcome::Failure(f) => return Outcome::Failure(f),
|
|
||||||
Outcome::Success(host) => host.host,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get access_token
|
// Get access_token
|
||||||
let access_token: &str = match headers.get_one("Authorization") {
|
let access_token: &str = match headers.get_one("Authorization") {
|
||||||
|
@ -344,7 +345,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
|
||||||
let device_uuid = claims.device;
|
let device_uuid = claims.device;
|
||||||
let user_uuid = claims.sub;
|
let user_uuid = claims.sub;
|
||||||
|
|
||||||
let conn = match request.guard::<DbConn>() {
|
let conn = match DbConn::from_request(request).await {
|
||||||
Outcome::Success(conn) => conn,
|
Outcome::Success(conn) => conn,
|
||||||
_ => err_handler!("Error getting DB"),
|
_ => err_handler!("Error getting DB"),
|
||||||
};
|
};
|
||||||
|
@ -363,7 +364,7 @@ impl<'a, 'r> FromRequest<'a, 'r> for Headers {
|
||||||
if let Some(stamp_exception) =
|
if let Some(stamp_exception) =
|
||||||
user.stamp_exception.as_deref().and_then(|s| serde_json::from_str::<UserStampException>(s).ok())
|
user.stamp_exception.as_deref().and_then(|s| serde_json::from_str::<UserStampException>(s).ok())
|
||||||
{
|
{
|
||||||
let current_route = match request.route().and_then(|r| r.name) {
|
let current_route = match request.route().and_then(|r| r.name.as_deref()) {
|
||||||
Some(name) => name,
|
Some(name) => name,
|
||||||
_ => err_handler!("Error getting current route for stamp exception"),
|
_ => err_handler!("Error getting current route for stamp exception"),
|
||||||
};
|
};
|
||||||
|
@ -411,13 +412,13 @@ pub struct OrgHeaders {
|
||||||
// but there are cases where it is a query value.
|
// but there are cases where it is a query value.
|
||||||
// First check the path, if this is not a valid uuid, try the query values.
|
// First check the path, if this is not a valid uuid, try the query values.
|
||||||
fn get_org_id(request: &Request) -> Option<String> {
|
fn get_org_id(request: &Request) -> Option<String> {
|
||||||
if let Some(Ok(org_id)) = request.get_param::<String>(1) {
|
if let Some(Ok(org_id)) = request.param::<String>(1) {
|
||||||
if uuid::Uuid::parse_str(&org_id).is_ok() {
|
if uuid::Uuid::parse_str(&org_id).is_ok() {
|
||||||
return Some(org_id);
|
return Some(org_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(Ok(org_id)) = request.get_query_value::<String>("organizationId") {
|
if let Some(Ok(org_id)) = request.query_value::<String>("organizationId") {
|
||||||
if uuid::Uuid::parse_str(&org_id).is_ok() {
|
if uuid::Uuid::parse_str(&org_id).is_ok() {
|
||||||
return Some(org_id);
|
return Some(org_id);
|
||||||
}
|
}
|
||||||
|
@ -426,52 +427,48 @@ fn get_org_id(request: &Request) -> Option<String> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for OrgHeaders {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for OrgHeaders {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match request.guard::<Headers>() {
|
let headers = try_outcome!(Headers::from_request(request).await);
|
||||||
Outcome::Forward(_) => Outcome::Forward(()),
|
match get_org_id(request) {
|
||||||
Outcome::Failure(f) => Outcome::Failure(f),
|
Some(org_id) => {
|
||||||
Outcome::Success(headers) => {
|
let conn = match DbConn::from_request(request).await {
|
||||||
match get_org_id(request) {
|
Outcome::Success(conn) => conn,
|
||||||
Some(org_id) => {
|
_ => err_handler!("Error getting DB"),
|
||||||
let conn = match request.guard::<DbConn>() {
|
};
|
||||||
Outcome::Success(conn) => conn,
|
|
||||||
_ => err_handler!("Error getting DB"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let user = headers.user;
|
let user = headers.user;
|
||||||
let org_user = match UserOrganization::find_by_user_and_org(&user.uuid, &org_id, &conn) {
|
let org_user = match UserOrganization::find_by_user_and_org(&user.uuid, &org_id, &conn) {
|
||||||
Some(user) => {
|
Some(user) => {
|
||||||
if user.status == UserOrgStatus::Confirmed as i32 {
|
if user.status == UserOrgStatus::Confirmed as i32 {
|
||||||
user
|
user
|
||||||
} else {
|
} else {
|
||||||
err_handler!("The current user isn't confirmed member of the organization")
|
err_handler!("The current user isn't confirmed member of the organization")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
None => err_handler!("The current user isn't member of the organization"),
|
|
||||||
};
|
|
||||||
|
|
||||||
Outcome::Success(Self {
|
|
||||||
host: headers.host,
|
|
||||||
device: headers.device,
|
|
||||||
user,
|
|
||||||
org_user_type: {
|
|
||||||
if let Some(org_usr_type) = UserOrgType::from_i32(org_user.atype) {
|
|
||||||
org_usr_type
|
|
||||||
} else {
|
|
||||||
// This should only happen if the DB is corrupted
|
|
||||||
err_handler!("Unknown user type in the database")
|
|
||||||
}
|
|
||||||
},
|
|
||||||
org_user,
|
|
||||||
org_id,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
_ => err_handler!("Error getting the organization id"),
|
None => err_handler!("The current user isn't member of the organization"),
|
||||||
}
|
};
|
||||||
|
|
||||||
|
Outcome::Success(Self {
|
||||||
|
host: headers.host,
|
||||||
|
device: headers.device,
|
||||||
|
user,
|
||||||
|
org_user_type: {
|
||||||
|
if let Some(org_usr_type) = UserOrgType::from_i32(org_user.atype) {
|
||||||
|
org_usr_type
|
||||||
|
} else {
|
||||||
|
// This should only happen if the DB is corrupted
|
||||||
|
err_handler!("Unknown user type in the database")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
org_user,
|
||||||
|
org_id,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
_ => err_handler!("Error getting the organization id"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -483,25 +480,21 @@ pub struct AdminHeaders {
|
||||||
pub org_user_type: UserOrgType,
|
pub org_user_type: UserOrgType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for AdminHeaders {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for AdminHeaders {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match request.guard::<OrgHeaders>() {
|
let headers = try_outcome!(OrgHeaders::from_request(request).await);
|
||||||
Outcome::Forward(_) => Outcome::Forward(()),
|
if headers.org_user_type >= UserOrgType::Admin {
|
||||||
Outcome::Failure(f) => Outcome::Failure(f),
|
Outcome::Success(Self {
|
||||||
Outcome::Success(headers) => {
|
host: headers.host,
|
||||||
if headers.org_user_type >= UserOrgType::Admin {
|
device: headers.device,
|
||||||
Outcome::Success(Self {
|
user: headers.user,
|
||||||
host: headers.host,
|
org_user_type: headers.org_user_type,
|
||||||
device: headers.device,
|
})
|
||||||
user: headers.user,
|
} else {
|
||||||
org_user_type: headers.org_user_type,
|
err_handler!("You need to be Admin or Owner to call this endpoint")
|
||||||
})
|
|
||||||
} else {
|
|
||||||
err_handler!("You need to be Admin or Owner to call this endpoint")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -520,13 +513,13 @@ impl From<AdminHeaders> for Headers {
|
||||||
// but there could be cases where it is a query value.
|
// but there could be cases where it is a query value.
|
||||||
// First check the path, if this is not a valid uuid, try the query values.
|
// First check the path, if this is not a valid uuid, try the query values.
|
||||||
fn get_col_id(request: &Request) -> Option<String> {
|
fn get_col_id(request: &Request) -> Option<String> {
|
||||||
if let Some(Ok(col_id)) = request.get_param::<String>(3) {
|
if let Some(Ok(col_id)) = request.param::<String>(3) {
|
||||||
if uuid::Uuid::parse_str(&col_id).is_ok() {
|
if uuid::Uuid::parse_str(&col_id).is_ok() {
|
||||||
return Some(col_id);
|
return Some(col_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(Ok(col_id)) = request.get_query_value::<String>("collectionId") {
|
if let Some(Ok(col_id)) = request.query_value::<String>("collectionId") {
|
||||||
if uuid::Uuid::parse_str(&col_id).is_ok() {
|
if uuid::Uuid::parse_str(&col_id).is_ok() {
|
||||||
return Some(col_id);
|
return Some(col_id);
|
||||||
}
|
}
|
||||||
|
@ -545,46 +538,38 @@ pub struct ManagerHeaders {
|
||||||
pub org_user_type: UserOrgType,
|
pub org_user_type: UserOrgType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for ManagerHeaders {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for ManagerHeaders {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match request.guard::<OrgHeaders>() {
|
let headers = try_outcome!(OrgHeaders::from_request(request).await);
|
||||||
Outcome::Forward(_) => Outcome::Forward(()),
|
if headers.org_user_type >= UserOrgType::Manager {
|
||||||
Outcome::Failure(f) => Outcome::Failure(f),
|
match get_col_id(request) {
|
||||||
Outcome::Success(headers) => {
|
Some(col_id) => {
|
||||||
if headers.org_user_type >= UserOrgType::Manager {
|
let conn = match DbConn::from_request(request).await {
|
||||||
match get_col_id(request) {
|
Outcome::Success(conn) => conn,
|
||||||
Some(col_id) => {
|
_ => err_handler!("Error getting DB"),
|
||||||
let conn = match request.guard::<DbConn>() {
|
};
|
||||||
Outcome::Success(conn) => conn,
|
|
||||||
_ => err_handler!("Error getting DB"),
|
|
||||||
};
|
|
||||||
|
|
||||||
if !headers.org_user.has_full_access() {
|
if !headers.org_user.has_full_access() {
|
||||||
match CollectionUser::find_by_collection_and_user(
|
match CollectionUser::find_by_collection_and_user(&col_id, &headers.org_user.user_uuid, &conn) {
|
||||||
&col_id,
|
Some(_) => (),
|
||||||
&headers.org_user.user_uuid,
|
None => err_handler!("The current user isn't a manager for this collection"),
|
||||||
&conn,
|
|
||||||
) {
|
|
||||||
Some(_) => (),
|
|
||||||
None => err_handler!("The current user isn't a manager for this collection"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
_ => err_handler!("Error getting the collection id"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Outcome::Success(Self {
|
|
||||||
host: headers.host,
|
|
||||||
device: headers.device,
|
|
||||||
user: headers.user,
|
|
||||||
org_user_type: headers.org_user_type,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
|
|
||||||
}
|
}
|
||||||
|
_ => err_handler!("Error getting the collection id"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Outcome::Success(Self {
|
||||||
|
host: headers.host,
|
||||||
|
device: headers.device,
|
||||||
|
user: headers.user,
|
||||||
|
org_user_type: headers.org_user_type,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -608,25 +593,21 @@ pub struct ManagerHeadersLoose {
|
||||||
pub org_user_type: UserOrgType,
|
pub org_user_type: UserOrgType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for ManagerHeadersLoose {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for ManagerHeadersLoose {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match request.guard::<OrgHeaders>() {
|
let headers = try_outcome!(OrgHeaders::from_request(request).await);
|
||||||
Outcome::Forward(_) => Outcome::Forward(()),
|
if headers.org_user_type >= UserOrgType::Manager {
|
||||||
Outcome::Failure(f) => Outcome::Failure(f),
|
Outcome::Success(Self {
|
||||||
Outcome::Success(headers) => {
|
host: headers.host,
|
||||||
if headers.org_user_type >= UserOrgType::Manager {
|
device: headers.device,
|
||||||
Outcome::Success(Self {
|
user: headers.user,
|
||||||
host: headers.host,
|
org_user_type: headers.org_user_type,
|
||||||
device: headers.device,
|
})
|
||||||
user: headers.user,
|
} else {
|
||||||
org_user_type: headers.org_user_type,
|
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
|
||||||
})
|
|
||||||
} else {
|
|
||||||
err_handler!("You need to be a Manager, Admin or Owner to call this endpoint")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -647,24 +628,20 @@ pub struct OwnerHeaders {
|
||||||
pub user: User,
|
pub user: User,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for OwnerHeaders {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for OwnerHeaders {
|
||||||
type Error = &'static str;
|
type Error = &'static str;
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
match request.guard::<OrgHeaders>() {
|
let headers = try_outcome!(OrgHeaders::from_request(request).await);
|
||||||
Outcome::Forward(_) => Outcome::Forward(()),
|
if headers.org_user_type == UserOrgType::Owner {
|
||||||
Outcome::Failure(f) => Outcome::Failure(f),
|
Outcome::Success(Self {
|
||||||
Outcome::Success(headers) => {
|
host: headers.host,
|
||||||
if headers.org_user_type == UserOrgType::Owner {
|
device: headers.device,
|
||||||
Outcome::Success(Self {
|
user: headers.user,
|
||||||
host: headers.host,
|
})
|
||||||
device: headers.device,
|
} else {
|
||||||
user: headers.user,
|
err_handler!("You need to be Owner to call this endpoint")
|
||||||
})
|
|
||||||
} else {
|
|
||||||
err_handler!("You need to be Owner to call this endpoint")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -678,10 +655,11 @@ pub struct ClientIp {
|
||||||
pub ip: IpAddr,
|
pub ip: IpAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for ClientIp {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for ClientIp {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn from_request(req: &'a Request<'r>) -> Outcome<Self, Self::Error> {
|
async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
let ip = if CONFIG._ip_header_enabled() {
|
let ip = if CONFIG._ip_header_enabled() {
|
||||||
req.headers().get_one(&CONFIG.ip_header()).and_then(|ip| {
|
req.headers().get_one(&CONFIG.ip_header()).and_then(|ip| {
|
||||||
match ip.find(',') {
|
match ip.find(',') {
|
||||||
|
|
|
@ -36,6 +36,9 @@ macro_rules! make_config {
|
||||||
pub struct Config { inner: RwLock<Inner> }
|
pub struct Config { inner: RwLock<Inner> }
|
||||||
|
|
||||||
struct Inner {
|
struct Inner {
|
||||||
|
rocket_shutdown_handle: Option<rocket::Shutdown>,
|
||||||
|
ws_shutdown_handle: Option<ws::Sender>,
|
||||||
|
|
||||||
templates: Handlebars<'static>,
|
templates: Handlebars<'static>,
|
||||||
config: ConfigItems,
|
config: ConfigItems,
|
||||||
|
|
||||||
|
@ -332,6 +335,8 @@ make_config! {
|
||||||
attachments_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "attachments");
|
attachments_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "attachments");
|
||||||
/// Sends folder
|
/// Sends folder
|
||||||
sends_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "sends");
|
sends_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "sends");
|
||||||
|
/// Temp folder |> Used for storing temporary file uploads
|
||||||
|
tmp_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "tmp");
|
||||||
/// Templates folder
|
/// Templates folder
|
||||||
templates_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "templates");
|
templates_folder: String, false, auto, |c| format!("{}/{}", c.data_folder, "templates");
|
||||||
/// Session JWT key
|
/// Session JWT key
|
||||||
|
@ -509,6 +514,9 @@ make_config! {
|
||||||
/// Max database connection retries |> Number of times to retry the database connection during startup, with 1 second between each retry, set to 0 to retry indefinitely
|
/// Max database connection retries |> Number of times to retry the database connection during startup, with 1 second between each retry, set to 0 to retry indefinitely
|
||||||
db_connection_retries: u32, false, def, 15;
|
db_connection_retries: u32, false, def, 15;
|
||||||
|
|
||||||
|
/// Timeout when aquiring database connection
|
||||||
|
database_timeout: u64, false, def, 30;
|
||||||
|
|
||||||
/// Database connection pool size
|
/// Database connection pool size
|
||||||
database_max_conns: u32, false, def, 10;
|
database_max_conns: u32, false, def, 10;
|
||||||
|
|
||||||
|
@ -743,6 +751,8 @@ impl Config {
|
||||||
|
|
||||||
Ok(Config {
|
Ok(Config {
|
||||||
inner: RwLock::new(Inner {
|
inner: RwLock::new(Inner {
|
||||||
|
rocket_shutdown_handle: None,
|
||||||
|
ws_shutdown_handle: None,
|
||||||
templates: load_templates(&config.templates_folder),
|
templates: load_templates(&config.templates_folder),
|
||||||
config,
|
config,
|
||||||
_env,
|
_env,
|
||||||
|
@ -907,6 +917,27 @@ impl Config {
|
||||||
hb.render(name, data).map_err(Into::into)
|
hb.render(name, data).map_err(Into::into)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_rocket_shutdown_handle(&self, handle: rocket::Shutdown) {
|
||||||
|
self.inner.write().unwrap().rocket_shutdown_handle = Some(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_ws_shutdown_handle(&self, handle: ws::Sender) {
|
||||||
|
self.inner.write().unwrap().ws_shutdown_handle = Some(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shutdown(&self) {
|
||||||
|
if let Ok(c) = self.inner.read() {
|
||||||
|
if let Some(handle) = c.ws_shutdown_handle.clone() {
|
||||||
|
handle.shutdown().ok();
|
||||||
|
}
|
||||||
|
// Wait a bit before stopping the web server
|
||||||
|
std::thread::sleep(std::time::Duration::from_secs(1));
|
||||||
|
if let Some(handle) = c.rocket_shutdown_handle.clone() {
|
||||||
|
handle.notify();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext, RenderError, Renderable};
|
use handlebars::{Context, Handlebars, Helper, HelperResult, Output, RenderContext, RenderError, Renderable};
|
||||||
|
|
231
src/db/mod.rs
231
src/db/mod.rs
|
@ -1,8 +1,16 @@
|
||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
|
||||||
use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
|
use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
|
||||||
use rocket::{
|
use rocket::{
|
||||||
http::Status,
|
http::Status,
|
||||||
|
outcome::IntoOutcome,
|
||||||
request::{FromRequest, Outcome},
|
request::{FromRequest, Outcome},
|
||||||
Request, State,
|
Request,
|
||||||
|
};
|
||||||
|
|
||||||
|
use tokio::{
|
||||||
|
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
|
||||||
|
time::timeout,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -22,6 +30,23 @@ pub mod __mysql_schema;
|
||||||
#[path = "schemas/postgresql/schema.rs"]
|
#[path = "schemas/postgresql/schema.rs"]
|
||||||
pub mod __postgresql_schema;
|
pub mod __postgresql_schema;
|
||||||
|
|
||||||
|
// There changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
|
||||||
|
|
||||||
|
// A wrapper around spawn_blocking that propagates panics to the calling code.
|
||||||
|
pub async fn run_blocking<F, R>(job: F) -> R
|
||||||
|
where
|
||||||
|
F: FnOnce() -> R + Send + 'static,
|
||||||
|
R: Send + 'static,
|
||||||
|
{
|
||||||
|
match tokio::task::spawn_blocking(job).await {
|
||||||
|
Ok(ret) => ret,
|
||||||
|
Err(e) => match e.try_into_panic() {
|
||||||
|
Ok(panic) => std::panic::resume_unwind(panic),
|
||||||
|
Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
|
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
|
||||||
macro_rules! generate_connections {
|
macro_rules! generate_connections {
|
||||||
( $( $name:ident: $ty:ty ),+ ) => {
|
( $( $name:ident: $ty:ty ),+ ) => {
|
||||||
|
@ -29,12 +54,53 @@ macro_rules! generate_connections {
|
||||||
#[derive(Eq, PartialEq)]
|
#[derive(Eq, PartialEq)]
|
||||||
pub enum DbConnType { $( $name, )+ }
|
pub enum DbConnType { $( $name, )+ }
|
||||||
|
|
||||||
|
pub struct DbConn {
|
||||||
|
conn: Arc<Mutex<Option<DbConnInner>>>,
|
||||||
|
permit: Option<OwnedSemaphorePermit>,
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
pub enum DbConn { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
|
pub enum DbConnInner { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
|
||||||
|
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct DbPool {
|
||||||
|
// This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
|
||||||
|
pool: Option<DbPoolInner>,
|
||||||
|
semaphore: Arc<Semaphore>
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(non_camel_case_types)]
|
#[allow(non_camel_case_types)]
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub enum DbPool { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
|
pub enum DbPoolInner { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
|
||||||
|
|
||||||
|
impl Drop for DbConn {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let conn = self.conn.clone();
|
||||||
|
let permit = self.permit.take();
|
||||||
|
|
||||||
|
// Since connection can't be on the stack in an async fn during an
|
||||||
|
// await, we have to spawn a new blocking-safe thread...
|
||||||
|
tokio::task::spawn_blocking(move || {
|
||||||
|
// And then re-enter the runtime to wait on the async mutex, but in a blocking fashion.
|
||||||
|
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
|
||||||
|
|
||||||
|
if let Some(conn) = conn.take() {
|
||||||
|
drop(conn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop permit after the connection is dropped
|
||||||
|
drop(permit);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DbPool {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let pool = self.pool.take();
|
||||||
|
tokio::task::spawn_blocking(move || drop(pool));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl DbPool {
|
impl DbPool {
|
||||||
// For the given database URL, guess it's type, run migrations create pool and return it
|
// For the given database URL, guess it's type, run migrations create pool and return it
|
||||||
|
@ -50,9 +116,13 @@ macro_rules! generate_connections {
|
||||||
let manager = ConnectionManager::new(&url);
|
let manager = ConnectionManager::new(&url);
|
||||||
let pool = Pool::builder()
|
let pool = Pool::builder()
|
||||||
.max_size(CONFIG.database_max_conns())
|
.max_size(CONFIG.database_max_conns())
|
||||||
|
.connection_timeout(Duration::from_secs(CONFIG.database_timeout()))
|
||||||
.build(manager)
|
.build(manager)
|
||||||
.map_res("Failed to create pool")?;
|
.map_res("Failed to create pool")?;
|
||||||
return Ok(Self::$name(pool));
|
return Ok(DbPool {
|
||||||
|
pool: Some(DbPoolInner::$name(pool)),
|
||||||
|
semaphore: Arc::new(Semaphore::new(CONFIG.database_max_conns() as usize)),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
#[cfg(not($name))]
|
#[cfg(not($name))]
|
||||||
#[allow(unreachable_code)]
|
#[allow(unreachable_code)]
|
||||||
|
@ -61,10 +131,26 @@ macro_rules! generate_connections {
|
||||||
)+ }
|
)+ }
|
||||||
}
|
}
|
||||||
// Get a connection from the pool
|
// Get a connection from the pool
|
||||||
pub fn get(&self) -> Result<DbConn, Error> {
|
pub async fn get(&self) -> Result<DbConn, Error> {
|
||||||
match self { $(
|
let duration = Duration::from_secs(CONFIG.database_timeout());
|
||||||
|
let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
|
||||||
|
Ok(p) => p.expect("Semaphore should be open"),
|
||||||
|
Err(_) => {
|
||||||
|
err!("Timeout waiting for database connection");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.pool.as_ref().expect("DbPool.pool should always be Some()") { $(
|
||||||
#[cfg($name)]
|
#[cfg($name)]
|
||||||
Self::$name(p) => Ok(DbConn::$name(p.get().map_res("Error retrieving connection from pool")?)),
|
DbPoolInner::$name(p) => {
|
||||||
|
let pool = p.clone();
|
||||||
|
let c = run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?;
|
||||||
|
|
||||||
|
return Ok(DbConn {
|
||||||
|
conn: Arc::new(Mutex::new(Some(DbConnInner::$name(c)))),
|
||||||
|
permit: Some(permit)
|
||||||
|
});
|
||||||
|
},
|
||||||
)+ }
|
)+ }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -113,42 +199,95 @@ macro_rules! db_run {
|
||||||
db_run! { $conn: sqlite, mysql, postgresql $body }
|
db_run! { $conn: sqlite, mysql, postgresql $body }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Different code for each db
|
|
||||||
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
|
|
||||||
#[allow(unused)] use diesel::prelude::*;
|
|
||||||
match $conn {
|
|
||||||
$($(
|
|
||||||
#[cfg($db)]
|
|
||||||
crate::db::DbConn::$db(ref $conn) => {
|
|
||||||
paste::paste! {
|
|
||||||
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
|
|
||||||
#[allow(unused)] use [<__ $db _model>]::*;
|
|
||||||
#[allow(unused)] use crate::db::FromDb;
|
|
||||||
}
|
|
||||||
$body
|
|
||||||
},
|
|
||||||
)+)+
|
|
||||||
}}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Same for all dbs
|
|
||||||
( @raw $conn:ident: $body:block ) => {
|
( @raw $conn:ident: $body:block ) => {
|
||||||
db_run! { @raw $conn: sqlite, mysql, postgresql $body }
|
db_run! { @raw $conn: sqlite, mysql, postgresql $body }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Different code for each db
|
// Different code for each db
|
||||||
( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {
|
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
|
||||||
#[allow(unused)] use diesel::prelude::*;
|
#[allow(unused)] use diesel::prelude::*;
|
||||||
#[allow(unused_variables)]
|
|
||||||
match $conn {
|
// It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
|
||||||
$($(
|
// derived from it) never be a variable on the stack at an await point,
|
||||||
#[cfg($db)]
|
// where Drop might be called at any time. This causes (synchronous)
|
||||||
crate::db::DbConn::$db(ref $conn) => {
|
// Drop to be called from asynchronous code, which some database
|
||||||
$body
|
// wrappers do not or can not handle.
|
||||||
},
|
let conn = $conn.conn.clone();
|
||||||
)+)+
|
|
||||||
}
|
// Since connection can't be on the stack in an async fn during an
|
||||||
};
|
// await, we have to spawn a new blocking-safe thread...
|
||||||
|
/*
|
||||||
|
run_blocking(move || {
|
||||||
|
// And then re-enter the runtime to wait on the async mutex, but in
|
||||||
|
// a blocking fashion.
|
||||||
|
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
|
||||||
|
let conn = conn.as_mut().expect("internal invariant broken: self.connection is Some");
|
||||||
|
*/
|
||||||
|
let mut __conn_mutex = conn.try_lock_owned().unwrap();
|
||||||
|
let conn = __conn_mutex.as_mut().unwrap();
|
||||||
|
match conn {
|
||||||
|
$($(
|
||||||
|
#[cfg($db)]
|
||||||
|
crate::db::DbConnInner::$db($conn) => {
|
||||||
|
paste::paste! {
|
||||||
|
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
|
||||||
|
#[allow(unused)] use [<__ $db _model>]::*;
|
||||||
|
#[allow(unused)] use crate::db::FromDb;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
// Since connection can't be on the stack in an async fn during an
|
||||||
|
// await, we have to spawn a new blocking-safe thread...
|
||||||
|
run_blocking(move || {
|
||||||
|
// And then re-enter the runtime to wait on the async mutex, but in
|
||||||
|
// a blocking fashion.
|
||||||
|
let mut conn = tokio::runtime::Handle::current().block_on(async {
|
||||||
|
conn.lock_owned().await
|
||||||
|
});
|
||||||
|
|
||||||
|
let conn = conn.as_mut().expect("internal invariant broken: self.connection is Some");
|
||||||
|
f(conn)
|
||||||
|
}).await;*/
|
||||||
|
|
||||||
|
$body
|
||||||
|
},
|
||||||
|
)+)+
|
||||||
|
}
|
||||||
|
// }).await
|
||||||
|
}};
|
||||||
|
|
||||||
|
( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
|
||||||
|
#[allow(unused)] use diesel::prelude::*;
|
||||||
|
|
||||||
|
// It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
|
||||||
|
// derived from it) never be a variable on the stack at an await point,
|
||||||
|
// where Drop might be called at any time. This causes (synchronous)
|
||||||
|
// Drop to be called from asynchronous code, which some database
|
||||||
|
// wrappers do not or can not handle.
|
||||||
|
let conn = $conn.conn.clone();
|
||||||
|
|
||||||
|
// Since connection can't be on the stack in an async fn during an
|
||||||
|
// await, we have to spawn a new blocking-safe thread...
|
||||||
|
run_blocking(move || {
|
||||||
|
// And then re-enter the runtime to wait on the async mutex, but in
|
||||||
|
// a blocking fashion.
|
||||||
|
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
|
||||||
|
match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
|
||||||
|
$($(
|
||||||
|
#[cfg($db)]
|
||||||
|
crate::db::DbConnInner::$db($conn) => {
|
||||||
|
paste::paste! {
|
||||||
|
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
|
||||||
|
// @RAW: #[allow(unused)] use [<__ $db _model>]::*;
|
||||||
|
#[allow(unused)] use crate::db::FromDb;
|
||||||
|
}
|
||||||
|
|
||||||
|
$body
|
||||||
|
},
|
||||||
|
)+)+
|
||||||
|
}
|
||||||
|
}).await
|
||||||
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait FromDb {
|
pub trait FromDb {
|
||||||
|
@ -227,9 +366,10 @@ pub mod models;
|
||||||
|
|
||||||
/// Creates a back-up of the sqlite database
|
/// Creates a back-up of the sqlite database
|
||||||
/// MySQL/MariaDB and PostgreSQL are not supported.
|
/// MySQL/MariaDB and PostgreSQL are not supported.
|
||||||
pub fn backup_database(conn: &DbConn) -> Result<(), Error> {
|
pub async fn backup_database(conn: &DbConn) -> Result<(), Error> {
|
||||||
db_run! {@raw conn:
|
db_run! {@raw conn:
|
||||||
postgresql, mysql {
|
postgresql, mysql {
|
||||||
|
let _ = conn;
|
||||||
err!("PostgreSQL and MySQL/MariaDB do not support this backup feature");
|
err!("PostgreSQL and MySQL/MariaDB do not support this backup feature");
|
||||||
}
|
}
|
||||||
sqlite {
|
sqlite {
|
||||||
|
@ -244,7 +384,7 @@ pub fn backup_database(conn: &DbConn) -> Result<(), Error> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the SQL Server version
|
/// Get the SQL Server version
|
||||||
pub fn get_sql_server_version(conn: &DbConn) -> String {
|
pub async fn get_sql_server_version(conn: &DbConn) -> String {
|
||||||
db_run! {@raw conn:
|
db_run! {@raw conn:
|
||||||
postgresql, mysql {
|
postgresql, mysql {
|
||||||
no_arg_sql_function!(version, diesel::sql_types::Text);
|
no_arg_sql_function!(version, diesel::sql_types::Text);
|
||||||
|
@ -260,15 +400,14 @@ pub fn get_sql_server_version(conn: &DbConn) -> String {
|
||||||
/// Attempts to retrieve a single connection from the managed database pool. If
|
/// Attempts to retrieve a single connection from the managed database pool. If
|
||||||
/// no pool is currently managed, fails with an `InternalServerError` status. If
|
/// no pool is currently managed, fails with an `InternalServerError` status. If
|
||||||
/// no connections are available, fails with a `ServiceUnavailable` status.
|
/// no connections are available, fails with a `ServiceUnavailable` status.
|
||||||
impl<'a, 'r> FromRequest<'a, 'r> for DbConn {
|
#[rocket::async_trait]
|
||||||
|
impl<'r> FromRequest<'r> for DbConn {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
fn from_request(request: &'a Request<'r>) -> Outcome<DbConn, ()> {
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
||||||
// https://github.com/SergioBenitez/Rocket/commit/e3c1a4ad3ab9b840482ec6de4200d30df43e357c
|
match request.rocket().state::<DbPool>() {
|
||||||
let pool = try_outcome!(request.guard::<State<DbPool>>());
|
Some(p) => p.get().await.map_err(|_| ()).into_outcome(Status::ServiceUnavailable),
|
||||||
match pool.get() {
|
None => Outcome::Failure((Status::InternalServerError, ())),
|
||||||
Ok(conn) => Outcome::Success(conn),
|
|
||||||
Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
10
src/error.rs
10
src/error.rs
|
@ -45,6 +45,7 @@ use lettre::transport::smtp::Error as SmtpErr;
|
||||||
use openssl::error::ErrorStack as SSLErr;
|
use openssl::error::ErrorStack as SSLErr;
|
||||||
use regex::Error as RegexErr;
|
use regex::Error as RegexErr;
|
||||||
use reqwest::Error as ReqErr;
|
use reqwest::Error as ReqErr;
|
||||||
|
use rocket::error::Error as RocketErr;
|
||||||
use serde_json::{Error as SerdeErr, Value};
|
use serde_json::{Error as SerdeErr, Value};
|
||||||
use std::io::Error as IoErr;
|
use std::io::Error as IoErr;
|
||||||
use std::time::SystemTimeError as TimeErr;
|
use std::time::SystemTimeError as TimeErr;
|
||||||
|
@ -84,6 +85,7 @@ make_error! {
|
||||||
Address(AddrErr): _has_source, _api_error,
|
Address(AddrErr): _has_source, _api_error,
|
||||||
Smtp(SmtpErr): _has_source, _api_error,
|
Smtp(SmtpErr): _has_source, _api_error,
|
||||||
OpenSSL(SSLErr): _has_source, _api_error,
|
OpenSSL(SSLErr): _has_source, _api_error,
|
||||||
|
Rocket(RocketErr): _has_source, _api_error,
|
||||||
|
|
||||||
DieselCon(DieselConErr): _has_source, _api_error,
|
DieselCon(DieselConErr): _has_source, _api_error,
|
||||||
DieselMig(DieselMigErr): _has_source, _api_error,
|
DieselMig(DieselMigErr): _has_source, _api_error,
|
||||||
|
@ -193,8 +195,8 @@ use rocket::http::{ContentType, Status};
|
||||||
use rocket::request::Request;
|
use rocket::request::Request;
|
||||||
use rocket::response::{self, Responder, Response};
|
use rocket::response::{self, Responder, Response};
|
||||||
|
|
||||||
impl<'r> Responder<'r> for Error {
|
impl<'r> Responder<'r, 'static> for Error {
|
||||||
fn respond_to(self, _: &Request) -> response::Result<'r> {
|
fn respond_to(self, _: &Request) -> response::Result<'static> {
|
||||||
match self.error {
|
match self.error {
|
||||||
ErrorKind::Empty(_) => {} // Don't print the error in this situation
|
ErrorKind::Empty(_) => {} // Don't print the error in this situation
|
||||||
ErrorKind::Simple(_) => {} // Don't print the error in this situation
|
ErrorKind::Simple(_) => {} // Don't print the error in this situation
|
||||||
|
@ -202,8 +204,8 @@ impl<'r> Responder<'r> for Error {
|
||||||
};
|
};
|
||||||
|
|
||||||
let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest);
|
let code = Status::from_code(self.error_code).unwrap_or(Status::BadRequest);
|
||||||
|
let body = self.to_string();
|
||||||
Response::build().status(code).header(ContentType::JSON).sized_body(Cursor::new(format!("{}", self))).ok()
|
Response::build().status(code).header(ContentType::JSON).sized_body(Some(body.len()), Cursor::new(body)).ok()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
98
src/main.rs
98
src/main.rs
|
@ -20,8 +20,15 @@ extern crate diesel;
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate diesel_migrations;
|
extern crate diesel_migrations;
|
||||||
|
|
||||||
use job_scheduler::{Job, JobScheduler};
|
use std::{
|
||||||
use std::{fs::create_dir_all, panic, path::Path, process::exit, str::FromStr, thread, time::Duration};
|
fs::{canonicalize, create_dir_all},
|
||||||
|
panic,
|
||||||
|
path::Path,
|
||||||
|
process::exit,
|
||||||
|
str::FromStr,
|
||||||
|
thread,
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
mod error;
|
mod error;
|
||||||
|
@ -37,9 +44,11 @@ mod util;
|
||||||
|
|
||||||
pub use config::CONFIG;
|
pub use config::CONFIG;
|
||||||
pub use error::{Error, MapResult};
|
pub use error::{Error, MapResult};
|
||||||
|
use rocket::data::{Limits, ToByteUnit};
|
||||||
pub use util::is_running_in_docker;
|
pub use util::is_running_in_docker;
|
||||||
|
|
||||||
fn main() {
|
#[rocket::main]
|
||||||
|
async fn main() -> Result<(), Error> {
|
||||||
parse_args();
|
parse_args();
|
||||||
launch_info();
|
launch_info();
|
||||||
|
|
||||||
|
@ -56,13 +65,16 @@ fn main() {
|
||||||
});
|
});
|
||||||
check_web_vault();
|
check_web_vault();
|
||||||
|
|
||||||
create_icon_cache_folder();
|
create_dir(&CONFIG.icon_cache_folder(), "icon cache");
|
||||||
|
create_dir(&CONFIG.tmp_folder(), "tmp folder");
|
||||||
|
create_dir(&CONFIG.sends_folder(), "sends folder");
|
||||||
|
create_dir(&CONFIG.attachments_folder(), "attachments folder");
|
||||||
|
|
||||||
let pool = create_db_pool();
|
let pool = create_db_pool();
|
||||||
schedule_jobs(pool.clone());
|
schedule_jobs(pool.clone()).await;
|
||||||
crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().unwrap()).unwrap();
|
crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&pool.get().await.unwrap()).unwrap();
|
||||||
|
|
||||||
launch_rocket(pool, extra_debug); // Blocks until program termination.
|
launch_rocket(pool, extra_debug).await // Blocks until program termination.
|
||||||
}
|
}
|
||||||
|
|
||||||
const HELP: &str = "\
|
const HELP: &str = "\
|
||||||
|
@ -127,10 +139,12 @@ fn init_logging(level: log::LevelFilter) -> Result<(), fern::InitError> {
|
||||||
.level_for("hyper::server", log::LevelFilter::Warn)
|
.level_for("hyper::server", log::LevelFilter::Warn)
|
||||||
// Silence rocket logs
|
// Silence rocket logs
|
||||||
.level_for("_", log::LevelFilter::Off)
|
.level_for("_", log::LevelFilter::Off)
|
||||||
.level_for("launch", log::LevelFilter::Off)
|
.level_for("rocket::launch", log::LevelFilter::Error)
|
||||||
.level_for("launch_", log::LevelFilter::Off)
|
.level_for("rocket::launch_", log::LevelFilter::Error)
|
||||||
.level_for("rocket::rocket", log::LevelFilter::Off)
|
.level_for("rocket::rocket", log::LevelFilter::Warn)
|
||||||
.level_for("rocket::fairing", log::LevelFilter::Off)
|
.level_for("rocket::server", log::LevelFilter::Warn)
|
||||||
|
.level_for("rocket::fairing::fairings", log::LevelFilter::Warn)
|
||||||
|
.level_for("rocket::shield::shield", log::LevelFilter::Warn)
|
||||||
// Never show html5ever and hyper::proto logs, too noisy
|
// Never show html5ever and hyper::proto logs, too noisy
|
||||||
.level_for("html5ever", log::LevelFilter::Off)
|
.level_for("html5ever", log::LevelFilter::Off)
|
||||||
.level_for("hyper::proto", log::LevelFilter::Off)
|
.level_for("hyper::proto", log::LevelFilter::Off)
|
||||||
|
@ -243,10 +257,6 @@ fn create_dir(path: &str, description: &str) {
|
||||||
create_dir_all(path).expect(&err_msg);
|
create_dir_all(path).expect(&err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_icon_cache_folder() {
|
|
||||||
create_dir(&CONFIG.icon_cache_folder(), "icon cache");
|
|
||||||
}
|
|
||||||
|
|
||||||
fn check_data_folder() {
|
fn check_data_folder() {
|
||||||
let data_folder = &CONFIG.data_folder();
|
let data_folder = &CONFIG.data_folder();
|
||||||
let path = Path::new(data_folder);
|
let path = Path::new(data_folder);
|
||||||
|
@ -314,51 +324,73 @@ fn create_db_pool() -> db::DbPool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn launch_rocket(pool: db::DbPool, extra_debug: bool) {
|
async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> {
|
||||||
let basepath = &CONFIG.domain_path();
|
let basepath = &CONFIG.domain_path();
|
||||||
|
|
||||||
|
let mut config = rocket::Config::from(rocket::Config::figment());
|
||||||
|
config.address = std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED); // TODO: Allow this to be changed, keep ROCKET_ADDRESS for compat
|
||||||
|
config.temp_dir = canonicalize(CONFIG.tmp_folder()).unwrap().into();
|
||||||
|
config.limits = Limits::new() //
|
||||||
|
.limit("json", 10.megabytes())
|
||||||
|
.limit("data-form", 150.megabytes())
|
||||||
|
.limit("file", 150.megabytes());
|
||||||
|
|
||||||
// If adding more paths here, consider also adding them to
|
// If adding more paths here, consider also adding them to
|
||||||
// crate::utils::LOGGED_ROUTES to make sure they appear in the log
|
// crate::utils::LOGGED_ROUTES to make sure they appear in the log
|
||||||
let result = rocket::ignite()
|
let instance = rocket::custom(config)
|
||||||
.mount(&[basepath, "/"].concat(), api::web_routes())
|
.mount([basepath, "/"].concat(), api::web_routes())
|
||||||
.mount(&[basepath, "/api"].concat(), api::core_routes())
|
.mount([basepath, "/api"].concat(), api::core_routes())
|
||||||
.mount(&[basepath, "/admin"].concat(), api::admin_routes())
|
.mount([basepath, "/admin"].concat(), api::admin_routes())
|
||||||
.mount(&[basepath, "/identity"].concat(), api::identity_routes())
|
.mount([basepath, "/identity"].concat(), api::identity_routes())
|
||||||
.mount(&[basepath, "/icons"].concat(), api::icons_routes())
|
.mount([basepath, "/icons"].concat(), api::icons_routes())
|
||||||
.mount(&[basepath, "/notifications"].concat(), api::notifications_routes())
|
.mount([basepath, "/notifications"].concat(), api::notifications_routes())
|
||||||
.manage(pool)
|
.manage(pool)
|
||||||
.manage(api::start_notification_server())
|
.manage(api::start_notification_server())
|
||||||
.attach(util::AppHeaders())
|
.attach(util::AppHeaders())
|
||||||
.attach(util::Cors())
|
.attach(util::Cors())
|
||||||
.attach(util::BetterLogging(extra_debug))
|
.attach(util::BetterLogging(extra_debug))
|
||||||
.launch();
|
.ignite()
|
||||||
|
.await?;
|
||||||
|
|
||||||
// Launch and print error if there is one
|
CONFIG.set_rocket_shutdown_handle(instance.shutdown());
|
||||||
// The launch will restore the original logging level
|
ctrlc::set_handler(move || {
|
||||||
error!("Launch error {:#?}", result);
|
info!("Exiting vaultwarden!");
|
||||||
|
CONFIG.shutdown();
|
||||||
|
})
|
||||||
|
.expect("Error setting Ctrl-C handler");
|
||||||
|
|
||||||
|
instance.launch().await?;
|
||||||
|
|
||||||
|
info!("Vaultwarden process exited!");
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn schedule_jobs(pool: db::DbPool) {
|
async fn schedule_jobs(pool: db::DbPool) {
|
||||||
if CONFIG.job_poll_interval_ms() == 0 {
|
if CONFIG.job_poll_interval_ms() == 0 {
|
||||||
info!("Job scheduler disabled.");
|
info!("Job scheduler disabled.");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let runtime = tokio::runtime::Handle::current();
|
||||||
|
|
||||||
thread::Builder::new()
|
thread::Builder::new()
|
||||||
.name("job-scheduler".to_string())
|
.name("job-scheduler".to_string())
|
||||||
.spawn(move || {
|
.spawn(move || {
|
||||||
|
use job_scheduler::{Job, JobScheduler};
|
||||||
|
|
||||||
let mut sched = JobScheduler::new();
|
let mut sched = JobScheduler::new();
|
||||||
|
|
||||||
// Purge sends that are past their deletion date.
|
// Purge sends that are past their deletion date.
|
||||||
if !CONFIG.send_purge_schedule().is_empty() {
|
if !CONFIG.send_purge_schedule().is_empty() {
|
||||||
sched.add(Job::new(CONFIG.send_purge_schedule().parse().unwrap(), || {
|
sched.add(Job::new(CONFIG.send_purge_schedule().parse().unwrap(), || {
|
||||||
api::purge_sends(pool.clone());
|
runtime.spawn(api::purge_sends(pool.clone()));
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Purge trashed items that are old enough to be auto-deleted.
|
// Purge trashed items that are old enough to be auto-deleted.
|
||||||
if !CONFIG.trash_purge_schedule().is_empty() {
|
if !CONFIG.trash_purge_schedule().is_empty() {
|
||||||
sched.add(Job::new(CONFIG.trash_purge_schedule().parse().unwrap(), || {
|
sched.add(Job::new(CONFIG.trash_purge_schedule().parse().unwrap(), || {
|
||||||
api::purge_trashed_ciphers(pool.clone());
|
runtime.spawn(api::purge_trashed_ciphers(pool.clone()));
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -366,7 +398,7 @@ fn schedule_jobs(pool: db::DbPool) {
|
||||||
// indicates that a user's master password has been compromised.
|
// indicates that a user's master password has been compromised.
|
||||||
if !CONFIG.incomplete_2fa_schedule().is_empty() {
|
if !CONFIG.incomplete_2fa_schedule().is_empty() {
|
||||||
sched.add(Job::new(CONFIG.incomplete_2fa_schedule().parse().unwrap(), || {
|
sched.add(Job::new(CONFIG.incomplete_2fa_schedule().parse().unwrap(), || {
|
||||||
api::send_incomplete_2fa_notifications(pool.clone());
|
runtime.spawn(api::send_incomplete_2fa_notifications(pool.clone()));
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -375,7 +407,7 @@ fn schedule_jobs(pool: db::DbPool) {
|
||||||
// sending reminders for requests that are about to be granted anyway.
|
// sending reminders for requests that are about to be granted anyway.
|
||||||
if !CONFIG.emergency_request_timeout_schedule().is_empty() {
|
if !CONFIG.emergency_request_timeout_schedule().is_empty() {
|
||||||
sched.add(Job::new(CONFIG.emergency_request_timeout_schedule().parse().unwrap(), || {
|
sched.add(Job::new(CONFIG.emergency_request_timeout_schedule().parse().unwrap(), || {
|
||||||
api::emergency_request_timeout_job(pool.clone());
|
runtime.spawn(api::emergency_request_timeout_job(pool.clone()));
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -383,7 +415,7 @@ fn schedule_jobs(pool: db::DbPool) {
|
||||||
// emergency access requests.
|
// emergency access requests.
|
||||||
if !CONFIG.emergency_notification_reminder_schedule().is_empty() {
|
if !CONFIG.emergency_notification_reminder_schedule().is_empty() {
|
||||||
sched.add(Job::new(CONFIG.emergency_notification_reminder_schedule().parse().unwrap(), || {
|
sched.add(Job::new(CONFIG.emergency_notification_reminder_schedule().parse().unwrap(), || {
|
||||||
api::emergency_notification_reminder_job(pool.clone());
|
runtime.spawn(api::emergency_notification_reminder_job(pool.clone()));
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
70
src/util.rs
70
src/util.rs
|
@ -5,10 +5,10 @@ use std::io::Cursor;
|
||||||
|
|
||||||
use rocket::{
|
use rocket::{
|
||||||
fairing::{Fairing, Info, Kind},
|
fairing::{Fairing, Info, Kind},
|
||||||
http::{ContentType, Header, HeaderMap, Method, RawStr, Status},
|
http::{ContentType, Header, HeaderMap, Method, Status},
|
||||||
request::FromParam,
|
request::FromParam,
|
||||||
response::{self, Responder},
|
response::{self, Responder},
|
||||||
Data, Request, Response, Rocket,
|
Data, Orbit, Request, Response, Rocket,
|
||||||
};
|
};
|
||||||
|
|
||||||
use std::thread::sleep;
|
use std::thread::sleep;
|
||||||
|
@ -18,6 +18,7 @@ use crate::CONFIG;
|
||||||
|
|
||||||
pub struct AppHeaders();
|
pub struct AppHeaders();
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
impl Fairing for AppHeaders {
|
impl Fairing for AppHeaders {
|
||||||
fn info(&self) -> Info {
|
fn info(&self) -> Info {
|
||||||
Info {
|
Info {
|
||||||
|
@ -26,7 +27,7 @@ impl Fairing for AppHeaders {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_response(&self, _req: &Request, res: &mut Response) {
|
async fn on_response<'r>(&self, _req: &'r Request<'_>, res: &mut Response<'r>) {
|
||||||
res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), camera=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), sync-xhr=(self \"https://haveibeenpwned.com\" \"https://2fa.directory\"), usb=(), vr=()");
|
res.set_raw_header("Permissions-Policy", "accelerometer=(), ambient-light-sensor=(), autoplay=(), camera=(), encrypted-media=(), fullscreen=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), midi=(), payment=(), picture-in-picture=(), sync-xhr=(self \"https://haveibeenpwned.com\" \"https://2fa.directory\"), usb=(), vr=()");
|
||||||
res.set_raw_header("Referrer-Policy", "same-origin");
|
res.set_raw_header("Referrer-Policy", "same-origin");
|
||||||
res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
|
res.set_raw_header("X-Frame-Options", "SAMEORIGIN");
|
||||||
|
@ -72,6 +73,7 @@ impl Cors {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[rocket::async_trait]
|
||||||
impl Fairing for Cors {
|
impl Fairing for Cors {
|
||||||
fn info(&self) -> Info {
|
fn info(&self) -> Info {
|
||||||
Info {
|
Info {
|
||||||
|
@ -80,7 +82,7 @@ impl Fairing for Cors {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_response(&self, request: &Request, response: &mut Response) {
|
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
|
||||||
let req_headers = request.headers();
|
let req_headers = request.headers();
|
||||||
|
|
||||||
if let Some(origin) = Cors::get_allowed_origin(req_headers) {
|
if let Some(origin) = Cors::get_allowed_origin(req_headers) {
|
||||||
|
@ -97,7 +99,7 @@ impl Fairing for Cors {
|
||||||
response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
|
response.set_header(Header::new("Access-Control-Allow-Credentials", "true"));
|
||||||
response.set_status(Status::Ok);
|
response.set_status(Status::Ok);
|
||||||
response.set_header(ContentType::Plain);
|
response.set_header(ContentType::Plain);
|
||||||
response.set_sized_body(Cursor::new(""));
|
response.set_sized_body(Some(0), Cursor::new(""));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -134,25 +136,21 @@ impl<R> Cached<R> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'r, R: Responder<'r>> Responder<'r> for Cached<R> {
|
impl<'r, R: 'r + Responder<'r, 'static> + Send> Responder<'r, 'static> for Cached<R> {
|
||||||
fn respond_to(self, req: &Request) -> response::Result<'r> {
|
fn respond_to(self, request: &'r Request<'_>) -> response::Result<'static> {
|
||||||
|
let mut res = self.response.respond_to(request)?;
|
||||||
|
|
||||||
let cache_control_header = if self.is_immutable {
|
let cache_control_header = if self.is_immutable {
|
||||||
format!("public, immutable, max-age={}", self.ttl)
|
format!("public, immutable, max-age={}", self.ttl)
|
||||||
} else {
|
} else {
|
||||||
format!("public, max-age={}", self.ttl)
|
format!("public, max-age={}", self.ttl)
|
||||||
};
|
};
|
||||||
|
res.set_raw_header("Cache-Control", cache_control_header);
|
||||||
|
|
||||||
let time_now = chrono::Local::now();
|
let time_now = chrono::Local::now();
|
||||||
|
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
|
||||||
match self.response.respond_to(req) {
|
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
|
||||||
Ok(mut res) => {
|
Ok(res)
|
||||||
res.set_raw_header("Cache-Control", cache_control_header);
|
|
||||||
let expiry_time = time_now + chrono::Duration::seconds(self.ttl.try_into().unwrap());
|
|
||||||
res.set_raw_header("Expires", format_datetime_http(&expiry_time));
|
|
||||||
Ok(res)
|
|
||||||
}
|
|
||||||
e @ Err(_) => e,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,11 +173,9 @@ impl<'r> FromParam<'r> for SafeString {
|
||||||
type Error = ();
|
type Error = ();
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn from_param(param: &'r RawStr) -> Result<Self, Self::Error> {
|
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
|
||||||
let s = param.percent_decode().map(|cow| cow.into_owned()).map_err(|_| ())?;
|
if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
|
||||||
|
Ok(SafeString(param.to_string()))
|
||||||
if s.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
|
|
||||||
Ok(SafeString(s))
|
|
||||||
} else {
|
} else {
|
||||||
Err(())
|
Err(())
|
||||||
}
|
}
|
||||||
|
@ -193,15 +189,16 @@ const LOGGED_ROUTES: [&str; 6] =
|
||||||
|
|
||||||
// Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
|
// Boolean is extra debug, when true, we ignore the whitelist above and also print the mounts
|
||||||
pub struct BetterLogging(pub bool);
|
pub struct BetterLogging(pub bool);
|
||||||
|
#[rocket::async_trait]
|
||||||
impl Fairing for BetterLogging {
|
impl Fairing for BetterLogging {
|
||||||
fn info(&self) -> Info {
|
fn info(&self) -> Info {
|
||||||
Info {
|
Info {
|
||||||
name: "Better Logging",
|
name: "Better Logging",
|
||||||
kind: Kind::Launch | Kind::Request | Kind::Response,
|
kind: Kind::Liftoff | Kind::Request | Kind::Response,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_launch(&self, rocket: &Rocket) {
|
async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
|
||||||
if self.0 {
|
if self.0 {
|
||||||
info!(target: "routes", "Routes loaded:");
|
info!(target: "routes", "Routes loaded:");
|
||||||
let mut routes: Vec<_> = rocket.routes().collect();
|
let mut routes: Vec<_> = rocket.routes().collect();
|
||||||
|
@ -225,34 +222,36 @@ impl Fairing for BetterLogging {
|
||||||
info!(target: "start", "Rocket has launched from {}", addr);
|
info!(target: "start", "Rocket has launched from {}", addr);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_request(&self, request: &mut Request<'_>, _data: &Data) {
|
async fn on_request(&self, request: &mut Request<'_>, _data: &mut Data<'_>) {
|
||||||
let method = request.method();
|
let method = request.method();
|
||||||
if !self.0 && method == Method::Options {
|
if !self.0 && method == Method::Options {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let uri = request.uri();
|
let uri = request.uri();
|
||||||
let uri_path = uri.path();
|
let uri_path = uri.path();
|
||||||
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
|
let uri_path_str = uri_path.url_decode_lossy();
|
||||||
|
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
|
||||||
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
||||||
match uri.query() {
|
match uri.query() {
|
||||||
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path, &q[..q.len().min(30)]),
|
Some(q) => info!(target: "request", "{} {}?{}", method, uri_path_str, &q[..q.len().min(30)]),
|
||||||
None => info!(target: "request", "{} {}", method, uri_path),
|
None => info!(target: "request", "{} {}", method, uri_path_str),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn on_response(&self, request: &Request, response: &mut Response) {
|
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
|
||||||
if !self.0 && request.method() == Method::Options {
|
if !self.0 && request.method() == Method::Options {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
let uri_path = request.uri().path();
|
let uri_path = request.uri().path();
|
||||||
let uri_subpath = uri_path.strip_prefix(&CONFIG.domain_path()).unwrap_or(uri_path);
|
let uri_path_str = uri_path.url_decode_lossy();
|
||||||
|
let uri_subpath = uri_path_str.strip_prefix(&CONFIG.domain_path()).unwrap_or(&uri_path_str);
|
||||||
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
if self.0 || LOGGED_ROUTES.iter().any(|r| uri_subpath.starts_with(r)) {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
if let Some(route) = request.route() {
|
if let Some(ref route) = request.route() {
|
||||||
info!(target: "response", "{} => {} {}", route, status.code, status.reason)
|
info!(target: "response", "{} => {}", route, status)
|
||||||
} else {
|
} else {
|
||||||
info!(target: "response", "{} {}", status.code, status.reason)
|
info!(target: "response", "{}", status)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -614,10 +613,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
use reqwest::{
|
use reqwest::{header, Client, ClientBuilder};
|
||||||
blocking::{Client, ClientBuilder},
|
|
||||||
header,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn get_reqwest_client() -> Client {
|
pub fn get_reqwest_client() -> Client {
|
||||||
get_reqwest_client_builder().build().expect("Failed to build client")
|
get_reqwest_client_builder().build().expect("Failed to build client")
|
||||||
|
|
Loading…
Reference in a new issue