WIP feat: reintroduce Pool

Signed-off-by: Austin Bonander <austin@launchbadge.com>
This commit is contained in:
Austin Bonander 2021-04-20 19:10:50 -07:00
parent e7664d8d19
commit e516bf66ed
No known key found for this signature in database
GPG key ID: 4E7DA63E66AFC37E
12 changed files with 1544 additions and 5 deletions

27
Cargo.lock generated
View file

@ -429,6 +429,21 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f55667319111d593ba876406af7c409c0ebb44dc4be6132a783ccf163ea14c1"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.13"
@ -436,6 +451,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c2dd2df839b57db9ab69c2c9d8f3e8c81984781937fe2807dc6dcf3b2ad2939"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
@ -488,6 +504,12 @@ dependencies = [
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c5629433c555de3d82861a7a4e3794a4c40040390907cfbfd7143a92a426c23"
[[package]]
name = "futures-task"
version = "0.3.13"
@ -500,9 +522,11 @@ version = "0.3.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1812c7ab8aedf8d6f2701a43e1243acdbcc2b36ab26e2ad421eb99ac963d96d1"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
@ -1088,7 +1112,10 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
name = "sqlx"
version = "0.6.0-pre"
dependencies = [
"crossbeam-queue",
"futures",
"futures-util",
"parking_lot",
"sqlx-core",
"sqlx-mysql",
"sqlx-postgres",

30
script/enforce-new-mod-style.sh Executable file
View file

@ -0,0 +1,30 @@
#!/usr/bin/env bash
# This script scans the project for `mod.rs` files and exits with a nonzero code if it finds any.
#
# You can also call it with `--fix` to replace any `mod.rs` files with their 2018 edition equivalents.
# The new files will be staged for commit for convenience.
FILES=$(find ./ -name mod.rs -print)
if [[ -z $FILES ]]; then
exit 0
fi
if [ "$1" != "--fix" ]; then
echo 'This project uses the Rust 2018 module style. mod.rs files are forbidden.'
echo "Execute \`$0 --fix\` to replace these with their 2018 equivalents and stage for commit."
echo 'Found mod.rs files:'
echo "$FILES"
exit 1
fi
echo 'Fixing Rust 2018 Module Style'
while read -r file; do
dest="$(dirname $file).rs"
echo "$file -> $dest"
mv $file $dest
git add $dest
done <<< $FILES

View file

@ -25,7 +25,10 @@ pub enum Error {
/// The database URL is malformed or contains invalid or unsupported
/// values for one or more options; a value of [`ConnectOptions`] failed
/// to be parsed.
ConnectOptions { message: Cow<'static, str>, source: Option<Box<dyn StdError + Send + Sync>> },
ConnectOptions {
message: Cow<'static, str>,
source: Option<Box<dyn StdError + Send + Sync>>,
},
/// An error that was returned from the database, normally from the
/// execution of a SQL command.
@ -65,6 +68,8 @@ pub enum Error {
///
Closed,
AcquireTimedOut,
/// An error occurred decoding a SQL value from the database.
Decode(DecodeError),
@ -72,19 +77,31 @@ pub enum Error {
Encode(EncodeError),
/// An attempt to access a column by index past the end of the row.
ColumnIndexOutOfBounds { index: usize, len: usize },
ColumnIndexOutOfBounds {
index: usize,
len: usize,
},
/// An attempt to access a column by name where no such column is
/// present in the row.
ColumnNotFound { name: Box<str> },
ColumnNotFound {
name: Box<str>,
},
/// An error occurred decoding a SQL value of a specific column
/// from the database.
ColumnDecode { column_index: usize, column_name: Box<str>, source: DecodeError },
ColumnDecode {
column_index: usize,
column_name: Box<str>,
source: DecodeError,
},
/// An error occurred encoding a value for a specific parameter to
/// be sent to the database.
ParameterEncode { parameter: Either<usize, Box<str>>, source: EncodeError },
ParameterEncode {
parameter: Either<usize, Box<str>>,
source: EncodeError,
},
}
impl Error {
@ -144,6 +161,8 @@ impl Display for Error {
Self::Closed => f.write_str("Connection or pool is closed"),
Self::AcquireTimedOut => f.write_str("Timeout on acquiring a connection from Pool"),
Self::Decode(error) => {
write!(f, "Decode: {}", error)
}

View file

@ -26,6 +26,8 @@ mod tokio_;
pub use actix_::Actix;
#[cfg(feature = "async-std")]
pub use async_std_::AsyncStd;
use std::future::Future;
use std::time::Duration;
#[cfg(feature = "tokio")]
pub use tokio_::Tokio;
@ -82,6 +84,15 @@ pub trait Runtime: 'static + Send + Sync + Sized + Debug {
fn connect_unix_async(path: &Path) -> BoxFuture<'_, io::Result<Self::UnixStream>>
where
Self: Async;
#[doc(hidden)]
#[cfg(all(unix, feature = "async"))]
fn timeout_async<'a, F: Future + Send + 'a>(
fut: F,
timeout: Duration,
) -> BoxFuture<'a, Option<F::Output>>
where
Self: Async;
}
/// Marks a [`Runtime`] as being capable of handling asynchronous execution.

View file

@ -12,6 +12,8 @@ use futures_util::{AsyncReadExt, AsyncWriteExt, FutureExt, TryFutureExt};
use crate::io::Stream;
use crate::{Async, Runtime};
use std::future::Future;
use std::time::{Duration, Instant};
/// Provides [`Runtime`] for [**Tokio**](https://tokio.rs). Supports only non-blocking operation.
///
@ -55,6 +57,14 @@ impl Runtime for Tokio {
fn connect_unix_async(path: &Path) -> BoxFuture<'_, io::Result<Self::UnixStream>> {
UnixStream::connect(path).map_ok(Compat::new).boxed()
}
#[doc(hidden)]
fn timeout_async<'a, F: Future + Send + 'a>(
fut: F,
timeout: Duration,
) -> BoxFuture<'a, Option<F::Output>> {
Box::pin(_tokio::time::timeout(timeout.into(), fut).map(Result::ok))
}
}
impl Async for Tokio {}

View file

@ -28,6 +28,9 @@ async-std = ["async", "sqlx-core/async-std"]
actix = ["async", "sqlx-core/actix"]
tokio = ["async", "sqlx-core/tokio"]
# Connection Pool
pool = ["crossbeam-queue", "parking_lot"]
# MySQL
mysql = ["sqlx-mysql"]
mysql-async = ["async", "mysql", "sqlx-mysql/async"]
@ -38,8 +41,15 @@ postgres = ["sqlx-postgres"]
postgres-async = ["async", "postgres", "sqlx-postgres/async"]
postgres-blocking = ["blocking", "postgres", "sqlx-postgres/blocking"]
[dependencies]
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core" }
sqlx-mysql = { version = "0.6.0-pre", path = "../sqlx-mysql", optional = true }
sqlx-postgres = { version = "0.6.0-pre", path = "../sqlx-postgres", optional = true }
futures-util = { version = "0.3", optional = true, features = ["io"] }
crossbeam-queue = { version = "0.3.1", optional = true }
parking_lot = { version = "0.11", optional = true }
[dev-dependencies]
futures = "0.3.5"

View file

@ -46,6 +46,9 @@
#[cfg(feature = "blocking")]
pub mod blocking;
#[cfg(feature = "pool")]
pub mod pool;
mod query;
mod query_as;
mod runtime;

214
sqlx/src/pool.rs Normal file
View file

@ -0,0 +1,214 @@
use crate::pool::connection::{Idle, Pooled};
use crate::pool::options::PoolOptions;
use crate::pool::shared::{SharedPool, TryAcquireResult};
use crate::pool::wait_list::WaitList;
use crate::{Connect, Connection, DefaultRuntime, Runtime};
use crossbeam_queue::ArrayQueue;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;
use std::time::{Duration, Instant};
mod connection;
mod options;
mod shared;
mod wait_list;
/// SQLx's integrated, runtime-agnostic connection pool.
pub struct Pool<Rt: Runtime, C: Connection<Rt>> {
shared: Arc<SharedPool<Rt, C>>,
}
impl<Rt: Runtime, C: Connection<Rt>> Pool<Rt, C> {
/// Construct a new pool with default configuration and given connection URI.
///
/// Connection will not be attempted until the first call to [`Self::acquire()`].
/// The only error that may be returned is [`Error::ConnectOptions`][crate::Error::ConnectOptions]
/// if the passed URI fails to parse or contains invalid options.
///
/// If you want to eagerly connect on construction of the pool, use [`Self::connect`]
/// instead.
///
/// See also:
/// * [`Self::new_with()`]
/// * [`Self::builder()`] (alias of [`PoolOptions::new()`]) and [`PoolOptions::build()`]
pub fn new(uri: &str) -> crate::Result<Self> {
Self::builder().build(uri)
}
/// Construct a new pool with default configuration and given connection options.
///
/// Connection will not be attempted until the first call to [`Self::acquire()`].
/// The only error that may be returned is [`Error::ConnectOptions`][crate::Error::ConnectOptions]
/// if the passed URI fails to parse or contains invalid options.
///
/// If you want to eagerly connect on construction of the pool, use [`Self::connect_with()`]
/// instead.
///
/// See also:
/// * [`Self::builder()`] (alias of [`PoolOptions::new()`]) and [`PoolOptions::build_with()`]
pub fn new_with(connect_options: <C as Connect<Rt>>::Options) -> Self {
Self::builder().build_with(connect_options)
}
/// A helpful alias for [`PoolOptions::new()`].
pub fn builder() -> PoolOptions<Rt, C> {
PoolOptions::new()
}
}
#[cfg(feature = "async")]
impl<Rt: crate::Async, C: Connection<Rt>> Pool<Rt, C> {
/// Construct a new pool with default configuration and given connection URI
/// and establish at least one connection to ensure the URI is valid.
///
/// See also:
/// * [`Self::connect_with()`]
/// * [`Self::builder()`] (alias of [`PoolOptions::new()`]) and [`PoolOptions::connect()`]
pub async fn connect(uri: &str) -> crate::Result<Self> {
Self::builder().connect(uri).await
}
/// Construct a new pool with default configuration and given connection options
/// and establish at least one connection to ensure the latter are valid.
///
/// See also:
/// * [`Self::builder()`] (alias of [`PoolOptions::new()`]) and [`PoolOptions::connect_with()`]
pub async fn connect_with(connect_options: <C as Connect<Rt>>::Options) -> crate::Result<Self> {
Self::builder().connect_with(connect_options).await
}
/// Acquire a connection from the pool.
///
/// This will either wait until a connection is released by another task (via the `Drop` impl on
/// [`Pooled`]) or, if the pool is not yet at its maximum size, opening a new connection.
///
/// If an acquire timeout is configured via [`PoolOptions::acquire_timeout()`], this will wait
/// at most the given duration before returning [`Error::AcquireTimedOut`][crate::Error::AcquireTimedOut].
///
/// See also:
/// * [`Self::acquire_timeout()`]
/// * [`PoolOptions::max_connections()`]
/// * [`PoolOptions::acquire_timeout()`]
pub async fn acquire(&self) -> crate::Result<Pooled<Rt, C>> {
if let Some(timeout) = self.shared.pool_options.acquire_timeout {
self.acquire_timeout(timeout).await
} else {
self.acquire_inner().await
}
}
/// Acquire a connection from the pool, waiting at most the given duration.
///
/// This will either wait until a connection is released by another task (via the `Drop` impl on
/// [`Pooled`]) or, if the pool is not yet at its maximum size, opening a new connection.
///
/// If the given duration elapses, this will return
/// [`Error::AcquireTimedOut`][crate::Error::AcquireTimedOut].
pub async fn acquire_timeout(&self, timeout: Duration) -> crate::Result<Pooled<Rt, C>> {
Rt::timeout_async(timeout, self.acquire_inner())
.await
.ok_or(crate::Error::AcquireTimedOut)?
}
async fn acquire_inner(&self) -> crate::Result<Pooled<Rt, C>> {
let mut acquire_permit = None;
loop {
match self.shared.try_acquire(acquire_permit.take()) {
TryAcquireResult::Acquired(mut conn) => {
match self.shared.on_acquire_async(&mut conn).await {
Ok(()) => return Ok(conn.attach(&self.shared)),
Err(e) => {
log::info!("error from before_acquire: {:?}", e);
}
}
}
TryAcquireResult::Connect(permit) => self.shared.connect_async(permit).await,
TryAcquireResult::Wait => {
acquire_permit = Some(self.shared.wait_async().await);
}
TryAcquireResult::PoolClosed => Err(crate::Error::Closed),
}
}
}
}
#[cfg(feature = "blocking")]
impl<C: Connection<crate::Blocking>> Pool<crate::Blocking, C> {
/// Construct a new pool with default configuration and given connection URI
/// and establish at least one connection to ensure the URI is valid.
///
/// See also:
/// * [`Self::connect_with()`]
/// * [`Self::builder()`] (alias of [`PoolOptions::new()`]) and [`PoolOptions::connect()`]
pub fn connect(uri: &str) -> crate::Result<Self> {
Self::builder().connect(uri)
}
/// Construct a new pool with default configuration and given connection options
/// and establish at least one connection to ensure the latter are valid.
///
/// See also:
/// * [`Self::builder()`] (alias of [`PoolOptions::new()`]) and [`PoolOptions::connect_with()`]
pub fn connect_with(
connect_options: <C as Connect<crate::Blocking>>::Options,
) -> crate::Result<Self> {
Self::builder().connect_with(connect_options)
}
/// Acquire a connection from the pool.
///
/// This will either wait until a connection is released by another thread (via the `Drop` impl on
/// [`Pooled`]) or, if the pool is not yet at its maximum size, opening a new connection.
///
/// If an acquire timeout is configured via [`PoolOptions::acquire_timeout()`], this will wait
/// at most the given duration before returning [`Error::AcquireTimedOut`][crate::Error::AcquireTimedOut].
///
/// See also:
/// * [`Self::acquire_timeout()`]
/// * [`PoolOptions::max_connections()`]
/// * [`PoolOptions::acquire_timeout()`]
pub fn acquire(&self) -> crate::Result<Pooled<crate::Blocking, C>> {
self.acquire_inner(self.shared.pool_options.acquire_timeout)
}
/// Acquire a connection from the pool, waiting at most the given duration.
///
/// This will either wait until a connection is released by another thread (via the `Drop` impl on
/// [`Pooled`]) or, if the pool is not yet at its maximum size, opening a new connection.
///
/// If the given duration elapses, this will return
/// [`Error::AcquireTimedOut`][crate::Error::AcquireTimedOut].
pub fn acquire_timeout(&self, timeout: Duration) -> crate::Result<Pooled<crate::Blocking, C>> {
self.acquire_inner(Some(timeout))
}
fn acquire_inner(
&self,
timeout: Option<Duration>,
) -> crate::Result<Pooled<crate::Blocking, C>> {
let mut acquire_permit = None;
let deadline = timeout.map(|timeout| Instant::now() + timeout);
loop {
match self.shared.try_acquire(acquire_permit.take()) {
TryAcquireResult::Acquired(mut conn) => {
match self.shared.on_acquire_blocking(&mut conn) {
Ok(()) => return Ok(conn.attach(&self.shared)),
Err(e) => {
log::info!("error from before_acquire: {:?}", e);
}
}
}
TryAcquireResult::Connect(permit) => self.shared.connect_blocking(permit),
TryAcquireResult::Wait => {
acquire_permit = Some(
self.shared.wait_blocking(deadline).ok_or(crate::Error::AcquireTimedOut)?,
);
}
TryAcquireResult::PoolClosed => Err(crate::Error::Closed),
}
}
}
}

156
sqlx/src/pool/connection.rs Normal file
View file

@ -0,0 +1,156 @@
use super::shared::{DecrementSizeGuard, SharedPool};
use crate::{Connection, Runtime};
use std::fmt::{self, Debug, Formatter};
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::time::Instant;
/// A connection managed by a [`Pool`][crate::pool::Pool].
///
/// Will be returned to the pool on-drop.
pub struct Pooled<Rt: Runtime, C: Connection<Rt>> {
live: Option<C>,
pub(crate) pool: Arc<SharedPool<Rt, C>>,
}
pub(super) struct Live<Rt: Runtime, C: Connection<Rt>> {
pub(super) raw: C,
pub(super) created: Instant,
_rt: PhantomData<Rt>,
}
pub(super) struct Idle<Rt: Runtime, C: Connection<Rt>> {
pub(super) live: Live<Rt, C>,
pub(super) since: Instant,
}
/// RAII wrapper for connections being handled by functions that may drop them
pub(super) struct Floating<'pool, C> {
inner: C,
guard: DecrementSizeGuard<'pool>,
}
const DEREF_ERR: &str = "(bug) connection already released to pool";
impl<Rt: Runtime, C: Connection<Rt>> Debug for Pooled<Rt, C> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
// TODO: Show the type name of the connection ?
f.debug_struct("PoolConnection").finish()
}
}
impl<Rt: Runtime, C: Connection<Rt>> Deref for Pooled<Rt, C> {
type Target = C;
fn deref(&self) -> &Self::Target {
&self.live.as_ref().expect(DEREF_ERR).raw
}
}
impl<Rt: Runtime, C: Connection<Rt>> DerefMut for Pooled<Rt, C> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.live.as_mut().expect(DEREF_ERR).raw
}
}
impl<Rt: Runtime, C: Connection<Rt>> Pooled<Rt, C> {
/// Explicitly release a connection from the pool
pub fn release(mut self) -> C {
self.live.take().expect("PoolConnection double-dropped").float(&self.pool).detach()
}
}
/// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from.
impl<Rt: Runtime, C: Connection<Rt>> Drop for Pooled<Rt, C> {
fn drop(&mut self) {
if let Some(live) = self.live.take() {
self.pool.release(live);
}
}
}
impl<Rt: Runtime, C: Connection<Rt>> Live<Rt, C> {
pub fn float(self, guard: DecrementSizeGuard<'_>) -> Floating<'_, Self> {
Floating { inner: self, guard }
}
pub fn into_idle(self) -> Idle<Rt, C> {
Idle { live: self, since: Instant::now() }
}
}
impl<Rt: Runtime, C: Connection<Rt>> Idle<Rt, C> {
pub fn float(self, guard: DecrementSizeGuard<'_>) -> Floating<'_, Self> {
Floating { inner: self, guard }
}
}
impl<Rt: Runtime, C: Connection<Rt>> Deref for Idle<Rt, C> {
type Target = Live<Rt, C>;
fn deref(&self) -> &Self::Target {
&self.live
}
}
impl<Rt: Runtime, C: Connection<Rt>> DerefMut for Idle<Rt, C> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.live
}
}
impl<'s, C> Floating<'s, C> {
pub fn into_leakable(self) -> C {
self.guard.cancel();
self.inner
}
pub fn same_pool(&self, other: &SharedPool<Rt, C>) -> bool {
self.guard.same_pool(other)
}
}
impl<'s, Rt: Runtime, C: Connection<C>> Floating<'s, Live<Rt, C>> {
pub fn attach(self, pool: &Arc<SharedPool<Rt, C>>) -> Pooled<Rt, C> {
let Floating { inner, guard } = self;
debug_assert!(guard.same_pool(pool), "BUG: attaching connection to different pool");
guard.cancel();
Pooled { live: Some(inner), pool: Arc::clone(pool) }
}
pub fn detach(self) -> C {
self.inner.raw
}
pub fn into_idle(self) -> Floating<'s, Idle<Rt, C>> {
Floating { inner: self.inner.into_idle(), guard: self.guard }
}
}
impl<'s, Rt: Runtime, C: Connection<Rt>> Floating<'s, Idle<Rt, C>> {
pub fn into_live(self) -> Floating<'s, Live<Rt, C>> {
Floating { inner: self.inner.live, guard: self.guard }
}
pub async fn close(self) -> crate::Result<()> {
// `guard` is dropped as intended
self.inner.live.raw.close().await
}
}
impl<C> Deref for Floating<'_, C> {
type Target = C;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<C> DerefMut for Floating<'_, C> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}

