feat(mssql): implement enough to get simple queries working

Co-authored-by: Daniel Akhterov <akhterovd@gmail.com>
This commit is contained in:
Ryan Leckey 2020-06-02 20:47:57 -07:00
parent 05eb07e7d4
commit 9a701313bc
33 changed files with 1772 additions and 33 deletions

View file

@ -60,6 +60,7 @@ runtime-tokio = [ "sqlx-core/runtime-tokio", "sqlx-macros/runtime-tokio" ]
postgres = [ "sqlx-core/postgres", "sqlx-macros/postgres" ]
mysql = [ "sqlx-core/mysql", "sqlx-macros/mysql" ]
sqlite = [ "sqlx-core/sqlite", "sqlx-macros/sqlite" ]
mssql = [ "sqlx-core/mssql" ]
# types
bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros/bigdecimal"]
@ -144,3 +145,12 @@ required-features = [ "postgres" ]
name = "postgres-describe"
path = "tests/postgres/describe.rs"
required-features = [ "postgres" ]
#
# Microsoft SQL Server (MSSQL)
#
[[test]]
name = "mssql"
path = "tests/mssql/mssql.rs"
required-features = [ "mssql" ]

View file

@ -19,7 +19,7 @@ default = [ "runtime-async-std" ]
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ]
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
sqlite = [ "libsqlite3-sys" ]
mssql = [ ]
mssql = [ "uuid" ]
# types
all-types = [ "chrono", "time", "bigdecimal", "ipnetwork", "json", "uuid" ]

View file

@ -0,0 +1,91 @@
use crate::error::Error;
use crate::io::Decode;
use crate::mssql::connection::stream::MsSqlStream;
use crate::mssql::protocol::login::Login7;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::message::Message;
use crate::mssql::protocol::packet::PacketType;
use crate::mssql::protocol::pre_login::{Encrypt, PreLogin, Version};
use crate::mssql::{MsSqlConnectOptions, MsSqlConnection};
impl MsSqlConnection {
pub(crate) async fn establish(options: &MsSqlConnectOptions) -> Result<Self, Error> {
let mut stream: MsSqlStream = MsSqlStream::connect(options).await?;
// Send PRELOGIN to set up the context for login. The server should immediately
// respond with a PRELOGIN message of its own.
// TODO: Encryption
// TODO: Send the version of SQLx over
stream.write_packet(
PacketType::PreLogin,
PreLogin {
version: Version::default(),
encryption: Encrypt::NOT_SUPPORTED,
..Default::default()
},
);
stream.flush().await?;
let (_, packet) = stream.recv_packet().await?;
let pl = PreLogin::decode(packet)?;
log::trace!(
"acknowledged PRELOGIN from MSSQL v{}.{}.{}",
pl.version.major,
pl.version.minor,
pl.version.build
);
// LOGIN7 defines the authentication rules for use between client and server
stream.write_packet(
PacketType::Tds7Login,
Login7 {
// FIXME: use a version constant
version: 0x74000004, // SQL Server 2012 - SQL Server 2019
client_program_version: 0,
client_pid: 0,
packet_size: 4096,
hostname: "",
username: &options.username,
password: options.password.as_deref().unwrap_or_default(),
app_name: "",
server_name: "",
client_interface_name: "",
language: "",
// FIXME: connect this to options.database
database: "",
client_id: [0; 6],
},
);
stream.flush().await?;
loop {
// NOTE: we should receive an [Error] message if something goes wrong, otherwise,
// all messages are mostly informational (ENVCHANGE, INFO, LOGINACK)
match stream.recv_message().await? {
Message::LoginAck(ack) => {
log::trace!(
"established connection to {} {}",
ack.program_name,
ack.program_version
);
}
Message::Done(_) => {
break;
}
_ => {}
}
}
Ok(Self { stream })
}
}

View file

@ -1,24 +1,63 @@
use async_stream::try_stream;
use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::TryStreamExt;
use crate::describe::Describe;
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::mssql::protocol::message::Message;
use crate::mssql::protocol::packet::PacketType;
use crate::mssql::protocol::sql_batch::SqlBatch;
use crate::mssql::{MsSql, MsSqlConnection, MsSqlRow};
impl MsSqlConnection {
async fn run(&mut self, query: &str) -> Result<(), Error> {
self.stream
.write_packet(PacketType::SqlBatch, SqlBatch { sql: query });
self.stream.flush().await?;
Ok(())
}
}
impl<'c> Executor<'c> for &'c mut MsSqlConnection {
type Database = MsSql;
fn fetch_many<'e, 'q: 'e, E: 'q>(
self,
query: E,
mut query: E,
) -> BoxStream<'e, Result<Either<u64, MsSqlRow>, Error>>
where
'c: 'e,
E: Execute<'q, Self::Database>,
{
unimplemented!()
let s = query.query();
// TODO: let arguments = query.take_arguments();
Box::pin(try_stream! {
self.run(s).await?;
loop {
match self.stream.recv_message().await? {
Message::Row(row) => {
let v = Either::Right(MsSqlRow { row });
yield v;
}
Message::Done(done) => {
let v = Either::Left(done.affected_rows);
yield v;
break;
}
_ => {}
}
}
})
}
fn fetch_optional<'e, 'q: 'e, E: 'q>(
@ -29,7 +68,17 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
'c: 'e,
E: Execute<'q, Self::Database>,
{
unimplemented!()
let mut s = self.fetch_many(query);
Box::pin(async move {
while let Some(v) = s.try_next().await? {
if let Either::Right(r) = v {
return Ok(Some(r));
}
}
Ok(None)
})
}
fn describe<'e, 'q: 'e, E: 'q>(

View file

@ -4,11 +4,16 @@ use futures_core::future::BoxFuture;
use crate::connection::{Connect, Connection};
use crate::error::{BoxDynError, Error};
use crate::mssql::connection::stream::MsSqlStream;
use crate::mssql::{MsSql, MsSqlConnectOptions};
mod establish;
mod executor;
mod stream;
pub struct MsSqlConnection {}
pub struct MsSqlConnection {
stream: MsSqlStream,
}
impl Debug for MsSqlConnection {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
@ -44,6 +49,6 @@ impl Connect for MsSqlConnection {
type Options = MsSqlConnectOptions;
fn connect_with(options: &Self::Options) -> BoxFuture<'_, Result<Self, Error>> {
unimplemented!()
Box::pin(MsSqlConnection::establish(options))
}
}

View file

@ -0,0 +1,141 @@
use std::ops::{Deref, DerefMut};
use bytes::Bytes;
use sqlx_rt::{TcpStream, TlsStream};
use crate::error::Error;
use crate::io::{BufStream, Encode};
use crate::mssql::protocol::col_meta_data::{ColMetaData, ColumnData};
use crate::mssql::protocol::done::Done;
use crate::mssql::protocol::env_change::EnvChange;
use crate::mssql::protocol::info::Info;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::message::{Message, MessageType};
use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status};
use crate::mssql::protocol::row::Row;
use crate::mssql::MsSqlConnectOptions;
use crate::net::MaybeTlsStream;
pub(crate) struct MsSqlStream {
inner: BufStream<MaybeTlsStream<TcpStream>>,
// current TabularResult from the server that we are iterating over
response: Option<(PacketHeader, Bytes)>,
// most recent column data from ColMetaData
// we need to store this as its needed when decoding <Row>
columns: Vec<ColumnData>,
}
impl MsSqlStream {
pub(super) async fn connect(options: &MsSqlConnectOptions) -> Result<Self, Error> {
let inner = BufStream::new(MaybeTlsStream::Raw(
TcpStream::connect((&*options.host, options.port)).await?,
));
Ok(Self {
inner,
columns: Vec::new(),
response: None,
})
}
// writes the packet out to the write buffer
// will (eventually) handle packet chunking
pub(super) fn write_packet<'en, T: Encode<'en>>(&mut self, ty: PacketType, payload: T) {
// TODO: Support packet chunking for large packet sizes
// We likely need to double-buffer the writes so we know to chunk
// write out the packet header, leaving room for setting the packet length later
let mut len_offset = 0;
self.inner.write_with(
PacketHeader {
r#type: ty,
status: Status::END_OF_MESSAGE,
length: 0,
server_process_id: 0,
packet_id: 1,
},
&mut len_offset,
);
// write out the payload
self.inner.write(payload);
// overwrite the packet length now that we know it
let len = self.inner.wbuf.len();
self.inner.wbuf[len_offset..(len_offset + 2)].copy_from_slice(&(len as u16).to_be_bytes());
}
// receive the next packet from the database
// blocks until a packet is available
pub(super) async fn recv_packet(&mut self) -> Result<(PacketHeader, Bytes), Error> {
// TODO: Support packet chunking for large packet sizes
let header: PacketHeader = self.inner.read(8).await?;
// NOTE: From what I can tell, the response type from the server should ~always~
// be TabularResult. Here we expect that and die otherwise.
if !matches!(header.r#type, PacketType::TabularResult) {
return Err(err_protocol!(
"received unexpected packet: {:?}",
header.r#type
));
}
let payload_len = (header.length - 8) as usize;
let payload: Bytes = self.inner.read(payload_len).await?;
Ok((header, payload))
}
// receive the next ~message~
// TDS communicates in streams of packets that are themselves streams of messages
pub(super) async fn recv_message(&mut self) -> Result<Message, Error> {
loop {
while self.response.as_ref().map_or(false, |r| !r.1.is_empty()) {
let mut buf = if let Some((_, buf)) = self.response.as_mut() {
buf
} else {
// this shouldn't be reachable but just nope out
// and head to refill our buffer
break;
};
return Ok(match MessageType::get(buf)? {
MessageType::EnvChange => Message::EnvChange(EnvChange::get(buf)?),
MessageType::Info => Message::Info(Info::get(buf)?),
MessageType::Row => Message::Row(Row::get(buf, &self.columns)?),
MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?),
MessageType::Done => Message::Done(Done::get(buf)?),
MessageType::ColMetaData => {
// NOTE: there isn't anything to return as the data gets
// consumed by the stream for use in subsequent Row decoding
ColMetaData::get(buf, &mut self.columns)?;
continue;
}
});
}
// no packet from the server to iterate (or its empty); fill our buffer
self.response = Some(self.recv_packet().await?);
}
}
}
impl Deref for MsSqlStream {
type Target = BufStream<MaybeTlsStream<TcpStream>>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DerefMut for MsSqlStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}

