40 Commits

Author SHA1 Message Date
1ea75721fd Generate multiple modules, instead of one. 2020-04-26 19:54:22 -07:00
a622aa9cc9 Trying to limit some of the instructions we do in Karatsuba multiplication ... 2020-04-26 19:53:40 -07:00
9ee668daad Replace Variable Strings with Words. 2020-04-12 19:53:42 -07:00
0483bb8692 [CHECKPOINT] Adjust the Karatsuba implementation to abstract Variables from Strings 2020-04-12 19:52:00 -07:00
2baa5f070d Target generation, ideally, to those files we'll need for doing crypto. 2020-04-12 19:51:29 -07:00
f93aa7ffc3 A few small changes to try to make generation faster. 2020-03-02 15:20:33 -08:00
9c76d7e0b4 Switch to a more dynamic, time-based test generation scheme. 2020-03-01 16:56:17 -08:00
71451617f9 Generate one file per type, rather than ... quite a few. 2020-03-01 15:51:35 -08:00
b995c1705f Add some more (in some cases, temporary) ignores to .gitignore. 2020-03-01 13:16:45 -08:00
af983adf1a Define a ModInv trait, and clean up some lingering warnings. 2020-02-09 17:03:33 -08:00
2617609bf6 Fix conversions and modinv. 2020-02-09 17:12:23 -06:00
d8a2e66e7c Sort out signed shifts. 2020-02-07 16:59:15 -06:00
6dd32647b5 Fix signed addition and subtraction. 2020-02-07 16:03:49 -06:00
89c297525a Improve the state of building, moving towards addition working for signed numbers. 2020-01-23 16:09:11 -08:00
ac01aad415 Unnecessary concurrency and terminal silliness. 2020-01-18 10:33:07 -08:00
b3fcd4715e All the infrastructure to eventually to modinv. Don't try to use any of this yet. 2020-01-17 20:44:41 -08:00
e46cfe56d1 Support generating signed numbers. 2020-01-14 12:13:40 -10:00
4383b67c44 Unsigned modular operations. 2020-01-10 09:05:11 -10:00
4b8d0b3f09 Better support for conversions between number types. 2020-01-10 09:04:47 -10:00
8c5f18cb7c Commit the start of the work on modular X before I worry about From. 2020-01-08 15:19:34 -10:00
3e82008189 Clean up an import. 2020-01-08 08:55:07 -10:00
2888164814 Division / modulus! 2020-01-07 18:51:29 -10:00
c1d2922ab2 Fix number printing again. 2020-01-07 18:51:07 -10:00
00e59673f7 Support scaling values by primitive types. 2020-01-06 13:16:11 -08:00
a35d0df6da Multiply works! 2020-01-06 12:25:38 -08:00
d8c752fad3 Start working on generating multiplies via Karatsuba. 2019-12-30 20:38:28 -08:00
e5fa103db0 Support subtraction. 2019-11-27 16:23:52 -08:00
4d724a5a6e Fix printing. 2019-11-27 16:23:13 -08:00
14ccd6c2b9 add a project, since the language-rust extensions haven't been released. 2019-11-27 16:23:00 -08:00
cf80930854 Support addition. 2019-11-27 15:16:15 -08:00
e223950c9f Remove an unnecessary clone. 2019-11-23 21:05:23 -08:00
430401ba54 Fix right shift. 2019-11-23 20:59:21 -08:00
8c4b369911 Fix left shift. 2019-11-23 20:36:21 -08:00
ba587cb37f Start trying to generate shift code. 2019-11-04 17:08:16 -08:00
ed07a0855d Tests, that work! 2019-10-31 18:56:10 -04:00
c52dadcf22 Some commits in the way of cleaning up the Rust and generating module lists. 2019-10-31 16:39:00 -04:00
3b0bd25dfa Clean out the old testdata work. 2019-10-31 16:36:19 -04:00
0dec5815dc Fix a bunch of build errors. 2019-10-24 08:46:42 -07:00
620048bce6 Complete the shift over to language-rust. 2019-10-22 22:06:34 -07:00
2400b10fbc Start working on switching to language-rust as a generator, for fun. 2019-10-22 20:12:08 -07:00
924 changed files with 5180 additions and 3541389 deletions

2
.gitignore vendored
View File

@@ -2,6 +2,8 @@
**/*.rs.bk **/*.rs.bk
Cargo.lock Cargo.lock
.vscode/
generate.hi generate.hi
generate.o generate.o
generate generate

View File

@@ -1,2 +1,3 @@
.ghc.environment* .ghc.environment*
dist-newstyle/ dist-newstyle/
dist/

132
generation/Main.hs Normal file
View File

@@ -0,0 +1,132 @@
module Main
where
import Add(safeAddOps,unsafeAddOps,safeSignedAddOps,unsafeSignedAddOps)
import Base(base)
import BinaryOps(binaryOps)
import Compare(comparisons, signedComparisons)
import Control.Concurrent(forkFinally)
import Control.Concurrent.MVar(MVar, newEmptyMVar, newMVar, putMVar, takeMVar)
import Control.Monad(replicateM, void)
import Conversions(conversions, signedConversions)
import CryptoNum(cryptoNum)
import Control.Monad(forM_,unless)
import Data.List(nub)
import Data.Text.Lazy(Text, pack)
import Division(divisionOps)
import GHC.Conc(getNumCapabilities)
import ModInv(generateModInvOps)
import ModOps(modulusOps)
import Multiply(safeMultiplyOps, unsafeMultiplyOps)
import RustModule(RustModule(suggested),Task(..),generateTasks)
import Scale(safeScaleOps, unsafeScaleOps)
import Shift(shiftOps, signedShiftOps)
import Signed(signedBaseOps)
import Subtract(safeSubtractOps,unsafeSubtractOps,safeSignedSubtractOps,unsafeSignedSubtractOps)
import System.Directory(createDirectoryIfMissing)
import System.Environment(getArgs)
import System.Exit(die)
import System.FilePath(takeDirectory,(</>))
import System.IO(IOMode(..),withFile)
import System.ProgressBar(Label(..), Progress(..), ProgressBar, Timing, defStyle, newProgressBar, stylePrefix, updateProgress)
import System.Random(getStdGen)
rsaWordSizes :: [Word]
rsaWordSizes = [512, 1024, 2048, 3072, 4096, 8192, 15360]
dsaWordSizes :: [Word]
dsaWordSizes = [192, 256, 384, 1024, 2048, 3072]
ecdsaIntSizes :: [Word]
ecdsaIntSizes = [192, 256, 384, 576]
bitsizes :: [Word]
bitsizes = expandSizes initialSet
where
initialSet = nub (rsaWordSizes ++ dsaWordSizes ++ ecdsaIntSizes)
unsignedFiles :: [RustModule]
unsignedFiles = [
base
, binaryOps
, comparisons
, conversions
, cryptoNum
, divisionOps
, generateModInvOps
, modulusOps
, safeAddOps
, safeMultiplyOps
, safeScaleOps
, safeSubtractOps
, shiftOps
, unsafeAddOps
, unsafeMultiplyOps
, unsafeScaleOps
, unsafeSubtractOps
]
signedFiles :: [RustModule]
signedFiles = [
safeSignedAddOps
, safeSignedSubtractOps
, signedBaseOps
, signedComparisons
, signedConversions
, signedShiftOps
, unsafeSignedAddOps
, unsafeSignedSubtractOps
]
allFiles :: [RustModule]
allFiles = unsignedFiles ++ signedFiles
expandSizes :: [Word] -> [Word]
expandSizes ls = bigger
where
bigger = nub (ls ++ concatMap (\ f -> concatMap (\ x -> suggested f x) ls) allFiles)
printLast :: Progress String -> Timing -> Text
printLast prog _ = pack (progressCustom prog)
runThread :: ProgressBar String -> FilePath -> MVar [Task] -> IO (MVar ())
runThread pb outputPath mtaskls =
do res <- newEmptyMVar
void $ forkFinally step (threadDie res)
return res
where
step =
do tasks <- takeMVar mtaskls
case tasks of
[] ->
putMVar mtaskls []
task : rest ->
do putMVar mtaskls rest
let target = outputPath </> outputFile task
createDirectoryIfMissing True (takeDirectory target)
withFile target WriteMode $ \ targetHandle ->
writer task targetHandle
updateProgress pb (\ p -> p{ progressCustom = outputFile task,
progressDone = progressDone p + 1 })
step
threadDie resmv thrRes =
do case thrRes of
Left se -> putStrLn ("Thread died: " ++ show se)
Right () -> return ()
putMVar resmv ()
main :: IO ()
main =
do args <- getArgs
unless (length args == 1) $
die ("generation takes exactly one argument, the target directory")
g <- getStdGen
let style = defStyle{ stylePrefix = Label printLast }
allTasks = generateTasks g allFiles bitsizes
progress = Progress 0 total "starting"
total = length allTasks
pb <- newProgressBar style 60 progress
chan <- newMVar allTasks
count <- getNumCapabilities
threads <- replicateM count (runThread pb (head args) chan)
forM_ threads (\ m -> takeMVar m)

46
generation/Test.hs Normal file
View File

@@ -0,0 +1,46 @@
import Data.Bits hiding (bit)
import GHC.Integer.GMP.Internals
import qualified Karatsuba
import Numeric
import Test.QuickCheck
modular_exponentiation :: Integer -> Integer -> Integer -> Integer
modular_exponentiation x y m = m_e_loop x y 1
where
m_e_loop _ 0 result = result
m_e_loop b e result = m_e_loop b' e' result'
where
b' = (b * b) `mod` m
e' = e `shiftR` 1
result' = if testBit e 0 then (result * b) `mod` m else result
prop_modExpSane :: Integer -> Integer -> Integer -> Property
prop_modExpSane b e m = (m' > 1) ==> modular_exponentiation b' e' m' == powModInteger b' e' m'
where
b' = abs b
e' = abs e
m' = abs m
modexpLR :: Int -> Integer -> Integer -> Integer -> Integer
modexpLR bitsize b e m = go (bitsize - 1) 1
where
go bit r0
| bit < 0 = r0
| testBit e bit = go (bit - 1) r2
| otherwise = go (bit - 1) r1
where
r1 = (r0 * r0) `mod` m
r2 = (r1 * b) `mod` m
prop_modExpLR192 :: Integer -> Integer -> Integer -> Property
prop_modExpLR192 b e m = (m' > 1) ==> modexpLR 192 b' e' m' == powModInteger b' e' m'
where
b' = abs b `mod` (2 ^ (192 :: Integer))
e' = abs e `mod` (2 ^ (192 :: Integer))
m' = abs m `mod` (2 ^ (192 :: Integer))
main :: IO ()
main =
do Karatsuba.runChecks
Karatsuba.runQuickCheck "Modular exponentiation sanity check" prop_modExpSane
Karatsuba.runQuickCheck "ModExp LR 192 works" prop_modExpLR192

2
generation/cabal.project Normal file
View File

@@ -0,0 +1,2 @@
packages: ./
../../language-rust/

View File

@@ -2,7 +2,6 @@ cabal-version: 2.0
-- Initial package description 'generation.cabal' generated by 'cabal -- Initial package description 'generation.cabal' generated by 'cabal
-- init'. For further documentation, see -- init'. For further documentation, see
-- http://haskell.org/cabal/users-guide/ -- http://haskell.org/cabal/users-guide/
name: generation name: generation
version: 0.1.0.0 version: 0.1.0.0
synopsis: Generates the cryptonum Rust library, based on requirements. synopsis: Generates the cryptonum Rust library, based on requirements.
@@ -16,16 +15,48 @@ category: Math
build-type: Simple build-type: Simple
extra-source-files: CHANGELOG.md extra-source-files: CHANGELOG.md
executable generation library
main-is: Main.hs default-language: Haskell2010
other-modules: Base, BinaryOps, Compare, Conversions, CryptoNum, File, Gen, Testing ghc-options: -Wall
-- other-extensions: build-depends: base,
build-depends: base ^>=4.12.0.0,
containers, containers,
directory, directory,
filepath, filepath,
integer-gmp,
language-rust,
largeword,
mtl, mtl,
random QuickCheck,
random,
vector
hs-source-dirs: src hs-source-dirs: src
exposed-modules: Add,
Base,
BinaryOps,
Compare,
Conversions,
CryptoNum,
Division,
Generators,
Karatsuba,
ModInv,
ModOps,
Multiply,
RustModule,
Scale,
Shift,
Signed,
Subtract
executable generation
main-is: Main.hs
default-language: Haskell2010 default-language: Haskell2010
ghc-options: -Wall -threaded -with-rtsopts=-N
build-depends: base, directory, filepath, generation, random, terminal-progress-bar, text
test-suite test-generation
type: exitcode-stdio-1.0
default-language: Haskell2010
main-is: Test.hs
ghc-options: -Wall ghc-options: -Wall
build-depends: base, generation, integer-gmp, QuickCheck

405
generation/src/Add.hs Normal file
View File

@@ -0,0 +1,405 @@
{-# LANGUAGE QuasiQuotes #-}
module Add(
safeAddOps
, unsafeAddOps
, safeSignedAddOps
, unsafeSignedAddOps
)
where
import Data.Bits((.&.))
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Generators
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
safeAddOps :: RustModule
safeAddOps = RustModule {
predicate = \ me others -> (me + 64) `elem` others,
suggested = \ me -> [me + 64],
outputName = "safe_add",
isUnsigned = True,
generator = declareSafeAddOperators,
testCase = Just generateSafeTest
}
unsafeAddOps :: RustModule
unsafeAddOps = RustModule {
predicate = \ _ _ -> True,
suggested = const [],
outputName = "unsafe_add",
isUnsigned = True,
generator = declareUnsafeAddOperators,
testCase = Just generateUnsafeTest
}
safeSignedAddOps :: RustModule
safeSignedAddOps = RustModule {
predicate = \ me others -> (me + 64) `elem` others,
suggested = \ me -> [me + 64],
outputName = "safe_sadd",
isUnsigned = False,
generator = declareSafeSignedAddOperators,
testCase = Just generateSafeSignedTest
}
unsafeSignedAddOps :: RustModule
unsafeSignedAddOps = RustModule {
predicate = \ _ _ -> True,
suggested = const [],
outputName = "unsafe_sadd",
isUnsigned = False,
generator = declareUnsafeSignedAddOperators,
testCase = Just generateUnsafeSignedTest
}
declareSafeAddOperators :: Word -> [Word] -> SourceFile Span
declareSafeAddOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize)
dname = mkIdent ("U" ++ show (bitsize + 64))
fullRippleAdd = makeRippleAdder True (bitsize `div` 64) "res"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::Add;
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use crate::unsigned::{$$sname,$$dname};
impl Add for $$sname {
type Output = $$dname;
fn add(self, rhs: $$sname) -> $$dname {
&self + &rhs
}
}
impl<'a> Add<&'a $$sname> for $$sname {
type Output = $$dname;
fn add(self, rhs: &$$sname) -> $$dname {
&self + rhs
}
}
impl<'a> Add<$$sname> for &'a $$sname {
type Output = $$dname;
fn add(self, rhs: $$sname) -> $$dname {
self + &rhs
}
}
impl<'a,'b> Add<&'a $$sname> for &'b $$sname {
type Output = $$dname;
fn add(self, rhs: &$$sname) -> $$dname {
let mut res = $$dname::zero();
$@{fullRippleAdd}
res
}
}
#[cfg(test)]
quickcheck! {
fn addition_symmetric(a: $$sname, b: $$sname) -> bool {
(&a + &b) == (&b + &a)
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("safe_add", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$dname::from_bytes(&zbytes);
assert_eq!(z, x + y);
});
}
|]
declareUnsafeAddOperators :: Word -> [Word] -> SourceFile Span
declareUnsafeAddOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize)
fullRippleAdd = makeRippleAdder False (bitsize `div` 64) "self"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::AddAssign;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use crate::unsigned::$$sname;
impl AddAssign for $$sname {
fn add_assign(&mut self, rhs: Self) {
self.add_assign(&rhs);
}
}
impl<'a> AddAssign<&'a $$sname> for $$sname {
fn add_assign(&mut self, rhs: &Self) {
$@{fullRippleAdd}
}
}
#[cfg(test)]
quickcheck! {
fn addition_symmetric(a: $$sname, b: $$sname) -> bool {
let mut side1 = a.clone();
let mut side2 = b.clone();
side1 += b;
side2 += a;
side1 == side2
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("unsafe_add", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let mut x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$sname::from_bytes(&zbytes);
x += &y;
assert_eq!(z, x);
});
}
|]
declareSafeSignedAddOperators :: Word -> [Word] -> SourceFile Span
declareSafeSignedAddOperators bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
dname = mkIdent ("I" ++ show (bitsize + 64))
testFileLit = Lit [] (Str (testFile False bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::Add;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use crate::signed::{$$sname,$$dname};
impl Add for $$sname {
type Output = $$dname;
fn add(self, rhs: $$sname) -> $$dname {
&self + &rhs
}
}
impl<'a> Add<&'a $$sname> for $$sname {
type Output = $$dname;
fn add(self, rhs: &$$sname) -> $$dname {
&self + rhs
}
}
impl<'a> Add<$$sname> for &'a $$sname {
type Output = $$dname;
fn add(self, rhs: $$sname) -> $$dname {
self + &rhs
}
}
impl<'a,'b> Add<&'a $$sname> for &'b $$sname {
type Output = $$dname;
fn add(self, rhs: &$$sname) -> $$dname {
let mut res = $$dname::from(self);
let bigrhs = $$dname::from(rhs);
res += bigrhs;
res
}
}
#[cfg(test)]
quickcheck! {
fn addition_symmetric(a: $$sname, b: $$sname) -> bool {
(&a + &b) == (&b + &a)
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("safe_sadd", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let mut x = $$sname::from_bytes(&xbytes);
let mut y = $$sname::from_bytes(&ybytes);
let mut z = $$dname::from_bytes(&zbytes);
if *neg0 { x = -x }
if *neg1 { y = -y }
if *neg2 { z = -z }
assert_eq!(z, x + y);
});
}
|]
declareUnsafeSignedAddOperators :: Word -> [Word] -> SourceFile Span
declareUnsafeSignedAddOperators bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
testFileLit = Lit [] (Str (testFile False bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::AddAssign;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use crate::signed::$$sname;
impl AddAssign for $$sname {
fn add_assign(&mut self, rhs: Self) {
self.contents += rhs.contents;
}
}
impl<'a> AddAssign<&'a $$sname> for $$sname {
fn add_assign(&mut self, rhs: &Self) {
self.contents += &rhs.contents;
}
}
#[cfg(test)]
quickcheck! {
fn addition_symmetric(a: $$sname, b: $$sname) -> bool {
let mut side1 = a.clone();
let mut side2 = b.clone();
side1 += b;
side2 += a;
side1 == side2
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("unsafe_sadd", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let mut x = $$sname::from_bytes(&xbytes);
let mut y = $$sname::from_bytes(&ybytes);
let mut z = $$sname::from_bytes(&zbytes);
if *neg0 { x = -x }
if *neg1 { y = -y }
if *neg2 { z = -z }
x += &y;
assert_eq!(z, x);
});
}
|]
makeRippleAdder :: Bool -> Word -> String -> [Stmt Span]
makeRippleAdder useLastCarry inElems resName =
concatMap (generateRipples useLastCarry (inElems - 1)) [0..inElems-1] ++
concatMap (generateSetters useLastCarry inElems resName) [0..inElems]
generateRipples :: Bool -> Word -> Word -> [Stmt Span]
generateRipples useLastCarry lastI i =
let sumi = mkIdent ("sum" ++ show i)
inCarry = mkIdent ("carry" ++ show (i - 1))
outCarry = mkIdent ("carry" ++ show i)
res = mkIdent ("res" ++ show i)
liti = toLit i
left = mkIdent ("left" ++ show i)
right = mkIdent ("right" ++ show i)
in [
[stmt|let $$left = self.value[$$(liti)] as u128; |]
, [stmt|let $$right = rhs.value[$$(liti)] as u128; |]
, if i == 0
then [stmt| let $$sumi = $$left + $$right; |]
else [stmt| let $$sumi = $$left + $$right + $$inCarry; |]
, [stmt|let $$res = $$sumi as u64; |]
] ++
if not useLastCarry && (i == lastI)
then []
else [[stmt|let $$outCarry = $$sumi >> 64; |]]
generateSetters :: Bool -> Word -> String -> Word -> [Stmt Span]
generateSetters useLastCarry maxI resName i
| not useLastCarry && (maxI == i) = []
| maxI == i =
let res = mkIdent ("carry" ++ show (i - 1))
liti = toLit i
in [[stmt| $$target.value[$$(liti)] = $$res as u64; |]]
| otherwise =
let res = mkIdent ("res" ++ show i)
liti = toLit i
in [[stmt| $$target.value[$$(liti)] = $$res; |]]
where
target = mkIdent resName
generateSafeTest :: RandomGen g => Word -> g -> (Map String String, g)
generateSafeTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX (x + y))]
generateUnsafeTest :: RandomGen g => Word -> g -> (Map String String, g)
generateUnsafeTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
z = (x + y) .&. ((2 ^ size) - 1)
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)]
generateSafeSignedTest :: RandomGen g => Word -> g -> (Map String String, g)
generateSafeSignedTest size g0 = (tcase, g2)
where
(x, g1) = generateSignedNum g0 size
(y, g2) = generateSignedNum g1 size
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX (x + y))]
generateUnsafeSignedTest :: RandomGen g => Word -> g -> (Map String String, g)
generateUnsafeSignedTest size g0 = (tcase, g2)
where
(x, g1) = generateSignedNum g0 size
(y, g2) = generateSignedNum g1 size
z = (x + y) .&. ((2 ^ size) - 1)
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)]

View File

