unified prepared statement interface

This commit is contained in:
Austin Bonander 2019-11-12 20:57:55 -08:00 committed by Austin Bonander
parent c993e4eee0
commit bbdc03c576
29 changed files with 654 additions and 253 deletions

View file

@ -8,14 +8,14 @@ edition = "2018"
proc-macro = true proc-macro = true
[dependencies] [dependencies]
dotenv = "0.15.0"
futures-preview = "0.3.0-alpha.18" futures-preview = "0.3.0-alpha.18"
hex = "0.4.0"
proc-macro2 = "1.0.6" proc-macro2 = "1.0.6"
sqlx = { path = "../", features = ["postgres"] } sqlx = { path = "../", features = ["postgres"] }
syn = "1.0" syn = "1.0"
quote = "1.0" quote = "1.0"
sha2 = "0.8.0"
tokio = { version = "0.2.0-alpha.4", default-features = false, features = [ "tcp" ] } tokio = { version = "0.2.0-alpha.4", default-features = false, features = [ "tcp" ] }
url = "2.1.0"
[features] [features]
postgres = ["sqlx/postgres"] postgres = ["sqlx/postgres"]

View file

@ -6,47 +6,52 @@ use proc_macro2::Span;
use quote::{format_ident, quote, quote_spanned, ToTokens}; use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::{parse_macro_input, Expr, ExprLit, Lit, LitStr, Token, Type}; use syn::{
use syn::spanned::Spanned; parse::{self, Parse, ParseStream},
use syn::punctuated::Punctuated; parse_macro_input,
use syn::parse::{self, Parse, ParseStream}; punctuated::Punctuated,
spanned::Spanned,
Expr, ExprLit, Lit, LitStr, Token, Type,
};
use sha2::{Sha256, Digest}; use sqlx::{HasTypeMetadata, Postgres};
use sqlx::Postgres;
use tokio::runtime::Runtime; use tokio::runtime::Runtime;
use std::error::Error as _; use std::{error::Error as _, fmt::Display, str::FromStr};
use url::Url;
type Error = Box<dyn std::error::Error>; type Error = Box<dyn std::error::Error>;
type Result<T> = std::result::Result<T, Error>; type Result<T> = std::result::Result<T, Error>;
mod postgres;
struct MacroInput { struct MacroInput {
sql: String, sql: String,
sql_span: Span, sql_span: Span,
args: Vec<Expr> args: Vec<Expr>,
} }
impl Parse for MacroInput { impl Parse for MacroInput {
fn parse(input: ParseStream) -> parse::Result<Self> { fn parse(input: ParseStream) -> parse::Result<Self> {
let mut args = Punctuated::<Expr, Token![,]>::parse_terminated(input)? let mut args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?.into_iter();
.into_iter();
let sql = match args.next() { let sql = match args.next() {
Some(Expr::Lit(ExprLit { lit: Lit::Str(sql), .. })) => sql, Some(Expr::Lit(ExprLit {
Some(other_expr) => return Err(parse::Error::new_spanned(other_expr, "expected string literal")), lit: Lit::Str(sql), ..
})) => sql,
Some(other_expr) => {
return Err(parse::Error::new_spanned(
other_expr,
"expected string literal",
));
}
None => return Err(input.error("expected SQL string literal")), None => return Err(input.error("expected SQL string literal")),
}; };
Ok( Ok(MacroInput {
MacroInput {
sql: sql.value(), sql: sql.value(),
sql_span: sql.span(), sql_span: sql.span(),
args: args.collect(), args: args.collect(),
} })
)
} }
} }
@ -56,51 +61,108 @@ pub fn sql(input: TokenStream) -> TokenStream {
eprintln!("expanding macro"); eprintln!("expanding macro");
match Runtime::new().map_err(Error::from).and_then(|runtime| runtime.block_on(process_sql(input))) { match Runtime::new()
.map_err(Error::from)
.and_then(|runtime| runtime.block_on(process_sql(input)))
{
Ok(ts) => { Ok(ts) => {
eprintln!("emitting output: {}", ts); eprintln!("emitting output: {}", ts);
ts ts
}, }
Err(e) => { Err(e) => {
if let Some(parse_err) = e.downcast_ref::<parse::Error>() { if let Some(parse_err) = e.downcast_ref::<parse::Error>() {
return parse_err.to_compile_error().into(); return parse_err.to_compile_error().into();
} }
let msg = e.to_string(); let msg = e.to_string();
quote! ( compile_error!(#msg) ).into() quote!(compile_error!(#msg)).into()
} }
} }
} }
async fn process_sql(input: MacroInput) -> Result<TokenStream> { async fn process_sql(input: MacroInput) -> Result<TokenStream> {
let hash = dbg!(hex::encode(&Sha256::digest(input.sql.as_bytes()))); let db_url = Url::parse(&dotenv::var("DB_URL")?)?;
let conn = sqlx::Connection::<Postgres>::establish("postgresql://postgres@127.0.0.1/sqlx_test") match db_url.scheme() {
#[cfg(feature = "postgres")]
"postgresql" => {
process_sql_with(
input,
sqlx::Connection::<sqlx::Postgres>::establish(db_url.as_str())
.await .await
.map_err(|e| format!("failed to connect to database: {}", e))?; .map_err(|e| format!("failed to connect to database: {}", e))?,
)
.await
}
#[cfg(feature = "mysql")]
"mysql" => {
process_sql_with(
input,
sqlx::Connection::<sqlx::MariaDb>::establish(
"postgresql://postgres@127.0.0.1/sqlx_test",
)
.await
.map_err(|e| format!("failed to connect to database: {}", e))?,
)
.await
}
scheme => Err(format!("unexpected scheme {:?} in DB_URL {}", scheme, db_url).into()),
}
}
async fn process_sql_with<DB: sqlx::Backend>(
input: MacroInput,
conn: sqlx::Connection<DB>,
) -> Result<TokenStream>
where
<DB as HasTypeMetadata>::TypeId: Display,
{
eprintln!("connection established"); eprintln!("connection established");
let prepared = conn.prepare(&hash, &input.sql) let prepared = conn
.prepare(&input.sql)
.await .await
.map_err(|e| parse::Error::new(input.sql_span, e))?; .map_err(|e| parse::Error::new(input.sql_span, e))?;
if input.args.len() != prepared.param_types.len() { if input.args.len() != prepared.param_types.len() {
return Err(parse::Error::new( return Err(parse::Error::new(
Span::call_site(), Span::call_site(),
format!("expected {} parameters, got {}", prepared.param_types.len(), input.args.len()) format!(
).into()); "expected {} parameters, got {}",
prepared.param_types.len(),
input.args.len()
),
)
.into());
} }
let param_types = prepared.param_types.iter().zip(&*input.args).map(|(type_, expr)| { let param_types = prepared
.param_types
.iter()
.zip(&*input.args)
.map(|(type_, expr)| {
get_type_override(expr) get_type_override(expr)
.or_else(|| postgres::map_param_type_oid(*type_)) .or_else(|| {
.ok_or_else(|| format!("unknown type OID: {}", type_).into()) Some(
<DB as sqlx::HasTypeMetadata>::param_type_for_id(type_)?
.parse::<proc_macro2::TokenStream>()
.unwrap(),
)
})
.ok_or_else(|| format!("unknown type ID: {}", type_).into())
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let output_types = prepared.fields.iter().map(|field| { let output_types = prepared
postgres::map_output_type_oid(field.type_id) .columns
.iter()
.map(|column| {
Ok(
<DB as sqlx::HasTypeMetadata>::return_type_for_id(&column.type_id)
.ok_or_else(|| format!("unknown type ID: {}", &column.type_id))?
.parse::<proc_macro2::TokenStream>()
.unwrap(),
)
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -112,8 +174,7 @@ async fn process_sql(input: MacroInput) -> Result<TokenStream> {
let query = &input.sql; let query = &input.sql;
Ok( Ok(quote! {{
quote! {{
use sqlx::TyConsExt as _; use sqlx::TyConsExt as _;
let params = (#(#params),*,); let params = (#(#params),*,);
@ -129,8 +190,7 @@ async fn process_sql(input: MacroInput) -> Result<TokenStream> {
backend: ::core::marker::PhantomData, backend: ::core::marker::PhantomData,
} }
}} }}
.into() .into())
)
} }
fn get_type_override(expr: &Expr) -> Option<proc_macro2::TokenStream> { fn get_type_override(expr: &Expr) -> Option<proc_macro2::TokenStream> {

View file

@ -1,45 +0,0 @@
use proc_macro2::TokenStream;
pub fn map_param_type_oid(oid: u32) -> Option<TokenStream> {
Some(match oid {
16 => "bool",
1000 => "&[bool]",
25 => "&str",
1009 => "&[&str]",
21 => "i16",
1005 => "&[i16]",
23 => "i32",
1007 => "&[i32]",
20 => "i64",
1016 => "&[i64]",
700 => "f32",
1021 => "&[f32]",
701 => "f64",
1022 => "&[f64]",
2950 => "sqlx::Uuid",
2951 => "&[sqlx::Uuid]",
_ => return None
}.parse().unwrap())
}
pub fn map_output_type_oid(oid: u32) -> crate::Result<TokenStream> {
Ok(match oid {
16 => "bool",
1000 => "Vec<bool>",
25 => "String",
1009 => "Vec<String>",
21 => "i16",
1005 => "Vec<i16>",
23 => "i32",
1007 => "Vec<i32>",
20 => "i64",
1016 => "Vec<i64>",
700 => "f32",
1021 => "Vec<f32>",
701 => "f64",
1022 => "Vec<f64>",
2950 => "sqlx::Uuid",
2951 => "Vec<sqlx::Uuid>",
_ => return Err(format!("unknown type ID: {}", oid).into())
}.parse().unwrap())
}

View file

@ -1,9 +1,9 @@
use crate::{connection::RawConnection, query::QueryParameters, row::Row}; use crate::{connection::RawConnection, query::QueryParameters, row::Row, types::HasTypeMetadata};
/// A database backend. /// A database backend.
/// ///
/// This trait represents the concept of a backend (e.g. "MySQL" vs "SQLite"). /// This trait represents the concept of a backend (e.g. "MySQL" vs "SQLite").
pub trait Backend: Sized { pub trait Backend: HasTypeMetadata + Sized {
/// The concrete `QueryParameters` implementation for this backend. /// The concrete `QueryParameters` implementation for this backend.
type QueryParameters: QueryParameters<Backend = Self>; type QueryParameters: QueryParameters<Backend = Self>;
@ -13,4 +13,12 @@ pub trait Backend: Sized {
/// The concrete `Row` implementation for this backend. This type is returned /// The concrete `Row` implementation for this backend. This type is returned
/// from methods in the `RawConnection`. /// from methods in the `RawConnection`.
type Row: Row<Backend = Self>; type Row: Row<Backend = Self>;
/// The identifier for prepared statements; in Postgres this is a string
/// and in MariaDB/MySQL this is an integer.
type StatementIdent;
/// The identifier for tables; in Postgres this is an `oid` while
/// in MariaDB/MySQL this is the qualified name of the table.
type TableIdent;
} }

View file

@ -1,5 +1,8 @@
use std::{env, str}; use std::{
use std::io::{self, Write, Read}; env,
io::{self, Read, Write},
str,
};
use std::process::{Command, Stdio}; use std::process::{Command, Stdio};
@ -11,13 +14,17 @@ fn get_expanded_target() -> crate::Result<Vec<u8>> {
let mut args = env::args_os().skip(2); let mut args = env::args_os().skip(2);
let cargo_args = args.by_ref().take_while(|arg| arg != "--").collect::<Vec<_>>(); let cargo_args = args
.by_ref()
.take_while(|arg| arg != "--")
.collect::<Vec<_>>();
let rustc_args = args.collect::<Vec<_>>(); let rustc_args = args.collect::<Vec<_>>();
let mut command = Command::new(cargo_path); let mut command = Command::new(cargo_path);
command.arg("rustc") command
.arg("rustc")
.args(cargo_args) .args(cargo_args)
.arg("--") .arg("--")
.arg("-Z") .arg("-Z")
@ -61,7 +68,7 @@ fn find_next_sql_string(input: &str) -> Result<Option<(&str, &str)>> {
let start = idx + STRING_START.len(); let start = idx + STRING_START.len();
while let Some(end) = input[start..].find(STRING_END) { while let Some(end) = input[start..].find(STRING_END) {
if &input[start + end - 1 .. start + end] != "\\" { if &input[start + end - 1..start + end] != "\\" {
return Ok(Some(input[start..].split_at(end))); return Ok(Some(input[start..].split_at(end)));
} }
} }

View file

@ -1,9 +1,6 @@
use crate::{query::IntoQueryParameters, Backend, Executor, FromSqlRow};
use futures_core::{future::BoxFuture, stream::BoxStream, Stream};
use std::marker::PhantomData; use std::marker::PhantomData;
use crate::query::IntoQueryParameters;
use crate::{Backend, FromSqlRow, Executor};
use futures_core::Stream;
use futures_core::stream::BoxStream;
use futures_core::future::BoxFuture;
pub struct CompiledSql<P, O, DB> { pub struct CompiledSql<P, O, DB> {
#[doc(hidden)] #[doc(hidden)]
@ -12,14 +9,22 @@ pub struct CompiledSql<P, O, DB> {
pub params: P, pub params: P,
#[doc(hidden)] #[doc(hidden)]
pub output: PhantomData<O>, pub output: PhantomData<O>,
pub backend: PhantomData<DB> pub backend: PhantomData<DB>,
} }
impl<DB, P, O> CompiledSql<P, O, DB> where DB: Backend, P: IntoQueryParameters<DB> + Send, O: FromSqlRow<DB> + Send + Unpin { impl<DB, P, O> CompiledSql<P, O, DB>
where
DB: Backend,
P: IntoQueryParameters<DB> + Send,
O: FromSqlRow<DB> + Send + Unpin,
{
#[inline] #[inline]
pub fn execute<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result<u64>> pub fn execute<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result<u64>>
where where
E: Executor<Backend = DB>, DB: 'e, P: 'e, O: 'e E: Executor<Backend = DB>,
DB: 'e,
P: 'e,
O: 'e,
{ {
executor.execute(self.query, self.params) executor.execute(self.query, self.params)
} }
@ -27,7 +32,10 @@ impl<DB, P, O> CompiledSql<P, O, DB> where DB: Backend, P: IntoQueryParameters<D
#[inline] #[inline]
pub fn fetch<'e, E: 'e>(self, executor: &'e E) -> BoxStream<'e, crate::Result<O>> pub fn fetch<'e, E: 'e>(self, executor: &'e E) -> BoxStream<'e, crate::Result<O>>
where where
E: Executor<Backend = DB>, DB: 'e, P: 'e, O: 'e E: Executor<Backend = DB>,
DB: 'e,
P: 'e,
O: 'e,
{ {
executor.fetch(self.query, self.params) executor.fetch(self.query, self.params)
} }
@ -35,7 +43,10 @@ impl<DB, P, O> CompiledSql<P, O, DB> where DB: Backend, P: IntoQueryParameters<D
#[inline] #[inline]
pub fn fetch_all<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result<Vec<O>>> pub fn fetch_all<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result<Vec<O>>>
where where
E: Executor<Backend = DB>, DB: 'e, P: 'e, O: 'e E: Executor<Backend = DB>,
DB: 'e,
P: 'e,
O: 'e,
{ {
executor.fetch_all(self.query, self.params) executor.fetch_all(self.query, self.params)
} }
@ -46,7 +57,10 @@ impl<DB, P, O> CompiledSql<P, O, DB> where DB: Backend, P: IntoQueryParameters<D
executor: &'e E, executor: &'e E,
) -> BoxFuture<'e, crate::Result<Option<O>>> ) -> BoxFuture<'e, crate::Result<Option<O>>>
where where
E: Executor<Backend = DB>, DB: 'e, P: 'e, O: 'e E: Executor<Backend = DB>,
DB: 'e,
P: 'e,
O: 'e,
{ {
executor.fetch_optional(self.query, self.params) executor.fetch_optional(self.query, self.params)
} }
@ -54,7 +68,10 @@ impl<DB, P, O> CompiledSql<P, O, DB> where DB: Backend, P: IntoQueryParameters<D
#[inline] #[inline]
pub fn fetch_one<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result<O>> pub fn fetch_one<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result<O>>
where where
E: Executor<Backend = DB>, DB: 'e, P: 'e, O: 'e E: Executor<Backend = DB>,
DB: 'e,
P: 'e,
O: 'e,
{ {
executor.fetch_one(self.query, self.params) executor.fetch_one(self.query, self.params)
} }

View file

@ -3,6 +3,7 @@ use crate::{
error::Error, error::Error,
executor::Executor, executor::Executor,
pool::{Live, SharedPool}, pool::{Live, SharedPool},
prepared::PreparedStatement,
query::{IntoQueryParameters, QueryParameters}, query::{IntoQueryParameters, QueryParameters},
row::FromSqlRow, row::FromSqlRow,
}; };
@ -19,7 +20,6 @@ use std::{
}, },
time::Instant, time::Instant,
}; };
use crate::prepared::PreparedStatement;
/// A connection.bak to the database. /// A connection.bak to the database.
/// ///
@ -73,10 +73,15 @@ pub trait RawConnection: Send {
params: <Self::Backend as Backend>::QueryParameters, params: <Self::Backend as Backend>::QueryParameters,
) -> crate::Result<Option<<Self::Backend as Backend>::Row>>; ) -> crate::Result<Option<<Self::Backend as Backend>::Row>>;
async fn prepare(&mut self, name: &str, body: &str) -> crate::Result<PreparedStatement> { async fn prepare(
// TODO: implement for other backends &mut self,
unimplemented!() query: &str,
} ) -> crate::Result<<Self::Backend as Backend>::StatementIdent>;
async fn prepare_describe(
&mut self,
query: &str,
) -> crate::Result<PreparedStatement<Self::Backend>>;
} }
pub struct Connection<DB>(Arc<SharedConnection<DB>>) pub struct Connection<DB>(Arc<SharedConnection<DB>>)
@ -128,9 +133,12 @@ where
} }
/// Prepares a statement. /// Prepares a statement.
pub async fn prepare(&self, name: &str, body: &str) -> crate::Result<PreparedStatement> { ///
/// UNSTABLE: for use by sqlx-macros only
#[doc(hidden)]
pub async fn prepare(&self, body: &str) -> crate::Result<PreparedStatement<DB>> {
let mut live = self.0.acquire().await; let mut live = self.0.acquire().await;
let ret = live.raw.prepare(name, body).await?; let ret = live.raw.prepare_describe(body).await?;
self.0.release(live); self.0.release(live);
Ok(ret) Ok(ret)
} }

View file

@ -54,7 +54,7 @@ impl Display for Error {
match self { match self {
Error::Io(error) => write!(f, "{}", error), Error::Io(error) => write!(f, "{}", error),
Error::Database(error) => f.write_str(error.message()), Error::Database(error) => Display::fmt(error, f),
Error::NotFound => f.write_str("found no rows when we expected at least one"), Error::NotFound => f.write_str("found no rows when we expected at least one"),
@ -85,8 +85,6 @@ where
} }
/// An error that was returned by the database backend. /// An error that was returned by the database backend.
pub trait DatabaseError: Debug + Send + Sync { pub trait DatabaseError: Display + Debug + Send + Sync {
fn message(&self) -> &str; fn message(&self) -> &str;
// TODO: Expose more error properties
} }

View file

@ -33,12 +33,11 @@ mod compiled;
#[doc(inline)] #[doc(inline)]
pub use self::{ pub use self::{
backend::Backend, backend::Backend,
connection::Connection,
compiled::CompiledSql, compiled::CompiledSql,
connection::Connection,
deserialize::FromSql, deserialize::FromSql,
error::{Error, Result}, error::{Error, Result},
executor::Executor, executor::Executor,
prepared::{PreparedStatement, Field},
pool::Pool, pool::Pool,
row::{FromSqlRow, Row}, row::{FromSqlRow, Row},
serialize::ToSql, serialize::ToSql,
@ -46,6 +45,9 @@ pub use self::{
types::HasSqlType, types::HasSqlType,
}; };
#[doc(hidden)]
pub use types::HasTypeMetadata;
#[cfg(feature = "mariadb")] #[cfg(feature = "mariadb")]
pub mod mariadb; pub mod mariadb;

View file

@ -7,6 +7,8 @@ impl Backend for MariaDb {
type QueryParameters = super::MariaDbQueryParameters; type QueryParameters = super::MariaDbQueryParameters;
type RawConnection = super::MariaDbRawConnection; type RawConnection = super::MariaDbRawConnection;
type Row = super::MariaDbRow; type Row = super::MariaDbRow;
type StatementIdent = u32;
type TableIdent = String;
} }
impl_from_sql_row_tuples_for_backend!(MariaDb); impl_from_sql_row_tuples_for_backend!(MariaDb);

View file

@ -11,11 +11,13 @@ use crate::{
}, },
MariaDb, MariaDbQueryParameters, MariaDbRow, MariaDb, MariaDbQueryParameters, MariaDbRow,
}, },
Backend, Error, Result, prepared::{Column, PreparedStatement},
Backend, Error, PreparedStatement, Result,
}; };
use async_trait::async_trait; use async_trait::async_trait;
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use futures_core::{future::BoxFuture, stream::BoxStream}; use futures_core::{future::BoxFuture, stream::BoxStream};
use futures_util::stream::{self, StreamExt};
use std::{ use std::{
future::Future, future::Future,
io, io,
@ -173,51 +175,33 @@ impl MariaDbRawConnection {
}) })
} }
// This should not be used by the user. It's mean for `RawConnection` impl async fn check_eof(&mut self) -> Result<()> {
// This assumes the buffer has been set and all it needs is a flush if !self
async fn exec_prepare(&mut self) -> Result<u32> {
self.stream.flush().await?;
// COM_STMT_PREPARE returns COM_STMT_PREPARE_OK (0x00) or ERR (0xFF)
let mut packet = self.receive().await?;
let ok = match packet[0] {
0xFF => {
let err = ErrPacket::decode(packet)?;
// TODO: Bubble as Error::Database
// panic!("received db err = {:?}", err);
return Err(
io::Error::new(io::ErrorKind::InvalidInput, format!("{:?}", err)).into(),
);
}
_ => ComStmtPrepareOk::decode(packet)?,
};
// Skip decoding Column Definition packets for the result from a prepare statement
for _ in 0..ok.columns {
let _ = self.receive().await?;
}
if ok.columns > 0
&& !self
.capabilities .capabilities
.contains(Capabilities::CLIENT_DEPRECATE_EOF) .contains(Capabilities::CLIENT_DEPRECATE_EOF)
{ {
// TODO: Should we do something with the warning indicators here? let _ = EofPacket::decode(self.receive().await?)?;
let _eof = EofPacket::decode(self.receive().await?)?;
} }
Ok(ok.statement_id) Ok(())
} }
async fn prepare<'c>(&'c mut self, statement: &'c str) -> Result<u32> { async fn send_prepare<'c>(&'c mut self, statement: &'c str) -> Result<ComStmtPrepareOk> {
self.stream.flush().await?; self.stream.flush().await?;
self.start_sequence(); self.start_sequence();
self.write(ComStmtPrepare { statement }); self.write(ComStmtPrepare { statement });
self.exec_prepare().await self.stream.flush().await?;
// COM_STMT_PREPARE returns COM_STMT_PREPARE_OK (0x00) or ERR (0xFF)
let packet = self.receive().await?;
if packet[0] == 0xFF {
return Err(ErrPacket::decode(packet)?.into());
}
ComStmtPrepareOk::decode(packet).map_err(Into::into)
} }
async fn execute(&mut self, statement_id: u32, params: MariaDbQueryParameters) -> Result<u64> { async fn execute(&mut self, statement_id: u32, params: MariaDbQueryParameters) -> Result<u64> {
@ -323,11 +307,9 @@ impl RawConnection for MariaDbRawConnection {
async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result<u64> { async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result<u64> {
// Write prepare statement to buffer // Write prepare statement to buffer
self.start_sequence(); self.start_sequence();
self.write(ComStmtPrepare { statement: query }); let prepare_ok = self.send_prepare(query).await?;
let statement_id = self.exec_prepare().await?; let affected = self.execute(prepare_ok.statement_id, params).await?;
let affected = self.execute(statement_id, params).await?;
Ok(affected) Ok(affected)
} }
@ -347,6 +329,56 @@ impl RawConnection for MariaDbRawConnection {
) -> crate::Result<Option<<Self::Backend as Backend>::Row>> { ) -> crate::Result<Option<<Self::Backend as Backend>::Row>> {
unimplemented!(); unimplemented!();
} }
async fn prepare(&mut self, query: &str) -> crate::Result<u32> {
let prepare_ok = self.send_prepare(query).await?;
for _ in 0..prepare_ok.params {
let _ = self.receive().await?;
}
self.check_eof().await?;
for _ in 0..prepare_ok.columns {
let _ = self.receive().await?;
}
self.check_eof().await?;
Ok(prepare_ok.statement_id)
}
async fn prepare_describe(&mut self, query: &str) -> crate::Result<PreparedStatement<MariaDb>> {
let prepare_ok = self.send_prepare(query).await?;
let mut param_types = Vec::with_capacity(prepare_ok.params as usize);
for _ in 0..prepare_ok.params {
let param = ColumnDefinitionPacket::decode(self.receive().await?)?;
param_types.push(param.field_type.0);
}
self.check_eof().await?;
let mut columns = Vec::with_capacity(prepare_ok.columns as usize);
for _ in 0..prepare_ok.columns {
let column = ColumnDefinitionPacket::decode(self.receive().await?)?;
columns.push(Column {
name: column.column_alias.or(column.column),
table_id: column.table_alias.or(column.table),
type_id: column.field_type.0,
})
}
self.check_eof().await?;
Ok(PreparedStatement {
identifier: prepare_ok.statement_id,
param_types,
columns,
})
}
} }
#[cfg(test)] #[cfg(test)]

21
src/mariadb/error.rs Normal file
View file

@ -0,0 +1,21 @@
use crate::{error::DatabaseError, mariadb::protocol::ErrorCode};
use std::fmt;
#[derive(Debug)]
pub struct Error {
pub code: ErrorCode,
pub message: Box<str>,
}
impl DatabaseError for Error {
fn message(&self) -> &str {
&self.message
}
}
impl fmt::Display for ErrorCode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "MariaDB returned an error: {}",)
}
}

