Remove postgres::protocol::{Buf, BufMut} and use crate::io::{Buf, BufMut} instead

This commit is contained in:
Ryan Leckey 2019-08-28 11:01:55 -07:00
parent f67421b50d
commit c8559cac84
38 changed files with 737 additions and 419 deletions

View file

@ -31,6 +31,9 @@ memchr = "2.2.1"
tokio = { version = "=0.2.0-alpha.2", default-features = false, features = [ "tcp" ] } tokio = { version = "=0.2.0-alpha.2", default-features = false, features = [ "tcp" ] }
url = "2.1.0" url = "2.1.0"
[dev-dependencies]
matches = "0.1.8"
[profile.release] [profile.release]
lto = true lto = true
codegen-units = 1 codegen-units = 1

131
src/io/buf.rs Normal file
View file

@ -0,0 +1,131 @@
use byteorder::ByteOrder;
use memchr::memchr;
use std::{convert::TryInto, io, mem::size_of, str};
pub trait Buf {
fn advance(&mut self, cnt: usize);
fn get_u8(&mut self) -> io::Result<u8>;
fn get_u16<T: ByteOrder>(&mut self) -> io::Result<u16>;
fn get_u24<T: ByteOrder>(&mut self) -> io::Result<u32>;
fn get_i32<T: ByteOrder>(&mut self) -> io::Result<i32>;
fn get_u32<T: ByteOrder>(&mut self) -> io::Result<u32>;
fn get_u64<T: ByteOrder>(&mut self) -> io::Result<u64>;
// TODO?: Move to mariadb::io::BufExt
fn get_uint<T: ByteOrder>(&mut self, n: usize) -> io::Result<u64>;
// TODO?: Move to mariadb::io::BufExt
fn get_uint_lenenc<T: ByteOrder>(&mut self) -> io::Result<u64>;
fn get_str(&mut self, len: usize) -> io::Result<&str>;
// TODO?: Move to mariadb::io::BufExt
fn get_str_eof(&mut self) -> io::Result<&str>;
fn get_str_nul(&mut self) -> io::Result<&str>;
// TODO?: Move to mariadb::io::BufExt
fn get_str_lenenc<T: ByteOrder>(&mut self) -> io::Result<&str>;
}
impl<'a> Buf for &'a [u8] {
fn advance(&mut self, cnt: usize) {
*self = &self[cnt..];
}
fn get_u8(&mut self) -> io::Result<u8> {
let val = self[0];
self.advance(1);
Ok(val)
}
fn get_u16<T: ByteOrder>(&mut self) -> io::Result<u16> {
let val = T::read_u16(*self);
self.advance(2);
Ok(val)
}
fn get_i32<T: ByteOrder>(&mut self) -> io::Result<i32> {
let val = T::read_i32(*self);
self.advance(4);
Ok(val)
}
fn get_u24<T: ByteOrder>(&mut self) -> io::Result<u32> {
let val = T::read_u24(*self);
self.advance(3);
Ok(val)
}
fn get_u32<T: ByteOrder>(&mut self) -> io::Result<u32> {
let val = T::read_u32(*self);
self.advance(4);
Ok(val)
}
fn get_u64<T: ByteOrder>(&mut self) -> io::Result<u64> {
let val = T::read_u64(*self);
self.advance(8);
Ok(val)
}
fn get_uint<T: ByteOrder>(&mut self, n: usize) -> io::Result<u64> {
let val = T::read_uint(*self, n);
self.advance(n);
Ok(val)
}
fn get_uint_lenenc<T: ByteOrder>(&mut self) -> io::Result<u64> {
Ok(match self.get_u8()? {
0xFC => self.get_u16::<T>()? as u64,
0xFD => self.get_u24::<T>()? as u64,
0xFE => self.get_u64::<T>()? as u64,
// ? 0xFF => panic!("int<lenenc> unprocessable first byte 0xFF"),
value => value as u64,
})
}
fn get_str(&mut self, len: usize) -> io::Result<&str> {
let buf = &self[..len];
self.advance(len);
if cfg!(debug_asserts) {
str::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
} else {
Ok(unsafe { str::from_utf8_unchecked(buf) })
}
}
fn get_str_eof(&mut self) -> io::Result<&str> {
self.get_str(self.len())
}
fn get_str_nul(&mut self) -> io::Result<&str> {
let len = memchr(b'\0', &*self).ok_or(io::ErrorKind::InvalidData)?;
let s = &self.get_str(len + 1)?[..len];
Ok(s)
}
fn get_str_lenenc<T: ByteOrder>(&mut self) -> io::Result<&str> {
let len = self.get_uint_lenenc::<T>()?;
let s = self.get_str(len as usize)?;
Ok(s)
}
}

122
src/io/buf_mut.rs Normal file
View file

@ -0,0 +1,122 @@
use byteorder::ByteOrder;
use memchr::memchr;
use std::{io, mem::size_of, str, u16, u32, u8};
pub trait BufMut {
fn advance(&mut self, cnt: usize);
fn put_u8(&mut self, val: u8);
fn put_u16<T: ByteOrder>(&mut self, val: u16);
fn put_i16<T: ByteOrder>(&mut self, val: i16);
fn put_u24<T: ByteOrder>(&mut self, val: u32);
fn put_i32<T: ByteOrder>(&mut self, val: i32);
fn put_u32<T: ByteOrder>(&mut self, val: u32);
fn put_u64<T: ByteOrder>(&mut self, val: u64);
// TODO: Move to mariadb::io::BufMutExt
fn put_u64_lenenc<T: ByteOrder>(&mut self, val: u64);
fn put_str_nul(&mut self, val: &str);
// TODO: Move to mariadb::io::BufMutExt
fn put_str_lenenc<T: ByteOrder>(&mut self, val: &str);
// TODO: Move to mariadb::io::BufMutExt
fn put_str_eof(&mut self, val: &str);
}
impl BufMut for Vec<u8> {
fn advance(&mut self, cnt: usize) {
self.resize(self.len() + cnt, 0);
}
fn put_u8(&mut self, val: u8) {
self.push(val);
}
fn put_i16<T: ByteOrder>(&mut self, val: i16) {
let mut buf = [0; 4];
T::write_i16(&mut buf, val);
self.extend_from_slice(&buf);
}
fn put_u16<T: ByteOrder>(&mut self, val: u16) {
let mut buf = [0; 2];
T::write_u16(&mut buf, val);
self.extend_from_slice(&buf);
}
fn put_u24<T: ByteOrder>(&mut self, val: u32) {
let mut buf = [0; 3];
T::write_u24(&mut buf, val);
self.extend_from_slice(&buf);
}
fn put_i32<T: ByteOrder>(&mut self, val: i32) {
let mut buf = [0; 4];
T::write_i32(&mut buf, val);
self.extend_from_slice(&buf);
}
fn put_u32<T: ByteOrder>(&mut self, val: u32) {
let mut buf = [0; 4];
T::write_u32(&mut buf, val);
self.extend_from_slice(&buf);
}
fn put_u64<T: ByteOrder>(&mut self, val: u64) {
let mut buf = [0; 8];
T::write_u64(&mut buf, val);
self.extend_from_slice(&buf);
}
fn put_u64_lenenc<T: ByteOrder>(&mut self, value: u64) {
// https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers
if value > 0xFF_FF_FF {
// Integer value is encoded in the next 8 bytes (9 bytes total)
self.push(0xFE);
self.put_u64::<T>(value);
} else if value > u16::MAX as _ {
// Integer value is encoded in the next 3 bytes (4 bytes total)
self.push(0xFD);
self.put_u24::<T>(value as u32);
} else if value > u8::MAX as _ {
// Integer value is encoded in the next 2 bytes (3 bytes total)
self.push(0xFC);
self.put_u16::<T>(value as u16);
} else {
match value {
// If the value is of size u8 and one of the key bytes used in length encoding
// we must put that single byte as a u16
0xFB | 0xFC | 0xFD | 0xFE | 0xFF => {
self.push(0xFC);
self.put_u16::<T>(value as u16);
}
_ => {
self.push(value as u8);
}
}
}
}
fn put_str_eof(&mut self, val: &str) {
self.extend_from_slice(val.as_bytes());
}
fn put_str_nul(&mut self, val: &str) {
self.extend_from_slice(val.as_bytes());
self.push(0);
}
fn put_str_lenenc<T: ByteOrder>(&mut self, val: &str) {
self.put_u64_lenenc::<T>(val.len() as u64);
self.extend_from_slice(val.as_bytes());
}
}

