Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
pennyannn committed Aug 16, 2024
1 parent 9f261d1 commit b08cfac
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 40 deletions.
175 changes: 175 additions & 0 deletions Specs/AESGCMKernel.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Yan Peng
-/
import Arm.BitVec
import Specs.AESArm
import Specs.GCM
import Specs.AESV8
import Specs.GCMV8

namespace AESGCMKernel

open BitVec

structure AESGCMLoopVars (in_bytes : Nat) where
plaintext : BitVec (in_bytes * 8)
input_ptr : Nat -- how many bytes have been processed
main_end_input_ptr : Nat -- number of bytes to be processed in main loop
Xi : BitVec 128
Htable : List (BitVec 128)
key : AESV8.AESKey

ctr0 : BitVec 128
ctr1 : BitVec 128
ctr2 : BitVec 128
ctr3 : BitVec 128

aes0 : BitVec 128
aes1 : BitVec 128
aes2 : BitVec 128
aes3 : BitVec 128

ciphertext : BitVec (in_bytes * 8)
deriving DecidableEq, Repr

set_option diagnostics true
set_option maxRecDepth 100000

def AESGCMEncKernelLoop {Param : AESArm.KBR} (vars : AESGCMLoopVars in_bytes)
: AESGCMLoopVars in_bytes :=
if input_ptr ≥ main_end_input_ptr then vars
else
let bit_len := in_bytes * 8
let ctr3 := GCM.inc_s 32 vars.ctr2 (by omega) (by omega)
let vars := {vars with ctr3 := ctr3}
let H1 := vars.Htable.get! 0
let H2 := vars.Htable.get! 2
let H3 := vars.Htable.get! 3
let H4 := vars.Htable.get! 5
let Xi := GCMV8.gcm_ghash_block H4 vars.Xi vars.aes0
let Xi := GCMV8.gcm_ghash_block H3 Xi vars.aes1
let Xi := GCMV8.gcm_ghash_block H2 Xi vars.aes2
let Xi := GCMV8.gcm_ghash_block H1 Xi vars.aes3
let start := bit_len - vars.ptr
let block0 := extractLsb (start - 1) (start - 128) vars.plaintext
let block1 := extractLsb (start - 129) (start - 256) vars.plaintext
let block2 := extractLsb (start - 257) (start - 384) vars.plaintext
let block3 := extractLsb (start - 385) (start - 512) vars.plaintext
let aes0 := block0 ^^^
(AESArm.AES_encrypt_with_ks (Param := Param) vars.ctr0 vars.key.rd_key)
let aes1 := block1 ^^^
(AESArm.AES_encrypt_with_ks (Param := Param) vars.ctr1 vars.key.rd_key)
let aes2 := block2 ^^^
(AESArm.AES_encrypt_with_ks (Param := Param) vars.ctr2 vars.key.rd_key)
let aes3 := block3 ^^^
(AESArm.AES_encrypt_with_ks (Param := Param) vars.ctr3 vars.key.rd_key)
let ctr0 := GCM.inc_s 32 ctr3 (by omega) (by omega)
let ctr1 := GCM.inc_s 32 ctr0 (by omega) (by omega)
let ctr2 := GCM.inc_s 32 ctr1 (by omega) (by omega)
let ciphertext := BitVec.partInstall (vars.ptr + 127) vars.ptr aes0 vars.ciphertext
let ciphertext := BitVec.partInstall (vars.ptr + 255) (vars.ptr + 128) aes1 ciphertext
let ciphertext := BitVec.partInstall (vars.ptr + 383) (vars.ptr + 256) aes2 ciphertext
let ciphertext := BitVec.partInstall (vars.ptr + 511) (vars.ptr + 384) aes3 ciphertext
let vars : AESGCMLoopVars in_bytes :=
{ plaintext := vars.plaintext,
ptr := vars.ptr + 128 * 4,
Xi := Xi,
Htable := vars.Htable,
key := vars.key,
ctr0 := ctr0, ctr1 := ctr1,
ctr2 := ctr2, ctr3 := vars.ctr3,
aes0 := aes0, aes1 := aes1,
aes2 := aes2, aes3 := aes3,
ciphertext := ciphertext }
AESGCMEncKernelLoop (Param := Param) vars

