420 lines
12 KiB
Haskell
420 lines
12 KiB
Haskell
{-# LANGUAGE QuasiQuotes #-}
|
|
module Add(
|
|
safeAddOps
|
|
, unsafeAddOps
|
|
, safeSignedAddOps
|
|
, unsafeSignedAddOps
|
|
)
|
|
where
|
|
|
|
import Data.Bits((.&.))
|
|
import Data.Map.Strict(Map)
|
|
import qualified Data.Map.Strict as Map
|
|
import File
|
|
import Gen(toLit)
|
|
import Generators
|
|
import Language.Rust.Data.Ident
|
|
import Language.Rust.Data.Position
|
|
import Language.Rust.Quote
|
|
import Language.Rust.Syntax
|
|
import System.Random(RandomGen)
|
|
|
|
numTestCases :: Int
|
|
numTestCases = 3000
|
|
|
|
safeAddOps :: File
|
|
safeAddOps = File {
|
|
predicate = \ me others -> (me + 64) `elem` others,
|
|
outputName = "safe_add",
|
|
isUnsigned = True,
|
|
generator = declareSafeAddOperators,
|
|
testCase = Just generateSafeTests
|
|
}
|
|
|
|
unsafeAddOps :: File
|
|
unsafeAddOps = File {
|
|
predicate = \ _ _ -> True,
|
|
outputName = "unsafe_add",
|
|
isUnsigned = True,
|
|
generator = declareUnsafeAddOperators,
|
|
testCase = Just generateUnsafeTests
|
|
}
|
|
|
|
safeSignedAddOps :: File
|
|
safeSignedAddOps = File {
|
|
predicate = \ me others -> (me + 64) `elem` others,
|
|
outputName = "safe_sadd",
|
|
isUnsigned = False,
|
|
generator = declareSafeSignedAddOperators,
|
|
testCase = Just generateSafeSignedTests
|
|
}
|
|
|
|
unsafeSignedAddOps :: File
|
|
unsafeSignedAddOps = File {
|
|
predicate = \ _ _ -> True,
|
|
outputName = "unsafe_sadd",
|
|
isUnsigned = False,
|
|
generator = declareUnsafeSignedAddOperators,
|
|
testCase = Just generateUnsafeSignedTests
|
|
}
|
|
|
|
declareSafeAddOperators :: Word -> [Word] -> SourceFile Span
|
|
declareSafeAddOperators bitsize _ =
|
|
let sname = mkIdent ("U" ++ show bitsize)
|
|
dname = mkIdent ("U" ++ show (bitsize + 64))
|
|
fullRippleAdd = makeRippleAdder True (bitsize `div` 64) "res"
|
|
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
|
|
in [sourceFile|
|
|
use core::ops::Add;
|
|
use crate::CryptoNum;
|
|
#[cfg(test)]
|
|
use crate::testing::{build_test_path,run_test};
|
|
#[cfg(test)]
|
|
use quickcheck::quickcheck;
|
|
use crate::unsigned::{$$sname,$$dname};
|
|
|
|
impl Add for $$sname {
|
|
type Output = $$dname;
|
|
|
|
fn add(self, rhs: $$sname) -> $$dname {
|
|
&self + &rhs
|
|
}
|
|
}
|
|
|
|
impl<'a> Add<&'a $$sname> for $$sname {
|
|
type Output = $$dname;
|
|
|
|
fn add(self, rhs: &$$sname) -> $$dname {
|
|
&self + rhs
|
|
}
|
|
}
|
|
|
|
impl<'a> Add<$$sname> for &'a $$sname {
|
|
type Output = $$dname;
|
|
|
|
fn add(self, rhs: $$sname) -> $$dname {
|
|
self + &rhs
|
|
}
|
|
}
|
|
|
|
impl<'a,'b> Add<&'a $$sname> for &'b $$sname {
|
|
type Output = $$dname;
|
|
|
|
fn add(self, rhs: &$$sname) -> $$dname {
|
|
let mut res = $$dname::zero();
|
|
|
|
$@{fullRippleAdd}
|
|
|
|
res
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
quickcheck! {
|
|
fn addition_symmetric(a: $$sname, b: $$sname) -> bool {
|
|
(&a + &b) == (&b + &a)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[allow(non_snake_case)]
|
|
#[test]
|
|
fn KATs() {
|
|
run_test(build_test_path("safe_add", $$(testFileLit)), 3, |case| {
|
|
let (neg0, xbytes) = case.get("x").unwrap();
|
|
let (neg1, ybytes) = case.get("y").unwrap();
|
|
let (neg2, zbytes) = case.get("z").unwrap();
|
|
|
|
assert!(!neg0 && !neg1 && !neg2);
|
|
let x = $$sname::from_bytes(&xbytes);
|
|
let y = $$sname::from_bytes(&ybytes);
|
|
let z = $$dname::from_bytes(&zbytes);
|
|
|
|
assert_eq!(z, x + y);
|
|
});
|
|
}
|
|
|]
|
|
|
|
declareUnsafeAddOperators :: Word -> [Word] -> SourceFile Span
|
|
declareUnsafeAddOperators bitsize _ =
|
|
let sname = mkIdent ("U" ++ show bitsize)
|
|
fullRippleAdd = makeRippleAdder False (bitsize `div` 64) "self"
|
|
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
|
|
in [sourceFile|
|
|
use core::ops::AddAssign;
|
|
#[cfg(test)]
|
|
use crate::CryptoNum;
|
|
#[cfg(test)]
|
|
use crate::testing::{build_test_path,run_test};
|
|
#[cfg(test)]
|
|
use quickcheck::quickcheck;
|
|
use crate::unsigned::$$sname;
|
|
|
|
impl AddAssign for $$sname {
|
|
fn add_assign(&mut self, rhs: Self) {
|
|
self.add_assign(&rhs);
|
|
}
|
|
}
|
|
|
|
impl<'a> AddAssign<&'a $$sname> for $$sname {
|
|
fn add_assign(&mut self, rhs: &Self) {
|
|
$@{fullRippleAdd}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
quickcheck! {
|
|
fn addition_symmetric(a: $$sname, b: $$sname) -> bool {
|
|
let mut side1 = a.clone();
|
|
let mut side2 = b.clone();
|
|
|
|
side1 += b;
|
|
side2 += a;
|
|
|
|
side1 == side2
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[allow(non_snake_case)]
|
|
#[test]
|
|
fn KATs() {
|
|
run_test(build_test_path("unsafe_add", $$(testFileLit)), 3, |case| {
|
|
let (neg0, xbytes) = case.get("x").unwrap();
|
|
let (neg1, ybytes) = case.get("y").unwrap();
|
|
let (neg2, zbytes) = case.get("z").unwrap();
|
|
|
|
assert!(!neg0 && !neg1 && !neg2);
|
|
let mut x = $$sname::from_bytes(&xbytes);
|
|
let y = $$sname::from_bytes(&ybytes);
|
|
let z = $$sname::from_bytes(&zbytes);
|
|
|
|
x += &y;
|
|
assert_eq!(z, x);
|
|
});
|
|
}
|
|
|]
|
|
|
|
declareSafeSignedAddOperators :: Word -> [Word] -> SourceFile Span
|
|
declareSafeSignedAddOperators bitsize _ =
|
|
let sname = mkIdent ("I" ++ show bitsize)
|
|
dname = mkIdent ("I" ++ show (bitsize + 64))
|
|
fullRippleAdd = makeRippleAdder True (bitsize `div` 64) "res"
|
|
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
|
|
in [sourceFile|
|
|
use core::ops::Add;
|
|
use crate::CryptoNum;
|
|
#[cfg(test)]
|
|
use crate::testing::{build_test_path,run_test};
|
|
#[cfg(test)]
|
|
use quickcheck::quickcheck;
|
|
use crate::signed::{$$sname,$$dname};
|
|
|
|
impl Add for $$sname {
|
|
type Output = $$dname;
|
|
|
|
fn add(self, rhs: $$sname) -> $$dname {
|
|
&self + &rhs
|
|
}
|
|
}
|
|
|
|
impl<'a> Add<&'a $$sname> for $$sname {
|
|
type Output = $$dname;
|
|
|
|
fn add(self, rhs: &$$sname) -> $$dname {
|
|
&self + rhs
|
|
}
|
|
}
|
|
|
|
impl<'a> Add<$$sname> for &'a $$sname {
|
|
type Output = $$dname;
|
|
|
|
fn add(self, rhs: $$sname) -> $$dname {
|
|
self + &rhs
|
|
}
|
|
}
|
|
|
|
impl<'a,'b> Add<&'a $$sname> for &'b $$sname {
|
|
type Output = $$dname;
|
|
|
|
fn add(self, rhs: &$$sname) -> $$dname {
|
|
panic!("add")
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
quickcheck! {
|
|
fn addition_symmetric(a: $$sname, b: $$sname) -> bool {
|
|
(&a + &b) == (&b + &a)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[allow(non_snake_case)]
|
|
#[test]
|
|
fn KATs() {
|
|
run_test(build_test_path("safe_sadd", $$(testFileLit)), 3, |case| {
|
|
let (neg0, xbytes) = case.get("x").unwrap();
|
|
let (neg1, ybytes) = case.get("y").unwrap();
|
|
let (neg2, zbytes) = case.get("z").unwrap();
|
|
|
|
let mut x = $$sname::from_bytes(&xbytes);
|
|
let mut y = $$sname::from_bytes(&ybytes);
|
|
let mut z = $$dname::from_bytes(&zbytes);
|
|
if neg0 { x = x.negate() }
|
|
if neg1 { y = y.negate() }
|
|
if neg2 { z = z.negate() }
|
|
|
|
assert_eq!(z, x + y);
|
|
});
|
|
}
|
|
|]
|
|
|
|
declareUnsafeSignedAddOperators :: Word -> [Word] -> SourceFile Span
|
|
declareUnsafeSignedAddOperators bitsize _ =
|
|
let sname = mkIdent ("I" ++ show bitsize)
|
|
fullRippleAdd = makeRippleAdder False (bitsize `div` 64) "self"
|
|
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
|
|
in [sourceFile|
|
|
use core::ops::AddAssign;
|
|
#[cfg(test)]
|
|
use crate::CryptoNum;
|
|
#[cfg(test)]
|
|
use crate::testing::{build_test_path,run_test};
|
|
#[cfg(test)]
|
|
use quickcheck::quickcheck;
|
|
use crate::signed::$$sname;
|
|
|
|
impl AddAssign for $$sname {
|
|
fn add_assign(&mut self, rhs: Self) {
|
|
self.add_assign(&rhs);
|
|
}
|
|
}
|
|
|
|
impl<'a> AddAssign<&'a $$sname> for $$sname {
|
|
fn add_assign(&mut self, rhs: &Self) {
|
|
panic!("add_assign")
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
quickcheck! {
|
|
fn addition_symmetric(a: $$sname, b: $$sname) -> bool {
|
|
let mut side1 = a.clone();
|
|
let mut side2 = b.clone();
|
|
|
|
side1 += b;
|
|
side2 += a;
|
|
|
|
side1 == side2
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
#[allow(non_snake_case)]
|
|
#[test]
|
|
fn KATs() {
|
|
run_test(build_test_path("unsafe_sadd", $$(testFileLit)), 3, |case| {
|
|
let (neg0, xbytes) = case.get("x").unwrap();
|
|
let (neg1, ybytes) = case.get("y").unwrap();
|
|
let (neg2, zbytes) = case.get("z").unwrap();
|
|
|
|
let mut x = $$sname::from_bytes(&xbytes);
|
|
let mut y = $$sname::from_bytes(&ybytes);
|
|
let mut z = $$sname::from_bytes(&zbytes);
|
|
if neg0 { x = x.negate() }
|
|
if neg1 { y = y.negate() }
|
|
if neg2 { z = z.negate() }
|
|
|
|
x += &y;
|
|
assert_eq!(z, x);
|
|
});
|
|
}
|
|
|]
|
|
|
|
makeRippleAdder :: Bool -> Word -> String -> [Stmt Span]
|
|
makeRippleAdder useLastCarry inElems resName =
|
|
concatMap (generateRipples useLastCarry (inElems - 1)) [0..inElems-1] ++
|
|
concatMap (generateSetters useLastCarry inElems resName) [0..inElems]
|
|
|
|
generateRipples :: Bool -> Word -> Word -> [Stmt Span]
|
|
generateRipples useLastCarry lastI i =
|
|
let sumi = mkIdent ("sum" ++ show i)
|
|
inCarry = mkIdent ("carry" ++ show (i - 1))
|
|
outCarry = mkIdent ("carry" ++ show i)
|
|
res = mkIdent ("res" ++ show i)
|
|
liti = toLit i
|
|
left = mkIdent ("left" ++ show i)
|
|
right = mkIdent ("right" ++ show i)
|
|
in [
|
|
[stmt|let $$left = self.value[$$(liti)] as u128; |]
|
|
, [stmt|let $$right = rhs.value[$$(liti)] as u128; |]
|
|
, if i == 0
|
|
then [stmt| let $$sumi = $$left + $$right; |]
|
|
else [stmt| let $$sumi = $$left + $$right + $$inCarry; |]
|
|
, [stmt|let $$res = $$sumi as u64; |]
|
|
] ++
|
|
if not useLastCarry && (i == lastI)
|
|
then []
|
|
else [[stmt|let $$outCarry = $$sumi >> 64; |]]
|
|
|
|
generateSetters :: Bool -> Word -> String -> Word -> [Stmt Span]
|
|
generateSetters useLastCarry maxI resName i
|
|
| not useLastCarry && (maxI == i) = []
|
|
| maxI == i =
|
|
let res = mkIdent ("carry" ++ show (i - 1))
|
|
liti = toLit i
|
|
in [[stmt| $$target.value[$$(liti)] = $$res as u64; |]]
|
|
| otherwise =
|
|
let res = mkIdent ("res" ++ show i)
|
|
liti = toLit i
|
|
in [[stmt| $$target.value[$$(liti)] = $$res; |]]
|
|
where
|
|
target = mkIdent resName
|
|
|
|
generateSafeTests :: RandomGen g => Word -> g -> [Map String String]
|
|
generateSafeTests size g = go g numTestCases
|
|
where
|
|
go _ 0 = []
|
|
go g0 i =
|
|
let (x, g1) = generateNum g0 size
|
|
(y, g2) = generateNum g1 size
|
|
tcase = Map.fromList [("x", showX x), ("y", showX y),
|
|
("z", showX (x + y))]
|
|
in tcase : go g2 (i - 1)
|
|
|
|
generateUnsafeTests :: RandomGen g => Word -> g -> [Map String String]
|
|
generateUnsafeTests size g = go g numTestCases
|
|
where
|
|
go _ 0 = []
|
|
go g0 i =
|
|
let (x, g1) = generateNum g0 size
|
|
(y, g2) = generateNum g1 size
|
|
z = (x + y) .&. ((2 ^ size) - 1)
|
|
tcase = Map.fromList [("x", showX x), ("y", showX y),
|
|
("z", showX z)]
|
|
in tcase : go g2 (i - 1)
|
|
|
|
generateSafeSignedTests :: RandomGen g => Word -> g -> [Map String String]
|
|
generateSafeSignedTests size g = go g numTestCases
|
|
where
|
|
go _ 0 = []
|
|
go g0 i =
|
|
let (x, g1) = generateSignedNum g0 size
|
|
(y, g2) = generateSignedNum g1 size
|
|
tcase = Map.fromList [("x", showX x), ("y", showX y),
|
|
("z", showX (x + y))]
|
|
in tcase : go g2 (i - 1)
|
|
|
|
generateUnsafeSignedTests :: RandomGen g => Word -> g -> [Map String String]
|
|
generateUnsafeSignedTests size g = go g numTestCases
|
|
where
|
|
go _ 0 = []
|
|
go g0 i =
|
|
let (x, g1) = generateSignedNum g0 size
|
|
(y, g2) = generateSignedNum g1 size
|
|
z = (x + y) .&. ((2 ^ size) - 1)
|
|
tcase = Map.fromList [("x", showX x), ("y", showX y),
|
|
("z", showX z)]
|
|
in tcase : go g2 (i - 1)
|