handle client-side and server-side redirects correctly (in Axum)

This commit is contained in:
Greg Johnston 2024-01-06 11:46:47 -05:00
parent 566df034ff
commit 4e3f1c834c
13 changed files with 111 additions and 16 deletions

View file

@ -77,6 +77,7 @@ pub async fn add_todo(title: String) -> Result<(), ServerFnError> {
pub async fn delete_todo(id: u16) -> Result<(), ServerFnError> {
let mut conn = db().await?;
leptos_axum::redirect("/foo");
Ok(sqlx::query("DELETE FROM todos WHERE id = $1")
.bind(id)
.execute(&mut conn)

View file

@ -37,7 +37,7 @@ use axum::{
body::{Body, Bytes},
extract::{FromRef, FromRequestParts, MatchedPath},
http::{
header::{self, HeaderName, HeaderValue},
header::{self, HeaderName, HeaderValue, ACCEPT, LOCATION, REFERER},
request::Parts,
HeaderMap, Method, Request, Response, StatusCode,
},
@ -54,6 +54,7 @@ use leptos_meta::{generate_head_metadata_separated, MetaContext};
use leptos_router::*;
use once_cell::sync::OnceCell;
use parking_lot::RwLock;
use server_fn::redirect::REDIRECT_HEADER;
use std::{
error::Error, fmt::Debug, io, pin::Pin, sync::Arc,
thread::available_parallelism,
@ -114,13 +115,40 @@ impl ResponseOptions {
/// it sets a StatusCode of 302 and a LOCATION header with the provided value.
/// If looking to redirect from the client, `leptos_router::use_navigate()` should be used instead
pub fn redirect(path: &str) {
if let Some(response_options) = use_context::<ResponseOptions>() {
response_options.set_status(StatusCode::FOUND);
response_options.insert_header(
if let (Some(req), Some(res)) =
(use_context::<Parts>(), use_context::<ResponseOptions>())
{
// insert the Location header in any case
res.insert_header(
header::LOCATION,
header::HeaderValue::from_str(path)
.expect("Failed to create HeaderValue"),
);
let accepts_html = req
.headers
.get(ACCEPT)
.and_then(|v| v.to_str().ok())
.map(|v| v.contains("text/html"))
.unwrap_or(false);
if accepts_html {
// if the request accepts text/html, it's a plain form request and needs
// to have the 302 code set
res.set_status(StatusCode::FOUND);
} else {
// otherwise, we sent it from the server fn client and actually don't want
// to set a real redirect, as this will break the ability to return data
// instead, set the REDIRECT_HEADER to indicate that the client should redirect
res.insert_header(
HeaderName::from_static(REDIRECT_HEADER),
HeaderValue::from_str("").unwrap(),
);
}
} else {
tracing::warn!(
"Couldn't retrieve either Parts or ResponseOptions while trying \
to redirect()."
);
}
}
@ -244,13 +272,36 @@ async fn handle_server_fns_inner(
provide_context(parts);
provide_context(ResponseOptions::default());
// store Accepts and Referer in case we need them for redirect (below)
let accepts_html = req
.headers()
.get(ACCEPT)
.and_then(|v| v.to_str().ok())
.map(|v| v.contains("text/html"))
.unwrap_or(false);
let referrer = req.headers().get(REFERER).cloned();
// actually run the server fn
let mut res = service.run(req).await;
// update response as needed
let res_options = expect_context::<ResponseOptions>().0;
let res_options_inner = res_options.read();
let (status, mut res_headers) =
(res_options_inner.status, res_options_inner.headers.clone());
// it it accepts text/html (i.e., is a plain form post) and doesn't already have a
// Location set, then redirect to to Referer
if accepts_html {
if let Some(referrer) = referrer {
let has_location = res.headers().get(LOCATION).is_some();
if !has_location {
*res.status_mut() = StatusCode::FOUND;
res.headers_mut().insert(LOCATION, referrer);
}
}
}
// apply status code and headers if used changed them
if let Some(status) = status {
*res.status_mut() = status;

View file

@ -31,6 +31,7 @@ wasm-bindgen-futures = { version = "0.4" }
lru = { version = "0.11", optional = true }
serde_json = "1.0.96"
itertools = "0.12.0"
send_wrapper = "0.6.0"
[dependencies.web-sys]
version = "0.3"

View file

@ -8,15 +8,16 @@ use leptos::{
server_fn::{
client::Client,
codec::{Encoding, PostUrl},
redirect::RedirectHook,
request::ClientReq,
ServerFn,
},
*,
};
use send_wrapper::SendWrapper;
use serde::{de::DeserializeOwned, Serialize};
use std::{error::Error, fmt::Debug, rc::Rc};
use wasm_bindgen::{JsCast, UnwrapThrowExt};
use wasm_bindgen_futures::JsFuture;
use web_sys::{
FormData, HtmlButtonElement, HtmlFormElement, HtmlInputElement,
RequestRedirect, SubmitEvent,
@ -454,6 +455,7 @@ where
ServFn:
Clone + DeserializeOwned + ServerFn<InputEncoding = PostUrl> + 'static,
ServerFnError<ServFn::Error>: Debug + Clone,
ServFn::Output: Debug,
ServFn::Error: Debug + 'static,
<<ServFn::Client as Client<ServFn::Error>>::Request as ClientReq<
ServFn::Error,
@ -485,10 +487,24 @@ where
ev.prevent_default();
let navigate = has_router.then(use_navigate);
let navigate_options = NavigateOptions {
let navigate_options = SendWrapper::new(NavigateOptions {
scroll: !noscroll,
..Default::default()
};
});
let redirect_hook = navigate.map(|navigate| {
let navigate = SendWrapper::new(navigate);
Box::new(move |path: &str| {
let path = path.to_string();
// delay by a tick here, so that the Action updates *before* the redirect
request_animation_frame({
let navigate = navigate.clone();
let navigate_options = navigate_options.clone();
move || {
navigate(&path, navigate_options.take());
}
});
}) as RedirectHook
});
let form =
form_from_event(&ev).expect("couldn't find form submitter");
@ -504,10 +520,14 @@ where
match req {
Ok(req) => {
spawn_local(async move {
// TODO set input
// TODO check order of setting things here, and use batch as needed
// TODO set version?
match <ServFn as ServerFn>::run_on_client_with_req(req)
.await
match <ServFn as ServerFn>::run_on_client_with_req(
req,
redirect_hook.as_ref(),
)
.await
{
Ok(res) => {
batch(move || {

View file

@ -190,7 +190,7 @@ pub fn use_resolved_path(
/// # runtime.dispose();
/// ```
#[track_caller]
pub fn use_navigate() -> impl Fn(&str, NavigateOptions) {
pub fn use_navigate() -> impl Fn(&str, NavigateOptions) + Clone {
let router = use_router();
move |to, options| {
let router = Rc::clone(&router.inner);

View file

@ -23,6 +23,7 @@ use error::ServerFnErrorSerde;
use http::Method;
use middleware::{Layer, Service};
use once_cell::sync::Lazy;
use redirect::RedirectHook;
use request::Req;
use response::{ClientRes, Res};
#[doc(hidden)]
@ -104,12 +105,14 @@ where
// create and send request on client
let req =
self.into_req(Self::PATH, Self::OutputEncoding::CONTENT_TYPE)?;
Self::run_on_client_with_req(req).await
Self::run_on_client_with_req(req, redirect::REDIRECT_HOOK.get())
.await
}
}
fn run_on_client_with_req(
req: <Self::Client as Client<Self::Error>>::Request,
redirect_hook: Option<&RedirectHook>,
) -> impl Future<Output = Result<Self::Output, ServerFnError<Self::Error>>> + Send
{
async move {
@ -117,6 +120,7 @@ where
let status = res.status();
let location = res.location();
let has_redirect_header = res.has_redirect();
// if it returns an error status, deserialize the error using FromStr
let res = if (400..=599).contains(&status) {
@ -128,8 +132,10 @@ where
}?;
// if redirected, call the redirect hook (if that's been set)
if (300..=399).contains(&status) {
redirect::call_redirect_hook(&location);
if let Some(redirect_hook) = redirect_hook {
if (300..=399).contains(&status) || has_redirect_header {
redirect_hook(&location);
}
}
res
}

View file

@ -1,10 +1,12 @@
use std::sync::OnceLock;
pub const REDIRECT_HEADER: &str = "serverfnredirect";
pub type RedirectHook = Box<dyn Fn(&str) + Send + Sync>;
// allowed: not in a public API, and pretty straightforward
#[allow(clippy::type_complexity)]
static REDIRECT_HOOK: OnceLock<RedirectHook> = OnceLock::new();
pub(crate) static REDIRECT_HOOK: OnceLock<RedirectHook> = OnceLock::new();
pub fn set_redirect_hook(
hook: impl Fn(&str) + Send + Sync + 'static,

View file

@ -25,6 +25,7 @@ impl<CustErr> Req<CustErr> for Request<Body> {
}
async fn try_into_string(self) -> Result<String, ServerFnError<CustErr>> {
println!("accepts = {:?}", self.headers().get(http::header::ACCEPT));
let bytes = self.try_into_bytes().await?;
String::from_utf8(bytes.to_vec())
.map_err(|e| ServerFnError::Deserialization(e.to_string()))

View file

@ -1,5 +1,5 @@
use super::ClientRes;
use crate::error::ServerFnError;
use crate::{error::ServerFnError, redirect::REDIRECT_HEADER};
use bytes::Bytes;
use futures::{Stream, StreamExt};
pub use gloo_net::http::Response;
@ -73,4 +73,8 @@ impl<CustErr> ClientRes<CustErr> for BrowserResponse {
.get("Location")
.unwrap_or_else(|| self.0.url())
}
fn has_redirect(&self) -> bool {
self.0.headers().get(REDIRECT_HEADER).is_some()
}
}

View file

@ -68,6 +68,9 @@ pub trait ClientRes<CustErr> {
/// The `Location` header or (if none is set), the URL of the response.
fn location(&self) -> String;
/// Whether the response has the [`REDIRECT_HEADER`](crate::redirect::REDIRECT_HEADER) set.
fn has_redirect(&self) -> bool;
}
/// A mocked response type that can be used in place of the actual server response,

View file

@ -42,4 +42,8 @@ impl<CustErr> ClientRes<CustErr> for Response {
.map(|value| String::from_utf8_lossy(value.as_bytes()).to_string())
.unwrap_or_else(|| self.url().to_string())
}
fn has_redirect(&self) -> bool {
self.headers().get("Location").is_some()
}
}

View file

@ -270,7 +270,9 @@ pub fn server_macro_impl(
})
}
} else {
body
quote! { async move {
#body
}}
};
quote! {
fn run_body(self) -> impl std::future::Future<Output = #return_ty> + Send {