Skip to content

Commit

Permalink
Removing extractLsb from bitvec library and memory aliasing proofs
Browse files Browse the repository at this point in the history
  • Loading branch information
pennyannn committed Sep 27, 2024
1 parent 313b638 commit 2ef760f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 67 deletions.
86 changes: 41 additions & 45 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ attribute [bitvec_rules] BitVec.getLsbD_truncate
attribute [bitvec_rules] BitVec.zeroExtend_zeroExtend_of_le
attribute [bitvec_rules] BitVec.truncate_truncate_of_le
attribute [bitvec_rules] BitVec.truncate_cast
attribute [bitvec_rules] BitVec.extractLsb_ofFin
attribute [bitvec_rules] BitVec.extractLsb_ofNat
attribute [bitvec_rules] BitVec.extractLsb'_toNat
attribute [bitvec_rules] BitVec.extractLsb_toNat
attribute [bitvec_rules] BitVec.getLsbD_extract
attribute [bitvec_rules] BitVec.toNat_allOnes
attribute [bitvec_rules] BitVec.getLsbD_allOnes
Expand Down Expand Up @@ -546,17 +543,16 @@ theorem zeroExtend_if_false [Decidable p] (x : BitVec n)
(zeroExtend (if p then a else b) x) = BitVec.cast h_eq (zeroExtend b x) := by
simp only [toNat_eq, toNat_truncate, ← h_eq, toNat_cast]

theorem extractLsb_eq (x : BitVec n) (h : n = n - 1 + 1) :
BitVec.extractLsb (n - 1) 0 x = BitVec.cast h x := by
unfold extractLsb extractLsb'
ext1
simp [←h]
theorem extractLsb'_eq (x : BitVec n) :
BitVec.extractLsb' 0 n x = x := by
unfold extractLsb'
simp only [Nat.shiftRight_zero, ofNat_toNat, zeroExtend_eq]

@[bitvec_rules]
protected theorem extract_lsb_of_zeroExtend (x : BitVec n) (h : j < i) :
extractLsb' 0 (j + 1) (zeroExtend i x) = zeroExtend (j + 1) x := by
protected theorem extractLsb'_of_zeroExtend (x : BitVec n) (h : j i) :
extractLsb' 0 j (zeroExtend i x) = zeroExtend j x := by
apply BitVec.eq_of_getLsbD_eq
simp
-- simp
intro k
have q : k < i := by omega
by_cases h : decide (k ≤ j) <;> simp [q, h]
Expand Down Expand Up @@ -634,12 +630,11 @@ theorem append_of_extract_general_nat (high low n vn : Nat) (h : vn < 2 ^ n) :
done

theorem append_of_extract (n : Nat) (v : BitVec n)
(high0 : high = n - low) (low0 : 1 <= low)
(h : high + (low - 1 - 0 + 1) = n) :
BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb (low - 1) 0 v) = v := by
(high0 : high = n - low) (h : high + low = n) :
BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb' 0 low v) = v := by
ext
subst high
have vlt := v.isLt; simp_all only [Nat.sub_zero]
have vlt := v.isLt
have := append_of_extract_general_nat (n - low) low n (BitVec.toNat v) vlt
have low_le : low ≤ n := by omega
simp_all [toNat_zeroExtend, Nat.sub_add_cancel, low_le]
Expand All @@ -649,17 +644,13 @@ theorem append_of_extract (n : Nat) (v : BitVec n)
exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) vlt
done

