mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
Establish a solid foundation with mason-postgres
This commit is contained in:
parent
1e983a2a8b
commit
fda6d9ef81
11 changed files with 335 additions and 179 deletions
|
@ -1,7 +1,8 @@
|
|||
[workspace]
|
||||
members = [
|
||||
".",
|
||||
"mason-postgres"
|
||||
"mason-core",
|
||||
"mason-postgres",
|
||||
]
|
||||
|
||||
[package]
|
||||
|
@ -14,6 +15,7 @@ edition = "2018"
|
|||
|
||||
[dependencies]
|
||||
runtime = "=0.3.0-alpha.4"
|
||||
mason-core = { path = "mason-core" }
|
||||
mason-postgres = { path = "mason-postgres" }
|
||||
failure = "0.1"
|
||||
env_logger = "0.6.1"
|
||||
|
|
9
mason-core/Cargo.toml
Normal file
9
mason-core/Cargo.toml
Normal file
|
@ -0,0 +1,9 @@
|
|||
[package]
|
||||
name = "mason-core"
|
||||
version = "0.0.0"
|
||||
authors = ["Ryan Leckey <leckey.ryan@gmail.com>"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
description = "Shared types and traits for Mason."
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
39
mason-core/src/connection.rs
Normal file
39
mason-core/src/connection.rs
Normal file
|
@ -0,0 +1,39 @@
|
|||
pub struct ConnectOptions<'a> {
|
||||
pub host: &'a str,
|
||||
pub port: u16,
|
||||
pub user: Option<&'a str>,
|
||||
pub database: Option<&'a str>,
|
||||
pub password: Option<&'a str>,
|
||||
}
|
||||
|
||||
impl<'a> Default for ConnectOptions<'a> {
|
||||
#[inline]
|
||||
fn default() -> Self {
|
||||
Self { host: "localhost", port: 5432, user: None, database: None, password: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ConnectOptions<'a> {
|
||||
#[inline]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn user(mut self, user: &'a str) -> Self {
|
||||
self.user = Some(user);
|
||||
self
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn database(mut self, database: &'a str) -> Self {
|
||||
self.database = Some(database);
|
||||
self
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn password(mut self, password: &'a str) -> Self {
|
||||
self.password = Some(password);
|
||||
self
|
||||
}
|
||||
}
|
3
mason-core/src/lib.rs
Normal file
3
mason-core/src/lib.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
pub use connection::ConnectOptions;
|
||||
|
||||
mod connection;
|
|
@ -7,6 +7,7 @@ description = "PostgreSQL database driver for dbx."
|
|||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
mason-core = { path = "../mason-core" }
|
||||
runtime = "=0.3.0-alpha.4"
|
||||
futures-preview = "=0.3.0-alpha.16"
|
||||
failure = "0.1"
|
||||
|
@ -15,3 +16,4 @@ log = "0.4"
|
|||
hex = "0.3.2"
|
||||
bytes = "0.4.12"
|
||||
memchr = "2.2.0"
|
||||
md-5 = "0.8.0"
|
||||
|
|
203
mason-postgres/src/connection.rs
Normal file
203
mason-postgres/src/connection.rs
Normal file
|
@ -0,0 +1,203 @@
|
|||
use crate::protocol::{
|
||||
client::{PasswordMessage, Serialize, Terminate, StartupMessage},
|
||||
server::Message as ServerMessage,
|
||||
};
|
||||
use bytes::BytesMut;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf},
|
||||
SinkExt, StreamExt,
|
||||
};
|
||||
use mason_core::ConnectOptions;
|
||||
use md5::{Digest, Md5};
|
||||
use runtime::{net::TcpStream, task::JoinHandle};
|
||||
use std::{io};
|
||||
|
||||
pub struct Connection {
|
||||
writer: WriteHalf<TcpStream>,
|
||||
incoming: mpsc::UnboundedReceiver<ServerMessage>,
|
||||
|
||||
// Buffer used when serializing outgoing messages
|
||||
wbuf: Vec<u8>,
|
||||
|
||||
// Handle to coroutine reading messages from the stream
|
||||
receiver: JoinHandle<io::Result<()>>,
|
||||
|
||||
// Process ID of the Backend
|
||||
process_id: i32,
|
||||
|
||||
// Backend-unique key to use to send a cancel query message to the server
|
||||
secret_key: i32,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub async fn establish(options: ConnectOptions<'_>) -> io::Result<Self> {
|
||||
let stream = TcpStream::connect((options.host, options.port)).await?;
|
||||
let (reader, writer) = stream.split();
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
let receiver = runtime::spawn(receiver(reader, tx));
|
||||
let mut conn = Self {
|
||||
wbuf: Vec::with_capacity(1024),
|
||||
writer,
|
||||
receiver,
|
||||
incoming: rx,
|
||||
process_id: -1,
|
||||
secret_key: -1,
|
||||
};
|
||||
|
||||
// See this doc for more runtime parameters
|
||||
// https://www.postgresql.org/docs/12/runtime-config-client.html
|
||||
let params = [
|
||||
("user", options.user),
|
||||
("database", options.database),
|
||||
// TODO: Expose this property perhaps?
|
||||
(
|
||||
"application_name",
|
||||
Some(concat!(env!("CARGO_PKG_NAME"), "/v", env!("CARGO_PKG_VERSION"))),
|
||||
),
|
||||
// Sets the display format for date and time values,
|
||||
// as well as the rules for interpreting ambiguous date input values.
|
||||
("DateStyle", Some("ISO, MDY")),
|
||||
// Sets the display format for interval values.
|
||||
("IntervalStyle", Some("iso_8601")),
|
||||
// Sets the time zone for displaying and interpreting time stamps.
|
||||
("TimeZone", Some("UTC")),
|
||||
// Adjust postgres to return percise values for floats
|
||||
// NOTE: This is default in postgres 12+
|
||||
("extra_float_digits", Some("3")),
|
||||
// Sets the client-side encoding (character set).
|
||||
("client_encoding", Some("UTF-8")),
|
||||
];
|
||||
|
||||
conn.send(StartupMessage { params: ¶ms }).await?;
|
||||
|
||||
while let Some(message) = conn.incoming.next().await {
|
||||
match message {
|
||||
ServerMessage::AuthenticationOk => {
|
||||
// Do nothing; server is just telling us that
|
||||
// there is no password needed
|
||||
}
|
||||
|
||||
ServerMessage::AuthenticationMd5Password(body) => {
|
||||
// Hash password|username
|
||||
// FIXME: ConnectOptions should prepare a default user
|
||||
let pass_user =
|
||||
md5(options.password.unwrap_or_default(), options.user.unwrap_or_default());
|
||||
|
||||
let with_salt = md5(pass_user, body.salt());
|
||||
let password = format!("md5{}", with_salt);
|
||||
|
||||
conn.send(PasswordMessage { password: &password }).await?;
|
||||
}
|
||||
|
||||
ServerMessage::BackendKeyData(body) => {
|
||||
conn.process_id = body.process_id();
|
||||
conn.secret_key = body.secret_key();
|
||||
}
|
||||
|
||||
ServerMessage::ReadyForQuery(_) => {
|
||||
// Good to go
|
||||
break;
|
||||
}
|
||||
|
||||
_ => {
|
||||
unimplemented!("received {:?} unimplemented message", message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
pub async fn close(mut self) -> io::Result<()> {
|
||||
self.send(Terminate).await?;
|
||||
self.writer.close().await?;
|
||||
self.receiver.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Send client-serializable message to the server
|
||||
async fn send<S>(&mut self, message: S) -> io::Result<()>
|
||||
where
|
||||
S: Serialize,
|
||||
{
|
||||
self.wbuf.clear();
|
||||
|
||||
message.serialize(&mut self.wbuf);
|
||||
|
||||
self.writer.write_all(&self.wbuf).await?;
|
||||
self.writer.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn receiver(
|
||||
mut reader: ReadHalf<TcpStream>,
|
||||
mut sender: mpsc::UnboundedSender<ServerMessage>,
|
||||
) -> io::Result<()> {
|
||||
let mut rbuf = BytesMut::with_capacity(0);
|
||||
let mut len = 0;
|
||||
|
||||
loop {
|
||||
// This uses an adaptive system to extend the vector when it fills. We want to
|
||||
// avoid paying to allocate and zero a huge chunk of memory if the reader only
|
||||
// has 4 bytes while still making large reads if the reader does have a ton
|
||||
// of data to return.
|
||||
|
||||
// See: https://github.com/rust-lang-nursery/futures-rs/blob/master/futures-util/src/io/read_to_end.rs#L50-L54
|
||||
|
||||
if len == rbuf.len() {
|
||||
rbuf.reserve(32);
|
||||
|
||||
unsafe {
|
||||
// Set length to the capacity and efficiently
|
||||
// zero-out the memory
|
||||
rbuf.set_len(rbuf.capacity());
|
||||
reader.initializer().initialize(&mut rbuf[len..]);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Need a select! on a channel that I can trigger to cancel this
|
||||
let cnt = reader.read(&mut rbuf[len..]).await?;
|
||||
|
||||
if cnt > 0 {
|
||||
len += cnt;
|
||||
} else {
|
||||
// Read 0 bytes from the server; end-of-stream
|
||||
break;
|
||||
}
|
||||
|
||||
while len > 0 {
|
||||
let size = rbuf.len();
|
||||
let message = ServerMessage::deserialize(&mut rbuf)?;
|
||||
len -= size - rbuf.len();
|
||||
|
||||
// TODO: Some messages should be kept behind here
|
||||
match message {
|
||||
Some(ServerMessage::ParameterStatus(body)) => {
|
||||
log::debug!("parameter {} = {}", body.name()?, body.value()?);
|
||||
}
|
||||
|
||||
Some(message) => {
|
||||
// TODO: Handle this error?
|
||||
sender.send(message).await.unwrap();
|
||||
}
|
||||
|
||||
None => {
|
||||
// Did not receive enough bytes to
|
||||
// deserialize a complete message
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn md5(a: impl AsRef<[u8]>, b: impl AsRef<[u8]>) -> String {
|
||||
hex::encode(Md5::new().chain(a).chain(b).result())
|
||||
}
|
|
@ -1,143 +1,7 @@
|
|||
#![feature(non_exhaustive, async_await)]
|
||||
#![allow(clippy::needless_lifetimes)]
|
||||
|
||||
use crate::protocol::{client, server};
|
||||
use bytes::BytesMut;
|
||||
use futures::{
|
||||
channel::mpsc,
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWriteExt, WriteHalf},
|
||||
SinkExt, StreamExt,
|
||||
};
|
||||
use runtime::{net::TcpStream, task::JoinHandle};
|
||||
use std::io;
|
||||
mod connection;
|
||||
mod protocol;
|
||||
|
||||
pub mod protocol;
|
||||
|
||||
pub struct Connection {
|
||||
buf: Vec<u8>,
|
||||
writer: WriteHalf<TcpStream>,
|
||||
incoming: mpsc::Receiver<server::Message>,
|
||||
reader: Option<JoinHandle<io::Result<()>>>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub async fn open(address: &str) -> io::Result<Self> {
|
||||
let stream = TcpStream::connect(address).await?;
|
||||
let (mut reader, writer) = stream.split();
|
||||
|
||||
// FIXME: What's a good buffer size here?
|
||||
let (mut tx, rx) = mpsc::channel(1024);
|
||||
|
||||
let reader = runtime::spawn(async move {
|
||||
let mut buf = BytesMut::with_capacity(0);
|
||||
let mut len = 0;
|
||||
|
||||
'reader: loop {
|
||||
if len == buf.len() {
|
||||
unsafe {
|
||||
buf.reserve(32);
|
||||
buf.set_len(buf.capacity());
|
||||
reader.initializer().initialize(&mut buf[len..]);
|
||||
}
|
||||
}
|
||||
|
||||
let num = reader.read(&mut buf[len..]).await?;
|
||||
if num > 0 {
|
||||
len += num;
|
||||
}
|
||||
|
||||
while len > 0 && !buf.is_empty() {
|
||||
let size = buf.len();
|
||||
let msg = server::Message::deserialize(&mut buf)?;
|
||||
|
||||
let removed = size - buf.len();
|
||||
len -= removed;
|
||||
|
||||
match msg {
|
||||
Some(server::Message::ParameterStatus(body)) => {
|
||||
// FIXME: Proper log
|
||||
log::info!("{:?}", body);
|
||||
}
|
||||
|
||||
Some(msg) => {
|
||||
tx.send(msg).await.unwrap();
|
||||
}
|
||||
|
||||
None => {
|
||||
// We have _some_ amount of data but not enough to
|
||||
// deserialize anything
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: This probably doesn't make sense
|
||||
if len == 0 && !buf.is_empty() {
|
||||
// Hit end-of-stream
|
||||
break 'reader;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
// FIXME: What's a good buffer size here?
|
||||
buf: Vec::with_capacity(1024),
|
||||
writer,
|
||||
reader: Some(reader),
|
||||
incoming: rx,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn startup<'a, 'b: 'a>(
|
||||
&'a mut self,
|
||||
user: &'b str,
|
||||
_password: &'b str,
|
||||
database: &'b str,
|
||||
) -> io::Result<()> {
|
||||
let msg = client::StartupMessage { user, database: Some(database) };
|
||||
msg.serialize(&mut self.buf);
|
||||
|
||||
self.writer.write_all(&self.buf).await?;
|
||||
self.buf.clear();
|
||||
|
||||
self.writer.flush().await?;
|
||||
|
||||
// FIXME: We _actually_ want to wait until ReadyForQuery, ErrorResponse, AuthX, etc.
|
||||
|
||||
while let Some(message) = self.incoming.next().await {
|
||||
match message {
|
||||
server::Message::AuthenticationOk => {
|
||||
// Do nothing; server is just telling us "you're in"
|
||||
}
|
||||
|
||||
server::Message::ReadyForQuery(_) => {
|
||||
// Good to go
|
||||
break;
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn terminate(&mut self) -> io::Result<()> {
|
||||
let msg = client::Terminate {};
|
||||
msg.serialize(&mut self.buf);
|
||||
|
||||
self.writer.write_all(&self.buf).await?;
|
||||
self.buf.clear();
|
||||
|
||||
self.writer.flush().await?;
|
||||
self.writer.close().await?;
|
||||
|
||||
if let Some(reader) = self.reader.take() {
|
||||
reader.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
pub use connection::Connection;
|
||||
|
|
|
@ -4,11 +4,15 @@ use byteorder::{BigEndian, ByteOrder};
|
|||
// https://www.postgresql.org/docs/devel/protocol-message-formats.html
|
||||
// https://www.postgresql.org/docs/devel/protocol-message-types.html
|
||||
|
||||
pub trait Serialize {
|
||||
fn serialize(&self, buf: &mut Vec<u8>);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Terminate;
|
||||
|
||||
impl Terminate {
|
||||
pub fn serialize(&self, buf: &mut Vec<u8>) {
|
||||
impl Serialize for Terminate {
|
||||
fn serialize(&self, buf: &mut Vec<u8>) {
|
||||
buf.push(b'X');
|
||||
buf.push(4);
|
||||
}
|
||||
|
@ -16,56 +20,44 @@ impl Terminate {
|
|||
|
||||
#[derive(Debug)]
|
||||
pub struct StartupMessage<'a> {
|
||||
pub user: &'a str,
|
||||
pub database: Option<&'a str>,
|
||||
pub params: &'a [(&'a str, Option<&'a str>)],
|
||||
}
|
||||
|
||||
impl<'a> StartupMessage<'a> {
|
||||
pub fn serialize(&self, buf: &mut Vec<u8>) {
|
||||
impl<'a> Serialize for StartupMessage<'a> {
|
||||
fn serialize(&self, buf: &mut Vec<u8>) {
|
||||
with_length_prefix(buf, |buf| {
|
||||
// protocol version: major = 3, minor = 0
|
||||
buf.extend_from_slice(&0x0003_i16.to_be_bytes());
|
||||
buf.extend_from_slice(&0x0000_i16.to_be_bytes());
|
||||
|
||||
write_str(buf, "user");
|
||||
write_str(buf, self.user);
|
||||
|
||||
if let Some(database) = self.database {
|
||||
write_str(buf, "database");
|
||||
write_str(buf, database);
|
||||
for (name, value) in self.params {
|
||||
if let Some(value) = value {
|
||||
write_str(buf, name);
|
||||
write_str(buf, value);
|
||||
}
|
||||
}
|
||||
|
||||
// See this doc for more runtime parameters
|
||||
// https://www.postgresql.org/docs/12/runtime-config-client.html
|
||||
|
||||
// Sets the display format for date and time values,
|
||||
// as well as the rules for interpreting ambiguous date input values.
|
||||
write_str(buf, "DateStyle");
|
||||
write_str(buf, "ISO");
|
||||
|
||||
// Sets the display format for interval values.
|
||||
write_str(buf, "IntervalStyle");
|
||||
write_str(buf, "iso_8601");
|
||||
|
||||
// Sets the time zone for displaying and interpreting time stamps.
|
||||
write_str(buf, "TimeZone");
|
||||
write_str(buf, "UTC");
|
||||
|
||||
// Adjust postgres to return percise values for floats
|
||||
// NOTE: This is default in postgres 12+
|
||||
write_str(buf, "extra_float_digits");
|
||||
write_str(buf, "3");
|
||||
|
||||
// Sets the client-side encoding (character set).
|
||||
write_str(buf, "client_encoding");
|
||||
write_str(buf, "UTF-8");
|
||||
|
||||
// A zero byte is required as a terminator after the last name/value pair.
|
||||
buf.push(0);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PasswordMessage<'a> {
|
||||
pub password: &'a str,
|
||||
}
|
||||
|
||||
impl<'a> Serialize for PasswordMessage<'a> {
|
||||
fn serialize(&self, buf: &mut Vec<u8>) {
|
||||
buf.push(b'p');
|
||||
|
||||
with_length_prefix(buf, |buf| {
|
||||
write_str(buf, self.password);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Write a string followed by a null-terminator
|
||||
#[inline]
|
||||
fn write_str(buf: &mut Vec<u8>, s: &str) {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use byteorder::{BigEndian, ByteOrder};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io;
|
||||
use std::{io, str};
|
||||
|
||||
// Reference
|
||||
// https://www.postgresql.org/docs/devel/protocol-message-formats.html
|
||||
|
@ -76,6 +76,9 @@ impl Message {
|
|||
|
||||
b'R' => match BigEndian::read_i32(&buf[idx..]) {
|
||||
0 => Message::AuthenticationOk,
|
||||
5 => Message::AuthenticationMd5Password(AuthenticationMd5Password {
|
||||
salt: buf.slice_from(idx + 4),
|
||||
}),
|
||||
|
||||
code => {
|
||||
unimplemented!("unknown response code received: {:x}", code);
|
||||
|
@ -96,7 +99,14 @@ impl Message {
|
|||
|
||||
#[derive(Debug)]
|
||||
pub struct AuthenticationMd5Password {
|
||||
pub(super) salt: [u8; 4],
|
||||
pub(super) salt: Bytes,
|
||||
}
|
||||
|
||||
impl AuthenticationMd5Password {
|
||||
#[inline]
|
||||
pub fn salt(&self) -> &[u8] {
|
||||
&self.salt
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -111,6 +121,18 @@ pub struct BackendKeyData {
|
|||
pub(super) secret_key: i32,
|
||||
}
|
||||
|
||||
impl BackendKeyData {
|
||||
#[inline]
|
||||
pub fn process_id(&self) -> i32 {
|
||||
self.process_id
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn secret_key(&self) -> i32 {
|
||||
self.secret_key
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CommandComplete {
|
||||
pub(super) tag: Bytes,
|
||||
|
@ -133,6 +155,18 @@ pub struct ParameterStatus {
|
|||
pub(super) value: Bytes,
|
||||
}
|
||||
|
||||
impl ParameterStatus {
|
||||
#[inline]
|
||||
pub fn name(&self) -> io::Result<&str> {
|
||||
Ok(str::from_utf8(&self.name).map_err(|_| io::ErrorKind::InvalidInput)?)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn value(&self) -> io::Result<&str> {
|
||||
Ok(str::from_utf8(&self.value).map_err(|_| io::ErrorKind::InvalidInput)?)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ReadyForQuery {
|
||||
pub(super) status: u8,
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
pub use mason_core::ConnectOptions;
|
||||
pub use mason_postgres as pg;
|
||||
|
|
|
@ -1,8 +1,15 @@
|
|||
#![feature(async_await)]
|
||||
|
||||
use mason::{pg::Connection, ConnectOptions};
|
||||
|
||||
#[runtime::main]
|
||||
async fn main() -> Result<(), failure::Error> {
|
||||
env_logger::try_init()?;
|
||||
|
||||
let conn =
|
||||
Connection::establish(ConnectOptions::new().user("postgres").password("password")).await?;
|
||||
|
||||
conn.close().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue