mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
feat(mysql): support packet splitting (#2665)
* Writing split packets * Reading split packets * Add tests for packet splitting * Fix test for packet splitting
This commit is contained in:
parent
b1387057e5
commit
5ebe296ecb
3 changed files with 86 additions and 11 deletions
|
@ -1,7 +1,7 @@
|
|||
use std::collections::VecDeque;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
|
||||
use crate::collation::{CharSet, Collation};
|
||||
use crate::error::Error;
|
||||
|
@ -126,9 +126,7 @@ impl<S: Socket> MySqlStream<S> {
|
|||
.write_with(Packet(payload), (self.capabilities, &mut self.sequence_id));
|
||||
}
|
||||
|
||||
// receive the next packet from the database server
|
||||
// may block (async) on more data from the server
|
||||
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
|
||||
async fn recv_packet_part(&mut self) -> Result<Bytes, Error> {
|
||||
// https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html
|
||||
// https://mariadb.com/kb/en/library/0-packet/#standard-packet
|
||||
|
||||
|
@ -142,10 +140,33 @@ impl<S: Socket> MySqlStream<S> {
|
|||
let payload: Bytes = self.socket.read(packet_size).await?;
|
||||
|
||||
// TODO: packet compression
|
||||
// TODO: packet joining
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
// receive the next packet from the database server
|
||||
// may block (async) on more data from the server
|
||||
pub(crate) async fn recv_packet(&mut self) -> Result<Packet<Bytes>, Error> {
|
||||
let payload = self.recv_packet_part().await?;
|
||||
let payload = if payload.len() < 0xFF_FF_FF {
|
||||
payload
|
||||
} else {
|
||||
let mut final_payload = BytesMut::with_capacity(0xFF_FF_FF * 2);
|
||||
final_payload.extend_from_slice(&payload);
|
||||
|
||||
drop(payload); // we don't need the allocation anymore
|
||||
|
||||
let mut last_read = 0xFF_FF_FF;
|
||||
while last_read == 0xFF_FF_FF {
|
||||
let part = self.recv_packet_part().await?;
|
||||
last_read = part.len();
|
||||
final_payload.extend_from_slice(&part);
|
||||
}
|
||||
final_payload.into()
|
||||
};
|
||||
|
||||
if payload
|
||||
.get(0)
|
||||
.first()
|
||||
.ok_or(err_protocol!("Packet empty"))?
|
||||
.eq(&0xff)
|
||||
{
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use std::cmp::min;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use bytes::Bytes;
|
||||
|
@ -19,6 +20,14 @@ where
|
|||
buf: &mut Vec<u8>,
|
||||
(capabilities, sequence_id): (Capabilities, &'stream mut u8),
|
||||
) {
|
||||
let mut next_header = |len: u32| {
|
||||
let mut buf = len.to_le_bytes();
|
||||
buf[3] = *sequence_id;
|
||||
*sequence_id = sequence_id.wrapping_add(1);
|
||||
|
||||
buf
|
||||
};
|
||||
|
||||
// reserve space to write the prefixed length
|
||||
let offset = buf.len();
|
||||
buf.extend(&[0_u8; 4]);
|
||||
|
@ -31,13 +40,25 @@ where
|
|||
let len = buf.len() - offset - 4;
|
||||
let header = &mut buf[offset..];
|
||||
|
||||
// FIXME: Support larger packets
|
||||
assert!(len < 0xFF_FF_FF);
|
||||
header[..4].copy_from_slice(&next_header(min(len, 0xFF_FF_FF) as u32));
|
||||
|
||||
header[..4].copy_from_slice(&(len as u32).to_le_bytes());
|
||||
header[3] = *sequence_id;
|
||||
// add more packets if we need to split the data
|
||||
if len >= 0xFF_FF_FF {
|
||||
let rest = buf.split_off(offset + 4 + 0xFF_FF_FF);
|
||||
let mut chunks = rest.chunks_exact(0xFF_FF_FF);
|
||||
|
||||
*sequence_id = sequence_id.wrapping_add(1);
|
||||
for chunk in chunks.by_ref() {
|
||||
buf.reserve(chunk.len() + 4);
|
||||
buf.extend(&next_header(chunk.len() as u32));
|
||||
buf.extend(chunk);
|
||||
}
|
||||
|
||||
// this will also handle adding a zero sized packet if the data size is a multiple of 0xFF_FF_FF
|
||||
let remainder = chunks.remainder();
|
||||
buf.reserve(remainder.len() + 4);
|
||||
buf.extend(&next_header(remainder.len() as u32));
|
||||
buf.extend(remainder);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -447,6 +447,39 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn it_can_handle_split_packets() -> anyhow::Result<()> {
|
||||
// This will only take effect on new connections
|
||||
new::<MySql>()
|
||||
.await?
|
||||
.execute("SET GLOBAL max_allowed_packet = 4294967297")
|
||||
.await?;
|
||||
|
||||
let mut conn = new::<MySql>().await?;
|
||||
|
||||
conn.execute(
|
||||
r#"
|
||||
CREATE TEMPORARY TABLE large_table (data LONGBLOB);
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let data = vec![0x41; 0xFF_FF_FF * 2];
|
||||
|
||||
sqlx::query("INSERT INTO large_table (data) VALUES (?)")
|
||||
.bind(&data)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
let ret: Vec<u8> = sqlx::query_scalar("SELECT * FROM large_table")
|
||||
.fetch_one(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert_eq!(ret, data);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn test_shrink_buffers() -> anyhow::Result<()> {
|
||||
// We don't really have a good way to test that `.shrink_buffers()` functions as expected
|
||||
|
|
Loading…
Reference in a new issue