From 82d94604399aca3554daea7eaf05e5d11d848558 Mon Sep 17 00:00:00 2001 From: Yan Peng Date: Mon, 16 Sep 2024 19:26:42 +0000 Subject: [PATCH] Use structural recursion and rfl --- Arm/BitVec.lean | 40 +++++++++---------- Specs/GCMV8.lean | 102 +++++++++++++++++++++++------------------------ 2 files changed, 70 insertions(+), 72 deletions(-) diff --git a/Arm/BitVec.lean b/Arm/BitVec.lean index a21a033b..6830f101 100644 --- a/Arm/BitVec.lean +++ b/Arm/BitVec.lean @@ -294,7 +294,7 @@ abbrev partInstall (hi lo : Nat) (val : BitVec (hi - lo + 1)) (x : BitVec n): Bi let x_with_hole := x &&& mask_with_hole x_with_hole ||| val_aligned -example : (partInstall 3 0 0xC#4 0xAB0D#16 = 0xAB0C#16) := by native_decide +example : (partInstall 3 0 0xC#4 0xAB0D#16 = 0xAB0C#16) := by rfl def flattenTR {n : Nat} (xs : List (BitVec n)) (i : Nat) (acc : BitVec len) (H : n > 0) : BitVec len := @@ -308,37 +308,35 @@ def flattenTR {n : Nat} (xs : List (BitVec n)) (i : Nat) /-- Reverse bits of a bit-vector. -/ def reverse (x : BitVec n) : BitVec n := let rec reverseTR (x : BitVec n) (i : Nat) (acc : BitVec n) := - if i < n then - let xi := extractLsb i i x - have h : i - i + 1 = (n - i - 1) - (n - i - 1) + 1 := by omega - let acc := BitVec.partInstall (n - i - 1) (n - i - 1) (BitVec.cast h xi) acc - reverseTR x (i + 1) acc - else acc - reverseTR x 0 $ BitVec.zero n - -example : reverse 0b11101#5 = 0b10111#5 := by - -- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here. - simp [reverse, reverse.reverseTR] - rfl + match i with + | 0 => acc + | j + 1 => + have h1 : i - 1 - (i - 1) + 1 = 1 := by omega + let xi : BitVec 1 := BitVec.cast h1 $ extractLsb (i - 1) (i - 1) x + have h2 : 1 = (n - i) - (n - i) + 1 := by omega + let acc := BitVec.partInstall (n - i) (n - i) (BitVec.cast h2 xi) acc + reverseTR x j acc + reverseTR x n $ BitVec.zero n + +example : reverse 0b11101#5 = 0b10111#5 := by rfl /-- Split a bit-vector into sub vectors of size e. -/ def split (x : BitVec n) (e : Nat) (h : 0 < e): List (BitVec e) := let rec splitTR (x : BitVec n) (e : Nat) (h : 0 < e) (i : Nat) (acc : List (BitVec e)) : List (BitVec e) := - if i < n/e then - let lo := i * e + match i with + | 0 => acc + | j + 1 => + let lo := (n / e - i) * e let hi := lo + e - 1 have h₀ : hi - lo + 1 = e := by simp only [hi, lo]; omega let part : BitVec e := BitVec.cast h₀ (extractLsb hi lo x) let newacc := part :: acc - splitTR x e h (i + 1) newacc - else acc - splitTR x e h 0 [] + splitTR x e h j newacc + splitTR x e h (n / e) [] example : split 0xabcd1234#32 8 (by omega) = [0xab#8, 0xcd#8, 0x12#8, 0x34#8] := - by - -- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here. - simp [split, split.splitTR] + by rfl /-- Reverse a list of bit vectors and flatten the list. -/ def revflat (x : List (BitVec n)) : BitVec (n * x.length) := diff --git a/Specs/GCMV8.lean b/Specs/GCMV8.lean index c2fda1e7..1d783d34 100644 --- a/Specs/GCMV8.lean +++ b/Specs/GCMV8.lean @@ -37,30 +37,29 @@ def lo (x : BitVec 128) : BitVec 64 := def pmult (x: BitVec (m + 1)) (y : BitVec (n + 1)) : BitVec (m + n + 1) := let rec pmultTR (x: BitVec (m + 1)) (y : BitVec (n + 1)) (i : Nat) (acc : BitVec (m + n + 1)) : BitVec (m + n + 1) := - if i < n + 1 then + match i with + | 0 => acc + | j + 1 => let acc := acc <<< 1 have h : m + n + 1 = n + (m + 1) := by omega - let tmp := if getMsbD y i + let tmp := if getMsbD y (n + 1 - i) then (BitVec.zero n) ++ x else BitVec.cast h (BitVec.zero (m + n + 1)) let acc := (BitVec.cast h acc) ^^^ tmp - pmultTR x y (i + 1) (BitVec.cast h.symm acc) - else acc - pmultTR x y 0 (BitVec.zero (m + n + 1)) + pmultTR x y j (BitVec.cast h.symm acc) + pmultTR x y (n + 1) (BitVec.zero (m + n + 1)) -example: pmult 0b1101#4 0b10#2 = 0b11010#5 := by - -- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here. - native_decide +example: pmult 0b1101#4 0b10#2 = 0b11010#5 := by rfl /-- Degree of x. -/ private def degree (x : BitVec n) : Nat := let rec degreeTR (x : BitVec n) (n : Nat) : Nat := - if n = 0 then 0 - else if getLsbD x n then n else degreeTR x (n - 1) + match n with + | 0 => 0 + | m + 1 => + if getLsbD x n then n else degreeTR x m degreeTR x (n - 1) -example: GCMV8.degree 0b0101#4 = 2 := by - -- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here. - native_decide +example: GCMV8.degree 0b0101#4 = 2 := by rfl /-- Subtract x from y if y's x-degree-th bit is 1. -/ private def reduce (x : BitVec n) (y : BitVec n) : BitVec n := @@ -70,50 +69,50 @@ private def reduce (x : BitVec n) (y : BitVec n) : BitVec n := def pdiv (x: BitVec n) (y : BitVec m) (h : 0 < m): BitVec n := let rec pdivTR (x : BitVec n) (y : BitVec m) (i : Nat) (z : BitVec m) (acc : BitVec n) : BitVec n := - if i < n then - have h2 : (n - i - 1) - (n - i - 1) + 1 = 1 := by omega - let xi : BitVec 1 := BitVec.cast h2 (extractLsb (n - i - 1) (n - i - 1) x) + match i with + | 0 => acc + | j + 1 => + have h2 : (i - 1) - (i - 1) + 1 = 1 := by omega + let xi : BitVec 1 := BitVec.cast h2 (extractLsb (i - 1) (i - 1) x) have h3 : m - 1 - 0 + 1 = m := by omega let zi : BitVec m := BitVec.cast h3 (extractLsb (m - 1) 0 ((GCMV8.reduce y z) ++ xi)) have h1 : GCMV8.degree y - GCMV8.degree y + 1 = 1 := by omega let bit : BitVec 1 := BitVec.cast h1 $ extractLsb (GCMV8.degree y) (GCMV8.degree y) zi - have h4 : 1 = (n - i - 1) - (n - i - 1) + 1 := by omega + have h4 : 1 = (i - 1) - (i - 1) + 1 := by omega let newacc : BitVec n := - partInstall (n - i - 1) (n - i - 1) (BitVec.cast h4 bit) acc - pdivTR x y (i + 1) zi newacc - else acc - pdivTR x y 0 (BitVec.zero m) (BitVec.zero n) + partInstall (i - 1) (i - 1) (BitVec.cast h4 bit) acc + pdivTR x y j zi newacc + pdivTR x y n (BitVec.zero m) (BitVec.zero n) --- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here. -example : pdiv 0b1101#4 0b10#2 (by omega) = 0b110#4 := by native_decide -example : pdiv 0x1a#5 0b10#2 (by omega) = 0b1101#5 := by native_decide -example : pdiv 0b1#1 0b10#2 (by omega) = 0b0#1 := by native_decide +example : pdiv 0b1101#4 0b10#2 (by omega) = 0b110#4 := by rfl +example : pdiv 0x1a#5 0b10#2 (by omega) = 0b1101#5 := by rfl +example : pdiv 0b1#1 0b10#2 (by omega) = 0b0#1 := by rfl /-- Performs modulus of polynomials over GF(2). -/ def pmod (x : BitVec n) (y : BitVec (m + 1)) (H : 0 < m) : BitVec m := let rec pmodTR (x : BitVec n) (y : BitVec (m + 1)) (p : BitVec (m + 1)) (i : Nat) (r : BitVec m) (H : 0 < m) : BitVec m := - if i < n then - let xi := getLsbD x i + match i with + | 0 => r + | j + 1 => + let xi := getLsbD x (n - i) have h : m - 1 + 1 = m := by omega let tmp : BitVec (m - 1 + 1) := if xi then extractLsb (m - 1) 0 p else BitVec.cast h.symm (BitVec.zero m) let r := (BitVec.cast h.symm r) ^^^ tmp - pmodTR x y (GCMV8.reduce y (p <<< 1)) (i + 1) (BitVec.cast h r) H - else r - if y = 0 then 0 else pmodTR x y (GCMV8.reduce y 1) 0 (BitVec.zero m) H + pmodTR x y (GCMV8.reduce y (p <<< 1)) j (BitVec.cast h r) H + if y = 0 then 0 else pmodTR x y (GCMV8.reduce y 1) n (BitVec.zero m) H --- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here. -example: pmod 0b011#3 0b00#2 (by omega) = 0b0#1 := by native_decide -example: pmod 0b011#3 0b01#2 (by omega) = 0b0#1 := by native_decide -example: pmod 0b011#3 0b10#2 (by omega) = 0b1#1 := by native_decide -example: pmod 0b011#3 0b11#2 (by omega) = 0b0#1 := by native_decide -example: pmod 0b011#3 0b100#3 (by omega) = 0b11#2 := by native_decide -example: pmod 0b011#3 0b1001#4 (by omega) = 0b11#3 := by native_decide +example: pmod 0b011#3 0b00#2 (by omega) = 0b0#1 := by rfl +example: pmod 0b011#3 0b01#2 (by omega) = 0b0#1 := by rfl +example: pmod 0b011#3 0b10#2 (by omega) = 0b1#1 := by rfl +example: pmod 0b011#3 0b11#2 (by omega) = 0b0#1 := by rfl +example: pmod 0b011#3 0b100#3 (by omega) = 0b11#2 := by rfl +example: pmod 0b011#3 0b1001#4 (by omega) = 0b11#3 := by rfl ------------------------------------------------------------------------------ -- Functions related to GCM @@ -188,10 +187,8 @@ def GCMInitV8 (H : BitVec 128) : (List (BitVec 128)) := [H0_rev, H1, H2_rev, H3_rev, H4, H5_rev, H6_rev, H7, H8_rev, H9_rev, H10, H11_rev] --- set_option profiler true in --- set_option maxRecDepth 20000 in --- set_option maxHeartbeats 2000000 in --- unseal pmod.pmodTR degree.degreeTR reverse.reverseTR pmult.pmultTR Nat.bitwise in +set_option maxRecDepth 20000 in +set_option maxHeartbeats 200000 in example : GCMInitV8 0x66e94bd4ef8a2c3b884cfa59ca342b2e#128 == [ 0x1099f4b39468565ccdd297a9df145877#128, 0x62d81a7fe5da3296dd4b631a4b7c0e2b#128, @@ -204,7 +201,7 @@ example : GCMInitV8 0x66e94bd4ef8a2c3b884cfa59ca342b2e#128 == 0x4af32418184aee1eec87cfb0e19d1c4e#128, 0xf109e6e0b31d1eee7d1998bcfc545474#128, 0x7498729da40cd2808c107e5c4f494a9a#128, - 0xa47c653dfbeac924d0e417a05fe61ba4#128 ] := by native_decide + 0xa47c653dfbeac924d0e417a05fe61ba4#128 ] := by rfl /-- GCMGmultV8 specification: H : [128] -- the first element in Htable, not the initial H input to GCMInitV8 @@ -216,12 +213,13 @@ def GCMGmultV8 (H : BitVec 128) (Xi : List (BitVec 8)) (h : 8 * Xi.length = 128) let H := (lo H) ++ (hi H) split (GCMV8.gcm_polyval H (BitVec.cast h (BitVec.flatten Xi))) 8 (by omega) --- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here. +set_option maxRecDepth 20000 in +set_option maxHeartbeats 200000 in example : GCMGmultV8 0x1099f4b39468565ccdd297a9df145877#128 [ 0x10#8, 0x54#8, 0x43#8, 0xb0#8, 0x2c#8, 0x4b#8, 0x1f#8, 0x24#8, 0x3b#8, 0xcd#8, 0xd4#8, 0x87#8, 0x16#8, 0x65#8, 0xb3#8, 0x2b#8 ] (by decide) = [ 0xa2#8, 0xc9#8, 0x9c#8, 0x56#8, 0xeb#8, 0xa7#8, 0x91#8, 0xf6#8, - 0x9e#8, 0x15#8, 0xa6#8, 0x00#8, 0x67#8, 0x29#8, 0x7e#8, 0x0f#8 ] := by native_decide + 0x9e#8, 0x15#8, 0xa6#8, 0x00#8, 0x67#8, 0x29#8, 0x7e#8, 0x0f#8 ] := by rfl private def gcm_ghash_block (H : BitVec 128) (Xi : BitVec 128) @@ -240,26 +238,28 @@ def GCMGhashV8 (H : BitVec 128) (Xi : List (BitVec 8)) : List (BitVec 8) := let rec GCMGhashV8TR {m : Nat} (H : BitVec 128) (Xi : BitVec 128) (inp : BitVec m) (i : Nat) (h1 : 128 ∣ m) : BitVec 128 := - if i < m / 128 then - let lo := m - (i + 1) * 128 + match i with + | 0 => Xi + | j + 1 => + let lo := (i - 1) * 128 let hi := lo + 127 have h2 : hi - lo + 1 = 128 := by omega let inpi : BitVec 128 := BitVec.cast h2 $ extractLsb hi lo inp let Xj := GCMV8.gcm_ghash_block H Xi inpi - GCMGhashV8TR H Xj inp (i + 1) h1 - else Xi + GCMGhashV8TR H Xj inp j h1 have h3 : 128 ∣ 8 * inp.length := by omega have h4 : 8 * Xi.length = 128 := by omega let flat_Xi := BitVec.cast h4 $ BitVec.flatten Xi let flat_inp := BitVec.flatten inp - split (GCMGhashV8TR H flat_Xi flat_inp 0 h3) 8 (by omega) + split (GCMGhashV8TR H flat_Xi flat_inp (8 * inp.length / 128) h3) 8 (by omega) --- (FIXME) With leanprover/lean4:nightly-2024-08-29, just `rfl` sufficed here. +set_option maxRecDepth 20000 in +set_option maxHeartbeats 200000 in example : GCMGhashV8 0x1099f4b39468565ccdd297a9df145877#128 [ 0xa2#8, 0xc9#8, 0x9c#8, 0x56#8, 0xeb#8, 0xa7#8, 0x91#8, 0xf6#8, 0x9e#8, 0x15#8, 0xa6#8, 0x00#8, 0x67#8, 0x29#8, 0x7e#8, 0x0f#8 ] (List.replicate 16 0x2a#8) (by simp) (by simp only [List.length_replicate]; omega) = [ 0x20#8, 0x60#8, 0x2e#8, 0x75#8, 0x7a#8, 0x4e#8, 0xec#8, 0x90#8, - 0xc0#8, 0x9d#8, 0x49#8, 0xfd#8, 0xdc#8, 0xf2#8, 0xc9#8, 0x35#8 ] := by native_decide + 0xc0#8, 0x9d#8, 0x49#8, 0xfd#8, 0xdc#8, 0xf2#8, 0xc9#8, 0x35#8 ] := by rfl end GCMV8