From 2ef760f102a3037d135729dafbe5b99dc3ffe03a Mon Sep 17 00:00:00 2001 From: Yan Peng Date: Fri, 27 Sep 2024 22:01:42 +0000 Subject: [PATCH] Removing extractLsb from bitvec library and memory aliasing proofs --- Arm/BitVec.lean | 86 +++++++++++++++++------------------- Arm/Memory/MemoryProofs.lean | 35 +++++++-------- Arm/State.lean | 4 +- 3 files changed, 58 insertions(+), 67 deletions(-) diff --git a/Arm/BitVec.lean b/Arm/BitVec.lean index c36588d4..9684d057 100644 --- a/Arm/BitVec.lean +++ b/Arm/BitVec.lean @@ -59,10 +59,7 @@ attribute [bitvec_rules] BitVec.getLsbD_truncate attribute [bitvec_rules] BitVec.zeroExtend_zeroExtend_of_le attribute [bitvec_rules] BitVec.truncate_truncate_of_le attribute [bitvec_rules] BitVec.truncate_cast -attribute [bitvec_rules] BitVec.extractLsb_ofFin -attribute [bitvec_rules] BitVec.extractLsb_ofNat attribute [bitvec_rules] BitVec.extractLsb'_toNat -attribute [bitvec_rules] BitVec.extractLsb_toNat attribute [bitvec_rules] BitVec.getLsbD_extract attribute [bitvec_rules] BitVec.toNat_allOnes attribute [bitvec_rules] BitVec.getLsbD_allOnes @@ -546,17 +543,16 @@ theorem zeroExtend_if_false [Decidable p] (x : BitVec n) (zeroExtend (if p then a else b) x) = BitVec.cast h_eq (zeroExtend b x) := by simp only [toNat_eq, toNat_truncate, ← h_eq, toNat_cast] -theorem extractLsb_eq (x : BitVec n) (h : n = n - 1 + 1) : - BitVec.extractLsb (n - 1) 0 x = BitVec.cast h x := by - unfold extractLsb extractLsb' - ext1 - simp [←h] +theorem extractLsb'_eq (x : BitVec n) : + BitVec.extractLsb' 0 n x = x := by + unfold extractLsb' + simp only [Nat.shiftRight_zero, ofNat_toNat, zeroExtend_eq] @[bitvec_rules] -protected theorem extract_lsb_of_zeroExtend (x : BitVec n) (h : j < i) : - extractLsb' 0 (j + 1) (zeroExtend i x) = zeroExtend (j + 1) x := by +protected theorem extractLsb'_of_zeroExtend (x : BitVec n) (h : j ≤ i) : + extractLsb' 0 j (zeroExtend i x) = zeroExtend j x := by apply BitVec.eq_of_getLsbD_eq - simp + -- simp intro k have q : k < i := by omega by_cases h : decide (k ≤ j) <;> simp [q, h] @@ -634,12 +630,11 @@ theorem append_of_extract_general_nat (high low n vn : Nat) (h : vn < 2 ^ n) : done theorem append_of_extract (n : Nat) (v : BitVec n) - (high0 : high = n - low) (low0 : 1 <= low) - (h : high + (low - 1 - 0 + 1) = n) : - BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb (low - 1) 0 v) = v := by + (high0 : high = n - low) (h : high + low = n) : + BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb' 0 low v) = v := by ext subst high - have vlt := v.isLt; simp_all only [Nat.sub_zero] + have vlt := v.isLt have := append_of_extract_general_nat (n - low) low n (BitVec.toNat v) vlt have low_le : low ≤ n := by omega simp_all [toNat_zeroExtend, Nat.sub_add_cancel, low_le] @@ -649,17 +644,13 @@ theorem append_of_extract (n : Nat) (v : BitVec n) exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) vlt done -theorem append_of_extract_general (v : BitVec n) - (low0 : 1 <= low) - (h1 : high = width) - (h2 : (high + low - 1 - 0 + 1) = (width + (low - 1 - 0 + 1))) : - BitVec.cast h1 (zeroExtend high (v >>> low)) ++ extractLsb (low - 1) 0 v = - BitVec.cast h2 (extractLsb (high + low - 1) 0 v) := by +theorem append_of_extract_general (v : BitVec n) : + (zeroExtend high (v >>> low)) ++ extractLsb' 0 low v = + extractLsb' 0 (high + low) v := by ext have := append_of_extract_general_nat high low n (BitVec.toNat v) - have h_vlt := v.isLt; simp_all only [Nat.sub_zero, h1] - simp only [h_vlt, h1, forall_prop_of_true] at this - have low' : 1 ≤ width + low := Nat.le_trans low0 (Nat.le_add_left low width) + have h_vlt := v.isLt; simp_all only [Nat.sub_zero] + simp only [h_vlt, forall_prop_of_true] at this simp_all [toNat_zeroExtend, Nat.sub_add_cancel] rw [Nat.mod_eq_of_lt (b := 2 ^ n)] at this · rw [this] @@ -788,7 +779,7 @@ def genBVPatMatchTest (var : Term) (pat : BVPat) : MacroM Term := do for c in pat.getComponents do let len := c.length if let some bv ← c.toBVLit? then - let test ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var == $bv) + let test ← `(extractLsb' $(quote shift) $(quote len) $var == $bv) result ← `($result && $test) shift := shift + len return result @@ -810,7 +801,7 @@ def declBVPatVars (var : Term) (pat : BVPat) (rhs : Term) : MacroM Term := do for c in pat.getComponents do let len := c.length if let some y ← c.toBVVar? then - let rhs ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var) + let rhs ← `(extractLsb' $(quote shift) $(quote len) $var) result ← `(let $y := $rhs; $result) shift := shift + len return result @@ -934,23 +925,28 @@ Definition to extract the `n`th least significant *Byte* from a bitvector. TODO: this should be named `getLsByte`, or `getLsbByte` (Shilpi prefers this). -/ def extractLsByte (val : BitVec w₁) (n : Nat) : BitVec 8 := - val.extractLsb ((n + 1) * 8 - 1) (n * 8) |> .cast (by omega) + val.extractLsb' (n * 8) 8 theorem extractLsByte_def (val : BitVec w₁) (n : Nat) : - val.extractLsByte n = (val.extractLsb ((n + 1)*8 - 1) (n * 8) |>.cast (by omega)) := rfl + val.extractLsByte n = val.extractLsb' (n * 8) 8 := rfl -- TODO: upstream -theorem extractLsb_or (x y : BitVec w₁) (n : Nat) : - (x ||| y).extractLsb n lo = (x.extractLsb n lo ||| y.extractLsb n lo) := by +theorem extractLsb'_or (x y : BitVec w₁) (n : Nat) : + (x ||| y).extractLsb' lo n = (x.extractLsb' lo n ||| y.extractLsb' lo n) := by apply BitVec.eq_of_getLsbD_eq simp only [getLsbD_extract, getLsbD_or] intros i - by_cases h : (i : Nat) ≤ n - lo - · simp only [h, decide_True, Bool.true_and] - · simp only [h, decide_False, Bool.false_and, Bool.or_self] + simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, getLsbD_or, Bool.true_and] + +-- TODO: upstream +protected theorem extractLsb'_ofNat (x n : Nat) (l lo : Nat) : + extractLsb' lo l (BitVec.ofNat n x) = .ofNat l ((x % 2^n) >>> lo) := by + apply eq_of_getLsbD_eq + intro ⟨i, _lt⟩ + simp [BitVec.ofNat] theorem extractLsByte_zero {w : Nat} : (0#w).extractLsByte i = 0#8 := by - simp only [extractLsByte, BitVec.extractLsb_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat] + simp only [extractLsByte, BitVec.extractLsb'_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat] theorem extractLsByte_ge (h : 8 * a ≥ w₁) (x : BitVec w₁) : x.extractLsByte a = 0#8 := by @@ -958,7 +954,7 @@ theorem extractLsByte_ge (h : 8 * a ≥ w₁) (x : BitVec w₁) : intros i simp only [getLsbD_zero, extractLsByte_def, getLsbD_cast, getLsbD_extract, Bool.and_eq_false_imp, decide_eq_true_eq] - intros _ + simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, Bool.true_and] apply BitVec.getLsbD_ge omega @@ -967,10 +963,13 @@ theorem getLsbD_extractLsByte (val : BitVec w₁) : ((BitVec.extractLsByte val n).getLsbD i) = (decide (i ≤ 7) && val.getLsbD (n * 8 + i)) := by simp only [extractLsByte, getLsbD_cast, getLsbD_extract] - rw [Nat.succ_mul] - simp only [Nat.add_one_sub_one, - Nat.add_sub_cancel_left] - + simp only [getLsbD_extractLsb'] + generalize val.getLsbD (n * 8 + i) = x + by_cases h : i < 8 + · simp only [show (i : Nat) ≤ 7 by omega, decide_True, Bool.true_and, + Bool.and_iff_right_iff_imp, decide_eq_true_eq, h] + · simp only [show ¬(i : Nat) ≤ 7 by omega, decide_False, Bool.false_and, + Bool.and_eq_false_imp, decide_eq_true_eq, h] /-- Two bitvectors of length `n*8` are equal if all their bytes are equal. @@ -994,9 +993,7 @@ theorem eq_of_extractLsByte_eq (x y : BitVec (n * 8)) @bollu: it's not clear if the definition for n=0 is desirable. -/ def extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) : BitVec (n * 8) := - match h : n with - | 0 => 0#0 - | x + 1 => val.extractLsb (base * 8 + n * 8 - 1) (base * 8) |>.cast (by omega) + extractLsb' (base * 8) (n * 8) val @[bitvec_rules] theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) : @@ -1009,10 +1006,9 @@ theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) simp only [show ¬i < 0 by omega, decide_False, Bool.false_and] · simp only [extractLsBytes, getLsbD_cast, getLsbD_extract, Nat.zero_lt_succ, decide_True, Bool.true_and] - simp only [show base * 8 + (n + 1) * 8 - 1 - base * 8 = (n + 1) * 8 - 1 by omega] by_cases h : i < (n + 1) * 8 - · simp only [show i ≤ (n + 1) * 8 - 1 by omega, decide_True, Bool.true_and, h] - · simp only [show ¬(i ≤ (n + 1) * 8 - 1) by omega, decide_False, Bool.false_and, h] + · simp only [getLsbD_extractLsb', h, decide_True, Bool.true_and] + · simp only [getLsbD_extractLsb', h, decide_False, Bool.false_and] theorem extractLsByte_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) : (BitVec.extractLsBytes val base n).extractLsByte i = diff --git a/Arm/Memory/MemoryProofs.lean b/Arm/Memory/MemoryProofs.lean index 6973512a..e0437500 100644 --- a/Arm/Memory/MemoryProofs.lean +++ b/Arm/Memory/MemoryProofs.lean @@ -63,12 +63,10 @@ theorem read_mem_of_write_mem_bytes_different (hn1 : n <= 2^64) theorem append_byte_of_extract_rest_same_cast (n : Nat) (v : BitVec ((n + 1) * 8)) (hn0 : Nat.succ 0 ≤ n) - (h : (n * 8 + (7 - 0 + 1)) = (n + 1) * 8) : - BitVec.cast h (zeroExtend (n * 8) (v >>> 8) ++ extractLsb 7 0 v) = v := by + (h : (n * 8 + 8) = (n + 1) * 8) : + BitVec.cast h (zeroExtend (n * 8) (v >>> 8) ++ extractLsb' 0 8 v) = v := by apply BitVec.append_of_extract · omega - · omega - · omega done @[state_simp_rules] @@ -85,7 +83,7 @@ theorem read_mem_bytes_of_write_mem_bytes_same (hn1 : n <= 2^64) : case base => simp only [read_mem_bytes, write_mem_bytes, read_mem_of_write_mem_same, BitVec.cast_eq] - have l1 := BitVec.extractLsb_eq v + have l1 := BitVec.extractLsb'_eq v simp only [Nat.reduceSucc, Nat.one_mul, Nat.succ_sub_succ_eq_sub, Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, forall_const] at l1 @@ -431,7 +429,7 @@ private theorem write_mem_bytes_of_write_mem_bytes_shadow_general_n2_eq rename_i n n_ih conv in write_mem_bytes (Nat.succ n) .. => simp only [write_mem_bytes] have n_ih' := @n_ih (addr1 + 1#64) val2 (zeroExtend (n * 8) (val1 >>> 8)) - (write_mem addr1 (extractLsb 7 0 val1) s) + (write_mem addr1 (extractLsb' 0 8 val1) s) (by omega) simp only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero] at h3 by_cases h₁ : n = 0 @@ -483,7 +481,7 @@ theorem write_mem_bytes_of_write_mem_bytes_shadow_general theorem read_mem_of_write_mem_bytes_same_first_address (h0 : 0 < n) (h1 : n <= 2^64) (h : 7 - 0 + 1 = 8) : read_mem addr (write_mem_bytes n addr val s) = - BitVec.cast h (extractLsb 7 0 val) := by + BitVec.cast h (extractLsb' 0 8 val) := by unfold write_mem_bytes; simp only [Nat.sub_zero, BitVec.cast_eq] split · contradiction @@ -495,18 +493,16 @@ theorem read_mem_of_write_mem_bytes_same_first_address -- (FIXME) Argh, it's annoying to need this lemma, but using -- BitVec.cast_eq directly was cumbersome. theorem cast_of_extract_eq (v : BitVec p) - (h1 : hi1 = hi2) (h2 : lo1 = lo2) - (h : hi1 - lo1 + 1 = hi2 - lo2 + 1) : - BitVec.cast h (extractLsb hi1 lo1 v) = (extractLsb hi2 lo2 v) := by + (h1 : n1 = n2) (h2 : lo1 = lo2): + BitVec.cast h (extractLsb' lo1 n1 v) = (extractLsb' lo2 n2 v) := by subst_vars simp only [Nat.sub_zero, BitVec.cast_eq] theorem read_mem_bytes_of_write_mem_bytes_subset_same_first_address (h0 : 0 < n1) (h1 : n1 <= 2^64) (h2 : 0 < n2) (h3 : n2 <= 2^64) - (h4 : mem_subset addr (addr + (BitVec.ofNat 64 (n2 - 1))) addr (addr + (BitVec.ofNat 64 (n1 - 1)))) - (h : n2 * 8 - 1 - 0 + 1 = n2 * 8) : + (h4 : mem_subset addr (addr + (BitVec.ofNat 64 (n2 - 1))) addr (addr + (BitVec.ofNat 64 (n1 - 1)))): read_mem_bytes n2 addr (write_mem_bytes n1 addr val s) = - BitVec.cast h (extractLsb ((n2 * 8) - 1) 0 val) := by + extractLsb' 0 (n2 * 8) val := by have rm_lemma := @read_mem_of_write_mem_bytes_same_first_address n1 addr val s h0 h1 simp only [Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, forall_const] at rm_lemma induction n2, h2 using Nat.le_induction generalizing n1 addr val s @@ -543,21 +539,20 @@ theorem read_mem_bytes_of_write_mem_bytes_subset_same_first_address erw [Nat.mod_eq_of_lt h3] at hn erw [Nat.mod_eq_of_lt h1] at hn exact hn - rw [n_ih (by omega) (by omega) (by omega) _ (by omega)] - · rw [BitVec.extract_lsb_of_zeroExtend (v >>> 8)] - · have l1 := @BitVec.append_of_extract_general ((n1_1 + 1) * 8) 8 (n*8-1+1) (n*8) v + rw [n_ih (by omega) (by omega) (by omega) _] + · rw [BitVec.extractLsb'_of_zeroExtend (v >>> 8)] + · have l1 := @BitVec.append_of_extract_general ((n1_1 + 1) * 8) (n*8) 8 v simp (config := { decide := true }) only [Nat.zero_lt_succ, Nat.mul_pos_iff_of_pos_left, Nat.succ_sub_succ_eq_sub, Nat.sub_zero, Nat.reduceAdd, Nat.succ.injEq, forall_const] at l1 - rw [l1 (by omega) (by omega)] - · simp only [Nat.add_eq, Nat.sub_zero, BitVec.cast_cast] - apply @cast_of_extract_eq ((n1_1 + 1) * 8) (n * 8 - 1 + 1 + 7) ((n + 1) * 8 - 1) 0 0 <;> + rw [l1] + · apply @cast_of_extract_eq ((n1_1 + 1) * 8) (n * 8 + 8) ((n + 1) * 8) 0 0 <;> omega · omega · have rw_lemma2 := @read_mem_of_write_mem_bytes_same_first_address n1_1 (addr + 1#64) (zeroExtend (n1_1 * 8) (v >>> 8)) - (write_mem addr (extractLsb 7 0 v) s) + (write_mem addr (extractLsb' 0 8 v) s) simp only [Nat.reducePow, Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, forall_const] at rw_lemma2 rw [rw_lemma2 (by omega) (by simp only [Nat.reducePow] at h1; omega)] diff --git a/Arm/State.lean b/Arm/State.lean index aaead5a1..d0b8c27a 100644 --- a/Arm/State.lean +++ b/Arm/State.lean @@ -962,7 +962,7 @@ def write_bytes (n : Nat) (addr : BitVec 64) match n with | 0 => m | n' + 1 => - let byte := BitVec.extractLsb 7 0 val + let byte := BitVec.extractLsb' 0 8 val let m := m.write addr byte let val_rest := BitVec.zeroExtend (n' * 8) (val >>> 8) m.write_bytes n' (addr + 1#64) val_rest @@ -988,7 +988,7 @@ and then recursing to write the rest. -/ theorem write_bytes_succ {mem : Memory} : mem.write_bytes (n + 1) addr val = - let byte := BitVec.extractLsb 7 0 val + let byte := BitVec.extractLsb' 0 8 val let mem := mem.write addr byte let val_rest := BitVec.zeroExtend (n * 8) (val >>> 8) mem.write_bytes n (addr + 1#64) val_rest := rfl