@@ -1,57 +1,91 @@
{-# LANGUAGE QuasiQuotes #-}
module Base( module Base(
base base
) )
where where
import Control.Monad(forM_) import Language.Rust.Data.Ident
import File import Language.Rust.Data.Position
import Gen import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
base :: File base :: RustModule
base = File { base = RustModule {
predicate = \ _ _ -> True, predicate = \ _ _ -> True,
suggested = const [],
outputName = "base", outputName = "base",
isUnsigned = True,
generator = declareBaseStructure, generator = declareBaseStructure,
testGenerator = Nothing testCase = Nothing
} }
declareBaseStructure :: Word -> Gen () declareBaseStructure :: Word -> [Word] -> SourceFile Span
declareBaseStructure bitsize = declareBaseStructure bitsize _ =
do let name = "U" ++ show bitsize let tname = "U" ++ show bitsize
entries = bitsize `div` 64 entries = bitsize `div` 64
top = entries - 1 sname = mkIdent tname
out "use core::fmt;" entriese = Lit [] (Int Dec (fromIntegral entries) Unsuffixed mempty) mempty
out "use quickcheck::{Arbitrary,Gen};" strname = Lit [] (Str tname Cooked Unsuffixed mempty) mempty
blank debugExp = buildDebugExp 0 entries [expr| f.debug_tuple($$(strname)) |]
out "#[derive(Clone)]" lowerPrints = buildPrints entries "x"
wrapIndent ("pub struct " ++ name) $ upperPrints = buildPrints entries "X"
out ("pub(crate) value: [u64; " ++ show entries ++ "]") in [sourceFile|
blank use core::fmt;
implFor "fmt::Debug" name $ use quickcheck::{Arbitrary,Gen};
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
do out ("f.debug_tuple(" ++ show name ++ ")") #[derive(Clone)]
forM_ [0..top] $ \ i -> pub struct $$sname {
out (" .field(&self.value[" ++ show i ++ "])") pub(crate) value: [u64; $$(entriese)]
out " .finish()" }
blank
implFor "fmt::UpperHex" name $ impl fmt::Debug for $$sname {
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
do forM_ (reverse [1..top]) $ \ i -> $$(debugExp).finish()
out ("write!(f, \"{:X}\", self.value[" ++ show i ++ "])?;") }
out "write!(f, \"{:X}\", self.value[0])" }
blank
implFor "fmt::LowerHex" name $ impl fmt::UpperHex for $$sname {
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
do forM_ (reverse [1..top]) $ \ i -> $@{upperPrints}
out ("write!(f, \"{:x}\", self.value[" ++ show i ++ "])?;") write!(f, "{:016X}", self.value[0])
out "write!(f, \"{:x}\", self.value[0])" }
blank }
implFor "Arbitrary" name $
wrapIndent "fn arbitrary<G: Gen>(g: &mut G) -> Self" $ impl fmt::LowerHex for $$sname {
do out (name ++ " {") fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
indent $ $@{lowerPrints}
do out ("value: [") write!(f, "{:016x}", self.value[0])
indent $ forM_ [0..top] $ \ _ -> }
out ("g.next_u64(),") }
out ("]")
out ("}") impl Arbitrary for $$sname {
fn arbitrary<G: Gen>(g: &mut G) -> Self {
let mut res = $$sname{ value: [0; $$(entriese)] };
for entry in res.value.iter_mut() {
*entry = g.next_u64();
}
res
}
}
|]
buildDebugExp :: Word -> Word -> Expr Span -> Expr Span
buildDebugExp i top acc
| i == top = acc
| otherwise =
let liti = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty
in buildDebugExp (i + 1) top [expr| $$(acc).field(&self.value[$$(liti)]) |]
buildPrints :: Word -> String -> [Stmt Span]
buildPrints entries printer = go (entries - 1)
where
litStr = Token mempty (LiteralTok (StrTok ("{:016" ++ printer ++ "}")) Nothing)
--Lit [] (Str ("{:" ++ printer ++ "}") Cooked Unsuffixed mempty) mempty
go 0 = []
go x =
let rest = go (x - 1)
curi = Token mempty (LiteralTok (IntegerTok (show x)) Nothing)
-- Lit [] (Int Dec (fromIntegral x) Unsuffixed mempty) mempty
cur = [stmt| write!(f, $$(litStr), self.value[$$(curi)])?; |]
in cur : rest

View File

@@ -1,171 +1,249 @@
{-# LANGUAGE QuasiQuotes #-}
module BinaryOps( module BinaryOps(
binaryOps binaryOps
) )
where where
import Control.Monad(forM_,replicateM_) import Data.Bits(xor,(.&.),(.|.))
import Data.Bits((.&.), (.|.), shiftL, xor) import Data.Map.Strict(Map)
import File import qualified Data.Map.Strict as Map
import Gen import Generators
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
binaryTestCount :: Int binaryOps :: RustModule
binaryTestCount = 3000 binaryOps = RustModule {
binaryOps :: File
binaryOps = File {
predicate = \ _ _ -> True, predicate = \ _ _ -> True,
suggested = const [],
outputName = "binary", outputName = "binary",
isUnsigned = True,
generator = declareBinaryOperators, generator = declareBinaryOperators,
testGenerator = Just testVectors testCase = Just generateTest
} }
declareBinaryOperators :: Word -> Gen () declareBinaryOperators :: Word -> [Word] -> SourceFile Span
declareBinaryOperators bitsize = declareBinaryOperators bitsize _ =
do let name = "U" ++ show bitsize let struct_name = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64 entries = bitsize `div` 64
out "use core::ops::{BitAnd,BitAndAssign};" andOps = generateBinOps "BitAnd" struct_name "bitand" BitAndOp entries
out "use core::ops::{BitOr,BitOrAssign};" orOps = generateBinOps "BitOr" struct_name "bitor" BitOrOp entries
out "use core::ops::{BitXor,BitXorAssign};" xorOps = generateBinOps "BitXor" struct_name "bitxor" BitXorOp entries
out "use core::ops::Not;" baseNegationStmts = negationStatements "self" entries
out "#[cfg(test)]" refNegationStmts = negationStatements "output" entries
out "use crate::CryptoNum;" testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
out "#[cfg(test)]" in [sourceFile|
out "use quickcheck::quickcheck;" use core::ops::{BitAnd,BitAndAssign};
out ("use super::U" ++ show bitsize ++ ";") use core::ops::{BitOr,BitOrAssign};
blank use core::ops::{BitXor,BitXorAssign};
generateBinOps "BitAnd" name "bitand" "&=" entries use core::ops::Not;
blank #[cfg(test)]
generateBinOps "BitOr" name "bitor" "|=" entries use crate::CryptoNum;
blank #[cfg(test)]
generateBinOps "BitXor" name "bitxor" "^=" entries use crate::testing::{build_test_path,run_test};
blank #[cfg(test)]
implFor "Not" name $ use quickcheck::quickcheck;
do out "type Output = Self;" use super::$$struct_name;
blank
wrapIndent "fn not(mut self) -> Self" $
do forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];")
out "self"
blank
implFor' "Not" ("&'a " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn not(self) -> " ++ name) $
do out "let mut output = self.clone();"
forM_ [0..entries-1] $ \ i ->
out ("output.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];")
out "output"
blank
addBinaryLaws name entries
generateBinOps :: String -> String -> String -> String -> Word -> Gen () $@{andOps}
generateBinOps trait name fun op entries = $@{orOps}
do implFor (trait ++ "Assign") name $ $@{xorOps}
wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: Self)") $
forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];")
blank
implFor' (trait ++ "Assign<&'a " ++ name ++ ">") name $
wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: &Self)") $
forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];")
blank
generateBinOpsFromAssigns trait name fun op
generateBinOpsFromAssigns :: String -> String -> String -> String -> Gen () impl Not for $$struct_name {
generateBinOpsFromAssigns trait name fun op = type Output = Self;
do implFor trait name $
do out "type Output = Self;"
blank
wrapIndent ("fn " ++ fun ++ "(mut self, rhs: Self) -> Self") $
do out ("self " ++ op ++ " rhs;")
out "self"
blank
implFor' (trait ++ "<&'a " ++ name ++ ">") name $
do out "type Output = Self;"
blank
wrapIndent ("fn " ++ fun ++ "(mut self, rhs: &Self) -> Self") $
do out ("self " ++ op ++ " rhs;")
out "self"
blank
implFor' (trait ++ "<" ++ name ++ ">") ("&'a " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn " ++ fun ++ "(self, mut rhs: " ++ name ++ ") -> " ++ name) $
do out ("rhs " ++ op ++ " self;")
out "rhs"
blank
implFor'' (trait ++ "<&'a " ++ name ++ ">") ("&'b " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn " ++ fun ++ "(self, rhs: &" ++ name ++ ") -> " ++ name) $
do out "let mut output = self.clone();"
out ("output " ++ op ++ " rhs;")
out "output"
addBinaryLaws :: String -> Word -> Gen () fn not(mut self) -> Self {
addBinaryLaws name entries = $@{baseNegationStmts}
do let args3 = "(a: " ++ name ++ ", b: " ++ name ++ ", c: " ++ name ++ ")" self
args2 = "(a: " ++ name ++ ", b: " ++ name ++ ")" }
out "#[cfg(test)]" }
wrapIndent "quickcheck!" $
do wrapIndent ("fn and_associative" ++ args3 ++ " -> bool") $
out ("((&a & &b) & &c) == (&a & (&b & &c))")
blank
wrapIndent ("fn and_commutative" ++ args2 ++ " -> bool") $
out ("(&a & &b) == (&b & &a)")
blank
wrapIndent ("fn and_idempotent" ++ args2 ++ " -> bool") $
out ("(&a & &b) == (&a & &b & &a)")
blank
wrapIndent ("fn xor_associative" ++ args3 ++ " -> bool") $
out ("((&a ^ &b) ^ &c) == (&a ^ (&b ^ &c))")
blank
wrapIndent ("fn xor_commutative" ++ args2 ++ " -> bool") $
out ("(&a ^ &b) == (&b ^ &a)")
blank
wrapIndent ("fn or_associative" ++ args3 ++ " -> bool") $
out ("((&a | &b) | &c) == (&a | (&b | &c))")
blank
wrapIndent ("fn or_commutative" ++ args2 ++ " -> bool") $
out ("(&a | &b) == (&b | &a)")
blank
wrapIndent ("fn or_idempotent" ++ args2 ++ " -> bool") $
out ("(&a | &b) == (&a | &b | &a)")
blank
wrapIndent ("fn and_or_distribution" ++ args3 ++ "-> bool") $
out ("(&a & (&b | &c)) == ((&a & &b) | (&a & &c))")
blank
wrapIndent ("fn xor_clears(a: " ++ name ++ ") -> bool") $
out (name ++ "::zero() == (&a ^ &a)")
blank
wrapIndent ("fn double_neg_ident(a: " ++ name ++ ") -> bool") $
out ("a == !!&a")
blank
wrapIndent ("fn and_ident(a: " ++ name ++ ") -> bool") $
do out ("let ones = !" ++ name ++ "::zero();")
out ("(&a & &ones) == a")
blank
wrapIndent ("fn or_ident(a: " ++ name ++ ") -> bool") $
out ("(&a | " ++ name ++ "::zero()) == a")
wrapIndent ("fn neg_as_xor(a: " ++ name ++ ") -> bool") $
do out ("let ones = " ++ name ++ "{ value: [0xFFFFFFFFFFFFFFFFu64; "
++ show entries ++ "] };")
out ("!&a == (&ones ^ &a)")
impl<'a> Not for &'a $$struct_name {
type Output = $$struct_name;
testVectors :: Word -> Gen () fn not(self) -> Self::Output {
testVectors bitsize = replicateM_ binaryTestCount $ let mut output = self.clone();
do a <- newNum False bitsize $@{refNegationStmts}
b <- newNum False bitsize output
let o = a .|. b }
c = a .&. b }
n = a `xor` ((1 `shiftL` fromIntegral bitsize) - 1)
x = a `xor` b
emitTestVariable 'a' a
emitTestVariable 'b' b
emitTestVariable 'c' c
emitTestVariable 'o' o
emitTestVariable 'n' n
emitTestVariable 'x' x
#[cfg(test)]
quickcheck! {
fn and_associative(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool {
((&a & &b) & &c) == (&a & (&b & &c))
}
fn and_commutative(a: $$struct_name, b: $$struct_name) -> bool {
(&a & &b) == (&b & &a)
}
fn and_idempotent(a: $$struct_name, b: $$struct_name) -> bool {
(&a & &b) == (&a & &b & &a)
}
fn xor_associative(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool {
((&a ^ &b) ^ &c) == (&a ^ (&b ^ &c))
}
fn xor_commutative(a: $$struct_name, b: $$struct_name) -> bool {
(&a ^ &b) == (&b ^ &a)
}
fn or_associative(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool {
((&a | &b) | &c) == (&a | (&b | &c))
}
fn or_commutative(a: $$struct_name, b: $$struct_name) -> bool {
(&a | &b) == (&b | &a)
}
fn or_idempotent(a: $$struct_name, b: $$struct_name) -> bool {
(&a | &b) == (&a | &b | &a)
}
fn and_or_distribution(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool {
(&a & (&b | &c)) == ((&a & &b) | (&a & &c))
}
fn xor_clears(a: $$struct_name) -> bool {
$$struct_name::zero() == (&a ^ &a)
}
fn double_neg_ident(a: $$struct_name) -> bool {
a == !!&a
}
fn and_ident(a: $$struct_name) -> bool {
let ones = !$$struct_name::zero();
(&a & &ones) == a
}
fn or_ident(a: $$struct_name) -> bool {
(&a | $$struct_name::zero()) == a
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("binary", $$(testFileLit)), 6, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, abytes) = case.get("a").unwrap();
let (neg3, obytes) = case.get("o").unwrap();
let (neg4, ebytes) = case.get("e").unwrap();
let (neg5, nbytes) = case.get("n").unwrap();
assert!(!neg0 && !neg1 && !neg2 && !neg3 && !neg4 && !neg5);
let x = $$struct_name::from_bytes(&xbytes);
let y = $$struct_name::from_bytes(&ybytes);
let a = $$struct_name::from_bytes(&abytes);
let o = $$struct_name::from_bytes(&obytes);
let e = $$struct_name::from_bytes(&ebytes);
let n = $$struct_name::from_bytes(&nbytes);
assert_eq!(a, &x & &y);
assert_eq!(o, &x | &y);
assert_eq!(e, &x ^ &y);
assert_eq!(n, !x);
});
}
|]
negationStatements :: String -> Word -> [Stmt Span]
negationStatements target entries = map genStatement [0..entries-1]
where
genStatement i =
let idx = toLit i
v = mkIdent target
in [stmt| $$v.value[$$(idx)] = !self.value[$$(idx)]; |]
generateBinOps :: String -> Ident -> String -> BinOp -> Word -> [Item Span]
generateBinOps trait sname func oper entries =
[normAssign, refAssign] ++
generateAllTheVariants traitIdent funcIdent sname oper entries
where
traitIdent = mkIdent trait
assignIdent = mkIdent (trait ++ "Assign")
funcIdent = mkIdent func
funcAssignIdent = mkIdent (func ++ "_assign")
--
normAssign = [item|
impl $$assignIdent for $$sname {
fn $$funcAssignIdent(&mut self, rhs: Self) {
$@{assignStatements}
}
}
|]
refAssign = [item|
impl<'a> $$assignIdent<&'a $$sname> for $$sname {
fn $$funcAssignIdent(&mut self, rhs: &Self) {
$@{assignStatements}
}
}
|]
--
assignStatements :: [Stmt Span]
assignStatements = map genAssign [0..entries-1]
genAssign i =
let idx = toLit i
left = [expr| self.value[$$(idx)] |]
right = [expr| rhs.value[$$(idx)] |]
in Semi (AssignOp [] oper left right mempty) mempty
generateAllTheVariants :: Ident -> Ident -> Ident -> BinOp -> Word -> [Item Span]
generateAllTheVariants traitname func sname oper entries = [
[item|
impl $$traitname for $$sname {
type Output = $$sname;
fn $$func(mut self, rhs: $$sname) -> Self::Output {
$@{assigners_self_rhs}
self
}
}|]
, [item|
impl<'a> $$traitname<&'a $$sname> for $$sname {
type Output = $$sname;
fn $$func(mut self, rhs: &$$sname) -> Self::Output {
$@{assigners_self_rhs}
self
}
}|]
, [item|
impl<'a> $$traitname<$$sname> for &'a $$sname {
type Output = $$sname;
fn $$func(self, mut rhs: $$sname) -> Self::Output {
$@{assigners_rhs_self}
rhs
}
}|]
, [item|
impl<'a,'b> $$traitname<&'a $$sname> for &'b $$sname {
type Output = $$sname;
fn $$func(self, rhs: &$$sname) -> Self::Output {
let mut out = self.clone();
$@{assigners_out_rhs}
out
}
}|]
]
where
assigners_self_rhs = assigners [expr| self |] [expr| rhs |]
assigners_rhs_self = assigners [expr| rhs |] [expr| self |]
assigners_out_rhs = assigners [expr| out |] [expr| rhs |]
assigners left right = map (genAssign left right . toLit) [0..entries-1]
genAssign left right i =
Semi (AssignOp [] oper [expr| $$(left).value[$$(i)] |]
[expr| $$(right).value[$$(i)] |]
mempty) mempty
generateTest :: RandomGen g => Word -> g -> (Map String String, g)
generateTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
tcase = Map.fromList [("x", showX x), ("y", showX y),
("a", showX (x .&. y)),
("o", showX (x .|. y)),
("e", showX (x `xor` y)),
("n", showX ( ((2 ^ size) - 1) `xor` x ))]

View File

@@ -1,61 +1,275 @@
module Compare(comparisons) {-# LANGUAGE QuasiQuotes #-}
module Compare(comparisons, signedComparisons)
where where
import Control.Monad(forM_) import Data.Map.Strict(Map)
import File import qualified Data.Map.Strict as Map
import Gen import Generators
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
comparisons :: File comparisons :: RustModule
comparisons = File { comparisons = RustModule {
predicate = \ _ _ -> True, predicate = \ _ _ -> True,
suggested = const [],
outputName = "compare", outputName = "compare",
isUnsigned = True,
generator = declareComparators, generator = declareComparators,
testGenerator = Nothing testCase = Just generateTest
} }
declareComparators :: Word -> Gen () signedComparisons :: RustModule
declareComparators bitsize = signedComparisons = RustModule {
do let name = "U" ++ show bitsize predicate = \ _ _ -> True,
suggested = const [],
outputName = "scompare",
isUnsigned = False,
generator = declareSignedComparators,
testCase = Just generateSignedTest
}
declareComparators :: Word -> [Word] -> SourceFile Span
declareComparators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64 entries = bitsize `div` 64
top = entries - 1 eqStatements = buildEqStatements 0 entries
out "use core::cmp::{Eq,Ordering,PartialEq};" compareExp = buildCompareExp 0 entries
out "#[cfg(test)]" testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
out "use quickcheck::quickcheck;" in [sourceFile|
out ("use super::" ++ name ++ ";") use core::cmp::{Eq,Ordering,PartialEq};
blank #[cfg(test)]
implFor "PartialEq" name $ use crate::CryptoNum;
wrapIndent "fn eq(&self, other: &Self) -> bool" $ #[cfg(test)]
do forM_ (reverse [1..top]) $ \ i -> use crate::testing::{build_test_path,run_test};
out ("self.value[" ++ show i ++ "] == other.value[" ++ show i ++ "] && ") #[cfg(test)]
out "self.value[0] == other.value[0]" use quickcheck::quickcheck;
blank use super::$$sname;
implFor "Eq" name $ return ()
blank impl PartialEq for $$sname {
implFor "Ord" name $ fn eq(&self, other: &Self) -> bool {
wrapIndent "fn cmp(&self, other: &Self) -> Ordering" $ let mut out = true;
do out ("self.value[" ++ show top ++ "].cmp(&other.value[" ++ show top ++ "])") $@{eqStatements}
forM_ (reverse [0..top-1]) $ \ i -> out
out (" .then(self.value[" ++ show i ++ "].cmp(&other.value[" ++ show i ++ "]))") }
blank }
implFor "PartialOrd" name $
wrapIndent "fn partial_cmp(&self, other: &Self) -> Option<Ordering>" $ impl Eq for $$sname {}
out "Some(self.cmp(other))"
blank impl Ord for $$sname {
out "#[cfg(test)]" fn cmp(&self, other: &Self) -> Ordering {
wrapIndent "quickcheck!" $ $$(compareExp)
do let transFun n = "fn " ++ n ++ "(a: " ++ name ++ ", b: " ++ name ++ }
", c: " ++ name ++ ") -> bool" }
wrapIndent (transFun "eq_is_transitive") $
out ("if a == c { a == b && b == c } else { a != b || b != c }") impl PartialOrd for $$sname {
blank fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
wrapIndent (transFun "gt_is_transitive") $ Some(self.cmp(other))
out ("if a > b && b > c { a > c } else { true }") }
blank }
wrapIndent (transFun "ge_is_transitive") $
out ("if a >= b && b >= c { a >= c } else { true }") #[cfg(test)]
blank quickcheck! {
wrapIndent (transFun "lt_is_transitive") $ fn eq_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
out ("if a < b && b < c { a < c } else { true }") if a == c { a == b && b == c } else { a != b || b != c }
blank }
wrapIndent (transFun "le_is_transitive") $
out ("if a <= b && b <= c { a <= c } else { true }") fn gt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a > b && b > c { a > c } else { true }
}
fn ge_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a >= b && b >= c { a >= c } else { true }
}
fn lt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a < b && b < c { a < c } else { true }
}
fn le_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a <= b && b <= c { a <= c } else { true }
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("compare", $$(testFileLit)), 8, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, ebytes) = case.get("e").unwrap();
let (neg3, nbytes) = case.get("n").unwrap();
let (neg4, gbytes) = case.get("g").unwrap();
let (neg5, hbytes) = case.get("h").unwrap();
let (neg6, lbytes) = case.get("l").unwrap();
let (neg7, kbytes) = case.get("k").unwrap();
assert!(!neg0 && !neg1 && !neg2 && !neg3 &&
!neg4 && !neg5 && !neg6 && !neg7);
let x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let e = 1 == ebytes[0];
let n = 1 == nbytes[0];
let g = 1 == gbytes[0];
let h = 1 == hbytes[0];
let l = 1 == lbytes[0];
let k = 1 == kbytes[0];
assert_eq!(e, x == y);
assert_eq!(n, x != y);
assert_eq!(g, x > y);
assert_eq!(h, x >= y);
assert_eq!(l, x < y);
assert_eq!(k, x <= y);
});
}
|]
declareSignedComparators :: Word -> [Word] -> SourceFile Span
declareSignedComparators bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
testFileLit = Lit [] (Str (testFile False bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::cmp::{Eq,Ordering,PartialEq};
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use super::$$sname;
impl PartialEq for $$sname {
fn eq(&self, other: &Self) -> bool {
&self.contents == &other.contents
}
}
impl Eq for $$sname {}
impl Ord for $$sname {
fn cmp(&self, other: &Self) -> Ordering {
match (self.is_negative(), other.is_negative()) {
(false, false) => self.contents.cmp(&other.contents),
(false, true) => Ordering::Greater,
(true, false) => Ordering::Less,
(true, true) => self.contents.cmp(&other.contents),
}
}
}
impl PartialOrd for $$sname {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[cfg(test)]
quickcheck! {
fn eq_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a == c { a == b && b == c } else { a != b || b != c }
}
fn gt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a > b && b > c { a > c } else { true }
}
fn ge_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a >= b && b >= c { a >= c } else { true }
}
fn lt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a < b && b < c { a < c } else { true }
}
fn le_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a <= b && b <= c { a <= c } else { true }
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("scompare", $$(testFileLit)), 8, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, ebytes) = case.get("e").unwrap();
let (neg3, nbytes) = case.get("n").unwrap();
let (neg4, gbytes) = case.get("g").unwrap();
let (neg5, hbytes) = case.get("h").unwrap();
let (neg6, lbytes) = case.get("l").unwrap();
let (neg7, kbytes) = case.get("k").unwrap();
assert!(!neg2 && !neg3 && !neg4 && !neg5 && !neg6 && !neg7);
let mut x = $$sname::from_bytes(&xbytes);
let mut y = $$sname::from_bytes(&ybytes);
if *neg0 { x = -x; }
if *neg1 { y = -y; }
let e = 1 == ebytes[0];
let n = 1 == nbytes[0];
let g = 1 == gbytes[0];
let h = 1 == hbytes[0];
let l = 1 == lbytes[0];
let k = 1 == kbytes[0];
assert_eq!(e, x == y);
assert_eq!(n, x != y);
assert_eq!(g, x > y);
assert_eq!(h, x >= y);
assert_eq!(l, x < y);
assert_eq!(k, x <= y);
});
}
|]
buildEqStatements :: Word -> Word -> [Stmt Span]
buildEqStatements i numEntries
| i == (numEntries - 1) =
[[stmt| out &= self.value[$$(x)] == other.value[$$(x)]; |]]
| otherwise =
let rest = buildEqStatements (i + 1) numEntries
cur = [stmt| out &= self.value[$$(x)] == other.value[$$(x)]; |]
in cur:rest
where
x = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty
buildCompareExp :: Word -> Word -> Expr Span
buildCompareExp i numEntries
| i == (numEntries - 1) =
[expr| self.value[$$(x)].cmp(&other.value[$$(x)]) |]
| otherwise =
let rest = buildCompareExp (i + 1) numEntries
in [expr| $$(rest).then(self.value[$$(x)].cmp(&other.value[$$(x)])) |]
where
x = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty
generateTest :: RandomGen g => Word -> g -> (Map String String, g)
generateTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
tcase = Map.fromList [("x", showX x), ("y", showX y),
("e", showB (x == y)),
("n", showB (x /= y)),
("g", showB (x > y)),
("h", showB (x >= y)),
("l", showB (x < y)),
("k", showB (x <= y))]
generateSignedTest :: RandomGen g => Word -> g -> (Map String String, g)
generateSignedTest size g0 = (tcase, g2)
where
(x, g1) = generateSignedNum g0 size
(y, g2) = generateSignedNum g1 size
tcase = Map.fromList [("x", showX x), ("y", showX y),
("e", showB (x == y)),
("n", showB (x /= y)),
("g", showB (x > y)),
("h", showB (x >= y)),
("l", showB (x < y)),
("k", showB (x <= y))]

