From ba587cb37f2fc74397cf8a438a1db9433675aa7d Mon Sep 17 00:00:00 2001 From: Adam Wick Date: Mon, 4 Nov 2019 17:08:16 -0800 Subject: [PATCH] Start trying to generate shift code. --- generation/generation.cabal | 2 +- generation/src/Generators.hs | 14 +++ generation/src/Main.hs | 2 + generation/src/Shift.hs | 192 +++++++++++++++++++++++++++++++++++ 4 files changed, 209 insertions(+), 1 deletion(-) create mode 100644 generation/src/Shift.hs diff --git a/generation/generation.cabal b/generation/generation.cabal index 4225bc5..2873de2 100644 --- a/generation/generation.cabal +++ b/generation/generation.cabal @@ -17,7 +17,7 @@ extra-source-files: CHANGELOG.md executable generation main-is: Main.hs - other-modules: Base, BinaryOps, Compare, Conversions, CryptoNum, File, Gen, Generators + other-modules: Base, BinaryOps, Compare, Conversions, CryptoNum, File, Gen, Generators, Shift -- other-extensions: build-depends: base >= 4.12.0.0, containers, diff --git a/generation/src/Generators.hs b/generation/src/Generators.hs index aa92f1a..3f9a474 100644 --- a/generation/src/Generators.hs +++ b/generation/src/Generators.hs @@ -10,6 +10,20 @@ generateNum g size = x' = x `mod` (2 ^ size) in (x', g') +modulate :: (Integral a, Integral b) => a -> b -> Integer +modulate x size = x' `mod` (2 ^ size') + where + x', size' :: Integer + size' = fromIntegral size + x' = fromIntegral x + +modulate' :: (Num a, Integral a, Integral b) => a -> b -> Integer +modulate' x size = signum x' * ((abs x') `mod` (2 ^ size')) + where + x', size' :: Integer + size' = fromIntegral size + x' = fromIntegral x + showX :: Integer -> String showX x | x < 0 = "-" ++ showX (abs x) | otherwise = showHex x "" diff --git a/generation/src/Main.hs b/generation/src/Main.hs index 5573a65..44d42be 100644 --- a/generation/src/Main.hs +++ b/generation/src/Main.hs @@ -8,6 +8,7 @@ import Conversions(conversions) import CryptoNum(cryptoNum) import Control.Monad(forM_,unless) import File(File,Task(..),generateTasks) +import Shift(shiftOps) import System.Directory(createDirectoryIfMissing) import System.Environment(getArgs) import System.Exit(die) @@ -31,6 +32,7 @@ unsignedFiles = [ , comparisons , conversions , cryptoNum + , shiftOps ] signedFiles :: [File] diff --git a/generation/src/Shift.hs b/generation/src/Shift.hs new file mode 100644 index 0000000..6d78c70 --- /dev/null +++ b/generation/src/Shift.hs @@ -0,0 +1,192 @@ +{-# LANGUAGE QuasiQuotes #-} +module Shift(shiftOps) + where + +import Data.Bits(shiftL,shiftR) +import Data.Map.Strict(Map) +import qualified Data.Map.Strict as Map +import File +import Gen(toLit) +import Generators +import Language.Rust.Data.Ident +import Language.Rust.Data.Position +import Language.Rust.Quote +import Language.Rust.Syntax +import System.Random(RandomGen) + +numTestCases :: Int +numTestCases = 3000 + +shiftOps :: File +shiftOps = File { + predicate = \ _ _ -> True, + outputName = "shift", + isUnsigned = True, + generator = declareShiftOperators, + testCase = Just generateTests +} + +declareShiftOperators :: Word -> SourceFile Span +declareShiftOperators bitsize = + let struct_name = mkIdent ("U" ++ show bitsize) + entries = bitsize `div` 64 + unsignedShifts = generateUnsigneds struct_name + shlUsizeImpls = generateBaseUsizes struct_name + shlActualImpl = concatMap (actualImplLines (toLit entries)) [1..entries-1] + testFileLit = Lit [] (Str (testFile bitsize) Cooked Unsuffixed mempty) mempty + in [sourceFile| + #[cfg(test)] + use core::convert::TryFrom; + use core::ops::{Shl,ShlAssign}; + use core::ops::{Shr,ShrAssign}; + #[cfg(test)] + use crate::CryptoNum; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + use super::$$struct_name; + + impl ShlAssign for $$struct_name { + fn shl_assign(&mut self, rhs: usize) { + let copy = self.clone(); + let digits = rhs / 64; + let bits = rhs % 64; + let mask = !(0xFFFF_FFFF_FFFF_FFFF << bits); + let shift = (64 - bits) as u32; + + let base0 = if digits == 0 { copy.value[0] } else { 0 }; + self.value[0] = base0 << bits; + $@{shlActualImpl} + } + } + + impl ShrAssign for $$struct_name { + fn shr_assign(&mut self, _rhs: usize) { + panic!("base shl_assign"); + } + } + + $@{shlUsizeImpls} + $@{unsignedShifts} + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("shift", $$(testFileLit)), 4, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, sbytes) = case.get("s").unwrap(); + let (neg2, lbytes) = case.get("l").unwrap(); + let (neg3, rbytes) = case.get("r").unwrap(); + + assert!(!neg0 && !neg1 && !neg2 && !neg3); + let x = $$struct_name::from_bytes(xbytes); + let s = usize::try_from($$struct_name::from_bytes(sbytes)).unwrap(); + let l = $$struct_name::from_bytes(lbytes); + let r = $$struct_name::from_bytes(rbytes); + + assert_eq!(l, &x << s); + assert_eq!(r, &x >> s); + }); + } + |] + +actualImplLines :: Expr Span -> Word -> [Stmt Span] +actualImplLines entries i = + let basei = mkIdent ("base" ++ show i) + basei1 = mkIdent ("base" ++ show (i - 1)) + carryi = mkIdent ("carry" ++ show i) + liti = toLit i + in [ + [stmt|let $$basei = if $$(liti) > $$(entries) { + copy.value[$$(liti)-digits] + } else { + 0 + }; |] + , [stmt|let ($$carryi,_) = ($$basei1 & mask).overflowing_shl(shift); |] + , [stmt|self.value[$$(liti)] = ($$basei << bits) | $$carryi; |] + ] + +generateBaseUsizes :: Ident -> [Item Span] +generateBaseUsizes sname = + generateBaseUsize sname (mkIdent "Shl") (mkIdent "shl") (mkIdent "shl_assign") ++ + generateBaseUsize sname (mkIdent "Shr") (mkIdent "shr") (mkIdent "shr_assign") + +generateBaseUsize :: Ident -> Ident -> Ident -> Ident -> [Item Span] +generateBaseUsize sname tname sfn assign = [ + [item| + impl $$tname for $$sname { + type Output = Self; + + fn $$sfn(mut self, rhs: usize) -> $$sname { + self.$$assign(rhs); + self + } + } + |] + , [item| + impl<'a> $$tname for &'a $$sname { + type Output = $$sname; + + fn $$sfn(self, rhs: usize) -> $$sname { + let mut res = self.clone(); + res.$$assign(rhs); + res + } + } + |] + ] + + +generateUnsigneds :: Ident -> [Item Span] +generateUnsigneds sname = + concatMap (generateUnsignedImpls sname . mkIdent) ["u8","u16","u32","u64","u128"] + +generateUnsignedImpls :: Ident -> Ident -> [Item Span] +generateUnsignedImpls sname rhs = + generateBaseImpls sname (mkIdent "Shl") (mkIdent "shl") + (mkIdent "ShlAssign") (mkIdent "shl_assign") rhs ++ + generateBaseImpls sname (mkIdent "Shr") (mkIdent "shr") + (mkIdent "ShrAssign") (mkIdent "shr_assign") rhs + +generateBaseImpls :: Ident -> Ident -> Ident -> Ident -> Ident -> Ident -> [Item Span] +generateBaseImpls sname upper_shift lower_shift assign_shift lassign_shift right = [ + [item| + impl $$assign_shift<$$right> for $$sname { + fn $$lassign_shift(&mut self, rhs: $$right) { + self.$$lassign_shift(rhs as usize); + } + } + |] + , [item| + impl $$upper_shift<$$right> for $$sname { + type Output = $$sname; + + fn $$lower_shift(self, rhs: $$right) -> Self::Output { + self.$$lower_shift(rhs as usize) + } + } + |] + , [item| + impl<'a> $$upper_shift<$$right> for &'a $$sname { + type Output = $$sname; + + fn $$lower_shift(self, rhs: $$right) -> $$sname { + self.$$lower_shift(rhs as usize) + } + } + |] + ] + +generateTests :: RandomGen g => Word -> g -> [Map String String] +generateTests size g = go g numTestCases + where + go _ 0 = [] + go g0 i = + let (x, g1) = generateNum g0 size + (y, g2) = generateNum g1 size + s = y `mod` fromIntegral size + l = modulate (x `shiftL` fromIntegral s) size + r = modulate (x `shiftR` fromIntegral s) size + tcase = Map.fromList [("x", showX x), ("s", showX s), + ("l", showX l), ("r", showX r)] + in tcase : go g2 (i - 1)