[MySQL] Add fetch_optional, fix encode/decode for integers

This commit is contained in:
Ryan Leckey 2019-12-02 22:14:41 -08:00
parent 6925d5999c
commit bf4f65ea2f
7 changed files with 117 additions and 17 deletions

View file

@ -43,6 +43,10 @@ required-features = [ "postgres", "uuid", "macros" ]
name = "postgres-types"
required-features = [ "postgres" ]
[[test]]
name = "mysql-types"
required-features = [ "mariadb" ]
[[bench]]
name = "postgres-protocol"
required-features = [ "postgres", "unstable" ]

View file

@ -175,7 +175,37 @@ impl MariaDb {
return ErrPacket::decode(packet)?.expect_error();
}
ComStmtPrepareOk::decode(packet).map_err(Into::into)
let ok = ComStmtPrepareOk::decode(packet)?;
// Input parameters
for _ in 0..ok.params {
// TODO: Maybe do something with this data ?
let _column = ColumnDefinitionPacket::decode(self.receive().await?)?;
}
// TODO: De-duplicate this
if !self
.capabilities
.contains(Capabilities::CLIENT_DEPRECATE_EOF)
{
let _eof = EofPacket::decode(self.receive().await?)?;
}
// Output parameters
for _ in 0..ok.columns {
// TODO: Maybe do something with this data ?
let _column = ColumnDefinitionPacket::decode(self.receive().await?)?;
}
// TODO: De-duplicate this
if !self
.capabilities
.contains(Capabilities::CLIENT_DEPRECATE_EOF)
{
let _eof = EofPacket::decode(self.receive().await?)?;
}
Ok(ok)
}
pub(super) async fn column_definitions(
@ -192,7 +222,8 @@ impl MariaDb {
// TODO: This information was *already* returned by PREPARE .., is there a way to suppress generation
let mut columns = vec![];
for _ in 0..column_count {
columns.push(ColumnDefinitionPacket::decode(self.receive().await?)?);
let column =ColumnDefinitionPacket::decode(self.receive().await?)?;
columns.push(column);
}
// When (legacy) EOFs are enabled, the fixed number column definitions are further terminated by
@ -210,7 +241,7 @@ impl MariaDb {
pub(super) async fn send_execute(
&mut self,
statement_id: u32,
_params: MariaDbQueryParameters,
params: MariaDbQueryParameters,
) -> Result<()> {
// TODO: EXECUTE(READ_ONLY) => FETCH instead of EXECUTE(NO)
@ -218,10 +249,10 @@ impl MariaDb {
self.start_sequence();
self.write(ComStmtExecute {
statement_id,
params: &[],
null: &[],
params: &params.params,
null: &params.null,
flags: StmtExecFlag::NO_CURSOR,
param_types: &[],
param_types: &params.param_types,
});
self.stream.flush().await?;
// =====================

View file

@ -115,7 +115,38 @@ impl Executor for MariaDb {
I: IntoQueryParameters<Self::Backend> + Send,
T: FromRow<Self::Backend, O> + Send,
{
unimplemented!();
let params = params.into_params();
Box::pin(async move {
let prepare = self.send_prepare(query).await?;
self.send_execute(prepare.statement_id, params).await?;
let columns = self.column_definitions().await?;
let capabilities = self.capabilities;
let mut row: Option<_> = None;
loop {
let packet = self.receive().await?;
if packet[0] == 0xFE && packet.len() < 0xFF_FF_FF {
// NOTE: It's possible for a ResultRow to start with 0xFE (which would normally signify end-of-rows)
// but it's not possible for an Ok/Eof to be larger than 0xFF_FF_FF.
if !capabilities.contains(Capabilities::CLIENT_DEPRECATE_EOF) {
let _eof = EofPacket::decode(packet)?;
} else {
let _ok = OkPacket::decode(packet, capabilities)?;
}
break;
} else if packet[0] == 0xFF {
let _err = ErrPacket::decode(packet)?;
} else {
row = Some(FromRow::from_row(ResultRow::decode(packet, &columns)?));
}
}
Ok(row)
})
}
fn describe<'e, 'q: 'e>(

View file

@ -45,6 +45,10 @@ impl ResultRow {
values.push(None);
} else {
match columns[column_idx].field_type {
FieldType::MYSQL_TYPE_TINY => {
values.push(Some(buf.get_bytes(1)?.into()));
}
FieldType::MYSQL_TYPE_LONG => {
values.push(Some(buf.get_bytes(4)?.into()));
}

View file

@ -7,9 +7,9 @@ use crate::{
};
pub struct MariaDbQueryParameters {
param_types: Vec<MariaDbTypeMetadata>,
params: Vec<u8>,
null: Vec<u8>,
pub(crate) param_types: Vec<MariaDbTypeMetadata>,
pub(crate) params: Vec<u8>,
pub(crate) null: Vec<u8>,
}
impl QueryParameters for MariaDbQueryParameters {

View file

@ -5,7 +5,7 @@ use crate::{
mariadb::protocol::{FieldType, ParameterFlag},
types::HasSqlType,
};
use byteorder::{BigEndian, ByteOrder};
use byteorder::{LittleEndian, ByteOrder};
impl HasSqlType<i16> for MariaDb {
#[inline]
@ -21,7 +21,7 @@ impl HasSqlType<i16> for MariaDb {
impl Encode<MariaDb> for i16 {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend_from_slice(&self.to_be_bytes());
buf.extend_from_slice(&self.to_le_bytes());
IsNull::No
}
@ -30,7 +30,7 @@ impl Encode<MariaDb> for i16 {
impl Decode<MariaDb> for i16 {
#[inline]
fn decode(buf: Option<&[u8]>) -> Self {
BigEndian::read_i16(buf.unwrap())
LittleEndian::read_i16(buf.unwrap())
}
}
@ -48,7 +48,7 @@ impl HasSqlType<i32> for MariaDb {
impl Encode<MariaDb> for i32 {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend_from_slice(&self.to_be_bytes());
buf.extend_from_slice(&self.to_le_bytes());
IsNull::No
}
@ -57,7 +57,7 @@ impl Encode<MariaDb> for i32 {
impl Decode<MariaDb> for i32 {
#[inline]
fn decode(buf: Option<&[u8]>) -> Self {
BigEndian::read_i32(buf.unwrap())
LittleEndian::read_i32(buf.unwrap())
}
}
@ -75,7 +75,7 @@ impl HasSqlType<i64> for MariaDb {
impl Encode<MariaDb> for i64 {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
buf.extend_from_slice(&self.to_be_bytes());
buf.extend_from_slice(&self.to_le_bytes());
IsNull::No
}
@ -84,7 +84,7 @@ impl Encode<MariaDb> for i64 {
impl Decode<MariaDb> for i64 {
#[inline]
fn decode(buf: Option<&[u8]>) -> Self {
BigEndian::read_i64(buf.unwrap())
LittleEndian::read_i64(buf.unwrap())
}
}

30
tests/mysql-types.rs Normal file
View file

@ -0,0 +1,30 @@
use sqlx::{Connection, MariaDb, Row};
use std::env;
macro_rules! test {
($name:ident: $ty:ty: $($text:literal == $value:expr),+) => {
#[async_std::test]
async fn $name () -> sqlx::Result<()> {
let mut conn =
Connection::<MariaDb>::open(&env::var("DATABASE_URL").unwrap()).await?;
$(
let row = sqlx::query(&format!("SELECT {} = ?, ?", $text))
.bind($value)
.bind($value)
.fetch_one(&mut conn)
.await?;
assert_eq!(row.get::<i32>(0), 1);
assert!($value == row.get::<$ty>(1));
)+
Ok(())
}
}
}
test!(mysql_bool: bool: "false" == false, "true" == true);
test!(mysql_long: i32: "2141512" == 2141512_i32);