Start experimenting with full generation of all of the numeric types.
Previously, we used a little bit of generation to drive a lot of Rust macros. This works, but it's a little confusing to read and write. In addition, we used a lot of implementations with variable timings based on their input, which isn't great for crypto. This is the start of an attempt to just generate all of the relevant Rust code directly, and to use timing-channel resistant implementations for most of the routines.
This commit is contained in:
73
generation/src/Gen.hs
Normal file
73
generation/src/Gen.hs
Normal file
@@ -0,0 +1,73 @@
|
||||
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
|
||||
module Gen(
|
||||
Gen(Gen),
|
||||
runGen,
|
||||
gensym,
|
||||
indent,
|
||||
blank,
|
||||
out,
|
||||
wrapIndent,
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Monad.RWS.Strict(RWS,evalRWS)
|
||||
import Control.Monad.State.Class(MonadState,get,put)
|
||||
import Control.Monad.Writer.Class(MonadWriter,tell)
|
||||
import Data.List(replicate)
|
||||
import Data.Word(Word)
|
||||
|
||||
newtype Gen a = Gen { unGen :: RWS () String GenState a}
|
||||
deriving (Applicative, Functor, Monad, MonadState GenState, MonadWriter String)
|
||||
|
||||
tabAmount :: Word
|
||||
tabAmount = 4
|
||||
|
||||
data GenState = GenState {
|
||||
indentAmount :: Word,
|
||||
gensymIndex :: Word
|
||||
}
|
||||
|
||||
initGenState :: GenState
|
||||
initGenState = GenState { indentAmount = 0, gensymIndex = 0 }
|
||||
|
||||
runGen :: FilePath -> Gen a -> IO a
|
||||
runGen path action =
|
||||
do let (res, contents) = evalRWS (unGen action) () initGenState
|
||||
writeFile path contents
|
||||
return res
|
||||
|
||||
gensym :: String -> Gen String
|
||||
gensym prefix =
|
||||
do gs <- get
|
||||
let gs' = gs{ gensymIndex = gensymIndex gs + 1 }
|
||||
put gs'
|
||||
return (prefix ++ show (gensymIndex gs))
|
||||
|
||||
indent :: Gen a -> Gen a
|
||||
indent action =
|
||||
do gs <- get
|
||||
put gs{ indentAmount = indentAmount gs + tabAmount }
|
||||
res <- action
|
||||
put gs
|
||||
return res
|
||||
|
||||
blank :: Gen ()
|
||||
blank = tell "\n"
|
||||
|
||||
out :: String -> Gen ()
|
||||
out val =
|
||||
do gs <- get
|
||||
tell (replicate (fromIntegral (indentAmount gs)) ' ')
|
||||
tell val
|
||||
tell "\n"
|
||||
|
||||
wrapIndent :: String -> Gen a -> Gen a
|
||||
wrapIndent val middle =
|
||||
do gs <- get
|
||||
tell (replicate (fromIntegral (indentAmount gs)) ' ')
|
||||
tell val
|
||||
tell " {\n"
|
||||
res <- indent middle
|
||||
tell (replicate (fromIntegral (indentAmount gs)) ' ')
|
||||
tell "}\n"
|
||||
return res
|
||||
32
generation/src/Main.hs
Normal file
32
generation/src/Main.hs
Normal file
@@ -0,0 +1,32 @@
|
||||
module Main
|
||||
where
|
||||
|
||||
import Control.Monad(forM_,unless)
|
||||
import Data.List(sort)
|
||||
import Data.Map.Strict(Map)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Gen(runGen)
|
||||
import Requirements(Requirement(..), Operation(..), requirements)
|
||||
import System.Directory(createDirectoryIfMissing)
|
||||
import System.Environment(getArgs)
|
||||
import System.Exit(die)
|
||||
import System.FilePath((</>))
|
||||
import UnsignedBase(declareBaseStructure)
|
||||
|
||||
gatherRequirements :: [Requirement] -> Map Int [Operation]
|
||||
gatherRequirements = foldr process Map.empty
|
||||
where process (Req x val) = Map.insertWith (++) x [val]
|
||||
|
||||
main :: IO ()
|
||||
main =
|
||||
do args <- getArgs
|
||||
unless (length args == 1) $
|
||||
die ("generation takes exactly one argument, the target directory")
|
||||
let reqs = sort (Map.toList (gatherRequirements requirements))
|
||||
target = head args
|
||||
forM_ reqs $ \ (size, opes) ->
|
||||
do let basedir = target </> "unsigned" </> ("u" ++ show size)
|
||||
createDirectoryIfMissing True basedir
|
||||
forM_ reqs $ \ (x, ops) ->
|
||||
do runGen (basedir </> "mod.rs") (declareBaseStructure size ops)
|
||||
|
||||
278
generation/src/Requirements.hs
Normal file
278
generation/src/Requirements.hs
Normal file
@@ -0,0 +1,278 @@
|
||||
module Requirements(
|
||||
Operation(..),
|
||||
Requirement(..),
|
||||
requirements
|
||||
)
|
||||
where
|
||||
|
||||
import Data.List(sort)
|
||||
|
||||
data Operation = Add
|
||||
| BaseOps
|
||||
| Barretts
|
||||
| Div
|
||||
| ModDiv
|
||||
| ModExp
|
||||
| ModMul
|
||||
| ModSq
|
||||
| Mul
|
||||
| Scale
|
||||
| Shifts
|
||||
| Square
|
||||
| Sub
|
||||
| Convert Int
|
||||
| SignedAdd
|
||||
| SignedBase
|
||||
| SignedCmp
|
||||
| SignedShift
|
||||
| SignedSub
|
||||
| SignedMul
|
||||
| SignedDiv
|
||||
| SignedModInv
|
||||
| SignedScale
|
||||
| SigConvert Int
|
||||
| SquareRoot
|
||||
| EGCD
|
||||
| ModInv
|
||||
| PrimeGen
|
||||
| RSA
|
||||
| DSA
|
||||
| ECDSA
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
data Requirement = Req Int Operation
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
data Need = Need Operation (Int -> [Requirement])
|
||||
|
||||
needs :: [Need]
|
||||
needs = [ Need RSA (\ size -> [Req (size `div` 2) Sub,
|
||||
Req (size `div` 2) Mul,
|
||||
Req (size `div` 2) PrimeGen,
|
||||
Req size BaseOps,
|
||||
Req size ModInv,
|
||||
Req size ModExp
|
||||
])
|
||||
, Need DSA (\ size -> [Req size BaseOps,
|
||||
Req size Shifts,
|
||||
Req size Add,
|
||||
Req size SquareRoot,
|
||||
Req size PrimeGen,
|
||||
Req size ModInv,
|
||||
Req size Mul,
|
||||
Req (size * 2) Add,
|
||||
Req (((size * 2) + 64) * 2) Div,
|
||||
Req size (Convert 512),
|
||||
Req size (Convert (size + 128)),
|
||||
Req size (Convert ((size * 2) + 64)),
|
||||
Req size (Convert (((size * 2) + 64) * 2))
|
||||
])
|
||||
, Need ECDSA (\ size -> [Req size SignedSub,
|
||||
Req size SignedMul,
|
||||
Req size ModMul,
|
||||
Req size ModDiv,
|
||||
Req (size * 2) BaseOps,
|
||||
Req (size * 2) SignedBase,
|
||||
Req (size * 2) SignedShift,
|
||||
Req (size * 2) SignedSub,
|
||||
Req (size * 2) SignedMul,
|
||||
Req (size * 2) SignedDiv,
|
||||
Req ((size * 2) + 64) SignedBase,
|
||||
Req ((size * 2) + 64) BaseOps,
|
||||
Req ((size * 2) + 64) SignedAdd,
|
||||
Req ((size * 2) + 64) SignedShift,
|
||||
Req ((size * 2) + 64) ModDiv,
|
||||
Req size (Convert (size * 2)),
|
||||
Req size (SigConvert (size * 2)),
|
||||
Req size (Convert ((size * 2) + 64)),
|
||||
Req size (SigConvert ((size * 2) + 64)),
|
||||
Req (size * 2) (Convert ((size * 2) + 64)),
|
||||
Req (size * 2) (SigConvert ((size * 2) + 64)),
|
||||
Req size (Convert (size * 4)),
|
||||
Req (size * 4) Div
|
||||
])
|
||||
, Need PrimeGen (\ size -> [Req size Div,
|
||||
Req size Shifts,
|
||||
Req size ModExp,
|
||||
Req size EGCD])
|
||||
, Need Add (\ size -> [Req size BaseOps,
|
||||
Req (size + 64) BaseOps,
|
||||
Req size (Convert (size + 64))
|
||||
])
|
||||
, Need Barretts (\ size -> [Req size BaseOps,
|
||||
Req (size + 64) BaseOps,
|
||||
Req (size * 2) BaseOps,
|
||||
Req ((size * 2) + 64) BaseOps,
|
||||
Req size (Convert ((size * 2) + 64)),
|
||||
Req (size + 64) Mul,
|
||||
Req ((size * 2) + 64) Add,
|
||||
Req ((size * 2) + 64) Sub,
|
||||
Req (size + 64) (Convert ((size * 2) + 64)),
|
||||
Req ((size * 2) + 64) (Convert ((size + 64) * 2)),
|
||||
Req (size * 2) (Convert ((size * 2) + 64)),
|
||||
Req (size + 64) (Convert ((size + 64) * 2)),
|
||||
Req (size + 64) (Convert (size * 2)),
|
||||
Req (size * 2) Shifts,
|
||||
Req ((size + 64) * 2) Shifts,
|
||||
Req ((size * 2) + 64) Div
|
||||
])
|
||||
, Need Div (\ size -> [Req size BaseOps,
|
||||
Req (size * 2) BaseOps,
|
||||
Req size (Convert (size * 2)),
|
||||
Req (size * 2) Sub,
|
||||
Req size Mul,
|
||||
Req 192 BaseOps,
|
||||
Req 192 Mul,
|
||||
Req 384 BaseOps
|
||||
])
|
||||
, Need ModExp (\ size -> [Req size BaseOps,
|
||||
Req size Barretts,
|
||||
Req size ModSq,
|
||||
Req size ModMul,
|
||||
Req size (Convert (size + 64))
|
||||
])
|
||||
, Need ModMul (\ size -> [Req size BaseOps,
|
||||
Req (size * 2) BaseOps,
|
||||
Req size Barretts,
|
||||
Req size Mul,
|
||||
Req size (Convert (size + 64))
|
||||
])
|
||||
, Need ModDiv (\ size -> [Req size ModInv,
|
||||
Req size SignedModInv,
|
||||
Req size SignedMul,
|
||||
Req size SignedDiv,
|
||||
Req (size * 2) SignedDiv,
|
||||
Req size (SigConvert (size * 2))
|
||||
])
|
||||
, Need ModSq (\ size -> [Req size BaseOps,
|
||||
Req (size * 2) BaseOps,
|
||||
Req size Barretts,
|
||||
Req size Square,
|
||||
Req (size * 2) Div,
|
||||
Req size (Convert (size * 2)),
|
||||
Req size (Convert (size + 64))
|
||||
])
|
||||
, Need Mul (\ size -> [Req size BaseOps,
|
||||
Req size Scale,
|
||||
Req (size * 2) BaseOps,
|
||||
Req size (Convert (size * 2))
|
||||
])
|
||||
, Need Scale (\ size -> [Req (size + 64) BaseOps])
|
||||
, Need Shifts (\ size -> [Req size BaseOps
|
||||
])
|
||||
, Need Square (\ size -> [Req size BaseOps,
|
||||
Req (size * 2) BaseOps
|
||||
])
|
||||
, Need Sub (\ size -> [Req size BaseOps
|
||||
])
|
||||
, Need SignedAdd (\ size -> [Req size SignedBase,
|
||||
Req size Add,
|
||||
Req size Sub,
|
||||
Req (size + 64) SignedBase,
|
||||
Req (size + 64) BaseOps
|
||||
])
|
||||
, Need SignedBase (\ size -> [Req size BaseOps])
|
||||
, Need SignedCmp (\ size -> [Req size BaseOps])
|
||||
, Need SignedShift (\ size -> [Req size SignedBase,
|
||||
Req size BaseOps,
|
||||
Req size Shifts,
|
||||
Req size Add
|
||||
])
|
||||
, Need SignedSub (\ size -> [Req size SignedBase,
|
||||
Req (size + 64) SignedBase,
|
||||
Req (size + 64) BaseOps,
|
||||
Req size Add,
|
||||
Req size Sub,
|
||||
Req (size + 64) Sub,
|
||||
Req size (Convert (size + 64)),
|
||||
Req size (SigConvert (size + 64))
|
||||
])
|
||||
, Need SignedMul (\ size -> [Req size Mul,
|
||||
Req size SignedScale,
|
||||
Req (size * 2) SignedBase,
|
||||
Req size (SigConvert (size * 2)),
|
||||
Req size Square
|
||||
])
|
||||
, Need SignedDiv (\ size -> [Req size Div,
|
||||
Req size Add
|
||||
])
|
||||
, Need EGCD (\ size -> [Req size SignedBase,
|
||||
Req size BaseOps,
|
||||
Req size Shifts,
|
||||
Req (size + 64) SignedBase,
|
||||
Req ((size + 64) * 2) SignedBase,
|
||||
Req size (SigConvert (size + 64)),
|
||||
Req (size + 64) SignedShift,
|
||||
Req (size + 64) SignedAdd,
|
||||
Req (size + 64) SignedSub,
|
||||
Req (size + 64) SignedCmp,
|
||||
Req (size + 64) SignedDiv,
|
||||
Req (size + 64) SignedMul,
|
||||
Req ((size + 64) * 2) SignedSub,
|
||||
Req (size + 64) (Convert (((size + 64) * 2) + 64)),
|
||||
Req (size + 64) (SigConvert (((size + 64) * 2) + 64))
|
||||
])
|
||||
, Need ModInv (\ size -> [Req size BaseOps,
|
||||
Req (size + 64) SignedBase,
|
||||
Req (size + 64) BaseOps,
|
||||
Req size (Convert (size + 64)),
|
||||
Req size EGCD,
|
||||
Req (size + 64) SignedAdd,
|
||||
Req size Barretts
|
||||
])
|
||||
, Need SignedModInv (\ size -> [
|
||||
Req size EGCD,
|
||||
Req size SignedModInv
|
||||
])
|
||||
, Need SquareRoot (\ size -> [Req size BaseOps,
|
||||
Req size Shifts,
|
||||
Req size Add,
|
||||
Req size Sub
|
||||
])
|
||||
]
|
||||
|
||||
newRequirements :: Requirement -> [Requirement]
|
||||
newRequirements (Req size op) = concatMap go needs ++ [Req size BaseOps]
|
||||
where
|
||||
go (Need op2 generator) | op == op2 = generator size
|
||||
| otherwise = []
|
||||
|
||||
rsaSizes :: [Int]
|
||||
rsaSizes = [512,1024,2048,3072,4096,8192,15360]
|
||||
|
||||
dsaSizes :: [Int]
|
||||
dsaSizes = [192,256,1024,2048,3072]
|
||||
|
||||
ecdsaSizes :: [Int]
|
||||
ecdsaSizes = [192,256,384,576]
|
||||
|
||||
baseRequirements :: [Requirement]
|
||||
baseRequirements = concatMap (\ x -> [Req x RSA]) rsaSizes
|
||||
++ concatMap (\ x -> [Req x DSA]) dsaSizes
|
||||
++ concatMap (\ x -> [Req x ECDSA]) ecdsaSizes
|
||||
++ [Req 192 (Convert 1024), Req 256 (Convert 2048), Req 256 (Convert 3072)] -- used in DSA
|
||||
++ [Req 384 (Convert 1024), Req 512 (Convert 2048), Req 512 (Convert 3072)] -- used in DSA
|
||||
++ [Req 192 Add, Req 256 Add, Req 384 Add] -- used for testing
|
||||
++ [Req 192 Mul, Req 384 Mul] -- used for testing
|
||||
++ [Req 448 (Convert 512)] -- used for testing
|
||||
|
||||
requirements :: [Requirement]
|
||||
requirements = go baseRequirements
|
||||
where
|
||||
step ls = let news = concatMap newRequirements ls
|
||||
ls' = concatMap sanitizeConverts (news ++ ls)
|
||||
ls'' = removeDups (sort ls')
|
||||
in ls''
|
||||
--
|
||||
go ls = let ls' = step ls
|
||||
in if ls == ls' then ls else go ls'
|
||||
--
|
||||
removeDups [] = []
|
||||
removeDups (x:xs) | x `elem` xs = removeDups xs
|
||||
| otherwise = x : removeDups xs
|
||||
--
|
||||
sanitizeConverts (Req x (Convert y))
|
||||
| x == y = []
|
||||
| x < y = [Req x (Convert y), Req y BaseOps]
|
||||
| otherwise = [Req y (Convert x), Req x BaseOps]
|
||||
sanitizeConverts x = [x]
|
||||
89
generation/src/UnsignedBase.hs
Normal file
89
generation/src/UnsignedBase.hs
Normal file
@@ -0,0 +1,89 @@
|
||||
module UnsignedBase(
|
||||
declareBaseStructure
|
||||
)
|
||||
where
|
||||
|
||||
import Control.Monad(forM_)
|
||||
import Gen
|
||||
import Requirements(Operation)
|
||||
|
||||
declareBaseStructure :: Int -> [Operation] -> Gen ()
|
||||
declareBaseStructure bitsize ops =
|
||||
do let name = "U" ++ show bitsize
|
||||
entries = bitsize `div` 64
|
||||
top = entries - 1
|
||||
out "use core::cmp::{Eq,Ordering,PartialEq,min};"
|
||||
out "use core::fmt;"
|
||||
out "use super::super::CryptoNum;"
|
||||
blank
|
||||
out "mod binary_ops;"
|
||||
blank
|
||||
wrapIndent ("pub struct " ++ name) $
|
||||
out ("value: [u64; " ++ show entries ++ "]")
|
||||
blank
|
||||
wrapIndent ("impl CryptoNum for " ++ name) $
|
||||
do wrapIndent ("fn zero() -> Self") $
|
||||
out (name ++ "{ value: [0; " ++ show entries ++ "] }")
|
||||
blank
|
||||
wrapIndent ("fn is_zero(&self) -> bool") $
|
||||
do forM_ (reverse [1..top]) $ \ i ->
|
||||
out ("self.value[" ++ show i ++ "] == 0 &&")
|
||||
out "self.value[0] == 0"
|
||||
blank
|
||||
wrapIndent ("fn is_even(&self) -> bool") $
|
||||
out "self.value[0] & 0x1 == 0"
|
||||
blank
|
||||
wrapIndent ("fn is_odd(&self) -> bool") $
|
||||
out "self.value[0] & 0x1 == 0"
|
||||
blank
|
||||
wrapIndent ("fn bit_length() -> usize") $
|
||||
out (show bitsize)
|
||||
blank
|
||||
wrapIndent ("fn mask(&mut self, len: usize)") $
|
||||
do out ("let dellen = min(len, " ++ show entries ++ ");")
|
||||
wrapIndent ("for i in dellen.." ++ show entries) $
|
||||
out ("self.value[i] = 0;")
|
||||
blank
|
||||
wrapIndent ("fn testbit(&self, bit: usize) -> bool") $
|
||||
do out "let idx = bit / 64;"
|
||||
out "let offset = bit % 64;"
|
||||
wrapIndent ("if idx >= " ++ show entries) $
|
||||
out "return false;"
|
||||
out "(self.value[idx] & (1u64 << offset)) != 0"
|
||||
blank
|
||||
wrapIndent ("impl PartialEq for " ++ name) $
|
||||
wrapIndent "fn eq(&self, other: &Self) -> bool" $
|
||||
do forM_ (reverse [1..top]) $ \ i ->
|
||||
out ("self.value[" ++ show i ++ "] == other.value[" ++ show i ++ "] && ")
|
||||
out "self.value[0] == other.value[0]"
|
||||
blank
|
||||
out ("impl Eq for " ++ name ++ " {}")
|
||||
blank
|
||||
wrapIndent ("impl Ord for " ++ name) $
|
||||
wrapIndent "fn cmp(&self, other: &Self) -> Ordering" $
|
||||
do out ("self.value[" ++ show top ++ "].cmp(&other.value[" ++ show top ++ "])")
|
||||
forM_ (reverse [0..top-1]) $ \ i ->
|
||||
out (" .then(self.value[" ++ show i ++ "].cmp(&other.value[" ++ show i ++ "]))")
|
||||
blank
|
||||
wrapIndent ("impl PartialOrd for " ++ name) $
|
||||
wrapIndent "fn partial_cmp(&self, other: &Self) -> Option<Ordering>" $
|
||||
out "Some(self.cmp(other))"
|
||||
blank
|
||||
wrapIndent ("impl fmt::Debug for " ++ name) $
|
||||
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
|
||||
do out ("f.debug_tuple(" ++ show name ++ ")")
|
||||
forM_ [0..top] $ \ i ->
|
||||
out (" .field(&self.value[" ++ show i ++ "])")
|
||||
out " .finish()"
|
||||
blank
|
||||
wrapIndent ("impl fmt::UpperHex for " ++ name) $
|
||||
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
|
||||
do forM_ (reverse [1..top]) $ \ i ->
|
||||
out ("write!(f, \"{:X}\", self.value[" ++ show i ++ "])?;")
|
||||
out "write!(f, \"{:X}\", self.value[0])"
|
||||
blank
|
||||
wrapIndent ("impl fmt::LowerHex for " ++ name) $
|
||||
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
|
||||
do forM_ (reverse [1..top]) $ \ i ->
|
||||
out ("write!(f, \"{:x}\", self.value[" ++ show i ++ "])?;")
|
||||
out "write!(f, \"{:x}\", self.value[0])"
|
||||
Reference in New Issue
Block a user