Establish a solid foundation with mason-postgres

This commit is contained in:
Ryan Leckey 2019-06-10 21:57:08 -07:00
parent 1e983a2a8b
commit fda6d9ef81
11 changed files with 335 additions and 179 deletions

View file

@ -1,7 +1,8 @@
[workspace]
members = [
".",
"mason-postgres"
"mason-core",
"mason-postgres",
]
[package]
@ -14,6 +15,7 @@ edition = "2018"
[dependencies]
runtime = "=0.3.0-alpha.4"
mason-core = { path = "mason-core" }
mason-postgres = { path = "mason-postgres" }
failure = "0.1"
env_logger = "0.6.1"

9
mason-core/Cargo.toml Normal file
View file

@ -0,0 +1,9 @@
[package]
name = "mason-core"
version = "0.0.0"
authors = ["Ryan Leckey <leckey.ryan@gmail.com>"]
license = "MIT OR Apache-2.0"
description = "Shared types and traits for Mason."
edition = "2018"
[dependencies]

View file

@ -0,0 +1,39 @@
pub struct ConnectOptions<'a> {
pub host: &'a str,
pub port: u16,
pub user: Option<&'a str>,
pub database: Option<&'a str>,
pub password: Option<&'a str>,
}
impl<'a> Default for ConnectOptions<'a> {
#[inline]
fn default() -> Self {
Self { host: "localhost", port: 5432, user: None, database: None, password: None }
}
}
impl<'a> ConnectOptions<'a> {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn user(mut self, user: &'a str) -> Self {
self.user = Some(user);
self
}
#[inline]
pub fn database(mut self, database: &'a str) -> Self {
self.database = Some(database);
self
}
#[inline]
pub fn password(mut self, password: &'a str) -> Self {
self.password = Some(password);
self
}
}

3
mason-core/src/lib.rs Normal file
View file

@ -0,0 +1,3 @@
pub use connection::ConnectOptions;
mod connection;

View file

@ -7,6 +7,7 @@ description = "PostgreSQL database driver for dbx."
edition = "2018"
[dependencies]
mason-core = { path = "../mason-core" }
runtime = "=0.3.0-alpha.4"
futures-preview = "=0.3.0-alpha.16"
failure = "0.1"
@ -15,3 +16,4 @@ log = "0.4"
hex = "0.3.2"
bytes = "0.4.12"
memchr = "2.2.0"
md-5 = "0.8.0"

View file

