Move protocol around a bit more and work more towards the refactor

This commit is contained in:
Ryan Leckey 2019-09-06 14:50:57 -07:00
parent 8ada85ef62
commit 837e12d797
37 changed files with 172 additions and 1211 deletions

View file

@ -7,6 +7,7 @@ mod encode;
mod error_code;
mod field;
mod response;
mod text;
mod server_status;
pub use capabilities::Capabilities;
@ -20,3 +21,7 @@ pub use response::{
ColumnCountPacket, ColumnDefinitionPacket, EofPacket, ErrPacket, OkPacket, ResultRow,
};
pub use server_status::ServerStatusFlag;
pub use text::{
ComDebug, ComInitDb,
ComPing, ComProcessKill, ComQuery, ComQuit,
};

View file

@ -1,47 +0,0 @@
use crate::mariadb::io::BufExt;
use byteorder::LittleEndian;
use std::io;
// The column packet is the first packet of a result set.
// Inside of it it contains the number of columns in the result set
// encoded as an int<lenenc>.
// https://mariadb.com/kb/en/library/resultset/#column-count-packet
#[derive(Debug)]
pub struct ColumnCountPacket {
pub columns: u64,
}
impl ColumnCountPacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
let columns = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
Ok(Self { columns })
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::__bytes_builder;
#[test]
fn it_decodes_column_packet_0x_fb() -> io::Result<()> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
0u8, 0u8, 0u8,
// int<1> seq_no
0u8,
// int<lenenc> tag code: Some(3 bytes)
0xFD_u8,
// value: 3 bytes
0x01_u8, 0x01_u8, 0x01_u8
);
let message = ColumnCountPacket::decode(&buf)?;
assert_eq!(message.columns, Some(0x010101));
Ok(())
}
}

View file

@ -1,127 +0,0 @@
use crate::{
io::Buf,
mariadb::{
io::BufExt,
protocol::{FieldDetailFlag, FieldType},
},
};
use byteorder::LittleEndian;
use std::io;
#[derive(Debug)]
// ColumnDefinitionPacket doesn't have a packet header because
// it's nested inside a result set packet
pub struct ColumnDefinitionPacket {
pub schema: Option<String>,
pub table_alias: Option<String>,
pub table: Option<String>,
pub column_alias: Option<String>,
pub column: Option<String>,
pub char_set: u16,
pub max_columns: i32,
pub field_type: FieldType,
pub field_details: FieldDetailFlag,
pub decimals: u8,
}
impl ColumnDefinitionPacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
// string<lenenc> catalog (always 'def')
let _catalog = buf.get_str_lenenc::<LittleEndian>()?;
// TODO: Assert that this is always DEF
// string<lenenc> schema
let schema = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
// string<lenenc> table alias
let table_alias = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
// string<lenenc> table
let table = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
// string<lenenc> column alias
let column_alias = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
// string<lenenc> column
let column = buf.get_str_lenenc::<LittleEndian>()?.map(ToOwned::to_owned);
// int<lenenc> length of fixed fields (=0xC)
let _length_of_fixed_fields = buf.get_uint_lenenc::<LittleEndian>()?;
// TODO: Assert that this is always 0xC
// int<2> character set number
let char_set = buf.get_u16::<LittleEndian>()?;
// int<4> max. column size
let max_columns = buf.get_i32::<LittleEndian>()?;
// int<1> Field types
let field_type = FieldType(buf.get_u8()?);
// int<2> Field detail flag
let field_details = FieldDetailFlag::from_bits_truncate(buf.get_u16::<LittleEndian>()?);
// int<1> decimals
let decimals = buf.get_u8()?;
// int<2> - unused -
buf.advance(2);
Ok(Self {
schema,
table_alias,
table,
column_alias,
column,
char_set,
max_columns,
field_type,
field_details,
decimals,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::__bytes_builder;
#[test]
fn it_decodes_column_def_packet() -> io::Result<()> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// length
1u8, 0u8, 0u8,
// seq_no
0u8,
// string<lenenc> catalog (always 'def')
1u8, b'a',
// string<lenenc> schema
1u8, b'b',
// string<lenenc> table alias
1u8, b'c',
// string<lenenc> table
1u8, b'd',
// string<lenenc> column alias
1u8, b'e',
// string<lenenc> column
1u8, b'f',
// int<lenenc> length of fixed fields (=0xC)
0xFC_u8, 1u8, 1u8,
// int<2> character set number
1u8, 1u8,
// int<4> max. column size
1u8, 1u8, 1u8, 1u8,
// int<1> Field types
1u8,
// int<2> Field detail flag
1u8, 0u8,
// int<1> decimals
1u8,
// int<2> - unused -
0u8, 0u8
);
let message = ColumnDefinitionPacket::decode(&buf)?;
assert_eq!(message.schema, Some(b"b"));
assert_eq!(message.table_alias, Some(b"c"));
assert_eq!(message.table, Some(b"d"));
assert_eq!(message.column_alias, Some(b"e"));
assert_eq!(message.column, Some(b"f"));
Ok(())
}
}

