mirror of
https://github.com/launchbadge/sqlx
synced 2024-09-20 06:11:57 +00:00
fix: ensure PG connection is established before using it (#1989)
Fixes #1940.
This commit is contained in:
parent
5e08cd077e
commit
29073cbe84
1 changed files with 40 additions and 21 deletions
|
@ -6,6 +6,7 @@ use either::Either;
|
|||
use futures_channel::mpsc;
|
||||
use futures_core::future::BoxFuture;
|
||||
use futures_core::stream::{BoxStream, Stream};
|
||||
use futures_util::{FutureExt, StreamExt, TryStreamExt};
|
||||
|
||||
use crate::describe::Describe;
|
||||
use crate::error::Error;
|
||||
|
@ -96,6 +97,7 @@ impl PgListener {
|
|||
/// The channel name is quoted here to ensure case sensitivity.
|
||||
pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
|
||||
self.connection()
|
||||
.await?
|
||||
.execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
|
||||
.await?;
|
||||
|
||||
|
@ -112,11 +114,8 @@ impl PgListener {
|
|||
let beg = self.channels.len();
|
||||
self.channels.extend(channels.into_iter().map(|s| s.into()));
|
||||
|
||||
self.connection
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.execute(&*build_listen_all_query(&self.channels[beg..]))
|
||||
.await?;
|
||||
let query = build_listen_all_query(&self.channels[beg..]);
|
||||
self.connection().await?.execute(&*query).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -124,9 +123,13 @@ impl PgListener {
|
|||
/// Stops listening for notifications on a channel.
|
||||
/// The channel name is quoted here to ensure case sensitivity.
|
||||
pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
|
||||
self.connection()
|
||||
.execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
|
||||
.await?;
|
||||
// use RAW connection and do NOT re-connect automatically, since this is not required for
|
||||
// UNLISTEN (we've disconnected anyways)
|
||||
if let Some(connection) = self.connection.as_mut() {
|
||||
connection
|
||||
.execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(pos) = self.channels.iter().position(|s| s == channel) {
|
||||
self.channels.remove(pos);
|
||||
|
@ -137,7 +140,11 @@ impl PgListener {
|
|||
|
||||
/// Stops listening for notifications on all channels.
|
||||
pub async fn unlisten_all(&mut self) -> Result<(), Error> {
|
||||
self.connection().execute("UNLISTEN *").await?;
|
||||
// use RAW connection and do NOT re-connect automatically, since this is not required for
|
||||
// UNLISTEN (we've disconnected anyways)
|
||||
if let Some(connection) = self.connection.as_mut() {
|
||||
connection.execute("UNLISTEN *").await?;
|
||||
}
|
||||
|
||||
self.channels.clear();
|
||||
|
||||
|
@ -161,8 +168,11 @@ impl PgListener {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn connection(&mut self) -> &mut PgConnection {
|
||||
self.connection.as_mut().unwrap()
|
||||
async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
|
||||
// Ensure we have an active connection to work with.
|
||||
self.connect_if_needed().await?;
|
||||
|
||||
Ok(self.connection.as_mut().unwrap())
|
||||
}
|
||||
|
||||
/// Receives the next notification available from any of the subscribed channels.
|
||||
|
@ -237,10 +247,7 @@ impl PgListener {
|
|||
let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
|
||||
|
||||
loop {
|
||||
// Ensure we have an active connection to work with.
|
||||
self.connect_if_needed().await?;
|
||||
|
||||
let next_message = self.connection().stream.recv_unchecked();
|
||||
let next_message = self.connection().await?.stream.recv_unchecked();
|
||||
|
||||
let res = if let Some(ref mut close_event) = close_event {
|
||||
// cancels the wait and returns `Err(PoolClosed)` if the pool is closed
|
||||
|
@ -256,7 +263,7 @@ impl PgListener {
|
|||
// The connection is dead, ensure that it is dropped,
|
||||
// update self state, and loop to try again.
|
||||
Err(Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
|
||||
self.buffer_tx = self.connection().stream.notifications.take();
|
||||
self.buffer_tx = self.connection().await?.stream.notifications.take();
|
||||
self.connection = None;
|
||||
|
||||
// lost connection
|
||||
|
@ -277,7 +284,7 @@ impl PgListener {
|
|||
|
||||
// Mark the connection as ready for another query
|
||||
MessageFormat::ReadyForQuery => {
|
||||
self.connection().pending_ready_for_query_count -= 1;
|
||||
self.connection().await?.pending_ready_for_query_count -= 1;
|
||||
}
|
||||
|
||||
// Ignore unexpected messages
|
||||
|
@ -336,7 +343,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
|
|||
'c: 'e,
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
self.connection().fetch_many(query)
|
||||
futures_util::stream::once(async move {
|
||||
// need some basic type annotation to help the compiler a bit
|
||||
let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
|
||||
res
|
||||
})
|
||||
.try_flatten()
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn fetch_optional<'e, 'q: 'e, E: 'q>(
|
||||
|
@ -347,7 +360,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
|
|||
'c: 'e,
|
||||
E: Execute<'q, Self::Database>,
|
||||
{
|
||||
self.connection().fetch_optional(query)
|
||||
async move { self.connection().await?.fetch_optional(query).await }.boxed()
|
||||
}
|
||||
|
||||
fn prepare_with<'e, 'q: 'e>(
|
||||
|
@ -358,7 +371,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
|
|||
where
|
||||
'c: 'e,
|
||||
{
|
||||
self.connection().prepare_with(query, parameters)
|
||||
async move {
|
||||
self.connection()
|
||||
.await?
|
||||
.prepare_with(query, parameters)
|
||||
.await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
|
@ -369,7 +388,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
|
|||
where
|
||||
'c: 'e,
|
||||
{
|
||||
self.connection().describe(query)
|
||||
async move { self.connection().await?.describe(query).await }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue