diff --git a/Arm/BitVec.lean b/Arm/BitVec.lean index 0930269a..bde49cf4 100644 --- a/Arm/BitVec.lean +++ b/Arm/BitVec.lean @@ -339,6 +339,14 @@ example : split 0xabcd1234#32 8 (by omega) = [0xab#8, 0xcd#8, 0x12#8, 0x34#8] := /-- Get the width of a bitvector. -/ protected def width (_ : BitVec n) : Nat := n +/-- Convert a bitvector into its hex representation, without leading zeroes. + +See `BitVec.toHex` if you do want the leading zeroes. + +NOTE: returns only the digits, without a `0x` prefix -/ +def toHexWithoutLeadingZeroes {w} (x : BitVec w) : String := + (Nat.toDigits 16 x.toNat).asString + ---------------------------------------------------------------------- attribute [ext] BitVec diff --git a/Arm/Exec.lean b/Arm/Exec.lean index f9bc7ccf..c8dd3db4 100644 --- a/Arm/Exec.lean +++ b/Arm/Exec.lean @@ -166,3 +166,13 @@ theorem run_onestep (s s': ArmState) (n : Nat) (h_nonneg : 0 < n): · cases h_nonneg · rename_i n simp [run] + +/-- helper lemma for automation -/ +theorem stepi_eq_of_fetch_inst_of_decode_raw_inst + (s : ArmState) (addr : BitVec 64) (rawInst : BitVec 32) (inst : ArmInst) + (h_err : r .ERR s = .None) + (h_pc : r .PC s = addr) + (h_fetch : fetch_inst addr s = some rawInst) + (h_decode : decode_raw_inst rawInst = some inst) : + stepi s = exec_inst inst s := by + simp only [stepi, h_err, h_pc, h_fetch, h_decode, read_err, read_pc] diff --git a/Arm/State.lean b/Arm/State.lean index 8315d24a..be289ecb 100644 --- a/Arm/State.lean +++ b/Arm/State.lean @@ -521,6 +521,14 @@ theorem fetch_inst_from_program unfold fetch_inst simp only +theorem fetch_inst_eq_of_prgram_eq_of_map_find + {state : ArmState} {program : Program} + {addr : BitVec 64} {inst? : Option (BitVec 32)} + (h_program : state.program = program) + (h_map : program.find? addr = inst?) : + fetch_inst addr state = inst? := by + rw [fetch_inst, h_program, h_map] + end Load_program_and_fetch_inst ---------------------------------------------------------------------- diff --git a/Proofs/AES-GCM/GCMGmultV8Sym.lean b/Proofs/AES-GCM/GCMGmultV8Sym.lean index 4fc97235..8cee03c3 100644 --- a/Proofs/AES-GCM/GCMGmultV8Sym.lean +++ b/Proofs/AES-GCM/GCMGmultV8Sym.lean @@ -4,9 +4,7 @@ import Tactics.StepThms namespace GCMGmultV8Program -#genStepTheorems gcm_gmult_v8_program thmType:="fetch" `state_simp_rules -#genStepTheorems gcm_gmult_v8_program thmType:="decodeExec" `state_simp_rules -#genStepTheorems gcm_gmult_v8_program thmType:="step" `state_simp_rules +#genStepEqTheorems gcm_gmult_v8_program theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState) (h_s0_program : s0.program = gcm_gmult_v8_program) diff --git a/Proofs/Experiments/Abs.lean b/Proofs/Experiments/Abs.lean index f2e599b6..69ee53e8 100644 --- a/Proofs/Experiments/Abs.lean +++ b/Proofs/Experiments/Abs.lean @@ -6,6 +6,8 @@ Author(s): Shilpi Goel, Siddharth Bhat The goal is to prove that this program implements absolute value correctly. -/ import Arm +import Tactics.StepThms +import Tactics.Sym namespace Abs @@ -19,14 +21,21 @@ def program : Program := def spec (x : BitVec 32) : BitVec 32 := BitVec.ofNat 32 x.toInt.natAbs +#genStepEqTheorems program + theorem correct {s0 sf : ArmState} (h_s0_pc : read_pc s0 = 0x4005d0#64) (h_s0_program : s0.program = program) (h_s0_err : read_err s0 = StateError.None) + (h_s0_sp : CheckSPAlignment s0) (h_run : sf = run program.length s0) : read_gpr 32 0 sf = spec (read_gpr 32 0 s0) ∧ - read_err sf = StateError.None := by sorry + read_err sf = StateError.None := by + simp (config := {ground := true}) at h_run + + sym1_n 5 + sorry /-- info: 'Abs.correct' depends on axioms: [propext, sorryAx, Classical.choice, Quot.sound] -/ #guard_msgs in #print axioms correct diff --git a/Proofs/Popcount32.lean b/Proofs/Popcount32.lean index 4ceb358b..fcead4eb 100644 --- a/Proofs/Popcount32.lean +++ b/Proofs/Popcount32.lean @@ -66,14 +66,7 @@ def popcount32_program : Program := (0x40061c#64 , 0xd65f03c0#32)] -- ret -#genStepTheorems popcount32_program thmType:="fetch" - --- #guard_msgs in --- #check popcount32_fetch_0x4005b4 - -#genStepTheorems popcount32_program thmType:="decodeExec" - -#genStepTheorems popcount32_program thmType:="step" `state_simp_rules +#genStepEqTheorems popcount32_program theorem popcount32_sym_no_error (s0 s_final : ArmState) (h_s0_pc : read_pc s0 = 0x4005b4#64) @@ -130,16 +123,15 @@ theorem popcount32_sym_no_error (s0 s_final : ArmState) section Tests /-- -info: popcount32_program.stepi_0x4005c0 (s sn : ArmState) (h_program : s.program = popcount32_program) +info: popcount32_program.stepi_eq_0x4005c0 {s : ArmState} (h_program : s.program = popcount32_program) (h_pc : r StateField.PC s = 4195776#64) (h_err : r StateField.ERR s = StateError.None) : - (sn = stepi s) = - (sn = - w StateField.PC (4195780#64) - (w (StateField.GPR 0#5) - (zeroExtend 64 ((zeroExtend 32 (r (StateField.GPR 0#5) s)).rotateRight 1) &&& 4294967295#64 &&& 2147483647#64) - s)) + stepi s = + w StateField.PC (4195780#64) + (w (StateField.GPR 0#5) + (zeroExtend 64 ((zeroExtend 32 (r (StateField.GPR 0#5) s)).rotateRight 1) &&& 4294967295#64 &&& 2147483647#64) + s) -/ -#guard_msgs in #check popcount32_program.stepi_0x4005c0 +#guard_msgs in #check popcount32_program.stepi_eq_0x4005c0 end Tests diff --git a/Proofs/SHA512/Sha512StepLemmas.lean b/Proofs/SHA512/Sha512StepLemmas.lean index 0613490b..4b246fbd 100644 --- a/Proofs/SHA512/Sha512StepLemmas.lean +++ b/Proofs/SHA512/Sha512StepLemmas.lean @@ -1,17 +1,15 @@ -import Proofs.SHA512.Sha512FetchLemmas -import Proofs.SHA512.Sha512DecodeExecLemmas import Proofs.SHA512.Sha512Program --- import Tests.SHA2.SHA512ProgramTest +import Tactics.StepThms -- set_option trace.gen_step.debug.heartBeats true in -- set_option trace.gen_step.print_names true in set_option maxHeartbeats 2000000 in -#genStepTheorems sha512_program thmType:="step" `state_simp_rules +#genStepEqTheorems sha512_program /-- -info: sha512_program.stepi_0x126c90 (s sn : ArmState) (h_program : s.program = sha512_program) +info: sha512_program.stepi_eq_0x126c90 {s : ArmState} (h_program : s.program = sha512_program) (h_pc : r StateField.PC s = 1207440#64) (h_err : r StateField.ERR s = StateError.None) : - (sn = stepi s) = (sn = w StateField.PC (if ¬r (StateField.GPR 2#5) s = 0#64 then 1205504#64 else 1207444#64) s) + stepi s = w StateField.PC (if ¬r (StateField.GPR 2#5) s = 0#64 then 1205504#64 else 1207444#64) s -/ #guard_msgs in -#check sha512_program.stepi_0x126c90 +#check sha512_program.stepi_eq_0x126c90 diff --git a/Tactics/Common.lean b/Tactics/Common.lean index 88c0d06b..8b5442e8 100644 --- a/Tactics/Common.lean +++ b/Tactics/Common.lean @@ -55,7 +55,7 @@ def getBitVecString? (e : Expr) (hex : Bool := false): MetaM (Option String) := | some ⟨_, value⟩ => if hex then -- We don't want leading zeroes here. - return some (Nat.toDigits 16 value.toNat).asString + return some value.toHexWithoutLeadingZeroes else return some (ToString.toString value.toNat) | none => return none @@ -88,13 +88,13 @@ that additionally recognizes: -- TODO: should this be upstreamed to core? def getBitVecValue? (e : Expr) : MetaM (Option ((n : Nat) × BitVec n)) := match_expr e with - | BitVec.ofFin _ i => OptionT.run do - let ⟨n, i⟩ ← getFinValue? i - let n' := Nat.log2 n - if h : n = 2^n' then - return ⟨n', .ofFin (Fin.cast h i)⟩ - else - failure + | BitVec.ofFin w i => OptionT.run do + let w ← getNatValue? w + let v ← do + match_expr i with + | Fin.mk _n v _h => getNatValue? v + | _ => pure (← getFinValue? i).2.val + return ⟨w, BitVec.ofNat w v⟩ | _ => Lean.Meta.getBitVecValue? e /-- Given a ground term `e` of type `Nat`, fully reduce it, @@ -115,21 +115,53 @@ which was obtained by reducing:\n\t{e}" reduce an expression `e` (of type `BitVec w`) to be of the form `?n#w`, and then reflect `?n` to build the meta-level bitvector -/ def reflectBitVecLiteral (w : Nat) (e : Expr) : MetaM (BitVec w) := do - if e.hasFVar then + if e.hasFVar || e.hasMVar then throwError "Expected a ground term, but {e} has free variables" - if let some ⟨n, x⟩ ← _root_.getBitVecValue? e then - if h : n = w then - return x.cast h - else - throwError "Expected a bitvector of width {w}, but\n\t{e}\nhas width {n}" + let some ⟨n, x⟩ ← _root_.getBitVecValue? e + | throwError "Failed to reflect:\n\t{e}\ninto a BitVec" - let x ← mkFreshExprMVar (Expr.const ``Nat []) - let e' ← mkAppM ``BitVec.ofNat #[toExpr w, x] - if (←isDefEq e e') then - return BitVec.ofNat w (← reflectNatLiteral x) + if h : n = w then + return x.cast h else - throwError "Failed to unify, expected:\n\t{e'}\nbut found:\n\t{e'}" + throwError "Expected a bitvector of width {w}, but\n\t{e}\nhas width {n}" + +/-! ## Hypothesis types -/ +namespace SymContext + +/-- `h_err_type state` returns an Expr for `r .ERR = .None`, +the expected type of `h_err` -/ +def h_err_type (state : Expr) : Expr := + mkAppN (mkConst ``Eq [1]) #[ + mkConst ``StateError, + mkApp2 (.const ``r []) (.const ``StateField.ERR []) state, + .const ``StateError.None [] + ] + +/-- `h_sp_type state` returns an Expr for `CheckSPAlignment `, +the expected type of `h_sp` -/ +def h_sp_type (state : Expr) : Expr := + mkApp (.const ``CheckSPAlignment []) state + +/-- `h_sp_type state` returns an Expr for `.program = `, +the expected type of `h_program` -/ +def h_program_type (state program : Expr) : Expr := + mkAppN (mkConst ``Eq [1]) #[ + mkConst ``Program, + mkApp (mkConst ``ArmState.program) state, + program + ] + +/-- `h_pc_type state` returns an Expr for `r .PC =
`, +the expected type of `h_pc` -/ +def h_pc_type (state address : Expr) : Expr := + mkAppN (mkConst ``Eq [1]) #[ + mkApp (mkConst ``BitVec) (toExpr 64), + mkApp2 (mkConst ``r) (mkConst ``StateField.PC) state, + address + ] + +end SymContext /-! ## Local Context Search -/ @@ -162,7 +194,7 @@ Throws an error if no such hypothesis could. -/ def findProgramHyp (state : Expr) : MetaM (LocalDecl × Name) := do -- Try to find `h_program`, and infer `program` from it let program ← mkFreshExprMVar none - let h_program_type ← mkEq (← mkAppM ``ArmState.program #[state]) program + let h_program_type := SymContext.h_program_type state program let h_program ← findLocalDeclOfTypeOrError h_program_type -- Assert that `program` is a(n application of a) constant, and find its name diff --git a/Tactics/Reflect/FetchAndDecode.lean b/Tactics/Reflect/FetchAndDecode.lean index 04c6f2e7..c9583f20 100644 --- a/Tactics/Reflect/FetchAndDecode.lean +++ b/Tactics/Reflect/FetchAndDecode.lean @@ -13,24 +13,15 @@ open Elab.Tactic Elab.Term initialize Lean.registerTraceClass `Sym.reduceFetchInst -theorem fetch_inst_eq_of_prgram_eq_of_map_find - {state : ArmState} {program : Program} - {addr : BitVec 64} {inst? : Option (BitVec 32)} - (h_program : state.program = program) - (h_map : program.find? addr = inst?) : - fetch_inst addr state = inst? := by - rw [fetch_inst, h_program, h_map] - -def reduceFetchInst? (addr : Expr) (s : Expr) : +def reduceFetchInst? (addr : BitVec 64) (s : Expr) : MetaM (BitVec 32 × Expr) := do - let addr ← reflectBitVecLiteral 64 addr let ⟨programHyp, program⟩ ← findProgramHyp s let programInfo ← try ProgramInfo.lookupOrGenerate program catch err => throwErrorAt err.getRef "Could not generate ProgramInfo for {program}:\n\n{err.toMessageData}" - let some rawInst := programInfo.getRawInstrAt? addr + let some rawInst := programInfo.getRawInstAt? addr | throwError "No instruction found at address {addr}" trace[Sym.reduceFetchInst] "{Lean.checkEmoji} reduced to: {rawInst}" @@ -54,6 +45,7 @@ simproc reduceFetchInst (fetch_inst _ _) := fun e => do trace[Sym.reduceFetchInst] "⚙️ simplifying {e}" let_expr fetch_inst addr s := e | return .continue + let addr ← reflectBitVecLiteral 64 addr try let ⟨x, proof?⟩ ← reduceFetchInst? addr s diff --git a/Tactics/Reflect/ProgramInfo.lean b/Tactics/Reflect/ProgramInfo.lean index 368b74ca..897ec2a3 100644 --- a/Tactics/Reflect/ProgramInfo.lean +++ b/Tactics/Reflect/ProgramInfo.lean @@ -17,81 +17,335 @@ Furthermore, we define a persistent env extension to store `ProgramInfo` in. open Lean Meta Elab.Term +initialize + registerTraceClass `ProgramInfo + +/-- `OnDemand α` is morally an `Option α`, +we use it for values that are computed, and cached, on demand. -/ +inductive InstInfo.OnDemand (α : Type) + /-- a value has not yet been cached, + you should run the relevant computation -/ + | notYetComputed + /-- a value has been cached -/ + | value (value : α) +open InstInfo (OnDemand) + +structure InstInfo where + /-- the raw instruction, as a bitvector -/ + rawInst : BitVec 32 + + /-- the decoded instruction, as a normalized(!) `Expr` of type `ArmInst`. + That is, `decode_raw_inst ` should be def-eq to `some `. + -/ + decodedInst? : OnDemand Expr := + .notYetComputed + + /-- if `instSemantics?` is `⟨sem, type, proof⟩`, then + - `sem` is the instruction semantics, as a simplified expression of type + `ArmState → ArmState`. + + That is, we've ran `simp` on `sem` with our dedicated simp-sets in the hopes + of obtaining only a sequence of `w` and `write_mem`s to the initial state. + However, note that some instructions might have conditional behaviour, + in which case `sem` might still contain `if`s + - `type` is the expression + ```lean + ∀ s (h_program : s.program = ) (h_pc : read_pc s = ) + (h_err : read_err s = .None), + exec_inst s = s + ``` + - `proof` is a proof of type `type` + -/ + instSemantics? : OnDemand (Expr × Expr × Expr) := + .notYetComputed + structure ProgramInfo where - rawProgram : HashMap (BitVec 64) (BitVec 32) + name : Name + instructions : HashMap (BitVec 64) InstInfo + +-------------------------------------------------------------------------------- + +/-! ## InstInfoT -/ + +/-- A monad transformer with `InstInfo` state -/ +abbrev InstInfoT := StateT InstInfo + +namespace InstInfoT +variable {m} [Monad m] + +/-- Return `InstInfo.rawInst` from the state -/ +def getRawInst : InstInfoT m (BitVec 32) := do + return (← get).rawInst + +/-- Return `InstInfo.decodedInst?` from the state if it is `some _`, +or use `f` to compute the relevant expression if it is missing -/ +def getDecodedInst (f : Unit → InstInfoT m Expr) : InstInfoT m Expr := do + let info ← get + match info.decodedInst? with + | .value val => return val + | .notYetComputed => + let val ← f () + set {info with decodedInst? := .value val} + return val + +/-- Return `InstInfo.instSemantics?` from the state if it is `some _`, +or use `f` to compute the relevant expressions if they are missing -/ +def getInstSemantics (f : Unit → InstInfoT m (Expr × Expr × Expr)) : + InstInfoT m (Expr × Expr × Expr) := do + let info ← get + match info.instSemantics? with + | .value val => return val + | .notYetComputed => + let val ← f () + set {info with instSemantics? := .value val} + return val -def ProgramInfo.getRawInstrAt? (pi : ProgramInfo) (addr : BitVec 64) : +end InstInfoT + +def InstInfo.ofRawInst (rawInst : BitVec 32) : InstInfo := + { rawInst } + +-------------------------------------------------------------------------------- + +namespace ProgramInfo + +/-- The expression `mkConst pi.name`, +i.e., an expression of this program referred to by name -/ +def expr (pi : ProgramInfo) : Expr := mkConst pi.name + +def getInstInfoAt? (pi : ProgramInfo) (addr : BitVec 64) : + Option InstInfo := + pi.instructions.find? addr + +def getRawInstAt? (pi : ProgramInfo) (addr : BitVec 64) : Option (BitVec 32) := - pi.rawProgram.find? addr + (·.rawInst) <$> pi.getInstInfoAt? addr -/-- Given an `Expr` of type `Program`, generate the basic `ProgramInfo` -/ -partial def ProgramInfo.generateFromExpr (e : Expr) : MetaM ProgramInfo := do +-- TODO: this instance could be upstreamed (after cleaning it up) +instance [BEq α] [Hashable α] : ForIn m (HashMap α β) (α × β) where + forIn map acc f := do + let f := fun (acc : ForInStep _) key val => do + match acc with + | .yield acc => f ⟨key, val⟩ acc + | .done _ => return acc + match ← map.foldM f (ForInStep.yield acc) with + | .done x | .yield x => return x + +/-! ## ProgramInfo Generation -/ + +/-- Given the name and defining expression of a `Program`, +generate the basic `ProgramInfo` -/ +partial def generateFromExpr (name : Name) (e : Expr) : MetaM ProgramInfo := do + trace[ProgramInfo] "Generating program info for `{name}` from definition:\n\t{e}" let type ← inferType e if !(←isDefEq type (mkConst ``Program)) then throwError "type mismatch: {e} {← mkHasTypeButIsExpectedMsg type (mkConst ``Program)}" - let rec go (rawProgram : HashMap _ _) (e : Expr) : MetaM (HashMap _ _) := do + let rec go (instructions : HashMap _ _) (e : Expr) : MetaM (HashMap _ _) := do let e ← whnfD e match_expr e with | List.cons _ hd tl => do + trace[ProgramInfo] "found address/instruction pair: {hd}" + let hd' ← reduce hd let_expr Prod.mk _ _ addr inst := hd' | throwError "expected `{hd}` to reduce to an application of `Prod.mk`, found:\n\t{hd'}" - let addr ← reflectBitVecLiteral 64 addr - let inst ← reflectBitVecLiteral 32 inst + let addr ← reflectBitVecLiteral 64 (← instantiateMVars addr) + let rawInst ← reflectBitVecLiteral 32 (← instantiateMVars inst) + let rawProgram := + let info := InstInfo.ofRawInst rawInst + instructions.insert addr info - let rawProgram := rawProgram.insert addr inst go rawProgram tl - | List.nil _ => return rawProgram + | List.nil _ => return instructions | _ => throwError "expected `List.cons _ _` or `List.nil`, found:\n\t{e}" return { - rawProgram := ← go ∅ e + name, + instructions := ← go ∅ e } /-- Given the `Name` of a constant of type `Program`, generate the basic `ProgramInfo` -/ -def ProgramInfo.generateFromConstName (program : Name) : MetaM ProgramInfo := do +def generateFromConstName (program : Name) : MetaM ProgramInfo := do let .defnInfo defnInfo ← getConstInfo program | throwError "expected a definition, but {program} is not" - generateFromExpr defnInfo.value + generateFromExpr program defnInfo.value /-! ## Env Extension -/ -initialize programInfoExt : PersistentEnvExtension (Name × ProgramInfo) (Name × ProgramInfo) (NameMap ProgramInfo) ← +initialize programInfoExt : PersistentEnvExtension (ProgramInfo) (ProgramInfo) (NameMap ProgramInfo) ← registerPersistentEnvExtension { name := `programInfo mkInitial := pure {} addImportedFn := fun _ _ => pure {} - addEntryFn := fun s p => s.insert p.1 p.2 + addEntryFn := fun s p => s.insert p.name p exportEntriesFn := fun m => - let r : Array (Name × _) := m.fold (fun a n p => a.push (n, p)) #[] - r.qsort (fun a b => Name.quickLt a.1 b.1) + let r : Array (ProgramInfo) := + m.fold (fun a name p => a.push {p with name}) #[] + r.qsort (fun a b => Name.quickLt a.name b.name) statsFn := fun s => "program info extension" ++ Format.line ++ "number of local entries: " ++ format s.size } -/-- store a `PogramInfo` for the given `program` in the environment -/ -private def ProgramInfo.store [Monad m] [MonadEnv m] - (program : Name) (pi : ProgramInfo) : m Unit := do - modifyEnv (programInfoExt.addEntry · ⟨program, pi⟩) +/-- persistently store a `ProgramInfo` in the environment -/ +def persistToEnv [Monad m] [MonadEnv m] (pi : ProgramInfo) : m Unit := do + modifyEnv (programInfoExt.addEntry · pi) /-- look up the `ProgramInfo` for a given `program` in the environment, returns `None` if not found -/ -def ProgramInfo.lookup? [Monad m] [MonadEnv m] (program : Name) : +def lookup? [Monad m] [MonadEnv m] (program : Name) : m (Option ProgramInfo) := do let env ← getEnv let state := programInfoExt.getState env return state.find? program /-- look up the `ProgramInfo` for a given `program` in the environment, -or, if none was found, generate (and cache) new program info -/ -def ProgramInfo.lookupOrGenerate (program : Name) : MetaM ProgramInfo := do +or, if none was found, generate new program info. + +If you pass in a value for `expr?`, that is assumed to be the definition for +`program` when generating new program info. +If you don't pass in an expr, the definition is found in the environment + +If `persist` is set to true (the default), then the newly generated program info +will be persistently cached in the environment (see `persistToEnv`) -/ +def lookupOrGenerate (program : Name) (expr? : Option Expr := none) + (persist : Bool := true) : + MetaM ProgramInfo := do if let some pi ← lookup? program then return pi else - let pi ← generateFromConstName program - store program pi + let pi ← match expr? with + | some expr => generateFromExpr program expr + | none => generateFromConstName program + if persist then + persistToEnv pi return pi + +end ProgramInfo + +/-! ## `ProgramInfoT` Monad Transformer -/ + +/-- A monad transformer with `ProgramInfo` state -/ +abbrev ProgramInfoT (m : Type → Type) := StateT ProgramInfo m + +namespace ProgramInfoT +variable [Monad m] [MonadEnv m] [MonadError m] + +/-! ### run -/ +section Run +variable [MonadLiftT MetaM m] + +protected def run' (programName : Name) (expr? : Option Expr) (persist : Bool) + (k : ProgramInfoT m α) : m α := do + let pi ← ProgramInfo.lookupOrGenerate programName expr? + let ⟨a, pi⟩ ← StateT.run k pi + if persist then + pi.persistToEnv + return a + +/-- run a `ProgramInfoT m` by looking up, or generating new program info, +by name. + +If `persist` is set to true, then the program info state after +executing `k` will be persistently cached in the environment +(see `persistToEnv`). -/ +def run (programName : Name) (k : ProgramInfoT m α) + (persist : Bool := false) : + m α := + ProgramInfoT.run' programName none persist k + +/-- run a `ProgramInfoT m` by looking up, or generating new program info. +The passed expression is assumed to be the definition of the program. + +If `persist` is set to true (the default), then the program info state after +executing `k` will be persistently cached in the environment +(see `persistToEnv`). -/ +def runE (programName : Name) (expr : Expr) (k : ProgramInfoT m α) + (persist : Bool := false) + : m α := + ProgramInfoT.run' programName expr persist k + +end Run + +/-! ### MonadError instance -/ + +instance [Monad m] [i : MonadError m] : MonadError (ProgramInfoT m) where + throw e := i.throw e + tryCatch k f := fun s => i.tryCatch (k s) (fun e => f e s) + getRef := i.getRef + withRef stx k := fun s => i.withRef stx (k s) + add stx msg := i.add stx msg + +/-! ### Wrappers -/ + +/-- Access the info for the instruction at a given address, +or throw an error if none is found -/ +def getInstInfoAt (addr : BitVec 64) : ProgramInfoT m InstInfo := do + let some x := (← StateT.get).getInstInfoAt? addr + | let addr := addr.toHexWithoutLeadingZeroes + throwError "No instruction found at address {addr}" + return x + +/-- Set the instruction info for a particular address -/ +def setInstInfoAt (addr : BitVec 64) (info : InstInfo) : + ProgramInfoT m Unit := do + let pi ← StateT.get + StateT.set {pi with instructions := pi.instructions.insert addr info} + +/-- Run `k` with the instruction info for the given address as initial state, +and store the resulting state at that same address. +Returns the value produced by `k`. Throws an error if the address is invalid -/ +def modifyInstInfoAt (addr : BitVec 64) (k : InstInfoT m α) : + ProgramInfoT m α := do + let info ← getInstInfoAt addr + let ⟨val, info⟩ ← monadLift (StateT.run k info) + setInstInfoAt addr info + return val + +/-! ### InstInfo Accessors -/ + +def getRawInstAt (addr : BitVec 64) : ProgramInfoT m (BitVec 32) := do + return (← getInstInfoAt addr).rawInst + +/-- if `decodedInst?` is `some _` for the instruction info at the given address, +return the cached value. +Otherwise, use `k` to compute the decoded instruction, then +cache and return that new value. +See `InstInfo.decodInst?` for the meaning of this field. + +NOTE: the computed value is only cached in the `ProgramInfoT` monad state, +not yet in the environment. -/ +def getDecodedInstAt (addr : BitVec 64) (k : InstInfo → ProgramInfoT m Expr) : + ProgramInfoT m Expr := do + let info ← getInstInfoAt addr + match info.decodedInst? with + | .value e => return e + | .notYetComputed => + let decodedInst ← k info + setInstInfoAt addr {info with decodedInst? := .value decodedInst} + return decodedInst + +/-- if `instSemantics?` is `some _` for the instruction info at the given address, +return the cached value. +Otherwise, use `k` to compute the decoded instruction, then +cache and return that new value. +See `InstInfo.instSemantics?` for the meaning of this field. + +NOTE: the computed value is only cached in the `ProgramInfoT` monad state, +not yet in the environment. -/ +def getInstSemanticsAt (addr : BitVec 64) + (k : InstInfo → ProgramInfoT m (Expr × Expr × Expr)) : + ProgramInfoT m (Expr × Expr × Expr) := do + let info ← getInstInfoAt addr + match info.instSemantics? with + | .value e => return e + | .notYetComputed => + let instSemantics ← k info + setInstInfoAt addr {info with instSemantics? := .value instSemantics} + return instSemantics + + +end ProgramInfoT diff --git a/Tactics/StepThms.lean b/Tactics/StepThms.lean index 2e2f14b0..3f9eb8df 100644 --- a/Tactics/StepThms.lean +++ b/Tactics/StepThms.lean @@ -1,10 +1,18 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Shilpi Goel, Alex Keizer +-/ import Lean import Arm.Map import Arm.Decode import Tactics.Common import Tactics.Simp import Tactics.ChangeHyps +import Tactics.Reflect.ProgramInfo + open Lean Lean.Expr Lean.Meta Lean.Elab Lean.Elab.Command +open SymContext (h_pc_type h_program_type h_err_type) -- NOTE: This is an experimental and probably quite shoddy method of autogenerating -- `stepi` theorems from a program under verification, and things may change @@ -33,11 +41,194 @@ initialize registerTraceClass `gen_step.print_names initialize registerTraceClass `gen_step.debug /- When true, prints the number of heartbeats taken per theorem. -/ initialize registerTraceClass `gen_step.debug.heartBeats +/- When true, prints the time taken at various steps of generation. -/ +initialize registerTraceClass `gen_step.debug.timing + +/-- Assuming that `rawInst` is indeed the right result, construct a proof that + `fetch_inst addr state = some rawInst` +given that `state.program = program` -/ +private def fetchLemma (state program h_program : Expr) + (addr : BitVec 64) (rawInst : BitVec 32) : Expr := + let someRawInst := toExpr (some rawInst) + mkAppN (mkConst ``fetch_inst_eq_of_prgram_eq_of_map_find) #[ + state, + program, + toExpr addr, + someRawInst, + h_program, + mkApp2 (.const ``Eq.refl [1]) + (mkApp (.const ``Option [0]) <| + mkApp (.const ``BitVec []) (toExpr 32)) + someRawInst + ] + +-- /-! ## `reduceDecodeInst` -/ + +/-- `canonicalizeBitVec e` recursively walks over expression `e` to convert any +occurrences of: + `BitVec.ofFin w (Fin.mk x _)` +to the canonical form: + `BitVec.ofNat w x` (i.e., `x#w`) + +Such expressions tend to result from using `reduce` or +`simp` with `{ground := true}`. +You can call `canonicalizeBitVec` after these functions to ensure you don't +needlessly expose `BitVec` internal details -/ +-- TODO: should this canonicalize to `BitVec.ofNatLt` instead, +-- as the current transformation loses information? +partial def canonicalizeBitVec (e : Expr) : MetaM Expr := do + match_expr e with + | BitVec.ofFin w i => + let_expr Fin.mk _ x _h := i | fallback + let w ← + if w.hasFVar || w.hasMVar then + pure w + else + withTransparency .all <| reduce w + -- ^^ NOTE: potentially expensive reduction + return mkApp2 (mkConst ``BitVec.ofNat) w x + | _ => fallback + where + fallback : MetaM Expr := do + let fn := e.getAppFn + let args ← e.getAppArgs.mapM canonicalizeBitVec + return mkAppN fn args + +/-- Given an expr `rawInst` of type `BitVec 32`, +return an expr of type `Option ArmInst` representing what `rawInst` decodes to. +The resulting expr is guaranteed to be def-eq to `decode_raw_inst $rawInst` -/ +def reduceDecodeInstExpr (rawInst : Expr) : MetaM Expr := do + let expr := mkApp (mkConst ``decode_raw_inst) rawInst + let expr ← withTransparency .all <| reduce expr + -- ^^ NOTE: possibly expensive reduction + canonicalizeBitVec expr + +/-! ## SymM Monad -/ + +abbrev SymM.CacheKey := BitVec 32 +abbrev SymM.CacheM := MonadCacheT CacheKey Expr MetaM +abbrev SymM := ProgramInfoT <| MonadCacheT SymM.CacheKey Expr MetaM + +@[inherit_doc ProgramInfoT.run] +abbrev SymM.run (name : Name) (k : SymM α) (persist : Bool := true) : MetaM α := + MonadCacheT.run <| ProgramInfoT.run name k persist + +open SymM in +/-- Given a (reflected) raw instruction, +return an expr of type `Option ArmInst` representing what `rawInst` decodes to. +The resulting expr is guaranteed to be def-eq to `decode_raw_inst $rawInst`. + +Results are cached so that the same instruction is not reduced multiple times -/ +def reduceDecodeInst (rawInst : BitVec 32) : CacheM Expr := + checkCache (rawInst) fun _ => + reduceDecodeInstExpr (toExpr rawInst) + +open ProgramInfoT InstInfoT + +/-! ## reduceStepiToExecInst -/ + +/-- Given a program and an address, and optionally the corresponding +raw and decoded instructions, construct and return first the expression: +``` +∀ {s} (h_program : s.program = ) (h_pc : r .PC s = ) + (h_err : r .ERR s = .None), + stepi s = s`> +``` +and then a proof of this fact. +That is, in + `let ⟨type, value⟩ ← reduceStepi ...` +`value` is an expr whose type is `type` -/ +def reduceStepi (addr : BitVec 64) : SymM (Expr × Expr) := do + let pi : ProgramInfo ← get + let ⟨_, type, proof⟩ ← modifyInstInfoAt addr <| getInstSemantics fun _ => do + let rawInst ← getRawInst + + let inst ← getDecodedInst <| fun _ => do + let optInst ← reduceDecodeInst rawInst + let_expr some _ inst := optInst + | let some := mkConst ``Option.some [1] + throwError "Expected an application of {some}, found:\n\t{optInst}" + pure inst + + withLocalDecl `s .implicit (mkConst ``ArmState) <| fun s => + withLocalDeclD `h_program (h_program_type s pi.expr) <| fun h_program => + withLocalDeclD `h_pc (h_pc_type s (toExpr addr)) <| fun h_pc => + withLocalDeclD `h_err (h_err_type s) <| fun h_err => do + let h_fetch := fetchLemma s pi.expr h_program addr rawInst + let h_decode := + let armInstTy := mkConst ``ArmInst + mkApp2 (mkConst ``Eq.refl [1]) + (mkApp (mkConst ``Option [0]) armInstTy) + (mkApp2 (mkConst ``Option.some [0]) armInstTy inst) + + let proof := -- stepi s = exec_inst s + mkAppN (mkConst ``stepi_eq_of_fetch_inst_of_decode_raw_inst) #[ + s, toExpr addr, toExpr rawInst, inst, + h_err, h_pc, h_fetch, h_decode + ] + let type ← inferType proof + + let (ctx, simprocs) ← do + let localDecls ← do + let hs := #[h_pc, h_err] + pure <| hs.filterMap (← getLCtx).findFVar? + LNSymSimpContext + (config := {decide := true, ground := false}) + (simp_attrs := #[`minimal_theory, `bitvec_rules, `state_simp_rules]) + (decls := localDecls) + (decls_to_unfold := #[``exec_inst]) + + let ⟨simpRes, _⟩ ← simp type ctx simprocs + + let_expr Eq _ _ sem := simpRes.expr + | let eq ← mkEq (← mkFreshExprMVar none) (← mkFreshExprMVar none) + throwError "Failed to normalize instruction semantics. Expected {eq}, but found:\n\t{simpRes.expr}" + let sem ← mkLambdaFVars #[s] sem + + let proof ← simpRes.mkCast proof -- stepi s = + let hs := #[s, h_program, h_pc, h_err] + let proof ← mkLambdaFVars hs proof + let type ← mkForallFVars hs simpRes.expr + return ⟨sem, type, proof⟩ + return ⟨type, proof⟩ + +def genStepEqTheorems : SymM Unit := do + let pi ← get + for ⟨addr, instInfo⟩ in pi.instructions do + let startTime ← IO.monoMsNow + let inst := instInfo.rawInst + + trace[gen_step.debug] "[genStepEqTheorems] Generating theorem for address {addr.toHex}\ + with instruction {inst.toHex}" + let name := let addr_str := addr.toHexWithoutLeadingZeroes + Name.str pi.name ("stepi_eq_0x" ++ addr_str) + let ⟨type, value⟩ ← reduceStepi addr + + trace[gen_step.debug.timing] "[genStepEqTheorems] reduced in: {(← IO.monoMsNow) - startTime}ms" + addDecl <| Declaration.thmDecl { + name, type, value, + levelParams := [] + } + trace[gen_step.debug.timing] "[genStepEqTheorems] added to environment in: {(← IO.monoMsNow) - startTime}ms" + +/-- `#genProgramInfo program` ensures the `ProgramInfo` for `program` +has been generated and persistently cached in the enviroment -/ +elab "#genProgramInfo" program:ident : command => liftTermElabM do + let _ ← ProgramInfo.lookupOrGenerate program.getId + + +elab "#genStepEqTheorems" program:term : command => liftTermElabM do + let .const name _ ← Elab.Term.elabTerm program (mkConst ``Program) + | throwError "Expected a constant, found: {program}" + + SymM.run name (persist := true) <| + genStepEqTheorems + /- Generate and prove a fetch theorem of the following form: ``` -theorem ( ++ "fetch_0x" ++ ) (s : ArmState) - (h : s.program = ) : fetch_inst s = some +theorem .("fetch_0x" ++
) (s : ArmState) + (h : s.program = ) : fetch_inst
s = some ``` -/ def genFetchTheorem (program_name : Name) (address_str : String) @@ -110,7 +301,7 @@ def genFetchTheorem (program_name : Name) (address_str : String) /- Generate and prove an exec theorem of the following form: ``` -theorem ( ++ "exec_0x" ++ ) (s : ArmState) : +theorem .("exec_0x" ++ ) (s : ArmState) : exec_inst s = ``` -/ @@ -359,6 +550,8 @@ def test_program : Program := (0x126514#64 , 0x4ea21c5c#32), -- mov v28.16b, v2.16b (0x126518#64 , 0x4ea31c7d#32)] -- mov v29.16b, v3.16b +#genStepEqTheorems test_program + #genStepTheorems test_program thmType:="fetch" /-- info: test_program.fetch_0x126510 (s : ArmState) (h : s.program = test_program) : @@ -402,6 +595,14 @@ info: test_program.stepi_0x126510 (s sn : ArmState) (h_program : s.program = tes #guard_msgs in #check test_program.stepi_0x126510 +/-- +info: test_program.stepi_eq_0x126510 {s : ArmState} (h_program : s.program = test_program) + (h_pc : r StateField.PC s = 1205520#64) (h_err : r StateField.ERR s = StateError.None) : + stepi s = w StateField.PC (1205524#64) (w (StateField.SFP 27#5) (r (StateField.SFP 1#5) s) s) +-/ +#guard_msgs in +#check test_program.stepi_eq_0x126510 + -- Here's the theorem that we'd actually like to obtain instead of the -- erstwhile test_stepi_0x126510. theorem test_stepi_0x126510_desired (s sn : ArmState) diff --git a/Tactics/Sym.lean b/Tactics/Sym.lean index fd147875..4ea216ec 100644 --- a/Tactics/Sym.lean +++ b/Tactics/Sym.lean @@ -103,18 +103,15 @@ def stepiTac (h_step : Ident) (ctx : SymContext) : TacticM Unit := withMainContext do let pc := (Nat.toDigits 16 ctx.pc.toNat).asString -- ^^ The PC in hex - let step_lemma := mkIdent <| Name.str ctx.program s!"stepi_0x{pc}" + let step_lemma := mkIdent <| Name.str ctx.program s!"stepi_eq_0x{pc}" evalTacticAndTrace <|← `(tactic| ( replace $h_step := - (propext_iff.mp + _root_.Eq.trans $h_step ($step_lemma:ident - _ - $ctx.next_state_ident:ident $ctx.h_program_ident:ident $ctx.h_pc_ident:ident - $ctx.h_err_ident:ident)).mp - $h_step + $ctx.h_err_ident:ident) )) elab "stepi_tac" h_step:ident : tactic => do diff --git a/Tactics/SymContext.lean b/Tactics/SymContext.lean index f5cf1b42..42b040bd 100644 --- a/Tactics/SymContext.lean +++ b/Tactics/SymContext.lean @@ -80,18 +80,6 @@ structure SymContext where curr_state_number : Nat := 0 deriving Repr -/-- `h_err_type state` returns an Expr representing `r state = .None`, -the expected type of `h_err` -/ -private def h_err_type (state : Expr) : MetaM Expr := - mkEq - (mkApp2 (.const ``r []) (.const ``StateField.ERR []) state) - (.const ``StateError.None []) - -/-- `h_sp_type state` returns an Expr representing `CheckSPAlignment state`, -the expected type of `h_sp` -/ -private def h_sp_type (state : Expr) : Expr := - mkApp (.const ``CheckSPAlignment []) state - namespace SymContext /-! ## Creating initial contexts -/ @@ -131,7 +119,7 @@ def addGoalsForMissingHypotheses (ctx : SymContext) : TacticM SymContext := let newGoal ← mkFreshMVarId goal := ← do - let goalType ← h_err_type stateExpr + let goalType := h_err_type stateExpr let newGoalExpr ← mkFreshExprMVarWithId newGoal goalType let goal' ← goal.assert h_err? goalType newGoalExpr let ⟨_, goal'⟩ ← goal'.intro1P @@ -225,9 +213,7 @@ def fromLocalContext (state? : Option Name) : MetaM SymContext := do -- Then, try to find `h_pc` let pc ← mkFreshExprMVar (← mkAppM ``BitVec #[toExpr 64]) - let h_pc_type ← do - let lhs ← mkAppM ``r #[(.const ``StateField.PC []), stateExpr] - mkEq lhs pc + let h_pc_type := h_pc_type stateExpr pc let h_pc ← findLocalDeclUsernameOfTypeOrError h_pc_type -- Unwrap and reflect `pc` @@ -235,9 +221,9 @@ def fromLocalContext (state? : Option Name) : MetaM SymContext := do let pc ← withErrorContext h_pc h_pc_type <| reflectBitVecLiteral 64 pc -- Attempt to find `h_err` and `h_sp` - let h_err? ← findLocalDeclUsernameOfType? (←h_err_type stateExpr) + let h_err? ← findLocalDeclUsernameOfType? (h_err_type stateExpr) if h_err?.isNone then - trace[Sym] "Could not find local hypothesis of type {←h_err_type stateExpr}" + trace[Sym] "Could not find local hypothesis of type {h_err_type stateExpr}" let h_sp? ← findLocalDeclUsernameOfType? (h_sp_type stateExpr) if h_sp?.isNone then trace[Sym] "Could not find local hypothesis of type {h_sp_type stateExpr}"