Files
cryptonum/generation/src/Karatsuba.hs

611 lines
20 KiB
Haskell

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Karatsuba(
Instruction(..)
, runChecks
, generateInstructions
)
where
import Control.Monad.Fail(MonadFail(..))
import Control.Monad.Identity hiding (fail)
import Control.Monad.RWS.Strict hiding (fail)
import Data.Bits
import Data.LargeWord
import Data.List
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Data.Vector(Vector, (!?))
import qualified Data.Vector as V
import Data.Word
import Prelude hiding (fail)
import Test.QuickCheck hiding ((.&.))
-- this drives the testing
inputWordSize :: Int
inputWordSize = 5
generateInstructions :: Word -> [Instruction]
generateInstructions numdigits = foldl replaceVar baseInstrs varRenames
where
x = rename "x" (V.replicate (fromIntegral numdigits) (D "" 1))
y = rename "y" (V.replicate (fromIntegral numdigits) (D "" 1))
(baseVec, baseInstrs) = runMath (karatsuba x y)
res = rename "res" baseVec
varRenames = zip (map name (V.toList baseVec)) (map name (V.toList res))
-- -----------------------------------------------------------------------------
--
-- Instructions that we emit as a result of running Karatsuba, that can be
-- turned into Rust lines.
--
-- -----------------------------------------------------------------------------
-- these are in Intel form, as I was corrupted young, so the first argument
-- is the destination and the rest are the arguments.
data Instruction = Add String [String]
| CastDown String String
| CastUp String String
| Complement String String
| Declare64 String Word64
| Declare128 String Word128
| Mask String String Word128
| Multiply String [String]
| ShiftR String String Int
deriving (Eq, Show)
class Declarable a where
declare :: String -> a -> Instruction
instance Declarable Word64 where
declare n x = Declare64 n x
instance Declarable Word128 where
declare n x = Declare128 n x
type Env = (Map String Word64, Map String Word128)
step :: Env -> Instruction -> Env
step (env64, env128) i =
case i of
Add outname items ->
(env64, Map.insert outname (sum (map (getv env128) items)) env128)
CastDown outname item ->
(Map.insert outname (fromIntegral (getv env128 item)) env64, env128)
CastUp outname item ->
(env64, Map.insert outname (fromIntegral (getv env64 item)) env128)
Complement outname item ->
(Map.insert outname (complement (getv env64 item)) env64, env128)
Declare64 outname val ->
(Map.insert outname val env64, env128)
Declare128 outname val ->
(env64, Map.insert outname val env128)
Mask outname item mask ->
(env64, Map.insert outname (getv env128 item .&. mask) env128)
Multiply outname items ->
(env64, Map.insert outname (product (map (getv env128) items)) env128)
ShiftR outname item amt ->
(env64, Map.insert outname (getv env128 item `shiftR` amt) env128)
where
getv env s =
case Map.lookup s env of
Nothing ->
error ("Failure to find key '" ++ s ++ "'")
Just v ->
v
run :: Env -> [Instruction] -> Env
run env instrs =
case instrs of
[] -> env
(x:rest) -> run (step env x) rest
replaceVar :: [Instruction] -> (String, String) -> [Instruction]
replaceVar ls (from, to) = map replace ls
where
replace x =
case x of
Add outname items -> Add (sub outname) (map sub items)
CastDown outname item -> CastDown (sub outname) (sub item)
CastUp outname item -> CastUp (sub outname) (sub item)
Complement outname item -> Complement (sub outname) (sub item)
Declare64 outname val -> Declare64 (sub outname) val
Declare128 outname val -> Declare128 (sub outname) val
Mask outname item mask -> Mask (sub outname) (sub item) mask
Multiply outname items -> Multiply (sub outname) (map sub items)
ShiftR outname item amt -> ShiftR (sub outname) (sub item) amt
sub x | x == from = to
| otherwise = x
-- -----------------------------------------------------------------------------
--
-- The Math monad.
--
-- -----------------------------------------------------------------------------
newtype Math a = Math { unMath :: RWS () [Instruction] Integer a }
deriving (Applicative, Functor, Monad,
MonadState Integer,
MonadWriter [Instruction])
instance MonadFail Math where
fail s = error ("Math fail: " ++ s)
emit :: Instruction -> Math ()
emit instr = tell [instr]
gensym :: String -> Math String
gensym base =
do x <- state (\ i -> (i, i + 1))
return (base ++ show x)
runMath :: Math a -> (a, [Instruction])
runMath m = evalRWS (unMath m) () 0
-- -----------------------------------------------------------------------------
--
-- Primitive mathematics that can run on a Digit
--
-- -----------------------------------------------------------------------------
data Digit size = D {
name :: String
, digit :: size
}
deriving (Eq,Show)
genDigit :: Declarable size => String -> size -> Math (Digit size)
genDigit nm x =
do newName <- gensym nm
emit (declare newName x)
return D{
name = newName
, digit = x
}
embiggen :: Digit Word64 -> Math (Digit Word128)
embiggen x =
do newName <- gensym ("big_" ++ name x)
emit (CastUp newName (name x))
return (D newName (fromIntegral (digit x)))
bottomBits :: Digit Word128 -> Math (Digit Word64)
bottomBits x =
do newName <- gensym ("norm_" ++ name x)
emit (CastDown newName (name x))
return (D newName (fromIntegral (digit x)))
oneDigit :: Math (Digit Word64)
oneDigit = genDigit "one" 1
bigZero :: Math (Digit Word128)
bigZero = genDigit "zero" 0
(|+|) :: Digit Word128 -> Digit Word128 -> Math (Digit Word128)
(|+|) x y =
do newName <- gensym "plus"
emit (Add newName [name x, name y])
let digval = digit x + digit y
return (D newName digval)
sumDigits :: [Digit Word128] -> Math (Digit Word128)
sumDigits ls =
do newName <- gensym "sum"
emit (Add newName (map name ls))
let digval = sum (map digit ls)
return (D newName digval)
(|*|) :: Digit Word128 -> Digit Word128 -> Math (Digit Word128)
(|*|) x y =
do newName <- gensym "times"
emit (Multiply newName [name x, name y])
let digval = digit x * digit y
return (D newName digval)
(|>>|) :: Digit Word128 -> Int -> Math (Digit Word128)
(|>>|) x s =
do newName <- gensym "shiftr"
emit (ShiftR newName (name x) s)
let digval = digit x `shiftR` s
return (D newName digval)
(|&|) :: Digit Word128 -> Word128 -> Math (Digit Word128)
(|&|) x m =
do newName <- gensym ("masked_" ++ name x)
emit (Mask newName (name x) m)
let digval = digit x .&. m
return (D newName digval)
complementDigit :: Digit Word64 -> Math (Digit Word64)
complementDigit x =
do newName <- gensym ("comp_" ++ name x)
emit (Complement newName (name x))
return (D newName (complement (digit x)))
-- -----------------------------------------------------------------------------
--
-- Extended mathematics that run on whole numbers
--
-- -----------------------------------------------------------------------------
type Number = Vector (Digit Word64)
instance Arbitrary Number where
arbitrary =
do ls <- replicateM inputWordSize (D "" <$> arbitrary)
return (V.fromList ls)
rename :: String -> Number -> Number
rename var num = go 0 num
where
go :: Word -> Number -> Number
go i v =
case v !? 0 of
Nothing -> V.empty
Just x -> D (var ++ show i) (digit x) `V.cons` go (i + 1) (V.drop 1 v)
convertTo :: Int -> Integer -> Number
convertTo s = pad . V.unfoldrN s next
where
next 0 = Nothing
next x = Just (D{ name = "", digit = fromIntegral x }, x `shiftR` 64)
pad v | V.length v == s = v
| otherwise = pad (v <> V.singleton D{ name = "", digit = 0})
convertFrom :: Number -> Integer
convertFrom n = V.foldr combine 0 n
where
combine x acc = (acc `shiftL` 64) + fromIntegral (digit x)
prop_ConversionWorksNum :: Number -> Bool
prop_ConversionWorksNum n =
n == convertTo inputWordSize (convertFrom n)
prop_ConversionWorksInt :: Integer -> Bool
prop_ConversionWorksInt n =
n' == convertFrom (convertTo inputWordSize n)
where n' = n `mod` (2 ^ (inputWordSize * 64))
zero :: Int -> Math Number
zero s = V.fromList `fmap` replicateM s (genDigit "zero" 0)
empty :: Number -> Bool
empty = null
size :: Number -> Int
size = length
splitDigits :: Int -> Number -> Math (Number, Number)
splitDigits i ls = return (V.splitAt i ls)
prop_SplitDigitsIsntTerrible :: Int -> Number -> Bool
prop_SplitDigitsIsntTerrible x n =
let ((left, right), _) = runMath (splitDigits x' n)
in n == (left <> right)
where x' = x `mod` inputWordSize
addZeros :: Int -> Number -> Math Number
addZeros x n =
do prefix <- zero x
return (prefix <> n)
prop_AddZerosIsShift :: Int -> Number -> Bool
prop_AddZerosIsShift x n =
let x' = abs (x `mod` inputWordSize)
nInt = convertFrom n
shiftVersion = nInt `shiftL` (x' * 64)
addVersion = convertFrom (fst (runMath (addZeros x' n)))
in shiftVersion == addVersion
padTo :: Int -> Number -> Math Number
padTo len num =
do suffix <- zero (len - V.length num)
return (num <> suffix)
prop_PadToWorks :: Int -> Number -> Property
prop_PadToWorks len num = len >= size num ==>
convertFrom num == convertFrom (fst (runMath (padTo len num)))
add2 :: Number -> Number -> Math Number
add2 xs ys
| size xs /= size ys =
fail "Add2 of uneven vectors."
| otherwise =
do let both = V.zip xs ys
nada <- bigZero
(res, carry) <- foldM ripple (V.empty, nada) both
lastDigit <- bottomBits carry
return (res <> V.singleton lastDigit)
where
ripple (res, carry) (x, y) =
do x' <- embiggen x
y' <- embiggen y
bigRes <- sumDigits [x', y', carry]
carry' <- bigRes |>>| 64
newdigit <- bottomBits bigRes
let res' = res <> V.singleton newdigit
return (res', carry')
prop_Add2Works :: Number -> Number -> Bool
prop_Add2Works n m =
let nInt = convertFrom n
mInt = convertFrom m
intRes = nInt + mInt
(numRes, _) = runMath (add2 n m)
numResInt = convertFrom numRes
in (size numRes == inputWordSize + 1) && (intRes == numResInt)
add3 :: Number -> Number -> Number -> Math Number
add3 x y z
| size x /= size y =
fail "Unequal lengths in add3 (1)."
| size y /= size z =
fail "Unequal lengths in add3 (2)."
| otherwise =
do let allThem = V.zip3 x y z
nada <- bigZero
(res, carry) <- foldM ripple (V.empty, nada) allThem
lastDigit <- bottomBits carry
return (res <> V.singleton lastDigit)
where
ripple (res, carry) (a, b, c) =
do a' <- embiggen a
b' <- embiggen b
c' <- embiggen c
bigRes <- sumDigits [a', b', c', carry]
carry' <- bigRes |>>| 64
digit' <- bottomBits bigRes
let res' = res <> V.singleton digit'
return (res', carry')
prop_Add3Works :: Number -> Number -> Number -> Bool
prop_Add3Works x y z =
let xInt = convertFrom x
yInt = convertFrom y
zInt = convertFrom z
intRes = xInt + yInt + zInt
(numRes, _) = runMath (add3 x y z)
numResInt = convertFrom numRes
in (size numRes == inputWordSize + 1) && (intRes == numResInt)
sub2 :: Number -> Number -> Math Number
sub2 x y
| size x /= size y =
fail "Unequal lengths in sub."
| otherwise =
do yinv <- mapM complementDigit y
oned <- oneDigit
one <- padTo (size x) (V.singleton oned)
res <- add3 x yinv one
return (V.take (size x) res)
prop_Sub2Works :: Number -> Number -> Bool
prop_Sub2Works a b
| convertFrom a < convertFrom b = prop_Sub2Works b a
| otherwise =
let aInt = convertFrom a
bInt = convertFrom b
intRes = aInt - bInt
(numRes, _) = runMath (sub2 a b)
numResInt = convertFrom numRes
in intRes == numResInt
-- -----------------------------------------------------------------------------
--
-- Finally, multiplication and Karatsuba
--
-- -----------------------------------------------------------------------------
mul1 :: Number -> Number -> Math Number
mul1 num1 num2
| size num1 /= 1 || size num2 /= 1 =
fail "Called mul1 with !1 digit numbers. Idiot."
| otherwise =
do x' <- embiggen (V.head num1)
y' <- embiggen (V.head num2)
comb <- x' |*| y'
z0 <- bottomBits comb
z1 <- bottomBits =<< (comb |>>| 64)
return (V.fromList [z0, z1])
prop_MulNWorks :: Int -> (Number -> Number -> Math Number) ->
Number -> Number ->
Bool
prop_MulNWorks nsize f x y =
let (x', _) = runMath (padTo nsize (V.take nsize x))
(y', _) = runMath (padTo nsize (V.take nsize y))
xInt = convertFrom x'
yInt = convertFrom y'
resInt = xInt * yInt
(resNum, _) = runMath (f x' y')
in (size x' == nsize) && (size y' == nsize) &&
(size resNum == (nsize * 2)) &&
(resInt == convertFrom resNum)
prop_Mul1Works :: Number -> Number -> Bool
prop_Mul1Works = prop_MulNWorks 1 mul1
mul2 :: Number -> Number -> Math Number
mul2 num1 num2
| size num1 /= 2 || size num2 /= 2 =
fail "Called mul2 with !2 digit numbers. Idiot."
| otherwise =
do [l0, l1] <- mapM embiggen (V.toList num1)
[r0, r1] <- mapM embiggen (V.toList num2)
--
l0r0 <- l0 |*| r0
carry0 <- l0r0 |>>| 64
dest0 <- bottomBits l0r0
l1r0 <- l1 |*| r0
l1r0' <- l1r0 |+| carry0
tdest1 <- l1r0' |&| 0xFFFFFFFFFFFFFFFF
tdest2 <- l1r0' |>>| 64
--
l0r1 <- l0 |*| r1
l0r1' <- tdest1 |+| l0r1
dest1 <- bottomBits l0r1'
l1r1 <- l1 |*| r1
l1r1' <- tdest2 |+| l1r1
carry1 <- l0r1' |>>| 64
l1r1'' <- l1r1' |+| carry1
dest2 <- bottomBits l1r1''
dest3 <- bottomBits =<< (l1r1'' |>>| 64)
return (V.fromList [dest0, dest1, dest2, dest3])
prop_Mul2Works :: Number -> Number -> Bool
prop_Mul2Works = prop_MulNWorks 2 mul2
mul3 :: Number -> Number -> Math Number
mul3 num1 num2
| size num1 /= 3 || size num2 /= 3 =
fail "Called mul2 with !2 digit numbers. Idiot."
| otherwise =
do [l0, l1, l2] <- mapM embiggen (V.toList num1)
[r0, r1, r2] <- mapM embiggen (V.toList num2)
--
l0r0 <- l0 |*| r0
dest0 <- bottomBits l0r0
carry0 <- l0r0 |>>| 64
l1r0 <- l1 |*| r0
l1r0' <- l1r0 |+| carry0
tdest1 <- l1r0' |&| 0xFFFFFFFFFFFFFFFF
carry1 <- l1r0' |>>| 64
l2r0 <- l2 |*| r0
l2r0' <- l2r0 |+| carry1
tdest2 <- l2r0' |&| 0xFFFFFFFFFFFFFFFF
tdest3 <- l2r0' |>>| 64
--
l0r1 <- l0 |*| r1
l0r1' <- tdest1 |+| l0r1
dest1 <- bottomBits l0r1'
carry2 <- l0r1' |>>| 64
l1r1 <- l1 |*| r1
l1r1' <- sumDigits [l1r1, tdest2, carry2]
tdest2' <- l1r1' |&| 0xFFFFFFFFFFFFFFFF
carry3 <- l1r1' |>>| 64
l2r1 <- l2 |*| r1
l2r1' <- sumDigits [l2r1, tdest3, carry3]
tdest3' <- l2r1' |&| 0xFFFFFFFFFFFFFFFF
tdest4' <- l2r1' |>>| 64
--
l0r2 <- l0 |*| r2
l0r2' <- l0r2 |+| tdest2'
dest2 <- bottomBits l0r2'
carry4 <- l0r2' |>>| 64
l1r2 <- l1 |*| r2
l1r2' <- sumDigits [l1r2, tdest3', carry4]
dest3 <- bottomBits l1r2'
carry5 <- l1r2' |>>| 64
l2r2 <- l2 |*| r2
l2r2' <- sumDigits [l2r2, tdest4', carry5]
dest4 <- bottomBits l2r2'
dest5 <- bottomBits =<< (l2r2' |>>| 64)
return (V.fromList [dest0, dest1, dest2, dest3, dest4, dest5])
prop_Mul3Works :: Number -> Number -> Bool
prop_Mul3Works = prop_MulNWorks 3 mul3
karatsuba :: Number -> Number -> Math Number
karatsuba num1 num2
| size num1 /= size num2 =
fail "Uneven numeric lengths!"
| empty num1 =
fail "Got empty nums"
| size num1 == 1 = mul1 num1 num2
| size num1 == 2 = mul2 num1 num2
| size num1 == 3 = mul3 num1 num2
| otherwise =
do let m = min (size num1) (size num2)
m2 = m `div` 2
(low1, high1) <- splitDigits (fromIntegral m2) num1
(low2, high2) <- splitDigits (fromIntegral m2) num2
z0 <- karatsuba low1 low2
let midsize = max (size low1) (size high1)
low1' <- padTo midsize low1
low2' <- padTo midsize low2
high1' <- padTo midsize high1
high2' <- padTo midsize high2
mid1 <- add2 low1' high1'
mid2 <- add2 low2' high2'
z1 <- karatsuba mid1 mid2
z2 <- karatsuba high1 high2
let subsize = max (size z0) (max (size z1) (size z2))
sz0 <- padTo subsize z0
sz1 <- padTo subsize z1
sz2 <- padTo subsize z2
tmp <- sub2 sz1 sz2
z1' <- addZeros m2 =<< sub2 tmp sz0
z2' <- addZeros (m2 * 2) z2
let addsize = max (size z0) (max (size z1') (size z2'))
az0 <- padTo addsize z0
az1 <- padTo addsize z1'
az2 <- padTo addsize z2'
add3 az2 az1 az0
prop_KaratsubaWorks :: Number -> Number -> Bool
prop_KaratsubaWorks x y =
let shouldBe = convertFrom x * convertFrom y
monad = karatsuba x y
myVersion = convertFrom (fst (runMath monad))
in shouldBe == myVersion
prop_InstructionsWork :: Number -> Number -> Bool
prop_InstructionsWork x y =
let shouldBe = convertFrom x' * convertFrom y'
(mine, instrs) = runMath (karatsuba x' y')
myVersion = convertFrom mine
(endEnv, _) = run startEnv instrs
instrVersion = V.map (getv endEnv . name) mine
in (shouldBe == myVersion) && (mine == instrVersion)
where
x' = rename "x" x
y' = rename "y" y
startEnv = (Map.fromList startEnv64, Map.empty)
startEnv64 = map (\ d -> (name d, digit d)) (V.toList (x' <> y'))
getv env n =
case Map.lookup n env of
Nothing -> error ("InstrProp lookup failure: " ++ n)
Just v -> D n v
prop_InstructionsConsistent :: Number -> Number -> Number -> Number -> Bool
prop_InstructionsConsistent a b x y =
let (_, instrs1) = runMath (karatsuba a' b')
(_, instrs2) = runMath (karatsuba x' y')
in instrs1 == instrs2
where
a' = rename "p" a
b' = rename "q" b
x' = rename "p" x
y' = rename "q" y
-- -----------------------------------------------------------------------------
--
-- Test running
--
-- -----------------------------------------------------------------------------
runQuickCheck :: Testable prop => String -> prop -> IO ()
runQuickCheck testname prop =
do putStr testname
quickCheck (withMaxSuccess 1000 prop)
runChecks :: IO ()
runChecks =
do runQuickCheck "Num -> Int -> Num " prop_ConversionWorksNum
runQuickCheck "Int -> Num -> Int " prop_ConversionWorksInt
runQuickCheck "Split Isn't Dumb " prop_SplitDigitsIsntTerrible
runQuickCheck "More 0s is Shift " prop_AddZerosIsShift
runQuickCheck "PadTo Does That " prop_PadToWorks
runQuickCheck "Add2 Works " prop_Add2Works
runQuickCheck "Add3 Works " prop_Add3Works
runQuickCheck "Sub2 Works " prop_Sub2Works
runQuickCheck "Mul1 Works " prop_Mul1Works
runQuickCheck "Mul2 Works " prop_Mul2Works
runQuickCheck "Mul3 Works " prop_Mul3Works
runQuickCheck "Karatsuba Works " prop_KaratsubaWorks
runQuickCheck "Instructions Work " prop_InstructionsWork
runQuickCheck "Generation Consistent " prop_InstructionsConsistent