Multiply works!

This commit is contained in:
2020-01-06 12:25:38 -08:00
parent d8c752fad3
commit a35d0df6da
3 changed files with 136 additions and 9 deletions

View File

@@ -9,6 +9,7 @@ import Conversions(conversions)
import CryptoNum(cryptoNum)
import Control.Monad(forM_,unless)
import File(File,Task(..),generateTasks)
import Multiply(safeMultiplyOps, unsafeMultiplyOps)
import Shift(shiftOps)
import Subtract(safeSubtractOps,unsafeSubtractOps)
import System.Directory(createDirectoryIfMissing)
@@ -35,9 +36,11 @@ unsignedFiles = [
, conversions
, cryptoNum
, safeAddOps
, safeMultiplyOps
, safeSubtractOps
, shiftOps
, unsafeAddOps
, unsafeMultiplyOps
, unsafeSubtractOps
]

View File

@@ -18,7 +18,7 @@ extra-source-files: CHANGELOG.md
library
default-language: Haskell2010
ghc-options: -Wall
build-depends: base >= 4.12.0.0,
build-depends: base,
containers,
directory,
filepath,

View File

@@ -6,6 +6,7 @@ module Multiply(
where
import Data.Bits((.&.))
import Data.List(union)
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import File
@@ -32,7 +33,7 @@ safeMultiplyOps = File {
unsafeMultiplyOps :: File
unsafeMultiplyOps = File {
predicate = \ _ _ -> False,
predicate = \ _ _ -> True,
outputName = "unsafe_mul",
isUnsigned = True,
generator = declareUnsafeMulOperators,
@@ -43,7 +44,7 @@ declareSafeMulOperators :: Word -> SourceFile Span
declareSafeMulOperators bitsize =
let sname = mkIdent ("U" ++ show bitsize)
dname = mkIdent ("U" ++ show (bitsize * 2))
fullRippleMul = undefined True (bitsize `div` 64) "res"
fullRippleMul = generateMultiplier True (bitsize `div` 64) "rhs" "res"
testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::Mul;
@@ -58,7 +59,7 @@ declareSafeMulOperators bitsize =
type Output = $$dname;
fn mul(self, rhs: $$sname) -> $$dname {
&self + &rhs
&self * &rhs
}
}
@@ -66,7 +67,7 @@ declareSafeMulOperators bitsize =
type Output = $$dname;
fn mul(self, rhs: &$$sname) -> $$dname {
&self + rhs
&self * rhs
}
}
@@ -74,7 +75,7 @@ declareSafeMulOperators bitsize =
type Output = $$dname;
fn mul(self, rhs: $$sname) -> $$dname {
self + &rhs
self * &rhs
}
}
@@ -111,16 +112,100 @@ declareSafeMulOperators bitsize =
let y = $$sname::from_bytes(&ybytes);
let z = $$dname::from_bytes(&zbytes);
assert_eq!(z, x + y);
assert_eq!(z, x * y);
});
}
|]
declareUnsafeMulOperators :: Word -> SourceFile Span
declareUnsafeMulOperators bitsize = undefined bitsize
declareUnsafeMulOperators bitsize =
let sname = mkIdent ("U" ++ show bitsize)
halfRippleMul = generateMultiplier False (bitsize `div` 64) "rhs" "self"
testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty
in [sourceFile|
use core::ops::MulAssign;
#[cfg(test)]
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use crate::unsigned::$$sname;
impl MulAssign for $$sname {
fn mul_assign(&mut self, rhs: $$sname) {
self.mul_assign(&rhs);
}
}
impl<'a> MulAssign<&'a $$sname> for $$sname {
fn mul_assign(&mut self, rhs: &$$sname) {
$@{halfRippleMul}
}
}
#[cfg(test)]
quickcheck! {
fn multiplication_symmetric(a: $$sname, b: $$sname) -> bool {
let a2 = a.clone();
let mut b2 = b.clone();
let mut a3 = a.clone();
let b3 = b.clone();
b2 *= &a2;
a3 *= b3;
a3 == b2
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("unsafe_mul", $$(testFileLit)), 3, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, ybytes) = case.get("y").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
assert!(!neg0 && !neg1 && !neg2);
let mut x = $$sname::from_bytes(&xbytes);
let y = $$sname::from_bytes(&ybytes);
let z = $$sname::from_bytes(&zbytes);
x *= y;
assert_eq!(z, x);
});
}
|]
-- -----------------------------------------------------------------------------
generateMultiplier :: Bool -> Word -> String -> String -> [Stmt Span]
generateMultiplier fullmul size inName outName =
let readIns = map (load "self" "x") [0..size-1] ++
map (load inName "y") [0..size-1]
instructions = releaseUnnecessary outVars (generateInstructions size)
outDigits | fullmul = 2 * size
| otherwise = size
outVars = map (("res" ++) . show) [0..outDigits-1]
operations = map translateInstruction instructions
writeOuts = map (store "res") [0..outDigits-1]
in readIns ++ operations ++ writeOuts
where
load rhs vname i =
let liti = toLit i
vec = mkIdent rhs
var = mkIdent (vname ++ show i)
in [stmt| let $$var = $$vec.value[$$(liti)]; |]
store vname i =
let liti = toLit i
vec = mkIdent outName
var = mkIdent (vname ++ show i)
in [stmt| $$vec.value[$$(liti)] = $$var; |]
translateInstruction :: Instruction -> Stmt Span
translateInstruction instr =
case instr of
@@ -169,12 +254,51 @@ translateInstruction instr =
val = toLit (fromIntegral amt)
in [stmt| let $$outid: u128 = $$inid >> $$(val); |]
releaseUnnecessary :: [String] -> [Instruction] -> [Instruction]
releaseUnnecessary outkeys instrs = snd (foldl check (outkeys, []) (reverse instrs))
where
check acc@(required, rest) cur
| outVar cur `elem` required = (union (inVars cur) required, cur : rest)
| otherwise = acc
outVar :: Instruction -> String
outVar instr =
case instr of
Add outname _ -> outname
CastDown outname _ -> outname
CastUp outname _ -> outname
Complement outname _ -> outname
Declare64 outname _ -> outname
Declare128 outname _ -> outname
Mask outname _ _ -> outname
Multiply outname _ -> outname
ShiftR outname _ _ -> outname
inVars :: Instruction -> [String]
inVars instr =
case instr of
Add _ args -> args
CastDown _ arg -> [arg]
CastUp _ arg -> [arg]
Complement _ arg -> [arg]
Declare64 _ _ -> []
Declare128 _ _ -> []
Mask _ arg _ -> [arg]
Multiply _ args -> args
ShiftR _ arg _ -> [arg]
-- -----------------------------------------------------------------------------
generateSafeTests :: RandomGen g => Word -> g -> [Map String String]
generateSafeTests size g = go g numTestCases
where
go _ 0 = []
go _ 0 = [
Map.fromList [("x", "0"), ("y", "0"), ("z", "0")]
, (let x = (2 ^ size) - 1
y = (2 ^ size) - 1
z = x * y
in Map.fromList [("x", showX x), ("y", showX y), ("z", showX z)])
]
go g0 i =
let (x, g1) = generateNum g0 size
(y, g2) = generateNum g1 size