diff --git a/generation/src/CryptoNum.hs b/generation/src/CryptoNum.hs index f7b2731..bee2492 100644 --- a/generation/src/CryptoNum.hs +++ b/generation/src/CryptoNum.hs @@ -24,7 +24,9 @@ declareCryptoNumInstance bitsize = out "#[cfg(test)]" out "use crate::testing::{build_test_path,run_test};" out "#[cfg(test)]" - out "use quickcheck::quickcheck;" + out "use quickcheck::{Arbitrary,Gen,quickcheck};" + out "#[cfg(test)]" + out "use std::fmt;" out ("use super::" ++ name ++ ";") blank implFor "CryptoNum" name $ @@ -96,13 +98,43 @@ declareCryptoNumInstance bitsize = out ("idx -= 1;") out ("bytes[idx] = byte" ++ show (bytes-1) ++ ";") blank + let bytes = bitsize `div` 8 + struct = "Bytes" ++ show bytes + out "#[cfg(test)]" + out "#[derive(Clone)]" + wrapIndent ("struct " ++ struct) $ + out ("value: [u8; " ++ show bytes ++ "]") + blank + out "#[cfg(test)]" + implFor "PartialEq" struct $ + wrapIndent ("fn eq(&self, other: &Self) -> bool") $ + out "self.value.iter().zip(other.value.iter()).all(|(a,b)| a == b)" + blank + out "#[cfg(test)]" + implFor "fmt::Debug" struct $ + wrapIndent ("fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result") $ + out "f.debug_list().entries(self.value.iter()).finish()" + blank + out "#[cfg(test)]" + implFor "Arbitrary" struct $ + wrapIndent ("fn arbitrary(g: &mut G) -> Self") $ + do out ("let mut res = " ++ struct ++ "{ value: [0; " ++ show bytes ++ "] };") + out ("g.fill_bytes(&mut res.value);") + out ("res") + blank out "#[cfg(test)]" wrapIndent "quickcheck!" $ do wrapIndent ("fn to_from_ident(x: " ++ name ++ ") -> bool") $ - do out ("let mut buffer = [0; " ++ show (bitsize `div` 8) ++ "];") + do out ("let mut buffer = [0; " ++ show bytes ++ "];") out ("x.to_bytes(&mut buffer);"); out ("let y = " ++ name ++ "::from_bytes(&buffer);") out ("x == y") + blank + wrapIndent ("fn from_to_ident(x: " ++ struct ++ ") -> bool") $ + do out ("let val = " ++ name ++ "::from_bytes(&x.value);") + out ("let mut buffer = [0; " ++ show bytes ++ "];") + out ("val.to_bytes(&mut buffer);") + out ("buffer.iter().zip(x.value.iter()).all(|(a,b)| a == b)") blank out "#[cfg(test)]" out "#[allow(non_snake_case)]"