feat(postgres): add support for built-in range types and allow derives to handle custom range types

Co-authored-by: Caio <c410.f3r@gmail.com>
This commit is contained in:
Ryan Leckey 2020-06-12 15:20:24 -07:00
parent fedd883d91
commit c9f3e1adca
24 changed files with 922 additions and 861 deletions

View file

@ -33,6 +33,8 @@ pub trait Encode<'q, DB: Database> {
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull;
fn produces(&self) -> Option<DB::TypeInfo> {
// `produces` is inherently a hook to allow database drivers to produce value-dependent
// type information; if the driver doesn't need this, it can leave this as `None`
None
}

View file

@ -11,7 +11,8 @@ use crate::database::Database;
pub type Result<T> = StdResult<T, Error>;
// Convenience type alias for usage within SQLx.
pub type BoxDynError = Box<dyn StdError + 'static + Send + Sync>;
// Do not make this type public.
pub(crate) type BoxDynError = Box<dyn StdError + 'static + Send + Sync>;
/// An unexpected `NULL` was encountered during decoding.
///

View file

@ -21,6 +21,15 @@ impl MySqlTypeInfo {
}
}
#[doc(hidden)]
pub const fn __enum() -> Self {
Self {
r#type: ColumnType::Enum,
flags: ColumnFlags::BINARY,
char_set: 63,
}
}
#[doc(hidden)]
pub fn __type_feature_gate(&self) -> Option<&'static str> {
match self.r#type {

View file

@ -47,26 +47,34 @@ impl<'q> Arguments<'q> for PgArguments {
self.types
.push(value.produces().unwrap_or_else(T::type_info));
// reserve space to write the prefixed length of the value
let offset = self.buffer.len();
self.buffer.extend(&[0; 4]);
// encode the value into our buffer
let len = if let IsNull::No = value.encode(&mut self.buffer) {
(self.buffer.len() - offset - 4) as i32
} else {
// Write a -1 to indicate NULL
// NOTE: It is illegal for [encode] to write any data
debug_assert_eq!(self.buffer.len(), offset + 4);
-1_i32
};
// write the len to the beginning of the value
self.buffer.buffer[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
self.buffer.encode(value);
}
}
impl PgArgumentBuffer {
pub(crate) fn encode<'q, T>(&mut self, value: T)
where
T: Encode<'q, Postgres>,
{
// reserve space to write the prefixed length of the value
let offset = self.len();
self.extend(&[0; 4]);
// encode the value into our buffer
let len = if let IsNull::No = value.encode(self) {
(self.len() - offset - 4) as i32
} else {
// Write a -1 to indicate NULL
// NOTE: It is illegal for [encode] to write any data
debug_assert_eq!(self.len(), offset + 4);
-1_i32
};
// write the len to the beginning of the value
self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
}
// Extends the inner buffer by enough space to have an OID
// Remembers where the OID goes and type name for the OID
pub(crate) fn push_type_hole(&mut self, type_name: &UStr) {
@ -81,7 +89,7 @@ impl PgArgumentBuffer {
pub(crate) async fn patch_type_holes(&mut self, conn: &mut PgConnection) -> Result<(), Error> {
for (offset, name) in &self.type_holes {
let oid = conn.fetch_type_id_by_name(&*name).await?;
self.buffer[*offset..].copy_from_slice(&oid.to_be_bytes());
self.buffer[*offset..(*offset + 4)].copy_from_slice(&oid.to_be_bytes());
}
Ok(())

View file

@ -209,27 +209,31 @@ ORDER BY attnum
})
}
async fn fetch_range_by_oid(&mut self, oid: u32, name: String) -> Result<PgTypeInfo, Error> {
let _: i32 = query_scalar(
r#"
SELECT 1
fn fetch_range_by_oid(
&mut self,
oid: u32,
name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let element_oid: u32 = query_scalar(
r#"
SELECT rngsubtype
FROM pg_catalog.pg_range
WHERE rngtypid = $1
"#,
)
.bind(oid)
.fetch_one(self)
.await?;
"#,
)
.bind(oid)
.fetch_one(&mut *self)
.await?;
let pg_type = PgType::try_from_oid(oid).ok_or_else(|| {
err_protocol!("Trying to retrieve a DB type that doesn't exist in SQLx")
})?;
let element = self.maybe_fetch_type_info_by_oid(element_oid, true).await?;
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Range(PgTypeInfo(pg_type)),
name: name.into(),
oid,
}))))
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Range(element),
name: name.into(),
oid,
}))))
})
}
pub(crate) async fn fetch_type_id_by_name(&mut self, name: &str) -> Result<u32, Error> {

View file

@ -28,15 +28,22 @@ async fn prepare(
// additional queries here to get any missing OIDs
let mut param_types = Vec::with_capacity(arguments.types.len());
let mut has_fetched = false;
for ty in &arguments.types {
param_types.push(if let PgType::DeclareWithName(name) = &ty.0 {
has_fetched = true;
conn.fetch_type_id_by_name(name).await?
} else {
ty.0.oid()
});
}
// flush and wait until we are re-ready
if has_fetched {
conn.wait_until_ready().await?;
}
// next we send the PARSE command to the server
conn.stream.write(Parse {
param_types: &*param_types,
@ -111,6 +118,18 @@ impl PgConnection {
// patch holes created during encoding
arguments.buffer.patch_type_holes(self).await?;
// describe the statement and, again, ask the server to immediately respond
// we need to fully realize the types
self.stream.write(message::Describe::Statement(statement));
self.stream.write(message::Flush);
self.stream.flush().await?;
let _ = recv_desc_params(self).await?;
let rows = recv_desc_rows(self).await?;
self.handle_row_description(rows, true).await?;
self.wait_until_ready().await?;
// bind to attach the arguments to the statement and create a portal
self.stream.write(Bind {
portal: None,
@ -121,17 +140,6 @@ impl PgConnection {
result_formats: &[PgValueFormat::Binary],
});
// describe the portal and, again, ask the server to immediately respond
// we need to fully realize the types
self.stream.write(message::Describe::UnnamedPortal);
self.stream.write(Flush);
self.stream.flush().await?;
let _ = self.stream.recv_expect(MessageFormat::BindComplete).await?;
let rows = recv_desc_rows(self).await?;
self.handle_row_description(rows, true).await?;
// executes the portal up to the passed limit
// the protocol-level limit acts nearly identically to the `LIMIT` in SQL
self.stream.write(message::Execute {

View file

@ -11,7 +11,7 @@ use crate::type_info::TypeInfo;
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
pub struct PgTypeInfo(pub(crate) PgType);
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))]
#[repr(u32)]
pub(crate) enum PgType {
@ -197,7 +197,12 @@ impl PgTypeInfo {
Self(PgType::DeclareWithName(UStr::Static(name)))
}
pub(crate) const fn with_oid(oid: u32) -> Self {
/// Create a `PgTypeInfo` from an OID.
///
/// Note that the OID for a type is very dependent on the environment. If you only ever use
/// one database or if this is an unhandled build-in type, you should be fine. Otherwise,
/// you will be better served using [`with_name`](#method.with_name).
pub const fn with_oid(oid: u32) -> Self {
Self(PgType::DeclareWithOid(oid))
}
}
@ -308,7 +313,14 @@ impl PgType {
}
pub(crate) fn oid(&self) -> u32 {
match self {
match self.try_oid() {
Some(oid) => oid,
None => unreachable!("(bug) use of unresolved type declaration [oid]"),
}
}
pub(crate) fn try_oid(&self) -> Option<u32> {
Some(match self {
PgType::Bool => 16,
PgType::Bytea => 17,
PgType::Char => 18,
@ -400,8 +412,10 @@ impl PgType {
PgType::Custom(ty) => ty.oid,
PgType::DeclareWithOid(oid) => *oid,
PgType::DeclareWithName(_) => unreachable!("(bug) use of unresolved type declaration"),
}
PgType::DeclareWithName(_) => {
return None;
}
})
}
pub(crate) fn name(&self) -> &str {
@ -576,24 +590,24 @@ impl PgType {
PgType::UuidArray => &PgTypeKind::Array(PgTypeInfo(PgType::Uuid)),
PgType::Jsonb => &PgTypeKind::Simple,
PgType::JsonbArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonb)),
PgType::Int4Range => &PgTypeKind::Simple,
PgType::Int4Range => &PgTypeKind::Range(PgTypeInfo::INT4),
PgType::Int4RangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Int4Range)),
PgType::NumRange => &PgTypeKind::Simple,
PgType::NumRange => &PgTypeKind::Range(PgTypeInfo::NUMERIC),
PgType::NumRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::NumRange)),
PgType::TsRange => &PgTypeKind::Simple,
PgType::TsRange => &PgTypeKind::Range(PgTypeInfo::TIMESTAMP),
PgType::TsRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::TsRange)),
PgType::TstzRange => &PgTypeKind::Simple,
PgType::TstzRange => &PgTypeKind::Range(PgTypeInfo::TIMESTAMPTZ),
PgType::TstzRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::TstzRange)),
PgType::DateRange => &PgTypeKind::Simple,
PgType::DateRange => &PgTypeKind::Range(PgTypeInfo::DATE),
PgType::DateRangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::DateRange)),
PgType::Int8Range => &PgTypeKind::Simple,
PgType::Int8Range => &PgTypeKind::Range(PgTypeInfo::INT8),
PgType::Int8RangeArray => &PgTypeKind::Array(PgTypeInfo(PgType::Int8Range)),
PgType::Jsonpath => &PgTypeKind::Simple,
PgType::JsonpathArray => &PgTypeKind::Array(PgTypeInfo(PgType::Jsonpath)),
PgType::Custom(ty) => &ty.kind,
PgType::DeclareWithOid(_) | PgType::DeclareWithName(_) => {
unreachable!("(bug) use of unresolved type declaration")
unreachable!("(bug) use of unresolved type declaration [kind]")
}
}
}
@ -817,3 +831,15 @@ impl Display for PgTypeInfo {
f.pad(self.0.name())
}
}
impl PartialEq<PgType> for PgType {
fn eq(&self, other: &PgType) -> bool {
if let (Some(a), Some(b)) = (self.try_oid(), other.try_oid()) {
// If there are OIDs available, use OIDs to perform a direct match
a == b
} else {
// Otherwise, perform a match on the name
self.name().eq_ignore_ascii_case(other.name())
}
}
}