View File

@@ -1,99 +1,739 @@
{-# LANGUAGE QuasiQuotes #-}
module Conversions( module Conversions(
conversions conversions
, signedConversions
) )
where where
import Data.List(intercalate) import Generators
import File import Language.Rust.Data.Ident
import Gen import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
conversions :: File conversions :: RustModule
conversions = File { conversions = RustModule {
predicate = \ _ _ -> True, predicate = \ _ _ -> True,
suggested = const [],
outputName = "conversions", outputName = "conversions",
isUnsigned = True,
generator = declareConversions, generator = declareConversions,
testGenerator = Nothing testCase = Nothing
} }
declareConversions :: Word -> Gen () signedConversions :: RustModule
declareConversions bitsize = signedConversions = RustModule {
do let name = "U" ++ show bitsize predicate = \ _ _ -> True,
suggested = const [],
outputName = "sconversions",
isUnsigned = False,
generator = declareSignedConversions,
testCase = Nothing
}
declareConversions :: Word -> [Word] -> SourceFile Span
declareConversions bitsize otherSizes =
let sname = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64 entries = bitsize `div` 64
out "use core::convert::{From,TryFrom};" u8_prims = buildPrimitives sname (mkIdent "u8") entries
out "#[cfg(test)]" u16_prims = buildPrimitives sname (mkIdent "u16") entries
out "use quickcheck::quickcheck;" u32_prims = buildPrimitives sname (mkIdent "u32") entries
out ("use super::" ++ name ++ ";") u64_prims = buildPrimitives sname (mkIdent "u64") entries
blank usz_prims = buildPrimitives sname (mkIdent "usize") entries
buildUnsignedPrimConversions name entries "u8" >> blank u128_prims = generateU128Primitives sname entries
buildUnsignedPrimConversions name entries "u16" >> blank i8_prims = generateSignedPrims sname (mkIdent "u8") (mkIdent "i8")
buildUnsignedPrimConversions name entries "u32" >> blank i16_prims = generateSignedPrims sname (mkIdent "u16") (mkIdent "i16")
buildUnsignedPrimConversions name entries "u64" >> blank i32_prims = generateSignedPrims sname (mkIdent "u32") (mkIdent "i32")
buildUnsignedPrimConversions name entries "usize" >> blank i64_prims = generateSignedPrims sname (mkIdent "u64") (mkIdent "i64")
buildSignedPrimConversions name entries "i8" >> blank isz_prims = buildPrimitives sname (mkIdent "isize") entries
buildSignedPrimConversions name entries "i16" >> blank i128_prims = generateI128Primitives sname
buildSignedPrimConversions name entries "i32" >> blank others = generateCryptonumConversions bitsize otherSizes
buildSignedPrimConversions name entries "i64" >> blank in [sourceFile|
buildSignedPrimConversions name entries "isize" use core::convert::{From,TryFrom};
blank use crate::CryptoNum;
out ("#[cfg(test)]") use crate::ConversionError;
wrapIndent "quickcheck!" $ #[cfg(test)]
do roundTripTest name "u8" >> blank use quickcheck::quickcheck;
roundTripTest name "u16" >> blank use super::super::*;
roundTripTest name "u32" >> blank
roundTripTest name "u64" >> blank
roundTripTest name "usize"
buildUnsignedPrimConversions :: String -> Word -> String -> Gen () $@{u8_prims}
buildUnsignedPrimConversions name entries primtype = $@{u16_prims}
do implFor ("From<" ++ primtype ++ ">") name $ $@{u32_prims}
wrapIndent ("fn from(x: " ++ primtype ++ ") -> Self") $ $@{u64_prims}
do let zeroes = replicate (fromIntegral (entries - 1)) "0," $@{usz_prims}
values = ("x as u64," : zeroes) $@{u128_prims}
out (name ++ " { value: [ ")
indent $ printBy 8 values
out ("] }")
blank
implFor ("From<" ++ name ++ ">") primtype $
wrapIndent ("fn from(x: " ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
blank
implFor' ("From<&'a " ++ name ++ ">") primtype $
wrapIndent ("fn from(x: &" ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
buildSignedPrimConversions :: String -> Word -> String -> Gen () $@{i8_prims}
buildSignedPrimConversions name entries primtype = $@{i16_prims}
do implFor ("TryFrom<" ++ primtype ++ ">") name $ $@{i32_prims}
do out ("type Error = &'static str;") $@{i64_prims}
blank $@{isz_prims}
wrapIndent ("fn try_from(x: " ++ primtype ++ ") -> Result<Self,Self::Error>") $ $@{i128_prims}
do wrapIndent ("if x < 0") $
out ("return Err(\"Attempt to convert negative number to " ++
name ++ ".\");")
blank
let zeroes = replicate (fromIntegral (entries - 1)) "0,"
values = ("x as u64," : zeroes)
out ("Ok(" ++ name ++ " { value: [ ")
indent $ printBy 8 values
out ("] })")
blank
implFor ("From<" ++ name ++ ">") primtype $
wrapIndent ("fn from(x: " ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
blank
implFor' ("From<&'a " ++ name ++ ">") primtype $
wrapIndent ("fn from(x: &" ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
roundTripTest :: String -> String -> Gen () $@{others}
roundTripTest name primtype =
wrapIndent ("fn " ++ primtype ++ "_roundtrips(x: " ++ primtype ++ ") -> bool") $ #[cfg(test)]
do out ("let big = " ++ name ++ "::from(x);"); quickcheck! {
out ("let small = " ++ primtype ++ "::from(big);") fn u8_recovers(x: u8) -> bool {
out ("x == small") x == u8::try_from($$sname::from(x)).unwrap()
}
fn u16_recovers(x: u16) -> bool {
x == u16::try_from($$sname::from(x)).unwrap()
}
fn u32_recovers(x: u32) -> bool {
x == u32::try_from($$sname::from(x)).unwrap()
}
fn u64_recovers(x: u64) -> bool {
x == u64::try_from($$sname::from(x)).unwrap()
}
fn usize_recovers(x: usize) -> bool {
x == usize::try_from($$sname::from(x)).unwrap()
}
fn u128_recovers(x: u128) -> bool {
x == u128::try_from($$sname::from(x)).unwrap()
}
}
|]
declareSignedConversions :: Word -> [Word] -> SourceFile Span
declareSignedConversions bitsize otherSizes =
let sname = mkIdent ("I" ++ show bitsize)
uname = mkIdent ("U" ++ show bitsize)
u8_prims = buildUSPrimitives sname (mkIdent "u8")
u16_prims = buildUSPrimitives sname (mkIdent "u16")
u32_prims = buildUSPrimitives sname (mkIdent "u32")
u64_prims = buildUSPrimitives sname (mkIdent "u64")
usz_prims = buildUSPrimitives sname (mkIdent "usize")
i8_prims = buildSSPrimitives sname uname (mkIdent "i8")
i16_prims = buildSSPrimitives sname uname (mkIdent "i16")
i32_prims = buildSSPrimitives sname uname (mkIdent "i32")
i64_prims = buildSSPrimitives sname uname (mkIdent "i64")
isz_prims = buildSSPrimitives sname uname (mkIdent "isize")
s128_prims = generateS128Primitives sname uname
others = generateSignedCryptonumConversions bitsize otherSizes
in [sourceFile|
use core::convert::{From,TryFrom};
use core::{i8,i16,i32,i64,isize};
use crate::CryptoNum;
use crate::ConversionError;
use crate::signed::*;
use crate::unsigned::*;
#[cfg(test)]
use quickcheck::quickcheck;
$@{u8_prims}
$@{u16_prims}
$@{u32_prims}
$@{u64_prims}
$@{usz_prims}
$@{i8_prims}
$@{i16_prims}
$@{i32_prims}
$@{i64_prims}
$@{isz_prims}
$@{s128_prims}
$@{others}
#[cfg(test)]
quickcheck! {
fn u8_recovers(x: u8) -> bool {
x == u8::try_from($$sname::from(x)).unwrap()
}
fn u16_recovers(x: u16) -> bool {
x == u16::try_from($$sname::from(x)).unwrap()
}
fn u32_recovers(x: u32) -> bool {
x == u32::try_from($$sname::from(x)).unwrap()
}
fn u64_recovers(x: u64) -> bool {
x == u64::try_from($$sname::from(x)).unwrap()
}
fn usize_recovers(x: usize) -> bool {
x == usize::try_from($$sname::from(x)).unwrap()
}
fn u128_recovers(x: u128) -> bool {
x == u128::try_from($$sname::from(x)).unwrap()
}
fn i8_recovers(x: i8) -> bool {
x == i8::try_from($$sname::from(x)).unwrap()
}
fn i16_recovers(x: i16) -> bool {
x == i16::try_from($$sname::from(x)).unwrap()
}
fn i32_recovers(x: i32) -> bool {
x == i32::try_from($$sname::from(x)).unwrap()
}
fn i64_recovers(x: i64) -> bool {
x == i64::try_from($$sname::from(x)).unwrap()
}
fn isize_recovers(x: isize) -> bool {
x == isize::try_from($$sname::from(x)).unwrap()
}
fn i128_recovers(x: i128) -> bool {
x == i128::try_from($$sname::from(x)).unwrap()
}
}
|]
generateU128Primitives :: Ident -> Word -> [Item Span]
generateU128Primitives sname entries = [
[item|impl From<u128> for $$sname {
fn from(x: u128) -> Self {
let mut res = $$sname::zero();
res.value[0] = x as u64;
res.value[1] = (x >> 64) as u64;
res
}
}|]
, [item|impl TryFrom<$$sname> for u128 {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<u128,ConversionError> {
let mut good_conversion = true;
let mut res;
res = (x.value[1] as u128) << 64;
res |= x.value[0] as u128;
$@{testZeros}
if good_conversion {
Ok(res)
} else {
Err(ConversionError::Overflow)
}
}
}|]
, [item|impl<'a> TryFrom<&'a $$sname> for u128 {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<u128,ConversionError> {
let mut good_conversion = true;
let mut res;
res = (x.value[1] as u128) << 64;
res |= x.value[0] as u128;
$@{testZeros}
if good_conversion {
Ok(res)
} else {
Err(ConversionError::Overflow)
}
}
}|]
]
where
testZeros = map (zeroTest . toLit) [2..entries-1]
zeroTest i =
[stmt| good_conversion &= x.value[$$(i)] == 0; |]
buildPrimitives :: Ident -> Ident -> Word -> [Item Span]
buildPrimitives sname tname entries = [
[item|impl From<$$tname> for $$sname {
fn from(x: $$tname) -> Self {
let mut res = $$sname::zero();
res.value[0] = x as u64;
res
}
}|]
, [item|impl TryFrom<$$sname> for $$tname {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<Self,ConversionError> {
let mut good_conversion = true;
let res = x.value[0] as $$tname;
$@{testZeros}
if good_conversion {
Ok(res)
} else {
Err(ConversionError::Overflow)
}
}
}|]
, [item|impl<'a> TryFrom<&'a $$sname> for $$tname {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<Self,ConversionError> {
let mut good_conversion = true;
let res = x.value[0] as $$tname;
$@{testZeros}
if good_conversion {
Ok(res)
} else {
Err(ConversionError::Overflow)
}
}
}|]
]
where
testZeros = map (zeroTest . toLit) [1..entries-1]
zeroTest i =
[stmt| good_conversion &= x.value[$$(i)] == 0; |]
generateSignedPrims :: Ident -> Ident -> Ident -> [Item Span]
generateSignedPrims sname unsigned signed = [
[item|impl TryFrom<$$signed> for $$sname {
type Error = ConversionError;
fn try_from(x: $$signed) -> Result<Self,ConversionError> {
let mut res = $$sname::zero();
res.value[0] = x as u64;
if x < 0 {
Err(ConversionError::NegativeToUnsigned)
} else {
Ok(res)
}
}
}|]
, [item|impl TryFrom<$$sname> for $$signed {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<Self,ConversionError> {
let uns = $$unsigned::try_from(x)?;
Ok($$signed::try_from(uns)?)
}
}|]
, [item|impl<'a> TryFrom<&'a $$sname> for $$signed {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<Self,ConversionError> {
let uns = $$unsigned::try_from(x)?;
Ok($$signed::try_from(uns)?)
}
}|]
]
generateI128Primitives :: Ident -> [Item Span]
generateI128Primitives sname = [
[item|impl TryFrom<i128> for $$sname {
type Error = ConversionError;
fn try_from(x: i128) -> Result<Self,ConversionError> {
let mut res = $$sname::zero();
res.value[0] = x as u64;
res.value[1] = ((x as u128) >> 64) as u64;
if x < 0 {
Err(ConversionError::NegativeToUnsigned)
} else {
Ok(res)
}
}
}|]
, [item|impl TryFrom<$$sname> for i128 {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<Self,ConversionError> {
let uns = u128::try_from(x)?;
Ok(i128::try_from(uns)?)
}
}|]
, [item|impl<'a> TryFrom<&'a $$sname> for i128 {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<Self,ConversionError> {
let uns = u128::try_from(x)?;
Ok(i128::try_from(uns)?)
}
}|]
]
generateCryptonumConversions :: Word -> [Word] -> [Item Span]
generateCryptonumConversions source = concatMap convert
where
sName = mkIdent ("U" ++ show source)
--
convert target =
let tName = mkIdent ("U" ++ show target)
sEntries = toLit (source `div` 64)
tEntries = toLit (target `div` 64)
in case compare source target of
LT -> [
[item|
impl<'a> From<&'a $$sName> for $$tName {
fn from(x: &$$sName) -> $$tName {
let mut res = $$tName::zero();
res.value[0..$$(sEntries)].copy_from_slice(&x.value);
res
}
}
|],
[item|
impl From<$$sName> for $$tName {
fn from(x: $$sName) -> $$tName {
$$tName::from(&x)
}
}
|]
]
EQ -> []
GT -> [
[item|
impl<'a> TryFrom<&'a $$sName> for $$tName {
type Error = ConversionError;
fn try_from(x: &$$sName) -> Result<$$tName, ConversionError> {
if x.value.iter().skip($$(tEntries)).all(|x| *x == 0) {
let mut res = $$tName::zero();
res.value.copy_from_slice(&x.value[0..$$(tEntries)]);
Ok(res)
} else {
Err(ConversionError::Overflow)
}
}
}
|],
[item|
impl TryFrom<$$sName> for $$tName {
type Error = ConversionError;
fn try_from(x: $$sName) -> Result<$$tName, ConversionError> {
$$tName::try_from(&x)
}
}
|]
]
buildUSPrimitives :: Ident -> Ident -> [Item Span]
buildUSPrimitives sname prim = [
[item|
impl From<$$prim> for $$sname {
fn from(x: $$prim) -> $$sname {
let mut base = $$sname::zero();
base.contents.value[0] = x as u64;
base
}
}
|]
, [item|
impl<'a> TryFrom<&'a $$sname> for $$prim {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<$$prim, ConversionError> {
if (x.contents.value)[1..].iter().any(|v| *v != 0) {
return Err(ConversionError::Overflow);
}
let res64 = x.contents.value[0];
if res64 & 0x8000_0000_0000_0000 != 0 {
return Err(ConversionError::Overflow);
}
Ok(res64 as $$prim)
}
}
|]
, [item|
impl TryFrom<$$sname> for $$prim {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<$$prim, ConversionError> {
$$prim::try_from(&x)
}
}
|]
]
buildSSPrimitives :: Ident -> Ident -> Ident -> [Item Span]
buildSSPrimitives sname uname prim = [
[item|
impl From<$$prim> for $$sname {
fn from(x: $$prim) -> $$sname {
let mut ures = $$uname::zero();
let topbits = if x < 0 { 0xFFFF_FFFF_FFFF_FFFF } else { 0 };
for x in ures.value.iter_mut() {
*x = topbits;
}
ures.value[0] = (x as i64) as u64;
$$sname{ contents: ures }
}
}
|]
, [item|
impl<'a> TryFrom<&'a $$sname> for $$prim {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<$$prim, ConversionError> {
let topbits = if x.is_negative() { 0xFFFF_FFFF_FFFF_FFFF } else { 0 };
if x.contents.value[1..].iter().any(|v| *v != topbits) {
return Err(ConversionError::Overflow);
}
let local_min = $$prim::MIN as i64;
let local_max = $$prim::MAX as i64;
let bottom = x.contents.value[0] as i64;
if (bottom > local_max) || (bottom < local_min) {
Err(ConversionError::Overflow)
} else {
Ok(bottom as $$prim)
}
}
}
|]
, [item|
impl TryFrom<$$sname> for $$prim {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<$$prim, ConversionError> {
$$prim::try_from(&x)
}
}
|]
]
generateS128Primitives :: Ident -> Ident -> [Item Span]
generateS128Primitives sname uname = [
[item|
impl From<u128> for $$sname {
fn from(x: u128) -> $$sname {
$$sname{ contents: $$uname::from(x) }
}
}
|],
[item|
impl From<i128> for $$sname {
fn from(x: i128) -> $$sname {
let mut basic = $$uname::from(x as u128);
if x < 0 {
for x in basic.value[2..].iter_mut() {
*x = 0xFFFF_FFFF_FFFF_FFFF;
}
}
$$sname{ contents: basic }
}
}
|],
[item|
impl TryFrom<$$sname> for u128 {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<u128,ConversionError> {
u128::try_from(&x)
}
}
|],
[item|
impl TryFrom<$$sname> for i128 {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<i128,ConversionError> {
i128::try_from(&x)
}
}
|],
[item|
impl<'a> TryFrom<&'a $$sname> for u128 {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<u128,ConversionError> {
if x.is_negative() {
return Err(ConversionError::Overflow);
}
u128::try_from(&x.contents)
}
}
|],
[item|
impl<'a> TryFrom<&'a $$sname> for i128 {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<i128,ConversionError> {
let isneg = x.is_negative();
let target_top = if isneg { 0xFFFF_FFFF_FFFF_FFFF } else { 0x0 };
let mut worked = true;
worked &= x.contents.value[2..].iter().all(|v| *v == target_top);
worked &= (x.contents.value[1] >> 63 == 1) == isneg;
let res = ((x.contents.value[1] as u128) << 64) | (x.contents.value[0] as u128);
if worked {
Ok(res as i128)
} else {
Err(ConversionError::Overflow)
}
}
}
|]
]
generateSignedCryptonumConversions :: Word -> [Word] -> [Item Span]
generateSignedCryptonumConversions source otherSizes = concatMap convert otherSizes
where
suName = mkIdent ("U" ++ show source)
ssName = mkIdent ("I" ++ show source)
--
convert target =
let tsName = mkIdent ("I" ++ show target)
tuName = mkIdent ("U" ++ show target)
sEntries = toLit (source `div` 64)
tEntries = toLit (target `div` 64)
sTop = toLit ((source `div` 64) - 1)
extensions = map (\ x ->
let xLit = toLit x
in [stmt| res.contents.value[$$(xLit)] = extension; |])
[(source `div` 64)..((target `div` 64) - 1)]
in case compare source target of
LT -> [
[item|
impl<'a> From<&'a $$ssName> for $$tsName {
fn from(x: &$$ssName) -> $$tsName {
let mut res = $$tsName::zero();
res.contents.value[0..$$(sEntries)].copy_from_slice(&x.contents.value);
let extension = if x.contents.value[$$(sTop)] & 0x8000_0000_0000_0000 == 0 {
0
} else {
0xFFFF_FFFF_FFFF_FFFFu64
};
$@{extensions}
res
}
}
|],
[item|
impl From<$$ssName> for $$tsName {
fn from(x: $$ssName) -> $$tsName {
$$tsName::from(&x)
}
}
|],
[item|
impl<'a> From<&'a $$suName> for $$tsName {
fn from(x: &$$suName) -> $$tsName {
$$tsName{ contents: $$tuName::from(x) }
}
}
|],
[item|
impl From<$$suName> for $$tsName {
fn from(x: $$suName) -> $$tsName {
$$tsName{ contents: $$tuName::from(x) }
}
}
|],
[item|
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
if x.is_negative() {
Err(ConversionError::NegativeToUnsigned)
} else {
Ok($$tuName::from(&x.contents))
}
}
}
|],
[item|
impl TryFrom<$$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
$$tuName::try_from(&x)
}
}
|]
]
EQ -> [
[item|
impl TryFrom<$$tuName> for $$ssName {
type Error = ConversionError;
fn try_from(x: $$tuName) -> Result<$$ssName,ConversionError> {
let res = $$ssName{ contents: x };
if res.is_negative() {
return Err(ConversionError::Overflow);
}
Ok(res)
}
}
|],
[item|
impl<'a> TryFrom<&'a $$tuName> for $$ssName {
type Error = ConversionError;
fn try_from(x: &$$tuName) -> Result<$$ssName,ConversionError> {
$$ssName::try_from(x.clone())
}
}
|],
[item|
impl TryFrom<$$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
if x.is_negative() {
return Err(ConversionError::Overflow);
}
Ok(x.contents)
}
}
|],
[item|
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
$$tuName::try_from(x.clone())
}
}
|]
]
GT -> [
[item|
impl<'a> TryFrom<&'a $$ssName> for $$tsName {
type Error = ConversionError;
fn try_from(x: &$$ssName) -> Result<$$tsName,ConversionError> {
let required_top = if x.is_negative() {
0xFFFF_FFFF_FFFF_FFFF
} else {
0
};
if x.contents.value.iter().skip($$(tEntries)).all(|x| *x == required_top) {
let mut res = $$tsName::zero();
res.contents.value.copy_from_slice(&x.contents.value[0..$$(tEntries)]);
Ok(res)
} else {
Err(ConversionError::Overflow)
}
}
}
|],
[item|
impl TryFrom<$$ssName> for $$tsName {
type Error = ConversionError;
fn try_from(x: $$ssName) -> Result<$$tsName,ConversionError> {
$$tsName::try_from(&x)
}
}
|],
[item|
impl<'a> TryFrom<&'a $$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: &$$ssName) -> Result<$$tuName,ConversionError> {
if x.is_negative() {
Err(ConversionError::NegativeToUnsigned)
} else {
$$tuName::try_from(&x.contents)
}
}
}
|],
[item|
impl TryFrom<$$ssName> for $$tuName {
type Error = ConversionError;
fn try_from(x: $$ssName) -> Result<$$tuName,ConversionError> {
$$tuName::try_from(&x)
}
}
|]
]
printBy :: Int -> [String] -> Gen ()
printBy amt xs
| length xs <= amt = out (intercalate " " xs)
| otherwise = printBy amt (take amt xs) >>
printBy amt (drop amt xs)

