refactor(postgres): make better use of traits to improve protocol handling

This commit is contained in:
Austin Bonander 2024-08-17 04:54:40 -07:00
parent 9b3808b2d5
commit 53766e4659
40 changed files with 1252 additions and 693 deletions

View file

@ -414,7 +414,8 @@ impl<'lock, C: AsMut<PgConnection>> Drop for PgAdvisoryLockGuard<'lock, C> {
// The `async fn` versions can safely use the prepared statement protocol,
// but this is the safest way to queue a query to execute on the next opportunity.
conn.as_mut()
.queue_simple_query(self.lock.get_release_query());
.queue_simple_query(self.lock.get_release_query())
.expect("BUG: PgAdvisoryLock::get_release_query() somehow too long for protocol");
}
}
}

View file

@ -145,6 +145,7 @@ impl<'q> Arguments<'q> for PgArguments {
write!(writer, "${}", self.buffer.count)
}
#[inline(always)]
fn len(&self) -> usize {
self.buffer.count
}

View file

@ -1,5 +1,6 @@
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::io::StatementId;
use crate::message::{ParameterDescription, RowDescription};
use crate::query_as::query_as;
use crate::query_scalar::query_scalar;
@ -27,10 +28,12 @@ enum TypType {
Range,
}
impl TryFrom<u8> for TypType {
impl TryFrom<i8> for TypType {
type Error = ();
fn try_from(t: u8) -> Result<Self, Self::Error> {
fn try_from(t: i8) -> Result<Self, Self::Error> {
let t = u8::try_from(t).or(Err(()))?;
let t = match t {
b'b' => Self::Base,
b'c' => Self::Composite,
@ -66,10 +69,12 @@ enum TypCategory {
Unknown,
}
impl TryFrom<u8> for TypCategory {
impl TryFrom<i8> for TypCategory {
type Error = ();
fn try_from(c: u8) -> Result<Self, Self::Error> {
fn try_from(c: i8) -> Result<Self, Self::Error> {
let c = u8::try_from(c).or(Err(()))?;
let c = match c {
b'A' => Self::Array,
b'B' => Self::Boolean,
@ -209,8 +214,8 @@ impl PgConnection {
.fetch_one(&mut *self)
.await?;
let typ_type = TypType::try_from(typ_type as u8);
let category = TypCategory::try_from(category as u8);
let typ_type = TypType::try_from(typ_type);
let category = TypCategory::try_from(category);
match (typ_type, category) {
(Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
@ -416,7 +421,7 @@ WHERE rngtypid = $1
pub(crate) async fn get_nullable_for_columns(
&mut self,
stmt_id: Oid,
stmt_id: StatementId,
meta: &PgStatementMetadata,
) -> Result<Vec<Option<bool>>, Error> {
if meta.columns.is_empty() {
@ -486,13 +491,10 @@ WHERE rngtypid = $1
/// and returns `None` for all others.
async fn nullables_from_explain(
&mut self,
stmt_id: Oid,
stmt_id: StatementId,
params_len: usize,
) -> Result<Vec<Option<bool>>, Error> {
let mut explain = format!(
"EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE sqlx_s_{}",
stmt_id.0
);
let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id}");
let mut comma = false;
if params_len > 0 {

View file

@ -3,11 +3,10 @@ use crate::HashMap;
use crate::common::StatementCache;
use crate::connection::{sasl, stream::PgStream};
use crate::error::Error;
use crate::io::Decode;
use crate::io::StatementId;
use crate::message::{
Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup,
Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup,
};
use crate::types::Oid;
use crate::{PgConnectOptions, PgConnection};
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3
@ -44,13 +43,13 @@ impl PgConnection {
params.push(("options", options));
}
stream
.send(Startup {
username: Some(&options.username),
database: options.database.as_deref(),
params: &params,
})
.await?;
stream.write(Startup {
username: Some(&options.username),
database: options.database.as_deref(),
params: &params,
})?;
stream.flush().await?;
// The server then uses this information and the contents of
// its configuration files (such as pg_hba.conf) to determine whether the connection is
@ -64,7 +63,7 @@ impl PgConnection {
loop {
let message = stream.recv().await?;
match message.format {
MessageFormat::Authentication => match message.decode()? {
BackendMessageFormat::Authentication => match message.decode()? {
Authentication::Ok => {
// the authentication exchange is successfully completed
// do nothing; no more information is required to continue
@ -108,7 +107,7 @@ impl PgConnection {
}
},
MessageFormat::BackendKeyData => {
BackendMessageFormat::BackendKeyData => {
// provides secret-key data that the frontend must save if it wants to be
// able to issue cancel requests later
@ -118,10 +117,9 @@ impl PgConnection {
secret_key = data.secret_key;
}
MessageFormat::ReadyForQuery => {
BackendMessageFormat::ReadyForQuery => {
// start-up is completed. The frontend can now issue commands
transaction_status =
ReadyForQuery::decode(message.contents)?.transaction_status;
transaction_status = message.decode::<ReadyForQuery>()?.transaction_status;
break;
}
@ -142,7 +140,7 @@ impl PgConnection {
transaction_status,
transaction_depth: 0,
pending_ready_for_query_count: 0,
next_statement_id: Oid(1),
next_statement_id: StatementId::NAMED_START,
cache_statement: StatementCache::new(options.statement_cache_capacity),
cache_type_oid: HashMap::new(),
cache_type_info: HashMap::new(),

View file

@ -1,13 +1,13 @@
use crate::describe::Describe;
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::io::{PortalId, StatementId};
use crate::logger::QueryLogger;
use crate::message::{
self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query,
RowDescription,
self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
ParseComplete, Query, RowDescription,
};
use crate::statement::PgStatementMetadata;
use crate::types::Oid;
use crate::{
statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
PgValueFormat, Postgres,
@ -16,6 +16,7 @@ use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;
use futures_core::Stream;
use futures_util::{pin_mut, TryStreamExt};
use sqlx_core::arguments::Arguments;
use sqlx_core::Either;
use std::{borrow::Cow, sync::Arc};
@ -24,9 +25,9 @@ async fn prepare(
sql: &str,
parameters: &[PgTypeInfo],
metadata: Option<Arc<PgStatementMetadata>>,
) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
let id = conn.next_statement_id;
conn.next_statement_id.incr_one();
conn.next_statement_id = id.next();
// build a list of type OIDs to send to the database in the PARSE command
// we have not yet started the query sequence, so we are *safe* to cleanly make
@ -42,15 +43,15 @@ async fn prepare(
conn.wait_until_ready().await?;
// next we send the PARSE command to the server
conn.stream.write(Parse {
conn.stream.write_msg(Parse {
param_types: &param_types,
query: sql,
statement: id,
});
})?;
if metadata.is_none() {
// get the statement columns and parameters
conn.stream.write(message::Describe::Statement(id));
conn.stream.write_msg(message::Describe::Statement(id))?;
}
// we ask for the server to immediately send us the result of the PARSE command
@ -58,9 +59,7 @@ async fn prepare(
conn.stream.flush().await?;
// indicates that the SQL query string is now successfully parsed and has semantic validity
conn.stream
.recv_expect(MessageFormat::ParseComplete)
.await?;
conn.stream.recv_expect::<ParseComplete>().await?;
let metadata = if let Some(metadata) = metadata {
// each SYNC produces one READY FOR QUERY
@ -95,18 +94,18 @@ async fn prepare(
}
async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
conn.stream
.recv_expect(MessageFormat::ParameterDescription)
.await
conn.stream.recv_expect().await
}
async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
let rows: Option<RowDescription> = match conn.stream.recv().await? {
// describes the rows that will be returned when the statement is eventually executed
message if message.format == MessageFormat::RowDescription => Some(message.decode()?),
message if message.format == BackendMessageFormat::RowDescription => {
Some(message.decode()?)
}
// no data would be returned if this statement was executed
message if message.format == MessageFormat::NoData => None,
message if message.format == BackendMessageFormat::NoData => None,
message => {
return Err(err_protocol!(
@ -125,12 +124,12 @@ impl PgConnection {
// we need to wait for the [CloseComplete] to be returned from the server
while count > 0 {
match self.stream.recv().await? {
message if message.format == MessageFormat::PortalSuspended => {
message if message.format == BackendMessageFormat::PortalSuspended => {
// there was an open portal
// this can happen if the last time a statement was used it was not fully executed
}
message if message.format == MessageFormat::CloseComplete => {
message if message.format == BackendMessageFormat::CloseComplete => {
// successfully closed the statement (and freed up the server resources)
count -= 1;
}
@ -147,8 +146,11 @@ impl PgConnection {
Ok(())
}
#[inline(always)]
pub(crate) fn write_sync(&mut self) {
self.stream.write(message::Sync);
self.stream
.write_msg(message::Sync)
.expect("BUG: Sync should not be too big for protocol");
// all SYNC messages will return a ReadyForQuery
self.pending_ready_for_query_count += 1;
@ -163,7 +165,7 @@ impl PgConnection {
// optional metadata that was provided by the user, this means they are reusing
// a statement object
metadata: Option<Arc<PgStatementMetadata>>,
) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
if let Some(statement) = self.cache_statement.get_mut(sql) {
return Ok((*statement).clone());
}
@ -172,7 +174,7 @@ impl PgConnection {
if store_to_cache && self.cache_statement.is_enabled() {
if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
self.stream.write(Close::Statement(id));
self.stream.write_msg(Close::Statement(id))?;
self.write_sync();
self.stream.flush().await?;
@ -201,6 +203,14 @@ impl PgConnection {
let mut metadata: Arc<PgStatementMetadata>;
let format = if let Some(mut arguments) = arguments {
// Check this before we write anything to the stream.
let num_params = i16::try_from(arguments.len()).map_err(|_| {
err_protocol!(
"PgConnection::run(): too many arguments for query: {}",
arguments.len()
)
})?;
// prepare the statement if this our first time executing it
// always return the statement ID here
let (statement, metadata_) = self
@ -216,21 +226,21 @@ impl PgConnection {
self.wait_until_ready().await?;
// bind to attach the arguments to the statement and create a portal
self.stream.write(Bind {
portal: None,
self.stream.write_msg(Bind {
portal: PortalId::UNNAMED,
statement,
formats: &[PgValueFormat::Binary],
num_params: arguments.types.len() as i16,
num_params,
params: &arguments.buffer,
result_formats: &[PgValueFormat::Binary],
});
})?;
// executes the portal up to the passed limit
// the protocol-level limit acts nearly identically to the `LIMIT` in SQL
self.stream.write(message::Execute {
portal: None,
self.stream.write_msg(message::Execute {
portal: PortalId::UNNAMED,
limit: limit.into(),
});
})?;
// From https://www.postgresql.org/docs/current/protocol-flow.html:
//
// "An unnamed portal is destroyed at the end of the transaction, or as
@ -240,7 +250,7 @@ impl PgConnection {
// we ask the database server to close the unnamed portal and free the associated resources
// earlier - after the execution of the current query.
self.stream.write(message::Close::Portal(None));
self.stream.write_msg(Close::Portal(PortalId::UNNAMED))?;
// finally, [Sync] asks postgres to process the messages that we sent and respond with
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
@ -253,7 +263,7 @@ impl PgConnection {
PgValueFormat::Binary
} else {
// Query will trigger a ReadyForQuery
self.stream.write(Query(query));
self.stream.write_msg(Query(query))?;
self.pending_ready_for_query_count += 1;
// metadata starts out as "nothing"
@ -270,12 +280,12 @@ impl PgConnection {
let message = self.stream.recv().await?;
match message.format {
MessageFormat::BindComplete
| MessageFormat::ParseComplete
| MessageFormat::ParameterDescription
| MessageFormat::NoData
BackendMessageFormat::BindComplete
| BackendMessageFormat::ParseComplete
| BackendMessageFormat::ParameterDescription
| BackendMessageFormat::NoData
// unnamed portal has been closed
| MessageFormat::CloseComplete
| BackendMessageFormat::CloseComplete
=> {
// harmless messages to ignore
}
@ -284,7 +294,7 @@ impl PgConnection {
// exactly one of these messages: CommandComplete,
// EmptyQueryResponse (if the portal was created from an
// empty query string), ErrorResponse, or PortalSuspended"
MessageFormat::CommandComplete => {
BackendMessageFormat::CommandComplete => {
// a SQL command completed normally
let cc: CommandComplete = message.decode()?;
@ -295,16 +305,16 @@ impl PgConnection {
}));
}
MessageFormat::EmptyQueryResponse => {
BackendMessageFormat::EmptyQueryResponse => {
// empty query string passed to an unprepared execute
}
// Message::ErrorResponse is handled in self.stream.recv()
// incomplete query execution has finished
MessageFormat::PortalSuspended => {}
BackendMessageFormat::PortalSuspended => {}
MessageFormat::RowDescription => {
BackendMessageFormat::RowDescription => {
// indicates that a *new* set of rows are about to be returned
let (columns, column_names) = self
.handle_row_description(Some(message.decode()?), false)
@ -317,7 +327,7 @@ impl PgConnection {
});
}
MessageFormat::DataRow => {
BackendMessageFormat::DataRow => {
logger.increment_rows_returned();
// one of the set of rows returned by a SELECT, FETCH, etc query
@ -331,7 +341,7 @@ impl PgConnection {
r#yield!(Either::Right(row));
}
MessageFormat::ReadyForQuery => {
BackendMessageFormat::ReadyForQuery => {
// processing of the query string is complete
self.handle_ready_for_query(message)?;
break;

View file

@ -8,9 +8,10 @@ use futures_util::FutureExt;
use crate::common::StatementCache;
use crate::error::Error;
use crate::ext::ustr::UStr;
use crate::io::Decode;
use crate::io::StatementId;
use crate::message::{
Close, Message, MessageFormat, Query, ReadyForQuery, Terminate, TransactionStatus,
BackendMessageFormat, Close, Query, ReadyForQuery, ReceivedMessage, Terminate,
TransactionStatus,
};
use crate::statement::PgStatementMetadata;
use crate::transaction::Transaction;
@ -47,10 +48,10 @@ pub struct PgConnection {
// sequence of statement IDs for use in preparing statements
// in PostgreSQL, the statement is prepared to a user-supplied identifier
next_statement_id: Oid,
next_statement_id: StatementId,
// cache statement by query string to the id and columns
cache_statement: StatementCache<(Oid, Arc<PgStatementMetadata>)>,
cache_statement: StatementCache<(StatementId, Arc<PgStatementMetadata>)>,
// cache user-defined types by id <-> info
cache_type_info: HashMap<Oid, PgTypeInfo>,
@ -82,7 +83,7 @@ impl PgConnection {
while self.pending_ready_for_query_count > 0 {
let message = self.stream.recv().await?;
if let MessageFormat::ReadyForQuery = message.format {
if let BackendMessageFormat::ReadyForQuery = message.format {
self.handle_ready_for_query(message)?;
}
}
@ -91,10 +92,7 @@ impl PgConnection {
}
async fn recv_ready_for_query(&mut self) -> Result<(), Error> {
let r: ReadyForQuery = self
.stream
.recv_expect(MessageFormat::ReadyForQuery)
.await?;
let r: ReadyForQuery = self.stream.recv_expect().await?;
self.pending_ready_for_query_count -= 1;
self.transaction_status = r.transaction_status;
@ -102,9 +100,10 @@ impl PgConnection {
Ok(())
}
fn handle_ready_for_query(&mut self, message: Message) -> Result<(), Error> {
#[inline(always)]
fn handle_ready_for_query(&mut self, message: ReceivedMessage) -> Result<(), Error> {
self.pending_ready_for_query_count -= 1;
self.transaction_status = ReadyForQuery::decode(message.contents)?.transaction_status;
self.transaction_status = message.decode::<ReadyForQuery>()?.transaction_status;
Ok(())
}
@ -112,9 +111,12 @@ impl PgConnection {
/// Queue a simple query (not prepared) to execute the next time this connection is used.
///
/// Used for rolling back transactions and releasing advisory locks.
pub(crate) fn queue_simple_query(&mut self, query: &str) {
#[inline(always)]
pub(crate) fn queue_simple_query(&mut self, query: &str) -> Result<(), Error> {
self.stream.write_msg(Query(query))?;
self.pending_ready_for_query_count += 1;
self.stream.write(Query(query));
Ok(())
}
}
@ -184,7 +186,7 @@ impl Connection for PgConnection {
self.wait_until_ready().await?;
while let Some((id, _)) = self.cache_statement.remove_lru() {
self.stream.write(Close::Statement(id));
self.stream.write_msg(Close::Statement(id))?;
cleared += 1;
}

View file

@ -1,8 +1,6 @@
use crate::connection::stream::PgStream;
use crate::error::Error;
use crate::message::{
Authentication, AuthenticationSasl, MessageFormat, SaslInitialResponse, SaslResponse,
};
use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse};
use crate::PgConnectOptions;
use hmac::{Hmac, Mac};
use rand::Rng;
@ -76,7 +74,7 @@ pub(crate) async fn authenticate(
})
.await?;
let cont = match stream.recv_expect(MessageFormat::Authentication).await? {
let cont = match stream.recv_expect().await? {
Authentication::SaslContinue(data) => data,
auth => {
@ -147,7 +145,7 @@ pub(crate) async fn authenticate(
stream.send(SaslResponse(&client_final_message)).await?;
let data = match stream.recv_expect(MessageFormat::Authentication).await? {
let data = match stream.recv_expect().await? {
Authentication::SaslFinal(data) => data,
auth => {
@ -172,10 +170,10 @@ fn gen_nonce() -> String {
// ;; a valid "value".
let nonce: String = std::iter::repeat(())
.map(|()| {
let mut c = rng.gen_range(0x21..0x7F) as u8;
let mut c = rng.gen_range(0x21u8..0x7F);
while c == 0x2C {
c = rng.gen_range(0x21..0x7F) as u8;
c = rng.gen_range(0x21u8..0x7F);
}
c

View file

@ -9,8 +9,10 @@ use sqlx_core::bytes::{Buf, Bytes};
use crate::connection::tls::MaybeUpgradeTls;
use crate::error::Error;
use crate::io::{Decode, Encode};
use crate::message::{Message, MessageFormat, Notice, Notification, ParameterStatus};
use crate::message::{
BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification,
ParameterStatus, ReceivedMessage,
};
use crate::net::{self, BufferedSocket, Socket};
use crate::{PgConnectOptions, PgDatabaseError, PgSeverity};
@ -55,59 +57,51 @@ impl PgStream {
})
}
pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error>
#[inline(always)]
pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> {
self.write(EncodeMessage(message))
}
pub(crate) async fn send<T>(&mut self, message: T) -> Result<(), Error>
where
T: Encode<'en>,
T: FrontendMessage,
{
self.write(message);
self.write_msg(message)?;
self.flush().await?;
Ok(())
}
// Expect a specific type and format
pub(crate) async fn recv_expect<'de, T: Decode<'de>>(
&mut self,
format: MessageFormat,
) -> Result<T, Error> {
let message = self.recv().await?;
if message.format != format {
return Err(err_protocol!(
"expecting {:?} but received {:?}",
format,
message.format
));
}
message.decode()
pub(crate) async fn recv_expect<B: BackendMessage>(&mut self) -> Result<B, Error> {
self.recv().await?.decode()
}
pub(crate) async fn recv_unchecked(&mut self) -> Result<Message, Error> {
pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
// all packets in postgres start with a 5-byte header
// this header contains the message type and the total length of the message
let mut header: Bytes = self.inner.read(5).await?;
let format = MessageFormat::try_from_u8(header.get_u8())?;
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
let size = (header.get_u32() - 4) as usize;
let contents = self.inner.read(size).await?;
Ok(Message { format, contents })
Ok(ReceivedMessage { format, contents })
}
// Get the next message from the server
// May wait for more data from the server
pub(crate) async fn recv(&mut self) -> Result<Message, Error> {
pub(crate) async fn recv(&mut self) -> Result<ReceivedMessage, Error> {
loop {
let message = self.recv_unchecked().await?;
match message.format {
MessageFormat::ErrorResponse => {
BackendMessageFormat::ErrorResponse => {
// An error returned from the database server.
return Err(PgDatabaseError(message.decode()?).into());
}
MessageFormat::NotificationResponse => {
BackendMessageFormat::NotificationResponse => {
if let Some(buffer) = &mut self.notifications {
let notification: Notification = message.decode()?;
let _ = buffer.send(notification).await;
@ -116,7 +110,7 @@ impl PgStream {
}
}
MessageFormat::ParameterStatus => {
BackendMessageFormat::ParameterStatus => {
// informs the frontend about the current (initial)
// setting of backend parameters
@ -135,7 +129,7 @@ impl PgStream {
continue;
}
MessageFormat::NoticeResponse => {
BackendMessageFormat::NoticeResponse => {
// do we need this to be more configurable?
// if you are reading this comment and think so, open an issue

View file

@ -11,7 +11,8 @@ use crate::error::{Error, Result};
use crate::ext::async_stream::TryAsyncStream;
use crate::io::AsyncRead;
use crate::message::{
CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse,
CopyOutResponse, CopyResponseData, Query, ReadyForQuery,
};
use crate::pool::{Pool, PoolConnection};
use crate::Postgres;
@ -138,7 +139,7 @@ impl PgPoolCopyExt for Pool<Postgres> {
#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
conn: Option<C>,
response: CopyResponse,
response: CopyResponseData,
}
impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
@ -146,8 +147,8 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
conn.wait_until_ready().await?;
conn.stream.send(Query(statement)).await?;
let response = match conn.stream.recv_expect(MessageFormat::CopyInResponse).await {
Ok(res) => res,
let response = match conn.stream.recv_expect::<CopyInResponse>().await {
Ok(res) => res.0,
Err(e) => {
conn.stream.recv().await?;
return Err(e);
@ -168,7 +169,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
/// Returns the number of columns expected in the input.
pub fn num_columns(&self) -> usize {
assert_eq!(
self.response.num_columns as usize,
self.response.num_columns.unsigned_abs() as usize,
self.response.format_codes.len(),
"num_columns does not match format_codes.len()"
);
@ -261,9 +262,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
match e.code() {
Some(Cow::Borrowed("57014")) => {
// postgres abort received error code
conn.stream
.recv_expect(MessageFormat::ReadyForQuery)
.await?;
conn.stream.recv_expect::<ReadyForQuery>().await?;
Ok(())
}
_ => Err(Error::Database(e)),
@ -283,11 +282,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
.expect("CopyWriter::finish: conn taken illegally");
conn.stream.send(CopyDone).await?;
let cc: CommandComplete = match conn
.stream
.recv_expect(MessageFormat::CommandComplete)
.await
{
let cc: CommandComplete = match conn.stream.recv_expect().await {
Ok(cc) => cc,
Err(e) => {
conn.stream.recv().await?;
@ -295,9 +290,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
}
};
conn.stream
.recv_expect(MessageFormat::ReadyForQuery)
.await?;
conn.stream.recv_expect::<ReadyForQuery>().await?;
Ok(cc.rows_affected())
}
@ -306,9 +299,11 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
fn drop(&mut self) {
if let Some(mut conn) = self.conn.take() {
conn.stream.write(CopyFail::new(
"PgCopyIn dropped without calling finish() or fail()",
));
conn.stream
.write_msg(CopyFail::new(
"PgCopyIn dropped without calling finish() or fail()",
))
.expect("BUG: PgCopyIn abort message should not be too large");
}
}
}
@ -320,24 +315,21 @@ async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
conn.wait_until_ready().await?;
conn.stream.send(Query(statement)).await?;
let _: CopyResponse = conn
.stream
.recv_expect(MessageFormat::CopyOutResponse)
.await?;
let _: CopyOutResponse = conn.stream.recv_expect().await?;
let stream: TryAsyncStream<'c, Bytes> = try_stream! {
loop {
match conn.stream.recv().await {
Err(e) => {
conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?;
conn.stream.recv_expect::<ReadyForQuery>().await?;
return Err(e);
},
Ok(msg) => match msg.format {
MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
MessageFormat::CopyDone => {
BackendMessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
BackendMessageFormat::CopyDone => {
let _ = msg.decode::<CopyDone>()?;
conn.stream.recv_expect(MessageFormat::CommandComplete).await?;
conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?;
conn.stream.recv_expect::<CommandComplete>().await?;
conn.stream.recv_expect::<ReadyForQuery>().await?;
return Ok(())
},
_ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))

View file

@ -1,54 +1,64 @@
use crate::types::Oid;
use crate::io::{PortalId, StatementId};
pub trait PgBufMutExt {
fn put_length_prefixed<F>(&mut self, f: F)
fn put_length_prefixed<F>(&mut self, f: F) -> Result<(), crate::Error>
where
F: FnOnce(&mut Vec<u8>);
F: FnOnce(&mut Vec<u8>) -> Result<(), crate::Error>;
fn put_statement_name(&mut self, id: Oid);
fn put_statement_name(&mut self, id: StatementId);
fn put_portal_name(&mut self, id: Option<Oid>);
fn put_portal_name(&mut self, id: PortalId);
}
impl PgBufMutExt for Vec<u8> {
// writes a length-prefixed message, this is used when encoding nearly all messages as postgres
// wants us to send the length of the often-variable-sized messages up front
fn put_length_prefixed<F>(&mut self, f: F)
fn put_length_prefixed<F>(&mut self, write_contents: F) -> Result<(), crate::Error>
where
F: FnOnce(&mut Vec<u8>),
F: FnOnce(&mut Vec<u8>) -> Result<(), crate::Error>,
{
// reserve space to write the prefixed length
let offset = self.len();
self.extend(&[0; 4]);
// write the main body of the message
f(self);
let write_result = write_contents(self);
// now calculate the size of what we wrote and set the length value
let size = (self.len() - offset) as i32;
self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes());
let size_result = write_result.and_then(|_| {
let size = self.len() - offset;
i32::try_from(size)
.map_err(|_| err_protocol!("message size out of range for Pg protocol: {size"))
});
match size_result {
Ok(size) => {
// now calculate the size of what we wrote and set the length value
self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes());
Ok(())
}
Err(e) => {
// Put the buffer back to where it was.
self.truncate(offset);
Err(e)
}
}
}
// writes a statement name by ID
#[inline]
fn put_statement_name(&mut self, id: Oid) {
// N.B. if you change this don't forget to update it in ../describe.rs
self.extend(b"sqlx_s_");
self.extend(itoa::Buffer::new().format(id.0).as_bytes());
self.push(0);
fn put_statement_name(&mut self, id: StatementId) {
let _: Result<(), ()> = id.write_name(|s| {
self.extend_from_slice(s.as_bytes());
Ok(())
});
}
// writes a portal name by ID
#[inline]
fn put_portal_name(&mut self, id: Option<Oid>) {
if let Some(id) = id {
self.extend(b"sqlx_p_");
self.extend(itoa::Buffer::new().format(id.0).as_bytes());
}
self.push(0);
fn put_portal_name(&mut self, id: PortalId) {
let _: Result<(), ()> = id.write_name(|s| {
self.extend_from_slice(s.as_bytes());
Ok(())
});
}
}

View file

@ -1,5 +1,130 @@
mod buf_mut;
pub use buf_mut::PgBufMutExt;
use std::fmt;
use std::fmt::{Display, Formatter};
use std::num::{NonZeroU32, Saturating};
pub(crate) use sqlx_core::io::*;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) struct StatementId(IdInner);
#[allow(dead_code)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) struct PortalId(IdInner);
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
struct IdInner(Option<NonZeroU32>);
impl StatementId {
pub const UNNAMED: Self = Self(IdInner::UNNAMED);
pub const NAMED_START: Self = Self(IdInner::NAMED_START);
#[cfg(test)]
pub const TEST_VAL: Self = Self(IdInner::TEST_VAL);
const NAME_PREFIX: &'static str = "sqlx_s_";
pub fn next(&self) -> Self {
Self(self.0.next())
}
pub fn name_len(&self) -> Saturating<usize> {
self.0.name_len(Self::NAME_PREFIX)
}
// There's no common trait implemented by `Formatter` and `Vec<u8>` for this purpose;
// we're deliberately avoiding the formatting machinery because it's known to be slow.
pub fn write_name<E>(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> {
self.0.write_name(Self::NAME_PREFIX, write)
}
}
impl Display for StatementId {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.write_name(|s| f.write_str(s))
}
}
#[allow(dead_code)]
impl PortalId {
// None selects the unnamed portal
pub const UNNAMED: Self = PortalId(IdInner::UNNAMED);
pub const NAMED_START: Self = PortalId(IdInner::NAMED_START);
#[cfg(test)]
pub const TEST_VAL: Self = Self(IdInner::TEST_VAL);
const NAME_PREFIX: &'static str = "sqlx_p_";
/// If ID represents a named portal, return the next ID, wrapping on overflow.
///
/// If this ID represents the unnamed portal, return the same.
pub fn next(&self) -> Self {
Self(self.0.next())
}
/// Calculate the number of bytes that will be written by [`Self::write_name()`].
pub fn name_len(&self) -> Saturating<usize> {
self.0.name_len(Self::NAME_PREFIX)
}
pub fn write_name<E>(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> {
self.0.write_name(Self::NAME_PREFIX, write)
}
}
impl IdInner {
const UNNAMED: Self = Self(None);
const NAMED_START: Self = Self(Some(NonZeroU32::MIN));
#[cfg(test)]
pub const TEST_VAL: Self = Self(NonZeroU32::new(1234567890));
#[inline(always)]
fn next(&self) -> Self {
Self(
self.0
.map(|id| id.checked_add(1).unwrap_or(NonZeroU32::MIN)),
)
}
#[inline(always)]
fn name_len(&self, name_prefix: &str) -> Saturating<usize> {
let mut len = Saturating(0);
if let Some(id) = self.0 {
len += name_prefix.len();
// estimate the length of the ID in decimal
// `.ilog10()` can't panic since the value is never zero
len += id.get().ilog10() as usize;
// add one to compensate for `ilog10()` rounding down.
len += 1;
}
// count the NUL terminator
len += 1;
len
}
#[inline(always)]
fn write_name<E>(
&self,
name_prefix: &str,
mut write: impl FnMut(&str) -> Result<(), E>,
) -> Result<(), E> {
if let Some(id) = self.0 {
write(name_prefix)?;
write(itoa::Buffer::new().format(id.get()))?;
}
write("\0")?;
Ok(())
}
}

View file

@ -11,7 +11,7 @@ use sqlx_core::Either;
use crate::describe::Describe;
use crate::error::Error;
use crate::executor::{Execute, Executor};
use crate::message::{MessageFormat, Notification};
use crate::message::{BackendMessageFormat, Notification};
use crate::pool::PoolOptions;
use crate::pool::{Pool, PoolConnection};
use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
@ -277,12 +277,12 @@ impl PgListener {
match message.format {
// We've received an async notification, return it.
MessageFormat::NotificationResponse => {
BackendMessageFormat::NotificationResponse => {
return Ok(Some(PgNotification(message.decode()?)));
}
// Mark the connection as ready for another query
MessageFormat::ReadyForQuery => {
BackendMessageFormat::ReadyForQuery => {
self.connection().await?.pending_ready_for_query_count -= 1;
}

View file

@ -4,10 +4,10 @@ use memchr::memchr;
use sqlx_core::bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::io::ProtocolDecode;
use crate::message::{BackendMessage, BackendMessageFormat};
use base64::prelude::{Engine as _, BASE64_STANDARD};
// On startup, the server sends an appropriate authentication request message,
// to which the frontend must reply with an appropriate authentication
// response message (such as a password).
@ -60,8 +60,10 @@ pub enum Authentication {
SaslFinal(AuthenticationSaslFinal),
}
impl Decode<'_> for Authentication {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
impl BackendMessage for Authentication {
const FORMAT: BackendMessageFormat = BackendMessageFormat::Authentication;
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
Ok(match buf.get_u32() {
0 => Authentication::Ok,
@ -129,7 +131,7 @@ pub struct AuthenticationSaslContinue {
pub message: String,
}
impl Decode<'_> for AuthenticationSaslContinue {
impl ProtocolDecode<'_> for AuthenticationSaslContinue {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
let mut iterations: u32 = 4096;
let mut salt = Vec::new();
@ -173,7 +175,7 @@ pub struct AuthenticationSaslFinal {
pub verifier: Vec<u8>,
}
impl Decode<'_> for AuthenticationSaslFinal {
impl ProtocolDecode<'_> for AuthenticationSaslFinal {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
let mut verifier = Vec::new();

View file

@ -2,7 +2,7 @@ use byteorder::{BigEndian, ByteOrder};
use sqlx_core::bytes::Bytes;
use crate::error::Error;
use crate::io::Decode;
use crate::message::{BackendMessage, BackendMessageFormat};
/// Contains cancellation key data. The frontend must save these values if it
/// wishes to be able to issue `CancelRequest` messages later.
@ -15,8 +15,10 @@ pub struct BackendKeyData {
pub secret_key: u32,
}
impl Decode<'_> for BackendKeyData {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
impl BackendMessage for BackendKeyData {
const FORMAT: BackendMessageFormat = BackendMessageFormat::BackendKeyData;
fn decode_body(buf: Bytes) -> Result<Self, Error> {
let process_id = BigEndian::read_u32(&buf);
let secret_key = BigEndian::read_u32(&buf[4..]);
@ -31,7 +33,7 @@ impl Decode<'_> for BackendKeyData {
fn test_decode_backend_key_data() {
const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+";
let m = BackendKeyData::decode(DATA.into()).unwrap();
let m = BackendKeyData::decode_body(DATA.into()).unwrap();
assert_eq!(m.process_id, 10182);
assert_eq!(m.secret_key, 2303903019);
@ -43,6 +45,6 @@ fn bench_decode_backend_key_data(b: &mut test::Bencher) {
const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+";
b.iter(|| {
BackendKeyData::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
BackendKeyData::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
});
}

View file

@ -1,15 +1,15 @@
use crate::io::Encode;
use crate::io::PgBufMutExt;
use crate::types::Oid;
use crate::io::{PgBufMutExt, PortalId, StatementId};
use crate::message::{FrontendMessage, FrontendMessageFormat};
use crate::PgValueFormat;
use std::num::Saturating;
#[derive(Debug)]
pub struct Bind<'a> {
/// The ID of the destination portal (`None` selects the unnamed portal).
pub portal: Option<Oid>,
/// The ID of the destination portal (`PortalId::UNNAMED` selects the unnamed portal).
pub portal: PortalId,
/// The id of the source prepared statement.
pub statement: Oid,
pub statement: StatementId,
/// The parameter format codes. Each must presently be zero (text) or one (binary).
///
@ -19,6 +19,8 @@ pub struct Bind<'a> {
pub formats: &'a [PgValueFormat],
/// The number of parameters.
///
/// May be different from `formats.len()`
pub num_params: i16,
/// The value of each parameter, in the indicated format.
@ -33,31 +35,59 @@ pub struct Bind<'a> {
pub result_formats: &'a [PgValueFormat],
}
impl Encode<'_> for Bind<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.push(b'B');
impl FrontendMessage for Bind<'_> {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Bind;
buf.put_length_prefixed(|buf| {
buf.put_portal_name(self.portal);
fn body_size_hint(&self) -> Saturating<usize> {
let mut size = Saturating(0);
size += self.portal.name_len();
size += self.statement.name_len();
buf.put_statement_name(self.statement);
// Parameter formats and length prefix
size += 2;
size += self.formats.len();
buf.extend(&(self.formats.len() as i16).to_be_bytes());
// `num_params`
size += 2;
for &format in self.formats {
buf.extend(&(format as i16).to_be_bytes());
}
size += self.params.len();
buf.extend(&self.num_params.to_be_bytes());
// Result formats and length prefix
size += 2;
size += self.result_formats.len();
buf.extend(self.params);
size
}
buf.extend(&(self.result_formats.len() as i16).to_be_bytes());
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), crate::Error> {
buf.put_portal_name(self.portal);
for &format in self.result_formats {
buf.extend(&(format as i16).to_be_bytes());
}
});
buf.put_statement_name(self.statement);
let formats_len = i16::try_from(self.formats.len()).map_err(|_| {
err_protocol!("too many parameter format codes ({})", self.formats.len())
})?;
buf.extend(formats_len.to_be_bytes());
for &format in self.formats {
buf.extend((format as i16).to_be_bytes());
}
buf.extend(self.num_params.to_be_bytes());
buf.extend(self.params);
let result_formats_len = i16::try_from(self.formats.len())
.map_err(|_| err_protocol!("too many result format codes ({})", self.formats.len()))?;
buf.extend(result_formats_len.to_be_bytes());
for &format in self.result_formats {
buf.extend((format as i16).to_be_bytes());
}
Ok(())
}
}

View file

@ -1,6 +1,6 @@
use crate::io::Encode;
use crate::io::PgBufMutExt;
use crate::types::Oid;
use crate::io::{PgBufMutExt, PortalId, StatementId};
use crate::message::{FrontendMessage, FrontendMessageFormat};
use std::num::Saturating;
const CLOSE_PORTAL: u8 = b'P';
const CLOSE_STATEMENT: u8 = b'S';
@ -8,18 +8,27 @@ const CLOSE_STATEMENT: u8 = b'S';
#[derive(Debug)]
#[allow(dead_code)]
pub enum Close {
Statement(Oid),
// None selects the unnamed portal
Portal(Option<Oid>),
Statement(StatementId),
Portal(PortalId),
}
impl Encode<'_> for Close {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
// 15 bytes for 1-digit statement/portal IDs
buf.reserve(20);
buf.push(b'C');
impl FrontendMessage for Close {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Close;
buf.put_length_prefixed(|buf| match self {
fn body_size_hint(&self) -> Saturating<usize> {
// Either `CLOSE_PORTAL` or `CLOSE_STATEMENT`
let mut size = Saturating(1);
match self {
Close::Statement(id) => size += id.name_len(),
Close::Portal(id) => size += id.name_len(),
}
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), crate::Error> {
match self {
Close::Statement(id) => {
buf.push(CLOSE_STATEMENT);
buf.put_statement_name(*id);
@ -29,6 +38,8 @@ impl Encode<'_> for Close {
buf.push(CLOSE_PORTAL);
buf.put_portal_name(*id);
}
})
}
Ok(())
}
}

View file

@ -3,7 +3,7 @@ use memchr::memrchr;
use sqlx_core::bytes::Bytes;
use crate::error::Error;
use crate::io::Decode;
use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug)]
pub struct CommandComplete {
@ -12,10 +12,11 @@ pub struct CommandComplete {
tag: Bytes,
}
impl Decode<'_> for CommandComplete {
#[inline]
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
Ok(CommandComplete { tag: buf })
impl BackendMessage for CommandComplete {
const FORMAT: BackendMessageFormat = BackendMessageFormat::CommandComplete;
fn decode_body(bytes: Bytes) -> Result<Self, Error> {
Ok(CommandComplete { tag: bytes })
}
}
@ -35,7 +36,7 @@ impl CommandComplete {
fn test_decode_command_complete_for_insert() {
const DATA: &[u8] = b"INSERT 0 1214\0";
let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap();
let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
assert_eq!(cc.rows_affected(), 1214);
}
@ -44,7 +45,7 @@ fn test_decode_command_complete_for_insert() {
fn test_decode_command_complete_for_begin() {
const DATA: &[u8] = b"BEGIN\0";
let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap();
let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
assert_eq!(cc.rows_affected(), 0);
}
@ -53,7 +54,7 @@ fn test_decode_command_complete_for_begin() {
fn test_decode_command_complete_for_update() {
const DATA: &[u8] = b"UPDATE 5\0";
let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap();
let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
assert_eq!(cc.rows_affected(), 5);
}
@ -64,7 +65,7 @@ fn bench_decode_command_complete(b: &mut test::Bencher) {
const DATA: &[u8] = b"INSERT 0 1214\0";
b.iter(|| {
let _ = CommandComplete::decode(test::black_box(Bytes::from_static(DATA)));
let _ = CommandComplete::decode_body(test::black_box(Bytes::from_static(DATA)));
});
}
@ -73,7 +74,7 @@ fn bench_decode_command_complete(b: &mut test::Bencher) {
fn bench_decode_command_complete_rows_affected(b: &mut test::Bencher) {
const DATA: &[u8] = b"INSERT 0 1214\0";
let data = CommandComplete::decode(Bytes::from_static(DATA)).unwrap();
let data = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
b.iter(|| {
let _rows = test::black_box(&data).rows_affected();

View file

@ -1,15 +1,25 @@
use crate::error::Result;
use crate::io::{BufExt, BufMutExt, Decode, Encode};
use sqlx_core::bytes::{Buf, BufMut, Bytes};
use crate::io::BufMutExt;
use crate::message::{
BackendMessage, BackendMessageFormat, FrontendMessage, FrontendMessageFormat,
};
use sqlx_core::bytes::{Buf, Bytes};
use sqlx_core::Error;
use std::num::Saturating;
use std::ops::Deref;
/// The same structure is sent for both `CopyInResponse` and `CopyOutResponse`
pub struct CopyResponse {
pub struct CopyResponseData {
pub format: i8,
pub num_columns: i16,
pub format_codes: Vec<i16>,
}
pub struct CopyInResponse(pub CopyResponseData);
#[allow(dead_code)]
pub struct CopyOutResponse(pub CopyResponseData);
pub struct CopyData<B>(pub B);
pub struct CopyFail {
@ -18,14 +28,15 @@ pub struct CopyFail {
pub struct CopyDone;
impl Decode<'_> for CopyResponse {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self> {
impl CopyResponseData {
#[inline]
fn decode(mut buf: Bytes) -> Result<Self> {
let format = buf.get_i8();
let num_columns = buf.get_i16();
let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect();
Ok(CopyResponse {
Ok(CopyResponseData {
format,
num_columns,
format_codes,
@ -33,40 +44,65 @@ impl Decode<'_> for CopyResponse {
}
}
impl Decode<'_> for CopyData<Bytes> {
fn decode_with(buf: Bytes, _: ()) -> Result<Self> {
// well.. that was easy
Ok(CopyData(buf))
impl BackendMessage for CopyInResponse {
const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyInResponse;
#[inline(always)]
fn decode_body(buf: Bytes) -> std::result::Result<Self, Error> {
Ok(Self(CopyResponseData::decode(buf)?))
}
}
impl<B: Deref<Target = [u8]>> Encode<'_> for CopyData<B> {
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) {
buf.push(b'd');
buf.put_u32(self.0.len() as u32 + 4);
impl BackendMessage for CopyOutResponse {
const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyOutResponse;
#[inline(always)]
fn decode_body(buf: Bytes) -> std::result::Result<Self, Error> {
Ok(Self(CopyResponseData::decode(buf)?))
}
}
impl BackendMessage for CopyData<Bytes> {
const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyData;
#[inline(always)]
fn decode_body(buf: Bytes) -> std::result::Result<Self, Error> {
Ok(Self(buf))
}
}
impl<B: Deref<Target = [u8]>> FrontendMessage for CopyData<B> {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyData;
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(self.0.len())
}
#[inline(always)]
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.extend_from_slice(&self.0);
Ok(())
}
}
impl Decode<'_> for CopyFail {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self> {
Ok(CopyFail {
message: buf.get_str_nul()?,
})
impl FrontendMessage for CopyFail {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyFail;
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(self.message.len())
}
}
impl Encode<'_> for CopyFail {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
let len = 4 + self.message.len() + 1;
buf.push(b'f'); // to pay respects
buf.put_u32(len as u32);
#[inline(always)]
fn encode_body(&self, buf: &mut Vec<u8>) -> std::result::Result<(), Error> {
buf.put_str_nul(&self.message);
Ok(())
}
}
impl CopyFail {
#[inline(always)]
pub fn new(msg: impl Into<String>) -> CopyFail {
CopyFail {
message: msg.into(),
@ -74,23 +110,32 @@ impl CopyFail {
}
}
impl Decode<'_> for CopyDone {
fn decode_with(buf: Bytes, _: ()) -> Result<Self> {
if buf.is_empty() {
Ok(CopyDone)
} else {
Err(err_protocol!(
"expected no data for CopyDone, got: {:?}",
buf
))
}
impl FrontendMessage for CopyDone {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyDone;
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(0)
}
#[inline(always)]
fn encode_body(&self, _buf: &mut Vec<u8>) -> std::result::Result<(), Error> {
Ok(())
}
}
impl Encode<'_> for CopyDone {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.reserve(4);
buf.push(b'c');
buf.put_u32(4);
impl BackendMessage for CopyDone {
const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyDone;
#[inline(always)]
fn decode_body(bytes: Bytes) -> std::result::Result<Self, Error> {
if !bytes.is_empty() {
// Not fatal but may indicate a protocol change
tracing::debug!(
"Postgres backend returned non-empty message for CopyDone: \"{}\"",
bytes.escape_ascii()
)
}
Ok(CopyDone)
}
}

View file

@ -1,10 +1,9 @@
use std::ops::Range;
use byteorder::{BigEndian, ByteOrder};
use sqlx_core::bytes::Bytes;
use std::ops::Range;
use crate::error::Error;
use crate::io::Decode;
use crate::message::{BackendMessage, BackendMessageFormat};
/// A row of data from the database.
#[derive(Debug)]
@ -26,25 +25,55 @@ impl DataRow {
}
}
impl Decode<'_> for DataRow {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
impl BackendMessage for DataRow {
const FORMAT: BackendMessageFormat = BackendMessageFormat::DataRow;
fn decode_body(buf: Bytes) -> Result<Self, Error> {
if buf.len() < 2 {
return Err(err_protocol!(
"expected at least 2 bytes, got {}",
buf.len()
));
}
let cnt = BigEndian::read_u16(&buf) as usize;
let mut values = Vec::with_capacity(cnt);
let mut offset = 2;
let mut offset: u32 = 2;
for _ in 0..cnt {
let value_start = offset
.checked_add(4)
.ok_or_else(|| err_protocol!("next value start out of range (offset: {offset})"))?;
// widen both to a larger type for a safe comparison
if (buf.len() as u64) < (value_start as u64) {
return Err(err_protocol!(
"expected 4 bytes at offset {offset}, got {}",
(value_start as u64) - (buf.len() as u64)
));
}
// Length of the column value, in bytes (this count does not include itself).
// Can be zero. As a special case, -1 indicates a NULL column value.
// No value bytes follow in the NULL case.
//
// we know `offset` is within range of `buf.len()` from the above check
#[allow(clippy::cast_possible_truncation)]
let length = BigEndian::read_i32(&buf[(offset as usize)..]);
offset += 4;
if length < 0 {
values.push(None);
if let Ok(length) = u32::try_from(length) {
let value_end = value_start.checked_add(length).ok_or_else(|| {
err_protocol!("value_start + length out of range ({offset} + {length})")
})?;
values.push(Some(value_start..value_end));
offset = value_end;
} else {
values.push(Some(offset..(offset + length as u32)));
offset += length as u32;
// Negative values signify NULL
values.push(None);
// `value_start` is actually the next value now.
offset = value_start;
}
}
@ -57,9 +86,22 @@ impl Decode<'_> for DataRow {
#[test]
fn test_decode_data_row() {
const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P";
const DATA: &[u8] = b"\
\x00\x08\
\xff\xff\xff\xff\
\x00\x00\x00\x04\
\x00\x00\x00\n\
\xff\xff\xff\xff\
\x00\x00\x00\x04\
\x00\x00\x00\x14\
\xff\xff\xff\xff\
\x00\x00\x00\x04\
\x00\x00\x00(\
\xff\xff\xff\xff\
\x00\x00\x00\x04\
\x00\x00\x00P";
let row = DataRow::decode(DATA.into()).unwrap();
let row = DataRow::decode_body(DATA.into()).unwrap();
assert_eq!(row.values.len(), 8);
@ -78,7 +120,7 @@ fn test_decode_data_row() {
fn bench_data_row_get(b: &mut test::Bencher) {
const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P";
let row = DataRow::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
let row = DataRow::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
b.iter(|| {
let _value = test::black_box(&row).get(3);
@ -91,6 +133,6 @@ fn bench_decode_data_row(b: &mut test::Bencher) {
const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P";
b.iter(|| {
let _ = DataRow::decode(test::black_box(Bytes::from_static(DATA)));
let _ = DataRow::decode_body(test::black_box(Bytes::from_static(DATA)));
});
}

View file

@ -1,127 +1,103 @@
use crate::io::Encode;
use crate::io::PgBufMutExt;
use crate::types::Oid;
use crate::io::{PgBufMutExt, PortalId, StatementId};
use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
use std::num::Saturating;
const DESCRIBE_PORTAL: u8 = b'P';
const DESCRIBE_STATEMENT: u8 = b'S';
// [Describe] will emit both a [RowDescription] and a [ParameterDescription] message
/// Note: will emit both a RowDescription and a ParameterDescription message
#[derive(Debug)]
#[allow(dead_code)]
pub enum Describe {
UnnamedStatement,
Statement(Oid),
UnnamedPortal,
Portal(Oid),
Statement(StatementId),
Portal(PortalId),
}
impl Encode<'_> for Describe {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
// 15 bytes for 1-digit statement/portal IDs
buf.reserve(20);
buf.push(b'D');
impl FrontendMessage for Describe {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Describe;
buf.put_length_prefixed(|buf| {
match self {
// #[likely]
Describe::Statement(id) => {
buf.push(DESCRIBE_STATEMENT);
buf.put_statement_name(*id);
}
fn body_size_hint(&self) -> Saturating<usize> {
// Either `DESCRIBE_PORTAL` or `DESCRIBE_STATEMENT`
let mut size = Saturating(1);
Describe::UnnamedPortal => {
buf.push(DESCRIBE_PORTAL);
buf.push(0);
}
match self {
Describe::Statement(id) => size += id.name_len(),
Describe::Portal(id) => size += id.name_len(),
}
Describe::UnnamedStatement => {
buf.push(DESCRIBE_STATEMENT);
buf.push(0);
}
size
}
Describe::Portal(id) => {
buf.push(DESCRIBE_PORTAL);
buf.put_portal_name(Some(*id));
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
match self {
// #[likely]
Describe::Statement(id) => {
buf.push(DESCRIBE_STATEMENT);
buf.put_statement_name(*id);
}
});
Describe::Portal(id) => {
buf.push(DESCRIBE_PORTAL);
buf.put_portal_name(*id);
}
}
Ok(())
}
}
#[test]
fn test_encode_describe_portal() {
const EXPECTED: &[u8] = b"D\0\0\0\x0EPsqlx_p_5\0";
#[cfg(test)]
mod tests {
use crate::message::FrontendMessage;
let mut buf = Vec::new();
let m = Describe::Portal(Oid(5));
use super::{Describe, PortalId, StatementId};
m.encode(&mut buf);
#[test]
fn test_encode_describe_portal() {
const EXPECTED: &[u8] = b"D\0\0\0\x17Psqlx_p_1234567890\0";
assert_eq!(buf, EXPECTED);
}
#[test]
fn test_encode_describe_unnamed_portal() {
const EXPECTED: &[u8] = b"D\0\0\0\x06P\0";
let mut buf = Vec::new();
let m = Describe::UnnamedPortal;
m.encode(&mut buf);
assert_eq!(buf, EXPECTED);
}
#[test]
fn test_encode_describe_statement() {
const EXPECTED: &[u8] = b"D\0\0\0\x0ESsqlx_s_5\0";
let mut buf = Vec::new();
let m = Describe::Statement(Oid(5));
m.encode(&mut buf);
assert_eq!(buf, EXPECTED);
}
#[test]
fn test_encode_describe_unnamed_statement() {
const EXPECTED: &[u8] = b"D\0\0\0\x06S\0";
let mut buf = Vec::new();
let m = Describe::UnnamedStatement;
m.encode(&mut buf);
assert_eq!(buf, EXPECTED);
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_describe_portal(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Describe::Portal(5)).encode(&mut buf);
});
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_describe_unnamed_statement(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Describe::UnnamedStatement).encode(&mut buf);
});
let mut buf = Vec::new();
let m = Describe::Portal(PortalId::TEST_VAL);
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}
#[test]
fn test_encode_describe_unnamed_portal() {
const EXPECTED: &[u8] = b"D\0\0\0\x06P\0";
let mut buf = Vec::new();
let m = Describe::Portal(PortalId::UNNAMED);
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}
#[test]
fn test_encode_describe_statement() {
const EXPECTED: &[u8] = b"D\0\0\0\x17Ssqlx_s_1234567890\0";
let mut buf = Vec::new();
let m = Describe::Statement(StatementId::TEST_VAL);
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}
#[test]
fn test_encode_describe_unnamed_statement() {
const EXPECTED: &[u8] = b"D\0\0\0\x06S\0";
let mut buf = Vec::new();
let m = Describe::Statement(StatementId::UNNAMED);
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}
}

View file

@ -1,39 +1,73 @@
use crate::io::Encode;
use crate::io::PgBufMutExt;
use crate::types::Oid;
use std::num::Saturating;
use sqlx_core::Error;
use crate::io::{PgBufMutExt, PortalId};
use crate::message::{FrontendMessage, FrontendMessageFormat};
pub struct Execute {
/// The id of the portal to execute (`None` selects the unnamed portal).
pub portal: Option<Oid>,
/// The id of the portal to execute.
pub portal: PortalId,
/// Maximum number of rows to return, if portal contains a query
/// that returns rows (ignored otherwise). Zero denotes “no limit”.
pub limit: u32,
}
impl Encode<'_> for Execute {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.reserve(20);
buf.push(b'E');
impl FrontendMessage for Execute {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Execute;
buf.put_length_prefixed(|buf| {
buf.put_portal_name(self.portal);
buf.extend(&self.limit.to_be_bytes());
});
fn body_size_hint(&self) -> Saturating<usize> {
let mut size = Saturating(0);
size += self.portal.name_len();
size += 2; // limit
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.put_portal_name(self.portal);
buf.extend(&self.limit.to_be_bytes());
Ok(())
}
}
#[test]
fn test_encode_execute() {
const EXPECTED: &[u8] = b"E\0\0\0\x11sqlx_p_5\0\0\0\0\x02";
#[cfg(test)]
mod tests {
use crate::io::PortalId;
use crate::message::FrontendMessage;
let mut buf = Vec::new();
let m = Execute {
portal: Some(Oid(5)),
limit: 2,
};
use super::Execute;
m.encode(&mut buf);
#[test]
fn test_encode_execute_named_portal() {
const EXPECTED: &[u8] = b"E\0\0\0\x1Asqlx_p_1234567890\0\0\0\0\x02";
assert_eq!(buf, EXPECTED);
let mut buf = Vec::new();
let m = Execute {
portal: PortalId::TEST_VAL,
limit: 2,
};
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}
#[test]
fn test_encode_execute_unnamed_portal() {
const EXPECTED: &[u8] = b"E\0\0\0\x09\0\x49\x96\x02\xD2";
let mut buf = Vec::new();
let m = Execute {
portal: PortalId::UNNAMED,
limit: 1234567890,
};
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}
}

View file

@ -1,17 +1,25 @@
use crate::io::Encode;
// The Flush message does not cause any specific output to be generated,
// but forces the backend to deliver any data pending in its output buffers.
// A Flush must be sent after any extended-query command except Sync, if the
// frontend wishes to examine the results of that command before issuing more commands.
use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
use std::num::Saturating;
/// The Flush message does not cause any specific output to be generated,
/// but forces the backend to deliver any data pending in its output buffers.
///
/// A Flush must be sent after any extended-query command except Sync, if the
/// frontend wishes to examine the results of that command before issuing more commands.
#[derive(Debug)]
pub struct Flush;
impl Encode<'_> for Flush {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.push(b'H');
buf.extend(&4_i32.to_be_bytes());
impl FrontendMessage for Flush {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Flush;
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(0)
}
#[inline(always)]
fn encode_body(&self, _buf: &mut Vec<u8>) -> Result<(), Error> {
Ok(())
}
}

View file

@ -1,7 +1,8 @@
use sqlx_core::bytes::Bytes;
use std::num::Saturating;
use crate::error::Error;
use crate::io::Decode;
use crate::io::PgBufMutExt;
mod authentication;
mod backend_key_data;
@ -17,6 +18,7 @@ mod notification;
mod parameter_description;
mod parameter_status;
mod parse;
mod parse_complete;
mod password;
mod query;
mod ready_for_query;
@ -33,7 +35,7 @@ pub use backend_key_data::BackendKeyData;
pub use bind::Bind;
pub use close::Close;
pub use command_complete::CommandComplete;
pub use copy::{CopyData, CopyDone, CopyFail, CopyResponse};
pub use copy::{CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse, CopyResponseData};
pub use data_row::DataRow;
pub use describe::Describe;
pub use execute::Execute;
@ -43,20 +45,51 @@ pub use notification::Notification;
pub use parameter_description::ParameterDescription;
pub use parameter_status::ParameterStatus;
pub use parse::Parse;
pub use parse_complete::ParseComplete;
pub use password::Password;
pub use query::Query;
pub use ready_for_query::{ReadyForQuery, TransactionStatus};
pub use response::{Notice, PgSeverity};
pub use row_description::RowDescription;
pub use sasl::{SaslInitialResponse, SaslResponse};
use sqlx_core::io::ProtocolEncode;
pub use ssl_request::SslRequest;
pub use startup::Startup;
pub use sync::Sync;
pub use terminate::Terminate;
// Note: we can't use the same enum for both frontend and backend message formats
// because there are duplicated format codes between them.
//
// For example, `Close` (frontend) and `CommandComplete` (backend) both use format code `C`.
// <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Debug, PartialOrd, PartialEq)]
#[repr(u8)]
pub enum MessageFormat {
pub enum FrontendMessageFormat {
Bind = b'B',
Close = b'C',
CopyData = b'd',
CopyDone = b'c',
CopyFail = b'f',
Describe = b'D',
Execute = b'E',
Flush = b'H',
Parse = b'P',
/// This message format is polymorphic. It's used for:
///
/// * Plain password responses
/// * MD5 password responses
/// * SASL responses
/// * GSSAPI/SSPI responses
PasswordPolymorphic = b'p',
Query = b'Q',
Sync = b'S',
Terminate = b'X',
}
#[derive(Debug, PartialOrd, PartialEq)]
#[repr(u8)]
pub enum BackendMessageFormat {
Authentication,
BackendKeyData,
BindComplete,
@ -81,49 +114,116 @@ pub enum MessageFormat {
}
#[derive(Debug)]
pub struct Message {
pub format: MessageFormat,
pub struct ReceivedMessage {
pub format: BackendMessageFormat,
pub contents: Bytes,
}
impl Message {
impl ReceivedMessage {
#[inline]
pub fn decode<'de, T>(self) -> Result<T, Error>
pub fn decode<T>(self) -> Result<T, Error>
where
T: Decode<'de>,
T: BackendMessage,
{
T::decode(self.contents)
if T::FORMAT != self.format {
return Err(err_protocol!(
"Postgres protocol error: expected {:?}, got {:?}",
T::FORMAT,
self.format
));
}
T::decode_body(self.contents).map_err(|e| match e {
Error::Protocol(s) => {
err_protocol!("Postgres protocol error (reading {:?}): {s}", self.format)
}
other => other,
})
}
}
impl MessageFormat {
impl BackendMessageFormat {
pub fn try_from_u8(v: u8) -> Result<Self, Error> {
// https://www.postgresql.org/docs/current/protocol-message-formats.html
Ok(match v {
b'1' => MessageFormat::ParseComplete,
b'2' => MessageFormat::BindComplete,
b'3' => MessageFormat::CloseComplete,
b'C' => MessageFormat::CommandComplete,
b'd' => MessageFormat::CopyData,
b'c' => MessageFormat::CopyDone,
b'G' => MessageFormat::CopyInResponse,
b'H' => MessageFormat::CopyOutResponse,
b'D' => MessageFormat::DataRow,
b'E' => MessageFormat::ErrorResponse,
b'I' => MessageFormat::EmptyQueryResponse,
b'A' => MessageFormat::NotificationResponse,
b'K' => MessageFormat::BackendKeyData,
b'N' => MessageFormat::NoticeResponse,
b'R' => MessageFormat::Authentication,
b'S' => MessageFormat::ParameterStatus,
b'T' => MessageFormat::RowDescription,
b'Z' => MessageFormat::ReadyForQuery,
b'n' => MessageFormat::NoData,
b's' => MessageFormat::PortalSuspended,
b't' => MessageFormat::ParameterDescription,
b'1' => BackendMessageFormat::ParseComplete,
b'2' => BackendMessageFormat::BindComplete,
b'3' => BackendMessageFormat::CloseComplete,
b'C' => BackendMessageFormat::CommandComplete,
b'd' => BackendMessageFormat::CopyData,
b'c' => BackendMessageFormat::CopyDone,
b'G' => BackendMessageFormat::CopyInResponse,
b'H' => BackendMessageFormat::CopyOutResponse,
b'D' => BackendMessageFormat::DataRow,
b'E' => BackendMessageFormat::ErrorResponse,
b'I' => BackendMessageFormat::EmptyQueryResponse,
b'A' => BackendMessageFormat::NotificationResponse,
b'K' => BackendMessageFormat::BackendKeyData,
b'N' => BackendMessageFormat::NoticeResponse,
b'R' => BackendMessageFormat::Authentication,
b'S' => BackendMessageFormat::ParameterStatus,
b'T' => BackendMessageFormat::RowDescription,
b'Z' => BackendMessageFormat::ReadyForQuery,
b'n' => BackendMessageFormat::NoData,
b's' => BackendMessageFormat::PortalSuspended,
b't' => BackendMessageFormat::ParameterDescription,
_ => return Err(err_protocol!("unknown message type: {:?}", v as char)),
})
}
}
pub(crate) trait FrontendMessage: Sized {
/// The format prefix of this message.
const FORMAT: FrontendMessageFormat;
/// Return the amount of space, in bytes, to reserve in the buffer passed to [`Self::encode_body()`].
fn body_size_hint(&self) -> Saturating<usize>;
/// Encode this type as a Frontend message in the Postgres protocol.
///
/// The implementation should *not* include `Self::FORMAT` or the length prefix.
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error>;
#[inline(always)]
#[cfg_attr(not(test), allow(dead_code))]
fn encode_msg(self, buf: &mut Vec<u8>) -> Result<(), Error> {
EncodeMessage(self).encode(buf)
}
}
pub(crate) trait BackendMessage: Sized {
/// The expected message format.
///
/// <https://www.postgresql.org/docs/current/protocol-message-formats.html>
const FORMAT: BackendMessageFormat;
/// Decode this type from a Backend message in the Postgres protocol.
///
/// The format code and length prefix have already been read and are not at the start of `bytes`.
fn decode_body(buf: Bytes) -> Result<Self, Error>;
}
pub struct EncodeMessage<F>(pub F);
impl<F: FrontendMessage> ProtocolEncode<'_, ()> for EncodeMessage<F> {
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) -> Result<(), Error> {
let mut size_hint = self.0.body_size_hint();
// plus format code and length prefix
size_hint += 5;
// don't panic if `size_hint` is ridiculous
buf.try_reserve(size_hint.0).map_err(|e| {
err_protocol!(
"Postgres protocol: error allocating {} bytes for encoding message {:?}: {e}",
size_hint.0,
F::FORMAT,
)
})?;
buf.push(F::FORMAT as u8);
buf.put_length_prefixed(|buf| self.0.encode_body(buf))
}
}

View file

@ -1,7 +1,8 @@
use sqlx_core::bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{BufExt, Decode};
use crate::io::BufExt;
use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug)]
pub struct Notification {
@ -10,9 +11,10 @@ pub struct Notification {
pub(crate) payload: Bytes,
}
impl Decode<'_> for Notification {
#[inline]
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
impl BackendMessage for Notification {
const FORMAT: BackendMessageFormat = BackendMessageFormat::NotificationResponse;
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
let process_id = buf.get_u32();
let channel = buf.get_bytes_nul()?;
let payload = buf.get_bytes_nul()?;
@ -29,7 +31,7 @@ impl Decode<'_> for Notification {
fn test_decode_notification_response() {
const NOTIFICATION_RESPONSE: &[u8] = b"\x34\x20\x10\x02TEST-CHANNEL\0THIS IS A TEST\0";
let message = Notification::decode(Bytes::from(NOTIFICATION_RESPONSE)).unwrap();
let message = Notification::decode_body(Bytes::from(NOTIFICATION_RESPONSE)).unwrap();
assert_eq!(message.process_id, 0x34201002);
assert_eq!(&*message.channel, &b"TEST-CHANNEL"[..]);

View file

@ -2,7 +2,7 @@ use smallvec::SmallVec;
use sqlx_core::bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::Decode;
use crate::message::{BackendMessage, BackendMessageFormat};
use crate::types::Oid;
#[derive(Debug)]
@ -10,8 +10,10 @@ pub struct ParameterDescription {
pub types: SmallVec<[Oid; 6]>,
}
impl Decode<'_> for ParameterDescription {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
impl BackendMessage for ParameterDescription {
const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterDescription;
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
let cnt = buf.get_u16();
let mut types = SmallVec::with_capacity(cnt as usize);
@ -27,7 +29,7 @@ impl Decode<'_> for ParameterDescription {
fn test_decode_parameter_description() {
const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00";
let m = ParameterDescription::decode(DATA.into()).unwrap();
let m = ParameterDescription::decode_body(DATA.into()).unwrap();
assert_eq!(m.types.len(), 2);
assert_eq!(m.types[0], Oid(0x0000_0000));
@ -38,7 +40,7 @@ fn test_decode_parameter_description() {
fn test_decode_empty_parameter_description() {
const DATA: &[u8] = b"\x00\x00";
let m = ParameterDescription::decode(DATA.into()).unwrap();
let m = ParameterDescription::decode_body(DATA.into()).unwrap();
assert!(m.types.is_empty());
}
@ -49,6 +51,6 @@ fn bench_decode_parameter_description(b: &mut test::Bencher) {
const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00";
b.iter(|| {
ParameterDescription::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
ParameterDescription::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
});
}

View file

@ -1,7 +1,8 @@
use sqlx_core::bytes::Bytes;
use crate::error::Error;
use crate::io::{BufExt, Decode};
use crate::io::BufExt;
use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug)]
pub struct ParameterStatus {
@ -9,8 +10,10 @@ pub struct ParameterStatus {
pub value: String,
}
impl Decode<'_> for ParameterStatus {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
impl BackendMessage for ParameterStatus {
const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterStatus;
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
let name = buf.get_str_nul()?;
let value = buf.get_str_nul()?;
@ -22,7 +25,7 @@ impl Decode<'_> for ParameterStatus {
fn test_decode_parameter_status() {
const DATA: &[u8] = b"client_encoding\x00UTF8\x00";
let m = ParameterStatus::decode(DATA.into()).unwrap();
let m = ParameterStatus::decode_body(DATA.into()).unwrap();
assert_eq!(&m.name, "client_encoding");
assert_eq!(&m.value, "UTF8")
@ -32,7 +35,7 @@ fn test_decode_parameter_status() {
fn test_decode_empty_parameter_status() {
const DATA: &[u8] = b"\x00\x00";
let m = ParameterStatus::decode(DATA.into()).unwrap();
let m = ParameterStatus::decode_body(DATA.into()).unwrap();
assert!(m.name.is_empty());
assert!(m.value.is_empty());
@ -44,7 +47,7 @@ fn bench_decode_parameter_status(b: &mut test::Bencher) {
const DATA: &[u8] = b"client_encoding\x00UTF8\x00";
b.iter(|| {
ParameterStatus::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
ParameterStatus::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
});
}
@ -52,7 +55,7 @@ fn bench_decode_parameter_status(b: &mut test::Bencher) {
fn test_decode_parameter_status_response() {
const PARAMETER_STATUS_RESPONSE: &[u8] = b"crdb_version\0CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)\0";
let message = ParameterStatus::decode(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap();
let message = ParameterStatus::decode_body(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap();
assert_eq!(message.name, "crdb_version");
assert_eq!(

View file

@ -1,11 +1,14 @@
use crate::io::PgBufMutExt;
use crate::io::{BufMutExt, Encode};
use crate::io::BufMutExt;
use crate::io::{PgBufMutExt, StatementId};
use crate::message::{FrontendMessage, FrontendMessageFormat};
use crate::types::Oid;
use sqlx_core::Error;
use std::num::Saturating;
#[derive(Debug)]
pub struct Parse<'a> {
/// The ID of the destination prepared statement.
pub statement: Oid,
pub statement: StatementId,
/// The query string to be parsed.
pub query: &'a str,
@ -16,39 +19,59 @@ pub struct Parse<'a> {
pub param_types: &'a [Oid],
}
impl Encode<'_> for Parse<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.push(b'P');
impl FrontendMessage for Parse<'_> {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Parse;
buf.put_length_prefixed(|buf| {
buf.put_statement_name(self.statement);
fn body_size_hint(&self) -> Saturating<usize> {
let mut size = Saturating(0);
buf.put_str_nul(self.query);
size += self.statement.name_len();
// TODO: Return an error here instead
assert!(self.param_types.len() <= (u16::MAX as usize));
size += self.query.len();
size += 1; // NUL terminator
buf.extend(&(self.param_types.len() as i16).to_be_bytes());
size += 2; // param_types_len
for &oid in self.param_types {
buf.extend(&oid.0.to_be_bytes());
}
})
// `param_types`
size += self.param_types.len().saturating_mul(4);
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.put_statement_name(self.statement);
buf.put_str_nul(self.query);
let param_types_len = i16::try_from(self.param_types.len()).map_err(|_| {
err_protocol!(
"param_types.len() too large for binary protocol: {}",
self.param_types.len()
)
})?;
buf.extend(param_types_len.to_be_bytes());
for &oid in self.param_types {
buf.extend(oid.0.to_be_bytes());
}
Ok(())
}
}
#[test]
fn test_encode_parse() {
const EXPECTED: &[u8] = b"P\0\0\0\x1dsqlx_s_1\0SELECT $1\0\0\x01\0\0\0\x19";
const EXPECTED: &[u8] = b"P\0\0\0\x26sqlx_s_1234567890\0SELECT $1\0\0\x01\0\0\0\x19";
let mut buf = Vec::new();
let m = Parse {
statement: Oid(1),
statement: StatementId::TEST_VAL,
query: "SELECT $1",
param_types: &[Oid(25)],
};
m.encode(&mut buf);
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}

View file

@ -0,0 +1,13 @@
use crate::message::{BackendMessage, BackendMessageFormat};
use sqlx_core::bytes::Bytes;
use sqlx_core::Error;
pub struct ParseComplete;
impl BackendMessage for ParseComplete {
const FORMAT: BackendMessageFormat = BackendMessageFormat::ParseComplete;
fn decode_body(_bytes: Bytes) -> Result<Self, Error> {
Ok(ParseComplete)
}
}

View file

@ -1,9 +1,9 @@
use std::fmt::Write;
use crate::io::BufMutExt;
use crate::message::{FrontendMessage, FrontendMessageFormat};
use md5::{Digest, Md5};
use crate::io::PgBufMutExt;
use crate::io::{BufMutExt, Encode};
use sqlx_core::Error;
use std::fmt::Write;
use std::num::Saturating;
#[derive(Debug)]
pub enum Password<'a> {
@ -16,117 +16,138 @@ pub enum Password<'a> {
},
}
impl Password<'_> {
#[inline]
fn len(&self) -> usize {
impl FrontendMessage for Password<'_> {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic;
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
let mut size = Saturating(0);
match self {
Password::Cleartext(s) => s.len() + 5,
Password::Md5 { .. } => 35 + 5,
}
}
}
impl Encode<'_> for Password<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.reserve(1 + 4 + self.len());
buf.push(b'p');
buf.put_length_prefixed(|buf| {
match self {
Password::Cleartext(password) => {
buf.put_str_nul(password);
}
Password::Md5 {
username,
password,
salt,
} => {
// The actual `PasswordMessage` can be computed in SQL as
// `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`.
// Keep in mind the md5() function returns its result as a hex string.
let mut hasher = Md5::new();
hasher.update(password);
hasher.update(username);
let mut output = String::with_capacity(35);
let _ = write!(output, "{:x}", hasher.finalize_reset());
hasher.update(&output);
hasher.update(salt);
output.clear();
let _ = write!(output, "md5{:x}", hasher.finalize());
buf.put_str_nul(&output);
}
Password::Cleartext(password) => {
// To avoid reporting the exact password length anywhere,
// we deliberately give a bad estimate.
//
// This shouldn't affect performance in the long run.
size += password
.len()
.saturating_add(1) // NUL terminator
.checked_next_power_of_two()
.unwrap_or(usize::MAX);
}
});
Password::Md5 { .. } => {
// "md5<32 hex chars>\0"
size += 36;
}
}
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
match self {
Password::Cleartext(password) => {
buf.put_str_nul(password);
}
Password::Md5 {
username,
password,
salt,
} => {
// The actual `PasswordMessage` can be computed in SQL as
// `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`.
// Keep in mind the md5() function returns its result as a hex string.
let mut hasher = Md5::new();
hasher.update(password);
hasher.update(username);
let mut output = String::with_capacity(35);
let _ = write!(output, "{:x}", hasher.finalize_reset());
hasher.update(&output);
hasher.update(salt);
output.clear();
let _ = write!(output, "md5{:x}", hasher.finalize());
buf.put_str_nul(&output);
}
}
Ok(())
}
}
#[test]
fn test_encode_clear_password() {
const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0";
#[cfg(test)]
mod tests {
use crate::message::FrontendMessage;
let mut buf = Vec::new();
let m = Password::Cleartext("password");
use super::Password;
m.encode(&mut buf);
#[test]
fn test_encode_clear_password() {
const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0";
assert_eq!(buf, EXPECTED);
}
let mut buf = Vec::new();
let m = Password::Cleartext("password");
#[test]
fn test_encode_md5_password() {
const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0";
m.encode_msg(&mut buf).unwrap();
let mut buf = Vec::new();
let m = Password::Md5 {
password: "password",
username: "root",
salt: [147, 24, 57, 152],
};
assert_eq!(buf, EXPECTED);
}
m.encode(&mut buf);
#[test]
fn test_encode_md5_password() {
const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0";
assert_eq!(buf, EXPECTED);
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_clear_password(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Password::Cleartext("password")).encode(&mut buf);
});
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_md5_password(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Password::Md5 {
let mut buf = Vec::new();
let m = Password::Md5 {
password: "password",
username: "root",
salt: [147, 24, 57, 152],
})
.encode(&mut buf);
});
};
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_clear_password(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Password::Cleartext("password")).encode_msg(&mut buf);
});
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_md5_password(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Password::Md5 {
password: "password",
username: "root",
salt: [147, 24, 57, 152],
})
.encode_msg(&mut buf);
});
}
}

View file

@ -1,27 +1,37 @@
use crate::io::{BufMutExt, Encode};
use crate::io::BufMutExt;
use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
use std::num::Saturating;
#[derive(Debug)]
pub struct Query<'a>(pub &'a str);
impl Encode<'_> for Query<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
let len = 4 + self.0.len() + 1;
impl FrontendMessage for Query<'_> {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Query;
buf.reserve(len + 1);
buf.push(b'Q');
buf.extend(&(len as i32).to_be_bytes());
fn body_size_hint(&self) -> Saturating<usize> {
let mut size = Saturating(0);
size += self.0.len();
size += 1; // NUL terminator
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.put_str_nul(self.0);
Ok(())
}
}
#[test]
fn test_encode_query() {
const EXPECTED: &[u8] = b"Q\0\0\0\rSELECT 1\0";
const EXPECTED: &[u8] = b"Q\0\0\0\x0DSELECT 1\0";
let mut buf = Vec::new();
let m = Query("SELECT 1");
m.encode(&mut buf);
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}

View file

@ -1,7 +1,7 @@
use sqlx_core::bytes::Bytes;
use crate::error::Error;
use crate::io::Decode;
use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug)]
#[repr(u8)]
@ -21,8 +21,10 @@ pub struct ReadyForQuery {
pub transaction_status: TransactionStatus,
}
impl Decode<'_> for ReadyForQuery {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
impl BackendMessage for ReadyForQuery {
const FORMAT: BackendMessageFormat = BackendMessageFormat::ReadyForQuery;
fn decode_body(buf: Bytes) -> Result<Self, Error> {
let status = match buf[0] {
b'I' => TransactionStatus::Idle,
b'T' => TransactionStatus::Transaction,
@ -46,7 +48,7 @@ impl Decode<'_> for ReadyForQuery {
fn test_decode_ready_for_query() -> Result<(), Error> {
const DATA: &[u8] = b"E";
let m = ReadyForQuery::decode(Bytes::from_static(DATA))?;
let m = ReadyForQuery::decode_body(Bytes::from_static(DATA))?;
assert!(matches!(m.transaction_status, TransactionStatus::Error));

View file

@ -1,10 +1,13 @@
use std::ops::Range;
use std::str::from_utf8;
use memchr::memchr;
use sqlx_core::bytes::Bytes;
use crate::error::Error;
use crate::io::Decode;
use crate::io::ProtocolDecode;
use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)]
@ -53,8 +56,8 @@ impl TryFrom<&str> for PgSeverity {
pub struct Notice {
storage: Bytes,
severity: PgSeverity,
message: (u16, u16),
code: (u16, u16),
message: Range<usize>,
code: Range<usize>,
}
impl Notice {
@ -65,12 +68,12 @@ impl Notice {
#[inline]
pub fn code(&self) -> &str {
self.get_cached_str(self.code)
self.get_cached_str(self.code.clone())
}
#[inline]
pub fn message(&self) -> &str {
self.get_cached_str(self.message)
self.get_cached_str(self.message.clone())
}
// Field descriptions available here:
@ -84,7 +87,7 @@ impl Notice {
pub fn get_raw(&self, ty: u8) -> Option<&[u8]> {
self.fields()
.filter(|(field, _)| *field == ty)
.map(|(_, (start, end))| &self.storage[start as usize..end as usize])
.map(|(_, range)| &self.storage[range])
.next()
}
}
@ -99,13 +102,13 @@ impl Notice {
}
#[inline]
fn get_cached_str(&self, cache: (u16, u16)) -> &str {
fn get_cached_str(&self, cache: Range<usize>) -> &str {
// unwrap: this cannot fail at this stage
from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap()
from_utf8(&self.storage[cache]).unwrap()
}
}
impl Decode<'_> for Notice {
impl ProtocolDecode<'_> for Notice {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
// In order to support PostgreSQL 9.5 and older we need to parse the localized S field.
// Newer versions additionally come with the V field that is guaranteed to be in English.
@ -113,8 +116,8 @@ impl Decode<'_> for Notice {
const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log;
let mut severity_v = None;
let mut severity_s = None;
let mut message = (0, 0);
let mut code = (0, 0);
let mut message = 0..0;
let mut code = 0..0;
// we cache the three always present fields
// this enables to keep the access time down for the fields most likely accessed
@ -125,7 +128,7 @@ impl Decode<'_> for Notice {
};
for (field, v) in fields {
if message.0 != 0 && code.0 != 0 {
if !(message.is_empty() || code.is_empty()) {
// stop iterating when we have the 3 fields we were looking for
// we assume V (severity) was the first field as it should be
break;
@ -133,7 +136,7 @@ impl Decode<'_> for Notice {
match field {
b'S' => {
severity_s = from_utf8(&buf[v.0 as usize..v.1 as usize])
severity_s = from_utf8(&buf[v.clone()])
// If the error string is not UTF-8, we have no hope of interpreting it,
// localized or not. The `V` field would likely fail to parse as well.
.map_err(|_| notice_protocol_err())?
@ -146,21 +149,19 @@ impl Decode<'_> for Notice {
// Propagate errors here, because V is not localized and
// thus we are missing a possible variant.
severity_v = Some(
from_utf8(&buf[v.0 as usize..v.1 as usize])
from_utf8(&buf[v.clone()])
.map_err(|_| notice_protocol_err())?
.try_into()?,
);
}
b'M' => {
_ = from_utf8(&buf[v.0 as usize..v.1 as usize])
.map_err(|_| notice_protocol_err())?;
_ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?;
message = v;
}
b'C' => {
_ = from_utf8(&buf[v.0 as usize..v.1 as usize])
.map_err(|_| notice_protocol_err())?;
_ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?;
code = v;
}
@ -179,31 +180,46 @@ impl Decode<'_> for Notice {
}
}
impl BackendMessage for Notice {
const FORMAT: BackendMessageFormat = BackendMessageFormat::NoticeResponse;
fn decode_body(buf: Bytes) -> Result<Self, Error> {
// Keeping both impls for now
Self::decode_with(buf, ())
}
}
/// An iterator over each field in the Error (or Notice) response.
struct Fields<'a> {
storage: &'a [u8],
offset: u16,
offset: usize,
}
impl<'a> Iterator for Fields<'a> {
type Item = (u8, (u16, u16));
type Item = (u8, Range<usize>);
fn next(&mut self) -> Option<Self::Item> {
// The fields in the response body are sequentially stored as [tag][string],
// ending in a final, additional [nul]
let ty = self.storage[self.offset as usize];
let ty = *self.storage.get(self.offset)?;
if ty == 0 {
return None;
}
let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16;
let offset = self.offset;
// Consume the type byte
self.offset = self.offset.checked_add(1)?;
self.offset += nul + 2;
let start = self.offset;
Some((ty, (offset + 1, offset + nul + 1)))
let len = memchr(b'\0', self.storage.get(start..)?)?;
// Neither can overflow as they will always be `<= self.storage.len()`.
let end = self.offset + len;
self.offset = end + 1;
Some((ty, start..end))
}
}

View file

@ -1,7 +1,8 @@
use sqlx_core::bytes::{Buf, Bytes};
use crate::error::Error;
use crate::io::{BufExt, Decode};
use crate::io::BufExt;
use crate::message::{BackendMessage, BackendMessageFormat};
use crate::types::Oid;
#[derive(Debug)]
@ -40,13 +41,30 @@ pub struct Field {
pub format: i16,
}
impl Decode<'_> for RowDescription {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
impl BackendMessage for RowDescription {
const FORMAT: BackendMessageFormat = BackendMessageFormat::RowDescription;
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
if buf.len() < 2 {
return Err(err_protocol!(
"expected at least 2 bytes, got {}",
buf.len()
));
}
let cnt = buf.get_u16();
let mut fields = Vec::with_capacity(cnt as usize);
for _ in 0..cnt {
let name = buf.get_str_nul()?.to_owned();
if buf.len() < 18 {
return Err(err_protocol!(
"expected at least 18 bytes after field name {name:?}, got {}",
buf.len()
));
}
let relation_id = buf.get_i32();
let relation_attribute_no = buf.get_i16();
let data_type_id = Oid(buf.get_u32());

View file

@ -1,35 +1,69 @@
use crate::io::PgBufMutExt;
use crate::io::{BufMutExt, Encode};
use crate::io::BufMutExt;
use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
use std::num::Saturating;
pub struct SaslInitialResponse<'a> {
pub response: &'a str,
pub plus: bool,
}
impl Encode<'_> for SaslInitialResponse<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.push(b'p');
buf.put_length_prefixed(|buf| {
// name of the SASL authentication mechanism that the client selected
buf.put_str_nul(if self.plus {
"SCRAM-SHA-256-PLUS"
} else {
"SCRAM-SHA-256"
});
impl SaslInitialResponse<'_> {
#[inline(always)]
fn selected_mechanism(&self) -> &'static str {
if self.plus {
"SCRAM-SHA-256-PLUS"
} else {
"SCRAM-SHA-256"
}
}
}
buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes());
buf.extend(self.response.as_bytes());
});
impl FrontendMessage for SaslInitialResponse<'_> {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic;
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
let mut size = Saturating(0);
size += self.selected_mechanism().len();
size += 1; // NUL terminator
size += 4; // response_len
size += self.response.len();
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
// name of the SASL authentication mechanism that the client selected
buf.put_str_nul(self.selected_mechanism());
let response_len = i32::try_from(self.response.len()).map_err(|_| {
err_protocol!(
"SASL Initial Response length too long for protcol: {}",
self.response.len()
)
})?;
buf.extend_from_slice(&response_len.to_be_bytes());
buf.extend_from_slice(self.response.as_bytes());
Ok(())
}
}
pub struct SaslResponse<'a>(pub &'a str);
impl Encode<'_> for SaslResponse<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.push(b'p');
buf.put_length_prefixed(|buf| {
buf.extend(self.0.as_bytes());
});
impl FrontendMessage for SaslResponse<'_> {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic;
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(self.0.len())
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.extend(self.0.as_bytes());
Ok(())
}
}

View file

@ -1,23 +1,38 @@
use crate::io::Encode;
use crate::io::ProtocolEncode;
pub struct SslRequest;
impl SslRequest {
pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16/";
// https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST
pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16\x2f";
}
impl Encode<'_> for SslRequest {
#[inline]
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.extend(&8_u32.to_be_bytes());
buf.extend(&(((1234 << 16) | 5679) as u32).to_be_bytes());
// Cannot impl FrontendMessage because it does not have a format code
impl ProtocolEncode<'_> for SslRequest {
#[inline(always)]
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) -> Result<(), crate::Error> {
buf.extend_from_slice(Self::BYTES);
Ok(())
}
}
#[test]
fn test_encode_ssl_request() {
let mut buf = Vec::new();
SslRequest.encode(&mut buf);
// Int32(8)
// Length of message contents in bytes, including self.
buf.extend_from_slice(&8_u32.to_be_bytes());
// Int32(80877103)
// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits,
// and 5679 in the least significant 16 bits.
// (To avoid confusion, this code must not be the same as any protocol version number.)
buf.extend_from_slice(&(((1234 << 16) | 5679) as u32).to_be_bytes());
let mut encoded = Vec::new();
SslRequest.encode(&mut encoded).unwrap();
assert_eq!(buf, SslRequest::BYTES);
assert_eq!(buf, encoded);
}

View file

@ -1,5 +1,5 @@
use crate::io::PgBufMutExt;
use crate::io::{BufMutExt, Encode};
use crate::io::{BufMutExt, ProtocolEncode};
// To begin a session, a frontend opens a connection to the server and sends a startup message.
// This message includes the names of the user and of the database the user wants to connect to;
@ -19,8 +19,9 @@ pub struct Startup<'a> {
pub params: &'a [(&'a str, &'a str)],
}
impl Encode<'_> for Startup<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
// Startup cannot impl FrontendMessage because it doesn't have a format code.
impl ProtocolEncode<'_> for Startup<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) -> Result<(), crate::Error> {
buf.reserve(120);
buf.put_length_prefixed(|buf| {
@ -47,7 +48,9 @@ impl Encode<'_> for Startup<'_> {
// A zero byte is required as a terminator
// after the last name/value pair.
buf.push(0);
});
Ok(())
})
}
}
@ -68,7 +71,7 @@ fn test_encode_startup() {
params: &[],
};
m.encode(&mut buf);
m.encode(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}

View file

@ -1,11 +1,20 @@
use crate::io::Encode;
use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
use std::num::Saturating;
#[derive(Debug)]
pub struct Sync;
impl Encode<'_> for Sync {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.push(b'S');
buf.extend(&4_i32.to_be_bytes());
impl FrontendMessage for Sync {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Sync;
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(0)
}
#[inline(always)]
fn encode_body(&self, _buf: &mut Vec<u8>) -> Result<(), Error> {
Ok(())
}
}

View file

@ -1,10 +1,19 @@
use crate::io::Encode;
use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
use std::num::Saturating;
pub struct Terminate;
impl Encode<'_> for Terminate {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
buf.push(b'X');
buf.extend(&4_u32.to_be_bytes());
impl FrontendMessage for Terminate {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Terminate;
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(0)
}
#[inline(always)]
fn encode_body(&self, _buf: &mut Vec<u8>) -> Result<(), Error> {
Ok(())
}
}

View file

@ -17,7 +17,7 @@ impl TransactionManager for PgTransactionManager {
Box::pin(async move {
let rollback = Rollback::new(conn);
let query = begin_ansi_transaction_sql(rollback.conn.transaction_depth);
rollback.conn.queue_simple_query(&query);
rollback.conn.queue_simple_query(&query)?;
rollback.conn.transaction_depth += 1;
rollback.conn.wait_until_ready().await?;
rollback.defuse();
@ -54,7 +54,8 @@ impl TransactionManager for PgTransactionManager {
fn start_rollback(conn: &mut PgConnection) {
if conn.transaction_depth > 0 {
conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth));
conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth))
.expect("BUG: Rollback query somehow too large for protocol");
conn.transaction_depth -= 1;
}

View file

@ -17,12 +17,6 @@ pub struct Oid(
pub u32,
);
impl Oid {
pub(crate) fn incr_one(&mut self) {
self.0 = self.0.wrapping_add(1);
}
}
impl Type<Postgres> for Oid {
fn type_info() -> PgTypeInfo {
PgTypeInfo::OID