View file

@ -63,22 +63,7 @@ where
buf.extend(&1_i32.to_be_bytes()); // lower bound
for element in self.iter() {
// allocate space for the length of the encoded element
let el_len_offset = buf.len();
buf.extend(&0_i32.to_be_bytes());
let el_start = buf.len();
if let IsNull::Yes = element.encode_by_ref(buf) {
// NULL is encoded as -1 for a length
buf[el_len_offset..el_start].copy_from_slice(&(-1_i32).to_be_bytes());
} else {
let el_end = buf.len();
let el_len = el_end - el_start;
// now we can go back and update the length
buf[el_len_offset..el_start].copy_from_slice(&(el_len as i32).to_be_bytes());
}
buf.encode(element);
}
IsNull::No
@ -144,23 +129,11 @@ where
let mut elements = Vec::with_capacity(len as usize);
for _ in 0..len {
let mut element_len = buf.get_i32();
let element_val = if element_len == -1 {
element_len = 0;
None
} else {
Some(&buf[..(element_len as usize)])
};
elements.push(T::decode(PgValueRef {
value: element_val,
row: None,
type_info: element_type_info.clone(),
elements.push(T::decode(PgValueRef::get(
&mut buf,
format,
})?);
buf.advance(element_len as usize);
element_type_info.clone(),
))?)
}
Ok(elements)

View file

@ -132,8 +132,8 @@ mod array;
mod bool;
mod bytes;
mod float;
mod num;
mod ranges;
mod int;
mod range;
mod record;
mod str;
mod tuple;
@ -159,7 +159,9 @@ mod json;
#[cfg(feature = "ipnetwork")]
mod ipnetwork;
pub use {
ranges::{pg_range::PgRange, pg_ranges::*},
record::{PgRecordDecoder, PgRecordEncoder},
};
pub use range::PgRange;
// used in derive(Type) for `struct`
// but the interface is not considered part of the public API
#[doc(hidden)]
pub use record::{PgRecordDecoder, PgRecordEncoder};

View file