View File

@@ -1,179 +1,188 @@
{-# LANGUAGE QuasiQuotes #-}
module CryptoNum( module CryptoNum(
cryptoNum cryptoNum
) )
where where
import Control.Monad(forM_) import Data.Bits(testBit)
import File import Data.Map.Strict(Map)
import Gen import qualified Data.Map.Strict as Map
import Generators
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
cryptoNum :: File cryptoNum :: RustModule
cryptoNum = File { cryptoNum = RustModule {
predicate = \ _ _ -> True, predicate = \ _ _ -> True,
suggested = const [],
outputName = "cryptonum", outputName = "cryptonum",
isUnsigned = True,
generator = declareCryptoNumInstance, generator = declareCryptoNumInstance,
testGenerator = Nothing testCase = Just generateTest
} }
declareCryptoNumInstance :: Word -> Gen () declareCryptoNumInstance :: Word -> [Word] -> SourceFile Span
declareCryptoNumInstance bitsize = declareCryptoNumInstance bitsize _ =
do let name = "U" ++ show bitsize let sname = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64 entries = bitsize `div` 64
top = entries - 1 entlit = Lit [] (Int Dec (fromIntegral entries) Unsuffixed mempty) mempty
out "use core::cmp::min;" zeroTests = generateZeroTests 0 entries
out "use crate::CryptoNum;" bitlength = toLit bitsize
out "#[cfg(test)]" bytelen = bitsize `div` 8
out "use crate::testing::{build_test_path,run_test};" bytelenlit = toLit bytelen
out "#[cfg(test)]" bytebuffer = Delimited mempty Bracket (Stream [
out "use quickcheck::{Arbitrary,Gen,quickcheck};" Tree (Token mempty (LiteralTok (IntegerTok "0") Nothing)),
out "#[cfg(test)]" Tree (Token mempty Semicolon),
out "use std::fmt;" Tree (Token mempty (LiteralTok (IntegerTok (show bytelen)) Nothing))
out ("use super::" ++ name ++ ";") ])
blank entrieslit = toLit entries
implFor "CryptoNum" name $ testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
do wrapIndent ("fn zero() -> Self") $ in [sourceFile|
out (name ++ "{ value: [0; " ++ show entries ++ "] }") use core::cmp::min;
blank #[cfg(test)]
wrapIndent ("fn is_zero(&self) -> bool") $ use core::convert::TryFrom;
do forM_ (reverse [1..top]) $ \ i -> use crate::CryptoNum;
out ("self.value[" ++ show i ++ "] == 0 &&") #[cfg(test)]
out "self.value[0] == 0" use crate::testing::{build_test_path,run_test};
blank #[cfg(test)]
wrapIndent ("fn is_even(&self) -> bool") $ use quickcheck::quickcheck;
out "self.value[0] & 0x1 == 0" use super::$$sname;
blank
wrapIndent ("fn is_odd(&self) -> bool") $
out "self.value[0] & 0x1 == 1"
blank
wrapIndent ("fn bit_length() -> usize") $
out (show bitsize)
blank
wrapIndent ("fn mask(&mut self, len: usize)") $
do out ("let dellen = min(len, " ++ show entries ++ ");")
wrapIndent ("for i in dellen.." ++ show entries) $
out ("self.value[i] = 0;")
blank
wrapIndent ("fn testbit(&self, bit: usize) -> bool") $
do out "let idx = bit / 64;"
out "let offset = bit % 64;"
wrapIndent ("if idx >= " ++ show entries) $
out "return false;"
out "(self.value[idx] & (1u64 << offset)) != 0"
blank
wrapIndent ("fn from_bytes(bytes: &[u8]) -> Self") $
do out ("let biggest = min(" ++ show (bitsize `div` 8) ++ ", " ++
"bytes.len()) - 1;")
out ("let mut idx = biggest / 8;")
out ("let mut shift = (biggest % 8) * 8;")
out ("let mut i = 0;")
out ("let mut res = " ++ name ++ "::zero();")
blank
wrapIndent ("while i <= biggest") $
do out ("res.value[idx] |= (bytes[i] as u64) << shift;")
out ("i += 1;")
out ("if shift == 0 {")
indent $
do out "shift = 56;"
out "if idx > 0 { idx -= 1; }"
out ("} else {")
indent $
out "shift -= 8;"
out "}"
blank
out "res"
blank
wrapIndent ("fn to_bytes(&self, bytes: &mut [u8])") $
do let bytes = bitsize `div` 8
out ("if bytes.len() == 0 { return; }")
blank
forM_ [0..bytes-1] $ \ idx ->
do let (validx, shift) = byteShiftInfo idx
out ("let byte" ++ show idx ++ " = (self.value[" ++
show validx ++ "] >> " ++ show shift ++ ")" ++
" as u8;")
blank
out ("let mut idx = min(bytes.len() - 1, " ++ show (bytes - 1) ++ ");")
forM_ [0..bytes-2] $ \ i ->
do out ("bytes[idx] = byte" ++ show i ++ ";")
out ("if idx == 0 { return; }")
out ("idx -= 1;")
out ("bytes[idx] = byte" ++ show (bytes-1) ++ ";")
blank
let bytes = bitsize `div` 8
struct = "Bytes" ++ show bytes
out "#[cfg(test)]"
out "#[derive(Clone)]"
wrapIndent ("struct " ++ struct) $
out ("value: [u8; " ++ show bytes ++ "]")
blank
out "#[cfg(test)]"
implFor "PartialEq" struct $
wrapIndent ("fn eq(&self, other: &Self) -> bool") $
out "self.value.iter().zip(other.value.iter()).all(|(a,b)| a == b)"
blank
out "#[cfg(test)]"
implFor "fmt::Debug" struct $
wrapIndent ("fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result") $
out "f.debug_list().entries(self.value.iter()).finish()"
blank
out "#[cfg(test)]"
implFor "Arbitrary" struct $
wrapIndent ("fn arbitrary<G: Gen>(g: &mut G) -> Self") $
do out ("let mut res = " ++ struct ++ "{ value: [0; " ++ show bytes ++ "] };")
out ("g.fill_bytes(&mut res.value);")
out ("res")
blank
out "#[cfg(test)]"
wrapIndent "quickcheck!" $
do wrapIndent ("fn to_from_ident(x: " ++ name ++ ") -> bool") $
do out ("let mut buffer = [0; " ++ show bytes ++ "];")
out ("x.to_bytes(&mut buffer);");
out ("let y = " ++ name ++ "::from_bytes(&buffer);")
out ("x == y")
blank
wrapIndent ("fn from_to_ident(x: " ++ struct ++ ") -> bool") $
do out ("let val = " ++ name ++ "::from_bytes(&x.value);")
out ("let mut buffer = [0; " ++ show bytes ++ "];")
out ("val.to_bytes(&mut buffer);")
out ("buffer.iter().zip(x.value.iter()).all(|(a,b)| a == b)")
blank
out "#[cfg(test)]"
out "#[allow(non_snake_case)]"
out "#[test]"
wrapIndent "fn KATs()" $
do let name' = pad 5 '0' (show bitsize)
out ("run_test(build_test_path(\"base\",\"" ++ name' ++ "\"), 8, |case| {")
indent $
do out ("let (neg0, xbytes) = case.get(\"x\").unwrap();")
out ("let (neg1, mbytes) = case.get(\"m\").unwrap();")
out ("let (neg2, zbytes) = case.get(\"z\").unwrap();")
out ("let (neg3, ebytes) = case.get(\"e\").unwrap();")
out ("let (neg4, obytes) = case.get(\"o\").unwrap();")
out ("let (neg5, rbytes) = case.get(\"r\").unwrap();")
out ("let (neg6, bbytes) = case.get(\"b\").unwrap();")
out ("let (neg7, tbytes) = case.get(\"t\").unwrap();")
out ("assert!(!neg0&&!neg1&&!neg2&&!neg3&&!neg4&&!neg5&&!neg6&&!neg7);")
out ("let mut x = "++name++"::from_bytes(xbytes);")
out ("let m = "++name++"::from_bytes(mbytes);")
out ("let z = 1 == zbytes[0];")
out ("let e = 1 == ebytes[0];")
out ("let o = 1 == obytes[0];")
out ("let r = "++name++"::from_bytes(rbytes);")
out ("let b = usize::from("++name++"::from_bytes(bbytes));")
out ("let t = 1 == tbytes[0];")
out ("assert_eq!(x.is_zero(), z);")
out ("assert_eq!(x.is_even(), e);")
out ("assert_eq!(x.is_odd(), o);")
out ("assert_eq!(x.testbit(b), t);")
out ("x.mask(usize::from(&m));")
out ("assert_eq!(x, r);")
out ("});")
byteShiftInfo :: Word -> (Word, Word) impl CryptoNum for $$sname {
byteShiftInfo idx = fn zero() -> Self {
(idx `div` 8, (idx `mod` 8) * 8) $$sname{ value: [0; $$(entlit)] }
}
fn is_zero(&self) -> bool {
let mut result = true;
$@{zeroTests}
result
}
fn is_even(&self) -> bool {
self.value[0] & 0x1 == 0
}
fn is_odd(&self) -> bool {
self.value[0] & 0x1 == 1
}
fn bit_length() -> usize {
$$(bitlength)
}
fn mask(&mut self, len: usize) {
let dellen = min(len, $$(entrieslit));
for i in dellen..$$(entrieslit) {
self.value[i] = 0;
}
}
fn testbit(&self, bit: usize) -> bool {
let idx = bit / 64;
let offset = bit % 64;
if idx >= $$(entrieslit) {
return false;
}
(self.value[idx] & (1u64 << offset)) != 0
}
fn from_bytes(bytes: &[u8]) -> Self {
let biggest = min($$(bytelenlit), bytes.len()) - 1;
let mut idx = biggest / 8;
let mut shift = (biggest % 8) * 8;
let mut i = 0;
let mut res = $$sname::zero();
pad :: Int -> Char -> String -> String while i <= biggest {
pad len c str res.value[idx] |= (bytes[i] as u64) << shift;
| length str >= len = str i += 1;
| otherwise = pad len c (c:str) if shift == 0 {
shift = 56;
if idx > 0 {
idx -= 1;
}
} else {
shift -= 8;
}
}
res
}
fn to_bytes(&self, bytes: &mut [u8]) {
let mut idx = 0;
let mut shift = 0;
for x in bytes.iter_mut().take($$(bytelenlit)).rev() {
*x = (self.value[idx] >> shift) as u8;
shift += 8;
if shift == 64 {
idx += 1;
shift = 0;
}
}
}
}
#[cfg(test)]
quickcheck! {
fn to_from_ident(x: $$sname) -> bool {
let mut buffer = $$(bytebuffer);
x.to_bytes(&mut buffer);
let y = $$sname::from_bytes(&buffer);
x == y
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("cryptonum", $$(testFileLit)), 8, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, mbytes) = case.get("m").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let (neg3, ebytes) = case.get("e").unwrap();
let (neg4, obytes) = case.get("o").unwrap();
let (neg5, rbytes) = case.get("r").unwrap();
let (neg6, bbytes) = case.get("b").unwrap();
let (neg7, tbytes) = case.get("t").unwrap();
assert!(!neg0 && !neg1 && !neg2 && !neg3 &&
!neg4 && !neg5 && !neg6 && !neg7);
let mut x = $$sname::from_bytes(&xbytes);
let z = 1 == zbytes[0];
let e = 1 == ebytes[0];
let o = 1 == obytes[0];
let t = 1 == tbytes[0];
let m = usize::try_from($$sname::from_bytes(&mbytes)).unwrap();
let b = usize::try_from($$sname::from_bytes(&bbytes)).unwrap();
let r = $$sname::from_bytes(&rbytes);
assert_eq!(x.is_zero(), z);
assert_eq!(x.is_even(), e);
assert_eq!(x.is_odd(), o);
assert_eq!(x.testbit(b), t);
x.mask(m);
assert_eq!(x, r);
});
}
|]
generateZeroTests :: Word -> Word -> [Stmt Span]
generateZeroTests i entries
| i == entries = []
| otherwise =
let ilit = toLit i
in [stmt| result &= self.value[$$(ilit)] == 0; |] :
generateZeroTests (i + 1) entries
generateTest :: RandomGen g => Word -> g -> (Map String String, g)
generateTest size g0 = (tcase, g3)
where
(x, g1) = generateNum g0 size
(m, g2) = generateNum g1 size
(b, g3) = generateNum g2 16
m' = m `mod` (fromIntegral size `div` 64)
r = x `mod` (2 ^ (64 * m'))
t = x `testBit` (fromIntegral b)
tcase = Map.fromList [("x", showX x), ("z", showB (x == 0)),
("e", showB (even x)), ("o", showB (odd x)),
("m", showX m'), ("r", showX r),
("b", showX b), ("t", showB t)]

207
generation/src/Division.hs Normal file
View File

@@ -0,0 +1,207 @@
{-# LANGUAGE QuasiQuotes #-}
module Division(divisionOps)
where
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Generators
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
divisionOps :: RustModule
divisionOps = RustModule {
predicate = \ _ _ -> True,
suggested = const [],
outputName = "divmod",
isUnsigned = True,
generator = declareDivision,
testCase = Just generateDivisionTest
}
declareDivision :: Word -> [Word] -> SourceFile Span
declareDivision size _ =
let sname = mkIdent ("U" ++ show size)
entries = size `div` 64
copyAssign = map doCopy [0..entries-1]
testFileLit = Lit [] (Str (testFile True size) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::{Div, DivAssign};
use core::ops::{Rem, RemAssign};
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::unsigned::$$sname;
use super::super::super::DivMod;
impl DivMod for $$sname {
fn divmod(&self, rhs: &$$sname) -> ($$sname, $$sname) {
let mut q = $$sname::zero();
let mut r = $$sname::zero();
for (ndigit, qdigit) in self.value.iter().rev().zip(q.value.iter_mut().rev()) {
for i in (0..64).rev() {
let mut r1: $$sname = &r << 1u64;
r1.value[0] |= (ndigit >> i) & 1u64;
let mut r2: $$sname = r1.clone();
r2 -= rhs;
let (newr, bit) = if &r1 > rhs {
(r2, 1)
} else {
(r1, 0)
};
r = newr;
*qdigit |= bit << i;
}
}
(q, r)
}
}
impl Div for $$sname {
type Output = $$sname;
fn div(self, rhs: $$sname) -> Self::Output {
let (res, _) = self.divmod(&rhs);
res
}
}
impl<'a> Div<$$sname> for &'a $$sname {
type Output = $$sname;
fn div(self, rhs: $$sname) -> Self::Output {
let (res, _) = self.divmod(&rhs);
res
}
}
impl<'a> Div<&'a $$sname> for $$sname {
type Output = $$sname;
fn div(self, rhs: &$$sname) -> Self::Output {
let (res, _) = self.divmod(rhs);
res
}
}
impl<'a,'b> Div<&'a $$sname> for &'b $$sname {
type Output = $$sname;
fn div(self, rhs: &$$sname) -> Self::Output {
let (res, _) = self.divmod(rhs);
res
}
}
impl DivAssign for $$sname {
fn div_assign(&mut self, rhs: $$sname) {
let (res, _) = self.divmod(&rhs);
$@{copyAssign}
}
}
impl<'a> DivAssign<&'a $$sname> for $$sname {
fn div_assign(&mut self, rhs: &$$sname) {
let (res, _) = self.divmod(rhs);
$@{copyAssign}
}
}
impl Rem for $$sname {
type Output = $$sname;
fn rem(self, rhs: $$sname) -> Self::Output {
let (_, res) = self.divmod(&rhs);
res
}
}
impl<'a> Rem<$$sname> for &'a $$sname {
type Output = $$sname;
fn rem(self, rhs: $$sname) -> Self::Output {
let (_, res) = self.divmod(&rhs);
res
}
}
impl<'a> Rem<&'a $$sname> for $$sname {
type Output = $$sname;
fn rem(self, rhs: &$$sname) -> Self::Output {
let (_, res) = self.divmod(rhs);
res
}
}
impl<'a,'b> Rem<&'a $$sname> for &'b $$sname {
type Output = $$sname;
fn rem(self, rhs: &$$sname) -> Self::Output {
let (_, res) = self.divmod(rhs);
res
}
}
impl RemAssign for $$sname {
fn rem_assign(&mut self, rhs: $$sname) {
let (_, res) = self.divmod(&rhs);
$@{copyAssign}
}
}
impl<'a> RemAssign<&'a $$sname> for $$sname {
fn rem_assign(&mut self, rhs: &$$sname) {
let (_, res) = self.divmod(rhs);
$@{copyAssign}
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("divmod", $$(testFileLit)), 4, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let (neg3, rbytes) = case.get("r").unwrap();
assert!(!neg0 && !neg1 && !neg2 && !neg3);
let x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$sname::from_bytes(&zbytes);
let r = $$sname::from_bytes(&rbytes);
let (myz, myr) = x.divmod(&y);
assert_eq!(z, myz);
assert_eq!(r, myr);
assert_eq!(z, &x / &y);
assert_eq!(r, &x % &y);
});
}
|]
doCopy :: Word -> Stmt Span
doCopy i =
let liti = toLit i
in [stmt| self.value[$$(liti)] = res.value[$$(liti)]; |]
generateDivisionTest :: RandomGen g => Word -> g -> (Map String String, g)
generateDivisionTest size g = go g
where
go g0 =
let (x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX (x `div` y)),
("r", showX (x `mod` y))]
in if y == 0
then go g2
else (tcase, g2)

View File

@@ -1,71 +0,0 @@
module File(
File(..),
Task(..),
addModuleTasks,
makeTasks
)
where
import Control.Monad(forM_)
import Data.Char(toUpper)
import Data.List(isPrefixOf)
import qualified Data.Map.Strict as Map
import Gen(Gen,blank,out)
import System.FilePath(takeBaseName,takeDirectory,takeFileName,(</>))
data File = File {
predicate :: Word -> [Word] -> Bool,
outputName :: FilePath,
generator :: Word -> Gen (),
testGenerator :: Maybe (Word -> Gen ())
}
data Task = Task {
outputFile :: FilePath,
fileGenerator :: Gen ()
}
makeTasks :: FilePath -> FilePath ->
Word -> [Word] ->
File ->
[Task]
makeTasks srcBase testBase size allSizes file
| predicate file size allSizes =
let base = Task (srcBase </> ("u" ++ show size) </> outputName file <> ".rs") (generator file size)
in case testGenerator file of
Nothing -> [base]
Just x ->
[base, Task (testBase </> outputName file </> ("U" ++ show size ++ ".test")) (x size)]
| otherwise = []
addModuleTasks :: FilePath -> [Task] -> [Task]
addModuleTasks base baseTasks = unsignedTask : (baseTasks ++ moduleTasks)
where
moduleMap = foldr addModuleInfo Map.empty baseTasks
addModuleInfo task map
| base `isPrefixOf` outputFile task =
Map.insertWith (++) (takeDirectory (outputFile task))
[takeBaseName (outputFile task)]
map
| otherwise = map
moduleTasks = Map.foldrWithKey generateModuleTask [] moduleMap
generateModuleTask directory mods acc = acc ++ [Task {
outputFile = directory </> "mod.rs",
fileGenerator =
do forM_ mods $ \ modle -> out ("mod " ++ modle ++ ";")
blank
out ("pub use base::" ++ upcase (takeFileName directory) ++ ";")
}]
unsignedTask = Task {
outputFile = base </> "unsigned" </> "mod.rs",
fileGenerator =
do forM_ (Map.keys moduleMap) $ \ key ->
out ("mod " ++ takeFileName key ++ ";")
blank
forM_ (Map.keys moduleMap) $ \ key ->
out ("pub use " ++ takeFileName key ++ "::" ++
upcase (takeFileName key) ++ ";")
}
upcase :: String -> String
upcase = map toUpper

View File

@@ -1,115 +0,0 @@
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Gen(
Gen(Gen),
runGen,
gensym,
indent,
blank,
out,
wrapIndent,
implFor,
implFor',
implFor'',
newNum,
TestVariable(..),
)
where
import Control.Monad.RWS.Strict(RWS,evalRWS)
import Control.Monad.State.Class(MonadState,get,put)
import Control.Monad.Writer.Class(MonadWriter,tell)
import Data.Bits(shiftL)
import Data.List(replicate)
import Data.Word(Word)
import Numeric(showHex)
import System.Random(StdGen, newStdGen, random, randomR)
newtype Gen a = Gen { unGen :: RWS () String GenState a}
deriving (Applicative, Functor, Monad, MonadState GenState, MonadWriter String)
tabAmount :: Word
tabAmount = 4
data GenState = GenState {
indentAmount :: Word,
gensymIndex :: Word,
rng :: StdGen
}
initGenState :: IO GenState
initGenState =
do rng0 <- newStdGen
return GenState { indentAmount = 0, gensymIndex = 0, rng = rng0 }
runGen :: FilePath -> Gen a -> IO a
runGen path action =
do state0 <- initGenState
let (res, contents) = evalRWS (unGen action) () state0
writeFile path contents
return res
gensym :: String -> Gen String
gensym prefix =
do gs <- get
let gs' = gs{ gensymIndex = gensymIndex gs + 1 }
put gs'
return (prefix ++ show (gensymIndex gs))
indent :: Gen a -> Gen a
indent action =
do gs <- get
put gs{ indentAmount = indentAmount gs + tabAmount }
res <- action
put gs
return res
blank :: Gen ()
blank = tell "\n"
out :: String -> Gen ()
out val =
do gs <- get
tell (replicate (fromIntegral (indentAmount gs)) ' ')
tell val
tell "\n"
wrapIndent :: String -> Gen a -> Gen a
wrapIndent val middle =
do gs <- get
tell (replicate (fromIntegral (indentAmount gs)) ' ')
tell val
tell " {\n"
res <- indent middle
tell (replicate (fromIntegral (indentAmount gs)) ' ')
tell "}\n"
return res
implFor :: String -> String -> Gen a -> Gen a
implFor trait name middle =
wrapIndent ("impl " ++ trait ++ " for " ++ name) middle
implFor' :: String -> String -> Gen a -> Gen a
implFor' trait name middle =
wrapIndent ("impl<'a> " ++ trait ++ " for " ++ name) middle
implFor'' :: String -> String -> Gen a -> Gen a
implFor'' trait name middle =
wrapIndent ("impl<'a,'b> " ++ trait ++ " for " ++ name) middle
newNum :: Bool -> Word -> Gen Integer
newNum signed bits =
do gs <- get
let rng0 = rng gs
let high = (1 `shiftL` fromIntegral bits) - 1
let (v, rng1) = randomR (0, high) rng0
let (sign, rng2) = random rng1
let v' = if signed && sign then -v else v
put gs{ rng = rng2 }
return v'
class TestVariable a where
emitTestVariable :: Char -> a -> Gen ()
instance TestVariable Integer where
emitTestVariable c v =
out ([c] ++ ": " ++ showHex v "")

View File

@@ -0,0 +1,45 @@
module Generators
where
import Language.Rust.Data.Position
import Language.Rust.Syntax
import Numeric(showHex)
import System.Random(RandomGen,random,randomR)
toLit :: Word -> Expr Span
toLit i = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty
generateNum :: RandomGen g => g -> Word -> (Integer, g)
generateNum g size =
let (x, g') = random g
x' = x `mod` (2 ^ size)
in (x', g')
generateSignedNum :: RandomGen g => g -> Word -> (Integer, g)
generateSignedNum g size =
let biggest = (2 ^ (size - 1)) - 1
smallest = - (2 ^ (size - 1))
(x, g') = randomR (smallest, biggest) g
in (x, g')
modulate :: (Integral a, Integral b) => a -> b -> Integer
modulate x size = x' `mod` (2 ^ size')
where
x', size' :: Integer
size' = fromIntegral size
x' = fromIntegral x
modulate' :: (Num a, Integral a, Integral b) => a -> b -> Integer
modulate' x size = signum x' * ((abs x') `mod` (2 ^ size'))
where
x', size' :: Integer
size' = fromIntegral size
x' = fromIntegral x
showX :: Integer -> String
showX x | x < 0 = "-" ++ showX (abs x)
| otherwise = showHex x ""
showB :: Bool -> String
showB False = "0"
showB True = "1"

712
generation/src/Karatsuba.hs Normal file
View File

@@ -0,0 +1,712 @@
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Karatsuba(
Instruction(..)
, InstructionData(..)
, Variable
, runChecks
, runQuickCheck
, generateInstructions
, variableName
)
where
import Control.Monad.Fail(MonadFail(..))
import Control.Monad.Identity hiding (fail)
import Control.Monad.RWS.Strict hiding (fail)
import Data.Bits
import Data.LargeWord
import Data.List
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Data.Vector(Vector)
import qualified Data.Vector as V
import Data.Word
import Prelude hiding (fail)
import Test.QuickCheck hiding ((.&.))
import Debug.Trace
-- this drives the testing
inputWordSize :: Int
inputWordSize = 5
data InstructionData = InstructionData {
idInstructions :: [Instruction],
idInput1 :: [Variable],
idInput2 :: [Variable],
idOutput :: [Variable]
}
generateInstructions :: Word -> InstructionData
generateInstructions numdigits =
let (baseID, baseInstrs) = runMath $ do (x, xinstrs) <- listen $ V.replicateM (fromIntegral numdigits) (genDigit 1)
(y, yinstrs) <- listen $ V.replicateM (fromIntegral numdigits) (genDigit 1)
res <- karatsuba x y
return InstructionData {
idInstructions = xinstrs ++ yinstrs,
idInput1 = map name (V.toList x),
idInput2 = map name (V.toList y),
idOutput = map name (V.toList res)
}
(preps, reals) = splitAt (length (idInstructions baseID)) baseInstrs
in baseID{ idInstructions = preps ++ simplifyConstants reals }
-- -----------------------------------------------------------------------------
--
-- Instructions that we emit as a result of running Karatsuba, that can be
-- turned into Rust lines.
--
-- -----------------------------------------------------------------------------
newtype Variable = V Word
deriving (Eq, Ord, Show)
variableName :: Variable -> String
variableName (V x) = "t" ++ show x
-- these are in Intel form, as I was corrupted young, so the first argument
-- is the destination and the rest are the arguments.
data Instruction = Add Variable [Variable]
| CastDown Variable Variable
| CastUp Variable Variable
| Complement Variable Variable
| Declare64 Variable Word64
| Declare128 Variable Word128
| Mask Variable Variable Word128
| Multiply Variable [Variable]
| ShiftR Variable Variable Int
deriving (Eq, Show)
class Declarable a where
declare :: Variable -> a -> Instruction
instance Declarable Word64 where
declare n x = Declare64 n x
instance Declarable Word128 where
declare n x = Declare128 n x
type Env = (Map Variable Word64, Map Variable Word128)
step :: Env -> Instruction -> Env
step (env64, env128) i =
case i of
Add outname items ->
(env64, Map.insert outname (sum (map (getv env128) items)) env128)
CastDown outname item ->
(Map.insert outname (fromIntegral (getv env128 item)) env64, env128)
CastUp outname item ->
(env64, Map.insert outname (fromIntegral (getv env64 item)) env128)
Complement outname item ->
(Map.insert outname (complement (getv env64 item)) env64, env128)
Declare64 outname val ->
(Map.insert outname val env64, env128)
Declare128 outname val ->
(env64, Map.insert outname val env128)
Mask outname item mask ->
(env64, Map.insert outname (getv env128 item .&. mask) env128)
Multiply outname items ->
(env64, Map.insert outname (product (map (getv env128) items)) env128)
ShiftR outname item amt ->
(env64, Map.insert outname (getv env128 item `shiftR` amt) env128)
where
getv :: Map Variable a -> Variable -> a
getv env s =
case Map.lookup s env of
Nothing -> error ("Failure to find key '" ++ show s ++ "'")
Just v -> v
run :: Env -> [Instruction] -> Env
run env instrs =
case instrs of
[] -> env
(x:rest) -> run (step env x) rest
simplifyConstants :: [Instruction] -> [Instruction]
simplifyConstants instrs = go instrs Map.empty Map.empty Map.empty
where
go [] _ _ _ = []
go (instr:rest) consts64 consts128 remaps =
case instr of
Add outname items ->
Add outname (map (replace remaps) items) : go rest consts64 consts128 remaps
CastDown outname item ->
CastDown outname (replace remaps item) : go rest consts64 consts128 remaps
CastUp outname item ->
CastUp outname (replace remaps item) : go rest consts64 consts128 remaps
Complement outname item ->
Complement outname (replace remaps item) : go rest consts64 consts128 remaps
Declare64 outname val | Just outname' <- Map.lookup val consts64 ->
go rest consts64 consts128 (Map.insert outname outname' remaps)
Declare64 outname val ->
Declare64 outname val : go rest (Map.insert val outname consts64) consts128 remaps
Declare128 outname val | Just outname' <- Map.lookup val consts128 ->
go rest consts64 consts128 (Map.insert outname outname' remaps)
Declare128 outname val ->
Declare128 outname val : go rest consts64 (Map.insert val outname consts128) remaps
Mask outname item mask ->
Mask outname (replace remaps item) mask : go rest consts64 consts128 remaps
Multiply outname items ->
Multiply outname (map (replace remaps) items) : go rest consts64 consts128 remaps
ShiftR outname item amt ->
ShiftR outname (replace remaps item) amt : go rest consts64 consts128 remaps
replace :: Map Variable Variable -> Variable -> Variable
replace remaps item = Map.findWithDefault item item remaps
-- -----------------------------------------------------------------------------
--
-- The Math monad.
--
-- -----------------------------------------------------------------------------
newtype Math a = Math { unMath :: RWS () [Instruction] Word a }
deriving (Applicative, Functor, Monad,
MonadState Word,
MonadWriter [Instruction])
instance MonadFail Math where
fail s = error ("Math fail: " ++ s)
emit :: Instruction -> Math ()
emit instr = tell [instr]
newVariable :: Math Variable
newVariable =
do x <- state (\ i -> (i, i + 1))
return (V x)
runMath :: Math a -> (a, [Instruction])
runMath m = evalRWS (unMath m) () 0
-- -----------------------------------------------------------------------------
--
-- Primitive mathematics that can run on a Digit
--
-- -----------------------------------------------------------------------------
data Digit size = D {
name :: Variable
, digit :: size
}
deriving (Eq,Show)
genDigit :: Declarable size => size -> Math (Digit size)
genDigit x =
do newName <- newVariable
emit (declare newName x)
return D{
name = newName
, digit = x
}
embiggen :: Digit Word64 -> Math (Digit Word128)
embiggen x =
do newName <- newVariable
emit (CastUp newName (name x))
return (D newName (fromIntegral (digit x)))
bottomBits :: Digit Word128 -> Math (Digit Word64)
bottomBits x =
do newName <- newVariable
emit (CastDown newName (name x))
return (D newName (fromIntegral (digit x)))
oneDigit :: Math (Digit Word64)
oneDigit = genDigit 1
bigZero :: Math (Digit Word128)
bigZero = genDigit 0
(|+|) :: Digit Word128 -> Digit Word128 -> Math (Digit Word128)
(|+|) x y =
do newName <- newVariable
emit (Add newName [name x, name y])
let digval = digit x + digit y
return (D newName digval)
sumDigits :: [Digit Word128] -> Math (Digit Word128)
sumDigits ls =
do newName <- newVariable
emit (Add newName (map name ls))
let digval = sum (map digit ls)
return (D newName digval)
(|*|) :: Digit Word128 -> Digit Word128 -> Math (Digit Word128)
(|*|) x y =
do newName <- newVariable
emit (Multiply newName [name x, name y])
let digval = digit x * digit y
return (D newName digval)
(|>>|) :: Digit Word128 -> Int -> Math (Digit Word128)
(|>>|) x s =
do newName <- newVariable
emit (ShiftR newName (name x) s)
let digval = digit x `shiftR` s
return (D newName digval)
(|&|) :: Digit Word128 -> Word128 -> Math (Digit Word128)
(|&|) x m =
do newName <- newVariable
emit (Mask newName (name x) m)
let digval = digit x .&. m
return (D newName digval)
complementDigit :: Digit Word64 -> Math (Digit Word64)
complementDigit x =
do newName <- newVariable
emit (Complement newName (name x))
return (D newName (complement (digit x)))
-- -----------------------------------------------------------------------------
--
-- Extended mathematics that run on whole numbers
--
-- -----------------------------------------------------------------------------
type Number = Vector (Digit Word64)
convertTo :: Int -> Integer -> Math Number
convertTo sz num = V.fromList `fmap` go sz num
where
go :: Int -> Integer -> Math [Digit Word64]
go 0 _ =
return []
go x v =
do d <- genDigit (fromIntegral v)
rest <- go (x - 1) (v `shiftR` 64)
return (d:rest)
convertFrom :: Number -> Integer
convertFrom n = V.foldr combine 0 n
where
combine x acc = (acc `shiftL` 64) + fromIntegral (digit x)
prop_ConversionWorksInt :: Integer -> Bool
prop_ConversionWorksInt n = n' == back
where
n' = abs n `mod` (2 ^ (inputWordSize * 64))
there = fst (runMath (convertTo inputWordSize n'))
back = convertFrom there
zero :: Int -> Math Number
zero s = V.fromList `fmap` replicateM s (genDigit 0)
empty :: Number -> Bool
empty = null
size :: Number -> Int
size = length
splitDigits :: Int -> Number -> Math (Number, Number)
splitDigits i ls = return (V.splitAt i ls)
prop_SplitDigitsIsntTerrible :: Int -> Int -> Integer -> Bool
prop_SplitDigitsIsntTerrible a b n =
let a' = a `mod` 20
b' = b `mod` 20
(p, l) | a' > b' = (b', a')
| a' < b' = (a', b')
| otherwise = (a' - 1, a')
in fst $ runMath $ do base <- convertTo l n
(left, right) <- splitDigits p base
return (base == (left <> right))
addZeros :: Int -> Number -> Math Number
addZeros x n =
do prefix <- zero x
return (prefix <> n)
prop_AddZerosIsShift :: Int -> Integer -> Bool
prop_AddZerosIsShift x n =
fst $ runMath $ do base <- convertTo inputWordSize n'
added <- addZeros x' base
let shiftVer = n' `shiftL` (x' * 64)
let mine = convertFrom added
return (shiftVer == mine)
where
x' = abs x `mod` inputWordSize
n' = abs n `mod` (2 ^ (inputWordSize * 64))
padTo :: Int -> Number -> Math Number
padTo len num =
do suffix <- zero (len - V.length num)
return (num <> suffix)
prop_PadToWorks :: Int -> Int -> Integer -> Bool
prop_PadToWorks a b num =
fst $ runMath $ do base <- convertTo sz num'
padded <- padTo len base
let newval = convertFrom padded
return (num' == newval)
where
a' = abs a `mod` (inputWordSize * 3)
b' = abs b `mod` (inputWordSize * 3)
(len, sz) | a' >= b' = (max 1 a', max 1 b')
| otherwise = (max 1 b', max 1 a')
num' = abs (num `mod` (2 ^ (64 * sz)))
add2 :: Number -> Number -> Math Number
add2 xs ys
| size xs /= size ys =
fail "Add2 of uneven vectors."
| otherwise =
do let both = V.zip xs ys
nada <- bigZero
(res, carry) <- foldM ripple (V.empty, nada) both
lastDigit <- bottomBits carry
return (res <> V.singleton lastDigit)
where
ripple (res, carry) (x, y) =
do x' <- embiggen x
y' <- embiggen y
bigRes <- sumDigits [x', y', carry]
carry' <- bigRes |>>| 64
newdigit <- bottomBits bigRes
let res' = res <> V.singleton newdigit
return (res', carry')
prop_Add2Works :: Int -> Integer -> Integer -> Bool
prop_Add2Works l n m =
fst $ runMath $ do num1 <- convertTo l' n'
num2 <- convertTo l' m'
res <- add2 num1 num2
let intRes = convertFrom res
return ((intRes == r) && (size res == l' + 1))
where
l' = max 1 (abs l `mod` inputWordSize)
n' = abs n `mod` (2 ^ (l' * 64))
m' = abs m `mod` (2 ^ (l' * 64))
r = n' + m'
add3 :: Number -> Number -> Number -> Math Number
add3 x y z
| size x /= size y =
fail "Unequal lengths in add3 (1)."
| size y /= size z =
fail "Unequal lengths in add3 (2)."
| otherwise =
do let allThem = V.zip3 x y z
nada <- bigZero
(res, carry) <- foldM ripple (V.empty, nada) allThem
lastDigit <- bottomBits carry
return (res <> V.singleton lastDigit)
where
ripple (res, carry) (a, b, c) =
do a' <- embiggen a
b' <- embiggen b
c' <- embiggen c
bigRes <- sumDigits [a', b', c', carry]
carry' <- bigRes |>>| 64
digit' <- bottomBits bigRes
let res' = res <> V.singleton digit'
return (res', carry')
prop_Add3Works :: Int -> Integer -> Integer -> Integer -> Bool
prop_Add3Works l x y z =
fst $ runMath $ do num1 <- convertTo l' x'
num2 <- convertTo l' y'
num3 <- convertTo l' z'
res <- add3 num1 num2 num3
let intRes = convertFrom res
return ((intRes == r) && (size res == l' + 1))
where
l' = max 1 (abs l `mod` inputWordSize)
x' = abs x `mod` (2 ^ (l' * 64))
y' = abs y `mod` (2 ^ (l' * 64))
z' = abs z `mod` (2 ^ (l' * 64))
r = x' + y' + z'
sub2 :: Number -> Number -> Math Number
sub2 x y
| size x /= size y =
fail "Unequal lengths in sub."
| otherwise =
do yinv <- mapM complementDigit y
oned <- oneDigit
one <- padTo (size x) (V.singleton oned)
res <- add3 x yinv one
return (V.take (size x) res)
prop_Sub2Works :: Int -> Integer -> Integer -> Bool
prop_Sub2Works l a b =
fst $ runMath $ do num1 <- convertTo l' x
num2 <- convertTo l' y
res <- sub2 num1 num2
let intRes = convertFrom res
return (intRes == r)
where
l' = max 1 (abs l `mod` inputWordSize)
a' = abs a `mod` (2 ^ (l' * 64))
b' = abs b `mod` (2 ^ (l' * 64))
(x, y) | a' >= b' = (a', b')
| otherwise = (b', a')
r = x - y
-- -----------------------------------------------------------------------------
--
-- Finally, multiplication and Karatsuba
--
-- -----------------------------------------------------------------------------
mul1 :: Number -> Number -> Math Number
mul1 num1 num2
| size num1 /= 1 || size num2 /= 1 =
fail "Called mul1 with !1 digit numbers. Idiot."
| otherwise =
do x' <- embiggen (V.head num1)
y' <- embiggen (V.head num2)
comb <- x' |*| y'
z0 <- bottomBits comb
z1 <- bottomBits =<< (comb |>>| 64)
return (V.fromList [z0, z1])
prop_MulNWorks :: Int -> (Number -> Number -> Math Number) ->
Integer -> Integer ->
Bool
prop_MulNWorks nsize f x y =
fst $ runMath $ do num1 <- convertTo nsize x'
num2 <- convertTo nsize y'
res <- f num1 num2
let resInt = convertFrom res
return ((size res == (nsize * 2)) && (resInt == (x' * y')))
where
x' = abs x `mod` (2 ^ (64 * nsize))
y' = abs y `mod` (2 ^ (64 * nsize))
prop_Mul1Works :: Integer -> Integer -> Bool
prop_Mul1Works = prop_MulNWorks 1 mul1
mul2 :: Number -> Number -> Math Number
mul2 num1 num2
| size num1 /= 2 || size num2 /= 2 =
fail "Called mul2 with !2 digit numbers. Idiot."
| otherwise =
do [l0, l1] <- mapM embiggen (V.toList num1)
[r0, r1] <- mapM embiggen (V.toList num2)
--
l0r0 <- l0 |*| r0
carry0 <- l0r0 |>>| 64
dest0 <- bottomBits l0r0
l1r0 <- l1 |*| r0
l1r0' <- l1r0 |+| carry0
tdest1 <- l1r0' |&| 0xFFFFFFFFFFFFFFFF
tdest2 <- l1r0' |>>| 64
--
l0r1 <- l0 |*| r1
l0r1' <- tdest1 |+| l0r1
dest1 <- bottomBits l0r1'
l1r1 <- l1 |*| r1
l1r1' <- tdest2 |+| l1r1
carry1 <- l0r1' |>>| 64
l1r1'' <- l1r1' |+| carry1
dest2 <- bottomBits l1r1''
dest3 <- bottomBits =<< (l1r1'' |>>| 64)
return (V.fromList [dest0, dest1, dest2, dest3])
prop_Mul2Works :: Integer -> Integer -> Bool
prop_Mul2Works = prop_MulNWorks 2 mul2
mul3 :: Number -> Number -> Math Number
mul3 num1 num2
| size num1 /= 3 || size num2 /= 3 =
fail "Called mul2 with !2 digit numbers. Idiot."
| otherwise =
do [l0, l1, l2] <- mapM embiggen (V.toList num1)
[r0, r1, r2] <- mapM embiggen (V.toList num2)
--
l0r0 <- l0 |*| r0
dest0 <- bottomBits l0r0
carry0 <- l0r0 |>>| 64
l1r0 <- l1 |*| r0
l1r0' <- l1r0 |+| carry0
tdest1 <- l1r0' |&| 0xFFFFFFFFFFFFFFFF
carry1 <- l1r0' |>>| 64
l2r0 <- l2 |*| r0
l2r0' <- l2r0 |+| carry1
tdest2 <- l2r0' |&| 0xFFFFFFFFFFFFFFFF
tdest3 <- l2r0' |>>| 64
--
l0r1 <- l0 |*| r1
l0r1' <- tdest1 |+| l0r1
dest1 <- bottomBits l0r1'
carry2 <- l0r1' |>>| 64
l1r1 <- l1 |*| r1
l1r1' <- sumDigits [l1r1, tdest2, carry2]
tdest2' <- l1r1' |&| 0xFFFFFFFFFFFFFFFF
carry3 <- l1r1' |>>| 64
l2r1 <- l2 |*| r1
l2r1' <- sumDigits [l2r1, tdest3, carry3]
tdest3' <- l2r1' |&| 0xFFFFFFFFFFFFFFFF
tdest4' <- l2r1' |>>| 64
--
l0r2 <- l0 |*| r2
l0r2' <- l0r2 |+| tdest2'
dest2 <- bottomBits l0r2'
carry4 <- l0r2' |>>| 64
l1r2 <- l1 |*| r2
l1r2' <- sumDigits [l1r2, tdest3', carry4]
dest3 <- bottomBits l1r2'
carry5 <- l1r2' |>>| 64
l2r2 <- l2 |*| r2
l2r2' <- sumDigits [l2r2, tdest4', carry5]
dest4 <- bottomBits l2r2'
dest5 <- bottomBits =<< (l2r2' |>>| 64)
return (V.fromList [dest0, dest1, dest2, dest3, dest4, dest5])
prop_Mul3Works :: Integer -> Integer -> Bool
prop_Mul3Works = prop_MulNWorks 3 mul3
karatsuba :: Number -> Number -> Math Number
karatsuba num1 num2
| size num1 /= size num2 =
fail "Uneven numeric lengths!"
| empty num1 =
fail "Got empty nums"
| size num1 == 1 = mul1 num1 num2
| size num1 == 2 = mul2 num1 num2
| size num1 == 3 = mul3 num1 num2
| otherwise =
do let m = min (size num1) (size num2)
m2 = m `div` 2
(low1, high1) <- splitDigits (fromIntegral m2) num1
(low2, high2) <- splitDigits (fromIntegral m2) num2
z0 <- karatsuba low1 low2
let midsize = max (size low1) (size high1)
low1' <- padTo midsize low1
low2' <- padTo midsize low2
high1' <- padTo midsize high1
high2' <- padTo midsize high2
mid1 <- add2 low1' high1'
mid2 <- add2 low2' high2'
z1 <- karatsuba mid1 mid2
z2 <- karatsuba high1 high2
let subsize = max (size z0) (max (size z1) (size z2))
sz0 <- padTo subsize z0
sz1 <- padTo subsize z1
sz2 <- padTo subsize z2
tmp <- sub2 sz1 sz2
z1' <- addZeros m2 =<< sub2 tmp sz0
z2' <- addZeros (m2 * 2) z2
let addsize = max (size z0) (max (size z1') (size z2'))
az0 <- padTo addsize z0
az1 <- padTo addsize z1'
az2 <- padTo addsize z2'
res <- add3 az2 az1 az0
forM_ (V.drop (m * 2) res) $ \ highDigit ->
-- this will only occur when (size res > (m * 2))
when (digit highDigit /= 0) $
fail "High bit found in Karatsuba result"
return (V.take (m * 2) res)
prop_KaratsubaWorks :: Int -> Integer -> Integer -> Bool
prop_KaratsubaWorks l x y =
fst $ runMath $ do num1 <- convertTo l' x'
num2 <- convertTo l' y'
res <- karatsuba num1 num2
let resInt = convertFrom res
sizeOk = size res == (l' * 2)
valOk = resInt == (x' * y')
return (sizeOk && valOk)
where
l' = (abs l `mod` (inputWordSize * 2)) + 2
x' = abs x `mod` (2 ^ (64 * l'))
y' = abs y `mod` (2 ^ (64 * l'))
prop_InstructionsWork' ::
([Instruction] -> [Instruction]) ->
Int ->
Integer ->
Integer ->
Bool
prop_InstructionsWork' f l x y =
let (value, instructions) = runMath $ do numx <- convertTo l' x'
numy <- convertTo l' y'
karatsuba numx numy
resGMP = x' * y'
resKaratsuba = convertFrom value
(endEnvironment, _) = run (Map.empty, Map.empty) (f instructions)
instrVersion = V.map (getv endEnvironment . name) value
in (resGMP == resKaratsuba) && (value == instrVersion)
where
l' = max 1 (abs l `mod` inputWordSize)
x' = abs x `mod` (2 ^ (64 * l'))
y' = abs y `mod` (2 ^ (64 * l'))
getv env n =
case Map.lookup n env of
Nothing -> error ("InstrProp lookup failure: " ++ show n)
Just v -> D n v
prop_InstructionsWork :: Int -> Integer -> Integer -> Bool
prop_InstructionsWork = prop_InstructionsWork' id
prop_SimplifiedInstructionsWork :: Int -> Integer -> Integer -> Bool
prop_SimplifiedInstructionsWork = prop_InstructionsWork' simplifyConstants
prop_InstructionsConsistent :: Int -> Integer -> Integer -> Integer -> Integer -> Bool
prop_InstructionsConsistent l a b x y =
let (_, instrs1) = runMath (karatsuba' a' b')
(_, instrs2) = runMath (karatsuba' x' y')
instrs1' = dropWhile isDeclare64 instrs1
instrs2' = dropWhile isDeclare64 instrs2
in instrs1' == instrs2'
where
l' = max 1 (abs l `mod` inputWordSize)
a' = abs a `mod` (2 ^ (64 * l'))
b' = abs b `mod` (2 ^ (64 * l'))
x' = abs x `mod` (2 ^ (64 * l'))
y' = abs y `mod` (2 ^ (64 * l'))
karatsuba' p q =
do num1 <- convertTo l' p
num2 <- convertTo l' q
karatsuba num1 num2
isDeclare64 i =
case i of
Declare64 _ _ -> True
_ -> False
-- -----------------------------------------------------------------------------
--
-- Test running
--
-- -----------------------------------------------------------------------------
runQuickCheck :: Testable prop => String -> prop -> IO ()
runQuickCheck testname prop =
do putStr testname
quickCheck (withMaxSuccess 1000 prop)
runChecks :: IO ()
runChecks =
do runQuickCheck "Int -> Num -> Int " prop_ConversionWorksInt
runQuickCheck "Split Isn't Dumb " prop_SplitDigitsIsntTerrible
runQuickCheck "More 0s is Shift " prop_AddZerosIsShift
runQuickCheck "PadTo Does That " prop_PadToWorks
runQuickCheck "Add2 Works " prop_Add2Works
runQuickCheck "Add3 Works " prop_Add3Works
runQuickCheck "Sub2 Works " prop_Sub2Works
runQuickCheck "Mul1 Works " prop_Mul1Works
runQuickCheck "Mul2 Works " prop_Mul2Works
runQuickCheck "Mul3 Works " prop_Mul3Works
runQuickCheck "Karatsuba Works " prop_KaratsubaWorks
runQuickCheck "Instructions Work " prop_InstructionsWork
runQuickCheck "Simpl. Instructions Work " prop_SimplifiedInstructionsWork
runQuickCheck "Generation Consistent " prop_InstructionsConsistent

