Fix conversions and modinv.

This commit is contained in:
2020-02-09 17:12:23 -06:00
parent d8a2e66e7c
commit 2617609bf6
3 changed files with 257 additions and 34 deletions

View File

@@ -47,6 +47,7 @@ unsignedFiles = [
, conversions , conversions
, cryptoNum , cryptoNum
, divisionOps , divisionOps
, generateModInvOps
, modulusOps , modulusOps
, safeAddOps , safeAddOps
, safeMultiplyOps , safeMultiplyOps
@@ -61,8 +62,7 @@ unsignedFiles = [
signedFiles :: [File] signedFiles :: [File]
signedFiles = [ signedFiles = [
generateModInvOps safeSignedAddOps
, safeSignedAddOps
, safeSignedSubtractOps , safeSignedSubtractOps
, signedBaseOps , signedBaseOps
, signedComparisons , signedComparisons
@@ -118,4 +118,4 @@ main =
chan <- newMVar allTasks chan <- newMVar allTasks
count <- getNumCapabilities count <- getNumCapabilities
threads <- replicateM count (runThread pb (head args) chan) threads <- replicateM count (runThread pb (head args) chan)
forM_ threads (\ m -> takeMVar m) forM_ threads (\ m -> takeMVar m)

View File

@@ -98,7 +98,6 @@ declareSignedConversions :: Word -> [Word] -> SourceFile Span
declareSignedConversions bitsize otherSizes = declareSignedConversions bitsize otherSizes =
let sname = mkIdent ("I" ++ show bitsize) let sname = mkIdent ("I" ++ show bitsize)
uname = mkIdent ("U" ++ show bitsize) uname = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64
u8_prims = buildUSPrimitives sname (mkIdent "u8") u8_prims = buildUSPrimitives sname (mkIdent "u8")
u16_prims = buildUSPrimitives sname (mkIdent "u16") u16_prims = buildUSPrimitives sname (mkIdent "u16")
u32_prims = buildUSPrimitives sname (mkIdent "u32") u32_prims = buildUSPrimitives sname (mkIdent "u32")
@@ -109,7 +108,7 @@ declareSignedConversions bitsize otherSizes =
i32_prims = buildSSPrimitives sname uname (mkIdent "i32") i32_prims = buildSSPrimitives sname uname (mkIdent "i32")
i64_prims = buildSSPrimitives sname uname (mkIdent "i64") i64_prims = buildSSPrimitives sname uname (mkIdent "i64")
isz_prims = buildSSPrimitives sname uname (mkIdent "isize") isz_prims = buildSSPrimitives sname uname (mkIdent "isize")
s128_prims = generateS128Primitives sname uname entries s128_prims = generateS128Primitives sname uname
others = generateSignedCryptonumConversions bitsize otherSizes others = generateSignedCryptonumConversions bitsize otherSizes
in [sourceFile| in [sourceFile|
use core::convert::{From,TryFrom}; use core::convert::{From,TryFrom};
@@ -482,8 +481,8 @@ buildSSPrimitives sname uname prim = [
|] |]
] ]
generateS128Primitives :: Ident -> Ident -> Word -> [Item Span] generateS128Primitives :: Ident -> Ident -> [Item Span]
generateS128Primitives sname uname entries = [ generateS128Primitives sname uname = [
[item| [item|
impl From<u128> for $$sname { impl From<u128> for $$sname {
fn from(x: u128) -> $$sname { fn from(x: u128) -> $$sname {
@@ -560,7 +559,8 @@ generateS128Primitives sname uname entries = [
generateSignedCryptonumConversions :: Word -> [Word] -> [Item Span] generateSignedCryptonumConversions :: Word -> [Word] -> [Item Span]
generateSignedCryptonumConversions source otherSizes = concatMap convert otherSizes generateSignedCryptonumConversions source otherSizes = concatMap convert otherSizes
where where
sName = mkIdent ("I" ++ show source) suName = mkIdent ("U" ++ show source)
ssName = mkIdent ("I" ++ show source)
-- --
convert target = convert target =
let tsName = mkIdent ("I" ++ show target) let tsName = mkIdent ("I" ++ show target)
@@ -575,8 +575,8 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
in case compare source target of in case compare source target of
LT -> [ LT -> [
[item| [item|
impl<'a> From<&'a $$sName> for $$tsName { impl<'a> From<&'a $$ssName> for $$tsName {
fn from(x: &$$sName) -> $$tsName { fn from(x: &$$ssName) -> $$tsName {
let mut res = $$tsName::zero(); let mut res = $$tsName::zero();
res.contents.value[0..$$(sEntries)].copy_from_slice(&x.contents.value); res.contents.value[0..$$(sEntries)].copy_from_slice(&x.contents.value);
let extension = if x.contents.value[$$(sTop)] & 0x8000_0000_0000_0000 == 0 { let extension = if x.contents.value[$$(sTop)] & 0x8000_0000_0000_0000 == 0 {
@@ -590,18 +590,32 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
} }
|], |],
[item| [item|
impl From<$$sName> for $$tsName { impl From<$$ssName> for $$tsName {
fn from(x: $$sName) -> $$tsName { fn from(x: $$ssName) -> $$tsName {
$$tsName::from(&x) $$tsName::from(&x)
} }
} }
|], |],
[item| [item|
impl<'a> TryFrom<&'a $$sName> for $$tuName { impl<'a> From<&'a $$suName> for $$tsName {
fn from(x: &$$suName) -> $$tsName {
$$tsName{ contents: $$tuName::from(x) }
}
}
|],
[item|
impl From<$$suName> for $$tsName {
fn from(x: $$suName) -> $$tsName {
$$tsName{ contents: $$tuName::from(x) }
}
}
|],
[item|
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
type Error = ConversionError; type Error = ConversionError;
fn try_from(x: &$$sName) -> Result<$$tuName,ConversionError> { fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
if x.is_negative() { if x.is_negative() {
Err(ConversionError::NegativeToUnsigned) Err(ConversionError::NegativeToUnsigned)
} else { } else {
@@ -611,10 +625,10 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
} }
|], |],
[item| [item|
impl TryFrom<$$sName> for $$tuName { impl TryFrom<$$ssName> for $$tuName {
type Error = ConversionError; type Error = ConversionError;
fn try_from(x: $$sName) -> Result<$$tuName,ConversionError> { fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
$$tuName::try_from(&x) $$tuName::try_from(&x)
} }
} }
@@ -622,11 +636,11 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
] ]
EQ -> [ EQ -> [
[item| [item|
impl TryFrom<$$tuName> for $$sName { impl TryFrom<$$tuName> for $$ssName {
type Error = ConversionError; type Error = ConversionError;
fn try_from(x: $$tuName) -> Result<$$sName,ConversionError> { fn try_from(x: $$tuName) -> Result<$$ssName,ConversionError> {
let res = $$sName{ contents: x }; let res = $$ssName{ contents: x };
if res.is_negative() { if res.is_negative() {
return Err(ConversionError::Overflow); return Err(ConversionError::Overflow);
@@ -637,19 +651,19 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
} }
|], |],
[item| [item|
impl<'a> TryFrom<&'a $$tuName> for $$sName { impl<'a> TryFrom<&'a $$tuName> for $$ssName {
type Error = ConversionError; type Error = ConversionError;
fn try_from(x: &$$tuName) -> Result<$$sName,ConversionError> { fn try_from(x: &$$tuName) -> Result<$$ssName,ConversionError> {
$$sName::try_from(x.clone()) $$ssName::try_from(x.clone())
} }
} }
|], |],
[item| [item|
impl TryFrom<$$sName> for $$tuName { impl TryFrom<$$ssName> for $$tuName {
type Error = ConversionError; type Error = ConversionError;
fn try_from(x: $$sName) -> Result<$$tuName,ConversionError> { fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
if x.is_negative() { if x.is_negative() {
return Err(ConversionError::Overflow); return Err(ConversionError::Overflow);
} }
@@ -658,15 +672,66 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
} }
|], |],
[item| [item|
impl<'a> TryFrom<&'a $$sName> for $$tuName { impl<'a> TryFrom<&'a $$ssName> for $$tuName {
type Error = ConversionError; type Error = ConversionError;
fn try_from(x: &$$sName) -> Result<$$tuName,ConversionError> { fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
$$tuName::try_from(x.clone()) $$tuName::try_from(x.clone())
} }
} }
|] |]
] ]
GT -> [ GT -> [
[item|
impl<'a> TryFrom<&'a $$ssName> for $$tsName {
type Error = ConversionError;
fn try_from(x: &$$ssName) -> Result<$$tsName,ConversionError> {
let required_top = if x.is_negative() {
0xFFFF_FFFF_FFFF_FFFF
} else {
0
};
if x.contents.value.iter().skip($$(tEntries)).all(|x| *x == required_top) {
let mut res = $$tsName::zero();
res.contents.value.copy_from_slice(&x.contents.value[0..$$(tEntries)]);
Ok(res)
} else {
Err(ConversionError::Overflow)
}
}
}
|],
[item|
impl TryFrom<$$ssName> for $$tsName {
type Error = ConversionError;
fn try_from(x: $$ssName) -> Result<$$tsName,ConversionError> {
$$tsName::try_from(&x)
}
}
|],
[item|
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
if x.is_negative() {
Err(ConversionError::NegativeToUnsigned)
} else {
$$tuName::try_from(&x.contents)
}
}
}
|],
[item|
impl TryFrom<$$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
$$tuName::try_from(&x)
}
}
|]
] ]

View File

@@ -1,32 +1,45 @@
{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
module ModInv( module ModInv(
generateModInvOps generateModInvOps
) )
where where
import Control.Exception(assert)
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import File import File
import Generators
import GHC.Integer.GMP.Internals(recipModInteger)
import Language.Rust.Data.Ident import Language.Rust.Data.Ident
import Language.Rust.Data.Position import Language.Rust.Data.Position
import Language.Rust.Quote import Language.Rust.Quote
import Language.Rust.Syntax import Language.Rust.Syntax
import System.Random(RandomGen)
numTestCases :: Int
numTestCases = 100
generateModInvOps :: File generateModInvOps :: File
generateModInvOps = File { generateModInvOps = File {
predicate = \ me others -> (me + 64) `elem` others, predicate = \ me others -> (me + 64) `elem` others,
outputName = "modinv", outputName = "modinv",
isUnsigned = False, isUnsigned = True,
generator = declareModInv, generator = declareModInv,
testCase = Nothing testCase = Just generateModInvTests
} }
declareModInv :: Word -> [Word] -> SourceFile Span declareModInv :: Word -> [Word] -> SourceFile Span
declareModInv bitsize _ = declareModInv bitsize _ =
let sname = mkIdent ("I" ++ show bitsize) let sname = mkIdent ("I" ++ show (bitsize + 64))
uname = mkIdent ("U" ++ show bitsize) uname = mkIdent ("U" ++ show bitsize)
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile| in [sourceFile|
use core::convert::TryFrom; use core::convert::TryFrom;
use crate::CryptoNum; use crate::CryptoNum;
use crate::signed::$$sname; use crate::signed::$$sname;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::unsigned::$$uname; use crate::unsigned::$$uname;
impl $$uname { impl $$uname {
@@ -38,7 +51,7 @@ declareModInv bitsize _ =
return None; return None;
} }
let sphi = $$sname::try_from(phi).expect("over/underflow in modinv phi"); let sphi = $$sname::from(phi);
while b.is_negative() { while b.is_negative() {
b += &sphi; b += &sphi;
@@ -53,8 +66,8 @@ declareModInv bitsize _ =
fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) { fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) {
// INPUT: two positive integers x and y. // INPUT: two positive integers x and y.
let mut x = $$sname::try_from(self).expect("overflow in modinv base"); let mut x = $$sname::from(self);
let mut y = $$sname::try_from(rhs).expect("overflow in modinv rhs"); let mut y = $$sname::from(rhs);
// OUTPUT: integers a, b, and v such that ax + by = v, // OUTPUT: integers a, b, and v such that ax + by = v,
// where v = gcd(x, y). // where v = gcd(x, y).
// 1. g1. // 1. g1.
@@ -170,4 +183,149 @@ declareModInv bitsize _ =
} }
} }
} }
|]
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("modinv", $$(testFileLit)), 6, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let (neg3, abytes) = case.get("a").unwrap();
let (neg4, bbytes) = case.get("b").unwrap();
let (neg5, vbytes) = case.get("v").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$uname::from_bytes(xbytes);
let y = $$uname::from_bytes(ybytes);
let z = $$uname::from_bytes(zbytes);
let mut a = $$sname::from_bytes(abytes);
let mut b = $$sname::from_bytes(bbytes);
let mut v = $$sname::from_bytes(vbytes);
if *neg3 { a = -a }
if *neg4 { b = -b }
if *neg5 { v = -v }
let (mya, myb, myv) = x.egcd(&y);
assert_eq!(a, mya);
assert_eq!(b, myb);
assert_eq!(v, myv);
assert_eq!(z, x.modinv(&y).expect("Didn't find a modinv?"));
assert_eq!(v == $$sname::from(1u64), x.gcd_is_one(&y));
});
}
|]
generateModInvTests :: RandomGen g => Word -> g -> [Map String String]
generateModInvTests size g = go g numTestCases
where
go _ 0 = []
go g0 i =
let (x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
z = recipModInteger x y
(a, b, v) = extendedGCD x y
tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX z), ("a", showX a),
("b", showX b), ("v", showX v)]
in if z == 0
then go g2 i
else assert (z < y) $
assert ((x * z) `mod` y == 1) $
assert (((a * x) + (b * y)) == v) $
assert (v == gcd x y) $
tcase : go g2 (i - 1)
extendedGCD :: Integer -> Integer -> (Integer, Integer, Integer)
extendedGCD x y = (a, b, g * (v finalState))
where
(x', y', g, initState) = initialState x y 1
finalState = runAlgorithm x' y' initState
a = bigC finalState
b = bigD finalState
data AlgState = AlgState {
u :: Integer,
v :: Integer,
bigA :: Integer,
bigB :: Integer,
bigC :: Integer,
bigD :: Integer
}
initialState :: Integer -> Integer -> Integer -> (Integer, Integer, Integer, AlgState)
initialState x y g | even x && even y = initialState (x `div` 2) (y `div` 2) (g * 2)
| otherwise = (x, y, g, AlgState x y 1 0 0 1)
printState :: AlgState -> IO ()
printState a =
do putStrLn ("u: " ++ showX (u a))
putStrLn ("v: " ++ showX (v a))
putStrLn ("A: " ++ showX (bigA a))
putStrLn ("B: " ++ showX (bigB a))
putStrLn ("C: " ++ showX (bigC a))
putStrLn ("D: " ++ showX (bigD a))
runAlgorithm :: Integer -> Integer -> AlgState -> AlgState
runAlgorithm x y state | u state == 0 = state
| otherwise = runAlgorithm x y state6
where
state4 = step4 x y state
state5 = step5 x y state4
state6 = step6 state5
step4 :: Integer -> Integer -> AlgState -> AlgState
step4 x y input@AlgState{..} | even u = step4 x y input'
| otherwise = input
where
input' = AlgState u' v bigA' bigB' bigC bigD
u' = u `div` 2
bigA' | even bigA && even bigB = bigA `div` 2
| otherwise = (bigA + y) `div` 2
bigB' | even bigA && even bigB = bigB `div` 2
| otherwise = (bigB - x) `div` 2
step5 :: Integer -> Integer -> AlgState -> AlgState
step5 x y input@AlgState{..} | even v = step5 x y input'
| otherwise = input
where
input' = AlgState u v' bigA bigB bigC' bigD'
v' = v `div` 2
bigC' | even bigC && even bigD = bigC `div` 2
| otherwise = (bigC + y) `div` 2
bigD' | even bigC && even bigD = bigD `div` 2
| otherwise = (bigD - x) `div` 2
step6 :: AlgState -> AlgState
step6 AlgState{..}
| u >= v = AlgState (u - v) v (bigA - bigC) (bigB - bigD) bigC bigD
| otherwise = AlgState u (v - u) bigA bigB (bigC - bigA) (bigD - bigB)
_run :: Integer -> Integer -> IO ()
_run inputx inputy =
do let (x, y, g, initState) = initialState inputx inputy 1
finalState <- go x y initState
putStrLn ("-- FINAL STATE -----------------------")
printState finalState
putStrLn ("Final value: " ++ showX (g * v finalState))
putStrLn ("-- RUN ------")
printState (runAlgorithm x y initState)
putStrLn ("-- NORMAL ------")
let (a, b, v) = extendedGCD inputx inputy
putStrLn ("a: " ++ showX a)
putStrLn ("b: " ++ showX b)
putStrLn ("v: " ++ showX v)
where
go x y state =
do putStrLn "-- STATE -----------------------------"
printState state
if u state == 0
then return state
else do let state' = step4 x y state
state'' = step5 x y state'
state''' = step6 state''
go x y state'''