WIP [next]: implement generalized query placeholders

Signed-off-by: Austin Bonander <austin@launchbadge.com>
This commit is contained in:
Austin Bonander 2021-02-28 14:34:21 -08:00
parent 1cac2864ec
commit a2eda2de24
No known key found for this signature in database
GPG key ID: 461F7F0F45383F2B
24 changed files with 1304 additions and 107 deletions

20
Cargo.lock generated
View file

@ -274,6 +274,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "combine"
version = "4.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc4369b5e4c0cddf64ad8981c0111e7df4f7078f4d6ba98fb31f2e17c4c57b7e"
dependencies = [
"bytes",
"memchr",
]
[[package]]
name = "concurrent-queue"
version = "1.2.2"
@ -819,6 +829,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "paste"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "acbf547ad0c65e31259204bd90935776d1c693cec2f4ff7abb7a1bbbd40dfe58"
[[package]]
name = "pem"
version = "0.8.3"
@ -1089,6 +1105,7 @@ dependencies = [
"async-std",
"bytes",
"bytestring",
"combine",
"conquer-once",
"crossbeam",
"either",
@ -1148,8 +1165,11 @@ dependencies = [
"log",
"md-5",
"memchr",
"paste",
"percent-encoding",
"sqlx-core",
"sqlx-test",
"tokio",
"url",
]

View file

@ -52,3 +52,4 @@ memchr = "2.3"
conquer-once = { version = "0.3.2", optional = true }
parking_lot = { version = "0.11.1", optional = true }
crossbeam = { version = "0.8.0", optional = true }
combine = "4.5.2"

View file

@ -1,10 +1,11 @@
use std::any;
use either::Either;
use crate::database::HasOutput;
use crate::{encode, Database, Error, Result, TypeEncode, TypeInfo};
use std::borrow::Cow;
use std::fmt::{self, Display, Formatter};
/// A collection of arguments to be applied to a prepared statement.
///
/// This container allows for a heterogeneous list of positional and named
@ -17,10 +18,17 @@ pub struct Arguments<'a, Db: Database> {
positional: Vec<Argument<'a, Db>>,
}
/// The index for a given bind argument; either positional, or named.
#[derive(Debug, PartialEq, Eq)]
pub enum ArgumentIndex<'a> {
Positioned(usize),
Named(Cow<'a, str>),
}
/// A single argument to be applied to a prepared statement.
pub struct Argument<'a, Db: Database> {
unchecked: bool,
parameter: Either<usize, &'a str>,
index: ArgumentIndex<'a>,
// preserved from `T::type_id()`
type_id: Db::TypeId,
@ -38,15 +46,15 @@ pub struct Argument<'a, Db: Database> {
}
impl<'a, Db: Database> Argument<'a, Db> {
fn new<T: 'a + TypeEncode<Db>>(
parameter: Either<usize, &'a str>,
fn new<'b: 'a, T: 'a + TypeEncode<Db>>(
parameter: impl Into<ArgumentIndex<'b>>,
value: &'a T,
unchecked: bool,
) -> Self {
Self {
value,
unchecked,
parameter,
index: parameter.into(),
type_id: T::type_id(),
type_compatible: T::compatible,
rust_type_name: any::type_name::<T>(),
@ -77,7 +85,7 @@ impl<'a, Db: Database> Arguments<'a, Db> {
pub fn add<T: 'a + TypeEncode<Db>>(&mut self, value: &'a T) {
let index = self.positional.len();
self.positional.push(Argument::new(Either::Left(index), value, false));
self.positional.push(Argument::new(index, value, false));
}
/// Add an unchecked value to the end of the arguments collection.
@ -89,17 +97,17 @@ impl<'a, Db: Database> Arguments<'a, Db> {
pub fn add_unchecked<T: 'a + TypeEncode<Db>>(&mut self, value: &'a T) {
let index = self.positional.len();
self.positional.push(Argument::new(Either::Left(index), value, true));
self.positional.push(Argument::new(index, value, true));
}
/// Add a named value to the argument collection.
pub fn add_as<T: 'a + TypeEncode<Db>>(&mut self, name: &'a str, value: &'a T) {
self.named.push(Argument::new(Either::Right(name), value, false));
self.named.push(Argument::new(name, value, false));
}
/// Add an unchecked, named value to the arguments collection.
pub fn add_unchecked_as<T: 'a + TypeEncode<Db>>(&mut self, name: &'a str, value: &'a T) {
self.named.push(Argument::new(Either::Right(name), value, true));
self.named.push(Argument::new(name, value, true));
}
}
@ -155,20 +163,30 @@ impl<'a, Db: Database> Arguments<'a, Db> {
}
/// Returns a reference to the argument at the location, if present.
pub fn get<'x, I: ArgumentIndex<'a, Db>>(&'x self, index: &I) -> Option<&'x Argument<'a, Db>> {
index.get(self)
pub fn get<'x, 'i, I: Into<ArgumentIndex<'i>>>(
&'x self,
index: I,
) -> Option<&'x Argument<'a, Db>> {
let index = index.into();
match index {
ArgumentIndex::Named(_) => &self.named,
ArgumentIndex::Positioned(_) => &self.positional,
}
.iter()
.find(|arg| arg.index == index)
}
}
impl<'a, Db: Database> Argument<'a, Db> {
/// Gets the name of this argument, if it is a named argument, None otherwise
pub fn name<'b>(&'b self) -> Option<&'a str>{
self.parameter.right()
pub fn name(&self) -> Option<&str> {
self.index.name()
}
/// Gets the position of this argument, if it is a positional argument, None otherwise
pub fn position(&self) -> Option<usize>{
self.parameter.left()
pub fn position(&self) -> Option<usize> {
self.index.position()
}
/// Returns the SQL type identifier of the argument.
@ -198,29 +216,96 @@ impl<'a, Db: Database> Argument<'a, Db> {
self.value.encode(ty, out)
};
res.map_err(|source| Error::ParameterEncode {
parameter: self.parameter.map_right(|name| name.to_string().into_boxed_str()),
source,
})
res.map_err(|source| Error::ParameterEncode { parameter: self.index.to_static(), source })
}
pub fn value(&self) -> &(dyn TypeEncode<Db> + 'a) {
self.value
}
}
/// A helper trait used for indexing into an [`Arguments`] collection.
pub trait ArgumentIndex<'a, Db: Database> {
/// Returns a reference to the argument at this location, if present.
fn get<'x>(&self, arguments: &'x Arguments<'a, Db>) -> Option<&'x Argument<'a, Db>>;
}
// access a named argument by name
impl<'a, Db: Database> ArgumentIndex<'a, Db> for str {
fn get<'x>(&self, arguments: &'x Arguments<'a, Db>) -> Option<&'x Argument<'a, Db>> {
arguments.named.iter().find_map(|arg| (arg.parameter.right() == Some(self)).then(|| arg))
impl<'a> From<&'a str> for ArgumentIndex<'a> {
fn from(name: &'a str) -> Self {
ArgumentIndex::Named(name.into())
}
}
// access a positional argument by index
impl<'a, Db: Database> ArgumentIndex<'a, Db> for usize {
fn get<'x>(&self, arguments: &'x Arguments<'a, Db>) -> Option<&'x Argument<'a, Db>> {
arguments.positional.get(*self)
impl<'a> From<&'a String> for ArgumentIndex<'a> {
fn from(name: &'a String) -> Self {
ArgumentIndex::Named(name.into())
}
}
impl From<usize> for ArgumentIndex<'static> {
fn from(position: usize) -> Self {
ArgumentIndex::Positioned(position)
}
}
impl<'a, 'b> From<&'a ArgumentIndex<'b>> for ArgumentIndex<'a> {
fn from(idx: &'a ArgumentIndex<'b>) -> Self {
match idx {
ArgumentIndex::Positioned(pos) => ArgumentIndex::Positioned(*pos),
ArgumentIndex::Named(name) => ArgumentIndex::Named(name.as_ref().into()),
}
}
}
impl<'a> ArgumentIndex<'a> {
pub(crate) fn into_static(self) -> ArgumentIndex<'static> {
match self {
Self::Positioned(pos) => ArgumentIndex::Positioned(pos),
Self::Named(named) => ArgumentIndex::Named(named.into_owned().into()),
}
}
pub(crate) fn to_static(&self) -> ArgumentIndex<'static> {
match self {
Self::Positioned(pos) => ArgumentIndex::Positioned(*pos),
Self::Named(named) => ArgumentIndex::Named((**named).to_owned().into()),
}
}
pub(crate) fn name(&self) -> Option<&str> {
if let Self::Named(s) = self {
Some(s)
} else {
None
}
}
pub(crate) fn position(&self) -> Option<usize> {
if let Self::Positioned(pos) = *self {
Some(pos)
} else {
None
}
}
}
impl Display for ArgumentIndex<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
Self::Positioned(pos) => Display::fmt(pos, f),
Self::Named(named) => Display::fmt(named, f),
}
}
}
impl PartialEq<str> for ArgumentIndex<'_> {
fn eq(&self, other: &str) -> bool {
self == &ArgumentIndex::from(other)
}
}
impl PartialEq<&'_ str> for ArgumentIndex<'_> {
fn eq(&self, other: &&str) -> bool {
self == &ArgumentIndex::from(*other)
}
}
impl PartialEq<usize> for ArgumentIndex<'_> {
fn eq(&self, other: &usize) -> bool {
self == &ArgumentIndex::from(*other)
}
}