View file

@ -1,65 +0,0 @@
use crate::{
io::Buf,
mariadb::{
io::BufExt,
protocol::{ErrorCode, ServerStatusFlag},
},
};
use byteorder::LittleEndian;
use std::io;
#[derive(Default, Debug)]
pub struct EofPacket {
pub warning_count: u16,
pub status: ServerStatusFlag,
}
impl EofPacket {
fn decode(mut buf: &[u8]) -> io::Result<Self> {
let header = buf.get_u8()?;
if header != 0xFE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected 0xFE; received {}", header),
));
}
let warning_count = buf.get_u16::<LittleEndian>()?;
let status = ServerStatusFlag::from_bits_truncate(buf.get_u16::<LittleEndian>()?);
Ok(Self {
warning_count,
status,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{__bytes_builder, mariadb::ConnContext};
use bytes::Bytes;
#[test]
fn it_decodes_eof_packet() -> Result<(), Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// int<1> 0xfe : EOF header
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
1u8, 1u8
);
let _message = EofPacket::decode(&buf)?;
// TODO: Assert fields?
Ok(())
}
}

View file

@ -1,133 +0,0 @@
use std::convert::TryFrom;
use bytes::Bytes;
use failure::Error;
use crate::mariadb::{DeContext, Decode, ErrorCode};
#[derive(Default, Debug)]
pub struct ErrPacket {
pub error_code: ErrorCode,
pub stage: Option<u8>,
pub max_stage: Option<u8>,
pub progress: Option<i32>,
pub progress_info: Option<Bytes>,
pub sql_state_marker: Option<Bytes>,
pub sql_state: Option<Bytes>,
pub error_message: Option<Bytes>,
}
impl Decode for ErrPacket {
fn decode(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_u8();
let packet_header = decoder.decode_int_u8();
if packet_header != 0xFF {
panic!("Packet header is not 0xFF for ErrPacket");
}
let error_code = ErrorCode(decoder.decode_int_u16());
let mut stage = None;
let mut max_stage = None;
let mut progress = None;
let mut progress_info = None;
let mut sql_state_marker = None;
let mut sql_state = None;
let mut error_message = None;
// Progress Reporting
if error_code.0 == 0xFFFF {
stage = Some(decoder.decode_int_u8());
max_stage = Some(decoder.decode_int_u8());
progress = Some(decoder.decode_int_i24());
progress_info = Some(decoder.decode_string_lenenc());
} else {
if decoder.buf[decoder.index] == b'#' {
sql_state_marker = Some(decoder.decode_string_fix(1));
sql_state = Some(decoder.decode_string_fix(5));
error_message = Some(decoder.decode_string_eof(Some(length as usize)));
} else {
error_message = Some(decoder.decode_string_eof(Some(length as usize)));
}
}
Ok(ErrPacket {
length,
seq_no,
error_code,
stage,
max_stage,
progress,
progress_info,
sql_state_marker,
sql_state,
error_message,
})
}
}
impl std::error::Error for ErrPacket {
fn description(&self) -> &str {
"Received error packet"
}
}
impl std::fmt::Display for ErrPacket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
write!(f, "{:?}:{:?}", self.error_code, self.error_message)
}
}
#[cfg(test)]
mod test {
use bytes::Bytes;
use super::*;
use crate::{
__bytes_builder,
mariadb::{ConnContext, Decoder},
};
#[test]
fn it_decodes_err_packet() -> Result<(), Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// int<1> 0xfe : EOF header
0xFF_u8,
// int<2> error code
0x84_u8, 0x04_u8,
// if (errorcode == 0xFFFF) /* progress reporting */ {
// int<1> stage
// int<1> max_stage
// int<3> progress
// string<lenenc> progress_info
// } else {
// if (next byte = '#') {
// string<1> sql state marker '#'
b"#",
// string<5>sql state
b"08S01",
// string<EOF> error message
b"Got packets out of order"
// } else {
// string<EOF> error message
// }
// }
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, buf);
let _message = ErrPacket::decode(&mut ctx)?;
Ok(())
}
}

