fix: ensure PG connection is established before using it (#1989)

Fixes #1940.
This commit is contained in:
Marco Neumann 2022-07-27 22:58:36 +02:00 committed by GitHub
parent 5e08cd077e
commit 29073cbe84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,6 +6,7 @@ use either::Either;
use futures_channel::mpsc;
use futures_core::future::BoxFuture;
use futures_core::stream::{BoxStream, Stream};
use futures_util::{FutureExt, StreamExt, TryStreamExt};
use crate::describe::Describe;
use crate::error::Error;
@ -96,6 +97,7 @@ impl PgListener {
/// The channel name is quoted here to ensure case sensitivity.
pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
self.connection()
.await?
.execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
.await?;
@ -112,11 +114,8 @@ impl PgListener {
let beg = self.channels.len();
self.channels.extend(channels.into_iter().map(|s| s.into()));
self.connection
.as_mut()
.unwrap()
.execute(&*build_listen_all_query(&self.channels[beg..]))
.await?;
let query = build_listen_all_query(&self.channels[beg..]);
self.connection().await?.execute(&*query).await?;
Ok(())
}
@ -124,9 +123,13 @@ impl PgListener {
/// Stops listening for notifications on a channel.
/// The channel name is quoted here to ensure case sensitivity.
pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
self.connection()
.execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
.await?;
// use RAW connection and do NOT re-connect automatically, since this is not required for
// UNLISTEN (we've disconnected anyways)
if let Some(connection) = self.connection.as_mut() {
connection
.execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
.await?;
}
if let Some(pos) = self.channels.iter().position(|s| s == channel) {
self.channels.remove(pos);
@ -137,7 +140,11 @@ impl PgListener {
/// Stops listening for notifications on all channels.
pub async fn unlisten_all(&mut self) -> Result<(), Error> {
self.connection().execute("UNLISTEN *").await?;
// use RAW connection and do NOT re-connect automatically, since this is not required for
// UNLISTEN (we've disconnected anyways)
if let Some(connection) = self.connection.as_mut() {
connection.execute("UNLISTEN *").await?;
}
self.channels.clear();
@ -161,8 +168,11 @@ impl PgListener {
}
#[inline]
fn connection(&mut self) -> &mut PgConnection {
self.connection.as_mut().unwrap()
async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
// Ensure we have an active connection to work with.
self.connect_if_needed().await?;
Ok(self.connection.as_mut().unwrap())
}
/// Receives the next notification available from any of the subscribed channels.
@ -237,10 +247,7 @@ impl PgListener {
let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
loop {
// Ensure we have an active connection to work with.
self.connect_if_needed().await?;
let next_message = self.connection().stream.recv_unchecked();
let next_message = self.connection().await?.stream.recv_unchecked();
let res = if let Some(ref mut close_event) = close_event {
// cancels the wait and returns `Err(PoolClosed)` if the pool is closed
@ -256,7 +263,7 @@ impl PgListener {
// The connection is dead, ensure that it is dropped,
// update self state, and loop to try again.
Err(Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
self.buffer_tx = self.connection().stream.notifications.take();
self.buffer_tx = self.connection().await?.stream.notifications.take();
self.connection = None;
// lost connection
@ -277,7 +284,7 @@ impl PgListener {
// Mark the connection as ready for another query
MessageFormat::ReadyForQuery => {
self.connection().pending_ready_for_query_count -= 1;
self.connection().await?.pending_ready_for_query_count -= 1;
}
// Ignore unexpected messages
@ -336,7 +343,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.connection().fetch_many(query)
futures_util::stream::once(async move {
// need some basic type annotation to help the compiler a bit
let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
res
})
.try_flatten()
.boxed()
}
fn fetch_optional<'e, 'q: 'e, E: 'q>(
@ -347,7 +360,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
'c: 'e,
E: Execute<'q, Self::Database>,
{
self.connection().fetch_optional(query)
async move { self.connection().await?.fetch_optional(query).await }.boxed()
}
fn prepare_with<'e, 'q: 'e>(
@ -358,7 +371,13 @@ impl<'c> Executor<'c> for &'c mut PgListener {
where
'c: 'e,
{
self.connection().prepare_with(query, parameters)
async move {
self.connection()
.await?
.prepare_with(query, parameters)
.await
}
.boxed()
}
#[doc(hidden)]
@ -369,7 +388,7 @@ impl<'c> Executor<'c> for &'c mut PgListener {
where
'c: 'e,
{
self.connection().describe(query)
async move { self.connection().await?.describe(query).await }.boxed()
}
}