diff --git a/Arm/BitVec.lean b/Arm/BitVec.lean index c36588d4..53a4015d 100644 --- a/Arm/BitVec.lean +++ b/Arm/BitVec.lean @@ -59,10 +59,7 @@ attribute [bitvec_rules] BitVec.getLsbD_truncate attribute [bitvec_rules] BitVec.zeroExtend_zeroExtend_of_le attribute [bitvec_rules] BitVec.truncate_truncate_of_le attribute [bitvec_rules] BitVec.truncate_cast -attribute [bitvec_rules] BitVec.extractLsb_ofFin -attribute [bitvec_rules] BitVec.extractLsb_ofNat attribute [bitvec_rules] BitVec.extractLsb'_toNat -attribute [bitvec_rules] BitVec.extractLsb_toNat attribute [bitvec_rules] BitVec.getLsbD_extract attribute [bitvec_rules] BitVec.toNat_allOnes attribute [bitvec_rules] BitVec.getLsbD_allOnes @@ -546,17 +543,15 @@ theorem zeroExtend_if_false [Decidable p] (x : BitVec n) (zeroExtend (if p then a else b) x) = BitVec.cast h_eq (zeroExtend b x) := by simp only [toNat_eq, toNat_truncate, ← h_eq, toNat_cast] -theorem extractLsb_eq (x : BitVec n) (h : n = n - 1 + 1) : - BitVec.extractLsb (n - 1) 0 x = BitVec.cast h x := by - unfold extractLsb extractLsb' - ext1 - simp [←h] +theorem extractLsb'_eq (x : BitVec n) : + BitVec.extractLsb' 0 n x = x := by + unfold extractLsb' + simp only [Nat.shiftRight_zero, ofNat_toNat, zeroExtend_eq] @[bitvec_rules] -protected theorem extract_lsb_of_zeroExtend (x : BitVec n) (h : j < i) : - extractLsb' 0 (j + 1) (zeroExtend i x) = zeroExtend (j + 1) x := by +protected theorem extractLsb'_of_zeroExtend (x : BitVec n) (h : j ≤ i) : + extractLsb' 0 j (zeroExtend i x) = zeroExtend j x := by apply BitVec.eq_of_getLsbD_eq - simp intro k have q : k < i := by omega by_cases h : decide (k ≤ j) <;> simp [q, h] @@ -634,12 +629,11 @@ theorem append_of_extract_general_nat (high low n vn : Nat) (h : vn < 2 ^ n) : done theorem append_of_extract (n : Nat) (v : BitVec n) - (high0 : high = n - low) (low0 : 1 <= low) - (h : high + (low - 1 - 0 + 1) = n) : - BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb (low - 1) 0 v) = v := by + (high0 : high = n - low) (h : high + low = n) : + BitVec.cast h (zeroExtend high (v >>> low) ++ extractLsb' 0 low v) = v := by ext subst high - have vlt := v.isLt; simp_all only [Nat.sub_zero] + have vlt := v.isLt have := append_of_extract_general_nat (n - low) low n (BitVec.toNat v) vlt have low_le : low ≤ n := by omega simp_all [toNat_zeroExtend, Nat.sub_add_cancel, low_le] @@ -649,17 +643,13 @@ theorem append_of_extract (n : Nat) (v : BitVec n) exact Nat.lt_of_le_of_lt (Nat.div_le_self _ _) vlt done -theorem append_of_extract_general (v : BitVec n) - (low0 : 1 <= low) - (h1 : high = width) - (h2 : (high + low - 1 - 0 + 1) = (width + (low - 1 - 0 + 1))) : - BitVec.cast h1 (zeroExtend high (v >>> low)) ++ extractLsb (low - 1) 0 v = - BitVec.cast h2 (extractLsb (high + low - 1) 0 v) := by +theorem append_of_extract_general (v : BitVec n) : + (zeroExtend high (v >>> low)) ++ extractLsb' 0 low v = + extractLsb' 0 (high + low) v := by ext have := append_of_extract_general_nat high low n (BitVec.toNat v) - have h_vlt := v.isLt; simp_all only [Nat.sub_zero, h1] - simp only [h_vlt, h1, forall_prop_of_true] at this - have low' : 1 ≤ width + low := Nat.le_trans low0 (Nat.le_add_left low width) + have h_vlt := v.isLt; simp_all only [Nat.sub_zero] + simp only [h_vlt, forall_prop_of_true] at this simp_all [toNat_zeroExtend, Nat.sub_add_cancel] rw [Nat.mod_eq_of_lt (b := 2 ^ n)] at this · rw [this] @@ -788,7 +778,7 @@ def genBVPatMatchTest (var : Term) (pat : BVPat) : MacroM Term := do for c in pat.getComponents do let len := c.length if let some bv ← c.toBVLit? then - let test ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var == $bv) + let test ← `(extractLsb' $(quote shift) $(quote len) $var == $bv) result ← `($result && $test) shift := shift + len return result @@ -810,7 +800,7 @@ def declBVPatVars (var : Term) (pat : BVPat) (rhs : Term) : MacroM Term := do for c in pat.getComponents do let len := c.length if let some y ← c.toBVVar? then - let rhs ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var) + let rhs ← `(extractLsb' $(quote shift) $(quote len) $var) result ← `(let $y := $rhs; $result) shift := shift + len return result @@ -934,23 +924,28 @@ Definition to extract the `n`th least significant *Byte* from a bitvector. TODO: this should be named `getLsByte`, or `getLsbByte` (Shilpi prefers this). -/ def extractLsByte (val : BitVec w₁) (n : Nat) : BitVec 8 := - val.extractLsb ((n + 1) * 8 - 1) (n * 8) |> .cast (by omega) + val.extractLsb' (n * 8) 8 theorem extractLsByte_def (val : BitVec w₁) (n : Nat) : - val.extractLsByte n = (val.extractLsb ((n + 1)*8 - 1) (n * 8) |>.cast (by omega)) := rfl + val.extractLsByte n = val.extractLsb' (n * 8) 8 := rfl -- TODO: upstream -theorem extractLsb_or (x y : BitVec w₁) (n : Nat) : - (x ||| y).extractLsb n lo = (x.extractLsb n lo ||| y.extractLsb n lo) := by +theorem extractLsb'_or (x y : BitVec w₁) (n : Nat) : + (x ||| y).extractLsb' lo n = (x.extractLsb' lo n ||| y.extractLsb' lo n) := by apply BitVec.eq_of_getLsbD_eq simp only [getLsbD_extract, getLsbD_or] intros i - by_cases h : (i : Nat) ≤ n - lo - · simp only [h, decide_True, Bool.true_and] - · simp only [h, decide_False, Bool.false_and, Bool.or_self] + simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, getLsbD_or, Bool.true_and] + +-- TODO: upstream +protected theorem extractLsb'_ofNat (x n : Nat) (l lo : Nat) : + extractLsb' lo l (BitVec.ofNat n x) = .ofNat l ((x % 2^n) >>> lo) := by + apply eq_of_getLsbD_eq + intro ⟨i, _lt⟩ + simp [BitVec.ofNat] theorem extractLsByte_zero {w : Nat} : (0#w).extractLsByte i = 0#8 := by - simp only [extractLsByte, BitVec.extractLsb_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat] + simp only [extractLsByte, BitVec.extractLsb'_ofNat, Nat.zero_mod, Nat.zero_shiftRight, cast_ofNat] theorem extractLsByte_ge (h : 8 * a ≥ w₁) (x : BitVec w₁) : x.extractLsByte a = 0#8 := by @@ -958,7 +953,7 @@ theorem extractLsByte_ge (h : 8 * a ≥ w₁) (x : BitVec w₁) : intros i simp only [getLsbD_zero, extractLsByte_def, getLsbD_cast, getLsbD_extract, Bool.and_eq_false_imp, decide_eq_true_eq] - intros _ + simp only [getLsbD_extractLsb', Fin.is_lt, decide_True, Bool.true_and] apply BitVec.getLsbD_ge omega @@ -967,10 +962,13 @@ theorem getLsbD_extractLsByte (val : BitVec w₁) : ((BitVec.extractLsByte val n).getLsbD i) = (decide (i ≤ 7) && val.getLsbD (n * 8 + i)) := by simp only [extractLsByte, getLsbD_cast, getLsbD_extract] - rw [Nat.succ_mul] - simp only [Nat.add_one_sub_one, - Nat.add_sub_cancel_left] - + simp only [getLsbD_extractLsb'] + generalize val.getLsbD (n * 8 + i) = x + by_cases h : i < 8 + · simp only [show (i : Nat) ≤ 7 by omega, decide_True, Bool.true_and, + Bool.and_iff_right_iff_imp, decide_eq_true_eq, h] + · simp only [show ¬(i : Nat) ≤ 7 by omega, decide_False, Bool.false_and, + Bool.and_eq_false_imp, decide_eq_true_eq, h] /-- Two bitvectors of length `n*8` are equal if all their bytes are equal. @@ -994,9 +992,7 @@ theorem eq_of_extractLsByte_eq (x y : BitVec (n * 8)) @bollu: it's not clear if the definition for n=0 is desirable. -/ def extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) : BitVec (n * 8) := - match h : n with - | 0 => 0#0 - | x + 1 => val.extractLsb (base * 8 + n * 8 - 1) (base * 8) |>.cast (by omega) + extractLsb' (base * 8) (n * 8) val @[bitvec_rules] theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) : @@ -1009,10 +1005,9 @@ theorem getLsbD_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) simp only [show ¬i < 0 by omega, decide_False, Bool.false_and] · simp only [extractLsBytes, getLsbD_cast, getLsbD_extract, Nat.zero_lt_succ, decide_True, Bool.true_and] - simp only [show base * 8 + (n + 1) * 8 - 1 - base * 8 = (n + 1) * 8 - 1 by omega] by_cases h : i < (n + 1) * 8 - · simp only [show i ≤ (n + 1) * 8 - 1 by omega, decide_True, Bool.true_and, h] - · simp only [show ¬(i ≤ (n + 1) * 8 - 1) by omega, decide_False, Bool.false_and, h] + · simp only [getLsbD_extractLsb', h, decide_True, Bool.true_and] + · simp only [getLsbD_extractLsb', h, decide_False, Bool.false_and] theorem extractLsByte_extractLsBytes (val : BitVec w) (base : Nat) (n : Nat) (i : Nat) : (BitVec.extractLsBytes val base n).extractLsByte i = diff --git a/Arm/Memory/MemoryProofs.lean b/Arm/Memory/MemoryProofs.lean index 6973512a..07fc4f1a 100644 --- a/Arm/Memory/MemoryProofs.lean +++ b/Arm/Memory/MemoryProofs.lean @@ -63,12 +63,10 @@ theorem read_mem_of_write_mem_bytes_different (hn1 : n <= 2^64) theorem append_byte_of_extract_rest_same_cast (n : Nat) (v : BitVec ((n + 1) * 8)) (hn0 : Nat.succ 0 ≤ n) - (h : (n * 8 + (7 - 0 + 1)) = (n + 1) * 8) : - BitVec.cast h (zeroExtend (n * 8) (v >>> 8) ++ extractLsb 7 0 v) = v := by + (h : (n * 8 + 8) = (n + 1) * 8) : + BitVec.cast h (zeroExtend (n * 8) (v >>> 8) ++ extractLsb' 0 8 v) = v := by apply BitVec.append_of_extract · omega - · omega - · omega done @[state_simp_rules] @@ -85,7 +83,7 @@ theorem read_mem_bytes_of_write_mem_bytes_same (hn1 : n <= 2^64) : case base => simp only [read_mem_bytes, write_mem_bytes, read_mem_of_write_mem_same, BitVec.cast_eq] - have l1 := BitVec.extractLsb_eq v + have l1 := BitVec.extractLsb'_eq v simp only [Nat.reduceSucc, Nat.one_mul, Nat.succ_sub_succ_eq_sub, Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, forall_const] at l1 @@ -431,7 +429,7 @@ private theorem write_mem_bytes_of_write_mem_bytes_shadow_general_n2_eq rename_i n n_ih conv in write_mem_bytes (Nat.succ n) .. => simp only [write_mem_bytes] have n_ih' := @n_ih (addr1 + 1#64) val2 (zeroExtend (n * 8) (val1 >>> 8)) - (write_mem addr1 (extractLsb 7 0 val1) s) + (write_mem addr1 (extractLsb' 0 8 val1) s) (by omega) simp only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero] at h3 by_cases h₁ : n = 0 @@ -483,7 +481,7 @@ theorem write_mem_bytes_of_write_mem_bytes_shadow_general theorem read_mem_of_write_mem_bytes_same_first_address (h0 : 0 < n) (h1 : n <= 2^64) (h : 7 - 0 + 1 = 8) : read_mem addr (write_mem_bytes n addr val s) = - BitVec.cast h (extractLsb 7 0 val) := by + BitVec.cast h (extractLsb' 0 8 val) := by unfold write_mem_bytes; simp only [Nat.sub_zero, BitVec.cast_eq] split · contradiction @@ -495,18 +493,16 @@ theorem read_mem_of_write_mem_bytes_same_first_address -- (FIXME) Argh, it's annoying to need this lemma, but using -- BitVec.cast_eq directly was cumbersome. theorem cast_of_extract_eq (v : BitVec p) - (h1 : hi1 = hi2) (h2 : lo1 = lo2) - (h : hi1 - lo1 + 1 = hi2 - lo2 + 1) : - BitVec.cast h (extractLsb hi1 lo1 v) = (extractLsb hi2 lo2 v) := by + (h1 : n1 = n2) (h2 : lo1 = lo2): + BitVec.cast h (extractLsb' lo1 n1 v) = (extractLsb' lo2 n2 v) := by subst_vars simp only [Nat.sub_zero, BitVec.cast_eq] theorem read_mem_bytes_of_write_mem_bytes_subset_same_first_address (h0 : 0 < n1) (h1 : n1 <= 2^64) (h2 : 0 < n2) (h3 : n2 <= 2^64) - (h4 : mem_subset addr (addr + (BitVec.ofNat 64 (n2 - 1))) addr (addr + (BitVec.ofNat 64 (n1 - 1)))) - (h : n2 * 8 - 1 - 0 + 1 = n2 * 8) : + (h4 : mem_subset addr (addr + (BitVec.ofNat 64 (n2 - 1))) addr (addr + (BitVec.ofNat 64 (n1 - 1)))): read_mem_bytes n2 addr (write_mem_bytes n1 addr val s) = - BitVec.cast h (extractLsb ((n2 * 8) - 1) 0 val) := by + extractLsb' 0 (n2 * 8) val := by have rm_lemma := @read_mem_of_write_mem_bytes_same_first_address n1 addr val s h0 h1 simp only [Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, forall_const] at rm_lemma induction n2, h2 using Nat.le_induction generalizing n1 addr val s @@ -543,21 +539,20 @@ theorem read_mem_bytes_of_write_mem_bytes_subset_same_first_address erw [Nat.mod_eq_of_lt h3] at hn erw [Nat.mod_eq_of_lt h1] at hn exact hn - rw [n_ih (by omega) (by omega) (by omega) _ (by omega)] - · rw [BitVec.extract_lsb_of_zeroExtend (v >>> 8)] - · have l1 := @BitVec.append_of_extract_general ((n1_1 + 1) * 8) 8 (n*8-1+1) (n*8) v + rw [n_ih (by omega) (by omega) (by omega) _] + · rw [BitVec.extractLsb'_of_zeroExtend (v >>> 8)] + · have l1 := @BitVec.append_of_extract_general ((n1_1 + 1) * 8) (n*8) 8 v simp (config := { decide := true }) only [Nat.zero_lt_succ, Nat.mul_pos_iff_of_pos_left, Nat.succ_sub_succ_eq_sub, Nat.sub_zero, Nat.reduceAdd, Nat.succ.injEq, forall_const] at l1 - rw [l1 (by omega) (by omega)] - · simp only [Nat.add_eq, Nat.sub_zero, BitVec.cast_cast] - apply @cast_of_extract_eq ((n1_1 + 1) * 8) (n * 8 - 1 + 1 + 7) ((n + 1) * 8 - 1) 0 0 <;> + rw [l1] + · apply @cast_of_extract_eq ((n1_1 + 1) * 8) (n * 8 + 8) ((n + 1) * 8) 0 0 <;> omega · omega · have rw_lemma2 := @read_mem_of_write_mem_bytes_same_first_address n1_1 (addr + 1#64) (zeroExtend (n1_1 * 8) (v >>> 8)) - (write_mem addr (extractLsb 7 0 v) s) + (write_mem addr (extractLsb' 0 8 v) s) simp only [Nat.reducePow, Nat.sub_zero, Nat.reduceAdd, BitVec.cast_eq, forall_const] at rw_lemma2 rw [rw_lemma2 (by omega) (by simp only [Nat.reducePow] at h1; omega)] @@ -636,22 +631,15 @@ theorem BitVec.to_nat_zero_lt_sub_64 (x y : BitVec 64) (h : ¬x = y) : theorem read_mem_of_write_mem_bytes_subset (h0 : 0 < n) (h1 : n <= 2^64) - (h2 : mem_subset addr2 addr2 addr1 (addr1 + (BitVec.ofNat 64 (n - 1)))) - (h : ((BitVec.toNat (addr2 - addr1) + 1) * 8 - 1 - - BitVec.toNat (addr2 - addr1) * 8 + 1) = 8) : + (h2 : mem_subset addr2 addr2 addr1 (addr1 + (BitVec.ofNat 64 (n - 1)))): read_mem addr2 (write_mem_bytes n addr1 val s) = - BitVec.cast h - (extractLsb - ((BitVec.toNat (addr2 - addr1) + 1) * 8 - 1) - (BitVec.toNat (addr2 - addr1) * 8) - val) := by + extractLsb' (BitVec.toNat (addr2 - addr1) * 8) 8 val := by induction n generalizing addr1 addr2 s case zero => contradiction case succ => rename_i n' n_ih simp_all only [write_mem_bytes, Nat.succ.injEq, Nat.zero_lt_succ, Nat.succ_sub_succ_eq_sub, Nat.sub_zero] - have cast_lemma := @cast_of_extract_eq by_cases h₀ : n' = 0 case pos => simp_all only [Nat.lt_irrefl, Nat.zero_le, Nat.zero_sub, @@ -659,20 +647,22 @@ theorem read_mem_of_write_mem_bytes_subset false_implies, implies_true] subst_vars simp only [write_mem_bytes, read_mem_of_write_mem_same] - rw [←cast_lemma] <;> bv_omega + simp only [Nat.reduceAdd, Nat.reduceMul, BitVec.sub_self, + toNat_ofNat, Nat.reducePow, Nat.zero_mod, Nat.zero_mul] case neg => -- (n' ≠ 0) by_cases h₁ : addr2 = addr1 case pos => -- (n' ≠ 0) and (addr2 = addr1) subst_vars rw [read_mem_of_write_mem_bytes_different (by omega)] · simp only [read_mem_of_write_mem_same] - rw [←cast_lemma] <;> bv_omega + simp only [BitVec.sub_self, toNat_ofNat, Nat.reducePow, + Nat.zero_mod, Nat.zero_mul] · rw [mem_separate_contiguous_regions_one_address _ (by omega)] case neg => -- (addr2 ≠ addr1) rw [n_ih] · ext -- simp only [bv_toNat] - simp only [toNat_cast, extractLsb, extractLsb', toNat_zeroExtend] + simp only [toNat_cast, extractLsb', toNat_zeroExtend] simp only [toNat_ushiftRight] simp_all only [toNat_ofNat, toNat_ofNatLt] simp only [BitVec.sub_of_add_is_sub_sub, Nat.succ_sub_succ_eq_sub, @@ -700,7 +690,6 @@ theorem read_mem_of_write_mem_bytes_subset · omega · rw [addr_add_one_add_m_sub_one _ _ (by omega) (by omega)] rw [mem_subset_one_addr_neq h₁ h2] - · omega done theorem read_mem_bytes_of_write_mem_bytes_subset_helper1 (a i : Nat) @@ -785,35 +774,26 @@ private theorem read_mem_bytes_of_write_mem_bytes_subset_n2_lt (h0 : 0 < n1) (h1 : n1 <= 2^64) (h2 : 0 < n2) (h3 : n2 < 2^64) (h4 : mem_subset addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1))) addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) (h5 : mem_legal addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1)))) - (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) - (h : ((BitVec.toNat (addr2 - addr1) + n2) * 8 - 1 - BitVec.toNat (addr2 - addr1) * 8 + 1) - = n2 * 8) : + (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) : read_mem_bytes n2 addr2 (write_mem_bytes n1 addr1 val s) = - BitVec.cast h - (extractLsb ((((addr2 - addr1).toNat + n2) * 8) - 1) ((addr2 - addr1).toNat * 8) val) := by + extractLsb' ((addr2 - addr1).toNat * 8) (n2 * 8) val := by induction n2, h2 using Nat.le_induction generalizing addr1 addr2 val s case base => simp only [Nat.reduceSucc, Nat.succ_sub_succ_eq_sub, Nat.sub_self, BitVec.add_zero] at h4 - simp_all only [read_mem_bytes, BitVec.cast_eq] - have h' : (BitVec.toNat (addr2 - addr1) + 1) * 8 - 1 - BitVec.toNat (addr2 - addr1) * 8 + 1 = 8 := by - omega - rw [read_mem_of_write_mem_bytes_subset h0 h1 h4 h'] + simp_all only [read_mem_bytes] + rw [read_mem_of_write_mem_bytes_subset h0 h1 h4] apply BitVec.empty_bitvector_append_left - decide case succ => rename_i n h2' n_ih by_cases h_addr : addr1 = addr2 case pos => -- (addr1 = addr2) subst addr2 - have h' : (n + 1) * 8 - 1 - 0 + 1 = (n + 1) * 8 := by omega have := @read_mem_bytes_of_write_mem_bytes_subset_same_first_address n1 (n + 1) addr1 val s - h0 h1 (by omega) (by omega) h4 h' + h0 h1 (by omega) (by omega) h4 rw [this] - ext - simp only [Nat.sub_zero, BitVec.cast_eq, extractLsb_toNat, - Nat.shiftRight_zero, toNat_cast, BitVec.sub_self, - toNat_ofNat, Nat.zero_mod, Nat.zero_mul, Nat.zero_add] + simp only [BitVec.sub_self, toNat_ofNat, Nat.reducePow, + Nat.zero_mod, Nat.zero_mul] case neg => -- (addr1 ≠ addr2) simp only [read_mem_bytes, Nat.add_eq, Nat.add_zero] simp only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero] at h4 @@ -832,25 +812,21 @@ private theorem read_mem_bytes_of_write_mem_bytes_subset_n2_lt have l2 := @first_address_is_subset_of_region addr2 (BitVec.ofNat 64 n) have l3 := mem_subset_trans l2 h4 simp only [l3, forall_const] at l1 - rw [l1 (by omega)] + rw [l1] simp only [Nat.succ_sub_succ_eq_sub, Nat.sub_zero] at h5 have n_ih' := @n_ih (addr2 + 1#64) addr1 val s (by omega) simp only [h_sub, forall_const] at n_ih' rw [mem_legal_lemma h2'] at n_ih' - · simp only [forall_const] at n_ih' - have h' : (BitVec.toNat (addr2 + 1#64 - addr1) + n) * 8 - 1 - - BitVec.toNat (addr2 + 1#64 - addr1) * 8 + 1 = - n * 8 := by - omega - rw [n_ih' h6 h'] + · simp only [h6, true_implies] at n_ih' + rw [n_ih'] ext - simp only [extractLsb, extractLsb', toNat_ofNat, toNat_cast, + simp only [extractLsb', toNat_ofNat, toNat_cast, BitVec.add_of_sub_sub_of_add] simp only [toNat_add (addr2 - addr1) 1#64, Nat.add_eq, Nat.add_zero, toNat_ofNat, Nat.add_mod_mod, cast_ofNat, toNat_append] have := @addr_diff_upper_bound_lemma n1 n addr1 addr2 h0 h1 (by omega) (by omega) h6 h5 h4 - rw [read_mem_bytes_of_write_mem_bytes_subset_helper2] <;> assumption + rw [read_mem_bytes_of_write_mem_bytes_subset_helper2] <;> sorry · omega · assumption done @@ -886,37 +862,21 @@ theorem entire_memory_subset_legal_regions_eq_addr simp_all [mem_subset, mem_legal] bv_omega -private theorem read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt_helper (val : BitVec (x * 8)) - (h0 : 0 < x) - (h : (BitVec.toNat (addr2 - addr2) + x) * 8 - 1 - - BitVec.toNat (addr2 - addr2) * 8 + 1 - = - x * 8) : +private theorem read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt_helper (val : BitVec (x * 8)): val = - BitVec.cast h - (extractLsb ((BitVec.toNat (addr2 - addr2) + x) * 8 - 1) - (BitVec.toNat (addr2 - addr2) * 8) val) := by + extractLsb' ((BitVec.toNat (addr2 - addr2)) * 8) (x * 8) val := by ext - simp only [extractLsb, extractLsb', BitVec.sub_self, toNat_ofNat, - Nat.zero_mod, Nat.zero_mul, Nat.shiftRight_zero, - ofNat_toNat, toNat_cast, toNat_truncate, Nat.zero_add, - Nat.sub_zero] - rw [Nat.mod_eq_of_lt] - rw [Nat.sub_add_cancel] - · exact val.isLt - · omega - done + simp only [BitVec.sub_self, toNat_ofNat, Nat.zero_mod, + Nat.zero_mul, extractLsb'_toNat, + Nat.shiftRight_zero, toNat_mod_cancel] private theorem read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt (h0 : 0 < n1) (h1 : n1 <= my_pow 2 64) (h2 : 0 < n2) (h3 : n2 = my_pow 2 64) (h4 : mem_subset addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1))) addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) (h5 : mem_legal addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1)))) - (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) - (h : ((BitVec.toNat (addr2 - addr1) + n2) * 8 - 1 - BitVec.toNat (addr2 - addr1) * 8 + 1) - = n2 * 8) : + (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))): read_mem_bytes n2 addr2 (write_mem_bytes n1 addr1 val s) = - BitVec.cast h - (extractLsb ((((addr2 - addr1).toNat + n2) * 8) - 1) ((addr2 - addr1).toNat * 8) val) := by + extractLsb' ((addr2 - addr1).toNat * 8) (n2 * 8) val := by subst n2 have l0 := @entire_memory_subset_of_only_itself n1 addr2 addr1 h1 h4 subst n1 @@ -924,7 +884,6 @@ private theorem read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt subst addr1 rw [read_mem_bytes_of_write_mem_bytes_same] · apply read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt_helper - simp [my_pow_2_gt_zero] · unfold my_pow; decide @[state_simp_rules] @@ -932,19 +891,16 @@ theorem read_mem_bytes_of_write_mem_bytes_subset (h0 : 0 < n1) (h1 : n1 <= 2^64) (h2 : 0 < n2) (h3 : n2 <= 2^64) (h4 : mem_subset addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1))) addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) (h5 : mem_legal addr2 (addr2 + (BitVec.ofNat 64 (n2 - 1)))) - (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))) - (h : ((BitVec.toNat (addr2 - addr1) + n2) * 8 - 1 - - BitVec.toNat (addr2 - addr1) * 8 + 1) - = n2 * 8) : + (h6 : mem_legal addr1 (addr1 + (BitVec.ofNat 64 (n1 - 1)))): read_mem_bytes n2 addr2 (write_mem_bytes n1 addr1 val s) = - BitVec.cast h - (extractLsb - ((((addr2 - addr1).toNat + n2) * 8) - 1) + (extractLsb' ((addr2 - addr1).toNat * 8) + (n2 * 8) val) := by by_cases h₀ : n2 = 2^64 case pos => - apply read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt h0 + apply read_mem_bytes_of_write_mem_bytes_subset_n2_eq_alt + · exact h0 · unfold my_pow; exact h1 · exact h2 · unfold my_pow; exact h₀ @@ -988,8 +944,8 @@ private theorem write_mem_bytes_irrelevant_helper (h : n * 8 + 8 = (n + 1) * 8) done private theorem extract_byte_of_read_mem_bytes_succ (n : Nat) : - extractLsb 7 0 (read_mem_bytes (n + 1) addr s) = read_mem addr s := by - simp only [read_mem_bytes, Nat.add_eq, Nat.add_zero, toNat_eq, extractLsb_toNat, + extractLsb' 0 8 (read_mem_bytes (n + 1) addr s) = read_mem addr s := by + simp only [read_mem_bytes, Nat.add_eq, Nat.add_zero, toNat_eq, extractLsb'_toNat, toNat_cast, toNat_append, Nat.shiftRight_zero, Nat.reduceAdd] generalize read_mem addr s = y generalize (read_mem_bytes n (addr + 1#64) s) = x diff --git a/Arm/State.lean b/Arm/State.lean index aaead5a1..d0b8c27a 100644 --- a/Arm/State.lean +++ b/Arm/State.lean @@ -962,7 +962,7 @@ def write_bytes (n : Nat) (addr : BitVec 64) match n with | 0 => m | n' + 1 => - let byte := BitVec.extractLsb 7 0 val + let byte := BitVec.extractLsb' 0 8 val let m := m.write addr byte let val_rest := BitVec.zeroExtend (n' * 8) (val >>> 8) m.write_bytes n' (addr + 1#64) val_rest @@ -988,7 +988,7 @@ and then recursing to write the rest. -/ theorem write_bytes_succ {mem : Memory} : mem.write_bytes (n + 1) addr val = - let byte := BitVec.extractLsb 7 0 val + let byte := BitVec.extractLsb' 0 8 val let mem := mem.write addr byte let val_rest := BitVec.zeroExtend (n * 8) (val >>> 8) mem.write_bytes n (addr + 1#64) val_rest := rfl