feat: add generic Array adapter

Signed-off-by: Austin Bonander <austin@launchbadge.com>
This commit is contained in:
Austin Bonander 2021-04-30 18:52:14 -07:00
parent 405474b575
commit 9fc011d827
No known key found for this signature in database
GPG key ID: 4E7DA63E66AFC37E
9 changed files with 364 additions and 134 deletions

View file

@ -32,6 +32,7 @@ postgres = [
"futures-util/sink",
"json",
"dirs",
"array"
]
mysql = [
"sha-1",
@ -58,6 +59,7 @@ all-types = [
"uuid",
"bit-vec",
]
array = []
bigdecimal = ["bigdecimal_", "num-bigint"]
decimal = ["rust_decimal", "num-bigint"]
json = ["serde", "serde_json"]

View file

@ -4,8 +4,10 @@ use crate::arguments::Arguments;
use crate::encode::{Encode, IsNull};
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::postgres::type_info::PgType;
use crate::postgres::{PgConnection, PgTypeInfo, Postgres};
use crate::types::Type;
use std::convert::TryInto;
// TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ?
// TODO: Extend the patch system to support dynamic lengths
@ -141,6 +143,42 @@ impl PgArgumentBuffer {
self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
}
pub(crate) fn encode_iter<'q, T, I>(&mut self, iter: I)
where
T: Encode<'q, Postgres> + Type<Postgres>,
I: IntoIterator<Item = T>,
{
self.extend(&1_i32.to_be_bytes()); // number of dimensions
self.extend(&0_i32.to_be_bytes()); // flags
// element type
match T::type_info().0 {
PgType::DeclareWithName(name) => self.patch_type_by_name(&name),
ty => {
self.extend(&ty.oid().to_be_bytes());
}
}
let len_at = self.len();
self.extend(&[0u8; 4]); // len (initially zero but we'll fix this up)
self.extend(&1_i32.to_be_bytes()); // lower bound
// count while encoding items at the same time
let len: i32 = iter
.into_iter()
.map(|item| item.encode(self))
.count()
.try_into()
// in practice, Postgres will reject arrays significantly smaller than this:
// https://github.com/postgres/postgres/blob/e6f9539dc32473793c03cbe95bc099ee0a199c73/src/backend/utils/adt/arrayutils.c#L66
.expect("array length exceeds maximum the Postgres protocol can handle");
// fixup the actual length
self[len_at..len_at + 4].copy_from_slice(&len.to_be_bytes());
}
// Adds a callback to be invoked later when we know the parameter type
#[allow(dead_code)]
pub(crate) fn patch<F>(&mut self, callback: F)

View file

@ -1,11 +1,13 @@
use bytes::Buf;
use crate::database::{HasArguments, HasValueRef};
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::postgres::type_info::PgType;
use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
use crate::types::Type;
use std::iter::FromIterator;
impl<T> Type<Postgres> for [Option<T>]
where
@ -33,14 +35,28 @@ where
}
}
impl<I> Type<Postgres> for crate::types::Array<I>
where
I: IntoIterator,
[I::Item]: Type<Postgres>,
{
fn type_info() -> PgTypeInfo {
<[I::Item] as Type<Postgres>>::type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
<[I::Item] as Type<Postgres>>::compatible(ty)
}
}
impl<'q, T> Encode<'q, Postgres> for Vec<T>
where
for<'a> &'a [T]: Encode<'q, Postgres>,
T: Encode<'q, Postgres>,
T: Encode<'q, Postgres> + Type<Postgres>,
{
#[inline]
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
self.as_slice().encode_by_ref(buf)
buf.encode_iter(self.as_slice());
IsNull::No
}
}
@ -49,25 +65,18 @@ where
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
buf.extend(&1_i32.to_be_bytes()); // number of dimensions
buf.extend(&0_i32.to_be_bytes()); // flags
// element type
match T::type_info().0 {
PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
ty => {
buf.extend(&ty.oid().to_be_bytes());
}
}
buf.extend(&(self.len() as i32).to_be_bytes()); // len
buf.extend(&1_i32.to_be_bytes()); // lower bound
for element in self.iter() {
buf.encode(element);
}
buf.encode_iter(*self);
IsNull::No
}
}
impl<'q, T, I> Encode<'q, Postgres> for crate::types::Array<I>
where
for<'a> &'a I: IntoIterator<Item = T>,
T: Encode<'q, Postgres> + Type<Postgres>,
{
fn encode_by_ref(&self, buf: &mut <Postgres as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
buf.encode_iter(&self.0);
IsNull::No
}
}
@ -77,141 +86,149 @@ where
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
let element_type_info;
let format = value.format();
// `impl<T> FromIterator<T> for Vec<T>` is specialized for `vec::IntoIter<T>`:
// https://github.com/rust-lang/rust/blob/8a9fa3682dcf0de095ec308a29a7b19b0e011ef0/library/alloc/src/vec/spec_from_iter.rs
decode_array(value)
}
}
match format {
PgValueFormat::Binary => {
// https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L1548
impl<'r, I> Decode<'r, Postgres> for crate::types::Array<I>
where
I: IntoIterator + FromIterator<<I as IntoIterator>::Item>,
I::Item: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
decode_array(value).map(Self)
}
}
let mut buf = value.as_bytes()?;
fn decode_array<T, I>(value: PgValueRef<'_>) -> Result<I, BoxDynError>
where
I: FromIterator<T>,
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
let element_type_info;
let format = value.format();
// number of dimensions in the array
let ndim = buf.get_i32();
match format {
PgValueFormat::Binary => {
// https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L1548
if ndim == 0 {
// zero dimensions is an empty array
return Ok(Vec::new());
}
let mut buf = value.as_bytes()?;
if ndim != 1 {
return Err(format!("encountered an array of {} dimensions; only one-dimensional arrays are supported", ndim).into());
}
// number of dimensions in the array
let ndim = buf.get_i32();
// appears to have been used in the past to communicate potential NULLS
// but reading source code back through our supported postgres versions (9.5+)
// this is never used for anything
let _flags = buf.get_i32();
// the OID of the element
let element_type_oid = buf.get_u32();
element_type_info = PgTypeInfo::try_from_oid(element_type_oid)
.unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid)));
// length of the array axis
let len = buf.get_i32();
// the lower bound, we only support arrays starting from "1"
let lower = buf.get_i32();
if lower != 1 {
return Err(format!("encountered an array with a lower bound of {} in the first dimension; only arrays starting at one are supported", lower).into());
}
let mut elements = Vec::with_capacity(len as usize);
for _ in 0..len {
elements.push(T::decode(PgValueRef::get(
&mut buf,
format,
element_type_info.clone(),
))?)
}
Ok(elements)
if ndim == 0 {
// zero dimensions is an empty array
return Ok(I::from_iter(std::iter::empty()));
}
PgValueFormat::Text => {
// no type is provided from the database for the element
element_type_info = T::type_info();
if ndim != 1 {
return Err(format!("encountered an array of {} dimensions; only one-dimensional arrays are supported", ndim).into());
}
let s = value.as_str()?;
// appears to have been used in the past to communicate potential NULLS
// but reading source code back through our supported postgres versions (9.5+)
// this is never used for anything
let _flags = buf.get_i32();
// https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L718
// the OID of the element
let element_type_oid = buf.get_u32();
element_type_info = PgTypeInfo::try_from_oid(element_type_oid)
.unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid)));
// trim the wrapping braces
let s = &s[1..(s.len() - 1)];
// length of the array axis
let len = buf.get_i32();
if s.is_empty() {
// short-circuit empty arrays up here
return Ok(Vec::new());
// the lower bound, we only support arrays starting from "1"
let lower = buf.get_i32();
if lower != 1 {
return Err(format!("encountered an array with a lower bound of {} in the first dimension; only arrays starting at one are supported", lower).into());
}
(0..len)
.map(|_| T::decode(PgValueRef::get(&mut buf, format, element_type_info.clone())))
.collect()
}
PgValueFormat::Text => {
// no type is provided from the database for the element
element_type_info = T::type_info();
let s = value.as_str()?;
// https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L718
// trim the wrapping braces
let s = &s[1..(s.len() - 1)];
if s.is_empty() {
// short-circuit empty arrays up here
return Ok(I::from_iter(std::iter::empty()));
}
// NOTE: Nearly *all* types use ',' as the sequence delimiter. Yes, there is one
// that does not. The BOX (not PostGIS) type uses ';' as a delimiter.
// TODO: When we add support for BOX we need to figure out some way to make the
// delimiter selection
let delimiter = ',';
let mut in_quotes = false;
let mut in_escape = false;
let mut value = String::with_capacity(10);
let mut chars = s.chars();
std::iter::from_fn(|| {
if chars.as_str().is_empty() {
return None;
}
// NOTE: Nearly *all* types use ',' as the sequence delimiter. Yes, there is one
// that does not. The BOX (not PostGIS) type uses ';' as a delimiter.
for ch in &mut chars {
match ch {
_ if in_escape => {
value.push(ch);
in_escape = false;
}
// TODO: When we add support for BOX we need to figure out some way to make the
// delimiter selection
'"' => {
in_quotes = !in_quotes;
}
let delimiter = ',';
let mut done = false;
let mut in_quotes = false;
let mut in_escape = false;
let mut value = String::with_capacity(10);
let mut chars = s.chars();
let mut elements = Vec::with_capacity(4);
'\\' => {
in_escape = true;
}
while !done {
loop {
match chars.next() {
Some(ch) => match ch {
_ if in_escape => {
value.push(ch);
in_escape = false;
}
_ if ch == delimiter && !in_quotes => {
break;
}
'"' => {
in_quotes = !in_quotes;
}
'\\' => {
in_escape = true;
}
_ if ch == delimiter && !in_quotes => {
break;
}
_ => {
value.push(ch);
}
},
None => {
done = true;
break;
}
_ => {
value.push(ch);
}
}
let value_opt = if value == "NULL" {
None
} else {
Some(value.as_bytes())
};
elements.push(T::decode(PgValueRef {
value: value_opt,
row: None,
type_info: element_type_info.clone(),
format,
})?);
value.clear();
}
Ok(elements)
}
let value_opt = if value == "NULL" {
None
} else {
Some(value.as_bytes())
};
let ret = T::decode(PgValueRef {
value: value_opt,
row: None,
type_info: element_type_info.clone(),
format,
});
value.clear();
Some(ret)
})
.collect()
}
}
}

