Move to explicit prepare vs. implicit in execute builder

This commit is contained in:
Ryan Leckey 2019-07-17 12:43:46 -07:00
parent 97d6e7482c
commit 8589f0b907
13 changed files with 233 additions and 394 deletions

View file

@ -13,6 +13,7 @@ memchr = "2.2.0"
md-5 = "0.8.0"
itoa = "0.4.4"
hex = "0.3.2"
log = "0.4"
[dev-dependencies]
matches = "0.1.8"

View file

@ -2,6 +2,7 @@ use crate::Encode;
use byteorder::{BigEndian, ByteOrder};
use std::io;
#[derive(Debug)]
pub struct Bind<'a> {
// The name of the destination portal (an empty string selects the unnamed portal).
portal: &'a str,

View file

@ -10,16 +10,15 @@ pub struct CommandComplete {
impl CommandComplete {
pub fn tag(&self) -> &str {
let tag_end = memchr(b' ', &*self.tag).unwrap();
unsafe { str::from_utf8_unchecked(&self.tag[..tag_end]) }
unsafe { str::from_utf8_unchecked(self.tag.as_ref()) }
}
pub fn rows(&self) -> u64 {
let rows_start = memrchr(b' ', &*self.tag).unwrap();
let rows_start = memrchr(b' ', &*self.tag).map_or(0, |i| i + 1);
let rows_s =
unsafe { str::from_utf8_unchecked(&self.tag[(rows_start + 1)..(self.tag.len() - 1)]) };
unsafe { str::from_utf8_unchecked(&self.tag[rows_start..(self.tag.len() - 1)]) };
rows_s.parse().unwrap()
rows_s.parse().unwrap_or(0)
}
}
@ -36,16 +35,28 @@ mod tests {
use bytes::Bytes;
use std::io;
const COMMAND_COMPLETE: &[u8] = b"INSERT 0 512\0";
const COMMAND_COMPLETE_INSERT: &[u8] = b"INSERT 0 512\0";
const COMMAND_COMPLETE_CREATE_TABLE: &[u8] = b"CREATE TABLE\0";
#[test]
fn it_decodes_command_complete() -> io::Result<()> {
let src = Bytes::from_static(COMMAND_COMPLETE);
fn it_decodes_command_complete_for_insert() -> io::Result<()> {
let src = Bytes::from_static(COMMAND_COMPLETE_INSERT);
let message = CommandComplete::decode(src)?;
assert_eq!(message.tag(), "INSERT");
assert_eq!(message.tag(), "INSERT 0 512");
assert_eq!(message.rows(), 512);
Ok(())
}
#[test]
fn it_decodes_command_complete_for_create_table() -> io::Result<()> {
let src = Bytes::from_static(COMMAND_COMPLETE_INSERT);
let message = CommandComplete::decode(src)?;
assert_eq!(message.tag(), "CREATE TABLE");
assert_eq!(message.rows(), 0);
Ok(())
}
}

View file

@ -1,6 +1,7 @@
use crate::Encode;
use std::io;
#[derive(Debug)]
pub struct Execute<'a> {
portal: &'a str,
limit: i32,

View file

@ -50,7 +50,9 @@ impl Message {
let src = src.split_to(len + 1).freeze().slice_from(5);
Ok(Some(match token {
log::trace!("recv {:?}", src);
let message = match token {
b'N' | b'E' => Message::Response(Response::decode(src)?),
b'S' => Message::ParameterStatus(ParameterStatus::decode(src)?),
b'Z' => Message::ReadyForQuery(ReadyForQuery::decode(src)?),
@ -65,6 +67,10 @@ impl Message {
b'n' => Message::NoData,
_ => unimplemented!("decode not implemented for token: {}", token as char),
}))
};
log::trace!("decode {:?}", message);
Ok(Some(message))
}
}

View file

@ -52,7 +52,7 @@ impl Decode for NotificationResponse {
// offset from pid=4
let channel_name = get_str(&storage[4..])?;
// offset = pid + channel_name.len() + \0
let message = get_str(&storage[(4 + channel_name.len() + 1)..])?;
@ -87,4 +87,4 @@ mod tests {
assert_eq!(message.message(), "THIS IS A TEST");
Ok(())
}
}
}

View file

@ -1,6 +1,7 @@
use crate::Encode;
use std::io;
#[derive(Debug)]
pub struct Parse<'a> {
portal: &'a str,
query: &'a str,

View file

@ -1,6 +1,7 @@
use crate::Encode;
use std::io;
#[derive(Debug)]
pub struct Sync;
impl Encode for Sync {

View file

@ -1,5 +1,4 @@
use super::Connection;
use futures::StreamExt;
use sqlx_core::ConnectOptions;
use sqlx_postgres_protocol::{Authentication, Message, PasswordMessage, StartupMessage};
use std::io;
@ -31,10 +30,11 @@ pub async fn establish<'a, 'b: 'a>(
let message = StartupMessage::new(params);
conn.send(message).await?;
conn.send(message);
conn.flush().await?;
while let Some(message) = conn.stream.next().await {
match message? {
while let Some(message) = conn.receive().await? {
match message {
Message::Authentication(Authentication::Ok) => {
// Do nothing; server is just telling us that
// there is no password needed
@ -44,8 +44,8 @@ pub async fn establish<'a, 'b: 'a>(
// FIXME: Should error early (before send) if the user did not supply a password
conn.send(PasswordMessage::cleartext(
options.password.unwrap_or_default(),
))
.await?;
));
conn.flush().await?;
}
Message::Authentication(Authentication::Md5Password { salt }) => {
@ -54,8 +54,8 @@ pub async fn establish<'a, 'b: 'a>(
options.password.unwrap_or_default(),
options.user.unwrap_or_default(),
salt,
))
.await?;
));
conn.flush().await?;
}
Message::BackendKeyData(body) => {

View file

@ -1,291 +0,0 @@
use super::Connection;
use futures::{io::AsyncWrite, ready, Stream};
use sqlx_postgres_protocol::{self as proto, Encode, Parse};
use std::{
future::Future,
io,
pin::Pin,
sync::atomic::Ordering,
task::{Context, Poll},
};
// NOTE: This is a rough draft of the implementation
#[inline]
pub fn execute<'a>(connection: &'a mut Connection, query: &'a str) -> Execute<'a> {
Execute {
connection,
query,
state: ExecuteState::Parse,
rows: 0,
}
}
pub struct Execute<'a> {
connection: &'a mut Connection,
query: &'a str,
state: ExecuteState,
rows: u64,
}
#[derive(Debug)]
enum ExecuteState {
Parse,
Bind,
Execute,
Sync,
SendingParse,
SendingBind,
SendingExecute,
SendingSync,
Flush,
WaitForComplete,
}
impl<'a> Execute<'a> {
#[inline]
pub fn bind(self, value: &'a [u8]) -> Bind<'a, &'a [u8]> {
Bind { ex: self, value }
}
}
fn poll_write_all<W: AsyncWrite + Unpin>(
mut writer: W,
buf: &mut Vec<u8>,
cx: &mut Context,
) -> Poll<io::Result<()>> {
// Derived from https://rust-lang-nursery.github.io/futures-api-docs/0.3.0-alpha.16/src/futures_util/io/write_all.rs.html#26
while !buf.is_empty() {
let n = ready!(Pin::new(&mut writer).poll_write(cx, &*buf))?;
buf.truncate(buf.len() - n);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
}
}
Poll::Ready(Ok(()))
}
fn poll_execute<T: ToSql>(
cx: &mut Context,
conn: &mut Connection,
state: &mut ExecuteState,
query: &str,
values: &T,
out: &mut u64,
) -> Poll<io::Result<u64>> {
loop {
*state = match state {
ExecuteState::Parse => {
conn.wbuf.clear();
let stmt = format!(
"__sqlx#{}",
conn.statement_index.fetch_add(1, Ordering::SeqCst)
);
Parse::new(&stmt, query, &[])
.encode(&mut conn.wbuf)
.unwrap();
ExecuteState::SendingParse
}
ExecuteState::SendingParse => {
ready!(poll_write_all(&mut conn.stream.inner, &mut conn.wbuf, cx))?;
ExecuteState::Bind
}
ExecuteState::Bind => {
conn.wbuf.clear();
// FIXME: Think of a better way to build up a BIND message. Think on how to
// avoid allocation here.
let mut values_buf = Vec::new();
values_buf.extend_from_slice(&values.count().to_be_bytes());
values.to_sql(&mut values_buf);
// FIXME: We need to cache the statement name around
let stmt = format!("__sqlx#{}", conn.statement_index.load(Ordering::SeqCst) - 1);
proto::Bind::new(&stmt, &stmt, &[], &values_buf, &[])
.encode(&mut conn.wbuf)
.unwrap();
ExecuteState::SendingBind
}
ExecuteState::SendingBind => {
ready!(poll_write_all(&mut conn.stream.inner, &mut conn.wbuf, cx))?;
ExecuteState::Execute
}
ExecuteState::Execute => {
conn.wbuf.clear();
// FIXME: We need to cache the statement name around
let stmt = format!("__sqlx#{}", conn.statement_index.load(Ordering::SeqCst) - 1);
proto::Execute::new(&stmt, 0)
.encode(&mut conn.wbuf)
.unwrap();
ExecuteState::SendingExecute
}
ExecuteState::SendingExecute => {
ready!(poll_write_all(&mut conn.stream.inner, &mut conn.wbuf, cx))?;
ExecuteState::Sync
}
ExecuteState::Sync => {
conn.wbuf.clear();
proto::Sync.encode(&mut conn.wbuf).unwrap();
ExecuteState::SendingSync
}
ExecuteState::SendingSync => {
ready!(poll_write_all(&mut conn.stream.inner, &mut conn.wbuf, cx))?;
ExecuteState::Flush
}
ExecuteState::Flush => {
ready!(Pin::new(&mut conn.stream.inner).poll_flush(cx))?;
ExecuteState::WaitForComplete
}
ExecuteState::WaitForComplete => {
while let Some(message) = ready!(Pin::new(&mut conn.stream).poll_next(cx)) {
match message? {
proto::Message::BindComplete | proto::Message::ParseComplete => {
// Indicates successful completion of a phase
}
proto::Message::DataRow(_) => {
// This is EXECUTE so we are ignoring any potential results
}
proto::Message::CommandComplete(body) => {
*out = body.rows();
}
proto::Message::ReadyForQuery(_) => {
// Successful completion of the whole cycle
return Poll::Ready(Ok(*out));
}
message => {
unimplemented!("received {:?} unimplemented message", message);
}
}
}
// FIXME: This is technically reachable if the pg conn is dropped?
unreachable!()
}
}
}
}
impl<'a> Future for Execute<'a> {
type Output = io::Result<u64>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let self_ = self.get_mut();
poll_execute(
cx,
&mut *self_.connection,
&mut self_.state,
&self_.query,
&(),
&mut self_.rows,
)
}
}
// TODO: This should be cleaned up and moved to core; probably needs to be generic over back-end
// TODO: I'm using some trait recursion here.. this should probably not be exposed in core
pub trait ToSql {
/// Converts the value of `self` into the appropriate format, appending it to `out`.
fn to_sql(&self, out: &mut Vec<u8>);
// Count the number of value parameters recursively encoded.
fn count(&self) -> i16;
}
impl<'a> ToSql for () {
#[inline]
fn to_sql(&self, _out: &mut Vec<u8>) {
// Do nothing
}
#[inline]
fn count(&self) -> i16 {
0
}
}
impl<'a> ToSql for &'a [u8] {
#[inline]
fn to_sql(&self, out: &mut Vec<u8>) {
out.extend_from_slice(&(self.len() as i32).to_be_bytes());
out.extend_from_slice(self);
}
#[inline]
fn count(&self) -> i16 {
1
}
}
impl<'a, T: ToSql + 'a, U: ToSql + 'a> ToSql for (T, U) {
#[inline]
fn to_sql(&self, out: &mut Vec<u8>) {
self.0.to_sql(out);
self.1.to_sql(out);
}
#[inline]
fn count(&self) -> i16 {
self.0.count() + self.1.count()
}
}
pub struct Bind<'a, T: ToSql + 'a> {
ex: Execute<'a>,
value: T,
}
impl<'a, T: ToSql + 'a> Bind<'a, T> {
#[inline]
pub fn bind(self, value: &'a [u8]) -> Bind<'a, (T, &'a [u8])> {
Bind {
ex: self.ex,
value: (self.value, value),
}
}
}
impl<'a, T: Unpin + ToSql + 'a> Future for Bind<'a, T> {
type Output = io::Result<u64>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let self_ = self.get_mut();
poll_execute(
cx,
&mut *self_.ex.connection,
&mut self_.ex.state,
&self_.ex.query,
&self_.value,
&mut self_.ex.rows,
)
}
}

