Add length and sequence number to packets

This commit is contained in:
Daniel Akhterov 2019-06-14 14:33:57 -07:00
parent 12913139da
commit 902970ddca
2 changed files with 41 additions and 5 deletions

View file

@ -10,6 +10,7 @@ pub trait Serialize {
#[derive(Default, Debug)]
pub struct SSLRequestPacket {
pub sequence_number: u8,
pub capabilities: Capabilities,
pub max_packet_size: u32,
pub collation: u8,
@ -18,13 +19,24 @@ pub struct SSLRequestPacket {
impl Serialize for SSLRequestPacket {
fn serialize(&self, buf: &mut Vec<u8>) {
// FIXME: Prepend length of packet in standard packet form
// https://mariadb.com/kb/en/library/0-packet
// buf.push(32);
// Temporary storage for length: 3 bytes
buf.push(0);
buf.push(0);
buf.push(0);
// Sequence Numer
buf.push(0);
LittleEndian::write_u32(buf, self.capabilities.bits() as u32);
LittleEndian::write_u32(buf, self.max_packet_size);
buf.push(self.collation);
buf.extend_from_slice(&[0u8;19]);
if !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
if let Some(capabilities) = self.extended_capabilities {
LittleEndian::write_u32(buf, capabilities.bits() as u32);
@ -32,5 +44,11 @@ impl Serialize for SSLRequestPacket {
} else {
buf.extend_from_slice(&[0u8;4]);
}
// Get length in little endian bytes
// packet length = byte[0] + (byte[1]<<8) + (byte[2]<<16)
buf[0] = buf.len().to_le_bytes()[0];
buf[1] = buf.len().to_le_bytes()[1];
buf[2] = buf.len().to_le_bytes()[2];
}
}

View file

@ -1,8 +1,7 @@
// Reference: https://mariadb.com/kb/en/library/connection
use byteorder::{ByteOrder, LittleEndian};
use failure::Error;
use std::iter::FromIterator;
use failure::{Error, err_msg};
use bytes::Bytes;
pub trait Deserialize: Sized {
@ -51,6 +50,8 @@ impl Default for Capabilities {
#[derive(Default, Debug)]
pub struct InitialHandshakePacket {
pub length: u32,
pub sequence_number: u8,
pub protocol_version: u8,
pub server_version: Bytes,
pub connection_id: u32,
@ -66,7 +67,22 @@ pub struct InitialHandshakePacket {
impl Deserialize for InitialHandshakePacket {
fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error> {
let mut index = 0;
let protocol_version = buf[0] as u8;
let length = (buf[0] + (buf[1]<<8) + (buf[2]<<16)) as u32;
index += 3;
if buf.len() != length as usize {
return Err(err_msg("Lengths to do not match"));
}
let sequence_number = buf[index];
index += 1;
if sequence_number != 0 {
return Err(err_msg("Squence Number of Initial Handshake Packet is not 0"));
}
let protocol_version = buf[index] as u8;
index += 1;
let null_index = memchr::memchr(b'\0', &buf[index..]).unwrap();
@ -119,6 +135,8 @@ impl Deserialize for InitialHandshakePacket {
}
Ok(InitialHandshakePacket {
length,
sequence_number,
protocol_version,
server_version,
connection_id,