WIP: DeContext

This commit is contained in:
Daniel Akhterov 2019-07-16 20:36:11 -07:00
parent 3d5590c6c9
commit 4cfb1d46a1
12 changed files with 104 additions and 61 deletions

View file

@ -1,6 +1,6 @@
use super::{Connection, Decoder};
use super::{Connection};
use crate::protocol::{
deserialize::Deserialize,
deserialize::{Deserialize, DeContext},
packets::{handshake_response::HandshakeResponsePacket, initial::InitialHandshakePacket},
server::Message as ServerMessage,
types::Capabilities,
@ -14,10 +14,8 @@ pub async fn establish<'a, 'b: 'a>(
options: ConnectOptions<'b>,
) -> Result<(), Error> {
let buf = &conn.stream.next_bytes().await?;
let init_packet =
InitialHandshakePacket::deserialize(conn, &mut Decoder::new(&buf))?;
conn.capabilities = init_packet.capabilities;
let mut de_ctx = DeContext::new(conn, &buf);
let _ = InitialHandshakePacket::deserialize(&mut de_ctx)?;
let handshake: HandshakeResponsePacket = HandshakeResponsePacket {
// Minimum client capabilities required to establish connection

View file

@ -1,6 +1,5 @@
use super::protocol::decode::Decoder;
use crate::protocol::{
deserialize::Deserialize,
deserialize::{Deserialize, DeContext},
encode::Encoder,
packets::{com_ping::ComPing, com_quit::ComQuit, ok::OkPacket},
serialize::Serialize,
@ -31,6 +30,9 @@ pub struct Connection {
// Sequence Number
pub seq_no: u8,
// Last sequence number return by MariaDB
pub last_seq_no: u8,
// Server Capabilities
pub capabilities: Capabilities,
@ -46,6 +48,7 @@ impl Connection {
encoder: Encoder::new(1024),
connection_id: -1,
seq_no: 1,
last_seq_no: 0,
capabilities: Capabilities::default(),
status: ServerStatusFlag::default(),
};
@ -83,7 +86,7 @@ impl Connection {
// Ping response must be an OkPacket
let buf = self.stream.next_bytes().await?;
OkPacket::deserialize(self, &mut Decoder::new(&buf))?;
OkPacket::deserialize(&mut DeContext::new(self, &buf))?;
Ok(())
}
@ -115,7 +118,7 @@ impl Connection {
while len > 0 {
let size = rbuf.len();
let message = ServerMessage::deserialize(self, &mut Decoder::new(&rbuf.as_ref().into()))?;
let message = ServerMessage::deserialize(&mut DeContext::new(self, &rbuf.as_ref().into()))?;
len -= size - rbuf.len();
match message {

View file

@ -1,7 +1,22 @@
use super::decode::Decoder;
use failure::Error;
use crate::connection::Connection;
use bytes::Bytes;
pub struct DeContext<'a> {
pub conn: &'a mut Connection,
pub decoder: Decoder<'a>,
}
impl<'a> DeContext<'a> {
pub fn new(conn: &'a mut Connection, buf: &'a Bytes) -> Self {
DeContext {
conn,
decoder: Decoder::new(&buf),
}
}
}
pub trait Deserialize: Sized {
fn deserialize(conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error>;
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error>;
}

View file

@ -47,50 +47,49 @@ impl Encoder {
#[inline]
pub fn encode_int_8(&mut self, value: u64) {
self.buf.put_u64_le(value);
self.buf.extend_from_slice(&value.to_le_bytes());
}
#[inline]
pub fn encode_int_4(&mut self, value: u32) {
self.buf.put_u32_le(value);
self.buf.extend_from_slice(&value.to_le_bytes());
}
#[inline]
pub fn encode_int_3(&mut self, value: u32) {
let length = value.to_le_bytes();
self.buf.extend_from_slice(&length[0..3]);
pub fn encode_int_3(&mut self, value: u32) {
self.buf.extend_from_slice(&value.to_le_bytes()[0..3]);
}
#[inline]
pub fn encode_int_2(&mut self, value: u16) {
self.buf.put_u16_le(value);
self.buf.extend_from_slice(&value.to_le_bytes());
}
#[inline]
pub fn encode_int_1(&mut self, value: u8) {
self.buf.put_u8(value);
self.buf.extend_from_slice(&value.to_le_bytes());
}
#[inline]
pub fn encode_int_lenenc(&mut self, value: Option<&usize>) {
if let Some(value) = value {
if *value > U24_MAX && *value <= std::u64::MAX as usize {
self.buf.put_u8(0xFE);
self.buf.push(0xFE);
self.encode_int_8(*value as u64);
} else if *value > std::u16::MAX as usize && *value <= U24_MAX {
self.buf.put_u8(0xFD);
self.buf.push(0xFD);
self.encode_int_3(*value as u32);
} else if *value > std::u8::MAX as usize && *value <= std::u16::MAX as usize {
self.buf.put_u8(0xFC);
self.buf.push(0xFC);
self.encode_int_2(*value as u16);
} else if *value <= std::u8::MAX as usize {
self.buf.put_u8(0xFA);
self.buf.push(0xFA);
self.encode_int_1(*value as u8);
} else {
panic!("Value is too long");
}
} else {
self.buf.put_u8(0xFB);
self.buf.push(0xFB);
}
}

View file

@ -1,6 +1,5 @@
use super::super::{decode::Decoder, deserialize::Deserialize};
use super::super::deserialize::{Deserialize, DeContext};
use failure::Error;
use crate::connection::Connection;
#[derive(Default, Debug)]
pub struct ColumnPacket {
@ -10,7 +9,8 @@ pub struct ColumnPacket {
}
impl Deserialize for ColumnPacket {
fn deserialize(_conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
let columns = decoder.decode_int_lenenc();

View file

@ -1,10 +1,8 @@
use std::convert::TryFrom;
use bytes::Bytes;
use failure::Error;
use crate::connection::Connection;
use super::super::{
decode::Decoder,
deserialize::Deserialize,
deserialize::{Deserialize, DeContext},
types::{FieldDetailFlag, FieldType},
};
@ -27,7 +25,8 @@ pub struct ColumnDefPacket {
}
impl Deserialize for ColumnDefPacket {
fn deserialize(_conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();

View file

@ -1,8 +1,7 @@
use std::convert::TryFrom;
use bytes::Bytes;
use failure::Error;
use crate::connection::Connection;
use super::super::{decode::Decoder, deserialize::Deserialize, error_codes::ErrorCode};
use super::super::{deserialize::Deserialize, deserialize::DeContext, error_codes::ErrorCode};
#[derive(Default, Debug)]
pub struct ErrPacket {
@ -19,7 +18,8 @@ pub struct ErrPacket {
}
impl Deserialize for ErrPacket {
fn deserialize(_conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();

View file

@ -1,11 +1,9 @@
use super::super::{
decode::Decoder,
deserialize::Deserialize,
deserialize::{Deserialize, DeContext},
types::{Capabilities, ServerStatusFlag},
};
use bytes::Bytes;
use failure::{err_msg, Error};
use crate::connection::Connection;
#[derive(Default, Debug)]
pub struct InitialHandshakePacket {
@ -24,7 +22,8 @@ pub struct InitialHandshakePacket {
}
impl Deserialize for InitialHandshakePacket {
fn deserialize(_conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
@ -80,6 +79,9 @@ impl Deserialize for InitialHandshakePacket {
auth_plugin_name = Some(decoder.decode_string_null()?);
}
ctx.conn.capabilities = capabilities;
ctx.conn.last_seq_no = seq_no;
Ok(InitialHandshakePacket {
length,
seq_no,

View file

@ -1,7 +1,6 @@
use super::super::{decode::Decoder, deserialize::Deserialize, types::ServerStatusFlag};
use super::super::{deserialize::Deserialize, deserialize::DeContext, types::ServerStatusFlag};
use bytes::Bytes;
use failure::Error;
use crate::connection::Connection;
use failure::err_msg;
#[derive(Default, Debug)]
@ -18,16 +17,17 @@ pub struct OkPacket {
}
impl Deserialize for OkPacket {
fn deserialize(_conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let decoder = &mut ctx.decoder;
// Packet header
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
// Packet body
let packet_header = decoder.decode_int_1();
// if packet_header != 0 && packet_header != 0xFE {
// return Err(err_msg("Packet header is not 0 or 0xFE for OkPacket"));
// }
if packet_header != 0 && packet_header != 0xFE {
return Err(err_msg("Packet header is not 0 or 0xFE for OkPacket"));
}
let affected_rows = decoder.decode_int_lenenc();
let last_insert_id = decoder.decode_int_lenenc();

View file

@ -1,10 +1,7 @@
use bytes::Bytes;
use failure::Error;
use crate::connection::Connection;
use super::super::{
decode::Decoder,
deserialize::Deserialize,
deserialize::{Deserialize, DeContext},
packets::{column::ColumnPacket, column_def::ColumnDefPacket},
};
@ -18,15 +15,15 @@ pub struct ResultSet {
}
impl Deserialize for ResultSet {
fn deserialize(conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
let length = decoder.decode_length()?;
let seq_no = decoder.decode_int_1();
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
let length = ctx.decoder.decode_length()?;
let seq_no = ctx.decoder.decode_int_1();
let column_packet = ColumnPacket::deserialize(conn, decoder)?;
let column_packet = ColumnPacket::deserialize(ctx)?;
let columns = if let Some(columns) = column_packet.columns {
(0..columns)
.map(|_| ColumnDefPacket::deserialize(conn, decoder))
.map(|_| ColumnDefPacket::deserialize(ctx))
.filter(Result::is_ok)
.map(Result::unwrap)
.collect::<Vec<ColumnDefPacket>>()
@ -38,19 +35,25 @@ impl Deserialize for ResultSet {
for _ in 0.. {
// if end of buffer stop
if decoder.eof() {
if ctx.decoder.eof() {
break;
}
// Decode each column as string<lenenc>
rows.push(
(0..column_packet.columns.unwrap_or(0))
.map(|_| decoder.decode_string_lenenc())
.map(|_| ctx.decoder.decode_string_lenenc())
.collect::<Vec<Bytes>>(),
)
}
Ok(ResultSet { length, seq_no, column_packet, columns, rows })
Ok(ResultSet {
length,
seq_no,
column_packet,
columns,
rows,
})
}
}

View file

@ -1,12 +1,11 @@
// Reference: https://mariadb.com/kb/en/library/connection
use failure::Error;
use super::{
decode::Decoder,
deserialize::Deserialize,
deserialize::{DeContext, Deserialize},
packets::{err::ErrPacket, initial::InitialHandshakePacket, ok::OkPacket},
};
use crate::connection::Connection;
#[derive(Debug)]
#[non_exhaustive]
@ -17,7 +16,8 @@ pub enum Message {
}
impl Message {
pub fn deserialize(conn: &mut Connection, decoder: &mut Decoder) -> Result<Option<Self>, Error> {
pub fn deserialize(ctx: &mut DeContext) -> Result<Option<Self>, Error> {
let decoder = &mut ctx.decoder;
if decoder.buf.len() < 4 {
return Ok(None);
}
@ -30,8 +30,8 @@ impl Message {
let tag = decoder.buf[4];
Ok(Some(match tag {
0xFF => Message::ErrPacket(ErrPacket::deserialize(conn, decoder)?),
0x00 | 0xFE => Message::OkPacket(OkPacket::deserialize(conn, decoder)?),
0xFF => Message::ErrPacket(ErrPacket::deserialize(ctx)?),
0x00 | 0xFE => Message::OkPacket(OkPacket::deserialize(ctx)?),
_ => unimplemented!(),
}))
}

24
src/main.rs Normal file
View file

@ -0,0 +1,24 @@
#![feature(async_await)]
use mason::{pg::Connection, ConnectOptions};
#[runtime::main]
async fn main() -> Result<(), failure::Error> {
env_logger::try_init()?;
let mut conn =
Connection::establish(ConnectOptions::new().user("postgres").password("password")).await?;
conn.execute("INSERT INTO \"users\" (name) VALUES ($1)")
.bind(b"Joe")
.await?;
conn.prepare("INSERT INTO \"users\" (name) VALUES ($1)")
.bind(b"Joe")
.execute()
.await?;
conn.close().await?;
Ok(())
}