Skip to content

Commit

Permalink
perf: optimize Nat.Linear.Expr.toPoly (#6062)
Browse files Browse the repository at this point in the history
  • Loading branch information
nomeata authored Nov 13, 2024
1 parent 6b811f8 commit 970261b
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions src/Init/Data/Nat/Linear.lean
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,16 @@ def Poly.combineAux (fuel : Nat) (p₁ p₂ : Poly) : Poly :=
def Poly.combine (p₁ p₂ : Poly) : Poly :=
combineAux hugeFuel p₁ p₂

def Expr.toPoly : Expr → Poly
| Expr.num k => bif k == 0 then [] else [ (k, fixedVar) ]
| Expr.var i => [(1, i)]
| Expr.add a b => a.toPoly ++ b.toPoly
| Expr.mulL k a => a.toPoly.mul k
| Expr.mulR a k => a.toPoly.mul k
def Expr.toPoly (e : Expr) := go 1 e []
where
-- Implementation note: This assembles the result using difference lists
-- to avoid `++` on lists.
go (coeff : Nat) : Expr → (Poly → Poly)
| Expr.num k => bif k == 0 then id else ((coeff * k, fixedVar) :: ·)
| Expr.var i => ((coeff, i) :: ·)
| Expr.add a b => go coeff a ∘ go coeff b
| Expr.mulL k a
| Expr.mulR a k => bif k == 0 then id else go (coeff * k) a

def Poly.norm (p : Poly) : Poly :=
p.sort.fuse
Expand Down Expand Up @@ -516,13 +520,25 @@ theorem Poly.denote_combine (ctx : Context) (p₁ p₂ : Poly) : (p₁.combine p

attribute [local simp] Poly.denote_combine

theorem Expr.denote_toPoly_go (ctx : Context) (e : Expr) :
(toPoly.go k e p).denote ctx = k * e.denote ctx + p.denote ctx := by
induction k, e using Expr.toPoly.go.induct generalizing p with
| case1 k k' =>
simp only [toPoly.go]
by_cases h : k' == 0
· simp [h, eq_of_beq h]
· simp [h, Var.denote]
| case2 k i => simp [toPoly.go]
| case3 k a b iha ihb => simp [toPoly.go, iha, ihb]
| case4 k k' a ih
| case5 k a k' ih =>
simp only [toPoly.go, denote, mul_eq]
by_cases h : k' == 0
· simp [h, eq_of_beq h]
· simp [h, cond_false, ih, Nat.mul_assoc]

theorem Expr.denote_toPoly (ctx : Context) (e : Expr) : e.toPoly.denote ctx = e.denote ctx := by
induction e with
| num k => by_cases h : k == 0 <;> simp [toPoly, h, Var.denote]; simp [eq_of_beq h]
| var i => simp [toPoly]
| add a b iha ihb => simp [toPoly, iha, ihb]
| mulL k a ih => simp [toPoly, ih, -Poly.mul]
| mulR k a ih => simp [toPoly, ih, -Poly.mul]
simp [toPoly, Expr.denote_toPoly_go]

attribute [local simp] Expr.denote_toPoly

Expand Down Expand Up @@ -554,8 +570,8 @@ theorem ExprCnstr.denote_toPoly (ctx : Context) (c : ExprCnstr) : c.toPoly.denot
cases c; rename_i eq lhs rhs
simp [ExprCnstr.denote, PolyCnstr.denote, ExprCnstr.toPoly];
by_cases h : eq = true <;> simp [h]
· simp [Poly.denote_eq, Expr.toPoly]
· simp [Poly.denote_le, Expr.toPoly]
· simp [Poly.denote_eq]
· simp [Poly.denote_le]

attribute [local simp] ExprCnstr.denote_toPoly

Expand Down

0 comments on commit 970261b

Please sign in to comment.