feat: Text adapter (#2894)

This commit is contained in:
Austin Bonander 2023-11-22 17:06:47 -08:00 committed by GitHub
parent 62f82cc43a
commit 9fc9e7518e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 442 additions and 9 deletions

View file

@ -345,4 +345,3 @@ required-features = ["postgres", "macros", "migrate"]
name = "postgres-migrate"
path = "tests/postgres/migrate.rs"
required-features = ["postgres", "macros", "migrate"]

View file

@ -92,5 +92,5 @@ event-listener = "2.5.2"
dotenvy = "0.15"
[dev-dependencies]
sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros"] }
sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] }
tokio = { version = "1", features = ["rt"] }

View file

@ -28,6 +28,8 @@ pub mod bstr;
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
mod json;
mod text;
#[cfg(feature = "uuid")]
#[cfg_attr(docsrs, doc(cfg(feature = "uuid")))]
#[doc(no_inline)]
@ -81,6 +83,8 @@ pub mod mac_address {
#[cfg(feature = "json")]
pub use json::{Json, JsonRawValue, JsonValue};
pub use text::Text;
/// Indicates that a SQL type is supported for a database.
///
/// ## Compile-time verification

134
sqlx-core/src/types/text.rs Normal file
View file

@ -0,0 +1,134 @@
use std::ops::{Deref, DerefMut};
/// Map a SQL text value to/from a Rust type using [`Display`] and [`FromStr`].
///
/// This can be useful for types that do not have a direct SQL equivalent, or are simply not
/// supported by SQLx for one reason or another.
///
/// For strongly typed databases like Postgres, this will report the value's type as `TEXT`.
/// Explicit conversion may be necessary on the SQL side depending on the desired type.
///
/// [`Display`]: std::fmt::Display
/// [`FromStr`]: std::str::FromStr
///
/// ### Panics
///
/// You should only use this adapter with `Display` implementations that are infallible,
/// otherwise you may encounter panics when attempting to bind a value.
///
/// This is because the design of the `Encode` trait assumes encoding is infallible, so there is no
/// way to bubble up the error.
///
/// Fortunately, most `Display` implementations are infallible by convention anyway
/// (the standard `ToString` trait also assumes this), but you may still want to audit
/// the source code for any types you intend to use with this adapter, just to be safe.
///
/// ### Example: `SocketAddr`
///
/// MySQL and SQLite do not have a native SQL equivalent for `SocketAddr`, so if you want to
/// store and retrieve instances of it, it makes sense to map it to `TEXT`:
///
/// ```rust,no_run
/// # use sqlx::types::{time, uuid};
///
/// use std::net::SocketAddr;
///
/// use sqlx::Connection;
/// use sqlx::mysql::MySqlConnection;
/// use sqlx::types::Text;
///
/// use uuid::Uuid;
/// use time::OffsetDateTime;
///
/// #[derive(sqlx::FromRow, Debug)]
/// struct Login {
/// user_id: Uuid,
/// socket_addr: Text<SocketAddr>,
/// login_at: OffsetDateTime
/// }
///
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
///
/// let mut conn: MySqlConnection = MySqlConnection::connect("<DATABASE URL>").await?;
///
/// let user_id: Uuid = "e9a72cdc-d907-48d6-a488-c64a91fd063c".parse().unwrap();
/// let socket_addr: SocketAddr = "198.51.100.47:31790".parse().unwrap();
///
/// // CREATE TABLE user_login(user_id VARCHAR(36), socket_addr TEXT, login_at TIMESTAMP);
/// sqlx::query("INSERT INTO user_login(user_id, socket_addr, login_at) VALUES (?, ?, NOW())")
/// .bind(user_id)
/// .bind(Text(socket_addr))
/// .execute(&mut conn)
/// .await?;
///
/// let logins: Vec<Login> = sqlx::query_as("SELECT * FROM user_login")
/// .fetch_all(&mut conn)
/// .await?;
///
/// println!("Logins for user ID {user_id}: {logins:?}");
///
/// # Ok(())
/// # }
/// ```
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct Text<T>(pub T);
impl<T> Text<T> {
/// Extract the inner value.
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> Deref for Text<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> DerefMut for Text<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
/* We shouldn't use blanket impls so individual drivers can provide specialized ones.
impl<T, DB> Type<DB> for Text<T>
where
String: Type<DB>,
DB: Database,
{
fn type_info() -> DB::TypeInfo {
String::type_info()
}
fn compatible(ty: &DB::TypeInfo) -> bool {
String::compatible(ty)
}
}
impl<'q, T, DB> Encode<'q, DB> for Text<T>
where
T: Display,
String: Encode<'q, DB>,
DB: Database,
{
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
self.0.to_string().encode(buf)
}
}
impl<'r, T, DB> Decode<'r, DB> for Text<T>
where
T: FromStr,
BoxDynError: From<<T as FromStr>::Err>,
&'r str: Decode<'r, DB>,
DB: Database,
{
fn decode(value: <DB as HasValueRef<'r>>::ValueRef) -> Result<Self, BoxDynError> {
Ok(Text(<&'r str as Decode<'r, DB>>::decode(value)?.parse()?))
}
}
*/

View file

@ -104,6 +104,7 @@ mod bytes;
mod float;
mod int;
mod str;
mod text;
mod uint;
#[cfg(feature = "json")]

View file

@ -0,0 +1,49 @@
use crate::{MySql, MySqlTypeInfo, MySqlValueRef};
use sqlx_core::decode::Decode;
use sqlx_core::encode::{Encode, IsNull};
use sqlx_core::error::BoxDynError;
use sqlx_core::types::{Text, Type};
use std::fmt::Display;
use std::str::FromStr;
impl<T> Type<MySql> for Text<T> {
fn type_info() -> MySqlTypeInfo {
<String as Type<MySql>>::type_info()
}
fn compatible(ty: &MySqlTypeInfo) -> bool {
<String as Type<MySql>>::compatible(ty)
}
}
impl<'q, T> Encode<'q, MySql> for Text<T>
where
T: Display,
{
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
// We can't really do the trick like with Postgres where we reserve the space for the
// length up-front and then overwrite it later, because MySQL appears to enforce that
// length-encoded integers use the smallest encoding for the value:
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_dt_integers.html#sect_protocol_basic_dt_int_le
//
// So we'd have to reserve space for the max-width encoding, format into the buffer,
// then figure out how many bytes our length-encoded integer needs to be and move the
// value bytes down to use up the empty space.
//
// Copying from a completely separate buffer instead is easier. It may or may not be faster
// or slower depending on a ton of different variables, but I don't currently have the time
// to implement both approaches and compare their performance.
Encode::<MySql>::encode(self.0.to_string(), buf)
}
}
impl<'r, T> Decode<'r, MySql> for Text<T>
where
T: FromStr,
BoxDynError: From<<T as FromStr>::Err>,
{
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
let s: &str = Decode::<MySql>::decode(value)?;
Ok(Self(s.parse()?))
}
}

