query_macros: allow Option<&str> to be passed in place of String

closes #93
This commit is contained in:
Austin Bonander 2020-01-27 19:02:46 -08:00 committed by Ryan Leckey
parent 4163388298
commit 800af574c5
13 changed files with 266 additions and 144 deletions

View file

@ -75,10 +75,6 @@ where
}
fn size_hint(&self) -> usize {
if self.is_some() {
(*self).size_hint()
} else {
0
}
self.as_ref().map_or(0, Encode::size_hint)
}
}

View file

@ -1,7 +1,7 @@
impl_database_ext! {
sqlx::postgres::Postgres {
bool,
String,
String | &str,
i16,
i32,
i64,
@ -9,7 +9,7 @@ impl_database_ext! {
f64,
// BYTEA
Vec<u8>,
Vec<u8> | &[u8],
#[cfg(feature = "uuid")]
sqlx::types::Uuid,

View file

@ -1,79 +1,96 @@
use proc_macro2::TokenStream;
use quote::{quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::Expr;
use quote::{quote, quote_spanned, ToTokens};
use sqlx::describe::Describe;
use crate::database::{DatabaseExt, ParamChecking};
use crate::query_macros::QueryMacroInput;
/// Returns a tokenstream which typechecks the arguments passed to the macro
/// and binds them to `DB::Arguments` with the ident `query_args`.
pub fn quote_args<DB: DatabaseExt>(
input: &QueryMacroInput,
describe: &Describe<DB>,
) -> crate::Result<TokenStream> {
let db_path = DB::db_path();
if input.arg_names.is_empty() {
return Ok(quote! {
let args = ();
let query_args = <#db_path as sqlx::Database>::Arguments::default();
});
}
let arg_name = &input.arg_names;
let args_check = if DB::PARAM_CHECKING == ParamChecking::Strong {
let param_types = describe
describe
.param_types
.iter()
.zip(&*input.arg_exprs)
.zip(input.arg_names.iter().zip(&input.arg_exprs))
.enumerate()
.map(|(i, (type_, expr))| {
get_type_override(expr)
.map(|(i, (param_ty, (name, expr)))| -> crate::Result<_>{
let param_ty = get_type_override(expr)
.or_else(|| {
Some(
DB::param_type_for_id(type_)?
DB::param_type_for_id(param_ty)?
.parse::<proc_macro2::TokenStream>()
.unwrap(),
)
})
.ok_or_else(|| {
if let Some(feature_gate) = <DB as DatabaseExt>::get_feature_gate(&type_) {
if let Some(feature_gate) = <DB as DatabaseExt>::get_feature_gate(&param_ty) {
format!(
"optional feature `{}` required for type {} of param #{}",
feature_gate,
type_,
param_ty,
i + 1,
)
.into()
} else {
format!("unsupported type {} for param #{}", type_, i + 1).into()
format!("unsupported type {} for param #{}", param_ty, i + 1)
}
})
})
.collect::<crate::Result<Vec<_>>>()?;
})?;
let args_ty_cons = input.arg_names.iter().enumerate().map(|(i, expr)| {
// required or `quote!()` emits it as `Nusize`
let i = syn::Index::from(i);
quote_spanned!( expr.span() => {
use sqlx::ty_cons::TyConsExt as _;
sqlx::ty_cons::TyCons::new(&args.#i).ty_cons()
})
});
Ok(quote_spanned!(expr.span() =>
// this shouldn't actually run
if false {
use sqlx::ty_match::{WrapSameExt as _, MatchBorrowExt as _};
// we want to make sure it doesn't run
quote! {
if false {
let _: (#(#param_types),*,) = (#(#args_ty_cons),*,);
}
}
// evaluate the expression only once in case it contains moves
let _expr = sqlx::ty_match::dupe_value(&$#name);
// if `_expr` is `Option<T>`, get `Option<$ty>`, otherwise `$ty`
let ty_check = sqlx::ty_match::WrapSame::<#param_ty, _>::new(&_expr).wrap_same();
// if `_expr` is `&str`, convert `String` to `&str`
let (mut ty_check, match_borrow) = sqlx::ty_match::MatchBorrow::new(ty_check, &_expr);
ty_check = match_borrow.match_borrow();
// this causes move-analysis to effectively ignore this block
panic!();
}
))
})
.collect::<crate::Result<TokenStream>>()?
} else {
// all we can do is check arity which we did in `QueryMacroInput::describe_validate()`
TokenStream::new()
};
let args = input.arg_names.iter();
let args_count = input.arg_names.len();
Ok(quote! {
// emit as a tuple first so each expression is only evaluated once
let args = (#(&$#args),*,);
#args_check
// bind as a local expression, by-ref
#(let #arg_name = &$#arg_name;)*
let mut query_args = <#db_path as sqlx::Database>::Arguments::default();
query_args.reserve(
#args_count,
0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(#arg_name))*
);
#(query_args.add(#arg_name);)*
})
}

View file

@ -1,7 +1,7 @@
use std::env;
use proc_macro2::{Ident, Span};
use sqlx::runtime::fs;
use quote::{format_ident, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
@ -9,10 +9,9 @@ use syn::token::Group;
use syn::{Expr, ExprLit, ExprPath, Lit};
use syn::{ExprGroup, Token};
use quote::{format_ident, ToTokens};
use sqlx::connection::Connection;
use sqlx::describe::Describe;
use sqlx::runtime::fs;
/// Macro input shared by `query!()` and `query_file!()`
pub struct QueryMacroInput {

View file

@ -46,7 +46,6 @@ where
}
let args_tokens = args::quote_args(&input.query_input, &describe)?;
let arg_names = &input.query_input.arg_names;
let query_args = format_ident!("query_args");
@ -58,10 +57,7 @@ where
&columns,
);
let db_path = <C::Database as DatabaseExt>::db_path();
let args_count = arg_names.len();
let arg_indices = (0..args_count).map(|i| syn::Index::from(i));
let arg_indices_2 = arg_indices.clone();
let arg_names = &input.query_input.arg_names;
Ok(quote! {
macro_rules! macro_result {
@ -70,13 +66,6 @@ where
#args_tokens
let mut #query_args = <#db_path as sqlx::Database>::Arguments::default();
#query_args.reserve(
#args_count,
0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))*
);
#(#query_args.add(args.#arg_indices_2);)*
#output
}}
}

View file

@ -26,9 +26,6 @@ where
let args = args::quote_args(&input, &describe)?;
let arg_names = &input.arg_names;
let args_count = arg_names.len();
let arg_indices = (0..args_count).map(|i| syn::Index::from(i));
let arg_indices_2 = arg_indices.clone();
let db_path = <C::Database as DatabaseExt>::db_path();
if describe.result_columns.is_empty() {
@ -39,14 +36,6 @@ where
#args
let mut query_args = <#db_path as sqlx::Database>::Arguments::default();
query_args.reserve(
#args_count,
0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))*
);
#(query_args.add(args.#arg_indices_2);)*
sqlx::query::<#db_path>(#sql).bind_all(query_args)
}
}}
@ -85,14 +74,6 @@ where
#args
let mut #query_args = <#db_path as sqlx::Database>::Arguments::default();
#query_args.reserve(
#args_count,
0 #(+ sqlx::encode::Encode::<#db_path>::size_hint(args.#arg_indices))*
);
#(#query_args.add(args.#arg_indices_2);)*
#output
}
}}

