Modular exponentiation with Barrett reduction. Seems slow. :(

This commit is contained in:
2018-06-18 12:04:11 -07:00
parent 011ebc0c99
commit b30fe6a75f
20 changed files with 66276 additions and 342 deletions

View File

@@ -5,7 +5,6 @@ use cryptonum::comparison::bignum_ge;
use cryptonum::subtraction::raw_subtraction;
use std::ops::{Add,AddAssign};
#[inline(always)]
pub fn raw_addition(x: &mut [u64], y: &[u64]) -> u64 {
assert_eq!(x.len(), y.len());

View File

@@ -4,7 +4,9 @@ use cryptonum::{U192, U256, U384, U512, U576,
use cryptonum::addition::{ModAdd,raw_addition};
use cryptonum::comparison::{bignum_cmp,bignum_ge};
use cryptonum::division::divmod;
use cryptonum::multiplication::raw_multiplication;
use cryptonum::exponentiation::ModExp;
use cryptonum::multiplication::{ModMul,raw_multiplication};
use cryptonum::squaring::{ModSquare,raw_square};
use cryptonum::subtraction::raw_subtraction;
use std::cmp::{Ordering,min};
use std::fmt;
@@ -63,9 +65,7 @@ macro_rules! generate_barrett_implementations {
}
pub fn reduce(&self, x: &mut $name) {
printvar("x", &x.values);
printvar("m", &self.m);
printvar("u", &self.mu);
assert!(self.reduce_ok(&x));
// 1. q1←⌊x/bk1⌋, q2←q1 · μ, q3←⌊q2/bk+1⌋.
let mut q1 = [0; $size/32];
shiftr(&x.values, self.k - 1, &mut q1);
@@ -100,6 +100,15 @@ macro_rules! generate_barrett_implementations {
*val = r[idx];
}
}
pub fn reduce_ok(&self, x: &$name) -> bool {
for i in self.k*2 .. x.values.len() {
if x.values[i] != 0 {
return false
}
}
true
}
}
impl ModAdd<$bname> for $name {
@@ -122,15 +131,122 @@ macro_rules! generate_barrett_implementations {
}
}
}
};
}
fn printvar(name: &'static str, val: &[u64]) {
print!("{}: 0x", name);
for x in val.iter().rev() {
print!("{:016X}", *x);
impl ModMul<$bname> for $name {
fn modmul(&mut self, x: &$name, m: &$bname) {
let mut mulres = [0; $size/32];
raw_multiplication(&self.values, &x.values, &mut mulres);
// 1. q1←⌊x/bk1⌋, q2←q1 · μ, q3←⌊q2/bk+1⌋.
let mut q1 = [0; $size/32];
shiftr(&mulres, m.k - 1, &mut q1);
let mut q2 = [0; $size/16];
raw_multiplication(&q1, &m.mu, &mut q2);
let mut q3 = [0; $size/16];
shiftr(&q2, m.k + 1, &mut q3);
// 2. r1←x mod bk+1, r2←q3 · m mod bk+1, r←r1 r2.
let mut r = [0; $size/16];
let copylen = min(m.k + 1, mulres.len());
for i in 0..copylen { r[i] = mulres[i]; }
let mut r2big = [0; $size/8];
let mut mwider = [0; $size/16];
for i in 0..$size/32 { mwider[i] = m.m[i]; }
raw_multiplication(&q3, &mwider, &mut r2big);
let mut r2 = [0; $size/16];
for i in 0..m.k+1 { r2[i] = r2big[i]; }
let went_negative = !bignum_ge(&r, &r2);
raw_subtraction(&mut r, &r2);
// 3. If r<0 then r←r+bk+1.
if went_negative {
let mut bk1 = [0; $size/32];
bk1[m.k + 1] = 1;
raw_addition(&mut r, &bk1);
}
println!("");
// 4. While r≥m do: r←rm.
while bignum_cmp(&r, &mwider) == Ordering::Greater {
raw_subtraction(&mut r, &mwider);
}
// Copy it over.
for (idx, val) in self.values.iter_mut().enumerate() {
*val = r[idx];
}
}
}
impl ModSquare<$bname> for $name {
fn modsq(&mut self, m: &$bname) {
let mut sqres = [0; $size/32];
raw_square(&self.values, &mut sqres);
// 1. q1←⌊x/bk1⌋, q2←q1 · μ, q3←⌊q2/bk+1⌋.
let mut q1 = [0; $size/32];
shiftr(&sqres, m.k - 1, &mut q1);
let mut q2 = [0; $size/16];
raw_multiplication(&q1, &m.mu, &mut q2);
let mut q3 = [0; $size/16];
shiftr(&q2, m.k + 1, &mut q3);
// 2. r1←x mod bk+1, r2←q3 · m mod bk+1, r←r1 r2.
let mut r = [0; $size/16];
let copylen = min(m.k + 1, sqres.len());
for i in 0..copylen { r[i] = sqres[i]; }
let mut r2big = [0; $size/8];
let mut mwider = [0; $size/16];
for i in 0..$size/32 { mwider[i] = m.m[i]; }
raw_multiplication(&q3, &mwider, &mut r2big);
let mut r2 = [0; $size/16];
for i in 0..m.k+1 { r2[i] = r2big[i]; }
let went_negative = !bignum_ge(&r, &r2);
raw_subtraction(&mut r, &r2);
// 3. If r<0 then r←r+bk+1.
if went_negative {
let mut bk1 = [0; $size/32];
bk1[m.k + 1] = 1;
raw_addition(&mut r, &bk1);
}
// 4. While r≥m do: r←rm.
while bignum_cmp(&r, &mwider) == Ordering::Greater {
raw_subtraction(&mut r, &mwider);
}
// Copy it over.
for (idx, val) in self.values.iter_mut().enumerate() {
*val = r[idx];
}
}
}
impl ModExp<$bname> for $name {
fn modexp(&mut self, e: &$name, m: &$bname) {
// S <- g
let mut s = self.clone();
m.reduce(&mut s);
// A <- 1
for val in self.values.iter_mut() { *val = 0; }
self.values[0] = 1;
// We do a quick skim through and find the highest index that
// actually has a value in it.
let mut highest_digit = 0;
for (idx, val) in e.values.iter().enumerate() {
if *val != 0 {
highest_digit = idx;
}
}
// While e != 0 do the following:
// If e is odd then A <- A * S
// e <- floor(e / 2)
// If e != 0 then S <- S * S
for i in 0..highest_digit+1 {
let mut mask = 1;
while mask != 0 {
if e.values[i] & mask != 0 {
self.modmul(&s, m);
}
mask <<= 1;
s.modsq(m);
}
}
// Return A
}
}
};
}
fn shiftr(x: &[u64], amt: usize, dest: &mut [u64])
@@ -254,6 +370,76 @@ macro_rules! generate_tests {
}
)*
}
#[cfg(test)]
mod modmul {
use cryptonum::encoding::Decoder;
use super::*;
use testing::run_test;
$(
#[test]
#[allow(non_snake_case)]
fn $name() {
let fname = format!("tests/math/modmul{}.test",
stringify!($name));
run_test(fname.to_string(), 4, |case| {
let (neg0, abytes) = case.get("a").unwrap();
let (neg1, bbytes) = case.get("b").unwrap();
let (neg2, cbytes) = case.get("c").unwrap();
let (neg3, mbytes) = case.get("m").unwrap();
assert!(!neg0 && !neg1 && !neg2 && !neg3);
let mut a = $name::from_bytes(abytes);
let b = $name::from_bytes(bbytes);
let m = $name::from_bytes(mbytes);
let mu = $bname::new(&m);
let c = $name::from_bytes(cbytes);
a.modmul(&b, &mu);
assert_eq!(a, c);
});
}
)*
}
#[cfg(test)]
mod modexp {
use cryptonum::encoding::{Decoder,raw_decoder};
use super::*;
use testing::run_test;
$(
#[test]
#[allow(non_snake_case)]
fn $name() {
let fname = format!("tests/math/bmodexp{}.test",
stringify!($name));
run_test(fname.to_string(), 6, |case| {
let (neg0, bbytes) = case.get("b").unwrap();
let (neg1, ebytes) = case.get("e").unwrap();
let (neg2, mbytes) = case.get("m").unwrap();
let (neg3, kbytes) = case.get("k").unwrap();
let (neg4, ubytes) = case.get("u").unwrap();
let (neg5, rbytes) = case.get("r").unwrap();
assert!(!neg0 && !neg1 && !neg2 &&
!neg3 && !neg4 && !neg5);
let mut b = $name::from_bytes(bbytes);
let e = $name::from_bytes(ebytes);
let mut kbig = [0; 1];
raw_decoder(&kbytes, &mut kbig);
let mut u = $bname{ k: kbig[0] as usize,
m: [0; $size/32],
mu: [0; $size/32] };
raw_decoder(&mbytes, &mut u.m);
raw_decoder(&ubytes, &mut u.mu);
let r = $name::from_bytes(rbytes);
b.modexp(&e, &u);
assert_eq!(b, r);
});
}
)*
}
}
}

