Add a Transaction type to simplify dealing with Transactions

This commit is contained in:
Ryan Leckey 2020-01-03 22:42:10 -08:00
parent 28ed854b03
commit b1a27ddac2
10 changed files with 272 additions and 41 deletions

View file

@ -50,7 +50,9 @@ async fn register(mut req: Request<PgPool>) -> Response {
let body: RegisterRequestBody = req.body_json().await.unwrap();
let hash = hash_password(&body.password).unwrap();
let mut pool = req.state();
// Make a new transaction
let pool = req.state();
let mut tx = pool.begin().await.unwrap();
let rec = sqlx::query!(
r#"
@ -62,12 +64,15 @@ RETURNING id, username, email
body.email,
hash,
)
.fetch_one(&mut pool)
.fetch_one(&mut tx)
.await
.unwrap();
let token = generate_token(rec.id).unwrap();
// Explicitly commit
tx.commit().await.unwrap();
#[derive(serde::Serialize)]
struct RegisterResponseBody {
user: User,

View file

@ -16,6 +16,7 @@ mod database;
mod executor;
mod query;
mod query_as;
mod transaction;
mod url;
#[macro_use]
@ -47,6 +48,7 @@ pub use connection::{Connect, Connection};
pub use executor::Executor;
pub use query::{query, Query};
pub use query_as::{query_as, QueryAs};
pub use transaction::Transaction;
#[doc(hidden)]
pub use query_as::query_as_mapped;

View file

@ -7,7 +7,7 @@ use futures_core::future::BoxFuture;
use sha1::Sha1;
use crate::cache::StatementCache;
use crate::connection::Connection;
use crate::connection::{Connect, Connection};
use crate::io::{Buf, BufMut, BufStream, MaybeTlsStream};
use crate::mysql::error::MySqlError;
use crate::mysql::protocol::{
@ -475,7 +475,7 @@ impl MySqlConnection {
}
impl MySqlConnection {
pub(super) async fn open(url: crate::Result<Url>) -> crate::Result<Self> {
pub(super) async fn establish(url: crate::Result<Url>) -> crate::Result<Self> {
let url = url?;
let mut self_ = Self::new(&url).await?;
@ -598,19 +598,19 @@ impl MySqlConnection {
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(MySqlConnection::open(url.try_into()))
Box::pin(MySqlConnection::establish(url.try_into()))
}
}
impl Connect for MySqlConnection {
type Connection = MySqlConnection;
fn connect<T>(url: T) -> BoxFuture<'static, Result<MySqlConnection>>
fn connect<T>(url: T) -> BoxFuture<'static, crate::Result<MySqlConnection>>
where
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(PgConnection::open(url.try_into()))
Box::pin(MySqlConnection::establish(url.try_into()))
}
}

View file

@ -26,16 +26,3 @@ pub use row::MySqlRow;
/// An alias for [`Pool`], specialized for **MySQL**.
pub type MySqlPool = super::Pool<MySql>;
use std::convert::TryInto;
use crate::url::Url;
// used in tests and hidden code in examples
#[doc(hidden)]
pub async fn connect<T>(url: T) -> crate::Result<MySqlConnection>
where
T: TryInto<Url, Error = crate::Error>,
{
MySqlConnection::open(url.try_into()).await
}

View file

@ -1,3 +1,5 @@
use std::ops::DerefMut;
use futures_core::{future::BoxFuture, stream::BoxStream};
use futures_util::StreamExt;
@ -9,6 +11,8 @@ use crate::{
Database,
};
use super::PoolConnection;
impl<C> Executor for Pool<C>
where
C: Connection + Connect<Connection = C>,
@ -108,3 +112,45 @@ where
Box::pin(async move { self.acquire().await?.describe(query).await })
}
}
impl<C> Executor for PoolConnection<C>
where
C: Connection + Connect<Connection = C>,
{
type Database = <C as Executor>::Database;
fn send<'e, 'q: 'e>(&'e mut self, commands: &'q str) -> BoxFuture<'e, crate::Result<()>> {
self.deref_mut().send(commands)
}
fn execute<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <<C as Executor>::Database as Database>::Arguments,
) -> BoxFuture<'e, crate::Result<u64>> {
self.deref_mut().execute(query, args)
}
fn fetch<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <<C as Executor>::Database as Database>::Arguments,
) -> BoxStream<'e, crate::Result<<<C as Executor>::Database as Database>::Row>> {
self.deref_mut().fetch(query, args)
}
fn fetch_optional<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <<C as Executor>::Database as Database>::Arguments,
) -> BoxFuture<'e, crate::Result<Option<<<C as Executor>::Database as Database>::Row>>> {
self.deref_mut().fetch_optional(query, args)
}
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
self.deref_mut().describe(query)
}
}