View file

@ -1,28 +1,34 @@
use bytes::{BufMut, BytesMut};
use futures::{
io::{AsyncRead, AsyncWriteExt},
io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
ready,
task::{Context, Poll},
Stream,
Future,
};
use runtime::net::TcpStream;
use sqlx_core::ConnectOptions;
use sqlx_postgres_protocol::{Encode, Message, Terminate};
use std::{fmt::Debug, io, pin::Pin, sync::atomic::AtomicU64};
use std::{fmt::Debug, io, pin::Pin};
mod establish;
mod execute;
mod prepare;
pub struct Connection {
pub(super) stream: Framed<TcpStream>,
pub(super) stream: TcpStream,
// HACK: This is how we currently "name" queries when executing
statement_index: AtomicU64,
// Do we think that there is data in the read buffer to be decoded
stream_readable: bool,
// Buffer used when serializing outgoing messages
// FIXME: Use BytesMut
// Have we reached end-of-file (been disconnected)
stream_eof: bool,
// Buffer used when sending outgoing messages
wbuf: Vec<u8>,
// Buffer used when reading incoming messages
// TODO: Evaluate if we _really_ want to use BytesMut here
rbuf: BytesMut,
// Process ID of the Backend
process_id: u32,
@ -35,10 +41,12 @@ impl Connection {
let stream = TcpStream::connect((options.host, options.port)).await?;
let mut conn = Self {
wbuf: Vec::with_capacity(1024),
stream: Framed::new(stream),
rbuf: BytesMut::with_capacity(1024 * 8),
stream,
stream_readable: false,
stream_eof: false,
process_id: 0,
secret_key: 0,
statement_index: AtomicU64::new(0),
};
establish::establish(&mut conn, options).await?;
@ -46,106 +54,131 @@ impl Connection {
Ok(conn)
}
#[inline]
pub fn execute<'a>(&'a mut self, query: &'a str) -> execute::Execute<'a> {
execute::execute(self, query)
pub fn prepare(&mut self, query: &str) -> prepare::Prepare {
prepare::prepare(self, query)
}
pub async fn close(mut self) -> io::Result<()> {
self.send(Terminate).await?;
self.stream.inner.close().await?;
self.send(Terminate);
self.flush().await?;
self.stream.close().await?;
Ok(())
}
async fn send<T>(&mut self, message: T) -> io::Result<()>
where
T: Encode + Debug,
{
self.wbuf.clear();
message.encode(&mut self.wbuf)?;
self.stream.inner.write_all(&self.wbuf).await?;
self.stream.inner.flush().await?;
Ok(())
}
}
pub(super) struct Framed<S> {
inner: S,
readable: bool,
eof: bool,
buffer: BytesMut,
}
impl<S> Framed<S> {
fn new(stream: S) -> Self {
Self {
readable: false,
eof: false,
inner: stream,
buffer: BytesMut::with_capacity(8 * 1024),
}
}
}
impl<S> Stream for Framed<S>
where
S: AsyncRead + Unpin,
{
type Item = io::Result<Message>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let self_ = Pin::get_mut(self);
// Wait and return the next message to be received from Postgres.
async fn receive(&mut self) -> io::Result<Option<Message>> {
loop {
if self_.readable {
if self_.eof {
return Poll::Ready(None);
}
if self.stream_eof {
// Reached end-of-file on a previous read call.
return Ok(None);
}
if self.stream_readable {
loop {
match Message::decode(&mut self_.buffer)? {
match Message::decode(&mut self.rbuf)? {
Some(Message::ParameterStatus(_body)) => {
// TODO: Not sure what to do with these but ignore
// TODO: not sure what to do with these yet
}
Some(Message::Response(_body)) => {
// TODO: Handle notices and errors
// TODO: Transform Errors+ into an error type and return
// TODO: Log all others
}
Some(message) => {
return Poll::Ready(Some(Ok(message)));
return Ok(Some(message));
}
None => {
self_.readable = false;
// Not enough data in the read buffer to parse a message
self.stream_readable = true;
break;
}
}
}
}
self_.buffer.reserve(32);
// Ensure there is at least 32-bytes of space available
// in the read buffer so we can safely detect end-of-file
self.rbuf.reserve(32);
let n = unsafe {
let b = self_.buffer.bytes_mut();
self_.inner.initializer().initialize(b);
// SAFE: Read data in directly to buffer without zero-initializing the data.
// Postgres is a self-describing format and the TCP frames encode
// length headers. We will never attempt to decode more than we
// received.
let n = self.stream.read(unsafe { self.rbuf.bytes_mut() }).await?;
let n = ready!(Pin::new(&mut self_.inner).poll_read(cx, b))?;
self_.buffer.advance_mut(n);
n
};
// SAFE: After we read in N bytes, we can tell the buffer that it actually
// has that many bytes for the decode routines to look at
unsafe { self.rbuf.set_len(n) }
if n == 0 {
self_.eof = true;
self.stream_eof = true;
}
self_.readable = true;
self.stream_readable = true;
}
}
fn send<T>(&mut self, message: T)
where
T: Encode + Debug,
{
log::trace!("encode {:?}", message);
// TODO: Encoding should not be fallible
message.encode(&mut self.wbuf).unwrap();
}
async fn flush(&mut self) -> io::Result<()> {
// TODO: Find some other way to print a Vec<u8> as an ASCII escaped string
log::trace!("send {:?}", bytes::Bytes::from(&*self.wbuf));
WriteAllVec::new(&mut self.stream, &mut self.wbuf).await?;
self.stream.flush().await?;
Ok(())
}
}
// Derived from: https://rust-lang-nursery.github.io/futures-api-docs/0.3.0-alpha.16/src/futures_util/io/write_all.rs.html#10-13
// With alterations to be more efficient if we're writing from a mutable vector
// that we can erase
// TODO: Move to Core under 'sqlx_core::io' perhaps?
// TODO: Perhaps the futures project wants this?
pub struct WriteAllVec<'a, W: ?Sized + Unpin> {
writer: &'a mut W,
buf: &'a mut Vec<u8>,
}
impl<W: ?Sized + Unpin> Unpin for WriteAllVec<'_, W> {}
impl<'a, W: AsyncWrite + ?Sized + Unpin> WriteAllVec<'a, W> {
pub(super) fn new(writer: &'a mut W, buf: &'a mut Vec<u8>) -> Self {
WriteAllVec { writer, buf }
}
}
impl<W: AsyncWrite + ?Sized + Unpin> Future for WriteAllVec<'_, W> {
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = &mut *self;
while !this.buf.is_empty() {
let n = ready!(Pin::new(&mut this.writer).poll_write(cx, this.buf))?;
this.buf.truncate(this.buf.len() - n);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()));
}
}
Poll::Ready(Ok(()))
}
}

