fix(core): async-stream crate seems to lose the stream if the stream owns the object we are streaming

hand-rolled a copy of the idea behind AsyncStream and things seem to work
This commit is contained in:
Ryan Leckey 2020-06-09 02:16:47 -07:00
parent 2677046a3b
commit e1d22a1840
20 changed files with 271 additions and 167 deletions

44
Cargo.lock generated
View file

@ -51,9 +51,9 @@ dependencies = [
[[package]]
name = "ahash"
version = "0.3.5"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f3e0bf23f51883cce372d5d5892211236856e4bb37fb942e1eb135ee0f146e3"
checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217"
[[package]]
name = "aho-corasick"
@ -144,27 +144,6 @@ dependencies = [
"wasm-bindgen-futures",
]
[[package]]
name = "async-stream"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22068c0c19514942eefcfd4daf8976ef1aad84e61539f95cd200c35202f80af5"
dependencies = [
"async-stream-impl",
"futures-core",
]
[[package]]
name = "async-stream-impl"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25f9db3b38af870bf7e5cc649167533b493928e50744e2c30ae350230b414670"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "async-task"
version = "3.0.0"
@ -173,9 +152,9 @@ checksum = "c17772156ef2829aadc587461c7753af20b7e8db1529bc66855add962a3b35d3"
[[package]]
name = "async-trait"
version = "0.1.33"
version = "0.1.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f1c13101a3224fb178860ae372a031ce350bbd92d39968518f016744dde0bf7"
checksum = "89cb5d814ab2a47fd66d3266e9efccb53ca4c740b7451043b8ffcf9a6208f3f8"
dependencies = [
"proc-macro2",
"quote",
@ -1217,7 +1196,7 @@ checksum = "f5e374eff525ce1c5b7687c4cef63943e7686524a387933ad27ca7ec43779cb3"
dependencies = [
"log",
"mio",
"miow 0.3.4",
"miow 0.3.5",
"winapi 0.3.8",
]
@ -1246,9 +1225,9 @@ dependencies = [
[[package]]
name = "miow"
version = "0.3.4"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22dfdd1d51b2639a5abd17ed07005c3af05fb7a2a3b1a1d0d7af1000a520c1c7"
checksum = "07b88fb9795d4d36d62a012dfbf49a8f5cf12751f36d31a9dbe66d528e58979e"
dependencies = [
"socket2",
"winapi 0.3.8",
@ -1619,9 +1598,9 @@ checksum = "7e0456befd48169b9f13ef0f0ad46d492cf9d2dbb918bcf38e01eed4ce3ec5e4"
[[package]]
name = "proc-macro-nested"
version = "0.1.4"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e946095f9d3ed29ec38de908c22f95d9ac008e424c7bcae54c75a79c527c694"
checksum = "0afe1bd463b9e9ed51d0e0f0b50b6b146aec855c56fd182bb242388710a9b6de"
[[package]]
name = "proc-macro2"
@ -1640,9 +1619,9 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]]
name = "quote"
version = "1.0.6"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54a21852a652ad6f610c9510194f398ff6f8692e334fd1145fed931f7fbe44ea"
checksum = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37"
dependencies = [
"proc-macro2",
]
@ -2058,7 +2037,6 @@ dependencies = [
name = "sqlx-core"
version = "0.4.0-pre"
dependencies = [
"async-stream",
"atoi",
"base64 0.12.1",
"bigdecimal",

View file

@ -38,7 +38,6 @@ offline = [ "serde" ]
[dependencies]
atoi = "0.3.2"
sqlx-rt = { path = "../sqlx-rt", version = "0.1.0-pre" }
async-stream = { version = "0.2.1", default-features = false }
base64 = { version = "0.12.1", default-features = false, optional = true, features = [ "std" ] }
bigdecimal_ = { version = "0.1.0", optional = true, package = "bigdecimal" }
bitflags = { version = "1.2.1", default-features = false }
@ -51,9 +50,9 @@ crossbeam-utils = { version = "0.7.2", default-features = false }
digest = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] }
encoding_rs = { version = "0.8.23", optional = true }
either = "1.5.3"
futures-channel = { version = "0.3.5", default-features = false, features = [ "alloc", "std" ] }
futures-channel = { version = "0.3.5", default-features = false, features = [ "sink", "alloc", "std" ] }
futures-core = { version = "0.3.5", default-features = false }
futures-util = "0.3.5"
futures-util = { version = "0.3.5", features = [ "sink" ] }
generic-array = { version = "0.12.3", default-features = false, optional = true }
hashbrown = "0.7.2"
hex = "0.4.2"

View file

@ -0,0 +1,73 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_channel::mpsc;
use futures_core::future::BoxFuture;
use futures_core::stream::Stream;
use futures_util::{pin_mut, FutureExt, SinkExt};
use crate::error::Error;
pub struct TryAsyncStream<'a, T> {
receiver: mpsc::Receiver<Result<T, Error>>,
future: BoxFuture<'a, Result<(), Error>>,
}
impl<'a, T> TryAsyncStream<'a, T> {
pub fn new<F, Fut>(f: F) -> Self
where
F: FnOnce(mpsc::Sender<Result<T, Error>>) -> Fut + Send,
Fut: 'a + Future<Output = Result<(), Error>> + Send,
T: 'a + Send,
{
let (mut sender, receiver) = mpsc::channel(1);
let future = f(sender.clone());
let future = async move {
if let Err(error) = future.await {
let _ = sender.send(Err(error));
}
Ok(())
}
.fuse()
.boxed();
Self { future, receiver }
}
}
impl<'a, T> Stream for TryAsyncStream<'a, T> {
type Item = Result<T, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let future = &mut self.future;
pin_mut!(future);
// the future is fused so its safe to call forever
// the future advances our "stream"
// the future should be polled in tandem with the stream receiver
let _ = future.poll(cx);
let receiver = &mut self.receiver;
pin_mut!(receiver);
// then we check to see if we have anything to return
return receiver.poll_next(cx);
}
}
macro_rules! try_stream2 {
($($block:tt)*) => {
crate::ext::async_stream::TryAsyncStream::new(move |mut sender| async move {
macro_rules! r#yield {
($v:expr) => {
let _ = futures_util::sink::SinkExt::send(&mut sender, Ok($v)).await;
}
}
{$($block)*}
})
}
}