View file

@ -1,4 +1,5 @@
use sqlx_core::bytes::Buf;
use sqlx_core::types::Text;
use std::borrow::Cow;
use crate::decode::Decode;
@ -67,6 +68,16 @@ where
}
}
impl<T> PgHasArrayType for Text<T> {
fn array_type_info() -> PgTypeInfo {
String::array_type_info()
}
fn array_compatible(ty: &PgTypeInfo) -> bool {
String::array_compatible(ty)
}
}
impl<T> Type<Postgres> for [T]
where
T: PgHasArrayType,

View file

@ -193,6 +193,7 @@ mod oid;
mod range;
mod record;
mod str;
mod text;
mod tuple;
mod void;

View file

@ -0,0 +1,50 @@
use crate::{PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres};
use sqlx_core::decode::Decode;
use sqlx_core::encode::{Encode, IsNull};
use sqlx_core::error::BoxDynError;
use sqlx_core::types::{Text, Type};
use std::fmt::Display;
use std::str::FromStr;
use std::io::Write;
impl<T> Type<Postgres> for Text<T> {
fn type_info() -> PgTypeInfo {
<String as Type<Postgres>>::type_info()
}
fn compatible(ty: &PgTypeInfo) -> bool {
<String as Type<Postgres>>::compatible(ty)
}
}
impl<'q, T> Encode<'q, Postgres> for Text<T>
where
T: Display,
{
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
// Unfortunately, our API design doesn't give us a way to bubble up the error here.
//
// Fortunately, writing to `Vec<u8>` is infallible so the only possible source of
// errors is from the implementation of `Display::fmt()` itself,
// where the onus is on the user.
//
// The blanket impl of `ToString` also panics if there's an error, so this is not
// unprecedented.
//
// However, the panic should be documented anyway.
write!(**buf, "{}", self.0).expect("unexpected error from `Display::fmt()`");
IsNull::No
}
}
impl<'r, T> Decode<'r, Postgres> for Text<T>
where
T: FromStr,
BoxDynError: From<<T as FromStr>::Err>,
{
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
let s: &str = Decode::<Postgres>::decode(value)?;
Ok(Self(s.parse()?))
}
}

View file

