From 0e83d48c0439edc88823e40a967be49d9c984f3f Mon Sep 17 00:00:00 2001 From: Evan Almloff Date: Mon, 10 Jul 2023 18:48:55 -0500 Subject: [PATCH] Fix salvo fullstack builds --- packages/fullstack/Cargo.toml | 2 +- .../examples/salvo-hello-world/src/main.rs | 2 - .../fullstack/src/adapters/salvo_adapter.rs | 207 +++++++++--------- packages/fullstack/src/launch.rs | 5 +- 4 files changed, 102 insertions(+), 114 deletions(-) diff --git a/packages/fullstack/Cargo.toml b/packages/fullstack/Cargo.toml index b1821eac6..6ba72b995 100644 --- a/packages/fullstack/Cargo.toml +++ b/packages/fullstack/Cargo.toml @@ -28,7 +28,7 @@ tower = { version = "0.4.13", features = ["util"], optional = true } axum-macros = "0.3.7" # salvo -salvo = { version = "0.37.7", optional = true, features = ["serve-static", "ws", "compression"] } +salvo = { version = "0.46.0", optional = true, features = ["serve-static", "websocket", "compression"] } serde = "1.0.159" # Dioxus + SSR diff --git a/packages/fullstack/examples/salvo-hello-world/src/main.rs b/packages/fullstack/examples/salvo-hello-world/src/main.rs index 84e190aab..9e5649efc 100644 --- a/packages/fullstack/examples/salvo-hello-world/src/main.rs +++ b/packages/fullstack/examples/salvo-hello-world/src/main.rs @@ -50,8 +50,6 @@ fn app(cx: Scope) -> Element { async fn post_server_data(data: String) -> Result<(), ServerFnError> { // The server context contains information about the current request and allows you to modify the response. let cx = server_context(); - cx.response_headers_mut() - .insert("Set-Cookie", "foo=bar".parse().unwrap()); println!("Server received: {}", data); println!("Request parts are {:?}", cx.request_parts()); diff --git a/packages/fullstack/src/adapters/salvo_adapter.rs b/packages/fullstack/src/adapters/salvo_adapter.rs index 5008f3f73..dc536376d 100644 --- a/packages/fullstack/src/adapters/salvo_adapter.rs +++ b/packages/fullstack/src/adapters/salvo_adapter.rs @@ -49,21 +49,31 @@ //! } //! ``` -use hyper::{http::HeaderValue, StatusCode}; +use http_body_util::{BodyExt, Limited}; +use hyper::body::Body as HyperBody; +use hyper::StatusCode; use salvo::{ async_trait, handler, + http::{ + cookie::{Cookie, CookieJar}, + ParseError, + }, serve_static::{StaticDir, StaticFile}, - Depot, FlowCtrl, Handler, Request, Response, Router, + Depot, Error as SalvoError, FlowCtrl, Handler, Request, Response, Router, }; -use server_fn::{Encoding, Payload, ServerFunctionRegistry}; +use server_fn::{Encoding, ServerFunctionRegistry}; use std::error::Error; use std::sync::Arc; -use tokio::task::spawn_blocking; +use std::sync::RwLock; use crate::{ - prelude::*, render::SSRState, serve_config::ServeConfig, server_fn::DioxusServerFnRegistry, + layer::Service, prelude::*, render::SSRState, serve_config::ServeConfig, + server_fn::DioxusServerFnRegistry, server_fn_service, }; +type HyperRequest = hyper::Request; +type HyperResponse = hyper::Response; + /// A extension trait with utilities for integrating Dioxus with your Salvo router. pub trait DioxusRouterExt { /// Registers server functions with a custom handler function. This allows you to pass custom context to your server functions by generating a [`DioxusServerContext`] from the request. @@ -297,13 +307,71 @@ impl DioxusRouterExt for Router { } /// Extracts the parts of a request that are needed for server functions. This will take parts of the request and replace them with empty values. -pub fn extract_parts(req: &mut Request) -> RequestParts { - RequestParts { - method: std::mem::take(req.method_mut()), - uri: std::mem::take(req.uri_mut()), - version: req.version(), - headers: std::mem::take(req.headers_mut()), - extensions: std::mem::take(req.extensions_mut()), +pub fn extract_parts(req: &mut Request) -> http::request::Parts { + let mut parts = http::request::Request::new(()).into_parts().0; + parts.method = std::mem::take(req.method_mut()); + parts.uri = std::mem::take(req.uri_mut()); + parts.version = req.version(); + parts.headers = std::mem::take(req.headers_mut()); + parts.extensions = std::mem::take(req.extensions_mut()); + + parts +} + +fn apply_request_parts_to_response( + headers: hyper::header::HeaderMap, + response: &mut salvo::prelude::Response, +) { + let mut_headers = response.headers_mut(); + for (key, value) in headers.iter() { + mut_headers.insert(key, value.clone()); + } +} + +#[inline] +async fn convert_request(req: &mut Request) -> Result { + let forward_url: hyper::Uri = TryFrom::try_from(req.uri()).map_err(SalvoError::other)?; + let mut build = hyper::Request::builder() + .method(req.method()) + .uri(&forward_url); + for (key, value) in req.headers() { + build = build.header(key, value); + } + static SECURE_MAX_SIZE: usize = 64 * 1024; + + let body = Limited::new(req.take_body(), SECURE_MAX_SIZE) + .collect() + .await + .map_err(ParseError::other)? + .to_bytes(); + build.body(body.into()).map_err(SalvoError::other) +} + +#[inline] +async fn convert_response(response: HyperResponse, res: &mut Response) { + let (parts, body) = response.into_parts(); + let http::response::Parts { + version, + headers, + status, + .. + } = parts; + res.status_code = Some(status); + res.version = version; + res.cookies = CookieJar::new(); + for cookie in headers.get_all(http::header::SET_COOKIE).iter() { + if let Some(cookie) = cookie + .to_str() + .ok() + .and_then(|s| Cookie::parse(s.to_string()).ok()) + { + res.cookies.add_original(cookie); + } + } + res.headers = headers; + res.version = version; + if let Ok(bytes) = hyper::body::to_bytes(body).await { + res.body = bytes.into() } } @@ -328,8 +396,9 @@ impl Handler for SSRHandler depot.inject(renderer.clone()); renderer }; - let parts: Arc = Arc::new(extract_parts(req)); - let route = parts.uri.path().to_string(); + + let route = req.uri().path().to_string(); + let parts: Arc> = Arc::new(RwLock::new(extract_parts(req))); let server_context = DioxusServerContext::new(parts); match renderer_pool @@ -341,7 +410,8 @@ impl Handler for SSRHandler res.write_body(html).unwrap(); - *res.headers_mut() = server_context.take_response_headers(); + let headers = server_context.response_parts().unwrap().headers.clone(); + apply_request_parts_to_response(headers, res); freshness.write(res.headers_mut()); } Err(err) => { @@ -375,95 +445,14 @@ impl ServerFnHandler { #[handler] impl ServerFnHandler { async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response) { - let Self { - server_context, - function, - } = self; - - let query = req - .uri() - .query() - .unwrap_or_default() - .as_bytes() - .to_vec() - .into(); - let body = hyper::body::to_bytes(req.body_mut().unwrap()).await; - let Ok(body)=body else { - handle_error(body.err().unwrap(), res); - return; - }; - let headers = req.headers(); - let accept_header = headers.get("Accept").cloned(); - - let parts = Arc::new(extract_parts(req)); - - // Because the future returned by `server_fn_handler` is `Send`, and the future returned by this function must be send, we need to spawn a new runtime - let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); - spawn_blocking({ - let function = function.clone(); - let mut server_context = server_context.clone(); - server_context.parts = parts; - move || { - tokio::runtime::Runtime::new() - .expect("couldn't spawn runtime") - .block_on(async move { - let data = match function.encoding() { - Encoding::Url | Encoding::Cbor => &body, - Encoding::GetJSON | Encoding::GetCBOR => &query, - }; - let server_function_future = function.call((), data); - let server_function_future = ProvideServerContext::new( - server_function_future, - server_context.clone(), - ); - let resp = server_function_future.await; - - resp_tx.send(resp).unwrap(); - }) - } - }); - let result = resp_rx.await.unwrap(); - - // Set the headers from the server context - *res.headers_mut() = server_context.take_response_headers(); - - match result { - Ok(serialized) => { - // if this is Accept: application/json then send a serialized JSON response - let accept_header = accept_header.as_ref().and_then(|value| value.to_str().ok()); - if accept_header == Some("application/json") - || accept_header - == Some( - "application/\ - x-www-form-urlencoded", - ) - || accept_header == Some("application/cbor") - { - res.set_status_code(StatusCode::OK); - } - - match serialized { - Payload::Binary(data) => { - res.headers_mut() - .insert("Content-Type", HeaderValue::from_static("application/cbor")); - res.write_body(data).unwrap(); - } - Payload::Url(data) => { - res.headers_mut().insert( - "Content-Type", - HeaderValue::from_static( - "application/\ - x-www-form-urlencoded", - ), - ); - res.write_body(data).unwrap(); - } - Payload::Json(data) => { - res.headers_mut() - .insert("Content-Type", HeaderValue::from_static("application/json")); - res.write_body(data).unwrap(); - } - } + match convert_request(req).await { + Ok(hyper_req) => { + let response = + server_fn_service(self.server_context.clone(), self.function.clone()) + .run(hyper_req) + .await + .unwrap(); + convert_response(response, res).await; } Err(err) => handle_error(err, res), } @@ -472,7 +461,7 @@ impl ServerFnHandler { fn handle_error(error: impl Error + Send + Sync, res: &mut Response) { let mut resp_err = Response::new(); - resp_err.set_status_code(StatusCode::INTERNAL_SERVER_ERROR); + resp_err.status_code(StatusCode::INTERNAL_SERVER_ERROR); resp_err.render(format!("Internal Server Error: {}", error)); *res = resp_err; } @@ -509,8 +498,8 @@ impl HotReloadHandler { _depot: &mut Depot, res: &mut Response, ) -> Result<(), salvo::http::StatusError> { - use salvo::ws::Message; - use salvo::ws::WebSocketUpgrade; + use salvo::websocket::Message; + use salvo::websocket::WebSocketUpgrade; let state = crate::hot_reload::spawn_hot_reload().await; @@ -557,10 +546,10 @@ impl HotReloadHandler { #[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))] #[handler] async fn ignore_ws(req: &mut Request, res: &mut Response) -> Result<(), salvo::http::StatusError> { - use salvo::ws::WebSocketUpgrade; + use salvo::websocket::WebSocketUpgrade; WebSocketUpgrade::new() .upgrade(req, res, |mut ws| async move { - let _ = ws.send(salvo::ws::Message::text("connected")).await; + let _ = ws.send(salvo::websocket::Message::text("connected")).await; while let Some(msg) = ws.recv().await { if msg.is_err() { return; diff --git a/packages/fullstack/src/launch.rs b/packages/fullstack/src/launch.rs index ced609cf4..6432e69ce 100644 --- a/packages/fullstack/src/launch.rs +++ b/packages/fullstack/src/launch.rs @@ -161,11 +161,12 @@ pub async fn launch_server( #[cfg(all(feature = "salvo", not(feature = "axum"), not(feature = "warp")))] { use crate::adapters::salvo_adapter::DioxusRouterExt; + use salvo::conn::Listener; let router = salvo::Router::new().serve_dioxus_application("", cfg).hoop( salvo::compression::Compression::new() - .with_algos(&[salvo::prelude::CompressionAlgo::Gzip]), + .enable_gzip(salvo::prelude::CompressionLevel::Default), ); - salvo::Server::new(salvo::listener::TcpListener::bind(addr)) + salvo::Server::new(salvo::conn::tcp::TcpListener::new(addr).bind().await) .serve(router) .await; }