WIP execution refactors

This commit is contained in:
Austin Bonander 2024-04-16 01:57:09 -07:00
parent 6a4f61e3b3
commit b5b981d54f
11 changed files with 628 additions and 197 deletions

View file

@ -66,7 +66,7 @@ futures-channel = { version = "0.3.19", default-features = false, features = ["s
futures-core = { version = "0.3.19", default-features = false }
futures-io = "0.3.24"
futures-intrusive = "0.5.0"
futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] }
futures-util = { version = "0.3.30", default-features = false, features = ["alloc", "sink", "io"] }
hex = "0.4.3"
log = { version = "0.4.14", default-features = false }
memchr = { version = "2.4.1", default-features = false }

View file

@ -58,6 +58,7 @@ use std::fmt::Debug;
use crate::arguments::Arguments;
use crate::column::Column;
use crate::connection::Connection;
use crate::result_set::ResultSet;
use crate::row::Row;
use crate::statement::Statement;
@ -82,6 +83,8 @@ pub trait Database: 'static + Sized + Send + Debug {
/// The concrete `QueryResult` implementation for this database.
type QueryResult: 'static + Sized + Send + Sync + Default + Extend<Self::QueryResult>;
type ResultSet: ResultSet;
/// The concrete `Column` implementation for this database.
type Column: Column<Database = Self>;

View file

