From fc5aabed5fe174872f2279c506d992f144682588 Mon Sep 17 00:00:00 2001 From: Siddharth Date: Thu, 22 Aug 2024 10:37:54 -0500 Subject: [PATCH] feat: memory model proofs for simp_mem [2/3?] (#108) This PR is stacked on top of https://github.com/leanprover/LNSym/pull/105, and is peeled from https://github.com/leanprover/LNSym/pull/90. This is the second of 3 anticipated PRs which build the new `simp_mem` tactic. This PR change `mem` from private to public, because it was found when writing automation that ```lean theorem ArmState.read_mem_eq_mem_read : read_mem addr s = s.mem.read addr := rfl ``` exposes `s.mem` to the outside world, on which we perform memory operations. So, we think of `s.mem` as a public interface for the memory, whose *methods* which modify memory will eventually be made private. ### License: By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. --------- Co-authored-by: Shilpi Goel Co-authored-by: Alex Keizer --- Arm/State.lean | 111 +++++++++++++++++++++++++++---------------------- 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/Arm/State.lean b/Arm/State.lean index be289ecb..7720ae74 100644 --- a/Arm/State.lean +++ b/Arm/State.lean @@ -135,7 +135,7 @@ structure ArmState where /-- PState -/ private pstate : PState /-- Memory: maps 64-bit addresses to bytes -/ - private mem : Memory + mem : Memory /-- Program: maps 64-bit addresses to 32-bit instructions. Note that we have the following assumption baked into our machine model: @@ -735,6 +735,14 @@ theorem State.read_mem_bytes_eq_mem_read_bytes (s : ArmState) : theorem read_bytes_zero_eq (m : Memory) : m.read_bytes 0 addr = 0#0 := rfl +@[memory_rules] +theorem read_bytes_one_eq (m : Memory) : m.read_bytes 1 addr = m.read addr := by + simp [read_bytes, read, bitvec_rules] + apply BitVec.eq_of_getLsb_eq + intros i + rw [BitVec.getLsb_append] + simp only [show (i : Nat) < 8 by omega, decide_True, Nat.zero_le, BitVec.getLsb_ge, cond_true] + theorem read_bytes_succ_eq (m : Memory) : m.read_bytes (n' + 1) addr = (m.read_bytes n' (addr + 1) ++ m.read addr).cast (by omega) := rfl @@ -771,20 +779,47 @@ theorem getLsb_read_bytes {n i : Nat} {addr : BitVec 64} {m : Memory} (hn : n subst hi' simp only [Nat.add_sub_cancel, Nat.zero_lt_succ, Nat.add_div_right, Nat.add_mod_right] congr 2 - rw [BitVec.add_assoc] - congr - rw [BitVec.add_def] - congr 1 - simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod] - rw [Nat.mod_eq_of_lt] - · omega - · omega + bv_omega · simp only [h₂, decide_False, Bool.false_and, Bool.false_eq, Bool.and_eq_false_imp, decide_eq_true_eq] - intros h₃ - omega + bv_omega · omega +/-- +Describe the behaviour of `m.read_bytes` at a byte level granularity. +-/ +@[memory_rules] +theorem extractLsByte_read_bytes {n i : Nat} {addr : BitVec 64} {m : Memory} (h : addr.toNat + n ≤ 2^64) : + (m.read_bytes n addr).extractLsByte i = + if i < n then m.read (addr + (BitVec.ofNat 64 i)) else 0#8 := by + apply BitVec.eq_of_getLsb_eq + simp only [BitVec.getLsb_extractLsByte] + intros j + simp only [show (j : Nat) ≤ 7 by omega, decide_True, Bool.true_and] + rw [getLsb_read_bytes] + by_cases h₁ : i * 8 + ↑j < n * 8 + · simp only [h₁, decide_True, Bool.true_and] + simp only [show (i < n) by omega, ↓reduceIte] + simp only [show (i * 8 + ↑j) / 8 = i by omega] + simp only [show (i * 8 + ↑j) % 8 = j by omega] + rfl + · simp only [h₁, decide_False, Bool.false_and, Bool.false_eq] + simp only [show ¬(i < n) by omega, ↓reduceIte, BitVec.getLsb_zero] + · omega + +/-- +Extracting a byte out of a byte returns the value if `i = 0`, and `0#8` +otherwise. +-/ +@[memory_rules] +theorem Memory.extractLsByte_read (m : Memory) : + (m.read addr).extractLsByte i = if i = 0 then m.read addr else 0#8 := by + rw [← read_bytes_one_eq] + rw [extractLsByte_read_bytes (by omega)] + by_cases h : i = 0 + · simp only [h, Nat.lt_add_one, ↓reduceIte, BitVec.add_zero, read_bytes_one_eq] + · simp only [h, ↓reduceIte, ite_eq_right_iff] + omega /-- This is a low level theorem. @@ -921,49 +956,25 @@ theorem write_bytes_eq_extractLsByte {ix base : BitVec 64} {m : Memory} case succ n ih => simp only [write_bytes] by_cases hix : ix.toNat = base.toNat - · obtain hix : ix = base := by - apply BitVec.eq_of_toNat_eq hix + · obtain hix : ix = base := by apply BitVec.eq_of_toNat_eq hix subst hix simp only [BitVec.sub_self, BitVec.toNat_ofNat, Nat.reducePow, Nat.zero_mod] rcases n with rfl | n - · simp only [Nat.reduceAdd, Nat.reduceMul, write_bytes_zero] - rw [write_of_eq (ix := ix) rfl] + · simp only [Nat.reduceAdd, Nat.reduceMul, write_bytes_zero, + write_of_eq (ix := ix) rfl] rfl - · rw [write_bytes_eq_of_le] - · simp only [write_of_eq rfl, BitVec.extractLsByte_def, Nat.reduceAdd, Nat.reduceMul, - Nat.add_one_sub_one, Nat.sub_zero, BitVec.cast_eq] - · rw [BitVec.toNat_add_eq_toNat_add_toNat] - · simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod, Nat.lt_add_one] - · simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega - · rw [BitVec.toNat_add_eq_toNat_add_toNat - (by simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega)] - simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega - · rw [ih] - -- | TODO: make these into some kind of proof automation. - · have h_base_plus_1 : (base + 1#64).toNat = base.toNat + 1 := by - simp only [BitVec.toNat_add, BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod] - rw [Nat.mod_eq_of_lt (by omega)] - have h_ix_sub_base_plus_1 : (ix - (base + 1#64)).toNat = ix.toNat - (base + 1#64).toNat := by - rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le] - simp only [BitVec.le_def, BitVec.toNat_add, BitVec.toNat_ofNat, Nat.reducePow, - Nat.reduceMod]; omega - have h_ix_sub_base : (ix - base).toNat = ix.toNat - base.toNat := by - rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le] - rw [BitVec.le_def] - omega - rw [h_ix_sub_base_plus_1, h_base_plus_1, h_ix_sub_base, Nat.sub_add_eq, - show ix.toNat - base.toNat - 1 = (ix.toNat - base.toNat) - 1 by omega] - apply extractLsByte_zeroExtend_shiftLeft - omega - · rw [BitVec.toNat_add_eq_toNat_add_toNat - (by simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega)] - simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod, ge_iff_le]; omega - · rw [BitVec.toNat_add_eq_toNat_add_toNat - (by simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega)] - simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega - · rw [BitVec.toNat_add_eq_toNat_add_toNat - (by simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega)] - simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega + · rw [write_bytes_eq_of_le (by bv_omega) (by bv_omega)] + simp only [write_of_eq rfl, BitVec.extractLsByte_def, Nat.reduceAdd, Nat.reduceMul, + Nat.add_one_sub_one, Nat.sub_zero, BitVec.cast_eq] + · rw [ih (by bv_omega) (by bv_omega) (by bv_omega)] + rw [show (ix - (base + 1#64)).toNat = ix.toNat - (base + 1#64).toNat by + bv_omega] + rw [show (base + 1#64).toNat = base.toNat + 1 by bv_omega] + rw [show (ix - base).toNat = ix.toNat - base.toNat by bv_omega] + rw [Nat.sub_add_eq, + show ix.toNat - base.toNat - 1 = (ix.toNat - base.toNat) - 1 by omega] + apply extractLsByte_zeroExtend_shiftLeft + omega /-- This is a low level theorem.