[MySQL] Add initial support for authenticationSupports caching_sha2_password and sha256_password

This commit is contained in:
Ryan Leckey 2019-12-31 14:16:48 -08:00
parent 1d7a2f27c6
commit ce343dee9c
14 changed files with 845 additions and 206 deletions

48
Cargo.lock generated
View file

@ -312,6 +312,11 @@ dependencies = [
"version_check 0.1.5 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "fake-simd"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "fnv"
version = "1.0.6"
@ -760,6 +765,16 @@ dependencies = [
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num-bigint"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"autocfg 0.1.7 (registry+https://github.com/rust-lang/crates.io-index)",
"num-integer 0.1.41 (registry+https://github.com/rust-lang/crates.io-index)",
"num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "num-integer"
version = "0.1.41"
@ -1022,6 +1037,28 @@ dependencies = [
"serde 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "sha-1"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"block-buffer 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
"digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
"fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"opaque-debug 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "sha2"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"block-buffer 0.7.3 (registry+https://github.com/rust-lang/crates.io-index)",
"digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
"fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)",
"opaque-debug 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "slab"
version = "0.4.2"
@ -1064,15 +1101,22 @@ version = "0.1.1"
dependencies = [
"async-std 1.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
"async-stream 0.2.0 (registry+https://github.com/rust-lang/crates.io-index)",
"base64 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)",
"bitflags 1.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
"chrono 0.4.10 (registry+https://github.com/rust-lang/crates.io-index)",
"digest 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-core 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"futures-util 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
"generic-array 0.12.3 (registry+https://github.com/rust-lang/crates.io-index)",
"log 0.4.8 (registry+https://github.com/rust-lang/crates.io-index)",
"matches 0.1.8 (registry+https://github.com/rust-lang/crates.io-index)",
"md-5 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
"memchr 2.2.1 (registry+https://github.com/rust-lang/crates.io-index)",
"num-bigint 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.7.2 (registry+https://github.com/rust-lang/crates.io-index)",
"sha-1 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
"sha2 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)",
"url 2.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
"uuid 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
@ -1425,6 +1469,7 @@ dependencies = [
"checksum dotenv 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)" = "77c90badedccf4105eca100756a0b1289e191f6fcbdadd3cee1d2f614f97da8f"
"checksum either 1.5.3 (registry+https://github.com/rust-lang/crates.io-index)" = "bb1f6b1ce1c140482ea30ddd3335fc0024ac7ee112895426e0a629a6c20adfe3"
"checksum error-chain 0.12.1 (registry+https://github.com/rust-lang/crates.io-index)" = "3ab49e9dcb602294bc42f9a7dfc9bc6e936fca4418ea300dbfb84fe16de0b7d9"
"checksum fake-simd 0.1.2 (registry+https://github.com/rust-lang/crates.io-index)" = "e88a8acf291dafb59c2d96e8f59828f3838bb1a70398823ade51a84de6a6deed"
"checksum fnv 1.0.6 (registry+https://github.com/rust-lang/crates.io-index)" = "2fad85553e09a6f881f739c29f0b00b0f01357c743266d478b68951ce23285f3"
"checksum fuchsia-zircon 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "2e9763c69ebaae630ba35f74888db465e49e259ba1bc0eda7d06f4a067615d82"
"checksum fuchsia-zircon-sys 0.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "3dcaa9ae7725d12cdb85b3ad99a434db70b468c09ded17e012d86b5c1010f7a7"
@ -1475,6 +1520,7 @@ dependencies = [
"checksum mio-uds 0.6.7 (registry+https://github.com/rust-lang/crates.io-index)" = "966257a94e196b11bb43aca423754d87429960a768de9414f3691d6957abf125"
"checksum miow 0.2.1 (registry+https://github.com/rust-lang/crates.io-index)" = "8c1f2f3b1cf331de6896aabf6e9d55dca90356cc9960cca7eaaf408a355ae919"
"checksum net2 0.2.33 (registry+https://github.com/rust-lang/crates.io-index)" = "42550d9fb7b6684a6d404d9fa7250c2eb2646df731d1c06afc06dcee9e1bcf88"
"checksum num-bigint 0.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "f9c3f34cdd24f334cb265d9bf8bfa8a241920d026916785747a92f0e55541a1a"
"checksum num-integer 0.1.41 (registry+https://github.com/rust-lang/crates.io-index)" = "b85e541ef8255f6cf42bbfe4ef361305c6c135d10919ecc26126c4e5ae94bc09"
"checksum num-traits 0.2.10 (registry+https://github.com/rust-lang/crates.io-index)" = "d4c81ffc11c212fa327657cb19dd85eb7419e163b5b076bede2bdb5c974c07e4"
"checksum num_cpus 1.11.1 (registry+https://github.com/rust-lang/crates.io-index)" = "76dac5ed2a876980778b8b85f75a71b6cbf0db0b1232ee12f826bccb00d09d72"
@ -1509,6 +1555,8 @@ dependencies = [
"checksum serde_derive 1.0.104 (registry+https://github.com/rust-lang/crates.io-index)" = "128f9e303a5a29922045a830221b8f78ec74a5f544944f3d5984f8ec3895ef64"
"checksum serde_json 1.0.44 (registry+https://github.com/rust-lang/crates.io-index)" = "48c575e0cc52bdd09b47f330f646cf59afc586e9c4e3ccd6fc1f625b8ea1dad7"
"checksum serde_qs 0.5.2 (registry+https://github.com/rust-lang/crates.io-index)" = "d43eef44996bbe16e99ac720e1577eefa16f7b76b5172165c98ced20ae9903e1"
"checksum sha-1 0.8.1 (registry+https://github.com/rust-lang/crates.io-index)" = "23962131a91661d643c98940b20fcaffe62d776a823247be80a48fcb8b6fce68"
"checksum sha2 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "7b4d8bfd0e469f417657573d8451fb33d16cfe0989359b93baf3a1ffc639543d"
"checksum slab 0.4.2 (registry+https://github.com/rust-lang/crates.io-index)" = "c111b5bd5695e56cffe5129854aa230b39c93a305372fdbb2668ca2394eea9f8"
"checksum smallvec 0.6.13 (registry+https://github.com/rust-lang/crates.io-index)" = "f7b0758c52e15a8b5e3691eae6cc559f08eee9406e548a4477ba4e67770a82b6"
"checksum smallvec 1.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "44e59e0c9fa00817912ae6e4e6e3c4fe04455e75699d06eedc7d85917ed8e8f4"

View file

@ -15,22 +15,29 @@ authors = [
[features]
default = []
unstable = []
postgres = []
mysql = []
postgres = [ "md-5" ]
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
[dependencies]
async-stream = { version = "0.2.0", default-features = false }
async-std = { version = "1.4.0", default-features = false, features = [ "unstable" ] }
async-stream = { version = "0.2.0", default-features = false }
base64 = { version = "0.11.0", default-features = false, optional = true, features = [ "std" ] }
bitflags = { version = "1.2.1", default-features = false }
byteorder = { version = "1.3.2", default-features = false }
chrono = { version = "0.4.10", default-features = false, features = [ "clock" ], optional = true }
digest = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] }
futures-core = { version = "0.3.1", default-features = false }
futures-util = { version = "0.3.1", default-features = false }
generic-array = { version = "0.12.3", default-features = false, optional = true }
log = { version = "0.4.8", default-features = false }
url = { version = "2.1.0", default-features = false }
byteorder = { version = "1.3.2", default-features = false }
md-5 = { version = "0.8.0", default-features = false, optional = true }
memchr = { version = "2.2.1", default-features = false }
md-5 = { version = "0.8.0", default-features = false }
num-bigint = { version = "0.2.3", default-features = false, optional = true, features = [ "std" ] }
rand = { version = "0.7.2", default-features = false, optional = true, features = [ "std" ] }
sha-1 = { version = "0.8.1", default-features = false, optional = true }
sha2 = { version = "0.8.0", default-features = false, optional = true }
url = { version = "2.1.0", default-features = false }
uuid = { version = "0.8.1", default-features = false, optional = true }
chrono = { version = "0.4.10", default-features = false, features = [ "clock" ], optional = true }
[dev-dependencies]
matches = "0.1.8"

View file

@ -4,6 +4,8 @@ use std::io;
use async_std::net::{Shutdown, TcpStream};
use byteorder::{ByteOrder, LittleEndian};
use futures_core::future::BoxFuture;
use sha1::Sha1;
use sha2::{Digest, Sha256};
use crate::cache::StatementCache;
use crate::connection::Connection;
@ -11,10 +13,18 @@ use crate::executor::Executor;
use crate::io::{Buf, BufMut, BufStream};
use crate::mysql::error::MySqlError;
use crate::mysql::protocol::{
Capabilities, Decode, Encode, EofPacket, ErrPacket, Handshake, HandshakeResponse, OkPacket,
AuthPlugin, AuthSwitch, Capabilities, Decode, Encode, EofPacket, ErrPacket, Handshake,
HandshakeResponse, OkPacket,
};
use crate::mysql::rsa;
use crate::mysql::util::xor_eq;
use crate::url::Url;
// Size before a packet is split
const MAX_PACKET_SIZE: u32 = 1024;
const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224;
/// An asynchronous connection to a [MySql] database.
///
/// The connection string expected by [Connection::open] should be a MySQL connection
@ -23,25 +33,27 @@ use crate::url::Url;
pub struct MySqlConnection {
pub(super) stream: BufStream<TcpStream>,
// Active capabilities of the client _&_ the server
pub(super) capabilities: Capabilities,
// Cache of prepared statements
// Query (String) to StatementId to ColumnMap
pub(super) statement_cache: StatementCache<u32>,
rbuf: Vec<u8>,
// Packets are buffered into a second buffer from the stream
// as we may have compressed or split packets to figure out before
// decoding
pub(super) packet: Vec<u8>,
packet_len: usize,
next_seq_no: u8,
pub(super) ready: bool,
// Packets in a command sequence have an incrementing sequence number
// This number must be 0 at the start of each command
pub(super) next_seq_no: u8,
}
impl MySqlConnection {
pub(super) fn begin_command_phase(&mut self) {
// At the start of the *command phase*, the sequence ID sent from the client
// must be 0
self.next_seq_no = 0;
}
pub(super) fn write(&mut self, packet: impl Encode + std::fmt::Debug) {
/// Write the packet to the stream ( do not send to the server )
pub(crate) fn write(&mut self, packet: impl Encode) {
let buf = self.stream.buffer_mut();
// Allocate room for the header that we write after the packet;
@ -66,51 +78,42 @@ impl MySqlConnection {
self.next_seq_no = self.next_seq_no.wrapping_add(1);
}
async fn receive_ok(&mut self) -> crate::Result<OkPacket> {
let packet = self.receive().await?;
Ok(match packet[0] {
0xfe | 0x00 => OkPacket::decode(packet)?,
0xff => {
return Err(MySqlError(ErrPacket::decode(packet)?).into());
}
id => {
return Err(protocol_err!(
"unexpected packet identifier 0x{:X?} when expecting 0xFE (OK) or 0xFF \
(ERR)",
id
)
.into());
}
})
}
pub(super) async fn receive_eof(&mut self) -> crate::Result<()> {
// When (legacy) EOFs are enabled, the fixed number column definitions are further
// terminated by an EOF packet
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
let _eof = EofPacket::decode(self.receive().await?)?;
}
/// Send the packet to the database server
pub(crate) async fn send(&mut self, packet: impl Encode) -> crate::Result<()> {
self.write(packet);
self.stream.flush().await?;
Ok(())
}
pub(super) async fn receive(&mut self) -> crate::Result<&[u8]> {
Ok(self
.try_receive()
.await?
.ok_or(io::ErrorKind::UnexpectedEof)?)
/// Send a [HandshakeResponse] packet to the database server
pub(crate) async fn send_handshake_response(
&mut self,
url: &Url,
auth_plugin: &AuthPlugin,
auth_response: &[u8],
) -> crate::Result<()> {
self.send(HandshakeResponse {
client_collation: COLLATE_UTF8MB4_UNICODE_CI,
max_packet_size: MAX_PACKET_SIZE,
username: url.username().unwrap_or("root"),
database: url.database(),
auth_plugin,
auth_response,
})
.await
}
pub(super) async fn try_receive(&mut self) -> crate::Result<Option<&[u8]>> {
self.rbuf.clear();
/// Try to receive a packet from the database server. Returns `None` if the server has sent
/// no data.
pub(crate) async fn try_receive(&mut self) -> crate::Result<Option<()>> {
self.packet.clear();
// Read the packet header which contains the length and the sequence number
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
let mut header = ret_if_none!(self.stream.peek(4).await?);
let payload_len = header.get_uint::<LittleEndian>(3)? as usize;
self.packet_len = header.get_uint::<LittleEndian>(3)? as usize;
self.next_seq_no = header.get_u8()?.wrapping_add(1);
self.stream.consume(4);
@ -118,66 +121,221 @@ impl MySqlConnection {
// We must have a separate buffer around the stream as we can't operate directly
// on bytes returned from the stream. We have various kinds of payload manipulation
// that must be handled before decoding.
let mut payload = ret_if_none!(self.stream.peek(payload_len).await?);
self.rbuf.extend_from_slice(payload);
self.stream.consume(payload_len);
let mut payload = ret_if_none!(self.stream.peek(self.packet_len).await?);
self.packet.extend_from_slice(payload);
self.stream.consume(self.packet_len);
// TODO: Implement packet compression
// TODO: Implement packet joining
Ok(Some(&self.rbuf[..payload_len]))
Ok(Some(()))
}
}
impl MySqlConnection {
// TODO: Authentication ?!
pub(super) async fn open(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?;
let stream = TcpStream::connect((url.host(), url.port(3306))).await?;
/// Receive a complete packet from the database server.
pub(crate) async fn receive(&mut self) -> crate::Result<&mut Self> {
self.try_receive()
.await?
.ok_or(io::ErrorKind::UnexpectedEof)?;
let mut self_ = Self {
stream: BufStream::new(stream),
capabilities: Capabilities::empty(),
rbuf: Vec::with_capacity(8192),
next_seq_no: 0,
statement_cache: StatementCache::new(),
ready: true,
};
Ok(self)
}
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
// https://mariadb.com/kb/en/connection/
/// Returns a reference to the most recently received packet data
#[inline]
pub(crate) fn packet(&self) -> &[u8] {
&self.packet[..self.packet_len]
}
// First, we receive the Handshake
/// Receive an [EofPacket] if we are supposed to receive them at all.
pub(crate) async fn receive_eof(&mut self) -> crate::Result<()> {
// When (legacy) EOFs are enabled, many things are terminated by an EOF packet
if !self.capabilities.contains(Capabilities::DEPRECATE_EOF) {
let _eof = EofPacket::decode(self.receive().await?.packet())?;
}
let handshake_packet = self_.receive().await?;
let handshake = Handshake::decode(handshake_packet)?;
Ok(())
}
let mut client_capabilities =
Capabilities::PROTOCOL_41 | Capabilities::IGNORE_SPACE | Capabilities::FOUND_ROWS;
/// Receive a [Handshake] packet. When connecting to the database server, this is immediately
/// received from the database server.
pub(crate) async fn receive_handshake(&mut self, url: &Url) -> crate::Result<Handshake> {
let handshake = Handshake::decode(self.receive().await?.packet())?;
let mut client_capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
| Capabilities::FOUND_ROWS
| Capabilities::PLUGIN_AUTH;
if url.database().is_some() {
client_capabilities |= Capabilities::CONNECT_WITH_DB;
}
// Fails if [Capabilities::PROTOCOL_41] is not in [server_capabilities]
self_.capabilities =
self.capabilities =
(client_capabilities & handshake.server_capabilities) | Capabilities::PROTOCOL_41;
// Next we send the response
Ok(handshake)
}
self_.write(HandshakeResponse {
client_collation: 192, // utf8_unicode_ci
max_packet_size: 1024,
username: url.username().unwrap_or("root"),
database: url.database(),
auth_plugin_name: handshake.auth_plugin_name.as_deref(),
auth_response: None,
});
/// Receives an [OkPacket] from the database server. This is called at the end of
/// authentication to confirm the established connection.
pub(crate) fn receive_auth_ok<'a>(
&'a mut self,
plugin: &'a AuthPlugin,
password: &'a str,
nonce: &'a [u8],
) -> BoxFuture<'a, crate::Result<()>> {
Box::pin(async move {
self.receive().await?;
self_.stream.flush().await?;
match self.packet[0] {
0x00 => self.handle_ok().map(drop),
0xfe => self.handle_auth_switch(password).await,
0xff => self.handle_err(),
let _ok = self_.receive_ok().await?;
_ => self.handle_auth_continue(plugin, password, nonce).await,
}
})
}
}
impl MySqlConnection {
pub(crate) fn handle_ok(&mut self) -> crate::Result<OkPacket> {
let ok = OkPacket::decode(self.packet())?;
// An OK signifies the end of the current command sequence
self.next_seq_no = 0;
Ok(ok)
}
pub(crate) fn handle_err<T>(&mut self) -> crate::Result<T> {
let err = ErrPacket::decode(self.packet())?;
// An ERR signifies the end of the current command sequence
self.next_seq_no = 0;
Err(MySqlError(err).into())
}
pub(crate) fn handle_unexpected_packet<T>(&self, id: u8) -> crate::Result<T> {
Err(protocol_err!("unexpected packet identifier 0x{:X?}", id).into())
}
pub(crate) async fn handle_auth_continue(
&mut self,
plugin: &AuthPlugin,
password: &str,
nonce: &[u8],
) -> crate::Result<()> {
match plugin {
AuthPlugin::CachingSha2Password => {
if self.packet[0] == 1 {
match self.packet[1] {
// AUTH_OK
0x03 => {}
// AUTH_CONTINUE
0x04 => {
// client sends an RSA encrypted password
let ct = self.rsa_encrypt(0x02, password, nonce).await?;
self.send(&*ct).await?;
}
auth => {
return Err(protocol_err!("unexpected result from 'fast' authentication 0x{:x} when expecting OK (0x03) or CONTINUE (0x04)", auth).into());
}
}
// ends with server sending either OK_Packet or ERR_Packet
self.receive_auth_ok(plugin, password, nonce)
.await
.map(drop)
} else {
return self.handle_unexpected_packet(self.packet[0]);
}
}
// No other supported auth methods will be called through continue
_ => unreachable!(),
}
}
pub(crate) async fn handle_auth_switch(&mut self, password: &str) -> crate::Result<()> {
let auth = AuthSwitch::decode(self.packet())?;
let auth_response = self
.make_auth_initial_response(&auth.auth_plugin, password, &auth.auth_plugin_data)
.await?;
self.send(&*auth_response).await?;
self.receive_auth_ok(&auth.auth_plugin, password, &auth.auth_plugin_data)
.await
}
pub(crate) async fn make_auth_initial_response(
&mut self,
plugin: &AuthPlugin,
password: &str,
nonce: &[u8],
) -> crate::Result<Vec<u8>> {
match plugin {
AuthPlugin::CachingSha2Password | AuthPlugin::MySqlNativePassword => {
Ok(plugin.scramble(password, nonce))
}
AuthPlugin::Sha256Password => {
// Full RSA exchange and password encrypt up front with no "cache"
Ok(self.rsa_encrypt(0x01, password, nonce).await?.into_vec())
}
}
}
pub(crate) async fn rsa_encrypt(
&mut self,
public_key_request_id: u8,
password: &str,
nonce: &[u8],
) -> crate::Result<Box<[u8]>> {
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/
// TODO: Handle SSL
// client sends a public key request
self.send(&[public_key_request_id][..]).await?;
// server sends a public key response
let mut packet = self.receive().await?.packet();
let rsa_pub_key = &packet[1..];
// The password string data must be NUL terminated
// Note: This is not in the documentation that I could find
let mut pass = password.as_bytes().to_vec();
pass.push(0);
xor_eq(&mut pass, nonce);
// client sends an RSA encrypted password
rsa::encrypt::<Sha1>(rsa_pub_key, &pass)
}
}
impl MySqlConnection {
async fn new(url: &Url) -> crate::Result<Self> {
let stream = TcpStream::connect((url.host(), url.port(3306))).await?;
Ok(Self {
stream: BufStream::new(stream),
capabilities: Capabilities::empty(),
packet: Vec::with_capacity(8192),
packet_len: 0,
next_seq_no: 0,
statement_cache: StatementCache::new(),
})
}
async fn initialize(&mut self) -> crate::Result<()> {
// On connect, we want to establish a modern, Rust-compatible baseline so we
// tweak connection options to enable UTC for TIMESTAMP, UTF-8 for character types, etc.
@ -194,25 +352,72 @@ impl MySqlConnection {
// NO_ZERO_DATE - Don't allow '0000-00-00'. This is invalid in Rust.
// NO_ZERO_IN_DATE - Don't allow 'yyyy-00-00'. This is invalid in Rust.
// NO_ZERO_IN_DATE - Don't allow 'YYYY-00-00'. This is invalid in Rust.
self_.send("SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'))")
// language=MySQL
self.execute_raw("SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'))")
.await?;
// This allows us to assume that the output from a TIMESTAMP field is UTC
self_.send("SET time_zone = 'UTC'").await?;
// language=MySQL
self.execute_raw("SET time_zone = 'UTC'").await?;
// https://mathiasbynens.be/notes/mysql-utf8mb4
self_
.send("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci")
// language=MySQL
self.execute_raw("SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci")
.await?;
Ok(())
}
}
impl MySqlConnection {
pub(super) async fn open(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?;
let mut self_ = Self::new(&url).await?;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase.html
// https://mariadb.com/kb/en/connection/
// On connect, server immediately sends the handshake
let handshake = self_.receive_handshake(&url).await?;
// Pre-generate an auth response by using the auth method in the [Handshake]
let password = url.password().unwrap_or_default();
let auth_response = self_
.make_auth_initial_response(
&handshake.auth_plugin,
password,
&handshake.auth_plugin_data,
)
.await?;
self_
.send_handshake_response(&url, &handshake.auth_plugin, &auth_response)
.await?;
// After sending the handshake response with our assumed auth method the server
// will send OK, fail, or tell us to change auth methods
self_
.receive_auth_ok(
&handshake.auth_plugin,
password,
&handshake.auth_plugin_data,
)
.await?;
// After the connection is established, we initialize by configuring a few
// connection parameters
self_.initialize().await?;
Ok(self_)
}
async fn close(mut self) -> crate::Result<()> {
// TODO: Actually tell MySQL that we're closing
self.stream.flush().await?;
self.stream.stream.shutdown(Shutdown::Both)?;

View file

@ -26,7 +26,7 @@ enum OkOrResultSet {
impl MySqlConnection {
async fn ignore_columns(&mut self, count: usize) -> crate::Result<()> {
for _ in 0..count {
let _column = ColumnDefinition::decode(self.receive().await?)?;
let _column = ColumnDefinition::decode(self.receive().await?.packet())?;
}
if count > 0 {
@ -37,35 +37,15 @@ impl MySqlConnection {
}
async fn receive_ok_or_column_count(&mut self) -> crate::Result<OkOrResultSet> {
let packet = self.receive().await?;
self.receive().await?;
match packet[0] {
0xfe if packet.len() < 0xffffff => {
let ok = OkPacket::decode(packet)?;
self.ready = true;
match self.packet[0] {
0x00 | 0xfe if self.packet.len() < 0xffffff => self.handle_ok().map(OkOrResultSet::Ok),
0xff => self.handle_err(),
Ok(OkOrResultSet::Ok(ok))
}
0x00 => {
let ok = OkPacket::decode(packet)?;
self.ready = true;
Ok(OkOrResultSet::Ok(ok))
}
0xff => {
let err = ErrPacket::decode(packet)?;
self.ready = true;
Err(MySqlError(err).into())
}
_ => {
let cc = ColumnCount::decode(packet)?;
Ok(OkOrResultSet::ResultSet(cc))
}
_ => Ok(OkOrResultSet::ResultSet(ColumnCount::decode(
self.packet(),
)?)),
}
}
@ -73,8 +53,8 @@ impl MySqlConnection {
let mut columns: Vec<Type> = Vec::with_capacity(count);
for _ in 0..count {
let packet = self.receive().await?;
let column: ColumnDefinition = ColumnDefinition::decode(packet)?;
let column: ColumnDefinition =
ColumnDefinition::decode(self.receive().await?.packet())?;
columns.push(column.r#type);
}
@ -87,7 +67,7 @@ impl MySqlConnection {
}
async fn wait_for_ready(&mut self) -> crate::Result<()> {
if !self.ready {
if self.next_seq_no != 0 {
while let Some(_step) = self.step(&[], true).await? {
// Drain steps until we hit the end
}
@ -98,21 +78,19 @@ impl MySqlConnection {
async fn prepare(&mut self, query: &str) -> crate::Result<ComStmtPrepareOk> {
// Start by sending a COM_STMT_PREPARE
self.begin_command_phase();
self.write(ComStmtPrepare { query });
self.stream.flush().await?;
self.send(ComStmtPrepare { query }).await?;
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_com_stmt_prepare.html
// First we should receive a COM_STMT_PREPARE_OK
let packet = self.receive().await?;
self.receive().await?;
if packet[0] == 0xff {
if self.packet[0] == 0xff {
// Oops, there was an error in the prepare command
return Err(MySqlError(ErrPacket::decode(packet)?).into());
return self.handle_err();
}
ComStmtPrepareOk::decode(packet)
ComStmtPrepareOk::decode(self.packet())
}
async fn prepare_with_cache(&mut self, query: &str) -> crate::Result<u32> {
@ -132,7 +110,7 @@ impl MySqlConnection {
let mut columns = HashMap::with_capacity(prepare_ok.columns as usize);
let mut index = 0_usize;
for _ in 0..prepare_ok.columns {
let column = ColumnDefinition::decode(self.receive().await?)?;
let column = ColumnDefinition::decode(self.receive().await?.packet())?;
if let Some(name) = column.column_alias.or(column.column) {
columns.insert(name, index);
@ -145,6 +123,9 @@ impl MySqlConnection {
self.receive_eof().await?;
}
// At the end of a command, this should go back to 0
self.next_seq_no = 0;
// Remember our column map in the statement cache
self.statement_cache
.put_columns(prepare_ok.statement_id, columns);
@ -155,73 +136,59 @@ impl MySqlConnection {
// [COM_STMT_EXECUTE]
async fn execute_statement(&mut self, id: u32, args: MySqlArguments) -> crate::Result<()> {
self.begin_command_phase();
self.ready = false;
self.write(ComStmtExecute {
self.send(ComStmtExecute {
cursor: Cursor::NO_CURSOR,
statement_id: id,
params: &args.params,
null_bitmap: &args.null_bitmap,
param_types: &args.param_types,
});
self.stream.flush().await?;
Ok(())
})
.await
}
async fn step(&mut self, columns: &[Type], binary: bool) -> crate::Result<Option<Step>> {
let capabilities = self.capabilities;
let packet = ret_if_none!(self.try_receive().await?);
match packet[0] {
0xfe if packet.len() < 0xffffff => {
// Resultset row can begin with 0xfe byte (when using text protocol
match self.packet[0] {
0xfe if self.packet.len() < 0xffffff => {
// ResultSet row can begin with 0xfe byte (when using text protocol
// with a field length > 0xffffff)
if !capabilities.contains(Capabilities::DEPRECATE_EOF) {
let _eof = EofPacket::decode(packet)?;
self.ready = true;
let _eof = EofPacket::decode(self.packet())?;
return Ok(None);
// An EOF -here- signifies the end of the current command sequence
self.next_seq_no = 0;
Ok(None)
} else {
let ok = OkPacket::decode(packet)?;
self.ready = true;
return Ok(Some(Step::Command(ok.affected_rows)));
self.handle_ok()
.map(|ok| Some(Step::Command(ok.affected_rows)))
}
}
0xff => {
let err = ErrPacket::decode(packet)?;
self.ready = true;
0xff => self.handle_err(),
return Err(MySqlError(err).into());
}
_ => {
return Ok(Some(Step::Row(Row::decode(packet, columns, binary)?)));
}
_ => Ok(Some(Step::Row(Row::decode(
self.packet(),
columns,
binary,
)?))),
}
}
}
impl MySqlConnection {
async fn send(&mut self, query: &str) -> crate::Result<()> {
pub(super) async fn execute_raw(&mut self, query: &str) -> crate::Result<()> {
self.wait_for_ready().await?;
self.begin_command_phase();
self.ready = false;
// enable multi-statement only for this query
self.write(ComQuery { query });
self.stream.flush().await?;
self.send(ComQuery { query }).await?;
// COM_QUERY can terminate before the result set with an ERR or OK packet
let num_columns = match self.receive_ok_or_column_count().await? {
OkOrResultSet::Ok(_) => {
self.next_seq_no = 0;
return Ok(());
}
@ -247,6 +214,8 @@ impl MySqlConnection {
// COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet
let num_columns = match self.receive_ok_or_column_count().await? {
OkOrResultSet::Ok(ok) => {
self.next_seq_no = 0;
return Ok(ok.affected_rows);
}
@ -275,7 +244,7 @@ impl MySqlConnection {
let mut result_columns = Vec::with_capacity(prepare_ok.columns as usize);
for _ in 0..prepare_ok.params {
let param = ColumnDefinition::decode(self.receive().await?)?;
let param = ColumnDefinition::decode(self.receive().await?.packet())?;
param_types.push(param.r#type.0);
}
@ -284,7 +253,7 @@ impl MySqlConnection {
}
for _ in 0..prepare_ok.columns {
let column = ColumnDefinition::decode(self.receive().await?)?;
let column = ColumnDefinition::decode(self.receive().await?.packet())?;
result_columns.push(Column::<MySql> {
name: column.column_alias.or(column.column),
@ -298,6 +267,9 @@ impl MySqlConnection {
self.receive_eof().await?;
}
// Command sequence is over
self.next_seq_no = 0;
Ok(Describe {
param_types: param_types.into_boxed_slice(),
result_columns: result_columns.into_boxed_slice(),
@ -321,6 +293,7 @@ impl MySqlConnection {
// COM_STMT_EXECUTE can terminate before the result set with an ERR or OK packet
let num_columns = match self.receive_ok_or_column_count().await? {
OkOrResultSet::Ok(_) => {
self.next_seq_no = 0;
return;
}
@ -342,7 +315,7 @@ impl Executor for MySqlConnection {
type Database = super::MySql;
fn send<'e, 'q: 'e>(&'e mut self, query: &'q str) -> BoxFuture<'e, crate::Result<()>> {
Box::pin(self.send(query))
Box::pin(self.execute_raw(query))
}
fn execute<'e, 'q: 'e>(

View file

@ -8,7 +8,9 @@ mod executor;
mod io;
mod protocol;
mod row;
mod rsa;
mod types;
mod util;
pub use database::MySql;

View file

@ -0,0 +1,100 @@
use digest::{Digest, FixedOutput};
use generic_array::GenericArray;
use sha1::Sha1;
use sha2::Sha256;
use crate::mysql::util::xor_eq;
#[derive(Debug)]
pub enum AuthPlugin {
MySqlNativePassword,
CachingSha2Password,
Sha256Password,
}
impl AuthPlugin {
pub(crate) fn from_opt_str(s: Option<&str>) -> crate::Result<AuthPlugin> {
match s {
Some("mysql_native_password") | None => Ok(AuthPlugin::MySqlNativePassword),
Some("caching_sha2_password") => Ok(AuthPlugin::CachingSha2Password),
Some("sha256_password") => Ok(AuthPlugin::Sha256Password),
Some(s) => {
Err(protocol_err!("requires unimplemented authentication plugin: {}", s).into())
}
}
}
pub(crate) fn as_str(&self) -> &'static str {
match self {
AuthPlugin::MySqlNativePassword => "mysql_native_password",
AuthPlugin::CachingSha2Password => "caching_sha2_password",
AuthPlugin::Sha256Password => "sha256_password",
}
}
pub(crate) fn scramble(&self, password: &str, nonce: &[u8]) -> Vec<u8> {
match self {
AuthPlugin::MySqlNativePassword => {
// The [nonce] for mysql_native_password is nul terminated
scramble_sha1(password, &nonce[..(nonce.len() - 1)]).to_vec()
}
AuthPlugin::CachingSha2Password => scramble_sha256(password, nonce).to_vec(),
_ => unimplemented!(),
}
}
}
fn scramble_sha1(
password: &str,
seed: &[u8],
) -> GenericArray<u8, <Sha1 as FixedOutput>::OutputSize> {
// SHA1( password ) ^ SHA1( seed + SHA1( SHA1( password ) ) )
// https://mariadb.com/kb/en/connection/#mysql_native_password-plugin
let mut ctx = Sha1::new();
ctx.input(password);
let mut pw_hash = ctx.result_reset();
ctx.input(&pw_hash);
let pw_hash_hash = ctx.result_reset();
ctx.input(seed);
ctx.input(pw_hash_hash);
let pw_seed_hash_hash = ctx.result();
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
pw_hash
}
fn scramble_sha256(
password: &str,
seed: &[u8],
) -> GenericArray<u8, <Sha256 as FixedOutput>::OutputSize> {
// XOR(SHA256(password), SHA256(seed, SHA256(SHA256(password))))
// https://mariadb.com/kb/en/caching_sha2_password-authentication-plugin/#sha-2-encrypted-password
let mut ctx = Sha256::new();
ctx.input(password);
let mut pw_hash = ctx.result_reset();
ctx.input(&pw_hash);
let pw_hash_hash = ctx.result_reset();
ctx.input(seed);
ctx.input(pw_hash_hash);
let pw_seed_hash_hash = ctx.result();
xor_eq(&mut pw_hash, &pw_seed_hash_hash);
pw_hash
}

View file

@ -0,0 +1,34 @@
use byteorder::LittleEndian;
use crate::io::Buf;
use crate::mysql::protocol::{AuthPlugin, Capabilities, Decode, Status};
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_auth_switch_request.html
#[derive(Debug)]
pub struct AuthSwitch {
pub auth_plugin: AuthPlugin,
pub auth_plugin_data: Box<[u8]>,
}
impl Decode for AuthSwitch {
fn decode(mut buf: &[u8]) -> crate::Result<Self>
where
Self: Sized,
{
let header = buf.get_u8()?;
if header != 0xFE {
return Err(protocol_err!(
"expected AUTH SWITCH (0xFE); received 0x{:X}",
header
))?;
}
let auth_plugin = AuthPlugin::from_opt_str(Some(buf.get_str_nul()?))?;
let auth_plugin_data = buf.get_bytes(buf.len())?.to_owned().into_boxed_slice();
Ok(Self {
auth_plugin_data,
auth_plugin,
})
}
}

View file

@ -1,5 +1,12 @@
use crate::io::BufMut;
use crate::mysql::protocol::Capabilities;
pub trait Encode {
fn encode(&self, buf: &mut Vec<u8>, capabilities: Capabilities);
}
impl Encode for &'_ [u8] {
fn encode(&self, buf: &mut Vec<u8>, _: Capabilities) {
buf.put_bytes(self);
}
}

View file

@ -34,19 +34,3 @@ impl Decode for EofPacket {
})
}
}
//#[cfg(test)]
//mod tests {
// use super::{Capabilities, Decode, ErrPacket, Status};
//
// const ERR_HANDSHAKE_UNKNOWN_DB: &[u8] = b"\xff\x19\x04#42000Unknown database \'unknown\'";
//
// #[test]
// fn it_decodes_ok_handshake() {
// let mut p = ErrPacket::decode(ERR_HANDSHAKE_UNKNOWN_DB).unwrap();
//
// assert_eq!(p.error_code, 1049);
// assert_eq!(&*p.sql_state, "42000");
// assert_eq!(&*p.error_message, "Unknown database \'unknown\'");
// }
//}

View file

@ -1,7 +1,7 @@
use byteorder::LittleEndian;
use crate::io::Buf;
use crate::mysql::protocol::{Capabilities, Decode, Status};
use crate::mysql::protocol::{AuthPlugin, Capabilities, Decode, Status};
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_v10.html
// https://mariadb.com/kb/en/connection/#initial-handshake-packet
@ -13,7 +13,7 @@ pub struct Handshake {
pub server_capabilities: Capabilities,
pub server_default_collation: u8,
pub status: Status,
pub auth_plugin_name: Option<Box<str>>,
pub auth_plugin: AuthPlugin,
pub auth_plugin_data: Box<[u8]>,
}
@ -81,10 +81,10 @@ impl Decode for Handshake {
buf.advance(1);
}
let auth_plugin_name = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
Some(buf.get_str_nul()?.to_owned().into())
let auth_plugin = if capabilities.contains(Capabilities::PLUGIN_AUTH) {
AuthPlugin::from_opt_str(Some(buf.get_str_nul()?))?
} else {
None
AuthPlugin::from_opt_str(None)?
};
Ok(Self {
@ -94,7 +94,7 @@ impl Decode for Handshake {
server_default_collation: char_set,
connection_id,
auth_plugin_data: scramble.into_boxed_slice(),
auth_plugin_name,
auth_plugin,
status,
})
}
@ -102,7 +102,8 @@ impl Decode for Handshake {
#[cfg(test)]
mod tests {
use super::{Capabilities, Decode, Handshake, Status};
use super::{AuthPlugin, Capabilities, Decode, Handshake, Status};
use matches::assert_matches;
const HANDSHAKE_MARIA_DB_10_4_7: &[u8] = b"\n5.5.5-10.4.7-MariaDB-1:10.4.7+maria~bionic\x00\x0b\x00\x00\x00t6L\\j\"dS\x00\xfe\xf7\x08\x02\x00\xff\x81\x15\x00\x00\x00\x00\x00\x00\x07\x00\x00\x00U14Oph9\"<H5n\x00mysql_native_password\x00";
const HANDSHAKE_MYSQL_8_0_18: &[u8] = b"\n8.0.18\x00\x19\x00\x00\x00\x114aB0c\x06g\x00\xff\xff\xff\x02\x00\xff\xc7\x15\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00tL\x03s\x0f[4\rl4. \x00caching_sha2_password\x00";
@ -147,7 +148,7 @@ mod tests {
assert_eq!(p.server_default_collation, 255);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert_eq!(p.auth_plugin_name.as_deref(), Some("caching_sha2_password"));
assert_matches!(p.auth_plugin, AuthPlugin::CachingSha2Password);
assert_eq!(
&*p.auth_plugin_data,
@ -195,7 +196,7 @@ mod tests {
assert_eq!(p.server_default_collation, 8);
assert!(p.status.contains(Status::SERVER_STATUS_AUTOCOMMIT));
assert_eq!(p.auth_plugin_name.as_deref(), Some("mysql_native_password"));
assert_matches!(p.auth_plugin, AuthPlugin::MySqlNativePassword);
assert_eq!(
&*p.auth_plugin_data,

View file

@ -2,7 +2,7 @@ use byteorder::LittleEndian;
use crate::io::BufMut;
use crate::mysql::io::BufMutExt;
use crate::mysql::protocol::{Capabilities, Encode};
use crate::mysql::protocol::{AuthPlugin, Capabilities, Encode};
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_connection_phase_packets_protocol_handshake_response.html
// https://mariadb.com/kb/en/connection/#handshake-response-packet
@ -12,8 +12,8 @@ pub struct HandshakeResponse<'a> {
pub client_collation: u8,
pub username: &'a str,
pub database: Option<&'a str>,
pub auth_plugin_name: Option<&'a str>,
pub auth_response: Option<&'a str>,
pub auth_plugin: &'a AuthPlugin,
pub auth_response: &'a [u8],
}
impl Encode for HandshakeResponse<'_> {
@ -43,15 +43,15 @@ impl Encode for HandshakeResponse<'_> {
if capabilities.contains(Capabilities::PLUGIN_AUTH_LENENC_DATA) {
// auth_response : string<lenenc>
buf.put_str_lenenc::<LittleEndian>(self.auth_response.unwrap_or_default());
buf.put_bytes_lenenc::<LittleEndian>(self.auth_response);
} else {
let auth_response = self.auth_response.unwrap_or_default();
let auth_response = self.auth_response;
// auth_response_length : int<1>
buf.put_u8(auth_response.len() as u8);
// auth_response : string<{auth_response_length}>
buf.put_str(auth_response);
buf.put_bytes(auth_response);
}
if capabilities.contains(Capabilities::CONNECT_WITH_DB) {
@ -63,7 +63,7 @@ impl Encode for HandshakeResponse<'_> {
if capabilities.contains(Capabilities::PLUGIN_AUTH) {
// client_plugin_name : string<NUL>
buf.put_str_nul(self.auth_plugin_name.unwrap_or_default());
buf.put_str_nul(self.auth_plugin.as_str());
}
}
}

View file

@ -8,11 +8,13 @@ mod encode;
pub use decode::Decode;
pub use encode::Encode;
mod auth_plugin;
mod capabilities;
mod field;
mod status;
mod r#type;
pub use auth_plugin::AuthPlugin;
pub use capabilities::Capabilities;
pub use field::FieldFlags;
pub use r#type::Type;
@ -30,6 +32,7 @@ pub use com_stmt_execute::{ComStmtExecute, Cursor};
pub use com_stmt_prepare::ComStmtPrepare;
pub use handshake::Handshake;
mod auth_switch;
mod column_count;
mod column_def;
mod com_stmt_prepare_ok;
@ -39,6 +42,7 @@ mod handshake_response;
mod ok;
mod row;
pub use auth_switch::AuthSwitch;
pub use column_count::ColumnCount;
pub use column_def::ColumnDefinition;
pub use com_stmt_prepare_ok::ComStmtPrepareOk;

265
sqlx-core/src/mysql/rsa.rs Normal file
View file

@ -0,0 +1,265 @@
use digest::{Digest, DynDigest};
use num_bigint::BigUint;
use rand::{thread_rng, Rng};
// This is mostly taken from https://github.com/RustCrypto/RSA/pull/18
// For the love of crypto, please delete as much of this as possible and use the RSA crate
// directly when that PR is merged
pub fn encrypt<D: Digest>(key: &[u8], message: &[u8]) -> crate::Result<Box<[u8]>> {
let key = std::str::from_utf8(key).map_err(|_err| {
// TODO(@abonander): protocol_err doesn't like referring to [err]
protocol_err!("unexpected error decoding what should be UTF-8")
})?;
let key = parse(key)?;
Ok(oaep_encrypt::<_, D>(&mut thread_rng(), &key, message)?.into_boxed_slice())
}
// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/internals.rs#L12
fn internals_encrypt(key: &PublicKey, m: &BigUint) -> BigUint {
m.modpow(&key.e, &key.n)
}
// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/internals.rs#L184
fn internals_copy_with_left_pad(dest: &mut [u8], src: &[u8]) {
// left pad with zeros
let padding_bytes = dest.len() - src.len();
for el in dest.iter_mut().take(padding_bytes) {
*el = 0;
}
dest[padding_bytes..].copy_from_slice(src);
}
// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/oaep.rs#L13
fn internals_inc_counter(counter: &mut [u8]) {
if counter[3] == u8::max_value() {
counter[3] = 0;
} else {
counter[3] += 1;
return;
}
if counter[2] == u8::max_value() {
counter[2] = 0;
} else {
counter[2] += 1;
return;
}
if counter[1] == u8::max_value() {
counter[1] = 0;
} else {
counter[1] += 1;
return;
}
if counter[0] == u8::max_value() {
counter[0] = 0u8;
counter[1] = 0u8;
counter[2] = 0u8;
counter[3] = 0u8;
} else {
counter[0] += 1;
}
}
// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/oaep.rs#L46
fn oeap_mgf1_xor<D: Digest>(out: &mut [u8], digest: &mut D, seed: &[u8]) {
let mut counter = vec![0u8; 4];
let mut i = 0;
while i < out.len() {
let mut digest_input = vec![0u8; seed.len() + 4];
digest_input[0..seed.len()].copy_from_slice(seed);
digest_input[seed.len()..].copy_from_slice(&counter);
digest.input(digest_input.as_slice());
let digest_output = &*digest.result_reset();
let mut j = 0;
loop {
if j >= digest_output.len() || i >= out.len() {
break;
}
out[i] ^= digest_output[j];
j += 1;
i += 1;
}
internals_inc_counter(counter.as_mut_slice());
}
}
// https://github.com/RustCrypto/RSA/blob/9f1464c43831d422d9903574aad6ab072db9f2b0/src/oaep.rs#L75
fn oaep_encrypt<R: Rng, D: Digest>(
rng: &mut R,
pub_key: &PublicKey,
msg: &[u8],
) -> crate::Result<Vec<u8>> {
// size of [n] in bytes
let k = (pub_key.n.bits() + 7) / 8;
let mut digest = D::new();
let h_size = D::output_size();
if msg.len() > k - 2 * h_size - 2 {
return Err(protocol_err!("mysql: password too long").into());
}
let mut em = vec![0u8; k];
let (_, payload) = em.split_at_mut(1);
let (seed, db) = payload.split_at_mut(h_size);
rng.fill(seed);
// Data block DB = pHash || PS || 01 || M
let db_len = k - h_size - 1;
let p_hash = digest.result_reset();
db[0..h_size].copy_from_slice(&*p_hash);
db[db_len - msg.len() - 1] = 1;
db[db_len - msg.len()..].copy_from_slice(msg);
oeap_mgf1_xor(db, &mut digest, seed);
oeap_mgf1_xor(seed, &mut digest, db);
{
let mut m = BigUint::from_bytes_be(&em);
let mut c = internals_encrypt(pub_key, &m).to_bytes_be();
internals_copy_with_left_pad(&mut em, &c);
}
Ok(em)
}
#[derive(Debug)]
struct PublicKey {
n: BigUint,
e: BigUint,
}
fn parse(key: &str) -> crate::Result<PublicKey> {
// This takes advantage of the knowledge that we know
// we are receiving a PKCS#8 RSA Public Key at all
// times from MySQL
if !key.starts_with("-----BEGIN PUBLIC KEY-----\n") {
return Err(protocol_err!(
"unexpected format for RSA Public Key from MySQL (expected PKCS#8); first line: {:?}",
key.splitn(1, '\n').next()
)
.into());
}
let key_with_trailer = key.trim_start_matches("-----BEGIN PUBLIC KEY-----\n");
let trailer_pos = key_with_trailer.find('-').unwrap_or(0);
let inner_key = key_with_trailer[..trailer_pos].replace('\n', "");
let inner = base64::decode(&inner_key).map_err(|_err| {
// TODO(@abonander): protocol_err doesn't like referring to [err]
protocol_err!("unexpected error decoding what should be base64-encoded data")
})?;
let len = inner.len();
let n_bytes = &inner[(len - 257 - 5)..(len - 5)];
let e_bytes = &inner[(len - 3)..];
let n = BigUint::from_bytes_be(n_bytes);
let e = BigUint::from_bytes_be(e_bytes);
Ok(PublicKey { n, e })
}
#[cfg(test)]
mod tests {
use super::{BigUint, PublicKey};
use rand::rngs::adapter::ReadRng;
use sha1::Sha1;
use sha2::Sha256;
const INPUT: &str = "-----BEGIN PUBLIC KEY-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAv9E+l0oFIoGnZmu6bdil\nI3WK79iug/hukj5QrWRrJVVCHL8rRxNsQGYPvQfXgqEnJW0Rqy2BBebNrnSMduny\nCazz1KM1h57hSI1xHGhg/o82Us1j9fUucKo0Pt3vg7xjVVcN0j1bwr96gEbt6B4Q\nt4eKZBhtle1bgoBcqFBhGfU17cnedSzMUCutM+kXTzzOTplKoqXeJpEZDTX8AP9F\nQ9JkoA22yTn8H2GROIAffm1UQS7DXXjI5OnzBJNs72oNSeK8i72xLkoSdfVw3vCu\ni+mpt4LJgAZLvzc2O4nLzu4Bljb+Mrch34HSWyxOfWzt1v9vpJfEVQ2/VZaIng6U\nUQIDAQAB\n-----END PUBLIC KEY-----\n";
#[test]
fn it_parses() {
let key = super::parse(INPUT).unwrap();
let n = &[
0xbf, 0xd1, 0x3e, 0x97, 0x4a, 0x5, 0x22, 0x81, 0xa7, 0x66, 0x6b, 0xba, 0x6d, 0xd8,
0xa5, 0x23, 0x75, 0x8a, 0xef, 0xd8, 0xae, 0x83, 0xf8, 0x6e, 0x92, 0x3e, 0x50, 0xad,
0x64, 0x6b, 0x25, 0x55, 0x42, 0x1c, 0xbf, 0x2b, 0x47, 0x13, 0x6c, 0x40, 0x66, 0xf,
0xbd, 0x7, 0xd7, 0x82, 0xa1, 0x27, 0x25, 0x6d, 0x11, 0xab, 0x2d, 0x81, 0x5, 0xe6, 0xcd,
0xae, 0x74, 0x8c, 0x76, 0xe9, 0xf2, 0x9, 0xac, 0xf3, 0xd4, 0xa3, 0x35, 0x87, 0x9e,
0xe1, 0x48, 0x8d, 0x71, 0x1c, 0x68, 0x60, 0xfe, 0x8f, 0x36, 0x52, 0xcd, 0x63, 0xf5,
0xf5, 0x2e, 0x70, 0xaa, 0x34, 0x3e, 0xdd, 0xef, 0x83, 0xbc, 0x63, 0x55, 0x57, 0xd,
0xd2, 0x3d, 0x5b, 0xc2, 0xbf, 0x7a, 0x80, 0x46, 0xed, 0xe8, 0x1e, 0x10, 0xb7, 0x87,
0x8a, 0x64, 0x18, 0x6d, 0x95, 0xed, 0x5b, 0x82, 0x80, 0x5c, 0xa8, 0x50, 0x61, 0x19,
0xf5, 0x35, 0xed, 0xc9, 0xde, 0x75, 0x2c, 0xcc, 0x50, 0x2b, 0xad, 0x33, 0xe9, 0x17,
0x4f, 0x3c, 0xce, 0x4e, 0x99, 0x4a, 0xa2, 0xa5, 0xde, 0x26, 0x91, 0x19, 0xd, 0x35,
0xfc, 0x0, 0xff, 0x45, 0x43, 0xd2, 0x64, 0xa0, 0xd, 0xb6, 0xc9, 0x39, 0xfc, 0x1f, 0x61,
0x91, 0x38, 0x80, 0x1f, 0x7e, 0x6d, 0x54, 0x41, 0x2e, 0xc3, 0x5d, 0x78, 0xc8, 0xe4,
0xe9, 0xf3, 0x4, 0x93, 0x6c, 0xef, 0x6a, 0xd, 0x49, 0xe2, 0xbc, 0x8b, 0xbd, 0xb1, 0x2e,
0x4a, 0x12, 0x75, 0xf5, 0x70, 0xde, 0xf0, 0xae, 0x8b, 0xe9, 0xa9, 0xb7, 0x82, 0xc9,
0x80, 0x6, 0x4b, 0xbf, 0x37, 0x36, 0x3b, 0x89, 0xcb, 0xce, 0xee, 0x1, 0x96, 0x36, 0xfe,
0x32, 0xb7, 0x21, 0xdf, 0x81, 0xd2, 0x5b, 0x2c, 0x4e, 0x7d, 0x6c, 0xed, 0xd6, 0xff,
0x6f, 0xa4, 0x97, 0xc4, 0x55, 0xd, 0xbf, 0x55, 0x96, 0x88, 0x9e, 0xe, 0x94, 0x51,
][..];
let e = &[0x1, 0x0, 0x1][..];
assert_eq!(key.n.to_bytes_be(), n);
assert_eq!(key.e.to_bytes_be(), e);
}
#[test]
fn it_encrypts_sha1() {
// https://github.com/pyca/cryptography/blob/master/vectors/cryptography_vectors/asymmetric/RSA/pkcs-1v2-1d2-vec/oaep-int.txt
let n = BigUint::from_bytes_be(&[
0xbb, 0xf8, 0x2f, 0x09, 0x06, 0x82, 0xce, 0x9c, 0x23, 0x38, 0xac, 0x2b, 0x9d, 0xa8,
0x71, 0xf7, 0x36, 0x8d, 0x07, 0xee, 0xd4, 0x10, 0x43, 0xa4, 0x40, 0xd6, 0xb6, 0xf0,
0x74, 0x54, 0xf5, 0x1f, 0xb8, 0xdf, 0xba, 0xaf, 0x03, 0x5c, 0x02, 0xab, 0x61, 0xea,
0x48, 0xce, 0xeb, 0x6f, 0xcd, 0x48, 0x76, 0xed, 0x52, 0x0d, 0x60, 0xe1, 0xec, 0x46,
0x19, 0x71, 0x9d, 0x8a, 0x5b, 0x8b, 0x80, 0x7f, 0xaf, 0xb8, 0xe0, 0xa3, 0xdf, 0xc7,
0x37, 0x72, 0x3e, 0xe6, 0xb4, 0xb7, 0xd9, 0x3a, 0x25, 0x84, 0xee, 0x6a, 0x64, 0x9d,
0x06, 0x09, 0x53, 0x74, 0x88, 0x34, 0xb2, 0x45, 0x45, 0x98, 0x39, 0x4e, 0xe0, 0xaa,
0xb1, 0x2d, 0x7b, 0x61, 0xa5, 0x1f, 0x52, 0x7a, 0x9a, 0x41, 0xf6, 0xc1, 0x68, 0x7f,
0xe2, 0x53, 0x72, 0x98, 0xca, 0x2a, 0x8f, 0x59, 0x46, 0xf8, 0xe5, 0xfd, 0x09, 0x1d,
0xbd, 0xcb,
]);
let e = BigUint::from_bytes_be(&[0x11]);
let pub_key = PublicKey { n, e };
let message = &[
0xd4, 0x36, 0xe9, 0x95, 0x69, 0xfd, 0x32, 0xa7, 0xc8, 0xa0, 0x5b, 0xbc, 0x90, 0xd3,
0x2c, 0x49,
];
let mut seed = &[
0xaa, 0xfd, 0x12, 0xf6, 0x59, 0xca, 0xe6, 0x34, 0x89, 0xb4, 0x79, 0xe5, 0x07, 0x6d,
0xde, 0xc2, 0xf0, 0x6c, 0xb5, 0x8f,
][..];
let mut rng = ReadRng::new(seed);
let cipher_text = super::oaep_encrypt::<_, Sha1>(&mut rng, &pub_key, message).unwrap();
let expected_cipher_text = &[
0x12, 0x53, 0xe0, 0x4d, 0xc0, 0xa5, 0x39, 0x7b, 0xb4, 0x4a, 0x7a, 0xb8, 0x7e, 0x9b,
0xf2, 0xa0, 0x39, 0xa3, 0x3d, 0x1e, 0x99, 0x6f, 0xc8, 0x2a, 0x94, 0xcc, 0xd3, 0x00,
0x74, 0xc9, 0x5d, 0xf7, 0x63, 0x72, 0x20, 0x17, 0x06, 0x9e, 0x52, 0x68, 0xda, 0x5d,
0x1c, 0x0b, 0x4f, 0x87, 0x2c, 0xf6, 0x53, 0xc1, 0x1d, 0xf8, 0x23, 0x14, 0xa6, 0x79,
0x68, 0xdf, 0xea, 0xe2, 0x8d, 0xef, 0x04, 0xbb, 0x6d, 0x84, 0xb1, 0xc3, 0x1d, 0x65,
0x4a, 0x19, 0x70, 0xe5, 0x78, 0x3b, 0xd6, 0xeb, 0x96, 0xa0, 0x24, 0xc2, 0xca, 0x2f,
0x4a, 0x90, 0xfe, 0x9f, 0x2e, 0xf5, 0xc9, 0xc1, 0x40, 0xe5, 0xbb, 0x48, 0xda, 0x95,
0x36, 0xad, 0x87, 0x00, 0xc8, 0x4f, 0xc9, 0x13, 0x0a, 0xde, 0xa7, 0x4e, 0x55, 0x8d,
0x51, 0xa7, 0x4d, 0xdf, 0x85, 0xd8, 0xb5, 0x0d, 0xe9, 0x68, 0x38, 0xd6, 0x06, 0x3e,
0x09, 0x55,
][..];
assert_eq!(&*expected_cipher_text, &*cipher_text);
}
}

View file

@ -0,0 +1,9 @@
// XOR(x, y)
// If len(y) < len(x), wrap around inside y
pub fn xor_eq(x: &mut [u8], y: &[u8]) {
let y_len = y.len();
for i in 0..x.len() {
x[i] ^= y[i % y_len];
}
}