Start working on switching to language-rust as a generator, for fun.

This commit is contained in:
2019-10-22 20:12:08 -07:00
parent d7665acf64
commit 2400b10fbc
9 changed files with 723 additions and 404 deletions

View File

@@ -2,7 +2,6 @@ cabal-version: 2.0
-- Initial package description 'generation.cabal' generated by 'cabal -- Initial package description 'generation.cabal' generated by 'cabal
-- init'. For further documentation, see -- init'. For further documentation, see
-- http://haskell.org/cabal/users-guide/ -- http://haskell.org/cabal/users-guide/
name: generation name: generation
version: 0.1.0.0 version: 0.1.0.0
synopsis: Generates the cryptonum Rust library, based on requirements. synopsis: Generates the cryptonum Rust library, based on requirements.
@@ -20,10 +19,11 @@ executable generation
main-is: Main.hs main-is: Main.hs
other-modules: Base, BinaryOps, Compare, Conversions, CryptoNum, File, Gen other-modules: Base, BinaryOps, Compare, Conversions, CryptoNum, File, Gen
-- other-extensions: -- other-extensions:
build-depends: base ^>=4.12.0.0, build-depends: base >= 4.12.0.0,
containers, containers,
directory, directory,
filepath, filepath,
language-rust,
mtl mtl
hs-source-dirs: src hs-source-dirs: src
default-language: Haskell2010 default-language: Haskell2010

View File

@@ -1,3 +1,4 @@
{-# LANGUAGE QuasiQuotes #-}
module Base( module Base(
base base
) )
@@ -6,6 +7,11 @@ module Base(
import Control.Monad(forM_) import Control.Monad(forM_)
import File import File
import Gen import Gen
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Pretty
import Language.Rust.Syntax
base :: File base :: File
base = File { base = File {
@@ -19,38 +25,90 @@ declareBaseStructure bitsize =
do let name = "U" ++ show bitsize do let name = "U" ++ show bitsize
entries = bitsize `div` 64 entries = bitsize `div` 64
top = entries - 1 top = entries - 1
out "use core::fmt;" sname = mkIdent name
out "use quickcheck::{Arbitrary,Gen};" entriese = Lit [] (Int Dec (fromIntegral entries) Unsuffixed mempty) mempty
blank strname = Lit [] (Str name Cooked Unsuffixed mempty) mempty
out "#[derive(Clone)]" debugExp = buildDebugExp 0 entries [expr| f.debug_tuple($$(strname)) |]
wrapIndent ("pub struct " ++ name) $ lowerPrints = buildPrints entries "x"
out ("pub(crate) value: [u64; " ++ show entries ++ "]") upperPrints = buildPrints entries "X"
blank out $ show $ pretty' $ [sourceFile|
implFor "fmt::Debug" name $ use core::fmt;
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ use quickcheck::{Arbitrary,Gen};
do out ("f.debug_tuple(" ++ show name ++ ")")
forM_ [0..top] $ \ i -> #[derive(Clone)]
out (" .field(&self.value[" ++ show i ++ "])") pub struct $$sname {
out " .finish()" value: [u64; $$(entriese)]
blank }
implFor "fmt::UpperHex" name $
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ impl fmt::Debug for $$sname {
do forM_ (reverse [1..top]) $ \ i -> fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
out ("write!(f, \"{:X}\", self.value[" ++ show i ++ "])?;") $$(debugExp).finish()
out "write!(f, \"{:X}\", self.value[0])" }
blank }
implFor "fmt::LowerHex" name $
wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $ impl fmt::UpperHex for $$sname {
do forM_ (reverse [1..top]) $ \ i -> fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
out ("write!(f, \"{:x}\", self.value[" ++ show i ++ "])?;") $@{upperPrints}
out "write!(f, \"{:x}\", self.value[0])" write!(f, "{:X}", self.value[0])
blank }
implFor "Arbitrary" name $ }
wrapIndent "fn arbitrary<G: Gen>(g: &mut G) -> Self" $
do out (name ++ " {") impl fmt::LowerHex for $$sname {
indent $ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
do out ("value: [") $@{lowerPrints}
indent $ forM_ [0..top] $ \ _ -> write!(f, "{:x}", self.value[0])
out ("g.next_u64(),") }
out ("]") }
out ("}")
impl Arbitrary for $$sname {
fn arbitrary<G: Gen>(g: &mut G) -> Self {
let mut res = $$sname{ value: [0; $$(entriese)] };
for entry in res.iter_mut() {
*entry = g.next_u64();
}
res
}
}
|]
buildDebugExp :: Word -> Word -> Expr Span -> Expr Span
buildDebugExp i top acc
| i == top = acc
| otherwise =
let liti = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty
in buildDebugExp (i + 1) top [expr| $$(acc).field(&self.value[$$(liti)]) |]
buildPrints :: Word -> String -> [Stmt Span]
buildPrints entries printer = go (entries - 1)
where
litStr = Token mempty (LiteralTok (StrTok ("{:" ++ printer ++ "}")) Nothing)
--Lit [] (Str ("{:" ++ printer ++ "}") Cooked Unsuffixed mempty) mempty
go 0 = []
go x =
let rest = go (x - 1)
curi = Token mempty (LiteralTok (IntegerTok (show x)) Nothing)
-- Lit [] (Int Dec (fromIntegral x) Unsuffixed mempty) mempty
cur = [stmt| write!(f, $$(litStr), self.value[$$(curi)])?; |]
in cur : rest
-- implFor "fmt::UpperHex" name $
-- wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
-- do forM_ (reverse [1..top]) $ \ i ->
-- out ("write!(f, \"{:X}\", self.value[" ++ show i ++ "])?;")
-- out "write!(f, \"{:X}\", self.value[0])"
-- blank
-- implFor "fmt::LowerHex" name $
-- wrapIndent "fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result" $
-- do forM_ (reverse [1..top]) $ \ i ->
-- out ("write!(f, \"{:x}\", self.value[" ++ show i ++ "])?;")
-- out "write!(f, \"{:x}\", self.value[0])"
-- blank
-- implFor "Arbitrary" name $
-- wrapIndent "fn arbitrary<G: Gen>(g: &mut G) -> Self" $
-- do out (name ++ " {")
-- indent $
-- do out ("value: [")
-- indent $ forM_ [0..top] $ \ _ ->
-- out ("g.next_u64(),")
-- out ("]")
-- out ("}")

