From a6a82773d315b4decd6d55383b11f0a7e6c53471 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Sun, 30 Dec 2018 17:13:01 -0800 Subject: [PATCH] Add additional support for GCD on signed numbers. --- src/signed/egcd.rs | 72 ++++++++++++++++++++++++++++++++++------- test-generator/Math.hs | 17 +++++++++- test-generator/Tests.hs | 12 ++++--- 3 files changed, 85 insertions(+), 16 deletions(-) diff --git a/src/signed/egcd.rs b/src/signed/egcd.rs index 5222330..80dfdba 100644 --- a/src/signed/egcd.rs +++ b/src/signed/egcd.rs @@ -1,5 +1,5 @@ /// GCD computations, with extended information -pub trait EGCD { +pub trait EGCD + Sized>: Sized { /// Compute the extended GCD for this value and the given value. /// If the inputs to this function are x (self) and y (the argument), /// and the results are (a, b, g), then (a * x) + (b * y) = g. @@ -8,11 +8,51 @@ pub trait EGCD { /// have a GCD of 1. This is a slightly faster version of calling /// `egcd` and testing the result, because it can ignore some /// intermediate values. - fn gcd_is_one(&self, &Self) -> bool; + fn gcd_is_one(&self, rhs: &Self) -> bool { + let (_, _, g) = self.egcd(rhs); + g == T::from(1u64) + } } macro_rules! egcd_impls { - ($sname: ident, $name: ident, $ssmall: ident) => { + ($sname: ident, $name: ident, $ssmall: ident, $bigger: ident) => { + impl EGCD<$sname> for $ssmall { + fn egcd(&self, rhs: &$ssmall) -> ($sname, $sname, $sname) { + // This slower version works when we can't guarantee that the + // inputs are positive. + let mut s = $sname::from(0i64); + let mut old_s = $sname::from(1i64); + let mut t = $sname::from(1i64); + let mut old_t = $sname::from(0i64); + let mut r = $sname::from(rhs); + let mut old_r = $sname::from(self); + + while !r.is_zero() { + let quotient = &old_r / &r; + + let prov_r = r.clone(); + let prov_s = s.clone(); + let prov_t = t.clone(); + + // FIXME: Make this less copying, although I suspect the + // division above is the more major problem. + r = $sname::from($bigger::from(old_r) - (r * "ient)); + s = $sname::from($bigger::from(old_s) - (s * "ient)); + t = $sname::from($bigger::from(old_t) - (t * "ient)); + + old_r = prov_r; + old_s = prov_s; + old_t = prov_t; + } + + if old_r.is_negative() { + (old_s.negate(), old_t.negate(), old_r.negate()) + } else { + (old_s, old_t, old_r) + } + } + } + impl EGCD<$sname> for $name { fn egcd(&self, rhs: &$name) -> ($sname, $sname, $sname) { // INPUT: two positive integers x and y. @@ -160,17 +200,27 @@ macro_rules! generate_egcd_tests { let (nega, abytes) = case.get("a").unwrap(); let (negb, bbytes) = case.get("b").unwrap(); - assert!(!negx && !negy); - let x = $uname::from_bytes(xbytes); - let y = $uname::from_bytes(ybytes); let v = $sname64::new(*negv, $uname64::from_bytes(vbytes)); let a = $sname64::new(*nega, $uname64::from_bytes(abytes)); let b = $sname64::new(*negb, $uname64::from_bytes(bbytes)); - let (mya, myb, myv) = x.egcd(&y); - assert_eq!(v, myv, "GCD test"); - assert_eq!(a, mya, "X factor test"); - assert_eq!(b, myb, "Y factor tst"); - assert_eq!(x.gcd_is_one(&y), (myv == $sname64::from(1i64))); + + if *negx || *negy { + let x = $sname::new(*negx, $uname::from_bytes(xbytes)); + let y = $sname::new(*negy, $uname::from_bytes(ybytes)); + let (mya, myb, myv) = x.egcd(&y); + assert_eq!(v, myv, "Signed GCD test"); + assert_eq!(a, mya, "Signed X factor test"); + assert_eq!(b, myb, "Signed Y factor tst"); + assert_eq!(x.gcd_is_one(&y), (myv == $sname64::from(1i64))); + } else { + let x = $uname::from_bytes(xbytes); + let y = $uname::from_bytes(ybytes); + let (mya, myb, myv) = x.egcd(&y); + assert_eq!(v, myv, "GCD test"); + assert_eq!(a, mya, "X factor test"); + assert_eq!(b, myb, "Y factor tst"); + assert_eq!(x.gcd_is_one(&y), (myv == $sname64::from(1i64))); + } }); }; } diff --git a/test-generator/Math.hs b/test-generator/Math.hs index ba33bc4..b3fd2e2 100644 --- a/test-generator/Math.hs +++ b/test-generator/Math.hs @@ -1,6 +1,6 @@ {-# LANGUAGE RecordWildCards #-} module Math( - extendedGCD + extendedGCD, safeGCD , barrett, computeK, base , modulate, modulate' , isqrt @@ -31,6 +31,21 @@ printState a = putStrLn ("C: " ++ showX (bigC a)) putStrLn ("D: " ++ showX (bigD a)) +safeGCD :: Integer -> Integer -> (Integer, Integer, Integer) +safeGCD self rhs = go (self, 1, 0) (rhs, 0, 1) + where + go (v, a, b) (0, _, _) = if v < 0 + then (-a, -b, -v) + else (a, b, v) + go old new = go new (step old new) + -- + step (old_r, old_s, old_t) (r, s, t) = + let quotient = old_r `div` r + r' = old_r - (r * quotient) + s' = old_s - (s * quotient) + t' = old_t - (t * quotient) + in (r', s', t') + extendedGCD :: Integer -> Integer -> (Integer, Integer, Integer) extendedGCD x y = (a, b, g * (v finalState)) where diff --git a/test-generator/Tests.hs b/test-generator/Tests.hs index 4fbb2da..f83b0cf 100644 --- a/test-generator/Tests.hs +++ b/test-generator/Tests.hs @@ -271,13 +271,17 @@ sigmulTest size memory0 = egcdTest :: Test egcdTest size memory0 = - let (x, memory1) = generateNum memory0 "x" size - (y, memory2) = generateNum memory1 "y" size - (a, b, v) = extendedGCD x y + let (x, memory1) = genSign (generateNum memory0 "x" size) + (y, memory2) = genSign (generateNum memory1 "y" size) + (a, b, v) = if (x >= 0) && (y >= 0) + then extendedGCD x y + else safeGCD x y res = Map.fromList [("x", showX x), ("y", showX y), ("a", showX a), ("b", showX b), ("v", showX v)] - in assert (v == gcd x y) (res, v, memory2) + in assert (((a * x) + (b * y)) == v) $ + assert (v == gcd x y) $ + (res, v, memory2) moddivTest :: Test moddivTest size memoryIn =