Skip to content

Commit

Permalink
feat: efficient axiomatization of effects (leanprover#129)
Browse files Browse the repository at this point in the history
### Description:

We've built a new datastructure `AxEffects`, which stores a map from
`StateField` to two `Expr`s: the current value of that field, and a
proof that `r <field> <currentState> = <value>`, together with other
assorted proofs related to axiomatic effects.

We then incorporate this `AxEffects` into `sym_n`, replacing and
superseding `intro_fetch_decode`. This has two benefits: `AxEffects` is
more complete, and it is *much* faster. Fully simulating
`gcm_gmult_v8_program` (assuming step theorems have been pre-generated)
used to take around 3.5 seconds, with this PR it takes only 0.5 seconds:
a 7x speedup!

- In the process, we've proven that `Abs` correctly implements its spec
- No longer is a `CheckSPAlignment` goal added by default. Instead, we
only add `CheckSPAlignment` goals if the initial state was known to be
aligned, and there is some write to the SP of which alignment could not
be automatically proven. For now, this goal is added eagerly, without
trying to determine if this fact is actually needed for further
simulation. Trying to be smarter about this is left for future PRs.

### Testing:

`make all` works locally

### License:

By submitting this pull request, I confirm that my contribution is
made under the terms of the Apache 2.0 license.

---------

Co-authored-by: Siddharth <siddu.druid@gmail.com>
Co-authored-by: Shilpi Goel <shigoel@gmail.com>
  • Loading branch information
3 people authored Sep 4, 2024
1 parent b0a10ee commit a52baba
Show file tree
Hide file tree
Showing 11 changed files with 985 additions and 37 deletions.
5 changes: 5 additions & 0 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,11 @@ protected theorem extract_lsb_of_zeroExtend (x : BitVec n) (h : j < i) :
simp_all
omega

@[bitvec_rules, simp]
theorem zero_append {w} (x : BitVec 0) (y : BitVec w) :
x ++ y = y.cast (by simp) := by
apply eq_of_getLsb_eq; simp; omega

@[bitvec_rules]
theorem empty_bitvector_append_left
(x : BitVec n) (h : 0 + n = n) :
Expand Down
14 changes: 14 additions & 0 deletions Arm/Insts/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ theorem CheckSPAlignment_AddWithCarry_64_4 (st : ArmState) (y : BitVec 64) (carr
simp_all only [CheckSPAlignment, read_gpr, zeroExtend_eq, Nat.sub_zero, add_eq,
Aligned_AddWithCarry_64_4]

@[state_simp_rules]
theorem CheckSPAlignment_of_r_sp_eq {s s' : ArmState}
(h_eq : r (StateField.GPR 31#5) s' = r (StateField.GPR 31#5) s)
(h_sp : CheckSPAlignment s) :
CheckSPAlignment s' := by
simpa only [CheckSPAlignment, read_gpr, h_eq] using h_sp

@[state_simp_rules]
theorem CheckSPAlignment_of_r_sp_aligned {s : ArmState} {value}
(h_eq : r (StateField.GPR 31#5) s = value)
(h_aligned : Aligned value 4) :
CheckSPAlignment s := by
simp only [CheckSPAlignment, read_gpr, h_eq, zeroExtend_eq, h_aligned]

----------------------------------------------------------------------

inductive ShiftType where
Expand Down
85 changes: 83 additions & 2 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ inductive PFlag where
| Z : PFlag
| C : PFlag
| V : PFlag
deriving DecidableEq, Repr
deriving DecidableEq, Repr, Hashable

instance : ToString PFlag :=
fun p => match p with
Expand Down Expand Up @@ -255,7 +255,50 @@ inductive StateField where
| PC : StateField
| FLAG : PFlag → StateField
| ERR : StateField
deriving DecidableEq, Repr
deriving DecidableEq, Repr, Hashable

namespace StateField

/-- general purpose register `x31` is used as stack pointer -/
@[state_simp_rules]
abbrev SP := GPR 31#5

/- Might eventually be used to maintain a `O(1)` access `StateField` map,
by using `StateField.toFin` to index into a sparse array. -/
/-- `StateField` is equivalent to `Fin (32 + 32 + 1 + 4 + 1)` -/
def toFin : StateField → Fin 70
| .GPR x => x.toFin.castAdd 38
| .SFP x => (x.toFin.addNat 32).castAdd 6
| .PC => 64
| .FLAG .N => 65
| .FLAG .Z => 66
| .FLAG .C => 67
| .FLAG .V => 68
| .ERR => 69

section ToExpr
open Lean PFlag

instance : ToExpr PFlag where
toTypeExpr := mkConst ``PFlag
toExpr := fun
| N => mkConst ``N
| V => mkConst ``V
| C => mkConst ``C
| Z => mkConst ``Z

instance : ToExpr StateField where
toTypeExpr := mkConst ``StateField
toExpr := fun
| GPR x => mkApp (mkConst ``GPR) (toExpr x)
| SFP x => mkApp (mkConst ``SFP) (toExpr x)
| PC => mkConst ``PC
| FLAG fl => mkApp (mkConst ``FLAG) (toExpr fl)
| ERR => mkConst ``ERR

end ToExpr

end StateField

instance : ToString StateField :=
fun s => match s with
Expand Down Expand Up @@ -645,6 +688,14 @@ theorem read_mem_bytes_of_w :
rw [n_ih]
done

@[state_simp_rules]
theorem read_mem_bytes_w_of_read_mem_eq
(h : ∀ n addr, read_mem_bytes n addr s₁ = read_mem_bytes n addr s₂)
(fld val n₁ addr₁) :
read_mem_bytes n₁ addr₁ (w fld val s₁)
= read_mem_bytes n₁ addr₁ s₂ := by
simp only [read_mem_bytes_of_w, h]

@[state_simp_rules]
theorem write_mem_bytes_program {n : Nat} (addr : BitVec 64) (bytes : BitVec (n * 8)):
(write_mem_bytes n addr bytes s).program = s.program := by
Expand Down Expand Up @@ -1022,3 +1073,33 @@ by_cases h : ix < base
· simp only [h₂, ↓reduceIte, BitVec.getLsb_extractLsByte]

end Memory

/-! ## Helper lemma for `AxEffects` -/

@[state_simp_rules]
theorem Memory.eq_of_read_mem_bytes_eq {m₁ m₂ : Memory}
(h : ∀ n addr, m₁.read_bytes n addr = m₂.read_bytes n addr) :
m₁ = m₂ := by
funext i
specialize (h 1 i)
simp only [Nat.reduceMul, read_bytes, Nat.reduceAdd, read, read_store,
BitVec.cast_eq] at h
rw [BitVec.zero_append, BitVec.zero_append] at h
simpa only [Nat.reduceAdd, BitVec.cast_eq] using h

theorem mem_eq_iff_read_mem_bytes_eq {s₁ s₂ : ArmState} :
s₁.mem = s₂.mem
↔ ∀ n addr, read_mem_bytes n addr s₁ = read_mem_bytes n addr s₂ := by
simp only [memory_rules]
constructor
· intro h _ _; rw[h]
· exact Memory.eq_of_read_mem_bytes_eq

theorem read_mem_bytes_write_mem_bytes_of_read_mem_eq
(h : ∀ n addr, read_mem_bytes n addr s₁ = read_mem_bytes n addr s₂)
(n₂ addr₂ val n₁ addr₁) :
read_mem_bytes n₁ addr₁ (write_mem_bytes n₂ addr₂ val s₁)
= read_mem_bytes n₁ addr₁ (write_mem_bytes n₂ addr₂ val s₂) := by
revert n₁ addr₁
simp only [← mem_eq_iff_read_mem_bytes_eq] at h ⊢
simp only [memory_rules, h]
34 changes: 29 additions & 5 deletions Proofs/Experiments/Abs/Abs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The goal is to prove that this program implements absolute value correctly.
import Arm
import Tactics.StepThms
import Tactics.Sym

import Tactics.CSE
namespace Abs

def program : Program :=
Expand All @@ -19,7 +19,16 @@ def program : Program :=
(0x4005dc#64, 0x4a000020#32), -- eor w0, w1, w0
(0x4005e0#64, 0xd65f03c0#32)] -- ret

def spec (x : BitVec 32) : BitVec 32 := BitVec.ofNat 32 x.toInt.natAbs
def spec (x : BitVec 32) : BitVec 32 :=
-- We prefer the current definition as opposed to:
-- BitVec.ofNat 32 x.toInt.natAbs
-- because the above has functions like `toInt` that do not play well with
-- bitblasting/LeanSAT.
let msb := BitVec.extractLsb 31 31 x
if msb == 0#1 then
x
else
(0#32 - x)

#genStepEqTheorems program

Expand All @@ -29,15 +38,30 @@ theorem correct
(h_s0_program : s0.program = program)
(h_s0_err : read_err s0 = StateError.None)
(h_s0_sp : CheckSPAlignment s0)
(h_run : sf = run program.length s0) :
(h_run : sf = run (program.length) s0) :
read_gpr 32 0 sf = spec (read_gpr 32 0 s0) ∧
read_err sf = StateError.None := by
simp (config := {ground := true}) at h_run

sym_n 5
sorry
simp only [run] at h_run
subst sf
apply And.intro
· simp only [read_gpr, BitVec.ofNat_eq_ofNat]
simp (config := {decide := true}) only [
h_s5_non_effects,
h_s4_x0,
h_s3_x1, h_s3_non_effects,
h_s2_x0, h_s2_non_effects,
h_s1_x1, h_s1_non_effects]
generalize r (StateField.GPR 0#5) s0 = x
simp only [spec, AddWithCarry]
split <;> bv_decide
· assumption

/-- info: 'Abs.correct' depends on axioms: [propext, sorryAx, Classical.choice, Quot.sound] -/
/--
info: 'Abs.correct' depends on axioms: [propext, Classical.choice, Lean.ofReduceBool, Quot.sound]
-/
#guard_msgs in #print axioms correct

end Abs
13 changes: 13 additions & 0 deletions Tactics/Attr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,16 @@ initialize

-- enable tracing for `sym_n` tactic and related components
registerTraceClass `Tactic.sym

-- enable extra checks for debugging `sym_n`,
-- see `AxEffects.validate` for more detail on what is being type-checked
registerOption `Tactic.sym.debug {
defValue := true
descr := "enable/disable type-checking of internal state during execution \
of the `sym_n` tactic, throwing an error if mal-formed expressions were \
created, indicating a bug in the implementation of `sym_n`.
This is an internal option for debugging purposes, end users should \
generally not set this option, unless they are reporting a bug with \
`sym_n`"
}
35 changes: 35 additions & 0 deletions Tactics/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,28 @@ def reflectBitVecLiteral (w : Nat) (e : Expr) : MetaM (BitVec w) := do
else
throwError "Expected a bitvector of width {w}, but\n\t{e}\nhas width {n}"

def reflectPFLag (e : Expr) : MetaM PFlag :=
match_expr e with
| PFlag.N => pure .N
| PFlag.Z => pure .Z
| PFlag.C => pure .C
| PFlag.V => pure .V
| _ =>
let pflag := mkConst ``PFlag
throwError "Expected a `{pflag}` constructor, found:\n {e}"

/-- Reflect a concrete `StateField` -/
def reflectStateField (e : Expr) : MetaM StateField :=
match_expr e with
| StateField.GPR x => StateField.GPR <$> reflectBitVecLiteral _ x
| StateField.SFP x => StateField.SFP <$> reflectBitVecLiteral _ x
| StateField.PC => pure StateField.PC
| StateField.FLAG f => StateField.FLAG <$> reflectPFLag f
| StateField.ERR => pure StateField.ERR
| _ =>
let sf := mkConst ``StateField
throwError "Expected a `{sf}` constructor, found:\n {e}"

/-! ## Hypothesis types -/
namespace SymContext

Expand Down Expand Up @@ -224,3 +246,16 @@ def findProgramHyp (state : Expr) : MetaM (LocalDecl × Name) := do
throwError "Expected a constant, found:\n\t{program}"

return ⟨h_program, program⟩

/-! ## Expr Builders -/

/-- Return the expression for `ArmState` -/
def mkArmState : Expr := mkConst ``ArmState

/-- Return `x = y`, given expressions `x` and `y` of type `ArmState` -/
def mkEqArmState (x y : Expr) : Expr :=
mkApp3 (.const ``Eq [1]) mkArmState x y

/-- Return a proof of type `x = x`, where `x : ArmState` -/
def mkEqReflArmState (x : Expr) : Expr :=
mkApp2 (.const ``Eq.refl [1]) mkArmState x
Loading

0 comments on commit a52baba

Please sign in to comment.