Get back to basics, with some basic tests working.

This commit is contained in:
2019-07-30 16:23:14 -07:00
parent 203c23e277
commit 1d8907539d
10 changed files with 384 additions and 173 deletions

56
generation/src/Base.hs Normal file
View File

@@ -0,0 +1,56 @@
module Base(
base
)
where
import Control.Monad(forM_)
import File
import Gen
base :: File
base = File {
predicate = \ _ _ -> True,
outputName = "base",
generator = declareBaseStructure
}
declareBaseStructure :: Word -> Gen ()
declareBaseStructure bitsize =
do let name = "U" ++ show bitsize
entries = bitsize `div` 64
top = entries - 1
out "use core::fmt;"
out "use quickcheck::{Arbitrary,Gen};"
blank
out "#[derive(Clone)]"
wrapIndent ("pub struct " ++ name) $
out ("pub(crate) value: [u64; " ++ show entries ++ "]")
blank
implFor "fmt::Debug" name $
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
do out ("f.debug_tuple(" ++ show name ++ ")")
forM_ [0..top] $ \ i ->
out (" .field(&self.value[" ++ show i ++ "])")
out " .finish()"
blank
implFor "fmt::UpperHex" name $
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
do forM_ (reverse [1..top]) $ \ i ->
out ("write!(f, \"{:X}\", self.value[" ++ show i ++ "])?;")
out "write!(f, \"{:X}\", self.value[0])"
blank
implFor "fmt::LowerHex" name $
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
do forM_ (reverse [1..top]) $ \ i ->
out ("write!(f, \"{:x}\", self.value[" ++ show i ++ "])?;")
out "write!(f, \"{:x}\", self.value[0])"
blank
implFor "Arbitrary" name $
wrapIndent "fn arbitrary<G: Gen>(g: &mut G) -> Self" $
do out (name ++ " {")
indent $
do out ("value: [")
indent $ forM_ [0..top] $ \ _ ->
out ("g.next_u64(),")
out ("]")
out ("}")

60
generation/src/Compare.hs Normal file
View File

@@ -0,0 +1,60 @@
module Compare(comparisons)
where
import Control.Monad(forM_)
import File
import Gen
comparisons :: File
comparisons = File {
predicate = \ _ _ -> True,
outputName = "compare",
generator = declareComparators
}
declareComparators :: Word -> Gen ()
declareComparators bitsize =
do let name = "U" ++ show bitsize
entries = bitsize `div` 64
top = entries - 1
out "use core::cmp::{Eq,Ordering,PartialEq};"
out "#[cfg(test)]"
out "use quickcheck::quickcheck;"
out ("use super::" ++ name ++ ";")
blank
implFor "PartialEq" name $
wrapIndent "fn eq(&self, other: &Self) -> bool" $
do forM_ (reverse [1..top]) $ \ i ->
out ("self.value[" ++ show i ++ "] == other.value[" ++ show i ++ "] && ")
out "self.value[0] == other.value[0]"
blank
implFor "Eq" name $ return ()
blank
implFor "Ord" name $
wrapIndent "fn cmp(&self, other: &Self) -> Ordering" $
do out ("self.value[" ++ show top ++ "].cmp(&other.value[" ++ show top ++ "])")
forM_ (reverse [0..top-1]) $ \ i ->
out (" .then(self.value[" ++ show i ++ "].cmp(&other.value[" ++ show i ++ "]))")
blank
implFor "PartialOrd" name $
wrapIndent "fn partial_cmp(&self, other: &Self) -> Option<Ordering>" $
out "Some(self.cmp(other))"
blank
out "#[cfg(test)]"
wrapIndent "quickcheck!" $
do let transFun n = "fn " ++ n ++ "(a: " ++ name ++ ", b: " ++ name ++
", c: " ++ name ++ ") -> bool"
wrapIndent (transFun "eq_is_transitive") $
out ("if a == c { a == b && b == c } else { a != b || b != c }")
blank
wrapIndent (transFun "gt_is_transitive") $
out ("if a > b && b > c { a > c } else { true }")
blank
wrapIndent (transFun "ge_is_transitive") $
out ("if a >= b && b >= c { a >= c } else { true }")
blank
wrapIndent (transFun "lt_is_transitive") $
out ("if a < b && b < c { a < c } else { true }")
blank
wrapIndent (transFun "le_is_transitive") $
out ("if a <= b && b <= c { a <= c } else { true }")

View File