@ -0,0 +1,530 @@
use std::fmt::{self, Debug, Display, Formatter};
use std::ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive};
use bitflags::bitflags;
use bytes::Buf;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::postgres::{
PgArgumentBuffer, PgTypeInfo, PgTypeKind, PgValueFormat, PgValueRef, Postgres,
};
use crate::types::Type;
// https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/include/utils/rangetypes.h#L35-L44
bitflags! {
struct RangeFlags: u8 {
const EMPTY = 0x01;
const LB_INC = 0x02;
const UB_INC = 0x04;
const LB_INF = 0x08;
const UB_INF = 0x10;
const LB_NULL = 0x20; // not used
const UB_NULL = 0x40; // not used
const CONTAIN_EMPTY = 0x80; // internal
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct PgRange<T> {
pub start: Bound<T>,
pub end: Bound<T>,
}
impl<T> From<[Bound<T>; 2]> for PgRange<T> {
fn from(v: [Bound<T>; 2]) -> Self {
let [start, end] = v;
Self { start, end }
}
}
impl<T> From<(Bound<T>, Bound<T>)> for PgRange<T> {
fn from(v: (Bound<T>, Bound<T>)) -> Self {
Self {
start: v.0,
end: v.1,
}
}
}
impl<T> From<Range<T>> for PgRange<T> {
fn from(v: Range<T>) -> Self {
Self {
start: Bound::Included(v.start),
end: Bound::Excluded(v.end),
}
}
}
impl<T> From<RangeFrom<T>> for PgRange<T> {
fn from(v: RangeFrom<T>) -> Self {
Self {
start: Bound::Included(v.start),
end: Bound::Unbounded,
}
}
}
impl<T> From<RangeInclusive<T>> for PgRange<T> {
fn from(v: RangeInclusive<T>) -> Self {
let (start, end) = v.into_inner();
Self {
start: Bound::Included(start),
end: Bound::Included(end),
}
}
}
impl<T> From<RangeTo<T>> for PgRange<T> {
fn from(v: RangeTo<T>) -> Self {
Self {
start: Bound::Unbounded,
end: Bound::Excluded(v.end),
}
}
}
impl<T> From<RangeToInclusive<T>> for PgRange<T> {
fn from(v: RangeToInclusive<T>) -> Self {
Self {
start: Bound::Unbounded,
end: Bound::Included(v.end),
}
}
}
impl<T> RangeBounds<T> for PgRange<T> {
fn start_bound(&self) -> Bound<&T> {
match self.start {
Bound::Included(ref start) => Bound::Included(start),
Bound::Excluded(ref start) => Bound::Excluded(start),
Bound::Unbounded => Bound::Unbounded,
}
}
fn end_bound(&self) -> Bound<&T> {
match self.end {
Bound::Included(ref end) => Bound::Included(end),
Bound::Excluded(ref end) => Bound::Excluded(end),
Bound::Unbounded => Bound::Unbounded,
}
}
}
impl Type<Postgres> for PgRange<i32> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::INT4_RANGE
}
}
impl Type<Postgres> for PgRange<i64> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::INT8_RANGE
}
}
#[cfg(feature = "bigdecimal")]
impl Type<Postgres> for PgRange<bigdecimal::BigDecimal> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::NUM_RANGE
}
}
#[cfg(feature = "chrono")]
impl Type<Postgres> for PgRange<chrono::NaiveDate> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::DATE_RANGE
}
}
#[cfg(feature = "chrono")]
impl Type<Postgres> for PgRange<chrono::NaiveDateTime> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TS_RANGE
}
}
#[cfg(feature = "chrono")]
impl<Tz: chrono::TimeZone> Type<Postgres> for PgRange<chrono::DateTime<Tz>> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TSTZ_RANGE
}
}
#[cfg(feature = "time")]
impl Type<Postgres> for PgRange<time::Date> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::DATE_RANGE
}
}
#[cfg(feature = "time")]
impl Type<Postgres> for PgRange<time::PrimitiveDateTime> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TS_RANGE
}
}
#[cfg(feature = "time")]
impl Type<Postgres> for PgRange<time::OffsetDateTime> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TSTZ_RANGE
}
}
impl Type<Postgres> for [PgRange<i32>] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::INT4_RANGE_ARRAY
}
}
impl Type<Postgres> for [PgRange<i64>] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::INT8_RANGE_ARRAY
}
}
#[cfg(feature = "bigdecimal")]
impl Type<Postgres> for [PgRange<bigdecimal::BigDecimal>] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::NUM_RANGE_ARRAY
}
}
#[cfg(feature = "chrono")]
impl Type<Postgres> for [PgRange<chrono::NaiveDate>] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::DATE_RANGE_ARRAY
}
}
#[cfg(feature = "chrono")]
impl Type<Postgres> for [PgRange<chrono::NaiveDateTime>] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TS_RANGE_ARRAY
}
}
#[cfg(feature = "chrono")]
impl<Tz: chrono::TimeZone> Type<Postgres> for [PgRange<chrono::DateTime<Tz>>] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TSTZ_RANGE_ARRAY
}
}
#[cfg(feature = "time")]
impl Type<Postgres> for [PgRange<time::Date>] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::DATE_RANGE_ARRAY
}
}
#[cfg(feature = "time")]
impl Type<Postgres> for [PgRange<time::PrimitiveDateTime>] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TS_RANGE_ARRAY
}
}
#[cfg(feature = "time")]
impl Type<Postgres> for [PgRange<time::OffsetDateTime>] {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TSTZ_RANGE_ARRAY
}
}
impl Type<Postgres> for Vec<PgRange<i32>> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::INT4_RANGE_ARRAY
}
}
impl Type<Postgres> for Vec<PgRange<i64>> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::INT8_RANGE_ARRAY
}
}
#[cfg(feature = "bigdecimal")]
impl Type<Postgres> for Vec<PgRange<bigdecimal::BigDecimal>> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::NUM_RANGE_ARRAY
}
}
#[cfg(feature = "chrono")]
impl Type<Postgres> for Vec<PgRange<chrono::NaiveDate>> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::DATE_RANGE_ARRAY
}
}
#[cfg(feature = "chrono")]
impl Type<Postgres> for Vec<PgRange<chrono::NaiveDateTime>> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TS_RANGE_ARRAY
}
}
#[cfg(feature = "chrono")]
impl<Tz: chrono::TimeZone> Type<Postgres> for Vec<PgRange<chrono::DateTime<Tz>>> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TSTZ_RANGE_ARRAY
}
}
#[cfg(feature = "time")]
impl Type<Postgres> for Vec<PgRange<time::Date>> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::DATE_RANGE_ARRAY
}
}
#[cfg(feature = "time")]
impl Type<Postgres> for Vec<PgRange<time::PrimitiveDateTime>> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TS_RANGE_ARRAY
}
}
#[cfg(feature = "time")]
impl Type<Postgres> for Vec<PgRange<time::OffsetDateTime>> {
fn type_info() -> PgTypeInfo {
PgTypeInfo::TSTZ_RANGE_ARRAY
}
}
impl<'q, T> Encode<'q, Postgres> for PgRange<T>
where
T: Encode<'q, Postgres>,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
// https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L245
let mut flags = RangeFlags::empty();
flags |= match self.start {
Bound::Included(_) => RangeFlags::LB_INC,
Bound::Unbounded => RangeFlags::LB_INF,
Bound::Excluded(_) => RangeFlags::empty(),
};
flags |= match self.end {
Bound::Included(_) => RangeFlags::UB_INC,
Bound::Unbounded => RangeFlags::UB_INF,
Bound::Excluded(_) => RangeFlags::empty(),
};
buf.push(flags.bits());
if let Bound::Included(v) | Bound::Excluded(v) = &self.start {
buf.encode(v);
}
if let Bound::Included(v) | Bound::Excluded(v) = &self.end {
buf.encode(v);
}
// ranges are themselves never null
IsNull::No
}
}
impl<'r, T> Decode<'r, Postgres> for PgRange<T>
where
T: Type<Postgres> + for<'a> Decode<'a, Postgres>,
{
fn accepts(ty: &PgTypeInfo) -> bool {
// we require the declared type to be a _range_ with an
// element type that is acceptable
if let PgTypeKind::Range(element) = &ty.0.kind() {
return T::accepts(&element);
}
false
}
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format {
PgValueFormat::Binary => {
let element_ty = if let PgTypeKind::Range(element) = &value.type_info.0.kind() {
element
} else {
return Err(format!("unexpected non-range type {}", value.type_info).into());
};
let mut buf = value.as_bytes()?;
let mut start = Bound::Unbounded;
let mut end = Bound::Unbounded;
let flags = RangeFlags::from_bits_truncate(buf.get_u8());
if flags.contains(RangeFlags::EMPTY) {
return Ok(PgRange { start, end });
}
if !flags.contains(RangeFlags::LB_INF) {
let value =
T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?;
start = if flags.contains(RangeFlags::LB_INC) {
Bound::Included(value)
} else {
Bound::Excluded(value)
};
}
if !flags.contains(RangeFlags::UB_INF) {
let value =
T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?;
end = if flags.contains(RangeFlags::UB_INC) {
Bound::Included(value)
} else {
Bound::Excluded(value)
};
}
Ok(PgRange { start, end })
}
PgValueFormat::Text => {
// https://github.com/postgres/postgres/blob/2f48ede080f42b97b594fb14102c82ca1001b80c/src/backend/utils/adt/rangetypes.c#L2046
let mut start = None;
let mut end = None;
let s = value.as_str()?;
// remember the bounds
let sb = s.as_bytes();
let lower = sb[0] as char;
let upper = sb[sb.len() - 1] as char;
// trim the wrapping braces/brackets
let s = &s[1..(s.len() - 1)];
let mut chars = s.chars();
let mut element = String::new();
let mut done = false;
let mut quoted = false;
let mut in_quotes = false;
let mut in_escape = false;
let mut prev_ch = '\0';
let mut count = 0;
while !done {
element.clear();
loop {
match chars.next() {
Some(ch) => {
match ch {
_ if in_escape => {
element.push(ch);
in_escape = false;
}
'"' if in_quotes => {
in_quotes = false;
}
'"' => {
in_quotes = true;
quoted = true;
if prev_ch == '"' {
element.push('"')
}
}
'\\' if !in_escape => {
in_escape = true;
}
',' if !in_quotes => break,
_ => {
element.push(ch);
}
}
prev_ch = ch;
}
None => {
done = true;
break;
}
}
}
count += 1;
if !(element.is_empty() && !quoted) {
let value = Some(T::decode(PgValueRef {
type_info: T::type_info(),
format: PgValueFormat::Text,
value: Some(element.as_bytes()),
row: None,
})?);
if count == 1 {
start = value;
} else if count == 2 {
end = value;
} else {
return Err("more than 2 elements found in a range".into());
}
}
}
let start = parse_bound(lower, start)?;
let end = parse_bound(upper, end)?;
Ok(PgRange { start, end })
}
}
}
}
fn parse_bound<T>(ch: char, value: Option<T>) -> Result<Bound<T>, BoxDynError> {
Ok(if let Some(value) = value {
match ch {
'(' | ')' => Bound::Excluded(value),
'[' | ']' => Bound::Included(value),
_ => {
return Err(format!(
"expected `(`, ')', '[', or `]` but found `{}` for range literal",
ch
)
.into());
}
}
} else {
Bound::Unbounded
})
}
impl<T> Display for PgRange<T>
where
T: Display,
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match &self.start {
Bound::Unbounded => f.write_str("(,")?,
Bound::Excluded(v) => write!(f, "({},", v)?,
Bound::Included(v) => write!(f, "[{},", v)?,
}
match &self.end {
Bound::Unbounded => f.write_str(")")?,
Bound::Excluded(v) => write!(f, "{})", v)?,
Bound::Included(v) => write!(f, "{}]", v)?,
}
Ok(())
}
}

View file

@ -1,87 +0,0 @@
pub(crate) mod pg_range;
pub(crate) mod pg_ranges;
use crate::{
decode::Decode,
encode::{Encode, IsNull},
postgres::{types::PgRange, PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres},
types::Type,
};
use core::{
convert::TryInto,
ops::{Range, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive},
};
macro_rules! impl_range {
($range:ident) => {
impl<'a, T> Decode<'a, Postgres> for $range<T>
where
T: for<'b> Decode<'b, Postgres> + Type<Postgres> + 'a,
{
fn accepts(ty: &PgTypeInfo) -> bool {
<PgRange<T> as Decode<'_, Postgres>>::accepts(ty)
}
fn decode(value: PgValueRef<'a>) -> Result<$range<T>, crate::error::BoxDynError> {
let bounds: PgRange<T> = Decode::<Postgres>::decode(value)?;
let rslt = bounds.try_into()?;
Ok(rslt)
}
}
impl<'a, T> Encode<'a, Postgres> for $range<T>
where
T: Clone + for<'b> Encode<'b, Postgres> + 'a,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
<PgRange<T> as Encode<'_, Postgres>>::encode(self.clone().into(), buf)
}
}
};
}
impl_range!(Range);
impl_range!(RangeFrom);
impl_range!(RangeInclusive);
impl_range!(RangeTo);
impl_range!(RangeToInclusive);
#[test]
fn test_decode_str_bounds() {
use crate::postgres::type_info::PgType;
const EXC1: Bound<i32> = Bound::Excluded(1);
const EXC2: Bound<i32> = Bound::Excluded(2);
const INC1: Bound<i32> = Bound::Included(1);
const INC2: Bound<i32> = Bound::Included(2);
const UNB: Bound<i32> = Bound::Unbounded;
let check = |s: &str, range_cmp: [Bound<i32>; 2]| {
let pg_value = PgValueRef {
type_info: PgTypeInfo(PgType::Int4Range),
format: PgValueFormat::Text,
value: Some(s.as_bytes()),
row: None,
};
let range: PgRange<i32> = Decode::<Postgres>::decode(pg_value).unwrap();
assert_eq!(Into::<[Bound<i32>; 2]>::into(range), range_cmp);
};
check("(,)", [UNB, UNB]);
check("(,]", [UNB, UNB]);
check("(,2)", [UNB, EXC2]);
check("(,2]", [UNB, INC2]);
check("(1,)", [EXC1, UNB]);
check("(1,]", [EXC1, UNB]);
check("(1,2)", [EXC1, EXC2]);
check("(1,2]", [EXC1, INC2]);
check("[,)", [UNB, UNB]);
check("[,]", [UNB, UNB]);
check("[,2)", [UNB, EXC2]);
check("[,2]", [UNB, INC2]);
check("[1,)", [INC1, UNB]);
check("[1,]", [INC1, UNB]);
check("[1,2)", [INC1, EXC2]);
check("[1,2]", [INC1, INC2]);
}

