Fix salvo fullstack builds

This commit is contained in:
Evan Almloff 2023-07-10 18:48:55 -05:00
parent 73a45fb113
commit 0e83d48c04
4 changed files with 102 additions and 114 deletions

View file

@ -28,7 +28,7 @@ tower = { version = "0.4.13", features = ["util"], optional = true }
axum-macros = "0.3.7"
# salvo
salvo = { version = "0.37.7", optional = true, features = ["serve-static", "ws", "compression"] }
salvo = { version = "0.46.0", optional = true, features = ["serve-static", "websocket", "compression"] }
serde = "1.0.159"
# Dioxus + SSR

View file

@ -50,8 +50,6 @@ fn app(cx: Scope<AppProps>) -> Element {
async fn post_server_data(data: String) -> Result<(), ServerFnError> {
// The server context contains information about the current request and allows you to modify the response.
let cx = server_context();
cx.response_headers_mut()
.insert("Set-Cookie", "foo=bar".parse().unwrap());
println!("Server received: {}", data);
println!("Request parts are {:?}", cx.request_parts());

View file

@ -49,21 +49,31 @@
//! }
//! ```
use hyper::{http::HeaderValue, StatusCode};
use http_body_util::{BodyExt, Limited};
use hyper::body::Body as HyperBody;
use hyper::StatusCode;
use salvo::{
async_trait, handler,
http::{
cookie::{Cookie, CookieJar},
ParseError,
},
serve_static::{StaticDir, StaticFile},
Depot, FlowCtrl, Handler, Request, Response, Router,
Depot, Error as SalvoError, FlowCtrl, Handler, Request, Response, Router,
};
use server_fn::{Encoding, Payload, ServerFunctionRegistry};
use server_fn::{Encoding, ServerFunctionRegistry};
use std::error::Error;
use std::sync::Arc;
use tokio::task::spawn_blocking;
use std::sync::RwLock;
use crate::{
prelude::*, render::SSRState, serve_config::ServeConfig, server_fn::DioxusServerFnRegistry,
layer::Service, prelude::*, render::SSRState, serve_config::ServeConfig,
server_fn::DioxusServerFnRegistry, server_fn_service,
};
type HyperRequest = hyper::Request<hyper::Body>;
type HyperResponse = hyper::Response<HyperBody>;
/// A extension trait with utilities for integrating Dioxus with your Salvo router.
pub trait DioxusRouterExt {
/// Registers server functions with a custom handler function. This allows you to pass custom context to your server functions by generating a [`DioxusServerContext`] from the request.
@ -297,13 +307,71 @@ impl DioxusRouterExt for Router {
}
/// Extracts the parts of a request that are needed for server functions. This will take parts of the request and replace them with empty values.
pub fn extract_parts(req: &mut Request) -> RequestParts {
RequestParts {
method: std::mem::take(req.method_mut()),
uri: std::mem::take(req.uri_mut()),
version: req.version(),
headers: std::mem::take(req.headers_mut()),
extensions: std::mem::take(req.extensions_mut()),
pub fn extract_parts(req: &mut Request) -> http::request::Parts {
let mut parts = http::request::Request::new(()).into_parts().0;
parts.method = std::mem::take(req.method_mut());
parts.uri = std::mem::take(req.uri_mut());
parts.version = req.version();
parts.headers = std::mem::take(req.headers_mut());
parts.extensions = std::mem::take(req.extensions_mut());
parts
}
fn apply_request_parts_to_response(
headers: hyper::header::HeaderMap,
response: &mut salvo::prelude::Response,
) {
let mut_headers = response.headers_mut();
for (key, value) in headers.iter() {
mut_headers.insert(key, value.clone());
}
}
#[inline]
async fn convert_request(req: &mut Request) -> Result<HyperRequest, SalvoError> {
let forward_url: hyper::Uri = TryFrom::try_from(req.uri()).map_err(SalvoError::other)?;
let mut build = hyper::Request::builder()
.method(req.method())
.uri(&forward_url);
for (key, value) in req.headers() {
build = build.header(key, value);
}
static SECURE_MAX_SIZE: usize = 64 * 1024;
let body = Limited::new(req.take_body(), SECURE_MAX_SIZE)
.collect()
.await
.map_err(ParseError::other)?
.to_bytes();
build.body(body.into()).map_err(SalvoError::other)
}
#[inline]
async fn convert_response(response: HyperResponse, res: &mut Response) {
let (parts, body) = response.into_parts();
let http::response::Parts {
version,
headers,
status,
..
} = parts;
res.status_code = Some(status);
res.version = version;
res.cookies = CookieJar::new();
for cookie in headers.get_all(http::header::SET_COOKIE).iter() {
if let Some(cookie) = cookie
.to_str()
.ok()
.and_then(|s| Cookie::parse(s.to_string()).ok())
{
res.cookies.add_original(cookie);
}
}
res.headers = headers;
res.version = version;
if let Ok(bytes) = hyper::body::to_bytes(body).await {
res.body = bytes.into()
}
}
@ -328,8 +396,9 @@ impl<P: Clone + serde::Serialize + Send + Sync + 'static> Handler for SSRHandler
depot.inject(renderer.clone());
renderer
};
let parts: Arc<RequestParts> = Arc::new(extract_parts(req));
let route = parts.uri.path().to_string();
let route = req.uri().path().to_string();
let parts: Arc<RwLock<http::request::Parts>> = Arc::new(RwLock::new(extract_parts(req)));
let server_context = DioxusServerContext::new(parts);
match renderer_pool
@ -341,7 +410,8 @@ impl<P: Clone + serde::Serialize + Send + Sync + 'static> Handler for SSRHandler
res.write_body(html).unwrap();
*res.headers_mut() = server_context.take_response_headers();
let headers = server_context.response_parts().unwrap().headers.clone();
apply_request_parts_to_response(headers, res);
freshness.write(res.headers_mut());
}
Err(err) => {
@ -375,95 +445,14 @@ impl ServerFnHandler {
#[handler]
impl ServerFnHandler {
async fn handle(&self, req: &mut Request, _depot: &mut Depot, res: &mut Response) {
let Self {
server_context,
function,
} = self;
let query = req
.uri()
.query()
.unwrap_or_default()
.as_bytes()
.to_vec()
.into();
let body = hyper::body::to_bytes(req.body_mut().unwrap()).await;
let Ok(body)=body else {
handle_error(body.err().unwrap(), res);
return;
};
let headers = req.headers();
let accept_header = headers.get("Accept").cloned();
let parts = Arc::new(extract_parts(req));
// Because the future returned by `server_fn_handler` is `Send`, and the future returned by this function must be send, we need to spawn a new runtime
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
spawn_blocking({
let function = function.clone();
let mut server_context = server_context.clone();
server_context.parts = parts;
move || {
tokio::runtime::Runtime::new()
.expect("couldn't spawn runtime")
.block_on(async move {
let data = match function.encoding() {
Encoding::Url | Encoding::Cbor => &body,
Encoding::GetJSON | Encoding::GetCBOR => &query,
};
let server_function_future = function.call((), data);
let server_function_future = ProvideServerContext::new(
server_function_future,
server_context.clone(),
);
let resp = server_function_future.await;
resp_tx.send(resp).unwrap();
})
}
});
let result = resp_rx.await.unwrap();
// Set the headers from the server context
*res.headers_mut() = server_context.take_response_headers();
match result {
Ok(serialized) => {
// if this is Accept: application/json then send a serialized JSON response
let accept_header = accept_header.as_ref().and_then(|value| value.to_str().ok());
if accept_header == Some("application/json")
|| accept_header
== Some(
"application/\
x-www-form-urlencoded",
)
|| accept_header == Some("application/cbor")
{
res.set_status_code(StatusCode::OK);
}
match serialized {
Payload::Binary(data) => {
res.headers_mut()
.insert("Content-Type", HeaderValue::from_static("application/cbor"));
res.write_body(data).unwrap();
}
Payload::Url(data) => {
res.headers_mut().insert(
"Content-Type",
HeaderValue::from_static(
"application/\
x-www-form-urlencoded",
),
);
res.write_body(data).unwrap();
}
Payload::Json(data) => {
res.headers_mut()
.insert("Content-Type", HeaderValue::from_static("application/json"));
res.write_body(data).unwrap();
}
}
match convert_request(req).await {
Ok(hyper_req) => {
let response =
server_fn_service(self.server_context.clone(), self.function.clone())
.run(hyper_req)
.await
.unwrap();
convert_response(response, res).await;
}
Err(err) => handle_error(err, res),
}
@ -472,7 +461,7 @@ impl ServerFnHandler {
fn handle_error(error: impl Error + Send + Sync, res: &mut Response) {
let mut resp_err = Response::new();
resp_err.set_status_code(StatusCode::INTERNAL_SERVER_ERROR);
resp_err.status_code(StatusCode::INTERNAL_SERVER_ERROR);
resp_err.render(format!("Internal Server Error: {}", error));
*res = resp_err;
}
@ -509,8 +498,8 @@ impl HotReloadHandler {
_depot: &mut Depot,
res: &mut Response,
) -> Result<(), salvo::http::StatusError> {
use salvo::ws::Message;
use salvo::ws::WebSocketUpgrade;
use salvo::websocket::Message;
use salvo::websocket::WebSocketUpgrade;
let state = crate::hot_reload::spawn_hot_reload().await;
@ -557,10 +546,10 @@ impl HotReloadHandler {
#[cfg(all(debug_assertions, feature = "hot-reload", feature = "ssr"))]
#[handler]
async fn ignore_ws(req: &mut Request, res: &mut Response) -> Result<(), salvo::http::StatusError> {
use salvo::ws::WebSocketUpgrade;
use salvo::websocket::WebSocketUpgrade;
WebSocketUpgrade::new()
.upgrade(req, res, |mut ws| async move {
let _ = ws.send(salvo::ws::Message::text("connected")).await;
let _ = ws.send(salvo::websocket::Message::text("connected")).await;
while let Some(msg) = ws.recv().await {
if msg.is_err() {
return;

View file

@ -161,11 +161,12 @@ pub async fn launch_server<P: Clone + serde::Serialize + Send + Sync + 'static>(
#[cfg(all(feature = "salvo", not(feature = "axum"), not(feature = "warp")))]
{
use crate::adapters::salvo_adapter::DioxusRouterExt;
use salvo::conn::Listener;
let router = salvo::Router::new().serve_dioxus_application("", cfg).hoop(
salvo::compression::Compression::new()
.with_algos(&[salvo::prelude::CompressionAlgo::Gzip]),
.enable_gzip(salvo::prelude::CompressionLevel::Default),
);
salvo::Server::new(salvo::listener::TcpListener::bind(addr))
salvo::Server::new(salvo::conn::tcp::TcpListener::new(addr).bind().await)
.serve(router)
.await;
}