26
src/io/byte_str.rs Normal file
View file

@ -0,0 +1,26 @@
use std::{
ascii::escape_default,
fmt::{self, Debug},
str::from_utf8,
};
// Wrapper type for byte slices that will debug print
// as a binary string
pub struct ByteStr<'a>(pub &'a [u8]);
impl Debug for ByteStr<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "b\"")?;
for &b in self.0 {
let part: Vec<u8> = escape_default(b).collect();
let s = from_utf8(&part).unwrap();
write!(f, "{}", s)?;
}
write!(f, "\"")?;
Ok(())
}
}

View file

@ -1,4 +1,8 @@
#[macro_use] #[macro_use]
mod buf_stream; mod buf_stream;
pub use self::buf_stream::BufStream; mod buf;
mod buf_mut;
mod byte_str;
pub use self::{buf::Buf, buf_mut::BufMut, buf_stream::BufStream, byte_str::ByteStr};

View file

@ -12,7 +12,7 @@ pub async fn execute(conn: &mut PostgresRawConnection) -> Result<u64, Error> {
Message::BindComplete | Message::ParseComplete | Message::DataRow(_) => {} Message::BindComplete | Message::ParseComplete | Message::DataRow(_) => {}
Message::CommandComplete(body) => { Message::CommandComplete(body) => {
rows = body.rows; rows = body.affected_rows();
} }
Message::ReadyForQuery(_) => { Message::ReadyForQuery(_) => {

View file

@ -4,7 +4,8 @@ use super::{
}; };
use crate::{connection::RawConnection, error::Error, io::BufStream, query::QueryParameters}; use crate::{connection::RawConnection, error::Error, io::BufStream, query::QueryParameters};
// use bytes::{BufMut, BytesMut}; // use bytes::{BufMut, BytesMut};
use super::protocol::Buf; use crate::io::Buf;
use byteorder::NetworkEndian;
use futures_core::{future::BoxFuture, stream::BoxStream}; use futures_core::{future::BoxFuture, stream::BoxStream};
use std::{ use std::{
io, io,
@ -69,16 +70,19 @@ impl PostgresRawConnection {
loop { loop {
// Read the message header (id + len) // Read the message header (id + len)
let mut header = ret_if_none!(self.stream.peek(5).await?); let mut header = ret_if_none!(self.stream.peek(5).await?);
log::trace!("recv:header {:?}", bytes::Bytes::from(&*header));
let id = header.get_u8()?; let id = header.get_u8()?;
let len = (header.get_u32()? - 4) as usize; let len = (header.get_u32::<NetworkEndian>()? - 4) as usize;
// Read the message body // Read the message body
self.stream.consume(5); self.stream.consume(5);
let body = ret_if_none!(self.stream.peek(len).await?); let body = ret_if_none!(self.stream.peek(len).await?);
log::trace!("recv {:?}", bytes::Bytes::from(&*body));
let message = match id { let message = match id {
b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)), b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)),
b'D' => Message::DataRow(Box::new(protocol::DataRow::decode(body)?)), b'D' => Message::DataRow(protocol::DataRow::decode(body)?),
b'S' => { b'S' => {
Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?)) Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?))
} }
@ -121,7 +125,14 @@ impl PostgresRawConnection {
} }
pub(super) fn write(&mut self, message: impl Encode) { pub(super) fn write(&mut self, message: impl Encode) {
let pos = self.stream.buffer_mut().len();
message.encode(self.stream.buffer_mut()); message.encode(self.stream.buffer_mut());
log::trace!(
"send {:?}",
bytes::Bytes::from(&self.stream.buffer_mut()[pos..])
);
} }
} }

View file

@ -1,7 +1,6 @@
mod backend; mod backend;
mod connection; mod connection;
// FIXME: Should only be public for benchmarks mod protocol;
pub mod protocol;
mod query; mod query;
mod row; mod row;
pub mod types; pub mod types;

View file

@ -1,4 +1,6 @@
use super::Decode; use super::Decode;
use crate::io::Buf;
use byteorder::NetworkEndian;
use std::io; use std::io;
#[derive(Debug)] #[derive(Debug)]
@ -28,8 +30,10 @@ pub enum Authentication {
GssContinue { data: Box<[u8]> }, GssContinue { data: Box<[u8]> },
/// SASL authentication is required. /// SASL authentication is required.
// FIXME: authentication mechanisms ///
Sasl, /// The message body is a list of SASL authentication mechanisms,
/// in the server's order of preference.
Sasl { mechanisms: Box<[Box<str>]> },
/// This message contains a SASL challenge. /// This message contains a SASL challenge.
SaslContinue { data: Box<[u8]> }, SaslContinue { data: Box<[u8]> },
@ -39,24 +43,100 @@ pub enum Authentication {
} }
impl Decode for Authentication { impl Decode for Authentication {
fn decode(src: &[u8]) -> io::Result<Self> { fn decode(mut buf: &[u8]) -> io::Result<Self> {
Ok(match src[0] { Ok(match buf.get_u32::<NetworkEndian>()? {
0 => Authentication::Ok, 0 => Authentication::Ok,
2 => Authentication::KerberosV5, 2 => Authentication::KerberosV5,
3 => Authentication::CleartextPassword, 3 => Authentication::CleartextPassword,
5 => { 5 => {
let mut salt = [0_u8; 4]; let mut salt = [0_u8; 4];
salt.copy_from_slice(&src[1..5]); salt.copy_from_slice(&buf);
Authentication::Md5Password { salt } Authentication::Md5Password { salt }
} }
6 => Authentication::ScmCredential, 6 => Authentication::ScmCredential,
7 => Authentication::Gss, 7 => Authentication::Gss,
8 => {
let mut data = Vec::with_capacity(buf.len());
data.extend_from_slice(buf);
Authentication::GssContinue {
data: data.into_boxed_slice(),
}
}
9 => Authentication::Sspi, 9 => Authentication::Sspi,
token => unimplemented!("decode not implemented for token: {}", token), 10 => {
let mut mechanisms = Vec::new();
while buf[0] != 0 {
mechanisms.push(buf.get_str_nul()?.into());
}
Authentication::Sasl {
mechanisms: mechanisms.into_boxed_slice(),
}
}
11 => {
let mut data = Vec::with_capacity(buf.len());
data.extend_from_slice(buf);
Authentication::SaslContinue {
data: data.into_boxed_slice(),
}
}
12 => {
let mut data = Vec::with_capacity(buf.len());
data.extend_from_slice(buf);
Authentication::SaslFinal {
data: data.into_boxed_slice(),
}
}
id => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unknown authentication response: {}", id),
));
}
}) })
} }
} }
#[cfg(test)]
mod tests {
use super::{Authentication, Decode};
use matches::assert_matches;
const AUTH_OK: &[u8] = b"\0\0\0\0";
const AUTH_MD5: &[u8] = b"\0\0\0\x05\x93\x189\x98";
#[test]
fn it_decodes_auth_ok() {
let m = Authentication::decode(AUTH_OK).unwrap();
assert_matches!(m, Authentication::Ok);
}
#[test]
fn it_decodes_auth_md5_password() {
let m = Authentication::decode(AUTH_MD5).unwrap();
assert_matches!(
m,
Authentication::Md5Password {
salt: [147, 24, 57, 152]
}
);
}
}

View file

@ -1,4 +1,6 @@
use super::{Buf, Decode}; use super::Decode;
use crate::io::Buf;
use byteorder::NetworkEndian;
use std::io; use std::io;
#[derive(Debug)] #[derive(Debug)]
@ -23,11 +25,9 @@ impl BackendKeyData {
} }
impl Decode for BackendKeyData { impl Decode for BackendKeyData {
fn decode(mut src: &[u8]) -> io::Result<Self> { fn decode(mut buf: &[u8]) -> io::Result<Self> {
debug_assert_eq!(src.len(), 8); let process_id = buf.get_u32::<NetworkEndian>()?;
let secret_key = buf.get_u32::<NetworkEndian>()?;
let process_id = src.get_u32()?;
let secret_key = src.get_u32()?;
Ok(Self { Ok(Self {
process_id, process_id,
@ -39,7 +39,6 @@ impl Decode for BackendKeyData {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{BackendKeyData, Decode}; use super::{BackendKeyData, Decode};
use bytes::Bytes;
const BACKEND_KEY_DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; const BACKEND_KEY_DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+";

View file

@ -1,5 +1,6 @@
use super::{BufMut, Encode}; use super::{Encode};
use byteorder::{BigEndian, ByteOrder}; use crate::io::BufMut;
use byteorder::{BigEndian, ByteOrder, NetworkEndian};
pub struct Bind<'a> { pub struct Bind<'a> {
/// The name of the destination portal (an empty string selects the unnamed portal). /// The name of the destination portal (an empty string selects the unnamed portal).
@ -29,24 +30,32 @@ pub struct Bind<'a> {
impl Encode for Bind<'_> { impl Encode for Bind<'_> {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'B'); buf.push(b'B');
let pos = buf.len(); let pos = buf.len();
buf.put_int_32(0); // skip over len buf.put_i32::<NetworkEndian>(0); // skip over len
buf.put_str(self.portal); buf.put_str_nul(self.portal);
buf.put_str(self.statement); buf.put_str_nul(self.statement);
buf.put_array_int_16(&self.formats); buf.put_i16::<NetworkEndian>(self.formats.len() as i16);
buf.put_int_16(self.values_len); for &format in self.formats {
buf.put_i16::<NetworkEndian>(format);
}
buf.put(self.values); buf.put_i16::<NetworkEndian>(self.values_len);
buf.put_array_int_16(&self.result_formats); buf.extend_from_slice(self.values);
buf.put_i16::<NetworkEndian>(self.result_formats.len() as i16);
for &format in self.result_formats {
buf.put_i16::<NetworkEndian>(format);
}
// Write-back the len to the beginning of this frame // Write-back the len to the beginning of this frame
let len = buf.len() - pos; let len = buf.len() - pos;
BigEndian::write_i32(&mut buf[pos..], len as i32); NetworkEndian::write_i32(&mut buf[pos..], len as i32);
} }
} }

View file

@ -1,4 +1,6 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
/// Sent instead of [`StartupMessage`] with a new connection to cancel a running query on an existing /// Sent instead of [`StartupMessage`] with a new connection to cancel a running query on an existing
/// connection. /// connection.
@ -14,9 +16,9 @@ pub struct CancelRequest {
impl Encode for CancelRequest { impl Encode for CancelRequest {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_int_32(16); // message length buf.put_i32::<NetworkEndian>(16); // message length
buf.put_int_32(8087_7102); // constant for cancel request buf.put_i32::<NetworkEndian>(8087_7102); // constant for cancel request
buf.put_int_32(self.process_id); buf.put_i32::<NetworkEndian>(self.process_id);
buf.put_int_32(self.secret_key); buf.put_i32::<NetworkEndian>(self.secret_key);
} }
} }

View file

@ -1,4 +1,6 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
#[repr(u8)] #[repr(u8)]
pub enum CloseKind { pub enum CloseKind {
@ -16,14 +18,17 @@ pub struct Close<'a> {
impl Encode for Close<'_> { impl Encode for Close<'_> {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'C'); buf.push(b'C');
// len + kind + nul + len(string) // len + kind + nul + len(string)
buf.put_int_32((4 + 1 + 1 + self.name.len()) as i32); buf.put_i32::<NetworkEndian>((4 + 1 + 1 + self.name.len()) as i32);
buf.put_byte(match self.kind {
buf.push(match self.kind {
CloseKind::PreparedStatement => b'S', CloseKind::PreparedStatement => b'S',
CloseKind::Portal => b'P', CloseKind::Portal => b'P',
}); });
buf.put_str(self.name);
buf.put_str_nul(self.name);
} }
} }

View file

@ -1,27 +1,37 @@
use super::Decode; use super::Decode;
use memchr::memrchr; use crate::io::Buf;
use std::{io, str}; use std::io;
#[derive(Debug)] #[derive(Debug)]
pub struct CommandComplete { pub struct CommandComplete {
pub rows: u64, affected_rows: u64,
}
impl CommandComplete {
#[inline]
pub fn affected_rows(&self) -> u64 {
self.affected_rows
}
} }
impl Decode for CommandComplete { impl Decode for CommandComplete {
fn decode(src: &[u8]) -> io::Result<Self> { fn decode(mut buf: &[u8]) -> io::Result<Self> {
// TODO: MariaDb/MySQL return 0 for affected rows in a SELECT .. statement.
// PostgreSQL returns a row count. Should we force return 0 for compatibilities sake?
// Attempt to parse the last word in the command tag as an integer // Attempt to parse the last word in the command tag as an integer
// If it can't be parased, the tag is probably "CREATE TABLE" or something // If it can't be parased, the tag is probably "CREATE TABLE" or something
// and we should return 0 rows // and we should return 0 rows
// TODO: Use [atoi] or similar to parse an integer directly from the bytes let rows = buf
.get_str_nul()?
let rows_start = memrchr(b' ', src).unwrap_or(0); .rsplit(' ')
let mut buf = &src[(rows_start + 1)..(src.len() - 1)]; .next()
.and_then(|s| s.parse().ok())
let rows = unsafe { str::from_utf8_unchecked(buf) }; .unwrap_or(0);
Ok(Self { Ok(Self {
rows: rows.parse().unwrap_or(0), affected_rows: rows,
}) })
} }
} }
@ -39,27 +49,27 @@ mod tests {
fn it_decodes_command_complete_for_insert() { fn it_decodes_command_complete_for_insert() {
let message = CommandComplete::decode(COMMAND_COMPLETE_INSERT).unwrap(); let message = CommandComplete::decode(COMMAND_COMPLETE_INSERT).unwrap();
assert_eq!(message.rows, 1); assert_eq!(message.affected_rows(), 1);
} }
#[test] #[test]
fn it_decodes_command_complete_for_update() { fn it_decodes_command_complete_for_update() {
let message = CommandComplete::decode(COMMAND_COMPLETE_UPDATE).unwrap(); let message = CommandComplete::decode(COMMAND_COMPLETE_UPDATE).unwrap();
assert_eq!(message.rows, 512); assert_eq!(message.affected_rows(), 512);
} }
#[test] #[test]
fn it_decodes_command_complete_for_begin() { fn it_decodes_command_complete_for_begin() {
let message = CommandComplete::decode(COMMAND_COMPLETE_BEGIN).unwrap(); let message = CommandComplete::decode(COMMAND_COMPLETE_BEGIN).unwrap();
assert_eq!(message.rows, 0); assert_eq!(message.affected_rows(), 0);
} }
#[test] #[test]
fn it_decodes_command_complete_for_create_table() { fn it_decodes_command_complete_for_create_table() {
let message = CommandComplete::decode(COMMAND_COMPLETE_CREATE_TABLE).unwrap(); let message = CommandComplete::decode(COMMAND_COMPLETE_CREATE_TABLE).unwrap();
assert_eq!(message.rows, 0); assert_eq!(message.affected_rows(), 0);
} }
} }

View file

@ -1,4 +1,6 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
// TODO: Implement Decode and think on an optimal representation // TODO: Implement Decode and think on an optimal representation
@ -19,9 +21,9 @@ pub struct CopyData<'a> {
impl Encode for CopyData<'_> { impl Encode for CopyData<'_> {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'd'); buf.push(b'd');
// len + nul + len(string) // len + nul + len(string)
buf.put_int_32((4 + 1 + self.data.len()) as i32); buf.put_i32::<NetworkEndian>((4 + 1 + self.data.len()) as i32);
buf.put(&self.data); buf.extend_from_slice(&self.data);
} }
} }

