mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
fix(postgres): Add support for domain types description
Fix commit updates the `postgres::connection::describe` module to add full support for domain types. Domain types were previously confused with their category which caused invalid oid resolution. Fixes launchbadge/sqlx#110
This commit is contained in:
parent
edcc91c9f2
commit
93b90be9f7
3 changed files with 245 additions and 16 deletions
|
@ -9,9 +9,86 @@ use crate::query_scalar::{query_scalar, query_scalar_with};
|
|||
use crate::types::Json;
|
||||
use crate::HashMap;
|
||||
use futures_core::future::BoxFuture;
|
||||
use std::convert::TryFrom;
|
||||
use std::fmt::Write;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Describes the type of the `pg_type.typtype` column
|
||||
///
|
||||
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html>
|
||||
enum TypType {
|
||||
Base,
|
||||
Composite,
|
||||
Domain,
|
||||
Enum,
|
||||
Pseudo,
|
||||
Range,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for TypType {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(t: u8) -> Result<Self, Self::Error> {
|
||||
let t = match t {
|
||||
b'b' => Self::Base,
|
||||
b'c' => Self::Composite,
|
||||
b'd' => Self::Domain,
|
||||
b'e' => Self::Enum,
|
||||
b'p' => Self::Pseudo,
|
||||
b'r' => Self::Range,
|
||||
_ => return Err(()),
|
||||
};
|
||||
Ok(t)
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes the type of the `pg_type.typcategory` column
|
||||
///
|
||||
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE>
|
||||
enum TypCategory {
|
||||
Array,
|
||||
Boolean,
|
||||
Composite,
|
||||
DateTime,
|
||||
Enum,
|
||||
Geometric,
|
||||
Network,
|
||||
Numeric,
|
||||
Pseudo,
|
||||
Range,
|
||||
String,
|
||||
Timespan,
|
||||
User,
|
||||
BitString,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for TypCategory {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(c: u8) -> Result<Self, Self::Error> {
|
||||
let c = match c {
|
||||
b'A' => Self::Array,
|
||||
b'B' => Self::Boolean,
|
||||
b'C' => Self::Composite,
|
||||
b'D' => Self::DateTime,
|
||||
b'E' => Self::Enum,
|
||||
b'G' => Self::Geometric,
|
||||
b'I' => Self::Network,
|
||||
b'N' => Self::Numeric,
|
||||
b'P' => Self::Pseudo,
|
||||
b'R' => Self::Range,
|
||||
b'S' => Self::String,
|
||||
b'T' => Self::Timespan,
|
||||
b'U' => Self::User,
|
||||
b'V' => Self::BitString,
|
||||
b'X' => Self::Unknown,
|
||||
_ => return Err(()),
|
||||
};
|
||||
Ok(c)
|
||||
}
|
||||
}
|
||||
|
||||
impl PgConnection {
|
||||
pub(super) async fn handle_row_description(
|
||||
&mut self,
|
||||
|
@ -106,31 +183,46 @@ impl PgConnection {
|
|||
|
||||
fn fetch_type_by_oid(&mut self, oid: u32) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
|
||||
Box::pin(async move {
|
||||
let (name, category, relation_id, element): (String, i8, u32, u32) = query_as(
|
||||
"SELECT typname, typcategory, typrelid, typelem FROM pg_catalog.pg_type WHERE oid = $1",
|
||||
let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, u32, u32, u32) = query_as(
|
||||
"SELECT typname, typtype, typcategory, typrelid, typelem, typbasetype FROM pg_catalog.pg_type WHERE oid = $1",
|
||||
)
|
||||
.bind(oid)
|
||||
.fetch_one(&mut *self)
|
||||
.await?;
|
||||
|
||||
match category as u8 {
|
||||
b'A' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
|
||||
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
|
||||
name: name.into(),
|
||||
oid,
|
||||
})))),
|
||||
let typ_type = TypType::try_from(typ_type as u8);
|
||||
let category = TypCategory::try_from(category as u8);
|
||||
|
||||
b'P' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
|
||||
kind: PgTypeKind::Pseudo,
|
||||
name: name.into(),
|
||||
oid,
|
||||
})))),
|
||||
match (typ_type, category) {
|
||||
(Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
|
||||
|
||||
b'R' => self.fetch_range_by_oid(oid, name).await,
|
||||
(Ok(TypType::Base), Ok(TypCategory::Array)) => {
|
||||
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
|
||||
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
|
||||
name: name.into(),
|
||||
oid,
|
||||
}))))
|
||||
}
|
||||
|
||||
b'E' => self.fetch_enum_by_oid(oid, name).await,
|
||||
(Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => {
|
||||
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
|
||||
kind: PgTypeKind::Pseudo,
|
||||
name: name.into(),
|
||||
oid,
|
||||
}))))
|
||||
}
|
||||
|
||||
b'C' => self.fetch_composite_by_oid(oid, relation_id, name).await,
|
||||
(Ok(TypType::Range), Ok(TypCategory::Range)) => {
|
||||
self.fetch_range_by_oid(oid, name).await
|
||||
}
|
||||
|
||||
(Ok(TypType::Enum), Ok(TypCategory::Enum)) => {
|
||||
self.fetch_enum_by_oid(oid, name).await
|
||||
}
|
||||
|
||||
(Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
|
||||
self.fetch_composite_by_oid(oid, relation_id, name).await
|
||||
}
|
||||
|
||||
_ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
|
||||
kind: PgTypeKind::Simple,
|
||||
|
@ -198,6 +290,23 @@ ORDER BY attnum
|
|||
})
|
||||
}
|
||||
|
||||
fn fetch_domain_by_oid(
|
||||
&mut self,
|
||||
oid: u32,
|
||||
base_type: u32,
|
||||
name: String,
|
||||
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
|
||||
Box::pin(async move {
|
||||
let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;
|
||||
|
||||
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
|
||||
oid,
|
||||
name: name.into(),
|
||||
kind: PgTypeKind::Domain(base_type),
|
||||
}))))
|
||||
})
|
||||
}
|
||||
|
||||
fn fetch_range_by_oid(
|
||||
&mut self,
|
||||
oid: u32,
|
||||
|
|
|
@ -887,3 +887,115 @@ from (values (null)) vals(val)
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn it_supports_domain_types_in_composite_domain_types() -> anyhow::Result<()> {
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
struct MonthId(i16);
|
||||
|
||||
impl sqlx::Type<Postgres> for MonthId {
|
||||
fn type_info() -> sqlx::postgres::PgTypeInfo {
|
||||
sqlx::postgres::PgTypeInfo::with_name("month_id")
|
||||
}
|
||||
|
||||
fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
|
||||
*ty == Self::type_info()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> sqlx::Decode<'r, Postgres> for MonthId {
|
||||
fn decode(
|
||||
value: sqlx::postgres::PgValueRef<'r>,
|
||||
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
|
||||
Ok(Self(<i16 as sqlx::Decode<Postgres>>::decode(value)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q> sqlx::Encode<'q, Postgres> for MonthId {
|
||||
fn encode_by_ref(
|
||||
&self,
|
||||
buf: &mut sqlx::postgres::PgArgumentBuffer,
|
||||
) -> sqlx::encode::IsNull {
|
||||
self.0.encode(buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
struct WinterYearMonth {
|
||||
year: i32,
|
||||
month: MonthId,
|
||||
}
|
||||
|
||||
impl sqlx::Type<Postgres> for WinterYearMonth {
|
||||
fn type_info() -> sqlx::postgres::PgTypeInfo {
|
||||
sqlx::postgres::PgTypeInfo::with_name("winter_year_month")
|
||||
}
|
||||
|
||||
fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
|
||||
*ty == Self::type_info()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r> sqlx::Decode<'r, Postgres> for WinterYearMonth {
|
||||
fn decode(
|
||||
value: sqlx::postgres::PgValueRef<'r>,
|
||||
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
|
||||
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;
|
||||
|
||||
let year = decoder.try_decode::<i32>()?;
|
||||
let month = decoder.try_decode::<MonthId>()?;
|
||||
|
||||
Ok(Self { year, month })
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q> sqlx::Encode<'q, Postgres> for WinterYearMonth {
|
||||
fn encode_by_ref(
|
||||
&self,
|
||||
buf: &mut sqlx::postgres::PgArgumentBuffer,
|
||||
) -> sqlx::encode::IsNull {
|
||||
let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf);
|
||||
encoder.encode(self.year);
|
||||
encoder.encode(self.month);
|
||||
encoder.finish();
|
||||
sqlx::encode::IsNull::No
|
||||
}
|
||||
}
|
||||
|
||||
let mut conn = new::<Postgres>().await?;
|
||||
|
||||
{
|
||||
let result = sqlx::query("DELETE FROM heating_bills;")
|
||||
.execute(&mut conn)
|
||||
.await;
|
||||
|
||||
let result = result.unwrap();
|
||||
assert_eq!(result.rows_affected(), 1);
|
||||
}
|
||||
|
||||
{
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO heating_bills(month, cost) VALUES($1::winter_year_month, 100);",
|
||||
)
|
||||
.bind(WinterYearMonth {
|
||||
year: 2021,
|
||||
month: MonthId(1),
|
||||
})
|
||||
.execute(&mut conn)
|
||||
.await;
|
||||
|
||||
let result = result.unwrap();
|
||||
assert_eq!(result.rows_affected(), 1);
|
||||
}
|
||||
|
||||
{
|
||||
let result = sqlx::query("DELETE FROM heating_bills;")
|
||||
.execute(&mut conn)
|
||||
.await;
|
||||
|
||||
let result = result.unwrap();
|
||||
assert_eq!(result.rows_affected(), 1);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -29,3 +29,11 @@ CREATE TABLE products (
|
|||
name TEXT,
|
||||
price NUMERIC CHECK (price > 0)
|
||||
);
|
||||
|
||||
CREATE DOMAIN month_id AS INT2 CHECK (1 <= value AND value <= 12);
|
||||
CREATE TYPE year_month AS (year INT4, month month_id);
|
||||
CREATE DOMAIN winter_year_month AS year_month CHECK ((value).month <= 3);
|
||||
CREATE TABLE heating_bills (
|
||||
month winter_year_month NOT NULL PRIMARY KEY,
|
||||
cost INT4 NOT NULL
|
||||
);
|
||||
|
|
Loading…
Reference in a new issue