Skip to content

Commit

Permalink
Prove three shift lemmas
Browse files Browse the repository at this point in the history
  • Loading branch information
pennyannn committed Sep 19, 2024
1 parent dcfdac3 commit bd6fd9a
Showing 1 changed file with 115 additions and 74 deletions.
189 changes: 115 additions & 74 deletions Arm/Insts/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def rev_elems (n esize : Nat) (x : BitVec n) (h₀ : esize ∣ n) (h₁ : 0 < es
BitVec.cast h3 (element ++ rest_ans)
termination_by n

example : rev_elems 4 4 0xA#4 (by decide) (by decide) = 0xA#4 := by
example : rev_elems 4 4 0xA#4 (by decide) (by decide) = 0xA#4 := by
native_decide
example : rev_elems 8 4 0xAB#8 (by decide) (by decide) = 0xBA#8 := by native_decide
example : rev_elems 8 4 (rev_elems 8 4 0xAB#8 (by decide) (by decide))
Expand Down Expand Up @@ -678,69 +678,26 @@ def shift_right_common_aux
termination_by (info.elements - e)


@[app_unexpander BitVec.ofNat] def unexpandBitVecOfNatToHex : Lean.PrettyPrinter.Unexpander
| `($(_) $n:num $i:num) =>
let i' := i.getNat
let n' := n.getNat
let trimmed_hex := -- Remove leading zeroes...
String.dropWhile (BitVec.ofNat n' i').toHex
(fun c => c = '0')
-- ... but keep one if the literal is all zeros.
let trimmed_hex := if trimmed_hex.isEmpty then "0" else trimmed_hex
let bv := Lean.Syntax.mkNumLit s!"0x{trimmed_hex}#{n'}"
`($bv:num)
| _ => throw ()
-- @[app_unexpander BitVec.ofNat] def unexpandBitVecOfNatToHex : Lean.PrettyPrinter.Unexpander
-- | `($(_) $n:num $i:num) =>
-- let i' := i.getNat
-- let n' := n.getNat
-- let trimmed_hex := -- Remove leading zeroes...
-- String.dropWhile (BitVec.ofNat n' i').toHex
-- (fun c => c = '0')
-- -- ... but keep one if the literal is all zeros.
-- let trimmed_hex := if trimmed_hex.isEmpty then "0" else trimmed_hex
-- let bv := Lean.Syntax.mkNumLit s!"0x{trimmed_hex}#{n'}"
-- `($bv:num)
-- | _ => throw ()

-- FIXME: should this be upstreamed?
theorem shift_le (x : Nat) (shift :Nat) :
x >>> shift ≤ x := by
simp only [Nat.shiftRight_eq_div_pow]
exact Nat.div_le_self x (2 ^ shift)

theorem crock1 (x : BitVec 64):
BitVec.ofInt 65 (Int.ofNat x.toNat) = zeroExtend 65 x := by
simp only [Int.ofNat_eq_coe, ofInt_natCast, ofNat_toNat]

#print BitVec.ofNatLt

theorem crock4 (x : BitVec n) (s : Nat) :
x.ushiftRight s = BitVec.ofNat n (x.toNat / (2 ^ s)) := by
sorry

theorem crock5 (x : BitVec 64) : (zeroExtend 65 x).toNat = x.toNat := by
simp only [toNat_truncate, Nat.reducePow]
omega

theorem crock3 (x : BitVec 64) (shift : Nat) :
extractLsb 63 0 ((zeroExtend 65 x).ushiftRight shift) = x.ushiftRight shift := by
simp only [crock4, extractLsb, extractLsb', Nat.sub_zero,
Nat.reduceAdd, Nat.shiftRight_zero]
have h:
(BitVec.ofNat 65 ((zeroExtend 65 x).toNat / 2 ^ shift)).toNat =
(x.toNat / 2 ^ shift)
:= by
rw [crock5]
simp only [toNat_ofNat]
refine Nat.mod_eq_of_lt ?h
refine Nat.div_lt_of_lt_mul ?h.h
have h : x.toNat < 2^64 := by exact isLt x
apply Nat.lt_trans h
have h1 : 2 ^ shift >= 1 := by exact Nat.one_le_two_pow
generalize 2 ^ shift = x at *
omega
exact congrArg (BitVec.ofNat 64) h

theorem crock2 (shift : Nat) (result : BitVec 128)
(x : BitVec 64) (y : BitVec 64)
: (result &&& 0xffffffffffffffff0000000000000000#128 |||
zeroExtend 128 (extractLsb 63 0 ((zeroExtend 65 x).ushiftRight shift))) &&&
0xffffffffffffffff#128 |||
zeroExtend 128 (extractLsb 63 0 ((zeroExtend 65 y).ushiftRight shift)) <<< 64 =
y.ushiftRight shift ++ x.ushiftRight shift
:= by
simp only [crock3]
generalize x.ushiftRight shift = x
generalize y.ushiftRight shift = y
bv_decide

-- (BitVec.cast ⋯ (extractLsb 63 0 operand)).toNat
-- --> (operand.toNat % 18446744073709551616)
-- extractLsb_toNat is the lemma that turned extractLsb into mod operation in Nat
-- This makes it hard to use bv_decide
@[state_simp_rules]
theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
(shift : Nat) (result : BitVec 128):
shift_right_common_aux 0
Expand All @@ -759,6 +716,7 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
simp only [-- -extractLsb_toNat,
state_simp_rules,
minimal_theory,
-- FIXME: simply using bitvec_rules will expand out extractLsb and truncate
-- bitvec_rules,
BitVec.cast_eq,
Nat.shiftRight_zero,
Expand All @@ -781,12 +739,36 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
Nat.reduceSub,
Nat.one_mul,
reduceHShiftLeft,
-- ushiftRight_eq,
crock1
-- Eliminating casting functions
Int.ofNat_eq_coe, ofInt_natCast, ofNat_toNat
]
generalize (extractLsb 63 0 operand) = x
generalize (extractLsb 127 64 operand) = y
apply crock2
generalize (extractLsb 127 64 operand) = x; simp at x
generalize (extractLsb 63 0 operand) = y; simp at y
have h0 : ∀ (z : BitVec 64), extractLsb 63 0 ((zeroExtend 65 z).ushiftRight shift)
= z.ushiftRight shift := by
intro z
simp only [ushiftRight, toNat_truncate]
have h1: z.toNat % 2 ^ 65 = z.toNat := by omega
simp only [h1]
simp only [Std.Tactic.BVDecide.Normalize.BitVec.ofNatLt_reduce]
simp only [Nat.sub_zero, Nat.reduceAdd, BitVec.extractLsb_ofNat, Nat.shiftRight_zero]
have h2 : z.toNat >>> shift % 2 ^ 65 = z.toNat >>> shift := by
refine Nat.mod_eq_of_lt ?h3
have h4 : z.toNat >>> shift ≤ z.toNat := by exact shift_le z.toNat shift
omega
simp only [h2]
simp only [h0]
clear h0
generalize x.ushiftRight shift = p
generalize y.ushiftRight shift = q
-- FIXME: This proof can be simplified once bv_decide supports shift
-- operations with variable offsets
bv_decide

-- FIXME: where to put this?
theorem ofInt_eq_signExtend (x : BitVec 32) :
BitVec.ofInt 33 x.toInt = signExtend 33 x := by
exact rfl

theorem shift_right_common_aux_32_4_fff (operand : BitVec 128)
(shift : Nat) (result : BitVec 128):
Expand All @@ -795,10 +777,69 @@ theorem shift_right_common_aux_32_4_fff (operand : BitVec 128)
unsigned := false, round := false, accumulate := false,
h := (by omega) }
operand 0#128 result =
(sshiftRight (extractLsb' 96 32 operand) shift)
++ (sshiftRight (extractLsb' 64 32 operand) shift)
++ (sshiftRight (extractLsb' 32 32 operand) shift)
++ (sshiftRight (extractLsb' 0 32 operand) shift) := by sorry
(sshiftRight (extractLsb 127 96 operand) shift)
++ (sshiftRight (extractLsb 95 64 operand) shift)
++ (sshiftRight (extractLsb 63 32 operand) shift)
++ (sshiftRight (extractLsb 31 0 operand) shift) := by
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
simp only [-- -extractLsb_toNat,
state_simp_rules,
minimal_theory,
-- FIXME: simply using bitvec_rules will expand out extractLsb and truncate
-- bitvec_rules,
BitVec.cast_eq,
Nat.shiftRight_zero,
Nat.zero_shiftRight,
Nat.reduceMul,
Nat.reduceAdd,
Nat.add_one_sub_one,
Nat.sub_zero,
reduceAllOnes,
reduceZeroExtend,
Nat.zero_mul,
shiftLeft_zero_eq,
reduceNot,
BitVec.extractLsb_ofNat,
Nat.reducePow,
Nat.zero_mod,
Int.ofNat_emod,
Int.Nat.cast_ofNat_Int,
BitVec.zero_add,
Nat.reduceSub,
Nat.one_mul,
reduceHShiftLeft,
-- Eliminating casting functions
ofInt_eq_signExtend
]
generalize extractLsb 31 0 operand = a; simp at a
generalize extractLsb 63 32 operand = b; simp at b
generalize extractLsb 95 64 operand = c; simp at c
generalize extractLsb 127 96 operand = d; simp at d
have h : ∀ (x : BitVec 32),
extractLsb 31 0 ((signExtend 33 x).sshiftRight shift)
= x.sshiftRight shift := by
intro x
simp only [sshiftRight, signExtend, toInt_ofInt]
have h1 : x.toInt.bmod (2^33) >>> shift = x.toInt >>> shift := by sorry
simp only [h1]
simp only [BitVec.ofInt, Std.Tactic.BVDecide.Normalize.BitVec.ofNatLt_reduce]
sorry
simp only [h]
clear h
generalize a.sshiftRight shift = a
generalize b.sshiftRight shift = b
generalize c.sshiftRight shift = c
generalize d.sshiftRight shift = d
bv_decide

@[state_simp_rules]
def shift_right_common
Expand Down Expand Up @@ -830,8 +871,8 @@ theorem shift_left_common_aux_64_2 (operand : BitVec 128)
unsigned := unsigned, round := round, accumulate := accumulate,
h := (by omega)}
operand result =
(extractLsb' 0 64 operand <<< shift)
++ (extractLsb' 64 64 operand <<< shift) := by sorry
(extractLsb 127 64 operand <<< shift)
++ (extractLsb 63 0 operand <<< shift) := by sorry

@[state_simp_rules]
def shift_left_common
Expand Down

0 comments on commit bd6fd9a

Please sign in to comment.