feat(mssql): fix a few bugs and implement Connection::describe

This commit is contained in:
Ryan Leckey 2020-06-07 01:59:59 -07:00
parent 559169cc79
commit ef2527ff3e
27 changed files with 424 additions and 61 deletions

1
.gitattributes vendored Normal file
View file

@ -0,0 +1 @@
* text=auto eol=lf

2
Cargo.lock generated
View file

@ -2112,10 +2112,12 @@ dependencies = [
"md-5",
"memchr",
"num-bigint",
"once_cell",
"parking_lot 0.10.2",
"percent-encoding 2.1.0",
"phf",
"rand",
"regex",
"serde",
"serde_json",
"sha-1",

View file

@ -159,3 +159,8 @@ required-features = [ "mssql" ]
name = "mssql-types"
path = "tests/mssql/types.rs"
required-features = [ "mssql" ]
[[test]]
name = "mssql-describe"
path = "tests/mssql/describe.rs"
required-features = [ "mssql" ]

View file

@ -19,7 +19,7 @@ default = [ "runtime-async-std" ]
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ]
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
sqlite = [ "libsqlite3-sys" ]
mssql = [ "uuid", "encoding_rs" ]
mssql = [ "uuid", "encoding_rs", "regex" ]
# types
all-types = [ "chrono", "time", "bigdecimal", "ipnetwork", "json", "uuid" ]
@ -65,11 +65,13 @@ log = { version = "0.4.8", default-features = false }
md-5 = { version = "0.8.0", default-features = false, optional = true }
memchr = { version = "2.3.3", default-features = false }
num-bigint = { version = "0.2.6", default-features = false, optional = true, features = [ "std" ] }
once_cell = "1.4.0"
percent-encoding = "2.1.0"
parking_lot = "0.10.2"
threadpool = "*"
phf = { version = "0.8.0", features = [ "macros" ] }
rand = { version = "0.7.3", default-features = false, optional = true, features = [ "std" ] }
regex = { version = "1.3.9", optional = true }
serde = { version = "1.0.106", features = [ "derive", "rc" ], optional = true }
serde_json = { version = "1.0.51", features = [ "raw_value" ], optional = true }
sha-1 = { version = "0.8.2", default-features = false, optional = true }

View file

@ -68,36 +68,49 @@ where
}
}
impl<'q, T: 'q + Encode<'q, DB>, DB: Database> Encode<'q, DB> for Option<T> {
#[inline]
fn produces(&self) -> DB::TypeInfo {
if let Some(v) = self {
v.produces()
} else {
T::type_info()
}
}
#[allow(unused_macros)]
macro_rules! impl_encode_for_option {
($DB:ident) => {
impl<'q, T: 'q + crate::encode::Encode<'q, $DB>> crate::encode::Encode<'q, $DB>
for Option<T>
{
#[inline]
fn produces(&self) -> <$DB as crate::database::Database>::TypeInfo {
if let Some(v) = self {
v.produces()
} else {
T::type_info()
}
}
#[inline]
fn encode(self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
if let Some(v) = self {
v.encode(buf)
} else {
IsNull::Yes
}
}
#[inline]
fn encode(
self,
buf: &mut <$DB as crate::database::HasArguments<'q>>::ArgumentBuffer,
) -> crate::encode::IsNull {
if let Some(v) = self {
v.encode(buf)
} else {
crate::encode::IsNull::Yes
}
}
#[inline]
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
if let Some(v) = self {
v.encode_by_ref(buf)
} else {
IsNull::Yes
}
}
#[inline]
fn encode_by_ref(
&self,
buf: &mut <$DB as crate::database::HasArguments<'q>>::ArgumentBuffer,
) -> crate::encode::IsNull {
if let Some(v) = self {
v.encode_by_ref(buf)
} else {
crate::encode::IsNull::Yes
}
}
#[inline]
fn size_hint(&self) -> usize {
self.as_ref().map_or(0, Encode::size_hint)
}
#[inline]
fn size_hint(&self) -> usize {
self.as_ref().map_or(0, crate::encode::Encode::size_hint)
}
}
};
}

