{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} module ModInv( generateModInvOps ) where import Control.Exception(assert) import Data.Map.Strict(Map) import qualified Data.Map.Strict as Map import Generators import GHC.Integer.GMP.Internals(recipModInteger) import Language.Rust.Data.Ident import Language.Rust.Data.Position import Language.Rust.Quote import Language.Rust.Syntax import RustModule import System.Random(RandomGen) generateModInvOps :: RustModule generateModInvOps = RustModule { predicate = \ me others -> (me + 64) `elem` others, outputName = "modinv", isUnsigned = True, generator = declareModInv, testCase = Just generateModInvTest } declareModInv :: Word -> [Word] -> SourceFile Span declareModInv bitsize _ = let sname = mkIdent ("I" ++ show (bitsize + 64)) uname = mkIdent ("U" ++ show bitsize) testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty in [sourceFile| use core::convert::TryFrom; use crate::{CryptoNum,ModularInversion}; use crate::signed::$$sname; #[cfg(test)] use crate::testing::{build_test_path,run_test}; use crate::unsigned::$$uname; impl ModularInversion for $$uname { type Signed = $$sname; fn modinv(&self, phi: &$$uname) -> Option<$$uname> { let (_, mut b, g) = phi.egcd(&self); if g != $$sname::from(1i64) { return None; } let sphi = $$sname::from(phi); while b.is_negative() { b += &sphi; } if b > sphi { b -= &sphi; } Some($$uname::try_from(b).expect("overflow/underflow in modinv result")) } fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) { // INPUT: two positive integers x and y. let mut x = $$sname::from(self); let mut y = $$sname::from(rhs); // OUTPUT: integers a, b, and v such that ax + by = v, // where v = gcd(x, y). // 1. g←1. let mut gshift: usize = 0; // 2. While x and y are both even, do the following: x←x/2, // y←y/2, g←2g. while x.is_even() && y.is_even() { x >>= 1u64; y >>= 1u64; gshift += 1; } // 3. u←x, v←y, A←1, B←0, C←0, D←1. let mut u = x.clone(); let mut v = y.clone(); #[allow(non_snake_case)] let mut A = $$sname::from(1i64); #[allow(non_snake_case)] let mut B = $$sname::zero(); #[allow(non_snake_case)] let mut C = $$sname::zero(); #[allow(non_snake_case)] let mut D = $$sname::from(1i64); loop { // 4. While u is even do the following: while u.is_even() { // 4.1 u←u/2. u >>= 1u64; // 4.2 If A≡B≡0 (mod 2) then A←A/2, B←B/2; otherwise, // A←(A + y)/2, B←(B − x)/2. if A.is_even() && B.is_even() { A >>= 1u64; B >>= 1u64; } else { A += &y; A >>= 1u64; B -= &x; B >>= 1u64; } } // 5. While v is even do the following: while v.is_even() { // 5.1 v←v/2. v >>= 1u64; // 5.2 If C ≡ D ≡ 0 (mod 2) then C←C/2, D←D/2; otherwise, // C←(C + y)/2, D←(D − x)/2. if C.is_even() && D.is_even() { C >>= 1u64; D >>= 1u64; } else { C += &y; C >>= 1u64; D -= &x; D >>= 1u64; } } // 6. If u≥v then u←u−v, A←A−C,B←B−D; // otherwise,v←v−u, C←C−A, D←D−B. if u >= v { u -= &v; A -= &C; B -= &D; } else { v -= &u; C -= &A; D -= &B; } // 7. If u = 0, then a←C, b←D, and return(a, b, g · v); // otherwise, go to step 4. if u.is_zero() { return (C, D, v << gshift); } } } fn gcd_is_one(&self, b: &$$uname) -> bool { let mut u = self.clone(); let mut v = b.clone(); let one = $$uname::from(1u64); if u.is_zero() { return v == one; } if v.is_zero() { return u == one; } if u.is_even() && v.is_even() { return false; } while u.is_even() { u >>= 1u64; } loop { while v.is_even() { v >>= 1u64; } // u and v guaranteed to be odd right now. if u > v { // make sure that v > u, so that our subtraction works // out. let t = u; u = v; v = t; } v -= &u; if v.is_zero() { return u == one; } } } } #[cfg(test)] #[allow(non_snake_case)] #[test] fn KATs() { run_test(build_test_path("modinv", $$(testFileLit)), 6, |case| { let (neg0, xbytes) = case.get("x").unwrap(); let (neg1, ybytes) = case.get("y").unwrap(); let (neg2, zbytes) = case.get("z").unwrap(); let (neg3, abytes) = case.get("a").unwrap(); let (neg4, bbytes) = case.get("b").unwrap(); let (neg5, vbytes) = case.get("v").unwrap(); assert!(!neg0 && !neg1 && !neg2); let x = $$uname::from_bytes(xbytes); let y = $$uname::from_bytes(ybytes); let z = $$uname::from_bytes(zbytes); let mut a = $$sname::from_bytes(abytes); let mut b = $$sname::from_bytes(bbytes); let mut v = $$sname::from_bytes(vbytes); if *neg3 { a = -a } if *neg4 { b = -b } if *neg5 { v = -v } let (mya, myb, myv) = x.egcd(&y); assert_eq!(a, mya); assert_eq!(b, myb); assert_eq!(v, myv); assert_eq!(z, x.modinv(&y).expect("Didn't find a modinv?")); assert_eq!(v == $$sname::from(1u64), x.gcd_is_one(&y)); }); } |] generateModInvTest :: RandomGen g => Word -> g -> (Map String String, g) generateModInvTest size g = go g where go g0 = let (x, g1) = generateNum g0 size (y, g2) = generateNum g1 size z = recipModInteger x y (a, b, v) = extendedGCD x y tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z), ("a", showX a), ("b", showX b), ("v", showX v)] in if z == 0 then go g2 else assert (z < y) $ assert ((x * z) `mod` y == 1) $ assert (((a * x) + (b * y)) == v) $ assert (v == gcd x y) $ (tcase, g2) extendedGCD :: Integer -> Integer -> (Integer, Integer, Integer) extendedGCD x y = (a, b, g * (v finalState)) where (x', y', g, initState) = initialState x y 1 finalState = runAlgorithm x' y' initState a = bigC finalState b = bigD finalState data AlgState = AlgState { u :: Integer, v :: Integer, bigA :: Integer, bigB :: Integer, bigC :: Integer, bigD :: Integer } initialState :: Integer -> Integer -> Integer -> (Integer, Integer, Integer, AlgState) initialState x y g | even x && even y = initialState (x `div` 2) (y `div` 2) (g * 2) | otherwise = (x, y, g, AlgState x y 1 0 0 1) printState :: AlgState -> IO () printState a = do putStrLn ("u: " ++ showX (u a)) putStrLn ("v: " ++ showX (v a)) putStrLn ("A: " ++ showX (bigA a)) putStrLn ("B: " ++ showX (bigB a)) putStrLn ("C: " ++ showX (bigC a)) putStrLn ("D: " ++ showX (bigD a)) runAlgorithm :: Integer -> Integer -> AlgState -> AlgState runAlgorithm x y state | u state == 0 = state | otherwise = runAlgorithm x y state6 where state4 = step4 x y state state5 = step5 x y state4 state6 = step6 state5 step4 :: Integer -> Integer -> AlgState -> AlgState step4 x y input@AlgState{..} | even u = step4 x y input' | otherwise = input where input' = AlgState u' v bigA' bigB' bigC bigD u' = u `div` 2 bigA' | even bigA && even bigB = bigA `div` 2 | otherwise = (bigA + y) `div` 2 bigB' | even bigA && even bigB = bigB `div` 2 | otherwise = (bigB - x) `div` 2 step5 :: Integer -> Integer -> AlgState -> AlgState step5 x y input@AlgState{..} | even v = step5 x y input' | otherwise = input where input' = AlgState u v' bigA bigB bigC' bigD' v' = v `div` 2 bigC' | even bigC && even bigD = bigC `div` 2 | otherwise = (bigC + y) `div` 2 bigD' | even bigC && even bigD = bigD `div` 2 | otherwise = (bigD - x) `div` 2 step6 :: AlgState -> AlgState step6 AlgState{..} | u >= v = AlgState (u - v) v (bigA - bigC) (bigB - bigD) bigC bigD | otherwise = AlgState u (v - u) bigA bigB (bigC - bigA) (bigD - bigB) _run :: Integer -> Integer -> IO () _run inputx inputy = do let (x, y, g, initState) = initialState inputx inputy 1 finalState <- go x y initState putStrLn ("-- FINAL STATE -----------------------") printState finalState putStrLn ("Final value: " ++ showX (g * v finalState)) putStrLn ("-- RUN ------") printState (runAlgorithm x y initState) putStrLn ("-- NORMAL ------") let (a, b, v) = extendedGCD inputx inputy putStrLn ("a: " ++ showX a) putStrLn ("b: " ++ showX b) putStrLn ("v: " ++ showX v) where go x y state = do putStrLn "-- STATE -----------------------------" printState state if u state == 0 then return state else do let state' = step4 x y state state'' = step5 x y state' state''' = step6 state'' go x y state'''