Remove receiver

This commit is contained in:
Daniel Akhterov 2019-06-30 14:22:00 -07:00 committed by Daniel Akhterov
parent f107f1de6c
commit 8dcd113517
3 changed files with 223 additions and 134 deletions

View file

@ -14,49 +14,45 @@ use mason_core::ConnectOptions;
use std::io;
use failure::Error;
use bytes::Bytes;
use failure::err_msg;
pub async fn establish<'a, 'b: 'a>(
conn: &'a mut Connection,
options: ConnectOptions<'b>,
) -> Result<(), Error> {
let init_packet = if let Some(message) = conn.incoming.next().await {
conn.sequence_number = message.sequence_number();
match message {
ServerMessage::InitialHandshakePacket(message) => {
Ok(message)
},
_ => Err(failure::err_msg("Incorrect First Packet")),
}
} else {
Err(failure::err_msg("Failed to connect"))
}?;
let init_packet = InitialHandshakePacket::deserialize(&conn.stream.next_bytes().await?)?;
conn.server_capabilities = init_packet.capabilities;
conn.capabilities = init_packet.capabilities;
let handshake: HandshakeResponsePacket = HandshakeResponsePacket {
// Minimum client capabilities required to establish connection
capabilities: Capabilities::CLIENT_PROTOCOL_41,
max_packet_size: 1024,
collation: 0,
extended_capabilities: Some(Capabilities::from_bits_truncate(0)),
username: Bytes::from_static(b"root"),
auth_data: None,
auth_response_len: None,
auth_response: None,
database: None,
auth_plugin_name: None,
conn_attr_len: None,
conn_attr: None,
..Default::default()
};
conn.send(handshake).await?;
if let Some(message) = conn.incoming.next().await {
println!("{:?}", message);
conn.sequence_number = message.sequence_number();
Ok(())
} else {
Err(failure::err_msg("Handshake Failed"))
match conn.stream.next().await? {
Some(ServerMessage::OkPacket(message)) => {
println!("{:?}", message);
conn.seq_no = message.seq_no;
Ok(())
}
Some(ServerMessage::ErrPacket(message)) => {
Err(err_msg(format!("{:?}", message)))
}
Some(message) => {
panic!("Did not receive OkPacket nor ErrPacket");
}
None => {
panic!("Did not recieve packet");
}
}
}
@ -69,29 +65,16 @@ mod test {
#[runtime::test]
async fn it_connects() -> Result<(), Error> {
let mut conn = Connection::establish(ConnectOptions {
host: "localhost",
host: "127.0.0.1",
port: 3306,
user: Some("root"),
database: None,
password: None,
}).await?;
//
// conn.ping().await?;
conn.ping().await?;
if let Some(message) = conn.incoming.next().await {
match message {
ServerMessage::OkPacket(packet) => {
conn.quit().await?;
Ok(())
}
ServerMessage::ErrPacket(packet) => {
Err(err_msg(format!("{:?}", packet)))
}
_ => Err(err_msg("Server Failed"))
}
} else {
Err(err_msg("Server Failed"))
}
Ok(())
}
}

View file

@ -3,16 +3,18 @@ use crate::protocol::{
client::ComQuit,
client::ComPing,
server::Message as ServerMessage,
server::ServerStatusFlag,
server::Capabilities,
server::InitialHandshakePacket,
server::Deserialize
};
use bytes::BytesMut;
use futures::{
channel::mpsc,
io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
SinkExt, StreamExt,
io::{AsyncRead, AsyncWriteExt},
task::{Context, Poll},
Stream,
};
use futures::prelude::*;
use mason_core::ConnectOptions;
use runtime::{net::TcpStream, task::JoinHandle};
use std::io;
@ -21,44 +23,40 @@ use failure::err_msg;
use byteorder::{ByteOrder, LittleEndian, WriteBytesExt};
use crate::protocol::serialize::serialize_length;
use bytes::BufMut;
use bytes::Bytes;
mod establish;
// mod query;
pub struct Connection {
writer: WriteHalf<TcpStream>,
incoming: mpsc::UnboundedReceiver<ServerMessage>,
stream: Framed,
// Buffer used when serializing outgoing messages
wbuf: BytesMut,
// Handle to coroutine reading messages from the stream
receiver: JoinHandle<Result<(), Error>>,
// MariaDB Connection ID
connection_id: i32,
// Sequence Number
sequence_number: u8,
seq_no: u8,
// Server Capabilities
server_capabilities: Capabilities,
capabilities: Capabilities,
// Server status
status: ServerStatusFlag,
}
impl Connection {
pub async fn establish(options: ConnectOptions<'static>) -> Result<Self, Error> {
let stream = TcpStream::connect((options.host, options.port)).await?;
let (reader, writer) = stream.split();
let (tx, rx) = mpsc::unbounded();
let receiver: JoinHandle<Result<(), Error>> = runtime::spawn(receiver(reader, tx));
let stream: Framed = Framed::new(TcpStream::connect((options.host, options.port)).await?);
let mut conn = Self {
stream,
wbuf: BytesMut::with_capacity(1024),
writer,
receiver,
incoming: rx,
connection_id: -1,
sequence_number: 1,
server_capabilities: Capabilities::default(),
seq_no: 1,
capabilities: Capabilities::default(),
status: ServerStatusFlag::default(),
};
establish::establish(&mut conn, options).await?;
@ -81,15 +79,15 @@ impl Connection {
*/
// Reserve space for packet header; Packet Body Length (3 bytes) and sequence number (1 byte)
self.wbuf.extend_from_slice(&[0; 4]);
self.wbuf[3] = self.sequence_number;
self.wbuf[3] = self.seq_no;
message.serialize(&mut self.wbuf, &self.server_capabilities)?;
message.serialize(&mut self.wbuf, &self.capabilities)?;
serialize_length(&mut self.wbuf);
println!("{:?}", self.wbuf);
self.writer.write_all(&self.wbuf).await?;
self.writer.flush().await?;
self.stream.inner.write_all(&self.wbuf).await?;
self.stream.inner.flush().await?;
Ok(())
}
@ -101,72 +99,191 @@ impl Connection {
}
async fn ping(&mut self) -> Result<(), Error> {
self.sequence_number = 0;
self.seq_no = 0;
self.send(ComPing()).await?;
Ok(())
match self.stream.next().await? {
Some(ServerMessage::OkPacket(message)) => {
println!("{:?}", message);
self.seq_no = message.seq_no;
Ok(())
}
Some(ServerMessage::ErrPacket(message)) => {
Err(err_msg(format!("{:?}", message)))
}
Some(message) => {
panic!("Did not receive OkPacket nor ErrPacket");
}
None => {
panic!("Did not recieve packet");
}
}
}
}
async fn receiver(
mut reader: ReadHalf<TcpStream>,
mut sender: mpsc::UnboundedSender<ServerMessage>,
) -> Result<(), Error> {
let mut rbuf = BytesMut::with_capacity(0);
let mut len = 0;
let mut first_packet = true;
struct Framed {
inner: TcpStream,
readable: bool,
eof: bool,
buffer: BytesMut,
}
loop {
// This uses an adaptive system to extend the vector when it fills. We want to
// avoid paying to allocate and zero a huge chunk of memory if the reader only
// has 4 bytes while still making large reads if the reader does have a ton
// of data to return.
impl Framed {
fn new(stream: TcpStream) -> Self {
Self {
readable: false,
eof: false,
inner: stream,
buffer: BytesMut::with_capacity(8 * 1024),
}
}
// See: https://github.com/rust-lang-nursery/futures-rs/blob/master/futures-util/src/io/read_to_end.rs#L50-L54
async fn next_bytes(&mut self) -> Result<Bytes, Error> {
let mut rbuf = BytesMut::with_capacity(0);
let mut len = 0;
let mut packet_len: u32 = 0;
if len == rbuf.len() {
rbuf.reserve(32);
loop {
if len == rbuf.len() {
rbuf.reserve(32);
unsafe {
// Set length to the capacity and efficiently
// zero-out the memory
rbuf.set_len(rbuf.capacity());
reader.initializer().initialize(&mut rbuf[len..]);
unsafe {
// Set length to the capacity and efficiently
// zero-out the memory
rbuf.set_len(rbuf.capacity());
self.inner.initializer().initialize(&mut rbuf[len..]);
}
}
let bytes_read = self.inner.read(&mut rbuf[len..]).await?;
if bytes_read > 0 {
len += bytes_read;
} else {
// Read 0 bytes from the server; end-of-stream
return Ok(Bytes::new());
}
println!("buf len: {:?}", rbuf);
if len > 0 && packet_len == 0 {
packet_len = LittleEndian::read_u24(&rbuf[0..]);
}
// Loop until the length of the buffer is the length of the packet
if packet_len as usize > len {
continue;
} else {
return Ok(rbuf.freeze());
}
}
}
// TODO: Need a select! on a channel that I can trigger to cancel this
let bytes_read = reader.read(&mut rbuf[len..]).await?;
async fn next(&mut self) -> Result<Option<ServerMessage>, Error> {
let mut rbuf = BytesMut::with_capacity(0);
let mut len = 0;
if bytes_read > 0 {
len += bytes_read;
} else {
// Read 0 bytes from the server; end-of-stream
break;
}
loop {
if len == rbuf.len() {
rbuf.reserve(32);
while len > 0 {
let size = rbuf.len();
let message = if first_packet {
ServerMessage::init(&mut rbuf)
unsafe {
// Set length to the capacity and efficiently
// zero-out the memory
rbuf.set_len(rbuf.capacity());
self.inner.initializer().initialize(&mut rbuf[len..]);
}
}
let bytes_read = self.inner.read(&mut rbuf[len..]).await?;
if bytes_read > 0 {
len += bytes_read;
} else {
ServerMessage::deserialize(&mut rbuf)
}?;
len -= size - rbuf.len();
if let Some(message) = message {
first_packet = false;
sender.send(message).await.unwrap();
} else {
// Did not receive enough bytes to
// deserialize a complete message
// Read 0 bytes from the server; end-of-stream
break;
}
while len > 0 {
let size = rbuf.len();
let message = ServerMessage::deserialize(&mut rbuf)?;
len -= size - rbuf.len();
match message {
message @ Some(_) => return Ok(message),
// Did not receive enough bytes to
// deserialize a complete message
None => break,
}
}
}
Err(err_msg("Failed to get next packet"))
}
Ok(())
}
//async fn receiver(
// mut reader: ReadHalf<TcpStream>,
// mut sender: mpsc::UnboundedSender<ServerMessage>,
//) -> Result<(), Error> {
// let mut rbuf = BytesMut::with_capacity(0);
// let mut len = 0;
// let mut first_packet = true;
//
// loop {
// // This uses an adaptive system to extend the vector when it fills. We want to
// // avoid paying to allocate and zero a huge chunk of memory if the reader only
// // has 4 bytes while still making large reads if the reader does have a ton
// // of data to return.
//
// // See: https://github.com/rust-lang-nursery/futures-rs/blob/master/futures-util/src/io/read_to_end.rs#L50-L54
//
// if len == rbuf.len() {
// rbuf.reserve(32);
//
// unsafe {
// // Set length to the capacity and efficiently
// // zero-out the memory
// rbuf.set_len(rbuf.capacity());
// reader.initializer().initialize(&mut rbuf[len..]);
// }
// }
//
// // TODO: Need a select! on a channel that I can trigger to cancel this
// let bytes_read = reader.read(&mut rbuf[len..]).await?;
//
// if bytes_read > 0 {
// len += bytes_read;
// } else {
// // Read 0 bytes from the server; end-of-stream
// break;
// }
//
// while len > 0 {
// let size = rbuf.len();
// let message = if first_packet {
// ServerMessage::init(&mut rbuf)
// } else {
// ServerMessage::deserialize(&mut rbuf)
// }?;
// len -= size - rbuf.len();
//
// if let Some(message) = message {
// first_packet = false;
// sender.send(message).await.unwrap();
// } else {
// // Did not receive enough bytes to
// // deserialize a complete message
// break;
// }
//
// }
//
//
// }
//
// Ok(())
//}

View file

@ -18,17 +18,6 @@ pub enum Message {
ErrPacket(ErrPacket),
}
impl Message {
pub fn sequence_number(&self) -> u8 {
match self {
Message::InitialHandshakePacket(InitialHandshakePacket{ sequence_number, ..}) => sequence_number + 1,
Message::OkPacket(OkPacket{ sequence_number, ..}) => sequence_number + 1,
Message::ErrPacket(ErrPacket { sequence_number, .. }) => sequence_number + 1,
_ => 0
}
}
}
bitflags! {
pub struct Capabilities: u128 {
const CLIENT_MYSQL = 1;
@ -119,7 +108,7 @@ impl Default for ServerStatusFlag {
#[derive(Default, Debug)]
pub struct InitialHandshakePacket {
pub length: u32,
pub sequence_number: u8,
pub seq_no: u8,
pub protocol_version: u8,
pub server_version: Bytes,
pub connection_id: u32,
@ -134,7 +123,7 @@ pub struct InitialHandshakePacket {
#[derive(Default, Debug)]
pub struct OkPacket {
pub sequence_number: u8,
pub seq_no: u8,
pub affected_rows: Option<usize>,
pub last_insert_id: Option<usize>,
pub server_status: ServerStatusFlag,
@ -146,7 +135,7 @@ pub struct OkPacket {
#[derive(Default, Debug)]
pub struct ErrPacket {
pub sequence_number: u8,
pub seq_no: u8,
pub error_code: u16,
pub stage: Option<u8>,
pub max_stage: Option<u8>,
@ -203,9 +192,9 @@ impl Deserialize for InitialHandshakePacket {
let mut index = 0;
let length = deserialize_length(&buf, &mut index)?;
let sequence_number = deserialize_int_1(&buf, &mut index);
let seq_no = deserialize_int_1(&buf, &mut index);
if sequence_number != 0 {
if seq_no != 0 {
return Err(err_msg("Squence Number of Initial Handshake Packet is not 0"));
}
@ -261,7 +250,7 @@ impl Deserialize for InitialHandshakePacket {
Ok(InitialHandshakePacket {
length,
sequence_number,
seq_no,
protocol_version,
server_version,
connection_id,
@ -282,7 +271,7 @@ impl Deserialize for OkPacket {
// Packet header
let length = deserialize_length(&buf, &mut index)?;
let sequence_number = deserialize_int_1(&buf, &mut index);
let seq_no = deserialize_int_1(&buf, &mut index);
// Packet body
let packet_header = deserialize_int_1(&buf, &mut index);
@ -302,7 +291,7 @@ impl Deserialize for OkPacket {
let info = Bytes::from(&buf[index..]);
Ok(OkPacket {
sequence_number,
seq_no,
affected_rows,
last_insert_id,
server_status,
@ -319,7 +308,7 @@ impl Deserialize for ErrPacket {
let mut index = 0;
let length = deserialize_length(&buf, &mut index)?;
let sequence_number = deserialize_int_1(&buf, &mut index);
let seq_no = deserialize_int_1(&buf, &mut index);
let packet_header = deserialize_int_1(&buf, &mut index);
if packet_header != 0xFF {
@ -354,7 +343,7 @@ impl Deserialize for ErrPacket {
}
Ok(ErrPacket {
sequence_number,
seq_no,
error_code,
stage,
max_stage,