View file

@ -1,4 +1,6 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
// TODO: Implement Decode // TODO: Implement Decode
@ -7,7 +9,7 @@ pub struct CopyDone;
impl Encode for CopyDone { impl Encode for CopyDone {
#[inline] #[inline]
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'c'); buf.push(b'c');
buf.put_int_32(4); buf.put_i32::<NetworkEndian>(4);
} }
} }

View file

@ -1,4 +1,6 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
pub struct CopyFail<'a> { pub struct CopyFail<'a> {
pub error: &'a str, pub error: &'a str,
@ -6,9 +8,9 @@ pub struct CopyFail<'a> {
impl Encode for CopyFail<'_> { impl Encode for CopyFail<'_> {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'f'); buf.push(b'f');
// len + nul + len(string) // len + nul + len(string)
buf.put_int_32((4 + 1 + self.error.len()) as i32); buf.put_i32::<NetworkEndian>((4 + 1 + self.error.len()) as i32);
buf.put_str(&self.error); buf.put_str_nul(&self.error);
} }
} }

View file

@ -1,6 +1,7 @@
use super::{Buf, Decode}; use super::Decode;
use crate::io::{Buf, ByteStr};
use byteorder::NetworkEndian;
use std::{ use std::{
convert::TryInto,
fmt::{self, Debug}, fmt::{self, Debug},
io, io,
pin::Pin, pin::Pin,
@ -19,16 +20,16 @@ unsafe impl Sync for DataRow {}
impl Decode for DataRow { impl Decode for DataRow {
fn decode(mut buf: &[u8]) -> io::Result<Self> { fn decode(mut buf: &[u8]) -> io::Result<Self> {
let len = buf.get_u16()? as usize; let cnt = buf.get_u16::<NetworkEndian>()? as usize;
let buffer: Pin<Box<[u8]>> = Pin::new(buf.into()); let buffer: Pin<Box<[u8]>> = Pin::new(buf.into());
let mut buf = &*buffer; let mut buf = &*buffer;
let mut values = Vec::with_capacity(len); let mut values = Vec::with_capacity(cnt);
while values.len() < len { while values.len() < cnt {
// The length of the column value, in bytes (this count does not include itself). // The length of the column value, in bytes (this count does not include itself).
// Can be zero. As a special case, -1 indicates a NULL column value. // Can be zero. As a special case, -1 indicates a NULL column value.
// No value bytes follow in the NULL case. // No value bytes follow in the NULL case.
let value_len = buf.get_i32()?; let value_len = buf.get_i32::<NetworkEndian>()?;
if value_len == -1 { if value_len == -1 {
values.push(None); values.push(None);
@ -65,8 +66,16 @@ impl DataRow {
} }
impl Debug for DataRow { impl Debug for DataRow {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
unimplemented!(); write!(f, "DataRow(")?;
f.debug_list()
.entries((0..self.len()).map(|i| self.get(i).map(ByteStr)))
.finish()?;
write!(f, ")")?;
Ok(())
} }
} }
@ -79,17 +88,17 @@ mod tests {
#[test] #[test]
fn it_decodes_data_row() { fn it_decodes_data_row() {
let message = DataRow::decode(DATA_ROW).unwrap(); let m = DataRow::decode(DATA_ROW).unwrap();
assert_eq!(message.len(), 3); assert_eq!(m.len(), 3);
assert_eq!(message.get(0), Some(&b"1"[..])); assert_eq!(m.get(0), Some(&b"1"[..]));
assert_eq!(message.get(1), Some(&b"2"[..])); assert_eq!(m.get(1), Some(&b"2"[..]));
assert_eq!(message.get(2), Some(&b"3"[..])); assert_eq!(m.get(2), Some(&b"3"[..]));
}
#[bench] assert_eq!(
fn bench_decode_data_row(b: &mut test::Bencher) { format!("{:?}", m),
b.iter(|| DataRow::decode(DATA_ROW).unwrap()); "DataRow([Some(b\"1\"), Some(b\"2\"), Some(b\"3\")])"
);
} }
} }

View file

@ -1,86 +1,7 @@
use memchr::memchr; use std::io;
use std::{convert::TryInto, io, str};
pub trait Decode { pub trait Decode {
fn decode(src: &[u8]) -> io::Result<Self> fn decode(src: &[u8]) -> io::Result<Self>
where where
Self: Sized; Self: Sized;
} }
#[inline]
pub(crate) fn get_str(src: &[u8]) -> &str {
let end = memchr(b'\0', &src).expect("expected null terminator in UTF-8 string");
let buf = &src[..end];
unsafe { str::from_utf8_unchecked(buf) }
}
pub trait Buf {
fn advance(&mut self, cnt: usize);
// An n-bit integer in network byte order (IntN)
fn get_u8(&mut self) -> io::Result<u8>;
fn get_u16(&mut self) -> io::Result<u16>;
fn get_i32(&mut self) -> io::Result<i32>;
fn get_u32(&mut self) -> io::Result<u32>;
// A null-terminated string
fn get_str_null(&mut self) -> io::Result<&str>;
}
impl<'a> Buf for &'a [u8] {
fn advance(&mut self, cnt: usize) {
*self = &self[cnt..];
}
fn get_u8(&mut self) -> io::Result<u8> {
let val = self[0];
self.advance(1);
Ok(val)
}
fn get_u16(&mut self) -> io::Result<u16> {
let val: [u8; 2] = (&self[..2])
.try_into()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
self.advance(2);
Ok(u16::from_be_bytes(val))
}
fn get_i32(&mut self) -> io::Result<i32> {
let val: [u8; 4] = (&self[..4])
.try_into()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
self.advance(4);
Ok(i32::from_be_bytes(val))
}
fn get_u32(&mut self) -> io::Result<u32> {
let val: [u8; 4] = (&self[..4])
.try_into()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?;
self.advance(4);
Ok(u32::from_be_bytes(val))
}
fn get_str_null(&mut self) -> io::Result<&str> {
let end = memchr(b'\0', &*self).ok_or(io::ErrorKind::InvalidData)?;
let buf = &self[..end];
self.advance(end + 1);
if cfg!(debug_asserts) {
str::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
} else {
Ok(unsafe { str::from_utf8_unchecked(buf) })
}
}
}

View file

@ -1,4 +1,6 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
#[repr(u8)] #[repr(u8)]
pub enum DescribeKind { pub enum DescribeKind {
@ -16,14 +18,14 @@ pub struct Describe<'a> {
impl Encode for Describe<'_> { impl Encode for Describe<'_> {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'D'); buf.push(b'D');
// len + kind + nul + len(string) // len + kind + nul + len(string)
buf.put_int_32((4 + 1 + 1 + self.name.len()) as i32); buf.put_i32::<NetworkEndian>((4 + 1 + 1 + self.name.len()) as i32);
buf.put_byte(match self.kind { buf.push(match self.kind {
DescribeKind::PreparedStatement => b'S', DescribeKind::PreparedStatement => b'S',
DescribeKind::Portal => b'P', DescribeKind::Portal => b'P',
}); });
buf.put_str(self.name); buf.put_str_nul(self.name);
} }
} }

View file

@ -1,93 +1,3 @@
pub trait Encode { pub trait Encode {
fn encode(&self, buf: &mut Vec<u8>); fn encode(&self, buf: &mut Vec<u8>);
} }
pub trait BufMut {
fn put(&mut self, bytes: &[u8]);
fn put_byte(&mut self, value: u8);
fn put_int_16(&mut self, value: i16);
fn put_uint_16(&mut self, value: u16);
fn put_int_32(&mut self, value: i32);
fn put_uint_32(&mut self, value: u32);
fn put_array_int_16(&mut self, values: &[i16]);
fn put_array_int_32(&mut self, values: &[i32]);
fn put_array_uint_32(&mut self, values: &[u32]);
fn put_str(&mut self, value: &str);
}
impl BufMut for Vec<u8> {
#[inline]
fn put(&mut self, bytes: &[u8]) {
self.extend_from_slice(bytes);
}
#[inline]
fn put_byte(&mut self, value: u8) {
self.push(value);
}
#[inline]
fn put_int_16(&mut self, value: i16) {
self.extend_from_slice(&value.to_be_bytes());
}
#[inline]
fn put_uint_16(&mut self, value: u16) {
self.extend_from_slice(&value.to_be_bytes());
}
#[inline]
fn put_int_32(&mut self, value: i32) {
self.extend_from_slice(&value.to_be_bytes());
}
#[inline]
fn put_uint_32(&mut self, value: u32) {
self.extend_from_slice(&value.to_be_bytes());
}
#[inline]
fn put_str(&mut self, value: &str) {
self.extend_from_slice(value.as_bytes());
self.push(0);
}
#[inline]
fn put_array_int_16(&mut self, values: &[i16]) {
// FIXME: What happens here when len(values) > i16
self.put_int_16(values.len() as i16);
for value in values {
self.put_int_16(*value);
}
}
#[inline]
fn put_array_int_32(&mut self, values: &[i32]) {
// FIXME: What happens here when len(values) > i16
self.put_int_16(values.len() as i16);
for value in values {
self.put_int_32(*value);
}
}
#[inline]
fn put_array_uint_32(&mut self, values: &[u32]) {
// FIXME: What happens here when len(values) > i16
self.put_int_16(values.len() as i16);
for value in values {
self.put_uint_32(*value);
}
}
}

View file

@ -1,4 +1,6 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
pub struct Execute<'a> { pub struct Execute<'a> {
/// The name of the portal to execute (an empty string selects the unnamed portal). /// The name of the portal to execute (an empty string selects the unnamed portal).
@ -11,10 +13,10 @@ pub struct Execute<'a> {
impl Encode for Execute<'_> { impl Encode for Execute<'_> {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'E'); buf.push(b'E');
// len + nul + len(string) + limit // len + nul + len(string) + limit
buf.put_int_32((4 + 1 + self.portal.len() + 4) as i32); buf.put_i32::<NetworkEndian>((4 + 1 + self.portal.len() + 4) as i32);
buf.put_str(&self.portal); buf.put_str_nul(&self.portal);
buf.put_int_32(self.limit); buf.put_i32::<NetworkEndian>(self.limit);
} }
} }

