Skip to content

Commit

Permalink
Use structural recursion and rfl
Browse files Browse the repository at this point in the history
  • Loading branch information
pennyannn committed Sep 16, 2024
1 parent 8f5ea59 commit 82d9460
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 72 deletions.
40 changes: 19 additions & 21 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ abbrev partInstall (hi lo : Nat) (val : BitVec (hi - lo + 1)) (x : BitVec n): Bi
let x_with_hole := x &&& mask_with_hole
x_with_hole ||| val_aligned

example : (partInstall 3 0 0xC#4 0xAB0D#16 = 0xAB0C#16) := by native_decide
example : (partInstall 3 0 0xC#4 0xAB0D#16 = 0xAB0C#16) := by rfl

def flattenTR {n : Nat} (xs : List (BitVec n)) (i : Nat)
(acc : BitVec len) (H : n > 0) : BitVec len :=
Expand All @@ -308,37 +308,35 @@ def flattenTR {n : Nat} (xs : List (BitVec n)) (i : Nat)
/-- Reverse bits of a bit-vector. -/
def reverse (x : BitVec n) : BitVec n :=
let rec reverseTR (x : BitVec n) (i : Nat) (acc : BitVec n) :=
if i < n then
let xi := extractLsb i i x
have h : i - i + 1 = (n - i - 1) - (n - i - 1) + 1 := by omega
let acc := BitVec.partInstall (n - i - 1) (n - i - 1) (BitVec.cast h xi) acc
reverseTR x (i + 1) acc
else acc
reverseTR x 0 $ BitVec.zero n

example : reverse 0b11101#5 = 0b10111#5 := by
-- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here.
simp [reverse, reverse.reverseTR]
rfl
match i with
| 0 => acc
| j + 1 =>
have h1 : i - 1 - (i - 1) + 1 = 1 := by omega
let xi : BitVec 1 := BitVec.cast h1 $ extractLsb (i - 1) (i - 1) x
have h2 : 1 = (n - i) - (n - i) + 1 := by omega
let acc := BitVec.partInstall (n - i) (n - i) (BitVec.cast h2 xi) acc
reverseTR x j acc
reverseTR x n $ BitVec.zero n

example : reverse 0b11101#5 = 0b10111#5 := by rfl

/-- Split a bit-vector into sub vectors of size e. -/
def split (x : BitVec n) (e : Nat) (h : 0 < e): List (BitVec e) :=
let rec splitTR (x : BitVec n) (e : Nat) (h : 0 < e)
(i : Nat) (acc : List (BitVec e)) : List (BitVec e) :=
if i < n/e then
let lo := i * e
match i with
| 0 => acc
| j + 1 =>
let lo := (n / e - i) * e
let hi := lo + e - 1
have h₀ : hi - lo + 1 = e := by simp only [hi, lo]; omega
let part : BitVec e := BitVec.cast h₀ (extractLsb hi lo x)
let newacc := part :: acc
splitTR x e h (i + 1) newacc
else acc
splitTR x e h 0 []
splitTR x e h j newacc
splitTR x e h (n / e) []

example : split 0xabcd1234#32 8 (by omega) = [0xab#8, 0xcd#8, 0x12#8, 0x34#8] :=
by
-- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here.
simp [split, split.splitTR]
by rfl

/-- Reverse a list of bit vectors and flatten the list. -/
def revflat (x : List (BitVec n)) : BitVec (n * x.length) :=
Expand Down
102 changes: 51 additions & 51 deletions Specs/GCMV8.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,29 @@ def lo (x : BitVec 128) : BitVec 64 :=
def pmult (x: BitVec (m + 1)) (y : BitVec (n + 1)) : BitVec (m + n + 1) :=
let rec pmultTR (x: BitVec (m + 1)) (y : BitVec (n + 1)) (i : Nat)
(acc : BitVec (m + n + 1)) : BitVec (m + n + 1) :=
if i < n + 1 then
match i with
| 0 => acc
| j + 1 =>
let acc := acc <<< 1
have h : m + n + 1 = n + (m + 1) := by omega
let tmp := if getMsbD y i
let tmp := if getMsbD y (n + 1 - i)
then (BitVec.zero n) ++ x
else BitVec.cast h (BitVec.zero (m + n + 1))
let acc := (BitVec.cast h acc) ^^^ tmp
pmultTR x y (i + 1) (BitVec.cast h.symm acc)
else acc
pmultTR x y 0 (BitVec.zero (m + n + 1))
pmultTR x y j (BitVec.cast h.symm acc)
pmultTR x y (n + 1) (BitVec.zero (m + n + 1))

example: pmult 0b1101#4 0b10#2 = 0b11010#5 := by
-- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here.
native_decide
example: pmult 0b1101#4 0b10#2 = 0b11010#5 := by rfl

/-- Degree of x. -/
private def degree (x : BitVec n) : Nat :=
let rec degreeTR (x : BitVec n) (n : Nat) : Nat :=
if n = 0 then 0
else if getLsbD x n then n else degreeTR x (n - 1)
match n with
| 0 => 0
| m + 1 =>
if getLsbD x n then n else degreeTR x m
degreeTR x (n - 1)
example: GCMV8.degree 0b0101#4 = 2 := by
-- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here.
native_decide
example: GCMV8.degree 0b0101#4 = 2 := by rfl

/-- Subtract x from y if y's x-degree-th bit is 1. -/
private def reduce (x : BitVec n) (y : BitVec n) : BitVec n :=
Expand All @@ -70,50 +69,50 @@ private def reduce (x : BitVec n) (y : BitVec n) : BitVec n :=
def pdiv (x: BitVec n) (y : BitVec m) (h : 0 < m): BitVec n :=
let rec pdivTR (x : BitVec n) (y : BitVec m) (i : Nat) (z : BitVec m)
(acc : BitVec n) : BitVec n :=
if i < n then
have h2 : (n - i - 1) - (n - i - 1) + 1 = 1 := by omega
let xi : BitVec 1 := BitVec.cast h2 (extractLsb (n - i - 1) (n - i - 1) x)
match i with
| 0 => acc
| j + 1 =>
have h2 : (i - 1) - (i - 1) + 1 = 1 := by omega
let xi : BitVec 1 := BitVec.cast h2 (extractLsb (i - 1) (i - 1) x)
have h3 : m - 1 - 0 + 1 = m := by omega
let zi : BitVec m :=
BitVec.cast h3 (extractLsb (m - 1) 0 ((GCMV8.reduce y z) ++ xi))
have h1 : GCMV8.degree y - GCMV8.degree y + 1 = 1 := by omega
let bit : BitVec 1 :=
BitVec.cast h1 $ extractLsb (GCMV8.degree y) (GCMV8.degree y) zi
have h4 : 1 = (n - i - 1) - (n - i - 1) + 1 := by omega
have h4 : 1 = (i - 1) - (i - 1) + 1 := by omega
let newacc : BitVec n :=
partInstall (n - i - 1) (n - i - 1) (BitVec.cast h4 bit) acc
pdivTR x y (i + 1) zi newacc
else acc
pdivTR x y 0 (BitVec.zero m) (BitVec.zero n)
partInstall (i - 1) (i - 1) (BitVec.cast h4 bit) acc
pdivTR x y j zi newacc
pdivTR x y n (BitVec.zero m) (BitVec.zero n)

-- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here.
example : pdiv 0b1101#4 0b10#2 (by omega) = 0b110#4 := by native_decide
example : pdiv 0x1a#5 0b10#2 (by omega) = 0b1101#5 := by native_decide
example : pdiv 0b1#1 0b10#2 (by omega) = 0b0#1 := by native_decide
example : pdiv 0b1101#4 0b10#2 (by omega) = 0b110#4 := by rfl
example : pdiv 0x1a#5 0b10#2 (by omega) = 0b1101#5 := by rfl
example : pdiv 0b1#1 0b10#2 (by omega) = 0b0#1 := by rfl

/-- Performs modulus of polynomials over GF(2). -/
def pmod (x : BitVec n) (y : BitVec (m + 1)) (H : 0 < m) : BitVec m :=
let rec pmodTR (x : BitVec n) (y : BitVec (m + 1)) (p : BitVec (m + 1))
(i : Nat) (r : BitVec m) (H : 0 < m) : BitVec m :=
if i < n then
let xi := getLsbD x i
match i with
| 0 => r
| j + 1 =>
let xi := getLsbD x (n - i)
have h : m - 1 + 1 = m := by omega
let tmp : BitVec (m - 1 + 1) :=
if xi
then extractLsb (m - 1) 0 p
else BitVec.cast h.symm (BitVec.zero m)
let r := (BitVec.cast h.symm r) ^^^ tmp
pmodTR x y (GCMV8.reduce y (p <<< 1)) (i + 1) (BitVec.cast h r) H
else r
if y = 0 then 0 else pmodTR x y (GCMV8.reduce y 1) 0 (BitVec.zero m) H
pmodTR x y (GCMV8.reduce y (p <<< 1)) j (BitVec.cast h r) H
if y = 0 then 0 else pmodTR x y (GCMV8.reduce y 1) n (BitVec.zero m) H

-- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here.
example: pmod 0b011#3 0b00#2 (by omega) = 0b0#1 := by native_decide
example: pmod 0b011#3 0b01#2 (by omega) = 0b0#1 := by native_decide
example: pmod 0b011#3 0b10#2 (by omega) = 0b1#1 := by native_decide
example: pmod 0b011#3 0b11#2 (by omega) = 0b0#1 := by native_decide
example: pmod 0b011#3 0b100#3 (by omega) = 0b11#2 := by native_decide
example: pmod 0b011#3 0b1001#4 (by omega) = 0b11#3 := by native_decide
example: pmod 0b011#3 0b00#2 (by omega) = 0b0#1 := by rfl
example: pmod 0b011#3 0b01#2 (by omega) = 0b0#1 := by rfl
example: pmod 0b011#3 0b10#2 (by omega) = 0b1#1 := by rfl
example: pmod 0b011#3 0b11#2 (by omega) = 0b0#1 := by rfl
example: pmod 0b011#3 0b100#3 (by omega) = 0b11#2 := by rfl
example: pmod 0b011#3 0b1001#4 (by omega) = 0b11#3 := by rfl

------------------------------------------------------------------------------
-- Functions related to GCM
Expand Down Expand Up @@ -188,10 +187,8 @@ def GCMInitV8 (H : BitVec 128) : (List (BitVec 128)) :=
[H0_rev, H1, H2_rev, H3_rev, H4, H5_rev, H6_rev,
H7, H8_rev, H9_rev, H10, H11_rev]

-- set_option profiler true in
-- set_option maxRecDepth 20000 in
-- set_option maxHeartbeats 2000000 in
-- unseal pmod.pmodTR degree.degreeTR reverse.reverseTR pmult.pmultTR Nat.bitwise in
set_option maxRecDepth 20000 in
set_option maxHeartbeats 200000 in
example : GCMInitV8 0x66e94bd4ef8a2c3b884cfa59ca342b2e#128 ==
[ 0x1099f4b39468565ccdd297a9df145877#128,
0x62d81a7fe5da3296dd4b631a4b7c0e2b#128,
Expand All @@ -204,7 +201,7 @@ example : GCMInitV8 0x66e94bd4ef8a2c3b884cfa59ca342b2e#128 ==
0x4af32418184aee1eec87cfb0e19d1c4e#128,
0xf109e6e0b31d1eee7d1998bcfc545474#128,
0x7498729da40cd2808c107e5c4f494a9a#128,
0xa47c653dfbeac924d0e417a05fe61ba4#128 ] := by native_decide
0xa47c653dfbeac924d0e417a05fe61ba4#128 ] := by rfl

/-- GCMGmultV8 specification:
H : [128] -- the first element in Htable, not the initial H input to GCMInitV8
Expand All @@ -216,12 +213,13 @@ def GCMGmultV8 (H : BitVec 128) (Xi : List (BitVec 8)) (h : 8 * Xi.length = 128)
let H := (lo H) ++ (hi H)
split (GCMV8.gcm_polyval H (BitVec.cast h (BitVec.flatten Xi))) 8 (by omega)

-- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here.
set_option maxRecDepth 20000 in
set_option maxHeartbeats 200000 in
example : GCMGmultV8 0x1099f4b39468565ccdd297a9df145877#128
[ 0x10#8, 0x54#8, 0x43#8, 0xb0#8, 0x2c#8, 0x4b#8, 0x1f#8, 0x24#8,
0x3b#8, 0xcd#8, 0xd4#8, 0x87#8, 0x16#8, 0x65#8, 0xb3#8, 0x2b#8 ] (by decide) =
[ 0xa2#8, 0xc9#8, 0x9c#8, 0x56#8, 0xeb#8, 0xa7#8, 0x91#8, 0xf6#8,
0x9e#8, 0x15#8, 0xa6#8, 0x00#8, 0x67#8, 0x29#8, 0x7e#8, 0x0f#8 ] := by native_decide
0x9e#8, 0x15#8, 0xa6#8, 0x00#8, 0x67#8, 0x29#8, 0x7e#8, 0x0f#8 ] := by rfl


private def gcm_ghash_block (H : BitVec 128) (Xi : BitVec 128)
Expand All @@ -240,26 +238,28 @@ def GCMGhashV8 (H : BitVec 128) (Xi : List (BitVec 8))
: List (BitVec 8) :=
let rec GCMGhashV8TR {m : Nat} (H : BitVec 128) (Xi : BitVec 128)
(inp : BitVec m) (i : Nat) (h1 : 128 ∣ m) : BitVec 128 :=
if i < m / 128 then
let lo := m - (i + 1) * 128
match i with
| 0 => Xi
| j + 1 =>
let lo := (i - 1) * 128
let hi := lo + 127
have h2 : hi - lo + 1 = 128 := by omega
let inpi : BitVec 128 := BitVec.cast h2 $ extractLsb hi lo inp
let Xj := GCMV8.gcm_ghash_block H Xi inpi
GCMGhashV8TR H Xj inp (i + 1) h1
else Xi
GCMGhashV8TR H Xj inp j h1
have h3 : 1288 * inp.length := by omega
have h4 : 8 * Xi.length = 128 := by omega
let flat_Xi := BitVec.cast h4 $ BitVec.flatten Xi
let flat_inp := BitVec.flatten inp
split (GCMGhashV8TR H flat_Xi flat_inp 0 h3) 8 (by omega)
split (GCMGhashV8TR H flat_Xi flat_inp (8 * inp.length / 128) h3) 8 (by omega)

-- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here.
set_option maxRecDepth 20000 in
set_option maxHeartbeats 200000 in
example : GCMGhashV8 0x1099f4b39468565ccdd297a9df145877#128
[ 0xa2#8, 0xc9#8, 0x9c#8, 0x56#8, 0xeb#8, 0xa7#8, 0x91#8, 0xf6#8,
0x9e#8, 0x15#8, 0xa6#8, 0x00#8, 0x67#8, 0x29#8, 0x7e#8, 0x0f#8 ]
(List.replicate 16 0x2a#8) (by simp) (by simp only [List.length_replicate]; omega) =
[ 0x20#8, 0x60#8, 0x2e#8, 0x75#8, 0x7a#8, 0x4e#8, 0xec#8, 0x90#8,
0xc0#8, 0x9d#8, 0x49#8, 0xfd#8, 0xdc#8, 0xf2#8, 0xc9#8, 0x35#8 ] := by native_decide
0xc0#8, 0x9d#8, 0x49#8, 0xfd#8, 0xdc#8, 0xf2#8, 0xc9#8, 0x35#8 ] := by rfl

end GCMV8

0 comments on commit 82d9460

Please sign in to comment.