Skip to content

Commit

Permalink
Rework step theorem generation to be faster and cache intermediate re…
Browse files Browse the repository at this point in the history
…sults to the environment (leanprover#92)

### Description:

Forewarning: the diff of this one is rather large.

I've implemented a new version of step theorem generation, culminating
in the `genStepEqTheorem` function.

In particular:
- We get rid of the different fetch/decode/exec intermediate lemma,
instead opting to generate the final step theorem in one go. The
bottleneck seems to be kernel checking, so by reducing the number of
theorems sent to the kernel we achieve a good speedup (SHA512 used to
take about 55 seconds for the three generation commands combined, now
it's about 35 seconds).
- In the process, we build a `ProgramInfo` struct, which holds a bunch
of interesting expressions that further proof automation could exploit.
This programInfo is stored in a persistent environment extension, so
that it is persisted in the olean files (hence, making it available to
downstream files).
- While building the previous, I've moved some definitions around and
did other refactors: happy to split those off into their own PR if that
makes reviewing easier.

EDIT: split off leanprover#93 for the `#time` command

### Testing:

`make all` succeeded

### License:

By submitting this pull request, I confirm that my contribution is
made under the terms of the Apache 2.0 license.

---------

Co-authored-by: Siddharth <siddu.druid@gmail.com>
Co-authored-by: Shilpi Goel <shigoel@gmail.com>
  • Loading branch information
3 people authored Aug 19, 2024
1 parent 16b87e6 commit 93f6a71
Show file tree
Hide file tree
Showing 13 changed files with 596 additions and 111 deletions.
8 changes: 8 additions & 0 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,14 @@ example : split 0xabcd1234#32 8 (by omega) = [0xab#8, 0xcd#8, 0x12#8, 0x34#8] :=
/-- Get the width of a bitvector. -/
protected def width (_ : BitVec n) : Nat := n

/-- Convert a bitvector into its hex representation, without leading zeroes.
See `BitVec.toHex` if you do want the leading zeroes.
NOTE: returns only the digits, without a `0x` prefix -/
def toHexWithoutLeadingZeroes {w} (x : BitVec w) : String :=
(Nat.toDigits 16 x.toNat).asString

----------------------------------------------------------------------

attribute [ext] BitVec
Expand Down
10 changes: 10 additions & 0 deletions Arm/Exec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,13 @@ theorem run_onestep (s s': ArmState) (n : Nat) (h_nonneg : 0 < n):
· cases h_nonneg
· rename_i n
simp [run]

/-- helper lemma for automation -/
theorem stepi_eq_of_fetch_inst_of_decode_raw_inst
(s : ArmState) (addr : BitVec 64) (rawInst : BitVec 32) (inst : ArmInst)
(h_err : r .ERR s = .None)
(h_pc : r .PC s = addr)
(h_fetch : fetch_inst addr s = some rawInst)
(h_decode : decode_raw_inst rawInst = some inst) :
stepi s = exec_inst inst s := by
simp only [stepi, h_err, h_pc, h_fetch, h_decode, read_err, read_pc]
8 changes: 8 additions & 0 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,14 @@ theorem fetch_inst_from_program
unfold fetch_inst
simp only

theorem fetch_inst_eq_of_prgram_eq_of_map_find
{state : ArmState} {program : Program}
{addr : BitVec 64} {inst? : Option (BitVec 32)}
(h_program : state.program = program)
(h_map : program.find? addr = inst?) :
fetch_inst addr state = inst? := by
rw [fetch_inst, h_program, h_map]

end Load_program_and_fetch_inst

----------------------------------------------------------------------
Expand Down
4 changes: 1 addition & 3 deletions Proofs/AES-GCM/GCMGmultV8Sym.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import Tactics.StepThms

namespace GCMGmultV8Program

#genStepTheorems gcm_gmult_v8_program thmType:="fetch" `state_simp_rules
#genStepTheorems gcm_gmult_v8_program thmType:="decodeExec" `state_simp_rules
#genStepTheorems gcm_gmult_v8_program thmType:="step" `state_simp_rules
#genStepEqTheorems gcm_gmult_v8_program

theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
(h_s0_program : s0.program = gcm_gmult_v8_program)
Expand Down
11 changes: 10 additions & 1 deletion Proofs/Experiments/Abs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Author(s): Shilpi Goel, Siddharth Bhat
The goal is to prove that this program implements absolute value correctly.
-/
import Arm
import Tactics.StepThms
import Tactics.Sym

namespace Abs

Expand All @@ -19,14 +21,21 @@ def program : Program :=

def spec (x : BitVec 32) : BitVec 32 := BitVec.ofNat 32 x.toInt.natAbs

#genStepEqTheorems program

theorem correct
{s0 sf : ArmState}
(h_s0_pc : read_pc s0 = 0x4005d0#64)
(h_s0_program : s0.program = program)
(h_s0_err : read_err s0 = StateError.None)
(h_s0_sp : CheckSPAlignment s0)
(h_run : sf = run program.length s0) :
read_gpr 32 0 sf = spec (read_gpr 32 0 s0) ∧
read_err sf = StateError.None := by sorry
read_err sf = StateError.None := by
simp (config := {ground := true}) at h_run

sym1_n 5
sorry

/-- info: 'Abs.correct' depends on axioms: [propext, sorryAx, Classical.choice, Quot.sound] -/
#guard_msgs in #print axioms correct
Expand Down
24 changes: 8 additions & 16 deletions Proofs/Popcount32.lean
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,7 @@ def popcount32_program : Program :=
(0x40061c#64 , 0xd65f03c0#32)] -- ret


#genStepTheorems popcount32_program thmType:="fetch"

-- #guard_msgs in
-- #check popcount32_fetch_0x4005b4

#genStepTheorems popcount32_program thmType:="decodeExec"

#genStepTheorems popcount32_program thmType:="step" `state_simp_rules
#genStepEqTheorems popcount32_program

theorem popcount32_sym_no_error (s0 s_final : ArmState)
(h_s0_pc : read_pc s0 = 0x4005b4#64)
Expand Down Expand Up @@ -130,16 +123,15 @@ theorem popcount32_sym_no_error (s0 s_final : ArmState)
section Tests

/--
info: popcount32_program.stepi_0x4005c0 (s sn : ArmState) (h_program : s.program = popcount32_program)
info: popcount32_program.stepi_eq_0x4005c0 {s : ArmState} (h_program : s.program = popcount32_program)
(h_pc : r StateField.PC s = 4195776#64) (h_err : r StateField.ERR s = StateError.None) :
(sn = stepi s) =
(sn =
w StateField.PC (4195780#64)
(w (StateField.GPR 0#5)
(zeroExtend 64 ((zeroExtend 32 (r (StateField.GPR 0#5) s)).rotateRight 1) &&& 4294967295#64 &&& 2147483647#64)
s))
stepi s =
w StateField.PC (4195780#64)
(w (StateField.GPR 0#5)
(zeroExtend 64 ((zeroExtend 32 (r (StateField.GPR 0#5) s)).rotateRight 1) &&& 4294967295#64 &&& 2147483647#64)
s)
-/
#guard_msgs in #check popcount32_program.stepi_0x4005c0
#guard_msgs in #check popcount32_program.stepi_eq_0x4005c0

end Tests

Expand Down
12 changes: 5 additions & 7 deletions Proofs/SHA512/Sha512StepLemmas.lean
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import Proofs.SHA512.Sha512FetchLemmas
import Proofs.SHA512.Sha512DecodeExecLemmas
import Proofs.SHA512.Sha512Program
-- import Tests.SHA2.SHA512ProgramTest
import Tactics.StepThms

-- set_option trace.gen_step.debug.heartBeats true in
-- set_option trace.gen_step.print_names true in
set_option maxHeartbeats 2000000 in
#genStepTheorems sha512_program thmType:="step" `state_simp_rules
#genStepEqTheorems sha512_program

/--
info: sha512_program.stepi_0x126c90 (s sn : ArmState) (h_program : s.program = sha512_program)
info: sha512_program.stepi_eq_0x126c90 {s : ArmState} (h_program : s.program = sha512_program)
(h_pc : r StateField.PC s = 1207440#64) (h_err : r StateField.ERR s = StateError.None) :
(sn = stepi s) = (sn = w StateField.PC (if ¬r (StateField.GPR 2#5) s = 0#64 then 1205504#64 else 1207444#64) s)
stepi s = w StateField.PC (if ¬r (StateField.GPR 2#5) s = 0#64 then 1205504#64 else 1207444#64) s
-/
#guard_msgs in
#check sha512_program.stepi_0x126c90
#check sha512_program.stepi_eq_0x126c90
72 changes: 52 additions & 20 deletions Tactics/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def getBitVecString? (e : Expr) (hex : Bool := false): MetaM (Option String) :=
| some ⟨_, value⟩ =>
if hex then
-- We don't want leading zeroes here.
return some (Nat.toDigits 16 value.toNat).asString
return some value.toHexWithoutLeadingZeroes
else
return some (ToString.toString value.toNat)
| none => return none
Expand Down Expand Up @@ -88,13 +88,13 @@ that additionally recognizes:
-- TODO: should this be upstreamed to core?
def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) :=
match_expr e with
| BitVec.ofFin _ i => OptionT.run do
let ⟨n, i⟩ ← getFinValue? i
let n' := Nat.log2 n
if h : n = 2^n' then
return ⟨n', .ofFin (Fin.cast h i)⟩
else
failure
| BitVec.ofFin w i => OptionT.run do
let w ← getNatValue? w
let v ← do
match_expr i with
| Fin.mk _n v _h => getNatValue? v
| _ => pure (← getFinValue? i).2.val
return ⟨w, BitVec.ofNat w v⟩
| _ => Lean.Meta.getBitVecValue? e

/-- Given a ground term `e` of type `Nat`, fully reduce it,
Expand All @@ -115,21 +115,53 @@ which was obtained by reducing:\n\t{e}"
reduce an expression `e` (of type `BitVec w`) to be of the form `?n#w`,
and then reflect `?n` to build the meta-level bitvector -/
def reflectBitVecLiteral (w : Nat) (e : Expr) : MetaM (BitVec w) := do
if e.hasFVar then
if e.hasFVar || e.hasMVar then
throwError "Expected a ground term, but {e} has free variables"

if let some ⟨n, x⟩ ← _root_.getBitVecValue? e then
if h : n = w then
return x.cast h
else
throwError "Expected a bitvector of width {w}, but\n\t{e}\nhas width {n}"
let some ⟨n, x⟩ ← _root_.getBitVecValue? e
| throwError "Failed to reflect:\n\t{e}\ninto a BitVec"

let x ← mkFreshExprMVar (Expr.const ``Nat [])
let e' ← mkAppM ``BitVec.ofNat #[toExpr w, x]
if (←isDefEq e e') then
return BitVec.ofNat w (← reflectNatLiteral x)
if h : n = w then
return x.cast h
else
throwError "Failed to unify, expected:\n\t{e'}\nbut found:\n\t{e'}"
throwError "Expected a bitvector of width {w}, but\n\t{e}\nhas width {n}"

/-! ## Hypothesis types -/
namespace SymContext

/-- `h_err_type state` returns an Expr for `r .ERR <state> = .None`,
the expected type of `h_err` -/
def h_err_type (state : Expr) : Expr :=
mkAppN (mkConst ``Eq [1]) #[
mkConst ``StateError,
mkApp2 (.const ``r []) (.const ``StateField.ERR []) state,
.const ``StateError.None []
]

/-- `h_sp_type state` returns an Expr for `CheckSPAlignment <state>`,
the expected type of `h_sp` -/
def h_sp_type (state : Expr) : Expr :=
mkApp (.const ``CheckSPAlignment []) state

/-- `h_sp_type state` returns an Expr for `<state>.program = <program>`,
the expected type of `h_program` -/
def h_program_type (state program : Expr) : Expr :=
mkAppN (mkConst ``Eq [1]) #[
mkConst ``Program,
mkApp (mkConst ``ArmState.program) state,
program
]

/-- `h_pc_type state` returns an Expr for `r .PC <state> = <address>`,
the expected type of `h_pc` -/
def h_pc_type (state address : Expr) : Expr :=
mkAppN (mkConst ``Eq [1]) #[
mkApp (mkConst ``BitVec) (toExpr 64),
mkApp2 (mkConst ``r) (mkConst ``StateField.PC) state,
address
]

end SymContext

/-! ## Local Context Search -/

Expand Down Expand Up @@ -162,7 +194,7 @@ Throws an error if no such hypothesis could. -/
def findProgramHyp (state : Expr) : MetaM (LocalDecl × Name) := do
-- Try to find `h_program`, and infer `program` from it
let program ← mkFreshExprMVar none
let h_program_type ← mkEq (← mkAppM ``ArmState.program #[state]) program
let h_program_type := SymContext.h_program_type state program
let h_program ← findLocalDeclOfTypeOrError h_program_type

-- Assert that `program` is a(n application of a) constant, and find its name
Expand Down
14 changes: 3 additions & 11 deletions Tactics/Reflect/FetchAndDecode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,15 @@ open Elab.Tactic Elab.Term
initialize
Lean.registerTraceClass `Sym.reduceFetchInst

theorem fetch_inst_eq_of_prgram_eq_of_map_find
{state : ArmState} {program : Program}
{addr : BitVec 64} {inst? : Option (BitVec 32)}
(h_program : state.program = program)
(h_map : program.find? addr = inst?) :
fetch_inst addr state = inst? := by
rw [fetch_inst, h_program, h_map]

def reduceFetchInst? (addr : Expr) (s : Expr) :
def reduceFetchInst? (addr : BitVec 64) (s : Expr) :
MetaM (BitVec 32 × Expr) := do
let addr ← reflectBitVecLiteral 64 addr
let ⟨programHyp, program⟩ ← findProgramHyp s
let programInfo ← try ProgramInfo.lookupOrGenerate program catch err =>
throwErrorAt
err.getRef
"Could not generate ProgramInfo for {program}:\n\n{err.toMessageData}"

let some rawInst := programInfo.getRawInstrAt? addr
let some rawInst := programInfo.getRawInstAt? addr
| throwError "No instruction found at address {addr}"

trace[Sym.reduceFetchInst] "{Lean.checkEmoji} reduced to: {rawInst}"
Expand All @@ -54,6 +45,7 @@ simproc reduceFetchInst (fetch_inst _ _) := fun e => do
trace[Sym.reduceFetchInst] "⚙️ simplifying {e}"
let_expr fetch_inst addr s := e
| return .continue
let addr ← reflectBitVecLiteral 64 addr

try
let ⟨x, proof?⟩ ← reduceFetchInst? addr s
Expand Down
Loading

0 comments on commit 93f6a71

Please sign in to comment.