[CHECKPOINT] Adjust the Karatsuba implementation to abstract Variables from Strings

This commit is contained in:
2020-04-12 19:52:00 -07:00
parent 2baa5f070d
commit 0483bb8692
2 changed files with 291 additions and 239 deletions

View File

@@ -23,6 +23,7 @@ import qualified Data.Map.Strict as Map
import Data.Vector(Vector, (!?))
import qualified Data.Vector as V
import Data.Word
import Debug.Trace
import Prelude hiding (fail)
import Test.QuickCheck hiding ((.&.))
@@ -31,13 +32,19 @@ inputWordSize :: Int
inputWordSize = 5
generateInstructions :: Word -> [Instruction]
generateInstructions numdigits = foldl replaceVar baseInstrs varRenames
where
x = rename "x" (V.replicate (fromIntegral numdigits) (D "" 1))
y = rename "y" (V.replicate (fromIntegral numdigits) (D "" 1))
(baseVec, baseInstrs) = runMath (karatsuba x y)
res = rename "res" baseVec
varRenames = zip (map name (V.toList baseVec)) (map name (V.toList res))
generateInstructions numdigits =
let (baseVec, baseInstrs) = runMath $ do x <- V.replicateM (fromIntegral numdigits) (genDigit 1)
y <- V.replicateM (fromIntegral numdigits) (genDigit 1)
karatsuba x y
in baseInstrs
-- foldl replaceVar baseInstrs varRenames
-- where
-- x = rename "x" (V.replicate (fromIntegral numdigits) (D "" 1))
-- y = rename "y" (V.replicate (fromIntegral numdigits) (D "" 1))
-- (baseVec, baseInstrs) = runMath (karatsuba x y)
-- res = rename "res" baseVec
-- varRenames = zip (map name (V.toList baseVec)) (map name (V.toList res))
-- -----------------------------------------------------------------------------
--
@@ -46,28 +53,31 @@ generateInstructions numdigits = foldl replaceVar baseInstrs varRenames
--
-- -----------------------------------------------------------------------------
newtype Variable = V String
deriving (Eq, Ord, Show)
-- these are in Intel form, as I was corrupted young, so the first argument
-- is the destination and the rest are the arguments.
data Instruction = Add String [String]
| CastDown String String
| CastUp String String
| Complement String String
| Declare64 String Word64
| Declare128 String Word128
| Mask String String Word128
| Multiply String [String]
| ShiftR String String Int
data Instruction = Add Variable [Variable]
| CastDown Variable Variable
| CastUp Variable Variable
| Complement Variable Variable
| Declare64 Variable Word64
| Declare128 Variable Word128
| Mask Variable Variable Word128
| Multiply Variable [Variable]
| ShiftR Variable Variable Int
deriving (Eq, Show)
class Declarable a where
declare :: String -> a -> Instruction
declare :: Variable -> a -> Instruction
instance Declarable Word64 where
declare n x = Declare64 n x
declare n x = Declare64 n x
instance Declarable Word128 where
declare n x = Declare128 n x
declare n x = Declare128 n x
type Env = (Map String Word64, Map String Word128)
type Env = (Map Variable Word64, Map Variable Word128)
step :: Env -> Instruction -> Env
step (env64, env128) i =
@@ -91,12 +101,11 @@ step (env64, env128) i =
ShiftR outname item amt ->
(env64, Map.insert outname (getv env128 item `shiftR` amt) env128)
where
getv :: Map Variable a -> Variable -> a
getv env s =
case Map.lookup s env of
Nothing ->
error ("Failure to find key '" ++ s ++ "'")
Just v ->
v
Nothing -> error ("Failure to find key '" ++ show s ++ "'")
Just v -> v
run :: Env -> [Instruction] -> Env
run env instrs =
@@ -104,7 +113,7 @@ run env instrs =
[] -> env
(x:rest) -> run (step env x) rest
replaceVar :: [Instruction] -> (String, String) -> [Instruction]
replaceVar :: [Instruction] -> (Variable, Variable) -> [Instruction]
replaceVar ls (from, to) = map replace ls
where
replace x =
@@ -138,10 +147,10 @@ instance MonadFail Math where
emit :: Instruction -> Math ()
emit instr = tell [instr]
gensym :: String -> Math String
gensym base =
newVariable :: Math Variable
newVariable =
do x <- state (\ i -> (i, i + 1))
return (base ++ show x)
return (V (show x))
runMath :: Math a -> (a, [Instruction])
runMath m = evalRWS (unMath m) () 0
@@ -153,14 +162,14 @@ runMath m = evalRWS (unMath m) () 0
-- -----------------------------------------------------------------------------
data Digit size = D {
name :: String
name :: Variable
, digit :: size
}
deriving (Eq,Show)
genDigit :: Declarable size => String -> size -> Math (Digit size)
genDigit nm x =
do newName <- gensym nm
genDigit :: Declarable size => size -> Math (Digit size)
genDigit x =
do newName <- newVariable
emit (declare newName x)
return D{
name = newName
@@ -169,60 +178,60 @@ genDigit nm x =
embiggen :: Digit Word64 -> Math (Digit Word128)
embiggen x =
do newName <- gensym ("big_" ++ name x)
do newName <- newVariable
emit (CastUp newName (name x))
return (D newName (fromIntegral (digit x)))
bottomBits :: Digit Word128 -> Math (Digit Word64)
bottomBits x =
do newName <- gensym ("norm_" ++ name x)
do newName <- newVariable
emit (CastDown newName (name x))
return (D newName (fromIntegral (digit x)))
oneDigit :: Math (Digit Word64)
oneDigit = genDigit "one" 1
oneDigit = genDigit 1
bigZero :: Math (Digit Word128)
bigZero = genDigit "zero" 0
bigZero = genDigit 0
(|+|) :: Digit Word128 -> Digit Word128 -> Math (Digit Word128)
(|+|) x y =
do newName <- gensym "plus"
do newName <- newVariable
emit (Add newName [name x, name y])
let digval = digit x + digit y
return (D newName digval)
sumDigits :: [Digit Word128] -> Math (Digit Word128)
sumDigits ls =
do newName <- gensym "sum"
do newName <- newVariable
emit (Add newName (map name ls))
let digval = sum (map digit ls)
return (D newName digval)
(|*|) :: Digit Word128 -> Digit Word128 -> Math (Digit Word128)
(|*|) x y =
do newName <- gensym "times"
do newName <- newVariable
emit (Multiply newName [name x, name y])
let digval = digit x * digit y
return (D newName digval)
(|>>|) :: Digit Word128 -> Int -> Math (Digit Word128)
(|>>|) x s =
do newName <- gensym "shiftr"
do newName <- newVariable
emit (ShiftR newName (name x) s)
let digval = digit x `shiftR` s
return (D newName digval)
(|&|) :: Digit Word128 -> Word128 -> Math (Digit Word128)
(|&|) x m =
do newName <- gensym ("masked_" ++ name x)
do newName <- newVariable
emit (Mask newName (name x) m)
let digval = digit x .&. m
return (D newName digval)
complementDigit :: Digit Word64 -> Math (Digit Word64)
complementDigit x =
do newName <- gensym ("comp_" ++ name x)
do newName <- newVariable
emit (Complement newName (name x))
return (D newName (complement (digit x)))
@@ -234,44 +243,31 @@ complementDigit x =
type Number = Vector (Digit Word64)
instance Arbitrary Number where
arbitrary =
do ls <- replicateM inputWordSize (D "" <$> arbitrary)
return (V.fromList ls)
rename :: String -> Number -> Number
rename var num = go 0 num
convertTo :: Int -> Integer -> Math Number
convertTo sz num = V.fromList `fmap` go sz num
where
go :: Word -> Number -> Number
go i v =
case v !? 0 of
Nothing -> V.empty
Just x -> D (var ++ show i) (digit x) `V.cons` go (i + 1) (V.drop 1 v)
convertTo :: Int -> Integer -> Number
convertTo s = pad . V.unfoldrN s next
where
next 0 = Nothing
next x = Just (D{ name = "", digit = fromIntegral x }, x `shiftR` 64)
pad v | V.length v == s = v
| otherwise = pad (v <> V.singleton D{ name = "", digit = 0})
go :: Int -> Integer -> Math [Digit Word64]
go 0 _ =
return []
go x v =
do d <- genDigit (fromIntegral v)
rest <- go (x - 1) (v `shiftR` 64)
return (d:rest)
convertFrom :: Number -> Integer
convertFrom n = V.foldr combine 0 n
where
combine x acc = (acc `shiftL` 64) + fromIntegral (digit x)
prop_ConversionWorksNum :: Number -> Bool
prop_ConversionWorksNum n =
n == convertTo inputWordSize (convertFrom n)
prop_ConversionWorksInt :: Integer -> Bool
prop_ConversionWorksInt n =
n' == convertFrom (convertTo inputWordSize n)
where n' = n `mod` (2 ^ (inputWordSize * 64))
prop_ConversionWorksInt n = n' == back
where
n' = abs n `mod` (2 ^ (inputWordSize * 64))
there = fst (runMath (convertTo inputWordSize n'))
back = convertFrom there
zero :: Int -> Math Number
zero s = V.fromList `fmap` replicateM s (genDigit "zero" 0)
zero s = V.fromList `fmap` replicateM s (genDigit 0)
empty :: Number -> Bool
empty = null
@@ -282,33 +278,50 @@ size = length
splitDigits :: Int -> Number -> Math (Number, Number)
splitDigits i ls = return (V.splitAt i ls)
prop_SplitDigitsIsntTerrible :: Int -> Number -> Bool
prop_SplitDigitsIsntTerrible x n =
let ((left, right), _) = runMath (splitDigits x' n)
in n == (left <> right)
where x' = x `mod` inputWordSize
prop_SplitDigitsIsntTerrible :: Int -> Int -> Integer -> Bool
prop_SplitDigitsIsntTerrible a b n =
let a' = a `mod` 20
b' = b `mod` 20
(p, l) | a' > b' = (b', a')
| a' < b' = (a', b')
| otherwise = (a' - 1, a')
in fst $ runMath $ do base <- convertTo l n
(left, right) <- splitDigits p base
return (base == (left <> right))
addZeros :: Int -> Number -> Math Number
addZeros x n =
do prefix <- zero x
return (prefix <> n)
prop_AddZerosIsShift :: Int -> Number -> Bool
prop_AddZerosIsShift :: Int -> Integer -> Bool
prop_AddZerosIsShift x n =
let x' = abs (x `mod` inputWordSize)
nInt = convertFrom n
shiftVersion = nInt `shiftL` (x' * 64)
addVersion = convertFrom (fst (runMath (addZeros x' n)))
in shiftVersion == addVersion
fst $ runMath $ do base <- convertTo inputWordSize n'
added <- addZeros x' base
let shiftVer = n' `shiftL` (x' * 64)
let mine = convertFrom added
return (shiftVer == mine)
where
x' = abs x `mod` inputWordSize
n' = abs n `mod` (2 ^ (inputWordSize * 64))
padTo :: Int -> Number -> Math Number
padTo len num =
do suffix <- zero (len - V.length num)
return (num <> suffix)
prop_PadToWorks :: Int -> Number -> Property
prop_PadToWorks len num = len >= size num ==>
convertFrom num == convertFrom (fst (runMath (padTo len num)))
prop_PadToWorks :: Int -> Int -> Integer -> Bool
prop_PadToWorks a b num =
fst $ runMath $ do base <- convertTo sz num'
padded <- padTo len base
let newval = convertFrom padded
return (num' == newval)
where
a' = abs a `mod` (inputWordSize * 3)
b' = abs b `mod` (inputWordSize * 3)
(len, sz) | a' >= b' = (max 1 a', max 1 b')
| otherwise = (max 1 b', max 1 a')
num' = abs (num `mod` (2 ^ (64 * sz)))
add2 :: Number -> Number -> Math Number
add2 xs ys
@@ -330,14 +343,18 @@ add2 xs ys
let res' = res <> V.singleton newdigit
return (res', carry')
prop_Add2Works :: Number -> Number -> Bool
prop_Add2Works n m =
let nInt = convertFrom n
mInt = convertFrom m
intRes = nInt + mInt
(numRes, _) = runMath (add2 n m)
numResInt = convertFrom numRes
in (size numRes == inputWordSize + 1) && (intRes == numResInt)
prop_Add2Works :: Int -> Integer -> Integer -> Bool
prop_Add2Works l n m =
fst $ runMath $ do num1 <- convertTo l' n'
num2 <- convertTo l' m'
res <- add2 num1 num2
let intRes = convertFrom res
return ((intRes == r) && (size res == l' + 1))
where
l' = max 1 (abs l `mod` inputWordSize)
n' = abs n `mod` (2 ^ (l' * 64))
m' = abs m `mod` (2 ^ (l' * 64))
r = n' + m'
add3 :: Number -> Number -> Number -> Math Number
add3 x y z
@@ -362,15 +379,20 @@ add3 x y z
let res' = res <> V.singleton digit'
return (res', carry')
prop_Add3Works :: Number -> Number -> Number -> Bool
prop_Add3Works x y z =
let xInt = convertFrom x
yInt = convertFrom y
zInt = convertFrom z
intRes = xInt + yInt + zInt
(numRes, _) = runMath (add3 x y z)
numResInt = convertFrom numRes
in (size numRes == inputWordSize + 1) && (intRes == numResInt)
prop_Add3Works :: Int -> Integer -> Integer -> Integer -> Bool
prop_Add3Works l x y z =
fst $ runMath $ do num1 <- convertTo l' x'
num2 <- convertTo l' y'
num3 <- convertTo l' z'
res <- add3 num1 num2 num3
let intRes = convertFrom res
return ((intRes == r) && (size res == l' + 1))
where
l' = max 1 (abs l `mod` inputWordSize)
x' = abs x `mod` (2 ^ (l' * 64))
y' = abs y `mod` (2 ^ (l' * 64))
z' = abs z `mod` (2 ^ (l' * 64))
r = x' + y' + z'
sub2 :: Number -> Number -> Math Number
sub2 x y
@@ -383,16 +405,20 @@ sub2 x y
res <- add3 x yinv one
return (V.take (size x) res)
prop_Sub2Works :: Number -> Number -> Bool
prop_Sub2Works a b
| convertFrom a < convertFrom b = prop_Sub2Works b a
| otherwise =
let aInt = convertFrom a
bInt = convertFrom b
intRes = aInt - bInt
(numRes, _) = runMath (sub2 a b)
numResInt = convertFrom numRes
in intRes == numResInt
prop_Sub2Works :: Int -> Integer -> Integer -> Bool
prop_Sub2Works l a b =
fst $ runMath $ do num1 <- convertTo l' x
num2 <- convertTo l' y
res <- sub2 num1 num2
let intRes = convertFrom res
return (intRes == r)
where
l' = max 1 (abs l `mod` inputWordSize)
a' = abs a `mod` (2 ^ (l' * 64))
b' = abs b `mod` (2 ^ (l' * 64))
(x, y) | a' >= b' = (a', b')
| otherwise = (b', a')
r = x - y
-- -----------------------------------------------------------------------------
--
@@ -413,20 +439,20 @@ mul1 num1 num2
return (V.fromList [z0, z1])
prop_MulNWorks :: Int -> (Number -> Number -> Math Number) ->
Number -> Number ->
Integer -> Integer ->
Bool
prop_MulNWorks nsize f x y =
let (x', _) = runMath (padTo nsize (V.take nsize x))
(y', _) = runMath (padTo nsize (V.take nsize y))
xInt = convertFrom x'
yInt = convertFrom y'
resInt = xInt * yInt
(resNum, _) = runMath (f x' y')
in (size x' == nsize) && (size y' == nsize) &&
(size resNum == (nsize * 2)) &&
(resInt == convertFrom resNum)
fst $ runMath $ do num1 <- convertTo nsize x'
num2 <- convertTo nsize y'
res <- f num1 num2
let resInt = convertFrom res
return ((size res == (nsize * 2)) && (resInt == (x' * y')))
where
x' = abs x `mod` (2 ^ (64 * nsize))
y' = abs y `mod` (2 ^ (64 * nsize))
prop_Mul1Works :: Number -> Number -> Bool
prop_Mul1Works :: Integer -> Integer -> Bool
prop_Mul1Works = prop_MulNWorks 1 mul1
mul2 :: Number -> Number -> Math Number
@@ -456,7 +482,7 @@ mul2 num1 num2
dest3 <- bottomBits =<< (l1r1'' |>>| 64)
return (V.fromList [dest0, dest1, dest2, dest3])
prop_Mul2Works :: Number -> Number -> Bool
prop_Mul2Works :: Integer -> Integer -> Bool
prop_Mul2Works = prop_MulNWorks 2 mul2
mul3 :: Number -> Number -> Math Number
@@ -506,7 +532,7 @@ mul3 num1 num2
dest5 <- bottomBits =<< (l2r2' |>>| 64)
return (V.fromList [dest0, dest1, dest2, dest3, dest4, dest5])
prop_Mul3Works :: Number -> Number -> Bool
prop_Mul3Works :: Integer -> Integer -> Bool
prop_Mul3Works = prop_MulNWorks 3 mul3
karatsuba :: Number -> Number -> Math Number
@@ -544,43 +570,69 @@ karatsuba num1 num2
az0 <- padTo addsize z0
az1 <- padTo addsize z1'
az2 <- padTo addsize z2'
add3 az2 az1 az0
res <- add3 az2 az1 az0
forM_ (V.drop (m * 2) res) $ \ highDigit ->
-- this will only occur when (size res > (m * 2))
when (digit highDigit /= 0) $
fail "High bit found in Karatsuba result"
return (V.take (m * 2) res)
prop_KaratsubaWorks :: Number -> Number -> Bool
prop_KaratsubaWorks x y =
let shouldBe = convertFrom x * convertFrom y
monad = karatsuba x y
myVersion = convertFrom (fst (runMath monad))
in shouldBe == myVersion
prop_InstructionsWork :: Number -> Number -> Bool
prop_InstructionsWork x y =
let shouldBe = convertFrom x' * convertFrom y'
(mine, instrs) = runMath (karatsuba x' y')
myVersion = convertFrom mine
(endEnv, _) = run startEnv instrs
instrVersion = V.map (getv endEnv . name) mine
in (shouldBe == myVersion) && (mine == instrVersion)
prop_KaratsubaWorks :: Int -> Integer -> Integer -> Bool
prop_KaratsubaWorks l x y =
fst $ runMath $ do num1 <- convertTo l' x'
num2 <- convertTo l' y'
res <- karatsuba num1 num2
let resInt = convertFrom res
sizeOk = size res == (l' * 2)
valOk = resInt == (x' * y')
return (sizeOk && valOk)
where
x' = rename "x" x
y' = rename "y" y
startEnv = (Map.fromList startEnv64, Map.empty)
startEnv64 = map (\ d -> (name d, digit d)) (V.toList (x' <> y'))
l' = (abs l `mod` (inputWordSize * 2)) + 2
x' = abs x `mod` (2 ^ (64 * l'))
y' = abs y `mod` (2 ^ (64 * l'))
prop_InstructionsWork :: Int -> Integer -> Integer -> Bool
prop_InstructionsWork l x y =
let (value, instructions) = runMath $ do numx <- convertTo l' x'
numy <- convertTo l' y'
karatsuba numx numy
resGMP = x' * y'
resKaratsuba = convertFrom value
(endEnvironment, _) = run (Map.empty, Map.empty) instructions
instrVersion = V.map (getv endEnvironment . name) value
in (resGMP == resKaratsuba) && (value == instrVersion)
where
l' = max 1 (abs l `mod` inputWordSize)
x' = abs x `mod` (2 ^ (64 * l'))
y' = abs y `mod` (2 ^ (64 * l'))
getv env n =
case Map.lookup n env of
Nothing -> error ("InstrProp lookup failure: " ++ n)
Nothing -> error ("InstrProp lookup failure: " ++ show n)
Just v -> D n v
prop_InstructionsConsistent :: Number -> Number -> Number -> Number -> Bool
prop_InstructionsConsistent a b x y =
let (_, instrs1) = runMath (karatsuba a' b')
(_, instrs2) = runMath (karatsuba x' y')
in instrs1 == instrs2
prop_InstructionsConsistent :: Int -> Integer -> Integer -> Integer -> Integer -> Bool
prop_InstructionsConsistent l a b x y =
let (_, instrs1) = runMath (karatsuba' a' b')
(_, instrs2) = runMath (karatsuba' x' y')
instrs1' = dropWhile isDeclare64 instrs1
instrs2' = dropWhile isDeclare64 instrs2
in instrs1' == instrs2'
where
a' = rename "p" a
b' = rename "q" b
x' = rename "p" x
y' = rename "q" y
l' = max 1 (abs l `mod` inputWordSize)
a' = abs a `mod` (2 ^ (64 * l'))
b' = abs b `mod` (2 ^ (64 * l'))
x' = abs x `mod` (2 ^ (64 * l'))
y' = abs y `mod` (2 ^ (64 * l'))
karatsuba' p q =
do num1 <- convertTo l' p
num2 <- convertTo l' q
karatsuba num1 num2
isDeclare64 i =
case i of
Declare64 _ _ -> True
_ -> False
-- -----------------------------------------------------------------------------
--
@@ -595,8 +647,7 @@ runQuickCheck testname prop =
runChecks :: IO ()
runChecks =
do runQuickCheck "Num -> Int -> Num " prop_ConversionWorksNum
runQuickCheck "Int -> Num -> Int " prop_ConversionWorksInt
do runQuickCheck "Int -> Num -> Int " prop_ConversionWorksInt
runQuickCheck "Split Isn't Dumb " prop_SplitDigitsIsntTerrible
runQuickCheck "More 0s is Shift " prop_AddZerosIsShift
runQuickCheck "PadTo Does That " prop_PadToWorks