View file

@ -31,10 +31,12 @@ pub mod connection;
#[macro_use]
pub mod transaction;
#[macro_use]
pub mod encode;
pub mod database;
pub mod decode;
pub mod describe;
pub mod encode;
pub mod executor;
mod ext;
pub mod from_row;
@ -59,3 +61,7 @@ pub mod sqlite;
#[cfg(feature = "mysql")]
#[cfg_attr(docsrs, doc(cfg(feature = "mysql")))]
pub mod mysql;
#[cfg(feature = "mssql")]
#[cfg_attr(docsrs, doc(cfg(feature = "mssql")))]
pub mod mssql;

View file

@ -2,6 +2,7 @@ use crate::arguments::Arguments;
use crate::encode::Encode;
use crate::mssql::database::MsSql;
use crate::mssql::io::MsSqlBufMutExt;
use crate::mssql::protocol::rpc::StatusFlags;
#[derive(Default)]
pub struct MsSqlArguments {
@ -31,6 +32,19 @@ impl MsSqlArguments {
self.add_named("", value);
}
pub(crate) fn declare<'q, T: Encode<'q, MsSql>>(&mut self, name: &str, initial_value: T) {
let ty = initial_value.produces();
let mut ty_name = String::new();
ty.0.fmt(&mut ty_name);
self.data.put_b_varchar(name); // [ParamName]
self.data.push(StatusFlags::BY_REF_VALUE.bits()); // [StatusFlags]
ty.0.put(&mut self.data); // [TYPE_INFO]
ty.0.put_value(&mut self.data, initial_value); // [ParamLenData]
}
pub(crate) fn append(&mut self, arguments: &mut MsSqlArguments) {
self.ordinal += arguments.ordinal;
self.data.append(&mut arguments.data);

View file

@ -49,8 +49,7 @@ impl MsSqlConnection {
server_name: "",
client_interface_name: "",
language: "",
// FIXME: connect this to options.database
database: "",
database: &*options.database,
client_id: [0; 6],
},
);

View file

@ -3,16 +3,19 @@ use either::Either;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_util::TryStreamExt;
use once_cell::sync::Lazy;
use regex::Regex;
use crate::describe::Describe;
use crate::describe::{Column, Describe};
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::mssql::protocol::done::Done;
use crate::mssql::protocol::col_meta_data::Flags;
use crate::mssql::protocol::done::{Done, Status};
use crate::mssql::protocol::message::Message;
use crate::mssql::protocol::packet::PacketType;
use crate::mssql::protocol::rpc::{OptionFlags, Procedure, RpcRequest};
use crate::mssql::protocol::sql_batch::SqlBatch;
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow};
use crate::mssql::{MsSql, MsSqlArguments, MsSqlConnection, MsSqlRow, MsSqlTypeInfo};
impl MsSqlConnection {
pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> {
@ -25,8 +28,10 @@ impl MsSqlConnection {
let message = self.stream.recv_message().await?;
if let Message::DoneProc(done) | Message::Done(done) = message {
// finished RPC procedure *OR* SQL batch
self.handle_done(done);
if !done.status.contains(Status::DONE_MORE) {
// finished RPC procedure *OR* SQL batch
self.handle_done(done);
}
}
}
@ -106,20 +111,23 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
yield v;
}
Message::DoneProc(done) => {
self.handle_done(done);
break;
Message::Done(done) | Message::DoneProc(done) => {
if done.status.contains(Status::DONE_COUNT) {
let v = Either::Left(done.affected_rows);
yield v;
}
if !done.status.contains(Status::DONE_MORE) {
self.handle_done(done);
break;
}
}
Message::DoneInProc(done) => {
// finished SQL query *within* procedure
let v = Either::Left(done.affected_rows);
yield v;
}
Message::Done(done) => {
self.handle_done(done);
break;
if done.status.contains(Status::DONE_COUNT) {
let v = Either::Left(done.affected_rows);
yield v;
}
}
_ => {}
@ -157,6 +165,90 @@ impl<'c> Executor<'c> for &'c mut MsSqlConnection {
'c: 'e,
E: Execute<'q, Self::Database>,
{
unimplemented!()
let s = query.query();
// [sp_prepare] will emit the column meta data
// small issue is that we need to declare all the used placeholders with a "fallback" type
// we currently use regex to collect them; false positives are *okay* but false
// negatives would break the query
let proc = Either::Right(Procedure::Prepare);
// NOTE: this does not support unicode identifiers; as we don't even support
// named parameters (yet) this is probably fine, for now
static PARAMS_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"@p[[:alnum:]]+").unwrap());
let mut params = String::new();
let mut num_params = 0;
for m in PARAMS_RE.captures_iter(s) {
if !params.is_empty() {
params.push_str(",");
}
params.push_str(&m[0]);
// NOTE: this means that a query! of `SELECT @p1` will have the macros believe
// it will return nvarchar(1); this is a greater issue with `query!` that we
// we need to circle back to. This doesn't happen much in practice however.
params.push_str(" nvarchar(1)");
num_params += 1;
}
let params = if params.is_empty() {
None
} else {
Some(&*params)
};
let mut args = MsSqlArguments::default();
args.declare("", 0_i32);
args.add_unnamed(params);
args.add_unnamed(s);
args.add_unnamed(0x0001_i32); // 1 = SEND_METADATA
self.stream.write_packet(
PacketType::Rpc,
RpcRequest {
transaction_descriptor: self.stream.transaction_descriptor,
arguments: &args,
procedure: proc,
options: OptionFlags::empty(),
},
);
Box::pin(async move {
self.stream.flush().await?;
loop {
match self.stream.recv_message().await? {
Message::DoneProc(done) | Message::Done(done) => {
if !done.status.contains(Status::DONE_MORE) {
// done with prepare
break;
}
}
_ => {}
}
}
let mut columns = Vec::with_capacity(self.stream.columns.len());
for col in &self.stream.columns {
columns.push(Column {
name: col.col_name.clone(),
type_info: Some(MsSqlTypeInfo(col.type_info.clone())),
not_null: Some(!col.flags.contains(Flags::NULLABLE)),
});
}
Ok(Describe {
params: vec![None; num_params],
columns,
})
})
}
}

