Support a more complete (and simple) requirements gathering mechanism,

and add support for binary operations.

This version of requirements generation simply generates every numeric
size within a provided range, and then will reject trait implementations
that rely on values outside this range. It should be a little more easy
to reason about, and easier to make local changes as I (inevitably) need
to modify rules.
This commit is contained in:
2019-07-22 08:14:40 -07:00
parent ab465296f2
commit aff88eb2f0
8 changed files with 243 additions and 384 deletions

View File

@@ -0,0 +1,93 @@
module BinaryOps(
binaryOps
)
where
import Control.Monad(forM_)
import File
import Gen
binaryOps :: File
binaryOps = File {
predicate = \ _ _ -> True,
outputName = "binary",
generator = declareBinaryOperators
}
declareBinaryOperators :: Word -> Gen ()
declareBinaryOperators bitsize =
do let name = "U" ++ show bitsize
entries = bitsize `div` 64
out "use core::ops::{BitAnd,BitAndAssign};"
out "use core::ops::{BitOr,BitOrAssign};"
out "use core::ops::{BitXor,BitXorAssign};"
out "use core::ops::Not;"
out ("use super::U" ++ show bitsize ++ ";")
blank
generateBinOps "BitAnd" name "bitand" "&=" entries
blank
generateBinOps "BitOr" name "bitor" "|=" entries
blank
generateBinOps "BitXor" name "bitxor" "^=" entries
blank
implFor "Not" name $
do out "type Output = Self;"
blank
wrapIndent "fn not(mut self) -> Self" $
do forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];")
out "self"
blank
implFor' "Not" ("&'a " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn not(self) -> " ++ name) $
do out "let mut output = self.clone();"
forM_ [0..entries-1] $ \ i ->
out ("output.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];")
out "output"
generateBinOps :: String -> String -> String -> String -> Word -> Gen ()
generateBinOps trait name fun op entries =
do implFor (trait ++ "Assign") name $
wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: Self)") $
forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];")
blank
implFor' (trait ++ "Assign<&'a " ++ name ++ ">") name $
wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: &Self)") $
forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];")
blank
generateBinOpsFromAssigns trait name fun op
generateBinOpsFromAssigns :: String -> String -> String -> String -> Gen ()
generateBinOpsFromAssigns trait name fun op =
do implFor trait name $
do out "type Output = Self;"
blank
wrapIndent ("fn " ++ fun ++ "(mut self, rhs: Self) -> Self") $
do out ("self " ++ op ++ " rhs;")
out "self"
blank
implFor' (trait ++ "<&'a " ++ name ++ ">") name $
do out "type Output = Self;"
blank
wrapIndent ("fn " ++ fun ++ "(mut self, rhs: &Self) -> Self") $
do out ("self " ++ op ++ " rhs;")
out "self"
blank
implFor' (trait ++ "<" ++ name ++ ">") ("&'a " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn " ++ fun ++ "(self, mut rhs: " ++ name ++ ") -> " ++ name) $
do out ("rhs " ++ op ++ " self;")
out "rhs"
blank
implFor'' (trait ++ "<&'a " ++ name ++ ">") ("&'b " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn " ++ fun ++ "(self, rhs: &" ++ name ++ ") -> " ++ name) $
do out "let mut output = self.clone();"
out ("output " ++ op ++ " rhs;")
out "output"

66
generation/src/File.hs Normal file
View File

@@ -0,0 +1,66 @@
module File(
File(..),
Task(..),
addModuleTasks,
makeTask
)
where
import Control.Monad(forM_)
import Data.Char(toUpper)
import qualified Data.Map.Strict as Map
import Gen(Gen,blank,out)
import System.FilePath(takeBaseName,takeDirectory,takeFileName,(</>))
data File = File {
predicate :: Word -> [Word] -> Bool,
outputName :: FilePath,
generator :: Word -> Gen ()
}
data Task = Task {
outputFile :: FilePath,
fileGenerator :: Gen ()
}
makeTask :: FilePath ->
Word -> [Word] ->
File ->
Maybe Task
makeTask base size allSizes file
| predicate file size allSizes =
Just Task {
outputFile = base </> ("u" ++ show size) </> outputName file <> ".rs",
fileGenerator = generator file size
}
| otherwise =
Nothing
addModuleTasks :: FilePath -> [Task] -> [Task]
addModuleTasks base baseTasks = unsignedTask : (baseTasks ++ moduleTasks)
where
moduleMap = foldr addModuleInfo Map.empty baseTasks
addModuleInfo task =
Map.insertWith (++) (takeDirectory (outputFile task))
[takeBaseName (outputFile task)]
moduleTasks = Map.foldrWithKey generateModuleTask [] moduleMap
generateModuleTask directory mods acc = acc ++ [Task {
outputFile = directory </> "mod.rs",
fileGenerator =
do forM_ mods $ \ modle -> out ("mod " ++ modle ++ ";")
blank
out ("pub use base::" ++ upcase (takeFileName directory) ++ ";")
}]
unsignedTask = Task {
outputFile = base </> "unsigned" </> "mod.rs",
fileGenerator =
do forM_ (Map.keys moduleMap) $ \ key ->
out ("mod " ++ takeFileName key ++ ";")
blank
forM_ (Map.keys moduleMap) $ \ key ->
out ("pub use " ++ takeFileName key ++ "::" ++
upcase (takeFileName key) ++ ";")
}
upcase :: String -> String
upcase = map toUpper

View File

@@ -1,33 +1,54 @@
module Main
where
import BinaryOps(binaryOps)
import Control.Monad(forM_,unless)
import Data.List(sort)
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Data.Maybe(mapMaybe)
import Data.Word(Word)
import File(File,Task(..),addModuleTasks,makeTask)
import Gen(runGen)
import Requirements(Requirement(..), Operation(..), requirements)
import System.Directory(createDirectoryIfMissing)
import System.Environment(getArgs)
import System.Exit(die)
import System.FilePath((</>))
import UnsignedBase(declareBaseStructure,declareBinaryOperators)
import System.FilePath(takeDirectory,(</>))
import UnsignedBase(base)
gatherRequirements :: [Requirement] -> Map Int [Operation]
gatherRequirements = foldr process Map.empty
where process (Req x val) = Map.insertWith (++) x [val]
lowestBitsize :: Word
lowestBitsize = 192
highestBitsize :: Word
highestBitsize = 512
bitsizes :: [Word]
bitsizes = [lowestBitsize,lowestBitsize+64..highestBitsize]
unsignedFiles :: [File]
unsignedFiles = [
base
, binaryOps
]
signedFiles :: [File]
signedFiles = [
]
makeTasks :: FilePath -> [File] -> [Task]
makeTasks basePath files =
concatMap (\ sz -> mapMaybe (makeTask basePath sz bitsizes) files) bitsizes
makeAllTasks :: FilePath -> [Task]
makeAllTasks basePath = addModuleTasks basePath $
makeTasks (basePath </> "unsigned") unsignedFiles ++
makeTasks (basePath </> "signed") signedFiles
main :: IO ()
main =
do args <- getArgs
unless (length args == 1) $
die ("generation takes exactly one argument, the target directory")
let reqs = sort (Map.toList (gatherRequirements requirements))
target = head args
forM_ reqs $ \ (size, opes) ->
do let basedir = target </> "unsigned" </> ("u" ++ show size)
createDirectoryIfMissing True basedir
forM_ reqs $ \ (x, ops) ->
do runGen (basedir </> "mod.rs") (declareBaseStructure size ops)
runGen (basedir </> "binary.rs") (declareBinaryOperators size)
let tasks = makeAllTasks (head args)
total = length tasks
forM_ (zip [(1::Word)..] tasks) $ \ (i, task) ->
do putStrLn ("[" ++ show i ++ "/" ++ show total ++ "] " ++ outputFile task)
createDirectoryIfMissing True (takeDirectory (outputFile task))
runGen (outputFile task) (fileGenerator task)

View File

@@ -1,278 +0,0 @@
module Requirements(
Operation(..),
Requirement(..),
requirements
)
where
import Data.List(sort)
data Operation = Add
| BaseOps
| Barretts
| Div
| ModDiv
| ModExp
| ModMul
| ModSq
| Mul
| Scale
| Shifts
| Square
| Sub
| Convert Int
| SignedAdd
| SignedBase
| SignedCmp
| SignedShift
| SignedSub
| SignedMul
| SignedDiv
| SignedModInv
| SignedScale
| SigConvert Int
| SquareRoot
| EGCD
| ModInv
| PrimeGen
| RSA
| DSA
| ECDSA
deriving (Eq, Ord, Show)
data Requirement = Req Int Operation
deriving (Eq, Ord, Show)
data Need = Need Operation (Int -> [Requirement])
needs :: [Need]
needs = [ Need RSA (\ size -> [Req (size `div` 2) Sub,
Req (size `div` 2) Mul,
Req (size `div` 2) PrimeGen,
Req size BaseOps,
Req size ModInv,
Req size ModExp
])
, Need DSA (\ size -> [Req size BaseOps,
Req size Shifts,
Req size Add,
Req size SquareRoot,
Req size PrimeGen,
Req size ModInv,
Req size Mul,
Req (size * 2) Add,
Req (((size * 2) + 64) * 2) Div,
Req size (Convert 512),
Req size (Convert (size + 128)),
Req size (Convert ((size * 2) + 64)),
Req size (Convert (((size * 2) + 64) * 2))
])
, Need ECDSA (\ size -> [Req size SignedSub,
Req size SignedMul,
Req size ModMul,
Req size ModDiv,
Req (size * 2) BaseOps,
Req (size * 2) SignedBase,
Req (size * 2) SignedShift,
Req (size * 2) SignedSub,
Req (size * 2) SignedMul,
Req (size * 2) SignedDiv,
Req ((size * 2) + 64) SignedBase,
Req ((size * 2) + 64) BaseOps,
Req ((size * 2) + 64) SignedAdd,
Req ((size * 2) + 64) SignedShift,
Req ((size * 2) + 64) ModDiv,
Req size (Convert (size * 2)),
Req size (SigConvert (size * 2)),
Req size (Convert ((size * 2) + 64)),
Req size (SigConvert ((size * 2) + 64)),
Req (size * 2) (Convert ((size * 2) + 64)),
Req (size * 2) (SigConvert ((size * 2) + 64)),
Req size (Convert (size * 4)),
Req (size * 4) Div
])
, Need PrimeGen (\ size -> [Req size Div,
Req size Shifts,
Req size ModExp,
Req size EGCD])
, Need Add (\ size -> [Req size BaseOps,
Req (size + 64) BaseOps,
Req size (Convert (size + 64))
])
, Need Barretts (\ size -> [Req size BaseOps,
Req (size + 64) BaseOps,
Req (size * 2) BaseOps,
Req ((size * 2) + 64) BaseOps,
Req size (Convert ((size * 2) + 64)),
Req (size + 64) Mul,
Req ((size * 2) + 64) Add,
Req ((size * 2) + 64) Sub,
Req (size + 64) (Convert ((size * 2) + 64)),
Req ((size * 2) + 64) (Convert ((size + 64) * 2)),
Req (size * 2) (Convert ((size * 2) + 64)),
Req (size + 64) (Convert ((size + 64) * 2)),
Req (size + 64) (Convert (size * 2)),
Req (size * 2) Shifts,
Req ((size + 64) * 2) Shifts,
Req ((size * 2) + 64) Div
])
, Need Div (\ size -> [Req size BaseOps,
Req (size * 2) BaseOps,
Req size (Convert (size * 2)),
Req (size * 2) Sub,
Req size Mul,
Req 192 BaseOps,
Req 192 Mul,
Req 384 BaseOps
])
, Need ModExp (\ size -> [Req size BaseOps,
Req size Barretts,
Req size ModSq,
Req size ModMul,
Req size (Convert (size + 64))
])
, Need ModMul (\ size -> [Req size BaseOps,
Req (size * 2) BaseOps,
Req size Barretts,
Req size Mul,
Req size (Convert (size + 64))
])
, Need ModDiv (\ size -> [Req size ModInv,
Req size SignedModInv,
Req size SignedMul,
Req size SignedDiv,
Req (size * 2) SignedDiv,
Req size (SigConvert (size * 2))
])
, Need ModSq (\ size -> [Req size BaseOps,
Req (size * 2) BaseOps,
Req size Barretts,
Req size Square,
Req (size * 2) Div,
Req size (Convert (size * 2)),
Req size (Convert (size + 64))
])
, Need Mul (\ size -> [Req size BaseOps,
Req size Scale,
Req (size * 2) BaseOps,
Req size (Convert (size * 2))
])
, Need Scale (\ size -> [Req (size + 64) BaseOps])
, Need Shifts (\ size -> [Req size BaseOps
])
, Need Square (\ size -> [Req size BaseOps,
Req (size * 2) BaseOps
])
, Need Sub (\ size -> [Req size BaseOps
])
, Need SignedAdd (\ size -> [Req size SignedBase,
Req size Add,
Req size Sub,
Req (size + 64) SignedBase,
Req (size + 64) BaseOps
])
, Need SignedBase (\ size -> [Req size BaseOps])
, Need SignedCmp (\ size -> [Req size BaseOps])
, Need SignedShift (\ size -> [Req size SignedBase,
Req size BaseOps,
Req size Shifts,
Req size Add
])
, Need SignedSub (\ size -> [Req size SignedBase,
Req (size + 64) SignedBase,
Req (size + 64) BaseOps,
Req size Add,
Req size Sub,
Req (size + 64) Sub,
Req size (Convert (size + 64)),
Req size (SigConvert (size + 64))
])
, Need SignedMul (\ size -> [Req size Mul,
Req size SignedScale,
Req (size * 2) SignedBase,
Req size (SigConvert (size * 2)),
Req size Square
])
, Need SignedDiv (\ size -> [Req size Div,
Req size Add
])
, Need EGCD (\ size -> [Req size SignedBase,
Req size BaseOps,
Req size Shifts,
Req (size + 64) SignedBase,
Req ((size + 64) * 2) SignedBase,
Req size (SigConvert (size + 64)),
Req (size + 64) SignedShift,
Req (size + 64) SignedAdd,
Req (size + 64) SignedSub,
Req (size + 64) SignedCmp,
Req (size + 64) SignedDiv,
Req (size + 64) SignedMul,
Req ((size + 64) * 2) SignedSub,
Req (size + 64) (Convert (((size + 64) * 2) + 64)),
Req (size + 64) (SigConvert (((size + 64) * 2) + 64))
])
, Need ModInv (\ size -> [Req size BaseOps,
Req (size + 64) SignedBase,
Req (size + 64) BaseOps,
Req size (Convert (size + 64)),
Req size EGCD,
Req (size + 64) SignedAdd,
Req size Barretts
])
, Need SignedModInv (\ size -> [
Req size EGCD,
Req size SignedModInv
])
, Need SquareRoot (\ size -> [Req size BaseOps,
Req size Shifts,
Req size Add,
Req size Sub
])
]
newRequirements :: Requirement -> [Requirement]
newRequirements (Req size op) = concatMap go needs ++ [Req size BaseOps]
where
go (Need op2 generator) | op == op2 = generator size
| otherwise = []
rsaSizes :: [Int]
rsaSizes = [512,1024,2048,3072,4096,8192,15360]
dsaSizes :: [Int]
dsaSizes = [192,256,1024,2048,3072]
ecdsaSizes :: [Int]
ecdsaSizes = [192,256,384,576]
baseRequirements :: [Requirement]
baseRequirements = concatMap (\ x -> [Req x RSA]) rsaSizes
++ concatMap (\ x -> [Req x DSA]) dsaSizes
++ concatMap (\ x -> [Req x ECDSA]) ecdsaSizes
++ [Req 192 (Convert 1024), Req 256 (Convert 2048), Req 256 (Convert 3072)] -- used in DSA
++ [Req 384 (Convert 1024), Req 512 (Convert 2048), Req 512 (Convert 3072)] -- used in DSA
++ [Req 192 Add, Req 256 Add, Req 384 Add] -- used for testing
++ [Req 192 Mul, Req 384 Mul] -- used for testing
++ [Req 448 (Convert 512)] -- used for testing
requirements :: [Requirement]
requirements = go baseRequirements
where
step ls = let news = concatMap newRequirements ls
ls' = concatMap sanitizeConverts (news ++ ls)
ls'' = removeDups (sort ls')
in ls''
--
go ls = let ls' = step ls
in if ls == ls' then ls else go ls'
--
removeDups [] = []
removeDups (x:xs) | x `elem` xs = removeDups xs
| otherwise = x : removeDups xs
--
sanitizeConverts (Req x (Convert y))
| x == y = []
| x < y = [Req x (Convert y), Req y BaseOps]
| otherwise = [Req y (Convert x), Req x BaseOps]
sanitizeConverts x = [x]

View File

@@ -1,27 +1,31 @@
module UnsignedBase(
declareBaseStructure
, declareBinaryOperators
base
)
where
import Control.Monad(forM_)
import File
import Gen
import Requirements(Operation)
declareBaseStructure :: Int -> [Operation] -> Gen ()
declareBaseStructure bitsize ops =
base :: File
base = File {
predicate = \ _ _ -> True,
outputName = "base",
generator = declareBaseStructure
}
declareBaseStructure :: Word -> Gen ()
declareBaseStructure bitsize =
do let name = "U" ++ show bitsize
entries = bitsize `div` 64
top = entries - 1
out "use core::cmp::{Eq,Ordering,PartialEq,min};"
out "use core::fmt;"
out "use super::super::CryptoNum;"
blank
out "mod binary;"
out "use super::super::super::CryptoNum;"
blank
out "#[derive(Clone)]"
wrapIndent ("pub struct " ++ name) $
out ("value: [u64; " ++ show entries ++ "]")
out ("pub(crate) value: [u64; " ++ show entries ++ "]")
blank
implFor "CryptoNum" name $
do wrapIndent ("fn zero() -> Self") $
@@ -89,81 +93,32 @@ declareBaseStructure bitsize ops =
do forM_ (reverse [1..top]) $ \ i ->
out ("write!(f, \"{:x}\", self.value[" ++ show i ++ "])?;")
out "write!(f, \"{:x}\", self.value[0])"
declareBinaryOperators :: Int -> Gen ()
declareBinaryOperators bitsize =
do let name = "U" ++ show bitsize
entries = bitsize `div` 64
out "use core::ops::{BitAnd,BitAndAssign};"
out "use core::ops::{BitOr,BitOrAssign};"
out "use core::ops::{BitXor,BitXorAssign};"
out "use core::ops::Not;"
out ("use super::U" ++ show bitsize ++ ";")
blank
generateBinOps "BitAnd" name "bitand" "&=" entries
blank
generateBinOps "BitOr" name "bitor" "|=" entries
blank
generateBinOps "BitXor" name "bitxor" "^=" entries
blank
implFor "Not" name $
do out "type Output = Self;"
blank
wrapIndent "fn not(mut self) -> Self" $
do forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];")
out "self"
blank
implFor' "Not" ("&'a " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn not(self) -> " ++ name) $
do out "let mut output = self.clone();"
forM_ [0..entries-1] $ \ i ->
out ("output.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];")
out "output"
generateBinOps :: String -> String -> String -> String -> Int -> Gen ()
generateBinOps trait name fun op entries =
do implFor (trait ++ "Assign") name $
wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: Self)") $
forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];")
blank
implFor' (trait ++ "Assign<&'a " ++ name ++ ">") name $
wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: &Self)") $
forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];")
blank
generateBinOpsFromAssigns trait name fun op
generateBinOpsFromAssigns :: String -> String -> String -> String -> Gen ()
generateBinOpsFromAssigns trait name fun op =
do implFor trait name $
do out "type Output = Self;"
blank
wrapIndent ("fn " ++ fun ++ "(mut self, rhs: Self) -> Self") $
do out ("self " ++ op ++ " rhs;")
out "self"
blank
implFor' (trait ++ "<&'a " ++ name ++ ">") name $
do out "type Output = Self;"
blank
wrapIndent ("fn " ++ fun ++ "(mut self, rhs: &Self) -> Self") $
do out ("self " ++ op ++ " rhs;")
out "self"
blank
implFor' (trait ++ "<" ++ name ++ ">") ("&'a " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn " ++ fun ++ "(self, mut rhs: " ++ name ++ ") -> " ++ name) $
do out ("rhs " ++ op ++ " self;")
out "rhs"
blank
implFor'' (trait ++ "<&'a " ++ name ++ ">") ("&'b " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn " ++ fun ++ "(self, rhs: &" ++ name ++ ") -> " ++ name) $
do out "let mut output = self.clone();"
out ("output " ++ op ++ " rhs;")
out "output"
out "#[test]"
wrapIndent "fn KATs()" $
do out ("run_test(\"testdata/base/" ++ name ++ ".test\", 8, |case| {")
indent $
do out ("let (neg0, xbytes) = case.get(\"x\").unwrap();")
out ("let (neg1, mbytes) = case.get(\"m\").unwrap();")
out ("let (neg2, zbytes) = case.get(\"z\").unwrap();")
out ("let (neg3, ebytes) = case.get(\"e\").unwrap();")
out ("let (neg4, obytes) = case.get(\"o\").unwrap();")
out ("let (neg5, rbytes) = case.get(\"r\").unwrap();")
out ("let (neg6, bbytes) = case.get(\"b\").unwrap();")
out ("let (neg7, tbytes) = case.get(\"t\").unwrap();")
out ("assert!(!neg0&&!neg1&&!neg2&&!neg3&&!neg4&&!neg5&&!neg6&&!neg7);")
out ("let mut x = "++name++"::from_bytes(xbytes);")
out ("let m = "++name++"::from_bytes(mbytes);")
out ("let z = 1 == zbytes[0];")
out ("let e = 1 == ebytes[0];")
out ("let o = 1 == obytes[0];")
out ("let r = "++name++"::from_bytes(rbytes);")
out ("let b = usize::from("++name++"::from_bytes(bbytes));")
out ("let t = 1 == tbytes[0];")
out ("assert_eq!(x.is_zero(), z);")
out ("assert_eq!(x.is_even(), e);")
out ("assert_eq!(x.is_odd(), o);")
out ("assert_eq!(x.testbit(b), t);")
out ("x.mask(usize::from(&m));")
out ("assert_eq!(x, r);")
out ("});")