View file

@ -1,385 +0,0 @@
use crate::{
decode::Decode,
encode::{Encode, IsNull},
postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres},
types::Type,
};
use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
use core::{
convert::TryFrom,
ops::{Bound, Range, RangeBounds, RangeFrom, RangeInclusive, RangeTo, RangeToInclusive},
};
bitflags::bitflags! {
struct RangeFlags: u8 {
const EMPTY = 0x01;
const LB_INC = 0x02;
const UB_INC = 0x04;
const LB_INF = 0x08;
const UB_INF = 0x10;
const LB_NULL = 0x20;
const UB_NULL = 0x40;
const CONTAIN_EMPTY = 0x80;
}
}
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct PgRange<T> {
pub start: Bound<T>,
pub end: Bound<T>,
}
impl<T> PgRange<T> {
pub fn new(start: Bound<T>, end: Bound<T>) -> Self {
Self { start, end }
}
}
impl<'a, T> Decode<'a, Postgres> for PgRange<T>
where
T: for<'b> Decode<'b, Postgres> + Type<Postgres> + 'a,
{
fn accepts(ty: &PgTypeInfo) -> bool {
[
PgTypeInfo::INT4_RANGE,
PgTypeInfo::NUM_RANGE,
PgTypeInfo::TS_RANGE,
PgTypeInfo::TSTZ_RANGE,
PgTypeInfo::DATE_RANGE,
PgTypeInfo::INT8_RANGE,
]
.contains(ty)
}
fn decode(value: PgValueRef<'a>) -> Result<PgRange<T>, crate::error::BoxDynError> {
match value.format() {
PgValueFormat::Binary => {
decode_binary(value.as_bytes()?, value.format, value.type_info)
}
PgValueFormat::Text => decode_str(value.as_str()?, value.format(), value.type_info),
}
}
}
impl<'a, T> Encode<'a, Postgres> for PgRange<T>
where
T: for<'b> Encode<'b, Postgres> + 'a,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
let mut flags = match self.start {
Bound::Included(_) => RangeFlags::LB_INC,
Bound::Excluded(_) => RangeFlags::empty(),
Bound::Unbounded => RangeFlags::LB_INF,
};
flags |= match self.end {
Bound::Included(_) => RangeFlags::UB_INC,
Bound::Excluded(_) => RangeFlags::empty(),
Bound::Unbounded => RangeFlags::UB_INF,
};
buf.write_u8(flags.bits()).unwrap();
let mut write = |bound: &Bound<T>| -> IsNull {
match bound {
Bound::Included(ref value) | Bound::Excluded(ref value) => {
buf.write_u32::<NetworkEndian>(0).unwrap();
let prev = buf.len();
if let IsNull::Yes = Encode::<Postgres>::encode(value, buf) {
return IsNull::Yes;
}
let len = buf.len() - prev;
buf[prev - 4..prev].copy_from_slice(&(len as u32).to_be_bytes());
}
Bound::Unbounded => {}
}
IsNull::No
};
if let IsNull::Yes = write(&self.start) {
return IsNull::Yes;
}
write(&self.end)
}
}
impl<T> From<[Bound<T>; 2]> for PgRange<T> {
fn from(from: [Bound<T>; 2]) -> Self {
let [start, end] = from;
Self { start, end }
}
}
impl<T> From<(Bound<T>, Bound<T>)> for PgRange<T> {
fn from(from: (Bound<T>, Bound<T>)) -> Self {
Self {
start: from.0,
end: from.1,
}
}
}
impl<T> From<PgRange<T>> for [Bound<T>; 2] {
fn from(from: PgRange<T>) -> Self {
[from.start, from.end]
}
}
impl<T> From<PgRange<T>> for (Bound<T>, Bound<T>) {
fn from(from: PgRange<T>) -> Self {
(from.start, from.end)
}
}
impl<T> From<Range<T>> for PgRange<T> {
fn from(from: Range<T>) -> Self {
Self {
start: Bound::Included(from.start),
end: Bound::Excluded(from.end),
}
}
}
impl<T> From<RangeFrom<T>> for PgRange<T> {
fn from(from: RangeFrom<T>) -> Self {
Self {
start: Bound::Included(from.start),
end: Bound::Unbounded,
}
}
}
impl<T> From<RangeInclusive<T>> for PgRange<T> {
fn from(from: RangeInclusive<T>) -> Self {
let (start, end) = from.into_inner();
Self {
start: Bound::Included(start),
end: Bound::Excluded(end),
}
}
}
impl<T> From<RangeTo<T>> for PgRange<T> {
fn from(from: RangeTo<T>) -> Self {
Self {
start: Bound::Unbounded,
end: Bound::Excluded(from.end),
}
}
}
impl<T> From<RangeToInclusive<T>> for PgRange<T> {
fn from(from: RangeToInclusive<T>) -> Self {
Self {
start: Bound::Unbounded,
end: Bound::Included(from.end),
}
}
}
impl<T> RangeBounds<T> for PgRange<T> {
fn start_bound(&self) -> Bound<&T> {
match &self.start {
Bound::Included(ref start) => Bound::Included(start),
Bound::Excluded(ref start) => Bound::Excluded(start),
Bound::Unbounded => Bound::Unbounded,
}
}
fn end_bound(&self) -> Bound<&T> {
match &self.end {
Bound::Included(ref end) => Bound::Included(end),
Bound::Excluded(ref end) => Bound::Excluded(end),
Bound::Unbounded => Bound::Unbounded,
}
}
}
impl<T> TryFrom<PgRange<T>> for Range<T> {
type Error = crate::error::Error;
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
let err_msg = "Invalid data for core::ops::Range";
let start = included(from.start, err_msg)?;
let end = excluded(from.end, err_msg)?;
Ok(start..end)
}
}
impl<T> TryFrom<PgRange<T>> for RangeFrom<T> {
type Error = crate::error::Error;
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
let err_msg = "Invalid data for core::ops::RangeFrom";
let start = included(from.start, err_msg)?;
unbounded(from.end, err_msg)?;
Ok(start..)
}
}
impl<T> TryFrom<PgRange<T>> for RangeInclusive<T> {
type Error = crate::error::Error;
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
let err_msg = "Invalid data for core::ops::RangeInclusive";
let start = included(from.start, err_msg)?;
let end = included(from.end, err_msg)?;
Ok(start..=end)
}
}
impl<T> TryFrom<PgRange<T>> for RangeTo<T> {
type Error = crate::error::Error;
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
let err_msg = "Invalid data for core::ops::RangeTo";
unbounded(from.start, err_msg)?;
let end = excluded(from.end, err_msg)?;
Ok(..end)
}
}
impl<T> TryFrom<PgRange<T>> for RangeToInclusive<T> {
type Error = crate::error::Error;
fn try_from(from: PgRange<T>) -> crate::error::Result<Self> {
let err_msg = "Invalid data for core::ops::RangeToInclusive";
unbounded(from.start, err_msg)?;
let end = included(from.end, err_msg)?;
Ok(..=end)
}
}
fn decode_binary<'r, T>(
mut bytes: &[u8],
format: PgValueFormat,
type_info: PgTypeInfo,
) -> Result<PgRange<T>, crate::error::BoxDynError>
where
T: for<'rec> Decode<'rec, Postgres> + 'r,
{
let flags: RangeFlags = RangeFlags::from_bits_truncate(bytes.read_u8()?);
let mut start_value = Bound::Unbounded;
let mut end_value = Bound::Unbounded;
if flags.contains(RangeFlags::EMPTY) {
return Ok(PgRange {
start: start_value,
end: end_value,
});
}
if !flags.contains(RangeFlags::LB_INF) {
let elem_size = bytes.read_i32::<NetworkEndian>()?;
let (elem_bytes, new_bytes) = bytes.split_at(elem_size as usize);
bytes = new_bytes;
let value = T::decode(PgValueRef {
type_info: type_info.clone(),
format,
value: Some(elem_bytes),
row: None,
})?;
start_value = if flags.contains(RangeFlags::LB_INC) {
Bound::Included(value)
} else {
Bound::Excluded(value)
};
}
if !flags.contains(RangeFlags::UB_INF) {
bytes.read_i32::<NetworkEndian>()?;
let value = T::decode(PgValueRef {
type_info,
format,
value: Some(bytes),
row: None,
})?;
end_value = if flags.contains(RangeFlags::UB_INC) {
Bound::Included(value)
} else {
Bound::Excluded(value)
};
}
Ok(PgRange {
start: start_value,
end: end_value,
})
}
fn decode_str<'r, T>(
s: &str,
format: PgValueFormat,
type_info: PgTypeInfo,
) -> Result<PgRange<T>, crate::error::BoxDynError>
where
T: for<'rec> Decode<'rec, Postgres> + 'r,
{
let err = || crate::error::Error::Decode("Invalid PostgreSQL range string".into());
let value =
|bound: &str, delim, bounds: [&str; 2]| -> Result<Bound<T>, crate::error::BoxDynError> {
if bound.len() == 0 {
return Ok(Bound::Unbounded);
}
let bound_value = T::decode(PgValueRef {
type_info: type_info.clone(),
format,
value: Some(bound.as_bytes()),
row: None,
})?;
if delim == bounds[0] {
Ok(Bound::Excluded(bound_value))
} else if delim == bounds[1] {
Ok(Bound::Included(bound_value))
} else {
Err(Box::new(err()))
}
};
let mut parts = s.split(',');
let start_str = parts.next().ok_or_else(err)?;
let start_value = value(
start_str.get(1..).ok_or_else(err)?,
start_str.get(0..1).ok_or_else(err)?,
["(", "["],
)?;
let end_str = parts.next().ok_or_else(err)?;
let last_char_idx = end_str.len() - 1;
let end_value = value(
end_str.get(..last_char_idx).ok_or_else(err)?,
end_str.get(last_char_idx..).ok_or_else(err)?,
[")", "]"],
)?;
Ok(PgRange {
start: start_value,
end: end_value,
})
}
fn excluded<T>(b: Bound<T>, err_msg: &str) -> crate::error::Result<T> {
if let Bound::Excluded(rslt) = b {
Ok(rslt)
} else {
Err(crate::error::Error::Decode(err_msg.into()))
}
}
fn included<T>(b: Bound<T>, err_msg: &str) -> crate::error::Result<T> {
if let Bound::Included(rslt) = b {
Ok(rslt)
} else {
Err(crate::error::Error::Decode(err_msg.into()))
}
}
fn unbounded<T>(b: Bound<T>, err_msg: &str) -> crate::error::Result<()> {
if matches!(b, Bound::Unbounded) {
Ok(())
} else {
Err(crate::error::Error::Decode(err_msg.into()))
}
}

