Fix conversions and modinv.

This commit is contained in:
2020-02-09 17:12:23 -06:00
parent d8a2e66e7c
commit 2617609bf6
3 changed files with 257 additions and 34 deletions

View File

@@ -47,6 +47,7 @@ unsignedFiles = [
, conversions
, cryptoNum
, divisionOps
, generateModInvOps
, modulusOps
, safeAddOps
, safeMultiplyOps
@@ -61,8 +62,7 @@ unsignedFiles = [
signedFiles :: [File]
signedFiles = [
generateModInvOps
, safeSignedAddOps
safeSignedAddOps
, safeSignedSubtractOps
, signedBaseOps
, signedComparisons
@@ -118,4 +118,4 @@ main =
chan <- newMVar allTasks
count <- getNumCapabilities
threads <- replicateM count (runThread pb (head args) chan)
forM_ threads (\ m -> takeMVar m)
forM_ threads (\ m -> takeMVar m)

View File

@@ -98,7 +98,6 @@ declareSignedConversions :: Word -> [Word] -> SourceFile Span
declareSignedConversions bitsize otherSizes =
let sname = mkIdent ("I" ++ show bitsize)
uname = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64
u8_prims = buildUSPrimitives sname (mkIdent "u8")
u16_prims = buildUSPrimitives sname (mkIdent "u16")
u32_prims = buildUSPrimitives sname (mkIdent "u32")
@@ -109,7 +108,7 @@ declareSignedConversions bitsize otherSizes =
i32_prims = buildSSPrimitives sname uname (mkIdent "i32")
i64_prims = buildSSPrimitives sname uname (mkIdent "i64")
isz_prims = buildSSPrimitives sname uname (mkIdent "isize")
s128_prims = generateS128Primitives sname uname entries
s128_prims = generateS128Primitives sname uname
others = generateSignedCryptonumConversions bitsize otherSizes
in [sourceFile|
use core::convert::{From,TryFrom};
@@ -482,8 +481,8 @@ buildSSPrimitives sname uname prim = [
|]
]
generateS128Primitives :: Ident -> Ident -> Word -> [Item Span]
generateS128Primitives sname uname entries = [
generateS128Primitives :: Ident -> Ident -> [Item Span]
generateS128Primitives sname uname = [
[item|
impl From<u128> for $$sname {
fn from(x: u128) -> $$sname {
@@ -560,7 +559,8 @@ generateS128Primitives sname uname entries = [
generateSignedCryptonumConversions :: Word -> [Word] -> [Item Span]
generateSignedCryptonumConversions source otherSizes = concatMap convert otherSizes
where
sName = mkIdent ("I" ++ show source)
suName = mkIdent ("U" ++ show source)
ssName = mkIdent ("I" ++ show source)
--
convert target =
let tsName = mkIdent ("I" ++ show target)
@@ -575,8 +575,8 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
in case compare source target of
LT -> [
[item|
impl<'a> From<&'a $$sName> for $$tsName {
fn from(x: &$$sName) -> $$tsName {
impl<'a> From<&'a $$ssName> for $$tsName {
fn from(x: &$$ssName) -> $$tsName {
let mut res = $$tsName::zero();
res.contents.value[0..$$(sEntries)].copy_from_slice(&x.contents.value);
let extension = if x.contents.value[$$(sTop)] & 0x8000_0000_0000_0000 == 0 {
@@ -590,18 +590,32 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
}
|],
[item|
impl From<$$sName> for $$tsName {
fn from(x: $$sName) -> $$tsName {
impl From<$$ssName> for $$tsName {
fn from(x: $$ssName) -> $$tsName {
$$tsName::from(&x)
}
}
|],
[item|
impl<'a> TryFrom<&'a $$sName> for $$tuName {
impl<'a> From<&'a $$suName> for $$tsName {
fn from(x: &$$suName) -> $$tsName {
$$tsName{ contents: $$tuName::from(x) }
}
}
|],
[item|
impl From<$$suName> for $$tsName {
fn from(x: $$suName) -> $$tsName {
$$tsName{ contents: $$tuName::from(x) }
}
}
|],
[item|
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: &$$sName) -> Result<$$tuName,ConversionError> {
fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
if x.is_negative() {
Err(ConversionError::NegativeToUnsigned)
} else {
@@ -611,10 +625,10 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
}
|],
[item|
impl TryFrom<$$sName> for $$tuName {
impl TryFrom<$$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: $$sName) -> Result<$$tuName,ConversionError> {
fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
$$tuName::try_from(&x)
}
}
@@ -622,11 +636,11 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
]
EQ -> [
[item|
impl TryFrom<$$tuName> for $$sName {
impl TryFrom<$$tuName> for $$ssName {
type Error = ConversionError;
fn try_from(x: $$tuName) -> Result<$$sName,ConversionError> {
let res = $$sName{ contents: x };
fn try_from(x: $$tuName) -> Result<$$ssName,ConversionError> {
let res = $$ssName{ contents: x };
if res.is_negative() {
return Err(ConversionError::Overflow);
@@ -637,19 +651,19 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
}
|],
[item|
impl<'a> TryFrom<&'a $$tuName> for $$sName {
impl<'a> TryFrom<&'a $$tuName> for $$ssName {
type Error = ConversionError;
fn try_from(x: &$$tuName) -> Result<$$sName,ConversionError> {
$$sName::try_from(x.clone())
fn try_from(x: &$$tuName) -> Result<$$ssName,ConversionError> {
$$ssName::try_from(x.clone())
}
}
|],
[item|
impl TryFrom<$$sName> for $$tuName {
impl TryFrom<$$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: $$sName) -> Result<$$tuName,ConversionError> {
fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
if x.is_negative() {
return Err(ConversionError::Overflow);
}
@@ -658,15 +672,66 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
}
|],
[item|
impl<'a> TryFrom<&'a $$sName> for $$tuName {
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: &$$sName) -> Result<$$tuName,ConversionError> {
fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
$$tuName::try_from(x.clone())
}
}
|]
]
GT -> [
[item|
impl<'a> TryFrom<&'a $$ssName> for $$tsName {
type Error = ConversionError;
fn try_from(x: &$$ssName) -> Result<$$tsName,ConversionError> {
let required_top = if x.is_negative() {
0xFFFF_FFFF_FFFF_FFFF
} else {
0
};
if x.contents.value.iter().skip($$(tEntries)).all(|x| *x == required_top) {
let mut res = $$tsName::zero();
res.contents.value.copy_from_slice(&x.contents.value[0..$$(tEntries)]);
Ok(res)
} else {
Err(ConversionError::Overflow)
}
}
}
|],
[item|
impl TryFrom<$$ssName> for $$tsName {
type Error = ConversionError;
fn try_from(x: $$ssName) -> Result<$$tsName,ConversionError> {
$$tsName::try_from(&x)
}
}
|],
[item|
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
if x.is_negative() {
Err(ConversionError::NegativeToUnsigned)
} else {
$$tuName::try_from(&x.contents)
}
}
}
|],
[item|
impl TryFrom<$$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
$$tuName::try_from(&x)
}
}
|]
]

View File

@@ -1,32 +1,45 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
module ModInv(
generateModInvOps
)
where
import Control.Exception(assert)
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import File
import Generators
import GHC.Integer.GMP.Internals(recipModInteger)
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import System.Random(RandomGen)
numTestCases :: Int
numTestCases = 100
generateModInvOps :: File
generateModInvOps = File {
predicate = \ me others -> (me + 64) `elem` others,
outputName = "modinv",
isUnsigned = False,
isUnsigned = True,
generator = declareModInv,
testCase = Nothing
testCase = Just generateModInvTests
}
declareModInv :: Word -> [Word] -> SourceFile Span
declareModInv bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
let sname = mkIdent ("I" ++ show (bitsize + 64))
uname = mkIdent ("U" ++ show bitsize)
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::convert::TryFrom;
use crate::CryptoNum;
use crate::signed::$$sname;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::unsigned::$$uname;
impl $$uname {
@@ -38,7 +51,7 @@ declareModInv bitsize _ =
return None;
}
let sphi = $$sname::try_from(phi).expect("over/underflow in modinv phi");
let sphi = $$sname::from(phi);
while b.is_negative() {
b += &sphi;
@@ -53,8 +66,8 @@ declareModInv bitsize _ =
fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) {
// INPUT: two positive integers x and y.
let mut x = $$sname::try_from(self).expect("overflow in modinv base");
let mut y = $$sname::try_from(rhs).expect("overflow in modinv rhs");
let mut x = $$sname::from(self);
let mut y = $$sname::from(rhs);
// OUTPUT: integers a, b, and v such that ax + by = v,
// where v = gcd(x, y).
// 1. g1.
@@ -170,4 +183,149 @@ declareModInv bitsize _ =
}
}
}
|]
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("modinv", $$(testFileLit)), 6, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let (neg3, abytes) = case.get("a").unwrap();
let (neg4, bbytes) = case.get("b").unwrap();
let (neg5, vbytes) = case.get("v").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$uname::from_bytes(xbytes);
let y = $$uname::from_bytes(ybytes);
let z = $$uname::from_bytes(zbytes);
let mut a = $$sname::from_bytes(abytes);
let mut b = $$sname::from_bytes(bbytes);
let mut v = $$sname::from_bytes(vbytes);
if *neg3 { a = -a }
if *neg4 { b = -b }
if *neg5 { v = -v }
let (mya, myb, myv) = x.egcd(&y);
assert_eq!(a, mya);
assert_eq!(b, myb);
assert_eq!(v, myv);
assert_eq!(z, x.modinv(&y).expect("Didn't find a modinv?"));
assert_eq!(v == $$sname::from(1u64), x.gcd_is_one(&y));
});
}
|]
generateModInvTests :: RandomGen g => Word -> g -> [Map String String]
generateModInvTests size g = go g numTestCases
where
go _ 0 = []
go g0 i =
let (x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
z = recipModInteger x y
(a, b, v) = extendedGCD x y
tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX z), ("a", showX a),
("b", showX b), ("v", showX v)]
in if z == 0
then go g2 i
else assert (z < y) $
assert ((x * z) `mod` y == 1) $
assert (((a * x) + (b * y)) == v) $
assert (v == gcd x y) $
tcase : go g2 (i - 1)
extendedGCD :: Integer -> Integer -> (Integer, Integer, Integer)
extendedGCD x y = (a, b, g * (v finalState))
where
(x', y', g, initState) = initialState x y 1
finalState = runAlgorithm x' y' initState
a = bigC finalState
b = bigD finalState
data AlgState = AlgState {
u :: Integer,
v :: Integer,
bigA :: Integer,
bigB :: Integer,
bigC :: Integer,
bigD :: Integer
}
initialState :: Integer -> Integer -> Integer -> (Integer, Integer, Integer, AlgState)
initialState x y g | even x && even y = initialState (x `div` 2) (y `div` 2) (g * 2)
| otherwise = (x, y, g, AlgState x y 1 0 0 1)
printState :: AlgState -> IO ()
printState a =
do putStrLn ("u: " ++ showX (u a))
putStrLn ("v: " ++ showX (v a))
putStrLn ("A: " ++ showX (bigA a))
putStrLn ("B: " ++ showX (bigB a))
putStrLn ("C: " ++ showX (bigC a))
putStrLn ("D: " ++ showX (bigD a))
runAlgorithm :: Integer -> Integer -> AlgState -> AlgState
runAlgorithm x y state | u state == 0 = state
| otherwise = runAlgorithm x y state6
where
state4 = step4 x y state
state5 = step5 x y state4
state6 = step6 state5
step4 :: Integer -> Integer -> AlgState -> AlgState
step4 x y input@AlgState{..} | even u = step4 x y input'
| otherwise = input
where
input' = AlgState u' v bigA' bigB' bigC bigD
u' = u `div` 2
bigA' | even bigA && even bigB = bigA `div` 2
| otherwise = (bigA + y) `div` 2
bigB' | even bigA && even bigB = bigB `div` 2
| otherwise = (bigB - x) `div` 2
step5 :: Integer -> Integer -> AlgState -> AlgState
step5 x y input@AlgState{..} | even v = step5 x y input'
| otherwise = input
where
input' = AlgState u v' bigA bigB bigC' bigD'
v' = v `div` 2
bigC' | even bigC && even bigD = bigC `div` 2
| otherwise = (bigC + y) `div` 2
bigD' | even bigC && even bigD = bigD `div` 2
| otherwise = (bigD - x) `div` 2
step6 :: AlgState -> AlgState
step6 AlgState{..}
| u >= v = AlgState (u - v) v (bigA - bigC) (bigB - bigD) bigC bigD
| otherwise = AlgState u (v - u) bigA bigB (bigC - bigA) (bigD - bigB)
_run :: Integer -> Integer -> IO ()
_run inputx inputy =
do let (x, y, g, initState) = initialState inputx inputy 1
finalState <- go x y initState
putStrLn ("-- FINAL STATE -----------------------")
printState finalState
putStrLn ("Final value: " ++ showX (g * v finalState))
putStrLn ("-- RUN ------")
printState (runAlgorithm x y initState)
putStrLn ("-- NORMAL ------")
let (a, b, v) = extendedGCD inputx inputy
putStrLn ("a: " ++ showX a)
putStrLn ("b: " ++ showX b)
putStrLn ("v: " ++ showX v)
where
go x y state =
do putStrLn "-- STATE -----------------------------"
printState state
if u state == 0
then return state
else do let state' = step4 x y state
state'' = step5 x y state'
state''' = step6 state''
go x y state'''