Start working on generating multiplies via Karatsuba.
This commit is contained in:
195
generation/src/Multiply.hs
Normal file
195
generation/src/Multiply.hs
Normal file
@@ -0,0 +1,195 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
module Multiply(
|
||||
safeMultiplyOps
|
||||
, unsafeMultiplyOps
|
||||
)
|
||||
where
|
||||
|
||||
import Data.Bits((.&.))
|
||||
import Data.Map.Strict(Map)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import File
|
||||
import Gen(toLit)
|
||||
import Generators
|
||||
import Karatsuba
|
||||
import Language.Rust.Data.Ident
|
||||
import Language.Rust.Data.Position
|
||||
import Language.Rust.Quote
|
||||
import Language.Rust.Syntax
|
||||
import System.Random(RandomGen)
|
||||
|
||||
numTestCases :: Int
|
||||
numTestCases = 3000
|
||||
|
||||
safeMultiplyOps :: File
|
||||
safeMultiplyOps = File {
|
||||
predicate = \ me others -> (me * 2) `elem` others,
|
||||
outputName = "safe_mul",
|
||||
isUnsigned = True,
|
||||
generator = declareSafeMulOperators,
|
||||
testCase = Just generateSafeTests
|
||||
}
|
||||
|
||||
unsafeMultiplyOps :: File
|
||||
unsafeMultiplyOps = File {
|
||||
predicate = \ _ _ -> False,
|
||||
outputName = "unsafe_mul",
|
||||
isUnsigned = True,
|
||||
generator = declareUnsafeMulOperators,
|
||||
testCase = Just generateUnsafeTests
|
||||
}
|
||||
|
||||
declareSafeMulOperators :: Word -> SourceFile Span
|
||||
declareSafeMulOperators bitsize =
|
||||
let sname = mkIdent ("U" ++ show bitsize)
|
||||
dname = mkIdent ("U" ++ show (bitsize * 2))
|
||||
fullRippleMul = undefined True (bitsize `div` 64) "res"
|
||||
testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty
|
||||
in [sourceFile|
|
||||
use core::ops::Mul;
|
||||
use crate::CryptoNum;
|
||||
#[cfg(test)]
|
||||
use crate::testing::{build_test_path,run_test};
|
||||
#[cfg(test)]
|
||||
use quickcheck::quickcheck;
|
||||
use crate::unsigned::{$$sname,$$dname};
|
||||
|
||||
impl Mul for $$sname {
|
||||
type Output = $$dname;
|
||||
|
||||
fn mul(self, rhs: $$sname) -> $$dname {
|
||||
&self + &rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Mul<&'a $$sname> for $$sname {
|
||||
type Output = $$dname;
|
||||
|
||||
fn mul(self, rhs: &$$sname) -> $$dname {
|
||||
&self + rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Mul<$$sname> for &'a $$sname {
|
||||
type Output = $$dname;
|
||||
|
||||
fn mul(self, rhs: $$sname) -> $$dname {
|
||||
self + &rhs
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a,'b> Mul<&'a $$sname> for &'b $$sname {
|
||||
type Output = $$dname;
|
||||
|
||||
fn mul(self, rhs: &$$sname) -> $$dname {
|
||||
let mut res = $$dname::zero();
|
||||
|
||||
$@{fullRippleMul}
|
||||
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
quickcheck! {
|
||||
fn multiplication_symmetric(a: $$sname, b: $$sname) -> bool {
|
||||
(&a * &b) == (&b * &a)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(non_snake_case)]
|
||||
#[test]
|
||||
fn KATs() {
|
||||
run_test(build_test_path("safe_mul", $$(testFileLit)), 3, |case| {
|
||||
let (neg0, xbytes) = case.get("x").unwrap();
|
||||
let (neg1, ybytes) = case.get("y").unwrap();
|
||||
let (neg2, zbytes) = case.get("z").unwrap();
|
||||
|
||||
assert!(!neg0 && !neg1 && !neg2);
|
||||
let x = $$sname::from_bytes(&xbytes);
|
||||
let y = $$sname::from_bytes(&ybytes);
|
||||
let z = $$dname::from_bytes(&zbytes);
|
||||
|
||||
assert_eq!(z, x + y);
|
||||
});
|
||||
}
|
||||
|]
|
||||
|
||||
declareUnsafeMulOperators :: Word -> SourceFile Span
|
||||
declareUnsafeMulOperators bitsize = undefined bitsize
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
translateInstruction :: Instruction -> Stmt Span
|
||||
translateInstruction instr =
|
||||
case instr of
|
||||
Add outname args ->
|
||||
let outid = mkIdent outname
|
||||
args' = map (\x -> [expr| $$x |]) (map mkIdent args)
|
||||
adds = foldl (\ x y -> [expr| $$(x) + $$(y) |])
|
||||
(head args')
|
||||
(tail args')
|
||||
in [stmt| let $$outid: u128 = $$(adds); |]
|
||||
CastDown outname arg ->
|
||||
let outid = mkIdent outname
|
||||
inid = mkIdent arg
|
||||
in [stmt| let $$outid: u64 = $$inid as u64; |]
|
||||
CastUp outname arg ->
|
||||
let outid = mkIdent outname
|
||||
inid = mkIdent arg
|
||||
in [stmt| let $$outid: u128 = $$inid as u128; |]
|
||||
Complement outname arg ->
|
||||
let outid = mkIdent outname
|
||||
inid = mkIdent arg
|
||||
in [stmt| let $$outid: u64 = !$$inid; |]
|
||||
Declare64 outname arg ->
|
||||
let outid = mkIdent outname
|
||||
val = toLit (fromIntegral arg)
|
||||
in [stmt| let $$outid: u64 = $$(val); |]
|
||||
Declare128 outname arg ->
|
||||
let outid = mkIdent outname
|
||||
val = toLit (fromIntegral arg)
|
||||
in [stmt| let $$outid: u128 = $$(val); |]
|
||||
Mask outname arg mask ->
|
||||
let outid = mkIdent outname
|
||||
inid = mkIdent arg
|
||||
val = toLit (fromIntegral mask)
|
||||
in [stmt| let $$outid: u128 = $$inid & $$(val); |]
|
||||
Multiply outname args ->
|
||||
let outid = mkIdent outname
|
||||
args' = map (\x -> [expr| $$x |]) (map mkIdent args)
|
||||
muls = foldl (\ x y -> [expr| $$(x) * $$(y) |])
|
||||
(head args')
|
||||
(tail args')
|
||||
in [stmt| let $$outid: u128 = $$(muls); |]
|
||||
ShiftR outname arg amt ->
|
||||
let outid = mkIdent outname
|
||||
inid = mkIdent arg
|
||||
val = toLit (fromIntegral amt)
|
||||
in [stmt| let $$outid: u128 = $$inid >> $$(val); |]
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
|
||||
generateSafeTests :: RandomGen g => Word -> g -> [Map String String]
|
||||
generateSafeTests size g = go g numTestCases
|
||||
where
|
||||
go _ 0 = []
|
||||
go g0 i =
|
||||
let (x, g1) = generateNum g0 size
|
||||
(y, g2) = generateNum g1 size
|
||||
tcase = Map.fromList [("x", showX x), ("y", showX y),
|
||||
("z", showX (x * y))]
|
||||
in tcase : go g2 (i - 1)
|
||||
|
||||
generateUnsafeTests :: RandomGen g => Word -> g -> [Map String String]
|
||||
generateUnsafeTests size g = go g numTestCases
|
||||
where
|
||||
go _ 0 = []
|
||||
go g0 i =
|
||||
let (x, g1) = generateNum g0 size
|
||||
(y, g2) = generateNum g1 size
|
||||
z = (x * y) .&. ((2 ^ size) - 1)
|
||||
tcase = Map.fromList [("x", showX x), ("y", showX y),
|
||||
("z", showX z)]
|
||||
in tcase : go g2 (i - 1)
|
||||
Reference in New Issue
Block a user