refactor(postgres): make better use of traits to improve protocol handling

This commit is contained in:
Austin Bonander 2024-08-17 04:54:40 -07:00
parent 9b3808b2d5
commit 53766e4659
40 changed files with 1252 additions and 693 deletions

View file

@ -414,7 +414,8 @@ impl<'lock, C: AsMut<PgConnection>> Drop for PgAdvisoryLockGuard<'lock, C> {
// The `async fn` versions can safely use the prepared statement protocol, // The `async fn` versions can safely use the prepared statement protocol,
// but this is the safest way to queue a query to execute on the next opportunity. // but this is the safest way to queue a query to execute on the next opportunity.
conn.as_mut() conn.as_mut()
.queue_simple_query(self.lock.get_release_query()); .queue_simple_query(self.lock.get_release_query())
.expect("BUG: PgAdvisoryLock::get_release_query() somehow too long for protocol");
} }
} }
} }

View file

@ -145,6 +145,7 @@ impl<'q> Arguments<'q> for PgArguments {
write!(writer, "${}", self.buffer.count) write!(writer, "${}", self.buffer.count)
} }
#[inline(always)]
fn len(&self) -> usize { fn len(&self) -> usize {
self.buffer.count self.buffer.count
} }

View file

@ -1,5 +1,6 @@
use crate::error::Error; use crate::error::Error;
use crate::ext::ustr::UStr; use crate::ext::ustr::UStr;
use crate::io::StatementId;
use crate::message::{ParameterDescription, RowDescription}; use crate::message::{ParameterDescription, RowDescription};
use crate::query_as::query_as; use crate::query_as::query_as;
use crate::query_scalar::query_scalar; use crate::query_scalar::query_scalar;
@ -27,10 +28,12 @@ enum TypType {
Range, Range,
} }
impl TryFrom<u8> for TypType { impl TryFrom<i8> for TypType {
type Error = (); type Error = ();
fn try_from(t: u8) -> Result<Self, Self::Error> { fn try_from(t: i8) -> Result<Self, Self::Error> {
let t = u8::try_from(t).or(Err(()))?;
let t = match t { let t = match t {
b'b' => Self::Base, b'b' => Self::Base,
b'c' => Self::Composite, b'c' => Self::Composite,
@ -66,10 +69,12 @@ enum TypCategory {
Unknown, Unknown,
} }
impl TryFrom<u8> for TypCategory { impl TryFrom<i8> for TypCategory {
type Error = (); type Error = ();
fn try_from(c: u8) -> Result<Self, Self::Error> { fn try_from(c: i8) -> Result<Self, Self::Error> {
let c = u8::try_from(c).or(Err(()))?;
let c = match c { let c = match c {
b'A' => Self::Array, b'A' => Self::Array,
b'B' => Self::Boolean, b'B' => Self::Boolean,
@ -209,8 +214,8 @@ impl PgConnection {
.fetch_one(&mut *self) .fetch_one(&mut *self)
.await?; .await?;
let typ_type = TypType::try_from(typ_type as u8); let typ_type = TypType::try_from(typ_type);
let category = TypCategory::try_from(category as u8); let category = TypCategory::try_from(category);
match (typ_type, category) { match (typ_type, category) {
(Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
@ -416,7 +421,7 @@ WHERE rngtypid = $1
pub(crate) async fn get_nullable_for_columns( pub(crate) async fn get_nullable_for_columns(
&mut self, &mut self,
stmt_id: Oid, stmt_id: StatementId,
meta: &PgStatementMetadata, meta: &PgStatementMetadata,
) -> Result<Vec<Option<bool>>, Error> { ) -> Result<Vec<Option<bool>>, Error> {
if meta.columns.is_empty() { if meta.columns.is_empty() {
@ -486,13 +491,10 @@ WHERE rngtypid = $1
/// and returns `None` for all others. /// and returns `None` for all others.
async fn nullables_from_explain( async fn nullables_from_explain(
&mut self, &mut self,
stmt_id: Oid, stmt_id: StatementId,
params_len: usize, params_len: usize,
) -> Result<Vec<Option<bool>>, Error> { ) -> Result<Vec<Option<bool>>, Error> {
let mut explain = format!( let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id}");
"EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE sqlx_s_{}",
stmt_id.0
);
let mut comma = false; let mut comma = false;
if params_len > 0 { if params_len > 0 {

View file

@ -3,11 +3,10 @@ use crate::HashMap;
use crate::common::StatementCache; use crate::common::StatementCache;
use crate::connection::{sasl, stream::PgStream}; use crate::connection::{sasl, stream::PgStream};
use crate::error::Error; use crate::error::Error;
use crate::io::Decode; use crate::io::StatementId;
use crate::message::{ use crate::message::{
Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup, Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup,
}; };
use crate::types::Oid;
use crate::{PgConnectOptions, PgConnection}; use crate::{PgConnectOptions, PgConnection};
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3 // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3
@ -44,13 +43,13 @@ impl PgConnection {
params.push(("options", options)); params.push(("options", options));
} }
stream stream.write(Startup {
.send(Startup { username: Some(&options.username),
username: Some(&options.username), database: options.database.as_deref(),
database: options.database.as_deref(), params: &params,
params: &params, })?;
})
.await?; stream.flush().await?;
// The server then uses this information and the contents of // The server then uses this information and the contents of
// its configuration files (such as pg_hba.conf) to determine whether the connection is // its configuration files (such as pg_hba.conf) to determine whether the connection is
@ -64,7 +63,7 @@ impl PgConnection {
loop { loop {
let message = stream.recv().await?; let message = stream.recv().await?;
match message.format { match message.format {
MessageFormat::Authentication => match message.decode()? { BackendMessageFormat::Authentication => match message.decode()? {
Authentication::Ok => { Authentication::Ok => {
// the authentication exchange is successfully completed // the authentication exchange is successfully completed
// do nothing; no more information is required to continue // do nothing; no more information is required to continue
@ -108,7 +107,7 @@ impl PgConnection {
} }
}, },
MessageFormat::BackendKeyData => { BackendMessageFormat::BackendKeyData => {
// provides secret-key data that the frontend must save if it wants to be // provides secret-key data that the frontend must save if it wants to be
// able to issue cancel requests later // able to issue cancel requests later
@ -118,10 +117,9 @@ impl PgConnection {
secret_key = data.secret_key; secret_key = data.secret_key;
} }
MessageFormat::ReadyForQuery => { BackendMessageFormat::ReadyForQuery => {
// start-up is completed. The frontend can now issue commands // start-up is completed. The frontend can now issue commands
transaction_status = transaction_status = message.decode::<ReadyForQuery>()?.transaction_status;
ReadyForQuery::decode(message.contents)?.transaction_status;
break; break;
} }
@ -142,7 +140,7 @@ impl PgConnection {
transaction_status, transaction_status,
transaction_depth: 0, transaction_depth: 0,
pending_ready_for_query_count: 0, pending_ready_for_query_count: 0,
next_statement_id: Oid(1), next_statement_id: StatementId::NAMED_START,
cache_statement: StatementCache::new(options.statement_cache_capacity), cache_statement: StatementCache::new(options.statement_cache_capacity),
cache_type_oid: HashMap::new(), cache_type_oid: HashMap::new(),
cache_type_info: HashMap::new(), cache_type_info: HashMap::new(),

View file

@ -1,13 +1,13 @@
use crate::describe::Describe; use crate::describe::Describe;
use crate::error::Error; use crate::error::Error;
use crate::executor::{Execute, Executor}; use crate::executor::{Execute, Executor};
use crate::io::{PortalId, StatementId};
use crate::logger::QueryLogger; use crate::logger::QueryLogger;
use crate::message::{ use crate::message::{
self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query, self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
RowDescription, ParseComplete, Query, RowDescription,
}; };
use crate::statement::PgStatementMetadata; use crate::statement::PgStatementMetadata;
use crate::types::Oid;
use crate::{ use crate::{
statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo, statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
PgValueFormat, Postgres, PgValueFormat, Postgres,
@ -16,6 +16,7 @@ use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream; use futures_core::stream::BoxStream;
use futures_core::Stream; use futures_core::Stream;
use futures_util::{pin_mut, TryStreamExt}; use futures_util::{pin_mut, TryStreamExt};
use sqlx_core::arguments::Arguments;
use sqlx_core::Either; use sqlx_core::Either;
use std::{borrow::Cow, sync::Arc}; use std::{borrow::Cow, sync::Arc};
@ -24,9 +25,9 @@ async fn prepare(
sql: &str, sql: &str,
parameters: &[PgTypeInfo], parameters: &[PgTypeInfo],
metadata: Option<Arc<PgStatementMetadata>>, metadata: Option<Arc<PgStatementMetadata>>,
) -> Result<(Oid, Arc<PgStatementMetadata>), Error> { ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
let id = conn.next_statement_id; let id = conn.next_statement_id;
conn.next_statement_id.incr_one(); conn.next_statement_id = id.next();
// build a list of type OIDs to send to the database in the PARSE command // build a list of type OIDs to send to the database in the PARSE command
// we have not yet started the query sequence, so we are *safe* to cleanly make // we have not yet started the query sequence, so we are *safe* to cleanly make
@ -42,15 +43,15 @@ async fn prepare(
conn.wait_until_ready().await?; conn.wait_until_ready().await?;
// next we send the PARSE command to the server // next we send the PARSE command to the server
conn.stream.write(Parse { conn.stream.write_msg(Parse {
param_types: &param_types, param_types: &param_types,
query: sql, query: sql,
statement: id, statement: id,
}); })?;
if metadata.is_none() { if metadata.is_none() {
// get the statement columns and parameters // get the statement columns and parameters
conn.stream.write(message::Describe::Statement(id)); conn.stream.write_msg(message::Describe::Statement(id))?;
} }
// we ask for the server to immediately send us the result of the PARSE command // we ask for the server to immediately send us the result of the PARSE command
@ -58,9 +59,7 @@ async fn prepare(
conn.stream.flush().await?; conn.stream.flush().await?;
// indicates that the SQL query string is now successfully parsed and has semantic validity // indicates that the SQL query string is now successfully parsed and has semantic validity
conn.stream conn.stream.recv_expect::<ParseComplete>().await?;
.recv_expect(MessageFormat::ParseComplete)
.await?;
let metadata = if let Some(metadata) = metadata { let metadata = if let Some(metadata) = metadata {
// each SYNC produces one READY FOR QUERY // each SYNC produces one READY FOR QUERY
@ -95,18 +94,18 @@ async fn prepare(
} }
async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> { async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
conn.stream conn.stream.recv_expect().await
.recv_expect(MessageFormat::ParameterDescription)
.await
} }
async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> { async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
let rows: Option<RowDescription> = match conn.stream.recv().await? { let rows: Option<RowDescription> = match conn.stream.recv().await? {
// describes the rows that will be returned when the statement is eventually executed // describes the rows that will be returned when the statement is eventually executed
message if message.format == MessageFormat::RowDescription => Some(message.decode()?), message if message.format == BackendMessageFormat::RowDescription => {
Some(message.decode()?)
}
// no data would be returned if this statement was executed // no data would be returned if this statement was executed
message if message.format == MessageFormat::NoData => None, message if message.format == BackendMessageFormat::NoData => None,
message => { message => {
return Err(err_protocol!( return Err(err_protocol!(
@ -125,12 +124,12 @@ impl PgConnection {
// we need to wait for the [CloseComplete] to be returned from the server // we need to wait for the [CloseComplete] to be returned from the server
while count > 0 { while count > 0 {
match self.stream.recv().await? { match self.stream.recv().await? {
message if message.format == MessageFormat::PortalSuspended => { message if message.format == BackendMessageFormat::PortalSuspended => {
// there was an open portal // there was an open portal
// this can happen if the last time a statement was used it was not fully executed // this can happen if the last time a statement was used it was not fully executed
} }
message if message.format == MessageFormat::CloseComplete => { message if message.format == BackendMessageFormat::CloseComplete => {
// successfully closed the statement (and freed up the server resources) // successfully closed the statement (and freed up the server resources)
count -= 1; count -= 1;
} }
@ -147,8 +146,11 @@ impl PgConnection {
Ok(()) Ok(())
} }
#[inline(always)]
pub(crate) fn write_sync(&mut self) { pub(crate) fn write_sync(&mut self) {
self.stream.write(message::Sync); self.stream
.write_msg(message::Sync)
.expect("BUG: Sync should not be too big for protocol");
// all SYNC messages will return a ReadyForQuery // all SYNC messages will return a ReadyForQuery
self.pending_ready_for_query_count += 1; self.pending_ready_for_query_count += 1;
@ -163,7 +165,7 @@ impl PgConnection {
// optional metadata that was provided by the user, this means they are reusing // optional metadata that was provided by the user, this means they are reusing
// a statement object // a statement object
metadata: Option<Arc<PgStatementMetadata>>, metadata: Option<Arc<PgStatementMetadata>>,
) -> Result<(Oid, Arc<PgStatementMetadata>), Error> { ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
if let Some(statement) = self.cache_statement.get_mut(sql) { if let Some(statement) = self.cache_statement.get_mut(sql) {
return Ok((*statement).clone()); return Ok((*statement).clone());
} }
@ -172,7 +174,7 @@ impl PgConnection {
if store_to_cache && self.cache_statement.is_enabled() { if store_to_cache && self.cache_statement.is_enabled() {
if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) { if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
self.stream.write(Close::Statement(id)); self.stream.write_msg(Close::Statement(id))?;
self.write_sync(); self.write_sync();
self.stream.flush().await?; self.stream.flush().await?;
@ -201,6 +203,14 @@ impl PgConnection {
let mut metadata: Arc<PgStatementMetadata>; let mut metadata: Arc<PgStatementMetadata>;
let format = if let Some(mut arguments) = arguments { let format = if let Some(mut arguments) = arguments {
// Check this before we write anything to the stream.
let num_params = i16::try_from(arguments.len()).map_err(|_| {
err_protocol!(
"PgConnection::run(): too many arguments for query: {}",
arguments.len()
)
})?;
// prepare the statement if this our first time executing it // prepare the statement if this our first time executing it
// always return the statement ID here // always return the statement ID here
let (statement, metadata_) = self let (statement, metadata_) = self
@ -216,21 +226,21 @@ impl PgConnection {
self.wait_until_ready().await?; self.wait_until_ready().await?;
// bind to attach the arguments to the statement and create a portal // bind to attach the arguments to the statement and create a portal
self.stream.write(Bind { self.stream.write_msg(Bind {
portal: None, portal: PortalId::UNNAMED,
statement, statement,
formats: &[PgValueFormat::Binary], formats: &[PgValueFormat::Binary],
num_params: arguments.types.len() as i16, num_params,
params: &arguments.buffer, params: &arguments.buffer,
result_formats: &[PgValueFormat::Binary], result_formats: &[PgValueFormat::Binary],
}); })?;
// executes the portal up to the passed limit // executes the portal up to the passed limit
// the protocol-level limit acts nearly identically to the `LIMIT` in SQL // the protocol-level limit acts nearly identically to the `LIMIT` in SQL
self.stream.write(message::Execute { self.stream.write_msg(message::Execute {
portal: None, portal: PortalId::UNNAMED,
limit: limit.into(), limit: limit.into(),
}); })?;
// From https://www.postgresql.org/docs/current/protocol-flow.html: // From https://www.postgresql.org/docs/current/protocol-flow.html:
// //
// "An unnamed portal is destroyed at the end of the transaction, or as // "An unnamed portal is destroyed at the end of the transaction, or as
@ -240,7 +250,7 @@ impl PgConnection {
// we ask the database server to close the unnamed portal and free the associated resources // we ask the database server to close the unnamed portal and free the associated resources
// earlier - after the execution of the current query. // earlier - after the execution of the current query.
self.stream.write(message::Close::Portal(None)); self.stream.write_msg(Close::Portal(PortalId::UNNAMED))?;
// finally, [Sync] asks postgres to process the messages that we sent and respond with // finally, [Sync] asks postgres to process the messages that we sent and respond with
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send // a [ReadyForQuery] message when it's completely done. Theoretically, we could send
@ -253,7 +263,7 @@ impl PgConnection {
PgValueFormat::Binary PgValueFormat::Binary
} else { } else {
// Query will trigger a ReadyForQuery // Query will trigger a ReadyForQuery
self.stream.write(Query(query)); self.stream.write_msg(Query(query))?;
self.pending_ready_for_query_count += 1; self.pending_ready_for_query_count += 1;
// metadata starts out as "nothing" // metadata starts out as "nothing"
@ -270,12 +280,12 @@ impl PgConnection {
let message = self.stream.recv().await?; let message = self.stream.recv().await?;
match message.format { match message.format {
MessageFormat::BindComplete BackendMessageFormat::BindComplete
| MessageFormat::ParseComplete | BackendMessageFormat::ParseComplete
| MessageFormat::ParameterDescription | BackendMessageFormat::ParameterDescription
| MessageFormat::NoData | BackendMessageFormat::NoData
// unnamed portal has been closed // unnamed portal has been closed
| MessageFormat::CloseComplete | BackendMessageFormat::CloseComplete
=> { => {
// harmless messages to ignore // harmless messages to ignore
} }
@ -284,7 +294,7 @@ impl PgConnection {
// exactly one of these messages: CommandComplete, // exactly one of these messages: CommandComplete,
// EmptyQueryResponse (if the portal was created from an // EmptyQueryResponse (if the portal was created from an
// empty query string), ErrorResponse, or PortalSuspended" // empty query string), ErrorResponse, or PortalSuspended"
MessageFormat::CommandComplete => { BackendMessageFormat::CommandComplete => {
// a SQL command completed normally // a SQL command completed normally
let cc: CommandComplete = message.decode()?; let cc: CommandComplete = message.decode()?;
@ -295,16 +305,16 @@ impl PgConnection {
})); }));
} }
MessageFormat::EmptyQueryResponse => { BackendMessageFormat::EmptyQueryResponse => {
// empty query string passed to an unprepared execute // empty query string passed to an unprepared execute
} }
// Message::ErrorResponse is handled in self.stream.recv() // Message::ErrorResponse is handled in self.stream.recv()
// incomplete query execution has finished // incomplete query execution has finished
MessageFormat::PortalSuspended => {} BackendMessageFormat::PortalSuspended => {}
MessageFormat::RowDescription => { BackendMessageFormat::RowDescription => {
// indicates that a *new* set of rows are about to be returned // indicates that a *new* set of rows are about to be returned
let (columns, column_names) = self let (columns, column_names) = self
.handle_row_description(Some(message.decode()?), false) .handle_row_description(Some(message.decode()?), false)
@ -317,7 +327,7 @@ impl PgConnection {
}); });
} }
MessageFormat::DataRow => { BackendMessageFormat::DataRow => {
logger.increment_rows_returned(); logger.increment_rows_returned();
// one of the set of rows returned by a SELECT, FETCH, etc query // one of the set of rows returned by a SELECT, FETCH, etc query
@ -331,7 +341,7 @@ impl PgConnection {
r#yield!(Either::Right(row)); r#yield!(Either::Right(row));
} }
MessageFormat::ReadyForQuery => { BackendMessageFormat::ReadyForQuery => {
// processing of the query string is complete // processing of the query string is complete
self.handle_ready_for_query(message)?; self.handle_ready_for_query(message)?;
break; break;

View file

@ -8,9 +8,10 @@ use futures_util::FutureExt;
use crate::common::StatementCache; use crate::common::StatementCache;
use crate::error::Error; use crate::error::Error;
use crate::ext::ustr::UStr; use crate::ext::ustr::UStr;
use crate::io::Decode; use crate::io::StatementId;
use crate::message::{ use crate::message::{
Close, Message, MessageFormat, Query, ReadyForQuery, Terminate, TransactionStatus, BackendMessageFormat, Close, Query, ReadyForQuery, ReceivedMessage, Terminate,
TransactionStatus,
}; };
use crate::statement::PgStatementMetadata; use crate::statement::PgStatementMetadata;
use crate::transaction::Transaction; use crate::transaction::Transaction;
@ -47,10 +48,10 @@ pub struct PgConnection {
// sequence of statement IDs for use in preparing statements // sequence of statement IDs for use in preparing statements
// in PostgreSQL, the statement is prepared to a user-supplied identifier // in PostgreSQL, the statement is prepared to a user-supplied identifier
next_statement_id: Oid, next_statement_id: StatementId,
// cache statement by query string to the id and columns // cache statement by query string to the id and columns
cache_statement: StatementCache<(Oid, Arc<PgStatementMetadata>)>, cache_statement: StatementCache<(StatementId, Arc<PgStatementMetadata>)>,
// cache user-defined types by id <-> info // cache user-defined types by id <-> info
cache_type_info: HashMap<Oid, PgTypeInfo>, cache_type_info: HashMap<Oid, PgTypeInfo>,
@ -82,7 +83,7 @@ impl PgConnection {
while self.pending_ready_for_query_count > 0 { while self.pending_ready_for_query_count > 0 {
let message = self.stream.recv().await?; let message = self.stream.recv().await?;
if let MessageFormat::ReadyForQuery = message.format { if let BackendMessageFormat::ReadyForQuery = message.format {
self.handle_ready_for_query(message)?; self.handle_ready_for_query(message)?;
} }
} }
@ -91,10 +92,7 @@ impl PgConnection {
} }
async fn recv_ready_for_query(&mut self) -> Result<(), Error> { async fn recv_ready_for_query(&mut self) -> Result<(), Error> {
let r: ReadyForQuery = self let r: ReadyForQuery = self.stream.recv_expect().await?;
.stream
.recv_expect(MessageFormat::ReadyForQuery)
.await?;
self.pending_ready_for_query_count -= 1; self.pending_ready_for_query_count -= 1;
self.transaction_status = r.transaction_status; self.transaction_status = r.transaction_status;
@ -102,9 +100,10 @@ impl PgConnection {
Ok(()) Ok(())
} }
fn handle_ready_for_query(&mut self, message: Message) -> Result<(), Error> { #[inline(always)]
fn handle_ready_for_query(&mut self, message: ReceivedMessage) -> Result<(), Error> {
self.pending_ready_for_query_count -= 1; self.pending_ready_for_query_count -= 1;
self.transaction_status = ReadyForQuery::decode(message.contents)?.transaction_status; self.transaction_status = message.decode::<ReadyForQuery>()?.transaction_status;
Ok(()) Ok(())
} }
@ -112,9 +111,12 @@ impl PgConnection {
/// Queue a simple query (not prepared) to execute the next time this connection is used. /// Queue a simple query (not prepared) to execute the next time this connection is used.
/// ///
/// Used for rolling back transactions and releasing advisory locks. /// Used for rolling back transactions and releasing advisory locks.
pub(crate) fn queue_simple_query(&mut self, query: &str) { #[inline(always)]
pub(crate) fn queue_simple_query(&mut self, query: &str) -> Result<(), Error> {
self.stream.write_msg(Query(query))?;
self.pending_ready_for_query_count += 1; self.pending_ready_for_query_count += 1;
self.stream.write(Query(query));
Ok(())
} }
} }
@ -184,7 +186,7 @@ impl Connection for PgConnection {
self.wait_until_ready().await?; self.wait_until_ready().await?;
while let Some((id, _)) = self.cache_statement.remove_lru() { while let Some((id, _)) = self.cache_statement.remove_lru() {
self.stream.write(Close::Statement(id)); self.stream.write_msg(Close::Statement(id))?;
cleared += 1; cleared += 1;
} }

View file

@ -1,8 +1,6 @@
use crate::connection::stream::PgStream; use crate::connection::stream::PgStream;
use crate::error::Error; use crate::error::Error;
use crate::message::{ use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse};
Authentication, AuthenticationSasl, MessageFormat, SaslInitialResponse, SaslResponse,
};
use crate::PgConnectOptions; use crate::PgConnectOptions;
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use rand::Rng; use rand::Rng;
@ -76,7 +74,7 @@ pub(crate) async fn authenticate(
}) })
.await?; .await?;
let cont = match stream.recv_expect(MessageFormat::Authentication).await? { let cont = match stream.recv_expect().await? {
Authentication::SaslContinue(data) => data, Authentication::SaslContinue(data) => data,
auth => { auth => {
@ -147,7 +145,7 @@ pub(crate) async fn authenticate(
stream.send(SaslResponse(&client_final_message)).await?; stream.send(SaslResponse(&client_final_message)).await?;
let data = match stream.recv_expect(MessageFormat::Authentication).await? { let data = match stream.recv_expect().await? {
Authentication::SaslFinal(data) => data, Authentication::SaslFinal(data) => data,
auth => { auth => {
@ -172,10 +170,10 @@ fn gen_nonce() -> String {
// ;; a valid "value". // ;; a valid "value".
let nonce: String = std::iter::repeat(()) let nonce: String = std::iter::repeat(())
.map(|()| { .map(|()| {
let mut c = rng.gen_range(0x21..0x7F) as u8; let mut c = rng.gen_range(0x21u8..0x7F);
while c == 0x2C { while c == 0x2C {
c = rng.gen_range(0x21..0x7F) as u8; c = rng.gen_range(0x21u8..0x7F);
} }
c c

View file

@ -9,8 +9,10 @@ use sqlx_core::bytes::{Buf, Bytes};
use crate::connection::tls::MaybeUpgradeTls; use crate::connection::tls::MaybeUpgradeTls;
use crate::error::Error; use crate::error::Error;
use crate::io::{Decode, Encode}; use crate::message::{
use crate::message::{Message, MessageFormat, Notice, Notification, ParameterStatus}; BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification,
ParameterStatus, ReceivedMessage,
};
use crate::net::{self, BufferedSocket, Socket}; use crate::net::{self, BufferedSocket, Socket};
use crate::{PgConnectOptions, PgDatabaseError, PgSeverity}; use crate::{PgConnectOptions, PgDatabaseError, PgSeverity};
@ -55,59 +57,51 @@ impl PgStream {
}) })
} }
pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error> #[inline(always)]
pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> {
self.write(EncodeMessage(message))
}
pub(crate) async fn send<T>(&mut self, message: T) -> Result<(), Error>
where where
T: Encode<'en>, T: FrontendMessage,
{ {
self.write(message); self.write_msg(message)?;
self.flush().await?; self.flush().await?;
Ok(()) Ok(())
} }
// Expect a specific type and format // Expect a specific type and format
pub(crate) async fn recv_expect<'de, T: Decode<'de>>( pub(crate) async fn recv_expect<B: BackendMessage>(&mut self) -> Result<B, Error> {
&mut self, self.recv().await?.decode()
format: MessageFormat,
) -> Result<T, Error> {
let message = self.recv().await?;
if message.format != format {
return Err(err_protocol!(
"expecting {:?} but received {:?}",
format,
message.format
));
}
message.decode()
} }
pub(crate) async fn recv_unchecked(&mut self) -> Result<Message, Error> { pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
// all packets in postgres start with a 5-byte header // all packets in postgres start with a 5-byte header
// this header contains the message type and the total length of the message // this header contains the message type and the total length of the message
let mut header: Bytes = self.inner.read(5).await?; let mut header: Bytes = self.inner.read(5).await?;
let format = MessageFormat::try_from_u8(header.get_u8())?; let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
let size = (header.get_u32() - 4) as usize; let size = (header.get_u32() - 4) as usize;
let contents = self.inner.read(size).await?; let contents = self.inner.read(size).await?;
Ok(Message { format, contents }) Ok(ReceivedMessage { format, contents })
} }
// Get the next message from the server // Get the next message from the server
// May wait for more data from the server // May wait for more data from the server
pub(crate) async fn recv(&mut self) -> Result<Message, Error> { pub(crate) async fn recv(&mut self) -> Result<ReceivedMessage, Error> {
loop { loop {
let message = self.recv_unchecked().await?; let message = self.recv_unchecked().await?;
match message.format { match message.format {
MessageFormat::ErrorResponse => { BackendMessageFormat::ErrorResponse => {
// An error returned from the database server. // An error returned from the database server.
return Err(PgDatabaseError(message.decode()?).into()); return Err(PgDatabaseError(message.decode()?).into());
} }
MessageFormat::NotificationResponse => { BackendMessageFormat::NotificationResponse => {
if let Some(buffer) = &mut self.notifications { if let Some(buffer) = &mut self.notifications {
let notification: Notification = message.decode()?; let notification: Notification = message.decode()?;
let _ = buffer.send(notification).await; let _ = buffer.send(notification).await;
@ -116,7 +110,7 @@ impl PgStream {
} }
} }
MessageFormat::ParameterStatus => { BackendMessageFormat::ParameterStatus => {
// informs the frontend about the current (initial) // informs the frontend about the current (initial)
// setting of backend parameters // setting of backend parameters
@ -135,7 +129,7 @@ impl PgStream {
continue; continue;
} }
MessageFormat::NoticeResponse => { BackendMessageFormat::NoticeResponse => {
// do we need this to be more configurable? // do we need this to be more configurable?
// if you are reading this comment and think so, open an issue // if you are reading this comment and think so, open an issue

View file

@ -11,7 +11,8 @@ use crate::error::{Error, Result};
use crate::ext::async_stream::TryAsyncStream; use crate::ext::async_stream::TryAsyncStream;
use crate::io::AsyncRead; use crate::io::AsyncRead;
use crate::message::{ use crate::message::{
CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query, BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse,
CopyOutResponse, CopyResponseData, Query, ReadyForQuery,
}; };
use crate::pool::{Pool, PoolConnection}; use crate::pool::{Pool, PoolConnection};
use crate::Postgres; use crate::Postgres;
@ -138,7 +139,7 @@ impl PgPoolCopyExt for Pool<Postgres> {
#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"] #[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> { pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
conn: Option<C>, conn: Option<C>,
response: CopyResponse, response: CopyResponseData,
} }
impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> { impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
@ -146,8 +147,8 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
conn.wait_until_ready().await?; conn.wait_until_ready().await?;
conn.stream.send(Query(statement)).await?; conn.stream.send(Query(statement)).await?;
let response = match conn.stream.recv_expect(MessageFormat::CopyInResponse).await { let response = match conn.stream.recv_expect::<CopyInResponse>().await {
Ok(res) => res, Ok(res) => res.0,
Err(e) => { Err(e) => {
conn.stream.recv().await?; conn.stream.recv().await?;
return Err(e); return Err(e);
@ -168,7 +169,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
/// Returns the number of columns expected in the input. /// Returns the number of columns expected in the input.
pub fn num_columns(&self) -> usize { pub fn num_columns(&self) -> usize {
assert_eq!( assert_eq!(
self.response.num_columns as usize, self.response.num_columns.unsigned_abs() as usize,
self.response.format_codes.len(), self.response.format_codes.len(),
"num_columns does not match format_codes.len()" "num_columns does not match format_codes.len()"
); );
@ -261,9 +262,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
match e.code() { match e.code() {
Some(Cow::Borrowed("57014")) => { Some(Cow::Borrowed("57014")) => {
// postgres abort received error code // postgres abort received error code
conn.stream conn.stream.recv_expect::<ReadyForQuery>().await?;
.recv_expect(MessageFormat::ReadyForQuery)
.await?;
Ok(()) Ok(())
} }
_ => Err(Error::Database(e)), _ => Err(Error::Database(e)),
@ -283,11 +282,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
.expect("CopyWriter::finish: conn taken illegally"); .expect("CopyWriter::finish: conn taken illegally");
conn.stream.send(CopyDone).await?; conn.stream.send(CopyDone).await?;
let cc: CommandComplete = match conn let cc: CommandComplete = match conn.stream.recv_expect().await {
.stream
.recv_expect(MessageFormat::CommandComplete)
.await
{
Ok(cc) => cc, Ok(cc) => cc,
Err(e) => { Err(e) => {
conn.stream.recv().await?; conn.stream.recv().await?;
@ -295,9 +290,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
} }
}; };
conn.stream conn.stream.recv_expect::<ReadyForQuery>().await?;
.recv_expect(MessageFormat::ReadyForQuery)
.await?;
Ok(cc.rows_affected()) Ok(cc.rows_affected())
} }
@ -306,9 +299,11 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> { impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(mut conn) = self.conn.take() { if let Some(mut conn) = self.conn.take() {
conn.stream.write(CopyFail::new( conn.stream
"PgCopyIn dropped without calling finish() or fail()", .write_msg(CopyFail::new(
)); "PgCopyIn dropped without calling finish() or fail()",
))
.expect("BUG: PgCopyIn abort message should not be too large");
} }
} }
} }
@ -320,24 +315,21 @@ async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
conn.wait_until_ready().await?; conn.wait_until_ready().await?;
conn.stream.send(Query(statement)).await?; conn.stream.send(Query(statement)).await?;
let _: CopyResponse = conn let _: CopyOutResponse = conn.stream.recv_expect().await?;
.stream
.recv_expect(MessageFormat::CopyOutResponse)
.await?;
let stream: TryAsyncStream<'c, Bytes> = try_stream! { let stream: TryAsyncStream<'c, Bytes> = try_stream! {
loop { loop {
match conn.stream.recv().await { match conn.stream.recv().await {
Err(e) => { Err(e) => {
conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; conn.stream.recv_expect::<ReadyForQuery>().await?;
return Err(e); return Err(e);
}, },
Ok(msg) => match msg.format { Ok(msg) => match msg.format {
MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0), BackendMessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
MessageFormat::CopyDone => { BackendMessageFormat::CopyDone => {
let _ = msg.decode::<CopyDone>()?; let _ = msg.decode::<CopyDone>()?;
conn.stream.recv_expect(MessageFormat::CommandComplete).await?; conn.stream.recv_expect::<CommandComplete>().await?;
conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; conn.stream.recv_expect::<ReadyForQuery>().await?;
return Ok(()) return Ok(())
}, },
_ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))

View file

@ -1,54 +1,64 @@
use crate::types::Oid; use crate::io::{PortalId, StatementId};
pub trait PgBufMutExt { pub trait PgBufMutExt {
fn put_length_prefixed<F>(&mut self, f: F) fn put_length_prefixed<F>(&mut self, f: F) -> Result<(), crate::Error>
where where
F: FnOnce(&mut Vec<u8>); F: FnOnce(&mut Vec<u8>) -> Result<(), crate::Error>;
fn put_statement_name(&mut self, id: Oid); fn put_statement_name(&mut self, id: StatementId);
fn put_portal_name(&mut self, id: Option<Oid>); fn put_portal_name(&mut self, id: PortalId);
} }
impl PgBufMutExt for Vec<u8> { impl PgBufMutExt for Vec<u8> {
// writes a length-prefixed message, this is used when encoding nearly all messages as postgres // writes a length-prefixed message, this is used when encoding nearly all messages as postgres
// wants us to send the length of the often-variable-sized messages up front // wants us to send the length of the often-variable-sized messages up front
fn put_length_prefixed<F>(&mut self, f: F) fn put_length_prefixed<F>(&mut self, write_contents: F) -> Result<(), crate::Error>
where where
F: FnOnce(&mut Vec<u8>), F: FnOnce(&mut Vec<u8>) -> Result<(), crate::Error>,
{ {
// reserve space to write the prefixed length // reserve space to write the prefixed length
let offset = self.len(); let offset = self.len();
self.extend(&[0; 4]); self.extend(&[0; 4]);
// write the main body of the message // write the main body of the message
f(self); let write_result = write_contents(self);
// now calculate the size of what we wrote and set the length value let size_result = write_result.and_then(|_| {
let size = (self.len() - offset) as i32; let size = self.len() - offset;
self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); i32::try_from(size)
.map_err(|_| err_protocol!("message size out of range for Pg protocol: {size"))
});
match size_result {
Ok(size) => {
// now calculate the size of what we wrote and set the length value
self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes());
Ok(())
}
Err(e) => {
// Put the buffer back to where it was.
self.truncate(offset);
Err(e)
}
}
} }
// writes a statement name by ID // writes a statement name by ID
#[inline] #[inline]
fn put_statement_name(&mut self, id: Oid) { fn put_statement_name(&mut self, id: StatementId) {
// N.B. if you change this don't forget to update it in ../describe.rs let _: Result<(), ()> = id.write_name(|s| {
self.extend(b"sqlx_s_"); self.extend_from_slice(s.as_bytes());
Ok(())
self.extend(itoa::Buffer::new().format(id.0).as_bytes()); });
self.push(0);
} }
// writes a portal name by ID // writes a portal name by ID
#[inline] #[inline]
fn put_portal_name(&mut self, id: Option<Oid>) { fn put_portal_name(&mut self, id: PortalId) {
if let Some(id) = id { let _: Result<(), ()> = id.write_name(|s| {
self.extend(b"sqlx_p_"); self.extend_from_slice(s.as_bytes());
Ok(())
self.extend(itoa::Buffer::new().format(id.0).as_bytes()); });
}
self.push(0);
} }
} }

View file

@ -1,5 +1,130 @@
mod buf_mut; mod buf_mut;
pub use buf_mut::PgBufMutExt; pub use buf_mut::PgBufMutExt;
use std::fmt;
use std::fmt::{Display, Formatter};
use std::num::{NonZeroU32, Saturating};
pub(crate) use sqlx_core::io::*; pub(crate) use sqlx_core::io::*;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) struct StatementId(IdInner);
#[allow(dead_code)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub(crate) struct PortalId(IdInner);
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
struct IdInner(Option<NonZeroU32>);
impl StatementId {
pub const UNNAMED: Self = Self(IdInner::UNNAMED);
pub const NAMED_START: Self = Self(IdInner::NAMED_START);
#[cfg(test)]
pub const TEST_VAL: Self = Self(IdInner::TEST_VAL);
const NAME_PREFIX: &'static str = "sqlx_s_";
pub fn next(&self) -> Self {
Self(self.0.next())
}
pub fn name_len(&self) -> Saturating<usize> {
self.0.name_len(Self::NAME_PREFIX)
}
// There's no common trait implemented by `Formatter` and `Vec<u8>` for this purpose;
// we're deliberately avoiding the formatting machinery because it's known to be slow.
pub fn write_name<E>(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> {
self.0.write_name(Self::NAME_PREFIX, write)
}
}
impl Display for StatementId {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
self.write_name(|s| f.write_str(s))
}
}
#[allow(dead_code)]
impl PortalId {
// None selects the unnamed portal
pub const UNNAMED: Self = PortalId(IdInner::UNNAMED);
pub const NAMED_START: Self = PortalId(IdInner::NAMED_START);
#[cfg(test)]
pub const TEST_VAL: Self = Self(IdInner::TEST_VAL);
const NAME_PREFIX: &'static str = "sqlx_p_";
/// If ID represents a named portal, return the next ID, wrapping on overflow.
///
/// If this ID represents the unnamed portal, return the same.
pub fn next(&self) -> Self {
Self(self.0.next())
}
/// Calculate the number of bytes that will be written by [`Self::write_name()`].
pub fn name_len(&self) -> Saturating<usize> {
self.0.name_len(Self::NAME_PREFIX)
}
pub fn write_name<E>(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> {
self.0.write_name(Self::NAME_PREFIX, write)
}
}
impl IdInner {
const UNNAMED: Self = Self(None);
const NAMED_START: Self = Self(Some(NonZeroU32::MIN));
#[cfg(test)]
pub const TEST_VAL: Self = Self(NonZeroU32::new(1234567890));
#[inline(always)]
fn next(&self) -> Self {
Self(
self.0
.map(|id| id.checked_add(1).unwrap_or(NonZeroU32::MIN)),
)
}
#[inline(always)]
fn name_len(&self, name_prefix: &str) -> Saturating<usize> {
let mut len = Saturating(0);
if let Some(id) = self.0 {
len += name_prefix.len();
// estimate the length of the ID in decimal
// `.ilog10()` can't panic since the value is never zero
len += id.get().ilog10() as usize;
// add one to compensate for `ilog10()` rounding down.
len += 1;
}
// count the NUL terminator
len += 1;
len
}
#[inline(always)]
fn write_name<E>(
&self,
name_prefix: &str,
mut write: impl FnMut(&str) -> Result<(), E>,
) -> Result<(), E> {
if let Some(id) = self.0 {
write(name_prefix)?;
write(itoa::Buffer::new().format(id.get()))?;
}
write("\0")?;
Ok(())
}
}

View file

@ -11,7 +11,7 @@ use sqlx_core::Either;
use crate::describe::Describe; use crate::describe::Describe;
use crate::error::Error; use crate::error::Error;
use crate::executor::{Execute, Executor}; use crate::executor::{Execute, Executor};
use crate::message::{MessageFormat, Notification}; use crate::message::{BackendMessageFormat, Notification};
use crate::pool::PoolOptions; use crate::pool::PoolOptions;
use crate::pool::{Pool, PoolConnection}; use crate::pool::{Pool, PoolConnection};
use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres}; use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
@ -277,12 +277,12 @@ impl PgListener {
match message.format { match message.format {
// We've received an async notification, return it. // We've received an async notification, return it.
MessageFormat::NotificationResponse => { BackendMessageFormat::NotificationResponse => {
return Ok(Some(PgNotification(message.decode()?))); return Ok(Some(PgNotification(message.decode()?)));
} }
// Mark the connection as ready for another query // Mark the connection as ready for another query
MessageFormat::ReadyForQuery => { BackendMessageFormat::ReadyForQuery => {
self.connection().await?.pending_ready_for_query_count -= 1; self.connection().await?.pending_ready_for_query_count -= 1;
} }

View file

@ -4,10 +4,10 @@ use memchr::memchr;
use sqlx_core::bytes::{Buf, Bytes}; use sqlx_core::bytes::{Buf, Bytes};
use crate::error::Error; use crate::error::Error;
use crate::io::Decode; use crate::io::ProtocolDecode;
use crate::message::{BackendMessage, BackendMessageFormat};
use base64::prelude::{Engine as _, BASE64_STANDARD}; use base64::prelude::{Engine as _, BASE64_STANDARD};
// On startup, the server sends an appropriate authentication request message, // On startup, the server sends an appropriate authentication request message,
// to which the frontend must reply with an appropriate authentication // to which the frontend must reply with an appropriate authentication
// response message (such as a password). // response message (such as a password).
@ -60,8 +60,10 @@ pub enum Authentication {
SaslFinal(AuthenticationSaslFinal), SaslFinal(AuthenticationSaslFinal),
} }
impl Decode<'_> for Authentication { impl BackendMessage for Authentication {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> { const FORMAT: BackendMessageFormat = BackendMessageFormat::Authentication;
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
Ok(match buf.get_u32() { Ok(match buf.get_u32() {
0 => Authentication::Ok, 0 => Authentication::Ok,
@ -129,7 +131,7 @@ pub struct AuthenticationSaslContinue {
pub message: String, pub message: String,
} }
impl Decode<'_> for AuthenticationSaslContinue { impl ProtocolDecode<'_> for AuthenticationSaslContinue {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> { fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
let mut iterations: u32 = 4096; let mut iterations: u32 = 4096;
let mut salt = Vec::new(); let mut salt = Vec::new();
@ -173,7 +175,7 @@ pub struct AuthenticationSaslFinal {
pub verifier: Vec<u8>, pub verifier: Vec<u8>,
} }
impl Decode<'_> for AuthenticationSaslFinal { impl ProtocolDecode<'_> for AuthenticationSaslFinal {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> { fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
let mut verifier = Vec::new(); let mut verifier = Vec::new();

View file

@ -2,7 +2,7 @@ use byteorder::{BigEndian, ByteOrder};
use sqlx_core::bytes::Bytes; use sqlx_core::bytes::Bytes;
use crate::error::Error; use crate::error::Error;
use crate::io::Decode; use crate::message::{BackendMessage, BackendMessageFormat};
/// Contains cancellation key data. The frontend must save these values if it /// Contains cancellation key data. The frontend must save these values if it
/// wishes to be able to issue `CancelRequest` messages later. /// wishes to be able to issue `CancelRequest` messages later.
@ -15,8 +15,10 @@ pub struct BackendKeyData {
pub secret_key: u32, pub secret_key: u32,
} }
impl Decode<'_> for BackendKeyData { impl BackendMessage for BackendKeyData {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> { const FORMAT: BackendMessageFormat = BackendMessageFormat::BackendKeyData;
fn decode_body(buf: Bytes) -> Result<Self, Error> {
let process_id = BigEndian::read_u32(&buf); let process_id = BigEndian::read_u32(&buf);
let secret_key = BigEndian::read_u32(&buf[4..]); let secret_key = BigEndian::read_u32(&buf[4..]);
@ -31,7 +33,7 @@ impl Decode<'_> for BackendKeyData {
fn test_decode_backend_key_data() { fn test_decode_backend_key_data() {
const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+";
let m = BackendKeyData::decode(DATA.into()).unwrap(); let m = BackendKeyData::decode_body(DATA.into()).unwrap();
assert_eq!(m.process_id, 10182); assert_eq!(m.process_id, 10182);
assert_eq!(m.secret_key, 2303903019); assert_eq!(m.secret_key, 2303903019);
@ -43,6 +45,6 @@ fn bench_decode_backend_key_data(b: &mut test::Bencher) {
const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+";
b.iter(|| { b.iter(|| {
BackendKeyData::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); BackendKeyData::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
}); });
} }

View file

@ -1,15 +1,15 @@
use crate::io::Encode; use crate::io::{PgBufMutExt, PortalId, StatementId};
use crate::io::PgBufMutExt; use crate::message::{FrontendMessage, FrontendMessageFormat};
use crate::types::Oid;
use crate::PgValueFormat; use crate::PgValueFormat;
use std::num::Saturating;
#[derive(Debug)] #[derive(Debug)]
pub struct Bind<'a> { pub struct Bind<'a> {
/// The ID of the destination portal (`None` selects the unnamed portal). /// The ID of the destination portal (`PortalId::UNNAMED` selects the unnamed portal).
pub portal: Option<Oid>, pub portal: PortalId,
/// The id of the source prepared statement. /// The id of the source prepared statement.
pub statement: Oid, pub statement: StatementId,
/// The parameter format codes. Each must presently be zero (text) or one (binary). /// The parameter format codes. Each must presently be zero (text) or one (binary).
/// ///
@ -19,6 +19,8 @@ pub struct Bind<'a> {
pub formats: &'a [PgValueFormat], pub formats: &'a [PgValueFormat],
/// The number of parameters. /// The number of parameters.
///
/// May be different from `formats.len()`
pub num_params: i16, pub num_params: i16,
/// The value of each parameter, in the indicated format. /// The value of each parameter, in the indicated format.
@ -33,31 +35,59 @@ pub struct Bind<'a> {
pub result_formats: &'a [PgValueFormat], pub result_formats: &'a [PgValueFormat],
} }
impl Encode<'_> for Bind<'_> { impl FrontendMessage for Bind<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Bind;
buf.push(b'B');
buf.put_length_prefixed(|buf| { fn body_size_hint(&self) -> Saturating<usize> {
buf.put_portal_name(self.portal); let mut size = Saturating(0);
size += self.portal.name_len();
size += self.statement.name_len();
buf.put_statement_name(self.statement); // Parameter formats and length prefix
size += 2;
size += self.formats.len();
buf.extend(&(self.formats.len() as i16).to_be_bytes()); // `num_params`
size += 2;
for &format in self.formats { size += self.params.len();
buf.extend(&(format as i16).to_be_bytes());
}
buf.extend(&self.num_params.to_be_bytes()); // Result formats and length prefix
size += 2;
size += self.result_formats.len();
buf.extend(self.params); size
}
buf.extend(&(self.result_formats.len() as i16).to_be_bytes()); fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), crate::Error> {
buf.put_portal_name(self.portal);
for &format in self.result_formats { buf.put_statement_name(self.statement);
buf.extend(&(format as i16).to_be_bytes());
} let formats_len = i16::try_from(self.formats.len()).map_err(|_| {
}); err_protocol!("too many parameter format codes ({})", self.formats.len())
})?;
buf.extend(formats_len.to_be_bytes());
for &format in self.formats {
buf.extend((format as i16).to_be_bytes());
}
buf.extend(self.num_params.to_be_bytes());
buf.extend(self.params);
let result_formats_len = i16::try_from(self.formats.len())
.map_err(|_| err_protocol!("too many result format codes ({})", self.formats.len()))?;
buf.extend(result_formats_len.to_be_bytes());
for &format in self.result_formats {
buf.extend((format as i16).to_be_bytes());
}
Ok(())
} }
} }

View file

@ -1,6 +1,6 @@
use crate::io::Encode; use crate::io::{PgBufMutExt, PortalId, StatementId};
use crate::io::PgBufMutExt; use crate::message::{FrontendMessage, FrontendMessageFormat};
use crate::types::Oid; use std::num::Saturating;
const CLOSE_PORTAL: u8 = b'P'; const CLOSE_PORTAL: u8 = b'P';
const CLOSE_STATEMENT: u8 = b'S'; const CLOSE_STATEMENT: u8 = b'S';
@ -8,18 +8,27 @@ const CLOSE_STATEMENT: u8 = b'S';
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] #[allow(dead_code)]
pub enum Close { pub enum Close {
Statement(Oid), Statement(StatementId),
// None selects the unnamed portal Portal(PortalId),
Portal(Option<Oid>),
} }
impl Encode<'_> for Close { impl FrontendMessage for Close {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Close;
// 15 bytes for 1-digit statement/portal IDs
buf.reserve(20);
buf.push(b'C');
buf.put_length_prefixed(|buf| match self { fn body_size_hint(&self) -> Saturating<usize> {
// Either `CLOSE_PORTAL` or `CLOSE_STATEMENT`
let mut size = Saturating(1);
match self {
Close::Statement(id) => size += id.name_len(),
Close::Portal(id) => size += id.name_len(),
}
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), crate::Error> {
match self {
Close::Statement(id) => { Close::Statement(id) => {
buf.push(CLOSE_STATEMENT); buf.push(CLOSE_STATEMENT);
buf.put_statement_name(*id); buf.put_statement_name(*id);
@ -29,6 +38,8 @@ impl Encode<'_> for Close {
buf.push(CLOSE_PORTAL); buf.push(CLOSE_PORTAL);
buf.put_portal_name(*id); buf.put_portal_name(*id);
} }
}) }
Ok(())
} }
} }

View file

@ -3,7 +3,7 @@ use memchr::memrchr;
use sqlx_core::bytes::Bytes; use sqlx_core::bytes::Bytes;
use crate::error::Error; use crate::error::Error;
use crate::io::Decode; use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug)] #[derive(Debug)]
pub struct CommandComplete { pub struct CommandComplete {
@ -12,10 +12,11 @@ pub struct CommandComplete {
tag: Bytes, tag: Bytes,
} }
impl Decode<'_> for CommandComplete { impl BackendMessage for CommandComplete {
#[inline] const FORMAT: BackendMessageFormat = BackendMessageFormat::CommandComplete;
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
Ok(CommandComplete { tag: buf }) fn decode_body(bytes: Bytes) -> Result<Self, Error> {
Ok(CommandComplete { tag: bytes })
} }
} }
@ -35,7 +36,7 @@ impl CommandComplete {
fn test_decode_command_complete_for_insert() { fn test_decode_command_complete_for_insert() {
const DATA: &[u8] = b"INSERT 0 1214\0"; const DATA: &[u8] = b"INSERT 0 1214\0";
let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
assert_eq!(cc.rows_affected(), 1214); assert_eq!(cc.rows_affected(), 1214);
} }
@ -44,7 +45,7 @@ fn test_decode_command_complete_for_insert() {
fn test_decode_command_complete_for_begin() { fn test_decode_command_complete_for_begin() {
const DATA: &[u8] = b"BEGIN\0"; const DATA: &[u8] = b"BEGIN\0";
let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
assert_eq!(cc.rows_affected(), 0); assert_eq!(cc.rows_affected(), 0);
} }
@ -53,7 +54,7 @@ fn test_decode_command_complete_for_begin() {
fn test_decode_command_complete_for_update() { fn test_decode_command_complete_for_update() {
const DATA: &[u8] = b"UPDATE 5\0"; const DATA: &[u8] = b"UPDATE 5\0";
let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
assert_eq!(cc.rows_affected(), 5); assert_eq!(cc.rows_affected(), 5);
} }
@ -64,7 +65,7 @@ fn bench_decode_command_complete(b: &mut test::Bencher) {
const DATA: &[u8] = b"INSERT 0 1214\0"; const DATA: &[u8] = b"INSERT 0 1214\0";
b.iter(|| { b.iter(|| {
let _ = CommandComplete::decode(test::black_box(Bytes::from_static(DATA))); let _ = CommandComplete::decode_body(test::black_box(Bytes::from_static(DATA)));
}); });
} }
@ -73,7 +74,7 @@ fn bench_decode_command_complete(b: &mut test::Bencher) {
fn bench_decode_command_complete_rows_affected(b: &mut test::Bencher) { fn bench_decode_command_complete_rows_affected(b: &mut test::Bencher) {
const DATA: &[u8] = b"INSERT 0 1214\0"; const DATA: &[u8] = b"INSERT 0 1214\0";
let data = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); let data = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
b.iter(|| { b.iter(|| {
let _rows = test::black_box(&data).rows_affected(); let _rows = test::black_box(&data).rows_affected();

View file

@ -1,15 +1,25 @@
use crate::error::Result; use crate::error::Result;
use crate::io::{BufExt, BufMutExt, Decode, Encode}; use crate::io::BufMutExt;
use sqlx_core::bytes::{Buf, BufMut, Bytes}; use crate::message::{
BackendMessage, BackendMessageFormat, FrontendMessage, FrontendMessageFormat,
};
use sqlx_core::bytes::{Buf, Bytes};
use sqlx_core::Error;
use std::num::Saturating;
use std::ops::Deref; use std::ops::Deref;
/// The same structure is sent for both `CopyInResponse` and `CopyOutResponse` /// The same structure is sent for both `CopyInResponse` and `CopyOutResponse`
pub struct CopyResponse { pub struct CopyResponseData {
pub format: i8, pub format: i8,
pub num_columns: i16, pub num_columns: i16,
pub format_codes: Vec<i16>, pub format_codes: Vec<i16>,
} }
pub struct CopyInResponse(pub CopyResponseData);
#[allow(dead_code)]
pub struct CopyOutResponse(pub CopyResponseData);
pub struct CopyData<B>(pub B); pub struct CopyData<B>(pub B);
pub struct CopyFail { pub struct CopyFail {
@ -18,14 +28,15 @@ pub struct CopyFail {
pub struct CopyDone; pub struct CopyDone;
impl Decode<'_> for CopyResponse { impl CopyResponseData {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self> { #[inline]
fn decode(mut buf: Bytes) -> Result<Self> {
let format = buf.get_i8(); let format = buf.get_i8();
let num_columns = buf.get_i16(); let num_columns = buf.get_i16();
let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect(); let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect();
Ok(CopyResponse { Ok(CopyResponseData {
format, format,
num_columns, num_columns,
format_codes, format_codes,
@ -33,40 +44,65 @@ impl Decode<'_> for CopyResponse {
} }
} }
impl Decode<'_> for CopyData<Bytes> { impl BackendMessage for CopyInResponse {
fn decode_with(buf: Bytes, _: ()) -> Result<Self> { const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyInResponse;
// well.. that was easy
Ok(CopyData(buf)) #[inline(always)]
fn decode_body(buf: Bytes) -> std::result::Result<Self, Error> {
Ok(Self(CopyResponseData::decode(buf)?))
} }
} }
impl<B: Deref<Target = [u8]>> Encode<'_> for CopyData<B> { impl BackendMessage for CopyOutResponse {
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) { const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyOutResponse;
buf.push(b'd');
buf.put_u32(self.0.len() as u32 + 4); #[inline(always)]
fn decode_body(buf: Bytes) -> std::result::Result<Self, Error> {
Ok(Self(CopyResponseData::decode(buf)?))
}
}
impl BackendMessage for CopyData<Bytes> {
const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyData;
#[inline(always)]
fn decode_body(buf: Bytes) -> std::result::Result<Self, Error> {
Ok(Self(buf))
}
}
impl<B: Deref<Target = [u8]>> FrontendMessage for CopyData<B> {
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyData;
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(self.0.len())
}
#[inline(always)]
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.extend_from_slice(&self.0); buf.extend_from_slice(&self.0);
Ok(())
} }
} }
impl Decode<'_> for CopyFail { impl FrontendMessage for CopyFail {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self> { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyFail;
Ok(CopyFail {
message: buf.get_str_nul()?, #[inline(always)]
}) fn body_size_hint(&self) -> Saturating<usize> {
Saturating(self.message.len())
} }
}
impl Encode<'_> for CopyFail { #[inline(always)]
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { fn encode_body(&self, buf: &mut Vec<u8>) -> std::result::Result<(), Error> {
let len = 4 + self.message.len() + 1;
buf.push(b'f'); // to pay respects
buf.put_u32(len as u32);
buf.put_str_nul(&self.message); buf.put_str_nul(&self.message);
Ok(())
} }
} }
impl CopyFail { impl CopyFail {
#[inline(always)]
pub fn new(msg: impl Into<String>) -> CopyFail { pub fn new(msg: impl Into<String>) -> CopyFail {
CopyFail { CopyFail {
message: msg.into(), message: msg.into(),
@ -74,23 +110,32 @@ impl CopyFail {
} }
} }
impl Decode<'_> for CopyDone { impl FrontendMessage for CopyDone {
fn decode_with(buf: Bytes, _: ()) -> Result<Self> { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyDone;
if buf.is_empty() { #[inline(always)]
Ok(CopyDone) fn body_size_hint(&self) -> Saturating<usize> {
} else { Saturating(0)
Err(err_protocol!( }
"expected no data for CopyDone, got: {:?}",
buf #[inline(always)]
)) fn encode_body(&self, _buf: &mut Vec<u8>) -> std::result::Result<(), Error> {
} Ok(())
} }
} }
impl Encode<'_> for CopyDone { impl BackendMessage for CopyDone {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyDone;
buf.reserve(4);
buf.push(b'c'); #[inline(always)]
buf.put_u32(4); fn decode_body(bytes: Bytes) -> std::result::Result<Self, Error> {
if !bytes.is_empty() {
// Not fatal but may indicate a protocol change
tracing::debug!(
"Postgres backend returned non-empty message for CopyDone: \"{}\"",
bytes.escape_ascii()
)
}
Ok(CopyDone)
} }
} }

View file

@ -1,10 +1,9 @@
use std::ops::Range;
use byteorder::{BigEndian, ByteOrder}; use byteorder::{BigEndian, ByteOrder};
use sqlx_core::bytes::Bytes; use sqlx_core::bytes::Bytes;
use std::ops::Range;
use crate::error::Error; use crate::error::Error;
use crate::io::Decode; use crate::message::{BackendMessage, BackendMessageFormat};
/// A row of data from the database. /// A row of data from the database.
#[derive(Debug)] #[derive(Debug)]
@ -26,25 +25,55 @@ impl DataRow {
} }
} }
impl Decode<'_> for DataRow { impl BackendMessage for DataRow {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> { const FORMAT: BackendMessageFormat = BackendMessageFormat::DataRow;
fn decode_body(buf: Bytes) -> Result<Self, Error> {
if buf.len() < 2 {
return Err(err_protocol!(
"expected at least 2 bytes, got {}",
buf.len()
));
}
let cnt = BigEndian::read_u16(&buf) as usize; let cnt = BigEndian::read_u16(&buf) as usize;
let mut values = Vec::with_capacity(cnt); let mut values = Vec::with_capacity(cnt);
let mut offset = 2; let mut offset: u32 = 2;
for _ in 0..cnt { for _ in 0..cnt {
let value_start = offset
.checked_add(4)
.ok_or_else(|| err_protocol!("next value start out of range (offset: {offset})"))?;
// widen both to a larger type for a safe comparison
if (buf.len() as u64) < (value_start as u64) {
return Err(err_protocol!(
"expected 4 bytes at offset {offset}, got {}",
(value_start as u64) - (buf.len() as u64)
));
}
// Length of the column value, in bytes (this count does not include itself). // Length of the column value, in bytes (this count does not include itself).
// Can be zero. As a special case, -1 indicates a NULL column value. // Can be zero. As a special case, -1 indicates a NULL column value.
// No value bytes follow in the NULL case. // No value bytes follow in the NULL case.
//
// we know `offset` is within range of `buf.len()` from the above check
#[allow(clippy::cast_possible_truncation)]
let length = BigEndian::read_i32(&buf[(offset as usize)..]); let length = BigEndian::read_i32(&buf[(offset as usize)..]);
offset += 4;
if length < 0 { if let Ok(length) = u32::try_from(length) {
values.push(None); let value_end = value_start.checked_add(length).ok_or_else(|| {
err_protocol!("value_start + length out of range ({offset} + {length})")
})?;
values.push(Some(value_start..value_end));
offset = value_end;
} else { } else {
values.push(Some(offset..(offset + length as u32))); // Negative values signify NULL
offset += length as u32; values.push(None);
// `value_start` is actually the next value now.
offset = value_start;
} }
} }
@ -57,9 +86,22 @@ impl Decode<'_> for DataRow {
#[test] #[test]
fn test_decode_data_row() { fn test_decode_data_row() {
const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; const DATA: &[u8] = b"\
\x00\x08\
\xff\xff\xff\xff\
\x00\x00\x00\x04\
\x00\x00\x00\n\
\xff\xff\xff\xff\
\x00\x00\x00\x04\
\x00\x00\x00\x14\
\xff\xff\xff\xff\
\x00\x00\x00\x04\
\x00\x00\x00(\
\xff\xff\xff\xff\
\x00\x00\x00\x04\
\x00\x00\x00P";
let row = DataRow::decode(DATA.into()).unwrap(); let row = DataRow::decode_body(DATA.into()).unwrap();
assert_eq!(row.values.len(), 8); assert_eq!(row.values.len(), 8);
@ -78,7 +120,7 @@ fn test_decode_data_row() {
fn bench_data_row_get(b: &mut test::Bencher) { fn bench_data_row_get(b: &mut test::Bencher) {
const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P";
let row = DataRow::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); let row = DataRow::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
b.iter(|| { b.iter(|| {
let _value = test::black_box(&row).get(3); let _value = test::black_box(&row).get(3);
@ -91,6 +133,6 @@ fn bench_decode_data_row(b: &mut test::Bencher) {
const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P";
b.iter(|| { b.iter(|| {
let _ = DataRow::decode(test::black_box(Bytes::from_static(DATA))); let _ = DataRow::decode_body(test::black_box(Bytes::from_static(DATA)));
}); });
} }

View file

@ -1,127 +1,103 @@
use crate::io::Encode; use crate::io::{PgBufMutExt, PortalId, StatementId};
use crate::io::PgBufMutExt; use crate::message::{FrontendMessage, FrontendMessageFormat};
use crate::types::Oid; use sqlx_core::Error;
use std::num::Saturating;
const DESCRIBE_PORTAL: u8 = b'P'; const DESCRIBE_PORTAL: u8 = b'P';
const DESCRIBE_STATEMENT: u8 = b'S'; const DESCRIBE_STATEMENT: u8 = b'S';
// [Describe] will emit both a [RowDescription] and a [ParameterDescription] message /// Note: will emit both a RowDescription and a ParameterDescription message
#[derive(Debug)] #[derive(Debug)]
#[allow(dead_code)] #[allow(dead_code)]
pub enum Describe { pub enum Describe {
UnnamedStatement, Statement(StatementId),
Statement(Oid), Portal(PortalId),
UnnamedPortal,
Portal(Oid),
} }
impl Encode<'_> for Describe { impl FrontendMessage for Describe {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Describe;
// 15 bytes for 1-digit statement/portal IDs
buf.reserve(20);
buf.push(b'D');
buf.put_length_prefixed(|buf| { fn body_size_hint(&self) -> Saturating<usize> {
match self { // Either `DESCRIBE_PORTAL` or `DESCRIBE_STATEMENT`
// #[likely] let mut size = Saturating(1);
Describe::Statement(id) => {
buf.push(DESCRIBE_STATEMENT);
buf.put_statement_name(*id);
}
Describe::UnnamedPortal => { match self {
buf.push(DESCRIBE_PORTAL); Describe::Statement(id) => size += id.name_len(),
buf.push(0); Describe::Portal(id) => size += id.name_len(),
} }
Describe::UnnamedStatement => { size
buf.push(DESCRIBE_STATEMENT); }
buf.push(0);
}
Describe::Portal(id) => { fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.push(DESCRIBE_PORTAL); match self {
buf.put_portal_name(Some(*id)); // #[likely]
} Describe::Statement(id) => {
buf.push(DESCRIBE_STATEMENT);
buf.put_statement_name(*id);
} }
});
Describe::Portal(id) => {
buf.push(DESCRIBE_PORTAL);
buf.put_portal_name(*id);
}
}
Ok(())
} }
} }
#[test] #[cfg(test)]
fn test_encode_describe_portal() { mod tests {
const EXPECTED: &[u8] = b"D\0\0\0\x0EPsqlx_p_5\0"; use crate::message::FrontendMessage;
let mut buf = Vec::new(); use super::{Describe, PortalId, StatementId};
let m = Describe::Portal(Oid(5));
m.encode(&mut buf); #[test]
fn test_encode_describe_portal() {
const EXPECTED: &[u8] = b"D\0\0\0\x17Psqlx_p_1234567890\0";
assert_eq!(buf, EXPECTED); let mut buf = Vec::new();
} let m = Describe::Portal(PortalId::TEST_VAL);
#[test] m.encode_msg(&mut buf).unwrap();
fn test_encode_describe_unnamed_portal() {
const EXPECTED: &[u8] = b"D\0\0\0\x06P\0"; assert_eq!(buf, EXPECTED);
}
let mut buf = Vec::new();
let m = Describe::UnnamedPortal; #[test]
fn test_encode_describe_unnamed_portal() {
m.encode(&mut buf); const EXPECTED: &[u8] = b"D\0\0\0\x06P\0";
assert_eq!(buf, EXPECTED); let mut buf = Vec::new();
} let m = Describe::Portal(PortalId::UNNAMED);
#[test] m.encode_msg(&mut buf).unwrap();
fn test_encode_describe_statement() {
const EXPECTED: &[u8] = b"D\0\0\0\x0ESsqlx_s_5\0"; assert_eq!(buf, EXPECTED);
}
let mut buf = Vec::new();
let m = Describe::Statement(Oid(5)); #[test]
fn test_encode_describe_statement() {
m.encode(&mut buf); const EXPECTED: &[u8] = b"D\0\0\0\x17Ssqlx_s_1234567890\0";
assert_eq!(buf, EXPECTED); let mut buf = Vec::new();
} let m = Describe::Statement(StatementId::TEST_VAL);
#[test] m.encode_msg(&mut buf).unwrap();
fn test_encode_describe_unnamed_statement() {
const EXPECTED: &[u8] = b"D\0\0\0\x06S\0"; assert_eq!(buf, EXPECTED);
}
let mut buf = Vec::new();
let m = Describe::UnnamedStatement; #[test]
fn test_encode_describe_unnamed_statement() {
m.encode(&mut buf); const EXPECTED: &[u8] = b"D\0\0\0\x06S\0";
assert_eq!(buf, EXPECTED); let mut buf = Vec::new();
} let m = Describe::Statement(StatementId::UNNAMED);
#[cfg(all(test, not(debug_assertions)))] m.encode_msg(&mut buf).unwrap();
#[bench]
fn bench_encode_describe_portal(b: &mut test::Bencher) { assert_eq!(buf, EXPECTED);
use test::black_box; }
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Describe::Portal(5)).encode(&mut buf);
});
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_describe_unnamed_statement(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Describe::UnnamedStatement).encode(&mut buf);
});
} }

View file

@ -1,39 +1,73 @@
use crate::io::Encode; use std::num::Saturating;
use crate::io::PgBufMutExt;
use crate::types::Oid; use sqlx_core::Error;
use crate::io::{PgBufMutExt, PortalId};
use crate::message::{FrontendMessage, FrontendMessageFormat};
pub struct Execute { pub struct Execute {
/// The id of the portal to execute (`None` selects the unnamed portal). /// The id of the portal to execute.
pub portal: Option<Oid>, pub portal: PortalId,
/// Maximum number of rows to return, if portal contains a query /// Maximum number of rows to return, if portal contains a query
/// that returns rows (ignored otherwise). Zero denotes “no limit”. /// that returns rows (ignored otherwise). Zero denotes “no limit”.
pub limit: u32, pub limit: u32,
} }
impl Encode<'_> for Execute { impl FrontendMessage for Execute {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Execute;
buf.reserve(20);
buf.push(b'E');
buf.put_length_prefixed(|buf| { fn body_size_hint(&self) -> Saturating<usize> {
buf.put_portal_name(self.portal); let mut size = Saturating(0);
buf.extend(&self.limit.to_be_bytes());
}); size += self.portal.name_len();
size += 2; // limit
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.put_portal_name(self.portal);
buf.extend(&self.limit.to_be_bytes());
Ok(())
} }
} }
#[test] #[cfg(test)]
fn test_encode_execute() { mod tests {
const EXPECTED: &[u8] = b"E\0\0\0\x11sqlx_p_5\0\0\0\0\x02"; use crate::io::PortalId;
use crate::message::FrontendMessage;
let mut buf = Vec::new(); use super::Execute;
let m = Execute {
portal: Some(Oid(5)),
limit: 2,
};
m.encode(&mut buf); #[test]
fn test_encode_execute_named_portal() {
const EXPECTED: &[u8] = b"E\0\0\0\x1Asqlx_p_1234567890\0\0\0\0\x02";
assert_eq!(buf, EXPECTED); let mut buf = Vec::new();
let m = Execute {
portal: PortalId::TEST_VAL,
limit: 2,
};
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}
#[test]
fn test_encode_execute_unnamed_portal() {
const EXPECTED: &[u8] = b"E\0\0\0\x09\0\x49\x96\x02\xD2";
let mut buf = Vec::new();
let m = Execute {
portal: PortalId::UNNAMED,
limit: 1234567890,
};
m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}
} }

View file

@ -1,17 +1,25 @@
use crate::io::Encode; use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
// The Flush message does not cause any specific output to be generated, use std::num::Saturating;
// but forces the backend to deliver any data pending in its output buffers.
// A Flush must be sent after any extended-query command except Sync, if the
// frontend wishes to examine the results of that command before issuing more commands.
/// The Flush message does not cause any specific output to be generated,
/// but forces the backend to deliver any data pending in its output buffers.
///
/// A Flush must be sent after any extended-query command except Sync, if the
/// frontend wishes to examine the results of that command before issuing more commands.
#[derive(Debug)] #[derive(Debug)]
pub struct Flush; pub struct Flush;
impl Encode<'_> for Flush { impl FrontendMessage for Flush {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Flush;
buf.push(b'H');
buf.extend(&4_i32.to_be_bytes()); #[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(0)
}
#[inline(always)]
fn encode_body(&self, _buf: &mut Vec<u8>) -> Result<(), Error> {
Ok(())
} }
} }

View file

@ -1,7 +1,8 @@
use sqlx_core::bytes::Bytes; use sqlx_core::bytes::Bytes;
use std::num::Saturating;
use crate::error::Error; use crate::error::Error;
use crate::io::Decode; use crate::io::PgBufMutExt;
mod authentication; mod authentication;
mod backend_key_data; mod backend_key_data;
@ -17,6 +18,7 @@ mod notification;
mod parameter_description; mod parameter_description;
mod parameter_status; mod parameter_status;
mod parse; mod parse;
mod parse_complete;
mod password; mod password;
mod query; mod query;
mod ready_for_query; mod ready_for_query;
@ -33,7 +35,7 @@ pub use backend_key_data::BackendKeyData;
pub use bind::Bind; pub use bind::Bind;
pub use close::Close; pub use close::Close;
pub use command_complete::CommandComplete; pub use command_complete::CommandComplete;
pub use copy::{CopyData, CopyDone, CopyFail, CopyResponse}; pub use copy::{CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse, CopyResponseData};
pub use data_row::DataRow; pub use data_row::DataRow;
pub use describe::Describe; pub use describe::Describe;
pub use execute::Execute; pub use execute::Execute;
@ -43,20 +45,51 @@ pub use notification::Notification;
pub use parameter_description::ParameterDescription; pub use parameter_description::ParameterDescription;
pub use parameter_status::ParameterStatus; pub use parameter_status::ParameterStatus;
pub use parse::Parse; pub use parse::Parse;
pub use parse_complete::ParseComplete;
pub use password::Password; pub use password::Password;
pub use query::Query; pub use query::Query;
pub use ready_for_query::{ReadyForQuery, TransactionStatus}; pub use ready_for_query::{ReadyForQuery, TransactionStatus};
pub use response::{Notice, PgSeverity}; pub use response::{Notice, PgSeverity};
pub use row_description::RowDescription; pub use row_description::RowDescription;
pub use sasl::{SaslInitialResponse, SaslResponse}; pub use sasl::{SaslInitialResponse, SaslResponse};
use sqlx_core::io::ProtocolEncode;
pub use ssl_request::SslRequest; pub use ssl_request::SslRequest;
pub use startup::Startup; pub use startup::Startup;
pub use sync::Sync; pub use sync::Sync;
pub use terminate::Terminate; pub use terminate::Terminate;
// Note: we can't use the same enum for both frontend and backend message formats
// because there are duplicated format codes between them.
//
// For example, `Close` (frontend) and `CommandComplete` (backend) both use format code `C`.
// <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Debug, PartialOrd, PartialEq)] #[derive(Debug, PartialOrd, PartialEq)]
#[repr(u8)] #[repr(u8)]
pub enum MessageFormat { pub enum FrontendMessageFormat {
Bind = b'B',
Close = b'C',
CopyData = b'd',
CopyDone = b'c',
CopyFail = b'f',
Describe = b'D',
Execute = b'E',
Flush = b'H',
Parse = b'P',
/// This message format is polymorphic. It's used for:
///
/// * Plain password responses
/// * MD5 password responses
/// * SASL responses
/// * GSSAPI/SSPI responses
PasswordPolymorphic = b'p',
Query = b'Q',
Sync = b'S',
Terminate = b'X',
}
#[derive(Debug, PartialOrd, PartialEq)]
#[repr(u8)]
pub enum BackendMessageFormat {
Authentication, Authentication,
BackendKeyData, BackendKeyData,
BindComplete, BindComplete,
@ -81,49 +114,116 @@ pub enum MessageFormat {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct Message { pub struct ReceivedMessage {
pub format: MessageFormat, pub format: BackendMessageFormat,
pub contents: Bytes, pub contents: Bytes,
} }
impl Message { impl ReceivedMessage {
#[inline] #[inline]
pub fn decode<'de, T>(self) -> Result<T, Error> pub fn decode<T>(self) -> Result<T, Error>
where where
T: Decode<'de>, T: BackendMessage,
{ {
T::decode(self.contents) if T::FORMAT != self.format {
return Err(err_protocol!(
"Postgres protocol error: expected {:?}, got {:?}",
T::FORMAT,
self.format
));
}
T::decode_body(self.contents).map_err(|e| match e {
Error::Protocol(s) => {
err_protocol!("Postgres protocol error (reading {:?}): {s}", self.format)
}
other => other,
})
} }
} }
impl MessageFormat { impl BackendMessageFormat {
pub fn try_from_u8(v: u8) -> Result<Self, Error> { pub fn try_from_u8(v: u8) -> Result<Self, Error> {
// https://www.postgresql.org/docs/current/protocol-message-formats.html // https://www.postgresql.org/docs/current/protocol-message-formats.html
Ok(match v { Ok(match v {
b'1' => MessageFormat::ParseComplete, b'1' => BackendMessageFormat::ParseComplete,
b'2' => MessageFormat::BindComplete, b'2' => BackendMessageFormat::BindComplete,
b'3' => MessageFormat::CloseComplete, b'3' => BackendMessageFormat::CloseComplete,
b'C' => MessageFormat::CommandComplete, b'C' => BackendMessageFormat::CommandComplete,
b'd' => MessageFormat::CopyData, b'd' => BackendMessageFormat::CopyData,
b'c' => MessageFormat::CopyDone, b'c' => BackendMessageFormat::CopyDone,
b'G' => MessageFormat::CopyInResponse, b'G' => BackendMessageFormat::CopyInResponse,
b'H' => MessageFormat::CopyOutResponse, b'H' => BackendMessageFormat::CopyOutResponse,
b'D' => MessageFormat::DataRow, b'D' => BackendMessageFormat::DataRow,
b'E' => MessageFormat::ErrorResponse, b'E' => BackendMessageFormat::ErrorResponse,
b'I' => MessageFormat::EmptyQueryResponse, b'I' => BackendMessageFormat::EmptyQueryResponse,
b'A' => MessageFormat::NotificationResponse, b'A' => BackendMessageFormat::NotificationResponse,
b'K' => MessageFormat::BackendKeyData, b'K' => BackendMessageFormat::BackendKeyData,
b'N' => MessageFormat::NoticeResponse, b'N' => BackendMessageFormat::NoticeResponse,
b'R' => MessageFormat::Authentication, b'R' => BackendMessageFormat::Authentication,
b'S' => MessageFormat::ParameterStatus, b'S' => BackendMessageFormat::ParameterStatus,
b'T' => MessageFormat::RowDescription, b'T' => BackendMessageFormat::RowDescription,
b'Z' => MessageFormat::ReadyForQuery, b'Z' => BackendMessageFormat::ReadyForQuery,
b'n' => MessageFormat::NoData, b'n' => BackendMessageFormat::NoData,
b's' => MessageFormat::PortalSuspended, b's' => BackendMessageFormat::PortalSuspended,
b't' => MessageFormat::ParameterDescription, b't' => BackendMessageFormat::ParameterDescription,
_ => return Err(err_protocol!("unknown message type: {:?}", v as char)), _ => return Err(err_protocol!("unknown message type: {:?}", v as char)),
}) })
} }
} }
pub(crate) trait FrontendMessage: Sized {
/// The format prefix of this message.
const FORMAT: FrontendMessageFormat;
/// Return the amount of space, in bytes, to reserve in the buffer passed to [`Self::encode_body()`].
fn body_size_hint(&self) -> Saturating<usize>;
/// Encode this type as a Frontend message in the Postgres protocol.
///
/// The implementation should *not* include `Self::FORMAT` or the length prefix.
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error>;
#[inline(always)]
#[cfg_attr(not(test), allow(dead_code))]
fn encode_msg(self, buf: &mut Vec<u8>) -> Result<(), Error> {
EncodeMessage(self).encode(buf)
}
}
pub(crate) trait BackendMessage: Sized {
/// The expected message format.
///
/// <https://www.postgresql.org/docs/current/protocol-message-formats.html>
const FORMAT: BackendMessageFormat;
/// Decode this type from a Backend message in the Postgres protocol.
///
/// The format code and length prefix have already been read and are not at the start of `bytes`.
fn decode_body(buf: Bytes) -> Result<Self, Error>;
}
pub struct EncodeMessage<F>(pub F);
impl<F: FrontendMessage> ProtocolEncode<'_, ()> for EncodeMessage<F> {
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) -> Result<(), Error> {
let mut size_hint = self.0.body_size_hint();
// plus format code and length prefix
size_hint += 5;
// don't panic if `size_hint` is ridiculous
buf.try_reserve(size_hint.0).map_err(|e| {
err_protocol!(
"Postgres protocol: error allocating {} bytes for encoding message {:?}: {e}",
size_hint.0,
F::FORMAT,
)
})?;
buf.push(F::FORMAT as u8);
buf.put_length_prefixed(|buf| self.0.encode_body(buf))
}
}

View file

@ -1,7 +1,8 @@
use sqlx_core::bytes::{Buf, Bytes}; use sqlx_core::bytes::{Buf, Bytes};
use crate::error::Error; use crate::error::Error;
use crate::io::{BufExt, Decode}; use crate::io::BufExt;
use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug)] #[derive(Debug)]
pub struct Notification { pub struct Notification {
@ -10,9 +11,10 @@ pub struct Notification {
pub(crate) payload: Bytes, pub(crate) payload: Bytes,
} }
impl Decode<'_> for Notification { impl BackendMessage for Notification {
#[inline] const FORMAT: BackendMessageFormat = BackendMessageFormat::NotificationResponse;
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
let process_id = buf.get_u32(); let process_id = buf.get_u32();
let channel = buf.get_bytes_nul()?; let channel = buf.get_bytes_nul()?;
let payload = buf.get_bytes_nul()?; let payload = buf.get_bytes_nul()?;
@ -29,7 +31,7 @@ impl Decode<'_> for Notification {
fn test_decode_notification_response() { fn test_decode_notification_response() {
const NOTIFICATION_RESPONSE: &[u8] = b"\x34\x20\x10\x02TEST-CHANNEL\0THIS IS A TEST\0"; const NOTIFICATION_RESPONSE: &[u8] = b"\x34\x20\x10\x02TEST-CHANNEL\0THIS IS A TEST\0";
let message = Notification::decode(Bytes::from(NOTIFICATION_RESPONSE)).unwrap(); let message = Notification::decode_body(Bytes::from(NOTIFICATION_RESPONSE)).unwrap();
assert_eq!(message.process_id, 0x34201002); assert_eq!(message.process_id, 0x34201002);
assert_eq!(&*message.channel, &b"TEST-CHANNEL"[..]); assert_eq!(&*message.channel, &b"TEST-CHANNEL"[..]);

View file

@ -2,7 +2,7 @@ use smallvec::SmallVec;
use sqlx_core::bytes::{Buf, Bytes}; use sqlx_core::bytes::{Buf, Bytes};
use crate::error::Error; use crate::error::Error;
use crate::io::Decode; use crate::message::{BackendMessage, BackendMessageFormat};
use crate::types::Oid; use crate::types::Oid;
#[derive(Debug)] #[derive(Debug)]
@ -10,8 +10,10 @@ pub struct ParameterDescription {
pub types: SmallVec<[Oid; 6]>, pub types: SmallVec<[Oid; 6]>,
} }
impl Decode<'_> for ParameterDescription { impl BackendMessage for ParameterDescription {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> { const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterDescription;
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
let cnt = buf.get_u16(); let cnt = buf.get_u16();
let mut types = SmallVec::with_capacity(cnt as usize); let mut types = SmallVec::with_capacity(cnt as usize);
@ -27,7 +29,7 @@ impl Decode<'_> for ParameterDescription {
fn test_decode_parameter_description() { fn test_decode_parameter_description() {
const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00";
let m = ParameterDescription::decode(DATA.into()).unwrap(); let m = ParameterDescription::decode_body(DATA.into()).unwrap();
assert_eq!(m.types.len(), 2); assert_eq!(m.types.len(), 2);
assert_eq!(m.types[0], Oid(0x0000_0000)); assert_eq!(m.types[0], Oid(0x0000_0000));
@ -38,7 +40,7 @@ fn test_decode_parameter_description() {
fn test_decode_empty_parameter_description() { fn test_decode_empty_parameter_description() {
const DATA: &[u8] = b"\x00\x00"; const DATA: &[u8] = b"\x00\x00";
let m = ParameterDescription::decode(DATA.into()).unwrap(); let m = ParameterDescription::decode_body(DATA.into()).unwrap();
assert!(m.types.is_empty()); assert!(m.types.is_empty());
} }
@ -49,6 +51,6 @@ fn bench_decode_parameter_description(b: &mut test::Bencher) {
const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00";
b.iter(|| { b.iter(|| {
ParameterDescription::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); ParameterDescription::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
}); });
} }

View file

@ -1,7 +1,8 @@
use sqlx_core::bytes::Bytes; use sqlx_core::bytes::Bytes;
use crate::error::Error; use crate::error::Error;
use crate::io::{BufExt, Decode}; use crate::io::BufExt;
use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug)] #[derive(Debug)]
pub struct ParameterStatus { pub struct ParameterStatus {
@ -9,8 +10,10 @@ pub struct ParameterStatus {
pub value: String, pub value: String,
} }
impl Decode<'_> for ParameterStatus { impl BackendMessage for ParameterStatus {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> { const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterStatus;
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
let name = buf.get_str_nul()?; let name = buf.get_str_nul()?;
let value = buf.get_str_nul()?; let value = buf.get_str_nul()?;
@ -22,7 +25,7 @@ impl Decode<'_> for ParameterStatus {
fn test_decode_parameter_status() { fn test_decode_parameter_status() {
const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; const DATA: &[u8] = b"client_encoding\x00UTF8\x00";
let m = ParameterStatus::decode(DATA.into()).unwrap(); let m = ParameterStatus::decode_body(DATA.into()).unwrap();
assert_eq!(&m.name, "client_encoding"); assert_eq!(&m.name, "client_encoding");
assert_eq!(&m.value, "UTF8") assert_eq!(&m.value, "UTF8")
@ -32,7 +35,7 @@ fn test_decode_parameter_status() {
fn test_decode_empty_parameter_status() { fn test_decode_empty_parameter_status() {
const DATA: &[u8] = b"\x00\x00"; const DATA: &[u8] = b"\x00\x00";
let m = ParameterStatus::decode(DATA.into()).unwrap(); let m = ParameterStatus::decode_body(DATA.into()).unwrap();
assert!(m.name.is_empty()); assert!(m.name.is_empty());
assert!(m.value.is_empty()); assert!(m.value.is_empty());
@ -44,7 +47,7 @@ fn bench_decode_parameter_status(b: &mut test::Bencher) {
const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; const DATA: &[u8] = b"client_encoding\x00UTF8\x00";
b.iter(|| { b.iter(|| {
ParameterStatus::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); ParameterStatus::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
}); });
} }
@ -52,7 +55,7 @@ fn bench_decode_parameter_status(b: &mut test::Bencher) {
fn test_decode_parameter_status_response() { fn test_decode_parameter_status_response() {
const PARAMETER_STATUS_RESPONSE: &[u8] = b"crdb_version\0CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)\0"; const PARAMETER_STATUS_RESPONSE: &[u8] = b"crdb_version\0CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)\0";
let message = ParameterStatus::decode(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap(); let message = ParameterStatus::decode_body(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap();
assert_eq!(message.name, "crdb_version"); assert_eq!(message.name, "crdb_version");
assert_eq!( assert_eq!(

View file

@ -1,11 +1,14 @@
use crate::io::PgBufMutExt; use crate::io::BufMutExt;
use crate::io::{BufMutExt, Encode}; use crate::io::{PgBufMutExt, StatementId};
use crate::message::{FrontendMessage, FrontendMessageFormat};
use crate::types::Oid; use crate::types::Oid;
use sqlx_core::Error;
use std::num::Saturating;
#[derive(Debug)] #[derive(Debug)]
pub struct Parse<'a> { pub struct Parse<'a> {
/// The ID of the destination prepared statement. /// The ID of the destination prepared statement.
pub statement: Oid, pub statement: StatementId,
/// The query string to be parsed. /// The query string to be parsed.
pub query: &'a str, pub query: &'a str,
@ -16,39 +19,59 @@ pub struct Parse<'a> {
pub param_types: &'a [Oid], pub param_types: &'a [Oid],
} }
impl Encode<'_> for Parse<'_> { impl FrontendMessage for Parse<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Parse;
buf.push(b'P');
buf.put_length_prefixed(|buf| { fn body_size_hint(&self) -> Saturating<usize> {
buf.put_statement_name(self.statement); let mut size = Saturating(0);
buf.put_str_nul(self.query); size += self.statement.name_len();
// TODO: Return an error here instead size += self.query.len();
assert!(self.param_types.len() <= (u16::MAX as usize)); size += 1; // NUL terminator
buf.extend(&(self.param_types.len() as i16).to_be_bytes()); size += 2; // param_types_len
for &oid in self.param_types { // `param_types`
buf.extend(&oid.0.to_be_bytes()); size += self.param_types.len().saturating_mul(4);
}
}) size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.put_statement_name(self.statement);
buf.put_str_nul(self.query);
let param_types_len = i16::try_from(self.param_types.len()).map_err(|_| {
err_protocol!(
"param_types.len() too large for binary protocol: {}",
self.param_types.len()
)
})?;
buf.extend(param_types_len.to_be_bytes());
for &oid in self.param_types {
buf.extend(oid.0.to_be_bytes());
}
Ok(())
} }
} }
#[test] #[test]
fn test_encode_parse() { fn test_encode_parse() {
const EXPECTED: &[u8] = b"P\0\0\0\x1dsqlx_s_1\0SELECT $1\0\0\x01\0\0\0\x19"; const EXPECTED: &[u8] = b"P\0\0\0\x26sqlx_s_1234567890\0SELECT $1\0\0\x01\0\0\0\x19";
let mut buf = Vec::new(); let mut buf = Vec::new();
let m = Parse { let m = Parse {
statement: Oid(1), statement: StatementId::TEST_VAL,
query: "SELECT $1", query: "SELECT $1",
param_types: &[Oid(25)], param_types: &[Oid(25)],
}; };
m.encode(&mut buf); m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED); assert_eq!(buf, EXPECTED);
} }

View file

@ -0,0 +1,13 @@
use crate::message::{BackendMessage, BackendMessageFormat};
use sqlx_core::bytes::Bytes;
use sqlx_core::Error;
pub struct ParseComplete;
impl BackendMessage for ParseComplete {
const FORMAT: BackendMessageFormat = BackendMessageFormat::ParseComplete;
fn decode_body(_bytes: Bytes) -> Result<Self, Error> {
Ok(ParseComplete)
}
}

View file

@ -1,9 +1,9 @@
use std::fmt::Write; use crate::io::BufMutExt;
use crate::message::{FrontendMessage, FrontendMessageFormat};
use md5::{Digest, Md5}; use md5::{Digest, Md5};
use sqlx_core::Error;
use crate::io::PgBufMutExt; use std::fmt::Write;
use crate::io::{BufMutExt, Encode}; use std::num::Saturating;
#[derive(Debug)] #[derive(Debug)]
pub enum Password<'a> { pub enum Password<'a> {
@ -16,117 +16,138 @@ pub enum Password<'a> {
}, },
} }
impl Password<'_> { impl FrontendMessage for Password<'_> {
#[inline] const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic;
fn len(&self) -> usize {
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
let mut size = Saturating(0);
match self { match self {
Password::Cleartext(s) => s.len() + 5, Password::Cleartext(password) => {
Password::Md5 { .. } => 35 + 5, // To avoid reporting the exact password length anywhere,
} // we deliberately give a bad estimate.
} //
} // This shouldn't affect performance in the long run.
size += password
impl Encode<'_> for Password<'_> { .len()
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { .saturating_add(1) // NUL terminator
buf.reserve(1 + 4 + self.len()); .checked_next_power_of_two()
buf.push(b'p'); .unwrap_or(usize::MAX);
buf.put_length_prefixed(|buf| {
match self {
Password::Cleartext(password) => {
buf.put_str_nul(password);
}
Password::Md5 {
username,
password,
salt,
} => {
// The actual `PasswordMessage` can be computed in SQL as
// `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`.
// Keep in mind the md5() function returns its result as a hex string.
let mut hasher = Md5::new();
hasher.update(password);
hasher.update(username);
let mut output = String::with_capacity(35);
let _ = write!(output, "{:x}", hasher.finalize_reset());
hasher.update(&output);
hasher.update(salt);
output.clear();
let _ = write!(output, "md5{:x}", hasher.finalize());
buf.put_str_nul(&output);
}
} }
}); Password::Md5 { .. } => {
// "md5<32 hex chars>\0"
size += 36;
}
}
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
match self {
Password::Cleartext(password) => {
buf.put_str_nul(password);
}
Password::Md5 {
username,
password,
salt,
} => {
// The actual `PasswordMessage` can be computed in SQL as
// `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`.
// Keep in mind the md5() function returns its result as a hex string.
let mut hasher = Md5::new();
hasher.update(password);
hasher.update(username);
let mut output = String::with_capacity(35);
let _ = write!(output, "{:x}", hasher.finalize_reset());
hasher.update(&output);
hasher.update(salt);
output.clear();
let _ = write!(output, "md5{:x}", hasher.finalize());
buf.put_str_nul(&output);
}
}
Ok(())
} }
} }
#[test] #[cfg(test)]
fn test_encode_clear_password() { mod tests {
const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0"; use crate::message::FrontendMessage;
let mut buf = Vec::new(); use super::Password;
let m = Password::Cleartext("password");
m.encode(&mut buf); #[test]
fn test_encode_clear_password() {
const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0";
assert_eq!(buf, EXPECTED); let mut buf = Vec::new();
} let m = Password::Cleartext("password");
#[test] m.encode_msg(&mut buf).unwrap();
fn test_encode_md5_password() {
const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0";
let mut buf = Vec::new(); assert_eq!(buf, EXPECTED);
let m = Password::Md5 { }
password: "password",
username: "root",
salt: [147, 24, 57, 152],
};
m.encode(&mut buf); #[test]
fn test_encode_md5_password() {
const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0";
assert_eq!(buf, EXPECTED); let mut buf = Vec::new();
} let m = Password::Md5 {
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_clear_password(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Password::Cleartext("password")).encode(&mut buf);
});
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_md5_password(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Password::Md5 {
password: "password", password: "password",
username: "root", username: "root",
salt: [147, 24, 57, 152], salt: [147, 24, 57, 152],
}) };
.encode(&mut buf);
}); m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED);
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_clear_password(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Password::Cleartext("password")).encode_msg(&mut buf);
});
}
#[cfg(all(test, not(debug_assertions)))]
#[bench]
fn bench_encode_md5_password(b: &mut test::Bencher) {
use test::black_box;
let mut buf = Vec::with_capacity(128);
b.iter(|| {
buf.clear();
black_box(Password::Md5 {
password: "password",
username: "root",
salt: [147, 24, 57, 152],
})
.encode_msg(&mut buf);
});
}
} }

View file

@ -1,27 +1,37 @@
use crate::io::{BufMutExt, Encode}; use crate::io::BufMutExt;
use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
use std::num::Saturating;
#[derive(Debug)] #[derive(Debug)]
pub struct Query<'a>(pub &'a str); pub struct Query<'a>(pub &'a str);
impl Encode<'_> for Query<'_> { impl FrontendMessage for Query<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Query;
let len = 4 + self.0.len() + 1;
buf.reserve(len + 1); fn body_size_hint(&self) -> Saturating<usize> {
buf.push(b'Q'); let mut size = Saturating(0);
buf.extend(&(len as i32).to_be_bytes());
size += self.0.len();
size += 1; // NUL terminator
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.put_str_nul(self.0); buf.put_str_nul(self.0);
Ok(())
} }
} }
#[test] #[test]
fn test_encode_query() { fn test_encode_query() {
const EXPECTED: &[u8] = b"Q\0\0\0\rSELECT 1\0"; const EXPECTED: &[u8] = b"Q\0\0\0\x0DSELECT 1\0";
let mut buf = Vec::new(); let mut buf = Vec::new();
let m = Query("SELECT 1"); let m = Query("SELECT 1");
m.encode(&mut buf); m.encode_msg(&mut buf).unwrap();
assert_eq!(buf, EXPECTED); assert_eq!(buf, EXPECTED);
} }

View file

@ -1,7 +1,7 @@
use sqlx_core::bytes::Bytes; use sqlx_core::bytes::Bytes;
use crate::error::Error; use crate::error::Error;
use crate::io::Decode; use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug)] #[derive(Debug)]
#[repr(u8)] #[repr(u8)]
@ -21,8 +21,10 @@ pub struct ReadyForQuery {
pub transaction_status: TransactionStatus, pub transaction_status: TransactionStatus,
} }
impl Decode<'_> for ReadyForQuery { impl BackendMessage for ReadyForQuery {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> { const FORMAT: BackendMessageFormat = BackendMessageFormat::ReadyForQuery;
fn decode_body(buf: Bytes) -> Result<Self, Error> {
let status = match buf[0] { let status = match buf[0] {
b'I' => TransactionStatus::Idle, b'I' => TransactionStatus::Idle,
b'T' => TransactionStatus::Transaction, b'T' => TransactionStatus::Transaction,
@ -46,7 +48,7 @@ impl Decode<'_> for ReadyForQuery {
fn test_decode_ready_for_query() -> Result<(), Error> { fn test_decode_ready_for_query() -> Result<(), Error> {
const DATA: &[u8] = b"E"; const DATA: &[u8] = b"E";
let m = ReadyForQuery::decode(Bytes::from_static(DATA))?; let m = ReadyForQuery::decode_body(Bytes::from_static(DATA))?;
assert!(matches!(m.transaction_status, TransactionStatus::Error)); assert!(matches!(m.transaction_status, TransactionStatus::Error));

View file

@ -1,10 +1,13 @@
use std::ops::Range;
use std::str::from_utf8; use std::str::from_utf8;
use memchr::memchr; use memchr::memchr;
use sqlx_core::bytes::Bytes; use sqlx_core::bytes::Bytes;
use crate::error::Error; use crate::error::Error;
use crate::io::Decode; use crate::io::ProtocolDecode;
use crate::message::{BackendMessage, BackendMessageFormat};
#[derive(Debug, Copy, Clone, Eq, PartialEq)] #[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[repr(u8)] #[repr(u8)]
@ -53,8 +56,8 @@ impl TryFrom<&str> for PgSeverity {
pub struct Notice { pub struct Notice {
storage: Bytes, storage: Bytes,
severity: PgSeverity, severity: PgSeverity,
message: (u16, u16), message: Range<usize>,
code: (u16, u16), code: Range<usize>,
} }
impl Notice { impl Notice {
@ -65,12 +68,12 @@ impl Notice {
#[inline] #[inline]
pub fn code(&self) -> &str { pub fn code(&self) -> &str {
self.get_cached_str(self.code) self.get_cached_str(self.code.clone())
} }
#[inline] #[inline]
pub fn message(&self) -> &str { pub fn message(&self) -> &str {
self.get_cached_str(self.message) self.get_cached_str(self.message.clone())
} }
// Field descriptions available here: // Field descriptions available here:
@ -84,7 +87,7 @@ impl Notice {
pub fn get_raw(&self, ty: u8) -> Option<&[u8]> { pub fn get_raw(&self, ty: u8) -> Option<&[u8]> {
self.fields() self.fields()
.filter(|(field, _)| *field == ty) .filter(|(field, _)| *field == ty)
.map(|(_, (start, end))| &self.storage[start as usize..end as usize]) .map(|(_, range)| &self.storage[range])
.next() .next()
} }
} }
@ -99,13 +102,13 @@ impl Notice {
} }
#[inline] #[inline]
fn get_cached_str(&self, cache: (u16, u16)) -> &str { fn get_cached_str(&self, cache: Range<usize>) -> &str {
// unwrap: this cannot fail at this stage // unwrap: this cannot fail at this stage
from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap() from_utf8(&self.storage[cache]).unwrap()
} }
} }
impl Decode<'_> for Notice { impl ProtocolDecode<'_> for Notice {
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> { fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
// In order to support PostgreSQL 9.5 and older we need to parse the localized S field. // In order to support PostgreSQL 9.5 and older we need to parse the localized S field.
// Newer versions additionally come with the V field that is guaranteed to be in English. // Newer versions additionally come with the V field that is guaranteed to be in English.
@ -113,8 +116,8 @@ impl Decode<'_> for Notice {
const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log; const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log;
let mut severity_v = None; let mut severity_v = None;
let mut severity_s = None; let mut severity_s = None;
let mut message = (0, 0); let mut message = 0..0;
let mut code = (0, 0); let mut code = 0..0;
// we cache the three always present fields // we cache the three always present fields
// this enables to keep the access time down for the fields most likely accessed // this enables to keep the access time down for the fields most likely accessed
@ -125,7 +128,7 @@ impl Decode<'_> for Notice {
}; };
for (field, v) in fields { for (field, v) in fields {
if message.0 != 0 && code.0 != 0 { if !(message.is_empty() || code.is_empty()) {
// stop iterating when we have the 3 fields we were looking for // stop iterating when we have the 3 fields we were looking for
// we assume V (severity) was the first field as it should be // we assume V (severity) was the first field as it should be
break; break;
@ -133,7 +136,7 @@ impl Decode<'_> for Notice {
match field { match field {
b'S' => { b'S' => {
severity_s = from_utf8(&buf[v.0 as usize..v.1 as usize]) severity_s = from_utf8(&buf[v.clone()])
// If the error string is not UTF-8, we have no hope of interpreting it, // If the error string is not UTF-8, we have no hope of interpreting it,
// localized or not. The `V` field would likely fail to parse as well. // localized or not. The `V` field would likely fail to parse as well.
.map_err(|_| notice_protocol_err())? .map_err(|_| notice_protocol_err())?
@ -146,21 +149,19 @@ impl Decode<'_> for Notice {
// Propagate errors here, because V is not localized and // Propagate errors here, because V is not localized and
// thus we are missing a possible variant. // thus we are missing a possible variant.
severity_v = Some( severity_v = Some(
from_utf8(&buf[v.0 as usize..v.1 as usize]) from_utf8(&buf[v.clone()])
.map_err(|_| notice_protocol_err())? .map_err(|_| notice_protocol_err())?
.try_into()?, .try_into()?,
); );
} }
b'M' => { b'M' => {
_ = from_utf8(&buf[v.0 as usize..v.1 as usize]) _ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?;
.map_err(|_| notice_protocol_err())?;
message = v; message = v;
} }
b'C' => { b'C' => {
_ = from_utf8(&buf[v.0 as usize..v.1 as usize]) _ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?;
.map_err(|_| notice_protocol_err())?;
code = v; code = v;
} }
@ -179,31 +180,46 @@ impl Decode<'_> for Notice {
} }
} }
impl BackendMessage for Notice {
const FORMAT: BackendMessageFormat = BackendMessageFormat::NoticeResponse;
fn decode_body(buf: Bytes) -> Result<Self, Error> {
// Keeping both impls for now
Self::decode_with(buf, ())
}
}
/// An iterator over each field in the Error (or Notice) response. /// An iterator over each field in the Error (or Notice) response.
struct Fields<'a> { struct Fields<'a> {
storage: &'a [u8], storage: &'a [u8],
offset: u16, offset: usize,
} }
impl<'a> Iterator for Fields<'a> { impl<'a> Iterator for Fields<'a> {
type Item = (u8, (u16, u16)); type Item = (u8, Range<usize>);
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
// The fields in the response body are sequentially stored as [tag][string], // The fields in the response body are sequentially stored as [tag][string],
// ending in a final, additional [nul] // ending in a final, additional [nul]
let ty = self.storage[self.offset as usize]; let ty = *self.storage.get(self.offset)?;
if ty == 0 { if ty == 0 {
return None; return None;
} }
let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16; // Consume the type byte
let offset = self.offset; self.offset = self.offset.checked_add(1)?;
self.offset += nul + 2; let start = self.offset;
Some((ty, (offset + 1, offset + nul + 1))) let len = memchr(b'\0', self.storage.get(start..)?)?;
// Neither can overflow as they will always be `<= self.storage.len()`.
let end = self.offset + len;
self.offset = end + 1;
Some((ty, start..end))
} }
} }

View file

@ -1,7 +1,8 @@
use sqlx_core::bytes::{Buf, Bytes}; use sqlx_core::bytes::{Buf, Bytes};
use crate::error::Error; use crate::error::Error;
use crate::io::{BufExt, Decode}; use crate::io::BufExt;
use crate::message::{BackendMessage, BackendMessageFormat};
use crate::types::Oid; use crate::types::Oid;
#[derive(Debug)] #[derive(Debug)]
@ -40,13 +41,30 @@ pub struct Field {
pub format: i16, pub format: i16,
} }
impl Decode<'_> for RowDescription { impl BackendMessage for RowDescription {
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> { const FORMAT: BackendMessageFormat = BackendMessageFormat::RowDescription;
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
if buf.len() < 2 {
return Err(err_protocol!(
"expected at least 2 bytes, got {}",
buf.len()
));
}
let cnt = buf.get_u16(); let cnt = buf.get_u16();
let mut fields = Vec::with_capacity(cnt as usize); let mut fields = Vec::with_capacity(cnt as usize);
for _ in 0..cnt { for _ in 0..cnt {
let name = buf.get_str_nul()?.to_owned(); let name = buf.get_str_nul()?.to_owned();
if buf.len() < 18 {
return Err(err_protocol!(
"expected at least 18 bytes after field name {name:?}, got {}",
buf.len()
));
}
let relation_id = buf.get_i32(); let relation_id = buf.get_i32();
let relation_attribute_no = buf.get_i16(); let relation_attribute_no = buf.get_i16();
let data_type_id = Oid(buf.get_u32()); let data_type_id = Oid(buf.get_u32());

View file

@ -1,35 +1,69 @@
use crate::io::PgBufMutExt; use crate::io::BufMutExt;
use crate::io::{BufMutExt, Encode}; use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
use std::num::Saturating;
pub struct SaslInitialResponse<'a> { pub struct SaslInitialResponse<'a> {
pub response: &'a str, pub response: &'a str,
pub plus: bool, pub plus: bool,
} }
impl Encode<'_> for SaslInitialResponse<'_> { impl SaslInitialResponse<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { #[inline(always)]
buf.push(b'p'); fn selected_mechanism(&self) -> &'static str {
buf.put_length_prefixed(|buf| { if self.plus {
// name of the SASL authentication mechanism that the client selected "SCRAM-SHA-256-PLUS"
buf.put_str_nul(if self.plus { } else {
"SCRAM-SHA-256-PLUS" "SCRAM-SHA-256"
} else { }
"SCRAM-SHA-256" }
}); }
buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes()); impl FrontendMessage for SaslInitialResponse<'_> {
buf.extend(self.response.as_bytes()); const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic;
});
#[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
let mut size = Saturating(0);
size += self.selected_mechanism().len();
size += 1; // NUL terminator
size += 4; // response_len
size += self.response.len();
size
}
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
// name of the SASL authentication mechanism that the client selected
buf.put_str_nul(self.selected_mechanism());
let response_len = i32::try_from(self.response.len()).map_err(|_| {
err_protocol!(
"SASL Initial Response length too long for protcol: {}",
self.response.len()
)
})?;
buf.extend_from_slice(&response_len.to_be_bytes());
buf.extend_from_slice(self.response.as_bytes());
Ok(())
} }
} }
pub struct SaslResponse<'a>(pub &'a str); pub struct SaslResponse<'a>(pub &'a str);
impl Encode<'_> for SaslResponse<'_> { impl FrontendMessage for SaslResponse<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic;
buf.push(b'p');
buf.put_length_prefixed(|buf| { fn body_size_hint(&self) -> Saturating<usize> {
buf.extend(self.0.as_bytes()); Saturating(self.0.len())
}); }
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
buf.extend(self.0.as_bytes());
Ok(())
} }
} }

View file

@ -1,23 +1,38 @@
use crate::io::Encode; use crate::io::ProtocolEncode;
pub struct SslRequest; pub struct SslRequest;
impl SslRequest { impl SslRequest {
pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16/"; // https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST
pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16\x2f";
} }
impl Encode<'_> for SslRequest { // Cannot impl FrontendMessage because it does not have a format code
#[inline] impl ProtocolEncode<'_> for SslRequest {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { #[inline(always)]
buf.extend(&8_u32.to_be_bytes()); fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) -> Result<(), crate::Error> {
buf.extend(&(((1234 << 16) | 5679) as u32).to_be_bytes()); buf.extend_from_slice(Self::BYTES);
Ok(())
} }
} }
#[test] #[test]
fn test_encode_ssl_request() { fn test_encode_ssl_request() {
let mut buf = Vec::new(); let mut buf = Vec::new();
SslRequest.encode(&mut buf);
// Int32(8)
// Length of message contents in bytes, including self.
buf.extend_from_slice(&8_u32.to_be_bytes());
// Int32(80877103)
// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits,
// and 5679 in the least significant 16 bits.
// (To avoid confusion, this code must not be the same as any protocol version number.)
buf.extend_from_slice(&(((1234 << 16) | 5679) as u32).to_be_bytes());
let mut encoded = Vec::new();
SslRequest.encode(&mut encoded).unwrap();
assert_eq!(buf, SslRequest::BYTES); assert_eq!(buf, SslRequest::BYTES);
assert_eq!(buf, encoded);
} }

View file

@ -1,5 +1,5 @@
use crate::io::PgBufMutExt; use crate::io::PgBufMutExt;
use crate::io::{BufMutExt, Encode}; use crate::io::{BufMutExt, ProtocolEncode};
// To begin a session, a frontend opens a connection to the server and sends a startup message. // To begin a session, a frontend opens a connection to the server and sends a startup message.
// This message includes the names of the user and of the database the user wants to connect to; // This message includes the names of the user and of the database the user wants to connect to;
@ -19,8 +19,9 @@ pub struct Startup<'a> {
pub params: &'a [(&'a str, &'a str)], pub params: &'a [(&'a str, &'a str)],
} }
impl Encode<'_> for Startup<'_> { // Startup cannot impl FrontendMessage because it doesn't have a format code.
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { impl ProtocolEncode<'_> for Startup<'_> {
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) -> Result<(), crate::Error> {
buf.reserve(120); buf.reserve(120);
buf.put_length_prefixed(|buf| { buf.put_length_prefixed(|buf| {
@ -47,7 +48,9 @@ impl Encode<'_> for Startup<'_> {
// A zero byte is required as a terminator // A zero byte is required as a terminator
// after the last name/value pair. // after the last name/value pair.
buf.push(0); buf.push(0);
});
Ok(())
})
} }
} }
@ -68,7 +71,7 @@ fn test_encode_startup() {
params: &[], params: &[],
}; };
m.encode(&mut buf); m.encode(&mut buf).unwrap();
assert_eq!(buf, EXPECTED); assert_eq!(buf, EXPECTED);
} }

View file

@ -1,11 +1,20 @@
use crate::io::Encode; use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
use std::num::Saturating;
#[derive(Debug)] #[derive(Debug)]
pub struct Sync; pub struct Sync;
impl Encode<'_> for Sync { impl FrontendMessage for Sync {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Sync;
buf.push(b'S');
buf.extend(&4_i32.to_be_bytes()); #[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(0)
}
#[inline(always)]
fn encode_body(&self, _buf: &mut Vec<u8>) -> Result<(), Error> {
Ok(())
} }
} }

View file

@ -1,10 +1,19 @@
use crate::io::Encode; use crate::message::{FrontendMessage, FrontendMessageFormat};
use sqlx_core::Error;
use std::num::Saturating;
pub struct Terminate; pub struct Terminate;
impl Encode<'_> for Terminate { impl FrontendMessage for Terminate {
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) { const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Terminate;
buf.push(b'X');
buf.extend(&4_u32.to_be_bytes()); #[inline(always)]
fn body_size_hint(&self) -> Saturating<usize> {
Saturating(0)
}
#[inline(always)]
fn encode_body(&self, _buf: &mut Vec<u8>) -> Result<(), Error> {
Ok(())
} }
} }

View file

@ -17,7 +17,7 @@ impl TransactionManager for PgTransactionManager {
Box::pin(async move { Box::pin(async move {
let rollback = Rollback::new(conn); let rollback = Rollback::new(conn);
let query = begin_ansi_transaction_sql(rollback.conn.transaction_depth); let query = begin_ansi_transaction_sql(rollback.conn.transaction_depth);
rollback.conn.queue_simple_query(&query); rollback.conn.queue_simple_query(&query)?;
rollback.conn.transaction_depth += 1; rollback.conn.transaction_depth += 1;
rollback.conn.wait_until_ready().await?; rollback.conn.wait_until_ready().await?;
rollback.defuse(); rollback.defuse();
@ -54,7 +54,8 @@ impl TransactionManager for PgTransactionManager {
fn start_rollback(conn: &mut PgConnection) { fn start_rollback(conn: &mut PgConnection) {
if conn.transaction_depth > 0 { if conn.transaction_depth > 0 {
conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth)); conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth))
.expect("BUG: Rollback query somehow too large for protocol");
conn.transaction_depth -= 1; conn.transaction_depth -= 1;
} }

View file

@ -17,12 +17,6 @@ pub struct Oid(
pub u32, pub u32,
); );
impl Oid {
pub(crate) fn incr_one(&mut self) {
self.0 = self.0.wrapping_add(1);
}
}
impl Type<Postgres> for Oid { impl Type<Postgres> for Oid {
fn type_info() -> PgTypeInfo { fn type_info() -> PgTypeInfo {
PgTypeInfo::OID PgTypeInfo::OID