@ -124,12 +124,14 @@
//! over a floating-point type in the first place.
//!
//! Instead, you should only use a type affinity that SQLite will not attempt to convert implicitly,
//! such as `TEXT` or `BLOB`, and map values to/from SQLite as strings.
//! such as `TEXT` or `BLOB`, and map values to/from SQLite as strings. You can do this easily
//! using [the `Text` adapter].
//!
//!
//! [`decimal.c`]: https://www.sqlite.org/floatingpoint.html#the_decimal_c_extension
//! [amalgamation]: https://www.sqlite.org/amalgamation.html
//! [type-affinity]: https://www.sqlite.org/datatype3.html#type_affinity
//!
//! [the `Text` adapter]: Text
pub(crate) use sqlx_core::types::*;
@ -142,6 +144,7 @@ mod int;
#[cfg(feature = "json")]
mod json;
mod str;
mod text;
#[cfg(feature = "time")]
mod time;
mod uint;

View file

@ -0,0 +1,37 @@
use crate::{Sqlite, SqliteArgumentValue, SqliteTypeInfo, SqliteValueRef};
use sqlx_core::decode::Decode;
use sqlx_core::encode::{Encode, IsNull};
use sqlx_core::error::BoxDynError;
use sqlx_core::types::{Text, Type};
use std::fmt::Display;
use std::str::FromStr;
impl<T> Type<Sqlite> for Text<T> {
fn type_info() -> SqliteTypeInfo {
<String as Type<Sqlite>>::type_info()
}
fn compatible(ty: &SqliteTypeInfo) -> bool {
<String as Type<Sqlite>>::compatible(ty)
}
}
impl<'q, T> Encode<'q, Sqlite> for Text<T>
where
T: Display,
{
fn encode_by_ref(&self, buf: &mut Vec<SqliteArgumentValue<'q>>) -> IsNull {
Encode::<Sqlite>::encode(self.0.to_string(), buf)
}
}
impl<'r, T> Decode<'r, Sqlite> for Text<T>
where
T: FromStr,
BoxDynError: From<<T as FromStr>::Err>,
{
fn decode(value: SqliteValueRef<'r>) -> Result<Self, BoxDynError> {
let s: &str = Decode::<Sqlite>::decode(value)?;
Ok(Self(s.parse()?))
}
}

View file

