Skip to content

Commit

Permalink
feat: memory model proofs for simp_mem [2/3?] (leanprover#108)
Browse files Browse the repository at this point in the history
This PR is stacked on top of
leanprover#105, and is peeled from
leanprover#90. This is the second of 3
anticipated PRs which build the new `simp_mem` tactic.

This PR change `mem` from private to public, because it was found when
writing automation that
```lean
theorem ArmState.read_mem_eq_mem_read : read_mem addr s = s.mem.read addr := rfl
```

exposes `s.mem` to the outside world, on which we perform memory
operations.
So, we think of `s.mem` as a public interface for the memory, whose
*methods* which modify memory will eventually be made private.


### License:

By submitting this pull request, I confirm that my contribution is
made under the terms of the Apache 2.0 license.

---------

Co-authored-by: Shilpi Goel <shigoel@gmail.com>
Co-authored-by: Alex Keizer <alex@keizer.dev>
  • Loading branch information
3 people authored Aug 22, 2024
1 parent 9e05e6a commit fc5aabe
Showing 1 changed file with 61 additions and 50 deletions.
111 changes: 61 additions & 50 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ structure ArmState where
/-- PState -/
private pstate : PState
/-- Memory: maps 64-bit addresses to bytes -/
private mem : Memory
mem : Memory
/--
Program: maps 64-bit addresses to 32-bit instructions.
Note that we have the following assumption baked into our machine model:
Expand Down Expand Up @@ -735,6 +735,14 @@ theorem State.read_mem_bytes_eq_mem_read_bytes (s : ArmState) :
theorem read_bytes_zero_eq (m : Memory) : m.read_bytes 0 addr = 0#0 :=
rfl

@[memory_rules]
theorem read_bytes_one_eq (m : Memory) : m.read_bytes 1 addr = m.read addr := by
simp [read_bytes, read, bitvec_rules]
apply BitVec.eq_of_getLsb_eq
intros i
rw [BitVec.getLsb_append]
simp only [show (i : Nat) < 8 by omega, decide_True, Nat.zero_le, BitVec.getLsb_ge, cond_true]

theorem read_bytes_succ_eq (m : Memory) :
m.read_bytes (n' + 1) addr = (m.read_bytes n' (addr + 1) ++ m.read addr).cast (by omega) := rfl

Expand Down Expand Up @@ -771,20 +779,47 @@ theorem getLsb_read_bytes {n i : Nat} {addr : BitVec 64} {m : Memory} (hn : n
subst hi'
simp only [Nat.add_sub_cancel, Nat.zero_lt_succ, Nat.add_div_right, Nat.add_mod_right]
congr 2
rw [BitVec.add_assoc]
congr
rw [BitVec.add_def]
congr 1
simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]
rw [Nat.mod_eq_of_lt]
· omega
· omega
bv_omega
· simp only [h₂, decide_False, Bool.false_and, Bool.false_eq, Bool.and_eq_false_imp,
decide_eq_true_eq]
intros h₃
omega
bv_omega
· omega

/--
Describe the behaviour of `m.read_bytes` at a byte level granularity.
-/
@[memory_rules]
theorem extractLsByte_read_bytes {n i : Nat} {addr : BitVec 64} {m : Memory} (h : addr.toNat + n ≤ 2^64) :
(m.read_bytes n addr).extractLsByte i =
if i < n then m.read (addr + (BitVec.ofNat 64 i)) else 0#8 := by
apply BitVec.eq_of_getLsb_eq
simp only [BitVec.getLsb_extractLsByte]
intros j
simp only [show (j : Nat) ≤ 7 by omega, decide_True, Bool.true_and]
rw [getLsb_read_bytes]
by_cases h₁ : i * 8 + ↑j < n * 8
· simp only [h₁, decide_True, Bool.true_and]
simp only [show (i < n) by omega, ↓reduceIte]
simp only [show (i * 8 + ↑j) / 8 = i by omega]
simp only [show (i * 8 + ↑j) % 8 = j by omega]
rfl
· simp only [h₁, decide_False, Bool.false_and, Bool.false_eq]
simp only [show ¬(i < n) by omega, ↓reduceIte, BitVec.getLsb_zero]
· omega

/--
Extracting a byte out of a byte returns the value if `i = 0`, and `0#8`
otherwise.
-/
@[memory_rules]
theorem Memory.extractLsByte_read (m : Memory) :
(m.read addr).extractLsByte i = if i = 0 then m.read addr else 0#8 := by
rw [← read_bytes_one_eq]
rw [extractLsByte_read_bytes (by omega)]
by_cases h : i = 0
· simp only [h, Nat.lt_add_one, ↓reduceIte, BitVec.add_zero, read_bytes_one_eq]
· simp only [h, ↓reduceIte, ite_eq_right_iff]
omega

/--
This is a low level theorem.
Expand Down Expand Up @@ -921,49 +956,25 @@ theorem write_bytes_eq_extractLsByte {ix base : BitVec 64} {m : Memory}
case succ n ih =>
simp only [write_bytes]
by_cases hix : ix.toNat = base.toNat
· obtain hix : ix = base := by
apply BitVec.eq_of_toNat_eq hix
· obtain hix : ix = base := by apply BitVec.eq_of_toNat_eq hix
subst hix
simp only [BitVec.sub_self, BitVec.toNat_ofNat, Nat.reducePow, Nat.zero_mod]
rcases n with rfl | n
· simp only [Nat.reduceAdd, Nat.reduceMul, write_bytes_zero]
rw [write_of_eq (ix := ix) rfl]
· simp only [Nat.reduceAdd, Nat.reduceMul, write_bytes_zero,
write_of_eq (ix := ix) rfl]
rfl
· rw [write_bytes_eq_of_le]
· simp only [write_of_eq rfl, BitVec.extractLsByte_def, Nat.reduceAdd, Nat.reduceMul,
Nat.add_one_sub_one, Nat.sub_zero, BitVec.cast_eq]
· rw [BitVec.toNat_add_eq_toNat_add_toNat]
· simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod, Nat.lt_add_one]
· simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega
· rw [BitVec.toNat_add_eq_toNat_add_toNat
(by simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega)]
simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega
· rw [ih]
-- | TODO: make these into some kind of proof automation.
· have h_base_plus_1 : (base + 1#64).toNat = base.toNat + 1 := by
simp only [BitVec.toNat_add, BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]
rw [Nat.mod_eq_of_lt (by omega)]
have h_ix_sub_base_plus_1 : (ix - (base + 1#64)).toNat = ix.toNat - (base + 1#64).toNat := by
rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le]
simp only [BitVec.le_def, BitVec.toNat_add, BitVec.toNat_ofNat, Nat.reducePow,
Nat.reduceMod]; omega
have h_ix_sub_base : (ix - base).toNat = ix.toNat - base.toNat := by
rw [BitVec.toNat_sub_eq_toNat_sub_toNat_of_le]
rw [BitVec.le_def]
omega
rw [h_ix_sub_base_plus_1, h_base_plus_1, h_ix_sub_base, Nat.sub_add_eq,
show ix.toNat - base.toNat - 1 = (ix.toNat - base.toNat) - 1 by omega]
apply extractLsByte_zeroExtend_shiftLeft
omega
· rw [BitVec.toNat_add_eq_toNat_add_toNat
(by simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega)]
simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod, ge_iff_le]; omega
· rw [BitVec.toNat_add_eq_toNat_add_toNat
(by simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega)]
simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega
· rw [BitVec.toNat_add_eq_toNat_add_toNat
(by simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega)]
simp only [BitVec.toNat_ofNat, Nat.reducePow, Nat.reduceMod]; omega
· rw [write_bytes_eq_of_le (by bv_omega) (by bv_omega)]
simp only [write_of_eq rfl, BitVec.extractLsByte_def, Nat.reduceAdd, Nat.reduceMul,
Nat.add_one_sub_one, Nat.sub_zero, BitVec.cast_eq]
· rw [ih (by bv_omega) (by bv_omega) (by bv_omega)]
rw [show (ix - (base + 1#64)).toNat = ix.toNat - (base + 1#64).toNat by
bv_omega]
rw [show (base + 1#64).toNat = base.toNat + 1 by bv_omega]
rw [show (ix - base).toNat = ix.toNat - base.toNat by bv_omega]
rw [Nat.sub_add_eq,
show ix.toNat - base.toNat - 1 = (ix.toNat - base.toNat) - 1 by omega]
apply extractLsByte_zeroExtend_shiftLeft
omega

/--
This is a low level theorem.
Expand Down

0 comments on commit fc5aabe

Please sign in to comment.