@ -0,0 +1,203 @@
use crate::protocol::{
client::{PasswordMessage, Serialize, Terminate, StartupMessage},
server::Message as ServerMessage,
};
use bytes::BytesMut;
use futures::{
channel::mpsc,
io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
SinkExt, StreamExt,
};
use mason_core::ConnectOptions;
use md5::{Digest, Md5};
use runtime::{net::TcpStream, task::JoinHandle};
use std::{io};
pub struct Connection {
writer: WriteHalf<TcpStream>,
incoming: mpsc::UnboundedReceiver<ServerMessage>,
// Buffer used when serializing outgoing messages
wbuf: Vec<u8>,
// Handle to coroutine reading messages from the stream
receiver: JoinHandle<io::Result<()>>,
// Process ID of the Backend
process_id: i32,
// Backend-unique key to use to send a cancel query message to the server
secret_key: i32,
}
impl Connection {
pub async fn establish(options: ConnectOptions<'_>) -> io::Result<Self> {
let stream = TcpStream::connect((options.host, options.port)).await?;
let (reader, writer) = stream.split();
let (tx, rx) = mpsc::unbounded();
let receiver = runtime::spawn(receiver(reader, tx));
let mut conn = Self {
wbuf: Vec::with_capacity(1024),
writer,
receiver,
incoming: rx,
process_id: -1,
secret_key: -1,
};
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
let params = [
("user", options.user),
("database", options.database),
// TODO: Expose this property perhaps?
(
"application_name",
Some(concat!(env!("CARGO_PKG_NAME"), "/v", env!("CARGO_PKG_VERSION"))),
),
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
("DateStyle", Some("ISO, MDY")),
// Sets the display format for interval values.
("IntervalStyle", Some("iso_8601")),
// Sets the time zone for displaying and interpreting time stamps.
("TimeZone", Some("UTC")),
// Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+
("extra_float_digits", Some("3")),
// Sets the client-side encoding (character set).
("client_encoding", Some("UTF-8")),
];
conn.send(StartupMessage { params: &params }).await?;
while let Some(message) = conn.incoming.next().await {
match message {
ServerMessage::AuthenticationOk => {
// Do nothing; server is just telling us that
// there is no password needed
}
ServerMessage::AuthenticationMd5Password(body) => {
// Hash password|username
// FIXME: ConnectOptions should prepare a default user
let pass_user =
md5(options.password.unwrap_or_default(), options.user.unwrap_or_default());
let with_salt = md5(pass_user, body.salt());
let password = format!("md5{}", with_salt);
conn.send(PasswordMessage { password: &password }).await?;
}
ServerMessage::BackendKeyData(body) => {
conn.process_id = body.process_id();
conn.secret_key = body.secret_key();
}
ServerMessage::ReadyForQuery(_) => {
// Good to go
break;
}
_ => {
unimplemented!("received {:?} unimplemented message", message);
}
}
}
Ok(conn)
}
pub async fn close(mut self) -> io::Result<()> {
self.send(Terminate).await?;
self.writer.close().await?;
self.receiver.await?;
Ok(())
}
// Send client-serializable message to the server
async fn send<S>(&mut self, message: S) -> io::Result<()>
where
S: Serialize,
{
self.wbuf.clear();
message.serialize(&mut self.wbuf);
self.writer.write_all(&self.wbuf).await?;
self.writer.flush().await?;
Ok(())
}
}
async fn receiver(
mut reader: ReadHalf<TcpStream>,
mut sender: mpsc::UnboundedSender<ServerMessage>,
) -> io::Result<()> {
let mut rbuf = BytesMut::with_capacity(0);
let mut len = 0;
loop {
// This uses an adaptive system to extend the vector when it fills. We want to
// avoid paying to allocate and zero a huge chunk of memory if the reader only
// has 4 bytes while still making large reads if the reader does have a ton
// of data to return.
// See: https://github.com/rust-lang-nursery/futures-rs/blob/master/futures-util/src/io/read_to_end.rs#L50-L54
if len == rbuf.len() {
rbuf.reserve(32);
unsafe {
// Set length to the capacity and efficiently
// zero-out the memory
rbuf.set_len(rbuf.capacity());
reader.initializer().initialize(&mut rbuf[len..]);
}
}
// TODO: Need a select! on a channel that I can trigger to cancel this
let cnt = reader.read(&mut rbuf[len..]).await?;
if cnt > 0 {
len += cnt;
} else {
// Read 0 bytes from the server; end-of-stream
break;
}
while len > 0 {
let size = rbuf.len();
let message = ServerMessage::deserialize(&mut rbuf)?;
len -= size - rbuf.len();
// TODO: Some messages should be kept behind here
match message {
Some(ServerMessage::ParameterStatus(body)) => {
log::debug!("parameter {} = {}", body.name()?, body.value()?);
}
Some(message) => {
// TODO: Handle this error?
sender.send(message).await.unwrap();
}
None => {
// Did not receive enough bytes to
// deserialize a complete message
break;
}
}
}
}
Ok(())
}
#[inline]
fn md5(a: impl AsRef<[u8]>, b: impl AsRef<[u8]>) -> String {
hex::encode(Md5::new().chain(a).chain(b).result())
}

View file

