Add some ground work for #1

This commit is contained in:
Ryan Leckey 2019-12-28 03:45:37 -08:00
parent 48c8d4c100
commit f67872cbcf
5 changed files with 67 additions and 15 deletions

View file

@ -48,8 +48,8 @@ pub trait Executor {
} else {
Ok(Some(val))
}
},
None => Ok(None)
}
None => Ok(None),
}
})
}

View file

@ -151,14 +151,6 @@ impl MySqlConnection {
let handshake_packet = self_.receive().await?;
let handshake = Handshake::decode(handshake_packet)?;
// TODO: Capabilities::SECURE_CONNECTION
// TODO: Capabilities::CONNECT_ATTRS
// TODO: Capabilities::PLUGIN_AUTH
// TODO: Capabilities::PLUGIN_AUTH_LENENC_CLIENT_DATA
// TODO: Capabilities::TRANSACTIONS
// TODO: Capabilities::CLIENT_DEPRECATE_EOF
// TODO: Capabilities::COMPRESS
// TODO: Capabilities::ZSTD_COMPRESSION_ALGORITHM
let client_capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
| Capabilities::FOUND_ROWS
@ -176,6 +168,8 @@ impl MySqlConnection {
username: url.username().unwrap_or("root"),
// TODO: Remove the panic!
database: url.database().expect("required database"),
auth_plugin_name: handshake.auth_plugin_name.as_deref(),
auth_response: None,
});
self_.stream.flush().await?;

View file

@ -8,9 +8,8 @@ use crate::describe::{Column, Describe};
use crate::executor::Executor;
use crate::mysql::error::MySqlError;
use crate::mysql::protocol::{
Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute,
ComStmtPrepare, ComStmtPrepareOk, Cursor, Decode, EofPacket, ErrPacket, OkPacket, Row,
Type,
Capabilities, ColumnCount, ColumnDefinition, ComQuery, ComStmtExecute, ComStmtPrepare,
ComStmtPrepareOk, Cursor, Decode, EofPacket, ErrPacket, OkPacket, Row, Type,
};
use crate::mysql::{MySql, MySqlArguments, MySqlConnection, MySqlRow};

View file

@ -105,6 +105,55 @@ mod tests {
use super::{Capabilities, Decode, Handshake, Status};
const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
#[test]
fn it_decodes_handshake_mysql_8_0_18() {
let mut p = Handshake::decode(HANDSHAKE_MYSQL_8_0_18).unwrap();
assert_eq!(p.protocol_version, 10);
p.server_capabilities.toggle(
Capabilities::MYSQL
| Capabilities::FOUND_ROWS
| Capabilities::LONG_FLAG
| Capabilities::CONNECT_WITH_DB
| Capabilities::NO_SCHEMA
| Capabilities::COMPRESS
| Capabilities::ODBC
| Capabilities::LOCAL_FILES
| Capabilities::IGNORE_SPACE
| Capabilities::PROTOCOL_41
| Capabilities::INTERACTIVE
| Capabilities::SSL
| Capabilities::TRANSACTIONS
| Capabilities::SECURE_CONNECTION
| Capabilities::MULTI_STATEMENTS
| Capabilities::MULTI_RESULTS
| Capabilities::PS_MULTI_RESULTS
| Capabilities::PLUGIN_AUTH
| Capabilities::CONNECT_ATTRS
| Capabilities::PLUGIN_AUTH_LENENC_DATA
| Capabilities::CAN_HANDLE_EXPIRED_PASSWORDS
| Capabilities::SESSION_TRACK
| Capabilities::DEPRECATE_EOF
| Capabilities::ZSTD_COMPRESSION_ALGORITHM
| Capabilities::SSL_VERIFY_SERVER_CERT
| Capabilities::OPTIONAL_RESULTSET_METADATA
| Capabilities::REMEMBER_OPTIONS,
);
assert!(p.server_capabilities.is_empty());
assert_eq!(p.server_default_collation, 255);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert_eq!(p.auth_plugin_name.as_deref(), Some("caching_sha2_password"));
assert_eq!(
&*p.auth_plugin_data,
&[17, 52, 97, 66, 48, 99, 6, 103, 116, 76, 3, 115, 15, 91, 52, 13, 108, 52, 46, 32,]
);
}
#[test]
fn it_decodes_handshake_mariadb_10_4_7() {

View file

@ -12,6 +12,8 @@ pub struct HandshakeResponse<'a> {
pub client_collation: u8,
pub username: &'a str,
pub database: &'a str,
pub auth_plugin_name: Option<&'a str>,
pub auth_response: Option<&'a str>,
}
impl Encode for HandshakeResponse<'_> {
@ -41,17 +43,25 @@ impl Encode for HandshakeResponse<'_> {
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
// auth_response : string<lenenc>
buf.put_str_lenenc::<LittleEndian>("");
buf.put_str_lenenc::<LittleEndian>(self.auth_response.unwrap_or_default());
} else {
let auth_response = self.auth_response.unwrap_or_default();
// auth_response_length : int<1>
buf.put_u8(0);
buf.put_u8(auth_response.len() as u8);
// auth_response : string<{auth_response_length}>
buf.put_str(auth_response);
}
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
// database : string<NUL>
buf.put_str_nul(self.database);
}
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
// client_plugin_name : string<NUL>
buf.put_str_nul(self.auth_plugin_name.unwrap_or_default());
}
}
}