View file

@ -0,0 +1,43 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::BufExt;
pub trait MsSqlBufExt: Buf {
fn get_utf16_str(&mut self, n: usize) -> Result<String, Error>;
fn get_b_varchar(&mut self) -> Result<String, Error>;
fn get_us_varchar(&mut self) -> Result<String, Error>;
fn get_b_varbyte(&mut self) -> Bytes;
}
impl MsSqlBufExt for Bytes {
fn get_utf16_str(&mut self, mut n: usize) -> Result<String, Error> {
let mut raw = Vec::with_capacity(n * 2);
while n > 0 {
let ch = self.get_u16_le();
raw.push(ch);
n -= 1;
}
String::from_utf16(&raw).map_err(Error::protocol)
}
fn get_b_varchar(&mut self) -> Result<String, Error> {
let size = self.get_u8();
self.get_utf16_str(size as usize)
}
fn get_us_varchar(&mut self) -> Result<String, Error> {
let size = self.get_u16_le();
self.get_utf16_str(size as usize)
}
fn get_b_varbyte(&mut self) -> Bytes {
let size = self.get_u8();
self.get_bytes(size as usize)
}
}

View file

@ -0,0 +1,12 @@
pub trait MsSqlBufMutExt {
fn put_utf16_str(&mut self, s: &str);
}
impl MsSqlBufMutExt for Vec<u8> {
fn put_utf16_str(&mut self, s: &str) {
let mut enc = s.encode_utf16();
while let Some(ch) = enc.next() {
self.extend_from_slice(&ch.to_le_bytes());
}
}
}

View file

@ -0,0 +1,5 @@
mod buf;
mod buf_mut;
pub(crate) use buf::MsSqlBufExt;
pub(crate) use buf_mut::MsSqlBufMutExt;

View file

@ -4,10 +4,13 @@ mod arguments;
mod connection;
mod database;
mod error;
mod io;
mod options;
mod protocol;
mod row;
mod transaction;
mod type_info;
pub mod types;
mod value;
pub use arguments::MsSqlArguments;

View file

@ -1,14 +1,78 @@
use std::str::FromStr;
use url::Url;
use crate::error::BoxDynError;
#[derive(Debug, Clone)]
pub struct MsSqlConnectOptions {}
pub struct MsSqlConnectOptions {
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) username: String,
pub(crate) password: Option<String>,
}
impl Default for MsSqlConnectOptions {
fn default() -> Self {
Self::new()
}
}
impl MsSqlConnectOptions {
pub fn new() -> Self {
Self {
port: 1433,
host: String::from("localhost"),
username: String::from("sa"),
password: None,
}
}
pub fn host(mut self, host: &str) -> Self {
self.host = host.to_owned();
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn username(mut self, username: &str) -> Self {
self.username = username.to_owned();
self
}
pub fn password(mut self, password: &str) -> Self {
self.password = Some(password.to_owned());
self
}
}
impl FromStr for MsSqlConnectOptions {
type Err = BoxDynError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
unimplemented!()
let url: Url = s.parse()?;
let mut options = Self::new();
if let Some(host) = url.host_str() {
options = options.host(host);
}
if let Some(port) = url.port() {
options = options.port(port);
}
let username = url.username();
if !username.is_empty() {
options = options.username(username);
}
if let Some(password) = url.password() {
options = options.password(password);
}
Ok(options)
}
}

View file

