derives: update for new Decode/Encode traits and extensively test in usage

This commit is contained in:
Ryan Leckey 2020-03-17 19:26:59 -07:00
parent 21059620dc
commit d77b2b1e97
10 changed files with 378 additions and 525 deletions

View file

@ -92,6 +92,10 @@ required-features = [ "mysql" ]
name = "mysql-raw"
required-features = [ "mysql" ]
[[test]]
name = "mysql-derives"
required-features = [ "mysql", "macros" ]
[[test]]
name = "postgres"
required-features = [ "postgres" ]
@ -104,6 +108,10 @@ required-features = [ "postgres" ]
name = "postgres-types"
required-features = [ "postgres" ]
[[test]]
name = "postgres-derives"
required-features = [ "postgres", "macros" ]
[[test]]
name = "mysql-types"
required-features = [ "mysql" ]

View file

@ -11,33 +11,42 @@ macro_rules! assert_attribute {
};
}
pub struct SqlxAttributes {
macro_rules! fail {
($t:expr, $m:expr) => {
return Err(syn::Error::new_spanned($t, $m));
};
}
macro_rules! try_set {
($i:ident, $v:expr, $t:expr) => {
match $i {
None => $i = Some($v),
Some(_) => fail!($t, "duplicate attribute"),
}
};
}
#[derive(Copy, Clone)]
pub enum RenameAll {
LowerCase,
}
pub struct SqlxContainerAttributes {
pub transparent: bool,
pub postgres_oid: Option<u32>,
pub rename_all: Option<RenameAll>,
pub repr: Option<Ident>,
}
pub struct SqlxChildAttributes {
pub rename: Option<String>,
}
pub fn parse_attributes(input: &[Attribute]) -> syn::Result<SqlxAttributes> {
pub fn parse_container_attributes(input: &[Attribute]) -> syn::Result<SqlxContainerAttributes> {
let mut transparent = None;
let mut postgres_oid = None;
let mut repr = None;
let mut rename = None;
macro_rules! fail {
($t:expr, $m:expr) => {
return Err(syn::Error::new_spanned($t, $m));
};
}
macro_rules! try_set {
($i:ident, $v:expr, $t:expr) => {
match $i {
None => $i = Some($v),
Some(_) => fail!($t, "duplicate attribute"),
}
};
}
let mut rename_all = None;
for attr in input {
let meta = attr
@ -51,11 +60,21 @@ pub fn parse_attributes(input: &[Attribute]) -> syn::Result<SqlxAttributes> {
Meta::Path(p) if p.is_ident("transparent") => {
try_set!(transparent, true, value)
}
Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(val),
..
}) if path.is_ident("rename") => try_set!(rename, val.value(), value),
}) if path.is_ident("rename_all") => {
let val = match &*val.value() {
"lowercase" => RenameAll::LowerCase,
_ => fail!(meta, "unexpected value for rename_all"),
};
try_set!(rename_all, val, value)
},
Meta::List(list) if list.path.is_ident("postgres") => {
for value in list.nested.iter() {
match value {
@ -92,85 +111,93 @@ pub fn parse_attributes(input: &[Attribute]) -> syn::Result<SqlxAttributes> {
}
}
Ok(SqlxAttributes {
Ok(SqlxContainerAttributes {
transparent: transparent.unwrap_or(false),
postgres_oid,
repr,
rename_all,
})
}
pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttributes> {
let mut rename = None;
for attr in input {
let meta = attr
.parse_meta()
.map_err(|e| syn::Error::new_spanned(attr, e))?;
match meta {
Meta::List(list) if list.path.is_ident("sqlx") => {
for value in list.nested.iter() {
match value {
NestedMeta::Meta(meta) => match meta {
Meta::NameValue(MetaNameValue {
path,
lit: Lit::Str(val),
..
}) if path.is_ident("rename") => try_set!(rename, val.value(), value),
u => fail!(u, "unexpected attribute"),
},
u => fail!(u, "unexpected attribute"),
}
}
}
_ => {}
}
}
Ok(SqlxChildAttributes {
rename,
})
}
pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> {
let attributes = parse_attributes(&input.attrs)?;
let attributes = parse_container_attributes(&input.attrs)?;
assert_attribute!(
attributes.transparent,
"expected #[sqlx(transparent)]",
input
);
#[cfg(feature = "postgres")]
assert_attribute!(
attributes.postgres_oid.is_none(),
"unexpected #[sqlx(postgres(oid = ..))]",
input
);
assert_attribute!(
attributes.rename.is_none(),
"unexpected #[sqlx(rename = ..)]",
attributes.rename_all.is_none(),
"unexpected #[sqlx(rename_all = ..)]",
field
);
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input);
let attributes = parse_attributes(&field.attrs)?;
assert_attribute!(
!attributes.transparent,
"unexpected #[sqlx(transparent)]",
field
);
#[cfg(feature = "postgres")]
assert_attribute!(
attributes.postgres_oid.is_none(),
"unexpected #[sqlx(postgres(oid = ..))]",
field
);
let attributes = parse_child_attributes(&field.attrs)?;
assert_attribute!(
attributes.rename.is_none(),
"unexpected #[sqlx(rename = ..)]",
field
);
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field);
Ok(())
}
pub fn check_enum_attributes<'a>(
input: &'a DeriveInput,
variants: &Punctuated<Variant, Comma>,
) -> syn::Result<SqlxAttributes> {
let attributes = parse_attributes(&input.attrs)?;
) -> syn::Result<SqlxContainerAttributes> {
let attributes = parse_container_attributes(&input.attrs)?;
assert_attribute!(
!attributes.transparent,
"unexpected #[sqlx(transparent)]",
input
);
assert_attribute!(
attributes.rename.is_none(),
"unexpected #[sqlx(rename = ..)]",
input
);
for variant in variants {
let attributes = parse_attributes(&variant.attrs)?;
assert_attribute!(
!attributes.transparent,
"unexpected #[sqlx(transparent)]",
variant
);
#[cfg(feature = "postgres")]
assert_attribute!(
attributes.postgres_oid.is_none(),
"unexpected #[sqlx(postgres(oid = ..))]",
variant
);
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", variant);
}
Ok(attributes)
}
@ -178,83 +205,90 @@ pub fn check_enum_attributes<'a>(
pub fn check_weak_enum_attributes(
input: &DeriveInput,
variants: &Punctuated<Variant, Comma>,
) -> syn::Result<Ident> {
let attributes = check_enum_attributes(input, variants)?;
) -> syn::Result<SqlxContainerAttributes> {
let attributes = check_enum_attributes(input)?;
#[cfg(feature = "postgres")]
assert_attribute!(
attributes.postgres_oid.is_none(),
"unexpected #[sqlx(postgres(oid = ..))]",
input
);
assert_attribute!(attributes.repr.is_some(), "expected #[repr(..)]", input);
assert_attribute!(
attributes.rename_all.is_none(),
"unexpected #[sqlx(c = ..)]",
input
);
for variant in variants {
let attributes = parse_attributes(&variant.attrs)?;
let attributes = parse_child_attributes(&variant.attrs)?;
assert_attribute!(
attributes.rename.is_none(),
"unexpected #[sqlx(rename = ..)]",
variant
);
}
Ok(attributes.repr.unwrap())
Ok(attributes)
}
pub fn check_strong_enum_attributes(
input: &DeriveInput,
variants: &Punctuated<Variant, Comma>,
) -> syn::Result<SqlxAttributes> {
let attributes = check_enum_attributes(input, variants)?;
_variants: &Punctuated<Variant, Comma>,
) -> syn::Result<SqlxContainerAttributes> {
let attributes = check_enum_attributes(input)?;
#[cfg(feature = "postgres")]
assert_attribute!(
attributes.postgres_oid.is_some(),
"expected #[sqlx(postgres(oid = ..))]",
input
);
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input);
Ok(attributes)
}
pub fn check_struct_attributes<'a>(
input: &'a DeriveInput,
fields: &Punctuated<Field, Comma>,
) -> syn::Result<SqlxAttributes> {
let attributes = parse_attributes(&input.attrs)?;
) -> syn::Result<SqlxContainerAttributes> {
let attributes = parse_container_attributes(&input.attrs)?;
assert_attribute!(
!attributes.transparent,
"unexpected #[sqlx(transparent)]",
input
);
#[cfg(feature = "postgres")]
assert_attribute!(
attributes.postgres_oid.is_some(),
"expected #[sqlx(postgres(oid = ..))]",
input
);
assert_attribute!(
attributes.rename.is_none(),
"unexpected #[sqlx(rename = ..)]",
attributes.rename_all.is_none(),
"unexpected #[sqlx(rename_all = ..)]",
input
);
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input);
for field in fields {
let attributes = parse_attributes(&field.attrs)?;
assert_attribute!(
!attributes.transparent,
"unexpected #[sqlx(transparent)]",
field
);
#[cfg(feature = "postgres")]
assert_attribute!(
attributes.postgres_oid.is_none(),
"unexpected #[sqlx(postgres(oid = ..))]",
field
);
let attributes = parse_child_attributes(&field.attrs)?;
assert_attribute!(
attributes.rename.is_none(),
"unexpected #[sqlx(rename = ..)]",
field
);
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", field);
}
Ok(attributes)