341
sqlx/src/pool/options.rs Normal file
View file

@ -0,0 +1,341 @@
use crate::pool::shared::SharedPool;
use crate::pool::Pool;
use crate::{Connect, ConnectOptions, Connection, Runtime};
use std::cmp;
use std::fmt::{self, Debug, Formatter};
use std::marker::PhantomData;
use std::sync::Arc;
use std::time::{Duration, Instant};
/// Configuration options/builder for constructing a [`Pool`].
///
/// See the source of [`Self::new()`] for the current defaults.
pub struct PoolOptions<Rt: Runtime, C: Connection<Rt>> {
// general options
pub(crate) max_connections: u32,
pub(crate) acquire_timeout: Option<Duration>,
pub(crate) min_connections: u32,
pub(crate) max_lifetime: Option<Duration>,
pub(crate) idle_timeout: Option<Duration>,
// callback functions (any runtime)
pub(crate) after_release: Option<Box<dyn Fn(&mut C) -> bool + 'static + Send + Sync>>,
// callback functions (async)
#[cfg(feature = "async")]
pub(crate) after_connect_async: Option<
Box<
dyn Fn(&mut C) -> futures_util::BoxFuture<'_, crate::Result<()>>
+ Send
+ Sync
+ 'static,
>,
>,
#[cfg(feature = "async")]
pub(crate) before_acquire_async: Option<
Box<
dyn Fn(&mut C) -> futures_util::BoxFuture<'_, crate::Result<()>>
+ Send
+ Sync
+ 'static,
>,
>,
//callback functions (blocking)
#[cfg(feature = "blocking")]
pub(crate) after_connect_blocking:
Option<Box<dyn Fn(&mut C) -> crate::Result<()> + Send + Sync + 'static>>,
#[cfg(feature = "blocking")]
pub(crate) before_acquire_blocking:
Option<Box<dyn Fn(&mut C) -> crate::Result<()> + Send + Sync + 'static>>,
// to satisfy the orphan type params check
_rt: PhantomData<Rt>,
}
impl<Rt: Runtime, C: Connection<Rt>> Default for PoolOptions<Rt, C> {
fn default() -> Self {
Self::new()
}
}
impl<Rt: Runtime, C: Connection<Rt>> PoolOptions<Rt, C> {
/// Create a new `PoolOptions` with some arbitrary, but sane, default values.
///
/// See the source of this method for the current values.
pub fn new() -> Self {
Self {
min_connections: 0,
max_connections: 10,
acquire_timeout: Some(Duration::from_secs(60)),
idle_timeout: Some(Duration::from_secs(10 * 60)),
max_lifetime: Some(Duration::from_secs(30 * 60)),
after_release: None,
#[cfg(feature = "async")]
after_connect_async: None,
#[cfg(feature = "async")]
before_acquire_async: None,
#[cfg(feature = "blocking")]
after_connect_blocking: None,
#[cfg(feature = "blocking")]
before_acquire_blocking: None,
_rt: PhantomData,
}
.test_before_acquire()
}
/// Set the minimum number of connections that this pool should maintain at all times.
///
/// When the pool size drops below this amount, new connections are established automatically
/// in the background.
///
/// See the source of [`Self::new()`] for the default value.
pub fn min_connections(mut self, min: u32) -> Self {
self.min_connections = min;
self
}
/// Set the maximum number of connections that this pool should maintain.
///
/// See the source of [`Self::new()`] for the default value.
pub fn max_connections(mut self, max: u32) -> Self {
self.max_connections = max;
self
}
/// Set the amount of time a task should wait while attempting to acquire a connection.
///
/// If this timeout elapses, [`Pool::acquire()`] will return an error.
///
/// If set to `None`, [`Pool::acquire()`] will wait as long as it takes to acquire
/// a new connection.
///
/// See the source of [`Self::new()`] for the default value.
pub fn acquire_timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
self.acquire_timeout = timeout.into();
self
}
/// Set the maximum lifetime of individual connections.
///
/// Any connection with a lifetime greater than this will be closed.
///
/// When set to `None`, all connections live until either reaped by [`idle_timeout`]
/// or explicitly disconnected.
///
/// Long-lived connections are not recommended due to the unfortunate reality of memory/resource
/// leaks on the database-side. It is better to retire connections periodically
/// (even if only once daily) to allow the database the opportunity to clean up data structures
/// (parse trees, query metadata caches, thread-local storage, etc.) that are associated with a
/// session.
///
/// See the source of [`Self::new()`] for the default value.
///
/// [`idle_timeout`]: Self::idle_timeout
pub fn max_lifetime(mut self, lifetime: impl Into<Option<Duration>>) -> Self {
self.max_lifetime = lifetime.into();
self
}
/// Set a maximum idle duration for individual connections.
///
/// Any connection with an idle duration longer than this will be closed.
///
/// For usage-based database server billing, this can be a cost saver.
///
/// See the source of [`Self::new()`] for the default value.
pub fn idle_timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
self.idle_timeout = timeout.into();
self
}
/// If enabled, the health of a connection will be verified by a call to [`Connection::ping`]
/// before returning the connection.
///
/// This overrides a previous callback set to [Self::before_acquire] and is also overridden by
/// `before_acquire`.
pub fn test_before_acquire(mut self) -> Self {
#[cfg(feature = "async")]
self.before_acquire_async = Some(Box::new(Connection::ping));
#[cfg(feature = "blocking")]
todo!("Connection doesn't have a ping_blocking()");
self
}
pub fn after_release<F>(mut self, callback: F) -> Self
where
F: Fn(&mut C) -> bool + 'static + Send + Sync,
{
self.after_release = Some(Box::new(callback));
self
}
/// Creates a new pool from this configuration.
///
/// Note that **this does not immediately connect to the database**;
/// this call will only error if the URI fails to parse.
///
/// A connection will first be established either on the first call to
/// [`Pool::acquire()`][super::Pool::acquire()] or,
/// if [`self.min_connections`][Self::min_connections] is nonzero,
/// when the background monitor task (async runtime) or thread (blocking runtime) is spawned.
///
/// If you prefer to establish a minimum number of connections on startup to ensure a valid
/// configuration, use [`.connect()`][Self::connect()] instead.
///
/// See [`Self::build_with()`] for a version that lets you pass a [`ConnectOptions`].
pub fn build(self, uri: &str) -> crate::Result<Pool<Rt, C>> {
Ok(self.build_with(uri.parse()?))
}
/// Creates a new pool from this configuration.
///
/// Note that **this does not immediately connect to the database**;
/// this method call is infallible.
///
/// A connection will first be established either on the first call to
/// [`Pool::acquire()`][super::Pool::acquire()] or,
/// if [`self.min_connections`][Self::min_connections] is nonzero,
/// when the background monitor task (async runtime) or thread (blocking runtime) is spawned.
///
/// If you prefer to establish at least one connections on startup to ensure a valid
/// configuration, use [`.connect_with()`][Self::connect_with()] instead.
pub fn build_with(self, options: <C as Connect<Rt>>::Options) -> Pool<Rt, C> {
Pool { shared: SharedPool::new(self, options).into() }
}
}
#[cfg(feature = "async")]
impl<Rt: crate::Async, C: Connection<Rt>> PoolOptions<Rt, C> {
/// Perform an action after connecting to the database.
pub fn after_connect<F>(mut self, callback: F) -> Self
where
for<'c> F:
Fn(&'c mut C) -> futures_util::BoxFuture<'c, crate::Result<()>> + Send + Sync + 'static,
{
self.after_connect_async = Some(Box::new(callback));
self
}
/// If set, this callback is executed with a connection that has been acquired from the idle
/// queue.
///
/// If the callback returns `Ok`, the acquired connection is returned to the caller. If
/// it returns `Err`, the error is logged and the caller attempts to acquire another connection.
///
/// This overrides [`Self::test_before_acquire()`].
pub fn before_acquire<F>(mut self, callback: F) -> Self
where
for<'c> F: Fn(&'c mut C) -> futures_util::BoxFuture<'c, crate::Result<bool>>
+ Send
+ Sync
+ 'static,
{
self.before_acquire_async = Some(Box::new(callback));
self
}
/// Creates a new pool from this configuration and immediately establishes
/// [`self.min_connections`][Self::min_connections()],
/// or just one connection if `min_connections == 0`.
///
/// Returns an error if the URI fails to parse or an error occurs while establishing a connection.
///
/// See [`Self::connect_with()`] for a version that lets you pass a [`ConnectOptions`].
///
/// If you do not want to connect immediately on startup,
/// use [`.build()`][Self::build()] instead.
pub async fn connect(self, uri: &str) -> crate::Result<Pool<Rt, C>> {
self.connect_with(uri.parse()?).await
}
/// Creates a new pool from this configuration and immediately establishes
/// [`self.min_connections`][Self::min_connections()],
/// or just one connection if `min_connections == 0`.
///
/// Returns an error if an error occurs while establishing a connection.
///
/// If you do not want to connect immediately on startup,
/// use [`.build_with()`][Self::build_with()] instead.
pub async fn connect_with(
self,
options: <C as Connect<Rt>>::Options,
) -> crate::Result<Pool<Rt, C>> {
let mut shared = SharedPool::new(self, options);
shared.init_min_connections_async().await?;
Ok(Pool { shared: shared.into() })
}
}
#[cfg(feature = "blocking")]
impl<C: Connection<crate::Blocking>> PoolOptions<crate::Blocking, C> {
/// Perform an action after connecting to the database.
pub fn after_connect(
mut self,
callback: impl Fn(&mut C) -> crate::Result<()> + Send + Sync + 'static,
) -> Self {
self.after_connect_blocking = Some(Box::new(callback));
self
}
/// If set, this callback is executed with a connection that has been acquired from the idle
/// queue.
///
/// If the callback returns `Ok`, the acquired connection is returned to the caller. If
/// it returns `Err`, the error is logged and the caller attempts to acquire another connection.
///
/// This overrides [`Self::test_before_acquire()`].
pub fn before_acquire<F>(
mut self,
callback: impl Fn(&mut C) -> crate::Result<bool> + Send + Sync + 'static,
) -> Self {
self.before_acquire_blocking = Some(Box::new(callback));
self
}
/// Creates a new pool from this configuration and immediately establishes
/// [`self.min_connections`][Self::min_connections()],
/// or just one connection if `min_connections == 0`.
///
/// Returns an error if the URI fails to parse or an error occurs while establishing a connection.
///
/// See [`Self::connect_with()`] for a version that lets you pass a [`ConnectOptions`].
///
/// If you do not want to connect immediately on startup,
/// use [`.build()`][Self::build()] instead.
pub fn connect(self, uri: &str) -> crate::Result<Pool<Rt, C>> {
self.connect_with(uri.parse()?)
}
/// Creates a new pool from this configuration and immediately establishes
/// [`self.min_connections`][Self::min_connections()],
/// or just one connection if `min_connections == 0`.
///
/// Returns an error if an error occurs while establishing a connection.
///
/// If you do not want to connect immediately on startup,
/// use [`.build_with()`][Self::build_with()] instead.
pub fn connect_with(self, options: <C as Connect<Rt>>::Options) -> crate::Result<Pool<Rt, C>> {
let mut shared = SharedPool::new(self, options);
shared.init_min_connections_blocking()?;
Ok(Pool { shared: shared.into() })
}
}
impl<Rt: Runtime, C: Connection<Rt>> Debug for PoolOptions<Rt, C> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("PoolOptions")
.field("max_connections", &self.max_connections)
.field("min_connections", &self.min_connections)
.field("acquire_timeout", &self.acquire_timeout)
.field("max_lifetime", &self.max_lifetime)
.field("idle_timeout", &self.idle_timeout)
.finish()
}
}