View file

@ -1,84 +0,0 @@
use crate::{
decode::Decode,
encode::{Encode, IsNull},
postgres::{
types::ranges::pg_range::PgRange, PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres,
},
types::Type,
};
macro_rules! impl_pg_range {
($range_name:ident, $type_info:expr, $type_info_array:expr, $range_type:ty) => {
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
#[repr(transparent)]
pub struct $range_name(pub PgRange<$range_type>);
impl<'a> Decode<'a, Postgres> for $range_name {
fn accepts(ty: &PgTypeInfo) -> bool {
<PgRange<$range_type> as Decode<'_, Postgres>>::accepts(ty)
}
fn decode(value: PgValueRef<'a>) -> Result<$range_name, crate::error::BoxDynError> {
Ok(Self(Decode::<Postgres>::decode(value)?))
}
}
impl<'a> Encode<'a, Postgres> for $range_name {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
<PgRange<$range_type> as Encode<'_, Postgres>>::encode_by_ref(&self.0, buf)
}
}
impl Type<Postgres> for $range_name {
fn type_info() -> PgTypeInfo {
$type_info
}
}
impl Type<Postgres> for [$range_name] {
fn type_info() -> PgTypeInfo {
$type_info_array
}
}
impl Type<Postgres> for Vec<$range_name> {
fn type_info() -> PgTypeInfo {
$type_info_array
}
}
};
}
impl_pg_range!(
Int4Range,
PgTypeInfo::INT4_RANGE,
PgTypeInfo::INT4_RANGE_ARRAY,
i32
);
#[cfg(feature = "bigdecimal")]
impl_pg_range!(
NumRange,
PgTypeInfo::NUM_RANGE,
PgTypeInfo::NUM_RANGE_ARRAY,
bigdecimal::BigDecimal
);
#[cfg(feature = "chrono")]
impl_pg_range!(
TsRange,
PgTypeInfo::TS_RANGE,
PgTypeInfo::TS_RANGE_ARRAY,
chrono::NaiveDateTime
);
#[cfg(feature = "chrono")]
impl_pg_range!(
DateRange,
PgTypeInfo::DATE_RANGE,
PgTypeInfo::DATE_RANGE_ARRAY,
chrono::NaiveDate
);
impl_pg_range!(
Int8Range,
PgTypeInfo::INT8_RANGE,
PgTypeInfo::INT8_RANGE_ARRAY,
i64
);

View file

