Skip to content

Commit

Permalink
Removing extractLsb from bitvec library and memory aliasing proofs
Browse files Browse the repository at this point in the history
  • Loading branch information
pennyannn committed Sep 30, 2024
1 parent 313b638 commit 5070a68
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 139 deletions.
85 changes: 40 additions & 45 deletions Arm/BitVec.lean
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ attribute [bitvec_rules] BitVec.getLsbD_truncate
attribute [bitvec_rules] BitVec.zeroExtend_zeroExtend_of_le
attribute [bitvec_rules] BitVec.truncate_truncate_of_le
attribute [bitvec_rules] BitVec.truncate_cast
attribute [bitvec_rules] BitVec.extractLsb_ofFin
attribute [bitvec_rules] BitVec.extractLsb_ofNat
attribute [bitvec_rules] BitVec.extractLsb'_toNat
attribute [bitvec_rules] BitVec.extractLsb_toNat
attribute [bitvec_rules] BitVec.getLsbD_extract
attribute [bitvec_rules] BitVec.toNat_allOnes
attribute [bitvec_rules] BitVec.getLsbD_allOnes
Expand Down Expand Up @@ -546,17 +543,15 @@ theorem zeroExtend_if_false [Decidable p] (x : BitVec n)
(zeroExtend (if p then a else b) x) = BitVec.cast h_eq (zeroExtend b x) := by
simp only [toNat_eq, toNat_truncate, ← h_eq, toNat_cast]

theorem extractLsb_eq (x : BitVec n) (h : n = n - 1 + 1) :
BitVec.extractLsb (n - 1) 0 x = BitVec.cast h x := by
unfold extractLsb extractLsb'
ext1
simp [←h]
theorem extractLsb'_eq (x : BitVec n) :
BitVec.extractLsb' 0 n x = x := by
unfold extractLsb'
simp only [Nat.shiftRight_zero, ofNat_toNat, zeroExtend_eq]

@[bitvec_rules]
protected theorem extract_lsb_of_zeroExtend (x : BitVec n) (h : j < i) :
extractLsb' 0 (j + 1) (zeroExtend i x) = zeroExtend (j + 1) x := by
protected theorem extractLsb'_of_zeroExtend (x : BitVec n) (h : j i) :
extractLsb' 0 j (zeroExtend i x) = zeroExtend j x := by
apply BitVec.eq_of_getLsbD_eq
simp
intro k
have q : k < i := by omega
by_cases h : decide (k ≤ j) <;> simp [q, h]
Expand Down Expand Up @@ -634,12 +629,11 @@ theorem append_of_extract_general_nat (high low n vn : Nat) (h : vn < 2 ^ n) :
done

theorem append_of_extract (n : Nat) (v : BitVec n)
(high0 : high = n - low) (low0 : 1 <= low)
(h : high + (low - 1 - 0 + 1) = n) :
BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb (low - 1) 0 v) = v := by
(high0 : high = n - low) (h : high + low = n) :
BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb' 0 low v) = v := by
ext
subst high
have vlt := v.isLt; simp_all only [Nat.sub_zero]
have vlt := v.isLt
have := append_of_extract_general_nat (n - low) low n (BitVec.toNat v) vlt
have low_le : low ≤ n := by omega
simp_all [toNat_zeroExtend, Nat.sub_add_cancel, low_le]
Expand All @@ -649,17 +643,13 @@ theorem append_of_extract (n : Nat) (v : BitVec n)
exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) vlt
done

theorem append_of_extract_general (v : BitVec n)
(low0 : 1 <= low)
(h1 : high = width)
(h2 : (high + low - 1 - 0 + 1) = (width + (low - 1 - 0 + 1))) :
BitVec.cast h1 (zeroExtend high (v >>> low)) ++ extractLsb (low - 1) 0 v =
BitVec.cast h2 (extractLsb (high + low - 1) 0 v) := by
theorem append_of_extract_general (v : BitVec n) :
(zeroExtend high (v >>> low)) ++ extractLsb' 0 low v =
extractLsb' 0 (high + low) v := by
ext
have := append_of_extract_general_nat high low n (BitVec.toNat v)
have h_vlt := v.isLt; simp_all only [Nat.sub_zero, h1]
simp only [h_vlt, h1, forall_prop_of_true] at this
have low' : 1 ≤ width + low := Nat.le_trans low0 (Nat.le_add_left low width)
have h_vlt := v.isLt; simp_all only [Nat.sub_zero]
simp only [h_vlt, forall_prop_of_true] at this
simp_all [toNat_zeroExtend, Nat.sub_add_cancel]
rw [Nat.mod_eq_of_lt (b := 2 ^ n)] at this
· rw [this]
Expand Down Expand Up @@ -788,7 +778,7 @@ def genBVPatMatchTest (var : Term) (pat : BVPat) : MacroM Term := do
for c in pat.getComponents do
let len := c.length
if let some bv ← c.toBVLit? then
let test ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var == $bv)
let test ← `(extractLsb' $(quote shift) $(quote len) $var == $bv)
result ← `($result && $test)
shift := shift + len
return result
Expand All @@ -810,7 +800,7 @@ def declBVPatVars (var : Term) (pat : BVPat) (rhs : Term) : MacroM Term := do
for c in pat.getComponents do
let len := c.length
if let some y ← c.toBVVar? then
let rhs ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var)
let rhs ← `(extractLsb' $(quote shift) $(quote len) $var)
result ← `(let $y := $rhs; $result)
shift := shift + len
return result
Expand Down Expand Up @@ -934,31 +924,36 @@ Definition to extract the `n`th least significant *Byte* from a bitvector.
TODO: this should be named `getLsByte`, or `getLsbByte` (Shilpi prefers this).
-/
def extractLsByte (val : BitVec w₁) (n : Nat) : BitVec 8 :=
val.extractLsb ((n + 1) * 8 - 1) (n * 8) |> .cast (by omega)
val.extractLsb' (n * 8) 8

