First commit, checkpoint for cyclical dependency error

This commit is contained in:
benwis 2023-12-31 12:09:38 -08:00 committed by Greg Johnston
parent cadd217078
commit c7422cd96e
37 changed files with 4173 additions and 1109 deletions

View file

@ -11,7 +11,7 @@ members = [
"leptos_server",
"server_fn",
"server_fn_macro",
"server_fn/server_fn_macro_default",
"server_fn_macro_default",
# integrations
"integrations/actix",

View file

@ -5,11 +5,11 @@
"systems": "systems"
},
"locked": {
"lastModified": 1681202837,
"narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=",
"lastModified": 1701680307,
"narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "cfacdce06f30d2b68473a46042957675eebb3401",
"rev": "4022d587cbbfd70fe950c1e2083a02621806a725",
"type": "github"
},
"original": {
@ -38,11 +38,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1672580127,
"narHash": "sha256-3lW3xZslREhJogoOkjeZtlBtvFMyxHku7I/9IVehhT8=",
"lastModified": 1685573264,
"narHash": "sha256-Zffu01pONhs/pqH07cjlF10NnMDLok8ix5Uk4rhOnZQ=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "0874168639713f547c05947c76124f78441ea46c",
"rev": "380be19fbd2d9079f677978361792cb25e8a3635",
"type": "github"
},
"original": {
@ -67,11 +67,11 @@
]
},
"locked": {
"lastModified": 1681525152,
"narHash": "sha256-KzI+ILcmU03iFWtB+ysPqtNmp8TP8v1BBReTuPP8MJY=",
"lastModified": 1703902408,
"narHash": "sha256-qXdWvu+tlgNjeoz8yQMRKSom6QyRROfgpmeOhwbujqw=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "b6f8d87208336d7cb85003b2e439fc707c38f92a",
"rev": "319f57cd2c34348c55970a4bf2b35afe82088681",
"type": "github"
},
"original": {

View file

@ -38,11 +38,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1702312524,
"narHash": "sha256-gkZJRDBUCpTPBvQk25G0B7vfbpEYM5s5OZqghkjZsnE=",
"lastModified": 1703637592,
"narHash": "sha256-8MXjxU0RfFfzl57Zy3OfXCITS0qWDNLzlBAdwxGZwfY=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "a9bf124c46ef298113270b1f84a164865987a91c",
"rev": "cfc3698c31b1fb9cdcf10f36c9643460264d0ca8",
"type": "github"
},
"original": {
@ -81,11 +81,11 @@
"nixpkgs": "nixpkgs_2"
},
"locked": {
"lastModified": 1702433821,
"narHash": "sha256-Kxv+dRbzj1fLQG0fyF/H6nswda6cN48r6kjctysnY4o=",
"lastModified": 1703902408,
"narHash": "sha256-qXdWvu+tlgNjeoz8yQMRKSom6QyRROfgpmeOhwbujqw=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "cb9016d3a569100a609bb92c0a45beb9e23cd4eb",
"rev": "319f57cd2c34348c55970a4bf2b35afe82088681",
"type": "github"
},
"original": {

View file

@ -0,0 +1,24 @@
[package]
name = "leptos_pavex"
version = { workspace = true }
edition = "2021"
authors = ["Ben Wishovich"]
license = "MIT"
repository = "https://github.com/leptos-rs/leptos"
description = "Pavex integrations for the Leptos web framework."
[dependencies]
pavex = { git = "https://github.com/LukeMathWalker/pavex", branch = "main" }
futures = "0.3"
leptos = { workspace = true, features = ["ssr"] }
leptos_meta = { workspace = true, features = ["ssr"] }
leptos_router = { workspace = true, features = ["ssr"] }
leptos_integration_utils = { workspace = true }
parking_lot = "0.12.1"
regex = "1.7.0"
tracing = "0.1.37"
tokio = { version = "1", features = ["full"] }
[features]
nonce = ["leptos/nonce"]
experimental-islands = ["leptos_integration_utils/experimental-islands"]

File diff suppressed because it is too large Load diff

View file

@ -1,42 +1,92 @@
[package]
name = "server_fn"
version = { workspace = true }
version = "0.6.0"
edition = "2021"
authors = ["Greg Johnston"]
license = "MIT"
repository = "https://github.com/leptos-rs/leptos"
description = "RPC for any web framework."
readme = "../README.md"
[dependencies]
server_fn_macro_default = { workspace = true }
serde = { version = "1", features = ["derive"] }
serde_qs = "0.12"
thiserror = "1"
serde_json = "1"
quote = "1"
syn = { version = "2", features = ["full", "parsing", "extra-traits"] }
proc-macro2 = "1"
ciborium = "0.2"
xxhash-rust = { version = "0.8", features = ["const_xxh64"] }
# used for hashing paths in #[server] macro
const_format = "0.2"
inventory = { version = "0.3", optional = true }
lazy_static = "1"
xxhash-rust = { version = "0.8", features = ["const_xxh64"] }
[dev-dependencies]
server_fn = { version = "0.2" }
serde = "1"
# used across multiple featurs
serde = { version = "1", features = ["derive"] }
send_wrapper = { version = "0.6", features = ["futures"], optional = true }
[target.'cfg(target_arch = "wasm32")'.dependencies]
gloo-net = "0.2"
js-sys = "0.3"
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
reqwest = { version = "0.11", default-features = false }
# registration system
inventory = {version="0.3",optional=true}
dashmap = "5"
once_cell = "1"
## servers
# actix
actix-web = { version = "4", optional = true }
# axum
axum = { version = "0.7", optional = true, features = ["multipart"] }
tower = { version = "0.4", optional = true }
tower-layer = { version = "0.3", optional = true }
## input encodings
serde_qs = { version = "0.12", optional = true }
multer = { version = "3", optional = true }
## output encodings
# serde
serde_json = { version = "1", optional = true }
futures = "0.3"
http = { version = "1", optional = true }
ciborium = { version = "0.2", optional = true }
hyper = { version = "1", optional = true }
bytes = "1"
thiserror = "1"
http-body-util = { version = "0.1.0", optional = true }
rkyv = { version = "0.7", features = [
"validation",
"uuid",
"strict",
], optional = true }
# client
gloo-net = { version = "0.5", optional = true }
js-sys = { version = "0.3", optional = true }
wasm-bindgen-futures = { version = "0.4", optional = true }
wasm-streams = { version = "0.4", optional = true }
web-sys = { version = "0.3", optional = true, features = [
"console",
"ReadableStream",
"ReadableStreamDefaultReader",
] }
# reqwest client
reqwest = { version = "0.11", default-features = false, optional = true, features = [
"multipart",
"stream",
] }
[features]
actix = ["dep:actix-web", "dep:send_wrapper"]
axum = [
"dep:axum",
"dep:http",
"dep:hyper",
"dep:http-body-util",
"dep:tower",
"dep:tower-layer",
]
browser = [
"dep:gloo-net",
"dep:js-sys",
"dep:send_wrapper",
"dep:web-sys",
"dep:wasm-streams",
"dep:wasm-bindgen-futures",
]
json = ["dep:serde_json"]
multipart = ["dep:multer"]
url = ["dep:serde_qs"]
cbor = ["dep:ciborium"]
rkyv = ["dep:rkyv"]
default-tls = ["reqwest/default-tls"]
rustls = ["reqwest/rustls-tls"]
reqwest = ["dep:http", "dep:reqwest"]
ssr = ["inventory"]
nightly = ["server_fn_macro_default/nightly"]

View file

@ -1 +0,0 @@
extend = { path = "../cargo-make/main.toml" }

View file

@ -1,69 +0,0 @@
#![cfg_attr(feature = "nightly", feature(proc_macro_span))]
//! This crate contains the default implementation of the #[macro@crate::server] macro without a context from the server. See the [server_fn_macro] crate for more information.
#![forbid(unsafe_code)]
// to prevent warnings from popping up when a nightly feature is stabilized
#![allow(stable_features)]
use proc_macro::TokenStream;
use server_fn_macro::server_macro_impl;
use syn::__private::ToTokens;
/// Declares that a function is a [server function](https://docs.rs/server_fn/).
/// This means that its body will only run on the server, i.e., when the `ssr`
/// feature is enabled.
///
/// You can specify one, two, three, or four arguments to the server function:
/// 1. **Required**: A type name that will be used to identify and register the server function
/// (e.g., `MyServerFn`).
/// 2. *Optional*: A URL prefix at which the function will be mounted when its registered
/// (e.g., `"/api"`). Defaults to `"/"`.
/// 3. *Optional*: The encoding for the server function (`"Url"`, `"Cbor"`, `"GetJson"`, or `"GetCbor`". See **Server Function Encodings** below.)
/// 4. *Optional*: A specific endpoint path to be used in the URL. (By default, a unique path will be generated.)
///
/// ```rust,ignore
/// // will generate a server function at `/api-prefix/hello`
/// #[server(MyServerFnType, "/api-prefix", "Url", "hello")]
/// ```
///
/// The server function itself can take any number of arguments, each of which should be serializable
/// and deserializable with `serde`.
///
/// ```ignore
/// # use server_fn::*; use serde::{Serialize, Deserialize};
/// # #[derive(Serialize, Deserialize)]
/// # pub struct Post { }
/// #[server(ReadPosts, "/api")]
/// pub async fn read_posts(how_many: u8, query: String) -> Result<Vec<Post>, ServerFnError> {
/// // do some work on the server to access the database
/// todo!()
/// }
/// ```
///
/// Note the following:
/// - **Server functions must be `async`.** Even if the work being done inside the function body
/// can run synchronously on the server, from the clients perspective it involves an asynchronous
/// function call.
/// - **Server functions must return `Result<T, ServerFnError>`.** Even if the work being done
/// inside the function body cant fail, the processes of serialization/deserialization and the
/// network call are fallible.
/// - **Return types must implement [Serialize](https://docs.rs/serde/latest/serde/trait.Serialize.html).**
/// This should be fairly obvious: we have to serialize arguments to send them to the server, and we
/// need to deserialize the result to return it to the client.
/// - **Arguments must be implement [`Serialize`](https://docs.rs/serde/latest/serde/trait.Serialize.html)
/// and [`DeserializeOwned`](https://docs.rs/serde/latest/serde/de/trait.DeserializeOwned.html).**
/// They are serialized as an `application/x-www-form-urlencoded`
/// form data using [`serde_qs`](https://docs.rs/serde_qs/latest/serde_qs/) or as `application/cbor`
/// using [`cbor`](https://docs.rs/cbor/latest/cbor/).
#[proc_macro_attribute]
pub fn server(args: proc_macro::TokenStream, s: TokenStream) -> TokenStream {
match server_macro_impl(
args.into(),
s.into(),
syn::parse_quote!(server_fn::default::DefaultServerFnTraitObj),
None,
Some(syn::parse_quote!(server_fn)),
) {
Err(e) => e.to_compile_error().into(),
Ok(s) => s.to_token_stream().into(),
}
}

65
server_fn/src/client.rs Normal file
View file

@ -0,0 +1,65 @@
use crate::{error::ServerFnError, request::ClientReq, response::ClientRes};
use std::future::Future;
pub trait Client<CustErr> {
type Request: ClientReq<CustErr> + Send;
type Response: ClientRes<CustErr> + Send;
fn send(
req: Self::Request,
) -> impl Future<Output = Result<Self::Response, ServerFnError<CustErr>>> + Send;
}
#[cfg(feature = "browser")]
pub mod browser {
use super::Client;
use crate::{
error::ServerFnError, request::browser::BrowserRequest, response::browser::BrowserResponse,
};
use send_wrapper::SendWrapper;
use std::future::Future;
pub struct BrowserClient;
impl<CustErr> Client<CustErr> for BrowserClient {
type Request = BrowserRequest;
type Response = BrowserResponse;
fn send(
req: Self::Request,
) -> impl Future<Output = Result<Self::Response, ServerFnError<CustErr>>> + Send {
SendWrapper::new(async move {
req.0
.take()
.send()
.await
.map(|res| BrowserResponse(SendWrapper::new(res)))
.map_err(|e| ServerFnError::Request(e.to_string()))
})
}
}
}
#[cfg(feature = "reqwest")]
pub mod reqwest {
use super::Client;
use crate::{error::ServerFnError, request::reqwest::CLIENT};
use futures::TryFutureExt;
use reqwest::{Request, Response};
use std::future::Future;
pub struct ReqwestClient;
impl<CustErr> Client<CustErr> for ReqwestClient {
type Request = Request;
type Response = Response;
fn send(
req: Self::Request,
) -> impl Future<Output = Result<Self::Response, ServerFnError<CustErr>>> + Send {
CLIENT
.execute(req)
.map_err(|e| ServerFnError::Request(e.to_string()))
}
}
}

160
server_fn/src/codec/cbor.rs Normal file
View file

@ -0,0 +1,160 @@
use super::{Encoding, FromReq, FromRes, IntoReq, IntoRes};
use crate::error::ServerFnError;
use crate::request::{ClientReq, Req};
use crate::response::{ClientRes, Res};
use bytes::Bytes;
use serde::de::DeserializeOwned;
use serde::Serialize;
/// Pass arguments and receive responses using `cbor` in a `POST` request.
pub struct Cbor;
impl Encoding for Cbor {
const CONTENT_TYPE: &'static str = "application/cbor";
}
impl<CustErr, T, Request> IntoReq<CustErr, Request, Cbor> for T
where
Request: ClientReq<CustErr>,
T: Serialize + Send,
{
fn into_req(self, path: &str, accepts: &str) -> Result<Request, ServerFnError<CustErr>> {
let mut buffer: Vec<u8> = Vec::new();
ciborium::ser::into_writer(&self, &mut buffer)
.map_err(|e| ServerFnError::Serialization(e.to_string()))?;
Request::try_new_post_bytes(path, accepts, Cbor::CONTENT_TYPE, Bytes::from(buffer))
}
}
impl<CustErr, T, Request> FromReq<CustErr, Request, Cbor> for T
where
Request: Req<CustErr> + Send + 'static,
T: DeserializeOwned,
{
async fn from_req(req: Request) -> Result<Self, ServerFnError<CustErr>> {
let body_bytes = req.try_into_bytes().await?;
ciborium::de::from_reader(body_bytes.as_ref())
.map_err(|e| ServerFnError::Args(e.to_string()))
}
}
impl<CustErr, T, Response> IntoRes<CustErr, Response, Cbor> for T
where
Response: Res<CustErr>,
T: Serialize + Send,
{
async fn into_res(self) -> Result<Response, ServerFnError<CustErr>> {
let mut buffer: Vec<u8> = Vec::new();
ciborium::ser::into_writer(&self, &mut buffer)
.map_err(|e| ServerFnError::Serialization(e.to_string()))?;
Response::try_from_bytes(Cbor::CONTENT_TYPE, Bytes::from(buffer))
}
}
impl<CustErr, T, Response> FromRes<CustErr, Response, Cbor> for T
where
Response: ClientRes<CustErr> + Send,
T: DeserializeOwned + Send,
{
async fn from_res(res: Response) -> Result<Self, ServerFnError<CustErr>> {
let data = res.try_into_bytes().await?;
ciborium::de::from_reader(data.as_ref()).map_err(|e| ServerFnError::Args(e.to_string()))
}
}
/* use std::fmt::Display;
use super::{Codec, Encoding};
use crate::error::{ServerFnError, IntoErrorResponse};
use async_trait::async_trait;
use axum::body::{Body, HttpBody};
use http_body_util::BodyExt;
use serde::de::DeserializeOwned;
use serde::Serialize;
/// Pass argument as JSON in the body of a POST Request
pub struct PostCbor;
impl Encoding for PostCbor {
const REQUEST_CONTENT_TYPE: &'static str = "application/cbor";
const RESPONSE_CONTENT_TYPE: &'static str = "application/cbor";
}
#[async_trait]
impl<T, RequestBody, ResponseBody>
Codec<
RequestBody,
ResponseBody,
http::Request<RequestBody>,
http::Response<ResponseBody>,
Body,
Body,
http::Request<Body>,
http::Response<Body>,
PostCbor,
> for T
where
T: DeserializeOwned + Serialize + Send,
for<'a> RequestBody: HttpBody + Send + 'a,
<RequestBody as HttpBody>::Error: Display + Send ,
<ResponseBody as HttpBody>::Error: Display + Send ,
for<'a> ResponseBody: HttpBody + Send + 'a,
<ResponseBody as HttpBody>::Data: Send ,
<RequestBody as HttpBody>::Data: Send ,
{
async fn from_req(req: http::Request<RequestBody>) -> Result<Self, ServerFnError<CustErr>> {
let (_parts, body) = req.into_parts();
let body_bytes = body
.collect()
.await
.map(|c| c.to_bytes())
.map_err(|e| ServerFnError::Deserialization(e.to_string()))?;
let data = ciborium::de::from_reader(body_bytes.as_ref())
.map_err(|e| ServerFnError::Args(e.to_string()))?;
Ok(data)
}
async fn into_req(self) -> Result<http::Request<Body>, ServerFnError<CustErr>> {
let mut buffer: Vec<u8> = Vec::new();
ciborium::ser::into_writer(&self, &mut buffer)?;
let req = http::Request::builder()
.method("POST")
.header(
http::header::CONTENT_TYPE,
<PostCbor as Encoding>::REQUEST_CONTENT_TYPE,
)
.body(Body::from(buffer))?;
Ok(req)
}
async fn from_res(res: http::Response<ResponseBody>) -> Result<Self, ServerFnError<CustErr>> {
let (_parts, body) = res.into_parts();
let body_bytes = body
.collect()
.await
.map(|c| c.to_bytes())
.map_err(|e| ServerFnError::Deserialization(e.to_string()))?;
ciborium::de::from_reader(body_bytes.as_ref())
.map_err(|e| ServerFnError::Args(e.to_string()))
}
async fn into_res(self) -> http::Response<Body> {
let mut buffer: Vec<u8> = Vec::new();
match ciborium::ser::into_writer(&self, &mut buffer) {
Ok(_) => (),
Err(e) => return e.into_err_res(),
};
let res = http::Response::builder()
.status(200)
.header(
http::header::CONTENT_TYPE,
<PostCbor as Encoding>::REQUEST_CONTENT_TYPE,
)
.body(Body::from(buffer))
.unwrap();
res
}
}
*/

View file

@ -0,0 +1,59 @@
use super::{Encoding, FromReq, FromRes};
use crate::error::ServerFnError;
use crate::request::{ClientReq, Req};
use crate::response::{ClientRes, Res};
use crate::{IntoReq, IntoRes};
use serde::de::DeserializeOwned;
use serde::Serialize;
/// Pass arguments and receive responses as JSON in the body of a `POST` request.
pub struct Json;
impl Encoding for Json {
const CONTENT_TYPE: &'static str = "application/json";
}
impl<CustErr, T, Request> IntoReq<CustErr, Request, Json> for T
where
Request: ClientReq<CustErr>,
T: Serialize + Send,
{
fn into_req(self, path: &str, accepts: &str) -> Result<Request, ServerFnError<CustErr>> {
let data = serde_json::to_string(&self)
.map_err(|e| ServerFnError::Serialization(e.to_string()))?;
Request::try_new_post(path, accepts, Json::CONTENT_TYPE, data)
}
}
impl<CustErr, T, Request> FromReq<CustErr, Request, Json> for T
where
Request: Req<CustErr> + Send + 'static,
T: DeserializeOwned,
{
async fn from_req(req: Request) -> Result<Self, ServerFnError<CustErr>> {
let string_data = req.try_into_string().await?;
serde_json::from_str::<Self>(&string_data).map_err(|e| ServerFnError::Args(e.to_string()))
}
}
impl<CustErr, T, Response> IntoRes<CustErr, Response, Json> for T
where
Response: Res<CustErr>,
T: Serialize + Send,
{
async fn into_res(self) -> Result<Response, ServerFnError<CustErr>> {
let data = serde_json::to_string(&self)
.map_err(|e| ServerFnError::Serialization(e.to_string()))?;
Response::try_from_string(Json::CONTENT_TYPE, data)
}
}
impl<CustErr, T, Response> FromRes<CustErr, Response, Json> for T
where
Response: ClientRes<CustErr> + Send,
T: DeserializeOwned + Send,
{
async fn from_res(res: Response) -> Result<Self, ServerFnError<CustErr>> {
let data = res.try_into_string().await?;
serde_json::from_str(&data).map_err(|e| ServerFnError::Deserialization(e.to_string()))
}
}

View file

@ -0,0 +1,53 @@
#[cfg(feature = "cbor")]
mod cbor;
#[cfg(feature = "cbor")]
pub use cbor::*;
#[cfg(feature = "json")]
mod json;
#[cfg(feature = "json")]
pub use json::*;
#[cfg(feature = "rkyv")]
mod rkyv;
#[cfg(feature = "rkyv")]
pub use rkyv::*;
#[cfg(feature = "url")]
mod url;
use crate::error::ServerFnError;
use futures::Future;
#[cfg(feature = "url")]
pub use url::*;
#[cfg(feature = "multipart")]
mod multipart;
#[cfg(feature = "multipart")]
pub use multipart::*;
mod stream;
pub use stream::*;
pub trait FromReq<CustErr, Request, Encoding>
where
Self: Sized,
{
fn from_req(req: Request) -> impl Future<Output = Result<Self, ServerFnError<CustErr>>> + Send;
}
pub trait IntoReq<CustErr, Request, Encoding> {
fn into_req(self, path: &str, accepts: &str) -> Result<Request, ServerFnError<CustErr>>;
}
pub trait FromRes<CustErr, Response, Encoding>
where
Self: Sized,
{
fn from_res(res: Response)
-> impl Future<Output = Result<Self, ServerFnError<CustErr>>> + Send;
}
pub trait IntoRes<CustErr, Response, Encoding> {
fn into_res(self) -> impl Future<Output = Result<Response, ServerFnError<CustErr>>> + Send;
}
pub trait Encoding {
const CONTENT_TYPE: &'static str;
}

View file

@ -0,0 +1,71 @@
use super::{Encoding, FromReq};
use crate::error::ServerFnError;
use crate::request::browser::BrowserFormData;
use crate::request::{ClientReq, Req};
use crate::IntoReq;
use futures::StreamExt;
use multer::Multipart;
use web_sys::FormData;
pub struct MultipartFormData;
impl Encoding for MultipartFormData {
const CONTENT_TYPE: &'static str = "multipart/form-data";
}
#[derive(Debug)]
pub enum MultipartData {
Client(BrowserFormData),
Server(multer::Multipart<'static>),
}
impl MultipartData {
pub fn into_client_data(self) -> Option<BrowserFormData> {
match self {
MultipartData::Client(data) => Some(data),
MultipartData::Server(_) => None,
}
}
pub fn into_data(self) -> Option<Multipart<'static>> {
match self {
MultipartData::Client(_) => None,
MultipartData::Server(data) => Some(data),
}
}
}
impl From<FormData> for MultipartData {
fn from(value: FormData) -> Self {
MultipartData::Client(value.into())
}
}
impl<CustErr, T, Request> IntoReq<CustErr, Request, MultipartFormData> for T
where
Request: ClientReq<CustErr, FormData = BrowserFormData>,
T: Into<MultipartData>,
{
fn into_req(self, path: &str, accepts: &str) -> Result<Request, ServerFnError<CustErr>> {
let multi = self.into();
Request::try_new_multipart(path, accepts, multi.into_client_data().unwrap())
}
}
impl<CustErr, T, Request> FromReq<CustErr, Request, MultipartFormData> for T
where
Request: Req<CustErr> + Send + 'static,
T: From<MultipartData>,
CustErr: 'static,
{
async fn from_req(req: Request) -> Result<Self, ServerFnError<CustErr>> {
let boundary = req
.to_content_type()
.and_then(|ct| multer::parse_boundary(ct).ok())
.expect("couldn't parse boundary");
let stream = req.try_into_stream()?;
let data =
multer::Multipart::new(stream.map(|data| data.map_err(|e| e.to_string())), boundary);
Ok(MultipartData::Server(data).into())
}
}

View file

@ -0,0 +1,72 @@
use rkyv::{
de::deserializers::SharedDeserializeMap, ser::serializers::AllocSerializer,
validation::validators::DefaultValidator, Archive, CheckBytes, Deserialize, Serialize,
};
use super::{Encoding, FromReq, FromRes, IntoReq, IntoRes};
use crate::error::ServerFnError;
use crate::request::{ClientReq, Req};
use crate::response::{ClientRes, Res};
use bytes::Bytes;
/// Pass arguments and receive responses using `rkyv` in a `POST` request.
pub struct Rkyv;
impl Encoding for Rkyv {
const CONTENT_TYPE: &'static str = "application/rkyv";
}
impl<CustErr, T, Request> IntoReq<CustErr, Request, Rkyv> for T
where
Request: ClientReq<CustErr>,
T: Serialize<AllocSerializer<1024>> + Send,
T: Archive,
T::Archived: for<'a> CheckBytes<DefaultValidator<'a>> + Deserialize<T, SharedDeserializeMap>,
{
fn into_req(self, path: &str, accepts: &str) -> Result<Request, ServerFnError<CustErr>> {
let encoded = rkyv::to_bytes::<T, 1024>(&self)?;
let bytes = Bytes::copy_from_slice(encoded.as_ref());
Request::try_new_post_bytes(path, accepts, Rkyv::CONTENT_TYPE, bytes)
}
}
impl<CustErr, T, Request> FromReq<CustErr, Request, Rkyv> for T
where
Request: Req<CustErr> + Send + 'static,
T: Serialize<AllocSerializer<1024>> + Send,
T: Archive,
T::Archived: for<'a> CheckBytes<DefaultValidator<'a>> + Deserialize<T, SharedDeserializeMap>,
{
async fn from_req(req: Request) -> Result<Self, ServerFnError<CustErr>> {
let body_bytes = req.try_into_bytes().await?;
rkyv::from_bytes::<T>(&body_bytes).map_err(|e| ServerFnError::Args(e.to_string()))
}
}
impl<CustErr, T, Response> IntoRes<CustErr, Response, Rkyv> for T
where
Response: Res<CustErr>,
T: Serialize<AllocSerializer<1024>> + Send,
T: Archive,
T::Archived: for<'a> CheckBytes<DefaultValidator<'a>> + Deserialize<T, SharedDeserializeMap>,
{
async fn into_res(self) -> Result<Response, ServerFnError<CustErr>> {
let encoded = rkyv::to_bytes::<T, 1024>(&self)
.map_err(|e| ServerFnError::Serialization(e.to_string()))?;
let bytes = Bytes::copy_from_slice(encoded.as_ref());
Response::try_from_bytes(Rkyv::CONTENT_TYPE, bytes)
}
}
impl<CustErr, T, Response> FromRes<CustErr, Response, Rkyv> for T
where
Response: ClientRes<CustErr> + Send,
T: Serialize<AllocSerializer<1024>> + Send,
T: Archive,
T::Archived: for<'a> CheckBytes<DefaultValidator<'a>> + Deserialize<T, SharedDeserializeMap>,
{
async fn from_res(res: Response) -> Result<Self, ServerFnError<CustErr>> {
let data = res.try_into_bytes().await?;
rkyv::from_bytes::<T>(&data).map_err(|e| ServerFnError::Deserialization(e.to_string()))
}
}

View file

@ -0,0 +1,128 @@
use std::pin::Pin;
use super::{Encoding, FromRes};
use crate::error::{NoCustomError, ServerFnError};
use crate::response::{ClientRes, Res};
use crate::IntoRes;
use bytes::Bytes;
use futures::{Stream, StreamExt};
pub struct Streaming;
impl Encoding for Streaming {
const CONTENT_TYPE: &'static str = "application/octet-stream";
}
/* impl<CustErr, T, Request> IntoReq<CustErr, Request, ByteStream> for T
where
Request: ClientReq<CustErr>,
T: Stream<Item = Bytes> + Send,
{
fn into_req(self, path: &str, accepts: &str) -> Result<Request, ServerFnError<CustErr>> {
Request::try_new_stream(path, ByteStream::CONTENT_TYPE, self)
}
} */
/* impl<CustErr, T, Request> FromReq<CustErr, Request, ByteStream> for T
where
Request: Req<CustErr> + Send + 'static,
T: Stream<Item = Bytes> + Send,
{
async fn from_req(req: Request) -> Result<Self, ServerFnError<CustErr>> {
req.try_into_stream().await
}
} */
pub struct ByteStream<CustErr = NoCustomError>(
Pin<Box<dyn Stream<Item = Result<Bytes, ServerFnError<CustErr>>> + Send>>,
);
impl<CustErr> ByteStream<CustErr> {
pub fn into_inner(self) -> impl Stream<Item = Result<Bytes, ServerFnError<CustErr>>> + Send {
self.0
}
}
impl<S, T> From<S> for ByteStream
where
S: Stream<Item = T> + Send + 'static,
T: Into<Bytes>,
{
fn from(value: S) -> Self {
Self(Box::pin(value.map(|data| Ok(data.into()))))
}
}
impl<CustErr, Response> IntoRes<CustErr, Response, Streaming> for ByteStream<CustErr>
where
Response: Res<CustErr>,
CustErr: 'static,
{
async fn into_res(self) -> Result<Response, ServerFnError<CustErr>> {
Response::try_from_stream(Streaming::CONTENT_TYPE, self.into_inner())
}
}
impl<CustErr, Response> FromRes<CustErr, Response, Streaming> for ByteStream
where
Response: ClientRes<CustErr> + Send,
{
async fn from_res(res: Response) -> Result<Self, ServerFnError<CustErr>> {
let stream = res.try_into_stream()?;
Ok(ByteStream(Box::pin(stream)))
}
}
pub struct StreamingText;
impl Encoding for StreamingText {
const CONTENT_TYPE: &'static str = "text/plain";
}
pub struct TextStream<CustErr = NoCustomError>(
Pin<Box<dyn Stream<Item = Result<String, ServerFnError<CustErr>>> + Send>>,
);
impl<CustErr> TextStream<CustErr> {
pub fn into_inner(self) -> impl Stream<Item = Result<String, ServerFnError<CustErr>>> + Send {
self.0
}
}
impl<S, T> From<S> for TextStream
where
S: Stream<Item = T> + Send + 'static,
T: Into<String>,
{
fn from(value: S) -> Self {
Self(Box::pin(value.map(|data| Ok(data.into()))))
}
}
impl<CustErr, Response> IntoRes<CustErr, Response, StreamingText> for TextStream<CustErr>
where
Response: Res<CustErr>,
CustErr: 'static,
{
async fn into_res(self) -> Result<Response, ServerFnError<CustErr>> {
Response::try_from_stream(
Streaming::CONTENT_TYPE,
self.into_inner().map(|stream| stream.map(Into::into)),
)
}
}
impl<CustErr, Response> FromRes<CustErr, Response, StreamingText> for TextStream
where
Response: ClientRes<CustErr> + Send,
{
async fn from_res(res: Response) -> Result<Self, ServerFnError<CustErr>> {
let stream = res.try_into_stream()?;
Ok(TextStream(Box::pin(stream.map(|chunk| {
chunk.and_then(|bytes| {
String::from_utf8(bytes.into())
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
})
}))))
}
}

131
server_fn/src/codec/url.rs Normal file
View file

@ -0,0 +1,131 @@
use super::{Encoding, FromReq, IntoReq};
use crate::error::ServerFnError;
use crate::request::{ClientReq, Req};
use serde::de::DeserializeOwned;
use serde::Serialize;
/// Pass arguments as a URL-encoded query string of a `GET` request.
pub struct GetUrl;
/// Pass arguments as the URL-encoded body of a `POST` request.
pub struct PostUrl;
impl Encoding for GetUrl {
const CONTENT_TYPE: &'static str = "application/x-www-form-urlencoded";
}
impl<CustErr, T, Request> IntoReq<CustErr, Request, GetUrl> for T
where
Request: ClientReq<CustErr>,
T: Serialize + Send,
{
fn into_req(self, path: &str, accepts: &str) -> Result<Request, ServerFnError<CustErr>> {
let data =
serde_qs::to_string(&self).map_err(|e| ServerFnError::Serialization(e.to_string()))?;
Request::try_new_get(path, accepts, GetUrl::CONTENT_TYPE, &data)
}
}
impl<CustErr, T, Request> FromReq<CustErr, Request, GetUrl> for T
where
Request: Req<CustErr> + Send + 'static,
T: DeserializeOwned,
{
async fn from_req(req: Request) -> Result<Self, ServerFnError<CustErr>> {
let string_data = req.as_query().unwrap_or_default();
let args = serde_qs::from_str::<Self>(string_data)
.map_err(|e| ServerFnError::Args(e.to_string()))?;
Ok(args)
}
}
impl Encoding for PostUrl {
const CONTENT_TYPE: &'static str = "application/x-www-form-urlencoded";
}
impl<CustErr, T, Request> IntoReq<CustErr, Request, PostUrl> for T
where
Request: ClientReq<CustErr>,
T: Serialize + Send,
{
fn into_req(self, path: &str, accepts: &str) -> Result<Request, ServerFnError<CustErr>> {
let qs =
serde_qs::to_string(&self).map_err(|e| ServerFnError::Serialization(e.to_string()))?;
Request::try_new_post(path, accepts, PostUrl::CONTENT_TYPE, qs)
}
}
impl<CustErr, T, Request> FromReq<CustErr, Request, PostUrl> for T
where
Request: Req<CustErr> + Send + 'static,
T: DeserializeOwned,
{
async fn from_req(req: Request) -> Result<Self, ServerFnError<CustErr>> {
let string_data = req.try_into_string().await?;
let args = serde_qs::from_str::<Self>(&string_data)
.map_err(|e| ServerFnError::Args(e.to_string()))?;
Ok(args)
}
}
/* #[async_trait]
impl<T, Request, Response> Codec<Request, Response, GetUrlJson> for T
where
T: DeserializeOwned + Serialize + Send,
Request: Req<CustErr> + Send,
Response: Res<CustErr> + Send,
{
async fn from_req(req: Request) -> Result<Self, ServerFnError<CustErr>> {
let string_data = req.try_into_string()?;
let args = serde_json::from_str::<Self>(&string_data)
.map_err(|e| ServerFnError::Args(e.to_string()))?;
Ok(args)
}
async fn into_req(self) -> Result<Request, ServerFnError<CustErr>> {
/* let qs = serde_qs::to_string(&self)?;
let req = http::Request::builder()
.method("GET")
.header(
http::header::CONTENT_TYPE,
<GetUrlJson as Encoding>::REQUEST_CONTENT_TYPE,
)
.body(Body::from(qs))?;
Ok(req) */
todo!()
}
async fn from_res(res: Response) -> Result<Self, ServerFnError<CustErr>> {
todo!()
/* let (_parts, body) = res.into_parts();
let body_bytes = body
.collect()
.await
.map(|c| c.to_bytes())
.map_err(|e| ServerFnError::Deserialization(e.to_string()))?;
let string_data = String::from_utf8(body_bytes.to_vec())?;
serde_json::from_str(&string_data)
.map_err(|e| ServerFnError::Deserialization(e.to_string())) */
}
async fn into_res(self) -> Response {
// Need to catch and err or here, or handle Errors at a higher level
let data = match serde_json::to_string(&self) {
Ok(d) => d,
Err(e) => return e.into_err_res(),
};
let builder = http::Response::builder();
let res = builder
.status(200)
.header(
http::header::CONTENT_TYPE,
<GetUrlJson as Encoding>::RESPONSE_CONTENT_TYPE,
)
.body(Body::from(data))
.unwrap();
res
}
}
*/

