Skip to content

Commit

Permalink
Simplify partInstall to use start and len
Browse files Browse the repository at this point in the history
  • Loading branch information
pennyannn committed Oct 4, 2024
1 parent 55f27e3 commit 86e65aa
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 128 deletions.
17 changes: 8 additions & 9 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,21 @@ abbrev ror (x : BitVec n) (r : Nat) : BitVec n :=
abbrev lsb (x : BitVec n) (i : Nat) : BitVec 1 :=
BitVec.extractLsb' i 1 x

abbrev partInstall (hi lo : Nat) (val : BitVec (hi - lo + 1)) (x : BitVec n): BitVec n :=
let mask := allOnes (hi - lo + 1)
let val_aligned := (zeroExtend n val) <<< lo
let mask_with_hole := ~~~ ((zeroExtend n mask) <<< lo)
abbrev partInstall (start len : Nat) (val : BitVec len) (x : BitVec n): BitVec n :=
let mask := allOnes len
let val_aligned := (zeroExtend n val) <<< start
let mask_with_hole := ~~~ ((zeroExtend n mask) <<< start)
let x_with_hole := x &&& mask_with_hole
x_with_hole ||| val_aligned

example : (partInstall 3 0 0xC#4 0xAB0D#16 = 0xAB0C#16) := rfl
example : (partInstall 0 4 0xC#4 0xAB0D#16 = 0xAB0C#16) := rfl

def flattenTR {n : Nat} (xs : List (BitVec n)) (i : Nat)
(acc : BitVec len) (H : n > 0) : BitVec len :=
match xs with
| [] => acc
| x :: rest =>
have h : n = (i * n + n - 1 - i * n + 1) := by omega
let new_acc := (BitVec.partInstall (i * n + n - 1) (i * n) (BitVec.cast h x) acc)
let new_acc := (BitVec.partInstall (i * n) n x acc)
flattenTR rest (i + 1) new_acc H

/-- Reverse bits of a bit-vector. -/
Expand All @@ -317,8 +316,8 @@ def reverse (x : BitVec n) : BitVec n :=
match i with
| 0 => acc
| j + 1 =>
let xi : BitVec 1 := extractLsb' (i - 1) 1 x
let acc := BitVec.partInstall (n - i) (n - i) (xi.cast (by omega)) acc
let xi := extractLsb' (i - 1) 1 x
let acc := BitVec.partInstall (n - i) 1 xi acc
reverseTR x j acc
reverseTR x n $ BitVec.zero n

Expand Down
10 changes: 4 additions & 6 deletions Arm/Insts/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -633,12 +633,10 @@ def elem_get (vector : BitVec n) (e : Nat) (size : Nat) : BitVec size :=
the `e`'th element in the `vector`. -/
@[state_simp_rules]
def elem_set (vector : BitVec n) (e : Nat) (size : Nat)
(value : BitVec size) (h: size > 0): BitVec n :=
(value : BitVec size) : BitVec n :=
-- assert (e+1)*size <= n
let lo := e * size
let hi := lo + size - 1
have h : size = hi - lo + 1 := by simp only [hi, lo]; omega
BitVec.partInstall hi lo (BitVec.cast h value) vector
BitVec.partInstall lo size value vector

----------------------------------------------------------------------

Expand Down Expand Up @@ -681,7 +679,7 @@ def shift_right_common_aux
let elem := Int_with_unsigned info.unsigned $ elem_get operand e info.esize
let shift_elem := RShr info.unsigned elem info.shift info.round
let acc_elem := elem_get operand2 e info.esize + shift_elem
let result := elem_set result e info.esize acc_elem info.h
let result := elem_set result e info.esize acc_elem
have _ : info.elements - (e + 1) < info.elements - e := by omega
shift_right_common_aux (e + 1) info operand operand2 result
termination_by (info.elements - e)
Expand All @@ -703,7 +701,7 @@ def shift_left_common_aux
else
let elem := elem_get operand e info.esize
let shift_elem := elem <<< info.shift
let result := elem_set result e info.esize shift_elem info.h
let result := elem_set result e info.esize shift_elem
have _ : info.elements - (e + 1) < info.elements - e := by omega
shift_left_common_aux (e + 1) info operand result
termination_by (info.elements - e)
Expand Down
3 changes: 1 addition & 2 deletions Arm/Insts/DPI/Move_wide_imm.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def exec_move_wide_imm (inst : Move_wide_imm_cls) (s : ArmState) : ArmState :=
let result := if inst.opc = 0b11#2
then read_gpr datasize inst.Rd s
else BitVec.zero datasize
have h : 16 = pos + 15 - pos + 1 := by omega
let result := partInstall (pos + 15) pos (BitVec.cast h inst.imm16) result
let result := partInstall pos 16 inst.imm16 result
let result := if inst.opc = 0b00#2 then ~~~result else result
-- State Update
let s := write_gpr datasize inst.Rd result s
Expand Down
2 changes: 1 addition & 1 deletion Arm/Insts/DPI/PC_rel_addressing.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def exec_pc_rel_addressing (inst : PC_rel_addressing_cls) (s : ArmState) : ArmSt
let result := if inst.op = 0#1 then
orig_pc + imm -- ADR
else
(BitVec.partInstall 11 0 0#12 orig_pc) + imm
(BitVec.partInstall 0 12 0#12 orig_pc) + imm
-- State Updates
let s := write_gpr_zr 64 inst.Rd result s
let s := write_pc (orig_pc + 4#64) s
Expand Down
21 changes: 8 additions & 13 deletions Arm/Insts/DPSFP/Advanced_simd_copy.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ namespace DPSFP
open BitVec

def dup_aux (e : Nat) (elements : Nat) (esize : Nat)
(element : BitVec esize) (result : BitVec datasize) (H : 0 < esize) : BitVec datasize :=
if h₀ : elements <= e then
(element : BitVec esize) (result : BitVec datasize) : BitVec datasize :=
if elements <= e then
result
else
let result := elem_set result e esize element H
have h : elements - (e + 1) < elements - e := by omega
dup_aux (e + 1) elements esize element result H
let result := elem_set result e esize element
dup_aux (e + 1) elements esize element result
termination_by (elements - e)

@[state_simp_rules]
Expand All @@ -38,9 +37,8 @@ def exec_dup_element (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState :
let datasize := 64 <<< inst.Q.toNat
let elements := datasize / esize
let operand := read_sfp idxdsize inst.Rn s
have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let element := elem_get operand index esize
let result := dup_aux 0 elements esize element (BitVec.zero datasize) h₀
let result := dup_aux 0 elements esize element (BitVec.zero datasize)
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
let s := write_sfp datasize inst.Rd result s
Expand All @@ -56,8 +54,7 @@ def exec_dup_general (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState :
let datasize := 64 <<< inst.Q.toNat
let elements := datasize / esize
let element := read_gpr esize inst.Rn s
have h₀ : 0 < esize := by apply zero_lt_shift_left_pos (by decide)
let result := dup_aux 0 elements esize element (BitVec.zero datasize) h₀
let result := dup_aux 0 elements esize element (BitVec.zero datasize)
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
let s := write_sfp datasize inst.Rd result s
Expand All @@ -75,9 +72,8 @@ def exec_ins_element (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState :
let esize := 8 <<< size
let operand := read_sfp idxdsize inst.Rn s
let result := read_sfp 128 inst.Rd s
have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let elem := elem_get operand src_index esize
let result := elem_set result dst_index esize elem h₀
let result := elem_set result dst_index esize elem
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
let s := write_sfp 128 inst.Rd result s
Expand All @@ -93,8 +89,7 @@ def exec_ins_general (inst : Advanced_simd_copy_cls) (s : ArmState) : ArmState :
let esize := 8 <<< size
let element := read_gpr esize inst.Rn s
let result := read_sfp 128 inst.Rd s
have h₀ : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let result := elem_set result index esize element h₀
let result := elem_set result index esize element
-- State Updates
let s := write_pc ((read_pc s) + 4#64) s
let s := write_sfp 128 inst.Rd result s
Expand Down
14 changes: 6 additions & 8 deletions Arm/Insts/DPSFP/Advanced_simd_permute.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@ open BitVec

def trn_aux (p : Nat) (pairs : Nat) (esize : Nat) (part : Nat)
(operand1 : BitVec datasize) (operand2 : BitVec datasize)
(result : BitVec datasize) (h : esize > 0) : BitVec datasize :=
if h₀ : pairs <= p then
(result : BitVec datasize) : BitVec datasize :=
if pairs <= p then
result
else
let idx_from := 2 * p + part
let op1_part := elem_get operand1 idx_from esize
let op2_part := elem_get operand2 idx_from esize
let result := elem_set result (2 * p) esize op1_part h
let result := elem_set result (2 * p + 1) esize op2_part h
have h₁ : pairs - (p + 1) < pairs - p := by omega
trn_aux (p + 1) pairs esize part operand1 operand2 result h
let result := elem_set result (2 * p) esize op1_part
let result := elem_set result (2 * p + 1) esize op2_part
trn_aux (p + 1) pairs esize part operand1 operand2 result
termination_by (pairs - p)

@[state_simp_rules]
Expand All @@ -43,8 +42,7 @@ def exec_trn (inst : Advanced_simd_permute_cls) (s : ArmState) : ArmState :=
let pairs := elements / 2
let operand1 := read_sfp datasize inst.Rn s
let operand2 := read_sfp datasize inst.Rm s
have h : esize > 0 := by apply zero_lt_shift_left_pos (by decide)
let result := trn_aux 0 pairs esize part operand1 operand2 (BitVec.zero datasize) h
let result := trn_aux 0 pairs esize part operand1 operand2 (BitVec.zero datasize)
-- Update States
let s := write_sfp datasize inst.Rd result s
let s := write_pc ((read_pc s) + 4#64) s
Expand Down
9 changes: 3 additions & 6 deletions Arm/Insts/DPSFP/Advanced_simd_table_lookup.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ def create_table (i : Nat) (regs : Nat) (Rn : BitVec 5) (table : BitVec (128 * r
table
else
let val := read_sfp 128 Rn s
have h₁ : 128 = 128 * i + 127 - 128 * i + 1 := by omega
let table := BitVec.partInstall (128 * i + 127) (128 * i) (BitVec.cast h₁ val) table
let table := BitVec.partInstall (128 * i) 128 val table
let Rn := (Rn + 1) % 32
have h₂ : regs - (i + 1) < regs - i := by omega
create_table (i + 1) regs Rn table s
Expand All @@ -31,18 +30,16 @@ def create_table (i : Nat) (regs : Nat) (Rn : BitVec 5) (table : BitVec (128 * r
def tblx_aux (i : Nat) (elements : Nat) (indices : BitVec datasize)
(regs : Nat) (table : BitVec (128 * regs)) (result: BitVec datasize)
: BitVec datasize :=
if h₀ : elements <= i then
if elements <= i then
result
else
have h₁ : 8 > 0 := by decide
let index := (elem_get indices i 8).toNat
let result :=
if index < 16 * regs then
let val := elem_get table index 8
elem_set result i 8 val h₁
elem_set result i 8 val
else
result
have h₂ : elements - (i + 1) < elements - i := by omega
tblx_aux (i + 1) elements indices regs table result
termination_by (elements - i)

Expand Down
13 changes: 5 additions & 8 deletions Arm/Insts/DPSFP/Advanced_simd_three_different.lean
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,16 @@ def polynomial_mult (op1 : BitVec m) (op2 : BitVec n) : BitVec (m+n) :=
polynomial_mult_aux 0 result op1 extended_op2

def pmull_op (e : Nat) (esize : Nat) (elements : Nat) (x : BitVec n)
(y : BitVec n) (result : BitVec (n*2)) (H : 0 < esize) : BitVec (n*2) :=
if h₀ : elements <= e then
(y : BitVec n) (result : BitVec (n*2)) : BitVec (n*2) :=
if elements <= e then
result
else
let element1 := elem_get x e esize
let element2 := elem_get y e esize
let elem_result := polynomial_mult element1 element2
have h₁ : esize + esize = 2 * esize := by omega
have h₂ : 2 * esize > 0 := by omega
let result := elem_set result e (2 * esize) (BitVec.cast h₁ elem_result) h₂
have _ : elements - (e + 1) < elements - e := by omega
pmull_op (e + 1) esize elements x y result H
let result := elem_set result e (2 * esize) (BitVec.cast h₁ elem_result)
pmull_op (e + 1) esize elements x y result
termination_by (elements - e)

@[state_simp_rules]
Expand All @@ -54,14 +52,13 @@ def exec_pmull (inst : Advanced_simd_three_different_cls) (s : ArmState) : ArmSt
write_err (StateError.Illegal s!"Illegal {inst} encountered!") s
else
let esize := 8 <<< inst.size.toNat
have h₀ : 0 < esize := by apply zero_lt_shift_left_pos (by decide)
let datasize := 64
let part := inst.Q.toNat
let elements := datasize / esize
let operand1 := Vpart_read inst.Rn part datasize s
let operand2 := Vpart_read inst.Rm part datasize s
let result :=
pmull_op 0 esize elements operand1 operand2 (BitVec.zero (2*datasize)) h₀
pmull_op 0 esize elements operand1 operand2 (BitVec.zero (2*datasize))
let s := write_sfp (datasize*2) inst.Rd result s
let s := write_pc ((read_pc s) + 4#64) s
s
Expand Down
18 changes: 7 additions & 11 deletions Arm/Insts/DPSFP/Advanced_simd_three_same.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,24 @@ open BitVec

def binary_vector_op_aux (e : Nat) (elems : Nat) (esize : Nat)
(op : BitVec esize → BitVec esize → BitVec esize)
(x : BitVec n) (y : BitVec n) (result : BitVec n)
(H : esize > 0) : BitVec n :=
if h₀ : elems ≤ e then
(x : BitVec n) (y : BitVec n) (result : BitVec n) : BitVec n :=
if elems ≤ e then
result
else
have h₁ : e < elems := by omega
let element1 := elem_get x e esize
let element2 := elem_get y e esize
let elem_result := op element1 element2
let result := elem_set result e esize elem_result H
have ht1 : elems - (e + 1) < elems - e := by omega
binary_vector_op_aux (e + 1) elems esize op x y result H
let result := elem_set result e esize elem_result
binary_vector_op_aux (e + 1) elems esize op x y result
termination_by (elems - e)

/--
Perform pairwise op on esize-bit slices of x and y
-/
@[state_simp_rules]
def binary_vector_op (esize : Nat) (op : BitVec esize → BitVec esize → BitVec esize)
(x : BitVec n) (y : BitVec n) (H : 0 < esize) : BitVec n :=
binary_vector_op_aux 0 (n / esize) esize op x y (BitVec.zero n) H
(x : BitVec n) (y : BitVec n) : BitVec n :=
binary_vector_op_aux 0 (n / esize) esize op x y (BitVec.zero n)

@[state_simp_rules]
def exec_binary_vector (inst : Advanced_simd_three_same_cls) (s : ArmState) : ArmState :=
Expand All @@ -48,12 +45,11 @@ def exec_binary_vector (inst : Advanced_simd_three_same_cls) (s : ArmState) : Ar
else
let datasize := if inst.Q = 1#1 then 128 else 64
let esize := 8 <<< (BitVec.toNat inst.size)
have h_esize : 0 < esize := by simp [esize]; apply zero_lt_shift_left_pos (by decide)
let sub_op := inst.U = 1
let operand1 := read_sfp datasize inst.Rn s
let operand2 := read_sfp datasize inst.Rm s
let op := if sub_op then BitVec.sub else BitVec.add
let result := binary_vector_op esize op operand1 operand2 h_esize
let result := binary_vector_op esize op operand1 operand2
let s := write_sfp datasize inst.Rd result s
s

Expand Down
21 changes: 9 additions & 12 deletions Specs/AESCommon.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def SubBytes_aux (i : Nat) (op : BitVec 128) (out : BitVec 128)
let i := 16 - i
let idx := (extractLsb' (i * 8) 8 op).toNat
let val := extractLsb' (idx * 8) 8 $ BitVec.flatten SBOX
have h₁ : 8 = i * 8 + 7 - i * 8 + 1 := by omega
let out := BitVec.partInstall (i * 8 + 7) (i * 8) (BitVec.cast h₁ val) out
let out := BitVec.partInstall (i * 8) 8 val out
SubBytes_aux i' op out

def SubBytes (op : BitVec 128) : BitVec 128 :=
Expand All @@ -66,20 +65,18 @@ def MixColumns_aux (c : Nat)
| 0 => (out0, out1, out2, out3)
| c' + 1 =>
let lo := (4 - c) * 8
let hi := lo + 7
let in0_byte := extractLsb' lo 8 in0
let in1_byte := extractLsb' lo 8 in1
let in2_byte := extractLsb' lo 8 in2
let in3_byte := extractLsb' lo 8 in3
have h : 8 = hi - lo + 1 := by omega
let val0 := BitVec.cast h $ FFmul02 in0_byte ^^^ FFmul03 in1_byte ^^^ in2_byte ^^^ in3_byte
let out0 := BitVec.partInstall hi lo val0 out0
let val1 := BitVec.cast h $ FFmul02 in1_byte ^^^ FFmul03 in2_byte ^^^ in3_byte ^^^ in0_byte
let out1 := BitVec.partInstall hi lo val1 out1
let val2 := BitVec.cast h $ FFmul02 in2_byte ^^^ FFmul03 in3_byte ^^^ in0_byte ^^^ in1_byte
let out2 := BitVec.partInstall hi lo val2 out2
let val3 := BitVec.cast h $ FFmul02 in3_byte ^^^ FFmul03 in0_byte ^^^ in1_byte ^^^ in2_byte
let out3 := BitVec.partInstall hi lo val3 out3
let val0 := FFmul02 in0_byte ^^^ FFmul03 in1_byte ^^^ in2_byte ^^^ in3_byte
let out0 := BitVec.partInstall lo 8 val0 out0
let val1 := FFmul02 in1_byte ^^^ FFmul03 in2_byte ^^^ in3_byte ^^^ in0_byte
let out1 := BitVec.partInstall lo 8 val1 out1
let val2 := FFmul02 in2_byte ^^^ FFmul03 in3_byte ^^^ in0_byte ^^^ in1_byte
let out2 := BitVec.partInstall lo 8 val2 out2
let val3 := FFmul02 in3_byte ^^^ FFmul03 in0_byte ^^^ in1_byte ^^^ in2_byte
let out3 := BitVec.partInstall lo 8 val3 out3
MixColumns_aux c' in0 in1 in2 in3 out0 out1 out2 out3 FFmul02 FFmul03

def MixColumns (op : BitVec 128) (FFmul02 : BitVec 8 -> BitVec 8)
Expand Down
3 changes: 1 addition & 2 deletions Specs/AESV8.lean
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def AESHWCtr32EncryptBlocks_helper {Param : AESArm.KBR} (in_blocks : BitVec m)
else
let lo := m - (i + 1) * 128
let hi := lo + 127
have h5 : hi - lo + 1 = 128 := by omega
let curr_block : BitVec 128 := BitVec.extractLsb' lo 128 in_blocks
have h4 : 128 = Param.block_size := by
cases h3
Expand All @@ -126,7 +125,7 @@ def AESHWCtr32EncryptBlocks_helper {Param : AESArm.KBR} (in_blocks : BitVec m)
(Param := Param) (BitVec.cast h4 ivec_rev) key.rd_key
let res_block := rev_elems 128 8 res_block (by decide) (by decide)
let res_block := res_block ^^^ curr_block
let new_acc := BitVec.partInstall hi lo (BitVec.cast h5.symm res_block) acc
let new_acc := BitVec.partInstall lo 128 res_block acc
AESHWCtr32EncryptBlocks_helper (Param := Param)
in_blocks (i + 1) len key (ivec + 1#128) new_acc h1 h2 h3
termination_by (len - i)
Expand Down
4 changes: 1 addition & 3 deletions Specs/GCM.lean
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,9 @@ def GCTR_aux (CIPH : Cipher (n := 128) (m := m))
Y
else
let lo := (n - i - 1) * 128
let hi := lo + 127
have h : 128 = hi - lo + 1 := by omega
let Xi := extractLsb' lo 128 X
let Yi := Xi ^^^ CIPH ICB K
let Y := BitVec.partInstall hi lo (BitVec.cast h Yi) Y
let Y := BitVec.partInstall lo 128 Yi Y
let ICB := inc_s 32 ICB (by omega)
GCTR_aux CIPH (i + 1) n K ICB X Y
termination_by (n - i)
Expand Down
2 changes: 1 addition & 1 deletion Specs/GCMV8.lean
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def pdiv (x: BitVec n) (y : BitVec m): BitVec n :=
let zi := extractLsb' 0 m ((GCMV8.reduce y z) ++ xi)
let bit := extractLsb' (GCMV8.degree y) 1 zi
let newacc : BitVec n :=
partInstall (i - 1) (i - 1) (bit.cast (by omega)) acc
partInstall (i - 1) 1 bit acc
pdivTR x y j zi newacc
pdivTR x y n (BitVec.zero m) (BitVec.zero n)

Expand Down
Loading

0 comments on commit 86e65aa

Please sign in to comment.