move decode_struct_field and encode_struct_field to sqlx-core

This commit is contained in:
Tom Dohrmann 2020-02-10 11:43:37 +01:00 committed by Ryan Leckey
parent e603f5fcf6
commit 4cd179d42b
5 changed files with 71 additions and 44 deletions

View file

@ -29,4 +29,4 @@ pub type PgPool = crate::pool::Pool<PgConnection>;
make_query_as!(PgQueryAs, Postgres, PgRow);
impl_map_row_for_row!(Postgres, PgRow);
impl_column_index_for_row!(Postgres);
impl_from_row_for_tuples!(Postgres, PgRow);
impl_from_row_for_tuples!(Postgres, PgRow);

View file

@ -12,6 +12,7 @@ mod bytes;
mod float;
mod int;
mod str;
pub mod r#struct;
#[cfg(feature = "chrono")]
mod chrono;

View file

@ -0,0 +1,59 @@
use crate::decode::{Decode, DecodeError};
use crate::encode::Encode;
use crate::postgres::protocol::TypeId;
use crate::postgres::types::PgTypeInfo;
use crate::types::HasSqlType;
use crate::Postgres;
use std::convert::TryInto;
/// read a struct field and advance the buffer
pub fn decode_struct_field<T: Decode<Postgres>>(buf: &mut &[u8]) -> Result<T, DecodeError>
where
Postgres: HasSqlType<T>,
{
if buf.len() < 8 {
return Err(DecodeError::Message(std::boxed::Box::new(
"Not enough data sent",
)));
}
let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap());
if oid != <Postgres as HasSqlType<T>>::type_info().oid() {
return Err(DecodeError::Message(std::boxed::Box::new("Invalid oid")));
}
let len = u32::from_be_bytes(buf[4..8].try_into().unwrap()) as usize;
if buf.len() < 8 + len {
return Err(DecodeError::Message(std::boxed::Box::new(
"Not enough data sent",
)));
}
let raw = &buf[8..8 + len];
let value = T::decode(raw)?;
*buf = &buf[8 + len..];
Ok(value)
}
pub fn encode_struct_field<T: Encode<Postgres>>(buf: &mut Vec<u8>, value: &T)
where
Postgres: HasSqlType<T>,
{
// write oid
let info = <Postgres as HasSqlType<T>>::type_info();
buf.extend(&info.oid().to_be_bytes());
// write zeros for length
buf.extend(&[0; 4]);
let start = buf.len();
value.encode(buf);
let end = buf.len();
let size = end - start;
// replaces zeros with actual length
buf[start - 4..start].copy_from_slice(&(size as u32).to_be_bytes());
}

View file

@ -161,35 +161,16 @@ fn expand_derive_decode_struct(
}
let (impl_generics, _, where_clause) = generics.split_for_impl();
let mut reads: Vec<Vec<Stmt>> = Vec::new();
let mut reads: Vec<Stmt> = Vec::new();
let mut names: Vec<Ident> = Vec::new();
for field in fields {
let id = &field.ident;
names.push(id.clone().unwrap());
let ty = &field.ty;
reads.push(parse_quote!(
if buf.len() < 8 {
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent")));
}
let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap());
if oid != <sqlx::Postgres as sqlx::types::HasSqlType<#ty>>::type_info().oid() {
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid")));
}
let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize;
if buf.len() < 8 + len {
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent")));
}
let raw = &buf[8..8+len];
let #id = <#ty as sqlx::decode::Decode<sqlx::Postgres>>::decode(raw)?;
let buf = &buf[8+len..];
));
let #id = sqlx::postgres::decode_struct_field::<#ty>(&mut buf)?;
));
}
let reads = reads.into_iter().flatten();
tts.extend(quote!(
impl #impl_generics sqlx::decode::Decode<sqlx::Postgres> for #ident#ty_generics #where_clause {
@ -202,7 +183,7 @@ fn expand_derive_decode_struct(
if column_count != #column_count {
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count")));
}
let buf = &buf[4..];
let mut buf = &buf[4..];
#(#reads)*

View file

@ -6,8 +6,8 @@ use quote::quote;
use syn::punctuated::Punctuated;
use syn::token::Comma;
use syn::{
parse_quote, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
FieldsUnnamed, Variant,
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
FieldsUnnamed, Stmt, Variant,
};
pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
@ -160,26 +160,12 @@ fn expand_derive_encode_struct(
}
let (impl_generics, _, where_clause) = generics.split_for_impl();
let mut writes: Vec<Block> = Vec::new();
let mut writes: Vec<Stmt> = Vec::new();
for field in fields {
let id = &field.ident;
let ty = &field.ty;
writes.push(parse_quote!({
// write oid
let info = <sqlx::Postgres as sqlx::types::HasSqlType<#ty>>::type_info();
buf.extend(&info.oid().to_be_bytes());
// write zeros for length
buf.extend(&[0; 4]);
let start = buf.len();
sqlx::encode::Encode::<sqlx::Postgres>::encode(&self. #id, buf);
let end = buf.len();
let size = end - start;
// replaces zeros with actual length
buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes());
}));
writes.push(parse_quote!(
sqlx::postgres::encode_struct_field(buf, &self. #id);
));
}
let mut sizes: Vec<Expr> = Vec::new();