View file

@ -1,38 +0,0 @@
// pub mod auth_switch_request;
// pub mod binary;
pub mod column_count_packet;
pub mod column_definition_packet;
pub mod eof_packet;
// pub mod err_packet;
pub mod handshake_response_packet;
// pub mod initial_handshake_packet;
pub mod ok_packet;
// pub mod packet_header;
// pub mod result_row;
// pub mod result_set;
// pub mod ssl_request;
// pub mod text;
// pub use auth_switch_request::AuthenticationSwitchRequestPacket;
pub use column_count_packet::ColumnCountPacket;
pub use column_definition_packet::ColumnDefinitionPacket;
// pub use eof::EofPacket;
// pub use err::ErrPacket;
// pub use handshake_response::HandshakeResponsePacket;
// pub use initial::InitialHandshakePacket;
// pub use ok::OkPacket;
// pub use packet_header::PacketHeader;
// pub use result_row::ResultRow;
// pub use result_set::ResultSet;
// pub use ssl_request::SSLRequestPacket;
// pub use text::{
// ComDebug, ComInitDb, ComPing, ComProcessKill, ComQuery, ComQuit, ComResetConnection,
// ComSetOption, ComShutdown, ComSleep, ComStatistics, ResultRow as ResultRowText,
// SetOptionOptions, ShutdownOptions,
// };
// pub use binary::{
// ComStmtClose, ComStmtExec, ComStmtFetch, ComStmtPrepare, ComStmtPrepareOk, ComStmtPrepareResp,
// ComStmtReset, ResultRow as ResultRowBinary,
// };

View file

@ -1,117 +0,0 @@
use crate::{
io::Buf,
mariadb::{
io::BufExt,
protocol::{Capabilities, ServerStatusFlag},
},
};
use byteorder::LittleEndian;
use std::io;
// https://mariadb.com/kb/en/library/ok_packet/
#[derive(Debug)]
pub struct OkPacket {
pub affected_rows: u64,
pub last_insert_id: u64,
pub server_status: ServerStatusFlag,
pub warning_count: u16,
pub info: Box<str>,
pub session_state_info: Option<Box<[u8]>>,
pub value_of_variable: Option<Box<str>>,
}
impl OkPacket {
fn decode(mut buf: &[u8], capabilities: Capabilities) -> io::Result<Self> {
let header = buf.get_u8()?;
if header != 0 && header != 0xFE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("expected 0x00 or 0xFE; received 0x{:X}", header),
));
}
let affected_rows = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
let last_insert_id = buf.get_uint_lenenc::<LittleEndian>()?.unwrap_or(0);
let server_status = ServerStatusFlag::from_bits_truncate(buf.get_u16::<LittleEndian>()?);
let warning_count = buf.get_u16::<LittleEndian>()?;
let info;
let mut session_state_info = None;
let mut value_of_variable = None;
if capabilities.contains(Capabilities::CLIENT_SESSION_TRACK) {
info = buf
.get_str_lenenc::<LittleEndian>()?
.unwrap_or_default()
.to_owned()
.into();
session_state_info = buf.get_byte_lenenc::<LittleEndian>()?.map(Into::into);
value_of_variable = buf.get_str_lenenc::<LittleEndian>()?.map(Into::into);
} else {
info = buf.get_str_eof()?.to_owned().into();
}
Ok(Self {
affected_rows,
last_insert_id,
server_status,
warning_count,
info,
session_state_info,
value_of_variable,
})
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
__bytes_builder,
mariadb::{ConnContext, Decoder},
};
#[test]
fn it_decodes_ok_packet() -> Result<(), Error> {
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
0u8, 0u8, 0u8,
// // int<1> seq_no
1u8,
// 0x00 : OK_Packet header or (0xFE if CLIENT_DEPRECATE_EOF is set)
0u8,
// int<lenenc> affected rows
0xFB_u8,
// int<lenenc> last insert id
0xFB_u8,
// int<2> server status
1u8, 1u8,
// int<2> warning count
0u8, 0u8,
// if session_tracking_supported (see CLIENT_SESSION_TRACK) {
// string<lenenc> info
// if (status flags & SERVER_SESSION_STATE_CHANGED) {
// string<lenenc> session state info
// string<lenenc> value of variable
// }
// } else {
// string<EOF> info
b"info"
// }
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, buf);
let message = OkPacket::decode(&mut ctx)?;
assert_eq!(message.affected_rows, None);
assert_eq!(message.last_insert_id, None);
assert!(!(message.server_status & ServerStatusFlag::SERVER_STATUS_IN_TRANS).is_empty());
assert_eq!(message.warning_count, 0);
assert_eq!(message.info, b"info".to_vec());
Ok(())
}
}