View file

@ -1,5 +1,6 @@
mod backend; mod backend;
mod connection; mod connection;
mod error;
mod establish; mod establish;
mod io; mod io;
mod protocol; mod protocol;

View file

@ -1,10 +1,40 @@
use std::fmt;
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct ErrorCode(pub(crate) u16); pub struct ErrorCode(pub(crate) u16);
// TODO: It would be nice to figure out a clean way to go from 1152 to "ER_ABORTING_CONNECTION (1152)" in Debug. use crate::error::DatabaseError;
use bitflags::_core::fmt::{Error, Formatter};
macro_rules! error_code_impl {
($(const $name:ident: ErrorCode = ErrorCode($code:expr));*;) => {
impl ErrorCode {
$(const $name: ErrorCode = ErrorCode($code);)*
pub fn code_name(&self) -> &'static str {
match self.0 {
$($code => $name,)*
_ => "<unknown error>"
}
}
}
}
}
impl fmt::Debug for ErrorCode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ErrorCode({} [()])",)
}
}
impl fmt::Display for ErrorCode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} ({})", self.code_name(), self.0)
}
}
// Values from https://mariadb.com/kb/en/library/mariadb-error-codes/ // Values from https://mariadb.com/kb/en/library/mariadb-error-codes/
impl ErrorCode { error_code_impl! {
const ER_ABORTING_CONNECTION: ErrorCode = ErrorCode(1152); const ER_ABORTING_CONNECTION: ErrorCode = ErrorCode(1152);
const ER_ACCESS_DENIED_CHANGE_USER_ERROR: ErrorCode = ErrorCode(1873); const ER_ACCESS_DENIED_CHANGE_USER_ERROR: ErrorCode = ErrorCode(1873);
const ER_ACCESS_DENIED_ERROR: ErrorCode = ErrorCode(1045); const ER_ACCESS_DENIED_ERROR: ErrorCode = ErrorCode(1045);

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
io::Buf, io::Buf,
mariadb::{io::BufExt, protocol::ErrorCode}, mariadb::{error::Error, io::BufExt, protocol::ErrorCode},
}; };
use byteorder::LittleEndian; use byteorder::LittleEndian;
use std::io; use std::io;
@ -66,6 +66,15 @@ impl ErrPacket {
}) })
} }
} }
pub fn expect_error<T>(self) -> crate::Result<T> {
match self {
ErrPacket::Progress { .. } => {
Err(format!("expected ErrPacket::Err, got {:?}", self).into())
}
ErrPacket::Error { code, message, .. } => Err(Error { code, message }.into()),
}
}
} }
#[cfg(test)] #[cfg(test)]

