mirror of
https://github.com/launchbadge/sqlx
synced 2024-09-20 14:21:57 +00:00
WIP: DeContext
This commit is contained in:
parent
3d5590c6c9
commit
4cfb1d46a1
12 changed files with 104 additions and 61 deletions
|
@ -1,6 +1,6 @@
|
|||
use super::{Connection, Decoder};
|
||||
use super::{Connection};
|
||||
use crate::protocol::{
|
||||
deserialize::Deserialize,
|
||||
deserialize::{Deserialize, DeContext},
|
||||
packets::{handshake_response::HandshakeResponsePacket, initial::InitialHandshakePacket},
|
||||
server::Message as ServerMessage,
|
||||
types::Capabilities,
|
||||
|
@ -14,10 +14,8 @@ pub async fn establish<'a, 'b: 'a>(
|
|||
options: ConnectOptions<'b>,
|
||||
) -> Result<(), Error> {
|
||||
let buf = &conn.stream.next_bytes().await?;
|
||||
let init_packet =
|
||||
InitialHandshakePacket::deserialize(conn, &mut Decoder::new(&buf))?;
|
||||
|
||||
conn.capabilities = init_packet.capabilities;
|
||||
let mut de_ctx = DeContext::new(conn, &buf);
|
||||
let _ = InitialHandshakePacket::deserialize(&mut de_ctx)?;
|
||||
|
||||
let handshake: HandshakeResponsePacket = HandshakeResponsePacket {
|
||||
// Minimum client capabilities required to establish connection
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use super::protocol::decode::Decoder;
|
||||
use crate::protocol::{
|
||||
deserialize::Deserialize,
|
||||
deserialize::{Deserialize, DeContext},
|
||||
encode::Encoder,
|
||||
packets::{com_ping::ComPing, com_quit::ComQuit, ok::OkPacket},
|
||||
serialize::Serialize,
|
||||
|
@ -31,6 +30,9 @@ pub struct Connection {
|
|||
// Sequence Number
|
||||
pub seq_no: u8,
|
||||
|
||||
// Last sequence number return by MariaDB
|
||||
pub last_seq_no: u8,
|
||||
|
||||
// Server Capabilities
|
||||
pub capabilities: Capabilities,
|
||||
|
||||
|
@ -46,6 +48,7 @@ impl Connection {
|
|||
encoder: Encoder::new(1024),
|
||||
connection_id: -1,
|
||||
seq_no: 1,
|
||||
last_seq_no: 0,
|
||||
capabilities: Capabilities::default(),
|
||||
status: ServerStatusFlag::default(),
|
||||
};
|
||||
|
@ -83,7 +86,7 @@ impl Connection {
|
|||
|
||||
// Ping response must be an OkPacket
|
||||
let buf = self.stream.next_bytes().await?;
|
||||
OkPacket::deserialize(self, &mut Decoder::new(&buf))?;
|
||||
OkPacket::deserialize(&mut DeContext::new(self, &buf))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -115,7 +118,7 @@ impl Connection {
|
|||
|
||||
while len > 0 {
|
||||
let size = rbuf.len();
|
||||
let message = ServerMessage::deserialize(self, &mut Decoder::new(&rbuf.as_ref().into()))?;
|
||||
let message = ServerMessage::deserialize(&mut DeContext::new(self, &rbuf.as_ref().into()))?;
|
||||
len -= size - rbuf.len();
|
||||
|
||||
match message {
|
||||
|
|
|
@ -1,7 +1,22 @@
|
|||
use super::decode::Decoder;
|
||||
use failure::Error;
|
||||
use crate::connection::Connection;
|
||||
use bytes::Bytes;
|
||||
|
||||
pub struct DeContext<'a> {
|
||||
pub conn: &'a mut Connection,
|
||||
pub decoder: Decoder<'a>,
|
||||
}
|
||||
|
||||
impl<'a> DeContext<'a> {
|
||||
pub fn new(conn: &'a mut Connection, buf: &'a Bytes) -> Self {
|
||||
DeContext {
|
||||
conn,
|
||||
decoder: Decoder::new(&buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Deserialize: Sized {
|
||||
fn deserialize(conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error>;
|
||||
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error>;
|
||||
}
|
||||
|
|
|
@ -47,50 +47,49 @@ impl Encoder {
|
|||
|
||||
#[inline]
|
||||
pub fn encode_int_8(&mut self, value: u64) {
|
||||
self.buf.put_u64_le(value);
|
||||
self.buf.extend_from_slice(&value.to_le_bytes());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn encode_int_4(&mut self, value: u32) {
|
||||
self.buf.put_u32_le(value);
|
||||
self.buf.extend_from_slice(&value.to_le_bytes());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn encode_int_3(&mut self, value: u32) {
|
||||
let length = value.to_le_bytes();
|
||||
self.buf.extend_from_slice(&length[0..3]);
|
||||
pub fn encode_int_3(&mut self, value: u32) {
|
||||
self.buf.extend_from_slice(&value.to_le_bytes()[0..3]);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn encode_int_2(&mut self, value: u16) {
|
||||
self.buf.put_u16_le(value);
|
||||
self.buf.extend_from_slice(&value.to_le_bytes());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn encode_int_1(&mut self, value: u8) {
|
||||
self.buf.put_u8(value);
|
||||
self.buf.extend_from_slice(&value.to_le_bytes());
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn encode_int_lenenc(&mut self, value: Option<&usize>) {
|
||||
if let Some(value) = value {
|
||||
if *value > U24_MAX && *value <= std::u64::MAX as usize {
|
||||
self.buf.put_u8(0xFE);
|
||||
self.buf.push(0xFE);
|
||||
self.encode_int_8(*value as u64);
|
||||
} else if *value > std::u16::MAX as usize && *value <= U24_MAX {
|
||||
self.buf.put_u8(0xFD);
|
||||
self.buf.push(0xFD);
|
||||
self.encode_int_3(*value as u32);
|
||||
} else if *value > std::u8::MAX as usize && *value <= std::u16::MAX as usize {
|
||||
self.buf.put_u8(0xFC);
|
||||
self.buf.push(0xFC);
|
||||
self.encode_int_2(*value as u16);
|
||||
} else if *value <= std::u8::MAX as usize {
|
||||
self.buf.put_u8(0xFA);
|
||||
self.buf.push(0xFA);
|
||||
self.encode_int_1(*value as u8);
|
||||
} else {
|
||||
panic!("Value is too long");
|
||||
}
|
||||
} else {
|
||||
self.buf.put_u8(0xFB);
|
||||
self.buf.push(0xFB);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use super::super::{decode::Decoder, deserialize::Deserialize};
|
||||
use super::super::deserialize::{Deserialize, DeContext};
|
||||
use failure::Error;
|
||||
use crate::connection::Connection;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ColumnPacket {
|
||||
|
@ -10,7 +9,8 @@ pub struct ColumnPacket {
|
|||
}
|
||||
|
||||
impl Deserialize for ColumnPacket {
|
||||
fn deserialize(_conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
|
||||
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
|
||||
let decoder = &mut ctx.decoder;
|
||||
let length = decoder.decode_length()?;
|
||||
let seq_no = decoder.decode_int_1();
|
||||
let columns = decoder.decode_int_lenenc();
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
use std::convert::TryFrom;
|
||||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
use crate::connection::Connection;
|
||||
use super::super::{
|
||||
decode::Decoder,
|
||||
deserialize::Deserialize,
|
||||
deserialize::{Deserialize, DeContext},
|
||||
types::{FieldDetailFlag, FieldType},
|
||||
};
|
||||
|
||||
|
@ -27,7 +25,8 @@ pub struct ColumnDefPacket {
|
|||
}
|
||||
|
||||
impl Deserialize for ColumnDefPacket {
|
||||
fn deserialize(_conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
|
||||
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
|
||||
let decoder = &mut ctx.decoder;
|
||||
let length = decoder.decode_length()?;
|
||||
let seq_no = decoder.decode_int_1();
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use std::convert::TryFrom;
|
||||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
use crate::connection::Connection;
|
||||
use super::super::{decode::Decoder, deserialize::Deserialize, error_codes::ErrorCode};
|
||||
use super::super::{deserialize::Deserialize, deserialize::DeContext, error_codes::ErrorCode};
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ErrPacket {
|
||||
|
@ -19,7 +18,8 @@ pub struct ErrPacket {
|
|||
}
|
||||
|
||||
impl Deserialize for ErrPacket {
|
||||
fn deserialize(_conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
|
||||
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
|
||||
let decoder = &mut ctx.decoder;
|
||||
let length = decoder.decode_length()?;
|
||||
let seq_no = decoder.decode_int_1();
|
||||
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
use super::super::{
|
||||
decode::Decoder,
|
||||
deserialize::Deserialize,
|
||||
deserialize::{Deserialize, DeContext},
|
||||
types::{Capabilities, ServerStatusFlag},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use failure::{err_msg, Error};
|
||||
use crate::connection::Connection;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct InitialHandshakePacket {
|
||||
|
@ -24,7 +22,8 @@ pub struct InitialHandshakePacket {
|
|||
}
|
||||
|
||||
impl Deserialize for InitialHandshakePacket {
|
||||
fn deserialize(_conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
|
||||
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
|
||||
let decoder = &mut ctx.decoder;
|
||||
let length = decoder.decode_length()?;
|
||||
let seq_no = decoder.decode_int_1();
|
||||
|
||||
|
@ -80,6 +79,9 @@ impl Deserialize for InitialHandshakePacket {
|
|||
auth_plugin_name = Some(decoder.decode_string_null()?);
|
||||
}
|
||||
|
||||
ctx.conn.capabilities = capabilities;
|
||||
ctx.conn.last_seq_no = seq_no;
|
||||
|
||||
Ok(InitialHandshakePacket {
|
||||
length,
|
||||
seq_no,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use super::super::{decode::Decoder, deserialize::Deserialize, types::ServerStatusFlag};
|
||||
use super::super::{deserialize::Deserialize, deserialize::DeContext, types::ServerStatusFlag};
|
||||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
use crate::connection::Connection;
|
||||
use failure::err_msg;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
|
@ -18,16 +17,17 @@ pub struct OkPacket {
|
|||
}
|
||||
|
||||
impl Deserialize for OkPacket {
|
||||
fn deserialize(_conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
|
||||
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
|
||||
let decoder = &mut ctx.decoder;
|
||||
// Packet header
|
||||
let length = decoder.decode_length()?;
|
||||
let seq_no = decoder.decode_int_1();
|
||||
|
||||
// Packet body
|
||||
let packet_header = decoder.decode_int_1();
|
||||
// if packet_header != 0 && packet_header != 0xFE {
|
||||
// return Err(err_msg("Packet header is not 0 or 0xFE for OkPacket"));
|
||||
// }
|
||||
if packet_header != 0 && packet_header != 0xFE {
|
||||
return Err(err_msg("Packet header is not 0 or 0xFE for OkPacket"));
|
||||
}
|
||||
|
||||
let affected_rows = decoder.decode_int_lenenc();
|
||||
let last_insert_id = decoder.decode_int_lenenc();
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
use bytes::Bytes;
|
||||
use failure::Error;
|
||||
use crate::connection::Connection;
|
||||
|
||||
use super::super::{
|
||||
decode::Decoder,
|
||||
deserialize::Deserialize,
|
||||
deserialize::{Deserialize, DeContext},
|
||||
packets::{column::ColumnPacket, column_def::ColumnDefPacket},
|
||||
};
|
||||
|
||||
|
@ -18,15 +15,15 @@ pub struct ResultSet {
|
|||
}
|
||||
|
||||
impl Deserialize for ResultSet {
|
||||
fn deserialize(conn: &mut Connection, decoder: &mut Decoder) -> Result<Self, Error> {
|
||||
let length = decoder.decode_length()?;
|
||||
let seq_no = decoder.decode_int_1();
|
||||
fn deserialize(ctx: &mut DeContext) -> Result<Self, Error> {
|
||||
let length = ctx.decoder.decode_length()?;
|
||||
let seq_no = ctx.decoder.decode_int_1();
|
||||
|
||||
let column_packet = ColumnPacket::deserialize(conn, decoder)?;
|
||||
let column_packet = ColumnPacket::deserialize(ctx)?;
|
||||
|
||||
let columns = if let Some(columns) = column_packet.columns {
|
||||
(0..columns)
|
||||
.map(|_| ColumnDefPacket::deserialize(conn, decoder))
|
||||
.map(|_| ColumnDefPacket::deserialize(ctx))
|
||||
.filter(Result::is_ok)
|
||||
.map(Result::unwrap)
|
||||
.collect::<Vec<ColumnDefPacket>>()
|
||||
|
@ -38,19 +35,25 @@ impl Deserialize for ResultSet {
|
|||
|
||||
for _ in 0.. {
|
||||
// if end of buffer stop
|
||||
if decoder.eof() {
|
||||
if ctx.decoder.eof() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Decode each column as string<lenenc>
|
||||
rows.push(
|
||||
(0..column_packet.columns.unwrap_or(0))
|
||||
.map(|_| decoder.decode_string_lenenc())
|
||||
.map(|_| ctx.decoder.decode_string_lenenc())
|
||||
.collect::<Vec<Bytes>>(),
|
||||
)
|
||||
}
|
||||
|
||||
Ok(ResultSet { length, seq_no, column_packet, columns, rows })
|
||||
Ok(ResultSet {
|
||||
length,
|
||||
seq_no,
|
||||
column_packet,
|
||||
columns,
|
||||
rows,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
// Reference: https://mariadb.com/kb/en/library/connection
|
||||
|
||||
use failure::Error;
|
||||
|
||||
use super::{
|
||||
decode::Decoder,
|
||||
deserialize::Deserialize,
|
||||
deserialize::{DeContext, Deserialize},
|
||||
packets::{err::ErrPacket, initial::InitialHandshakePacket, ok::OkPacket},
|
||||
};
|
||||
use crate::connection::Connection;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[non_exhaustive]
|
||||
|
@ -17,7 +16,8 @@ pub enum Message {
|
|||
}
|
||||
|
||||
impl Message {
|
||||
pub fn deserialize(conn: &mut Connection, decoder: &mut Decoder) -> Result<Option<Self>, Error> {
|
||||
pub fn deserialize(ctx: &mut DeContext) -> Result<Option<Self>, Error> {
|
||||
let decoder = &mut ctx.decoder;
|
||||
if decoder.buf.len() < 4 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
@ -30,8 +30,8 @@ impl Message {
|
|||
let tag = decoder.buf[4];
|
||||
|
||||
Ok(Some(match tag {
|
||||
0xFF => Message::ErrPacket(ErrPacket::deserialize(conn, decoder)?),
|
||||
0x00 | 0xFE => Message::OkPacket(OkPacket::deserialize(conn, decoder)?),
|
||||
0xFF => Message::ErrPacket(ErrPacket::deserialize(ctx)?),
|
||||
0x00 | 0xFE => Message::OkPacket(OkPacket::deserialize(ctx)?),
|
||||
_ => unimplemented!(),
|
||||
}))
|
||||
}
|
||||
|
|
24
src/main.rs
Normal file
24
src/main.rs
Normal file
|
@ -0,0 +1,24 @@
|
|||
#![feature(async_await)]
|
||||
|
||||
use mason::{pg::Connection, ConnectOptions};
|
||||
|
||||
#[runtime::main]
|
||||
async fn main() -> Result<(), failure::Error> {
|
||||
env_logger::try_init()?;
|
||||
|
||||
let mut conn =
|
||||
Connection::establish(ConnectOptions::new().user("postgres").password("password")).await?;
|
||||
|
||||
conn.execute("INSERT INTO \"users\" (name) VALUES ($1)")
|
||||
.bind(b"Joe")
|
||||
.await?;
|
||||
|
||||
conn.prepare("INSERT INTO \"users\" (name) VALUES ($1)")
|
||||
.bind(b"Joe")
|
||||
.execute()
|
||||
.await?;
|
||||
|
||||
conn.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in a new issue