feat(mysql): support packet splitting (#2665)

* Writing split packets

* Reading split packets

* Add tests for packet splitting

* Fix test for packet splitting
This commit is contained in:
tk2217 2023-10-11 15:59:06 -05:00 committed by GitHub
parent b1387057e5
commit 5ebe296ecb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 86 additions and 11 deletions

View file

@ -1,7 +1,7 @@
use std::collections::VecDeque;
use std::ops::{Deref, DerefMut};
use bytes::{Buf, Bytes};
use bytes::{Buf, Bytes, BytesMut};
use crate::collation::{CharSet, Collation};
use crate::error::Error;
@ -126,9 +126,7 @@ impl<S: Socket> MySqlStream<S> {
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
}
// receive the next packet from the database server
// may block (async) on more data from the server
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
async fn recv_packet_part(&mut self) -> Result<Bytes, Error> {
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
@ -142,10 +140,33 @@ impl<S: Socket> MySqlStream<S> {
let payload: Bytes = self.socket.read(packet_size).await?;
// TODO: packet compression
// TODO: packet joining
Ok(payload)
}
// receive the next packet from the database server
// may block (async) on more data from the server
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
let payload = self.recv_packet_part().await?;
let payload = if payload.len() < 0xFF_FF_FF {
payload
} else {
let mut final_payload = BytesMut::with_capacity(0xFF_FF_FF * 2);
final_payload.extend_from_slice(&payload);
drop(payload); // we don't need the allocation anymore
let mut last_read = 0xFF_FF_FF;
while last_read == 0xFF_FF_FF {
let part = self.recv_packet_part().await?;
last_read = part.len();
final_payload.extend_from_slice(&part);
}
final_payload.into()
};
if payload
.get(0)
.first()
.ok_or(err_protocol!("Packet empty"))?
.eq(&0xff)
{

View file

@ -1,3 +1,4 @@
use std::cmp::min;
use std::ops::{Deref, DerefMut};
use bytes::Bytes;
@ -19,6 +20,14 @@ where
buf: &mut Vec<u8>,
(capabilities, sequence_id): (Capabilities, &'stream mut u8),
) {
let mut next_header = |len: u32| {
let mut buf = len.to_le_bytes();
buf[3] = *sequence_id;
*sequence_id = sequence_id.wrapping_add(1);
buf
};
// reserve space to write the prefixed length
let offset = buf.len();
buf.extend(&[0_u8; 4]);
@ -31,13 +40,25 @@ where
let len = buf.len() - offset - 4;
let header = &mut buf[offset..];
// FIXME: Support larger packets
assert!(len < 0xFF_FF_FF);
header[..4].copy_from_slice(&next_header(min(len, 0xFF_FF_FF) as u32));
header[..4].copy_from_slice(&(len as u32).to_le_bytes());
header[3] = *sequence_id;
// add more packets if we need to split the data
if len >= 0xFF_FF_FF {
let rest = buf.split_off(offset + 4 + 0xFF_FF_FF);
let mut chunks = rest.chunks_exact(0xFF_FF_FF);
*sequence_id = sequence_id.wrapping_add(1);
for chunk in chunks.by_ref() {
buf.reserve(chunk.len() + 4);
buf.extend(&next_header(chunk.len() as u32));
buf.extend(chunk);
}
// this will also handle adding a zero sized packet if the data size is a multiple of 0xFF_FF_FF
let remainder = chunks.remainder();
buf.reserve(remainder.len() + 4);
buf.extend(&next_header(remainder.len() as u32));
buf.extend(remainder);
}
}
}

View file

@ -447,6 +447,39 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> {
Ok(())
}
#[sqlx_macros::test]
async fn it_can_handle_split_packets() -> anyhow::Result<()> {
// This will only take effect on new connections
new::<MySql>()
.await?
.execute("SET GLOBAL max_allowed_packet = 4294967297")
.await?;
let mut conn = new::<MySql>().await?;
conn.execute(
r#"
CREATE TEMPORARY TABLE large_table (data LONGBLOB);
"#,
)
.await?;
let data = vec![0x41; 0xFF_FF_FF * 2];
sqlx::query("INSERT INTO large_table (data) VALUES (?)")
.bind(&data)
.execute(&mut conn)
.await?;
let ret: Vec<u8> = sqlx::query_scalar("SELECT * FROM large_table")
.fetch_one(&mut conn)
.await?;
assert_eq!(ret, data);
Ok(())
}
#[sqlx_macros::test]
async fn test_shrink_buffers() -> anyhow::Result<()> {
// We don't really have a good way to test that `.shrink_buffers()` functions as expected