Fixing BigDecimal conversion for PostgreSQL

Now working properly with numbers, such as `0.01` and `0.012`.
This commit is contained in:
Julius de Bruijn 2020-10-30 11:27:49 +01:00 committed by Austin Bonander
parent 25e72925fa
commit a0007b4e98
2 changed files with 46 additions and 40 deletions

View file

@ -1,7 +1,7 @@
use std::cmp;
use std::convert::{TryFrom, TryInto};
use bigdecimal::BigDecimal;
use bigdecimal::{BigDecimal, ToPrimitive, Zero};
use num_bigint::{BigInt, Sign};
use crate::decode::Decode;
@ -77,65 +77,64 @@ impl TryFrom<&'_ BigDecimal> for PgNumeric {
type Error = BoxDynError;
fn try_from(decimal: &BigDecimal) -> Result<Self, BoxDynError> {
let base_10_to_10000 = |chunk: &[u8]| chunk.iter().fold(0i16, |a, &d| a * 10 + d as i16);
if decimal.is_zero() {
return Ok(PgNumeric::Number {
sign: PgNumericSign::Positive,
scale: 0,
weight: 0,
digits: vec![],
});
}
// NOTE: this unfortunately copies the BigInt internally
let (integer, exp) = decimal.as_bigint_and_exponent();
// this routine is specifically optimized for base-10
// FIXME: is there a way to iterate over the digits to avoid the Vec allocation
let (sign, base_10) = integer.to_radix_be(10);
// weight is positive power of 10000
// exp is the negative power of 10
let weight_10 = base_10.len() as i64 - exp;
// scale is only nonzero when we have fractional digits
// since `exp` is the _negative_ decimal exponent, it tells us
// exactly what our scale should be
let scale: i16 = cmp::max(0, exp).try_into()?;
// there's an implicit +1 offset in the interpretation
let weight: i16 = if weight_10 <= 0 {
weight_10 / 4 - 1
} else {
// the `-1` is a fix for an off by 1 error (4 digits should still be 0 weight)
(weight_10 - 1) / 4
}
.try_into()?;
let (sign, uint) = integer.into_parts();
let mut mantissa = uint.to_u128().unwrap();
let digits_len = if base_10.len() % 4 != 0 {
base_10.len() / 4 + 1
} else {
base_10.len() / 4
};
// If our scale is not a multiple of 4, we need to go to the next
// multiple.
let groups_diff = scale % 4;
if groups_diff > 0 {
let remainder = 4 - groups_diff as u32;
let power = 10u32.pow(remainder as u32) as u128;
let offset = weight_10.rem_euclid(4) as usize;
let mut digits = Vec::with_capacity(digits_len);
if let Some(first) = base_10.get(..offset) {
if offset != 0 {
digits.push(base_10_to_10000(first));
}
mantissa = mantissa * power;
}
if let Some(rest) = base_10.get(offset..) {
digits.extend(
rest.chunks(4)
.map(|chunk| base_10_to_10000(chunk) * 10i16.pow(4 - chunk.len() as u32)),
);
// Array to store max mantissa of Decimal in Postgres decimal format.
let mut digits = Vec::with_capacity(8);
// Convert to base-10000.
while mantissa != 0 {
digits.push((mantissa % 10_000) as i16);
mantissa /= 10_000;
}
// Change the endianness.
digits.reverse();
// Weight is number of digits on the left side of the decimal.
let digits_after_decimal = (scale + 3) as u16 / 4;
let weight = digits.len() as i16 - digits_after_decimal as i16 - 1;
// Remove non-significant zeroes.
while let Some(&0) = digits.last() {
digits.pop();
}
let sign = match sign {
Sign::Plus | Sign::NoSign => PgNumericSign::Positive,
Sign::Minus => PgNumericSign::Negative,
};
Ok(PgNumeric::Number {
sign: match sign {
Sign::Plus | Sign::NoSign => PgNumericSign::Positive,
Sign::Minus => PgNumericSign::Negative,
},
sign,
scale,
weight,
digits,

View file

@ -396,7 +396,14 @@ test_type!(bigdecimal<sqlx::types::BigDecimal>(Postgres,
"10000::numeric" == "10000".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.1::numeric" == "0.1".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.01::numeric" == "0.01".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.012::numeric" == "0.012".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.0123::numeric" == "0.0123".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.01234::numeric" == "0.01234".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.012345::numeric" == "0.012345".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.0123456::numeric" == "0.0123456".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.01234567::numeric" == "0.01234567".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.012345678::numeric" == "0.012345678".parse::<sqlx::types::BigDecimal>().unwrap(),
"0.0123456789::numeric" == "0.0123456789".parse::<sqlx::types::BigDecimal>().unwrap(),
"12.34::numeric" == "12.34".parse::<sqlx::types::BigDecimal>().unwrap(),
"12345.6789::numeric" == "12345.6789".parse::<sqlx::types::BigDecimal>().unwrap(),
));