View file

@ -1,117 +0,0 @@
#[cfg(any(feature = "ssr", doc))]
use crate::ServerFnTraitObj;
pub use server_fn_macro_default::server;
#[cfg(any(feature = "ssr", doc))]
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
#[cfg(any(feature = "ssr", doc))]
lazy_static::lazy_static! {
static ref REGISTERED_SERVER_FUNCTIONS: Arc<RwLock<HashMap<&'static str, &'static DefaultServerFnTraitObj>>> = {
let mut map = HashMap::new();
for server_fn in inventory::iter::<DefaultServerFnTraitObj> {
map.insert(server_fn.0.url(), server_fn);
}
Arc::new(RwLock::new(map))
};
}
#[cfg(feature = "ssr")]
inventory::collect!(DefaultServerFnTraitObj);
/// Attempts to find a server function registered at the given path.
///
/// This can be used by a server to handle the requests, as in the following example (using [`actix-web`]).
///
/// ```rust, ignore
/// #[post("{tail:.*}")]
/// async fn handle_server_fns(
/// req: HttpRequest,
/// params: web::Path<String>,
/// body: web::Bytes,
/// ) -> impl Responder {
/// let path = params.into_inner();
/// let accept_header = req
/// .headers()
/// .get("Accept")
/// .and_then(|value| value.to_str().ok());
///
/// if let Some(server_fn) = server_fn_by_path(path.as_str()) {
/// let body: &[u8] = &body;
/// match server_fn(&body).await {
/// Ok(serialized) => {
/// // if this is Accept: application/json then send a serialized JSON response
/// if let Some("application/json") = accept_header {
/// HttpResponse::Ok().body(serialized)
/// }
/// // otherwise, it's probably a <form> submit or something: redirect back to the referrer
/// else {
/// HttpResponse::SeeOther()
/// .insert_header(("Location", "/"))
/// .content_type("application/json")
/// .body(serialized)
/// }
/// }
/// Err(e) => {
/// eprintln!("server function error: {e:#?}");
/// HttpResponse::InternalServerError().body(e.to_string())
/// }
/// }
/// } else {
/// HttpResponse::BadRequest().body(format!("Could not find a server function at that route."))
/// }
/// }
/// ```
///
/// [`actix-web`]: <https://docs.rs/actix-web/>
#[cfg(any(feature = "ssr", doc))]
pub fn server_fn_by_path(
path: &str,
) -> Option<&'static DefaultServerFnTraitObj> {
REGISTERED_SERVER_FUNCTIONS
.read()
.expect("Server function registry is poisoned")
.get(path)
.copied()
}
/// Returns the set of currently-registered server function paths, for debugging purposes.
#[cfg(any(feature = "ssr", doc))]
pub fn server_fns_by_path() -> Vec<&'static str> {
REGISTERED_SERVER_FUNCTIONS
.read()
.expect("Server function registry is poisoned")
.keys()
.copied()
.collect()
}
#[cfg(any(feature = "ssr", doc))]
/// A server function that can be called from the client without any context from the server.
pub struct DefaultServerFnTraitObj(ServerFnTraitObj<()>);
#[cfg(any(feature = "ssr", doc))]
impl DefaultServerFnTraitObj {
/// Creates a new server function with the given prefix, URL, encoding, and function.
pub const fn from_generic_server_fn(f: ServerFnTraitObj<()>) -> Self {
Self(f)
}
}
#[cfg(any(feature = "ssr", doc))]
impl std::ops::Deref for DefaultServerFnTraitObj {
type Target = ServerFnTraitObj<()>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[cfg(any(feature = "ssr", doc))]
impl std::ops::DerefMut for DefaultServerFnTraitObj {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

View file

@ -1,65 +1,107 @@
use core::fmt::{self, Display};
use serde::{Deserialize, Serialize};
use std::{error, fmt, ops, sync::Arc};
use thiserror::Error;
/// This is a result type into which any error can be converted,
/// and which can be used directly in your `view`.
///
/// All errors will be stored as [`struct@Error`].
pub type Result<T, E = Error> = core::result::Result<T, E>;
/// An empty value indicating that there is no custom error type associated
/// with this server function.
#[derive(Debug, Deserialize, Serialize)]
pub struct NoCustomError(());
/// A generic wrapper for any error.
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct Error(Arc<dyn error::Error + Send + Sync>);
impl Error {
/// Converts the wrapper into the inner reference-counted error.
pub fn into_inner(self) -> Arc<dyn error::Error + Send + Sync> {
Arc::clone(&self.0)
// Implement `Display` for `NoCustomError`
impl fmt::Display for NoCustomError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Unit Type Displayed")
}
}
impl ops::Deref for Error {
type Target = Arc<dyn error::Error + Send + Sync>;
#[derive(Debug)]
pub struct WrapError<T>(pub T);
fn deref(&self) -> &Self::Target {
&self.0
}
/// This helper macro lets you call the gnarly autoref-specialization call
/// without having to worry about things like how many & you need.
/// Mostly used when you impl From<ServerFnError> for YourError
#[macro_export]
macro_rules! server_fn_error {
() => {{
use $crate::{ViaError, WrapError};
(&&&&&WrapError(())).to_server_error()
}};
($err:expr) => {{
use $crate::error::{ViaError, WrapError};
match $err {
error => (&&&&&WrapError(error)).to_server_error(),
}
}};
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
/// This trait serves as the conversion method between a variety of types
/// and [`ServerFnError`].
pub trait ViaError<E> {
fn to_server_error(&self) -> ServerFnError<E>;
}
impl<T> From<T> for Error
where
T: std::error::Error + Send + Sync + 'static,
// This impl should catch if you fed it a [`ServerFnError`] already.
impl<E: ServerFnErrorKind + std::error::Error + Clone> ViaError<E>
for &&&&WrapError<ServerFnError<E>>
{
fn from(value: T) -> Self {
Error(Arc::new(value))
fn to_server_error(&self) -> ServerFnError<E> {
self.0.clone()
}
}
impl From<ServerFnError> for Error {
fn from(e: ServerFnError) -> Self {
Error(Arc::new(ServerFnErrorErr::from(e)))
// A type tag for ServerFnError so we can special case it
pub(crate) trait ServerFnErrorKind {}
impl ServerFnErrorKind for ServerFnError {}
// This impl should catch passing () or nothing to server_fn_error
impl ViaError<()> for &&&WrapError<()> {
fn to_server_error(&self) -> ServerFnError<()> {
ServerFnError::WrappedServerError(self.0.clone())
}
}
// This impl will catch any type that implements any type that impls
// Error and Clone, so that it can be wrapped into ServerFnError
impl<E: std::error::Error + Clone> ViaError<E> for &&WrapError<E> {
fn to_server_error(&self) -> ServerFnError<E> {
ServerFnError::WrappedServerError(self.0.clone())
}
}
// If it doesn't impl Error, but does impl Display and Clone,
// we can still wrap it in String form
impl<E: Display + Clone> ViaError<E> for &WrapError<E> {
fn to_server_error(&self) -> ServerFnError<E> {
ServerFnError::WrappedServerError(self.0.clone())
}
}
// This is what happens if someone tries to pass in something that does
// not meet the above criteria
impl<E> ViaError<E> for WrapError<E> {
#[track_caller]
fn to_server_error(&self) -> ServerFnError<E> {
panic!("At {}, you call `to_server_error()` or use `server_fn_error!` with a value that does not implement `Clone` and either `Error` or `Display`.", std::panic::Location::caller());
}
}
/// Type for errors that can occur when using server functions.
///
/// Unlike [`ServerFnErrorErr`], this does not implement [`std::error::Error`].
/// Unlike [`ServerFnErrorErr`], this does not implement [`Error`](std::error::Error).
/// This means that other error types can easily be converted into it using the
/// `?` operator.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum ServerFnError {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ServerFnError<E = NoCustomError> {
WrappedServerError(E),
/// Error while trying to register the server function (only occurs in case of poisoned RwLock).
Registration(String),
/// Occurs on the client if there is a network error while trying to run function on server.
Request(String),
/// Occurs on the server if there is an error creating an HTTP response.
Response(String),
/// Occurs when there is an error while actually running the function on the server.
ServerError(String),
/// Occurs on the client if there is an error deserializing the server's response.
@ -72,39 +114,49 @@ pub enum ServerFnError {
MissingArg(String),
}
impl core::fmt::Display for ServerFnError {
impl<CustErr> From<CustErr> for ServerFnError<CustErr> {
fn from(value: CustErr) -> Self {
ServerFnError::WrappedServerError(value)
}
}
impl<CustErr> Display for ServerFnError<CustErr>
where
CustErr: Display,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
ServerFnError::Registration(s) => format!(
"error while trying to register the server function: {s}"
),
ServerFnError::Request(s) => format!(
"error reaching server to call server function: {s}"
),
ServerFnError::ServerError(s) =>
format!("error running server function: {s}"),
ServerFnError::Registration(s) =>
format!("error while trying to register the server function: {s}"),
ServerFnError::Request(s) =>
format!("error reaching server to call server function: {s}"),
ServerFnError::ServerError(s) => format!("error running server function: {s}"),
ServerFnError::Deserialization(s) =>
format!("error deserializing server function results: {s}"),
ServerFnError::Serialization(s) =>
format!("error serializing server function arguments: {s}"),
ServerFnError::Args(s) => format!(
"error deserializing server function arguments: {s}"
),
ServerFnError::Args(s) =>
format!("error deserializing server function arguments: {s}"),
ServerFnError::MissingArg(s) => format!("missing argument {s}"),
ServerFnError::Response(s) => format!("error generating HTTP response: {s}"),
ServerFnError::WrappedServerError(e) => format!("{}", e),
}
)
}
}
impl<E> From<E> for ServerFnError
impl<E> std::error::Error for ServerFnError<E>
where
E: std::error::Error,
E: std::error::Error + 'static,
ServerFnError<E>: std::fmt::Display,
{
fn from(e: E) -> Self {
ServerFnError::ServerError(e.to_string())
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ServerFnError::WrappedServerError(e) => Some(e),
_ => None,
}
}
}
@ -117,8 +169,10 @@ where
///
/// [`ServerFnError`] and [`ServerFnErrorErr`] mutually implement [`From`], so
/// it is easy to convert between the two types.
#[derive(Error, Debug, Clone, Serialize, Deserialize)]
pub enum ServerFnErrorErr {
#[derive(Error, Debug, Clone)]
pub enum ServerFnErrorErr<E = NoCustomError> {
#[error("internal error: {0}")]
WrappedServerError(E),
/// Error while trying to register the server function (only occurs in case of poisoned RwLock).
#[error("error while trying to register the server function: {0}")]
Registration(String),
@ -140,28 +194,23 @@ pub enum ServerFnErrorErr {
/// Occurs on the server if there's a missing argument.
#[error("missing argument {0}")]
MissingArg(String),
/// Occurs on the server if there is an error creating an HTTP response.
#[error("error creating response {0}")]
Response(String),
}
impl From<ServerFnError> for ServerFnErrorErr {
fn from(value: ServerFnError) -> Self {
impl<CustErr> From<ServerFnError<CustErr>> for ServerFnErrorErr<CustErr> {
fn from(value: ServerFnError<CustErr>) -> Self {
match value {
ServerFnError::Registration(value) => {
ServerFnErrorErr::Registration(value)
}
ServerFnError::Registration(value) => ServerFnErrorErr::Registration(value),
ServerFnError::Request(value) => ServerFnErrorErr::Request(value),
ServerFnError::ServerError(value) => {
ServerFnErrorErr::ServerError(value)
}
ServerFnError::Deserialization(value) => {
ServerFnErrorErr::Deserialization(value)
}
ServerFnError::Serialization(value) => {
ServerFnErrorErr::Serialization(value)
}
ServerFnError::ServerError(value) => ServerFnErrorErr::ServerError(value),
ServerFnError::Deserialization(value) => ServerFnErrorErr::Deserialization(value),
ServerFnError::Serialization(value) => ServerFnErrorErr::Serialization(value),
ServerFnError::Args(value) => ServerFnErrorErr::Args(value),
ServerFnError::MissingArg(value) => {
ServerFnErrorErr::MissingArg(value)
}
ServerFnError::MissingArg(value) => ServerFnErrorErr::MissingArg(value),
ServerFnError::WrappedServerError(value) => ServerFnErrorErr::WrappedServerError(value),
ServerFnError::Response(value) => ServerFnErrorErr::Response(value),
}
}
}

View file

@ -1,632 +1,291 @@
#![forbid(unsafe_code)]
#![deny(missing_docs)]
pub mod client;
pub mod codec;
#[macro_use]
pub mod error;
pub mod middleware;
pub mod redirect;
pub mod request;
pub mod response;
//! # Server Functions
//!
//! This package is based on a simple idea: sometimes its useful to write functions
//! that will only run on the server, and call them from the client.
//!
//! If youre creating anything beyond a toy app, youll need to do this all the time:
//! reading from or writing to a database that only runs on the server, running expensive
//! computations using libraries you dont want to ship down to the client, accessing
//! APIs that need to be called from the server rather than the client for CORS reasons
//! or because you need a secret API key thats stored on the server and definitely
//! shouldnt be shipped down to a users browser.
//!
//! Traditionally, this is done by separating your server and client code, and by setting
//! up something like a REST API or GraphQL API to allow your client to fetch and mutate
//! data on the server. This is fine, but it requires you to write and maintain your code
//! in multiple separate places (client-side code for fetching, server-side functions to run),
//! as well as creating a third thing to manage, which is the API contract between the two.
//!
//! This package provides two simple primitives that allow you instead to write co-located,
//! isomorphic server functions. (*Co-located* means you can write them in your app code so
//! that they are “located alongside” the client code that calls them, rather than separating
//! the client and server sides. *Isomorphic* means you can call them from the client as if
//! you were simply calling a function; the function call has the “same shape” on the client
//! as it does on the server.)
//!
//! ### `#[server]`
//!
//! The [`#[server]`][server] macro allows you to annotate a function to
//! indicate that it should only run on the server (i.e., when you have an `ssr` feature in your
//! crate that is enabled).
//!
//! **Important**: Before calling a server function on a non-web platform, you must set the server URL by calling [`set_server_url`].
//!
//! ```rust,ignore
//! #[server(ReadFromDB)]
//! async fn read_posts(how_many: usize, query: String) -> Result<Vec<Posts>, ServerFnError> {
//! // do some server-only work here to access the database
//! let posts = ...;
//! Ok(posts)
//! }
//!
//! // call the function
//! # #[tokio::main]
//! # async fn main() {
//! async {
//! let posts = read_posts(3, "my search".to_string()).await;
//! log::debug!("posts = {posts:#?}");
//! }
//! # }
//!
//! // make sure you've registered it somewhere in main
//! fn main() {
//! // for non-web apps, you must set the server URL manually
//! server_fn::set_server_url("http://localhost:3000");
//! _ = ReadFromDB::register();
//! }
//! ```
//!
//! If you call this function from the client, it will serialize the function arguments and `POST`
//! them to the server as if they were the inputs in `<form method="POST">`.
//!
//! Heres what you need to remember:
//! - **Server functions must be `async`.** Even if the work being done inside the function body
//! can run synchronously on the server, from the clients perspective it involves an asynchronous
//! function call.
//! - **Server functions must return `Result<T, ServerFnError>`.** Even if the work being done
//! inside the function body cant fail, the processes of serialization/deserialization and the
//! network call are fallible.
//! - **Return types must implement [`serde::Serialize`].**
//! This should be fairly obvious: we have to serialize arguments to send them to the server, and we
//! need to deserialize the result to return it to the client.
//! - **Arguments must be implement [`serde::Serialize`].** They are serialized as an `application/x-www-form-urlencoded`
//! form data using [`serde_qs`] or as `application/cbor` using [`cbor`].
//!
//! [server]: <https://docs.rs/server_fn/latest/server_fn/attr.server.html>
//! [`serde_qs`]: <https://docs.rs/serde_qs/latest/serde_qs/>
//! [`cbor`]: <https://docs.rs/cbor/latest/cbor/>
use client::Client;
use codec::{Encoding, FromReq, FromRes, IntoReq, IntoRes};
use dashmap::DashMap;
pub use error::ServerFnError;
use middleware::{Layer, Service};
use once_cell::sync::Lazy;
use request::Req;
use response::{ClientRes, Res};
use serde::{de::DeserializeOwned, Serialize};
use std::{future::Future, pin::Pin, sync::Arc};
// used by the macro
// reexports for the sake of the macro
#[doc(hidden)]
pub use const_format;
// used by the macro
#[cfg(feature = "ssr")]
#[doc(hidden)]
pub use inventory;
#[cfg(any(feature = "ssr", doc))]
use proc_macro2::TokenStream;
#[cfg(any(feature = "ssr", doc))]
use quote::TokenStreamExt;
// used by the macro
#[doc(hidden)]
pub use serde;
use serde::{de::DeserializeOwned, Serialize};
pub use server_fn_macro_default::server;
use std::{future::Future, pin::Pin, str::FromStr};
#[cfg(any(feature = "ssr", doc))]
use syn::parse_quote;
// used by the macro
#[doc(hidden)]
pub use xxhash_rust;
/// Error types used in server functions.
pub mod error;
pub use error::ServerFnError;
/// Default server function registry
pub mod default;
pub trait ServerFn
where
Self: Send
+ FromReq<Self::Error, Self::ServerRequest, Self::InputEncoding>
+ IntoReq<Self::Error, <Self::Client as Client<Self::Error>>::Request, Self::InputEncoding>,
{
const PATH: &'static str;
/// Something that can register a server function.
pub trait ServerFunctionRegistry<T> {
/// An error that can occur when registering a server function.
type Error: std::error::Error;
/// The type of the HTTP client that will send the request from the client side.
///
/// For example, this might be `gloo-net` in the browser, or `reqwest` for a desktop app.
type Client: Client<Self::Error>;
/// Server functions are automatically registered on most platforms, (including Linux, macOS,
/// iOS, FreeBSD, Android, and Windows). If you are on another platform, like a WASM server runtime,
/// this will explicitly register server functions.
fn register_explicit(
prefix: &'static str,
url: &'static str,
server_function: SerializedFnTraitObj<T>,
encoding: Encoding,
) -> Result<(), Self::Error>;
/// The type of the HTTP request when received by the server function on the server side.
type ServerRequest: Req<Self::Error> + Send;
/// Returns the server function registered at the given URL, or `None` if no function is registered at that URL.
fn get(url: &str) -> Option<ServerFnTraitObj<T>>;
/// The type of the HTTP response returned by the server function on the server side.
type ServerResponse: Res<Self::Error> + Send;
/// Returns the server function registered at the given URL, or `None` if no function is registered at that URL.
fn get_trait_obj(url: &str) -> Option<ServerFnTraitObj<T>>;
/// Returns the encoding of the server FN at the given URL, or `None` if no function is
/// registered at that URL
fn get_encoding(url: &str) -> Option<Encoding>;
/// Returns a list of all registered server functions.
fn paths_registered() -> Vec<&'static str>;
/// The return type of the server function.
///
/// This needs to be converted into `ServerResponse` on the server side, and converted
/// *from* `ClientResponse` when received by the client.
type Output: IntoRes<Self::Error, Self::ServerResponse, Self::OutputEncoding>
+ FromRes<Self::Error, <Self::Client as Client<Self::Error>>::Response, Self::OutputEncoding>
+ Send;
/// The [`Encoding`] used in the request for arguments into the server function.
type InputEncoding: Encoding;
/// The [`Encoding`] used in the response for the result of the server function.
type OutputEncoding: Encoding;
/// The type of the custom error on [`ServerFnError`], if any. (If there is no
/// custom error type, this can be `NoCustomError` by default.)
type Error: Serialize + DeserializeOwned;
/// Middleware that should be applied to this server function.
fn middlewares() -> Vec<Arc<dyn Layer<Self::ServerRequest, Self::ServerResponse>>> {
Vec::new()
}
// The body of the server function. This will only run on the server.
fn run_body(
self,
) -> impl Future<Output = Result<Self::Output, ServerFnError<Self::Error>>> + Send;
fn run_on_server(
req: Self::ServerRequest,
) -> impl Future<Output = Self::ServerResponse> + Send {
async {
Self::execute_on_server(req)
.await
.unwrap_or_else(Self::ServerResponse::error_response)
}
}
fn run_on_client(
self,
) -> impl Future<Output = Result<Self::Output, ServerFnError<Self::Error>>> + Send {
async move {
// create and send request on client
let req = self.into_req(Self::PATH, Self::OutputEncoding::CONTENT_TYPE)?;
let res = Self::Client::send(req).await?;
let status = res.status();
let location = res.location();
// if it returns an error status, deserialize the error
// this is the same logic as the current implementation of server fns
// TODO I don't love that this requires shipping `serde_json` for errors
let res = if (400..=599).contains(&status) {
let status_text = res.status_text();
let text = res.try_into_string().await?;
match serde_json::from_str(&text) {
Ok(e) => Err(e),
Err(_) => Err(ServerFnError::ServerError(if text.is_empty() {
format!("{} {}", status, status_text)
} else {
format!("{} {}: {}", status, status_text, text)
})),
}
} else {
// otherwise, deserialize the body as is
Self::Output::from_res(res).await
};
// if redirected, call the redirect hook (if that's been set)
if (300..=399).contains(&status) {
redirect::call_redirect_hook(&location);
}
res
}
}
#[doc(hidden)]
fn execute_on_server(
req: Self::ServerRequest,
) -> impl Future<Output = Result<Self::ServerResponse, ServerFnError<Self::Error>>> + Send {
async {
let this = Self::from_req(req).await?;
let output = this.run_body().await?;
let res = output.into_res().await?;
Ok(res)
}
}
fn url() -> &'static str {
Self::PATH
}
}
/// A server function that can be called from the client.
pub type SerializedFnTraitObj<T> =
fn(
T,
&[u8],
) -> Pin<Box<dyn Future<Output = Result<Payload, ServerFnError>>>>;
#[doc(hidden)]
pub use inventory;
/// A server function that can be called from the client.
#[derive(Clone)]
pub struct ServerFnTraitObj<T> {
pub(crate) prefix: &'static str,
pub(crate) url: &'static str,
pub(crate) encoding: Encoding,
pub(crate) run: SerializedFnTraitObj<T>,
#[macro_export]
macro_rules! initialize_server_fn_map {
($req:ty, $res:ty) => {
once_cell::sync::Lazy::new(|| {
$crate::inventory::iter::<ServerFnTraitObj<$req, $res>>
.into_iter()
.map(|obj| (obj.path(), *obj))
.collect()
})
};
}
impl<T> ServerFnTraitObj<T> {
/// Creates a new server function with the given prefix, URL, encoding, and function.
pub struct ServerFnTraitObj<Req, Res> {
path: &'static str,
handler: fn(Req) -> Pin<Box<dyn Future<Output = Res> + Send>>,
middleware: fn() -> Vec<Arc<dyn Layer<Req, Res>>>,
}
impl<Req, Res> ServerFnTraitObj<Req, Res> {
pub const fn new(
prefix: &'static str,
url: &'static str,
encoding: Encoding,
run: SerializedFnTraitObj<T>,
path: &'static str,
handler: fn(Req) -> Pin<Box<dyn Future<Output = Res> + Send>>,
middleware: fn() -> Vec<Arc<dyn Layer<Req, Res>>>,
) -> Self {
Self {
prefix,
url,
encoding,
run,
path,
handler,
middleware,
}
}
/// Runs the server function with the given server agruments and serialized buffer from the client.
pub fn call(
&self,
args: T,
buffer: &[u8],
) -> Pin<Box<dyn Future<Output = Result<Payload, ServerFnError>>>> {
(self.run)(args, buffer)
}
/// Returns the prefix of the server function.
pub fn prefix(&self) -> &str {
self.prefix
}
/// Returns the URL of the server function.
pub fn url(&self) -> &str {
self.url
}
/// Returns the encoding of the server function.
pub fn encoding(&self) -> Encoding {
self.encoding
pub fn path(&self) -> &'static str {
self.path
}
}
/// A dual type to hold the possible Response datatypes
#[derive(Debug)]
pub enum Payload {
///Encodes Data using CBOR
Binary(Vec<u8>),
///Encodes data in the URL
Url(String),
///Encodes Data using Json
Json(String),
}
/// Attempts to find a server function registered at the given path.
///
/// This can be used by a server to handle the requests, as in the following example (using [`actix-web`])
///
/// ```rust, ignore
/// #[post("{tail:.*}")]
/// async fn handle_server_fns(
/// req: HttpRequest,
/// params: web::Path<String>,
/// body: web::Bytes,
/// ) -> impl Responder {
/// let path = params.into_inner();
/// let accept_header = req
/// .headers()
/// .get("Accept")
/// .and_then(|value| value.to_str().ok());
///
/// if let Some(server_fn) = server_fn_by_path::<MyRegistry>(path.as_str()) {
/// let body: &[u8] = &body;
/// match (server_fn.trait_obj)(&body).await {
/// Ok(serialized) => {
/// // if this is Accept: application/json then send a serialized JSON response
/// if let Some("application/json") = accept_header {
/// HttpResponse::Ok().body(serialized)
/// }
/// // otherwise, it's probably a <form> submit or something: redirect back to the referrer
/// else {
/// HttpResponse::SeeOther()
/// .insert_header(("Location", "/"))
/// .content_type("application/json")
/// .body(serialized)
/// }
/// }
/// Err(e) => {
/// eprintln!("server function error: {e:#?}");
/// HttpResponse::InternalServerError().body(e.to_string())
/// }
/// }
/// } else {
/// HttpResponse::BadRequest().body(format!("Could not find a server function at that route."))
/// }
/// }
/// ```
///
/// [`actix-web`]: <https://docs.rs/actix-web/>
#[cfg(any(feature = "ssr", doc))]
pub fn server_fn_by_path<T: 'static, R: ServerFunctionRegistry<T>>(
path: &str,
) -> Option<ServerFnTraitObj<T>> {
R::get(path)
}
/// Returns a trait obj of the server fn for calling purposes
#[cfg(any(feature = "ssr", doc))]
pub fn server_fn_trait_obj_by_path<T: 'static, R: ServerFunctionRegistry<T>>(
path: &str,
) -> Option<ServerFnTraitObj<T>> {
R::get_trait_obj(path)
}
/// Returns the Encoding of the server fn at a particular path
#[cfg(any(feature = "ssr", doc))]
pub fn server_fn_encoding_by_path<T: 'static, R: ServerFunctionRegistry<T>>(
path: &str,
) -> Option<Encoding> {
R::get_encoding(path)
}
/// Returns the set of currently-registered server function paths, for debugging purposes.
#[cfg(any(feature = "ssr", doc))]
pub fn server_fns_by_path<T: 'static, R: ServerFunctionRegistry<T>>(
) -> Vec<&'static str> {
R::paths_registered()
}
/// Holds the current options for encoding types.
/// More could be added, but they need to be serde
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub enum Encoding {
/// A Binary Encoding Scheme Called Cbor
Cbor,
/// The Default URL-encoded encoding method
#[default]
Url,
/// Pass arguments to server fns as part of the query string. Cacheable. Returns JSON
GetJSON,
/// Pass arguments to server fns as part of the query string. Cacheable. Returns CBOR
GetCBOR,
}
impl FromStr for Encoding {
type Err = ();
fn from_str(input: &str) -> Result<Encoding, Self::Err> {
match input {
"URL" => Ok(Encoding::Url),
"Cbor" => Ok(Encoding::Cbor),
"GetCbor" => Ok(Encoding::GetCBOR),
"GetJson" => Ok(Encoding::GetJSON),
_ => Err(()),
}
}
}
#[cfg(any(feature = "ssr", doc))]
impl quote::ToTokens for Encoding {
fn to_tokens(&self, tokens: &mut TokenStream) {
let option: syn::Ident = match *self {
Encoding::Cbor => parse_quote!(Cbor),
Encoding::Url => parse_quote!(Url),
Encoding::GetJSON => parse_quote!(GetJSON),
Encoding::GetCBOR => parse_quote!(GetCBOR),
};
let expansion: syn::Ident = syn::parse_quote! {
Encoding::#option
};
tokens.append(expansion);
}
}
/// Defines a "server function." A server function can be called from the server or the client,
/// but the body of its code will only be run on the server, i.e., if a crate feature `ssr` (server-side-rendering) is enabled.
///
/// Server functions are created using the `server` macro.
///
/// The set of server functions can be queried on the server for routing purposes by calling [`server_fn_by_path`].
///
/// Technically, the trait is implemented on a type that describes the server function's arguments.
pub trait ServerFn<T: 'static>
impl<Req, Res> Service<Req, Res> for ServerFnTraitObj<Req, Res>
where
Self: Serialize + DeserializeOwned + Sized + 'static,
Req: Send + 'static,
Res: 'static,
{
/// The return type of the function.
type Output: serde::Serialize;
/// URL prefix that should be prepended by the client to the generated URL.
fn prefix() -> &'static str;
/// The path at which the server function can be reached on the server.
fn url() -> &'static str;
/// The path at which the server function can be reached on the server.
fn encoding() -> Encoding;
/// Runs the function on the server.
#[cfg(any(feature = "ssr", doc))]
fn call_fn(
self,
cx: T,
) -> Pin<Box<dyn Future<Output = Result<Self::Output, ServerFnError>>>>;
/// Runs the function on the client by sending an HTTP request to the server.
#[cfg(any(not(feature = "ssr"), doc))]
fn call_fn_client(
self,
cx: T,
) -> Pin<Box<dyn Future<Output = Result<Self::Output, ServerFnError>>>>;
/// Returns a trait object that can be used to call the server function.
#[cfg(any(feature = "ssr", doc))]
fn call_from_bytes(
cx: T,
data: &[u8],
) -> Pin<Box<dyn Future<Output = Result<Payload, ServerFnError>>>> {
// decode the args
let value = match Self::encoding() {
Encoding::Url | Encoding::GetJSON | Encoding::GetCBOR => {
serde_qs::Config::new(5, false)
.deserialize_bytes(data)
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
}
Encoding::Cbor => ciborium::de::from_reader(data)
.map_err(|e| ServerFnError::Deserialization(e.to_string())),
};
Box::pin(async move {
let value: Self = match value {
Ok(v) => v,
Err(e) => return Err(e),
};
// call the function
let result = match value.call_fn(cx).await {
Ok(r) => r,
Err(e) => return Err(e),
};
// serialize the output
let result = match Self::encoding() {
Encoding::Url | Encoding::GetJSON => {
match serde_json::to_string(&result).map_err(|e| {
ServerFnError::Serialization(e.to_string())
}) {
Ok(r) => Payload::Url(r),
Err(e) => return Err(e),
}
}
Encoding::Cbor | Encoding::GetCBOR => {
let mut buffer: Vec<u8> = Vec::new();
match ciborium::ser::into_writer(&result, &mut buffer)
.map_err(|e| {
ServerFnError::Serialization(e.to_string())
}) {
Ok(_) => Payload::Binary(buffer),
Err(e) => return Err(e),
}
}
};
Ok(result)
})
as Pin<Box<dyn Future<Output = Result<Payload, ServerFnError>>>>
}
/// Registers the server function explicitly on platforms that require it,
/// allowing the server to query it by URL.
#[cfg(any(feature = "ssr", doc,))]
fn register_in_explicit<R: ServerFunctionRegistry<T>>(
) -> Result<(), ServerFnError> {
// store it in the hashmap
R::register_explicit(
Self::prefix(),
Self::url(),
Self::call_from_bytes,
Self::encoding(),
)
.map_err(|e| ServerFnError::Registration(e.to_string()))
fn run(&mut self, req: Req) -> Pin<Box<dyn Future<Output = Res> + Send>> {
let handler = self.handler;
Box::pin(async move { handler(req).await })
}
}
/// Executes the HTTP call to call a server function from the client, given its URL and argument type.
#[cfg(not(feature = "ssr"))]
pub async fn call_server_fn<T, C: 'static>(
url: &str,
args: impl ServerFn<C>,
enc: Encoding,
) -> Result<T, ServerFnError>
where
T: serde::Serialize + serde::de::DeserializeOwned + Sized,
{
use ciborium::ser::into_writer;
use serde_json::Deserializer as JSONDeserializer;
#[cfg(not(target_arch = "wasm32"))]
let url = format!("{}{}", get_server_url(), url);
#[derive(Debug)]
enum Payload {
Binary(Vec<u8>),
Url(String),
impl<Req, Res> Clone for ServerFnTraitObj<Req, Res> {
fn clone(&self) -> Self {
*self
}
let args_encoded = match &enc {
Encoding::Url | Encoding::GetJSON | Encoding::GetCBOR => Payload::Url(
serde_qs::to_string(&args)
.map_err(|e| ServerFnError::Serialization(e.to_string()))?,
),
Encoding::Cbor => {
let mut buffer: Vec<u8> = Vec::new();
into_writer(&args, &mut buffer)
.map_err(|e| ServerFnError::Serialization(e.to_string()))?;
Payload::Binary(buffer)
}
};
}
let content_type_header = match &enc {
Encoding::Url | Encoding::GetJSON | Encoding::GetCBOR => {
"application/x-www-form-urlencoded"
}
Encoding::Cbor => "application/cbor",
};
impl<Req, Res> Copy for ServerFnTraitObj<Req, Res> {}
let accept_header = match &enc {
Encoding::Url | Encoding::GetJSON => {
"application/x-www-form-urlencoded"
}
Encoding::Cbor | Encoding::GetCBOR => "application/cbor",
};
type LazyServerFnMap<Req, Res> = Lazy<DashMap<&'static str, ServerFnTraitObj<Req, Res>>>;
#[cfg(target_arch = "wasm32")]
let resp = match &enc {
Encoding::Url | Encoding::Cbor => match args_encoded {
Payload::Binary(b) => {
let slice_ref: &[u8] = &b;
let js_array = js_sys::Uint8Array::from(slice_ref).buffer();
gloo_net::http::Request::post(url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.body(js_array)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?
}
Payload::Url(s) => gloo_net::http::Request::post(url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.body(s)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?,
},
Encoding::GetCBOR | Encoding::GetJSON => match args_encoded {
Payload::Binary(_) => panic!(
"Binary data cannot be transferred via GET request in a query \
string. Please try using the CBOR encoding."
// Axum integration
#[cfg(feature = "axum")]
pub mod axum {
use crate::{
middleware::{BoxedService, Layer, Service},
LazyServerFnMap, ServerFn, ServerFnTraitObj,
};
use axum::body::Body;
use http::{Request, Response, StatusCode};
inventory::collect!(ServerFnTraitObj<Request<Body>, Response<Body>>);
static REGISTERED_SERVER_FUNCTIONS: LazyServerFnMap<Request<Body>, Response<Body>> =
initialize_server_fn_map!(Request<Body>, Response<Body>);
pub fn register_explicit<T>()
where
T: ServerFn<ServerRequest = Request<Body>, ServerResponse = Response<Body>> + 'static,
{
REGISTERED_SERVER_FUNCTIONS.insert(
T::PATH,
ServerFnTraitObj::new(
T::PATH,
|req| Box::pin(T::run_on_server(req)),
T::middlewares,
),
Payload::Url(s) => {
let full_url = format!("{url}?{s}");
gloo_net::http::Request::get(&full_url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?
}
},
};
#[cfg(not(target_arch = "wasm32"))]
let resp = match &enc {
Encoding::Url | Encoding::Cbor => match args_encoded {
Payload::Binary(b) => CLIENT
.post(url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.body(b)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?,
Payload::Url(s) => CLIENT
.post(url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.body(s)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?,
},
Encoding::GetJSON | Encoding::GetCBOR => match args_encoded {
Payload::Binary(_) => panic!(
"Binary data cannot be transferred via GET request in a query \
string. Please try using the CBOR encoding."
),
Payload::Url(s) => {
let full_url = format!("{url}?{s}");
CLIENT
.get(full_url)
.header("Content-Type", content_type_header)
.header("Accept", accept_header)
.send()
.await
.map_err(|e| ServerFnError::Request(e.to_string()))?
}
},
};
// check for error status
let status = resp.status();
#[cfg(not(target_arch = "wasm32"))]
let status = status.as_u16();
if (400..=599).contains(&status) {
let text = resp.text().await.unwrap_or_default();
return Err(match serde_json::from_str(&text) {
Ok(e) => e,
Err(_) => {
#[cfg(target_arch = "wasm32")]
let status_text = resp.status_text();
#[cfg(not(target_arch = "wasm32"))]
let status_text = status.to_string();
ServerFnError::ServerError(if text.is_empty() {
format!("{} {}", status, status_text)
} else {
format!("{} {}: {}", status, status_text, text)
})
}
});
);
}
// Decoding the body of the request
if (enc == Encoding::Cbor) || (enc == Encoding::GetCBOR) {
#[cfg(target_arch = "wasm32")]
let binary = resp
.binary()
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))?;
#[cfg(target_arch = "wasm32")]
let binary = binary.as_slice();
#[cfg(not(target_arch = "wasm32"))]
let binary = resp
.bytes()
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))?;
#[cfg(not(target_arch = "wasm32"))]
let binary = binary.as_ref();
pub async fn handle_server_fn(req: Request<Body>) -> Response<Body> {
let path = req.uri().path();
ciborium::de::from_reader(binary)
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
} else {
let text = resp
.text()
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))?;
let mut deserializer = JSONDeserializer::from_str(&text);
T::deserialize(&mut deserializer)
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
if let Some(server_fn) = REGISTERED_SERVER_FUNCTIONS.get(path) {
let middleware = (server_fn.middleware)();
let mut service = BoxedService::new(*server_fn);
for middleware in middleware {
service = middleware.layer(service);
}
service.run(req).await
} else {
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from(format!(
"Could not find a server function at the route {path}. \n\nIt's likely that either\n 1. The API prefix you specify in the `#[server]` macro doesn't match the prefix at which your server function handler is mounted, or \n2. You are on a platform that doesn't support automatic server function registration and you need to call ServerFn::register_explicit() on the server function type, somewhere in your `main` function.",
)))
.unwrap()
}
}
}
// Lazily initialize the client to be reused for all server function calls.
#[cfg(any(all(not(feature = "ssr"), not(target_arch = "wasm32")), doc))]
static CLIENT: once_cell::sync::Lazy<reqwest::Client> =
once_cell::sync::Lazy::new(reqwest::Client::new);
// Actix integration
#[cfg(feature = "actix")]
pub mod actix {
use actix_web::{HttpRequest, HttpResponse};
use send_wrapper::SendWrapper;
#[cfg(any(all(not(feature = "ssr"), not(target_arch = "wasm32")), doc))]
static ROOT_URL: once_cell::sync::OnceCell<&'static str> =
once_cell::sync::OnceCell::new();
use crate::request::actix::ActixRequest;
use crate::response::actix::ActixResponse;
use crate::{LazyServerFnMap, ServerFn, ServerFnTraitObj};
#[cfg(any(all(not(feature = "ssr"), not(target_arch = "wasm32")), doc))]
/// Set the root server url that all server function paths are relative to for the client. On WASM this will default to the origin.
pub fn set_server_url(url: &'static str) {
ROOT_URL.set(url).unwrap();
}
#[cfg(all(not(feature = "ssr"), not(target_arch = "wasm32")))]
fn get_server_url() -> &'static str {
ROOT_URL
.get()
.expect("Call set_root_url before calling a server function.")
inventory::collect!(ServerFnTraitObj<ActixRequest, ActixResponse>);
static REGISTERED_SERVER_FUNCTIONS: LazyServerFnMap<ActixRequest, ActixResponse> =
initialize_server_fn_map!(ActixRequest, ActixResponse);
pub fn register_explicit<T>()
where
T: ServerFn<ServerRequest = ActixRequest, ServerResponse = ActixResponse> + 'static,
{
REGISTERED_SERVER_FUNCTIONS.insert(
T::PATH,
ServerFnTraitObj::new(T::PATH, |req| Box::pin(T::run_on_server(req))),
);
}
pub async fn handle_server_fn(req: HttpRequest) -> HttpResponse {
let path = req.uri().path();
if let Some(server_fn) = REGISTERED_SERVER_FUNCTIONS.get(path) {
server_fn
.run(ActixRequest(SendWrapper::new(req)))
.await
.0
.take()
} else {
HttpResponse::BadRequest().body(format!(
"Could not find a server function at the route {path}. \n\nIt's likely that either\n 1. The API prefix you specify in the `#[server]` macro doesn't match the prefix at which your server function handler is mounted, or \n2. You are on a platform that doesn't support automatic server function registration and you need to call ServerFn::register_explicit() on the server function type, somewhere in your `main` function.",
))
}
}
}

View file

@ -0,0 +1,109 @@
use std::{future::Future, pin::Pin};
pub trait Layer<Req, Res>: Send + Sync + 'static {
fn layer(&self, inner: BoxedService<Req, Res>) -> BoxedService<Req, Res>;
}
pub struct BoxedService<Req, Res>(pub Box<dyn Service<Req, Res> + Send>);
impl<Req, Res> BoxedService<Req, Res> {
pub fn new(service: impl Service<Req, Res> + Send + 'static) -> Self {
Self(Box::new(service))
}
}
pub trait Service<Request, Response> {
fn run(&mut self, req: Request) -> Pin<Box<dyn Future<Output = Response> + Send>>;
}
#[cfg(feature = "axum")]
mod axum {
use crate::{response::Res, ServerFnError};
use axum::body::Body;
use http::{Request, Response};
use std::fmt::{Debug, Display};
use std::future::Future;
use std::pin::Pin;
use super::{BoxedService, Service};
impl<S> super::Service<Request<Body>, Response<Body>> for S
where
S: tower::Service<Request<Body>, Response = Response<Body>>,
S::Future: Send + 'static,
S::Error: Into<ServerFnError> + Send + Debug + Display + Sync + 'static,
{
fn run(
&mut self,
req: Request<Body>,
) -> Pin<Box<dyn Future<Output = Response<Body>> + Send>> {
let inner = self.call(req);
Box::pin(async move {
inner.await.unwrap_or_else(|e| {
let err = ServerFnError::from(e);
Response::<Body>::error_response(err)
})
})
}
}
impl tower::Service<Request<Body>> for BoxedService<Request<Body>, Response<Body>> {
type Response = Response<Body>;
type Error = ServerFnError;
type Future =
Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
Ok(()).into()
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let inner = self.0.run(req);
Box::pin(async move { Ok(inner.await) })
}
}
impl<L> super::Layer<Request<Body>, Response<Body>> for L
where
L: tower_layer::Layer<BoxedService<Request<Body>, Response<Body>>> + Sync + Send + 'static,
L::Service: Service<Request<Body>, Response<Body>> + Send + 'static,
{
fn layer(
&self,
inner: BoxedService<Request<Body>, Response<Body>>,
) -> BoxedService<Request<Body>, Response<Body>> {
BoxedService(Box::new(self.layer(inner)))
}
}
}
#[cfg(feature = "actix")]
mod actix {
use crate::{
response::{actix::ActixResponse, Res},
ServerFnError,
};
use actix_web::{HttpRequest, HttpResponse};
use std::fmt::{Debug, Display};
use std::{future::Future, pin::Pin};
impl<S> super::Service<HttpRequest, HttpResponse> for S
where
S: actix_web::dev::Service<HttpRequest, Response = HttpResponse>,
S::Future: Send + 'static,
S::Error: Into<ServerFnError> + Debug + Display + 'static,
{
fn run(&mut self, req: HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + Send>> {
let inner = self.call(req);
Box::pin(async move {
inner.await.unwrap_or_else(|e| {
let err = ServerFnError::from(e);
ActixResponse::error_response(err).into_inner()
})
})
}
}
}

13
server_fn/src/redirect.rs Normal file
View file

@ -0,0 +1,13 @@
use std::sync::OnceLock;
static REDIRECT_HOOK: OnceLock<Box<dyn Fn(&str) + Send + Sync>> = OnceLock::new();
pub fn set_redirect_hook(hook: impl Fn(&str) + Send + Sync + 'static) {
REDIRECT_HOOK.set(Box::new(hook));
}
pub fn call_redirect_hook(path: &str) {
if let Some(hook) = REDIRECT_HOOK.get() {
hook(path)
}
}

View file

@ -0,0 +1,50 @@
use crate::{error::ServerFnError, request::Req};
use actix_web::{FromRequest, HttpRequest};
use bytes::Bytes;
use futures::Stream;
use send_wrapper::SendWrapper;
use std::future::Future;
pub struct ActixRequest(pub(crate) SendWrapper<HttpRequest>);
impl<CustErr> Req<CustErr> for ActixRequest {
fn as_query(&self) -> Option<&str> {
self.0.uri().query()
}
fn to_content_type(&self) -> Option<String> {
self.0
.headers()
.get("Content-Type")
.map(|h| String::from_utf8_lossy(h.as_bytes()).to_string())
}
fn try_into_bytes(self) -> impl Future<Output = Result<Bytes, ServerFnError<CustErr>>> + Send {
// Actix is going to keep this on a single thread anyway so it's fine to wrap it
// with SendWrapper, which makes it `Send` but will panic if it moves to another thread
SendWrapper::new(async move {
Bytes::extract(&self.0)
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
})
}
fn try_into_string(
self,
) -> impl Future<Output = Result<String, ServerFnError<CustErr>>> + Send {
// Actix is going to keep this on a single thread anyway so it's fine to wrap it
// with SendWrapper, which makes it `Send` but will panic if it moves to another thread
SendWrapper::new(async move {
String::extract(&self.0)
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
})
}
fn try_into_stream(
self,
) -> Result<impl Stream<Item = Result<Bytes, ServerFnError>> + Send, ServerFnError<CustErr>>
{
Ok(futures::stream::once(async { todo!() }))
}
}

View file

@ -0,0 +1,43 @@
use crate::{error::ServerFnError, request::Req};
use axum::body::{Body, Bytes};
use futures::{Stream, StreamExt};
use http::{header::CONTENT_TYPE, Request};
use http_body_util::BodyExt;
impl<CustErr> Req<CustErr> for Request<Body> {
fn as_query(&self) -> Option<&str> {
self.uri().query()
}
fn to_content_type(&self) -> Option<String> {
self.headers()
.get(CONTENT_TYPE)
.map(|h| String::from_utf8_lossy(h.as_bytes()).to_string())
}
async fn try_into_bytes(self) -> Result<Bytes, ServerFnError<CustErr>> {
let (_parts, body) = self.into_parts();
body.collect()
.await
.map(|c| c.to_bytes())
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
}
async fn try_into_string(self) -> Result<String, ServerFnError<CustErr>> {
let bytes = self.try_into_bytes().await?;
let body = String::from_utf8(bytes.to_vec())
.map_err(|e| ServerFnError::Deserialization(e.to_string()));
body
}
fn try_into_stream(
self,
) -> Result<impl Stream<Item = Result<Bytes, ServerFnError>> + Send, ServerFnError<CustErr>>
{
Ok(self
.into_body()
.into_data_stream()
.map(|chunk| chunk.map_err(|e| ServerFnError::Deserialization(e.to_string()))))
}
}

View file

@ -0,0 +1,93 @@
use crate::error::ServerFnError;
use super::ClientReq;
use bytes::Bytes;
pub use gloo_net::http::Request;
use js_sys::Uint8Array;
use send_wrapper::SendWrapper;
use web_sys::FormData;
#[derive(Debug)]
pub struct BrowserRequest(pub(crate) SendWrapper<Request>);
impl From<Request> for BrowserRequest {
fn from(value: Request) -> Self {
Self(SendWrapper::new(value))
}
}
#[derive(Debug)]
pub struct BrowserFormData(pub(crate) SendWrapper<FormData>);
impl From<FormData> for BrowserFormData {
fn from(value: FormData) -> Self {
Self(SendWrapper::new(value))
}
}
impl<CustErr> ClientReq<CustErr> for BrowserRequest {
type FormData = BrowserFormData;
fn try_new_get(
path: &str,
accepts: &str,
content_type: &str,
query: &str,
) -> Result<Self, ServerFnError<CustErr>> {
let mut url = path.to_owned();
url.push('?');
url.push_str(query);
Ok(Self(SendWrapper::new(
Request::get(&url)
.header("Content-Type", content_type)
.header("Accept", accepts)
.build()
.map_err(|e| ServerFnError::Request(e.to_string()))?,
)))
}
fn try_new_post(
path: &str,
accepts: &str,
content_type: &str,
body: String,
) -> Result<Self, ServerFnError<CustErr>> {
Ok(Self(SendWrapper::new(
Request::post(path)
.header("Content-Type", content_type)
.header("Accept", accepts)
.body(body)
.map_err(|e| ServerFnError::Request(e.to_string()))?,
)))
}
fn try_new_post_bytes(
path: &str,
accepts: &str,
content_type: &str,
body: Bytes,
) -> Result<Self, ServerFnError<CustErr>> {
let body: &[u8] = &body;
let body = Uint8Array::from(body).buffer();
Ok(Self(SendWrapper::new(
Request::post(path)
.header("Content-Type", content_type)
.header("Accept", accepts)
.body(body)
.map_err(|e| ServerFnError::Request(e.to_string()))?,
)))
}
fn try_new_multipart(
path: &str,
accepts: &str,
body: Self::FormData,
) -> Result<Self, ServerFnError<CustErr>> {
Ok(Self(SendWrapper::new(
Request::post(path)
.header("Accept", accepts)
.body(body.0.take())
.map_err(|e| ServerFnError::Request(e.to_string()))?,
)))
}
}

View file

@ -0,0 +1,103 @@
use crate::error::ServerFnError;
use bytes::Bytes;
use futures::Stream;
use std::future::Future;
#[cfg(feature = "actix")]
pub mod actix;
#[cfg(feature = "axum")]
pub mod axum;
#[cfg(feature = "browser")]
pub mod browser;
#[cfg(feature = "reqwest")]
pub mod reqwest;
/// Represents a request as made by the client.
pub trait ClientReq<CustErr>
where
Self: Sized,
{
type FormData;
fn try_new_get(
path: &str,
content_type: &str,
accepts: &str,
query: &str,
) -> Result<Self, ServerFnError<CustErr>>;
fn try_new_post(
path: &str,
content_type: &str,
accepts: &str,
body: String,
) -> Result<Self, ServerFnError<CustErr>>;
fn try_new_post_bytes(
path: &str,
content_type: &str,
accepts: &str,
body: Bytes,
) -> Result<Self, ServerFnError<CustErr>>;
fn try_new_multipart(
path: &str,
accepts: &str,
body: Self::FormData,
) -> Result<Self, ServerFnError<CustErr>>;
}
/// Represents the request as received by the server.
pub trait Req<CustErr>
where
Self: Sized,
{
/// Returns the query string of the requests URL, starting after the `?`.
fn as_query(&self) -> Option<&str>;
/// Returns the `Content-Type` header, if any.
fn to_content_type(&self) -> Option<String>;
/// Attempts to extract the body of the request into [`Bytes`].
fn try_into_bytes(self) -> impl Future<Output = Result<Bytes, ServerFnError<CustErr>>> + Send;
/// Attempts to convert the body of the request into a string.
fn try_into_string(self)
-> impl Future<Output = Result<String, ServerFnError<CustErr>>> + Send;
/// Attempts to convert the body of the request into a string.
fn try_into_stream(
self,
) -> Result<impl Stream<Item = Result<Bytes, ServerFnError>> + Send, ServerFnError<CustErr>>;
}
/// A mocked request type that can be used in place of the actual server request,
/// when compiling for the browser.
pub struct BrowserMockReq;
impl<CustErr> Req<CustErr> for BrowserMockReq {
fn as_query(&self) -> Option<&str> {
unreachable!()
}
fn to_content_type(&self) -> Option<String> {
unreachable!()
}
fn try_into_bytes(self) -> impl Future<Output = Result<Bytes, ServerFnError<CustErr>>> + Send {
async { unreachable!() }
}
fn try_into_string(
self,
) -> impl Future<Output = Result<String, ServerFnError<CustErr>>> + Send {
async { unreachable!() }
}
fn try_into_stream(
self,
) -> Result<impl Stream<Item = Result<Bytes, ServerFnError>> + Send, ServerFnError<CustErr>>
{
Ok(futures::stream::once(async { unreachable!() }))
}
}

View file

@ -0,0 +1,93 @@
use std::sync::OnceLock;
use crate::error::ServerFnError;
use super::ClientReq;
use bytes::Bytes;
use once_cell::sync::Lazy;
use reqwest::header::{ACCEPT, CONTENT_TYPE};
pub use reqwest::{multipart::Form, Client, Method, Request, Url};
pub(crate) static CLIENT: Lazy<Client> = Lazy::new(Client::new);
static ROOT_URL: OnceLock<&'static str> = OnceLock::new();
/// Set the root server url that all server function paths are relative to for the client.
///
/// If this is not set, it defaults to the origin.
pub fn set_server_url(url: &'static str) {
ROOT_URL.set(url).unwrap();
}
fn get_server_url() -> &'static str {
ROOT_URL
.get()
.expect("Call `set_root_url` before calling a server function.")
}
impl<CustErr> ClientReq<CustErr> for Request {
type FormData = Form;
fn try_new_get(
path: &str,
accepts: &str,
content_type: &str,
query: &str,
) -> Result<Self, ServerFnError<CustErr>> {
let url = format!("{}{}", get_server_url(), path);
let mut url =
Url::try_from(url.as_str()).map_err(|e| ServerFnError::Request(e.to_string()))?;
url.set_query(Some(query));
let req = CLIENT
.get(url)
.header(CONTENT_TYPE, content_type)
.header(ACCEPT, accepts)
.build()
.map_err(|e| ServerFnError::Request(e.to_string()))?;
Ok(req)
}
fn try_new_post(
path: &str,
accepts: &str,
content_type: &str,
body: String,
) -> Result<Self, ServerFnError<CustErr>> {
let url = format!("{}{}", get_server_url(), path);
Ok(CLIENT
.post(url)
.header(CONTENT_TYPE, content_type)
.header(ACCEPT, accepts)
.body(body)
.build()
.map_err(|e| ServerFnError::Request(e.to_string()))?)
}
fn try_new_post_bytes(
path: &str,
accepts: &str,
content_type: &str,
body: Bytes,
) -> Result<Self, ServerFnError<CustErr>> {
let url = format!("{}{}", get_server_url(), path);
Ok(CLIENT
.post(url)
.header(CONTENT_TYPE, content_type)
.header(ACCEPT, accepts)
.body(body)
.build()
.map_err(|e| ServerFnError::Request(e.to_string()))?)
}
fn try_new_multipart(
path: &str,
accepts: &str,
body: Self::FormData,
) -> Result<Self, ServerFnError<CustErr>> {
Ok(CLIENT
.post(path)
.header(ACCEPT, accepts)
.multipart(body)
.build()
.map_err(|e| ServerFnError::Request(e.to_string()))?)
}
}

View file

@ -0,0 +1,51 @@
use super::Res;
use crate::error::ServerFnError;
use actix_web::{http::header, http::StatusCode, HttpResponse};
use bytes::Bytes;
use futures::Stream;
use send_wrapper::SendWrapper;
use std::fmt::Display;
pub struct ActixResponse(pub(crate) SendWrapper<HttpResponse>);
impl ActixResponse {
pub fn into_inner(self) -> HttpResponse {
self.0.take()
}
}
impl<CustErr> Res<CustErr> for ActixResponse
where
CustErr: Display,
{
fn try_from_string(content_type: &str, data: String) -> Result<Self, ServerFnError<CustErr>> {
let mut builder = HttpResponse::build(StatusCode::OK);
Ok(ActixResponse(SendWrapper::new(
builder
.insert_header((header::CONTENT_TYPE, content_type))
.body(data),
)))
}
fn try_from_bytes(content_type: &str, data: Bytes) -> Result<Self, ServerFnError<CustErr>> {
let mut builder = HttpResponse::build(StatusCode::OK);
Ok(ActixResponse(SendWrapper::new(
builder
.insert_header((header::CONTENT_TYPE, content_type))
.body(data),
)))
}
fn error_response(err: ServerFnError<CustErr>) -> Self {
ActixResponse(SendWrapper::new(
HttpResponse::build(StatusCode::INTERNAL_SERVER_ERROR).body(err.to_string()),
))
}
fn try_from_stream(
content_type: &str,
data: impl Stream<Item = Result<Bytes, ServerFnError<CustErr>>>,
) -> Result<Self, ServerFnError<CustErr>> {
todo!()
}
}

View file

@ -0,0 +1,73 @@
use crate::error::ServerFnError;
use super::ClientRes;
use bytes::Bytes;
use futures::{Stream, StreamExt};
pub use gloo_net::http::Response;
use js_sys::{wasm_bindgen::JsCast, Uint8Array};
use send_wrapper::SendWrapper;
use std::future::Future;
use wasm_streams::ReadableStream;
pub struct BrowserResponse(pub(crate) SendWrapper<Response>);
impl<CustErr> ClientRes<CustErr> for BrowserResponse {
fn try_into_string(
self,
) -> impl Future<Output = Result<String, ServerFnError<CustErr>>> + Send {
// the browser won't send this async work between threads (because it's single-threaded)
// so we can safely wrap this
SendWrapper::new(async move {
self.0
.text()
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
})
}
fn try_into_bytes(self) -> impl Future<Output = Result<Bytes, ServerFnError<CustErr>>> + Send {
// the browser won't send this async work between threads (because it's single-threaded)
// so we can safely wrap this
SendWrapper::new(async move {
self.0
.binary()
.await
.map(Bytes::from)
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
})
}
fn try_into_stream(
self,
) -> Result<
impl Stream<Item = Result<Bytes, ServerFnError>> + Send + 'static,
ServerFnError<CustErr>,
> {
let stream = ReadableStream::from_raw(self.0.body().unwrap())
.into_stream()
.map(|data| {
let data = data.unwrap().unchecked_into::<Uint8Array>();
let mut buf = Vec::new();
let length = data.length();
buf.resize(length as usize, 0);
data.copy_to(&mut buf);
Ok(Bytes::from(buf))
});
Ok(SendWrapper::new(stream))
}
fn status(&self) -> u16 {
self.0.status()
}
fn status_text(&self) -> String {
self.0.status_text()
}
fn location(&self) -> String {
self.0
.headers()
.get("Location")
.unwrap_or_else(|| self.0.url())
}
}

View file

@ -0,0 +1,50 @@
use super::Res;
use crate::error::{ServerFnError, ServerFnErrorErr};
use axum::body::Body;
use bytes::Bytes;
use futures::{Stream, StreamExt};
use http::Response;
use std::fmt::{Debug, Display};
impl<CustErr> Res<CustErr> for Response<Body>
where
CustErr: Send + Sync + Debug + Display + 'static,
{
fn try_from_string(content_type: &str, data: String) -> Result<Self, ServerFnError<CustErr>> {
let builder = http::Response::builder();
builder
.status(200)
.header(http::header::CONTENT_TYPE, content_type)
.body(Body::from(data))
.map_err(|e| ServerFnError::Response(e.to_string()))
}
fn try_from_bytes(content_type: &str, data: Bytes) -> Result<Self, ServerFnError<CustErr>> {
let builder = http::Response::builder();
builder
.status(200)
.header(http::header::CONTENT_TYPE, content_type)
.body(Body::from(data))
.map_err(|e| ServerFnError::Response(e.to_string()))
}
fn try_from_stream(
content_type: &str,
data: impl Stream<Item = Result<Bytes, ServerFnError<CustErr>>> + Send + 'static,
) -> Result<Self, ServerFnError<CustErr>> {
let body = Body::from_stream(data.map(|n| n.map_err(ServerFnErrorErr::from)));
let builder = http::Response::builder();
builder
.status(200)
.header(http::header::CONTENT_TYPE, content_type)
.body(body)
.map_err(|e| ServerFnError::Response(e.to_string()))
}
fn error_response(err: ServerFnError<CustErr>) -> Self {
Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(err.to_string()))
.unwrap()
}
}

View file

@ -0,0 +1,85 @@
#[cfg(feature = "actix")]
pub mod actix;
#[cfg(feature = "browser")]
pub mod browser;
#[cfg(feature = "axum")]
pub mod http;
#[cfg(feature = "reqwest")]
pub mod reqwest;
use crate::error::ServerFnError;
use bytes::Bytes;
use futures::Stream;
use std::future::Future;
/// Represents the response as created by the server;
pub trait Res<CustErr>
where
Self: Sized,
{
/// Attempts to convert a UTF-8 string into an HTTP response.
fn try_from_string(content_type: &str, data: String) -> Result<Self, ServerFnError<CustErr>>;
/// Attempts to convert a binary blob represented as bytes into an HTTP response.
fn try_from_bytes(content_type: &str, data: Bytes) -> Result<Self, ServerFnError<CustErr>>;
/// Attempts to convert a stream of bytes into an HTTP response.
fn try_from_stream(
content_type: &str,
data: impl Stream<Item = Result<Bytes, ServerFnError<CustErr>>> + Send + 'static,
) -> Result<Self, ServerFnError<CustErr>>;
fn error_response(err: ServerFnError<CustErr>) -> Self;
}
/// Represents the response as received by the client.
pub trait ClientRes<CustErr> {
/// Attempts to extract a UTF-8 string from an HTTP response.
fn try_into_string(self)
-> impl Future<Output = Result<String, ServerFnError<CustErr>>> + Send;
/// Attempts to extract a binary blob from an HTTP response.
fn try_into_bytes(self) -> impl Future<Output = Result<Bytes, ServerFnError<CustErr>>> + Send;
/// Attempts to extract a binary stream from an HTTP response.
fn try_into_stream(
self,
) -> Result<
impl Stream<Item = Result<Bytes, ServerFnError>> + Send + 'static,
ServerFnError<CustErr>,
>;
/// HTTP status code of the response.
fn status(&self) -> u16;
/// Status text for the status code.
fn status_text(&self) -> String;
/// The `Location` header or (if none is set), the URL of the response.
fn location(&self) -> String;
}
/// A mocked response type that can be used in place of the actual server response,
/// when compiling for the browser.
pub struct BrowserMockRes;
impl<CustErr> Res<CustErr> for BrowserMockRes {
fn try_from_string(content_type: &str, data: String) -> Result<Self, ServerFnError<CustErr>> {
unreachable!()
}
fn try_from_bytes(content_type: &str, data: Bytes) -> Result<Self, ServerFnError<CustErr>> {
unreachable!()
}
fn error_response(err: ServerFnError<CustErr>) -> Self {
unreachable!()
}
fn try_from_stream(
content_type: &str,
data: impl Stream<Item = Result<Bytes, ServerFnError<CustErr>>>,
) -> Result<Self, ServerFnError<CustErr>> {
todo!()
}
}

View file

@ -0,0 +1,45 @@
use super::ClientRes;
use crate::error::ServerFnError;
use bytes::Bytes;
use futures::{Stream, TryStreamExt};
use reqwest::Response;
impl<CustErr> ClientRes<CustErr> for Response {
async fn try_into_string(self) -> Result<String, ServerFnError<CustErr>> {
self.text()
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
}
async fn try_into_bytes(self) -> Result<Bytes, ServerFnError<CustErr>> {
self.bytes()
.await
.map_err(|e| ServerFnError::Deserialization(e.to_string()))
}
fn try_into_stream(
self,
) -> Result<
impl Stream<Item = Result<Bytes, ServerFnError>> + Send + 'static,
ServerFnError<CustErr>,
> {
Ok(self
.bytes_stream()
.map_err(|e| ServerFnError::Response(e.to_string())))
}
fn status(&self) -> u16 {
self.status().as_u16()
}
fn status_text(&self) -> String {
self.status().to_string()
}
fn location(&self) -> String {
self.headers()
.get("Location")
.map(|value| String::from_utf8_lossy(value.as_bytes()).to_string())
.unwrap_or_else(|| self.url().to_string())
}
}

View file

@ -1,22 +1,22 @@
[package]
name = "server_fn_macro"
version = { workspace = true }
version = "0.6.0"
edition = "2021"
authors = ["Greg Johnston"]
license = "MIT"
repository = "https://github.com/leptos-rs/leptos"
description = "RPC for any web framework."
readme = "../README.md"
[dependencies]
serde = { version = "1", features = ["derive"] }
quote = "1"
syn = { version = "2", features = ["full", "parsing", "extra-traits"] }
proc-macro2 = "1"
proc-macro-error = { version = "1", default-features = false }
xxhash-rust = { version = "0.8.6", features = ["const_xxh64"] }
const_format = "0.2.30"
convert_case = "0.6.0"
[features]
nightly = []
ssr = []
actix = []
axum = []

View file

@ -1 +0,0 @@
extend = { path = "../cargo-make/main.toml" }

View file

@ -1,16 +1,14 @@
#![cfg_attr(feature = "nightly", feature(proc_macro_span))]
#![forbid(unsafe_code)]
#![deny(missing_docs)]
// to prevent warnings from popping up when a nightly feature is stabilized
#![allow(stable_features)]
//! Implementation of the `server_fn` macro.
//!
//! This crate contains the implementation of the `server_fn` macro. [`server_macro_impl`] can be used to implement custom versions of the macro for different frameworks that allow users to pass a custom context from the server to the server function.
use convert_case::{Case, Converter};
use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
use proc_macro_error::abort;
use quote::{quote, quote_spanned};
use quote::{quote, quote_spanned, ToTokens};
use syn::{
parse::{Parse, ParseStream},
punctuated::Punctuated,
@ -18,30 +16,6 @@ use syn::{
*,
};
/// Describes the custom context from the server that passed to the server function. Optionally, the first argument of a server function
/// can be a custom context of this type. This context can be used to access the server's state within the server function.
pub struct ServerContext {
/// The type of the context.
pub ty: Ident,
/// The path to the context type. Used to reference the context type in the generated code.
pub path: Path,
}
fn fn_arg_is_cx(f: &syn::FnArg, server_context: &ServerContext) -> bool {
if let FnArg::Typed(t) = f {
if let Type::Path(path) = &*t.ty {
path.path
.segments
.iter()
.any(|segment| segment.ident == server_context.ty)
} else {
false
}
} else {
false
}
}
/// The implementation of the `server_fn` macro.
/// To allow the macro to accept a custom context from the server, pass a custom server context to this function.
/// **The Context comes from the server.** Optionally, the first argument of a server function
@ -53,14 +27,9 @@ fn fn_arg_is_cx(f: &syn::FnArg, server_context: &ServerContext) -> bool {
/// ```ignore
/// #[proc_macro_attribute]
/// pub fn server(args: proc_macro::TokenStream, s: TokenStream) -> TokenStream {
/// let server_context = Some(ServerContext {
/// ty: syn::parse_quote!(MyContext),
/// path: syn::parse_quote!(my_crate::prelude::MyContext),
/// });
/// match server_macro_impl(
/// args.into(),
/// s.into(),
/// Some(server_context),
/// Some(syn::parse_quote!(my_crate::exports::server_fn)),
/// ) {
/// Err(e) => e.to_compile_error().into(),
@ -72,21 +41,57 @@ pub fn server_macro_impl(
args: TokenStream2,
body: TokenStream2,
trait_obj_wrapper: Type,
server_context: Option<ServerContext>,
server_fn_path: Option<Path>,
default_path: &str,
) -> Result<TokenStream2> {
let ServerFnName {
let mut body = syn::parse::<ServerFnBody>(body.into())?;
// extract all #[middleware] attributes, removing them from signature of dummy
let mut middlewares: Vec<Middleware> = vec![];
body.attrs.retain(|attr| {
if attr.meta.path().is_ident("middleware") {
if let Ok(middleware) = attr.parse_args() {
middlewares.push(middleware);
false
} else {
true
}
} else {
true
}
});
let dummy = body.to_dummy_output();
let dummy_name = body.to_dummy_ident();
let args = syn::parse::<ServerFnArgs>(args.into())?;
// default values for args
let ServerFnArgs {
struct_name,
prefix,
encoding,
input,
output,
fn_path,
..
} = syn::parse2::<ServerFnName>(args)?;
let prefix = prefix.unwrap_or_else(|| Literal::string(""));
} = args;
let prefix = prefix.unwrap_or_else(|| Literal::string(default_path));
let fn_path = fn_path.unwrap_or_else(|| Literal::string(""));
let encoding = quote!(#server_fn_path::#encoding);
let input = input.unwrap_or_else(|| syn::parse_quote!(PostUrl));
let input_is_rkyv = input == "Rkyv";
let input_is_multipart = input == "MultipartFormData";
let input = codec_ident(server_fn_path.as_ref(), input);
let output = output.unwrap_or_else(|| syn::parse_quote!(Json));
let output = codec_ident(server_fn_path.as_ref(), output);
// default to PascalCase version of function name if no struct name given
let struct_name = struct_name.unwrap_or_else(|| {
let upper_camel_case_name = Converter::new()
.from_case(Case::Snake)
.to_case(Case::UpperCamel)
.convert(body.ident.to_string());
Ident::new(&upper_camel_case_name, body.ident.span())
});
let mut body = syn::parse::<ServerFnBody>(body.into())?;
// build struct for type
let mut body = body;
let fn_name = &body.ident;
let fn_name_as_str = body.ident.to_string();
let vis = body.vis;
@ -96,23 +101,17 @@ pub fn server_macro_impl(
let fields = body
.inputs
.iter_mut()
.filter(|f| {
if let Some(ctx) = &server_context {
!fn_arg_is_cx(f, ctx)
} else {
true
}
})
.map(|f| {
let typed_arg = match f {
FnArg::Receiver(_) => {
abort!(
f,
"cannot use receiver types in server function macro"
)
return Err(syn::Error::new(
f.span(),
"cannot use receiver types in server function macro",
))
}
FnArg::Typed(t) => t,
};
// allow #[server(default)] on fields — TODO is this documented?
let mut default = false;
let mut other_attrs = Vec::new();
for attr in typed_arg.attrs.iter() {
@ -141,87 +140,76 @@ pub fn server_macro_impl(
})
.collect::<Result<Vec<_>>>()?;
let cx_arg = body.inputs.iter().next().and_then(|f| {
server_context
.as_ref()
.and_then(|ctx| fn_arg_is_cx(f, ctx).then_some(f))
});
let cx_fn_arg = if cx_arg.is_some() {
quote! { cx, }
} else {
quote! {}
};
let fn_args = body
.inputs
.iter()
.filter_map(|f| match f {
FnArg::Receiver(_) => None,
FnArg::Typed(t) => Some(t),
})
.collect::<Vec<_>>();
let fn_args = body.inputs.iter().map(|f| {
let typed_arg = match f {
FnArg::Receiver(_) => {
abort!(f, "cannot use receiver types in server function macro")
}
FnArg::Typed(t) => t,
};
let is_cx = if let Some(ctx) = &server_context {
fn_arg_is_cx(f, ctx)
} else {
false
};
if is_cx {
let field_names = body
.inputs
.iter()
.filter_map(|f| match f {
FnArg::Receiver(_) => None,
FnArg::Typed(t) => Some(&t.pat),
})
.collect::<Vec<_>>();
// if there's exactly one field, impl From<T> for the struct
let first_field = body
.inputs
.iter()
.filter_map(|f| match f {
FnArg::Receiver(_) => None,
FnArg::Typed(t) => Some((&t.pat, &t.ty)),
})
.next();
let from_impl =
(body.inputs.len() == 1 && first_field.is_some()).then(|| {
let field = first_field.unwrap();
let (name, ty) = field;
quote! {
#[allow(unused)]
#typed_arg
}
} else {
quote! { #typed_arg }
}
});
let fn_args_2 = fn_args.clone();
let field_names = body.inputs.iter().filter_map(|f| match f {
FnArg::Receiver(_) => todo!(),
FnArg::Typed(t) => {
if let Some(ctx) = &server_context {
if fn_arg_is_cx(f, ctx) {
None
} else {
Some(&t.pat)
impl From<#struct_name> for #ty {
fn from(value: #struct_name) -> Self {
let #struct_name { #name } = value;
#name
}
}
impl From<#ty> for #struct_name {
fn from(#name: #ty) -> Self {
#struct_name { #name }
}
}
} else {
Some(&t.pat)
}
}
});
let field_names_2 = field_names.clone();
let field_names_3 = field_names.clone();
let field_names_4 = field_names.clone();
let field_names_5 = field_names.clone();
});
// check output type
let output_arrow = body.output_arrow;
let return_ty = body.return_ty;
let output_ty = 'output_ty: {
if let syn::Type::Path(pat) = &return_ty {
if pat.path.segments[0].ident == "Result" {
if let PathArguments::AngleBracketed(args) =
&pat.path.segments[0].arguments
{
break 'output_ty &args.args[0];
}
let output_ty = output_type(&return_ty)?;
let error_ty = err_type(&return_ty)?;
let error_ty =
error_ty.map(ToTokens::to_token_stream).unwrap_or_else(|| {
quote! {
#server_fn_path::error::NoCustomError
}
}
abort!(
return_ty,
"server functions should return Result<T, ServerFnError>"
);
};
let server_ctx_path = if let Some(ctx) = &server_context {
let path = &ctx.path;
quote!(#path)
} else {
quote!(())
};
});
// build server fn path
let serde_path = server_fn_path.as_ref().map(|path| {
let path = path
.segments
.iter()
.map(|segment| segment.ident.to_string())
.collect::<Vec<_>>();
let path = path.join("::");
format!("{path}::serde")
});
let server_fn_path = server_fn_path
.map(|path| quote!(#path))
.unwrap_or_else(|| quote! { server_fn });
@ -239,43 +227,49 @@ pub fn server_macro_impl(
#[doc = #link_to_server_fn]
};
// pass through docs
let docs = body
.docs
.iter()
.map(|(doc, span)| quote_spanned!(*span=> #[doc = #doc]))
.collect::<TokenStream2>();
// auto-registration with inventory
let inventory = if cfg!(feature = "ssr") {
quote! {
#server_fn_path::inventory::submit! {
#trait_obj_wrapper::from_generic_server_fn(#server_fn_path::ServerFnTraitObj::new(
#struct_name::PREFIX,
#struct_name::URL,
#struct_name::ENCODING,
<#struct_name as #server_fn_path::ServerFn<#server_ctx_path>>::call_from_bytes,
))
}
#server_fn_path::inventory::submit! {{
use #server_fn_path::ServerFn;
#server_fn_path::ServerFnTraitObj::new(
#struct_name::PATH,
|req| {
Box::pin(#struct_name::run_on_server(req))
},
#struct_name::middlewares
)
}}
}
} else {
quote! {}
};
let call_fn = if cfg!(feature = "ssr") {
// run_body in the trait implementation
let run_body = if cfg!(feature = "ssr") {
quote! {
fn call_fn(self, cx: #server_ctx_path) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Output, #server_fn_path::ServerFnError>>>> {
async fn run_body(self) -> #return_ty {
let #struct_name { #(#field_names),* } = self;
Box::pin(async move { #fn_name( #cx_fn_arg #(#field_names_2),*).await })
#dummy_name(#(#field_names),*).await
}
}
} else {
quote! {
fn call_fn_client(self, cx: #server_ctx_path) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Output, #server_fn_path::ServerFnError>>>> {
let #struct_name { #(#field_names_3),* } = self;
Box::pin(async move { #fn_name( #cx_fn_arg #(#field_names_4),*).await })
#[allow(unused_variables)]
async fn run_body(self) -> #return_ty {
unreachable!()
}
}
};
// the actual function definition
let func = if cfg!(feature = "ssr") {
quote! {
#docs
@ -289,114 +283,414 @@ pub fn server_macro_impl(
#docs
#(#attrs)*
#[allow(unused_variables)]
#vis async fn #fn_name(#(#fn_args_2),*) #output_arrow #return_ty {
#server_fn_path::call_server_fn(
&{
let prefix = #struct_name::PREFIX.to_string();
prefix + "/" + #struct_name::URL
},
#struct_name { #(#field_names_5),* },
#encoding
).await
#vis async fn #fn_name(#(#fn_args),*) #output_arrow #return_ty {
use #server_fn_path::ServerFn;
let data = #struct_name { #(#field_names),* };
data.run_on_client().await
}
}
};
// TODO rkyv derives
let derives = if input_is_multipart {
quote! {}
} else if input_is_rkyv {
todo!("implement derives for Rkyv")
} else {
quote! {
Clone, #server_fn_path::serde::Serialize, #server_fn_path::serde::Deserialize
}
};
let serde_path = (!input_is_multipart && !input_is_rkyv).then(|| {
quote! {
#[serde(crate = #serde_path)]
}
});
// TODO reqwest
let client = quote! {
#server_fn_path::client::browser::BrowserClient
};
// TODO Actix etc
let req = if !cfg!(feature = "ssr") {
quote! {
#server_fn_path::request::BrowserMockReq
}
} else if cfg!(feature = "axum") {
quote! {
::axum::http::Request<::axum::body::Body>
}
} else if cfg!(feature = "actix") {
quote! {
::actix_web::HttpRequest
}
} else {
return Err(syn::Error::new(
Span::call_site(),
"If the `ssr` feature is enabled, either the `actix` or `axum` \
features should also be enabled.",
));
};
let res = if !cfg!(feature = "ssr") {
quote! {
#server_fn_path::response::BrowserMockRes
}
} else if cfg!(feature = "axum") {
quote! {
::axum::http::Response<::axum::body::Body>
}
} else if cfg!(feature = "actix") {
quote! {
::actix_web::HttpResponse
}
} else {
return Err(syn::Error::new(
Span::call_site(),
"If the `ssr` feature is enabled, either the `actix` or `axum` \
features should also be enabled.",
));
};
// generate path
let path = quote! {
if #fn_path.is_empty() {
#server_fn_path::const_format::concatcp!(
#prefix,
"/",
#fn_name_as_str,
#server_fn_path::xxhash_rust::const_xxh64::xxh64(
concat!(env!(#key_env_var), ":", file!(), ":", line!(), ":", column!()).as_bytes(),
0
)
)
} else {
#server_fn_path::const_format::concatcp!(
#prefix,
#fn_path
)
}
};
// only emit the dummy (unmodified server-only body) for the server build
let dummy = cfg!(feature = "ssr").then_some(dummy);
let middlewares = if cfg!(feature = "ssr") {
quote! {
vec![
#(
std::sync::Arc::new(#middlewares),
),*
]
}
} else {
quote! { vec![] }
};
Ok(quote::quote! {
#args_docs
#docs
#[derive(Clone, Debug, #server_fn_path::serde::Serialize, #server_fn_path::serde::Deserialize)]
#[derive(Debug, #derives)]
#serde_path
pub struct #struct_name {
#(#fields),*
}
impl #struct_name {
const URL: &'static str = if #fn_path.is_empty() {
#server_fn_path::const_format::concatcp!(
#fn_name_as_str,
#server_fn_path::xxhash_rust::const_xxh64::xxh64(
concat!(env!(#key_env_var), ":", file!(), ":", line!(), ":", column!()).as_bytes(),
0
)
)
} else {
#fn_path
};
const PREFIX: &'static str = #prefix;
const ENCODING: #server_fn_path::Encoding = #encoding;
#from_impl
impl #server_fn_path::ServerFn for #struct_name {
// TODO prefix
const PATH: &'static str = #path;
type Client = #client;
type ServerRequest = #req;
type ServerResponse = #res;
type Output = #output_ty;
type InputEncoding = #input;
type OutputEncoding = #output;
type Error = #error_ty;
fn middlewares() -> Vec<std::sync::Arc<dyn #server_fn_path::middleware::Layer<#req, #res>>> {
#middlewares
}
#run_body
}
#inventory
impl #server_fn_path::ServerFn<#server_ctx_path> for #struct_name {
type Output = #output_ty;
fn prefix() -> &'static str {
Self::PREFIX
}
fn url() -> &'static str {
Self::URL
}
fn encoding() -> #server_fn_path::Encoding {
Self::ENCODING
}
#call_fn
}
#func
#dummy
})
}
struct ServerFnName {
_attrs: Vec<Attribute>,
struct_name: Ident,
_comma: Option<Token![,]>,
#[derive(Debug)]
struct Middleware {
expr: syn::Expr,
}
impl ToTokens for Middleware {
fn to_tokens(&self, tokens: &mut TokenStream2) {
let expr = &self.expr;
tokens.extend(quote::quote! {
#expr
});
}
}
impl Parse for Middleware {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let arg: syn::Expr = input.parse()?;
Ok(Middleware { expr: arg })
}
}
fn output_type(return_ty: &Type) -> Result<&GenericArgument> {
if let syn::Type::Path(pat) = &return_ty {
if pat.path.segments[0].ident == "Result" {
if pat.path.segments.is_empty() {
panic!("{:#?}", pat.path);
} else if let PathArguments::AngleBracketed(args) =
&pat.path.segments[0].arguments
{
return Ok(&args.args[0]);
}
}
};
Err(syn::Error::new(
return_ty.span(),
"server functions should return Result<T, ServerFnError> or Result<T, \
ServerFnError<E>>",
))
}
fn err_type(return_ty: &Type) -> Result<Option<&GenericArgument>> {
if let syn::Type::Path(pat) = &return_ty {
if pat.path.segments[0].ident == "Result" {
if let PathArguments::AngleBracketed(args) =
&pat.path.segments[0].arguments
{
// Result<T>
if args.args.len() == 1 {
return Ok(None);
}
// Result<T, _>
else if let GenericArgument::Type(Type::Path(pat)) =
&args.args[1]
{
if pat.path.segments[0].ident == "ServerFnError" {
let args = &pat.path.segments[0].arguments;
match args {
// Result<T, ServerFnError>
PathArguments::None => return Ok(None),
// Result<T, ServerFnError<E>>
PathArguments::AngleBracketed(args) => {
if args.args.len() == 1 {
return Ok(Some(&args.args[0]));
}
}
_ => {}
}
}
}
}
}
};
Err(syn::Error::new(
return_ty.span(),
"server functions should return Result<T, ServerFnError> or Result<T, \
ServerFnError<E>>",
))
}
#[derive(Debug)]
struct ServerFnArgs {
struct_name: Option<Ident>,
prefix: Option<Literal>,
_comma2: Option<Token![,]>,
encoding: Path,
_comma3: Option<Token![,]>,
input: Option<Ident>,
output: Option<Ident>,
fn_path: Option<Literal>,
}
impl Parse for ServerFnName {
fn parse(input: ParseStream) -> syn::Result<Self> {
let _attrs: Vec<Attribute> = input.call(Attribute::parse_outer)?;
let struct_name = input.parse()?;
let _comma = input.parse()?;
let prefix = input.parse()?;
let _comma2 = input.parse()?;
let encoding = input
.parse::<Literal>()
.map(|encoding| {
match encoding.to_string().to_lowercase().as_str() {
"\"url\"" => syn::parse_quote!(Encoding::Url),
"\"cbor\"" => syn::parse_quote!(Encoding::Cbor),
"\"getcbor\"" => syn::parse_quote!(Encoding::GetCBOR),
"\"getjson\"" => syn::parse_quote!(Encoding::GetJSON),
_ => abort!(encoding, "Encoding Not Found"),
impl Parse for ServerFnArgs {
fn parse(stream: ParseStream) -> syn::Result<Self> {
// legacy 4-part arguments
let mut struct_name: Option<Ident> = None;
let mut prefix: Option<Literal> = None;
let mut encoding: Option<Literal> = None;
let mut fn_path: Option<Literal> = None;
// new arguments: can only be keyed by name
let mut input: Option<Ident> = None;
let mut output: Option<Ident> = None;
let mut use_key_and_value = false;
let mut arg_pos = 0;
while !stream.is_empty() {
arg_pos += 1;
let lookahead = stream.lookahead1();
if lookahead.peek(Ident) {
let key_or_value: Ident = stream.parse()?;
let lookahead = stream.lookahead1();
if lookahead.peek(Token![=]) {
stream.parse::<Token![=]>()?;
let key = key_or_value;
use_key_and_value = true;
if key == "name" {
if struct_name.is_some() {
return Err(syn::Error::new(
key.span(),
"keyword argument repeated: `name`",
));
}
struct_name = Some(stream.parse()?);
} else if key == "prefix" {
if prefix.is_some() {
return Err(syn::Error::new(
key.span(),
"keyword argument repeated: `prefix`",
));
}
prefix = Some(stream.parse()?);
} else if key == "encoding" {
if encoding.is_some() {
return Err(syn::Error::new(
key.span(),
"keyword argument repeated: `encoding`",
));
}
encoding = Some(stream.parse()?);
} else if key == "endpoint" {
if fn_path.is_some() {
return Err(syn::Error::new(
key.span(),
"keyword argument repeated: `endpoint`",
));
}
fn_path = Some(stream.parse()?);
} else if key == "input" {
if encoding.is_some() {
return Err(syn::Error::new(
key.span(),
"`encoding` and `input` should not both be \
specified",
));
} else if input.is_some() {
return Err(syn::Error::new(
key.span(),
"keyword argument repeated: `input`",
));
}
input = Some(stream.parse()?);
} else if key == "output" {
if encoding.is_some() {
return Err(syn::Error::new(
key.span(),
"`encoding` and `output` should not both be \
specified",
));
} else if output.is_some() {
return Err(syn::Error::new(
key.span(),
"keyword argument repeated: `output`",
));
}
output = Some(stream.parse()?);
} else {
return Err(lookahead.error());
}
} else {
let value = key_or_value;
if use_key_and_value {
return Err(syn::Error::new(
value.span(),
"positional argument follows keyword argument",
));
}
if arg_pos == 1 {
struct_name = Some(value)
} else {
return Err(syn::Error::new(
value.span(),
"expected string literal",
));
}
}
})
.unwrap_or_else(|_| syn::parse_quote!(Encoding::Url));
let _comma3 = input.parse()?;
let fn_path = input.parse()?;
} else if lookahead.peek(LitStr) {
let value: Literal = stream.parse()?;
if use_key_and_value {
return Err(syn::Error::new(
value.span(),
"If you use keyword arguments (e.g., `name` = \
Something), then you can no longer use arguments \
without a keyword.",
));
}
match arg_pos {
1 => return Err(lookahead.error()),
2 => prefix = Some(value),
3 => encoding = Some(value),
4 => fn_path = Some(value),
_ => {
return Err(syn::Error::new(
value.span(),
"unexpected extra argument",
))
}
}
} else {
return Err(lookahead.error());
}
if !stream.is_empty() {
stream.parse::<Token![,]>()?;
}
}
// parse legacy encoding into input/output
if let Some(encoding) = encoding {
match encoding.to_string().to_lowercase().as_str() {
"\"url\"" => {
input = syn::parse_quote!(PostUrl);
output = syn::parse_quote!(Json);
}
"\"cbor\"" => {
input = syn::parse_quote!(Cbor);
output = syn::parse_quote!(Cbor);
}
"\"getcbor\"" => {
input = syn::parse_quote!(GetUrl);
output = syn::parse_quote!(Cbor);
}
"\"getjson\"" => {
input = syn::parse_quote!(GetUrl);
output = syn::parse_quote!(Json);
}
_ => {
return Err(syn::Error::new(
encoding.span(),
"Encoding not found.",
))
}
}
}
Ok(Self {
_attrs,
struct_name,
_comma,
prefix,
_comma2,
encoding,
_comma3,
input,
output,
fn_path,
})
}
}
#[allow(unused)]
#[derive(Debug)]
struct ServerFnBody {
pub attrs: Vec<Attribute>,
pub vis: syn::Visibility,
@ -408,11 +702,10 @@ struct ServerFnBody {
pub inputs: Punctuated<FnArg, Token![,]>,
pub output_arrow: Token![->],
pub return_ty: syn::Type,
pub block: Box<Block>,
pub block: TokenStream2,
pub docs: Vec<(String, Span)>,
}
/// The custom rusty variant of parsing rsx!
impl Parse for ServerFnBody {
fn parse(input: ParseStream) -> Result<Self> {
let mut attrs: Vec<Attribute> = input.call(Attribute::parse_outer)?;
@ -478,3 +771,57 @@ impl Parse for ServerFnBody {
})
}
}
impl ServerFnBody {
fn to_dummy_ident(&self) -> Ident {
Ident::new(&format!("__{}", self.ident), self.ident.span())
}
fn to_dummy_output(&self) -> TokenStream2 {
let ident = self.to_dummy_ident();
let Self {
attrs,
vis,
async_token,
fn_token,
generics,
inputs,
output_arrow,
return_ty,
block,
..
} = &self;
quote! {
#[doc(hidden)]
#(#attrs)*
#vis #async_token #fn_token #ident #generics ( #inputs ) #output_arrow #return_ty
#block
}
}
}
/// Returns either the path of the codec (if it's a builtin) or the
/// original ident.
fn codec_ident(server_fn_path: Option<&Path>, ident: Ident) -> TokenStream2 {
if let Some(server_fn_path) = server_fn_path {
let str = ident.to_string();
if [
"GetUrl",
"PostUrl",
"Cbor",
"Json",
"Rkyv",
"Streaming",
"StreamingText",
"MultipartFormData",
]
.contains(&str.as_str())
{
return quote! {
#server_fn_path::codec::#ident
};
}
}
ident.into_token_stream()
}

View file

@ -1,6 +1,6 @@
[package]
name = "server_fn_macro_default"
version = { workspace = true }
version = "0.1.0"
edition = "2021"
authors = ["Greg Johnston"]
license = "MIT"
@ -11,8 +11,11 @@ description = "The default implementation of the server_fn macro without a conte
proc-macro = true
[dependencies]
syn = { version = "2", features = ["full"] }
server_fn_macro = { workspace = true }
syn = { version = "2" }
server_fn_macro = { path = "../server_fn_macro" }
[features]
nightly = ["server_fn_macro/nightly"]
ssr = ["server_fn_macro/ssr"]
actix = ["server_fn_macro/actix"]
axum = ["server_fn_macro/axum"]

View file

@ -0,0 +1,19 @@
#![forbid(unsafe_code)]
use proc_macro::TokenStream;
use server_fn_macro::server_macro_impl;
use syn::__private::ToTokens;
#[proc_macro_attribute]
pub fn server(args: proc_macro::TokenStream, s: TokenStream) -> TokenStream {
match server_macro_impl(
args.into(),
s.into(),
syn::parse_quote!(server_fn::default::DefaultServerFnTraitObj),
Some(syn::parse_quote!(server_fns)),
"/api",
) {
Err(e) => e.to_compile_error().into(),
Ok(s) => s.to_token_stream().into(),
}
}