All the infrastructure to eventually to modinv. Don't try to use any of this yet.

This commit is contained in:
2020-01-17 20:44:41 -08:00
parent e46cfe56d1
commit b3fcd4715e
10 changed files with 958 additions and 32 deletions

View File

@@ -1,21 +1,22 @@
module Main module Main
where where
import Add(safeAddOps,unsafeAddOps) import Add(safeAddOps,unsafeAddOps,safeSignedAddOps,unsafeSignedAddOps)
import Base(base) import Base(base)
import BinaryOps(binaryOps) import BinaryOps(binaryOps)
import Compare(comparisons) import Compare(comparisons, signedComparisons)
import Conversions(conversions) import Conversions(conversions, signedConversions)
import CryptoNum(cryptoNum) import CryptoNum(cryptoNum)
import Control.Monad(forM_,unless) import Control.Monad(forM_,unless)
import Division(divisionOps) import Division(divisionOps)
import File(File,Task(..),generateTasks) import File(File,Task(..),generateTasks)
import ModInv(generateModInvOps)
import ModOps(modulusOps) import ModOps(modulusOps)
import Multiply(safeMultiplyOps, unsafeMultiplyOps) import Multiply(safeMultiplyOps, unsafeMultiplyOps)
import Scale(safeScaleOps, unsafeScaleOps) import Scale(safeScaleOps, unsafeScaleOps)
import Shift(shiftOps) import Shift(shiftOps, signedShiftOps)
import Signed(signedBaseOps) import Signed(signedBaseOps)
import Subtract(safeSubtractOps,unsafeSubtractOps) import Subtract(safeSubtractOps,unsafeSubtractOps,safeSignedSubtractOps,unsafeSignedSubtractOps)
import System.Directory(createDirectoryIfMissing) import System.Directory(createDirectoryIfMissing)
import System.Environment(getArgs) import System.Environment(getArgs)
import System.Exit(die) import System.Exit(die)
@@ -54,7 +55,15 @@ unsignedFiles = [
signedFiles :: [File] signedFiles :: [File]
signedFiles = [ signedFiles = [
signedBaseOps generateModInvOps
, safeSignedAddOps
, safeSignedSubtractOps
, signedBaseOps
, signedComparisons
, signedConversions
, signedShiftOps
, unsafeSignedAddOps
, unsafeSignedSubtractOps
] ]
allFiles :: [File] allFiles :: [File]

View File

@@ -41,6 +41,7 @@ library
Gen, Gen,
Generators, Generators,
Karatsuba, Karatsuba,
ModInv,
ModOps, ModOps,
Multiply, Multiply,
Scale, Scale,

View File

@@ -2,6 +2,8 @@
module Add( module Add(
safeAddOps safeAddOps
, unsafeAddOps , unsafeAddOps
, safeSignedAddOps
, unsafeSignedAddOps
) )
where where
@@ -38,6 +40,24 @@ unsafeAddOps = File {
testCase = Just generateUnsafeTests testCase = Just generateUnsafeTests
} }
safeSignedAddOps :: File
safeSignedAddOps = File {
predicate = \ me others -> (me + 64) `elem` others,
outputName = "safe_sadd",
isUnsigned = False,
generator = declareSafeSignedAddOperators,
testCase = Just generateSafeSignedTests
}
unsafeSignedAddOps :: File
unsafeSignedAddOps = File {
predicate = \ _ _ -> True,
outputName = "unsafe_sadd",
isUnsigned = False,
generator = declareUnsafeSignedAddOperators,
testCase = Just generateUnsafeSignedTests
}
declareSafeAddOperators :: Word -> [Word] -> SourceFile Span declareSafeAddOperators :: Word -> [Word] -> SourceFile Span
declareSafeAddOperators bitsize _ = declareSafeAddOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize) let sname = mkIdent ("U" ++ show bitsize)
@@ -175,6 +195,142 @@ declareUnsafeAddOperators bitsize _ =
} }
|] |]
declareSafeSignedAddOperators :: Word -> [Word] -> SourceFile Span
declareSafeSignedAddOperators bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
dname = mkIdent ("I" ++ show (bitsize + 64))
fullRippleAdd = makeRippleAdder True (bitsize `div` 64) "res"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::Add;
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use crate::signed::{$$sname,$$dname};
impl Add for $$sname {
type Output = $$dname;
fn add(self, rhs: $$sname) -> $$dname {
&self + &rhs
}
}
impl<'a> Add<&'a $$sname> for $$sname {
type Output = $$dname;
fn add(self, rhs: &$$sname) -> $$dname {
&self + rhs
}
}
impl<'a> Add<$$sname> for &'a $$sname {
type Output = $$dname;
fn add(self, rhs: $$sname) -> $$dname {
self + &rhs
}
}
impl<'a,'b> Add<&'a $$sname> for &'b $$sname {
type Output = $$dname;
fn add(self, rhs: &$$sname) -> $$dname {
panic!("add")
}
}
#[cfg(test)]
quickcheck! {
fn addition_symmetric(a: $$sname, b: $$sname) -> bool {
(&a + &b) == (&b + &a)
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("safe_sadd", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let mut x = $$sname::from_bytes(&xbytes);
let mut y = $$sname::from_bytes(&ybytes);
let mut z = $$dname::from_bytes(&zbytes);
if neg0 { x = x.negate() }
if neg1 { y = y.negate() }
if neg2 { z = z.negate() }
assert_eq!(z, x + y);
});
}
|]
declareUnsafeSignedAddOperators :: Word -> [Word] -> SourceFile Span
declareUnsafeSignedAddOperators bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
fullRippleAdd = makeRippleAdder False (bitsize `div` 64) "self"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::AddAssign;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use crate::signed::$$sname;
impl AddAssign for $$sname {
fn add_assign(&mut self, rhs: Self) {
self.add_assign(&rhs);
}
}
impl<'a> AddAssign<&'a $$sname> for $$sname {
fn add_assign(&mut self, rhs: &Self) {
panic!("add_assign")
}
}
#[cfg(test)]
quickcheck! {
fn addition_symmetric(a: $$sname, b: $$sname) -> bool {
let mut side1 = a.clone();
let mut side2 = b.clone();
side1 += b;
side2 += a;
side1 == side2
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("unsafe_sadd", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let mut x = $$sname::from_bytes(&xbytes);
let mut y = $$sname::from_bytes(&ybytes);
let mut z = $$sname::from_bytes(&zbytes);
if neg0 { x = x.negate() }
if neg1 { y = y.negate() }
if neg2 { z = z.negate() }
x += &y;
assert_eq!(z, x);
});
}
|]
makeRippleAdder :: Bool -> Word -> String -> [Stmt Span] makeRippleAdder :: Bool -> Word -> String -> [Stmt Span]
makeRippleAdder useLastCarry inElems resName = makeRippleAdder useLastCarry inElems resName =
@@ -238,3 +394,26 @@ generateUnsafeTests size g = go g numTestCases
tcase = Map.fromList [("x", showX x), ("y", showX y), tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX z)] ("z", showX z)]
in tcase : go g2 (i - 1) in tcase : go g2 (i - 1)
generateSafeSignedTests :: RandomGen g => Word -> g -> [Map String String]
generateSafeSignedTests size g = go g numTestCases
where
go _ 0 = []
go g0 i =
let (x, g1) = generateSignedNum g0 size
(y, g2) = generateSignedNum g1 size
tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX (x + y))]
in tcase : go g2 (i - 1)
generateUnsafeSignedTests :: RandomGen g => Word -> g -> [Map String String]
generateUnsafeSignedTests size g = go g numTestCases
where
go _ 0 = []
go g0 i =
let (x, g1) = generateSignedNum g0 size
(y, g2) = generateSignedNum g1 size
z = (x + y) .&. ((2 ^ size) - 1)
tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX z)]
in tcase : go g2 (i - 1)

