mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
fix cancellation issues with PgListener
, PgStream::recv()
(#3467)
* fix(postgres): make `PgStream::recv_unchecked()` cancel-safe * fix(postgres): make `PgListener` close the connection on-error * fix: incorrect math in `BufferedSocket::read_buffered()`
This commit is contained in:
parent
20ba796b0d
commit
e10789d9d7
4 changed files with 120 additions and 21 deletions
|
@ -1,9 +1,9 @@
|
|||
use crate::error::Error;
|
||||
use crate::net::Socket;
|
||||
use bytes::BytesMut;
|
||||
use std::ops::ControlFlow;
|
||||
use std::{cmp, io};
|
||||
|
||||
use crate::error::Error;
|
||||
|
||||
use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode};
|
||||
|
||||
// Tokio, async-std, and std all use this as the default capacity for their buffered I/O.
|
||||
|
@ -45,8 +45,39 @@ impl<S: Socket> BufferedSocket<S> {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn read_buffered(&mut self, len: usize) -> io::Result<BytesMut> {
|
||||
self.read_buf.read(len, &mut self.socket).await
|
||||
pub async fn read_buffered(&mut self, len: usize) -> Result<BytesMut, Error> {
|
||||
self.try_read(|buf| {
|
||||
Ok(if buf.len() < len {
|
||||
ControlFlow::Continue(len)
|
||||
} else {
|
||||
ControlFlow::Break(buf.split_to(len))
|
||||
})
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Retryable read operation.
|
||||
///
|
||||
/// The callback should check the contents of the buffer passed to it and either:
|
||||
///
|
||||
/// * Remove a full message from the buffer and return [`ControlFlow::Break`], or:
|
||||
/// * Return [`ControlFlow::Continue`] with the expected _total_ length of the buffer,
|
||||
/// _without_ modifying it.
|
||||
///
|
||||
/// Cancel-safe as long as the callback does not modify the passed `BytesMut`
|
||||
/// before returning [`ControlFlow::Continue`].
|
||||
pub async fn try_read<F, R>(&mut self, mut try_read: F) -> Result<R, Error>
|
||||
where
|
||||
F: FnMut(&mut BytesMut) -> Result<ControlFlow<R, usize>, Error>,
|
||||
{
|
||||
loop {
|
||||
let read_len = match try_read(&mut self.read_buf.read)? {
|
||||
ControlFlow::Continue(read_len) => read_len,
|
||||
ControlFlow::Break(ret) => return Ok(ret),
|
||||
};
|
||||
|
||||
self.read_buf.read(read_len, &mut self.socket).await?;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_buffer(&self) -> &WriteBuffer {
|
||||
|
@ -244,7 +275,7 @@ impl WriteBuffer {
|
|||
}
|
||||
|
||||
impl ReadBuffer {
|
||||
async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<BytesMut> {
|
||||
async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<()> {
|
||||
// Because of how `BytesMut` works, we should only be shifting capacity back and forth
|
||||
// between `read` and `available` unless we have to read an oversize message.
|
||||
while self.read.len() < len {
|
||||
|
@ -266,7 +297,7 @@ impl ReadBuffer {
|
|||
self.advance(read);
|
||||
}
|
||||
|
||||
Ok(self.drain(len))
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn reserve(&mut self, amt: usize) {
|
||||
|
@ -279,10 +310,6 @@ impl ReadBuffer {
|
|||
self.read.unsplit(self.available.split_to(amt));
|
||||
}
|
||||
|
||||
fn drain(&mut self, amt: usize) -> BytesMut {
|
||||
self.read.split_to(amt)
|
||||
}
|
||||
|
||||
fn shrink(&mut self) {
|
||||
if self.available.capacity() > DEFAULT_BUF_SIZE {
|
||||
// `BytesMut` doesn't have a way to shrink its capacity,
|
||||
|
|
|
@ -13,11 +13,14 @@ use super::inner::{is_beyond_max_lifetime, DecrementSizeGuard, PoolInner};
|
|||
use crate::pool::options::PoolConnectionMetadata;
|
||||
use std::future::Future;
|
||||
|
||||
const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// A connection managed by a [`Pool`][crate::pool::Pool].
|
||||
///
|
||||
/// Will be returned to the pool on-drop.
|
||||
pub struct PoolConnection<DB: Database> {
|
||||
live: Option<Live<DB>>,
|
||||
close_on_drop: bool,
|
||||
pub(crate) pool: Arc<PoolInner<DB>>,
|
||||
}
|
||||
|
||||
|
@ -85,6 +88,16 @@ impl<DB: Database> PoolConnection<DB> {
|
|||
floating.inner.raw.close().await
|
||||
}
|
||||
|
||||
/// Close this connection on-drop, instead of returning it to the pool.
|
||||
///
|
||||
/// May be used in cases where waiting for the [`.close()`][Self::close] call
|
||||
/// to complete is unacceptable, but you still want the connection to be closed gracefully
|
||||
/// so that the server can clean up resources.
|
||||
#[inline(always)]
|
||||
pub fn close_on_drop(&mut self) {
|
||||
self.close_on_drop = true;
|
||||
}
|
||||
|
||||
/// Detach this connection from the pool, allowing it to open a replacement.
|
||||
///
|
||||
/// Note that if your application uses a single shared pool, this
|
||||
|
@ -140,6 +153,27 @@ impl<DB: Database> PoolConnection<DB> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn take_and_close(&mut self) -> impl Future<Output = ()> + Send + 'static {
|
||||
// float the connection in the pool before we move into the task
|
||||
// in case the returned `Future` isn't executed, like if it's spawned into a dying runtime
|
||||
// https://github.com/launchbadge/sqlx/issues/1396
|
||||
// Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22).
|
||||
let floating = self.live.take().map(|live| live.float(self.pool.clone()));
|
||||
|
||||
let pool = self.pool.clone();
|
||||
|
||||
async move {
|
||||
if let Some(floating) = floating {
|
||||
// Don't hold the connection forever if it hangs while trying to close
|
||||
crate::rt::timeout(CLOSE_ON_DROP_TIMEOUT, floating.close())
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
|
||||
pool.min_connections_maintenance(None).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB> {
|
||||
|
@ -164,6 +198,11 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB
|
|||
/// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from.
|
||||
impl<DB: Database> Drop for PoolConnection<DB> {
|
||||
fn drop(&mut self) {
|
||||
if self.close_on_drop {
|
||||
crate::rt::spawn(self.take_and_close());
|
||||
return;
|
||||
}
|
||||
|
||||
// We still need to spawn a task to maintain `min_connections`.
|
||||
if self.live.is_some() || self.pool.options.min_connections > 0 {
|
||||
crate::rt::spawn(self.return_to_pool());
|
||||
|
@ -221,6 +260,7 @@ impl<DB: Database> Floating<DB, Live<DB>> {
|
|||
guard.cancel();
|
||||
PoolConnection {
|
||||
live: Some(inner),
|
||||
close_on_drop: false,
|
||||
pool,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
use std::collections::BTreeMap;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::ops::{ControlFlow, Deref, DerefMut};
|
||||
use std::str::FromStr;
|
||||
|
||||
use futures_channel::mpsc::UnboundedSender;
|
||||
use futures_util::SinkExt;
|
||||
use log::Level;
|
||||
use sqlx_core::bytes::{Buf, Bytes};
|
||||
use sqlx_core::bytes::Buf;
|
||||
|
||||
use crate::connection::tls::MaybeUpgradeTls;
|
||||
use crate::error::Error;
|
||||
|
@ -77,16 +77,45 @@ impl PgStream {
|
|||
}
|
||||
|
||||
pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
|
||||
// NOTE: to not break everything, this should be cancel-safe;
|
||||
// DO NOT modify `buf` unless a full message has been read
|
||||
self.inner
|
||||
.try_read(|buf| {
|
||||
// all packets in postgres start with a 5-byte header
|
||||
// this header contains the message type and the total length of the message
|
||||
let mut header: Bytes = self.inner.read(5).await?;
|
||||
let Some(mut header) = buf.get(..5) else {
|
||||
return Ok(ControlFlow::Continue(5));
|
||||
};
|
||||
|
||||
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
|
||||
let size = (header.get_u32() - 4) as usize;
|
||||
|
||||
let contents = self.inner.read(size).await?;
|
||||
let message_len = header.get_u32() as usize;
|
||||
|
||||
Ok(ReceivedMessage { format, contents })
|
||||
let expected_len = message_len
|
||||
.checked_add(1)
|
||||
// this shouldn't really happen but is mostly a sanity check
|
||||
.ok_or_else(|| {
|
||||
err_protocol!("message_len + 1 overflows usize: {message_len}")
|
||||
})?;
|
||||
|
||||
if buf.len() < expected_len {
|
||||
return Ok(ControlFlow::Continue(expected_len));
|
||||
}
|
||||
|
||||
// `buf` SHOULD NOT be modified ABOVE this line
|
||||
|
||||
// pop off the format code since it's not counted in `message_len`
|
||||
buf.advance(1);
|
||||
|
||||
// consume the message, including the length prefix
|
||||
let mut contents = buf.split_to(message_len).freeze();
|
||||
|
||||
// cut off the length prefix
|
||||
contents.advance(4);
|
||||
|
||||
Ok(ControlFlow::Break(ReceivedMessage { format, contents }))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
// Get the next message from the server
|
||||
|
|
|
@ -262,8 +262,11 @@ impl PgListener {
|
|||
if (err.kind() == io::ErrorKind::ConnectionAborted
|
||||
|| err.kind() == io::ErrorKind::UnexpectedEof) =>
|
||||
{
|
||||
self.buffer_tx = self.connection().await?.stream.notifications.take();
|
||||
self.connection = None;
|
||||
if let Some(mut conn) = self.connection.take() {
|
||||
self.buffer_tx = conn.stream.notifications.take();
|
||||
// Close the connection in a background task, so we can continue.
|
||||
conn.close_on_drop();
|
||||
}
|
||||
|
||||
// lost connection
|
||||
return Ok(None);
|
||||
|
|
Loading…
Reference in a new issue