@ -1,143 +1,7 @@
#![feature(non_exhaustive, async_await)]
#![allow(clippy::needless_lifetimes)]
use crate::protocol::{client, server};
use bytes::BytesMut;
use futures::{
channel::mpsc,
io::{AsyncRead, AsyncReadExt, AsyncWriteExt, WriteHalf},
SinkExt, StreamExt,
};
use runtime::{net::TcpStream, task::JoinHandle};
use std::io;
mod connection;
mod protocol;
pub mod protocol;
pub struct Connection {
buf: Vec<u8>,
writer: WriteHalf<TcpStream>,
incoming: mpsc::Receiver<server::Message>,
reader: Option<JoinHandle<io::Result<()>>>,
}
impl Connection {
pub async fn open(address: &str) -> io::Result<Self> {
let stream = TcpStream::connect(address).await?;
let (mut reader, writer) = stream.split();
// FIXME: What's a good buffer size here?
let (mut tx, rx) = mpsc::channel(1024);
let reader = runtime::spawn(async move {
let mut buf = BytesMut::with_capacity(0);
let mut len = 0;
'reader: loop {
if len == buf.len() {
unsafe {
buf.reserve(32);
buf.set_len(buf.capacity());
reader.initializer().initialize(&mut buf[len..]);
}
}
let num = reader.read(&mut buf[len..]).await?;
if num > 0 {
len += num;
}
while len > 0 && !buf.is_empty() {
let size = buf.len();
let msg = server::Message::deserialize(&mut buf)?;
let removed = size - buf.len();
len -= removed;
match msg {
Some(server::Message::ParameterStatus(body)) => {
// FIXME: Proper log
log::info!("{:?}", body);
}
Some(msg) => {
tx.send(msg).await.unwrap();
}
None => {
// We have _some_ amount of data but not enough to
// deserialize anything
break;
}
}
}
// FIXME: This probably doesn't make sense
if len == 0 && !buf.is_empty() {
// Hit end-of-stream
break 'reader;
}
}
Ok(())
});
Ok(Self {
// FIXME: What's a good buffer size here?
buf: Vec::with_capacity(1024),
writer,
reader: Some(reader),
incoming: rx,
})
}
pub async fn startup<'a, 'b: 'a>(
&'a mut self,
user: &'b str,
_password: &'b str,
database: &'b str,
) -> io::Result<()> {
let msg = client::StartupMessage { user, database: Some(database) };
msg.serialize(&mut self.buf);
self.writer.write_all(&self.buf).await?;
self.buf.clear();
self.writer.flush().await?;
// FIXME: We _actually_ want to wait until ReadyForQuery, ErrorResponse, AuthX, etc.
while let Some(message) = self.incoming.next().await {
match message {
server::Message::AuthenticationOk => {
// Do nothing; server is just telling us "you're in"
}
server::Message::ReadyForQuery(_) => {
// Good to go
break;
}
_ => {}
}
}
Ok(())
}
pub async fn terminate(&mut self) -> io::Result<()> {
let msg = client::Terminate {};
msg.serialize(&mut self.buf);
self.writer.write_all(&self.buf).await?;
self.buf.clear();
self.writer.flush().await?;
self.writer.close().await?;
if let Some(reader) = self.reader.take() {
reader.await?;
}
Ok(())
}
}
pub use connection::Connection;

View file

