diff --git a/src/auth.rs b/src/auth.rs index 4ee9c188..0f4a3076 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -1,13 +1,18 @@ // JWT Handling // use chrono::{TimeDelta, Utc}; +use jsonwebtoken::{errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header}; use num_traits::FromPrimitive; use once_cell::sync::{Lazy, OnceCell}; - -use jsonwebtoken::{errors::ErrorKind, Algorithm, DecodingKey, EncodingKey, Header}; use openssl::rsa::Rsa; use serde::de::DeserializeOwned; use serde::ser::Serialize; +use std::{ + env, + fs::File, + io::{Read, Write}, + net::IpAddr, +}; use crate::{error::Error, CONFIG}; @@ -31,27 +36,36 @@ static PRIVATE_RSA_KEY: OnceCell = OnceCell::new(); static PUBLIC_RSA_KEY: OnceCell = OnceCell::new(); pub fn initialize_keys() -> Result<(), crate::error::Error> { - let mut priv_key_buffer = Vec::with_capacity(2048); + fn read_key(create_if_missing: bool) -> Result<(Rsa, Vec), crate::error::Error> { + let mut priv_key_buffer = Vec::with_capacity(2048); - let priv_key = { - let mut priv_key_file = - File::options().create(true).truncate(false).read(true).write(true).open(CONFIG.private_rsa_key())?; + let mut priv_key_file = File::options() + .create(create_if_missing) + .truncate(false) + .read(true) + .write(create_if_missing) + .open(CONFIG.private_rsa_key())?; #[allow(clippy::verbose_file_reads)] let bytes_read = priv_key_file.read_to_end(&mut priv_key_buffer)?; - if bytes_read > 0 { + let rsa_key = if bytes_read > 0 { Rsa::private_key_from_pem(&priv_key_buffer[..bytes_read])? - } else { + } else if create_if_missing { // Only create the key if the file doesn't exist or is empty let rsa_key = openssl::rsa::Rsa::generate(2048)?; priv_key_buffer = rsa_key.private_key_to_pem()?; priv_key_file.write_all(&priv_key_buffer)?; - info!("Private key created correctly."); + info!("Private key '{}' created correctly", CONFIG.private_rsa_key()); rsa_key - } - }; + } else { + err!("Private key does not exist or invalid format", CONFIG.private_rsa_key()); + }; + Ok((rsa_key, priv_key_buffer)) + } + + let (priv_key, priv_key_buffer) = read_key(true).or_else(|_| read_key(false))?; let pub_key_buffer = priv_key.public_key_to_pem()?; let enc = EncodingKey::from_rsa_pem(&priv_key_buffer)?; @@ -803,12 +817,6 @@ impl<'r> FromRequest<'r> for OwnerHeaders { // // Client IP address detection // -use std::{ - env, - fs::File, - io::{Read, Write}, - net::IpAddr, -}; pub struct ClientIp { pub ip: IpAddr, diff --git a/src/main.rs b/src/main.rs index ecc4f320..8a3f0eb8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -73,11 +73,9 @@ async fn main() -> Result<(), Error> { }); init_logging(level).ok(); - let extra_debug = matches!(level, LF::Trace | LF::Debug); - check_data_folder().await; - auth::initialize_keys().unwrap_or_else(|_| { - error!("Error creating keys, exiting..."); + auth::initialize_keys().unwrap_or_else(|e| { + error!("Error creating private key '{}'\n{e:?}\nExiting Vaultwarden!", CONFIG.private_rsa_key()); exit(1); }); check_web_vault(); @@ -91,6 +89,7 @@ async fn main() -> Result<(), Error> { schedule_jobs(pool.clone()); crate::db::models::TwoFactor::migrate_u2f_to_webauthn(&mut pool.get().await.unwrap()).await.unwrap(); + let extra_debug = matches!(level, LF::Trace | LF::Debug); launch_rocket(pool, extra_debug).await // Blocks until program termination. } @@ -514,7 +513,7 @@ async fn launch_rocket(pool: db::DbPool, extra_debug: bool) -> Result<(), Error> tokio::spawn(async move { tokio::signal::ctrl_c().await.expect("Error setting Ctrl-C handler"); - info!("Exiting vaultwarden!"); + info!("Exiting Vaultwarden!"); CONFIG.shutdown(); });