diff --git a/generation/src/BinaryOps.hs b/generation/src/BinaryOps.hs index f2cd991..ff7b2f2 100644 --- a/generation/src/BinaryOps.hs +++ b/generation/src/BinaryOps.hs @@ -4,12 +4,20 @@ module BinaryOps( ) where +import Data.Bits(xor,(.&.),(.|.)) +import Data.Map.Strict(Map) +import qualified Data.Map.Strict as Map import File import Gen(toLit) +import Generators import Language.Rust.Data.Ident import Language.Rust.Data.Position import Language.Rust.Quote import Language.Rust.Syntax +import System.Random(RandomGen) + +numTestCases :: Int +numTestCases = 3000 binaryOps :: File binaryOps = File { @@ -17,7 +25,7 @@ binaryOps = File { outputName = "binary", isUnsigned = True, generator = declareBinaryOperators, - testCase = Nothing + testCase = Just generateTests } declareBinaryOperators :: Word -> SourceFile Span @@ -29,6 +37,7 @@ declareBinaryOperators bitsize = xorOps = generateBinOps "BitXor" struct_name "bitxor" BitXorOp entries baseNegationStmts = negationStatements "self" entries refNegationStmts = negationStatements "output" entries + testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty in [sourceFile| use core::ops::{BitAnd,BitAndAssign}; use core::ops::{BitOr,BitOrAssign}; @@ -37,6 +46,8 @@ declareBinaryOperators bitsize = #[cfg(test)] use crate::CryptoNum; #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + #[cfg(test)] use quickcheck::quickcheck; use super::$$struct_name; @@ -109,6 +120,33 @@ declareBinaryOperators bitsize = (&a | $$struct_name::zero()) == a } } + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("binary", $$(testFileLit)), 6, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, ybytes) = case.get("y").unwrap(); + let (neg2, abytes) = case.get("a").unwrap(); + let (neg3, obytes) = case.get("o").unwrap(); + let (neg4, ebytes) = case.get("e").unwrap(); + let (neg5, nbytes) = case.get("n").unwrap(); + + assert!(!neg0 && !neg1 && !neg2 && !neg3 && !neg4 && !neg5); + let x = $$struct_name::from_bytes(&xbytes); + let y = $$struct_name::from_bytes(&ybytes); + let a = $$struct_name::from_bytes(&abytes); + let o = $$struct_name::from_bytes(&obytes); + let e = $$struct_name::from_bytes(&ebytes); + let n = $$struct_name::from_bytes(&nbytes); + + assert_eq!(a, &x & &y); + assert_eq!(o, &x | &y); + assert_eq!(e, &x ^ &y); + assert_eq!(n, !x); + }); + } |] negationStatements :: String -> Word -> [Stmt Span] @@ -201,3 +239,17 @@ generateAllTheVariants traitname func sname oper entries = [ Semi (AssignOp [] oper [expr| $$(left).value[$$(i)] |] [expr| $$(right).value[$$(i)] |] mempty) mempty + +generateTests :: RandomGen g => Word -> g -> [Map String String] +generateTests size g = go g numTestCases + where + go _ 0 = [] + go g0 i = + let (x, g1) = generateNum g0 size + (y, g2) = generateNum g1 size + tcase = Map.fromList [("x", showX x), ("y", showX y), + ("a", showX (x .&. y)), + ("o", showX (x .|. y)), + ("e", showX (x `xor` y)), + ("n", showX ( ((2 ^ size) - 1) `xor` x ))] + in tcase : go g2 (i - 1) diff --git a/generation/src/Compare.hs b/generation/src/Compare.hs index 2c5d189..9ebd39a 100644 --- a/generation/src/Compare.hs +++ b/generation/src/Compare.hs @@ -2,11 +2,18 @@ module Compare(comparisons) where +import Data.Map.Strict(Map) +import qualified Data.Map.Strict as Map import File +import Generators import Language.Rust.Data.Ident import Language.Rust.Data.Position import Language.Rust.Quote import Language.Rust.Syntax +import System.Random(RandomGen) + +numTestCases :: Int +numTestCases = 3000 comparisons :: File comparisons = File { @@ -14,7 +21,7 @@ comparisons = File { outputName = "compare", isUnsigned = True, generator = declareComparators, - testCase = Nothing + testCase = Just generateTests } declareComparators :: Word -> SourceFile Span @@ -23,9 +30,14 @@ declareComparators bitsize = entries = bitsize `div` 64 eqStatements = buildEqStatements 0 entries compareExp = buildCompareExp 0 entries + testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty in [sourceFile| use core::cmp::{Eq,Ordering,PartialEq}; #[cfg(test)] + use crate::CryptoNum; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + #[cfg(test)] use quickcheck::quickcheck; use super::$$sname; @@ -73,7 +85,41 @@ declareComparators bitsize = if a <= b && b <= c { a <= c } else { true } } } - |] + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("compare", $$(testFileLit)), 8, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, ybytes) = case.get("y").unwrap(); + let (neg2, ebytes) = case.get("e").unwrap(); + let (neg3, nbytes) = case.get("n").unwrap(); + let (neg4, gbytes) = case.get("g").unwrap(); + let (neg5, hbytes) = case.get("h").unwrap(); + let (neg6, lbytes) = case.get("l").unwrap(); + let (neg7, kbytes) = case.get("k").unwrap(); + + assert!(!neg0 && !neg1 && !neg2 && !neg3 && + !neg4 && !neg5 && !neg6 && !neg7); + let x = $$sname::from_bytes(&xbytes); + let y = $$sname::from_bytes(&ybytes); + let e = 1 == ebytes[0]; + let n = 1 == nbytes[0]; + let g = 1 == gbytes[0]; + let h = 1 == hbytes[0]; + let l = 1 == lbytes[0]; + let k = 1 == kbytes[0]; + + assert_eq!(e, x == y); + assert_eq!(n, x != y); + assert_eq!(g, x > y); + assert_eq!(h, x >= y); + assert_eq!(l, x < y); + assert_eq!(k, x <= y); + }); + } + |] buildEqStatements :: Word -> Word -> [Stmt Span] buildEqStatements i numEntries @@ -95,3 +141,19 @@ buildCompareExp i numEntries in [expr| $$(rest).then(self.value[$$(x)].cmp(&other.value[$$(x)])) |] where x = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty + +generateTests :: RandomGen g => Word -> g -> [Map String String] +generateTests size g = go g numTestCases + where + go _ 0 = [] + go g0 i = + let (x, g1) = generateNum g0 size + (y, g2) = generateNum g1 size + tcase = Map.fromList [("x", showX x), ("y", showX y), + ("e", showB (x == y)), + ("n", showB (x /= y)), + ("g", showB (x > y)), + ("h", showB (x >= y)), + ("l", showB (x < y)), + ("k", showB (x <= y))] + in tcase : go g2 (i - 1) diff --git a/generation/src/Conversions.hs b/generation/src/Conversions.hs index ee98bd4..99e890c 100644 --- a/generation/src/Conversions.hs +++ b/generation/src/Conversions.hs @@ -39,8 +39,10 @@ declareConversions bitsize = in [sourceFile| use core::convert::{From,TryFrom}; use crate::CryptoNum; - use super::$$sname; use crate::ConversionError; + #[cfg(test)] + use quickcheck::quickcheck; + use super::$$sname; $@{u8_prims} $@{u16_prims} @@ -55,6 +57,28 @@ declareConversions bitsize = $@{i64_prims} $@{isz_prims} $@{i128_prims} + + #[cfg(test)] + quickcheck! { + fn u8_recovers(x: u8) -> bool { + x == u8::try_from($$sname::from(x)).unwrap() + } + fn u16_recovers(x: u16) -> bool { + x == u16::try_from($$sname::from(x)).unwrap() + } + fn u32_recovers(x: u32) -> bool { + x == u32::try_from($$sname::from(x)).unwrap() + } + fn u64_recovers(x: u64) -> bool { + x == u64::try_from($$sname::from(x)).unwrap() + } + fn usize_recovers(x: usize) -> bool { + x == usize::try_from($$sname::from(x)).unwrap() + } + fn u128_recovers(x: u128) -> bool { + x == u128::try_from($$sname::from(x)).unwrap() + } + } |] generateU128Primitives :: Ident -> Word -> [Item Span] diff --git a/generation/src/CryptoNum.hs b/generation/src/CryptoNum.hs index d2840f8..c9145b8 100644 --- a/generation/src/CryptoNum.hs +++ b/generation/src/CryptoNum.hs @@ -43,6 +43,7 @@ declareCryptoNumInstance bitsize = Tree (Token mempty (LiteralTok (IntegerTok (show bytelen)) Nothing)) ]) entrieslit = toLit entries + testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty in [sourceFile| use core::cmp::min; #[cfg(test)] @@ -137,7 +138,7 @@ declareCryptoNumInstance bitsize = #[allow(non_snake_case)] #[test] fn KATs() { - run_test(build_test_path("base", stringify!($$sname)), 8, |case| { + run_test(build_test_path("cryptonum", $$(testFileLit)), 8, |case| { let (neg0, xbytes) = case.get("x").unwrap(); let (neg1, mbytes) = case.get("m").unwrap(); let (neg2, zbytes) = case.get("z").unwrap(); @@ -184,7 +185,7 @@ generateTests size g = go g numTestCases (m, g2) = generateNum g1 size (b, g3) = generateNum g2 16 m' = m `mod` (fromIntegral size `div` 64) - r = m `mod` (2 ^ (64 * m')) + r = x `mod` (2 ^ (64 * m')) t = x `testBit` (fromIntegral b) tcase = Map.fromList [("x", showX x), ("z", showB (x == 0)), ("e", showB (even x)), ("o", showB (odd x)), diff --git a/generation/src/File.hs b/generation/src/File.hs index 9231f50..fe606c5 100644 --- a/generation/src/File.hs +++ b/generation/src/File.hs @@ -3,7 +3,8 @@ module File( File(..), Task(..), - generateTasks + generateTasks, + testFile ) where @@ -34,6 +35,9 @@ data Task = Task { writer :: Handle -> IO () } +testFile :: Word -> FilePath +testFile size = "U" ++ show5 size ++ ".test" + show5 :: Word -> String show5 = go . show where @@ -65,8 +69,7 @@ generateTasks rng files sizes = basicTasks ++ moduleTasks mainTask : tasks Just caseGenerator -> let testTask = Task { - outputFile = "testdata" outputName file - ("U" ++ show5 size ++ ".test"), + outputFile = "testdata" outputName file testFile size, writer = \ hndl -> writeTestCase hndl (caseGenerator size myg) } in testTask : mainTask : tasks diff --git a/src/testing.rs b/src/testing.rs index e355f6f..23530d9 100644 --- a/src/testing.rs +++ b/src/testing.rs @@ -22,7 +22,7 @@ fn next_value_set(line: &str) -> (String, bool, Vec) let key = items.next().unwrap(); let valbits = items.next().unwrap(); let neg = valbits.contains('-'); - let valbitsnoneg = valbits.trim_start_matches("-"); + let valbitsnoneg = valbits.trim_start_matches('-'); let mut nibble_iter = valbitsnoneg.chars().rev(); let mut val = Vec::new();