Files
cryptonum/generation/src/ModInv.hs

332 lines
11 KiB
Haskell
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
module ModInv(
generateModInvOps
)
where
import Control.Exception(assert)
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Generators
import GHC.Integer.GMP.Internals(powModInteger, recipModInteger)
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
generateModInvOps :: RustModule
generateModInvOps = RustModule {
predicate = \ me others -> (me + 64) `elem` others,
suggested = \ me -> [me + 64],
outputName = "modinv",
isUnsigned = True,
generator = declareModInv,
testCase = Just generateModInvTest
}
declareModInv :: Word -> [Word] -> SourceFile Span
declareModInv bitsize _ =
let sname = mkIdent ("I" ++ show (bitsize + 64))
uname = mkIdent ("U" ++ show bitsize)
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::convert::TryFrom;
use crate::{CryptoNum,ModularInversion};
use crate::signed::$$sname;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::unsigned::$$uname;
impl ModularInversion for $$uname {
type Signed = $$sname;
fn modinv(&self, phi: &$$uname) -> Option<$$uname>
{
let (_, mut b, g) = phi.egcd(&self);
if g != $$sname::from(1i64) {
return None;
}
let sphi = $$sname::from(phi);
while b.is_negative() {
b += &sphi;
}
if b > sphi {
b -= &sphi;
}
Some($$uname::try_from(b).expect("overflow/underflow in modinv result"))
}
fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) {
// INPUT: two positive integers x and y.
let mut x = $$sname::from(self);
let mut y = $$sname::from(rhs);
// OUTPUT: integers a, b, and v such that ax + by = v,
// where v = gcd(x, y).
// 1. g1.
let mut gshift: usize = 0;
// 2. While x and y are both even, do the following: xx/2,
// yy/2, g2g.
while x.is_even() && y.is_even() {
x >>= 1u64;
y >>= 1u64;
gshift += 1;
}
// 3. ux, vy, A1, B0, C0, D1.
let mut u = x.clone();
let mut v = y.clone();
#[allow(non_snake_case)]
let mut A = $$sname::from(1i64);
#[allow(non_snake_case)]
let mut B = $$sname::zero();
#[allow(non_snake_case)]
let mut C = $$sname::zero();
#[allow(non_snake_case)]
let mut D = $$sname::from(1i64);
loop {
// 4. While u is even do the following:
while u.is_even() {
// 4.1 uu/2.
u >>= 1u64;
// 4.2 If AB0 (mod 2) then AA/2, BB/2; otherwise,
// A(A + y)/2, B(B x)/2.
if A.is_even() && B.is_even() {
A >>= 1u64;
B >>= 1u64;
} else {
A += &y;
A >>= 1u64;
B -= &x;
B >>= 1u64;
}
}
// 5. While v is even do the following:
while v.is_even() {
// 5.1 vv/2.
v >>= 1u64;
// 5.2 If C D 0 (mod 2) then CC/2, DD/2; otherwise,
// C(C + y)/2, D(D x)/2.
if C.is_even() && D.is_even() {
C >>= 1u64;
D >>= 1u64;
} else {
C += &y;
C >>= 1u64;
D -= &x;
D >>= 1u64;
}
}
// 6. If uv then uuv, AAC,BBD;
// otherwise,vvu, CCA, DDB.
if u >= v {
u -= &v;
A -= &C;
B -= &D;
} else {
v -= &u;
C -= &A;
D -= &B;
}
// 7. If u = 0, then aC, bD, and return(a, b, g · v);
// otherwise, go to step 4.
if u.is_zero() {
return (C, D, v << gshift);
}
}
}
fn gcd_is_one(&self, b: &$$uname) -> bool {
let mut u = self.clone();
let mut v = b.clone();
let one = $$uname::from(1u64);
if u.is_zero() {
return v == one;
}
if v.is_zero() {
return u == one;
}
if u.is_even() && v.is_even() {
return false;
}
while u.is_even() {
u >>= 1u64;
}
loop {
while v.is_even() {
v >>= 1u64;
}
// u and v guaranteed to be odd right now.
if u > v {
// make sure that v > u, so that our subtraction works
// out.
let t = u;
u = v;
v = t;
}
v -= &u;
if v.is_zero() {
return u == one;
}
}
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("modinv", $$(testFileLit)), 6, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let (neg3, abytes) = case.get("a").unwrap();
let (neg4, bbytes) = case.get("b").unwrap();
let (neg5, vbytes) = case.get("v").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$uname::from_bytes(xbytes);
let y = $$uname::from_bytes(ybytes);
let z = $$uname::from_bytes(zbytes);
let mut a = $$sname::from_bytes(abytes);
let mut b = $$sname::from_bytes(bbytes);
let mut v = $$sname::from_bytes(vbytes);
if *neg3 { a = -a }
if *neg4 { b = -b }
if *neg5 { v = -v }
let (mya, myb, myv) = x.egcd(&y);
assert_eq!(a, mya);
assert_eq!(b, myb);
assert_eq!(v, myv);
assert_eq!(z, x.modinv(&y).expect("Didn't find a modinv?"));
assert_eq!(v == $$sname::from(1u64), x.gcd_is_one(&y));
});
}
|]
generateModInvTest :: RandomGen g => Word -> g -> (Map String String, g)
generateModInvTest size g = go g
where
go g0 =
let (x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
z = recipModInteger x y
(a, b, v) = extendedGCD x y
tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX z), ("a", showX a),
("b", showX b), ("v", showX v)]
in if z == 0
then go g2
else assert (z < y) $
assert (powModInteger x z y == 1) $
-- assert ((x * z) `mod` y == 1) $
-- assert (((a * x) + (b * y)) == v) $
-- assert (v == gcd x y) $
(tcase, g2)
extendedGCD :: Integer -> Integer -> (Integer, Integer, Integer)
extendedGCD x y = (a, b, g * (v finalState))
where
(x', y', g, initState) = initialState x y 1
finalState = runAlgorithm x' y' initState
a = bigC finalState
b = bigD finalState
data AlgState = AlgState {
u :: Integer,
v :: Integer,
bigA :: Integer,
bigB :: Integer,
bigC :: Integer,
bigD :: Integer
}
initialState :: Integer -> Integer -> Integer -> (Integer, Integer, Integer, AlgState)
initialState x y g | even x && even y = initialState (x `div` 2) (y `div` 2) (g * 2)
| otherwise = (x, y, g, AlgState x y 1 0 0 1)
printState :: AlgState -> IO ()
printState a =
do putStrLn ("u: " ++ showX (u a))
putStrLn ("v: " ++ showX (v a))
putStrLn ("A: " ++ showX (bigA a))
putStrLn ("B: " ++ showX (bigB a))
putStrLn ("C: " ++ showX (bigC a))
putStrLn ("D: " ++ showX (bigD a))
runAlgorithm :: Integer -> Integer -> AlgState -> AlgState
runAlgorithm x y state | u state == 0 = state
| otherwise = runAlgorithm x y state6
where
state4 = step4 x y state
state5 = step5 x y state4
state6 = step6 state5
step4 :: Integer -> Integer -> AlgState -> AlgState
step4 x y input@AlgState{..} | even u = step4 x y input'
| otherwise = input
where
input' = AlgState u' v bigA' bigB' bigC bigD
u' = u `div` 2
bigA' | even bigA && even bigB = bigA `div` 2
| otherwise = (bigA + y) `div` 2
bigB' | even bigA && even bigB = bigB `div` 2
| otherwise = (bigB - x) `div` 2
step5 :: Integer -> Integer -> AlgState -> AlgState
step5 x y input@AlgState{..} | even v = step5 x y input'
| otherwise = input
where
input' = AlgState u v' bigA bigB bigC' bigD'
v' = v `div` 2
bigC' | even bigC && even bigD = bigC `div` 2
| otherwise = (bigC + y) `div` 2
bigD' | even bigC && even bigD = bigD `div` 2
| otherwise = (bigD - x) `div` 2
step6 :: AlgState -> AlgState
step6 AlgState{..}
| u >= v = AlgState (u - v) v (bigA - bigC) (bigB - bigD) bigC bigD
| otherwise = AlgState u (v - u) bigA bigB (bigC - bigA) (bigD - bigB)
_run :: Integer -> Integer -> IO ()
_run inputx inputy =
do let (x, y, g, initState) = initialState inputx inputy 1
finalState <- go x y initState
putStrLn ("-- FINAL STATE -----------------------")
printState finalState
putStrLn ("Final value: " ++ showX (g * v finalState))
putStrLn ("-- RUN ------")
printState (runAlgorithm x y initState)
putStrLn ("-- NORMAL ------")
let (a, b, v) = extendedGCD inputx inputy
putStrLn ("a: " ++ showX a)
putStrLn ("b: " ++ showX b)
putStrLn ("v: " ++ showX v)
where
go x y state =
do putStrLn "-- STATE -----------------------------"
printState state
if u state == 0
then return state
else do let state' = step4 x y state
state'' = step5 x y state'
state''' = step6 state''
go x y state'''