View file

@ -1,36 +0,0 @@
use byteorder::{ByteOrder, LittleEndian};
#[derive(Debug, Default, Clone, Copy)]
pub struct PacketHeader {
pub length: u32,
pub seq_no: u8,
}
impl PacketHeader {
pub fn size() -> usize {
4
}
pub fn combined_length(&self) -> usize {
PacketHeader::size() + self.length as usize
}
}
impl core::convert::TryFrom<&[u8]> for PacketHeader {
type Error = failure::Error;
fn try_from(buffer: &[u8]) -> Result<Self, Self::Error> {
if buffer.len() < 4 {
failure::bail!("Buffer length is too short")
} else {
let packet = PacketHeader {
length: LittleEndian::read_u24(&buffer),
seq_no: buffer[3],
};
if packet.length == 0 && packet.seq_no == 0 {
failure::bail!("Length and seq_no cannot be zero");
}
Ok(packet)
}
}
}

View file

@ -1,28 +0,0 @@
use crate::mariadb::{ResultRowBinary, ResultRowText};
#[derive(Debug)]
pub struct ResultRow {
pub length: u32,
pub seq_no: u8,
pub columns: Vec<Option<bytes::Bytes>>,
}
impl From<ResultRowText> for ResultRow {
fn from(row: ResultRowText) -> Self {
ResultRow {
length: row.length,
seq_no: row.seq_no,
columns: row.columns,
}
}
}
impl From<ResultRowBinary> for ResultRow {
fn from(row: ResultRowBinary) -> Self {
ResultRow {
length: row.length,
seq_no: row.seq_no,
columns: row.columns,
}
}
}

View file