@ -0,0 +1,117 @@
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::mssql::io::MsSqlBufExt;
use crate::mssql::protocol::type_info::TypeInfo;
#[derive(Debug)]
pub(crate) struct ColMetaData;
#[derive(Debug)]
pub(crate) struct ColumnData {
// The user type ID of the data type of the column. Depending on the TDS version that is used,
// valid values are 0x0000 or 0x00000000, with the exceptions of data type
// TIMESTAMP (0x0050 or 0x00000050) and alias types (greater than 0x00FF or 0x000000FF).
pub(crate) user_type: u32,
pub(crate) flags: Flags,
pub(crate) type_info: TypeInfo,
// TODO: pub(crate) table_name: Option<Vec<String>>,
// TODO: crypto_meta_data: Option<CryptoMetaData>,
// The column name. It contains the column name length and column name.
pub(crate) col_name: String,
}
bitflags! {
pub struct Flags: u16 {
// Its value is 1 if the column is nullable.
const NULLABLE = 0x0001;
// Set to 1 for string columns with binary collation and always for the XML data type.
// Set to 0 otherwise.
const CASE_SEN = 0x0002;
// usUpdateable is a 2-bit field. Its value is 0 if column is read-only, 1 if column is
// read/write and2 if updateable is unknown.
const UPDATEABLE1 = 0x0004;
const UPDATEABLE2 = 0x0008;
// Its value is 1 if the column is an identity column.
const IDENITTY = 0x0010;
// Its value is 1 if the column is a COMPUTED column.
const COMPUTED = 0x0020;
// Its value is 1 if the column is a fixed-length common language runtime
// user-defined type (CLR UDT).
const FIXED_LEN_CLR_TYPE = 0x0100;
// fSparseColumnSet, introduced in TDSversion 7.3.B, is a bit flag. Its value is 1 if the
// column is the special XML column for the sparse column set. For information about using
// column sets, see [MSDN-ColSets]
const SPARSE_COLUMN_SET = 0x0200;
// Its value is 1 if the column is encrypted transparently and
// has to be decrypted to view the plaintext value. This flag is valid when the column
// encryption feature is negotiated between client and server and is turned on.
const ENCRYPTED = 0x0400;
// Its value is 1 if the column is part of a hidden primary key created to support a
// T-SQL SELECT statement containing FOR BROWSE.
const HIDDEN = 0x0800;
// Its value is 1 if the column is part of a primary key for the row
// and the T-SQL SELECT statement contains FOR BROWSE.
const KEY = 0x1000;
// Its value is 1 if it is unknown whether the column might be nullable.
const NULLABLE_UNKNOWN = 0x2000;
}
}
impl ColMetaData {
pub(crate) fn get(buf: &mut Bytes, columns: &mut Vec<ColumnData>) -> Result<(), Error> {
columns.clear();
let mut count = buf.get_u16_le();
if count == 0xffff {
// In the event that the client requested no metadata to be returned, the value of
// Count will be 0xFFFF. This has the same effect on Count as a
// zero value (for example, no ColumnData is sent).
count = 0;
} else {
columns.reserve(count as usize);
}
while count > 0 {
columns.push(ColumnData::get(buf)?);
count -= 1;
}
Ok(())
}
}
impl ColumnData {
fn get(buf: &mut Bytes) -> Result<Self, Error> {
let user_type = buf.get_u32_le();
let flags = Flags::from_bits_truncate(buf.get_u16_le());
let type_info = TypeInfo::get(buf)?;
// TODO: table_name
// TODO: crypto_meta_data
let name = buf.get_b_varchar()?;
Ok(Self {
user_type,
flags,
type_info,
col_name: name,
})
}
}

View file

@ -0,0 +1,64 @@
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use crate::error::Error;
// Token Stream Function:
// Indicates the completion status of a SQL statementwithin a stored procedure.
// Token Stream Definition:
// DONEINPROC =
// TokenType
// Status
// CurCmd
// DoneRowCount
#[derive(Debug)]
pub(crate) struct Done {
status: Status,
// The token of the current SQL statement. The token value is provided andcontrolled by the
// application layer, which utilizes TDS. The TDS layer does not evaluate the value.
cursor_command: u16,
// The count of rows that were affected by the SQL statement. The value of DoneRowCount is
// valid if the value of Status includes DONE_COUNT.
pub(crate) affected_rows: u64, // NOTE: u32 before TDS 7.2
}
impl Done {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
let status = Status::from_bits_truncate(buf.get_u16_le());
let cursor_command = buf.get_u16_le();
let affected_rows = buf.get_u64_le();
Ok(Self {
affected_rows,
status,
cursor_command,
})
}
}
bitflags! {
pub struct Status: u16 {
// This DONEINPROC message is not the final DONE/DONEPROC/DONEINPROC message in
// the response; more data streams are to follow.
const DONE_MORE = 0x0001;
// An error occurred on the current SQL statement or execution of a stored procedure was
// interrupted. A preceding ERROR token SHOULD be sent when this bit is set.
const DONE_ERROR = 0x0002;
// A transaction is in progress.
const DONE_INXACT = 0x0004;
// The DoneRowCount value is valid. This is used to distinguish between a valid value of 0
// for DoneRowCount or just an initialized variable.
const DONE_COUNT = 0x0010;
// Used in place of DONE_ERROR when an error occurred on the current SQL statement that is
// severe enough to require the result set, if any, to be discarded.
const DONE_SRVERROR = 0x0100;
}
}

View file

@ -0,0 +1,54 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::mssql::io::MsSqlBufExt;
#[derive(Debug)]
pub(crate) enum EnvChange {
Database(String),
Language(String),
CharacterSet(String),
PacketSize(String),
UnicodeDataSortingLocalId(String),
UnicodeDataSortingComparisonFlags(String),
SqlCollation(Bytes),
// TDS 7.2+
BeginTransaction,
CommitTransaction,
RollbackTransaction,
EnlistDtcTransaction,
DefectTransaction,
RealTimeLogShipping,
PromoteTransaction,
TransactionManagerAddress,
TransactionEnded,
ResetConnectionCompletionAck,
LoginRequestUserNameAck,
// TDS 7.4+
RoutingInformation,
}
impl EnvChange {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
let len = buf.get_u16_le();
let ty = buf.get_u8();
let mut data = buf.split_to((len - 1) as usize);
Ok(match ty {
1 => EnvChange::Database(data.get_b_varchar()?),
2 => EnvChange::Language(data.get_b_varchar()?),
3 => EnvChange::CharacterSet(data.get_b_varchar()?),
4 => EnvChange::PacketSize(data.get_b_varchar()?),
5 => EnvChange::UnicodeDataSortingLocalId(data.get_b_varchar()?),
6 => EnvChange::UnicodeDataSortingComparisonFlags(data.get_b_varchar()?),
7 => EnvChange::SqlCollation(data.get_b_varbyte()),
_ => {
return Err(err_protocol!("unexpected value {} for ENVCHANGE Type", ty));
}
})
}
}

View file

@ -0,0 +1,40 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::mssql::io::MsSqlBufExt;
#[derive(Debug)]
pub(crate) struct Info {
pub(crate) number: u32,
pub(crate) state: u8,
pub(crate) class: u8,
pub(crate) message: String,
pub(crate) server: String,
pub(crate) procedure: String,
pub(crate) line: u32,
}
impl Info {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
let len = buf.get_u16_le();
let mut data = buf.split_to(len as usize);
let number = data.get_u32_le();
let state = data.get_u8();
let class = data.get_u8();
let message = data.get_us_varchar()?;
let server = data.get_b_varchar()?;
let procedure = data.get_b_varchar()?;
let line = data.get_u32_le();
Ok(Self {
number,
state,
class,
message,
server,
procedure,
line,
})
}
}

View file

