Merge pull request #3806 from BlackDex/fix-3776

Allow Authorization header for Web Sockets
This commit is contained in:
Daniel García 2023-08-31 20:46:07 +02:00 committed by GitHub
commit f579a4154c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 2 deletions

View file

@ -20,7 +20,7 @@ use tokio_tungstenite::{
}; };
use crate::{ use crate::{
auth::ClientIp, auth::{ClientIp, WsAccessTokenHeader},
db::{ db::{
models::{Cipher, Folder, Send as DbSend, User}, models::{Cipher, Folder, Send as DbSend, User},
DbConn, DbConn,
@ -111,11 +111,19 @@ fn websockets_hub<'r>(
ws: rocket_ws::WebSocket, ws: rocket_ws::WebSocket,
data: WsAccessToken, data: WsAccessToken,
ip: ClientIp, ip: ClientIp,
header_token: WsAccessTokenHeader,
) -> Result<rocket_ws::Stream!['r], Error> { ) -> Result<rocket_ws::Stream!['r], Error> {
let addr = ip.ip; let addr = ip.ip;
info!("Accepting Rocket WS connection from {addr}"); info!("Accepting Rocket WS connection from {addr}");
let Some(token) = data.access_token else { err_code!("Invalid claim", 401) }; let token = if let Some(token) = data.access_token {
token
} else if let Some(token) = header_token.access_token {
token
} else {
err_code!("Invalid claim", 401)
};
let Ok(claims) = crate::auth::decode_login(&token) else { err_code!("Invalid token", 401) }; let Ok(claims) = crate::auth::decode_login(&token) else { err_code!("Invalid token", 401) };
let (mut rx, guard) = { let (mut rx, guard) = {

View file

@ -825,3 +825,26 @@ impl<'r> FromRequest<'r> for ClientIp {
}) })
} }
} }
pub struct WsAccessTokenHeader {
pub access_token: Option<String>,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for WsAccessTokenHeader {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
let headers = request.headers();
// Get access_token
let access_token = match headers.get_one("Authorization") {
Some(a) => a.rsplit("Bearer ").next().map(String::from),
None => None,
};
Outcome::Success(Self {
access_token,
})
}
}