simplify pool implementation, run rustfmt

This commit is contained in:
Austin Bonander 2019-11-22 17:06:32 +00:00
parent 8bd768afe8
commit a9fb263520
17 changed files with 139 additions and 147 deletions

View file

@ -16,14 +16,12 @@ postgres = []
mariadb = []
[dependencies]
async-std = { version = "1.1.0", features = ["attributes"] }
async-std = { version = "1.1.0", features = ["attributes", "unstable"] }
async-stream = "0.2.0"
async-trait = "0.1.18"
bitflags = "1.2.1"
byteorder = { version = "1.3.2", default-features = false }
bytes = "0.4.12"
crossbeam-queue = "0.2.0"
crossbeam-utils = { version = "0.7.0", default-features = false }
futures-channel = "0.3.1"
futures-core = "0.3.1"
futures-util = "0.3.1"

View file

@ -15,7 +15,7 @@ pub struct Connection<DB>
where
DB: Backend,
{
live: Option<Live<DB>>,
live: Live<DB>,
pool: Option<Arc<SharedPool<DB>>>,
}
@ -24,25 +24,17 @@ where
DB: Backend,
{
pub(crate) fn new(live: Live<DB>, pool: Option<Arc<SharedPool<DB>>>) -> Self {
Self {
live: Some(live),
pool,
}
Self { live, pool }
}
pub async fn open(url: &str) -> crate::Result<Self> {
let raw = DB::open(url).await?;
let live = Live {
raw,
since: Instant::now(),
};
Ok(Self::new(live, None))
Ok(Self::new(Live::unpooled(raw), None))
}
/// Verifies a connection to the database is still alive.
pub async fn ping(&mut self) -> crate::Result<()> {
self.live.as_mut().expect("released").raw.ping().await
self.live.ping().await
}
/// Analyze the SQL statement and report the inferred bind parameter types and returned
@ -50,12 +42,7 @@ where
///
/// Mainly intended for use by sqlx-macros.
pub async fn describe(&mut self, statement: &str) -> crate::Result<Describe<DB>> {
self.live
.as_mut()
.expect("released")
.raw
.describe(statement)
.await
self.live.describe(statement).await
}
}
@ -73,14 +60,7 @@ where
where
A: IntoQueryParameters<Self::Backend> + Send,
{
Box::pin(async move {
self.live
.as_mut()
.expect("released")
.raw
.execute(query, params.into_params())
.await
})
Box::pin(async move { self.live.execute(query, params.into_params()).await })
}
fn fetch<'c, 'q: 'c, T: 'c, A: 'c>(
@ -93,7 +73,7 @@ where
T: FromSqlRow<Self::Backend> + Send + Unpin,
{
Box::pin(async_stream::try_stream! {
let mut s = self.live.as_mut().expect("released").raw.fetch(query, params.into_params());
let mut s = self.live.fetch(query, params.into_params());
while let Some(row) = s.next().await.transpose()? {
yield T::from_row(row);
@ -113,9 +93,6 @@ where
Box::pin(async move {
let row = self
.live
.as_mut()
.expect("released")
.raw
.fetch_optional(query, params.into_params())
.await?;
@ -123,16 +100,3 @@ where
})
}
}
impl<DB> Drop for Connection<DB>
where
DB: Backend,
{
fn drop(&mut self) {
if let Some(pool) = &self.pool {
if let Some(live) = self.live.take() {
pool.release(live);
}
}
}
}

View file

@ -38,11 +38,11 @@ pub use self::{
compiled::CompiledSql,
connection::Connection,
decode::Decode,
encode::Encode,
error::{Error, Result},
executor::Executor,
pool::Pool,
row::{FromSqlRow, Row},
encode::Encode,
sql::{query, SqlQuery},
types::HasSqlType,
};

View file

@ -1,8 +1,8 @@
use super::MariaDb;
use crate::{
encode::{Encode, IsNull},
mariadb::types::MariaDbTypeMetadata,
query::QueryParameters,
encode::{IsNull, Encode},
types::HasSqlType,
};

View file

@ -1,10 +1,10 @@
use crate::{
encode::IsNull,
mariadb::{
protocol::{FieldType, ParameterFlag},
types::MariaDbTypeMetadata,
},
encode::IsNull,
Decode, HasSqlType, MariaDb, Encode,
Decode, Encode, HasSqlType, MariaDb,
};
impl HasSqlType<[u8]> for MariaDb {
@ -31,7 +31,7 @@ impl Encode<MariaDb> for [u8] {
impl Encode<MariaDb> for Vec<u8> {
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
<[u8] as Encode<MariaDb>>::to_sql(self, buf)
<[u8] as Encode<MariaDb>>::encode(self, buf)
}
}

View file

@ -1,8 +1,8 @@
use super::{MariaDb, MariaDbTypeMetadata};
use crate::{
decode::Decode,
encode::{Encode, IsNull},
mariadb::protocol::{FieldType, ParameterFlag},
encode::{IsNull, Encode},
types::HasSqlType,
};

View file

@ -1,8 +1,8 @@
use super::{MariaDb, MariaDbTypeMetadata};
use crate::{
decode::Decode,
encode::{Encode, IsNull},
mariadb::protocol::{FieldType, ParameterFlag},
encode::{IsNull, Encode},
types::HasSqlType,
};
use std::str;
@ -37,7 +37,7 @@ impl Encode<MariaDb> for str {
impl Encode<MariaDb> for String {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
<str as Encode<MariaDb>>::to_sql(self.as_str(), buf)
<str as Encode<MariaDb>>::encode(self.as_str(), buf)
}
}

View file

@ -1,8 +1,8 @@
use super::{MariaDb, MariaDbTypeMetadata};
use crate::{
decode::Decode,
encode::{Encode, IsNull},
mariadb::protocol::{FieldType, ParameterFlag},
encode::{IsNull, Encode},
types::HasSqlType,
};
use byteorder::{BigEndian, ByteOrder};
@ -102,14 +102,14 @@ impl HasSqlType<f32> for MariaDb {
impl Encode<MariaDb> for f32 {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
<i32 as Encode<MariaDb>>::to_sql(&(self.to_bits() as i32), buf)
<i32 as Encode<MariaDb>>::encode(&(self.to_bits() as i32), buf)
}
}
impl Decode<MariaDb> for f32 {
#[inline]
fn decode(buf: Option<&[u8]>) -> Self {
f32::from_bits(<i32 as Decode<MariaDb>>::from_sql(buf) as u32)
f32::from_bits(<i32 as Decode<MariaDb>>::decode(buf) as u32)
}
}
@ -127,13 +127,13 @@ impl HasSqlType<f64> for MariaDb {
impl Encode<MariaDb> for f64 {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
<i64 as Encode<MariaDb>>::to_sql(&(self.to_bits() as i64), buf)
<i64 as Encode<MariaDb>>::encode(&(self.to_bits() as i64), buf)
}
}
impl Decode<MariaDb> for f64 {
#[inline]
fn decode(buf: Option<&[u8]>) -> Self {
f64::from_bits(<i64 as Decode<MariaDb>>::from_sql(buf) as u64)
f64::from_bits(<i64 as Decode<MariaDb>>::decode(buf) as u64)
}
}

View file

@ -2,12 +2,13 @@ use crate::{
backend::Backend, connection::Connection, error::Error, executor::Executor,
query::IntoQueryParameters, row::FromSqlRow,
};
use crossbeam_queue::{ArrayQueue, SegQueue};
use futures_channel::oneshot;
use futures_core::{future::BoxFuture, stream::BoxStream};
use futures_util::stream::StreamExt;
use futures_util::{future::FutureExt, stream::StreamExt};
use std::{
future::Future,
marker::PhantomData,
ops::{Deref, DerefMut},
sync::{
atomic::{AtomicU32, AtomicUsize, Ordering},
Arc,
@ -15,6 +16,9 @@ use std::{
time::{Duration, Instant},
};
use async_std::sync::{channel, Receiver, Sender};
use async_std::task;
/// A pool of database connections.
pub struct Pool<DB>(Arc<SharedPool<DB>>)
where
@ -68,7 +72,7 @@ where
/// Returns the number of idle connections.
pub fn idle(&self) -> usize {
self.0.num_idle.load(Ordering::Acquire)
self.0.pool_rx.len()
}
/// Returns the configured maximum pool size.
@ -170,11 +174,9 @@ where
DB: Backend,
{
url: String,
idle: ArrayQueue<Idle<DB>>,
waiters: SegQueue<oneshot::Sender<Live<DB>>>,
pool_rx: Receiver<Idle<DB>>,
pool_tx: Sender<Idle<DB>>,
size: AtomicU32,
num_waiters: AtomicUsize,
num_idle: AtomicUsize,
options: Options,
}
@ -185,26 +187,20 @@ where
async fn new(url: &str, options: Options) -> crate::Result<Self> {
// TODO: Establish [min_idle] connections
let (pool_tx, pool_rx) = channel(options.max_size as usize);
Ok(Self {
url: url.to_owned(),
idle: ArrayQueue::new(options.max_size as usize),
waiters: SegQueue::new(),
pool_rx,
pool_tx,
size: AtomicU32::new(0),
num_idle: AtomicUsize::new(0),
num_waiters: AtomicUsize::new(0),
options,
})
}
#[inline]
fn try_acquire(&self) -> Option<Live<DB>> {
if let Ok(idle) = self.idle.pop() {
self.num_idle.fetch_sub(1, Ordering::AcqRel);
return Some(idle.live);
}
None
Some(self.pool_rx.recv().now_or_never()??.live(&self.pool_tx))
}
async fn acquire(&self) -> crate::Result<Live<DB>> {
@ -219,54 +215,23 @@ where
// Too many open connections
// Wait until one is available
let (sender, receiver) = oneshot::channel();
self.waiters.push(sender);
self.num_waiters.fetch_add(1, Ordering::AcqRel);
// Waiters are not dropped unless the pool is dropped
// which would drop this future
return Ok(receiver
return Ok(self
.pool_rx
.recv()
.await
.expect("waiter dropped without dropping pool"));
.expect("waiter dropped without dropping pool")
.live(&self.pool_tx));
}
if self.size.compare_and_swap(size, size + 1, Ordering::AcqRel) == size {
// Open a new connection and return directly
let raw = DB::open(&self.url).await?;
let live = Live {
raw,
since: Instant::now(),
};
return Ok(live);
return Ok(Live::pooled(raw, &self.pool_tx));
}
}
}
pub(crate) fn release(&self, mut live: Live<DB>) {
if self.num_waiters.load(Ordering::Acquire) > 0 {
while let Ok(waiter) = self.waiters.pop() {
self.num_waiters.fetch_sub(1, Ordering::AcqRel);
live = match waiter.send(live) {
Ok(()) => {
return;
}
Err(live) => live,
};
}
}
let _ = self.idle.push(Idle {
live,
since: Instant::now(),
});
self.num_idle.fetch_add(1, Ordering::AcqRel);
}
}
impl<DB> Executor for Pool<DB>
@ -338,9 +303,7 @@ where
{
Box::pin(async move {
let mut live = self.0.acquire().await?;
let result = live.raw.execute(query, params.into_params()).await;
self.0.release(live);
let result = live.execute(query, params.into_params()).await;
result
})
}
@ -356,14 +319,11 @@ where
{
Box::pin(async_stream::try_stream! {
let mut live = self.0.acquire().await?;
let mut s = live.raw.fetch(query, params.into_params());
let mut s = live.fetch(query, params.into_params());
while let Some(row) = s.next().await.transpose()? {
yield T::from_row(row);
}
drop(s);
self.0.release(live);
})
}
@ -378,29 +338,100 @@ where
{
Box::pin(async move {
let mut live = self.0.acquire().await?;
let row = live.raw.fetch_optional(query, params.into_params()).await?;
self.0.release(live);
let row = live.fetch_optional(query, params.into_params()).await?;
Ok(row.map(T::from_row))
})
}
}
struct Raw<DB> {
pub(crate) inner: DB,
pub(crate) created: Instant,
}
struct Idle<DB>
where
DB: Backend,
{
live: Live<DB>,
raw: Raw<DB>,
#[allow(unused)]
since: Instant,
}
impl<DB: Backend> Idle<DB> {
fn live(self, pool_tx: &Sender<Idle<DB>>) -> Live<DB> {
Live {
raw: Some(self.raw),
pool_tx: Some(pool_tx.clone()),
}
}
}
pub(crate) struct Live<DB>
where
DB: Backend,
{
pub(crate) raw: DB,
#[allow(unused)]
pub(crate) since: Instant,
raw: Option<Raw<DB>>,
pool_tx: Option<Sender<Idle<DB>>>,
}
impl<DB: Backend> Live<DB> {
pub fn unpooled(raw: DB) -> Self {
Live {
raw: Some(Raw {
inner: raw,
created: Instant::now(),
}),
pool_tx: None,
}
}
fn pooled(raw: DB, pool_tx: &Sender<Idle<DB>>) -> Self {
Live {
raw: Some(Raw {
inner: raw,
created: Instant::now(),
}),
pool_tx: Some(pool_tx.clone()),
}
}
pub fn release(mut self) {
self.release_mut()
}
fn release_mut(&mut self) {
// `.release_mut()` will be called twice if `.release()` is called
if let (Some(raw), Some(pool_tx)) = (self.raw.take(), self.pool_tx.as_ref()) {
pool_tx
.send(Idle {
raw,
since: Instant::now(),
})
.now_or_never()
.expect("(bug) connection released into a full pool")
}
}
}
const DEREF_ERR: &str = "(bug) connection already released to pool";
impl<DB: Backend> Deref for Live<DB> {
type Target = DB;
fn deref(&self) -> &DB {
&self.raw.as_ref().expect(DEREF_ERR).inner
}
}
impl<DB: Backend> DerefMut for Live<DB> {
fn deref_mut(&mut self) -> &mut DB {
&mut self.raw.as_mut().expect(DEREF_ERR).inner
}
}
impl<DB: Backend> Drop for Live<DB> {
fn drop(&mut self) {
self.release_mut()
}
}

View file

@ -1,8 +1,8 @@
use super::Postgres;
use crate::{
encode::{Encode, IsNull},
io::BufMut,
query::QueryParameters,
encode::{IsNull, Encode},
types::HasSqlType,
};
use byteorder::{BigEndian, ByteOrder, NetworkEndian};
@ -44,7 +44,7 @@ impl QueryParameters for PostgresQueryParameters {
(self.buf.len() - pos - 4) as i32
} else {
// Write a -1 for the len to indicate NULL
// TODO: It is illegal for [to_sql] to write any data if IsSql::No; fail a debug assertion
// TODO: It is illegal for [encode] to write any data if IsSql::No; fail a debug assertion
-1
};

View file

@ -1,7 +1,7 @@
use crate::{
postgres::types::{PostgresTypeFormat, PostgresTypeMetadata},
encode::IsNull,
Decode, HasSqlType, Postgres, Encode,
postgres::types::{PostgresTypeFormat, PostgresTypeMetadata},
Decode, Encode, HasSqlType, Postgres,
};
impl HasSqlType<[u8]> for Postgres {
@ -29,7 +29,7 @@ impl Encode<Postgres> for [u8] {
impl Encode<Postgres> for Vec<u8> {
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
<[u8] as Encode<Postgres>>::to_sql(self, buf)
<[u8] as Encode<Postgres>>::encode(self, buf)
}
}

View file

@ -1,7 +1,7 @@
use super::{Postgres, PostgresTypeFormat, PostgresTypeMetadata};
use crate::{
decode::Decode,
encode::{IsNull, Encode},
encode::{Encode, IsNull},
types::HasSqlType,
};

View file

@ -1,7 +1,7 @@
use super::{Postgres, PostgresTypeFormat, PostgresTypeMetadata};
use crate::{
decode::Decode,
encode::{IsNull, Encode},
encode::{Encode, IsNull},
types::HasSqlType,
};
use std::str;
@ -36,7 +36,7 @@ impl Encode<Postgres> for str {
impl Encode<Postgres> for String {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
<str as Encode<Postgres>>::to_sql(self.as_str(), buf)
<str as Encode<Postgres>>::encode(self.as_str(), buf)
}
}

View file

@ -1,7 +1,7 @@
use super::{Postgres, PostgresTypeFormat, PostgresTypeMetadata};
use crate::{
decode::Decode,
encode::{IsNull, Encode},
encode::{Encode, IsNull},
types::HasSqlType,
};
use byteorder::{BigEndian, ByteOrder};
@ -101,14 +101,14 @@ impl HasSqlType<f32> for Postgres {
impl Encode<Postgres> for f32 {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
<i32 as Encode<Postgres>>::to_sql(&(self.to_bits() as i32), buf)
<i32 as Encode<Postgres>>::encode(&(self.to_bits() as i32), buf)
}
}
impl Decode<Postgres> for f32 {
#[inline]
fn decode(buf: Option<&[u8]>) -> Self {
f32::from_bits(<i32 as Decode<Postgres>>::from_sql(buf) as u32)
f32::from_bits(<i32 as Decode<Postgres>>::decode(buf) as u32)
}
}
@ -126,13 +126,13 @@ impl HasSqlType<f64> for Postgres {
impl Encode<Postgres> for f64 {
#[inline]
fn encode(&self, buf: &mut Vec<u8>) -> IsNull {
<i64 as Encode<Postgres>>::to_sql(&(self.to_bits() as i64), buf)
<i64 as Encode<Postgres>>::encode(&(self.to_bits() as i64), buf)
}
}
impl Decode<Postgres> for f64 {
#[inline]
fn decode(buf: Option<&[u8]>) -> Self {
f64::from_bits(<i64 as Decode<Postgres>>::from_sql(buf) as u64)
f64::from_bits(<i64 as Decode<Postgres>>::decode(buf) as u64)
}
}

View file

@ -3,7 +3,7 @@ use uuid::Uuid;
use super::{Postgres, PostgresTypeFormat, PostgresTypeMetadata};
use crate::{
decode::Decode,
encode::{IsNull, Encode},
encode::{Encode, IsNull},
types::HasSqlType,
};

View file

@ -1,6 +1,6 @@
use crate::{
backend::Backend, error::Error, executor::Executor, query::QueryParameters, row::FromSqlRow,
encode::Encode, types::HasSqlType,
backend::Backend, encode::Encode, error::Error, executor::Executor, query::QueryParameters,
row::FromSqlRow, types::HasSqlType,
};
use futures_core::{future::BoxFuture, stream::BoxStream};

View file

@ -1,8 +1,7 @@
#[async_std::test]
async fn test_sqlx_macro() -> sqlx::Result<()> {
let mut conn =
sqlx::Connection::<sqlx::Postgres>::open("postgres://postgres@127.0.0.1/sqlx_test")
.await?;
sqlx::Connection::<sqlx::Postgres>::open("postgres://postgres@127.0.0.1/sqlx_test").await?;
let uuid: sqlx::types::Uuid = "256ba9c8-0048-11ea-b0f0-8f04859d047e".parse().unwrap();
let accounts = sqlx::query!("SELECT * from accounts where id != $1", None)
.fetch_optional(&mut conn)