@ -0,0 +1,267 @@
use hex::encode;
use std::mem::size_of;
use crate::io::Encode;
use crate::mssql::io::MsSqlBufMutExt;
// Stream definition
// LOGIN7 = Length
// TDSVersion
// PacketSize
// ClientProgVer
// ClientPID
// ConnectionID
// OptionFlags1
// OptionFlags2
// TypeFlags
// OptionFlags3
// ClientTimeZone
// ClientLCID
// OffsetLength
// Data
// FeatureExt
#[derive(Debug)]
pub struct Login7<'a> {
pub version: u32,
pub packet_size: u32,
pub client_program_version: u32,
pub client_pid: u32,
pub hostname: &'a str,
pub username: &'a str,
pub password: &'a str,
pub app_name: &'a str,
pub server_name: &'a str,
pub client_interface_name: &'a str,
pub language: &'a str,
pub database: &'a str,
pub client_id: [u8; 6],
}
impl Encode<'_> for Login7<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
// [Length] The total length of the LOGIN7 structure.
let beg = buf.len();
buf.extend(&0_u32.to_le_bytes());
// [TDSVersion] The highest TDS version supported by the client.
buf.extend(&self.version.to_le_bytes());
// [PacketSize] The packet size being requested by the client.
buf.extend(&self.packet_size.to_le_bytes());
// [ClientProgVer] The version of the **interface** library.
buf.extend(&self.client_program_version.to_le_bytes());
// [ClientPID] The process ID of the client application.
buf.extend(&self.client_pid.to_le_bytes());
// [ConnectionID] The connection ID of the primary server.
buf.extend(&0_u32.to_le_bytes());
// [OptionFlags1]
// 7 | SET_LANG_ON (1) Require a warning message for a language choice statement
// 6 | INIT_DB_FATAL (1) Fail to change to initial database should be fatal
// 5 | USE_DB_ON (1) Require a warning message for a db change statement
// 4 | DUMPLOAD_OFF (0)
// 3-2 | FLOAT_IEEE_754 (0)
// 1 | CHARSET_ASCII (0)
// 0 | ORDER_X86 (0)
buf.push(0b11_10_00_00);
// [OptionsFlags2]
// 6 | INTEGRATED_SECURITY_OFF (0)
// 5-4 | USER_NORMAL (0)
// 3 | <fCacheConnect>
// 2 | <fTransBoundary>
// 1 | ODBC_ON (1)
// 0 | INIT_LANG_FATAL (1)
buf.push(0b00_00_00_11);
// [TypeFlags]
// 2 | <fReadOnlyIntent>
// 1 | OLEDB_OFF (0)
// 0 | SQL_DFLT (0)
buf.push(0);
// [OptionFlags3]
// 4 | <fExtension>
// 3 | <fUnknownCollationHandling>
// 2 | <fUserInstance>
// 1 | <fSendYukonBinaryXML>
// 0 | <fChangePassword>
buf.push(0);
// [ClientTimeZone] This field is not used and can be set to zero.
buf.extend(&0_u32.to_le_bytes());
// [ClientLanguageCodeIdentifier] The language code identifier (LCID) value for
// the client collation.
buf.extend(&0_u32.to_le_bytes());
// [OffsetLength] pre-allocate a space for all offset, length pairs
let mut offsets = buf.len();
buf.resize(buf.len() + 58, 0);
// [Hostname] The client machine name
write_str(buf, &mut offsets, beg, self.hostname);
// [UserName] The client user ID
write_str(buf, &mut offsets, beg, self.username);
// [Password] The password supplied by the client
let password_start = buf.len();
write_str(buf, &mut offsets, beg, self.password);
// Before submitting a password from the client to the server, for every byte in the
// password buffer starting with the position pointed to by ibPassword or
// ibChangePassword, the client SHOULD first swap the four high bits with
// the four low bits and then do a bit-XOR with 0xA5 (10100101).
for i in password_start..buf.len() {
let b = buf[i];
buf[i] = ((b << 4) & 0xf0 | (b >> 4) & 0x0f) ^ 0xa5;
}
// [AppName] The client application name
write_str(buf, &mut offsets, beg, self.app_name);
// [ServerName] The server name
write_str(buf, &mut offsets, beg, self.server_name);
// [Extension] Points to an extension block.
// TODO: Implement to get FeatureExt which should let us use UTF-8
write_offset(buf, &mut offsets, beg);
offsets += 2;
// [CltIntName] The interface library name
write_str(buf, &mut offsets, beg, self.client_interface_name);
// [Language] The initial language (overrides the user IDs language)
write_str(buf, &mut offsets, beg, self.language);
// [Database] The initial database (overrides the user IDs database)
write_str(buf, &mut offsets, beg, self.database);
// [ClientID] The unique client ID. Can be all zero.
buf[offsets..(offsets + 6)].copy_from_slice(&self.client_id);
offsets += 6;
// [SSPI] SSPI data
write_offset(buf, &mut offsets, beg);
offsets += 2;
// [AtchDBFile] The file name for a database that is to be attached
write_offset(buf, &mut offsets, beg);
offsets += 2;
// [ChangePassword] New password for the specified login
write_offset(buf, &mut offsets, beg);
offsets += 2;
// [SSPILong] Used for large SSPI data
offsets += 4;
// Establish the length of the entire structure
let len = buf.len();
buf[beg..beg + 4].copy_from_slice(&((len - beg) as u32).to_le_bytes());
}
}
fn write_offset(buf: &mut Vec<u8>, offsets: &mut usize, beg: usize) {
// The offset must be relative to the beginning of the packet payload, after
// the packet header
let offset = buf.len() - beg;
buf[*offsets..(*offsets + 2)].copy_from_slice(&(offset as u16).to_le_bytes());
*offsets += 2;
}
fn write_str(buf: &mut Vec<u8>, offsets: &mut usize, beg: usize, s: &str) {
// Write the offset
write_offset(buf, offsets, beg);
// Write the length, in UCS-2 characters
buf[*offsets..(*offsets + 2)].copy_from_slice(&(s.len() as u16).to_le_bytes());
*offsets += 2;
// Encode the character sequence as UCS-2 (precursor to UTF16-LE)
buf.put_utf16_str(s);
}
#[test]
fn test_encode_login() {
let mut buf = Vec::new();
let login = Login7 {
version: 0x72090002,
client_program_version: 0x07_00_00_00,
client_pid: 0x0100,
packet_size: 0x1000,
hostname: "skostov1",
username: "sa",
password: "",
app_name: "OSQL-32",
server_name: "",
client_interface_name: "ODBC",
language: "",
database: "",
client_id: [0x00, 0x50, 0x8B, 0xE2, 0xB7, 0x8F],
};
// Adapted from v20191101 of MS-TDS
#[rustfmt::skip]
let expected = vec![
// Packet Header
/* 0x10, 0x01, 0x00, 0x90, 0x00, 0x00, 0x01, 0x00, */
0x88, 0x00, 0x00, 0x00, // Length
0x02, 0x00, 0x09, 0x72, // TDS Version = SQL Server 2005
0x00, 0x10, 0x00, 0x00, // Packet Size = 1048576 or 1 Mi
0x00, 0x00, 0x00, 0x07, // Client Program Version = 7
0x00, 0x01, 0x00, 0x00, // Client PID = 0x01_00_00
0x00, 0x00, 0x00, 0x00, // Connection ID
0xE0, // [OptionFlags1] 0b1110_0000
0x03, // [OptionFlags2] 0b0000_0011
0x00, // [TypeFlags]
0x00, // [OptionFlags3]
0x00, 0x00, 0x00, 0x00, // [ClientTimeZone]
0x00, 0x00, 0x00, 0x00, // [ClientLCID]
0x5E, 0x00, // [ibHostName]
0x08, 0x00, // [cchHostName]
0x6E, 0x00, // [ibUserName]
0x02, 0x00, // [cchUserName]
0x72, 0x00, // [ibPassword]
0x00, 0x00, // [cchPassword]
0x72, 0x00, // [ibAppName]
0x07, 0x00, // [cchAppName]
0x80, 0x00, // [ibServerName]
0x00, 0x00, // [cchServerName]
0x80, 0x00, // [ibUnused]
0x00, 0x00, // [cbUnused]
0x80, 0x00, // [ibCltIntName]
0x04, 0x00, // [cchCltIntName]
0x88, 0x00, // [ibLanguage]
0x00, 0x00, // [cchLanguage]
0x88, 0x00, // [ibDatabase]
0x00, 0x00, // [chDatabase]
0x00, 0x50, 0x8B, // [ClientID]
0xE2, 0xB7, 0x8F,
0x88, 0x00, // [ibSSPI]
0x00, 0x00, // [cchSSPI]
0x88, 0x00, // [ibAtchDBFile]
0x00, 0x00, // [cchAtchDBFile]
0x88, 0x00, // [ibChangePassword]
0x00, 0x00, // [cchChangePassword]
0x00, 0x00, 0x00, 0x00, // [cbSSPILong]
0x73, 0x00, 0x6B, 0x00, 0x6F, 0x00, 0x73, 0x00, 0x74, 0x00, // [Data]
0x6F, 0x00, 0x76, 0x00, 0x31, 0x00, 0x73, 0x00, 0x61, 0x00,
0x4F, 0x00, 0x53, 0x00, 0x51, 0x00, 0x4C, 0x00, 0x2D, 0x00,
0x33, 0x00, 0x32, 0x00, 0x4F, 0x00, 0x44, 0x00, 0x42, 0x00,
0x43, 0x00,
];
login.encode(&mut buf);
assert_eq!(expected, buf);
}

