fix: audit PgValueRef::get() and usage sites for bad casts

This commit is contained in:
Austin Bonander 2024-08-16 13:19:13 -07:00
parent 26c85240fc
commit af9cce726b
4 changed files with 26 additions and 19 deletions

View file

@ -242,6 +242,9 @@ where
// length of the array axis
let len = buf.get_i32();
let len = usize::try_from(len)
.map_err(|_| format!("overflow converting array len ({len}) to usize"))?;
// the lower bound, we only support arrays starting from "1"
let lower = buf.get_i32();
@ -249,14 +252,12 @@ where
return Err(format!("encountered an array with a lower bound of {lower} in the first dimension; only arrays starting at one are supported").into());
}
let mut elements = Vec::with_capacity(len as usize);
let mut elements = Vec::with_capacity(len);
for _ in 0..len {
elements.push(T::decode(PgValueRef::get(
&mut buf,
format,
element_type_info.clone(),
))?)
let value_ref = PgValueRef::get(&mut buf, format, element_type_info.clone())?;
elements.push(T::decode(value_ref)?);
}
Ok(elements)

View file

@ -350,7 +350,7 @@ where
if !flags.contains(RangeFlags::LB_INF) {
let value =
T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?;
T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?;
start = if flags.contains(RangeFlags::LB_INC) {
Bound::Included(value)
@ -361,7 +361,7 @@ where
if !flags.contains(RangeFlags::UB_INF) {
let value =
T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone()))?;
T::decode(PgValueRef::get(&mut buf, value.format, element_ty.clone())?)?;
end = if flags.contains(RangeFlags::UB_INC) {
Bound::Included(value)

View file

@ -137,7 +137,7 @@ impl<'r> PgRecordDecoder<'r> {
self.ind += 1;
T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type))
T::decode(PgValueRef::get(&mut self.buf, self.fmt, element_type)?)
}
PgValueFormat::Text => {

View file

@ -1,11 +1,10 @@
use crate::error::{BoxDynError, UnexpectedNullError};
use crate::{PgTypeInfo, Postgres};
use sqlx_core::bytes::{Buf, Bytes};
pub(crate) use sqlx_core::value::{Value, ValueRef};
use std::borrow::Cow;
use std::str::from_utf8;
pub(crate) use sqlx_core::value::{Value, ValueRef};
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
#[repr(u8)]
pub enum PgValueFormat {
@ -31,24 +30,31 @@ pub struct PgValue {
}
impl<'r> PgValueRef<'r> {
pub(crate) fn get(buf: &mut &'r [u8], format: PgValueFormat, ty: PgTypeInfo) -> Self {
let mut element_len = buf.get_i32();
pub(crate) fn get(
buf: &mut &'r [u8],
format: PgValueFormat,
ty: PgTypeInfo,
) -> Result<Self, String> {
let element_len = buf.get_i32();
let element_val = if element_len == -1 {
element_len = 0;
None
} else {
Some(&buf[..(element_len as usize)])
let element_len: usize = element_len
.try_into()
.map_err(|_| format!("overflow converting element_len ({element_len}) to usize"))?;
let val = &buf[..element_len];
buf.advance(element_len);
Some(val)
};
buf.advance(element_len as usize);
PgValueRef {
Ok(PgValueRef {
value: element_val,
row: None,
type_info: ty,
format,
}
})
}
pub fn format(&self) -> PgValueFormat {