View file

@ -2,21 +2,26 @@
use std::{
fmt,
mem,
ops::{Deref, DerefMut},
sync::Arc,
time::{Duration, Instant},
};
use futures_core::future::BoxFuture;
use crate::connection::{Connect, Connection};
use crate::transaction::Transaction;
use self::inner::SharedPool;
pub use self::options::Builder;
use self::options::Options;
mod executor;
mod inner;
mod options;
pub use self::options::Builder;
/// A pool of database connections.
pub struct Pool<C>(Arc<SharedPool<C>>);
@ -84,6 +89,11 @@ where
})
}
/// Retrieves a new connection and immediately begins a new transaction.
pub async fn begin(&self) -> crate::Result<Transaction<PoolConnection<C>>> {
Ok(Transaction::new(0, self.acquire().await?).await?)
}
/// Ends the use of a connection pool. Prevents any new connections
/// and will close all active connections when they are returned to the pool.
///
@ -172,6 +182,27 @@ where
}
}
impl<C> Connection for PoolConnection<C>
where
C: Connection + Connect<Connection = C>,
{
fn close(mut self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(async move {
if let Some(live) = self.live.take() {
let raw = live.raw;
// Explicitly close the connection
raw.close().await?;
}
// Forget ourself so it does not go back to the pool
mem::forget(self);
Ok(())
})
}
}
impl<C> Drop for PoolConnection<C>
where
C: Connection + Connect<Connection = C>,

View file

