diff --git a/src/cryptonum/complete_arith.rs b/src/cryptonum/complete_arith.rs index eb6b467..9b828b5 100644 --- a/src/cryptonum/complete_arith.rs +++ b/src/cryptonum/complete_arith.rs @@ -48,3 +48,79 @@ macro_rules! derive_arithmetic_operators } } } + +macro_rules! derive_shift_operators +{ + ($type: ident, $asncl: ident, $cl: ident, + $asnfn: ident, $fn: ident, + $base: ident) => + { + impl $asncl<$base> for $type { + fn $asnfn(&mut self, rhs: $base) { + self.$asnfn(rhs as u64); + } + } + + derive_shifts_from_shift_assign!($type, $asncl, $cl, + $asnfn, $fn, + $base); + } +} + +macro_rules! derive_shifts_from_shift_assign +{ + ($type: ident, $asncl: ident, $cl: ident, + $asnfn: ident, $fn: ident, + $base: ident) => + { + impl $cl<$base> for $type { + type Output = $type; + + fn $fn(self, rhs: $base) -> $type { + let mut copy = self.clone(); + copy.$asnfn(rhs); + copy + } + } + + impl<'a> $cl<$base> for &'a $type { + type Output = $type; + + fn $fn(self, rhs: $base) -> $type { + let mut copy = self.clone(); + copy.$asnfn(rhs); + copy + } + } + } +} + +macro_rules! derive_signed_shift_operators +{ + ($type: ident, $base: ident, $signed_base: ident) => { + impl ShlAssign<$signed_base> for $type { + fn shl_assign(&mut self, rhs: $signed_base) { + if rhs < 0 { + self.shr_assign(-rhs); + } else { + self.shl_assign(rhs as $base); + } + } + } + + impl ShrAssign<$signed_base> for $type { + fn shr_assign(&mut self, rhs: $signed_base) { + if rhs < 0 { + self.shl_assign(-rhs); + } else { + self.shr_assign(rhs); + } + } + } + + derive_shifts_from_shift_assign!($type, ShlAssign, Shl, + shl_assign, shl, $signed_base); + derive_shifts_from_shift_assign!($type, ShrAssign, Shr, + shr_assign, shr, $signed_base); + } +} diff --git a/src/cryptonum/mod.rs b/src/cryptonum/mod.rs index 1dace16..1b3f3b5 100644 --- a/src/cryptonum/mod.rs +++ b/src/cryptonum/mod.rs @@ -309,6 +309,100 @@ derive_arithmetic_operators!(UCN, BitOr, bitor, BitOrAssign, bitor_assign); derive_arithmetic_operators!(UCN, BitXor, bitxor, BitXorAssign, bitxor_assign); derive_arithmetic_operators!(UCN, BitAnd, bitand, BitAndAssign, bitand_assign); +//------------------------------------------------------------------------------ +// +// Shifts +// +//------------------------------------------------------------------------------ + +impl ShlAssign for UCN { + fn shl_assign(&mut self, rhs: u64) { + let mut digits = rhs / 64; + let bits = rhs % 64; + let mut carry = 0; + + // ripple the bit-level shift through + if bits != 0 { + for x in self.contents.iter_mut() { + let new_carry = *x >> (64 - bits); + *x = (*x << bits) | carry; + carry = new_carry; + } + } + + // if we pulled some stuff off the end, add it back + if carry != 0 { + self.contents.push(carry); + } + + // add the appropriate digits on the low side + while digits > 0 { + self.contents.insert(0,0); + digits -= 1; + } + } +} + +impl Shl for UCN { + type Output = UCN; + + fn shl(self, rhs: u64) -> UCN { + let mut copy = self.clone(); + copy.shl_assign(rhs); + copy + } +} + +derive_shift_operators!(UCN, ShlAssign, Shl, shl_assign, shl, usize); +derive_shift_operators!(UCN, ShlAssign, Shl, shl_assign, shl, u32); +derive_shift_operators!(UCN, ShlAssign, Shl, shl_assign, shl, u16); +derive_shift_operators!(UCN, ShlAssign, Shl, shl_assign, shl, u8); + +impl ShrAssign for UCN { + fn shr_assign(&mut self, rhs: u64) { + let mut digits = rhs / 64; + let bits = rhs % 64; + + // remove the appropriate digits on the low side + while digits > 0 { + self.contents.remove(0); + digits -= 1; + } + // ripple the shifts over + let mut carry = 0; + let mask = !(0xFFFFFFFFFFFFFFFF << bits); + + for x in self.contents.iter_mut().rev() { + let base = *x >> bits; + let (new_carry, _) = (*x & mask).overflowing_shl((64-bits) as u32); + *x = base | carry; + carry = new_carry; + } + // in this case, we just junk the extra carry bits + } +} + +impl Shr for UCN { + type Output = UCN; + + fn shr(self, rhs: u64) -> UCN { + let mut copy = self.clone(); + copy.shr_assign(rhs); + copy + } +} + +derive_shift_operators!(UCN, ShrAssign, Shr, shr_assign, shr, usize); +derive_shift_operators!(UCN, ShrAssign, Shr, shr_assign, shr, u32); +derive_shift_operators!(UCN, ShrAssign, Shr, shr_assign, shr, u16); +derive_shift_operators!(UCN, ShrAssign, Shr, shr_assign, shr, u8); + +derive_signed_shift_operators!(UCN, usize, isize); +derive_signed_shift_operators!(UCN, u64, i64); +derive_signed_shift_operators!(UCN, u32, i32); +derive_signed_shift_operators!(UCN, u16, i16); +derive_signed_shift_operators!(UCN, u8, i8); + //------------------------------------------------------------------------------ // // Tests! @@ -475,6 +569,12 @@ mod test { let effs = UCN{ contents: contents }; (&a & &effs) == a } + fn shl_identity(a: UCN) -> bool { + (&a << 0) == a + } + fn shr_identity(a: UCN) -> bool { + (&a << 0) == a + } } quickcheck! { @@ -488,6 +588,9 @@ mod test { let zero = UCN{ contents: vec![] }; (&a & &zero) == zero } + fn shl_shr_annihilate(a: UCN, b: u8) -> bool { + ((&a << b) >> b) == a + } fn xor_inverse(a: UCN, b: UCN) -> bool { ((&a ^ &b) ^ &b) == a } @@ -521,4 +624,5 @@ mod test { (!(&a2 | &b2)) == (!a2 & !b2) } } + }