fix(macros): fix derive for Encode

This commit is contained in:
Austin Bonander 2020-06-10 21:18:25 -07:00 committed by Ryan Leckey
parent 475ed9e1df
commit 646823e093

View file

@ -9,7 +9,7 @@ use syn::punctuated::Punctuated;
use syn::token::Comma; use syn::token::Comma;
use syn::{ use syn::{
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
FieldsUnnamed, Lifetime, Stmt, Variant, FieldsUnnamed, Lifetime, LifetimeDef, Stmt, Variant,
}; };
pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> { pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
@ -64,6 +64,9 @@ fn expand_derive_encode_transparent(
// add db type for impl generics & where clause // add db type for impl generics & where clause
let lifetime = Lifetime::new("'q", Span::call_site()); let lifetime = Lifetime::new("'q", Span::call_site());
let mut generics = generics.clone(); let mut generics = generics.clone();
generics
.params
.insert(0, LifetimeDef::new(lifetime.clone()).into());
generics.params.insert(0, parse_quote!(DB: sqlx::Database)); generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics generics
.make_where_clause() .make_where_clause()
@ -72,7 +75,7 @@ fn expand_derive_encode_transparent(
let (impl_generics, _, where_clause) = generics.split_for_impl(); let (impl_generics, _, where_clause) = generics.split_for_impl();
Ok(quote!( Ok(quote!(
impl<#lifetime> #impl_generics sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause { impl #impl_generics sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause {
fn encode(self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull { fn encode(self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
sqlx::encode::Encode::encode(self.0, buf) sqlx::encode::Encode::encode(self.0, buf)
} }
@ -182,6 +185,7 @@ fn expand_derive_encode_struct(
for field in fields { for field in fields {
let ty = &field.ty; let ty = &field.ty;
predicates.insert(0, parse_quote!('q));
predicates.push(parse_quote!(#ty: sqlx::encode::Encode<'q, sqlx::Postgres>)); predicates.push(parse_quote!(#ty: sqlx::encode::Encode<'q, sqlx::Postgres>));
predicates.push(parse_quote!(#ty: sqlx::types::Type<'q, sqlx::Postgres>)); predicates.push(parse_quote!(#ty: sqlx::types::Type<'q, sqlx::Postgres>));
} }
@ -207,7 +211,7 @@ fn expand_derive_encode_struct(
}); });
tts.extend(quote!( tts.extend(quote!(
impl<'q> #impl_generics sqlx::encode::Encode<'q, sqlx::Postgres> for #ident #ty_generics #where_clause { impl #impl_generics sqlx::encode::Encode<'q, sqlx::Postgres> for #ident #ty_generics #where_clause {
fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull { fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
let mut encoder = sqlx::postgres::types::raw::PgRecordEncoder::new(buf); let mut encoder = sqlx::postgres::types::raw::PgRecordEncoder::new(buf);