Trying to limit some of the instructions we do in Karatsuba multiplication ...

This commit is contained in:
2020-04-26 19:53:40 -07:00
parent 9ee668daad
commit a622aa9cc9

View File

@@ -6,9 +6,12 @@
{-# LANGUAGE TypeSynonymInstances #-}
module Karatsuba(
Instruction(..)
, InstructionData(..)
, Variable
, runChecks
, runQuickCheck
, generateInstructions
, variableName
)
where
@@ -20,31 +23,37 @@ import Data.LargeWord
import Data.List
import Data.Map.Strict(Map)
import qualified Data.Map.Strict as Map
import Data.Vector(Vector, (!?))
import Data.Vector(Vector)
import qualified Data.Vector as V
import Data.Word
import Debug.Trace
import Prelude hiding (fail)
import Test.QuickCheck hiding ((.&.))
import Debug.Trace
-- this drives the testing
inputWordSize :: Int
inputWordSize = 5
generateInstructions :: Word -> [Instruction]
generateInstructions numdigits =
let (baseVec, baseInstrs) = runMath $ do x <- V.replicateM (fromIntegral numdigits) (genDigit 1)
y <- V.replicateM (fromIntegral numdigits) (genDigit 1)
karatsuba x y
in baseInstrs
data InstructionData = InstructionData {
idInstructions :: [Instruction],
idInput1 :: [Variable],
idInput2 :: [Variable],
idOutput :: [Variable]
}
-- foldl replaceVar baseInstrs varRenames
-- where
-- x = rename "x" (V.replicate (fromIntegral numdigits) (D "" 1))
-- y = rename "y" (V.replicate (fromIntegral numdigits) (D "" 1))
-- (baseVec, baseInstrs) = runMath (karatsuba x y)
-- res = rename "res" baseVec
-- varRenames = zip (map name (V.toList baseVec)) (map name (V.toList res))
generateInstructions :: Word -> InstructionData
generateInstructions numdigits =
let (baseID, baseInstrs) = runMath $ do (x, xinstrs) <- listen $ V.replicateM (fromIntegral numdigits) (genDigit 1)
(y, yinstrs) <- listen $ V.replicateM (fromIntegral numdigits) (genDigit 1)
res <- karatsuba x y
return InstructionData {
idInstructions = xinstrs ++ yinstrs,
idInput1 = map name (V.toList x),
idInput2 = map name (V.toList y),
idOutput = map name (V.toList res)
}
(preps, reals) = splitAt (length (idInstructions baseID)) baseInstrs
in baseID{ idInstructions = preps ++ simplifyConstants reals }
-- -----------------------------------------------------------------------------
--
@@ -56,6 +65,9 @@ generateInstructions numdigits =
newtype Variable = V Word
deriving (Eq, Ord, Show)
variableName :: Variable -> String
variableName (V x) = "t" ++ show x
-- these are in Intel form, as I was corrupted young, so the first argument
-- is the destination and the rest are the arguments.
data Instruction = Add Variable [Variable]
@@ -113,22 +125,48 @@ run env instrs =
[] -> env
(x:rest) -> run (step env x) rest
replaceVar :: [Instruction] -> (Variable, Variable) -> [Instruction]
replaceVar ls (from, to) = map replace ls
simplifyConstants :: [Instruction] -> [Instruction]
simplifyConstants instrs = go instrs Map.empty Map.empty Map.empty
where
replace x =
case x of
Add outname items -> Add (sub outname) (map sub items)
CastDown outname item -> CastDown (sub outname) (sub item)
CastUp outname item -> CastUp (sub outname) (sub item)
Complement outname item -> Complement (sub outname) (sub item)
Declare64 outname val -> Declare64 (sub outname) val
Declare128 outname val -> Declare128 (sub outname) val
Mask outname item mask -> Mask (sub outname) (sub item) mask
Multiply outname items -> Multiply (sub outname) (map sub items)
ShiftR outname item amt -> ShiftR (sub outname) (sub item) amt
sub x | x == from = to
| otherwise = x
go [] _ _ _ = []
go (instr:rest) consts64 consts128 remaps =
case instr of
Add outname items ->
Add outname (map (replace remaps) items) : go rest consts64 consts128 remaps
CastDown outname item ->
CastDown outname (replace remaps item) : go rest consts64 consts128 remaps
CastUp outname item ->
CastUp outname (replace remaps item) : go rest consts64 consts128 remaps
Complement outname item ->
Complement outname (replace remaps item) : go rest consts64 consts128 remaps
Declare64 outname val | Just outname' <- Map.lookup val consts64 ->
go rest consts64 consts128 (Map.insert outname outname' remaps)
Declare64 outname val ->
Declare64 outname val : go rest (Map.insert val outname consts64) consts128 remaps
Declare128 outname val | Just outname' <- Map.lookup val consts128 ->
go rest consts64 consts128 (Map.insert outname outname' remaps)
Declare128 outname val ->
Declare128 outname val : go rest consts64 (Map.insert val outname consts128) remaps
Mask outname item mask ->
Mask outname (replace remaps item) mask : go rest consts64 consts128 remaps
Multiply outname items ->
Multiply outname (map (replace remaps) items) : go rest consts64 consts128 remaps
ShiftR outname item amt ->
ShiftR outname (replace remaps item) amt : go rest consts64 consts128 remaps
replace :: Map Variable Variable -> Variable -> Variable
replace remaps item = Map.findWithDefault item item remaps
-- -----------------------------------------------------------------------------
--
@@ -593,14 +631,19 @@ prop_KaratsubaWorks l x y =
y' = abs y `mod` (2 ^ (64 * l'))
prop_InstructionsWork :: Int -> Integer -> Integer -> Bool
prop_InstructionsWork l x y =
prop_InstructionsWork' ::
([Instruction] -> [Instruction]) ->
Int ->
Integer ->
Integer ->
Bool
prop_InstructionsWork' f l x y =
let (value, instructions) = runMath $ do numx <- convertTo l' x'
numy <- convertTo l' y'
karatsuba numx numy
resGMP = x' * y'
resKaratsuba = convertFrom value
(endEnvironment, _) = run (Map.empty, Map.empty) instructions
(endEnvironment, _) = run (Map.empty, Map.empty) (f instructions)
instrVersion = V.map (getv endEnvironment . name) value
in (resGMP == resKaratsuba) && (value == instrVersion)
where
@@ -612,6 +655,12 @@ prop_InstructionsWork l x y =
Nothing -> error ("InstrProp lookup failure: " ++ show n)
Just v -> D n v
prop_InstructionsWork :: Int -> Integer -> Integer -> Bool
prop_InstructionsWork = prop_InstructionsWork' id
prop_SimplifiedInstructionsWork :: Int -> Integer -> Integer -> Bool
prop_SimplifiedInstructionsWork = prop_InstructionsWork' simplifyConstants
prop_InstructionsConsistent :: Int -> Integer -> Integer -> Integer -> Integer -> Bool
prop_InstructionsConsistent l a b x y =
let (_, instrs1) = runMath (karatsuba' a' b')
@@ -659,4 +708,5 @@ runChecks =
runQuickCheck "Mul3 Works " prop_Mul3Works
runQuickCheck "Karatsuba Works " prop_KaratsubaWorks
runQuickCheck "Instructions Work " prop_InstructionsWork
runQuickCheck "Simpl. Instructions Work " prop_SimplifiedInstructionsWork
runQuickCheck "Generation Consistent " prop_InstructionsConsistent