View file

@ -1,5 +1,8 @@
use super::protocol::{FieldType, ParameterFlag}; use super::protocol::{FieldType, ParameterFlag};
use crate::{mariadb::MariaDb, types::TypeMetadata}; use crate::{
mariadb::MariaDb,
types::{HasTypeMetadata, TypeMetadata},
};
pub mod boolean; pub mod boolean;
pub mod character; pub mod character;
@ -11,6 +14,43 @@ pub struct MariaDbTypeMetadata {
pub param_flag: ParameterFlag, pub param_flag: ParameterFlag,
} }
impl TypeMetadata for MariaDb { impl HasTypeMetadata for MariaDb {
type TypeMetadata = MariaDbTypeMetadata; type TypeMetadata = MariaDbTypeMetadata;
type TypeId = u8;
fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str> {
Some(match FieldType(*id) {
FieldType::MYSQL_TYPE_TINY => "i8",
FieldType::MYSQL_TYPE_SHORT => "i16",
FieldType::MYSQL_TYPE_LONG => "i32",
FieldType::MYSQL_TYPE_LONGLONG => "i64",
FieldType::MYSQL_TYPE_VAR_STRING => "&str",
FieldType::MYSQL_TYPE_FLOAT => "f32",
FieldType::MYSQL_TYPE_DOUBLE => "f64",
FieldType::MYSQL_TYPE_BLOB => "&[u8]",
_ => return None
})
}
fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str> {
Some(match FieldType(*id) {
FieldType::MYSQL_TYPE_TINY => "i8",
FieldType::MYSQL_TYPE_SHORT => "i16",
FieldType::MYSQL_TYPE_LONG => "i32",
FieldType::MYSQL_TYPE_LONGLONG => "i64",
FieldType::MYSQL_TYPE_VAR_STRING => "String",
FieldType::MYSQL_TYPE_FLOAT => "f32",
FieldType::MYSQL_TYPE_DOUBLE => "f64",
FieldType::MYSQL_TYPE_BLOB => "Vec<u8>",
_ => return None
})
}
}
impl TypeMetadata for MariaDbTypeMetadata {
type TypeId = u8;
fn type_id(&self) -> &Self::TypeId {
&self.field_type.0
}
} }

