mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
fix cancellation issues with PgListener
, PgStream::recv()
(#3467)
* fix(postgres): make `PgStream::recv_unchecked()` cancel-safe * fix(postgres): make `PgListener` close the connection on-error * fix: incorrect math in `BufferedSocket::read_buffered()`
This commit is contained in:
parent
20ba796b0d
commit
e10789d9d7
4 changed files with 120 additions and 21 deletions
|
@ -1,9 +1,9 @@
|
||||||
|
use crate::error::Error;
|
||||||
use crate::net::Socket;
|
use crate::net::Socket;
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
|
use std::ops::ControlFlow;
|
||||||
use std::{cmp, io};
|
use std::{cmp, io};
|
||||||
|
|
||||||
use crate::error::Error;
|
|
||||||
|
|
||||||
use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode};
|
use crate::io::{AsyncRead, AsyncReadExt, ProtocolDecode, ProtocolEncode};
|
||||||
|
|
||||||
// Tokio, async-std, and std all use this as the default capacity for their buffered I/O.
|
// Tokio, async-std, and std all use this as the default capacity for their buffered I/O.
|
||||||
|
@ -45,8 +45,39 @@ impl<S: Socket> BufferedSocket<S> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn read_buffered(&mut self, len: usize) -> io::Result<BytesMut> {
|
pub async fn read_buffered(&mut self, len: usize) -> Result<BytesMut, Error> {
|
||||||
self.read_buf.read(len, &mut self.socket).await
|
self.try_read(|buf| {
|
||||||
|
Ok(if buf.len() < len {
|
||||||
|
ControlFlow::Continue(len)
|
||||||
|
} else {
|
||||||
|
ControlFlow::Break(buf.split_to(len))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retryable read operation.
|
||||||
|
///
|
||||||
|
/// The callback should check the contents of the buffer passed to it and either:
|
||||||
|
///
|
||||||
|
/// * Remove a full message from the buffer and return [`ControlFlow::Break`], or:
|
||||||
|
/// * Return [`ControlFlow::Continue`] with the expected _total_ length of the buffer,
|
||||||
|
/// _without_ modifying it.
|
||||||
|
///
|
||||||
|
/// Cancel-safe as long as the callback does not modify the passed `BytesMut`
|
||||||
|
/// before returning [`ControlFlow::Continue`].
|
||||||
|
pub async fn try_read<F, R>(&mut self, mut try_read: F) -> Result<R, Error>
|
||||||
|
where
|
||||||
|
F: FnMut(&mut BytesMut) -> Result<ControlFlow<R, usize>, Error>,
|
||||||
|
{
|
||||||
|
loop {
|
||||||
|
let read_len = match try_read(&mut self.read_buf.read)? {
|
||||||
|
ControlFlow::Continue(read_len) => read_len,
|
||||||
|
ControlFlow::Break(ret) => return Ok(ret),
|
||||||
|
};
|
||||||
|
|
||||||
|
self.read_buf.read(read_len, &mut self.socket).await?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn write_buffer(&self) -> &WriteBuffer {
|
pub fn write_buffer(&self) -> &WriteBuffer {
|
||||||
|
@ -244,7 +275,7 @@ impl WriteBuffer {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ReadBuffer {
|
impl ReadBuffer {
|
||||||
async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<BytesMut> {
|
async fn read(&mut self, len: usize, socket: &mut impl Socket) -> io::Result<()> {
|
||||||
// Because of how `BytesMut` works, we should only be shifting capacity back and forth
|
// Because of how `BytesMut` works, we should only be shifting capacity back and forth
|
||||||
// between `read` and `available` unless we have to read an oversize message.
|
// between `read` and `available` unless we have to read an oversize message.
|
||||||
while self.read.len() < len {
|
while self.read.len() < len {
|
||||||
|
@ -266,7 +297,7 @@ impl ReadBuffer {
|
||||||
self.advance(read);
|
self.advance(read);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(self.drain(len))
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reserve(&mut self, amt: usize) {
|
fn reserve(&mut self, amt: usize) {
|
||||||
|
@ -279,10 +310,6 @@ impl ReadBuffer {
|
||||||
self.read.unsplit(self.available.split_to(amt));
|
self.read.unsplit(self.available.split_to(amt));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn drain(&mut self, amt: usize) -> BytesMut {
|
|
||||||
self.read.split_to(amt)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn shrink(&mut self) {
|
fn shrink(&mut self) {
|
||||||
if self.available.capacity() > DEFAULT_BUF_SIZE {
|
if self.available.capacity() > DEFAULT_BUF_SIZE {
|
||||||
// `BytesMut` doesn't have a way to shrink its capacity,
|
// `BytesMut` doesn't have a way to shrink its capacity,
|
||||||
|
|
|
@ -13,11 +13,14 @@ use super::inner::{is_beyond_max_lifetime, DecrementSizeGuard, PoolInner};
|
||||||
use crate::pool::options::PoolConnectionMetadata;
|
use crate::pool::options::PoolConnectionMetadata;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
|
|
||||||
|
const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5);
|
||||||
|
|
||||||
/// A connection managed by a [`Pool`][crate::pool::Pool].
|
/// A connection managed by a [`Pool`][crate::pool::Pool].
|
||||||
///
|
///
|
||||||
/// Will be returned to the pool on-drop.
|
/// Will be returned to the pool on-drop.
|
||||||
pub struct PoolConnection<DB: Database> {
|
pub struct PoolConnection<DB: Database> {
|
||||||
live: Option<Live<DB>>,
|
live: Option<Live<DB>>,
|
||||||
|
close_on_drop: bool,
|
||||||
pub(crate) pool: Arc<PoolInner<DB>>,
|
pub(crate) pool: Arc<PoolInner<DB>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,6 +88,16 @@ impl<DB: Database> PoolConnection<DB> {
|
||||||
floating.inner.raw.close().await
|
floating.inner.raw.close().await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Close this connection on-drop, instead of returning it to the pool.
|
||||||
|
///
|
||||||
|
/// May be used in cases where waiting for the [`.close()`][Self::close] call
|
||||||
|
/// to complete is unacceptable, but you still want the connection to be closed gracefully
|
||||||
|
/// so that the server can clean up resources.
|
||||||
|
#[inline(always)]
|
||||||
|
pub fn close_on_drop(&mut self) {
|
||||||
|
self.close_on_drop = true;
|
||||||
|
}
|
||||||
|
|
||||||
/// Detach this connection from the pool, allowing it to open a replacement.
|
/// Detach this connection from the pool, allowing it to open a replacement.
|
||||||
///
|
///
|
||||||
/// Note that if your application uses a single shared pool, this
|
/// Note that if your application uses a single shared pool, this
|
||||||
|
@ -140,6 +153,27 @@ impl<DB: Database> PoolConnection<DB> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn take_and_close(&mut self) -> impl Future<Output = ()> + Send + 'static {
|
||||||
|
// float the connection in the pool before we move into the task
|
||||||
|
// in case the returned `Future` isn't executed, like if it's spawned into a dying runtime
|
||||||
|
// https://github.com/launchbadge/sqlx/issues/1396
|
||||||
|
// Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22).
|
||||||
|
let floating = self.live.take().map(|live| live.float(self.pool.clone()));
|
||||||
|
|
||||||
|
let pool = self.pool.clone();
|
||||||
|
|
||||||
|
async move {
|
||||||
|
if let Some(floating) = floating {
|
||||||
|
// Don't hold the connection forever if it hangs while trying to close
|
||||||
|
crate::rt::timeout(CLOSE_ON_DROP_TIMEOUT, floating.close())
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
pool.min_connections_maintenance(None).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB> {
|
impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB> {
|
||||||
|
@ -164,6 +198,11 @@ impl<'c, DB: Database> crate::acquire::Acquire<'c> for &'c mut PoolConnection<DB
|
||||||
/// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from.
|
/// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from.
|
||||||
impl<DB: Database> Drop for PoolConnection<DB> {
|
impl<DB: Database> Drop for PoolConnection<DB> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
|
if self.close_on_drop {
|
||||||
|
crate::rt::spawn(self.take_and_close());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// We still need to spawn a task to maintain `min_connections`.
|
// We still need to spawn a task to maintain `min_connections`.
|
||||||
if self.live.is_some() || self.pool.options.min_connections > 0 {
|
if self.live.is_some() || self.pool.options.min_connections > 0 {
|
||||||
crate::rt::spawn(self.return_to_pool());
|
crate::rt::spawn(self.return_to_pool());
|
||||||
|
@ -221,6 +260,7 @@ impl<DB: Database> Floating<DB, Live<DB>> {
|
||||||
guard.cancel();
|
guard.cancel();
|
||||||
PoolConnection {
|
PoolConnection {
|
||||||
live: Some(inner),
|
live: Some(inner),
|
||||||
|
close_on_drop: false,
|
||||||
pool,
|
pool,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
use std::ops::{Deref, DerefMut};
|
use std::ops::{ControlFlow, Deref, DerefMut};
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
use futures_channel::mpsc::UnboundedSender;
|
use futures_channel::mpsc::UnboundedSender;
|
||||||
use futures_util::SinkExt;
|
use futures_util::SinkExt;
|
||||||
use log::Level;
|
use log::Level;
|
||||||
use sqlx_core::bytes::{Buf, Bytes};
|
use sqlx_core::bytes::Buf;
|
||||||
|
|
||||||
use crate::connection::tls::MaybeUpgradeTls;
|
use crate::connection::tls::MaybeUpgradeTls;
|
||||||
use crate::error::Error;
|
use crate::error::Error;
|
||||||
|
@ -77,16 +77,45 @@ impl PgStream {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
|
pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
|
||||||
// all packets in postgres start with a 5-byte header
|
// NOTE: to not break everything, this should be cancel-safe;
|
||||||
// this header contains the message type and the total length of the message
|
// DO NOT modify `buf` unless a full message has been read
|
||||||
let mut header: Bytes = self.inner.read(5).await?;
|
self.inner
|
||||||
|
.try_read(|buf| {
|
||||||
|
// all packets in postgres start with a 5-byte header
|
||||||
|
// this header contains the message type and the total length of the message
|
||||||
|
let Some(mut header) = buf.get(..5) else {
|
||||||
|
return Ok(ControlFlow::Continue(5));
|
||||||
|
};
|
||||||
|
|
||||||
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
|
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
|
||||||
let size = (header.get_u32() - 4) as usize;
|
|
||||||
|
|
||||||
let contents = self.inner.read(size).await?;
|
let message_len = header.get_u32() as usize;
|
||||||
|
|
||||||
Ok(ReceivedMessage { format, contents })
|
let expected_len = message_len
|
||||||
|
.checked_add(1)
|
||||||
|
// this shouldn't really happen but is mostly a sanity check
|
||||||
|
.ok_or_else(|| {
|
||||||
|
err_protocol!("message_len + 1 overflows usize: {message_len}")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if buf.len() < expected_len {
|
||||||
|
return Ok(ControlFlow::Continue(expected_len));
|
||||||
|
}
|
||||||
|
|
||||||
|
// `buf` SHOULD NOT be modified ABOVE this line
|
||||||
|
|
||||||
|
// pop off the format code since it's not counted in `message_len`
|
||||||
|
buf.advance(1);
|
||||||
|
|
||||||
|
// consume the message, including the length prefix
|
||||||
|
let mut contents = buf.split_to(message_len).freeze();
|
||||||
|
|
||||||
|
// cut off the length prefix
|
||||||
|
contents.advance(4);
|
||||||
|
|
||||||
|
Ok(ControlFlow::Break(ReceivedMessage { format, contents }))
|
||||||
|
})
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the next message from the server
|
// Get the next message from the server
|
||||||
|
|
|
@ -262,8 +262,11 @@ impl PgListener {
|
||||||
if (err.kind() == io::ErrorKind::ConnectionAborted
|
if (err.kind() == io::ErrorKind::ConnectionAborted
|
||||||
|| err.kind() == io::ErrorKind::UnexpectedEof) =>
|
|| err.kind() == io::ErrorKind::UnexpectedEof) =>
|
||||||
{
|
{
|
||||||
self.buffer_tx = self.connection().await?.stream.notifications.take();
|
if let Some(mut conn) = self.connection.take() {
|
||||||
self.connection = None;
|
self.buffer_tx = conn.stream.notifications.take();
|
||||||
|
// Close the connection in a background task, so we can continue.
|
||||||
|
conn.close_on_drop();
|
||||||
|
}
|
||||||
|
|
||||||
// lost connection
|
// lost connection
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
|
|
Loading…
Reference in a new issue