mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
WIP execution refactors
This commit is contained in:
parent
6a4f61e3b3
commit
b5b981d54f
11 changed files with 628 additions and 197 deletions
|
@ -66,7 +66,7 @@ futures-channel = { version = "0.3.19", default-features = false, features = ["s
|
|||
futures-core = { version = "0.3.19", default-features = false }
|
||||
futures-io = "0.3.24"
|
||||
futures-intrusive = "0.5.0"
|
||||
futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] }
|
||||
futures-util = { version = "0.3.30", default-features = false, features = ["alloc", "sink", "io"] }
|
||||
hex = "0.4.3"
|
||||
log = { version = "0.4.14", default-features = false }
|
||||
memchr = { version = "2.4.1", default-features = false }
|
||||
|
|
|
@ -58,6 +58,7 @@ use std::fmt::Debug;
|
|||
use crate::arguments::Arguments;
|
||||
use crate::column::Column;
|
||||
use crate::connection::Connection;
|
||||
use crate::result_set::ResultSet;
|
||||
use crate::row::Row;
|
||||
|
||||
use crate::statement::Statement;
|
||||
|
@ -82,6 +83,8 @@ pub trait Database: 'static + Sized + Send + Debug {
|
|||
/// The concrete `QueryResult` implementation for this database.
|
||||
type QueryResult: 'static + Sized + Send + Sync + Default + Extend<Self::QueryResult>;
|
||||
|
||||
type ResultSet: ResultSet;
|
||||
|
||||
/// The concrete `Column` implementation for this database.
|
||||
type Column: Column<Database = Self>;
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ use std::mem;
|
|||
use crate::database::Database;
|
||||
|
||||
/// The return type of [Encode::encode].
|
||||
#[must_use = "NULL values may write no data to the argument buffer"]
|
||||
pub enum IsNull {
|
||||
/// The value is null; no data was written.
|
||||
Yes,
|
||||
|
@ -17,20 +18,10 @@ pub enum IsNull {
|
|||
|
||||
/// Encode a single value to be sent to the database.
|
||||
pub trait Encode<'q, DB: Database> {
|
||||
/// Writes the value of `self` into `buf` in the expected format for the database.
|
||||
#[must_use]
|
||||
fn encode(self, buf: &mut <DB as Database>::ArgumentBuffer<'q>) -> IsNull
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
self.encode_by_ref(buf)
|
||||
}
|
||||
|
||||
/// Writes the value of `self` into `buf` without moving `self`.
|
||||
///
|
||||
/// Where possible, make use of `encode` instead as it can take advantage of re-using
|
||||
/// memory.
|
||||
#[must_use]
|
||||
fn encode_by_ref(&self, buf: &mut <DB as Database>::ArgumentBuffer<'q>) -> IsNull;
|
||||
|
||||
fn produces(&self) -> Option<DB::TypeInfo> {
|
||||
|
|
|
@ -2,11 +2,11 @@ use crate::database::Database;
|
|||
use crate::describe::Describe;
|
||||
use crate::error::Error;
|
||||
|
||||
use either::Either;
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_core::stream::BoxStream;
|
||||
use futures_util::{future, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
|
||||
use std::fmt::Debug;
|
||||
use std::marker::PhantomData;
|
||||
use crate::query_string::QueryString;
|
||||
use crate::result_set::ResultSet;
|
||||
|
||||
/// A type that contains or can provide a database
|
||||
/// connection to use for executing queries against the database.
|
||||
|
@ -32,110 +32,19 @@ use std::fmt::Debug;
|
|||
///
|
||||
pub trait Executor<'c>: Send + Debug + Sized {
|
||||
type Database: Database;
|
||||
type ResultSet: ResultSet<
|
||||
Database = Self::Database,
|
||||
Row = <Self::Database as Database>::Row
|
||||
>;
|
||||
|
||||
/// Execute the query and return the total number of rows affected.
|
||||
fn execute<'e, 'q: 'e, E: 'q>(
|
||||
self,
|
||||
query: E,
|
||||
) -> BoxFuture<'e, Result<<Self::Database as Database>::QueryResult, Error>>
|
||||
where
|
||||
'c: 'e,
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
self.execute_many(query).try_collect().boxed()
|
||||
}
|
||||
/// Execute a query as an implicitly prepared statement.
|
||||
async fn execute_prepared(&mut self, params: ExecutePrepared<'_, Self::Database>) -> Self::ResultSet;
|
||||
|
||||
/// Execute multiple queries and return the rows affected from each query, in a stream.
|
||||
fn execute_many<'e, 'q: 'e, E: 'q>(
|
||||
self,
|
||||
query: E,
|
||||
) -> BoxStream<'e, Result<<Self::Database as Database>::QueryResult, Error>>
|
||||
where
|
||||
'c: 'e,
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
self.fetch_many(query)
|
||||
.try_filter_map(|step| async move {
|
||||
Ok(match step {
|
||||
Either::Left(rows) => Some(rows),
|
||||
Either::Right(_) => None,
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Execute the query and return the generated results as a stream.
|
||||
fn fetch<'e, 'q: 'e, E: 'q>(
|
||||
self,
|
||||
query: E,
|
||||
) -> BoxStream<'e, Result<<Self::Database as Database>::Row, Error>>
|
||||
where
|
||||
'c: 'e,
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
self.fetch_many(query)
|
||||
.try_filter_map(|step| async move {
|
||||
Ok(match step {
|
||||
Either::Left(_) => None,
|
||||
Either::Right(row) => Some(row),
|
||||
})
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Execute multiple queries and return the generated results as a stream
|
||||
/// from each query, in a stream.
|
||||
fn fetch_many<'e, 'q: 'e, E: 'q>(
|
||||
self,
|
||||
query: E,
|
||||
) -> BoxStream<
|
||||
'e,
|
||||
Result<
|
||||
Either<<Self::Database as Database>::QueryResult, <Self::Database as Database>::Row>,
|
||||
Error,
|
||||
>,
|
||||
>
|
||||
where
|
||||
'c: 'e,
|
||||
E: Execute<'q, Self::Database>;
|
||||
|
||||
/// Execute the query and return all the generated results, collected into a [`Vec`].
|
||||
fn fetch_all<'e, 'q: 'e, E: 'q>(
|
||||
self,
|
||||
query: E,
|
||||
) -> BoxFuture<'e, Result<Vec<<Self::Database as Database>::Row>, Error>>
|
||||
where
|
||||
'c: 'e,
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
self.fetch(query).try_collect().boxed()
|
||||
}
|
||||
|
||||
/// Execute the query and returns exactly one row.
|
||||
fn fetch_one<'e, 'q: 'e, E: 'q>(
|
||||
self,
|
||||
query: E,
|
||||
) -> BoxFuture<'e, Result<<Self::Database as Database>::Row, Error>>
|
||||
where
|
||||
'c: 'e,
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
self.fetch_optional(query)
|
||||
.and_then(|row| match row {
|
||||
Some(row) => future::ok(row),
|
||||
None => future::err(Error::RowNotFound),
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Execute the query and returns at most one row.
|
||||
fn fetch_optional<'e, 'q: 'e, E: 'q>(
|
||||
self,
|
||||
query: E,
|
||||
) -> BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
|
||||
where
|
||||
'c: 'e,
|
||||
E: Execute<'q, Self::Database>;
|
||||
/// Execute raw SQL without creating a prepared statement.
|
||||
///
|
||||
/// The SQL string may contain multiple statements separated by semicolons (`;`)
|
||||
/// as well as DDL (`CREATE TABLE`, `ALTER TABLE`, etc.).
|
||||
async fn execute_raw(&mut self, params: ExecuteRaw<'_, Self::Database>) -> Self::ResultSet;
|
||||
|
||||
/// Prepare the SQL query to inspect the type information of its parameters
|
||||
/// and results.
|
||||
|
@ -146,14 +55,14 @@ pub trait Executor<'c>: Send + Debug + Sized {
|
|||
/// This explicit API is provided to allow access to the statement metadata available after
|
||||
/// it prepared but before the first row is returned.
|
||||
#[inline]
|
||||
fn prepare<'e, 'q: 'e>(
|
||||
async fn prepare<'e, 'q: 'e>(
|
||||
self,
|
||||
query: &'q str,
|
||||
) -> BoxFuture<'e, Result<<Self::Database as Database>::Statement<'q>, Error>>
|
||||
) -> Result<<Self::Database as Database>::Statement<'q>, Error>
|
||||
where
|
||||
'c: 'e,
|
||||
{
|
||||
self.prepare_with(query, &[])
|
||||
self.prepare_with(query, &[]).await
|
||||
}
|
||||
|
||||
/// Prepare the SQL query, with parameter type information, to inspect the
|
||||
|
@ -161,11 +70,11 @@ pub trait Executor<'c>: Send + Debug + Sized {
|
|||
///
|
||||
/// Only some database drivers (PostgreSQL, MSSQL) can take advantage of
|
||||
/// this extra information to influence parameter type inference.
|
||||
fn prepare_with<'e, 'q: 'e>(
|
||||
async fn prepare_with<'e, 'q: 'e>(
|
||||
self,
|
||||
sql: &'q str,
|
||||
parameters: &'e [<Self::Database as Database>::TypeInfo],
|
||||
) -> BoxFuture<'e, Result<<Self::Database as Database>::Statement<'q>, Error>>
|
||||
) -> Result<<Self::Database as Database>::Statement<'q>, Error>
|
||||
where
|
||||
'c: 'e;
|
||||
|
||||
|
@ -183,73 +92,42 @@ pub trait Executor<'c>: Send + Debug + Sized {
|
|||
'c: 'e;
|
||||
}
|
||||
|
||||
/// A type that may be executed against a database connection.
|
||||
///
|
||||
/// Implemented for the following:
|
||||
///
|
||||
/// * [`&str`](std::str)
|
||||
/// * [`Query`](super::query::Query)
|
||||
///
|
||||
pub trait Execute<'q, DB: Database>: Send + Sized {
|
||||
/// Gets the SQL that will be executed.
|
||||
fn sql(&self) -> &'q str;
|
||||
|
||||
/// Gets the previously cached statement, if available.
|
||||
fn statement(&self) -> Option<&DB::Statement<'q>>;
|
||||
|
||||
/// Returns the arguments to be bound against the query string.
|
||||
/// Arguments struct for [`Executor::execute_prepared()`].
|
||||
pub struct ExecutePrepared<'q, DB: Database> {
|
||||
/// The SQL string to execute.
|
||||
pub query: QueryString<'_>,
|
||||
/// The bind arguments for the query string; must match the number of placeholders.
|
||||
pub arguments: <DB as Database>::Arguments<'q>,
|
||||
/// The maximum number of rows to return.
|
||||
///
|
||||
/// Returning `None` for `Arguments` indicates to use a "simple" query protocol and to not
|
||||
/// prepare the query. Returning `Some(Default::default())` is an empty arguments object that
|
||||
/// will be prepared (and cached) before execution.
|
||||
fn take_arguments(&mut self) -> Option<<DB as Database>::Arguments<'q>>;
|
||||
|
||||
/// Returns `true` if the statement should be cached.
|
||||
fn persistent(&self) -> bool;
|
||||
/// Set to `Some(0)` to just get the result.
|
||||
pub limit: Option<u64>,
|
||||
/// The number of rows to request from the database at a time.
|
||||
///
|
||||
/// This is the maximum number of rows that will be buffered in-memory.
|
||||
///
|
||||
/// This will also be the maximum number of rows that need to be read and discarded should the
|
||||
/// [`ResultSet`] be dropped early.
|
||||
pub buffer: Option<usize>,
|
||||
/// If `true`, prepare the statement with a name and cache it for later re-use.
|
||||
pub persistent: bool,
|
||||
_db: PhantomData<DB>
|
||||
}
|
||||
|
||||
// NOTE: `Execute` is explicitly not implemented for String and &String to make it slightly more
|
||||
// involved to write `conn.execute(format!("SELECT {val}"))`
|
||||
impl<'q, DB: Database> Execute<'q, DB> for &'q str {
|
||||
#[inline]
|
||||
fn sql(&self) -> &'q str {
|
||||
self
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn statement(&self) -> Option<&DB::Statement<'q>> {
|
||||
None
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn take_arguments(&mut self) -> Option<<DB as Database>::Arguments<'q>> {
|
||||
None
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn persistent(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<<DB as Database>::Arguments<'q>>) {
|
||||
#[inline]
|
||||
fn sql(&self) -> &'q str {
|
||||
self.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn statement(&self) -> Option<&DB::Statement<'q>> {
|
||||
None
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn take_arguments(&mut self) -> Option<<DB as Database>::Arguments<'q>> {
|
||||
self.1.take()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn persistent(&self) -> bool {
|
||||
true
|
||||
}
|
||||
/// Arguments struct for [`Executor::execute_raw()`].
|
||||
pub struct ExecuteRaw<'q, DB: Database> {
|
||||
/// The SQL string to execute.
|
||||
pub query: QueryString<'_>,
|
||||
/// The maximum number of rows to return.
|
||||
///
|
||||
/// Set to `Some(0)` to just get the result.
|
||||
pub limit: Option<u64>,
|
||||
/// The number of rows to request from the database at a time.
|
||||
///
|
||||
/// This is the maximum number of rows that will be buffered in-memory.
|
||||
///
|
||||
/// This will also be the maximum number of rows that need to be read and discarded should the
|
||||
/// [`ResultSet`] be dropped early.
|
||||
pub buffer: Option<usize>,
|
||||
_db: PhantomData<DB>
|
||||
}
|
||||
|
|
|
@ -75,7 +75,11 @@ pub mod query_as;
|
|||
pub mod query_builder;
|
||||
pub mod query_scalar;
|
||||
|
||||
pub mod query_string;
|
||||
|
||||
pub mod raw_sql;
|
||||
|
||||
pub mod result_set;
|
||||
pub mod row;
|
||||
pub mod rt;
|
||||
pub mod sync;
|
||||
|
|
|
@ -594,7 +594,7 @@ impl<DB: Database> DecrementSizeGuard<DB> {
|
|||
|
||||
pub fn from_permit(pool: Arc<PoolInner<DB>>, permit: AsyncSemaphoreReleaser<'_>) -> Self {
|
||||
// here we effectively take ownership of the permit
|
||||
permit.disarm();
|
||||
permit.consume();
|
||||
Self::new_permit(pool)
|
||||
}
|
||||
|
||||
|
|
|
@ -9,13 +9,14 @@ use crate::database::{Database, HasStatementCache};
|
|||
use crate::encode::Encode;
|
||||
use crate::error::Error;
|
||||
use crate::executor::{Execute, Executor};
|
||||
use crate::query_string::QueryString;
|
||||
use crate::statement::Statement;
|
||||
use crate::types::Type;
|
||||
|
||||
/// A single SQL query as a prepared statement. Returned by [`query()`].
|
||||
#[must_use = "query must be executed to affect database"]
|
||||
pub struct Query<'q, DB: Database, A> {
|
||||
pub(crate) statement: Either<&'q str, &'q DB::Statement<'q>>,
|
||||
pub(crate) statement: Either<QueryString<'q>, &'q DB::Statement<'q>>,
|
||||
pub(crate) arguments: Option<A>,
|
||||
pub(crate) database: PhantomData<DB>,
|
||||
pub(crate) persistent: bool,
|
||||
|
|
115
sqlx-core/src/query_string.rs
Normal file
115
sqlx-core/src/query_string.rs
Normal file
|
@ -0,0 +1,115 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
/// A SQL string that is safe to execute on a database connection.
|
||||
///
|
||||
/// A "safe" query string is one that is unlikely to contain a [SQL injection vulnerability][injection].
|
||||
///
|
||||
/// In practice, this means a string type that is unlikely to contain dynamic data or user input.
|
||||
///
|
||||
/// This is designed to act as a speedbump against naively using `format!()` to add dynamic data
|
||||
/// or user input to a query, which is a classic vector for SQL injection as SQLx does not
|
||||
/// provide any sort of escaping or sanitization (which would have to be specially implemented
|
||||
/// for each database flavor/locale).
|
||||
///
|
||||
/// The recommended way to incorporate dynamic data or user input in a query is to use
|
||||
/// bind parameters, which requires the query to execute as a prepared statement.
|
||||
/// See [`query()`] for details.
|
||||
///
|
||||
/// `&'static str` is the only string type that satisfies the requirements of this trait
|
||||
/// (ignoring [`String::leak()`] which has niche use-cases) and so is the only string type that
|
||||
/// natively implements this trait by default.
|
||||
///
|
||||
/// For other string types, use [`AssertQuerySafe`] to assert this property.
|
||||
/// This is the only intended way to pass an owned `String` to [`query()`] and its related functions.
|
||||
///
|
||||
/// This trait and `AssertQuerySafe` are intentionally analogous to [`std::panic::UnwindSafe`] and
|
||||
/// [`std::panic::AssertUnwindSafe`].
|
||||
///
|
||||
/// [injection]: https://en.wikipedia.org/wiki/SQL_injection
|
||||
/// [`query()`]: crate::query::query
|
||||
pub trait QuerySafeStr<'a> {
|
||||
///
|
||||
fn wrap(self) -> QueryString<'a>;
|
||||
}
|
||||
|
||||
impl QuerySafeStr<'static> for &'static str {
|
||||
fn wrap(self) -> QueryString<'static> {
|
||||
QueryString(Repr::Slice(self))
|
||||
}
|
||||
}
|
||||
|
||||
/// Assert that some string type is safe to execute on a database connection.
|
||||
///
|
||||
/// Using this API means that **you** have made sure that the string contents do not contain a
|
||||
/// [SQL injection vulnerability][injection]. It means that, if the string was constructed
|
||||
/// dynamically, and/or from user input, you have taken care to sanitize the input yourself.
|
||||
///
|
||||
/// The maintainers of SQLx take no responsibility for any data leaks or loss resulting from the use
|
||||
/// of this API. **Use at your own risk.**
|
||||
///
|
||||
/// Note that `&'static str` implements [`QuerySafeStr`] directly and so does not need to be wrapped
|
||||
/// with this type.
|
||||
///
|
||||
/// [injection]: https://en.wikipedia.org/wiki/SQL_injection
|
||||
pub struct AssertQuerySafe<T>(pub T);
|
||||
|
||||
impl<'a> QuerySafeStr<'a> for AssertQuerySafe<&'a str> {
|
||||
fn wrap(self) -> QueryString<'a> {
|
||||
QueryString(Repr::Slice(self.0))
|
||||
}
|
||||
}
|
||||
impl QuerySafeStr<'static> for AssertQuerySafe<String> {
|
||||
fn wrap(self) -> QueryString<'static> {
|
||||
// For `Repr` to not be 4 words wide, we convert `String` to `Box<str>`
|
||||
QueryString(Repr::Boxed(self.0.into()))
|
||||
}
|
||||
}
|
||||
|
||||
impl QuerySafeStr<'static> for AssertQuerySafe<Box<str>> {
|
||||
fn wrap(self) -> QueryString<'static> {
|
||||
QueryString(Repr::Boxed(self.0))
|
||||
}
|
||||
}
|
||||
|
||||
// Note: this is not implemented for `Rc<str>` because it would make `QueryString: !Send`.
|
||||
impl QuerySafeStr<'static> for AssertQuerySafe<Arc<str>> {
|
||||
fn wrap(self) -> QueryString<'static> {
|
||||
QueryString(Repr::Arced(self.into()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// A SQL string that is ready to execute on a database connection.
|
||||
///
|
||||
/// This is essentially `Cow<'a, str>` but which can be constructed from additional types
|
||||
/// without copying.
|
||||
///
|
||||
/// See [`QuerySafeStr`] for details.
|
||||
pub struct QueryString<'a>(Repr<'a>);
|
||||
|
||||
impl<'a> QuerySafeStr<'a> for QueryString<'a> {
|
||||
fn wrap(self) -> QueryString<'a> {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl QueryString<'_> {
|
||||
pub fn into_static(self) -> QueryString<'static> {
|
||||
QueryString(match self.0 {
|
||||
Repr::Slice(s) => Repr::Boxed(s.into()),
|
||||
Repr::StaticSlice(s) => Repr::StaticSlice(s),
|
||||
Repr::Boxed(s) => Repr::Boxed(s),
|
||||
Repr::Arced(s) => Repr::Arced(s),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
enum Repr<'a> {
|
||||
Slice(&'a str),
|
||||
// We need a variant to memoize when we already have a static string, so we don't copy it.
|
||||
StaticSlice(&'static str),
|
||||
// This enum would be 4 words wide if this variant existed. Instead, convert to `Box<str>`.
|
||||
// Owned(String),
|
||||
Boxed(Box<str>),
|
||||
Arced(Arc<str>),
|
||||
}
|
144
sqlx-core/src/result_set.rs
Normal file
144
sqlx-core/src/result_set.rs
Normal file
|
@ -0,0 +1,144 @@
|
|||
use std::marker::PhantomData;
|
||||
use std::mem;
|
||||
use either::Either;
|
||||
use crate::database::Database;
|
||||
use crate::from_row::FromRow;
|
||||
use crate::row::Row;
|
||||
use crate::sync::spsc;
|
||||
|
||||
pub trait ResultSet {
|
||||
type Database: Database;
|
||||
|
||||
type Row;
|
||||
|
||||
/// Wait for the server to return the next row in the result set.
|
||||
///
|
||||
/// Returns `Ok(None)` when the result set is exhausted. The result of the query
|
||||
/// will be available from [`.next_result()`][Self::next_result].
|
||||
///
|
||||
/// If the query was capable of returning multiple result sets, this will start returning
|
||||
/// rows again after returning `Ok(None)`.
|
||||
///
|
||||
/// This clears the stored result if one was available but `.next_result()` was not called.
|
||||
async fn next_row(&mut self) -> crate::Result<Option<Self::Row>>;
|
||||
|
||||
/// Wait for the next query result, giving the number of rows affected.
|
||||
///
|
||||
/// If [`.next_row()`][Self::next_row] returned `Ok(None)`, the result should be cached
|
||||
/// internally and this should return immediately.
|
||||
///
|
||||
/// If there are any rows buffered before the next query result, this will discard them.
|
||||
async fn next_result(&mut self) -> crate::Result<Option<<Self::Database as Database>::QueryResult>>;
|
||||
|
||||
fn map_row<F, R>(self, map: F) -> MapRow<Self, R, F> where F: FnMut(Self::Row) -> crate::Result<R> {
|
||||
MapRow {
|
||||
map,
|
||||
inner: self,
|
||||
row: PhantomData
|
||||
}
|
||||
}
|
||||
|
||||
fn map_from_row<T>(self) -> MapFromRow<Self, T>
|
||||
where
|
||||
Self::Row: Row<Database = Self::Database>,
|
||||
T: for<'r> FromRow<'r, Self::Row>
|
||||
{
|
||||
self.map_row(|row| T::from_row(&row))
|
||||
}
|
||||
|
||||
async fn collect_rows<T: Default + Extend<Self::Row>>(&mut self) -> crate::Result<T> {
|
||||
let mut rows_out = T::default();
|
||||
|
||||
while let Some(row) = self.next_row().await? {
|
||||
rows_out.extend(Some(row));
|
||||
}
|
||||
|
||||
Ok(rows_out)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MapRow<Rs, Row, F> {
|
||||
map: F,
|
||||
inner: Rs,
|
||||
row: PhantomData<Row>,
|
||||
}
|
||||
|
||||
impl<Rs, Row, F> ResultSet for MapRow<Rs, Row, F>
|
||||
where
|
||||
Rs: ResultSet,
|
||||
F: FnMut(Rs::Row) -> Row
|
||||
{
|
||||
type Database = Rs::Database;
|
||||
type Row = Row;
|
||||
|
||||
async fn next_row(&mut self) -> crate::Result<Option<Self::Row>> {
|
||||
let maybe_row = self.inner.next_row().await?;
|
||||
Ok(maybe_row.map(&mut self.map))
|
||||
}
|
||||
|
||||
async fn next_result(&mut self) -> crate::Result<Option<<Self::Database as Database>::QueryResult>> {
|
||||
self.inner.next_result()
|
||||
}
|
||||
}
|
||||
|
||||
pub type MapFromRow<Rs: ResultSet, Row> = MapRow<Rs, Row, fn(Rs::Row) -> Row>;
|
||||
|
||||
pub struct ChannelResultSet<DB: Database> {
|
||||
flavor: Flavor<DB>,
|
||||
last_result: Option<DB::QueryResult>,
|
||||
}
|
||||
|
||||
enum Flavor<DB: Database> {
|
||||
Channel(spsc::Receiver<crate::Result<Either<DB::QueryResult, DB::Row>>>),
|
||||
Error(crate::Error),
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl<DB: Database> ResultSet for ChannelResultSet<DB> {
|
||||
type Database = DB;
|
||||
type Row = DB::Row;
|
||||
|
||||
async fn next_row(&mut self) -> crate::Result<Option<Self::Row>> {
|
||||
// Clear the previous result if it was ignored.
|
||||
self.last_result = None;
|
||||
|
||||
match self.flavor.recv().await? {
|
||||
Some(Either::Left(result)) => {
|
||||
self.last_result = Some(result);
|
||||
return Ok(None);
|
||||
}
|
||||
Some(Either::Right(row)) => Ok(Some(row)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
async fn next_result(&mut self) -> crate::Result<Option<<Self::Database as Database>::QueryResult>> {
|
||||
if let Some(result) = self.last_result.take() {
|
||||
return Ok(Some(result));
|
||||
}
|
||||
|
||||
loop {
|
||||
match self.flavor.recv().await? {
|
||||
Some(Either::Left(result)) => return Ok(Some(result)),
|
||||
// Drop rows until the next result.
|
||||
Some(Either::Right(_)) => (),
|
||||
None => return Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB: Database> Flavor<DB> {
|
||||
async fn recv(&mut self) -> crate::Result<Option<Either<DB::QueryResult, DB::Row>>> {
|
||||
match self {
|
||||
Self::Channel(chan) => chan.recv().await.transpose(),
|
||||
Self::Error(_) => {
|
||||
let Self::Error(e) = mem::replace(self, Self::Empty) else {
|
||||
unreachable!()
|
||||
};
|
||||
Err(e)
|
||||
}
|
||||
Self::Empty => Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,15 +1,12 @@
|
|||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
pub mod spsc;
|
||||
|
||||
// For types with identical signatures that don't require runtime support,
|
||||
// we can just arbitrarily pick one to use based on what's enabled.
|
||||
//
|
||||
// We'll generally lean towards Tokio's types as those are more featureful
|
||||
// (including `tokio-console` support) and more widely deployed.
|
||||
|
||||
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
|
||||
pub use async_std::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard};
|
||||
|
||||
#[cfg(feature = "_rt-tokio")]
|
||||
pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard};
|
||||
|
||||
pub struct AsyncSemaphore {
|
||||
// We use the semaphore from futures-intrusive as the one from async-std
|
||||
// is missing the ability to add arbitrary permits, and is not guaranteed to be fair:
|
||||
|
@ -125,7 +122,7 @@ pub struct AsyncSemaphoreReleaser<'a> {
|
|||
}
|
||||
|
||||
impl AsyncSemaphoreReleaser<'_> {
|
||||
pub fn disarm(self) {
|
||||
pub fn consume(self) {
|
||||
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
|
||||
{
|
||||
let mut this = self;
|
||||
|
@ -143,3 +140,65 @@ impl AsyncSemaphoreReleaser<'_> {
|
|||
crate::rt::missing_rt(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AsyncMutex<T> {
|
||||
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
|
||||
inner: async_std::sync::Mutex<T>,
|
||||
|
||||
#[cfg(feature = "_rt-tokio")]
|
||||
inner: tokio::sync::Mutex<T>,
|
||||
}
|
||||
|
||||
pub struct AsyncMutexGuard<'a, T> {
|
||||
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
|
||||
inner: async_std::sync::MutexGuard<'a, T>,
|
||||
|
||||
#[cfg(feature = "_rt-tokio")]
|
||||
inner: tokio::sync::MutexGuard<'a, T>,
|
||||
}
|
||||
|
||||
impl<T> AsyncMutex<T> {
|
||||
pub fn new(value: T) -> Self {
|
||||
if cfg!(not(any(feature = "_rt-async-std", feature = "_rt-tokio"))) {
|
||||
crate::rt::missing_rt(value);
|
||||
}
|
||||
|
||||
AsyncMutex {
|
||||
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
|
||||
inner: async_std::sync::Mutex::new(value),
|
||||
|
||||
#[cfg(feature = "_rt-tokio")]
|
||||
inner: tokio::sync::Mutex::new(value),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn lock(&self) -> AsyncMutexGuard<'_, T> {
|
||||
AsyncMutexGuard {
|
||||
inner: self.inner.lock().await,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn try_lock(&self) -> Option<AsyncMutexGuard<'_, T>> {
|
||||
Some(AsyncMutexGuard {
|
||||
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
|
||||
inner: self.inner.try_lock()?,
|
||||
|
||||
#[cfg(feature = "_rt-tokio")]
|
||||
inner: self.inner.try_lock().ok()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Deref for AsyncMutexGuard<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> DerefMut for AsyncMutexGuard<'_, T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
236
sqlx-core/src/sync/spsc.rs
Normal file
236
sqlx-core/src/sync/spsc.rs
Normal file
|
@ -0,0 +1,236 @@
|
|||
//! A cooperatively bounded SPSC channel.
|
||||
//!
|
||||
//! Senders may either obey the channel capacity but will have to wait when it is exhausted
|
||||
//! ([`Sender::send`]) or ignore the channel capacity when necessary ([`Sender::send_unbounded()`]),
|
||||
//! e.g. in a `Drop` impl where neither blocking nor waiting is acceptable.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::task::Poll;
|
||||
|
||||
use futures_util::task::AtomicWaker;
|
||||
|
||||
use crate::sync::{AsyncMutex, AsyncSemaphore};
|
||||
|
||||
/// The sender side of a cooperatively bounded SPSC channel.
|
||||
///
|
||||
/// The channel is closed when either this or [`Receiver`] is dropped.
|
||||
pub struct Sender<T> {
|
||||
channel: Arc<Channel<T>>,
|
||||
}
|
||||
|
||||
/// The receiver side of a cooperatively bounded SPSC channel.
|
||||
///
|
||||
/// The channel is closed when either this or [`Sender`] is dropped.
|
||||
pub struct Receiver<T> {
|
||||
channel: Arc<Channel<T>>,
|
||||
}
|
||||
|
||||
struct Channel<T> {
|
||||
buffer: Buffer<T>,
|
||||
closed: AtomicBool,
|
||||
capacity: usize,
|
||||
semaphore: AsyncSemaphore,
|
||||
recv_waker: AtomicWaker,
|
||||
// Ensure we only attempt to read from the channel if this was a legitimate wakeup.
|
||||
receiver_woken: AtomicBool,
|
||||
}
|
||||
|
||||
struct Buffer<T> {
|
||||
// Using double-buffering so the sender and receiver aren't constantly fighting
|
||||
front: AsyncMutex<Deque<T>>,
|
||||
back: AsyncMutex<Deque<T>>,
|
||||
write_front: AtomicBool,
|
||||
}
|
||||
|
||||
struct Deque<T> {
|
||||
messages: VecDeque<T>,
|
||||
/// Each value represents the number of bounded sends at the head of the queue.
|
||||
///
|
||||
/// Pushing a message that used a permit should increment the value at the back of this queue,
|
||||
/// or push a new value of 1 if the value is 255 or 0, or the queue is empty.
|
||||
///
|
||||
/// Pushing a message that did _not_ use a permit (unbounded send) should simultaneously
|
||||
/// push a zero value to this queue.
|
||||
///
|
||||
/// Popping a message should decrement the value at the head of this queue,
|
||||
/// removing it when it reaches 0. The receiver should simultaneously add a permit
|
||||
/// back to the channel semaphore.
|
||||
///
|
||||
/// A 0 value at the head of the queue thus indicates an unbounded send,
|
||||
/// which should _not_ result in the release of a permit.
|
||||
permits_used: VecDeque<u8>,
|
||||
}
|
||||
|
||||
impl<T> Sender<T> {
|
||||
/// Send a message or wait for a permit to be released from the receiver.
|
||||
///
|
||||
/// ### Cancel-Safe
|
||||
/// This method is entirely cancel-safe. It is guaranteed to _only_ send the message
|
||||
/// immediately before it returns.
|
||||
///
|
||||
/// Contrast this to [`flume::Sender::send_async()`][flume-cancel-safe], where
|
||||
/// the message _may_ be received any time after the first poll if it does not return `Ready`
|
||||
/// immediately.
|
||||
///
|
||||
/// [flume-cancel-safe]: https://github.com/zesterer/flume/issues/104#issuecomment-1216387210
|
||||
pub async fn send(&mut self, message: T) -> Result<(), T> {
|
||||
if self.channel.closed.load(Ordering::Acquire) {
|
||||
return Err(message);
|
||||
}
|
||||
|
||||
let permit = self.channel.semaphore.acquire(1).await;
|
||||
|
||||
self.send_inner(message, true).map(|_| permit.consume())
|
||||
}
|
||||
|
||||
/// Send a message immediately.
|
||||
///
|
||||
/// ### Note: Only Use when Necessary
|
||||
/// This call ignores channel capacity and should only be used where blocking or waiting
|
||||
/// is not an option, e.g. in a `Drop` impl.
|
||||
pub fn send_unbounded(&mut self, message: T) -> Result<(), T> {
|
||||
self.send_inner(message, false)
|
||||
}
|
||||
|
||||
fn send_inner(&mut self, message: T, permit_used: bool) -> Result<(), T> {
|
||||
if self.channel.closed.load(Ordering::Acquire) {
|
||||
return Err(message);
|
||||
}
|
||||
|
||||
self.channel.buffer.write(message, permit_used);
|
||||
self.channel.receiver_woken.store(true, Ordering::Release);
|
||||
self.channel.recv_waker.wake();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Sender<T> {
|
||||
fn drop(&mut self) {
|
||||
self.channel.closed.store(true, Ordering::Release);
|
||||
self.channel.recv_waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Receiver<T> {
|
||||
pub async fn recv(&mut self) -> Option<T> {
|
||||
loop {
|
||||
if self.channel.closed.load(Ordering::Acquire) {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(message) = self.channel.buffer.read().await {
|
||||
return Some(message);
|
||||
}
|
||||
|
||||
futures_util::future::poll_fn(|cx| {
|
||||
let ready = self.channel.closed.load(Ordering::Acquire)
|
||||
|| self.channel.receiver_woken.load(Ordering::Acquire);
|
||||
|
||||
// Clear the `receiver_woken` flag.
|
||||
self.channel.receiver_woken.store(false, Ordering::Release);
|
||||
|
||||
if ready {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
// Ensure the waker is up-to-date every time we're polled.
|
||||
self.channel.recv_waker.register(cx.waker());
|
||||
Poll::Pending
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Receiver<T> {
|
||||
fn drop(&mut self) {
|
||||
self.channel.closed.store(true, Ordering::Relaxed);
|
||||
self.channel.recv_waker.take();
|
||||
self.channel.semaphore.release(self.channel.capacity);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Buffer<T> {
|
||||
pub fn write(&self, value: T, permit_used: bool) {
|
||||
let mut side = if self.write_front.load(Ordering::Acquire) {
|
||||
self.front
|
||||
.try_lock()
|
||||
.expect("BUG: receiver has front buffer locked while reading back buffer")
|
||||
} else {
|
||||
self.back
|
||||
.try_lock()
|
||||
.expect("BUG: receiver has back buffer locked while reading front buffer")
|
||||
};
|
||||
|
||||
side.messages.push_back(value);
|
||||
|
||||
if permit_used {
|
||||
side.permits_used = side.permits_used.checked_add(1)
|
||||
.expect("BUG: permits_used overflowed!");
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read(&self) -> Option<(T, bool)> {
|
||||
// If the sender is writing the front, we should read the back and vice versa.
|
||||
let read_back = self.write_front.load(Ordering::Acquire);
|
||||
|
||||
let mut side = if read_back {
|
||||
// If we just swapped buffers, we may need to wait for the sender to release the lock.
|
||||
self.back.lock().await
|
||||
} else {
|
||||
self.front.lock().await
|
||||
};
|
||||
|
||||
let val = side.messages.pop_front();
|
||||
|
||||
// It doesn't actually matter if this exact message actually used a permit,
|
||||
// it just matters that we made room in the channel.
|
||||
let permit_used = side.permits_used.checked_sub(1)
|
||||
.is_some_and(|permits_used| {
|
||||
side.permits_used = permits_used;
|
||||
true
|
||||
});
|
||||
|
||||
// Note: be sure to release the lock before swapping or `write()` will panic.
|
||||
drop(side);
|
||||
|
||||
if val.is_none() {
|
||||
// This side is empty; swap.
|
||||
self.write_front.store(!read_back, Ordering::Release);
|
||||
}
|
||||
|
||||
val.map(|val| (val, permit_used))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn channel<T>(bounded_capacity: usize) -> (Sender<T>, Receiver<T>) {
|
||||
let channel = Arc::new(Channel {
|
||||
buffer: Buffer {
|
||||
front: AsyncMutex::new(Deque {
|
||||
messages: VecDeque::with_capacity(bounded_capacity),
|
||||
permits_used: 0,
|
||||
}),
|
||||
back: AsyncMutex::new(Deque {
|
||||
messages: VecDeque::with_capacity(bounded_capacity),
|
||||
permits_used: 0,
|
||||
}),
|
||||
write_front: true.into()
|
||||
},
|
||||
closed: false.into(),
|
||||
capacity: bounded_capacity,
|
||||
semaphore: AsyncSemaphore::new(true, bounded_capacity),
|
||||
recv_waker: Default::default(),
|
||||
receiver_woken: false.into(),
|
||||
});
|
||||
|
||||
(
|
||||
Sender {
|
||||
channel: channel.clone(),
|
||||
},
|
||||
Receiver {
|
||||
channel
|
||||
}
|
||||
)
|
||||
}
|
Loading…
Reference in a new issue