View File

@@ -1,62 +0,0 @@
module Main
where
import Base(base)
import BinaryOps(binaryOps)
import Compare(comparisons)
import Conversions(conversions)
import CryptoNum(cryptoNum)
import Control.Monad(forM_,unless)
import Data.Word(Word)
import File(File,Task(..),addModuleTasks,makeTasks)
import Gen(runGen)
import System.Directory(createDirectoryIfMissing)
import System.Environment(getArgs)
import System.Exit(die)
import System.FilePath(takeDirectory,(</>))
lowestBitsize :: Word
lowestBitsize = 192
highestBitsize :: Word
highestBitsize = 512
bitsizes :: [Word]
bitsizes = [lowestBitsize,lowestBitsize+64..highestBitsize]
unsignedFiles :: [File]
unsignedFiles = [
base
, binaryOps
, comparisons
, conversions
, cryptoNum
]
signedFiles :: [File]
signedFiles = [
]
makeTasks' :: FilePath -> FilePath -> [File] -> [Task]
makeTasks' srcPath testPath files =
concatMap (\ sz -> concatMap (makeTasks srcPath testPath sz bitsizes) files) bitsizes
makeAllTasks :: FilePath -> FilePath -> [Task]
makeAllTasks srcPath testPath = addModuleTasks srcPath $
makeTasks' (srcPath </> "unsigned") testPath unsignedFiles ++
makeTasks' (srcPath </> "signed") testPath signedFiles
main :: IO ()
main =
do args <- getArgs
unless (length args == 1) $
die ("generation takes exactly one argument, the target directory")
let topLevel = head args
srcPath = topLevel </> "src"
testPath = topLevel </> "testdata"
tasks = makeAllTasks srcPath testPath
total = length tasks
forM_ (zip [(1::Word)..] tasks) $ \ (i, task) ->
do putStrLn ("[" ++ show i ++ "/" ++ show total ++ "] " ++ outputFile task)
createDirectoryIfMissing True (takeDirectory (outputFile task))
runGen (outputFile task) (fileGenerator task)