@ -1,10 +1,14 @@
extern crate time_ as time;
use std::net::SocketAddr;
#[cfg(feature = "rust_decimal")]
use std::str::FromStr;
use sqlx::mysql::MySql;
use sqlx::{Executor, Row};
use sqlx::types::Text;
use sqlx_test::{new, test_type};
test_type!(bool(MySql, "false" == false, "true" == true));
@ -68,9 +72,10 @@ test_type!(uuid_simple<sqlx::types::uuid::fmt::Simple>(MySql,
#[cfg(feature = "chrono")]
mod chrono {
use super::*;
use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
use super::*;
test_type!(chrono_date<NaiveDate>(MySql,
"DATE '2001-01-05'" == NaiveDate::from_ymd(2001, 1, 5),
"DATE '2050-11-23'" == NaiveDate::from_ymd(2050, 11, 23)
@ -137,10 +142,12 @@ mod chrono {
#[cfg(feature = "time")]
mod time_tests {
use super::*;
use sqlx::types::time::{Date, OffsetDateTime, PrimitiveDateTime, Time};
use time::macros::{date, time};
use sqlx::types::time::{Date, OffsetDateTime, PrimitiveDateTime, Time};
use super::*;
test_type!(time_date<Date>(
MySql,
"DATE '2001-01-05'" == date!(2001 - 1 - 5),
@ -236,11 +243,13 @@ test_type!(decimal<sqlx::types::Decimal>(MySql,
#[cfg(feature = "json")]
mod json_tests {
use super::*;
use serde_json::{json, Value as JsonValue};
use sqlx::types::Json;
use sqlx_test::test_type;
use super::*;
test_type!(json<JsonValue>(
MySql,
// MySQL 8.0.27 changed `<=>` to return an unsigned integer
@ -319,3 +328,46 @@ CREATE TEMPORARY TABLE with_bits (
Ok(())
}
#[sqlx_macros::test]
async fn test_text_adapter() -> anyhow::Result<()> {
#[derive(sqlx::FromRow, Debug, PartialEq, Eq)]
struct Login {
user_id: i32,
socket_addr: Text<SocketAddr>,
#[cfg(feature = "time")]
login_at: time::OffsetDateTime,
}
let mut conn = new::<MySql>().await?;
conn.execute(
r#"
CREATE TEMPORARY TABLE user_login (
user_id INT PRIMARY KEY AUTO_INCREMENT,
socket_addr TEXT NOT NULL,
login_at TIMESTAMP NOT NULL
);
"#,
)
.await?;
let user_id = 1234;
let socket_addr: SocketAddr = "198.51.100.47:31790".parse().unwrap();
sqlx::query("INSERT INTO user_login (user_id, socket_addr, login_at) VALUES (?, ?, NOW())")
.bind(user_id)
.bind(Text(socket_addr))
.execute(&mut conn)
.await?;
let last_login: Login =
sqlx::query_as("SELECT * FROM user_login ORDER BY login_at DESC LIMIT 1")
.fetch_one(&mut conn)
.await?;
assert_eq!(last_login.user_id, user_id);
assert_eq!(*last_login.socket_addr, socket_addr);
Ok(())
}

View file

@ -1,11 +1,14 @@
extern crate time_ as time;
use std::net::SocketAddr;
use std::ops::Bound;
use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange};
use sqlx::postgres::Postgres;
use sqlx_test::{test_decode_type, test_prepared_type, test_type};
use sqlx_test::{new, test_decode_type, test_prepared_type, test_type};
use sqlx_core::executor::Executor;
use sqlx_core::types::Text;
use std::str::FromStr;
test_type!(null<Option<i16>>(Postgres,
@ -579,3 +582,46 @@ test_type!(ltree_vec<Vec<sqlx::postgres::types::PgLTree>>(Postgres,
sqlx::postgres::types::PgLTree::from_iter(["Alpha", "Beta", "Delta", "Gamma"]).unwrap()
]
));
#[sqlx_macros::test]
async fn test_text_adapter() -> anyhow::Result<()> {
#[derive(sqlx::FromRow, Debug, PartialEq, Eq)]
struct Login {
user_id: i32,
socket_addr: Text<SocketAddr>,
#[cfg(feature = "time")]
login_at: time::OffsetDateTime,
}
let mut conn = new::<Postgres>().await?;
conn.execute(
r#"
CREATE TEMPORARY TABLE user_login (
user_id INT PRIMARY KEY,
socket_addr TEXT NOT NULL,
login_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"#,
)
.await?;
let user_id = 1234;
let socket_addr: SocketAddr = "198.51.100.47:31790".parse().unwrap();
sqlx::query("INSERT INTO user_login (user_id, socket_addr) VALUES ($1, $2)")
.bind(user_id)
.bind(Text(socket_addr))
.execute(&mut conn)
.await?;
let last_login: Login =
sqlx::query_as("SELECT * FROM user_login ORDER BY login_at DESC LIMIT 1")
.fetch_one(&mut conn)
.await?;
assert_eq!(last_login.user_id, user_id);
assert_eq!(*last_login.socket_addr, socket_addr);
Ok(())
}

View file

@ -1,9 +1,12 @@
extern crate time_ as time;
use sqlx::sqlite::{Sqlite, SqliteRow};
use sqlx_core::executor::Executor;
use sqlx_core::row::Row;
use sqlx_core::types::Text;
use sqlx_test::new;
use sqlx_test::test_type;
use std::net::SocketAddr;
test_type!(null<Option<i32>>(Sqlite,
"NULL" == None::<i32>
@ -204,3 +207,46 @@ test_type!(uuid_simple<sqlx::types::uuid::fmt::Simple>(Sqlite,
"'00000000000000000000000000000000'"
== sqlx::types::Uuid::parse_str("00000000000000000000000000000000").unwrap().simple()
));
#[sqlx_macros::test]
async fn test_text_adapter() -> anyhow::Result<()> {
#[derive(sqlx::FromRow, Debug, PartialEq, Eq)]
struct Login {
user_id: i32,
socket_addr: Text<SocketAddr>,
#[cfg(feature = "time")]
login_at: time::OffsetDateTime,
}
let mut conn = new::<Sqlite>().await?;
conn.execute(
r#"
CREATE TEMPORARY TABLE user_login (
user_id INT PRIMARY KEY,
socket_addr TEXT NOT NULL,
login_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
"#,
)
.await?;
let user_id = 1234;
let socket_addr: SocketAddr = "198.51.100.47:31790".parse().unwrap();
sqlx::query("INSERT INTO user_login (user_id, socket_addr) VALUES (?, ?)")
.bind(user_id)
.bind(Text(socket_addr))
.execute(&mut conn)
.await?;
let last_login: Login =
sqlx::query_as("SELECT * FROM user_login ORDER BY login_at DESC LIMIT 1")
.fetch_one(&mut conn)
.await?;
assert_eq!(last_login.user_id, user_id);
assert_eq!(*last_login.socket_addr, socket_addr);
Ok(())
}