From b3fcd4715e52cfa471ad506811decbc1f68bc018 Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Fri, 17 Jan 2020 20:44:41 -0800 Subject: [PATCH] All the infrastructure to eventually to modinv. Don't try to use any of this yet. --- generation/Main.hs | 21 +++- generation/generation.cabal | 1 + generation/src/Add.hs | 179 ++++++++++++++++++++++++++++ generation/src/Compare.hs | 122 +++++++++++++++++++- generation/src/Conversions.hs | 211 +++++++++++++++++++++++++++++++++- generation/src/Generators.hs | 9 ++ generation/src/ModInv.hs | 173 ++++++++++++++++++++++++++++ generation/src/Shift.hs | 88 +++++++++++++- generation/src/Signed.hs | 30 ++--- generation/src/Subtract.hs | 156 +++++++++++++++++++++++++ 10 files changed, 958 insertions(+), 32 deletions(-) create mode 100644 generation/src/ModInv.hs diff --git a/generation/Main.hs b/generation/Main.hs index 3ab8665..9adf98b 100644 --- a/generation/Main.hs +++ b/generation/Main.hs @@ -1,21 +1,22 @@ module Main where -import Add(safeAddOps,unsafeAddOps) +import Add(safeAddOps,unsafeAddOps,safeSignedAddOps,unsafeSignedAddOps) import Base(base) import BinaryOps(binaryOps) -import Compare(comparisons) -import Conversions(conversions) +import Compare(comparisons, signedComparisons) +import Conversions(conversions, signedConversions) import CryptoNum(cryptoNum) import Control.Monad(forM_,unless) import Division(divisionOps) import File(File,Task(..),generateTasks) +import ModInv(generateModInvOps) import ModOps(modulusOps) import Multiply(safeMultiplyOps, unsafeMultiplyOps) import Scale(safeScaleOps, unsafeScaleOps) -import Shift(shiftOps) +import Shift(shiftOps, signedShiftOps) import Signed(signedBaseOps) -import Subtract(safeSubtractOps,unsafeSubtractOps) +import Subtract(safeSubtractOps,unsafeSubtractOps,safeSignedSubtractOps,unsafeSignedSubtractOps) import System.Directory(createDirectoryIfMissing) import System.Environment(getArgs) import System.Exit(die) @@ -54,7 +55,15 @@ unsignedFiles = [ signedFiles :: [File] signedFiles = [ - signedBaseOps + generateModInvOps + , safeSignedAddOps + , safeSignedSubtractOps + , signedBaseOps + , signedComparisons + , signedConversions + , signedShiftOps + , unsafeSignedAddOps + , unsafeSignedSubtractOps ] allFiles :: [File] diff --git a/generation/generation.cabal b/generation/generation.cabal index a81afc7..5f89a15 100644 --- a/generation/generation.cabal +++ b/generation/generation.cabal @@ -41,6 +41,7 @@ library Gen, Generators, Karatsuba, + ModInv, ModOps, Multiply, Scale, diff --git a/generation/src/Add.hs b/generation/src/Add.hs index 96e7184..5cb76db 100644 --- a/generation/src/Add.hs +++ b/generation/src/Add.hs @@ -2,6 +2,8 @@ module Add( safeAddOps , unsafeAddOps + , safeSignedAddOps + , unsafeSignedAddOps ) where @@ -38,6 +40,24 @@ unsafeAddOps = File { testCase = Just generateUnsafeTests } +safeSignedAddOps :: File +safeSignedAddOps = File { + predicate = \ me others -> (me + 64) `elem` others, + outputName = "safe_sadd", + isUnsigned = False, + generator = declareSafeSignedAddOperators, + testCase = Just generateSafeSignedTests +} + +unsafeSignedAddOps :: File +unsafeSignedAddOps = File { + predicate = \ _ _ -> True, + outputName = "unsafe_sadd", + isUnsigned = False, + generator = declareUnsafeSignedAddOperators, + testCase = Just generateUnsafeSignedTests +} + declareSafeAddOperators :: Word -> [Word] -> SourceFile Span declareSafeAddOperators bitsize _ = let sname = mkIdent ("U" ++ show bitsize) @@ -175,6 +195,142 @@ declareUnsafeAddOperators bitsize _ = } |] +declareSafeSignedAddOperators :: Word -> [Word] -> SourceFile Span +declareSafeSignedAddOperators bitsize _ = + let sname = mkIdent ("I" ++ show bitsize) + dname = mkIdent ("I" ++ show (bitsize + 64)) + fullRippleAdd = makeRippleAdder True (bitsize `div` 64) "res" + testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty + in [sourceFile| + use core::ops::Add; + use crate::CryptoNum; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + #[cfg(test)] + use quickcheck::quickcheck; + use crate::signed::{$$sname,$$dname}; + + impl Add for $$sname { + type Output = $$dname; + + fn add(self, rhs: $$sname) -> $$dname { + &self + &rhs + } + } + + impl<'a> Add<&'a $$sname> for $$sname { + type Output = $$dname; + + fn add(self, rhs: &$$sname) -> $$dname { + &self + rhs + } + } + + impl<'a> Add<$$sname> for &'a $$sname { + type Output = $$dname; + + fn add(self, rhs: $$sname) -> $$dname { + self + &rhs + } + } + + impl<'a,'b> Add<&'a $$sname> for &'b $$sname { + type Output = $$dname; + + fn add(self, rhs: &$$sname) -> $$dname { + panic!("add") + } + } + + #[cfg(test)] + quickcheck! { + fn addition_symmetric(a: $$sname, b: $$sname) -> bool { + (&a + &b) == (&b + &a) + } + } + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("safe_sadd", $$(testFileLit)), 3, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, ybytes) = case.get("y").unwrap(); + let (neg2, zbytes) = case.get("z").unwrap(); + + let mut x = $$sname::from_bytes(&xbytes); + let mut y = $$sname::from_bytes(&ybytes); + let mut z = $$dname::from_bytes(&zbytes); + if neg0 { x = x.negate() } + if neg1 { y = y.negate() } + if neg2 { z = z.negate() } + + assert_eq!(z, x + y); + }); + } + |] + +declareUnsafeSignedAddOperators :: Word -> [Word] -> SourceFile Span +declareUnsafeSignedAddOperators bitsize _ = + let sname = mkIdent ("I" ++ show bitsize) + fullRippleAdd = makeRippleAdder False (bitsize `div` 64) "self" + testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty + in [sourceFile| + use core::ops::AddAssign; + #[cfg(test)] + use crate::CryptoNum; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + #[cfg(test)] + use quickcheck::quickcheck; + use crate::signed::$$sname; + + impl AddAssign for $$sname { + fn add_assign(&mut self, rhs: Self) { + self.add_assign(&rhs); + } + } + + impl<'a> AddAssign<&'a $$sname> for $$sname { + fn add_assign(&mut self, rhs: &Self) { + panic!("add_assign") + } + } + + #[cfg(test)] + quickcheck! { + fn addition_symmetric(a: $$sname, b: $$sname) -> bool { + let mut side1 = a.clone(); + let mut side2 = b.clone(); + + side1 += b; + side2 += a; + + side1 == side2 + } + } + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("unsafe_sadd", $$(testFileLit)), 3, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, ybytes) = case.get("y").unwrap(); + let (neg2, zbytes) = case.get("z").unwrap(); + + let mut x = $$sname::from_bytes(&xbytes); + let mut y = $$sname::from_bytes(&ybytes); + let mut z = $$sname::from_bytes(&zbytes); + if neg0 { x = x.negate() } + if neg1 { y = y.negate() } + if neg2 { z = z.negate() } + + x += &y; + assert_eq!(z, x); + }); + } + |] makeRippleAdder :: Bool -> Word -> String -> [Stmt Span] makeRippleAdder useLastCarry inElems resName = @@ -238,3 +394,26 @@ generateUnsafeTests size g = go g numTestCases tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)] in tcase : go g2 (i - 1) + +generateSafeSignedTests :: RandomGen g => Word -> g -> [Map String String] +generateSafeSignedTests size g = go g numTestCases + where + go _ 0 = [] + go g0 i = + let (x, g1) = generateSignedNum g0 size + (y, g2) = generateSignedNum g1 size + tcase = Map.fromList [("x", showX x), ("y", showX y), + ("z", showX (x + y))] + in tcase : go g2 (i - 1) + +generateUnsafeSignedTests :: RandomGen g => Word -> g -> [Map String String] +generateUnsafeSignedTests size g = go g numTestCases + where + go _ 0 = [] + go g0 i = + let (x, g1) = generateSignedNum g0 size + (y, g2) = generateSignedNum g1 size + z = (x + y) .&. ((2 ^ size) - 1) + tcase = Map.fromList [("x", showX x), ("y", showX y), + ("z", showX z)] + in tcase : go g2 (i - 1) diff --git a/generation/src/Compare.hs b/generation/src/Compare.hs index bfdfa11..18df5d1 100644 --- a/generation/src/Compare.hs +++ b/generation/src/Compare.hs @@ -1,5 +1,5 @@ {-# LANGUAGE QuasiQuotes #-} -module Compare(comparisons) +module Compare(comparisons, signedComparisons) where import Data.Map.Strict(Map) @@ -24,6 +24,15 @@ comparisons = File { testCase = Just generateTests } +signedComparisons :: File +signedComparisons = File { + predicate = \ _ _ -> True, + outputName = "scompare", + isUnsigned = False, + generator = declareSignedComparators, + testCase = Just generateSignedTests +} + declareComparators :: Word -> [Word] -> SourceFile Span declareComparators bitsize _ = let sname = mkIdent ("U" ++ show bitsize) @@ -121,6 +130,101 @@ declareComparators bitsize _ = } |] +declareSignedComparators :: Word -> [Word] -> SourceFile Span +declareSignedComparators bitsize _ = + let sname = mkIdent ("I" ++ show bitsize) + entries = bitsize `div` 64 + eqStatements = buildEqStatements 0 entries + compareExp = buildCompareExp 0 entries + testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty + in [sourceFile| + use core::cmp::{Eq,Ordering,PartialEq}; + #[cfg(test)] + use crate::CryptoNum; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + #[cfg(test)] + use quickcheck::quickcheck; + use super::$$sname; + + impl PartialEq for $$sname { + fn eq(&self, other: &Self) -> bool { + &self.contents == &other.contents + } + } + + impl Eq for $$sname {} + + impl Ord for $$sname { + fn cmp(&self, other: &Self) -> Ordering { + panic!("cmp") + } + } + + impl PartialOrd for $$sname { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + #[cfg(test)] + quickcheck! { + fn eq_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool { + if a == c { a == b && b == c } else { a != b || b != c } + } + + fn gt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool { + if a > b && b > c { a > c } else { true } + } + + fn ge_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool { + if a >= b && b >= c { a >= c } else { true } + } + + fn lt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool { + if a < b && b < c { a < c } else { true } + } + + fn le_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool { + if a <= b && b <= c { a <= c } else { true } + } + } + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("scompare", $$(testFileLit)), 8, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, ybytes) = case.get("y").unwrap(); + let (neg2, ebytes) = case.get("e").unwrap(); + let (neg3, nbytes) = case.get("n").unwrap(); + let (neg4, gbytes) = case.get("g").unwrap(); + let (neg5, hbytes) = case.get("h").unwrap(); + let (neg6, lbytes) = case.get("l").unwrap(); + let (neg7, kbytes) = case.get("k").unwrap(); + + assert!(!neg0 && !neg1 && !neg2 && !neg3 && + !neg4 && !neg5 && !neg6 && !neg7); + let x = $$sname::from_bytes(&xbytes); + let y = $$sname::from_bytes(&ybytes); + let e = 1 == ebytes[0]; + let n = 1 == nbytes[0]; + let g = 1 == gbytes[0]; + let h = 1 == hbytes[0]; + let l = 1 == lbytes[0]; + let k = 1 == kbytes[0]; + + assert_eq!(e, x == y); + assert_eq!(n, x != y); + assert_eq!(g, x > y); + assert_eq!(h, x >= y); + assert_eq!(l, x < y); + assert_eq!(k, x <= y); + }); + } + |] + buildEqStatements :: Word -> Word -> [Stmt Span] buildEqStatements i numEntries | i == (numEntries - 1) = @@ -157,3 +261,19 @@ generateTests size g = go g numTestCases ("l", showB (x < y)), ("k", showB (x <= y))] in tcase : go g2 (i - 1) + +generateSignedTests :: RandomGen g => Word -> g -> [Map String String] +generateSignedTests size g = go g numTestCases + where + go _ 0 = [] + go g0 i = + let (x, g1) = generateSignedNum g0 size + (y, g2) = generateSignedNum g1 size + tcase = Map.fromList [("x", showX x), ("y", showX y), + ("e", showB (x == y)), + ("n", showB (x /= y)), + ("g", showB (x > y)), + ("h", showB (x >= y)), + ("l", showB (x < y)), + ("k", showB (x <= y))] + in tcase : go g2 (i - 1) diff --git a/generation/src/Conversions.hs b/generation/src/Conversions.hs index b3f8804..9907ba9 100644 --- a/generation/src/Conversions.hs +++ b/generation/src/Conversions.hs @@ -1,6 +1,7 @@ {-# LANGUAGE QuasiQuotes #-} module Conversions( conversions + , signedConversions ) where @@ -20,6 +21,15 @@ conversions = File { testCase = Nothing } +signedConversions :: File +signedConversions = File { + predicate = \ _ _ -> True, + outputName = "sconversions", + isUnsigned = False, + generator = declareSignedConversions, + testCase = Nothing +} + declareConversions :: Word -> [Word] -> SourceFile Span declareConversions bitsize otherSizes = let sname = mkIdent ("U" ++ show bitsize) @@ -84,6 +94,71 @@ declareConversions bitsize otherSizes = } |] +declareSignedConversions :: Word -> [Word] -> SourceFile Span +declareSignedConversions bitsize otherSizes = + let sname = mkIdent ("I" ++ show bitsize) + entries = bitsize `div` 64 + u8_prims = buildUSPrimitives sname (mkIdent "u8") entries + u16_prims = buildUSPrimitives sname (mkIdent "u16") entries + u32_prims = buildUSPrimitives sname (mkIdent "u32") entries + u64_prims = buildUSPrimitives sname (mkIdent "u64") entries + usz_prims = buildUSPrimitives sname (mkIdent "usize") entries + u128_prims = generateUS128Primitives sname entries + i8_prims = buildSSPrimitives sname (mkIdent "i8") entries + i16_prims = buildSSPrimitives sname (mkIdent "i16") entries + i32_prims = buildSSPrimitives sname (mkIdent "i32") entries + i64_prims = buildSSPrimitives sname (mkIdent "i64") entries + isz_prims = buildSSPrimitives sname (mkIdent "isize") entries + i128_prims = generateSS128Primitives sname entries + others = generateSignedCryptonumConversions bitsize otherSizes + in [sourceFile| + use core::convert::{From,TryFrom}; + use crate::CryptoNum; + use crate::ConversionError; + use crate::signed::*; + use crate::unsigned::*; + #[cfg(test)] + use quickcheck::quickcheck; + + $@{u8_prims} + $@{u16_prims} + $@{u32_prims} + $@{u64_prims} + $@{usz_prims} + $@{u128_prims} + + $@{i8_prims} + $@{i16_prims} + $@{i32_prims} + $@{i64_prims} + $@{isz_prims} + $@{i128_prims} + + $@{others} + + #[cfg(test)] + quickcheck! { + fn u8_recovers(x: u8) -> bool { + x == u8::try_from($$sname::from(x)).unwrap() + } + fn u16_recovers(x: u16) -> bool { + x == u16::try_from($$sname::from(x)).unwrap() + } + fn u32_recovers(x: u32) -> bool { + x == u32::try_from($$sname::from(x)).unwrap() + } + fn u64_recovers(x: u64) -> bool { + x == u64::try_from($$sname::from(x)).unwrap() + } + fn usize_recovers(x: usize) -> bool { + x == usize::try_from($$sname::from(x)).unwrap() + } + fn u128_recovers(x: u128) -> bool { + x == u128::try_from($$sname::from(x)).unwrap() + } + } + |] + generateU128Primitives :: Ident -> Word -> [Item Span] generateU128Primitives sname entries = [ [item|impl From for $$sname { @@ -303,4 +378,138 @@ generateCryptonumConversions source otherSizes = concatMap convert otherSizes } |] ] - \ No newline at end of file + +buildUSPrimitives :: Ident -> Ident -> Word -> [Item Span] +buildUSPrimitives sname prim entries = [ + [item| + impl From<$$prim> for $$sname { + fn from(x: $$prim) -> $$sname { + let mut base = $$sname::zero(); + base.contents.value[0] = x as u64; + base + } + } + |] + , [item| + impl<'a> TryFrom<&'a $$sname> for $$prim { + type Error = ConversionError; + + fn try_from(x: &$$sname) -> Result<$$prim, ConversionError> { + if (x.contents.value)[1..].iter().any(|v| *v != 0) { + return Err(ConversionError::Overflow); + } + let res64 = x.contents.value[0]; + if res64 & 0x8000_0000_0000_0000 != 0 { + return Err(ConversionError::Overflow); + } + Ok(res64 as $$prim) + } + } + |] + , [item| + impl TryFrom<$$sname> for $$prim { + type Error = ConversionError; + + fn try_from(x: $$sname) -> Result<$$prim, ConversionError> { + $$prim::try_from(&x) + } + } + |] + ] + +buildSSPrimitives :: Ident -> Ident -> Word -> [Item Span] +buildSSPrimitives sname prim entries = [ + [item| + impl From<$$prim> for $$sname { + fn from(x: $$prim) -> $$sname { + panic!("from_signed") + } + } + |] + , [item| + impl<'a> TryFrom<&'a $$sname> for $$prim { + type Error = ConversionError; + + fn try_from(x: &$$sname) -> Result<$$prim, ConversionError> { + panic!("try_from_signed") + } + } + |] + , [item| + impl TryFrom<$$sname> for $$prim { + type Error = ConversionError; + + fn try_from(x: $$sname) -> Result<$$prim, ConversionError> { + $$prim::try_from(&x) + } + } + |] + ] + +generateUS128Primitives :: Ident -> Word -> [Item Span] +generateUS128Primitives struct entries = [] + +generateSS128Primitives :: Ident -> Word -> [Item Span] +generateSS128Primitives struct entries = [] + +generateSignedCryptonumConversions :: Word -> [Word] -> [Item Span] +generateSignedCryptonumConversions source otherSizes = concatMap convert otherSizes + where + sName = mkIdent ("I" ++ show source) + -- + convert target = + let tsName = mkIdent ("I" ++ show target) + tuName = mkIdent ("U" ++ show target) + sEntries = toLit (source `div` 64) + tEntries = toLit (target `div` 64) + in case compare source target of + LT -> [] + EQ -> [ + [item| + impl TryFrom<$$tuName> for $$sName { + type Error = ConversionError; + + fn try_from(x: $$tuName) -> Result<$$sName,ConversionError> { + let res = $$sName{ contents: x }; + + if res.is_negative() { + return Err(ConversionError::Overflow); + } + + Ok(res) + } + } + |], + [item| + impl<'a> TryFrom<&'a $$tuName> for $$sName { + type Error = ConversionError; + + fn try_from(x: &$$tuName) -> Result<$$sName,ConversionError> { + $$sName::try_from(x.clone()) + } + } + |], + [item| + impl TryFrom<$$sName> for $$tuName { + type Error = ConversionError; + + fn try_from(x: $$sName) -> Result<$$tuName,ConversionError> { + if x.is_negative() { + return Err(ConversionError::Overflow); + } + Ok(x.contents) + } + } + |], + [item| + impl<'a> TryFrom<&'a $$sName> for $$tuName { + type Error = ConversionError; + + fn try_from(x: &$$sName) -> Result<$$tuName,ConversionError> { + $$tuName::try_from(x.clone()) + } + } + |] + ] + GT -> [] + diff --git a/generation/src/Generators.hs b/generation/src/Generators.hs index 3f9a474..c30f79b 100644 --- a/generation/src/Generators.hs +++ b/generation/src/Generators.hs @@ -10,6 +10,15 @@ generateNum g size = x' = x `mod` (2 ^ size) in (x', g') +generateSignedNum :: RandomGen g => g -> Word -> (Integer, g) +generateSignedNum g size = + let (x, g') = random g + s :: Integer + (s, g'') = random g' + x' = x `mod` (2 ^ size) + sign = if even s then 1 else -1 + in (x' * sign, g'') + modulate :: (Integral a, Integral b) => a -> b -> Integer modulate x size = x' `mod` (2 ^ size') where diff --git a/generation/src/ModInv.hs b/generation/src/ModInv.hs new file mode 100644 index 0000000..8f10f7d --- /dev/null +++ b/generation/src/ModInv.hs @@ -0,0 +1,173 @@ +{-# LANGUAGE QuasiQuotes #-} +module ModInv( + generateModInvOps + ) + where + +import File +import Language.Rust.Data.Ident +import Language.Rust.Data.Position +import Language.Rust.Quote +import Language.Rust.Syntax + +generateModInvOps :: File +generateModInvOps = File { + predicate = \ me others -> (me + 64) `elem` others, + outputName = "modinv", + isUnsigned = False, + generator = declareModInv, + testCase = Nothing +} + +declareModInv :: Word -> [Word] -> SourceFile Span +declareModInv bitsize _ = + let sname = mkIdent ("I" ++ show bitsize) + uname = mkIdent ("U" ++ show bitsize) + in [sourceFile| + use core::convert::TryFrom; + use crate::CryptoNum; + use crate::signed::$$sname; + use crate::unsigned::$$uname; + + impl $$uname { + fn modinv(&self, phi: &$$uname) -> Option<$$uname> + { + let (_, mut b, g) = phi.egcd(&self); + + if g != $$sname::from(1i64) { + return None; + } + + let sphi = $$sname::try_from(phi).expect("over/underflow in modinv phi"); + + while b.is_negative() { + b += &sphi; + } + + if b > sphi { + b -= &sphi; + } + + Some($$uname::try_from(b).expect("overflow/underflow in modinv result")) + } + + fn egcd(&self, rhs: &$$uname) -> ($$sname, $$sname, $$sname) { + // INPUT: two positive integers x and y. + let mut x = $$sname::try_from(self).expect("overflow in modinv base"); + let mut y = $$sname::try_from(rhs).expect("overflow in modinv rhs"); + // OUTPUT: integers a, b, and v such that ax + by = v, + // where v = gcd(x, y). + // 1. g←1. + let mut gshift: usize = 0; + // 2. While x and y are both even, do the following: x←x/2, + // y←y/2, g←2g. + while x.is_even() && y.is_even() { + x >>= 1u64; + y >>= 1u64; + gshift += 1; + } + // 3. u←x, v←y, A←1, B←0, C←0, D←1. + let mut u = x.clone(); + let mut v = y.clone(); + #[allow(non_snake_case)] + let mut A = $$sname::from(1i64); + #[allow(non_snake_case)] + let mut B = $$sname::zero(); + #[allow(non_snake_case)] + let mut C = $$sname::zero(); + #[allow(non_snake_case)] + let mut D = $$sname::from(1i64); + loop { + // 4. While u is even do the following: + while u.is_even() { + // 4.1 u←u/2. + u >>= 1u64; + // 4.2 If A≡B≡0 (mod 2) then A←A/2, B←B/2; otherwise, + // A←(A + y)/2, B←(B − x)/2. + if A.is_even() && B.is_even() { + A >>= 1u64; + B >>= 1u64; + } else { + A += &y; + A >>= 1u64; + B -= &x; + B >>= 1u64; + } + } + // 5. While v is even do the following: + while v.is_even() { + // 5.1 v←v/2. + v >>= 1u64; + // 5.2 If C ≡ D ≡ 0 (mod 2) then C←C/2, D←D/2; otherwise, + // C←(C + y)/2, D←(D − x)/2. + if C.is_even() && D.is_even() { + C >>= 1u64; + D >>= 1u64; + } else { + C += &y; + C >>= 1u64; + D -= &x; + D >>= 1u64; + } + } + // 6. If u≥v then u←u−v, A←A−C,B←B−D; + // otherwise,v←v−u, C←C−A, D←D−B. + if u >= v { + u -= &v; + A -= &C; + B -= &D; + } else { + v -= &u; + C -= &A; + D -= &B; + } + // 7. If u = 0, then a←C, b←D, and return(a, b, g · v); + // otherwise, go to step 4. + if u.is_zero() { + return (C, D, v << gshift); + } + } + } + + fn gcd_is_one(&self, b: &$$uname) -> bool { + let mut u = self.clone(); + let mut v = b.clone(); + let one = $$uname::from(1u64); + + if u.is_zero() { + return v == one; + } + + if v.is_zero() { + return u == one; + } + + if u.is_even() && v.is_even() { + return false; + } + + while u.is_even() { + u >>= 1u64; + } + + loop { + while v.is_even() { + v >>= 1u64; + } + // u and v guaranteed to be odd right now. + if u > v { + // make sure that v > u, so that our subtraction works + // out. + let t = u; + u = v; + v = t; + } + v -= &u; + + if v.is_zero() { + return u == one; + } + } + } + } + |] \ No newline at end of file diff --git a/generation/src/Shift.hs b/generation/src/Shift.hs index 338802a..864a10e 100644 --- a/generation/src/Shift.hs +++ b/generation/src/Shift.hs @@ -1,5 +1,5 @@ {-# LANGUAGE QuasiQuotes #-} -module Shift(shiftOps) +module Shift(shiftOps, signedShiftOps) where import Data.Bits(shiftL,shiftR) @@ -26,6 +26,15 @@ shiftOps = File { testCase = Just generateTests } +signedShiftOps :: File +signedShiftOps = File { + predicate = \ _ _ -> True, + outputName = "sshift", + isUnsigned = False, + generator = declareSignedShiftOperators, + testCase = Just generateSignedTests +} + declareShiftOperators :: Word -> [Word] -> SourceFile Span declareShiftOperators bitsize _ = let struct_name = mkIdent ("U" ++ show bitsize) @@ -97,6 +106,68 @@ declareShiftOperators bitsize _ = } |] +declareSignedShiftOperators :: Word -> [Word] -> SourceFile Span +declareSignedShiftOperators bitsize _ = + let struct_name = mkIdent ("I" ++ show bitsize) + entries = bitsize `div` 64 + unsignedShifts = generateUnsigneds struct_name + shlUsizeImpls = generateBaseUsizes struct_name + shlActualImpl = concatMap actualShlImplLines [1..entries-1] + shrActualImpl = concatMap (actualShrImplLines entries) (reverse [0..entries-1]) + resAssign = map reassignSelf [0..entries-1] + testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty + in [sourceFile| + #[cfg(test)] + use core::convert::TryFrom; + use core::ops::{Shl,ShlAssign}; + use core::ops::{Shr,ShrAssign}; + #[cfg(test)] + use crate::CryptoNum; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + use super::$$struct_name; + + impl ShlAssign for $$struct_name { + fn shl_assign(&mut self, rhs: usize) { + panic!("shl_assign") + } + } + + impl ShrAssign for $$struct_name { + fn shr_assign(&mut self, rhs: usize) { + panic!("shr_assign") + } + } + + $@{shlUsizeImpls} + $@{unsignedShifts} + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("shift", $$(testFileLit)), 4, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, sbytes) = case.get("s").unwrap(); + let (neg2, lbytes) = case.get("l").unwrap(); + let (neg3, rbytes) = case.get("r").unwrap(); + + assert!(!neg1); + let mut x = $$struct_name::from_bytes(xbytes); + let mut l = $$struct_name::from_bytes(lbytes); + let mut r = $$struct_name::from_bytes(rbytes); + + if neg0 { x = x.negate() } + if neg2 { l = l.negate() } + if neg3 { r = r.negate() } + let s = usize::try_from($$struct_name::from_bytes(sbytes)).unwrap(); + + assert_eq!(l, &x << s); + assert_eq!(r, &x >> s); + }); + } + |] + actualShlImplLines :: Word -> [Stmt Span] actualShlImplLines i = let basei = mkIdent ("base" ++ show i) @@ -224,3 +295,18 @@ generateTests size g = go g numTestCases tcase = Map.fromList [("x", showX x), ("s", showX s), ("l", showX l), ("r", showX r)] in tcase : go g2 (i - 1) + + +generateSignedTests :: RandomGen g => Word -> g -> [Map String String] +generateSignedTests size g = go g numTestCases + where + go _ 0 = [] + go g0 i = + let (x, g1) = generateSignedNum g0 size + (y, g2) = generateNum g1 size + s = y `mod` fromIntegral size + l = modulate (x `shiftL` fromIntegral s) size + r = modulate (x `shiftR` fromIntegral s) size + tcase = Map.fromList [("x", showX x), ("s", showX s), + ("l", showX l), ("r", showX r)] + in tcase : go g2 (i - 1) diff --git a/generation/src/Signed.hs b/generation/src/Signed.hs index 219d41c..c3e7d55 100644 --- a/generation/src/Signed.hs +++ b/generation/src/Signed.hs @@ -33,7 +33,13 @@ declareSigned bitsize _ = #[derive(Clone)] pub struct $$sname { - contents: $$uname, + pub(crate) contents: $$uname, + } + + impl $$sname { + pub fn is_negative(&self) -> bool { + self.contents.value[self.contents.value.len()-1] & 0x8000_0000_0000_0000 != 0 + } } impl Neg for $$sname { @@ -109,20 +115,6 @@ declareSigned bitsize _ = } } - impl PartialEq for $$sname { - fn eq(&self, other: &$$sname) -> bool { - let mut all_equal = true; - - for (l, r) in self.contents.value.iter().zip(other.contents.value.iter()) { - all_equal &= *l == *r; - } - - all_equal - } - } - - impl Eq for $$sname { } - impl Arbitrary for $$sname { fn arbitrary(g: &mut G) -> $$sname { $$sname{ @@ -163,14 +155,6 @@ declareSigned bitsize _ = #[cfg(test)] quickcheck! { - fn equality_reflexive(x: $$sname) -> bool { - &x == &x - } - - fn equality_symmetric(x: $$sname, y: $$sname) -> bool { - (&x == &y) == (&y == &x) - } - fn double_not(x: $$sname) -> bool { x == !!&x } diff --git a/generation/src/Subtract.hs b/generation/src/Subtract.hs index 56a7ddb..1aaa413 100644 --- a/generation/src/Subtract.hs +++ b/generation/src/Subtract.hs @@ -2,6 +2,8 @@ module Subtract( safeSubtractOps , unsafeSubtractOps + , safeSignedSubtractOps + , unsafeSignedSubtractOps ) where @@ -29,6 +31,15 @@ safeSubtractOps = File { testCase = Just generateSafeTests } +safeSignedSubtractOps :: File +safeSignedSubtractOps = File { + predicate = \ me others -> (me + 64) `elem` others, + outputName = "safe_ssub", + isUnsigned = True, + generator = declareSafeSignedSubtractOperators, + testCase = Just generateSafeSignedTests +} + unsafeSubtractOps :: File unsafeSubtractOps = File { predicate = \ _ _ -> True, @@ -38,6 +49,15 @@ unsafeSubtractOps = File { testCase = Just generateUnsafeTests } +unsafeSignedSubtractOps :: File +unsafeSignedSubtractOps = File { + predicate = \ _ _ -> True, + outputName = "unsafe_ssub", + isUnsigned = True, + generator = declareUnsafeSignedSubtractOperators, + testCase = Just generateUnsafeSignedTests +} + declareSafeSubtractOperators :: Word -> [Word] -> SourceFile Span declareSafeSubtractOperators bitsize _ = let sname = mkIdent ("U" ++ show bitsize) @@ -106,6 +126,70 @@ declareSafeSubtractOperators bitsize _ = } |] +declareSafeSignedSubtractOperators :: Word -> [Word] -> SourceFile Span +declareSafeSignedSubtractOperators bitsize _ = + let sname = mkIdent ("I" ++ show bitsize) + dname = mkIdent ("I" ++ show (bitsize + 64)) + fullRippleSubtract = makeRippleSubtracter True (bitsize `div` 64) "res" + testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty + in [sourceFile| + use core::ops::Sub; + use crate::CryptoNum; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + use crate::signed::{$$sname,$$dname}; + + impl Sub for $$sname { + type Output = $$dname; + + fn sub(self, rhs: $$sname) -> $$dname { + &self - &rhs + } + } + + impl<'a> Sub<&'a $$sname> for $$sname { + type Output = $$dname; + + fn sub(self, rhs: &$$sname) -> $$dname { + &self - rhs + } + } + + impl<'a> Sub<$$sname> for &'a $$sname { + type Output = $$dname; + + fn sub(self, rhs: $$sname) -> $$dname { + self - &rhs + } + } + + impl<'a,'b> Sub<&'a $$sname> for &'b $$sname { + type Output = $$dname; + + fn sub(self, rhs: &$$sname) -> $$dname { + panic!("sub") + } + } + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("safe_sub", $$(testFileLit)), 3, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, ybytes) = case.get("y").unwrap(); + let (neg2, zbytes) = case.get("z").unwrap(); + + assert!(!neg0 && !neg1 && !neg2); + let x = $$sname::from_bytes(&xbytes); + let y = $$sname::from_bytes(&ybytes); + let z = $$dname::from_bytes(&zbytes); + + assert_eq!(z, x - y); + }); + } + |] + declareUnsafeSubtractOperators :: Word -> [Word] -> SourceFile Span declareUnsafeSubtractOperators bitsize _ = let sname = mkIdent ("U" ++ show bitsize) @@ -151,6 +235,52 @@ declareUnsafeSubtractOperators bitsize _ = } |] +declareUnsafeSignedSubtractOperators :: Word -> [Word] -> SourceFile Span +declareUnsafeSignedSubtractOperators bitsize _ = + let sname = mkIdent ("I" ++ show bitsize) + fullRippleSubtract = makeRippleSubtracter False (bitsize `div` 64) "self" + testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty + in [sourceFile| + use core::ops::SubAssign; + #[cfg(test)] + use crate::CryptoNum; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + use crate::signed::$$sname; + + impl SubAssign for $$sname { + fn sub_assign(&mut self, rhs: Self) { + self.sub_assign(&rhs); + } + } + + impl<'a> SubAssign<&'a $$sname> for $$sname { + fn sub_assign(&mut self, rhs: &Self) { + panic!("sub_assign") + } + } + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("unsafe_ssub", $$(testFileLit)), 3, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, ybytes) = case.get("y").unwrap(); + let (neg2, zbytes) = case.get("z").unwrap(); + + let mut x = $$sname::from_bytes(&xbytes); + let mut y = $$sname::from_bytes(&ybytes); + let mut z = $$sname::from_bytes(&zbytes); + if neg0 { x = x.negate(); } + if neg1 { y = y.negate(); } + if neg2 { z = z.negate(); } + + x -= &y; + assert_eq!(z, x); + }); + } + |] makeRippleSubtracter :: Bool -> Word -> String -> [Stmt Span] makeRippleSubtracter useLastCarry inElems resName = @@ -206,6 +336,20 @@ generateSafeTests size g = go g numTestCases ("z", showX r)] in tcase : go g2 (i - 1) +generateSafeSignedTests :: RandomGen g => Word -> g -> [Map String String] +generateSafeSignedTests size g = go g numTestCases + where + go _ 0 = [] + go g0 i = + let (x, g1) = generateSignedNum g0 size + (y, g2) = generateSignedNum g1 size + r | x < y = (2 ^ (size + 64)) + (x - y) + | otherwise = x - y + tcase = Map.fromList [("x", showX x), + ("y", showX y), + ("z", showX r)] + in tcase : go g2 (i - 1) + generateUnsafeTests :: RandomGen g => Word -> g -> [Map String String] generateUnsafeTests size g = go g numTestCases where @@ -217,3 +361,15 @@ generateUnsafeTests size g = go g numTestCases tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)] in tcase : go g2 (i - 1) + +generateUnsafeSignedTests :: RandomGen g => Word -> g -> [Map String String] +generateUnsafeSignedTests size g = go g numTestCases + where + go _ 0 = [] + go g0 i = + let (x, g1) = generateSignedNum g0 size + (y, g2) = generateSignedNum g1 size + z = (x - y) .&. ((2 ^ size) - 1) + tcase = Map.fromList [("x", showX x), ("y", showX y), + ("z", showX z)] + in tcase : go g2 (i - 1)