Fix conversions and modinv.
This commit is contained in:
@@ -47,6 +47,7 @@ unsignedFiles = [
|
||||
, conversions
|
||||
, cryptoNum
|
||||
, divisionOps
|
||||
, generateModInvOps
|
||||
, modulusOps
|
||||
, safeAddOps
|
||||
, safeMultiplyOps
|
||||
@@ -61,8 +62,7 @@ unsignedFiles = [
|
||||
|
||||
signedFiles :: [File]
|
||||
signedFiles = [
|
||||
generateModInvOps
|
||||
, safeSignedAddOps
|
||||
safeSignedAddOps
|
||||
, safeSignedSubtractOps
|
||||
, signedBaseOps
|
||||
, signedComparisons
|
||||
|
||||
@@ -98,7 +98,6 @@ declareSignedConversions :: Word -> [Word] -> SourceFile Span
|
||||
declareSignedConversions bitsize otherSizes =
|
||||
let sname = mkIdent ("I" ++ show bitsize)
|
||||
uname = mkIdent ("U" ++ show bitsize)
|
||||
entries = bitsize `div` 64
|
||||
u8_prims = buildUSPrimitives sname (mkIdent "u8")
|
||||
u16_prims = buildUSPrimitives sname (mkIdent "u16")
|
||||
u32_prims = buildUSPrimitives sname (mkIdent "u32")
|
||||
@@ -109,7 +108,7 @@ declareSignedConversions bitsize otherSizes =
|
||||
i32_prims = buildSSPrimitives sname uname (mkIdent "i32")
|
||||
i64_prims = buildSSPrimitives sname uname (mkIdent "i64")
|
||||
isz_prims = buildSSPrimitives sname uname (mkIdent "isize")
|
||||
s128_prims = generateS128Primitives sname uname entries
|
||||
s128_prims = generateS128Primitives sname uname
|
||||
others = generateSignedCryptonumConversions bitsize otherSizes
|
||||
in [sourceFile|
|
||||
use core::convert::{From,TryFrom};
|
||||
@@ -482,8 +481,8 @@ buildSSPrimitives sname uname prim = [
|
||||
|]
|
||||
]
|
||||
|
||||
generateS128Primitives :: Ident -> Ident -> Word -> [Item Span]
|
||||
generateS128Primitives sname uname entries = [
|
||||
generateS128Primitives :: Ident -> Ident -> [Item Span]
|
||||
generateS128Primitives sname uname = [
|
||||
[item|
|
||||
impl From<u128> for $$sname {
|
||||
fn from(x: u128) -> $$sname {
|
||||
@@ -560,7 +559,8 @@ generateS128Primitives sname uname entries = [
|
||||
generateSignedCryptonumConversions :: Word -> [Word] -> [Item Span]
|
||||
generateSignedCryptonumConversions source otherSizes = concatMap convert otherSizes
|
||||
where
|
||||
sName = mkIdent ("I" ++ show source)
|
||||
suName = mkIdent ("U" ++ show source)
|
||||
ssName = mkIdent ("I" ++ show source)
|
||||
--
|
||||
convert target =
|
||||
let tsName = mkIdent ("I" ++ show target)
|
||||
@@ -575,8 +575,8 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
|
||||
in case compare source target of
|
||||
LT -> [
|
||||
[item|
|
||||
impl<'a> From<&'a $$sName> for $$tsName {
|
||||
fn from(x: &$$sName) -> $$tsName {
|
||||
impl<'a> From<&'a $$ssName> for $$tsName {
|
||||
fn from(x: &$$ssName) -> $$tsName {
|
||||
let mut res = $$tsName::zero();
|
||||
res.contents.value[0..$$(sEntries)].copy_from_slice(&x.contents.value);
|
||||
let extension = if x.contents.value[$$(sTop)] & 0x8000_0000_0000_0000 == 0 {
|
||||
@@ -590,18 +590,32 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl From<$$sName> for $$tsName {
|
||||
fn from(x: $$sName) -> $$tsName {
|
||||
impl From<$$ssName> for $$tsName {
|
||||
fn from(x: $$ssName) -> $$tsName {
|
||||
$$tsName::from(&x)
|
||||
}
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl<'a> TryFrom<&'a $$sName> for $$tuName {
|
||||
impl<'a> From<&'a $$suName> for $$tsName {
|
||||
fn from(x: &$$suName) -> $$tsName {
|
||||
$$tsName{ contents: $$tuName::from(x) }
|
||||
}
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl From<$$suName> for $$tsName {
|
||||
fn from(x: $$suName) -> $$tsName {
|
||||
$$tsName{ contents: $$tuName::from(x) }
|
||||
}
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
|
||||
type Error = ConversionError;
|
||||
|
||||
|
||||
fn try_from(x: &$$sName) -> Result<$$tuName,ConversionError> {
|
||||
fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
|
||||
if x.is_negative() {
|
||||
Err(ConversionError::NegativeToUnsigned)
|
||||
} else {
|
||||
@@ -611,10 +625,10 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl TryFrom<$$sName> for $$tuName {
|
||||
impl TryFrom<$$ssName> for $$tuName {
|
||||
type Error = ConversionError;
|
||||
|
||||
fn try_from(x: $$sName) -> Result<$$tuName,ConversionError> {
|
||||
fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
|
||||
$$tuName::try_from(&x)
|
||||
}
|
||||
}
|
||||
@@ -622,11 +636,11 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
|
||||
]
|
||||
EQ -> [
|
||||
[item|
|
||||
impl TryFrom<$$tuName> for $$sName {
|
||||
impl TryFrom<$$tuName> for $$ssName {
|
||||
type Error = ConversionError;
|
||||
|
||||
fn try_from(x: $$tuName) -> Result<$$sName,ConversionError> {
|
||||
let res = $$sName{ contents: x };
|
||||
fn try_from(x: $$tuName) -> Result<$$ssName,ConversionError> {
|
||||
let res = $$ssName{ contents: x };
|
||||
|
||||
if res.is_negative() {
|
||||
return Err(ConversionError::Overflow);
|
||||
@@ -637,19 +651,19 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl<'a> TryFrom<&'a $$tuName> for $$sName {
|
||||
impl<'a> TryFrom<&'a $$tuName> for $$ssName {
|
||||
type Error = ConversionError;
|
||||
|
||||
fn try_from(x: &$$tuName) -> Result<$$sName,ConversionError> {
|
||||
$$sName::try_from(x.clone())
|
||||
fn try_from(x: &$$tuName) -> Result<$$ssName,ConversionError> {
|
||||
$$ssName::try_from(x.clone())
|
||||
}
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl TryFrom<$$sName> for $$tuName {
|
||||
impl TryFrom<$$ssName> for $$tuName {
|
||||
type Error = ConversionError;
|
||||
|
||||
fn try_from(x: $$sName) -> Result<$$tuName,ConversionError> {
|
||||
fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
|
||||
if x.is_negative() {
|
||||
return Err(ConversionError::Overflow);
|
||||
}
|
||||
@@ -658,15 +672,66 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl<'a> TryFrom<&'a $$sName> for $$tuName {
|
||||
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
|
||||
type Error = ConversionError;
|
||||
|
||||
fn try_from(x: &$$sName) -> Result<$$tuName,ConversionError> {
|
||||
fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
|
||||
$$tuName::try_from(x.clone())
|
||||
}
|
||||
}
|
||||
|]
|
||||
]
|
||||
GT -> [
|
||||
[item|
|
||||
impl<'a> TryFrom<&'a $$ssName> for $$tsName {
|
||||
type Error = ConversionError;
|
||||
|
||||
fn try_from(x: &$$ssName) -> Result<$$tsName,ConversionError> {
|
||||
let required_top = if x.is_negative() {
|
||||
0xFFFF_FFFF_FFFF_FFFF
|
||||
} else {
|
||||
0
|
||||
};
|
||||
if x.contents.value.iter().skip($$(tEntries)).all(|x| *x == required_top) {
|
||||
let mut res = $$tsName::zero();
|
||||
res.contents.value.copy_from_slice(&x.contents.value[0..$$(tEntries)]);
|
||||
Ok(res)
|
||||
} else {
|
||||
Err(ConversionError::Overflow)
|
||||
}
|
||||
}
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl TryFrom<$$ssName> for $$tsName {
|
||||
type Error = ConversionError;
|
||||
|
||||
fn try_from(x: $$ssName) -> Result<$$tsName,ConversionError> {
|
||||
$$tsName::try_from(&x)
|
||||
}
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
|
||||
type Error = ConversionError;
|
||||
|
||||
fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
|
||||
if x.is_negative() {
|
||||
Err(ConversionError::NegativeToUnsigned)
|
||||
} else {
|
||||
$$tuName::try_from(&x.contents)
|
||||
}
|
||||
}
|
||||
}
|
||||
|],
|
||||
[item|
|
||||
impl TryFrom<$$ssName> for $$tuName {
|
||||
type Error = ConversionError;
|
||||
|
||||
fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
|
||||
$$tuName::try_from(&x)
|
||||
}
|
||||
}
|
||||
|]
|
||||
]
|
||||
|
||||
|
||||
@@ -1,32 +1,45 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
module ModInv(
|
||||
generateModInvOps
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Exception(assert)
|
||||
import Data.Map.Strict(Map)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import File
|
||||
import Generators
|
||||
import GHC.Integer.GMP.Internals(recipModInteger)
|
||||
import Language.Rust.Data.Ident
|
||||
import Language.Rust.Data.Position
|
||||
import Language.Rust.Quote
|
||||
import Language.Rust.Syntax
|
||||
import System.Random(RandomGen)
|
||||
|
||||
numTestCases :: Int
|
||||
numTestCases = 100
|
||||
|
||||
generateModInvOps :: File
|
||||
generateModInvOps = File {
|
||||
predicate = \ me others -> (me + 64) `elem` others,
|
||||
outputName = "modinv",
|
||||
isUnsigned = False,
|
||||
isUnsigned = True,
|
||||
generator = declareModInv,
|
||||
testCase = Nothing
|
||||
testCase = Just generateModInvTests
|
||||
}
|
||||
|
||||
declareModInv :: Word -> [Word] -> SourceFile Span
|
||||
declareModInv bitsize _ =
|
||||
let sname = mkIdent ("I" ++ show bitsize)
|
||||
let sname = mkIdent ("I" ++ show (bitsize + 64))
|
||||
uname = mkIdent ("U" ++ show bitsize)
|
||||
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
|
||||
in [sourceFile|
|
||||
use core::convert::TryFrom;
|
||||
use crate::CryptoNum;
|
||||
use crate::signed::$$sname;
|
||||
#[cfg(test)]
|
||||
use crate::testing::{build_test_path,run_test};
|
||||
use crate::unsigned::$$uname;
|
||||
|
||||
impl $$uname {
|
||||
@@ -38,7 +51,7 @@ declareModInv bitsize _ =
|
||||
return None;
|
||||
}
|
||||
|
||||
let sphi = $$sname::try_from(phi).expect("over/underflow in modinv phi");
|
||||
let sphi = $$sname::from(phi);
|
||||
|
||||
while b.is_negative() {
|
||||
b += &sphi;
|
||||
@@ -53,8 +66,8 @@ declareModInv bitsize _ =
|
||||
|
||||
fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) {
|
||||
// INPUT: two positive integers x and y.
|
||||
let mut x = $$sname::try_from(self).expect("overflow in modinv base");
|
||||
let mut y = $$sname::try_from(rhs).expect("overflow in modinv rhs");
|
||||
let mut x = $$sname::from(self);
|
||||
let mut y = $$sname::from(rhs);
|
||||
// OUTPUT: integers a, b, and v such that ax + by = v,
|
||||
// where v = gcd(x, y).
|
||||
// 1. g←1.
|
||||
@@ -170,4 +183,149 @@ declareModInv bitsize _ =
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(non_snake_case)]
|
||||
#[test]
|
||||
fn KATs() {
|
||||
run_test(build_test_path("modinv", $$(testFileLit)), 6, |case| {
|
||||
let (neg0, xbytes) = case.get("x").unwrap();
|
||||
let (neg1, ybytes) = case.get("y").unwrap();
|
||||
let (neg2, zbytes) = case.get("z").unwrap();
|
||||
let (neg3, abytes) = case.get("a").unwrap();
|
||||
let (neg4, bbytes) = case.get("b").unwrap();
|
||||
let (neg5, vbytes) = case.get("v").unwrap();
|
||||
|
||||
assert!(!neg0 && !neg1 && !neg2);
|
||||
let x = $$uname::from_bytes(xbytes);
|
||||
let y = $$uname::from_bytes(ybytes);
|
||||
let z = $$uname::from_bytes(zbytes);
|
||||
let mut a = $$sname::from_bytes(abytes);
|
||||
let mut b = $$sname::from_bytes(bbytes);
|
||||
let mut v = $$sname::from_bytes(vbytes);
|
||||
|
||||
if *neg3 { a = -a }
|
||||
if *neg4 { b = -b }
|
||||
if *neg5 { v = -v }
|
||||
|
||||
let (mya, myb, myv) = x.egcd(&y);
|
||||
assert_eq!(a, mya);
|
||||
assert_eq!(b, myb);
|
||||
assert_eq!(v, myv);
|
||||
|
||||
assert_eq!(z, x.modinv(&y).expect("Didn't find a modinv?"));
|
||||
assert_eq!(v == $$sname::from(1u64), x.gcd_is_one(&y));
|
||||
});
|
||||
}
|
||||
|]
|
||||
|
||||
generateModInvTests :: RandomGen g => Word -> g -> [Map String String]
|
||||
generateModInvTests size g = go g numTestCases
|
||||
where
|
||||
go _ 0 = []
|
||||
go g0 i =
|
||||
let (x, g1) = generateNum g0 size
|
||||
(y, g2) = generateNum g1 size
|
||||
z = recipModInteger x y
|
||||
(a, b, v) = extendedGCD x y
|
||||
tcase = Map.fromList [("x", showX x), ("y", showX y),
|
||||
("z", showX z), ("a", showX a),
|
||||
("b", showX b), ("v", showX v)]
|
||||
in if z == 0
|
||||
then go g2 i
|
||||
else assert (z < y) $
|
||||
assert ((x * z) `mod` y == 1) $
|
||||
assert (((a * x) + (b * y)) == v) $
|
||||
assert (v == gcd x y) $
|
||||
tcase : go g2 (i - 1)
|
||||
|
||||
extendedGCD :: Integer -> Integer -> (Integer, Integer, Integer)
|
||||
extendedGCD x y = (a, b, g * (v finalState))
|
||||
where
|
||||
(x', y', g, initState) = initialState x y 1
|
||||
finalState = runAlgorithm x' y' initState
|
||||
a = bigC finalState
|
||||
b = bigD finalState
|
||||
|
||||
data AlgState = AlgState {
|
||||
u :: Integer,
|
||||
v :: Integer,
|
||||
bigA :: Integer,
|
||||
bigB :: Integer,
|
||||
bigC :: Integer,
|
||||
bigD :: Integer
|
||||
}
|
||||
|
||||
initialState :: Integer -> Integer -> Integer -> (Integer, Integer, Integer, AlgState)
|
||||
initialState x y g | even x && even y = initialState (x `div` 2) (y `div` 2) (g * 2)
|
||||
| otherwise = (x, y, g, AlgState x y 1 0 0 1)
|
||||
|
||||
printState :: AlgState -> IO ()
|
||||
printState a =
|
||||
do putStrLn ("u: " ++ showX (u a))
|
||||
putStrLn ("v: " ++ showX (v a))
|
||||
putStrLn ("A: " ++ showX (bigA a))
|
||||
putStrLn ("B: " ++ showX (bigB a))
|
||||
putStrLn ("C: " ++ showX (bigC a))
|
||||
putStrLn ("D: " ++ showX (bigD a))
|
||||
|
||||
runAlgorithm :: Integer -> Integer -> AlgState -> AlgState
|
||||
runAlgorithm x y state | u state == 0 = state
|
||||
| otherwise = runAlgorithm x y state6
|
||||
where
|
||||
state4 = step4 x y state
|
||||
state5 = step5 x y state4
|
||||
state6 = step6 state5
|
||||
|
||||
step4 :: Integer -> Integer -> AlgState -> AlgState
|
||||
step4 x y input@AlgState{..} | even u = step4 x y input'
|
||||
| otherwise = input
|
||||
where
|
||||
input' = AlgState u' v bigA' bigB' bigC bigD
|
||||
u' = u `div` 2
|
||||
bigA' | even bigA && even bigB = bigA `div` 2
|
||||
| otherwise = (bigA + y) `div` 2
|
||||
bigB' | even bigA && even bigB = bigB `div` 2
|
||||
| otherwise = (bigB - x) `div` 2
|
||||
|
||||
step5 :: Integer -> Integer -> AlgState -> AlgState
|
||||
step5 x y input@AlgState{..} | even v = step5 x y input'
|
||||
| otherwise = input
|
||||
where
|
||||
input' = AlgState u v' bigA bigB bigC' bigD'
|
||||
v' = v `div` 2
|
||||
bigC' | even bigC && even bigD = bigC `div` 2
|
||||
| otherwise = (bigC + y) `div` 2
|
||||
bigD' | even bigC && even bigD = bigD `div` 2
|
||||
| otherwise = (bigD - x) `div` 2
|
||||
|
||||
step6 :: AlgState -> AlgState
|
||||
step6 AlgState{..}
|
||||
| u >= v = AlgState (u - v) v (bigA - bigC) (bigB - bigD) bigC bigD
|
||||
| otherwise = AlgState u (v - u) bigA bigB (bigC - bigA) (bigD - bigB)
|
||||
|
||||
_run :: Integer -> Integer -> IO ()
|
||||
_run inputx inputy =
|
||||
do let (x, y, g, initState) = initialState inputx inputy 1
|
||||
finalState <- go x y initState
|
||||
putStrLn ("-- FINAL STATE -----------------------")
|
||||
printState finalState
|
||||
putStrLn ("Final value: " ++ showX (g * v finalState))
|
||||
putStrLn ("-- RUN ------")
|
||||
printState (runAlgorithm x y initState)
|
||||
putStrLn ("-- NORMAL ------")
|
||||
let (a, b, v) = extendedGCD inputx inputy
|
||||
putStrLn ("a: " ++ showX a)
|
||||
putStrLn ("b: " ++ showX b)
|
||||
putStrLn ("v: " ++ showX v)
|
||||
|
||||
where
|
||||
go x y state =
|
||||
do putStrLn "-- STATE -----------------------------"
|
||||
printState state
|
||||
if u state == 0
|
||||
then return state
|
||||
else do let state' = step4 x y state
|
||||
state'' = step5 x y state'
|
||||
state''' = step6 state''
|
||||
go x y state'''
|
||||
|
||||
Reference in New Issue
Block a user