View file

@ -40,7 +40,7 @@ mod macros;
// macro support
#[cfg(feature = "macros")]
#[doc(hidden)]
pub mod ty_cons;
pub mod ty_match;
#[cfg(feature = "macros")]
#[doc(hidden)]

View file

@ -97,9 +97,9 @@ macro_rules! query (
($query:literal, $($args:expr),*$(,)?) => ({
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query!($query, $($args),*);
$crate::sqlx_macros::query!($query, $($args)*);
}
macro_result!($($args),*)
macro_result!($($args)*)
})
);
@ -158,9 +158,9 @@ macro_rules! query_file (
($query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)]{
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_file!($query, $($args),*);
$crate::sqlx_macros::query_file!($query, $($args)*);
}
macro_result!($($args),*)
macro_result!($($args)*)
})
);
@ -224,9 +224,9 @@ macro_rules! query_as (
($out_struct:path, $query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)] {
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_as!($out_struct, $query, $($args),*);
$crate::sqlx_macros::query_as!($out_struct, $query, $($args)*);
}
macro_result!($($args),*)
macro_result!($($args)*)
})
);
@ -275,8 +275,8 @@ macro_rules! query_file_as (
($out_struct:path, $query:literal, $($args:expr),*$(,)?) => (#[allow(dead_code)] {
#[macro_use]
mod _macro_result {
$crate::sqlx_macros::query_file_as!($out_struct, $query, $($args),*);
$crate::sqlx_macros::query_file_as!($out_struct, $query, $($args)*);
}
macro_result!($($args),*)
macro_result!($($args)*)
})
);

View file