331
generation/src/ModInv.hs Normal file
View File

@@ -0,0 +1,331 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RecordWildCards #-}
module ModInv(
generateModInvOps
)
where
import Control.Exception(assert)
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Generators
import GHC.Integer.GMP.Internals(powModInteger, recipModInteger)
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
generateModInvOps :: RustModule
generateModInvOps = RustModule {
predicate = \ me others -> (me + 64) `elem` others,
suggested = \ me -> [me + 64],
outputName = "modinv",
isUnsigned = True,
generator = declareModInv,
testCase = Just generateModInvTest
}
declareModInv :: Word -> [Word] -> SourceFile Span
declareModInv bitsize _ =
let sname = mkIdent ("I" ++ show (bitsize + 64))
uname = mkIdent ("U" ++ show bitsize)
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::convert::TryFrom;
use crate::{CryptoNum,ModularInversion};
use crate::signed::$$sname;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::unsigned::$$uname;
impl ModularInversion for $$uname {
type Signed = $$sname;
fn modinv(&self, phi: &$$uname) -> Option<$$uname>
{
let (_, mut b, g) = phi.egcd(&self);
if g != $$sname::from(1i64) {
return None;
}
let sphi = $$sname::from(phi);
while b.is_negative() {
b += &sphi;
}
if b > sphi {
b -= &sphi;
}
Some($$uname::try_from(b).expect("overflow/underflow in modinv result"))
}
fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) {
// INPUT: two positive integers x and y.
let mut x = $$sname::from(self);
let mut y = $$sname::from(rhs);
// OUTPUT: integers a, b, and v such that ax + by = v,
// where v = gcd(x, y).
// 1. g1.
let mut gshift: usize = 0;
// 2. While x and y are both even, do the following: xx/2,
// yy/2, g2g.
while x.is_even() && y.is_even() {
x >>= 1u64;
y >>= 1u64;
gshift += 1;
}
// 3. ux, vy, A1, B0, C0, D1.
let mut u = x.clone();
let mut v = y.clone();
#[allow(non_snake_case)]
let mut A = $$sname::from(1i64);
#[allow(non_snake_case)]
let mut B = $$sname::zero();
#[allow(non_snake_case)]
let mut C = $$sname::zero();
#[allow(non_snake_case)]
let mut D = $$sname::from(1i64);
loop {
// 4. While u is even do the following:
while u.is_even() {
// 4.1 uu/2.
u >>= 1u64;
// 4.2 If AB0 (mod 2) then AA/2, BB/2; otherwise,
// A(A + y)/2, B(B x)/2.
if A.is_even() && B.is_even() {
A >>= 1u64;
B >>= 1u64;
} else {
A += &y;
A >>= 1u64;
B -= &x;
B >>= 1u64;
}
}
// 5. While v is even do the following:
while v.is_even() {
// 5.1 vv/2.
v >>= 1u64;
// 5.2 If C D 0 (mod 2) then CC/2, DD/2; otherwise,
// C(C + y)/2, D(D x)/2.
if C.is_even() && D.is_even() {
C >>= 1u64;
D >>= 1u64;
} else {
C += &y;
C >>= 1u64;
D -= &x;
D >>= 1u64;
}
}
// 6. If uv then uuv, AAC,BBD;
// otherwise,vvu, CCA, DDB.
if u >= v {
u -= &v;
A -= &C;
B -= &D;
} else {
v -= &u;
C -= &A;
D -= &B;
}
// 7. If u = 0, then aC, bD, and return(a, b, g · v);
// otherwise, go to step 4.
if u.is_zero() {
return (C, D, v << gshift);
}
}
}
fn gcd_is_one(&self, b: &$$uname) -> bool {
let mut u = self.clone();
let mut v = b.clone();
let one = $$uname::from(1u64);
if u.is_zero() {
return v == one;
}
if v.is_zero() {
return u == one;
}
if u.is_even() && v.is_even() {
return false;
}
while u.is_even() {
u >>= 1u64;
}
loop {
while v.is_even() {
v >>= 1u64;
}
// u and v guaranteed to be odd right now.
if u > v {
// make sure that v > u, so that our subtraction works
// out.
let t = u;
u = v;
v = t;
}
v -= &u;
if v.is_zero() {
return u == one;
}
}
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("modinv", $$(testFileLit)), 6, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let (neg3, abytes) = case.get("a").unwrap();
let (neg4, bbytes) = case.get("b").unwrap();
let (neg5, vbytes) = case.get("v").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$uname::from_bytes(xbytes);
let y = $$uname::from_bytes(ybytes);
let z = $$uname::from_bytes(zbytes);
let mut a = $$sname::from_bytes(abytes);
let mut b = $$sname::from_bytes(bbytes);
let mut v = $$sname::from_bytes(vbytes);
if *neg3 { a = -a }
if *neg4 { b = -b }
if *neg5 { v = -v }
let (mya, myb, myv) = x.egcd(&y);
assert_eq!(a, mya);
assert_eq!(b, myb);
assert_eq!(v, myv);
assert_eq!(z, x.modinv(&y).expect("Didn't find a modinv?"));
assert_eq!(v == $$sname::from(1u64), x.gcd_is_one(&y));
});
}
|]
generateModInvTest :: RandomGen g => Word -> g -> (Map String String, g)
generateModInvTest size g = go g
where
go g0 =
let (x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
z = recipModInteger x y
(a, b, v) = extendedGCD x y
tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX z), ("a", showX a),
("b", showX b), ("v", showX v)]
in if z == 0
then go g2
else assert (z < y) $
assert (powModInteger x z y == 1) $
-- assert ((x * z) `mod` y == 1) $
-- assert (((a * x) + (b * y)) == v) $
-- assert (v == gcd x y) $
(tcase, g2)
extendedGCD :: Integer -> Integer -> (Integer, Integer, Integer)
extendedGCD x y = (a, b, g * (v finalState))
where
(x', y', g, initState) = initialState x y 1
finalState = runAlgorithm x' y' initState
a = bigC finalState
b = bigD finalState
data AlgState = AlgState {
u :: Integer,
v :: Integer,
bigA :: Integer,
bigB :: Integer,
bigC :: Integer,
bigD :: Integer
}
initialState :: Integer -> Integer -> Integer -> (Integer, Integer, Integer, AlgState)
initialState x y g | even x && even y = initialState (x `div` 2) (y `div` 2) (g * 2)
| otherwise = (x, y, g, AlgState x y 1 0 0 1)
printState :: AlgState -> IO ()
printState a =
do putStrLn ("u: " ++ showX (u a))
putStrLn ("v: " ++ showX (v a))
putStrLn ("A: " ++ showX (bigA a))
putStrLn ("B: " ++ showX (bigB a))
putStrLn ("C: " ++ showX (bigC a))
putStrLn ("D: " ++ showX (bigD a))
runAlgorithm :: Integer -> Integer -> AlgState -> AlgState
runAlgorithm x y state | u state == 0 = state
| otherwise = runAlgorithm x y state6
where
state4 = step4 x y state
state5 = step5 x y state4
state6 = step6 state5
step4 :: Integer -> Integer -> AlgState -> AlgState
step4 x y input@AlgState{..} | even u = step4 x y input'
| otherwise = input
where
input' = AlgState u' v bigA' bigB' bigC bigD
u' = u `div` 2
bigA' | even bigA && even bigB = bigA `div` 2
| otherwise = (bigA + y) `div` 2
bigB' | even bigA && even bigB = bigB `div` 2
| otherwise = (bigB - x) `div` 2
step5 :: Integer -> Integer -> AlgState -> AlgState
step5 x y input@AlgState{..} | even v = step5 x y input'
| otherwise = input
where
input' = AlgState u v' bigA bigB bigC' bigD'
v' = v `div` 2
bigC' | even bigC && even bigD = bigC `div` 2
| otherwise = (bigC + y) `div` 2
bigD' | even bigC && even bigD = bigD `div` 2
| otherwise = (bigD - x) `div` 2
step6 :: AlgState -> AlgState
step6 AlgState{..}
| u >= v = AlgState (u - v) v (bigA - bigC) (bigB - bigD) bigC bigD
| otherwise = AlgState u (v - u) bigA bigB (bigC - bigA) (bigD - bigB)
_run :: Integer -> Integer -> IO ()
_run inputx inputy =
do let (x, y, g, initState) = initialState inputx inputy 1
finalState <- go x y initState
putStrLn ("-- FINAL STATE -----------------------")
printState finalState
putStrLn ("Final value: " ++ showX (g * v finalState))
putStrLn ("-- RUN ------")
printState (runAlgorithm x y initState)
putStrLn ("-- NORMAL ------")
let (a, b, v) = extendedGCD inputx inputy
putStrLn ("a: " ++ showX a)
putStrLn ("b: " ++ showX b)
putStrLn ("v: " ++ showX v)
where
go x y state =
do putStrLn "-- STATE -----------------------------"
printState state
if u state == 0
then return state
else do let state' = step4 x y state
state'' = step5 x y state'
state''' = step6 state''
go x y state'''

129
generation/src/ModOps.hs Normal file
View File