theorem append_of_extract_general (v : BitVec n)
(low0 : 1 <= low)
(h1 : high = width)
(h2 : (high + low - 1 - 0 + 1) = (width + (low - 1 - 0 + 1))) :
BitVec.cast h1 (zeroExtend high (v >>> low)) ++ extractLsb (low - 1) 0 v =
BitVec.cast h2 (extractLsb (high + low - 1) 0 v) := by
theorem append_of_extract_general (v : BitVec n) :
(zeroExtend high (v >>> low)) ++ extractLsb' 0 low v =
extractLsb' 0 (high + low) v := by
ext
have := append_of_extract_general_nat high low n (BitVec.toNat v)
have h_vlt := v.isLt; simp_all only [Nat.sub_zero, h1]
simp only [h_vlt, h1, forall_prop_of_true] at this
have low' : 1 ≤ width + low := Nat.le_trans low0 (Nat.le_add_left low width)
have h_vlt := v.isLt; simp_all only [Nat.sub_zero]
simp only [h_vlt, forall_prop_of_true] at this
simp_all [toNat_zeroExtend, Nat.sub_add_cancel]
rw [Nat.mod_eq_of_lt (b := 2 ^ n)] at this
· rw [this]
Expand Down Expand Up @@ -788,7 +779,7 @@ def genBVPatMatchTest (var : Term) (pat : BVPat) : MacroM Term := do
for c in pat.getComponents do
let len := c.length
if let some bv ← c.toBVLit? then
let test ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var == $bv)
let test ← `(extractLsb' $(quote shift) $(quote len) $var == $bv)
result ← `($result && $test)
shift := shift + len
return result
Expand All @@ -810,7 +801,7 @@ def declBVPatVars (var : Term) (pat : BVPat) (rhs : Term) : MacroM Term := do
for c in pat.getComponents do
let len := c.length
if let some y ← c.toBVVar? then
let rhs ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var)
let rhs ← `(extractLsb' $(quote shift) $(quote len) $var)
result ← `(let $y := $rhs; $result)
shift := shift + len
return result
Expand Down Expand Up @@ -934,31 +925,36 @@ Definition to extract the `n`th least significant *Byte* from a bitvector.
TODO: this should be named `getLsByte`, or `getLsbByte` (Shilpi prefers this).
-/
def extractLsByte (val : BitVec w₁) (n : Nat) : BitVec 8 :=
val.extractLsb ((n + 1) * 8 - 1) (n * 8) |> .cast (by omega)
val.extractLsb' (n * 8) 8

theorem extractLsByte_def (val : BitVec w₁) (n : Nat) :
val.extractLsByte n = (val.extractLsb ((n + 1)*8 - 1) (n * 8) |>.cast (by omega)) := rfl
val.extractLsByte n = val.extractLsb' (n * 8) 8 := rfl

-- TODO: upstream
theorem extractLsb_or (x y : BitVec w₁) (n : Nat) :
(x ||| y).extractLsb n lo = (x.extractLsb n lo ||| y.extractLsb n lo) := by
theorem extractLsb'_or (x y : BitVec w₁) (n : Nat) :
(x ||| y).extractLsb' lo n = (x.extractLsb' lo n ||| y.extractLsb' lo n) := by
apply BitVec.eq_of_getLsbD_eq
simp only [getLsbD_extract, getLsbD_or]
intros i
by_cases h : (i : Nat) ≤ n - lo
· simp only [h, decide_True, Bool.true_and]
· simp only [h, decide_False, Bool.false_and, Bool.or_self]
simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, getLsbD_or, Bool.true_and]

-- TODO: upstream
protected theorem extractLsb'_ofNat (x n : Nat) (l lo : Nat) :
extractLsb' lo l (BitVec.ofNat n x) = .ofNat l ((x % 2^n) >>> lo) := by
apply eq_of_getLsbD_eq
intro ⟨i, _lt⟩
simp [BitVec.ofNat]

