mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 14:34:19 +00:00
Optimize encode of StartupMessage
This commit is contained in:
parent
da42be7d0a
commit
f161fa3178
4 changed files with 57 additions and 109 deletions
|
@ -6,7 +6,7 @@ use sqlx_postgres_protocol::{Encode, PasswordMessage, Response, Severity, Startu
|
|||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
c.bench_function("encode Response::builder()", |b| {
|
||||
let mut dst = Vec::new();
|
||||
let mut dst = Vec::with_capacity(1024);
|
||||
b.iter(|| {
|
||||
dst.clear();
|
||||
Response::builder()
|
||||
|
@ -22,7 +22,7 @@ fn criterion_benchmark(c: &mut Criterion) {
|
|||
});
|
||||
|
||||
c.bench_function("encode PasswordMessage::cleartext", |b| {
|
||||
let mut dst = Vec::new();
|
||||
let mut dst = Vec::with_capacity(1024);
|
||||
b.iter(|| {
|
||||
dst.clear();
|
||||
PasswordMessage::cleartext("8e323AMF9YSE9zftFnuhQcvhz7Vf342W4cWU")
|
||||
|
@ -32,31 +32,31 @@ fn criterion_benchmark(c: &mut Criterion) {
|
|||
});
|
||||
|
||||
c.bench_function("encode StartupMessage", |b| {
|
||||
let mut dst = Vec::new();
|
||||
let mut dst = Vec::with_capacity(1024);
|
||||
b.iter(|| {
|
||||
dst.clear();
|
||||
StartupMessage::builder()
|
||||
.param("user", "postgres")
|
||||
.param("database", "postgres")
|
||||
.param("DateStyle", "ISO, MDY")
|
||||
.param("IntervalStyle", "iso_8601")
|
||||
.param("TimeZone", "UTC")
|
||||
.param("extra_float_digits", "3")
|
||||
.param("client_encoding", "UTF-8")
|
||||
.build()
|
||||
.encode(&mut dst)
|
||||
.unwrap();
|
||||
StartupMessage::new(&[
|
||||
("user", "postgres"),
|
||||
("database", "postgres"),
|
||||
("DateStyle", "ISO, MDY"),
|
||||
("IntervalStyle", "iso_8601"),
|
||||
("TimeZone", "UTC"),
|
||||
("extra_float_digits", "3"),
|
||||
("client_encoding", "UTF-8"),
|
||||
])
|
||||
.encode(&mut dst)
|
||||
.unwrap();
|
||||
})
|
||||
});
|
||||
|
||||
c.bench_function("encode Password(MD5)", |b| {
|
||||
let mut dst = Vec::new();
|
||||
let mut dst = Vec::with_capacity(1024);
|
||||
b.iter(|| {
|
||||
dst.clear();
|
||||
PasswordMessage::md5(
|
||||
"8e323AMF9YSE9zftFnuhQcvhz7Vf342W4cWU",
|
||||
"postgres",
|
||||
&[10, 41, 20, 150],
|
||||
[10, 41, 20, 150],
|
||||
)
|
||||
.encode(&mut dst)
|
||||
.unwrap();
|
||||
|
|
|
@ -1,91 +1,46 @@
|
|||
use crate::Encode;
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use std::io;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct StartupMessage {
|
||||
// (major, minor)
|
||||
version: (u16, u16),
|
||||
params: Bytes,
|
||||
pub struct StartupMessage<'a> {
|
||||
params: &'a [(&'a str, &'a str)],
|
||||
}
|
||||
|
||||
impl StartupMessage {
|
||||
pub fn builder() -> StartupMessageBuilder {
|
||||
StartupMessageBuilder::new()
|
||||
impl<'a> StartupMessage<'a> {
|
||||
pub fn new(params: &'a [(&'a str, &'a str)]) -> Self {
|
||||
Self { params }
|
||||
}
|
||||
|
||||
pub fn version(&self) -> (u16, u16) {
|
||||
self.version
|
||||
}
|
||||
|
||||
pub fn params(&self) -> StartupMessageParams<'_> {
|
||||
StartupMessageParams(&*self.params)
|
||||
pub fn params(&self) -> &'a [(&'a str, &'a str)] {
|
||||
self.params
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode for StartupMessage {
|
||||
impl<'a> Encode for StartupMessage<'a> {
|
||||
fn encode(&self, buf: &mut Vec<u8>) -> io::Result<()> {
|
||||
let len = self.params.len() + 8;
|
||||
buf.reserve(len);
|
||||
buf.put_u32_be(len as u32);
|
||||
buf.put_u16_be(self.version.0);
|
||||
buf.put_u16_be(self.version.1);
|
||||
buf.put(&self.params);
|
||||
let pos = buf.len();
|
||||
buf.extend_from_slice(&(0 as u32).to_be_bytes()); // skip over len
|
||||
buf.extend_from_slice(&3_u16.to_be_bytes()); // major version
|
||||
buf.extend_from_slice(&0_u16.to_be_bytes()); // minor version
|
||||
|
||||
for (name, value) in self.params {
|
||||
buf.extend_from_slice(name.as_bytes());
|
||||
buf.push(0);
|
||||
buf.extend_from_slice(value.as_bytes());
|
||||
buf.push(0);
|
||||
}
|
||||
|
||||
buf.push(0);
|
||||
|
||||
// Write-back the len to the beginning of this frame
|
||||
let len = buf.len() - pos;
|
||||
BigEndian::write_u32(&mut buf[pos..], len as u32);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Impl Iterator to iter over params
|
||||
pub struct StartupMessageParams<'a>(&'a [u8]);
|
||||
|
||||
pub struct StartupMessageBuilder {
|
||||
// (major, minor)
|
||||
version: (u16, u16),
|
||||
params: BytesMut,
|
||||
}
|
||||
|
||||
impl Default for StartupMessageBuilder {
|
||||
fn default() -> Self {
|
||||
StartupMessageBuilder {
|
||||
version: (3, 0),
|
||||
params: BytesMut::with_capacity(156),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StartupMessageBuilder {
|
||||
pub fn new() -> Self {
|
||||
StartupMessageBuilder::default()
|
||||
}
|
||||
|
||||
/// Set the protocol version number. Defaults to `3.0`.
|
||||
pub fn version(mut self, major: u16, minor: u16) -> Self {
|
||||
self.version = (major, minor);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn param(mut self, name: &str, value: &str) -> Self {
|
||||
self.params.reserve(name.len() + value.len() + 2);
|
||||
self.params.put(name.as_bytes());
|
||||
self.params.put_u8(0);
|
||||
self.params.put(value.as_bytes());
|
||||
self.params.put_u8(0);
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(mut self) -> StartupMessage {
|
||||
self.params.reserve(1);
|
||||
self.params.put_u8(0);
|
||||
|
||||
StartupMessage {
|
||||
version: self.version,
|
||||
params: self.params.freeze(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StartupMessage;
|
||||
|
@ -96,10 +51,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn it_encodes_startup_message() -> io::Result<()> {
|
||||
let message = StartupMessage::builder()
|
||||
.param("user", "postgres")
|
||||
.param("database", "postgres")
|
||||
.build();
|
||||
let message = StartupMessage::new(&[("user", "postgres"), ("database", "postgres")]);
|
||||
|
||||
let mut buf = Vec::new();
|
||||
message.encode(&mut buf)?;
|
||||
|
|
|
@ -10,31 +10,26 @@ pub async fn establish<'a, 'b: 'a>(
|
|||
) -> io::Result<()> {
|
||||
// See this doc for more runtime parameters
|
||||
// https://www.postgresql.org/docs/12/runtime-config-client.html
|
||||
let mut message = StartupMessage::builder();
|
||||
|
||||
if let Some(user) = options.user {
|
||||
// FIXME: User is technically required. We should default this like psql does.
|
||||
message = message.param("user", user);
|
||||
}
|
||||
|
||||
if let Some(database) = options.database {
|
||||
message = message.param("database", database);
|
||||
}
|
||||
|
||||
let message = message
|
||||
let params = &[
|
||||
// FIXME: ConnectOptions user and database need to be required parameters and error
|
||||
// before they get here
|
||||
("user", options.user.expect("user is required")),
|
||||
("database", options.database.expect("database is required")),
|
||||
// Sets the display format for date and time values,
|
||||
// as well as the rules for interpreting ambiguous date input values.
|
||||
.param("DateStyle", "ISO, MDY")
|
||||
("DateStyle", "ISO, MDY"),
|
||||
// Sets the display format for interval values.
|
||||
.param("IntervalStyle", "iso_8601")
|
||||
("IntervalStyle", "iso_8601"),
|
||||
// Sets the time zone for displaying and interpreting time stamps.
|
||||
.param("TimeZone", "UTC")
|
||||
("TimeZone", "UTC"),
|
||||
// Adjust postgres to return percise values for floats
|
||||
// NOTE: This is default in postgres 12+
|
||||
.param("extra_float_digits", "3")
|
||||
("extra_float_digits", "3"),
|
||||
// Sets the client-side encoding (character set).
|
||||
.param("client_encoding", "UTF-8")
|
||||
.build();
|
||||
("client_encoding", "UTF-8"),
|
||||
];
|
||||
|
||||
let message = StartupMessage::new(params);
|
||||
|
||||
conn.send(message).await?;
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ async fn main() -> io::Result<()> {
|
|||
.host("127.0.0.1")
|
||||
.port(5432)
|
||||
.user("postgres")
|
||||
.database("postgres")
|
||||
.password("password"),
|
||||
)
|
||||
.await?;
|
||||
|
|
Loading…
Reference in a new issue