View File

@@ -1,5 +1,5 @@
{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE QuasiQuotes #-}
module Compare(comparisons) module Compare(comparisons, signedComparisons)
where where
import Data.Map.Strict(Map) import Data.Map.Strict(Map)
@@ -24,6 +24,15 @@ comparisons = File {
testCase = Just generateTests testCase = Just generateTests
} }
signedComparisons :: File
signedComparisons = File {
predicate = \ _ _ -> True,
outputName = "scompare",
isUnsigned = False,
generator = declareSignedComparators,
testCase = Just generateSignedTests
}
declareComparators :: Word -> [Word] -> SourceFile Span declareComparators :: Word -> [Word] -> SourceFile Span
declareComparators bitsize _ = declareComparators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize) let sname = mkIdent ("U" ++ show bitsize)
@@ -121,6 +130,101 @@ declareComparators bitsize _ =
} }
|] |]
declareSignedComparators :: Word -> [Word] -> SourceFile Span
declareSignedComparators bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
entries = bitsize `div` 64
eqStatements = buildEqStatements 0 entries
compareExp = buildCompareExp 0 entries
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::cmp::{Eq,Ordering,PartialEq};
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use super::$$sname;
impl PartialEq for $$sname {
fn eq(&self, other: &Self) -> bool {
&self.contents == &other.contents
}
}
impl Eq for $$sname {}
impl Ord for $$sname {
fn cmp(&self, other: &Self) -> Ordering {
panic!("cmp")
}
}
impl PartialOrd for $$sname {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[cfg(test)]
quickcheck! {
fn eq_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a == c { a == b && b == c } else { a != b || b != c }
}
fn gt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a > b && b > c { a > c } else { true }
}
fn ge_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a >= b && b >= c { a >= c } else { true }
}
fn lt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a < b && b < c { a < c } else { true }
}
fn le_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a <= b && b <= c { a <= c } else { true }
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("scompare", $$(testFileLit)), 8, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, ebytes) = case.get("e").unwrap();
let (neg3, nbytes) = case.get("n").unwrap();
let (neg4, gbytes) = case.get("g").unwrap();
let (neg5, hbytes) = case.get("h").unwrap();
let (neg6, lbytes) = case.get("l").unwrap();
let (neg7, kbytes) = case.get("k").unwrap();
assert!(!neg0 && !neg1 && !neg2 && !neg3 &&
!neg4 && !neg5 && !neg6 && !neg7);
let x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let e = 1 == ebytes[0];
let n = 1 == nbytes[0];
let g = 1 == gbytes[0];
let h = 1 == hbytes[0];
let l = 1 == lbytes[0];
let k = 1 == kbytes[0];
assert_eq!(e, x == y);
assert_eq!(n, x != y);
assert_eq!(g, x > y);
assert_eq!(h, x >= y);
assert_eq!(l, x < y);
assert_eq!(k, x <= y);
});
}
|]
buildEqStatements :: Word -> Word -> [Stmt Span] buildEqStatements :: Word -> Word -> [Stmt Span]
buildEqStatements i numEntries buildEqStatements i numEntries
| i == (numEntries - 1) = | i == (numEntries - 1) =
@@ -157,3 +261,19 @@ generateTests size g = go g numTestCases
("l", showB (x < y)), ("l", showB (x < y)),
("k", showB (x <= y))] ("k", showB (x <= y))]
in tcase : go g2 (i - 1) in tcase : go g2 (i - 1)
generateSignedTests :: RandomGen g => Word -> g -> [Map String String]
generateSignedTests size g = go g numTestCases
where
go _ 0 = []
go g0 i =
let (x, g1) = generateSignedNum g0 size
(y, g2) = generateSignedNum g1 size
tcase = Map.fromList [("x", showX x), ("y", showX y),
("e", showB (x == y)),
("n", showB (x /= y)),
("g", showB (x > y)),
("h", showB (x >= y)),
("l", showB (x < y)),
("k", showB (x <= y))]
in tcase : go g2 (i - 1)

View File

@@ -1,6 +1,7 @@
{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE QuasiQuotes #-}
module Conversions( module Conversions(
conversions conversions
, signedConversions
) )
where where
@@ -20,6 +21,15 @@ conversions = File {
testCase = Nothing testCase = Nothing
} }
signedConversions :: File
signedConversions = File {
predicate = \ _ _ -> True,
outputName = "sconversions",
isUnsigned = False,
generator = declareSignedConversions,
testCase = Nothing
}
declareConversions :: Word -> [Word] -> SourceFile Span declareConversions :: Word -> [Word] -> SourceFile Span
declareConversions bitsize otherSizes = declareConversions bitsize otherSizes =
let sname = mkIdent ("U" ++ show bitsize) let sname = mkIdent ("U" ++ show bitsize)
@@ -84,6 +94,71 @@ declareConversions bitsize otherSizes =
} }
|] |]
declareSignedConversions :: Word -> [Word] -> SourceFile Span
declareSignedConversions bitsize otherSizes =
let sname = mkIdent ("I" ++ show bitsize)
entries = bitsize `div` 64
u8_prims = buildUSPrimitives sname (mkIdent "u8") entries
u16_prims = buildUSPrimitives sname (mkIdent "u16") entries
u32_prims = buildUSPrimitives sname (mkIdent "u32") entries
u64_prims = buildUSPrimitives sname (mkIdent "u64") entries
usz_prims = buildUSPrimitives sname (mkIdent "usize") entries
u128_prims = generateUS128Primitives sname entries
i8_prims = buildSSPrimitives sname (mkIdent "i8") entries
i16_prims = buildSSPrimitives sname (mkIdent "i16") entries
i32_prims = buildSSPrimitives sname (mkIdent "i32") entries
i64_prims = buildSSPrimitives sname (mkIdent "i64") entries
isz_prims = buildSSPrimitives sname (mkIdent "isize") entries
i128_prims = generateSS128Primitives sname entries
others = generateSignedCryptonumConversions bitsize otherSizes
in [sourceFile|
use core::convert::{From,TryFrom};
use crate::CryptoNum;
use crate::ConversionError;
use crate::signed::*;
use crate::unsigned::*;
#[cfg(test)]
use quickcheck::quickcheck;
$@{u8_prims}
$@{u16_prims}
$@{u32_prims}
$@{u64_prims}
$@{usz_prims}
$@{u128_prims}
$@{i8_prims}
$@{i16_prims}
$@{i32_prims}
$@{i64_prims}
$@{isz_prims}
$@{i128_prims}
$@{others}
#[cfg(test)]
quickcheck! {
fn u8_recovers(x: u8) -> bool {
x == u8::try_from($$sname::from(x)).unwrap()
}
fn u16_recovers(x: u16) -> bool {
x == u16::try_from($$sname::from(x)).unwrap()
}
fn u32_recovers(x: u32) -> bool {
x == u32::try_from($$sname::from(x)).unwrap()
}
fn u64_recovers(x: u64) -> bool {
x == u64::try_from($$sname::from(x)).unwrap()
}
fn usize_recovers(x: usize) -> bool {
x == usize::try_from($$sname::from(x)).unwrap()
}
fn u128_recovers(x: u128) -> bool {
x == u128::try_from($$sname::from(x)).unwrap()
}
}
|]
generateU128Primitives :: Ident -> Word -> [Item Span] generateU128Primitives :: Ident -> Word -> [Item Span]
generateU128Primitives sname entries = [ generateU128Primitives sname entries = [
[item|impl From<u128> for $$sname { [item|impl From<u128> for $$sname {
@@ -304,3 +379,137 @@ generateCryptonumConversions source otherSizes = concatMap convert otherSizes
|] |]
] ]
buildUSPrimitives :: Ident -> Ident -> Word -> [Item Span]
buildUSPrimitives sname prim entries = [
[item|
impl From<$$prim> for $$sname {
fn from(x: $$prim) -> $$sname {
let mut base = $$sname::zero();
base.contents.value[0] = x as u64;
base
}
}
|]
, [item|
impl<'a> TryFrom<&'a $$sname> for $$prim {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<$$prim, ConversionError> {
if (x.contents.value)[1..].iter().any(|v| *v != 0) {
return Err(ConversionError::Overflow);
}
let res64 = x.contents.value[0];
if res64 & 0x8000_0000_0000_0000 != 0 {
return Err(ConversionError::Overflow);
}
Ok(res64 as $$prim)
}
}
|]
, [item|
impl TryFrom<$$sname> for $$prim {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<$$prim, ConversionError> {
$$prim::try_from(&x)
}
}
|]
]
buildSSPrimitives :: Ident -> Ident -> Word -> [Item Span]
buildSSPrimitives sname prim entries = [
[item|
impl From<$$prim> for $$sname {
fn from(x: $$prim) -> $$sname {
panic!("from_signed")
}
}
|]
, [item|
impl<'a> TryFrom<&'a $$sname> for $$prim {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<$$prim, ConversionError> {
panic!("try_from_signed")
}
}
|]
, [item|
impl TryFrom<$$sname> for $$prim {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<$$prim, ConversionError> {
$$prim::try_from(&x)
}
}
|]
]
generateUS128Primitives :: Ident -> Word -> [Item Span]
generateUS128Primitives struct entries = []
generateSS128Primitives :: Ident -> Word -> [Item Span]
generateSS128Primitives struct entries = []
generateSignedCryptonumConversions :: Word -> [Word] -> [Item Span]
generateSignedCryptonumConversions source otherSizes = concatMap convert otherSizes
where
sName = mkIdent ("I" ++ show source)
--
convert target =
let tsName = mkIdent ("I" ++ show target)
tuName = mkIdent ("U" ++ show target)
sEntries = toLit (source `div` 64)
tEntries = toLit (target `div` 64)
in case compare source target of
LT -> []
EQ -> [
[item|
impl TryFrom<$$tuName> for $$sName {
type Error = ConversionError;
fn try_from(x: $$tuName) -> Result<$$sName,ConversionError> {
let res = $$sName{ contents: x };
if res.is_negative() {
return Err(ConversionError::Overflow);
}
Ok(res)
}
}
|],
[item|
impl<'a> TryFrom<&'a $$tuName> for $$sName {
type Error = ConversionError;
fn try_from(x: &$$tuName) -> Result<$$sName,ConversionError> {
$$sName::try_from(x.clone())
}
}
|],
[item|
impl TryFrom<$$sName> for $$tuName {
type Error = ConversionError;
fn try_from(x: $$sName) -> Result<$$tuName,ConversionError> {
if x.is_negative() {
return Err(ConversionError::Overflow);
}
Ok(x.contents)
}
}
|],
[item|
impl<'a> TryFrom<&'a $$sName> for $$tuName {
type Error = ConversionError;
fn try_from(x: &$$sName) -> Result<$$tuName,ConversionError> {
$$tuName::try_from(x.clone())
}
}
|]
]
GT -> []

View File

@@ -10,6 +10,15 @@ generateNum g size =
x' = x `mod` (2 ^ size) x' = x `mod` (2 ^ size)
in (x', g') in (x', g')
generateSignedNum :: RandomGen g => g -> Word -> (Integer, g)
generateSignedNum g size =
let (x, g') = random g
s :: Integer
(s, g'') = random g'
x' = x `mod` (2 ^ size)
sign = if even s then 1 else -1
in (x' * sign, g'')
modulate :: (Integral a, Integral b) => a -> b -> Integer modulate :: (Integral a, Integral b) => a -> b -> Integer
modulate x size = x' `mod` (2 ^ size') modulate x size = x' `mod` (2 ^ size')
where where

173
generation/src/ModInv.hs Normal file
View File

@@ -0,0 +1,173 @@
{-# LANGUAGE QuasiQuotes #-}
module ModInv(
generateModInvOps
)
where
import File
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Syntax
generateModInvOps :: File
generateModInvOps = File {
predicate = \ me others -> (me + 64) `elem` others,
outputName = "modinv",
isUnsigned = False,
generator = declareModInv,
testCase = Nothing
}
declareModInv :: Word -> [Word] -> SourceFile Span
declareModInv bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
uname = mkIdent ("U" ++ show bitsize)
in [sourceFile|
use core::convert::TryFrom;
use crate::CryptoNum;
use crate::signed::$$sname;
use crate::unsigned::$$uname;
impl $$uname {
fn modinv(&self, phi: &$$uname) -> Option<$$uname>
{
let (_, mut b, g) = phi.egcd(&self);
if g != $$sname::from(1i64) {
return None;
}
let sphi = $$sname::try_from(phi).expect("over/underflow in modinv phi");
while b.is_negative() {
b += &sphi;
}
if b > sphi {
b -= &sphi;
}
Some($$uname::try_from(b).expect("overflow/underflow in modinv result"))
}
fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) {
// INPUT: two positive integers x and y.
let mut x = $$sname::try_from(self).expect("overflow in modinv base");
let mut y = $$sname::try_from(rhs).expect("overflow in modinv rhs");
// OUTPUT: integers a, b, and v such that ax + by = v,
// where v = gcd(x, y).
// 1. g1.
let mut gshift: usize = 0;
// 2. While x and y are both even, do the following: xx/2,
// yy/2, g2g.
while x.is_even() && y.is_even() {
x >>= 1u64;
y >>= 1u64;
gshift += 1;
}
// 3. ux, vy, A1, B0, C0, D1.
let mut u = x.clone();
let mut v = y.clone();
#[allow(non_snake_case)]
let mut A = $$sname::from(1i64);
#[allow(non_snake_case)]
let mut B = $$sname::zero();
#[allow(non_snake_case)]
let mut C = $$sname::zero();
#[allow(non_snake_case)]
let mut D = $$sname::from(1i64);
loop {
// 4. While u is even do the following:
while u.is_even() {
// 4.1 uu/2.
u >>= 1u64;
// 4.2 If AB0 (mod 2) then AA/2, BB/2; otherwise,
// A(A + y)/2, B(B x)/2.
if A.is_even() && B.is_even() {
A >>= 1u64;
B >>= 1u64;
} else {
A += &y;
A >>= 1u64;
B -= &x;
B >>= 1u64;
}
}
// 5. While v is even do the following:
while v.is_even() {
// 5.1 vv/2.
v >>= 1u64;
// 5.2 If C D 0 (mod 2) then CC/2, DD/2; otherwise,
// C(C + y)/2, D(D x)/2.
if C.is_even() && D.is_even() {
C >>= 1u64;
D >>= 1u64;
} else {
C += &y;
C >>= 1u64;
D -= &x;
D >>= 1u64;
}
}
// 6. If uv then uuv, AAC,BBD;
// otherwise,vvu, CCA, DDB.
if u >= v {
u -= &v;
A -= &C;
B -= &D;
} else {
v -= &u;
C -= &A;
D -= &B;
}
// 7. If u = 0, then aC, bD, and return(a, b, g · v);
// otherwise, go to step 4.
if u.is_zero() {
return (C, D, v << gshift);
}
}
}
fn gcd_is_one(&self, b: &$$uname) -> bool {
let mut u = self.clone();
let mut v = b.clone();
let one = $$uname::from(1u64);
if u.is_zero() {
return v == one;
}
if v.is_zero() {
return u == one;
}
if u.is_even() && v.is_even() {
return false;
}
while u.is_even() {
u >>= 1u64;
}
loop {
while v.is_even() {
v >>= 1u64;
}
// u and v guaranteed to be odd right now.
if u > v {
// make sure that v > u, so that our subtraction works
// out.
let t = u;
u = v;
v = t;
}
v -= &u;
if v.is_zero() {
return u == one;
}
}
}
}
|]

View File

@@ -1,5 +1,5 @@
{-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE QuasiQuotes #-}
module Shift(shiftOps) module Shift(shiftOps, signedShiftOps)
where where
import Data.Bits(shiftL,shiftR) import Data.Bits(shiftL,shiftR)
@@ -26,6 +26,15 @@ shiftOps = File {
testCase = Just generateTests testCase = Just generateTests
} }
signedShiftOps :: File
signedShiftOps = File {
predicate = \ _ _ -> True,
outputName = "sshift",
isUnsigned = False,
generator = declareSignedShiftOperators,
testCase = Just generateSignedTests
}
declareShiftOperators :: Word -> [Word] -> SourceFile Span declareShiftOperators :: Word -> [Word] -> SourceFile Span
declareShiftOperators bitsize _ = declareShiftOperators bitsize _ =
let struct_name = mkIdent ("U" ++ show bitsize) let struct_name = mkIdent ("U" ++ show bitsize)
@@ -97,6 +106,68 @@ declareShiftOperators bitsize _ =
} }
|] |]
declareSignedShiftOperators :: Word -> [Word] -> SourceFile Span
declareSignedShiftOperators bitsize _ =
let struct_name = mkIdent ("I" ++ show bitsize)
entries = bitsize `div` 64
unsignedShifts = generateUnsigneds struct_name
shlUsizeImpls = generateBaseUsizes struct_name
shlActualImpl = concatMap actualShlImplLines [1..entries-1]
shrActualImpl = concatMap (actualShrImplLines entries) (reverse [0..entries-1])
resAssign = map reassignSelf [0..entries-1]
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
#[cfg(test)]
use core::convert::TryFrom;
use core::ops::{Shl,ShlAssign};
use core::ops::{Shr,ShrAssign};
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use super::$$struct_name;
impl ShlAssign<usize> for $$struct_name {
fn shl_assign(&mut self, rhs: usize) {
panic!("shl_assign")
}
}
impl ShrAssign<usize> for $$struct_name {
fn shr_assign(&mut self, rhs: usize) {
panic!("shr_assign")
}
}
$@{shlUsizeImpls}
$@{unsignedShifts}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("shift", $$(testFileLit)), 4, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, sbytes) = case.get("s").unwrap();
let (neg2, lbytes) = case.get("l").unwrap();
let (neg3, rbytes) = case.get("r").unwrap();
assert!(!neg1);
let mut x = $$struct_name::from_bytes(xbytes);
let mut l = $$struct_name::from_bytes(lbytes);
let mut r = $$struct_name::from_bytes(rbytes);
if neg0 { x = x.negate() }
if neg2 { l = l.negate() }
if neg3 { r = r.negate() }
let s = usize::try_from($$struct_name::from_bytes(sbytes)).unwrap();
assert_eq!(l, &x << s);
assert_eq!(r, &x >> s);
});
}
|]
actualShlImplLines :: Word -> [Stmt Span] actualShlImplLines :: Word -> [Stmt Span]
actualShlImplLines i = actualShlImplLines i =
let basei = mkIdent ("base" ++ show i) let basei = mkIdent ("base" ++ show i)
@@ -224,3 +295,18 @@ generateTests size g = go g numTestCases
tcase = Map.fromList [("x", showX x), ("s", showX s), tcase = Map.fromList [("x", showX x), ("s", showX s),
("l", showX l), ("r", showX r)] ("l", showX l), ("r", showX r)]
in tcase : go g2 (i - 1) in tcase : go g2 (i - 1)
generateSignedTests :: RandomGen g => Word -> g -> [Map String String]
generateSignedTests size g = go g numTestCases
where
go _ 0 = []
go g0 i =
let (x, g1) = generateSignedNum g0 size
(y, g2) = generateNum g1 size
s = y `mod` fromIntegral size
l = modulate (x `shiftL` fromIntegral s) size
r = modulate (x `shiftR` fromIntegral s) size
tcase = Map.fromList [("x", showX x), ("s", showX s),
("l", showX l), ("r", showX r)]
in tcase : go g2 (i - 1)

View File

@@ -33,7 +33,13 @@ declareSigned bitsize _ =
#[derive(Clone)] #[derive(Clone)]
pub struct $$sname { pub struct $$sname {
contents: $$uname, pub(crate) contents: $$uname,
}
impl $$sname {
pub fn is_negative(&self) -> bool {
self.contents.value[self.contents.value.len()-1] & 0x8000_0000_0000_0000 != 0
}
} }
impl Neg for $$sname { impl Neg for $$sname {
@@ -109,20 +115,6 @@ declareSigned bitsize _ =
} }
} }
impl PartialEq for $$sname {
fn eq(&self, other: &$$sname) -> bool {
let mut all_equal = true;
for (l, r) in self.contents.value.iter().zip(other.contents.value.iter()) {
all_equal &= *l == *r;
}
all_equal
}
}
impl Eq for $$sname { }
impl Arbitrary for $$sname { impl Arbitrary for $$sname {
fn arbitrary<G: Gen>(g: &mut G) -> $$sname { fn arbitrary<G: Gen>(g: &mut G) -> $$sname {
$$sname{ $$sname{
@@ -163,14 +155,6 @@ declareSigned bitsize _ =
#[cfg(test)] #[cfg(test)]
quickcheck! { quickcheck! {
fn equality_reflexive(x: $$sname) -> bool {
&x == &x
}
fn equality_symmetric(x: $$sname, y: $$sname) -> bool {
(&x == &y) == (&y == &x)
}
fn double_not(x: $$sname) -> bool { fn double_not(x: $$sname) -> bool {
x == !!&x x == !!&x
} }

View File

@@ -2,6 +2,8 @@
module Subtract( module Subtract(
safeSubtractOps safeSubtractOps
, unsafeSubtractOps , unsafeSubtractOps
, safeSignedSubtractOps
, unsafeSignedSubtractOps
) )
where where
@@ -29,6 +31,15 @@ safeSubtractOps = File {
testCase = Just generateSafeTests testCase = Just generateSafeTests
} }
safeSignedSubtractOps :: File
safeSignedSubtractOps = File {
predicate = \ me others -> (me + 64) `elem` others,
outputName = "safe_ssub",
isUnsigned = True,
generator = declareSafeSignedSubtractOperators,
testCase = Just generateSafeSignedTests
}
unsafeSubtractOps :: File unsafeSubtractOps :: File
unsafeSubtractOps = File { unsafeSubtractOps = File {
predicate = \ _ _ -> True, predicate = \ _ _ -> True,
@@ -38,6 +49,15 @@ unsafeSubtractOps = File {
testCase = Just generateUnsafeTests testCase = Just generateUnsafeTests
} }
unsafeSignedSubtractOps :: File
unsafeSignedSubtractOps = File {
predicate = \ _ _ -> True,
outputName = "unsafe_ssub",
isUnsigned = True,
generator = declareUnsafeSignedSubtractOperators,
testCase = Just generateUnsafeSignedTests
}
declareSafeSubtractOperators :: Word -> [Word] -> SourceFile Span declareSafeSubtractOperators :: Word -> [Word] -> SourceFile Span
declareSafeSubtractOperators bitsize _ = declareSafeSubtractOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize) let sname = mkIdent ("U" ++ show bitsize)
@@ -106,6 +126,70 @@ declareSafeSubtractOperators bitsize _ =
} }
|] |]
declareSafeSignedSubtractOperators :: Word -> [Word] -> SourceFile Span
declareSafeSignedSubtractOperators bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
dname = mkIdent ("I" ++ show (bitsize + 64))
fullRippleSubtract = makeRippleSubtracter True (bitsize `div` 64) "res"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::Sub;
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::signed::{$$sname,$$dname};
impl Sub for $$sname {
type Output = $$dname;
fn sub(self, rhs: $$sname) -> $$dname {
&self - &rhs
}
}
impl<'a> Sub<&'a $$sname> for $$sname {
type Output = $$dname;
fn sub(self, rhs: &$$sname) -> $$dname {
&self - rhs
}
}
impl<'a> Sub<$$sname> for &'a $$sname {
type Output = $$dname;
fn sub(self, rhs: $$sname) -> $$dname {
self - &rhs
}
}
impl<'a,'b> Sub<&'a $$sname> for &'b $$sname {
type Output = $$dname;
fn sub(self, rhs: &$$sname) -> $$dname {
panic!("sub")
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("safe_sub", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$dname::from_bytes(&zbytes);
assert_eq!(z, x - y);
});
}
|]
declareUnsafeSubtractOperators :: Word -> [Word] -> SourceFile Span declareUnsafeSubtractOperators :: Word -> [Word] -> SourceFile Span
declareUnsafeSubtractOperators bitsize _ = declareUnsafeSubtractOperators bitsize _ =
let sname = mkIdent ("U" ++ show bitsize) let sname = mkIdent ("U" ++ show bitsize)
@@ -151,6 +235,52 @@ declareUnsafeSubtractOperators bitsize _ =
} }
|] |]
declareUnsafeSignedSubtractOperators :: Word -> [Word] -> SourceFile Span
declareUnsafeSignedSubtractOperators bitsize _ =
let sname = mkIdent ("I" ++ show bitsize)
fullRippleSubtract = makeRippleSubtracter False (bitsize `div` 64) "self"
testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::SubAssign;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
use crate::signed::$$sname;
impl SubAssign for $$sname {
fn sub_assign(&mut self, rhs: Self) {
self.sub_assign(&rhs);
}
}
impl<'a> SubAssign<&'a $$sname> for $$sname {
fn sub_assign(&mut self, rhs: &Self) {
panic!("sub_assign")
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("unsafe_ssub", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let mut x = $$sname::from_bytes(&xbytes);
let mut y = $$sname::from_bytes(&ybytes);
let mut z = $$sname::from_bytes(&zbytes);
if neg0 { x = x.negate(); }
if neg1 { y = y.negate(); }
if neg2 { z = z.negate(); }
x -= &y;
assert_eq!(z, x);
});
}
|]
makeRippleSubtracter :: Bool -> Word -> String -> [Stmt Span] makeRippleSubtracter :: Bool -> Word -> String -> [Stmt Span]
makeRippleSubtracter useLastCarry inElems resName = makeRippleSubtracter useLastCarry inElems resName =
@@ -206,6 +336,20 @@ generateSafeTests size g = go g numTestCases
("z", showX r)] ("z", showX r)]
in tcase : go g2 (i - 1) in tcase : go g2 (i - 1)
generateSafeSignedTests :: RandomGen g => Word -> g -> [Map String String]
generateSafeSignedTests size g = go g numTestCases
where
go _ 0 = []
go g0 i =
let (x, g1) = generateSignedNum g0 size
(y, g2) = generateSignedNum g1 size
r | x < y = (2 ^ (size + 64)) + (x - y)
| otherwise = x - y
tcase = Map.fromList [("x", showX x),
("y", showX y),
("z", showX r)]
in tcase : go g2 (i - 1)
generateUnsafeTests :: RandomGen g => Word -> g -> [Map String String] generateUnsafeTests :: RandomGen g => Word -> g -> [Map String String]
generateUnsafeTests size g = go g numTestCases generateUnsafeTests size g = go g numTestCases
where where
@@ -217,3 +361,15 @@ generateUnsafeTests size g = go g numTestCases
tcase = Map.fromList [("x", showX x), ("y", showX y), tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX z)] ("z", showX z)]
in tcase : go g2 (i - 1) in tcase : go g2 (i - 1)
generateUnsafeSignedTests :: RandomGen g => Word -> g -> [Map String String]
generateUnsafeSignedTests size g = go g numTestCases
where
go _ 0 = []
go g0 i =
let (x, g1) = generateSignedNum g0 size
(y, g2) = generateSignedNum g1 size
z = (x - y) .&. ((2 ^ size) - 1)
tcase = Map.fromList [("x", showX x), ("y", showX y),
("z", showX z)]
in tcase : go g2 (i - 1)