diff --git a/graviola/src/low/aarch64/aes.rs b/graviola/src/low/aarch64/aes.rs index b3658630..51d8c702 100644 --- a/graviola/src/low/aarch64/aes.rs +++ b/graviola/src/low/aarch64/aes.rs @@ -38,16 +38,76 @@ impl AesKey { } } - pub(crate) fn round_keys(&self) -> (uint8x16_t, &[uint8x16_t], uint8x16_t) { + pub(crate) fn ctr(&self, initial_counter: &[u8; 16], cipher_inout: &mut [u8]) { + // SAFETY: this crate requires the `aes` & `neon` cpu features + unsafe { self._ctr(initial_counter, cipher_inout) } + } + + #[target_feature(enable = "aes,neon")] + unsafe fn _ctr(&self, initial_counter: &[u8; 16], cipher_inout: &mut [u8]) { + // counter and inc are big endian, so must be vrev32q_u8'd before use (and after increment) + let counter = vld1q_u8(initial_counter.as_ptr().cast()); + let mut counter = vreinterpretq_u32_u8(vrev32q_u8(counter)); + + let inc = vsetq_lane_u8(1, vdupq_n_u8(0), 15); + let inc = vreinterpretq_u32_u8(vrev32q_u8(inc)); + + let mut exact = cipher_inout.chunks_exact_mut(16); + + for cipher in exact.by_ref() { + counter = vaddq_u32(counter, inc); + let mut block = vrev32q_u8(vreinterpretq_u8_u32(counter)); + + let (rks, rkn2, rkn1) = self.round_keys(); + for rk in rks { + block = vaeseq_u8(block, *rk); + block = vaesmcq_u8(block); + } + + let block = vaeseq_u8(block, rkn2); + let block = veorq_u8(block, rkn1); + + let block = veorq_u8(vld1q_u8(cipher.as_ptr().cast()), block); + vst1q_u8(cipher.as_mut_ptr().cast(), block); + } + + let cipher_inout = exact.into_remainder(); + if !cipher_inout.is_empty() { + let mut cipher = [0u8; 16]; + let len = cipher_inout.len(); + debug_assert!(len < 16); + cipher[..len].copy_from_slice(cipher_inout); + + counter = vaddq_u32(counter, inc); + let mut block = vrev32q_u8(vreinterpretq_u8_u32(counter)); + + let (rks, rkn2, rkn1) = self.round_keys(); + for rk in rks { + block = vaeseq_u8(block, *rk); + block = vaesmcq_u8(block); + } + + let block = vaeseq_u8(block, rkn2); + let block = veorq_u8(block, rkn1); + + let block = veorq_u8(vld1q_u8(cipher.as_ptr().cast()), block); + vst1q_u8(cipher.as_mut_ptr().cast(), block); + + cipher_inout.copy_from_slice(&cipher[..len]); + } + } + + /// Returns the round keys: (0..N-2, N-2, N-1) + pub(crate) fn round_keys(&self) -> (&[uint8x16_t], uint8x16_t, uint8x16_t) { match self { Self::Aes128(a128) => ( - a128.round_keys[0], - &a128.round_keys[1..10], + &a128.round_keys[0..9], + a128.round_keys[9], a128.round_keys[10], ), Self::Aes256(a256) => ( - a256.round_keys[0], - &a256.round_keys[1..14], + &a256.round_keys[0..13], + a256.round_keys[13], a256.round_keys[14], ), } diff --git a/graviola/src/low/aarch64/aes_gcm.rs b/graviola/src/low/aarch64/aes_gcm.rs index a865c2ab..b7083504 100644 --- a/graviola/src/low/aarch64/aes_gcm.rs +++ b/graviola/src/low/aarch64/aes_gcm.rs @@ -12,7 +12,7 @@ pub(crate) fn encrypt( cipher_inout: &mut [u8], ) { ghash.add(aad); - cipher(key, initial_counter, cipher_inout); + key.ctr(initial_counter, cipher_inout); ghash.add(cipher_inout); } @@ -25,44 +25,5 @@ pub(crate) fn decrypt( ) { ghash.add(aad); ghash.add(cipher_inout); - cipher(key, initial_counter, cipher_inout); -} - -fn cipher(key: &AesKey, initial_counter: &[u8; 16], cipher_inout: &mut [u8]) { - let mut counter = *initial_counter; - let mut exact = cipher_inout.chunks_exact_mut(16); - - for block in exact.by_ref() { - inc_counter(&mut counter); - ctr(key, &counter, block); - } - - let cipher_inout = exact.into_remainder(); - if !cipher_inout.is_empty() { - let mut block = [0u8; 16]; - let len = cipher_inout.len(); - debug_assert!(len < 16); - block[..len].copy_from_slice(cipher_inout); - - inc_counter(&mut counter); - ctr(key, &counter, &mut block); - - cipher_inout.copy_from_slice(&block[..len]); - } -} - -#[inline] -fn ctr(key: &AesKey, counter: &[u8; 16], cipher_inout: &mut [u8]) { - let mut block = *counter; - key.encrypt_block(&mut block); - for (x, y) in cipher_inout.iter_mut().zip(block.iter()) { - *x ^= *y; - } -} - -#[inline] -fn inc_counter(block: &mut [u8; 16]) { - let c = u32::from_be_bytes(block[12..].try_into().unwrap()); - let c = c.wrapping_add(1); - block[12..].copy_from_slice(&c.to_be_bytes()); + key.ctr(initial_counter, cipher_inout); }