Multiply works!
This commit is contained in:
@@ -6,6 +6,7 @@ module Multiply(
|
||||
where
|
||||
|
||||
import Data.Bits((.&.))
|
||||
import Data.List(union)
|
||||
import Data.Map.Strict(Map)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import File
|
||||
@@ -32,7 +33,7 @@ safeMultiplyOps = File {
|
||||
|
||||
unsafeMultiplyOps :: File
|
||||
unsafeMultiplyOps = File {
|
||||
predicate = \ _ _ -> False,
|
||||
predicate = \ _ _ -> True,
|
||||
outputName = "unsafe_mul",
|
||||
isUnsigned = True,
|
||||
generator = declareUnsafeMulOperators,
|
||||
@@ -43,7 +44,7 @@ declareSafeMulOperators :: Word -> SourceFile Span
|
||||
declareSafeMulOperators bitsize =
|
||||
let sname = mkIdent ("U" ++ show bitsize)
|
||||
dname = mkIdent ("U" ++ show (bitsize * 2))
|
||||
fullRippleMul = undefined True (bitsize `div` 64) "res"
|
||||
fullRippleMul = generateMultiplier True (bitsize `div` 64) "rhs" "res"
|
||||
testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty
|
||||
in [sourceFile|
|
||||
use core::ops::Mul;
|
||||
@@ -58,7 +59,7 @@ declareSafeMulOperators bitsize =
|
||||
type Output = $$dname;
|
||||
|
||||
fn mul(self, rhs: $$sname) -> $$dname {
|
||||
&self + &rhs
|
||||
&self * &rhs
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +67,7 @@ declareSafeMulOperators bitsize =
|
||||
type Output = $$dname;
|
||||
|
||||
fn mul(self, rhs: &$$sname) -> $$dname {
|
||||
&self + rhs
|
||||
&self * rhs
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +75,7 @@ declareSafeMulOperators bitsize =
|
||||
type Output = $$dname;
|
||||
|
||||
fn mul(self, rhs: $$sname) -> $$dname {
|
||||
self + &rhs
|
||||
self * &rhs
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,16 +112,100 @@ declareSafeMulOperators bitsize =
|
||||
let y = $$sname::from_bytes(&ybytes);
|
||||
let z = $$dname::from_bytes(&zbytes);
|
||||
|
||||
assert_eq!(z, x + y);
|
||||
assert_eq!(z, x * y);
|
||||
});
|
||||
}
|
||||
|]
|
||||
|
||||
declareUnsafeMulOperators :: Word -> SourceFile Span
|
||||
declareUnsafeMulOperators bitsize = undefined bitsize
|
||||
declareUnsafeMulOperators bitsize =
|
||||
let sname = mkIdent ("U" ++ show bitsize)
|
||||
halfRippleMul = generateMultiplier False (bitsize `div` 64) "rhs" "self"
|
||||
testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty
|
||||
in [sourceFile|
|
||||
use core::ops::MulAssign;
|
||||
#[cfg(test)]
|
||||
use crate::CryptoNum;
|
||||
#[cfg(test)]
|
||||
use crate::testing::{build_test_path,run_test};
|
||||
#[cfg(test)]
|
||||
use quickcheck::quickcheck;
|
||||
use crate::unsigned::$$sname;
|
||||
|
||||
impl MulAssign for $$sname {
|
||||
fn mul_assign(&mut self, rhs: $$sname) {
|
||||
self.mul_assign(&rhs);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> MulAssign<&'a $$sname> for $$sname {
|
||||
fn mul_assign(&mut self, rhs: &$$sname) {
|
||||
$@{halfRippleMul}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
quickcheck! {
|
||||
fn multiplication_symmetric(a: $$sname, b: $$sname) -> bool {
|
||||
let a2 = a.clone();
|
||||
let mut b2 = b.clone();
|
||||
let mut a3 = a.clone();
|
||||
let b3 = b.clone();
|
||||
|
||||
b2 *= &a2;
|
||||
a3 *= b3;
|
||||
|
||||
a3 == b2
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(non_snake_case)]
|
||||
#[test]
|
||||
fn KATs() {
|
||||
run_test(build_test_path("unsafe_mul", $$(testFileLit)), 3, |case| {
|
||||
let (neg0, xbytes) = case.get("x").unwrap();
|
||||
let (neg1, ybytes) = case.get("y").unwrap();
|
||||
let (neg2, zbytes) = case.get("z").unwrap();
|
||||
|
||||
assert!(!neg0 && !neg1 && !neg2);
|
||||
let mut x = $$sname::from_bytes(&xbytes);
|
||||
let y = $$sname::from_bytes(&ybytes);
|
||||
let z = $$sname::from_bytes(&zbytes);
|
||||
|
||||
x *= y;
|
||||
assert_eq!(z, x);
|
||||
});
|
||||
}
|
||||
|]
|
||||
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
generateMultiplier :: Bool -> Word -> String -> String -> [Stmt Span]
|
||||
generateMultiplier fullmul size inName outName =
|
||||
let readIns = map (load "self" "x") [0..size-1] ++
|
||||
map (load inName "y") [0..size-1]
|
||||
instructions = releaseUnnecessary outVars (generateInstructions size)
|
||||
outDigits | fullmul = 2 * size
|
||||
| otherwise = size
|
||||
outVars = map (("res" ++) . show) [0..outDigits-1]
|
||||
operations = map translateInstruction instructions
|
||||
writeOuts = map (store "res") [0..outDigits-1]
|
||||
in readIns ++ operations ++ writeOuts
|
||||
where
|
||||
load rhs vname i =
|
||||
let liti = toLit i
|
||||
vec = mkIdent rhs
|
||||
var = mkIdent (vname ++ show i)
|
||||
in [stmt| let $$var = $$vec.value[$$(liti)]; |]
|
||||
store vname i =
|
||||
let liti = toLit i
|
||||
vec = mkIdent outName
|
||||
var = mkIdent (vname ++ show i)
|
||||
in [stmt| $$vec.value[$$(liti)] = $$var; |]
|
||||
|
||||
|
||||
translateInstruction :: Instruction -> Stmt Span
|
||||
translateInstruction instr =
|
||||
case instr of
|
||||
@@ -169,12 +254,51 @@ translateInstruction instr =
|
||||
val = toLit (fromIntegral amt)
|
||||
in [stmt| let $$outid: u128 = $$inid >> $$(val); |]
|
||||
|
||||
releaseUnnecessary :: [String] -> [Instruction] -> [Instruction]
|
||||
releaseUnnecessary outkeys instrs = snd (foldl check (outkeys, []) (reverse instrs))
|
||||
where
|
||||
check acc@(required, rest) cur
|
||||
| outVar cur `elem` required = (union (inVars cur) required, cur : rest)
|
||||
| otherwise = acc
|
||||
|
||||
outVar :: Instruction -> String
|
||||
outVar instr =
|
||||
case instr of
|
||||
Add outname _ -> outname
|
||||
CastDown outname _ -> outname
|
||||
CastUp outname _ -> outname
|
||||
Complement outname _ -> outname
|
||||
Declare64 outname _ -> outname
|
||||
Declare128 outname _ -> outname
|
||||
Mask outname _ _ -> outname
|
||||
Multiply outname _ -> outname
|
||||
ShiftR outname _ _ -> outname
|
||||
|
||||
inVars :: Instruction -> [String]
|
||||
inVars instr =
|
||||
case instr of
|
||||
Add _ args -> args
|
||||
CastDown _ arg -> [arg]
|
||||
CastUp _ arg -> [arg]
|
||||
Complement _ arg -> [arg]
|
||||
Declare64 _ _ -> []
|
||||
Declare128 _ _ -> []
|
||||
Mask _ arg _ -> [arg]
|
||||
Multiply _ args -> args
|
||||
ShiftR _ arg _ -> [arg]
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
generateSafeTests :: RandomGen g => Word -> g -> [Map String String]
|
||||
generateSafeTests size g = go g numTestCases
|
||||
where
|
||||
go _ 0 = []
|
||||
go _ 0 = [
|
||||
Map.fromList [("x", "0"), ("y", "0"), ("z", "0")]
|
||||
, (let x = (2 ^ size) - 1
|
||||
y = (2 ^ size) - 1
|
||||
z = x * y
|
||||
in Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)])
|
||||
]
|
||||
go g0 i =
|
||||
let (x, g1) = generateNum g0 size
|
||||
(y, g2) = generateNum g1 size
|
||||
|
||||
Reference in New Issue
Block a user