View file

@ -7,6 +7,8 @@ impl Backend for Postgres {
type QueryParameters = super::PostgresQueryParameters; type QueryParameters = super::PostgresQueryParameters;
type RawConnection = super::PostgresRawConnection; type RawConnection = super::PostgresRawConnection;
type Row = super::PostgresRow; type Row = super::PostgresRow;
type StatementIdent = String;
type TableIdent = u32;
} }
impl_from_sql_row_tuples_for_backend!(Postgres); impl_from_sql_row_tuples_for_backend!(Postgres);

View file

@ -1,10 +1,19 @@
use super::{Postgres, PostgresQueryParameters, PostgresRawConnection, PostgresRow}; use super::{Postgres, PostgresQueryParameters, PostgresRawConnection, PostgresRow};
use crate::{connection::RawConnection, postgres::raw::Step, url::Url, Error}; use crate::{
use crate::query::QueryParameters; connection::RawConnection,
postgres::{error::ProtocolError, raw::Step},
prepared::{Column, PreparedStatement},
query::QueryParameters,
url::Url,
Error,
};
use async_trait::async_trait; use async_trait::async_trait;
use futures_core::stream::BoxStream; use futures_core::stream::BoxStream;
use crate::prepared::{PreparedStatement, Field};
use crate::postgres::error::ProtocolError; use std::sync::atomic::{AtomicU64, Ordering};
use crate::postgres::{protocol::Message, PostgresDatabaseError};
use std::hash::Hasher;
#[async_trait] #[async_trait]
impl RawConnection for PostgresRawConnection { impl RawConnection for PostgresRawConnection {
@ -96,45 +105,80 @@ impl RawConnection for PostgresRawConnection {
Ok(row) Ok(row)
} }
async fn prepare(&mut self, name: &str, body: &str) -> crate::Result<PreparedStatement> { async fn prepare(&mut self, body: &str) -> crate::Result<String> {
self.parse(name, body, &PostgresQueryParameters::new()); let name = gen_statement_name(body);
self.describe(name); self.parse(&name, body, &PostgresQueryParameters::new());
match self.receive().await? {
Some(Message::Response(response)) => Err(PostgresDatabaseError(response).into()),
Some(Message::ParseComplete) => Ok(name),
Some(message) => {
Err(ProtocolError(format!("unexpected message: {:?}", message)).into())
}
None => Err(ProtocolError("expected ParseComplete or ErrorResponse").into()),
}
}
async fn prepare_describe(&mut self, body: &str) -> crate::Result<PreparedStatement<Postgres>> {
let name = gen_statement_name(body);
self.parse(&name, body, &PostgresQueryParameters::new());
self.describe(&name);
self.sync().await?; self.sync().await?;
let param_desc= loop { let param_desc = loop {
let step = self.step().await? let step = self
.step()
.await?
.ok_or(ProtocolError("did not receive ParameterDescription")); .ok_or(ProtocolError("did not receive ParameterDescription"));
if let Step::ParamDesc(desc) = dbg!(step)? if let Step::ParamDesc(desc) = step? {
{
break desc; break desc;
} }
}; };
let row_desc = loop { let row_desc = loop {
let step = self.step().await? let step = self
.step()
.await?
.ok_or(ProtocolError("did not receive RowDescription")); .ok_or(ProtocolError("did not receive RowDescription"));
if let Step::RowDesc(desc) = dbg!(step)? if let Step::RowDesc(desc) = step? {
{
break desc; break desc;
} }
}; };
Ok(PreparedStatement { Ok(PreparedStatement {
name: name.into(), identifier: name.into(),
param_types: param_desc.ids, param_types: param_desc.ids.into_vec(),
fields: row_desc.fields.into_vec().into_iter() columns: row_desc
.map(|field| Field { .fields
name: field.name, .into_vec()
table_id: field.table_id, .into_iter()
type_id: field.type_id .map(|field| Column {
name: Some(field.name),
table_id: Some(field.table_id),
type_id: field.type_id,
}) })
.collect(), .collect(),
}) })
} }
} }
static STATEMENT_COUNT: AtomicU64 = AtomicU64::new(0);
fn gen_statement_name(query: &str) -> String {
// hasher with no external dependencies
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
// including a global counter should help prevent collision
// with queries with the same content
hasher.write_u64(STATEMENT_COUNT.fetch_add(1, Ordering::SeqCst));
hasher.write(query.as_bytes());
format!("sqlx_stmt_{:x}", hasher.finish())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View file

@ -1,7 +1,10 @@
use super::protocol::Response; use super::protocol::Response;
use crate::error::DatabaseError; use crate::error::DatabaseError;
use std::borrow::Cow; use bitflags::_core::fmt::{Error, Formatter};
use std::fmt::Debug; use std::{
borrow::Cow,
fmt::{self, Debug, Display},
};
#[derive(Debug)] #[derive(Debug)]
pub struct PostgresDatabaseError(pub(super) Box<Response>); pub struct PostgresDatabaseError(pub(super) Box<Response>);
@ -15,8 +18,20 @@ impl DatabaseError for PostgresDatabaseError {
} }
} }
impl Display for PostgresDatabaseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad(self.message())
}
}
impl<T: AsRef<str> + Debug + Send + Sync> DatabaseError for ProtocolError<T> { impl<T: AsRef<str> + Debug + Send + Sync> DatabaseError for ProtocolError<T> {
fn message(&self) -> &str { fn message(&self) -> &str {
self.0.as_ref() self.0.as_ref()
} }
} }
impl<T: AsRef<str>> Display for ProtocolError<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad(self.0.as_ref())
}
}

