Skip to content

Commit

Permalink
try by-4
Browse files Browse the repository at this point in the history
  • Loading branch information
ctz committed Oct 27, 2024
1 parent 708a52a commit fad7cf0
Showing 1 changed file with 114 additions and 3 deletions.
117 changes: 114 additions & 3 deletions graviola/src/low/aarch64/aes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,41 @@ impl AesKey {
let inc = vsetq_lane_u8(1, vdupq_n_u8(0), 15);
let inc = vreinterpretq_u32_u8(vrev32q_u8(inc));

let mut exact = cipher_inout.chunks_exact_mut(16);
let mut quads = cipher_inout.chunks_exact_mut(64);

for cipher in exact.by_ref() {
for cipher4 in quads.by_ref() {
counter = vaddq_u32(counter, inc);
let blocka = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let blockb = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let blockc = vrev32q_u8(vreinterpretq_u8_u32(counter));
counter = vaddq_u32(counter, inc);
let blockd = vrev32q_u8(vreinterpretq_u8_u32(counter));

let (blocka, blockb, blockc, blockd) = match self {
Self::Aes128(a128) => {
_aes128_4_blocks(&a128.round_keys, blocka, blockb, blockc, blockd)
}
Self::Aes256(a256) => {
_aes256_4_blocks(&a256.round_keys, blocka, blockb, blockc, blockd)
}
};

let blocka = veorq_u8(vld1q_u8(cipher4.as_ptr().add(0).cast()), blocka);
let blockb = veorq_u8(vld1q_u8(cipher4.as_ptr().add(16).cast()), blockb);
let blockc = veorq_u8(vld1q_u8(cipher4.as_ptr().add(32).cast()), blockc);
let blockd = veorq_u8(vld1q_u8(cipher4.as_ptr().add(48).cast()), blockd);

vst1q_u8(cipher4.as_mut_ptr().add(0).cast(), blocka);
vst1q_u8(cipher4.as_mut_ptr().add(16).cast(), blockb);
vst1q_u8(cipher4.as_mut_ptr().add(32).cast(), blockc);
vst1q_u8(cipher4.as_mut_ptr().add(48).cast(), blockd);
}

let mut singles = quads.into_remainder().chunks_exact_mut(16);

for cipher in singles.by_ref() {
counter = vaddq_u32(counter, inc);
let block = vrev32q_u8(vreinterpretq_u8_u32(counter));

Expand All @@ -66,7 +98,7 @@ impl AesKey {
vst1q_u8(cipher.as_mut_ptr().cast(), block);
}

let cipher_inout = exact.into_remainder();
let cipher_inout = singles.into_remainder();
if !cipher_inout.is_empty() {
let mut cipher = [0u8; 16];
let len = cipher_inout.len();
Expand Down Expand Up @@ -243,6 +275,50 @@ unsafe fn _aes128_block(round_keys: &[uint8x16_t; 11], block: uint8x16_t) -> uin
veorq_u8(block, round_keys[10])
}

macro_rules! round_4 {
($b1:ident, $b2:ident, $b3:ident, $b4:ident, $rk:expr) => {
$b1 = vaeseq_u8($b1, $rk);
$b1 = vaesmcq_u8($b1);
$b2 = vaeseq_u8($b2, $rk);
$b2 = vaesmcq_u8($b2);
$b3 = vaeseq_u8($b3, $rk);
$b3 = vaesmcq_u8($b3);
$b4 = vaeseq_u8($b4, $rk);
$b4 = vaesmcq_u8($b4);
};
}

#[target_feature(enable = "aes")]
#[inline]
unsafe fn _aes128_4_blocks(
round_keys: &[uint8x16_t; 11],
mut b1: uint8x16_t,
mut b2: uint8x16_t,
mut b3: uint8x16_t,
mut b4: uint8x16_t,
) -> (uint8x16_t, uint8x16_t, uint8x16_t, uint8x16_t) {
round_4!(b1, b2, b3, b4, round_keys[0]);
round_4!(b1, b2, b3, b4, round_keys[1]);
round_4!(b1, b2, b3, b4, round_keys[2]);
round_4!(b1, b2, b3, b4, round_keys[3]);
round_4!(b1, b2, b3, b4, round_keys[4]);
round_4!(b1, b2, b3, b4, round_keys[5]);
round_4!(b1, b2, b3, b4, round_keys[6]);
round_4!(b1, b2, b3, b4, round_keys[7]);
round_4!(b1, b2, b3, b4, round_keys[8]);

let b1 = vaeseq_u8(b1, round_keys[9]);
let b2 = vaeseq_u8(b2, round_keys[9]);
let b3 = vaeseq_u8(b3, round_keys[9]);
let b4 = vaeseq_u8(b4, round_keys[9]);
(
veorq_u8(b1, round_keys[10]),
veorq_u8(b2, round_keys[10]),
veorq_u8(b3, round_keys[10]),
veorq_u8(b4, round_keys[10]),
)
}

#[target_feature(enable = "aes")]
unsafe fn aes256_block(round_keys: &[uint8x16_t; 15], block_inout: &mut [u8]) {
let block = vld1q_u8(block_inout.as_ptr() as *const _);
Expand Down Expand Up @@ -283,6 +359,41 @@ unsafe fn _aes256_block(round_keys: &[uint8x16_t; 15], block: uint8x16_t) -> uin
veorq_u8(block, round_keys[14])
}

#[target_feature(enable = "aes")]
#[inline]
unsafe fn _aes256_4_blocks(
round_keys: &[uint8x16_t; 15],
mut b1: uint8x16_t,
mut b2: uint8x16_t,
mut b3: uint8x16_t,
mut b4: uint8x16_t,
) -> (uint8x16_t, uint8x16_t, uint8x16_t, uint8x16_t) {
round_4!(b1, b2, b3, b4, round_keys[0]);
round_4!(b1, b2, b3, b4, round_keys[1]);
round_4!(b1, b2, b3, b4, round_keys[2]);
round_4!(b1, b2, b3, b4, round_keys[3]);
round_4!(b1, b2, b3, b4, round_keys[4]);
round_4!(b1, b2, b3, b4, round_keys[5]);
round_4!(b1, b2, b3, b4, round_keys[6]);
round_4!(b1, b2, b3, b4, round_keys[7]);
round_4!(b1, b2, b3, b4, round_keys[8]);
round_4!(b1, b2, b3, b4, round_keys[9]);
round_4!(b1, b2, b3, b4, round_keys[10]);
round_4!(b1, b2, b3, b4, round_keys[11]);
round_4!(b1, b2, b3, b4, round_keys[12]);

let b1 = vaeseq_u8(b1, round_keys[13]);
let b2 = vaeseq_u8(b2, round_keys[13]);
let b3 = vaeseq_u8(b3, round_keys[13]);
let b4 = vaeseq_u8(b4, round_keys[13]);
(
veorq_u8(b1, round_keys[14]),
veorq_u8(b2, round_keys[14]),
veorq_u8(b3, round_keys[14]),
veorq_u8(b4, round_keys[14]),
)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit fad7cf0

Please sign in to comment.