diff --git a/generation/generation.cabal b/generation/generation.cabal index bf06ed1..5c5f675 100644 --- a/generation/generation.cabal +++ b/generation/generation.cabal @@ -2,7 +2,6 @@ cabal-version: 2.0 -- Initial package description 'generation.cabal' generated by 'cabal -- init'. For further documentation, see -- http://haskell.org/cabal/users-guide/ - name: generation version: 0.1.0.0 synopsis: Generates the cryptonum Rust library, based on requirements. @@ -20,10 +19,11 @@ executable generation main-is: Main.hs other-modules: Base, BinaryOps, Compare, Conversions, CryptoNum, File, Gen -- other-extensions: - build-depends: base ^>=4.12.0.0, + build-depends: base >= 4.12.0.0, containers, directory, filepath, + language-rust, mtl hs-source-dirs: src default-language: Haskell2010 diff --git a/generation/src/Base.hs b/generation/src/Base.hs index 7cbcfd2..ac70443 100644 --- a/generation/src/Base.hs +++ b/generation/src/Base.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE QuasiQuotes #-} module Base( base ) @@ -6,6 +7,11 @@ module Base( import Control.Monad(forM_) import File import Gen +import Language.Rust.Data.Ident +import Language.Rust.Data.Position +import Language.Rust.Quote +import Language.Rust.Pretty +import Language.Rust.Syntax base :: File base = File { @@ -19,38 +25,90 @@ declareBaseStructure bitsize = do let name = "U" ++ show bitsize entries = bitsize `div` 64 top = entries - 1 - out "use core::fmt;" - out "use quickcheck::{Arbitrary,Gen};" - blank - out "#[derive(Clone)]" - wrapIndent ("pub struct " ++ name) $ - out ("pub(crate) value: [u64; " ++ show entries ++ "]") - blank - implFor "fmt::Debug" name $ - wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ - do out ("f.debug_tuple(" ++ show name ++ ")") - forM_ [0..top] $ \ i -> - out (" .field(&self.value[" ++ show i ++ "])") - out " .finish()" - blank - implFor "fmt::UpperHex" name $ - wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ - do forM_ (reverse [1..top]) $ \ i -> - out ("write!(f, \"{:X}\", self.value[" ++ show i ++ "])?;") - out "write!(f, \"{:X}\", self.value[0])" - blank - implFor "fmt::LowerHex" name $ - wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ - do forM_ (reverse [1..top]) $ \ i -> - out ("write!(f, \"{:x}\", self.value[" ++ show i ++ "])?;") - out "write!(f, \"{:x}\", self.value[0])" - blank - implFor "Arbitrary" name $ - wrapIndent "fn arbitrary(g: &mut G) -> Self" $ - do out (name ++ " {") - indent $ - do out ("value: [") - indent $ forM_ [0..top] $ \ _ -> - out ("g.next_u64(),") - out ("]") - out ("}") \ No newline at end of file + sname = mkIdent name + entriese = Lit [] (Int Dec (fromIntegral entries) Unsuffixed mempty) mempty + strname = Lit [] (Str name Cooked Unsuffixed mempty) mempty + debugExp = buildDebugExp 0 entries [expr| f.debug_tuple($$(strname)) |] + lowerPrints = buildPrints entries "x" + upperPrints = buildPrints entries "X" + out $ show $ pretty' $ [sourceFile| + use core::fmt; + use quickcheck::{Arbitrary,Gen}; + + #[derive(Clone)] + pub struct $$sname { + value: [u64; $$(entriese)] + } + + impl fmt::Debug for $$sname { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + $$(debugExp).finish() + } + } + + impl fmt::UpperHex for $$sname { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + $@{upperPrints} + write!(f, "{:X}", self.value[0]) + } + } + + impl fmt::LowerHex for $$sname { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + $@{lowerPrints} + write!(f, "{:x}", self.value[0]) + } + } + + impl Arbitrary for $$sname { + fn arbitrary(g: &mut G) -> Self { + let mut res = $$sname{ value: [0; $$(entriese)] }; + for entry in res.iter_mut() { + *entry = g.next_u64(); + } + res + } + } + |] + +buildDebugExp :: Word -> Word -> Expr Span -> Expr Span +buildDebugExp i top acc + | i == top = acc + | otherwise = + let liti = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty + in buildDebugExp (i + 1) top [expr| $$(acc).field(&self.value[$$(liti)]) |] + +buildPrints :: Word -> String -> [Stmt Span] +buildPrints entries printer = go (entries - 1) + where + litStr = Token mempty (LiteralTok (StrTok ("{:" ++ printer ++ "}")) Nothing) + --Lit [] (Str ("{:" ++ printer ++ "}") Cooked Unsuffixed mempty) mempty + go 0 = [] + go x = + let rest = go (x - 1) + curi = Token mempty (LiteralTok (IntegerTok (show x)) Nothing) + -- Lit [] (Int Dec (fromIntegral x) Unsuffixed mempty) mempty + cur = [stmt| write!(f, $$(litStr), self.value[$$(curi)])?; |] + in cur : rest + +-- implFor "fmt::UpperHex" name $ +-- wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ +-- do forM_ (reverse [1..top]) $ \ i -> +-- out ("write!(f, \"{:X}\", self.value[" ++ show i ++ "])?;") +-- out "write!(f, \"{:X}\", self.value[0])" +-- blank +-- implFor "fmt::LowerHex" name $ +-- wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ +-- do forM_ (reverse [1..top]) $ \ i -> +-- out ("write!(f, \"{:x}\", self.value[" ++ show i ++ "])?;") +-- out "write!(f, \"{:x}\", self.value[0])" +-- blank +-- implFor "Arbitrary" name $ +-- wrapIndent "fn arbitrary(g: &mut G) -> Self" $ +-- do out (name ++ " {") +-- indent $ +-- do out ("value: [") +-- indent $ forM_ [0..top] $ \ _ -> +-- out ("g.next_u64(),") +-- out ("]") +-- out ("}") diff --git a/generation/src/BinaryOps.hs b/generation/src/BinaryOps.hs index 7732b6d..05da4e9 100644 --- a/generation/src/BinaryOps.hs +++ b/generation/src/BinaryOps.hs @@ -1,11 +1,16 @@ +{-# LANGUAGE QuasiQuotes #-} module BinaryOps( binaryOps ) where -import Control.Monad(forM_) import File import Gen +import Language.Rust.Data.Ident +import Language.Rust.Data.Position +import Language.Rust.Quote +import Language.Rust.Pretty +import Language.Rust.Syntax binaryOps :: File binaryOps = File { @@ -16,130 +21,177 @@ binaryOps = File { declareBinaryOperators :: Word -> Gen () declareBinaryOperators bitsize = - do let name = "U" ++ show bitsize + do let struct_name = mkIdent ("U" ++ show bitsize) entries = bitsize `div` 64 - out "use core::ops::{BitAnd,BitAndAssign};" - out "use core::ops::{BitOr,BitOrAssign};" - out "use core::ops::{BitXor,BitXorAssign};" - out "use core::ops::Not;" - out "#[cfg(test)]" - out "use crate::CryptoNum;" - out "#[cfg(test)]" - out "use quickcheck::quickcheck;" - out ("use super::U" ++ show bitsize ++ ";") - blank - generateBinOps "BitAnd" name "bitand" "&=" entries - blank - generateBinOps "BitOr" name "bitor" "|=" entries - blank - generateBinOps "BitXor" name "bitxor" "^=" entries - blank - implFor "Not" name $ - do out "type Output = Self;" - blank - wrapIndent "fn not(mut self) -> Self" $ - do forM_ [0..entries-1] $ \ i -> - out ("self.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];") - out "self" - blank - implFor' "Not" ("&'a " ++ name) $ - do out ("type Output = " ++ name ++ ";") - blank - wrapIndent ("fn not(self) -> " ++ name) $ - do out "let mut output = self.clone();" - forM_ [0..entries-1] $ \ i -> - out ("output.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];") - out "output" - blank - addBinaryLaws name + andOps = generateBinOps "BitAnd" struct_name "bitand" BitAndOp entries + orOps = generateBinOps "BitOr" struct_name "bitor" BitOrOp entries + xorOps = generateBinOps "BitXor" struct_name "bitxor" BitXorOp entries + baseNegationStmts = negationStatements "self" entries + refNegationStmts = negationStatements "output" entries + out $ show $ pretty' $ [sourceFile| + use core::ops::{BitAnd,BitAndAssign}; + use core::ops::{BitOr,BitOrAssign}; + use core::ops::{BitXor,BitXorAssign}; + use core::ops::Not; + #[cfg(test)] + use crate::CryptoNum; + #[cfg(test)] + use quickcheck::quickcheck; + use super::$$struct_name; -generateBinOps :: String -> String -> String -> String -> Word -> Gen () -generateBinOps trait name fun op entries = - do implFor (trait ++ "Assign") name $ - wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: Self)") $ - forM_ [0..entries-1] $ \ i -> - out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];") - blank - implFor' (trait ++ "Assign<&'a " ++ name ++ ">") name $ - wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: &Self)") $ - forM_ [0..entries-1] $ \ i -> - out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];") - blank - generateBinOpsFromAssigns trait name fun op + $@{andOps} + $@{orOps} + $@{xorOps} -generateBinOpsFromAssigns :: String -> String -> String -> String -> Gen () -generateBinOpsFromAssigns trait name fun op = - do implFor trait name $ - do out "type Output = Self;" - blank - wrapIndent ("fn " ++ fun ++ "(mut self, rhs: Self) -> Self") $ - do out ("self " ++ op ++ " rhs;") - out "self" - blank - implFor' (trait ++ "<&'a " ++ name ++ ">") name $ - do out "type Output = Self;" - blank - wrapIndent ("fn " ++ fun ++ "(mut self, rhs: &Self) -> Self") $ - do out ("self " ++ op ++ " rhs;") - out "self" - blank - implFor' (trait ++ "<" ++ name ++ ">") ("&'a " ++ name) $ - do out ("type Output = " ++ name ++ ";") - blank - wrapIndent ("fn " ++ fun ++ "(self, mut rhs: " ++ name ++ ") -> " ++ name) $ - do out ("rhs " ++ op ++ " self;") - out "rhs" - blank - implFor'' (trait ++ "<&'a " ++ name ++ ">") ("&'b " ++ name) $ - do out ("type Output = " ++ name ++ ";") - blank - wrapIndent ("fn " ++ fun ++ "(self, rhs: &" ++ name ++ ") -> " ++ name) $ - do out "let mut output = self.clone();" - out ("output " ++ op ++ " rhs;") - out "output" + impl Not for $$struct_name { + type Output = Self; -addBinaryLaws :: String -> Gen () -addBinaryLaws name = - do let args3 = "(a: " ++ name ++ ", b: " ++ name ++ ", c: " ++ name ++ ")" - args2 = "(a: " ++ name ++ ", b: " ++ name ++ ")" - out "#[cfg(test)]" - wrapIndent "quickcheck!" $ - do wrapIndent ("fn and_associative" ++ args3 ++ " -> bool") $ - out ("((&a & &b) & &c) == (&a & (&b & &c))") - blank - wrapIndent ("fn and_commutative" ++ args2 ++ " -> bool") $ - out ("(&a & &b) == (&b & &a)") - blank - wrapIndent ("fn and_idempotent" ++ args2 ++ " -> bool") $ - out ("(&a & &b) == (&a & &b & &a)") - blank - wrapIndent ("fn xor_associative" ++ args3 ++ " -> bool") $ - out ("((&a ^ &b) ^ &c) == (&a ^ (&b ^ &c))") - blank - wrapIndent ("fn xor_commutative" ++ args2 ++ " -> bool") $ - out ("(&a ^ &b) == (&b ^ &a)") - blank - wrapIndent ("fn or_associative" ++ args3 ++ " -> bool") $ - out ("((&a | &b) | &c) == (&a | (&b | &c))") - blank - wrapIndent ("fn or_commutative" ++ args2 ++ " -> bool") $ - out ("(&a | &b) == (&b | &a)") - blank - wrapIndent ("fn or_idempotent" ++ args2 ++ " -> bool") $ - out ("(&a | &b) == (&a | &b | &a)") - blank - wrapIndent ("fn and_or_distribution" ++ args3 ++ "-> bool") $ - out ("(&a & (&b | &c)) == ((&a & &b) | (&a & &c))") - blank - wrapIndent ("fn xor_clears(a: " ++ name ++ ") -> bool") $ - out (name ++ "::zero() == (&a ^ &a)") - blank - wrapIndent ("fn double_neg_ident(a: " ++ name ++ ") -> bool") $ - out ("a == !!&a") - blank - wrapIndent ("fn and_ident(a: " ++ name ++ ") -> bool") $ - do out ("let ones = !" ++ name ++ "::zero();") - out ("(&a & &ones) == a") - blank - wrapIndent ("fn or_ident(a: " ++ name ++ ") -> bool") $ - out ("(&a | " ++ name ++ "::zero()) == a") \ No newline at end of file + fn not(mut self) -> Self { + $@{baseNegationStmts} + self + } + } + + impl<'a> Not for &'a $$struct_name { + type Output = Self; + + fn not(self) -> Self { + let mut output = self.clone(); + $@{refNegationStmts} + output + } + } + + quickcheck! { + fn and_associative(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool { + ((&a & &b) & &c) == (&a & (&b & &c)) + } + fn and_commutative(a: $$struct_name, b: $$struct_name) -> bool { + (&a & &b) == (&b & &a) + } + fn and_idempotent(a: $$struct_name, b: $$struct_name) -> bool { + (&a & &b) == (&a & &b & &a) + } + + fn xor_associative(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool { + ((&a ^ &b) ^ &c) == (&a ^ (&b ^ &c)) + } + fn xor_commutative(a: $$struct_name, b: $$struct_name) -> bool { + (&a ^ &b) == (&b ^ &a) + } + + fn or_associative(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool { + ((&a | &b) & &c) == (&a | (&b | &c)) + } + fn or_commutative(a: $$struct_name, b: $$struct_name) -> bool { + (&a | &b) == (&b | &a) + } + fn or_idempotent(a: $$struct_name, b: $$struct_name) -> bool { + (&a | &b) == (&a | &b | &a) + } + + fn and_or_distribution(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool { + (&a & (&b | &c)) == ((&a & &b) | (&a & &c)) + } + fn xor_clears(a: $$struct_name) -> bool { + $$struct_name::zero() == (&a ^ *a) + } + fn double_neg_ident(a: $$struct_name) -> bool { + a == !!$a + } + fn and_ident(a: $$struct_name) -> bool { + let ones = !$$struct_name::zero(); + (&a & &ones) == a + } + fn or_ident(a: $$struct_name) -> bool { + (&a | $$struct_name::zero()) == a + } + } + |] + +negationStatements :: String -> Word -> [Stmt Span] +negationStatements target entries = map genStatement [0..entries-1] + where + genStatement i = + let idx = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty + v = mkIdent target + in [stmt| $$v.value[$$(idx)] = !self.value[$$(idx)]; |] + +generateBinOps :: String -> Ident -> String -> BinOp -> Word -> [Item Span] +generateBinOps trait sname func oper entries = + [normAssign, refAssign] ++ generateAllTheVariants traitIdent funcIdent sname oper + where + traitIdent = mkIdent trait + assignIdent = mkIdent (trait ++ "Assign") + funcIdent = mkIdent func + funcAssignIdent = mkIdent (func ++ "_assign") + -- + normAssign = [item| + impl $$assignIdent for $$sname { + fn $$funcAssignIdent(&mut self, rhs: Self) { + $@{assignStatements} + } + } + |] + refAssign = [item| + impl $$assignIdent<&'a $$sname> for $$sname { + fn $$funcAssignIdent(&mut self, rhs: &Self) { + $@{assignStatements} + } + } + |] + -- + assignStatements :: [Stmt Span] + assignStatements = map genAssign [0..entries-1] + genAssign i = + let idx = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty + left = [expr| self.value[$$(idx)] |] + right = [expr| rhs.value[$$(idx)] |] + in Semi (AssignOp [] oper left right mempty) mempty + +generateAllTheVariants :: Ident -> Ident -> Ident -> BinOp -> [Item Span] +generateAllTheVariants traitname func sname oper = [ + [item| + impl $$traitname for $$sname { + type Output = Self; + + fn $$func(mut self, rhs: Self) -> Self { + $${assigner_self_rhs} + self + } + }|] + , [item| + impl<'a> $$traitname<&'a $$sname> for $$sname { + type Output = Self; + + fn $$func(mut self, rhs: Self) -> Self { + $${assigner_self_rhs} + self + } + }|] + , [item| + impl<'a> $$traitname for &'a $$sname { + type Output = Self; + + fn $$func(mut self, rhs: Self) -> Self { + $${assigner_rhs_self} + self + } + }|] + , [item| + impl<'a,'b> $$traitname<&'a $$sname> for &'b $$sname { + type Output = Self; + + fn $$func(mut self, rhs: Self) -> Self { + let mut out = self.clone(); + $${assigner_out_rhs} + out + } + }|] + ] + where + assigner_self_rhs = assigner [expr| self |] [expr| rhs |] + assigner_rhs_self = assigner [expr| rhs |] [expr| self |] + assigner_out_rhs = assigner [expr| out |] [expr| rhs |] + assigner left right = + Semi (AssignOp [] oper left right mempty) mempty diff --git a/generation/src/Compare.hs b/generation/src/Compare.hs index 79d9d9b..24169db 100644 --- a/generation/src/Compare.hs +++ b/generation/src/Compare.hs @@ -1,9 +1,14 @@ +{-# LANGUAGE QuasiQuotes #-} module Compare(comparisons) where -import Control.Monad(forM_) import File import Gen +import Language.Rust.Data.Ident +import Language.Rust.Data.Position +import Language.Rust.Quote +import Language.Rust.Pretty +import Language.Rust.Syntax comparisons :: File comparisons = File { @@ -14,47 +19,79 @@ comparisons = File { declareComparators :: Word -> Gen () declareComparators bitsize = - do let name = "U" ++ show bitsize + do let sname = mkIdent ("U" ++ show bitsize) entries = bitsize `div` 64 - top = entries - 1 - out "use core::cmp::{Eq,Ordering,PartialEq};" - out "#[cfg(test)]" - out "use quickcheck::quickcheck;" - out ("use super::" ++ name ++ ";") - blank - implFor "PartialEq" name $ - wrapIndent "fn eq(&self, other: &Self) -> bool" $ - do forM_ (reverse [1..top]) $ \ i -> - out ("self.value[" ++ show i ++ "] == other.value[" ++ show i ++ "] && ") - out "self.value[0] == other.value[0]" - blank - implFor "Eq" name $ return () - blank - implFor "Ord" name $ - wrapIndent "fn cmp(&self, other: &Self) -> Ordering" $ - do out ("self.value[" ++ show top ++ "].cmp(&other.value[" ++ show top ++ "])") - forM_ (reverse [0..top-1]) $ \ i -> - out (" .then(self.value[" ++ show i ++ "].cmp(&other.value[" ++ show i ++ "]))") - blank - implFor "PartialOrd" name $ - wrapIndent "fn partial_cmp(&self, other: &Self) -> Option" $ - out "Some(self.cmp(other))" - blank - out "#[cfg(test)]" - wrapIndent "quickcheck!" $ - do let transFun n = "fn " ++ n ++ "(a: " ++ name ++ ", b: " ++ name ++ - ", c: " ++ name ++ ") -> bool" - wrapIndent (transFun "eq_is_transitive") $ - out ("if a == c { a == b && b == c } else { a != b || b != c }") - blank - wrapIndent (transFun "gt_is_transitive") $ - out ("if a > b && b > c { a > c } else { true }") - blank - wrapIndent (transFun "ge_is_transitive") $ - out ("if a >= b && b >= c { a >= c } else { true }") - blank - wrapIndent (transFun "lt_is_transitive") $ - out ("if a < b && b < c { a < c } else { true }") - blank - wrapIndent (transFun "le_is_transitive") $ - out ("if a <= b && b <= c { a <= c } else { true }") \ No newline at end of file + eqStatements = buildEqStatements 0 entries + compareExp = buildCompareExp 0 entries + out $ show $ pretty' $ [sourceFile| + use core::cmp::{Eq,Ordering,PartialEq}; + #[cfg(test)] + use quickcheck::quickcheck; + use super::$$sname; + + impl PartialEq for $$sname { + fn eq(&self, other: &Self) -> bool { + let mut out = true; + $@{eqStatements} + out + } + } + + impl Eq for $$sname {} + + impl Ord for $$sname { + fn cmp(&self, other: &Self) -> Ordering { + $$(compareExp) + } + } + + impl PartialOrd for $$sname { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + #[cfg(test)] + quickcheck! { + fn eq_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool { + if a == c { a == b && b == c } else { a != b || b != c } + } + + fn gt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool { + if a > b && b > c { a > c } else { true } + } + + fn ge_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool { + if a >= b && b >= c { a >= c } else { true } + } + + fn lt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool { + if a < b && b < c { a < c } else { true } + } + + fn le_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool { + if a <= b && b <= c { a <= c } else { true } + } + } + |] + +buildEqStatements :: Word -> Word -> [Stmt Span] +buildEqStatements i numEntries + | i == (numEntries - 1) = + [[stmt| out &= self.value[$$(x)] == other.value[$$(x)]; |]] + | otherwise = + let rest = buildEqStatements (i + 1) numEntries + cur = [stmt| out &= self.value[$$(x)] == other.value[$$(x)]; |] + in cur:rest + where + x = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty + +buildCompareExp :: Word -> Word -> Expr Span +buildCompareExp i numEntries + | i == (numEntries - 1) = + [expr| self.value[$$(x)].cmp(&other.value[$$(x)]) |] + | otherwise = + let rest = buildCompareExp (i + 1) numEntries + in [expr| $$(rest).then(self.value[$$(x)].cmp(&other.value[$$(x)])) |] + where + x = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty diff --git a/generation/src/Conversions.hs b/generation/src/Conversions.hs index 72f3fde..d9afbcc 100644 --- a/generation/src/Conversions.hs +++ b/generation/src/Conversions.hs @@ -1,11 +1,16 @@ +{-# LANGUAGE QuasiQuotes #-} module Conversions( conversions ) where -import Data.List(intercalate) import File -import Gen +import Gen(Gen,toLit,out) +import Language.Rust.Data.Ident +import Language.Rust.Data.Position +import Language.Rust.Quote +import Language.Rust.Pretty +import Language.Rust.Syntax conversions :: File conversions = File { @@ -16,83 +21,203 @@ conversions = File { declareConversions :: Word -> Gen () declareConversions bitsize = - do let name = "U" ++ show bitsize - entries = bitsize `div` 64 - out "use core::convert::{From,TryFrom};" - out "#[cfg(test)]" - out "use quickcheck::quickcheck;" - out ("use super::" ++ name ++ ";") - blank - buildUnsignedPrimConversions name entries "u8" >> blank - buildUnsignedPrimConversions name entries "u16" >> blank - buildUnsignedPrimConversions name entries "u32" >> blank - buildUnsignedPrimConversions name entries "u64" >> blank - buildUnsignedPrimConversions name entries "usize" >> blank - buildSignedPrimConversions name entries "i8" >> blank - buildSignedPrimConversions name entries "i16" >> blank - buildSignedPrimConversions name entries "i32" >> blank - buildSignedPrimConversions name entries "i64" >> blank - buildSignedPrimConversions name entries "isize" - blank - out ("#[cfg(test)]") - wrapIndent "quickcheck!" $ - do roundTripTest name "u8" >> blank - roundTripTest name "u16" >> blank - roundTripTest name "u32" >> blank - roundTripTest name "u64" >> blank - roundTripTest name "usize" + do let sname = mkIdent ("U" ++ show bitsize) + entries = bitsize `div` 64 + u8_prims = buildPrimitives sname (mkIdent "u8") entries + u16_prims = buildPrimitives sname (mkIdent "u16") entries + u32_prims = buildPrimitives sname (mkIdent "u32") entries + u64_prims = buildPrimitives sname (mkIdent "u64") entries + u128_prims = generateU128Primitives sname entries + i8_prims = generateSignedPrims sname (mkIdent "u8") (mkIdent "i8") + i16_prims = generateSignedPrims sname (mkIdent "u16") (mkIdent "i16") + i32_prims = generateSignedPrims sname (mkIdent "u32") (mkIdent "i32") + i64_prims = generateSignedPrims sname (mkIdent "u64") (mkIdent "i64") + i128_prims = generateI128Primitives sname + out $ show $ pretty' $ [sourceFile| + use core::convert::{From,TryFrom}; + use core::num::TryFromIntError; + #[cfg(test)] + use quickcheck::quickcheck; + use super::$$sname; + use crate::ConversionError; -buildUnsignedPrimConversions :: String -> Word -> String -> Gen () -buildUnsignedPrimConversions name entries primtype = - do implFor ("From<" ++ primtype ++ ">") name $ - wrapIndent ("fn from(x: " ++ primtype ++ ") -> Self") $ - do let zeroes = replicate (fromIntegral (entries - 1)) "0," - values = ("x as u64," : zeroes) - out (name ++ " { value: [ ") - indent $ printBy 8 values - out ("] }") - blank - implFor ("From<" ++ name ++ ">") primtype $ - wrapIndent ("fn from(x: " ++ name ++ ") -> Self") $ - out ("x.value[0] as " ++ primtype) - blank - implFor' ("From<&'a " ++ name ++ ">") primtype $ - wrapIndent ("fn from(x: &" ++ name ++ ") -> Self") $ - out ("x.value[0] as " ++ primtype) + $@{u8_prims} + $@{u16_prims} + $@{u32_prims} + $@{u64_prims} + $@{u128_prims} -buildSignedPrimConversions :: String -> Word -> String -> Gen () -buildSignedPrimConversions name entries primtype = - do implFor ("TryFrom<" ++ primtype ++ ">") name $ - do out ("type Error = &'static str;") - blank - wrapIndent ("fn try_from(x: " ++ primtype ++ ") -> Result") $ - do wrapIndent ("if x < 0") $ - out ("return Err(\"Attempt to convert negative number to " ++ - name ++ ".\");") - blank - let zeroes = replicate (fromIntegral (entries - 1)) "0," - values = ("x as u64," : zeroes) - out ("Ok(" ++ name ++ " { value: [ ") - indent $ printBy 8 values - out ("] })") - blank - implFor ("From<" ++ name ++ ">") primtype $ - wrapIndent ("fn from(x: " ++ name ++ ") -> Self") $ - out ("x.value[0] as " ++ primtype) - blank - implFor' ("From<&'a " ++ name ++ ">") primtype $ - wrapIndent ("fn from(x: &" ++ name ++ ") -> Self") $ - out ("x.value[0] as " ++ primtype) + $@{i8_prims} + $@{i16_prims} + $@{i32_prims} + $@{i64_prims} + $@{i128_prims} + |] -roundTripTest :: String -> String -> Gen () -roundTripTest name primtype = - wrapIndent ("fn " ++ primtype ++ "_roundtrips(x: " ++ primtype ++ ") -> bool") $ - do out ("let big = " ++ name ++ "::from(x);"); - out ("let small = " ++ primtype ++ "::from(big);") - out ("x == small") +generateU128Primitives :: Ident -> Word -> [Item Span] +generateU128Primitives sname entries = [ + [item|impl From for $$sname { + fn from(x: u128) -> Self { + let mut res = $$sname::zero; + res[0] = x as u64; + res[1] = (x >> 64) as u64; + res + } + }|] + , [item|impl TryFrom<$$sname> for u128 { + type Error = ConversionError; -printBy :: Int -> [String] -> Gen () -printBy amt xs - | length xs <= amt = out (intercalate " " xs) - | otherwise = printBy amt (take amt xs) >> - printBy amt (drop amt xs) \ No newline at end of file + fn try_from(x: $$sname) -> Result { + let mut goodConversion = true; + let mut res = 0; + + res = (x.values[1] as u128) << 64; + res |= x.values[0] as u128; + + $@{testZeros} + if goodConversion { + Ok(res) + } else { + Err(ConversionError::Overflow); + } + } + }|] + , [item|impl<'a> TryFrom<&'a $$sname> for u128 { + type Error = ConversionError; + + fn try_from(x: &$$sname) -> Result { + let mut goodConversion = true; + let mut res = 0; + + res = (x.values[1] as u128) << 64; + res |= x.values[0] as u128; + + $@{testZeros} + if goodConversion { + Ok(res) + } else { + Err(ConversionError::Overflow()); + } + } + }|] + ] + where + testZeros = map (zeroTest . toLit) [2..entries-1] + zeroTest i = + [stmt| goodConversion &= x.values[$$(i)] == 0; |] + +buildPrimitives :: Ident -> Ident -> Word -> [Item Span] +buildPrimitives sname tname entries = [ + [item|impl From<$$tname> for $$sname { + fn from(x: $$tname) -> Self { + let mut res = $$sname::zero(); + res.values[0] = x as u64; + res + } + }|] + , [item|impl TryFrom<$$sname> for $$tname { + type Error = ConversionError; + + fn try_from(x: $$sname) -> Result { + let mut goodConversion = true; + let mut res = 0; + + res = x.values[0] as $$tname; + + $@{testZeros} + if goodConversion { + Ok(res) + } else { + Err(ConversionError::Overflow) + } + } + }|] + , [item|impl<'a> TryFrom<&'a $$sname> for $$tname { + type Error = ConversionError; + + fn try_from(x: &$$sname) -> Result { + let mut goodConversion = true; + let mut res = 0; + + res = x.values[0] as $$tname; + + $@{testZeros} + if goodConversion { + Ok(res) + } else { + Err(ConversionError::Overflow) + } + } + }|] + ] + where + testZeros = map (zeroTest . toLit) [1..entries-1] + zeroTest i = + [stmt| goodConversion &= x.values[$$(i)] == 0; |] + +generateSignedPrims :: Ident -> Ident -> Ident -> [Item Span] +generateSignedPrims sname unsigned signed = [ + [item|impl TryFrom<$$signed> for $$sname { + type Error = ConversionError; + + fn try_from(x: $$signed) -> Result { + let mut res = $$sname::zero(); + res.values[0] = x as u64; + if x < 0 { + Err(ConversionError::NegativeToUnsigned) + } else { + Ok(res) + } + } + }|] + , [item|impl TryFrom<$$sname> for $$signed { + type Error = ConversionError; + + fn try_from(x: $$sname) -> Result { + let uns = $$unsigned::from(x)?; + Ok($$signed::try_from(uns)?) + } + }|] + , [item|impl<'a> TryFrom<&'a $$sname> for $$signed { + type Error = ConversionError; + + fn try_from(x: &$$sname) -> Result { + let uns = $$unsigned::from(x)?; + Ok($$signed::try_from(uns)?) + } + }|] + ] + +generateI128Primitives :: Ident -> [Item Span] +generateI128Primitives sname = [ + [item|impl TryFrom for $$sname { + type Error = ConversionError; + + fn try_from(x: i128) -> Result { + let mut res = $$sname::zero(); + res.values[0] = x as u64; + res.values[1] = ((x as u128) >> 64) as u64; + if x < 0 { + Err(ConversionError::NegativeToUnsigned) + } else { + Ok(res) + } + } + }|] + , [item|impl TryFrom<$$sname> for i128 { + type Error = ConversionError; + + fn try_from(x: $$sname) -> Result { + let uns = u128::from(x)?; + Ok(i128::try_from(uns)?) + } + }|] + , [item|impl<'a> TryFrom<&'a $$sname> for i128 { + type Error = ConversionError; + + fn try_from(x: &$$sname) -> Result { + let uns = u128::from(x)?; + Ok(i128::try_from(uns)?) + } + }|] + ] diff --git a/generation/src/CryptoNum.hs b/generation/src/CryptoNum.hs index f7b2731..3f77073 100644 --- a/generation/src/CryptoNum.hs +++ b/generation/src/CryptoNum.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE QuasiQuotes #-} module CryptoNum( cryptoNum ) @@ -6,6 +7,11 @@ module CryptoNum( import Control.Monad(forM_) import File import Gen +import Language.Rust.Data.Ident +import Language.Rust.Data.Position +import Language.Rust.Quote +import Language.Rust.Pretty +import Language.Rust.Syntax cryptoNum :: File cryptoNum = File { @@ -16,125 +22,125 @@ cryptoNum = File { declareCryptoNumInstance :: Word -> Gen () declareCryptoNumInstance bitsize = - do let name = "U" ++ show bitsize + do let sname = mkIdent ("U" ++ show bitsize) entries = bitsize `div` 64 + entlit = Lit [] (Int Dec (fromIntegral entries) Unsuffixed mempty) mempty top = entries - 1 - out "use core::cmp::min;" - out "use crate::CryptoNum;" - out "#[cfg(test)]" - out "use crate::testing::{build_test_path,run_test};" - out "#[cfg(test)]" - out "use quickcheck::quickcheck;" - out ("use super::" ++ name ++ ";") - blank - implFor "CryptoNum" name $ - do wrapIndent ("fn zero() -> Self") $ - out (name ++ "{ value: [0; " ++ show entries ++ "] }") - blank - wrapIndent ("fn is_zero(&self) -> bool") $ - do forM_ (reverse [1..top]) $ \ i -> - out ("self.value[" ++ show i ++ "] == 0 &&") - out "self.value[0] == 0" - blank - wrapIndent ("fn is_even(&self) -> bool") $ - out "self.value[0] & 0x1 == 0" - blank - wrapIndent ("fn is_odd(&self) -> bool") $ - out "self.value[0] & 0x1 == 1" - blank - wrapIndent ("fn bit_length() -> usize") $ - out (show bitsize) - blank - wrapIndent ("fn mask(&mut self, len: usize)") $ - do out ("let dellen = min(len, " ++ show entries ++ ");") - wrapIndent ("for i in dellen.." ++ show entries) $ - out ("self.value[i] = 0;") - blank - wrapIndent ("fn testbit(&self, bit: usize) -> bool") $ - do out "let idx = bit / 64;" - out "let offset = bit % 64;" - wrapIndent ("if idx >= " ++ show entries) $ - out "return false;" - out "(self.value[idx] & (1u64 << offset)) != 0" - blank - wrapIndent ("fn from_bytes(bytes: &[u8]) -> Self") $ - do out ("let biggest = min(" ++ show (bitsize `div` 8) ++ ", " ++ - "bytes.len()) - 1;") - out ("let mut idx = biggest / 8;") - out ("let mut shift = (biggest % 8) * 8;") - out ("let mut i = 0;") - out ("let mut res = " ++ name ++ "::zero();") - blank - wrapIndent ("while i <= biggest") $ - do out ("res.value[idx] |= (bytes[i] as u64) << shift;") - out ("i += 1;") - out ("if shift == 0 {") - indent $ - do out "shift = 56;" - out "if idx > 0 { idx -= 1; }" - out ("} else {") - indent $ - out "shift -= 8;" - out "}" - blank - out "res" - blank - wrapIndent ("fn to_bytes(&self, bytes: &mut [u8])") $ - do let bytes = bitsize `div` 8 - out ("if bytes.len() == 0 { return; }") - blank - forM_ [0..bytes-1] $ \ idx -> - do let (validx, shift) = byteShiftInfo idx - out ("let byte" ++ show idx ++ " = (self.value[" ++ - show validx ++ "] >> " ++ show shift ++ ")" ++ - " as u8;") - blank - out ("let mut idx = min(bytes.len() - 1, " ++ show (bytes - 1) ++ ");") - forM_ [0..bytes-2] $ \ i -> - do out ("bytes[idx] = byte" ++ show i ++ ";") - out ("if idx == 0 { return; }") - out ("idx -= 1;") - out ("bytes[idx] = byte" ++ show (bytes-1) ++ ";") - blank - out "#[cfg(test)]" - wrapIndent "quickcheck!" $ - do wrapIndent ("fn to_from_ident(x: " ++ name ++ ") -> bool") $ - do out ("let mut buffer = [0; " ++ show (bitsize `div` 8) ++ "];") - out ("x.to_bytes(&mut buffer);"); - out ("let y = " ++ name ++ "::from_bytes(&buffer);") - out ("x == y") - blank - out "#[cfg(test)]" - out "#[allow(non_snake_case)]" - out "#[test]" - wrapIndent "fn KATs()" $ - do let name' = pad 5 '0' (show bitsize) - out ("run_test(build_test_path(\"base\",\"" ++ name' ++ "\"), 8, |case| {") - indent $ - do out ("let (neg0, xbytes) = case.get(\"x\").unwrap();") - out ("let (neg1, mbytes) = case.get(\"m\").unwrap();") - out ("let (neg2, zbytes) = case.get(\"z\").unwrap();") - out ("let (neg3, ebytes) = case.get(\"e\").unwrap();") - out ("let (neg4, obytes) = case.get(\"o\").unwrap();") - out ("let (neg5, rbytes) = case.get(\"r\").unwrap();") - out ("let (neg6, bbytes) = case.get(\"b\").unwrap();") - out ("let (neg7, tbytes) = case.get(\"t\").unwrap();") - out ("assert!(!neg0&&!neg1&&!neg2&&!neg3&&!neg4&&!neg5&&!neg6&&!neg7);") - out ("let mut x = "++name++"::from_bytes(xbytes);") - out ("let m = "++name++"::from_bytes(mbytes);") - out ("let z = 1 == zbytes[0];") - out ("let e = 1 == ebytes[0];") - out ("let o = 1 == obytes[0];") - out ("let r = "++name++"::from_bytes(rbytes);") - out ("let b = usize::from("++name++"::from_bytes(bbytes));") - out ("let t = 1 == tbytes[0];") - out ("assert_eq!(x.is_zero(), z);") - out ("assert_eq!(x.is_even(), e);") - out ("assert_eq!(x.is_odd(), o);") - out ("assert_eq!(x.testbit(b), t);") - out ("x.mask(usize::from(&m));") - out ("assert_eq!(x, r);") - out ("});") + zeroTests = generateZeroTests 0 entries + bitlength = toLit bitsize + bytelen = bitsize `div` 8 + bytelenlit = toLit bytelen + bytebuffer = Delimited mempty Brace (Stream [ + Tree (Token mempty (LiteralTok (IntegerTok "0") Nothing)), + Tree (Token mempty Semicolon), + Tree (Token mempty (LiteralTok (IntegerTok (show bytelen)) Nothing)) + ]) + entrieslit = toLit entries + packerLines = generatePackerLines 0 (bitsize `div` 8) + out $ show $ pretty' $ [sourceFile| + use core::cmp::min; + use crate::CryptoNum; + #[cfg(test)] + use crate::testing::{build_test_path,run_test}; + #[cfg(test)] + use quickcheck::quickcheck; + use super::$$sname; + + impl CryptoNum for $$sname { + fn zero() -> Self { + $$sname{ value: [0; $$(entlit)] } + } + fn is_zero(&self) -> bool { + let mut result = true; + $@{zeroTests} + result + } + fn is_even(&self) -> bool { + self.value[0] & 0x1 == 0 + } + fn is_off(&self) -> bool { + self.value[0] & 0x1 == 1 + } + fn bit_length() -> usize { + $$(bitlength) + } + fn mask(&mut self, len: usize) { + let dellen = min(len, $$(entrieslit)); + for i in dellen..$$(entrieslit) { + self.value[i] = 0; + } + } + fn testbit(&self, bit: usize) -> bool { + let idx = bit / 64; + let offset = bit % 64; + if idx >= $$(entrieslit) { + return false; + } + (self.value[idx] & (1u64 << offset)) != 0 + } + fn from_bytes(bytes: &[u8]) -> Self { + let biggest = min($$(bytelenlit), bytes.len()) - 1; + let mut idx = biggest / 8; + let mut shift = (biggest % 8) * 8; + let mut i = 0; + let mut res = $$sname::zero(); + + while i <= biggest { + res.value[idx] |= (bytes[i] as u64) << shift; + i += 1; + if shift == 0 { + shift = 56; + if idx > 0 { + idx -= 1; + } + } else { + shift -= 8; + } + } + + res + } + fn to_bytes(&self, bytes: &mut [u8]) { + let mut idx = 0; + let mut shift = 0; + + for x in bytes.iter_mut().take($$(bytelenlit)).reverse() { + *x = (self.values[idx] >> shift) as u8; + shift += 8; + if shift == 64 { + idx += 1; + shift = 0; + } + } + } + } + + #[cfg(test)] + quickcheck! { + fn to_from_ident(x: $$sname) -> bool { + let mut buffer = $$(bytebuffer); + x.to_bytes(&mut buffer); + let y = $$sname::from_bytes(&buffer); + x == y + } + } + + #[cfg(test)] + #[allow(non_snake_case)] + #[test] + fn KATs() { + run_test(build_test_path("base", stringify!($$sname)), 8, |case| { + let (neg0, xbytes) = case.get("x").unwrap(); + let (neg1, mbytes) = case.get("m").unwrap(); + let (neg2, zbytes) = case.get("z").unwrap(); + let (neg3, ebytes) = case.get("e").unwrap(); + let (neg4, obytes) = case.get("o").unwrap(); + let (neg5, rbytes) = case.get("r").unwrap(); + let (neg6, bbytes) = case.get("b").unwrap(); + let (neg7, tbytes) = case.get("t").unwrap(); + }); + } + |] byteShiftInfo :: Word -> (Word, Word) byteShiftInfo idx = @@ -143,4 +149,25 @@ byteShiftInfo idx = pad :: Int -> Char -> String -> String pad len c str | length str >= len = str - | otherwise = pad len c (c:str) \ No newline at end of file + | otherwise = pad len c (c:str) + +generateZeroTests :: Word -> Word -> [Stmt Span] +generateZeroTests i max + | i == max = [] + | otherwise = + let ilit = toLit i + in [stmt| result = self.values[$$(ilit)] == 0; |] : + generateZeroTests (i + 1) max + +generatePackerLines :: Word -> Word -> [Stmt Span] +generatePackerLines i max + | i == max = [] + | otherwise = + let ilit = toLit i + nextLit = toLit (i + 1) + validx = toLit (i `div` 8) + shiftx = toLit ((i `mod` 8) * 8) + writeLine = [stmt| bytes[$$(ilit)] = (self.values[$$(validx)] >> $$(shiftx)) as u8; |] + ifLine = [stmt| if bytes.len() == $$(nextLit) { return; } |] + in writeLine : ifLine : generatePackerLines (i + 1) max + diff --git a/generation/src/Gen.hs b/generation/src/Gen.hs index 02bd661..f7d2ee6 100644 --- a/generation/src/Gen.hs +++ b/generation/src/Gen.hs @@ -10,6 +10,7 @@ module Gen( implFor, implFor', implFor'', + toLit ) where @@ -18,6 +19,8 @@ import Control.Monad.State.Class(MonadState,get,put) import Control.Monad.Writer.Class(MonadWriter,tell) import Data.List(replicate) import Data.Word(Word) +import Language.Rust.Data.Position +import Language.Rust.Syntax newtype Gen a = Gen { unGen :: RWS () String GenState a} deriving (Applicative, Functor, Monad, MonadState GenState, MonadWriter String) @@ -85,4 +88,9 @@ implFor' trait name middle = implFor'' :: String -> String -> Gen a -> Gen a implFor'' trait name middle = - wrapIndent ("impl<'a,'b> " ++ trait ++ " for " ++ name) middle \ No newline at end of file + wrapIndent ("impl<'a,'b> " ++ trait ++ " for " ++ name) middle + +toLit :: Word -> Expr Span +toLit i = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty + + diff --git a/generation/src/Main.hs b/generation/src/Main.hs index ca6346e..6bc50b3 100644 --- a/generation/src/Main.hs +++ b/generation/src/Main.hs @@ -8,7 +8,6 @@ import Conversions(conversions) import CryptoNum(cryptoNum) import Control.Monad(forM_,unless) import Data.Maybe(mapMaybe) -import Data.Word(Word) import File(File,Task(..),addModuleTasks,makeTask) import Gen(runGen) import System.Directory(createDirectoryIfMissing) @@ -57,4 +56,4 @@ main = forM_ (zip [(1::Word)..] tasks) $ \ (i, task) -> do putStrLn ("[" ++ show i ++ "/" ++ show total ++ "] " ++ outputFile task) createDirectoryIfMissing True (takeDirectory (outputFile task)) - runGen (outputFile task) (fileGenerator task) \ No newline at end of file + runGen (outputFile task) (fileGenerator task) diff --git a/src/lib.rs b/src/lib.rs index efeb742..de7a747 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,8 @@ pub mod unsigned; #[cfg(test)] mod testing; +use core::num::TryFromIntError; + /// A trait definition for large numbers. pub trait CryptoNum { /// Generate a new value of the given type. @@ -36,3 +38,14 @@ pub trait CryptoNum { fn to_bytes(&self, bytes: &mut [u8]); } +/// An error in conversion of large numbers (either to primitives or to other numbers +pub enum ConversionError { + NegativeToUnsigned, + Overflow +} + +impl From for ConversionError { + fn from(_: TryFromIntError) -> ConversionError { + ConversionError::Overflow + } +}