View File

@@ -1,306 +0,0 @@
use std::cmp::Ordering;
#[inline(always)]
pub fn generic_cmp(a: &[u64], b: &[u64]) -> Ordering {
let mut i = a.len() - 1;
assert!(a.len() == b.len());
loop {
match a[i].cmp(&b[i]) {
Ordering::Equal if i == 0 =>
return Ordering::Equal,
Ordering::Equal =>
i -= 1,
res =>
return res
}
}
}
fn le(a: &[u64], b: &[u64]) -> bool {
generic_cmp(a, b) != Ordering::Greater
}
fn ge(a: &[u64], b: &[u64]) -> bool {
generic_cmp(a, b) != Ordering::Less
}
#[inline(always)]
pub fn generic_bitand(a: &mut [u64], b: &[u64]) {
let mut i = 0;
assert!(a.len() == b.len());
while i < a.len() {
a[i] &= b[i];
i += 1;
}
}
#[inline(always)]
pub fn generic_bitor(a: &mut [u64], b: &[u64]) {
let mut i = 0;
assert!(a.len() == b.len());
while i < a.len() {
a[i] |= b[i];
i += 1;
}
}
#[inline(always)]
pub fn generic_bitxor(a: &mut [u64], b: &[u64]) {
let mut i = 0;
assert!(a.len() == b.len());
while i < a.len() {
a[i] ^= b[i];
i += 1;
}
}
#[inline(always)]
pub fn generic_not(a: &mut [u64]) {
for x in a.iter_mut() {
*x = !*x;
}
}
#[inline(always)]
pub fn generic_shl(a: &mut [u64], orig: &[u64], amount: usize) {
let digits = amount / 64;
let bits = amount % 64;
assert!(a.len() == orig.len());
for i in 0..a.len() {
if i < digits {
a[i] = 0;
} else {
let origidx = i - digits;
let prev = if origidx == 0 { 0 } else { orig[origidx - 1] };
let (carry,_) = if bits == 0 { (0, false) }
else { prev.overflowing_shr(64 - bits as u32) };
a[i] = (orig[origidx] << bits) | carry;
}
}
}
#[inline(always)]
pub fn generic_shr(a: &mut [u64], orig: &[u64], amount: usize) {
let digits = amount / 64;
let bits = amount % 64;
assert!(a.len() == orig.len());
for i in 0..a.len() {
let oldidx = i + digits;
let caridx = i + digits + 1;
let old = if oldidx >= a.len() { 0 } else { orig[oldidx] };
let carry = if caridx >= a.len() { 0 } else { orig[caridx] };
let cb = if bits == 0 { 0 } else { carry << (64 - bits) };
a[i] = (old >> bits) | cb;
}
}
#[inline(always)]
pub fn generic_add(a: &mut [u64], b: &[u64]) {
let mut carry = 0;
assert!(a.len() == b.len());
for i in 0..a.len() {
let x = a[i] as u128;
let y = b[i] as u128;
let total = x + y + carry;
a[i] = total as u64;
carry = total >> 64;
}
}
#[inline(always)]
pub fn generic_sub(a: &mut [u64], b: &[u64]) {
let mut negated_rhs = b.to_vec();
generic_not(&mut negated_rhs);
let mut one = Vec::with_capacity(a.len());
one.resize(a.len(), 0);
one[0] = 1;
generic_add(&mut negated_rhs, &one);
generic_add(a, &negated_rhs);
}
#[inline(always)]
pub fn generic_mul(a: &mut [u64], orig: &[u64], b: &[u64]) {
assert!(a.len() == orig.len());
assert!(a.len() == b.len());
assert!(a == orig);
// Build the output table. This is a little bit awkward because we don't
// know how big we're running, but hopefully the compiler is smart enough
// to work all this out.
let mut table = Vec::with_capacity(a.len());
for _ in 0..a.len() {
let mut row = Vec::with_capacity(a.len());
row.resize(a.len(), 0);
table.push(row);
}
// This uses "simple" grade school techniques to work things out. But,
// for reference, consider two 4 digit numbers:
//
// l0c3 l0c2 l0c1 l0c0 [orig]
// x l1c3 l1c2 l1c1 l1c0 [b]
// ------------------------------------------------------------
// (l0c3*l1c0) (l0c2*l1c0) (l0c1*l1c0) (l0c0*l1c0)
// (l0c2*l1c1) (l0c1*l1c1) (l0c0*l1c1)
// (l0c1*l1c2) (l0c0*l1c2)
// (l0c0*l1c3)
// ------------------------------------------------------------
// AAAAA BBBBB CCCCC DDDDD
for line in 0..a.len() {
let maxcol = a.len() - line;
for col in 0..maxcol {
let left = orig[col] as u128;
let right = b[line] as u128;
table[line][col + line] = left * right;
}
}
// ripple the carry across each line, ensuring that each entry in the
// table is 64-bits
for line in 0..a.len() {
let mut carry = 0;
for col in 0..a.len() {
table[line][col] = table[line][col] + carry;
carry = table[line][col] >> 64;
table[line][col] &= 0xFFFFFFFFFFFFFFFF;
}
}
// now do the final addition across the lines, rippling the carry as
// normal
let mut carry = 0;
for col in 0..a.len() {
let mut total = carry;
for line in 0..a.len() {
total += table[line][col];
}
a[col] = total as u64;
carry = total >> 64;
}
}
#[inline(always)]
pub fn generic_div(inx: &[u64], iny: &[u64],
outq: &mut [u64], outr: &mut [u64])
{
assert!(inx.len() == inx.len());
assert!(inx.len() == iny.len());
assert!(inx.len() == outq.len());
assert!(inx.len() == outr.len());
// This algorithm is from the Handbook of Applied Cryptography, Chapter 14,
// algorithm 14.20. It has a couple assumptions about the inputs, namely that
// n >= t >= 1 and y[t] != 0, where n and t refer to the number of digits in
// the numbers. Which means that if we used the inputs unmodified, we can't
// divide by single-digit numbers.
//
// To deal with this, we multiply inx and iny by 2^64, so that we push out
// t by one.
//
// In addition, this algorithm starts to go badly when y[t] is very small
// and x[n] is very large. Really, really badly. This can be fixed by
// insuring that the top bit is set in y[t], which we can achieve by
// shifting everyone over a maxiumum of 63 bits.
//
// What this means is, just for safety, we add a 0 at the beginning and
// end of each number.
let mut y = iny.to_vec();
let mut x = inx.to_vec();
y.insert(0,0); y.push(0);
x.insert(0,0); x.push(0);
// 0. Compute 'n' and 't'
let n = x.len() - 1;
let mut t = y.len() - 1;
while (t > 0) && (y[t] == 0) { t -= 1 }
assert!(y[t] != 0); // this is where division by zero will fire
// 0.5. Figure out a shift we can do such that the high bit of y[t] is
// set, and then shift x and y left by that much.
let additional_shift: usize = y[t].leading_zeros() as usize;
let origx = x.clone();
let origy = y.clone();
generic_shl(&mut x, &origx, additional_shift);
generic_shl(&mut y, &origy, additional_shift);
// 1. For j from 0 to (n - 1) do: q_j <- 0
let mut q = Vec::with_capacity(y.len());
q.resize(y.len(), 0);
for qj in q.iter_mut() { *qj = 0 }
// 2. While (x >= yb^(n-t)) do the following:
// q_(n-t) <- q_(n-t) + 1
// x <- x - yb^(n-t)
let mut ybnt = y.clone();
generic_shl(&mut ybnt, &y, 64 * (n - t));
while ge(&x, &ybnt) {
q[n-t] = q[n-t] + 1;
generic_sub(&mut x, &ybnt);
}
// 3. For i from n down to (t + 1) do the following:
let mut i = n;
while i >= (t + 1) {
// 3.1. if x_i = y_t, then set q_(i-t-1) <- b - 1; otherwise set
// q_(i-t-1) <- floor((x_i * b + x_(i-1)) / y_t).
if x[i] == y[t] {
q[i-t-1] = 0xFFFFFFFFFFFFFFFF;
} else {
let top = ((x[i] as u128) << 64) + (x[i-1] as u128);
let bot = y[t] as u128;
let solution = top / bot;
q[i-t-1] = solution as u64;
}
// 3.2. While (q_(i-t-1)(y_t * b + y_(t-1)) > x_i(b2) + x_(i-1)b +
// x_(i-2)) do:
// q_(i - t - 1) <- q_(i - t 1) - 1.
loop {
let mut left = Vec::with_capacity(x.len());
left.resize(x.len(), 0);
left[0] = q[i - t - 1];
let mut leftright = Vec::with_capacity(x.len());
leftright.resize(x.len(), 0);
leftright[0] = y[t-1];
let copy = left.clone();
generic_mul(&mut left, &copy, &leftright);
let mut right = Vec::with_capacity(x.len());
right.resize(x.len(), 0);
right[0] = x[i-2];
right[1] = x[i-1];
right[2] = x[i];
if le(&left, &right) {
break
}
q[i - t - 1] -= 1;
}
// 3.3. x <- x - q_(i - t - 1) * y * b^(i-t-1)
let mut right = Vec::with_capacity(y.len());
right.resize(y.len(), 0);
right[i - t - 1] = q[i - t - 1];
let rightclone = right.clone();
generic_mul(&mut right, &rightclone, &y);
let wentnegative = generic_cmp(&x, &right) == Ordering::Less;
generic_sub(&mut x, &right);
// 3.4. if x < 0 then set x <- x + yb^(i-t-1) and
// q_(i-t-1) <- q_(i-t-1) - 1
if wentnegative {
let mut ybit1 = y.to_vec();
generic_shl(&mut ybit1, &y, 64 * (i - t - 1));
generic_add(&mut x, &ybit1);
q[i - t - 1] -= 1;
}
i -= 1;
}
// 4. r <- x
let finalx = x.clone();
generic_shr(&mut x, &finalx, additional_shift);
for i in 0..outr.len() {
outr[i] = x[i + 1]; // note that for the remainder, we're dividing by
// our normalization value.
}
// 5. return (q,r)
for i in 0..outq.len() {
outq[i] = q[i];
}
}