@ -1,59 +0,0 @@
use std::marker::PhantomData;
// These types allow the `sqlx_macros::query_[as]!()` macros to polymorphically compare a
// given parameter's type to an expected parameter type even if the former
// is behind a reference or in `Option`
#[doc(hidden)]
pub struct TyCons<T>(PhantomData<T>);
impl<T> TyCons<T> {
pub fn new(_t: &T) -> TyCons<T> {
TyCons(PhantomData)
}
}
#[doc(hidden)]
pub trait TyConsExt: Sized {
type Cons;
fn ty_cons(self) -> Self::Cons {
panic!("should not be run, only for type resolution")
}
}
impl<T> TyCons<Option<&'_ T>> {
pub fn ty_cons(self) -> T {
panic!("should not be run, only for type resolution")
}
}
// no overlap with the following impls because of the `: Sized` bound
impl<T: Sized> TyConsExt for TyCons<&'_ T> {
type Cons = T;
}
impl TyConsExt for TyCons<&'_ str> {
type Cons = String;
}
impl<T> TyConsExt for TyCons<&'_ [T]> {
type Cons = Vec<T>;
}
impl<T> TyConsExt for TyCons<Option<T>> {
type Cons = T;
}
impl<T> TyConsExt for &'_ TyCons<T> {
type Cons = T;
}
#[test]
fn test_tycons_ext() {
if false {
let _: u64 = TyCons::new(&Some(5u64)).ty_cons();
let _: u64 = TyCons::new(&Some(&5u64)).ty_cons();
let _: u64 = TyCons::new(&&5u64).ty_cons();
let _: u64 = TyCons::new(&5u64).ty_cons();
}
}

122
src/ty_match.rs Normal file
View file

@ -0,0 +1,122 @@
use std::marker::PhantomData;
// These types allow the `query!()` and friends to compare a given parameter's type to
// an expected parameter type even if the former is behind a reference or in `Option`.
// For query parameters, Postgres gives us a single type ID which we convert to an "expected" or
// preferred Rust type, but there can actually be several types that are compatible for a given type
// in input position. E.g. for an expected parameter of `String`, we want to accept `String`,
// `Option<String>`, `&str` and `Option<&str>`. And for the best compiler errors we don't just
// want an `IsCompatible` trait (at least not without `#[on_unimplemented]` which is unstable
// for the foreseeable future).
// We can do this by using autoref (for method calls, the compiler adds reference ops until
// it finds a matching impl) with impls that technically don't overlap as a hacky form of
// specialization (but this works only if all types are statically known, i.e. we're not in a
// generic context; this should suit 99% of use cases for the macros).
pub fn same_type<T>(_1: &T, _2: &T) {}
pub struct WrapSame<T, U>(PhantomData<T>, PhantomData<U>);
impl<T, U> WrapSame<T, U> {
pub fn new(_arg: &U) -> Self {
WrapSame(PhantomData, PhantomData)
}
}
pub trait WrapSameExt: Sized {
type Wrapped;
fn wrap_same(self) -> Self::Wrapped {
panic!("only for type resolution")
}
}
impl<T, U> WrapSameExt for WrapSame<T, Option<U>> {
type Wrapped = Option<T>;
}
impl<T, U> WrapSameExt for &'_ WrapSame<T, U> {
type Wrapped = T;
}
pub struct MatchBorrow<T, U>(PhantomData<T>, PhantomData<U>);
impl<T, U> MatchBorrow<T, U> {
pub fn new(t: T, _u: &U) -> (T, Self) {
(t, MatchBorrow(PhantomData, PhantomData))
}
}
pub trait MatchBorrowExt: Sized {
type Matched;
fn match_borrow(self) -> Self::Matched {
panic!("only for type resolution")
}
}
impl<'a> MatchBorrowExt for MatchBorrow<Option<&'a str>, Option<String>> {
type Matched = Option<&'a str>;
}
impl<'a> MatchBorrowExt for MatchBorrow<Option<&'a [u8]>, Option<Vec<u8>>> {
type Matched = Option<&'a [u8]>;
}
impl<'a> MatchBorrowExt for MatchBorrow<Option<&'a str>, Option<&'a String>> {
type Matched = Option<&'a str>;
}
impl<'a> MatchBorrowExt for MatchBorrow<Option<&'a [u8]>, Option<&'a Vec<u8>>> {
type Matched = Option<&'a [u8]>;
}
impl<'a> MatchBorrowExt for MatchBorrow<&'a str, String> {
type Matched = &'a str;
}
impl<'a> MatchBorrowExt for MatchBorrow<&'a [u8], Vec<u8>> {
type Matched = &'a [u8];
}
impl<T, U> MatchBorrowExt for &'_ MatchBorrow<T, U> {
type Matched = U;
}
pub fn conjure_value<T>() -> T {
panic!()
}
pub fn dupe_value<T>(_t: &T) -> T {
panic!()
}
#[test]
fn test_dupe_value() {
let ref val = (String::new(),);
if false {
let _: i32 = dupe_value(&0i32);
let _: String = dupe_value(&String::new());
let _: String = dupe_value(&val.0);
}
}
#[test]
fn test_wrap_same() {
if false {
let _: i32 = WrapSame::<i32, _>::new(&0i32).wrap_same();
let _: i32 = WrapSame::<i32, _>::new(&"hello, world!").wrap_same();
let _: Option<i32> = WrapSame::<i32, _>::new(&Some(String::new())).wrap_same();
}
}
#[test]
fn test_match_borrow() {
if false {
let (_, match_borrow) = MatchBorrow::new("", &String::new());
let _: &str = match_borrow.match_borrow();
}
}