theorem extractLsByte_zero {w : Nat} : (0#w).extractLsByte i = 0#8 := by
simp only [extractLsByte, BitVec.extractLsb_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat]
simp only [extractLsByte, BitVec.extractLsb'_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat]

theorem extractLsByte_ge (h : 8 * a ≥ w₁) (x : BitVec w₁) :
x.extractLsByte a = 0#8 := by
apply BitVec.eq_of_getLsbD_eq
intros i
simp only [getLsbD_zero, extractLsByte_def,
getLsbD_cast, getLsbD_extract, Bool.and_eq_false_imp, decide_eq_true_eq]
intros _
simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, Bool.true_and]
apply BitVec.getLsbD_ge
omega

Expand All @@ -967,10 +963,13 @@ theorem getLsbD_extractLsByte (val : BitVec w₁) :
((BitVec.extractLsByte val n).getLsbD i) =
(decide (i ≤ 7) && val.getLsbD (n * 8 + i)) := by
simp only [extractLsByte, getLsbD_cast, getLsbD_extract]
rw [Nat.succ_mul]
simp only [Nat.add_one_sub_one,
Nat.add_sub_cancel_left]

simp only [getLsbD_extractLsb']
generalize val.getLsbD (n * 8 + i) = x
by_cases h : i < 8
· simp only [show (i : Nat) ≤ 7 by omega, decide_True, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq, h]
· simp only [show ¬(i : Nat) ≤ 7 by omega, decide_False, Bool.false_and,
Bool.and_eq_false_imp, decide_eq_true_eq, h]

/--
Two bitvectors of length `n*8` are equal if all their bytes are equal.
Expand All @@ -994,9 +993,7 @@ theorem eq_of_extractLsByte_eq (x y : BitVec (n * 8))
@bollu: it's not clear if the definition for n=0 is desirable.
-/
def extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) : BitVec (n * 8) :=
match h : n with
| 0 => 0#0
| x + 1 => val.extractLsb (base * 8 + n * 8 - 1) (base * 8) |>.cast (by omega)
extractLsb' (base * 8) (n * 8) val

@[bitvec_rules]
theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) :
Expand All @@ -1009,10 +1006,9 @@ theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat)
simp only [show ¬i < 0 by omega, decide_False, Bool.false_and]
· simp only [extractLsBytes, getLsbD_cast, getLsbD_extract, Nat.zero_lt_succ, decide_True,
Bool.true_and]
simp only [show base * 8 + (n + 1) * 8 - 1 - base * 8 = (n + 1) * 8 - 1 by omega]
by_cases h : i < (n + 1) * 8
· simp only [show i ≤ (n + 1) * 8 - 1 by omega, decide_True, Bool.true_and, h]
· simp only [show ¬(i ≤ (n + 1) * 8 - 1) by omega, decide_False, Bool.false_and, h]
· simp only [getLsbD_extractLsb', h, decide_True, Bool.true_and]
· simp only [getLsbD_extractLsb', h, decide_False, Bool.false_and]

theorem extractLsByte_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) :
(BitVec.extractLsBytes val base n).extractLsByte i =
Expand Down
35 changes: 15 additions & 20 deletions Arm/Memory/MemoryProofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,10 @@ theorem read_mem_of_write_mem_bytes_different (hn1 : n <= 2^64)

theorem append_byte_of_extract_rest_same_cast (n : Nat) (v : BitVec ((n + 1) * 8))
(hn0 : Nat.succ 0 ≤ n)
(h : (n * 8 + (7 - 0 + 1)) = (n + 1) * 8) :
BitVec.cast h (zeroExtend (n * 8) (v >>> 8) ++ extractLsb 7 0 v) = v := by
(h : (n * 8 + 8) = (n + 1) * 8) :
BitVec.cast h (zeroExtend (n * 8) (v >>> 8) ++ extractLsb' 0 8 v) = v := by
apply BitVec.append_of_extract
· omega
· omega
· omega
done

@[state_simp_rules]
Expand All @@ -85,7 +83,7 @@ theorem read_mem_bytes_of_write_mem_bytes_same (hn1 : n <= 2^64) :
case base =>
simp only [read_mem_bytes, write_mem_bytes,
read_mem_of_write_mem_same, BitVec.cast_eq]
have l1 := BitVec.extractLsb_eq v
have l1 := BitVec.extractLsb'_eq v
simp only [Nat.reduceSucc, Nat.one_mul, Nat.succ_sub_succ_eq_sub,
Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq,
forall_const] at l1
Expand Down Expand Up @@ -431,7 +429,7 @@ private theorem write_mem_bytes_of_write_mem_bytes_shadow_general_n2_eq
rename_i n n_ih
conv in write_mem_bytes (Nat.succ n) .. => simp only [write_mem_bytes]
have n_ih' := @n_ih (addr1 + 1#64) val2 (zeroExtend (n * 8) (val1 >>> 8))
(write_mem addr1 (extractLsb 7 0 val1) s)
(write_mem addr1 (extractLsb' 0 8 val1) s)
(by omega)
simp only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero] at h3
by_cases h₁ : n = 0
Expand Down Expand Up @@ -483,7 +481,7 @@ theorem write_mem_bytes_of_write_mem_bytes_shadow_general
theorem read_mem_of_write_mem_bytes_same_first_address
(h0 : 0 < n) (h1 : n <= 2^64) (h : 7 - 0 + 1 = 8) :
read_mem addr (write_mem_bytes n addr val s) =
BitVec.cast h (extractLsb 7 0 val) := by
BitVec.cast h (extractLsb' 0 8 val) := by
unfold write_mem_bytes; simp only [Nat.sub_zero, BitVec.cast_eq]
split
· contradiction
Expand All @@ -495,18 +493,16 @@ theorem read_mem_of_write_mem_bytes_same_first_address
-- (FIXME) Argh, it's annoying to need this lemma, but using
-- BitVec.cast_eq directly was cumbersome.
theorem cast_of_extract_eq (v : BitVec p)
(h1 : hi1 = hi2) (h2 : lo1 = lo2)
(h : hi1 - lo1 + 1 = hi2 - lo2 + 1) :
BitVec.cast h (extractLsb hi1 lo1 v) = (extractLsb hi2 lo2 v) := by
(h1 : n1 = n2) (h2 : lo1 = lo2):
BitVec.cast h (extractLsb' lo1 n1 v) = (extractLsb' lo2 n2 v) := by
subst_vars
simp only [Nat.sub_zero, BitVec.cast_eq]

theorem read_mem_bytes_of_write_mem_bytes_subset_same_first_address
(h0 : 0 < n1) (h1 : n1 <= 2^64) (h2 : 0 < n2) (h3 : n2 <= 2^64)
(h4 : mem_subset addr (addr + (BitVec.ofNat 64 (n2 - 1))) addr (addr + (BitVec.ofNat 64 (n1 - 1))))
(h : n2 * 8 - 1 - 0 + 1 = n2 * 8) :
(h4 : mem_subset addr (addr + (BitVec.ofNat 64 (n2 - 1))) addr (addr + (BitVec.ofNat 64 (n1 - 1)))):
read_mem_bytes n2 addr (write_mem_bytes n1 addr val s) =
BitVec.cast h (extractLsb ((n2 * 8) - 1) 0 val) := by
extractLsb' 0 (n2 * 8) val := by
have rm_lemma := @read_mem_of_write_mem_bytes_same_first_address n1 addr val s h0 h1
simp only [Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, forall_const] at rm_lemma
induction n2, h2 using Nat.le_induction generalizing n1 addr val s
Expand Down Expand Up @@ -543,21 +539,20 @@ theorem read_mem_bytes_of_write_mem_bytes_subset_same_first_address
erw [Nat.mod_eq_of_lt h3] at hn
erw [Nat.mod_eq_of_lt h1] at hn
exact hn
rw [n_ih (by omega) (by omega) (by omega) _ (by omega)]
· rw [BitVec.extract_lsb_of_zeroExtend (v >>> 8)]
· have l1 := @BitVec.append_of_extract_general ((n1_1 + 1) * 8) 8 (n*8-1+1) (n*8) v
rw [n_ih (by omega) (by omega) (by omega) _]
· rw [BitVec.extractLsb'_of_zeroExtend (v >>> 8)]
· have l1 := @BitVec.append_of_extract_general ((n1_1 + 1) * 8) (n*8) 8 v
simp (config := { decide := true }) only [Nat.zero_lt_succ,
Nat.mul_pos_iff_of_pos_left, Nat.succ_sub_succ_eq_sub,
Nat.sub_zero, Nat.reduceAdd, Nat.succ.injEq, forall_const] at l1
rw [l1 (by omega) (by omega)]
· simp only [Nat.add_eq, Nat.sub_zero, BitVec.cast_cast]
apply @cast_of_extract_eq ((n1_1 + 1) * 8) (n * 8 - 1 + 1 + 7) ((n + 1) * 8 - 1) 0 0 <;>
rw [l1]
· apply @cast_of_extract_eq ((n1_1 + 1) * 8) (n * 8 + 8) ((n + 1) * 8) 0 0 <;>
omega
· omega
· have rw_lemma2 := @read_mem_of_write_mem_bytes_same_first_address
n1_1 (addr + 1#64)
(zeroExtend (n1_1 * 8) (v >>> 8))
(write_mem addr (extractLsb 7 0 v) s)
(write_mem addr (extractLsb' 0 8 v) s)
simp only [Nat.reducePow, Nat.sub_zero, Nat.reduceAdd,
BitVec.cast_eq, forall_const] at rw_lemma2
rw [rw_lemma2 (by omega) (by simp only [Nat.reducePow] at h1; omega)]
Expand Down
4 changes: 2 additions & 2 deletions Arm/State.lean
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def write_bytes (n : Nat) (addr : BitVec 64)
match n with
| 0 => m
| n' + 1 =>
let byte := BitVec.extractLsb 7 0 val
let byte := BitVec.extractLsb' 0 8 val
let m := m.write addr byte
let val_rest := BitVec.zeroExtend (n' * 8) (val >>> 8)
m.write_bytes n' (addr + 1#64) val_rest
Expand All @@ -988,7 +988,7 @@ and then recursing to write the rest.
-/
theorem write_bytes_succ {mem : Memory} :
mem.write_bytes (n + 1) addr val =
let byte := BitVec.extractLsb 7 0 val
let byte := BitVec.extractLsb' 0 8 val
let mem := mem.write addr byte
let val_rest := BitVec.zeroExtend (n * 8) (val >>> 8)
mem.write_bytes n (addr + 1#64) val_rest := rfl
Expand Down

0 comments on commit 2ef760f

Please sign in to comment.