From e9a562f89ade258e87db9c07bc757747484c24f3 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Thu, 11 Jun 2020 03:34:03 -0700 Subject: [PATCH] fix(mysql): handle MySQL sending more or less bytes than we expect for an integer type --- sqlx-core/src/mysql/types/int.rs | 54 +++++++++++++++---------------- sqlx-core/src/mysql/types/uint.rs | 54 +++++++++++++++---------------- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/sqlx-core/src/mysql/types/int.rs b/sqlx-core/src/mysql/types/int.rs index aa54b7c7..a09c4252 100644 --- a/sqlx-core/src/mysql/types/int.rs +++ b/sqlx-core/src/mysql/types/int.rs @@ -1,3 +1,5 @@ +use std::convert::TryInto; + use byteorder::{ByteOrder, LittleEndian}; use crate::decode::Decode; @@ -7,17 +9,6 @@ use crate::mysql::protocol::text::{ColumnFlags, ColumnType}; use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueFormat, MySqlValueRef}; use crate::types::Type; -fn int_accepts(ty: &MySqlTypeInfo) -> bool { - matches!( - ty.r#type, - ColumnType::Tiny - | ColumnType::Short - | ColumnType::Long - | ColumnType::Int24 - | ColumnType::LongLong - ) && !ty.flags.contains(ColumnFlags::UNSIGNED) -} - impl Type for i8 { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Tiny) @@ -74,16 +65,34 @@ impl Encode<'_, MySql> for i64 { } } +fn int_accepts(ty: &MySqlTypeInfo) -> bool { + matches!( + ty.r#type, + ColumnType::Tiny + | ColumnType::Short + | ColumnType::Long + | ColumnType::Int24 + | ColumnType::LongLong + ) && !ty.flags.contains(ColumnFlags::UNSIGNED) +} + +fn int_decode(value: MySqlValueRef<'_>) -> Result { + Ok(match value.format() { + MySqlValueFormat::Text => value.as_str()?.parse()?, + MySqlValueFormat::Binary => { + let buf = value.as_bytes()?; + LittleEndian::read_int(buf, buf.len()) + } + }) +} + impl Decode<'_, MySql> for i8 { fn accepts(ty: &MySqlTypeInfo) -> bool { int_accepts(ty) } fn decode(value: MySqlValueRef<'_>) -> Result { - Ok(match value.format() { - MySqlValueFormat::Binary => value.as_bytes()?[0] as i8, - MySqlValueFormat::Text => value.as_str()?.parse()?, - }) + int_decode(value)?.try_into().map_err(Into::into) } } @@ -93,10 +102,7 @@ impl Decode<'_, MySql> for i16 { } fn decode(value: MySqlValueRef<'_>) -> Result { - Ok(match value.format() { - MySqlValueFormat::Binary => LittleEndian::read_i16(value.as_bytes()?), - MySqlValueFormat::Text => value.as_str()?.parse()?, - }) + int_decode(value)?.try_into().map_err(Into::into) } } @@ -106,10 +112,7 @@ impl Decode<'_, MySql> for i32 { } fn decode(value: MySqlValueRef<'_>) -> Result { - Ok(match value.format() { - MySqlValueFormat::Binary => LittleEndian::read_i32(value.as_bytes()?), - MySqlValueFormat::Text => value.as_str()?.parse()?, - }) + int_decode(value)?.try_into().map_err(Into::into) } } @@ -119,9 +122,6 @@ impl Decode<'_, MySql> for i64 { } fn decode(value: MySqlValueRef<'_>) -> Result { - Ok(match value.format() { - MySqlValueFormat::Binary => LittleEndian::read_i64(value.as_bytes()?), - MySqlValueFormat::Text => value.as_str()?.parse()?, - }) + int_decode(value)?.try_into().map_err(Into::into) } } diff --git a/sqlx-core/src/mysql/types/uint.rs b/sqlx-core/src/mysql/types/uint.rs index 35c09f59..4a54fc9d 100644 --- a/sqlx-core/src/mysql/types/uint.rs +++ b/sqlx-core/src/mysql/types/uint.rs @@ -1,3 +1,5 @@ +use std::convert::TryInto; + use byteorder::{ByteOrder, LittleEndian}; use crate::decode::Decode; @@ -15,17 +17,6 @@ fn uint_type_info(ty: ColumnType) -> MySqlTypeInfo { } } -fn uint_accepts(ty: &MySqlTypeInfo) -> bool { - matches!( - ty.r#type, - ColumnType::Tiny - | ColumnType::Short - | ColumnType::Long - | ColumnType::Int24 - | ColumnType::LongLong - ) && ty.flags.contains(ColumnFlags::UNSIGNED) -} - impl Type for u8 { fn type_info() -> MySqlTypeInfo { uint_type_info(ColumnType::Tiny) @@ -82,16 +73,34 @@ impl Encode<'_, MySql> for u64 { } } +fn uint_accepts(ty: &MySqlTypeInfo) -> bool { + matches!( + ty.r#type, + ColumnType::Tiny + | ColumnType::Short + | ColumnType::Long + | ColumnType::Int24 + | ColumnType::LongLong + ) && ty.flags.contains(ColumnFlags::UNSIGNED) +} + +fn uint_decode(value: MySqlValueRef<'_>) -> Result { + Ok(match value.format() { + MySqlValueFormat::Text => value.as_str()?.parse()?, + MySqlValueFormat::Binary => { + let buf = value.as_bytes()?; + LittleEndian::read_uint(buf, buf.len()) + } + }) +} + impl Decode<'_, MySql> for u8 { fn accepts(ty: &MySqlTypeInfo) -> bool { uint_accepts(ty) } fn decode(value: MySqlValueRef<'_>) -> Result { - Ok(match value.format() { - MySqlValueFormat::Binary => value.as_bytes()?[0] as u8, - MySqlValueFormat::Text => value.as_str()?.parse()?, - }) + uint_decode(value)?.try_into().map_err(Into::into) } } @@ -101,10 +110,7 @@ impl Decode<'_, MySql> for u16 { } fn decode(value: MySqlValueRef<'_>) -> Result { - Ok(match value.format() { - MySqlValueFormat::Binary => LittleEndian::read_u16(value.as_bytes()?), - MySqlValueFormat::Text => value.as_str()?.parse()?, - }) + uint_decode(value)?.try_into().map_err(Into::into) } } @@ -114,10 +120,7 @@ impl Decode<'_, MySql> for u32 { } fn decode(value: MySqlValueRef<'_>) -> Result { - Ok(match value.format() { - MySqlValueFormat::Binary => LittleEndian::read_u32(value.as_bytes()?), - MySqlValueFormat::Text => value.as_str()?.parse()?, - }) + uint_decode(value)?.try_into().map_err(Into::into) } } @@ -127,9 +130,6 @@ impl Decode<'_, MySql> for u64 { } fn decode(value: MySqlValueRef<'_>) -> Result { - Ok(match value.format() { - MySqlValueFormat::Binary => LittleEndian::read_u64(value.as_bytes()?), - MySqlValueFormat::Text => value.as_str()?.parse()?, - }) + uint_decode(value)?.try_into().map_err(Into::into) } }