Merge branch '1d-array-pg-support' of https://github.com/oeb25/sqlx into oeb25-1d-array-pg-support

This commit is contained in:
Ryan Leckey 2020-03-20 10:11:02 -07:00
commit f562e919b7
12 changed files with 329 additions and 1 deletions

View file

@ -0,0 +1,202 @@
/// Encoding and decoding of Postgres arrays. Documentation of the byte format can be found [here](https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/include/utils/array.h;h=7f7e744cb12bc872f628f90dad99dfdf074eb314;hb=master#l6)
use crate::decode::Decode;
use crate::decode::DecodeError;
use crate::encode::Encode;
use crate::io::{Buf, BufMut};
use crate::postgres::database::Postgres;
use crate::types::HasSqlType;
use std::marker::PhantomData;
impl<T> Encode<Postgres> for [T]
where
T: Encode<Postgres>,
Postgres: HasSqlType<T>,
{
fn encode(&self, buf: &mut Vec<u8>) {
let mut encoder = ArrayEncoder::new(buf);
for item in self {
encoder.push(item);
}
}
}
impl<T> Encode<Postgres> for Vec<T>
where
[T]: Encode<Postgres>,
Postgres: HasSqlType<T>,
{
fn encode(&self, buf: &mut Vec<u8>) {
self.as_slice().encode(buf)
}
}
impl<T> Decode<Postgres> for Vec<T>
where
T: Decode<Postgres>,
Postgres: HasSqlType<T>,
{
fn decode(buf: &[u8]) -> Result<Self, DecodeError> {
let decoder = ArrayDecoder::<T>::new(buf)?;
decoder.collect()
}
}
type Order = byteorder::BigEndian;
struct ArrayDecoder<'a, T>
where
T: Decode<Postgres>,
Postgres: HasSqlType<T>,
{
left: usize,
did_error: bool,
buf: &'a [u8],
phantom: PhantomData<T>,
}
impl<T> ArrayDecoder<'_, T>
where
T: Decode<Postgres>,
Postgres: HasSqlType<T>,
{
fn new(mut buf: &[u8]) -> Result<ArrayDecoder<T>, DecodeError> {
let ndim = buf.get_i32::<Order>()?;
let dataoffset = buf.get_i32::<Order>()?;
let elemtype = buf.get_i32::<Order>()?;
if ndim == 0 {
return Ok(ArrayDecoder {
left: 0,
did_error: false,
buf,
phantom: PhantomData,
});
}
assert_eq!(ndim, 1, "only arrays of dimension 1 is supported");
let dimensions = buf.get_i32::<Order>()?;
let lower_bnds = buf.get_i32::<Order>()?;
assert_eq!(dataoffset, 0, "arrays with [null bitmap] is not supported");
assert_eq!(
elemtype,
<Postgres as HasSqlType<T>>::type_info().id.0 as i32,
"mismatched array element type"
);
assert_eq!(lower_bnds, 1);
Ok(ArrayDecoder {
left: dimensions as usize,
did_error: false,
buf,
phantom: PhantomData,
})
}
/// Decodes the next element without worring how many are left, or if it previously errored
fn decode_next_element(&mut self) -> Result<T, DecodeError> {
let len = self.buf.get_i32::<Order>()?;
let bytes = self.buf.get_bytes(len as usize)?;
Decode::decode(bytes)
}
}
impl<T> Iterator for ArrayDecoder<'_, T>
where
T: Decode<Postgres>,
Postgres: HasSqlType<T>,
{
type Item = Result<T, DecodeError>;
fn next(&mut self) -> Option<Result<T, DecodeError>> {
if self.did_error || self.left == 0 {
return None;
}
self.left -= 1;
let decoded = self.decode_next_element();
self.did_error = decoded.is_err();
Some(decoded)
}
}
struct ArrayEncoder<'a, T>
where
T: Encode<Postgres>,
Postgres: HasSqlType<T>,
{
count: usize,
len_start_index: usize,
buf: &'a mut Vec<u8>,
phantom: PhantomData<T>,
}
impl<T> ArrayEncoder<'_, T>
where
T: Encode<Postgres>,
Postgres: HasSqlType<T>,
{
fn new(buf: &mut Vec<u8>) -> ArrayEncoder<T> {
let ty = <Postgres as HasSqlType<T>>::type_info();
// ndim
buf.put_i32::<Order>(1);
// dataoffset
buf.put_i32::<Order>(0);
// elemtype
buf.put_i32::<Order>(ty.id.0 as i32);
let len_start_index = buf.len();
// dimensions
buf.put_i32::<Order>(0);
// lower_bnds
buf.put_i32::<Order>(1);
ArrayEncoder {
count: 0,
len_start_index,
buf,
phantom: PhantomData,
}
}
fn push(&mut self, item: &T) {
// Allocate space for the length of the encoded elemement up front
let el_len_index = self.buf.len();
self.buf.put_i32::<Order>(0);
// Allocate the element it self
let el_start = self.buf.len();
Encode::encode(item, self.buf);
let el_end = self.buf.len();
// Now we know the actual length of the encoded element
let el_len = el_end - el_start;
// And we can now go back and update the length
self.buf[el_len_index..el_start].copy_from_slice(&(el_len as i32).to_be_bytes());
self.count += 1;
}
fn update_len(&mut self) {
const I32_SIZE: usize = std::mem::size_of::<i32>();
let size_bytes = (self.count as i32).to_be_bytes();
self.buf[self.len_start_index..self.len_start_index + I32_SIZE]
.copy_from_slice(&size_bytes);
}
}
impl<T> Drop for ArrayEncoder<'_, T>
where
T: Encode<Postgres>,
Postgres: HasSqlType<T>,
{
fn drop(&mut self) {
self.update_len();
}
}