View file

@ -0,0 +1,39 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::mssql::io::MsSqlBufExt;
use crate::mssql::protocol::pre_login::Version;
#[derive(Debug)]
pub(crate) struct LoginAck {
pub(crate) interface: u8,
pub(crate) tds_version: u32,
pub(crate) program_name: String,
pub(crate) program_version: Version,
}
impl LoginAck {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
let len = buf.get_u16_le();
let mut data = buf.split_to(len as usize);
let interface = data.get_u8();
let tds_version = data.get_u32_le();
let program_name = data.get_b_varchar()?;
let program_version_major = data.get_u8();
let program_version_minor = data.get_u8();
let program_version_build = data.get_u16();
Ok(Self {
interface,
tds_version,
program_name,
program_version: Version {
major: program_version_major,
minor: program_version_minor,
build: program_version_build,
sub_build: 0,
},
})
}
}

View file

@ -0,0 +1,49 @@
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::mssql::protocol::col_meta_data::ColMetaData;
use crate::mssql::protocol::done::Done;
use crate::mssql::protocol::env_change::EnvChange;
use crate::mssql::protocol::info::Info;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::row::Row;
#[derive(Debug)]
pub(crate) enum Message {
Info(Info),
LoginAck(LoginAck),
EnvChange(EnvChange),
Done(Done),
Row(Row),
ColMetaData(ColMetaData),
}
#[derive(Debug)]
pub(crate) enum MessageType {
Info,
LoginAck,
EnvChange,
Done,
Row,
ColMetaData,
}
impl MessageType {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
Ok(match buf.get_u8() {
0x81 => MessageType::ColMetaData,
0xab => MessageType::Info,
0xad => MessageType::LoginAck,
0xd1 => MessageType::Row,
0xe3 => MessageType::EnvChange,
0xfd => MessageType::Done,
ty => {
return Err(err_protocol!(
"unknown value `0x{:02x?}` for message type in token stream",
ty
));
}
})
}
}

View file

@ -0,0 +1,12 @@
pub(crate) mod col_meta_data;
pub(crate) mod done;
pub(crate) mod env_change;
pub(crate) mod info;
pub(crate) mod login;
pub(crate) mod login_ack;
pub(crate) mod message;
pub(crate) mod packet;
pub(crate) mod pre_login;
pub(crate) mod row;
pub(crate) mod sql_batch;
pub(crate) mod type_info;

View file

@ -0,0 +1,138 @@
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{Decode, Encode};
#[derive(Debug)]
pub(crate) struct PacketHeader {
// Type defines the type of message. Type is a 1-byte unsigned char.
pub(crate) r#type: PacketType,
// Status is a bit field used to indicate the message state. Status is a 1-byte unsigned char.
pub(crate) status: Status,
// Length is the size of the packet including the 8 bytes in the packet header.
pub(crate) length: u16,
// The process ID on the server, corresponding to the current connection.
pub(crate) server_process_id: u16,
// Packet ID is used for numbering message packets that contain data in addition to the packet
// header. Packet ID is a 1-byte, unsigned char. Each time packet data is sent, the value of
// PacketID is incremented by 1, modulo 256. This allows the receiver to track the sequence
// of TDS packets for a given message. This value is currently ignored.
pub(crate) packet_id: u8,
}
impl<'s> Encode<'s, &'s mut usize> for PacketHeader {
fn encode_with(&self, buf: &mut Vec<u8>, offset: &'s mut usize) {
buf.push(self.r#type as u8);
buf.push(self.status.bits());
*offset = buf.len();
buf.extend(&self.length.to_be_bytes());
buf.extend(&self.server_process_id.to_be_bytes());
buf.push(self.packet_id);
// window, unused
buf.push(0);
}
}
impl Decode<'_> for PacketHeader {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
Ok(Self {
r#type: PacketType::get(buf.get_u8())?,
status: Status::from_bits_truncate(buf.get_u8()),
length: buf.get_u16(),
server_process_id: buf.get_u16(),
packet_id: buf.get_u8(),
})
}
}
#[derive(Debug, Copy, PartialEq, Clone)]
pub(crate) enum PacketType {
// Pre-login. Should always be #18 unless we decide to try and support pre 7.0 TDS
PreTds7Login = 2,
PreLogin = 18,
SqlBatch = 1,
Rpc = 3,
AttentionSignal = 6,
BulkLoadData = 7,
FederatedAuthToken = 8,
TransactionManagerRequest = 14,
Tds7Login = 16,
Sspi = 17,
TabularResult = 4,
}
impl PacketType {
pub fn get(value: u8) -> Result<Self, Error> {
Ok(match value {
1 => PacketType::SqlBatch,
2 => PacketType::PreTds7Login,
3 => PacketType::Rpc,
4 => PacketType::TabularResult,
6 => PacketType::AttentionSignal,
7 => PacketType::BulkLoadData,
8 => PacketType::FederatedAuthToken,
14 => PacketType::TransactionManagerRequest,
16 => PacketType::Tds7Login,
17 => PacketType::Sspi,
18 => PacketType::PreLogin,
ty => {
return Err(err_protocol!("unknown packet type: {}", ty));
}
})
}
}
// Status is a bit field used to indicate the message state. Status is a 1-byte unsigned char.
// The following Status bit flags are defined.
bitflags! {
pub(crate) struct Status: u8 {
// "Normal" message.
const NORMAL = 0x00;
// End of message (EOM). The packet is the last packet in the whole request.
const END_OF_MESSAGE = 0x01;
// (From client to server) Ignore this event (0x01 MUST also be set).
const IGNORE_EVENT = 0x02;
// RESETCONNECTION
//
// (Introduced in TDS 7.1)
//
// (From client to server) Reset this connection
// before processing event. Only set for event types Batch, RPC, or Transaction Manager
// request. If clients want to set this bit, it MUST be part of the first packet of the
// message. This signals the server to clean up the environment state of the connection
// back to the default environment setting, effectively simulating a logout and a
// subsequent login, and provides server support for connection pooling. This bit SHOULD
// be ignored if it is set in a packet that is not the first packet of the message.
//
// This status bit MUST NOT be set in conjunction with the RESETCONNECTIONSKIPTRAN bit.
// Distributed transactions and isolation levels will not be reset.
const RESET_CONN = 0x08;
// RESETCONNECTIONSKIPTRAN
//
// (Introduced in TDS 7.3)
//
// (From client to server) Reset the
// connection before processing event but do not modify the transaction state (the
// state will remain the same before and after the reset). The transaction in the
// session can be a local transaction that is started from the session or it can
// be a distributed transaction in which the session is enlisted. This status bit
// MUST NOT be set in conjunction with the RESETCONNECTION bit.
// Otherwise identical to RESETCONNECTION.
const RESET_CONN_SKIP_TRAN = 0x10;
}
}

