From 9fc9e7518e40c0c6ac674cf5aa278ea0facf6e69 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 22 Nov 2023 17:06:47 -0800 Subject: [PATCH] feat: `Text` adapter (#2894) --- Cargo.toml | 1 - sqlx-core/Cargo.toml | 2 +- sqlx-core/src/types/mod.rs | 4 + sqlx-core/src/types/text.rs | 134 +++++++++++++++++++++++++++++++ sqlx-mysql/src/types/mod.rs | 1 + sqlx-mysql/src/types/text.rs | 49 +++++++++++ sqlx-postgres/src/types/array.rs | 11 +++ sqlx-postgres/src/types/mod.rs | 1 + sqlx-postgres/src/types/text.rs | 50 ++++++++++++ sqlx-sqlite/src/types/mod.rs | 7 +- sqlx-sqlite/src/types/text.rs | 37 +++++++++ tests/mysql/types.rs | 60 +++++++++++++- tests/postgres/types.rs | 48 ++++++++++- tests/sqlite/types.rs | 46 +++++++++++ 14 files changed, 442 insertions(+), 9 deletions(-) create mode 100644 sqlx-core/src/types/text.rs create mode 100644 sqlx-mysql/src/types/text.rs create mode 100644 sqlx-postgres/src/types/text.rs create mode 100644 sqlx-sqlite/src/types/text.rs diff --git a/Cargo.toml b/Cargo.toml index 03d818d8..3ed868ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -345,4 +345,3 @@ required-features = ["postgres", "macros", "migrate"] name = "postgres-migrate" path = "tests/postgres/migrate.rs" required-features = ["postgres", "macros", "migrate"] - diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index e638773a..98688969 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -92,5 +92,5 @@ event-listener = "2.5.2" dotenvy = "0.15" [dev-dependencies] -sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros"] } +sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] } tokio = { version = "1", features = ["rt"] } diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 3ff6c80f..7e8df217 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -28,6 +28,8 @@ pub mod bstr; #[cfg_attr(docsrs, doc(cfg(feature = "json")))] mod json; +mod text; + #[cfg(feature = "uuid")] #[cfg_attr(docsrs, doc(cfg(feature = "uuid")))] #[doc(no_inline)] @@ -81,6 +83,8 @@ pub mod mac_address { #[cfg(feature = "json")] pub use json::{Json, JsonRawValue, JsonValue}; +pub use text::Text; + /// Indicates that a SQL type is supported for a database. /// /// ## Compile-time verification diff --git a/sqlx-core/src/types/text.rs b/sqlx-core/src/types/text.rs new file mode 100644 index 00000000..90480e6d --- /dev/null +++ b/sqlx-core/src/types/text.rs @@ -0,0 +1,134 @@ +use std::ops::{Deref, DerefMut}; + +/// Map a SQL text value to/from a Rust type using [`Display`] and [`FromStr`]. +/// +/// This can be useful for types that do not have a direct SQL equivalent, or are simply not +/// supported by SQLx for one reason or another. +/// +/// For strongly typed databases like Postgres, this will report the value's type as `TEXT`. +/// Explicit conversion may be necessary on the SQL side depending on the desired type. +/// +/// [`Display`]: std::fmt::Display +/// [`FromStr`]: std::str::FromStr +/// +/// ### Panics +/// +/// You should only use this adapter with `Display` implementations that are infallible, +/// otherwise you may encounter panics when attempting to bind a value. +/// +/// This is because the design of the `Encode` trait assumes encoding is infallible, so there is no +/// way to bubble up the error. +/// +/// Fortunately, most `Display` implementations are infallible by convention anyway +/// (the standard `ToString` trait also assumes this), but you may still want to audit +/// the source code for any types you intend to use with this adapter, just to be safe. +/// +/// ### Example: `SocketAddr` +/// +/// MySQL and SQLite do not have a native SQL equivalent for `SocketAddr`, so if you want to +/// store and retrieve instances of it, it makes sense to map it to `TEXT`: +/// +/// ```rust,no_run +/// # use sqlx::types::{time, uuid}; +/// +/// use std::net::SocketAddr; +/// +/// use sqlx::Connection; +/// use sqlx::mysql::MySqlConnection; +/// use sqlx::types::Text; +/// +/// use uuid::Uuid; +/// use time::OffsetDateTime; +/// +/// #[derive(sqlx::FromRow, Debug)] +/// struct Login { +/// user_id: Uuid, +/// socket_addr: Text, +/// login_at: OffsetDateTime +/// } +/// +/// # async fn example() -> Result<(), Box> { +/// +/// let mut conn: MySqlConnection = MySqlConnection::connect("").await?; +/// +/// let user_id: Uuid = "e9a72cdc-d907-48d6-a488-c64a91fd063c".parse().unwrap(); +/// let socket_addr: SocketAddr = "198.51.100.47:31790".parse().unwrap(); +/// +/// // CREATE TABLE user_login(user_id VARCHAR(36), socket_addr TEXT, login_at TIMESTAMP); +/// sqlx::query("INSERT INTO user_login(user_id, socket_addr, login_at) VALUES (?, ?, NOW())") +/// .bind(user_id) +/// .bind(Text(socket_addr)) +/// .execute(&mut conn) +/// .await?; +/// +/// let logins: Vec = sqlx::query_as("SELECT * FROM user_login") +/// .fetch_all(&mut conn) +/// .await?; +/// +/// println!("Logins for user ID {user_id}: {logins:?}"); +/// +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct Text(pub T); + +impl Text { + /// Extract the inner value. + pub fn into_inner(self) -> T { + self.0 + } +} + +impl Deref for Text { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Text { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +/* We shouldn't use blanket impls so individual drivers can provide specialized ones. +impl Type for Text +where + String: Type, + DB: Database, +{ + fn type_info() -> DB::TypeInfo { + String::type_info() + } + + fn compatible(ty: &DB::TypeInfo) -> bool { + String::compatible(ty) + } +} + +impl<'q, T, DB> Encode<'q, DB> for Text +where + T: Display, + String: Encode<'q, DB>, + DB: Database, +{ + fn encode_by_ref(&self, buf: &mut >::ArgumentBuffer) -> IsNull { + self.0.to_string().encode(buf) + } +} + +impl<'r, T, DB> Decode<'r, DB> for Text +where + T: FromStr, + BoxDynError: From<::Err>, + &'r str: Decode<'r, DB>, + DB: Database, +{ + fn decode(value: >::ValueRef) -> Result { + Ok(Text(<&'r str as Decode<'r, DB>>::decode(value)?.parse()?)) + } +} +*/ diff --git a/sqlx-mysql/src/types/mod.rs b/sqlx-mysql/src/types/mod.rs index 889c4d51..6cbbde71 100644 --- a/sqlx-mysql/src/types/mod.rs +++ b/sqlx-mysql/src/types/mod.rs @@ -104,6 +104,7 @@ mod bytes; mod float; mod int; mod str; +mod text; mod uint; #[cfg(feature = "json")] diff --git a/sqlx-mysql/src/types/text.rs b/sqlx-mysql/src/types/text.rs new file mode 100644 index 00000000..6b617289 --- /dev/null +++ b/sqlx-mysql/src/types/text.rs @@ -0,0 +1,49 @@ +use crate::{MySql, MySqlTypeInfo, MySqlValueRef}; +use sqlx_core::decode::Decode; +use sqlx_core::encode::{Encode, IsNull}; +use sqlx_core::error::BoxDynError; +use sqlx_core::types::{Text, Type}; +use std::fmt::Display; +use std::str::FromStr; + +impl Type for Text { + fn type_info() -> MySqlTypeInfo { + >::type_info() + } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + >::compatible(ty) + } +} + +impl<'q, T> Encode<'q, MySql> for Text +where + T: Display, +{ + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + // We can't really do the trick like with Postgres where we reserve the space for the + // length up-front and then overwrite it later, because MySQL appears to enforce that + // length-encoded integers use the smallest encoding for the value: + // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le + // + // So we'd have to reserve space for the max-width encoding, format into the buffer, + // then figure out how many bytes our length-encoded integer needs to be and move the + // value bytes down to use up the empty space. + // + // Copying from a completely separate buffer instead is easier. It may or may not be faster + // or slower depending on a ton of different variables, but I don't currently have the time + // to implement both approaches and compare their performance. + Encode::::encode(self.0.to_string(), buf) + } +} + +impl<'r, T> Decode<'r, MySql> for Text +where + T: FromStr, + BoxDynError: From<::Err>, +{ + fn decode(value: MySqlValueRef<'r>) -> Result { + let s: &str = Decode::::decode(value)?; + Ok(Self(s.parse()?)) + } +} diff --git a/sqlx-postgres/src/types/array.rs b/sqlx-postgres/src/types/array.rs index cbe760a2..dac9b684 100644 --- a/sqlx-postgres/src/types/array.rs +++ b/sqlx-postgres/src/types/array.rs @@ -1,4 +1,5 @@ use sqlx_core::bytes::Buf; +use sqlx_core::types::Text; use std::borrow::Cow; use crate::decode::Decode; @@ -67,6 +68,16 @@ where } } +impl PgHasArrayType for Text { + fn array_type_info() -> PgTypeInfo { + String::array_type_info() + } + + fn array_compatible(ty: &PgTypeInfo) -> bool { + String::array_compatible(ty) + } +} + impl Type for [T] where T: PgHasArrayType, diff --git a/sqlx-postgres/src/types/mod.rs b/sqlx-postgres/src/types/mod.rs index 5a80b8e5..cf2c9ea9 100644 --- a/sqlx-postgres/src/types/mod.rs +++ b/sqlx-postgres/src/types/mod.rs @@ -193,6 +193,7 @@ mod oid; mod range; mod record; mod str; +mod text; mod tuple; mod void; diff --git a/sqlx-postgres/src/types/text.rs b/sqlx-postgres/src/types/text.rs new file mode 100644 index 00000000..7e96d03f --- /dev/null +++ b/sqlx-postgres/src/types/text.rs @@ -0,0 +1,50 @@ +use crate::{PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres}; +use sqlx_core::decode::Decode; +use sqlx_core::encode::{Encode, IsNull}; +use sqlx_core::error::BoxDynError; +use sqlx_core::types::{Text, Type}; +use std::fmt::Display; +use std::str::FromStr; + +use std::io::Write; + +impl Type for Text { + fn type_info() -> PgTypeInfo { + >::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + >::compatible(ty) + } +} + +impl<'q, T> Encode<'q, Postgres> for Text +where + T: Display, +{ + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + // Unfortunately, our API design doesn't give us a way to bubble up the error here. + // + // Fortunately, writing to `Vec` is infallible so the only possible source of + // errors is from the implementation of `Display::fmt()` itself, + // where the onus is on the user. + // + // The blanket impl of `ToString` also panics if there's an error, so this is not + // unprecedented. + // + // However, the panic should be documented anyway. + write!(**buf, "{}", self.0).expect("unexpected error from `Display::fmt()`"); + IsNull::No + } +} + +impl<'r, T> Decode<'r, Postgres> for Text +where + T: FromStr, + BoxDynError: From<::Err>, +{ + fn decode(value: PgValueRef<'r>) -> Result { + let s: &str = Decode::::decode(value)?; + Ok(Self(s.parse()?)) + } +} diff --git a/sqlx-sqlite/src/types/mod.rs b/sqlx-sqlite/src/types/mod.rs index 33a1d584..54284c8a 100644 --- a/sqlx-sqlite/src/types/mod.rs +++ b/sqlx-sqlite/src/types/mod.rs @@ -124,12 +124,14 @@ //! over a floating-point type in the first place. //! //! Instead, you should only use a type affinity that SQLite will not attempt to convert implicitly, -//! such as `TEXT` or `BLOB`, and map values to/from SQLite as strings. +//! such as `TEXT` or `BLOB`, and map values to/from SQLite as strings. You can do this easily +//! using [the `Text` adapter]. +//! //! //! [`decimal.c`]: https://www.sqlite.org/floatingpoint.html#the_decimal_c_extension //! [amalgamation]: https://www.sqlite.org/amalgamation.html //! [type-affinity]: https://www.sqlite.org/datatype3.html#type_affinity -//! +//! [the `Text` adapter]: Text pub(crate) use sqlx_core::types::*; @@ -142,6 +144,7 @@ mod int; #[cfg(feature = "json")] mod json; mod str; +mod text; #[cfg(feature = "time")] mod time; mod uint; diff --git a/sqlx-sqlite/src/types/text.rs b/sqlx-sqlite/src/types/text.rs new file mode 100644 index 00000000..63fd01a8 --- /dev/null +++ b/sqlx-sqlite/src/types/text.rs @@ -0,0 +1,37 @@ +use crate::{Sqlite, SqliteArgumentValue, SqliteTypeInfo, SqliteValueRef}; +use sqlx_core::decode::Decode; +use sqlx_core::encode::{Encode, IsNull}; +use sqlx_core::error::BoxDynError; +use sqlx_core::types::{Text, Type}; +use std::fmt::Display; +use std::str::FromStr; + +impl Type for Text { + fn type_info() -> SqliteTypeInfo { + >::type_info() + } + + fn compatible(ty: &SqliteTypeInfo) -> bool { + >::compatible(ty) + } +} + +impl<'q, T> Encode<'q, Sqlite> for Text +where + T: Display, +{ + fn encode_by_ref(&self, buf: &mut Vec>) -> IsNull { + Encode::::encode(self.0.to_string(), buf) + } +} + +impl<'r, T> Decode<'r, Sqlite> for Text +where + T: FromStr, + BoxDynError: From<::Err>, +{ + fn decode(value: SqliteValueRef<'r>) -> Result { + let s: &str = Decode::::decode(value)?; + Ok(Self(s.parse()?)) + } +} diff --git a/tests/mysql/types.rs b/tests/mysql/types.rs index 497e4576..fad95b36 100644 --- a/tests/mysql/types.rs +++ b/tests/mysql/types.rs @@ -1,10 +1,14 @@ extern crate time_ as time; +use std::net::SocketAddr; #[cfg(feature = "rust_decimal")] use std::str::FromStr; use sqlx::mysql::MySql; use sqlx::{Executor, Row}; + +use sqlx::types::Text; + use sqlx_test::{new, test_type}; test_type!(bool(MySql, "false" == false, "true" == true)); @@ -68,9 +72,10 @@ test_type!(uuid_simple(MySql, #[cfg(feature = "chrono")] mod chrono { - use super::*; use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; + use super::*; + test_type!(chrono_date(MySql, "DATE '2001-01-05'" == NaiveDate::from_ymd(2001, 1, 5), "DATE '2050-11-23'" == NaiveDate::from_ymd(2050, 11, 23) @@ -137,10 +142,12 @@ mod chrono { #[cfg(feature = "time")] mod time_tests { - use super::*; - use sqlx::types::time::{Date, OffsetDateTime, PrimitiveDateTime, Time}; use time::macros::{date, time}; + use sqlx::types::time::{Date, OffsetDateTime, PrimitiveDateTime, Time}; + + use super::*; + test_type!(time_date( MySql, "DATE '2001-01-05'" == date!(2001 - 1 - 5), @@ -236,11 +243,13 @@ test_type!(decimal(MySql, #[cfg(feature = "json")] mod json_tests { - use super::*; use serde_json::{json, Value as JsonValue}; + use sqlx::types::Json; use sqlx_test::test_type; + use super::*; + test_type!(json( MySql, // MySQL 8.0.27 changed `<=>` to return an unsigned integer @@ -319,3 +328,46 @@ CREATE TEMPORARY TABLE with_bits ( Ok(()) } + +#[sqlx_macros::test] +async fn test_text_adapter() -> anyhow::Result<()> { + #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] + struct Login { + user_id: i32, + socket_addr: Text, + #[cfg(feature = "time")] + login_at: time::OffsetDateTime, + } + + let mut conn = new::().await?; + + conn.execute( + r#" +CREATE TEMPORARY TABLE user_login ( + user_id INT PRIMARY KEY AUTO_INCREMENT, + socket_addr TEXT NOT NULL, + login_at TIMESTAMP NOT NULL +); + "#, + ) + .await?; + + let user_id = 1234; + let socket_addr: SocketAddr = "198.51.100.47:31790".parse().unwrap(); + + sqlx::query("INSERT INTO user_login (user_id, socket_addr, login_at) VALUES (?, ?, NOW())") + .bind(user_id) + .bind(Text(socket_addr)) + .execute(&mut conn) + .await?; + + let last_login: Login = + sqlx::query_as("SELECT * FROM user_login ORDER BY login_at DESC LIMIT 1") + .fetch_one(&mut conn) + .await?; + + assert_eq!(last_login.user_id, user_id); + assert_eq!(*last_login.socket_addr, socket_addr); + + Ok(()) +} diff --git a/tests/postgres/types.rs b/tests/postgres/types.rs index 57cf0b9d..184007ce 100644 --- a/tests/postgres/types.rs +++ b/tests/postgres/types.rs @@ -1,11 +1,14 @@ extern crate time_ as time; +use std::net::SocketAddr; use std::ops::Bound; use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange}; use sqlx::postgres::Postgres; -use sqlx_test::{test_decode_type, test_prepared_type, test_type}; +use sqlx_test::{new, test_decode_type, test_prepared_type, test_type}; +use sqlx_core::executor::Executor; +use sqlx_core::types::Text; use std::str::FromStr; test_type!(null>(Postgres, @@ -579,3 +582,46 @@ test_type!(ltree_vec>(Postgres, sqlx::postgres::types::PgLTree::from_iter(["Alpha", "Beta", "Delta", "Gamma"]).unwrap() ] )); + +#[sqlx_macros::test] +async fn test_text_adapter() -> anyhow::Result<()> { + #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] + struct Login { + user_id: i32, + socket_addr: Text, + #[cfg(feature = "time")] + login_at: time::OffsetDateTime, + } + + let mut conn = new::().await?; + + conn.execute( + r#" +CREATE TEMPORARY TABLE user_login ( + user_id INT PRIMARY KEY, + socket_addr TEXT NOT NULL, + login_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + "#, + ) + .await?; + + let user_id = 1234; + let socket_addr: SocketAddr = "198.51.100.47:31790".parse().unwrap(); + + sqlx::query("INSERT INTO user_login (user_id, socket_addr) VALUES ($1, $2)") + .bind(user_id) + .bind(Text(socket_addr)) + .execute(&mut conn) + .await?; + + let last_login: Login = + sqlx::query_as("SELECT * FROM user_login ORDER BY login_at DESC LIMIT 1") + .fetch_one(&mut conn) + .await?; + + assert_eq!(last_login.user_id, user_id); + assert_eq!(*last_login.socket_addr, socket_addr); + + Ok(()) +} diff --git a/tests/sqlite/types.rs b/tests/sqlite/types.rs index 307e1409..71c788d4 100644 --- a/tests/sqlite/types.rs +++ b/tests/sqlite/types.rs @@ -1,9 +1,12 @@ extern crate time_ as time; use sqlx::sqlite::{Sqlite, SqliteRow}; +use sqlx_core::executor::Executor; use sqlx_core::row::Row; +use sqlx_core::types::Text; use sqlx_test::new; use sqlx_test::test_type; +use std::net::SocketAddr; test_type!(null>(Sqlite, "NULL" == None:: @@ -204,3 +207,46 @@ test_type!(uuid_simple(Sqlite, "'00000000000000000000000000000000'" == sqlx::types::Uuid::parse_str("00000000000000000000000000000000").unwrap().simple() )); + +#[sqlx_macros::test] +async fn test_text_adapter() -> anyhow::Result<()> { + #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] + struct Login { + user_id: i32, + socket_addr: Text, + #[cfg(feature = "time")] + login_at: time::OffsetDateTime, + } + + let mut conn = new::().await?; + + conn.execute( + r#" +CREATE TEMPORARY TABLE user_login ( + user_id INT PRIMARY KEY, + socket_addr TEXT NOT NULL, + login_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); + "#, + ) + .await?; + + let user_id = 1234; + let socket_addr: SocketAddr = "198.51.100.47:31790".parse().unwrap(); + + sqlx::query("INSERT INTO user_login (user_id, socket_addr) VALUES (?, ?)") + .bind(user_id) + .bind(Text(socket_addr)) + .execute(&mut conn) + .await?; + + let last_login: Login = + sqlx::query_as("SELECT * FROM user_login ORDER BY login_at DESC LIMIT 1") + .fetch_one(&mut conn) + .await?; + + assert_eq!(last_login.user_id, user_id); + assert_eq!(*last_login.socket_addr, socket_addr); + + Ok(()) +}