theorem extractLsByte_def (val : BitVec w₁) (n : Nat) :
val.extractLsByte n = (val.extractLsb ((n + 1)*8 - 1) (n * 8) |>.cast (by omega)) := rfl
val.extractLsByte n = val.extractLsb' (n * 8) 8 := rfl

-- TODO: upstream
theorem extractLsb_or (x y : BitVec w₁) (n : Nat) :
(x ||| y).extractLsb n lo = (x.extractLsb n lo ||| y.extractLsb n lo) := by
theorem extractLsb'_or (x y : BitVec w₁) (n : Nat) :
(x ||| y).extractLsb' lo n = (x.extractLsb' lo n ||| y.extractLsb' lo n) := by
apply BitVec.eq_of_getLsbD_eq
simp only [getLsbD_extract, getLsbD_or]
intros i
by_cases h : (i : Nat) ≤ n - lo
· simp only [h, decide_True, Bool.true_and]
· simp only [h, decide_False, Bool.false_and, Bool.or_self]
simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, getLsbD_or, Bool.true_and]

-- TODO: upstream
protected theorem extractLsb'_ofNat (x n : Nat) (l lo : Nat) :
extractLsb' lo l (BitVec.ofNat n x) = .ofNat l ((x % 2^n) >>> lo) := by
apply eq_of_getLsbD_eq
intro ⟨i, _lt⟩
simp [BitVec.ofNat]

theorem extractLsByte_zero {w : Nat} : (0#w).extractLsByte i = 0#8 := by
simp only [extractLsByte, BitVec.extractLsb_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat]
simp only [extractLsByte, BitVec.extractLsb'_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat]

theorem extractLsByte_ge (h : 8 * a ≥ w₁) (x : BitVec w₁) :
x.extractLsByte a = 0#8 := by
apply BitVec.eq_of_getLsbD_eq
intros i
simp only [getLsbD_zero, extractLsByte_def,
getLsbD_cast, getLsbD_extract, Bool.and_eq_false_imp, decide_eq_true_eq]
intros _
simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, Bool.true_and]
apply BitVec.getLsbD_ge
omega

Expand All @@ -967,10 +962,13 @@ theorem getLsbD_extractLsByte (val : BitVec w₁) :
((BitVec.extractLsByte val n).getLsbD i) =
(decide (i ≤ 7) && val.getLsbD (n * 8 + i)) := by
simp only [extractLsByte, getLsbD_cast, getLsbD_extract]
rw [Nat.succ_mul]
simp only [Nat.add_one_sub_one,
Nat.add_sub_cancel_left]

simp only [getLsbD_extractLsb']
generalize val.getLsbD (n * 8 + i) = x
by_cases h : i < 8
· simp only [show (i : Nat) ≤ 7 by omega, decide_True, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq, h]
· simp only [show ¬(i : Nat) ≤ 7 by omega, decide_False, Bool.false_and,
Bool.and_eq_false_imp, decide_eq_true_eq, h]

/--
Two bitvectors of length `n*8` are equal if all their bytes are equal.
Expand All @@ -994,9 +992,7 @@ theorem eq_of_extractLsByte_eq (x y : BitVec (n * 8))
@bollu: it's not clear if the definition for n=0 is desirable.
-/
def extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) : BitVec (n * 8) :=
match h : n with
| 0 => 0#0
| x + 1 => val.extractLsb (base * 8 + n * 8 - 1) (base * 8) |>.cast (by omega)
extractLsb' (base * 8) (n * 8) val

@[bitvec_rules]
theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) :
Expand All @@ -1009,10 +1005,9 @@ theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat)
simp only [show ¬i < 0 by omega, decide_False, Bool.false_and]
· simp only [extractLsBytes, getLsbD_cast, getLsbD_extract, Nat.zero_lt_succ, decide_True,
Bool.true_and]
simp only [show base * 8 + (n + 1) * 8 - 1 - base * 8 = (n + 1) * 8 - 1 by omega]
by_cases h : i < (n + 1) * 8
· simp only [show i ≤ (n + 1) * 8 - 1 by omega, decide_True, Bool.true_and, h]
· simp only [show ¬(i ≤ (n + 1) * 8 - 1) by omega, decide_False, Bool.false_and, h]
· simp only [getLsbD_extractLsb', h, decide_True, Bool.true_and]
· simp only [getLsbD_extractLsb', h, decide_False, Bool.false_and]

theorem extractLsByte_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) :
(BitVec.extractLsBytes val base n).extractLsByte i =
Expand Down
Loading

0 comments on commit 5070a68

Please sign in to comment.