View file

@ -26,6 +26,15 @@ pub trait Database:
/// The concrete [`TypeId`] implementation for this database.
type TypeId: 'static + PartialEq + Hash + Clone + Copy + Send + Sync;
/// The character used to prefix bind parameter placeholders, e.g. `$` for Postgres, `?` for MySQL, etc.
const PLACEHOLDER_CHAR: char;
/// The indexing type for bind parameters.
///
/// E.g. `Implicit` for MySQL which just does `SELECT 1 FROM foo WHERE bar = ? AND baz = ?`
/// or `OneIndexed` for Postgres which does `SELECT 1 FROM foo WHERE bar = $1 AND baz = $2`
const PARAM_INDEXING: crate::placeholders::ParamIndexing;
}
/// Associates [`Database`] with an `Output` of a generic lifetime.

View file

@ -2,7 +2,9 @@ use std::error::Error as StdError;
use std::fmt::{self, Display, Formatter};
use crate::database::HasOutput;
use crate::Database;
use crate::{Arguments, Database, Type, TypeEncode};
use std::iter::FromIterator;
use std::ops::BitOr;
/// Type returned from [`Encode::encode`] that indicates if the value encoded is the SQL `null` or not.
pub enum IsNull {
@ -20,10 +22,44 @@ pub enum IsNull {
No,
}
#[doc(hidden)]
impl BitOr<IsNull> for IsNull {
type Output = IsNull;
fn bitor(self, rhs: IsNull) -> Self::Output {
use IsNull::*;
match (self, rhs) {
(No, No) => No,
(Yes, No) | (No, Yes) | (Yes, Yes) => Yes,
}
}
}
/// Useful for encoding arrays
#[doc(hidden)]
impl FromIterator<IsNull> for IsNull {
fn from_iter<T: IntoIterator<Item = IsNull>>(iter: T) -> Self {
iter.into_iter().fold(IsNull::No, BitOr::bitor)
}
}
/// A type that can be encoded into a SQL value.
pub trait Encode<Db: Database>: Send + Sync {
/// Encode this value into the specified SQL type.
fn encode(&self, ty: &Db::TypeInfo, out: &mut <Db as HasOutput<'_>>::Output) -> Result;
/// If this type is a vector, get its length.
fn vector_len(&self) -> Option<usize> {
None
}
/// If this type is a vector, add its elements as positional arguments to `arguments`.
///
/// Panics if not a vector.
fn expand_vector<'a>(&'a self, _arguments: &mut Arguments<'a, Db>) {
panic!("not a vector!")
}
}
impl<T: Encode<Db>, Db: Database> Encode<Db> for &T {
@ -31,6 +67,10 @@ impl<T: Encode<Db>, Db: Database> Encode<Db> for &T {
fn encode(&self, ty: &Db::TypeInfo, out: &mut <Db as HasOutput<'_>>::Output) -> Result {
(*self).encode(ty, out)
}
fn vector_len(&self) -> Option<usize> {
(*self).vector_len()
}
}
/// Errors which can occur while encoding a SQL value.

View file

@ -13,10 +13,11 @@ mod database;
pub use client::ClientError;
pub use database::DatabaseError;
use crate::arguments::ArgumentIndex;
use crate::Column;
/// Specialized `Result` type returned from fallible methods within SQLx.
pub type Result<T> = std::result::Result<T, Error>;
pub type Result<T, E = Error> = std::result::Result<T, E>;
/// Error type returned for all methods in SQLX.
#[derive(Debug)]
@ -84,7 +85,10 @@ pub enum Error {
/// An error occurred encoding a value for a specific parameter to
/// be sent to the database.
ParameterEncode { parameter: Either<usize, Box<str>>, source: EncodeError },
ParameterEncode { parameter: ArgumentIndex<'static>, source: EncodeError },
/// An error occurred while parsing or expanding the generic placeholder syntax in a query.
Placeholders(crate::placeholders::Error),
}
impl Error {
@ -172,13 +176,11 @@ impl Display for Error {
}
}
Self::ParameterEncode { parameter: Either::Left(index), source } => {
write!(f, "Encode parameter {}: {}", index, source)
Self::ParameterEncode { parameter, source } => {
write!(f, "Encode parameter {}: {}", parameter, source)
}
Self::ParameterEncode { parameter: Either::Right(name), source } => {
write!(f, "Encode parameter `{}`: {}", name, source)
}
Self::Placeholders(e) => e.fmt(f),
}
}
}
@ -188,7 +190,7 @@ impl StdError for Error {
match self {
Self::ConnectOptions { source: Some(source), .. } => Some(&**source),
Self::Network(source) => Some(source),
Self::Placeholders(source) => Some(source),
_ => None,
}
}

View file

@ -49,6 +49,9 @@ pub mod io;
#[doc(hidden)]
pub mod net;
#[doc(hidden)]
pub mod placeholders;
#[doc(hidden)]
#[cfg(feature = "_mock")]
pub mod mock;

View file

@ -0,0 +1,711 @@
//! Parsing support for Generalized Query Placeholders, similar to `println!()` or `format_args!()` syntax.
//!
//! ### Kinds
//!
//! Implicit indexing: `SELECT * FROM foo WHERE id = {} AND bar = {}`
//! where each placeholder implicitly refers to an expression at the equivalent position
//! in the bind arguments list
//!
//! Explicit zero-based indexing: `SELECT * FROM foo WHERE id = {N}` where `N` is an unsigned integer
//! which refers to the Nth expression in the bind arguments list (starting from zero)
//!
//! Arguments capturing:
//!
//! `SELECT * FROM foo WHERE id = {<ident>}` where `<ident>` is a Rust identifier
//! defined in the same scope as the query string (for the macros) or an explicitly named bind argument
//! (for the dynamic interface)
//!
//! `SELECT * FROM foo WHERE id = {<field-expr>}` where `<field-expr>` is a Rust field expression
//! (e.g. `foo.bar.baz`) which resolves in the current scope (for the macros)
//!
//! Repetition interpolated into the query string:
//!
//! * `SELECT * FROM foo WHERE id IN ({+})`
//! * `SELECT * FROM foo WHERE id IN ({N+})`
//! * `SELECT * FROM foo WHERE id IN ({<ident>+})`
//! * `SELECT * FROM foo WHERE id IN ({(<field-expr>)+})`
//!
//! Similar to the above, but where the bind argument corresponding to the placeholder is expected
//! to be an iterable, and the repetition is expanded into the query string at runtime
//! (for databases which don't support binding arrays).
//!
//! For example:
//!
//! ```rust,ignore
//! let foo = [1, 2, 3, 4, 5];
//!
//! sqlx::query!("SELECT * FROM foo WHERE id IN ({foo*}")
//!
//! // would be equivalent to:
//!
//! sqlx::query!("SELECT * FROM foo WHERE id IN ($1, $2, $3, $4, $5)", foo[0], foo[1], foo[2], foo[3], foo[4])
//! ```
//!
//! (Note: for Postgres, binding the array directly instead of using expansion should be preferred
//! as it will not generate a different query string for every arity of iterable passed.)
//!
//! ### Potential Pitfalls to Avoid
//! We want to make sure to avoid trying to parse paired braces inside strings as it could
//! be, e.g. a JSON object literal. We also need to support escaping braces (and erasing the escapes)
//!
use std::borrow::Cow;
use std::fmt::{self, Display, Formatter, Write};
use std::ops::Range;
use crate::arguments::ArgumentIndex;
use crate::{Arguments, Database};
use combine::parser::char::{alpha_num, letter};
use combine::parser::range::{recognize, recognize_with_value, take_while1};
use combine::parser::repeat::{escaped, repeat_skip_until};
use combine::stream::position::{Positioner, RangePositioner, SourcePosition};
use combine::*;
use std::cmp;
/// The number of words (group of characters separated by a space) before and after a given position
/// to give for context. See [`error_context()`].
const NUM_CONTEXT_WORDS: usize = 3;
/// A query parsed for generic placeholders with [`parse_query()`].
pub struct ParsedQuery<'a> {
pub(crate) query: &'a str,
pub(crate) placeholders: Vec<Placeholder<'a>>,
}
/// A single generic placeholder in a query parsed with [`parse_query()`].
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub struct Placeholder<'a> {
/// The byte range in the source query where this placeholder appears, including the `{}`
pub token: Range<usize>,
/// The identifier for this placeholder.
pub ident: Ident<'a>,
/// The kleene operator for this placeholder. If `Some`, the bind parameter is expected to be a vector.
pub kleene: Option<Kleene>,
}
/// The identifier for a placeholder which connects it to a bind parameter.
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum Ident<'a> {
/// An implicitly indexed placeholder, i.e. just `{}`
Implicit,
/// A positionally indexed placeholder, e.g. `{0}`, `{1}`, etc.
Positional(u16),
/// A named placeholder, e.g. `{foo}` would be `Named("foo")`
Named(Cow<'a, str>),
/// A placeholder with a field access expression, e.g. `{(foo.bar.baz)}` would be `Field("foo.bar.baz")`
Field(Cow<'a, str>),
}
/// The optional Kleene operator of a [Placeholder] which changes its expansion.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum Kleene {
// not currently supported
// Question,
// Star,
/// The `+` Kleene operator, e.g. `{foo+}`. Always expands to at least one value.
///
/// A vector of 0 items expands to the literal `NULL` while
/// a non-empty vector expands to a comma-separated list, e.g. `$1, $2, $3`.
Plus,
}
/// The bind parameter indexing type for the given database.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum ParamIndexing {
/// Implicitly indexed bind parameters, e.g. for MySQL
/// which just does `SELECT 1 FROM foo WHERE bar = ? AND baz = ?`
Implicit,
/// Explicitly 1-based indexing of bind parameters, e.g. for Postgres
/// which does `SELECT 1 FROM foo WHERE bar = $1 AND baz = $2`
OneIndexed,
}
/// The type of an individual bind argument.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum ArgumentKind {
/// This bind param is a scalar, i.e. it should expand to only one concrete placeholder.
Scalar,
/// This bind param is a vector, i.e. its expansion is dictated by the `Kleene` operator.
/// The `usize` value is the length of the vector (which may be 0).
///
/// [`ParsedQuery::expand()`] will error if the corresponding [`Placeholder::kleene`] is `None`.
Vector(usize),
}
/// The error type returned by [`parse_query`] and [`ParsedQuery::expand()`]
#[derive(Debug)]
#[non_exhaustive]
pub enum Error {
/// An error occurred while parsing the query for generic placeholder syntax.
Parse {
/// The byte position in the query string where the error occurred.
byte_position: usize,
/// The line in the string where the error occurred.
line: i32,
/// The column in the string where the error occurred.
column: i32,
/// The message string, with error and context.
message: String,
/// The context string, which may help with locating the error.
context: String,
},
/// An error occurred while expanding the generic placeholder syntax.
///
/// The string is the error message.
Expand(String),
/// There was a mismatch between query placeholders and bind arguments.
///
/// The string is the error message.
ArgsMismatch(String),
/// An error occurred mapping an individual placeholder to a bind argument.
PlaceholderToArgument {
/// The argument which triggered the error.
argument: ArgumentIndex<'static>,
/// The error message.
message: String,
},
/// One or more generic placeholders was parsed in a non-prepared statement
/// (e.g. a raw query string passed directly to a method of `Executor`)
/// but generic placeholders are only supported when using prepared statements
/// (e.g. `sqlx::query()`).
PreparedStatementsOnly,
}
type Result<T, E = Error> = std::result::Result<T, E>;
impl<'a> ParsedQuery<'a> {
/// Get the parsed list of placeholders.
pub fn placeholders(&self) -> &[Placeholder<'a>] {
&self.placeholders
}
/// Expand the placeholders in this query according to
/// [`DB::PLACEHOLDER_CHAR`][Database::PLACEHOLDER_CHAR] and
/// [`DB::PARAMETER_STYLE`][Database::PARAMETER_STYLE].
///
/// The callback will be invoked for each placeholder and should return the `ArgumentKind`
/// for the corresponding query argument.
///
/// See [`default_get_arg()`] which returns a default implementation of this callback
/// that just looks up the value in an `Arguments` struct or errors.
///
/// A custom callback is only necessary if the database needs to adjust how the value is bound
/// based on the placeholder; e.g. Postgres, which has native support for vectors/arrays, needs
/// to know if the placeholder is expecting a comma-expansion (bind each value separately)
/// or not (bind the array wholesale).
///
/// Returns an error if:
/// * the `get_arg` callback returns an error (will be `Error::PlaceholderToArgument`)
/// * any param is a [`ArgumentKind::Scalar`] but the corresponding [`Placeholder::kleene`] is `Some`
/// * any param is a [`ArgumentKind::Vector`] but the corresponding [`Placeholder::kleene`] is `None`
pub fn expand<
DB: Database,
Arg: FnMut(&ArgumentIndex<'_>, &Placeholder<'a>) -> Result<ArgumentKind, String>,
>(
&self,
get_arg: Arg,
) -> Result<Cow<'a, str>> {
self.expand_inner(DB::PLACEHOLDER_CHAR, DB::PARAM_INDEXING, get_arg)
}
/// Unit-testable version of `expand`
fn expand_inner(
&self,
placeholder_char: char,
indexing: ParamIndexing,
mut get_arg: impl FnMut(&ArgumentIndex<'_>, &Placeholder<'a>) -> Result<ArgumentKind, String>,
) -> Result<Cow<'a, str>> {
macro_rules! err {
($($args:tt)*) => {
Err(Error::Expand(format!($($args)*)))
};
}
// optimization: if we don't have any placeholders to substitute, then just return `self.query`
if self.placeholders.is_empty() {
return Ok(self.query.into());
}
// the current placeholder index; unused if `ParamIndexing::Implicit`
let mut index = match indexing {
ParamIndexing::Implicit => None,
ParamIndexing::OneIndexed => Some(1),
};
let mut push_placeholder = |buf: &mut String| {
buf.push(placeholder_char);
if let Some(ref mut index) = index {
write!(buf, "{}", index).expect("write!() to a string is infallible");
*index += 1;
}
};
let mut out = String::with_capacity(self.query.len());
let mut implicit_pos: usize = 0;
// copy `this .. self.query.len()` to the end of `out` after processing `placeholders`
let mut last_placeholder_end = 0;
for placeholder in &self.placeholders {
// push the chunk of `self.query` between the last placeholder and this one
out.push_str(&self.query[last_placeholder_end..placeholder.token.start]);
last_placeholder_end = placeholder.token.end;
let argument = placeholder.ident.to_arg_index(&mut implicit_pos);
match get_arg(&argument, placeholder).map_err(|e| Error::PlaceholderToArgument {
argument: argument.into_static(),
message: e,
})? {
ArgumentKind::Scalar => {
if placeholder.kleene.is_some() {
return err!("expected vector bind param for {:?}", placeholder);
}
push_placeholder(&mut out);
}
ArgumentKind::Vector(len) => {
let kleene = placeholder.kleene.ok_or_else(|| {
Error::Expand(format!("expected Kleene operator for {:?}", placeholder))
})?;
if len == 0 {
match kleene {
Kleene::Plus => {
out.push_str("NULL");
}
}
continue;
}
let mut comma_needed = false;
for _ in 0..len {
if comma_needed {
out.push_str(", ");
}
push_placeholder(&mut out);
comma_needed = true;
}
}
}
}
out.push_str(&self.query[last_placeholder_end..]);
Ok(out.into())
}
}
impl std::error::Error for Error {}
impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
use Error::*;
match self {
Parse { line, column, message, context, .. } => {
write!(
f,
"Error parsing placeholders in query at line {}, column {}: {} near {:?}",
line, column, message, context
)
}
Expand(s) => write!(f, "Error expanding placeholders in query: {}", s),
ArgsMismatch(s) => write!(f, "Error matching placeholders to arguments: {}", s),
PlaceholderToArgument { argument, message } => {
write!(f, "Error mapping bind argument {} to a placeholder: {}", argument, message)
}
PreparedStatementsOnly => f.write_str(
"generic placeholders are only supported when using prepared statements",
),
}
}
}
impl From<Error> for crate::Error {
fn from(e: Error) -> Self {
crate::Error::Placeholders(e)
}
}
impl Ident<'_> {
fn to_arg_index(&self, implicit_pos: &mut usize) -> ArgumentIndex<'_> {
match self {
Self::Implicit => {
let ret = *implicit_pos;
*implicit_pos += 1;
ret.into()
}
Self::Positional(pos) => (*pos as usize).into(),
Self::Named(s) => (&**s).into(),
Self::Field(s) => (&**s).into(),
}
}
}
/// similar to combine's `IndexPositioner` but which correctly maintains byte-position
/// and also tracks a `SourcePosition` for user-friendliness
#[derive(Clone, Default, PartialOrd, Ord, PartialEq, Eq, Debug)]
struct StrPosition {
byte_pos: usize,
source_pos: SourcePosition,
}
impl Positioner<char> for StrPosition {
type Position = Self;
type Checkpoint = Self;
fn position(&self) -> Self::Position {
self.clone()
}
fn update(&mut self, token: &char) {
self.byte_pos += token.len_utf8();
self.source_pos.update(token);
}
fn checkpoint(&self) -> Self::Checkpoint {
self.clone()
}
fn reset(&mut self, checkpoint: Self::Checkpoint) {
*self = checkpoint;
}
}
impl<'a> RangePositioner<char, &'a str> for StrPosition {
fn update_range(&mut self, range: &&'a str) {
self.byte_pos += range.len();
self.source_pos.update_range(range);
}
}
impl Display for StrPosition {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.source_pos.fmt(f)
}
}
struct DisplayErrors<'a>(Vec<combine::easy::Error<char, &'a str>>);
impl Display for DisplayErrors<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
combine::easy::Error::fmt_errors(&self.0, f)
}
}
pub fn parse_query(query: &str) -> Result<ParsedQuery<'_>> {
let placeholders = parse_query_string(query).map_err(|e| {
let combine::easy::Errors {
position: StrPosition { byte_pos, source_pos: SourcePosition { line, column } },
errors,
} = e;
Error::Parse {
byte_position: byte_pos,
line,
column,
message: DisplayErrors(errors).to_string(),
context: error_context(query, byte_pos).to_string(),
}
})?;
Ok(ParsedQuery { query, placeholders })
}
/// Convenient function to pass to `ParsedQuery::expand()` when no special handling of arguments
/// is necessary.
pub fn default_get_arg<'a, DB: Database>(
args: &'a Arguments<'_, DB>,
) -> impl FnMut(&ArgumentIndex<'_>, &Placeholder<'_>) -> Result<ArgumentKind, String> + 'a {
move |idx, _place| {
let arg = args.get(idx).ok_or("unknown argument")?;
Ok(arg.value().vector_len().map_or(ArgumentKind::Scalar, ArgumentKind::Vector))
}
}
fn parse_query_string(
query: &str,
) -> Result<Vec<Placeholder<'_>>, combine::easy::Errors<char, &'_ str, StrPosition>> {
parse_placeholders()
.easy_parse(combine::stream::position::Stream::with_positioner(
query,
StrPosition::default(),
))
.map(|(placeholders, _)| placeholders)
}
fn parse_placeholders<'a, I: RangeStream<Token = char, Range = &'a str, Position = StrPosition>>(
) -> impl Parser<combine::easy::Stream<I>, Output = Vec<Placeholder<'a>>> {
combine::many(
repeat_skip_until(
combine::choice((one_of("'\"`".chars()).then(escaped_string), any().map(|_| ()))),
attempt(token('{')),
)
.then(|_| parse_placeholder()),
)
}
fn parse_placeholder<'a, I: RangeStream<Token = char, Range = &'a str, Position = StrPosition>>(
) -> impl Parser<I, Output = Placeholder<'a>> {
(
position(),
recognize_with_value(between(
token('{'),
token('}'),
(parse_ident(), optional(parse_kleene())),
)),
)
.map(
|(pos, (range, (ident, kleene))): (
StrPosition,
(&str, (Ident<'_>, Option<Kleene>)),
)| {
let pos = pos.byte_pos;
Placeholder { token: pos..pos + range.len(), ident, kleene }
},
)
}
fn parse_ident<'a, I: RangeStream<Token = char, Range = &'a str>>(
) -> impl Parser<I, Output = Ident<'a>> {
let ident = || (letter().or(token('_')), skip_many(alpha_num().or(token('_'))));
choice((
// explicit positional: `{N...}`
parse_u16().map(Ident::Positional),
// explicit identifier: `{foo...}`
recognize(ident()).map(|ident: &str| Ident::Named(ident.into())),
// field access: `{(foo.bar)...}`
between(
token('('),
token(')'),
recognize((skip_many(attempt((ident(), token('.')))), ident())),
)
.map(|ident: &str| Ident::Field(ident.into())),
// implicit: `{...}`
attempt(optional(parse_kleene())).map(|_| Ident::Implicit),
))
}
fn parse_kleene<I: Stream<Token = char>>() -> impl Parser<I, Output = Kleene> {
// if we decide to support more Kleene operators
// choice((
// token('?').map(|_| Kleene::Question),
// token('*').map(|_| Kleene::Star),
// token('+').map(|_| Kleene::Plus),
// ))
not_followed_by(choice((token('?'), token('*'))))
.message("unsupported Kleene operator")
.then(|_| token('+').map(|_| Kleene::Plus))
}
fn parse_u16<'a, I: RangeStream<Token = char, Range = &'a str>>() -> impl Parser<I, Output = u16> {
from_str(take_while1(|c: char| c.is_digit(10)))
}
fn escaped_string<I: RangeStream<Token = char>>(quote_char: char) -> impl Parser<I, Output = ()>
where
I::Range: combine::stream::Range,
{
(
escaped(take_while1(move |c| c != quote_char && c != '\\'), '\\', token(quote_char)),
token(quote_char),
)
.map(|_| ())
}
/// Give context for the error in `s` at `at`
fn error_context(s: &str, at: usize) -> &str {
// match the _last_ non-whitespace character before one or more spaces
let edge_trigger_whitespace = || {
let mut prev = ' ';
move |c: char| {
let ret = c.is_whitespace() && !prev.is_whitespace();
prev = c;
ret
}
};
// defaults to the beginning of the string
// `cmp::max(Option, Option)` returns the `Some` value if only one is `Some`,
// else it's the max of the two values, or `None` if both are `None`
let start = cmp::max(
{
s[..at]
.rmatch_indices(edge_trigger_whitespace())
.take(NUM_CONTEXT_WORDS)
.last()
.map(|(i, sp)| i + sp.len())
},
// OR the previous newline
s[..at].rfind('\n'),
)
.unwrap_or(0);
// defaults to the end of string
// `cmp::min(Option, Option)` returns `None` if either is `None` so we have to unwrap first
let end = cmp::min(
s[at..]
.match_indices(edge_trigger_whitespace())
.take(NUM_CONTEXT_WORDS)
.last()
.map_or(s.len(), |(i, _s)| at + i),
s[at..].find('\n').map_or(s.len(), |i| at + i),
);
// trim excess whitespace around the context
&s[start..end].trim()
}
#[test]
fn test_parse_query_string() -> Result<(), Box<dyn std::error::Error>> {
use Ident::*;
use Kleene::*;
assert_eq!(
parse_query_string("SELECT 1 FROM foo WHERE bar = {} AND baz = {baz}")?,
[
Placeholder { token: 30..32, ident: Implicit, kleene: None },
Placeholder { token: 43..48, ident: Named("baz".into()), kleene: None }
]
);
assert_eq!(
parse_query_string("SELECT 1 FROM foo WHERE bar IN {(foo.bar)+}")?,
[Placeholder { token: 31..43, ident: Field("foo.bar".into()), kleene: Some(Plus) }]
);
assert_eq!(
parse_query_string(
r#"SELECT 1 FROM foo WHERE quux = '{ "foo": "\'bar\'" }' and bar IN {0}"#
)?,
[Placeholder { token: 65..68, ident: Positional(0), kleene: None }]
);
Ok(())
}
#[test]
fn test_expand_parsed_query() -> Result<()> {
use ArgumentKind::*;
use ParamIndexing::*;
macro_rules! args {
($($ident:expr => $val:expr),*$(,)?) => {
|arg: &ArgumentIndex<'_>, _p: &Placeholder<'_>| -> Result<ArgumentKind, String> {
$(
if *arg == $ident {
return Ok($val);
}
)*
Err(format!("unknown bind arg identifier {:?}", arg))
}
}
}
// Postgres
assert_eq!(
parse_query("SELECT 1 FROM foo WHERE bar = {} AND baz = {baz}")?.expand_inner(
'$',
OneIndexed,
args! {
0usize => Scalar,
"baz" => Scalar
}
)?,
"SELECT 1 FROM foo WHERE bar = $1 AND baz = $2"
);
assert_eq!(
parse_query(
r#"
SELECT 1
FROM foo
WHERE bar IN ({(foo.bar)+})
AND baz IN ({baz+})
AND quux IN ({quux+})"#
)?
.expand_inner(
'$',
OneIndexed,
args! {
"foo.bar" => Vector(3),
"baz" => Vector(0),
"quux" => Vector(1)
}
)?,
r#"
SELECT 1
FROM foo
WHERE bar IN ($1, $2, $3)
AND baz IN (NULL)
AND quux IN ($4)"#
);
assert_eq!(
parse_query(r#"SELECT 1 FROM foo WHERE quux = '{ "foo": "\'bar\'" }' and bar IN {0}"#)?
.expand_inner('$', OneIndexed, args! { 0usize => Scalar })?,
r#"SELECT 1 FROM foo WHERE quux = '{ "foo": "\'bar\'" }' and bar IN $1"#
);
// MySQL
assert_eq!(
parse_query("SELECT 1 FROM foo WHERE bar = {} AND baz = {baz}")?.expand_inner(
'?',
Implicit,
args! {
0usize => Scalar,
"baz" => Scalar,
}
)?,
"SELECT 1 FROM foo WHERE bar = ? AND baz = ?"
);
assert_eq!(
parse_query(
r#"
SELECT 1
FROM foo
WHERE bar IN ({(foo.bar)+})
AND baz IN ({baz+})
AND quux IN ({quux+})"#
)?
.expand_inner(
'?',
Implicit,
args! {
"foo.bar" => Vector(3),
"baz" => Vector(0),
"quux" => Vector(1)
}
)?,
r#"
SELECT 1
FROM foo
WHERE bar IN (?, ?, ?)
AND baz IN (NULL)
AND quux IN (?)"#
);
assert_eq!(
parse_query(r#"SELECT 1 FROM foo WHERE quux = '{ "foo": "\'bar\'" }' and bar IN {0}"#)?
.expand_inner('?', Implicit, args! { 0usize => Scalar })?,
r#"SELECT 1 FROM foo WHERE quux = '{ "foo": "\'bar\'" }' and bar IN ?"#
);
Ok(())
}

View file

@ -1,4 +1,5 @@
use sqlx_core::database::{HasOutput, HasRawValue};
use sqlx_core::placeholders;
use sqlx_core::Database;
use super::{
@ -18,6 +19,8 @@ impl Database for MySql {
type TypeInfo = MySqlTypeInfo;
type TypeId = MySqlTypeId;
const PLACEHOLDER_CHAR: char = '?';
const PARAM_INDEXING: placeholders::ParamIndexing = placeholders::ParamIndexing::Implicit;
}
impl<'x> HasOutput<'x> for MySql {

View file

@ -42,9 +42,17 @@ bitflags = "1.2"
base64 = "0.13.0"
md-5 = "0.9.1"
itoa = "0.4.7"
paste = "1.0.5"
[dev-dependencies]
sqlx-core = { version = "0.6.0-pre", path = "../sqlx-core", features = ["_mock"] }
sqlx-test = { path = "../sqlx-test" }
futures-executor = "0.3.8"
anyhow = "1.0.37"
conquer-once = "0.3.2"
tokio = { version = "1.0", features = ["full"] }
[[test]]
name = "postgres-connection"
path = "tests/connection.rs"
required-features = ["async", "sqlx-core/tokio"]

View file

@ -1,32 +1,39 @@
use sqlx_core::{Result, Runtime};
use sqlx_core::{placeholders, Result, Runtime};
use crate::protocol::backend::{
BackendMessage, BackendMessageType, ParameterDescription, RowDescription,
};
use crate::protocol::frontend::{Describe, Parse, StatementRef, Sync, Target};
use crate::protocol::frontend::{Describe, Parse, StatementId, Sync, Target};
use crate::raw_statement::RawStatement;
use crate::{PgArguments, PgClientError, PgConnection};
use crate::{PgArguments, PgClientError, PgConnection, Postgres};
use sqlx_core::arguments::ArgumentIndex;
use sqlx_core::placeholders::{ArgumentKind, Placeholder};
impl<Rt: Runtime> PgConnection<Rt> {
fn start_raw_prepare(
&mut self,
sql: &str,
sql: &placeholders::ParsedQuery<'_>,
arguments: &PgArguments<'_>,
) -> Result<RawStatement> {
let statement_id = self.next_statement_id;
self.next_statement_id = self.next_statement_id.wrapping_add(1);
let mut has_expansion = false;
let sql =
sql.expand::<Postgres, _>(placeholder_get_argument(arguments, &mut has_expansion))?;
// if the query has a comma-expansion, we don't want to keep it as a named prepared statement
let statement_id = if !has_expansion {
let val = self.next_statement_id;
self.next_statement_id = self.next_statement_id.wrapping_add(1);
StatementId::Named(val)
} else {
StatementId::Unnamed
};
let statement = RawStatement::new(statement_id);
self.stream.write_message(&Parse {
statement: StatementRef::Named(statement.id),
sql,
arguments,
})?;
self.stream.write_message(&Parse { statement: statement.id, sql: &sql, arguments })?;
self.stream.write_message(&Describe {
target: Target::Statement(StatementRef::Named(statement.id)),
})?;
self.stream.write_message(&Describe { target: Target::Statement(statement.id) })?;
self.stream.write_message(&Sync)?;
@ -93,7 +100,7 @@ impl<Rt: Runtime> super::PgConnection<Rt> {
#[cfg(feature = "async")]
pub(crate) async fn raw_prepare_async(
&mut self,
sql: &str,
sql: &placeholders::ParsedQuery<'_>,
arguments: &PgArguments<'_>,
) -> Result<RawStatement>
where
@ -106,7 +113,7 @@ impl<Rt: Runtime> super::PgConnection<Rt> {
#[cfg(feature = "blocking")]
pub(crate) fn raw_prepare_blocking(
&mut self,
sql: &str,
sql: &placeholders::ParsedQuery<'_>,
arguments: &PgArguments<'_>,
) -> Result<RawStatement>
where
@ -126,3 +133,23 @@ macro_rules! raw_prepare {
$self.raw_prepare_async($sql, $arguments).await?
};
}
fn placeholder_get_argument<'b, 'a: 'b>(
arguments: &'b PgArguments<'_>,
has_expansion: &'b mut bool,
) -> impl FnMut(&ArgumentIndex<'_>, &Placeholder<'a>) -> Result<ArgumentKind, String> + 'b {
move |idx, place| {
// note: we don't need to print the argument cause it's included in the outer error
let arg = arguments.get(idx).ok_or("unknown argument")?;
Ok(if place.kleene.is_some() {
let len = arg.value().vector_len().ok_or("expected vector for argument")?;
*has_expansion = true;
ArgumentKind::Vector(len)
} else {
ArgumentKind::Scalar
})
}
}

View file

@ -1,8 +1,11 @@
use sqlx_core::{Execute, Result, Runtime};
use sqlx_core::{placeholders, Arguments, Execute, Result, Runtime};
use crate::protocol::frontend::{self, Bind, PortalRef, Query, StatementRef, Sync};
use crate::protocol::frontend::{self, Bind, PortalRef, Query, StatementId, Sync};
use crate::raw_statement::RawStatement;
use crate::{PgArguments, PgConnection, Postgres};
use sqlx_core::arguments::ArgumentIndex;
use sqlx_core::placeholders::{ArgumentKind, Placeholder};
use std::borrow::Cow;
impl<Rt: Runtime> PgConnection<Rt> {
fn write_raw_query_statement(
@ -13,7 +16,7 @@ impl<Rt: Runtime> PgConnection<Rt> {
// bind values to the prepared statement
self.stream.write_message(&Bind {
portal: PortalRef::Unnamed,
statement: StatementRef::Named(statement.id),
statement: statement.id,
arguments,
parameters: &statement.parameters,
})?;
@ -37,11 +40,17 @@ impl<Rt: Runtime> PgConnection<Rt> {
macro_rules! impl_raw_query {
($(@$blocking:ident)? $self:ident, $query:ident) => {{
let parsed = placeholders::parse_query($query.sql())?;
if let Some(arguments) = $query.arguments() {
let statement = raw_prepare!($(@$blocking)? $self, $query.sql(), arguments);
let statement = raw_prepare!($(@$blocking)? $self, &parsed, arguments);
$self.write_raw_query_statement(&statement, arguments)?;
} else {
if !parsed.placeholders().is_empty() {
return Err(placeholders::Error::PreparedStatementsOnly.into());
}
$self.stream.write_message(&Query { sql: $query.sql() })?;
};

View file

@ -1,4 +1,5 @@
use sqlx_core::database::{HasOutput, HasRawValue};
use sqlx_core::placeholders;
use sqlx_core::Database;
use super::{PgColumn, PgOutput, PgQueryResult, PgRawValue, PgRow, PgTypeId, PgTypeInfo};
@ -16,6 +17,9 @@ impl Database for Postgres {
type TypeInfo = PgTypeInfo;
type TypeId = PgTypeId;
const PLACEHOLDER_CHAR: char = '$';
const PARAM_INDEXING: placeholders::ParamIndexing = placeholders::ParamIndexing::OneIndexed;
}
// 'x: execution

View file

@ -23,7 +23,7 @@ pub(crate) use password::{Password, PasswordMd5};
pub(crate) use portal::PortalRef;
pub(crate) use query::Query;
pub(crate) use startup::Startup;
pub(crate) use statement::StatementRef;
pub(crate) use statement::StatementId;
pub(crate) use sync::Sync;
pub(crate) use target::Target;
pub(crate) use terminate::Terminate;

View file

@ -4,13 +4,13 @@ use sqlx_core::io::Serialize;
use sqlx_core::Result;
use crate::io::PgWriteExt;
use crate::protocol::frontend::{PortalRef, StatementRef};
use crate::protocol::frontend::{PortalRef, StatementId};
use crate::{PgArguments, PgOutput, PgRawValueFormat, PgTypeInfo};
use sqlx_core::encode::IsNull;
pub(crate) struct Bind<'a> {
pub(crate) portal: PortalRef,
pub(crate) statement: StatementRef,
pub(crate) statement: StatementId,
pub(crate) parameters: &'a [PgTypeInfo],
pub(crate) arguments: &'a PgArguments<'a>,
}

View file

@ -4,11 +4,11 @@ use sqlx_core::io::{Serialize, WriteExt};
use sqlx_core::Result;
use crate::io::PgWriteExt;
use crate::protocol::frontend::StatementRef;
use crate::protocol::frontend::StatementId;
use crate::{PgArguments, PgTypeId};
pub(crate) struct Parse<'a> {
pub(crate) statement: StatementRef,
pub(crate) statement: StatementId,
pub(crate) sql: &'a str,
pub(crate) arguments: &'a PgArguments<'a>,
}

View file

@ -2,14 +2,14 @@ use sqlx_core::io::Serialize;
use sqlx_core::Result;
#[derive(Debug, Copy, Clone)]
pub(crate) enum StatementRef {
pub(crate) enum StatementId {
Unnamed,
Named(u32),
}
impl Serialize<'_> for StatementRef {
impl Serialize<'_> for StatementId {
fn serialize_with(&self, buf: &mut Vec<u8>, _: ()) -> Result<()> {
if let StatementRef::Named(id) = self {
if let StatementId::Named(id) = self {
buf.extend_from_slice(b"_sqlx_s_");
itoa::write(&mut *buf, *id).unwrap();

View file

@ -2,14 +2,14 @@ use sqlx_core::io::Serialize;
use sqlx_core::Result;
use crate::io::PgWriteExt;
use crate::protocol::frontend::{PortalRef, StatementRef};
use crate::protocol::frontend::{PortalRef, StatementId};
/// Target a command at a portal *or* statement.
/// Used by [`Describe`] and [`Close`].
#[derive(Debug)]
pub(crate) enum Target {
Portal(PortalRef),
Statement(StatementRef),
Statement(StatementId),
}
impl Serialize<'_> for Target {

View file

@ -1,14 +1,15 @@
use crate::protocol::frontend::StatementId;
use crate::{PgColumn, PgTypeInfo};
#[derive(Debug, Clone)]
pub(crate) struct RawStatement {
pub(crate) id: u32,
pub(crate) id: StatementId,
pub(crate) columns: Vec<PgColumn>,
pub(crate) parameters: Vec<PgTypeInfo>,
}
impl RawStatement {
pub(crate) fn new(id: u32) -> Self {
pub(crate) fn new(id: StatementId) -> Self {
Self { id, columns: Vec::new(), parameters: Vec::new() }
}
}

View file

@ -10,10 +10,72 @@ pub enum PgTypeId {
Name(&'static str),
}
/// Macro to reduce boilerplate for defining constants for `PgTypeId`. See usage below for examples.
macro_rules! type_id {
($(
$(#[$meta:meta])*
$name:ident = $kind:ident ($val:literal) $(, [] = $array_kind:ident ($array_val:literal))?
);* $(;)?) => {
impl PgTypeId {
$(
$(#[$meta])*
pub const $name: Self = Self::$kind($val);
$(
paste::paste! {
#[doc = "An array of [`" $name "`][Self::" $name "]."]
///
/// Maps to either a slice or a vector of the equivalent Rust type.
pub const [<$name _ARRAY>]: Self = Self::$array_kind($array_val);
}
)?
)*
}
impl PgTypeId {
/// Get the name of this type as a string.
#[must_use]
pub (crate) const fn name(self) -> &'static str {
match self {
$(
Self::$name => stringify!($name),
$(
// just appends `[]` to the type name
Self::$array_kind($array_val) => concat!(stringify!($name), "[]"),
)?
)*
Self::Name(name) => name,
_ => "UNKNOWN"
}
}
/// Get the ID of the inner type if the current type is an array.
#[allow(dead_code)]
pub (crate) const fn elem_type(self) -> Option<Self> {
match self {
// only generates an arm if `$array_kind` and `$array_val` are provided
$($(Self::$array_kind($array_val) => Some(Self::$kind($val)),)?)*
_ => None,
}
}
/// Get the type ID for an array of this type, if we know it.
#[allow(dead_code)]
pub (crate) const fn array_type(self) -> Option<Self> {
match self {
// only generates an arm if `$array_kind` and `$array_val` are provided
$($(Self::$name => Some(Self::$array_kind($array_val)),)?)*
_ => None,
}
}
}
};
}
// Data Types
// https://www.postgresql.org/docs/current/datatype.html
impl PgTypeId {
// for OIDs see: https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat
type_id! {
// Boolean
// https://www.postgresql.org/docs/current/datatype-boolean.html
@ -21,7 +83,7 @@ impl PgTypeId {
///
/// Maps to `bool`.
///
pub const BOOLEAN: Self = Self::Oid(16);
BOOLEAN = Oid(16), [] = Oid(1000); // also defines `BOOLEAN_ARRAY` for the `BOOLEAN[]` type
// Integers
// https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT
@ -34,7 +96,7 @@ impl PgTypeId {
///
#[doc(alias = "INT2")]
#[doc(alias = "SMALLSERIAL")]
pub const SMALLINT: Self = Self::Oid(21);
SMALLINT = Oid(21), [] = Oid(1005);
/// A 4-byte integer.
///
@ -44,17 +106,17 @@ impl PgTypeId {
///
#[doc(alias = "INT4")]
#[doc(alias = "SERIAL")]
pub const INTEGER: Self = Self::Oid(23);
INTEGER = Oid(23), [] = Oid(1007);
/// An 8-byte integer.
///
/// Compatible with any primitive integer type.
///
/// Maps to `i64`.
/// Maps to `i64`
///
#[doc(alias = "INT8")]
#[doc(alias = "BIGSERIAL")]
pub const BIGINT: Self = Self::Oid(20);
BIGINT = Oid(20), [] = Oid(1016);
// Arbitrary Precision Numbers
// https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-DECIMAL
@ -70,7 +132,7 @@ impl PgTypeId {
/// enabled crate features).
///
#[doc(alias = "DECIMAL")]
pub const NUMERIC: Self = Self::Oid(1700);
NUMERIC = Oid(1700), [] = Oid(1231);
// Floating-Point
// https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-FLOAT
@ -82,7 +144,7 @@ impl PgTypeId {
/// Maps to `f32`.
///
#[doc(alias = "FLOAT4")]
pub const REAL: Self = Self::Oid(700);
REAL = Oid(700), [] = Oid(1021);
/// An 8-byte floating-point numeric type.
///
@ -91,34 +153,24 @@ impl PgTypeId {
/// Maps to `f64`.
///
#[doc(alias = "FLOAT8")]
pub const DOUBLE: Self = Self::Oid(701);
DOUBLE = Oid(701), [] = Oid(1022);
/// The `UNKNOWN` Postgres type. Returned for expressions that do not
/// have a type (e.g., `SELECT $1` with no parameter type hint
/// or `SELECT NULL`).
pub const UNKNOWN: Self = Self::Oid(705);
UNKNOWN = Oid(705);
}
impl PgTypeId {
#[must_use]
pub(crate) const fn name(self) -> &'static str {
match self {
Self::BOOLEAN => "BOOLEAN",
Self::SMALLINT => "SMALLINT",
Self::INTEGER => "INTEGER",
Self::BIGINT => "BIGINT",
Self::NUMERIC => "NUMERIC",
Self::REAL => "REAL",
Self::DOUBLE => "DOUBLE",
_ => "UNKNOWN",
pub(crate) const fn oid(self) -> Option<u32> {
if let Self::Oid(oid) = self {
Some(oid)
} else {
None
}
}
pub(crate) const fn is_integer(&self) -> bool {
matches!(*self, Self::SMALLINT | Self::INTEGER | Self::BIGINT)
pub(crate) const fn is_integer(self) -> bool {
matches!(self, Self::SMALLINT | Self::INTEGER | Self::BIGINT)
}
}

View file

@ -1,3 +1,4 @@
mod array;
mod bool;
mod int;
mod null;

View file

@ -0,0 +1,160 @@
use sqlx_core::{encode, Arguments, Database, Encode, Type, TypeEncode};
use crate::{PgOutput, PgTypeId, PgTypeInfo, Postgres};
use sqlx_core::database::HasOutput;
/// Marker trait for types which support being wrapped in array in Postgres.
pub trait PgHasArray {
/// The type ID in Postgres of the array type which has this type as an element.
const ARRAY_TYPE_ID: PgTypeId;
}
impl<T: PgHasArray> Type<Postgres> for &'_ [T] {
fn type_id() -> <Postgres as Database>::TypeId
where
Self: Sized,
{
T::ARRAY_TYPE_ID
}
// TODO: check `PgTypeInfo` for array element type and check compatibility of that
// fn compatible(ty: &<Postgres as Database>::TypeInfo) -> bool
// where
// Self: Sized,
// {
// }
}
impl<T: Type<Postgres> + Encode<Postgres>> Encode<Postgres> for &'_ [T] {
fn encode(
&self,
ty: &<Postgres as Database>::TypeInfo,
out: &mut <Postgres as HasOutput<'_>>::Output,
) -> encode::Result {
encode_array(*self, ty, out)
}
fn vector_len(&self) -> Option<usize> {
Some(self.len())
}
fn expand_vector<'a>(&'a self, arguments: &mut Arguments<'a, Postgres>) {
for elem in *self {
arguments.add(elem);
}
}
}
// Vector
impl<T: PgHasArray> Type<Postgres> for Vec<T> {
fn type_id() -> <Postgres as Database>::TypeId
where
Self: Sized,
{
<&[T]>::type_id()
}
fn compatible(ty: &<Postgres as Database>::TypeInfo) -> bool
where
Self: Sized,
{
<&[T]>::compatible(ty)
}
}
impl<T: Type<Postgres> + Encode<Postgres>> Encode<Postgres> for Vec<T> {
fn encode(
&self,
ty: &<Postgres as Database>::TypeInfo,
out: &mut <Postgres as HasOutput<'_>>::Output,
) -> encode::Result {
encode_array(self.iter(), ty, out)
}
fn vector_len(&self) -> Option<usize> {
Some(self.len())
}
fn expand_vector<'a>(&'a self, arguments: &mut Arguments<'a, Postgres>) {
for elem in self {
arguments.add(elem);
}
}
}
// static-size arrays
impl<T: PgHasArray, const N: usize> Type<Postgres> for [T; N] {
fn type_id() -> <Postgres as Database>::TypeId
where
Self: Sized,
{
<&[T]>::type_id()
}
fn compatible(ty: &<Postgres as Database>::TypeInfo) -> bool
where
Self: Sized,
{
<&[T]>::compatible(ty)
}
}
impl<T: Type<Postgres> + Encode<Postgres>, const N: usize> Encode<Postgres> for [T; N] {
fn encode(
&self,
ty: &<Postgres as Database>::TypeInfo,
out: &mut <Postgres as HasOutput<'_>>::Output,
) -> encode::Result {
encode_array(self.iter(), ty, out)
}
fn vector_len(&self) -> Option<usize> {
Some(self.len())
}
fn expand_vector<'a>(&'a self, arguments: &mut Arguments<'a, Postgres>) {
for elem in self {
arguments.add(elem);
}
}
}
pub fn encode_array<T: TypeEncode<Postgres>, I: IntoIterator<Item = T>>(
array: I,
_ty: &PgTypeInfo,
out: &mut PgOutput<'_>,
) -> encode::Result {
// number of dimensions (1 for now)
out.buffer().extend_from_slice(&1i32.to_be_bytes());
let len_start = out.buffer().len();
// whether or not the array is null (fixup afterward)
out.buffer().extend_from_slice(&[0; 4]);
// FIXME: better error message/avoid the error
let elem_type = T::type_id().oid().ok_or_else(|| {
encode::Error::msg("can only bind an array with elements with a known oid")
})?;
out.buffer().extend_from_slice(&elem_type.to_be_bytes());
let mut count: i32 = 0;
let is_null = array
.into_iter()
.map(|elem| {
count = count
.checked_add(1)
.ok_or_else(|| encode::Error::msg("array length overflows i32"))?;
elem.encode(&PgTypeInfo(T::type_id()), out)
})
.collect::<encode::Result>()?;
// fixup the length
out.buffer()[len_start..][..4].copy_from_slice(&count.to_be_bytes());
Ok(is_null)
}

View file

@ -86,7 +86,7 @@ where
}
macro_rules! impl_type_int {
($ty:ty $(: $real:ty)? => $sql:ident) => {
($ty:ty $(: $real:ty)? => $sql:ident $(, [] => $array_sql:ident)?) => {
impl Type<Postgres> for $ty {
fn type_id() -> PgTypeId {
PgTypeId::$sql
@ -97,6 +97,12 @@ macro_rules! impl_type_int {
}
}
$(
impl super::array::PgHasArray for $ty {
const ARRAY_TYPE_ID: PgTypeId = PgTypeId::$array_sql;
}
)?
impl Encode<Postgres> for $ty {
fn encode(&self, ty: &PgTypeInfo, out: &mut PgOutput<'_>) -> encode::Result {
ensure_not_too_large_or_too_small((*self $(as $real)?).into(), ty)?;
@ -115,9 +121,9 @@ macro_rules! impl_type_int {
};
}
impl_type_int! { i8 => SMALLINT }
impl_type_int! { i16 => SMALLINT }
impl_type_int! { i32 => INTEGER }
impl_type_int! { i8 => SMALLINT, [] => SMALLINT_ARRAY }
impl_type_int! { i16 => SMALLINT, [] => SMALLINT_ARRAY }
impl_type_int! { i32 => INTEGER, [] => INTEGER_ARRAY }
impl_type_int! { i64 => BIGINT }
impl_type_int! { i128 => BIGINT }

View file

@ -0,0 +1,55 @@
use sqlx_core::{Connect, Connection, Executor, Tokio};
use sqlx_postgres::PgArguments;
use sqlx_postgres::PgConnection;
use sqlx_test::assert_cancellation_safe;
use std::env;
#[tokio::test]
async fn test_connect() -> anyhow::Result<()> {
let url = env::var("DATABASE_URL")?;
let mut conn = PgConnection::<Tokio>::connect(&url).await?;
conn.ping().await?;
Ok(())
}
#[tokio::test]
async fn test_select_1() -> anyhow::Result<()> {
let url = env::var("DATABASE_URL")?;
let mut conn = PgConnection::<Tokio>::connect(&url).await?;
let row = conn.fetch_one("SELECT 1").await?;
let col0: i32 = row.try_get(0)?;
assert_eq!(col0, 1);
Ok(())
}
#[tokio::test]
async fn test_generic_placeholders() -> anyhow::Result<()> {
let url = env::var("DATABASE_URL")?;
let mut conn = PgConnection::<Tokio>::connect(&url).await?;
let mut args = PgArguments::new();
args.add(&1i32);
let row = conn.fetch_one(("SELECT {}", args)).await?;
let col0: i32 = row.try_get(0)?;
let mut args = PgArguments::new();
args.add(&[1i32, 2, 3, 4, 5, 6]);
let row = conn
.fetch_one((
"SELECT val FROM generate_series(0, 9, 3) AS vals(val) WHERE val IN ({+})",
args,
))
.await?;
let col0: i32 = row.try_get(0)?;
assert_eq!(col0, 3);
Ok(())
}