@@ -0,0 +1,98 @@
module Conversions(
conversions
)
where
import Data.List(intercalate)
import File
import Gen
conversions :: File
conversions = File {
predicate = \ _ _ -> True,
outputName = "conversions",
generator = declareConversions
}
declareConversions :: Word -> Gen ()
declareConversions bitsize =
do let name = "U" ++ show bitsize
entries = bitsize `div` 64
out "use core::convert::{From,TryFrom};"
out "#[cfg(test)]"
out "use quickcheck::quickcheck;"
out ("use super::" ++ name ++ ";")
blank
buildUnsignedPrimConversions name entries "u8" >> blank
buildUnsignedPrimConversions name entries "u16" >> blank
buildUnsignedPrimConversions name entries "u32" >> blank
buildUnsignedPrimConversions name entries "u64" >> blank
buildUnsignedPrimConversions name entries "usize" >> blank
buildSignedPrimConversions name entries "i8" >> blank
buildSignedPrimConversions name entries "i16" >> blank
buildSignedPrimConversions name entries "i32" >> blank
buildSignedPrimConversions name entries "i64" >> blank
buildSignedPrimConversions name entries "isize"
blank
out ("#[cfg(test)]")
wrapIndent "quickcheck!" $
do roundTripTest name "u8" >> blank
roundTripTest name "u16" >> blank
roundTripTest name "u32" >> blank
roundTripTest name "u64" >> blank
roundTripTest name "usize"
buildUnsignedPrimConversions :: String -> Word -> String -> Gen ()
buildUnsignedPrimConversions name entries primtype =
do implFor ("From<" ++ primtype ++ ">") name $
wrapIndent ("fn from(x: " ++ primtype ++ ") -> Self") $
do let zeroes = replicate (fromIntegral (entries - 1)) "0,"
values = ("x as u64," : zeroes)
out (name ++ " { value: [ ")
indent $ printBy 8 values
out ("] }")
blank
implFor ("From<" ++ name ++ ">") primtype $
wrapIndent ("fn from(x: " ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
blank
implFor' ("From<&'a " ++ name ++ ">") primtype $
wrapIndent ("fn from(x: &" ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
buildSignedPrimConversions :: String -> Word -> String -> Gen ()
buildSignedPrimConversions name entries primtype =
do implFor ("TryFrom<" ++ primtype ++ ">") name $
do out ("type Error = &'static str;")
blank
wrapIndent ("fn try_from(x: " ++ primtype ++ ") -> Result<Self,Self::Error>") $
do wrapIndent ("if x < 0") $
out ("return Err(\"Attempt to convert negative number to " ++
name ++ ".\");")
blank
let zeroes = replicate (fromIntegral (entries - 1)) "0,"
values = ("x as u64," : zeroes)
out ("Ok(" ++ name ++ " { value: [ ")
indent $ printBy 8 values
out ("] })")
blank
implFor ("From<" ++ name ++ ">") primtype $
wrapIndent ("fn from(x: " ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
blank
implFor' ("From<&'a " ++ name ++ ">") primtype $
wrapIndent ("fn from(x: &" ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
roundTripTest :: String -> String -> Gen ()
roundTripTest name primtype =
wrapIndent ("fn " ++ primtype ++ "_roundtrips(x: " ++ primtype ++ ") -> bool") $
do out ("let big = " ++ name ++ "::from(x);");
out ("let small = " ++ primtype ++ "::from(big);")
out ("x == small")
printBy :: Int -> [String] -> Gen ()
printBy amt xs
| length xs <= amt = out (intercalate " " xs)
| otherwise = printBy amt (take amt xs) >>
printBy amt (drop amt xs)

146
generation/src/CryptoNum.hs Normal file
View File

@@ -0,0 +1,146 @@
module CryptoNum(
cryptoNum
)
where
import Control.Monad(forM_)
import File
import Gen
cryptoNum :: File
cryptoNum = File {
predicate = \ _ _ -> True,
outputName = "cryptonum",
generator = declareCryptoNumInstance
}
declareCryptoNumInstance :: Word -> Gen ()
declareCryptoNumInstance bitsize =
do let name = "U" ++ show bitsize
entries = bitsize `div` 64
top = entries - 1
out "use core::cmp::min;"
out "use crate::CryptoNum;"
out "#[cfg(test)]"
out "use crate::testing::{build_test_path,run_test};"
out "#[cfg(test)]"
out "use quickcheck::quickcheck;"
out ("use super::" ++ name ++ ";")
blank
implFor "CryptoNum" name $
do wrapIndent ("fn zero() -> Self") $
out (name ++ "{ value: [0; " ++ show entries ++ "] }")
blank
wrapIndent ("fn is_zero(&self) -> bool") $
do forM_ (reverse [1..top]) $ \ i ->
out ("self.value[" ++ show i ++ "] == 0 &&")
out "self.value[0] == 0"
blank
wrapIndent ("fn is_even(&self) -> bool") $
out "self.value[0] & 0x1 == 0"
blank
wrapIndent ("fn is_odd(&self) -> bool") $
out "self.value[0] & 0x1 == 1"
blank
wrapIndent ("fn bit_length() -> usize") $
out (show bitsize)
blank
wrapIndent ("fn mask(&mut self, len: usize)") $
do out ("let dellen = min(len, " ++ show entries ++ ");")
wrapIndent ("for i in dellen.." ++ show entries) $
out ("self.value[i] = 0;")
blank
wrapIndent ("fn testbit(&self, bit: usize) -> bool") $
do out "let idx = bit / 64;"
out "let offset = bit % 64;"
wrapIndent ("if idx >= " ++ show entries) $
out "return false;"
out "(self.value[idx] & (1u64 << offset)) != 0"
blank
wrapIndent ("fn from_bytes(bytes: &[u8]) -> Self") $
do out ("let biggest = min(" ++ show (bitsize `div` 8) ++ ", " ++
"bytes.len()) - 1;")
out ("let mut idx = biggest / 8;")
out ("let mut shift = (biggest % 8) * 8;")
out ("let mut i = 0;")
out ("let mut res = " ++ name ++ "::zero();")
blank
wrapIndent ("while i <= biggest") $
do out ("res.value[idx] |= (bytes[i] as u64) << shift;")
out ("i += 1;")
out ("if shift == 0 {")
indent $
do out "shift = 56;"
out "if idx > 0 { idx -= 1; }"
out ("} else {")
indent $
out "shift -= 8;"
out "}"
blank
out "res"
blank
wrapIndent ("fn to_bytes(&self, bytes: &mut [u8])") $
do let bytes = bitsize `div` 8
out ("if bytes.len() == 0 { return; }")
blank
forM_ [0..bytes-1] $ \ idx ->
do let (validx, shift) = byteShiftInfo idx
out ("let byte" ++ show idx ++ " = (self.value[" ++
show validx ++ "] >> " ++ show shift ++ ")" ++
" as u8;")
blank
out ("let mut idx = min(bytes.len() - 1, " ++ show (bytes - 1) ++ ");")
forM_ [0..bytes-2] $ \ i ->
do out ("bytes[idx] = byte" ++ show i ++ ";")
out ("if idx == 0 { return; }")
out ("idx -= 1;")
out ("bytes[idx] = byte" ++ show (bytes-1) ++ ";")
blank
out "#[cfg(test)]"
wrapIndent "quickcheck!" $
do wrapIndent ("fn to_from_ident(x: " ++ name ++ ") -> bool") $
do out ("let mut buffer = [0; " ++ show (bitsize `div` 8) ++ "];")
out ("x.to_bytes(&mut buffer);");
out ("let y = " ++ name ++ "::from_bytes(&buffer);")
out ("x == y")
blank
out "#[cfg(test)]"
out "#[allow(non_snake_case)]"
out "#[test]"
wrapIndent "fn KATs()" $
do let name' = pad 5 '0' (show bitsize)
out ("run_test(build_test_path(\"base\",\"" ++ name' ++ "\"), 8, |case| {")
indent $
do out ("let (neg0, xbytes) = case.get(\"x\").unwrap();")
out ("let (neg1, mbytes) = case.get(\"m\").unwrap();")
out ("let (neg2, zbytes) = case.get(\"z\").unwrap();")
out ("let (neg3, ebytes) = case.get(\"e\").unwrap();")
out ("let (neg4, obytes) = case.get(\"o\").unwrap();")
out ("let (neg5, rbytes) = case.get(\"r\").unwrap();")
out ("let (neg6, bbytes) = case.get(\"b\").unwrap();")
out ("let (neg7, tbytes) = case.get(\"t\").unwrap();")
out ("assert!(!neg0&&!neg1&&!neg2&&!neg3&&!neg4&&!neg5&&!neg6&&!neg7);")
out ("let mut x = "++name++"::from_bytes(xbytes);")
out ("let m = "++name++"::from_bytes(mbytes);")
out ("let z = 1 == zbytes[0];")
out ("let e = 1 == ebytes[0];")
out ("let o = 1 == obytes[0];")
out ("let r = "++name++"::from_bytes(rbytes);")
out ("let b = usize::from("++name++"::from_bytes(bbytes));")
out ("let t = 1 == tbytes[0];")
out ("assert_eq!(x.is_zero(), z);")
out ("assert_eq!(x.is_even(), e);")
out ("assert_eq!(x.is_odd(), o);")
out ("assert_eq!(x.testbit(b), t);")
out ("x.mask(usize::from(&m));")
out ("assert_eq!(x, r);")
out ("});")
byteShiftInfo :: Word -> (Word, Word)
byteShiftInfo idx =
(idx `div` 8, (idx `mod` 8) * 8)
pad :: Int -> Char -> String -> String
pad len c str
| length str >= len = str
| otherwise = pad len c (c:str)

View File

@@ -1,7 +1,11 @@
module Main
where
import Base(base)
import BinaryOps(binaryOps)
import Compare(comparisons)
import Conversions(conversions)
import CryptoNum(cryptoNum)
import Control.Monad(forM_,unless)
import Data.Maybe(mapMaybe)
import Data.Word(Word)
@@ -11,7 +15,6 @@ import System.Directory(createDirectoryIfMissing)
import System.Environment(getArgs)
import System.Exit(die)
import System.FilePath(takeDirectory,(</>))
import UnsignedBase(base)
lowestBitsize :: Word
lowestBitsize = 192
@@ -26,6 +29,9 @@ unsignedFiles :: [File]
unsignedFiles = [
base
, binaryOps
, comparisons
, conversions
, cryptoNum
]
signedFiles :: [File]

View File

@@ -1,160 +0,0 @@
module UnsignedBase(
base
)
where
import Control.Monad(forM_)
import Data.List(intercalate)
import File
import Gen
base :: File
base = File {
predicate = \ _ _ -> True,
outputName = "base",
generator = declareBaseStructure
}
declareBaseStructure :: Word -> Gen ()
declareBaseStructure bitsize =
do let name = "U" ++ show bitsize
entries = bitsize `div` 64
top = entries - 1
out "use core::cmp::{Eq,Ordering,PartialEq,min};"
out "use core::fmt;"
out "use super::super::super::CryptoNum;"
blank
out "#[derive(Clone)]"
wrapIndent ("pub struct " ++ name) $
out ("pub(crate) value: [u64; " ++ show entries ++ "]")
blank
implFor "CryptoNum" name $
do wrapIndent ("fn zero() -> Self") $
out (name ++ "{ value: [0; " ++ show entries ++ "] }")
blank
wrapIndent ("fn is_zero(&self) -> bool") $
do forM_ (reverse [1..top]) $ \ i ->
out ("self.value[" ++ show i ++ "] == 0 &&")
out "self.value[0] == 0"
blank
wrapIndent ("fn is_even(&self) -> bool") $
out "self.value[0] & 0x1 == 0"
blank
wrapIndent ("fn is_odd(&self) -> bool") $
out "self.value[0] & 0x1 == 0"
blank
wrapIndent ("fn bit_length() -> usize") $
out (show bitsize)
blank
wrapIndent ("fn mask(&mut self, len: usize)") $
do out ("let dellen = min(len, " ++ show entries ++ ");")
wrapIndent ("for i in dellen.." ++ show entries) $
out ("self.value[i] = 0;")
blank
wrapIndent ("fn testbit(&self, bit: usize) -> bool") $
do out "let idx = bit / 64;"
out "let offset = bit % 64;"
wrapIndent ("if idx >= " ++ show entries) $
out "return false;"
out "(self.value[idx] & (1u64 << offset)) != 0"
blank
wrapIndent ("fn from_bytes(&self, bytes: &[u8]) -> Self") $
do let bytes = bitsize `div` 8;
forM_ [0..bytes-1] $ \ idx ->
out ("let byte" ++ show idx ++ " = " ++
"if " ++ show idx ++ " < bytes.len() { " ++
"bytes[" ++ show idx ++ "] as u64 } else { 0 };")
blank
let byteNames = map (\ x -> "byte" ++ show x) [0..bytes-1]
byteGroups = groupCount 8 (reverse byteNames)
forM_ (zip byteGroups [0..bytes-1]) $ \ (byteGroup, idx) ->
do let shiftAmts = [0,8..56]
shifts = zipWith (\ n s -> n ++ " << " ++ show s)
byteGroup shiftAmts
shift0 = head shifts
shiftL = last shifts
middles = reverse (drop 1 (reverse (drop 1 shifts)))
prefix = "let word" ++ show idx ++ " "
blankPrefix = map (const ' ') prefix
out (prefix ++ " = " ++ shift0)
forM_ middles $ \ s -> out (blankPrefix ++ " | " ++ s)
out (blankPrefix ++ " | " ++ shiftL ++ ";")
blank
wrapIndent name $
do out ("value: [")
let vwords = map (\ x -> "word" ++ show x) [0..top]
linewords = groupCount 4 vwords
vlines = map (intercalate ", ") linewords
forM_ vlines $ \ l -> out (" " ++ l ++ ",")
out ("]")
blank
implFor "PartialEq" name $
wrapIndent "fn eq(&self, other: &Self) -> bool" $
do forM_ (reverse [1..top]) $ \ i ->
out ("self.value[" ++ show i ++ "] == other.value[" ++ show i ++ "] && ")
out "self.value[0] == other.value[0]"
blank
implFor "Eq" name $ return ()
blank
implFor "Ord" name $
wrapIndent "fn cmp(&self, other: &Self) -> Ordering" $
do out ("self.value[" ++ show top ++ "].cmp(&other.value[" ++ show top ++ "])")
forM_ (reverse [0..top-1]) $ \ i ->
out (" .then(self.value[" ++ show i ++ "].cmp(&other.value[" ++ show i ++ "]))")
blank
implFor "PartialOrd" name $
wrapIndent "fn partial_cmp(&self, other: &Self) -> Option<Ordering>" $
out "Some(self.cmp(other))"
blank
implFor "fmt::Debug" name $
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
do out ("f.debug_tuple(" ++ show name ++ ")")
forM_ [0..top] $ \ i ->
out (" .field(&self.value[" ++ show i ++ "])")
out " .finish()"
blank
implFor "fmt::UpperHex" name $
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
do forM_ (reverse [1..top]) $ \ i ->
out ("write!(f, \"{:X}\", self.value[" ++ show i ++ "])?;")
out "write!(f, \"{:X}\", self.value[0])"
blank
implFor "fmt::LowerHex" name $
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
do forM_ (reverse [1..top]) $ \ i ->
out ("write!(f, \"{:x}\", self.value[" ++ show i ++ "])?;")
out "write!(f, \"{:x}\", self.value[0])"
blank
out "#[test]"
wrapIndent "fn KATs()" $
do out ("run_test(\"testdata/base/" ++ name ++ ".test\", 8, |case| {")
indent $
do out ("let (neg0, xbytes) = case.get(\"x\").unwrap();")
out ("let (neg1, mbytes) = case.get(\"m\").unwrap();")
out ("let (neg2, zbytes) = case.get(\"z\").unwrap();")
out ("let (neg3, ebytes) = case.get(\"e\").unwrap();")
out ("let (neg4, obytes) = case.get(\"o\").unwrap();")
out ("let (neg5, rbytes) = case.get(\"r\").unwrap();")
out ("let (neg6, bbytes) = case.get(\"b\").unwrap();")
out ("let (neg7, tbytes) = case.get(\"t\").unwrap();")
out ("assert!(!neg0&&!neg1&&!neg2&&!neg3&&!neg4&&!neg5&&!neg6&&!neg7);")
out ("let mut x = "++name++"::from_bytes(xbytes);")
out ("let m = "++name++"::from_bytes(mbytes);")
out ("let z = 1 == zbytes[0];")
out ("let e = 1 == ebytes[0];")
out ("let o = 1 == obytes[0];")
out ("let r = "++name++"::from_bytes(rbytes);")
out ("let b = usize::from("++name++"::from_bytes(bbytes));")
out ("let t = 1 == tbytes[0];")
out ("assert_eq!(x.is_zero(), z);")
out ("assert_eq!(x.is_even(), e);")
out ("assert_eq!(x.is_odd(), o);")
out ("assert_eq!(x.testbit(b), t);")
out ("x.mask(usize::from(&m));")
out ("assert_eq!(x, r);")
out ("});")
groupCount :: Int -> [a] -> [[a]]
groupCount x ls
| x >= length ls = [ls]
| otherwise = take x ls : groupCount x (drop x ls)