View file

@ -19,6 +19,11 @@ impl Type<Postgres> for [bool] {
PgTypeInfo::new(TypeId::ARRAY_BOOL, "BOOL[]")
}
}
impl HasSqlType<Vec<bool>> for Postgres {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<[bool]>>::type_info()
}
}
impl Encode<Postgres> for bool {
fn encode(&self, buf: &mut Vec<u8>) {

View file

@ -19,6 +19,12 @@ impl Type<Postgres> for [&'_ [u8]] {
}
}
impl Type<Postgres> for Vec<&'_ [u8]> {
fn type_info() -> PgTypeInfo {
<&'_ [u8] as Type<Postgres>>::type_info()
}
}
impl Type<Postgres> for Vec<u8> {
fn type_info() -> PgTypeInfo {
<[u8] as Type<Postgres>>::type_info()

View file

@ -67,6 +67,33 @@ where
}
}
impl Type<Postgres> for Vec<NaiveTime> {
fn type_info() -> PgTypeInfo {
<[NaiveTime] as Type<Postgres>>::type_info()
}
}
impl Type<Postgres> for Vec<NaiveDate> {
fn type_info() -> PgTypeInfo {
<[NaiveDate] as Type<Postgres>>::type_info()
}
}
impl Type<Postgres> for Vec<NaiveDateTime> {
fn type_info() -> PgTypeInfo {
<[NaiveDateTime] as Type<Postgres>>::type_info()
}
}
impl<Tz> Type<Postgres> for Vec<DateTime<Tz>>
where
Tz: TimeZone,
{
fn type_info() -> PgTypeInfo {
<[NaiveDateTime] as Type<Postgres>>::type_info()
}
}
impl<'de> Decode<'de, Postgres> for NaiveTime {
fn decode(value: Option<PgValue<'de>>) -> crate::Result<Self> {
match value.try_into()? {

View file

@ -22,6 +22,11 @@ impl Type<Postgres> for [f32] {
PgTypeInfo::new(TypeId::ARRAY_FLOAT4, "FLOAT4[]")
}
}
impl HasSqlType<Vec<f32>> for Postgres {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<[f32]>>::type_info()
}
}
impl Encode<Postgres> for f32 {
fn encode(&self, buf: &mut Vec<u8>) {
@ -53,6 +58,11 @@ impl Type<Postgres> for [f64] {
PgTypeInfo::new(TypeId::ARRAY_FLOAT8, "FLOAT8[]")
}
}
impl HasSqlType<Vec<f64>> for Postgres {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<[f64]>>::type_info()
}
}
impl Encode<Postgres> for f64 {
fn encode(&self, buf: &mut Vec<u8>) {

View file

@ -22,6 +22,11 @@ impl Type<Postgres> for [i16] {
PgTypeInfo::new(TypeId::ARRAY_INT2, "INT2[]")
}
}
impl HasSqlType<Vec<i16>> for Postgres {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<[i16]>>::type_info()
}
}
impl Encode<Postgres> for i16 {
fn encode(&self, buf: &mut Vec<u8>) {
@ -49,6 +54,11 @@ impl Type<Postgres> for [i32] {
PgTypeInfo::new(TypeId::ARRAY_INT4, "INT4[]")
}
}
impl HasSqlType<Vec<i32>> for Postgres {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<[i32]>>::type_info()
}
}
impl Encode<Postgres> for i32 {
fn encode(&self, buf: &mut Vec<u8>) {
@ -76,6 +86,11 @@ impl Type<Postgres> for [i64] {
PgTypeInfo::new(TypeId::ARRAY_INT8, "INT8[]")
}
}
impl HasSqlType<Vec<i64>> for Postgres {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<[i64]>>::type_info()
}
}
impl Encode<Postgres> for i64 {
fn encode(&self, buf: &mut Vec<u8>) {

View file

@ -59,6 +59,7 @@ use crate::postgres::protocol::TypeId;
use crate::postgres::{PgValue, Postgres};
use crate::types::TypeInfo;
mod array;
mod bool;
mod bytes;
mod float;

View file

@ -21,12 +21,27 @@ impl Type<Postgres> for [&'_ str] {
PgTypeInfo::new(TypeId::ARRAY_TEXT, "TEXT[]")
}
}
impl HasSqlType<Vec<&'_ str>> for Postgres {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<[&'_ str]>>::type_info()
}
}
impl Type<Postgres> for String {
fn type_info() -> PgTypeInfo {
<str as Type<Postgres>>::type_info()
}
}
impl HasSqlType<[String]> for Postgres {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<[&'_ str]>>::type_info()
}
}
impl HasSqlType<Vec<String>> for Postgres {
fn type_info() -> PgTypeInfo {
<Self as HasSqlType<Vec<&'_ str>>>::type_info()
}
}
impl Encode<Postgres> for str {
fn encode(&self, buf: &mut Vec<u8>) {

View file

@ -23,6 +23,12 @@ impl Type<Postgres> for [Uuid] {
}
}
impl HasSqlType<Vec<Uuid>> for Postgres {
fn type_info() -> PgTypeInfo {
<Postgres as HasSqlType<[Uuid]>>::type_info()
}
}
impl Encode<Postgres> for Uuid {
fn encode(&self, buf: &mut Vec<u8>) {
buf.extend_from_slice(self.as_bytes());

View file

@ -30,7 +30,16 @@ impl_database_ext! {
sqlx::types::BigDecimal,
#[cfg(feature = "ipnetwork")]
sqlx::types::ipnetwork::IpNetwork
sqlx::types::ipnetwork::IpNetwork,
// Arrays
Vec<bool> | &[bool],
Vec<String> | &[String],
Vec<i16> | &[i16],
Vec<i32> | &[i32],
Vec<i64> | &[i64],
Vec<f32> | &[f32],
Vec<f64> | &[f64],
},
ParamChecking::Strong,
feature-types: info => info.type_feature_gate(),

View file

@ -186,6 +186,32 @@ async fn test_many_args() -> anyhow::Result<()> {
Ok(())
}
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_array_from_slice() -> anyhow::Result<()> {
let mut conn = connect().await?;
let list: &[i32] = &[1, 2, 3, 4i32];
let result = sqlx::query!("SELECT $1::int[] as my_array", *list)
.fetch_one(&mut conn)
.await?;
assert_eq!(result.my_array, vec![1, 2, 3, 4]);
println!("result ID: {:?}", result.my_array);
let account = sqlx::query!("SELECT ARRAY[4,3,2,1] as my_array")
.fetch_one(&mut conn)
.await?;
assert_eq!(account.my_array, vec![4, 3, 2, 1]);
println!("account ID: {:?}", account.my_array);
Ok(())
}
async fn connect() -> anyhow::Result<PgConnection> {
let _ = dotenv::dotenv();
let _ = env_logger::try_init();

View file

@ -282,6 +282,12 @@ async fn test_unprepared_anonymous_record() -> anyhow::Result<()> {
Ok(())
}
test!(postgres_int_vec: Vec<i32>: "ARRAY[1, 2, 3]::int[]" == vec![1, 2, 3i32], "ARRAY[3, 292, 15, 2, 3]::int[]" == vec![3, 292, 15, 2, 3], "ARRAY[7, 6, 5, 4, 3, 2, 1]::int[]" == vec![7, 6, 5, 4, 3, 2, 1], "ARRAY[]::int[]" == vec![] as Vec<i32>);
test!(postgres_string_vec: Vec<String>: "ARRAY['Hello', 'world', 'friend']::text[]" == vec!["Hello", "world", "friend"]);
test!(postgres_bool_vec: Vec<bool>: "ARRAY[true, true, false, true]::bool[]" == vec![true, true, false, true]);
test!(postgres_real_vec: Vec<f32>: "ARRAY[0.0, 1.0, 3.14, 1.234, -0.002, 100000.0]::real[]" == vec![0.0, 1.0, 3.14, 1.234, -0.002, 100000.0_f32]);
test!(postgres_double_vec: Vec<f64>: "ARRAY[0.0, 1.0, 3.14, 1.234, -0.002, 100000.0]::double precision[]" == vec![0.0, 1.0, 3.14, 1.234, -0.002, 100000.0_f64]);
#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_prepared_structs() -> anyhow::Result<()> {