mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 14:34:19 +00:00
feat(mssql): handle stream flushing
This commit is contained in:
parent
c64122c03f
commit
434bfaa76a
9 changed files with 119 additions and 21 deletions
|
@ -76,6 +76,9 @@ impl MsSqlConnection {
|
|||
}
|
||||
}
|
||||
|
||||
Ok(Self { stream })
|
||||
Ok(Self {
|
||||
stream,
|
||||
pending_done_count: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ use futures_util::TryStreamExt;
|
|||
use crate::describe::Describe;
|
||||
use crate::error::Error;
|
||||
use crate::executor::{Execute, Executor};
|
||||
use crate::mssql::protocol::done::Done;
|
||||
use crate::mssql::protocol::message::Message;
|
||||
use crate::mssql::protocol::packet::PacketType;
|
||||
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
|
||||
|
@ -14,7 +15,31 @@ use crate::mssql::protocol::sql_batch::SqlBatch;
|
|||
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow};
|
||||
|
||||
impl MsSqlConnection {
|
||||
async fn wait_until_ready(&mut self) -> Result<(), Error> {
|
||||
if !self.stream.wbuf.is_empty() {
|
||||
self.stream.flush().await?;
|
||||
}
|
||||
|
||||
while self.pending_done_count > 0 {
|
||||
if let Message::DoneProc(done) | Message::Done(done) =
|
||||
self.stream.recv_message().await?
|
||||
{
|
||||
// finished RPC procedure *OR* SQL batch
|
||||
self.handle_done(done);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_done(&mut self, _: Done) {
|
||||
self.pending_done_count -= 1;
|
||||
}
|
||||
|
||||
async fn run(&mut self, query: &str, arguments: Option<MsSqlArguments>) -> Result<(), Error> {
|
||||
self.wait_until_ready().await?;
|
||||
self.pending_done_count += 1;
|
||||
|
||||
if let Some(mut arguments) = arguments {
|
||||
let proc = Either::Right(Procedure::ExecuteSql);
|
||||
let mut proc_args = MsSqlArguments::default();
|
||||
|
@ -22,12 +47,14 @@ impl MsSqlConnection {
|
|||
// SQL
|
||||
proc_args.add_unnamed(query);
|
||||
|
||||
// Declarations
|
||||
// NAME TYPE, NAME TYPE, ...
|
||||
proc_args.add_unnamed(&*arguments.declarations);
|
||||
if !arguments.data.is_empty() {
|
||||
// Declarations
|
||||
// NAME TYPE, NAME TYPE, ...
|
||||
proc_args.add_unnamed(&*arguments.declarations);
|
||||
|
||||
// Add the list of SQL parameters _after_ our RPC parameters
|
||||
proc_args.append(&mut arguments);
|
||||
// Add the list of SQL parameters _after_ our RPC parameters
|
||||
proc_args.append(&mut arguments);
|
||||
}
|
||||
|
||||
self.stream.write_packet(
|
||||
PacketType::Rpc,
|
||||
|
@ -72,10 +99,19 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
|
|||
yield v;
|
||||
}
|
||||
|
||||
Message::Done(done) => {
|
||||
Message::DoneProc(done) => {
|
||||
self.handle_done(done);
|
||||
break;
|
||||
}
|
||||
|
||||
Message::DoneInProc(done) => {
|
||||
// finished SQL query *within* procedure
|
||||
let v = Either::Left(done.affected_rows);
|
||||
yield v;
|
||||
}
|
||||
|
||||
Message::Done(done) => {
|
||||
self.handle_done(done);
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
@ -16,6 +16,9 @@ mod stream;
|
|||
|
||||
pub struct MsSqlConnection {
|
||||
stream: MsSqlStream,
|
||||
|
||||
// number of Done* messages that we are currently expecting
|
||||
pub(crate) pending_done_count: usize,
|
||||
}
|
||||
|
||||
impl Debug for MsSqlConnection {
|
||||
|
|
|
@ -13,6 +13,7 @@ use crate::mssql::protocol::info::Info;
|
|||
use crate::mssql::protocol::login_ack::LoginAck;
|
||||
use crate::mssql::protocol::message::{Message, MessageType};
|
||||
use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status};
|
||||
use crate::mssql::protocol::return_status::ReturnStatus;
|
||||
use crate::mssql::protocol::row::Row;
|
||||
use crate::mssql::{MsSqlConnectOptions, MsSqlDatabaseError};
|
||||
use crate::net::MaybeTlsStream;
|
||||
|
@ -106,13 +107,15 @@ impl MsSqlStream {
|
|||
};
|
||||
|
||||
let ty = MessageType::get(buf)?;
|
||||
|
||||
return Ok(match ty {
|
||||
let message = match ty {
|
||||
MessageType::EnvChange => Message::EnvChange(EnvChange::get(buf)?),
|
||||
MessageType::Info => Message::Info(Info::get(buf)?),
|
||||
MessageType::Row => Message::Row(Row::get(buf, &self.columns)?),
|
||||
MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?),
|
||||
MessageType::ReturnStatus => Message::ReturnStatus(ReturnStatus::get(buf)?),
|
||||
MessageType::Done => Message::Done(Done::get(buf)?),
|
||||
MessageType::DoneInProc => Message::DoneInProc(Done::get(buf)?),
|
||||
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),
|
||||
|
||||
MessageType::Error => {
|
||||
let err = ProtocolError::get(buf)?;
|
||||
|
@ -125,7 +128,9 @@ impl MsSqlStream {
|
|||
ColMetaData::get(buf, &mut self.columns)?;
|
||||
continue;
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
return Ok(message);
|
||||
}
|
||||
|
||||
// no packet from the server to iterate (or its empty); fill our buffer
|
||||
|
|
|
@ -3,21 +3,11 @@ use bytes::{Buf, Bytes};
|
|||
|
||||
use crate::error::Error;
|
||||
|
||||
// Token Stream Function:
|
||||
// Indicates the completion status of a SQL statementwithin a stored procedure.
|
||||
|
||||
// Token Stream Definition:
|
||||
// DONEINPROC =
|
||||
// TokenType
|
||||
// Status
|
||||
// CurCmd
|
||||
// DoneRowCount
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Done {
|
||||
status: Status,
|
||||
|
||||
// The token of the current SQL statement. The token value is provided andcontrolled by the
|
||||
// The token of the current SQL statement. The token value is provided and controlled by the
|
||||
// application layer, which utilizes TDS. The TDS layer does not evaluate the value.
|
||||
cursor_command: u16,
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ use crate::mssql::protocol::env_change::EnvChange;
|
|||
use crate::mssql::protocol::error::Error;
|
||||
use crate::mssql::protocol::info::Info;
|
||||
use crate::mssql::protocol::login_ack::LoginAck;
|
||||
use crate::mssql::protocol::return_status::ReturnStatus;
|
||||
use crate::mssql::protocol::row::Row;
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -14,7 +15,10 @@ pub(crate) enum Message {
|
|||
LoginAck(LoginAck),
|
||||
EnvChange(EnvChange),
|
||||
Done(Done),
|
||||
DoneInProc(Done),
|
||||
DoneProc(Done),
|
||||
Row(Row),
|
||||
ReturnStatus(ReturnStatus),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -23,9 +27,12 @@ pub(crate) enum MessageType {
|
|||
LoginAck,
|
||||
EnvChange,
|
||||
Done,
|
||||
DoneProc,
|
||||
DoneInProc,
|
||||
Row,
|
||||
Error,
|
||||
ColMetaData,
|
||||
ReturnStatus,
|
||||
}
|
||||
|
||||
impl MessageType {
|
||||
|
@ -37,7 +44,10 @@ impl MessageType {
|
|||
0xad => MessageType::LoginAck,
|
||||
0xd1 => MessageType::Row,
|
||||
0xe3 => MessageType::EnvChange,
|
||||
0x79 => MessageType::ReturnStatus,
|
||||
0xfd => MessageType::Done,
|
||||
0xfe => MessageType::DoneProc,
|
||||
0xff => MessageType::DoneInProc,
|
||||
|
||||
ty => {
|
||||
return Err(err_protocol!(
|
||||
|
|
|
@ -9,6 +9,7 @@ pub(crate) mod login_ack;
|
|||
pub(crate) mod message;
|
||||
pub(crate) mod packet;
|
||||
pub(crate) mod pre_login;
|
||||
pub(crate) mod return_status;
|
||||
pub(crate) mod row;
|
||||
pub(crate) mod rpc;
|
||||
pub(crate) mod sql_batch;
|
||||
|
|
17
sqlx-core/src/mssql/protocol/return_status.rs
Normal file
17
sqlx-core/src/mssql/protocol/return_status.rs
Normal file
|
@ -0,0 +1,17 @@
|
|||
use bitflags::bitflags;
|
||||
use bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ReturnStatus {
|
||||
value: i32,
|
||||
}
|
||||
|
||||
impl ReturnStatus {
|
||||
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
|
||||
let value = buf.get_i32_le();
|
||||
|
||||
Ok(Self { value })
|
||||
}
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
use futures::TryStreamExt;
|
||||
use sqlx::mssql::MsSql;
|
||||
use sqlx::{Connection, Executor, Row};
|
||||
use sqlx_core::mssql::MsSqlRow;
|
||||
|
@ -40,3 +41,35 @@ async fn it_maths() -> anyhow::Result<()> {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[sqlx_macros::test]
|
||||
async fn it_executes() -> anyhow::Result<()> {
|
||||
let mut conn = new::<MsSql>().await?;
|
||||
|
||||
let _ = conn
|
||||
.execute(
|
||||
r#"
|
||||
CREATE TABLE #users (id INTEGER PRIMARY KEY);
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
|
||||
for index in 1..=10_i32 {
|
||||
let cnt = sqlx::query("INSERT INTO #users (id) VALUES (@p1)")
|
||||
.bind(index * 2)
|
||||
.execute(&mut conn)
|
||||
.await?;
|
||||
|
||||
assert_eq!(cnt, 1);
|
||||
}
|
||||
|
||||
let sum: i32 = sqlx::query("SELECT id FROM #users")
|
||||
.try_map(|row: MsSqlRow| row.try_get::<i32, _>(0))
|
||||
.fetch(&mut conn)
|
||||
.try_fold(0_i32, |acc, x| async move { Ok(acc + x) })
|
||||
.await?;
|
||||
|
||||
assert_eq!(sum, 110);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue