Merge pull request #1554 from nbraud/factor/faster/montgomery32

factor: Refactor and improve performance (plus a few bug fixes)
This commit is contained in:
Roy Ivy III 2020-07-24 11:29:47 -05:00
commit 8cda0f596e
6 changed files with 339 additions and 103 deletions

27
Cargo.lock generated
View file

@ -649,6 +649,23 @@ dependencies = [
"pkg-config 0.3.17 (registry+https://github.com/rust-lang/crates.io-index)", "pkg-config 0.3.17 (registry+https://github.com/rust-lang/crates.io-index)",
] ]
[[package]]
name = "paste"
version = "0.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"paste-impl 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)",
"proc-macro-hack 0.5.16 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "paste-impl"
version = "0.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"proc-macro-hack 0.5.16 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]] [[package]]
name = "pkg-config" name = "pkg-config"
version = "0.3.17" version = "0.3.17"
@ -668,6 +685,11 @@ name = "ppv-lite86"
version = "0.2.8" version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "proc-macro-hack"
version = "0.5.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.18" version = "1.0.18"
@ -1335,6 +1357,8 @@ dependencies = [
name = "uu_factor" name = "uu_factor"
version = "0.0.1" version = "0.0.1"
dependencies = [ dependencies = [
"num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)",
"paste 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)",
"quickcheck 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)", "quickcheck 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)",
"rand 0.5.6 (registry+https://github.com/rust-lang/crates.io-index)", "rand 0.5.6 (registry+https://github.com/rust-lang/crates.io-index)",
"uucore 0.0.4 (git+https://github.com/uutils/uucore.git?branch=canary)", "uucore 0.0.4 (git+https://github.com/uutils/uucore.git?branch=canary)",
@ -2253,9 +2277,12 @@ dependencies = [
"checksum numtoa 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b8f8bdf33df195859076e54ab11ee78a1b208382d3a26ec40d142ffc1ecc49ef" "checksum numtoa 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b8f8bdf33df195859076e54ab11ee78a1b208382d3a26ec40d142ffc1ecc49ef"
"checksum onig 4.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8518fcb2b1b8c2f45f0ad499df4fda6087fc3475ca69a185c173b8315d2fb383" "checksum onig 4.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8518fcb2b1b8c2f45f0ad499df4fda6087fc3475ca69a185c173b8315d2fb383"
"checksum onig_sys 69.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388410bf5fa341f10e58e6db3975f4bea1ac30247dd79d37a9e5ced3cb4cc3b0" "checksum onig_sys 69.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388410bf5fa341f10e58e6db3975f4bea1ac30247dd79d37a9e5ced3cb4cc3b0"
"checksum paste 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)" = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880"
"checksum paste-impl 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)" = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6"
"checksum pkg-config 0.3.17 (registry+https://github.com/rust-lang/crates.io-index)" = "05da548ad6865900e60eaba7f589cc0783590a92e940c26953ff81ddbab2d677" "checksum pkg-config 0.3.17 (registry+https://github.com/rust-lang/crates.io-index)" = "05da548ad6865900e60eaba7f589cc0783590a92e940c26953ff81ddbab2d677"
"checksum platform-info 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f2fd076acdc7a98374de6e300bf3af675997225bef21aecac2219553f04dd7e8" "checksum platform-info 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f2fd076acdc7a98374de6e300bf3af675997225bef21aecac2219553f04dd7e8"
"checksum ppv-lite86 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "237a5ed80e274dbc66f86bd59c1e25edc039660be53194b5fe0a482e0f2612ea" "checksum ppv-lite86 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "237a5ed80e274dbc66f86bd59c1e25edc039660be53194b5fe0a482e0f2612ea"
"checksum proc-macro-hack 0.5.16 (registry+https://github.com/rust-lang/crates.io-index)" = "7e0456befd48169b9f13ef0f0ad46d492cf9d2dbb918bcf38e01eed4ce3ec5e4"
"checksum proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)" = "beae6331a816b1f65d04c45b078fd8e6c93e8071771f41b8163255bbd8d7c8fa" "checksum proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)" = "beae6331a816b1f65d04c45b078fd8e6c93e8071771f41b8163255bbd8d7c8fa"
"checksum quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" "checksum quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
"checksum quickcheck 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)" = "a44883e74aa97ad63db83c4bf8ca490f02b2fc02f92575e720c8551e843c945f" "checksum quickcheck 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)" = "a44883e74aa97ad63db83c4bf8ca490f02b2fc02f92575e720c8551e843c945f"

View file

@ -11,12 +11,18 @@ keywords = ["coreutils", "uutils", "cross-platform", "cli", "utility"]
categories = ["command-line-utilities"] categories = ["command-line-utilities"]
edition = "2018" edition = "2018"
[build-dependencies]
num-traits = "0.2" # used in src/numerics.rs, which is included by build.rs
[dependencies] [dependencies]
num-traits = "0.2"
rand = "0.5" rand = "0.5"
uucore = { version="0.0.4", package="uucore", git="https://github.com/uutils/uucore.git", branch="canary" } uucore = { version="0.0.4", package="uucore", git="https://github.com/uutils/uucore.git", branch="canary" }
uucore_procs = { version="0.0.4", package="uucore_procs", git="https://github.com/uutils/uucore.git", branch="canary" } uucore_procs = { version="0.0.4", package="uucore_procs", git="https://github.com/uutils/uucore.git", branch="canary" }
[dev-dependencies] [dev-dependencies]
paste = "0.1.18"
quickcheck = "0.9.2" quickcheck = "0.9.2"
[[bin]] [[bin]]

View file

@ -29,7 +29,7 @@ use miller_rabin::is_prime;
#[path = "src/numeric.rs"] #[path = "src/numeric.rs"]
mod numeric; mod numeric;
use numeric::inv_mod_u64; use numeric::modular_inverse;
mod sieve; mod sieve;
@ -57,7 +57,7 @@ fn main() {
let mut x = primes.next().unwrap(); let mut x = primes.next().unwrap();
for next in primes { for next in primes {
// format the table // format the table
let outstr = format!("({}, {}, {}),", x, inv_mod_u64(x), std::u64::MAX / x); let outstr = format!("({}, {}, {}),", x, modular_inverse(x), std::u64::MAX / x);
if cols + outstr.len() > MAX_WIDTH { if cols + outstr.len() > MAX_WIDTH {
write!(file, "\n {}", outstr).unwrap(); write!(file, "\n {}", outstr).unwrap();
cols = 4 + outstr.len(); cols = 4 + outstr.len();

View file

@ -54,12 +54,16 @@ impl fmt::Display for Factors {
} }
} }
fn _factor<A: Arithmetic>(num: u64, f: Factors) -> Factors { fn _factor<A: Arithmetic + miller_rabin::Basis>(num: u64, f: Factors) -> Factors {
use miller_rabin::Result::*; use miller_rabin::Result::*;
// Shadow the name, so the recursion automatically goes from “Big” arithmetic to small. // Shadow the name, so the recursion automatically goes from “Big” arithmetic to small.
let _factor = |n, f| { let _factor = |n, f| {
// TODO: Optimise with 32 and 64b versions if n < (1 << 32) {
_factor::<Montgomery<u32>>(n, f)
} else {
_factor::<A>(n, f) _factor::<A>(n, f)
}
}; };
if num == 1 { if num == 1 {
@ -101,8 +105,11 @@ pub fn factor(mut n: u64) -> Factors {
let (factors, n) = table::factor(n, factors); let (factors, n) = table::factor(n, factors);
// TODO: Optimise with 32 and 64b versions if n < (1 << 32) {
_factor::<Montgomery>(n, factors) _factor::<Montgomery<u32>>(n, factors)
} else {
_factor::<Montgomery<u64>>(n, factors)
}
} }
#[cfg(test)] #[cfg(test)]

View file

@ -2,10 +2,27 @@
use crate::numeric::*; use crate::numeric::*;
// Small set of bases for the Miller-Rabin prime test, valid for all 64b integers; pub(crate) trait Basis {
// discovered by Jim Sinclair on 2011-04-20, see miller-rabin.appspot.com const BASIS: &'static [u64];
#[allow(clippy::unreadable_literal)] }
const BASIS: [u64; 7] = [2, 325, 9375, 28178, 450775, 9780504, 1795265022];
impl Basis for Montgomery<u64> {
// Small set of bases for the Miller-Rabin prime test, valid for all 64b integers;
// discovered by Jim Sinclair on 2011-04-20, see miller-rabin.appspot.com
#[allow(clippy::unreadable_literal)]
const BASIS: &'static [u64] = &[2, 325, 9375, 28178, 450775, 9780504, 1795265022];
}
impl Basis for Montgomery<u32> {
// Small set of bases for the Miller-Rabin prime test, valid for all 32b integers;
// discovered by Steve Worley on 2013-05-27, see miller-rabin.appspot.com
#[allow(clippy::unreadable_literal)]
const BASIS: &'static [u64] = &[
4230279247111683200,
14694767155120705706,
16641139526367750375,
];
}
#[derive(Eq, PartialEq)] #[derive(Eq, PartialEq)]
pub(crate) enum Result { pub(crate) enum Result {
@ -23,16 +40,12 @@ impl Result {
// Deterministic Miller-Rabin primality-checking algorithm, adapted to extract // Deterministic Miller-Rabin primality-checking algorithm, adapted to extract
// (some) dividers; it will fail to factor strong pseudoprimes. // (some) dividers; it will fail to factor strong pseudoprimes.
#[allow(clippy::many_single_char_names)] #[allow(clippy::many_single_char_names)]
pub(crate) fn test<A: Arithmetic>(m: A) -> Result { pub(crate) fn test<A: Arithmetic + Basis>(m: A) -> Result {
use self::Result::*; use self::Result::*;
let n = m.modulus(); let n = m.modulus();
if n < 2 { debug_assert!(n > 1);
return Pseudoprime; debug_assert!(n % 2 != 0);
}
if n % 2 == 0 {
return if n == 2 { Prime } else { Composite(2) };
}
// n-1 = r 2ⁱ // n-1 = r 2ⁱ
let i = (n - 1).trailing_zeros(); let i = (n - 1).trailing_zeros();
@ -41,10 +54,10 @@ pub(crate) fn test<A: Arithmetic>(m: A) -> Result {
let one = m.one(); let one = m.one();
let minus_one = m.minus_one(); let minus_one = m.minus_one();
for _a in BASIS.iter() { for _a in A::BASIS.iter() {
let _a = _a % n; let _a = _a % n;
if _a == 0 { if _a == 0 {
break; continue;
} }
let a = m.from_u64(_a); let a = m.from_u64(_a);
@ -87,48 +100,113 @@ pub(crate) fn test<A: Arithmetic>(m: A) -> Result {
// Used by build.rs' tests and debug assertions // Used by build.rs' tests and debug assertions
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn is_prime(n: u64) -> bool { pub(crate) fn is_prime(n: u64) -> bool {
if n % 2 == 0 { if n < 2 {
false
} else if n % 2 == 0 {
n == 2 n == 2
} else { } else {
test::<Montgomery>(Montgomery::new(n)).is_prime() test::<Montgomery<u64>>(Montgomery::new(n)).is_prime()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::is_prime; use super::*;
use crate::numeric::{Arithmetic, Montgomery};
use quickcheck::quickcheck;
use std::iter;
const LARGEST_U64_PRIME: u64 = 0xFFFFFFFFFFFFFFC5; const LARGEST_U64_PRIME: u64 = 0xFFFFFFFFFFFFFFC5;
fn primes() -> impl Iterator<Item = u64> {
iter::once(2).chain(odd_primes())
}
fn odd_primes() -> impl Iterator<Item = u64> {
use crate::table::{NEXT_PRIME, P_INVS_U64};
P_INVS_U64
.iter()
.map(|(p, _, _)| *p)
.chain(iter::once(NEXT_PRIME))
}
#[test] #[test]
fn largest_prime() { fn largest_prime() {
assert!(is_prime(LARGEST_U64_PRIME)); assert!(is_prime(LARGEST_U64_PRIME));
} }
#[test] #[test]
fn first_primes() { fn largest_composites() {
use crate::table::{NEXT_PRIME, P_INVS_U64}; for i in LARGEST_U64_PRIME + 1..=u64::MAX {
for (p, _, _) in P_INVS_U64.iter() { assert!(!is_prime(i), "2⁶⁴ - {} reported prime", u64::MAX - i + 1);
assert!(is_prime(*p), "{} reported composite", p);
} }
assert!(is_prime(NEXT_PRIME));
} }
#[test]
fn two() {
assert!(is_prime(2));
}
// TODO: Deduplicate with macro in numeric.rs
macro_rules! parametrized_check {
( $f:ident ) => {
paste::item! {
#[test]
fn [< $f _ u32 >]() {
$f::<Montgomery<u32>>()
}
#[test]
fn [< $f _ u64 >]() {
$f::<Montgomery<u64>>()
}
}
};
}
fn first_primes<A: Arithmetic + Basis>() {
for p in odd_primes() {
assert!(test(A::new(p)).is_prime(), "{} reported composite", p);
}
}
parametrized_check!(first_primes);
#[test]
fn one() {
assert!(!is_prime(1));
}
#[test]
fn zero() {
assert!(!is_prime(0));
}
fn first_composites<A: Arithmetic + Basis>() {
for (p, q) in primes().zip(odd_primes()) {
for i in p + 1..q {
assert!(!is_prime(i), "{} reported prime", i);
}
}
}
parametrized_check!(first_composites);
#[test] #[test]
fn issue_1556() { fn issue_1556() {
// 10 425 511 = 2441 × 4271 // 10 425 511 = 2441 × 4271
assert!(!is_prime(10_425_511)); assert!(!is_prime(10_425_511));
} }
#[test] fn small_semiprimes<A: Arithmetic + Basis>() {
fn small_composites() { for p in odd_primes() {
use crate::table::P_INVS_U64; for q in odd_primes().take_while(|q| *q <= p) {
for i in 0..P_INVS_U64.len() {
let (p, _, _) = P_INVS_U64[i];
for (q, _, _) in &P_INVS_U64[0..i] {
let n = p * q; let n = p * q;
assert!(!is_prime(n), "{} = {} × {} reported prime", n, p, q); let m = A::new(n);
assert!(!test(m).is_prime(), "{} = {} × {} reported prime", n, p, q);
} }
} }
} }
parametrized_check!(small_semiprimes);
quickcheck! {
fn composites(i: u64, j: u64) -> bool {
i < 2 || j < 2 || !is_prime(i*j)
}
}
} }

View file

@ -6,6 +6,12 @@
// * For the full copyright and license information, please view the LICENSE file // * For the full copyright and license information, please view the LICENSE file
// * that was distributed with this source code. // * that was distributed with this source code.
use num_traits::{
identities::{One, Zero},
int::PrimInt,
ops::wrapping::{WrappingMul, WrappingNeg, WrappingSub},
};
use std::fmt::{Debug, Display};
use std::mem::swap; use std::mem::swap;
// This is incorrectly reported as dead code, // This is incorrectly reported as dead code,
@ -20,16 +26,17 @@ pub(crate) fn gcd(mut a: u64, mut b: u64) -> u64 {
} }
pub(crate) trait Arithmetic: Copy + Sized { pub(crate) trait Arithmetic: Copy + Sized {
type I: Copy + Sized + Eq; // The type of integers mod m, in some opaque representation
type ModInt: Copy + Sized + Eq;
fn new(m: u64) -> Self; fn new(m: u64) -> Self;
fn modulus(&self) -> u64; fn modulus(&self) -> u64;
fn from_u64(&self, n: u64) -> Self::I; fn from_u64(&self, n: u64) -> Self::ModInt;
fn to_u64(&self, n: Self::I) -> u64; fn to_u64(&self, n: Self::ModInt) -> u64;
fn add(&self, a: Self::I, b: Self::I) -> Self::I; fn add(&self, a: Self::ModInt, b: Self::ModInt) -> Self::ModInt;
fn mul(&self, a: Self::I, b: Self::I) -> Self::I; fn mul(&self, a: Self::ModInt, b: Self::ModInt) -> Self::ModInt;
fn pow(&self, mut a: Self::I, mut b: u64) -> Self::I { fn pow(&self, mut a: Self::ModInt, mut b: u64) -> Self::ModInt {
let (_a, _b) = (a, b); let (_a, _b) = (a, b);
let mut result = self.one(); let mut result = self.one();
while b > 0 { while b > 0 {
@ -54,75 +61,90 @@ pub(crate) trait Arithmetic: Copy + Sized {
result result
} }
fn one(&self) -> Self::I { fn one(&self) -> Self::ModInt {
self.from_u64(1) self.from_u64(1)
} }
fn minus_one(&self) -> Self::I { fn minus_one(&self) -> Self::ModInt {
self.from_u64(self.modulus() - 1) self.from_u64(self.modulus() - 1)
} }
fn zero(&self) -> Self::I { fn zero(&self) -> Self::ModInt {
self.from_u64(0) self.from_u64(0)
} }
} }
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub(crate) struct Montgomery { pub(crate) struct Montgomery<T: DoubleInt> {
a: u64, a: T,
n: u64, n: T,
} }
impl Montgomery { impl<T: DoubleInt> Montgomery<T> {
/// computes x/R mod n efficiently /// computes x/R mod n efficiently
fn reduce(&self, x: u128) -> u64 { fn reduce(&self, x: T::DoubleWidth) -> T {
debug_assert!(x < (self.n as u128) << 64); let t_bits = T::zero().count_zeros() as usize;
debug_assert!(x < (self.n.as_double_width()) << t_bits);
// TODO: optimiiiiiiise // TODO: optimiiiiiiise
let Montgomery { a, n } = self; let Montgomery { a, n } = self;
let m = (x as u64).wrapping_mul(*a); let m = T::from_double_width(x).wrapping_mul(a);
let nm = (*n as u128) * (m as u128); let nm = (n.as_double_width()) * (m.as_double_width());
let (xnm, overflow) = (x as u128).overflowing_add(nm); // x + n*m let (xnm, overflow) = x.overflowing_add_(nm); // x + n*m
debug_assert_eq!(xnm % (1 << 64), 0); debug_assert_eq!(
xnm % (T::DoubleWidth::one() << T::zero().count_zeros() as usize),
T::DoubleWidth::zero()
);
// (x + n*m) / R // (x + n*m) / R
// in case of overflow, this is (2¹²⁸ + xnm)/2⁶⁴ - n = xnm/2⁶⁴ + (2⁶⁴ - n) // in case of overflow, this is (2¹²⁸ + xnm)/2⁶⁴ - n = xnm/2⁶⁴ + (2⁶⁴ - n)
let y = (xnm >> 64) as u64 + if !overflow { 0 } else { n.wrapping_neg() }; let y = T::from_double_width(xnm >> t_bits)
+ if !overflow {
T::zero()
} else {
n.wrapping_neg()
};
if y >= *n { if y >= *n {
y - n y - *n
} else { } else {
y y
} }
} }
} }
impl Arithmetic for Montgomery { impl<T: DoubleInt> Arithmetic for Montgomery<T> {
// Montgomery transform, R=2⁶⁴ // Montgomery transform, R=2⁶⁴
// Provides fast arithmetic mod n (n odd, u64) // Provides fast arithmetic mod n (n odd, u64)
type I = u64; type ModInt = T;
fn new(n: u64) -> Self { fn new(n: u64) -> Self {
let a = inv_mod_u64(n).wrapping_neg(); debug_assert!(T::zero().count_zeros() >= 64 || n < (1 << T::zero().count_zeros() as usize));
debug_assert_eq!(n.wrapping_mul(a), 1_u64.wrapping_neg()); let n = T::from_u64(n);
let a = modular_inverse(n).wrapping_neg();
debug_assert_eq!(n.wrapping_mul(&a), T::one().wrapping_neg());
Montgomery { a, n } Montgomery { a, n }
} }
fn modulus(&self) -> u64 { fn modulus(&self) -> u64 {
self.n self.n.as_u64()
} }
fn from_u64(&self, x: u64) -> Self::I { fn from_u64(&self, x: u64) -> Self::ModInt {
// TODO: optimise! // TODO: optimise!
assert!(x < self.n); debug_assert!(x < self.n.as_u64());
let r = (((x as u128) << 64) % self.n as u128) as u64; let r = T::from_double_width(
((T::DoubleWidth::from_u64(x)) << T::zero().count_zeros() as usize)
% self.n.as_double_width(),
);
debug_assert_eq!(x, self.to_u64(r)); debug_assert_eq!(x, self.to_u64(r));
r r
} }
fn to_u64(&self, n: Self::I) -> u64 { fn to_u64(&self, n: Self::ModInt) -> u64 {
self.reduce(n as u128) self.reduce(n.as_double_width()).as_u64()
} }
fn add(&self, a: Self::I, b: Self::I) -> Self::I { fn add(&self, a: Self::ModInt, b: Self::ModInt) -> Self::ModInt {
let (r, overflow) = a.overflowing_add(b); let (r, overflow) = a.overflowing_add_(b);
// In case of overflow, a+b = 2⁶⁴ + r = (2⁶⁴ - n) + r (working mod n) // In case of overflow, a+b = 2⁶⁴ + r = (2⁶⁴ - n) + r (working mod n)
let r = if !overflow { let r = if !overflow {
@ -138,10 +160,10 @@ impl Arithmetic for Montgomery {
// a+b % n // a+b % n
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
let a_r = self.to_u64(a); let a_r = self.to_u64(a) as u128;
let b_r = self.to_u64(b); let b_r = self.to_u64(b) as u128;
let r_r = self.to_u64(r); let r_r = self.to_u64(r);
let r_2 = (((a_r as u128) + (b_r as u128)) % (self.n as u128)) as u64; let r_2 = ((a_r + b_r) % self.n.as_u128()) as u64;
debug_assert_eq!( debug_assert_eq!(
r_r, r_2, r_r, r_2,
"[{}] = {} ≠ {} = {} + {} = [{}] + [{}] mod {}; a = {}", "[{}] = {} ≠ {} = {} + {} = [{}] + [{}] mod {}; a = {}",
@ -151,17 +173,17 @@ impl Arithmetic for Montgomery {
r r
} }
fn mul(&self, a: Self::I, b: Self::I) -> Self::I { fn mul(&self, a: Self::ModInt, b: Self::ModInt) -> Self::ModInt {
let r = self.reduce((a as u128) * (b as u128)); let r = self.reduce(a.as_double_width() * b.as_double_width());
// Check that r (reduced back to the usual representation) equals // Check that r (reduced back to the usual representation) equals
// a*b % n // a*b % n
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
let a_r = self.to_u64(a); let a_r = self.to_u64(a) as u128;
let b_r = self.to_u64(b); let b_r = self.to_u64(b) as u128;
let r_r = self.to_u64(r); let r_r = self.to_u64(r);
let r_2 = (((a_r as u128) * (b_r as u128)) % (self.n as u128)) as u64; let r_2: u64 = ((a_r * b_r) % self.n.as_u128()) as u64;
debug_assert_eq!( debug_assert_eq!(
r_r, r_2, r_r, r_2,
"[{}] = {} ≠ {} = {} * {} = [{}] * [{}] mod {}; a = {}", "[{}] = {} ≠ {} = {} * {} = [{}] * [{}] mod {}; a = {}",
@ -172,35 +194,112 @@ impl Arithmetic for Montgomery {
} }
} }
// NOTE: Trait can be removed once num-traits adds a similar one;
// see https://github.com/rust-num/num-traits/issues/168
pub(crate) trait OverflowingAdd: Sized {
fn overflowing_add_(self, n: Self) -> (Self, bool);
}
macro_rules! overflowing {
($x:ty) => {
impl OverflowingAdd for $x {
fn overflowing_add_(self, n: Self) -> (Self, bool) {
self.overflowing_add(n)
}
}
};
}
overflowing!(u32);
overflowing!(u64);
overflowing!(u128);
pub(crate) trait Int:
Display + Debug + PrimInt + OverflowingAdd + WrappingNeg + WrappingSub + WrappingMul
{
fn as_u64(&self) -> u64;
fn from_u64(n: u64) -> Self;
#[cfg(debug_assertions)]
fn as_u128(&self) -> u128;
}
pub(crate) trait DoubleInt: Int {
/// An integer type with twice the width of `Self`.
/// In particular, multiplications (of `Int` values) can be performed in
/// `Self::DoubleWidth` without possibility of overflow.
type DoubleWidth: Int;
fn as_double_width(self) -> Self::DoubleWidth;
fn from_double_width(n: Self::DoubleWidth) -> Self;
}
macro_rules! int {
( $x:ty ) => {
impl Int for $x {
fn as_u64(&self) -> u64 {
*self as u64
}
fn from_u64(n: u64) -> Self {
n as _
}
#[cfg(debug_assertions)]
fn as_u128(&self) -> u128 {
*self as u128
}
}
};
}
macro_rules! double_int {
( $x:ty, $y:ty ) => {
int!($x);
impl DoubleInt for $x {
type DoubleWidth = $y;
fn as_double_width(self) -> $y {
self as _
}
fn from_double_width(n: $y) -> $x {
n as _
}
}
};
}
double_int!(u32, u64);
double_int!(u64, u128);
int!(u128);
// extended Euclid algorithm // extended Euclid algorithm
// precondition: a is odd // precondition: a is odd
pub(crate) fn inv_mod_u64(a: u64) -> u64 { pub(crate) fn modular_inverse<T: Int>(a: T) -> T {
assert!(a % 2 == 1, "{} is not odd", a); let zero = T::zero();
let mut t = 0u64; let one = T::one();
let mut newt = 1u64; debug_assert!(a % (one + one) == one, "{:?} is not odd", a);
let mut r = 0u64;
let mut t = zero;
let mut newt = one;
let mut r = zero;
let mut newr = a; let mut newr = a;
while newr != 0 { while newr != zero {
let quot = if r == 0 { let quot = if r == zero {
// special case when we're just starting out // special case when we're just starting out
// This works because we know that // This works because we know that
// a does not divide 2^64, so floor(2^64 / a) == floor((2^64-1) / a); // a does not divide 2^64, so floor(2^64 / a) == floor((2^64-1) / a);
std::u64::MAX T::max_value()
} else { } else {
r r
} / newr; } / newr;
let newtp = t.wrapping_sub(quot.wrapping_mul(newt)); let newtp = t.wrapping_sub(&quot.wrapping_mul(&newt));
t = newt; t = newt;
newt = newtp; newt = newtp;
let newrp = r.wrapping_sub(quot.wrapping_mul(newr)); let newrp = r.wrapping_sub(&quot.wrapping_mul(&newr));
r = newr; r = newr;
newr = newrp; newr = newrp;
} }
assert_eq!(r, 1); debug_assert_eq!(r, one);
t t
} }
@ -208,19 +307,37 @@ pub(crate) fn inv_mod_u64(a: u64) -> u64 {
mod tests { mod tests {
use super::*; use super::*;
macro_rules! parametrized_check {
( $f:ident ) => {
paste::item! {
#[test] #[test]
fn test_inverter() { fn [< $f _ u32 >]() {
// All odd integers from 1 to 20 000 $f::<u32>()
let mut test_values = (0..10_000u64).map(|i| 2 * i + 1); }
#[test]
assert!(test_values.all(|x| x.wrapping_mul(inv_mod_u64(x)) == 1)); fn [< $f _ u64 >]() {
$f::<u64>()
}
}
};
} }
#[test] fn test_inverter<T: Int>() {
fn test_montgomery_add() { // All odd integers from 1 to 20 000
let one = T::from(1).unwrap();
let two = T::from(2).unwrap();
let mut test_values = (0..10_000)
.map(|i| T::from(i).unwrap())
.map(|i| two * i + one);
assert!(test_values.all(|x| x.wrapping_mul(&modular_inverse(x)) == one));
}
parametrized_check!(test_inverter);
fn test_add<A: DoubleInt>() {
for n in 0..100 { for n in 0..100 {
let n = 2 * n + 1; let n = 2 * n + 1;
let m = Montgomery::new(n); let m = Montgomery::<A>::new(n);
for x in 0..n { for x in 0..n {
let m_x = m.from_u64(x); let m_x = m.from_u64(x);
for y in 0..=x { for y in 0..=x {
@ -231,12 +348,12 @@ mod tests {
} }
} }
} }
parametrized_check!(test_add);
#[test] fn test_mult<A: DoubleInt>() {
fn test_montgomery_mult() {
for n in 0..100 { for n in 0..100 {
let n = 2 * n + 1; let n = 2 * n + 1;
let m = Montgomery::new(n); let m = Montgomery::<A>::new(n);
for x in 0..n { for x in 0..n {
let m_x = m.from_u64(x); let m_x = m.from_u64(x);
for y in 0..=x { for y in 0..=x {
@ -246,16 +363,17 @@ mod tests {
} }
} }
} }
parametrized_check!(test_mult);
#[test] fn test_roundtrip<A: DoubleInt>() {
fn test_montgomery_roundtrip() {
for n in 0..100 { for n in 0..100 {
let n = 2 * n + 1; let n = 2 * n + 1;
let m = Montgomery::new(n); let m = Montgomery::<A>::new(n);
for x in 0..n { for x in 0..n {
let x_ = m.from_u64(x); let x_ = m.from_u64(x);
assert_eq!(x, m.to_u64(x_)); assert_eq!(x, m.to_u64(x_));
} }
} }
} }
parametrized_check!(test_roundtrip);
} }