diff --git a/generation/Test.hs b/generation/Test.hs index bca2b0a..e2c71bd 100644 --- a/generation/Test.hs +++ b/generation/Test.hs @@ -1,5 +1,4 @@ import Data.Bits hiding (bit) -import Debug.Trace import GHC.Integer.GMP.Internals import qualified Karatsuba import Numeric diff --git a/generation/src/Multiply.hs b/generation/src/Multiply.hs index a1cdf88..b6ccce2 100644 --- a/generation/src/Multiply.hs +++ b/generation/src/Multiply.hs @@ -182,114 +182,130 @@ declareUnsafeMulOperators bitsize _ = -- ----------------------------------------------------------------------------- generateMultiplier :: Bool -> Word -> String -> String -> [Stmt Span] -generateMultiplier fullmul size inName outName = - let readIns = map (load "self" "x") [0..size-1] ++ - map (load inName "y") [0..size-1] - instructions = releaseUnnecessary outVars (generateInstructions size) - outDigits | fullmul = 2 * size - | otherwise = size - outVars = map (("res" ++) . show) [0..outDigits-1] - operations = map translateInstruction instructions - writeOuts = map (store "res") [0..outDigits-1] - in readIns ++ operations ++ writeOuts +generateMultiplier fullmul size inName outName = readIns ++ operations ++ writeOuts where - load rhs vname i = - let liti = toLit i - vec = mkIdent rhs - var = mkIdent (vname ++ show i) - in [stmt| let $$var = $$vec.value[$$(liti)]; |] - store vname i = - let liti = toLit i - vec = mkIdent outName - var = mkIdent (vname ++ show i) - in [stmt| $$vec.value[$$(liti)] = $$var; |] + outDigits | fullmul = size * 2 + | otherwise = size + -- + outVars = map (("res" ++) . show) [0..outDigits-1] + instructionData = generateInstructions size + instrOutputs = take (fromIntegral outDigits) (idOutput instructionData) + instructions = releaseUnnecessary instrOutputs (idInstructions instructionData) + -- + readIns = map (load "self" "x") [0..size-1] ++ map (load inName "y") [0..size-1] + writeOuts = map (store "res") [0..outDigits-1] + -- + env = zip (idInput1 instructionData) (map (\ i -> "x" ++ show i) [0..size-1]) ++ + zip (idInput2 instructionData) (map (\ i -> "y" ++ show i) [0..size-1]) ++ + zip (idOutput instructionData) outVars + operations = map (translateInstruction env) instructions + -- + load rhs vname i = + let liti = toLit i + vec = mkIdent rhs + var = mkIdent (vname ++ show i) + in [stmt| let $$var = $$vec.value[$$(liti)]; |] + store vname i = + let liti = toLit i + vec = mkIdent outName + var = mkIdent (vname ++ show i) + in [stmt| $$vec.value[$$(liti)] = $$var; |] -translateInstruction :: Instruction -> Stmt Span -translateInstruction instr = undefined --- case instr of --- Add outname args -> --- let outid = mkIdent outname --- args' = map (\x -> [expr| $$x |]) (map mkIdent args) --- adds = foldl (\ x y -> [expr| $$(x) + $$(y) |]) --- (head args') --- (tail args') --- in [stmt| let $$outid: u128 = $$(adds); |] --- CastDown outname arg -> --- let outid = mkIdent outname --- inid = mkIdent arg --- in [stmt| let $$outid: u64 = $$inid as u64; |] --- CastUp outname arg -> --- let outid = mkIdent outname --- inid = mkIdent arg --- in [stmt| let $$outid: u128 = $$inid as u128; |] --- Complement outname arg -> --- let outid = mkIdent outname --- inid = mkIdent arg --- in [stmt| let $$outid: u64 = !$$inid; |] --- Declare64 outname arg -> --- let outid = mkIdent outname --- val = toLit (fromIntegral arg) --- in [stmt| let $$outid: u64 = $$(val); |] --- Declare128 outname arg -> --- let outid = mkIdent outname --- val = toLit (fromIntegral arg) --- in [stmt| let $$outid: u128 = $$(val); |] --- Mask outname arg mask -> --- let outid = mkIdent outname --- inid = mkIdent arg --- val = toLit (fromIntegral mask) --- in [stmt| let $$outid: u128 = $$inid & $$(val); |] --- Multiply outname args -> --- let outid = mkIdent outname --- args' = map (\x -> [expr| $$x |]) (map mkIdent args) --- muls = foldl (\ x y -> [expr| $$(x) * $$(y) |]) --- (head args') --- (tail args') --- in [stmt| let $$outid: u128 = $$(muls); |] --- ShiftR outname arg amt -> --- let outid = mkIdent outname --- inid = mkIdent arg --- val = toLit (fromIntegral amt) --- in [stmt| let $$outid: u128 = $$inid >> $$(val); |] +translateInstruction :: [(Variable, String)] -> Instruction -> Stmt Span +translateInstruction env instr = + case instr of + Add outname args -> + let outid = mkIdentO outname + args' = map (\x -> [expr| $$x |]) (map mkIdentI args) + adds = foldl (\ x y -> [expr| $$(x) + $$(y) |]) + (head args') + (tail args') + in [stmt| let $$outid: u128 = $$(adds); |] + CastDown outname arg -> + let outid = mkIdentO outname + inid = mkIdentI arg + in [stmt| let $$outid: u64 = $$inid as u64; |] + CastUp outname arg -> + let outid = mkIdentO outname + inid = mkIdentI arg + in [stmt| let $$outid: u128 = $$inid as u128; |] + Complement outname arg -> + let outid = mkIdentO outname + inid = mkIdentI arg + in [stmt| let $$outid: u64 = !$$inid; |] + Declare64 outname _ | Just inName <- lookup outname env -> + let outid = mkIdent (variableName outname) + inid = mkIdent inName + in [stmt| let $$outid: u64 = $$inid; |] + Declare64 outname arg -> + let outid = mkIdentO outname + val = toLit (fromIntegral arg) + in [stmt| let $$outid: u64 = $$(val); |] + Declare128 outname arg -> + let outid = mkIdentO outname + val = toLit (fromIntegral arg) + in [stmt| let $$outid: u128 = $$(val); |] + Mask outname arg mask -> + let outid = mkIdentO outname + inid = mkIdentI arg + val = toLit (fromIntegral mask) + in [stmt| let $$outid: u128 = $$inid & $$(val); |] + Multiply outname args -> + let outid = mkIdentO outname + args' = map (\x -> [expr| $$x |]) (map mkIdentI args) + muls = foldl (\ x y -> [expr| $$(x) * $$(y) |]) + (head args') + (tail args') + in [stmt| let $$outid: u128 = $$(muls); |] + ShiftR outname arg amt -> + let outid = mkIdentO outname + inid = mkIdentI arg + val = toLit (fromIntegral amt) + in [stmt| let $$outid: u128 = $$inid >> $$(val); |] + where + mkIdentO :: Variable -> Ident + mkIdentO v | Just x <- lookup v env = mkIdent x + | otherwise = mkIdent (variableName v) + mkIdentI :: Variable -> Ident + mkIdentI = mkIdent . variableName -releaseUnnecessary :: [String] -> [Instruction] -> [Instruction] -releaseUnnecessary outkeys instrs = undefined --- go (Set.fromList outkeys) [] rInstrs --- where --- rInstrs = reverse instrs --- -- --- go _ acc [] = acc --- go required acc (cur:rest) --- | outVar cur `Set.member` required = --- go (foldl' (flip Set.insert) required (inVars cur)) (cur:acc) rest --- | otherwise = --- go required acc rest --- ---outVar :: Instruction -> String ---outVar instr = --- case instr of --- Add outname _ -> outname --- CastDown outname _ -> outname --- CastUp outname _ -> outname --- Complement outname _ -> outname --- Declare64 outname _ -> outname --- Declare128 outname _ -> outname --- Mask outname _ _ -> outname --- Multiply outname _ -> outname --- ShiftR outname _ _ -> outname --- ---inVars :: Instruction -> [String] ---inVars instr = --- case instr of --- Add _ args -> args --- CastDown _ arg -> [arg] --- CastUp _ arg -> [arg] --- Complement _ arg -> [arg] --- Declare64 _ _ -> [] --- Declare128 _ _ -> [] --- Mask _ arg _ -> [arg] --- Multiply _ args -> args --- ShiftR _ arg _ -> [arg] +releaseUnnecessary :: [Variable] -> [Instruction] -> [Instruction] +releaseUnnecessary outkeys instrs = go (Set.fromList outkeys) [] rInstrs + where + rInstrs = reverse instrs + -- + go _ acc [] = acc + go required acc (cur:rest) + | outVar cur `Set.member` required = + go (foldl' (flip Set.insert) required (inVars cur)) (cur:acc) rest + | otherwise = + go required acc rest + +outVar :: Instruction -> Variable +outVar instr = + case instr of + Add outname _ -> outname + CastDown outname _ -> outname + CastUp outname _ -> outname + Complement outname _ -> outname + Declare64 outname _ -> outname + Declare128 outname _ -> outname + Mask outname _ _ -> outname + Multiply outname _ -> outname + ShiftR outname _ _ -> outname + +inVars :: Instruction -> [Variable] +inVars instr = + case instr of + Add _ args -> args + CastDown _ arg -> [arg] + CastUp _ arg -> [arg] + Complement _ arg -> [arg] + Declare64 _ _ -> [] + Declare128 _ _ -> [] + Mask _ arg _ -> [arg] + Multiply _ args -> args + ShiftR _ arg _ -> [arg] -- ----------------------------------------------------------------------------- diff --git a/generation/src/RustModule.hs b/generation/src/RustModule.hs index 1a766c2..d7fe617 100644 --- a/generation/src/RustModule.hs +++ b/generation/src/RustModule.hs @@ -10,12 +10,12 @@ module RustModule( import Control.Monad(forM_, unless) import Data.Char(toUpper) -import Data.List(isPrefixOf, partition) +import Data.List(partition) import Data.Map.Strict(Map) import qualified Data.Map.Strict as Map import Data.Maybe(mapMaybe) import Language.Rust.Data.Ident(mkIdent) -import Language.Rust.Data.Position(Span, spanOf) +import Language.Rust.Data.Position(Position(NoPosition), Span(Span)) import Language.Rust.Pretty(writeSourceFile) import Language.Rust.Quote(item, sourceFile) import Language.Rust.Syntax(Item(..), SourceFile(..), Visibility(..)) @@ -64,21 +64,18 @@ generateTasks :: RandomGen g => g -> [RustModule] -> [Word] -> [Task] generateTasks rng modules sizes = allTheFiles where allTheFiles = implementationsAndTests ++ - [lump "src/signed", lump "src/unsigned"] + [lump "i" "src/signed.rs", lump "u" "src/unsigned.rs"] implementationsAndTests = concatMap generateModules sizes -- - lump prefix = - let allFiles = map outputFile implementationsAndTests - files = filter (prefix `isPrefixOf`) allFiles - moduleFiles = map (drop (length prefix + 1)) files - moduleNames = map (takeWhile (/= '.')) moduleFiles + lump prefix file = + let moduleNames = map (\s -> prefix ++ show s) sizes moduleIdents = map mkIdent moduleNames types = map (mkIdent . map toUpper) moduleNames mods = map (\ name -> [item| mod $$name; |]) moduleIdents uses = zipWith (\ mname tname -> [item| pub use $$mname::$$tname; |]) moduleIdents types - file = [sourceFile| $@{mods} $@{uses} |] - in Task (prefix ++ ".rs") (\hndl -> writeSourceFile hndl file) + source = [sourceFile| $@{mods} $@{uses} |] + in Task file (\hndl -> writeSourceFile hndl source) -- generateModules size = let modules' = filter (\m -> predicate m size sizes) modules @@ -92,18 +89,21 @@ generateTasks rng modules sizes = allTheFiles | otherwise = let name = mkIdent (startsWith ++ show size) baseInclude = [item| pub use self::base::$$name; |] - moduleSources = map (generateSubmodule size sizes) modules' - moduleFile | startsWith == "I" = "src/signed/i" ++ show size ++ ".rs" - | otherwise = "src/unsigned/u" ++ show size ++ ".rs" - allSource = SourceFile Nothing [] (baseInclude : moduleSources) - in [Task moduleFile (\ hndl -> writeSourceFile hndl allSource)] + isSigned = startsWith == "I" + moduleSources = map (generateSubmodule isSigned size sizes) modules' + moduleFile | isSigned = "src/signed/i" ++ show size ++ ".rs" + | otherwise = "src/unsigned/u" ++ show size ++ ".rs" + allSource = SourceFile Nothing [] (baseInclude : map fst moduleSources) + in [Task moduleFile (\ hndl -> writeSourceFile hndl allSource)] ++ map snd moduleSources -generateSubmodule :: Word -> [Word] -> RustModule -> Item Span -generateSubmodule size allSizes m = - let SourceFile _ attrs internals = generator m size allSizes +generateSubmodule :: Bool -> Word -> [Word] -> RustModule -> (Item Span, Task) +generateSubmodule isSigned size allSizes m = + let modBody = generator m size allSizes modName = mkIdent (outputName m) - modSpan = spanOf internals - in Mod attrs CrateV modName (Just internals) modSpan + modDecl = Mod [] CrateV modName Nothing (Span NoPosition NoPosition) + modFile | isSigned = "src/signed/i" ++ show size ++ "/" ++ outputName m ++ ".rs" + | otherwise = "src/unsigned/u" ++ show size ++ "/" ++ outputName m ++ ".rs" + in (modDecl, Task modFile (\ hndl -> writeSourceFile hndl modBody)) generateTests :: RandomGen g => Word -> g ->