View File

@@ -38,10 +38,10 @@ macro_rules! generate_exponentiators
while mask != 0 {
if e.values[i] & mask != 0 {
self.modmul(&s, &m);
self.modmul(&s, m);
}
mask <<= 1;
s.modsq(&m);
s.modsq(m);
}
}
// Return A
@@ -94,7 +94,40 @@ macro_rules! generate_tests {
}
)*
}
#[cfg(test)]
mod varrett_modular {
use cryptonum::encoding::Decoder;
use super::*;
use testing::run_test;
$(
#[test]
#[allow(non_snake_case)]
#[ignore]
fn $name() {
let fname = format!("tests/math/modexp{}.test",
stringify!($name));
run_test(fname.to_string(), 4, |case| {
let (neg0, bbytes) = case.get("b").unwrap();
let (neg1, ebytes) = case.get("e").unwrap();
let (neg2, mbytes) = case.get("m").unwrap();
let (neg3, rbytes) = case.get("r").unwrap();
assert!(!neg0 && !neg1 && !neg2 && !neg3);
let mut b = $name::from_bytes(bbytes);
let e = $name::from_bytes(ebytes);
let m = $name::from_bytes(mbytes);
let r = $name::from_bytes(rbytes);
b.modexp(&e, &m);
assert_eq!(b, r);
});
}
)*
}
}
}
generate_tests!(U192, U256, U384, U512, U576, U1024, U2048, U3072, U4096, U8192, U15360);
generate_tests!(U192, U256, U384, U512, U576, U1024, U2048,
U3072, U4096, U8192, U15360);

