diff --git a/mason-postgres/src/lib.rs b/mason-postgres/src/lib.rs index 2cb84420..68696b5b 100644 --- a/mason-postgres/src/lib.rs +++ b/mason-postgres/src/lib.rs @@ -96,9 +96,7 @@ impl Connection { _password: &'b str, database: &'b str, ) -> io::Result<()> { - let params = [("user", user), ("database", database)]; - - let msg = client::StartupMessage { params: ¶ms }; + let msg = client::StartupMessage { user, database: Some(database) }; msg.serialize(&mut self.buf); self.writer.write_all(&self.buf).await?; diff --git a/mason-postgres/src/protocol/client.rs b/mason-postgres/src/protocol/client.rs index 35ab9835..9fd1a809 100644 --- a/mason-postgres/src/protocol/client.rs +++ b/mason-postgres/src/protocol/client.rs @@ -16,31 +16,63 @@ impl Terminate { #[derive(Debug)] pub struct StartupMessage<'a> { - /// One or more pairs of parameter name and value strings. - /// A zero byte is required as a terminator after the last name/value pair. - /// Parameters can appear in any order. user is required, others are optional. - pub params: &'a [(&'a str, &'a str)], + pub user: &'a str, + pub database: Option<&'a str>, } impl<'a> StartupMessage<'a> { pub fn serialize(&self, buf: &mut Vec) { with_length_prefix(buf, |buf| { - // version: 3 = major, 0 = minor - buf.extend_from_slice(&0x0003_0000_i32.to_be_bytes()); + // protocol version: major = 3, minor = 0 + buf.extend_from_slice(&0x0003_i16.to_be_bytes()); + buf.extend_from_slice(&0x0000_i16.to_be_bytes()); - for (name, value) in self.params { - buf.extend_from_slice(name.as_bytes()); - buf.push(0); - buf.extend_from_slice(value.as_bytes()); - buf.push(0); + write_str(buf, "user"); + write_str(buf, self.user); + + if let Some(database) = self.database { + write_str(buf, "database"); + write_str(buf, database); } + // See this doc for more runtime parameters + // https://www.postgresql.org/docs/12/runtime-config-client.html + + // Sets the display format for date and time values, + // as well as the rules for interpreting ambiguous date input values. + write_str(buf, "DateStyle"); + write_str(buf, "ISO"); + + // Sets the display format for interval values. + write_str(buf, "IntervalStyle"); + write_str(buf, "iso_8601"); + + // Sets the time zone for displaying and interpreting time stamps. + write_str(buf, "TimeZone"); + write_str(buf, "UTC"); + + // Adjust postgres to return percise values for floats + // NOTE: This is default in postgres 12+ + write_str(buf, "extra_float_digits"); + write_str(buf, "3"); + + // Sets the client-side encoding (character set). + write_str(buf, "client_encoding"); + write_str(buf, "UTF-8"); + // A zero byte is required as a terminator after the last name/value pair. buf.push(0); }); } } +// Write a string followed by a null-terminator +#[inline] +fn write_str(buf: &mut Vec, s: &str) { + buf.extend_from_slice(s.as_bytes()); + buf.push(0); +} + // Write a variable amount of data into a buffer and then // prefix that data with the length of what was written fn with_length_prefix(buf: &mut Vec, f: F) @@ -63,18 +95,21 @@ where mod tests { use super::*; - // TODO: Serialize test more messages + // TODO: encode test more messages #[test] fn ser_startup_message() { - let msg = StartupMessage { params: &[("user", "postgres"), ("database", "postgres")] }; + let msg = StartupMessage { user: "postgres", database: None }; let mut buf = Vec::new(); - msg.serialize(&mut buf); + msg.encode(&mut buf); assert_eq!( - "00000029000300007573657200706f73746772657\ - 300646174616261736500706f7374677265730000", + "00000074000300007573657200706f7374677265730044617465537\ + 4796c650049534f00496e74657276616c5374796c650069736f5f38\ + 3630310054696d655a6f6e65005554430065787472615f666c6f617\ + 45f646967697473003300636c69656e745f656e636f64696e670055\ + 54462d380000", hex::encode(buf) ); }