Fix an issue in Barrett reduction.

This commit is contained in:
2018-10-04 20:00:46 -07:00
parent 78750598a5
commit fe43949684
2 changed files with 46 additions and 43 deletions

View File

@@ -50,8 +50,10 @@ needs = [ Need ModExp (\ size -> [Req size ModMul
,Req (size + 64) Mul ,Req (size + 64) Mul
,Req (size * 2) (Convert ((size * 2) + 64)) ,Req (size * 2) (Convert ((size * 2) + 64))
,Req ((size * 2) + 64) Shifts ,Req ((size * 2) + 64) Shifts
,Req ((size * 2) + 128) Shifts
,Req ((size * 2) + 64) Div ,Req ((size * 2) + 64) Div
,Req (size + 64) (Convert (size * 2)) ,Req (size + 64) (Convert (size * 2))
,Req (size + 64) (Convert ((size * 2) + 128))
,Req ((size * 2) + 64) ,Req ((size * 2) + 64)
(Convert ((size * 2) + 128)) (Convert ((size * 2) + 128))
]) ])
@@ -254,7 +256,7 @@ generateAllTheTests =
let (db3, gen3) = emptyDatabase gen2 let (db3, gen3) = emptyDatabase gen2
generateTests Barretts "barrett_reduce" db3 $ \ size memory0 -> generateTests Barretts "barrett_reduce" db3 $ \ size memory0 ->
let (m, memory1) = generateNum memory0 "m" size let (m, memory1) = generateNum memory0 "m" size
(x, memory2) = generateNum memory1 "x" (min size (2 * k * 64)) (x, memory2) = generateNum memory1 "x" (min (2 * size) (2 * k * 64))
k = computeK m k = computeK m
u = barrett m u = barrett m
r = x `mod` m r = x `mod` m

View File

@@ -8,51 +8,52 @@ macro_rules! barrett_impl {
impl $bar { impl $bar {
pub fn new(m: $name) -> $bar { pub fn new(m: $name) -> $bar {
// Step #1: Figure out k // Step #1: Figure out k
let mut k = 0; let mut k = 0;
for i in 0..m.value.len() { for i in 0..m.value.len() {
if m.value[i] != 0 { if m.value[i] != 0 {
k = i; k = i;
} }
} }
k += 1; k += 1;
// Step #2: Compute b // Step #2: Compute b
let mut b = $dbl64::zero(); let mut b = $dbl64::zero();
b.value[2*k] = 1; b.value[2*k] = 1;
// Step #3: Divide b by m. // Step #3: Divide b by m.
let bigm = $dbl64::from(&m); let bigm = $dbl64::from(&m);
let quot = b / &bigm; let quot = b / &bigm;
let resm = $name64::from(&m); let resm = $name64::from(&m);
let mu = $name64::from(&quot); let mu = $name64::from(&quot);
// Done! // Done!
$bar { k: k, m: resm, mu: mu } $bar { k: k, m: resm, mu: mu }
} }
pub fn reduce(&self, x: &$dbl) -> $name { pub fn reduce(&self, x: &$dbl) -> $name {
let m2: $dbl64 = $dbl64::from(&self.m); // 1. q1←⌊x/bk1⌋, q2←q1 · μ, q3←⌊q2/bk+1⌋.
// 1. q1←⌊x/bk1⌋, q2←q1 · μ, q3←⌊q2/bk+1⌋. let q1: $name64 = $name64::from(x >> ((self.k - 1) * 64));
let q1: $name64 = $name64::from(x >> ((self.k - 1) * 64)); let q2 = q1 * &self.mu;
let q2: $dbl64 = $dbl64::from(q1 * &self.mu); let q3: $name64 = $name64::from(q2 >> ((self.k + 1) * 64));
let q3: $name64 = $name64::from(q2 >> ((self.k + 1) * 64)); // 2. r1←x mod bk+1, r2←q3 · m mod bk+1, r←r1 r2.
// 2. r1←x mod bk+1, r2←q3 · m mod bk+1, r←r1 r2. let mut r: $dbl64 = $dbl64::from(x);
let mut r: $dbl64 = $dbl64::from(x); r.mask(self.k + 1);
r.mask(self.k + 1); let mut r2: $dbl64 = $dbl64::from(q3 * &self.m);
let mut r2: $dbl64 = $dbl64::from(q3 * &self.m); r2.mask(self.k + 1);
r2.mask(self.k + 1); let went_negative = &r < &r2;
let went_negative = &r < &r2; r -= &r2;
r -= &r2; // 3. If r<0 then r←r+bk+1.
// 3. If r<0 then r←r+bk+1. if went_negative {
if went_negative { let mut bk1 = $dbl64::zero();
let mut bk1 = $dbl64::zero(); bk1.value[self.k+1] = 1;
bk1.value[self.k+1] = 1; // this may overflow, and we should probably be OK with it.
r += &bk1; r += &bk1;
} }
// 4. While r≥m do: r←rm. // 4. While r≥m do: r←rm.
while &r > &m2 { let m2 = $dbl64::from(&self.m);
r -= &m2; while &r > &m2 {
} r -= &m2;
// Done! }
$name::from(&r) // Done!
$name::from(&r)
} }
} }