Multiply works!
This commit is contained in:
@@ -9,6 +9,7 @@ import Conversions(conversions)
|
|||||||
import CryptoNum(cryptoNum)
|
import CryptoNum(cryptoNum)
|
||||||
import Control.Monad(forM_,unless)
|
import Control.Monad(forM_,unless)
|
||||||
import File(File,Task(..),generateTasks)
|
import File(File,Task(..),generateTasks)
|
||||||
|
import Multiply(safeMultiplyOps, unsafeMultiplyOps)
|
||||||
import Shift(shiftOps)
|
import Shift(shiftOps)
|
||||||
import Subtract(safeSubtractOps,unsafeSubtractOps)
|
import Subtract(safeSubtractOps,unsafeSubtractOps)
|
||||||
import System.Directory(createDirectoryIfMissing)
|
import System.Directory(createDirectoryIfMissing)
|
||||||
@@ -35,9 +36,11 @@ unsignedFiles = [
|
|||||||
, conversions
|
, conversions
|
||||||
, cryptoNum
|
, cryptoNum
|
||||||
, safeAddOps
|
, safeAddOps
|
||||||
|
, safeMultiplyOps
|
||||||
, safeSubtractOps
|
, safeSubtractOps
|
||||||
, shiftOps
|
, shiftOps
|
||||||
, unsafeAddOps
|
, unsafeAddOps
|
||||||
|
, unsafeMultiplyOps
|
||||||
, unsafeSubtractOps
|
, unsafeSubtractOps
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ extra-source-files: CHANGELOG.md
|
|||||||
library
|
library
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
ghc-options: -Wall
|
ghc-options: -Wall
|
||||||
build-depends: base >= 4.12.0.0,
|
build-depends: base,
|
||||||
containers,
|
containers,
|
||||||
directory,
|
directory,
|
||||||
filepath,
|
filepath,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ module Multiply(
|
|||||||
where
|
where
|
||||||
|
|
||||||
import Data.Bits((.&.))
|
import Data.Bits((.&.))
|
||||||
|
import Data.List(union)
|
||||||
import Data.Map.Strict(Map)
|
import Data.Map.Strict(Map)
|
||||||
import qualified Data.Map.Strict as Map
|
import qualified Data.Map.Strict as Map
|
||||||
import File
|
import File
|
||||||
@@ -32,7 +33,7 @@ safeMultiplyOps = File {
|
|||||||
|
|
||||||
unsafeMultiplyOps :: File
|
unsafeMultiplyOps :: File
|
||||||
unsafeMultiplyOps = File {
|
unsafeMultiplyOps = File {
|
||||||
predicate = \ _ _ -> False,
|
predicate = \ _ _ -> True,
|
||||||
outputName = "unsafe_mul",
|
outputName = "unsafe_mul",
|
||||||
isUnsigned = True,
|
isUnsigned = True,
|
||||||
generator = declareUnsafeMulOperators,
|
generator = declareUnsafeMulOperators,
|
||||||
@@ -43,7 +44,7 @@ declareSafeMulOperators :: Word -> SourceFile Span
|
|||||||
declareSafeMulOperators bitsize =
|
declareSafeMulOperators bitsize =
|
||||||
let sname = mkIdent ("U" ++ show bitsize)
|
let sname = mkIdent ("U" ++ show bitsize)
|
||||||
dname = mkIdent ("U" ++ show (bitsize * 2))
|
dname = mkIdent ("U" ++ show (bitsize * 2))
|
||||||
fullRippleMul = undefined True (bitsize `div` 64) "res"
|
fullRippleMul = generateMultiplier True (bitsize `div` 64) "rhs" "res"
|
||||||
testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty
|
testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty
|
||||||
in [sourceFile|
|
in [sourceFile|
|
||||||
use core::ops::Mul;
|
use core::ops::Mul;
|
||||||
@@ -58,7 +59,7 @@ declareSafeMulOperators bitsize =
|
|||||||
type Output = $$dname;
|
type Output = $$dname;
|
||||||
|
|
||||||
fn mul(self, rhs: $$sname) -> $$dname {
|
fn mul(self, rhs: $$sname) -> $$dname {
|
||||||
&self + &rhs
|
&self * &rhs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,7 +67,7 @@ declareSafeMulOperators bitsize =
|
|||||||
type Output = $$dname;
|
type Output = $$dname;
|
||||||
|
|
||||||
fn mul(self, rhs: &$$sname) -> $$dname {
|
fn mul(self, rhs: &$$sname) -> $$dname {
|
||||||
&self + rhs
|
&self * rhs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ declareSafeMulOperators bitsize =
|
|||||||
type Output = $$dname;
|
type Output = $$dname;
|
||||||
|
|
||||||
fn mul(self, rhs: $$sname) -> $$dname {
|
fn mul(self, rhs: $$sname) -> $$dname {
|
||||||
self + &rhs
|
self * &rhs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,16 +112,100 @@ declareSafeMulOperators bitsize =
|
|||||||
let y = $$sname::from_bytes(&ybytes);
|
let y = $$sname::from_bytes(&ybytes);
|
||||||
let z = $$dname::from_bytes(&zbytes);
|
let z = $$dname::from_bytes(&zbytes);
|
||||||
|
|
||||||
assert_eq!(z, x + y);
|
assert_eq!(z, x * y);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|]
|
|]
|
||||||
|
|
||||||
declareUnsafeMulOperators :: Word -> SourceFile Span
|
declareUnsafeMulOperators :: Word -> SourceFile Span
|
||||||
declareUnsafeMulOperators bitsize = undefined bitsize
|
declareUnsafeMulOperators bitsize =
|
||||||
|
let sname = mkIdent ("U" ++ show bitsize)
|
||||||
|
halfRippleMul = generateMultiplier False (bitsize `div` 64) "rhs" "self"
|
||||||
|
testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty
|
||||||
|
in [sourceFile|
|
||||||
|
use core::ops::MulAssign;
|
||||||
|
#[cfg(test)]
|
||||||
|
use crate::CryptoNum;
|
||||||
|
#[cfg(test)]
|
||||||
|
use crate::testing::{build_test_path,run_test};
|
||||||
|
#[cfg(test)]
|
||||||
|
use quickcheck::quickcheck;
|
||||||
|
use crate::unsigned::$$sname;
|
||||||
|
|
||||||
|
impl MulAssign for $$sname {
|
||||||
|
fn mul_assign(&mut self, rhs: $$sname) {
|
||||||
|
self.mul_assign(&rhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> MulAssign<&'a $$sname> for $$sname {
|
||||||
|
fn mul_assign(&mut self, rhs: &$$sname) {
|
||||||
|
$@{halfRippleMul}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
quickcheck! {
|
||||||
|
fn multiplication_symmetric(a: $$sname, b: $$sname) -> bool {
|
||||||
|
let a2 = a.clone();
|
||||||
|
let mut b2 = b.clone();
|
||||||
|
let mut a3 = a.clone();
|
||||||
|
let b3 = b.clone();
|
||||||
|
|
||||||
|
b2 *= &a2;
|
||||||
|
a3 *= b3;
|
||||||
|
|
||||||
|
a3 == b2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
#[test]
|
||||||
|
fn KATs() {
|
||||||
|
run_test(build_test_path("unsafe_mul", $$(testFileLit)), 3, |case| {
|
||||||
|
let (neg0, xbytes) = case.get("x").unwrap();
|
||||||
|
let (neg1, ybytes) = case.get("y").unwrap();
|
||||||
|
let (neg2, zbytes) = case.get("z").unwrap();
|
||||||
|
|
||||||
|
assert!(!neg0 && !neg1 && !neg2);
|
||||||
|
let mut x = $$sname::from_bytes(&xbytes);
|
||||||
|
let y = $$sname::from_bytes(&ybytes);
|
||||||
|
let z = $$sname::from_bytes(&zbytes);
|
||||||
|
|
||||||
|
x *= y;
|
||||||
|
assert_eq!(z, x);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|]
|
||||||
|
|
||||||
|
|
||||||
-- -----------------------------------------------------------------------------
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
generateMultiplier :: Bool -> Word -> String -> String -> [Stmt Span]
|
||||||
|
generateMultiplier fullmul size inName outName =
|
||||||
|
let readIns = map (load "self" "x") [0..size-1] ++
|
||||||
|
map (load inName "y") [0..size-1]
|
||||||
|
instructions = releaseUnnecessary outVars (generateInstructions size)
|
||||||
|
outDigits | fullmul = 2 * size
|
||||||
|
| otherwise = size
|
||||||
|
outVars = map (("res" ++) . show) [0..outDigits-1]
|
||||||
|
operations = map translateInstruction instructions
|
||||||
|
writeOuts = map (store "res") [0..outDigits-1]
|
||||||
|
in readIns ++ operations ++ writeOuts
|
||||||
|
where
|
||||||
|
load rhs vname i =
|
||||||
|
let liti = toLit i
|
||||||
|
vec = mkIdent rhs
|
||||||
|
var = mkIdent (vname ++ show i)
|
||||||
|
in [stmt| let $$var = $$vec.value[$$(liti)]; |]
|
||||||
|
store vname i =
|
||||||
|
let liti = toLit i
|
||||||
|
vec = mkIdent outName
|
||||||
|
var = mkIdent (vname ++ show i)
|
||||||
|
in [stmt| $$vec.value[$$(liti)] = $$var; |]
|
||||||
|
|
||||||
|
|
||||||
translateInstruction :: Instruction -> Stmt Span
|
translateInstruction :: Instruction -> Stmt Span
|
||||||
translateInstruction instr =
|
translateInstruction instr =
|
||||||
case instr of
|
case instr of
|
||||||
@@ -169,12 +254,51 @@ translateInstruction instr =
|
|||||||
val = toLit (fromIntegral amt)
|
val = toLit (fromIntegral amt)
|
||||||
in [stmt| let $$outid: u128 = $$inid >> $$(val); |]
|
in [stmt| let $$outid: u128 = $$inid >> $$(val); |]
|
||||||
|
|
||||||
|
releaseUnnecessary :: [String] -> [Instruction] -> [Instruction]
|
||||||
|
releaseUnnecessary outkeys instrs = snd (foldl check (outkeys, []) (reverse instrs))
|
||||||
|
where
|
||||||
|
check acc@(required, rest) cur
|
||||||
|
| outVar cur `elem` required = (union (inVars cur) required, cur : rest)
|
||||||
|
| otherwise = acc
|
||||||
|
|
||||||
|
outVar :: Instruction -> String
|
||||||
|
outVar instr =
|
||||||
|
case instr of
|
||||||
|
Add outname _ -> outname
|
||||||
|
CastDown outname _ -> outname
|
||||||
|
CastUp outname _ -> outname
|
||||||
|
Complement outname _ -> outname
|
||||||
|
Declare64 outname _ -> outname
|
||||||
|
Declare128 outname _ -> outname
|
||||||
|
Mask outname _ _ -> outname
|
||||||
|
Multiply outname _ -> outname
|
||||||
|
ShiftR outname _ _ -> outname
|
||||||
|
|
||||||
|
inVars :: Instruction -> [String]
|
||||||
|
inVars instr =
|
||||||
|
case instr of
|
||||||
|
Add _ args -> args
|
||||||
|
CastDown _ arg -> [arg]
|
||||||
|
CastUp _ arg -> [arg]
|
||||||
|
Complement _ arg -> [arg]
|
||||||
|
Declare64 _ _ -> []
|
||||||
|
Declare128 _ _ -> []
|
||||||
|
Mask _ arg _ -> [arg]
|
||||||
|
Multiply _ args -> args
|
||||||
|
ShiftR _ arg _ -> [arg]
|
||||||
|
|
||||||
-- -----------------------------------------------------------------------------
|
-- -----------------------------------------------------------------------------
|
||||||
|
|
||||||
generateSafeTests :: RandomGen g => Word -> g -> [Map String String]
|
generateSafeTests :: RandomGen g => Word -> g -> [Map String String]
|
||||||
generateSafeTests size g = go g numTestCases
|
generateSafeTests size g = go g numTestCases
|
||||||
where
|
where
|
||||||
go _ 0 = []
|
go _ 0 = [
|
||||||
|
Map.fromList [("x", "0"), ("y", "0"), ("z", "0")]
|
||||||
|
, (let x = (2 ^ size) - 1
|
||||||
|
y = (2 ^ size) - 1
|
||||||
|
z = x * y
|
||||||
|
in Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)])
|
||||||
|
]
|
||||||
go g0 i =
|
go g0 i =
|
||||||
let (x, g1) = generateNum g0 size
|
let (x, g1) = generateNum g0 size
|
||||||
(y, g2) = generateNum g1 size
|
(y, g2) = generateNum g1 size
|
||||||
|
|||||||
Reference in New Issue
Block a user