266
sqlx/src/pool/shared.rs Normal file
View file

@ -0,0 +1,266 @@
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
use std::time::Instant;
use std::{cmp, mem, ptr};
use crossbeam_queue::ArrayQueue;
use crate::pool::connection::{Floating, Idle, Pooled};
use crate::pool::options::PoolOptions;
use crate::pool::wait_list::WaitList;
use crate::{Acquire, Connect, Connection, Runtime};
pub struct SharedPool<Rt: Runtime, C: Connection<Rt>> {
idle: ArrayQueue<Idle<Rt, C>>,
wait_list: WaitList,
size: AtomicU32,
is_closed: AtomicBool,
pub(crate) pool_options: PoolOptions<Rt, C>,
connect_options: <C as Connect<Rt>>::Options,
}
/// RAII guard returned by `Pool::try_increment_size()` and others.
///
/// Will decrement the pool size if dropped, to avoid semantically "leaking" connections
/// (where the pool thinks it has more connections than it does).
pub struct DecrementSizeGuard<'pool> {
size: &'pool AtomicU32,
wait_list: &'pool WaitList,
dropped: bool,
}
// NOTE: neither of these may be `Copy` or `Clone`!
pub struct ConnectPermit<'pool>(DecrementSizeGuard<'pool>);
pub struct AcquirePermit<'pool>(&'pool AtomicU32); // just need a pointer to compare for sanity check
/// Returned by `SharedPool::try_acquire()`.
///
/// Compared to SQLx <= 0.5, the process of acquiring a connection is broken into distinct steps
/// in order to facilitate both blocking and nonblocking versions.
pub enum TryAcquireResult<'pool, Rt: Runtime, C: Connection<Rt>> {
/// A connection has been acquired from the idle queue.
///
/// Depending on the pool settings, it may still need to be tested for liveness before being
/// returned to the user.
Acquired(Floating<'pool, Idle<Rt, C>>),
/// The pool's current size dropped below its maximum and a new connection may be opened.
///
/// Call `.connect_async()` or `.connect_blocking()` with the given permit.
Connect(ConnectPermit<'pool>),
/// The task or thread should wait and call `.try_acquire()` again.
///
/// The inner value is the same `AcquirePermit` that was passed to `.try_acquire()`.
Wait,
/// The pool is closed; the attempt to acquire the connection should return an error.
PoolClosed,
}
impl<Rt: Runtime, C: Connection<Rt>> SharedPool<Rt, C> {
pub fn new(
pool_options: PoolOptions<Rt, C>,
connect_options: <C as Connect<Rt>>::Options,
) -> Self {
Self {
idle: ArrayQueue::new(pool_options.max_connections as usize),
wait_list: WaitList::new(),
size: AtomicU32::new(0),
is_closed: AtomicBool::new(false),
pool_options,
connect_options,
}
}
#[inline]
pub fn is_closed(&self) -> bool {
self.is_closed.load(Ordering::Acquire)
}
/// Attempt to acquire a connection.
///
/// If `permit` is `Some`,
pub fn try_acquire(&self, permit: Option<AcquirePermit<'_>>) -> TryAcquireResult<'_, C> {
use TryAcquireResult::*;
assert!(
permit.map_or(true, |permit| ptr::eq(&self.size, permit.0)),
"BUG: given AcquirePermit is from a different pool"
);
if self.is_closed() {
return PoolClosed;
}
// if the user has an `AcquirePermit`, then they've already waited at least once
// and we should try to get them a connection immediately if possible;
//
// otherwise, we can immediately return a connection or `ConnectPermit` if no one is waiting
if permit.is_some() || self.wait_list.is_empty() {
// try to pull a connection from the idle queue
if let Some(idle) = self.idle.pop() {
return Acquired(idle.float(self));
}
// try to bump `self.size`
if let Some(guard) = self.try_increment_size() {
return Connect(ConnectPermit(guard));
}
}
// check again after the others to make sure
if self.is_closed() {
return PoolClosed;
}
Wait
}
/// Attempt to increment the current size, failing if it would exceed the maximum size.
fn try_increment_size(&self) -> Option<DecrementSizeGuard<'_>> {
if self.is_closed() {
return None;
}
self.size
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| {
(size < todo!("self.options.max_connections")).then(|| size + 1)
})
.ok()
.map(|_| DecrementSizeGuard::new(self))
}
}
#[cfg(feature = "async")]
impl<Rt: crate::Async, C: Connection<Rt>> SharedPool<Rt, C> {
pub async fn wait_async(&self) -> AcquirePermit {
self.wait_list.wait().await;
AcquirePermit(&self.size)
}
pub async fn connect_async(
self: &Arc<Self>,
permit: ConnectPermit,
) -> crate::Result<Pooled<Rt, C>>
where
C: crate::Connect<Rt>,
{
assert!(permit.0.same_pool(self), "BUG: ConnectPermit is from a different pool!");
let mut conn = crate::Connect::connect_with(&self.connect_options)
.await
.map(|c| Floating::new_live(c, permit.0))?;
if let Some(ref after_connect) = self.pool_options.after_connect_async {
after_connect(&mut conn).await?;
}
Ok(conn.attach(self))
}
pub async fn on_acquire_async(
self: &Arc<Self>,
conn: &mut Floating<'_, C>,
) -> crate::Result<()> {
assert!(conn.same_pool(self), "BUG: connection is from a different pool");
if let Some(ref before_acquire) = self.pool_options.before_acquire_async {
before_acquire(conn).await?;
}
Ok(())
}
pub async fn init_min_connections_async<Rt: Runtime, C: Connection<Rt>>(
&mut self,
) -> crate::Result<()> {
for _ in 0..cmp::max(self.pool_options.min_connections, 1) {
// this guard will prevent us from exceeding `max_size`
if let Some(guard) = self.try_increment_size() {
// [connect] will raise an error when past deadline
let conn = self.connect_async(ConnectPermit(guard)).await?;
let is_ok = self.idle.push(conn.into_idle().into_leakable()).is_ok();
if !is_ok {
panic!("BUG: connection queue overflow in init_min_connections");
}
}
}
Ok(())
}
}
#[cfg(feature = "blocking")]
impl<C: Connection<crate::Blocking>> SharedPool<crate::Blocking, C> {
pub fn wait_blocking(&self, deadline: Option<Instant>) -> Option<AcquirePermit<'_>> {
self.wait_list.wait().block_on(deadline).then(|| AcquirePermit(&self.size))
}
pub fn connect_blocking(
self: &Arc<Self>,
permit: ConnectPermit<'_>,
) -> crate::Result<Pooled<crate::Blocking, C>>
where
C: crate::blocking::Connect<crate::Blocking>,
{
assert!(permit.0.same_pool(self), "BUG: ConnectPermit is from a different pool!");
crate::blocking::Connect::connect_with(&self.connect_options)
.map(|c| Floating::new_live(c, permit.0).attach(self))
}
pub fn on_acquire_blocking(self: &Arc<Self>, conn: &mut Floating<'_, C>) -> crate::Result<()> {
assert!(conn.same_pool(self), "BUG: connection is from a different pool");
if let Some(ref before_acquire) = self.pool_options.before_acquire_blocking {
before_acquire(conn)?;
}
Ok(())
}
pub fn init_min_connections_blocking<Rt: Runtime, C: Connection<Rt>>(
&mut self,
) -> crate::Result<()> {
for _ in 0..cmp::max(self.pool_options.min_connections, 1) {
// this guard will prevent us from exceeding `max_size`
if let Some(guard) = self.try_increment_size() {
// [connect] will raise an error when past deadline
let conn = self.connect_blocking(ConnectPermit(guard))?;
let is_ok = self.idle.push(conn.into_idle().into_leakable()).is_ok();
if !is_ok {
panic!("BUG: connection queue overflow in init_min_connections");
}
}
}
Ok(())
}
}
impl<'pool> DecrementSizeGuard<'pool> {
fn new<Rt: Runtime, C: Connection<Rt>>(pool: &'pool SharedPool<Rt, C>) -> Self {
Self { size: &pool.size, wait_list: &pool.wait_list, dropped: false }
}
/// Return `true` if the internal references point to the same fields in `SharedPool`.
pub fn same_pool<Rt: Runtime, C: Connection<Rt>>(
&self,
pool: &'pool SharedPool<Rt, C>,
) -> bool {
ptr::eq(self.size, &pool.size) && ptr::eq(self.wait_list, &pool.wait_list)
}
pub fn cancel(self) {
mem::forget(self);
}
}
impl Drop for DecrementSizeGuard<'_> {
fn drop(&mut self) {
assert!(!self.dropped, "double-dropped!");
self.dropped = true;
self.size.fetch_sub(1, Ordering::SeqCst);
self.wait_list.wake_one();
}
}

