diff --git a/test-generator/DSA.hs b/test-generator/DSA.hs index 595db29..4e8d09b 100644 --- a/test-generator/DSA.hs +++ b/test-generator/DSA.hs @@ -11,11 +11,9 @@ import Data.ByteString.Lazy(ByteString) import qualified Data.ByteString.Lazy as BSL import qualified Data.Map.Strict as Map import Math(showX,showBin) -import Task(Task(..),Test) +import Task(Task(..),liftTest) import Utils(HashAlg(..),generateHash,showHash) -import Debug.Trace - dsaSizes :: [(ParameterSizes, Int)] dsaSizes = [(L1024_N160, 400), (L2048_N224, 100), @@ -38,27 +36,26 @@ signTest :: ParameterSizes -> Int -> Task signTest sz cnt = Task { taskName = "DSA " ++ show sz ++ " signing", taskFile = "../testdata/dsa/sign" ++ showParam sz ++ ".test", - taskTest = go, + taskTest = liftTest go, taskCount = cnt } where - go :: Test go (memory, drg0) = case generateProbablePrimes sz drg0 sha256 Nothing of - Left _ -> trace "generate primes" $ goAdvance memory drg0 + Left _ -> goAdvance memory drg0 Right (p, q, _, drg1) -> case generateUnverifiableGenerator p q of - Nothing -> trace "generate g" $ goAdvance memory drg1 + Nothing -> goAdvance memory drg1 Just g -> let params = Params p g q in case generateKeyPairWithParams params drg1 of - Left _ -> trace "generate key" $ goAdvance memory drg1 + Left _ -> goAdvance memory drg1 Right (pub, priv, drg2) -> let (msg, drg3) = withDRG drg2 $ getRandomBytes =<< ((fromIntegral . BS.head) `fmap` getRandomBytes 1) (hashf, drg4) = withDRG drg3 generateHash in case signMessage' (translateHash hashf) kViaRFC6979 drg4 priv (BSL.fromStrict msg) of Left _ -> - trace "sign failure" $ go (memory, drg4) + go (memory, drg4) Right (sig, drg5) -> let res = Map.fromList [("p", showX p), ("q", showX q), diff --git a/test-generator/ECDSATesting.hs b/test-generator/ECDSATesting.hs index 5b1f16d..6955f74 100644 --- a/test-generator/ECDSATesting.hs +++ b/test-generator/ECDSATesting.hs @@ -15,7 +15,7 @@ import qualified Data.ByteString as S import qualified Data.Map.Strict as Map import Math(showX,showBin) import RFC6979(generateKStream) -import Task(Task(..)) +import Task(Task(..),liftTest) import Utils(HashAlg(..),generateHash,runHash,showHash) curves :: [(String, Curve)] @@ -29,7 +29,7 @@ negateTest :: String -> Curve -> Task negateTest name curve = Task { taskName = name ++ " point negation", taskFile = "../testdata/ecc/negate/" ++ name ++ ".test", - taskTest = go, + taskTest = liftTest go, taskCount = 1000 } where @@ -49,7 +49,7 @@ doubleTest :: String -> Curve -> Task doubleTest name curve = Task { taskName = name ++ " point doubling", taskFile = "../testdata/ecc/double/" ++ name ++ ".test", - taskTest = go, + taskTest = liftTest go, taskCount = 1000 } where @@ -69,7 +69,7 @@ addTest :: String -> Curve -> Task addTest name curve = Task { taskName = name ++ " point addition", taskFile = "../testdata/ecc/add/" ++ name ++ ".test", - taskTest = go, + taskTest = liftTest go, taskCount = 1000 } where @@ -92,7 +92,7 @@ scaleTest :: String -> Curve -> Task scaleTest name curve = Task { taskName = name ++ " point scaling", taskFile = "../testdata/ecc/scale/" ++ name ++ ".test", - taskTest = go, + taskTest = liftTest go, taskCount = 1000 } where @@ -117,7 +117,7 @@ addScaleTest :: String -> Curve -> Task addScaleTest name curve = Task { taskName = name ++ " point addition of two scalings", taskFile = "../testdata/ecc/add_scale2/" ++ name ++ ".test", - taskTest = go, + taskTest = liftTest go, taskCount = 1000 } where @@ -144,7 +144,7 @@ signTest :: String -> Curve -> Task signTest name curve = Task { taskName = name ++ " curve signing", taskFile = "../testdata/ecc/sign/" ++ name ++ ".test", - taskTest = go, + taskTest = liftTest go, taskCount = 1000 } where diff --git a/test-generator/RFC6979.hs b/test-generator/RFC6979.hs index 2d4e7cc..160daa3 100644 --- a/test-generator/RFC6979.hs +++ b/test-generator/RFC6979.hs @@ -15,7 +15,7 @@ import qualified Data.ByteString as S import Data.Char(toUpper) import qualified Data.Map.Strict as Map import Math(showBin,showX) -import Task(Task(..)) +import Task(Task(..),liftTest) import Utils(HashAlg(..), runHash) @@ -88,7 +88,7 @@ rfc6979Test :: HashAlg -> Task rfc6979Test alg = Task { taskName = name ++ " RFC 6979 deterministic k-generation", taskFile = "../testdata/rfc6979/" ++ name ++ ".test", - taskTest = go, + taskTest = liftTest go, taskCount = 1000 } where diff --git a/test-generator/RSA.hs b/test-generator/RSA.hs index 38775f8..0785ad4 100644 --- a/test-generator/RSA.hs +++ b/test-generator/RSA.hs @@ -18,7 +18,7 @@ import Data.Maybe(fromMaybe,isJust) import Data.Word(Word8) import Database(Database) import Math(barrett,computeK,showX,showBin) -import Task(Task(..)) +import Task(Task(..),liftTest) import Utils(HashAlg(..),generateHash,showHash) rsaSizes :: [(Int, Int)] @@ -40,7 +40,7 @@ signTest :: Int -> Int -> Task signTest sz cnt = Task { taskName = "RSA " ++ show sz ++ " signing", taskFile = "../testdata/rsa/sign" ++ show sz ++ ".test", - taskTest = go, + taskTest = liftTest go, taskCount = cnt } where @@ -78,7 +78,7 @@ encryptTest :: Int -> Int -> Task encryptTest sz cnt = Task { taskName = "RSA " ++ show sz ++ " encryption", taskFile = "../testdata/rsa/encrypt" ++ show sz ++ ".test", - taskTest = go, + taskTest = liftTest go, taskCount = cnt } where diff --git a/test-generator/Task.hs b/test-generator/Task.hs index aa4c3fc..bc679e4 100644 --- a/test-generator/Task.hs +++ b/test-generator/Task.hs @@ -2,7 +2,8 @@ module Task( Test, Task(..), - runTask + runTask, + liftTest ) where @@ -15,7 +16,7 @@ import System.Directory(createDirectoryIfMissing,doesFileExist) import System.FilePath(takeDirectory) import System.IO(Handle,IOMode(..),hPutStrLn,withFile) -type Test = Database -> (Map.Map String String, Integer, Database) +type Test = Database -> IO (Map.Map String String, Integer, Database) data Task = Task { taskName :: String, @@ -24,6 +25,10 @@ data Task = Task { taskCount :: Int } +liftTest :: (Database -> (Map.Map String String, Integer, Database)) -> + (Database -> IO (Map.Map String String, Integer, Database)) +liftTest f db = return (f db) + runTask :: SystemRandom -> Task -> IO SystemRandom runTask gen task = do createDirectoryIfMissing True (takeDirectory (taskFile task)) @@ -40,8 +45,8 @@ runTask gen task = where writer :: Handle -> ProgressBar -> Test -> Database -> Int -> IO Database writer hndl pg runner db x = - do let (output, key, acc@(db',gen')) = runner db - before = Map.findWithDefault [] "RESULT" db' + do (output, key, acc@(db',gen')) <- runner db + let before = Map.findWithDefault [] "RESULT" db' if length (filter (== key) before) >= 10 then writer hndl pg runner acc x else do forM_ (Map.toList output) $ \ (outkey, val) ->