View file

@ -1,11 +1,13 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
pub struct Flush; pub struct Flush;
impl Encode for Flush { impl Encode for Flush {
#[inline] #[inline]
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'H'); buf.push(b'H');
buf.put_int_32(4); buf.put_i32::<NetworkEndian>(4);
} }
} }

View file

@ -14,7 +14,7 @@ pub enum Message {
BackendKeyData(BackendKeyData), BackendKeyData(BackendKeyData),
ReadyForQuery(ReadyForQuery), ReadyForQuery(ReadyForQuery),
CommandComplete(CommandComplete), CommandComplete(CommandComplete),
DataRow(Box<DataRow>), DataRow(DataRow),
Response(Box<Response>), Response(Box<Response>),
NotificationResponse(Box<NotificationResponse>), NotificationResponse(Box<NotificationResponse>),
ParseComplete, ParseComplete,

View file

@ -32,7 +32,7 @@ pub use self::{
copy_done::CopyDone, copy_done::CopyDone,
copy_fail::CopyFail, copy_fail::CopyFail,
describe::Describe, describe::Describe,
encode::{BufMut, Encode}, encode::Encode,
execute::Execute, execute::Execute,
flush::Flush, flush::Flush,
parse::Parse, parse::Parse,
@ -43,30 +43,24 @@ pub use self::{
terminate::Terminate, terminate::Terminate,
}; };
// TODO: Audit backend protocol
mod authentication; mod authentication;
mod backend_key_data; mod backend_key_data;
mod command_complete; mod command_complete;
mod data_row; mod data_row;
mod decode; mod decode;
mod message;
mod notification_response; mod notification_response;
mod parameter_description; mod parameter_description;
mod parameter_status; mod parameter_status;
mod ready_for_query; mod ready_for_query;
mod response; mod response;
// TODO: Audit backend protocol
mod message;
pub use self::{ pub use self::{
authentication::Authentication, authentication::Authentication, backend_key_data::BackendKeyData,
backend_key_data::BackendKeyData, command_complete::CommandComplete, data_row::DataRow, decode::Decode, message::Message,
command_complete::CommandComplete, notification_response::NotificationResponse, parameter_description::ParameterDescription,
data_row::DataRow, parameter_status::ParameterStatus, ready_for_query::ReadyForQuery, response::Response,
decode::{Buf, Decode},
message::Message,
notification_response::NotificationResponse,
parameter_description::ParameterDescription,
parameter_status::ParameterStatus,
ready_for_query::ReadyForQuery,
response::Response,
}; };

View file

@ -1,10 +1,11 @@
use super::{Buf, Decode}; use super::Decode;
use byteorder::{BigEndian, ByteOrder}; use crate::io::Buf;
use byteorder::NetworkEndian;
use std::{fmt, io, pin::Pin, ptr::NonNull}; use std::{fmt, io, pin::Pin, ptr::NonNull};
pub struct NotificationResponse { pub struct NotificationResponse {
#[used] #[used]
storage: Pin<Vec<u8>>, buffer: Pin<Vec<u8>>,
pid: u32, pid: u32,
channel_name: NonNull<str>, channel_name: NonNull<str>,
message: NonNull<str>, message: NonNull<str>,
@ -44,18 +45,17 @@ impl fmt::Debug for NotificationResponse {
} }
impl Decode for NotificationResponse { impl Decode for NotificationResponse {
fn decode(mut src: &[u8]) -> io::Result<Self> { fn decode(mut buf: &[u8]) -> io::Result<Self> {
let pid = src.get_u32()?; let pid = buf.get_u32::<NetworkEndian>()?;
// offset from pid=4 let buffer = Pin::new(buf.into());
let storage = Pin::new(src.into()); let mut buf: &[u8] = &*buffer;
let mut src: &[u8] = &*storage;
let channel_name = src.get_str_null()?.into(); let channel_name = buf.get_str_nul()?.into();
let message = src.get_str_null()?.into(); let message = buf.get_str_nul()?.into();
Ok(Self { Ok(Self {
storage, buffer,
pid, pid,
channel_name, channel_name,
message, message,
@ -77,5 +77,11 @@ mod tests {
assert_eq!(message.pid(), 0x34201002); assert_eq!(message.pid(), 0x34201002);
assert_eq!(message.channel_name(), "TEST-CHANNEL"); assert_eq!(message.channel_name(), "TEST-CHANNEL");
assert_eq!(message.message(), "THIS IS A TEST"); assert_eq!(message.message(), "THIS IS A TEST");
assert_eq!(
format!("{:?}", message),
"NotificationResponse { pid: 874516482, channel_name: \"TEST-CHANNEL\", message: \
\"THIS IS A TEST\" }"
);
} }
} }

View file

@ -1,24 +1,23 @@
use super::{Buf, Decode}; use super::Decode;
use byteorder::{BigEndian, ByteOrder}; use crate::io::Buf;
use std::{io, mem::size_of}; use byteorder::NetworkEndian;
use std::io;
type ObjectId = u32;
#[derive(Debug)] #[derive(Debug)]
pub struct ParameterDescription { pub struct ParameterDescription {
ids: Box<[ObjectId]>, ids: Box<[u32]>,
} }
impl Decode for ParameterDescription { impl Decode for ParameterDescription {
fn decode(mut src: &[u8]) -> io::Result<Self> { fn decode(mut buf: &[u8]) -> io::Result<Self> {
let count = src.get_u16()?; let cnt = buf.get_u16::<NetworkEndian>()? as usize;
let mut ids = Vec::with_capacity(count as usize); let mut ids = Vec::with_capacity(cnt);
for i in 0..count { for i in 0..cnt {
ids.push(src.get_u32()?); ids.push(buf.get_u32::<NetworkEndian>()?);
} }
Ok(ParameterDescription { Ok(Self {
ids: ids.into_boxed_slice(), ids: ids.into_boxed_slice(),
}) })
} }
@ -31,8 +30,8 @@ mod test {
#[test] #[test]
fn it_decodes_parameter_description() { fn it_decodes_parameter_description() {
let src = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; let buf = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00";
let desc = ParameterDescription::decode(src).unwrap(); let desc = ParameterDescription::decode(buf).unwrap();
assert_eq!(desc.ids.len(), 2); assert_eq!(desc.ids.len(), 2);
assert_eq!(desc.ids[0], 0x0000_0000); assert_eq!(desc.ids[0], 0x0000_0000);
@ -41,8 +40,8 @@ mod test {
#[test] #[test]
fn it_decodes_empty_parameter_description() { fn it_decodes_empty_parameter_description() {
let src = b"\x00\x00"; let buf = b"\x00\x00";
let desc = ParameterDescription::decode(src).unwrap(); let desc = ParameterDescription::decode(buf).unwrap();
assert_eq!(desc.ids.len(), 0); assert_eq!(desc.ids.len(), 0);
} }

View file

@ -1,11 +1,16 @@
use super::decode::{Buf, Decode}; use super::decode::Decode;
use std::{io, pin::Pin, ptr::NonNull, str}; use crate::io::Buf;
use std::{
fmt::{self, Debug},
io,
pin::Pin,
ptr::NonNull,
str,
};
// FIXME: Use &str functions for a custom Debug
#[derive(Debug)]
pub struct ParameterStatus { pub struct ParameterStatus {
#[used] #[used]
storage: Pin<Box<[u8]>>, buffer: Pin<Box<[u8]>>,
name: NonNull<str>, name: NonNull<str>,
value: NonNull<str>, value: NonNull<str>,
} }
@ -29,21 +34,30 @@ impl ParameterStatus {
} }
impl Decode for ParameterStatus { impl Decode for ParameterStatus {
fn decode(src: &[u8]) -> io::Result<Self> { fn decode(buf: &[u8]) -> io::Result<Self> {
let storage = Pin::new(src.into()); let buffer = Pin::new(buf.into());
let mut src: &[u8] = &*storage; let mut buf: &[u8] = &*buffer;
let name = NonNull::from(src.get_str_null()?); let name = buf.get_str_nul()?.into();
let value = NonNull::from(src.get_str_null()?); let value = buf.get_str_nul()?.into();
Ok(Self { Ok(Self {
storage, buffer,
name, name,
value, value,
}) })
} }
} }
impl fmt::Debug for ParameterStatus {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ParameterStatus")
.field("name", &self.name())
.field("value", &self.value())
.finish()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{Decode, ParameterStatus}; use super::{Decode, ParameterStatus};
@ -56,10 +70,10 @@ mod tests {
assert_eq!(message.name(), "session_authorization"); assert_eq!(message.name(), "session_authorization");
assert_eq!(message.value(), "postgres"); assert_eq!(message.value(), "postgres");
}
#[bench] assert_eq!(
fn bench_decode_param_status(b: &mut test::Bencher) { format!("{:?}", message),
b.iter(|| ParameterStatus::decode(PARAM_STATUS).unwrap()); "ParameterStatus { name: \"session_authorization\", value: \"postgres\" }"
);
} }
} }

View file

@ -1,4 +1,6 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
pub struct Parse<'a> { pub struct Parse<'a> {
pub portal: &'a str, pub portal: &'a str,
@ -8,15 +10,19 @@ pub struct Parse<'a> {
impl Encode for Parse<'_> { impl Encode for Parse<'_> {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'P'); buf.push(b'P');
// len + portal + nul + query + null + len(param_types) + param_types // len + portal + nul + query + null + len(param_types) + param_types
let len = 4 + self.portal.len() + 1 + self.query.len() + 1 + 2 + self.param_types.len() * 4; let len = 4 + self.portal.len() + 1 + self.query.len() + 1 + 2 + self.param_types.len() * 4;
buf.put_int_32(len as i32); buf.put_i32::<NetworkEndian>(len as i32);
buf.put_str(self.portal); buf.put_str_nul(self.portal);
buf.put_str(self.query); buf.put_str_nul(self.query);
buf.put_array_uint_32(&self.param_types); buf.put_i16::<NetworkEndian>(self.param_types.len() as i16);
for &type_ in self.param_types {
buf.put_u32::<NetworkEndian>(type_);
}
} }
} }

View file

@ -1,4 +1,6 @@
use super::{BufMut, Encode}; use super::Encode;
use crate::io::BufMut;
use byteorder::NetworkEndian;
use md5::{Digest, Md5}; use md5::{Digest, Md5};
#[derive(Debug)] #[derive(Debug)]
@ -13,13 +15,13 @@ pub enum PasswordMessage<'a> {
impl Encode for PasswordMessage<'_> { impl Encode for PasswordMessage<'_> {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'p'); buf.push(b'p');
match self { match self {
PasswordMessage::Cleartext(s) => { PasswordMessage::Cleartext(s) => {
// len + password + nul // len + password + nul
buf.put_int_32((4 + s.len() + 1) as i32); buf.put_u32::<NetworkEndian>((4 + s.len() + 1) as u32);
buf.put_str(s); buf.put_str_nul(s);
} }
PasswordMessage::Md5 { PasswordMessage::Md5 {
@ -40,11 +42,44 @@ impl Encode for PasswordMessage<'_> {
let salted = hex::encode(hasher.result()); let salted = hex::encode(hasher.result());
// len + "md5" + (salted) // len + "md5" + (salted)
buf.put_int_32((4 + 3 + salted.len()) as i32); buf.put_u32::<NetworkEndian>((4 + 3 + salted.len() + 1) as u32);
buf.put(b"md5"); buf.extend_from_slice(b"md5");
buf.put(salted.as_bytes()); buf.extend_from_slice(salted.as_bytes());
buf.push(0);
} }
} }
} }
} }
#[cfg(test)]
mod tests {
use super::{Encode, PasswordMessage};
const PASSWORD_CLEAR: &[u8] = b"p\0\0\0\rpassword\0";
const PASSWORD_MD5: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0";
#[test]
fn it_encodes_password_clear() {
let mut buf = Vec::new();
let m = PasswordMessage::Cleartext("password");
m.encode(&mut buf);
assert_eq!(buf, PASSWORD_CLEAR);
}
#[test]
fn it_encodes_password_md5() {
let mut buf = Vec::new();
let m = PasswordMessage::Md5 {
password: "password",
user: "root",
salt: [147, 24, 57, 152],
};
m.encode(&mut buf);
assert_eq!(buf, PASSWORD_MD5);
}
}

View file

@ -1,15 +1,17 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
pub struct Query<'a>(pub &'a str); pub struct Query<'a>(pub &'a str);
impl Encode for Query<'_> { impl Encode for Query<'_> {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'Q'); buf.push(b'Q');
// len + query + nul // len + query + nul
buf.put_int_32((4 + self.0.len() + 1) as i32); buf.put_i32::<NetworkEndian>((4 + self.0.len() + 1) as i32);
buf.put_str(self.0); buf.put_str_nul(self.0);
} }
} }