@ -1,7 +1,7 @@
use bytes::Buf;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::encode::Encode;
use crate::error::{mismatched_types, BoxDynError};
use crate::postgres::type_info::PgType;
use crate::postgres::{
@ -30,7 +30,7 @@ impl<'a> PgRecordEncoder<'a> {
#[doc(hidden)]
pub fn finish(&mut self) {
// fill in the record length
self.buf[self.off..].copy_from_slice(&self.num.to_be_bytes());
self.buf[self.off..(self.off + 4)].copy_from_slice(&self.num.to_be_bytes());
}
#[doc(hidden)]
@ -50,16 +50,8 @@ impl<'a> PgRecordEncoder<'a> {
self.buf.extend(&ty.0.oid().to_be_bytes());
}
let offset = self.buf.len();
self.buf.extend(&(0_u32).to_be_bytes());
let size = if let IsNull::Yes = value.encode(self.buf) {
-1
} else {
(self.buf.len() - offset + 4) as i32
};
self.buf[offset..].copy_from_slice(&size.to_be_bytes());
self.buf.encode(value);
self.num += 1;
self
}
@ -133,6 +125,8 @@ impl<'r> PgRecordDecoder<'r> {
}
};
self.ind += 1;
if let Some(ty) = &element_type_opt {
if !T::accepts(ty) {
return Err(mismatched_types::<Postgres, T>(&T::type_info(), ty));
@ -142,23 +136,7 @@ impl<'r> PgRecordDecoder<'r> {
let element_type =
element_type_opt.unwrap_or_else(|| PgTypeInfo::with_oid(element_type_oid));
let mut element_len = self.buf.get_i32();
let element_buf = if element_len < 0 {
element_len = 0;
None
} else {
Some(&self.buf[..(element_len as usize)])
};
self.buf.advance(element_len as usize);
self.ind += 1;
T::decode(PgValueRef {
type_info: element_type,
format: self.fmt,
value: element_buf,
row: None,
})
T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type))
}
PgValueFormat::Text => {

View file

@ -1,7 +1,7 @@
use std::borrow::Cow;
use std::str::from_utf8;
use bytes::Bytes;
use bytes::{Buf, Bytes};
use crate::error::{BoxDynError, UnexpectedNullError};
use crate::postgres::{PgTypeInfo, Postgres};
@ -32,6 +32,26 @@ pub struct PgValue {
}
impl<'r> PgValueRef<'r> {
pub(crate) fn get(buf: &mut &'r [u8], format: PgValueFormat, ty: PgTypeInfo) -> Self {
let mut element_len = buf.get_i32();
let element_val = if element_len == -1 {
element_len = 0;
None
} else {
Some(&buf[..(element_len as usize)])
};
buf.advance(element_len as usize);
PgValueRef {
value: element_val,
row: None,
type_info: ty,
format,
}
}
pub(crate) fn format(&self) -> PgValueFormat {
self.format
}
@ -62,7 +82,13 @@ impl Value for PgValue {
}
fn type_info(&self) -> Option<Cow<'_, PgTypeInfo>> {
Some(Cow::Borrowed(&self.type_info))
if self.format == PgValueFormat::Text {
// For TEXT encoding the type defined on the value is unreliable
// We don't even bother to return it so type checking is implicitly opted-out
None
} else {
Some(Cow::Borrowed(&self.type_info))
}
}
fn is_null(&self) -> bool {
@ -90,7 +116,13 @@ impl<'r> ValueRef<'r> for PgValueRef<'r> {
}
fn type_info(&self) -> Option<Cow<'_, PgTypeInfo>> {
Some(Cow::Borrowed(&self.type_info))
if self.format == PgValueFormat::Text {
// For TEXT encoding the type defined on the value is unreliable
// We don't even bother to return it so type checking is implicitly opted-out
None
} else {
Some(Cow::Borrowed(&self.type_info))
}
}
fn is_null(&self) -> bool {

View file

@ -146,15 +146,12 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
Ok(SqlxChildAttributes { rename })
}
pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::Result<()> {
pub fn check_transparent_attributes(
input: &DeriveInput,
field: &Field,
) -> syn::Result<SqlxContainerAttributes> {
let attributes = parse_container_attributes(&input.attrs)?;
assert_attribute!(
attributes.transparent,
"expected #[sqlx(transparent)]",
input
);
assert_attribute!(
attributes.rename_all.is_none(),
"unexpected #[sqlx(rename_all = ..)]",
@ -163,15 +160,15 @@ pub fn check_transparent_attributes(input: &DeriveInput, field: &Field) -> syn::
assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input);
let attributes = parse_child_attributes(&field.attrs)?;
let ch_attributes = parse_child_attributes(&field.attrs)?;
assert_attribute!(
attributes.rename.is_none(),
ch_attributes.rename.is_none(),
"unexpected #[sqlx(rename = ..)]",
field
);
Ok(())
Ok(attributes)
}
pub fn check_enum_attributes<'a>(input: &'a DeriveInput) -> syn::Result<SqlxContainerAttributes> {

View file

@ -62,21 +62,21 @@ fn expand_derive_decode_transparent(
// add db type for impl generics & where clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics.params.insert(0, parse_quote!('de));
generics.params.insert(0, parse_quote!('r));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::decode::Decode<'de, DB>));
.push(parse_quote!(#ty: sqlx::decode::Decode<'r, DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
let tts = quote!(
impl #impl_generics sqlx::decode::Decode<'de, DB> for #ident #ty_generics #where_clause {
impl #impl_generics sqlx::decode::Decode<'r, DB> for #ident #ty_generics #where_clause {
fn accepts(ty: &DB::TypeInfo) -> bool {
<#ty as sqlx::decode::Decode<'de, DB>>::accepts(ty)
<#ty as sqlx::decode::Decode<'r, DB>>::accepts(ty)
}
fn decode(value: <DB as sqlx::database::HasValueRef<'de>>::ValueRef) -> std::result::Result<Self, sqlx::BoxDynError> {
<#ty as sqlx::decode::Decode<'de, DB>>::decode(value).map(Self)
fn decode(value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
<#ty as sqlx::decode::Decode<'r, DB>>::decode(value).map(Self)
}
}
);
@ -103,13 +103,13 @@ fn expand_derive_decode_weak_enum(
.collect::<Vec<Arm>>();
Ok(quote!(
impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where #repr: sqlx::decode::Decode<'de, DB> {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == Self::type_info()
impl<'r, DB: sqlx::Database> sqlx::decode::Decode<'r, DB> for #ident where #repr: sqlx::decode::Decode<'r, DB> {
fn accepts(ty: &DB::TypeInfo) -> bool {
<#repr as sqlx::decode::Decode<'r, DB>>::accepts(ty)
}
fn decode(value: <DB as sqlx::database::HasValueRef<'de>>::ValueRef) -> std::result::Result<Self, sqlx::BoxDynError> {
let value = <#repr as sqlx::decode::Decode<'de, DB>>::decode(value)?;
fn decode(value: <DB as sqlx::database::HasValueRef<'r>>::ValueRef) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <#repr as sqlx::decode::Decode<'r, DB>>::decode(value)?;
match value {
#(#arms)*
@ -146,22 +146,65 @@ fn expand_derive_decode_strong_enum(
}
});
Ok(quote!(
impl<'de, DB: sqlx::Database> sqlx::decode::Decode<'de, DB> for #ident where &'de str: sqlx::decode::Decode<'de, DB> {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == Self::type_info()
}
let values = quote! {
match value {
#(#value_arms)*
fn decode(value: <DB as sqlx::database::HasValueRef<'de>>::ValueRef) -> std::result::Result<Self, sqlx::BoxDynError> {
let value = <&'de str as sqlx::decode::Decode<'de, DB>>::decode(value)?;
match value {
#(#value_arms)*
_ => Err(format!("invalid value {:?} for enum {}", value, #ident_s).into())
}
};
_ => Err(Box::new(sqlx::Error::Decode(format!("invalid value {:?} for enum {}", value, #ident_s).into())))
let mut tts = proc_macro2::TokenStream::new();
if cfg!(feature = "mysql") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::mysql::MySql> for #ident {
fn accepts(ty: &sqlx::mysql::MySqlTypeInfo) -> bool {
ty == sqlx::mysql::MySqlTypeInfo::__enum()
}
fn decode(value: sqlx::mysql::MySqlValueRef<'r>) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <&'r str as sqlx::decode::Decode<'r, sqlx::mysql::MySql>>::decode(value)?;
#values
}
}
}
))
));
}
if cfg!(feature = "postgres") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::postgres::Postgres> for #ident {
fn accepts(ty: &sqlx::postgres::PgTypeInfo) -> bool {
*ty == <#ident as sqlx::Type<sqlx::postgres::Postgres>>::type_info()
}
fn decode(value: sqlx::postgres::PgValueRef<'r>) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <&'r str as sqlx::decode::Decode<'r, sqlx::postgres::Postgres>>::decode(value)?;
#values
}
}
));
}
if cfg!(feature = "sqlite") {
tts.extend(quote!(
impl<'r> sqlx::decode::Decode<'r, sqlx::sqlite::Sqlite> for #ident {
fn accepts(ty: &sqlx::sqlite::SqliteTypeInfo) -> bool {
<&str as sqlx::decode::Decode<'r, DB>>::accepts(ty)
}
fn decode(value: sqlx::sqlite::SqliteValueRef<'r>) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let value = <&'r str as sqlx::decode::Decode<'r, sqlx::sqlite::Sqlite>>::decode(value)?;
#values
}
}
));
}
Ok(tts)
}
fn expand_derive_decode_struct(
@ -181,14 +224,14 @@ fn expand_derive_decode_struct(
// add db type for impl generics & where clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!('de));
generics.params.insert(0, parse_quote!('r));
let predicates = &mut generics.make_where_clause().predicates;
for field in fields {
let ty = &field.ty;
predicates.push(parse_quote!(#ty: sqlx::decode::Decode<'de, sqlx::Postgres>));
predicates.push(parse_quote!(#ty: sqlx::decode::Decode<'r, sqlx::Postgres>));
predicates.push(parse_quote!(#ty: sqlx::types::Type<sqlx::Postgres>));
}
@ -199,20 +242,20 @@ fn expand_derive_decode_struct(
let ty = &field.ty;
parse_quote!(
let #id = decoder.decode::<#ty>()?;
let #id = decoder.try_decode::<#ty>()?;
)
});
let names = fields.iter().map(|field| &field.ident);
tts.extend(quote!(
impl #impl_generics sqlx::decode::Decode<'de, sqlx::Postgres> for #ident #ty_generics #where_clause {
fn accepts(ty: &MySqlTypeInfo) -> bool {
*ty == Self::type_info()
impl #impl_generics sqlx::decode::Decode<'r, sqlx::Postgres> for #ident #ty_generics #where_clause {
fn accepts(ty: &sqlx::postgres::PgTypeInfo) -> bool {
*ty == <Self as sqlx::Type<sqlx::postgres::Postgres>>::type_info()
}
fn decode(value: <sqlx::Postgres as sqlx::value::HasRawValue<'de>>::RawValue) -> sqlx::Result<Self> {
let mut decoder = sqlx::postgres::types::raw::PgRecordDecoder::new(value)?;
fn decode(value: sqlx::postgres::PgValueRef<'r>) -> std::result::Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;
#(#reads)*

View file

@ -67,6 +67,7 @@ fn expand_derive_encode_transparent(
generics
.params
.insert(0, LifetimeDef::new(lifetime.clone()).into());
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
@ -76,18 +77,16 @@ fn expand_derive_encode_transparent(
Ok(quote!(
impl #impl_generics sqlx::encode::Encode<#lifetime, DB> for #ident #ty_generics #where_clause {
fn encode(self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
sqlx::encode::Encode::encode(self.0, buf)
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#ty as sqlx::encode::Encode<#lifetime, DB>>::encode_by_ref(&self.0, buf)
}
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<#lifetime>>::ArgumentBuffer) -> sqlx::encode::IsNull {
sqlx::encode::Encode::encode_by_ref(&self.0, buf)
}
fn produces(&self) -> Option<DB::TypeInfo> {
<#ty as sqlx::encode::Encode<DB>>::produces(&self.0)
<#ty as sqlx::encode::Encode<#lifetime, DB>>::produces(&self.0)
}
fn size_hint(&self) -> usize {
sqlx::encode::Encode::size_hint(&self.0)
<#ty as sqlx::encode::Encode<#lifetime, DB>>::size_hint(&self.0)
}
}
))
@ -103,21 +102,17 @@ fn expand_derive_encode_weak_enum(
let ident = &input.ident;
Ok(quote!(
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where #repr: sqlx::encode::Encode<'q, DB> {
fn encode(self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
sqlx::encode::Encode::encode((self as #repr), buf)
}
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
sqlx::encode::Encode::encode_by_ref(&(*self as #repr), buf)
}
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where #repr: sqlx::encode::Encode<'q, DB> {
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
<#repr as sqlx::encode::Encode<DB>>::encode_by_ref(&(*self as #repr), buf)
}
fn produces(&self) -> Option<DB::TypeInfo> {
<Self as Type<MySql>>::type_info().into()
<#repr as sqlx::encode::Encode<DB>>::produces(&(*self as #repr))
}
fn size_hint(&self) -> usize {
sqlx::encode::Encode::size_hint(&(*self as #repr))
<#repr as sqlx::encode::Encode<DB>>::size_hint(&(*self as #repr))
}
}
))
@ -149,24 +144,21 @@ fn expand_derive_encode_strong_enum(
}
Ok(quote!(
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where str: sqlx::encode::Encode<'q, DB> {
impl<'q, DB: sqlx::Database> sqlx::encode::Encode<'q, DB> for #ident where &'q str: sqlx::encode::Encode<'q, DB> {
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
let val = match self {
#(#value_arms)*
};
<str as sqlx::encode::Encode<'q, DB>>::encode_by_ref(val, buf)
}
fn produces(&self) -> Option<DB::TypeInfo> {
<Self as Type<MySql>>::type_info().into()
<&str as sqlx::encode::Encode<'q, DB>>::encode(val, buf)
}
fn size_hint(&self) -> usize {
let val = match self {
#(#value_arms)*
};
<str as sqlx::encode::Encode<'q, DB>>::size_hint(val)
<&str as sqlx::encode::Encode<'q, DB>>::size_hint(&val)
}
}
))
@ -190,14 +182,14 @@ fn expand_derive_encode_struct(
// add db type for impl generics & where clause
let mut generics = generics.clone();
let predicates = &mut generics.make_where_clause().predicates;
for field in fields {
let ty = &field.ty;
predicates.insert(0, parse_quote!('q));
predicates.push(parse_quote!(#ty: sqlx::encode::Encode<'q, sqlx::Postgres>));
predicates.push(parse_quote!(#ty: sqlx::types::Type<'q, sqlx::Postgres>));
predicates.push(parse_quote!(#ty: for<'q> sqlx::encode::Encode<'q, sqlx::Postgres>));
predicates.push(parse_quote!(#ty: sqlx::types::Type<sqlx::Postgres>));
}
let (impl_generics, _, where_clause) = generics.split_for_impl();
@ -206,7 +198,6 @@ fn expand_derive_encode_struct(
let id = &field.ident;
parse_quote!(
// sqlx::postgres::encode_struct_field(buf, &self. #id);
encoder.encode(&self. #id);
)
});
@ -221,17 +212,15 @@ fn expand_derive_encode_struct(
});
tts.extend(quote!(
impl #impl_generics sqlx::encode::Encode<'q, sqlx::Postgres> for #ident #ty_generics #where_clause {
fn encode_by_ref(&self, buf: &mut <sqlx::Postgres as sqlx::database::HasArguments<'q>>::ArgumentBuffer) -> sqlx::encode::IsNull {
let mut encoder = sqlx::postgres::types::raw::PgRecordEncoder::new(buf);
impl #impl_generics sqlx::encode::Encode<'_, sqlx::Postgres> for #ident #ty_generics #where_clause {
fn encode_by_ref(&self, buf: &mut sqlx::postgres::PgArgumentBuffer) -> sqlx::encode::IsNull {
let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf);
#(#writes)*
encoder.finish()
}
encoder.finish();
fn produces(&self) -> Option<DB::TypeInfo> {
<Self as Type<MySql>>::type_info().into()
sqlx::encode::IsNull::No
}
fn size_hint(&self) -> usize {

View file

@ -49,32 +49,48 @@ fn expand_derive_has_sql_type_transparent(
input: &DeriveInput,
field: &Field,
) -> syn::Result<proc_macro2::TokenStream> {
check_transparent_attributes(input, field)?;
let attr = check_transparent_attributes(input, field)?;
let ident = &input.ident;
let ty = &field.ty;
// extract type generics
let generics = &input.generics;
let (_, ty_generics, _) = generics.split_for_impl();
// add db type for clause
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::Type<DB>));
if attr.transparent {
let mut generics = generics.clone();
generics.params.insert(0, parse_quote!(DB: sqlx::Database));
generics
.make_where_clause()
.predicates
.push(parse_quote!(#ty: sqlx::Type<DB>));
let (impl_generics, _, where_clause) = generics.split_for_impl();
let (impl_generics, _, where_clause) = generics.split_for_impl();
Ok(quote!(
impl #impl_generics sqlx::Type< DB > for #ident #ty_generics #where_clause {
fn type_info() -> DB::TypeInfo {
<#ty as sqlx::Type<DB>>::type_info()
return Ok(quote!(
impl #impl_generics sqlx::Type< DB > for #ident #ty_generics #where_clause {
fn type_info() -> DB::TypeInfo {
<#ty as sqlx::Type<DB>>::type_info()
}
}
}
))
));
}
let mut tts = proc_macro2::TokenStream::new();
if cfg!(feature = "postgres") {
let ty_name = attr.rename.unwrap_or_else(|| ident.to_string());
tts.extend(quote!(
impl sqlx::Type< sqlx::postgres::Postgres > for #ident #ty_generics {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name(#ty_name)
}
}
));
}
Ok(tts)
}
fn expand_derive_has_sql_type_weak_enum(
@ -84,8 +100,7 @@ fn expand_derive_has_sql_type_weak_enum(
let attr = check_weak_enum_attributes(input, variants)?;
let repr = attr.repr.unwrap();
let ident = &input.ident;
Ok(quote!(
let ts = quote!(
impl<DB: sqlx::Database> sqlx::Type<DB> for #ident
where
#repr: sqlx::Type<DB>,
@ -94,7 +109,9 @@ fn expand_derive_has_sql_type_weak_enum(
<#repr as sqlx::Type<DB>>::type_info()
}
}
))
);
Ok(ts)
}
fn expand_derive_has_sql_type_strong_enum(
@ -110,7 +127,7 @@ fn expand_derive_has_sql_type_strong_enum(
tts.extend(quote!(
impl sqlx::Type< sqlx::MySql > for #ident {
fn type_info() -> sqlx::mysql::MySqlTypeInfo {
sqlx::mysql::MySqlTypeInfo::r#enum()
sqlx::mysql::MySqlTypeInfo::__enum()
}
}
));

View file

@ -11,16 +11,14 @@ pub use sqlx_core::query_as::{query_as, query_as_with};
pub use sqlx_core::query_scalar::{query_scalar, query_scalar_with};
pub use sqlx_core::row::{ColumnIndex, Row};
pub use sqlx_core::transaction::{Transaction, TransactionManager};
pub use sqlx_core::types::Type;
pub use sqlx_core::value::{Value, ValueRef};
#[doc(hidden)]
pub use sqlx_core::describe;
#[doc(inline)]
pub use sqlx_core::types::{self, Type};
#[doc(inline)]
pub use sqlx_core::error::{self, BoxDynError, Error, Result};
pub use sqlx_core::error::{self, Error, Result};
#[cfg(feature = "mysql")]
#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))]
@ -42,6 +40,7 @@ pub use sqlx_core::sqlite::{self, Sqlite, SqliteConnection, SqlitePool};
#[doc(hidden)]
pub extern crate sqlx_macros;
// derives
#[cfg(feature = "macros")]
pub use sqlx_macros::{FromRow, Type};
@ -57,6 +56,13 @@ pub mod ty_match;
#[doc(hidden)]
pub mod result_ext;
pub mod types {
pub use sqlx_core::types::*;
#[cfg(feature = "macros")]
pub use sqlx_macros::Type;
}
/// Types and traits for encoding values for the database.
pub mod encode {
pub use sqlx_core::encode::{Encode, IsNull};

View file

@ -1,6 +1,9 @@
use sqlx::{postgres::PgQueryAs, Connection, Cursor, Executor, FromRow, Postgres};
use futures::TryStreamExt;
use sqlx::{Connection, Executor, FromRow, Postgres};
use sqlx_core::postgres::types::PgRange;
use sqlx_test::{new, test_type};
use std::fmt::Debug;
use std::ops::Bound;
// Transparent types are rust-side wrappers over DB types
#[derive(PartialEq, Debug, sqlx::Type)]
@ -37,6 +40,7 @@ enum ColorLower {
Green,
Blue,
}
#[derive(PartialEq, Debug, sqlx::Type)]
#[sqlx(rename = "color_snake")]
#[sqlx(rename_all = "snake_case")]
@ -44,6 +48,7 @@ enum ColorSnake {
RedGreen,
BlueBlack,
}
#[derive(PartialEq, Debug, sqlx::Type)]
#[sqlx(rename = "color_upper")]
#[sqlx(rename_all = "uppercase")]
@ -73,36 +78,43 @@ struct InventoryItem {
price: Option<i64>,
}
test_type!(transparent(
Postgres,
Transparent,
// Custom range type
#[derive(sqlx::Type, Debug, PartialEq)]
#[sqlx(rename = "float_range")]
struct FloatRange(PgRange<f64>);
// Custom domain type
#[derive(sqlx::Type, Debug)]
#[sqlx(rename = "int4rangeL0pC")]
struct RangeInclusive(PgRange<i32>);
test_type!(transparent<Transparent>(Postgres,
"0" == Transparent(0),
"23523" == Transparent(23523)
));
test_type!(weak_enum(
Postgres,
Weak,
test_type!(weak_enum<Weak>(Postgres,
"0::int4" == Weak::One,
"2::int4" == Weak::Two,
"4::int4" == Weak::Three
));
test_type!(strong_enum(
Postgres,
Strong,
test_type!(strong_enum<Strong>(Postgres,
"'one'::text" == Strong::One,
"'two'::text" == Strong::Two,
"'four'::text" == Strong::Three
));
test_type!(floatrange<FloatRange>(Postgres,
"'[1.234, 5.678]'::float_range" == FloatRange(PgRange::from((Bound::Included(1.234), Bound::Included(5.678)))),
));
#[sqlx_macros::test]
async fn test_enum_type() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
conn.execute(
r#"
DROP TABLE IF EXISTS people;
DROP TYPE IF EXISTS mood CASCADE;
@ -154,7 +166,7 @@ RETURNING id
let rec: PeopleRow = sqlx::query_as(
"
SELECT id, mood FROM people WHERE id = $1
",
",
)
.bind(people_id)
.fetch_one(&mut conn)
@ -169,20 +181,23 @@ SELECT id, mood FROM people WHERE id = $1
let stmt = format!("SELECT id, mood FROM people WHERE id = {}", people_id);
dbg!(&stmt);
let mut cursor = conn.fetch(&*stmt);
let row = cursor.next().await?.unwrap();
let row = cursor.try_next().await?.unwrap();
let rec = PeopleRow::from_row(&row)?;
assert_eq!(rec.id, people_id);
assert_eq!(rec.mood, Mood::Sad);
drop(cursor);
// Normal type equivalency test
let rec: (bool, Mood) = sqlx::query_as(
"
SELECT $1 = 'happy'::mood, $1
",
SELECT $1 = 'happy'::mood, $1
",
)
.bind(&Mood::Happy)
.fetch_one(&mut conn)
@ -193,8 +208,8 @@ SELECT $1 = 'happy'::mood, $1
let rec: (bool, ColorLower) = sqlx::query_as(
"
SELECT $1 = 'green'::color_lower, $1
",
SELECT $1 = 'green'::color_lower, $1
",
)
.bind(&ColorLower::Green)
.fetch_one(&mut conn)
@ -205,8 +220,8 @@ SELECT $1 = 'green'::color_lower, $1
let rec: (bool, ColorSnake) = sqlx::query_as(
"
SELECT $1 = 'red_green'::color_snake, $1
",
SELECT $1 = 'red_green'::color_snake, $1
",
)
.bind(&ColorSnake::RedGreen)
.fetch_one(&mut conn)
@ -217,8 +232,8 @@ SELECT $1 = 'red_green'::color_snake, $1
let rec: (bool, ColorUpper) = sqlx::query_as(
"
SELECT $1 = 'RED'::color_upper, $1
",
SELECT $1 = 'RED'::color_upper, $1
",
)
.bind(&ColorUpper::Red)
.fetch_one(&mut conn)
@ -234,23 +249,6 @@ SELECT $1 = 'RED'::color_upper, $1
async fn test_record_type() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;
conn.execute(
r#"
DO $$ BEGIN
CREATE TYPE inventory_item AS (
name text,
supplier_id int,
price bigint
);
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
"#,
)
.await?;
let value = InventoryItem {
name: "fuzzy dice".to_owned(),
supplier_id: Some(42),
@ -259,7 +257,7 @@ END $$;
let rec: (bool, InventoryItem) = sqlx::query_as(
"
SELECT $1 = ROW('fuzzy dice', 42, 199)::inventory_item, $1
SELECT $1 = ROW('fuzzy dice', 42, 199)::inventory_item, $1
",
)
.bind(&value)
@ -275,9 +273,6 @@ END $$;
#[cfg(feature = "macros")]
#[sqlx_macros::test]
async fn test_from_row() -> anyhow::Result<()> {
// Needed for PgQueryAs
use sqlx::prelude::*;
let mut conn = new::<Postgres>().await?;
#[derive(sqlx::FromRow)]
@ -310,7 +305,8 @@ async fn test_from_row() -> anyhow::Result<()> {
.bind(1_i32)
.fetch(&mut conn);
let account = RefAccount::from_row(&cursor.next().await?.unwrap())?;
let row = cursor.try_next().await?.unwrap();
let account = RefAccount::from_row(&row)?;
assert_eq!(account.id, 1);
assert_eq!(account.name, "Herp Derpinson");
@ -321,8 +317,6 @@ async fn test_from_row() -> anyhow::Result<()> {
#[cfg(feature = "macros")]
#[sqlx_macros::test]
async fn test_from_row_with_keyword() -> anyhow::Result<()> {
use sqlx::prelude::*;
#[derive(Debug, sqlx::FromRow)]
struct AccountKeyword {
r#type: i32,
@ -353,8 +347,6 @@ async fn test_from_row_with_keyword() -> anyhow::Result<()> {
#[cfg(feature = "macros")]
#[sqlx_macros::test]
async fn test_from_row_with_rename() -> anyhow::Result<()> {
use sqlx::prelude::*;
#[derive(Debug, sqlx::FromRow)]
struct AccountKeyword {
#[sqlx(rename = "type")]

View file

@ -17,3 +17,9 @@ CREATE TABLE tweet
text TEXT NOT NULL,
owner_id BIGINT
);
CREATE TYPE float_range AS RANGE
(
subtype = float8,
subtype_diff = float8mi
);

View file

@ -1,5 +1,8 @@
extern crate time_ as time;
use std::ops::Bound;
use sqlx::postgres::types::PgRange;
use sqlx::postgres::Postgres;
use sqlx_test::{test_decode_type, test_prepared_type, test_type};
@ -334,35 +337,26 @@ test_type!(decimal<sqlx::types::BigDecimal>(Postgres,
"12345.6789::numeric" == "12345.6789".parse::<sqlx::types::BigDecimal>().unwrap(),
));
mod ranges {
use super::*;
use core::ops::Bound;
use sqlx::postgres::types::{Int4Range, PgRange};
const EXC2: Bound<i32> = Bound::Excluded(2);
const EXC3: Bound<i32> = Bound::Excluded(3);
const INC1: Bound<i32> = Bound::Included(1);
const INC2: Bound<i32> = Bound::Included(2);
const UNB: Bound<i32> = Bound::Unbounded;
const EXC2: Bound<i32> = Bound::Excluded(2);
const EXC3: Bound<i32> = Bound::Excluded(3);
const INC1: Bound<i32> = Bound::Included(1);
const INC2: Bound<i32> = Bound::Included(2);
const UNB: Bound<i32> = Bound::Unbounded;
// int4range display is hard-coded into [l, u)
test_type!(int4range<PgRange<i32>>(Postgres,
"'(,)'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
"'(,]'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
"'(,2)'::int4range" == Int4Range(PgRange::new([UNB, EXC2])),
"'(,2]'::int4range" == Int4Range(PgRange::new([UNB, EXC3])),
"'(1,)'::int4range" == Int4Range(PgRange::new([INC2, UNB])),
"'(1,]'::int4range" == Int4Range(PgRange::new([INC2, UNB])),
"'(1,2]'::int4range" == Int4Range(PgRange::new([INC2, EXC3])),
"'[,)'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
"'[,]'::int4range" == Int4Range(PgRange::new([UNB, UNB])),
"'[,2)'::int4range" == Int4Range(PgRange::new([UNB, EXC2])),
"'[,2]'::int4range" == Int4Range(PgRange::new([UNB, EXC3])),
"'[1,)'::int4range" == Int4Range(PgRange::new([INC1, UNB])),
"'[1,]'::int4range" == Int4Range(PgRange::new([INC1, UNB])),
"'[1,2)'::int4range" == Int4Range(PgRange::new([INC1, EXC2])),
"'[1,2]'::int4range" == Int4Range(PgRange::new([INC1, EXC3])),
));
}
test_type!(int4range<PgRange<i32>>(Postgres,
"'(,)'::int4range" == PgRange::from((UNB, UNB)),
"'(,]'::int4range" == PgRange::from((UNB, UNB)),
"'(,2)'::int4range" == PgRange::from((UNB, EXC2)),
"'(,2]'::int4range" == PgRange::from((UNB, EXC3)),
"'(1,)'::int4range" == PgRange::from((INC2, UNB)),
"'(1,]'::int4range" == PgRange::from((INC2, UNB)),
"'(1,2]'::int4range" == PgRange::from((INC2, EXC3)),
"'[,)'::int4range" == PgRange::from((UNB, UNB)),
"'[,]'::int4range" == PgRange::from((UNB, UNB)),
"'[,2)'::int4range" == PgRange::from((UNB, EXC2)),
"'[,2]'::int4range" == PgRange::from((UNB, EXC3)),
"'[1,)'::int4range" == PgRange::from((INC1, UNB)),
"'[1,]'::int4range" == PgRange::from((INC1, UNB)),
"'[1,2)'::int4range" == PgRange::from((INC1, EXC2)),
"'[1,2]'::int4range" == PgRange::from((INC1, EXC3)),
));