mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 14:34:19 +00:00
add parameter support to the sqlx macro
This commit is contained in:
parent
3a76f9d207
commit
b05ec7686a
2 changed files with 53 additions and 9 deletions
|
@ -2,35 +2,72 @@ extern crate proc_macro;
|
|||
|
||||
use proc_macro::TokenStream;
|
||||
|
||||
use proc_macro2::Span;
|
||||
|
||||
use quote::quote;
|
||||
|
||||
use syn::parse_macro_input;
|
||||
use syn::{parse_macro_input, Expr, ExprLit, Lit, LitStr, Token};
|
||||
use syn::punctuated::Punctuated;
|
||||
use syn::parse::{self, Parse, ParseStream};
|
||||
|
||||
use sha2::{Sha256, Digest};
|
||||
use sqlx::Postgres;
|
||||
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use std::error::Error as _;
|
||||
|
||||
type Error = Box<dyn std::error::Error>;
|
||||
type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
struct MacroInput {
|
||||
sql: String,
|
||||
sql_span: Span,
|
||||
args: Vec<Expr>
|
||||
}
|
||||
|
||||
impl Parse for MacroInput {
|
||||
fn parse(input: ParseStream) -> parse::Result<Self> {
|
||||
let mut args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?
|
||||
.into_iter();
|
||||
|
||||
let sql = match args.next() {
|
||||
Some(Expr::Lit(ExprLit { lit: Lit::Str(sql), .. })) => sql,
|
||||
Some(other_expr) => return Err(parse::Error::new_spanned(other_expr, "expected string literal")),
|
||||
None => return Err(input.error("expected SQL string literal")),
|
||||
};
|
||||
|
||||
Ok(
|
||||
MacroInput {
|
||||
sql: sql.value(),
|
||||
sql_span: sql.span(),
|
||||
args: args.collect(),
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[proc_macro]
|
||||
pub fn sql(input: TokenStream) -> TokenStream {
|
||||
let string = parse_macro_input!(input as syn::LitStr).value();
|
||||
let input = parse_macro_input!(input as MacroInput);
|
||||
|
||||
eprintln!("expanding macro");
|
||||
|
||||
match Runtime::new().map_err(Error::from).and_then(|runtime| runtime.block_on(process_sql(&string))) {
|
||||
match Runtime::new().map_err(Error::from).and_then(|runtime| runtime.block_on(process_sql(input))) {
|
||||
Ok(ts) => ts,
|
||||
Err(e) => {
|
||||
if let Some(parse_err) = e.downcast_ref::<parse::Error>() {
|
||||
return parse_err.to_compile_error().into();
|
||||
}
|
||||
|
||||
let msg = e.to_string();
|
||||
quote! ( compile_error!(#msg) ).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_sql(sql: &str) -> Result<TokenStream> {
|
||||
let hash = dbg!(hex::encode(&Sha256::digest(sql.as_bytes())));
|
||||
async fn process_sql(input: MacroInput) -> Result<TokenStream> {
|
||||
let hash = dbg!(hex::encode(&Sha256::digest(input.sql.as_bytes())));
|
||||
|
||||
let conn = sqlx::Connection::<Postgres>::establish("postgresql://postgres@127.0.0.1/sqlx_test")
|
||||
.await
|
||||
|
@ -38,9 +75,16 @@ async fn process_sql(sql: &str) -> Result<TokenStream> {
|
|||
|
||||
eprintln!("connection established");
|
||||
|
||||
let prepared = conn.prepare(&hash, sql).await?;
|
||||
let prepared = conn.prepare(&hash, &input.sql)
|
||||
.await
|
||||
.map_err(|e| parse::Error::new(input.sql_span, e))?;
|
||||
|
||||
let msg = format!("{:?}", prepared);
|
||||
if input.args.len() != prepared.param_types.len() {
|
||||
return Err(parse::Error::new(
|
||||
Span::call_site(),
|
||||
format!("expected {} parameters, got {}", prepared.param_types.len(), input.args.len())
|
||||
).into());
|
||||
}
|
||||
|
||||
Ok(quote! { compile_error!(#msg) }.into())
|
||||
Ok(quote! { compile_error!("implementation not finished yet") }.into())
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
#![feature(proc_macro_hygiene)]
|
||||
|
||||
fn main() {
|
||||
sqlx_macros::sql!("SELECT * from accounts");
|
||||
sqlx_macros::sql!("SELECT * from accounts where id != $1", "");
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue