fix cancellation issues with PgListener, PgStream::recv() (#3467)

* fix(postgres): make `PgStream::recv_unchecked()` cancel-safe

* fix(postgres): make `PgListener` close the connection on-error

* fix: incorrect math in `BufferedSocket::read_buffered()`
This commit is contained in:
Austin Bonander 2024-08-27 10:54:31 -07:00 committed by GitHub
parent 20ba796b0d
commit e10789d9d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 120 additions and 21 deletions

View file

@ -1,9 +1,9 @@
use crate::error::Error;
use crate::net::Socket; use crate::net::Socket;
use bytes::BytesMut; use bytes::BytesMut;
use std::ops::ControlFlow;
use std::{cmp, io}; use std::{cmp, io};
use crate::error::Error;
use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode}; use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode};
// Tokio, async-std, and std all use this as the default capacity for their buffered I/O. // Tokio, async-std, and std all use this as the default capacity for their buffered I/O.
@ -45,8 +45,39 @@ impl<S: Socket> BufferedSocket<S> {
} }
} }
pub async fn read_buffered(&mut self, len: usize) -> io::Result<BytesMut> { pub async fn read_buffered(&mut self, len: usize) -> Result<BytesMut, Error> {
self.read_buf.read(len, &mut self.socket).await self.try_read(|buf| {
Ok(if buf.len() < len {
ControlFlow::Continue(len)
} else {
ControlFlow::Break(buf.split_to(len))
})
})
.await
}
/// Retryable read operation.
///
/// The callback should check the contents of the buffer passed to it and either:
///
/// * Remove a full message from the buffer and return [`ControlFlow::Break`], or:
/// * Return [`ControlFlow::Continue`] with the expected _total_ length of the buffer,
/// _without_ modifying it.
///
/// Cancel-safe as long as the callback does not modify the passed `BytesMut`
/// before returning [`ControlFlow::Continue`].
pub async fn try_read<F, R>(&mut self, mut try_read: F) -> Result<R, Error>
where
F: FnMut(&mut BytesMut) -> Result<ControlFlow<R, usize>, Error>,
{
loop {
let read_len = match try_read(&mut self.read_buf.read)? {
ControlFlow::Continue(read_len) => read_len,
ControlFlow::Break(ret) => return Ok(ret),
};
self.read_buf.read(read_len, &mut self.socket).await?;
}
} }
pub fn write_buffer(&self) -> &WriteBuffer { pub fn write_buffer(&self) -> &WriteBuffer {
@ -244,7 +275,7 @@ impl WriteBuffer {
} }
impl ReadBuffer { impl ReadBuffer {
async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<BytesMut> { async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<()> {
// Because of how `BytesMut` works, we should only be shifting capacity back and forth // Because of how `BytesMut` works, we should only be shifting capacity back and forth
// between `read` and `available` unless we have to read an oversize message. // between `read` and `available` unless we have to read an oversize message.
while self.read.len() < len { while self.read.len() < len {
@ -266,7 +297,7 @@ impl ReadBuffer {
self.advance(read); self.advance(read);
} }
Ok(self.drain(len)) Ok(())
} }
fn reserve(&mut self, amt: usize) { fn reserve(&mut self, amt: usize) {
@ -279,10 +310,6 @@ impl ReadBuffer {
self.read.unsplit(self.available.split_to(amt)); self.read.unsplit(self.available.split_to(amt));
} }
fn drain(&mut self, amt: usize) -> BytesMut {
self.read.split_to(amt)
}
fn shrink(&mut self) { fn shrink(&mut self) {
if self.available.capacity() > DEFAULT_BUF_SIZE { if self.available.capacity() > DEFAULT_BUF_SIZE {
// `BytesMut` doesn't have a way to shrink its capacity, // `BytesMut` doesn't have a way to shrink its capacity,

View file

@ -13,11 +13,14 @@ use super::inner::{is_beyond_max_lifetime, DecrementSizeGuard, PoolInner};
use crate::pool::options::PoolConnectionMetadata; use crate::pool::options::PoolConnectionMetadata;
use std::future::Future; use std::future::Future;
const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5);
/// A connection managed by a [`Pool`][crate::pool::Pool]. /// A connection managed by a [`Pool`][crate::pool::Pool].
/// ///
/// Will be returned to the pool on-drop. /// Will be returned to the pool on-drop.
pub struct PoolConnection<DB: Database> { pub struct PoolConnection<DB: Database> {
live: Option<Live<DB>>, live: Option<Live<DB>>,
close_on_drop: bool,
pub(crate) pool: Arc<PoolInner<DB>>, pub(crate) pool: Arc<PoolInner<DB>>,
} }
@ -85,6 +88,16 @@ impl<DB: Database> PoolConnection<DB> {
floating.inner.raw.close().await floating.inner.raw.close().await
} }
/// Close this connection on-drop, instead of returning it to the pool.
///
/// May be used in cases where waiting for the [`.close()`][Self::close] call
/// to complete is unacceptable, but you still want the connection to be closed gracefully
/// so that the server can clean up resources.
#[inline(always)]
pub fn close_on_drop(&mut self) {
self.close_on_drop = true;
}
/// Detach this connection from the pool, allowing it to open a replacement. /// Detach this connection from the pool, allowing it to open a replacement.
/// ///
/// Note that if your application uses a single shared pool, this /// Note that if your application uses a single shared pool, this
@ -140,6 +153,27 @@ impl<DB: Database> PoolConnection<DB> {
} }
} }
} }
fn take_and_close(&mut self) -> impl Future<Output = ()> + Send + 'static {
// float the connection in the pool before we move into the task
// in case the returned `Future` isn't executed, like if it's spawned into a dying runtime
// https://github.com/launchbadge/sqlx/issues/1396
// Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22).
let floating = self.live.take().map(|live| live.float(self.pool.clone()));
let pool = self.pool.clone();
async move {
if let Some(floating) = floating {
// Don't hold the connection forever if it hangs while trying to close
crate::rt::timeout(CLOSE_ON_DROP_TIMEOUT, floating.close())
.await
.ok();
}
pool.min_connections_maintenance(None).await;
}
}
} }
impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB> { impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB> {
@ -164,6 +198,11 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB
/// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from. /// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from.
impl<DB: Database> Drop for PoolConnection<DB> { impl<DB: Database> Drop for PoolConnection<DB> {
fn drop(&mut self) { fn drop(&mut self) {
if self.close_on_drop {
crate::rt::spawn(self.take_and_close());
return;
}
// We still need to spawn a task to maintain `min_connections`. // We still need to spawn a task to maintain `min_connections`.
if self.live.is_some() || self.pool.options.min_connections > 0 { if self.live.is_some() || self.pool.options.min_connections > 0 {
crate::rt::spawn(self.return_to_pool()); crate::rt::spawn(self.return_to_pool());
@ -221,6 +260,7 @@ impl<DB: Database> Floating<DB, Live<DB>> {
guard.cancel(); guard.cancel();
PoolConnection { PoolConnection {
live: Some(inner), live: Some(inner),
close_on_drop: false,
pool, pool,
} }
} }

View file

@ -1,11 +1,11 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::ops::{Deref, DerefMut}; use std::ops::{ControlFlow, Deref, DerefMut};
use std::str::FromStr; use std::str::FromStr;
use futures_channel::mpsc::UnboundedSender; use futures_channel::mpsc::UnboundedSender;
use futures_util::SinkExt; use futures_util::SinkExt;
use log::Level; use log::Level;
use sqlx_core::bytes::{Buf, Bytes}; use sqlx_core::bytes::Buf;
use crate::connection::tls::MaybeUpgradeTls; use crate::connection::tls::MaybeUpgradeTls;
use crate::error::Error; use crate::error::Error;
@ -77,16 +77,45 @@ impl PgStream {
} }
pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> { pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
// all packets in postgres start with a 5-byte header // NOTE: to not break everything, this should be cancel-safe;
// this header contains the message type and the total length of the message // DO NOT modify `buf` unless a full message has been read
let mut header: Bytes = self.inner.read(5).await?; self.inner
.try_read(|buf| {
// all packets in postgres start with a 5-byte header
// this header contains the message type and the total length of the message
let Some(mut header) = buf.get(..5) else {
return Ok(ControlFlow::Continue(5));
};
let format = BackendMessageFormat::try_from_u8(header.get_u8())?; let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
let size = (header.get_u32() - 4) as usize;
let contents = self.inner.read(size).await?; let message_len = header.get_u32() as usize;
Ok(ReceivedMessage { format, contents }) let expected_len = message_len
.checked_add(1)
// this shouldn't really happen but is mostly a sanity check
.ok_or_else(|| {
err_protocol!("message_len + 1 overflows usize: {message_len}")
})?;
if buf.len() < expected_len {
return Ok(ControlFlow::Continue(expected_len));
}
// `buf` SHOULD NOT be modified ABOVE this line
// pop off the format code since it's not counted in `message_len`
buf.advance(1);
// consume the message, including the length prefix
let mut contents = buf.split_to(message_len).freeze();
// cut off the length prefix
contents.advance(4);
Ok(ControlFlow::Break(ReceivedMessage { format, contents }))
})
.await
} }
// Get the next message from the server // Get the next message from the server

View file

@ -262,8 +262,11 @@ impl PgListener {
if (err.kind() == io::ErrorKind::ConnectionAborted if (err.kind() == io::ErrorKind::ConnectionAborted
|| err.kind() == io::ErrorKind::UnexpectedEof) => || err.kind() == io::ErrorKind::UnexpectedEof) =>
{ {
self.buffer_tx = self.connection().await?.stream.notifications.take(); if let Some(mut conn) = self.connection.take() {
self.connection = None; self.buffer_tx = conn.stream.notifications.take();
// Close the connection in a background task, so we can continue.
conn.close_on_drop();
}
// lost connection // lost connection
return Ok(None); return Ok(None);