View File

@@ -9,7 +9,6 @@ pub trait ModMul<T=Self> {
}
// This is algorithm 14.12 from "Handbook of Applied Cryptography"
#[inline(always)]
pub fn raw_multiplication(x: &[u64], y: &[u64], w: &mut [u64])
{
assert_eq!(x.len(), y.len());

View File

@@ -7,7 +7,6 @@ pub trait ModSquare<T=Self>
}
// This is algorithm 14.16 from "Handbook of Applied Cryptography".
#[inline(always)]
pub fn raw_square(x: &[u64], result: &mut [u64])
{
assert_eq!(x.len() * 2, result.len());

View File

@@ -4,7 +4,6 @@ use cryptonum::{U192, U256, U384, U512, U576,
use cryptonum::addition::raw_addition;
use std::ops::{Sub,SubAssign};
#[inline(always)]
pub fn raw_subtraction(x: &mut [u64], y: &[u64])
{
assert_eq!(x.len(), y.len());

View File

@@ -5,6 +5,7 @@ import qualified Data.Map.Strict as Map
import GHC.Integer.GMP.Internals(powModInteger)
import Numeric(showHex)
import Prelude hiding (log)
import System.Environment(getArgs)
import System.IO(hFlush,stdout,IOMode(..),withFile,Handle,hClose,hPutStrLn)
import System.Random(StdGen,newStdGen,random)
@@ -18,6 +19,7 @@ testTypes = [("addition", addTest),
("squaring", squareTest),
("modsq", modsqTest),
("modexp", modexpTest),
("bmodexp", bmodexpTest),
("division", divTest),
("barrett_gen", barrettGenTest),
("barrett_reduce", barrettReduceTest)
@@ -189,6 +191,24 @@ barrettReduceTest bitsize gen0 = (res, gen2)
("u", showHex u ""),
("r", showHex r "")]
bmodexpTest :: Int -> StdGen -> (Map String String, StdGen)
bmodexpTest bitsize gen0 = (res, gen2)
where
(b, gen1) = random gen0
(e, gen2) = random gen1
(m, gen3) = random gen2
[b',e',m'] = splitMod bitsize [b,e,m]
k = computeK m'
u = barrett bitsize m'
r = powModInteger b' e' m'
res = Map.fromList [("b", showHex b' ""),
("e", showHex e' ""),
("m", showHex m' ""),
("k", showHex k ""),
("u", showHex u ""),
("r", showHex r "")]
barrett :: Int -> Integer -> Integer
barrett bitsize m = (b ^ (2 * k)) `div` m
where
@@ -218,10 +238,15 @@ generateData hndl generator gen () =
main :: IO ()
main =
forM_ testTypes $ \ (testName, testFun) ->
do args <- getArgs
let tests = if null args
then testTypes
else filter (\ (name,_) -> name `elem` args) testTypes
forM_ tests $ \ (testName, testFun) ->
forM_ bitSizes $ \ bitsize ->
do log ("Generating " ++ show bitsize ++ "-bit " ++ testName ++ " tests")
withFile (testName ++ "U" ++ show bitsize ++ ".test") WriteMode $ \ hndl ->
do log ("Generating "++show bitsize++"-bit "++testName++" tests")
withFile (testName ++ "U" ++ show bitsize ++ ".test") WriteMode $
\ hndl ->
do gen <- newStdGen
foldM_ (generateData hndl (testFun bitsize))
gen

6000
tests/math/bmodexpU1024.test Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

6000
tests/math/bmodexpU192.test Normal file

File diff suppressed because it is too large Load Diff

6000
tests/math/bmodexpU2048.test Normal file

File diff suppressed because it is too large Load Diff

6000
tests/math/bmodexpU256.test Normal file

File diff suppressed because it is too large Load Diff

6000
tests/math/bmodexpU3072.test Normal file

File diff suppressed because it is too large Load Diff

6000
tests/math/bmodexpU384.test Normal file

File diff suppressed because it is too large Load Diff

6000
tests/math/bmodexpU4096.test Normal file

File diff suppressed because it is too large Load Diff

6000
tests/math/bmodexpU512.test Normal file

File diff suppressed because it is too large Load Diff

6000
tests/math/bmodexpU576.test Normal file

File diff suppressed because it is too large Load Diff

6000
tests/math/bmodexpU8192.test Normal file

File diff suppressed because it is too large Load Diff