mirror of
https://github.com/uutils/coreutils
synced 2024-12-14 07:12:44 +00:00
Merge pull request #1554 from nbraud/factor/faster/montgomery32
factor: Refactor and improve performance (plus a few bug fixes)
This commit is contained in:
commit
8cda0f596e
6 changed files with 339 additions and 103 deletions
27
Cargo.lock
generated
27
Cargo.lock
generated
|
@ -649,6 +649,23 @@ dependencies = [
|
|||
"pkg-config 0.3.17 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "0.1.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"paste-impl 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"proc-macro-hack 0.5.16 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste-impl"
|
||||
version = "0.1.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
dependencies = [
|
||||
"proc-macro-hack 0.5.16 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pkg-config"
|
||||
version = "0.3.17"
|
||||
|
@ -668,6 +685,11 @@ name = "ppv-lite86"
|
|||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-hack"
|
||||
version = "0.5.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.18"
|
||||
|
@ -1335,6 +1357,8 @@ dependencies = [
|
|||
name = "uu_factor"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"num-traits 0.2.12 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"paste 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"quickcheck 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"rand 0.5.6 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"uucore 0.0.4 (git+https://github.com/uutils/uucore.git?branch=canary)",
|
||||
|
@ -2253,9 +2277,12 @@ dependencies = [
|
|||
"checksum numtoa 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "b8f8bdf33df195859076e54ab11ee78a1b208382d3a26ec40d142ffc1ecc49ef"
|
||||
"checksum onig 4.3.3 (registry+https://github.com/rust-lang/crates.io-index)" = "8518fcb2b1b8c2f45f0ad499df4fda6087fc3475ca69a185c173b8315d2fb383"
|
||||
"checksum onig_sys 69.1.0 (registry+https://github.com/rust-lang/crates.io-index)" = "388410bf5fa341f10e58e6db3975f4bea1ac30247dd79d37a9e5ced3cb4cc3b0"
|
||||
"checksum paste 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)" = "45ca20c77d80be666aef2b45486da86238fabe33e38306bd3118fe4af33fa880"
|
||||
"checksum paste-impl 0.1.18 (registry+https://github.com/rust-lang/crates.io-index)" = "d95a7db200b97ef370c8e6de0088252f7e0dfff7d047a28528e47456c0fc98b6"
|
||||
"checksum pkg-config 0.3.17 (registry+https://github.com/rust-lang/crates.io-index)" = "05da548ad6865900e60eaba7f589cc0783590a92e940c26953ff81ddbab2d677"
|
||||
"checksum platform-info 0.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f2fd076acdc7a98374de6e300bf3af675997225bef21aecac2219553f04dd7e8"
|
||||
"checksum ppv-lite86 0.2.8 (registry+https://github.com/rust-lang/crates.io-index)" = "237a5ed80e274dbc66f86bd59c1e25edc039660be53194b5fe0a482e0f2612ea"
|
||||
"checksum proc-macro-hack 0.5.16 (registry+https://github.com/rust-lang/crates.io-index)" = "7e0456befd48169b9f13ef0f0ad46d492cf9d2dbb918bcf38e01eed4ce3ec5e4"
|
||||
"checksum proc-macro2 1.0.18 (registry+https://github.com/rust-lang/crates.io-index)" = "beae6331a816b1f65d04c45b078fd8e6c93e8071771f41b8163255bbd8d7c8fa"
|
||||
"checksum quick-error 1.2.3 (registry+https://github.com/rust-lang/crates.io-index)" = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
|
||||
"checksum quickcheck 0.9.2 (registry+https://github.com/rust-lang/crates.io-index)" = "a44883e74aa97ad63db83c4bf8ca490f02b2fc02f92575e720c8551e843c945f"
|
||||
|
|
|
@ -11,12 +11,18 @@ keywords = ["coreutils", "uutils", "cross-platform", "cli", "utility"]
|
|||
categories = ["command-line-utilities"]
|
||||
edition = "2018"
|
||||
|
||||
[build-dependencies]
|
||||
num-traits = "0.2" # used in src/numerics.rs, which is included by build.rs
|
||||
|
||||
|
||||
[dependencies]
|
||||
num-traits = "0.2"
|
||||
rand = "0.5"
|
||||
uucore = { version="0.0.4", package="uucore", git="https://github.com/uutils/uucore.git", branch="canary" }
|
||||
uucore_procs = { version="0.0.4", package="uucore_procs", git="https://github.com/uutils/uucore.git", branch="canary" }
|
||||
|
||||
[dev-dependencies]
|
||||
paste = "0.1.18"
|
||||
quickcheck = "0.9.2"
|
||||
|
||||
[[bin]]
|
||||
|
|
|
@ -29,7 +29,7 @@ use miller_rabin::is_prime;
|
|||
|
||||
#[path = "src/numeric.rs"]
|
||||
mod numeric;
|
||||
use numeric::inv_mod_u64;
|
||||
use numeric::modular_inverse;
|
||||
|
||||
mod sieve;
|
||||
|
||||
|
@ -57,7 +57,7 @@ fn main() {
|
|||
let mut x = primes.next().unwrap();
|
||||
for next in primes {
|
||||
// format the table
|
||||
let outstr = format!("({}, {}, {}),", x, inv_mod_u64(x), std::u64::MAX / x);
|
||||
let outstr = format!("({}, {}, {}),", x, modular_inverse(x), std::u64::MAX / x);
|
||||
if cols + outstr.len() > MAX_WIDTH {
|
||||
write!(file, "\n {}", outstr).unwrap();
|
||||
cols = 4 + outstr.len();
|
||||
|
|
|
@ -54,12 +54,16 @@ impl fmt::Display for Factors {
|
|||
}
|
||||
}
|
||||
|
||||
fn _factor<A: Arithmetic>(num: u64, f: Factors) -> Factors {
|
||||
fn _factor<A: Arithmetic + miller_rabin::Basis>(num: u64, f: Factors) -> Factors {
|
||||
use miller_rabin::Result::*;
|
||||
|
||||
// Shadow the name, so the recursion automatically goes from “Big” arithmetic to small.
|
||||
let _factor = |n, f| {
|
||||
// TODO: Optimise with 32 and 64b versions
|
||||
_factor::<A>(n, f)
|
||||
if n < (1 << 32) {
|
||||
_factor::<Montgomery<u32>>(n, f)
|
||||
} else {
|
||||
_factor::<A>(n, f)
|
||||
}
|
||||
};
|
||||
|
||||
if num == 1 {
|
||||
|
@ -101,8 +105,11 @@ pub fn factor(mut n: u64) -> Factors {
|
|||
|
||||
let (factors, n) = table::factor(n, factors);
|
||||
|
||||
// TODO: Optimise with 32 and 64b versions
|
||||
_factor::<Montgomery>(n, factors)
|
||||
if n < (1 << 32) {
|
||||
_factor::<Montgomery<u32>>(n, factors)
|
||||
} else {
|
||||
_factor::<Montgomery<u64>>(n, factors)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -2,10 +2,27 @@
|
|||
|
||||
use crate::numeric::*;
|
||||
|
||||
// Small set of bases for the Miller-Rabin prime test, valid for all 64b integers;
|
||||
// discovered by Jim Sinclair on 2011-04-20, see miller-rabin.appspot.com
|
||||
#[allow(clippy::unreadable_literal)]
|
||||
const BASIS: [u64; 7] = [2, 325, 9375, 28178, 450775, 9780504, 1795265022];
|
||||
pub(crate) trait Basis {
|
||||
const BASIS: &'static [u64];
|
||||
}
|
||||
|
||||
impl Basis for Montgomery<u64> {
|
||||
// Small set of bases for the Miller-Rabin prime test, valid for all 64b integers;
|
||||
// discovered by Jim Sinclair on 2011-04-20, see miller-rabin.appspot.com
|
||||
#[allow(clippy::unreadable_literal)]
|
||||
const BASIS: &'static [u64] = &[2, 325, 9375, 28178, 450775, 9780504, 1795265022];
|
||||
}
|
||||
|
||||
impl Basis for Montgomery<u32> {
|
||||
// Small set of bases for the Miller-Rabin prime test, valid for all 32b integers;
|
||||
// discovered by Steve Worley on 2013-05-27, see miller-rabin.appspot.com
|
||||
#[allow(clippy::unreadable_literal)]
|
||||
const BASIS: &'static [u64] = &[
|
||||
4230279247111683200,
|
||||
14694767155120705706,
|
||||
16641139526367750375,
|
||||
];
|
||||
}
|
||||
|
||||
#[derive(Eq, PartialEq)]
|
||||
pub(crate) enum Result {
|
||||
|
@ -23,16 +40,12 @@ impl Result {
|
|||
// Deterministic Miller-Rabin primality-checking algorithm, adapted to extract
|
||||
// (some) dividers; it will fail to factor strong pseudoprimes.
|
||||
#[allow(clippy::many_single_char_names)]
|
||||
pub(crate) fn test<A: Arithmetic>(m: A) -> Result {
|
||||
pub(crate) fn test<A: Arithmetic + Basis>(m: A) -> Result {
|
||||
use self::Result::*;
|
||||
|
||||
let n = m.modulus();
|
||||
if n < 2 {
|
||||
return Pseudoprime;
|
||||
}
|
||||
if n % 2 == 0 {
|
||||
return if n == 2 { Prime } else { Composite(2) };
|
||||
}
|
||||
debug_assert!(n > 1);
|
||||
debug_assert!(n % 2 != 0);
|
||||
|
||||
// n-1 = r 2ⁱ
|
||||
let i = (n - 1).trailing_zeros();
|
||||
|
@ -41,10 +54,10 @@ pub(crate) fn test<A: Arithmetic>(m: A) -> Result {
|
|||
let one = m.one();
|
||||
let minus_one = m.minus_one();
|
||||
|
||||
for _a in BASIS.iter() {
|
||||
for _a in A::BASIS.iter() {
|
||||
let _a = _a % n;
|
||||
if _a == 0 {
|
||||
break;
|
||||
continue;
|
||||
}
|
||||
|
||||
let a = m.from_u64(_a);
|
||||
|
@ -87,48 +100,113 @@ pub(crate) fn test<A: Arithmetic>(m: A) -> Result {
|
|||
// Used by build.rs' tests and debug assertions
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn is_prime(n: u64) -> bool {
|
||||
if n % 2 == 0 {
|
||||
if n < 2 {
|
||||
false
|
||||
} else if n % 2 == 0 {
|
||||
n == 2
|
||||
} else {
|
||||
test::<Montgomery>(Montgomery::new(n)).is_prime()
|
||||
test::<Montgomery<u64>>(Montgomery::new(n)).is_prime()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::is_prime;
|
||||
use super::*;
|
||||
use crate::numeric::{Arithmetic, Montgomery};
|
||||
use quickcheck::quickcheck;
|
||||
use std::iter;
|
||||
const LARGEST_U64_PRIME: u64 = 0xFFFFFFFFFFFFFFC5;
|
||||
|
||||
fn primes() -> impl Iterator<Item = u64> {
|
||||
iter::once(2).chain(odd_primes())
|
||||
}
|
||||
|
||||
fn odd_primes() -> impl Iterator<Item = u64> {
|
||||
use crate::table::{NEXT_PRIME, P_INVS_U64};
|
||||
P_INVS_U64
|
||||
.iter()
|
||||
.map(|(p, _, _)| *p)
|
||||
.chain(iter::once(NEXT_PRIME))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn largest_prime() {
|
||||
assert!(is_prime(LARGEST_U64_PRIME));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn first_primes() {
|
||||
use crate::table::{NEXT_PRIME, P_INVS_U64};
|
||||
for (p, _, _) in P_INVS_U64.iter() {
|
||||
assert!(is_prime(*p), "{} reported composite", p);
|
||||
fn largest_composites() {
|
||||
for i in LARGEST_U64_PRIME + 1..=u64::MAX {
|
||||
assert!(!is_prime(i), "2⁶⁴ - {} reported prime", u64::MAX - i + 1);
|
||||
}
|
||||
assert!(is_prime(NEXT_PRIME));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two() {
|
||||
assert!(is_prime(2));
|
||||
}
|
||||
|
||||
// TODO: Deduplicate with macro in numeric.rs
|
||||
macro_rules! parametrized_check {
|
||||
( $f:ident ) => {
|
||||
paste::item! {
|
||||
#[test]
|
||||
fn [< $f _ u32 >]() {
|
||||
$f::<Montgomery<u32>>()
|
||||
}
|
||||
#[test]
|
||||
fn [< $f _ u64 >]() {
|
||||
$f::<Montgomery<u64>>()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn first_primes<A: Arithmetic + Basis>() {
|
||||
for p in odd_primes() {
|
||||
assert!(test(A::new(p)).is_prime(), "{} reported composite", p);
|
||||
}
|
||||
}
|
||||
parametrized_check!(first_primes);
|
||||
|
||||
#[test]
|
||||
fn one() {
|
||||
assert!(!is_prime(1));
|
||||
}
|
||||
#[test]
|
||||
fn zero() {
|
||||
assert!(!is_prime(0));
|
||||
}
|
||||
|
||||
fn first_composites<A: Arithmetic + Basis>() {
|
||||
for (p, q) in primes().zip(odd_primes()) {
|
||||
for i in p + 1..q {
|
||||
assert!(!is_prime(i), "{} reported prime", i);
|
||||
}
|
||||
}
|
||||
}
|
||||
parametrized_check!(first_composites);
|
||||
|
||||
#[test]
|
||||
fn issue_1556() {
|
||||
// 10 425 511 = 2441 × 4271
|
||||
assert!(!is_prime(10_425_511));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn small_composites() {
|
||||
use crate::table::P_INVS_U64;
|
||||
|
||||
for i in 0..P_INVS_U64.len() {
|
||||
let (p, _, _) = P_INVS_U64[i];
|
||||
for (q, _, _) in &P_INVS_U64[0..i] {
|
||||
fn small_semiprimes<A: Arithmetic + Basis>() {
|
||||
for p in odd_primes() {
|
||||
for q in odd_primes().take_while(|q| *q <= p) {
|
||||
let n = p * q;
|
||||
assert!(!is_prime(n), "{} = {} × {} reported prime", n, p, q);
|
||||
let m = A::new(n);
|
||||
assert!(!test(m).is_prime(), "{} = {} × {} reported prime", n, p, q);
|
||||
}
|
||||
}
|
||||
}
|
||||
parametrized_check!(small_semiprimes);
|
||||
|
||||
quickcheck! {
|
||||
fn composites(i: u64, j: u64) -> bool {
|
||||
i < 2 || j < 2 || !is_prime(i*j)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,12 @@
|
|||
// * For the full copyright and license information, please view the LICENSE file
|
||||
// * that was distributed with this source code.
|
||||
|
||||
use num_traits::{
|
||||
identities::{One, Zero},
|
||||
int::PrimInt,
|
||||
ops::wrapping::{WrappingMul, WrappingNeg, WrappingSub},
|
||||
};
|
||||
use std::fmt::{Debug, Display};
|
||||
use std::mem::swap;
|
||||
|
||||
// This is incorrectly reported as dead code,
|
||||
|
@ -20,16 +26,17 @@ pub(crate) fn gcd(mut a: u64, mut b: u64) -> u64 {
|
|||
}
|
||||
|
||||
pub(crate) trait Arithmetic: Copy + Sized {
|
||||
type I: Copy + Sized + Eq;
|
||||
// The type of integers mod m, in some opaque representation
|
||||
type ModInt: Copy + Sized + Eq;
|
||||
|
||||
fn new(m: u64) -> Self;
|
||||
fn modulus(&self) -> u64;
|
||||
fn from_u64(&self, n: u64) -> Self::I;
|
||||
fn to_u64(&self, n: Self::I) -> u64;
|
||||
fn add(&self, a: Self::I, b: Self::I) -> Self::I;
|
||||
fn mul(&self, a: Self::I, b: Self::I) -> Self::I;
|
||||
fn from_u64(&self, n: u64) -> Self::ModInt;
|
||||
fn to_u64(&self, n: Self::ModInt) -> u64;
|
||||
fn add(&self, a: Self::ModInt, b: Self::ModInt) -> Self::ModInt;
|
||||
fn mul(&self, a: Self::ModInt, b: Self::ModInt) -> Self::ModInt;
|
||||
|
||||
fn pow(&self, mut a: Self::I, mut b: u64) -> Self::I {
|
||||
fn pow(&self, mut a: Self::ModInt, mut b: u64) -> Self::ModInt {
|
||||
let (_a, _b) = (a, b);
|
||||
let mut result = self.one();
|
||||
while b > 0 {
|
||||
|
@ -54,75 +61,90 @@ pub(crate) trait Arithmetic: Copy + Sized {
|
|||
result
|
||||
}
|
||||
|
||||
fn one(&self) -> Self::I {
|
||||
fn one(&self) -> Self::ModInt {
|
||||
self.from_u64(1)
|
||||
}
|
||||
fn minus_one(&self) -> Self::I {
|
||||
fn minus_one(&self) -> Self::ModInt {
|
||||
self.from_u64(self.modulus() - 1)
|
||||
}
|
||||
fn zero(&self) -> Self::I {
|
||||
fn zero(&self) -> Self::ModInt {
|
||||
self.from_u64(0)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) struct Montgomery {
|
||||
a: u64,
|
||||
n: u64,
|
||||
pub(crate) struct Montgomery<T: DoubleInt> {
|
||||
a: T,
|
||||
n: T,
|
||||
}
|
||||
|
||||
impl Montgomery {
|
||||
impl<T: DoubleInt> Montgomery<T> {
|
||||
/// computes x/R mod n efficiently
|
||||
fn reduce(&self, x: u128) -> u64 {
|
||||
debug_assert!(x < (self.n as u128) << 64);
|
||||
fn reduce(&self, x: T::DoubleWidth) -> T {
|
||||
let t_bits = T::zero().count_zeros() as usize;
|
||||
|
||||
debug_assert!(x < (self.n.as_double_width()) << t_bits);
|
||||
// TODO: optimiiiiiiise
|
||||
let Montgomery { a, n } = self;
|
||||
let m = (x as u64).wrapping_mul(*a);
|
||||
let nm = (*n as u128) * (m as u128);
|
||||
let (xnm, overflow) = (x as u128).overflowing_add(nm); // x + n*m
|
||||
debug_assert_eq!(xnm % (1 << 64), 0);
|
||||
let m = T::from_double_width(x).wrapping_mul(a);
|
||||
let nm = (n.as_double_width()) * (m.as_double_width());
|
||||
let (xnm, overflow) = x.overflowing_add_(nm); // x + n*m
|
||||
debug_assert_eq!(
|
||||
xnm % (T::DoubleWidth::one() << T::zero().count_zeros() as usize),
|
||||
T::DoubleWidth::zero()
|
||||
);
|
||||
|
||||
// (x + n*m) / R
|
||||
// in case of overflow, this is (2¹²⁸ + xnm)/2⁶⁴ - n = xnm/2⁶⁴ + (2⁶⁴ - n)
|
||||
let y = (xnm >> 64) as u64 + if !overflow { 0 } else { n.wrapping_neg() };
|
||||
let y = T::from_double_width(xnm >> t_bits)
|
||||
+ if !overflow {
|
||||
T::zero()
|
||||
} else {
|
||||
n.wrapping_neg()
|
||||
};
|
||||
|
||||
if y >= *n {
|
||||
y - n
|
||||
y - *n
|
||||
} else {
|
||||
y
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Arithmetic for Montgomery {
|
||||
impl<T: DoubleInt> Arithmetic for Montgomery<T> {
|
||||
// Montgomery transform, R=2⁶⁴
|
||||
// Provides fast arithmetic mod n (n odd, u64)
|
||||
type I = u64;
|
||||
type ModInt = T;
|
||||
|
||||
fn new(n: u64) -> Self {
|
||||
let a = inv_mod_u64(n).wrapping_neg();
|
||||
debug_assert_eq!(n.wrapping_mul(a), 1_u64.wrapping_neg());
|
||||
debug_assert!(T::zero().count_zeros() >= 64 || n < (1 << T::zero().count_zeros() as usize));
|
||||
let n = T::from_u64(n);
|
||||
let a = modular_inverse(n).wrapping_neg();
|
||||
debug_assert_eq!(n.wrapping_mul(&a), T::one().wrapping_neg());
|
||||
Montgomery { a, n }
|
||||
}
|
||||
|
||||
fn modulus(&self) -> u64 {
|
||||
self.n
|
||||
self.n.as_u64()
|
||||
}
|
||||
|
||||
fn from_u64(&self, x: u64) -> Self::I {
|
||||
fn from_u64(&self, x: u64) -> Self::ModInt {
|
||||
// TODO: optimise!
|
||||
assert!(x < self.n);
|
||||
let r = (((x as u128) << 64) % self.n as u128) as u64;
|
||||
debug_assert!(x < self.n.as_u64());
|
||||
let r = T::from_double_width(
|
||||
((T::DoubleWidth::from_u64(x)) << T::zero().count_zeros() as usize)
|
||||
% self.n.as_double_width(),
|
||||
);
|
||||
debug_assert_eq!(x, self.to_u64(r));
|
||||
r
|
||||
}
|
||||
|
||||
fn to_u64(&self, n: Self::I) -> u64 {
|
||||
self.reduce(n as u128)
|
||||
fn to_u64(&self, n: Self::ModInt) -> u64 {
|
||||
self.reduce(n.as_double_width()).as_u64()
|
||||
}
|
||||
|
||||
fn add(&self, a: Self::I, b: Self::I) -> Self::I {
|
||||
let (r, overflow) = a.overflowing_add(b);
|
||||
fn add(&self, a: Self::ModInt, b: Self::ModInt) -> Self::ModInt {
|
||||
let (r, overflow) = a.overflowing_add_(b);
|
||||
|
||||
// In case of overflow, a+b = 2⁶⁴ + r = (2⁶⁴ - n) + r (working mod n)
|
||||
let r = if !overflow {
|
||||
|
@ -138,10 +160,10 @@ impl Arithmetic for Montgomery {
|
|||
// a+b % n
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a_r = self.to_u64(a);
|
||||
let b_r = self.to_u64(b);
|
||||
let a_r = self.to_u64(a) as u128;
|
||||
let b_r = self.to_u64(b) as u128;
|
||||
let r_r = self.to_u64(r);
|
||||
let r_2 = (((a_r as u128) + (b_r as u128)) % (self.n as u128)) as u64;
|
||||
let r_2 = ((a_r + b_r) % self.n.as_u128()) as u64;
|
||||
debug_assert_eq!(
|
||||
r_r, r_2,
|
||||
"[{}] = {} ≠ {} = {} + {} = [{}] + [{}] mod {}; a = {}",
|
||||
|
@ -151,17 +173,17 @@ impl Arithmetic for Montgomery {
|
|||
r
|
||||
}
|
||||
|
||||
fn mul(&self, a: Self::I, b: Self::I) -> Self::I {
|
||||
let r = self.reduce((a as u128) * (b as u128));
|
||||
fn mul(&self, a: Self::ModInt, b: Self::ModInt) -> Self::ModInt {
|
||||
let r = self.reduce(a.as_double_width() * b.as_double_width());
|
||||
|
||||
// Check that r (reduced back to the usual representation) equals
|
||||
// a*b % n
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let a_r = self.to_u64(a);
|
||||
let b_r = self.to_u64(b);
|
||||
let a_r = self.to_u64(a) as u128;
|
||||
let b_r = self.to_u64(b) as u128;
|
||||
let r_r = self.to_u64(r);
|
||||
let r_2 = (((a_r as u128) * (b_r as u128)) % (self.n as u128)) as u64;
|
||||
let r_2: u64 = ((a_r * b_r) % self.n.as_u128()) as u64;
|
||||
debug_assert_eq!(
|
||||
r_r, r_2,
|
||||
"[{}] = {} ≠ {} = {} * {} = [{}] * [{}] mod {}; a = {}",
|
||||
|
@ -172,35 +194,112 @@ impl Arithmetic for Montgomery {
|
|||
}
|
||||
}
|
||||
|
||||
// NOTE: Trait can be removed once num-traits adds a similar one;
|
||||
// see https://github.com/rust-num/num-traits/issues/168
|
||||
pub(crate) trait OverflowingAdd: Sized {
|
||||
fn overflowing_add_(self, n: Self) -> (Self, bool);
|
||||
}
|
||||
macro_rules! overflowing {
|
||||
($x:ty) => {
|
||||
impl OverflowingAdd for $x {
|
||||
fn overflowing_add_(self, n: Self) -> (Self, bool) {
|
||||
self.overflowing_add(n)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
overflowing!(u32);
|
||||
overflowing!(u64);
|
||||
overflowing!(u128);
|
||||
|
||||
pub(crate) trait Int:
|
||||
Display + Debug + PrimInt + OverflowingAdd + WrappingNeg + WrappingSub + WrappingMul
|
||||
{
|
||||
fn as_u64(&self) -> u64;
|
||||
fn from_u64(n: u64) -> Self;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn as_u128(&self) -> u128;
|
||||
}
|
||||
|
||||
pub(crate) trait DoubleInt: Int {
|
||||
/// An integer type with twice the width of `Self`.
|
||||
/// In particular, multiplications (of `Int` values) can be performed in
|
||||
/// `Self::DoubleWidth` without possibility of overflow.
|
||||
type DoubleWidth: Int;
|
||||
|
||||
fn as_double_width(self) -> Self::DoubleWidth;
|
||||
fn from_double_width(n: Self::DoubleWidth) -> Self;
|
||||
}
|
||||
|
||||
macro_rules! int {
|
||||
( $x:ty ) => {
|
||||
impl Int for $x {
|
||||
fn as_u64(&self) -> u64 {
|
||||
*self as u64
|
||||
}
|
||||
fn from_u64(n: u64) -> Self {
|
||||
n as _
|
||||
}
|
||||
#[cfg(debug_assertions)]
|
||||
fn as_u128(&self) -> u128 {
|
||||
*self as u128
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
macro_rules! double_int {
|
||||
( $x:ty, $y:ty ) => {
|
||||
int!($x);
|
||||
impl DoubleInt for $x {
|
||||
type DoubleWidth = $y;
|
||||
|
||||
fn as_double_width(self) -> $y {
|
||||
self as _
|
||||
}
|
||||
fn from_double_width(n: $y) -> $x {
|
||||
n as _
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
double_int!(u32, u64);
|
||||
double_int!(u64, u128);
|
||||
int!(u128);
|
||||
|
||||
// extended Euclid algorithm
|
||||
// precondition: a is odd
|
||||
pub(crate) fn inv_mod_u64(a: u64) -> u64 {
|
||||
assert!(a % 2 == 1, "{} is not odd", a);
|
||||
let mut t = 0u64;
|
||||
let mut newt = 1u64;
|
||||
let mut r = 0u64;
|
||||
pub(crate) fn modular_inverse<T: Int>(a: T) -> T {
|
||||
let zero = T::zero();
|
||||
let one = T::one();
|
||||
debug_assert!(a % (one + one) == one, "{:?} is not odd", a);
|
||||
|
||||
let mut t = zero;
|
||||
let mut newt = one;
|
||||
let mut r = zero;
|
||||
let mut newr = a;
|
||||
|
||||
while newr != 0 {
|
||||
let quot = if r == 0 {
|
||||
while newr != zero {
|
||||
let quot = if r == zero {
|
||||
// special case when we're just starting out
|
||||
// This works because we know that
|
||||
// a does not divide 2^64, so floor(2^64 / a) == floor((2^64-1) / a);
|
||||
std::u64::MAX
|
||||
T::max_value()
|
||||
} else {
|
||||
r
|
||||
} / newr;
|
||||
|
||||
let newtp = t.wrapping_sub(quot.wrapping_mul(newt));
|
||||
let newtp = t.wrapping_sub(".wrapping_mul(&newt));
|
||||
t = newt;
|
||||
newt = newtp;
|
||||
|
||||
let newrp = r.wrapping_sub(quot.wrapping_mul(newr));
|
||||
let newrp = r.wrapping_sub(".wrapping_mul(&newr));
|
||||
r = newr;
|
||||
newr = newrp;
|
||||
}
|
||||
|
||||
assert_eq!(r, 1);
|
||||
debug_assert_eq!(r, one);
|
||||
t
|
||||
}
|
||||
|
||||
|
@ -208,19 +307,37 @@ pub(crate) fn inv_mod_u64(a: u64) -> u64 {
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_inverter() {
|
||||
// All odd integers from 1 to 20 000
|
||||
let mut test_values = (0..10_000u64).map(|i| 2 * i + 1);
|
||||
|
||||
assert!(test_values.all(|x| x.wrapping_mul(inv_mod_u64(x)) == 1));
|
||||
macro_rules! parametrized_check {
|
||||
( $f:ident ) => {
|
||||
paste::item! {
|
||||
#[test]
|
||||
fn [< $f _ u32 >]() {
|
||||
$f::<u32>()
|
||||
}
|
||||
#[test]
|
||||
fn [< $f _ u64 >]() {
|
||||
$f::<u64>()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_montgomery_add() {
|
||||
fn test_inverter<T: Int>() {
|
||||
// All odd integers from 1 to 20 000
|
||||
let one = T::from(1).unwrap();
|
||||
let two = T::from(2).unwrap();
|
||||
let mut test_values = (0..10_000)
|
||||
.map(|i| T::from(i).unwrap())
|
||||
.map(|i| two * i + one);
|
||||
|
||||
assert!(test_values.all(|x| x.wrapping_mul(&modular_inverse(x)) == one));
|
||||
}
|
||||
parametrized_check!(test_inverter);
|
||||
|
||||
fn test_add<A: DoubleInt>() {
|
||||
for n in 0..100 {
|
||||
let n = 2 * n + 1;
|
||||
let m = Montgomery::new(n);
|
||||
let m = Montgomery::<A>::new(n);
|
||||
for x in 0..n {
|
||||
let m_x = m.from_u64(x);
|
||||
for y in 0..=x {
|
||||
|
@ -231,12 +348,12 @@ mod tests {
|
|||
}
|
||||
}
|
||||
}
|
||||
parametrized_check!(test_add);
|
||||
|
||||
#[test]
|
||||
fn test_montgomery_mult() {
|
||||
fn test_mult<A: DoubleInt>() {
|
||||
for n in 0..100 {
|
||||
let n = 2 * n + 1;
|
||||
let m = Montgomery::new(n);
|
||||
let m = Montgomery::<A>::new(n);
|
||||
for x in 0..n {
|
||||
let m_x = m.from_u64(x);
|
||||
for y in 0..=x {
|
||||
|
@ -246,16 +363,17 @@ mod tests {
|
|||
}
|
||||
}
|
||||
}
|
||||
parametrized_check!(test_mult);
|
||||
|
||||
#[test]
|
||||
fn test_montgomery_roundtrip() {
|
||||
fn test_roundtrip<A: DoubleInt>() {
|
||||
for n in 0..100 {
|
||||
let n = 2 * n + 1;
|
||||
let m = Montgomery::new(n);
|
||||
let m = Montgomery::<A>::new(n);
|
||||
for x in 0..n {
|
||||
let x_ = m.from_u64(x);
|
||||
assert_eq!(x, m.to_u64(x_));
|
||||
}
|
||||
}
|
||||
}
|
||||
parametrized_check!(test_roundtrip);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue