fix: clippy warnings

This commit is contained in:
Austin Bonander 2024-06-05 20:15:44 -07:00
parent 41089f3655
commit bae083cf79
110 changed files with 513 additions and 480 deletions

View file

@ -34,6 +34,7 @@ jobs:
rustup update rustup update
rustup component add clippy rustup component add clippy
rustup toolchain install beta rustup toolchain install beta
rustup component add --toolchain beta clippy
- run: > - run: >
cargo clippy cargo clippy

View file

@ -1,5 +1,13 @@
disallowed-methods = [ [[disallowed-methods]]
# It is *much* too easy to misread `x.min(y)` as "x should be *at least* y" when in fact it path = "core::cmp::Ord::min"
# means the *exact* opposite, and same with `x.max(y)`; use `cmp::{min, max}` instead. reason = '''
"core::cmp::Ord::min", "core::cmp::Ord::max" too easy to misread `x.min(y)` as "let the minimum value of `x` be `y`" when it actually means the exact opposite;
] use `std::cmp::min` instead.
'''
[[disallowed-methods]]
path = "core::cmp::Ord::max"
reason = '''
too easy to misread `x.max(y)` as "let the maximum value of `x` be `y`" when it actually means the exact opposite;
use `std::cmp::max` instead.
'''

View file

@ -129,16 +129,15 @@ where
.build(), .build(),
|| { || {
connect(db_url).map_err(|e| -> backoff::Error<anyhow::Error> { connect(db_url).map_err(|e| -> backoff::Error<anyhow::Error> {
match e { if let sqlx::Error::Io(ref ioe) = e {
sqlx::Error::Io(ref ioe) => match ioe.kind() { match ioe.kind() {
io::ErrorKind::ConnectionRefused io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionReset
| io::ErrorKind::ConnectionAborted => { | io::ErrorKind::ConnectionAborted => {
return backoff::Error::transient(e.into()); return backoff::Error::transient(e.into());
} }
_ => (), _ => (),
}, }
_ => (),
} }
backoff::Error::permanent(e.into()) backoff::Error::permanent(e.into())

View file

@ -83,7 +83,7 @@ impl Metadata {
self.packages.get(id) self.packages.get(id)
} }
pub fn entries<'this>(&'this self) -> btree_map::Iter<'this, MetadataId, Package> { pub fn entries(&self) -> btree_map::Iter<'_, MetadataId, Package> {
self.packages.iter() self.packages.iter()
} }

View file

@ -20,7 +20,7 @@ fn create_file(
use std::path::PathBuf; use std::path::PathBuf;
let mut file_name = file_prefix.to_string(); let mut file_name = file_prefix.to_string();
file_name.push_str("_"); file_name.push('_');
file_name.push_str(&description.replace(' ', "_")); file_name.push_str(&description.replace(' ', "_"));
file_name.push_str(migration_type.suffix()); file_name.push_str(migration_type.suffix());
@ -120,20 +120,20 @@ pub async fn add(
if migration_type.is_reversible() { if migration_type.is_reversible() {
create_file( create_file(
migration_source, migration_source,
&file_prefix, file_prefix,
description, description,
MigrationType::ReversibleUp, MigrationType::ReversibleUp,
)?; )?;
create_file( create_file(
migration_source, migration_source,
&file_prefix, file_prefix,
description, description,
MigrationType::ReversibleDown, MigrationType::ReversibleDown,
)?; )?;
} else { } else {
create_file( create_file(
migration_source, migration_source,
&file_prefix, file_prefix,
description, description,
MigrationType::Simple, MigrationType::Simple,
)?; )?;
@ -194,7 +194,7 @@ fn short_checksum(checksum: &[u8]) -> String {
pub async fn info(migration_source: &str, connect_opts: &ConnectOpts) -> anyhow::Result<()> { pub async fn info(migration_source: &str, connect_opts: &ConnectOpts) -> anyhow::Result<()> {
let migrator = Migrator::new(Path::new(migration_source)).await?; let migrator = Migrator::new(Path::new(migration_source)).await?;
let mut conn = crate::connect(&connect_opts).await?; let mut conn = crate::connect(connect_opts).await?;
conn.ensure_migrations_table().await?; conn.ensure_migrations_table().await?;
@ -300,7 +300,7 @@ pub async fn run(
let latest_version = applied_migrations let latest_version = applied_migrations
.iter() .iter()
.max_by(|x, y| x.version.cmp(&y.version)) .max_by(|x, y| x.version.cmp(&y.version))
.and_then(|migration| Some(migration.version)) .map(|migration| migration.version)
.unwrap_or(0); .unwrap_or(0);
if let Some(target_version) = target_version { if let Some(target_version) = target_version {
if target_version < latest_version { if target_version < latest_version {
@ -326,10 +326,8 @@ pub async fn run(
} }
} }
None => { None => {
let skip = match target_version { let skip =
Some(target_version) if migration.version > target_version => true, target_version.is_some_and(|target_version| migration.version > target_version);
_ => false,
};
let elapsed = if dry_run || skip { let elapsed = if dry_run || skip {
Duration::new(0, 0) Duration::new(0, 0)
@ -380,7 +378,7 @@ pub async fn revert(
} }
} }
let mut conn = crate::connect(&connect_opts).await?; let mut conn = crate::connect(connect_opts).await?;
conn.ensure_migrations_table().await?; conn.ensure_migrations_table().await?;
@ -395,7 +393,7 @@ pub async fn revert(
let latest_version = applied_migrations let latest_version = applied_migrations
.iter() .iter()
.max_by(|x, y| x.version.cmp(&y.version)) .max_by(|x, y| x.version.cmp(&y.version))
.and_then(|migration| Some(migration.version)) .map(|migration| migration.version)
.unwrap_or(0); .unwrap_or(0);
if let Some(target_version) = target_version { if let Some(target_version) = target_version {
if target_version > latest_version { if target_version > latest_version {
@ -417,10 +415,9 @@ pub async fn revert(
} }
if applied_migrations.contains_key(&migration.version) { if applied_migrations.contains_key(&migration.version) {
let skip = match target_version { let skip =
Some(target_version) if migration.version <= target_version => true, target_version.is_some_and(|target_version| migration.version <= target_version);
_ => false,
};
let elapsed = if dry_run || skip { let elapsed = if dry_run || skip {
Duration::new(0, 0) Duration::new(0, 0)
} else { } else {
@ -447,7 +444,7 @@ pub async fn revert(
// Only a single migration will be reverted at a time if no target // Only a single migration will be reverted at a time if no target
// version is supplied, so we break. // version is supplied, so we break.
if let None = target_version { if target_version.is_none() {
break; break;
} }
} }

View file

@ -11,13 +11,13 @@ use std::future;
impl<'c> Executor<'c> for &'c mut AnyConnection { impl<'c> Executor<'c> for &'c mut AnyConnection {
type Database = Any; type Database = Any;
fn fetch_many<'e, 'q: 'e, E: 'q>( fn fetch_many<'e, 'q: 'e, E>(
self, self,
mut query: E, mut query: E,
) -> BoxStream<'e, Result<Either<AnyQueryResult, AnyRow>, Error>> ) -> BoxStream<'e, Result<Either<AnyQueryResult, AnyRow>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Any>, E: 'q + Execute<'q, Any>,
{ {
let arguments = match query.take_arguments().map_err(Error::Encode) { let arguments = match query.take_arguments().map_err(Error::Encode) {
Ok(arguments) => arguments, Ok(arguments) => arguments,
@ -26,13 +26,13 @@ impl<'c> Executor<'c> for &'c mut AnyConnection {
self.backend.fetch_many(query.sql(), arguments) self.backend.fetch_many(query.sql(), arguments)
} }
fn fetch_optional<'e, 'q: 'e, E: 'q>( fn fetch_optional<'e, 'q: 'e, E>(
self, self,
mut query: E, mut query: E,
) -> BoxFuture<'e, Result<Option<AnyRow>, Error>> ) -> BoxFuture<'e, Result<Option<AnyRow>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: 'q + Execute<'q, Self::Database>,
{ {
let arguments = match query.take_arguments().map_err(Error::Encode) { let arguments = match query.take_arguments().map_err(Error::Encode) {
Ok(arguments) => arguments, Ok(arguments) => arguments,

View file

@ -38,7 +38,7 @@ impl<'q> Statement<'q> for AnyStatement<'q> {
fn parameters(&self) -> Option<Either<&[AnyTypeInfo], usize>> { fn parameters(&self) -> Option<Either<&[AnyTypeInfo], usize>> {
match &self.parameters { match &self.parameters {
Some(Either::Left(types)) => Some(Either::Left(&types)), Some(Either::Left(types)) => Some(Either::Left(types)),
Some(Either::Right(count)) => Some(Either::Right(*count)), Some(Either::Right(count)) => Some(Either::Right(*count)),
None => None, None => None,
} }
@ -57,7 +57,7 @@ impl<'i> ColumnIndex<AnyStatement<'_>> for &'i str {
.column_names .column_names
.get(*self) .get(*self)
.ok_or_else(|| Error::ColumnNotFound((*self).into())) .ok_or_else(|| Error::ColumnNotFound((*self).into()))
.map(|v| *v) .copied()
} }
} }

View file

@ -7,6 +7,8 @@ use crate::types::Type;
use std::fmt::{self, Write}; use std::fmt::{self, Write};
/// A tuple of arguments to be sent to the database. /// A tuple of arguments to be sent to the database.
// This lint is designed for general collections, but `Arguments` is not meant to be as such.
#[allow(clippy::len_without_is_empty)]
pub trait Arguments<'q>: Send + Sized + Default { pub trait Arguments<'q>: Send + Sized + Default {
type Database: Database; type Database: Database;

View file

@ -43,6 +43,10 @@ impl<T> StatementCache<T> {
self.inner.len() self.inner.len()
} }
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
/// Removes the least recently used item from the cache. /// Removes the least recently used item from the cache.
pub fn remove_lru(&mut self) -> Option<T> { pub fn remove_lru(&mut self) -> Option<T> {
self.inner.remove_lru().map(|(_, v)| v) self.inner.remove_lru().map(|(_, v)| v)

View file

@ -143,7 +143,7 @@ pub trait Connection: Send {
{ {
let options = url.parse(); let options = url.parse();
Box::pin(async move { Ok(Self::connect_with(&options?).await?) }) Box::pin(async move { Self::connect_with(&options?).await })
} }
/// Establish a new database connection with the provided options. /// Establish a new database connection with the provided options.

View file

@ -34,25 +34,25 @@ pub trait Executor<'c>: Send + Debug + Sized {
type Database: Database; type Database: Database;
/// Execute the query and return the total number of rows affected. /// Execute the query and return the total number of rows affected.
fn execute<'e, 'q: 'e, E: 'q>( fn execute<'e, 'q: 'e, E>(
self, self,
query: E, query: E,
) -> BoxFuture<'e, Result<<Self::Database as Database>::QueryResult, Error>> ) -> BoxFuture<'e, Result<<Self::Database as Database>::QueryResult, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: 'q + Execute<'q, Self::Database>,
{ {
self.execute_many(query).try_collect().boxed() self.execute_many(query).try_collect().boxed()
} }
/// Execute multiple queries and return the rows affected from each query, in a stream. /// Execute multiple queries and return the rows affected from each query, in a stream.
fn execute_many<'e, 'q: 'e, E: 'q>( fn execute_many<'e, 'q: 'e, E>(
self, self,
query: E, query: E,
) -> BoxStream<'e, Result<<Self::Database as Database>::QueryResult, Error>> ) -> BoxStream<'e, Result<<Self::Database as Database>::QueryResult, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: 'q + Execute<'q, Self::Database>,
{ {
self.fetch_many(query) self.fetch_many(query)
.try_filter_map(|step| async move { .try_filter_map(|step| async move {
@ -65,13 +65,13 @@ pub trait Executor<'c>: Send + Debug + Sized {
} }
/// Execute the query and return the generated results as a stream. /// Execute the query and return the generated results as a stream.
fn fetch<'e, 'q: 'e, E: 'q>( fn fetch<'e, 'q: 'e, E>(
self, self,
query: E, query: E,
) -> BoxStream<'e, Result<<Self::Database as Database>::Row, Error>> ) -> BoxStream<'e, Result<<Self::Database as Database>::Row, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: 'q + Execute<'q, Self::Database>,
{ {
self.fetch_many(query) self.fetch_many(query)
.try_filter_map(|step| async move { .try_filter_map(|step| async move {
@ -85,7 +85,7 @@ pub trait Executor<'c>: Send + Debug + Sized {
/// Execute multiple queries and return the generated results as a stream /// Execute multiple queries and return the generated results as a stream
/// from each query, in a stream. /// from each query, in a stream.
fn fetch_many<'e, 'q: 'e, E: 'q>( fn fetch_many<'e, 'q: 'e, E>(
self, self,
query: E, query: E,
) -> BoxStream< ) -> BoxStream<
@ -97,28 +97,28 @@ pub trait Executor<'c>: Send + Debug + Sized {
> >
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>; E: 'q + Execute<'q, Self::Database>;
/// Execute the query and return all the generated results, collected into a [`Vec`]. /// Execute the query and return all the generated results, collected into a [`Vec`].
fn fetch_all<'e, 'q: 'e, E: 'q>( fn fetch_all<'e, 'q: 'e, E>(
self, self,
query: E, query: E,
) -> BoxFuture<'e, Result<Vec<<Self::Database as Database>::Row>, Error>> ) -> BoxFuture<'e, Result<Vec<<Self::Database as Database>::Row>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: 'q + Execute<'q, Self::Database>,
{ {
self.fetch(query).try_collect().boxed() self.fetch(query).try_collect().boxed()
} }
/// Execute the query and returns exactly one row. /// Execute the query and returns exactly one row.
fn fetch_one<'e, 'q: 'e, E: 'q>( fn fetch_one<'e, 'q: 'e, E>(
self, self,
query: E, query: E,
) -> BoxFuture<'e, Result<<Self::Database as Database>::Row, Error>> ) -> BoxFuture<'e, Result<<Self::Database as Database>::Row, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: 'q + Execute<'q, Self::Database>,
{ {
self.fetch_optional(query) self.fetch_optional(query)
.and_then(|row| match row { .and_then(|row| match row {
@ -129,13 +129,13 @@ pub trait Executor<'c>: Send + Debug + Sized {
} }
/// Execute the query and returns at most one row. /// Execute the query and returns at most one row.
fn fetch_optional<'e, 'q: 'e, E: 'q>( fn fetch_optional<'e, 'q: 'e, E>(
self, self,
query: E, query: E,
) -> BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>> ) -> BoxFuture<'e, Result<Option<<Self::Database as Database>::Row>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>; E: 'q + Execute<'q, Self::Database>;
/// Prepare the SQL query to inspect the type information of its parameters /// Prepare the SQL query to inspect the type information of its parameters
/// and results. /// and results.

View file

@ -121,10 +121,10 @@ impl<'a, T> Stream for TryAsyncStream<'a, T> {
#[macro_export] #[macro_export]
macro_rules! try_stream { macro_rules! try_stream {
($($block:tt)*) => { ($($block:tt)*) => {
crate::ext::async_stream::TryAsyncStream::new(move |yielder| async move { $crate::ext::async_stream::TryAsyncStream::new(move |yielder| async move {
// Anti-footgun: effectively pins `yielder` to this future to prevent any accidental // Anti-footgun: effectively pins `yielder` to this future to prevent any accidental
// move to another task, which could deadlock. // move to another task, which could deadlock.
let ref yielder = yielder; let yielder = &yielder;
macro_rules! r#yield { macro_rules! r#yield {
($v:expr) => {{ ($v:expr) => {{

View file

@ -36,14 +36,14 @@ impl Hash for UStr {
fn hash<H: Hasher>(&self, state: &mut H) { fn hash<H: Hasher>(&self, state: &mut H) {
// Forward the hash to the string representation of this // Forward the hash to the string representation of this
// A derive(Hash) encodes the enum discriminant // A derive(Hash) encodes the enum discriminant
(&**self).hash(state); (**self).hash(state);
} }
} }
impl Borrow<str> for UStr { impl Borrow<str> for UStr {
#[inline] #[inline]
fn borrow(&self) -> &str { fn borrow(&self) -> &str {
&**self self
} }
} }
@ -102,6 +102,6 @@ impl serde::Serialize for UStr {
where where
S: serde::Serializer, S: serde::Serializer,
{ {
serializer.serialize_str(&self) serializer.serialize_str(self)
} }
} }

View file

@ -22,7 +22,7 @@ pub trait BufExt: Buf {
impl BufExt for Bytes { impl BufExt for Bytes {
fn get_bytes_nul(&mut self) -> Result<Bytes, Error> { fn get_bytes_nul(&mut self) -> Result<Bytes, Error> {
let nul = let nul =
memchr(b'\0', &self).ok_or_else(|| err_protocol!("expected NUL in byte sequence"))?; memchr(b'\0', self).ok_or_else(|| err_protocol!("expected NUL in byte sequence"))?;
let v = self.slice(0..nul); let v = self.slice(0..nul);
@ -40,7 +40,7 @@ impl BufExt for Bytes {
fn get_str_nul(&mut self) -> Result<String, Error> { fn get_str_nul(&mut self) -> Result<String, Error> {
self.get_bytes_nul().and_then(|bytes| { self.get_bytes_nul().and_then(|bytes| {
from_utf8(&*bytes) from_utf8(&bytes)
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
.map_err(|err| err_protocol!("{}", err)) .map_err(|err| err_protocol!("{}", err))
}) })

View file

@ -16,7 +16,7 @@
#![warn(future_incompatible, rust_2018_idioms)] #![warn(future_incompatible, rust_2018_idioms)]
#![allow(clippy::needless_doctest_main, clippy::type_complexity)] #![allow(clippy::needless_doctest_main, clippy::type_complexity)]
// See `clippy.toml` at the workspace root // See `clippy.toml` at the workspace root
#![deny(clippy::disallowed_method)] #![deny(clippy::disallowed_methods)]
// The only unsafe code in SQLx is that necessary to interact with native APIs like with SQLite, // The only unsafe code in SQLx is that necessary to interact with native APIs like with SQLite,
// and that can live in its own separate driver crate. // and that can live in its own separate driver crate.
#![forbid(unsafe_code)] #![forbid(unsafe_code)]

View file

@ -106,14 +106,14 @@ impl<'q> QueryLogger<'q> {
let log_is_enabled = log::log_enabled!(target: "sqlx::query", log_level) let log_is_enabled = log::log_enabled!(target: "sqlx::query", log_level)
|| private_tracing_dynamic_enabled!(target: "sqlx::query", tracing_level); || private_tracing_dynamic_enabled!(target: "sqlx::query", tracing_level);
if log_is_enabled { if log_is_enabled {
let mut summary = parse_query_summary(&self.sql); let mut summary = parse_query_summary(self.sql);
let sql = if summary != self.sql { let sql = if summary != self.sql {
summary.push_str(""); summary.push_str("");
format!( format!(
"\n\n{}\n", "\n\n{}\n",
sqlformat::format( sqlformat::format(
&self.sql, self.sql,
&sqlformat::QueryParams::None, &sqlformat::QueryParams::None,
sqlformat::FormatOptions::default() sqlformat::FormatOptions::default()
) )

View file

@ -55,14 +55,14 @@ pub struct ResolveError {
// FIXME: paths should just be part of `Migration` but we can't add a field backwards compatibly // FIXME: paths should just be part of `Migration` but we can't add a field backwards compatibly
// since it's `#[non_exhaustive]`. // since it's `#[non_exhaustive]`.
pub fn resolve_blocking(path: &Path) -> Result<Vec<(Migration, PathBuf)>, ResolveError> { pub fn resolve_blocking(path: &Path) -> Result<Vec<(Migration, PathBuf)>, ResolveError> {
let mut s = fs::read_dir(path).map_err(|e| ResolveError { let s = fs::read_dir(path).map_err(|e| ResolveError {
message: format!("error reading migration directory {}: {e}", path.display()), message: format!("error reading migration directory {}: {e}", path.display()),
source: Some(e), source: Some(e),
})?; })?;
let mut migrations = Vec::new(); let mut migrations = Vec::new();
while let Some(res) = s.next() { for res in s {
let entry = res.map_err(|e| ResolveError { let entry = res.map_err(|e| ResolveError {
message: format!( message: format!(
"error reading contents of migration directory {}: {e}", "error reading contents of migration directory {}: {e}",

View file

@ -101,7 +101,7 @@ where
let this = &mut *self; let this = &mut *self;
while !this.buf.is_empty() { while !this.buf.is_empty() {
match this.socket.try_write(&mut this.buf) { match this.socket.try_write(this.buf) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => { Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
ready!(this.socket.poll_write_ready(cx))?; ready!(this.socket.poll_write_ready(cx))?;
} }
@ -225,14 +225,12 @@ pub async fn connect_tcp<Ws: WithSocket>(
// If we reach this point, it means we failed to connect to any of the addresses. // If we reach this point, it means we failed to connect to any of the addresses.
// Return the last error we encountered, or a custom error if the hostname didn't resolve to any address. // Return the last error we encountered, or a custom error if the hostname didn't resolve to any address.
match last_err { match last_err {
Some(err) => return Err(err.into()), Some(err) => Err(err.into()),
None => { None => Err(io::Error::new(
return Err(io::Error::new( io::ErrorKind::AddrNotAvailable,
io::ErrorKind::AddrNotAvailable, "Hostname did not resolve to any addresses",
"Hostname did not resolve to any addresses", )
) .into()),
.into())
}
} }
} }
@ -249,38 +247,41 @@ pub async fn connect_uds<P: AsRef<Path>, Ws: WithSocket>(
path: P, path: P,
with_socket: Ws, with_socket: Ws,
) -> crate::Result<Ws::Output> { ) -> crate::Result<Ws::Output> {
#[cfg(unix)]
{
#[cfg(feature = "_rt-tokio")]
if crate::rt::rt_tokio::available() {
use tokio::net::UnixStream;
let stream = UnixStream::connect(path).await?;
return Ok(with_socket.with_socket(stream));
}
#[cfg(feature = "_rt-async-std")]
{
use async_io::Async;
use std::os::unix::net::UnixStream;
let stream = Async::<UnixStream>::connect(path).await?;
Ok(with_socket.with_socket(stream))
}
#[cfg(not(feature = "_rt-async-std"))]
{
crate::rt::missing_rt((path, with_socket))
}
}
#[cfg(not(unix))] #[cfg(not(unix))]
{ {
drop((path, with_socket)); drop((path, with_socket));
return Err(io::Error::new( Err(io::Error::new(
io::ErrorKind::Unsupported, io::ErrorKind::Unsupported,
"Unix domain sockets are not supported on this platform", "Unix domain sockets are not supported on this platform",
) )
.into()); .into())
}
#[cfg(all(unix, feature = "_rt-tokio"))]
if crate::rt::rt_tokio::available() {
use tokio::net::UnixStream;
let stream = UnixStream::connect(path).await?;
return Ok(with_socket.with_socket(stream));
}
#[cfg(all(unix, feature = "_rt-async-std"))]
{
use async_io::Async;
use std::os::unix::net::UnixStream;
let stream = Async::<UnixStream>::connect(path).await?;
return Ok(with_socket.with_socket(stream));
}
#[cfg(all(unix, not(feature = "_rt-async-std")))]
{
crate::rt::missing_rt((path, with_socket))
} }
} }

View file

@ -241,9 +241,7 @@ impl ServerCertVerifier for NoHostnameTlsVerifier {
ocsp_response, ocsp_response,
now, now,
) { ) {
Err(TlsError::InvalidCertificate(reason)) Err(TlsError::InvalidCertificate(CertificateError::NotValidForName)) => {
if reason == CertificateError::NotValidForName =>
{
Ok(ServerCertVerified::assertion()) Ok(ServerCertVerified::assertion())
} }
res => res, res => res,

View file

@ -15,12 +15,12 @@ where
{ {
type Database = DB; type Database = DB;
fn fetch_many<'e, 'q: 'e, E: 'q>( fn fetch_many<'e, 'q: 'e, E>(
self, self,
query: E, query: E,
) -> BoxStream<'e, Result<Either<DB::QueryResult, DB::Row>, Error>> ) -> BoxStream<'e, Result<Either<DB::QueryResult, DB::Row>, Error>>
where where
E: Execute<'q, Self::Database>, E: 'q + Execute<'q, Self::Database>,
{ {
let pool = self.clone(); let pool = self.clone();
@ -36,12 +36,12 @@ where
}) })
} }
fn fetch_optional<'e, 'q: 'e, E: 'q>( fn fetch_optional<'e, 'q: 'e, E>(
self, self,
query: E, query: E,
) -> BoxFuture<'e, Result<Option<DB::Row>, Error>> ) -> BoxFuture<'e, Result<Option<DB::Row>, Error>>
where where
E: Execute<'q, Self::Database>, E: 'q + Execute<'q, Self::Database>,
{ {
let pool = self.clone(); let pool = self.clone();

View file

@ -153,7 +153,7 @@ impl<DB: Database> PoolInner<DB> {
if parent_close_event.as_mut().poll(cx).is_ready() { if parent_close_event.as_mut().poll(cx).is_ready() {
// Propagate the parent's close event to the child. // Propagate the parent's close event to the child.
let _ = self.close(); self.mark_closed();
return Poll::Ready(Err(Error::PoolClosed)); return Poll::Ready(Err(Error::PoolClosed));
} }
@ -208,7 +208,7 @@ impl<DB: Database> PoolInner<DB> {
let Floating { inner: idle, guard } = floating.into_idle(); let Floating { inner: idle, guard } = floating.into_idle();
if !self.idle_conns.push(idle).is_ok() { if self.idle_conns.push(idle).is_err() {
panic!("BUG: connection queue overflow in release()"); panic!("BUG: connection queue overflow in release()");
} }
@ -226,7 +226,7 @@ impl<DB: Database> PoolInner<DB> {
self: &'a Arc<Self>, self: &'a Arc<Self>,
permit: AsyncSemaphoreReleaser<'a>, permit: AsyncSemaphoreReleaser<'a>,
) -> Result<DecrementSizeGuard<DB>, AsyncSemaphoreReleaser<'a>> { ) -> Result<DecrementSizeGuard<DB>, AsyncSemaphoreReleaser<'a>> {
match self let result = self
.size .size
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| { .fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| {
if self.is_closed() { if self.is_closed() {
@ -235,7 +235,9 @@ impl<DB: Database> PoolInner<DB> {
size.checked_add(1) size.checked_add(1)
.filter(|size| size <= &self.options.max_connections) .filter(|size| size <= &self.options.max_connections)
}) { });
match result {
// we successfully incremented the size // we successfully incremented the size
Ok(_) => Ok(DecrementSizeGuard::from_permit((*self).clone(), permit)), Ok(_) => Ok(DecrementSizeGuard::from_permit((*self).clone(), permit)),
// the pool is at max capacity or is closed // the pool is at max capacity or is closed
@ -332,10 +334,10 @@ impl<DB: Database> PoolInner<DB> {
} }
let mut backoff = Duration::from_millis(10); let mut backoff = Duration::from_millis(10);
let max_backoff = deadline_as_timeout::<DB>(deadline)? / 5; let max_backoff = deadline_as_timeout(deadline)? / 5;
loop { loop {
let timeout = deadline_as_timeout::<DB>(deadline)?; let timeout = deadline_as_timeout(deadline)?;
// clone the connect options arc so it can be used without holding the RwLockReadGuard // clone the connect options arc so it can be used without holding the RwLockReadGuard
// across an async await point // across an async await point
@ -505,9 +507,9 @@ async fn check_idle_conn<DB: Database>(
} }
fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) { fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
// NOTE: use `pool_weak` for the maintenance tasks so // NOTE: use `pool_weak` for the maintenance tasks
// they don't keep `PoolInner` from being dropped. // so they don't keep `PoolInner` from being dropped.
let pool_weak = Arc::downgrade(&pool); let pool_weak = Arc::downgrade(pool);
let period = match (pool.options.max_lifetime, pool.options.idle_timeout) { let period = match (pool.options.max_lifetime, pool.options.idle_timeout) {
(Some(it), None) | (None, Some(it)) => it, (Some(it), None) | (None, Some(it)) => it,

View file

@ -376,7 +376,7 @@ impl<DB: Database> Pool<DB> {
/// Retrieves a connection and immediately begins a new transaction. /// Retrieves a connection and immediately begins a new transaction.
pub async fn begin(&self) -> Result<Transaction<'static, DB>, Error> { pub async fn begin(&self) -> Result<Transaction<'static, DB>, Error> {
Ok(Transaction::begin(MaybePoolConnection::PoolConnection(self.acquire().await?)).await?) Transaction::begin(MaybePoolConnection::PoolConnection(self.acquire().await?)).await
} }
/// Attempts to retrieve a connection and immediately begins a new transaction if successful. /// Attempts to retrieve a connection and immediately begins a new transaction if successful.
@ -642,7 +642,7 @@ impl FusedFuture for CloseEvent {
/// get the time between the deadline and now and use that as our timeout /// get the time between the deadline and now and use that as our timeout
/// ///
/// returns `Error::PoolTimedOut` if the deadline is in the past /// returns `Error::PoolTimedOut` if the deadline is in the past
fn deadline_as_timeout<DB: Database>(deadline: Instant) -> Result<Duration, Error> { fn deadline_as_timeout(deadline: Instant) -> Result<Duration, Error> {
deadline deadline
.checked_duration_since(Instant::now()) .checked_duration_since(Instant::now())
.ok_or(Error::PoolTimedOut) .ok_or(Error::PoolTimedOut)

View file

@ -106,7 +106,7 @@ impl<DB: Database> Clone for PoolOptions<DB> {
max_lifetime: self.max_lifetime, max_lifetime: self.max_lifetime,
idle_timeout: self.idle_timeout, idle_timeout: self.idle_timeout,
fair: self.fair, fair: self.fair,
parent_pool: self.parent_pool.as_ref().map(Pool::clone), parent_pool: self.parent_pool.clone(),
} }
} }
} }

View file

@ -46,14 +46,14 @@ where
#[inline] #[inline]
fn sql(&self) -> &'q str { fn sql(&self) -> &'q str {
match self.statement { match self.statement {
Either::Right(ref statement) => statement.sql(), Either::Right(statement) => statement.sql(),
Either::Left(sql) => sql, Either::Left(sql) => sql,
} }
} }
fn statement(&self) -> Option<&DB::Statement<'q>> { fn statement(&self) -> Option<&DB::Statement<'q>> {
match self.statement { match self.statement {
Either::Right(ref statement) => Some(&statement), Either::Right(statement) => Some(statement),
Either::Left(_) => None, Either::Left(_) => None,
} }
} }
@ -364,7 +364,7 @@ where
let mut f = self.mapper; let mut f = self.mapper;
Map { Map {
inner: self.inner, inner: self.inner,
mapper: move |row| f(row).and_then(|o| g(o)), mapper: move |row| f(row).and_then(&mut g),
} }
} }

View file

@ -33,9 +33,9 @@ pub async fn timeout<F: Future>(duration: Duration, f: F) -> Result<F::Output, T
#[cfg(feature = "_rt-async-std")] #[cfg(feature = "_rt-async-std")]
{ {
return async_std::future::timeout(duration, f) async_std::future::timeout(duration, f)
.await .await
.map_err(|_| TimeoutError(())); .map_err(|_| TimeoutError(()))
} }
#[cfg(not(feature = "_rt-async-std"))] #[cfg(not(feature = "_rt-async-std"))]
@ -50,7 +50,7 @@ pub async fn sleep(duration: Duration) {
#[cfg(feature = "_rt-async-std")] #[cfg(feature = "_rt-async-std")]
{ {
return async_std::task::sleep(duration).await; async_std::task::sleep(duration).await
} }
#[cfg(not(feature = "_rt-async-std"))] #[cfg(not(feature = "_rt-async-std"))]
@ -70,7 +70,7 @@ where
#[cfg(feature = "_rt-async-std")] #[cfg(feature = "_rt-async-std")]
{ {
return JoinHandle::AsyncStd(async_std::task::spawn(fut)); JoinHandle::AsyncStd(async_std::task::spawn(fut))
} }
#[cfg(not(feature = "_rt-async-std"))] #[cfg(not(feature = "_rt-async-std"))]
@ -90,7 +90,7 @@ where
#[cfg(feature = "_rt-async-std")] #[cfg(feature = "_rt-async-std")]
{ {
return JoinHandle::AsyncStd(async_std::task::spawn_blocking(f)); JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
} }
#[cfg(not(feature = "_rt-async-std"))] #[cfg(not(feature = "_rt-async-std"))]
@ -105,7 +105,7 @@ pub async fn yield_now() {
#[cfg(feature = "_rt-async-std")] #[cfg(feature = "_rt-async-std")]
{ {
return async_std::task::yield_now().await; async_std::task::yield_now().await;
} }
#[cfg(not(feature = "_rt-async-std"))] #[cfg(not(feature = "_rt-async-std"))]
@ -125,13 +125,12 @@ pub fn test_block_on<F: Future>(f: F) -> F::Output {
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
{ {
return async_std::task::block_on(f); async_std::task::block_on(f)
} }
#[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))]
{ {
drop(f); missing_rt(f)
panic!("at least one of the `runtime-*` features must be enabled")
} }
} }

View file

@ -126,17 +126,15 @@ pub struct AsyncSemaphoreReleaser<'a> {
impl AsyncSemaphoreReleaser<'_> { impl AsyncSemaphoreReleaser<'_> {
pub fn disarm(self) { pub fn disarm(self) {
#[cfg(feature = "_rt-tokio")]
{
self.inner.forget();
}
#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] #[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))]
{ {
let mut this = self; let mut this = self;
this.inner.disarm(); this.inner.disarm();
return;
}
#[cfg(feature = "_rt-tokio")]
{
self.inner.forget();
return;
} }
#[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))] #[cfg(not(any(feature = "_rt-async-std", feature = "_rt-tokio")))]

View file

@ -109,6 +109,7 @@ impl<DB: Database> FixtureSnapshot<DB> {
/// Implements `ToString` but not `Display` because it uses [`QueryBuilder`] internally, /// Implements `ToString` but not `Display` because it uses [`QueryBuilder`] internally,
/// which appends to an internal string. /// which appends to an internal string.
#[allow(clippy::to_string_trait_impl)]
impl<DB: Database> ToString for Fixture<DB> impl<DB: Database> ToString for Fixture<DB>
where where
for<'a> <DB as Database>::Arguments<'a>: Default, for<'a> <DB as Database>::Arguments<'a>: Default,

View file

@ -111,7 +111,7 @@ where
DB: Database, DB: Database,
{ {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
(self.fmt)(&self.value, f) (self.fmt)(self.value, f)
} }
} }

View file

@ -45,16 +45,16 @@ impl<DB: DatabaseExt> CachingDescribeBlocking<DB> {
where where
for<'a> &'a mut DB::Connection: Executor<'a, Database = DB>, for<'a> &'a mut DB::Connection: Executor<'a, Database = DB>,
{ {
crate::block_on(async { let mut cache = self
let mut cache = self .connections
.connections .lock()
.lock() .expect("previous panic in describe call");
.expect("previous panic in describe call");
crate::block_on(async {
let conn = match cache.entry(database_url.to_string()) { let conn = match cache.entry(database_url.to_string()) {
hash_map::Entry::Occupied(hit) => hit.into_mut(), hash_map::Entry::Occupied(hit) => hit.into_mut(),
hash_map::Entry::Vacant(miss) => { hash_map::Entry::Vacant(miss) => {
miss.insert(DB::Connection::connect(&database_url).await?) miss.insert(DB::Connection::connect(database_url).await?)
} }
}; };

View file

@ -41,6 +41,7 @@ impl TypeName {
} }
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
#[allow(clippy::enum_variant_names)]
pub enum RenameAll { pub enum RenameAll {
LowerCase, LowerCase,
SnakeCase, SnakeCase,
@ -165,7 +166,7 @@ pub fn parse_child_attributes(input: &[Attribute]) -> syn::Result<SqlxChildAttri
json = true; json = true;
} }
return Ok(()); Ok(())
})?; })?;
if json && flatten { if json && flatten {
@ -265,8 +266,8 @@ pub fn check_strong_enum_attributes(
Ok(attributes) Ok(attributes)
} }
pub fn check_struct_attributes<'a>( pub fn check_struct_attributes(
input: &'a DeriveInput, input: &DeriveInput,
fields: &Punctuated<Field, Comma>, fields: &Punctuated<Field, Comma>,
) -> syn::Result<SqlxContainerAttributes> { ) -> syn::Result<SqlxContainerAttributes> {
let attributes = parse_container_attributes(&input.attrs)?; let attributes = parse_container_attributes(&input.attrs)?;

View file

@ -95,7 +95,7 @@ fn expand_derive_decode_weak_enum(
input: &DeriveInput, input: &DeriveInput,
variants: &Punctuated<Variant, Comma>, variants: &Punctuated<Variant, Comma>,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let attr = check_weak_enum_attributes(input, &variants)?; let attr = check_weak_enum_attributes(input, variants)?;
let repr = attr.repr.unwrap(); let repr = attr.repr.unwrap();
let ident = &input.ident; let ident = &input.ident;
@ -142,7 +142,7 @@ fn expand_derive_decode_strong_enum(
input: &DeriveInput, input: &DeriveInput,
variants: &Punctuated<Variant, Comma>, variants: &Punctuated<Variant, Comma>,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let cattr = check_strong_enum_attributes(input, &variants)?; let cattr = check_strong_enum_attributes(input, variants)?;
let ident = &input.ident; let ident = &input.ident;
let ident_s = ident.to_string(); let ident_s = ident.to_string();
@ -154,7 +154,7 @@ fn expand_derive_decode_strong_enum(
if let Some(rename) = attributes.rename { if let Some(rename) = attributes.rename {
parse_quote!(#rename => ::std::result::Result::Ok(#ident :: #id),) parse_quote!(#rename => ::std::result::Result::Ok(#ident :: #id),)
} else if let Some(pattern) = cattr.rename_all { } else if let Some(pattern) = cattr.rename_all {
let name = rename_all(&*id.to_string(), pattern); let name = rename_all(&id.to_string(), pattern);
parse_quote!(#name => ::std::result::Result::Ok(#ident :: #id),) parse_quote!(#name => ::std::result::Result::Ok(#ident :: #id),)
} else { } else {

View file

@ -20,7 +20,7 @@ pub fn expand_derive_encode(input: &DeriveInput) -> syn::Result<TokenStream> {
fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }), fields: Fields::Unnamed(FieldsUnnamed { unnamed, .. }),
.. ..
}) if unnamed.len() == 1 => { }) if unnamed.len() == 1 => {
expand_derive_encode_transparent(&input, unnamed.first().unwrap()) expand_derive_encode_transparent(input, unnamed.first().unwrap())
} }
Data::Enum(DataEnum { variants, .. }) => match args.repr { Data::Enum(DataEnum { variants, .. }) => match args.repr {
Some(_) => expand_derive_encode_weak_enum(input, variants), Some(_) => expand_derive_encode_weak_enum(input, variants),
@ -104,7 +104,7 @@ fn expand_derive_encode_weak_enum(
input: &DeriveInput, input: &DeriveInput,
variants: &Punctuated<Variant, Comma>, variants: &Punctuated<Variant, Comma>,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let attr = check_weak_enum_attributes(input, &variants)?; let attr = check_weak_enum_attributes(input, variants)?;
let repr = attr.repr.unwrap(); let repr = attr.repr.unwrap();
let ident = &input.ident; let ident = &input.ident;
@ -143,7 +143,7 @@ fn expand_derive_encode_strong_enum(
input: &DeriveInput, input: &DeriveInput,
variants: &Punctuated<Variant, Comma>, variants: &Punctuated<Variant, Comma>,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
let cattr = check_strong_enum_attributes(input, &variants)?; let cattr = check_strong_enum_attributes(input, variants)?;
let ident = &input.ident; let ident = &input.ident;
@ -156,7 +156,7 @@ fn expand_derive_encode_strong_enum(
if let Some(rename) = attributes.rename { if let Some(rename) = attributes.rename {
value_arms.push(quote!(#ident :: #id => #rename,)); value_arms.push(quote!(#ident :: #id => #rename,));
} else if let Some(pattern) = cattr.rename_all { } else if let Some(pattern) = cattr.rename_all {
let name = rename_all(&*id.to_string(), pattern); let name = rename_all(&id.to_string(), pattern);
value_arms.push(quote!(#ident :: #id => #name,)); value_arms.push(quote!(#ident :: #id => #name,));
} else { } else {
@ -197,7 +197,7 @@ fn expand_derive_encode_struct(
input: &DeriveInput, input: &DeriveInput,
fields: &Punctuated<Field, Comma>, fields: &Punctuated<Field, Comma>,
) -> syn::Result<TokenStream> { ) -> syn::Result<TokenStream> {
check_struct_attributes(input, &fields)?; check_struct_attributes(input, fields)?;
let mut tts = TokenStream::new(); let mut tts = TokenStream::new();

View file

@ -65,16 +65,14 @@ fn expand_derive_from_row_struct(
let container_attributes = parse_container_attributes(&input.attrs)?; let container_attributes = parse_container_attributes(&input.attrs)?;
let default_instance: Option<Stmt>; let default_instance: Option<Stmt> = if container_attributes.default {
if container_attributes.default {
predicates.push(parse_quote!(#ident: ::std::default::Default)); predicates.push(parse_quote!(#ident: ::std::default::Default));
default_instance = Some(parse_quote!( Some(parse_quote!(
let __default = #ident::default(); let __default = #ident::default();
)); ))
} else { } else {
default_instance = None; None
} };
let reads: Vec<Stmt> = fields let reads: Vec<Stmt> = fields
.iter() .iter()

View file

@ -69,11 +69,13 @@ where
.expect("failed to start Tokio runtime") .expect("failed to start Tokio runtime")
}); });
return TOKIO_RT.block_on(f); TOKIO_RT.block_on(f)
} }
#[cfg(all(feature = "_rt-async-std", not(feature = "tokio")))] #[cfg(all(feature = "_rt-async-std", not(feature = "tokio")))]
return async_std::task::block_on(f); {
async_std::task::block_on(f)
}
#[cfg(not(any(feature = "_rt-async-std", feature = "tokio")))] #[cfg(not(any(feature = "_rt-async-std", feature = "tokio")))]
sqlx_core::rt::missing_rt(f) sqlx_core::rt::missing_rt(f)

View file

@ -20,7 +20,7 @@ impl ToTokens for QuoteMigrationType {
quote! { ::sqlx::migrate::MigrationType::ReversibleDown } quote! { ::sqlx::migrate::MigrationType::ReversibleDown }
} }
}; };
tokens.append_all(ts.into_iter()); tokens.append_all(ts);
} }
} }
@ -77,7 +77,7 @@ impl ToTokens for QuoteMigration {
} }
}; };
tokens.append_all(ts.into_iter()); tokens.append_all(ts);
} }
} }

View file

@ -56,9 +56,9 @@ pub fn quote_args<DB: DatabaseExt>(
} }
let param_ty = let param_ty =
DB::param_type_for_id(&param_ty) DB::param_type_for_id(param_ty)
.ok_or_else(|| { .ok_or_else(|| {
if let Some(feature_gate) = DB::get_feature_gate(&param_ty) { if let Some(feature_gate) = DB::get_feature_gate(param_ty) {
format!( format!(
"optional sqlx feature `{}` required for type {} of param #{}", "optional sqlx feature `{}` required for type {} of param #{}",
feature_gate, feature_gate,

View file

@ -85,8 +85,8 @@ impl Metadata {
let cargo = env("CARGO").expect("`CARGO` must be set"); let cargo = env("CARGO").expect("`CARGO` must be set");
let output = Command::new(&cargo) let output = Command::new(cargo)
.args(&["metadata", "--format-version=1", "--no-deps"]) .args(["metadata", "--format-version=1", "--no-deps"])
.current_dir(&self.manifest_dir) .current_dir(&self.manifest_dir)
.env_remove("__CARGO_FIX_PLZ") .env_remove("__CARGO_FIX_PLZ")
.output() .output()
@ -190,7 +190,7 @@ pub fn expand_input<'a>(
}; };
for driver in drivers { for driver in drivers {
if data_source.matches_driver(&driver) { if data_source.matches_driver(driver) {
return (driver.expand)(input, data_source); return (driver.expand)(input, data_source);
} }
} }
@ -222,7 +222,7 @@ where
let (query_data, offline): (QueryData<DB>, bool) = match data_source { let (query_data, offline): (QueryData<DB>, bool) = match data_source {
QueryDataSource::Cached(dyn_data) => (QueryData::from_dyn_data(dyn_data)?, true), QueryDataSource::Cached(dyn_data) => (QueryData::from_dyn_data(dyn_data)?, true),
QueryDataSource::Live { database_url, .. } => { QueryDataSource::Live { database_url, .. } => {
let describe = DB::describe_blocking(&input.sql, &database_url)?; let describe = DB::describe_blocking(&input.sql, database_url)?;
(QueryData::from_describe(&input.sql, describe), false) (QueryData::from_describe(&input.sql, describe), false)
} }
}; };
@ -295,13 +295,9 @@ where
} }
} }
let record_fields = columns.iter().map( let record_fields = columns
|&output::RustColumn { .iter()
ref ident, .map(|output::RustColumn { ident, type_, .. }| quote!(#ident: #type_,));
ref type_,
..
}| quote!(#ident: #type_,),
);
let mut record_tokens = quote! { let mut record_tokens = quote! {
#[derive(Debug)] #[derive(Debug)]

View file

@ -86,7 +86,7 @@ fn column_to_rust<DB: DatabaseExt>(describe: &Describe<DB>, i: usize) -> crate::
let column = &describe.columns()[i]; let column = &describe.columns()[i];
// add raw prefix to all identifiers // add raw prefix to all identifiers
let decl = ColumnDecl::parse(&column.name()) let decl = ColumnDecl::parse(column.name())
.map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?; .map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?;
let ColumnOverride { nullability, type_ } = decl.r#override; let ColumnOverride { nullability, type_ } = decl.r#override;
@ -133,10 +133,8 @@ pub fn quote_query_as<DB: DatabaseExt>(
let instantiations = columns.iter().enumerate().map( let instantiations = columns.iter().enumerate().map(
|( |(
i, i,
&RustColumn { RustColumn {
ref var_name, var_name, type_, ..
ref type_,
..
}, },
)| { )| {
match (input.checked, type_) { match (input.checked, type_) {
@ -221,19 +219,19 @@ pub fn quote_query_scalar<DB: DatabaseExt>(
} }
fn get_column_type<DB: DatabaseExt>(i: usize, column: &DB::Column) -> TokenStream { fn get_column_type<DB: DatabaseExt>(i: usize, column: &DB::Column) -> TokenStream {
let type_info = &*column.type_info(); let type_info = column.type_info();
<DB as TypeChecking>::return_type_for_id(&type_info).map_or_else( <DB as TypeChecking>::return_type_for_id(type_info).map_or_else(
|| { || {
let message = let message =
if let Some(feature_gate) = <DB as TypeChecking>::get_feature_gate(&type_info) { if let Some(feature_gate) = <DB as TypeChecking>::get_feature_gate(type_info) {
format!( format!(
"optional sqlx feature `{feat}` required for type {ty} of {col}", "optional sqlx feature `{feat}` required for type {ty} of {col}",
ty = &type_info, ty = &type_info,
feat = feature_gate, feat = feature_gate,
col = DisplayColumn { col = DisplayColumn {
idx: i, idx: i,
name: &*column.name() name: column.name()
} }
) )
} else { } else {
@ -242,7 +240,7 @@ fn get_column_type<DB: DatabaseExt>(i: usize, column: &DB::Column) -> TokenStrea
ty = type_info, ty = type_info,
col = DisplayColumn { col = DisplayColumn {
idx: i, idx: i,
name: &*column.name() name: column.name()
} }
) )
}; };

View file

@ -22,7 +22,7 @@ impl Column for MySqlColumn {
} }
fn name(&self) -> &str { fn name(&self) -> &str {
&*self.name &self.name
} }
fn type_info(&self) -> &MySqlTypeInfo { fn type_info(&self) -> &MySqlTypeInfo {

View file

@ -87,7 +87,7 @@ fn scramble_sha1(
let mut pw_hash = ctx.finalize_reset(); let mut pw_hash = ctx.finalize_reset();
ctx.update(&pw_hash); ctx.update(pw_hash);
let pw_hash_hash = ctx.finalize_reset(); let pw_hash_hash = ctx.finalize_reset();
@ -114,7 +114,7 @@ fn scramble_sha256(
let mut pw_hash = ctx.finalize_reset(); let mut pw_hash = ctx.finalize_reset();
ctx.update(&pw_hash); ctx.update(pw_hash);
let pw_hash_hash = ctx.finalize_reset(); let pw_hash_hash = ctx.finalize_reset();
@ -155,10 +155,10 @@ async fn encrypt_rsa<'s>(
let (a, b) = (nonce.first_ref(), nonce.last_ref()); let (a, b) = (nonce.first_ref(), nonce.last_ref());
let mut nonce = Vec::with_capacity(a.len() + b.len()); let mut nonce = Vec::with_capacity(a.len() + b.len());
nonce.extend_from_slice(&*a); nonce.extend_from_slice(a);
nonce.extend_from_slice(&*b); nonce.extend_from_slice(b);
xor_eq(&mut pass, &*nonce); xor_eq(&mut pass, &nonce);
// client sends an RSA encrypted password // client sends an RSA encrypted password
let pkey = parse_rsa_pub_key(rsa_pub_key)?; let pkey = parse_rsa_pub_key(rsa_pub_key)?;
@ -193,5 +193,5 @@ fn parse_rsa_pub_key(key: &[u8]) -> Result<RsaPublicKey, Error> {
// we are receiving a PKCS#8 RSA Public Key at all // we are receiving a PKCS#8 RSA Public Key at all
// times from MySQL // times from MySQL
RsaPublicKey::from_public_key_pem(&pem).map_err(Error::protocol) RsaPublicKey::from_public_key_pem(pem).map_err(Error::protocol)
} }

View file

@ -245,13 +245,15 @@ impl MySqlConnection {
impl<'c> Executor<'c> for &'c mut MySqlConnection { impl<'c> Executor<'c> for &'c mut MySqlConnection {
type Database = MySql; type Database = MySql;
fn fetch_many<'e, 'q: 'e, E: 'q>( fn fetch_many<'e, 'q, E>(
self, self,
mut query: E, mut query: E,
) -> BoxStream<'e, Result<Either<MySqlQueryResult, MySqlRow>, Error>> ) -> BoxStream<'e, Result<Either<MySqlQueryResult, MySqlRow>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
'q: 'e,
E: 'q,
{ {
let sql = query.sql(); let sql = query.sql();
let arguments = query.take_arguments().map_err(Error::Encode); let arguments = query.take_arguments().map_err(Error::Encode);
@ -270,13 +272,12 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
}) })
} }
fn fetch_optional<'e, 'q: 'e, E: 'q>( fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<MySqlRow>, Error>>
self,
query: E,
) -> BoxFuture<'e, Result<Option<MySqlRow>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
'q: 'e,
E: 'q,
{ {
let mut s = self.fetch_many(query); let mut s = self.fetch_many(query);
@ -338,7 +339,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection {
.send_packet(StmtClose { statement: id }) .send_packet(StmtClose { statement: id })
.await?; .await?;
let columns = (&*metadata.columns).clone(); let columns = (*metadata.columns).clone();
let nullable = columns let nullable = columns
.iter() .iter()
@ -384,7 +385,7 @@ fn recv_next_result_column(def: &ColumnDefinition, ordinal: usize) -> Result<MyS
(name, _) => UStr::new(name), (name, _) => UStr::new(name),
}; };
let type_info = MySqlTypeInfo::from_column(&def); let type_info = MySqlTypeInfo::from_column(def);
Ok(MySqlColumn { Ok(MySqlColumn {
name, name,

View file

@ -119,7 +119,7 @@ impl MySqlConnectOptions {
/// The default behavior when the host is not specified, /// The default behavior when the host is not specified,
/// is to connect to localhost. /// is to connect to localhost.
pub fn host(mut self, host: &str) -> Self { pub fn host(mut self, host: &str) -> Self {
self.host = host.to_owned(); host.clone_into(&mut self.host);
self self
} }
@ -142,7 +142,7 @@ impl MySqlConnectOptions {
/// Sets the username to connect as. /// Sets the username to connect as.
pub fn username(mut self, username: &str) -> Self { pub fn username(mut self, username: &str) -> Self {
self.username = username.to_owned(); username.clone_into(&mut self.username);
self self
} }
@ -302,7 +302,7 @@ impl MySqlConnectOptions {
/// The default character set is `utf8mb4`. This is supported from MySQL 5.5.3. /// The default character set is `utf8mb4`. This is supported from MySQL 5.5.3.
/// If you need to connect to an older version, we recommend you to change this to `utf8`. /// If you need to connect to an older version, we recommend you to change this to `utf8`.
pub fn charset(mut self, charset: &str) -> Self { pub fn charset(mut self, charset: &str) -> Self {
self.charset = charset.to_owned(); charset.clone_into(&mut self.charset);
self self
} }

View file

@ -22,7 +22,7 @@ impl MySqlConnectOptions {
let username = url.username(); let username = url.username();
if !username.is_empty() { if !username.is_empty() {
options = options.username( options = options.username(
&*percent_decode_str(username) &percent_decode_str(username)
.decode_utf8() .decode_utf8()
.map_err(Error::config)?, .map_err(Error::config)?,
); );
@ -30,7 +30,7 @@ impl MySqlConnectOptions {
if let Some(password) = url.password() { if let Some(password) = url.password() {
options = options.password( options = options.password(
&*percent_decode_str(password) &percent_decode_str(password)
.decode_utf8() .decode_utf8()
.map_err(Error::config)?, .map_err(Error::config)?,
); );
@ -52,11 +52,11 @@ impl MySqlConnectOptions {
} }
"charset" => { "charset" => {
options = options.charset(&*value); options = options.charset(&value);
} }
"collation" => { "collation" => {
options = options.collation(&*value); options = options.collation(&value);
} }
"sslcert" | "ssl-cert" => options = options.ssl_client_cert(&*value), "sslcert" | "ssl-cert" => options = options.ssl_client_cert(&*value),
@ -87,12 +87,12 @@ impl MySqlConnectOptions {
.expect("BUG: generated un-parseable URL"); .expect("BUG: generated un-parseable URL");
if let Some(password) = &self.password { if let Some(password) = &self.password {
let password = utf8_percent_encode(&password, NON_ALPHANUMERIC).to_string(); let password = utf8_percent_encode(password, NON_ALPHANUMERIC).to_string();
let _ = url.set_password(Some(&password)); let _ = url.set_password(Some(&password));
} }
if let Some(database) = &self.database { if let Some(database) = &self.database {
url.set_path(&database); url.set_path(database);
} }
let ssl_mode = match self.ssl_mode { let ssl_mode = match self.ssl_mode {
@ -112,7 +112,7 @@ impl MySqlConnectOptions {
url.query_pairs_mut().append_pair("charset", &self.charset); url.query_pairs_mut().append_pair("charset", &self.charset);
if let Some(collation) = &self.collation { if let Some(collation) = &self.collation {
url.query_pairs_mut().append_pair("charset", &collation); url.query_pairs_mut().append_pair("charset", collation);
} }
if let Some(ssl_client_cert) = &self.ssl_client_cert { if let Some(ssl_client_cert) = &self.ssl_client_cert {

View file

@ -4,7 +4,7 @@ use std::str::FromStr;
/// Options for controlling the desired security state of the connection to the MySQL server. /// Options for controlling the desired security state of the connection to the MySQL server.
/// ///
/// It is used by the [`ssl_mode`](super::MySqlConnectOptions::ssl_mode) method. /// It is used by the [`ssl_mode`](super::MySqlConnectOptions::ssl_mode) method.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy, Default)]
pub enum MySqlSslMode { pub enum MySqlSslMode {
/// Establish an unencrypted connection. /// Establish an unencrypted connection.
Disabled, Disabled,
@ -13,6 +13,7 @@ pub enum MySqlSslMode {
/// back to an unencrypted connection if an encrypted connection cannot be established. /// back to an unencrypted connection if an encrypted connection cannot be established.
/// ///
/// This is the default if `ssl_mode` is not specified. /// This is the default if `ssl_mode` is not specified.
#[default]
Preferred, Preferred,
/// Establish an encrypted connection if the server supports encrypted connections. /// Establish an encrypted connection if the server supports encrypted connections.
@ -30,12 +31,6 @@ pub enum MySqlSslMode {
VerifyIdentity, VerifyIdentity,
} }
impl Default for MySqlSslMode {
fn default() -> Self {
MySqlSslMode::Preferred
}
}
impl FromStr for MySqlSslMode { impl FromStr for MySqlSslMode {
type Err = Error; type Err = Error;

View file

@ -3,6 +3,8 @@ use std::str::FromStr;
use crate::error::Error; use crate::error::Error;
#[derive(Debug, Copy, Clone)] #[derive(Debug, Copy, Clone)]
// These have all the same suffix but they match the auth plugin names.
#[allow(clippy::enum_variant_names)]
pub enum AuthPlugin { pub enum AuthPlugin {
MySqlNativePassword, MySqlNativePassword,
CachingSha2Password, CachingSha2Password,

View file

@ -1,5 +1,6 @@
use bytes::buf::Chain; use bytes::buf::Chain;
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use std::cmp;
use crate::error::Error; use crate::error::Error;
use crate::io::{BufExt, Decode}; use crate::io::{BufExt, Decode};
@ -61,7 +62,7 @@ impl Decode<'_> for Handshake {
} }
let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) { let auth_plugin_data_2 = if capabilities.contains(Capabilities::SECURE_CONNECTION) {
let len = ((auth_plugin_data_len as isize) - 9).max(12) as usize; let len = cmp::max((auth_plugin_data_len as isize) - 9, 12) as usize;
let v = buf.get_bytes(len); let v = buf.get_bytes(len);
buf.advance(1); // NUL-terminator buf.advance(1); // NUL-terminator

View file

@ -30,7 +30,7 @@ impl Decode<'_, Capabilities> for ErrPacket {
if capabilities.contains(Capabilities::PROTOCOL_41) { if capabilities.contains(Capabilities::PROTOCOL_41) {
// If the next byte is '#' then we have a SQL STATE // If the next byte is '#' then we have a SQL STATE
if buf.get(0) == Some(&0x23) { if buf.starts_with(b"#") {
buf.advance(1); buf.advance(1);
sql_state = Some(buf.get_str(5)?); sql_state = Some(buf.get_str(5)?);
} }

View file

@ -10,8 +10,6 @@ pub(crate) struct Row {
impl Row { impl Row {
pub(crate) fn get(&self, index: usize) -> Option<&[u8]> { pub(crate) fn get(&self, index: usize) -> Option<&[u8]> {
self.values[index] self.values[index].clone().map(|col| &self.storage[col])
.as_ref()
.map(|col| &self.storage[(col.start as usize)..(col.end as usize)])
} }
} }

View file

@ -46,6 +46,6 @@ impl ColumnIndex<MySqlRow> for &'_ str {
row.column_names row.column_names
.get(*self) .get(*self)
.ok_or_else(|| Error::ColumnNotFound((*self).into())) .ok_or_else(|| Error::ColumnNotFound((*self).into()))
.map(|v| *v) .copied()
} }
} }

View file

@ -55,6 +55,6 @@ impl ColumnIndex<MySqlStatement<'_>> for &'_ str {
.column_names .column_names
.get(*self) .get(*self)
.ok_or_else(|| Error::ColumnNotFound((*self).into())) .ok_or_else(|| Error::ColumnNotFound((*self).into()))
.map(|v| *v) .copied()
} }
} }

View file

@ -27,10 +27,7 @@ static DO_CLEANUP: AtomicBool = AtomicBool::new(true);
impl TestSupport for MySql { impl TestSupport for MySql {
fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> { fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> {
Box::pin(async move { Box::pin(async move { test_context(args).await })
let res = test_context(args).await;
res
})
} }
fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> { fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> {
@ -47,7 +44,7 @@ impl TestSupport for MySql {
.await?; .await?;
query("delete from _sqlx_test_databases where db_id = ?") query("delete from _sqlx_test_databases where db_id = ?")
.bind(&db_id) .bind(db_id)
.execute(&mut *conn) .execute(&mut *conn)
.await?; .await?;
@ -141,7 +138,7 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<MySql>, Error> {
} }
query("insert into _sqlx_test_databases(test_path) values (?)") query("insert into _sqlx_test_databases(test_path) values (?)")
.bind(&args.test_path) .bind(args.test_path)
.execute(&mut *conn) .execute(&mut *conn)
.await?; .await?;
@ -182,7 +179,7 @@ async fn do_cleanup(conn: &mut MySqlConnection, created_before: Duration) -> Res
"select db_id from _sqlx_test_databases \ "select db_id from _sqlx_test_databases \
where created_at < from_unixtime(?)", where created_at < from_unixtime(?)",
) )
.bind(&created_before_as_secs) .bind(created_before_as_secs)
.fetch_all(&mut *conn) .fetch_all(&mut *conn)
.await?; .await?;
@ -221,8 +218,6 @@ async fn do_cleanup(conn: &mut MySqlConnection, created_before: Duration) -> Res
separated.push_bind(db_id); separated.push_bind(db_id);
} }
drop(separated);
query.push(")").build().execute(&mut *conn).await?; query.push(")").build().execute(&mut *conn).await?;
Ok(deleted_db_ids.len()) Ok(deleted_db_ids.len())

View file

@ -59,7 +59,7 @@ impl TransactionManager for MySqlTransactionManager {
conn.inner.stream.sequence_id = 0; conn.inner.stream.sequence_id = 0;
conn.inner conn.inner
.stream .stream
.write_packet(Query(&*rollback_ansi_transaction_sql(depth))); .write_packet(Query(&rollback_ansi_transaction_sql(depth)));
conn.inner.transaction_depth = depth - 1; conn.inner.transaction_depth = depth - 1;
} }

View file

@ -25,7 +25,7 @@ impl Type<MySql> for DateTime<Utc> {
/// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). /// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC).
impl Encode<'_, MySql> for DateTime<Utc> { impl Encode<'_, MySql> for DateTime<Utc> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> { fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
Encode::<MySql>::encode(&self.naive_utc(), buf) Encode::<MySql>::encode(self.naive_utc(), buf)
} }
} }
@ -51,7 +51,7 @@ impl Type<MySql> for DateTime<Local> {
/// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). /// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC).
impl Encode<'_, MySql> for DateTime<Local> { impl Encode<'_, MySql> for DateTime<Local> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> { fn encode_by_ref(&self, buf: &mut Vec<u8>) -> Result<IsNull, BoxDynError> {
Encode::<MySql>::encode(&self.naive_utc(), buf) Encode::<MySql>::encode(self.naive_utc(), buf)
} }
} }
@ -318,7 +318,7 @@ fn encode_time(time: &NaiveTime, include_micros: bool, buf: &mut Vec<u8>) {
buf.push(time.second() as u8); buf.push(time.second() as u8);
if include_micros { if include_micros {
buf.extend(&((time.nanosecond() / 1000) as u32).to_le_bytes()); buf.extend((time.nanosecond() / 1000).to_le_bytes());
} }
} }

View file

@ -498,7 +498,7 @@ impl MySqlTimeSign {
} }
} }
fn to_byte(&self) -> u8 { fn to_byte(self) -> u8 {
match self { match self {
// We can't use `#[repr(u8)]` because this is opposite of the ordering we want from `Ord` // We can't use `#[repr(u8)]` because this is opposite of the ordering we want from `Ord`
Self::Negative => 1, Self::Negative => 1,
@ -579,7 +579,7 @@ fn parse(text: &str) -> Result<MySqlTime, BoxDynError> {
MySqlTimeSign::Positive MySqlTimeSign::Positive
}; };
let hours = hours.abs() as u32; let hours = hours.unsigned_abs();
let minutes: u8 = minutes let minutes: u8 = minutes
.parse() .parse()

View file

@ -27,7 +27,7 @@ impl Encode<'_, MySql> for OffsetDateTime {
let utc_dt = self.to_offset(UtcOffset::UTC); let utc_dt = self.to_offset(UtcOffset::UTC);
let primitive_dt = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time()); let primitive_dt = PrimitiveDateTime::new(utc_dt.date(), utc_dt.time());
Encode::<MySql>::encode(&primitive_dt, buf) Encode::<MySql>::encode(primitive_dt, buf)
} }
} }
@ -287,8 +287,8 @@ fn decode_date(buf: &[u8]) -> Result<Option<Date>, BoxDynError> {
Date::from_calendar_date( Date::from_calendar_date(
LittleEndian::read_u16(buf) as i32, LittleEndian::read_u16(buf) as i32,
time::Month::try_from(buf[2] as u8)?, time::Month::try_from(buf[2])?,
buf[3] as u8, buf[3],
) )
.map_err(Into::into) .map_err(Into::into)
.map(Some) .map(Some)
@ -300,7 +300,7 @@ fn encode_time(time: &Time, include_micros: bool, buf: &mut Vec<u8>) {
buf.push(time.second()); buf.push(time.second());
if include_micros { if include_micros {
buf.extend(&((time.nanosecond() / 1000) as u32).to_le_bytes()); buf.extend(&(time.nanosecond() / 1000).to_le_bytes());
} }
} }

View file

@ -95,7 +95,7 @@ impl<'r> ValueRef<'r> for MySqlValueRef<'r> {
#[inline] #[inline]
fn is_null(&self) -> bool { fn is_null(&self) -> bool {
is_null(self.value.as_deref(), &self.type_info) is_null(self.value, &self.type_info)
} }
} }
@ -105,7 +105,7 @@ fn is_null(value: Option<&[u8]>, ty: &MySqlTypeInfo) -> bool {
if matches!( if matches!(
ty.r#type, ty.r#type,
ColumnType::Date | ColumnType::Timestamp | ColumnType::Datetime ColumnType::Date | ColumnType::Timestamp | ColumnType::Datetime
) && value.get(0) == Some(&0) ) && value.starts_with(b"\0")
{ {
return true; return true;
} }

View file

@ -32,12 +32,7 @@ pub struct PgArgumentBuffer {
// //
// This currently is only setup to be useful if there is a *fixed-size* slot that needs to be // This currently is only setup to be useful if there is a *fixed-size* slot that needs to be
// tweaked from the input type. However, that's the only use case we currently have. // tweaked from the input type. However, that's the only use case we currently have.
// patches: Vec<Patch>,
patches: Vec<(
usize, // offset
usize, // argument index
Box<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
)>,
// Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID // Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID
// It pushes a "hole" that must be patched later. // It pushes a "hole" that must be patched later.
@ -49,6 +44,13 @@ pub struct PgArgumentBuffer {
type_holes: Vec<(usize, UStr)>, // Vec<{ offset, type_name }> type_holes: Vec<(usize, UStr)>, // Vec<{ offset, type_name }>
} }
struct Patch {
buf_offset: usize,
arg_index: usize,
#[allow(clippy::type_complexity)]
callback: Box<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
}
/// Implementation of [`Arguments`] for PostgreSQL. /// Implementation of [`Arguments`] for PostgreSQL.
#[derive(Default)] #[derive(Default)]
pub struct PgArguments { pub struct PgArguments {
@ -97,15 +99,15 @@ impl PgArguments {
.. ..
} = self.buffer; } = self.buffer;
for (offset, ty, callback) in patches { for patch in patches {
let buf = &mut buffer[*offset..]; let buf = &mut buffer[patch.buf_offset..];
let ty = &parameters[*ty]; let ty = &parameters[patch.arg_index];
callback(buf, ty); (patch.callback)(buf, ty);
} }
for (offset, name) in type_holes { for (offset, name) in type_holes {
let oid = conn.fetch_type_id_by_name(&*name).await?; let oid = conn.fetch_type_id_by_name(name).await?;
buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes()); buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
} }
@ -169,9 +171,13 @@ impl PgArgumentBuffer {
F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync, F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
{ {
let offset = self.len(); let offset = self.len();
let index = self.count; let arg_index = self.count;
self.patches.push((offset, index, Box::new(callback))); self.patches.push(Patch {
buf_offset: offset,
arg_index,
callback: Box::new(callback),
});
} }
// Extends the inner buffer by enough space to have an OID // Extends the inner buffer by enough space to have an OID

View file

@ -23,7 +23,7 @@ impl Column for PgColumn {
} }
fn name(&self) -> &str { fn name(&self) -> &str {
&*self.name &self.name
} }
fn type_info(&self) -> &PgTypeInfo { fn type_info(&self) -> &PgTypeInfo {

View file

@ -463,7 +463,7 @@ WHERE rngtypid = $1
}) = explains.first() }) = explains.first()
{ {
nullables.resize(outputs.len(), None); nullables.resize(outputs.len(), None);
visit_plan(&plan, outputs, &mut nullables); visit_plan(plan, outputs, &mut nullables);
} }
Ok(nullables) Ok(nullables)

View file

@ -48,7 +48,7 @@ async fn prepare(
// next we send the PARSE command to the server // next we send the PARSE command to the server
conn.stream.write(Parse { conn.stream.write(Parse {
param_types: &*param_types, param_types: &param_types,
query: sql, query: sql,
statement: id, statement: id,
}); });
@ -63,8 +63,7 @@ async fn prepare(
conn.stream.flush().await?; conn.stream.flush().await?;
// indicates that the SQL query string is now successfully parsed and has semantic validity // indicates that the SQL query string is now successfully parsed and has semantic validity
let _ = conn conn.stream
.stream
.recv_expect(MessageFormat::ParseComplete) .recv_expect(MessageFormat::ParseComplete)
.await?; .await?;
@ -227,7 +226,7 @@ impl PgConnection {
statement, statement,
formats: &[PgValueFormat::Binary], formats: &[PgValueFormat::Binary],
num_params: arguments.types.len() as i16, num_params: arguments.types.len() as i16,
params: &*arguments.buffer, params: &arguments.buffer,
result_formats: &[PgValueFormat::Binary], result_formats: &[PgValueFormat::Binary],
}); });
@ -360,15 +359,19 @@ impl PgConnection {
impl<'c> Executor<'c> for &'c mut PgConnection { impl<'c> Executor<'c> for &'c mut PgConnection {
type Database = Postgres; type Database = Postgres;
fn fetch_many<'e, 'q: 'e, E: 'q>( fn fetch_many<'e, 'q, E>(
self, self,
mut query: E, mut query: E,
) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>> ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
'q: 'e,
E: 'q,
{ {
let sql = query.sql(); let sql = query.sql();
// False positive: https://github.com/rust-lang/rust-clippy/issues/12560
#[allow(clippy::map_clone)]
let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
let arguments = query.take_arguments().map_err(Error::Encode); let arguments = query.take_arguments().map_err(Error::Encode);
let persistent = query.persistent(); let persistent = query.persistent();
@ -386,15 +389,16 @@ impl<'c> Executor<'c> for &'c mut PgConnection {
}) })
} }
fn fetch_optional<'e, 'q: 'e, E: 'q>( fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
self,
mut query: E,
) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
'q: 'e,
E: 'q,
{ {
let sql = query.sql(); let sql = query.sql();
// False positive: https://github.com/rust-lang/rust-clippy/issues/12560
#[allow(clippy::map_clone)]
let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
let arguments = query.take_arguments().map_err(Error::Encode); let arguments = query.take_arguments().map_err(Error::Encode);
let persistent = query.persistent(); let persistent = query.persistent();

View file

@ -101,7 +101,7 @@ pub(crate) async fn authenticate(
let client_key = mac.finalize().into_bytes(); let client_key = mac.finalize().into_bytes();
// StoredKey := H(ClientKey) // StoredKey := H(ClientKey)
let stored_key = Sha256::digest(&client_key); let stored_key = Sha256::digest(client_key);
// client-final-message-without-proof // client-final-message-without-proof
let client_final_message_wo_proof = format!( let client_final_message_wo_proof = format!(
@ -120,7 +120,7 @@ pub(crate) async fn authenticate(
// ClientSignature := HMAC(StoredKey, AuthMessage) // ClientSignature := HMAC(StoredKey, AuthMessage)
let mut mac = Hmac::<Sha256>::new_from_slice(&stored_key).map_err(Error::protocol)?; let mut mac = Hmac::<Sha256>::new_from_slice(&stored_key).map_err(Error::protocol)?;
mac.update(&auth_message.as_bytes()); mac.update(auth_message.as_bytes());
let client_signature = mac.finalize().into_bytes(); let client_signature = mac.finalize().into_bytes();
@ -139,7 +139,7 @@ pub(crate) async fn authenticate(
// ServerSignature := HMAC(ServerKey, AuthMessage) // ServerSignature := HMAC(ServerKey, AuthMessage)
let mut mac = Hmac::<Sha256>::new_from_slice(&server_key).map_err(Error::protocol)?; let mut mac = Hmac::<Sha256>::new_from_slice(&server_key).map_err(Error::protocol)?;
mac.update(&auth_message.as_bytes()); mac.update(auth_message.as_bytes());
// client-final-message = client-final-message-without-proof "," proof // client-final-message = client-final-message-without-proof "," proof
let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}="); let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}=");
@ -192,7 +192,7 @@ fn gen_nonce() -> String {
fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> { fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> {
let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?; let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?;
mac.update(&salt); mac.update(salt);
mac.update(&1u32.to_be_bytes()); mac.update(&1u32.to_be_bytes());
let mut u = mac.finalize_reset().into_bytes(); let mut u = mac.finalize_reset().into_bytes();

View file

@ -159,11 +159,10 @@ impl PgStream {
tracing_level tracing_level
); );
if log_is_enabled { if log_is_enabled {
let message = format!("{}", notice.message());
sqlx_core::private_tracing_dynamic_event!( sqlx_core::private_tracing_dynamic_event!(
target: "sqlx::postgres::notice", target: "sqlx::postgres::notice",
tracing_level, tracing_level,
message message = notice.message()
); );
} }
@ -211,7 +210,7 @@ fn parse_server_version(s: &str) -> Option<u32> {
break; break;
} }
} }
_ if ch.is_digit(10) => { _ if ch.is_ascii_digit() => {
if chs.peek().is_none() { if chs.peek().is_none() {
if let Ok(num) = u32::from_str(&s[from..]) { if let Ok(num) = u32::from_str(&s[from..]) {
parts.push(num); parts.push(num);

View file

@ -332,13 +332,15 @@ impl Drop for PgListener {
impl<'c> Executor<'c> for &'c mut PgListener { impl<'c> Executor<'c> for &'c mut PgListener {
type Database = Postgres; type Database = Postgres;
fn fetch_many<'e, 'q: 'e, E: 'q>( fn fetch_many<'e, 'q, E>(
self, self,
query: E, query: E,
) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>> ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
'q: 'e,
E: 'q,
{ {
futures_util::stream::once(async move { futures_util::stream::once(async move {
// need some basic type annotation to help the compiler a bit // need some basic type annotation to help the compiler a bit
@ -349,13 +351,12 @@ impl<'c> Executor<'c> for &'c mut PgListener {
.boxed() .boxed()
} }
fn fetch_optional<'e, 'q: 'e, E: 'q>( fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
self,
query: E,
) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
'q: 'e,
E: 'q,
{ {
async move { self.connection().await?.fetch_optional(query).await }.boxed() async move { self.connection().await?.fetch_optional(query).await }.boxed()
} }

View file

@ -162,8 +162,8 @@ impl Decode<'_> for AuthenticationSaslContinue {
Ok(Self { Ok(Self {
iterations, iterations,
salt, salt,
nonce: from_utf8(&*nonce).map_err(Error::protocol)?.to_owned(), nonce: from_utf8(&nonce).map_err(Error::protocol)?.to_owned(),
message: from_utf8(&*buf).map_err(Error::protocol)?.to_owned(), message: from_utf8(&buf).map_err(Error::protocol)?.to_owned(),
}) })
} }
} }

View file

@ -201,7 +201,7 @@ impl PgConnectOptions {
/// .host("localhost"); /// .host("localhost");
/// ``` /// ```
pub fn host(mut self, host: &str) -> Self { pub fn host(mut self, host: &str) -> Self {
self.host = host.to_owned(); host.clone_into(&mut self.host);
self self
} }
@ -243,7 +243,7 @@ impl PgConnectOptions {
/// .username("postgres"); /// .username("postgres");
/// ``` /// ```
pub fn username(mut self, username: &str) -> Self { pub fn username(mut self, username: &str) -> Self {
self.username = username.to_owned(); username.clone_into(&mut self.username);
self self
} }

View file

@ -24,7 +24,7 @@ impl PgConnectOptions {
let username = url.username(); let username = url.username();
if !username.is_empty() { if !username.is_empty() {
options = options.username( options = options.username(
&*percent_decode_str(username) &percent_decode_str(username)
.decode_utf8() .decode_utf8()
.map_err(Error::config)?, .map_err(Error::config)?,
); );
@ -32,7 +32,7 @@ impl PgConnectOptions {
if let Some(password) = url.password() { if let Some(password) = url.password() {
options = options.password( options = options.password(
&*percent_decode_str(password) &percent_decode_str(password)
.decode_utf8() .decode_utf8()
.map_err(Error::config)?, .map_err(Error::config)?,
); );
@ -63,32 +63,32 @@ impl PgConnectOptions {
} }
"host" => { "host" => {
if value.starts_with("/") { if value.starts_with('/') {
options = options.socket(&*value); options = options.socket(&*value);
} else { } else {
options = options.host(&*value); options = options.host(&value);
} }
} }
"hostaddr" => { "hostaddr" => {
value.parse::<IpAddr>().map_err(Error::config)?; value.parse::<IpAddr>().map_err(Error::config)?;
options = options.host(&*value) options = options.host(&value)
} }
"port" => options = options.port(value.parse().map_err(Error::config)?), "port" => options = options.port(value.parse().map_err(Error::config)?),
"dbname" => options = options.database(&*value), "dbname" => options = options.database(&value),
"user" => options = options.username(&*value), "user" => options = options.username(&value),
"password" => options = options.password(&*value), "password" => options = options.password(&value),
"application_name" => options = options.application_name(&*value), "application_name" => options = options.application_name(&value),
"options" => { "options" => {
if let Some(options) = options.options.as_mut() { if let Some(options) = options.options.as_mut() {
options.push(' '); options.push(' ');
options.push_str(&*value); options.push_str(&value);
} else { } else {
options.options = Some(value.to_string()); options.options = Some(value.to_string());
} }
@ -112,7 +112,7 @@ impl PgConnectOptions {
pub(crate) fn build_url(&self) -> Url { pub(crate) fn build_url(&self) -> Url {
let host = match &self.socket { let host = match &self.socket {
Some(socket) => { Some(socket) => {
utf8_percent_encode(&*socket.to_string_lossy(), NON_ALPHANUMERIC).to_string() utf8_percent_encode(&socket.to_string_lossy(), NON_ALPHANUMERIC).to_string()
} }
None => self.host.to_owned(), None => self.host.to_owned(),
}; };
@ -124,12 +124,12 @@ impl PgConnectOptions {
.expect("BUG: generated un-parseable URL"); .expect("BUG: generated un-parseable URL");
if let Some(password) = &self.password { if let Some(password) = &self.password {
let password = utf8_percent_encode(&password, NON_ALPHANUMERIC).to_string(); let password = utf8_percent_encode(password, NON_ALPHANUMERIC).to_string();
let _ = url.set_password(Some(&password)); let _ = url.set_password(Some(&password));
} }
if let Some(database) = &self.database { if let Some(database) = &self.database {
url.set_path(&database); url.set_path(database);
} }
let ssl_mode = match self.ssl_mode { let ssl_mode = match self.ssl_mode {

View file

@ -41,7 +41,14 @@ fn load_password_from_file(
username: &str, username: &str,
database: Option<&str>, database: Option<&str>,
) -> Option<String> { ) -> Option<String> {
let file = File::open(&path).ok()?; let file = File::open(&path)
.map_err(|e| {
tracing::warn!(
path = %path.display(),
"Failed to open `.pgpass` file: {e:?}",
);
})
.ok()?;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
{ {
@ -54,7 +61,7 @@ fn load_password_from_file(
let mode = permissions.mode(); let mode = permissions.mode();
if mode & 0o77 != 0 { if mode & 0o77 != 0 {
tracing::warn!( tracing::warn!(
path = %path.to_string_lossy(), path = %path.display(),
permissions = format!("{mode:o}"), permissions = format!("{mode:o}"),
"Ignoring path. Permissions are not strict enough", "Ignoring path. Permissions are not strict enough",
); );
@ -184,7 +191,7 @@ fn find_next_field<'a>(line: &mut &'a str) -> Option<Cow<'a, str>> {
} }
} }
return None; None
} }
#[cfg(test)] #[cfg(test)]

View file

@ -4,7 +4,7 @@ use std::str::FromStr;
/// Options for controlling the level of protection provided for PostgreSQL SSL connections. /// Options for controlling the level of protection provided for PostgreSQL SSL connections.
/// ///
/// It is used by the [`ssl_mode`](super::PgConnectOptions::ssl_mode) method. /// It is used by the [`ssl_mode`](super::PgConnectOptions::ssl_mode) method.
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy, Default)]
pub enum PgSslMode { pub enum PgSslMode {
/// Only try a non-SSL connection. /// Only try a non-SSL connection.
Disable, Disable,
@ -13,6 +13,9 @@ pub enum PgSslMode {
Allow, Allow,
/// First try an SSL connection; if that fails, try a non-SSL connection. /// First try an SSL connection; if that fails, try a non-SSL connection.
///
/// This is the default if no other mode is specified.
#[default]
Prefer, Prefer,
/// Only try an SSL connection. If a root CA file is present, verify the connection /// Only try an SSL connection. If a root CA file is present, verify the connection
@ -28,12 +31,6 @@ pub enum PgSslMode {
VerifyFull, VerifyFull,
} }
impl Default for PgSslMode {
fn default() -> Self {
PgSslMode::Prefer
}
}
impl FromStr for PgSslMode { impl FromStr for PgSslMode {
type Err = Error; type Err = Error;

View file

@ -47,7 +47,7 @@ impl ColumnIndex<PgRow> for &'_ str {
.column_names .column_names
.get(*self) .get(*self)
.ok_or_else(|| Error::ColumnNotFound((*self).into())) .ok_or_else(|| Error::ColumnNotFound((*self).into()))
.map(|v| *v) .copied()
} }
} }

View file

@ -56,7 +56,7 @@ impl ColumnIndex<PgStatement<'_>> for &'_ str {
.column_names .column_names
.get(*self) .get(*self)
.ok_or_else(|| Error::ColumnNotFound((*self).into())) .ok_or_else(|| Error::ColumnNotFound((*self).into()))
.map(|v| *v) .copied()
} }
} }

View file

@ -26,10 +26,7 @@ static DO_CLEANUP: AtomicBool = AtomicBool::new(true);
impl TestSupport for Postgres { impl TestSupport for Postgres {
fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> { fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> {
Box::pin(async move { Box::pin(async move { test_context(args).await })
let res = test_context(args).await;
res
})
} }
fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> { fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> {
@ -44,7 +41,7 @@ impl TestSupport for Postgres {
.await?; .await?;
query("delete from _sqlx_test.databases where db_name = $1") query("delete from _sqlx_test.databases where db_name = $1")
.bind(&db_name) .bind(db_name)
.execute(&mut *conn) .execute(&mut *conn)
.await?; .await?;
@ -157,7 +154,7 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<Postgres>, Error> {
returning db_name returning db_name
"#, "#,
) )
.bind(&args.test_path) .bind(args.test_path)
.fetch_one(&mut *conn) .fetch_one(&mut *conn)
.await?; .await?;
@ -190,7 +187,7 @@ async fn do_cleanup(conn: &mut PgConnection, created_before: Duration) -> Result
"select db_name from _sqlx_test.databases \ "select db_name from _sqlx_test.databases \
where created_at < (to_timestamp($1) at time zone 'UTC')", where created_at < (to_timestamp($1) at time zone 'UTC')",
) )
.bind(&created_before) .bind(created_before)
.fetch_all(&mut *conn) .fetch_all(&mut *conn)
.await?; .await?;

View file

@ -561,7 +561,7 @@ impl PgType {
PgType::Money => "MONEY", PgType::Money => "MONEY",
PgType::MoneyArray => "MONEY[]", PgType::MoneyArray => "MONEY[]",
PgType::Void => "VOID", PgType::Void => "VOID",
PgType::Custom(ty) => &*ty.name, PgType::Custom(ty) => &ty.name,
PgType::DeclareWithOid(_) => "?", PgType::DeclareWithOid(_) => "?",
PgType::DeclareWithName(name) => name, PgType::DeclareWithName(name) => name,
} }
@ -661,7 +661,7 @@ impl PgType {
PgType::Money => "money", PgType::Money => "money",
PgType::MoneyArray => "_money", PgType::MoneyArray => "_money",
PgType::Void => "void", PgType::Void => "void",
PgType::Custom(ty) => &*ty.name, PgType::Custom(ty) => &ty.name,
PgType::DeclareWithOid(_) => "?", PgType::DeclareWithOid(_) => "?",
PgType::DeclareWithName(name) => name, PgType::DeclareWithName(name) => name,
} }

View file

@ -156,7 +156,7 @@ where
T: Encode<'q, Postgres> + Type<Postgres>, T: Encode<'q, Postgres> + Type<Postgres>,
{ {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
let type_info = if self.len() < 1 { let type_info = if self.is_empty() {
T::type_info() T::type_info()
} else { } else {
self[0].produces().unwrap_or_else(T::type_info) self[0].produces().unwrap_or_else(T::type_info)

View file

@ -66,7 +66,7 @@ impl Decode<'_, Postgres> for BitVec {
))?; ))?;
} }
let mut bitvec = BitVec::from_bytes(&bytes); let mut bitvec = BitVec::from_bytes(bytes);
// Chop off zeroes from the back. We get bits in bytes, so if // Chop off zeroes from the back. We get bits in bytes, so if
// our bitvec is not in full bytes, extra zeroes are added to // our bitvec is not in full bytes, extra zeroes are added to

View file

@ -24,7 +24,7 @@ impl Encode<'_, Postgres> for NaiveDate {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
// DATE is encoded as the days since epoch // DATE is encoded as the days since epoch
let days = (*self - postgres_epoch_date()).num_days() as i32; let days = (*self - postgres_epoch_date()).num_days() as i32;
Encode::<Postgres>::encode(&days, buf) Encode::<Postgres>::encode(days, buf)
} }
fn size_hint(&self) -> usize { fn size_hint(&self) -> usize {

View file

@ -35,11 +35,11 @@ impl<Tz: TimeZone> PgHasArrayType for DateTime<Tz> {
impl Encode<'_, Postgres> for NaiveDateTime { impl Encode<'_, Postgres> for NaiveDateTime {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
// TIMESTAMP is encoded as the microseconds since the epoch // TIMESTAMP is encoded as the microseconds since the epoch
let us = (*self - postgres_epoch_datetime()) let micros = (*self - postgres_epoch_datetime())
.num_microseconds() .num_microseconds()
.ok_or_else(|| format!("NaiveDateTime out of range for Postgres: {self:?}"))?; .ok_or_else(|| format!("NaiveDateTime out of range for Postgres: {self:?}"))?;
Encode::<Postgres>::encode(&us, buf) Encode::<Postgres>::encode(micros, buf)
} }
fn size_hint(&self) -> usize { fn size_hint(&self) -> usize {

View file

@ -21,11 +21,11 @@ impl PgHasArrayType for NaiveTime {
impl Encode<'_, Postgres> for NaiveTime { impl Encode<'_, Postgres> for NaiveTime {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
// TIME is encoded as the microseconds since midnight // TIME is encoded as the microseconds since midnight
let us = (*self - NaiveTime::default()) let micros = (*self - NaiveTime::default())
.num_microseconds() .num_microseconds()
.ok_or_else(|| format!("Time out of range for PostgreSQL: {self}"))?; .ok_or_else(|| format!("Time out of range for PostgreSQL: {self}"))?;
Encode::<Postgres>::encode(&us, buf) Encode::<Postgres>::encode(micros, buf)
} }
fn size_hint(&self) -> usize { fn size_hint(&self) -> usize {

View file

@ -167,7 +167,7 @@ impl TryFrom<chrono::Duration> for PgInterval {
Ok(Self { Ok(Self {
months: 0, months: 0,
days: 0, days: 0,
microseconds: microseconds, microseconds,
}) })
}, },
) )

View file

@ -75,6 +75,9 @@ impl PgLQuery {
} }
/// creates lquery from an iterator with checking labels /// creates lquery from an iterator with checking labels
// TODO: this should just be removed but I didn't want to bury it in a massive diff
#[deprecated = "renamed to `try_from_iter()`"]
#[allow(clippy::should_implement_trait)]
pub fn from_iter<I, S>(levels: I) -> Result<Self, PgLQueryParseError> pub fn from_iter<I, S>(levels: I) -> Result<Self, PgLQueryParseError>
where where
S: Into<String>, S: Into<String>,
@ -86,6 +89,26 @@ impl PgLQuery {
} }
Ok(lquery) Ok(lquery)
} }
/// Create an `LQUERY` from an iterator of label strings.
///
/// Returns an error if any label fails to parse according to [`PgLQueryLevel::from_str()`].
pub fn try_from_iter<I, S>(levels: I) -> Result<Self, PgLQueryParseError>
where
S: AsRef<str>,
I: IntoIterator<Item = S>,
{
levels
.into_iter()
.map(|level| level.as_ref().parse::<PgLQueryLevel>())
.collect()
}
}
impl FromIterator<PgLQueryLevel> for PgLQuery {
fn from_iter<T: IntoIterator<Item = PgLQueryLevel>>(iter: T) -> Self {
Self::from(iter.into_iter().collect())
}
} }
impl IntoIterator for PgLQuery { impl IntoIterator for PgLQuery {
@ -104,7 +127,7 @@ impl FromStr for PgLQuery {
Ok(Self { Ok(Self {
levels: s levels: s
.split('.') .split('.')
.map(|s| PgLQueryLevel::from_str(s)) .map(PgLQueryLevel::from_str)
.collect::<Result<_, Self::Err>>()?, .collect::<Result<_, Self::Err>>()?,
}) })
} }
@ -244,12 +267,12 @@ impl FromStr for PgLQueryLevel {
b'!' => Ok(PgLQueryLevel::NotNonStar( b'!' => Ok(PgLQueryLevel::NotNonStar(
s[1..] s[1..]
.split('|') .split('|')
.map(|s| PgLQueryVariant::from_str(s)) .map(PgLQueryVariant::from_str)
.collect::<Result<Vec<_>, PgLQueryParseError>>()?, .collect::<Result<Vec<_>, PgLQueryParseError>>()?,
)), )),
_ => Ok(PgLQueryLevel::NonStar( _ => Ok(PgLQueryLevel::NonStar(
s.split('|') s.split('|')
.map(|s| PgLQueryVariant::from_str(s)) .map(PgLQueryVariant::from_str)
.collect::<Result<Vec<_>, PgLQueryParseError>>()?, .collect::<Result<Vec<_>, PgLQueryParseError>>()?,
)), )),
} }
@ -262,10 +285,9 @@ impl FromStr for PgLQueryVariant {
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut label_length = s.len(); let mut label_length = s.len();
let mut rev_iter = s.bytes().rev();
let mut modifiers = PgLQueryVariantFlag::empty(); let mut modifiers = PgLQueryVariantFlag::empty();
while let Some(b) = rev_iter.next() { for b in s.bytes().rev() {
match b { match b {
b'@' => modifiers.insert(PgLQueryVariantFlag::IN_CASE), b'@' => modifiers.insert(PgLQueryVariantFlag::IN_CASE),
b'*' => modifiers.insert(PgLQueryVariantFlag::ANY_END), b'*' => modifiers.insert(PgLQueryVariantFlag::ANY_END),
@ -306,8 +328,8 @@ impl Display for PgLQueryLevel {
PgLQueryLevel::Star(Some(at_least), _) => write!(f, "*{{{at_least},}}"), PgLQueryLevel::Star(Some(at_least), _) => write!(f, "*{{{at_least},}}"),
PgLQueryLevel::Star(_, Some(at_most)) => write!(f, "*{{,{at_most}}}"), PgLQueryLevel::Star(_, Some(at_most)) => write!(f, "*{{,{at_most}}}"),
PgLQueryLevel::Star(_, _) => write!(f, "*"), PgLQueryLevel::Star(_, _) => write!(f, "*"),
PgLQueryLevel::NonStar(variants) => write_variants(f, &variants, false), PgLQueryLevel::NonStar(variants) => write_variants(f, variants, false),
PgLQueryLevel::NotNonStar(variants) => write_variants(f, &variants, true), PgLQueryLevel::NotNonStar(variants) => write_variants(f, variants, true),
} }
} }
} }

View file

@ -27,9 +27,9 @@ pub struct PgLTreeLabel(String);
impl PgLTreeLabel { impl PgLTreeLabel {
pub fn new<S>(label: S) -> Result<Self, PgLTreeParseError> pub fn new<S>(label: S) -> Result<Self, PgLTreeParseError>
where where
String: From<S>, S: Into<String>,
{ {
let label = String::from(label); let label = label.into();
if label.len() <= 256 if label.len() <= 256
&& label && label
.bytes() .bytes()
@ -101,6 +101,9 @@ impl PgLTree {
} }
/// creates ltree from an iterator with checking labels /// creates ltree from an iterator with checking labels
// TODO: this should just be removed but I didn't want to bury it in a massive diff
#[deprecated = "renamed to `try_from_iter()`"]
#[allow(clippy::should_implement_trait)]
pub fn from_iter<I, S>(labels: I) -> Result<Self, PgLTreeParseError> pub fn from_iter<I, S>(labels: I) -> Result<Self, PgLTreeParseError>
where where
String: From<S>, String: From<S>,
@ -113,6 +116,17 @@ impl PgLTree {
Ok(ltree) Ok(ltree)
} }
/// Create an `LTREE` from an iterator of label strings.
///
/// Returns an error if any label fails to parse according to [`PgLTreeLabel::new()`].
pub fn try_from_iter<I, S>(labels: I) -> Result<Self, PgLTreeParseError>
where
S: Into<String>,
I: IntoIterator<Item = S>,
{
labels.into_iter().map(PgLTreeLabel::new).collect()
}
/// push a label to ltree /// push a label to ltree
pub fn push(&mut self, label: PgLTreeLabel) { pub fn push(&mut self, label: PgLTreeLabel) {
self.labels.push(label); self.labels.push(label);
@ -124,6 +138,14 @@ impl PgLTree {
} }
} }
impl FromIterator<PgLTreeLabel> for PgLTree {
fn from_iter<T: IntoIterator<Item = PgLTreeLabel>>(iter: T) -> Self {
Self {
labels: iter.into_iter().collect(),
}
}
}
impl IntoIterator for PgLTree { impl IntoIterator for PgLTree {
type Item = PgLTreeLabel; type Item = PgLTreeLabel;
type IntoIter = std::vec::IntoIter<Self::Item>; type IntoIter = std::vec::IntoIter<Self::Item>;
@ -140,7 +162,7 @@ impl FromStr for PgLTree {
Ok(Self { Ok(Self {
labels: s labels: s
.split('.') .split('.')
.map(|s| PgLTreeLabel::new(s)) .map(PgLTreeLabel::new)
.collect::<Result<Vec<_>, Self::Err>>()?, .collect::<Result<Vec<_>, Self::Err>>()?,
}) })
} }

View file

@ -261,7 +261,7 @@ fn array_compatible<E: Type<Postgres> + ?Sized>(ty: &PgTypeInfo) -> bool {
// we require the declared type to be an _array_ with an // we require the declared type to be an _array_ with an
// element type that is acceptable // element type that is acceptable
if let PgTypeKind::Array(element) = &ty.kind() { if let PgTypeKind::Array(element) = &ty.kind() {
return E::compatible(&element); return E::compatible(element);
} }
false false

View file

@ -445,7 +445,7 @@ where
} }
count += 1; count += 1;
if !(element.is_empty() && !quoted) { if !element.is_empty() || quoted {
let value = Some(T::decode(PgValueRef { let value = Some(T::decode(PgValueRef {
type_info: T::type_info(), type_info: T::type_info(),
format: PgValueFormat::Text, format: PgValueFormat::Text,
@ -515,7 +515,7 @@ fn range_compatible<E: Type<Postgres>>(ty: &PgTypeInfo) -> bool {
// we require the declared type to be a _range_ with an // we require the declared type to be a _range_ with an
// element type that is acceptable // element type that is acceptable
if let PgTypeKind::Range(element) = &ty.kind() { if let PgTypeKind::Range(element) = &ty.kind() {
return E::compatible(&element); return E::compatible(element);
} }
false false

View file

@ -103,9 +103,9 @@ impl From<&'_ Decimal> for PgNumeric {
let groups_diff = scale % 4; let groups_diff = scale % 4;
if groups_diff > 0 { if groups_diff > 0 {
let remainder = 4 - groups_diff as u32; let remainder = 4 - groups_diff as u32;
let power = 10u32.pow(remainder as u32) as u128; let power = 10u32.pow(remainder) as u128;
mantissa = mantissa * power; mantissa *= power;
} }
// Array to store max mantissa of Decimal in Postgres decimal format. // Array to store max mantissa of Decimal in Postgres decimal format.
@ -121,7 +121,7 @@ impl From<&'_ Decimal> for PgNumeric {
digits.reverse(); digits.reverse();
// Weight is number of digits on the left side of the decimal. // Weight is number of digits on the left side of the decimal.
let digits_after_decimal = (scale + 3) as u16 / 4; let digits_after_decimal = (scale + 3) / 4;
let weight = digits.len() as i16 - digits_after_decimal as i16 - 1; let weight = digits.len() as i16 - digits_after_decimal as i16 - 1;
// Remove non-significant zeroes. // Remove non-significant zeroes.

View file

@ -125,7 +125,7 @@ impl Encode<'_, Postgres> for String {
impl<'r> Decode<'r, Postgres> for &'r str { impl<'r> Decode<'r, Postgres> for &'r str {
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> { fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(value.as_str()?) value.as_str()
} }
} }

View file

@ -24,7 +24,7 @@ impl Encode<'_, Postgres> for Date {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
// DATE is encoded as the days since epoch // DATE is encoded as the days since epoch
let days = (*self - PG_EPOCH).whole_days() as i32; let days = (*self - PG_EPOCH).whole_days() as i32;
Encode::<Postgres>::encode(&days, buf) Encode::<Postgres>::encode(days, buf)
} }
fn size_hint(&self) -> usize { fn size_hint(&self) -> usize {

View file

@ -37,8 +37,8 @@ impl PgHasArrayType for OffsetDateTime {
impl Encode<'_, Postgres> for PrimitiveDateTime { impl Encode<'_, Postgres> for PrimitiveDateTime {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
// TIMESTAMP is encoded as the microseconds since the epoch // TIMESTAMP is encoded as the microseconds since the epoch
let us = (*self - PG_EPOCH.midnight()).whole_microseconds() as i64; let micros = (*self - PG_EPOCH.midnight()).whole_microseconds() as i64;
Encode::<Postgres>::encode(&us, buf) Encode::<Postgres>::encode(micros, buf)
} }
fn size_hint(&self) -> usize { fn size_hint(&self) -> usize {
@ -69,10 +69,10 @@ impl<'r> Decode<'r, Postgres> for PrimitiveDateTime {
// This is given for timestamptz for some reason // This is given for timestamptz for some reason
// Postgres already guarantees this to always be UTC // Postgres already guarantees this to always be UTC
if s.contains('+') { if s.contains('+') {
PrimitiveDateTime::parse(&*s, &format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond][offset_hour]"))? PrimitiveDateTime::parse(&s, &format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond][offset_hour]"))?
} else { } else {
PrimitiveDateTime::parse( PrimitiveDateTime::parse(
&*s, &s,
&format_description!( &format_description!(
"[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]" "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond]"
), ),
@ -88,7 +88,7 @@ impl Encode<'_, Postgres> for OffsetDateTime {
let utc = self.to_offset(offset!(UTC)); let utc = self.to_offset(offset!(UTC));
let primitive = PrimitiveDateTime::new(utc.date(), utc.time()); let primitive = PrimitiveDateTime::new(utc.date(), utc.time());
Encode::<Postgres>::encode(&primitive, buf) Encode::<Postgres>::encode(primitive, buf)
} }
fn size_hint(&self) -> usize { fn size_hint(&self) -> usize {

View file

@ -1,5 +1,8 @@
mod date; mod date;
mod datetime; mod datetime;
// Parent module is named after the `time` crate, this module is named after the `TIME` SQL type.
#[allow(clippy::module_inception)]
mod time; mod time;
#[rustfmt::skip] #[rustfmt::skip]

View file

@ -22,8 +22,8 @@ impl PgHasArrayType for Time {
impl Encode<'_, Postgres> for Time { impl Encode<'_, Postgres> for Time {
fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
// TIME is encoded as the microseconds since midnight // TIME is encoded as the microseconds since midnight
let us = (*self - Time::MIDNIGHT).whole_microseconds() as i64; let micros = (*self - Time::MIDNIGHT).whole_microseconds() as i64;
Encode::<Postgres>::encode(&us, buf) Encode::<Postgres>::encode(micros, buf)
} }
fn size_hint(&self) -> usize { fn size_hint(&self) -> usize {

View file

@ -88,38 +88,36 @@ mod chrono {
Ok(PgTimeTz { time, offset }) Ok(PgTimeTz { time, offset })
} }
PgValueFormat::Text => { PgValueFormat::Text => try_parse_timetz(value.as_str()?),
let s = value.as_str()?; }
}
}
let mut tmp = String::with_capacity(11 + s.len()); fn try_parse_timetz(s: &str) -> Result<PgTimeTz<NaiveTime, FixedOffset>, BoxDynError> {
tmp.push_str("2001-07-08 "); let mut tmp = String::with_capacity(11 + s.len());
tmp.push_str(s); tmp.push_str("2001-07-08 ");
tmp.push_str(s);
let dt = 'out: loop { let mut err = None;
let mut err = None;
for fmt in &["%Y-%m-%d %H:%M:%S%.f%#z", "%Y-%m-%d %H:%M:%S%.f"] {
match DateTime::parse_from_str(&tmp, fmt) {
Ok(dt) => {
break 'out dt;
}
Err(error) => {
err = Some(error);
}
}
}
return Err(err.unwrap().into());
};
for fmt in &["%Y-%m-%d %H:%M:%S%.f%#z", "%Y-%m-%d %H:%M:%S%.f"] {
match DateTime::parse_from_str(&tmp, fmt) {
Ok(dt) => {
let time = dt.time(); let time = dt.time();
let offset = *dt.offset(); let offset = *dt.offset();
Ok(PgTimeTz { time, offset }) return Ok(PgTimeTz { time, offset });
}
Err(error) => {
err = Some(error);
} }
} }
} }
Err(err
.expect("BUG: loop should have set `err` to `Some()` before exiting")
.into())
} }
} }

View file

@ -19,7 +19,7 @@ impl Column for SqliteColumn {
} }
fn name(&self) -> &str { fn name(&self) -> &str {
&*self.name &self.name
} }
fn type_info(&self) -> &SqliteTypeInfo { fn type_info(&self) -> &SqliteTypeInfo {

View file

@ -15,6 +15,7 @@ use crate::SqliteError;
#[derive(Clone)] #[derive(Clone)]
pub struct Collation { pub struct Collation {
name: Arc<str>, name: Arc<str>,
#[allow(clippy::type_complexity)]
collate: Arc<dyn Fn(&str, &str) -> Ordering + Send + Sync + 'static>, collate: Arc<dyn Fn(&str, &str) -> Ordering + Send + Sync + 'static>,
// SAFETY: these must match the concrete type of `collate` // SAFETY: these must match the concrete type of `collate`
call: unsafe extern "C" fn( call: unsafe extern "C" fn(

View file

@ -24,6 +24,7 @@ use std::time::Duration;
// https://doc.rust-lang.org/stable/std/sync/atomic/index.html#portability // https://doc.rust-lang.org/stable/std/sync/atomic/index.html#portability
static THREAD_ID: AtomicUsize = AtomicUsize::new(0); static THREAD_ID: AtomicUsize = AtomicUsize::new(0);
#[derive(Copy, Clone)]
enum SqliteLoadExtensionMode { enum SqliteLoadExtensionMode {
/// Enables only the C-API, leaving the SQL function disabled. /// Enables only the C-API, leaving the SQL function disabled.
Enable, Enable,
@ -32,7 +33,7 @@ enum SqliteLoadExtensionMode {
} }
impl SqliteLoadExtensionMode { impl SqliteLoadExtensionMode {
fn as_int(self) -> c_int { fn to_int(self) -> c_int {
match self { match self {
SqliteLoadExtensionMode::Enable => 1, SqliteLoadExtensionMode::Enable => 1,
SqliteLoadExtensionMode::DisableAll => 0, SqliteLoadExtensionMode::DisableAll => 0,
@ -101,13 +102,13 @@ impl EstablishParams {
} }
if let Some(vfs) = options.vfs.as_deref() { if let Some(vfs) = options.vfs.as_deref() {
query_params.insert("vfs", &vfs); query_params.insert("vfs", vfs);
} }
if !query_params.is_empty() { if !query_params.is_empty() {
filename = format!( filename = format!(
"file:{}?{}", "file:{}?{}",
percent_encoding::percent_encode(filename.as_bytes(), &NON_ALPHANUMERIC), percent_encoding::percent_encode(filename.as_bytes(), NON_ALPHANUMERIC),
serde_urlencoded::to_string(&query_params).unwrap() serde_urlencoded::to_string(&query_params).unwrap()
); );
flags |= libsqlite3_sys::SQLITE_OPEN_URI; flags |= libsqlite3_sys::SQLITE_OPEN_URI;
@ -174,7 +175,7 @@ impl EstablishParams {
let status = sqlite3_db_config( let status = sqlite3_db_config(
db, db,
SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION,
mode.as_int(), mode.to_int(),
null::<i32>(), null::<i32>(),
); );
@ -294,7 +295,7 @@ impl EstablishParams {
transaction_depth: 0, transaction_depth: 0,
log_settings: self.log_settings.clone(), log_settings: self.log_settings.clone(),
progress_handler_callback: None, progress_handler_callback: None,
update_hook_callback: None update_hook_callback: None,
}) })
} }
} }

View file

@ -68,10 +68,10 @@ impl Iterator for ExecuteIter<'_> {
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
let statement = if self.goto_next { let statement = if self.goto_next {
let mut statement = match self.statement.prepare_next(self.handle) { let statement = match self.statement.prepare_next(self.handle) {
Ok(Some(statement)) => statement, Ok(Some(statement)) => statement,
Ok(None) => return None, Ok(None) => return None,
Err(e) => return Some(Err(e.into())), Err(e) => return Some(Err(e)),
}; };
self.goto_next = false; self.goto_next = false;
@ -83,7 +83,7 @@ impl Iterator for ExecuteIter<'_> {
statement.handle.clear_bindings(); statement.handle.clear_bindings();
match bind(&mut statement.handle, &self.args, self.args_used) { match bind(statement.handle, &self.args, self.args_used) {
Ok(args_used) => self.args_used += args_used, Ok(args_used) => self.args_used += args_used,
Err(e) => return Some(Err(e)), Err(e) => return Some(Err(e)),
} }
@ -98,9 +98,9 @@ impl Iterator for ExecuteIter<'_> {
self.logger.increment_rows_returned(); self.logger.increment_rows_returned();
Some(Ok(Either::Right(SqliteRow::current( Some(Ok(Either::Right(SqliteRow::current(
&statement.handle, statement.handle,
&statement.columns, statement.columns,
&statement.column_names, statement.column_names,
)))) ))))
} }
Ok(false) => { Ok(false) => {

View file

@ -13,13 +13,15 @@ use std::future;
impl<'c> Executor<'c> for &'c mut SqliteConnection { impl<'c> Executor<'c> for &'c mut SqliteConnection {
type Database = Sqlite; type Database = Sqlite;
fn fetch_many<'e, 'q: 'e, E: 'q>( fn fetch_many<'e, 'q, E>(
self, self,
mut query: E, mut query: E,
) -> BoxStream<'e, Result<Either<SqliteQueryResult, SqliteRow>, Error>> ) -> BoxStream<'e, Result<Either<SqliteQueryResult, SqliteRow>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
'q: 'e,
E: 'q,
{ {
let sql = query.sql(); let sql = query.sql();
let arguments = match query.take_arguments().map_err(Error::Encode) { let arguments = match query.take_arguments().map_err(Error::Encode) {
@ -36,13 +38,15 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection {
) )
} }
fn fetch_optional<'e, 'q: 'e, E: 'q>( fn fetch_optional<'e, 'q, E>(
self, self,
mut query: E, mut query: E,
) -> BoxFuture<'e, Result<Option<SqliteRow>, Error>> ) -> BoxFuture<'e, Result<Option<SqliteRow>, Error>>
where where
'c: 'e, 'c: 'e,
E: Execute<'q, Self::Database>, E: Execute<'q, Self::Database>,
'q: 'e,
E: 'q,
{ {
let sql = query.sql(); let sql = query.sql();
let arguments = match query.take_arguments().map_err(Error::Encode) { let arguments = match query.take_arguments().map_err(Error::Encode) {

View file

@ -160,7 +160,7 @@ impl ColumnType {
} }
fn map_to_datatype(&self) -> DataType { fn map_to_datatype(&self) -> DataType {
match self { match self {
Self::Single { datatype, .. } => datatype.clone(), Self::Single { datatype, .. } => *datatype,
Self::Record(_) => DataType::Null, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context Self::Record(_) => DataType::Null, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context
} }
} }
@ -188,7 +188,7 @@ impl core::fmt::Debug for ColumnType {
let mut column_iter = columns.iter(); let mut column_iter = columns.iter();
if let Some(item) = column_iter.next() { if let Some(item) = column_iter.next() {
write!(f, "{:?}", item)?; write!(f, "{:?}", item)?;
while let Some(item) = column_iter.next() { for item in column_iter {
write!(f, ", {:?}", item)?; write!(f, ", {:?}", item)?;
} }
} }
@ -400,7 +400,7 @@ fn root_block_columns(
); );
} }
return Ok(row_info); Ok(row_info)
} }
struct Sequence(i64); struct Sequence(i64);
@ -544,7 +544,7 @@ impl BranchList {
std::collections::hash_map::Entry::Occupied(entry) => { std::collections::hash_map::Entry::Occupied(entry) => {
//already saw a state identical to this one, so no point in processing it //already saw a state identical to this one, so no point in processing it
state.mem = entry.key().clone(); //replace state.mem since .entry() moved it state.mem = entry.key().clone(); //replace state.mem since .entry() moved it
logger.add_result(state, BranchResult::Dedup(entry.get().clone())); logger.add_result(state, BranchResult::Dedup(*entry.get()));
} }
} }
} }
@ -974,7 +974,7 @@ pub(super) fn explain(
.and_then(|c| c.columns_ref(&state.mem.t, &state.mem.r)) .and_then(|c| c.columns_ref(&state.mem.t, &state.mem.r))
.and_then(|cc| cc.get(&p2)) .and_then(|cc| cc.get(&p2))
.cloned() .cloned()
.unwrap_or_else(|| ColumnType::default()); .unwrap_or_default();
// insert into p3 the datatype of the col // insert into p3 the datatype of the col
state.mem.r.insert(p3, RegDataType::Single(value)); state.mem.r.insert(p3, RegDataType::Single(value));
@ -1123,7 +1123,7 @@ pub(super) fn explain(
OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX | OP_SORTER_OPEN => { OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX | OP_SORTER_OPEN => {
//Create a new pointer which is referenced by p1 //Create a new pointer which is referenced by p1
let table_info = TableDataType { let table_info = TableDataType {
cols: IntMap::from_dense_record(&vec![ColumnType::null(); p2 as usize]), cols: IntMap::from_elem(ColumnType::null(), p2 as usize),
is_empty: Some(true), is_empty: Some(true),
}; };
@ -1376,7 +1376,7 @@ pub(super) fn explain(
state.mem.r.insert( state.mem.r.insert(
p2, p2,
RegDataType::Single(ColumnType::Single { RegDataType::Single(ColumnType::Single {
datatype: opcode_to_type(&opcode), datatype: opcode_to_type(opcode),
nullable: Some(false), nullable: Some(false),
}), }),
); );
@ -1490,8 +1490,7 @@ pub(super) fn explain(
while let Some(result) = result_states.pop() { while let Some(result) = result_states.pop() {
// find the datatype info from each ResultRow execution // find the datatype info from each ResultRow execution
let mut idx = 0; for (idx, this_col) in result.into_iter().enumerate() {
for this_col in result {
let this_type = this_col.map_to_datatype(); let this_type = this_col.map_to_datatype();
let this_nullable = this_col.map_to_nullable(); let this_nullable = this_col.map_to_nullable();
if output.len() == idx { if output.len() == idx {
@ -1513,7 +1512,6 @@ pub(super) fn explain(
} else { } else {
nullable[idx] = this_nullable; nullable[idx] = this_nullable;
} }
idx += 1;
} }
} }

View file

@ -1,3 +1,4 @@
use std::cmp::Ordering;
use std::{fmt::Debug, hash::Hash}; use std::{fmt::Debug, hash::Hash};
/// Simplistic map implementation built on a Vec of Options (index = key) /// Simplistic map implementation built on a Vec of Options (index = key)
@ -65,7 +66,7 @@ impl<V> IntMap<V> {
let item = self.0.get_mut(idx); let item = self.0.get_mut(idx);
match item { match item {
Some(content) => std::mem::replace(content, None), Some(content) => content.take(),
None => None, None => None,
} }
} }
@ -100,7 +101,10 @@ impl<V: Default> IntMap<V> {
} }
impl<V: Clone> IntMap<V> { impl<V: Clone> IntMap<V> {
pub(crate) fn from_dense_record(record: &Vec<V>) -> Self { pub(crate) fn from_elem(elem: V, len: usize) -> Self {
Self(vec![Some(elem); len])
}
pub(crate) fn from_dense_record(record: &[V]) -> Self {
Self(record.iter().cloned().map(Some).collect()) Self(record.iter().cloned().map(Some).collect())
} }
} }
@ -139,21 +143,16 @@ impl<V: Hash> Hash for IntMap<V> {
impl<V: PartialEq> PartialEq for IntMap<V> { impl<V: PartialEq> PartialEq for IntMap<V> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
if !self match self.0.len().cmp(&other.0.len()) {
.0 Ordering::Greater => {
.iter() self.0[..other.0.len()] == other.0
.zip(other.0.iter()) && self.0[other.0.len()..].iter().all(Option::is_none)
.all(|(l, r)| PartialEq::eq(l, r)) }
{ Ordering::Less => {
return false; other.0[..self.0.len()] == self.0
} && other.0[self.0.len()..].iter().all(Option::is_none)
}
if self.0.len() > other.0.len() { Ordering::Equal => self.0 == other.0,
self.0[other.0.len()..].iter().all(Option::is_none)
} else if self.0.len() < other.0.len() {
other.0[self.0.len()..].iter().all(Option::is_none)
} else {
true
} }
} }
} }

View file

@ -4,6 +4,7 @@ use std::fmt::Write;
use std::fmt::{self, Debug, Formatter}; use std::fmt::{self, Debug, Formatter};
use std::os::raw::{c_int, c_void}; use std::os::raw::{c_int, c_void};
use std::panic::catch_unwind; use std::panic::catch_unwind;
use std::ptr;
use std::ptr::NonNull; use std::ptr::NonNull;
use futures_core::future::BoxFuture; use futures_core::future::BoxFuture;
@ -112,7 +113,7 @@ impl ConnectionState {
pub(crate) fn remove_progress_handler(&mut self) { pub(crate) fn remove_progress_handler(&mut self) {
if let Some(mut handler) = self.progress_handler_callback.take() { if let Some(mut handler) = self.progress_handler_callback.take() {
unsafe { unsafe {
sqlite3_progress_handler(self.handle.as_ptr(), 0, None, std::ptr::null_mut()); sqlite3_progress_handler(self.handle.as_ptr(), 0, None, ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) }; let _ = { Box::from_raw(handler.0.as_mut()) };
} }
} }
@ -121,7 +122,7 @@ impl ConnectionState {
pub(crate) fn remove_update_hook(&mut self) { pub(crate) fn remove_update_hook(&mut self) {
if let Some(mut handler) = self.update_hook_callback.take() { if let Some(mut handler) = self.update_hook_callback.take() {
unsafe { unsafe {
sqlite3_update_hook(self.handle.as_ptr(), None, std::ptr::null_mut()); sqlite3_update_hook(self.handle.as_ptr(), None, ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) }; let _ = { Box::from_raw(handler.0.as_mut()) };
} }
} }

View file

@ -45,9 +45,9 @@ pub struct QueryPlanLogger<'q, R: Debug + 'static, S: Debug + DebugDiff + 'stati
fn dot_escape_string(value: impl AsRef<str>) -> String { fn dot_escape_string(value: impl AsRef<str>) -> String {
value value
.as_ref() .as_ref()
.replace("\\", "\\\\") .replace('\\', r#"\\"#)
.replace("\"", "'") .replace('"', "'")
.replace("\n", "\\n") .replace('\n', r#"\n"#)
.to_string() .to_string()
} }
@ -76,7 +76,7 @@ impl<R: Debug, S: Debug + DebugDiff, P: Debug> core::fmt::Display for QueryPlanL
let mut instruction_uses: IntMap<Vec<BranchParent>> = Default::default(); let mut instruction_uses: IntMap<Vec<BranchParent>> = Default::default();
for (k, state) in all_states.iter() { for (k, state) in all_states.iter() {
let entry = instruction_uses.get_mut_or_default(&(state.program_i as i64)); let entry = instruction_uses.get_mut_or_default(&(state.program_i as i64));
entry.push(k.clone()); entry.push(*k);
} }
let mut branch_children: std::collections::HashMap<BranchParent, Vec<BranchParent>> = let mut branch_children: std::collections::HashMap<BranchParent, Vec<BranchParent>> =
@ -127,27 +127,27 @@ impl<R: Debug, S: Debug + DebugDiff, P: Debug> core::fmt::Display for QueryPlanL
state_list state_list
.entry(state_diff) .entry(state_diff)
.or_default() .or_default()
.push((curr_ref.clone(), Some(next_ref))); .push((*curr_ref, Some(next_ref)));
} else { } else {
state_list state_list
.entry(Default::default()) .entry(Default::default())
.or_default() .or_default()
.push((curr_ref.clone(), None)); .push((*curr_ref, None));
}; };
if let Some(children) = branch_children.get(curr_ref) { if let Some(children) = branch_children.get(curr_ref) {
for next_ref in children { for next_ref in children {
if let Some(next_state) = all_states.get(&next_ref) { if let Some(next_state) = all_states.get(next_ref) {
let state_diff = next_state.state.diff(&curr_state.state); let state_diff = next_state.state.diff(&curr_state.state);
if !state_diff.is_empty() { if !state_diff.is_empty() {
branched_with_state.insert(next_ref.clone()); branched_with_state.insert(*next_ref);
} }
state_list state_list
.entry(state_diff) .entry(state_diff)
.or_default() .or_default()
.push((curr_ref.clone(), Some(next_ref.clone()))); .push((*curr_ref, Some(*next_ref)));
} }
} }
}; };
@ -176,7 +176,7 @@ impl<R: Debug, S: Debug + DebugDiff, P: Debug> core::fmt::Display for QueryPlanL
for (curr_ref, next_ref) in ref_list { for (curr_ref, next_ref) in ref_list {
if let Some(next_ref) = next_ref { if let Some(next_ref) = next_ref {
let next_program_i = all_states let next_program_i = all_states
.get(&next_ref) .get(next_ref)
.map(|s| s.program_i.to_string()) .map(|s| s.program_i.to_string())
.unwrap_or_default(); .unwrap_or_default();
@ -258,7 +258,7 @@ impl<R: Debug, S: Debug + DebugDiff, P: Debug> core::fmt::Display for QueryPlanL
let mut instruction_list: Vec<(BranchParent, &InstructionHistory<S>)> = Vec::new(); let mut instruction_list: Vec<(BranchParent, &InstructionHistory<S>)> = Vec::new();
if let Some(parent) = self.branch_origins.get(&branch_id) { if let Some(parent) = self.branch_origins.get(&branch_id) {
if let Some(parent_state) = all_states.get(parent) { if let Some(parent_state) = all_states.get(parent) {
instruction_list.push((parent.clone(), parent_state)); instruction_list.push((*parent, parent_state));
} }
} }
if let Some(instructions) = self.branch_operations.get(&branch_id) { if let Some(instructions) = self.branch_operations.get(&branch_id) {
@ -278,11 +278,11 @@ impl<R: Debug, S: Debug + DebugDiff, P: Debug> core::fmt::Display for QueryPlanL
if let Some((cur_ref, _)) = instructions_iter.next() { if let Some((cur_ref, _)) = instructions_iter.next() {
let mut prev_ref = cur_ref; let mut prev_ref = cur_ref;
while let Some((cur_ref, _)) = instructions_iter.next() { for (cur_ref, _) in instructions_iter {
if branched_with_state.contains(&cur_ref) { if branched_with_state.contains(&cur_ref) {
write!( writeln!(
f, f,
"\"b{}p{}\" -> \"b{}p{}_b{}p{}\" -> \"b{}p{}\"\n", "\"b{}p{}\" -> \"b{}p{}_b{}p{}\" -> \"b{}p{}\"",
prev_ref.id, prev_ref.id,
prev_ref.idx, prev_ref.idx,
prev_ref.id, prev_ref.id,
@ -360,7 +360,7 @@ impl<'q, R: Debug, S: Debug + DebugDiff, P: Debug> QueryPlanLogger<'q, R, S, P>
return; return;
} }
let branch: BranchParent = BranchParent::from(state); let branch: BranchParent = BranchParent::from(state);
self.branch_origins.insert(branch.id, parent.clone()); self.branch_origins.insert(branch.id, *parent);
} }
pub fn add_operation<I: Copy>(&mut self, program_i: usize, state: I) pub fn add_operation<I: Copy>(&mut self, program_i: usize, state: I)
@ -402,14 +402,14 @@ impl<'q, R: Debug, S: Debug + DebugDiff, P: Debug> QueryPlanLogger<'q, R, S, P>
return; return;
} }
let mut summary = parse_query_summary(&self.sql); let mut summary = parse_query_summary(self.sql);
let sql = if summary != self.sql { let sql = if summary != self.sql {
summary.push_str(""); summary.push_str("");
format!( format!(
"\n\n{}\n", "\n\n{}\n",
sqlformat::format( sqlformat::format(
&self.sql, self.sql,
&sqlformat::QueryParams::None, &sqlformat::QueryParams::None,
sqlformat::FormatOptions::default() sqlformat::FormatOptions::default()
) )

View file

@ -28,8 +28,7 @@ impl MigrateDatabase for Sqlite {
} }
// Opening a connection to sqlite creates the database // Opening a connection to sqlite creates the database
let _ = opts opts.connect()
.connect()
.await? .await?
// Ensure WAL mode tempfiles are cleaned up // Ensure WAL mode tempfiles are cleaned up
.close() .close()

View file

@ -1,8 +1,9 @@
use crate::error::Error; use crate::error::Error;
use std::str::FromStr; use std::str::FromStr;
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SqliteAutoVacuum { pub enum SqliteAutoVacuum {
#[default]
None, None,
Full, Full,
Incremental, Incremental,
@ -18,12 +19,6 @@ impl SqliteAutoVacuum {
} }
} }
impl Default for SqliteAutoVacuum {
fn default() -> Self {
SqliteAutoVacuum::None
}
}
impl FromStr for SqliteAutoVacuum { impl FromStr for SqliteAutoVacuum {
type Err = Error; type Err = Error;

Some files were not shown because too many files have changed in this diff Show more