From 53766e4659367e87eeef8a90ef63d5dc91b0e039 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sat, 17 Aug 2024 04:54:40 -0700 Subject: [PATCH] refactor(postgres): make better use of traits to improve protocol handling --- sqlx-postgres/src/advisory_lock.rs | 3 +- sqlx-postgres/src/arguments.rs | 1 + sqlx-postgres/src/connection/describe.rs | 26 +- sqlx-postgres/src/connection/establish.rs | 30 ++- sqlx-postgres/src/connection/executor.rs | 92 +++---- sqlx-postgres/src/connection/mod.rs | 30 +-- sqlx-postgres/src/connection/sasl.rs | 12 +- sqlx-postgres/src/connection/stream.rs | 50 ++-- sqlx-postgres/src/copy.rs | 48 ++-- sqlx-postgres/src/io/buf_mut.rs | 62 +++-- sqlx-postgres/src/io/mod.rs | 125 ++++++++++ sqlx-postgres/src/listener.rs | 6 +- sqlx-postgres/src/message/authentication.rs | 14 +- sqlx-postgres/src/message/backend_key_data.rs | 12 +- sqlx-postgres/src/message/bind.rs | 76 ++++-- sqlx-postgres/src/message/close.rs | 37 ++- sqlx-postgres/src/message/command_complete.rs | 21 +- sqlx-postgres/src/message/copy.rs | 127 ++++++---- sqlx-postgres/src/message/data_row.rs | 72 ++++-- sqlx-postgres/src/message/describe.rs | 188 +++++++-------- sqlx-postgres/src/message/execute.rs | 80 ++++-- sqlx-postgres/src/message/flush.rs | 30 ++- sqlx-postgres/src/message/mod.rs | 162 ++++++++++--- sqlx-postgres/src/message/notification.rs | 12 +- .../src/message/parameter_description.rs | 14 +- sqlx-postgres/src/message/parameter_status.rs | 17 +- sqlx-postgres/src/message/parse.rs | 61 +++-- sqlx-postgres/src/message/parse_complete.rs | 13 + sqlx-postgres/src/message/password.rs | 227 ++++++++++-------- sqlx-postgres/src/message/query.rs | 28 ++- sqlx-postgres/src/message/ready_for_query.rs | 10 +- sqlx-postgres/src/message/response.rs | 66 +++-- sqlx-postgres/src/message/row_description.rs | 24 +- sqlx-postgres/src/message/sasl.rs | 76 ++++-- sqlx-postgres/src/message/ssl_request.rs | 31 ++- sqlx-postgres/src/message/startup.rs | 13 +- sqlx-postgres/src/message/sync.rs | 19 +- sqlx-postgres/src/message/terminate.rs | 19 +- sqlx-postgres/src/transaction.rs | 5 +- sqlx-postgres/src/types/oid.rs | 6 - 40 files changed, 1252 insertions(+), 693 deletions(-) create mode 100644 sqlx-postgres/src/message/parse_complete.rs diff --git a/sqlx-postgres/src/advisory_lock.rs b/sqlx-postgres/src/advisory_lock.rs index 98274413..d1aef176 100644 --- a/sqlx-postgres/src/advisory_lock.rs +++ b/sqlx-postgres/src/advisory_lock.rs @@ -414,7 +414,8 @@ impl<'lock, C: AsMut> Drop for PgAdvisoryLockGuard<'lock, C> { // The `async fn` versions can safely use the prepared statement protocol, // but this is the safest way to queue a query to execute on the next opportunity. conn.as_mut() - .queue_simple_query(self.lock.get_release_query()); + .queue_simple_query(self.lock.get_release_query()) + .expect("BUG: PgAdvisoryLock::get_release_query() somehow too long for protocol"); } } } diff --git a/sqlx-postgres/src/arguments.rs b/sqlx-postgres/src/arguments.rs index 2e7d5fd9..bc7e861c 100644 --- a/sqlx-postgres/src/arguments.rs +++ b/sqlx-postgres/src/arguments.rs @@ -145,6 +145,7 @@ impl<'q> Arguments<'q> for PgArguments { write!(writer, "${}", self.buffer.count) } + #[inline(always)] fn len(&self) -> usize { self.buffer.count } diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index e53a054a..d9c55201 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -1,5 +1,6 @@ use crate::error::Error; use crate::ext::ustr::UStr; +use crate::io::StatementId; use crate::message::{ParameterDescription, RowDescription}; use crate::query_as::query_as; use crate::query_scalar::query_scalar; @@ -27,10 +28,12 @@ enum TypType { Range, } -impl TryFrom for TypType { +impl TryFrom for TypType { type Error = (); - fn try_from(t: u8) -> Result { + fn try_from(t: i8) -> Result { + let t = u8::try_from(t).or(Err(()))?; + let t = match t { b'b' => Self::Base, b'c' => Self::Composite, @@ -66,10 +69,12 @@ enum TypCategory { Unknown, } -impl TryFrom for TypCategory { +impl TryFrom for TypCategory { type Error = (); - fn try_from(c: u8) -> Result { + fn try_from(c: i8) -> Result { + let c = u8::try_from(c).or(Err(()))?; + let c = match c { b'A' => Self::Array, b'B' => Self::Boolean, @@ -209,8 +214,8 @@ impl PgConnection { .fetch_one(&mut *self) .await?; - let typ_type = TypType::try_from(typ_type as u8); - let category = TypCategory::try_from(category as u8); + let typ_type = TypType::try_from(typ_type); + let category = TypCategory::try_from(category); match (typ_type, category) { (Ok(TypType::Domain), _) => self.fetch_domain_by_oid(oid, base_type, name).await, @@ -416,7 +421,7 @@ WHERE rngtypid = $1 pub(crate) async fn get_nullable_for_columns( &mut self, - stmt_id: Oid, + stmt_id: StatementId, meta: &PgStatementMetadata, ) -> Result>, Error> { if meta.columns.is_empty() { @@ -486,13 +491,10 @@ WHERE rngtypid = $1 /// and returns `None` for all others. async fn nullables_from_explain( &mut self, - stmt_id: Oid, + stmt_id: StatementId, params_len: usize, ) -> Result>, Error> { - let mut explain = format!( - "EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE sqlx_s_{}", - stmt_id.0 - ); + let mut explain = format!("EXPLAIN (VERBOSE, FORMAT JSON) EXECUTE {stmt_id}"); let mut comma = false; if params_len > 0 { diff --git a/sqlx-postgres/src/connection/establish.rs b/sqlx-postgres/src/connection/establish.rs index 83b9843a..a730f5c1 100644 --- a/sqlx-postgres/src/connection/establish.rs +++ b/sqlx-postgres/src/connection/establish.rs @@ -3,11 +3,10 @@ use crate::HashMap; use crate::common::StatementCache; use crate::connection::{sasl, stream::PgStream}; use crate::error::Error; -use crate::io::Decode; +use crate::io::StatementId; use crate::message::{ - Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup, + Authentication, BackendKeyData, BackendMessageFormat, Password, ReadyForQuery, Startup, }; -use crate::types::Oid; use crate::{PgConnectOptions, PgConnection}; // https://www.postgresql.org/docs/current/protocol-flow.html#id-1.10.5.7.3 @@ -44,13 +43,13 @@ impl PgConnection { params.push(("options", options)); } - stream - .send(Startup { - username: Some(&options.username), - database: options.database.as_deref(), - params: ¶ms, - }) - .await?; + stream.write(Startup { + username: Some(&options.username), + database: options.database.as_deref(), + params: ¶ms, + })?; + + stream.flush().await?; // The server then uses this information and the contents of // its configuration files (such as pg_hba.conf) to determine whether the connection is @@ -64,7 +63,7 @@ impl PgConnection { loop { let message = stream.recv().await?; match message.format { - MessageFormat::Authentication => match message.decode()? { + BackendMessageFormat::Authentication => match message.decode()? { Authentication::Ok => { // the authentication exchange is successfully completed // do nothing; no more information is required to continue @@ -108,7 +107,7 @@ impl PgConnection { } }, - MessageFormat::BackendKeyData => { + BackendMessageFormat::BackendKeyData => { // provides secret-key data that the frontend must save if it wants to be // able to issue cancel requests later @@ -118,10 +117,9 @@ impl PgConnection { secret_key = data.secret_key; } - MessageFormat::ReadyForQuery => { + BackendMessageFormat::ReadyForQuery => { // start-up is completed. The frontend can now issue commands - transaction_status = - ReadyForQuery::decode(message.contents)?.transaction_status; + transaction_status = message.decode::()?.transaction_status; break; } @@ -142,7 +140,7 @@ impl PgConnection { transaction_status, transaction_depth: 0, pending_ready_for_query_count: 0, - next_statement_id: Oid(1), + next_statement_id: StatementId::NAMED_START, cache_statement: StatementCache::new(options.statement_cache_capacity), cache_type_oid: HashMap::new(), cache_type_info: HashMap::new(), diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index bb73db1e..d2f6bcdd 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -1,13 +1,13 @@ use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; +use crate::io::{PortalId, StatementId}; use crate::logger::QueryLogger; use crate::message::{ - self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query, - RowDescription, + self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse, + ParseComplete, Query, RowDescription, }; use crate::statement::PgStatementMetadata; -use crate::types::Oid; use crate::{ statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo, PgValueFormat, Postgres, @@ -16,6 +16,7 @@ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; +use sqlx_core::arguments::Arguments; use sqlx_core::Either; use std::{borrow::Cow, sync::Arc}; @@ -24,9 +25,9 @@ async fn prepare( sql: &str, parameters: &[PgTypeInfo], metadata: Option>, -) -> Result<(Oid, Arc), Error> { +) -> Result<(StatementId, Arc), Error> { let id = conn.next_statement_id; - conn.next_statement_id.incr_one(); + conn.next_statement_id = id.next(); // build a list of type OIDs to send to the database in the PARSE command // we have not yet started the query sequence, so we are *safe* to cleanly make @@ -42,15 +43,15 @@ async fn prepare( conn.wait_until_ready().await?; // next we send the PARSE command to the server - conn.stream.write(Parse { + conn.stream.write_msg(Parse { param_types: ¶m_types, query: sql, statement: id, - }); + })?; if metadata.is_none() { // get the statement columns and parameters - conn.stream.write(message::Describe::Statement(id)); + conn.stream.write_msg(message::Describe::Statement(id))?; } // we ask for the server to immediately send us the result of the PARSE command @@ -58,9 +59,7 @@ async fn prepare( conn.stream.flush().await?; // indicates that the SQL query string is now successfully parsed and has semantic validity - conn.stream - .recv_expect(MessageFormat::ParseComplete) - .await?; + conn.stream.recv_expect::().await?; let metadata = if let Some(metadata) = metadata { // each SYNC produces one READY FOR QUERY @@ -95,18 +94,18 @@ async fn prepare( } async fn recv_desc_params(conn: &mut PgConnection) -> Result { - conn.stream - .recv_expect(MessageFormat::ParameterDescription) - .await + conn.stream.recv_expect().await } async fn recv_desc_rows(conn: &mut PgConnection) -> Result, Error> { let rows: Option = match conn.stream.recv().await? { // describes the rows that will be returned when the statement is eventually executed - message if message.format == MessageFormat::RowDescription => Some(message.decode()?), + message if message.format == BackendMessageFormat::RowDescription => { + Some(message.decode()?) + } // no data would be returned if this statement was executed - message if message.format == MessageFormat::NoData => None, + message if message.format == BackendMessageFormat::NoData => None, message => { return Err(err_protocol!( @@ -125,12 +124,12 @@ impl PgConnection { // we need to wait for the [CloseComplete] to be returned from the server while count > 0 { match self.stream.recv().await? { - message if message.format == MessageFormat::PortalSuspended => { + message if message.format == BackendMessageFormat::PortalSuspended => { // there was an open portal // this can happen if the last time a statement was used it was not fully executed } - message if message.format == MessageFormat::CloseComplete => { + message if message.format == BackendMessageFormat::CloseComplete => { // successfully closed the statement (and freed up the server resources) count -= 1; } @@ -147,8 +146,11 @@ impl PgConnection { Ok(()) } + #[inline(always)] pub(crate) fn write_sync(&mut self) { - self.stream.write(message::Sync); + self.stream + .write_msg(message::Sync) + .expect("BUG: Sync should not be too big for protocol"); // all SYNC messages will return a ReadyForQuery self.pending_ready_for_query_count += 1; @@ -163,7 +165,7 @@ impl PgConnection { // optional metadata that was provided by the user, this means they are reusing // a statement object metadata: Option>, - ) -> Result<(Oid, Arc), Error> { + ) -> Result<(StatementId, Arc), Error> { if let Some(statement) = self.cache_statement.get_mut(sql) { return Ok((*statement).clone()); } @@ -172,7 +174,7 @@ impl PgConnection { if store_to_cache && self.cache_statement.is_enabled() { if let Some((id, _)) = self.cache_statement.insert(sql, statement.clone()) { - self.stream.write(Close::Statement(id)); + self.stream.write_msg(Close::Statement(id))?; self.write_sync(); self.stream.flush().await?; @@ -201,6 +203,14 @@ impl PgConnection { let mut metadata: Arc; let format = if let Some(mut arguments) = arguments { + // Check this before we write anything to the stream. + let num_params = i16::try_from(arguments.len()).map_err(|_| { + err_protocol!( + "PgConnection::run(): too many arguments for query: {}", + arguments.len() + ) + })?; + // prepare the statement if this our first time executing it // always return the statement ID here let (statement, metadata_) = self @@ -216,21 +226,21 @@ impl PgConnection { self.wait_until_ready().await?; // bind to attach the arguments to the statement and create a portal - self.stream.write(Bind { - portal: None, + self.stream.write_msg(Bind { + portal: PortalId::UNNAMED, statement, formats: &[PgValueFormat::Binary], - num_params: arguments.types.len() as i16, + num_params, params: &arguments.buffer, result_formats: &[PgValueFormat::Binary], - }); + })?; // executes the portal up to the passed limit // the protocol-level limit acts nearly identically to the `LIMIT` in SQL - self.stream.write(message::Execute { - portal: None, + self.stream.write_msg(message::Execute { + portal: PortalId::UNNAMED, limit: limit.into(), - }); + })?; // From https://www.postgresql.org/docs/current/protocol-flow.html: // // "An unnamed portal is destroyed at the end of the transaction, or as @@ -240,7 +250,7 @@ impl PgConnection { // we ask the database server to close the unnamed portal and free the associated resources // earlier - after the execution of the current query. - self.stream.write(message::Close::Portal(None)); + self.stream.write_msg(Close::Portal(PortalId::UNNAMED))?; // finally, [Sync] asks postgres to process the messages that we sent and respond with // a [ReadyForQuery] message when it's completely done. Theoretically, we could send @@ -253,7 +263,7 @@ impl PgConnection { PgValueFormat::Binary } else { // Query will trigger a ReadyForQuery - self.stream.write(Query(query)); + self.stream.write_msg(Query(query))?; self.pending_ready_for_query_count += 1; // metadata starts out as "nothing" @@ -270,12 +280,12 @@ impl PgConnection { let message = self.stream.recv().await?; match message.format { - MessageFormat::BindComplete - | MessageFormat::ParseComplete - | MessageFormat::ParameterDescription - | MessageFormat::NoData + BackendMessageFormat::BindComplete + | BackendMessageFormat::ParseComplete + | BackendMessageFormat::ParameterDescription + | BackendMessageFormat::NoData // unnamed portal has been closed - | MessageFormat::CloseComplete + | BackendMessageFormat::CloseComplete => { // harmless messages to ignore } @@ -284,7 +294,7 @@ impl PgConnection { // exactly one of these messages: CommandComplete, // EmptyQueryResponse (if the portal was created from an // empty query string), ErrorResponse, or PortalSuspended" - MessageFormat::CommandComplete => { + BackendMessageFormat::CommandComplete => { // a SQL command completed normally let cc: CommandComplete = message.decode()?; @@ -295,16 +305,16 @@ impl PgConnection { })); } - MessageFormat::EmptyQueryResponse => { + BackendMessageFormat::EmptyQueryResponse => { // empty query string passed to an unprepared execute } // Message::ErrorResponse is handled in self.stream.recv() // incomplete query execution has finished - MessageFormat::PortalSuspended => {} + BackendMessageFormat::PortalSuspended => {} - MessageFormat::RowDescription => { + BackendMessageFormat::RowDescription => { // indicates that a *new* set of rows are about to be returned let (columns, column_names) = self .handle_row_description(Some(message.decode()?), false) @@ -317,7 +327,7 @@ impl PgConnection { }); } - MessageFormat::DataRow => { + BackendMessageFormat::DataRow => { logger.increment_rows_returned(); // one of the set of rows returned by a SELECT, FETCH, etc query @@ -331,7 +341,7 @@ impl PgConnection { r#yield!(Either::Right(row)); } - MessageFormat::ReadyForQuery => { + BackendMessageFormat::ReadyForQuery => { // processing of the query string is complete self.handle_ready_for_query(message)?; break; diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 1c7a4682..9003dcb3 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -8,9 +8,10 @@ use futures_util::FutureExt; use crate::common::StatementCache; use crate::error::Error; use crate::ext::ustr::UStr; -use crate::io::Decode; +use crate::io::StatementId; use crate::message::{ - Close, Message, MessageFormat, Query, ReadyForQuery, Terminate, TransactionStatus, + BackendMessageFormat, Close, Query, ReadyForQuery, ReceivedMessage, Terminate, + TransactionStatus, }; use crate::statement::PgStatementMetadata; use crate::transaction::Transaction; @@ -47,10 +48,10 @@ pub struct PgConnection { // sequence of statement IDs for use in preparing statements // in PostgreSQL, the statement is prepared to a user-supplied identifier - next_statement_id: Oid, + next_statement_id: StatementId, // cache statement by query string to the id and columns - cache_statement: StatementCache<(Oid, Arc)>, + cache_statement: StatementCache<(StatementId, Arc)>, // cache user-defined types by id <-> info cache_type_info: HashMap, @@ -82,7 +83,7 @@ impl PgConnection { while self.pending_ready_for_query_count > 0 { let message = self.stream.recv().await?; - if let MessageFormat::ReadyForQuery = message.format { + if let BackendMessageFormat::ReadyForQuery = message.format { self.handle_ready_for_query(message)?; } } @@ -91,10 +92,7 @@ impl PgConnection { } async fn recv_ready_for_query(&mut self) -> Result<(), Error> { - let r: ReadyForQuery = self - .stream - .recv_expect(MessageFormat::ReadyForQuery) - .await?; + let r: ReadyForQuery = self.stream.recv_expect().await?; self.pending_ready_for_query_count -= 1; self.transaction_status = r.transaction_status; @@ -102,9 +100,10 @@ impl PgConnection { Ok(()) } - fn handle_ready_for_query(&mut self, message: Message) -> Result<(), Error> { + #[inline(always)] + fn handle_ready_for_query(&mut self, message: ReceivedMessage) -> Result<(), Error> { self.pending_ready_for_query_count -= 1; - self.transaction_status = ReadyForQuery::decode(message.contents)?.transaction_status; + self.transaction_status = message.decode::()?.transaction_status; Ok(()) } @@ -112,9 +111,12 @@ impl PgConnection { /// Queue a simple query (not prepared) to execute the next time this connection is used. /// /// Used for rolling back transactions and releasing advisory locks. - pub(crate) fn queue_simple_query(&mut self, query: &str) { + #[inline(always)] + pub(crate) fn queue_simple_query(&mut self, query: &str) -> Result<(), Error> { + self.stream.write_msg(Query(query))?; self.pending_ready_for_query_count += 1; - self.stream.write(Query(query)); + + Ok(()) } } @@ -184,7 +186,7 @@ impl Connection for PgConnection { self.wait_until_ready().await?; while let Some((id, _)) = self.cache_statement.remove_lru() { - self.stream.write(Close::Statement(id)); + self.stream.write_msg(Close::Statement(id))?; cleared += 1; } diff --git a/sqlx-postgres/src/connection/sasl.rs b/sqlx-postgres/src/connection/sasl.rs index 11f36eec..729cc1fc 100644 --- a/sqlx-postgres/src/connection/sasl.rs +++ b/sqlx-postgres/src/connection/sasl.rs @@ -1,8 +1,6 @@ use crate::connection::stream::PgStream; use crate::error::Error; -use crate::message::{ - Authentication, AuthenticationSasl, MessageFormat, SaslInitialResponse, SaslResponse, -}; +use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse}; use crate::PgConnectOptions; use hmac::{Hmac, Mac}; use rand::Rng; @@ -76,7 +74,7 @@ pub(crate) async fn authenticate( }) .await?; - let cont = match stream.recv_expect(MessageFormat::Authentication).await? { + let cont = match stream.recv_expect().await? { Authentication::SaslContinue(data) => data, auth => { @@ -147,7 +145,7 @@ pub(crate) async fn authenticate( stream.send(SaslResponse(&client_final_message)).await?; - let data = match stream.recv_expect(MessageFormat::Authentication).await? { + let data = match stream.recv_expect().await? { Authentication::SaslFinal(data) => data, auth => { @@ -172,10 +170,10 @@ fn gen_nonce() -> String { // ;; a valid "value". let nonce: String = std::iter::repeat(()) .map(|()| { - let mut c = rng.gen_range(0x21..0x7F) as u8; + let mut c = rng.gen_range(0x21u8..0x7F); while c == 0x2C { - c = rng.gen_range(0x21..0x7F) as u8; + c = rng.gen_range(0x21u8..0x7F); } c diff --git a/sqlx-postgres/src/connection/stream.rs b/sqlx-postgres/src/connection/stream.rs index 0cbf405d..a7c7d1ae 100644 --- a/sqlx-postgres/src/connection/stream.rs +++ b/sqlx-postgres/src/connection/stream.rs @@ -9,8 +9,10 @@ use sqlx_core::bytes::{Buf, Bytes}; use crate::connection::tls::MaybeUpgradeTls; use crate::error::Error; -use crate::io::{Decode, Encode}; -use crate::message::{Message, MessageFormat, Notice, Notification, ParameterStatus}; +use crate::message::{ + BackendMessage, BackendMessageFormat, EncodeMessage, FrontendMessage, Notice, Notification, + ParameterStatus, ReceivedMessage, +}; use crate::net::{self, BufferedSocket, Socket}; use crate::{PgConnectOptions, PgDatabaseError, PgSeverity}; @@ -55,59 +57,51 @@ impl PgStream { }) } - pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error> + #[inline(always)] + pub(crate) fn write_msg(&mut self, message: impl FrontendMessage) -> Result<(), Error> { + self.write(EncodeMessage(message)) + } + + pub(crate) async fn send(&mut self, message: T) -> Result<(), Error> where - T: Encode<'en>, + T: FrontendMessage, { - self.write(message); + self.write_msg(message)?; self.flush().await?; Ok(()) } // Expect a specific type and format - pub(crate) async fn recv_expect<'de, T: Decode<'de>>( - &mut self, - format: MessageFormat, - ) -> Result { - let message = self.recv().await?; - - if message.format != format { - return Err(err_protocol!( - "expecting {:?} but received {:?}", - format, - message.format - )); - } - - message.decode() + pub(crate) async fn recv_expect(&mut self) -> Result { + self.recv().await?.decode() } - pub(crate) async fn recv_unchecked(&mut self) -> Result { + pub(crate) async fn recv_unchecked(&mut self) -> Result { // all packets in postgres start with a 5-byte header // this header contains the message type and the total length of the message let mut header: Bytes = self.inner.read(5).await?; - let format = MessageFormat::try_from_u8(header.get_u8())?; + let format = BackendMessageFormat::try_from_u8(header.get_u8())?; let size = (header.get_u32() - 4) as usize; let contents = self.inner.read(size).await?; - Ok(Message { format, contents }) + Ok(ReceivedMessage { format, contents }) } // Get the next message from the server // May wait for more data from the server - pub(crate) async fn recv(&mut self) -> Result { + pub(crate) async fn recv(&mut self) -> Result { loop { let message = self.recv_unchecked().await?; match message.format { - MessageFormat::ErrorResponse => { + BackendMessageFormat::ErrorResponse => { // An error returned from the database server. return Err(PgDatabaseError(message.decode()?).into()); } - MessageFormat::NotificationResponse => { + BackendMessageFormat::NotificationResponse => { if let Some(buffer) = &mut self.notifications { let notification: Notification = message.decode()?; let _ = buffer.send(notification).await; @@ -116,7 +110,7 @@ impl PgStream { } } - MessageFormat::ParameterStatus => { + BackendMessageFormat::ParameterStatus => { // informs the frontend about the current (initial) // setting of backend parameters @@ -135,7 +129,7 @@ impl PgStream { continue; } - MessageFormat::NoticeResponse => { + BackendMessageFormat::NoticeResponse => { // do we need this to be more configurable? // if you are reading this comment and think so, open an issue diff --git a/sqlx-postgres/src/copy.rs b/sqlx-postgres/src/copy.rs index 98efbba0..347877c3 100644 --- a/sqlx-postgres/src/copy.rs +++ b/sqlx-postgres/src/copy.rs @@ -11,7 +11,8 @@ use crate::error::{Error, Result}; use crate::ext::async_stream::TryAsyncStream; use crate::io::AsyncRead; use crate::message::{ - CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query, + BackendMessageFormat, CommandComplete, CopyData, CopyDone, CopyFail, CopyInResponse, + CopyOutResponse, CopyResponseData, Query, ReadyForQuery, }; use crate::pool::{Pool, PoolConnection}; use crate::Postgres; @@ -138,7 +139,7 @@ impl PgPoolCopyExt for Pool { #[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"] pub struct PgCopyIn> { conn: Option, - response: CopyResponse, + response: CopyResponseData, } impl> PgCopyIn { @@ -146,8 +147,8 @@ impl> PgCopyIn { conn.wait_until_ready().await?; conn.stream.send(Query(statement)).await?; - let response = match conn.stream.recv_expect(MessageFormat::CopyInResponse).await { - Ok(res) => res, + let response = match conn.stream.recv_expect::().await { + Ok(res) => res.0, Err(e) => { conn.stream.recv().await?; return Err(e); @@ -168,7 +169,7 @@ impl> PgCopyIn { /// Returns the number of columns expected in the input. pub fn num_columns(&self) -> usize { assert_eq!( - self.response.num_columns as usize, + self.response.num_columns.unsigned_abs() as usize, self.response.format_codes.len(), "num_columns does not match format_codes.len()" ); @@ -261,9 +262,7 @@ impl> PgCopyIn { match e.code() { Some(Cow::Borrowed("57014")) => { // postgres abort received error code - conn.stream - .recv_expect(MessageFormat::ReadyForQuery) - .await?; + conn.stream.recv_expect::().await?; Ok(()) } _ => Err(Error::Database(e)), @@ -283,11 +282,7 @@ impl> PgCopyIn { .expect("CopyWriter::finish: conn taken illegally"); conn.stream.send(CopyDone).await?; - let cc: CommandComplete = match conn - .stream - .recv_expect(MessageFormat::CommandComplete) - .await - { + let cc: CommandComplete = match conn.stream.recv_expect().await { Ok(cc) => cc, Err(e) => { conn.stream.recv().await?; @@ -295,9 +290,7 @@ impl> PgCopyIn { } }; - conn.stream - .recv_expect(MessageFormat::ReadyForQuery) - .await?; + conn.stream.recv_expect::().await?; Ok(cc.rows_affected()) } @@ -306,9 +299,11 @@ impl> PgCopyIn { impl> Drop for PgCopyIn { fn drop(&mut self) { if let Some(mut conn) = self.conn.take() { - conn.stream.write(CopyFail::new( - "PgCopyIn dropped without calling finish() or fail()", - )); + conn.stream + .write_msg(CopyFail::new( + "PgCopyIn dropped without calling finish() or fail()", + )) + .expect("BUG: PgCopyIn abort message should not be too large"); } } } @@ -320,24 +315,21 @@ async fn pg_begin_copy_out<'c, C: DerefMut + Send + 'c>( conn.wait_until_ready().await?; conn.stream.send(Query(statement)).await?; - let _: CopyResponse = conn - .stream - .recv_expect(MessageFormat::CopyOutResponse) - .await?; + let _: CopyOutResponse = conn.stream.recv_expect().await?; let stream: TryAsyncStream<'c, Bytes> = try_stream! { loop { match conn.stream.recv().await { Err(e) => { - conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; + conn.stream.recv_expect::().await?; return Err(e); }, Ok(msg) => match msg.format { - MessageFormat::CopyData => r#yield!(msg.decode::>()?.0), - MessageFormat::CopyDone => { + BackendMessageFormat::CopyData => r#yield!(msg.decode::>()?.0), + BackendMessageFormat::CopyDone => { let _ = msg.decode::()?; - conn.stream.recv_expect(MessageFormat::CommandComplete).await?; - conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; + conn.stream.recv_expect::().await?; + conn.stream.recv_expect::().await?; return Ok(()) }, _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) diff --git a/sqlx-postgres/src/io/buf_mut.rs b/sqlx-postgres/src/io/buf_mut.rs index b5688f3b..ff6fe03d 100644 --- a/sqlx-postgres/src/io/buf_mut.rs +++ b/sqlx-postgres/src/io/buf_mut.rs @@ -1,54 +1,64 @@ -use crate::types::Oid; +use crate::io::{PortalId, StatementId}; pub trait PgBufMutExt { - fn put_length_prefixed(&mut self, f: F) + fn put_length_prefixed(&mut self, f: F) -> Result<(), crate::Error> where - F: FnOnce(&mut Vec); + F: FnOnce(&mut Vec) -> Result<(), crate::Error>; - fn put_statement_name(&mut self, id: Oid); + fn put_statement_name(&mut self, id: StatementId); - fn put_portal_name(&mut self, id: Option); + fn put_portal_name(&mut self, id: PortalId); } impl PgBufMutExt for Vec { // writes a length-prefixed message, this is used when encoding nearly all messages as postgres // wants us to send the length of the often-variable-sized messages up front - fn put_length_prefixed(&mut self, f: F) + fn put_length_prefixed(&mut self, write_contents: F) -> Result<(), crate::Error> where - F: FnOnce(&mut Vec), + F: FnOnce(&mut Vec) -> Result<(), crate::Error>, { // reserve space to write the prefixed length let offset = self.len(); self.extend(&[0; 4]); // write the main body of the message - f(self); + let write_result = write_contents(self); - // now calculate the size of what we wrote and set the length value - let size = (self.len() - offset) as i32; - self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); + let size_result = write_result.and_then(|_| { + let size = self.len() - offset; + i32::try_from(size) + .map_err(|_| err_protocol!("message size out of range for Pg protocol: {size")) + }); + + match size_result { + Ok(size) => { + // now calculate the size of what we wrote and set the length value + self[offset..(offset + 4)].copy_from_slice(&size.to_be_bytes()); + Ok(()) + } + Err(e) => { + // Put the buffer back to where it was. + self.truncate(offset); + Err(e) + } + } } // writes a statement name by ID #[inline] - fn put_statement_name(&mut self, id: Oid) { - // N.B. if you change this don't forget to update it in ../describe.rs - self.extend(b"sqlx_s_"); - - self.extend(itoa::Buffer::new().format(id.0).as_bytes()); - - self.push(0); + fn put_statement_name(&mut self, id: StatementId) { + let _: Result<(), ()> = id.write_name(|s| { + self.extend_from_slice(s.as_bytes()); + Ok(()) + }); } // writes a portal name by ID #[inline] - fn put_portal_name(&mut self, id: Option) { - if let Some(id) = id { - self.extend(b"sqlx_p_"); - - self.extend(itoa::Buffer::new().format(id.0).as_bytes()); - } - - self.push(0); + fn put_portal_name(&mut self, id: PortalId) { + let _: Result<(), ()> = id.write_name(|s| { + self.extend_from_slice(s.as_bytes()); + Ok(()) + }); } } diff --git a/sqlx-postgres/src/io/mod.rs b/sqlx-postgres/src/io/mod.rs index 1a6d0702..f90db85d 100644 --- a/sqlx-postgres/src/io/mod.rs +++ b/sqlx-postgres/src/io/mod.rs @@ -1,5 +1,130 @@ mod buf_mut; pub use buf_mut::PgBufMutExt; +use std::fmt; +use std::fmt::{Display, Formatter}; +use std::num::{NonZeroU32, Saturating}; pub(crate) use sqlx_core::io::*; + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) struct StatementId(IdInner); + +#[allow(dead_code)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub(crate) struct PortalId(IdInner); + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +struct IdInner(Option); + +impl StatementId { + pub const UNNAMED: Self = Self(IdInner::UNNAMED); + + pub const NAMED_START: Self = Self(IdInner::NAMED_START); + + #[cfg(test)] + pub const TEST_VAL: Self = Self(IdInner::TEST_VAL); + + const NAME_PREFIX: &'static str = "sqlx_s_"; + + pub fn next(&self) -> Self { + Self(self.0.next()) + } + + pub fn name_len(&self) -> Saturating { + self.0.name_len(Self::NAME_PREFIX) + } + + // There's no common trait implemented by `Formatter` and `Vec` for this purpose; + // we're deliberately avoiding the formatting machinery because it's known to be slow. + pub fn write_name(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> { + self.0.write_name(Self::NAME_PREFIX, write) + } +} + +impl Display for StatementId { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + self.write_name(|s| f.write_str(s)) + } +} + +#[allow(dead_code)] +impl PortalId { + // None selects the unnamed portal + pub const UNNAMED: Self = PortalId(IdInner::UNNAMED); + + pub const NAMED_START: Self = PortalId(IdInner::NAMED_START); + + #[cfg(test)] + pub const TEST_VAL: Self = Self(IdInner::TEST_VAL); + + const NAME_PREFIX: &'static str = "sqlx_p_"; + + /// If ID represents a named portal, return the next ID, wrapping on overflow. + /// + /// If this ID represents the unnamed portal, return the same. + pub fn next(&self) -> Self { + Self(self.0.next()) + } + + /// Calculate the number of bytes that will be written by [`Self::write_name()`]. + pub fn name_len(&self) -> Saturating { + self.0.name_len(Self::NAME_PREFIX) + } + + pub fn write_name(&self, write: impl FnMut(&str) -> Result<(), E>) -> Result<(), E> { + self.0.write_name(Self::NAME_PREFIX, write) + } +} + +impl IdInner { + const UNNAMED: Self = Self(None); + + const NAMED_START: Self = Self(Some(NonZeroU32::MIN)); + + #[cfg(test)] + pub const TEST_VAL: Self = Self(NonZeroU32::new(1234567890)); + + #[inline(always)] + fn next(&self) -> Self { + Self( + self.0 + .map(|id| id.checked_add(1).unwrap_or(NonZeroU32::MIN)), + ) + } + + #[inline(always)] + fn name_len(&self, name_prefix: &str) -> Saturating { + let mut len = Saturating(0); + + if let Some(id) = self.0 { + len += name_prefix.len(); + // estimate the length of the ID in decimal + // `.ilog10()` can't panic since the value is never zero + len += id.get().ilog10() as usize; + // add one to compensate for `ilog10()` rounding down. + len += 1; + } + + // count the NUL terminator + len += 1; + + len + } + + #[inline(always)] + fn write_name( + &self, + name_prefix: &str, + mut write: impl FnMut(&str) -> Result<(), E>, + ) -> Result<(), E> { + if let Some(id) = self.0 { + write(name_prefix)?; + write(itoa::Buffer::new().format(id.get()))?; + } + + write("\0")?; + + Ok(()) + } +} diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index ca4f78a2..43bd3c8f 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -11,7 +11,7 @@ use sqlx_core::Either; use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; -use crate::message::{MessageFormat, Notification}; +use crate::message::{BackendMessageFormat, Notification}; use crate::pool::PoolOptions; use crate::pool::{Pool, PoolConnection}; use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres}; @@ -277,12 +277,12 @@ impl PgListener { match message.format { // We've received an async notification, return it. - MessageFormat::NotificationResponse => { + BackendMessageFormat::NotificationResponse => { return Ok(Some(PgNotification(message.decode()?))); } // Mark the connection as ready for another query - MessageFormat::ReadyForQuery => { + BackendMessageFormat::ReadyForQuery => { self.connection().await?.pending_ready_for_query_count -= 1; } diff --git a/sqlx-postgres/src/message/authentication.rs b/sqlx-postgres/src/message/authentication.rs index 2e55c11f..3a3cf7ff 100644 --- a/sqlx-postgres/src/message/authentication.rs +++ b/sqlx-postgres/src/message/authentication.rs @@ -4,10 +4,10 @@ use memchr::memchr; use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::Decode; +use crate::io::ProtocolDecode; +use crate::message::{BackendMessage, BackendMessageFormat}; use base64::prelude::{Engine as _, BASE64_STANDARD}; - // On startup, the server sends an appropriate authentication request message, // to which the frontend must reply with an appropriate authentication // response message (such as a password). @@ -60,8 +60,10 @@ pub enum Authentication { SaslFinal(AuthenticationSaslFinal), } -impl Decode<'_> for Authentication { - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl BackendMessage for Authentication { + const FORMAT: BackendMessageFormat = BackendMessageFormat::Authentication; + + fn decode_body(mut buf: Bytes) -> Result { Ok(match buf.get_u32() { 0 => Authentication::Ok, @@ -129,7 +131,7 @@ pub struct AuthenticationSaslContinue { pub message: String, } -impl Decode<'_> for AuthenticationSaslContinue { +impl ProtocolDecode<'_> for AuthenticationSaslContinue { fn decode_with(buf: Bytes, _: ()) -> Result { let mut iterations: u32 = 4096; let mut salt = Vec::new(); @@ -173,7 +175,7 @@ pub struct AuthenticationSaslFinal { pub verifier: Vec, } -impl Decode<'_> for AuthenticationSaslFinal { +impl ProtocolDecode<'_> for AuthenticationSaslFinal { fn decode_with(buf: Bytes, _: ()) -> Result { let mut verifier = Vec::new(); diff --git a/sqlx-postgres/src/message/backend_key_data.rs b/sqlx-postgres/src/message/backend_key_data.rs index d89df65f..f2dc2f23 100644 --- a/sqlx-postgres/src/message/backend_key_data.rs +++ b/sqlx-postgres/src/message/backend_key_data.rs @@ -2,7 +2,7 @@ use byteorder::{BigEndian, ByteOrder}; use sqlx_core::bytes::Bytes; use crate::error::Error; -use crate::io::Decode; +use crate::message::{BackendMessage, BackendMessageFormat}; /// Contains cancellation key data. The frontend must save these values if it /// wishes to be able to issue `CancelRequest` messages later. @@ -15,8 +15,10 @@ pub struct BackendKeyData { pub secret_key: u32, } -impl Decode<'_> for BackendKeyData { - fn decode_with(buf: Bytes, _: ()) -> Result { +impl BackendMessage for BackendKeyData { + const FORMAT: BackendMessageFormat = BackendMessageFormat::BackendKeyData; + + fn decode_body(buf: Bytes) -> Result { let process_id = BigEndian::read_u32(&buf); let secret_key = BigEndian::read_u32(&buf[4..]); @@ -31,7 +33,7 @@ impl Decode<'_> for BackendKeyData { fn test_decode_backend_key_data() { const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; - let m = BackendKeyData::decode(DATA.into()).unwrap(); + let m = BackendKeyData::decode_body(DATA.into()).unwrap(); assert_eq!(m.process_id, 10182); assert_eq!(m.secret_key, 2303903019); @@ -43,6 +45,6 @@ fn bench_decode_backend_key_data(b: &mut test::Bencher) { const DATA: &[u8] = b"\0\0'\xc6\x89R\xc5+"; b.iter(|| { - BackendKeyData::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + BackendKeyData::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); }); } diff --git a/sqlx-postgres/src/message/bind.rs b/sqlx-postgres/src/message/bind.rs index b8db9679..83631fea 100644 --- a/sqlx-postgres/src/message/bind.rs +++ b/sqlx-postgres/src/message/bind.rs @@ -1,15 +1,15 @@ -use crate::io::Encode; -use crate::io::PgBufMutExt; -use crate::types::Oid; +use crate::io::{PgBufMutExt, PortalId, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; use crate::PgValueFormat; +use std::num::Saturating; #[derive(Debug)] pub struct Bind<'a> { - /// The ID of the destination portal (`None` selects the unnamed portal). - pub portal: Option, + /// The ID of the destination portal (`PortalId::UNNAMED` selects the unnamed portal). + pub portal: PortalId, /// The id of the source prepared statement. - pub statement: Oid, + pub statement: StatementId, /// The parameter format codes. Each must presently be zero (text) or one (binary). /// @@ -19,6 +19,8 @@ pub struct Bind<'a> { pub formats: &'a [PgValueFormat], /// The number of parameters. + /// + /// May be different from `formats.len()` pub num_params: i16, /// The value of each parameter, in the indicated format. @@ -33,31 +35,59 @@ pub struct Bind<'a> { pub result_formats: &'a [PgValueFormat], } -impl Encode<'_> for Bind<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'B'); +impl FrontendMessage for Bind<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Bind; - buf.put_length_prefixed(|buf| { - buf.put_portal_name(self.portal); + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + size += self.portal.name_len(); + size += self.statement.name_len(); - buf.put_statement_name(self.statement); + // Parameter formats and length prefix + size += 2; + size += self.formats.len(); - buf.extend(&(self.formats.len() as i16).to_be_bytes()); + // `num_params` + size += 2; - for &format in self.formats { - buf.extend(&(format as i16).to_be_bytes()); - } + size += self.params.len(); - buf.extend(&self.num_params.to_be_bytes()); + // Result formats and length prefix + size += 2; + size += self.result_formats.len(); - buf.extend(self.params); + size + } - buf.extend(&(self.result_formats.len() as i16).to_be_bytes()); + fn encode_body(&self, buf: &mut Vec) -> Result<(), crate::Error> { + buf.put_portal_name(self.portal); - for &format in self.result_formats { - buf.extend(&(format as i16).to_be_bytes()); - } - }); + buf.put_statement_name(self.statement); + + let formats_len = i16::try_from(self.formats.len()).map_err(|_| { + err_protocol!("too many parameter format codes ({})", self.formats.len()) + })?; + + buf.extend(formats_len.to_be_bytes()); + + for &format in self.formats { + buf.extend((format as i16).to_be_bytes()); + } + + buf.extend(self.num_params.to_be_bytes()); + + buf.extend(self.params); + + let result_formats_len = i16::try_from(self.formats.len()) + .map_err(|_| err_protocol!("too many result format codes ({})", self.formats.len()))?; + + buf.extend(result_formats_len.to_be_bytes()); + + for &format in self.result_formats { + buf.extend((format as i16).to_be_bytes()); + } + + Ok(()) } } diff --git a/sqlx-postgres/src/message/close.rs b/sqlx-postgres/src/message/close.rs index 0ffa638c..172f244c 100644 --- a/sqlx-postgres/src/message/close.rs +++ b/sqlx-postgres/src/message/close.rs @@ -1,6 +1,6 @@ -use crate::io::Encode; -use crate::io::PgBufMutExt; -use crate::types::Oid; +use crate::io::{PgBufMutExt, PortalId, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use std::num::Saturating; const CLOSE_PORTAL: u8 = b'P'; const CLOSE_STATEMENT: u8 = b'S'; @@ -8,18 +8,27 @@ const CLOSE_STATEMENT: u8 = b'S'; #[derive(Debug)] #[allow(dead_code)] pub enum Close { - Statement(Oid), - // None selects the unnamed portal - Portal(Option), + Statement(StatementId), + Portal(PortalId), } -impl Encode<'_> for Close { - fn encode_with(&self, buf: &mut Vec, _: ()) { - // 15 bytes for 1-digit statement/portal IDs - buf.reserve(20); - buf.push(b'C'); +impl FrontendMessage for Close { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Close; - buf.put_length_prefixed(|buf| match self { + fn body_size_hint(&self) -> Saturating { + // Either `CLOSE_PORTAL` or `CLOSE_STATEMENT` + let mut size = Saturating(1); + + match self { + Close::Statement(id) => size += id.name_len(), + Close::Portal(id) => size += id.name_len(), + } + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), crate::Error> { + match self { Close::Statement(id) => { buf.push(CLOSE_STATEMENT); buf.put_statement_name(*id); @@ -29,6 +38,8 @@ impl Encode<'_> for Close { buf.push(CLOSE_PORTAL); buf.put_portal_name(*id); } - }) + } + + Ok(()) } } diff --git a/sqlx-postgres/src/message/command_complete.rs b/sqlx-postgres/src/message/command_complete.rs index c2c8e158..eb33c512 100644 --- a/sqlx-postgres/src/message/command_complete.rs +++ b/sqlx-postgres/src/message/command_complete.rs @@ -3,7 +3,7 @@ use memchr::memrchr; use sqlx_core::bytes::Bytes; use crate::error::Error; -use crate::io::Decode; +use crate::message::{BackendMessage, BackendMessageFormat}; #[derive(Debug)] pub struct CommandComplete { @@ -12,10 +12,11 @@ pub struct CommandComplete { tag: Bytes, } -impl Decode<'_> for CommandComplete { - #[inline] - fn decode_with(buf: Bytes, _: ()) -> Result { - Ok(CommandComplete { tag: buf }) +impl BackendMessage for CommandComplete { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CommandComplete; + + fn decode_body(bytes: Bytes) -> Result { + Ok(CommandComplete { tag: bytes }) } } @@ -35,7 +36,7 @@ impl CommandComplete { fn test_decode_command_complete_for_insert() { const DATA: &[u8] = b"INSERT 0 1214\0"; - let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); assert_eq!(cc.rows_affected(), 1214); } @@ -44,7 +45,7 @@ fn test_decode_command_complete_for_insert() { fn test_decode_command_complete_for_begin() { const DATA: &[u8] = b"BEGIN\0"; - let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); assert_eq!(cc.rows_affected(), 0); } @@ -53,7 +54,7 @@ fn test_decode_command_complete_for_begin() { fn test_decode_command_complete_for_update() { const DATA: &[u8] = b"UPDATE 5\0"; - let cc = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + let cc = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); assert_eq!(cc.rows_affected(), 5); } @@ -64,7 +65,7 @@ fn bench_decode_command_complete(b: &mut test::Bencher) { const DATA: &[u8] = b"INSERT 0 1214\0"; b.iter(|| { - let _ = CommandComplete::decode(test::black_box(Bytes::from_static(DATA))); + let _ = CommandComplete::decode_body(test::black_box(Bytes::from_static(DATA))); }); } @@ -73,7 +74,7 @@ fn bench_decode_command_complete(b: &mut test::Bencher) { fn bench_decode_command_complete_rows_affected(b: &mut test::Bencher) { const DATA: &[u8] = b"INSERT 0 1214\0"; - let data = CommandComplete::decode(Bytes::from_static(DATA)).unwrap(); + let data = CommandComplete::decode_body(Bytes::from_static(DATA)).unwrap(); b.iter(|| { let _rows = test::black_box(&data).rows_affected(); diff --git a/sqlx-postgres/src/message/copy.rs b/sqlx-postgres/src/message/copy.rs index db0e7398..837d849a 100644 --- a/sqlx-postgres/src/message/copy.rs +++ b/sqlx-postgres/src/message/copy.rs @@ -1,15 +1,25 @@ use crate::error::Result; -use crate::io::{BufExt, BufMutExt, Decode, Encode}; -use sqlx_core::bytes::{Buf, BufMut, Bytes}; +use crate::io::BufMutExt; +use crate::message::{ + BackendMessage, BackendMessageFormat, FrontendMessage, FrontendMessageFormat, +}; +use sqlx_core::bytes::{Buf, Bytes}; +use sqlx_core::Error; +use std::num::Saturating; use std::ops::Deref; /// The same structure is sent for both `CopyInResponse` and `CopyOutResponse` -pub struct CopyResponse { +pub struct CopyResponseData { pub format: i8, pub num_columns: i16, pub format_codes: Vec, } +pub struct CopyInResponse(pub CopyResponseData); + +#[allow(dead_code)] +pub struct CopyOutResponse(pub CopyResponseData); + pub struct CopyData(pub B); pub struct CopyFail { @@ -18,14 +28,15 @@ pub struct CopyFail { pub struct CopyDone; -impl Decode<'_> for CopyResponse { - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl CopyResponseData { + #[inline] + fn decode(mut buf: Bytes) -> Result { let format = buf.get_i8(); let num_columns = buf.get_i16(); let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect(); - Ok(CopyResponse { + Ok(CopyResponseData { format, num_columns, format_codes, @@ -33,40 +44,65 @@ impl Decode<'_> for CopyResponse { } } -impl Decode<'_> for CopyData { - fn decode_with(buf: Bytes, _: ()) -> Result { - // well.. that was easy - Ok(CopyData(buf)) +impl BackendMessage for CopyInResponse { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyInResponse; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(CopyResponseData::decode(buf)?)) } } -impl> Encode<'_> for CopyData { - fn encode_with(&self, buf: &mut Vec, _context: ()) { - buf.push(b'd'); - buf.put_u32(self.0.len() as u32 + 4); +impl BackendMessage for CopyOutResponse { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyOutResponse; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(CopyResponseData::decode(buf)?)) + } +} + +impl BackendMessage for CopyData { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyData; + + #[inline(always)] + fn decode_body(buf: Bytes) -> std::result::Result { + Ok(Self(buf)) + } +} + +impl> FrontendMessage for CopyData { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyData; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(self.0.len()) + } + + #[inline(always)] + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { buf.extend_from_slice(&self.0); + Ok(()) } } -impl Decode<'_> for CopyFail { - fn decode_with(mut buf: Bytes, _: ()) -> Result { - Ok(CopyFail { - message: buf.get_str_nul()?, - }) +impl FrontendMessage for CopyFail { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyFail; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(self.message.len()) } -} -impl Encode<'_> for CopyFail { - fn encode_with(&self, buf: &mut Vec, _: ()) { - let len = 4 + self.message.len() + 1; - - buf.push(b'f'); // to pay respects - buf.put_u32(len as u32); + #[inline(always)] + fn encode_body(&self, buf: &mut Vec) -> std::result::Result<(), Error> { buf.put_str_nul(&self.message); + Ok(()) } } impl CopyFail { + #[inline(always)] pub fn new(msg: impl Into) -> CopyFail { CopyFail { message: msg.into(), @@ -74,23 +110,32 @@ impl CopyFail { } } -impl Decode<'_> for CopyDone { - fn decode_with(buf: Bytes, _: ()) -> Result { - if buf.is_empty() { - Ok(CopyDone) - } else { - Err(err_protocol!( - "expected no data for CopyDone, got: {:?}", - buf - )) - } +impl FrontendMessage for CopyDone { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::CopyDone; + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> std::result::Result<(), Error> { + Ok(()) } } -impl Encode<'_> for CopyDone { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.reserve(4); - buf.push(b'c'); - buf.put_u32(4); +impl BackendMessage for CopyDone { + const FORMAT: BackendMessageFormat = BackendMessageFormat::CopyDone; + + #[inline(always)] + fn decode_body(bytes: Bytes) -> std::result::Result { + if !bytes.is_empty() { + // Not fatal but may indicate a protocol change + tracing::debug!( + "Postgres backend returned non-empty message for CopyDone: \"{}\"", + bytes.escape_ascii() + ) + } + + Ok(CopyDone) } } diff --git a/sqlx-postgres/src/message/data_row.rs b/sqlx-postgres/src/message/data_row.rs index 3e08d22f..ae9d0d9b 100644 --- a/sqlx-postgres/src/message/data_row.rs +++ b/sqlx-postgres/src/message/data_row.rs @@ -1,10 +1,9 @@ -use std::ops::Range; - use byteorder::{BigEndian, ByteOrder}; use sqlx_core::bytes::Bytes; +use std::ops::Range; use crate::error::Error; -use crate::io::Decode; +use crate::message::{BackendMessage, BackendMessageFormat}; /// A row of data from the database. #[derive(Debug)] @@ -26,25 +25,55 @@ impl DataRow { } } -impl Decode<'_> for DataRow { - fn decode_with(buf: Bytes, _: ()) -> Result { +impl BackendMessage for DataRow { + const FORMAT: BackendMessageFormat = BackendMessageFormat::DataRow; + + fn decode_body(buf: Bytes) -> Result { + if buf.len() < 2 { + return Err(err_protocol!( + "expected at least 2 bytes, got {}", + buf.len() + )); + } + let cnt = BigEndian::read_u16(&buf) as usize; let mut values = Vec::with_capacity(cnt); - let mut offset = 2; + let mut offset: u32 = 2; for _ in 0..cnt { + let value_start = offset + .checked_add(4) + .ok_or_else(|| err_protocol!("next value start out of range (offset: {offset})"))?; + + // widen both to a larger type for a safe comparison + if (buf.len() as u64) < (value_start as u64) { + return Err(err_protocol!( + "expected 4 bytes at offset {offset}, got {}", + (value_start as u64) - (buf.len() as u64) + )); + } + // Length of the column value, in bytes (this count does not include itself). // Can be zero. As a special case, -1 indicates a NULL column value. // No value bytes follow in the NULL case. + // + // we know `offset` is within range of `buf.len()` from the above check + #[allow(clippy::cast_possible_truncation)] let length = BigEndian::read_i32(&buf[(offset as usize)..]); - offset += 4; - if length < 0 { - values.push(None); + if let Ok(length) = u32::try_from(length) { + let value_end = value_start.checked_add(length).ok_or_else(|| { + err_protocol!("value_start + length out of range ({offset} + {length})") + })?; + + values.push(Some(value_start..value_end)); + offset = value_end; } else { - values.push(Some(offset..(offset + length as u32))); - offset += length as u32; + // Negative values signify NULL + values.push(None); + // `value_start` is actually the next value now. + offset = value_start; } } @@ -57,9 +86,22 @@ impl Decode<'_> for DataRow { #[test] fn test_decode_data_row() { - const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; + const DATA: &[u8] = b"\ + \x00\x08\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00\n\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00\x14\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00(\ + \xff\xff\xff\xff\ + \x00\x00\x00\x04\ + \x00\x00\x00P"; - let row = DataRow::decode(DATA.into()).unwrap(); + let row = DataRow::decode_body(DATA.into()).unwrap(); assert_eq!(row.values.len(), 8); @@ -78,7 +120,7 @@ fn test_decode_data_row() { fn bench_data_row_get(b: &mut test::Bencher) { const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; - let row = DataRow::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + let row = DataRow::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); b.iter(|| { let _value = test::black_box(&row).get(3); @@ -91,6 +133,6 @@ fn bench_decode_data_row(b: &mut test::Bencher) { const DATA: &[u8] = b"\x00\x08\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\n\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00\x14\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00(\xff\xff\xff\xff\x00\x00\x00\x04\x00\x00\x00P"; b.iter(|| { - let _ = DataRow::decode(test::black_box(Bytes::from_static(DATA))); + let _ = DataRow::decode_body(test::black_box(Bytes::from_static(DATA))); }); } diff --git a/sqlx-postgres/src/message/describe.rs b/sqlx-postgres/src/message/describe.rs index 382f6e70..d6ea7e89 100644 --- a/sqlx-postgres/src/message/describe.rs +++ b/sqlx-postgres/src/message/describe.rs @@ -1,127 +1,103 @@ -use crate::io::Encode; -use crate::io::PgBufMutExt; -use crate::types::Oid; +use crate::io::{PgBufMutExt, PortalId, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; const DESCRIBE_PORTAL: u8 = b'P'; const DESCRIBE_STATEMENT: u8 = b'S'; -// [Describe] will emit both a [RowDescription] and a [ParameterDescription] message - +/// Note: will emit both a RowDescription and a ParameterDescription message #[derive(Debug)] #[allow(dead_code)] pub enum Describe { - UnnamedStatement, - Statement(Oid), - - UnnamedPortal, - Portal(Oid), + Statement(StatementId), + Portal(PortalId), } -impl Encode<'_> for Describe { - fn encode_with(&self, buf: &mut Vec, _: ()) { - // 15 bytes for 1-digit statement/portal IDs - buf.reserve(20); - buf.push(b'D'); +impl FrontendMessage for Describe { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Describe; - buf.put_length_prefixed(|buf| { - match self { - // #[likely] - Describe::Statement(id) => { - buf.push(DESCRIBE_STATEMENT); - buf.put_statement_name(*id); - } + fn body_size_hint(&self) -> Saturating { + // Either `DESCRIBE_PORTAL` or `DESCRIBE_STATEMENT` + let mut size = Saturating(1); - Describe::UnnamedPortal => { - buf.push(DESCRIBE_PORTAL); - buf.push(0); - } + match self { + Describe::Statement(id) => size += id.name_len(), + Describe::Portal(id) => size += id.name_len(), + } - Describe::UnnamedStatement => { - buf.push(DESCRIBE_STATEMENT); - buf.push(0); - } + size + } - Describe::Portal(id) => { - buf.push(DESCRIBE_PORTAL); - buf.put_portal_name(Some(*id)); - } + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + match self { + // #[likely] + Describe::Statement(id) => { + buf.push(DESCRIBE_STATEMENT); + buf.put_statement_name(*id); } - }); + + Describe::Portal(id) => { + buf.push(DESCRIBE_PORTAL); + buf.put_portal_name(*id); + } + } + + Ok(()) } } -#[test] -fn test_encode_describe_portal() { - const EXPECTED: &[u8] = b"D\0\0\0\x0EPsqlx_p_5\0"; +#[cfg(test)] +mod tests { + use crate::message::FrontendMessage; - let mut buf = Vec::new(); - let m = Describe::Portal(Oid(5)); + use super::{Describe, PortalId, StatementId}; - m.encode(&mut buf); + #[test] + fn test_encode_describe_portal() { + const EXPECTED: &[u8] = b"D\0\0\0\x17Psqlx_p_1234567890\0"; - assert_eq!(buf, EXPECTED); -} - -#[test] -fn test_encode_describe_unnamed_portal() { - const EXPECTED: &[u8] = b"D\0\0\0\x06P\0"; - - let mut buf = Vec::new(); - let m = Describe::UnnamedPortal; - - m.encode(&mut buf); - - assert_eq!(buf, EXPECTED); -} - -#[test] -fn test_encode_describe_statement() { - const EXPECTED: &[u8] = b"D\0\0\0\x0ESsqlx_s_5\0"; - - let mut buf = Vec::new(); - let m = Describe::Statement(Oid(5)); - - m.encode(&mut buf); - - assert_eq!(buf, EXPECTED); -} - -#[test] -fn test_encode_describe_unnamed_statement() { - const EXPECTED: &[u8] = b"D\0\0\0\x06S\0"; - - let mut buf = Vec::new(); - let m = Describe::UnnamedStatement; - - m.encode(&mut buf); - - assert_eq!(buf, EXPECTED); -} - -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_describe_portal(b: &mut test::Bencher) { - use test::black_box; - - let mut buf = Vec::with_capacity(128); - - b.iter(|| { - buf.clear(); - - black_box(Describe::Portal(5)).encode(&mut buf); - }); -} - -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_describe_unnamed_statement(b: &mut test::Bencher) { - use test::black_box; - - let mut buf = Vec::with_capacity(128); - - b.iter(|| { - buf.clear(); - - black_box(Describe::UnnamedStatement).encode(&mut buf); - }); + let mut buf = Vec::new(); + let m = Describe::Portal(PortalId::TEST_VAL); + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[test] + fn test_encode_describe_unnamed_portal() { + const EXPECTED: &[u8] = b"D\0\0\0\x06P\0"; + + let mut buf = Vec::new(); + let m = Describe::Portal(PortalId::UNNAMED); + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[test] + fn test_encode_describe_statement() { + const EXPECTED: &[u8] = b"D\0\0\0\x17Ssqlx_s_1234567890\0"; + + let mut buf = Vec::new(); + let m = Describe::Statement(StatementId::TEST_VAL); + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[test] + fn test_encode_describe_unnamed_statement() { + const EXPECTED: &[u8] = b"D\0\0\0\x06S\0"; + + let mut buf = Vec::new(); + let m = Describe::Statement(StatementId::UNNAMED); + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } } diff --git a/sqlx-postgres/src/message/execute.rs b/sqlx-postgres/src/message/execute.rs index 3550ae78..f82b7884 100644 --- a/sqlx-postgres/src/message/execute.rs +++ b/sqlx-postgres/src/message/execute.rs @@ -1,39 +1,73 @@ -use crate::io::Encode; -use crate::io::PgBufMutExt; -use crate::types::Oid; +use std::num::Saturating; + +use sqlx_core::Error; + +use crate::io::{PgBufMutExt, PortalId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; pub struct Execute { - /// The id of the portal to execute (`None` selects the unnamed portal). - pub portal: Option, + /// The id of the portal to execute. + pub portal: PortalId, /// Maximum number of rows to return, if portal contains a query /// that returns rows (ignored otherwise). Zero denotes “no limit”. pub limit: u32, } -impl Encode<'_> for Execute { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.reserve(20); - buf.push(b'E'); +impl FrontendMessage for Execute { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Execute; - buf.put_length_prefixed(|buf| { - buf.put_portal_name(self.portal); - buf.extend(&self.limit.to_be_bytes()); - }); + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + size += self.portal.name_len(); + size += 2; // limit + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.put_portal_name(self.portal); + buf.extend(&self.limit.to_be_bytes()); + + Ok(()) } } -#[test] -fn test_encode_execute() { - const EXPECTED: &[u8] = b"E\0\0\0\x11sqlx_p_5\0\0\0\0\x02"; +#[cfg(test)] +mod tests { + use crate::io::PortalId; + use crate::message::FrontendMessage; - let mut buf = Vec::new(); - let m = Execute { - portal: Some(Oid(5)), - limit: 2, - }; + use super::Execute; - m.encode(&mut buf); + #[test] + fn test_encode_execute_named_portal() { + const EXPECTED: &[u8] = b"E\0\0\0\x1Asqlx_p_1234567890\0\0\0\0\x02"; - assert_eq!(buf, EXPECTED); + let mut buf = Vec::new(); + let m = Execute { + portal: PortalId::TEST_VAL, + limit: 2, + }; + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[test] + fn test_encode_execute_unnamed_portal() { + const EXPECTED: &[u8] = b"E\0\0\0\x09\0\x49\x96\x02\xD2"; + + let mut buf = Vec::new(); + let m = Execute { + portal: PortalId::UNNAMED, + limit: 1234567890, + }; + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } } diff --git a/sqlx-postgres/src/message/flush.rs b/sqlx-postgres/src/message/flush.rs index fc21d3f1..d1dfabbf 100644 --- a/sqlx-postgres/src/message/flush.rs +++ b/sqlx-postgres/src/message/flush.rs @@ -1,17 +1,25 @@ -use crate::io::Encode; - -// The Flush message does not cause any specific output to be generated, -// but forces the backend to deliver any data pending in its output buffers. - -// A Flush must be sent after any extended-query command except Sync, if the -// frontend wishes to examine the results of that command before issuing more commands. +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; +/// The Flush message does not cause any specific output to be generated, +/// but forces the backend to deliver any data pending in its output buffers. +/// +/// A Flush must be sent after any extended-query command except Sync, if the +/// frontend wishes to examine the results of that command before issuing more commands. #[derive(Debug)] pub struct Flush; -impl Encode<'_> for Flush { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'H'); - buf.extend(&4_i32.to_be_bytes()); +impl FrontendMessage for Flush { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Flush; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> Result<(), Error> { + Ok(()) } } diff --git a/sqlx-postgres/src/message/mod.rs b/sqlx-postgres/src/message/mod.rs index ef1dbfab..e62f9beb 100644 --- a/sqlx-postgres/src/message/mod.rs +++ b/sqlx-postgres/src/message/mod.rs @@ -1,7 +1,8 @@ use sqlx_core::bytes::Bytes; +use std::num::Saturating; use crate::error::Error; -use crate::io::Decode; +use crate::io::PgBufMutExt; mod authentication; mod backend_key_data; @@ -17,6 +18,7 @@ mod notification; mod parameter_description; mod parameter_status; mod parse; +mod parse_complete; mod password; mod query; mod ready_for_query; @@ -33,7 +35,7 @@ pub use backend_key_data::BackendKeyData; pub use bind::Bind; pub use close::Close; pub use command_complete::CommandComplete; -pub use copy::{CopyData, CopyDone, CopyFail, CopyResponse}; +pub use copy::{CopyData, CopyDone, CopyFail, CopyInResponse, CopyOutResponse, CopyResponseData}; pub use data_row::DataRow; pub use describe::Describe; pub use execute::Execute; @@ -43,20 +45,51 @@ pub use notification::Notification; pub use parameter_description::ParameterDescription; pub use parameter_status::ParameterStatus; pub use parse::Parse; +pub use parse_complete::ParseComplete; pub use password::Password; pub use query::Query; pub use ready_for_query::{ReadyForQuery, TransactionStatus}; pub use response::{Notice, PgSeverity}; pub use row_description::RowDescription; pub use sasl::{SaslInitialResponse, SaslResponse}; +use sqlx_core::io::ProtocolEncode; pub use ssl_request::SslRequest; pub use startup::Startup; pub use sync::Sync; pub use terminate::Terminate; +// Note: we can't use the same enum for both frontend and backend message formats +// because there are duplicated format codes between them. +// +// For example, `Close` (frontend) and `CommandComplete` (backend) both use format code `C`. +// #[derive(Debug, PartialOrd, PartialEq)] #[repr(u8)] -pub enum MessageFormat { +pub enum FrontendMessageFormat { + Bind = b'B', + Close = b'C', + CopyData = b'd', + CopyDone = b'c', + CopyFail = b'f', + Describe = b'D', + Execute = b'E', + Flush = b'H', + Parse = b'P', + /// This message format is polymorphic. It's used for: + /// + /// * Plain password responses + /// * MD5 password responses + /// * SASL responses + /// * GSSAPI/SSPI responses + PasswordPolymorphic = b'p', + Query = b'Q', + Sync = b'S', + Terminate = b'X', +} + +#[derive(Debug, PartialOrd, PartialEq)] +#[repr(u8)] +pub enum BackendMessageFormat { Authentication, BackendKeyData, BindComplete, @@ -81,49 +114,116 @@ pub enum MessageFormat { } #[derive(Debug)] -pub struct Message { - pub format: MessageFormat, +pub struct ReceivedMessage { + pub format: BackendMessageFormat, pub contents: Bytes, } -impl Message { +impl ReceivedMessage { #[inline] - pub fn decode<'de, T>(self) -> Result + pub fn decode(self) -> Result where - T: Decode<'de>, + T: BackendMessage, { - T::decode(self.contents) + if T::FORMAT != self.format { + return Err(err_protocol!( + "Postgres protocol error: expected {:?}, got {:?}", + T::FORMAT, + self.format + )); + } + + T::decode_body(self.contents).map_err(|e| match e { + Error::Protocol(s) => { + err_protocol!("Postgres protocol error (reading {:?}): {s}", self.format) + } + other => other, + }) } } -impl MessageFormat { +impl BackendMessageFormat { pub fn try_from_u8(v: u8) -> Result { // https://www.postgresql.org/docs/current/protocol-message-formats.html Ok(match v { - b'1' => MessageFormat::ParseComplete, - b'2' => MessageFormat::BindComplete, - b'3' => MessageFormat::CloseComplete, - b'C' => MessageFormat::CommandComplete, - b'd' => MessageFormat::CopyData, - b'c' => MessageFormat::CopyDone, - b'G' => MessageFormat::CopyInResponse, - b'H' => MessageFormat::CopyOutResponse, - b'D' => MessageFormat::DataRow, - b'E' => MessageFormat::ErrorResponse, - b'I' => MessageFormat::EmptyQueryResponse, - b'A' => MessageFormat::NotificationResponse, - b'K' => MessageFormat::BackendKeyData, - b'N' => MessageFormat::NoticeResponse, - b'R' => MessageFormat::Authentication, - b'S' => MessageFormat::ParameterStatus, - b'T' => MessageFormat::RowDescription, - b'Z' => MessageFormat::ReadyForQuery, - b'n' => MessageFormat::NoData, - b's' => MessageFormat::PortalSuspended, - b't' => MessageFormat::ParameterDescription, + b'1' => BackendMessageFormat::ParseComplete, + b'2' => BackendMessageFormat::BindComplete, + b'3' => BackendMessageFormat::CloseComplete, + b'C' => BackendMessageFormat::CommandComplete, + b'd' => BackendMessageFormat::CopyData, + b'c' => BackendMessageFormat::CopyDone, + b'G' => BackendMessageFormat::CopyInResponse, + b'H' => BackendMessageFormat::CopyOutResponse, + b'D' => BackendMessageFormat::DataRow, + b'E' => BackendMessageFormat::ErrorResponse, + b'I' => BackendMessageFormat::EmptyQueryResponse, + b'A' => BackendMessageFormat::NotificationResponse, + b'K' => BackendMessageFormat::BackendKeyData, + b'N' => BackendMessageFormat::NoticeResponse, + b'R' => BackendMessageFormat::Authentication, + b'S' => BackendMessageFormat::ParameterStatus, + b'T' => BackendMessageFormat::RowDescription, + b'Z' => BackendMessageFormat::ReadyForQuery, + b'n' => BackendMessageFormat::NoData, + b's' => BackendMessageFormat::PortalSuspended, + b't' => BackendMessageFormat::ParameterDescription, _ => return Err(err_protocol!("unknown message type: {:?}", v as char)), }) } } + +pub(crate) trait FrontendMessage: Sized { + /// The format prefix of this message. + const FORMAT: FrontendMessageFormat; + + /// Return the amount of space, in bytes, to reserve in the buffer passed to [`Self::encode_body()`]. + fn body_size_hint(&self) -> Saturating; + + /// Encode this type as a Frontend message in the Postgres protocol. + /// + /// The implementation should *not* include `Self::FORMAT` or the length prefix. + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error>; + + #[inline(always)] + #[cfg_attr(not(test), allow(dead_code))] + fn encode_msg(self, buf: &mut Vec) -> Result<(), Error> { + EncodeMessage(self).encode(buf) + } +} + +pub(crate) trait BackendMessage: Sized { + /// The expected message format. + /// + /// + const FORMAT: BackendMessageFormat; + + /// Decode this type from a Backend message in the Postgres protocol. + /// + /// The format code and length prefix have already been read and are not at the start of `bytes`. + fn decode_body(buf: Bytes) -> Result; +} + +pub struct EncodeMessage(pub F); + +impl ProtocolEncode<'_, ()> for EncodeMessage { + fn encode_with(&self, buf: &mut Vec, _context: ()) -> Result<(), Error> { + let mut size_hint = self.0.body_size_hint(); + // plus format code and length prefix + size_hint += 5; + + // don't panic if `size_hint` is ridiculous + buf.try_reserve(size_hint.0).map_err(|e| { + err_protocol!( + "Postgres protocol: error allocating {} bytes for encoding message {:?}: {e}", + size_hint.0, + F::FORMAT, + ) + })?; + + buf.push(F::FORMAT as u8); + + buf.put_length_prefixed(|buf| self.0.encode_body(buf)) + } +} diff --git a/sqlx-postgres/src/message/notification.rs b/sqlx-postgres/src/message/notification.rs index 34303908..7bf02983 100644 --- a/sqlx-postgres/src/message/notification.rs +++ b/sqlx-postgres/src/message/notification.rs @@ -1,7 +1,8 @@ use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::{BufExt, Decode}; +use crate::io::BufExt; +use crate::message::{BackendMessage, BackendMessageFormat}; #[derive(Debug)] pub struct Notification { @@ -10,9 +11,10 @@ pub struct Notification { pub(crate) payload: Bytes, } -impl Decode<'_> for Notification { - #[inline] - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl BackendMessage for Notification { + const FORMAT: BackendMessageFormat = BackendMessageFormat::NotificationResponse; + + fn decode_body(mut buf: Bytes) -> Result { let process_id = buf.get_u32(); let channel = buf.get_bytes_nul()?; let payload = buf.get_bytes_nul()?; @@ -29,7 +31,7 @@ impl Decode<'_> for Notification { fn test_decode_notification_response() { const NOTIFICATION_RESPONSE: &[u8] = b"\x34\x20\x10\x02TEST-CHANNEL\0THIS IS A TEST\0"; - let message = Notification::decode(Bytes::from(NOTIFICATION_RESPONSE)).unwrap(); + let message = Notification::decode_body(Bytes::from(NOTIFICATION_RESPONSE)).unwrap(); assert_eq!(message.process_id, 0x34201002); assert_eq!(&*message.channel, &b"TEST-CHANNEL"[..]); diff --git a/sqlx-postgres/src/message/parameter_description.rs b/sqlx-postgres/src/message/parameter_description.rs index 8d525d05..8aa361a8 100644 --- a/sqlx-postgres/src/message/parameter_description.rs +++ b/sqlx-postgres/src/message/parameter_description.rs @@ -2,7 +2,7 @@ use smallvec::SmallVec; use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::Decode; +use crate::message::{BackendMessage, BackendMessageFormat}; use crate::types::Oid; #[derive(Debug)] @@ -10,8 +10,10 @@ pub struct ParameterDescription { pub types: SmallVec<[Oid; 6]>, } -impl Decode<'_> for ParameterDescription { - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl BackendMessage for ParameterDescription { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterDescription; + + fn decode_body(mut buf: Bytes) -> Result { let cnt = buf.get_u16(); let mut types = SmallVec::with_capacity(cnt as usize); @@ -27,7 +29,7 @@ impl Decode<'_> for ParameterDescription { fn test_decode_parameter_description() { const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; - let m = ParameterDescription::decode(DATA.into()).unwrap(); + let m = ParameterDescription::decode_body(DATA.into()).unwrap(); assert_eq!(m.types.len(), 2); assert_eq!(m.types[0], Oid(0x0000_0000)); @@ -38,7 +40,7 @@ fn test_decode_parameter_description() { fn test_decode_empty_parameter_description() { const DATA: &[u8] = b"\x00\x00"; - let m = ParameterDescription::decode(DATA.into()).unwrap(); + let m = ParameterDescription::decode_body(DATA.into()).unwrap(); assert!(m.types.is_empty()); } @@ -49,6 +51,6 @@ fn bench_decode_parameter_description(b: &mut test::Bencher) { const DATA: &[u8] = b"\x00\x02\x00\x00\x00\x00\x00\x00\x05\x00"; b.iter(|| { - ParameterDescription::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + ParameterDescription::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); }); } diff --git a/sqlx-postgres/src/message/parameter_status.rs b/sqlx-postgres/src/message/parameter_status.rs index 37abe4e3..d979d189 100644 --- a/sqlx-postgres/src/message/parameter_status.rs +++ b/sqlx-postgres/src/message/parameter_status.rs @@ -1,7 +1,8 @@ use sqlx_core::bytes::Bytes; use crate::error::Error; -use crate::io::{BufExt, Decode}; +use crate::io::BufExt; +use crate::message::{BackendMessage, BackendMessageFormat}; #[derive(Debug)] pub struct ParameterStatus { @@ -9,8 +10,10 @@ pub struct ParameterStatus { pub value: String, } -impl Decode<'_> for ParameterStatus { - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl BackendMessage for ParameterStatus { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ParameterStatus; + + fn decode_body(mut buf: Bytes) -> Result { let name = buf.get_str_nul()?; let value = buf.get_str_nul()?; @@ -22,7 +25,7 @@ impl Decode<'_> for ParameterStatus { fn test_decode_parameter_status() { const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; - let m = ParameterStatus::decode(DATA.into()).unwrap(); + let m = ParameterStatus::decode_body(DATA.into()).unwrap(); assert_eq!(&m.name, "client_encoding"); assert_eq!(&m.value, "UTF8") @@ -32,7 +35,7 @@ fn test_decode_parameter_status() { fn test_decode_empty_parameter_status() { const DATA: &[u8] = b"\x00\x00"; - let m = ParameterStatus::decode(DATA.into()).unwrap(); + let m = ParameterStatus::decode_body(DATA.into()).unwrap(); assert!(m.name.is_empty()); assert!(m.value.is_empty()); @@ -44,7 +47,7 @@ fn bench_decode_parameter_status(b: &mut test::Bencher) { const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; b.iter(|| { - ParameterStatus::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + ParameterStatus::decode_body(test::black_box(Bytes::from_static(DATA))).unwrap(); }); } @@ -52,7 +55,7 @@ fn bench_decode_parameter_status(b: &mut test::Bencher) { fn test_decode_parameter_status_response() { const PARAMETER_STATUS_RESPONSE: &[u8] = b"crdb_version\0CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)\0"; - let message = ParameterStatus::decode(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap(); + let message = ParameterStatus::decode_body(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap(); assert_eq!(message.name, "crdb_version"); assert_eq!( diff --git a/sqlx-postgres/src/message/parse.rs b/sqlx-postgres/src/message/parse.rs index 6bcbdb6b..3e77c302 100644 --- a/sqlx-postgres/src/message/parse.rs +++ b/sqlx-postgres/src/message/parse.rs @@ -1,11 +1,14 @@ -use crate::io::PgBufMutExt; -use crate::io::{BufMutExt, Encode}; +use crate::io::BufMutExt; +use crate::io::{PgBufMutExt, StatementId}; +use crate::message::{FrontendMessage, FrontendMessageFormat}; use crate::types::Oid; +use sqlx_core::Error; +use std::num::Saturating; #[derive(Debug)] pub struct Parse<'a> { /// The ID of the destination prepared statement. - pub statement: Oid, + pub statement: StatementId, /// The query string to be parsed. pub query: &'a str, @@ -16,39 +19,59 @@ pub struct Parse<'a> { pub param_types: &'a [Oid], } -impl Encode<'_> for Parse<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'P'); +impl FrontendMessage for Parse<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Parse; - buf.put_length_prefixed(|buf| { - buf.put_statement_name(self.statement); + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); - buf.put_str_nul(self.query); + size += self.statement.name_len(); - // TODO: Return an error here instead - assert!(self.param_types.len() <= (u16::MAX as usize)); + size += self.query.len(); + size += 1; // NUL terminator - buf.extend(&(self.param_types.len() as i16).to_be_bytes()); + size += 2; // param_types_len - for &oid in self.param_types { - buf.extend(&oid.0.to_be_bytes()); - } - }) + // `param_types` + size += self.param_types.len().saturating_mul(4); + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.put_statement_name(self.statement); + + buf.put_str_nul(self.query); + + let param_types_len = i16::try_from(self.param_types.len()).map_err(|_| { + err_protocol!( + "param_types.len() too large for binary protocol: {}", + self.param_types.len() + ) + })?; + + buf.extend(param_types_len.to_be_bytes()); + + for &oid in self.param_types { + buf.extend(oid.0.to_be_bytes()); + } + + Ok(()) } } #[test] fn test_encode_parse() { - const EXPECTED: &[u8] = b"P\0\0\0\x1dsqlx_s_1\0SELECT $1\0\0\x01\0\0\0\x19"; + const EXPECTED: &[u8] = b"P\0\0\0\x26sqlx_s_1234567890\0SELECT $1\0\0\x01\0\0\0\x19"; let mut buf = Vec::new(); let m = Parse { - statement: Oid(1), + statement: StatementId::TEST_VAL, query: "SELECT $1", param_types: &[Oid(25)], }; - m.encode(&mut buf); + m.encode_msg(&mut buf).unwrap(); assert_eq!(buf, EXPECTED); } diff --git a/sqlx-postgres/src/message/parse_complete.rs b/sqlx-postgres/src/message/parse_complete.rs new file mode 100644 index 00000000..3051f5ff --- /dev/null +++ b/sqlx-postgres/src/message/parse_complete.rs @@ -0,0 +1,13 @@ +use crate::message::{BackendMessage, BackendMessageFormat}; +use sqlx_core::bytes::Bytes; +use sqlx_core::Error; + +pub struct ParseComplete; + +impl BackendMessage for ParseComplete { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ParseComplete; + + fn decode_body(_bytes: Bytes) -> Result { + Ok(ParseComplete) + } +} diff --git a/sqlx-postgres/src/message/password.rs b/sqlx-postgres/src/message/password.rs index ba8b5ac6..4eaaeb15 100644 --- a/sqlx-postgres/src/message/password.rs +++ b/sqlx-postgres/src/message/password.rs @@ -1,9 +1,9 @@ -use std::fmt::Write; - +use crate::io::BufMutExt; +use crate::message::{FrontendMessage, FrontendMessageFormat}; use md5::{Digest, Md5}; - -use crate::io::PgBufMutExt; -use crate::io::{BufMutExt, Encode}; +use sqlx_core::Error; +use std::fmt::Write; +use std::num::Saturating; #[derive(Debug)] pub enum Password<'a> { @@ -16,117 +16,138 @@ pub enum Password<'a> { }, } -impl Password<'_> { - #[inline] - fn len(&self) -> usize { +impl FrontendMessage for Password<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + match self { - Password::Cleartext(s) => s.len() + 5, - Password::Md5 { .. } => 35 + 5, - } - } -} - -impl Encode<'_> for Password<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.reserve(1 + 4 + self.len()); - buf.push(b'p'); - - buf.put_length_prefixed(|buf| { - match self { - Password::Cleartext(password) => { - buf.put_str_nul(password); - } - - Password::Md5 { - username, - password, - salt, - } => { - // The actual `PasswordMessage` can be computed in SQL as - // `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`. - - // Keep in mind the md5() function returns its result as a hex string. - - let mut hasher = Md5::new(); - - hasher.update(password); - hasher.update(username); - - let mut output = String::with_capacity(35); - - let _ = write!(output, "{:x}", hasher.finalize_reset()); - - hasher.update(&output); - hasher.update(salt); - - output.clear(); - - let _ = write!(output, "md5{:x}", hasher.finalize()); - - buf.put_str_nul(&output); - } + Password::Cleartext(password) => { + // To avoid reporting the exact password length anywhere, + // we deliberately give a bad estimate. + // + // This shouldn't affect performance in the long run. + size += password + .len() + .saturating_add(1) // NUL terminator + .checked_next_power_of_two() + .unwrap_or(usize::MAX); } - }); + Password::Md5 { .. } => { + // "md5<32 hex chars>\0" + size += 36; + } + } + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + match self { + Password::Cleartext(password) => { + buf.put_str_nul(password); + } + + Password::Md5 { + username, + password, + salt, + } => { + // The actual `PasswordMessage` can be computed in SQL as + // `concat('md5', md5(concat(md5(concat(password, username)), random-salt)))`. + + // Keep in mind the md5() function returns its result as a hex string. + + let mut hasher = Md5::new(); + + hasher.update(password); + hasher.update(username); + + let mut output = String::with_capacity(35); + + let _ = write!(output, "{:x}", hasher.finalize_reset()); + + hasher.update(&output); + hasher.update(salt); + + output.clear(); + + let _ = write!(output, "md5{:x}", hasher.finalize()); + + buf.put_str_nul(&output); + } + } + + Ok(()) } } -#[test] -fn test_encode_clear_password() { - const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0"; +#[cfg(test)] +mod tests { + use crate::message::FrontendMessage; - let mut buf = Vec::new(); - let m = Password::Cleartext("password"); + use super::Password; - m.encode(&mut buf); + #[test] + fn test_encode_clear_password() { + const EXPECTED: &[u8] = b"p\0\0\0\rpassword\0"; - assert_eq!(buf, EXPECTED); -} + let mut buf = Vec::new(); + let m = Password::Cleartext("password"); -#[test] -fn test_encode_md5_password() { - const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0"; + m.encode_msg(&mut buf).unwrap(); - let mut buf = Vec::new(); - let m = Password::Md5 { - password: "password", - username: "root", - salt: [147, 24, 57, 152], - }; + assert_eq!(buf, EXPECTED); + } - m.encode(&mut buf); + #[test] + fn test_encode_md5_password() { + const EXPECTED: &[u8] = b"p\0\0\0(md53e2c9d99d49b201ef867a36f3f9ed62c\0"; - assert_eq!(buf, EXPECTED); -} - -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_clear_password(b: &mut test::Bencher) { - use test::black_box; - - let mut buf = Vec::with_capacity(128); - - b.iter(|| { - buf.clear(); - - black_box(Password::Cleartext("password")).encode(&mut buf); - }); -} - -#[cfg(all(test, not(debug_assertions)))] -#[bench] -fn bench_encode_md5_password(b: &mut test::Bencher) { - use test::black_box; - - let mut buf = Vec::with_capacity(128); - - b.iter(|| { - buf.clear(); - - black_box(Password::Md5 { + let mut buf = Vec::new(); + let m = Password::Md5 { password: "password", username: "root", salt: [147, 24, 57, 152], - }) - .encode(&mut buf); - }); + }; + + m.encode_msg(&mut buf).unwrap(); + + assert_eq!(buf, EXPECTED); + } + + #[cfg(all(test, not(debug_assertions)))] + #[bench] + fn bench_encode_clear_password(b: &mut test::Bencher) { + use test::black_box; + + let mut buf = Vec::with_capacity(128); + + b.iter(|| { + buf.clear(); + + black_box(Password::Cleartext("password")).encode_msg(&mut buf); + }); + } + + #[cfg(all(test, not(debug_assertions)))] + #[bench] + fn bench_encode_md5_password(b: &mut test::Bencher) { + use test::black_box; + + let mut buf = Vec::with_capacity(128); + + b.iter(|| { + buf.clear(); + + black_box(Password::Md5 { + password: "password", + username: "root", + salt: [147, 24, 57, 152], + }) + .encode_msg(&mut buf); + }); + } } diff --git a/sqlx-postgres/src/message/query.rs b/sqlx-postgres/src/message/query.rs index 8f49aabc..788d7808 100644 --- a/sqlx-postgres/src/message/query.rs +++ b/sqlx-postgres/src/message/query.rs @@ -1,27 +1,37 @@ -use crate::io::{BufMutExt, Encode}; +use crate::io::BufMutExt; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; #[derive(Debug)] pub struct Query<'a>(pub &'a str); -impl Encode<'_> for Query<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - let len = 4 + self.0.len() + 1; +impl FrontendMessage for Query<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Query; - buf.reserve(len + 1); - buf.push(b'Q'); - buf.extend(&(len as i32).to_be_bytes()); + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + size += self.0.len(); + size += 1; // NUL terminator + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { buf.put_str_nul(self.0); + Ok(()) } } #[test] fn test_encode_query() { - const EXPECTED: &[u8] = b"Q\0\0\0\rSELECT 1\0"; + const EXPECTED: &[u8] = b"Q\0\0\0\x0DSELECT 1\0"; let mut buf = Vec::new(); let m = Query("SELECT 1"); - m.encode(&mut buf); + m.encode_msg(&mut buf).unwrap(); assert_eq!(buf, EXPECTED); } diff --git a/sqlx-postgres/src/message/ready_for_query.rs b/sqlx-postgres/src/message/ready_for_query.rs index 21e6540d..a1f6761b 100644 --- a/sqlx-postgres/src/message/ready_for_query.rs +++ b/sqlx-postgres/src/message/ready_for_query.rs @@ -1,7 +1,7 @@ use sqlx_core::bytes::Bytes; use crate::error::Error; -use crate::io::Decode; +use crate::message::{BackendMessage, BackendMessageFormat}; #[derive(Debug)] #[repr(u8)] @@ -21,8 +21,10 @@ pub struct ReadyForQuery { pub transaction_status: TransactionStatus, } -impl Decode<'_> for ReadyForQuery { - fn decode_with(buf: Bytes, _: ()) -> Result { +impl BackendMessage for ReadyForQuery { + const FORMAT: BackendMessageFormat = BackendMessageFormat::ReadyForQuery; + + fn decode_body(buf: Bytes) -> Result { let status = match buf[0] { b'I' => TransactionStatus::Idle, b'T' => TransactionStatus::Transaction, @@ -46,7 +48,7 @@ impl Decode<'_> for ReadyForQuery { fn test_decode_ready_for_query() -> Result<(), Error> { const DATA: &[u8] = b"E"; - let m = ReadyForQuery::decode(Bytes::from_static(DATA))?; + let m = ReadyForQuery::decode_body(Bytes::from_static(DATA))?; assert!(matches!(m.transaction_status, TransactionStatus::Error)); diff --git a/sqlx-postgres/src/message/response.rs b/sqlx-postgres/src/message/response.rs index ec3c8808..d6e43e08 100644 --- a/sqlx-postgres/src/message/response.rs +++ b/sqlx-postgres/src/message/response.rs @@ -1,10 +1,13 @@ +use std::ops::Range; use std::str::from_utf8; use memchr::memchr; + use sqlx_core::bytes::Bytes; use crate::error::Error; -use crate::io::Decode; +use crate::io::ProtocolDecode; +use crate::message::{BackendMessage, BackendMessageFormat}; #[derive(Debug, Copy, Clone, Eq, PartialEq)] #[repr(u8)] @@ -53,8 +56,8 @@ impl TryFrom<&str> for PgSeverity { pub struct Notice { storage: Bytes, severity: PgSeverity, - message: (u16, u16), - code: (u16, u16), + message: Range, + code: Range, } impl Notice { @@ -65,12 +68,12 @@ impl Notice { #[inline] pub fn code(&self) -> &str { - self.get_cached_str(self.code) + self.get_cached_str(self.code.clone()) } #[inline] pub fn message(&self) -> &str { - self.get_cached_str(self.message) + self.get_cached_str(self.message.clone()) } // Field descriptions available here: @@ -84,7 +87,7 @@ impl Notice { pub fn get_raw(&self, ty: u8) -> Option<&[u8]> { self.fields() .filter(|(field, _)| *field == ty) - .map(|(_, (start, end))| &self.storage[start as usize..end as usize]) + .map(|(_, range)| &self.storage[range]) .next() } } @@ -99,13 +102,13 @@ impl Notice { } #[inline] - fn get_cached_str(&self, cache: (u16, u16)) -> &str { + fn get_cached_str(&self, cache: Range) -> &str { // unwrap: this cannot fail at this stage - from_utf8(&self.storage[cache.0 as usize..cache.1 as usize]).unwrap() + from_utf8(&self.storage[cache]).unwrap() } } -impl Decode<'_> for Notice { +impl ProtocolDecode<'_> for Notice { fn decode_with(buf: Bytes, _: ()) -> Result { // In order to support PostgreSQL 9.5 and older we need to parse the localized S field. // Newer versions additionally come with the V field that is guaranteed to be in English. @@ -113,8 +116,8 @@ impl Decode<'_> for Notice { const DEFAULT_SEVERITY: PgSeverity = PgSeverity::Log; let mut severity_v = None; let mut severity_s = None; - let mut message = (0, 0); - let mut code = (0, 0); + let mut message = 0..0; + let mut code = 0..0; // we cache the three always present fields // this enables to keep the access time down for the fields most likely accessed @@ -125,7 +128,7 @@ impl Decode<'_> for Notice { }; for (field, v) in fields { - if message.0 != 0 && code.0 != 0 { + if !(message.is_empty() || code.is_empty()) { // stop iterating when we have the 3 fields we were looking for // we assume V (severity) was the first field as it should be break; @@ -133,7 +136,7 @@ impl Decode<'_> for Notice { match field { b'S' => { - severity_s = from_utf8(&buf[v.0 as usize..v.1 as usize]) + severity_s = from_utf8(&buf[v.clone()]) // If the error string is not UTF-8, we have no hope of interpreting it, // localized or not. The `V` field would likely fail to parse as well. .map_err(|_| notice_protocol_err())? @@ -146,21 +149,19 @@ impl Decode<'_> for Notice { // Propagate errors here, because V is not localized and // thus we are missing a possible variant. severity_v = Some( - from_utf8(&buf[v.0 as usize..v.1 as usize]) + from_utf8(&buf[v.clone()]) .map_err(|_| notice_protocol_err())? .try_into()?, ); } b'M' => { - _ = from_utf8(&buf[v.0 as usize..v.1 as usize]) - .map_err(|_| notice_protocol_err())?; + _ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?; message = v; } b'C' => { - _ = from_utf8(&buf[v.0 as usize..v.1 as usize]) - .map_err(|_| notice_protocol_err())?; + _ = from_utf8(&buf[v.clone()]).map_err(|_| notice_protocol_err())?; code = v; } @@ -179,31 +180,46 @@ impl Decode<'_> for Notice { } } +impl BackendMessage for Notice { + const FORMAT: BackendMessageFormat = BackendMessageFormat::NoticeResponse; + + fn decode_body(buf: Bytes) -> Result { + // Keeping both impls for now + Self::decode_with(buf, ()) + } +} + /// An iterator over each field in the Error (or Notice) response. struct Fields<'a> { storage: &'a [u8], - offset: u16, + offset: usize, } impl<'a> Iterator for Fields<'a> { - type Item = (u8, (u16, u16)); + type Item = (u8, Range); fn next(&mut self) -> Option { // The fields in the response body are sequentially stored as [tag][string], // ending in a final, additional [nul] - let ty = self.storage[self.offset as usize]; + let ty = *self.storage.get(self.offset)?; if ty == 0 { return None; } - let nul = memchr(b'\0', &self.storage[(self.offset + 1) as usize..])? as u16; - let offset = self.offset; + // Consume the type byte + self.offset = self.offset.checked_add(1)?; - self.offset += nul + 2; + let start = self.offset; - Some((ty, (offset + 1, offset + nul + 1))) + let len = memchr(b'\0', self.storage.get(start..)?)?; + + // Neither can overflow as they will always be `<= self.storage.len()`. + let end = self.offset + len; + self.offset = end + 1; + + Some((ty, start..end)) } } diff --git a/sqlx-postgres/src/message/row_description.rs b/sqlx-postgres/src/message/row_description.rs index 32121386..3f3155ed 100644 --- a/sqlx-postgres/src/message/row_description.rs +++ b/sqlx-postgres/src/message/row_description.rs @@ -1,7 +1,8 @@ use sqlx_core::bytes::{Buf, Bytes}; use crate::error::Error; -use crate::io::{BufExt, Decode}; +use crate::io::BufExt; +use crate::message::{BackendMessage, BackendMessageFormat}; use crate::types::Oid; #[derive(Debug)] @@ -40,13 +41,30 @@ pub struct Field { pub format: i16, } -impl Decode<'_> for RowDescription { - fn decode_with(mut buf: Bytes, _: ()) -> Result { +impl BackendMessage for RowDescription { + const FORMAT: BackendMessageFormat = BackendMessageFormat::RowDescription; + + fn decode_body(mut buf: Bytes) -> Result { + if buf.len() < 2 { + return Err(err_protocol!( + "expected at least 2 bytes, got {}", + buf.len() + )); + } + let cnt = buf.get_u16(); let mut fields = Vec::with_capacity(cnt as usize); for _ in 0..cnt { let name = buf.get_str_nul()?.to_owned(); + + if buf.len() < 18 { + return Err(err_protocol!( + "expected at least 18 bytes after field name {name:?}, got {}", + buf.len() + )); + } + let relation_id = buf.get_i32(); let relation_attribute_no = buf.get_i16(); let data_type_id = Oid(buf.get_u32()); diff --git a/sqlx-postgres/src/message/sasl.rs b/sqlx-postgres/src/message/sasl.rs index 77d0bf8d..9d393189 100644 --- a/sqlx-postgres/src/message/sasl.rs +++ b/sqlx-postgres/src/message/sasl.rs @@ -1,35 +1,69 @@ -use crate::io::PgBufMutExt; -use crate::io::{BufMutExt, Encode}; +use crate::io::BufMutExt; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; pub struct SaslInitialResponse<'a> { pub response: &'a str, pub plus: bool, } -impl Encode<'_> for SaslInitialResponse<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'p'); - buf.put_length_prefixed(|buf| { - // name of the SASL authentication mechanism that the client selected - buf.put_str_nul(if self.plus { - "SCRAM-SHA-256-PLUS" - } else { - "SCRAM-SHA-256" - }); +impl SaslInitialResponse<'_> { + #[inline(always)] + fn selected_mechanism(&self) -> &'static str { + if self.plus { + "SCRAM-SHA-256-PLUS" + } else { + "SCRAM-SHA-256" + } + } +} - buf.extend(&(self.response.as_bytes().len() as i32).to_be_bytes()); - buf.extend(self.response.as_bytes()); - }); +impl FrontendMessage for SaslInitialResponse<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + let mut size = Saturating(0); + + size += self.selected_mechanism().len(); + size += 1; // NUL terminator + + size += 4; // response_len + size += self.response.len(); + + size + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + // name of the SASL authentication mechanism that the client selected + buf.put_str_nul(self.selected_mechanism()); + + let response_len = i32::try_from(self.response.len()).map_err(|_| { + err_protocol!( + "SASL Initial Response length too long for protcol: {}", + self.response.len() + ) + })?; + + buf.extend_from_slice(&response_len.to_be_bytes()); + buf.extend_from_slice(self.response.as_bytes()); + + Ok(()) } } pub struct SaslResponse<'a>(pub &'a str); -impl Encode<'_> for SaslResponse<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'p'); - buf.put_length_prefixed(|buf| { - buf.extend(self.0.as_bytes()); - }); +impl FrontendMessage for SaslResponse<'_> { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::PasswordPolymorphic; + + fn body_size_hint(&self) -> Saturating { + Saturating(self.0.len()) + } + + fn encode_body(&self, buf: &mut Vec) -> Result<(), Error> { + buf.extend(self.0.as_bytes()); + Ok(()) } } diff --git a/sqlx-postgres/src/message/ssl_request.rs b/sqlx-postgres/src/message/ssl_request.rs index fa57faf0..09c88622 100644 --- a/sqlx-postgres/src/message/ssl_request.rs +++ b/sqlx-postgres/src/message/ssl_request.rs @@ -1,23 +1,38 @@ -use crate::io::Encode; +use crate::io::ProtocolEncode; pub struct SslRequest; impl SslRequest { - pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16/"; + // https://www.postgresql.org/docs/current/protocol-message-formats.html#PROTOCOL-MESSAGE-FORMATS-SSLREQUEST + pub const BYTES: &'static [u8] = b"\x00\x00\x00\x08\x04\xd2\x16\x2f"; } -impl Encode<'_> for SslRequest { - #[inline] - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.extend(&8_u32.to_be_bytes()); - buf.extend(&(((1234 << 16) | 5679) as u32).to_be_bytes()); +// Cannot impl FrontendMessage because it does not have a format code +impl ProtocolEncode<'_> for SslRequest { + #[inline(always)] + fn encode_with(&self, buf: &mut Vec, _context: ()) -> Result<(), crate::Error> { + buf.extend_from_slice(Self::BYTES); + Ok(()) } } #[test] fn test_encode_ssl_request() { let mut buf = Vec::new(); - SslRequest.encode(&mut buf); + + // Int32(8) + // Length of message contents in bytes, including self. + buf.extend_from_slice(&8_u32.to_be_bytes()); + + // Int32(80877103) + // The SSL request code. The value is chosen to contain 1234 in the most significant 16 bits, + // and 5679 in the least significant 16 bits. + // (To avoid confusion, this code must not be the same as any protocol version number.) + buf.extend_from_slice(&(((1234 << 16) | 5679) as u32).to_be_bytes()); + + let mut encoded = Vec::new(); + SslRequest.encode(&mut encoded).unwrap(); assert_eq!(buf, SslRequest::BYTES); + assert_eq!(buf, encoded); } diff --git a/sqlx-postgres/src/message/startup.rs b/sqlx-postgres/src/message/startup.rs index 83869584..1c6d735a 100644 --- a/sqlx-postgres/src/message/startup.rs +++ b/sqlx-postgres/src/message/startup.rs @@ -1,5 +1,5 @@ use crate::io::PgBufMutExt; -use crate::io::{BufMutExt, Encode}; +use crate::io::{BufMutExt, ProtocolEncode}; // To begin a session, a frontend opens a connection to the server and sends a startup message. // This message includes the names of the user and of the database the user wants to connect to; @@ -19,8 +19,9 @@ pub struct Startup<'a> { pub params: &'a [(&'a str, &'a str)], } -impl Encode<'_> for Startup<'_> { - fn encode_with(&self, buf: &mut Vec, _: ()) { +// Startup cannot impl FrontendMessage because it doesn't have a format code. +impl ProtocolEncode<'_> for Startup<'_> { + fn encode_with(&self, buf: &mut Vec, _context: ()) -> Result<(), crate::Error> { buf.reserve(120); buf.put_length_prefixed(|buf| { @@ -47,7 +48,9 @@ impl Encode<'_> for Startup<'_> { // A zero byte is required as a terminator // after the last name/value pair. buf.push(0); - }); + + Ok(()) + }) } } @@ -68,7 +71,7 @@ fn test_encode_startup() { params: &[], }; - m.encode(&mut buf); + m.encode(&mut buf).unwrap(); assert_eq!(buf, EXPECTED); } diff --git a/sqlx-postgres/src/message/sync.rs b/sqlx-postgres/src/message/sync.rs index bc30114e..56f44987 100644 --- a/sqlx-postgres/src/message/sync.rs +++ b/sqlx-postgres/src/message/sync.rs @@ -1,11 +1,20 @@ -use crate::io::Encode; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; #[derive(Debug)] pub struct Sync; -impl Encode<'_> for Sync { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'S'); - buf.extend(&4_i32.to_be_bytes()); +impl FrontendMessage for Sync { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Sync; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> Result<(), Error> { + Ok(()) } } diff --git a/sqlx-postgres/src/message/terminate.rs b/sqlx-postgres/src/message/terminate.rs index 98e41fdb..39f8ff6e 100644 --- a/sqlx-postgres/src/message/terminate.rs +++ b/sqlx-postgres/src/message/terminate.rs @@ -1,10 +1,19 @@ -use crate::io::Encode; +use crate::message::{FrontendMessage, FrontendMessageFormat}; +use sqlx_core::Error; +use std::num::Saturating; pub struct Terminate; -impl Encode<'_> for Terminate { - fn encode_with(&self, buf: &mut Vec, _: ()) { - buf.push(b'X'); - buf.extend(&4_u32.to_be_bytes()); +impl FrontendMessage for Terminate { + const FORMAT: FrontendMessageFormat = FrontendMessageFormat::Terminate; + + #[inline(always)] + fn body_size_hint(&self) -> Saturating { + Saturating(0) + } + + #[inline(always)] + fn encode_body(&self, _buf: &mut Vec) -> Result<(), Error> { + Ok(()) } } diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 02028624..b9330d52 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -17,7 +17,7 @@ impl TransactionManager for PgTransactionManager { Box::pin(async move { let rollback = Rollback::new(conn); let query = begin_ansi_transaction_sql(rollback.conn.transaction_depth); - rollback.conn.queue_simple_query(&query); + rollback.conn.queue_simple_query(&query)?; rollback.conn.transaction_depth += 1; rollback.conn.wait_until_ready().await?; rollback.defuse(); @@ -54,7 +54,8 @@ impl TransactionManager for PgTransactionManager { fn start_rollback(conn: &mut PgConnection) { if conn.transaction_depth > 0 { - conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth)); + conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.transaction_depth)) + .expect("BUG: Rollback query somehow too large for protocol"); conn.transaction_depth -= 1; } diff --git a/sqlx-postgres/src/types/oid.rs b/sqlx-postgres/src/types/oid.rs index caa90dfc..04c5ef83 100644 --- a/sqlx-postgres/src/types/oid.rs +++ b/sqlx-postgres/src/types/oid.rs @@ -17,12 +17,6 @@ pub struct Oid( pub u32, ); -impl Oid { - pub(crate) fn incr_one(&mut self) { - self.0 = self.0.wrapping_add(1); - } -} - impl Type for Oid { fn type_info() -> PgTypeInfo { PgTypeInfo::OID