def AESGCMEncKernelHelper {Param : AESArm.KBR}
(in_stream : BitVec (in_bytes * 8))
(Xi : BitVec 128) (ivec : BitVec 128)
(key : AESV8.AESKey) (Htable : List (BitVec 128))
: (List (BitVec 8) × BitVec 128) :=
let input_ptr := 0
-- subtracting 1 from in_bytes because the asssembly wants
-- last four blocks be handled in tail code
let main_end_input_ptr := input_ptr + ((in_bytes - 1) / 64 * 64)
let end_input_ptr := input_ptr + in_bytes
if input_ptr ≥ main_end_input_ptr -- if less or equal to 4 blocks
then sorry -- tail code
else
let ciphertext := BitVec.zero (in_bytes * 8)

have h1 : 128 = Param.block_size := by sorry
let ctr0 := BitVec.cast h1 ivec
let ctr1 := GCM.inc_s 32 ctr0 (by omega) (by omega)
let ctr2 := GCM.inc_s 32 ctr1 (by omega) (by omega)
let ctr3 := GCM.inc_s 32 ctr2 (by omega) (by omega)

let lo := input_ptr * 8
let hi := lo + 127
have h2 : hi - lo + 1 = 128 := by sorry
let block0 := BitVec.cast h2 $ extractLsb 127 0 in_stream
let lo := hi + 1
let hi := lo + 127
let block1 := BitVec.cast h2 $ extractLsb 255 128 in_stream
let lo := hi + 1
let hi := lo + 127
let block2 := BitVec.cast h2 $ extractLsb 383 256 in_stream
let lo := hi + 1
let hi := lo + 127
let block3 := BitVec.cast h2 $ extractLsb 511 384 in_stream

let aes0 := block0 ^^^ (BitVec.cast h1.symm
(AESArm.AES_encrypt_with_ks (Param := Param) ctr0 key.rd_key))
let aes1 := block1 ^^^ (BitVec.cast h1.symm
(AESArm.AES_encrypt_with_ks (Param := Param) ctr1 key.rd_key))
let aes2 := block2 ^^^ (BitVec.cast h1.symm
(AESArm.AES_encrypt_with_ks (Param := Param) ctr2 key.rd_key))
let aes3 := block3 ^^^ (BitVec.cast h1.symm
(AESArm.AES_encrypt_with_ks (Param := Param) ctr3 key.rd_key))

let ciphertext := BitVec.partInstall 127 0 aes0 ciphertext
let ciphertext := BitVec.partInstall 255 128 aes1 ciphertext
let ciphertext := BitVec.partInstall 383 256 aes2 ciphertext
let ciphertext := BitVec.partInstall 511 384 aes3 ciphertext

let ctr3 := BitVec.cast h1.symm ctr3
let ctr4 := GCM.inc_s 32 ctr3 (by omega) (by omega)
let ctr5 := GCM.inc_s 32 ctr4 (by omega) (by omega)
let ctr6 := GCM.inc_s 32 ctr5 (by omega) (by omega)
-- let ctr7 := GCM.inc_s 32 ctr6 (by omega) (by omega)
let vars : AESGCMLoopVars in_bytes :=
{ plaintext := in_stream,
input_ptr := input_ptr + 64,
main_end_input_ptr := main_end_input_ptr,
Xi := Xi,
Htable := Htable,
key := key,
ctr0 := ctr4, ctr1 := ctr5, ctr2 := ctr6, ctr3 := ctr3,
aes0 := aes0, aes1 := aes1, aes2 := aes2, aes3 := aes3,
ciphertext := ciphertext }
let res := AESGCMEncKernelLoop (Param := Param) vars
( split (BitVec.reverse res.ciphertext) 8 (by omega), res.Xi)

