diff --git a/generation/Main.hs b/generation/Main.hs index a007b75..7b40ae4 100644 --- a/generation/Main.hs +++ b/generation/Main.hs @@ -47,6 +47,7 @@ unsignedFiles = [ , conversions , cryptoNum , divisionOps + , generateModInvOps , modulusOps , safeAddOps , safeMultiplyOps @@ -61,8 +62,7 @@ unsignedFiles = [ signedFiles :: [File] signedFiles = [ - generateModInvOps - , safeSignedAddOps + safeSignedAddOps , safeSignedSubtractOps , signedBaseOps , signedComparisons @@ -118,4 +118,4 @@ main = chan <- newMVar allTasks count <- getNumCapabilities threads <- replicateM count (runThread pb (head args) chan) - forM_ threads (\ m -> takeMVar m) \ No newline at end of file + forM_ threads (\ m -> takeMVar m) diff --git a/generation/src/Conversions.hs b/generation/src/Conversions.hs index c038d73..3871f77 100644 --- a/generation/src/Conversions.hs +++ b/generation/src/Conversions.hs @@ -98,7 +98,6 @@ declareSignedConversions :: Word -> [Word] -> SourceFile Span declareSignedConversions bitsize otherSizes = let sname = mkIdent ("I" ++ show bitsize) uname = mkIdent ("U" ++ show bitsize) - entries = bitsize `div` 64 u8_prims = buildUSPrimitives sname (mkIdent "u8") u16_prims = buildUSPrimitives sname (mkIdent "u16") u32_prims = buildUSPrimitives sname (mkIdent "u32") @@ -109,7 +108,7 @@ declareSignedConversions bitsize otherSizes = i32_prims = buildSSPrimitives sname uname (mkIdent "i32") i64_prims = buildSSPrimitives sname uname (mkIdent "i64") isz_prims = buildSSPrimitives sname uname (mkIdent "isize") - s128_prims = generateS128Primitives sname uname entries + s128_prims = generateS128Primitives sname uname others = generateSignedCryptonumConversions bitsize otherSizes in [sourceFile| use core::convert::{From,TryFrom}; @@ -482,8 +481,8 @@ buildSSPrimitives sname uname prim = [ |] ] -generateS128Primitives :: Ident -> Ident -> Word -> [Item Span] -generateS128Primitives sname uname entries = [ +generateS128Primitives :: Ident -> Ident -> [Item Span] +generateS128Primitives sname uname = [ [item| impl From for $$sname { fn from(x: u128) -> $$sname { @@ -560,7 +559,8 @@ generateS128Primitives sname uname entries = [ generateSignedCryptonumConversions :: Word -> [Word] -> [Item Span] generateSignedCryptonumConversions source otherSizes = concatMap convert otherSizes where - sName = mkIdent ("I" ++ show source) + suName = mkIdent ("U" ++ show source) + ssName = mkIdent ("I" ++ show source) -- convert target = let tsName = mkIdent ("I" ++ show target) @@ -575,8 +575,8 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi in case compare source target of LT -> [ [item| - impl<'a> From<&'a $$sName> for $$tsName { - fn from(x: &$$sName) -> $$tsName { + impl<'a> From<&'a $$ssName> for $$tsName { + fn from(x: &$$ssName) -> $$tsName { let mut res = $$tsName::zero(); res.contents.value[0..$$(sEntries)].copy_from_slice(&x.contents.value); let extension = if x.contents.value[$$(sTop)] & 0x8000_0000_0000_0000 == 0 { @@ -590,18 +590,32 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi } |], [item| - impl From<$$sName> for $$tsName { - fn from(x: $$sName) -> $$tsName { + impl From<$$ssName> for $$tsName { + fn from(x: $$ssName) -> $$tsName { $$tsName::from(&x) } } |], [item| - impl<'a> TryFrom<&'a $$sName> for $$tuName { + impl<'a> From<&'a $$suName> for $$tsName { + fn from(x: &$$suName) -> $$tsName { + $$tsName{ contents: $$tuName::from(x) } + } + } + |], + [item| + impl From<$$suName> for $$tsName { + fn from(x: $$suName) -> $$tsName { + $$tsName{ contents: $$tuName::from(x) } + } + } + |], + [item| + impl<'a> TryFrom<&'a $$ssName> for $$tuName { type Error = ConversionError; - fn try_from(x: &$$sName) -> Result<$$tuName,ConversionError> { + fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> { if x.is_negative() { Err(ConversionError::NegativeToUnsigned) } else { @@ -611,10 +625,10 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi } |], [item| - impl TryFrom<$$sName> for $$tuName { + impl TryFrom<$$ssName> for $$tuName { type Error = ConversionError; - fn try_from(x: $$sName) -> Result<$$tuName,ConversionError> { + fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> { $$tuName::try_from(&x) } } @@ -622,11 +636,11 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi ] EQ -> [ [item| - impl TryFrom<$$tuName> for $$sName { + impl TryFrom<$$tuName> for $$ssName { type Error = ConversionError; - fn try_from(x: $$tuName) -> Result<$$sName,ConversionError> { - let res = $$sName{ contents: x }; + fn try_from(x: $$tuName) -> Result<$$ssName,ConversionError> { + let res = $$ssName{ contents: x }; if res.is_negative() { return Err(ConversionError::Overflow); @@ -637,19 +651,19 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi } |], [item| - impl<'a> TryFrom<&'a $$tuName> for $$sName { + impl<'a> TryFrom<&'a $$tuName> for $$ssName { type Error = ConversionError; - fn try_from(x: &$$tuName) -> Result<$$sName,ConversionError> { - $$sName::try_from(x.clone()) + fn try_from(x: &$$tuName) -> Result<$$ssName,ConversionError> { + $$ssName::try_from(x.clone()) } } |], [item| - impl TryFrom<$$sName> for $$tuName { + impl TryFrom<$$ssName> for $$tuName { type Error = ConversionError; - fn try_from(x: $$sName) -> Result<$$tuName,ConversionError> { + fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> { if x.is_negative() { return Err(ConversionError::Overflow); } @@ -658,15 +672,66 @@ generateSignedCryptonumConversions source otherSizes = concatMap convert otherSi } |], [item| - impl<'a> TryFrom<&'a $$sName> for $$tuName { + impl<'a> TryFrom<&'a $$ssName> for $$tuName { type Error = ConversionError; - fn try_from(x: &$$sName) -> Result<$$tuName,ConversionError> { + fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> { $$tuName::try_from(x.clone()) } } |] ] GT -> [ + [item| + impl<'a> TryFrom<&'a $$ssName> for $$tsName { + type Error = ConversionError; + + fn try_from(x: &$$ssName) -> Result<$$tsName,ConversionError> { + let required_top = if x.is_negative() { + 0xFFFF_FFFF_FFFF_FFFF + } else { + 0 + }; + if x.contents.value.iter().skip($$(tEntries)).all(|x| *x == required_top) { + let mut res = $$tsName::zero(); + res.contents.value.copy_from_slice(&x.contents.value[0..$$(tEntries)]); + Ok(res) + } else { + Err(ConversionError::Overflow) + } + } + } + |], + [item| + impl TryFrom<$$ssName> for $$tsName { + type Error = ConversionError; + + fn try_from(x: $$ssName) -> Result<$$tsName,ConversionError> { + $$tsName::try_from(&x) + } + } + |], + [item| + impl<'a> TryFrom<&'a $$ssName> for $$tuName { + type Error = ConversionError; + + fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> { + if x.is_negative() { + Err(ConversionError::NegativeToUnsigned) + } else { + $$tuName::try_from(&x.contents) + } + } + } + |], + [item| + impl TryFrom<$$ssName> for $$tuName { + type Error = ConversionError; + + fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> { + $$tuName::try_from(&x) + } + } + |] ] diff --git a/generation/src/ModInv.hs b/generation/src/ModInv.hs index 8f10f7d..b0fdf96 100644 --- a/generation/src/ModInv.hs +++ b/generation/src/ModInv.hs @@ -1,32 +1,45 @@ -{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} module ModInv( generateModInvOps ) where +import Control.Exception(assert) +import Data.Map.Strict(Map) +import qualified Data.Map.Strict as Map import File +import Generators +import GHC.Integer.GMP.Internals(recipModInteger) import Language.Rust.Data.Ident import Language.Rust.Data.Position import Language.Rust.Quote import Language.Rust.Syntax +import System.Random(RandomGen) + +numTestCases :: Int +numTestCases = 100 generateModInvOps :: File generateModInvOps = File { predicate = \ me others -> (me + 64) `elem` others, outputName = "modinv", - isUnsigned = False, + isUnsigned = True, generator = declareModInv, - testCase = Nothing + testCase = Just generateModInvTests } declareModInv :: Word -> [Word] -> SourceFile Span declareModInv bitsize _ = - let sname = mkIdent ("I" ++ show bitsize) + let sname = mkIdent ("I" ++ show (bitsize + 64)) uname = mkIdent ("U" ++ show bitsize) + testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty in [sourceFile| use core::convert::TryFrom; use crate::CryptoNum; use crate::signed::$$sname; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; use crate::unsigned::$$uname; impl $$uname { @@ -38,7 +51,7 @@ declareModInv bitsize _ = return None; } - let sphi = $$sname::try_from(phi).expect("over/underflow in modinv phi"); + let sphi = $$sname::from(phi); while b.is_negative() { b += &sphi; @@ -53,8 +66,8 @@ declareModInv bitsize _ = fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) { // INPUT: two positive integers x and y. - let mut x = $$sname::try_from(self).expect("overflow in modinv base"); - let mut y = $$sname::try_from(rhs).expect("overflow in modinv rhs"); + let mut x = $$sname::from(self); + let mut y = $$sname::from(rhs); // OUTPUT: integers a, b, and v such that ax + by = v, // where v = gcd(x, y). // 1. g←1. @@ -170,4 +183,149 @@ declareModInv bitsize _ = } } } - |] \ No newline at end of file + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("modinv", $$(testFileLit)), 6, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, ybytes) = case.get("y").unwrap(); + let (neg2, zbytes) = case.get("z").unwrap(); + let (neg3, abytes) = case.get("a").unwrap(); + let (neg4, bbytes) = case.get("b").unwrap(); + let (neg5, vbytes) = case.get("v").unwrap(); + + assert!(!neg0 && !neg1 && !neg2); + let x = $$uname::from_bytes(xbytes); + let y = $$uname::from_bytes(ybytes); + let z = $$uname::from_bytes(zbytes); + let mut a = $$sname::from_bytes(abytes); + let mut b = $$sname::from_bytes(bbytes); + let mut v = $$sname::from_bytes(vbytes); + + if *neg3 { a = -a } + if *neg4 { b = -b } + if *neg5 { v = -v } + + let (mya, myb, myv) = x.egcd(&y); + assert_eq!(a, mya); + assert_eq!(b, myb); + assert_eq!(v, myv); + + assert_eq!(z, x.modinv(&y).expect("Didn't find a modinv?")); + assert_eq!(v == $$sname::from(1u64), x.gcd_is_one(&y)); + }); + } + |] + +generateModInvTests :: RandomGen g => Word -> g -> [Map String String] +generateModInvTests size g = go g numTestCases + where + go _ 0 = [] + go g0 i = + let (x, g1) = generateNum g0 size + (y, g2) = generateNum g1 size + z = recipModInteger x y + (a, b, v) = extendedGCD x y + tcase = Map.fromList [("x", showX x), ("y", showX y), + ("z", showX z), ("a", showX a), + ("b", showX b), ("v", showX v)] + in if z == 0 + then go g2 i + else assert (z < y) $ + assert ((x * z) `mod` y == 1) $ + assert (((a * x) + (b * y)) == v) $ + assert (v == gcd x y) $ + tcase : go g2 (i - 1) + +extendedGCD :: Integer -> Integer -> (Integer, Integer, Integer) +extendedGCD x y = (a, b, g * (v finalState)) + where + (x', y', g, initState) = initialState x y 1 + finalState = runAlgorithm x' y' initState + a = bigC finalState + b = bigD finalState + +data AlgState = AlgState { + u :: Integer, + v :: Integer, + bigA :: Integer, + bigB :: Integer, + bigC :: Integer, + bigD :: Integer +} + +initialState :: Integer -> Integer -> Integer -> (Integer, Integer, Integer, AlgState) +initialState x y g | even x && even y = initialState (x `div` 2) (y `div` 2) (g * 2) + | otherwise = (x, y, g, AlgState x y 1 0 0 1) + +printState :: AlgState -> IO () +printState a = + do putStrLn ("u: " ++ showX (u a)) + putStrLn ("v: " ++ showX (v a)) + putStrLn ("A: " ++ showX (bigA a)) + putStrLn ("B: " ++ showX (bigB a)) + putStrLn ("C: " ++ showX (bigC a)) + putStrLn ("D: " ++ showX (bigD a)) + +runAlgorithm :: Integer -> Integer -> AlgState -> AlgState +runAlgorithm x y state | u state == 0 = state + | otherwise = runAlgorithm x y state6 + where + state4 = step4 x y state + state5 = step5 x y state4 + state6 = step6 state5 + +step4 :: Integer -> Integer -> AlgState -> AlgState +step4 x y input@AlgState{..} | even u = step4 x y input' + | otherwise = input + where + input' = AlgState u' v bigA' bigB' bigC bigD + u' = u `div` 2 + bigA' | even bigA && even bigB = bigA `div` 2 + | otherwise = (bigA + y) `div` 2 + bigB' | even bigA && even bigB = bigB `div` 2 + | otherwise = (bigB - x) `div` 2 + +step5 :: Integer -> Integer -> AlgState -> AlgState +step5 x y input@AlgState{..} | even v = step5 x y input' + | otherwise = input + where + input' = AlgState u v' bigA bigB bigC' bigD' + v' = v `div` 2 + bigC' | even bigC && even bigD = bigC `div` 2 + | otherwise = (bigC + y) `div` 2 + bigD' | even bigC && even bigD = bigD `div` 2 + | otherwise = (bigD - x) `div` 2 + +step6 :: AlgState -> AlgState +step6 AlgState{..} + | u >= v = AlgState (u - v) v (bigA - bigC) (bigB - bigD) bigC bigD + | otherwise = AlgState u (v - u) bigA bigB (bigC - bigA) (bigD - bigB) + +_run :: Integer -> Integer -> IO () +_run inputx inputy = + do let (x, y, g, initState) = initialState inputx inputy 1 + finalState <- go x y initState + putStrLn ("-- FINAL STATE -----------------------") + printState finalState + putStrLn ("Final value: " ++ showX (g * v finalState)) + putStrLn ("-- RUN ------") + printState (runAlgorithm x y initState) + putStrLn ("-- NORMAL ------") + let (a, b, v) = extendedGCD inputx inputy + putStrLn ("a: " ++ showX a) + putStrLn ("b: " ++ showX b) + putStrLn ("v: " ++ showX v) + + where + go x y state = + do putStrLn "-- STATE -----------------------------" + printState state + if u state == 0 + then return state + else do let state' = step4 x y state + state'' = step5 x y state' + state''' = step6 state'' + go x y state'''