@ -8,7 +8,7 @@ use rand::Rng;
use sha2::{Digest, Sha256};
use crate::cache::StatementCache;
use crate::connection::Connection;
use crate::connection::{Connect, Connection};
use crate::io::{Buf, BufStream, MaybeTlsStream};
use crate::postgres::protocol::{
self, hi, Authentication, Decode, Encode, Message, SaslInitialResponse, SaslResponse,
@ -334,7 +334,7 @@ impl PgConnection {
}
impl PgConnection {
pub(super) async fn open(url: Result<Url>) -> Result<Self> {
pub(super) async fn establish(url: Result<Url>) -> Result<Self> {
let url = url?;
let stream = MaybeTlsStream::connect(&url, 5432).await?;
@ -402,7 +402,7 @@ impl PgConnection {
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(PgConnection::open(url.try_into()))
Box::pin(PgConnection::establish(url.try_into()))
}
}
@ -414,7 +414,7 @@ impl Connect for PgConnection {
T: TryInto<Url, Error = crate::Error>,
Self: Sized,
{
Box::pin(PgConnection::open(url.try_into()))
Box::pin(PgConnection::establish(url.try_into()))
}
}

View file

@ -18,16 +18,3 @@ mod types;
/// An alias for [`Pool`], specialized for **Postgres**.
pub type PgPool = super::Pool<Postgres>;
use std::convert::TryInto;
use crate::url::Url;
// used in tests and hidden code in examples
#[doc(hidden)]
pub async fn connect<T>(url: T) -> crate::Result<PgConnection>
where
T: TryInto<Url, Error = crate::Error>,
{
PgConnection::open(url.try_into()).await
}

View file

@ -0,0 +1,173 @@
use std::ops::{Deref, DerefMut};
use async_std::task;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use crate::database::Database;
use crate::describe::Describe;
use crate::executor::Executor;
use crate::connection::Connection;
pub struct Transaction<T>
where
T: Connection + Send + 'static,
{
inner: Option<T>,
depth: u32,
}
impl<T> Transaction<T>
where
T: Connection + Send + 'static,
{
pub(crate) async fn new(depth: u32, mut inner: T) -> crate::Result<Self> {
if depth == 0 {
inner.send("BEGIN").await?;
} else {
inner
.send(&format!("SAVEPOINT _sqlx_savepoint_{}", depth))
.await?;
}
Ok(Self {
inner: Some(inner),
depth: depth + 1,
})
}
pub async fn begin(mut self) -> crate::Result<Transaction<T>> {
Transaction::new(self.depth, self.inner.take().expect(ERR_FINALIZED)).await
}
pub async fn commit(mut self) -> crate::Result<T> {
let mut inner = self.inner.take().expect(ERR_FINALIZED);
let depth = self.depth;
if depth == 1 {
inner.send("COMMIT").await?;
} else {
inner
.send(&format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1))
.await?;
}
Ok(inner)
}
pub async fn rollback(mut self) -> crate::Result<T> {
let mut inner = self.inner.take().expect(ERR_FINALIZED);
let depth = self.depth;
if depth == 1 {
inner.send("ROLLBACK").await?;
} else {
inner
.send(&format!(
"ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}",
depth - 1
))
.await?;
}
Ok(inner)
}
}
const ERR_FINALIZED: &str = "(bug) transaction already finalized";
impl<Conn> Deref for Transaction<Conn>
where
Conn: Connection,
{
type Target = Conn;
fn deref(&self) -> &Self::Target {
self.inner.as_ref().expect(ERR_FINALIZED)
}
}
impl<Conn> DerefMut for Transaction<Conn>
where
Conn: Connection,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner.as_mut().expect(ERR_FINALIZED)
}
}
impl<T> Connection for Transaction<T>
where
T: Connection
{
// Close is equivalent to ROLLBACK followed by CLOSE
fn close(self) -> BoxFuture<'static, crate::Result<()>> {
Box::pin(async move {
self.rollback().await?.close().await
})
}
}
impl<T> Executor for Transaction<T>
where
T: Connection,
{
type Database = T::Database;
fn send<'e, 'q: 'e>(&'e mut self, commands: &'q str) -> BoxFuture<'e, crate::Result<()>> {
self.deref_mut().send(commands)
}
fn execute<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <Self::Database as Database>::Arguments,
) -> BoxFuture<'e, crate::Result<u64>> {
self.deref_mut().execute(query, args)
}
fn fetch<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <Self::Database as Database>::Arguments,
) -> BoxStream<'e, crate::Result<<Self::Database as Database>::Row>> {
self.deref_mut().fetch(query, args)
}
fn fetch_optional<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
args: <Self::Database as Database>::Arguments,
) -> BoxFuture<'e, crate::Result<Option<<Self::Database as Database>::Row>>> {
self.deref_mut().fetch_optional(query, args)
}
fn describe<'e, 'q: 'e>(
&'e mut self,
query: &'q str,
) -> BoxFuture<'e, crate::Result<Describe<Self::Database>>> {
self.deref_mut().describe(query)
}
}
impl<Conn> Drop for Transaction<Conn>
where
Conn: Connection,
{
fn drop(&mut self) {
if self.depth > 0 {
if let Some(mut inner) = self.inner.take() {
task::spawn(async move {
let res = inner.send("ROLLBACK").await;
// If the rollback failed we need to close the inner connection
if res.is_err() {
// This will explicitly forget the connection so it will not
// return to the pool
let _ = inner.close().await;
}
});
}
}
}
}

View file

@ -29,14 +29,14 @@ use query_macros::*;
macro_rules! async_macro (
($db:ident => $expr:expr) => {{
let res: Result<proc_macro2::TokenStream> = task::block_on(async {
use sqlx::Connection;
use sqlx::Connect;
let db_url = Url::parse(&dotenv::var("DATABASE_URL").map_err(|_| "DATABASE_URL not set")?)?;
match db_url.scheme() {
#[cfg(feature = "postgres")]
"postgresql" | "postgres" => {
let $db = sqlx::postgres::PgConnection::open(db_url.as_str())
let $db = sqlx::postgres::PgConnection::connect(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?;
@ -50,7 +50,7 @@ macro_rules! async_macro (
).into()),
#[cfg(feature = "mysql")]
"mysql" | "mariadb" => {
let $db = sqlx::mysql::MySqlConnection::open(db_url.as_str())
let $db = sqlx::mysql::MySqlConnection::connect(db_url.as_str())
.await
.map_err(|e| format!("failed to connect to database: {}", e))?;