def AESGCMEncKernel (in_blocks : List (BitVec 8))
(in_bits : BitVec 64) (Xi : BitVec 128) (ivec : BitVec 128)
(key : AESV8.AESKey) (Htable : List (BitVec 128))
(h1 : key.rounds = 10#64 ∨ key.rounds = 12#64 ∨ key.rounds = 14#64)
(h2: 128 ∣ in_bits.toNat) -- in_blocks only contains whole blocks
(h3: 8 * in_blocks.length = in_bits.toNat)
: (List (BitVec 8) × BitVec 128) :=
let in_bytes := in_bits.toNat/8
have h: 8 * in_blocks.reverse.length = in_bytes * 8 := by
simp only [List.length_reverse, h3, in_bytes]; omega
let in_stream : BitVec (in_bytes * 8) :=
BitVec.cast h $ BitVec.flatten $ List.reverse in_blocks
let p := AESV8.KBR_from_AESKey key h1
AESGCMEncKernelHelper (Param := p) in_stream Xi ivec key Htable

def AESGCMDecKernel (in_blocks : List (BitVec 8)) (in_bits : BitVec 64)
(Xi : BitVec 128) (ivec : BitVec 128) (key : AESKey)
(Htable : List (BitVec 128)) : (List (BitVec 8) × BitVec 128) :=
sorry

end AESGCMKernel
72 changes: 33 additions & 39 deletions Specs/AESV8.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,39 @@ example : let res :=
}
AESHWSetEncryptKey 0#128 (by simp) = res := by native_decide


