unified prepared statement interface

This commit is contained in:
Austin Bonander 2019-11-12 20:57:55 -08:00 committed by Austin Bonander
parent c993e4eee0
commit bbdc03c576
29 changed files with 654 additions and 253 deletions

View file

@ -8,14 +8,14 @@ edition = "2018"
proc-macro = true
[dependencies]
dotenv = "0.15.0"
futures-preview = "0.3.0-alpha.18"
hex = "0.4.0"
proc-macro2 = "1.0.6"
sqlx = { path = "../", features = ["postgres"] }
syn = "1.0"
quote = "1.0"
sha2 = "0.8.0"
tokio = { version = "0.2.0-alpha.4", default-features = false, features = [ "tcp" ] }
url = "2.1.0"
[features]
postgres = ["sqlx/postgres"]

View file

@ -6,47 +6,52 @@ use proc_macro2::Span;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::{parse_macro_input, Expr, ExprLit, Lit, LitStr, Token, Type};
use syn::spanned::Spanned;
use syn::punctuated::Punctuated;
use syn::parse::{self, Parse, ParseStream};
use syn::{
parse::{self, Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
spanned::Spanned,
Expr, ExprLit, Lit, LitStr, Token, Type,
};
use sha2::{Sha256, Digest};
use sqlx::Postgres;
use sqlx::{HasTypeMetadata, Postgres};
use tokio::runtime::Runtime;
use std::error::Error as _;
use std::{error::Error as _, fmt::Display, str::FromStr};
use url::Url;
type Error = Box<dyn std::error::Error>;
type Result<T> = std::result::Result<T, Error>;
mod postgres;
struct MacroInput {
sql: String,
sql_span: Span,
args: Vec<Expr>
args: Vec<Expr>,
}
impl Parse for MacroInput {
fn parse(input: ParseStream) -> parse::Result<Self> {
let mut args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?
.into_iter();
let mut args = Punctuated::<Expr, Token![,]>::parse_terminated(input)?.into_iter();
let sql = match args.next() {
Some(Expr::Lit(ExprLit { lit: Lit::Str(sql), .. })) => sql,
Some(other_expr) => return Err(parse::Error::new_spanned(other_expr, "expected string literal")),
Some(Expr::Lit(ExprLit {
lit: Lit::Str(sql), ..
})) => sql,
Some(other_expr) => {
return Err(parse::Error::new_spanned(
other_expr,
"expected string literal",
));
}
None => return Err(input.error("expected SQL string literal")),
};
Ok(
MacroInput {
sql: sql.value(),
sql_span: sql.span(),
args: args.collect(),
}
)
Ok(MacroInput {
sql: sql.value(),
sql_span: sql.span(),
args: args.collect(),
})
}
}
@ -56,52 +61,109 @@ pub fn sql(input: TokenStream) -> TokenStream {
eprintln!("expanding macro");
match Runtime::new().map_err(Error::from).and_then(|runtime| runtime.block_on(process_sql(input))) {
match Runtime::new()
.map_err(Error::from)
.and_then(|runtime| runtime.block_on(process_sql(input)))
{
Ok(ts) => {
eprintln!("emitting output: {}", ts);
ts
},
}
Err(e) => {
if let Some(parse_err) = e.downcast_ref::<parse::Error>() {
return parse_err.to_compile_error().into();
}
let msg = e.to_string();
quote! ( compile_error!(#msg) ).into()
quote!(compile_error!(#msg)).into()
}
}
}
async fn process_sql(input: MacroInput) -> Result<TokenStream> {
let hash = dbg!(hex::encode(&Sha256::digest(input.sql.as_bytes())));
let db_url = Url::parse(&dotenv::var("DB_URL")?)?;
let conn = sqlx::Connection::<Postgres>::establish("postgresql://postgres@127.0.0.1/sqlx_test")
.await
.map_err(|e| format!("failed to connect to database: {}", e))?;
match db_url.scheme() {
#[cfg(feature = "postgres")]
"postgresql" => {
process_sql_with(
input,
sqlx::Connection::<sqlx::Postgres>::establish(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?,
)
.await
}
#[cfg(feature = "mysql")]
"mysql" => {
process_sql_with(
input,
sqlx::Connection::<sqlx::MariaDb>::establish(
"postgresql://postgres@127.0.0.1/sqlx_test",
)
.await
.map_err(|e| format!("failed to connect to database: {}", e))?,
)
.await
}
scheme => Err(format!("unexpected scheme {:?} in DB_URL {}", scheme, db_url).into()),
}
}
async fn process_sql_with<DB: sqlx::Backend>(
input: MacroInput,
conn: sqlx::Connection<DB>,
) -> Result<TokenStream>
where
<DB as HasTypeMetadata>::TypeId: Display,
{
eprintln!("connection established");
let prepared = conn.prepare(&hash, &input.sql)
let prepared = conn
.prepare(&input.sql)
.await
.map_err(|e| parse::Error::new(input.sql_span, e))?;
if input.args.len() != prepared.param_types.len() {
return Err(parse::Error::new(
Span::call_site(),
format!("expected {} parameters, got {}", prepared.param_types.len(), input.args.len())
).into());
format!(
"expected {} parameters, got {}",
prepared.param_types.len(),
input.args.len()
),
)
.into());
}
let param_types = prepared.param_types.iter().zip(&*input.args).map(|(type_, expr)| {
get_type_override(expr)
.or_else(|| postgres::map_param_type_oid(*type_))
.ok_or_else(|| format!("unknown type OID: {}", type_).into())
})
let param_types = prepared
.param_types
.iter()
.zip(&*input.args)
.map(|(type_, expr)| {
get_type_override(expr)
.or_else(|| {
Some(
<DB as sqlx::HasTypeMetadata>::param_type_for_id(type_)?
.parse::<proc_macro2::TokenStream>()
.unwrap(),
)
})
.ok_or_else(|| format!("unknown type ID: {}", type_).into())
})
.collect::<Result<Vec<_>>>()?;
let output_types = prepared.fields.iter().map(|field| {
postgres::map_output_type_oid(field.type_id)
})
let output_types = prepared
.columns
.iter()
.map(|column| {
Ok(
<DB as sqlx::HasTypeMetadata>::return_type_for_id(&column.type_id)
.ok_or_else(|| format!("unknown type ID: {}", &column.type_id))?
.parse::<proc_macro2::TokenStream>()
.unwrap(),
)
})
.collect::<Result<Vec<_>>>()?;
let params = input.args.iter();
@ -112,25 +174,23 @@ async fn process_sql(input: MacroInput) -> Result<TokenStream> {
let query = &input.sql;
Ok(
quote! {{
use sqlx::TyConsExt as _;
Ok(quote! {{
use sqlx::TyConsExt as _;
let params = (#(#params),*,);
let params = (#(#params),*,);
if false {
let _: (#(#param_types),*,) = (#(#params_ty_cons),*,);
}
if false {
let _: (#(#param_types),*,) = (#(#params_ty_cons),*,);
}
sqlx::CompiledSql::<_, (#(#output_types),*), sqlx::Postgres> {
query: #query,
params,
output: ::core::marker::PhantomData,
backend: ::core::marker::PhantomData,
}
}}
.into()
)
sqlx::CompiledSql::<_, (#(#output_types),*), sqlx::Postgres> {
query: #query,
params,
output: ::core::marker::PhantomData,
backend: ::core::marker::PhantomData,
}
}}
.into())
}
fn get_type_override(expr: &Expr) -> Option<proc_macro2::TokenStream> {

View file

@ -1,45 +0,0 @@
use proc_macro2::TokenStream;
pub fn map_param_type_oid(oid: u32) -> Option<TokenStream> {
Some(match oid {
16 => "bool",
1000 => "&[bool]",
25 => "&str",
1009 => "&[&str]",
21 => "i16",
1005 => "&[i16]",
23 => "i32",
1007 => "&[i32]",
20 => "i64",
1016 => "&[i64]",
700 => "f32",
1021 => "&[f32]",
701 => "f64",
1022 => "&[f64]",
2950 => "sqlx::Uuid",
2951 => "&[sqlx::Uuid]",
_ => return None
}.parse().unwrap())
}
pub fn map_output_type_oid(oid: u32) -> crate::Result<TokenStream> {
Ok(match oid {
16 => "bool",
1000 => "Vec<bool>",
25 => "String",
1009 => "Vec<String>",
21 => "i16",
1005 => "Vec<i16>",
23 => "i32",
1007 => "Vec<i32>",
20 => "i64",
1016 => "Vec<i64>",
700 => "f32",
1021 => "Vec<f32>",
701 => "f64",
1022 => "Vec<f64>",
2950 => "sqlx::Uuid",
2951 => "Vec<sqlx::Uuid>",
_ => return Err(format!("unknown type ID: {}", oid).into())
}.parse().unwrap())
}

View file

@ -1,9 +1,9 @@
use crate::{connection::RawConnection, query::QueryParameters, row::Row};
use crate::{connection::RawConnection, query::QueryParameters, row::Row, types::HasTypeMetadata};
/// A database backend.
///
/// This trait represents the concept of a backend (e.g. "MySQL" vs "SQLite").
pub trait Backend: Sized {
pub trait Backend: HasTypeMetadata + Sized {
/// The concrete `QueryParameters` implementation for this backend.
type QueryParameters: QueryParameters<Backend = Self>;
@ -13,4 +13,12 @@ pub trait Backend: Sized {
/// The concrete `Row` implementation for this backend. This type is returned
/// from methods in the `RawConnection`.
type Row: Row<Backend = Self>;
/// The identifier for prepared statements; in Postgres this is a string
/// and in MariaDB/MySQL this is an integer.
type StatementIdent;
/// The identifier for tables; in Postgres this is an `oid` while
/// in MariaDB/MySQL this is the qualified name of the table.
type TableIdent;
}

View file

@ -1,5 +1,8 @@
use std::{env, str};
use std::io::{self, Write, Read};
use std::{
env,
io::{self, Read, Write},
str,
};
use std::process::{Command, Stdio};
@ -11,13 +14,17 @@ fn get_expanded_target() -> crate::Result<Vec<u8>> {
let mut args = env::args_os().skip(2);
let cargo_args = args.by_ref().take_while(|arg| arg != "--").collect::<Vec<_>>();
let cargo_args = args
.by_ref()
.take_while(|arg| arg != "--")
.collect::<Vec<_>>();
let rustc_args = args.collect::<Vec<_>>();
let mut command = Command::new(cargo_path);
command.arg("rustc")
command
.arg("rustc")
.args(cargo_args)
.arg("--")
.arg("-Z")
@ -61,7 +68,7 @@ fn find_next_sql_string(input: &str) -> Result<Option<(&str, &str)>> {
let start = idx + STRING_START.len();
while let Some(end) = input[start..].find(STRING_END) {
if &input[start + end - 1 .. start + end] != "\\" {
if &input[start + end - 1..start + end] != "\\" {
return Ok(Some(input[start..].split_at(end)));
}
}

View file

@ -1,9 +1,6 @@
use crate::{query::IntoQueryParameters, Backend, Executor, FromSqlRow};
use futures_core::{future::BoxFuture, stream::BoxStream, Stream};
use std::marker::PhantomData;
use crate::query::IntoQueryParameters;
use crate::{Backend, FromSqlRow, Executor};
use futures_core::Stream;
use futures_core::stream::BoxStream;
use futures_core::future::BoxFuture;
pub struct CompiledSql<P, O, DB> {
#[doc(hidden)]
@ -12,30 +9,44 @@ pub struct CompiledSql<P, O, DB> {
pub params: P,
#[doc(hidden)]
pub output: PhantomData<O>,
pub backend: PhantomData<DB>
pub backend: PhantomData<DB>,
}
impl<DB, P, O> CompiledSql<P, O, DB> where DB: Backend, P: IntoQueryParameters<DB> + Send, O: FromSqlRow<DB> + Send + Unpin {
impl<DB, P, O> CompiledSql<P, O, DB>
where
DB: Backend,
P: IntoQueryParameters<DB> + Send,
O: FromSqlRow<DB> + Send + Unpin,
{
#[inline]
pub fn execute<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result<u64>>
where
E: Executor<Backend = DB>, DB: 'e, P: 'e, O: 'e
where
E: Executor<Backend = DB>,
DB: 'e,
P: 'e,
O: 'e,
{
executor.execute(self.query, self.params)
}
#[inline]
pub fn fetch<'e, E: 'e>(self, executor: &'e E) -> BoxStream<'e, crate::Result<O>>
where
E: Executor<Backend = DB>, DB: 'e, P: 'e, O: 'e
where
E: Executor<Backend = DB>,
DB: 'e,
P: 'e,
O: 'e,
{
executor.fetch(self.query, self.params)
}
#[inline]
pub fn fetch_all<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result<Vec<O>>>
where
E: Executor<Backend = DB>, DB: 'e, P: 'e, O: 'e
where
E: Executor<Backend = DB>,
DB: 'e,
P: 'e,
O: 'e,
{
executor.fetch_all(self.query, self.params)
}
@ -45,16 +56,22 @@ impl<DB, P, O> CompiledSql<P, O, DB> where DB: Backend, P: IntoQueryParameters<D
self,
executor: &'e E,
) -> BoxFuture<'e, crate::Result<Option<O>>>
where
E: Executor<Backend = DB>, DB: 'e, P: 'e, O: 'e
where
E: Executor<Backend = DB>,
DB: 'e,
P: 'e,
O: 'e,
{
executor.fetch_optional(self.query, self.params)
}
#[inline]
pub fn fetch_one<'e, E: 'e>(self, executor: &'e E) -> BoxFuture<'e, crate::Result<O>>
where
E: Executor<Backend = DB>, DB: 'e, P: 'e, O: 'e
where
E: Executor<Backend = DB>,
DB: 'e,
P: 'e,
O: 'e,
{
executor.fetch_one(self.query, self.params)
}

View file

@ -3,6 +3,7 @@ use crate::{
error::Error,
executor::Executor,
pool::{Live, SharedPool},
prepared::PreparedStatement,
query::{IntoQueryParameters, QueryParameters},
row::FromSqlRow,
};
@ -19,7 +20,6 @@ use std::{
},
time::Instant,
};
use crate::prepared::PreparedStatement;
/// A connection.bak to the database.
///
@ -73,10 +73,15 @@ pub trait RawConnection: Send {
params: <Self::Backend as Backend>::QueryParameters,
) -> crate::Result<Option<<Self::Backend as Backend>::Row>>;
async fn prepare(&mut self, name: &str, body: &str) -> crate::Result<PreparedStatement> {
// TODO: implement for other backends
unimplemented!()
}
async fn prepare(
&mut self,
query: &str,
) -> crate::Result<<Self::Backend as Backend>::StatementIdent>;
async fn prepare_describe(
&mut self,
query: &str,
) -> crate::Result<PreparedStatement<Self::Backend>>;
}
pub struct Connection<DB>(Arc<SharedConnection<DB>>)
@ -128,9 +133,12 @@ where
}
/// Prepares a statement.
pub async fn prepare(&self, name: &str, body: &str) -> crate::Result<PreparedStatement> {
///
/// UNSTABLE: for use by sqlx-macros only
#[doc(hidden)]
pub async fn prepare(&self, body: &str) -> crate::Result<PreparedStatement<DB>> {
let mut live = self.0.acquire().await;
let ret = live.raw.prepare(name, body).await?;
let ret = live.raw.prepare_describe(body).await?;
self.0.release(live);
Ok(ret)
}

View file

@ -54,7 +54,7 @@ impl Display for Error {
match self {
Error::Io(error) => write!(f, "{}", error),
Error::Database(error) => f.write_str(error.message()),
Error::Database(error) => Display::fmt(error, f),
Error::NotFound => f.write_str("found no rows when we expected at least one"),
@ -85,8 +85,6 @@ where
}
/// An error that was returned by the database backend.
pub trait DatabaseError: Debug + Send + Sync {
pub trait DatabaseError: Display + Debug + Send + Sync {
fn message(&self) -> &str;
// TODO: Expose more error properties
}

View file

@ -33,12 +33,11 @@ mod compiled;
#[doc(inline)]
pub use self::{
backend::Backend,
connection::Connection,
compiled::CompiledSql,
connection::Connection,
deserialize::FromSql,
error::{Error, Result},
executor::Executor,
prepared::{PreparedStatement, Field},
pool::Pool,
row::{FromSqlRow, Row},
serialize::ToSql,
@ -46,6 +45,9 @@ pub use self::{
types::HasSqlType,
};
#[doc(hidden)]
pub use types::HasTypeMetadata;
#[cfg(feature = "mariadb")]
pub mod mariadb;

View file

@ -7,6 +7,8 @@ impl Backend for MariaDb {
type QueryParameters = super::MariaDbQueryParameters;
type RawConnection = super::MariaDbRawConnection;
type Row = super::MariaDbRow;
type StatementIdent = u32;
type TableIdent = String;
}
impl_from_sql_row_tuples_for_backend!(MariaDb);

View file

@ -11,11 +11,13 @@ use crate::{
},
MariaDb, MariaDbQueryParameters, MariaDbRow,
},
Backend, Error, Result,
prepared::{Column, PreparedStatement},
Backend, Error, PreparedStatement, Result,
};
use async_trait::async_trait;
use byteorder::{ByteOrder, LittleEndian};
use futures_core::{future::BoxFuture, stream::BoxStream};
use futures_util::stream::{self, StreamExt};
use std::{
future::Future,
io,
@ -173,51 +175,33 @@ impl MariaDbRawConnection {
})
}
// This should not be used by the user. It's mean for `RawConnection` impl
// This assumes the buffer has been set and all it needs is a flush
async fn exec_prepare(&mut self) -> Result<u32> {
self.stream.flush().await?;
// COM_STMT_PREPARE returns COM_STMT_PREPARE_OK (0x00) or ERR (0xFF)
let mut packet = self.receive().await?;
let ok = match packet[0] {
0xFF => {
let err = ErrPacket::decode(packet)?;
// TODO: Bubble as Error::Database
// panic!("received db err = {:?}", err);
return Err(
io::Error::new(io::ErrorKind::InvalidInput, format!("{:?}", err)).into(),
);
}
_ => ComStmtPrepareOk::decode(packet)?,
};
// Skip decoding Column Definition packets for the result from a prepare statement
for _ in 0..ok.columns {
let _ = self.receive().await?;
}
if ok.columns > 0
&& !self
.capabilities
.contains(Capabilities::CLIENT_DEPRECATE_EOF)
async fn check_eof(&mut self) -> Result<()> {
if !self
.capabilities
.contains(Capabilities::CLIENT_DEPRECATE_EOF)
{
// TODO: Should we do something with the warning indicators here?
let _eof = EofPacket::decode(self.receive().await?)?;
let _ = EofPacket::decode(self.receive().await?)?;
}
Ok(ok.statement_id)
Ok(())
}
async fn prepare<'c>(&'c mut self, statement: &'c str) -> Result<u32> {
async fn send_prepare<'c>(&'c mut self, statement: &'c str) -> Result<ComStmtPrepareOk> {
self.stream.flush().await?;
self.start_sequence();
self.write(ComStmtPrepare { statement });
self.exec_prepare().await
self.stream.flush().await?;
// COM_STMT_PREPARE returns COM_STMT_PREPARE_OK (0x00) or ERR (0xFF)
let packet = self.receive().await?;
if packet[0] == 0xFF {
return Err(ErrPacket::decode(packet)?.into());
}
ComStmtPrepareOk::decode(packet).map_err(Into::into)
}
async fn execute(&mut self, statement_id: u32, params: MariaDbQueryParameters) -> Result<u64> {
@ -323,11 +307,9 @@ impl RawConnection for MariaDbRawConnection {
async fn execute(&mut self, query: &str, params: MariaDbQueryParameters) -> crate::Result<u64> {
// Write prepare statement to buffer
self.start_sequence();
self.write(ComStmtPrepare { statement: query });
let prepare_ok = self.send_prepare(query).await?;
let statement_id = self.exec_prepare().await?;
let affected = self.execute(statement_id, params).await?;
let affected = self.execute(prepare_ok.statement_id, params).await?;
Ok(affected)
}
@ -347,6 +329,56 @@ impl RawConnection for MariaDbRawConnection {
) -> crate::Result<Option<<Self::Backend as Backend>::Row>> {
unimplemented!();
}
async fn prepare(&mut self, query: &str) -> crate::Result<u32> {
let prepare_ok = self.send_prepare(query).await?;
for _ in 0..prepare_ok.params {
let _ = self.receive().await?;
}
self.check_eof().await?;
for _ in 0..prepare_ok.columns {
let _ = self.receive().await?;
}
self.check_eof().await?;
Ok(prepare_ok.statement_id)
}
async fn prepare_describe(&mut self, query: &str) -> crate::Result<PreparedStatement<MariaDb>> {
let prepare_ok = self.send_prepare(query).await?;
let mut param_types = Vec::with_capacity(prepare_ok.params as usize);
for _ in 0..prepare_ok.params {
let param = ColumnDefinitionPacket::decode(self.receive().await?)?;
param_types.push(param.field_type.0);
}
self.check_eof().await?;
let mut columns = Vec::with_capacity(prepare_ok.columns as usize);
for _ in 0..prepare_ok.columns {
let column = ColumnDefinitionPacket::decode(self.receive().await?)?;
columns.push(Column {
name: column.column_alias.or(column.column),
table_id: column.table_alias.or(column.table),
type_id: column.field_type.0,
})
}
self.check_eof().await?;
Ok(PreparedStatement {
identifier: prepare_ok.statement_id,
param_types,
columns,
})
}
}
#[cfg(test)]

21
src/mariadb/error.rs Normal file
View file

@ -0,0 +1,21 @@
use crate::{error::DatabaseError, mariadb::protocol::ErrorCode};
use std::fmt;
#[derive(Debug)]
pub struct Error {
pub code: ErrorCode,
pub message: Box<str>,
}
impl DatabaseError for Error {
fn message(&self) -> &str {
&self.message
}
}
impl fmt::Display for ErrorCode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "MariaDB returned an error: {}",)
}
}

View file

@ -1,5 +1,6 @@
mod backend;
mod connection;
mod error;
mod establish;
mod io;
mod protocol;

View file

@ -1,10 +1,40 @@
use std::fmt;
#[derive(Default, Debug)]
pub struct ErrorCode(pub(crate) u16);
// TODO: It would be nice to figure out a clean way to go from 1152 to "ER_ABORTING_CONNECTION (1152)" in Debug.
use crate::error::DatabaseError;
use bitflags::_core::fmt::{Error, Formatter};
macro_rules! error_code_impl {
($(const $name:ident: ErrorCode = ErrorCode($code:expr));*;) => {
impl ErrorCode {
$(const $name: ErrorCode = ErrorCode($code);)*
pub fn code_name(&self) -> &'static str {
match self.0 {
$($code => $name,)*
_ => "<unknown error>"
}
}
}
}
}
impl fmt::Debug for ErrorCode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ErrorCode({} [()])",)
}
}
impl fmt::Display for ErrorCode {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} ({})", self.code_name(), self.0)
}
}
// Values from https://mariadb.com/kb/en/library/mariadb-error-codes/
impl ErrorCode {
error_code_impl! {
const ER_ABORTING_CONNECTION: ErrorCode = ErrorCode(1152);
const ER_ACCESS_DENIED_CHANGE_USER_ERROR: ErrorCode = ErrorCode(1873);
const ER_ACCESS_DENIED_ERROR: ErrorCode = ErrorCode(1045);

View file

@ -1,6 +1,6 @@
use crate::{
io::Buf,
mariadb::{io::BufExt, protocol::ErrorCode},
mariadb::{error::Error, io::BufExt, protocol::ErrorCode},
};
use byteorder::LittleEndian;
use std::io;
@ -66,6 +66,15 @@ impl ErrPacket {
})
}
}
pub fn expect_error<T>(self) -> crate::Result<T> {
match self {
ErrPacket::Progress { .. } => {
Err(format!("expected ErrPacket::Err, got {:?}", self).into())
}
ErrPacket::Error { code, message, .. } => Err(Error { code, message }.into()),
}
}
}
#[cfg(test)]

View file

@ -1,5 +1,8 @@
use super::protocol::{FieldType, ParameterFlag};
use crate::{mariadb::MariaDb, types::TypeMetadata};
use crate::{
mariadb::MariaDb,
types::{HasTypeMetadata, TypeMetadata},
};
pub mod boolean;
pub mod character;
@ -11,6 +14,43 @@ pub struct MariaDbTypeMetadata {
pub param_flag: ParameterFlag,
}
impl TypeMetadata for MariaDb {
impl HasTypeMetadata for MariaDb {
type TypeMetadata = MariaDbTypeMetadata;
type TypeId = u8;
fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str> {
Some(match FieldType(*id) {
FieldType::MYSQL_TYPE_TINY => "i8",
FieldType::MYSQL_TYPE_SHORT => "i16",
FieldType::MYSQL_TYPE_LONG => "i32",
FieldType::MYSQL_TYPE_LONGLONG => "i64",
FieldType::MYSQL_TYPE_VAR_STRING => "&str",
FieldType::MYSQL_TYPE_FLOAT => "f32",
FieldType::MYSQL_TYPE_DOUBLE => "f64",
FieldType::MYSQL_TYPE_BLOB => "&[u8]",
_ => return None
})
}
fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str> {
Some(match FieldType(*id) {
FieldType::MYSQL_TYPE_TINY => "i8",
FieldType::MYSQL_TYPE_SHORT => "i16",
FieldType::MYSQL_TYPE_LONG => "i32",
FieldType::MYSQL_TYPE_LONGLONG => "i64",
FieldType::MYSQL_TYPE_VAR_STRING => "String",
FieldType::MYSQL_TYPE_FLOAT => "f32",
FieldType::MYSQL_TYPE_DOUBLE => "f64",
FieldType::MYSQL_TYPE_BLOB => "Vec<u8>",
_ => return None
})
}
}
impl TypeMetadata for MariaDbTypeMetadata {
type TypeId = u8;
fn type_id(&self) -> &Self::TypeId {
&self.field_type.0
}
}

View file

@ -7,6 +7,8 @@ impl Backend for Postgres {
type QueryParameters = super::PostgresQueryParameters;
type RawConnection = super::PostgresRawConnection;
type Row = super::PostgresRow;
type StatementIdent = String;
type TableIdent = u32;
}
impl_from_sql_row_tuples_for_backend!(Postgres);

View file

@ -1,10 +1,19 @@
use super::{Postgres, PostgresQueryParameters, PostgresRawConnection, PostgresRow};
use crate::{connection::RawConnection, postgres::raw::Step, url::Url, Error};
use crate::query::QueryParameters;
use crate::{
connection::RawConnection,
postgres::{error::ProtocolError, raw::Step},
prepared::{Column, PreparedStatement},
query::QueryParameters,
url::Url,
Error,
};
use async_trait::async_trait;
use futures_core::stream::BoxStream;
use crate::prepared::{PreparedStatement, Field};
use crate::postgres::error::ProtocolError;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::postgres::{protocol::Message, PostgresDatabaseError};
use std::hash::Hasher;
#[async_trait]
impl RawConnection for PostgresRawConnection {
@ -96,45 +105,80 @@ impl RawConnection for PostgresRawConnection {
Ok(row)
}
async fn prepare(&mut self, name: &str, body: &str) -> crate::Result<PreparedStatement> {
self.parse(name, body, &PostgresQueryParameters::new());
self.describe(name);
async fn prepare(&mut self, body: &str) -> crate::Result<String> {
let name = gen_statement_name(body);
self.parse(&name, body, &PostgresQueryParameters::new());
match self.receive().await? {
Some(Message::Response(response)) => Err(PostgresDatabaseError(response).into()),
Some(Message::ParseComplete) => Ok(name),
Some(message) => {
Err(ProtocolError(format!("unexpected message: {:?}", message)).into())
}
None => Err(ProtocolError("expected ParseComplete or ErrorResponse").into()),
}
}
async fn prepare_describe(&mut self, body: &str) -> crate::Result<PreparedStatement<Postgres>> {
let name = gen_statement_name(body);
self.parse(&name, body, &PostgresQueryParameters::new());
self.describe(&name);
self.sync().await?;
let param_desc= loop {
let step = self.step().await?
let param_desc = loop {
let step = self
.step()
.await?
.ok_or(ProtocolError("did not receive ParameterDescription"));
if let Step::ParamDesc(desc) = dbg!(step)?
{
break desc;
}
if let Step::ParamDesc(desc) = step? {
break desc;
}
};
let row_desc = loop {
let step = self.step().await?
let step = self
.step()
.await?
.ok_or(ProtocolError("did not receive RowDescription"));
if let Step::RowDesc(desc) = dbg!(step)?
{
if let Step::RowDesc(desc) = step? {
break desc;
}
};
Ok(PreparedStatement {
name: name.into(),
param_types: param_desc.ids,
fields: row_desc.fields.into_vec().into_iter()
.map(|field| Field {
name: field.name,
table_id: field.table_id,
type_id: field.type_id
identifier: name.into(),
param_types: param_desc.ids.into_vec(),
columns: row_desc
.fields
.into_vec()
.into_iter()
.map(|field| Column {
name: Some(field.name),
table_id: Some(field.table_id),
type_id: field.type_id,
})
.collect(),
})
}
}
static STATEMENT_COUNT: AtomicU64 = AtomicU64::new(0);
fn gen_statement_name(query: &str) -> String {
// hasher with no external dependencies
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
// including a global counter should help prevent collision
// with queries with the same content
hasher.write_u64(STATEMENT_COUNT.fetch_add(1, Ordering::SeqCst));
hasher.write(query.as_bytes());
format!("sqlx_stmt_{:x}", hasher.finish())
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -1,7 +1,10 @@
use super::protocol::Response;
use crate::error::DatabaseError;
use std::borrow::Cow;
use std::fmt::Debug;
use bitflags::_core::fmt::{Error, Formatter};
use std::{
borrow::Cow,
fmt::{self, Debug, Display},
};
#[derive(Debug)]
pub struct PostgresDatabaseError(pub(super) Box<Response>);
@ -15,8 +18,20 @@ impl DatabaseError for PostgresDatabaseError {
}
}
impl Display for PostgresDatabaseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad(self.message())
}
}
impl<T: AsRef<str> + Debug + Send + Sync> DatabaseError for ProtocolError<T> {
fn message(&self) -> &str {
self.0.as_ref()
}
}
impl<T: AsRef<str>> Display for ProtocolError<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.pad(self.0.as_ref())
}
}

View file

@ -21,5 +21,5 @@ pub enum Message {
NoData,
PortalSuspended,
ParameterDescription(Box<ParameterDescription>),
RowDescription(Box<RowDescription>)
RowDescription(Box<RowDescription>),
}

View file

@ -31,10 +31,22 @@ mod terminate;
// TODO: mod ssl_request;
pub use self::{
bind::Bind, cancel_request::CancelRequest, close::Close, copy_data::CopyData,
copy_done::CopyDone, copy_fail::CopyFail, describe::Describe, describe::DescribeKind, encode::Encode, execute::Execute,
flush::Flush, parse::Parse, password_message::PasswordMessage, query::Query,
startup_message::StartupMessage, sync::Sync, terminate::Terminate,
bind::Bind,
cancel_request::CancelRequest,
close::Close,
copy_data::CopyData,
copy_done::CopyDone,
copy_fail::CopyFail,
describe::{Describe, DescribeKind},
encode::Encode,
execute::Execute,
flush::Flush,
parse::Parse,
password_message::PasswordMessage,
query::Query,
startup_message::StartupMessage,
sync::Sync,
terminate::Terminate,
};
mod authentication;
@ -54,10 +66,17 @@ mod row_description;
mod message;
pub use self::{
authentication::Authentication, backend_key_data::BackendKeyData,
command_complete::CommandComplete, data_row::DataRow, decode::Decode, message::Message,
notification_response::NotificationResponse, parameter_description::ParameterDescription,
parameter_status::ParameterStatus, ready_for_query::ReadyForQuery, response::Response,
authentication::Authentication,
backend_key_data::BackendKeyData,
command_complete::CommandComplete,
data_row::DataRow,
decode::Decode,
message::Message,
notification_response::NotificationResponse,
parameter_description::ParameterDescription,
parameter_status::ParameterStatus,
ready_for_query::ReadyForQuery,
response::Response,
row_description::{RowDescription, RowField},
};

View file

@ -1,8 +1,7 @@
use super::Decode;
use crate::io::Buf;
use byteorder::NetworkEndian;
use std::io;
use std::io::BufRead;
use std::{io, io::BufRead};
#[derive(Debug)]
pub struct RowDescription {
@ -17,7 +16,7 @@ pub struct RowField {
pub type_id: u32,
pub type_size: i16,
pub type_mod: i32,
pub format_code: i16
pub format_code: i16,
}
impl Decode for RowDescription {
@ -26,7 +25,7 @@ impl Decode for RowDescription {
let mut fields = Vec::with_capacity(cnt);
for _ in 0..cnt {
fields.push(dbg!(RowField {
fields.push(RowField {
name: super::read_string(&mut buf)?,
table_id: buf.get_u32::<NetworkEndian>()?,
attr_num: buf.get_i16::<NetworkEndian>()?,
@ -34,7 +33,7 @@ impl Decode for RowDescription {
type_size: buf.get_i16::<NetworkEndian>()?,
type_mod: buf.get_i32::<NetworkEndian>()?,
format_code: buf.get_i16::<NetworkEndian>()?,
}));
});
}
Ok(Self {

View file

@ -151,8 +151,9 @@ impl PostgresRawConnection {
pub(super) fn describe(&mut self, statement: &str) {
protocol::Describe {
kind: protocol::DescribeKind::PreparedStatement,
name: statement
}.encode(self.stream.buffer_mut())
name: statement,
}
.encode(self.stream.buffer_mut())
}
pub(super) fn bind(&mut self, portal: &str, statement: &str, params: &PostgresQueryParameters) {
@ -198,15 +199,15 @@ impl PostgresRawConnection {
Message::ReadyForQuery(_) => {
return Ok(None);
},
}
Message::ParameterDescription(desc) => {
return Ok(Some(Step::ParamDesc(desc)));
},
}
Message::RowDescription(desc) => {
return Ok(Some(Step::RowDesc(desc)));
},
}
message => {
return Err(io::Error::new(
@ -260,9 +261,7 @@ impl PostgresRawConnection {
b't' => Message::ParameterDescription(Box::new(
protocol::ParameterDescription::decode(body)?,
)),
b'T' => Message::RowDescription(Box::new(
protocol::RowDescription::decode(body)?
)),
b'T' => Message::RowDescription(Box::new(protocol::RowDescription::decode(body)?)),
id => {
return Err(io::Error::new(

View file

@ -0,0 +1,41 @@
use crate::{
postgres::types::{PostgresTypeFormat, PostgresTypeMetadata},
serialize::IsNull,
types::TypeMetadata,
FromSql, HasSqlType, Postgres, ToSql,
};
impl HasSqlType<[u8]> for Postgres {
fn metadata() -> Self::TypeMetadata {
PostgresTypeMetadata {
format: PostgresTypeFormat::Binary,
oid: 17,
array_oid: 1001,
}
}
}
impl HasSqlType<Vec<u8>> for Postgres {
fn metadata() -> Self::TypeMetadata {
<Postgres as HasSqlType<[u8]>>::metadata()
}
}
impl ToSql<Postgres> for [u8] {
fn to_sql(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend_from_slice(self);
IsNull::No
}
}
impl ToSql<Postgres> for Vec<u8> {
fn to_sql(&self, buf: &mut Vec<u8>) -> IsNull {
<[u8] as ToSql<Postgres>>::to_sql(self, buf)
}
}
impl FromSql<Postgres> for Vec<u8> {
fn from_sql(raw: Option<&[u8]>) -> Self {
raw.unwrap().into()
}
}

View file

@ -28,9 +28,12 @@
//! | `Uuid` (`uuid` feature) | UUID |
use super::Postgres;
use crate::types::TypeMetadata;
use crate::HasSqlType;
use crate::{
types::{HasTypeMetadata, TypeMetadata},
HasSqlType,
};
mod binary;
mod boolean;
mod character;
mod numeric;
@ -54,6 +57,59 @@ pub struct PostgresTypeMetadata {
pub array_oid: u32,
}
impl TypeMetadata for Postgres {
impl HasTypeMetadata for Postgres {
type TypeId = u32;
type TypeMetadata = PostgresTypeMetadata;
fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str> {
Some(match id {
16 => "bool",
1000 => "&[bool]",
25 => "&str",
1009 => "&[&str]",
21 => "i16",
1005 => "&[i16]",
23 => "i32",
1007 => "&[i32]",
20 => "i64",
1016 => "&[i64]",
700 => "f32",
1021 => "&[f32]",
701 => "f64",
1022 => "&[f64]",
2950 => "sqlx::Uuid",
2951 => "&[sqlx::Uuid]",
_ => return None,
})
}
fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str> {
Some(match id {
16 => "bool",
1000 => "Vec<bool>",
25 => "String",
1009 => "Vec<String>",
21 => "i16",
1005 => "Vec<i16>",
23 => "i32",
1007 => "Vec<i32>",
20 => "i64",
1016 => "Vec<i64>",
700 => "f32",
1021 => "Vec<f32>",
701 => "f64",
1022 => "Vec<f64>",
2950 => "sqlx::Uuid",
2951 => "Vec<sqlx::Uuid>",
_ => return None,
})
}
}
impl TypeMetadata for PostgresTypeMetadata {
type TypeId = u32;
fn type_id(&self) -> &u32 {
&self.oid
}
}

View file

@ -1,13 +1,25 @@
#[derive(Debug)]
pub struct PreparedStatement {
pub name: String,
pub param_types: Box<[u32]>,
pub fields: Vec<Field>,
use crate::{query::QueryParameters, Backend, Error, Executor, FromSqlRow, HasSqlType, ToSql};
use futures_core::{future::BoxFuture, stream::BoxStream};
use std::marker::PhantomData;
use crate::types::{HasTypeMetadata, TypeMetadata};
use std::fmt::{self, Debug};
/// A prepared statement.
pub struct PreparedStatement<DB: Backend> {
///
pub identifier: <DB as Backend>::StatementIdent,
/// The expected type IDs of bind parameters.
pub param_types: Vec<<DB as HasTypeMetadata>::TypeId>,
///
pub columns: Vec<Column<DB>>,
}
#[derive(Debug)]
pub struct Field {
pub name: String,
pub table_id: u32,
pub type_id: u32,
pub struct Column<DB: Backend> {
pub name: Option<String>,
pub table_id: Option<<DB as Backend>::TableIdent>,
/// The type ID of this result column.
pub type_id: <DB as HasTypeMetadata>::TypeId,
}

View file

@ -46,9 +46,9 @@ where
}
impl<T: ?Sized, DB> ToSql<DB> for &'_ T
where
DB: Backend + HasSqlType<T>,
T: ToSql<DB>,
where
DB: Backend + HasSqlType<T>,
T: ToSql<DB>,
{
#[inline]
fn to_sql(&self, buf: &mut Vec<u8>) -> IsNull {

View file

@ -1,18 +1,39 @@
/// Information about how a backend stores metadata about
/// given SQL types.
pub trait TypeMetadata {
pub trait HasTypeMetadata {
/// The actual type used to represent metadata.
type TypeMetadata;
type TypeMetadata: TypeMetadata<TypeId = Self::TypeId>;
/// The Rust type of the type ID for the backend.
type TypeId: Eq;
/// UNSTABLE: for internal use only
#[doc(hidden)]
fn param_type_for_id(id: &Self::TypeId) -> Option<&'static str>;
/// UNSTABLE: for internal use only
#[doc(hidden)]
fn return_type_for_id(id: &Self::TypeId) -> Option<&'static str>;
}
pub trait TypeMetadata {
type TypeId: Eq;
fn type_id(&self) -> &Self::TypeId;
fn type_id_eq(&self, id: &Self::TypeId) -> bool {
self.type_id() == id
}
}
/// Indicates that a SQL type exists for a backend and defines
/// useful metadata for the backend.
pub trait HasSqlType<A: ?Sized>: TypeMetadata {
pub trait HasSqlType<A: ?Sized>: HasTypeMetadata {
fn metadata() -> Self::TypeMetadata;
}
impl<A: ?Sized, DB> HasSqlType<&'_ A> for DB
where DB: HasSqlType<A>
where
DB: HasSqlType<A>,
{
fn metadata() -> Self::TypeMetadata {
<DB as HasSqlType<A>>::metadata()
@ -20,7 +41,8 @@ impl<A: ?Sized, DB> HasSqlType<&'_ A> for DB
}
impl<A, DB> HasSqlType<Option<A>> for DB
where DB: HasSqlType<A>
where
DB: HasSqlType<A>,
{
fn metadata() -> Self::TypeMetadata {
<DB as HasSqlType<A>>::metadata()

View file

@ -2,7 +2,9 @@
#[tokio::test]
async fn test_sqlx_macro() -> sqlx::Result<()> {
let conn = sqlx::Connection::<sqlx::Postgres>::establish("postgres://postgres@127.0.0.1/sqlx_test").await?;
let conn =
sqlx::Connection::<sqlx::Postgres>::establish("postgres://postgres@127.0.0.1/sqlx_test")
.await?;
let uuid: sqlx::Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap();
let accounts = sqlx_macros::sql!("SELECT * from accounts where id = $1", 5i64)
.fetch_one(&conn)