Use bitflags and memchr

This commit is contained in:
Daniel Akhterov 2019-06-14 13:57:46 -07:00
parent e56f364599
commit 12913139da
4 changed files with 73 additions and 98 deletions

View file

@ -16,3 +16,4 @@ log = "0.4"
hex = "0.3.2" hex = "0.3.2"
bytes = "0.4.12" bytes = "0.4.12"
memchr = "2.2.0" memchr = "2.2.0"
bitflags = "1.1.0"

View file

@ -1,5 +1,8 @@
#![feature(non_exhaustive, async_await)] #![feature(non_exhaustive, async_await)]
#![allow(clippy::needless_lifetimes)] #![allow(clippy::needless_lifetimes)]
#[macro_use]
extern crate bitflags;
// mod connection; // mod connection;
mod protocol; mod protocol;

View file

@ -10,10 +10,10 @@ pub trait Serialize {
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct SSLRequestPacket { pub struct SSLRequestPacket {
pub capabilities: u32, pub capabilities: Capabilities,
pub max_packet_size: u32, pub max_packet_size: u32,
pub collation: u8, pub collation: u8,
pub extended_capabilities: Option<u32>, pub extended_capabilities: Option<Capabilities>,
} }
impl Serialize for SSLRequestPacket { impl Serialize for SSLRequestPacket {
@ -21,13 +21,13 @@ impl Serialize for SSLRequestPacket {
// FIXME: Prepend length of packet in standard packet form // FIXME: Prepend length of packet in standard packet form
// https://mariadb.com/kb/en/library/0-packet // https://mariadb.com/kb/en/library/0-packet
// buf.push(32); // buf.push(32);
LittleEndian::write_u32(buf, self.capabilities); LittleEndian::write_u32(buf, self.capabilities.bits() as u32);
LittleEndian::write_u32(buf, self.max_packet_size); LittleEndian::write_u32(buf, self.max_packet_size);
buf.push(self.collation); buf.push(self.collation);
buf.extend_from_slice(&[0u8;19]); buf.extend_from_slice(&[0u8;19]);
if self.capabilities as u128 & Capabilities::ClientMysql as u128 > 0 { if !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
if let Some(capabilities) = self.extended_capabilities { if let Some(capabilities) = self.extended_capabilities {
LittleEndian::write_u32(buf, capabilities); LittleEndian::write_u32(buf, capabilities.bits() as u32);
} }
} else { } else {
buf.extend_from_slice(&[0u8;4]); buf.extend_from_slice(&[0u8;4]);

View file

@ -3,6 +3,7 @@
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use failure::Error; use failure::Error;
use std::iter::FromIterator; use std::iter::FromIterator;
use bytes::Bytes;
pub trait Deserialize: Sized { pub trait Deserialize: Sized {
fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error>; fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error>;
@ -14,90 +15,74 @@ pub enum Message {
InitialHandshakePacket(InitialHandshakePacket), InitialHandshakePacket(InitialHandshakePacket),
} }
pub enum Capabilities {
ClientMysql = 1, bitflags! {
FoundRows = 2, pub struct Capabilities: u128 {
ConnectWithDb = 8, const CLIENT_MYSQL = 1;
Compress = 32, const FOUND_ROWS = 2;
LocalFiles = 128, const CONNECT_WITH_DB = 8;
IgnroeSpace = 256, const COMPRESS = 32;
ClientProtocol41 = 1 << 9, const LOCAL_FILES = 128;
ClientInteractive = 1 << 10, const IGNORE_SPACE = 256;
SSL = 1 << 11, const CLIENT_PROTOCOL_41 = 1 << 9;
Transactions = 1 << 12, const CLIENT_INTERACTIVE = 1 << 10;
SecureConnection = 1 << 13, const SSL = 1 << 11;
MultiStatements = 1 << 16, const TRANSACTIONS = 1 << 12;
MultiResults = 1 << 17, const SECURE_CONNECTION = 1 << 13;
PsMultiResults = 1 << 18, const MULTI_STATEMENTS = 1 << 16;
PluginAuth = 1 << 19, const MULTI_RESULTS = 1 << 17;
ConnectAttrs = 1 << 20, const PS_MULTI_RESULTS = 1 << 18;
PluginAuthLenencClientData = 1 << 21, const PLUGIN_AUTH = 1 << 19;
ClientSessionTrack = 1 << 23, const CONNECT_ATTRS = 1 << 20;
ClientDeprecateEof = 1 << 24, const PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21;
MariaDbClientProgress = 1 << 32, const CLIENT_SESSION_TRACK = 1 << 23;
MariaDbClientComMulti = 1 << 33, const CLIENT_DEPRECATE_EOF = 1 << 24;
MariaClientStmtBulkOperations = 1 << 34, const MARIA_DB_CLIENT_PROGRESS = 1 << 32;
const MARIA_DB_CLIENT_COM_MULTI = 1 << 33;
const MARIA_CLIENT_STMT_BULK_OPERATIONS = 1 << 34;
}
}
impl Default for Capabilities {
fn default() -> Self {
Capabilities::CLIENT_MYSQL
}
} }
#[derive(Default, Debug)] #[derive(Default, Debug)]
pub struct InitialHandshakePacket { pub struct InitialHandshakePacket {
pub protocol_version: u8, pub protocol_version: u8,
pub server_version: String, pub server_version: Bytes,
pub connection_id: u32, pub connection_id: u32,
pub auth_seed: String, pub auth_seed: Bytes,
pub reserved: u8, pub capabilities: Capabilities,
pub capabilities1: u16,
pub collation: u8, pub collation: u8,
pub status: u16, pub status: u16,
pub plugin_data_length: u8, pub plugin_data_length: u8,
pub scramble2: Option<String>, pub scramble2: Option<Bytes>,
pub reserved2: Option<u8>, pub auth_plugin_name: Option<Bytes>,
pub auth_plugin_name: Option<String>,
} }
impl Deserialize for InitialHandshakePacket { impl Deserialize for InitialHandshakePacket {
fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error> { fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error> {
let mut index = 0; let mut index = 0;
let mut null_index = 0;
let protocol_version = buf[0] as u8; let protocol_version = buf[0] as u8;
index += 1; index += 1;
// Find index of null character let null_index = memchr::memchr(b'\0', &buf[index..]).unwrap();
null_index = index; let server_version = Bytes::from(buf[index..null_index].to_vec());
loop {
if buf[null_index] == b'\0' {
break;
}
null_index += 1;
}
let server_version = String::from_iter(
buf[index..null_index]
.iter()
.map(|b| char::from(b.clone()))
.collect::<Vec<char>>()
.into_iter(),
);
// Script null character
index = null_index + 1; index = null_index + 1;
let connection_id = LittleEndian::read_u32(&buf); let connection_id = LittleEndian::read_u32(&buf);
// Increment by index by 4 bytes since we read a u32
index += 4; index += 4;
let auth_seed = String::from_iter( let auth_seed = Bytes::from(buf[index..index + 8].to_vec());
buf[index..index + 8]
.iter()
.map(|b| char::from(b.clone()))
.collect::<Vec<char>>()
.into_iter(),
);
index += 8; index += 8;
// Skip reserved byte // Skip reserved byte
index += 1; index += 1;
let mut capabilities = LittleEndian::read_u16(&buf[index..]) as u32; let mut capabilities = Capabilities::from_bits(LittleEndian::read_u16(&buf[index..]).into()).unwrap();
index += 2; index += 2;
let collation = buf[index]; let collation = buf[index];
@ -106,11 +91,11 @@ impl Deserialize for InitialHandshakePacket {
let status = LittleEndian::read_u16(&buf[index..]); let status = LittleEndian::read_u16(&buf[index..]);
index += 2; index += 2;
capabilities |= LittleEndian::read_u16(&buf[index..]) as u32; capabilities |= Capabilities::from_bits(LittleEndian::read_u16(&buf[index..]).into()).unwrap();
index += 2; index += 2;
let mut plugin_data_length = 0; let mut plugin_data_length = 0;
if capabilities as u128 & Capabilities::PluginAuth as u128 > 0 { if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
plugin_data_length = buf[index] as u8; plugin_data_length = buf[index] as u8;
} }
index += 1; index += 1;
@ -118,46 +103,32 @@ impl Deserialize for InitialHandshakePacket {
// Skip filler // Skip filler
index += 6; index += 6;
if capabilities as u128 & Capabilities::ClientMysql as u128 == 0 { if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
capabilities |= LittleEndian::read_u32(&buf[index..]); capabilities |= Capabilities::from_bits(LittleEndian::read_u32(&buf[index..]).into()).unwrap();
} }
index += 4; index += 4;
let mut scramble2: Option<String> = None; let mut scramble2: Option<Bytes> = None;
let mut auth_plugin_name: Option<String> = None; let mut auth_plugin_name: Option<Bytes> = None;
if capabilities as u128 & Capabilities::SecureConnection as u128 > 0 { if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
// TODO: scramble 2nd part. Length = max(12, plugin_data_length - 9)
let len = std::cmp::max(12, plugin_data_length - 9); let len = std::cmp::max(12, plugin_data_length - 9);
scramble2 = Some(String::from_iter( scramble2 = Some(Bytes::from(buf[index..index + len as usize].to_vec()));
buf[index..index + len as usize]
.iter()
.map(|b| char::from(b.clone()))
.collect::<Vec<char>>()
.into_iter(),
));
// Skip length characters + the reserved byte
index += len as usize + 1;
} else { } else {
// TODO: auth_plugin_name null temrinated string let null_index = memchr::memchr(b'\0', &buf[index..]).unwrap();
// Find index of null character auth_plugin_name = Some(Bytes::from(buf[index..null_index].to_vec()));
null_index = index;
loop {
if buf[null_index] == b'\0' {
break;
}
null_index += 1;
}
auth_plugin_name = Some(String::from_iter(
buf[index..null_index]
.iter()
.map(|b| char::from(b.clone()))
.collect::<Vec<char>>()
.into_iter(),
));
// Script null character
index = null_index + 1;
} }
Ok(InitialHandshakePacket::default()) Ok(InitialHandshakePacket {
protocol_version,
server_version,
connection_id,
auth_seed,
capabilities,
collation,
status,
plugin_data_length,
scramble2,
auth_plugin_name,
})
} }
} }