def KBR_from_AESKey (key : AESKey)
(h : key.rounds = 10#64 ∨ key.rounds = 12#64 ∨ key.rounds = 14#64)
: AESArm.KBR :=
if key.rounds = 10 then AESArm.AES128KBR
else if key.rounds = 12 then AESArm.AES192KBR
else AESArm.AES256KBR

theorem KBR_from_AESKey_blocksize (key : AESKey)
(h : key.rounds = 10#64 ∨ key.rounds = 12#64 ∨ key.rounds = 14#64)
: AESArm.KBR.block_size (KBR_from_AESKey key h) = 128 := by
simp only [ KBR_from_AESKey,
AESArm.AES128KBR, AESArm.AES192KBR,
AESArm.AES256KBR, AESArm.BlockSize ]
cases h
· rename_i h; simp [h]
· rename_i h; cases h
· rename_i h; simp [h]
· rename_i h; simp [h]

def flat_rev_block (in_block : List (BitVec 8))
(h : 8 * in_block.length = 128) : BitVec 128 :=
let in_block := BitVec.flatten in_block
rev_elems 128 8 (BitVec.cast h in_block) (by decide) (by decide)

def AESHWEncrypt (in_block : List (BitVec 8)) (key : AESKey)
(h1 : key.rounds = 10#64 ∨ key.rounds = 12#64 ∨ key.rounds = 14#64)
(h2 : 8 * in_block.length = 128)
: List (BitVec 8) :=
let p : AESArm.KBR :=
if key.rounds = 10 then AESArm.AES128KBR
else if key.rounds = 12 then AESArm.AES192KBR
else AESArm.AES256KBR
let p : AESArm.KBR := KBR_from_AESKey key h1
-- AESArm.AES_encrypt_with_ks is little-endian
let in_block := BitVec.flatten in_block
let in_block :=
rev_elems 128 8 (BitVec.cast h2 in_block) (by decide) (by decide)
have h : p.block_size = 128 := by
simp only [ p, AESArm.AES128KBR, AESArm.AES192KBR,
AESArm.AES256KBR, AESArm.BlockSize ]
cases h1
· rename_i h; simp [h]
· rename_i h; cases h
· rename_i h; simp [h]
· rename_i h; simp [h]
let in_block := flat_rev_block in_block h2
have h : p.block_size = 128 := by apply KBR_from_AESKey_blocksize
let res_block :=
AESArm.AES_encrypt_with_ks (Param := p)
(BitVec.cast h.symm in_block) key.rd_key
Expand Down Expand Up @@ -103,10 +116,7 @@ example : let in_block := List.replicate 16 0#8

def AESHWCtr32EncryptBlocks_helper {Param : AESArm.KBR} (in_blocks : BitVec m)
(i : Nat) (len : Nat) (key : AESKey) (ivec : BitVec 128) (acc : BitVec m)
(h1 : 128 ∣ m) (h2 : m / 128 = len)
(h3 : Param = AESArm.AES128KBR
∨ Param = AESArm.AES192KBR
∨ Param = AESArm.AES256KBR)
(h1 : 128 ∣ m) (h2 : m / 128 = len) (h3 : Param.block_size = 128)
: BitVec m :=
if i >= len then acc
else
Expand All @@ -115,16 +125,10 @@ def AESHWCtr32EncryptBlocks_helper {Param : AESArm.KBR} (in_blocks : BitVec m)
have h5 : hi - lo + 1 = 128 := by omega
let curr_block : BitVec 128 :=
BitVec.cast h5 $ BitVec.extractLsb hi lo in_blocks
have h4 : 128 = Param.block_size := by
cases h3
· rename_i h; simp only [h, AESArm.AES128KBR, AESArm.BlockSize]
· rename_i h; cases h
· rename_i h; simp only [h, AESArm.AES192KBR, AESArm.BlockSize]
· rename_i h; simp only [h, AESArm.AES256KBR, AESArm.BlockSize]
let ivec_rev := rev_elems 128 8 ivec (by decide) (by decide)
let res_block : BitVec 128 :=
BitVec.cast h4.symm $ AESArm.AES_encrypt_with_ks
(Param := Param) (BitVec.cast h4 ivec_rev) key.rd_key
BitVec.cast h3 $ AESArm.AES_encrypt_with_ks
(Param := Param) (BitVec.cast h3.symm ivec_rev) key.rd_key
let res_block := rev_elems 128 8 res_block (by decide) (by decide)
let res_block := res_block ^^^ curr_block
let new_acc := BitVec.partInstall hi lo (BitVec.cast h5.symm res_block) acc
Expand All @@ -137,18 +141,8 @@ def AESHWCtr32EncryptBlocks (in_blocks : List (BitVec 8)) (len : Nat)
(h1 : key.rounds = 10#64 ∨ key.rounds = 12#64 ∨ key.rounds = 14#64)
(h2 : 16 ∣ in_blocks.length) (h3 : in_blocks.length / 16 = len)
: List (BitVec 8) :=
let p : AESArm.KBR :=
if key.rounds = 10 then AESArm.AES128KBR
else if key.rounds = 12 then AESArm.AES192KBR
else AESArm.AES256KBR
have h : p = AESArm.AES128KBR
∨ p = AESArm.AES192KBR
∨ p = AESArm.AES256KBR := by
cases h1
· rename_i h; simp only [p, h]; simp
· rename_i h; cases h
· rename_i h; simp only [p, h]; simp
· rename_i h; simp only [p, h]; simp
let p : AESArm.KBR := KBR_from_AESKey key h1
have h : p.block_size = 128 := by apply KBR_from_AESKey_blocksize
let res := AESHWCtr32EncryptBlocks_helper (Param := p)
(BitVec.flatten in_blocks) 0 len key ivec
(BitVec.zero (8 * in_blocks.length)) (by omega) (by omega) h
Expand Down
2 changes: 1 addition & 1 deletion Specs/GCMV8.lean
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ example : GCMGmultV8 0x1099f4b39468565ccdd297a9df145877#128
0x9e#8, 0x15#8, 0xa6#8, 0x00#8, 0x67#8, 0x29#8, 0x7e#8, 0x0f#8 ] := by rfl


private def gcm_ghash_block (H : BitVec 128) (Xi : BitVec 128)
def gcm_ghash_block (H : BitVec 128) (Xi : BitVec 128)
(inp : BitVec 128) : BitVec 128 :=
let H := (lo H) ++ (hi H)
GCMV8.gcm_polyval H (Xi ^^^ inp)
Expand Down

0 comments on commit b08cfac

Please sign in to comment.