@ -1,446 +0,0 @@
use bytes::Bytes;
use failure::Error;
use crate::mariadb::{
Capabilities, ColumnDefPacket, ColumnPacket, ConnContext, DeContext, Decode, Decoder,
EofPacket, ErrPacket, Framed, OkPacket, ProtocolType, ResultRow, ResultRowBinary,
ResultRowText,
};
#[derive(Debug, Default)]
pub struct ResultSet {
pub column_packet: ColumnPacket,
pub column_defs: Option<Vec<ColumnDefPacket>>,
pub rows: Vec<ResultRow>,
}
impl ResultSet {
pub async fn deserialize(
mut ctx: DeContext<'_>,
protocol: ProtocolType,
) -> Result<Self, Error> {
let column_packet = ColumnPacket::decode(&mut ctx)?;
ctx.columns = column_packet.columns;
let column_defs = if let Some(columns) = column_packet.columns {
let mut column_defs = Vec::new();
for _ in 0..columns {
ctx.next_packet().await?;
column_defs.push(ColumnDefPacket::decode(&mut ctx)?);
}
Some(column_defs)
} else {
None
};
if column_defs.is_some() {
ctx.column_defs = column_defs.clone();
}
ctx.next_packet().await?;
let eof_packet = if !ctx
.ctx
.capabilities
.contains(Capabilities::CLIENT_DEPRECATE_EOF)
{
// If we get an eof packet we must update ctx to hold a new buffer of the next packet.
let eof_packet = Some(EofPacket::decode(&mut ctx)?);
ctx.next_packet().await?;
eof_packet
} else {
None
};
let mut rows = Vec::new();
loop {
let packet_header = match ctx.decoder.peek_packet_header() {
Ok(v) => v,
Err(_) => break,
};
let tag = ctx.decoder.peek_tag();
if tag == &0xFE && packet_header.length <= 0xFFFFFF {
break;
} else {
let index = ctx.decoder.index;
match protocol {
ProtocolType::Text => match ResultRowText::decode(&mut ctx) {
Ok(row) => {
rows.push(ResultRow::from(row));
ctx.next_packet().await?;
}
Err(_) => {
ctx.decoder.index = index;
break;
}
},
ProtocolType::Binary => match ResultRowBinary::decode(&mut ctx) {
Ok(row) => {
rows.push(ResultRow::from(row));
ctx.next_packet().await?;
}
Err(_) => {
ctx.decoder.index = index;
break;
}
},
}
}
}
if ctx.decoder.peek_packet_header()?.length > 0 {
if ctx
.ctx
.capabilities
.contains(Capabilities::CLIENT_DEPRECATE_EOF)
{
OkPacket::decode(&mut ctx)?;
} else {
EofPacket::decode(&mut ctx)?;
}
}
Ok(ResultSet {
column_packet,
column_defs,
rows,
})
}
}
#[cfg(test)]
mod test {
use bytes::{BufMut, Bytes};
use super::*;
use crate::{
__bytes_builder,
mariadb::{
Capabilities, ConnContext, EofPacket, ErrPacket, MariaDbRawConnection, OkPacket,
ResultRow, ServerStatusFlag,
},
};
#[tokio::test]
async fn it_decodes_result_set_text_packet() -> Result<(), Error> {
// TODO: Use byte string as input for test; this is a valid return from a mariadb.
#[rustfmt::skip]
let buf: Bytes = __bytes_builder!(
// ------------------- //
// Column Count packet //
// ------------------- //
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// int<lenenc> tag code or length
4u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
40u8, 0u8, 0u8,
// int<1> seq_no
2u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
// string<lenenc> column alias
2u8, b"id",
// string<lenenc> column
2u8, b"id",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
0x80_u8, 0u8, 0u8, 0u8,
// int<1> Field types
0xFD_u8,
// int<2> Field detail flag
3u8, 64u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
52u8, 0u8, 0u8,
// int<1> seq_no
3u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
// string<lenenc> column alias
8u8, b"username",
// string<lenenc> column
8u8, b"username",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
0xFF_u8, 0xFF_u8, 0u8, 0u8,
// int<1> Field types
0xFC_u8,
// int<2> Field detail flag
0x11_u8, 0x10_u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
52u8, 0u8, 0u8,
// int<1> seq_no
4u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
// string<lenenc> column alias
8u8, b"password",
// string<lenenc> column
8u8, b"password",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
0xFF_u8, 0xFF_u8, 0u8, 0u8,
// int<1> Field types
0xFC_u8,
// int<2> Field detail flag
0x11_u8, 0x10_u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
60u8, 0u8, 0u8,
// int<1> seq_no
5u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
// string<lenenc> column alias
0x0C_u8, b"access_level",
// string<lenenc> column
0x0C_u8, b"access_level",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
7u8, 0u8, 0u8, 0u8,
// int<1> Field types
0xFE_u8,
// int<2> Field detail flag
1u8, 0x11_u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ---------- //
// EOF Packet //
// ---------- //
// int<3> length
5u8, 0u8, 0u8,
// int<1> seq_no
6u8,
// int<1> 0xfe : EOF header
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
34u8, 0u8,
// ----------------- //
// Result Row Packet //
// ----------------- //
// int<3> length
62u8, 0u8, 0u8,
// int<1> seq_no
7u8,
// string<lenenc> column data
36u8, b"044d3f34-af65-11e9-a2e5-0242ac110003",
// string<lenenc> column data
4u8, b"josh",
// string<lenenc> column data
0x0B_u8, b"password123",
// string<lenenc> column data
7u8, b"regular",
// ----------------- //
// Result Row Packet //
// ----------------- //
// int<3> length
52u8, 0u8, 0u8,
// int<1> seq_no
8u8,
// string<lenenc> column data
36u8, b"d83dd1c4-ada9-11e9-96bc-0242ac110003",
// string<lenenc> column data
6u8, b"daniel",
// string<lenenc> column data
1u8, b"f",
// string<lenenc> column data
5u8, b"admin",
// ------------- //
// OK/EOF Packet //
// ------------- //
// int<3> length
5u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// 0xFE: Required header for last packet of result set
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
34u8, 0u8
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, buf);
ResultSet::deserialize(ctx, ProtocolType::Text).await?;
Ok(())
}
#[tokio::test]
async fn it_decodes_result_set_binary_packet() -> Result<(), Error> {
// TODO: Use byte string as input for test; this is a valid return from a mariadb.
#[rustfmt::skip]
let buf: Bytes = __bytes_builder!(
// ------------------- //
// Column Count packet //
// ------------------- //
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// int<lenenc> tag code or length
1u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// int<3> length
40u8, 0u8, 0u8,
// int<1> seq_no
5u8,
// string<lenenc> catalog (always 'def')
3u8, b"def",
// string<lenenc> schema
4u8, b"test",
// string<lenenc> table alias
5u8, b"users",
// string<lenenc> table
5u8, b"users",
// string<lenenc> column alias
2u8, b"id",
// string<lenenc> column
2u8, b"id",
// int<lenenc> length of fixed fields (=0xC)
0x0C_u8,
// int<2> character set number
8u8, 0u8,
// int<4> max. column size
0x80u8, 0u8, 0u8, 0u8,
// int<1> Field types
0xFD_u8,
// int<2> Field detail flag
3u8, 64u8,
// int<1> decimals
0u8,
// int<2> - unused -
0u8, 0u8,
// ---------- //
// EOF Packet //
// ---------- //
// int<3> length
5u8, 0u8, 0u8,
// int<1> seq_no
3u8,
// int<1> 0xfe : EOF header
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
34u8, 0u8,
// ----------------- //
// Result Row Packet //
// ----------------- //
// int<3> length
39u8, 0u8, 0u8,
// int<1> seq_no
4u8,
// byte<1> 0x00 header
0u8,
// byte<(number_of_columns + 9) / 8> NULL-Bitmap
0u8,
// byte<lenenc> encoded result
36u8, b"044d3f34-af65-11e9-a2e5-0242ac110003",
// ---------- //
// EOF Packet //
// ---------- //
// int<3> length
5u8, 0u8, 0u8,
// int<1> seq_no
5u8,
// int<1> 0xfe : EOF header
0xFE_u8,
// int<2> warning count
0u8, 0u8,
// int<2> server status
34u8, 0u8
);
let mut context = ConnContext::new();
let mut ctx = DeContext::new(&mut context, buf);
ResultSet::deserialize(ctx, ProtocolType::Binary).await?;
Ok(())
}
}