View file

@ -21,5 +21,5 @@ pub enum Message {
NoData, NoData,
PortalSuspended, PortalSuspended,
ParameterDescription(Box<ParameterDescription>), ParameterDescription(Box<ParameterDescription>),
RowDescription(Box<RowDescription>) RowDescription(Box<RowDescription>),
} }

View file

@ -31,10 +31,22 @@ mod terminate;
// TODO: mod ssl_request; // TODO: mod ssl_request;
pub use self::{ pub use self::{
bind::Bind, cancel_request::CancelRequest, close::Close, copy_data::CopyData, bind::Bind,
copy_done::CopyDone, copy_fail::CopyFail, describe::Describe, describe::DescribeKind, encode::Encode, execute::Execute, cancel_request::CancelRequest,
flush::Flush, parse::Parse, password_message::PasswordMessage, query::Query, close::Close,
startup_message::StartupMessage, sync::Sync, terminate::Terminate, copy_data::CopyData,
copy_done::CopyDone,
copy_fail::CopyFail,
describe::{Describe, DescribeKind},
encode::Encode,
execute::Execute,
flush::Flush,
parse::Parse,
password_message::PasswordMessage,
query::Query,
startup_message::StartupMessage,
sync::Sync,
terminate::Terminate,
}; };
mod authentication; mod authentication;
@ -54,10 +66,17 @@ mod row_description;
mod message; mod message;
pub use self::{ pub use self::{
authentication::Authentication, backend_key_data::BackendKeyData, authentication::Authentication,
command_complete::CommandComplete, data_row::DataRow, decode::Decode, message::Message, backend_key_data::BackendKeyData,
notification_response::NotificationResponse, parameter_description::ParameterDescription, command_complete::CommandComplete,
parameter_status::ParameterStatus, ready_for_query::ReadyForQuery, response::Response, data_row::DataRow,
decode::Decode,
message::Message,
notification_response::NotificationResponse,
parameter_description::ParameterDescription,
parameter_status::ParameterStatus,
ready_for_query::ReadyForQuery,
response::Response,
row_description::{RowDescription, RowField}, row_description::{RowDescription, RowField},
}; };

