refactor(derives): use separate impls per database

database-generic impls are *mostly* impossible in SQLx so we recently
capitalized on that and made it *totally* impossible (until Rust
has specialization and lazy norm)
This commit is contained in:
Ryan Leckey 2020-06-27 05:30:38 -07:00
parent af7bd71ab2
commit e3483230e0
3 changed files with 347 additions and 87 deletions

View file

@ -59,23 +59,37 @@ fn expand_derive_decode_transparent(
let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();
// add db type for impl generics & where clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics.params.insert(0, parse_quote!('r));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::decode::Decode<'r, DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
let mut tts = proc_macro2::TokenStream::new();
let tts = quote!(
impl #impl_generics sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause {
fn decode(value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
<#ty as sqlx::decode::Decode<'r, DB>>::decode(value).map(Self)
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::MySql> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::MySql> {
fn decode(value: <sqlx::MySql as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
<#ty as sqlx::decode::Decode<'r, sqlx::MySql>>::decode(value).map(Self)
}
}
}
);
));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::Postgres> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::Postgres> {
fn decode(value: <sqlx::Postgres as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
<#ty as sqlx::decode::Decode<'r, sqlx::Postgres>>::decode(value).map(Self)
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::Sqlite> for #ident #ty_generics where #ty: sqlx::decode::Decode<'r, sqlx::Sqlite> {
fn decode(value: <sqlx::Sqlite as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
<#ty as sqlx::decode::Decode<'r, sqlx::Sqlite>>::decode(value).map(Self)
}
}
));
}
Ok(tts)
}
@ -98,19 +112,57 @@ fn expand_derive_decode_weak_enum(
})
.collect::<Vec<Arm>>();
Ok(quote!(
impl<'r, DB: sqlx::Database> sqlx::decode::Decode<'r, DB> for #ident where #repr: sqlx::decode::Decode<'r, DB> {
fn decode(value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <#repr as sqlx::decode::Decode<'r, DB>>::decode(value)?;
let mut tts = proc_macro2::TokenStream::new();
match value {
#(#arms)*
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::MySql> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::MySql> {
fn decode(value: <sqlx::MySql as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <#repr as sqlx::decode::Decode<'r, sqlx::MySql>>::decode(value)?;
_ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())))
match value {
#(#arms)*
_ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())))
}
}
}
}
))
));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::Postgres> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::Postgres> {
fn decode(value: <sqlx::Postgres as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <#repr as sqlx::decode::Decode<'r, sqlx::Postgres>>::decode(value)?;
match value {
#(#arms)*
_ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())))
}
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::Sqlite> for #ident where #repr: sqlx::decode::Decode<'r, sqlx::Sqlite> {
fn decode(value: <sqlx::Sqlite as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <#repr as sqlx::decode::Decode<'r, sqlx::Sqlite>>::decode(value)?;
match value {
#(#arms)*
_ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())))
}
}
}
));
}
Ok(tts)
}
fn expand_derive_decode_strong_enum(

View file

@ -68,28 +68,70 @@ fn expand_derive_encode_transparent(
.params
.insert(0, LifetimeDef::new(lifetime.clone()).into());
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::encode::Encode<#lifetime, DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
Ok(quote!(
impl #impl_generics sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause {
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#ty as sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf)
}
let (impl_generics, _, _) = generics.split_for_impl();
fn produces(&self) -> Option<DB::TypeInfo> {
<#ty as sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0)
}
let mut tts = proc_macro2::TokenStream::new();
fn size_hint(&self) -> usize {
<#ty as sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0)
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::MySql> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::MySql> {
fn encode_by_ref(&self, buf: &mut <sqlx::MySql as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::encode_by_ref(&self.0, buf)
}
fn produces(&self) -> Option<sqlx::mysql::MySqlTypeInfo> {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::produces(&self.0)
}
fn size_hint(&self) -> usize {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::MySql>>::size_hint(&self.0)
}
}
}
))
));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::Postgres> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::Postgres> {
fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::encode_by_ref(&self.0, buf)
}
fn produces(&self) -> Option<sqlx::postgres::PgTypeInfo> {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::produces(&self.0)
}
fn size_hint(&self) -> usize {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Postgres>>::size_hint(&self.0)
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl #impl_generics sqlx::encode::Encode<#lifetime, sqlx::Sqlite> for #ident #ty_generics where #ty: sqlx::encode::Encode<#lifetime, sqlx::Sqlite> {
fn encode_by_ref(&self, buf: &mut <sqlx::Sqlite as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::encode_by_ref(&self.0, buf)
}
fn produces(&self) -> Option<sqlx::sqlite::SqliteTypeInfo> {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::produces(&self.0)
}
fn size_hint(&self) -> usize {
<#ty as sqlx::encode::Encode<#lifetime, sqlx::Sqlite>>::size_hint(&self.0)
}
}
));
}
Ok(tts)
}
fn expand_derive_encode_weak_enum(
@ -101,21 +143,63 @@ fn expand_derive_encode_weak_enum(
let ident = &input.ident;
Ok(quote!(
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where #repr: sqlx::encode::Encode<'q, DB> {
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#repr as sqlx::encode::Encode<DB>>::encode_by_ref(&(*self as #repr), buf)
}
let mut tts = proc_macro2::TokenStream::new();
fn produces(&self) -> Option<DB::TypeInfo> {
<#repr as sqlx::encode::Encode<DB>>::produces(&(*self as #repr))
}
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::MySql> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::MySql> {
fn encode_by_ref(&self, buf: &mut <sqlx::MySql as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#repr as sqlx::encode::Encode<sqlx::MySql>>::encode_by_ref(&(*self as #repr), buf)
}
fn size_hint(&self) -> usize {
<#repr as sqlx::encode::Encode<DB>>::size_hint(&(*self as #repr))
fn produces(&self) -> Option<sqlx::mysql::MySqlTypeInfo> {
<#repr as sqlx::encode::Encode<sqlx::MySql>>::produces(&(*self as #repr))
}
fn size_hint(&self) -> usize {
<#repr as sqlx::encode::Encode<sqlx::MySql>>::size_hint(&(*self as #repr))
}
}
}
))
));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::Postgres> {
fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#repr as sqlx::encode::Encode<sqlx::Postgres>>::encode_by_ref(&(*self as #repr), buf)
}
fn produces(&self) -> Option<sqlx::postgres::PgTypeInfo> {
<#repr as sqlx::encode::Encode<sqlx::Postgres>>::produces(&(*self as #repr))
}
fn size_hint(&self) -> usize {
<#repr as sqlx::encode::Encode<sqlx::Postgres>>::size_hint(&(*self as #repr))
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::Sqlite> for #ident where #repr: sqlx::encode::Encode<'q, sqlx::Sqlite> {
fn encode_by_ref(&self, buf: &mut <sqlx::Sqlite as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#repr as sqlx::encode::Encode<sqlx::Sqlite>>::encode_by_ref(&(*self as #repr), buf)
}
fn produces(&self) -> Option<sqlx::sqlite::SqliteTypeInfo> {
<#repr as sqlx::encode::Encode<sqlx::Sqlite>>::produces(&(*self as #repr))
}
fn size_hint(&self) -> usize {
<#repr as sqlx::encode::Encode<sqlx::Sqlite>>::size_hint(&(*self as #repr))
}
}
));
}
Ok(tts)
}
fn expand_derive_encode_strong_enum(
@ -143,25 +227,75 @@ fn expand_derive_encode_strong_enum(
}
}
Ok(quote!(
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where &'q str: sqlx::encode::Encode<'q, DB> {
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
let val = match self {
#(#value_arms)*
};
let mut tts = proc_macro2::TokenStream::new();
<&str as sqlx::encode::Encode<'q, DB>>::encode(val, buf)
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::MySql> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::MySql> {
fn encode_by_ref(&self, buf: &mut <sqlx::MySql as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
let val = match self {
#(#value_arms)*
};
<&str as sqlx::encode::Encode<'q, sqlx::MySql>>::encode(val, buf)
}
fn size_hint(&self) -> usize {
let val = match self {
#(#value_arms)*
};
<&str as sqlx::encode::Encode<'q, sqlx::MySql>>::size_hint(&val)
}
}
));
}
fn size_hint(&self) -> usize {
let val = match self {
#(#value_arms)*
};
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::Postgres> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::Postgres> {
fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
let val = match self {
#(#value_arms)*
};
<&str as sqlx::encode::Encode<'q, DB>>::size_hint(&val)
<&str as sqlx::encode::Encode<'q, sqlx::Postgres>>::encode(val, buf)
}
fn size_hint(&self) -> usize {
let val = match self {
#(#value_arms)*
};
<&str as sqlx::encode::Encode<'q, sqlx::Postgres>>::size_hint(&val)
}
}
}
))
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl<'q> sqlx::encode::Encode<'q, sqlx::Sqlite> for #ident where &'q str: sqlx::encode::Encode<'q, sqlx::Sqlite> {
fn encode_by_ref(&self, buf: &mut <sqlx::Sqlite as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
let val = match self {
#(#value_arms)*
};
<&str as sqlx::encode::Encode<'q, sqlx::Sqlite>>::encode(val, buf)
}
fn size_hint(&self) -> usize {
let val = match self {
#(#value_arms)*
};
<&str as sqlx::encode::Encode<'q, sqlx::Sqlite>>::size_hint(&val)
}
}
));
}
Ok(tts)
}
fn expand_derive_encode_struct(

View file

@ -59,21 +59,64 @@ fn expand_derive_has_sql_type_transparent(
if attr.transparent {
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::Type<DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
let mut tts = proc_macro2::TokenStream::new();
return Ok(quote!(
impl #impl_generics sqlx::Type< DB > for #ident #ty_generics #where_clause {
fn type_info() -> DB::TypeInfo {
<#ty as sqlx::Type<DB>>::type_info()
if cfg!(feature = "mysql") {
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::Type<sqlx::MySql>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
tts.extend(quote!(
impl #impl_generics sqlx::Type<sqlx::MySql> for #ident #ty_generics #where_clause
{
fn type_info() -> sqlx::mysql::MySqlTypeInfo {
<#ty as sqlx::Type<sqlx::MySql>>::type_info()
}
}
}
));
));
}
if cfg!(feature = "postgres") {
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::Type<sqlx::Postgres>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
tts.extend(quote!(
impl #impl_generics sqlx::Type<sqlx::Postgres> for #ident #ty_generics #where_clause
{
fn type_info() -> sqlx::postgres::PgTypeInfo {
<#ty as sqlx::Type<sqlx::Postgres>>::type_info()
}
}
));
}
if cfg!(feature = "sqlite") {
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::Type<sqlx::Sqlite>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
tts.extend(quote!(
impl #impl_generics sqlx::Type<sqlx::Sqlite> for #ident #ty_generics #where_clause
{
fn type_info() -> sqlx::sqlite::SqliteTypeInfo {
<#ty as sqlx::Type<sqlx::Sqlite>>::type_info()
}
}
));
}
return Ok(tts);
}
let mut tts = proc_macro2::TokenStream::new();
@ -100,18 +143,49 @@ fn expand_derive_has_sql_type_weak_enum(
let attr = check_weak_enum_attributes(input, variants)?;
let repr = attr.repr.unwrap();
let ident = &input.ident;
let ts = quote!(
impl<DB: sqlx::Database> sqlx::Type<DB> for #ident
where
#repr: sqlx::Type<DB>,
{
fn type_info() -> DB::TypeInfo {
<#repr as sqlx::Type<DB>>::type_info()
}
}
);
Ok(ts)
let mut tts = proc_macro2::TokenStream::new();
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl sqlx::Type<sqlx::MySql> for #ident
where
#repr: sqlx::Type<sqlx::MySql>,
{
fn type_info() -> sqlx::mysql::MySqlTypeInfo {
<#repr as sqlx::Type<sqlx::MySql>>::type_info()
}
}
));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl sqlx::Type<sqlx::Postgres> for #ident
where
#repr: sqlx::Type<sqlx::Postgres>,
{
fn type_info() -> sqlx::postgres::PgTypeInfo {
<#repr as sqlx::Type<sqlx::Postgres>>::type_info()
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl sqlx::Type<sqlx::Sqlite> for #ident
where
#repr: sqlx::Type<sqlx::Sqlite>,
{
fn type_info() -> sqlx::sqlite::SqliteTypeInfo {
<#repr as sqlx::Type<sqlx::Sqlite>>::type_info()
}
}
));
}
Ok(tts)
}
fn expand_derive_has_sql_type_strong_enum(