From c8559cac84b12f7891083fa6d4f15c0a260a3b59 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Wed, 28 Aug 2019 11:01:55 -0700 Subject: [PATCH] Remove postgres::protocol::{Buf, BufMut} and use crate::io::{Buf, BufMut} instead --- Cargo.toml | 3 + src/io/buf.rs | 131 ++++++++++++++++++ src/io/buf_mut.rs | 122 ++++++++++++++++ src/io/byte_str.rs | 26 ++++ src/io/mod.rs | 6 +- src/postgres/connection/execute.rs | 2 +- src/postgres/connection/mod.rs | 17 ++- src/postgres/mod.rs | 3 +- src/postgres/protocol/authentication.rs | 92 +++++++++++- src/postgres/protocol/backend_key_data.rs | 13 +- src/postgres/protocol/bind.rs | 31 +++-- src/postgres/protocol/cancel_request.rs | 12 +- src/postgres/protocol/close.rs | 15 +- src/postgres/protocol/command_complete.rs | 40 ++++-- src/postgres/protocol/copy_data.rs | 10 +- src/postgres/protocol/copy_done.rs | 8 +- src/postgres/protocol/copy_fail.rs | 10 +- src/postgres/protocol/data_row.rs | 43 +++--- src/postgres/protocol/decode.rs | 81 +---------- src/postgres/protocol/describe.rs | 12 +- src/postgres/protocol/encode.rs | 90 ------------ src/postgres/protocol/execute.rs | 12 +- src/postgres/protocol/flush.rs | 8 +- src/postgres/protocol/message.rs | 2 +- src/postgres/protocol/mod.rs | 24 ++-- .../protocol/notification_response.rs | 28 ++-- .../protocol/parameter_description.rs | 31 ++--- src/postgres/protocol/parameter_status.rs | 44 ++++-- src/postgres/protocol/parse.rs | 18 ++- src/postgres/protocol/password_message.rs | 49 ++++++- src/postgres/protocol/query.rs | 10 +- src/postgres/protocol/ready_for_query.rs | 5 +- src/postgres/protocol/response.rs | 120 ++++++++-------- src/postgres/protocol/startup_message.rs | 14 +- src/postgres/protocol/sync.rs | 8 +- src/postgres/protocol/terminate.rs | 8 +- src/postgres/query.rs | 6 +- src/postgres/row.rs | 2 +- 38 files changed, 737 insertions(+), 419 deletions(-) create mode 100644 src/io/buf.rs create mode 100644 src/io/buf_mut.rs create mode 100644 src/io/byte_str.rs diff --git a/Cargo.toml b/Cargo.toml index 5f9d7549..207cad74 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,9 @@ memchr = "2.2.1" tokio = { version = "=0.2.0-alpha.2", default-features = false, features = [ "tcp" ] } url = "2.1.0" +[dev-dependencies] +matches = "0.1.8" + [profile.release] lto = true codegen-units = 1 diff --git a/src/io/buf.rs b/src/io/buf.rs new file mode 100644 index 00000000..566bacfb --- /dev/null +++ b/src/io/buf.rs @@ -0,0 +1,131 @@ +use byteorder::ByteOrder; +use memchr::memchr; +use std::{convert::TryInto, io, mem::size_of, str}; + +pub trait Buf { + fn advance(&mut self, cnt: usize); + + fn get_u8(&mut self) -> io::Result; + + fn get_u16(&mut self) -> io::Result; + + fn get_u24(&mut self) -> io::Result; + + fn get_i32(&mut self) -> io::Result; + + fn get_u32(&mut self) -> io::Result; + + fn get_u64(&mut self) -> io::Result; + + // TODO?: Move to mariadb::io::BufExt + fn get_uint(&mut self, n: usize) -> io::Result; + + // TODO?: Move to mariadb::io::BufExt + fn get_uint_lenenc(&mut self) -> io::Result; + + fn get_str(&mut self, len: usize) -> io::Result<&str>; + + // TODO?: Move to mariadb::io::BufExt + fn get_str_eof(&mut self) -> io::Result<&str>; + + fn get_str_nul(&mut self) -> io::Result<&str>; + + // TODO?: Move to mariadb::io::BufExt + fn get_str_lenenc(&mut self) -> io::Result<&str>; +} + +impl<'a> Buf for &'a [u8] { + fn advance(&mut self, cnt: usize) { + *self = &self[cnt..]; + } + + fn get_u8(&mut self) -> io::Result { + let val = self[0]; + + self.advance(1); + + Ok(val) + } + + fn get_u16(&mut self) -> io::Result { + let val = T::read_u16(*self); + self.advance(2); + + Ok(val) + } + + fn get_i32(&mut self) -> io::Result { + let val = T::read_i32(*self); + self.advance(4); + + Ok(val) + } + + fn get_u24(&mut self) -> io::Result { + let val = T::read_u24(*self); + self.advance(3); + + Ok(val) + } + + fn get_u32(&mut self) -> io::Result { + let val = T::read_u32(*self); + self.advance(4); + + Ok(val) + } + + fn get_u64(&mut self) -> io::Result { + let val = T::read_u64(*self); + self.advance(8); + + Ok(val) + } + + fn get_uint(&mut self, n: usize) -> io::Result { + let val = T::read_uint(*self, n); + self.advance(n); + + Ok(val) + } + + fn get_uint_lenenc(&mut self) -> io::Result { + Ok(match self.get_u8()? { + 0xFC => self.get_u16::()? as u64, + 0xFD => self.get_u24::()? as u64, + 0xFE => self.get_u64::()? as u64, + // ? 0xFF => panic!("int unprocessable first byte 0xFF"), + value => value as u64, + }) + } + + fn get_str(&mut self, len: usize) -> io::Result<&str> { + let buf = &self[..len]; + + self.advance(len); + + if cfg!(debug_asserts) { + str::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) + } else { + Ok(unsafe { str::from_utf8_unchecked(buf) }) + } + } + + fn get_str_eof(&mut self) -> io::Result<&str> { + self.get_str(self.len()) + } + + fn get_str_nul(&mut self) -> io::Result<&str> { + let len = memchr(b'\0', &*self).ok_or(io::ErrorKind::InvalidData)?; + let s = &self.get_str(len + 1)?[..len]; + + Ok(s) + } + + fn get_str_lenenc(&mut self) -> io::Result<&str> { + let len = self.get_uint_lenenc::()?; + let s = self.get_str(len as usize)?; + + Ok(s) + } +} diff --git a/src/io/buf_mut.rs b/src/io/buf_mut.rs new file mode 100644 index 00000000..0ed2fb01 --- /dev/null +++ b/src/io/buf_mut.rs @@ -0,0 +1,122 @@ +use byteorder::ByteOrder; +use memchr::memchr; +use std::{io, mem::size_of, str, u16, u32, u8}; + +pub trait BufMut { + fn advance(&mut self, cnt: usize); + + fn put_u8(&mut self, val: u8); + + fn put_u16(&mut self, val: u16); + + fn put_i16(&mut self, val: i16); + + fn put_u24(&mut self, val: u32); + + fn put_i32(&mut self, val: i32); + + fn put_u32(&mut self, val: u32); + + fn put_u64(&mut self, val: u64); + + // TODO: Move to mariadb::io::BufMutExt + fn put_u64_lenenc(&mut self, val: u64); + + fn put_str_nul(&mut self, val: &str); + + // TODO: Move to mariadb::io::BufMutExt + fn put_str_lenenc(&mut self, val: &str); + + // TODO: Move to mariadb::io::BufMutExt + fn put_str_eof(&mut self, val: &str); +} + +impl BufMut for Vec { + fn advance(&mut self, cnt: usize) { + self.resize(self.len() + cnt, 0); + } + + fn put_u8(&mut self, val: u8) { + self.push(val); + } + + fn put_i16(&mut self, val: i16) { + let mut buf = [0; 4]; + T::write_i16(&mut buf, val); + self.extend_from_slice(&buf); + } + + fn put_u16(&mut self, val: u16) { + let mut buf = [0; 2]; + T::write_u16(&mut buf, val); + self.extend_from_slice(&buf); + } + + fn put_u24(&mut self, val: u32) { + let mut buf = [0; 3]; + T::write_u24(&mut buf, val); + self.extend_from_slice(&buf); + } + + fn put_i32(&mut self, val: i32) { + let mut buf = [0; 4]; + T::write_i32(&mut buf, val); + self.extend_from_slice(&buf); + } + + fn put_u32(&mut self, val: u32) { + let mut buf = [0; 4]; + T::write_u32(&mut buf, val); + self.extend_from_slice(&buf); + } + + fn put_u64(&mut self, val: u64) { + let mut buf = [0; 8]; + T::write_u64(&mut buf, val); + self.extend_from_slice(&buf); + } + + fn put_u64_lenenc(&mut self, value: u64) { + // https://mariadb.com/kb/en/library/protocol-data-types/#length-encoded-integers + if value > 0xFF_FF_FF { + // Integer value is encoded in the next 8 bytes (9 bytes total) + self.push(0xFE); + self.put_u64::(value); + } else if value > u16::MAX as _ { + // Integer value is encoded in the next 3 bytes (4 bytes total) + self.push(0xFD); + self.put_u24::(value as u32); + } else if value > u8::MAX as _ { + // Integer value is encoded in the next 2 bytes (3 bytes total) + self.push(0xFC); + self.put_u16::(value as u16); + } else { + match value { + // If the value is of size u8 and one of the key bytes used in length encoding + // we must put that single byte as a u16 + 0xFB | 0xFC | 0xFD | 0xFE | 0xFF => { + self.push(0xFC); + self.put_u16::(value as u16); + } + + _ => { + self.push(value as u8); + } + } + } + } + + fn put_str_eof(&mut self, val: &str) { + self.extend_from_slice(val.as_bytes()); + } + + fn put_str_nul(&mut self, val: &str) { + self.extend_from_slice(val.as_bytes()); + self.push(0); + } + + fn put_str_lenenc(&mut self, val: &str) { + self.put_u64_lenenc::(val.len() as u64); + self.extend_from_slice(val.as_bytes()); + } +} diff --git a/src/io/byte_str.rs b/src/io/byte_str.rs new file mode 100644 index 00000000..360b4e12 --- /dev/null +++ b/src/io/byte_str.rs @@ -0,0 +1,26 @@ +use std::{ + ascii::escape_default, + fmt::{self, Debug}, + str::from_utf8, +}; + +// Wrapper type for byte slices that will debug print +// as a binary string +pub struct ByteStr<'a>(pub &'a [u8]); + +impl Debug for ByteStr<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "b\"")?; + + for &b in self.0 { + let part: Vec = escape_default(b).collect(); + let s = from_utf8(&part).unwrap(); + + write!(f, "{}", s)?; + } + + write!(f, "\"")?; + + Ok(()) + } +} diff --git a/src/io/mod.rs b/src/io/mod.rs index 466e39b4..a2098395 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -1,4 +1,8 @@ #[macro_use] mod buf_stream; -pub use self::buf_stream::BufStream; +mod buf; +mod buf_mut; +mod byte_str; + +pub use self::{buf::Buf, buf_mut::BufMut, buf_stream::BufStream, byte_str::ByteStr}; diff --git a/src/postgres/connection/execute.rs b/src/postgres/connection/execute.rs index 6fdd923a..2e12e239 100644 --- a/src/postgres/connection/execute.rs +++ b/src/postgres/connection/execute.rs @@ -12,7 +12,7 @@ pub async fn execute(conn: &mut PostgresRawConnection) -> Result { Message::BindComplete | Message::ParseComplete | Message::DataRow(_) => {} Message::CommandComplete(body) => { - rows = body.rows; + rows = body.affected_rows(); } Message::ReadyForQuery(_) => { diff --git a/src/postgres/connection/mod.rs b/src/postgres/connection/mod.rs index 7aa2c7a1..0b414e80 100644 --- a/src/postgres/connection/mod.rs +++ b/src/postgres/connection/mod.rs @@ -4,7 +4,8 @@ use super::{ }; use crate::{connection::RawConnection, error::Error, io::BufStream, query::QueryParameters}; // use bytes::{BufMut, BytesMut}; -use super::protocol::Buf; +use crate::io::Buf; +use byteorder::NetworkEndian; use futures_core::{future::BoxFuture, stream::BoxStream}; use std::{ io, @@ -69,16 +70,19 @@ impl PostgresRawConnection { loop { // Read the message header (id + len) let mut header = ret_if_none!(self.stream.peek(5).await?); + log::trace!("recv:header {:?}", bytes::Bytes::from(&*header)); + let id = header.get_u8()?; - let len = (header.get_u32()? - 4) as usize; + let len = (header.get_u32::()? - 4) as usize; // Read the message body self.stream.consume(5); let body = ret_if_none!(self.stream.peek(len).await?); + log::trace!("recv {:?}", bytes::Bytes::from(&*body)); let message = match id { b'N' | b'E' => Message::Response(Box::new(protocol::Response::decode(body)?)), - b'D' => Message::DataRow(Box::new(protocol::DataRow::decode(body)?)), + b'D' => Message::DataRow(protocol::DataRow::decode(body)?), b'S' => { Message::ParameterStatus(Box::new(protocol::ParameterStatus::decode(body)?)) } @@ -121,7 +125,14 @@ impl PostgresRawConnection { } pub(super) fn write(&mut self, message: impl Encode) { + let pos = self.stream.buffer_mut().len(); + message.encode(self.stream.buffer_mut()); + + log::trace!( + "send {:?}", + bytes::Bytes::from(&self.stream.buffer_mut()[pos..]) + ); } } diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs index 2716af0f..ea5bfcd9 100644 --- a/src/postgres/mod.rs +++ b/src/postgres/mod.rs @@ -1,7 +1,6 @@ mod backend; mod connection; -// FIXME: Should only be public for benchmarks -pub mod protocol; +mod protocol; mod query; mod row; pub mod types; diff --git a/src/postgres/protocol/authentication.rs b/src/postgres/protocol/authentication.rs index 17642c6a..600cac8d 100644 --- a/src/postgres/protocol/authentication.rs +++ b/src/postgres/protocol/authentication.rs @@ -1,4 +1,6 @@ use super::Decode; +use crate::io::Buf; +use byteorder::NetworkEndian; use std::io; #[derive(Debug)] @@ -28,8 +30,10 @@ pub enum Authentication { GssContinue { data: Box<[u8]> }, /// SASL authentication is required. - // FIXME: authentication mechanisms - Sasl, + /// + /// The message body is a list of SASL authentication mechanisms, + /// in the server's order of preference. + Sasl { mechanisms: Box<[Box]> }, /// This message contains a SASL challenge. SaslContinue { data: Box<[u8]> }, @@ -39,24 +43,100 @@ pub enum Authentication { } impl Decode for Authentication { - fn decode(src: &[u8]) -> io::Result { - Ok(match src[0] { + fn decode(mut buf: &[u8]) -> io::Result { + Ok(match buf.get_u32::()? { 0 => Authentication::Ok, + 2 => Authentication::KerberosV5, + 3 => Authentication::CleartextPassword, 5 => { let mut salt = [0_u8; 4]; - salt.copy_from_slice(&src[1..5]); + salt.copy_from_slice(&buf); Authentication::Md5Password { salt } } 6 => Authentication::ScmCredential, + 7 => Authentication::Gss, + + 8 => { + let mut data = Vec::with_capacity(buf.len()); + data.extend_from_slice(buf); + + Authentication::GssContinue { + data: data.into_boxed_slice(), + } + } + 9 => Authentication::Sspi, - token => unimplemented!("decode not implemented for token: {}", token), + 10 => { + let mut mechanisms = Vec::new(); + + while buf[0] != 0 { + mechanisms.push(buf.get_str_nul()?.into()); + } + + Authentication::Sasl { + mechanisms: mechanisms.into_boxed_slice(), + } + } + + 11 => { + let mut data = Vec::with_capacity(buf.len()); + data.extend_from_slice(buf); + + Authentication::SaslContinue { + data: data.into_boxed_slice(), + } + } + + 12 => { + let mut data = Vec::with_capacity(buf.len()); + data.extend_from_slice(buf); + + Authentication::SaslFinal { + data: data.into_boxed_slice(), + } + } + + id => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("unknown authentication response: {}", id), + )); + } }) } } + +#[cfg(test)] +mod tests { + use super::{Authentication, Decode}; + use matches::assert_matches; + + const AUTH_OK: &[u8] = b"\0\0\0\0"; + const AUTH_MD5: &[u8] = b"\0\0\0\x05\x93\x189\x98"; + + #[test] + fn it_decodes_auth_ok() { + let m = Authentication::decode(AUTH_OK).unwrap(); + + assert_matches!(m, Authentication::Ok); + } + + #[test] + fn it_decodes_auth_md5_password() { + let m = Authentication::decode(AUTH_MD5).unwrap(); + + assert_matches!( + m, + Authentication::Md5Password { + salt: [147, 24, 57, 152] + } + ); + } +} diff --git a/src/postgres/protocol/backend_key_data.rs b/src/postgres/protocol/backend_key_data.rs index 7b6aba1b..fbac8487 100644 --- a/src/postgres/protocol/backend_key_data.rs +++ b/src/postgres/protocol/backend_key_data.rs @@ -1,4 +1,6 @@ -use super::{Buf, Decode}; +use super::Decode; +use crate::io::Buf; +use byteorder::NetworkEndian; use std::io; #[derive(Debug)] @@ -23,11 +25,9 @@ impl BackendKeyData { } impl Decode for BackendKeyData { - fn decode(mut src: &[u8]) -> io::Result { - debug_assert_eq!(src.len(), 8); - - let process_id = src.get_u32()?; - let secret_key = src.get_u32()?; + fn decode(mut buf: &[u8]) -> io::Result { + let process_id = buf.get_u32::()?; + let secret_key = buf.get_u32::()?; Ok(Self { process_id, @@ -39,7 +39,6 @@ impl Decode for BackendKeyData { #[cfg(test)] mod tests { use super::{BackendKeyData, Decode}; - use bytes::Bytes; const BACKEND_KEY_DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; diff --git a/src/postgres/protocol/bind.rs b/src/postgres/protocol/bind.rs index 80295a03..a9896e37 100644 --- a/src/postgres/protocol/bind.rs +++ b/src/postgres/protocol/bind.rs @@ -1,5 +1,6 @@ -use super::{BufMut, Encode}; -use byteorder::{BigEndian, ByteOrder}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::{BigEndian, ByteOrder, NetworkEndian}; pub struct Bind<'a> { /// The name of the destination portal (an empty string selects the unnamed portal). @@ -29,24 +30,32 @@ pub struct Bind<'a> { impl Encode for Bind<'_> { fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'B'); + buf.push(b'B'); let pos = buf.len(); - buf.put_int_32(0); // skip over len + buf.put_i32::(0); // skip over len - buf.put_str(self.portal); - buf.put_str(self.statement); + buf.put_str_nul(self.portal); + buf.put_str_nul(self.statement); - buf.put_array_int_16(&self.formats); + buf.put_i16::(self.formats.len() as i16); - buf.put_int_16(self.values_len); + for &format in self.formats { + buf.put_i16::(format); + } - buf.put(self.values); + buf.put_i16::(self.values_len); - buf.put_array_int_16(&self.result_formats); + buf.extend_from_slice(self.values); + + buf.put_i16::(self.result_formats.len() as i16); + + for &format in self.result_formats { + buf.put_i16::(format); + } // Write-back the len to the beginning of this frame let len = buf.len() - pos; - BigEndian::write_i32(&mut buf[pos..], len as i32); + NetworkEndian::write_i32(&mut buf[pos..], len as i32); } } diff --git a/src/postgres/protocol/cancel_request.rs b/src/postgres/protocol/cancel_request.rs index 9f8624ee..90169675 100644 --- a/src/postgres/protocol/cancel_request.rs +++ b/src/postgres/protocol/cancel_request.rs @@ -1,4 +1,6 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; /// Sent instead of [`StartupMessage`] with a new connection to cancel a running query on an existing /// connection. @@ -14,9 +16,9 @@ pub struct CancelRequest { impl Encode for CancelRequest { fn encode(&self, buf: &mut Vec) { - buf.put_int_32(16); // message length - buf.put_int_32(8087_7102); // constant for cancel request - buf.put_int_32(self.process_id); - buf.put_int_32(self.secret_key); + buf.put_i32::(16); // message length + buf.put_i32::(8087_7102); // constant for cancel request + buf.put_i32::(self.process_id); + buf.put_i32::(self.secret_key); } } diff --git a/src/postgres/protocol/close.rs b/src/postgres/protocol/close.rs index 5d5716c6..837fe88f 100644 --- a/src/postgres/protocol/close.rs +++ b/src/postgres/protocol/close.rs @@ -1,4 +1,6 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; #[repr(u8)] pub enum CloseKind { @@ -16,14 +18,17 @@ pub struct Close<'a> { impl Encode for Close<'_> { fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'C'); + buf.push(b'C'); + // len + kind + nul + len(string) - buf.put_int_32((4 + 1 + 1 + self.name.len()) as i32); - buf.put_byte(match self.kind { + buf.put_i32::((4 + 1 + 1 + self.name.len()) as i32); + + buf.push(match self.kind { CloseKind::PreparedStatement => b'S', CloseKind::Portal => b'P', }); - buf.put_str(self.name); + + buf.put_str_nul(self.name); } } diff --git a/src/postgres/protocol/command_complete.rs b/src/postgres/protocol/command_complete.rs index 104e0a80..3fd5b6ab 100644 --- a/src/postgres/protocol/command_complete.rs +++ b/src/postgres/protocol/command_complete.rs @@ -1,27 +1,37 @@ use super::Decode; -use memchr::memrchr; -use std::{io, str}; +use crate::io::Buf; +use std::io; #[derive(Debug)] pub struct CommandComplete { - pub rows: u64, + affected_rows: u64, +} + +impl CommandComplete { + #[inline] + pub fn affected_rows(&self) -> u64 { + self.affected_rows + } } impl Decode for CommandComplete { - fn decode(src: &[u8]) -> io::Result { + fn decode(mut buf: &[u8]) -> io::Result { + // TODO: MariaDb/MySQL return 0 for affected rows in a SELECT .. statement. + // PostgreSQL returns a row count. Should we force return 0 for compatibilities sake? + // Attempt to parse the last word in the command tag as an integer // If it can't be parased, the tag is probably "CREATE TABLE" or something // and we should return 0 rows - // TODO: Use [atoi] or similar to parse an integer directly from the bytes - - let rows_start = memrchr(b' ', src).unwrap_or(0); - let mut buf = &src[(rows_start + 1)..(src.len() - 1)]; - - let rows = unsafe { str::from_utf8_unchecked(buf) }; + let rows = buf + .get_str_nul()? + .rsplit(' ') + .next() + .and_then(|s| s.parse().ok()) + .unwrap_or(0); Ok(Self { - rows: rows.parse().unwrap_or(0), + affected_rows: rows, }) } } @@ -39,27 +49,27 @@ mod tests { fn it_decodes_command_complete_for_insert() { let message = CommandComplete::decode(COMMAND_COMPLETE_INSERT).unwrap(); - assert_eq!(message.rows, 1); + assert_eq!(message.affected_rows(), 1); } #[test] fn it_decodes_command_complete_for_update() { let message = CommandComplete::decode(COMMAND_COMPLETE_UPDATE).unwrap(); - assert_eq!(message.rows, 512); + assert_eq!(message.affected_rows(), 512); } #[test] fn it_decodes_command_complete_for_begin() { let message = CommandComplete::decode(COMMAND_COMPLETE_BEGIN).unwrap(); - assert_eq!(message.rows, 0); + assert_eq!(message.affected_rows(), 0); } #[test] fn it_decodes_command_complete_for_create_table() { let message = CommandComplete::decode(COMMAND_COMPLETE_CREATE_TABLE).unwrap(); - assert_eq!(message.rows, 0); + assert_eq!(message.affected_rows(), 0); } } diff --git a/src/postgres/protocol/copy_data.rs b/src/postgres/protocol/copy_data.rs index 9f0bea4d..a86cacd1 100644 --- a/src/postgres/protocol/copy_data.rs +++ b/src/postgres/protocol/copy_data.rs @@ -1,4 +1,6 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; // TODO: Implement Decode and think on an optimal representation @@ -19,9 +21,9 @@ pub struct CopyData<'a> { impl Encode for CopyData<'_> { fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'd'); + buf.push(b'd'); // len + nul + len(string) - buf.put_int_32((4 + 1 + self.data.len()) as i32); - buf.put(&self.data); + buf.put_i32::((4 + 1 + self.data.len()) as i32); + buf.extend_from_slice(&self.data); } } diff --git a/src/postgres/protocol/copy_done.rs b/src/postgres/protocol/copy_done.rs index 9b89c82d..92a45143 100644 --- a/src/postgres/protocol/copy_done.rs +++ b/src/postgres/protocol/copy_done.rs @@ -1,4 +1,6 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; // TODO: Implement Decode @@ -7,7 +9,7 @@ pub struct CopyDone; impl Encode for CopyDone { #[inline] fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'c'); - buf.put_int_32(4); + buf.push(b'c'); + buf.put_i32::(4); } } diff --git a/src/postgres/protocol/copy_fail.rs b/src/postgres/protocol/copy_fail.rs index 87b2b995..e606a92b 100644 --- a/src/postgres/protocol/copy_fail.rs +++ b/src/postgres/protocol/copy_fail.rs @@ -1,4 +1,6 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; pub struct CopyFail<'a> { pub error: &'a str, @@ -6,9 +8,9 @@ pub struct CopyFail<'a> { impl Encode for CopyFail<'_> { fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'f'); + buf.push(b'f'); // len + nul + len(string) - buf.put_int_32((4 + 1 + self.error.len()) as i32); - buf.put_str(&self.error); + buf.put_i32::((4 + 1 + self.error.len()) as i32); + buf.put_str_nul(&self.error); } } diff --git a/src/postgres/protocol/data_row.rs b/src/postgres/protocol/data_row.rs index 8524a67a..abd96eb7 100644 --- a/src/postgres/protocol/data_row.rs +++ b/src/postgres/protocol/data_row.rs @@ -1,6 +1,7 @@ -use super::{Buf, Decode}; +use super::Decode; +use crate::io::{Buf, ByteStr}; +use byteorder::NetworkEndian; use std::{ - convert::TryInto, fmt::{self, Debug}, io, pin::Pin, @@ -19,16 +20,16 @@ unsafe impl Sync for DataRow {} impl Decode for DataRow { fn decode(mut buf: &[u8]) -> io::Result { - let len = buf.get_u16()? as usize; + let cnt = buf.get_u16::()? as usize; let buffer: Pin> = Pin::new(buf.into()); let mut buf = &*buffer; - let mut values = Vec::with_capacity(len); + let mut values = Vec::with_capacity(cnt); - while values.len() < len { + while values.len() < cnt { // The length of the column value, in bytes (this count does not include itself). // Can be zero. As a special case, -1 indicates a NULL column value. // No value bytes follow in the NULL case. - let value_len = buf.get_i32()?; + let value_len = buf.get_i32::()?; if value_len == -1 { values.push(None); @@ -65,8 +66,16 @@ impl DataRow { } impl Debug for DataRow { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - unimplemented!(); + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "DataRow(")?; + + f.debug_list() + .entries((0..self.len()).map(|i| self.get(i).map(ByteStr))) + .finish()?; + + write!(f, ")")?; + + Ok(()) } } @@ -79,17 +88,17 @@ mod tests { #[test] fn it_decodes_data_row() { - let message = DataRow::decode(DATA_ROW).unwrap(); + let m = DataRow::decode(DATA_ROW).unwrap(); - assert_eq!(message.len(), 3); + assert_eq!(m.len(), 3); - assert_eq!(message.get(0), Some(&b"1"[..])); - assert_eq!(message.get(1), Some(&b"2"[..])); - assert_eq!(message.get(2), Some(&b"3"[..])); - } + assert_eq!(m.get(0), Some(&b"1"[..])); + assert_eq!(m.get(1), Some(&b"2"[..])); + assert_eq!(m.get(2), Some(&b"3"[..])); - #[bench] - fn bench_decode_data_row(b: &mut test::Bencher) { - b.iter(|| DataRow::decode(DATA_ROW).unwrap()); + assert_eq!( + format!("{:?}", m), + "DataRow([Some(b\"1\"), Some(b\"2\"), Some(b\"3\")])" + ); } } diff --git a/src/postgres/protocol/decode.rs b/src/postgres/protocol/decode.rs index f4ac7bad..232e16fe 100644 --- a/src/postgres/protocol/decode.rs +++ b/src/postgres/protocol/decode.rs @@ -1,86 +1,7 @@ -use memchr::memchr; -use std::{convert::TryInto, io, str}; +use std::io; pub trait Decode { fn decode(src: &[u8]) -> io::Result where Self: Sized; } - -#[inline] -pub(crate) fn get_str(src: &[u8]) -> &str { - let end = memchr(b'\0', &src).expect("expected null terminator in UTF-8 string"); - let buf = &src[..end]; - - unsafe { str::from_utf8_unchecked(buf) } -} - -pub trait Buf { - fn advance(&mut self, cnt: usize); - - // An n-bit integer in network byte order (IntN) - fn get_u8(&mut self) -> io::Result; - fn get_u16(&mut self) -> io::Result; - fn get_i32(&mut self) -> io::Result; - fn get_u32(&mut self) -> io::Result; - - // A null-terminated string - fn get_str_null(&mut self) -> io::Result<&str>; -} - -impl<'a> Buf for &'a [u8] { - fn advance(&mut self, cnt: usize) { - *self = &self[cnt..]; - } - - fn get_u8(&mut self) -> io::Result { - let val = self[0]; - - self.advance(1); - - Ok(val) - } - - fn get_u16(&mut self) -> io::Result { - let val: [u8; 2] = (&self[..2]) - .try_into() - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; - - self.advance(2); - - Ok(u16::from_be_bytes(val)) - } - - fn get_i32(&mut self) -> io::Result { - let val: [u8; 4] = (&self[..4]) - .try_into() - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; - - self.advance(4); - - Ok(i32::from_be_bytes(val)) - } - - fn get_u32(&mut self) -> io::Result { - let val: [u8; 4] = (&self[..4]) - .try_into() - .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; - - self.advance(4); - - Ok(u32::from_be_bytes(val)) - } - - fn get_str_null(&mut self) -> io::Result<&str> { - let end = memchr(b'\0', &*self).ok_or(io::ErrorKind::InvalidData)?; - let buf = &self[..end]; - - self.advance(end + 1); - - if cfg!(debug_asserts) { - str::from_utf8(buf).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) - } else { - Ok(unsafe { str::from_utf8_unchecked(buf) }) - } - } -} diff --git a/src/postgres/protocol/describe.rs b/src/postgres/protocol/describe.rs index aae39c16..c5c75601 100644 --- a/src/postgres/protocol/describe.rs +++ b/src/postgres/protocol/describe.rs @@ -1,4 +1,6 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; #[repr(u8)] pub enum DescribeKind { @@ -16,14 +18,14 @@ pub struct Describe<'a> { impl Encode for Describe<'_> { fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'D'); + buf.push(b'D'); // len + kind + nul + len(string) - buf.put_int_32((4 + 1 + 1 + self.name.len()) as i32); - buf.put_byte(match self.kind { + buf.put_i32::((4 + 1 + 1 + self.name.len()) as i32); + buf.push(match self.kind { DescribeKind::PreparedStatement => b'S', DescribeKind::Portal => b'P', }); - buf.put_str(self.name); + buf.put_str_nul(self.name); } } diff --git a/src/postgres/protocol/encode.rs b/src/postgres/protocol/encode.rs index 6687936c..05dcdeed 100644 --- a/src/postgres/protocol/encode.rs +++ b/src/postgres/protocol/encode.rs @@ -1,93 +1,3 @@ pub trait Encode { fn encode(&self, buf: &mut Vec); } - -pub trait BufMut { - fn put(&mut self, bytes: &[u8]); - - fn put_byte(&mut self, value: u8); - - fn put_int_16(&mut self, value: i16); - - fn put_uint_16(&mut self, value: u16); - - fn put_int_32(&mut self, value: i32); - - fn put_uint_32(&mut self, value: u32); - - fn put_array_int_16(&mut self, values: &[i16]); - - fn put_array_int_32(&mut self, values: &[i32]); - - fn put_array_uint_32(&mut self, values: &[u32]); - - fn put_str(&mut self, value: &str); -} - -impl BufMut for Vec { - #[inline] - fn put(&mut self, bytes: &[u8]) { - self.extend_from_slice(bytes); - } - - #[inline] - fn put_byte(&mut self, value: u8) { - self.push(value); - } - - #[inline] - fn put_int_16(&mut self, value: i16) { - self.extend_from_slice(&value.to_be_bytes()); - } - - #[inline] - fn put_uint_16(&mut self, value: u16) { - self.extend_from_slice(&value.to_be_bytes()); - } - - #[inline] - fn put_int_32(&mut self, value: i32) { - self.extend_from_slice(&value.to_be_bytes()); - } - - #[inline] - fn put_uint_32(&mut self, value: u32) { - self.extend_from_slice(&value.to_be_bytes()); - } - - #[inline] - fn put_str(&mut self, value: &str) { - self.extend_from_slice(value.as_bytes()); - self.push(0); - } - - #[inline] - fn put_array_int_16(&mut self, values: &[i16]) { - // FIXME: What happens here when len(values) > i16 - self.put_int_16(values.len() as i16); - - for value in values { - self.put_int_16(*value); - } - } - - #[inline] - fn put_array_int_32(&mut self, values: &[i32]) { - // FIXME: What happens here when len(values) > i16 - self.put_int_16(values.len() as i16); - - for value in values { - self.put_int_32(*value); - } - } - - #[inline] - fn put_array_uint_32(&mut self, values: &[u32]) { - // FIXME: What happens here when len(values) > i16 - self.put_int_16(values.len() as i16); - - for value in values { - self.put_uint_32(*value); - } - } -} diff --git a/src/postgres/protocol/execute.rs b/src/postgres/protocol/execute.rs index 94004489..ab961a44 100644 --- a/src/postgres/protocol/execute.rs +++ b/src/postgres/protocol/execute.rs @@ -1,4 +1,6 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; pub struct Execute<'a> { /// The name of the portal to execute (an empty string selects the unnamed portal). @@ -11,10 +13,10 @@ pub struct Execute<'a> { impl Encode for Execute<'_> { fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'E'); + buf.push(b'E'); // len + nul + len(string) + limit - buf.put_int_32((4 + 1 + self.portal.len() + 4) as i32); - buf.put_str(&self.portal); - buf.put_int_32(self.limit); + buf.put_i32::((4 + 1 + self.portal.len() + 4) as i32); + buf.put_str_nul(&self.portal); + buf.put_i32::(self.limit); } } diff --git a/src/postgres/protocol/flush.rs b/src/postgres/protocol/flush.rs index 88e8f594..6e08dfb4 100644 --- a/src/postgres/protocol/flush.rs +++ b/src/postgres/protocol/flush.rs @@ -1,11 +1,13 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; pub struct Flush; impl Encode for Flush { #[inline] fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'H'); - buf.put_int_32(4); + buf.push(b'H'); + buf.put_i32::(4); } } diff --git a/src/postgres/protocol/message.rs b/src/postgres/protocol/message.rs index 5986df34..c1df45da 100644 --- a/src/postgres/protocol/message.rs +++ b/src/postgres/protocol/message.rs @@ -14,7 +14,7 @@ pub enum Message { BackendKeyData(BackendKeyData), ReadyForQuery(ReadyForQuery), CommandComplete(CommandComplete), - DataRow(Box), + DataRow(DataRow), Response(Box), NotificationResponse(Box), ParseComplete, diff --git a/src/postgres/protocol/mod.rs b/src/postgres/protocol/mod.rs index c41ce47b..5390a882 100644 --- a/src/postgres/protocol/mod.rs +++ b/src/postgres/protocol/mod.rs @@ -32,7 +32,7 @@ pub use self::{ copy_done::CopyDone, copy_fail::CopyFail, describe::Describe, - encode::{BufMut, Encode}, + encode::Encode, execute::Execute, flush::Flush, parse::Parse, @@ -43,30 +43,24 @@ pub use self::{ terminate::Terminate, }; -// TODO: Audit backend protocol - mod authentication; mod backend_key_data; mod command_complete; mod data_row; mod decode; -mod message; mod notification_response; mod parameter_description; mod parameter_status; mod ready_for_query; mod response; +// TODO: Audit backend protocol + +mod message; + pub use self::{ - authentication::Authentication, - backend_key_data::BackendKeyData, - command_complete::CommandComplete, - data_row::DataRow, - decode::{Buf, Decode}, - message::Message, - notification_response::NotificationResponse, - parameter_description::ParameterDescription, - parameter_status::ParameterStatus, - ready_for_query::ReadyForQuery, - response::Response, + authentication::Authentication, backend_key_data::BackendKeyData, + command_complete::CommandComplete, data_row::DataRow, decode::Decode, message::Message, + notification_response::NotificationResponse, parameter_description::ParameterDescription, + parameter_status::ParameterStatus, ready_for_query::ReadyForQuery, response::Response, }; diff --git a/src/postgres/protocol/notification_response.rs b/src/postgres/protocol/notification_response.rs index 58e873f4..cef5d34a 100644 --- a/src/postgres/protocol/notification_response.rs +++ b/src/postgres/protocol/notification_response.rs @@ -1,10 +1,11 @@ -use super::{Buf, Decode}; -use byteorder::{BigEndian, ByteOrder}; +use super::Decode; +use crate::io::Buf; +use byteorder::NetworkEndian; use std::{fmt, io, pin::Pin, ptr::NonNull}; pub struct NotificationResponse { #[used] - storage: Pin>, + buffer: Pin>, pid: u32, channel_name: NonNull, message: NonNull, @@ -44,18 +45,17 @@ impl fmt::Debug for NotificationResponse { } impl Decode for NotificationResponse { - fn decode(mut src: &[u8]) -> io::Result { - let pid = src.get_u32()?; + fn decode(mut buf: &[u8]) -> io::Result { + let pid = buf.get_u32::()?; - // offset from pid=4 - let storage = Pin::new(src.into()); - let mut src: &[u8] = &*storage; + let buffer = Pin::new(buf.into()); + let mut buf: &[u8] = &*buffer; - let channel_name = src.get_str_null()?.into(); - let message = src.get_str_null()?.into(); + let channel_name = buf.get_str_nul()?.into(); + let message = buf.get_str_nul()?.into(); Ok(Self { - storage, + buffer, pid, channel_name, message, @@ -77,5 +77,11 @@ mod tests { assert_eq!(message.pid(), 0x34201002); assert_eq!(message.channel_name(), "TEST-CHANNEL"); assert_eq!(message.message(), "THIS IS A TEST"); + + assert_eq!( + format!("{:?}", message), + "NotificationResponse { pid: 874516482, channel_name: \"TEST-CHANNEL\", message: \ + \"THIS IS A TEST\" }" + ); } } diff --git a/src/postgres/protocol/parameter_description.rs b/src/postgres/protocol/parameter_description.rs index 91a5ab74..3b3328ec 100644 --- a/src/postgres/protocol/parameter_description.rs +++ b/src/postgres/protocol/parameter_description.rs @@ -1,24 +1,23 @@ -use super::{Buf, Decode}; -use byteorder::{BigEndian, ByteOrder}; -use std::{io, mem::size_of}; - -type ObjectId = u32; +use super::Decode; +use crate::io::Buf; +use byteorder::NetworkEndian; +use std::io; #[derive(Debug)] pub struct ParameterDescription { - ids: Box<[ObjectId]>, + ids: Box<[u32]>, } impl Decode for ParameterDescription { - fn decode(mut src: &[u8]) -> io::Result { - let count = src.get_u16()?; - let mut ids = Vec::with_capacity(count as usize); + fn decode(mut buf: &[u8]) -> io::Result { + let cnt = buf.get_u16::()? as usize; + let mut ids = Vec::with_capacity(cnt); - for i in 0..count { - ids.push(src.get_u32()?); + for i in 0..cnt { + ids.push(buf.get_u32::()?); } - Ok(ParameterDescription { + Ok(Self { ids: ids.into_boxed_slice(), }) } @@ -31,8 +30,8 @@ mod test { #[test] fn it_decodes_parameter_description() { - let src = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; - let desc = ParameterDescription::decode(src).unwrap(); + let buf = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; + let desc = ParameterDescription::decode(buf).unwrap(); assert_eq!(desc.ids.len(), 2); assert_eq!(desc.ids[0], 0x0000_0000); @@ -41,8 +40,8 @@ mod test { #[test] fn it_decodes_empty_parameter_description() { - let src = b"\x00\x00"; - let desc = ParameterDescription::decode(src).unwrap(); + let buf = b"\x00\x00"; + let desc = ParameterDescription::decode(buf).unwrap(); assert_eq!(desc.ids.len(), 0); } diff --git a/src/postgres/protocol/parameter_status.rs b/src/postgres/protocol/parameter_status.rs index f6d193be..b80afae1 100644 --- a/src/postgres/protocol/parameter_status.rs +++ b/src/postgres/protocol/parameter_status.rs @@ -1,11 +1,16 @@ -use super::decode::{Buf, Decode}; -use std::{io, pin::Pin, ptr::NonNull, str}; +use super::decode::Decode; +use crate::io::Buf; +use std::{ + fmt::{self, Debug}, + io, + pin::Pin, + ptr::NonNull, + str, +}; -// FIXME: Use &str functions for a custom Debug -#[derive(Debug)] pub struct ParameterStatus { #[used] - storage: Pin>, + buffer: Pin>, name: NonNull, value: NonNull, } @@ -29,21 +34,30 @@ impl ParameterStatus { } impl Decode for ParameterStatus { - fn decode(src: &[u8]) -> io::Result { - let storage = Pin::new(src.into()); - let mut src: &[u8] = &*storage; + fn decode(buf: &[u8]) -> io::Result { + let buffer = Pin::new(buf.into()); + let mut buf: &[u8] = &*buffer; - let name = NonNull::from(src.get_str_null()?); - let value = NonNull::from(src.get_str_null()?); + let name = buf.get_str_nul()?.into(); + let value = buf.get_str_nul()?.into(); Ok(Self { - storage, + buffer, name, value, }) } } +impl fmt::Debug for ParameterStatus { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("ParameterStatus") + .field("name", &self.name()) + .field("value", &self.value()) + .finish() + } +} + #[cfg(test)] mod tests { use super::{Decode, ParameterStatus}; @@ -56,10 +70,10 @@ mod tests { assert_eq!(message.name(), "session_authorization"); assert_eq!(message.value(), "postgres"); - } - #[bench] - fn bench_decode_param_status(b: &mut test::Bencher) { - b.iter(|| ParameterStatus::decode(PARAM_STATUS).unwrap()); + assert_eq!( + format!("{:?}", message), + "ParameterStatus { name: \"session_authorization\", value: \"postgres\" }" + ); } } diff --git a/src/postgres/protocol/parse.rs b/src/postgres/protocol/parse.rs index cd38d59d..6ef06e20 100644 --- a/src/postgres/protocol/parse.rs +++ b/src/postgres/protocol/parse.rs @@ -1,4 +1,6 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; pub struct Parse<'a> { pub portal: &'a str, @@ -8,15 +10,19 @@ pub struct Parse<'a> { impl Encode for Parse<'_> { fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'P'); + buf.push(b'P'); // len + portal + nul + query + null + len(param_types) + param_types let len = 4 + self.portal.len() + 1 + self.query.len() + 1 + 2 + self.param_types.len() * 4; - buf.put_int_32(len as i32); + buf.put_i32::(len as i32); - buf.put_str(self.portal); - buf.put_str(self.query); + buf.put_str_nul(self.portal); + buf.put_str_nul(self.query); - buf.put_array_uint_32(&self.param_types); + buf.put_i16::(self.param_types.len() as i16); + + for &type_ in self.param_types { + buf.put_u32::(type_); + } } } diff --git a/src/postgres/protocol/password_message.rs b/src/postgres/protocol/password_message.rs index 8da7e782..254dfc1d 100644 --- a/src/postgres/protocol/password_message.rs +++ b/src/postgres/protocol/password_message.rs @@ -1,4 +1,6 @@ -use super::{BufMut, Encode}; +use super::Encode; +use crate::io::BufMut; +use byteorder::NetworkEndian; use md5::{Digest, Md5}; #[derive(Debug)] @@ -13,13 +15,13 @@ pub enum PasswordMessage<'a> { impl Encode for PasswordMessage<'_> { fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'p'); + buf.push(b'p'); match self { PasswordMessage::Cleartext(s) => { // len + password + nul - buf.put_int_32((4 + s.len() + 1) as i32); - buf.put_str(s); + buf.put_u32::((4 + s.len() + 1) as u32); + buf.put_str_nul(s); } PasswordMessage::Md5 { @@ -40,11 +42,44 @@ impl Encode for PasswordMessage<'_> { let salted = hex::encode(hasher.result()); // len + "md5" + (salted) - buf.put_int_32((4 + 3 + salted.len()) as i32); + buf.put_u32::((4 + 3 + salted.len() + 1) as u32); - buf.put(b"md5"); - buf.put(salted.as_bytes()); + buf.extend_from_slice(b"md5"); + buf.extend_from_slice(salted.as_bytes()); + buf.push(0); } } } } + +#[cfg(test)] +mod tests { + use super::{Encode, PasswordMessage}; + + const PASSWORD_CLEAR: &[u8] = b"p\0\0\0\rpassword\0"; + const PASSWORD_MD5: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0"; + + #[test] + fn it_encodes_password_clear() { + let mut buf = Vec::new(); + let m = PasswordMessage::Cleartext("password"); + + m.encode(&mut buf); + + assert_eq!(buf, PASSWORD_CLEAR); + } + + #[test] + fn it_encodes_password_md5() { + let mut buf = Vec::new(); + let m = PasswordMessage::Md5 { + password: "password", + user: "root", + salt: [147, 24, 57, 152], + }; + + m.encode(&mut buf); + + assert_eq!(buf, PASSWORD_MD5); + } +} diff --git a/src/postgres/protocol/query.rs b/src/postgres/protocol/query.rs index c48a6905..014b53cc 100644 --- a/src/postgres/protocol/query.rs +++ b/src/postgres/protocol/query.rs @@ -1,15 +1,17 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; pub struct Query<'a>(pub &'a str); impl Encode for Query<'_> { fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'Q'); + buf.push(b'Q'); // len + query + nul - buf.put_int_32((4 + self.0.len() + 1) as i32); + buf.put_i32::((4 + self.0.len() + 1) as i32); - buf.put_str(self.0); + buf.put_str_nul(self.0); } } diff --git a/src/postgres/protocol/ready_for_query.rs b/src/postgres/protocol/ready_for_query.rs index 1bfa4766..83e37187 100644 --- a/src/postgres/protocol/ready_for_query.rs +++ b/src/postgres/protocol/ready_for_query.rs @@ -28,10 +28,9 @@ impl ReadyForQuery { } impl Decode for ReadyForQuery { - fn decode(src: &[u8]) -> io::Result { + fn decode(buf: &[u8]) -> io::Result { Ok(Self { - status: match src[0] { - // FIXME: Variant value are duplicated with declaration + status: match buf[0] { b'I' => TransactionStatus::Idle, b'T' => TransactionStatus::Transaction, b'E' => TransactionStatus::Error, diff --git a/src/postgres/protocol/response.rs b/src/postgres/protocol/response.rs index affd1ed4..6ea3adc7 100644 --- a/src/postgres/protocol/response.rs +++ b/src/postgres/protocol/response.rs @@ -1,4 +1,5 @@ -use super::{decode::get_str, Decode}; +use super::Decode; +use crate::io::Buf; use std::{ fmt, io, pin::Pin, @@ -73,10 +74,9 @@ impl FromStr for Severity { } } -#[derive(Clone)] pub struct Response { #[used] - storage: Pin>, + buffer: Pin>, severity: Severity, code: NonNull, message: NonNull, @@ -225,44 +225,41 @@ impl fmt::Debug for Response { } impl Decode for Response { - fn decode(src: &[u8]) -> io::Result { - let storage: Pin> = Pin::new(src.into()); + fn decode(buf: &[u8]) -> io::Result { + let buffer: Pin> = Pin::new(buf.into()); + let mut buf: &[u8] = &*buffer; - let mut code = None::<&str>; - let mut message = None::<&str>; - let mut severity = None::<&str>; + let mut code = None::>; + let mut message = None::>; + let mut severity = None::>; let mut severity_non_local = None::; - let mut detail = None::<&str>; - let mut hint = None::<&str>; + let mut detail = None::>; + let mut hint = None::>; let mut position = None::; let mut internal_position = None::; - let mut internal_query = None::<&str>; - let mut where_ = None::<&str>; - let mut schema = None::<&str>; - let mut table = None::<&str>; - let mut column = None::<&str>; - let mut data_type = None::<&str>; - let mut constraint = None::<&str>; - let mut file = None::<&str>; + let mut internal_query = None::>; + let mut where_ = None::>; + let mut schema = None::>; + let mut table = None::>; + let mut column = None::>; + let mut data_type = None::>; + let mut constraint = None::>; + let mut file = None::>; let mut line = None::; - let mut routine = None::<&str>; - - let mut idx = 0; + let mut routine = None::>; loop { - let field_type = storage[idx]; - idx += 1; + let field_type = buf.get_u8()?; if field_type == 0 { break; } - let field_value = get_str(&storage[idx..]); - idx += field_value.len() + 1; + let field_value = buf.get_str_nul()?; match field_type { b'S' => { - severity = Some(field_value); + severity = Some(field_value.into()); } b'V' => { @@ -270,19 +267,19 @@ impl Decode for Response { } b'C' => { - code = Some(field_value); + code = Some(field_value.into()); } b'M' => { - message = Some(field_value); + message = Some(field_value.into()); } b'D' => { - detail = Some(field_value); + detail = Some(field_value.into()); } b'H' => { - hint = Some(field_value); + hint = Some(field_value.into()); } b'P' => { @@ -302,35 +299,35 @@ impl Decode for Response { } b'q' => { - internal_query = Some(field_value); + internal_query = Some(field_value.into()); } b'w' => { - where_ = Some(field_value); + where_ = Some(field_value.into()); } b's' => { - schema = Some(field_value); + schema = Some(field_value.into()); } b't' => { - table = Some(field_value); + table = Some(field_value.into()); } b'c' => { - column = Some(field_value); + column = Some(field_value.into()); } b'd' => { - data_type = Some(field_value); + data_type = Some(field_value.into()); } b'n' => { - constraint = Some(field_value); + constraint = Some(field_value.into()); } b'F' => { - file = Some(field_value); + file = Some(field_value.into()); } b'L' => { @@ -342,38 +339,43 @@ impl Decode for Response { } b'R' => { - routine = Some(field_value); + routine = Some(field_value.into()); } _ => { - unimplemented!( - "response message field {:?} not implemented", - field_type as char - ); + // TODO: Should we return these somehow, like in a map? + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("received unknown field in Response: {}", field_type), + )); } } } let severity = severity_non_local - .or_else(move || severity?.parse().ok()) - .expect("`severity` required by protocol"); + .or_else(move || unsafe { severity?.as_ref() }.parse().ok()) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "did not receieve field `severity` for Response", + ) + })?; - let code = NonNull::from(code.expect("`code` required by protocol")); - let message = NonNull::from(message.expect("`message` required by protocol")); - let detail = detail.map(NonNull::from); - let hint = hint.map(NonNull::from); - let internal_query = internal_query.map(NonNull::from); - let where_ = where_.map(NonNull::from); - let schema = schema.map(NonNull::from); - let table = table.map(NonNull::from); - let column = column.map(NonNull::from); - let data_type = data_type.map(NonNull::from); - let constraint = constraint.map(NonNull::from); - let file = file.map(NonNull::from); - let routine = routine.map(NonNull::from); + let code = code.ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "did not receieve field `code` for Response", + ) + })?; + let message = message.ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "did not receieve field `message` for Response", + ) + })?; Ok(Self { - storage, + buffer, severity, code, message, diff --git a/src/postgres/protocol/startup_message.rs b/src/postgres/protocol/startup_message.rs index ba64ace1..630fdf4a 100644 --- a/src/postgres/protocol/startup_message.rs +++ b/src/postgres/protocol/startup_message.rs @@ -1,5 +1,7 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; use byteorder::{BigEndian, ByteOrder}; +use byteorder::NetworkEndian; pub struct StartupMessage<'a> { pub params: &'a [(&'a str, &'a str)], @@ -8,17 +10,17 @@ pub struct StartupMessage<'a> { impl Encode for StartupMessage<'_> { fn encode(&self, buf: &mut Vec) { let pos = buf.len(); - buf.put_int_32(0); // skip over len + buf.put_i32::(0); // skip over len // protocol version number (3.0) - buf.put_int_32(196_608); + buf.put_i32::(196_608); for (name, value) in self.params { - buf.put_str(name); - buf.put_str(value); + buf.put_str_nul(name); + buf.put_str_nul(value); } - buf.put_byte(0); + buf.push(0); // Write-back the len to the beginning of this frame let len = buf.len() - pos; diff --git a/src/postgres/protocol/sync.rs b/src/postgres/protocol/sync.rs index 51566fc3..d4ad8a9b 100644 --- a/src/postgres/protocol/sync.rs +++ b/src/postgres/protocol/sync.rs @@ -1,11 +1,13 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; pub struct Sync; impl Encode for Sync { #[inline] fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'S'); - buf.put_int_32(4); + buf.push(b'S'); + buf.put_i32::(4); } } diff --git a/src/postgres/protocol/terminate.rs b/src/postgres/protocol/terminate.rs index a0402815..58cedfa9 100644 --- a/src/postgres/protocol/terminate.rs +++ b/src/postgres/protocol/terminate.rs @@ -1,11 +1,13 @@ -use super::{BufMut, Encode}; +use super::{Encode}; +use crate::io::BufMut; +use byteorder::NetworkEndian; pub struct Terminate; impl Encode for Terminate { #[inline] fn encode(&self, buf: &mut Vec) { - buf.put_byte(b'X'); - buf.put_int_32(4); + buf.push(b'X'); + buf.put_i32::(4); } } diff --git a/src/postgres/query.rs b/src/postgres/query.rs index f434be4e..df30bd58 100644 --- a/src/postgres/query.rs +++ b/src/postgres/query.rs @@ -1,13 +1,15 @@ use super::{ - protocol::{self, BufMut}, + protocol, Postgres, PostgresRawConnection, }; use crate::{ + io::BufMut, query::QueryParameters, serialize::{IsNull, ToSql}, types::HasSqlType, }; use byteorder::{BigEndian, ByteOrder}; +use byteorder::NetworkEndian; pub struct PostgresQueryParameters { // OIDs of the bind parameters @@ -40,7 +42,7 @@ impl QueryParameters for PostgresQueryParameters { self.types.push(>::metadata().oid); let pos = self.buf.len(); - self.buf.put_int_32(0); + self.buf.put_i32::(0); let len = if let IsNull::No = value.to_sql(&mut self.buf) { (self.buf.len() - pos - 4) as i32 diff --git a/src/postgres/row.rs b/src/postgres/row.rs index 9822af1a..d115b51e 100644 --- a/src/postgres/row.rs +++ b/src/postgres/row.rs @@ -1,7 +1,7 @@ use super::{protocol::DataRow, Postgres}; use crate::row::Row; -pub struct PostgresRow(pub(crate) Box); +pub struct PostgresRow(pub(crate) DataRow); impl Row for PostgresRow { type Backend = Postgres;