View file

@ -0,0 +1,311 @@
use std::fmt::{self, Display, Formatter};
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use uuid::Uuid;
use crate::error::Error;
use crate::io::{Decode, Encode};
/// A message sent by the client to set up context for login. The server responds to a client
/// `PRELOGIN` message with a message of packet header type `0x04` and the packet data
/// containing a `PRELOGIN` structure.
#[derive(Debug, Default)]
pub(crate) struct PreLogin<'a> {
pub(crate) version: Version,
pub(crate) encryption: Encrypt,
pub(crate) instance: Option<&'a str>,
pub(crate) thread_id: Option<u32>,
pub(crate) trace_id: Option<TraceId>,
pub(crate) multiple_active_result_sets: Option<bool>,
}
impl<'de> Decode<'de> for PreLogin<'de> {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
let mut version = None;
let mut encryption = None;
// TODO: Decode the remainder of the structure
// let mut instance = None;
// let mut thread_id = None;
// let mut trace_id = None;
// let mut multiple_active_result_sets = None;
let mut offsets = buf.clone();
loop {
let token = offsets.get_u8();
match PreLoginOptionToken::get(token) {
Some(token) => {
let offset = offsets.get_u16() as usize;
let size = offsets.get_u16() as usize;
let mut data = &buf[offset..offset + size];
match token {
PreLoginOptionToken::Version => {
let major = data.get_u8();
let minor = data.get_u8();
let build = data.get_u16();
let sub_build = data.get_u16();
version = Some(Version {
major,
minor,
build,
sub_build,
});
}
PreLoginOptionToken::Encryption => {
encryption = Some(Encrypt::from_bits_truncate(data.get_u8()));
}
tok => todo!("{:?}", tok),
}
}
None if token == 0xff => {
break;
}
None => {
return Err(err_protocol!(
"PRELOGIN: unexpected login option token: 0x{:02?}",
token
)
.into());
}
}
}
let version =
version.ok_or(err_protocol!("PRELOGIN: missing required `version` option"))?;
let encryption = encryption.ok_or(err_protocol!(
"PRELOGIN: missing required `encryption` option"
))?;
Ok(Self {
version,
encryption,
..Default::default()
})
}
}
impl Encode<'_> for PreLogin<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
use PreLoginOptionToken::*;
// NOTE: Packet headers are written in MsSqlStream::write
// Rules
// PRELOGIN = (*PRELOGIN_OPTION *PL_OPTION_DATA) / SSL_PAYLOAD
// PRELOGIN_OPTION = (PL_OPTION_TOKEN PL_OFFSET PL_OPTION_LENGTH) / TERMINATOR
// Count the number of set options
let num_options = 2
+ self.instance.map_or(0, |_| 1)
+ self.thread_id.map_or(0, |_| 1)
+ self.trace_id.as_ref().map_or(0, |_| 1)
+ self.multiple_active_result_sets.map_or(0, |_| 1);
// Calculate the length of the option offset block. Each block is 5 bytes and it ends in
// a 1 byte terminator.
let len_offsets = (num_options * 5) + 1;
let mut offsets = buf.len() as usize;
let mut offset = len_offsets as u16;
// Reserve a chunk for the offset block and set the final terminator
buf.resize(buf.len() + len_offsets, 0);
let end_offsets = buf.len() - 1;
buf[end_offsets] = 0xff;
// NOTE: VERSION is a required token, and it MUST be the first token.
Version.put(buf, &mut offsets, &mut offset, 6);
self.version.encode(buf);
Encryption.put(buf, &mut offsets, &mut offset, 1);
buf.push(self.encryption.bits());
if let Some(name) = self.instance {
Instance.put(buf, &mut offsets, &mut offset, name.len() as u16 + 1);
buf.extend_from_slice(name.as_bytes());
buf.push(b'\0');
}
if let Some(id) = self.thread_id {
ThreadId.put(buf, &mut offsets, &mut offset, 4);
buf.extend_from_slice(&id.to_le_bytes());
}
if let Some(trace) = &self.trace_id {
ThreadId.put(buf, &mut offsets, &mut offset, 36);
buf.extend_from_slice(trace.connection_id.as_bytes());
buf.extend_from_slice(trace.activity_id.as_bytes());
buf.extend_from_slice(&trace.activity_seq.to_be_bytes());
}
if let Some(mars) = &self.multiple_active_result_sets {
MultipleActiveResultSets.put(buf, &mut offsets, &mut offset, 1);
buf.push(*mars as u8);
}
}
}
// token value representing the option (PL_OPTION_TOKEN)
#[derive(Debug, Copy, Clone)]
#[repr(u8)]
enum PreLoginOptionToken {
Version = 0x00,
Encryption = 0x01,
Instance = 0x02,
ThreadId = 0x03,
// Multiple Active Result Sets (MARS)
MultipleActiveResultSets = 0x04,
TraceId = 0x05,
}
impl PreLoginOptionToken {
fn put(self, buf: &mut Vec<u8>, pos: &mut usize, offset: &mut u16, len: u16) {
buf[*pos] = self as u8;
*pos += 1;
buf[*pos..(*pos + 2)].copy_from_slice(&offset.to_be_bytes());
*pos += 2;
buf[*pos..(*pos + 2)].copy_from_slice(&len.to_be_bytes());
*pos += 2;
*offset += len;
}
fn get(b: u8) -> Option<Self> {
Some(match b {
0x00 => PreLoginOptionToken::Version,
0x01 => PreLoginOptionToken::Encryption,
0x02 => PreLoginOptionToken::Instance,
0x03 => PreLoginOptionToken::ThreadId,
0x04 => PreLoginOptionToken::MultipleActiveResultSets,
0x05 => PreLoginOptionToken::TraceId,
_ => {
return None;
}
})
}
}
#[derive(Debug)]
pub(crate) struct TraceId {
// client application trace ID (GUID_CONNID)
pub(crate) connection_id: Uuid,
// client application activity ID (GUID_ActivityID)
pub(crate) activity_id: Uuid,
// client application activity sequence (ActivitySequence)
pub(crate) activity_seq: u32,
}
// Version of the sender (UL_VERSION)
#[derive(Debug, Default)]
pub(crate) struct Version {
pub(crate) major: u8,
pub(crate) minor: u8,
pub(crate) build: u16,
// Sub-build number of the sender (US_SUBBUILD)
pub(crate) sub_build: u16,
}
impl Version {
fn encode(&self, buf: &mut Vec<u8>) {
buf.push(self.major);
buf.push(self.minor);
buf.extend(&self.build.to_be_bytes());
buf.extend(&self.sub_build.to_be_bytes());
}
}
impl Display for Version {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "v{}.{}.{}", self.major, self.minor, self.build)
}
}
bitflags! {
/// During the Pre-Login handshake, the client and the server negotiate the
/// wire encryption to be used.
#[derive(Default)]
pub(crate) struct Encrypt: u8 {
/// Encryption is available but on.
const ON = 0x01;
/// Encryption is not available.
const NOT_SUPPORTED = 0x02;
/// Encryption is required.
const REQUIRED = 0x03;
/// The client certificate should be used to authenticate
/// the user in place of a user/password.
const CLIENT_CERT = 0x80;
}
}
#[test]
fn test_encode_pre_login() {
let mut buf = Vec::new();
let pre_login = PreLogin {
version: Version {
major: 9,
minor: 0,
build: 0,
sub_build: 0,
},
encryption: Encrypt::ON,
instance: Some(""),
thread_id: Some(0x00000DB8),
multiple_active_result_sets: Some(true),
..Default::default()
};
// From v20191101 of MS-TDS documentation
#[rustfmt::skip]
let expected = vec![
0x00, 0x00, 0x1A, 0x00, 0x06, 0x01, 0x00, 0x20, 0x00, 0x01, 0x02, 0x00, 0x21, 0x00,
0x01, 0x03, 0x00, 0x22, 0x00, 0x04, 0x04, 0x00, 0x26, 0x00, 0x01, 0xFF, 0x09, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0xB8, 0x0D, 0x00, 0x00, 0x01
];
pre_login.encode(&mut buf);
assert_eq!(expected, buf);
}
#[test]
fn test_decode_pre_login() {
#[rustfmt::skip]
let buffer = Bytes::from_static(&[
0, 0, 11, 0, 6, 1, 0, 17, 0, 1, 255,
14, 0, 12, 209, 0, 0, 0,
]);
let pre_login = PreLogin::decode(buffer).unwrap();
// v14.0.3281
assert_eq!(pre_login.version.major, 14);
assert_eq!(pre_login.version.minor, 0);
assert_eq!(pre_login.version.build, 3281);
assert_eq!(pre_login.version.sub_build, 0);
// ENCRYPT_OFF
assert_eq!(pre_login.encryption.bits(), 0);
}