View file

@ -83,6 +83,19 @@ impl<'q, DB: Database> Query<'q, DB, <DB as HasArguments<'q>>::Arguments> {
self
}
/// Bind any iterable as an array.
/// Only supported on databases with first-class arrays, like Postgres.
///
/// See also: [Array][crate::types::Array]
#[cfg(feature = "array")]
#[cfg_attr(docsrs, doc(cfg(feature = "array")))]
pub fn bind_array<I>(mut self, iter: I) -> Self
where
crate::types::Array<I>: Encode<'q, DB> + Type<DB> + Send + 'q,
{
self.bind(crate::types::Array(iter))
}
}
impl<'q, DB, A> Query<'q, DB, A>

View file

@ -55,6 +55,19 @@ impl<'q, DB: Database, O> QueryAs<'q, DB, O, <DB as HasArguments<'q>>::Arguments
self.inner = self.inner.bind(value);
self
}
/// Bind any iterable as an array.
/// Only supported on databases with first-class arrays, like Postgres.
///
/// See also: [Array][crate::types::Array]
#[cfg(feature = "array")]
#[cfg_attr(docsrs, doc(cfg(feature = "array")))]
pub fn bind_array<I>(mut self, iter: I) -> Self
where
crate::types::Array<I>: Encode<'q, DB> + Type<DB> + Send + 'q,
{
self.bind(crate::types::Array(iter))
}
}
// FIXME: This is very close, nearly 1:1 with `Map`

