diff --git a/mason-mariadb/Cargo.toml b/mason-mariadb/Cargo.toml index e89e5140..652c795d 100644 --- a/mason-mariadb/Cargo.toml +++ b/mason-mariadb/Cargo.toml @@ -16,3 +16,4 @@ log = "0.4" hex = "0.3.2" bytes = "0.4.12" memchr = "2.2.0" +bitflags = "1.1.0" diff --git a/mason-mariadb/src/lib.rs b/mason-mariadb/src/lib.rs index 8a8bf704..2e760c7f 100644 --- a/mason-mariadb/src/lib.rs +++ b/mason-mariadb/src/lib.rs @@ -1,5 +1,8 @@ #![feature(non_exhaustive, async_await)] #![allow(clippy::needless_lifetimes)] +#[macro_use] +extern crate bitflags; + // mod connection; mod protocol; diff --git a/mason-mariadb/src/protocol/client.rs b/mason-mariadb/src/protocol/client.rs index 07e2c5a7..78691e80 100644 --- a/mason-mariadb/src/protocol/client.rs +++ b/mason-mariadb/src/protocol/client.rs @@ -10,10 +10,10 @@ pub trait Serialize { #[derive(Default, Debug)] pub struct SSLRequestPacket { - pub capabilities: u32, + pub capabilities: Capabilities, pub max_packet_size: u32, pub collation: u8, - pub extended_capabilities: Option, + pub extended_capabilities: Option, } impl Serialize for SSLRequestPacket { @@ -21,13 +21,13 @@ impl Serialize for SSLRequestPacket { // FIXME: Prepend length of packet in standard packet form // https://mariadb.com/kb/en/library/0-packet // buf.push(32); - LittleEndian::write_u32(buf, self.capabilities); + LittleEndian::write_u32(buf, self.capabilities.bits() as u32); LittleEndian::write_u32(buf, self.max_packet_size); buf.push(self.collation); buf.extend_from_slice(&[0u8;19]); - if self.capabilities as u128 & Capabilities::ClientMysql as u128 > 0 { + if !(self.capabilities & Capabilities::CLIENT_MYSQL).is_empty() { if let Some(capabilities) = self.extended_capabilities { - LittleEndian::write_u32(buf, capabilities); + LittleEndian::write_u32(buf, capabilities.bits() as u32); } } else { buf.extend_from_slice(&[0u8;4]); diff --git a/mason-mariadb/src/protocol/server.rs b/mason-mariadb/src/protocol/server.rs index ec62eda2..ba6e7ca4 100644 --- a/mason-mariadb/src/protocol/server.rs +++ b/mason-mariadb/src/protocol/server.rs @@ -3,6 +3,7 @@ use byteorder::{ByteOrder, LittleEndian}; use failure::Error; use std::iter::FromIterator; +use bytes::Bytes; pub trait Deserialize: Sized { fn deserialize(buf: &mut Vec) -> Result; @@ -14,90 +15,74 @@ pub enum Message { InitialHandshakePacket(InitialHandshakePacket), } -pub enum Capabilities { - ClientMysql = 1, - FoundRows = 2, - ConnectWithDb = 8, - Compress = 32, - LocalFiles = 128, - IgnroeSpace = 256, - ClientProtocol41 = 1 << 9, - ClientInteractive = 1 << 10, - SSL = 1 << 11, - Transactions = 1 << 12, - SecureConnection = 1 << 13, - MultiStatements = 1 << 16, - MultiResults = 1 << 17, - PsMultiResults = 1 << 18, - PluginAuth = 1 << 19, - ConnectAttrs = 1 << 20, - PluginAuthLenencClientData = 1 << 21, - ClientSessionTrack = 1 << 23, - ClientDeprecateEof = 1 << 24, - MariaDbClientProgress = 1 << 32, - MariaDbClientComMulti = 1 << 33, - MariaClientStmtBulkOperations = 1 << 34, + +bitflags! { + pub struct Capabilities: u128 { + const CLIENT_MYSQL = 1; + const FOUND_ROWS = 2; + const CONNECT_WITH_DB = 8; + const COMPRESS = 32; + const LOCAL_FILES = 128; + const IGNORE_SPACE = 256; + const CLIENT_PROTOCOL_41 = 1 << 9; + const CLIENT_INTERACTIVE = 1 << 10; + const SSL = 1 << 11; + const TRANSACTIONS = 1 << 12; + const SECURE_CONNECTION = 1 << 13; + const MULTI_STATEMENTS = 1 << 16; + const MULTI_RESULTS = 1 << 17; + const PS_MULTI_RESULTS = 1 << 18; + const PLUGIN_AUTH = 1 << 19; + const CONNECT_ATTRS = 1 << 20; + const PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21; + const CLIENT_SESSION_TRACK = 1 << 23; + const CLIENT_DEPRECATE_EOF = 1 << 24; + const MARIA_DB_CLIENT_PROGRESS = 1 << 32; + const MARIA_DB_CLIENT_COM_MULTI = 1 << 33; + const MARIA_CLIENT_STMT_BULK_OPERATIONS = 1 << 34; + } +} + +impl Default for Capabilities { + fn default() -> Self { + Capabilities::CLIENT_MYSQL + } } #[derive(Default, Debug)] pub struct InitialHandshakePacket { pub protocol_version: u8, - pub server_version: String, + pub server_version: Bytes, pub connection_id: u32, - pub auth_seed: String, - pub reserved: u8, - pub capabilities1: u16, + pub auth_seed: Bytes, + pub capabilities: Capabilities, pub collation: u8, pub status: u16, pub plugin_data_length: u8, - pub scramble2: Option, - pub reserved2: Option, - pub auth_plugin_name: Option, + pub scramble2: Option, + pub auth_plugin_name: Option, } impl Deserialize for InitialHandshakePacket { fn deserialize(buf: &mut Vec) -> Result { let mut index = 0; - let mut null_index = 0; let protocol_version = buf[0] as u8; index += 1; - // Find index of null character - null_index = index; - loop { - if buf[null_index] == b'\0' { - break; - } - null_index += 1; - } - let server_version = String::from_iter( - buf[index..null_index] - .iter() - .map(|b| char::from(b.clone())) - .collect::>() - .into_iter(), - ); - // Script null character + let null_index = memchr::memchr(b'\0', &buf[index..]).unwrap(); + let server_version = Bytes::from(buf[index..null_index].to_vec()); index = null_index + 1; let connection_id = LittleEndian::read_u32(&buf); - - // Increment by index by 4 bytes since we read a u32 index += 4; - let auth_seed = String::from_iter( - buf[index..index + 8] - .iter() - .map(|b| char::from(b.clone())) - .collect::>() - .into_iter(), - ); + let auth_seed = Bytes::from(buf[index..index + 8].to_vec()); index += 8; // Skip reserved byte index += 1; - let mut capabilities = LittleEndian::read_u16(&buf[index..]) as u32; + let mut capabilities = Capabilities::from_bits(LittleEndian::read_u16(&buf[index..]).into()).unwrap(); index += 2; let collation = buf[index]; @@ -106,11 +91,11 @@ impl Deserialize for InitialHandshakePacket { let status = LittleEndian::read_u16(&buf[index..]); index += 2; - capabilities |= LittleEndian::read_u16(&buf[index..]) as u32; + capabilities |= Capabilities::from_bits(LittleEndian::read_u16(&buf[index..]).into()).unwrap(); index += 2; let mut plugin_data_length = 0; - if capabilities as u128 & Capabilities::PluginAuth as u128 > 0 { + if !(capabilities & Capabilities::PLUGIN_AUTH).is_empty() { plugin_data_length = buf[index] as u8; } index += 1; @@ -118,46 +103,32 @@ impl Deserialize for InitialHandshakePacket { // Skip filler index += 6; - if capabilities as u128 & Capabilities::ClientMysql as u128 == 0 { - capabilities |= LittleEndian::read_u32(&buf[index..]); + if (capabilities & Capabilities::CLIENT_MYSQL).is_empty() { + capabilities |= Capabilities::from_bits(LittleEndian::read_u32(&buf[index..]).into()).unwrap(); } index += 4; - let mut scramble2: Option = None; - let mut auth_plugin_name: Option = None; - if capabilities as u128 & Capabilities::SecureConnection as u128 > 0 { - // TODO: scramble 2nd part. Length = max(12, plugin_data_length - 9) + let mut scramble2: Option = None; + let mut auth_plugin_name: Option = None; + if !(capabilities & Capabilities::SECURE_CONNECTION).is_empty() { let len = std::cmp::max(12, plugin_data_length - 9); - scramble2 = Some(String::from_iter( - buf[index..index + len as usize] - .iter() - .map(|b| char::from(b.clone())) - .collect::>() - .into_iter(), - )); - // Skip length characters + the reserved byte - index += len as usize + 1; + scramble2 = Some(Bytes::from(buf[index..index + len as usize].to_vec())); } else { - // TODO: auth_plugin_name null temrinated string - // Find index of null character - null_index = index; - loop { - if buf[null_index] == b'\0' { - break; - } - null_index += 1; - } - auth_plugin_name = Some(String::from_iter( - buf[index..null_index] - .iter() - .map(|b| char::from(b.clone())) - .collect::>() - .into_iter(), - )); - // Script null character - index = null_index + 1; + let null_index = memchr::memchr(b'\0', &buf[index..]).unwrap(); + auth_plugin_name = Some(Bytes::from(buf[index..null_index].to_vec())); } - Ok(InitialHandshakePacket::default()) + Ok(InitialHandshakePacket { + protocol_version, + server_version, + connection_id, + auth_seed, + capabilities, + collation, + status, + plugin_data_length, + scramble2, + auth_plugin_name, + }) } }