mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 14:34:19 +00:00
move decode_struct_field and encode_struct_field to sqlx-core
This commit is contained in:
parent
e603f5fcf6
commit
4cd179d42b
5 changed files with 71 additions and 44 deletions
|
@ -29,4 +29,4 @@ pub type PgPool = crate::pool::Pool<PgConnection>;
|
|||
make_query_as!(PgQueryAs, Postgres, PgRow);
|
||||
impl_map_row_for_row!(Postgres, PgRow);
|
||||
impl_column_index_for_row!(Postgres);
|
||||
impl_from_row_for_tuples!(Postgres, PgRow);
|
||||
impl_from_row_for_tuples!(Postgres, PgRow);
|
|
@ -12,6 +12,7 @@ mod bytes;
|
|||
mod float;
|
||||
mod int;
|
||||
mod str;
|
||||
pub mod r#struct;
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
mod chrono;
|
||||
|
|
59
sqlx-core/src/postgres/types/struct.rs
Normal file
59
sqlx-core/src/postgres/types/struct.rs
Normal file
|
@ -0,0 +1,59 @@
|
|||
use crate::decode::{Decode, DecodeError};
|
||||
use crate::encode::Encode;
|
||||
use crate::postgres::protocol::TypeId;
|
||||
use crate::postgres::types::PgTypeInfo;
|
||||
use crate::types::HasSqlType;
|
||||
use crate::Postgres;
|
||||
use std::convert::TryInto;
|
||||
|
||||
/// read a struct field and advance the buffer
|
||||
pub fn decode_struct_field<T: Decode<Postgres>>(buf: &mut &[u8]) -> Result<T, DecodeError>
|
||||
where
|
||||
Postgres: HasSqlType<T>,
|
||||
{
|
||||
if buf.len() < 8 {
|
||||
return Err(DecodeError::Message(std::boxed::Box::new(
|
||||
"Not enough data sent",
|
||||
)));
|
||||
}
|
||||
|
||||
let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap());
|
||||
if oid != <Postgres as HasSqlType<T>>::type_info().oid() {
|
||||
return Err(DecodeError::Message(std::boxed::Box::new("Invalid oid")));
|
||||
}
|
||||
|
||||
let len = u32::from_be_bytes(buf[4..8].try_into().unwrap()) as usize;
|
||||
|
||||
if buf.len() < 8 + len {
|
||||
return Err(DecodeError::Message(std::boxed::Box::new(
|
||||
"Not enough data sent",
|
||||
)));
|
||||
}
|
||||
|
||||
let raw = &buf[8..8 + len];
|
||||
let value = T::decode(raw)?;
|
||||
|
||||
*buf = &buf[8 + len..];
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
pub fn encode_struct_field<T: Encode<Postgres>>(buf: &mut Vec<u8>, value: &T)
|
||||
where
|
||||
Postgres: HasSqlType<T>,
|
||||
{
|
||||
// write oid
|
||||
let info = <Postgres as HasSqlType<T>>::type_info();
|
||||
buf.extend(&info.oid().to_be_bytes());
|
||||
|
||||
// write zeros for length
|
||||
buf.extend(&[0; 4]);
|
||||
|
||||
let start = buf.len();
|
||||
value.encode(buf);
|
||||
let end = buf.len();
|
||||
let size = end - start;
|
||||
|
||||
// replaces zeros with actual length
|
||||
buf[start - 4..start].copy_from_slice(&(size as u32).to_be_bytes());
|
||||
}
|
|
@ -161,35 +161,16 @@ fn expand_derive_decode_struct(
|
|||
}
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
let mut reads: Vec<Vec<Stmt>> = Vec::new();
|
||||
let mut reads: Vec<Stmt> = Vec::new();
|
||||
let mut names: Vec<Ident> = Vec::new();
|
||||
for field in fields {
|
||||
let id = &field.ident;
|
||||
names.push(id.clone().unwrap());
|
||||
let ty = &field.ty;
|
||||
reads.push(parse_quote!(
|
||||
if buf.len() < 8 {
|
||||
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent")));
|
||||
}
|
||||
|
||||
let oid = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[0..4]).unwrap());
|
||||
if oid != <sqlx::Postgres as sqlx::types::HasSqlType<#ty>>::type_info().oid() {
|
||||
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid oid")));
|
||||
}
|
||||
|
||||
let len = u32::from_be_bytes(std::convert::TryInto::try_into(&buf[4..8]).unwrap()) as usize;
|
||||
|
||||
if buf.len() < 8 + len {
|
||||
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Not enough data sent")));
|
||||
}
|
||||
|
||||
let raw = &buf[8..8+len];
|
||||
let #id = <#ty as sqlx::decode::Decode<sqlx::Postgres>>::decode(raw)?;
|
||||
|
||||
let buf = &buf[8+len..];
|
||||
));
|
||||
let #id = sqlx::postgres::decode_struct_field::<#ty>(&mut buf)?;
|
||||
));
|
||||
}
|
||||
let reads = reads.into_iter().flatten();
|
||||
|
||||
tts.extend(quote!(
|
||||
impl #impl_generics sqlx::decode::Decode<sqlx::Postgres> for #ident#ty_generics #where_clause {
|
||||
|
@ -202,7 +183,7 @@ fn expand_derive_decode_struct(
|
|||
if column_count != #column_count {
|
||||
return Err(sqlx::decode::DecodeError::Message(std::boxed::Box::new("Invalid column count")));
|
||||
}
|
||||
let buf = &buf[4..];
|
||||
let mut buf = &buf[4..];
|
||||
|
||||
#(#reads)*
|
||||
|
||||
|
|
|
@ -6,8 +6,8 @@ use quote::quote;
|
|||
use syn::punctuated::Punctuated;
|
||||
use syn::token::Comma;
|
||||
use syn::{
|
||||
parse_quote, Block, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
|
||||
FieldsUnnamed, Variant,
|
||||
parse_quote, Data, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed,
|
||||
FieldsUnnamed, Stmt, Variant,
|
||||
};
|
||||
|
||||
pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
|
||||
|
@ -160,26 +160,12 @@ fn expand_derive_encode_struct(
|
|||
}
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
let mut writes: Vec<Block> = Vec::new();
|
||||
let mut writes: Vec<Stmt> = Vec::new();
|
||||
for field in fields {
|
||||
let id = &field.ident;
|
||||
let ty = &field.ty;
|
||||
writes.push(parse_quote!({
|
||||
// write oid
|
||||
let info = <sqlx::Postgres as sqlx::types::HasSqlType<#ty>>::type_info();
|
||||
buf.extend(&info.oid().to_be_bytes());
|
||||
|
||||
// write zeros for length
|
||||
buf.extend(&[0; 4]);
|
||||
|
||||
let start = buf.len();
|
||||
sqlx::encode::Encode::<sqlx::Postgres>::encode(&self. #id, buf);
|
||||
let end = buf.len();
|
||||
let size = end - start;
|
||||
|
||||
// replaces zeros with actual length
|
||||
buf[start-4..start].copy_from_slice(&(size as u32).to_be_bytes());
|
||||
}));
|
||||
writes.push(parse_quote!(
|
||||
sqlx::postgres::encode_struct_field(buf, &self. #id);
|
||||
));
|
||||
}
|
||||
|
||||
let mut sizes: Vec<Expr> = Vec::new();
|
||||
|
|
Loading…
Reference in a new issue