View file

@ -1,7 +1,9 @@
use super::attributes::{
check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes,
check_weak_enum_attributes, parse_attributes,
check_weak_enum_attributes, parse_container_attributes,
parse_child_attributes,
};
use super::rename_all;
use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::Comma;
@ -11,7 +13,7 @@ use syn::{
};
pub fn expand_derive_decode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let attrs = parse_attributes(&input.attrs)?;
let attrs = parse_container_attributes(&input.attrs)?;
match &input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
@ -83,24 +85,29 @@ fn expand_derive_decode_weak_enum(
input: &DeriveInput,
variants: &Punctuated<Variant, Comma>,
) -> syn::Result<proc_macro2::TokenStream> {
let repr = check_weak_enum_attributes(input, &variants)?;
let attr = check_weak_enum_attributes(input, &variants)?;
let repr = attr.repr.unwrap();
let ident = &input.ident;
let ident_s = ident.to_string();
let arms = variants
.iter()
.map(|v| {
let id = &v.ident;
parse_quote!(_ if (#ident :: #id as #repr) == val => Ok(#ident :: #id),)
parse_quote!(_ if (#ident :: #id as #repr) == value => Ok(#ident :: #id),)
})
.collect::<Vec<Arm>>();
Ok(quote!(
impl<DB: sqlx::Database> sqlx::decode::Decode<DB> for #ident where #repr: sqlx::decode::Decode<DB> {
fn decode(raw: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
let val = <#repr as sqlx::decode::Decode<DB>>::decode(raw)?;
match val {
impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where #repr: sqlx::decode::Decode<'de, DB> {
fn decode(value: <DB as sqlx::database::HasRawValue<'de>>::RawValue) -> sqlx::Result<Self> {
let value = <#repr as sqlx::decode::Decode<'de, DB>>::decode(value)?;
match value {
#(#arms)*
_ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value")))
_ => Err(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))
}
}
}
@ -111,29 +118,35 @@ fn expand_derive_decode_strong_enum(
input: &DeriveInput,
variants: &Punctuated<Variant, Comma>,
) -> syn::Result<proc_macro2::TokenStream> {
check_strong_enum_attributes(input, &variants)?;
let cattr = check_strong_enum_attributes(input, &variants)?;
let ident = &input.ident;
let ident_s = ident.to_string();
let value_arms = variants.iter().map(|v| -> Arm {
let id = &v.ident;
let attributes = parse_attributes(&v.attrs).unwrap();
let attributes = parse_child_attributes(&v.attrs).unwrap();
if let Some(rename) = attributes.rename {
parse_quote!(#rename => Ok(#ident :: #id),)
} else if let Some(pattern) = cattr.rename_all {
let name = rename_all(&*id.to_string(), pattern);
parse_quote!(#name => Ok(#ident :: #id),)
} else {
let name = id.to_string();
parse_quote!(#name => Ok(#ident :: #id),)
}
});
// TODO: prevent heap allocation
Ok(quote!(
impl<DB: sqlx::Database> sqlx::decode::Decode<DB> for #ident where std::string::String: sqlx::decode::Decode<DB> {
fn decode(buf: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
let val = <String as sqlx::decode::Decode<DB>>::decode(buf)?;
match val.as_str() {
impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where &'de str: sqlx::decode::Decode<'de, DB> {
fn decode(value: <DB as sqlx::database::HasRawValue<'de>>::RawValue) -> sqlx::Result<Self> {
let value = <&'de str as sqlx::decode::Decode<'de, DB>>::decode(value)?;
match value {
#(#value_arms)*
_ => Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid value")))
_ => Err(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into()))
}
}
}
@ -151,57 +164,50 @@ fn expand_derive_decode_struct(
if cfg!(feature = "postgres") {
let ident = &input.ident;
let column_count = fields.len();
// extract type generics
let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();
// add db type for impl generics & where clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!('de));
let predicates = &mut generics.make_where_clause().predicates;
for field in fields {
let ty = &field.ty;
predicates.push(parse_quote!(#ty: sqlx::decode::Decode<sqlx::Postgres>));
predicates.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>));
predicates.push(parse_quote!(#ty: sqlx::types::Type<sqlx::Postgres>));
}
let (impl_generics, _, where_clause) = generics.split_for_impl();
let reads = fields.iter().map(|field| -> Stmt {
let id = &field.ident;
let ty = &field.ty;
parse_quote!(
let #id = sqlx::postgres::decode_struct_field::<#ty>(&mut buf)?;
let #id = decoder.decode::<#ty>()?;
)
});
let names = fields.iter().map(|field| &field.ident);
tts.extend(quote!(
impl #impl_generics sqlx::decode::Decode<sqlx::Postgres> for #ident#ty_generics #where_clause {
fn decode(buf: &[u8]) -> std::result::Result<Self, sqlx::decode::DecodeError> {
if buf.len() < 4 {
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent")));
impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause {
fn decode(value: <sqlx::Postgres as sqlx::database::HasRawValue<'de>>::RawValue) -> sqlx::Result<Self> {
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;
#(#reads)*
Ok(#ident {
#(#names),*
})
}
let column_count = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[..4]).unwrap()) as usize;
if column_count != #column_count {
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count")));
}
let mut buf = &buf[4..];
#(#reads)*
if !buf.is_empty() {
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new(format!("Too much data sent ({} bytes left)", buf.len()))));
}
Ok(#ident {
#(#names),*
})
}
}
))
));
}
Ok(tts)
}

View file

@ -1,7 +1,8 @@
use super::attributes::{
check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes,
check_weak_enum_attributes, parse_attributes,
check_weak_enum_attributes, parse_container_attributes, parse_child_attributes,
};
use super::rename_all;
use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::Comma;
@ -11,7 +12,7 @@ use syn::{
};
pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let args = parse_attributes(&input.attrs)?;
let args = parse_container_attributes(&input.attrs)?;
match &input.data {
Data::Struct(DataStruct {
@ -87,18 +88,21 @@ fn expand_derive_encode_weak_enum(
input: &DeriveInput,
variants: &Punctuated<Variant, Comma>,
) -> syn::Result<proc_macro2::TokenStream> {
let repr = check_weak_enum_attributes(input, &variants)?;
let attr = check_weak_enum_attributes(input, &variants)?;
let repr = attr.repr.unwrap();
let ident = &input.ident;
Ok(quote!(
impl<DB: sqlx::Database> sqlx::encode::Encode<DB> for #ident where #repr: sqlx::encode::Encode<DB> {
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
fn encode(&self, buf: &mut DB::RawBuffer) {
sqlx::encode::Encode::encode(&(*self as #repr), buf)
}
fn encode_nullable(&self, buf: &mut std::vec::Vec<u8>) -> sqlx::encode::IsNull {
fn encode_nullable(&self, buf: &mut DB::RawBuffer) -> sqlx::encode::IsNull {
sqlx::encode::Encode::encode_nullable(&(*self as #repr), buf)
}
fn size_hint(&self) -> usize {
sqlx::encode::Encode::size_hint(&(*self as #repr))
}
@ -110,16 +114,21 @@ fn expand_derive_encode_strong_enum(
input: &DeriveInput,
variants: &Punctuated<Variant, Comma>,
) -> syn::Result<proc_macro2::TokenStream> {
check_strong_enum_attributes(input, &variants)?;
let cattr = check_strong_enum_attributes(input, &variants)?;
let ident = &input.ident;
let mut value_arms = Vec::new();
for v in variants {
let id = &v.ident;
let attributes = parse_attributes(&v.attrs)?;
let attributes = parse_child_attributes(&v.attrs)?;
if let Some(rename) = attributes.rename {
value_arms.push(quote!(#ident :: #id => #rename,));
} else if let Some(pattern) = cattr.rename_all {
let name = rename_all(&*id.to_string(), pattern);
value_arms.push(quote!(#ident :: #id => #name,));
} else {
let name = id.to_string();
value_arms.push(quote!(#ident :: #id => #name,));
@ -128,12 +137,13 @@ fn expand_derive_encode_strong_enum(
Ok(quote!(
impl<DB: sqlx::Database> sqlx::encode::Encode<DB> for #ident where str: sqlx::encode::Encode<DB> {
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
fn encode(&self, buf: &mut DB::RawBuffer) {
let val = match self {
#(#value_arms)*
};
<str as sqlx::encode::Encode<DB>>::encode(val, buf)
}
fn size_hint(&self) -> usize {
let val = match self {
#(#value_arms)*
@ -154,7 +164,6 @@ fn expand_derive_encode_struct(
if cfg!(feature = "postgres") {
let ident = &input.ident;
let column_count = fields.len();
// extract type generics
@ -164,23 +173,29 @@ fn expand_derive_encode_struct(
// add db type for impl generics & where clause
let mut generics = generics.clone();
let predicates = &mut generics.make_where_clause().predicates;
for field in fields {
let ty = &field.ty;
predicates.push(parse_quote!(#ty: sqlx::encode::Encode<sqlx::Postgres>));
predicates.push(parse_quote!(sqlx::Postgres: sqlx::types::HasSqlType<#ty>));
predicates.push(parse_quote!(#ty: sqlx::types::Type<sqlx::Postgres>));
}
let (impl_generics, _, where_clause) = generics.split_for_impl();
let writes = fields.iter().map(|field| -> Stmt {
let id = &field.ident;
parse_quote!(
sqlx::postgres::encode_struct_field(buf, &self. #id);
// sqlx::postgres::encode_struct_field(buf, &self. #id);
encoder.encode(&self. #id);
)
});
let sizes = fields.iter().map(|field| -> Expr {
let id = &field.ident;
let ty = &field.ty;
parse_quote!(
<#ty as sqlx::encode::Encode<sqlx::Postgres>>::size_hint(&self. #id)
)
@ -189,13 +204,16 @@ fn expand_derive_encode_struct(
tts.extend(quote!(
impl #impl_generics sqlx::encode::Encode<sqlx::Postgres> for #ident #ty_generics #where_clause {
fn encode(&self, buf: &mut std::vec::Vec<u8>) {
buf.extend(&(#column_count as u32).to_be_bytes());
let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf);
#(#writes)*
encoder.finish();
}
fn size_hint(&self) -> usize {
4 // oid
+ #column_count * (4 + 4) // oid (int) and length (int) for each column
+ #(#sizes)+* // sum of the size hints for each column
#column_count * (4 + 4) // oid (int) and length (int) for each column
+ #(#sizes)+* // sum of the size hints for each column
}
}
));

View file

@ -7,6 +7,7 @@ pub(crate) use decode::expand_derive_decode;
pub(crate) use encode::expand_derive_encode;
pub(crate) use r#type::expand_derive_type;
use self::attributes::RenameAll;
use std::iter::FromIterator;
use syn::DeriveInput;
@ -23,3 +24,11 @@ pub(crate) fn expand_derive_type_encode_decode(
Ok(combined)
}
pub(crate) fn rename_all(s: &str, pattern: RenameAll) -> String {
match pattern {
RenameAll::LowerCase => {
s.to_lowercase()
}
}
}

View file

@ -1,6 +1,6 @@
use super::attributes::{
check_strong_enum_attributes, check_struct_attributes, check_transparent_attributes,
check_weak_enum_attributes, parse_attributes,
check_weak_enum_attributes, parse_container_attributes,
};
use quote::quote;
use syn::punctuated::Punctuated;
@ -11,7 +11,7 @@ use syn::{
};
pub fn expand_derive_type(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let attrs = parse_attributes(&input.attrs)?;
let attrs = parse_container_attributes(&input.attrs)?;
match &input.data {
Data::Struct(DataStruct {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
@ -65,64 +65,36 @@ fn expand_derive_has_sql_type_transparent(
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::types::Type<DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
let mut tts = proc_macro2::TokenStream::new();
// if cfg!(feature = "mysql") {
tts.extend(quote!(
Ok(quote!(
impl #impl_generics sqlx::types::Type< DB > for #ident #ty_generics #where_clause {
fn type_info() -> DB::TypeInfo {
<#ty as sqlx::Type<DB>>::type_info()
}
}
));
// }
// if cfg!(feature = "postgres") {
// tts.extend(quote!(
// impl #impl_generics sqlx::types::HasSqlType< sqlx::Postgres > #ident #ty_generics #where_clause {
// fn type_info() -> Self::TypeInfo {
// <Self as HasSqlType<#ty>>::type_info()
// }
// }
// ));
// }
Ok(tts)
))
}
fn expand_derive_has_sql_type_weak_enum(
input: &DeriveInput,
variants: &Punctuated<Variant, Comma>,
) -> syn::Result<proc_macro2::TokenStream> {
let repr = check_weak_enum_attributes(input, variants)?;
let attr = check_weak_enum_attributes(input, variants)?;
let repr = attr.repr.unwrap();
let ident = &input.ident;
let mut tts = proc_macro2::TokenStream::new();
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl sqlx::types::HasSqlType< #ident > for sqlx::MySql where Self: sqlx::types::HasSqlType< #repr > {
fn type_info() -> Self::TypeInfo {
<Self as HasSqlType<#repr>>::type_info()
}
Ok(quote!(
impl<DB: sqlx::Database> sqlx::Type<DB> for #ident
where
#repr: sqlx::Type<DB>,
{
fn type_info() -> DB::TypeInfo {
<#repr as sqlx::Type<DB>>::type_info()
}
));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres where Self: sqlx::types::HasSqlType< #repr > {
fn type_info() -> Self::TypeInfo {
<Self as HasSqlType<#repr>>::type_info()
}
}
));
}
Ok(tts)
}
))
}
fn expand_derive_has_sql_type_strong_enum(
@ -136,9 +108,11 @@ fn expand_derive_has_sql_type_strong_enum(
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl sqlx::types::HasSqlType< #ident > for sqlx::MySql {
fn type_info() -> Self::TypeInfo {
sqlx::mysql::MySqlTypeInfo::r#enum()
impl sqlx::Type< sqlx::MySql > for #ident {
fn type_info() -> sqlx::mysql::MySqlTypeInfo {
// This is really fine, MySQL is loosely typed and
// we don't nede to be specific here
<str as sqlx::Type<sqlx::MySql>>::type_info()
}
}
));
@ -147,8 +121,8 @@ fn expand_derive_has_sql_type_strong_enum(
if cfg!(feature = "postgres") {
let oid = attributes.postgres_oid.unwrap();
tts.extend(quote!(
impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres {
fn type_info() -> Self::TypeInfo {
impl sqlx::Type< sqlx::Postgres > for #ident {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_oid(#oid)
}
}
@ -170,8 +144,8 @@ fn expand_derive_has_sql_type_struct(
if cfg!(feature = "postgres") {
let oid = attributes.postgres_oid.unwrap();
tts.extend(quote!(
impl sqlx::types::HasSqlType< #ident > for sqlx::Postgres {
fn type_info() -> Self::TypeInfo {
impl sqlx::types::Type< sqlx::Postgres > for #ident {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_oid(#oid)
}
}

View file

@ -1,318 +0,0 @@
use sqlx::decode::Decode;
use sqlx::encode::Encode;
use sqlx::types::TypeInfo;
use sqlx::Type;
use std::fmt::Debug;
#[derive(PartialEq, Debug, Type)]
#[sqlx(transparent)]
struct Transparent(i32);
// #[derive(PartialEq, Debug, Clone, Copy, Encode, Decode, HasSqlType)]
// #[repr(i32)]
// #[allow(dead_code)]
// enum Weak {
// One,
// Two,
// Three,
// }
//
// #[derive(PartialEq, Debug, Encode, Decode, HasSqlType)]
// #[sqlx(postgres(oid = 10101010))]
// #[allow(dead_code)]
// enum Strong {
// One,
// Two,
// #[sqlx(rename = "four")]
// Three,
// }
//
// #[derive(PartialEq, Debug, Encode, Decode, HasSqlType)]
// #[sqlx(postgres(oid = 20202020))]
// #[allow(dead_code)]
// struct Struct {
// field1: String,
// field2: i64,
// field3: bool,
// }
#[test]
#[cfg(feature = "mysql")]
fn encode_transparent_mysql() {
encode_transparent::<sqlx::MySql>();
}
#[test]
#[cfg(feature = "postgres")]
fn encode_transparent_postgres() {
encode_transparent::<sqlx::Postgres>();
}
#[allow(dead_code)]
fn encode_transparent<DB: sqlx::Database<RawBuffer = Vec<u8>>>()
where
Transparent: Encode<DB>,
i32: Encode<DB>,
{
let example = Transparent(0x1122_3344);
let mut encoded = Vec::new();
let mut encoded_orig = Vec::new();
Encode::<DB>::encode(&example, &mut encoded);
Encode::<DB>::encode(&example.0, &mut encoded_orig);
assert_eq!(encoded, encoded_orig);
}
//
// #[test]
// #[cfg(feature = "mysql")]
// fn encode_weak_enum_mysql() {
// encode_weak_enum::<sqlx::MySql>();
// }
//
// #[test]
// #[cfg(feature = "postgres")]
// fn encode_weak_enum_postgres() {
// encode_weak_enum::<sqlx::Postgres>();
// }
//
// #[allow(dead_code)]
// fn encode_weak_enum<DB: sqlx::Database<RawBuffer = Vec<u8>>>()
// where
// Weak: Encode<DB>,
// i32: Encode<DB>,
// {
// for example in [Weak::One, Weak::Two, Weak::Three].iter() {
// let mut encoded = Vec::new();
// let mut encoded_orig = Vec::new();
//
// Encode::<DB>::encode(example, &mut encoded);
// Encode::<DB>::encode(&(*example as i32), &mut encoded_orig);
//
// assert_eq!(encoded, encoded_orig);
// }
// }
//
// #[test]
// #[cfg(feature = "mysql")]
// fn encode_strong_enum_mysql() {
// encode_strong_enum::<sqlx::MySql>();
// }
//
// #[test]
// #[cfg(feature = "postgres")]
// fn encode_strong_enum_postgres() {
// encode_strong_enum::<sqlx::Postgres>();
// }
//
// #[allow(dead_code)]
// fn encode_strong_enum<DB: sqlx::Database<RawBuffer = Vec<u8>>>()
// where
// Strong: Encode<DB>,
// str: Encode<DB>,
// {
// for (example, name) in [
// (Strong::One, "One"),
// (Strong::Two, "Two"),
// (Strong::Three, "four"),
// ]
// .iter()
// {
// let mut encoded = Vec::new();
// let mut encoded_orig = Vec::new();
//
// Encode::<DB>::encode(example, &mut encoded);
// Encode::<DB>::encode(*name, &mut encoded_orig);
//
// assert_eq!(encoded, encoded_orig);
// }
// }
//
// #[test]
// #[cfg(feature = "postgres")]
// fn encode_struct_postgres() {
// let field1 = "Foo".to_string();
// let field2 = 3;
// let field3 = false;
//
// let example = Struct {
// field1: field1.clone(),
// field2,
// field3,
// };
//
// let mut encoded = Vec::new();
// Encode::<sqlx::Postgres>::encode(&example, &mut encoded);
//
// let string_oid = <sqlx::Postgres as HasSqlType<String>>::type_info().oid();
// let i64_oid = <sqlx::Postgres as HasSqlType<i64>>::type_info().oid();
// let bool_oid = <sqlx::Postgres as HasSqlType<bool>>::type_info().oid();
//
// // 3 columns
// assert_eq!(&[0, 0, 0, 3], &encoded[..4]);
// let encoded = &encoded[4..];
//
// // check field1 (string)
// assert_eq!(&string_oid.to_be_bytes(), &encoded[0..4]);
// assert_eq!(&(field1.len() as u32).to_be_bytes(), &encoded[4..8]);
// assert_eq!(field1.as_bytes(), &encoded[8..8 + field1.len()]);
// let encoded = &encoded[8 + field1.len()..];
//
// // check field2 (i64)
// assert_eq!(&i64_oid.to_be_bytes(), &encoded[0..4]);
// assert_eq!(&8u32.to_be_bytes(), &encoded[4..8]);
// assert_eq!(field2.to_be_bytes(), &encoded[8..16]);
// let encoded = &encoded[16..];
//
// // check field3 (bool)
// assert_eq!(&bool_oid.to_be_bytes(), &encoded[0..4]);
// assert_eq!(&1u32.to_be_bytes(), &encoded[4..8]);
// assert_eq!(field3, encoded[8] != 0);
// let encoded = &encoded[9..];
//
// assert!(encoded.is_empty());
//
// let string_size = <String as Encode<sqlx::Postgres>>::size_hint(&field1);
// let i64_size = <i64 as Encode<sqlx::Postgres>>::size_hint(&field2);
// let bool_size = <bool as Encode<sqlx::Postgres>>::size_hint(&field3);
//
// assert_eq!(
// 4 + 3 * (4 + 4) + string_size + i64_size + bool_size,
// example.size_hint()
// );
// }
#[test]
#[cfg(feature = "mysql")]
fn decode_transparent_mysql() {
decode_with_db::<sqlx::MySql, Transparent>(Transparent(0x1122_3344));
}
#[test]
#[cfg(feature = "postgres")]
fn decode_transparent_postgres() {
decode_with_db::<sqlx::Postgres, Transparent>(Transparent(0x1122_3344));
}
//
// #[test]
// #[cfg(feature = "mysql")]
// fn decode_weak_enum_mysql() {
// decode_with_db::<sqlx::MySql, Weak>(Weak::One);
// decode_with_db::<sqlx::MySql, Weak>(Weak::Two);
// decode_with_db::<sqlx::MySql, Weak>(Weak::Three);
// }
//
// #[test]
// #[cfg(feature = "postgres")]
// fn decode_weak_enum_postgres() {
// decode_with_db::<sqlx::Postgres, Weak>(Weak::One);
// decode_with_db::<sqlx::Postgres, Weak>(Weak::Two);
// decode_with_db::<sqlx::Postgres, Weak>(Weak::Three);
// }
//
// #[test]
// #[cfg(feature = "mysql")]
// fn decode_strong_enum_mysql() {
// decode_with_db::<sqlx::MySql, Strong>(Strong::One);
// decode_with_db::<sqlx::MySql, Strong>(Strong::Two);
// decode_with_db::<sqlx::MySql, Strong>(Strong::Three);
// }
//
// #[test]
// #[cfg(feature = "postgres")]
// fn decode_strong_enum_postgres() {
// decode_with_db::<sqlx::Postgres, Strong>(Strong::One);
// decode_with_db::<sqlx::Postgres, Strong>(Strong::Two);
// decode_with_db::<sqlx::Postgres, Strong>(Strong::Three);
// }
//
// #[test]
// #[cfg(feature = "postgres")]
// fn decode_struct_postgres() {
// decode_with_db::<sqlx::Postgres, Struct>(Struct {
// field1: "Foo".to_string(),
// field2: 3,
// field3: true,
// });
// }
//
#[allow(dead_code)]
fn decode_with_db<
DB: sqlx::Database<RawBuffer = Vec<u8>>,
V: for<'de> Decode<'de, DB> + Encode<DB> + PartialEq + Debug,
>(
example: V,
) {
let mut encoded = Vec::new();
Encode::<DB>::encode(&example, &mut encoded);
// let decoded = V::decode(&encoded).unwrap();
// assert_eq!(example, decoded);
}
#[test]
#[cfg(feature = "mysql")]
fn type_transparent_mysql() {
type_transparent::<sqlx::MySql>();
}
#[test]
#[cfg(feature = "postgres")]
fn type_transparent_postgres() {
type_transparent::<sqlx::Postgres>();
}
#[allow(dead_code)]
fn type_transparent<DB: sqlx::Database<RawBuffer = Vec<u8>>>()
where
Transparent: Type<DB>,
i32: Type<DB>,
{
let info: DB::TypeInfo = <Transparent as Type<DB>>::type_info();
let info_orig: DB::TypeInfo = <i32 as Type<DB>>::type_info();
assert!(info.compatible(&info_orig));
}
//
// #[test]
// #[cfg(feature = "mysql")]
// fn type_weak_enum_mysql() {
// type_weak_enum::<sqlx::MySql>();
// }
//
// #[test]
// #[cfg(feature = "postgres")]
// fn type_weak_enum_postgres() {
// type_weak_enum::<sqlx::Postgres>();
// }
//
// #[allow(dead_code)]
// fn type_weak_enum<DB: sqlx::Database<RawBuffer = Vec<u8>>>()
// where
// DB: HasSqlType<Weak> + HasSqlType<i32>,
// {
// let info: DB::TypeInfo = <DB as HasSqlType<Weak>>::type_info();
// let info_orig: DB::TypeInfo = <DB as HasSqlType<i32>>::type_info();
// assert!(info.compatible(&info_orig));
// }
//
// #[test]
// #[cfg(feature = "mysql")]
// fn type_strong_enum_mysql() {
// let info: sqlx::mysql::MySqlTypeInfo = <sqlx::MySql as HasSqlType<Strong>>::type_info();
// assert!(info.compatible(&sqlx::mysql::MySqlTypeInfo::r#enum()))
// }
//
// #[test]
// #[cfg(feature = "postgres")]
// fn type_strong_enum_postgres() {
// let info: sqlx::postgres::PgTypeInfo = <sqlx::Postgres as HasSqlType<Strong>>::type_info();
// assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(10101010)))
// }
//
// #[test]
// #[cfg(feature = "postgres")]
// fn type_struct_postgres() {
// let info: sqlx::postgres::PgTypeInfo = <sqlx::Postgres as HasSqlType<Struct>>::type_info();
// assert!(info.compatible(&sqlx::postgres::PgTypeInfo::with_oid(20202020)))
// }

47
tests/mysql-derives.rs Normal file
View file

@ -0,0 +1,47 @@
use sqlx_test::test_type;
use std::fmt::Debug;
use sqlx::MySql;
// Transparent types are rust-side wrappers over DB types
#[derive(PartialEq, Debug, sqlx::Type)]
#[sqlx(transparent)]
struct Transparent(i32);
// "Weak" enums map to an integer type indicated by #[repr]
#[derive(PartialEq, Copy, Clone, Debug, sqlx::Type)]
#[repr(i32)]
enum Weak {
One = 0,
Two = 2,
Three = 4,
}
// "Strong" enums can map to TEXT or a custom enum
#[derive(PartialEq, Debug, sqlx::Type)]
#[sqlx(rename_all = "lowercase")]
enum Color {
Red,
Green,
Blue,
}
test_type!(transparent(
MySql,
Transparent,
"0" == Transparent(0),
"23523" == Transparent(23523)
));
test_type!(weak_enum(
MySql,
Weak,
"0" == Weak::One,
"2" == Weak::Two,
"4" == Weak::Three
));
test_type!(strong_color_enum(
MySql,
Color,
"'green'" == Color::Green
));

81
tests/postgres-derives.rs Normal file
View file

@ -0,0 +1,81 @@
use sqlx_test::test_type;
use std::fmt::Debug;
use sqlx::Postgres;
// Transparent types are rust-side wrappers over DB types
#[derive(PartialEq, Debug, sqlx::Type)]
#[sqlx(transparent)]
struct Transparent(i32);
// "Weak" enums map to an integer type indicated by #[repr]
#[derive(PartialEq, Copy, Clone, Debug, sqlx::Type)]
#[repr(i32)]
enum Weak {
One = 0,
Two = 2,
Three = 4,
}
// "Strong" enums can map to TEXT (25) or a custom enum type
#[derive(PartialEq, Debug, sqlx::Type)]
#[sqlx(postgres(oid = 25))]
#[sqlx(rename_all = "lowercase")]
enum Strong {
One,
Two,
#[sqlx(rename = "four")]
Three,
}
// Records must map to a custom type
// Note that all types are types in Postgres
#[derive(PartialEq, Debug, sqlx::Type)]
#[sqlx(postgres(oid = 12184))]
struct PgConfig {
name: String,
setting: Option<String>,
}
test_type!(transparent(
Postgres,
Transparent,
"0" == Transparent(0),
"23523" == Transparent(23523)
));
test_type!(weak_enum(
Postgres,
Weak,
"0::int4" == Weak::One,
"2::int4" == Weak::Two,
"4::int4" == Weak::Three
));
test_type!(strong_enum(
Postgres,
Strong,
"'one'::text" == Strong::One,
"'two'::text" == Strong::Two,
"'four'::text" == Strong::Three
));
test_type!(record_pg_config(
Postgres,
PgConfig,
// (CC,gcc)
"(SELECT ROW('CC', 'gcc')::pg_config)" == PgConfig {
name: "CC".to_owned(),
setting: Some("gcc".to_owned()),
},
// (CC,)
"(SELECT '(\"CC\",)'::pg_config)" == PgConfig {
name: "CC".to_owned(),
setting: None,
},
// (CC,"")
"(SELECT '(\"CC\",\"\")'::pg_config)" == PgConfig {
name: "CC".to_owned(),
setting: Some("".to_owned()),
}
));

View file

@ -1,12 +1,13 @@
use futures::TryStreamExt;
use sqlx_test::new;
use sqlx::postgres::{PgPool, PgQueryAs, PgRow};
use sqlx::{postgres::PgConnection, Connect, Connection, Executor, Row};
use sqlx::{Postgres, Connection, Executor, Row};
use std::time::Duration;
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_connects() -> anyhow::Result<()> {
let mut conn = connect().await?;
let mut conn = new::<Postgres>().await?;
let value = sqlx::query("select 1 + 1")
.try_map(|row: PgRow| row.try_get::<i32, _>(0))
@ -21,7 +22,7 @@ async fn it_connects() -> anyhow::Result<()> {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_executes() -> anyhow::Result<()> {
let mut conn = connect().await?;
let mut conn = new::<Postgres>().await?;
let _ = conn
.execute(
@ -55,7 +56,7 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY);
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_can_return_interleaved_nulls_issue_104() -> anyhow::Result<()> {
let mut conn = connect().await?;
let mut conn = new::<Postgres>().await?;
let tuple =
sqlx::query("SELECT NULL::INT, 10::INT, NULL, 20::INT, NULL, 40::INT, NULL, 80::INT")
@ -89,7 +90,7 @@ async fn it_can_return_interleaved_nulls_issue_104() -> anyhow::Result<()> {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn it_can_work_with_transactions() -> anyhow::Result<()> {
let mut conn = connect().await?;
let mut conn = new::<Postgres>().await?;
conn.execute("CREATE TABLE IF NOT EXISTS _sqlx_users_1922 (id INTEGER PRIMARY KEY)")
.await?;
@ -141,7 +142,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> {
.await?;
}
conn = connect().await?;
conn = new::<Postgres>().await?;
let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_1922")
.fetch_one(&mut conn)
@ -210,7 +211,7 @@ async fn pool_smoke_test() -> anyhow::Result<()> {
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_describe() -> anyhow::Result<()> {
let mut conn = connect().await?;
let mut conn = new::<Postgres>().await?;
let _ = conn
.execute(
@ -239,10 +240,3 @@ async fn test_describe() -> anyhow::Result<()> {
Ok(())
}
async fn connect() -> anyhow::Result<PgConnection> {
let _ = dotenv::dotenv();
let _ = env_logger::try_init();
Ok(PgConnection::connect(dotenv::var("DATABASE_URL")?).await?)
}