@ -5,6 +5,7 @@ use std::mem;
use crate::database::Database;
/// The return type of [Encode::encode].
#[must_use = "NULL values may write no data to the argument buffer"]
pub enum IsNull {
/// The value is null; no data was written.
Yes,
@ -17,20 +18,10 @@ pub enum IsNull {
/// Encode a single value to be sent to the database.
pub trait Encode<'q, DB: Database> {
/// Writes the value of `self` into `buf` in the expected format for the database.
#[must_use]
fn encode(self, buf: &mut <DB as Database>::ArgumentBuffer<'q>) -> IsNull
where
Self: Sized,
{
self.encode_by_ref(buf)
}
/// Writes the value of `self` into `buf` without moving `self`.
///
/// Where possible, make use of `encode` instead as it can take advantage of re-using
/// memory.
#[must_use]
fn encode_by_ref(&self, buf: &mut <DB as Database>::ArgumentBuffer<'q>) -> IsNull;
fn produces(&self) -> Option<DB::TypeInfo> {

View file

@ -2,11 +2,11 @@ use crate::database::Database;
use crate::describe::Describe;
use crate::error::Error;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::{future, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
use std::fmt::Debug;
use std::marker::PhantomData;
use crate::query_string::QueryString;
use crate::result_set::ResultSet;
/// A type that contains or can provide a database
/// connection to use for executing queries against the database.
@ -32,110 +32,19 @@ use std::fmt::Debug;
///
pub trait Executor<'c>: Send + Debug + Sized {
type Database: Database;
type ResultSet: ResultSet<
Database = Self::Database,
Row = <Self::Database as Database>::Row
>;
/// Execute the query and return the total number of rows affected.
fn execute<'e, 'q: 'e, E: 'q>(
self,
query: E,
) -> BoxFuture<'e, Result<<Self::Database as Database>::QueryResult, Error>>
where
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.execute_many(query).try_collect().boxed()
}
/// Execute a query as an implicitly prepared statement.
async fn execute_prepared(&mut self, params: ExecutePrepared<'_, Self::Database>) -> Self::ResultSet;
/// Execute multiple queries and return the rows affected from each query, in a stream.
fn execute_many<'e, 'q: 'e, E: 'q>(
self,
query: E,
) -> BoxStream<'e, Result<<Self::Database as Database>::QueryResult, Error>>
where
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.fetch_many(query)
.try_filter_map(|step| async move {
Ok(match step {
Either::Left(rows) => Some(rows),
Either::Right(_) => None,
})
})
.boxed()
}
/// Execute the query and return the generated results as a stream.
fn fetch<'e, 'q: 'e, E: 'q>(
self,
query: E,
) -> BoxStream<'e, Result<<Self::Database as Database>::Row, Error>>
where
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.fetch_many(query)
.try_filter_map(|step| async move {
Ok(match step {
Either::Left(_) => None,
Either::Right(row) => Some(row),
})
})
.boxed()
}
/// Execute multiple queries and return the generated results as a stream
/// from each query, in a stream.
fn fetch_many<'e, 'q: 'e, E: 'q>(
self,
query: E,
) -> BoxStream<
'e,
Result<
Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
Error,
>,
>
where
'c: 'e,
E: Execute<'q, Self::Database>;
/// Execute the query and return all the generated results, collected into a [`Vec`].
fn fetch_all<'e, 'q: 'e, E: 'q>(
self,
query: E,
) -> BoxFuture<'e, Result<Vec<<Self::Database as Database>::Row>, Error>>
where
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.fetch(query).try_collect().boxed()
}
/// Execute the query and returns exactly one row.
fn fetch_one<'e, 'q: 'e, E: 'q>(
self,
query: E,
) -> BoxFuture<'e, Result<<Self::Database as Database>::Row, Error>>
where
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.fetch_optional(query)
.and_then(|row| match row {
Some(row) => future::ok(row),
None => future::err(Error::RowNotFound),
})
.boxed()
}
/// Execute the query and returns at most one row.
fn fetch_optional<'e, 'q: 'e, E: 'q>(
self,
query: E,
) -> BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
where
'c: 'e,
E: Execute<'q, Self::Database>;
/// Execute raw SQL without creating a prepared statement.
///
/// The SQL string may contain multiple statements separated by semicolons (`;`)
/// as well as DDL (`CREATE TABLE`, `ALTER TABLE`, etc.).
async fn execute_raw(&mut self, params: ExecuteRaw<'_, Self::Database>) -> Self::ResultSet;
/// Prepare the SQL query to inspect the type information of its parameters
/// and results.
@ -146,14 +55,14 @@ pub trait Executor<'c>: Send + Debug + Sized {
/// This explicit API is provided to allow access to the statement metadata available after
/// it prepared but before the first row is returned.
#[inline]
fn prepare<'e, 'q: 'e>(
async fn prepare<'e, 'q: 'e>(
self,
query: &'q str,
) -> BoxFuture<'e, Result<<Self::Database as Database>::Statement<'q>, Error>>
) -> Result<<Self::Database as Database>::Statement<'q>, Error>
where
'c: 'e,
{
self.prepare_with(query, &[])
self.prepare_with(query, &[]).await
}
/// Prepare the SQL query, with parameter type information, to inspect the
@ -161,11 +70,11 @@ pub trait Executor<'c>: Send + Debug + Sized {
///
/// Only some database drivers (PostgreSQL, MSSQL) can take advantage of
/// this extra information to influence parameter type inference.
fn prepare_with<'e, 'q: 'e>(
async fn prepare_with<'e, 'q: 'e>(
self,
sql: &'q str,
parameters: &'e [<Self::Database as Database>::TypeInfo],
) -> BoxFuture<'e, Result<<Self::Database as Database>::Statement<'q>, Error>>
) -> Result<<Self::Database as Database>::Statement<'q>, Error>
where
'c: 'e;
@ -183,73 +92,42 @@ pub trait Executor<'c>: Send + Debug + Sized {
'c: 'e;
}
/// A type that may be executed against a database connection.
///
/// Implemented for the following:
///
/// * [`&str`](std::str)
/// * [`Query`](super::query::Query)
///
pub trait Execute<'q, DB: Database>: Send + Sized {
/// Gets the SQL that will be executed.
fn sql(&self) -> &'q str;
/// Gets the previously cached statement, if available.
fn statement(&self) -> Option<&DB::Statement<'q>>;
/// Returns the arguments to be bound against the query string.
/// Arguments struct for [`Executor::execute_prepared()`].
pub struct ExecutePrepared<'q, DB: Database> {
/// The SQL string to execute.
pub query: QueryString<'_>,
/// The bind arguments for the query string; must match the number of placeholders.
pub arguments: <DB as Database>::Arguments<'q>,
/// The maximum number of rows to return.
///
/// Returning `None` for `Arguments` indicates to use a "simple" query protocol and to not
/// prepare the query. Returning `Some(Default::default())` is an empty arguments object that
/// will be prepared (and cached) before execution.
fn take_arguments(&mut self) -> Option<<DB as Database>::Arguments<'q>>;
/// Returns `true` if the statement should be cached.
fn persistent(&self) -> bool;
/// Set to `Some(0)` to just get the result.
pub limit: Option<u64>,
/// The number of rows to request from the database at a time.
///
/// This is the maximum number of rows that will be buffered in-memory.
///
/// This will also be the maximum number of rows that need to be read and discarded should the
/// [`ResultSet`] be dropped early.
pub buffer: Option<usize>,
/// If `true`, prepare the statement with a name and cache it for later re-use.
pub persistent: bool,
_db: PhantomData<DB>
}
// NOTE: `Execute` is explicitly not implemented for String and &String to make it slightly more
// involved to write `conn.execute(format!("SELECT {val}"))`
impl<'q, DB: Database> Execute<'q, DB> for &'q str {
#[inline]
fn sql(&self) -> &'q str {
self
}
#[inline]
fn statement(&self) -> Option<&DB::Statement<'q>> {
None
}
#[inline]
fn take_arguments(&mut self) -> Option<<DB as Database>::Arguments<'q>> {
None
}
#[inline]
fn persistent(&self) -> bool {
true
}
}
impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<<DB as Database>::Arguments<'q>>) {
#[inline]
fn sql(&self) -> &'q str {
self.0
}
#[inline]
fn statement(&self) -> Option<&DB::Statement<'q>> {
None
}
#[inline]
fn take_arguments(&mut self) -> Option<<DB as Database>::Arguments<'q>> {
self.1.take()
}
#[inline]
fn persistent(&self) -> bool {
true
}
/// Arguments struct for [`Executor::execute_raw()`].
pub struct ExecuteRaw<'q, DB: Database> {
/// The SQL string to execute.
pub query: QueryString<'_>,
/// The maximum number of rows to return.
///
/// Set to `Some(0)` to just get the result.
pub limit: Option<u64>,
/// The number of rows to request from the database at a time.
///
/// This is the maximum number of rows that will be buffered in-memory.
///
/// This will also be the maximum number of rows that need to be read and discarded should the
/// [`ResultSet`] be dropped early.
pub buffer: Option<usize>,
_db: PhantomData<DB>
}

