feat: add MySqlTime, audit mysql::types for panics (#3154)

Also clarifies the handling of `TIME` (we never realized it's used for both time-of-day and signed intervals) and adds appropriate impls for `std::time::Duration`, `time::Duration`, `chrono::TimeDelta`
This commit is contained in:
Austin Bonander 2024-03-30 11:49:12 -07:00 committed by GitHub
parent 1f6642cafa
commit 7102a7a254
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 999 additions and 82 deletions

View file

@ -45,7 +45,20 @@ impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow {
// NOTE: MySQL will never generate NULL types for non-NULL values
let type_info = &column.type_info;
// Unlike Postgres, MySQL does not length-prefix every value in a binary row.
// Values are *either* fixed-length or length-prefixed,
// so we need to inspect the type code to be sure.
let size: usize = match type_info.r#type {
// All fixed-length types.
ColumnType::LongLong => 8,
ColumnType::Long | ColumnType::Int24 => 4,
ColumnType::Short | ColumnType::Year => 2,
ColumnType::Tiny => 1,
ColumnType::Float => 4,
ColumnType::Double => 8,
// Blobs and strings are prefixed with their length,
// which is itself a length-encoded integer.
ColumnType::String
| ColumnType::VarChar
| ColumnType::VarString
@ -61,18 +74,13 @@ impl<'de> Decode<'de, &'de [MySqlColumn]> for BinaryRow {
| ColumnType::Json
| ColumnType::NewDecimal => buf.get_uint_lenenc() as usize,
ColumnType::LongLong => 8,
ColumnType::Long | ColumnType::Int24 => 4,
ColumnType::Short | ColumnType::Year => 2,
ColumnType::Tiny => 1,
ColumnType::Float => 4,
ColumnType::Double => 8,
// Like strings and blobs, these values are variable-length.
// Unlike strings and blobs, however, they exclusively use one byte for length.
ColumnType::Time
| ColumnType::Timestamp
| ColumnType::Date
| ColumnType::Datetime => {
// The size of this type is important for decoding
// Leave the length byte on the front of the value because decoding uses it.
buf[0] as usize + 1
}

View file

@ -2,13 +2,14 @@ use bytes::Buf;
use chrono::{
DateTime, Datelike, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc,
};
use sqlx_core::database::Database;
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::{BoxDynError, UnexpectedNullError};
use crate::protocol::text::ColumnType;
use crate::type_info::MySqlTypeInfo;
use crate::types::Type;
use crate::types::{MySqlTime, MySqlTimeSign, Type};
use crate::{MySql, MySqlValueFormat, MySqlValueRef};
impl Type<MySql> for DateTime<Utc> {
@ -63,7 +64,7 @@ impl<'r> Decode<'r, MySql> for DateTime<Local> {
impl Type<MySql> for NaiveTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Time)
MySqlTime::type_info()
}
}
@ -75,7 +76,7 @@ impl Encode<'_, MySql> for NaiveTime {
// NaiveTime is not negative
buf.push(0);
// "date on 4 bytes little-endian format" (?)
// Number of days in the interval; always 0 for time-of-day values.
// https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding
buf.extend_from_slice(&[0_u8; 4]);
@ -95,34 +96,18 @@ impl Encode<'_, MySql> for NaiveTime {
}
}
/// Decode from a `TIME` value.
///
/// ### Errors
/// Returns an error if the `TIME` value is negative or exceeds `23:59:59.999999`.
impl<'r> Decode<'r, MySql> for NaiveTime {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let mut buf = value.as_bytes()?;
// data length, expecting 8 or 12 (fractional seconds)
let len = buf.get_u8();
// MySQL specifies that if all of hours, minutes, seconds, microseconds
// are 0 then the length is 0 and no further data is send
// https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
if len == 0 {
return Ok(NaiveTime::from_hms_micro_opt(0, 0, 0, 0)
.expect("expected NaiveTime to construct from all zeroes"));
}
// is negative : int<1>
let is_negative = buf.get_u8();
debug_assert_eq!(is_negative, 0, "Negative dates/times are not supported");
// "date on 4 bytes little-endian format" (?)
// https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding
buf.advance(4);
decode_time(len - 5, buf)
// Covers most possible failure modes.
MySqlTime::decode(value)?.try_into()
}
// Retaining this parsing for now as it allows us to cross-check our impl.
MySqlValueFormat::Text => {
let s = value.as_str()?;
NaiveTime::parse_from_str(s, "%H:%M:%S%.f").map_err(Into::into)
@ -131,6 +116,57 @@ impl<'r> Decode<'r, MySql> for NaiveTime {
}
}
impl TryFrom<MySqlTime> for NaiveTime {
type Error = BoxDynError;
fn try_from(time: MySqlTime) -> Result<Self, Self::Error> {
NaiveTime::from_hms_micro_opt(
time.hours(),
time.minutes() as u32,
time.seconds() as u32,
time.microseconds(),
)
.ok_or_else(|| format!("Cannot convert `MySqlTime` value to `NaiveTime`: {time}").into())
}
}
impl From<MySqlTime> for chrono::TimeDelta {
fn from(time: MySqlTime) -> Self {
chrono::TimeDelta::new(time.whole_seconds_signed(), time.subsec_nanos())
.expect("BUG: chrono::TimeDelta should have a greater range than MySqlTime")
}
}
impl TryFrom<chrono::TimeDelta> for MySqlTime {
type Error = BoxDynError;
fn try_from(value: chrono::TimeDelta) -> Result<Self, Self::Error> {
let sign = if value < chrono::TimeDelta::zero() {
MySqlTimeSign::Negative
} else {
MySqlTimeSign::Positive
};
Ok(
// `std::time::Duration` has a greater positive range than `TimeDelta`
// which makes it a great intermediate if you ignore the sign.
MySqlTime::try_from(value.abs().to_std()?)?.with_sign(sign),
)
}
}
impl Type<MySql> for chrono::TimeDelta {
fn type_info() -> MySqlTypeInfo {
MySqlTime::type_info()
}
}
impl<'r> Decode<'r, MySql> for chrono::TimeDelta {
fn decode(value: <MySql as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(MySqlTime::decode(value)?.into())
}
}
impl Type<MySql> for NaiveDate {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Date)
@ -155,7 +191,14 @@ impl<'r> Decode<'r, MySql> for NaiveDate {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
decode_date(&value.as_bytes()?[1..])?.ok_or_else(|| UnexpectedNullError.into())
let buf = value.as_bytes()?;
// Row decoding should have left the length prefix.
if buf.is_empty() {
return Err("empty buffer".into());
}
decode_date(&buf[1..])?.ok_or_else(|| UnexpectedNullError.into())
}
MySqlValueFormat::Text => {
@ -214,6 +257,10 @@ impl<'r> Decode<'r, MySql> for NaiveDateTime {
MySqlValueFormat::Binary => {
let buf = value.as_bytes()?;
if buf.is_empty() {
return Err("empty buffer".into());
}
let len = buf[0];
let date = decode_date(&buf[1..])?.ok_or(UnexpectedNullError)?;

View file

@ -8,6 +8,7 @@ use crate::types::Type;
use crate::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef};
fn real_compatible(ty: &MySqlTypeInfo) -> bool {
// NOTE: `DECIMAL` is explicitly excluded because floating-point numbers have different semantics.
matches!(ty.r#type, ColumnType::Float | ColumnType::Double)
}
@ -53,12 +54,22 @@ impl Decode<'_, MySql> for f32 {
MySqlValueFormat::Binary => {
let buf = value.as_bytes()?;
if buf.len() == 8 {
match buf.len() {
// These functions panic if `buf` is not exactly the right size.
4 => LittleEndian::read_f32(buf),
// MySQL can return 8-byte DOUBLE values for a FLOAT
// We take and truncate to f32 as that's the same behavior as *in* MySQL
LittleEndian::read_f64(buf) as f32
} else {
LittleEndian::read_f32(buf)
// We take and truncate to f32 as that's the same behavior as *in* MySQL,
8 => LittleEndian::read_f64(buf) as f32,
other => {
// Users may try to decode a DECIMAL as floating point;
// inform them why that's a bad idea.
return Err(format!(
"expected a FLOAT as 4 or 8 bytes, got {other} bytes; \
note that decoding DECIMAL as `f32` is not supported \
due to differing semantics"
)
.into());
}
}
}
@ -70,7 +81,26 @@ impl Decode<'_, MySql> for f32 {
impl Decode<'_, MySql> for f64 {
fn decode(value: MySqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(match value.format() {
MySqlValueFormat::Binary => LittleEndian::read_f64(value.as_bytes()?),
MySqlValueFormat::Binary => {
let buf = value.as_bytes()?;
// The `read_*` functions panic if `buf` is not exactly the right size.
match buf.len() {
// Allow implicit widening here
4 => LittleEndian::read_f32(buf) as f64,
8 => LittleEndian::read_f64(buf),
other => {
// Users may try to decode a DECIMAL as floating point;
// inform them why that's a bad idea.
return Err(format!(
"expected a DOUBLE as 4 or 8 bytes, got {other} bytes; \
note that decoding DECIMAL as `f64` is not supported \
due to differing semantics"
)
.into());
}
}
}
MySqlValueFormat::Text => value.as_str()?.parse()?,
})
}

View file

@ -95,6 +95,20 @@ fn int_decode(value: MySqlValueRef<'_>) -> Result<i64, BoxDynError> {
MySqlValueFormat::Text => value.as_str()?.parse()?,
MySqlValueFormat::Binary => {
let buf = value.as_bytes()?;
// Check conditions that could cause `read_int()` to panic.
if buf.is_empty() {
return Err("empty buffer".into());
}
if buf.len() > 8 {
return Err(format!(
"expected no more than 8 bytes for integer value, got {}",
buf.len()
)
.into());
}
LittleEndian::read_int(buf, buf.len())
}
})

View file

@ -20,6 +20,8 @@
//! | `IpAddr` | VARCHAR, TEXT |
//! | `Ipv4Addr` | INET4 (MariaDB-only), VARCHAR, TEXT |
//! | `Ipv6Addr` | INET6 (MariaDB-only), VARCHAR, TEXT |
//! | [`MySqlTime`] | TIME (encode and decode full range) |
//! | [`Duration`] | TIME (for decoding positive values only) |
//!
//! ##### Note: `BOOLEAN`/`BOOL` Type
//! MySQL and MariaDB treat `BOOLEAN` as an alias of the `TINYINT` type:
@ -38,6 +40,12 @@
//! Thus, you must use the type override syntax in the query to tell the macros you are expecting
//! a `bool` column. See the docs for `query!()` and `query_as!()` for details on this syntax.
//!
//! ### NOTE: MySQL's `TIME` type is signed
//! MySQL's `TIME` type can be used as either a time-of-day value, or a signed interval.
//! Thus, it may take on negative values.
//!
//! Decoding a [`std::time::Duration`] returns an error if the `TIME` value is negative.
//!
//! ### [`chrono`](https://crates.io/crates/chrono)
//!
//! Requires the `chrono` Cargo feature flag.
@ -48,7 +56,20 @@
//! | `chrono::DateTime<Local>` | TIMESTAMP |
//! | `chrono::NaiveDateTime` | DATETIME |
//! | `chrono::NaiveDate` | DATE |
//! | `chrono::NaiveTime` | TIME |
//! | `chrono::NaiveTime` | TIME (time-of-day only) |
//! | `chrono::TimeDelta` | TIME (decodes full range; see note for encoding) |
//!
//! ### NOTE: MySQL's `TIME` type is dual-purpose
//! MySQL's `TIME` type can be used as either a time-of-day value, or an interval.
//! However, `chrono::NaiveTime` is designed only to represent a time-of-day.
//!
//! Decoding a `TIME` value as `chrono::NaiveTime` will return an error if the value is out of range.
//!
//! The [`MySqlTime`] type supports the full range and it also implements `TryInto<chrono::NaiveTime>`.
//!
//! Decoding a `chrono::TimeDelta` also supports the full range.
//!
//! To encode a `chrono::TimeDelta`, convert it to [`MySqlTime`] first using `TryFrom`/`TryInto`.
//!
//! ### [`time`](https://crates.io/crates/time)
//!
@ -59,7 +80,20 @@
//! | `time::PrimitiveDateTime` | DATETIME |
//! | `time::OffsetDateTime` | TIMESTAMP |
//! | `time::Date` | DATE |
//! | `time::Time` | TIME |
//! | `time::Time` | TIME (time-of-day only) |
//! | `time::Duration` | TIME (decodes full range; see note for encoding) |
//!
//! ### NOTE: MySQL's `TIME` type is dual-purpose
//! MySQL's `TIME` type can be used as either a time-of-day value, or an interval.
//! However, `time::Time` is designed only to represent a time-of-day.
//!
//! Decoding a `TIME` value as `time::Time` will return an error if the value is out of range.
//!
//! The [`MySqlTime`] type supports the full range, and it also implements `TryInto<time::Time>`.
//!
//! Decoding a `time::Duration` also supports the full range.
//!
//! To encode a `time::Duration`, convert it to [`MySqlTime`] first using `TryFrom`/`TryInto`.
//!
//! ### [`bigdecimal`](https://crates.io/crates/bigdecimal)
//! Requires the `bigdecimal` Cargo feature flag.
@ -102,11 +136,14 @@
pub(crate) use sqlx_core::types::*;
pub use mysql_time::{MySqlTime, MySqlTimeError, MySqlTimeSign};
mod bool;
mod bytes;
mod float;
mod inet;
mod int;
mod mysql_time;
mod str;
mod text;
mod uint;

View file

@ -0,0 +1,707 @@
//! The [`MysqlTime`] type.
use crate::protocol::text::ColumnType;
use crate::{MySql, MySqlTypeInfo, MySqlValueFormat};
use bytes::{Buf, BufMut};
use sqlx_core::database::Database;
use sqlx_core::decode::Decode;
use sqlx_core::encode::{Encode, IsNull};
use sqlx_core::error::BoxDynError;
use sqlx_core::types::Type;
use std::cmp::Ordering;
use std::fmt::{Debug, Display, Formatter, Write};
use std::time::Duration;
// Similar to `PgInterval`
/// Container for a MySQL `TIME` value, which may be an interval or a time-of-day.
///
/// Allowed range is `-838:59:59.0` to `838:59:59.0`.
///
/// If this value is used for a time-of-day, the range should be `00:00:00.0` to `23:59:59.999999`.
/// You can use [`Self::is_time_of_day()`] to check this easily.
///
/// * [MySQL Manual 13.2.3: The TIME Type](https://dev.mysql.com/doc/refman/8.3/en/time.html)
/// * [MariaDB Manual: TIME](https://mariadb.com/kb/en/time/)
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub struct MySqlTime {
pub(crate) sign: MySqlTimeSign,
pub(crate) magnitude: TimeMagnitude,
}
// By using a subcontainer for the actual time magnitude,
// we can still use a derived `Ord` implementation and just flip the comparison for negative values.
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
pub(crate) struct TimeMagnitude {
pub(crate) hours: u32,
pub(crate) minutes: u8,
pub(crate) seconds: u8,
pub(crate) microseconds: u32,
}
const MAGNITUDE_ZERO: TimeMagnitude = TimeMagnitude {
hours: 0,
minutes: 0,
seconds: 0,
microseconds: 0,
};
/// Maximum magnitude (positive or negative).
const MAGNITUDE_MAX: TimeMagnitude = TimeMagnitude {
hours: MySqlTime::HOURS_MAX,
minutes: 59,
seconds: 59,
// Surprisingly this is not 999_999 which is why `MySqlTimeError::SubsecondExcess`.
microseconds: 0,
};
/// The sign for a [`MySqlTime`] type.
#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)]
pub enum MySqlTimeSign {
// The protocol actually specifies negative as 1 and positive as 0,
// but by specifying variants this way we can derive `Ord` and it works as expected.
/// The interval is negative (invalid for time-of-day values).
Negative,
/// The interval is positive, or represents a time-of-day.
Positive,
}
/// Errors returned by [`MySqlTime::new()`].
#[derive(Debug, thiserror::Error)]
pub enum MySqlTimeError {
/// A field of [`MySqlTime`] exceeded its max range.
#[error("`MySqlTime` field `{field}` cannot exceed {max}, got {value}")]
FieldRange {
field: &'static str,
max: u32,
value: u64,
},
/// Error returned for time magnitudes (positive or negative) between `838:59:59.0` and `839:00:00.0`.
///
/// Other range errors should be covered by [`Self::FieldRange`] for the `hours` field.
///
/// For applications which can tolerate rounding, a valid truncated value is provided.
#[error(
"`MySqlTime` cannot exceed +/-838:59:59.000000; got {sign}838:59:59.{microseconds:06}"
)]
SubsecondExcess {
/// The sign of the magnitude.
sign: MySqlTimeSign,
/// The number of microseconds over the maximum.
microseconds: u32,
/// The truncated value,
/// either [`MySqlTime::MIN`] if negative or [`MySqlTime::MAX`] if positive.
truncated: MySqlTime,
},
/// MySQL coerces `-00:00:00` to `00:00:00` but this API considers that an error.
///
/// For applications which can tolerate coercion, you can convert this error to [`MySqlTime::ZERO`].
#[error("attempted to construct a `MySqlTime` value of negative zero")]
NegativeZero,
}
impl MySqlTime {
/// The `MySqlTime` value corresponding to `TIME '0:00:00.0'` (zero).
pub const ZERO: Self = MySqlTime {
sign: MySqlTimeSign::Positive,
magnitude: MAGNITUDE_ZERO,
};
/// The `MySqlTime` value corresponding to `TIME '838:59:59.0'` (max value).
pub const MAX: Self = MySqlTime {
sign: MySqlTimeSign::Positive,
magnitude: MAGNITUDE_MAX,
};
/// The `MySqlTime` value corresponding to `TIME '-838:59:59.0'` (min value).
pub const MIN: Self = MySqlTime {
sign: MySqlTimeSign::Negative,
// Same magnitude, opposite sign.
magnitude: MAGNITUDE_MAX,
};
// The maximums for the other values are self-evident, but not necessarily this one.
pub(crate) const HOURS_MAX: u32 = 838;
/// Construct a [`MySqlTime`] that is valid for use as a `TIME` value.
///
/// ### Errors
/// * [`MySqlTimeError::NegativeZero`] if all fields are 0 but `sign` is [`MySqlSign::Negative`].
/// * [`MySqlTimeError::FieldRange`] if any field is out of range:
/// * `hours > 838`
/// * `minutes > 59`
/// * `seconds > 59`
/// * `microseconds > 999_999`
/// * [`MySqlTimeError::SubsecondExcess`] if the magnitude is less than one second over the maximum.
/// * Durations 839 hours or greater are covered by `FieldRange`.
pub fn new(
sign: MySqlTimeSign,
hours: u32,
minutes: u8,
seconds: u8,
microseconds: u32,
) -> Result<Self, MySqlTimeError> {
macro_rules! check_fields {
($($name:ident: $max:expr),+ $(,)?) => {
$(
if $name > $max {
return Err(MySqlTimeError::FieldRange {
field: stringify!($name),
max: $max as u32,
value: $name as u64
})
}
)+
}
}
check_fields!(
hours: Self::HOURS_MAX,
minutes: 59,
seconds: 59,
microseconds: 999_999
);
let values = TimeMagnitude {
hours,
minutes,
seconds,
microseconds,
};
if sign.is_negative() && values == MAGNITUDE_ZERO {
return Err(MySqlTimeError::NegativeZero);
}
// This is only `true` if less than 1 second over the maximum magnitude
if values > MAGNITUDE_MAX {
return Err(MySqlTimeError::SubsecondExcess {
sign,
microseconds,
truncated: if sign.is_positive() {
Self::MAX
} else {
Self::MIN
},
});
}
Ok(Self {
sign,
magnitude: values,
})
}
/// Update the `sign` of this value.
pub fn with_sign(self, sign: MySqlTimeSign) -> Self {
Self { sign, ..self }
}
/// Return the sign (positive or negative) for this TIME value.
pub fn sign(&self) -> MySqlTimeSign {
self.sign
}
/// Returns `true` if `self` is zero (equal to [`Self::ZERO`]).
pub fn is_zero(&self) -> bool {
self == &Self::ZERO
}
/// Returns `true` if `self` is positive or zero, `false` if negative.
pub fn is_positive(&self) -> bool {
self.sign.is_positive()
}
/// Returns `true` if `self` is negative, `false` if positive or zero.
pub fn is_negative(&self) -> bool {
self.sign.is_positive()
}
/// Returns `true` if this interval is a valid time-of-day.
///
/// If `true`, the sign is positive and `hours` is not greater than 23.
pub fn is_valid_time_of_day(&self) -> bool {
self.sign.is_positive() && self.hours() < 24
}
/// Get the total number of hours in this interval, from 0 to 838.
///
/// If this value represents a time-of-day, the range is 0 to 23.
pub fn hours(&self) -> u32 {
self.magnitude.hours
}
/// Get the number of minutes in this interval, from 0 to 59.
pub fn minutes(&self) -> u8 {
self.magnitude.minutes
}
/// Get the number of seconds in this interval, from 0 to 59.
pub fn seconds(&self) -> u8 {
self.magnitude.seconds
}
/// Get the number of seconds in this interval, from 0 to 999,999.
pub fn microseconds(&self) -> u32 {
self.magnitude.microseconds
}
/// Convert this TIME value to a [`std::time::Duration`].
///
/// Returns `None` if this value is negative (cannot be represented).
pub fn to_duration(&self) -> Option<Duration> {
self.is_positive()
.then(|| Duration::new(self.whole_seconds() as u64, self.subsec_nanos()))
}
/// Get the whole number of seconds (`seconds + (minutes * 60) + (hours * 3600)`) in this time.
///
/// Sign is ignored.
pub(crate) fn whole_seconds(&self) -> u32 {
// If `hours` does not exceed 838 then this cannot overflow.
self.hours() * 3600 + self.minutes() as u32 * 60 + self.seconds() as u32
}
#[cfg_attr(not(any(feature = "time", feature = "chrono")), allow(dead_code))]
pub(crate) fn whole_seconds_signed(&self) -> i64 {
self.whole_seconds() as i64 * self.sign.signum() as i64
}
pub(crate) fn subsec_nanos(&self) -> u32 {
self.microseconds() * 1000
}
fn encoded_len(&self) -> u8 {
if self.is_zero() {
0
} else if self.microseconds() == 0 {
8
} else {
12
}
}
}
impl PartialOrd<MySqlTime> for MySqlTime {
fn partial_cmp(&self, other: &MySqlTime) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MySqlTime {
fn cmp(&self, other: &Self) -> Ordering {
// If the sides have different signs, we just need to compare those.
if self.sign != other.sign {
return self.sign.cmp(&other.sign);
}
// We've checked that both sides have the same sign
match self.sign {
MySqlTimeSign::Positive => self.magnitude.cmp(&other.magnitude),
// Reverse the comparison for negative values (smaller negative magnitude = greater)
MySqlTimeSign::Negative => other.magnitude.cmp(&self.magnitude),
}
}
}
impl Display for MySqlTime {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let TimeMagnitude {
hours,
minutes,
seconds,
microseconds,
} = self.magnitude;
// Obeys the `+` flag.
Display::fmt(&self.sign(), f)?;
write!(f, "{hours}:{minutes:02}:{seconds:02}")?;
// Write microseconds if not zero or a nonzero precision was explicitly requested.
if f.precision().map_or(microseconds != 0, |it| it != 0) {
f.write_char('.')?;
let mut remaining_precision = f.precision();
let mut remainder = microseconds;
let mut power_of_10 = 10u32.pow(5);
// Write digits from most-significant to least, up to the requested precision.
while remainder > 0 && remaining_precision != Some(0) {
let digit = remainder / power_of_10;
// 1 % 1 = 0
remainder %= power_of_10;
power_of_10 /= 10;
write!(f, "{digit}")?;
if let Some(remaining_precision) = &mut remaining_precision {
*remaining_precision = remaining_precision.saturating_sub(1);
}
}
// If any requested precision remains, pad with zeroes.
if let Some(precision) = remaining_precision.filter(|it| *it != 0) {
write!(f, "{:0precision$}", 0)?;
}
}
Ok(())
}
}
impl Type<MySql> for MySqlTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Time)
}
}
impl<'r> Decode<'r, MySql> for MySqlTime {
fn decode(value: <MySql as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let mut buf = value.as_bytes()?;
// Row decoding should have left the length byte on the front.
if buf.is_empty() {
return Err("empty buffer".into());
}
let length = buf.get_u8();
// MySQL specifies that if all fields are 0 then the length is 0 and no further data is sent
// https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
if length == 0 {
return Ok(Self::ZERO);
}
if !matches!(buf.len(), 8 | 12) {
return Err(format!(
"expected 8 or 12 bytes for TIME value, got {}",
buf.len()
)
.into());
}
let sign = MySqlTimeSign::from_byte(buf.get_u8())?;
// The wire protocol includes days but the text format doesn't. Isn't that crazy?
let days = buf.get_u32_le();
let hours = buf.get_u8();
let minutes = buf.get_u8();
let seconds = buf.get_u8();
let microseconds = if !buf.is_empty() { buf.get_u32_le() } else { 0 };
let whole_hours = days
.checked_mul(24)
.and_then(|days_to_hours| days_to_hours.checked_add(hours as u32))
.ok_or("overflow calculating whole hours from `days * 24 + hours`")?;
Ok(Self::new(
sign,
whole_hours,
minutes,
seconds,
microseconds,
)?)
}
MySqlValueFormat::Text => parse(value.as_str()?),
}
}
}
impl<'q> Encode<'q, MySql> for MySqlTime {
fn encode_by_ref(&self, buf: &mut <MySql as Database>::ArgumentBuffer<'q>) -> IsNull {
if self.is_zero() {
buf.put_u8(0);
return IsNull::No;
}
buf.put_u8(self.encoded_len());
buf.put_u8(self.sign.to_byte());
let TimeMagnitude {
hours: whole_hours,
minutes,
seconds,
microseconds,
} = self.magnitude;
let days = whole_hours / 24;
let hours = (whole_hours % 24) as u8;
buf.put_u32_le(days);
buf.put_u8(hours);
buf.put_u8(minutes);
buf.put_u8(seconds);
if microseconds != 0 {
buf.put_u32_le(microseconds);
}
IsNull::No
}
fn size_hint(&self) -> usize {
self.encoded_len() as usize + 1
}
}
/// Convert [`MySqlTime`] from [`std::time::Duration`].
///
/// ### Note: Precision Truncation
/// [`Duration`] supports nanosecond precision, but MySQL `TIME` values only support microsecond
/// precision.
///
/// For simplicity, higher precision values are truncated when converting.
/// If you prefer another rounding mode instead, you should apply that to the `Duration` first.
///
/// See also: [MySQL Manual, section 13.2.6: Fractional Seconds in Time Values](https://dev.mysql.com/doc/refman/8.3/en/fractional-seconds.html)
///
/// ### Errors:
/// Returns [`MySqlTimeError::FieldRange`] if the given duration is longer than `838:59:59.999999`.
///
impl TryFrom<Duration> for MySqlTime {
type Error = MySqlTimeError;
fn try_from(value: Duration) -> Result<Self, Self::Error> {
let hours = value.as_secs() / 3600;
let rem_seconds = value.as_secs() % 3600;
let minutes = (rem_seconds / 60) as u8;
let seconds = (rem_seconds % 60) as u8;
// Simply divides by 1000
let microseconds = value.subsec_micros();
Self::new(
MySqlTimeSign::Positive,
hours.try_into().map_err(|_| MySqlTimeError::FieldRange {
field: "hours",
max: Self::HOURS_MAX,
value: hours,
})?,
minutes,
seconds,
microseconds,
)
}
}
impl MySqlTimeSign {
fn from_byte(b: u8) -> Result<Self, BoxDynError> {
match b {
0 => Ok(Self::Positive),
1 => Ok(Self::Negative),
other => Err(format!("expected 0 or 1 for TIME sign byte, got {other}").into()),
}
}
fn to_byte(&self) -> u8 {
match self {
// We can't use `#[repr(u8)]` because this is opposite of the ordering we want from `Ord`
Self::Negative => 1,
Self::Positive => 0,
}
}
fn signum(&self) -> i32 {
match self {
Self::Negative => -1,
Self::Positive => 1,
}
}
/// Returns `true` if positive, `false` if negative.
pub fn is_positive(&self) -> bool {
matches!(self, Self::Positive)
}
/// Returns `true` if negative, `false` if positive.
pub fn is_negative(&self) -> bool {
matches!(self, Self::Negative)
}
}
impl Display for MySqlTimeSign {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::Positive if f.sign_plus() => f.write_char('+'),
Self::Negative => f.write_char('-'),
_ => Ok(()),
}
}
}
impl Type<MySql> for Duration {
fn type_info() -> MySqlTypeInfo {
MySqlTime::type_info()
}
}
impl<'r> Decode<'r, MySql> for Duration {
fn decode(value: <MySql as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
let time = MySqlTime::decode(value)?;
time.to_duration().ok_or_else(|| {
format!("`std::time::Duration` can only decode positive TIME values; got {time}").into()
})
}
}
// Not exposing this as a `FromStr` impl currently because `MySqlTime` is not designed to be
// a general interchange type.
fn parse(text: &str) -> Result<MySqlTime, BoxDynError> {
let mut segments = text.split(':');
let hours = segments
.next()
.ok_or("expected hours segment, got nothing")?;
let minutes = segments
.next()
.ok_or("expected minutes segment, got nothing")?;
let seconds = segments
.next()
.ok_or("expected seconds segment, got nothing")?;
// Include the sign in parsing for convenience;
// the allowed range of whole hours is much smaller than `i32`'s positive range.
let hours: i32 = hours
.parse()
.map_err(|e| format!("error parsing hours from {text:?} (segment {hours:?}): {e}"))?;
let sign = if hours.is_negative() {
MySqlTimeSign::Negative
} else {
MySqlTimeSign::Positive
};
let hours = hours.abs() as u32;
let minutes: u8 = minutes
.parse()
.map_err(|e| format!("error parsing minutes from {text:?} (segment {minutes:?}): {e}"))?;
let (seconds, microseconds): (u8, u32) = if let Some((seconds, microseconds)) =
seconds.split_once('.')
{
(
seconds.parse().map_err(|e| {
format!("error parsing seconds from {text:?} (segment {seconds:?}): {e}")
})?,
parse_microseconds(microseconds).map_err(|e| {
format!("error parsing microseconds from {text:?} (segment {microseconds:?}): {e}")
})?,
)
} else {
(
seconds.parse().map_err(|e| {
format!("error parsing seconds from {text:?} (segment {seconds:?}): {e}")
})?,
0,
)
};
Ok(MySqlTime::new(sign, hours, minutes, seconds, microseconds)?)
}
/// Parse microseconds from a fractional seconds string.
fn parse_microseconds(micros: &str) -> Result<u32, BoxDynError> {
const EXPECTED_DIGITS: usize = 6;
match micros.len() {
0 => Err("empty string".into()),
len @ ..=EXPECTED_DIGITS => {
// Fewer than 6 digits, multiply to the correct magnitude
let micros: u32 = micros.parse()?;
Ok(micros * 10u32.pow((EXPECTED_DIGITS - len) as u32))
}
// More digits than expected, truncate
_ => Ok(micros[..EXPECTED_DIGITS].parse()?),
}
}
#[cfg(test)]
mod tests {
use super::MySqlTime;
use crate::types::MySqlTimeSign;
use super::parse_microseconds;
#[test]
fn test_display() {
assert_eq!(MySqlTime::ZERO.to_string(), "0:00:00");
assert_eq!(format!("{:.0}", MySqlTime::ZERO), "0:00:00");
assert_eq!(format!("{:.3}", MySqlTime::ZERO), "0:00:00.000");
assert_eq!(format!("{:.6}", MySqlTime::ZERO), "0:00:00.000000");
assert_eq!(format!("{:.9}", MySqlTime::ZERO), "0:00:00.000000000");
assert_eq!(format!("{:.0}", MySqlTime::MAX), "838:59:59");
assert_eq!(format!("{:.3}", MySqlTime::MAX), "838:59:59.000");
assert_eq!(format!("{:.6}", MySqlTime::MAX), "838:59:59.000000");
assert_eq!(format!("{:.9}", MySqlTime::MAX), "838:59:59.000000000");
assert_eq!(format!("{:+.0}", MySqlTime::MAX), "+838:59:59");
assert_eq!(format!("{:+.3}", MySqlTime::MAX), "+838:59:59.000");
assert_eq!(format!("{:+.6}", MySqlTime::MAX), "+838:59:59.000000");
assert_eq!(format!("{:+.9}", MySqlTime::MAX), "+838:59:59.000000000");
assert_eq!(format!("{:.0}", MySqlTime::MIN), "-838:59:59");
assert_eq!(format!("{:.3}", MySqlTime::MIN), "-838:59:59.000");
assert_eq!(format!("{:.6}", MySqlTime::MIN), "-838:59:59.000000");
assert_eq!(format!("{:.9}", MySqlTime::MIN), "-838:59:59.000000000");
let positive = MySqlTime::new(MySqlTimeSign::Positive, 123, 45, 56, 890011).unwrap();
assert_eq!(positive.to_string(), "123:45:56.890011");
assert_eq!(format!("{positive:.0}"), "123:45:56");
assert_eq!(format!("{positive:.3}"), "123:45:56.890");
assert_eq!(format!("{positive:.6}"), "123:45:56.890011");
assert_eq!(format!("{positive:.9}"), "123:45:56.890011000");
assert_eq!(format!("{positive:+.0}"), "+123:45:56");
assert_eq!(format!("{positive:+.3}"), "+123:45:56.890");
assert_eq!(format!("{positive:+.6}"), "+123:45:56.890011");
assert_eq!(format!("{positive:+.9}"), "+123:45:56.890011000");
let negative = MySqlTime::new(MySqlTimeSign::Negative, 123, 45, 56, 890011).unwrap();
assert_eq!(negative.to_string(), "-123:45:56.890011");
assert_eq!(format!("{negative:.0}"), "-123:45:56");
assert_eq!(format!("{negative:.3}"), "-123:45:56.890");
assert_eq!(format!("{negative:.6}"), "-123:45:56.890011");
assert_eq!(format!("{negative:.9}"), "-123:45:56.890011000");
}
#[test]
fn test_parse_microseconds() {
assert_eq!(parse_microseconds("010").unwrap(), 10_000);
assert_eq!(parse_microseconds("0100000000").unwrap(), 10_000);
assert_eq!(parse_microseconds("890").unwrap(), 890_000);
assert_eq!(parse_microseconds("0890").unwrap(), 89_000);
assert_eq!(
// Case in point about not exposing this:
// we always truncate excess precision because it's simpler than rounding
// and MySQL should never return a higher precision.
parse_microseconds("123456789").unwrap(),
123456,
);
}
}

