Trying to limit some of the instructions we do in Karatsuba multiplication ...
This commit is contained in:
@@ -6,9 +6,12 @@
|
||||
{-# LANGUAGE TypeSynonymInstances #-}
|
||||
module Karatsuba(
|
||||
Instruction(..)
|
||||
, InstructionData(..)
|
||||
, Variable
|
||||
, runChecks
|
||||
, runQuickCheck
|
||||
, generateInstructions
|
||||
, variableName
|
||||
)
|
||||
where
|
||||
|
||||
@@ -20,31 +23,37 @@ import Data.LargeWord
|
||||
import Data.List
|
||||
import Data.Map.Strict(Map)
|
||||
import qualified Data.Map.Strict as Map
|
||||
import Data.Vector(Vector, (!?))
|
||||
import Data.Vector(Vector)
|
||||
import qualified Data.Vector as V
|
||||
import Data.Word
|
||||
import Debug.Trace
|
||||
import Prelude hiding (fail)
|
||||
import Test.QuickCheck hiding ((.&.))
|
||||
import Debug.Trace
|
||||
|
||||
-- this drives the testing
|
||||
inputWordSize :: Int
|
||||
inputWordSize = 5
|
||||
|
||||
generateInstructions :: Word -> [Instruction]
|
||||
generateInstructions numdigits =
|
||||
let (baseVec, baseInstrs) = runMath $ do x <- V.replicateM (fromIntegral numdigits) (genDigit 1)
|
||||
y <- V.replicateM (fromIntegral numdigits) (genDigit 1)
|
||||
karatsuba x y
|
||||
in baseInstrs
|
||||
data InstructionData = InstructionData {
|
||||
idInstructions :: [Instruction],
|
||||
idInput1 :: [Variable],
|
||||
idInput2 :: [Variable],
|
||||
idOutput :: [Variable]
|
||||
}
|
||||
|
||||
-- foldl replaceVar baseInstrs varRenames
|
||||
-- where
|
||||
-- x = rename "x" (V.replicate (fromIntegral numdigits) (D "" 1))
|
||||
-- y = rename "y" (V.replicate (fromIntegral numdigits) (D "" 1))
|
||||
-- (baseVec, baseInstrs) = runMath (karatsuba x y)
|
||||
-- res = rename "res" baseVec
|
||||
-- varRenames = zip (map name (V.toList baseVec)) (map name (V.toList res))
|
||||
generateInstructions :: Word -> InstructionData
|
||||
generateInstructions numdigits =
|
||||
let (baseID, baseInstrs) = runMath $ do (x, xinstrs) <- listen $ V.replicateM (fromIntegral numdigits) (genDigit 1)
|
||||
(y, yinstrs) <- listen $ V.replicateM (fromIntegral numdigits) (genDigit 1)
|
||||
res <- karatsuba x y
|
||||
return InstructionData {
|
||||
idInstructions = xinstrs ++ yinstrs,
|
||||
idInput1 = map name (V.toList x),
|
||||
idInput2 = map name (V.toList y),
|
||||
idOutput = map name (V.toList res)
|
||||
}
|
||||
(preps, reals) = splitAt (length (idInstructions baseID)) baseInstrs
|
||||
in baseID{ idInstructions = preps ++ simplifyConstants reals }
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
--
|
||||
@@ -56,6 +65,9 @@ generateInstructions numdigits =
|
||||
newtype Variable = V Word
|
||||
deriving (Eq, Ord, Show)
|
||||
|
||||
variableName :: Variable -> String
|
||||
variableName (V x) = "t" ++ show x
|
||||
|
||||
-- these are in Intel form, as I was corrupted young, so the first argument
|
||||
-- is the destination and the rest are the arguments.
|
||||
data Instruction = Add Variable [Variable]
|
||||
@@ -113,22 +125,48 @@ run env instrs =
|
||||
[] -> env
|
||||
(x:rest) -> run (step env x) rest
|
||||
|
||||
replaceVar :: [Instruction] -> (Variable, Variable) -> [Instruction]
|
||||
replaceVar ls (from, to) = map replace ls
|
||||
simplifyConstants :: [Instruction] -> [Instruction]
|
||||
simplifyConstants instrs = go instrs Map.empty Map.empty Map.empty
|
||||
where
|
||||
replace x =
|
||||
case x of
|
||||
Add outname items -> Add (sub outname) (map sub items)
|
||||
CastDown outname item -> CastDown (sub outname) (sub item)
|
||||
CastUp outname item -> CastUp (sub outname) (sub item)
|
||||
Complement outname item -> Complement (sub outname) (sub item)
|
||||
Declare64 outname val -> Declare64 (sub outname) val
|
||||
Declare128 outname val -> Declare128 (sub outname) val
|
||||
Mask outname item mask -> Mask (sub outname) (sub item) mask
|
||||
Multiply outname items -> Multiply (sub outname) (map sub items)
|
||||
ShiftR outname item amt -> ShiftR (sub outname) (sub item) amt
|
||||
sub x | x == from = to
|
||||
| otherwise = x
|
||||
go [] _ _ _ = []
|
||||
go (instr:rest) consts64 consts128 remaps =
|
||||
case instr of
|
||||
Add outname items ->
|
||||
Add outname (map (replace remaps) items) : go rest consts64 consts128 remaps
|
||||
|
||||
CastDown outname item ->
|
||||
CastDown outname (replace remaps item) : go rest consts64 consts128 remaps
|
||||
|
||||
CastUp outname item ->
|
||||
CastUp outname (replace remaps item) : go rest consts64 consts128 remaps
|
||||
|
||||
Complement outname item ->
|
||||
Complement outname (replace remaps item) : go rest consts64 consts128 remaps
|
||||
|
||||
Declare64 outname val | Just outname' <- Map.lookup val consts64 ->
|
||||
go rest consts64 consts128 (Map.insert outname outname' remaps)
|
||||
|
||||
Declare64 outname val ->
|
||||
Declare64 outname val : go rest (Map.insert val outname consts64) consts128 remaps
|
||||
|
||||
Declare128 outname val | Just outname' <- Map.lookup val consts128 ->
|
||||
go rest consts64 consts128 (Map.insert outname outname' remaps)
|
||||
|
||||
Declare128 outname val ->
|
||||
Declare128 outname val : go rest consts64 (Map.insert val outname consts128) remaps
|
||||
|
||||
Mask outname item mask ->
|
||||
Mask outname (replace remaps item) mask : go rest consts64 consts128 remaps
|
||||
|
||||
Multiply outname items ->
|
||||
Multiply outname (map (replace remaps) items) : go rest consts64 consts128 remaps
|
||||
|
||||
ShiftR outname item amt ->
|
||||
ShiftR outname (replace remaps item) amt : go rest consts64 consts128 remaps
|
||||
|
||||
replace :: Map Variable Variable -> Variable -> Variable
|
||||
replace remaps item = Map.findWithDefault item item remaps
|
||||
|
||||
|
||||
-- -----------------------------------------------------------------------------
|
||||
--
|
||||
@@ -593,14 +631,19 @@ prop_KaratsubaWorks l x y =
|
||||
y' = abs y `mod` (2 ^ (64 * l'))
|
||||
|
||||
|
||||
prop_InstructionsWork :: Int -> Integer -> Integer -> Bool
|
||||
prop_InstructionsWork l x y =
|
||||
prop_InstructionsWork' ::
|
||||
([Instruction] -> [Instruction]) ->
|
||||
Int ->
|
||||
Integer ->
|
||||
Integer ->
|
||||
Bool
|
||||
prop_InstructionsWork' f l x y =
|
||||
let (value, instructions) = runMath $ do numx <- convertTo l' x'
|
||||
numy <- convertTo l' y'
|
||||
karatsuba numx numy
|
||||
resGMP = x' * y'
|
||||
resKaratsuba = convertFrom value
|
||||
(endEnvironment, _) = run (Map.empty, Map.empty) instructions
|
||||
(endEnvironment, _) = run (Map.empty, Map.empty) (f instructions)
|
||||
instrVersion = V.map (getv endEnvironment . name) value
|
||||
in (resGMP == resKaratsuba) && (value == instrVersion)
|
||||
where
|
||||
@@ -612,6 +655,12 @@ prop_InstructionsWork l x y =
|
||||
Nothing -> error ("InstrProp lookup failure: " ++ show n)
|
||||
Just v -> D n v
|
||||
|
||||
prop_InstructionsWork :: Int -> Integer -> Integer -> Bool
|
||||
prop_InstructionsWork = prop_InstructionsWork' id
|
||||
|
||||
prop_SimplifiedInstructionsWork :: Int -> Integer -> Integer -> Bool
|
||||
prop_SimplifiedInstructionsWork = prop_InstructionsWork' simplifyConstants
|
||||
|
||||
prop_InstructionsConsistent :: Int -> Integer -> Integer -> Integer -> Integer -> Bool
|
||||
prop_InstructionsConsistent l a b x y =
|
||||
let (_, instrs1) = runMath (karatsuba' a' b')
|
||||
@@ -659,4 +708,5 @@ runChecks =
|
||||
runQuickCheck "Mul3 Works " prop_Mul3Works
|
||||
runQuickCheck "Karatsuba Works " prop_KaratsubaWorks
|
||||
runQuickCheck "Instructions Work " prop_InstructionsWork
|
||||
runQuickCheck "Simpl. Instructions Work " prop_SimplifiedInstructionsWork
|
||||
runQuickCheck "Generation Consistent " prop_InstructionsConsistent
|
||||
|
||||
Reference in New Issue
Block a user