@ -4,11 +4,15 @@ use byteorder::{BigEndian, ByteOrder};
// https://www.postgresql.org/docs/devel/protocol-message-formats.html
// https://www.postgresql.org/docs/devel/protocol-message-types.html
pub trait Serialize {
fn serialize(&self, buf: &mut Vec<u8>);
}
#[derive(Debug)]
pub struct Terminate;
impl Terminate {
pub fn serialize(&self, buf: &mut Vec<u8>) {
impl Serialize for Terminate {
fn serialize(&self, buf: &mut Vec<u8>) {
buf.push(b'X');
buf.push(4);
}
@ -16,56 +20,44 @@ impl Terminate {
#[derive(Debug)]
pub struct StartupMessage<'a> {
pub user: &'a str,
pub database: Option<&'a str>,
pub params: &'a [(&'a str, Option<&'a str>)],
}
impl<'a> StartupMessage<'a> {
pub fn serialize(&self, buf: &mut Vec<u8>) {
impl<'a> Serialize for StartupMessage<'a> {
fn serialize(&self, buf: &mut Vec<u8>) {
with_length_prefix(buf, |buf| {
// protocol version: major = 3, minor = 0
buf.extend_from_slice(&0x0003_i16.to_be_bytes());
buf.extend_from_slice(&0x0000_i16.to_be_bytes());
write_str(buf, "user");
write_str(buf, self.user);
if let Some(database) = self.database {
write_str(buf, "database");
write_str(buf, database);
for (name, value) in self.params {
if let Some(value) = value {
write_str(buf, name);
write_str(buf, value);
}
}
// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
// Sets the display format for date and time values,
// as well as the rules for interpreting ambiguous date input values.
write_str(buf, "DateStyle");
write_str(buf, "ISO");
// Sets the display format for interval values.
write_str(buf, "IntervalStyle");
write_str(buf, "iso_8601");
// Sets the time zone for displaying and interpreting time stamps.
write_str(buf, "TimeZone");
write_str(buf, "UTC");
// Adjust postgres to return percise values for floats
// NOTE: This is default in postgres 12+
write_str(buf, "extra_float_digits");
write_str(buf, "3");
// Sets the client-side encoding (character set).
write_str(buf, "client_encoding");
write_str(buf, "UTF-8");
// A zero byte is required as a terminator after the last name/value pair.
buf.push(0);
});
}
}
#[derive(Debug)]
pub struct PasswordMessage<'a> {
pub password: &'a str,
}
impl<'a> Serialize for PasswordMessage<'a> {
fn serialize(&self, buf: &mut Vec<u8>) {
buf.push(b'p');
with_length_prefix(buf, |buf| {
write_str(buf, self.password);
});
}
}
// Write a string followed by a null-terminator
#[inline]
fn write_str(buf: &mut Vec<u8>, s: &str) {

View file

@ -1,6 +1,6 @@
use byteorder::{BigEndian, ByteOrder};
use bytes::{Bytes, BytesMut};
use std::io;
use std::{io, str};
// Reference
// https://www.postgresql.org/docs/devel/protocol-message-formats.html
@ -76,6 +76,9 @@ impl Message {
b'R' => match BigEndian::read_i32(&buf[idx..]) {
0 => Message::AuthenticationOk,
5 => Message::AuthenticationMd5Password(AuthenticationMd5Password {
salt: buf.slice_from(idx + 4),
}),
code => {
unimplemented!("unknown response code received: {:x}", code);
@ -96,7 +99,14 @@ impl Message {
#[derive(Debug)]
pub struct AuthenticationMd5Password {
pub(super) salt: [u8; 4],
pub(super) salt: Bytes,
}
impl AuthenticationMd5Password {
#[inline]
pub fn salt(&self) -> &[u8] {
&self.salt
}
}
#[derive(Debug)]
@ -111,6 +121,18 @@ pub struct BackendKeyData {
pub(super) secret_key: i32,
}
impl BackendKeyData {
#[inline]
pub fn process_id(&self) -> i32 {
self.process_id
}
#[inline]
pub fn secret_key(&self) -> i32 {
self.secret_key
}
}
#[derive(Debug)]
pub struct CommandComplete {
pub(super) tag: Bytes,
@ -133,6 +155,18 @@ pub struct ParameterStatus {
pub(super) value: Bytes,
}
impl ParameterStatus {
#[inline]
pub fn name(&self) -> io::Result<&str> {
Ok(str::from_utf8(&self.name).map_err(|_| io::ErrorKind::InvalidInput)?)
}
#[inline]
pub fn value(&self) -> io::Result<&str> {
Ok(str::from_utf8(&self.value).map_err(|_| io::ErrorKind::InvalidInput)?)
}
}
#[derive(Debug)]
pub struct ReadyForQuery {
pub(super) status: u8,

View file

@ -1 +1,2 @@
pub use mason_core::ConnectOptions;
pub use mason_postgres as pg;

View file

@ -1,8 +1,15 @@
#![feature(async_await)]
use mason::{pg::Connection, ConnectOptions};
#[runtime::main]
async fn main() -> Result<(), failure::Error> {
env_logger::try_init()?;
let conn =
Connection::establish(ConnectOptions::new().user("postgres").password("password")).await?;
conn.close().await?;
Ok(())
}