factor::numeric::Montgomery: Add debug assertions

In debug mode, checks that all arithmetic operations coincide with the
plain-u64 versions, as long as the latter does not overflow.
This commit is contained in:
nicoo 2020-05-30 10:11:05 +02:00
parent 8a4d0d30ad
commit 33e18b4cd3

View file

@ -29,7 +29,8 @@ pub(crate) trait Arithmetic: Copy + Sized {
fn mul(&self, a: Self::I, b: Self::I) -> Self::I;
fn pow(&self, mut a: Self::I, mut b: u64) -> Self::I {
let mut result = self.from_u64(1u64);
let (_a, _b) = (a, b);
let mut result = self.one();
while b > 0 {
if b & 1 != 0 {
result = self.mul(result, a);
@ -37,6 +38,15 @@ pub(crate) trait Arithmetic: Copy + Sized {
a = self.mul(a, a);
b >>= 1;
}
// Check that r (reduced back to the usual representation) equals
// a^b % n, unless the latter computation overflows
debug_assert!(self
.to_u64(_a)
.checked_pow(_b as u32)
.map(|r| r % self.modulus() == self.to_u64(result))
.unwrap_or(true));
result
}
@ -79,10 +89,9 @@ impl Arithmetic for Montgomery {
type I = Wrapping<u64>;
fn new(n: u64) -> Self {
Montgomery {
a: inv_mod_u64(n).wrapping_neg(),
n,
}
let a = inv_mod_u64(n).wrapping_neg();
debug_assert_eq!(n.wrapping_mul(a), 1_u64.wrapping_neg());
Montgomery { a, n }
}
fn modulus(&self) -> u64 {
@ -91,7 +100,10 @@ impl Arithmetic for Montgomery {
fn from_u64(&self, x: u64) -> Self::I {
// TODO: optimise!
Wrapping((((x as u128) << 64) % self.n as u128) as u64)
assert!(x < self.n);
let r = Wrapping((((x as u128) << 64) % self.n as u128) as u64);
debug_assert_eq!(x, self.to_u64(r));
r
}
fn to_u64(&self, n: Self::I) -> u64 {
@ -99,11 +111,43 @@ impl Arithmetic for Montgomery {
}
fn add(&self, a: Self::I, b: Self::I) -> Self::I {
a + b
let r = a + b;
// Check that r (reduced back to the usual representation) equals
// a+b % n
#[cfg(debug_assertions)]
{
let a_r = self.to_u64(a);
let b_r = self.to_u64(b);
let r_r = self.to_u64(r);
let r_2 = (((a_r as u128) + (b_r as u128)) % (self.n as u128)) as u64;
debug_assert_eq!(
r_r, r_2,
"[{}] = {} ≠ {} = {} + {} = [{}] + [{}] mod {}; a = {}",
r, r_r, r_2, a_r, b_r, a, b, self.n, self.a
);
}
r
}
fn mul(&self, a: Self::I, b: Self::I) -> Self::I {
Wrapping(self.reduce((a * b).0))
let r = Wrapping(self.reduce((a * b).0));
// Check that r (reduced back to the usual representation) equals
// a*b % n
#[cfg(debug_assertions)]
{
let a_r = self.to_u64(a);
let b_r = self.to_u64(b);
let r_r = self.to_u64(r);
let r_2 = (((a_r as u128) * (b_r as u128)) % (self.n as u128)) as u64;
debug_assert_eq!(
r_r, r_2,
"[{}] = {} ≠ {} = {} * {} = [{}] * [{}] mod {}; a = {}",
r, r_r, r_2, a_r, b_r, a, b, self.n, self.a
);
}
r
}
}