Start working on switching to language-rust as a generator, for fun.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE QuasiQuotes #-}
|
||||
module CryptoNum(
|
||||
cryptoNum
|
||||
)
|
||||
@@ -6,6 +7,11 @@ module CryptoNum(
|
||||
import Control.Monad(forM_)
|
||||
import File
|
||||
import Gen
|
||||
import Language.Rust.Data.Ident
|
||||
import Language.Rust.Data.Position
|
||||
import Language.Rust.Quote
|
||||
import Language.Rust.Pretty
|
||||
import Language.Rust.Syntax
|
||||
|
||||
cryptoNum :: File
|
||||
cryptoNum = File {
|
||||
@@ -16,125 +22,125 @@ cryptoNum = File {
|
||||
|
||||
declareCryptoNumInstance :: Word -> Gen ()
|
||||
declareCryptoNumInstance bitsize =
|
||||
do let name = "U" ++ show bitsize
|
||||
do let sname = mkIdent ("U" ++ show bitsize)
|
||||
entries = bitsize `div` 64
|
||||
entlit = Lit [] (Int Dec (fromIntegral entries) Unsuffixed mempty) mempty
|
||||
top = entries - 1
|
||||
out "use core::cmp::min;"
|
||||
out "use crate::CryptoNum;"
|
||||
out "#[cfg(test)]"
|
||||
out "use crate::testing::{build_test_path,run_test};"
|
||||
out "#[cfg(test)]"
|
||||
out "use quickcheck::quickcheck;"
|
||||
out ("use super::" ++ name ++ ";")
|
||||
blank
|
||||
implFor "CryptoNum" name $
|
||||
do wrapIndent ("fn zero() -> Self") $
|
||||
out (name ++ "{ value: [0; " ++ show entries ++ "] }")
|
||||
blank
|
||||
wrapIndent ("fn is_zero(&self) -> bool") $
|
||||
do forM_ (reverse [1..top]) $ \ i ->
|
||||
out ("self.value[" ++ show i ++ "] == 0 &&")
|
||||
out "self.value[0] == 0"
|
||||
blank
|
||||
wrapIndent ("fn is_even(&self) -> bool") $
|
||||
out "self.value[0] & 0x1 == 0"
|
||||
blank
|
||||
wrapIndent ("fn is_odd(&self) -> bool") $
|
||||
out "self.value[0] & 0x1 == 1"
|
||||
blank
|
||||
wrapIndent ("fn bit_length() -> usize") $
|
||||
out (show bitsize)
|
||||
blank
|
||||
wrapIndent ("fn mask(&mut self, len: usize)") $
|
||||
do out ("let dellen = min(len, " ++ show entries ++ ");")
|
||||
wrapIndent ("for i in dellen.." ++ show entries) $
|
||||
out ("self.value[i] = 0;")
|
||||
blank
|
||||
wrapIndent ("fn testbit(&self, bit: usize) -> bool") $
|
||||
do out "let idx = bit / 64;"
|
||||
out "let offset = bit % 64;"
|
||||
wrapIndent ("if idx >= " ++ show entries) $
|
||||
out "return false;"
|
||||
out "(self.value[idx] & (1u64 << offset)) != 0"
|
||||
blank
|
||||
wrapIndent ("fn from_bytes(bytes: &[u8]) -> Self") $
|
||||
do out ("let biggest = min(" ++ show (bitsize `div` 8) ++ ", " ++
|
||||
"bytes.len()) - 1;")
|
||||
out ("let mut idx = biggest / 8;")
|
||||
out ("let mut shift = (biggest % 8) * 8;")
|
||||
out ("let mut i = 0;")
|
||||
out ("let mut res = " ++ name ++ "::zero();")
|
||||
blank
|
||||
wrapIndent ("while i <= biggest") $
|
||||
do out ("res.value[idx] |= (bytes[i] as u64) << shift;")
|
||||
out ("i += 1;")
|
||||
out ("if shift == 0 {")
|
||||
indent $
|
||||
do out "shift = 56;"
|
||||
out "if idx > 0 { idx -= 1; }"
|
||||
out ("} else {")
|
||||
indent $
|
||||
out "shift -= 8;"
|
||||
out "}"
|
||||
blank
|
||||
out "res"
|
||||
blank
|
||||
wrapIndent ("fn to_bytes(&self, bytes: &mut [u8])") $
|
||||
do let bytes = bitsize `div` 8
|
||||
out ("if bytes.len() == 0 { return; }")
|
||||
blank
|
||||
forM_ [0..bytes-1] $ \ idx ->
|
||||
do let (validx, shift) = byteShiftInfo idx
|
||||
out ("let byte" ++ show idx ++ " = (self.value[" ++
|
||||
show validx ++ "] >> " ++ show shift ++ ")" ++
|
||||
" as u8;")
|
||||
blank
|
||||
out ("let mut idx = min(bytes.len() - 1, " ++ show (bytes - 1) ++ ");")
|
||||
forM_ [0..bytes-2] $ \ i ->
|
||||
do out ("bytes[idx] = byte" ++ show i ++ ";")
|
||||
out ("if idx == 0 { return; }")
|
||||
out ("idx -= 1;")
|
||||
out ("bytes[idx] = byte" ++ show (bytes-1) ++ ";")
|
||||
blank
|
||||
out "#[cfg(test)]"
|
||||
wrapIndent "quickcheck!" $
|
||||
do wrapIndent ("fn to_from_ident(x: " ++ name ++ ") -> bool") $
|
||||
do out ("let mut buffer = [0; " ++ show (bitsize `div` 8) ++ "];")
|
||||
out ("x.to_bytes(&mut buffer);");
|
||||
out ("let y = " ++ name ++ "::from_bytes(&buffer);")
|
||||
out ("x == y")
|
||||
blank
|
||||
out "#[cfg(test)]"
|
||||
out "#[allow(non_snake_case)]"
|
||||
out "#[test]"
|
||||
wrapIndent "fn KATs()" $
|
||||
do let name' = pad 5 '0' (show bitsize)
|
||||
out ("run_test(build_test_path(\"base\",\"" ++ name' ++ "\"), 8, |case| {")
|
||||
indent $
|
||||
do out ("let (neg0, xbytes) = case.get(\"x\").unwrap();")
|
||||
out ("let (neg1, mbytes) = case.get(\"m\").unwrap();")
|
||||
out ("let (neg2, zbytes) = case.get(\"z\").unwrap();")
|
||||
out ("let (neg3, ebytes) = case.get(\"e\").unwrap();")
|
||||
out ("let (neg4, obytes) = case.get(\"o\").unwrap();")
|
||||
out ("let (neg5, rbytes) = case.get(\"r\").unwrap();")
|
||||
out ("let (neg6, bbytes) = case.get(\"b\").unwrap();")
|
||||
out ("let (neg7, tbytes) = case.get(\"t\").unwrap();")
|
||||
out ("assert!(!neg0&&!neg1&&!neg2&&!neg3&&!neg4&&!neg5&&!neg6&&!neg7);")
|
||||
out ("let mut x = "++name++"::from_bytes(xbytes);")
|
||||
out ("let m = "++name++"::from_bytes(mbytes);")
|
||||
out ("let z = 1 == zbytes[0];")
|
||||
out ("let e = 1 == ebytes[0];")
|
||||
out ("let o = 1 == obytes[0];")
|
||||
out ("let r = "++name++"::from_bytes(rbytes);")
|
||||
out ("let b = usize::from("++name++"::from_bytes(bbytes));")
|
||||
out ("let t = 1 == tbytes[0];")
|
||||
out ("assert_eq!(x.is_zero(), z);")
|
||||
out ("assert_eq!(x.is_even(), e);")
|
||||
out ("assert_eq!(x.is_odd(), o);")
|
||||
out ("assert_eq!(x.testbit(b), t);")
|
||||
out ("x.mask(usize::from(&m));")
|
||||
out ("assert_eq!(x, r);")
|
||||
out ("});")
|
||||
zeroTests = generateZeroTests 0 entries
|
||||
bitlength = toLit bitsize
|
||||
bytelen = bitsize `div` 8
|
||||
bytelenlit = toLit bytelen
|
||||
bytebuffer = Delimited mempty Brace (Stream [
|
||||
Tree (Token mempty (LiteralTok (IntegerTok "0") Nothing)),
|
||||
Tree (Token mempty Semicolon),
|
||||
Tree (Token mempty (LiteralTok (IntegerTok (show bytelen)) Nothing))
|
||||
])
|
||||
entrieslit = toLit entries
|
||||
packerLines = generatePackerLines 0 (bitsize `div` 8)
|
||||
out $ show $ pretty' $ [sourceFile|
|
||||
use core::cmp::min;
|
||||
use crate::CryptoNum;
|
||||
#[cfg(test)]
|
||||
use crate::testing::{build_test_path,run_test};
|
||||
#[cfg(test)]
|
||||
use quickcheck::quickcheck;
|
||||
use super::$$sname;
|
||||
|
||||
impl CryptoNum for $$sname {
|
||||
fn zero() -> Self {
|
||||
$$sname{ value: [0; $$(entlit)] }
|
||||
}
|
||||
fn is_zero(&self) -> bool {
|
||||
let mut result = true;
|
||||
$@{zeroTests}
|
||||
result
|
||||
}
|
||||
fn is_even(&self) -> bool {
|
||||
self.value[0] & 0x1 == 0
|
||||
}
|
||||
fn is_off(&self) -> bool {
|
||||
self.value[0] & 0x1 == 1
|
||||
}
|
||||
fn bit_length() -> usize {
|
||||
$$(bitlength)
|
||||
}
|
||||
fn mask(&mut self, len: usize) {
|
||||
let dellen = min(len, $$(entrieslit));
|
||||
for i in dellen..$$(entrieslit) {
|
||||
self.value[i] = 0;
|
||||
}
|
||||
}
|
||||
fn testbit(&self, bit: usize) -> bool {
|
||||
let idx = bit / 64;
|
||||
let offset = bit % 64;
|
||||
if idx >= $$(entrieslit) {
|
||||
return false;
|
||||
}
|
||||
(self.value[idx] & (1u64 << offset)) != 0
|
||||
}
|
||||
fn from_bytes(bytes: &[u8]) -> Self {
|
||||
let biggest = min($$(bytelenlit), bytes.len()) - 1;
|
||||
let mut idx = biggest / 8;
|
||||
let mut shift = (biggest % 8) * 8;
|
||||
let mut i = 0;
|
||||
let mut res = $$sname::zero();
|
||||
|
||||
while i <= biggest {
|
||||
res.value[idx] |= (bytes[i] as u64) << shift;
|
||||
i += 1;
|
||||
if shift == 0 {
|
||||
shift = 56;
|
||||
if idx > 0 {
|
||||
idx -= 1;
|
||||
}
|
||||
} else {
|
||||
shift -= 8;
|
||||
}
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
fn to_bytes(&self, bytes: &mut [u8]) {
|
||||
let mut idx = 0;
|
||||
let mut shift = 0;
|
||||
|
||||
for x in bytes.iter_mut().take($$(bytelenlit)).reverse() {
|
||||
*x = (self.values[idx] >> shift) as u8;
|
||||
shift += 8;
|
||||
if shift == 64 {
|
||||
idx += 1;
|
||||
shift = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
quickcheck! {
|
||||
fn to_from_ident(x: $$sname) -> bool {
|
||||
let mut buffer = $$(bytebuffer);
|
||||
x.to_bytes(&mut buffer);
|
||||
let y = $$sname::from_bytes(&buffer);
|
||||
x == y
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(non_snake_case)]
|
||||
#[test]
|
||||
fn KATs() {
|
||||
run_test(build_test_path("base", stringify!($$sname)), 8, |case| {
|
||||
let (neg0, xbytes) = case.get("x").unwrap();
|
||||
let (neg1, mbytes) = case.get("m").unwrap();
|
||||
let (neg2, zbytes) = case.get("z").unwrap();
|
||||
let (neg3, ebytes) = case.get("e").unwrap();
|
||||
let (neg4, obytes) = case.get("o").unwrap();
|
||||
let (neg5, rbytes) = case.get("r").unwrap();
|
||||
let (neg6, bbytes) = case.get("b").unwrap();
|
||||
let (neg7, tbytes) = case.get("t").unwrap();
|
||||
});
|
||||
}
|
||||
|]
|
||||
|
||||
byteShiftInfo :: Word -> (Word, Word)
|
||||
byteShiftInfo idx =
|
||||
@@ -143,4 +149,25 @@ byteShiftInfo idx =
|
||||
pad :: Int -> Char -> String -> String
|
||||
pad len c str
|
||||
| length str >= len = str
|
||||
| otherwise = pad len c (c:str)
|
||||
| otherwise = pad len c (c:str)
|
||||
|
||||
generateZeroTests :: Word -> Word -> [Stmt Span]
|
||||
generateZeroTests i max
|
||||
| i == max = []
|
||||
| otherwise =
|
||||
let ilit = toLit i
|
||||
in [stmt| result = self.values[$$(ilit)] == 0; |] :
|
||||
generateZeroTests (i + 1) max
|
||||
|
||||
generatePackerLines :: Word -> Word -> [Stmt Span]
|
||||
generatePackerLines i max
|
||||
| i == max = []
|
||||
| otherwise =
|
||||
let ilit = toLit i
|
||||
nextLit = toLit (i + 1)
|
||||
validx = toLit (i `div` 8)
|
||||
shiftx = toLit ((i `mod` 8) * 8)
|
||||
writeLine = [stmt| bytes[$$(ilit)] = (self.values[$$(validx)] >> $$(shiftx)) as u8; |]
|
||||
ifLine = [stmt| if bytes.len() == $$(nextLit) { return; } |]
|
||||
in writeLine : ifLine : generatePackerLines (i + 1) max
|
||||
|
||||
|
||||
Reference in New Issue
Block a user