View file

@ -1,26 +0,0 @@
use crate::{io::BufMut, mariadb::Encode};
pub struct ComDebug();
impl Encode for ComDebug {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_u8(super::TextProtocol::ComDebug as u8);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
#[test]
fn it_encodes_com_debug() -> io::Result<()> {
let mut buf = Vec::with_capacity(1024);
ComDebug().encode(&mut buf);
assert_eq!(&buf[..], b"\x01\0\0\x00\x0D");
Ok(())
}
}

View file

@ -1,32 +0,0 @@
use crate::{io::BufMut, mariadb::Encode};
pub struct ComInitDb<'a> {
pub schema_name: &'a str,
}
impl<'a> Encode for ComInitDb<'a> {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_u8(super::TextProtocol::ComInitDb as u8);
buf.put_str_nul(self.schema_name);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
#[test]
fn it_encodes_com_init_db() -> io::Result<()> {
let mut buf = Vec::with_capacity(1024);
ComInitDb {
schema_name: "portal",
}
.encode(&mut buf);
assert_eq!(&buf[..], b"\x08\0\0\x00\x02portal\0");
Ok(())
}
}

View file

@ -1,26 +0,0 @@
use crate::{io::BufMut, mariadb::Encode};
pub struct ComPing();
impl Encode for ComPing {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_u8(super::TextProtocol::ComPing.into());
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
#[test]
fn it_encodes_com_ping() -> io::Result<()> {
let mut buf = Vec::with_capacity(1024);
ComPing().encode(&mut buf);
assert_eq!(&buf[..], b"\x01\0\0\x00\x0E");
Ok(())
}
}

