diff --git a/sqlx-core/src/postgres/listener.rs b/sqlx-core/src/postgres/listener.rs index 03bb34de..1432ae6c 100644 --- a/sqlx-core/src/postgres/listener.rs +++ b/sqlx-core/src/postgres/listener.rs @@ -6,6 +6,7 @@ use either::Either; use futures_channel::mpsc; use futures_core::future::BoxFuture; use futures_core::stream::{BoxStream, Stream}; +use futures_util::{FutureExt, StreamExt, TryStreamExt}; use crate::describe::Describe; use crate::error::Error; @@ -96,6 +97,7 @@ impl PgListener { /// The channel name is quoted here to ensure case sensitivity. pub async fn listen(&mut self, channel: &str) -> Result<(), Error> { self.connection() + .await? .execute(&*format!(r#"LISTEN "{}""#, ident(channel))) .await?; @@ -112,11 +114,8 @@ impl PgListener { let beg = self.channels.len(); self.channels.extend(channels.into_iter().map(|s| s.into())); - self.connection - .as_mut() - .unwrap() - .execute(&*build_listen_all_query(&self.channels[beg..])) - .await?; + let query = build_listen_all_query(&self.channels[beg..]); + self.connection().await?.execute(&*query).await?; Ok(()) } @@ -124,9 +123,13 @@ impl PgListener { /// Stops listening for notifications on a channel. /// The channel name is quoted here to ensure case sensitivity. pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> { - self.connection() - .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel))) - .await?; + // use RAW connection and do NOT re-connect automatically, since this is not required for + // UNLISTEN (we've disconnected anyways) + if let Some(connection) = self.connection.as_mut() { + connection + .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel))) + .await?; + } if let Some(pos) = self.channels.iter().position(|s| s == channel) { self.channels.remove(pos); @@ -137,7 +140,11 @@ impl PgListener { /// Stops listening for notifications on all channels. pub async fn unlisten_all(&mut self) -> Result<(), Error> { - self.connection().execute("UNLISTEN *").await?; + // use RAW connection and do NOT re-connect automatically, since this is not required for + // UNLISTEN (we've disconnected anyways) + if let Some(connection) = self.connection.as_mut() { + connection.execute("UNLISTEN *").await?; + } self.channels.clear(); @@ -161,8 +168,11 @@ impl PgListener { } #[inline] - fn connection(&mut self) -> &mut PgConnection { - self.connection.as_mut().unwrap() + async fn connection(&mut self) -> Result<&mut PgConnection, Error> { + // Ensure we have an active connection to work with. + self.connect_if_needed().await?; + + Ok(self.connection.as_mut().unwrap()) } /// Receives the next notification available from any of the subscribed channels. @@ -237,10 +247,7 @@ impl PgListener { let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event()); loop { - // Ensure we have an active connection to work with. - self.connect_if_needed().await?; - - let next_message = self.connection().stream.recv_unchecked(); + let next_message = self.connection().await?.stream.recv_unchecked(); let res = if let Some(ref mut close_event) = close_event { // cancels the wait and returns `Err(PoolClosed)` if the pool is closed @@ -256,7 +263,7 @@ impl PgListener { // The connection is dead, ensure that it is dropped, // update self state, and loop to try again. Err(Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { - self.buffer_tx = self.connection().stream.notifications.take(); + self.buffer_tx = self.connection().await?.stream.notifications.take(); self.connection = None; // lost connection @@ -277,7 +284,7 @@ impl PgListener { // Mark the connection as ready for another query MessageFormat::ReadyForQuery => { - self.connection().pending_ready_for_query_count -= 1; + self.connection().await?.pending_ready_for_query_count -= 1; } // Ignore unexpected messages @@ -336,7 +343,13 @@ impl<'c> Executor<'c> for &'c mut PgListener { 'c: 'e, E: Execute<'q, Self::Database>, { - self.connection().fetch_many(query) + futures_util::stream::once(async move { + // need some basic type annotation to help the compiler a bit + let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query)); + res + }) + .try_flatten() + .boxed() } fn fetch_optional<'e, 'q: 'e, E: 'q>( @@ -347,7 +360,7 @@ impl<'c> Executor<'c> for &'c mut PgListener { 'c: 'e, E: Execute<'q, Self::Database>, { - self.connection().fetch_optional(query) + async move { self.connection().await?.fetch_optional(query).await }.boxed() } fn prepare_with<'e, 'q: 'e>( @@ -358,7 +371,13 @@ impl<'c> Executor<'c> for &'c mut PgListener { where 'c: 'e, { - self.connection().prepare_with(query, parameters) + async move { + self.connection() + .await? + .prepare_with(query, parameters) + .await + } + .boxed() } #[doc(hidden)] @@ -369,7 +388,7 @@ impl<'c> Executor<'c> for &'c mut PgListener { where 'c: 'e, { - self.connection().describe(query) + async move { self.connection().await?.describe(query).await }.boxed() } }