mirror of
https://github.com/launchbadge/sqlx
synced 2024-11-10 06:24:16 +00:00
refactor(postgres): make better use of traits to improve protocol handling
This commit is contained in:
parent
9b3808b2d5
commit
53766e4659
40 changed files with 1252 additions and 693 deletions
|
@ -414,7 +414,8 @@ impl<'lock, C: AsMut<PgConnection>> Drop for PgAdvisoryLockGuard<'lock, C> {
|
|||
// The `async fn` versions can safely use the prepared statement protocol,
|
||||
// but this is the safest way to queue a query to execute on the next opportunity.
|
||||
conn.as_mut()
|
||||
.queue_simple_query(self.lock.get_release_query());
|
||||
.queue_simple_query(self.lock.get_release_query())
|
||||
.expect("BUG: PgAdvisoryLock::get_release_query() somehow too long for protocol");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -145,6 +145,7 @@ impl<'q> Arguments<'q> for PgArguments {
|
|||
write!(writer, "${}", self.buffer.count)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn len(&self) -> usize {
|
||||
self.buffer.count
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use crate::error::Error;
|
||||
use crate::ext::ustr::UStr;
|
||||
use crate::io::StatementId;
|
||||
use crate::message::{ParameterDescription, RowDescription};
|
||||
use crate::query_as::query_as;
|
||||
use crate::query_scalar::query_scalar;
|
||||
|
@ -27,10 +28,12 @@ enum TypType {
|
|||
Range,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for TypType {
|
||||
impl TryFrom<i8> for TypType {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(t: u8) -> Result<Self, Self::Error> {
|
||||
fn try_from(t: i8) -> Result<Self, Self::Error> {
|
||||
let t = u8::try_from(t).or(Err(()))?;
|
||||
|
||||
let t = match t {
|
||||
b'b' => Self::Base,
|
||||
b'c' => Self::Composite,
|
||||
|
@ -66,10 +69,12 @@ enum TypCategory {
|
|||
Unknown,
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for TypCategory {
|
||||
impl TryFrom<i8> for TypCategory {
|
||||
type Error = ();
|
||||
|
||||
fn try_from(c: u8) -> Result<Self, Self::Error> {
|
||||
fn try_from(c: i8) -> Result<Self, Self::Error> {
|
||||
let c = u8::try_from(c).or(Err(()))?;
|
||||
|
||||
let c = match c {
|
||||
b'A' => Self::Array,
|
||||
b'B' => Self::Boolean,
|
||||
|
@ -209,8 +214,8 @@ impl PgConnection {
|
|||
.fetch_one(&mut *self)
|
||||
.await?;
|
||||
|
||||
let typ_type = TypType::try_from(typ_type as u8);
|
||||
let category = TypCategory::try_from(category as u8);
|
||||
let typ_type = TypType::try_from(typ_type);
|
||||
let category = TypCategory::try_from(category);
|
||||
|
||||
match (typ_type, category) {
|
||||
(Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await,
|
||||
|
@ -416,7 +421,7 @@ WHERE rngtypid = $1
|
|||
|
||||
pub(crate) async fn get_nullable_for_columns(
|
||||
&mut self,
|
||||
stmt_id: Oid,
|
||||
stmt_id: StatementId,
|
||||
meta: &PgStatementMetadata,
|
||||
) -> Result<Vec<Option<bool>>, Error> {
|
||||
if meta.columns.is_empty() {
|
||||
|
@ -486,13 +491,10 @@ WHERE rngtypid = $1
|
|||
/// and returns `None` for all others.
|
||||
async fn nullables_from_explain(
|
||||
&mut self,
|
||||
stmt_id: Oid,
|
||||
stmt_id: StatementId,
|
||||
params_len: usize,
|
||||
) -> Result<Vec<Option<bool>>, Error> {
|
||||
let mut explain = format!(
|
||||
"EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE sqlx_s_{}",
|
||||
stmt_id.0
|
||||
);
|
||||
let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id}");
|
||||
let mut comma = false;
|
||||
|
||||
if params_len > 0 {
|
||||
|
|
|
@ -3,11 +3,10 @@ use crate::HashMap;
|
|||
use crate::common::StatementCache;
|
||||
use crate::connection::{sasl, stream::PgStream};
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::io::StatementId;
|
||||
use crate::message::{
|
||||
Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup,
|
||||
Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup,
|
||||
};
|
||||
use crate::types::Oid;
|
||||
use crate::{PgConnectOptions, PgConnection};
|
||||
|
||||
// https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3
|
||||
|
@ -44,13 +43,13 @@ impl PgConnection {
|
|||
params.push(("options", options));
|
||||
}
|
||||
|
||||
stream
|
||||
.send(Startup {
|
||||
username: Some(&options.username),
|
||||
database: options.database.as_deref(),
|
||||
params: ¶ms,
|
||||
})
|
||||
.await?;
|
||||
stream.write(Startup {
|
||||
username: Some(&options.username),
|
||||
database: options.database.as_deref(),
|
||||
params: ¶ms,
|
||||
})?;
|
||||
|
||||
stream.flush().await?;
|
||||
|
||||
// The server then uses this information and the contents of
|
||||
// its configuration files (such as pg_hba.conf) to determine whether the connection is
|
||||
|
@ -64,7 +63,7 @@ impl PgConnection {
|
|||
loop {
|
||||
let message = stream.recv().await?;
|
||||
match message.format {
|
||||
MessageFormat::Authentication => match message.decode()? {
|
||||
BackendMessageFormat::Authentication => match message.decode()? {
|
||||
Authentication::Ok => {
|
||||
// the authentication exchange is successfully completed
|
||||
// do nothing; no more information is required to continue
|
||||
|
@ -108,7 +107,7 @@ impl PgConnection {
|
|||
}
|
||||
},
|
||||
|
||||
MessageFormat::BackendKeyData => {
|
||||
BackendMessageFormat::BackendKeyData => {
|
||||
// provides secret-key data that the frontend must save if it wants to be
|
||||
// able to issue cancel requests later
|
||||
|
||||
|
@ -118,10 +117,9 @@ impl PgConnection {
|
|||
secret_key = data.secret_key;
|
||||
}
|
||||
|
||||
MessageFormat::ReadyForQuery => {
|
||||
BackendMessageFormat::ReadyForQuery => {
|
||||
// start-up is completed. The frontend can now issue commands
|
||||
transaction_status =
|
||||
ReadyForQuery::decode(message.contents)?.transaction_status;
|
||||
transaction_status = message.decode::<ReadyForQuery>()?.transaction_status;
|
||||
|
||||
break;
|
||||
}
|
||||
|
@ -142,7 +140,7 @@ impl PgConnection {
|
|||
transaction_status,
|
||||
transaction_depth: 0,
|
||||
pending_ready_for_query_count: 0,
|
||||
next_statement_id: Oid(1),
|
||||
next_statement_id: StatementId::NAMED_START,
|
||||
cache_statement: StatementCache::new(options.statement_cache_capacity),
|
||||
cache_type_oid: HashMap::new(),
|
||||
cache_type_info: HashMap::new(),
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
use crate::describe::Describe;
|
||||
use crate::error::Error;
|
||||
use crate::executor::{Execute, Executor};
|
||||
use crate::io::{PortalId, StatementId};
|
||||
use crate::logger::QueryLogger;
|
||||
use crate::message::{
|
||||
self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query,
|
||||
RowDescription,
|
||||
self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
|
||||
ParseComplete, Query, RowDescription,
|
||||
};
|
||||
use crate::statement::PgStatementMetadata;
|
||||
use crate::types::Oid;
|
||||
use crate::{
|
||||
statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
|
||||
PgValueFormat, Postgres,
|
||||
|
@ -16,6 +16,7 @@ use futures_core::future::BoxFuture;
|
|||
use futures_core::stream::BoxStream;
|
||||
use futures_core::Stream;
|
||||
use futures_util::{pin_mut, TryStreamExt};
|
||||
use sqlx_core::arguments::Arguments;
|
||||
use sqlx_core::Either;
|
||||
use std::{borrow::Cow, sync::Arc};
|
||||
|
||||
|
@ -24,9 +25,9 @@ async fn prepare(
|
|||
sql: &str,
|
||||
parameters: &[PgTypeInfo],
|
||||
metadata: Option<Arc<PgStatementMetadata>>,
|
||||
) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
|
||||
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
|
||||
let id = conn.next_statement_id;
|
||||
conn.next_statement_id.incr_one();
|
||||
conn.next_statement_id = id.next();
|
||||
|
||||
// build a list of type OIDs to send to the database in the PARSE command
|
||||
// we have not yet started the query sequence, so we are *safe* to cleanly make
|
||||
|
@ -42,15 +43,15 @@ async fn prepare(
|
|||
conn.wait_until_ready().await?;
|
||||
|
||||
// next we send the PARSE command to the server
|
||||
conn.stream.write(Parse {
|
||||
conn.stream.write_msg(Parse {
|
||||
param_types: ¶m_types,
|
||||
query: sql,
|
||||
statement: id,
|
||||
});
|
||||
})?;
|
||||
|
||||
if metadata.is_none() {
|
||||
// get the statement columns and parameters
|
||||
conn.stream.write(message::Describe::Statement(id));
|
||||
conn.stream.write_msg(message::Describe::Statement(id))?;
|
||||
}
|
||||
|
||||
// we ask for the server to immediately send us the result of the PARSE command
|
||||
|
@ -58,9 +59,7 @@ async fn prepare(
|
|||
conn.stream.flush().await?;
|
||||
|
||||
// indicates that the SQL query string is now successfully parsed and has semantic validity
|
||||
conn.stream
|
||||
.recv_expect(MessageFormat::ParseComplete)
|
||||
.await?;
|
||||
conn.stream.recv_expect::<ParseComplete>().await?;
|
||||
|
||||
let metadata = if let Some(metadata) = metadata {
|
||||
// each SYNC produces one READY FOR QUERY
|
||||
|
@ -95,18 +94,18 @@ async fn prepare(
|
|||
}
|
||||
|
||||
async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
|
||||
conn.stream
|
||||
.recv_expect(MessageFormat::ParameterDescription)
|
||||
.await
|
||||
conn.stream.recv_expect().await
|
||||
}
|
||||
|
||||
async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
|
||||
let rows: Option<RowDescription> = match conn.stream.recv().await? {
|
||||
// describes the rows that will be returned when the statement is eventually executed
|
||||
message if message.format == MessageFormat::RowDescription => Some(message.decode()?),
|
||||
message if message.format == BackendMessageFormat::RowDescription => {
|
||||
Some(message.decode()?)
|
||||
}
|
||||
|
||||
// no data would be returned if this statement was executed
|
||||
message if message.format == MessageFormat::NoData => None,
|
||||
message if message.format == BackendMessageFormat::NoData => None,
|
||||
|
||||
message => {
|
||||
return Err(err_protocol!(
|
||||
|
@ -125,12 +124,12 @@ impl PgConnection {
|
|||
// we need to wait for the [CloseComplete] to be returned from the server
|
||||
while count > 0 {
|
||||
match self.stream.recv().await? {
|
||||
message if message.format == MessageFormat::PortalSuspended => {
|
||||
message if message.format == BackendMessageFormat::PortalSuspended => {
|
||||
// there was an open portal
|
||||
// this can happen if the last time a statement was used it was not fully executed
|
||||
}
|
||||
|
||||
message if message.format == MessageFormat::CloseComplete => {
|
||||
message if message.format == BackendMessageFormat::CloseComplete => {
|
||||
// successfully closed the statement (and freed up the server resources)
|
||||
count -= 1;
|
||||
}
|
||||
|
@ -147,8 +146,11 @@ impl PgConnection {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn write_sync(&mut self) {
|
||||
self.stream.write(message::Sync);
|
||||
self.stream
|
||||
.write_msg(message::Sync)
|
||||
.expect("BUG: Sync should not be too big for protocol");
|
||||
|
||||
// all SYNC messages will return a ReadyForQuery
|
||||
self.pending_ready_for_query_count += 1;
|
||||
|
@ -163,7 +165,7 @@ impl PgConnection {
|
|||
// optional metadata that was provided by the user, this means they are reusing
|
||||
// a statement object
|
||||
metadata: Option<Arc<PgStatementMetadata>>,
|
||||
) -> Result<(Oid, Arc<PgStatementMetadata>), Error> {
|
||||
) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
|
||||
if let Some(statement) = self.cache_statement.get_mut(sql) {
|
||||
return Ok((*statement).clone());
|
||||
}
|
||||
|
@ -172,7 +174,7 @@ impl PgConnection {
|
|||
|
||||
if store_to_cache && self.cache_statement.is_enabled() {
|
||||
if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) {
|
||||
self.stream.write(Close::Statement(id));
|
||||
self.stream.write_msg(Close::Statement(id))?;
|
||||
self.write_sync();
|
||||
|
||||
self.stream.flush().await?;
|
||||
|
@ -201,6 +203,14 @@ impl PgConnection {
|
|||
let mut metadata: Arc<PgStatementMetadata>;
|
||||
|
||||
let format = if let Some(mut arguments) = arguments {
|
||||
// Check this before we write anything to the stream.
|
||||
let num_params = i16::try_from(arguments.len()).map_err(|_| {
|
||||
err_protocol!(
|
||||
"PgConnection::run(): too many arguments for query: {}",
|
||||
arguments.len()
|
||||
)
|
||||
})?;
|
||||
|
||||
// prepare the statement if this our first time executing it
|
||||
// always return the statement ID here
|
||||
let (statement, metadata_) = self
|
||||
|
@ -216,21 +226,21 @@ impl PgConnection {
|
|||
self.wait_until_ready().await?;
|
||||
|
||||
// bind to attach the arguments to the statement and create a portal
|
||||
self.stream.write(Bind {
|
||||
portal: None,
|
||||
self.stream.write_msg(Bind {
|
||||
portal: PortalId::UNNAMED,
|
||||
statement,
|
||||
formats: &[PgValueFormat::Binary],
|
||||
num_params: arguments.types.len() as i16,
|
||||
num_params,
|
||||
params: &arguments.buffer,
|
||||
result_formats: &[PgValueFormat::Binary],
|
||||
});
|
||||
})?;
|
||||
|
||||
// executes the portal up to the passed limit
|
||||
// the protocol-level limit acts nearly identically to the `LIMIT` in SQL
|
||||
self.stream.write(message::Execute {
|
||||
portal: None,
|
||||
self.stream.write_msg(message::Execute {
|
||||
portal: PortalId::UNNAMED,
|
||||
limit: limit.into(),
|
||||
});
|
||||
})?;
|
||||
// From https://www.postgresql.org/docs/current/protocol-flow.html:
|
||||
//
|
||||
// "An unnamed portal is destroyed at the end of the transaction, or as
|
||||
|
@ -240,7 +250,7 @@ impl PgConnection {
|
|||
|
||||
// we ask the database server to close the unnamed portal and free the associated resources
|
||||
// earlier - after the execution of the current query.
|
||||
self.stream.write(message::Close::Portal(None));
|
||||
self.stream.write_msg(Close::Portal(PortalId::UNNAMED))?;
|
||||
|
||||
// finally, [Sync] asks postgres to process the messages that we sent and respond with
|
||||
// a [ReadyForQuery] message when it's completely done. Theoretically, we could send
|
||||
|
@ -253,7 +263,7 @@ impl PgConnection {
|
|||
PgValueFormat::Binary
|
||||
} else {
|
||||
// Query will trigger a ReadyForQuery
|
||||
self.stream.write(Query(query));
|
||||
self.stream.write_msg(Query(query))?;
|
||||
self.pending_ready_for_query_count += 1;
|
||||
|
||||
// metadata starts out as "nothing"
|
||||
|
@ -270,12 +280,12 @@ impl PgConnection {
|
|||
let message = self.stream.recv().await?;
|
||||
|
||||
match message.format {
|
||||
MessageFormat::BindComplete
|
||||
| MessageFormat::ParseComplete
|
||||
| MessageFormat::ParameterDescription
|
||||
| MessageFormat::NoData
|
||||
BackendMessageFormat::BindComplete
|
||||
| BackendMessageFormat::ParseComplete
|
||||
| BackendMessageFormat::ParameterDescription
|
||||
| BackendMessageFormat::NoData
|
||||
// unnamed portal has been closed
|
||||
| MessageFormat::CloseComplete
|
||||
| BackendMessageFormat::CloseComplete
|
||||
=> {
|
||||
// harmless messages to ignore
|
||||
}
|
||||
|
@ -284,7 +294,7 @@ impl PgConnection {
|
|||
// exactly one of these messages: CommandComplete,
|
||||
// EmptyQueryResponse (if the portal was created from an
|
||||
// empty query string), ErrorResponse, or PortalSuspended"
|
||||
MessageFormat::CommandComplete => {
|
||||
BackendMessageFormat::CommandComplete => {
|
||||
// a SQL command completed normally
|
||||
let cc: CommandComplete = message.decode()?;
|
||||
|
||||
|
@ -295,16 +305,16 @@ impl PgConnection {
|
|||
}));
|
||||
}
|
||||
|
||||
MessageFormat::EmptyQueryResponse => {
|
||||
BackendMessageFormat::EmptyQueryResponse => {
|
||||
// empty query string passed to an unprepared execute
|
||||
}
|
||||
|
||||
// Message::ErrorResponse is handled in self.stream.recv()
|
||||
|
||||
// incomplete query execution has finished
|
||||
MessageFormat::PortalSuspended => {}
|
||||
BackendMessageFormat::PortalSuspended => {}
|
||||
|
||||
MessageFormat::RowDescription => {
|
||||
BackendMessageFormat::RowDescription => {
|
||||
// indicates that a *new* set of rows are about to be returned
|
||||
let (columns, column_names) = self
|
||||
.handle_row_description(Some(message.decode()?), false)
|
||||
|
@ -317,7 +327,7 @@ impl PgConnection {
|
|||
});
|
||||
}
|
||||
|
||||
MessageFormat::DataRow => {
|
||||
BackendMessageFormat::DataRow => {
|
||||
logger.increment_rows_returned();
|
||||
|
||||
// one of the set of rows returned by a SELECT, FETCH, etc query
|
||||
|
@ -331,7 +341,7 @@ impl PgConnection {
|
|||
r#yield!(Either::Right(row));
|
||||
}
|
||||
|
||||
MessageFormat::ReadyForQuery => {
|
||||
BackendMessageFormat::ReadyForQuery => {
|
||||
// processing of the query string is complete
|
||||
self.handle_ready_for_query(message)?;
|
||||
break;
|
||||
|
|
|
@ -8,9 +8,10 @@ use futures_util::FutureExt;
|
|||
use crate::common::StatementCache;
|
||||
use crate::error::Error;
|
||||
use crate::ext::ustr::UStr;
|
||||
use crate::io::Decode;
|
||||
use crate::io::StatementId;
|
||||
use crate::message::{
|
||||
Close, Message, MessageFormat, Query, ReadyForQuery, Terminate, TransactionStatus,
|
||||
BackendMessageFormat, Close, Query, ReadyForQuery, ReceivedMessage, Terminate,
|
||||
TransactionStatus,
|
||||
};
|
||||
use crate::statement::PgStatementMetadata;
|
||||
use crate::transaction::Transaction;
|
||||
|
@ -47,10 +48,10 @@ pub struct PgConnection {
|
|||
|
||||
// sequence of statement IDs for use in preparing statements
|
||||
// in PostgreSQL, the statement is prepared to a user-supplied identifier
|
||||
next_statement_id: Oid,
|
||||
next_statement_id: StatementId,
|
||||
|
||||
// cache statement by query string to the id and columns
|
||||
cache_statement: StatementCache<(Oid, Arc<PgStatementMetadata>)>,
|
||||
cache_statement: StatementCache<(StatementId, Arc<PgStatementMetadata>)>,
|
||||
|
||||
// cache user-defined types by id <-> info
|
||||
cache_type_info: HashMap<Oid, PgTypeInfo>,
|
||||
|
@ -82,7 +83,7 @@ impl PgConnection {
|
|||
while self.pending_ready_for_query_count > 0 {
|
||||
let message = self.stream.recv().await?;
|
||||
|
||||
if let MessageFormat::ReadyForQuery = message.format {
|
||||
if let BackendMessageFormat::ReadyForQuery = message.format {
|
||||
self.handle_ready_for_query(message)?;
|
||||
}
|
||||
}
|
||||
|
@ -91,10 +92,7 @@ impl PgConnection {
|
|||
}
|
||||
|
||||
async fn recv_ready_for_query(&mut self) -> Result<(), Error> {
|
||||
let r: ReadyForQuery = self
|
||||
.stream
|
||||
.recv_expect(MessageFormat::ReadyForQuery)
|
||||
.await?;
|
||||
let r: ReadyForQuery = self.stream.recv_expect().await?;
|
||||
|
||||
self.pending_ready_for_query_count -= 1;
|
||||
self.transaction_status = r.transaction_status;
|
||||
|
@ -102,9 +100,10 @@ impl PgConnection {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_ready_for_query(&mut self, message: Message) -> Result<(), Error> {
|
||||
#[inline(always)]
|
||||
fn handle_ready_for_query(&mut self, message: ReceivedMessage) -> Result<(), Error> {
|
||||
self.pending_ready_for_query_count -= 1;
|
||||
self.transaction_status = ReadyForQuery::decode(message.contents)?.transaction_status;
|
||||
self.transaction_status = message.decode::<ReadyForQuery>()?.transaction_status;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -112,9 +111,12 @@ impl PgConnection {
|
|||
/// Queue a simple query (not prepared) to execute the next time this connection is used.
|
||||
///
|
||||
/// Used for rolling back transactions and releasing advisory locks.
|
||||
pub(crate) fn queue_simple_query(&mut self, query: &str) {
|
||||
#[inline(always)]
|
||||
pub(crate) fn queue_simple_query(&mut self, query: &str) -> Result<(), Error> {
|
||||
self.stream.write_msg(Query(query))?;
|
||||
self.pending_ready_for_query_count += 1;
|
||||
self.stream.write(Query(query));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -184,7 +186,7 @@ impl Connection for PgConnection {
|
|||
self.wait_until_ready().await?;
|
||||
|
||||
while let Some((id, _)) = self.cache_statement.remove_lru() {
|
||||
self.stream.write(Close::Statement(id));
|
||||
self.stream.write_msg(Close::Statement(id))?;
|
||||
cleared += 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
use crate::connection::stream::PgStream;
|
||||
use crate::error::Error;
|
||||
use crate::message::{
|
||||
Authentication, AuthenticationSasl, MessageFormat, SaslInitialResponse, SaslResponse,
|
||||
};
|
||||
use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse};
|
||||
use crate::PgConnectOptions;
|
||||
use hmac::{Hmac, Mac};
|
||||
use rand::Rng;
|
||||
|
@ -76,7 +74,7 @@ pub(crate) async fn authenticate(
|
|||
})
|
||||
.await?;
|
||||
|
||||
let cont = match stream.recv_expect(MessageFormat::Authentication).await? {
|
||||
let cont = match stream.recv_expect().await? {
|
||||
Authentication::SaslContinue(data) => data,
|
||||
|
||||
auth => {
|
||||
|
@ -147,7 +145,7 @@ pub(crate) async fn authenticate(
|
|||
|
||||
stream.send(SaslResponse(&client_final_message)).await?;
|
||||
|
||||
let data = match stream.recv_expect(MessageFormat::Authentication).await? {
|
||||
let data = match stream.recv_expect().await? {
|
||||
Authentication::SaslFinal(data) => data,
|
||||
|
||||
auth => {
|
||||
|
@ -172,10 +170,10 @@ fn gen_nonce() -> String {
|
|||
// ;; a valid "value".
|
||||
let nonce: String = std::iter::repeat(())
|
||||
.map(|()| {
|
||||
let mut c = rng.gen_range(0x21..0x7F) as u8;
|
||||
let mut c = rng.gen_range(0x21u8..0x7F);
|
||||
|
||||
while c == 0x2C {
|
||||
c = rng.gen_range(0x21..0x7F) as u8;
|
||||
c = rng.gen_range(0x21u8..0x7F);
|
||||
}
|
||||
|
||||
c
|
||||
|
|
|
@ -9,8 +9,10 @@ use sqlx_core::bytes::{Buf, Bytes};
|
|||
|
||||
use crate::connection::tls::MaybeUpgradeTls;
|
||||
use crate::error::Error;
|
||||
use crate::io::{Decode, Encode};
|
||||
use crate::message::{Message, MessageFormat, Notice, Notification, ParameterStatus};
|
||||
use crate::message::{
|
||||
BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification,
|
||||
ParameterStatus, ReceivedMessage,
|
||||
};
|
||||
use crate::net::{self, BufferedSocket, Socket};
|
||||
use crate::{PgConnectOptions, PgDatabaseError, PgSeverity};
|
||||
|
||||
|
@ -55,59 +57,51 @@ impl PgStream {
|
|||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error>
|
||||
#[inline(always)]
|
||||
pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> {
|
||||
self.write(EncodeMessage(message))
|
||||
}
|
||||
|
||||
pub(crate) async fn send<T>(&mut self, message: T) -> Result<(), Error>
|
||||
where
|
||||
T: Encode<'en>,
|
||||
T: FrontendMessage,
|
||||
{
|
||||
self.write(message);
|
||||
self.write_msg(message)?;
|
||||
self.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Expect a specific type and format
|
||||
pub(crate) async fn recv_expect<'de, T: Decode<'de>>(
|
||||
&mut self,
|
||||
format: MessageFormat,
|
||||
) -> Result<T, Error> {
|
||||
let message = self.recv().await?;
|
||||
|
||||
if message.format != format {
|
||||
return Err(err_protocol!(
|
||||
"expecting {:?} but received {:?}",
|
||||
format,
|
||||
message.format
|
||||
));
|
||||
}
|
||||
|
||||
message.decode()
|
||||
pub(crate) async fn recv_expect<B: BackendMessage>(&mut self) -> Result<B, Error> {
|
||||
self.recv().await?.decode()
|
||||
}
|
||||
|
||||
pub(crate) async fn recv_unchecked(&mut self) -> Result<Message, Error> {
|
||||
pub(crate) async fn recv_unchecked(&mut self) -> Result<ReceivedMessage, Error> {
|
||||
// all packets in postgres start with a 5-byte header
|
||||
// this header contains the message type and the total length of the message
|
||||
let mut header: Bytes = self.inner.read(5).await?;
|
||||
|
||||
let format = MessageFormat::try_from_u8(header.get_u8())?;
|
||||
let format = BackendMessageFormat::try_from_u8(header.get_u8())?;
|
||||
let size = (header.get_u32() - 4) as usize;
|
||||
|
||||
let contents = self.inner.read(size).await?;
|
||||
|
||||
Ok(Message { format, contents })
|
||||
Ok(ReceivedMessage { format, contents })
|
||||
}
|
||||
|
||||
// Get the next message from the server
|
||||
// May wait for more data from the server
|
||||
pub(crate) async fn recv(&mut self) -> Result<Message, Error> {
|
||||
pub(crate) async fn recv(&mut self) -> Result<ReceivedMessage, Error> {
|
||||
loop {
|
||||
let message = self.recv_unchecked().await?;
|
||||
|
||||
match message.format {
|
||||
MessageFormat::ErrorResponse => {
|
||||
BackendMessageFormat::ErrorResponse => {
|
||||
// An error returned from the database server.
|
||||
return Err(PgDatabaseError(message.decode()?).into());
|
||||
}
|
||||
|
||||
MessageFormat::NotificationResponse => {
|
||||
BackendMessageFormat::NotificationResponse => {
|
||||
if let Some(buffer) = &mut self.notifications {
|
||||
let notification: Notification = message.decode()?;
|
||||
let _ = buffer.send(notification).await;
|
||||
|
@ -116,7 +110,7 @@ impl PgStream {
|
|||
}
|
||||
}
|
||||
|
||||
MessageFormat::ParameterStatus => {
|
||||
BackendMessageFormat::ParameterStatus => {
|
||||
// informs the frontend about the current (initial)
|
||||
// setting of backend parameters
|
||||
|
||||
|
@ -135,7 +129,7 @@ impl PgStream {
|
|||
continue;
|
||||
}
|
||||
|
||||
MessageFormat::NoticeResponse => {
|
||||
BackendMessageFormat::NoticeResponse => {
|
||||
// do we need this to be more configurable?
|
||||
// if you are reading this comment and think so, open an issue
|
||||
|
||||
|
|
|
@ -11,7 +11,8 @@ use crate::error::{Error, Result};
|
|||
use crate::ext::async_stream::TryAsyncStream;
|
||||
use crate::io::AsyncRead;
|
||||
use crate::message::{
|
||||
CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query,
|
||||
BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse,
|
||||
CopyOutResponse, CopyResponseData, Query, ReadyForQuery,
|
||||
};
|
||||
use crate::pool::{Pool, PoolConnection};
|
||||
use crate::Postgres;
|
||||
|
@ -138,7 +139,7 @@ impl PgPoolCopyExt for Pool<Postgres> {
|
|||
#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"]
|
||||
pub struct PgCopyIn<C: DerefMut<Target = PgConnection>> {
|
||||
conn: Option<C>,
|
||||
response: CopyResponse,
|
||||
response: CopyResponseData,
|
||||
}
|
||||
|
||||
impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
|
||||
|
@ -146,8 +147,8 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
|
|||
conn.wait_until_ready().await?;
|
||||
conn.stream.send(Query(statement)).await?;
|
||||
|
||||
let response = match conn.stream.recv_expect(MessageFormat::CopyInResponse).await {
|
||||
Ok(res) => res,
|
||||
let response = match conn.stream.recv_expect::<CopyInResponse>().await {
|
||||
Ok(res) => res.0,
|
||||
Err(e) => {
|
||||
conn.stream.recv().await?;
|
||||
return Err(e);
|
||||
|
@ -168,7 +169,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
|
|||
/// Returns the number of columns expected in the input.
|
||||
pub fn num_columns(&self) -> usize {
|
||||
assert_eq!(
|
||||
self.response.num_columns as usize,
|
||||
self.response.num_columns.unsigned_abs() as usize,
|
||||
self.response.format_codes.len(),
|
||||
"num_columns does not match format_codes.len()"
|
||||
);
|
||||
|
@ -261,9 +262,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
|
|||
match e.code() {
|
||||
Some(Cow::Borrowed("57014")) => {
|
||||
// postgres abort received error code
|
||||
conn.stream
|
||||
.recv_expect(MessageFormat::ReadyForQuery)
|
||||
.await?;
|
||||
conn.stream.recv_expect::<ReadyForQuery>().await?;
|
||||
Ok(())
|
||||
}
|
||||
_ => Err(Error::Database(e)),
|
||||
|
@ -283,11 +282,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
|
|||
.expect("CopyWriter::finish: conn taken illegally");
|
||||
|
||||
conn.stream.send(CopyDone).await?;
|
||||
let cc: CommandComplete = match conn
|
||||
.stream
|
||||
.recv_expect(MessageFormat::CommandComplete)
|
||||
.await
|
||||
{
|
||||
let cc: CommandComplete = match conn.stream.recv_expect().await {
|
||||
Ok(cc) => cc,
|
||||
Err(e) => {
|
||||
conn.stream.recv().await?;
|
||||
|
@ -295,9 +290,7 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
|
|||
}
|
||||
};
|
||||
|
||||
conn.stream
|
||||
.recv_expect(MessageFormat::ReadyForQuery)
|
||||
.await?;
|
||||
conn.stream.recv_expect::<ReadyForQuery>().await?;
|
||||
|
||||
Ok(cc.rows_affected())
|
||||
}
|
||||
|
@ -306,9 +299,11 @@ impl<C: DerefMut<Target = PgConnection>> PgCopyIn<C> {
|
|||
impl<C: DerefMut<Target = PgConnection>> Drop for PgCopyIn<C> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(mut conn) = self.conn.take() {
|
||||
conn.stream.write(CopyFail::new(
|
||||
"PgCopyIn dropped without calling finish() or fail()",
|
||||
));
|
||||
conn.stream
|
||||
.write_msg(CopyFail::new(
|
||||
"PgCopyIn dropped without calling finish() or fail()",
|
||||
))
|
||||
.expect("BUG: PgCopyIn abort message should not be too large");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -320,24 +315,21 @@ async fn pg_begin_copy_out<'c, C: DerefMut<Target = PgConnection> + Send + 'c>(
|
|||
conn.wait_until_ready().await?;
|
||||
conn.stream.send(Query(statement)).await?;
|
||||
|
||||
let _: CopyResponse = conn
|
||||
.stream
|
||||
.recv_expect(MessageFormat::CopyOutResponse)
|
||||
.await?;
|
||||
let _: CopyOutResponse = conn.stream.recv_expect().await?;
|
||||
|
||||
let stream: TryAsyncStream<'c, Bytes> = try_stream! {
|
||||
loop {
|
||||
match conn.stream.recv().await {
|
||||
Err(e) => {
|
||||
conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?;
|
||||
conn.stream.recv_expect::<ReadyForQuery>().await?;
|
||||
return Err(e);
|
||||
},
|
||||
Ok(msg) => match msg.format {
|
||||
MessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
|
||||
MessageFormat::CopyDone => {
|
||||
BackendMessageFormat::CopyData => r#yield!(msg.decode::<CopyData<Bytes>>()?.0),
|
||||
BackendMessageFormat::CopyDone => {
|
||||
let _ = msg.decode::<CopyDone>()?;
|
||||
conn.stream.recv_expect(MessageFormat::CommandComplete).await?;
|
||||
conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?;
|
||||
conn.stream.recv_expect::<CommandComplete>().await?;
|
||||
conn.stream.recv_expect::<ReadyForQuery>().await?;
|
||||
return Ok(())
|
||||
},
|
||||
_ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format))
|
||||
|
|
|
@ -1,54 +1,64 @@
|
|||
use crate::types::Oid;
|
||||
use crate::io::{PortalId, StatementId};
|
||||
|
||||
pub trait PgBufMutExt {
|
||||
fn put_length_prefixed<F>(&mut self, f: F)
|
||||
fn put_length_prefixed<F>(&mut self, f: F) -> Result<(), crate::Error>
|
||||
where
|
||||
F: FnOnce(&mut Vec<u8>);
|
||||
F: FnOnce(&mut Vec<u8>) -> Result<(), crate::Error>;
|
||||
|
||||
fn put_statement_name(&mut self, id: Oid);
|
||||
fn put_statement_name(&mut self, id: StatementId);
|
||||
|
||||
fn put_portal_name(&mut self, id: Option<Oid>);
|
||||
fn put_portal_name(&mut self, id: PortalId);
|
||||
}
|
||||
|
||||
impl PgBufMutExt for Vec<u8> {
|
||||
// writes a length-prefixed message, this is used when encoding nearly all messages as postgres
|
||||
// wants us to send the length of the often-variable-sized messages up front
|
||||
fn put_length_prefixed<F>(&mut self, f: F)
|
||||
fn put_length_prefixed<F>(&mut self, write_contents: F) -> Result<(), crate::Error>
|
||||
where
|
||||
F: FnOnce(&mut Vec<u8>),
|
||||
F: FnOnce(&mut Vec<u8>) -> Result<(), crate::Error>,
|
||||
{
|
||||
// reserve space to write the prefixed length
|
||||
let offset = self.len();
|
||||
self.extend(&[0; 4]);
|
||||
|
||||
// write the main body of the message
|
||||
f(self);
|
||||
let write_result = write_contents(self);
|
||||
|
||||
// now calculate the size of what we wrote and set the length value
|
||||
let size = (self.len() - offset) as i32;
|
||||
self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes());
|
||||
let size_result = write_result.and_then(|_| {
|
||||
let size = self.len() - offset;
|
||||
i32::try_from(size)
|
||||
.map_err(|_| err_protocol!("message size out of range for Pg protocol: {size"))
|
||||
});
|
||||
|
||||
match size_result {
|
||||
Ok(size) => {
|
||||
// now calculate the size of what we wrote and set the length value
|
||||
self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes());
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
// Put the buffer back to where it was.
|
||||
self.truncate(offset);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writes a statement name by ID
|
||||
#[inline]
|
||||
fn put_statement_name(&mut self, id: Oid) {
|
||||
// N.B. if you change this don't forget to update it in ../describe.rs
|
||||
self.extend(b"sqlx_s_");
|
||||
|
||||
self.extend(itoa::Buffer::new().format(id.0).as_bytes());
|
||||
|
||||
self.push(0);
|
||||
fn put_statement_name(&mut self, id: StatementId) {
|
||||
let _: Result<(), ()> = id.write_name(|s| {
|
||||
self.extend_from_slice(s.as_bytes());
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
// writes a portal name by ID
|
||||
#[inline]
|
||||
fn put_portal_name(&mut self, id: Option<Oid>) {
|
||||
if let Some(id) = id {
|
||||
self.extend(b"sqlx_p_");
|
||||
|
||||
self.extend(itoa::Buffer::new().format(id.0).as_bytes());
|
||||
}
|
||||
|
||||
self.push(0);
|
||||
fn put_portal_name(&mut self, id: PortalId) {
|
||||
let _: Result<(), ()> = id.write_name(|s| {
|
||||
self.extend_from_slice(s.as_bytes());
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,130 @@
|
|||
mod buf_mut;
|
||||
|
||||
pub use buf_mut::PgBufMutExt;
|
||||
use std::fmt;
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::num::{NonZeroU32, Saturating};
|
||||
|
||||
pub(crate) use sqlx_core::io::*;
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct StatementId(IdInner);
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct PortalId(IdInner);
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
struct IdInner(Option<NonZeroU32>);
|
||||
|
||||
impl StatementId {
|
||||
pub const UNNAMED: Self = Self(IdInner::UNNAMED);
|
||||
|
||||
pub const NAMED_START: Self = Self(IdInner::NAMED_START);
|
||||
|
||||
#[cfg(test)]
|
||||
pub const TEST_VAL: Self = Self(IdInner::TEST_VAL);
|
||||
|
||||
const NAME_PREFIX: &'static str = "sqlx_s_";
|
||||
|
||||
pub fn next(&self) -> Self {
|
||||
Self(self.0.next())
|
||||
}
|
||||
|
||||
pub fn name_len(&self) -> Saturating<usize> {
|
||||
self.0.name_len(Self::NAME_PREFIX)
|
||||
}
|
||||
|
||||
// There's no common trait implemented by `Formatter` and `Vec<u8>` for this purpose;
|
||||
// we're deliberately avoiding the formatting machinery because it's known to be slow.
|
||||
pub fn write_name<E>(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> {
|
||||
self.0.write_name(Self::NAME_PREFIX, write)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for StatementId {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||
self.write_name(|s| f.write_str(s))
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl PortalId {
|
||||
// None selects the unnamed portal
|
||||
pub const UNNAMED: Self = PortalId(IdInner::UNNAMED);
|
||||
|
||||
pub const NAMED_START: Self = PortalId(IdInner::NAMED_START);
|
||||
|
||||
#[cfg(test)]
|
||||
pub const TEST_VAL: Self = Self(IdInner::TEST_VAL);
|
||||
|
||||
const NAME_PREFIX: &'static str = "sqlx_p_";
|
||||
|
||||
/// If ID represents a named portal, return the next ID, wrapping on overflow.
|
||||
///
|
||||
/// If this ID represents the unnamed portal, return the same.
|
||||
pub fn next(&self) -> Self {
|
||||
Self(self.0.next())
|
||||
}
|
||||
|
||||
/// Calculate the number of bytes that will be written by [`Self::write_name()`].
|
||||
pub fn name_len(&self) -> Saturating<usize> {
|
||||
self.0.name_len(Self::NAME_PREFIX)
|
||||
}
|
||||
|
||||
pub fn write_name<E>(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> {
|
||||
self.0.write_name(Self::NAME_PREFIX, write)
|
||||
}
|
||||
}
|
||||
|
||||
impl IdInner {
|
||||
const UNNAMED: Self = Self(None);
|
||||
|
||||
const NAMED_START: Self = Self(Some(NonZeroU32::MIN));
|
||||
|
||||
#[cfg(test)]
|
||||
pub const TEST_VAL: Self = Self(NonZeroU32::new(1234567890));
|
||||
|
||||
#[inline(always)]
|
||||
fn next(&self) -> Self {
|
||||
Self(
|
||||
self.0
|
||||
.map(|id| id.checked_add(1).unwrap_or(NonZeroU32::MIN)),
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn name_len(&self, name_prefix: &str) -> Saturating<usize> {
|
||||
let mut len = Saturating(0);
|
||||
|
||||
if let Some(id) = self.0 {
|
||||
len += name_prefix.len();
|
||||
// estimate the length of the ID in decimal
|
||||
// `.ilog10()` can't panic since the value is never zero
|
||||
len += id.get().ilog10() as usize;
|
||||
// add one to compensate for `ilog10()` rounding down.
|
||||
len += 1;
|
||||
}
|
||||
|
||||
// count the NUL terminator
|
||||
len += 1;
|
||||
|
||||
len
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn write_name<E>(
|
||||
&self,
|
||||
name_prefix: &str,
|
||||
mut write: impl FnMut(&str) -> Result<(), E>,
|
||||
) -> Result<(), E> {
|
||||
if let Some(id) = self.0 {
|
||||
write(name_prefix)?;
|
||||
write(itoa::Buffer::new().format(id.get()))?;
|
||||
}
|
||||
|
||||
write("\0")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ use sqlx_core::Either;
|
|||
use crate::describe::Describe;
|
||||
use crate::error::Error;
|
||||
use crate::executor::{Execute, Executor};
|
||||
use crate::message::{MessageFormat, Notification};
|
||||
use crate::message::{BackendMessageFormat, Notification};
|
||||
use crate::pool::PoolOptions;
|
||||
use crate::pool::{Pool, PoolConnection};
|
||||
use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
|
||||
|
@ -277,12 +277,12 @@ impl PgListener {
|
|||
|
||||
match message.format {
|
||||
// We've received an async notification, return it.
|
||||
MessageFormat::NotificationResponse => {
|
||||
BackendMessageFormat::NotificationResponse => {
|
||||
return Ok(Some(PgNotification(message.decode()?)));
|
||||
}
|
||||
|
||||
// Mark the connection as ready for another query
|
||||
MessageFormat::ReadyForQuery => {
|
||||
BackendMessageFormat::ReadyForQuery => {
|
||||
self.connection().await?.pending_ready_for_query_count -= 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -4,10 +4,10 @@ use memchr::memchr;
|
|||
use sqlx_core::bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::io::ProtocolDecode;
|
||||
|
||||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
use base64::prelude::{Engine as _, BASE64_STANDARD};
|
||||
|
||||
// On startup, the server sends an appropriate authentication request message,
|
||||
// to which the frontend must reply with an appropriate authentication
|
||||
// response message (such as a password).
|
||||
|
@ -60,8 +60,10 @@ pub enum Authentication {
|
|||
SaslFinal(AuthenticationSaslFinal),
|
||||
}
|
||||
|
||||
impl Decode<'_> for Authentication {
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
impl BackendMessage for Authentication {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::Authentication;
|
||||
|
||||
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
|
||||
Ok(match buf.get_u32() {
|
||||
0 => Authentication::Ok,
|
||||
|
||||
|
@ -129,7 +131,7 @@ pub struct AuthenticationSaslContinue {
|
|||
pub message: String,
|
||||
}
|
||||
|
||||
impl Decode<'_> for AuthenticationSaslContinue {
|
||||
impl ProtocolDecode<'_> for AuthenticationSaslContinue {
|
||||
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
let mut iterations: u32 = 4096;
|
||||
let mut salt = Vec::new();
|
||||
|
@ -173,7 +175,7 @@ pub struct AuthenticationSaslFinal {
|
|||
pub verifier: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Decode<'_> for AuthenticationSaslFinal {
|
||||
impl ProtocolDecode<'_> for AuthenticationSaslFinal {
|
||||
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
let mut verifier = Vec::new();
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ use byteorder::{BigEndian, ByteOrder};
|
|||
use sqlx_core::bytes::Bytes;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
|
||||
/// Contains cancellation key data. The frontend must save these values if it
|
||||
/// wishes to be able to issue `CancelRequest` messages later.
|
||||
|
@ -15,8 +15,10 @@ pub struct BackendKeyData {
|
|||
pub secret_key: u32,
|
||||
}
|
||||
|
||||
impl Decode<'_> for BackendKeyData {
|
||||
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
impl BackendMessage for BackendKeyData {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::BackendKeyData;
|
||||
|
||||
fn decode_body(buf: Bytes) -> Result<Self, Error> {
|
||||
let process_id = BigEndian::read_u32(&buf);
|
||||
let secret_key = BigEndian::read_u32(&buf[4..]);
|
||||
|
||||
|
@ -31,7 +33,7 @@ impl Decode<'_> for BackendKeyData {
|
|||
fn test_decode_backend_key_data() {
|
||||
const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+";
|
||||
|
||||
let m = BackendKeyData::decode(DATA.into()).unwrap();
|
||||
let m = BackendKeyData::decode_body(DATA.into()).unwrap();
|
||||
|
||||
assert_eq!(m.process_id, 10182);
|
||||
assert_eq!(m.secret_key, 2303903019);
|
||||
|
@ -43,6 +45,6 @@ fn bench_decode_backend_key_data(b: &mut test::Bencher) {
|
|||
const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+";
|
||||
|
||||
b.iter(|| {
|
||||
BackendKeyData::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
|
||||
BackendKeyData::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
use crate::io::Encode;
|
||||
use crate::io::PgBufMutExt;
|
||||
use crate::types::Oid;
|
||||
use crate::io::{PgBufMutExt, PortalId, StatementId};
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
use crate::PgValueFormat;
|
||||
use std::num::Saturating;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Bind<'a> {
|
||||
/// The ID of the destination portal (`None` selects the unnamed portal).
|
||||
pub portal: Option<Oid>,
|
||||
/// The ID of the destination portal (`PortalId::UNNAMED` selects the unnamed portal).
|
||||
pub portal: PortalId,
|
||||
|
||||
/// The id of the source prepared statement.
|
||||
pub statement: Oid,
|
||||
pub statement: StatementId,
|
||||
|
||||
/// The parameter format codes. Each must presently be zero (text) or one (binary).
|
||||
///
|
||||
|
@ -19,6 +19,8 @@ pub struct Bind<'a> {
|
|||
pub formats: &'a [PgValueFormat],
|
||||
|
||||
/// The number of parameters.
|
||||
///
|
||||
/// May be different from `formats.len()`
|
||||
pub num_params: i16,
|
||||
|
||||
/// The value of each parameter, in the indicated format.
|
||||
|
@ -33,31 +35,59 @@ pub struct Bind<'a> {
|
|||
pub result_formats: &'a [PgValueFormat],
|
||||
}
|
||||
|
||||
impl Encode<'_> for Bind<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.push(b'B');
|
||||
impl FrontendMessage for Bind<'_> {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Bind;
|
||||
|
||||
buf.put_length_prefixed(|buf| {
|
||||
buf.put_portal_name(self.portal);
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
let mut size = Saturating(0);
|
||||
size += self.portal.name_len();
|
||||
size += self.statement.name_len();
|
||||
|
||||
buf.put_statement_name(self.statement);
|
||||
// Parameter formats and length prefix
|
||||
size += 2;
|
||||
size += self.formats.len();
|
||||
|
||||
buf.extend(&(self.formats.len() as i16).to_be_bytes());
|
||||
// `num_params`
|
||||
size += 2;
|
||||
|
||||
for &format in self.formats {
|
||||
buf.extend(&(format as i16).to_be_bytes());
|
||||
}
|
||||
size += self.params.len();
|
||||
|
||||
buf.extend(&self.num_params.to_be_bytes());
|
||||
// Result formats and length prefix
|
||||
size += 2;
|
||||
size += self.result_formats.len();
|
||||
|
||||
buf.extend(self.params);
|
||||
size
|
||||
}
|
||||
|
||||
buf.extend(&(self.result_formats.len() as i16).to_be_bytes());
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), crate::Error> {
|
||||
buf.put_portal_name(self.portal);
|
||||
|
||||
for &format in self.result_formats {
|
||||
buf.extend(&(format as i16).to_be_bytes());
|
||||
}
|
||||
});
|
||||
buf.put_statement_name(self.statement);
|
||||
|
||||
let formats_len = i16::try_from(self.formats.len()).map_err(|_| {
|
||||
err_protocol!("too many parameter format codes ({})", self.formats.len())
|
||||
})?;
|
||||
|
||||
buf.extend(formats_len.to_be_bytes());
|
||||
|
||||
for &format in self.formats {
|
||||
buf.extend((format as i16).to_be_bytes());
|
||||
}
|
||||
|
||||
buf.extend(self.num_params.to_be_bytes());
|
||||
|
||||
buf.extend(self.params);
|
||||
|
||||
let result_formats_len = i16::try_from(self.formats.len())
|
||||
.map_err(|_| err_protocol!("too many result format codes ({})", self.formats.len()))?;
|
||||
|
||||
buf.extend(result_formats_len.to_be_bytes());
|
||||
|
||||
for &format in self.result_formats {
|
||||
buf.extend((format as i16).to_be_bytes());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::io::Encode;
|
||||
use crate::io::PgBufMutExt;
|
||||
use crate::types::Oid;
|
||||
use crate::io::{PgBufMutExt, PortalId, StatementId};
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
use std::num::Saturating;
|
||||
|
||||
const CLOSE_PORTAL: u8 = b'P';
|
||||
const CLOSE_STATEMENT: u8 = b'S';
|
||||
|
@ -8,18 +8,27 @@ const CLOSE_STATEMENT: u8 = b'S';
|
|||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub enum Close {
|
||||
Statement(Oid),
|
||||
// None selects the unnamed portal
|
||||
Portal(Option<Oid>),
|
||||
Statement(StatementId),
|
||||
Portal(PortalId),
|
||||
}
|
||||
|
||||
impl Encode<'_> for Close {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
// 15 bytes for 1-digit statement/portal IDs
|
||||
buf.reserve(20);
|
||||
buf.push(b'C');
|
||||
impl FrontendMessage for Close {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Close;
|
||||
|
||||
buf.put_length_prefixed(|buf| match self {
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
// Either `CLOSE_PORTAL` or `CLOSE_STATEMENT`
|
||||
let mut size = Saturating(1);
|
||||
|
||||
match self {
|
||||
Close::Statement(id) => size += id.name_len(),
|
||||
Close::Portal(id) => size += id.name_len(),
|
||||
}
|
||||
|
||||
size
|
||||
}
|
||||
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), crate::Error> {
|
||||
match self {
|
||||
Close::Statement(id) => {
|
||||
buf.push(CLOSE_STATEMENT);
|
||||
buf.put_statement_name(*id);
|
||||
|
@ -29,6 +38,8 @@ impl Encode<'_> for Close {
|
|||
buf.push(CLOSE_PORTAL);
|
||||
buf.put_portal_name(*id);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ use memchr::memrchr;
|
|||
use sqlx_core::bytes::Bytes;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CommandComplete {
|
||||
|
@ -12,10 +12,11 @@ pub struct CommandComplete {
|
|||
tag: Bytes,
|
||||
}
|
||||
|
||||
impl Decode<'_> for CommandComplete {
|
||||
#[inline]
|
||||
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
Ok(CommandComplete { tag: buf })
|
||||
impl BackendMessage for CommandComplete {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::CommandComplete;
|
||||
|
||||
fn decode_body(bytes: Bytes) -> Result<Self, Error> {
|
||||
Ok(CommandComplete { tag: bytes })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -35,7 +36,7 @@ impl CommandComplete {
|
|||
fn test_decode_command_complete_for_insert() {
|
||||
const DATA: &[u8] = b"INSERT 0 1214\0";
|
||||
|
||||
let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap();
|
||||
let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
|
||||
|
||||
assert_eq!(cc.rows_affected(), 1214);
|
||||
}
|
||||
|
@ -44,7 +45,7 @@ fn test_decode_command_complete_for_insert() {
|
|||
fn test_decode_command_complete_for_begin() {
|
||||
const DATA: &[u8] = b"BEGIN\0";
|
||||
|
||||
let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap();
|
||||
let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
|
||||
|
||||
assert_eq!(cc.rows_affected(), 0);
|
||||
}
|
||||
|
@ -53,7 +54,7 @@ fn test_decode_command_complete_for_begin() {
|
|||
fn test_decode_command_complete_for_update() {
|
||||
const DATA: &[u8] = b"UPDATE 5\0";
|
||||
|
||||
let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap();
|
||||
let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
|
||||
|
||||
assert_eq!(cc.rows_affected(), 5);
|
||||
}
|
||||
|
@ -64,7 +65,7 @@ fn bench_decode_command_complete(b: &mut test::Bencher) {
|
|||
const DATA: &[u8] = b"INSERT 0 1214\0";
|
||||
|
||||
b.iter(|| {
|
||||
let _ = CommandComplete::decode(test::black_box(Bytes::from_static(DATA)));
|
||||
let _ = CommandComplete::decode_body(test::black_box(Bytes::from_static(DATA)));
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -73,7 +74,7 @@ fn bench_decode_command_complete(b: &mut test::Bencher) {
|
|||
fn bench_decode_command_complete_rows_affected(b: &mut test::Bencher) {
|
||||
const DATA: &[u8] = b"INSERT 0 1214\0";
|
||||
|
||||
let data = CommandComplete::decode(Bytes::from_static(DATA)).unwrap();
|
||||
let data = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap();
|
||||
|
||||
b.iter(|| {
|
||||
let _rows = test::black_box(&data).rows_affected();
|
||||
|
|
|
@ -1,15 +1,25 @@
|
|||
use crate::error::Result;
|
||||
use crate::io::{BufExt, BufMutExt, Decode, Encode};
|
||||
use sqlx_core::bytes::{Buf, BufMut, Bytes};
|
||||
use crate::io::BufMutExt;
|
||||
use crate::message::{
|
||||
BackendMessage, BackendMessageFormat, FrontendMessage, FrontendMessageFormat,
|
||||
};
|
||||
use sqlx_core::bytes::{Buf, Bytes};
|
||||
use sqlx_core::Error;
|
||||
use std::num::Saturating;
|
||||
use std::ops::Deref;
|
||||
|
||||
/// The same structure is sent for both `CopyInResponse` and `CopyOutResponse`
|
||||
pub struct CopyResponse {
|
||||
pub struct CopyResponseData {
|
||||
pub format: i8,
|
||||
pub num_columns: i16,
|
||||
pub format_codes: Vec<i16>,
|
||||
}
|
||||
|
||||
pub struct CopyInResponse(pub CopyResponseData);
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct CopyOutResponse(pub CopyResponseData);
|
||||
|
||||
pub struct CopyData<B>(pub B);
|
||||
|
||||
pub struct CopyFail {
|
||||
|
@ -18,14 +28,15 @@ pub struct CopyFail {
|
|||
|
||||
pub struct CopyDone;
|
||||
|
||||
impl Decode<'_> for CopyResponse {
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
impl CopyResponseData {
|
||||
#[inline]
|
||||
fn decode(mut buf: Bytes) -> Result<Self> {
|
||||
let format = buf.get_i8();
|
||||
let num_columns = buf.get_i16();
|
||||
|
||||
let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect();
|
||||
|
||||
Ok(CopyResponse {
|
||||
Ok(CopyResponseData {
|
||||
format,
|
||||
num_columns,
|
||||
format_codes,
|
||||
|
@ -33,40 +44,65 @@ impl Decode<'_> for CopyResponse {
|
|||
}
|
||||
}
|
||||
|
||||
impl Decode<'_> for CopyData<Bytes> {
|
||||
fn decode_with(buf: Bytes, _: ()) -> Result<Self> {
|
||||
// well.. that was easy
|
||||
Ok(CopyData(buf))
|
||||
impl BackendMessage for CopyInResponse {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyInResponse;
|
||||
|
||||
#[inline(always)]
|
||||
fn decode_body(buf: Bytes) -> std::result::Result<Self, Error> {
|
||||
Ok(Self(CopyResponseData::decode(buf)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Deref<Target = [u8]>> Encode<'_> for CopyData<B> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) {
|
||||
buf.push(b'd');
|
||||
buf.put_u32(self.0.len() as u32 + 4);
|
||||
impl BackendMessage for CopyOutResponse {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyOutResponse;
|
||||
|
||||
#[inline(always)]
|
||||
fn decode_body(buf: Bytes) -> std::result::Result<Self, Error> {
|
||||
Ok(Self(CopyResponseData::decode(buf)?))
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendMessage for CopyData<Bytes> {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyData;
|
||||
|
||||
#[inline(always)]
|
||||
fn decode_body(buf: Bytes) -> std::result::Result<Self, Error> {
|
||||
Ok(Self(buf))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Deref<Target = [u8]>> FrontendMessage for CopyData<B> {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyData;
|
||||
|
||||
#[inline(always)]
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
Saturating(self.0.len())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
buf.extend_from_slice(&self.0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_> for CopyFail {
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self> {
|
||||
Ok(CopyFail {
|
||||
message: buf.get_str_nul()?,
|
||||
})
|
||||
impl FrontendMessage for CopyFail {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyFail;
|
||||
|
||||
#[inline(always)]
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
Saturating(self.message.len())
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_> for CopyFail {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
let len = 4 + self.message.len() + 1;
|
||||
|
||||
buf.push(b'f'); // to pay respects
|
||||
buf.put_u32(len as u32);
|
||||
#[inline(always)]
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> std::result::Result<(), Error> {
|
||||
buf.put_str_nul(&self.message);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl CopyFail {
|
||||
#[inline(always)]
|
||||
pub fn new(msg: impl Into<String>) -> CopyFail {
|
||||
CopyFail {
|
||||
message: msg.into(),
|
||||
|
@ -74,23 +110,32 @@ impl CopyFail {
|
|||
}
|
||||
}
|
||||
|
||||
impl Decode<'_> for CopyDone {
|
||||
fn decode_with(buf: Bytes, _: ()) -> Result<Self> {
|
||||
if buf.is_empty() {
|
||||
Ok(CopyDone)
|
||||
} else {
|
||||
Err(err_protocol!(
|
||||
"expected no data for CopyDone, got: {:?}",
|
||||
buf
|
||||
))
|
||||
}
|
||||
impl FrontendMessage for CopyDone {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyDone;
|
||||
#[inline(always)]
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
Saturating(0)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn encode_body(&self, _buf: &mut Vec<u8>) -> std::result::Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_> for CopyDone {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.reserve(4);
|
||||
buf.push(b'c');
|
||||
buf.put_u32(4);
|
||||
impl BackendMessage for CopyDone {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyDone;
|
||||
|
||||
#[inline(always)]
|
||||
fn decode_body(bytes: Bytes) -> std::result::Result<Self, Error> {
|
||||
if !bytes.is_empty() {
|
||||
// Not fatal but may indicate a protocol change
|
||||
tracing::debug!(
|
||||
"Postgres backend returned non-empty message for CopyDone: \"{}\"",
|
||||
bytes.escape_ascii()
|
||||
)
|
||||
}
|
||||
|
||||
Ok(CopyDone)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use sqlx_core::bytes::Bytes;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
|
||||
/// A row of data from the database.
|
||||
#[derive(Debug)]
|
||||
|
@ -26,25 +25,55 @@ impl DataRow {
|
|||
}
|
||||
}
|
||||
|
||||
impl Decode<'_> for DataRow {
|
||||
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
impl BackendMessage for DataRow {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::DataRow;
|
||||
|
||||
fn decode_body(buf: Bytes) -> Result<Self, Error> {
|
||||
if buf.len() < 2 {
|
||||
return Err(err_protocol!(
|
||||
"expected at least 2 bytes, got {}",
|
||||
buf.len()
|
||||
));
|
||||
}
|
||||
|
||||
let cnt = BigEndian::read_u16(&buf) as usize;
|
||||
|
||||
let mut values = Vec::with_capacity(cnt);
|
||||
let mut offset = 2;
|
||||
let mut offset: u32 = 2;
|
||||
|
||||
for _ in 0..cnt {
|
||||
let value_start = offset
|
||||
.checked_add(4)
|
||||
.ok_or_else(|| err_protocol!("next value start out of range (offset: {offset})"))?;
|
||||
|
||||
// widen both to a larger type for a safe comparison
|
||||
if (buf.len() as u64) < (value_start as u64) {
|
||||
return Err(err_protocol!(
|
||||
"expected 4 bytes at offset {offset}, got {}",
|
||||
(value_start as u64) - (buf.len() as u64)
|
||||
));
|
||||
}
|
||||
|
||||
// Length of the column value, in bytes (this count does not include itself).
|
||||
// Can be zero. As a special case, -1 indicates a NULL column value.
|
||||
// No value bytes follow in the NULL case.
|
||||
//
|
||||
// we know `offset` is within range of `buf.len()` from the above check
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
let length = BigEndian::read_i32(&buf[(offset as usize)..]);
|
||||
offset += 4;
|
||||
|
||||
if length < 0 {
|
||||
values.push(None);
|
||||
if let Ok(length) = u32::try_from(length) {
|
||||
let value_end = value_start.checked_add(length).ok_or_else(|| {
|
||||
err_protocol!("value_start + length out of range ({offset} + {length})")
|
||||
})?;
|
||||
|
||||
values.push(Some(value_start..value_end));
|
||||
offset = value_end;
|
||||
} else {
|
||||
values.push(Some(offset..(offset + length as u32)));
|
||||
offset += length as u32;
|
||||
// Negative values signify NULL
|
||||
values.push(None);
|
||||
// `value_start` is actually the next value now.
|
||||
offset = value_start;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,9 +86,22 @@ impl Decode<'_> for DataRow {
|
|||
|
||||
#[test]
|
||||
fn test_decode_data_row() {
|
||||
const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P";
|
||||
const DATA: &[u8] = b"\
|
||||
\x00\x08\
|
||||
\xff\xff\xff\xff\
|
||||
\x00\x00\x00\x04\
|
||||
\x00\x00\x00\n\
|
||||
\xff\xff\xff\xff\
|
||||
\x00\x00\x00\x04\
|
||||
\x00\x00\x00\x14\
|
||||
\xff\xff\xff\xff\
|
||||
\x00\x00\x00\x04\
|
||||
\x00\x00\x00(\
|
||||
\xff\xff\xff\xff\
|
||||
\x00\x00\x00\x04\
|
||||
\x00\x00\x00P";
|
||||
|
||||
let row = DataRow::decode(DATA.into()).unwrap();
|
||||
let row = DataRow::decode_body(DATA.into()).unwrap();
|
||||
|
||||
assert_eq!(row.values.len(), 8);
|
||||
|
||||
|
@ -78,7 +120,7 @@ fn test_decode_data_row() {
|
|||
fn bench_data_row_get(b: &mut test::Bencher) {
|
||||
const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P";
|
||||
|
||||
let row = DataRow::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
|
||||
let row = DataRow::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
|
||||
|
||||
b.iter(|| {
|
||||
let _value = test::black_box(&row).get(3);
|
||||
|
@ -91,6 +133,6 @@ fn bench_decode_data_row(b: &mut test::Bencher) {
|
|||
const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P";
|
||||
|
||||
b.iter(|| {
|
||||
let _ = DataRow::decode(test::black_box(Bytes::from_static(DATA)));
|
||||
let _ = DataRow::decode_body(test::black_box(Bytes::from_static(DATA)));
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,127 +1,103 @@
|
|||
use crate::io::Encode;
|
||||
use crate::io::PgBufMutExt;
|
||||
use crate::types::Oid;
|
||||
use crate::io::{PgBufMutExt, PortalId, StatementId};
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
use sqlx_core::Error;
|
||||
use std::num::Saturating;
|
||||
|
||||
const DESCRIBE_PORTAL: u8 = b'P';
|
||||
const DESCRIBE_STATEMENT: u8 = b'S';
|
||||
|
||||
// [Describe] will emit both a [RowDescription] and a [ParameterDescription] message
|
||||
|
||||
/// Note: will emit both a RowDescription and a ParameterDescription message
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub enum Describe {
|
||||
UnnamedStatement,
|
||||
Statement(Oid),
|
||||
|
||||
UnnamedPortal,
|
||||
Portal(Oid),
|
||||
Statement(StatementId),
|
||||
Portal(PortalId),
|
||||
}
|
||||
|
||||
impl Encode<'_> for Describe {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
// 15 bytes for 1-digit statement/portal IDs
|
||||
buf.reserve(20);
|
||||
buf.push(b'D');
|
||||
impl FrontendMessage for Describe {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Describe;
|
||||
|
||||
buf.put_length_prefixed(|buf| {
|
||||
match self {
|
||||
// #[likely]
|
||||
Describe::Statement(id) => {
|
||||
buf.push(DESCRIBE_STATEMENT);
|
||||
buf.put_statement_name(*id);
|
||||
}
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
// Either `DESCRIBE_PORTAL` or `DESCRIBE_STATEMENT`
|
||||
let mut size = Saturating(1);
|
||||
|
||||
Describe::UnnamedPortal => {
|
||||
buf.push(DESCRIBE_PORTAL);
|
||||
buf.push(0);
|
||||
}
|
||||
match self {
|
||||
Describe::Statement(id) => size += id.name_len(),
|
||||
Describe::Portal(id) => size += id.name_len(),
|
||||
}
|
||||
|
||||
Describe::UnnamedStatement => {
|
||||
buf.push(DESCRIBE_STATEMENT);
|
||||
buf.push(0);
|
||||
}
|
||||
size
|
||||
}
|
||||
|
||||
Describe::Portal(id) => {
|
||||
buf.push(DESCRIBE_PORTAL);
|
||||
buf.put_portal_name(Some(*id));
|
||||
}
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
match self {
|
||||
// #[likely]
|
||||
Describe::Statement(id) => {
|
||||
buf.push(DESCRIBE_STATEMENT);
|
||||
buf.put_statement_name(*id);
|
||||
}
|
||||
});
|
||||
|
||||
Describe::Portal(id) => {
|
||||
buf.push(DESCRIBE_PORTAL);
|
||||
buf.put_portal_name(*id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_describe_portal() {
|
||||
const EXPECTED: &[u8] = b"D\0\0\0\x0EPsqlx_p_5\0";
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::message::FrontendMessage;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Describe::Portal(Oid(5));
|
||||
use super::{Describe, PortalId, StatementId};
|
||||
|
||||
m.encode(&mut buf);
|
||||
#[test]
|
||||
fn test_encode_describe_portal() {
|
||||
const EXPECTED: &[u8] = b"D\0\0\0\x17Psqlx_p_1234567890\0";
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_describe_unnamed_portal() {
|
||||
const EXPECTED: &[u8] = b"D\0\0\0\x06P\0";
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Describe::UnnamedPortal;
|
||||
|
||||
m.encode(&mut buf);
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_describe_statement() {
|
||||
const EXPECTED: &[u8] = b"D\0\0\0\x0ESsqlx_s_5\0";
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Describe::Statement(Oid(5));
|
||||
|
||||
m.encode(&mut buf);
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_describe_unnamed_statement() {
|
||||
const EXPECTED: &[u8] = b"D\0\0\0\x06S\0";
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Describe::UnnamedStatement;
|
||||
|
||||
m.encode(&mut buf);
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(debug_assertions)))]
|
||||
#[bench]
|
||||
fn bench_encode_describe_portal(b: &mut test::Bencher) {
|
||||
use test::black_box;
|
||||
|
||||
let mut buf = Vec::with_capacity(128);
|
||||
|
||||
b.iter(|| {
|
||||
buf.clear();
|
||||
|
||||
black_box(Describe::Portal(5)).encode(&mut buf);
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(debug_assertions)))]
|
||||
#[bench]
|
||||
fn bench_encode_describe_unnamed_statement(b: &mut test::Bencher) {
|
||||
use test::black_box;
|
||||
|
||||
let mut buf = Vec::with_capacity(128);
|
||||
|
||||
b.iter(|| {
|
||||
buf.clear();
|
||||
|
||||
black_box(Describe::UnnamedStatement).encode(&mut buf);
|
||||
});
|
||||
let mut buf = Vec::new();
|
||||
let m = Describe::Portal(PortalId::TEST_VAL);
|
||||
|
||||
m.encode_msg(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_describe_unnamed_portal() {
|
||||
const EXPECTED: &[u8] = b"D\0\0\0\x06P\0";
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Describe::Portal(PortalId::UNNAMED);
|
||||
|
||||
m.encode_msg(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_describe_statement() {
|
||||
const EXPECTED: &[u8] = b"D\0\0\0\x17Ssqlx_s_1234567890\0";
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Describe::Statement(StatementId::TEST_VAL);
|
||||
|
||||
m.encode_msg(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_describe_unnamed_statement() {
|
||||
const EXPECTED: &[u8] = b"D\0\0\0\x06S\0";
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Describe::Statement(StatementId::UNNAMED);
|
||||
|
||||
m.encode_msg(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,39 +1,73 @@
|
|||
use crate::io::Encode;
|
||||
use crate::io::PgBufMutExt;
|
||||
use crate::types::Oid;
|
||||
use std::num::Saturating;
|
||||
|
||||
use sqlx_core::Error;
|
||||
|
||||
use crate::io::{PgBufMutExt, PortalId};
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
|
||||
pub struct Execute {
|
||||
/// The id of the portal to execute (`None` selects the unnamed portal).
|
||||
pub portal: Option<Oid>,
|
||||
/// The id of the portal to execute.
|
||||
pub portal: PortalId,
|
||||
|
||||
/// Maximum number of rows to return, if portal contains a query
|
||||
/// that returns rows (ignored otherwise). Zero denotes “no limit”.
|
||||
pub limit: u32,
|
||||
}
|
||||
|
||||
impl Encode<'_> for Execute {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.reserve(20);
|
||||
buf.push(b'E');
|
||||
impl FrontendMessage for Execute {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Execute;
|
||||
|
||||
buf.put_length_prefixed(|buf| {
|
||||
buf.put_portal_name(self.portal);
|
||||
buf.extend(&self.limit.to_be_bytes());
|
||||
});
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
let mut size = Saturating(0);
|
||||
|
||||
size += self.portal.name_len();
|
||||
size += 2; // limit
|
||||
|
||||
size
|
||||
}
|
||||
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
buf.put_portal_name(self.portal);
|
||||
buf.extend(&self.limit.to_be_bytes());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_execute() {
|
||||
const EXPECTED: &[u8] = b"E\0\0\0\x11sqlx_p_5\0\0\0\0\x02";
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::io::PortalId;
|
||||
use crate::message::FrontendMessage;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Execute {
|
||||
portal: Some(Oid(5)),
|
||||
limit: 2,
|
||||
};
|
||||
use super::Execute;
|
||||
|
||||
m.encode(&mut buf);
|
||||
#[test]
|
||||
fn test_encode_execute_named_portal() {
|
||||
const EXPECTED: &[u8] = b"E\0\0\0\x1Asqlx_p_1234567890\0\0\0\0\x02";
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
let mut buf = Vec::new();
|
||||
let m = Execute {
|
||||
portal: PortalId::TEST_VAL,
|
||||
limit: 2,
|
||||
};
|
||||
|
||||
m.encode_msg(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_execute_unnamed_portal() {
|
||||
const EXPECTED: &[u8] = b"E\0\0\0\x09\0\x49\x96\x02\xD2";
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Execute {
|
||||
portal: PortalId::UNNAMED,
|
||||
limit: 1234567890,
|
||||
};
|
||||
|
||||
m.encode_msg(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,17 +1,25 @@
|
|||
use crate::io::Encode;
|
||||
|
||||
// The Flush message does not cause any specific output to be generated,
|
||||
// but forces the backend to deliver any data pending in its output buffers.
|
||||
|
||||
// A Flush must be sent after any extended-query command except Sync, if the
|
||||
// frontend wishes to examine the results of that command before issuing more commands.
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
use sqlx_core::Error;
|
||||
use std::num::Saturating;
|
||||
|
||||
/// The Flush message does not cause any specific output to be generated,
|
||||
/// but forces the backend to deliver any data pending in its output buffers.
|
||||
///
|
||||
/// A Flush must be sent after any extended-query command except Sync, if the
|
||||
/// frontend wishes to examine the results of that command before issuing more commands.
|
||||
#[derive(Debug)]
|
||||
pub struct Flush;
|
||||
|
||||
impl Encode<'_> for Flush {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.push(b'H');
|
||||
buf.extend(&4_i32.to_be_bytes());
|
||||
impl FrontendMessage for Flush {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Flush;
|
||||
|
||||
#[inline(always)]
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
Saturating(0)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn encode_body(&self, _buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use sqlx_core::bytes::Bytes;
|
||||
use std::num::Saturating;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::io::PgBufMutExt;
|
||||
|
||||
mod authentication;
|
||||
mod backend_key_data;
|
||||
|
@ -17,6 +18,7 @@ mod notification;
|
|||
mod parameter_description;
|
||||
mod parameter_status;
|
||||
mod parse;
|
||||
mod parse_complete;
|
||||
mod password;
|
||||
mod query;
|
||||
mod ready_for_query;
|
||||
|
@ -33,7 +35,7 @@ pub use backend_key_data::BackendKeyData;
|
|||
pub use bind::Bind;
|
||||
pub use close::Close;
|
||||
pub use command_complete::CommandComplete;
|
||||
pub use copy::{CopyData, CopyDone, CopyFail, CopyResponse};
|
||||
pub use copy::{CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse, CopyResponseData};
|
||||
pub use data_row::DataRow;
|
||||
pub use describe::Describe;
|
||||
pub use execute::Execute;
|
||||
|
@ -43,20 +45,51 @@ pub use notification::Notification;
|
|||
pub use parameter_description::ParameterDescription;
|
||||
pub use parameter_status::ParameterStatus;
|
||||
pub use parse::Parse;
|
||||
pub use parse_complete::ParseComplete;
|
||||
pub use password::Password;
|
||||
pub use query::Query;
|
||||
pub use ready_for_query::{ReadyForQuery, TransactionStatus};
|
||||
pub use response::{Notice, PgSeverity};
|
||||
pub use row_description::RowDescription;
|
||||
pub use sasl::{SaslInitialResponse, SaslResponse};
|
||||
use sqlx_core::io::ProtocolEncode;
|
||||
pub use ssl_request::SslRequest;
|
||||
pub use startup::Startup;
|
||||
pub use sync::Sync;
|
||||
pub use terminate::Terminate;
|
||||
|
||||
// Note: we can't use the same enum for both frontend and backend message formats
|
||||
// because there are duplicated format codes between them.
|
||||
//
|
||||
// For example, `Close` (frontend) and `CommandComplete` (backend) both use format code `C`.
|
||||
// <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
#[derive(Debug, PartialOrd, PartialEq)]
|
||||
#[repr(u8)]
|
||||
pub enum MessageFormat {
|
||||
pub enum FrontendMessageFormat {
|
||||
Bind = b'B',
|
||||
Close = b'C',
|
||||
CopyData = b'd',
|
||||
CopyDone = b'c',
|
||||
CopyFail = b'f',
|
||||
Describe = b'D',
|
||||
Execute = b'E',
|
||||
Flush = b'H',
|
||||
Parse = b'P',
|
||||
/// This message format is polymorphic. It's used for:
|
||||
///
|
||||
/// * Plain password responses
|
||||
/// * MD5 password responses
|
||||
/// * SASL responses
|
||||
/// * GSSAPI/SSPI responses
|
||||
PasswordPolymorphic = b'p',
|
||||
Query = b'Q',
|
||||
Sync = b'S',
|
||||
Terminate = b'X',
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialOrd, PartialEq)]
|
||||
#[repr(u8)]
|
||||
pub enum BackendMessageFormat {
|
||||
Authentication,
|
||||
BackendKeyData,
|
||||
BindComplete,
|
||||
|
@ -81,49 +114,116 @@ pub enum MessageFormat {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Message {
|
||||
pub format: MessageFormat,
|
||||
pub struct ReceivedMessage {
|
||||
pub format: BackendMessageFormat,
|
||||
pub contents: Bytes,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
impl ReceivedMessage {
|
||||
#[inline]
|
||||
pub fn decode<'de, T>(self) -> Result<T, Error>
|
||||
pub fn decode<T>(self) -> Result<T, Error>
|
||||
where
|
||||
T: Decode<'de>,
|
||||
T: BackendMessage,
|
||||
{
|
||||
T::decode(self.contents)
|
||||
if T::FORMAT != self.format {
|
||||
return Err(err_protocol!(
|
||||
"Postgres protocol error: expected {:?}, got {:?}",
|
||||
T::FORMAT,
|
||||
self.format
|
||||
));
|
||||
}
|
||||
|
||||
T::decode_body(self.contents).map_err(|e| match e {
|
||||
Error::Protocol(s) => {
|
||||
err_protocol!("Postgres protocol error (reading {:?}): {s}", self.format)
|
||||
}
|
||||
other => other,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MessageFormat {
|
||||
impl BackendMessageFormat {
|
||||
pub fn try_from_u8(v: u8) -> Result<Self, Error> {
|
||||
// https://www.postgresql.org/docs/current/protocol-message-formats.html
|
||||
|
||||
Ok(match v {
|
||||
b'1' => MessageFormat::ParseComplete,
|
||||
b'2' => MessageFormat::BindComplete,
|
||||
b'3' => MessageFormat::CloseComplete,
|
||||
b'C' => MessageFormat::CommandComplete,
|
||||
b'd' => MessageFormat::CopyData,
|
||||
b'c' => MessageFormat::CopyDone,
|
||||
b'G' => MessageFormat::CopyInResponse,
|
||||
b'H' => MessageFormat::CopyOutResponse,
|
||||
b'D' => MessageFormat::DataRow,
|
||||
b'E' => MessageFormat::ErrorResponse,
|
||||
b'I' => MessageFormat::EmptyQueryResponse,
|
||||
b'A' => MessageFormat::NotificationResponse,
|
||||
b'K' => MessageFormat::BackendKeyData,
|
||||
b'N' => MessageFormat::NoticeResponse,
|
||||
b'R' => MessageFormat::Authentication,
|
||||
b'S' => MessageFormat::ParameterStatus,
|
||||
b'T' => MessageFormat::RowDescription,
|
||||
b'Z' => MessageFormat::ReadyForQuery,
|
||||
b'n' => MessageFormat::NoData,
|
||||
b's' => MessageFormat::PortalSuspended,
|
||||
b't' => MessageFormat::ParameterDescription,
|
||||
b'1' => BackendMessageFormat::ParseComplete,
|
||||
b'2' => BackendMessageFormat::BindComplete,
|
||||
b'3' => BackendMessageFormat::CloseComplete,
|
||||
b'C' => BackendMessageFormat::CommandComplete,
|
||||
b'd' => BackendMessageFormat::CopyData,
|
||||
b'c' => BackendMessageFormat::CopyDone,
|
||||
b'G' => BackendMessageFormat::CopyInResponse,
|
||||
b'H' => BackendMessageFormat::CopyOutResponse,
|
||||
b'D' => BackendMessageFormat::DataRow,
|
||||
b'E' => BackendMessageFormat::ErrorResponse,
|
||||
b'I' => BackendMessageFormat::EmptyQueryResponse,
|
||||
b'A' => BackendMessageFormat::NotificationResponse,
|
||||
b'K' => BackendMessageFormat::BackendKeyData,
|
||||
b'N' => BackendMessageFormat::NoticeResponse,
|
||||
b'R' => BackendMessageFormat::Authentication,
|
||||
b'S' => BackendMessageFormat::ParameterStatus,
|
||||
b'T' => BackendMessageFormat::RowDescription,
|
||||
b'Z' => BackendMessageFormat::ReadyForQuery,
|
||||
b'n' => BackendMessageFormat::NoData,
|
||||
b's' => BackendMessageFormat::PortalSuspended,
|
||||
b't' => BackendMessageFormat::ParameterDescription,
|
||||
|
||||
_ => return Err(err_protocol!("unknown message type: {:?}", v as char)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait FrontendMessage: Sized {
|
||||
/// The format prefix of this message.
|
||||
const FORMAT: FrontendMessageFormat;
|
||||
|
||||
/// Return the amount of space, in bytes, to reserve in the buffer passed to [`Self::encode_body()`].
|
||||
fn body_size_hint(&self) -> Saturating<usize>;
|
||||
|
||||
/// Encode this type as a Frontend message in the Postgres protocol.
|
||||
///
|
||||
/// The implementation should *not* include `Self::FORMAT` or the length prefix.
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error>;
|
||||
|
||||
#[inline(always)]
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
fn encode_msg(self, buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
EncodeMessage(self).encode(buf)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) trait BackendMessage: Sized {
|
||||
/// The expected message format.
|
||||
///
|
||||
/// <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
||||
const FORMAT: BackendMessageFormat;
|
||||
|
||||
/// Decode this type from a Backend message in the Postgres protocol.
|
||||
///
|
||||
/// The format code and length prefix have already been read and are not at the start of `bytes`.
|
||||
fn decode_body(buf: Bytes) -> Result<Self, Error>;
|
||||
}
|
||||
|
||||
pub struct EncodeMessage<F>(pub F);
|
||||
|
||||
impl<F: FrontendMessage> ProtocolEncode<'_, ()> for EncodeMessage<F> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) -> Result<(), Error> {
|
||||
let mut size_hint = self.0.body_size_hint();
|
||||
// plus format code and length prefix
|
||||
size_hint += 5;
|
||||
|
||||
// don't panic if `size_hint` is ridiculous
|
||||
buf.try_reserve(size_hint.0).map_err(|e| {
|
||||
err_protocol!(
|
||||
"Postgres protocol: error allocating {} bytes for encoding message {:?}: {e}",
|
||||
size_hint.0,
|
||||
F::FORMAT,
|
||||
)
|
||||
})?;
|
||||
|
||||
buf.push(F::FORMAT as u8);
|
||||
|
||||
buf.put_length_prefixed(|buf| self.0.encode_body(buf))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use sqlx_core::bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::{BufExt, Decode};
|
||||
use crate::io::BufExt;
|
||||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Notification {
|
||||
|
@ -10,9 +11,10 @@ pub struct Notification {
|
|||
pub(crate) payload: Bytes,
|
||||
}
|
||||
|
||||
impl Decode<'_> for Notification {
|
||||
#[inline]
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
impl BackendMessage for Notification {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::NotificationResponse;
|
||||
|
||||
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
|
||||
let process_id = buf.get_u32();
|
||||
let channel = buf.get_bytes_nul()?;
|
||||
let payload = buf.get_bytes_nul()?;
|
||||
|
@ -29,7 +31,7 @@ impl Decode<'_> for Notification {
|
|||
fn test_decode_notification_response() {
|
||||
const NOTIFICATION_RESPONSE: &[u8] = b"\x34\x20\x10\x02TEST-CHANNEL\0THIS IS A TEST\0";
|
||||
|
||||
let message = Notification::decode(Bytes::from(NOTIFICATION_RESPONSE)).unwrap();
|
||||
let message = Notification::decode_body(Bytes::from(NOTIFICATION_RESPONSE)).unwrap();
|
||||
|
||||
assert_eq!(message.process_id, 0x34201002);
|
||||
assert_eq!(&*message.channel, &b"TEST-CHANNEL"[..]);
|
||||
|
|
|
@ -2,7 +2,7 @@ use smallvec::SmallVec;
|
|||
use sqlx_core::bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
use crate::types::Oid;
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -10,8 +10,10 @@ pub struct ParameterDescription {
|
|||
pub types: SmallVec<[Oid; 6]>,
|
||||
}
|
||||
|
||||
impl Decode<'_> for ParameterDescription {
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
impl BackendMessage for ParameterDescription {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterDescription;
|
||||
|
||||
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
|
||||
let cnt = buf.get_u16();
|
||||
let mut types = SmallVec::with_capacity(cnt as usize);
|
||||
|
||||
|
@ -27,7 +29,7 @@ impl Decode<'_> for ParameterDescription {
|
|||
fn test_decode_parameter_description() {
|
||||
const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00";
|
||||
|
||||
let m = ParameterDescription::decode(DATA.into()).unwrap();
|
||||
let m = ParameterDescription::decode_body(DATA.into()).unwrap();
|
||||
|
||||
assert_eq!(m.types.len(), 2);
|
||||
assert_eq!(m.types[0], Oid(0x0000_0000));
|
||||
|
@ -38,7 +40,7 @@ fn test_decode_parameter_description() {
|
|||
fn test_decode_empty_parameter_description() {
|
||||
const DATA: &[u8] = b"\x00\x00";
|
||||
|
||||
let m = ParameterDescription::decode(DATA.into()).unwrap();
|
||||
let m = ParameterDescription::decode_body(DATA.into()).unwrap();
|
||||
|
||||
assert!(m.types.is_empty());
|
||||
}
|
||||
|
@ -49,6 +51,6 @@ fn bench_decode_parameter_description(b: &mut test::Bencher) {
|
|||
const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00";
|
||||
|
||||
b.iter(|| {
|
||||
ParameterDescription::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
|
||||
ParameterDescription::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use sqlx_core::bytes::Bytes;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::{BufExt, Decode};
|
||||
use crate::io::BufExt;
|
||||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ParameterStatus {
|
||||
|
@ -9,8 +10,10 @@ pub struct ParameterStatus {
|
|||
pub value: String,
|
||||
}
|
||||
|
||||
impl Decode<'_> for ParameterStatus {
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
impl BackendMessage for ParameterStatus {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterStatus;
|
||||
|
||||
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
|
||||
let name = buf.get_str_nul()?;
|
||||
let value = buf.get_str_nul()?;
|
||||
|
||||
|
@ -22,7 +25,7 @@ impl Decode<'_> for ParameterStatus {
|
|||
fn test_decode_parameter_status() {
|
||||
const DATA: &[u8] = b"client_encoding\x00UTF8\x00";
|
||||
|
||||
let m = ParameterStatus::decode(DATA.into()).unwrap();
|
||||
let m = ParameterStatus::decode_body(DATA.into()).unwrap();
|
||||
|
||||
assert_eq!(&m.name, "client_encoding");
|
||||
assert_eq!(&m.value, "UTF8")
|
||||
|
@ -32,7 +35,7 @@ fn test_decode_parameter_status() {
|
|||
fn test_decode_empty_parameter_status() {
|
||||
const DATA: &[u8] = b"\x00\x00";
|
||||
|
||||
let m = ParameterStatus::decode(DATA.into()).unwrap();
|
||||
let m = ParameterStatus::decode_body(DATA.into()).unwrap();
|
||||
|
||||
assert!(m.name.is_empty());
|
||||
assert!(m.value.is_empty());
|
||||
|
@ -44,7 +47,7 @@ fn bench_decode_parameter_status(b: &mut test::Bencher) {
|
|||
const DATA: &[u8] = b"client_encoding\x00UTF8\x00";
|
||||
|
||||
b.iter(|| {
|
||||
ParameterStatus::decode(test::black_box(Bytes::from_static(DATA))).unwrap();
|
||||
ParameterStatus::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap();
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -52,7 +55,7 @@ fn bench_decode_parameter_status(b: &mut test::Bencher) {
|
|||
fn test_decode_parameter_status_response() {
|
||||
const PARAMETER_STATUS_RESPONSE: &[u8] = b"crdb_version\0CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)\0";
|
||||
|
||||
let message = ParameterStatus::decode(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap();
|
||||
let message = ParameterStatus::decode_body(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap();
|
||||
|
||||
assert_eq!(message.name, "crdb_version");
|
||||
assert_eq!(
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
use crate::io::PgBufMutExt;
|
||||
use crate::io::{BufMutExt, Encode};
|
||||
use crate::io::BufMutExt;
|
||||
use crate::io::{PgBufMutExt, StatementId};
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
use crate::types::Oid;
|
||||
use sqlx_core::Error;
|
||||
use std::num::Saturating;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Parse<'a> {
|
||||
/// The ID of the destination prepared statement.
|
||||
pub statement: Oid,
|
||||
pub statement: StatementId,
|
||||
|
||||
/// The query string to be parsed.
|
||||
pub query: &'a str,
|
||||
|
@ -16,39 +19,59 @@ pub struct Parse<'a> {
|
|||
pub param_types: &'a [Oid],
|
||||
}
|
||||
|
||||
impl Encode<'_> for Parse<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.push(b'P');
|
||||
impl FrontendMessage for Parse<'_> {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Parse;
|
||||
|
||||
buf.put_length_prefixed(|buf| {
|
||||
buf.put_statement_name(self.statement);
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
let mut size = Saturating(0);
|
||||
|
||||
buf.put_str_nul(self.query);
|
||||
size += self.statement.name_len();
|
||||
|
||||
// TODO: Return an error here instead
|
||||
assert!(self.param_types.len() <= (u16::MAX as usize));
|
||||
size += self.query.len();
|
||||
size += 1; // NUL terminator
|
||||
|
||||
buf.extend(&(self.param_types.len() as i16).to_be_bytes());
|
||||
size += 2; // param_types_len
|
||||
|
||||
for &oid in self.param_types {
|
||||
buf.extend(&oid.0.to_be_bytes());
|
||||
}
|
||||
})
|
||||
// `param_types`
|
||||
size += self.param_types.len().saturating_mul(4);
|
||||
|
||||
size
|
||||
}
|
||||
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
buf.put_statement_name(self.statement);
|
||||
|
||||
buf.put_str_nul(self.query);
|
||||
|
||||
let param_types_len = i16::try_from(self.param_types.len()).map_err(|_| {
|
||||
err_protocol!(
|
||||
"param_types.len() too large for binary protocol: {}",
|
||||
self.param_types.len()
|
||||
)
|
||||
})?;
|
||||
|
||||
buf.extend(param_types_len.to_be_bytes());
|
||||
|
||||
for &oid in self.param_types {
|
||||
buf.extend(oid.0.to_be_bytes());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_parse() {
|
||||
const EXPECTED: &[u8] = b"P\0\0\0\x1dsqlx_s_1\0SELECT $1\0\0\x01\0\0\0\x19";
|
||||
const EXPECTED: &[u8] = b"P\0\0\0\x26sqlx_s_1234567890\0SELECT $1\0\0\x01\0\0\0\x19";
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Parse {
|
||||
statement: Oid(1),
|
||||
statement: StatementId::TEST_VAL,
|
||||
query: "SELECT $1",
|
||||
param_types: &[Oid(25)],
|
||||
};
|
||||
|
||||
m.encode(&mut buf);
|
||||
m.encode_msg(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
|
13
sqlx-postgres/src/message/parse_complete.rs
Normal file
13
sqlx-postgres/src/message/parse_complete.rs
Normal file
|
@ -0,0 +1,13 @@
|
|||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
use sqlx_core::bytes::Bytes;
|
||||
use sqlx_core::Error;
|
||||
|
||||
pub struct ParseComplete;
|
||||
|
||||
impl BackendMessage for ParseComplete {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::ParseComplete;
|
||||
|
||||
fn decode_body(_bytes: Bytes) -> Result<Self, Error> {
|
||||
Ok(ParseComplete)
|
||||
}
|
||||
}
|
|
@ -1,9 +1,9 @@
|
|||
use std::fmt::Write;
|
||||
|
||||
use crate::io::BufMutExt;
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
use md5::{Digest, Md5};
|
||||
|
||||
use crate::io::PgBufMutExt;
|
||||
use crate::io::{BufMutExt, Encode};
|
||||
use sqlx_core::Error;
|
||||
use std::fmt::Write;
|
||||
use std::num::Saturating;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Password<'a> {
|
||||
|
@ -16,117 +16,138 @@ pub enum Password<'a> {
|
|||
},
|
||||
}
|
||||
|
||||
impl Password<'_> {
|
||||
#[inline]
|
||||
fn len(&self) -> usize {
|
||||
impl FrontendMessage for Password<'_> {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic;
|
||||
|
||||
#[inline(always)]
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
let mut size = Saturating(0);
|
||||
|
||||
match self {
|
||||
Password::Cleartext(s) => s.len() + 5,
|
||||
Password::Md5 { .. } => 35 + 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encode<'_> for Password<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.reserve(1 + 4 + self.len());
|
||||
buf.push(b'p');
|
||||
|
||||
buf.put_length_prefixed(|buf| {
|
||||
match self {
|
||||
Password::Cleartext(password) => {
|
||||
buf.put_str_nul(password);
|
||||
}
|
||||
|
||||
Password::Md5 {
|
||||
username,
|
||||
password,
|
||||
salt,
|
||||
} => {
|
||||
// The actual `PasswordMessage` can be computed in SQL as
|
||||
// `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`.
|
||||
|
||||
// Keep in mind the md5() function returns its result as a hex string.
|
||||
|
||||
let mut hasher = Md5::new();
|
||||
|
||||
hasher.update(password);
|
||||
hasher.update(username);
|
||||
|
||||
let mut output = String::with_capacity(35);
|
||||
|
||||
let _ = write!(output, "{:x}", hasher.finalize_reset());
|
||||
|
||||
hasher.update(&output);
|
||||
hasher.update(salt);
|
||||
|
||||
output.clear();
|
||||
|
||||
let _ = write!(output, "md5{:x}", hasher.finalize());
|
||||
|
||||
buf.put_str_nul(&output);
|
||||
}
|
||||
Password::Cleartext(password) => {
|
||||
// To avoid reporting the exact password length anywhere,
|
||||
// we deliberately give a bad estimate.
|
||||
//
|
||||
// This shouldn't affect performance in the long run.
|
||||
size += password
|
||||
.len()
|
||||
.saturating_add(1) // NUL terminator
|
||||
.checked_next_power_of_two()
|
||||
.unwrap_or(usize::MAX);
|
||||
}
|
||||
});
|
||||
Password::Md5 { .. } => {
|
||||
// "md5<32 hex chars>\0"
|
||||
size += 36;
|
||||
}
|
||||
}
|
||||
|
||||
size
|
||||
}
|
||||
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
match self {
|
||||
Password::Cleartext(password) => {
|
||||
buf.put_str_nul(password);
|
||||
}
|
||||
|
||||
Password::Md5 {
|
||||
username,
|
||||
password,
|
||||
salt,
|
||||
} => {
|
||||
// The actual `PasswordMessage` can be computed in SQL as
|
||||
// `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`.
|
||||
|
||||
// Keep in mind the md5() function returns its result as a hex string.
|
||||
|
||||
let mut hasher = Md5::new();
|
||||
|
||||
hasher.update(password);
|
||||
hasher.update(username);
|
||||
|
||||
let mut output = String::with_capacity(35);
|
||||
|
||||
let _ = write!(output, "{:x}", hasher.finalize_reset());
|
||||
|
||||
hasher.update(&output);
|
||||
hasher.update(salt);
|
||||
|
||||
output.clear();
|
||||
|
||||
let _ = write!(output, "md5{:x}", hasher.finalize());
|
||||
|
||||
buf.put_str_nul(&output);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_clear_password() {
|
||||
const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0";
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::message::FrontendMessage;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Password::Cleartext("password");
|
||||
use super::Password;
|
||||
|
||||
m.encode(&mut buf);
|
||||
#[test]
|
||||
fn test_encode_clear_password() {
|
||||
const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0";
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
let mut buf = Vec::new();
|
||||
let m = Password::Cleartext("password");
|
||||
|
||||
#[test]
|
||||
fn test_encode_md5_password() {
|
||||
const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0";
|
||||
m.encode_msg(&mut buf).unwrap();
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Password::Md5 {
|
||||
password: "password",
|
||||
username: "root",
|
||||
salt: [147, 24, 57, 152],
|
||||
};
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
m.encode(&mut buf);
|
||||
#[test]
|
||||
fn test_encode_md5_password() {
|
||||
const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0";
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(debug_assertions)))]
|
||||
#[bench]
|
||||
fn bench_encode_clear_password(b: &mut test::Bencher) {
|
||||
use test::black_box;
|
||||
|
||||
let mut buf = Vec::with_capacity(128);
|
||||
|
||||
b.iter(|| {
|
||||
buf.clear();
|
||||
|
||||
black_box(Password::Cleartext("password")).encode(&mut buf);
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(debug_assertions)))]
|
||||
#[bench]
|
||||
fn bench_encode_md5_password(b: &mut test::Bencher) {
|
||||
use test::black_box;
|
||||
|
||||
let mut buf = Vec::with_capacity(128);
|
||||
|
||||
b.iter(|| {
|
||||
buf.clear();
|
||||
|
||||
black_box(Password::Md5 {
|
||||
let mut buf = Vec::new();
|
||||
let m = Password::Md5 {
|
||||
password: "password",
|
||||
username: "root",
|
||||
salt: [147, 24, 57, 152],
|
||||
})
|
||||
.encode(&mut buf);
|
||||
});
|
||||
};
|
||||
|
||||
m.encode_msg(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(debug_assertions)))]
|
||||
#[bench]
|
||||
fn bench_encode_clear_password(b: &mut test::Bencher) {
|
||||
use test::black_box;
|
||||
|
||||
let mut buf = Vec::with_capacity(128);
|
||||
|
||||
b.iter(|| {
|
||||
buf.clear();
|
||||
|
||||
black_box(Password::Cleartext("password")).encode_msg(&mut buf);
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(all(test, not(debug_assertions)))]
|
||||
#[bench]
|
||||
fn bench_encode_md5_password(b: &mut test::Bencher) {
|
||||
use test::black_box;
|
||||
|
||||
let mut buf = Vec::with_capacity(128);
|
||||
|
||||
b.iter(|| {
|
||||
buf.clear();
|
||||
|
||||
black_box(Password::Md5 {
|
||||
password: "password",
|
||||
username: "root",
|
||||
salt: [147, 24, 57, 152],
|
||||
})
|
||||
.encode_msg(&mut buf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,27 +1,37 @@
|
|||
use crate::io::{BufMutExt, Encode};
|
||||
use crate::io::BufMutExt;
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
use sqlx_core::Error;
|
||||
use std::num::Saturating;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Query<'a>(pub &'a str);
|
||||
|
||||
impl Encode<'_> for Query<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
let len = 4 + self.0.len() + 1;
|
||||
impl FrontendMessage for Query<'_> {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Query;
|
||||
|
||||
buf.reserve(len + 1);
|
||||
buf.push(b'Q');
|
||||
buf.extend(&(len as i32).to_be_bytes());
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
let mut size = Saturating(0);
|
||||
|
||||
size += self.0.len();
|
||||
size += 1; // NUL terminator
|
||||
|
||||
size
|
||||
}
|
||||
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
buf.put_str_nul(self.0);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_query() {
|
||||
const EXPECTED: &[u8] = b"Q\0\0\0\rSELECT 1\0";
|
||||
const EXPECTED: &[u8] = b"Q\0\0\0\x0DSELECT 1\0";
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let m = Query("SELECT 1");
|
||||
|
||||
m.encode(&mut buf);
|
||||
m.encode_msg(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use sqlx_core::bytes::Bytes;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[repr(u8)]
|
||||
|
@ -21,8 +21,10 @@ pub struct ReadyForQuery {
|
|||
pub transaction_status: TransactionStatus,
|
||||
}
|
||||
|
||||
impl Decode<'_> for ReadyForQuery {
|
||||
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
impl BackendMessage for ReadyForQuery {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::ReadyForQuery;
|
||||
|
||||
fn decode_body(buf: Bytes) -> Result<Self, Error> {
|
||||
let status = match buf[0] {
|
||||
b'I' => TransactionStatus::Idle,
|
||||
b'T' => TransactionStatus::Transaction,
|
||||
|
@ -46,7 +48,7 @@ impl Decode<'_> for ReadyForQuery {
|
|||
fn test_decode_ready_for_query() -> Result<(), Error> {
|
||||
const DATA: &[u8] = b"E";
|
||||
|
||||
let m = ReadyForQuery::decode(Bytes::from_static(DATA))?;
|
||||
let m = ReadyForQuery::decode_body(Bytes::from_static(DATA))?;
|
||||
|
||||
assert!(matches!(m.transaction_status, TransactionStatus::Error));
|
||||
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
use std::ops::Range;
|
||||
use std::str::from_utf8;
|
||||
|
||||
use memchr::memchr;
|
||||
|
||||
use sqlx_core::bytes::Bytes;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::Decode;
|
||||
use crate::io::ProtocolDecode;
|
||||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
|
||||
#[repr(u8)]
|
||||
|
@ -53,8 +56,8 @@ impl TryFrom<&str> for PgSeverity {
|
|||
pub struct Notice {
|
||||
storage: Bytes,
|
||||
severity: PgSeverity,
|
||||
message: (u16, u16),
|
||||
code: (u16, u16),
|
||||
message: Range<usize>,
|
||||
code: Range<usize>,
|
||||
}
|
||||
|
||||
impl Notice {
|
||||
|
@ -65,12 +68,12 @@ impl Notice {
|
|||
|
||||
#[inline]
|
||||
pub fn code(&self) -> &str {
|
||||
self.get_cached_str(self.code)
|
||||
self.get_cached_str(self.code.clone())
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn message(&self) -> &str {
|
||||
self.get_cached_str(self.message)
|
||||
self.get_cached_str(self.message.clone())
|
||||
}
|
||||
|
||||
// Field descriptions available here:
|
||||
|
@ -84,7 +87,7 @@ impl Notice {
|
|||
pub fn get_raw(&self, ty: u8) -> Option<&[u8]> {
|
||||
self.fields()
|
||||
.filter(|(field, _)| *field == ty)
|
||||
.map(|(_, (start, end))| &self.storage[start as usize..end as usize])
|
||||
.map(|(_, range)| &self.storage[range])
|
||||
.next()
|
||||
}
|
||||
}
|
||||
|
@ -99,13 +102,13 @@ impl Notice {
|
|||
}
|
||||
|
||||
#[inline]
|
||||
fn get_cached_str(&self, cache: (u16, u16)) -> &str {
|
||||
fn get_cached_str(&self, cache: Range<usize>) -> &str {
|
||||
// unwrap: this cannot fail at this stage
|
||||
from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap()
|
||||
from_utf8(&self.storage[cache]).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl Decode<'_> for Notice {
|
||||
impl ProtocolDecode<'_> for Notice {
|
||||
fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
// In order to support PostgreSQL 9.5 and older we need to parse the localized S field.
|
||||
// Newer versions additionally come with the V field that is guaranteed to be in English.
|
||||
|
@ -113,8 +116,8 @@ impl Decode<'_> for Notice {
|
|||
const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log;
|
||||
let mut severity_v = None;
|
||||
let mut severity_s = None;
|
||||
let mut message = (0, 0);
|
||||
let mut code = (0, 0);
|
||||
let mut message = 0..0;
|
||||
let mut code = 0..0;
|
||||
|
||||
// we cache the three always present fields
|
||||
// this enables to keep the access time down for the fields most likely accessed
|
||||
|
@ -125,7 +128,7 @@ impl Decode<'_> for Notice {
|
|||
};
|
||||
|
||||
for (field, v) in fields {
|
||||
if message.0 != 0 && code.0 != 0 {
|
||||
if !(message.is_empty() || code.is_empty()) {
|
||||
// stop iterating when we have the 3 fields we were looking for
|
||||
// we assume V (severity) was the first field as it should be
|
||||
break;
|
||||
|
@ -133,7 +136,7 @@ impl Decode<'_> for Notice {
|
|||
|
||||
match field {
|
||||
b'S' => {
|
||||
severity_s = from_utf8(&buf[v.0 as usize..v.1 as usize])
|
||||
severity_s = from_utf8(&buf[v.clone()])
|
||||
// If the error string is not UTF-8, we have no hope of interpreting it,
|
||||
// localized or not. The `V` field would likely fail to parse as well.
|
||||
.map_err(|_| notice_protocol_err())?
|
||||
|
@ -146,21 +149,19 @@ impl Decode<'_> for Notice {
|
|||
// Propagate errors here, because V is not localized and
|
||||
// thus we are missing a possible variant.
|
||||
severity_v = Some(
|
||||
from_utf8(&buf[v.0 as usize..v.1 as usize])
|
||||
from_utf8(&buf[v.clone()])
|
||||
.map_err(|_| notice_protocol_err())?
|
||||
.try_into()?,
|
||||
);
|
||||
}
|
||||
|
||||
b'M' => {
|
||||
_ = from_utf8(&buf[v.0 as usize..v.1 as usize])
|
||||
.map_err(|_| notice_protocol_err())?;
|
||||
_ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?;
|
||||
message = v;
|
||||
}
|
||||
|
||||
b'C' => {
|
||||
_ = from_utf8(&buf[v.0 as usize..v.1 as usize])
|
||||
.map_err(|_| notice_protocol_err())?;
|
||||
_ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?;
|
||||
code = v;
|
||||
}
|
||||
|
||||
|
@ -179,31 +180,46 @@ impl Decode<'_> for Notice {
|
|||
}
|
||||
}
|
||||
|
||||
impl BackendMessage for Notice {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::NoticeResponse;
|
||||
|
||||
fn decode_body(buf: Bytes) -> Result<Self, Error> {
|
||||
// Keeping both impls for now
|
||||
Self::decode_with(buf, ())
|
||||
}
|
||||
}
|
||||
|
||||
/// An iterator over each field in the Error (or Notice) response.
|
||||
struct Fields<'a> {
|
||||
storage: &'a [u8],
|
||||
offset: u16,
|
||||
offset: usize,
|
||||
}
|
||||
|
||||
impl<'a> Iterator for Fields<'a> {
|
||||
type Item = (u8, (u16, u16));
|
||||
type Item = (u8, Range<usize>);
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
// The fields in the response body are sequentially stored as [tag][string],
|
||||
// ending in a final, additional [nul]
|
||||
|
||||
let ty = self.storage[self.offset as usize];
|
||||
let ty = *self.storage.get(self.offset)?;
|
||||
|
||||
if ty == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16;
|
||||
let offset = self.offset;
|
||||
// Consume the type byte
|
||||
self.offset = self.offset.checked_add(1)?;
|
||||
|
||||
self.offset += nul + 2;
|
||||
let start = self.offset;
|
||||
|
||||
Some((ty, (offset + 1, offset + nul + 1)))
|
||||
let len = memchr(b'\0', self.storage.get(start..)?)?;
|
||||
|
||||
// Neither can overflow as they will always be `<= self.storage.len()`.
|
||||
let end = self.offset + len;
|
||||
self.offset = end + 1;
|
||||
|
||||
Some((ty, start..end))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use sqlx_core::bytes::{Buf, Bytes};
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::io::{BufExt, Decode};
|
||||
use crate::io::BufExt;
|
||||
use crate::message::{BackendMessage, BackendMessageFormat};
|
||||
use crate::types::Oid;
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -40,13 +41,30 @@ pub struct Field {
|
|||
pub format: i16,
|
||||
}
|
||||
|
||||
impl Decode<'_> for RowDescription {
|
||||
fn decode_with(mut buf: Bytes, _: ()) -> Result<Self, Error> {
|
||||
impl BackendMessage for RowDescription {
|
||||
const FORMAT: BackendMessageFormat = BackendMessageFormat::RowDescription;
|
||||
|
||||
fn decode_body(mut buf: Bytes) -> Result<Self, Error> {
|
||||
if buf.len() < 2 {
|
||||
return Err(err_protocol!(
|
||||
"expected at least 2 bytes, got {}",
|
||||
buf.len()
|
||||
));
|
||||
}
|
||||
|
||||
let cnt = buf.get_u16();
|
||||
let mut fields = Vec::with_capacity(cnt as usize);
|
||||
|
||||
for _ in 0..cnt {
|
||||
let name = buf.get_str_nul()?.to_owned();
|
||||
|
||||
if buf.len() < 18 {
|
||||
return Err(err_protocol!(
|
||||
"expected at least 18 bytes after field name {name:?}, got {}",
|
||||
buf.len()
|
||||
));
|
||||
}
|
||||
|
||||
let relation_id = buf.get_i32();
|
||||
let relation_attribute_no = buf.get_i16();
|
||||
let data_type_id = Oid(buf.get_u32());
|
||||
|
|
|
@ -1,35 +1,69 @@
|
|||
use crate::io::PgBufMutExt;
|
||||
use crate::io::{BufMutExt, Encode};
|
||||
use crate::io::BufMutExt;
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
use sqlx_core::Error;
|
||||
use std::num::Saturating;
|
||||
|
||||
pub struct SaslInitialResponse<'a> {
|
||||
pub response: &'a str,
|
||||
pub plus: bool,
|
||||
}
|
||||
|
||||
impl Encode<'_> for SaslInitialResponse<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.push(b'p');
|
||||
buf.put_length_prefixed(|buf| {
|
||||
// name of the SASL authentication mechanism that the client selected
|
||||
buf.put_str_nul(if self.plus {
|
||||
"SCRAM-SHA-256-PLUS"
|
||||
} else {
|
||||
"SCRAM-SHA-256"
|
||||
});
|
||||
impl SaslInitialResponse<'_> {
|
||||
#[inline(always)]
|
||||
fn selected_mechanism(&self) -> &'static str {
|
||||
if self.plus {
|
||||
"SCRAM-SHA-256-PLUS"
|
||||
} else {
|
||||
"SCRAM-SHA-256"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes());
|
||||
buf.extend(self.response.as_bytes());
|
||||
});
|
||||
impl FrontendMessage for SaslInitialResponse<'_> {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic;
|
||||
|
||||
#[inline(always)]
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
let mut size = Saturating(0);
|
||||
|
||||
size += self.selected_mechanism().len();
|
||||
size += 1; // NUL terminator
|
||||
|
||||
size += 4; // response_len
|
||||
size += self.response.len();
|
||||
|
||||
size
|
||||
}
|
||||
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
// name of the SASL authentication mechanism that the client selected
|
||||
buf.put_str_nul(self.selected_mechanism());
|
||||
|
||||
let response_len = i32::try_from(self.response.len()).map_err(|_| {
|
||||
err_protocol!(
|
||||
"SASL Initial Response length too long for protcol: {}",
|
||||
self.response.len()
|
||||
)
|
||||
})?;
|
||||
|
||||
buf.extend_from_slice(&response_len.to_be_bytes());
|
||||
buf.extend_from_slice(self.response.as_bytes());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SaslResponse<'a>(pub &'a str);
|
||||
|
||||
impl Encode<'_> for SaslResponse<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.push(b'p');
|
||||
buf.put_length_prefixed(|buf| {
|
||||
buf.extend(self.0.as_bytes());
|
||||
});
|
||||
impl FrontendMessage for SaslResponse<'_> {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic;
|
||||
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
Saturating(self.0.len())
|
||||
}
|
||||
|
||||
fn encode_body(&self, buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
buf.extend(self.0.as_bytes());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,23 +1,38 @@
|
|||
use crate::io::Encode;
|
||||
use crate::io::ProtocolEncode;
|
||||
|
||||
pub struct SslRequest;
|
||||
|
||||
impl SslRequest {
|
||||
pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16/";
|
||||
// https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST
|
||||
pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16\x2f";
|
||||
}
|
||||
|
||||
impl Encode<'_> for SslRequest {
|
||||
#[inline]
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.extend(&8_u32.to_be_bytes());
|
||||
buf.extend(&(((1234 << 16) | 5679) as u32).to_be_bytes());
|
||||
// Cannot impl FrontendMessage because it does not have a format code
|
||||
impl ProtocolEncode<'_> for SslRequest {
|
||||
#[inline(always)]
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) -> Result<(), crate::Error> {
|
||||
buf.extend_from_slice(Self::BYTES);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_ssl_request() {
|
||||
let mut buf = Vec::new();
|
||||
SslRequest.encode(&mut buf);
|
||||
|
||||
// Int32(8)
|
||||
// Length of message contents in bytes, including self.
|
||||
buf.extend_from_slice(&8_u32.to_be_bytes());
|
||||
|
||||
// Int32(80877103)
|
||||
// The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits,
|
||||
// and 5679 in the least significant 16 bits.
|
||||
// (To avoid confusion, this code must not be the same as any protocol version number.)
|
||||
buf.extend_from_slice(&(((1234 << 16) | 5679) as u32).to_be_bytes());
|
||||
|
||||
let mut encoded = Vec::new();
|
||||
SslRequest.encode(&mut encoded).unwrap();
|
||||
|
||||
assert_eq!(buf, SslRequest::BYTES);
|
||||
assert_eq!(buf, encoded);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::io::PgBufMutExt;
|
||||
use crate::io::{BufMutExt, Encode};
|
||||
use crate::io::{BufMutExt, ProtocolEncode};
|
||||
|
||||
// To begin a session, a frontend opens a connection to the server and sends a startup message.
|
||||
// This message includes the names of the user and of the database the user wants to connect to;
|
||||
|
@ -19,8 +19,9 @@ pub struct Startup<'a> {
|
|||
pub params: &'a [(&'a str, &'a str)],
|
||||
}
|
||||
|
||||
impl Encode<'_> for Startup<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
// Startup cannot impl FrontendMessage because it doesn't have a format code.
|
||||
impl ProtocolEncode<'_> for Startup<'_> {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _context: ()) -> Result<(), crate::Error> {
|
||||
buf.reserve(120);
|
||||
|
||||
buf.put_length_prefixed(|buf| {
|
||||
|
@ -47,7 +48,9 @@ impl Encode<'_> for Startup<'_> {
|
|||
// A zero byte is required as a terminator
|
||||
// after the last name/value pair.
|
||||
buf.push(0);
|
||||
});
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,7 +71,7 @@ fn test_encode_startup() {
|
|||
params: &[],
|
||||
};
|
||||
|
||||
m.encode(&mut buf);
|
||||
m.encode(&mut buf).unwrap();
|
||||
|
||||
assert_eq!(buf, EXPECTED);
|
||||
}
|
||||
|
|
|
@ -1,11 +1,20 @@
|
|||
use crate::io::Encode;
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
use sqlx_core::Error;
|
||||
use std::num::Saturating;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sync;
|
||||
|
||||
impl Encode<'_> for Sync {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.push(b'S');
|
||||
buf.extend(&4_i32.to_be_bytes());
|
||||
impl FrontendMessage for Sync {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Sync;
|
||||
|
||||
#[inline(always)]
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
Saturating(0)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn encode_body(&self, _buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,19 @@
|
|||
use crate::io::Encode;
|
||||
use crate::message::{FrontendMessage, FrontendMessageFormat};
|
||||
use sqlx_core::Error;
|
||||
use std::num::Saturating;
|
||||
|
||||
pub struct Terminate;
|
||||
|
||||
impl Encode<'_> for Terminate {
|
||||
fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
|
||||
buf.push(b'X');
|
||||
buf.extend(&4_u32.to_be_bytes());
|
||||
impl FrontendMessage for Terminate {
|
||||
const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Terminate;
|
||||
|
||||
#[inline(always)]
|
||||
fn body_size_hint(&self) -> Saturating<usize> {
|
||||
Saturating(0)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn encode_body(&self, _buf: &mut Vec<u8>) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@ impl TransactionManager for PgTransactionManager {
|
|||
Box::pin(async move {
|
||||
let rollback = Rollback::new(conn);
|
||||
let query = begin_ansi_transaction_sql(rollback.conn.transaction_depth);
|
||||
rollback.conn.queue_simple_query(&query);
|
||||
rollback.conn.queue_simple_query(&query)?;
|
||||
rollback.conn.transaction_depth += 1;
|
||||
rollback.conn.wait_until_ready().await?;
|
||||
rollback.defuse();
|
||||
|
@ -54,7 +54,8 @@ impl TransactionManager for PgTransactionManager {
|
|||
|
||||
fn start_rollback(conn: &mut PgConnection) {
|
||||
if conn.transaction_depth > 0 {
|
||||
conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth));
|
||||
conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth))
|
||||
.expect("BUG: Rollback query somehow too large for protocol");
|
||||
|
||||
conn.transaction_depth -= 1;
|
||||
}
|
||||
|
|
|
@ -17,12 +17,6 @@ pub struct Oid(
|
|||
pub u32,
|
||||
);
|
||||
|
||||
impl Oid {
|
||||
pub(crate) fn incr_one(&mut self) {
|
||||
self.0 = self.0.wrapping_add(1);
|
||||
}
|
||||
}
|
||||
|
||||
impl Type<Postgres> for Oid {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::OID
|
||||
|
|
Loading…
Reference in a new issue