View File

@@ -1,11 +1,16 @@
{-# LANGUAGE QuasiQuotes #-}
module BinaryOps( module BinaryOps(
binaryOps binaryOps
) )
where where
import Control.Monad(forM_)
import File import File
import Gen import Gen
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Pretty
import Language.Rust.Syntax
binaryOps :: File binaryOps :: File
binaryOps = File { binaryOps = File {
@@ -16,130 +21,177 @@ binaryOps = File {
declareBinaryOperators :: Word -> Gen () declareBinaryOperators :: Word -> Gen ()
declareBinaryOperators bitsize = declareBinaryOperators bitsize =
do let name = "U" ++ show bitsize do let struct_name = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64 entries = bitsize `div` 64
out "use core::ops::{BitAnd,BitAndAssign};" andOps = generateBinOps "BitAnd" struct_name "bitand" BitAndOp entries
out "use core::ops::{BitOr,BitOrAssign};" orOps = generateBinOps "BitOr" struct_name "bitor" BitOrOp entries
out "use core::ops::{BitXor,BitXorAssign};" xorOps = generateBinOps "BitXor" struct_name "bitxor" BitXorOp entries
out "use core::ops::Not;" baseNegationStmts = negationStatements "self" entries
out "#[cfg(test)]" refNegationStmts = negationStatements "output" entries
out "use crate::CryptoNum;" out $ show $ pretty' $ [sourceFile|
out "#[cfg(test)]" use core::ops::{BitAnd,BitAndAssign};
out "use quickcheck::quickcheck;" use core::ops::{BitOr,BitOrAssign};
out ("use super::U" ++ show bitsize ++ ";") use core::ops::{BitXor,BitXorAssign};
blank use core::ops::Not;
generateBinOps "BitAnd" name "bitand" "&=" entries #[cfg(test)]
blank use crate::CryptoNum;
generateBinOps "BitOr" name "bitor" "|=" entries #[cfg(test)]
blank use quickcheck::quickcheck;
generateBinOps "BitXor" name "bitxor" "^=" entries use super::$$struct_name;
blank
implFor "Not" name $
do out "type Output = Self;"
blank
wrapIndent "fn not(mut self) -> Self" $
do forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];")
out "self"
blank
implFor' "Not" ("&'a " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn not(self) -> " ++ name) $
do out "let mut output = self.clone();"
forM_ [0..entries-1] $ \ i ->
out ("output.value[" ++ show i ++ "] = !self.value[" ++ show i ++ "];")
out "output"
blank
addBinaryLaws name
generateBinOps :: String -> String -> String -> String -> Word -> Gen () $@{andOps}
generateBinOps trait name fun op entries = $@{orOps}
do implFor (trait ++ "Assign") name $ $@{xorOps}
wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: Self)") $
forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];")
blank
implFor' (trait ++ "Assign<&'a " ++ name ++ ">") name $
wrapIndent ("fn " ++ fun ++ "_assign(&mut self, rhs: &Self)") $
forM_ [0..entries-1] $ \ i ->
out ("self.value[" ++ show i ++ "] "++op++" rhs.value[" ++ show i ++ "];")
blank
generateBinOpsFromAssigns trait name fun op
generateBinOpsFromAssigns :: String -> String -> String -> String -> Gen () impl Not for $$struct_name {
generateBinOpsFromAssigns trait name fun op = type Output = Self;
do implFor trait name $
do out "type Output = Self;"
blank
wrapIndent ("fn " ++ fun ++ "(mut self, rhs: Self) -> Self") $
do out ("self " ++ op ++ " rhs;")
out "self"
blank
implFor' (trait ++ "<&'a " ++ name ++ ">") name $
do out "type Output = Self;"
blank
wrapIndent ("fn " ++ fun ++ "(mut self, rhs: &Self) -> Self") $
do out ("self " ++ op ++ " rhs;")
out "self"
blank
implFor' (trait ++ "<" ++ name ++ ">") ("&'a " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn " ++ fun ++ "(self, mut rhs: " ++ name ++ ") -> " ++ name) $
do out ("rhs " ++ op ++ " self;")
out "rhs"
blank
implFor'' (trait ++ "<&'a " ++ name ++ ">") ("&'b " ++ name) $
do out ("type Output = " ++ name ++ ";")
blank
wrapIndent ("fn " ++ fun ++ "(self, rhs: &" ++ name ++ ") -> " ++ name) $
do out "let mut output = self.clone();"
out ("output " ++ op ++ " rhs;")
out "output"
addBinaryLaws :: String -> Gen () fn not(mut self) -> Self {
addBinaryLaws name = $@{baseNegationStmts}
do let args3 = "(a: " ++ name ++ ", b: " ++ name ++ ", c: " ++ name ++ ")" self
args2 = "(a: " ++ name ++ ", b: " ++ name ++ ")" }
out "#[cfg(test)]" }
wrapIndent "quickcheck!" $
do wrapIndent ("fn and_associative" ++ args3 ++ " -> bool") $ impl<'a> Not for &'a $$struct_name {
out ("((&a & &b) & &c) == (&a & (&b & &c))") type Output = Self;
blank
wrapIndent ("fn and_commutative" ++ args2 ++ " -> bool") $ fn not(self) -> Self {
out ("(&a & &b) == (&b & &a)") let mut output = self.clone();
blank $@{refNegationStmts}
wrapIndent ("fn and_idempotent" ++ args2 ++ " -> bool") $ output
out ("(&a & &b) == (&a & &b & &a)") }
blank }
wrapIndent ("fn xor_associative" ++ args3 ++ " -> bool") $
out ("((&a ^ &b) ^ &c) == (&a ^ (&b ^ &c))") quickcheck! {
blank fn and_associative(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool {
wrapIndent ("fn xor_commutative" ++ args2 ++ " -> bool") $ ((&a & &b) & &c) == (&a & (&b & &c))
out ("(&a ^ &b) == (&b ^ &a)") }
blank fn and_commutative(a: $$struct_name, b: $$struct_name) -> bool {
wrapIndent ("fn or_associative" ++ args3 ++ " -> bool") $ (&a & &b) == (&b & &a)
out ("((&a | &b) | &c) == (&a | (&b | &c))") }
blank fn and_idempotent(a: $$struct_name, b: $$struct_name) -> bool {
wrapIndent ("fn or_commutative" ++ args2 ++ " -> bool") $ (&a & &b) == (&a & &b & &a)
out ("(&a | &b) == (&b | &a)") }
blank
wrapIndent ("fn or_idempotent" ++ args2 ++ " -> bool") $ fn xor_associative(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool {
out ("(&a | &b) == (&a | &b | &a)") ((&a ^ &b) ^ &c) == (&a ^ (&b ^ &c))
blank }
wrapIndent ("fn and_or_distribution" ++ args3 ++ "-> bool") $ fn xor_commutative(a: $$struct_name, b: $$struct_name) -> bool {
out ("(&a & (&b | &c)) == ((&a & &b) | (&a & &c))") (&a ^ &b) == (&b ^ &a)
blank }
wrapIndent ("fn xor_clears(a: " ++ name ++ ") -> bool") $
out (name ++ "::zero() == (&a ^ &a)") fn or_associative(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool {
blank ((&a | &b) & &c) == (&a | (&b | &c))
wrapIndent ("fn double_neg_ident(a: " ++ name ++ ") -> bool") $ }
out ("a == !!&a") fn or_commutative(a: $$struct_name, b: $$struct_name) -> bool {
blank (&a | &b) == (&b | &a)
wrapIndent ("fn and_ident(a: " ++ name ++ ") -> bool") $ }
do out ("let ones = !" ++ name ++ "::zero();") fn or_idempotent(a: $$struct_name, b: $$struct_name) -> bool {
out ("(&a & &ones) == a") (&a | &b) == (&a | &b | &a)
blank }
wrapIndent ("fn or_ident(a: " ++ name ++ ") -> bool") $
out ("(&a | " ++ name ++ "::zero()) == a") fn and_or_distribution(a: $$struct_name, b: $$struct_name, c: $$struct_name) -> bool {
(&a & (&b | &c)) == ((&a & &b) | (&a & &c))
}
fn xor_clears(a: $$struct_name) -> bool {
$$struct_name::zero() == (&a ^ *a)
}
fn double_neg_ident(a: $$struct_name) -> bool {
a == !!$a
}
fn and_ident(a: $$struct_name) -> bool {
let ones = !$$struct_name::zero();
(&a & &ones) == a
}
fn or_ident(a: $$struct_name) -> bool {
(&a | $$struct_name::zero()) == a
}
}
|]
negationStatements :: String -> Word -> [Stmt Span]
negationStatements target entries = map genStatement [0..entries-1]
where
genStatement i =
let idx = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty
v = mkIdent target
in [stmt| $$v.value[$$(idx)] = !self.value[$$(idx)]; |]
generateBinOps :: String -> Ident -> String -> BinOp -> Word -> [Item Span]
generateBinOps trait sname func oper entries =
[normAssign, refAssign] ++ generateAllTheVariants traitIdent funcIdent sname oper
where
traitIdent = mkIdent trait
assignIdent = mkIdent (trait ++ "Assign")
funcIdent = mkIdent func
funcAssignIdent = mkIdent (func ++ "_assign")
--
normAssign = [item|
impl $$assignIdent for $$sname {
fn $$funcAssignIdent(&mut self, rhs: Self) {
$@{assignStatements}
}
}
|]
refAssign = [item|
impl $$assignIdent<&'a $$sname> for $$sname {
fn $$funcAssignIdent(&mut self, rhs: &Self) {
$@{assignStatements}
}
}
|]
--
assignStatements :: [Stmt Span]
assignStatements = map genAssign [0..entries-1]
genAssign i =
let idx = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty
left = [expr| self.value[$$(idx)] |]
right = [expr| rhs.value[$$(idx)] |]
in Semi (AssignOp [] oper left right mempty) mempty
generateAllTheVariants :: Ident -> Ident -> Ident -> BinOp -> [Item Span]
generateAllTheVariants traitname func sname oper = [
[item|
impl $$traitname for $$sname {
type Output = Self;
fn $$func(mut self, rhs: Self) -> Self {
$${assigner_self_rhs}
self
}
}|]
, [item|
impl<'a> $$traitname<&'a $$sname> for $$sname {
type Output = Self;
fn $$func(mut self, rhs: Self) -> Self {
$${assigner_self_rhs}
self
}
}|]
, [item|
impl<'a> $$traitname for &'a $$sname {
type Output = Self;
fn $$func(mut self, rhs: Self) -> Self {
$${assigner_rhs_self}
self
}
}|]
, [item|
impl<'a,'b> $$traitname<&'a $$sname> for &'b $$sname {
type Output = Self;
fn $$func(mut self, rhs: Self) -> Self {
let mut out = self.clone();
$${assigner_out_rhs}
out
}
}|]
]
where
assigner_self_rhs = assigner [expr| self |] [expr| rhs |]
assigner_rhs_self = assigner [expr| rhs |] [expr| self |]
assigner_out_rhs = assigner [expr| out |] [expr| rhs |]
assigner left right =
Semi (AssignOp [] oper left right mempty) mempty

View File

@@ -1,9 +1,14 @@
{-# LANGUAGE QuasiQuotes #-}
module Compare(comparisons) module Compare(comparisons)
where where
import Control.Monad(forM_)
import File import File
import Gen import Gen
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Pretty
import Language.Rust.Syntax
comparisons :: File comparisons :: File
comparisons = File { comparisons = File {
@@ -14,47 +19,79 @@ comparisons = File {
declareComparators :: Word -> Gen () declareComparators :: Word -> Gen ()
declareComparators bitsize = declareComparators bitsize =
do let name = "U" ++ show bitsize do let sname = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64 entries = bitsize `div` 64
top = entries - 1 eqStatements = buildEqStatements 0 entries
out "use core::cmp::{Eq,Ordering,PartialEq};" compareExp = buildCompareExp 0 entries
out "#[cfg(test)]" out $ show $ pretty' $ [sourceFile|
out "use quickcheck::quickcheck;" use core::cmp::{Eq,Ordering,PartialEq};
out ("use super::" ++ name ++ ";") #[cfg(test)]
blank use quickcheck::quickcheck;
implFor "PartialEq" name $ use super::$$sname;
wrapIndent "fn eq(&self, other: &Self) -> bool" $
do forM_ (reverse [1..top]) $ \ i -> impl PartialEq for $$sname {
out ("self.value[" ++ show i ++ "] == other.value[" ++ show i ++ "] && ") fn eq(&self, other: &Self) -> bool {
out "self.value[0] == other.value[0]" let mut out = true;
blank $@{eqStatements}
implFor "Eq" name $ return () out
blank }
implFor "Ord" name $ }
wrapIndent "fn cmp(&self, other: &Self) -> Ordering" $
do out ("self.value[" ++ show top ++ "].cmp(&other.value[" ++ show top ++ "])") impl Eq for $$sname {}
forM_ (reverse [0..top-1]) $ \ i ->
out (" .then(self.value[" ++ show i ++ "].cmp(&other.value[" ++ show i ++ "]))") impl Ord for $$sname {
blank fn cmp(&self, other: &Self) -> Ordering {
implFor "PartialOrd" name $ $$(compareExp)
wrapIndent "fn partial_cmp(&self, other: &Self) -> Option<Ordering>" $ }
out "Some(self.cmp(other))" }
blank
out "#[cfg(test)]" impl PartialOrd for $$sname {
wrapIndent "quickcheck!" $ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
do let transFun n = "fn " ++ n ++ "(a: " ++ name ++ ", b: " ++ name ++ Some(self.cmp(other))
", c: " ++ name ++ ") -> bool" }
wrapIndent (transFun "eq_is_transitive") $ }
out ("if a == c { a == b && b == c } else { a != b || b != c }")
blank #[cfg(test)]
wrapIndent (transFun "gt_is_transitive") $ quickcheck! {
out ("if a > b && b > c { a > c } else { true }") fn eq_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
blank if a == c { a == b && b == c } else { a != b || b != c }
wrapIndent (transFun "ge_is_transitive") $ }
out ("if a >= b && b >= c { a >= c } else { true }")
blank fn gt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
wrapIndent (transFun "lt_is_transitive") $ if a > b && b > c { a > c } else { true }
out ("if a < b && b < c { a < c } else { true }") }
blank
wrapIndent (transFun "le_is_transitive") $ fn ge_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
out ("if a <= b && b <= c { a <= c } else { true }") if a >= b && b >= c { a >= c } else { true }
}
fn lt_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a < b && b < c { a < c } else { true }
}
fn le_is_transitive(a: $$sname, b: $$sname, c: $$sname) -> bool {
if a <= b && b <= c { a <= c } else { true }
}
}
|]
buildEqStatements :: Word -> Word -> [Stmt Span]
buildEqStatements i numEntries
| i == (numEntries - 1) =
[[stmt| out &= self.value[$$(x)] == other.value[$$(x)]; |]]
| otherwise =
let rest = buildEqStatements (i + 1) numEntries
cur = [stmt| out &= self.value[$$(x)] == other.value[$$(x)]; |]
in cur:rest
where
x = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty
buildCompareExp :: Word -> Word -> Expr Span
buildCompareExp i numEntries
| i == (numEntries - 1) =
[expr| self.value[$$(x)].cmp(&other.value[$$(x)]) |]
| otherwise =
let rest = buildCompareExp (i + 1) numEntries
in [expr| $$(rest).then(self.value[$$(x)].cmp(&other.value[$$(x)])) |]
where
x = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty

