Switch to IO-based tasks.

This commit is contained in:
2019-05-14 21:51:45 -07:00
parent aaa8dc3497
commit 6c61e1c56c
5 changed files with 27 additions and 25 deletions

View File

@@ -11,11 +11,9 @@ import Data.ByteString.Lazy(ByteString)
import qualified Data.ByteString.Lazy as BSL import qualified Data.ByteString.Lazy as BSL
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import Math(showX,showBin) import Math(showX,showBin)
import Task(Task(..),Test) import Task(Task(..),liftTest)
import Utils(HashAlg(..),generateHash,showHash) import Utils(HashAlg(..),generateHash,showHash)
import Debug.Trace
dsaSizes :: [(ParameterSizes, Int)] dsaSizes :: [(ParameterSizes, Int)]
dsaSizes = [(L1024_N160, 400), dsaSizes = [(L1024_N160, 400),
(L2048_N224, 100), (L2048_N224, 100),
@@ -38,27 +36,26 @@ signTest :: ParameterSizes -> Int -> Task
signTest sz cnt = Task { signTest sz cnt = Task {
taskName = "DSA " ++ show sz ++ " signing", taskName = "DSA " ++ show sz ++ " signing",
taskFile = "../testdata/dsa/sign" ++ showParam sz ++ ".test", taskFile = "../testdata/dsa/sign" ++ showParam sz ++ ".test",
taskTest = go, taskTest = liftTest go,
taskCount = cnt taskCount = cnt
} }
where where
go :: Test
go (memory, drg0) = go (memory, drg0) =
case generateProbablePrimes sz drg0 sha256 Nothing of case generateProbablePrimes sz drg0 sha256 Nothing of
Left _ -> trace "generate primes" $ goAdvance memory drg0 Left _ -> goAdvance memory drg0
Right (p, q, _, drg1) -> Right (p, q, _, drg1) ->
case generateUnverifiableGenerator p q of case generateUnverifiableGenerator p q of
Nothing -> trace "generate g" $ goAdvance memory drg1 Nothing -> goAdvance memory drg1
Just g -> Just g ->
let params = Params p g q let params = Params p g q
in case generateKeyPairWithParams params drg1 of in case generateKeyPairWithParams params drg1 of
Left _ -> trace "generate key" $ goAdvance memory drg1 Left _ -> goAdvance memory drg1
Right (pub, priv, drg2) -> Right (pub, priv, drg2) ->
let (msg, drg3) = withDRG drg2 $ getRandomBytes =<< ((fromIntegral . BS.head) `fmap` getRandomBytes 1) let (msg, drg3) = withDRG drg2 $ getRandomBytes =<< ((fromIntegral . BS.head) `fmap` getRandomBytes 1)
(hashf, drg4) = withDRG drg3 generateHash (hashf, drg4) = withDRG drg3 generateHash
in case signMessage' (translateHash hashf) kViaRFC6979 drg4 priv (BSL.fromStrict msg) of in case signMessage' (translateHash hashf) kViaRFC6979 drg4 priv (BSL.fromStrict msg) of
Left _ -> Left _ ->
trace "sign failure" $ go (memory, drg4) go (memory, drg4)
Right (sig, drg5) -> Right (sig, drg5) ->
let res = Map.fromList [("p", showX p), let res = Map.fromList [("p", showX p),
("q", showX q), ("q", showX q),

View File

@@ -15,7 +15,7 @@ import qualified Data.ByteString as S
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import Math(showX,showBin) import Math(showX,showBin)
import RFC6979(generateKStream) import RFC6979(generateKStream)
import Task(Task(..)) import Task(Task(..),liftTest)
import Utils(HashAlg(..),generateHash,runHash,showHash) import Utils(HashAlg(..),generateHash,runHash,showHash)
curves :: [(String, Curve)] curves :: [(String, Curve)]
@@ -29,7 +29,7 @@ negateTest :: String -> Curve -> Task
negateTest name curve = Task { negateTest name curve = Task {
taskName = name ++ " point negation", taskName = name ++ " point negation",
taskFile = "../testdata/ecc/negate/" ++ name ++ ".test", taskFile = "../testdata/ecc/negate/" ++ name ++ ".test",
taskTest = go, taskTest = liftTest go,
taskCount = 1000 taskCount = 1000
} }
where where
@@ -49,7 +49,7 @@ doubleTest :: String -> Curve -> Task
doubleTest name curve = Task { doubleTest name curve = Task {
taskName = name ++ " point doubling", taskName = name ++ " point doubling",
taskFile = "../testdata/ecc/double/" ++ name ++ ".test", taskFile = "../testdata/ecc/double/" ++ name ++ ".test",
taskTest = go, taskTest = liftTest go,
taskCount = 1000 taskCount = 1000
} }
where where
@@ -69,7 +69,7 @@ addTest :: String -> Curve -> Task
addTest name curve = Task { addTest name curve = Task {
taskName = name ++ " point addition", taskName = name ++ " point addition",
taskFile = "../testdata/ecc/add/" ++ name ++ ".test", taskFile = "../testdata/ecc/add/" ++ name ++ ".test",
taskTest = go, taskTest = liftTest go,
taskCount = 1000 taskCount = 1000
} }
where where
@@ -92,7 +92,7 @@ scaleTest :: String -> Curve -> Task
scaleTest name curve = Task { scaleTest name curve = Task {
taskName = name ++ " point scaling", taskName = name ++ " point scaling",
taskFile = "../testdata/ecc/scale/" ++ name ++ ".test", taskFile = "../testdata/ecc/scale/" ++ name ++ ".test",
taskTest = go, taskTest = liftTest go,
taskCount = 1000 taskCount = 1000
} }
where where
@@ -117,7 +117,7 @@ addScaleTest :: String -> Curve -> Task
addScaleTest name curve = Task { addScaleTest name curve = Task {
taskName = name ++ " point addition of two scalings", taskName = name ++ " point addition of two scalings",
taskFile = "../testdata/ecc/add_scale2/" ++ name ++ ".test", taskFile = "../testdata/ecc/add_scale2/" ++ name ++ ".test",
taskTest = go, taskTest = liftTest go,
taskCount = 1000 taskCount = 1000
} }
where where
@@ -144,7 +144,7 @@ signTest :: String -> Curve -> Task
signTest name curve = Task { signTest name curve = Task {
taskName = name ++ " curve signing", taskName = name ++ " curve signing",
taskFile = "../testdata/ecc/sign/" ++ name ++ ".test", taskFile = "../testdata/ecc/sign/" ++ name ++ ".test",
taskTest = go, taskTest = liftTest go,
taskCount = 1000 taskCount = 1000
} }
where where

View File

@@ -15,7 +15,7 @@ import qualified Data.ByteString as S
import Data.Char(toUpper) import Data.Char(toUpper)
import qualified Data.Map.Strict as Map import qualified Data.Map.Strict as Map
import Math(showBin,showX) import Math(showBin,showX)
import Task(Task(..)) import Task(Task(..),liftTest)
import Utils(HashAlg(..), runHash) import Utils(HashAlg(..), runHash)
@@ -88,7 +88,7 @@ rfc6979Test :: HashAlg -> Task
rfc6979Test alg = Task { rfc6979Test alg = Task {
taskName = name ++ " RFC 6979 deterministic k-generation", taskName = name ++ " RFC 6979 deterministic k-generation",
taskFile = "../testdata/rfc6979/" ++ name ++ ".test", taskFile = "../testdata/rfc6979/" ++ name ++ ".test",
taskTest = go, taskTest = liftTest go,
taskCount = 1000 taskCount = 1000
} }
where where

View File

@@ -18,7 +18,7 @@ import Data.Maybe(fromMaybe,isJust)
import Data.Word(Word8) import Data.Word(Word8)
import Database(Database) import Database(Database)
import Math(barrett,computeK,showX,showBin) import Math(barrett,computeK,showX,showBin)
import Task(Task(..)) import Task(Task(..),liftTest)
import Utils(HashAlg(..),generateHash,showHash) import Utils(HashAlg(..),generateHash,showHash)
rsaSizes :: [(Int, Int)] rsaSizes :: [(Int, Int)]
@@ -40,7 +40,7 @@ signTest :: Int -> Int -> Task
signTest sz cnt = Task { signTest sz cnt = Task {
taskName = "RSA " ++ show sz ++ " signing", taskName = "RSA " ++ show sz ++ " signing",
taskFile = "../testdata/rsa/sign" ++ show sz ++ ".test", taskFile = "../testdata/rsa/sign" ++ show sz ++ ".test",
taskTest = go, taskTest = liftTest go,
taskCount = cnt taskCount = cnt
} }
where where
@@ -78,7 +78,7 @@ encryptTest :: Int -> Int -> Task
encryptTest sz cnt = Task { encryptTest sz cnt = Task {
taskName = "RSA " ++ show sz ++ " encryption", taskName = "RSA " ++ show sz ++ " encryption",
taskFile = "../testdata/rsa/encrypt" ++ show sz ++ ".test", taskFile = "../testdata/rsa/encrypt" ++ show sz ++ ".test",
taskTest = go, taskTest = liftTest go,
taskCount = cnt taskCount = cnt
} }
where where

View File

@@ -2,7 +2,8 @@
module Task( module Task(
Test, Test,
Task(..), Task(..),
runTask runTask,
liftTest
) )
where where
@@ -15,7 +16,7 @@ import System.Directory(createDirectoryIfMissing,doesFileExist)
import System.FilePath(takeDirectory) import System.FilePath(takeDirectory)
import System.IO(Handle,IOMode(..),hPutStrLn,withFile) import System.IO(Handle,IOMode(..),hPutStrLn,withFile)
type Test = Database -> (Map.Map String String, Integer, Database) type Test = Database -> IO (Map.Map String String, Integer, Database)
data Task = Task { data Task = Task {
taskName :: String, taskName :: String,
@@ -24,6 +25,10 @@ data Task = Task {
taskCount :: Int taskCount :: Int
} }
liftTest :: (Database -> (Map.Map String String, Integer, Database)) ->
(Database -> IO (Map.Map String String, Integer, Database))
liftTest f db = return (f db)
runTask :: SystemRandom -> Task -> IO SystemRandom runTask :: SystemRandom -> Task -> IO SystemRandom
runTask gen task = runTask gen task =
do createDirectoryIfMissing True (takeDirectory (taskFile task)) do createDirectoryIfMissing True (takeDirectory (taskFile task))
@@ -40,8 +45,8 @@ runTask gen task =
where where
writer :: Handle -> ProgressBar -> Test -> Database -> Int -> IO Database writer :: Handle -> ProgressBar -> Test -> Database -> Int -> IO Database
writer hndl pg runner db x = writer hndl pg runner db x =
do let (output, key, acc@(db',gen')) = runner db do (output, key, acc@(db',gen')) <- runner db
before = Map.findWithDefault [] "RESULT" db' let before = Map.findWithDefault [] "RESULT" db'
if length (filter (== key) before) >= 10 if length (filter (== key) before) >= 10
then writer hndl pg runner acc x then writer hndl pg runner acc x
else do forM_ (Map.toList output) $ \ (outkey, val) -> else do forM_ (Map.toList output) $ \ (outkey, val) ->