Start working on switching to language-rust as a generator, for fun.

This commit is contained in:
2019-10-22 20:12:08 -07:00
parent d7665acf64
commit 2400b10fbc
9 changed files with 723 additions and 404 deletions

View File

@@ -1,3 +1,4 @@
{-# LANGUAGE QuasiQuotes #-}
module CryptoNum(
cryptoNum
)
@@ -6,6 +7,11 @@ module CryptoNum(
import Control.Monad(forM_)
import File
import Gen
import Language.Rust.Data.Ident
import Language.Rust.Data.Position
import Language.Rust.Quote
import Language.Rust.Pretty
import Language.Rust.Syntax
cryptoNum :: File
cryptoNum = File {
@@ -16,125 +22,125 @@ cryptoNum = File {
declareCryptoNumInstance :: Word -> Gen ()
declareCryptoNumInstance bitsize =
do let name = "U" ++ show bitsize
do let sname = mkIdent ("U" ++ show bitsize)
entries = bitsize `div` 64
entlit = Lit [] (Int Dec (fromIntegral entries) Unsuffixed mempty) mempty
top = entries - 1
out "use core::cmp::min;"
out "use crate::CryptoNum;"
out "#[cfg(test)]"
out "use crate::testing::{build_test_path,run_test};"
out "#[cfg(test)]"
out "use quickcheck::quickcheck;"
out ("use super::" ++ name ++ ";")
blank
implFor "CryptoNum" name $
do wrapIndent ("fn zero() -> Self") $
out (name ++ "{ value: [0; " ++ show entries ++ "] }")
blank
wrapIndent ("fn is_zero(&self) -> bool") $
do forM_ (reverse [1..top]) $ \ i ->
out ("self.value[" ++ show i ++ "] == 0 &&")
out "self.value[0] == 0"
blank
wrapIndent ("fn is_even(&self) -> bool") $
out "self.value[0] & 0x1 == 0"
blank
wrapIndent ("fn is_odd(&self) -> bool") $
out "self.value[0] & 0x1 == 1"
blank
wrapIndent ("fn bit_length() -> usize") $
out (show bitsize)
blank
wrapIndent ("fn mask(&mut self, len: usize)") $
do out ("let dellen = min(len, " ++ show entries ++ ");")
wrapIndent ("for i in dellen.." ++ show entries) $
out ("self.value[i] = 0;")
blank
wrapIndent ("fn testbit(&self, bit: usize) -> bool") $
do out "let idx = bit / 64;"
out "let offset = bit % 64;"
wrapIndent ("if idx >= " ++ show entries) $
out "return false;"
out "(self.value[idx] & (1u64 << offset)) != 0"
blank
wrapIndent ("fn from_bytes(bytes: &[u8]) -> Self") $
do out ("let biggest = min(" ++ show (bitsize `div` 8) ++ ", " ++
"bytes.len()) - 1;")
out ("let mut idx = biggest / 8;")
out ("let mut shift = (biggest % 8) * 8;")
out ("let mut i = 0;")
out ("let mut res = " ++ name ++ "::zero();")
blank
wrapIndent ("while i <= biggest") $
do out ("res.value[idx] |= (bytes[i] as u64) << shift;")
out ("i += 1;")
out ("if shift == 0 {")
indent $
do out "shift = 56;"
out "if idx > 0 { idx -= 1; }"
out ("} else {")
indent $
out "shift -= 8;"
out "}"
blank
out "res"
blank
wrapIndent ("fn to_bytes(&self, bytes: &mut [u8])") $
do let bytes = bitsize `div` 8
out ("if bytes.len() == 0 { return; }")
blank
forM_ [0..bytes-1] $ \ idx ->
do let (validx, shift) = byteShiftInfo idx
out ("let byte" ++ show idx ++ " = (self.value[" ++
show validx ++ "] >> " ++ show shift ++ ")" ++
" as u8;")
blank
out ("let mut idx = min(bytes.len() - 1, " ++ show (bytes - 1) ++ ");")
forM_ [0..bytes-2] $ \ i ->
do out ("bytes[idx] = byte" ++ show i ++ ";")
out ("if idx == 0 { return; }")
out ("idx -= 1;")
out ("bytes[idx] = byte" ++ show (bytes-1) ++ ";")
blank
out "#[cfg(test)]"
wrapIndent "quickcheck!" $
do wrapIndent ("fn to_from_ident(x: " ++ name ++ ") -> bool") $
do out ("let mut buffer = [0; " ++ show (bitsize `div` 8) ++ "];")
out ("x.to_bytes(&mut buffer);");
out ("let y = " ++ name ++ "::from_bytes(&buffer);")
out ("x == y")
blank
out "#[cfg(test)]"
out "#[allow(non_snake_case)]"
out "#[test]"
wrapIndent "fn KATs()" $
do let name' = pad 5 '0' (show bitsize)
out ("run_test(build_test_path(\"base\",\"" ++ name' ++ "\"), 8, |case| {")
indent $
do out ("let (neg0, xbytes) = case.get(\"x\").unwrap();")
out ("let (neg1, mbytes) = case.get(\"m\").unwrap();")
out ("let (neg2, zbytes) = case.get(\"z\").unwrap();")
out ("let (neg3, ebytes) = case.get(\"e\").unwrap();")
out ("let (neg4, obytes) = case.get(\"o\").unwrap();")
out ("let (neg5, rbytes) = case.get(\"r\").unwrap();")
out ("let (neg6, bbytes) = case.get(\"b\").unwrap();")
out ("let (neg7, tbytes) = case.get(\"t\").unwrap();")
out ("assert!(!neg0&&!neg1&&!neg2&&!neg3&&!neg4&&!neg5&&!neg6&&!neg7);")
out ("let mut x = "++name++"::from_bytes(xbytes);")
out ("let m = "++name++"::from_bytes(mbytes);")
out ("let z = 1 == zbytes[0];")
out ("let e = 1 == ebytes[0];")
out ("let o = 1 == obytes[0];")
out ("let r = "++name++"::from_bytes(rbytes);")
out ("let b = usize::from("++name++"::from_bytes(bbytes));")
out ("let t = 1 == tbytes[0];")
out ("assert_eq!(x.is_zero(), z);")
out ("assert_eq!(x.is_even(), e);")
out ("assert_eq!(x.is_odd(), o);")
out ("assert_eq!(x.testbit(b), t);")
out ("x.mask(usize::from(&m));")
out ("assert_eq!(x, r);")
out ("});")
zeroTests = generateZeroTests 0 entries
bitlength = toLit bitsize
bytelen = bitsize `div` 8
bytelenlit = toLit bytelen
bytebuffer = Delimited mempty Brace (Stream [
Tree (Token mempty (LiteralTok (IntegerTok "0") Nothing)),
Tree (Token mempty Semicolon),
Tree (Token mempty (LiteralTok (IntegerTok (show bytelen)) Nothing))
])
entrieslit = toLit entries
packerLines = generatePackerLines 0 (bitsize `div` 8)
out $ show $ pretty' $ [sourceFile|
use core::cmp::min;
use crate::CryptoNum;
#[cfg(test)]
use crate::testing::{build_test_path,run_test};
#[cfg(test)]
use quickcheck::quickcheck;
use super::$$sname;
impl CryptoNum for $$sname {
fn zero() -> Self {
$$sname{ value: [0; $$(entlit)] }
}
fn is_zero(&self) -> bool {
let mut result = true;
$@{zeroTests}
result
}
fn is_even(&self) -> bool {
self.value[0] & 0x1 == 0
}
fn is_off(&self) -> bool {
self.value[0] & 0x1 == 1
}
fn bit_length() -> usize {
$$(bitlength)
}
fn mask(&mut self, len: usize) {
let dellen = min(len, $$(entrieslit));
for i in dellen..$$(entrieslit) {
self.value[i] = 0;
}
}
fn testbit(&self, bit: usize) -> bool {
let idx = bit / 64;
let offset = bit % 64;
if idx >= $$(entrieslit) {
return false;
}
(self.value[idx] & (1u64 << offset)) != 0
}
fn from_bytes(bytes: &[u8]) -> Self {
let biggest = min($$(bytelenlit), bytes.len()) - 1;
let mut idx = biggest / 8;
let mut shift = (biggest % 8) * 8;
let mut i = 0;
let mut res = $$sname::zero();
while i <= biggest {
res.value[idx] |= (bytes[i] as u64) << shift;
i += 1;
if shift == 0 {
shift = 56;
if idx > 0 {
idx -= 1;
}
} else {
shift -= 8;
}
}
res
}
fn to_bytes(&self, bytes: &mut [u8]) {
let mut idx = 0;
let mut shift = 0;
for x in bytes.iter_mut().take($$(bytelenlit)).reverse() {
*x = (self.values[idx] >> shift) as u8;
shift += 8;
if shift == 64 {
idx += 1;
shift = 0;
}
}
}
}
#[cfg(test)]
quickcheck! {
fn to_from_ident(x: $$sname) -> bool {
let mut buffer = $$(bytebuffer);
x.to_bytes(&mut buffer);
let y = $$sname::from_bytes(&buffer);
x == y
}
}
#[cfg(test)]
#[allow(non_snake_case)]
#[test]
fn KATs() {
run_test(build_test_path("base", stringify!($$sname)), 8, |case| {
let (neg0, xbytes) = case.get("x").unwrap();
let (neg1, mbytes) = case.get("m").unwrap();
let (neg2, zbytes) = case.get("z").unwrap();
let (neg3, ebytes) = case.get("e").unwrap();
let (neg4, obytes) = case.get("o").unwrap();
let (neg5, rbytes) = case.get("r").unwrap();
let (neg6, bbytes) = case.get("b").unwrap();
let (neg7, tbytes) = case.get("t").unwrap();
});
}
|]
byteShiftInfo :: Word -> (Word, Word)
byteShiftInfo idx =
@@ -143,4 +149,25 @@ byteShiftInfo idx =
pad :: Int -> Char -> String -> String
pad len c str
| length str >= len = str
| otherwise = pad len c (c:str)
| otherwise = pad len c (c:str)
generateZeroTests :: Word -> Word -> [Stmt Span]
generateZeroTests i max
| i == max = []
| otherwise =
let ilit = toLit i
in [stmt| result = self.values[$$(ilit)] == 0; |] :
generateZeroTests (i + 1) max
generatePackerLines :: Word -> Word -> [Stmt Span]
generatePackerLines i max
| i == max = []
| otherwise =
let ilit = toLit i
nextLit = toLit (i + 1)
validx = toLit (i `div` 8)
shiftx = toLit ((i `mod` 8) * 8)
writeLine = [stmt| bytes[$$(ilit)] = (self.values[$$(validx)] >> $$(shiftx)) as u8; |]
ifLine = [stmt| if bytes.len() == $$(nextLit) { return; } |]
in writeLine : ifLine : generatePackerLines (i + 1) max