View file

@ -75,7 +75,11 @@ pub mod query_as;
pub mod query_builder;
pub mod query_scalar;
pub mod query_string;
pub mod raw_sql;
pub mod result_set;
pub mod row;
pub mod rt;
pub mod sync;

View file

@ -594,7 +594,7 @@ impl<DB: Database> DecrementSizeGuard<DB> {
pub fn from_permit(pool: Arc<PoolInner<DB>>, permit: AsyncSemaphoreReleaser<'_>) -> Self {
// here we effectively take ownership of the permit
permit.disarm();
permit.consume();
Self::new_permit(pool)
}

View file

@ -9,13 +9,14 @@ use crate::database::{Database, HasStatementCache};
use crate::encode::Encode;
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::query_string::QueryString;
use crate::statement::Statement;
use crate::types::Type;
/// A single SQL query as a prepared statement. Returned by [`query()`].
#[must_use = "query must be executed to affect database"]
pub struct Query<'q, DB: Database, A> {
pub(crate) statement: Either<&'q str, &'q DB::Statement<'q>>,
pub(crate) statement: Either<QueryString<'q>, &'q DB::Statement<'q>>,
pub(crate) arguments: Option<A>,
pub(crate) database: PhantomData<DB>,
pub(crate) persistent: bool,

View file

@ -0,0 +1,115 @@
use std::sync::Arc;
/// A SQL string that is safe to execute on a database connection.
///
/// A "safe" query string is one that is unlikely to contain a [SQL injection vulnerability][injection].
///
/// In practice, this means a string type that is unlikely to contain dynamic data or user input.
///
/// This is designed to act as a speedbump against naively using `format!()` to add dynamic data
/// or user input to a query, which is a classic vector for SQL injection as SQLx does not
/// provide any sort of escaping or sanitization (which would have to be specially implemented
/// for each database flavor/locale).
///
/// The recommended way to incorporate dynamic data or user input in a query is to use
/// bind parameters, which requires the query to execute as a prepared statement.
/// See [`query()`] for details.
///
/// `&'static str` is the only string type that satisfies the requirements of this trait
/// (ignoring [`String::leak()`] which has niche use-cases) and so is the only string type that
/// natively implements this trait by default.
///
/// For other string types, use [`AssertQuerySafe`] to assert this property.
/// This is the only intended way to pass an owned `String` to [`query()`] and its related functions.
///
/// This trait and `AssertQuerySafe` are intentionally analogous to [`std::panic::UnwindSafe`] and
/// [`std::panic::AssertUnwindSafe`].
///
/// [injection]: https://en.wikipedia.org/wiki/SQL_injection
/// [`query()`]: crate::query::query
pub trait QuerySafeStr<'a> {
///
fn wrap(self) -> QueryString<'a>;
}
impl QuerySafeStr<'static> for &'static str {
fn wrap(self) -> QueryString<'static> {
QueryString(Repr::Slice(self))
}
}
/// Assert that some string type is safe to execute on a database connection.
///
/// Using this API means that **you** have made sure that the string contents do not contain a
/// [SQL injection vulnerability][injection]. It means that, if the string was constructed
/// dynamically, and/or from user input, you have taken care to sanitize the input yourself.
///
/// The maintainers of SQLx take no responsibility for any data leaks or loss resulting from the use
/// of this API. **Use at your own risk.**
///
/// Note that `&'static str` implements [`QuerySafeStr`] directly and so does not need to be wrapped
/// with this type.
///
/// [injection]: https://en.wikipedia.org/wiki/SQL_injection
pub struct AssertQuerySafe<T>(pub T);
impl<'a> QuerySafeStr<'a> for AssertQuerySafe<&'a str> {
fn wrap(self) -> QueryString<'a> {
QueryString(Repr::Slice(self.0))
}
}
impl QuerySafeStr<'static> for AssertQuerySafe<String> {
fn wrap(self) -> QueryString<'static> {
// For `Repr` to not be 4 words wide, we convert `String` to `Box<str>`
QueryString(Repr::Boxed(self.0.into()))
}
}
impl QuerySafeStr<'static> for AssertQuerySafe<Box<str>> {
fn wrap(self) -> QueryString<'static> {
QueryString(Repr::Boxed(self.0))
}
}
// Note: this is not implemented for `Rc<str>` because it would make `QueryString: !Send`.
impl QuerySafeStr<'static> for AssertQuerySafe<Arc<str>> {
fn wrap(self) -> QueryString<'static> {
QueryString(Repr::Arced(self.into()))
}
}
/// A SQL string that is ready to execute on a database connection.
///
/// This is essentially `Cow<'a, str>` but which can be constructed from additional types
/// without copying.
///
/// See [`QuerySafeStr`] for details.
pub struct QueryString<'a>(Repr<'a>);
impl<'a> QuerySafeStr<'a> for QueryString<'a> {
fn wrap(self) -> QueryString<'a> {
self
}
}
impl QueryString<'_> {
pub fn into_static(self) -> QueryString<'static> {
QueryString(match self.0 {
Repr::Slice(s) => Repr::Boxed(s.into()),
Repr::StaticSlice(s) => Repr::StaticSlice(s),
Repr::Boxed(s) => Repr::Boxed(s),
Repr::Arced(s) => Repr::Arced(s),
})
}
}
enum Repr<'a> {
Slice(&'a str),
// We need a variant to memoize when we already have a static string, so we don't copy it.
StaticSlice(&'static str),
// This enum would be 4 words wide if this variant existed. Instead, convert to `Box<str>`.
// Owned(String),
Boxed(Box<str>),
Arced(Arc<str>),
}

144
sqlx-core/src/result_set.rs Normal file
View file

@ -0,0 +1,144 @@
use std::marker::PhantomData;
use std::mem;
use either::Either;
use crate::database::Database;
use crate::from_row::FromRow;
use crate::row::Row;
use crate::sync::spsc;
pub trait ResultSet {
type Database: Database;
type Row;
/// Wait for the server to return the next row in the result set.
///
/// Returns `Ok(None)` when the result set is exhausted. The result of the query
/// will be available from [`.next_result()`][Self::next_result].
///
/// If the query was capable of returning multiple result sets, this will start returning
/// rows again after returning `Ok(None)`.
///
/// This clears the stored result if one was available but `.next_result()` was not called.
async fn next_row(&mut self) -> crate::Result<Option<Self::Row>>;
/// Wait for the next query result, giving the number of rows affected.
///
/// If [`.next_row()`][Self::next_row] returned `Ok(None)`, the result should be cached
/// internally and this should return immediately.
///
/// If there are any rows buffered before the next query result, this will discard them.
async fn next_result(&mut self) -> crate::Result<Option<<Self::Database as Database>::QueryResult>>;
fn map_row<F, R>(self, map: F) -> MapRow<Self, R, F> where F: FnMut(Self::Row) -> crate::Result<R> {
MapRow {
map,
inner: self,
row: PhantomData
}
}
fn map_from_row<T>(self) -> MapFromRow<Self, T>
where
Self::Row: Row<Database = Self::Database>,
T: for<'r> FromRow<'r, Self::Row>
{
self.map_row(|row| T::from_row(&row))
}
async fn collect_rows<T: Default + Extend<Self::Row>>(&mut self) -> crate::Result<T> {
let mut rows_out = T::default();
while let Some(row) = self.next_row().await? {
rows_out.extend(Some(row));
}
Ok(rows_out)
}
}
pub struct MapRow<Rs, Row, F> {
map: F,
inner: Rs,
row: PhantomData<Row>,
}
impl<Rs, Row, F> ResultSet for MapRow<Rs, Row, F>
where
Rs: ResultSet,
F: FnMut(Rs::Row) -> Row
{
type Database = Rs::Database;
type Row = Row;
async fn next_row(&mut self) -> crate::Result<Option<Self::Row>> {
let maybe_row = self.inner.next_row().await?;
Ok(maybe_row.map(&mut self.map))
}
async fn next_result(&mut self) -> crate::Result<Option<<Self::Database as Database>::QueryResult>> {
self.inner.next_result()
}
}
pub type MapFromRow<Rs: ResultSet, Row> = MapRow<Rs, Row, fn(Rs::Row) -> Row>;
pub struct ChannelResultSet<DB: Database> {
flavor: Flavor<DB>,
last_result: Option<DB::QueryResult>,
}
enum Flavor<DB: Database> {
Channel(spsc::Receiver<crate::Result<Either<DB::QueryResult, DB::Row>>>),
Error(crate::Error),
Empty,
}
impl<DB: Database> ResultSet for ChannelResultSet<DB> {
type Database = DB;
type Row = DB::Row;
async fn next_row(&mut self) -> crate::Result<Option<Self::Row>> {
// Clear the previous result if it was ignored.
self.last_result = None;
match self.flavor.recv().await? {
Some(Either::Left(result)) => {
self.last_result = Some(result);
return Ok(None);
}
Some(Either::Right(row)) => Ok(Some(row)),
None => Ok(None),
}
}
async fn next_result(&mut self) -> crate::Result<Option<<Self::Database as Database>::QueryResult>> {
if let Some(result) = self.last_result.take() {
return Ok(Some(result));
}
loop {
match self.flavor.recv().await? {
Some(Either::Left(result)) => return Ok(Some(result)),
// Drop rows until the next result.
Some(Either::Right(_)) => (),
None => return Ok(None),
}
}
}
}
impl<DB: Database> Flavor<DB> {
async fn recv(&mut self) -> crate::Result<Option<Either<DB::QueryResult, DB::Row>>> {
match self {
Self::Channel(chan) => chan.recv().await.transpose(),
Self::Error(_) => {
let Self::Error(e) = mem::replace(self, Self::Empty) else {
unreachable!()
};
Err(e)
}
Self::Empty => Ok(None)
}
}
}

View file

@ -1,15 +1,12 @@
use std::ops::{Deref, DerefMut};
pub mod spsc;
// For types with identical signatures that don't require runtime support,
// we can just arbitrarily pick one to use based on what's enabled.
//
// We'll generally lean towards Tokio's types as those are more featureful
// (including `tokio-console` support) and more widely deployed.
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
pub use async_std::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard};
#[cfg(feature = "_rt-tokio")]
pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard};
pub struct AsyncSemaphore {
// We use the semaphore from futures-intrusive as the one from async-std
// is missing the ability to add arbitrary permits, and is not guaranteed to be fair:
@ -125,7 +122,7 @@ pub struct AsyncSemaphoreReleaser<'a> {
}
impl AsyncSemaphoreReleaser<'_> {
pub fn disarm(self) {
pub fn consume(self) {
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
{
let mut this = self;
@ -143,3 +140,65 @@ impl AsyncSemaphoreReleaser<'_> {
crate::rt::missing_rt(())
}
}
pub struct AsyncMutex<T> {
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
inner: async_std::sync::Mutex<T>,
#[cfg(feature = "_rt-tokio")]
inner: tokio::sync::Mutex<T>,
}
pub struct AsyncMutexGuard<'a, T> {
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
inner: async_std::sync::MutexGuard<'a, T>,
#[cfg(feature = "_rt-tokio")]
inner: tokio::sync::MutexGuard<'a, T>,
}
impl<T> AsyncMutex<T> {
pub fn new(value: T) -> Self {
if cfg!(not(any(feature = "_rt-async-std", feature = "_rt-tokio"))) {
crate::rt::missing_rt(value);
}
AsyncMutex {
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
inner: async_std::sync::Mutex::new(value),
#[cfg(feature = "_rt-tokio")]
inner: tokio::sync::Mutex::new(value),
}
}
pub async fn lock(&self) -> AsyncMutexGuard<'_, T> {
AsyncMutexGuard {
inner: self.inner.lock().await,
}
}
pub fn try_lock(&self) -> Option<AsyncMutexGuard<'_, T>> {
Some(AsyncMutexGuard {
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
inner: self.inner.try_lock()?,
#[cfg(feature = "_rt-tokio")]
inner: self.inner.try_lock().ok()?,
})
}
}
impl<T> Deref for AsyncMutexGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<T> DerefMut for AsyncMutexGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}

