Got a result set!!!

This commit is contained in:
Daniel Akhterov 2019-07-24 20:44:30 -07:00
parent 6e282ee33b
commit 70bca63fcd
9 changed files with 238 additions and 101 deletions

View file

@ -1,7 +1,7 @@
use crate::protocol::{
deserialize::{DeContext, Deserialize},
encode::Encoder,
packets::{com_ping::ComPing, com_query::ComQuery, com_quit::ComQuit, ok::OkPacket},
packets::{com_ping::ComPing, com_query::ComQuery, com_quit::ComQuit, com_init_db::ComInitDb, ok::OkPacket},
serialize::Serialize,
server::Message as ServerMessage,
types::{Capabilities, ServerStatusFlag},
@ -19,7 +19,7 @@ use runtime::net::TcpStream;
mod establish;
pub struct Connection {
stream: Framed,
pub stream: Framed,
// Buffer used when serializing outgoing messages
pub encoder: Encoder,
@ -47,7 +47,7 @@ pub struct ConnContext {
}
impl Connection {
pub(crate) async fn establish(options: ConnectOptions<'static>) -> Result<Self, Error> {
pub async fn establish(options: ConnectOptions<'static>) -> Result<Self, Error> {
let stream: Framed = Framed::new(TcpStream::connect((options.host, options.port)).await?);
let mut conn: Connection = Self {
stream,
@ -66,7 +66,7 @@ impl Connection {
Ok(conn)
}
pub(crate) async fn send<S>(&mut self, message: S) -> Result<(), Error>
pub async fn send<S>(&mut self, message: S) -> Result<(), Error>
where
S: Serialize,
{
@ -82,19 +82,29 @@ impl Connection {
Ok(())
}
pub(crate) async fn quit(&mut self) -> Result<(), Error> {
pub async fn quit(&mut self) -> Result<(), Error> {
self.context.seq_no = 0;
self.send(ComQuit()).await?;
Ok(())
}
pub(crate) async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<(), Error> {
pub async fn query<'a>(&'a mut self, sql_statement: &'a str) -> Result<(), Error> {
self.context.seq_no = 0;
self.send(ComQuery { sql_statement: bytes::Bytes::from(sql_statement) }).await?;
Ok(())
}
pub(crate) async fn ping(&mut self) -> Result<(), Error> {
pub async fn select_db<'a>(&'a mut self, db: &'a str) -> Result<(), Error> {
self.context.seq_no = 0;
self.send(ComInitDb { schema_name: bytes::Bytes::from(db) }).await?;
Ok(())
}
pub async fn ping(&mut self) -> Result<(), Error> {
self.context.seq_no = 0;
self.send(ComPing()).await?;
@ -105,7 +115,7 @@ impl Connection {
Ok(())
}
pub(crate) async fn next(&mut self) -> Result<Option<ServerMessage>, Error> {
pub async fn next(&mut self) -> Result<Option<ServerMessage>, Error> {
let mut rbuf = BytesMut::new();
let mut len = 0;
@ -151,7 +161,7 @@ impl Connection {
}
}
struct Framed {
pub struct Framed {
inner: TcpStream,
readable: bool,
eof: bool,
@ -168,14 +178,14 @@ impl Framed {
}
}
async fn next_bytes(&mut self) -> Result<Bytes, Error> {
pub async fn next_bytes(&mut self) -> Result<Bytes, Error> {
let mut rbuf = BytesMut::new();
let mut len = 0;
let mut packet_len: u32 = 0;
loop {
if len == rbuf.len() {
rbuf.reserve(32);
rbuf.reserve(20000);
unsafe {
// Set length to the capacity and efficiently

View file

@ -51,6 +51,12 @@ impl<'a> Decoder<'a> {
pub fn eof(&self) -> bool {
self.buf.len() == self.index
}
#[inline]
pub fn eof_byte(&self) -> bool {
self.buf[self.index] == 0xFE
}
#[inline]
pub fn decode_int_lenenc(&mut self) -> Option<usize> {
match self.buf[self.index] {
@ -119,7 +125,7 @@ impl<'a> Decoder<'a> {
#[inline]
pub fn decode_string_lenenc(&mut self) -> Bytes {
let length = self.decode_int_3();
let length = self.decode_int_1();
let value = Bytes::from(&self.buf[self.index..self.index + length as usize]);
self.index = self.index + length as usize;
value
@ -294,12 +300,12 @@ mod tests {
#[test]
fn it_decodes_string_lenenc() {
let buf = Bytes::from(b"\x01\x00\x00\x01".to_vec());
let buf = Bytes::from(b"\x03sup".to_vec());
let mut decoder = Decoder::new(&buf);
let string: Bytes = decoder.decode_string_lenenc();
assert_eq!(string[0], b'\x01');
assert_eq!(string.len(), 1);
assert_eq!(string[..], b"sup"[..]);
assert_eq!(string.len(), 3);
assert_eq!(decoder.index, 4);
}

View file

@ -6,11 +6,12 @@ use failure::Error;
pub struct DeContext<'a> {
pub conn: &'a mut ConnContext,
pub decoder: Decoder<'a>,
pub columns: Option<usize>,
}
impl<'a> DeContext<'a> {
pub fn new(conn: &'a mut ConnContext, buf: &'a Bytes) -> Self {
DeContext { conn, decoder: Decoder::new(&buf) }
DeContext { conn, decoder: Decoder::new(&buf), columns: None }
}
}

View file

@ -1,7 +1,7 @@
use super::super::deserialize::{DeContext, Deserialize};
use failure::Error;
#[derive(Default, Debug)]
#[derive(Default, Debug, Clone, Copy)]
// ColumnPacket doesn't have a packet header because
// it's nested inside a result set packet
pub struct ColumnPacket {

View file

@ -27,23 +27,37 @@ pub struct ColumnDefPacket {
impl Deserialize for ColumnDefPacket {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
// string<lenenc> catalog (always 'def')
let catalog = decoder.decode_string_lenenc();
// string<lenenc> schema
let schema = decoder.decode_string_lenenc();
// string<lenenc> table alias
let table_alias = decoder.decode_string_lenenc();
// string<lenenc> table
let table = decoder.decode_string_lenenc();
// string<lenenc> column alias
let column_alias = decoder.decode_string_lenenc();
// string<lenenc> column
let column = decoder.decode_string_lenenc();
// int<lenenc> length of fixed fields (=0xC)
let length_of_fixed_fields = decoder.decode_int_lenenc();
// int<2> character set number
let char_set = decoder.decode_int_2();
// int<4> max. column size
let max_columns = decoder.decode_int_4();
// int<1> Field types
let field_type = FieldType::try_from(decoder.decode_int_1())?;
// int<2> Field detail flag
let field_details = FieldDetailFlag::from_bits_truncate(decoder.decode_int_2());
// int<1> decimals
let decimals = decoder.decode_int_1();
// Skip last two unused bytes
// int<2> - unused -
decoder.skip_bytes(2);
Ok(ColumnDefPacket {
catalog,
schema,
@ -81,18 +95,22 @@ mod test {
#[rustfmt::skip]
let buf = __bytes_builder!(
// length
1u8, 0u8, 0u8,
// seq_no
0u8,
// string<lenenc> catalog (always 'def')
1u8, 0u8, 0u8, b'a',
1u8, b'a',
// string<lenenc> schema
1u8, 0u8, 0u8, b'b',
1u8, b'b',
// string<lenenc> table alias
1u8, 0u8, 0u8, b'c',
1u8, b'c',
// string<lenenc> table
1u8, 0u8, 0u8, b'd',
1u8, b'd',
// string<lenenc> column alias
1u8, 0u8, 0u8, b'e',
1u8, b'e',
// string<lenenc> column
1u8, 0u8, 0u8, b'f',
1u8, b'f',
// int<lenenc> length of fixed fields (=0xC)
0xFC_u8, 1u8, 1u8,
// int<2> character set number

View file

@ -25,9 +25,9 @@ impl Deserialize for EofPacket {
let packet_header = decoder.decode_int_1();
if packet_header != 0xFE {
panic!("Packet header is not 0xFE for ErrPacket");
}
// if packet_header != 0xFE {
// panic!("Packet header is not 0xFE for ErrPacket");
// }
let warning_count = decoder.decode_int_2();
let status = ServerStatusFlag::from_bits_truncate(decoder.decode_int_2());

View file

@ -20,3 +20,4 @@ pub mod ok;
pub mod packet_header;
pub mod result_set;
pub mod ssl_request;
pub mod result_row;

View file

@ -0,0 +1,68 @@
use super::super::{
decode::Decoder,
deserialize::{DeContext, Deserialize},
error_codes::ErrorCode,
types::ServerStatusFlag,
};
use bytes::Bytes;
use failure::Error;
use std::convert::TryFrom;
#[derive(Default, Debug)]
pub struct ResultRow {
pub length: u32,
pub seq_no: u8,
pub row: Vec<Bytes>,
}
impl Deserialize for ResultRow {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let row = if let Some(columns) = ctx.columns {
(0..columns).map(|_| decoder.decode_string_lenenc()).collect::<Vec<Bytes>>()
} else {
Vec::new()
};
Ok(ResultRow { length, seq_no, row })
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{__bytes_builder, connection::Connection};
use bytes::Bytes;
use mason_core::ConnectOptions;
#[runtime::test]
async fn it_decodes_result_row_packet() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
#[rustfmt::skip]
let buf = __bytes_builder!(
// int<3> length
1u8, 0u8, 0u8,
// int<1> seq_no
1u8,
// string<lenenc> column data
1u8, b"s"
);
let mut ctx = DeContext::new(&mut conn.context, &buf);
ctx.columns = Some(1);
let _message = ResultRow::deserialize(&mut ctx)?;
Ok(())
}
}

View file

@ -1,9 +1,12 @@
use super::super::{
deserialize::{DeContext, Deserialize},
packets::{column::ColumnPacket, column_def::ColumnDefPacket},
packets::{ok::OkPacket, err::ErrPacket, eof::EofPacket, column::ColumnPacket, column_def::ColumnDefPacket, result_row::ResultRow},
};
use bytes::Bytes;
use failure::Error;
use crate::protocol::server::Message;
use crate::protocol::types::Capabilities;
use crate::protocol::decode::Decoder;
#[derive(Debug, Default)]
pub struct ResultSet {
@ -11,7 +14,7 @@ pub struct ResultSet {
pub seq_no: u8,
pub column_packet: ColumnPacket,
pub columns: Vec<ColumnDefPacket>,
pub rows: Vec<Vec<Bytes>>,
pub rows: Vec<ResultRow>,
}
impl Deserialize for ResultSet {
@ -21,18 +24,6 @@ impl Deserialize for ResultSet {
let column_packet = ColumnPacket::deserialize(ctx)?;
match ctx.decoder.decode_int_1() {
// 0x00 -> PACKET_OK
0x00 => {}
// 0xFF -> PACKET_ERR
0xFF => {}
_ => {
panic!("Didn't receive 0x00 nor 0xFF");
}
}
let columns = if let Some(columns) = column_packet.columns {
(0..columns)
.map(|_| ColumnDefPacket::deserialize(ctx))
@ -43,22 +34,43 @@ impl Deserialize for ResultSet {
Vec::new()
};
let mut rows = Vec::new();
println!("length: {:?}", length);
println!("seq_no: {:?}", seq_no);
loop {
// if end of buffer stop
if ctx.decoder.eof() {
break;
}
let eof_packet = if !(ctx.conn.capabilities & Capabilities::CLIENT_DEPRECATE_EOF).is_empty() {
Some(EofPacket::deserialize(ctx)?)
} else {
None
};
let columns = if let Some(columns) = column_packet.columns {
(0..columns).map(|_| ctx.decoder.decode_string_lenenc()).collect::<Vec<Bytes>>()
} else {
Vec::new()
};
println!("column_packet: {:?}", column_packet);
for col in &columns {
println!("col: {:?}", col);
}
Ok(ResultSet { length, seq_no, column_packet, columns, rows })
println!("eof_packet: {:?}", eof_packet);
ctx.columns = column_packet.columns.clone();
// TODO: Deserialize all rows
let rows = vec![ResultRow::deserialize(ctx)?];
if (ctx.conn.capabilities & Capabilities::CLIENT_DEPRECATE_EOF).is_empty() {
println!("eof_packet: {:?}", EofPacket::deserialize(ctx)?);
} else {
println!("ok_packet: {:?}", OkPacket::deserialize(ctx)?);
}
println!("rows: {:?}", rows);
Ok(ResultSet {
length,
seq_no,
column_packet,
columns,
rows
})
}
}
@ -67,6 +79,7 @@ mod test {
use super::*;
use crate::{__bytes_builder, connection::Connection};
use bytes::{BufMut, Bytes};
use crate::protocol::packets::{ok::OkPacket, err::ErrPacket, eof::EofPacket, result_row::ResultRow};
#[runtime::test]
async fn it_decodes_result_set_packet() -> Result<(), Error> {
@ -79,7 +92,27 @@ mod test {
})
.await?;
// conn.query("SELECT * FROM users");
conn.select_db("test").await?;
match conn.next().await? {
Some(Message::OkPacket(_)) => {},
Some(message @ Message::ErrPacket(_)) => {
failure::bail!("Received an ErrPacket packet: {:?}", message);
},
Some(message) => {
failure::bail!("Received an unexpected packet type: {:?}", message);
}
None => {
failure::bail!("Did not receive a packet when one was expected");
}
}
conn.query("SELECT * FROM users").await?;
let buf = conn.stream.next_bytes().await?;
let mut ctx = DeContext::new(&mut conn.context, &buf);
ResultSet::deserialize(&mut ctx)?;
#[rustfmt::skip]
let buf = __bytes_builder!(
@ -88,120 +121,120 @@ mod test {
// ------------------- //
// length
0x02_u8, 0x0_u8, 0x0_u8,
2u8, 0u8, 0u8,
// seq_no
0x02_u8,
2u8,
// int<lenenc> Column count packet
0x02_u8, 0x00_u8,
2u8, 0u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// length
0x02_u8, 0x0_u8, 0x0_u8,
2u8, 0u8, 0u8,
// seq_no
0x02_u8,
2u8,
// string<lenenc> catalog (always 'def')
0x03_u8, 0x0_u8, 0x0_u8, b"def",
3u8, b"def",
// string<lenenc> schema
0x01_u8, 0x0_u8, 0x0_u8, b'b',
1u8, b'b',
// string<lenenc> table alias
0x01_u8, 0x0_u8, 0x0_u8, b'c',
1u8, b'c',
// string<lenenc> table
0x01_u8, 0x0_u8, 0x0_u8, b'd',
1u8, b'd',
// string<lenenc> column alias
0x01_u8, 0x0_u8, 0x0_u8, b'e',
1u8, b'e',
// string<lenenc> column
0x01_u8, 0x0_u8, 0x0_u8, b'f',
1u8, b'f',
// int<lenenc> length of fixed fields (=0xC)
0xfc_u8, 0x01_u8, 0x01_u8,
0xFC_u8, 1u8, 1u8,
// int<2> character set number
0x01_u8, 0x01_u8,
1u8, 1u8,
// int<4> max. column size
0x01_u8, 0x01_u8, 0x01_u8, 0x01_u8,
1u8, 1u8, 1u8, 1u8,
// int<1> Field types
0x00_u8,
0u8,
// int<2> Field detail flag
0x00_u8, 0x00_u8,
0u8, 0u8,
// int<1> decimals
0x01_u8,
1u8,
// int<2> - unused -
0x0_u8, 0x0_u8,
0u8, 0u8,
// ------------------------ //
// Column Definition packet //
// ------------------------ //
// length
0x02_u8, 0x0_u8, 0x0_u8,
2u8, 0u8, 0u8,
// seq_no
0x02_u8,
2u8,
// string<lenenc> catalog (always 'def')
0x03_u8, 0x0_u8, 0x0_u8, b"def",
3u8, b"def",
// string<lenenc> schema
0x01_u8, 0x0_u8, 0x0_u8, b'b',
1u8, b'b',
// string<lenenc> table alias
0x01_u8, 0x0_u8, 0x0_u8, b'c',
1u8, b'c',
// string<lenenc> table
0x01_u8, 0x0_u8, 0x0_u8, b'd',
1u8, b'd',
// string<lenenc> column alias
0x01_u8, 0x0_u8, 0x0_u8, b'e',
1u8, b'e',
// string<lenenc> column
0x01_u8, 0x0_u8, 0x0_u8, b'f',
1u8, b'f',
// int<lenenc> length of fixed fields (=0xC)
0xfc_u8, 0x01_u8, 0x01_u8,
0xFC_u8, 1u8, 1u8,
// int<2> character set number
0x01_u8, 0x01_u8,
1u8, 1u8,
// int<4> max. column size
0x01_u8, 0x01_u8, 0x01_u8, 0x01_u8,
1u8, 1u8, 1u8, 1u8,
// int<1> Field types
0x00_u8,
// int<2> Field detail flag
0x00_u8, 0x00_u8,
0u8,
// int<2> Field detail flag
0u8, 0u8,
// int<1> decimals
0x01_u8,
1u8,
// int<2> - unused -
0x0_u8, 0x00_u8,
0u8, 0u8,
// ---------- //
// EOF Packet //
// ---------- //
// length
0x02_u8, 0x0_u8, 0x0_u8,
1u8, 0u8, 0u8,
// seq_no
0x02_u8,
1u8,
// int<1> 0xfe : EOF header
0xfe_u8,
0xFE_u8,
// int<2> warning count
0x0_u8, 0x0_u8,
0u8, 0u8,
// int<2> server status
0x01_u8, 0x00_u8,
1u8, 0u8,
// ------------------- //
// N Result Row Packet //
// ------------------- //
// string<lenenc> column data
0x01_u8, 0x0_u8, 0x0_u8, b'h',
1u8, b'h',
// string<lenenc> column data
0x01_u8, 0x0_u8, 0x0_u8, b'i',
1u8, b'i',
// ---------- //
// EOF Packet //
// ---------- //
// length
0x02_u8, 0x0_u8, 0x0_u8,
1u8, 0u8, 0u8,
// seq_no
0x02_u8,
1u8,
// int<1> 0xfe : EOF header
0xfe_u8,
0xFE_u8,
// int<2> warning count
0x0_u8, 0x0_u8,
0u8, 0u8,
// int<2> server status
0x01_u8, 0x00_u8
1u8, 0u8
);
Ok(())