View file

@ -1,30 +0,0 @@
use crate::{io::BufMut, mariadb::Encode};
use byteorder::LittleEndian;
pub struct ComProcessKill {
pub process_id: u32,
}
impl Encode for ComProcessKill {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_u8(super::TextProtocol::ComProcessKill.into());
buf.put_u32::<LittleEndian>(self.process_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
#[test]
fn it_encodes_com_process_kill() -> io::Result<()> {
let mut buf = Vec::with_capacity(1024);
ComProcessKill { process_id: 1 }.encode(&mut buf);
assert_eq!(&buf[..], b"\x05\0\0\x00\x0C\x01\0\0\0");
Ok(())
}
}

View file

@ -1,25 +0,0 @@
use crate::{io::BufMut, mariadb::Encode};
pub struct ComQuit();
impl Encode for ComQuit {
fn encode(&self, buf: &mut Vec<u8>) {
buf.put_u8(super::TextProtocol::ComQuit as u8);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_encodes_com_quit() -> std::io::Result<()> {
let mut buf = Vec::with_capacity(1024);
ComQuit().encode(&mut buf);
assert_eq!(&buf[..], b"\x01\0\0\x00\x01");
Ok(())
}
}

View file

@ -0,0 +1,25 @@
use crate::{io::BufMut, mariadb::protocol::{Capabilities, Encode}};
use super::TextProtocol;
#[derive(Debug)]
pub struct ComDebug;
impl Encode for ComDebug {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
// COM_DEBUG Header (0xOD) : int<1>
buf.put_u8(TextProtocol::ComDebug as u8);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_encodes_com_debug() {
let mut buf = Vec::new();
ComDebug.encode(&mut buf, Capabilities::empty());
assert_eq!(&buf[..], b"\x0D");
}
}

View file

@ -0,0 +1,33 @@
use crate::{io::BufMut, mariadb::protocol::{Encode, Capabilities}};
use super::TextProtocol;
pub struct ComInitDb<'a> {
pub schema_name: &'a str,
}
impl Encode for ComInitDb<'_> {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
// COM_INIT_DB Header : int<1>
buf.put_u8(TextProtocol::ComInitDb as u8);
// schema name : string<NUL>
buf.put_str_nul(self.schema_name);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_encodes_com_init_db() {
let mut buf = Vec::new();
ComInitDb {
schema_name: "portal",
}
.encode(&mut buf, Capabilities::empty());
assert_eq!(&buf[..], b"\x02portal\0");
}
}

View file

@ -0,0 +1,25 @@
use crate::{io::BufMut, mariadb::{protocol::{Encode, Capabilities}}};
use super::TextProtocol;
#[derive(Debug)]
pub struct ComPing;
impl Encode for ComPing {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
// COM_PING Header : int<1>
buf.put_u8(TextProtocol::ComPing as u8);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_encodes_com_ping() {
let mut buf = Vec::new();
ComPing.encode(&mut buf, Capabilities::empty());
assert_eq!(&buf[..], b"\x0E");
}
}

View file

