fix: audit PgValueRef::get() and usage sites for bad casts

This commit is contained in:
Austin Bonander 2024-08-16 13:19:13 -07:00
parent 26c85240fc
commit af9cce726b
4 changed files with 26 additions and 19 deletions

View file

@ -242,6 +242,9 @@ where
// length of the array axis // length of the array axis
let len = buf.get_i32(); let len = buf.get_i32();
let len = usize::try_from(len)
.map_err(|_| format!("overflow converting array len ({len}) to usize"))?;
// the lower bound, we only support arrays starting from "1" // the lower bound, we only support arrays starting from "1"
let lower = buf.get_i32(); let lower = buf.get_i32();
@ -249,14 +252,12 @@ where
return Err(format!("encountered an array with a lower bound of {lower} in the first dimension; only arrays starting at one are supported").into()); return Err(format!("encountered an array with a lower bound of {lower} in the first dimension; only arrays starting at one are supported").into());
} }
let mut elements = Vec::with_capacity(len as usize); let mut elements = Vec::with_capacity(len);
for _ in 0..len { for _ in 0..len {
elements.push(T::decode(PgValueRef::get( let value_ref = PgValueRef::get(&mut buf, format, element_type_info.clone())?;
&mut buf,
format, elements.push(T::decode(value_ref)?);
element_type_info.clone(),
))?)
} }
Ok(elements) Ok(elements)

View file

@ -350,7 +350,7 @@ where
if !flags.contains(RangeFlags::LB_INF) { if !flags.contains(RangeFlags::LB_INF) {
let value = let value =
T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?; T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?;
start = if flags.contains(RangeFlags::LB_INC) { start = if flags.contains(RangeFlags::LB_INC) {
Bound::Included(value) Bound::Included(value)
@ -361,7 +361,7 @@ where
if !flags.contains(RangeFlags::UB_INF) { if !flags.contains(RangeFlags::UB_INF) {
let value = let value =
T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?; T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?;
end = if flags.contains(RangeFlags::UB_INC) { end = if flags.contains(RangeFlags::UB_INC) {
Bound::Included(value) Bound::Included(value)

View file

@ -137,7 +137,7 @@ impl<'r> PgRecordDecoder<'r> {
self.ind += 1; self.ind += 1;
T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type)) T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type)?)
} }
PgValueFormat::Text => { PgValueFormat::Text => {

View file

@ -1,11 +1,10 @@
use crate::error::{BoxDynError, UnexpectedNullError}; use crate::error::{BoxDynError, UnexpectedNullError};
use crate::{PgTypeInfo, Postgres}; use crate::{PgTypeInfo, Postgres};
use sqlx_core::bytes::{Buf, Bytes}; use sqlx_core::bytes::{Buf, Bytes};
pub(crate) use sqlx_core::value::{Value, ValueRef};
use std::borrow::Cow; use std::borrow::Cow;
use std::str::from_utf8; use std::str::from_utf8;
pub(crate) use sqlx_core::value::{Value, ValueRef};
#[derive(Debug, Clone, Copy, Eq, PartialEq)] #[derive(Debug, Clone, Copy, Eq, PartialEq)]
#[repr(u8)] #[repr(u8)]
pub enum PgValueFormat { pub enum PgValueFormat {
@ -31,24 +30,31 @@ pub struct PgValue {
} }
impl<'r> PgValueRef<'r> { impl<'r> PgValueRef<'r> {
pub(crate) fn get(buf: &mut &'r [u8], format: PgValueFormat, ty: PgTypeInfo) -> Self { pub(crate) fn get(
let mut element_len = buf.get_i32(); buf: &mut &'r [u8],
format: PgValueFormat,
ty: PgTypeInfo,
) -> Result<Self, String> {
let element_len = buf.get_i32();
let element_val = if element_len == -1 { let element_val = if element_len == -1 {
element_len = 0;
None None
} else { } else {
Some(&buf[..(element_len as usize)]) let element_len: usize = element_len
.try_into()
.map_err(|_| format!("overflow converting element_len ({element_len}) to usize"))?;
let val = &buf[..element_len];
buf.advance(element_len);
Some(val)
}; };
buf.advance(element_len as usize); Ok(PgValueRef {
PgValueRef {
value: element_val, value: element_val,
row: None, row: None,
type_info: ty, type_info: ty,
format, format,
} })
} }
pub fn format(&self) -> PgValueFormat { pub fn format(&self) -> PgValueFormat {