mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 14:34:19 +00:00
Use bitflags and memchr
This commit is contained in:
parent
e56f364599
commit
12913139da
4 changed files with 73 additions and 98 deletions
|
@ -16,3 +16,4 @@ log = "0.4"
|
||||||
hex = "0.3.2"
|
hex = "0.3.2"
|
||||||
bytes = "0.4.12"
|
bytes = "0.4.12"
|
||||||
memchr = "2.2.0"
|
memchr = "2.2.0"
|
||||||
|
bitflags = "1.1.0"
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
#![feature(non_exhaustive, async_await)]
|
#![feature(non_exhaustive, async_await)]
|
||||||
#![allow(clippy::needless_lifetimes)]
|
#![allow(clippy::needless_lifetimes)]
|
||||||
|
|
||||||
|
#[macro_use]
|
||||||
|
extern crate bitflags;
|
||||||
|
|
||||||
// mod connection;
|
// mod connection;
|
||||||
mod protocol;
|
mod protocol;
|
||||||
|
|
|
@ -10,10 +10,10 @@ pub trait Serialize {
|
||||||
|
|
||||||
#[derive(Default, Debug)]
|
#[derive(Default, Debug)]
|
||||||
pub struct SSLRequestPacket {
|
pub struct SSLRequestPacket {
|
||||||
pub capabilities: u32,
|
pub capabilities: Capabilities,
|
||||||
pub max_packet_size: u32,
|
pub max_packet_size: u32,
|
||||||
pub collation: u8,
|
pub collation: u8,
|
||||||
pub extended_capabilities: Option<u32>,
|
pub extended_capabilities: Option<Capabilities>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Serialize for SSLRequestPacket {
|
impl Serialize for SSLRequestPacket {
|
||||||
|
@ -21,13 +21,13 @@ impl Serialize for SSLRequestPacket {
|
||||||
// FIXME: Prepend length of packet in standard packet form
|
// FIXME: Prepend length of packet in standard packet form
|
||||||
// https://mariadb.com/kb/en/library/0-packet
|
// https://mariadb.com/kb/en/library/0-packet
|
||||||
// buf.push(32);
|
// buf.push(32);
|
||||||
LittleEndian::write_u32(buf, self.capabilities);
|
LittleEndian::write_u32(buf, self.capabilities.bits() as u32);
|
||||||
LittleEndian::write_u32(buf, self.max_packet_size);
|
LittleEndian::write_u32(buf, self.max_packet_size);
|
||||||
buf.push(self.collation);
|
buf.push(self.collation);
|
||||||
buf.extend_from_slice(&[0u8;19]);
|
buf.extend_from_slice(&[0u8;19]);
|
||||||
if self.capabilities as u128 & Capabilities::ClientMysql as u128 > 0 {
|
if !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
|
||||||
if let Some(capabilities) = self.extended_capabilities {
|
if let Some(capabilities) = self.extended_capabilities {
|
||||||
LittleEndian::write_u32(buf, capabilities);
|
LittleEndian::write_u32(buf, capabilities.bits() as u32);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
buf.extend_from_slice(&[0u8;4]);
|
buf.extend_from_slice(&[0u8;4]);
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
use byteorder::{ByteOrder, LittleEndian};
|
use byteorder::{ByteOrder, LittleEndian};
|
||||||
use failure::Error;
|
use failure::Error;
|
||||||
use std::iter::FromIterator;
|
use std::iter::FromIterator;
|
||||||
|
use bytes::Bytes;
|
||||||
|
|
||||||
pub trait Deserialize: Sized {
|
pub trait Deserialize: Sized {
|
||||||
fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error>;
|
fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error>;
|
||||||
|
@ -14,90 +15,74 @@ pub enum Message {
|
||||||
InitialHandshakePacket(InitialHandshakePacket),
|
InitialHandshakePacket(InitialHandshakePacket),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum Capabilities {
|
|
||||||
ClientMysql = 1,
|
bitflags! {
|
||||||
FoundRows = 2,
|
pub struct Capabilities: u128 {
|
||||||
ConnectWithDb = 8,
|
const CLIENT_MYSQL = 1;
|
||||||
Compress = 32,
|
const FOUND_ROWS = 2;
|
||||||
LocalFiles = 128,
|
const CONNECT_WITH_DB = 8;
|
||||||
IgnroeSpace = 256,
|
const COMPRESS = 32;
|
||||||
ClientProtocol41 = 1 << 9,
|
const LOCAL_FILES = 128;
|
||||||
ClientInteractive = 1 << 10,
|
const IGNORE_SPACE = 256;
|
||||||
SSL = 1 << 11,
|
const CLIENT_PROTOCOL_41 = 1 << 9;
|
||||||
Transactions = 1 << 12,
|
const CLIENT_INTERACTIVE = 1 << 10;
|
||||||
SecureConnection = 1 << 13,
|
const SSL = 1 << 11;
|
||||||
MultiStatements = 1 << 16,
|
const TRANSACTIONS = 1 << 12;
|
||||||
MultiResults = 1 << 17,
|
const SECURE_CONNECTION = 1 << 13;
|
||||||
PsMultiResults = 1 << 18,
|
const MULTI_STATEMENTS = 1 << 16;
|
||||||
PluginAuth = 1 << 19,
|
const MULTI_RESULTS = 1 << 17;
|
||||||
ConnectAttrs = 1 << 20,
|
const PS_MULTI_RESULTS = 1 << 18;
|
||||||
PluginAuthLenencClientData = 1 << 21,
|
const PLUGIN_AUTH = 1 << 19;
|
||||||
ClientSessionTrack = 1 << 23,
|
const CONNECT_ATTRS = 1 << 20;
|
||||||
ClientDeprecateEof = 1 << 24,
|
const PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21;
|
||||||
MariaDbClientProgress = 1 << 32,
|
const CLIENT_SESSION_TRACK = 1 << 23;
|
||||||
MariaDbClientComMulti = 1 << 33,
|
const CLIENT_DEPRECATE_EOF = 1 << 24;
|
||||||
MariaClientStmtBulkOperations = 1 << 34,
|
const MARIA_DB_CLIENT_PROGRESS = 1 << 32;
|
||||||
|
const MARIA_DB_CLIENT_COM_MULTI = 1 << 33;
|
||||||
|
const MARIA_CLIENT_STMT_BULK_OPERATIONS = 1 << 34;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Capabilities {
|
||||||
|
fn default() -> Self {
|
||||||
|
Capabilities::CLIENT_MYSQL
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default, Debug)]
|
#[derive(Default, Debug)]
|
||||||
pub struct InitialHandshakePacket {
|
pub struct InitialHandshakePacket {
|
||||||
pub protocol_version: u8,
|
pub protocol_version: u8,
|
||||||
pub server_version: String,
|
pub server_version: Bytes,
|
||||||
pub connection_id: u32,
|
pub connection_id: u32,
|
||||||
pub auth_seed: String,
|
pub auth_seed: Bytes,
|
||||||
pub reserved: u8,
|
pub capabilities: Capabilities,
|
||||||
pub capabilities1: u16,
|
|
||||||
pub collation: u8,
|
pub collation: u8,
|
||||||
pub status: u16,
|
pub status: u16,
|
||||||
pub plugin_data_length: u8,
|
pub plugin_data_length: u8,
|
||||||
pub scramble2: Option<String>,
|
pub scramble2: Option<Bytes>,
|
||||||
pub reserved2: Option<u8>,
|
pub auth_plugin_name: Option<Bytes>,
|
||||||
pub auth_plugin_name: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Deserialize for InitialHandshakePacket {
|
impl Deserialize for InitialHandshakePacket {
|
||||||
fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error> {
|
fn deserialize(buf: &mut Vec<u8>) -> Result<Self, Error> {
|
||||||
let mut index = 0;
|
let mut index = 0;
|
||||||
let mut null_index = 0;
|
|
||||||
let protocol_version = buf[0] as u8;
|
let protocol_version = buf[0] as u8;
|
||||||
index += 1;
|
index += 1;
|
||||||
|
|
||||||
// Find index of null character
|
let null_index = memchr::memchr(b'\0', &buf[index..]).unwrap();
|
||||||
null_index = index;
|
let server_version = Bytes::from(buf[index..null_index].to_vec());
|
||||||
loop {
|
|
||||||
if buf[null_index] == b'\0' {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
null_index += 1;
|
|
||||||
}
|
|
||||||
let server_version = String::from_iter(
|
|
||||||
buf[index..null_index]
|
|
||||||
.iter()
|
|
||||||
.map(|b| char::from(b.clone()))
|
|
||||||
.collect::<Vec<char>>()
|
|
||||||
.into_iter(),
|
|
||||||
);
|
|
||||||
// Script null character
|
|
||||||
index = null_index + 1;
|
index = null_index + 1;
|
||||||
|
|
||||||
let connection_id = LittleEndian::read_u32(&buf);
|
let connection_id = LittleEndian::read_u32(&buf);
|
||||||
|
|
||||||
// Increment by index by 4 bytes since we read a u32
|
|
||||||
index += 4;
|
index += 4;
|
||||||
|
|
||||||
let auth_seed = String::from_iter(
|
let auth_seed = Bytes::from(buf[index..index + 8].to_vec());
|
||||||
buf[index..index + 8]
|
|
||||||
.iter()
|
|
||||||
.map(|b| char::from(b.clone()))
|
|
||||||
.collect::<Vec<char>>()
|
|
||||||
.into_iter(),
|
|
||||||
);
|
|
||||||
index += 8;
|
index += 8;
|
||||||
|
|
||||||
// Skip reserved byte
|
// Skip reserved byte
|
||||||
index += 1;
|
index += 1;
|
||||||
|
|
||||||
let mut capabilities = LittleEndian::read_u16(&buf[index..]) as u32;
|
let mut capabilities = Capabilities::from_bits(LittleEndian::read_u16(&buf[index..]).into()).unwrap();
|
||||||
index += 2;
|
index += 2;
|
||||||
|
|
||||||
let collation = buf[index];
|
let collation = buf[index];
|
||||||
|
@ -106,11 +91,11 @@ impl Deserialize for InitialHandshakePacket {
|
||||||
let status = LittleEndian::read_u16(&buf[index..]);
|
let status = LittleEndian::read_u16(&buf[index..]);
|
||||||
index += 2;
|
index += 2;
|
||||||
|
|
||||||
capabilities |= LittleEndian::read_u16(&buf[index..]) as u32;
|
capabilities |= Capabilities::from_bits(LittleEndian::read_u16(&buf[index..]).into()).unwrap();
|
||||||
index += 2;
|
index += 2;
|
||||||
|
|
||||||
let mut plugin_data_length = 0;
|
let mut plugin_data_length = 0;
|
||||||
if capabilities as u128 & Capabilities::PluginAuth as u128 > 0 {
|
if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() {
|
||||||
plugin_data_length = buf[index] as u8;
|
plugin_data_length = buf[index] as u8;
|
||||||
}
|
}
|
||||||
index += 1;
|
index += 1;
|
||||||
|
@ -118,46 +103,32 @@ impl Deserialize for InitialHandshakePacket {
|
||||||
// Skip filler
|
// Skip filler
|
||||||
index += 6;
|
index += 6;
|
||||||
|
|
||||||
if capabilities as u128 & Capabilities::ClientMysql as u128 == 0 {
|
if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() {
|
||||||
capabilities |= LittleEndian::read_u32(&buf[index..]);
|
capabilities |= Capabilities::from_bits(LittleEndian::read_u32(&buf[index..]).into()).unwrap();
|
||||||
}
|
}
|
||||||
index += 4;
|
index += 4;
|
||||||
|
|
||||||
let mut scramble2: Option<String> = None;
|
let mut scramble2: Option<Bytes> = None;
|
||||||
let mut auth_plugin_name: Option<String> = None;
|
let mut auth_plugin_name: Option<Bytes> = None;
|
||||||
if capabilities as u128 & Capabilities::SecureConnection as u128 > 0 {
|
if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() {
|
||||||
// TODO: scramble 2nd part. Length = max(12, plugin_data_length - 9)
|
|
||||||
let len = std::cmp::max(12, plugin_data_length - 9);
|
let len = std::cmp::max(12, plugin_data_length - 9);
|
||||||
scramble2 = Some(String::from_iter(
|
scramble2 = Some(Bytes::from(buf[index..index + len as usize].to_vec()));
|
||||||
buf[index..index + len as usize]
|
|
||||||
.iter()
|
|
||||||
.map(|b| char::from(b.clone()))
|
|
||||||
.collect::<Vec<char>>()
|
|
||||||
.into_iter(),
|
|
||||||
));
|
|
||||||
// Skip length characters + the reserved byte
|
|
||||||
index += len as usize + 1;
|
|
||||||
} else {
|
} else {
|
||||||
// TODO: auth_plugin_name null temrinated string
|
let null_index = memchr::memchr(b'\0', &buf[index..]).unwrap();
|
||||||
// Find index of null character
|
auth_plugin_name = Some(Bytes::from(buf[index..null_index].to_vec()));
|
||||||
null_index = index;
|
|
||||||
loop {
|
|
||||||
if buf[null_index] == b'\0' {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
null_index += 1;
|
|
||||||
}
|
|
||||||
auth_plugin_name = Some(String::from_iter(
|
|
||||||
buf[index..null_index]
|
|
||||||
.iter()
|
|
||||||
.map(|b| char::from(b.clone()))
|
|
||||||
.collect::<Vec<char>>()
|
|
||||||
.into_iter(),
|
|
||||||
));
|
|
||||||
// Script null character
|
|
||||||
index = null_index + 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(InitialHandshakePacket::default())
|
Ok(InitialHandshakePacket {
|
||||||
|
protocol_version,
|
||||||
|
server_version,
|
||||||
|
connection_id,
|
||||||
|
auth_seed,
|
||||||
|
capabilities,
|
||||||
|
collation,
|
||||||
|
status,
|
||||||
|
plugin_data_length,
|
||||||
|
scramble2,
|
||||||
|
auth_plugin_name,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue