diff --git a/generation/Main.hs b/generation/Main.hs index c7405f4..d1f9772 100644 --- a/generation/Main.hs +++ b/generation/Main.hs @@ -9,6 +9,7 @@ import Conversions(conversions) import CryptoNum(cryptoNum) import Control.Monad(forM_,unless) import File(File,Task(..),generateTasks) +import Multiply(safeMultiplyOps, unsafeMultiplyOps) import Shift(shiftOps) import Subtract(safeSubtractOps,unsafeSubtractOps) import System.Directory(createDirectoryIfMissing) @@ -35,9 +36,11 @@ unsignedFiles = [ , conversions , cryptoNum , safeAddOps + , safeMultiplyOps , safeSubtractOps , shiftOps , unsafeAddOps + , unsafeMultiplyOps , unsafeSubtractOps ] diff --git a/generation/generation.cabal b/generation/generation.cabal index 86fdcdf..a5a907a 100644 --- a/generation/generation.cabal +++ b/generation/generation.cabal @@ -18,7 +18,7 @@ extra-source-files: CHANGELOG.md library default-language: Haskell2010 ghc-options: -Wall - build-depends: base >= 4.12.0.0, + build-depends: base, containers, directory, filepath, diff --git a/generation/src/Multiply.hs b/generation/src/Multiply.hs index 694b70f..11ee4b3 100644 --- a/generation/src/Multiply.hs +++ b/generation/src/Multiply.hs @@ -6,6 +6,7 @@ module Multiply( where import Data.Bits((.&.)) +import Data.List(union) import Data.Map.Strict(Map) import qualified Data.Map.Strict as Map import File @@ -32,7 +33,7 @@ safeMultiplyOps = File { unsafeMultiplyOps :: File unsafeMultiplyOps = File { - predicate = \ _ _ -> False, + predicate = \ _ _ -> True, outputName = "unsafe_mul", isUnsigned = True, generator = declareUnsafeMulOperators, @@ -43,7 +44,7 @@ declareSafeMulOperators :: Word -> SourceFile Span declareSafeMulOperators bitsize = let sname = mkIdent ("U" ++ show bitsize) dname = mkIdent ("U" ++ show (bitsize * 2)) - fullRippleMul = undefined True (bitsize `div` 64) "res" + fullRippleMul = generateMultiplier True (bitsize `div` 64) "rhs" "res" testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty in [sourceFile| use core::ops::Mul; @@ -58,7 +59,7 @@ declareSafeMulOperators bitsize = type Output = $$dname; fn mul(self, rhs: $$sname) -> $$dname { - &self + &rhs + &self * &rhs } } @@ -66,7 +67,7 @@ declareSafeMulOperators bitsize = type Output = $$dname; fn mul(self, rhs: &$$sname) -> $$dname { - &self + rhs + &self * rhs } } @@ -74,7 +75,7 @@ declareSafeMulOperators bitsize = type Output = $$dname; fn mul(self, rhs: $$sname) -> $$dname { - self + &rhs + self * &rhs } } @@ -111,16 +112,100 @@ declareSafeMulOperators bitsize = let y = $$sname::from_bytes(&ybytes); let z = $$dname::from_bytes(&zbytes); - assert_eq!(z, x + y); + assert_eq!(z, x * y); }); } |] declareUnsafeMulOperators :: Word -> SourceFile Span -declareUnsafeMulOperators bitsize = undefined bitsize +declareUnsafeMulOperators bitsize = + let sname = mkIdent ("U" ++ show bitsize) + halfRippleMul = generateMultiplier False (bitsize `div` 64) "rhs" "self" + testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty + in [sourceFile| + use core::ops::MulAssign; + #[cfg(test)] + use crate::CryptoNum; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + #[cfg(test)] + use quickcheck::quickcheck; + use crate::unsigned::$$sname; + + impl MulAssign for $$sname { + fn mul_assign(&mut self, rhs: $$sname) { + self.mul_assign(&rhs); + } + } + + impl<'a> MulAssign<&'a $$sname> for $$sname { + fn mul_assign(&mut self, rhs: &$$sname) { + $@{halfRippleMul} + } + } + + #[cfg(test)] + quickcheck! { + fn multiplication_symmetric(a: $$sname, b: $$sname) -> bool { + let a2 = a.clone(); + let mut b2 = b.clone(); + let mut a3 = a.clone(); + let b3 = b.clone(); + + b2 *= &a2; + a3 *= b3; + + a3 == b2 + } + } + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("unsafe_mul", $$(testFileLit)), 3, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, ybytes) = case.get("y").unwrap(); + let (neg2, zbytes) = case.get("z").unwrap(); + + assert!(!neg0 && !neg1 && !neg2); + let mut x = $$sname::from_bytes(&xbytes); + let y = $$sname::from_bytes(&ybytes); + let z = $$sname::from_bytes(&zbytes); + + x *= y; + assert_eq!(z, x); + }); + } + |] + -- ----------------------------------------------------------------------------- +generateMultiplier :: Bool -> Word -> String -> String -> [Stmt Span] +generateMultiplier fullmul size inName outName = + let readIns = map (load "self" "x") [0..size-1] ++ + map (load inName "y") [0..size-1] + instructions = releaseUnnecessary outVars (generateInstructions size) + outDigits | fullmul = 2 * size + | otherwise = size + outVars = map (("res" ++) . show) [0..outDigits-1] + operations = map translateInstruction instructions + writeOuts = map (store "res") [0..outDigits-1] + in readIns ++ operations ++ writeOuts + where + load rhs vname i = + let liti = toLit i + vec = mkIdent rhs + var = mkIdent (vname ++ show i) + in [stmt| let $$var = $$vec.value[$$(liti)]; |] + store vname i = + let liti = toLit i + vec = mkIdent outName + var = mkIdent (vname ++ show i) + in [stmt| $$vec.value[$$(liti)] = $$var; |] + + translateInstruction :: Instruction -> Stmt Span translateInstruction instr = case instr of @@ -169,12 +254,51 @@ translateInstruction instr = val = toLit (fromIntegral amt) in [stmt| let $$outid: u128 = $$inid >> $$(val); |] +releaseUnnecessary :: [String] -> [Instruction] -> [Instruction] +releaseUnnecessary outkeys instrs = snd (foldl check (outkeys, []) (reverse instrs)) + where + check acc@(required, rest) cur + | outVar cur `elem` required = (union (inVars cur) required, cur : rest) + | otherwise = acc + +outVar :: Instruction -> String +outVar instr = + case instr of + Add outname _ -> outname + CastDown outname _ -> outname + CastUp outname _ -> outname + Complement outname _ -> outname + Declare64 outname _ -> outname + Declare128 outname _ -> outname + Mask outname _ _ -> outname + Multiply outname _ -> outname + ShiftR outname _ _ -> outname + +inVars :: Instruction -> [String] +inVars instr = + case instr of + Add _ args -> args + CastDown _ arg -> [arg] + CastUp _ arg -> [arg] + Complement _ arg -> [arg] + Declare64 _ _ -> [] + Declare128 _ _ -> [] + Mask _ arg _ -> [arg] + Multiply _ args -> args + ShiftR _ arg _ -> [arg] + -- ----------------------------------------------------------------------------- generateSafeTests :: RandomGen g => Word -> g -> [Map String String] generateSafeTests size g = go g numTestCases where - go _ 0 = [] + go _ 0 = [ + Map.fromList [("x", "0"), ("y", "0"), ("z", "0")] + , (let x = (2 ^ size) - 1 + y = (2 ^ size) - 1 + z = x * y + in Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)]) + ] go g0 i = let (x, g1) = generateNum g0 size (y, g2) = generateNum g1 size