View file

@ -14,6 +14,7 @@ use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::message::{Message, MessageType};
use crate::mssql::protocol::packet::{PacketHeader, PacketType, Status};
use crate::mssql::protocol::return_status::ReturnStatus;
use crate::mssql::protocol::return_value::ReturnValue;
use crate::mssql::protocol::row::Row;
use crate::mssql::{MsSqlConnectOptions, MsSqlDatabaseError};
use crate::net::MaybeTlsStream;
@ -30,7 +31,7 @@ pub(crate) struct MsSqlStream {
// most recent column data from ColMetaData
// we need to store this as its needed when decoding <Row>
columns: Vec<ColumnData>,
pub(crate) columns: Vec<ColumnData>,
}
impl MsSqlStream {
@ -112,6 +113,7 @@ impl MsSqlStream {
};
let ty = MessageType::get(buf)?;
let message = match ty {
MessageType::EnvChange => {
match EnvChange::get(buf)? {
@ -137,6 +139,7 @@ impl MsSqlStream {
MessageType::Row => Message::Row(Row::get(buf, &self.columns)?),
MessageType::LoginAck => Message::LoginAck(LoginAck::get(buf)?),
MessageType::ReturnStatus => Message::ReturnStatus(ReturnStatus::get(buf)?),
MessageType::ReturnValue => Message::ReturnValue(ReturnValue::get(buf)?),
MessageType::Done => Message::Done(Done::get(buf)?),
MessageType::DoneInProc => Message::DoneInProc(Done::get(buf)?),
MessageType::DoneProc => Message::DoneProc(Done::get(buf)?),

View file

@ -9,6 +9,7 @@ pub struct MsSqlConnectOptions {
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) username: String,
pub(crate) database: String,
pub(crate) password: Option<String>,
}
@ -23,6 +24,7 @@ impl MsSqlConnectOptions {
Self {
port: 1433,
host: String::from("localhost"),
database: String::from("master"),
username: String::from("sa"),
password: None,
}
@ -47,6 +49,11 @@ impl MsSqlConnectOptions {
self.password = Some(password.to_owned());
self
}
pub fn database(mut self, database: &str) -> Self {
self.database = database.to_owned();
self
}
}
impl FromStr for MsSqlConnectOptions {
@ -73,6 +80,11 @@ impl FromStr for MsSqlConnectOptions {
options = options.password(password);
}
let path = url.path().trim_start_matches('/');
if !path.is_empty() {
options = options.database(path);
}
Ok(options)
}
}

View file

@ -5,7 +5,7 @@ use crate::error::Error;
#[derive(Debug)]
pub(crate) struct Done {
status: Status,
pub(crate) status: Status,
// The token of the current SQL statement. The token value is provided and controlled by the
// application layer, which utilizes TDS. The TDS layer does not evaluate the value.

View file

@ -3,6 +3,7 @@ use bytes::{Buf, Bytes};
use crate::mssql::protocol::done::Done;
use crate::mssql::protocol::login_ack::LoginAck;
use crate::mssql::protocol::return_status::ReturnStatus;
use crate::mssql::protocol::return_value::ReturnValue;
use crate::mssql::protocol::row::Row;
#[derive(Debug)]
@ -13,6 +14,7 @@ pub(crate) enum Message {
DoneProc(Done),
Row(Row),
ReturnStatus(ReturnStatus),
ReturnValue(ReturnValue),
}
#[derive(Debug)]
@ -27,6 +29,7 @@ pub(crate) enum MessageType {
Error,
ColMetaData,
ReturnStatus,
ReturnValue,
}
impl MessageType {
@ -35,6 +38,7 @@ impl MessageType {
0x81 => MessageType::ColMetaData,
0xaa => MessageType::Error,
0xab => MessageType::Info,
0xac => MessageType::ReturnValue,
0xad => MessageType::LoginAck,
0xd1 => MessageType::Row,
0xe3 => MessageType::EnvChange,

View file

@ -10,6 +10,7 @@ pub(crate) mod message;
pub(crate) mod packet;
pub(crate) mod pre_login;
pub(crate) mod return_status;
pub(crate) mod return_value;
pub(crate) mod row;
pub(crate) mod rpc;
pub(crate) mod sql_batch;

View file

@ -0,0 +1,50 @@
use bitflags::bitflags;
use bytes::{Buf, Bytes};
use crate::error::Error;
use crate::mssql::io::MsSqlBufExt;
use crate::mssql::protocol::col_meta_data::Flags;
use crate::mssql::protocol::type_info::TypeInfo;
#[derive(Debug)]
pub(crate) struct ReturnValue {
param_ordinal: u16,
param_name: String,
status: ReturnValueStatus,
user_type: u32,
flags: Flags,
type_info: TypeInfo,
value: Bytes,
}
bitflags! {
pub(crate) struct ReturnValueStatus: u8 {
// If ReturnValue corresponds to OUTPUT parameter of a stored procedure invocation
const OUTPUT_PARAM = 0x01;
// If ReturnValue corresponds to return value of User Defined Function.
const USER_DEFINED = 0x02;
}
}
impl ReturnValue {
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
let ordinal = buf.get_u16_le();
let name = buf.get_b_varchar()?;
let status = ReturnValueStatus::from_bits_truncate(buf.get_u8());
let user_type = buf.get_u32_le();
let flags = Flags::from_bits_truncate(buf.get_u16_le());
let type_info = TypeInfo::get(buf)?;
let value = type_info.get_value(buf);
Ok(Self {
param_ordinal: ordinal,
param_name: name,
status,
user_type,
flags,
type_info,
value,
})
}
}

View file

@ -2,7 +2,7 @@ use bitflags::bitflags;
use bytes::{Buf, Bytes};
use encoding_rs::Encoding;
use crate::encode::Encode;
use crate::encode::{Encode, IsNull};
use crate::error::Error;
use crate::mssql::MsSql;
@ -413,9 +413,13 @@ impl TypeInfo {
let offset = buf.len();
buf.push(0);
let _ = value.encode(buf);
let size = if let IsNull::Yes = value.encode(buf) {
0xFF
} else {
(buf.len() - offset - 1) as u8
};
buf[offset] = (buf.len() - offset - 1) as u8;
buf[offset] = size;
}
pub(crate) fn put_short_len_value<'q, T: Encode<'q, MsSql>>(
@ -426,9 +430,12 @@ impl TypeInfo {
let offset = buf.len();
buf.extend(&0_u16.to_le_bytes());
let _ = value.encode(buf);
let size = if let IsNull::Yes = value.encode(buf) {
0xFFFF
} else {
(buf.len() - offset - 2) as u16
};
let size = (buf.len() - offset - 2) as u16;
buf[offset..(offset + 2)].copy_from_slice(&size.to_le_bytes());
}
@ -436,9 +443,12 @@ impl TypeInfo {
let offset = buf.len();
buf.extend(&0_u32.to_le_bytes());
let _ = value.encode(buf);
let size = if let IsNull::Yes = value.encode(buf) {
0xFFFF_FFFF
} else {
(buf.len() - offset - 4) as u32
};
let size = (buf.len() - offset - 4) as u32;
buf[offset..(offset + 4)].copy_from_slice(&size.to_le_bytes());
}

View file

@ -1,3 +1,38 @@
use crate::encode::{Encode, IsNull};
use crate::mssql::protocol::type_info::{DataType, TypeInfo};
use crate::mssql::{MsSql, MsSqlTypeInfo};
mod float;
mod int;
mod str;
impl<'q, T: 'q + Encode<'q, MsSql>> Encode<'q, MsSql> for Option<T> {
fn produces(&self) -> MsSqlTypeInfo {
if let Some(v) = self {
v.produces()
} else {
// MSSQL requires a special NULL type ID
MsSqlTypeInfo(TypeInfo::new(DataType::Null, 0))
}
}
fn encode(self, buf: &mut Vec<u8>) -> IsNull {
if let Some(v) = self {
v.encode(buf)
} else {
IsNull::Yes
}
}
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
if let Some(v) = self {
v.encode_by_ref(buf)
} else {
IsNull::Yes
}
}
fn size_hint(&self) -> usize {
self.as_ref().map_or(0, Encode::size_hint)
}
}

View file

@ -30,3 +30,7 @@ pub type MySqlPool = crate::pool::Pool<MySql>;
impl_into_arguments_for_arguments!(MySqlArguments);
impl_executor_for_pool_connection!(MySql, MySqlConnection, MySqlRow);
impl_executor_for_transaction!(MySql, MySqlRow);
// required because some databases have a different handling
// of NULL
impl_encode_for_option!(MySql);

View file

@ -33,3 +33,7 @@ pub type PgPool = crate::pool::Pool<Postgres>;
impl_into_arguments_for_arguments!(PgArguments);
impl_executor_for_pool_connection!(Postgres, PgConnection, PgRow);
impl_executor_for_transaction!(Postgres, PgRow);
// required because some databases have a different handling
// of NULL
impl_encode_for_option!(Postgres);

View file

@ -34,3 +34,7 @@ pub type SqlitePool = crate::pool::Pool<Sqlite>;
impl_into_arguments_for_arguments!(SqliteArguments<'q>);
impl_executor_for_pool_connection!(Sqlite, SqliteConnection, SqliteRow);
impl_executor_for_transaction!(Sqlite, SqliteRow);
// required because some databases have a different handling
// of NULL
impl_encode_for_option!(Postgres);

View file

@ -1,3 +1,5 @@
*
!certs/*
!keys/*
!mssql/*.sh
!*/*.sql

View file

@ -180,13 +180,21 @@ services:
#
mssql_2019:
image: mcr.microsoft.com/mssql/server:2019-latest
build:
context: .
dockerfile: mssql/Dockerfile
args:
VERSION: 2019-latest
environment:
ACCEPT_EULA: Y
SA_PASSWORD: Password123!
mssql_2017:
image: mcr.microsoft.com/mssql/server:2017-latest
build:
context: .
dockerfile: mssql/Dockerfile
args:
VERSION: 2017-latest
environment:
ACCEPT_EULA: Y
SA_PASSWORD: Password123!

21
tests/mssql/Dockerfile Normal file
View file

@ -0,0 +1,21 @@
ARG VERSION
FROM mcr.microsoft.com/mssql/server:${VERSION}
# Create a config directory
RUN mkdir -p /usr/config
WORKDIR /usr/config
# Bundle config source
COPY mssql/entrypoint.sh /usr/config/entrypoint.sh
COPY mssql/configure-db.sh /usr/config/configure-db.sh
COPY mssql/setup.sql /usr/config/setup.sql
# Grant permissions for to our scripts to be executable
USER root
RUN chmod +x /usr/config/entrypoint.sh
RUN chmod +x /usr/config/configure-db.sh
RUN chown 10001 /usr/config/entrypoint.sh
RUN chown 10001 /usr/config/configure-db.sh
USER 10001
ENTRYPOINT ["/usr/config/entrypoint.sh"]

View file

@ -0,0 +1,7 @@
#!/usr/bin/env bash
# Wait 60 seconds for SQL Server to start up
sleep 60
# Run the setup script to create the DB and the schema in the DB
/opt/mssql-tools/bin/sqlcmd -S localhost -U sa -P $SA_PASSWORD -d master -i setup.sql

37
tests/mssql/describe.rs Normal file
View file

@ -0,0 +1,37 @@
use sqlx::mssql::MsSql;
use sqlx::{describe::Column, Executor};
use sqlx_test::new;
fn type_names(columns: &[Column<MsSql>]) -> Vec<String> {
columns
.iter()
.filter_map(|col| Some(col.type_info.as_ref()?.to_string()))
.collect()
}
#[sqlx_macros::test]
async fn it_describes_simple() -> anyhow::Result<()> {
let mut conn = new::<MsSql>().await?;
let d = conn.describe("SELECT * FROM tweet").await?;
let columns = d.columns;
assert_eq!(columns[0].name, "id");
assert_eq!(columns[1].name, "text");
assert_eq!(columns[2].name, "is_sent");
assert_eq!(columns[3].name, "owner_id");
assert_eq!(columns[0].not_null, Some(true));
assert_eq!(columns[1].not_null, Some(true));
assert_eq!(columns[2].not_null, Some(true));
assert_eq!(columns[3].not_null, Some(false));
let column_type_names = type_names(&columns);
assert_eq!(column_type_names[0], "bigint");
assert_eq!(column_type_names[1], "nvarchar(max)");
assert_eq!(column_type_names[2], "tinyint");
assert_eq!(column_type_names[3], "bigint");
Ok(())
}

View file

@ -0,0 +1,7 @@
#!/usr/bin/env bash
# Start the script to create the DB and user
/usr/config/configure-db.sh &
# Start SQL Server
/opt/mssql/bin/sqlservr

20
tests/mssql/setup.sql Normal file
View file

@ -0,0 +1,20 @@
IF DB_ID('sqlx') IS NULL
BEGIN
CREATE DATABASE sqlx;
END;
GO
USE sqlx;
GO
IF OBJECT_ID('tweet') IS NULL
BEGIN
CREATE TABLE tweet
(
id BIGINT NOT NULL PRIMARY KEY,
text NVARCHAR(4000) NOT NULL,
is_sent TINYINT NOT NULL DEFAULT 1,
owner_id BIGINT
);
END;
GO