View file

@ -1,8 +1,7 @@
use super::Decode; use super::Decode;
use crate::io::Buf; use crate::io::Buf;
use byteorder::NetworkEndian; use byteorder::NetworkEndian;
use std::io; use std::{io, io::BufRead};
use std::io::BufRead;
#[derive(Debug)] #[derive(Debug)]
pub struct RowDescription { pub struct RowDescription {
@ -17,7 +16,7 @@ pub struct RowField {
pub type_id: u32, pub type_id: u32,
pub type_size: i16, pub type_size: i16,
pub type_mod: i32, pub type_mod: i32,
pub format_code: i16 pub format_code: i16,
} }
impl Decode for RowDescription { impl Decode for RowDescription {
@ -26,7 +25,7 @@ impl Decode for RowDescription {
let mut fields = Vec::with_capacity(cnt); let mut fields = Vec::with_capacity(cnt);
for _ in 0..cnt { for _ in 0..cnt {
fields.push(dbg!(RowField { fields.push(RowField {
name: super::read_string(&mut buf)?, name: super::read_string(&mut buf)?,
table_id: buf.get_u32::<NetworkEndian>()?, table_id: buf.get_u32::<NetworkEndian>()?,
attr_num: buf.get_i16::<NetworkEndian>()?, attr_num: buf.get_i16::<NetworkEndian>()?,
@ -34,7 +33,7 @@ impl Decode for RowDescription {
type_size: buf.get_i16::<NetworkEndian>()?, type_size: buf.get_i16::<NetworkEndian>()?,
type_mod: buf.get_i32::<NetworkEndian>()?, type_mod: buf.get_i32::<NetworkEndian>()?,
format_code: buf.get_i16::<NetworkEndian>()?, format_code: buf.get_i16::<NetworkEndian>()?,
})); });
} }
Ok(Self { Ok(Self {

View file

@ -151,8 +151,9 @@ impl PostgresRawConnection {
pub(super) fn describe(&mut self, statement: &str) { pub(super) fn describe(&mut self, statement: &str) {
protocol::Describe { protocol::Describe {
kind: protocol::DescribeKind::PreparedStatement, kind: protocol::DescribeKind::PreparedStatement,
name: statement name: statement,
}.encode(self.stream.buffer_mut()) }
.encode(self.stream.buffer_mut())
} }
pub(super) fn bind(&mut self, portal: &str, statement: &str, params: &PostgresQueryParameters) { pub(super) fn bind(&mut self, portal: &str, statement: &str, params: &PostgresQueryParameters) {
@ -198,15 +199,15 @@ impl PostgresRawConnection {
Message::ReadyForQuery(_) => { Message::ReadyForQuery(_) => {
return Ok(None); return Ok(None);
}, }
Message::ParameterDescription(desc) => { Message::ParameterDescription(desc) => {
return Ok(Some(Step::ParamDesc(desc))); return Ok(Some(Step::ParamDesc(desc)));
}, }
Message::RowDescription(desc) => { Message::RowDescription(desc) => {
return Ok(Some(Step::RowDesc(desc))); return Ok(Some(Step::RowDesc(desc)));
}, }
message => { message => {
return Err(io::Error::new( return Err(io::Error::new(
@ -260,9 +261,7 @@ impl PostgresRawConnection {
b't' => Message::ParameterDescription(Box::new( b't' => Message::ParameterDescription(Box::new(
protocol::ParameterDescription::decode(body)?, protocol::ParameterDescription::decode(body)?,
)), )),
b'T' => Message::RowDescription(Box::new( b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)),
protocol::RowDescription::decode(body)?
)),
id => { id => {
return Err(io::Error::new( return Err(io::Error::new(

View file

@ -0,0 +1,41 @@
use crate::{
postgres::types::{PostgresTypeFormat, PostgresTypeMetadata},
serialize::IsNull,
types::TypeMetadata,
FromSql, HasSqlType, Postgres, ToSql,
};
impl HasSqlType<[u8]> for Postgres {
fn metadata() -> Self::TypeMetadata {
PostgresTypeMetadata {
format: PostgresTypeFormat::Binary,
oid: 17,
array_oid: 1001,
}
}
}
impl HasSqlType<Vec<u8>> for Postgres {
fn metadata() -> Self::TypeMetadata {
<Postgres as HasSqlType<[u8]>>::metadata()
}
}
impl ToSql<Postgres> for [u8] {
fn to_sql(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend_from_slice(self);
IsNull::No
}
}
impl ToSql<Postgres> for Vec<u8> {
fn to_sql(&self, buf: &mut Vec<u8>) -> IsNull {
<[u8] as ToSql<Postgres>>::to_sql(self, buf)
}
}
impl FromSql<Postgres> for Vec<u8> {
fn from_sql(raw: Option<&[u8]>) -> Self {
raw.unwrap().into()
}
}

View file

@ -28,9 +28,12 @@
//! | `Uuid` (`uuid` feature) | UUID | //! | `Uuid` (`uuid` feature) | UUID |
use super::Postgres; use super::Postgres;
use crate::types::TypeMetadata; use crate::{
use crate::HasSqlType; types::{HasTypeMetadata, TypeMetadata},
HasSqlType,
};
mod binary;
mod boolean; mod boolean;
mod character; mod character;
mod numeric; mod numeric;
@ -54,6 +57,59 @@ pub struct PostgresTypeMetadata {
pub array_oid: u32, pub array_oid: u32,
} }
impl TypeMetadata for Postgres { impl HasTypeMetadata for Postgres {
type TypeId = u32;
type TypeMetadata = PostgresTypeMetadata; type TypeMetadata = PostgresTypeMetadata;
fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str> {
Some(match id {
16 => "bool",
1000 => "&[bool]",
25 => "&str",
1009 => "&[&str]",
21 => "i16",
1005 => "&[i16]",
23 => "i32",
1007 => "&[i32]",
20 => "i64",
1016 => "&[i64]",
700 => "f32",
1021 => "&[f32]",
701 => "f64",
1022 => "&[f64]",
2950 => "sqlx::Uuid",
2951 => "&[sqlx::Uuid]",
_ => return None,
})
}
fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str> {
Some(match id {
16 => "bool",
1000 => "Vec<bool>",
25 => "String",
1009 => "Vec<String>",
21 => "i16",
1005 => "Vec<i16>",
23 => "i32",
1007 => "Vec<i32>",
20 => "i64",
1016 => "Vec<i64>",
700 => "f32",
1021 => "Vec<f32>",
701 => "f64",
1022 => "Vec<f64>",
2950 => "sqlx::Uuid",
2951 => "Vec<sqlx::Uuid>",
_ => return None,
})
}
}
impl TypeMetadata for PostgresTypeMetadata {
type TypeId = u32;
fn type_id(&self) -> &u32 {
&self.oid
}
} }

View file

@ -1,13 +1,25 @@
#[derive(Debug)] use crate::{query::QueryParameters, Backend, Error, Executor, FromSqlRow, HasSqlType, ToSql};
pub struct PreparedStatement {
pub name: String, use futures_core::{future::BoxFuture, stream::BoxStream};
pub param_types: Box<[u32]>, use std::marker::PhantomData;
pub fields: Vec<Field>,
use crate::types::{HasTypeMetadata, TypeMetadata};
use std::fmt::{self, Debug};
/// A prepared statement.
pub struct PreparedStatement<DB: Backend> {
///
pub identifier: <DB as Backend>::StatementIdent,
/// The expected type IDs of bind parameters.
pub param_types: Vec<<DB as HasTypeMetadata>::TypeId>,
///
pub columns: Vec<Column<DB>>,
} }
#[derive(Debug)] pub struct Column<DB: Backend> {
pub struct Field { pub name: Option<String>,
pub name: String, pub table_id: Option<<DB as Backend>::TableIdent>,
pub table_id: u32, /// The type ID of this result column.
pub type_id: u32, pub type_id: <DB as HasTypeMetadata>::TypeId,
} }

View file

@ -46,7 +46,7 @@ where
} }
impl<T: ?Sized, DB> ToSql<DB> for &'_ T impl<T: ?Sized, DB> ToSql<DB> for &'_ T
where where
DB: Backend + HasSqlType<T>, DB: Backend + HasSqlType<T>,
T: ToSql<DB>, T: ToSql<DB>,
{ {

View file

@ -1,18 +1,39 @@
/// Information about how a backend stores metadata about /// Information about how a backend stores metadata about
/// given SQL types. /// given SQL types.
pub trait TypeMetadata { pub trait HasTypeMetadata {
/// The actual type used to represent metadata. /// The actual type used to represent metadata.
type TypeMetadata; type TypeMetadata: TypeMetadata<TypeId = Self::TypeId>;
/// The Rust type of the type ID for the backend.
type TypeId: Eq;
/// UNSTABLE: for internal use only
#[doc(hidden)]
fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str>;
/// UNSTABLE: for internal use only
#[doc(hidden)]
fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str>;
}
pub trait TypeMetadata {
type TypeId: Eq;
fn type_id(&self) -> &Self::TypeId;
fn type_id_eq(&self, id: &Self::TypeId) -> bool {
self.type_id() == id
}
} }
/// Indicates that a SQL type exists for a backend and defines /// Indicates that a SQL type exists for a backend and defines
/// useful metadata for the backend. /// useful metadata for the backend.
pub trait HasSqlType<A: ?Sized>: TypeMetadata { pub trait HasSqlType<A: ?Sized>: HasTypeMetadata {
fn metadata() -> Self::TypeMetadata; fn metadata() -> Self::TypeMetadata;
} }
impl<A: ?Sized, DB> HasSqlType<&'_ A> for DB impl<A: ?Sized, DB> HasSqlType<&'_ A> for DB
where DB: HasSqlType<A> where
DB: HasSqlType<A>,
{ {
fn metadata() -> Self::TypeMetadata { fn metadata() -> Self::TypeMetadata {
<DB as HasSqlType<A>>::metadata() <DB as HasSqlType<A>>::metadata()
@ -20,7 +41,8 @@ impl<A: ?Sized, DB> HasSqlType<&'_ A> for DB
} }
impl<A, DB> HasSqlType<Option<A>> for DB impl<A, DB> HasSqlType<Option<A>> for DB
where DB: HasSqlType<A> where
DB: HasSqlType<A>,
{ {
fn metadata() -> Self::TypeMetadata { fn metadata() -> Self::TypeMetadata {
<DB as HasSqlType<A>>::metadata() <DB as HasSqlType<A>>::metadata()

View file

@ -2,7 +2,9 @@
#[tokio::test] #[tokio::test]
async fn test_sqlx_macro() -> sqlx::Result<()> { async fn test_sqlx_macro() -> sqlx::Result<()> {
let conn = sqlx::Connection::<sqlx::Postgres>::establish("postgres://postgres@127.0.0.1/sqlx_test").await?; let conn =
sqlx::Connection::<sqlx::Postgres>::establish("postgres://postgres@127.0.0.1/sqlx_test")
.await?;
let uuid: sqlx::Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap(); let uuid: sqlx::Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap();
let accounts = sqlx_macros::sql!("SELECT * from accounts where id = $1", 5i64) let accounts = sqlx_macros::sql!("SELECT * from accounts where id = $1", 5i64)
.fetch_one(&conn) .fetch_one(&conn)