Optimize encode of StartupMessage

This commit is contained in:
Ryan Leckey 2019-07-11 10:08:57 -07:00
parent da42be7d0a
commit f161fa3178
4 changed files with 57 additions and 109 deletions

View file

@ -6,7 +6,7 @@ use sqlx_postgres_protocol::{Encode, PasswordMessage, Response, Severity, Startu
fn criterion_benchmark(c: &mut Criterion) {
c.bench_function("encode Response::builder()", |b| {
let mut dst = Vec::new();
let mut dst = Vec::with_capacity(1024);
b.iter(|| {
dst.clear();
Response::builder()
@ -22,7 +22,7 @@ fn criterion_benchmark(c: &mut Criterion) {
});
c.bench_function("encode PasswordMessage::cleartext", |b| {
let mut dst = Vec::new();
let mut dst = Vec::with_capacity(1024);
b.iter(|| {
dst.clear();
PasswordMessage::cleartext("8e323AMF9YSE9zftFnuhQcvhz7Vf342W4cWU")
@ -32,31 +32,31 @@ fn criterion_benchmark(c: &mut Criterion) {
});
c.bench_function("encode StartupMessage", |b| {
let mut dst = Vec::new();
let mut dst = Vec::with_capacity(1024);
b.iter(|| {
dst.clear();
StartupMessage::builder()
.param("user", "postgres")
.param("database", "postgres")
.param("DateStyle", "ISO, MDY")
.param("IntervalStyle", "iso_8601")
.param("TimeZone", "UTC")
.param("extra_float_digits", "3")
.param("client_encoding", "UTF-8")
.build()
.encode(&mut dst)
.unwrap();
StartupMessage::new(&[
("user", "postgres"),
("database", "postgres"),
("DateStyle", "ISO, MDY"),
("IntervalStyle", "iso_8601"),
("TimeZone", "UTC"),
("extra_float_digits", "3"),
("client_encoding", "UTF-8"),
])
.encode(&mut dst)
.unwrap();
})
});
c.bench_function("encode Password(MD5)", |b| {
let mut dst = Vec::new();
let mut dst = Vec::with_capacity(1024);
b.iter(|| {
dst.clear();
PasswordMessage::md5(
"8e323AMF9YSE9zftFnuhQcvhz7Vf342W4cWU",
"postgres",
&[10, 41, 20, 150],
[10, 41, 20, 150],
)
.encode(&mut dst)
.unwrap();

View file

@ -1,91 +1,46 @@
use crate::Encode;
use bytes::{BufMut, Bytes, BytesMut};
use byteorder::{BigEndian, ByteOrder};
use std::io;
#[derive(Debug)]
pub struct StartupMessage {
// (major, minor)
version: (u16, u16),
params: Bytes,
pub struct StartupMessage<'a> {
params: &'a [(&'a str, &'a str)],
}
impl StartupMessage {
pub fn builder() -> StartupMessageBuilder {
StartupMessageBuilder::new()
impl<'a> StartupMessage<'a> {
pub fn new(params: &'a [(&'a str, &'a str)]) -> Self {
Self { params }
}
pub fn version(&self) -> (u16, u16) {
self.version
}
pub fn params(&self) -> StartupMessageParams<'_> {
StartupMessageParams(&*self.params)
pub fn params(&self) -> &'a [(&'a str, &'a str)] {
self.params
}
}
impl Encode for StartupMessage {
impl<'a> Encode for StartupMessage<'a> {
fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()> {
let len = self.params.len() + 8;
buf.reserve(len);
buf.put_u32_be(len as u32);
buf.put_u16_be(self.version.0);
buf.put_u16_be(self.version.1);
buf.put(&self.params);
let pos = buf.len();
buf.extend_from_slice(&(0 as u32).to_be_bytes()); // skip over len
buf.extend_from_slice(&3_u16.to_be_bytes()); // major version
buf.extend_from_slice(&0_u16.to_be_bytes()); // minor version
for (name, value) in self.params {
buf.extend_from_slice(name.as_bytes());
buf.push(0);
buf.extend_from_slice(value.as_bytes());
buf.push(0);
}
buf.push(0);
// Write-back the len to the beginning of this frame
let len = buf.len() - pos;
BigEndian::write_u32(&mut buf[pos..], len as u32);
Ok(())
}
}
// TODO: Impl Iterator to iter over params
pub struct StartupMessageParams<'a>(&'a [u8]);
pub struct StartupMessageBuilder {
// (major, minor)
version: (u16, u16),
params: BytesMut,
}
impl Default for StartupMessageBuilder {
fn default() -> Self {
StartupMessageBuilder {
version: (3, 0),
params: BytesMut::with_capacity(156),
}
}
}
impl StartupMessageBuilder {
pub fn new() -> Self {
StartupMessageBuilder::default()
}
/// Set the protocol version number. Defaults to `3.0`.
pub fn version(mut self, major: u16, minor: u16) -> Self {
self.version = (major, minor);
self
}
pub fn param(mut self, name: &str, value: &str) -> Self {
self.params.reserve(name.len() + value.len() + 2);
self.params.put(name.as_bytes());
self.params.put_u8(0);
self.params.put(value.as_bytes());
self.params.put_u8(0);
self
}
pub fn build(mut self) -> StartupMessage {
self.params.reserve(1);
self.params.put_u8(0);
StartupMessage {
version: self.version,
params: self.params.freeze(),
}
}
}
#[cfg(test)]
mod tests {
use super::StartupMessage;
@ -96,10 +51,7 @@ mod tests {
#[test]
fn it_encodes_startup_message() -> io::Result<()> {
let message = StartupMessage::builder()
.param("user", "postgres")
.param("database", "postgres")
.build();
let message = StartupMessage::new(&[("user", "postgres"), ("database", "postgres")]);
let mut buf = Vec::new();
message.encode(&mut buf)?;

View file

@ -10,31 +10,26 @@ pub async fn establish<'a, 'b: 'a>(
) -> io::Result<()> {
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
let mut message = StartupMessage::builder();
if let Some(user) = options.user {
// FIXME: User is technically required. We should default this like psql does.
message = message.param("user", user);
}
if let Some(database) = options.database {
message = message.param("database", database);
}
let message = message
let params = &[
// FIXME: ConnectOptions user and database need to be required parameters and error
// before they get here
("user", options.user.expect("user is required")),
("database", options.database.expect("database is required")),
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
.param("DateStyle", "ISO, MDY")
("DateStyle", "ISO, MDY"),
// Sets the display format for interval values.
.param("IntervalStyle", "iso_8601")
("IntervalStyle", "iso_8601"),
// Sets the time zone for displaying and interpreting time stamps.
.param("TimeZone", "UTC")
("TimeZone", "UTC"),
// Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+
.param("extra_float_digits", "3")
("extra_float_digits", "3"),
// Sets the client-side encoding (character set).
.param("client_encoding", "UTF-8")
.build();
("client_encoding", "UTF-8"),
];
let message = StartupMessage::new(params);
conn.send(message).await?;

View file

@ -12,6 +12,7 @@ async fn main() -> io::Result<()> {
.host("127.0.0.1")
.port(5432)
.user("postgres")
.database("postgres")
.password("password"),
)
.await?;