236
sqlx-core/src/sync/spsc.rs Normal file
View file

@ -0,0 +1,236 @@
//! A cooperatively bounded SPSC channel.
//!
//! Senders may either obey the channel capacity but will have to wait when it is exhausted
//! ([`Sender::send`]) or ignore the channel capacity when necessary ([`Sender::send_unbounded()`]),
//! e.g. in a `Drop` impl where neither blocking nor waiting is acceptable.
use std::collections::VecDeque;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::Poll;
use futures_util::task::AtomicWaker;
use crate::sync::{AsyncMutex, AsyncSemaphore};
/// The sender side of a cooperatively bounded SPSC channel.
///
/// The channel is closed when either this or [`Receiver`] is dropped.
pub struct Sender<T> {
channel: Arc<Channel<T>>,
}
/// The receiver side of a cooperatively bounded SPSC channel.
///
/// The channel is closed when either this or [`Sender`] is dropped.
pub struct Receiver<T> {
channel: Arc<Channel<T>>,
}
struct Channel<T> {
buffer: Buffer<T>,
closed: AtomicBool,
capacity: usize,
semaphore: AsyncSemaphore,
recv_waker: AtomicWaker,
// Ensure we only attempt to read from the channel if this was a legitimate wakeup.
receiver_woken: AtomicBool,
}
struct Buffer<T> {
// Using double-buffering so the sender and receiver aren't constantly fighting
front: AsyncMutex<Deque<T>>,
back: AsyncMutex<Deque<T>>,
write_front: AtomicBool,
}
struct Deque<T> {
messages: VecDeque<T>,
/// Each value represents the number of bounded sends at the head of the queue.
///
/// Pushing a message that used a permit should increment the value at the back of this queue,
/// or push a new value of 1 if the value is 255 or 0, or the queue is empty.
///
/// Pushing a message that did _not_ use a permit (unbounded send) should simultaneously
/// push a zero value to this queue.
///
/// Popping a message should decrement the value at the head of this queue,
/// removing it when it reaches 0. The receiver should simultaneously add a permit
/// back to the channel semaphore.
///
/// A 0 value at the head of the queue thus indicates an unbounded send,
/// which should _not_ result in the release of a permit.
permits_used: VecDeque<u8>,
}
impl<T> Sender<T> {
/// Send a message or wait for a permit to be released from the receiver.
///
/// ### Cancel-Safe
/// This method is entirely cancel-safe. It is guaranteed to _only_ send the message
/// immediately before it returns.
///
/// Contrast this to [`flume::Sender::send_async()`][flume-cancel-safe], where
/// the message _may_ be received any time after the first poll if it does not return `Ready`
/// immediately.
///
/// [flume-cancel-safe]: https://github.com/zesterer/flume/issues/104#issuecomment-1216387210
pub async fn send(&mut self, message: T) -> Result<(), T> {
if self.channel.closed.load(Ordering::Acquire) {
return Err(message);
}
let permit = self.channel.semaphore.acquire(1).await;
self.send_inner(message, true).map(|_| permit.consume())
}
/// Send a message immediately.
///
/// ### Note: Only Use when Necessary
/// This call ignores channel capacity and should only be used where blocking or waiting
/// is not an option, e.g. in a `Drop` impl.
pub fn send_unbounded(&mut self, message: T) -> Result<(), T> {
self.send_inner(message, false)
}
fn send_inner(&mut self, message: T, permit_used: bool) -> Result<(), T> {
if self.channel.closed.load(Ordering::Acquire) {
return Err(message);
}
self.channel.buffer.write(message, permit_used);
self.channel.receiver_woken.store(true, Ordering::Release);
self.channel.recv_waker.wake();
Ok(())
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
self.channel.closed.store(true, Ordering::Release);
self.channel.recv_waker.wake();
}
}
impl<T> Receiver<T> {
pub async fn recv(&mut self) -> Option<T> {
loop {
if self.channel.closed.load(Ordering::Acquire) {
return None;
}
if let Some(message) = self.channel.buffer.read().await {
return Some(message);
}
futures_util::future::poll_fn(|cx| {
let ready = self.channel.closed.load(Ordering::Acquire)
|| self.channel.receiver_woken.load(Ordering::Acquire);
// Clear the `receiver_woken` flag.
self.channel.receiver_woken.store(false, Ordering::Release);
if ready {
Poll::Ready(())
} else {
// Ensure the waker is up-to-date every time we're polled.
self.channel.recv_waker.register(cx.waker());
Poll::Pending
}
})
.await;
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
self.channel.closed.store(true, Ordering::Relaxed);
self.channel.recv_waker.take();
self.channel.semaphore.release(self.channel.capacity);
}
}
impl<T> Buffer<T> {
pub fn write(&self, value: T, permit_used: bool) {
let mut side = if self.write_front.load(Ordering::Acquire) {
self.front
.try_lock()
.expect("BUG: receiver has front buffer locked while reading back buffer")
} else {
self.back
.try_lock()
.expect("BUG: receiver has back buffer locked while reading front buffer")
};
side.messages.push_back(value);
if permit_used {
side.permits_used = side.permits_used.checked_add(1)
.expect("BUG: permits_used overflowed!");
}
}
pub async fn read(&self) -> Option<(T, bool)> {
// If the sender is writing the front, we should read the back and vice versa.
let read_back = self.write_front.load(Ordering::Acquire);
let mut side = if read_back {
// If we just swapped buffers, we may need to wait for the sender to release the lock.
self.back.lock().await
} else {
self.front.lock().await
};
let val = side.messages.pop_front();
// It doesn't actually matter if this exact message actually used a permit,
// it just matters that we made room in the channel.
let permit_used = side.permits_used.checked_sub(1)
.is_some_and(|permits_used| {
side.permits_used = permits_used;
true
});
// Note: be sure to release the lock before swapping or `write()` will panic.
drop(side);
if val.is_none() {
// This side is empty; swap.
self.write_front.store(!read_back, Ordering::Release);
}
val.map(|val| (val, permit_used))
}
}
pub fn channel<T>(bounded_capacity: usize) -> (Sender<T>, Receiver<T>) {
let channel = Arc::new(Channel {
buffer: Buffer {
front: AsyncMutex::new(Deque {
messages: VecDeque::with_capacity(bounded_capacity),
permits_used: 0,
}),
back: AsyncMutex::new(Deque {
messages: VecDeque::with_capacity(bounded_capacity),
permits_used: 0,
}),
write_front: true.into()
},
closed: false.into(),
capacity: bounded_capacity,
semaphore: AsyncSemaphore::new(true, bounded_capacity),
recv_waker: Default::default(),
receiver_woken: false.into(),
});
(
Sender {
channel: channel.clone(),
},
Receiver {
channel
}
)
}