452
sqlx/src/pool/wait_list.rs Normal file
View file

@ -0,0 +1,452 @@
// see `SAFETY:` annotations
#![allow(unsafe_code)]
use parking_lot::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll, Waker};
use std::thread::{self, Thread};
use std::time::Instant;
/// An intrusive list of waiting tasks.
///
/// Tasks wait by calling `.wait().await` for async code or `.wait().block_on(deadline)`
/// for blocking code where `deadline` is `Option<Instant>`
pub struct WaitList(RwLock<ListInner>);
struct ListInner {
// NOTE: these must either both be null or both be pointing to a node
/// The head of the list; if NULL then the list is empty.
head: *mut Node,
/// The tail of the list; if NULL then the list is empty.
tail: *mut Node,
}
// SAFETY: access to `Node` pointers must be protected by a lock
// this could potentially be made lock-free but the critical sections are short
// so using a lightweight RwLock like from `parking_lot` seemed reasonable
unsafe impl Send for ListInner {}
unsafe impl Sync for ListInner {}
impl WaitList {
pub fn new() -> Self {
WaitList(RwLock::new(ListInner { head: ptr::null_mut(), tail: ptr::null_mut() }))
}
pub fn is_empty(&self) -> bool {
let inner = self.0.read();
inner.head.is_null() && inner.tail.is_null()
}
pub fn wake_one(&self) {
self.0.read().wake(false)
}
pub fn wake_all(&self) {
self.0.read().wake(true)
}
/// Wait in this waitlist for a call to either `.wake_one()` or `.wake_all()`.
///
/// The returned handle may either be `.await`ed for async code, or you can call
/// `.block_on(deadline)` for blocking code, where `deadline` is the optional `Instant`
/// at which to stop waiting.
pub fn wait(&self) -> Wait<'_> {
Wait { list: &self.0, node: None, actually_woken: bool, _not_unpin: PhantomPinned }
}
}
impl ListInner {
/// Wake either one or all nodes in the list.
fn wake(&self, all: bool) {
let mut node_p: *const Node = inner.head;
// SAFETY: `node_p` is not dangling as long as we have at least a shared lock
// (implied by having `&self`)
while let Some(node) = unsafe { node_p.as_ref() } {
// `.wake()` only returns `true` if the node was not already woken
if node.wake() && !all {
break;
}
node_p = node.next;
}
}
}
pub struct Wait<'a> {
list: &'a RwLock<ListInner>,
/// SAFETY: `Node` must not be modified without a lock
/// SAFETY: `Node` may not be moved once it's entered in the list
node: Option<Node>,
actually_woken: bool,
_not_unpin: PhantomPinned,
}
/// cancel-safe
impl<'a> Future for Wait<'a> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let node = self.get_node(|| Wake::Waker(cx.waker().clone()));
if node.woken.load(Ordering::Acquire) {
// SAFETY: not moving out of `self` here
unsafe { self.get_unchecked_mut().actually_woken = true }
Poll::Ready(())
} else {
let wake = RwLock::upgradable_read(&node.wake);
// make sure our `Waker` is up to date;
// the waker may change if the task moves between threads
if !wake.waker_eq(cx.waker()) {
*RwLockUpgradableReadGuard::upgrade(wake) = Wake::Waker(cx.waker().clone());
}
Poll::Pending
}
}
}
impl<'a> Wait<'a> {
/// Insert a node into the parent `WaitList` referred to by `self` and return it.
///
/// The provided closure should return the appropriate `Wake` variant for waking the calling
/// task.
fn get_node(self: Pin<&mut Self>, get_wake: impl FnOnce() -> Wake) -> &Node {
// SAFETY: `this.node` must not be moved once it's entered in the list
let this = unsafe { self.get_unchecked_mut() };
if let Some(ref node) = this.node {
node
} else {
// FIXME: use `Option::insert()` when stable
let node = this.node.get_or_insert_with(|| Node::new(get_wake()));
// SAFETY: we need an exclusive lock to modify the list
let mut list = this.list.write();
if list.head.is_null() {
// sanity check; see `ListInner` definition
assert!(list.tail.is_null());
// the list is empty so insert this node as both the head and tail
list.head = node;
list.tail = node;
} else {
// sanity check; see `ListInner` definition
assert!(!list.tail.is_null());
// the list is nonempty so insert this node as the tail
// SAFETY: `list.tail` is not null because of the above assert and
// not dangling as long as we have an exclusive lock for modifying the list
// (or any nodes in it)
unsafe {
// set the `next` pointer of the previous tail to this node
(*list.tail).next = node;
}
node.prev = list.tail;
list.tail = node;
}
node
}
}
/// Block until woken.
///
/// Returns `true` if we were woken without the deadline elapsing, `false` if the deadline elapsed.
/// If no deadline is set then this always returns `true` but *will block* until woken.
#[cfg(feature = "blocking")]
pub fn block_on(mut self, deadline: Option<Instant>) -> bool {
// SAFETY:`self.node` may not be moved once entered in the list (`.get_node()` is called)
let mut this = unsafe { Pin::new_unchecked(&mut self) };
let node = this.as_mut().get_node(|| Wake::Thread(thread::current()));
while !node.woken.load(Ordering::Acquire) {
if let Some(deadline) = deadline {
let now = Instant::now();
if deadline < now {
return false;
} else {
// N.B. may wake spuriously
thread::park_timeout(deadline - now);
}
} else {
// N.B. may return spuriously
thread::park();
}
}
// SAFETY: we're not moving out of `this` here
unsafe {
this.get_unchecked_mut().actually_woken = true;
}
true
}
}
// SAFETY: since futures must be pinned to be polled we can be sure that `Drop::drop()` is called
// because there's no way to leak a future without the memory location remaining valid for the
// life of the program:
// * can't be moved into `mem::forget()` or an Rc-cycle because it's pinned
// * leaking `Pin<Box<Wait>>` or via Rc-cycle keeps it around forever, perfectly fine
// * aborting the program means it's not our problem anymore
//
// The only way this could cause memory issues is if the *thread* is aborted without unwinding
// or aborting the process, which doesn't have a safe API in Rust and the C APIs for canceling
// threads don't recommend doing it either for similar reasons.
// * https://man7.org/linux/man-pages/man3/pthread_exit.3.html#DESCRIPTION
// * https://docs.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-exitthread#remarks
//
// However, if Rust were to gain a safe API for instantly exiting a thread it would completely break
// the assumptions that the `Pin` API are built on so it's not something for us to worry about
// specifically.
impl<'a> Drop for Wait<'a> {
fn drop(&mut self) {
if let Some(node) = &self.node {
// if we were inserted into the list then remove the node from the list,
// linking the previous node (if applicable) to the next node (if applicable)
// SAFETY: we must have an exclusive lock while we're futzing with the list
let mut list = self.list.write();
// SAFETY: `prev` cannot be dangling while we have an exclusive lock
if let Some(prev) = unsafe { node.prev.as_mut() } {
// set the `next` pointer of the previous node to this node's `next` pointer
// note: `node.next` may be null which means we're the tail of the list
prev.next = node.next;
} else {
// we were the head of the list so we set the head to the next node
list.head = node.next;
}
// SAFETY: `next` cannot be dangling while we have an exclusive lock
if let Some(next) = unsafe { node.next.as_mut() } {
// set the `prev` pointer of the next node to this node's `prev` pointer
// note: `node.prev` may be null which means we're the head of the list
next.prev = node.prev;
} else {
// we were the tail of the list so we set the tail to the previous node
list.tail = node.prev;
}
// sanity check; see `ListInner` definition
assert_eq!(list.head.is_null(), list.tail.is_null());
// if this node was marked woken but we didn't actually wake,
// then we need to wake the next node in the list
if node.woken.load(Ordering::Acquire) && !self.actually_woken {
// we don't need an exclusive lock anymore
RwLockWriteGuard::downgrade(list).wake(false);
}
}
}
}
struct Node {
/// The previous node in the list. If NULL, then this node is the head of the list.
prev: *mut Node,
/// The next node in the list. If NULL, then this node is the tail of the list.
next: *mut Node,
woken: AtomicBool,
wake: RwLock<Wake>,
}
// SAFETY: access to `Node` pointers must be protected by a lock
unsafe impl Send for Node {}
unsafe impl Sync for Node {}
impl Node {
fn new(wake: Wake) -> Self {
Node {
prev: ptr::null_mut(),
next: ptr::null_mut(),
woken: AtomicBool::new(false),
wake: RwLock::new(wake),
}
}
/// Returns `true` if this node was woken by this call, `false` otherwise.
fn wake(&self) -> bool {
let do_wake =
self.woken.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire).is_ok();
if do_wake {
match &*self.wake.read() {
Wake::Waker(waker) => waker.wake_by_ref(),
#[cfg(feature = "blocking")]
Wake::Thread(thread) => thread.unpark(),
}
}
do_wake
}
}
enum Wake {
Waker(Waker),
#[cfg(feature = "blocking")]
Thread(Thread),
}
impl Wake {
fn waker_eq(&self, waker: &Waker) -> bool {
match self {
Self::Waker(waker_) => waker_.will_wake(waker),
#[cfg(feature = "blocking")]
_ => false,
}
}
}
// note: this test should take about 2 minutes to run!
#[test]
#[cfg(feature = "blocking")]
fn test_wait_list_blocking() {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
const NUM_THREADS: u64 = 200;
let list = Arc::new(WaitList::new());
let mut threads = Vec::new();
// create an arbitrary pattern of deadlines; some of these may elapse, others may not
// the ultimate goal of this test is to make sure that no threads _deadlock_ or segfault
for i in 1..NUM_THREADS {
let ms = i + i * 25 % 100;
let deadline = (i < 100).then(|| Instant::now() + Duration::from_millis(ms));
let list = Arc::new(list.clone());
let thread = Arc::new(AtomicBool::new(false));
threads.push((thread.clone(), deadline));
thread::spawn(move || {
list.wait().block_on(deadline);
thread.store(true, Ordering::Release);
});
}
//
for _ in 1..NUM_THREADS {
thread::sleep(Duration::from_millis(5));
list.wake_one();
}
// wait enough time for all timeouts to elapse
thread::sleep(Duration::from_secs(60));
for (i, (thread, deadline)) in threads.iter().enumerate() {
assert!(
thread.load(Ordering::Acquire),
"thread {} did not exit; deadline: {:?}",
i,
deadline
);
}
}
// #[cfg(all(test, feature = "async"))]
// mod test_async {
// use super::WaitList;
//
// #[cfg(feature = "tokio")]
//
// async fn test_waiter_list() {
// use futures::future::{join_all, FutureExt};
// use futures::pin_mut;
// use std::sync::Arc;
// use std::time::Duration;
//
// let list = Arc::new(WaitList::new());
// let mut tasks = Vec::new();
//
// for _ in 0..1000 {
// let list = list.clone();
//
// tasks.push(spawn(async move {
// list.wait().await;
//
// list.wait().await;
// }));
// }
//
// let waker = async {
// loop {
// list.wake_one();
// yield_now().await;
// }
// }
// .fuse();
//
// let timeout = timeout(Duration::from_secs(10), join_all(tasks)).fuse();
//
// pin_mut!(waker);
// pin_mut!(timeout);
//
// futures::select_biased!(
// res = timeout => res.expect("all tasks should have exited by now"),
// _ = waker => unreachable!("waker shouldn't have quit"),
// );
// }
// }
//
// // N.B. test will run forever
// #[test]
// #[ignore]
// fn test_waiter_list_forever() {
// use async_std::{
// future::{timeout, Future},
// task,
// };
// use futures::future::poll_fn;
// use futures::pin_mut;
// use futures::stream::{FuturesUnordered, StreamExt};
// use std::sync::Arc;
// use std::thread;
// use std::time::Duration;
//
// let list = Arc::new(WaitList::new());
//
// let list_ = list.clone();
// task::spawn(async move {
// let mut unordered = FuturesUnordered::new();
//
// loop {
// unordered.push(WaitList::wait(&list_));
// let _ = timeout(Duration::from_millis(50), unordered.next()).await;
// }
// });
//
// let list_ = list.clone();
// task::spawn(poll_fn::<(), _>(move |cx| {
// let yielder = task::yield_now();
// pin_mut!(yielder);
// let _ = yielder.poll(cx);
//
// let park = WaitList::wait(&list_);
// pin_mut!(park);
// let _ = park.poll(cx);
//
// Poll::Pending
// }));
//
// for num in (0..5).cycle() {
// for _ in 0..num {
// list.wake_one();
// }
//
// thread::sleep(Duration::from_millis(50));
// }
// }