View file

@ -1,2 +1,5 @@
pub mod maybe_owned;
pub mod ustr;
#[macro_use]
pub mod async_stream;

View file

@ -17,6 +17,9 @@
#[cfg(feature = "bigdecimal")]
extern crate bigdecimal_ as bigdecimal;
#[macro_use]
mod ext;
#[macro_use]
pub mod error;
@ -38,7 +41,6 @@ pub mod database;
pub mod decode;
pub mod describe;
pub mod executor;
mod ext;
pub mod from_row;
mod io;
mod net;

View file

@ -1,4 +1,3 @@
use async_stream::try_stream;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
@ -101,7 +100,7 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
let s = query.query();
let arguments = query.take_arguments();
Box::pin(try_stream! {
Box::pin(try_stream2! {
self.run(s, arguments).await?;
loop {
@ -109,14 +108,12 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
match message {
Message::Row(row) => {
let v = Either::Right(MssqlRow { row });
yield v;
r#yield!(Either::Right(MssqlRow { row }));
}
Message::Done(done) | Message::DoneProc(done) => {
if done.status.contains(Status::DONE_COUNT) {
let v = Either::Left(done.affected_rows);
yield v;
r#yield!(Either::Left(done.affected_rows));
}
if !done.status.contains(Status::DONE_MORE) {
@ -127,14 +124,15 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection {
Message::DoneInProc(done) => {
if done.status.contains(Status::DONE_COUNT) {
let v = Either::Left(done.affected_rows);
yield v;
r#yield!(Either::Left(done.affected_rows));
}
}
_ => {}
}
}
Ok(())
})
}

View file

@ -1,6 +1,5 @@
use std::sync::Arc;
use async_stream::try_stream;
use bytes::Bytes;
use either::Either;
use futures_core::future::BoxFuture;
@ -12,6 +11,7 @@ use crate::describe::{Column, Describe};
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::ext::ustr::UStr;
use crate::mysql::connection::stream::Busy;
use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::response::Status;
use crate::mysql::protocol::statement::{
@ -111,7 +111,7 @@ impl MySqlConnection {
arguments: Option<MySqlArguments>,
) -> Result<impl Stream<Item = Result<Either<u64, MySqlRow>, Error>> + 'c, Error> {
self.stream.wait_until_ready().await?;
self.stream.busy = true;
self.stream.busy = Busy::Result;
let format = if let Some(arguments) = arguments {
let statement = self.prepare(query).await?;
@ -132,30 +132,30 @@ impl MySqlConnection {
MySqlValueFormat::Text
};
Ok(try_stream! {
Ok(Box::pin(try_stream2! {
loop {
// query response is a meta-packet which may be one of:
// Ok, Err, ResultSet, or (unhandled) LocalInfileRequest
let mut packet = self.stream.recv_packet().await?;
let packet = self.stream.recv_packet().await?;
if packet[0] == 0x00 || packet[0] == 0xff {
// first packet in a query response is OK or ERR
// this indicates either a successful query with no rows at all or a failed query
let ok = packet.ok()?;
let v = Either::Left(ok.affected_rows);
yield v;
r#yield!(Either::Left(ok.affected_rows));
if ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
// more result sets exist, continue to the next one
continue;
}
self.stream.busy = false;
return;
self.stream.busy = Busy::NotBusy;
return Ok(());
}
// otherwise, this first packet is the start of the result-set metadata,
self.stream.busy = Busy::Row;
self.recv_result_metadata(packet).await?;
// finally, there will be none or many result-rows
@ -164,17 +164,16 @@ impl MySqlConnection {
if packet[0] == 0xfe && packet.len() < 9 {
let eof = packet.eof(self.stream.capabilities)?;
let v = Either::Left(0);
yield v;
r#yield!(Either::Left(0));
if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
// more result sets exist, continue to the next one
self.stream.busy = Busy::Result;
break;
}
self.stream.busy = false;
return;
self.stream.busy = Busy::NotBusy;
return Ok(());
}
let row = match format {
@ -189,10 +188,10 @@ impl MySqlConnection {
column_names: Arc::clone(&self.scratch_row_column_names),
});
yield v;
r#yield!(v);
}
}
})
}))
}
}
@ -210,13 +209,15 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
let s = query.query();
let arguments = query.take_arguments();
Box::pin(try_stream! {
Box::pin(try_stream2! {
let s = self.run(s, arguments).await?;
pin_mut!(s);
while let Some(v) = s.try_next().await? {
yield v;
r#yield!(v);
}
Ok(())
})
}

View file

@ -20,7 +20,7 @@ mod executor;
mod stream;
mod tls;
pub(crate) use stream::MySqlStream;
pub(crate) use stream::{Busy, MySqlStream};
const COLLATE_UTF8MB4_UNICODE_CI: u8 = 224;
@ -120,10 +120,10 @@ impl Connect for MySqlConnection {
// https://mathiasbynens.be/notes/mysql-utf8mb4
conn.execute(r#"
SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'));
SET time_zone = '+00:00';
SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci;
"#).await?;
SET sql_mode=(SELECT CONCAT(@@sql_mode, ',PIPES_AS_CONCAT,NO_ENGINE_SUBSTITUTION,NO_ZERO_DATE,NO_ZERO_IN_DATE'));
SET time_zone = '+00:00';
SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci;
"#).await?;
Ok(conn)
})

View file

@ -5,7 +5,8 @@ use sqlx_rt::TcpStream;
use crate::error::Error;
use crate::io::{BufStream, Decode, Encode};
use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket};
use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
use crate::mysql::protocol::{Capabilities, Packet};
use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError};
use crate::net::MaybeTlsStream;
@ -14,7 +15,18 @@ pub struct MySqlStream {
stream: BufStream<MaybeTlsStream<TcpStream>>,
pub(super) capabilities: Capabilities,
pub(crate) sequence_id: u8,
pub(crate) busy: bool,
pub(crate) busy: Busy,
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum Busy {
NotBusy,
// waiting for a result set
Result,
// waiting for a row within a result set
Row,
}
impl MySqlStream {
@ -39,7 +51,7 @@ impl MySqlStream {
}
Ok(Self {
busy: false,
busy: Busy::NotBusy,
capabilities,
sequence_id: 0,
stream: BufStream::new(MaybeTlsStream::Raw(stream)),
@ -51,19 +63,33 @@ impl MySqlStream {
self.stream.flush().await?;
}
if self.busy {
loop {
while self.busy != Busy::NotBusy {
while self.busy == Busy::Row {
let packet = self.recv_packet().await?;
match packet[0] {
0x00 | 0xfe if packet.len() < 9 => {
// OK or EOF packet
self.busy = false;
break;
}
_ => {
// Something else; skip
if packet[0] == 0xfe && packet.len() < 9 {
let eof = packet.eof(self.capabilities)?;
self.busy = if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
Busy::Result
} else {
Busy::NotBusy
};
}
}
while self.busy == Busy::Result {
let packet = self.recv_packet().await?;
if packet[0] == 0x00 || packet[0] == 0xff {
let ok = packet.ok()?;
if !ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) {
self.busy = Busy::NotBusy;
}
} else {
self.busy = Busy::Row;
self.skip_result_metadata(packet).await?;
}
}
}
@ -107,7 +133,7 @@ impl MySqlStream {
// TODO: packet joining
if payload[0] == 0xff {
self.busy = false;
self.busy = Busy::NotBusy;
// instead of letting this packet be looked at everywhere, we check here
// and emit a proper Error
@ -137,6 +163,18 @@ impl MySqlStream {
self.recv().await.map(Some)
}
}
async fn skip_result_metadata(&mut self, mut packet: Packet<Bytes>) -> Result<(), Error> {
let num_columns: u64 = packet.get_uint_lenenc(); // column count
for _ in 0..num_columns {
let _ = self.recv_packet().await?;
}
self.maybe_recv_eof().await?;
Ok(())
}
}
impl Deref for MySqlStream {

View file

@ -2,6 +2,7 @@ use futures_core::future::BoxFuture;
use crate::error::Error;
use crate::executor::Executor;
use crate::mysql::connection::Busy;
use crate::mysql::protocol::text::Query;
use crate::mysql::{MySql, MySqlConnection};
use crate::transaction::{
@ -40,7 +41,7 @@ impl TransactionManager for MySqlTransactionManager {
}
fn start_rollback(conn: &mut MySqlConnection, depth: usize) {
conn.stream.busy = true;
conn.stream.busy = Busy::Result;
conn.stream.sequence_id = 0;
conn.stream
.write_packet(Query(&*rollback_ansi_transaction_sql(depth)));

View file

@ -1,4 +1,3 @@
use async_stream::try_stream;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
@ -25,13 +24,15 @@ where
{
let pool = self.clone();
Box::pin(try_stream! {
Box::pin(try_stream2! {
let mut conn = pool.acquire().await?;
let mut s = conn.fetch_many(query);
for v in s.try_next().await? {
yield v;
while let Some(v) = s.try_next().await? {
r#yield!(v);
}
Ok(())
})
}
@ -80,7 +81,7 @@ macro_rules! impl_executor_for_pool_connection {
'c: 'e,
E: crate::executor::Execute<'q, $DB>,
{
(&mut **self).fetch_many(query)
(**self).fetch_many(query)
}
#[inline]
@ -92,7 +93,7 @@ macro_rules! impl_executor_for_pool_connection {
'c: 'e,
E: crate::executor::Execute<'q, $DB>,
{
(&mut **self).fetch_optional(query)
(**self).fetch_optional(query)
}
#[doc(hidden)]
@ -108,7 +109,7 @@ macro_rules! impl_executor_for_pool_connection {
'c: 'e,
E: crate::executor::Execute<'q, $DB>,
{
(&mut **self).describe(query)
(**self).describe(query)
}
}
};

View file

@ -1,10 +1,9 @@
//! Connection pool for SQLx database connections.
use std::{
fmt,
sync::Arc,
time::{Duration, Instant},
};
use std::fmt;
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::database::Database;
use crate::error::Error;
@ -50,8 +49,9 @@ impl<DB: Database> Pool<DB> {
/// Retrieves a connection from the pool.
///
/// Waits for at most the configured connection timeout before returning an error.
pub async fn acquire(&self) -> Result<PoolConnection<DB>, Error> {
self.0.acquire().await.map(|conn| conn.attach(&self.0))
pub fn acquire(&self) -> impl Future<Output = Result<PoolConnection<DB>, Error>> + 'static {
let shared = self.0.clone();
async move { shared.acquire().await.map(|conn| conn.attach(&shared)) }
}
/// Attempts to retrieve a connection from the pool if there is one available.

View file

@ -1,4 +1,3 @@
use async_stream::try_stream;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
@ -181,7 +180,7 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
self.pending_ready_for_query_count += 1;
self.stream.flush().await?;
Ok(try_stream! {
Ok(try_stream2! {
loop {
let message = self.stream.recv().await?;
@ -194,7 +193,7 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
// a SQL command completed normally
let cc: CommandComplete = message.decode()?;
yield Either::Left(cc.rows_affected());
r#yield!(Either::Left(cc.rows_affected()));
}
MessageFormat::EmptyQueryResponse => {
@ -218,7 +217,7 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
column_names: Arc::clone(&self.scratch_row_column_names),
};
yield Either::Right(row);
r#yield!(Either::Right(row));
}
MessageFormat::ReadyForQuery => {
@ -235,6 +234,8 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
}
}
}
Ok(())
})
}
}
@ -253,13 +254,15 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
let s = query.query();
let arguments = query.take_arguments();
Box::pin(try_stream! {
Box::pin(try_stream2! {
let s = self.run(s, arguments, 0).await?;
pin_mut!(s);
while let Some(s) = s.try_next().await? {
yield s;
while let Some(v) = s.try_next().await? {
r#yield!(v);
}
Ok(())
})
}

View file

@ -2,7 +2,6 @@ use std::fmt::{self, Debug};
use std::io;
use std::str::from_utf8;
use async_stream::try_stream;
use futures_channel::mpsc;
use futures_core::future::BoxFuture;
use futures_core::stream::{BoxStream, Stream};
@ -185,10 +184,9 @@ impl PgListener {
/// Consume this listener, returning a `Stream` of notifications.
pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification, Error>> + Unpin {
Box::pin(try_stream! {
Box::pin(try_stream2! {
loop {
let notification = self.recv().await?;
yield notification;
r#yield!(self.recv().await?);
}
})
}

View file

@ -1,6 +1,5 @@
use std::marker::PhantomData;
use async_stream::try_stream;
use either::Either;
use futures_core::stream::BoxStream;
use futures_util::{future, StreamExt, TryFutureExt, TryStreamExt};
@ -238,17 +237,19 @@ where
F: 'e,
O: 'e,
{
Box::pin(try_stream! {
Box::pin(try_stream2! {
let mut s = executor.fetch_many(self.inner);
while let Some(v) = s.try_next().await? {
match v {
Either::Left(v) => yield Either::Left(v),
r#yield!(match v {
Either::Left(v) => Either::Left(v),
Either::Right(row) => {
let mapped = (self.mapper)(row)?;
yield Either::Right(mapped);
Either::Right((self.mapper)(row)?)
}
}
});
}
Ok(())
})
}

View file

@ -1,6 +1,5 @@
use std::marker::PhantomData;
use async_stream::try_stream;
use either::Either;
use futures_core::stream::BoxStream;
use futures_util::{StreamExt, TryStreamExt};
@ -82,17 +81,17 @@ where
O: 'e,
A: 'e,
{
Box::pin(try_stream! {
Box::pin(try_stream2! {
let mut s = executor.fetch_many(self.inner);
while let Some(v) = s.try_next().await? {
match v {
Either::Left(v) => yield Either::Left(v),
Either::Right(row) => {
let mapped = O::from_row(&row)?;
yield Either::Right(mapped);
}
}
r#yield!(match v {
Either::Left(v) => Either::Left(v),
Either::Right(row) => Either::Right(O::from_row(&row)?),
});
}
Ok(())
})
}

View file

@ -1,6 +1,5 @@
use std::sync::Arc;
use async_stream::try_stream;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
@ -85,7 +84,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
let s = query.query();
let arguments = query.take_arguments();
Box::pin(try_stream! {
Box::pin(try_stream2! {
let SqliteConnection {
handle: ref mut conn,
ref mut statements,
@ -125,8 +124,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
match worker.step(handle).await? {
Either::Left(changes) => {
let v = Either::Left(changes);
yield v;
r#yield!(Either::Left(changes));
break;
}
@ -140,12 +138,14 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
let v = Either::Right(row);
*last_row_values = Some(weak_values_ref);
yield v;
r#yield!(v);
}
}
}
}
}
Ok(())
})
}

View file

@ -1,5 +1,5 @@
use futures::TryStreamExt;
use sqlx::mysql::{MySql, MySqlRow, MySqlPool};
use sqlx::mysql::{MySql, MySqlPool, MySqlRow};
use sqlx::{Connection, Executor, Row};
use sqlx_test::new;
use std::env;
@ -74,50 +74,25 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY);
Ok(())
}
#[sqlx_macros::test]
async fn it_executes_2() -> anyhow::Result<()> {
async fn it_executes_with_pool() -> anyhow::Result<()> {
let pool: MySqlPool = MySqlPool::builder()
.min_size(2)
.max_size(2)
.build(&env::var("DATABASE_URL")?)
.test_on_acquire(false)
.build(&dotenv::var("DATABASE_URL")?)
.await?;
let mut conn = pool.acquire().await?;
let rows = pool.fetch_all("SELECT 1; SELECT 2").await?;
#[derive(Debug, sqlx::FromRow)]
struct User { id: i32 };
assert_eq!(rows.len(), 2);
let _ = sqlx::query(
r#"
CREATE TABLE users (id INTEGER PRIMARY KEY);
"#,
)
.execute(&mut conn)
.await?;
let count = pool
.fetch("SELECT 1; SELECT 2")
.try_fold(0, |acc, _| async move { Ok(acc + 1) })
.await?;
for index in 1..=10_i32 {
let cnt = sqlx::query("INSERT INTO users (id) VALUES (?)")
.bind(index)
.execute(&mut conn)
.await?;
assert_eq!(cnt, 1);
}
let users: Vec<User> = sqlx::query_as::<MySql, User>(
"SELECT
id
FROM
users"
)
.fetch_all(&pool)
.await?;
assert_eq!(users.len(), 10);
sqlx::query("drop table users;").execute(&mut conn).await?;
assert_eq!(count, 2);
Ok(())
}

View file

@ -1,7 +1,7 @@
use futures::TryStreamExt;
use sqlx::postgres::PgRow;
use sqlx::postgres::{PgDatabaseError, PgErrorPosition, PgPool, PgSeverity};
use sqlx::{postgres::Postgres, Connection, Executor, Row};
use sqlx_core::postgres::{PgDatabaseError, PgErrorPosition, PgSeverity};
use sqlx_test::new;
#[sqlx_macros::test]
@ -99,6 +99,22 @@ CREATE TEMPORARY TABLE users (id INTEGER PRIMARY KEY);
Ok(())
}
#[sqlx_macros::test]
async fn it_executes_with_pool() -> anyhow::Result<()> {
let pool: PgPool = PgPool::builder()
.min_size(2)
.max_size(2)
.test_on_acquire(false)
.build(&dotenv::var("DATABASE_URL")?)
.await?;
let rows = pool.fetch_all("SELECT 1; SElECT 2").await?;
assert_eq!(rows.len(), 2);
Ok(())
}
// https://github.com/launchbadge/sqlx/issues/104
#[sqlx_macros::test]
async fn it_can_return_interleaved_nulls_issue_104() -> anyhow::Result<()> {

View file

@ -1,5 +1,7 @@
use futures::TryStreamExt;
use sqlx::{query, sqlite::Sqlite, Connect, Connection, Executor, Row, SqliteConnection};
use sqlx::{
query, sqlite::Sqlite, Connect, Connection, Executor, Row, SqliteConnection, SqlitePool,
};
use sqlx_test::new;
#[sqlx_macros::test]
@ -87,6 +89,22 @@ async fn it_fetches_in_loop() -> anyhow::Result<()> {
Ok(())
}
#[sqlx_macros::test]
async fn it_executes_with_pool() -> anyhow::Result<()> {
let pool: SqlitePool = SqlitePool::builder()
.min_size(2)
.max_size(2)
.test_on_acquire(false)
.build(&dotenv::var("DATABASE_URL")?)
.await?;
let rows = pool.fetch_all("SELECT 1; SElECT 2").await?;
assert_eq!(rows.len(), 2);
Ok(())
}
#[sqlx_macros::test]
async fn it_opens_in_memory() -> anyhow::Result<()> {
// If the filename is ":memory:", then a private, temporary in-memory database