@@ -0,0 +1,129 @@
{-# LANGUAGE QuasiQuotes #-}
module ModOps(modulusOps)
where
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Generators
import GHC.Integer.GMP.Internals(powModInteger)
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
modulusOps :: RustModule
modulusOps = RustModule {
predicate = \ me others -> (me * 2) `elem` others,
suggested = \ me -> [me * 2],
outputName = "modops",
isUnsigned = True,
generator = declareModOps,
testCase = Just generateModulusTest
}
declareModOps :: Word -> [Word] -> SourceFile Span
declareModOps bitsize _ =
let sname = mkIdent ("U" ++ show bitsize)
bname = mkIdent ("U" ++ show (bitsize * 2))
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::convert::TryFrom;
use crate::unsigned::{$$sname, $$bname};
use crate::{DivMod, ModularOperations};
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
impl ModularOperations for $$sname {
fn reduce(&self, m: &$$sname) -> $$sname {
let (_, res) = self.divmod(m);
res
}
fn modmul(&self, y: &$$sname, m: &$$sname) -> $$sname {
let r = self * y;
let bigm = $$bname::from(m);
let bigres = r % bigm;
$$sname::try_from(bigres)
.expect("Mathematics is broken?! (mod returned too big result")
}
fn modsq(&self, m: &$$sname) -> $$sname {
let r = self * self;
let bigm = $$bname::from(m);
let bigres = r % bigm;
$$sname::try_from(bigres)
.expect("Mathematics is broken?! (mod returned too big result")
}
fn modexp(&self, e: &$$sname, m: &$$sname) -> $$sname {
let mut r = $$sname::from(1u64);
let bigm = $$bname::from(m);
for digit in e.value.iter().rev() {
for bit in (0..64).rev() {
r = r.modsq(&m);
let big_possible_r = (&r * self) % &bigm;
let possible_r = $$sname::try_from(big_possible_r)
.expect("Math is broken (again)");
let bit = (*digit >> bit) & 1;
r = if bit == 1 { possible_r } else { r };
}
}
r
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("modops", $$(testFileLit)), 7, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, mbytes) = case.get("m").unwrap();
let (neg3, rbytes) = case.get("r").unwrap();
let (neg4, tbytes) = case.get("t").unwrap();
let (neg5, sbytes) = case.get("s").unwrap();
let (neg6, ebytes) = case.get("e").unwrap();
assert!(!neg0 && !neg1 && !neg2 && !neg3 && !neg4 && !neg5 && !neg6);
let x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let m = $$sname::from_bytes(&mbytes);
let r = $$sname::from_bytes(&rbytes);
let t = $$sname::from_bytes(&tbytes);
let s = $$sname::from_bytes(&sbytes);
let e = $$sname::from_bytes(&ebytes);
assert_eq!(r, x.reduce(&m));
assert_eq!(t, x.modmul(&y, &m));
assert_eq!(s, x.modsq(&m));
assert_eq!(e, x.modexp(&y, &m));
});
}
|]
generateModulusTest :: RandomGen g => Word -> g -> (Map String String, g)
generateModulusTest size g = go g
where
go g0 =
let (x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
(m, g3) = generateNum g2 size
tcase = Map.fromList [("x", showX x), ("y", showX y),
("m", showX m),
("r", showX (x `mod` m)),
("t", showX ((x * y) `mod` m)),
("s", showX (powModInteger x 2 m)),
("e", showX (powModInteger x y m))
]
in if y < 2
then go g3
else (tcase, g3)

325
generation/src/Multiply.hs Normal file
View File

@@ -0,0 +1,325 @@
{-# LANGUAGE QuasiQuotes #-}
module Multiply(
safeMultiplyOps
, unsafeMultiplyOps
)
where
import Data.Bits((.&.))
import Data.List(foldl')
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Generators
import Karatsuba
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
safeMultiplyOps :: RustModule
safeMultiplyOps = RustModule {
predicate = \ me others -> (me * 2) `elem` others,
suggested = \ me -> [me * 2],
outputName = "safe_mul",
isUnsigned = True,
generator = declareSafeMulOperators,
testCase = Just generateSafeTest
}
unsafeMultiplyOps :: RustModule
unsafeMultiplyOps = RustModule {
predicate = \ _ _ -> True,
suggested = const [],
outputName = "unsafe_mul",
isUnsigned = True,
generator = declareUnsafeMulOperators,
testCase = Just generateUnsafeTest
}
declareSafeMulOperators :: Word -> [Word] -> SourceFile Span
declareSafeMulOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize)
dname = mkIdent ("U" ++ show (bitsize * 2))
fullRippleMul = generateMultiplier True (bitsize `div` 64) "rhs" "res"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::Mul;
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use crate::unsigned::{$$sname,$$dname};
impl Mul for $$sname {
type Output = $$dname;
fn mul(self, rhs: $$sname) -> $$dname {
&self * &rhs
}
}
impl<'a> Mul<&'a $$sname> for $$sname {
type Output = $$dname;
fn mul(self, rhs: &$$sname) -> $$dname {
&self * rhs
}
}
impl<'a> Mul<$$sname> for &'a $$sname {
type Output = $$dname;
fn mul(self, rhs: $$sname) -> $$dname {
self * &rhs
}
}
impl<'a,'b> Mul<&'a $$sname> for &'b $$sname {
type Output = $$dname;
fn mul(self, rhs: &$$sname) -> $$dname {
let mut res = $$dname::zero();
$@{fullRippleMul}
res
}
}
#[cfg(test)]
quickcheck! {
fn multiplication_symmetric(a: $$sname, b: $$sname) -> bool {
(&a * &b) == (&b * &a)
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("safe_mul", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$dname::from_bytes(&zbytes);
assert_eq!(z, x * y);
});
}
|]
declareUnsafeMulOperators :: Word -> [Word] -> SourceFile Span
declareUnsafeMulOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize)
halfRippleMul = generateMultiplier False (bitsize `div` 64) "rhs" "self"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::MulAssign;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use crate::unsigned::$$sname;
impl MulAssign for $$sname {
fn mul_assign(&mut self, rhs: $$sname) {
self.mul_assign(&rhs);
}
}
impl<'a> MulAssign<&'a $$sname> for $$sname {
fn mul_assign(&mut self, rhs: &$$sname) {
$@{halfRippleMul}
}
}
#[cfg(test)]
quickcheck! {
fn multiplication_symmetric(a: $$sname, b: $$sname) -> bool {
let a2 = a.clone();
let mut b2 = b.clone();
let mut a3 = a.clone();
let b3 = b.clone();
b2 *= &a2;
a3 *= b3;
a3 == b2
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("unsafe_mul", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let mut x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$sname::from_bytes(&zbytes);
x *= y;
assert_eq!(z, x);
});
}
|]
-- -----------------------------------------------------------------------------
generateMultiplier :: Bool -> Word -> String -> String -> [Stmt Span]
generateMultiplier fullmul size inName outName = readIns ++ operations ++ writeOuts
where
outDigits | fullmul = size * 2
| otherwise = size
--
outVars = map (("res" ++) . show) [0..outDigits-1]
instructionData = generateInstructions size
instrOutputs = take (fromIntegral outDigits) (idOutput instructionData)
instructions = releaseUnnecessary instrOutputs (idInstructions instructionData)
--
readIns = map (load "self" "x") [0..size-1] ++ map (load inName "y") [0..size-1]
writeOuts = map (store "res") [0..outDigits-1]
--
env = zip (idInput1 instructionData) (map (\ i -> "x" ++ show i) [0..size-1]) ++
zip (idInput2 instructionData) (map (\ i -> "y" ++ show i) [0..size-1]) ++
zip (idOutput instructionData) outVars
operations = map (translateInstruction env) instructions
--
load rhs vname i =
let liti = toLit i
vec = mkIdent rhs
var = mkIdent (vname ++ show i)
in [stmt| let $$var = $$vec.value[$$(liti)]; |]
store vname i =
let liti = toLit i
vec = mkIdent outName
var = mkIdent (vname ++ show i)
in [stmt| $$vec.value[$$(liti)] = $$var; |]
translateInstruction :: [(Variable, String)] -> Instruction -> Stmt Span
translateInstruction env instr =
case instr of
Add outname args ->
let outid = mkIdentO outname
args' = map (\x -> [expr| $$x |]) (map mkIdentI args)
adds = foldl (\ x y -> [expr| $$(x) + $$(y) |])
(head args')
(tail args')
in [stmt| let $$outid: u128 = $$(adds); |]
CastDown outname arg ->
let outid = mkIdentO outname
inid = mkIdentI arg
in [stmt| let $$outid: u64 = $$inid as u64; |]
CastUp outname arg ->
let outid = mkIdentO outname
inid = mkIdentI arg
in [stmt| let $$outid: u128 = $$inid as u128; |]
Complement outname arg ->
let outid = mkIdentO outname
inid = mkIdentI arg
in [stmt| let $$outid: u64 = !$$inid; |]
Declare64 outname _ | Just inName <- lookup outname env ->
let outid = mkIdent (variableName outname)
inid = mkIdent inName
in [stmt| let $$outid: u64 = $$inid; |]
Declare64 outname arg ->
let outid = mkIdentO outname
val = toLit (fromIntegral arg)
in [stmt| let $$outid: u64 = $$(val); |]
Declare128 outname arg ->
let outid = mkIdentO outname
val = toLit (fromIntegral arg)
in [stmt| let $$outid: u128 = $$(val); |]
Mask outname arg mask ->
let outid = mkIdentO outname
inid = mkIdentI arg
val = toLit (fromIntegral mask)
in [stmt| let $$outid: u128 = $$inid & $$(val); |]
Multiply outname args ->
let outid = mkIdentO outname
args' = map (\x -> [expr| $$x |]) (map mkIdentI args)
muls = foldl (\ x y -> [expr| $$(x) * $$(y) |])
(head args')
(tail args')
in [stmt| let $$outid: u128 = $$(muls); |]
ShiftR outname arg amt ->
let outid = mkIdentO outname
inid = mkIdentI arg
val = toLit (fromIntegral amt)
in [stmt| let $$outid: u128 = $$inid >> $$(val); |]
where
mkIdentO :: Variable -> Ident
mkIdentO v | Just x <- lookup v env = mkIdent x
| otherwise = mkIdent (variableName v)
mkIdentI :: Variable -> Ident
mkIdentI = mkIdent . variableName
releaseUnnecessary :: [Variable] -> [Instruction] -> [Instruction]
releaseUnnecessary outkeys instrs = go (Set.fromList outkeys) [] rInstrs
where
rInstrs = reverse instrs
--
go _ acc [] = acc
go required acc (cur:rest)
| outVar cur `Set.member` required =
go (foldl' (flip Set.insert) required (inVars cur)) (cur:acc) rest
| otherwise =
go required acc rest
outVar :: Instruction -> Variable
outVar instr =
case instr of
Add outname _ -> outname
CastDown outname _ -> outname
CastUp outname _ -> outname
Complement outname _ -> outname
Declare64 outname _ -> outname
Declare128 outname _ -> outname
Mask outname _ _ -> outname
Multiply outname _ -> outname
ShiftR outname _ _ -> outname
inVars :: Instruction -> [Variable]
inVars instr =
case instr of
Add _ args -> args
CastDown _ arg -> [arg]
CastUp _ arg -> [arg]
Complement _ arg -> [arg]
Declare64 _ _ -> []
Declare128 _ _ -> []
Mask _ arg _ -> [arg]
Multiply _ args -> args
ShiftR _ arg _ -> [arg]
-- -----------------------------------------------------------------------------
generateSafeTest :: RandomGen g => Word -> g -> (Map String String, g)
generateSafeTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX (x * y))]
generateUnsafeTest :: RandomGen g => Word -> g -> (Map String String, g)
generateUnsafeTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
z = (x * y) .&. ((2 ^ size) - 1)
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)]

View File

@@ -0,0 +1,144 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE RankNTypes #-}
module RustModule(
RustModule(..),
Task(..),
generateTasks,
testFile
)
where
import Control.Monad(forM_, unless)
import Data.Char(toUpper)
import Data.List(partition)
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Data.Maybe(mapMaybe)
import Language.Rust.Data.Ident(mkIdent)
import Language.Rust.Data.Position(Position(NoPosition), Span(Span))
import Language.Rust.Pretty(writeSourceFile)
import Language.Rust.Quote(item, sourceFile)
import Language.Rust.Syntax(Item(..), SourceFile(..), Visibility(..))
import System.CPUTime(getCPUTime)
import System.IO(Handle,hPutStrLn)
import System.Random(RandomGen(..))
minimumTestCases :: Int
minimumTestCases = 10
maximumTestCases :: Int
maximumTestCases = 5000
targetTestGenerationTime :: Float
targetTestGenerationTime = 2.0 -- in seconds
targetTestGenerationPicos :: Integer
targetTestGenerationPicos =
floor (targetTestGenerationTime * 1000000000000.0)
data RustModule = RustModule {
predicate :: Word -> [Word] -> Bool,
suggested :: Word -> [Word],
outputName :: String,
isUnsigned :: Bool,
generator :: Word -> [Word] -> SourceFile Span,
testCase :: forall g. RandomGen g => Maybe (Word -> g -> (Map String String, g))
}
data Task = Task {
outputFile :: FilePath,
writer :: Handle -> IO ()
}
testFile :: Bool -> Word -> FilePath
testFile True size = "U" ++ show5 size ++ ".test"
testFile False size = "I" ++ show5 size ++ ".test"
show5 :: Word -> String
show5 = go . show
where
go x | length x < 5 = go ('0' : x)
| otherwise = x
generateTasks :: RandomGen g => g -> [RustModule] -> [Word] -> [Task]
generateTasks rng modules sizes = allTheFiles
where
allTheFiles = implementationsAndTests ++
[lump "i" "src/signed.rs", lump "u" "src/unsigned.rs"]
implementationsAndTests = concatMap generateModules sizes
--
lump prefix file =
let moduleNames = map (\s -> prefix ++ show s) sizes
moduleIdents = map mkIdent moduleNames
types = map (mkIdent . map toUpper) moduleNames
mods = map (\ name -> [item| mod $$name; |]) moduleIdents
uses = zipWith (\ mname tname -> [item| pub use $$mname::$$tname; |])
moduleIdents types
source = [sourceFile| $@{mods} $@{uses} |]
in Task file (\hndl -> writeSourceFile hndl source)
--
generateModules size =
let modules' = filter (\m -> predicate m size sizes) modules
(umodules, smodules) = partition isUnsigned modules'
unsignedTasks = generateImplementations "U" size umodules
signedTasks = generateImplementations "I" size smodules
in unsignedTasks ++ signedTasks ++ mapMaybe (generateTests size rng) modules'
--
generateImplementations startsWith size modules'
| null modules' = []
| otherwise =
let name = mkIdent (startsWith ++ show size)
baseInclude = [item| pub use self::base::$$name; |]
isSigned = startsWith == "I"
moduleSources = map (generateSubmodule isSigned size sizes) modules'
moduleFile | isSigned = "src/signed/i" ++ show size ++ ".rs"
| otherwise = "src/unsigned/u" ++ show size ++ ".rs"
allSource = SourceFile Nothing [] (baseInclude : map fst moduleSources)
in [Task moduleFile (\ hndl -> writeSourceFile hndl allSource)] ++ map snd moduleSources
generateSubmodule :: Bool -> Word -> [Word] -> RustModule -> (Item Span, Task)
generateSubmodule isSigned size allSizes m =
let modBody = generator m size allSizes
modName = mkIdent (outputName m)
modDecl = Mod [] CrateV modName Nothing (Span NoPosition NoPosition)
modFile | isSigned = "src/signed/i" ++ show size ++ "/" ++ outputName m ++ ".rs"
| otherwise = "src/unsigned/u" ++ show size ++ "/" ++ outputName m ++ ".rs"
in (modDecl, Task modFile (\ hndl -> writeSourceFile hndl modBody))
generateTests :: RandomGen g =>
Word -> g ->
RustModule ->
Maybe Task
generateTests size rng m = fmap builder (testCase m)
where
builder testGenerator =
let outFile = "testdata/" ++ outputName m ++ "/" ++ testFile (isUnsigned m) size
testGenAction hndl = writeTestCases hndl (snd (split rng)) (testGenerator size)
in Task outFile testGenAction
writeTestCases :: RandomGen g =>
Handle -> g ->
(g -> (Map String String, g)) ->
IO ()
writeTestCases hndl rng nextTest =
do startTime <- getCPUTime
let stopTime = startTime + targetTestGenerationPicos
go 0 stopTime rng
where
go x endTime g
| x >= maximumTestCases = return ()
| x < minimumTestCases = emit x endTime g
| otherwise =
do now <- getCPUTime
unless (now >= endTime) $
emit x endTime g
--
emit x endTime g =
do let (test, g') = nextTest g
writeTestCase hndl test
go (x + 1) endTime g'
writeTestCase :: Handle -> Map String String -> IO ()
writeTestCase hndl test =
forM_ (Map.toList test) $ \ (key, value) ->
hPutStrLn hndl (key ++ ": " ++ value)

270
generation/src/Scale.hs Normal file
View File

@@ -0,0 +1,270 @@
{-# LANGUAGE QuasiQuotes #-}
module Scale(
safeScaleOps
, unsafeScaleOps
)
where
import Data.Bits((.&.))
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Generators
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
safeScaleOps :: RustModule
safeScaleOps = RustModule {
predicate = \ me others -> (me + 64) `elem` others,
suggested = \ me -> [me + 64],
outputName = "safe_scale",
isUnsigned = True,
generator = declareSafeScaleOperators,
testCase = Just generateSafeTest
}
unsafeScaleOps :: RustModule
unsafeScaleOps = RustModule {
predicate = \ _ _ -> True,
suggested = const [],
outputName = "unsafe_scale",
isUnsigned = True,
generator = declareUnsafeScaleOperators,
testCase = Just generateUnsafeTest
}
declareSafeScaleOperators :: Word -> [Word] -> SourceFile Span
declareSafeScaleOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize)
dname = mkIdent ("U" ++ show (bitsize + 64))
fullRippleScale = generateScaletiplier True (bitsize `div` 64) "rhs" "res"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::Mul;
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::unsigned::{$$sname,$$dname};
impl Mul<u8> for $$sname {
type Output = $$dname;
fn mul(self, rhs: u8) -> $$dname {
&self * (rhs as u64)
}
}
impl<'a> Mul<u8> for &'a $$sname {
type Output = $$dname;
fn mul(self, rhs: u8) -> $$dname {
self * (rhs as u64)
}
}
impl Mul<u16> for $$sname {
type Output = $$dname;
fn mul(self, rhs: u16) -> $$dname {
&self * (rhs as u64)
}
}
impl<'a> Mul<u16> for &'a $$sname {
type Output = $$dname;
fn mul(self, rhs: u16) -> $$dname {
self * (rhs as u64)
}
}
impl Mul<u32> for $$sname {
type Output = $$dname;
fn mul(self, rhs: u32) -> $$dname {
&self * (rhs as u64)
}
}
impl<'a> Mul<u32> for &'a $$sname {
type Output = $$dname;
fn mul(self, rhs: u32) -> $$dname {
self * (rhs as u64)
}
}
impl Mul<usize> for $$sname {
type Output = $$dname;
fn mul(self, rhs: usize) -> $$dname {
&self * (rhs as u64)
}
}
impl<'a> Mul<usize> for &'a $$sname {
type Output = $$dname;
fn mul(self, rhs: usize) -> $$dname {
self * (rhs as u64)
}
}
impl Mul<u64> for $$sname {
type Output = $$dname;
fn mul(self, rhs: u64) -> $$dname {
&self * (rhs as u64)
}
}
impl<'a> Mul<u64> for &'a $$sname {
type Output = $$dname;
fn mul(self, rhs: u64) -> $$dname {
let mut res = $$dname::zero();
$@{fullRippleScale}
res
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("safe_scale", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$dname::from_bytes(&zbytes);
assert_eq!(z, x * y.value[0]);
});
}
|]
declareUnsafeScaleOperators :: Word -> [Word] -> SourceFile Span
declareUnsafeScaleOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize)
halfRippleScale = generateScaletiplier False (bitsize `div` 64) "rhs" "self"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::MulAssign;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::unsigned::$$sname;
impl MulAssign<u8> for $$sname {
fn mul_assign(&mut self, rhs: u8) {
self.mul_assign(rhs as u64);
}
}
impl MulAssign<u16> for $$sname {
fn mul_assign(&mut self, rhs: u16) {
self.mul_assign(rhs as u64);
}
}
impl MulAssign<u32> for $$sname {
fn mul_assign(&mut self, rhs: u32) {
self.mul_assign(rhs as u64);
}
}
impl MulAssign<usize> for $$sname {
fn mul_assign(&mut self, rhs: usize) {
self.mul_assign(rhs as u64);
}
}
impl MulAssign<u64> for $$sname {
fn mul_assign(&mut self, rhs: u64) {
$@{halfRippleScale}
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("unsafe_scale", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let mut x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$sname::from_bytes(&zbytes);
x *= y.value[0];
assert_eq!(z, x);
});
}
|]
-- -----------------------------------------------------------------------------
generateScaletiplier :: Bool -> Word -> String -> String -> [Stmt Span]
generateScaletiplier full size input output = loaders ++ [bigy] ++ ripples ++
carryCatch ++ stores
where
outSize | full = size + 1
| otherwise = size
loaders = map load [0..size-1]
bigy = let invar = mkIdent input
in [stmt| let y = $$invar as u128; |]
ripples = map scale [0..size-1]
carryCatch | not full = []
| otherwise = let outvar = mkIdent ("scaled" ++ show size)
lastv = mkIdent ("scaled" ++ show (size - 1))
in [[stmt| let $$outvar = ($$lastv >> 64) as u64; |]]
stores = map store [0..outSize-1]
--
load i =
let var = mkIdent ("x" ++ show i)
liti = toLit i
in [stmt| let $$var = self.value[$$(liti)]; |]
scale i =
let out = mkIdent ("scaled" ++ show i)
x = mkIdent ("x" ++ show i)
y = mkIdent "y"
--
prevName = mkIdent ("scaled" ++ show (i - 1))
prev | i == 0 = toLit 0
| otherwise = [expr| $$prevName >> 64 |]
in [stmt| let $$out = ($$x as u128) * $$y + $$(prev); |]
store i =
let var = mkIdent ("scaled" ++ show i)
out = mkIdent output
liti = toLit i
in [stmt| $$out.value[$$(liti)] = $$var as u64; |]
-- -----------------------------------------------------------------------------
generateSafeTest :: RandomGen g => Word -> g -> (Map String String, g)
generateSafeTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 64
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX (x * y))]
generateUnsafeTest :: RandomGen g => Word -> g -> (Map String String, g)
generateUnsafeTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 64
z = (x * y) .&. ((2 ^ size) - 1)
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)]

319
generation/src/Shift.hs Normal file
View File

