mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
feat(postgres): add support for built-in range types and allow derives to handle custom range types
Co-authored-by: Caio <c410.f3r@gmail.com>
This commit is contained in:
parent
fedd883d91
commit
c9f3e1adca
24 changed files with 922 additions and 861 deletions
|
@ -33,6 +33,8 @@ pub trait Encode<'q, DB: Database> {
|
|||
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull;
|
||||
|
||||
fn produces(&self) -> Option<DB::TypeInfo> {
|
||||
// `produces` is inherently a hook to allow database drivers to produce value-dependent
|
||||
// type information; if the driver doesn't need this, it can leave this as `None`
|
||||
None
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,8 @@ use crate::database::Database;
|
|||
pub type Result<T> = StdResult<T, Error>;
|
||||
|
||||
// Convenience type alias for usage within SQLx.
|
||||
pub type BoxDynError = Box<dyn StdError + 'static + Send + Sync>;
|
||||
// Do not make this type public.
|
||||
pub(crate) type BoxDynError = Box<dyn StdError + 'static + Send + Sync>;
|
||||
|
||||
/// An unexpected `NULL` was encountered during decoding.
|
||||
///
|
||||
|
|
|
@ -21,6 +21,15 @@ impl MySqlTypeInfo {
|
|||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub const fn __enum() -> Self {
|
||||
Self {
|
||||
r#type: ColumnType::Enum,
|
||||
flags: ColumnFlags::BINARY,
|
||||
char_set: 63,
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn __type_feature_gate(&self) -> Option<&'static str> {
|
||||
match self.r#type {
|
||||
|
|
|
@ -47,26 +47,34 @@ impl<'q> Arguments<'q> for PgArguments {
|
|||
self.types
|
||||
.push(value.produces().unwrap_or_else(T::type_info));
|
||||
|
||||
// reserve space to write the prefixed length of the value
|
||||
let offset = self.buffer.len();
|
||||
self.buffer.extend(&[0; 4]);
|
||||
|
||||
// encode the value into our buffer
|
||||
let len = if let IsNull::No = value.encode(&mut self.buffer) {
|
||||
(self.buffer.len() - offset - 4) as i32
|
||||
} else {
|
||||
// Write a -1 to indicate NULL
|
||||
// NOTE: It is illegal for [encode] to write any data
|
||||
debug_assert_eq!(self.buffer.len(), offset + 4);
|
||||
-1_i32
|
||||
};
|
||||
|
||||
// write the len to the beginning of the value
|
||||
self.buffer.buffer[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
|
||||
self.buffer.encode(value);
|
||||
}
|
||||
}
|
||||
|
||||
impl PgArgumentBuffer {
|
||||
pub(crate) fn encode<'q, T>(&mut self, value: T)
|
||||
where
|
||||
T: Encode<'q, Postgres>,
|
||||
{
|
||||
// reserve space to write the prefixed length of the value
|
||||
let offset = self.len();
|
||||
self.extend(&[0; 4]);
|
||||
|
||||
// encode the value into our buffer
|
||||
let len = if let IsNull::No = value.encode(self) {
|
||||
(self.len() - offset - 4) as i32
|
||||
} else {
|
||||
// Write a -1 to indicate NULL
|
||||
// NOTE: It is illegal for [encode] to write any data
|
||||
debug_assert_eq!(self.len(), offset + 4);
|
||||
-1_i32
|
||||
};
|
||||
|
||||
// write the len to the beginning of the value
|
||||
self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
|
||||
}
|
||||
|
||||
// Extends the inner buffer by enough space to have an OID
|
||||
// Remembers where the OID goes and type name for the OID
|
||||
pub(crate) fn push_type_hole(&mut self, type_name: &UStr) {
|
||||
|
@ -81,7 +89,7 @@ impl PgArgumentBuffer {
|
|||
pub(crate) async fn patch_type_holes(&mut self, conn: &mut PgConnection) -> Result<(), Error> {
|
||||
for (offset, name) in &self.type_holes {
|
||||
let oid = conn.fetch_type_id_by_name(&*name).await?;
|
||||
self.buffer[*offset..].copy_from_slice(&oid.to_be_bytes());
|
||||
self.buffer[*offset..(*offset + 4)].copy_from_slice(&oid.to_be_bytes());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -209,27 +209,31 @@ ORDER BY attnum
|
|||
})
|
||||
}
|
||||
|
||||
async fn fetch_range_by_oid(&mut self, oid: u32, name: String) -> Result<PgTypeInfo, Error> {
|
||||
let _: i32 = query_scalar(
|
||||
r#"
|
||||
SELECT 1
|
||||
fn fetch_range_by_oid(
|
||||
&mut self,
|
||||
oid: u32,
|
||||
name: String,
|
||||
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
|
||||
Box::pin(async move {
|
||||
let element_oid: u32 = query_scalar(
|
||||
r#"
|
||||
SELECT rngsubtype
|
||||
FROM pg_catalog.pg_range
|
||||
WHERE rngtypid = $1
|
||||
"#,
|
||||
)
|
||||
.bind(oid)
|
||||
.fetch_one(self)
|
||||
.await?;
|
||||
"#,
|
||||
)
|
||||
.bind(oid)
|
||||
.fetch_one(&mut *self)
|
||||
.await?;
|
||||
|
||||
let pg_type = PgType::try_from_oid(oid).ok_or_else(|| {
|
||||
err_protocol!("Trying to retrieve a DB type that doesn't exist in SQLx")
|
||||
})?;
|
||||
let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?;
|
||||
|
||||
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
|
||||
kind: PgTypeKind::Range(PgTypeInfo(pg_type)),
|
||||
name: name.into(),
|
||||
oid,
|
||||
}))))
|
||||
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
|
||||
kind: PgTypeKind::Range(element),
|
||||
name: name.into(),
|
||||
oid,
|
||||
}))))
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<u32, Error> {
|
||||
|
|
|
@ -28,15 +28,22 @@ async fn prepare(
|
|||
// additional queries here to get any missing OIDs
|
||||
|
||||
let mut param_types = Vec::with_capacity(arguments.types.len());
|
||||
let mut has_fetched = false;
|
||||
|
||||
for ty in &arguments.types {
|
||||
param_types.push(if let PgType::DeclareWithName(name) = &ty.0 {
|
||||
has_fetched = true;
|
||||
conn.fetch_type_id_by_name(name).await?
|
||||
} else {
|
||||
ty.0.oid()
|
||||
});
|
||||
}
|
||||
|
||||
// flush and wait until we are re-ready
|
||||
if has_fetched {
|
||||
conn.wait_until_ready().await?;
|
||||
}
|
||||
|
||||
// next we send the PARSE command to the server
|
||||
conn.stream.write(Parse {
|
||||
param_types: &*param_types,
|
||||
|
@ -111,6 +118,18 @@ impl PgConnection {
|
|||
// patch holes created during encoding
|
||||
arguments.buffer.patch_type_holes(self).await?;
|
||||
|
||||
// describe the statement and, again, ask the server to immediately respond
|
||||
// we need to fully realize the types
|
||||
self.stream.write(message::Describe::Statement(statement));
|
||||
self.stream.write(message::Flush);
|
||||
self.stream.flush().await?;
|
||||
|
||||
let _ = recv_desc_params(self).await?;
|
||||
let rows = recv_desc_rows(self).await?;
|
||||
|
||||
self.handle_row_description(rows, true).await?;
|
||||
self.wait_until_ready().await?;
|
||||
|
||||
// bind to attach the arguments to the statement and create a portal
|
||||
self.stream.write(Bind {
|
||||
portal: None,
|
||||
|
@ -121,17 +140,6 @@ impl PgConnection {
|
|||
result_formats: &[PgValueFormat::Binary],
|
||||
});
|
||||
|
||||
// describe the portal and, again, ask the server to immediately respond
|
||||
// we need to fully realize the types
|
||||
self.stream.write(message::Describe::UnnamedPortal);
|
||||
self.stream.write(Flush);
|
||||
self.stream.flush().await?;
|
||||
|
||||
let _ = self.stream.recv_expect(MessageFormat::BindComplete).await?;
|
||||
|
||||
let rows = recv_desc_rows(self).await?;
|
||||
self.handle_row_description(rows, true).await?;
|
||||
|
||||
// executes the portal up to the passed limit
|
||||
// the protocol-level limit acts nearly identically to the `LIMIT` in SQL
|
||||
self.stream.write(message::Execute {
|
||||
|
|
|
@ -11,7 +11,7 @@ use crate::type_info::TypeInfo;
|
|||
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct PgTypeInfo(pub(crate) PgType);
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
|
||||
#[repr(u32)]
|
||||
pub(crate) enum PgType {
|
||||
|
@ -197,7 +197,12 @@ impl PgTypeInfo {
|
|||
Self(PgType::DeclareWithName(UStr::Static(name)))
|
||||
}
|
||||
|
||||
pub(crate) const fn with_oid(oid: u32) -> Self {
|
||||
/// Create a `PgTypeInfo` from an OID.
|
||||
///
|
||||
/// Note that the OID for a type is very dependent on the environment. If you only ever use
|
||||
/// one database or if this is an unhandled build-in type, you should be fine. Otherwise,
|
||||
/// you will be better served using [`with_name`](#method.with_name).
|
||||
pub const fn with_oid(oid: u32) -> Self {
|
||||
Self(PgType::DeclareWithOid(oid))
|
||||
}
|
||||
}
|
||||
|
@ -308,7 +313,14 @@ impl PgType {
|
|||
}
|
||||
|
||||
pub(crate) fn oid(&self) -> u32 {
|
||||
match self {
|
||||
match self.try_oid() {
|
||||
Some(oid) => oid,
|
||||
None => unreachable!("(bug) use of unresolved type declaration [oid]"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn try_oid(&self) -> Option<u32> {
|
||||
Some(match self {
|
||||
PgType::Bool => 16,
|
||||
PgType::Bytea => 17,
|
||||
PgType::Char => 18,
|
||||
|
@ -400,8 +412,10 @@ impl PgType {
|
|||
PgType::Custom(ty) => ty.oid,
|
||||
|
||||
PgType::DeclareWithOid(oid) => *oid,
|
||||
PgType::DeclareWithName(_) => unreachable!("(bug) use of unresolved type declaration"),
|
||||
}
|
||||
PgType::DeclareWithName(_) => {
|
||||
return None;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
|
@ -576,24 +590,24 @@ impl PgType {
|
|||
PgType::UuidArray => &PgTypeKind::Array(PgTypeInfo(PgType::Uuid)),
|
||||
PgType::Jsonb => &PgTypeKind::Simple,
|
||||
PgType::JsonbArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonb)),
|
||||
PgType::Int4Range => &PgTypeKind::Simple,
|
||||
PgType::Int4Range => &PgTypeKind::Range(PgTypeInfo::INT4),
|
||||
PgType::Int4RangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Int4Range)),
|
||||
PgType::NumRange => &PgTypeKind::Simple,
|
||||
PgType::NumRange => &PgTypeKind::Range(PgTypeInfo::NUMERIC),
|
||||
PgType::NumRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::NumRange)),
|
||||
PgType::TsRange => &PgTypeKind::Simple,
|
||||
PgType::TsRange => &PgTypeKind::Range(PgTypeInfo::TIMESTAMP),
|
||||
PgType::TsRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::TsRange)),
|
||||
PgType::TstzRange => &PgTypeKind::Simple,
|
||||
PgType::TstzRange => &PgTypeKind::Range(PgTypeInfo::TIMESTAMPTZ),
|
||||
PgType::TstzRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::TstzRange)),
|
||||
PgType::DateRange => &PgTypeKind::Simple,
|
||||
PgType::DateRange => &PgTypeKind::Range(PgTypeInfo::DATE),
|
||||
PgType::DateRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::DateRange)),
|
||||
PgType::Int8Range => &PgTypeKind::Simple,
|
||||
PgType::Int8Range => &PgTypeKind::Range(PgTypeInfo::INT8),
|
||||
PgType::Int8RangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Int8Range)),
|
||||
PgType::Jsonpath => &PgTypeKind::Simple,
|
||||
PgType::JsonpathArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonpath)),
|
||||
PgType::Custom(ty) => &ty.kind,
|
||||
|
||||
PgType::DeclareWithOid(_) | PgType::DeclareWithName(_) => {
|
||||
unreachable!("(bug) use of unresolved type declaration")
|
||||
unreachable!("(bug) use of unresolved type declaration [kind]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -817,3 +831,15 @@ impl Display for PgTypeInfo {
|
|||
f.pad(self.0.name())
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<PgType> for PgType {
|
||||
fn eq(&self, other: &PgType) -> bool {
|
||||
if let (Some(a), Some(b)) = (self.try_oid(), other.try_oid()) {
|
||||
// If there are OIDs available, use OIDs to perform a direct match
|
||||
a == b
|
||||
} else {
|
||||
// Otherwise, perform a match on the name
|
||||
self.name().eq_ignore_ascii_case(other.name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,22 +63,7 @@ where
|
|||
buf.extend(&1_i32.to_be_bytes()); // lower bound
|
||||
|
||||
for element in self.iter() {
|
||||
// allocate space for the length of the encoded element
|
||||
let el_len_offset = buf.len();
|
||||
buf.extend(&0_i32.to_be_bytes());
|
||||
|
||||
let el_start = buf.len();
|
||||
|
||||
if let IsNull::Yes = element.encode_by_ref(buf) {
|
||||
// NULL is encoded as -1 for a length
|
||||
buf[el_len_offset..el_start].copy_from_slice(&(-1_i32).to_be_bytes());
|
||||
} else {
|
||||
let el_end = buf.len();
|
||||
let el_len = el_end - el_start;
|
||||
|
||||
// now we can go back and update the length
|
||||
buf[el_len_offset..el_start].copy_from_slice(&(el_len as i32).to_be_bytes());
|
||||
}
|
||||
buf.encode(element);
|
||||
}
|
||||
|
||||
IsNull::No
|
||||
|
@ -144,23 +129,11 @@ where
|
|||
let mut elements = Vec::with_capacity(len as usize);
|
||||
|
||||
for _ in 0..len {
|
||||
let mut element_len = buf.get_i32();
|
||||
|
||||
let element_val = if element_len == -1 {
|
||||
element_len = 0;
|
||||
None
|
||||
} else {
|
||||
Some(&buf[..(element_len as usize)])
|
||||
};
|
||||
|
||||
elements.push(T::decode(PgValueRef {
|
||||
value: element_val,
|
||||
row: None,
|
||||
type_info: element_type_info.clone(),
|
||||
elements.push(T::decode(PgValueRef::get(
|
||||
&mut buf,
|
||||
format,
|
||||
})?);
|
||||
|
||||
buf.advance(element_len as usize);
|
||||
element_type_info.clone(),
|
||||
))?)
|
||||
}
|
||||
|
||||
Ok(elements)
|
||||
|
|
|
@ -132,8 +132,8 @@ mod array;
|
|||
mod bool;
|
||||
mod bytes;
|
||||
mod float;
|
||||
mod num;
|
||||
mod ranges;
|
||||
mod int;
|
||||
mod range;
|
||||
mod record;
|
||||
mod str;
|
||||
mod tuple;
|
||||
|
@ -159,7 +159,9 @@ mod json;
|
|||
#[cfg(feature = "ipnetwork")]
|
||||
mod ipnetwork;
|
||||
|
||||
pub use {
|
||||
ranges::{pg_range::PgRange, pg_ranges::*},
|
||||
record::{PgRecordDecoder, PgRecordEncoder},
|
||||
};
|
||||
pub use range::PgRange;
|
||||
|
||||
// used in derive(Type) for `struct`
|
||||
// but the interface is not considered part of the public API
|
||||
#[doc(hidden)]
|
||||
pub use record::{PgRecordDecoder, PgRecordEncoder};
|
||||
|
|
530
sqlx-core/src/postgres/types/range.rs
Normal file
530
sqlx-core/src/postgres/types/range.rs
Normal file
|
@ -0,0 +1,530 @@
|
|||
use std::fmt::{self, Debug, Display, Formatter};
|
||||
use std::ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive};
|
||||
|
||||
use bitflags::bitflags;
|
||||
use bytes::Buf;
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::error::BoxDynError;
|
||||
use crate::postgres::{
|
||||
PgArgumentBuffer, PgTypeInfo, PgTypeKind, PgValueFormat, PgValueRef, Postgres,
|
||||
};
|
||||
use crate::types::Type;
|
||||
|
||||
// https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/include/utils/rangetypes.h#L35-L44
|
||||
bitflags! {
|
||||
struct RangeFlags: u8 {
|
||||
const EMPTY = 0x01;
|
||||
const LB_INC = 0x02;
|
||||
const UB_INC = 0x04;
|
||||
const LB_INF = 0x08;
|
||||
const UB_INF = 0x10;
|
||||
const LB_NULL = 0x20; // not used
|
||||
const UB_NULL = 0x40; // not used
|
||||
const CONTAIN_EMPTY = 0x80; // internal
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct PgRange<T> {
|
||||
pub start: Bound<T>,
|
||||
pub end: Bound<T>,
|
||||
}
|
||||
|
||||
impl<T> From<[Bound<T>; 2]> for PgRange<T> {
|
||||
fn from(v: [Bound<T>; 2]) -> Self {
|
||||
let [start, end] = v;
|
||||
Self { start, end }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<(Bound<T>, Bound<T>)> for PgRange<T> {
|
||||
fn from(v: (Bound<T>, Bound<T>)) -> Self {
|
||||
Self {
|
||||
start: v.0,
|
||||
end: v.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<Range<T>> for PgRange<T> {
|
||||
fn from(v: Range<T>) -> Self {
|
||||
Self {
|
||||
start: Bound::Included(v.start),
|
||||
end: Bound::Excluded(v.end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<RangeFrom<T>> for PgRange<T> {
|
||||
fn from(v: RangeFrom<T>) -> Self {
|
||||
Self {
|
||||
start: Bound::Included(v.start),
|
||||
end: Bound::Unbounded,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<RangeInclusive<T>> for PgRange<T> {
|
||||
fn from(v: RangeInclusive<T>) -> Self {
|
||||
let (start, end) = v.into_inner();
|
||||
Self {
|
||||
start: Bound::Included(start),
|
||||
end: Bound::Included(end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<RangeTo<T>> for PgRange<T> {
|
||||
fn from(v: RangeTo<T>) -> Self {
|
||||
Self {
|
||||
start: Bound::Unbounded,
|
||||
end: Bound::Excluded(v.end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<RangeToInclusive<T>> for PgRange<T> {
|
||||
fn from(v: RangeToInclusive<T>) -> Self {
|
||||
Self {
|
||||
start: Bound::Unbounded,
|
||||
end: Bound::Included(v.end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> RangeBounds<T> for PgRange<T> {
|
||||
fn start_bound(&self) -> Bound<&T> {
|
||||
match self.start {
|
||||
Bound::Included(ref start) => Bound::Included(start),
|
||||
Bound::Excluded(ref start) => Bound::Excluded(start),
|
||||
Bound::Unbounded => Bound::Unbounded,
|
||||
}
|
||||
}
|
||||
|
||||
fn end_bound(&self) -> Bound<&T> {
|
||||
match self.end {
|
||||
Bound::Included(ref end) => Bound::Included(end),
|
||||
Bound::Excluded(ref end) => Bound::Excluded(end),
|
||||
Bound::Unbounded => Bound::Unbounded,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for PgRange<i32> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::INT4_RANGE
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for PgRange<i64> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::INT8_RANGE
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "bigdecimal")]
|
||||
impl Type<Postgres> for PgRange<bigdecimal::BigDecimal> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::NUM_RANGE
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
impl Type<Postgres> for PgRange<chrono::NaiveDate> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::DATE_RANGE
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
impl Type<Postgres> for PgRange<chrono::NaiveDateTime> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TS_RANGE
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
impl<Tz: chrono::TimeZone> Type<Postgres> for PgRange<chrono::DateTime<Tz>> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TSTZ_RANGE
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "time")]
|
||||
impl Type<Postgres> for PgRange<time::Date> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::DATE_RANGE
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "time")]
|
||||
impl Type<Postgres> for PgRange<time::PrimitiveDateTime> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TS_RANGE
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "time")]
|
||||
impl Type<Postgres> for PgRange<time::OffsetDateTime> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TSTZ_RANGE
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for [PgRange<i32>] {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::INT4_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for [PgRange<i64>] {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::INT8_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "bigdecimal")]
|
||||
impl Type<Postgres> for [PgRange<bigdecimal::BigDecimal>] {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::NUM_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
impl Type<Postgres> for [PgRange<chrono::NaiveDate>] {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::DATE_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
impl Type<Postgres> for [PgRange<chrono::NaiveDateTime>] {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TS_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
impl<Tz: chrono::TimeZone> Type<Postgres> for [PgRange<chrono::DateTime<Tz>>] {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TSTZ_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "time")]
|
||||
impl Type<Postgres> for [PgRange<time::Date>] {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::DATE_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "time")]
|
||||
impl Type<Postgres> for [PgRange<time::PrimitiveDateTime>] {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TS_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "time")]
|
||||
impl Type<Postgres> for [PgRange<time::OffsetDateTime>] {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TSTZ_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for Vec<PgRange<i32>> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::INT4_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for Vec<PgRange<i64>> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::INT8_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "bigdecimal")]
|
||||
impl Type<Postgres> for Vec<PgRange<bigdecimal::BigDecimal>> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::NUM_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
impl Type<Postgres> for Vec<PgRange<chrono::NaiveDate>> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::DATE_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
impl Type<Postgres> for Vec<PgRange<chrono::NaiveDateTime>> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TS_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "chrono")]
|
||||
impl<Tz: chrono::TimeZone> Type<Postgres> for Vec<PgRange<chrono::DateTime<Tz>>> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TSTZ_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "time")]
|
||||
impl Type<Postgres> for Vec<PgRange<time::Date>> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::DATE_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "time")]
|
||||
impl Type<Postgres> for Vec<PgRange<time::PrimitiveDateTime>> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TS_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "time")]
|
||||
impl Type<Postgres> for Vec<PgRange<time::OffsetDateTime>> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::TSTZ_RANGE_ARRAY
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q, T> Encode<'q, Postgres> for PgRange<T>
|
||||
where
|
||||
T: Encode<'q, Postgres>,
|
||||
{
|
||||
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
|
||||
// https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L245
|
||||
|
||||
let mut flags = RangeFlags::empty();
|
||||
|
||||
flags |= match self.start {
|
||||
Bound::Included(_) => RangeFlags::LB_INC,
|
||||
Bound::Unbounded => RangeFlags::LB_INF,
|
||||
Bound::Excluded(_) => RangeFlags::empty(),
|
||||
};
|
||||
|
||||
flags |= match self.end {
|
||||
Bound::Included(_) => RangeFlags::UB_INC,
|
||||
Bound::Unbounded => RangeFlags::UB_INF,
|
||||
Bound::Excluded(_) => RangeFlags::empty(),
|
||||
};
|
||||
|
||||
buf.push(flags.bits());
|
||||
|
||||
if let Bound::Included(v) | Bound::Excluded(v) = &self.start {
|
||||
buf.encode(v);
|
||||
}
|
||||
|
||||
if let Bound::Included(v) | Bound::Excluded(v) = &self.end {
|
||||
buf.encode(v);
|
||||
}
|
||||
|
||||
// ranges are themselves never null
|
||||
IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, T> Decode<'r, Postgres> for PgRange<T>
|
||||
where
|
||||
T: Type<Postgres> + for<'a> Decode<'a, Postgres>,
|
||||
{
|
||||
fn accepts(ty: &PgTypeInfo) -> bool {
|
||||
// we require the declared type to be a _range_ with an
|
||||
// element type that is acceptable
|
||||
if let PgTypeKind::Range(element) = &ty.0.kind() {
|
||||
return T::accepts(&element);
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
|
||||
match value.format {
|
||||
PgValueFormat::Binary => {
|
||||
let element_ty = if let PgTypeKind::Range(element) = &value.type_info.0.kind() {
|
||||
element
|
||||
} else {
|
||||
return Err(format!("unexpected non-range type {}", value.type_info).into());
|
||||
};
|
||||
|
||||
let mut buf = value.as_bytes()?;
|
||||
|
||||
let mut start = Bound::Unbounded;
|
||||
let mut end = Bound::Unbounded;
|
||||
|
||||
let flags = RangeFlags::from_bits_truncate(buf.get_u8());
|
||||
|
||||
if flags.contains(RangeFlags::EMPTY) {
|
||||
return Ok(PgRange { start, end });
|
||||
}
|
||||
|
||||
if !flags.contains(RangeFlags::LB_INF) {
|
||||
let value =
|
||||
T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?;
|
||||
|
||||
start = if flags.contains(RangeFlags::LB_INC) {
|
||||
Bound::Included(value)
|
||||
} else {
|
||||
Bound::Excluded(value)
|
||||
};
|
||||
}
|
||||
|
||||
if !flags.contains(RangeFlags::UB_INF) {
|
||||
let value =
|
||||
T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?;
|
||||
|
||||
end = if flags.contains(RangeFlags::UB_INC) {
|
||||
Bound::Included(value)
|
||||
} else {
|
||||
Bound::Excluded(value)
|
||||
};
|
||||
}
|
||||
|
||||
Ok(PgRange { start, end })
|
||||
}
|
||||
|
||||
PgValueFormat::Text => {
|
||||
// https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L2046
|
||||
|
||||
let mut start = None;
|
||||
let mut end = None;
|
||||
|
||||
let s = value.as_str()?;
|
||||
|
||||
// remember the bounds
|
||||
let sb = s.as_bytes();
|
||||
let lower = sb[0] as char;
|
||||
let upper = sb[sb.len() - 1] as char;
|
||||
|
||||
// trim the wrapping braces/brackets
|
||||
let s = &s[1..(s.len() - 1)];
|
||||
|
||||
let mut chars = s.chars();
|
||||
|
||||
let mut element = String::new();
|
||||
let mut done = false;
|
||||
let mut quoted = false;
|
||||
let mut in_quotes = false;
|
||||
let mut in_escape = false;
|
||||
let mut prev_ch = '\0';
|
||||
let mut count = 0;
|
||||
|
||||
while !done {
|
||||
element.clear();
|
||||
|
||||
loop {
|
||||
match chars.next() {
|
||||
Some(ch) => {
|
||||
match ch {
|
||||
_ if in_escape => {
|
||||
element.push(ch);
|
||||
in_escape = false;
|
||||
}
|
||||
|
||||
'"' if in_quotes => {
|
||||
in_quotes = false;
|
||||
}
|
||||
|
||||
'"' => {
|
||||
in_quotes = true;
|
||||
quoted = true;
|
||||
|
||||
if prev_ch == '"' {
|
||||
element.push('"')
|
||||
}
|
||||
}
|
||||
|
||||
'\\' if !in_escape => {
|
||||
in_escape = true;
|
||||
}
|
||||
|
||||
',' if !in_quotes => break,
|
||||
|
||||
_ => {
|
||||
element.push(ch);
|
||||
}
|
||||
}
|
||||
prev_ch = ch;
|
||||
}
|
||||
|
||||
None => {
|
||||
done = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
count += 1;
|
||||
if !(element.is_empty() && !quoted) {
|
||||
let value = Some(T::decode(PgValueRef {
|
||||
type_info: T::type_info(),
|
||||
format: PgValueFormat::Text,
|
||||
value: Some(element.as_bytes()),
|
||||
row: None,
|
||||
})?);
|
||||
|
||||
if count == 1 {
|
||||
start = value;
|
||||
} else if count == 2 {
|
||||
end = value;
|
||||
} else {
|
||||
return Err("more than 2 elements found in a range".into());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let start = parse_bound(lower, start)?;
|
||||
let end = parse_bound(upper, end)?;
|
||||
|
||||
Ok(PgRange { start, end })
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_bound<T>(ch: char, value: Option<T>) -> Result<Bound<T>, BoxDynError> {
|
||||
Ok(if let Some(value) = value {
|
||||
match ch {
|
||||
'(' | ')' => Bound::Excluded(value),
|
||||
'[' | ']' => Bound::Included(value),
|
||||
|
||||
_ => {
|
||||
return Err(format!(
|
||||
"expected `(`, ')', '[', or `]` but found `{}` for range literal",
|
||||
ch
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Bound::Unbounded
|
||||
})
|
||||
}
|
||||
|
||||
impl<T> Display for PgRange<T>
|
||||
where
|
||||
T: Display,
|
||||
{
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
match &self.start {
|
||||
Bound::Unbounded => f.write_str("(,")?,
|
||||
Bound::Excluded(v) => write!(f, "({},", v)?,
|
||||
Bound::Included(v) => write!(f, "[{},", v)?,
|
||||
}
|
||||
|
||||
match &self.end {
|
||||
Bound::Unbounded => f.write_str(")")?,
|
||||
Bound::Excluded(v) => write!(f, "{})", v)?,
|
||||
Bound::Included(v) => write!(f, "{}]", v)?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,87 +0,0 @@
|
|||
pub(crate) mod pg_range;
|
||||
pub(crate) mod pg_ranges;
|
||||
|
||||
use crate::{
|
||||
decode::Decode,
|
||||
encode::{Encode, IsNull},
|
||||
postgres::{types::PgRange, PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres},
|
||||
types::Type,
|
||||
};
|
||||
use core::{
|
||||
convert::TryInto,
|
||||
ops::{Range, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive},
|
||||
};
|
||||
|
||||
macro_rules! impl_range {
|
||||
($range:ident) => {
|
||||
impl<'a, T> Decode<'a, Postgres> for $range<T>
|
||||
where
|
||||
T: for<'b> Decode<'b, Postgres> + Type<Postgres> + 'a,
|
||||
{
|
||||
fn accepts(ty: &PgTypeInfo) -> bool {
|
||||
<PgRange<T> as Decode<'_, Postgres>>::accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: PgValueRef<'a>) -> Result<$range<T>, crate::error::BoxDynError> {
|
||||
let bounds: PgRange<T> = Decode::<Postgres>::decode(value)?;
|
||||
let rslt = bounds.try_into()?;
|
||||
Ok(rslt)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Encode<'a, Postgres> for $range<T>
|
||||
where
|
||||
T: Clone + for<'b> Encode<'b, Postgres> + 'a,
|
||||
{
|
||||
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
|
||||
<PgRange<T> as Encode<'_, Postgres>>::encode(self.clone().into(), buf)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_range!(Range);
|
||||
impl_range!(RangeFrom);
|
||||
impl_range!(RangeInclusive);
|
||||
impl_range!(RangeTo);
|
||||
impl_range!(RangeToInclusive);
|
||||
|
||||
#[test]
|
||||
fn test_decode_str_bounds() {
|
||||
use crate::postgres::type_info::PgType;
|
||||
|
||||
const EXC1: Bound<i32> = Bound::Excluded(1);
|
||||
const EXC2: Bound<i32> = Bound::Excluded(2);
|
||||
const INC1: Bound<i32> = Bound::Included(1);
|
||||
const INC2: Bound<i32> = Bound::Included(2);
|
||||
const UNB: Bound<i32> = Bound::Unbounded;
|
||||
|
||||
let check = |s: &str, range_cmp: [Bound<i32>; 2]| {
|
||||
let pg_value = PgValueRef {
|
||||
type_info: PgTypeInfo(PgType::Int4Range),
|
||||
format: PgValueFormat::Text,
|
||||
value: Some(s.as_bytes()),
|
||||
row: None,
|
||||
};
|
||||
let range: PgRange<i32> = Decode::<Postgres>::decode(pg_value).unwrap();
|
||||
assert_eq!(Into::<[Bound<i32>; 2]>::into(range), range_cmp);
|
||||
};
|
||||
|
||||
check("(,)", [UNB, UNB]);
|
||||
check("(,]", [UNB, UNB]);
|
||||
check("(,2)", [UNB, EXC2]);
|
||||
check("(,2]", [UNB, INC2]);
|
||||
check("(1,)", [EXC1, UNB]);
|
||||
check("(1,]", [EXC1, UNB]);
|
||||
check("(1,2)", [EXC1, EXC2]);
|
||||
check("(1,2]", [EXC1, INC2]);
|
||||
|
||||
check("[,)", [UNB, UNB]);
|
||||
check("[,]", [UNB, UNB]);
|
||||
check("[,2)", [UNB, EXC2]);
|
||||
check("[,2]", [UNB, INC2]);
|
||||
check("[1,)", [INC1, UNB]);
|
||||
check("[1,]", [INC1, UNB]);
|
||||
check("[1,2)", [INC1, EXC2]);
|
||||
check("[1,2]", [INC1, INC2]);
|
||||
}
|
|
@ -1,385 +0,0 @@
|
|||
use crate::{
|
||||
decode::Decode,
|
||||
encode::{Encode, IsNull},
|
||||
postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres},
|
||||
types::Type,
|
||||
};
|
||||
use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
|
||||
use core::{
|
||||
convert::TryFrom,
|
||||
ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive},
|
||||
};
|
||||
|
||||
bitflags::bitflags! {
|
||||
struct RangeFlags: u8 {
|
||||
const EMPTY = 0x01;
|
||||
const LB_INC = 0x02;
|
||||
const UB_INC = 0x04;
|
||||
const LB_INF = 0x08;
|
||||
const UB_INF = 0x10;
|
||||
const LB_NULL = 0x20;
|
||||
const UB_NULL = 0x40;
|
||||
const CONTAIN_EMPTY = 0x80;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
pub struct PgRange<T> {
|
||||
pub start: Bound<T>,
|
||||
pub end: Bound<T>,
|
||||
}
|
||||
|
||||
impl<T> PgRange<T> {
|
||||
pub fn new(start: Bound<T>, end: Bound<T>) -> Self {
|
||||
Self { start, end }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Decode<'a, Postgres> for PgRange<T>
|
||||
where
|
||||
T: for<'b> Decode<'b, Postgres> + Type<Postgres> + 'a,
|
||||
{
|
||||
fn accepts(ty: &PgTypeInfo) -> bool {
|
||||
[
|
||||
PgTypeInfo::INT4_RANGE,
|
||||
PgTypeInfo::NUM_RANGE,
|
||||
PgTypeInfo::TS_RANGE,
|
||||
PgTypeInfo::TSTZ_RANGE,
|
||||
PgTypeInfo::DATE_RANGE,
|
||||
PgTypeInfo::INT8_RANGE,
|
||||
]
|
||||
.contains(ty)
|
||||
}
|
||||
|
||||
fn decode(value: PgValueRef<'a>) -> Result<PgRange<T>, crate::error::BoxDynError> {
|
||||
match value.format() {
|
||||
PgValueFormat::Binary => {
|
||||
decode_binary(value.as_bytes()?, value.format, value.type_info)
|
||||
}
|
||||
PgValueFormat::Text => decode_str(value.as_str()?, value.format(), value.type_info),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Encode<'a, Postgres> for PgRange<T>
|
||||
where
|
||||
T: for<'b> Encode<'b, Postgres> + 'a,
|
||||
{
|
||||
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
|
||||
let mut flags = match self.start {
|
||||
Bound::Included(_) => RangeFlags::LB_INC,
|
||||
Bound::Excluded(_) => RangeFlags::empty(),
|
||||
Bound::Unbounded => RangeFlags::LB_INF,
|
||||
};
|
||||
|
||||
flags |= match self.end {
|
||||
Bound::Included(_) => RangeFlags::UB_INC,
|
||||
Bound::Excluded(_) => RangeFlags::empty(),
|
||||
Bound::Unbounded => RangeFlags::UB_INF,
|
||||
};
|
||||
|
||||
buf.write_u8(flags.bits()).unwrap();
|
||||
|
||||
let mut write = |bound: &Bound<T>| -> IsNull {
|
||||
match bound {
|
||||
Bound::Included(ref value) | Bound::Excluded(ref value) => {
|
||||
buf.write_u32::<NetworkEndian>(0).unwrap();
|
||||
let prev = buf.len();
|
||||
if let IsNull::Yes = Encode::<Postgres>::encode(value, buf) {
|
||||
return IsNull::Yes;
|
||||
}
|
||||
let len = buf.len() - prev;
|
||||
buf[prev - 4..prev].copy_from_slice(&(len as u32).to_be_bytes());
|
||||
}
|
||||
Bound::Unbounded => {}
|
||||
}
|
||||
IsNull::No
|
||||
};
|
||||
|
||||
if let IsNull::Yes = write(&self.start) {
|
||||
return IsNull::Yes;
|
||||
}
|
||||
write(&self.end)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<[Bound<T>; 2]> for PgRange<T> {
|
||||
fn from(from: [Bound<T>; 2]) -> Self {
|
||||
let [start, end] = from;
|
||||
Self { start, end }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<(Bound<T>, Bound<T>)> for PgRange<T> {
|
||||
fn from(from: (Bound<T>, Bound<T>)) -> Self {
|
||||
Self {
|
||||
start: from.0,
|
||||
end: from.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<PgRange<T>> for [Bound<T>; 2] {
|
||||
fn from(from: PgRange<T>) -> Self {
|
||||
[from.start, from.end]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<PgRange<T>> for (Bound<T>, Bound<T>) {
|
||||
fn from(from: PgRange<T>) -> Self {
|
||||
(from.start, from.end)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<Range<T>> for PgRange<T> {
|
||||
fn from(from: Range<T>) -> Self {
|
||||
Self {
|
||||
start: Bound::Included(from.start),
|
||||
end: Bound::Excluded(from.end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<RangeFrom<T>> for PgRange<T> {
|
||||
fn from(from: RangeFrom<T>) -> Self {
|
||||
Self {
|
||||
start: Bound::Included(from.start),
|
||||
end: Bound::Unbounded,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<RangeInclusive<T>> for PgRange<T> {
|
||||
fn from(from: RangeInclusive<T>) -> Self {
|
||||
let (start, end) = from.into_inner();
|
||||
Self {
|
||||
start: Bound::Included(start),
|
||||
end: Bound::Excluded(end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<RangeTo<T>> for PgRange<T> {
|
||||
fn from(from: RangeTo<T>) -> Self {
|
||||
Self {
|
||||
start: Bound::Unbounded,
|
||||
end: Bound::Excluded(from.end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<RangeToInclusive<T>> for PgRange<T> {
|
||||
fn from(from: RangeToInclusive<T>) -> Self {
|
||||
Self {
|
||||
start: Bound::Unbounded,
|
||||
end: Bound::Included(from.end),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> RangeBounds<T> for PgRange<T> {
|
||||
fn start_bound(&self) -> Bound<&T> {
|
||||
match &self.start {
|
||||
Bound::Included(ref start) => Bound::Included(start),
|
||||
Bound::Excluded(ref start) => Bound::Excluded(start),
|
||||
Bound::Unbounded => Bound::Unbounded,
|
||||
}
|
||||
}
|
||||
|
||||
fn end_bound(&self) -> Bound<&T> {
|
||||
match &self.end {
|
||||
Bound::Included(ref end) => Bound::Included(end),
|
||||
Bound::Excluded(ref end) => Bound::Excluded(end),
|
||||
Bound::Unbounded => Bound::Unbounded,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TryFrom<PgRange<T>> for Range<T> {
|
||||
type Error = crate::error::Error;
|
||||
|
||||
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
|
||||
let err_msg = "Invalid data for core::ops::Range";
|
||||
let start = included(from.start, err_msg)?;
|
||||
let end = excluded(from.end, err_msg)?;
|
||||
Ok(start..end)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TryFrom<PgRange<T>> for RangeFrom<T> {
|
||||
type Error = crate::error::Error;
|
||||
|
||||
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
|
||||
let err_msg = "Invalid data for core::ops::RangeFrom";
|
||||
let start = included(from.start, err_msg)?;
|
||||
unbounded(from.end, err_msg)?;
|
||||
Ok(start..)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TryFrom<PgRange<T>> for RangeInclusive<T> {
|
||||
type Error = crate::error::Error;
|
||||
|
||||
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
|
||||
let err_msg = "Invalid data for core::ops::RangeInclusive";
|
||||
let start = included(from.start, err_msg)?;
|
||||
let end = included(from.end, err_msg)?;
|
||||
Ok(start..=end)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TryFrom<PgRange<T>> for RangeTo<T> {
|
||||
type Error = crate::error::Error;
|
||||
|
||||
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
|
||||
let err_msg = "Invalid data for core::ops::RangeTo";
|
||||
unbounded(from.start, err_msg)?;
|
||||
let end = excluded(from.end, err_msg)?;
|
||||
Ok(..end)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TryFrom<PgRange<T>> for RangeToInclusive<T> {
|
||||
type Error = crate::error::Error;
|
||||
|
||||
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
|
||||
let err_msg = "Invalid data for core::ops::RangeToInclusive";
|
||||
unbounded(from.start, err_msg)?;
|
||||
let end = included(from.end, err_msg)?;
|
||||
Ok(..=end)
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_binary<'r, T>(
|
||||
mut bytes: &[u8],
|
||||
format: PgValueFormat,
|
||||
type_info: PgTypeInfo,
|
||||
) -> Result<PgRange<T>, crate::error::BoxDynError>
|
||||
where
|
||||
T: for<'rec> Decode<'rec, Postgres> + 'r,
|
||||
{
|
||||
let flags: RangeFlags = RangeFlags::from_bits_truncate(bytes.read_u8()?);
|
||||
let mut start_value = Bound::Unbounded;
|
||||
let mut end_value = Bound::Unbounded;
|
||||
|
||||
if flags.contains(RangeFlags::EMPTY) {
|
||||
return Ok(PgRange {
|
||||
start: start_value,
|
||||
end: end_value,
|
||||
});
|
||||
}
|
||||
|
||||
if !flags.contains(RangeFlags::LB_INF) {
|
||||
let elem_size = bytes.read_i32::<NetworkEndian>()?;
|
||||
let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize);
|
||||
bytes = new_bytes;
|
||||
let value = T::decode(PgValueRef {
|
||||
type_info: type_info.clone(),
|
||||
format,
|
||||
value: Some(elem_bytes),
|
||||
row: None,
|
||||
})?;
|
||||
|
||||
start_value = if flags.contains(RangeFlags::LB_INC) {
|
||||
Bound::Included(value)
|
||||
} else {
|
||||
Bound::Excluded(value)
|
||||
};
|
||||
}
|
||||
|
||||
if !flags.contains(RangeFlags::UB_INF) {
|
||||
bytes.read_i32::<NetworkEndian>()?;
|
||||
let value = T::decode(PgValueRef {
|
||||
type_info,
|
||||
format,
|
||||
value: Some(bytes),
|
||||
row: None,
|
||||
})?;
|
||||
|
||||
end_value = if flags.contains(RangeFlags::UB_INC) {
|
||||
Bound::Included(value)
|
||||
} else {
|
||||
Bound::Excluded(value)
|
||||
};
|
||||
}
|
||||
|
||||
Ok(PgRange {
|
||||
start: start_value,
|
||||
end: end_value,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_str<'r, T>(
|
||||
s: &str,
|
||||
format: PgValueFormat,
|
||||
type_info: PgTypeInfo,
|
||||
) -> Result<PgRange<T>, crate::error::BoxDynError>
|
||||
where
|
||||
T: for<'rec> Decode<'rec, Postgres> + 'r,
|
||||
{
|
||||
let err = || crate::error::Error::Decode("Invalid PostgreSQL range string".into());
|
||||
|
||||
let value =
|
||||
|bound: &str, delim, bounds: [&str; 2]| -> Result<Bound<T>, crate::error::BoxDynError> {
|
||||
if bound.len() == 0 {
|
||||
return Ok(Bound::Unbounded);
|
||||
}
|
||||
let bound_value = T::decode(PgValueRef {
|
||||
type_info: type_info.clone(),
|
||||
format,
|
||||
value: Some(bound.as_bytes()),
|
||||
row: None,
|
||||
})?;
|
||||
if delim == bounds[0] {
|
||||
Ok(Bound::Excluded(bound_value))
|
||||
} else if delim == bounds[1] {
|
||||
Ok(Bound::Included(bound_value))
|
||||
} else {
|
||||
Err(Box::new(err()))
|
||||
}
|
||||
};
|
||||
|
||||
let mut parts = s.split(',');
|
||||
let start_str = parts.next().ok_or_else(err)?;
|
||||
let start_value = value(
|
||||
start_str.get(1..).ok_or_else(err)?,
|
||||
start_str.get(0..1).ok_or_else(err)?,
|
||||
["(", "["],
|
||||
)?;
|
||||
let end_str = parts.next().ok_or_else(err)?;
|
||||
let last_char_idx = end_str.len() - 1;
|
||||
let end_value = value(
|
||||
end_str.get(..last_char_idx).ok_or_else(err)?,
|
||||
end_str.get(last_char_idx..).ok_or_else(err)?,
|
||||
[")", "]"],
|
||||
)?;
|
||||
|
||||
Ok(PgRange {
|
||||
start: start_value,
|
||||
end: end_value,
|
||||
})
|
||||
}
|
||||
|
||||
fn excluded<T>(b: Bound<T>, err_msg: &str) -> crate::error::Result<T> {
|
||||
if let Bound::Excluded(rslt) = b {
|
||||
Ok(rslt)
|
||||
} else {
|
||||
Err(crate::error::Error::Decode(err_msg.into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn included<T>(b: Bound<T>, err_msg: &str) -> crate::error::Result<T> {
|
||||
if let Bound::Included(rslt) = b {
|
||||
Ok(rslt)
|
||||
} else {
|
||||
Err(crate::error::Error::Decode(err_msg.into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn unbounded<T>(b: Bound<T>, err_msg: &str) -> crate::error::Result<()> {
|
||||
if matches!(b, Bound::Unbounded) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(crate::error::Error::Decode(err_msg.into()))
|
||||
}
|
||||
}
|
|
@ -1,84 +0,0 @@
|
|||
use crate::{
|
||||
decode::Decode,
|
||||
encode::{Encode, IsNull},
|
||||
postgres::{
|
||||
types::ranges::pg_range::PgRange, PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres,
|
||||
},
|
||||
types::Type,
|
||||
};
|
||||
|
||||
macro_rules! impl_pg_range {
|
||||
($range_name:ident, $type_info:expr, $type_info_array:expr, $range_type:ty) => {
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
pub struct $range_name(pub PgRange<$range_type>);
|
||||
|
||||
impl<'a> Decode<'a, Postgres> for $range_name {
|
||||
fn accepts(ty: &PgTypeInfo) -> bool {
|
||||
<PgRange<$range_type> as Decode<'_, Postgres>>::accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: PgValueRef<'a>) -> Result<$range_name, crate::error::BoxDynError> {
|
||||
Ok(Self(Decode::<Postgres>::decode(value)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Encode<'a, Postgres> for $range_name {
|
||||
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
|
||||
<PgRange<$range_type> as Encode<'_, Postgres>>::encode_by_ref(&self.0, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for $range_name {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
$type_info
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for [$range_name] {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
$type_info_array
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for Vec<$range_name> {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
$type_info_array
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_pg_range!(
|
||||
Int4Range,
|
||||
PgTypeInfo::INT4_RANGE,
|
||||
PgTypeInfo::INT4_RANGE_ARRAY,
|
||||
i32
|
||||
);
|
||||
#[cfg(feature = "bigdecimal")]
|
||||
impl_pg_range!(
|
||||
NumRange,
|
||||
PgTypeInfo::NUM_RANGE,
|
||||
PgTypeInfo::NUM_RANGE_ARRAY,
|
||||
bigdecimal::BigDecimal
|
||||
);
|
||||
#[cfg(feature = "chrono")]
|
||||
impl_pg_range!(
|
||||
TsRange,
|
||||
PgTypeInfo::TS_RANGE,
|
||||
PgTypeInfo::TS_RANGE_ARRAY,
|
||||
chrono::NaiveDateTime
|
||||
);
|
||||
#[cfg(feature = "chrono")]
|
||||
impl_pg_range!(
|
||||
DateRange,
|
||||
PgTypeInfo::DATE_RANGE,
|
||||
PgTypeInfo::DATE_RANGE_ARRAY,
|
||||
chrono::NaiveDate
|
||||
);
|
||||
impl_pg_range!(
|
||||
Int8Range,
|
||||
PgTypeInfo::INT8_RANGE,
|
||||
PgTypeInfo::INT8_RANGE_ARRAY,
|
||||
i64
|
||||
);
|
|
@ -1,7 +1,7 @@
|
|||
use bytes::Buf;
|
||||
|
||||
use crate::decode::Decode;
|
||||
use crate::encode::{Encode, IsNull};
|
||||
use crate::encode::Encode;
|
||||
use crate::error::{mismatched_types, BoxDynError};
|
||||
use crate::postgres::type_info::PgType;
|
||||
use crate::postgres::{
|
||||
|
@ -30,7 +30,7 @@ impl<'a> PgRecordEncoder<'a> {
|
|||
#[doc(hidden)]
|
||||
pub fn finish(&mut self) {
|
||||
// fill in the record length
|
||||
self.buf[self.off..].copy_from_slice(&self.num.to_be_bytes());
|
||||
self.buf[self.off..(self.off + 4)].copy_from_slice(&self.num.to_be_bytes());
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
|
@ -50,16 +50,8 @@ impl<'a> PgRecordEncoder<'a> {
|
|||
self.buf.extend(&ty.0.oid().to_be_bytes());
|
||||
}
|
||||
|
||||
let offset = self.buf.len();
|
||||
self.buf.extend(&(0_u32).to_be_bytes());
|
||||
|
||||
let size = if let IsNull::Yes = value.encode(self.buf) {
|
||||
-1
|
||||
} else {
|
||||
(self.buf.len() - offset + 4) as i32
|
||||
};
|
||||
|
||||
self.buf[offset..].copy_from_slice(&size.to_be_bytes());
|
||||
self.buf.encode(value);
|
||||
self.num += 1;
|
||||
|
||||
self
|
||||
}
|
||||
|
@ -133,6 +125,8 @@ impl<'r> PgRecordDecoder<'r> {
|
|||
}
|
||||
};
|
||||
|
||||
self.ind += 1;
|
||||
|
||||
if let Some(ty) = &element_type_opt {
|
||||
if !T::accepts(ty) {
|
||||
return Err(mismatched_types::<Postgres, T>(&T::type_info(), ty));
|
||||
|
@ -142,23 +136,7 @@ impl<'r> PgRecordDecoder<'r> {
|
|||
let element_type =
|
||||
element_type_opt.unwrap_or_else(|| PgTypeInfo::with_oid(element_type_oid));
|
||||
|
||||
let mut element_len = self.buf.get_i32();
|
||||
let element_buf = if element_len < 0 {
|
||||
element_len = 0;
|
||||
None
|
||||
} else {
|
||||
Some(&self.buf[..(element_len as usize)])
|
||||
};
|
||||
|
||||
self.buf.advance(element_len as usize);
|
||||
self.ind += 1;
|
||||
|
||||
T::decode(PgValueRef {
|
||||
type_info: element_type,
|
||||
format: self.fmt,
|
||||
value: element_buf,
|
||||
row: None,
|
||||
})
|
||||
T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type))
|
||||
}
|
||||
|
||||
PgValueFormat::Text => {
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::borrow::Cow;
|
||||
use std::str::from_utf8;
|
||||
|
||||
use bytes::Bytes;
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::{BoxDynError, UnexpectedNullError};
|
||||
use crate::postgres::{PgTypeInfo, Postgres};
|
||||
|
@ -32,6 +32,26 @@ pub struct PgValue {
|
|||
}
|
||||
|
||||
impl<'r> PgValueRef<'r> {
|
||||
pub(crate) fn get(buf: &mut &'r [u8], format: PgValueFormat, ty: PgTypeInfo) -> Self {
|
||||
let mut element_len = buf.get_i32();
|
||||
|
||||
let element_val = if element_len == -1 {
|
||||
element_len = 0;
|
||||
None
|
||||
} else {
|
||||
Some(&buf[..(element_len as usize)])
|
||||
};
|
||||
|
||||
buf.advance(element_len as usize);
|
||||
|
||||
PgValueRef {
|
||||
value: element_val,
|
||||
row: None,
|
||||
type_info: ty,
|
||||
format,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn format(&self) -> PgValueFormat {
|
||||
self.format
|
||||
}
|
||||
|
@ -62,7 +82,13 @@ impl Value for PgValue {
|
|||
}
|
||||
|
||||
fn type_info(&self) -> Option<Cow<'_, PgTypeInfo>> {
|
||||
Some(Cow::Borrowed(&self.type_info))
|
||||
if self.format == PgValueFormat::Text {
|
||||
// For TEXT encoding the type defined on the value is unreliable
|
||||
// We don't even bother to return it so type checking is implicitly opted-out
|
||||
None
|
||||
} else {
|
||||
Some(Cow::Borrowed(&self.type_info))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
|
@ -90,7 +116,13 @@ impl<'r> ValueRef<'r> for PgValueRef<'r> {
|
|||
}
|
||||
|
||||
fn type_info(&self) -> Option<Cow<'_, PgTypeInfo>> {
|
||||
Some(Cow::Borrowed(&self.type_info))
|
||||
if self.format == PgValueFormat::Text {
|
||||
// For TEXT encoding the type defined on the value is unreliable
|
||||
// We don't even bother to return it so type checking is implicitly opted-out
|
||||
None
|
||||
} else {
|
||||
Some(Cow::Borrowed(&self.type_info))
|
||||
}
|
||||
}
|
||||
|
||||
fn is_null(&self) -> bool {
|
||||
|
|
|
@ -146,15 +146,12 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
|
|||
Ok(SqlxChildAttributes { rename })
|
||||
}
|
||||
|
||||
pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> {
|
||||
pub fn check_transparent_attributes(
|
||||
input: &DeriveInput,
|
||||
field: &Field,
|
||||
) -> syn::Result<SqlxContainerAttributes> {
|
||||
let attributes = parse_container_attributes(&input.attrs)?;
|
||||
|
||||
assert_attribute!(
|
||||
attributes.transparent,
|
||||
"expected #[sqlx(transparent)]",
|
||||
input
|
||||
);
|
||||
|
||||
assert_attribute!(
|
||||
attributes.rename_all.is_none(),
|
||||
"unexpected #[sqlx(rename_all = ..)]",
|
||||
|
@ -163,15 +160,15 @@ pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::
|
|||
|
||||
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input);
|
||||
|
||||
let attributes = parse_child_attributes(&field.attrs)?;
|
||||
let ch_attributes = parse_child_attributes(&field.attrs)?;
|
||||
|
||||
assert_attribute!(
|
||||
attributes.rename.is_none(),
|
||||
ch_attributes.rename.is_none(),
|
||||
"unexpected #[sqlx(rename = ..)]",
|
||||
field
|
||||
);
|
||||
|
||||
Ok(())
|
||||
Ok(attributes)
|
||||
}
|
||||
|
||||
pub fn check_enum_attributes<'a>(input: &'a DeriveInput) -> syn::Result<SqlxContainerAttributes> {
|
||||
|
|
|
@ -62,21 +62,21 @@ fn expand_derive_decode_transparent(
|
|||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics.params.insert(0, parse_quote!('de));
|
||||
generics.params.insert(0, parse_quote!('r));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::decode::Decode<'de, DB>));
|
||||
.push(parse_quote!(#ty: sqlx::decode::Decode<'r, DB>));
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
let tts = quote!(
|
||||
impl #impl_generics sqlx::decode::Decode<'de, DB> for #ident #ty_generics #where_clause {
|
||||
impl #impl_generics sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause {
|
||||
fn accepts(ty: &DB::TypeInfo) -> bool {
|
||||
<#ty as sqlx::decode::Decode<'de, DB>>::accepts(ty)
|
||||
<#ty as sqlx::decode::Decode<'r, DB>>::accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: <DB as sqlx::database::HasValueRef<'de>>::ValueRef) -> std::result::Result<Self, sqlx::BoxDynError> {
|
||||
<#ty as sqlx::decode::Decode<'de, DB>>::decode(value).map(Self)
|
||||
fn decode(value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
|
||||
<#ty as sqlx::decode::Decode<'r, DB>>::decode(value).map(Self)
|
||||
}
|
||||
}
|
||||
);
|
||||
|
@ -103,13 +103,13 @@ fn expand_derive_decode_weak_enum(
|
|||
.collect::<Vec<Arm>>();
|
||||
|
||||
Ok(quote!(
|
||||
impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where #repr: sqlx::decode::Decode<'de, DB> {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
*ty == Self::type_info()
|
||||
impl<'r, DB: sqlx::Database> sqlx::decode::Decode<'r, DB> for #ident where #repr: sqlx::decode::Decode<'r, DB> {
|
||||
fn accepts(ty: &DB::TypeInfo) -> bool {
|
||||
<#repr as sqlx::decode::Decode<'r, DB>>::accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: <DB as sqlx::database::HasValueRef<'de>>::ValueRef) -> std::result::Result<Self, sqlx::BoxDynError> {
|
||||
let value = <#repr as sqlx::decode::Decode<'de, DB>>::decode(value)?;
|
||||
fn decode(value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
|
||||
let value = <#repr as sqlx::decode::Decode<'r, DB>>::decode(value)?;
|
||||
|
||||
match value {
|
||||
#(#arms)*
|
||||
|
@ -146,22 +146,65 @@ fn expand_derive_decode_strong_enum(
|
|||
}
|
||||
});
|
||||
|
||||
Ok(quote!(
|
||||
impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where &'de str: sqlx::decode::Decode<'de, DB> {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
*ty == Self::type_info()
|
||||
}
|
||||
let values = quote! {
|
||||
match value {
|
||||
#(#value_arms)*
|
||||
|
||||
fn decode(value: <DB as sqlx::database::HasValueRef<'de>>::ValueRef) -> std::result::Result<Self, sqlx::BoxDynError> {
|
||||
let value = <&'de str as sqlx::decode::Decode<'de, DB>>::decode(value)?;
|
||||
match value {
|
||||
#(#value_arms)*
|
||||
_ => Err(format!("invalid value {:?} for enum {}", value, #ident_s).into())
|
||||
}
|
||||
};
|
||||
|
||||
_ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())))
|
||||
let mut tts = proc_macro2::TokenStream::new();
|
||||
|
||||
if cfg!(feature = "mysql") {
|
||||
tts.extend(quote!(
|
||||
impl<'r> sqlx::decode::Decode<'r, sqlx::mysql::MySql> for #ident {
|
||||
fn accepts(ty: &sqlx::mysql::MySqlTypeInfo) -> bool {
|
||||
ty == sqlx::mysql::MySqlTypeInfo::__enum()
|
||||
}
|
||||
|
||||
fn decode(value: sqlx::mysql::MySqlValueRef<'r>) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
|
||||
let value = <&'r str as sqlx::decode::Decode<'r, sqlx::mysql::MySql>>::decode(value)?;
|
||||
|
||||
#values
|
||||
}
|
||||
}
|
||||
}
|
||||
))
|
||||
));
|
||||
}
|
||||
|
||||
if cfg!(feature = "postgres") {
|
||||
tts.extend(quote!(
|
||||
impl<'r> sqlx::decode::Decode<'r, sqlx::postgres::Postgres> for #ident {
|
||||
fn accepts(ty: &sqlx::postgres::PgTypeInfo) -> bool {
|
||||
*ty == <#ident as sqlx::Type<sqlx::postgres::Postgres>>::type_info()
|
||||
}
|
||||
|
||||
fn decode(value: sqlx::postgres::PgValueRef<'r>) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
|
||||
let value = <&'r str as sqlx::decode::Decode<'r, sqlx::postgres::Postgres>>::decode(value)?;
|
||||
|
||||
#values
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
if cfg!(feature = "sqlite") {
|
||||
tts.extend(quote!(
|
||||
impl<'r> sqlx::decode::Decode<'r, sqlx::sqlite::Sqlite> for #ident {
|
||||
fn accepts(ty: &sqlx::sqlite::SqliteTypeInfo) -> bool {
|
||||
<&str as sqlx::decode::Decode<'r, DB>>::accepts(ty)
|
||||
}
|
||||
|
||||
fn decode(value: sqlx::sqlite::SqliteValueRef<'r>) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
|
||||
let value = <&'r str as sqlx::decode::Decode<'r, sqlx::sqlite::Sqlite>>::decode(value)?;
|
||||
|
||||
#values
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
Ok(tts)
|
||||
}
|
||||
|
||||
fn expand_derive_decode_struct(
|
||||
|
@ -181,14 +224,14 @@ fn expand_derive_decode_struct(
|
|||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!('de));
|
||||
generics.params.insert(0, parse_quote!('r));
|
||||
|
||||
let predicates = &mut generics.make_where_clause().predicates;
|
||||
|
||||
for field in fields {
|
||||
let ty = &field.ty;
|
||||
|
||||
predicates.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>));
|
||||
predicates.push(parse_quote!(#ty: sqlx::decode::Decode<'r, sqlx::Postgres>));
|
||||
predicates.push(parse_quote!(#ty: sqlx::types::Type<sqlx::Postgres>));
|
||||
}
|
||||
|
||||
|
@ -199,20 +242,20 @@ fn expand_derive_decode_struct(
|
|||
let ty = &field.ty;
|
||||
|
||||
parse_quote!(
|
||||
let #id = decoder.decode::<#ty>()?;
|
||||
let #id = decoder.try_decode::<#ty>()?;
|
||||
)
|
||||
});
|
||||
|
||||
let names = fields.iter().map(|field| &field.ident);
|
||||
|
||||
tts.extend(quote!(
|
||||
impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause {
|
||||
fn accepts(ty: &MySqlTypeInfo) -> bool {
|
||||
*ty == Self::type_info()
|
||||
impl #impl_generics sqlx::decode::Decode<'r, sqlx::Postgres> for #ident #ty_generics #where_clause {
|
||||
fn accepts(ty: &sqlx::postgres::PgTypeInfo) -> bool {
|
||||
*ty == <Self as sqlx::Type<sqlx::postgres::Postgres>>::type_info()
|
||||
}
|
||||
|
||||
fn decode(value: <sqlx::Postgres as sqlx::value::HasRawValue<'de>>::RawValue) -> sqlx::Result<Self> {
|
||||
let mut decoder = sqlx::postgres::types::raw::PgRecordDecoder::new(value)?;
|
||||
fn decode(value: sqlx::postgres::PgValueRef<'r>) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
|
||||
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;
|
||||
|
||||
#(#reads)*
|
||||
|
||||
|
|
|
@ -67,6 +67,7 @@ fn expand_derive_encode_transparent(
|
|||
generics
|
||||
.params
|
||||
.insert(0, LifetimeDef::new(lifetime.clone()).into());
|
||||
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics
|
||||
.make_where_clause()
|
||||
|
@ -76,18 +77,16 @@ fn expand_derive_encode_transparent(
|
|||
|
||||
Ok(quote!(
|
||||
impl #impl_generics sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause {
|
||||
fn encode(self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
|
||||
sqlx::encode::Encode::encode(self.0, buf)
|
||||
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
|
||||
<#ty as sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf)
|
||||
}
|
||||
|
||||
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
|
||||
sqlx::encode::Encode::encode_by_ref(&self.0, buf)
|
||||
}
|
||||
fn produces(&self) -> Option<DB::TypeInfo> {
|
||||
<#ty as sqlx::encode::Encode<DB>>::produces(&self.0)
|
||||
<#ty as sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
sqlx::encode::Encode::size_hint(&self.0)
|
||||
<#ty as sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0)
|
||||
}
|
||||
}
|
||||
))
|
||||
|
@ -103,21 +102,17 @@ fn expand_derive_encode_weak_enum(
|
|||
let ident = &input.ident;
|
||||
|
||||
Ok(quote!(
|
||||
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where #repr: sqlx::encode::Encode<'q, DB> {
|
||||
fn encode(self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
|
||||
sqlx::encode::Encode::encode((self as #repr), buf)
|
||||
}
|
||||
|
||||
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
|
||||
sqlx::encode::Encode::encode_by_ref(&(*self as #repr), buf)
|
||||
}
|
||||
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where #repr: sqlx::encode::Encode<'q, DB> {
|
||||
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
|
||||
<#repr as sqlx::encode::Encode<DB>>::encode_by_ref(&(*self as #repr), buf)
|
||||
}
|
||||
|
||||
fn produces(&self) -> Option<DB::TypeInfo> {
|
||||
<Self as Type<MySql>>::type_info().into()
|
||||
<#repr as sqlx::encode::Encode<DB>>::produces(&(*self as #repr))
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
sqlx::encode::Encode::size_hint(&(*self as #repr))
|
||||
<#repr as sqlx::encode::Encode<DB>>::size_hint(&(*self as #repr))
|
||||
}
|
||||
}
|
||||
))
|
||||
|
@ -149,24 +144,21 @@ fn expand_derive_encode_strong_enum(
|
|||
}
|
||||
|
||||
Ok(quote!(
|
||||
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where str: sqlx::encode::Encode<'q, DB> {
|
||||
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where &'q str: sqlx::encode::Encode<'q, DB> {
|
||||
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
|
||||
let val = match self {
|
||||
#(#value_arms)*
|
||||
};
|
||||
|
||||
<str as sqlx::encode::Encode<'q, DB>>::encode_by_ref(val, buf)
|
||||
}
|
||||
|
||||
fn produces(&self) -> Option<DB::TypeInfo> {
|
||||
<Self as Type<MySql>>::type_info().into()
|
||||
<&str as sqlx::encode::Encode<'q, DB>>::encode(val, buf)
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
let val = match self {
|
||||
#(#value_arms)*
|
||||
};
|
||||
<str as sqlx::encode::Encode<'q, DB>>::size_hint(val)
|
||||
|
||||
<&str as sqlx::encode::Encode<'q, DB>>::size_hint(&val)
|
||||
}
|
||||
}
|
||||
))
|
||||
|
@ -190,14 +182,14 @@ fn expand_derive_encode_struct(
|
|||
|
||||
// add db type for impl generics & where clause
|
||||
let mut generics = generics.clone();
|
||||
|
||||
let predicates = &mut generics.make_where_clause().predicates;
|
||||
|
||||
for field in fields {
|
||||
let ty = &field.ty;
|
||||
|
||||
predicates.insert(0, parse_quote!('q));
|
||||
predicates.push(parse_quote!(#ty: sqlx::encode::Encode<'q, sqlx::Postgres>));
|
||||
predicates.push(parse_quote!(#ty: sqlx::types::Type<'q, sqlx::Postgres>));
|
||||
predicates.push(parse_quote!(#ty: for<'q> sqlx::encode::Encode<'q, sqlx::Postgres>));
|
||||
predicates.push(parse_quote!(#ty: sqlx::types::Type<sqlx::Postgres>));
|
||||
}
|
||||
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
@ -206,7 +198,6 @@ fn expand_derive_encode_struct(
|
|||
let id = &field.ident;
|
||||
|
||||
parse_quote!(
|
||||
// sqlx::postgres::encode_struct_field(buf, &self. #id);
|
||||
encoder.encode(&self. #id);
|
||||
)
|
||||
});
|
||||
|
@ -221,17 +212,15 @@ fn expand_derive_encode_struct(
|
|||
});
|
||||
|
||||
tts.extend(quote!(
|
||||
impl #impl_generics sqlx::encode::Encode<'q, sqlx::Postgres> for #ident #ty_generics #where_clause {
|
||||
fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
|
||||
let mut encoder = sqlx::postgres::types::raw::PgRecordEncoder::new(buf);
|
||||
impl #impl_generics sqlx::encode::Encode<'_, sqlx::Postgres> for #ident #ty_generics #where_clause {
|
||||
fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> sqlx::encode::IsNull {
|
||||
let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf);
|
||||
|
||||
#(#writes)*
|
||||
|
||||
encoder.finish()
|
||||
}
|
||||
encoder.finish();
|
||||
|
||||
fn produces(&self) -> Option<DB::TypeInfo> {
|
||||
<Self as Type<MySql>>::type_info().into()
|
||||
sqlx::encode::IsNull::No
|
||||
}
|
||||
|
||||
fn size_hint(&self) -> usize {
|
||||
|
|
|
@ -49,32 +49,48 @@ fn expand_derive_has_sql_type_transparent(
|
|||
input: &DeriveInput,
|
||||
field: &Field,
|
||||
) -> syn::Result<proc_macro2::TokenStream> {
|
||||
check_transparent_attributes(input, field)?;
|
||||
let attr = check_transparent_attributes(input, field)?;
|
||||
|
||||
let ident = &input.ident;
|
||||
let ty = &field.ty;
|
||||
|
||||
// extract type generics
|
||||
let generics = &input.generics;
|
||||
let (_, ty_generics, _) = generics.split_for_impl();
|
||||
|
||||
// add db type for clause
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::Type<DB>));
|
||||
if attr.transparent {
|
||||
let mut generics = generics.clone();
|
||||
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
|
||||
generics
|
||||
.make_where_clause()
|
||||
.predicates
|
||||
.push(parse_quote!(#ty: sqlx::Type<DB>));
|
||||
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
let (impl_generics, _, where_clause) = generics.split_for_impl();
|
||||
|
||||
Ok(quote!(
|
||||
impl #impl_generics sqlx::Type< DB > for #ident #ty_generics #where_clause {
|
||||
fn type_info() -> DB::TypeInfo {
|
||||
<#ty as sqlx::Type<DB>>::type_info()
|
||||
return Ok(quote!(
|
||||
impl #impl_generics sqlx::Type< DB > for #ident #ty_generics #where_clause {
|
||||
fn type_info() -> DB::TypeInfo {
|
||||
<#ty as sqlx::Type<DB>>::type_info()
|
||||
}
|
||||
}
|
||||
}
|
||||
))
|
||||
));
|
||||
}
|
||||
|
||||
let mut tts = proc_macro2::TokenStream::new();
|
||||
|
||||
if cfg!(feature = "postgres") {
|
||||
let ty_name = attr.rename.unwrap_or_else(|| ident.to_string());
|
||||
|
||||
tts.extend(quote!(
|
||||
impl sqlx::Type< sqlx::postgres::Postgres > for #ident #ty_generics {
|
||||
fn type_info() -> sqlx::postgres::PgTypeInfo {
|
||||
sqlx::postgres::PgTypeInfo::with_name(#ty_name)
|
||||
}
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
Ok(tts)
|
||||
}
|
||||
|
||||
fn expand_derive_has_sql_type_weak_enum(
|
||||
|
@ -84,8 +100,7 @@ fn expand_derive_has_sql_type_weak_enum(
|
|||
let attr = check_weak_enum_attributes(input, variants)?;
|
||||
let repr = attr.repr.unwrap();
|
||||
let ident = &input.ident;
|
||||
|
||||
Ok(quote!(
|
||||
let ts = quote!(
|
||||
impl<DB: sqlx::Database> sqlx::Type<DB> for #ident
|
||||
where
|
||||
#repr: sqlx::Type<DB>,
|
||||
|
@ -94,7 +109,9 @@ fn expand_derive_has_sql_type_weak_enum(
|
|||
<#repr as sqlx::Type<DB>>::type_info()
|
||||
}
|
||||
}
|
||||
))
|
||||
);
|
||||
|
||||
Ok(ts)
|
||||
}
|
||||
|
||||
fn expand_derive_has_sql_type_strong_enum(
|
||||
|
@ -110,7 +127,7 @@ fn expand_derive_has_sql_type_strong_enum(
|
|||
tts.extend(quote!(
|
||||
impl sqlx::Type< sqlx::MySql > for #ident {
|
||||
fn type_info() -> sqlx::mysql::MySqlTypeInfo {
|
||||
sqlx::mysql::MySqlTypeInfo::r#enum()
|
||||
sqlx::mysql::MySqlTypeInfo::__enum()
|
||||
}
|
||||
}
|
||||
));
|
||||
|
|
14
src/lib.rs
14
src/lib.rs
|
@ -11,16 +11,14 @@ pub use sqlx_core::query_as::{query_as, query_as_with};
|
|||
pub use sqlx_core::query_scalar::{query_scalar, query_scalar_with};
|
||||
pub use sqlx_core::row::{ColumnIndex, Row};
|
||||
pub use sqlx_core::transaction::{Transaction, TransactionManager};
|
||||
pub use sqlx_core::types::Type;
|
||||
pub use sqlx_core::value::{Value, ValueRef};
|
||||
|
||||
#[doc(hidden)]
|
||||
pub use sqlx_core::describe;
|
||||
|
||||
#[doc(inline)]
|
||||
pub use sqlx_core::types::{self, Type};
|
||||
|
||||
#[doc(inline)]
|
||||
pub use sqlx_core::error::{self, BoxDynError, Error, Result};
|
||||
pub use sqlx_core::error::{self, Error, Result};
|
||||
|
||||
#[cfg(feature = "mysql")]
|
||||
#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))]
|
||||
|
@ -42,6 +40,7 @@ pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqlitePool};
|
|||
#[doc(hidden)]
|
||||
pub extern crate sqlx_macros;
|
||||
|
||||
// derives
|
||||
#[cfg(feature = "macros")]
|
||||
pub use sqlx_macros::{FromRow, Type};
|
||||
|
||||
|
@ -57,6 +56,13 @@ pub mod ty_match;
|
|||
#[doc(hidden)]
|
||||
pub mod result_ext;
|
||||
|
||||
pub mod types {
|
||||
pub use sqlx_core::types::*;
|
||||
|
||||
#[cfg(feature = "macros")]
|
||||
pub use sqlx_macros::Type;
|
||||
}
|
||||
|
||||
/// Types and traits for encoding values for the database.
|
||||
pub mod encode {
|
||||
pub use sqlx_core::encode::{Encode, IsNull};
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use sqlx::{postgres::PgQueryAs, Connection, Cursor, Executor, FromRow, Postgres};
|
||||
use futures::TryStreamExt;
|
||||
use sqlx::{Connection, Executor, FromRow, Postgres};
|
||||
use sqlx_core::postgres::types::PgRange;
|
||||
use sqlx_test::{new, test_type};
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Bound;
|
||||
|
||||
// Transparent types are rust-side wrappers over DB types
|
||||
#[derive(PartialEq, Debug, sqlx::Type)]
|
||||
|
@ -37,6 +40,7 @@ enum ColorLower {
|
|||
Green,
|
||||
Blue,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, sqlx::Type)]
|
||||
#[sqlx(rename = "color_snake")]
|
||||
#[sqlx(rename_all = "snake_case")]
|
||||
|
@ -44,6 +48,7 @@ enum ColorSnake {
|
|||
RedGreen,
|
||||
BlueBlack,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug, sqlx::Type)]
|
||||
#[sqlx(rename = "color_upper")]
|
||||
#[sqlx(rename_all = "uppercase")]
|
||||
|
@ -73,36 +78,43 @@ struct InventoryItem {
|
|||
price: Option<i64>,
|
||||
}
|
||||
|
||||
test_type!(transparent(
|
||||
Postgres,
|
||||
Transparent,
|
||||
// Custom range type
|
||||
#[derive(sqlx::Type, Debug, PartialEq)]
|
||||
#[sqlx(rename = "float_range")]
|
||||
struct FloatRange(PgRange<f64>);
|
||||
|
||||
// Custom domain type
|
||||
#[derive(sqlx::Type, Debug)]
|
||||
#[sqlx(rename = "int4rangeL0pC")]
|
||||
struct RangeInclusive(PgRange<i32>);
|
||||
|
||||
test_type!(transparent<Transparent>(Postgres,
|
||||
"0" == Transparent(0),
|
||||
"23523" == Transparent(23523)
|
||||
));
|
||||
|
||||
test_type!(weak_enum(
|
||||
Postgres,
|
||||
Weak,
|
||||
test_type!(weak_enum<Weak>(Postgres,
|
||||
"0::int4" == Weak::One,
|
||||
"2::int4" == Weak::Two,
|
||||
"4::int4" == Weak::Three
|
||||
));
|
||||
|
||||
test_type!(strong_enum(
|
||||
Postgres,
|
||||
Strong,
|
||||
test_type!(strong_enum<Strong>(Postgres,
|
||||
"'one'::text" == Strong::One,
|
||||
"'two'::text" == Strong::Two,
|
||||
"'four'::text" == Strong::Three
|
||||
));
|
||||
|
||||
test_type!(floatrange<FloatRange>(Postgres,
|
||||
"'[1.234, 5.678]'::float_range" == FloatRange(PgRange::from((Bound::Included(1.234), Bound::Included(5.678)))),
|
||||
));
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_enum_type() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
conn.execute(
|
||||
r#"
|
||||
|
||||
DROP TABLE IF EXISTS people;
|
||||
|
||||
DROP TYPE IF EXISTS mood CASCADE;
|
||||
|
@ -154,7 +166,7 @@ RETURNING id
|
|||
let rec: PeopleRow = sqlx::query_as(
|
||||
"
|
||||
SELECT id, mood FROM people WHERE id = $1
|
||||
",
|
||||
",
|
||||
)
|
||||
.bind(people_id)
|
||||
.fetch_one(&mut conn)
|
||||
|
@ -169,20 +181,23 @@ SELECT id, mood FROM people WHERE id = $1
|
|||
|
||||
let stmt = format!("SELECT id, mood FROM people WHERE id = {}", people_id);
|
||||
dbg!(&stmt);
|
||||
|
||||
let mut cursor = conn.fetch(&*stmt);
|
||||
|
||||
let row = cursor.next().await?.unwrap();
|
||||
let row = cursor.try_next().await?.unwrap();
|
||||
let rec = PeopleRow::from_row(&row)?;
|
||||
|
||||
assert_eq!(rec.id, people_id);
|
||||
assert_eq!(rec.mood, Mood::Sad);
|
||||
|
||||
drop(cursor);
|
||||
|
||||
// Normal type equivalency test
|
||||
|
||||
let rec: (bool, Mood) = sqlx::query_as(
|
||||
"
|
||||
SELECT $1 = 'happy'::mood, $1
|
||||
",
|
||||
SELECT $1 = 'happy'::mood, $1
|
||||
",
|
||||
)
|
||||
.bind(&Mood::Happy)
|
||||
.fetch_one(&mut conn)
|
||||
|
@ -193,8 +208,8 @@ SELECT $1 = 'happy'::mood, $1
|
|||
|
||||
let rec: (bool, ColorLower) = sqlx::query_as(
|
||||
"
|
||||
SELECT $1 = 'green'::color_lower, $1
|
||||
",
|
||||
SELECT $1 = 'green'::color_lower, $1
|
||||
",
|
||||
)
|
||||
.bind(&ColorLower::Green)
|
||||
.fetch_one(&mut conn)
|
||||
|
@ -205,8 +220,8 @@ SELECT $1 = 'green'::color_lower, $1
|
|||
|
||||
let rec: (bool, ColorSnake) = sqlx::query_as(
|
||||
"
|
||||
SELECT $1 = 'red_green'::color_snake, $1
|
||||
",
|
||||
SELECT $1 = 'red_green'::color_snake, $1
|
||||
",
|
||||
)
|
||||
.bind(&ColorSnake::RedGreen)
|
||||
.fetch_one(&mut conn)
|
||||
|
@ -217,8 +232,8 @@ SELECT $1 = 'red_green'::color_snake, $1
|
|||
|
||||
let rec: (bool, ColorUpper) = sqlx::query_as(
|
||||
"
|
||||
SELECT $1 = 'RED'::color_upper, $1
|
||||
",
|
||||
SELECT $1 = 'RED'::color_upper, $1
|
||||
",
|
||||
)
|
||||
.bind(&ColorUpper::Red)
|
||||
.fetch_one(&mut conn)
|
||||
|
@ -234,23 +249,6 @@ SELECT $1 = 'RED'::color_upper, $1
|
|||
async fn test_record_type() -> anyhow::Result<()> {
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
conn.execute(
|
||||
r#"
|
||||
DO $$ BEGIN
|
||||
|
||||
CREATE TYPE inventory_item AS (
|
||||
name text,
|
||||
supplier_id int,
|
||||
price bigint
|
||||
);
|
||||
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let value = InventoryItem {
|
||||
name: "fuzzy dice".to_owned(),
|
||||
supplier_id: Some(42),
|
||||
|
@ -259,7 +257,7 @@ END $$;
|
|||
|
||||
let rec: (bool, InventoryItem) = sqlx::query_as(
|
||||
"
|
||||
SELECT $1 = ROW('fuzzy dice', 42, 199)::inventory_item, $1
|
||||
SELECT $1 = ROW('fuzzy dice', 42, 199)::inventory_item, $1
|
||||
",
|
||||
)
|
||||
.bind(&value)
|
||||
|
@ -275,9 +273,6 @@ END $$;
|
|||
#[cfg(feature = "macros")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_from_row() -> anyhow::Result<()> {
|
||||
// Needed for PgQueryAs
|
||||
use sqlx::prelude::*;
|
||||
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
|
@ -310,7 +305,8 @@ async fn test_from_row() -> anyhow::Result<()> {
|
|||
.bind(1_i32)
|
||||
.fetch(&mut conn);
|
||||
|
||||
let account = RefAccount::from_row(&cursor.next().await?.unwrap())?;
|
||||
let row = cursor.try_next().await?.unwrap();
|
||||
let account = RefAccount::from_row(&row)?;
|
||||
|
||||
assert_eq!(account.id, 1);
|
||||
assert_eq!(account.name, "Herp Derpinson");
|
||||
|
@ -321,8 +317,6 @@ async fn test_from_row() -> anyhow::Result<()> {
|
|||
#[cfg(feature = "macros")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_from_row_with_keyword() -> anyhow::Result<()> {
|
||||
use sqlx::prelude::*;
|
||||
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
struct AccountKeyword {
|
||||
r#type: i32,
|
||||
|
@ -353,8 +347,6 @@ async fn test_from_row_with_keyword() -> anyhow::Result<()> {
|
|||
#[cfg(feature = "macros")]
|
||||
#[sqlx_macros::test]
|
||||
async fn test_from_row_with_rename() -> anyhow::Result<()> {
|
||||
use sqlx::prelude::*;
|
||||
|
||||
#[derive(Debug, sqlx::FromRow)]
|
||||
struct AccountKeyword {
|
||||
#[sqlx(rename = "type")]
|
||||
|
|
|
@ -17,3 +17,9 @@ CREATE TABLE tweet
|
|||
text TEXT NOT NULL,
|
||||
owner_id BIGINT
|
||||
);
|
||||
|
||||
CREATE TYPE float_range AS RANGE
|
||||
(
|
||||
subtype = float8,
|
||||
subtype_diff = float8mi
|
||||
);
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
extern crate time_ as time;
|
||||
|
||||
use std::ops::Bound;
|
||||
|
||||
use sqlx::postgres::types::PgRange;
|
||||
use sqlx::postgres::Postgres;
|
||||
use sqlx_test::{test_decode_type, test_prepared_type, test_type};
|
||||
|
||||
|
@ -334,35 +337,26 @@ test_type!(decimal<sqlx::types::BigDecimal>(Postgres,
|
|||
"12345.6789::numeric" == "12345.6789".parse::<sqlx::types::BigDecimal>().unwrap(),
|
||||
));
|
||||
|
||||
mod ranges {
|
||||
use super::*;
|
||||
use core::ops::Bound;
|
||||
use sqlx::postgres::types::{Int4Range, PgRange};
|
||||
const EXC2: Bound<i32> = Bound::Excluded(2);
|
||||
const EXC3: Bound<i32> = Bound::Excluded(3);
|
||||
const INC1: Bound<i32> = Bound::Included(1);
|
||||
const INC2: Bound<i32> = Bound::Included(2);
|
||||
const UNB: Bound<i32> = Bound::Unbounded;
|
||||
|
||||
const EXC2: Bound<i32> = Bound::Excluded(2);
|
||||
const EXC3: Bound<i32> = Bound::Excluded(3);
|
||||
const INC1: Bound<i32> = Bound::Included(1);
|
||||
const INC2: Bound<i32> = Bound::Included(2);
|
||||
const UNB: Bound<i32> = Bound::Unbounded;
|
||||
|
||||
// int4range display is hard-coded into [l, u)
|
||||
test_type!(int4range<PgRange<i32>>(Postgres,
|
||||
|
||||
"'(,)'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
|
||||
"'(,]'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
|
||||
"'(,2)'::int4range" == Int4Range(PgRange::new([UNB, EXC2])),
|
||||
"'(,2]'::int4range" == Int4Range(PgRange::new([UNB, EXC3])),
|
||||
"'(1,)'::int4range" == Int4Range(PgRange::new([INC2, UNB])),
|
||||
"'(1,]'::int4range" == Int4Range(PgRange::new([INC2, UNB])),
|
||||
"'(1,2]'::int4range" == Int4Range(PgRange::new([INC2, EXC3])),
|
||||
|
||||
"'[,)'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
|
||||
"'[,]'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
|
||||
"'[,2)'::int4range" == Int4Range(PgRange::new([UNB, EXC2])),
|
||||
"'[,2]'::int4range" == Int4Range(PgRange::new([UNB, EXC3])),
|
||||
"'[1,)'::int4range" == Int4Range(PgRange::new([INC1, UNB])),
|
||||
"'[1,]'::int4range" == Int4Range(PgRange::new([INC1, UNB])),
|
||||
"'[1,2)'::int4range" == Int4Range(PgRange::new([INC1, EXC2])),
|
||||
"'[1,2]'::int4range" == Int4Range(PgRange::new([INC1, EXC3])),
|
||||
));
|
||||
}
|
||||
test_type!(int4range<PgRange<i32>>(Postgres,
|
||||
"'(,)'::int4range" == PgRange::from((UNB, UNB)),
|
||||
"'(,]'::int4range" == PgRange::from((UNB, UNB)),
|
||||
"'(,2)'::int4range" == PgRange::from((UNB, EXC2)),
|
||||
"'(,2]'::int4range" == PgRange::from((UNB, EXC3)),
|
||||
"'(1,)'::int4range" == PgRange::from((INC2, UNB)),
|
||||
"'(1,]'::int4range" == PgRange::from((INC2, UNB)),
|
||||
"'(1,2]'::int4range" == PgRange::from((INC2, EXC3)),
|
||||
"'[,)'::int4range" == PgRange::from((UNB, UNB)),
|
||||
"'[,]'::int4range" == PgRange::from((UNB, UNB)),
|
||||
"'[,2)'::int4range" == PgRange::from((UNB, EXC2)),
|
||||
"'[,2]'::int4range" == PgRange::from((UNB, EXC3)),
|
||||
"'[1,)'::int4range" == PgRange::from((INC1, UNB)),
|
||||
"'[1,]'::int4range" == PgRange::from((INC1, UNB)),
|
||||
"'[1,2)'::int4range" == PgRange::from((INC1, EXC2)),
|
||||
"'[1,2]'::int4range" == PgRange::from((INC1, EXC3)),
|
||||
));
|
||||
|
|
Loading…
Reference in a new issue