MySQL UDS Support

- Adds support for Unix Domain Sockets on MySQL
- Allows setting the socket path in PostgreSQL connection options
- ... and MySQL connection options
This commit is contained in:
Julius de Bruijn 2020-06-25 15:31:26 +02:00 committed by Ryan Leckey
parent e3483230e0
commit 868dc3dd5b
7 changed files with 79 additions and 17 deletions

View file

@ -27,7 +27,7 @@ impl<T> StatementCache<T> {
pub fn insert(&mut self, k: &str, v: T) -> Option<T> {
let mut lru_item = None;
if self.inner.capacity() == self.len() && !self.inner.contains_key(k) {
if self.capacity() == self.len() && !self.contains_key(k) {
lru_item = self.remove_lru();
} else if self.contains_key(k) {
lru_item = self.inner.remove(k);
@ -49,7 +49,7 @@ impl<T> StatementCache<T> {
}
/// Clear all cached statements from the cache.
#[cfg(any(feature = "sqlite"))]
#[cfg(feature = "sqlite")]
pub fn clear(&mut self) {
self.inner.clear();
}

View file

@ -1,5 +1,4 @@
use std::fmt::{self, Debug, Formatter};
use std::net::Shutdown;
use std::sync::Arc;
use futures_core::future::BoxFuture;
@ -57,7 +56,7 @@ impl Connection for MySqlConnection {
fn close(mut self) -> BoxFuture<'static, Result<(), Error>> {
Box::pin(async move {
self.stream.send_packet(Quit).await?;
self.stream.shutdown(Shutdown::Both)?;
self.stream.shutdown()?;
Ok(())
})

View file

@ -1,7 +1,6 @@
use std::ops::{Deref, DerefMut};
use bytes::{Buf, Bytes};
use sqlx_rt::TcpStream;
use crate::error::Error;
use crate::io::{BufStream, Decode, Encode};
@ -9,10 +8,10 @@ use crate::mysql::io::MySqlBufExt;
use crate::mysql::protocol::response::{EofPacket, ErrPacket, OkPacket, Status};
use crate::mysql::protocol::{Capabilities, Packet};
use crate::mysql::{MySqlConnectOptions, MySqlDatabaseError};
use crate::net::MaybeTlsStream;
use crate::net::{MaybeTlsStream, Socket};
pub struct MySqlStream {
stream: BufStream<MaybeTlsStream<TcpStream>>,
stream: BufStream<MaybeTlsStream<Socket>>,
pub(super) capabilities: Capabilities,
pub(crate) sequence_id: u8,
pub(crate) busy: Busy,
@ -31,7 +30,10 @@ pub(crate) enum Busy {
impl MySqlStream {
pub(super) async fn connect(options: &MySqlConnectOptions) -> Result<Self, Error> {
let stream = TcpStream::connect((&*options.host, options.port)).await?;
let socket = match options.socket {
Some(ref path) => Socket::connect_uds(path).await?,
None => Socket::connect(&options.host, options.port).await?,
};
let mut capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
@ -54,7 +56,7 @@ impl MySqlStream {
busy: Busy::NotBusy,
capabilities,
sequence_id: 0,
stream: BufStream::new(MaybeTlsStream::Raw(stream)),
stream: BufStream::new(MaybeTlsStream::Raw(socket)),
})
}
@ -178,7 +180,7 @@ impl MySqlStream {
}
impl Deref for MySqlStream {
type Target = BufStream<MaybeTlsStream<TcpStream>>;
type Target = BufStream<MaybeTlsStream<Socket>>;
fn deref(&self) -> &Self::Target {
&self.stream

View file

@ -75,6 +75,7 @@ impl FromStr for MySqlSslMode {
/// | `ssl-mode` | `PREFERRED` | Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated. See [`MySqlSslMode`]. |
/// | `ssl-ca` | `None` | Sets the name of a file containing a list of trusted SSL Certificate Authorities. |
/// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. Set to `0` to disable. |
/// | `socket` | `None` | Path to the unix domain socket, which will be used instead of TCP if set. |
///
/// # Example
///
@ -106,6 +107,7 @@ impl FromStr for MySqlSslMode {
pub struct MySqlConnectOptions {
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) socket: Option<PathBuf>,
pub(crate) username: String,
pub(crate) password: Option<String>,
pub(crate) database: Option<String>,
@ -126,6 +128,7 @@ impl MySqlConnectOptions {
Self {
port: 3306,
host: String::from("localhost"),
socket: None,
username: String::from("root"),
password: None,
database: None,
@ -152,6 +155,15 @@ impl MySqlConnectOptions {
self
}
/// Pass a path to a Unix socket. This changes the connection stream from
/// TCP to UDS.
///
/// By default set to `None`.
pub fn socket(mut self, path: impl AsRef<Path>) -> Self {
self.socket = Some(path.as_ref().to_path_buf());
self
}
/// Sets the username to connect as.
pub fn username(mut self, username: &str) -> Self {
self.username = username.to_owned();
@ -258,6 +270,10 @@ impl FromStr for MySqlConnectOptions {
options = options.statement_cache_capacity(value.parse()?);
}
"socket" => {
options = options.socket(&*value);
}
_ => {}
}
}

View file

@ -2,6 +2,7 @@
use std::io;
use std::net::Shutdown;
use std::path::Path;
use std::pin::Pin;
use std::task::{Context, Poll};
@ -26,14 +27,27 @@ impl Socket {
if host.starts_with('/') {
// if the host starts with a forward slash, assume that this is a request
// to connect to a local socket
sqlx_rt::UnixStream::connect(format!("{}/.s.PGSQL.{}", host, port))
.await
.map(Socket::Unix)
Self::connect_uds(&format!("{}/.s.PGSQL.{}", host, port)).await
} else {
TcpStream::connect((host, port)).await.map(Socket::Tcp)
}
}
#[cfg(unix)]
pub async fn connect_uds(path: impl AsRef<Path>) -> io::Result<Self> {
sqlx_rt::UnixStream::connect(path.as_ref())
.await
.map(Socket::Unix)
}
#[cfg(not(unix))]
pub async fn connect_uds(_: impl AsRef<Path>) -> io::Result<Self> {
Err(io::Error(
io::ErrorKind::Other,
"Unix domain sockets are not supported outside Unix platforms.",
))
}
pub fn shutdown(&self) -> io::Result<()> {
match self {
Socket::Tcp(s) => s.shutdown(Shutdown::Both),

View file

@ -31,9 +31,15 @@ pub struct PgStream {
impl PgStream {
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
let inner = BufStream::new(MaybeTlsStream::Raw(
Socket::connect(&options.host, options.port).await?,
));
let socket = match options.socket {
Some(ref path) => {
Socket::connect_uds(&format!("{}/.s.PGSQL.{}", path.display(), options.port))
.await?
}
None => Socket::connect(&options.host, options.port).await?,
};
let inner = BufStream::new(MaybeTlsStream::Raw(socket));
Ok(Self {
inner,

View file

@ -76,7 +76,7 @@ impl FromStr for PgSslMode {
/// | `sslmode` | `prefer` | Determines whether or with what priority a secure SSL TCP/IP connection will be negotiated. See [`PgSqlSslMode`]. |
/// | `sslrootcert` | `None` | Sets the name of a file containing a list of trusted SSL Certificate Authorities. |
/// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. Set to `0` to disable. |
///
/// | `host` | `None` | Path to the directory containing a PostgreSQL unix domain socket, which will be used instead of TCP if set. |
///
/// The URI scheme designator can be either `postgresql://` or `postgres://`.
/// Each of the URI parts is optional.
@ -121,6 +121,7 @@ impl FromStr for PgSslMode {
pub struct PgConnectOptions {
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) socket: Option<PathBuf>,
pub(crate) username: String,
pub(crate) password: Option<String>,
pub(crate) database: Option<String>,
@ -166,6 +167,7 @@ impl PgConnectOptions {
PgConnectOptions {
port,
host,
socket: None,
username: var("PGUSER").ok().unwrap_or_else(whoami::username),
password: var("PGPASSWORD").ok(),
database: var("PGDATABASE").ok(),
@ -215,6 +217,25 @@ impl PgConnectOptions {
self
}
/// Sets a custom path to a directory containing a unix domain socket,
/// switching the connection method from TCP to the corresponding socket.
///
/// By default set to `None`.
#[cfg(unix)]
pub fn socket(mut self, path: impl AsRef<Path>) -> Self {
self.socket = Some(path.as_ref().to_path_buf());
self
}
/// Sets a custom path to a directory containing a unix domain socket,
/// switching the connection method from TCP to the corresponding socket.
///
/// By default set to `None`.
#[cfg(not(unix))]
pub fn socket(mut self, _: impl AsRef<Path>) -> Self {
self
}
/// Sets the username to connect as.
///
/// Defaults to be the same as the operating system name of
@ -373,6 +394,10 @@ impl FromStr for PgConnectOptions {
options = options.statement_cache_capacity(value.parse()?);
}
"host" => {
options = options.socket(&*value);
}
_ => {}
}
}