@@ -0,0 +1,319 @@
{-# LANGUAGE QuasiQuotes #-}
module Shift(shiftOps, signedShiftOps)
where
import Data.Bits(shiftL,shiftR)
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Generators
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
shiftOps :: RustModule
shiftOps = RustModule {
predicate = \ _ _ -> True,
suggested = const [],
outputName = "shift",
isUnsigned = True,
generator = declareShiftOperators,
testCase = Just generateTest
}
signedShiftOps :: RustModule
signedShiftOps = RustModule {
predicate = \ _ _ -> True,
suggested = const [],
outputName = "sshift",
isUnsigned = False,
generator = declareSignedShiftOperators,
testCase = Just generateSignedTest
}
declareShiftOperators :: Word -> [Word] -> SourceFile Span
declareShiftOperators bitsize _ =
let struct_name = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64
unsignedShifts = generateUnsigneds struct_name
shlUsizeImpls = generateBaseUsizes struct_name
shlActualImpl = concatMap actualShlImplLines [1..entries-1]
shrActualImpl = concatMap (actualShrImplLines False entries) (reverse [0..entries-1])
resAssign = map (reassignSelf False) [0..entries-1]
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
#[cfg(test)]
use core::convert::TryFrom;
use core::ops::{Shl,ShlAssign};
use core::ops::{Shr,ShrAssign};
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use super::$$struct_name;
impl ShlAssign<usize> for $$struct_name {
fn shl_assign(&mut self, rhs: usize) {
let digits = rhs / 64;
let bits = rhs % 64;
let shift = (64 - bits) as u32;
let base0 = if digits == 0 { self.value[0] } else { 0 };
let res0 = base0 << bits;
$@{shlActualImpl}
$@{resAssign}
}
}
impl ShrAssign<usize> for $$struct_name {
fn shr_assign(&mut self, rhs: usize) {
let digits = rhs / 64;
let bits = rhs % 64;
let mask = !(0xFFFFFFFFFFFFFFFFu64 << bits);
let shift = (64 - bits) as u32;
let arith_base = 0;
$@{shrActualImpl}
$@{resAssign}
}
}
$@{shlUsizeImpls}
$@{unsignedShifts}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("shift", $$(testFileLit)), 4, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, sbytes) = case.get("s").unwrap();
let (neg2, lbytes) = case.get("l").unwrap();
let (neg3, rbytes) = case.get("r").unwrap();
assert!(!neg0 && !neg1 && !neg2 && !neg3);
let x = $$struct_name::from_bytes(xbytes);
let s = usize::try_from($$struct_name::from_bytes(sbytes)).unwrap();
let l = $$struct_name::from_bytes(lbytes);
let r = $$struct_name::from_bytes(rbytes);
assert_eq!(l, &x << s);
assert_eq!(r, &x >> s);
});
}
|]
declareSignedShiftOperators :: Word -> [Word] -> SourceFile Span
declareSignedShiftOperators bitsize _ =
let struct_name = mkIdent ("I" ++ show bitsize)
entries = bitsize `div` 64
unsignedShifts = generateUnsigneds struct_name
shlUsizeImpls = generateBaseUsizes struct_name
shrActualImpl = concatMap (actualShrImplLines True entries) (reverse [0..entries-1])
resAssign = map (reassignSelf True) [0..entries-1]
testFileLit = Lit [] (Str (testFile False bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
#[cfg(test)]
use core::convert::TryFrom;
use core::ops::{Shl,ShlAssign};
use core::ops::{Shr,ShrAssign};
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use super::$$struct_name;
impl ShlAssign<usize> for $$struct_name {
fn shl_assign(&mut self, rhs: usize) {
self.contents <<= rhs;
}
}
impl ShrAssign<usize> for $$struct_name {
fn shr_assign(&mut self, rhs: usize) {
let digits = rhs / 64;
let bits = rhs % 64;
let mask = !(0xFFFFFFFFFFFFFFFFu64 << bits);
let shift = (64 - bits) as u32;
let arith_base = if self.is_negative() {
0xFFFF_FFFF_FFFF_FFFFu64
} else {
0
};
$@{shrActualImpl}
$@{resAssign}
}
}
$@{shlUsizeImpls}
$@{unsignedShifts}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("sshift", $$(testFileLit)), 4, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, sbytes) = case.get("s").unwrap();
let (neg2, lbytes) = case.get("l").unwrap();
let (neg3, rbytes) = case.get("r").unwrap();
assert!(!neg1);
let mut x = $$struct_name::from_bytes(xbytes);
let mut l = $$struct_name::from_bytes(lbytes);
let mut r = $$struct_name::from_bytes(rbytes);
if *neg0 { x = -x }
if *neg2 { l = -l }
if *neg3 { r = -r }
let s = usize::try_from($$struct_name::from_bytes(sbytes)).unwrap();
assert_eq!(l, &x << s);
assert_eq!(r, &x >> s);
});
}
|]
actualShlImplLines :: Word -> [Stmt Span]
actualShlImplLines i =
let basei = mkIdent ("base" ++ show i)
basei1 = mkIdent ("base" ++ show (i - 1))
carryi = mkIdent ("carry" ++ show i)
resi = mkIdent ("res" ++ show i)
liti = toLit i
in [
[stmt|let $$basei = if $$(liti) >= digits {
self.value[$$(liti)-digits]
} else {
0
}; |]
, [stmt|let $$carryi = if shift == 64 { 0 } else { $$basei1 >> shift }; |]
, [stmt|let $$resi = ($$basei << bits) | $$carryi; |]
]
actualShrImplLines :: Bool -> Word -> Word -> [Stmt Span]
actualShrImplLines prefix_field entries i =
let basei = mkIdent ("base" ++ show i)
carryi = mkIdent ("carry" ++ show i)
carryi1 = mkIdent ("carry" ++ show (i + 1))
targeti = mkIdent ("target" ++ show i)
resi = mkIdent ("res" ++ show i)
liti = toLit i
litentries = toLit entries
sourceI | prefix_field = [expr| self.contents.value[$$targeti] |]
| otherwise = [expr| self.value[$$targeti] |]
in concat [
[[stmt|let $$targeti = $$(liti) + digits; |]]
, [[stmt|let $$basei = if $$targeti >= $$(litentries) { arith_base } else { $$(sourceI) }; |]]
, if i == (entries - 1)
then [[stmt| let ($$carryi1,_) = (arith_base & mask).overflowing_shl(shift); |]]
else []
, if i == 0
then []
else [[stmt|let ($$carryi,_) = ($$basei & mask).overflowing_shl(shift); |]]
, [[stmt|let $$resi = ($$basei >> bits) | $$carryi1; |]]
]
reassignSelf :: Bool -> Word -> Stmt Span
reassignSelf prefix_field i =
let liti = toLit i
resi = mkIdent ("res" ++ show i)
in if prefix_field
then [stmt| self.contents.value[$$(liti)] = $$resi; |]
else [stmt| self.value[$$(liti)] = $$resi; |]
generateBaseUsizes :: Ident -> [Item Span]
generateBaseUsizes sname =
generateBaseUsize sname (mkIdent "Shl") (mkIdent "shl") (mkIdent "shl_assign") ++
generateBaseUsize sname (mkIdent "Shr") (mkIdent "shr") (mkIdent "shr_assign")
generateBaseUsize :: Ident -> Ident -> Ident -> Ident -> [Item Span]
generateBaseUsize sname tname sfn assign = [
[item|
impl $$tname<usize> for $$sname {
type Output = Self;
fn $$sfn(mut self, rhs: usize) -> $$sname {
self.$$assign(rhs);
self
}
}
|]
, [item|
impl<'a> $$tname<usize> for &'a $$sname {
type Output = $$sname;
fn $$sfn(self, rhs: usize) -> $$sname {
let mut res = self.clone();
res.$$assign(rhs);
res
}
}
|]
]
generateUnsigneds :: Ident -> [Item Span]
generateUnsigneds sname =
concatMap (generateUnsignedImpls sname . mkIdent) ["u8","u16","u32","u64","u128"]
generateUnsignedImpls :: Ident -> Ident -> [Item Span]
generateUnsignedImpls sname rhs =
generateBaseImpls sname (mkIdent "Shl") (mkIdent "shl")
(mkIdent "ShlAssign") (mkIdent "shl_assign") rhs ++
generateBaseImpls sname (mkIdent "Shr") (mkIdent "shr")
(mkIdent "ShrAssign") (mkIdent "shr_assign") rhs
generateBaseImpls :: Ident -> Ident -> Ident -> Ident -> Ident -> Ident -> [Item Span]
generateBaseImpls sname upper_shift lower_shift assign_shift lassign_shift right = [
[item|
impl $$assign_shift<$$right> for $$sname {
fn $$lassign_shift(&mut self, rhs: $$right) {
self.$$lassign_shift(rhs as usize);
}
}
|]
, [item|
impl $$upper_shift<$$right> for $$sname {
type Output = $$sname;
fn $$lower_shift(self, rhs: $$right) -> Self::Output {
self.$$lower_shift(rhs as usize)
}
}
|]
, [item|
impl<'a> $$upper_shift<$$right> for &'a $$sname {
type Output = $$sname;
fn $$lower_shift(self, rhs: $$right) -> $$sname {
self.$$lower_shift(rhs as usize)
}
}
|]
]
generateTest :: RandomGen g => Word -> g -> (Map String String, g)
generateTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
s = y `mod` fromIntegral size
l = modulate (x `shiftL` fromIntegral s) size
r = modulate (x `shiftR` fromIntegral s) size
tcase = Map.fromList [("x", showX x), ("s", showX s),
("l", showX l), ("r", showX r)]
generateSignedTest :: RandomGen g => Word -> g -> (Map String String, g)
generateSignedTest size g0 = (tcase, g2)
where
(x, g1) = generateSignedNum g0 size
(y, g2) = generateNum g1 size
s = y `mod` fromIntegral size
l = modulate (x `shiftL` fromIntegral s) size
r = modulate (x `shiftR` fromIntegral s) size
tcase = Map.fromList [("x", showX x), ("s", showX s),
("l", showX l), ("r", showX r)]

166
generation/src/Signed.hs Normal file
View File

@@ -0,0 +1,166 @@
{-# LANGUAGE QuasiQuotes #-}
module Signed(signedBaseOps)
where
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
signedBaseOps :: RustModule
signedBaseOps = RustModule {
predicate = const (const True),
suggested = const [],
outputName = "base",
isUnsigned = False,
generator = declareSigned,
testCase = Nothing
}
declareSigned :: Word -> [Word] -> SourceFile Span
declareSigned bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
uname = mkIdent ("U" ++ show bitsize)
in [sourceFile|
use core::fmt;
use core::ops::{Neg, Not};
use crate::CryptoNum;
use crate::unsigned::$$uname;
use quickcheck::{Arbitrary,Gen};
#[cfg(test)]
use quickcheck::quickcheck;
#[derive(Clone)]
pub struct $$sname {
pub(crate) contents: $$uname,
}
impl $$sname {
pub fn is_negative(&self) -> bool {
self.contents.value[self.contents.value.len()-1] & 0x8000_0000_0000_0000 != 0
}
}
impl Neg for $$sname {
type Output = $$sname;
fn neg(mut self) -> $$sname {
for x in self.contents.value.iter_mut() {
*x = !*x;
}
let one = $$uname::from(1u64);
self.contents += one;
self
}
}
impl<'a> Neg for &'a $$sname {
type Output = $$sname;
fn neg(self) -> $$sname {
let res = self.clone();
res.neg()
}
}
impl Not for $$sname {
type Output = $$sname;
fn not(mut self) -> $$sname {
for x in self.contents.value.iter_mut() {
*x = !*x;
}
self
}
}
impl<'a> Not for &'a $$sname {
type Output = $$sname;
fn not(self) -> $$sname {
self.clone().not()
}
}
impl fmt::Debug for $$sname {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self.contents)
}
}
impl fmt::UpperHex for $$sname {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut temp = self.clone();
if temp.contents.value[temp.contents.value.len()-1] >> 63 == 1 {
write!(f, "-")?;
temp = !temp;
}
write!(f, "{:X}", temp.contents)
}
}
impl fmt::LowerHex for $$sname {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut temp = self.clone();
if temp.contents.value[temp.contents.value.len()-1] >> 63 == 1 {
write!(f, "-")?;
temp = !temp;
}
write!(f, "{:x}", temp.contents)
}
}
impl Arbitrary for $$sname {
fn arbitrary<G: Gen>(g: &mut G) -> $$sname {
$$sname{
contents: $$uname::arbitrary(g),
}
}
}
impl CryptoNum for $$sname {
fn zero() -> $$sname {
$$sname{ contents: $$uname::zero() }
}
fn is_zero(&self) -> bool {
self.contents.is_zero()
}
fn is_even(&self) -> bool {
self.contents.is_even()
}
fn is_odd(&self) -> bool {
self.contents.is_odd()
}
fn bit_length() -> usize {
$$uname::bit_length()
}
fn mask(&mut self, len: usize) {
self.contents.mask(len);
}
fn testbit(&self, bit: usize) -> bool {
self.contents.testbit(bit)
}
fn from_bytes(bytes: &[u8]) -> $$sname {
$$sname{ contents: $$uname::from_bytes(bytes) }
}
fn to_bytes(&self, bytes: &mut [u8]) {
self.contents.to_bytes(bytes);
}
}
#[cfg(test)]
quickcheck! {
fn double_not(x: $$sname) -> bool {
x == !!&x
}
fn double_neg(x: $$sname) -> bool {
x == --&x
}
}
|]

356
generation/src/Subtract.hs Normal file
View File

@@ -0,0 +1,356 @@
{-# LANGUAGE QuasiQuotes #-}
module Subtract(
safeSubtractOps
, unsafeSubtractOps
, safeSignedSubtractOps
, unsafeSignedSubtractOps
)
where
import Data.Bits((.&.))
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Generators
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
import RustModule
import System.Random(RandomGen)
safeSubtractOps :: RustModule
safeSubtractOps = RustModule {
predicate = \ me others -> (me + 64) `elem` others,
suggested = \ me -> [me + 64],
outputName = "safe_sub",
isUnsigned = True,
generator = declareSafeSubtractOperators,
testCase = Just generateSafeTest
}
safeSignedSubtractOps :: RustModule
safeSignedSubtractOps = RustModule {
predicate = \ me others -> (me + 64) `elem` others,
suggested = \ me -> [me + 64],
outputName = "safe_ssub",
isUnsigned = False,
generator = declareSafeSignedSubtractOperators,
testCase = Just generateSafeSignedTest
}
unsafeSubtractOps :: RustModule
unsafeSubtractOps = RustModule {
predicate = \ _ _ -> True,
suggested = const [],
outputName = "unsafe_sub",
isUnsigned = True,
generator = declareUnsafeSubtractOperators,
testCase = Just generateUnsafeTest
}
unsafeSignedSubtractOps :: RustModule
unsafeSignedSubtractOps = RustModule {
predicate = \ _ _ -> True,
suggested = const [],
outputName = "unsafe_ssub",
isUnsigned = False,
generator = declareUnsafeSignedSubtractOperators,
testCase = Just generateUnsafeSignedTest
}
declareSafeSubtractOperators :: Word -> [Word] -> SourceFile Span
declareSafeSubtractOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize)
dname = mkIdent ("U" ++ show (bitsize + 64))
fullRippleSubtract = makeRippleSubtracter True (bitsize `div` 64) "res"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::Sub;
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::unsigned::{$$sname,$$dname};
impl Sub for $$sname {
type Output = $$dname;
fn sub(self, rhs: $$sname) -> $$dname {
&self - &rhs
}
}
impl<'a> Sub<&'a $$sname> for $$sname {
type Output = $$dname;
fn sub(self, rhs: &$$sname) -> $$dname {
&self - rhs
}
}
impl<'a> Sub<$$sname> for &'a $$sname {
type Output = $$dname;
fn sub(self, rhs: $$sname) -> $$dname {
self - &rhs
}
}
impl<'a,'b> Sub<&'a $$sname> for &'b $$sname {
type Output = $$dname;
fn sub(self, rhs: &$$sname) -> $$dname {
let mut res = $$dname::zero();
$@{fullRippleSubtract}
res
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("safe_sub", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$dname::from_bytes(&zbytes);
assert_eq!(z, x - y);
});
}
|]
declareSafeSignedSubtractOperators :: Word -> [Word] -> SourceFile Span
declareSafeSignedSubtractOperators bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
dname = mkIdent ("I" ++ show (bitsize + 64))
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::Sub;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::signed::{$$sname,$$dname};
impl Sub for $$sname {
type Output = $$dname;
fn sub(self, rhs: $$sname) -> $$dname {
&self - &rhs
}
}
impl<'a> Sub<&'a $$sname> for $$sname {
type Output = $$dname;
fn sub(self, rhs: &$$sname) -> $$dname {
&self - rhs
}
}
impl<'a> Sub<$$sname> for &'a $$sname {
type Output = $$dname;
fn sub(self, rhs: $$sname) -> $$dname {
self - &rhs
}
}
impl<'a,'b> Sub<&'a $$sname> for &'b $$sname {
type Output = $$dname;
fn sub(self, rhs: &$$sname) -> $$dname {
$$dname{ contents: &self.contents - &rhs.contents }
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("safe_sub", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$dname::from_bytes(&zbytes);
assert_eq!(z, x - y);
});
}
|]
declareUnsafeSubtractOperators :: Word -> [Word] -> SourceFile Span
declareUnsafeSubtractOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize)
fullRippleSubtract = makeRippleSubtracter False (bitsize `div` 64) "self"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::SubAssign;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::unsigned::$$sname;
impl SubAssign for $$sname {
fn sub_assign(&mut self, rhs: Self) {
self.sub_assign(&rhs);
}
}
impl<'a> SubAssign<&'a $$sname> for $$sname {
fn sub_assign(&mut self, rhs: &Self) {
$@{fullRippleSubtract}
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("unsafe_sub", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let mut x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$sname::from_bytes(&zbytes);
x -= &y;
assert_eq!(z, x);
});
}
|]
declareUnsafeSignedSubtractOperators :: Word -> [Word] -> SourceFile Span
declareUnsafeSignedSubtractOperators bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
testFileLit = Lit [] (Str (testFile False bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::SubAssign;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::signed::$$sname;
impl SubAssign for $$sname {
fn sub_assign(&mut self, rhs: Self) {
self.sub_assign(&rhs);
}
}
impl<'a> SubAssign<&'a $$sname> for $$sname {
fn sub_assign(&mut self, rhs: &Self) {
self.contents -= &rhs.contents;
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("unsafe_ssub", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let mut x = $$sname::from_bytes(&xbytes);
let mut y = $$sname::from_bytes(&ybytes);
let mut z = $$sname::from_bytes(&zbytes);
if *neg0 { x = -x; }
if *neg1 { y = -y; }
if *neg2 { z = -z; }
x -= &y;
assert_eq!(z, x);
});
}
|]
makeRippleSubtracter :: Bool -> Word -> String -> [Stmt Span]
makeRippleSubtracter useLastCarry inElems resName =
concatMap (generateRipples useLastCarry (inElems - 1)) [0..inElems-1] ++
concatMap (generateSetters useLastCarry inElems resName) [0..inElems]
generateRipples :: Bool -> Word -> Word -> [Stmt Span]
generateRipples useLastCarry lastI i =
let sumi = mkIdent ("sum" ++ show i)
inCarry = mkIdent ("carry" ++ show (i - 1))
outCarry = mkIdent ("carry" ++ show i)
res = mkIdent ("res" ++ show i)
liti = toLit i
left = mkIdent ("left" ++ show i)
right = mkIdent ("right" ++ show i)
in [
[stmt|let $$left = self.value[$$(liti)] as u128; |]
, [stmt|let $$right = !rhs.value[$$(liti)] as u128; |]
, if i == 0
then [stmt| let $$sumi = $$left + $$right + 1; |]
else [stmt| let $$sumi = $$left + $$right + $$inCarry; |]
, [stmt|let $$res = $$sumi as u64; |]
] ++
if not useLastCarry && (i == lastI)
then []
else [[stmt|let $$outCarry = $$sumi >> 64; |]]
generateSetters :: Bool -> Word -> String -> Word -> [Stmt Span]
generateSetters useLastCarry maxI resName i
| not useLastCarry && (maxI == i) = []
| maxI == i =
let res = mkIdent ("carry" ++ show (i - 1))
liti = toLit i
in [[stmt| $$target.value[$$(liti)] = (0xFFFFFFFFFFFFFFFFu128 + $$res) as u64; |]]
| otherwise =
let res = mkIdent ("res" ++ show i)
liti = toLit i
in [[stmt| $$target.value[$$(liti)] = $$res; |]]
where
target = mkIdent resName
generateSafeTest :: RandomGen g => Word -> g -> (Map String String, g)
generateSafeTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
r | x < y = (2 ^ (size + 64)) + (x - y)
| otherwise = x - y
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX r)]
generateSafeSignedTest :: RandomGen g => Word -> g -> (Map String String, g)
generateSafeSignedTest size g0 = (tcase, g2)
where
(x, g1) = generateSignedNum g0 size
(y, g2) = generateSignedNum g1 size
r | x < y = (2 ^ (size + 64)) + (x - y)
| otherwise = x - y
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX r)]
generateUnsafeTest :: RandomGen g => Word -> g -> (Map String String, g)
generateUnsafeTest size g0 = (tcase, g2)
where
(x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size
z = (x - y) .&. ((2 ^ size) - 1)
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)]
generateUnsafeSignedTest :: RandomGen g => Word -> g -> (Map String String, g)
generateUnsafeSignedTest size g0 = (tcase, g2)
where
(x, g1) = generateSignedNum g0 size
(y, g2) = generateSignedNum g1 size
z = (x - y) .&. ((2 ^ size) - 1)
tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)]

View File

@@ -1,4 +0,0 @@
module Testing(
)
where

View File

@@ -0,0 +1,11 @@
{
"folders": [
{
"path": "/Users/awick/projects/cryptonum/generation"
},
{
"path": "/Users/awick/projects/cryptonum/src"
}
],
"settings": {}
}

2
src/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
signed/**/*.rs
unsigned/**/*.rs

View File

@@ -4,6 +4,8 @@ pub mod unsigned;
#[cfg(test)] #[cfg(test)]
mod testing; mod testing;
use core::num::TryFromIntError;
/// A trait definition for large numbers. /// A trait definition for large numbers.
pub trait CryptoNum { pub trait CryptoNum {
/// Generate a new value of the given type. /// Generate a new value of the given type.
@@ -36,3 +38,57 @@ pub trait CryptoNum {
fn to_bytes(&self, bytes: &mut [u8]); fn to_bytes(&self, bytes: &mut [u8]);
} }
/// Provides the ability to do a simultaneous division and modulus operation;
/// this is used as the implementation of division and multiplication, and
/// so you can save time doing both at once if you need them.
///
/// WARNING: There has been some effort made to make this have a constant-time
/// implementation, but it does use a single conditional inside an otherwise-
/// constant time loop. There may be unforeseen timing effects of this, or
/// the compiler may do something funny to "optimize" some math.
pub trait DivMod: Sized {
/// Divide and modulus as a single operation. The first element of the tuple
/// is the quotient, the second is the modulus.
fn divmod(&self, rhs: &Self) -> (Self, Self);
}
/// Provides support for a variety of modular mathematical operations, as beloved
/// by cryptographers. Note that modular inversion and GCD calculations are shoved
/// off into another trait, because they operate on slightly different number
/// types.
pub trait ModularOperations<Modulus=Self> {
// reduce the current value by the provided modulus
fn reduce(&self, m: &Modulus) -> Self;
// multiply this value by the provided one, modulo the modulus
fn modmul(&self, rhs: &Self, m: &Modulus) -> Self;
// square the provided number, modulo the modulus
fn modsq(&self, m: &Modulus) -> Self;
// modular exponentiation!
fn modexp(&self, e: &Self, m: &Modulus) -> Self;
}
/// Provide support for modular inversion and GCD operations, which are useful
/// here and there. We provide default implementations for `modinv` and
/// `gcd_is_one`, based on the implementation of `egcd`. The built-in versions
/// explicitly define the latter, though, to improve performance.
pub trait ModularInversion: Sized {
type Signed;
fn modinv(&self, phi: &Self) -> Option<Self>;
fn egcd(&self, rhs: &Self) -> (Self::Signed, Self::Signed, Self::Signed);
fn gcd_is_one(&self, b: &Self) -> bool;
}
/// An error in conversion of large numbers (either to primitives or to other numbers
#[derive(Debug)]
pub enum ConversionError {
NegativeToUnsigned,
Overflow
}
impl From<TryFromIntError> for ConversionError {
fn from(_: TryFromIntError) -> ConversionError {
ConversionError::Overflow
}
}

View File

@@ -22,7 +22,7 @@ fn next_value_set(line: &str) -> (String, bool, Vec<u8>)
let key = items.next().unwrap(); let key = items.next().unwrap();
let valbits = items.next().unwrap(); let valbits = items.next().unwrap();
let neg = valbits.contains('-'); let neg = valbits.contains('-');
let valbitsnoneg = valbits.trim_start_matches("-"); let valbitsnoneg = valbits.trim_start_matches('-');
let mut nibble_iter = valbitsnoneg.chars().rev(); let mut nibble_iter = valbitsnoneg.chars().rev();
let mut val = Vec::new(); let mut val = Vec::new();
@@ -63,7 +63,6 @@ fn next_test_case(contents: &mut Lines, lines: usize) ->
pub fn run_test<F>(fname: PathBuf, i: usize, f: F) pub fn run_test<F>(fname: PathBuf, i: usize, f: F)
where F: Fn(HashMap<String,(bool,Vec<u8>)>) where F: Fn(HashMap<String,(bool,Vec<u8>)>)
{ {
println!("fname: {:?}", fname);
let mut file = File::open(fname).unwrap(); let mut file = File::open(fname).unwrap();
let mut contents = String::new(); let mut contents = String::new();
file.read_to_string(&mut contents).unwrap(); file.read_to_string(&mut contents).unwrap();

3003
testdata/add/00192.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00256.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00320.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00384.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00448.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00512.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00576.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00640.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00768.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00832.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00896.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/00960.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/01024.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/01088.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/01152.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/01216.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/01280.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/01600.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/01664.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/01728.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/01792.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/02048.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/02112.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/02176.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/02432.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/02496.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/02560.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/03072.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/03136.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/03200.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/04096.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/04160.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/04224.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/06144.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/06208.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/06272.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/07744.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/08256.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/08320.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/15424.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/15488.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/16448.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/16512.test vendored

File diff suppressed because it is too large Load Diff

3003
testdata/add/30784.test vendored

File diff suppressed because one or more lines are too long

3003
testdata/add/30848.test vendored

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More