mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
Remove receiver
This commit is contained in:
parent
f107f1de6c
commit
8dcd113517
3 changed files with 223 additions and 134 deletions
|
@ -14,49 +14,45 @@ use mason_core::ConnectOptions;
|
|||
use std::io;
|
||||
use failure::Error;
|
||||
use bytes::Bytes;
|
||||
use failure::err_msg;
|
||||
|
||||
pub async fn establish<'a, 'b: 'a>(
|
||||
conn: &'a mut Connection,
|
||||
options: ConnectOptions<'b>,
|
||||
) -> Result<(), Error> {
|
||||
let init_packet = if let Some(message) = conn.incoming.next().await {
|
||||
conn.sequence_number = message.sequence_number();
|
||||
match message {
|
||||
ServerMessage::InitialHandshakePacket(message) => {
|
||||
Ok(message)
|
||||
},
|
||||
_ => Err(failure::err_msg("Incorrect First Packet")),
|
||||
}
|
||||
} else {
|
||||
Err(failure::err_msg("Failed to connect"))
|
||||
}?;
|
||||
let init_packet = InitialHandshakePacket::deserialize(&conn.stream.next_bytes().await?)?;
|
||||
|
||||
conn.server_capabilities = init_packet.capabilities;
|
||||
conn.capabilities = init_packet.capabilities;
|
||||
|
||||
let handshake: HandshakeResponsePacket = HandshakeResponsePacket {
|
||||
// Minimum client capabilities required to establish connection
|
||||
capabilities: Capabilities::CLIENT_PROTOCOL_41,
|
||||
max_packet_size: 1024,
|
||||
collation: 0,
|
||||
extended_capabilities: Some(Capabilities::from_bits_truncate(0)),
|
||||
username: Bytes::from_static(b"root"),
|
||||
auth_data: None,
|
||||
auth_response_len: None,
|
||||
auth_response: None,
|
||||
database: None,
|
||||
auth_plugin_name: None,
|
||||
conn_attr_len: None,
|
||||
conn_attr: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
conn.send(handshake).await?;
|
||||
|
||||
if let Some(message) = conn.incoming.next().await {
|
||||
println!("{:?}", message);
|
||||
conn.sequence_number = message.sequence_number();
|
||||
Ok(())
|
||||
} else {
|
||||
Err(failure::err_msg("Handshake Failed"))
|
||||
match conn.stream.next().await? {
|
||||
Some(ServerMessage::OkPacket(message)) => {
|
||||
println!("{:?}", message);
|
||||
conn.seq_no = message.seq_no;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Some(ServerMessage::ErrPacket(message)) => {
|
||||
Err(err_msg(format!("{:?}", message)))
|
||||
}
|
||||
|
||||
Some(message) => {
|
||||
panic!("Did not receive OkPacket nor ErrPacket");
|
||||
}
|
||||
|
||||
None => {
|
||||
panic!("Did not recieve packet");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,29 +65,16 @@ mod test {
|
|||
#[runtime::test]
|
||||
async fn it_connects() -> Result<(), Error> {
|
||||
let mut conn = Connection::establish(ConnectOptions {
|
||||
host: "localhost",
|
||||
host: "127.0.0.1",
|
||||
port: 3306,
|
||||
user: Some("root"),
|
||||
database: None,
|
||||
password: None,
|
||||
}).await?;
|
||||
//
|
||||
// conn.ping().await?;
|
||||
|
||||
conn.ping().await?;
|
||||
|
||||
if let Some(message) = conn.incoming.next().await {
|
||||
match message {
|
||||
ServerMessage::OkPacket(packet) => {
|
||||
conn.quit().await?;
|
||||
Ok(())
|
||||
}
|
||||
ServerMessage::ErrPacket(packet) => {
|
||||
Err(err_msg(format!("{:?}", packet)))
|
||||
}
|
||||
_ => Err(err_msg("Server Failed"))
|
||||
}
|
||||
} else {
|
||||
Err(err_msg("Server Failed"))
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -3,16 +3,18 @@ use crate::protocol::{
|
|||
client::ComQuit,
|
||||
client::ComPing,
|
||||
server::Message as ServerMessage,
|
||||
server::ServerStatusFlag,
|
||||
server::Capabilities,
|
||||
server::InitialHandshakePacket,
|
||||
server::Deserialize
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
|
||||
SinkExt, StreamExt,
|
||||
io::{AsyncRead, AsyncWriteExt},
|
||||
task::{Context, Poll},
|
||||
Stream,
|
||||
};
|
||||
use futures::prelude::*;
|
||||
use mason_core::ConnectOptions;
|
||||
use runtime::{net::TcpStream, task::JoinHandle};
|
||||
use std::io;
|
||||
|
@ -21,44 +23,40 @@ use failure::err_msg;
|
|||
use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
|
||||
use crate::protocol::serialize::serialize_length;
|
||||
use bytes::BufMut;
|
||||
use bytes::Bytes;
|
||||
|
||||
mod establish;
|
||||
// mod query;
|
||||
|
||||
pub struct Connection {
|
||||
writer: WriteHalf<TcpStream>,
|
||||
incoming: mpsc::UnboundedReceiver<ServerMessage>,
|
||||
stream: Framed,
|
||||
|
||||
// Buffer used when serializing outgoing messages
|
||||
wbuf: BytesMut,
|
||||
|
||||
// Handle to coroutine reading messages from the stream
|
||||
receiver: JoinHandle<Result<(), Error>>,
|
||||
|
||||
// MariaDB Connection ID
|
||||
connection_id: i32,
|
||||
|
||||
// Sequence Number
|
||||
sequence_number: u8,
|
||||
seq_no: u8,
|
||||
|
||||
// Server Capabilities
|
||||
server_capabilities: Capabilities,
|
||||
capabilities: Capabilities,
|
||||
|
||||
// Server status
|
||||
status: ServerStatusFlag,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub async fn establish(options: ConnectOptions<'static>) -> Result<Self, Error> {
|
||||
let stream = TcpStream::connect((options.host, options.port)).await?;
|
||||
let (reader, writer) = stream.split();
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let receiver: JoinHandle<Result<(), Error>> = runtime::spawn(receiver(reader, tx));
|
||||
let stream: Framed = Framed::new(TcpStream::connect((options.host, options.port)).await?);
|
||||
let mut conn = Self {
|
||||
stream,
|
||||
wbuf: BytesMut::with_capacity(1024),
|
||||
writer,
|
||||
receiver,
|
||||
incoming: rx,
|
||||
connection_id: -1,
|
||||
sequence_number: 1,
|
||||
server_capabilities: Capabilities::default(),
|
||||
seq_no: 1,
|
||||
capabilities: Capabilities::default(),
|
||||
status: ServerStatusFlag::default(),
|
||||
};
|
||||
|
||||
establish::establish(&mut conn, options).await?;
|
||||
|
@ -81,15 +79,15 @@ impl Connection {
|
|||
*/
|
||||
// Reserve space for packet header; Packet Body Length (3 bytes) and sequence number (1 byte)
|
||||
self.wbuf.extend_from_slice(&[0; 4]);
|
||||
self.wbuf[3] = self.sequence_number;
|
||||
self.wbuf[3] = self.seq_no;
|
||||
|
||||
message.serialize(&mut self.wbuf, &self.server_capabilities)?;
|
||||
message.serialize(&mut self.wbuf, &self.capabilities)?;
|
||||
serialize_length(&mut self.wbuf);
|
||||
|
||||
println!("{:?}", self.wbuf);
|
||||
|
||||
self.writer.write_all(&self.wbuf).await?;
|
||||
self.writer.flush().await?;
|
||||
self.stream.inner.write_all(&self.wbuf).await?;
|
||||
self.stream.inner.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -101,72 +99,191 @@ impl Connection {
|
|||
}
|
||||
|
||||
async fn ping(&mut self) -> Result<(), Error> {
|
||||
self.sequence_number = 0;
|
||||
self.seq_no = 0;
|
||||
self.send(ComPing()).await?;
|
||||
|
||||
Ok(())
|
||||
match self.stream.next().await? {
|
||||
Some(ServerMessage::OkPacket(message)) => {
|
||||
println!("{:?}", message);
|
||||
self.seq_no = message.seq_no;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Some(ServerMessage::ErrPacket(message)) => {
|
||||
Err(err_msg(format!("{:?}", message)))
|
||||
}
|
||||
|
||||
Some(message) => {
|
||||
panic!("Did not receive OkPacket nor ErrPacket");
|
||||
}
|
||||
|
||||
None => {
|
||||
panic!("Did not recieve packet");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn receiver(
|
||||
mut reader: ReadHalf<TcpStream>,
|
||||
mut sender: mpsc::UnboundedSender<ServerMessage>,
|
||||
) -> Result<(), Error> {
|
||||
let mut rbuf = BytesMut::with_capacity(0);
|
||||
let mut len = 0;
|
||||
let mut first_packet = true;
|
||||
struct Framed {
|
||||
inner: TcpStream,
|
||||
readable: bool,
|
||||
eof: bool,
|
||||
buffer: BytesMut,
|
||||
}
|
||||
|
||||
loop {
|
||||
// This uses an adaptive system to extend the vector when it fills. We want to
|
||||
// avoid paying to allocate and zero a huge chunk of memory if the reader only
|
||||
// has 4 bytes while still making large reads if the reader does have a ton
|
||||
// of data to return.
|
||||
impl Framed {
|
||||
fn new(stream: TcpStream) -> Self {
|
||||
Self {
|
||||
readable: false,
|
||||
eof: false,
|
||||
inner: stream,
|
||||
buffer: BytesMut::with_capacity(8 * 1024),
|
||||
}
|
||||
}
|
||||
|
||||
// See: https://github.com/rust-lang-nursery/futures-rs/blob/master/futures-util/src/io/read_to_end.rs#L50-L54
|
||||
async fn next_bytes(&mut self) -> Result<Bytes, Error> {
|
||||
let mut rbuf = BytesMut::with_capacity(0);
|
||||
let mut len = 0;
|
||||
let mut packet_len: u32 = 0;
|
||||
|
||||
if len == rbuf.len() {
|
||||
rbuf.reserve(32);
|
||||
loop {
|
||||
if len == rbuf.len() {
|
||||
rbuf.reserve(32);
|
||||
|
||||
unsafe {
|
||||
// Set length to the capacity and efficiently
|
||||
// zero-out the memory
|
||||
rbuf.set_len(rbuf.capacity());
|
||||
reader.initializer().initialize(&mut rbuf[len..]);
|
||||
unsafe {
|
||||
// Set length to the capacity and efficiently
|
||||
// zero-out the memory
|
||||
rbuf.set_len(rbuf.capacity());
|
||||
self.inner.initializer().initialize(&mut rbuf[len..]);
|
||||
}
|
||||
}
|
||||
|
||||
let bytes_read = self.inner.read(&mut rbuf[len..]).await?;
|
||||
|
||||
if bytes_read > 0 {
|
||||
len += bytes_read;
|
||||
} else {
|
||||
// Read 0 bytes from the server; end-of-stream
|
||||
return Ok(Bytes::new());
|
||||
}
|
||||
|
||||
println!("buf len: {:?}", rbuf);
|
||||
|
||||
if len > 0 && packet_len == 0 {
|
||||
packet_len = LittleEndian::read_u24(&rbuf[0..]);
|
||||
}
|
||||
|
||||
// Loop until the length of the buffer is the length of the packet
|
||||
if packet_len as usize > len {
|
||||
continue;
|
||||
} else {
|
||||
return Ok(rbuf.freeze());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Need a select! on a channel that I can trigger to cancel this
|
||||
let bytes_read = reader.read(&mut rbuf[len..]).await?;
|
||||
async fn next(&mut self) -> Result<Option<ServerMessage>, Error> {
|
||||
let mut rbuf = BytesMut::with_capacity(0);
|
||||
let mut len = 0;
|
||||
|
||||
if bytes_read > 0 {
|
||||
len += bytes_read;
|
||||
} else {
|
||||
// Read 0 bytes from the server; end-of-stream
|
||||
break;
|
||||
}
|
||||
loop {
|
||||
if len == rbuf.len() {
|
||||
rbuf.reserve(32);
|
||||
|
||||
while len > 0 {
|
||||
let size = rbuf.len();
|
||||
let message = if first_packet {
|
||||
ServerMessage::init(&mut rbuf)
|
||||
unsafe {
|
||||
// Set length to the capacity and efficiently
|
||||
// zero-out the memory
|
||||
rbuf.set_len(rbuf.capacity());
|
||||
self.inner.initializer().initialize(&mut rbuf[len..]);
|
||||
}
|
||||
}
|
||||
|
||||
let bytes_read = self.inner.read(&mut rbuf[len..]).await?;
|
||||
|
||||
if bytes_read > 0 {
|
||||
len += bytes_read;
|
||||
} else {
|
||||
ServerMessage::deserialize(&mut rbuf)
|
||||
}?;
|
||||
len -= size - rbuf.len();
|
||||
|
||||
if let Some(message) = message {
|
||||
first_packet = false;
|
||||
sender.send(message).await.unwrap();
|
||||
} else {
|
||||
// Did not receive enough bytes to
|
||||
// deserialize a complete message
|
||||
// Read 0 bytes from the server; end-of-stream
|
||||
break;
|
||||
}
|
||||
|
||||
while len > 0 {
|
||||
let size = rbuf.len();
|
||||
let message = ServerMessage::deserialize(&mut rbuf)?;
|
||||
len -= size - rbuf.len();
|
||||
|
||||
match message {
|
||||
message @ Some(_) => return Ok(message),
|
||||
// Did not receive enough bytes to
|
||||
// deserialize a complete message
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Err(err_msg("Failed to get next packet"))
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
//async fn receiver(
|
||||
// mut reader: ReadHalf<TcpStream>,
|
||||
// mut sender: mpsc::UnboundedSender<ServerMessage>,
|
||||
//) -> Result<(), Error> {
|
||||
// let mut rbuf = BytesMut::with_capacity(0);
|
||||
// let mut len = 0;
|
||||
// let mut first_packet = true;
|
||||
//
|
||||
// loop {
|
||||
// // This uses an adaptive system to extend the vector when it fills. We want to
|
||||
// // avoid paying to allocate and zero a huge chunk of memory if the reader only
|
||||
// // has 4 bytes while still making large reads if the reader does have a ton
|
||||
// // of data to return.
|
||||
//
|
||||
// // See: https://github.com/rust-lang-nursery/futures-rs/blob/master/futures-util/src/io/read_to_end.rs#L50-L54
|
||||
//
|
||||
// if len == rbuf.len() {
|
||||
// rbuf.reserve(32);
|
||||
//
|
||||
// unsafe {
|
||||
// // Set length to the capacity and efficiently
|
||||
// // zero-out the memory
|
||||
// rbuf.set_len(rbuf.capacity());
|
||||
// reader.initializer().initialize(&mut rbuf[len..]);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // TODO: Need a select! on a channel that I can trigger to cancel this
|
||||
// let bytes_read = reader.read(&mut rbuf[len..]).await?;
|
||||
//
|
||||
// if bytes_read > 0 {
|
||||
// len += bytes_read;
|
||||
// } else {
|
||||
// // Read 0 bytes from the server; end-of-stream
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
// while len > 0 {
|
||||
// let size = rbuf.len();
|
||||
// let message = if first_packet {
|
||||
// ServerMessage::init(&mut rbuf)
|
||||
// } else {
|
||||
// ServerMessage::deserialize(&mut rbuf)
|
||||
// }?;
|
||||
// len -= size - rbuf.len();
|
||||
//
|
||||
// if let Some(message) = message {
|
||||
// first_packet = false;
|
||||
// sender.send(message).await.unwrap();
|
||||
// } else {
|
||||
// // Did not receive enough bytes to
|
||||
// // deserialize a complete message
|
||||
// break;
|
||||
// }
|
||||
//
|
||||
// }
|
||||
//
|
||||
//
|
||||
// }
|
||||
//
|
||||
// Ok(())
|
||||
//}
|
||||
|
|
|
@ -18,17 +18,6 @@ pub enum Message {
|
|||
ErrPacket(ErrPacket),
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn sequence_number(&self) -> u8 {
|
||||
match self {
|
||||
Message::InitialHandshakePacket(InitialHandshakePacket{ sequence_number, ..}) => sequence_number + 1,
|
||||
Message::OkPacket(OkPacket{ sequence_number, ..}) => sequence_number + 1,
|
||||
Message::ErrPacket(ErrPacket { sequence_number, .. }) => sequence_number + 1,
|
||||
_ => 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bitflags! {
|
||||
pub struct Capabilities: u128 {
|
||||
const CLIENT_MYSQL = 1;
|
||||
|
@ -119,7 +108,7 @@ impl Default for ServerStatusFlag {
|
|||
#[derive(Default, Debug)]
|
||||
pub struct InitialHandshakePacket {
|
||||
pub length: u32,
|
||||
pub sequence_number: u8,
|
||||
pub seq_no: u8,
|
||||
pub protocol_version: u8,
|
||||
pub server_version: Bytes,
|
||||
pub connection_id: u32,
|
||||
|
@ -134,7 +123,7 @@ pub struct InitialHandshakePacket {
|
|||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct OkPacket {
|
||||
pub sequence_number: u8,
|
||||
pub seq_no: u8,
|
||||
pub affected_rows: Option<usize>,
|
||||
pub last_insert_id: Option<usize>,
|
||||
pub server_status: ServerStatusFlag,
|
||||
|
@ -146,7 +135,7 @@ pub struct OkPacket {
|
|||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ErrPacket {
|
||||
pub sequence_number: u8,
|
||||
pub seq_no: u8,
|
||||
pub error_code: u16,
|
||||
pub stage: Option<u8>,
|
||||
pub max_stage: Option<u8>,
|
||||
|
@ -203,9 +192,9 @@ impl Deserialize for InitialHandshakePacket {
|
|||
let mut index = 0;
|
||||
|
||||
let length = deserialize_length(&buf, &mut index)?;
|
||||
let sequence_number = deserialize_int_1(&buf, &mut index);
|
||||
let seq_no = deserialize_int_1(&buf, &mut index);
|
||||
|
||||
if sequence_number != 0 {
|
||||
if seq_no != 0 {
|
||||
return Err(err_msg("Squence Number of Initial Handshake Packet is not 0"));
|
||||
}
|
||||
|
||||
|
@ -261,7 +250,7 @@ impl Deserialize for InitialHandshakePacket {
|
|||
|
||||
Ok(InitialHandshakePacket {
|
||||
length,
|
||||
sequence_number,
|
||||
seq_no,
|
||||
protocol_version,
|
||||
server_version,
|
||||
connection_id,
|
||||
|
@ -282,7 +271,7 @@ impl Deserialize for OkPacket {
|
|||
|
||||
// Packet header
|
||||
let length = deserialize_length(&buf, &mut index)?;
|
||||
let sequence_number = deserialize_int_1(&buf, &mut index);
|
||||
let seq_no = deserialize_int_1(&buf, &mut index);
|
||||
|
||||
// Packet body
|
||||
let packet_header = deserialize_int_1(&buf, &mut index);
|
||||
|
@ -302,7 +291,7 @@ impl Deserialize for OkPacket {
|
|||
let info = Bytes::from(&buf[index..]);
|
||||
|
||||
Ok(OkPacket {
|
||||
sequence_number,
|
||||
seq_no,
|
||||
affected_rows,
|
||||
last_insert_id,
|
||||
server_status,
|
||||
|
@ -319,7 +308,7 @@ impl Deserialize for ErrPacket {
|
|||
let mut index = 0;
|
||||
|
||||
let length = deserialize_length(&buf, &mut index)?;
|
||||
let sequence_number = deserialize_int_1(&buf, &mut index);
|
||||
let seq_no = deserialize_int_1(&buf, &mut index);
|
||||
|
||||
let packet_header = deserialize_int_1(&buf, &mut index);
|
||||
if packet_header != 0xFF {
|
||||
|
@ -354,7 +343,7 @@ impl Deserialize for ErrPacket {
|
|||
}
|
||||
|
||||
Ok(ErrPacket {
|
||||
sequence_number,
|
||||
seq_no,
|
||||
error_code,
|
||||
stage,
|
||||
max_stage,
|
||||
|
|
Loading…
Reference in a new issue