Fix conversions and modinv.

This commit is contained in:
2020-02-09 17:12:23 -06:00
parent d8a2e66e7c
commit 2617609bf6
3 changed files with 257 additions and 34 deletions

View File

@@ -1,32 +1,45 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
module ModInv(
generateModInvOps
)
where
import Control.Exception(assert)
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import File
import Generators
import GHC.Integer.GMP.Internals(recipModInteger)
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import System.Random(RandomGen)
numTestCases :: Int
numTestCases = 100
generateModInvOps :: File
generateModInvOps = File {
predicate = \ me others -> (me + 64) `elem` others,
outputName = "modinv",
isUnsigned = False,
isUnsigned = True,
generator = declareModInv,
testCase = Nothing
testCase = Just generateModInvTests
}
declareModInv :: Word -> [Word] -> SourceFile Span
declareModInv bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
let sname = mkIdent ("I" ++ show (bitsize + 64))
uname = mkIdent ("U" ++ show bitsize)
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::convert::TryFrom;
use crate::CryptoNum;
use crate::signed::$$sname;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::unsigned::$$uname;
impl $$uname {
@@ -38,7 +51,7 @@ declareModInv bitsize _ =
return None;
}
let sphi = $$sname::try_from(phi).expect("over/underflow in modinv phi");
let sphi = $$sname::from(phi);
while b.is_negative() {
b += &sphi;
@@ -53,8 +66,8 @@ declareModInv bitsize _ =
fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) {
// INPUT: two positive integers x and y.
let mut x = $$sname::try_from(self).expect("overflow in modinv base");
let mut y = $$sname::try_from(rhs).expect("overflow in modinv rhs");
let mut x = $$sname::from(self);
let mut y = $$sname::from(rhs);
// OUTPUT: integers a, b, and v such that ax + by = v,
// where v = gcd(x, y).
// 1. g1.
@@ -170,4 +183,149 @@ declareModInv bitsize _ =
}
}
}
|]
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("modinv", $$(testFileLit)), 6, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let (neg3, abytes) = case.get("a").unwrap();
let (neg4, bbytes) = case.get("b").unwrap();
let (neg5, vbytes) = case.get("v").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$uname::from_bytes(xbytes);
let y = $$uname::from_bytes(ybytes);
let z = $$uname::from_bytes(zbytes);
let mut a = $$sname::from_bytes(abytes);
let mut b = $$sname::from_bytes(bbytes);
let mut v = $$sname::from_bytes(vbytes);
if *neg3 { a = -a }
if *neg4 { b = -b }
if *neg5 { v = -v }
let (mya, myb, myv) = x.egcd(&y);
assert_eq!(a, mya);
assert_eq!(b, myb);
assert_eq!(v, myv);
assert_eq!(z, x.modinv(&y).expect("Didn't find a modinv?"));
assert_eq!(v == $$sname::from(1u64), x.gcd_is_one(&y));
});
}
|]
generateModInvTests :: RandomGen g => Word -> g -> [Map String String]
generateModInvTests size g = go g numTestCases
where
go _ 0 = []
go g0 i =
let (x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
z = recipModInteger x y
(a, b, v) = extendedGCD x y
tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX z), ("a", showX a),
("b", showX b), ("v", showX v)]
in if z == 0
then go g2 i
else assert (z < y) $
assert ((x * z) `mod` y == 1) $
assert (((a * x) + (b * y)) == v) $
assert (v == gcd x y) $
tcase : go g2 (i - 1)
extendedGCD :: Integer -> Integer -> (Integer, Integer, Integer)
extendedGCD x y = (a, b, g * (v finalState))
where
(x', y', g, initState) = initialState x y 1
finalState = runAlgorithm x' y' initState
a = bigC finalState
b = bigD finalState
data AlgState = AlgState {
u :: Integer,
v :: Integer,
bigA :: Integer,
bigB :: Integer,
bigC :: Integer,
bigD :: Integer
}
initialState :: Integer -> Integer -> Integer -> (Integer, Integer, Integer, AlgState)
initialState x y g | even x && even y = initialState (x `div` 2) (y `div` 2) (g * 2)
| otherwise = (x, y, g, AlgState x y 1 0 0 1)
printState :: AlgState -> IO ()
printState a =
do putStrLn ("u: " ++ showX (u a))
putStrLn ("v: " ++ showX (v a))
putStrLn ("A: " ++ showX (bigA a))
putStrLn ("B: " ++ showX (bigB a))
putStrLn ("C: " ++ showX (bigC a))
putStrLn ("D: " ++ showX (bigD a))
runAlgorithm :: Integer -> Integer -> AlgState -> AlgState
runAlgorithm x y state | u state == 0 = state
| otherwise = runAlgorithm x y state6
where
state4 = step4 x y state
state5 = step5 x y state4
state6 = step6 state5
step4 :: Integer -> Integer -> AlgState -> AlgState
step4 x y input@AlgState{..} | even u = step4 x y input'
| otherwise = input
where
input' = AlgState u' v bigA' bigB' bigC bigD
u' = u `div` 2
bigA' | even bigA && even bigB = bigA `div` 2
| otherwise = (bigA + y) `div` 2
bigB' | even bigA && even bigB = bigB `div` 2
| otherwise = (bigB - x) `div` 2
step5 :: Integer -> Integer -> AlgState -> AlgState
step5 x y input@AlgState{..} | even v = step5 x y input'
| otherwise = input
where
input' = AlgState u v' bigA bigB bigC' bigD'
v' = v `div` 2
bigC' | even bigC && even bigD = bigC `div` 2
| otherwise = (bigC + y) `div` 2
bigD' | even bigC && even bigD = bigD `div` 2
| otherwise = (bigD - x) `div` 2
step6 :: AlgState -> AlgState
step6 AlgState{..}
| u >= v = AlgState (u - v) v (bigA - bigC) (bigB - bigD) bigC bigD
| otherwise = AlgState u (v - u) bigA bigB (bigC - bigA) (bigD - bigB)
_run :: Integer -> Integer -> IO ()
_run inputx inputy =
do let (x, y, g, initState) = initialState inputx inputy 1
finalState <- go x y initState
putStrLn ("-- FINAL STATE -----------------------")
printState finalState
putStrLn ("Final value: " ++ showX (g * v finalState))
putStrLn ("-- RUN ------")
printState (runAlgorithm x y initState)
putStrLn ("-- NORMAL ------")
let (a, b, v) = extendedGCD inputx inputy
putStrLn ("a: " ++ showX a)
putStrLn ("b: " ++ showX b)
putStrLn ("v: " ++ showX v)
where
go x y state =
do putStrLn "-- STATE -----------------------------"
printState state
if u state == 0
then return state
else do let state' = step4 x y state
state'' = step5 x y state'
state''' = step6 state''
go x y state'''