fix(postgres): Add support for domain types description

Fix commit updates the `postgres::connection::describe` module to add full support for domain types. Domain types were previously confused with their category which caused invalid oid resolution.

Fixes launchbadge/sqlx#110
This commit is contained in:
Charles Samborski 2021-03-16 18:39:52 +01:00
parent edcc91c9f2
commit 93b90be9f7
3 changed files with 245 additions and 16 deletions

View file

@ -9,9 +9,86 @@ use crate::query_scalar::{query_scalar, query_scalar_with};
use crate::types::Json;
use crate::HashMap;
use futures_core::future::BoxFuture;
use std::convert::TryFrom;
use std::fmt::Write;
use std::sync::Arc;
/// Describes the type of the `pg_type.typtype` column
///
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html>
enum TypType {
Base,
Composite,
Domain,
Enum,
Pseudo,
Range,
}
impl TryFrom<u8> for TypType {
type Error = ();
fn try_from(t: u8) -> Result<Self, Self::Error> {
let t = match t {
b'b' => Self::Base,
b'c' => Self::Composite,
b'd' => Self::Domain,
b'e' => Self::Enum,
b'p' => Self::Pseudo,
b'r' => Self::Range,
_ => return Err(()),
};
Ok(t)
}
}
/// Describes the type of the `pg_type.typcategory` column
///
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE>
enum TypCategory {
Array,
Boolean,
Composite,
DateTime,
Enum,
Geometric,
Network,
Numeric,
Pseudo,
Range,
String,
Timespan,
User,
BitString,
Unknown,
}
impl TryFrom<u8> for TypCategory {
type Error = ();
fn try_from(c: u8) -> Result<Self, Self::Error> {
let c = match c {
b'A' => Self::Array,
b'B' => Self::Boolean,
b'C' => Self::Composite,
b'D' => Self::DateTime,
b'E' => Self::Enum,
b'G' => Self::Geometric,
b'I' => Self::Network,
b'N' => Self::Numeric,
b'P' => Self::Pseudo,
b'R' => Self::Range,
b'S' => Self::String,
b'T' => Self::Timespan,
b'U' => Self::User,
b'V' => Self::BitString,
b'X' => Self::Unknown,
_ => return Err(()),
};
Ok(c)
}
}
impl PgConnection {
pub(super) async fn handle_row_description(
&mut self,
@ -106,31 +183,46 @@ impl PgConnection {
fn fetch_type_by_oid(&mut self, oid: u32) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let (name, category, relation_id, element): (String, i8, u32, u32) = query_as(
"SELECT typname, typcategory, typrelid, typelem FROM pg_catalog.pg_type WHERE oid = $1",
let (name, typ_type, category, relation_id, element, base_type): (String, i8, i8, u32, u32, u32) = query_as(
"SELECT typname, typtype, typcategory, typrelid, typelem, typbasetype FROM pg_catalog.pg_type WHERE oid = $1",
)
.bind(oid)
.fetch_one(&mut *self)
.await?;
match category as u8 {
b'A' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
name: name.into(),
oid,
})))),
let typ_type = TypType::try_from(typ_type as u8);
let category = TypCategory::try_from(category as u8);
b'P' => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Pseudo,
name: name.into(),
oid,
})))),
match (typ_type, category) {
(Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
b'R' => self.fetch_range_by_oid(oid, name).await,
(Ok(TypType::Base), Ok(TypCategory::Array)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
name: name.into(),
oid,
}))))
}
b'E' => self.fetch_enum_by_oid(oid, name).await,
(Ok(TypType::Pseudo), Ok(TypCategory::Pseudo)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Pseudo,
name: name.into(),
oid,
}))))
}
b'C' => self.fetch_composite_by_oid(oid, relation_id, name).await,
(Ok(TypType::Range), Ok(TypCategory::Range)) => {
self.fetch_range_by_oid(oid, name).await
}
(Ok(TypType::Enum), Ok(TypCategory::Enum)) => {
self.fetch_enum_by_oid(oid, name).await
}
(Ok(TypType::Composite), Ok(TypCategory::Composite)) => {
self.fetch_composite_by_oid(oid, relation_id, name).await
}
_ => Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Simple,
@ -198,6 +290,23 @@ ORDER BY attnum
})
}
fn fetch_domain_by_oid(
&mut self,
oid: u32,
base_type: u32,
name: String,
) -> BoxFuture<'_, Result<PgTypeInfo, Error>> {
Box::pin(async move {
let base_type = self.maybe_fetch_type_info_by_oid(base_type, true).await?;
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
oid,
name: name.into(),
kind: PgTypeKind::Domain(base_type),
}))))
})
}
fn fetch_range_by_oid(
&mut self,
oid: u32,

View file

@ -887,3 +887,115 @@ from (values (null)) vals(val)
Ok(())
}
#[sqlx_macros::test]
async fn it_supports_domain_types_in_composite_domain_types() -> anyhow::Result<()> {
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct MonthId(i16);
impl sqlx::Type<Postgres> for MonthId {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name("month_id")
}
fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
*ty == Self::type_info()
}
}
impl<'r> sqlx::Decode<'r, Postgres> for MonthId {
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
Ok(Self(<i16 as sqlx::Decode<Postgres>>::decode(value)?))
}
}
impl<'q> sqlx::Encode<'q, Postgres> for MonthId {
fn encode_by_ref(
&self,
buf: &mut sqlx::postgres::PgArgumentBuffer,
) -> sqlx::encode::IsNull {
self.0.encode(buf)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct WinterYearMonth {
year: i32,
month: MonthId,
}
impl sqlx::Type<Postgres> for WinterYearMonth {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name("winter_year_month")
}
fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
*ty == Self::type_info()
}
}
impl<'r> sqlx::Decode<'r, Postgres> for WinterYearMonth {
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;
let year = decoder.try_decode::<i32>()?;
let month = decoder.try_decode::<MonthId>()?;
Ok(Self { year, month })
}
}
impl<'q> sqlx::Encode<'q, Postgres> for WinterYearMonth {
fn encode_by_ref(
&self,
buf: &mut sqlx::postgres::PgArgumentBuffer,
) -> sqlx::encode::IsNull {
let mut encoder = sqlx::postgres::types::PgRecordEncoder::new(buf);
encoder.encode(self.year);
encoder.encode(self.month);
encoder.finish();
sqlx::encode::IsNull::No
}
}
let mut conn = new::<Postgres>().await?;
{
let result = sqlx::query("DELETE FROM heating_bills;")
.execute(&mut conn)
.await;
let result = result.unwrap();
assert_eq!(result.rows_affected(), 1);
}
{
let result = sqlx::query(
"INSERT INTO heating_bills(month, cost) VALUES($1::winter_year_month, 100);",
)
.bind(WinterYearMonth {
year: 2021,
month: MonthId(1),
})
.execute(&mut conn)
.await;
let result = result.unwrap();
assert_eq!(result.rows_affected(), 1);
}
{
let result = sqlx::query("DELETE FROM heating_bills;")
.execute(&mut conn)
.await;
let result = result.unwrap();
assert_eq!(result.rows_affected(), 1);
}
Ok(())
}

View file

@ -29,3 +29,11 @@ CREATE TABLE products (
name TEXT,
price NUMERIC CHECK (price > 0)
);
CREATE DOMAIN month_id AS INT2 CHECK (1 <= value AND value <= 12);
CREATE TYPE year_month AS (year INT4, month month_id);
CREATE DOMAIN winter_year_month AS year_month CHECK ((value).month <= 3);
CREATE TABLE heating_bills (
month winter_year_month NOT NULL PRIMARY KEY,
cost INT4 NOT NULL
);