View file

@ -54,9 +54,11 @@ struct Account {
async fn test_query_as() -> anyhow::Result<()> {
let mut conn = connect().await?;
let name: Option<&str> = None;
let account = sqlx::query_as!(
Account,
"SELECT * from (VALUES (1, null)) accounts(id, name)",
"SELECT * from (VALUES (1, $1)) accounts(id, name)",
name
)
.fetch_one(&mut conn)
.await?;
@ -114,12 +116,18 @@ async fn query_by_string() -> anyhow::Result<()> {
let mut conn = connect().await?;
let string = "Hello, world!".to_string();
let ref tuple = ("Hello, world!".to_string(),);
let result = sqlx::query!(
"SELECT * from (VALUES('Hello, world!')) strings(string)\
where string = $1 or string = $2",
string,
string[..]
where string in ($1, $2, $3, $4, $5, $6, $7)",
string, // make sure we don't actually take ownership here
&string[..],
Some(&string),
Some(&string[..]),
Option::<String>::None,
string.clone(),
tuple.0 // make sure we're not trying to move out of a field expression
)
.fetch_one(&mut conn)
.await?;

View file

@ -0,0 +1,14 @@
fn main() {
let _query = sqlx::query!("select $1::text", 0i32);
let _query = sqlx::query!("select $1::text", &0i32);
let _query = sqlx::query!("select $1::text", Some(0i32));
let arg = 0i32;
let _query = sqlx::query!("select $1::text", arg);
let arg = Some(0i32);
let _query = sqlx::query!("select $1::text", arg);
let _query = sqlx::query!("select $1::text", arg.as_ref());
}

View file

@ -0,0 +1,55 @@
error[E0308]: mismatched types
--> $DIR/wrong_param_type.rs:2:50
|
2 | let _query = sqlx::query!("select $1::text", 0i32);
| ^^^^ expected `&str`, found `i32`
|
= note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)
error[E0308]: mismatched types
--> $DIR/wrong_param_type.rs:4:50
|
4 | let _query = sqlx::query!("select $1::text", &0i32);
| ^^^^^ expected `str`, found `i32`
|
= note: expected reference `&str`
found reference `&i32`
= note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)
error[E0308]: mismatched types
--> $DIR/wrong_param_type.rs:6:50
|
6 | let _query = sqlx::query!("select $1::text", Some(0i32));
| ^^^^^^^^^^ expected `&str`, found `i32`
|
= note: expected enum `std::option::Option<&str>`
found enum `std::option::Option<i32>`
= note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)
error[E0308]: mismatched types
--> $DIR/wrong_param_type.rs:9:50
|
9 | let _query = sqlx::query!("select $1::text", arg);
| ^^^ expected `&str`, found `i32`
|
= note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)
error[E0308]: mismatched types
--> $DIR/wrong_param_type.rs:12:50
|
12 | let _query = sqlx::query!("select $1::text", arg);
| ^^^ expected `&str`, found `i32`
|
= note: expected enum `std::option::Option<&str>`
found enum `std::option::Option<i32>`
= note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)
error[E0308]: mismatched types
--> $DIR/wrong_param_type.rs:13:50
|
13 | let _query = sqlx::query!("select $1::text", arg.as_ref());
| ^^^^^^^^^^^^ expected `str`, found `i32`
|
= note: expected enum `std::option::Option<&str>`
found enum `std::option::Option<&i32>`
= note: this error originates in a macro (in Nightly builds, run with -Z macro-backtrace for more info)