@ -0,0 +1,32 @@
use crate::{io::BufMut, mariadb::protocol::{Encode, Capabilities}};
use super::TextProtocol;
use byteorder::LittleEndian;
/// Forces the server to terminate a specified connection.
pub struct ComProcessKill {
pub process_id: u32,
}
impl Encode for ComProcessKill {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
// COM_PROCESS_KILL : int<1>
buf.put_u8(TextProtocol::ComProcessKill as u8);
// process id : int<4>
buf.put_u32::<LittleEndian>(self.process_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_encodes_com_process_kill() {
let mut buf = Vec::new();
ComProcessKill { process_id: 1 }.encode(&mut buf, Capabilities::empty());
assert_eq!(&buf[..], b"\x0C\x01\0\0\0");
}
}

View file

@ -1,14 +1,15 @@
use crate::{
io::BufMut,
mariadb::{BufMutExt, Encode},
mariadb::{io::{BufMutExt}, protocol::{Encode, Capabilities}},
};
/// Sends the server an SQL statement to be executed immediately.
pub struct ComQuery<'a> {
pub sql_statement: &'a str,
}
impl<'a> Encode for ComQuery<'a> {
fn encode(&self, buf: &mut Vec<u8>) {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.put_u8(super::TextProtocol::ComQuery as u8);
buf.put_str(&self.sql_statement);
}
@ -17,19 +18,16 @@ impl<'a> Encode for ComQuery<'a> {
#[cfg(test)]
mod tests {
use super::*;
use std::io;
#[test]
fn it_encodes_com_query() -> io::Result<()> {
let mut buf = Vec::with_capacity(1024);
fn it_encodes_com_query() {
let mut buf = Vec::new();
ComQuery {
sql_statement: "SELECT * FROM users",
}
.encode(&mut buf);
.encode(&mut buf, Capabilities::empty());
assert_eq!(&buf[..], b"\x14\0\0\x00\x03SELECT * FROM users");
Ok(())
assert_eq!(&buf, b"\x03SELECT * FROM users");
}
}

View file

@ -0,0 +1,26 @@
use crate::{io::BufMut, mariadb::protocol::{Encode, Capabilities}};
use super::TextProtocol;
pub struct ComQuit;
impl Encode for ComQuit {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.put_u8(TextProtocol::ComQuit as u8);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_encodes_com_quit() -> std::io::Result<()> {
let mut buf = Vec::new();
ComQuit.encode(&mut buf, Capabilities::empty());
assert_eq!(&buf[..], b"\x01");
Ok(())
}
}

View file

@ -1,15 +1,15 @@
pub mod com_debug;
pub mod com_init_db;
pub mod com_ping;
pub mod com_process_kill;
pub mod com_query;
pub mod com_quit;
pub mod com_reset_conn;
pub mod com_set_option;
pub mod com_shutdown;
pub mod com_sleep;
pub mod com_statistics;
pub mod result_row;
mod com_debug;
mod com_init_db;
mod com_ping;
mod com_process_kill;
mod com_query;
mod com_quit;
// mod com_reset_conn;
// mod com_set_option;
// mod com_shutdown;
// mod com_sleep;
// mod com_statistics;
// mod result_row;
pub use com_debug::ComDebug;
pub use com_init_db::ComInitDb;
@ -17,19 +17,19 @@ pub use com_ping::ComPing;
pub use com_process_kill::ComProcessKill;
pub use com_query::ComQuery;
pub use com_quit::ComQuit;
pub use com_reset_conn::ComResetConnection;
pub use com_set_option::{ComSetOption, SetOptionOptions};
pub use com_shutdown::{ComShutdown, ShutdownOptions};
pub use com_sleep::ComSleep;
pub use com_statistics::ComStatistics;
pub use result_row::ResultRow;
// pub use com_reset_conn::ComResetConnection;
// pub use com_set_option::{ComSetOption, SetOptionOptions};
// pub use com_shutdown::{ComShutdown, ShutdownOptions};
// pub use com_sleep::ComSleep;
// pub use com_statistics::ComStatistics;
// pub use result_row::ResultRow;
// This is an enum of text protocol packet tags.
// Tags are the 5th byte of the packet (1st byte of packet body)
// and are used to determine which type of query was sent.
// The name of the enum variant represents the type of query, and
// the value is the byte value required by the server.
pub enum TextProtocol {
enum TextProtocol {
ComChangeUser = 0x11,
ComDebug = 0x0D,
ComInitDb = 0x02,
@ -43,10 +43,3 @@ pub enum TextProtocol {
ComSleep = 0x00,
ComStatistics = 0x09,
}
// Helper method to easily transform into u8
impl Into<u8> for TextProtocol {
fn into(self) -> u8 {
self as u8
}
}