View file

@ -52,6 +52,19 @@ impl<'q, DB: Database, O> QueryScalar<'q, DB, O, <DB as HasArguments<'q>>::Argum
self.inner = self.inner.bind(value);
self
}
/// Bind any iterable as an array.
/// Only supported on databases with first-class arrays, like Postgres.
///
/// See also: [Array][crate::types::Array]
#[cfg(feature = "array")]
#[cfg_attr(docsrs, doc(cfg(feature = "array")))]
pub fn bind_array<I>(mut self, iter: I) -> Self
where
crate::types::Array<I>: Encode<'q, DB> + Type<DB> + Send + 'q,
{
self.bind(crate::types::Array(iter))
}
}
// FIXME: This is very close, nearly 1:1 with `Map`

View file

@ -0,0 +1,110 @@
use std::ops::{Deref, DerefMut};
/// A generic adapter for encoding and decoding any type that implements
/// [`IntoIterator`][std::iter::IntoIterator]/[`FromIterator`][std::iter::FromIterator]
/// to or from an array in SQL, respectively.
///
/// Only supported on databases that have native support for arrays, such as PostgreSQL.
///
/// ## Examples
///
/// #### (Postgres) Bulk Insert with Array of Structs -> Struct of Arrays
///
/// You can implement bulk insert of structs by turning an array of structs into
/// an array for each field in the struct and then using Postgres' `UNNEST()`
///
/// ```rust,ignore
/// use sqlx::types::Array;
///
/// struct Foo {
/// bar: String,
/// baz: i32,
/// quux: bool
/// }
///
/// let foos = vec![
/// Foo {
/// bar: "bar".to_string(),
/// baz: 0,
/// quux: bool
/// }
/// ];
///
/// sqlx::query!(
/// "
/// INSERT INTO foo(bar, baz, quux)
/// SELECT * FROM UNNEST($1, $2, $3)
/// ",
/// // type overrides are necessary for the macros to accept this instead of `[String]`, etc.
/// Array(foos.iter().map(|foo| &foo.bar)) as _,
/// Array(foos.iter().map(|foo| foo.baz)) as _,
/// Array(foos.iter().map(|foo| foo.quux)) as _
/// )
/// .execute(&pool)
/// .await?;
/// ```
///
/// #### (Postgres) Deserialize a Different Type than `Vec<T>`
///
/// ```sql,ignore
/// CREATE TABLE media(
/// id BIGSERIAL PRIMARY KEY,
/// filename TEXT NOT NULL,
/// tags TEXT[] NOT NULL
/// )
/// ```
///
/// ```rust,ignore
/// use sqlx::types::Array;
///
/// use std::collections::HashSet;
///
/// struct Media {
/// id: i32,
/// filename: String,
/// tags: Array<HashSet<T>>,
/// }
///
/// let media: Vec<Media> = sqlx::query_as!(
/// r#"
/// SELECT id, filename, tags AS "tags: Array<HashSet<_>>"
/// "#
/// )
/// .fetch_all(&pool)
/// .await?;
/// ```
#[derive(Debug)]
pub struct Array<I>(pub I);
impl<I> Array<I> {
pub fn into_inner(self) -> I {
self.0
}
}
impl<I> Deref for Array<I> {
type Target = I;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<I> DerefMut for Array<I> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl<I> From<I> for Array<I> {
fn from(iterable: I) -> Self {
Self(iterable)
}
}
// orphan trait impl error
// impl<I> From<Array<I>> for I {
// fn from(array: Array<I>) -> Self {
// array.0
// }
// }

View file

@ -20,6 +20,10 @@
use crate::database::Database;
#[cfg(feature = "array")]
#[cfg_attr(docsrs, doc(cfg(feature = "array")))]
mod array;
#[cfg(feature = "bstr")]
#[cfg_attr(docsrs, doc(cfg(feature = "bstr")))]
pub mod bstr;
@ -75,6 +79,9 @@ pub mod ipnetwork {
pub use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network};
}
#[cfg(feature = "array")]
pub use array::Array;
#[cfg(feature = "json")]
pub use json::Json;

View file

@ -476,3 +476,20 @@ test_prepared_type!(money<PgMoney>(Postgres, "123.45::money" == PgMoney(12345)))
test_prepared_type!(money_vec<Vec<PgMoney>>(Postgres,
"array[123.45,420.00,666.66]::money[]" == vec![PgMoney(12345), PgMoney(42000), PgMoney(66666)],
));
mod array {
use sqlx::types::Array;
use std::collections::HashSet;
macro_rules! set [
($($item:expr),*) => {{
let mut set = HashSet::new();
$(set.insert($item);)*
set
}}
];
test_type!(array_to_hashset<String>(Postgres,
"array['foo', 'bar', 'baz']" == Array(set!["foo", "bar", "baz"])));
}