diff --git a/generation/src/Gen.hs b/generation/src/Gen.hs index 421a24d..02bd661 100644 --- a/generation/src/Gen.hs +++ b/generation/src/Gen.hs @@ -7,6 +7,9 @@ module Gen( blank, out, wrapIndent, + implFor, + implFor', + implFor'', ) where @@ -70,4 +73,16 @@ wrapIndent val middle = res <- indent middle tell (replicate (fromIntegral (indentAmount gs)) ' ') tell "}\n" - return res \ No newline at end of file + return res + +implFor :: String -> String -> Gen a -> Gen a +implFor trait name middle = + wrapIndent ("impl " ++ trait ++ " for " ++ name) middle + +implFor' :: String -> String -> Gen a -> Gen a +implFor' trait name middle = + wrapIndent ("impl<'a> " ++ trait ++ " for " ++ name) middle + +implFor'' :: String -> String -> Gen a -> Gen a +implFor'' trait name middle = + wrapIndent ("impl<'a,'b> " ++ trait ++ " for " ++ name) middle \ No newline at end of file diff --git a/generation/src/Main.hs b/generation/src/Main.hs index 373ffbf..e7afc73 100644 --- a/generation/src/Main.hs +++ b/generation/src/Main.hs @@ -11,7 +11,7 @@ import System.Directory(createDirectoryIfMissing) import System.Environment(getArgs) import System.Exit(die) import System.FilePath(()) -import UnsignedBase(declareBaseStructure) +import UnsignedBase(declareBaseStructure,declareBinaryOperators) gatherRequirements :: [Requirement] -> Map Int [Operation] gatherRequirements = foldr process Map.empty @@ -29,4 +29,5 @@ main = createDirectoryIfMissing True basedir forM_ reqs $ \ (x, ops) -> do runGen (basedir "mod.rs") (declareBaseStructure size ops) + runGen (basedir "binary.rs") (declareBinaryOperators size) \ No newline at end of file diff --git a/generation/src/UnsignedBase.hs b/generation/src/UnsignedBase.hs index 77b4a29..6c36e72 100644 --- a/generation/src/UnsignedBase.hs +++ b/generation/src/UnsignedBase.hs @@ -1,5 +1,6 @@ module UnsignedBase( declareBaseStructure + , declareBinaryOperators ) where @@ -16,12 +17,13 @@ declareBaseStructure bitsize ops = out "use core::fmt;" out "use super::super::CryptoNum;" blank - out "mod binary_ops;" + out "mod binary;" blank + out "#[derive(Clone)]" wrapIndent ("pub struct " ++ name) $ out ("value: [u64; " ++ show entries ++ "]") blank - wrapIndent ("impl CryptoNum for " ++ name) $ + implFor "CryptoNum" name $ do wrapIndent ("fn zero() -> Self") $ out (name ++ "{ value: [0; " ++ show entries ++ "] }") blank @@ -51,39 +53,117 @@ declareBaseStructure bitsize ops = out "return false;" out "(self.value[idx] & (1u64 << offset)) != 0" blank - wrapIndent ("impl PartialEq for " ++ name) $ + implFor "PartialEq" name $ wrapIndent "fn eq(&self, other: &Self) -> bool" $ do forM_ (reverse [1..top]) $ \ i -> out ("self.value[" ++ show i ++ "] == other.value[" ++ show i ++ "] && ") out "self.value[0] == other.value[0]" blank - out ("impl Eq for " ++ name ++ " {}") + implFor "Eq" name $ return () blank - wrapIndent ("impl Ord for " ++ name) $ + implFor "Ord" name $ wrapIndent "fn cmp(&self, other: &Self) -> Ordering" $ do out ("self.value[" ++ show top ++ "].cmp(&other.value[" ++ show top ++ "])") forM_ (reverse [0..top-1]) $ \ i -> out (" .then(self.value[" ++ show i ++ "].cmp(&other.value[" ++ show i ++ "]))") blank - wrapIndent ("impl PartialOrd for " ++ name) $ + implFor "PartialOrd" name $ wrapIndent "fn partial_cmp(&self, other: &Self) -> Option" $ out "Some(self.cmp(other))" blank - wrapIndent ("impl fmt::Debug for " ++ name) $ + implFor "fmt::Debug" name $ wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ do out ("f.debug_tuple(" ++ show name ++ ")") forM_ [0..top] $ \ i -> out (" .field(&self.value[" ++ show i ++ "])") out " .finish()" blank - wrapIndent ("impl fmt::UpperHex for " ++ name) $ + implFor "fmt::UpperHex" name $ wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ do forM_ (reverse [1..top]) $ \ i -> out ("write!(f, \"{:X}\", self.value[" ++ show i ++ "])?;") out "write!(f, \"{:X}\", self.value[0])" blank - wrapIndent ("impl fmt::LowerHex for " ++ name) $ + implFor "fmt::LowerHex" name $ wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ do forM_ (reverse [1..top]) $ \ i -> out ("write!(f, \"{:x}\", self.value[" ++ show i ++ "])?;") - out "write!(f, \"{:x}\", self.value[0])" \ No newline at end of file + out "write!(f, \"{:x}\", self.value[0])" + +declareBinaryOperators :: Int -> Gen () +declareBinaryOperators bitsize = + do let name = "U" ++ show bitsize + entries = bitsize `div` 64 + out "use core::ops::{BitAnd,BitAndAssign};" + out "use core::ops::{BitOr,BitOrAssign};" + out "use core::ops::{BitXor,BitXorAssign};" + out "use core::ops::Not;" + out ("use super::U" ++ show bitsize ++ ";") + blank + generateBinOps "BitAnd" name "bitand" "&=" entries + blank + generateBinOps "BitOr" name "bitor" "|=" entries + blank + generateBinOps "BitXor" name "bitxor" "^=" entries + blank + implFor "Not" name $ + do out "type Output = Self;" + blank + wrapIndent "fn not(mut self) -> Self" $ + do forM_ [0..entries-1] $ \ i -> + out ("self.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];") + out "self" + blank + implFor' "Not" ("&'a " ++ name) $ + do out ("type Output = " ++ name ++ ";") + blank + wrapIndent ("fn not(self) -> " ++ name) $ + do out "let mut output = self.clone();" + forM_ [0..entries-1] $ \ i -> + out ("output.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];") + out "output" + +generateBinOps :: String -> String -> String -> String -> Int -> Gen () +generateBinOps trait name fun op entries = + do implFor (trait ++ "Assign") name $ + wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: Self)") $ + forM_ [0..entries-1] $ \ i -> + out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];") + blank + implFor' (trait ++ "Assign<&'a " ++ name ++ ">") name $ + wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: &Self)") $ + forM_ [0..entries-1] $ \ i -> + out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];") + blank + generateBinOpsFromAssigns trait name fun op + +generateBinOpsFromAssigns :: String -> String -> String -> String -> Gen () +generateBinOpsFromAssigns trait name fun op = + do implFor trait name $ + do out "type Output = Self;" + blank + wrapIndent ("fn " ++ fun ++ "(mut self, rhs: Self) -> Self") $ + do out ("self " ++ op ++ " rhs;") + out "self" + blank + implFor' (trait ++ "<&'a " ++ name ++ ">") name $ + do out "type Output = Self;" + blank + wrapIndent ("fn " ++ fun ++ "(mut self, rhs: &Self) -> Self") $ + do out ("self " ++ op ++ " rhs;") + out "self" + blank + implFor' (trait ++ "<" ++ name ++ ">") ("&'a " ++ name) $ + do out ("type Output = " ++ name ++ ";") + blank + wrapIndent ("fn " ++ fun ++ "(self, mut rhs: " ++ name ++ ") -> " ++ name) $ + do out ("rhs " ++ op ++ " self;") + out "rhs" + blank + implFor'' (trait ++ "<&'a " ++ name ++ ">") ("&'b " ++ name) $ + do out ("type Output = " ++ name ++ ";") + blank + wrapIndent ("fn " ++ fun ++ "(self, rhs: &" ++ name ++ ") -> " ++ name) $ + do out "let mut output = self.clone();" + out ("output " ++ op ++ " rhs;") + out "output"