View File

@@ -1,11 +1,16 @@
{-# LANGUAGE QuasiQuotes #-}
module Conversions( module Conversions(
conversions conversions
) )
where where
import Data.List(intercalate)
import File import File
import Gen import Gen(Gen,toLit,out)
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Pretty
import Language.Rust.Syntax
conversions :: File conversions :: File
conversions = File { conversions = File {
@@ -16,83 +21,203 @@ conversions = File {
declareConversions :: Word -> Gen () declareConversions :: Word -> Gen ()
declareConversions bitsize = declareConversions bitsize =
do let name = "U" ++ show bitsize do let sname = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64 entries = bitsize `div` 64
out "use core::convert::{From,TryFrom};" u8_prims = buildPrimitives sname (mkIdent "u8") entries
out "#[cfg(test)]" u16_prims = buildPrimitives sname (mkIdent "u16") entries
out "use quickcheck::quickcheck;" u32_prims = buildPrimitives sname (mkIdent "u32") entries
out ("use super::" ++ name ++ ";") u64_prims = buildPrimitives sname (mkIdent "u64") entries
blank u128_prims = generateU128Primitives sname entries
buildUnsignedPrimConversions name entries "u8" >> blank i8_prims = generateSignedPrims sname (mkIdent "u8") (mkIdent "i8")
buildUnsignedPrimConversions name entries "u16" >> blank i16_prims = generateSignedPrims sname (mkIdent "u16") (mkIdent "i16")
buildUnsignedPrimConversions name entries "u32" >> blank i32_prims = generateSignedPrims sname (mkIdent "u32") (mkIdent "i32")
buildUnsignedPrimConversions name entries "u64" >> blank i64_prims = generateSignedPrims sname (mkIdent "u64") (mkIdent "i64")
buildUnsignedPrimConversions name entries "usize" >> blank i128_prims = generateI128Primitives sname
buildSignedPrimConversions name entries "i8" >> blank out $ show $ pretty' $ [sourceFile|
buildSignedPrimConversions name entries "i16" >> blank use core::convert::{From,TryFrom};
buildSignedPrimConversions name entries "i32" >> blank use core::num::TryFromIntError;
buildSignedPrimConversions name entries "i64" >> blank #[cfg(test)]
buildSignedPrimConversions name entries "isize" use quickcheck::quickcheck;
blank use super::$$sname;
out ("#[cfg(test)]") use crate::ConversionError;
wrapIndent "quickcheck!" $
do roundTripTest name "u8" >> blank
roundTripTest name "u16" >> blank
roundTripTest name "u32" >> blank
roundTripTest name "u64" >> blank
roundTripTest name "usize"
buildUnsignedPrimConversions :: String -> Word -> String -> Gen () $@{u8_prims}
buildUnsignedPrimConversions name entries primtype = $@{u16_prims}
do implFor ("From<" ++ primtype ++ ">") name $ $@{u32_prims}
wrapIndent ("fn from(x: " ++ primtype ++ ") -> Self") $ $@{u64_prims}
do let zeroes = replicate (fromIntegral (entries - 1)) "0," $@{u128_prims}
values = ("x as u64," : zeroes)
out (name ++ " { value: [ ")
indent $ printBy 8 values
out ("] }")
blank
implFor ("From<" ++ name ++ ">") primtype $
wrapIndent ("fn from(x: " ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
blank
implFor' ("From<&'a " ++ name ++ ">") primtype $
wrapIndent ("fn from(x: &" ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
buildSignedPrimConversions :: String -> Word -> String -> Gen () $@{i8_prims}
buildSignedPrimConversions name entries primtype = $@{i16_prims}
do implFor ("TryFrom<" ++ primtype ++ ">") name $ $@{i32_prims}
do out ("type Error = &'static str;") $@{i64_prims}
blank $@{i128_prims}
wrapIndent ("fn try_from(x: " ++ primtype ++ ") -> Result<Self,Self::Error>") $ |]
do wrapIndent ("if x < 0") $
out ("return Err(\"Attempt to convert negative number to " ++
name ++ ".\");")
blank
let zeroes = replicate (fromIntegral (entries - 1)) "0,"
values = ("x as u64," : zeroes)
out ("Ok(" ++ name ++ " { value: [ ")
indent $ printBy 8 values
out ("] })")
blank
implFor ("From<" ++ name ++ ">") primtype $
wrapIndent ("fn from(x: " ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
blank
implFor' ("From<&'a " ++ name ++ ">") primtype $
wrapIndent ("fn from(x: &" ++ name ++ ") -> Self") $
out ("x.value[0] as " ++ primtype)
roundTripTest :: String -> String -> Gen () generateU128Primitives :: Ident -> Word -> [Item Span]
roundTripTest name primtype = generateU128Primitives sname entries = [
wrapIndent ("fn " ++ primtype ++ "_roundtrips(x: " ++ primtype ++ ") -> bool") $ [item|impl From<u128> for $$sname {
do out ("let big = " ++ name ++ "::from(x);"); fn from(x: u128) -> Self {
out ("let small = " ++ primtype ++ "::from(big);") let mut res = $$sname::zero;
out ("x == small") res[0] = x as u64;
res[1] = (x >> 64) as u64;
res
}
}|]
, [item|impl TryFrom<$$sname> for u128 {
type Error = ConversionError;
printBy :: Int -> [String] -> Gen () fn try_from(x: $$sname) -> Result<u128,ConversionError> {
printBy amt xs let mut goodConversion = true;
| length xs <= amt = out (intercalate " " xs) let mut res = 0;
| otherwise = printBy amt (take amt xs) >>
printBy amt (drop amt xs) res = (x.values[1] as u128) << 64;
res |= x.values[0] as u128;
$@{testZeros}
if goodConversion {
Ok(res)
} else {
Err(ConversionError::Overflow);
}
}
}|]
, [item|impl<'a> TryFrom<&'a $$sname> for u128 {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<u128,ConversionError> {
let mut goodConversion = true;
let mut res = 0;
res = (x.values[1] as u128) << 64;
res |= x.values[0] as u128;
$@{testZeros}
if goodConversion {
Ok(res)
} else {
Err(ConversionError::Overflow());
}
}
}|]
]
where
testZeros = map (zeroTest . toLit) [2..entries-1]
zeroTest i =
[stmt| goodConversion &= x.values[$$(i)] == 0; |]
buildPrimitives :: Ident -> Ident -> Word -> [Item Span]
buildPrimitives sname tname entries = [
[item|impl From<$$tname> for $$sname {
fn from(x: $$tname) -> Self {
let mut res = $$sname::zero();
res.values[0] = x as u64;
res
}
}|]
, [item|impl TryFrom<$$sname> for $$tname {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<Self,ConversionError> {
let mut goodConversion = true;
let mut res = 0;
res = x.values[0] as $$tname;
$@{testZeros}
if goodConversion {
Ok(res)
} else {
Err(ConversionError::Overflow)
}
}
}|]
, [item|impl<'a> TryFrom<&'a $$sname> for $$tname {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<Self,ConversionError> {
let mut goodConversion = true;
let mut res = 0;
res = x.values[0] as $$tname;
$@{testZeros}
if goodConversion {
Ok(res)
} else {
Err(ConversionError::Overflow)
}
}
}|]
]
where
testZeros = map (zeroTest . toLit) [1..entries-1]
zeroTest i =
[stmt| goodConversion &= x.values[$$(i)] == 0; |]
generateSignedPrims :: Ident -> Ident -> Ident -> [Item Span]
generateSignedPrims sname unsigned signed = [
[item|impl TryFrom<$$signed> for $$sname {
type Error = ConversionError;
fn try_from(x: $$signed) -> Result<Self,ConversionError> {
let mut res = $$sname::zero();
res.values[0] = x as u64;
if x < 0 {
Err(ConversionError::NegativeToUnsigned)
} else {
Ok(res)
}
}
}|]
, [item|impl TryFrom<$$sname> for $$signed {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<Self,ConversionError> {
let uns = $$unsigned::from(x)?;
Ok($$signed::try_from(uns)?)
}
}|]
, [item|impl<'a> TryFrom<&'a $$sname> for $$signed {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<Self,ConversionError> {
let uns = $$unsigned::from(x)?;
Ok($$signed::try_from(uns)?)
}
}|]
]
generateI128Primitives :: Ident -> [Item Span]
generateI128Primitives sname = [
[item|impl TryFrom<i128> for $$sname {
type Error = ConversionError;
fn try_from(x: i128) -> Result<Self,ConversionError> {
let mut res = $$sname::zero();
res.values[0] = x as u64;
res.values[1] = ((x as u128) >> 64) as u64;
if x < 0 {
Err(ConversionError::NegativeToUnsigned)
} else {
Ok(res)
}
}
}|]
, [item|impl TryFrom<$$sname> for i128 {
type Error = ConversionError;
fn try_from(x: $$sname) -> Result<Self,ConversionError> {
let uns = u128::from(x)?;
Ok(i128::try_from(uns)?)
}
}|]
, [item|impl<'a> TryFrom<&'a $$sname> for i128 {
type Error = ConversionError;
fn try_from(x: &$$sname) -> Result<Self,ConversionError> {
let uns = u128::from(x)?;
Ok(i128::try_from(uns)?)
}
}|]
]

View File

@@ -1,3 +1,4 @@
{-# LANGUAGE QuasiQuotes #-}
module CryptoNum( module CryptoNum(
cryptoNum cryptoNum
) )
@@ -6,6 +7,11 @@ module CryptoNum(
import Control.Monad(forM_) import Control.Monad(forM_)
import File import File
import Gen import Gen
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Pretty
import Language.Rust.Syntax
cryptoNum :: File cryptoNum :: File
cryptoNum = File { cryptoNum = File {
@@ -16,125 +22,125 @@ cryptoNum = File {
declareCryptoNumInstance :: Word -> Gen () declareCryptoNumInstance :: Word -> Gen ()
declareCryptoNumInstance bitsize = declareCryptoNumInstance bitsize =
do let name = "U" ++ show bitsize do let sname = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64 entries = bitsize `div` 64
entlit = Lit [] (Int Dec (fromIntegral entries) Unsuffixed mempty) mempty
top = entries - 1 top = entries - 1
out "use core::cmp::min;" zeroTests = generateZeroTests 0 entries
out "use crate::CryptoNum;" bitlength = toLit bitsize
out "#[cfg(test)]" bytelen = bitsize `div` 8
out "use crate::testing::{build_test_path,run_test};" bytelenlit = toLit bytelen
out "#[cfg(test)]" bytebuffer = Delimited mempty Brace (Stream [
out "use quickcheck::quickcheck;" Tree (Token mempty (LiteralTok (IntegerTok "0") Nothing)),
out ("use super::" ++ name ++ ";") Tree (Token mempty Semicolon),
blank Tree (Token mempty (LiteralTok (IntegerTok (show bytelen)) Nothing))
implFor "CryptoNum" name $ ])
do wrapIndent ("fn zero() -> Self") $ entrieslit = toLit entries
out (name ++ "{ value: [0; " ++ show entries ++ "] }") packerLines = generatePackerLines 0 (bitsize `div` 8)
blank out $ show $ pretty' $ [sourceFile|
wrapIndent ("fn is_zero(&self) -> bool") $ use core::cmp::min;
do forM_ (reverse [1..top]) $ \ i -> use crate::CryptoNum;
out ("self.value[" ++ show i ++ "] == 0 &&") #[cfg(test)]
out "self.value[0] == 0" use crate::testing::{build_test_path,run_test};
blank #[cfg(test)]
wrapIndent ("fn is_even(&self) -> bool") $ use quickcheck::quickcheck;
out "self.value[0] & 0x1 == 0" use super::$$sname;
blank
wrapIndent ("fn is_odd(&self) -> bool") $ impl CryptoNum for $$sname {
out "self.value[0] & 0x1 == 1" fn zero() -> Self {
blank $$sname{ value: [0; $$(entlit)] }
wrapIndent ("fn bit_length() -> usize") $ }
out (show bitsize) fn is_zero(&self) -> bool {
blank let mut result = true;
wrapIndent ("fn mask(&mut self, len: usize)") $ $@{zeroTests}
do out ("let dellen = min(len, " ++ show entries ++ ");") result
wrapIndent ("for i in dellen.." ++ show entries) $ }
out ("self.value[i] = 0;") fn is_even(&self) -> bool {
blank self.value[0] & 0x1 == 0
wrapIndent ("fn testbit(&self, bit: usize) -> bool") $ }
do out "let idx = bit / 64;" fn is_off(&self) -> bool {
out "let offset = bit % 64;" self.value[0] & 0x1 == 1
wrapIndent ("if idx >= " ++ show entries) $ }
out "return false;" fn bit_length() -> usize {
out "(self.value[idx] & (1u64 << offset)) != 0" $$(bitlength)
blank }
wrapIndent ("fn from_bytes(bytes: &[u8]) -> Self") $ fn mask(&mut self, len: usize) {
do out ("let biggest = min(" ++ show (bitsize `div` 8) ++ ", " ++ let dellen = min(len, $$(entrieslit));
"bytes.len()) - 1;") for i in dellen..$$(entrieslit) {
out ("let mut idx = biggest / 8;") self.value[i] = 0;
out ("let mut shift = (biggest % 8) * 8;") }
out ("let mut i = 0;") }
out ("let mut res = " ++ name ++ "::zero();") fn testbit(&self, bit: usize) -> bool {
blank let idx = bit / 64;
wrapIndent ("while i <= biggest") $ let offset = bit % 64;
do out ("res.value[idx] |= (bytes[i] as u64) << shift;") if idx >= $$(entrieslit) {
out ("i += 1;") return false;
out ("if shift == 0 {") }
indent $ (self.value[idx] & (1u64 << offset)) != 0
do out "shift = 56;" }
out "if idx > 0 { idx -= 1; }" fn from_bytes(bytes: &[u8]) -> Self {
out ("} else {") let biggest = min($$(bytelenlit), bytes.len()) - 1;
indent $ let mut idx = biggest / 8;
out "shift -= 8;" let mut shift = (biggest % 8) * 8;
out "}" let mut i = 0;
blank let mut res = $$sname::zero();
out "res"
blank while i <= biggest {
wrapIndent ("fn to_bytes(&self, bytes: &mut [u8])") $ res.value[idx] |= (bytes[i] as u64) << shift;
do let bytes = bitsize `div` 8 i += 1;
out ("if bytes.len() == 0 { return; }") if shift == 0 {
blank shift = 56;
forM_ [0..bytes-1] $ \ idx -> if idx > 0 {
do let (validx, shift) = byteShiftInfo idx idx -= 1;
out ("let byte" ++ show idx ++ " = (self.value[" ++ }
show validx ++ "] >> " ++ show shift ++ ")" ++ } else {
" as u8;") shift -= 8;
blank }
out ("let mut idx = min(bytes.len() - 1, " ++ show (bytes - 1) ++ ");") }
forM_ [0..bytes-2] $ \ i ->
do out ("bytes[idx] = byte" ++ show i ++ ";") res
out ("if idx == 0 { return; }") }
out ("idx -= 1;") fn to_bytes(&self, bytes: &mut [u8]) {
out ("bytes[idx] = byte" ++ show (bytes-1) ++ ";") let mut idx = 0;
blank let mut shift = 0;
out "#[cfg(test)]"
wrapIndent "quickcheck!" $ for x in bytes.iter_mut().take($$(bytelenlit)).reverse() {
do wrapIndent ("fn to_from_ident(x: " ++ name ++ ") -> bool") $ *x = (self.values[idx] >> shift) as u8;
do out ("let mut buffer = [0; " ++ show (bitsize `div` 8) ++ "];") shift += 8;
out ("x.to_bytes(&mut buffer);"); if shift == 64 {
out ("let y = " ++ name ++ "::from_bytes(&buffer);") idx += 1;
out ("x == y") shift = 0;
blank }
out "#[cfg(test)]" }
out "#[allow(non_snake_case)]" }
out "#[test]" }
wrapIndent "fn KATs()" $
do let name' = pad 5 '0' (show bitsize) #[cfg(test)]
out ("run_test(build_test_path(\"base\",\"" ++ name' ++ "\"), 8, |case| {") quickcheck! {
indent $ fn to_from_ident(x: $$sname) -> bool {
do out ("let (neg0, xbytes) = case.get(\"x\").unwrap();") let mut buffer = $$(bytebuffer);
out ("let (neg1, mbytes) = case.get(\"m\").unwrap();") x.to_bytes(&mut buffer);
out ("let (neg2, zbytes) = case.get(\"z\").unwrap();") let y = $$sname::from_bytes(&buffer);
out ("let (neg3, ebytes) = case.get(\"e\").unwrap();") x == y
out ("let (neg4, obytes) = case.get(\"o\").unwrap();") }
out ("let (neg5, rbytes) = case.get(\"r\").unwrap();") }
out ("let (neg6, bbytes) = case.get(\"b\").unwrap();")
out ("let (neg7, tbytes) = case.get(\"t\").unwrap();") #[cfg(test)]
out ("assert!(!neg0&&!neg1&&!neg2&&!neg3&&!neg4&&!neg5&&!neg6&&!neg7);") #[allow(non_snake_case)]
out ("let mut x = "++name++"::from_bytes(xbytes);") #[test]
out ("let m = "++name++"::from_bytes(mbytes);") fn KATs() {
out ("let z = 1 == zbytes[0];") run_test(build_test_path("base", stringify!($$sname)), 8, |case| {
out ("let e = 1 == ebytes[0];") let (neg0, xbytes) = case.get("x").unwrap();
out ("let o = 1 == obytes[0];") let (neg1, mbytes) = case.get("m").unwrap();
out ("let r = "++name++"::from_bytes(rbytes);") let (neg2, zbytes) = case.get("z").unwrap();
out ("let b = usize::from("++name++"::from_bytes(bbytes));") let (neg3, ebytes) = case.get("e").unwrap();
out ("let t = 1 == tbytes[0];") let (neg4, obytes) = case.get("o").unwrap();
out ("assert_eq!(x.is_zero(), z);") let (neg5, rbytes) = case.get("r").unwrap();
out ("assert_eq!(x.is_even(), e);") let (neg6, bbytes) = case.get("b").unwrap();
out ("assert_eq!(x.is_odd(), o);") let (neg7, tbytes) = case.get("t").unwrap();
out ("assert_eq!(x.testbit(b), t);") });
out ("x.mask(usize::from(&m));") }
out ("assert_eq!(x, r);") |]
out ("});")
byteShiftInfo :: Word -> (Word, Word) byteShiftInfo :: Word -> (Word, Word)
byteShiftInfo idx = byteShiftInfo idx =
@@ -143,4 +149,25 @@ byteShiftInfo idx =
pad :: Int -> Char -> String -> String pad :: Int -> Char -> String -> String
pad len c str pad len c str
| length str >= len = str | length str >= len = str
| otherwise = pad len c (c:str) | otherwise = pad len c (c:str)
generateZeroTests :: Word -> Word -> [Stmt Span]
generateZeroTests i max
| i == max = []
| otherwise =
let ilit = toLit i
in [stmt| result = self.values[$$(ilit)] == 0; |] :
generateZeroTests (i + 1) max
generatePackerLines :: Word -> Word -> [Stmt Span]
generatePackerLines i max
| i == max = []
| otherwise =
let ilit = toLit i
nextLit = toLit (i + 1)
validx = toLit (i `div` 8)
shiftx = toLit ((i `mod` 8) * 8)
writeLine = [stmt| bytes[$$(ilit)] = (self.values[$$(validx)] >> $$(shiftx)) as u8; |]
ifLine = [stmt| if bytes.len() == $$(nextLit) { return; } |]
in writeLine : ifLine : generatePackerLines (i + 1) max

View File

@@ -10,6 +10,7 @@ module Gen(
implFor, implFor,
implFor', implFor',
implFor'', implFor'',
toLit
) )
where where
@@ -18,6 +19,8 @@ import Control.Monad.State.Class(MonadState,get,put)
import Control.Monad.Writer.Class(MonadWriter,tell) import Control.Monad.Writer.Class(MonadWriter,tell)
import Data.List(replicate) import Data.List(replicate)
import Data.Word(Word) import Data.Word(Word)
import Language.Rust.Data.Position
import Language.Rust.Syntax
newtype Gen a = Gen { unGen :: RWS () String GenState a} newtype Gen a = Gen { unGen :: RWS () String GenState a}
deriving (Applicative, Functor, Monad, MonadState GenState, MonadWriter String) deriving (Applicative, Functor, Monad, MonadState GenState, MonadWriter String)
@@ -85,4 +88,9 @@ implFor' trait name middle =
implFor'' :: String -> String -> Gen a -> Gen a implFor'' :: String -> String -> Gen a -> Gen a
implFor'' trait name middle = implFor'' trait name middle =
wrapIndent ("impl<'a,'b> " ++ trait ++ " for " ++ name) middle wrapIndent ("impl<'a,'b> " ++ trait ++ " for " ++ name) middle
toLit :: Word -> Expr Span
toLit i = Lit [] (Int Dec (fromIntegral i) Unsuffixed mempty) mempty

View File

@@ -8,7 +8,6 @@ import Conversions(conversions)
import CryptoNum(cryptoNum) import CryptoNum(cryptoNum)
import Control.Monad(forM_,unless) import Control.Monad(forM_,unless)
import Data.Maybe(mapMaybe) import Data.Maybe(mapMaybe)
import Data.Word(Word)
import File(File,Task(..),addModuleTasks,makeTask) import File(File,Task(..),addModuleTasks,makeTask)
import Gen(runGen) import Gen(runGen)
import System.Directory(createDirectoryIfMissing) import System.Directory(createDirectoryIfMissing)
@@ -57,4 +56,4 @@ main =
forM_ (zip [(1::Word)..] tasks) $ \ (i, task) -> forM_ (zip [(1::Word)..] tasks) $ \ (i, task) ->
do putStrLn ("[" ++ show i ++ "/" ++ show total ++ "] " ++ outputFile task) do putStrLn ("[" ++ show i ++ "/" ++ show total ++ "] " ++ outputFile task)
createDirectoryIfMissing True (takeDirectory (outputFile task)) createDirectoryIfMissing True (takeDirectory (outputFile task))
runGen (outputFile task) (fileGenerator task) runGen (outputFile task) (fileGenerator task)

View File

@@ -4,6 +4,8 @@ pub mod unsigned;
#[cfg(test)] #[cfg(test)]
mod testing; mod testing;
use core::num::TryFromIntError;
/// A trait definition for large numbers. /// A trait definition for large numbers.
pub trait CryptoNum { pub trait CryptoNum {
/// Generate a new value of the given type. /// Generate a new value of the given type.
@@ -36,3 +38,14 @@ pub trait CryptoNum {
fn to_bytes(&self, bytes: &mut [u8]); fn to_bytes(&self, bytes: &mut [u8]);
} }
/// An error in conversion of large numbers (either to primitives or to other numbers
pub enum ConversionError {
NegativeToUnsigned,
Overflow
}
impl From<TryFromIntError> for ConversionError {
fn from(_: TryFromIntError) -> ConversionError {
ConversionError::Overflow
}
}