View file

@ -0,0 +1,63 @@
use super::Connection;
use sqlx_postgres_protocol::{Bind, Execute, Message, Parse, Sync};
use std::io;
pub struct Prepare<'a> {
connection: &'a mut Connection,
}
#[inline]
pub fn prepare<'a, 'b>(connection: &'a mut Connection, query: &'b str) -> Prepare<'a> {
// TODO: Use a hash map to cache the parse
// TODO: Use named statements
connection.send(Parse::new("", query, &[]));
Prepare { connection }
}
impl<'a> Prepare<'a> {
#[inline]
pub fn bind<'b>(self, value: &'b [u8]) -> Self {
// TODO: Encode the value here onto the wbuf
self
}
#[inline]
pub async fn execute(self) -> io::Result<u64> {
// FIXME: Break this up into BindHeader, BindValue, and BindTrailer
self.connection.send(Bind::new("", "", &[], &[0, 0], &[]));
self.connection.send(Execute::new("", 0));
self.connection.send(Sync);
self.connection.flush().await?;
let mut rows = 0;
while let Some(message) = self.connection.receive().await? {
match message {
Message::BindComplete | Message::ParseComplete => {
// Indicates successful completion of a phase
}
Message::DataRow(_) => {
// This is EXECUTE so we are ignoring any potential results
}
Message::CommandComplete(body) => {
rows = body.rows();
}
Message::ReadyForQuery(_) => {
// Successful completion of the whole cycle
return Ok(rows);
}
message => {
unimplemented!("received {:?} unimplemented message", message);
}
}
}
// FIXME: This is an end-of-file error. How we should bubble this up here?
unreachable!()
}
}

View file

@ -20,11 +20,23 @@ async fn main() -> io::Result<()> {
)
.await?;
conn.execute("INSERT INTO \"users\" (name) VALUES ($1)")
.bind(b"Joe")
conn.prepare(
r#"
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
name TEXT NOT NULL
);
"#,
)
.execute()
.await?;
conn.prepare("INSERT INTO users (name) VALUES ('George')")
// .bind(b"Joe")
.execute()
.await?;
let count = conn.execute("SELECT name FROM users").await?;
let count = conn.prepare("SELECT name FROM users").execute().await?;
println!("users: {}", count);
conn.close().await?;