View file

@ -0,0 +1,37 @@
use std::ops::Range;
use bytes::Bytes;
use crate::error::Error;
use crate::mssql::protocol::col_meta_data::ColumnData;
use crate::mssql::{MsSql, MsSqlTypeInfo};
#[derive(Debug)]
pub(crate) struct Row {
// TODO: Column names?
// FIXME: Columns Vec should be an Arc<_>
pub(crate) column_types: Vec<MsSqlTypeInfo>,
pub(crate) values: Vec<Option<Bytes>>,
}
impl Row {
pub(crate) fn get(buf: &mut Bytes, columns: &[ColumnData]) -> Result<Self, Error> {
let mut values = Vec::with_capacity(columns.len());
let mut column_types = Vec::with_capacity(columns.len());
for column in columns {
column_types.push(MsSqlTypeInfo(column.type_info.clone()));
if column.type_info.is_null() {
values.push(None);
} else {
values.push(Some(buf.split_to(column.type_info.size())));
}
}
Ok(Self {
values,
column_types,
})
}
}

View file

@ -0,0 +1,35 @@
use crate::io::Encode;
use crate::mssql::io::MsSqlBufMutExt;
const HEADER_TRANSACTION_DESCRIPTOR: u16 = 0x00_02;
#[derive(Debug)]
pub(crate) struct SqlBatch<'a> {
pub(crate) sql: &'a str,
}
impl Encode<'_> for SqlBatch<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
// ALL_HEADERS -> TotalLength
buf.extend(&(4_u32 + 18).to_le_bytes()); // 4 + 18
// [Header] Transaction Descriptor
// SQL_BATCH messages require this header
// contains information regarding number of outstanding requests for MARS
buf.extend(&18_u32.to_le_bytes()); // 4 + 2 + 8 + 4
buf.extend(&HEADER_TRANSACTION_DESCRIPTOR.to_le_bytes());
// [TransactionDescriptor] a number that uniquely identifies the current transaction
// TODO: use this once we support transactions, it will be given to us from the
// server ENVCHANGE event
buf.extend(&0_u64.to_le_bytes());
// [OutstandingRequestCount] Number of active requests to MSSQL from the
// same connection
// NOTE: Long-term when we support MARS we need to connect this value correctly
buf.extend(&(1_u32.to_le_bytes()));
// SQLText
buf.put_utf16_str(self.sql);
}
}

View file

@ -0,0 +1,78 @@
use crate::error::Error;
use bytes::{Buf, Bytes};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum DataType {
// Fixed-length data types
// https://docs.microsoft.com/en-us/openspecs/sql_server_protocols/ms-sstds/d33ef17b-7e53-4380-ad11-2ba42c8dda8d
Null = 0x1f,
TinyInt = 0x30,
Bit = 0x32,
SmallInt = 0x34,
Int = 0x38,
SmallDateTime = 0x3a,
Real = 0x3b,
Money = 0x3c,
DateTime = 0x3d,
Float = 0x3e,
SmallMoney = 0x7a,
BigInt = 0x7f,
}
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct TypeInfo {
pub(crate) ty: DataType,
}
impl TypeInfo {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
let ty = DataType::get(buf)?;
Ok(Self { ty })
}
pub(crate) fn is_null(&self) -> bool {
matches!(self.ty, DataType::Null)
}
pub(crate) fn size(&self) -> usize {
match self.ty {
DataType::Null => 0,
DataType::TinyInt => 1,
DataType::Bit => 1,
DataType::SmallInt => 2,
DataType::Int => 4,
DataType::SmallDateTime => 4,
DataType::Real => 4,
DataType::Money => 4,
DataType::DateTime => 8,
DataType::Float => 8,
DataType::SmallMoney => 4,
DataType::BigInt => 8,
}
}
}
impl DataType {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
Ok(match buf.get_u8() {
0x1f => DataType::Null,
0x30 => DataType::TinyInt,
0x32 => DataType::Bit,
0x34 => DataType::SmallInt,
0x38 => DataType::Int,
0x3a => DataType::SmallDateTime,
0x3b => DataType::Real,
0x3c => DataType::Money,
0x3d => DataType::DateTime,
0x3e => DataType::Float,
0x7a => DataType::SmallMoney,
0x7f => DataType::BigInt,
ty => {
return Err(err_protocol!("unknown data type 0x{:02x}", ty));
}
})
}
}

View file

@ -1,8 +1,11 @@
use crate::error::Error;
use crate::mssql::protocol::row::Row as ProtocolRow;
use crate::mssql::{MsSql, MsSqlValueRef};
use crate::row::{ColumnIndex, Row};
pub struct MsSqlRow {}
pub struct MsSqlRow {
pub(crate) row: ProtocolRow,
}
impl crate::row::private_row::Sealed for MsSqlRow {}
@ -11,13 +14,19 @@ impl Row for MsSqlRow {
#[inline]
fn len(&self) -> usize {
unimplemented!()
self.row.values.len()
}
fn try_get_raw<I>(&self, index: I) -> Result<MsSqlValueRef<'_>, Error>
where
I: ColumnIndex<Self>,
{
unimplemented!()
let index = index.index(self)?;
let value = MsSqlValueRef {
data: self.row.values[index].as_ref(),
type_info: self.row.column_types[index].clone(),
};
Ok(value)
}
}

View file

@ -1,9 +1,10 @@
use std::fmt::{self, Display, Formatter};
use crate::mssql::protocol::type_info::TypeInfo as ProtocolTypeInfo;
use crate::type_info::TypeInfo;
#[derive(Debug, Clone)]
pub struct MsSqlTypeInfo {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MsSqlTypeInfo(pub(crate) ProtocolTypeInfo);
impl TypeInfo for MsSqlTypeInfo {}
@ -12,11 +13,3 @@ impl Display for MsSqlTypeInfo {
unimplemented!()
}
}
impl PartialEq<MsSqlTypeInfo> for MsSqlTypeInfo {
fn eq(&self, other: &MsSqlTypeInfo) -> bool {
unimplemented!()
}
}
impl Eq for MsSqlTypeInfo {}

View file

@ -0,0 +1,20 @@
use byteorder::{ByteOrder, LittleEndian};
use crate::database::{Database, HasValueRef};
use crate::decode::Decode;
use crate::error::BoxDynError;
use crate::mssql::protocol::type_info::{DataType, TypeInfo};
use crate::mssql::{MsSql, MsSqlTypeInfo, MsSqlValueRef};
use crate::types::Type;
impl Type<MsSql> for i32 {
fn type_info() -> MsSqlTypeInfo {
MsSqlTypeInfo(TypeInfo { ty: DataType::Int })
}
}
impl Decode<'_, MsSql> for i32 {
fn decode(value: MsSqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(LittleEndian::read_i32(value.as_bytes()?))
}
}

View file

@ -0,0 +1 @@
mod int;

View file

@ -1,48 +1,70 @@
use std::borrow::Cow;
use std::marker::PhantomData;
use bytes::Bytes;
use crate::database::HasValueRef;
use crate::error::{BoxDynError, UnexpectedNullError};
use crate::mssql::{MsSql, MsSqlTypeInfo};
use crate::value::{Value, ValueRef};
/// Implementation of [`ValueRef`] for MSSQL.
#[derive(Clone)]
pub struct MsSqlValueRef<'r> {
phantom: PhantomData<&'r ()>,
pub(crate) type_info: MsSqlTypeInfo,
pub(crate) data: Option<&'r Bytes>,
}
impl<'r> MsSqlValueRef<'r> {
pub(crate) fn as_bytes(&self) -> Result<&'r [u8], BoxDynError> {
match &self.data {
Some(v) => Ok(v),
None => Err(UnexpectedNullError.into()),
}
}
}
impl ValueRef<'_> for MsSqlValueRef<'_> {
type Database = MsSql;
fn to_owned(&self) -> MsSqlValue {
unimplemented!()
MsSqlValue {
data: self.data.cloned(),
type_info: self.type_info.clone(),
}
}
fn type_info(&self) -> Option<Cow<'_, MsSqlTypeInfo>> {
unimplemented!()
Some(Cow::Borrowed(&self.type_info))
}
fn is_null(&self) -> bool {
unimplemented!()
self.data.is_none()
}
}
/// Implementation of [`Value`] for MSSQL.
#[derive(Clone)]
pub struct MsSqlValue {}
pub struct MsSqlValue {
pub(crate) type_info: MsSqlTypeInfo,
pub(crate) data: Option<Bytes>,
}
impl Value for MsSqlValue {
type Database = MsSql;
fn as_ref(&self) -> <Self::Database as HasValueRef<'_>>::ValueRef {
unimplemented!()
fn as_ref(&self) -> MsSqlValueRef<'_> {
MsSqlValueRef {
data: self.data.as_ref(),
type_info: self.type_info.clone(),
}
}
fn type_info(&self) -> Option<Cow<'_, MsSqlTypeInfo>> {
unimplemented!()
Some(Cow::Borrowed(&self.type_info))
}
fn is_null(&self) -> bool {
unimplemented!()
self.data.is_none()
}
}

View file

@ -26,6 +26,10 @@ pub use sqlx_core::error::{self, BoxDynError, Error, Result};
#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))]
pub use sqlx_core::mysql::{self, MySql, MySqlConnection, MySqlPool};
#[cfg(feature = "mssql")]
#[cfg_attr(docsrs, doc(cfg(feature = "mssql")))]
pub use sqlx_core::mssql::{self, MsSql, MsSqlConnection, MsSqlPool};
#[cfg(feature = "postgres")]
#[cfg_attr(docsrs, doc(cfg(feature = "postgres")))]
pub use sqlx_core::postgres::{self, PgConnection, PgPool, Postgres};

View file

@ -180,13 +180,13 @@ services:
#
mssql_2019:
image: microsoft-mssql-server:2019-latest
image: mcr.microsoft.com/mssql/server:2019-latest
environment:
ACCEPT_EULA: Y
SA_PASSWORD: Password123!
mssql_2017:
image: microsoft-mssql-server:2017-latest
image: mcr.microsoft.com/mssql/server:2017-latest
environment:
ACCEPT_EULA: Y
SA_PASSWORD: Password123!

27
tests/mssql/mssql.rs Normal file
View file

@ -0,0 +1,27 @@
use sqlx::mssql::MsSql;
use sqlx::{Connection, Executor, Row};
use sqlx_core::mssql::MsSqlRow;
use sqlx_test::new;
#[sqlx_macros::test]
async fn it_connects() -> anyhow::Result<()> {
let mut conn = new::<MsSql>().await?;
conn.ping().await?;
conn.close().await?;
Ok(())
}
#[sqlx_macros::test]
async fn it_can_select_1() -> anyhow::Result<()> {
let mut conn = new::<MsSql>().await?;
let row: MsSqlRow = conn.fetch_one("SELECT 4").await?;
let v: i32 = row.try_get(0)?;
assert_eq!(v, 4);
Ok(())
}

View file

@ -8,7 +8,6 @@ async fn it_connects() -> anyhow::Result<()> {
let mut conn = new::<MySql>().await?;
conn.ping().await?;
conn.close().await?;
Ok(())