From 7db5140242f8425d81d3aadc8a83b790ce82ea6e Mon Sep 17 00:00:00 2001 From: Yan Peng Date: Mon, 16 Sep 2024 22:56:09 +0000 Subject: [PATCH] Prove three shift lemmas --- Arm/Insts/Common.lean | 188 +++++++++++++++++++++++++----------------- 1 file changed, 114 insertions(+), 74 deletions(-) diff --git a/Arm/Insts/Common.lean b/Arm/Insts/Common.lean index 170ac5df..5ca7ad4a 100644 --- a/Arm/Insts/Common.lean +++ b/Arm/Insts/Common.lean @@ -563,7 +563,7 @@ def rev_elems (n esize : Nat) (x : BitVec n) (h₀ : esize ∣ n) (h₁ : 0 < es BitVec.cast h3 (element ++ rest_ans) termination_by n -example : rev_elems 4 4 0xA#4 (by decide) (by decide) = 0xA#4 := by +example : rev_elems 4 4 0xA#4 (by decide) (by decide) = 0xA#4 := by native_decide example : rev_elems 8 4 0xAB#8 (by decide) (by decide) = 0xBA#8 := by native_decide example : rev_elems 8 4 (rev_elems 8 4 0xAB#8 (by decide) (by decide)) @@ -678,69 +678,26 @@ def shift_right_common_aux termination_by (info.elements - e) -@[app_unexpander BitVec.ofNat] def unexpandBitVecOfNatToHex : Lean.PrettyPrinter.Unexpander - | `($(_) $n:num $i:num) => - let i' := i.getNat - let n' := n.getNat - let trimmed_hex := -- Remove leading zeroes... - String.dropWhile (BitVec.ofNat n' i').toHex - (fun c => c = '0') - -- ... but keep one if the literal is all zeros. - let trimmed_hex := if trimmed_hex.isEmpty then "0" else trimmed_hex - let bv := Lean.Syntax.mkNumLit s!"0x{trimmed_hex}#{n'}" - `($bv:num) - | _ => throw () +-- @[app_unexpander BitVec.ofNat] def unexpandBitVecOfNatToHex : Lean.PrettyPrinter.Unexpander +-- | `($(_) $n:num $i:num) => +-- let i' := i.getNat +-- let n' := n.getNat +-- let trimmed_hex := -- Remove leading zeroes... +-- String.dropWhile (BitVec.ofNat n' i').toHex +-- (fun c => c = '0') +-- -- ... but keep one if the literal is all zeros. +-- let trimmed_hex := if trimmed_hex.isEmpty then "0" else trimmed_hex +-- let bv := Lean.Syntax.mkNumLit s!"0x{trimmed_hex}#{n'}" +-- `($bv:num) +-- | _ => throw () + +-- FIXME: should this be upstreamed? +theorem shift_le (x : Nat) (shift :Nat) : + x >>> shift ≤ x := by + simp only [Nat.shiftRight_eq_div_pow] + exact Nat.div_le_self x (2 ^ shift) -theorem crock1 (x : BitVec 64): - BitVec.ofInt 65 (Int.ofNat x.toNat) = zeroExtend 65 x := by - simp only [Int.ofNat_eq_coe, ofInt_natCast, ofNat_toNat] - -#print BitVec.ofNatLt - -theorem crock4 (x : BitVec n) (s : Nat) : - x.ushiftRight s = BitVec.ofNat n (x.toNat / (2 ^ s)) := by - sorry - -theorem crock5 (x : BitVec 64) : (zeroExtend 65 x).toNat = x.toNat := by - simp only [toNat_truncate, Nat.reducePow] - omega - -theorem crock3 (x : BitVec 64) (shift : Nat) : - extractLsb 63 0 ((zeroExtend 65 x).ushiftRight shift) = x.ushiftRight shift := by - simp only [crock4, extractLsb, extractLsb', Nat.sub_zero, - Nat.reduceAdd, Nat.shiftRight_zero] - have h: - (BitVec.ofNat 65 ((zeroExtend 65 x).toNat / 2 ^ shift)).toNat = - (x.toNat / 2 ^ shift) - := by - rw [crock5] - simp only [toNat_ofNat] - refine Nat.mod_eq_of_lt ?h - refine Nat.div_lt_of_lt_mul ?h.h - have h : x.toNat < 2^64 := by exact isLt x - apply Nat.lt_trans h - have h1 : 2 ^ shift >= 1 := by exact Nat.one_le_two_pow - generalize 2 ^ shift = x at * - omega - exact congrArg (BitVec.ofNat 64) h - -theorem crock2 (shift : Nat) (result : BitVec 128) - (x : BitVec 64) (y : BitVec 64) - : (result &&& 0xffffffffffffffff0000000000000000#128 ||| - zeroExtend 128 (extractLsb 63 0 ((zeroExtend 65 x).ushiftRight shift))) &&& - 0xffffffffffffffff#128 ||| - zeroExtend 128 (extractLsb 63 0 ((zeroExtend 65 y).ushiftRight shift)) <<< 64 = - y.ushiftRight shift ++ x.ushiftRight shift - := by - simp only [crock3] - generalize x.ushiftRight shift = x - generalize y.ushiftRight shift = y - bv_decide - --- (BitVec.cast ⋯ (extractLsb 63 0 operand)).toNat --- --> (operand.toNat % 18446744073709551616) --- extractLsb_toNat is the lemma that turned extractLsb into mod operation in Nat --- This makes it hard to use bv_decide +@[state_simp_rules] theorem shift_right_common_aux_64_2_tff (operand : BitVec 128) (shift : Nat) (result : BitVec 128): shift_right_common_aux 0 @@ -759,6 +716,7 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128) simp only [-- -extractLsb_toNat, state_simp_rules, minimal_theory, + -- FIXME: simply using bitvec_rules will expand out extractLsb and truncate -- bitvec_rules, BitVec.cast_eq, Nat.shiftRight_zero, @@ -781,12 +739,36 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128) Nat.reduceSub, Nat.one_mul, reduceHShiftLeft, - -- ushiftRight_eq, - crock1 + -- Eliminating casting functions + Int.ofNat_eq_coe, ofInt_natCast, ofNat_toNat ] - generalize (extractLsb 63 0 operand) = x - generalize (extractLsb 127 64 operand) = y - apply crock2 + generalize (extractLsb 127 64 operand) = x; simp at x + generalize (extractLsb 63 0 operand) = y; simp at y + have h0 : ∀ (z : BitVec 64), extractLsb 63 0 ((zeroExtend 65 z).ushiftRight shift) + = z.ushiftRight shift := by + intro z + simp only [ushiftRight, toNat_truncate] + have h1: z.toNat % 2 ^ 65 = z.toNat := by omega + simp only [h1] + simp only [Std.Tactic.BVDecide.Normalize.BitVec.ofNatLt_reduce] + simp only [Nat.sub_zero, Nat.reduceAdd, BitVec.extractLsb_ofNat, Nat.shiftRight_zero] + have h2 : z.toNat >>> shift % 2 ^ 65 = z.toNat >>> shift := by + refine Nat.mod_eq_of_lt ?h3 + have h4 : z.toNat >>> shift ≤ z.toNat := by exact shift_le z.toNat shift + omega + simp only [h2] + simp only [h0] + clear h0 + generalize x.ushiftRight shift = p + generalize y.ushiftRight shift = q + -- FIXME: This proof can be simplified once bv_decide supports shift + -- operations with variable offsets + bv_decide + +-- FIXME: where to put this? +theorem ofInt_eq_signExtend (x : BitVec 32) : + BitVec.ofInt 33 x.toInt = signExtend 33 x := by + exact rfl theorem shift_right_common_aux_32_4_fff (operand : BitVec 128) (shift : Nat) (result : BitVec 128): @@ -795,10 +777,68 @@ theorem shift_right_common_aux_32_4_fff (operand : BitVec 128) unsigned := false, round := false, accumulate := false, h := (by omega) } operand 0#128 result = - (sshiftRight (extractLsb' 96 32 operand) shift) - ++ (sshiftRight (extractLsb' 64 32 operand) shift) - ++ (sshiftRight (extractLsb' 32 32 operand) shift) - ++ (sshiftRight (extractLsb' 0 32 operand) shift) := by sorry + (sshiftRight (extractLsb 127 96 operand) shift) + ++ (sshiftRight (extractLsb 95 64 operand) shift) + ++ (sshiftRight (extractLsb 63 32 operand) shift) + ++ (sshiftRight (extractLsb 31 0 operand) shift) := by + unfold shift_right_common_aux + simp only [minimal_theory, bitvec_rules] + unfold shift_right_common_aux + simp only [minimal_theory, bitvec_rules] + unfold shift_right_common_aux + simp only [minimal_theory, bitvec_rules] + unfold shift_right_common_aux + simp only [minimal_theory, bitvec_rules] + unfold shift_right_common_aux + simp only [minimal_theory, bitvec_rules] + simp only [-- -extractLsb_toNat, + state_simp_rules, + minimal_theory, + -- FIXME: simply using bitvec_rules will expand out extractLsb and truncate + -- bitvec_rules, + BitVec.cast_eq, + Nat.shiftRight_zero, + Nat.zero_shiftRight, + Nat.reduceMul, + Nat.reduceAdd, + Nat.add_one_sub_one, + Nat.sub_zero, + reduceAllOnes, + reduceZeroExtend, + Nat.zero_mul, + shiftLeft_zero_eq, + reduceNot, + BitVec.extractLsb_ofNat, + Nat.reducePow, + Nat.zero_mod, + Int.ofNat_emod, + Int.Nat.cast_ofNat_Int, + BitVec.zero_add, + Nat.reduceSub, + Nat.one_mul, + reduceHShiftLeft, + -- Eliminating casting functions + ofInt_eq_signExtend + ] + generalize extractLsb 31 0 operand = a; simp at a + generalize extractLsb 63 32 operand = b; simp at b + generalize extractLsb 95 64 operand = c; simp at c + generalize extractLsb 127 96 operand = d; simp at d + have h : ∀ (x : BitVec 32), + extractLsb 31 0 ((signExtend 33 x).sshiftRight shift) + = x.sshiftRight shift := by + intro x + simp only [sshiftRight, signExtend] + sorry + simp only [h] + clear h + generalize a.sshiftRight shift = a + generalize b.sshiftRight shift = b + generalize c.sshiftRight shift = c + generalize d.sshiftRight shift = d + -- FIXME: This proof can be simplified once bv_decide supports shift + -- operations with variable offsets + bv_decide @[state_simp_rules] def shift_right_common @@ -830,8 +870,8 @@ theorem shift_left_common_aux_64_2 (operand : BitVec 128) unsigned := unsigned, round := round, accumulate := accumulate, h := (by omega)} operand result = - (extractLsb' 0 64 operand <<< shift) - ++ (extractLsb' 64 64 operand <<< shift) := by sorry + (extractLsb 127 64 operand <<< shift) + ++ (extractLsb 63 0 operand <<< shift) := by sorry @[state_simp_rules] def shift_left_common