View file

@ -28,10 +28,9 @@ impl ReadyForQuery {
} }
impl Decode for ReadyForQuery { impl Decode for ReadyForQuery {
fn decode(src: &[u8]) -> io::Result<Self> { fn decode(buf: &[u8]) -> io::Result<Self> {
Ok(Self { Ok(Self {
status: match src[0] { status: match buf[0] {
// FIXME: Variant value are duplicated with declaration
b'I' => TransactionStatus::Idle, b'I' => TransactionStatus::Idle,
b'T' => TransactionStatus::Transaction, b'T' => TransactionStatus::Transaction,
b'E' => TransactionStatus::Error, b'E' => TransactionStatus::Error,

View file

@ -1,4 +1,5 @@
use super::{decode::get_str, Decode}; use super::Decode;
use crate::io::Buf;
use std::{ use std::{
fmt, io, fmt, io,
pin::Pin, pin::Pin,
@ -73,10 +74,9 @@ impl FromStr for Severity {
} }
} }
#[derive(Clone)]
pub struct Response { pub struct Response {
#[used] #[used]
storage: Pin<Box<[u8]>>, buffer: Pin<Box<[u8]>>,
severity: Severity, severity: Severity,
code: NonNull<str>, code: NonNull<str>,
message: NonNull<str>, message: NonNull<str>,
@ -225,44 +225,41 @@ impl fmt::Debug for Response {
} }
impl Decode for Response { impl Decode for Response {
fn decode(src: &[u8]) -> io::Result<Self> { fn decode(buf: &[u8]) -> io::Result<Self> {
let storage: Pin<Box<[u8]>> = Pin::new(src.into()); let buffer: Pin<Box<[u8]>> = Pin::new(buf.into());
let mut buf: &[u8] = &*buffer;
let mut code = None::<&str>; let mut code = None::<NonNull<str>>;
let mut message = None::<&str>; let mut message = None::<NonNull<str>>;
let mut severity = None::<&str>; let mut severity = None::<NonNull<str>>;
let mut severity_non_local = None::<Severity>; let mut severity_non_local = None::<Severity>;
let mut detail = None::<&str>; let mut detail = None::<NonNull<str>>;
let mut hint = None::<&str>; let mut hint = None::<NonNull<str>>;
let mut position = None::<usize>; let mut position = None::<usize>;
let mut internal_position = None::<usize>; let mut internal_position = None::<usize>;
let mut internal_query = None::<&str>; let mut internal_query = None::<NonNull<str>>;
let mut where_ = None::<&str>; let mut where_ = None::<NonNull<str>>;
let mut schema = None::<&str>; let mut schema = None::<NonNull<str>>;
let mut table = None::<&str>; let mut table = None::<NonNull<str>>;
let mut column = None::<&str>; let mut column = None::<NonNull<str>>;
let mut data_type = None::<&str>; let mut data_type = None::<NonNull<str>>;
let mut constraint = None::<&str>; let mut constraint = None::<NonNull<str>>;
let mut file = None::<&str>; let mut file = None::<NonNull<str>>;
let mut line = None::<usize>; let mut line = None::<usize>;
let mut routine = None::<&str>; let mut routine = None::<NonNull<str>>;
let mut idx = 0;
loop { loop {
let field_type = storage[idx]; let field_type = buf.get_u8()?;
idx += 1;
if field_type == 0 { if field_type == 0 {
break; break;
} }
let field_value = get_str(&storage[idx..]); let field_value = buf.get_str_nul()?;
idx += field_value.len() + 1;
match field_type { match field_type {
b'S' => { b'S' => {
severity = Some(field_value); severity = Some(field_value.into());
} }
b'V' => { b'V' => {
@ -270,19 +267,19 @@ impl Decode for Response {
} }
b'C' => { b'C' => {
code = Some(field_value); code = Some(field_value.into());
} }
b'M' => { b'M' => {
message = Some(field_value); message = Some(field_value.into());
} }
b'D' => { b'D' => {
detail = Some(field_value); detail = Some(field_value.into());
} }
b'H' => { b'H' => {
hint = Some(field_value); hint = Some(field_value.into());
} }
b'P' => { b'P' => {
@ -302,35 +299,35 @@ impl Decode for Response {
} }
b'q' => { b'q' => {
internal_query = Some(field_value); internal_query = Some(field_value.into());
} }
b'w' => { b'w' => {
where_ = Some(field_value); where_ = Some(field_value.into());
} }
b's' => { b's' => {
schema = Some(field_value); schema = Some(field_value.into());
} }
b't' => { b't' => {
table = Some(field_value); table = Some(field_value.into());
} }
b'c' => { b'c' => {
column = Some(field_value); column = Some(field_value.into());
} }
b'd' => { b'd' => {
data_type = Some(field_value); data_type = Some(field_value.into());
} }
b'n' => { b'n' => {
constraint = Some(field_value); constraint = Some(field_value.into());
} }
b'F' => { b'F' => {
file = Some(field_value); file = Some(field_value.into());
} }
b'L' => { b'L' => {
@ -342,38 +339,43 @@ impl Decode for Response {
} }
b'R' => { b'R' => {
routine = Some(field_value); routine = Some(field_value.into());
} }
_ => { _ => {
unimplemented!( // TODO: Should we return these somehow, like in a map?
"response message field {:?} not implemented", return Err(io::Error::new(
field_type as char io::ErrorKind::InvalidData,
); format!("received unknown field in Response: {}", field_type),
));
} }
} }
} }
let severity = severity_non_local let severity = severity_non_local
.or_else(move || severity?.parse().ok()) .or_else(move || unsafe { severity?.as_ref() }.parse().ok())
.expect("`severity` required by protocol"); .ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"did not receieve field `severity` for Response",
)
})?;
let code = NonNull::from(code.expect("`code` required by protocol")); let code = code.ok_or_else(|| {
let message = NonNull::from(message.expect("`message` required by protocol")); io::Error::new(
let detail = detail.map(NonNull::from); io::ErrorKind::InvalidData,
let hint = hint.map(NonNull::from); "did not receieve field `code` for Response",
let internal_query = internal_query.map(NonNull::from); )
let where_ = where_.map(NonNull::from); })?;
let schema = schema.map(NonNull::from); let message = message.ok_or_else(|| {
let table = table.map(NonNull::from); io::Error::new(
let column = column.map(NonNull::from); io::ErrorKind::InvalidData,
let data_type = data_type.map(NonNull::from); "did not receieve field `message` for Response",
let constraint = constraint.map(NonNull::from); )
let file = file.map(NonNull::from); })?;
let routine = routine.map(NonNull::from);
Ok(Self { Ok(Self {
storage, buffer,
severity, severity,
code, code,
message, message,

View file

@ -1,5 +1,7 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::{BigEndian, ByteOrder}; use byteorder::{BigEndian, ByteOrder};
use byteorder::NetworkEndian;
pub struct StartupMessage<'a> { pub struct StartupMessage<'a> {
pub params: &'a [(&'a str, &'a str)], pub params: &'a [(&'a str, &'a str)],
@ -8,17 +10,17 @@ pub struct StartupMessage<'a> {
impl Encode for StartupMessage<'_> { impl Encode for StartupMessage<'_> {
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
let pos = buf.len(); let pos = buf.len();
buf.put_int_32(0); // skip over len buf.put_i32::<NetworkEndian>(0); // skip over len
// protocol version number (3.0) // protocol version number (3.0)
buf.put_int_32(196_608); buf.put_i32::<NetworkEndian>(196_608);
for (name, value) in self.params { for (name, value) in self.params {
buf.put_str(name); buf.put_str_nul(name);
buf.put_str(value); buf.put_str_nul(value);
} }
buf.put_byte(0); buf.push(0);
// Write-back the len to the beginning of this frame // Write-back the len to the beginning of this frame
let len = buf.len() - pos; let len = buf.len() - pos;

View file

@ -1,11 +1,13 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
pub struct Sync; pub struct Sync;
impl Encode for Sync { impl Encode for Sync {
#[inline] #[inline]
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'S'); buf.push(b'S');
buf.put_int_32(4); buf.put_i32::<NetworkEndian>(4);
} }
} }

View file

@ -1,11 +1,13 @@
use super::{BufMut, Encode}; use super::{Encode};
use crate::io::BufMut;
use byteorder::NetworkEndian;
pub struct Terminate; pub struct Terminate;
impl Encode for Terminate { impl Encode for Terminate {
#[inline] #[inline]
fn encode(&self, buf: &mut Vec<u8>) { fn encode(&self, buf: &mut Vec<u8>) {
buf.put_byte(b'X'); buf.push(b'X');
buf.put_int_32(4); buf.put_i32::<NetworkEndian>(4);
} }
} }

View file

@ -1,13 +1,15 @@
use super::{ use super::{
protocol::{self, BufMut}, protocol,
Postgres, PostgresRawConnection, Postgres, PostgresRawConnection,
}; };
use crate::{ use crate::{
io::BufMut,
query::QueryParameters, query::QueryParameters,
serialize::{IsNull, ToSql}, serialize::{IsNull, ToSql},
types::HasSqlType, types::HasSqlType,
}; };
use byteorder::{BigEndian, ByteOrder}; use byteorder::{BigEndian, ByteOrder};
use byteorder::NetworkEndian;
pub struct PostgresQueryParameters { pub struct PostgresQueryParameters {
// OIDs of the bind parameters // OIDs of the bind parameters
@ -40,7 +42,7 @@ impl QueryParameters for PostgresQueryParameters {
self.types.push(<Postgres as HasSqlType<T>>::metadata().oid); self.types.push(<Postgres as HasSqlType<T>>::metadata().oid);
let pos = self.buf.len(); let pos = self.buf.len();
self.buf.put_int_32(0); self.buf.put_i32::<NetworkEndian>(0);
let len = if let IsNull::No = value.to_sql(&mut self.buf) { let len = if let IsNull::No = value.to_sql(&mut self.buf) {
(self.buf.len() - pos - 4) as i32 (self.buf.len() - pos - 4) as i32

View file

@ -1,7 +1,7 @@
use super::{protocol::DataRow, Postgres}; use super::{protocol::DataRow, Postgres};
use crate::row::Row; use crate::row::Row;
pub struct PostgresRow(pub(crate) Box<DataRow>); pub struct PostgresRow(pub(crate) DataRow);
impl Row for PostgresRow { impl Row for PostgresRow {
type Backend = Postgres; type Backend = Postgres;