Skip to content

Commit

Permalink
Try another formulation of AES-CTR
Browse files Browse the repository at this point in the history
  • Loading branch information
ctz committed Oct 27, 2024
1 parent 9b02e3b commit 655adc6
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 46 deletions.
70 changes: 65 additions & 5 deletions graviola/src/low/aarch64/aes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,76 @@ impl AesKey {
}
}

pub(crate) fn round_keys(&self) -> (uint8x16_t, &[uint8x16_t], uint8x16_t) {
pub(crate) fn ctr(&self, initial_counter: &[u8; 16], cipher_inout: &mut [u8]) {
// SAFETY: this crate requires the `aes` & `neon` cpu features
unsafe { self._ctr(initial_counter, cipher_inout) }
}

#[target_feature(enable = "aes,neon")]
unsafe fn _ctr(&self, initial_counter: &[u8; 16], cipher_inout: &mut [u8]) {
// counter and inc are big endian, so must be vrev32q_u8'd before use (and after increment)
let counter = vld1q_u8(initial_counter.as_ptr().cast());
let mut counter = vreinterpretq_u32_u8(vrev32q_u8(counter));

let inc = vsetq_lane_u8(1, vdupq_n_u8(0), 15);
let inc = vreinterpretq_u32_u8(vrev32q_u8(inc));

let mut exact = cipher_inout.chunks_exact_mut(16);

for cipher in exact.by_ref() {
counter = vaddq_u32(counter, inc);
let mut block = vrev32q_u8(vreinterpretq_u8_u32(counter));

let (rks, rkn2, rkn1) = self.round_keys();
for rk in rks {
block = vaeseq_u8(block, *rk);
block = vaesmcq_u8(block);
}

let block = vaeseq_u8(block, rkn2);
let block = veorq_u8(block, rkn1);

let block = veorq_u8(vld1q_u8(cipher.as_ptr().cast()), block);
vst1q_u8(cipher.as_mut_ptr().cast(), block);
}

let cipher_inout = exact.into_remainder();
if !cipher_inout.is_empty() {
let mut cipher = [0u8; 16];
let len = cipher_inout.len();
debug_assert!(len < 16);
cipher[..len].copy_from_slice(cipher_inout);

counter = vaddq_u32(counter, inc);
let mut block = vrev32q_u8(vreinterpretq_u8_u32(counter));

let (rks, rkn2, rkn1) = self.round_keys();
for rk in rks {
block = vaeseq_u8(block, *rk);
block = vaesmcq_u8(block);
}

let block = vaeseq_u8(block, rkn2);
let block = veorq_u8(block, rkn1);

let block = veorq_u8(vld1q_u8(cipher.as_ptr().cast()), block);
vst1q_u8(cipher.as_mut_ptr().cast(), block);

cipher_inout.copy_from_slice(&cipher[..len]);
}
}

/// Returns the round keys: (0..N-2, N-2, N-1)
pub(crate) fn round_keys(&self) -> (&[uint8x16_t], uint8x16_t, uint8x16_t) {
match self {
Self::Aes128(a128) => (
a128.round_keys[0],
&a128.round_keys[1..10],
&a128.round_keys[0..9],
a128.round_keys[9],
a128.round_keys[10],
),
Self::Aes256(a256) => (
a256.round_keys[0],
&a256.round_keys[1..14],
&a256.round_keys[0..13],
a256.round_keys[13],
a256.round_keys[14],
),
}
Expand Down
43 changes: 2 additions & 41 deletions graviola/src/low/aarch64/aes_gcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub(crate) fn encrypt(
cipher_inout: &mut [u8],
) {
ghash.add(aad);
cipher(key, initial_counter, cipher_inout);
key.ctr(initial_counter, cipher_inout);
ghash.add(cipher_inout);
}

Expand All @@ -25,44 +25,5 @@ pub(crate) fn decrypt(
) {
ghash.add(aad);
ghash.add(cipher_inout);
cipher(key, initial_counter, cipher_inout);
}

fn cipher(key: &AesKey, initial_counter: &[u8; 16], cipher_inout: &mut [u8]) {
let mut counter = *initial_counter;
let mut exact = cipher_inout.chunks_exact_mut(16);

for block in exact.by_ref() {
inc_counter(&mut counter);
ctr(key, &counter, block);
}

let cipher_inout = exact.into_remainder();
if !cipher_inout.is_empty() {
let mut block = [0u8; 16];
let len = cipher_inout.len();
debug_assert!(len < 16);
block[..len].copy_from_slice(cipher_inout);

inc_counter(&mut counter);
ctr(key, &counter, &mut block);

cipher_inout.copy_from_slice(&block[..len]);
}
}

#[inline]
fn ctr(key: &AesKey, counter: &[u8; 16], cipher_inout: &mut [u8]) {
let mut block = *counter;
key.encrypt_block(&mut block);
for (x, y) in cipher_inout.iter_mut().zip(block.iter()) {
*x ^= *y;
}
}

#[inline]
fn inc_counter(block: &mut [u8; 16]) {
let c = u32::from_be_bytes(block[12..].try_into().unwrap());
let c = c.wrapping_add(1);
block[12..].copy_from_slice(&c.to_be_bytes());
key.ctr(initial_counter, cipher_inout);
}

0 comments on commit 655adc6

Please sign in to comment.