View file

@ -1,5 +1,6 @@
use byteorder::{ByteOrder, LittleEndian};
use bytes::Buf;
use sqlx_core::database::Database;
use time::macros::format_description;
use time::{Date, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset};
@ -8,7 +9,7 @@ use crate::encode::{Encode, IsNull};
use crate::error::{BoxDynError, UnexpectedNullError};
use crate::protocol::text::ColumnType;
use crate::type_info::MySqlTypeInfo;
use crate::types::Type;
use crate::types::{MySqlTime, MySqlTimeSign, Type};
use crate::{MySql, MySqlValueFormat, MySqlValueRef};
impl Type<MySql> for OffsetDateTime {
@ -52,7 +53,7 @@ impl Encode<'_, MySql> for Time {
// Time is not negative
buf.push(0);
// "date on 4 bytes little-endian format" (?)
// Number of days in the interval; always 0 for time-of-day values.
// https://mariadb.com/kb/en/resultset-row/#teimstamp-binary-encoding
buf.extend_from_slice(&[0_u8; 4]);
@ -76,29 +77,11 @@ impl<'r> Decode<'r, MySql> for Time {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let mut buf = value.as_bytes()?;
// data length, expecting 8 or 12 (fractional seconds)
let len = buf.get_u8();
// MySQL specifies that if all of hours, minutes, seconds, microseconds
// are 0 then the length is 0 and no further data is send
// https://dev.mysql.com/doc/internals/en/binary-protocol-value.html
if len == 0 {
return Ok(Time::MIDNIGHT);
}
// is negative : int<1>
let is_negative = buf.get_u8();
assert_eq!(is_negative, 0, "Negative dates/times are not supported");
// "date on 4 bytes little-endian format" (?)
// https://mariadb.com/kb/en/resultset-row/#timestamp-binary-encoding
buf.advance(4);
decode_time(len - 5, buf)
// Should never panic.
MySqlTime::decode(value)?.try_into()
}
// Retaining this parsing for now as it allows us to cross-check our impl.
MySqlValueFormat::Text => Time::parse(
value.as_str()?,
&format_description!("[hour]:[minute]:[second].[subsecond]"),
@ -108,6 +91,57 @@ impl<'r> Decode<'r, MySql> for Time {
}
}
impl TryFrom<MySqlTime> for Time {
type Error = BoxDynError;
fn try_from(time: MySqlTime) -> Result<Self, Self::Error> {
if !time.is_valid_time_of_day() {
return Err(format!("MySqlTime value out of range for `time::Time`: {time}").into());
}
Ok(Time::from_hms_micro(
// `is_valid_time_of_day()` ensures this won't overflow
time.hours() as u8,
time.minutes(),
time.seconds(),
time.microseconds(),
)?)
}
}
impl From<MySqlTime> for time::Duration {
fn from(time: MySqlTime) -> Self {
time::Duration::new(time.whole_seconds_signed(), time.subsec_nanos() as i32)
}
}
impl TryFrom<time::Duration> for MySqlTime {
type Error = BoxDynError;
fn try_from(value: time::Duration) -> Result<Self, Self::Error> {
let sign = if value.is_negative() {
MySqlTimeSign::Negative
} else {
MySqlTimeSign::Positive
};
// Similar to `TryFrom<chrono::TimeDelta>`, use `std::time::Duration` as an intermediate.
Ok(MySqlTime::try_from(std::time::Duration::try_from(value.abs())?)?.with_sign(sign))
}
}
impl Type<MySql> for time::Duration {
fn type_info() -> MySqlTypeInfo {
MySqlTime::type_info()
}
}
impl<'r> Decode<'r, MySql> for time::Duration {
fn decode(value: <MySql as Database>::ValueRef<'r>) -> Result<Self, BoxDynError> {
Ok(MySqlTime::decode(value)?.into())
}
}
impl Type<MySql> for Date {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Date)
@ -132,7 +166,14 @@ impl<'r> Decode<'r, MySql> for Date {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
Ok(decode_date(&value.as_bytes()?[1..])?.ok_or(UnexpectedNullError)?)
let buf = value.as_bytes()?;
// Row decoding should leave the length byte on the front.
if buf.is_empty() {
return Err("empty buffer".into());
}
Ok(decode_date(&buf[1..])?.ok_or(UnexpectedNullError)?)
}
MySqlValueFormat::Text => {
let s = value.as_str()?;
@ -183,12 +224,18 @@ impl<'r> Decode<'r, MySql> for PrimitiveDateTime {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let buf = value.as_bytes()?;
let len = buf[0];
let date = decode_date(&buf[1..])?.ok_or(UnexpectedNullError)?;
let mut buf = value.as_bytes()?;
if buf.is_empty() {
return Err("empty buffer".into());
}
let len = buf.get_u8();
let date = decode_date(buf)?.ok_or(UnexpectedNullError)?;
let dt = if len > 4 {
date.with_time(decode_time(len - 4, &buf[5..])?)
date.with_time(decode_time(&buf[4..])?)
} else {
date.midnight()
};
@ -255,12 +302,12 @@ fn encode_time(time: &Time, include_micros: bool, buf: &mut Vec<u8>) {
}
}
fn decode_time(len: u8, mut buf: &[u8]) -> Result<Time, BoxDynError> {
fn decode_time(mut buf: &[u8]) -> Result<Time, BoxDynError> {
let hour = buf.get_u8();
let minute = buf.get_u8();
let seconds = buf.get_u8();
let micros = if len > 3 {
let micros = if !buf.is_empty() {
// microseconds : int<EOF>
buf.get_uint_le(buf.len())
} else {

View file

@ -119,6 +119,20 @@ fn uint_decode(value: MySqlValueRef<'_>) -> Result<u64, BoxDynError> {
MySqlValueFormat::Binary => {
let buf = value.as_bytes()?;
// Check conditions that could cause `read_uint()` to panic.
if buf.is_empty() {
return Err("empty buffer".into());
}
if buf.len() > 8 {
return Err(format!(
"expected no more than 8 bytes for unsigned integer value, got {}",
buf.len()
)
.into());
}
LittleEndian::read_uint(buf, buf.len())
}
})

View file

@ -9,6 +9,9 @@ use sqlx::{Executor, Row};
use sqlx::types::Text;
use sqlx::mysql::types::MySqlTime;
use sqlx_mysql::types::MySqlTimeSign;
use sqlx_test::{new, test_type};
test_type!(bool(MySql, "false" == false, "true" == true));
@ -70,34 +73,44 @@ test_type!(uuid_simple<sqlx::types::uuid::fmt::Simple>(MySql,
== sqlx::types::Uuid::parse_str("00000000000000000000000000000000").unwrap().simple()
));
test_type!(mysql_time<MySqlTime>(MySql,
"TIME '00:00:00.000000'" == MySqlTime::ZERO,
"TIME '-00:00:00.000000'" == MySqlTime::ZERO,
"TIME '838:59:59.0'" == MySqlTime::MAX,
"TIME '-838:59:59.0'" == MySqlTime::MIN,
"TIME '123:45:56.890'" == MySqlTime::new(MySqlTimeSign::Positive, 123, 45, 56, 890_000).unwrap(),
"TIME '-123:45:56.890'" == MySqlTime::new(MySqlTimeSign::Negative, 123, 45, 56, 890_000).unwrap(),
"TIME '123:45:56.890011'" == MySqlTime::new(MySqlTimeSign::Positive, 123, 45, 56, 890_011).unwrap(),
"TIME '-123:45:56.890011'" == MySqlTime::new(MySqlTimeSign::Negative, 123, 45, 56, 890_011).unwrap(),
));
#[cfg(feature = "chrono")]
mod chrono {
use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
use super::*;
test_type!(chrono_date<NaiveDate>(MySql,
"DATE '2001-01-05'" == NaiveDate::from_ymd(2001, 1, 5),
"DATE '2050-11-23'" == NaiveDate::from_ymd(2050, 11, 23)
"DATE '2001-01-05'" == NaiveDate::from_ymd_opt(2001, 1, 5).unwrap(),
"DATE '2050-11-23'" == NaiveDate::from_ymd_opt(2050, 11, 23).unwrap()
));
test_type!(chrono_time_zero<NaiveTime>(MySql,
"TIME '00:00:00.000000'" == NaiveTime::from_hms_micro(0, 0, 0, 0)
"TIME '00:00:00.000000'" == NaiveTime::from_hms_micro_opt(0, 0, 0, 0).unwrap()
));
test_type!(chrono_time<NaiveTime>(MySql,
"TIME '05:10:20.115100'" == NaiveTime::from_hms_micro(5, 10, 20, 115100)
"TIME '05:10:20.115100'" == NaiveTime::from_hms_micro_opt(5, 10, 20, 115100).unwrap()
));
test_type!(chrono_date_time<NaiveDateTime>(MySql,
"TIMESTAMP '2019-01-02 05:10:20'" == NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20)
"TIMESTAMP '2019-01-02 05:10:20'" == NaiveDate::from_ymd_opt(2019, 1, 2).unwrap().and_hms_opt(5, 10, 20).unwrap()
));
test_type!(chrono_timestamp<DateTime::<Utc>>(MySql,
"TIMESTAMP '2019-01-02 05:10:20.115100'"
== DateTime::<Utc>::from_utc(
NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100),
Utc,
== Utc.from_utc_datetime(
&NaiveDate::from_ymd_opt(2019, 1, 2).unwrap().and_hms_micro_opt(5, 10, 20, 115100).unwrap(),
)
));