feat(mssql): handle stream flushing

This commit is contained in:
Ryan Leckey 2020-06-06 12:09:15 -07:00
parent c64122c03f
commit 434bfaa76a
9 changed files with 119 additions and 21 deletions

View file

@ -76,6 +76,9 @@ impl MsSqlConnection {
}
}
Ok(Self { stream })
Ok(Self {
stream,
pending_done_count: 0,
})
}
}

View file

@ -7,6 +7,7 @@ use futures_util::TryStreamExt;
use crate::describe::Describe;
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::mssql::protocol::done::Done;
use crate::mssql::protocol::message::Message;
use crate::mssql::protocol::packet::PacketType;
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
@ -14,7 +15,31 @@ use crate::mssql::protocol::sql_batch::SqlBatch;
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow};
impl MsSqlConnection {
async fn wait_until_ready(&mut self) -> Result<(), Error> {
if !self.stream.wbuf.is_empty() {
self.stream.flush().await?;
}
while self.pending_done_count > 0 {
if let Message::DoneProc(done) | Message::Done(done) =
self.stream.recv_message().await?
{
// finished RPC procedure *OR* SQL batch
self.handle_done(done);
}
}
Ok(())
}
fn handle_done(&mut self, _: Done) {
self.pending_done_count -= 1;
}
async fn run(&mut self, query: &str, arguments: Option<MsSqlArguments>) -> Result<(), Error> {
self.wait_until_ready().await?;
self.pending_done_count += 1;
if let Some(mut arguments) = arguments {
let proc = Either::Right(Procedure::ExecuteSql);
let mut proc_args = MsSqlArguments::default();
@ -22,12 +47,14 @@ impl MsSqlConnection {
// SQL
proc_args.add_unnamed(query);
// Declarations
// NAME TYPE, NAME TYPE, ...
proc_args.add_unnamed(&*arguments.declarations);
if !arguments.data.is_empty() {
// Declarations
// NAME TYPE, NAME TYPE, ...
proc_args.add_unnamed(&*arguments.declarations);
// Add the list of SQL parameters _after_ our RPC parameters
proc_args.append(&mut arguments);
// Add the list of SQL parameters _after_ our RPC parameters
proc_args.append(&mut arguments);
}
self.stream.write_packet(
PacketType::Rpc,
@ -72,10 +99,19 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
yield v;
}
Message::Done(done) => {
Message::DoneProc(done) => {
self.handle_done(done);
break;
}
Message::DoneInProc(done) => {
// finished SQL query *within* procedure
let v = Either::Left(done.affected_rows);
yield v;
}
Message::Done(done) => {
self.handle_done(done);
break;
}

View file

@ -16,6 +16,9 @@ mod stream;
pub struct MsSqlConnection {
stream: MsSqlStream,
// number of Done* messages that we are currently expecting
pub(crate) pending_done_count: usize,
}
impl Debug for MsSqlConnection {

View file

@ -13,6 +13,7 @@ use crate::mssql::protocol::info::Info;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::message::{Message, MessageType};
use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status};
use crate::mssql::protocol::return_status::ReturnStatus;
use crate::mssql::protocol::row::Row;
use crate::mssql::{MsSqlConnectOptions, MsSqlDatabaseError};
use crate::net::MaybeTlsStream;
@ -106,13 +107,15 @@ impl MsSqlStream {
};
let ty = MessageType::get(buf)?;
return Ok(match ty {
let message = match ty {
MessageType::EnvChange => Message::EnvChange(EnvChange::get(buf)?),
MessageType::Info => Message::Info(Info::get(buf)?),
MessageType::Row => Message::Row(Row::get(buf, &self.columns)?),
MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?),
MessageType::ReturnStatus => Message::ReturnStatus(ReturnStatus::get(buf)?),
MessageType::Done => Message::Done(Done::get(buf)?),
MessageType::DoneInProc => Message::DoneInProc(Done::get(buf)?),
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),
MessageType::Error => {
let err = ProtocolError::get(buf)?;
@ -125,7 +128,9 @@ impl MsSqlStream {
ColMetaData::get(buf, &mut self.columns)?;
continue;
}
});
};
return Ok(message);
}
// no packet from the server to iterate (or its empty); fill our buffer

View file

@ -3,21 +3,11 @@ use bytes::{Buf, Bytes};
use crate::error::Error;
// Token Stream Function:
// Indicates the completion status of a SQL statementwithin a stored procedure.
// Token Stream Definition:
// DONEINPROC =
// TokenType
// Status
// CurCmd
// DoneRowCount
#[derive(Debug)]
pub(crate) struct Done {
status: Status,
// The token of the current SQL statement. The token value is provided andcontrolled by the
// The token of the current SQL statement. The token value is provided and controlled by the
// application layer, which utilizes TDS. The TDS layer does not evaluate the value.
cursor_command: u16,

View file

@ -6,6 +6,7 @@ use crate::mssql::protocol::env_change::EnvChange;
use crate::mssql::protocol::error::Error;
use crate::mssql::protocol::info::Info;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::return_status::ReturnStatus;
use crate::mssql::protocol::row::Row;
#[derive(Debug)]
@ -14,7 +15,10 @@ pub(crate) enum Message {
LoginAck(LoginAck),
EnvChange(EnvChange),
Done(Done),
DoneInProc(Done),
DoneProc(Done),
Row(Row),
ReturnStatus(ReturnStatus),
}
#[derive(Debug)]
@ -23,9 +27,12 @@ pub(crate) enum MessageType {
LoginAck,
EnvChange,
Done,
DoneProc,
DoneInProc,
Row,
Error,
ColMetaData,
ReturnStatus,
}
impl MessageType {
@ -37,7 +44,10 @@ impl MessageType {
0xad => MessageType::LoginAck,
0xd1 => MessageType::Row,
0xe3 => MessageType::EnvChange,
0x79 => MessageType::ReturnStatus,
0xfd => MessageType::Done,
0xfe => MessageType::DoneProc,
0xff => MessageType::DoneInProc,
ty => {
return Err(err_protocol!(

View file

@ -9,6 +9,7 @@ pub(crate) mod login_ack;
pub(crate) mod message;
pub(crate) mod packet;
pub(crate) mod pre_login;
pub(crate) mod return_status;
pub(crate) mod row;
pub(crate) mod rpc;
pub(crate) mod sql_batch;

View file

@ -0,0 +1,17 @@
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use crate::error::Error;
#[derive(Debug)]
pub(crate) struct ReturnStatus {
value: i32,
}
impl ReturnStatus {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
let value = buf.get_i32_le();
Ok(Self { value })
}
}

View file

@ -1,3 +1,4 @@
use futures::TryStreamExt;
use sqlx::mssql::MsSql;
use sqlx::{Connection, Executor, Row};
use sqlx_core::mssql::MsSqlRow;
@ -40,3 +41,35 @@ async fn it_maths() -> anyhow::Result<()> {
Ok(())
}
#[sqlx_macros::test]
async fn it_executes() -> anyhow::Result<()> {
let mut conn = new::<MsSql>().await?;
let _ = conn
.execute(
r#"
CREATE TABLE #users (id INTEGER PRIMARY KEY);
"#,
)
.await?;
for index in 1..=10_i32 {
let cnt = sqlx::query("INSERT INTO #users (id) VALUES (@p1)")
.bind(index * 2)
.execute(&mut conn)
.await?;
assert_eq!(cnt, 1);
}
let sum: i32 = sqlx::query("SELECT id FROM #users")
.try_map(|row: MsSqlRow| row.try_get::<i32, _>(0))
.fetch(&mut conn)
.try_fold(0_i32, |acc, x| async move { Ok(acc + x) })
.await?;
assert_eq!(sum, 110);
Ok(())
}