{-# LANGUAGE QuasiQuotes #-} module Add( safeAddOps , unsafeAddOps , safeSignedAddOps , unsafeSignedAddOps ) where import Data.Bits((.&.)) import Data.Map.Strict(Map) import qualified Data.Map.Strict as Map import Generators import Language.Rust.Data.Ident import Language.Rust.Data.Position import Language.Rust.Quote import Language.Rust.Syntax import RustModule import System.Random(RandomGen) safeAddOps :: RustModule safeAddOps = RustModule { predicate = \ me others -> (me + 64) `elem` others, suggested = \ me -> [me + 64], outputName = "safe_add", isUnsigned = True, generator = declareSafeAddOperators, testCase = Just generateSafeTest } unsafeAddOps :: RustModule unsafeAddOps = RustModule { predicate = \ _ _ -> True, suggested = const [], outputName = "unsafe_add", isUnsigned = True, generator = declareUnsafeAddOperators, testCase = Just generateUnsafeTest } safeSignedAddOps :: RustModule safeSignedAddOps = RustModule { predicate = \ me others -> (me + 64) `elem` others, suggested = \ me -> [me + 64], outputName = "safe_sadd", isUnsigned = False, generator = declareSafeSignedAddOperators, testCase = Just generateSafeSignedTest } unsafeSignedAddOps :: RustModule unsafeSignedAddOps = RustModule { predicate = \ _ _ -> True, suggested = const [], outputName = "unsafe_sadd", isUnsigned = False, generator = declareUnsafeSignedAddOperators, testCase = Just generateUnsafeSignedTest } declareSafeAddOperators :: Word -> [Word] -> SourceFile Span declareSafeAddOperators bitsize _ = let sname = mkIdent ("U" ++ show bitsize) dname = mkIdent ("U" ++ show (bitsize + 64)) fullRippleAdd = makeRippleAdder True (bitsize `div` 64) "res" testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty in [sourceFile| use core::ops::Add; use crate::CryptoNum; #[cfg(test)] use crate::testing::{build_test_path,run_test}; #[cfg(test)] use quickcheck::quickcheck; use crate::unsigned::{$$sname,$$dname}; impl Add for $$sname { type Output = $$dname; fn add(self, rhs: $$sname) -> $$dname { &self + &rhs } } impl<'a> Add<&'a $$sname> for $$sname { type Output = $$dname; fn add(self, rhs: &$$sname) -> $$dname { &self + rhs } } impl<'a> Add<$$sname> for &'a $$sname { type Output = $$dname; fn add(self, rhs: $$sname) -> $$dname { self + &rhs } } impl<'a,'b> Add<&'a $$sname> for &'b $$sname { type Output = $$dname; fn add(self, rhs: &$$sname) -> $$dname { let mut res = $$dname::zero(); $@{fullRippleAdd} res } } #[cfg(test)] quickcheck! { fn addition_symmetric(a: $$sname, b: $$sname) -> bool { (&a + &b) == (&b + &a) } } #[cfg(test)] #[allow(non_snake_case)] #[test] fn KATs() { run_test(build_test_path("safe_add", $$(testFileLit)), 3, |case| { let (neg0, xbytes) = case.get("x").unwrap(); let (neg1, ybytes) = case.get("y").unwrap(); let (neg2, zbytes) = case.get("z").unwrap(); assert!(!neg0 && !neg1 && !neg2); let x = $$sname::from_bytes(&xbytes); let y = $$sname::from_bytes(&ybytes); let z = $$dname::from_bytes(&zbytes); assert_eq!(z, x + y); }); } |] declareUnsafeAddOperators :: Word -> [Word] -> SourceFile Span declareUnsafeAddOperators bitsize _ = let sname = mkIdent ("U" ++ show bitsize) fullRippleAdd = makeRippleAdder False (bitsize `div` 64) "self" testFileLit = Lit [] (Str (testFile True bitsize) Cooked Unsuffixed mempty) mempty in [sourceFile| use core::ops::AddAssign; #[cfg(test)] use crate::CryptoNum; #[cfg(test)] use crate::testing::{build_test_path,run_test}; #[cfg(test)] use quickcheck::quickcheck; use crate::unsigned::$$sname; impl AddAssign for $$sname { fn add_assign(&mut self, rhs: Self) { self.add_assign(&rhs); } } impl<'a> AddAssign<&'a $$sname> for $$sname { fn add_assign(&mut self, rhs: &Self) { $@{fullRippleAdd} } } #[cfg(test)] quickcheck! { fn addition_symmetric(a: $$sname, b: $$sname) -> bool { let mut side1 = a.clone(); let mut side2 = b.clone(); side1 += b; side2 += a; side1 == side2 } } #[cfg(test)] #[allow(non_snake_case)] #[test] fn KATs() { run_test(build_test_path("unsafe_add", $$(testFileLit)), 3, |case| { let (neg0, xbytes) = case.get("x").unwrap(); let (neg1, ybytes) = case.get("y").unwrap(); let (neg2, zbytes) = case.get("z").unwrap(); assert!(!neg0 && !neg1 && !neg2); let mut x = $$sname::from_bytes(&xbytes); let y = $$sname::from_bytes(&ybytes); let z = $$sname::from_bytes(&zbytes); x += &y; assert_eq!(z, x); }); } |] declareSafeSignedAddOperators :: Word -> [Word] -> SourceFile Span declareSafeSignedAddOperators bitsize _ = let sname = mkIdent ("I" ++ show bitsize) dname = mkIdent ("I" ++ show (bitsize + 64)) testFileLit = Lit [] (Str (testFile False bitsize) Cooked Unsuffixed mempty) mempty in [sourceFile| use core::ops::Add; #[cfg(test)] use crate::CryptoNum; #[cfg(test)] use crate::testing::{build_test_path,run_test}; #[cfg(test)] use quickcheck::quickcheck; use crate::signed::{$$sname,$$dname}; impl Add for $$sname { type Output = $$dname; fn add(self, rhs: $$sname) -> $$dname { &self + &rhs } } impl<'a> Add<&'a $$sname> for $$sname { type Output = $$dname; fn add(self, rhs: &$$sname) -> $$dname { &self + rhs } } impl<'a> Add<$$sname> for &'a $$sname { type Output = $$dname; fn add(self, rhs: $$sname) -> $$dname { self + &rhs } } impl<'a,'b> Add<&'a $$sname> for &'b $$sname { type Output = $$dname; fn add(self, rhs: &$$sname) -> $$dname { let mut res = $$dname::from(self); let bigrhs = $$dname::from(rhs); res += bigrhs; res } } #[cfg(test)] quickcheck! { fn addition_symmetric(a: $$sname, b: $$sname) -> bool { (&a + &b) == (&b + &a) } } #[cfg(test)] #[allow(non_snake_case)] #[test] fn KATs() { run_test(build_test_path("safe_sadd", $$(testFileLit)), 3, |case| { let (neg0, xbytes) = case.get("x").unwrap(); let (neg1, ybytes) = case.get("y").unwrap(); let (neg2, zbytes) = case.get("z").unwrap(); let mut x = $$sname::from_bytes(&xbytes); let mut y = $$sname::from_bytes(&ybytes); let mut z = $$dname::from_bytes(&zbytes); if *neg0 { x = -x } if *neg1 { y = -y } if *neg2 { z = -z } assert_eq!(z, x + y); }); } |] declareUnsafeSignedAddOperators :: Word -> [Word] -> SourceFile Span declareUnsafeSignedAddOperators bitsize _ = let sname = mkIdent ("I" ++ show bitsize) testFileLit = Lit [] (Str (testFile False bitsize) Cooked Unsuffixed mempty) mempty in [sourceFile| use core::ops::AddAssign; #[cfg(test)] use crate::CryptoNum; #[cfg(test)] use crate::testing::{build_test_path,run_test}; #[cfg(test)] use quickcheck::quickcheck; use crate::signed::$$sname; impl AddAssign for $$sname { fn add_assign(&mut self, rhs: Self) { self.contents += rhs.contents; } } impl<'a> AddAssign<&'a $$sname> for $$sname { fn add_assign(&mut self, rhs: &Self) { self.contents += &rhs.contents; } } #[cfg(test)] quickcheck! { fn addition_symmetric(a: $$sname, b: $$sname) -> bool { let mut side1 = a.clone(); let mut side2 = b.clone(); side1 += b; side2 += a; side1 == side2 } } #[cfg(test)] #[allow(non_snake_case)] #[test] fn KATs() { run_test(build_test_path("unsafe_sadd", $$(testFileLit)), 3, |case| { let (neg0, xbytes) = case.get("x").unwrap(); let (neg1, ybytes) = case.get("y").unwrap(); let (neg2, zbytes) = case.get("z").unwrap(); let mut x = $$sname::from_bytes(&xbytes); let mut y = $$sname::from_bytes(&ybytes); let mut z = $$sname::from_bytes(&zbytes); if *neg0 { x = -x } if *neg1 { y = -y } if *neg2 { z = -z } x += &y; assert_eq!(z, x); }); } |] makeRippleAdder :: Bool -> Word -> String -> [Stmt Span] makeRippleAdder useLastCarry inElems resName = concatMap (generateRipples useLastCarry (inElems - 1)) [0..inElems-1] ++ concatMap (generateSetters useLastCarry inElems resName) [0..inElems] generateRipples :: Bool -> Word -> Word -> [Stmt Span] generateRipples useLastCarry lastI i = let sumi = mkIdent ("sum" ++ show i) inCarry = mkIdent ("carry" ++ show (i - 1)) outCarry = mkIdent ("carry" ++ show i) res = mkIdent ("res" ++ show i) liti = toLit i left = mkIdent ("left" ++ show i) right = mkIdent ("right" ++ show i) in [ [stmt|let $$left = self.value[$$(liti)] as u128; |] , [stmt|let $$right = rhs.value[$$(liti)] as u128; |] , if i == 0 then [stmt| let $$sumi = $$left + $$right; |] else [stmt| let $$sumi = $$left + $$right + $$inCarry; |] , [stmt|let $$res = $$sumi as u64; |] ] ++ if not useLastCarry && (i == lastI) then [] else [[stmt|let $$outCarry = $$sumi >> 64; |]] generateSetters :: Bool -> Word -> String -> Word -> [Stmt Span] generateSetters useLastCarry maxI resName i | not useLastCarry && (maxI == i) = [] | maxI == i = let res = mkIdent ("carry" ++ show (i - 1)) liti = toLit i in [[stmt| $$target.value[$$(liti)] = $$res as u64; |]] | otherwise = let res = mkIdent ("res" ++ show i) liti = toLit i in [[stmt| $$target.value[$$(liti)] = $$res; |]] where target = mkIdent resName generateSafeTest :: RandomGen g => Word -> g -> (Map String String, g) generateSafeTest size g0 = (tcase, g2) where (x, g1) = generateNum g0 size (y, g2) = generateNum g1 size tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX (x + y))] generateUnsafeTest :: RandomGen g => Word -> g -> (Map String String, g) generateUnsafeTest size g0 = (tcase, g2) where (x, g1) = generateNum g0 size (y, g2) = generateNum g1 size z = (x + y) .&. ((2 ^ size) - 1) tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)] generateSafeSignedTest :: RandomGen g => Word -> g -> (Map String String, g) generateSafeSignedTest size g0 = (tcase, g2) where (x, g1) = generateSignedNum g0 size (y, g2) = generateSignedNum g1 size tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX (x + y))] generateUnsafeSignedTest :: RandomGen g => Word -> g -> (Map String String, g) generateUnsafeSignedTest size g0 = (tcase, g2) where (x, g1) = generateSignedNum g0 size (y, g2) = generateSignedNum g1 size z = (x + y) .&. ((2 ^ size) - 1) tcase = Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)]