Skip to content

Commit

Permalink
Prove three shift lemmas
Browse files Browse the repository at this point in the history
  • Loading branch information
pennyannn committed Sep 19, 2024
1 parent dcfdac3 commit 52cf70e
Showing 1 changed file with 157 additions and 74 deletions.
231 changes: 157 additions & 74 deletions Arm/Insts/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def rev_elems (n esize : Nat) (x : BitVec n) (h₀ : esize ∣ n) (h₁ : 0 < es
BitVec.cast h3 (element ++ rest_ans)
termination_by n

example : rev_elems 4 4 0xA#4 (by decide) (by decide) = 0xA#4 := by
example : rev_elems 4 4 0xA#4 (by decide) (by decide) = 0xA#4 := by
native_decide
example : rev_elems 8 4 0xAB#8 (by decide) (by decide) = 0xBA#8 := by native_decide
example : rev_elems 8 4 (rev_elems 8 4 0xAB#8 (by decide) (by decide))
Expand Down Expand Up @@ -678,69 +678,26 @@ def shift_right_common_aux
termination_by (info.elements - e)


@[app_unexpander BitVec.ofNat] def unexpandBitVecOfNatToHex : Lean.PrettyPrinter.Unexpander
| `($(_) $n:num $i:num) =>
let i' := i.getNat
let n' := n.getNat
let trimmed_hex := -- Remove leading zeroes...
String.dropWhile (BitVec.ofNat n' i').toHex
(fun c => c = '0')
-- ... but keep one if the literal is all zeros.
let trimmed_hex := if trimmed_hex.isEmpty then "0" else trimmed_hex
let bv := Lean.Syntax.mkNumLit s!"0x{trimmed_hex}#{n'}"
`($bv:num)
| _ => throw ()
-- @[app_unexpander BitVec.ofNat] def unexpandBitVecOfNatToHex : Lean.PrettyPrinter.Unexpander
-- | `($(_) $n:num $i:num) =>
-- let i' := i.getNat
-- let n' := n.getNat
-- let trimmed_hex := -- Remove leading zeroes...
-- String.dropWhile (BitVec.ofNat n' i').toHex
-- (fun c => c = '0')
-- -- ... but keep one if the literal is all zeros.
-- let trimmed_hex := if trimmed_hex.isEmpty then "0" else trimmed_hex
-- let bv := Lean.Syntax.mkNumLit s!"0x{trimmed_hex}#{n'}"
-- `($bv:num)
-- | _ => throw ()

-- FIXME: should this be upstreamed?
theorem shift_le (x : Nat) (shift :Nat) :
x >>> shift ≤ x := by
simp only [Nat.shiftRight_eq_div_pow]
exact Nat.div_le_self x (2 ^ shift)

theorem crock1 (x : BitVec 64):
BitVec.ofInt 65 (Int.ofNat x.toNat) = zeroExtend 65 x := by
simp only [Int.ofNat_eq_coe, ofInt_natCast, ofNat_toNat]

#print BitVec.ofNatLt

theorem crock4 (x : BitVec n) (s : Nat) :
x.ushiftRight s = BitVec.ofNat n (x.toNat / (2 ^ s)) := by
sorry

theorem crock5 (x : BitVec 64) : (zeroExtend 65 x).toNat = x.toNat := by
simp only [toNat_truncate, Nat.reducePow]
omega

theorem crock3 (x : BitVec 64) (shift : Nat) :
extractLsb 63 0 ((zeroExtend 65 x).ushiftRight shift) = x.ushiftRight shift := by
simp only [crock4, extractLsb, extractLsb', Nat.sub_zero,
Nat.reduceAdd, Nat.shiftRight_zero]
have h:
(BitVec.ofNat 65 ((zeroExtend 65 x).toNat / 2 ^ shift)).toNat =
(x.toNat / 2 ^ shift)
:= by
rw [crock5]
simp only [toNat_ofNat]
refine Nat.mod_eq_of_lt ?h
refine Nat.div_lt_of_lt_mul ?h.h
have h : x.toNat < 2^64 := by exact isLt x
apply Nat.lt_trans h
have h1 : 2 ^ shift >= 1 := by exact Nat.one_le_two_pow
generalize 2 ^ shift = x at *
omega
exact congrArg (BitVec.ofNat 64) h

theorem crock2 (shift : Nat) (result : BitVec 128)
(x : BitVec 64) (y : BitVec 64)
: (result &&& 0xffffffffffffffff0000000000000000#128 |||
zeroExtend 128 (extractLsb 63 0 ((zeroExtend 65 x).ushiftRight shift))) &&&
0xffffffffffffffff#128 |||
zeroExtend 128 (extractLsb 63 0 ((zeroExtend 65 y).ushiftRight shift)) <<< 64 =
y.ushiftRight shift ++ x.ushiftRight shift
:= by
simp only [crock3]
generalize x.ushiftRight shift = x
generalize y.ushiftRight shift = y
bv_decide

-- (BitVec.cast ⋯ (extractLsb 63 0 operand)).toNat
-- --> (operand.toNat % 18446744073709551616)
-- extractLsb_toNat is the lemma that turned extractLsb into mod operation in Nat
-- This makes it hard to use bv_decide
@[state_simp_rules]
theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
(shift : Nat) (result : BitVec 128):
shift_right_common_aux 0
Expand All @@ -759,6 +716,7 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
simp only [-- -extractLsb_toNat,
state_simp_rules,
minimal_theory,
-- FIXME: simply using bitvec_rules will expand out extractLsb and truncate
-- bitvec_rules,
BitVec.cast_eq,
Nat.shiftRight_zero,
Expand All @@ -781,12 +739,53 @@ theorem shift_right_common_aux_64_2_tff (operand : BitVec 128)
Nat.reduceSub,
Nat.one_mul,
reduceHShiftLeft,
-- ushiftRight_eq,
crock1
-- Eliminating casting functions
Int.ofNat_eq_coe, ofInt_natCast, ofNat_toNat
]
generalize (extractLsb 63 0 operand) = x
generalize (extractLsb 127 64 operand) = y
apply crock2
generalize (extractLsb 127 64 operand) = x; simp at x
generalize (extractLsb 63 0 operand) = y; simp at y
have h0 : ∀ (z : BitVec 64), extractLsb 63 0 ((zeroExtend 65 z).ushiftRight shift)
= z.ushiftRight shift := by
intro z
simp only [ushiftRight, toNat_truncate]
have h1: z.toNat % 2 ^ 65 = z.toNat := by omega
simp only [h1]
simp only [Std.Tactic.BVDecide.Normalize.BitVec.ofNatLt_reduce]
simp only [Nat.sub_zero, Nat.reduceAdd, BitVec.extractLsb_ofNat, Nat.shiftRight_zero]
have h2 : z.toNat >>> shift % 2 ^ 65 = z.toNat >>> shift := by
refine Nat.mod_eq_of_lt ?h3
have h4 : z.toNat >>> shift ≤ z.toNat := by exact shift_le z.toNat shift
omega
simp only [h2]
simp only [h0]
clear h0
generalize x.ushiftRight shift = p
generalize y.ushiftRight shift = q
-- FIXME: This proof can be simplified once bv_decide supports shift
-- operations with variable offsets
bv_decide

-- FIXME: where to put this?
theorem ofInt_eq_signExtend (x : BitVec 32) :
BitVec.ofInt 33 x.toInt = signExtend 33 x := by
exact rfl

-- FIXME: where to put this?
theorem msb_signExtend (x : BitVec w) (hw: w < w'):
(signExtend w' x).msb = x.msb := by
rcases w' with rfl | w'
· simp only [show w = 0 by omega,
msb_eq_getLsbD_last, Nat.zero_sub, Nat.le_refl,
getLsbD_ge]
· simp only [msb_eq_getLsbD_last, Nat.add_one_sub_one,
getLsbD_signExtend, Nat.lt_add_one,
decide_True, Bool.true_and, ite_eq_right_iff]
by_cases h : w' < w
· rcases w with rfl | w
· simp
· simp only [h, Nat.add_one_sub_one, true_implies]
omega
· simp [h]

theorem shift_right_common_aux_32_4_fff (operand : BitVec 128)
(shift : Nat) (result : BitVec 128):
Expand All @@ -795,10 +794,86 @@ theorem shift_right_common_aux_32_4_fff (operand : BitVec 128)
unsigned := false, round := false, accumulate := false,
h := (by omega) }
operand 0#128 result =
(sshiftRight (extractLsb' 96 32 operand) shift)
++ (sshiftRight (extractLsb' 64 32 operand) shift)
++ (sshiftRight (extractLsb' 32 32 operand) shift)
++ (sshiftRight (extractLsb' 0 32 operand) shift) := by sorry
(sshiftRight (extractLsb 127 96 operand) shift)
++ (sshiftRight (extractLsb 95 64 operand) shift)
++ (sshiftRight (extractLsb 63 32 operand) shift)
++ (sshiftRight (extractLsb 31 0 operand) shift) := by
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_right_common_aux
simp only [minimal_theory, bitvec_rules]
simp only [-- -extractLsb_toNat,
state_simp_rules,
minimal_theory,
-- FIXME: simply using bitvec_rules will expand out extractLsb and truncate
-- bitvec_rules,
BitVec.cast_eq,
Nat.shiftRight_zero,
Nat.zero_shiftRight,
Nat.reduceMul,
Nat.reduceAdd,
Nat.add_one_sub_one,
Nat.sub_zero,
reduceAllOnes,
reduceZeroExtend,
Nat.zero_mul,
shiftLeft_zero_eq,
reduceNot,
BitVec.extractLsb_ofNat,
Nat.reducePow,
Nat.zero_mod,
Int.ofNat_emod,
Int.Nat.cast_ofNat_Int,
BitVec.zero_add,
Nat.reduceSub,
Nat.one_mul,
reduceHShiftLeft,
-- Eliminating casting functions
ofInt_eq_signExtend
]
generalize extractLsb 31 0 operand = a; simp at a
generalize extractLsb 63 32 operand = b; simp at b
generalize extractLsb 95 64 operand = c; simp at c
generalize extractLsb 127 96 operand = d; simp at d
have h : ∀ (x : BitVec 32),
extractLsb 31 0 ((signExtend 33 x).sshiftRight shift)
= x.sshiftRight shift := by
intros x
apply eq_of_getLsbD_eq; intros i; simp at i
simp only [getLsbD_sshiftRight]
simp only [Nat.sub_zero, Nat.reduceAdd, getLsbD_extract, Nat.zero_add,
getLsbD_sshiftRight, getLsbD_signExtend]
simp only [show (i : Nat) ≤ 31 by omega,
decide_True, Bool.true_and]
simp only [show ¬33 ≤ (i : Nat) by omega,
decide_False, Bool.not_false, Bool.true_and]
simp only [show ¬32 ≤ (i : Nat) by omega,
decide_False, Bool.not_false, Bool.true_and]
by_cases h : s + (i : Nat) < 32
· simp only [h, reduceIte]
simp only [show s + (i : Nat) < 33 by omega,
↓reduceIte, decide_True, Bool.true_and]
· simp only [h, reduceIte]
have icases : s + (i : Nat) = 3232 < s + (i : Nat) := by omega
rcases icases with (h' | h')
· simp only [h', Nat.lt_add_one, ↓reduceIte, decide_True, Bool.true_and]
· simp only [show ¬(s + (i : Nat) < 33) by omega, ↓reduceIte]
apply msb_signExtend; trivial
simp only [h]
clear h
generalize a.sshiftRight shift = a
generalize b.sshiftRight shift = b
generalize c.sshiftRight shift = c
generalize d.sshiftRight shift = d
-- FIXME: This proof can be simplified once bv_decide supports shift
-- operations with variable offsets
bv_decide

@[state_simp_rules]
def shift_right_common
Expand Down Expand Up @@ -830,8 +905,16 @@ theorem shift_left_common_aux_64_2 (operand : BitVec 128)
unsigned := unsigned, round := round, accumulate := accumulate,
h := (by omega)}
operand result =
(extractLsb' 0 64 operand <<< shift)
++ (extractLsb' 64 64 operand <<< shift) := by sorry
(extractLsb 127 64 operand <<< shift)
++ (extractLsb 63 0 operand <<< shift) := by
unfold shift_left_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_left_common_aux
simp only [minimal_theory, bitvec_rules]
unfold shift_left_common_aux
simp only [minimal_theory, bitvec_rules]
simp only [state_simp_rules, minimal_theory, bitvec_rules]
bv_decide

@[state_simp_rules]
def shift_left_common
Expand Down

0 comments on commit 52cf70e

Please sign in to comment.