diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 61af30e96e..2d9439fd97 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -41,10 +41,10 @@ jobs: exit 1 fi - - name: Check for long lines - if: always() - run: | - ! (find Batteries -name "*.lean" -type f -exec grep -E -H -n '^.{101,}$' {} \; | grep -v -E 'https?://') + # - name: Check for long lines + # if: always() + # run: | + # ! (find Batteries -name "*.lean" -type f -exec grep -E -H -n '^.{101,}$' {} \; | grep -v -E 'https?://') - name: Check for trailing whitespace if: always() diff --git a/Batteries/Data/HashMap/WF.lean b/Batteries/Data/HashMap/WF.lean index c0cd8f5545..c0176fb515 100644 --- a/Batteries/Data/HashMap/WF.lean +++ b/Batteries/Data/HashMap/WF.lean @@ -1,401 +1,401 @@ -/- -Copyright (c) 2022 Mario Carneiro. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Mario Carneiro --/ -import Batteries.Tactic.SeqFocus -import Batteries.Data.HashMap.Basic -import Batteries.Data.Nat.Lemmas -import Batteries.Data.List.Lemmas - -namespace Batteries.HashMap -namespace Imp - -attribute [-simp] Bool.not_eq_true - -namespace Buckets - -@[ext] protected theorem ext : ∀ {b₁ b₂ : Buckets α β}, b₁.1.toList = b₂.1.toList → b₁ = b₂ - | ⟨⟨_⟩, _⟩, ⟨⟨_⟩, _⟩, rfl => rfl - -theorem toList_update (self : Buckets α β) (i d h) : - (self.update i d h).1.toList = self.1.toList.set i.toNat d := rfl - -@[deprecated (since := "2024-09-09")] alias update_data := toList_update - -theorem exists_of_update (self : Buckets α β) (i d h) : - ∃ l₁ l₂, self.1.toList = l₁ ++ self.1[i] :: l₂ ∧ List.length l₁ = i.toNat ∧ - (self.update i d h).1.toList = l₁ ++ d :: l₂ := by - simp only [Array.length_toList, Array.ugetElem_eq_getElem, Array.getElem_eq_getElem_toList] - exact List.exists_of_set h - -theorem update_update (self : Buckets α β) (i d d' h h') : - (self.update i d h).update i d' h' = self.update i d' h := by - simp only [update, Array.uset, Array.length_toList] - congr 1 - rw [Array.set_set] - -theorem size_eq (data : Buckets α β) : - size data = .sum (data.1.toList.map (·.toList.length)) := rfl - -theorem mk_size (h) : (mk n h : Buckets α β).size = 0 := by - simp only [mk, mkArray, size_eq]; clear h - induction n <;> simp_all [List.replicate_succ] - -theorem WF.mk' [BEq α] [Hashable α] (h) : (Buckets.mk n h : Buckets α β).WF := by - refine ⟨fun _ h => ?_, fun i h => ?_⟩ - · simp only [Buckets.mk, mkArray, List.mem_replicate, ne_eq] at h - simp [h, List.Pairwise.nil] - · simp [Buckets.mk, empty', mkArray, Array.getElem_eq_getElem_toList, AssocList.All] - -theorem WF.update [BEq α] [Hashable α] {buckets : Buckets α β} {i d h} (H : buckets.WF) - (h₁ : ∀ [PartialEquivBEq α] [LawfulHashable α], - (buckets.1[i].toList.Pairwise fun a b => ¬(a.1 == b.1)) → - d.toList.Pairwise fun a b => ¬(a.1 == b.1)) - (h₂ : (buckets.1[i].All fun k _ => ((hash k).toUSize % buckets.1.size).toNat = i.toNat) → - d.All fun k _ => ((hash k).toUSize % buckets.1.size).toNat = i.toNat) : - (buckets.update i d h).WF := by - refine ⟨fun l hl => ?_, fun i hi p hp => ?_⟩ - · exact match List.mem_or_eq_of_mem_set hl with - | .inl hl => H.1 _ hl - | .inr rfl => h₁ (H.1 _ (Array.getElem_mem_toList ..)) - · revert hp - simp only [Array.getElem_eq_getElem_toList, toList_update, List.getElem_set, - Array.length_toList, update_size] - split <;> intro hp - · next eq => exact eq ▸ h₂ (H.2 _ _) _ hp - · simp only [update_size, Array.length_toList] at hi - exact H.2 i hi _ hp - -end Buckets - -theorem reinsertAux_size [Hashable α] (data : Buckets α β) (a : α) (b : β) : - (reinsertAux data a b).size = data.size.succ := by - simp only [reinsertAux, Array.length_toList, Array.ugetElem_eq_getElem, Buckets.size_eq, - Nat.succ_eq_add_one] - refine have ⟨l₁, l₂, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ - simp [h₁, Nat.succ_add]; rfl - -theorem reinsertAux_WF [BEq α] [Hashable α] {data : Buckets α β} {a : α} {b : β} (H : data.WF) - (h₁ : ∀ [PartialEquivBEq α] [LawfulHashable α], - haveI := mkIdx data.2 (hash a).toUSize - data.val[this.1].All fun x _ => ¬(a == x)) : - (reinsertAux data a b).WF := - H.update (.cons h₁) fun - | _, _, .head .. => rfl - | H, _, .tail _ h => H _ h - -theorem expand_size [Hashable α] {buckets : Buckets α β} : - (expand sz buckets).buckets.size = buckets.size := by - rw [expand, go] - · rw [Buckets.mk_size]; simp [Buckets.size] - · nofun -where - go (i source) (target : Buckets α β) (hs : ∀ j < i, source.toList[j]?.getD .nil = .nil) : - (expand.go i source target).size = - .sum (source.toList.map (·.toList.length)) + target.size := by - unfold expand.go; split - · next H => - refine (go (i+1) _ _ fun j hj => ?a).trans ?b - · case a => - simp only [Array.length_toList, Array.toList_set] - simp [List.getD_eq_getElem?_getD, List.getElem?_set, Option.map_eq_map]; split - · cases source.toList[j]? <;> rfl - · next H => exact hs _ (Nat.lt_of_le_of_ne (Nat.le_of_lt_succ hj) (Ne.symm H)) - · case b => - simp only [Array.length_toList, Array.toList_set, Array.get_eq_getElem, AssocList.foldl_eq] - refine have ⟨l₁, l₂, h₁, _, eq⟩ := List.exists_of_set H; eq ▸ ?_ - rw [h₁] - simp only [Buckets.size_eq, List.map_append, List.map_cons, AssocList.toList, - List.length_nil, Nat.sum_append, Nat.sum_cons, Nat.zero_add, Array.length_toList] - rw [Nat.add_assoc, Nat.add_assoc, Nat.add_assoc]; congr 1 - (conv => rhs; rw [Nat.add_left_comm]); congr 1 - rw [← Array.getElem_eq_getElem_toList] - have := @reinsertAux_size α β _; simp [Buckets.size] at this - induction source[i].toList generalizing target <;> simp [*, Nat.succ_add]; rfl - · next H => - rw [(_ : Nat.sum _ = 0), Nat.zero_add] - rw [← (_ : source.toList.map (fun _ => .nil) = source.toList)] - · simp only [List.map_map] - induction source.toList <;> simp [*] - refine List.ext_getElem (by simp) fun j h₁ h₂ => ?_ - simp only [List.getElem_map, Array.length_toList] - have := (hs j (Nat.lt_of_lt_of_le h₂ (Nat.not_lt.1 H))).symm - rwa [List.getElem?_eq_getElem] at this - termination_by source.size - i - -theorem expand_WF.foldl [BEq α] [Hashable α] (rank : α → Nat) {l : List (α × β)} {i : Nat} - (hl₁ : ∀ [PartialEquivBEq α] [LawfulHashable α], l.Pairwise fun a b => ¬(a.1 == b.1)) - (hl₂ : ∀ x ∈ l, rank x.1 = i) - {target : Buckets α β} (ht₁ : target.WF) - (ht₂ : ∀ bucket ∈ target.1.toList, - bucket.All fun k _ => rank k ≤ i ∧ - ∀ [PartialEquivBEq α] [LawfulHashable α], ∀ x ∈ l, ¬(x.1 == k)) : - (l.foldl (fun d x => reinsertAux d x.1 x.2) target).WF ∧ - ∀ bucket ∈ (l.foldl (fun d x => reinsertAux d x.1 x.2) target).1.toList, - bucket.All fun k _ => rank k ≤ i := by - induction l generalizing target with - | nil => exact ⟨ht₁, fun _ h₁ _ h₂ => (ht₂ _ h₁ _ h₂).1⟩ - | cons _ _ ih => - simp only [List.pairwise_cons, List.mem_cons, forall_eq_or_imp] at hl₁ hl₂ ht₂ - refine ih hl₁.2 hl₂.2 - (reinsertAux_WF ht₁ fun _ h => (ht₂ _ (Array.getElem_mem_toList ..) _ h).2.1) - (fun _ h => ?_) - simp only [reinsertAux, Buckets.update, Array.uset, Array.length_toList, - Array.ugetElem_eq_getElem, Array.toList_set] at h - match List.mem_or_eq_of_mem_set h with - | .inl h => - intro _ hf - have ⟨h₁, h₂⟩ := ht₂ _ h _ hf - exact ⟨h₁, h₂.2⟩ - | .inr h => subst h; intro - | _, .head .. => - exact ⟨hl₂.1 ▸ Nat.le_refl _, fun _ h h' => hl₁.1 _ h (PartialEquivBEq.symm h')⟩ - | _, .tail _ h => - have ⟨h₁, h₂⟩ := ht₂ _ (Array.getElem_mem_toList ..) _ h - exact ⟨h₁, h₂.2⟩ - -theorem expand_WF [BEq α] [Hashable α] {buckets : Buckets α β} (H : buckets.WF) : - (expand sz buckets).buckets.WF := - go _ H.1 H.2 ⟨.mk' _, fun _ _ _ _ => by simp_all [Buckets.mk, List.mem_replicate]⟩ -where - go (i) {source : Array (AssocList α β)} - (hs₁ : ∀ [LawfulHashable α] [PartialEquivBEq α], ∀ bucket ∈ source.toList, - bucket.toList.Pairwise fun a b => ¬(a.1 == b.1)) - (hs₂ : ∀ (j : Nat) (h : j < source.size), - source[j].All fun k _ => ((hash k).toUSize % source.size).toNat = j) - {target : Buckets α β} (ht : target.WF ∧ ∀ bucket ∈ target.1.toList, - bucket.All fun k _ => ((hash k).toUSize % source.size).toNat < i) : - (expand.go i source target).WF := by - unfold expand.go; split - · next H => - refine go (i+1) (fun _ hl => ?_) (fun i h => ?_) ?_ - · match List.mem_or_eq_of_mem_set hl with - | .inl hl => exact hs₁ _ hl - | .inr e => exact e ▸ .nil - · simp only [Array.length_toList, Array.size_set, Array.getElem_eq_getElem_toList, - Array.toList_set, List.getElem_set] - split - · nofun - · exact hs₂ _ (by simp_all) - · let rank (k : α) := ((hash k).toUSize % source.size).toNat - have := expand_WF.foldl rank ?_ (hs₂ _ H) ht.1 (fun _ h₁ _ h₂ => ?_) - · simp only [Array.get_eq_getElem, AssocList.foldl_eq, Array.size_set] - exact ⟨this.1, fun _ h₁ _ h₂ => Nat.lt_succ_of_le (this.2 _ h₁ _ h₂)⟩ - · exact hs₁ _ (Array.getElem_mem_toList ..) - · have := ht.2 _ h₁ _ h₂ - refine ⟨Nat.le_of_lt this, fun _ h h' => Nat.ne_of_lt this ?_⟩ - exact LawfulHashable.hash_eq h' ▸ hs₂ _ H _ h - · exact ht.1 - termination_by source.size - i - -theorem insert_size [BEq α] [Hashable α] {m : Imp α β} {k v} - (h : m.size = m.buckets.size) : - (insert m k v).size = (insert m k v).buckets.size := by - dsimp [insert, cond]; split - · unfold Buckets.size - refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ - simp [h, h₁, Buckets.size_eq] - split - · unfold Buckets.size - refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ - simp [h, h₁, Buckets.size_eq, Nat.succ_add]; rfl - · rw [expand_size]; simp only [expand, h, Buckets.size, Array.length_toList, Buckets.update_size] - refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ - simp [h₁, Buckets.size_eq, Nat.succ_add]; rfl - -private theorem mem_replaceF {l : List (α × β)} {x : α × β} {p : α × β → Bool} {f : α × β → β} : - x ∈ (l.replaceF fun a => bif p a then some (k, f a) else none) → x.1 = k ∨ x ∈ l := by - induction l with - | nil => exact .inr - | cons a l ih => - simp only [List.replaceF, List.mem_cons] - generalize e : cond .. = z; revert e - unfold cond; split <;> (intro h; subst h; simp) - · intro - | .inl eq => exact eq ▸ .inl rfl - | .inr h => exact .inr (.inr h) - · intro - | .inl eq => exact .inr (.inl eq) - | .inr h => exact (ih h).imp_right .inr - -private theorem pairwise_replaceF [BEq α] [PartialEquivBEq α] - {l : List (α × β)} {f : α × β → β} - (H : l.Pairwise fun a b => ¬(a.fst == b.fst)) : - (l.replaceF fun a => bif a.fst == k then some (k, f a) else none) - |>.Pairwise fun a b => ¬(a.fst == b.fst) := by - induction l with - | nil => simp [H] - | cons a l ih => - simp only [List.pairwise_cons, List.replaceF] at H ⊢ - generalize e : cond .. = z; unfold cond at e; revert e - split <;> (intro h; subst h; simp only [List.pairwise_cons]) - · next e => exact ⟨(H.1 · · ∘ PartialEquivBEq.trans e), H.2⟩ - · next e => - refine ⟨fun a h => ?_, ih H.2⟩ - match mem_replaceF h with - | .inl eq => exact eq ▸ ne_true_of_eq_false e - | .inr h => exact H.1 a h - -theorem insert_WF [BEq α] [Hashable α] {m : Imp α β} {k v} - (h : m.buckets.WF) : (insert m k v).buckets.WF := by - dsimp [insert, cond]; split - · next h₁ => - simp only [AssocList.contains_eq, List.any_eq_true] at h₁; have ⟨x, hx₁, hx₂⟩ := h₁ - refine h.update (fun H => ?_) (fun H a h => ?_) - · simp only [AssocList.toList_replace] - exact pairwise_replaceF H - · simp only [AssocList.All, Array.ugetElem_eq_getElem, AssocList.toList_replace] at H h ⊢ - match mem_replaceF h with - | .inl rfl => rfl - | .inr h => exact H _ h - · next h₁ => - rw [Bool.eq_false_iff] at h₁ - simp only [AssocList.contains_eq, ne_eq, List.any_eq_true, not_exists, not_and] at h₁ - suffices _ by split <;> [exact this; refine expand_WF this] - refine h.update (.cons ?_) (fun H a h => ?_) - · exact fun a h h' => h₁ a h (PartialEquivBEq.symm h') - · cases h with - | head => rfl - | tail _ h => exact H _ h - -theorem erase_size [BEq α] [Hashable α] {m : Imp α β} {k} - (h : m.size = m.buckets.size) : - (erase m k).size = (erase m k).buckets.size := by - dsimp [erase, cond]; split - · next H => - simp only [h, Buckets.size] - refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ - simp only [h₁, Array.length_toList, Array.ugetElem_eq_getElem, List.map_append, List.map_cons, - Nat.sum_append, Nat.sum_cons, AssocList.toList_erase] - rw [(_ : List.length _ = _ + 1), Nat.add_right_comm]; {rfl} - clear h₁ eq - simp only [AssocList.contains_eq, List.any_eq_true] at H - have ⟨a, h₁, h₂⟩ := H - refine have ⟨_, _, _, _, _, h, eq⟩ := List.exists_of_eraseP h₁ h₂; eq ▸ ?_ - simp [h]; rfl - · exact h - -theorem erase_WF [BEq α] [Hashable α] {m : Imp α β} {k} - (h : m.buckets.WF) : (erase m k).buckets.WF := by - dsimp [erase, cond]; split - · refine h.update (fun H => ?_) (fun H a h => ?_) <;> simp only [AssocList.toList_erase] at h ⊢ - · exact H.sublist (List.eraseP_sublist _) - · exact H _ (List.mem_of_mem_eraseP h) - · exact h - -theorem modify_size [BEq α] [Hashable α] {m : Imp α β} {k} - (h : m.size = m.buckets.size) : - (modify m k f).size = (modify m k f).buckets.size := by - dsimp [modify, cond]; rw [Buckets.update_update] - simp only [h, Buckets.size] - refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ - simp [h, h₁, Buckets.size_eq] - -theorem modify_WF [BEq α] [Hashable α] {m : Imp α β} {k} - (h : m.buckets.WF) : (modify m k f).buckets.WF := by - dsimp [modify, cond]; rw [Buckets.update_update] - refine h.update (fun H => ?_) (fun H a h => ?_) <;> simp at h ⊢ - · exact pairwise_replaceF H - · simp only [AssocList.All, Array.ugetElem_eq_getElem] at H h ⊢ - match mem_replaceF h with - | .inl rfl => rfl - | .inr h => exact H _ h - -theorem WF.out [BEq α] [Hashable α] {m : Imp α β} (h : m.WF) : - m.size = m.buckets.size ∧ m.buckets.WF := by - induction h with - | mk h₁ h₂ => exact ⟨h₁, h₂⟩ - | @empty' _ h => exact ⟨(Buckets.mk_size h).symm, .mk' h⟩ - | insert _ ih => exact ⟨insert_size ih.1, insert_WF ih.2⟩ - | erase _ ih => exact ⟨erase_size ih.1, erase_WF ih.2⟩ - | modify _ ih => exact ⟨modify_size ih.1, modify_WF ih.2⟩ - -theorem WF_iff [BEq α] [Hashable α] {m : Imp α β} : - m.WF ↔ m.size = m.buckets.size ∧ m.buckets.WF := - ⟨(·.out), fun ⟨h₁, h₂⟩ => .mk h₁ h₂⟩ - -theorem WF.mapVal {α β γ} {f : α → β → γ} [BEq α] [Hashable α] - {m : Imp α β} (H : WF m) : WF (mapVal f m) := by - have ⟨h₁, h₂⟩ := H.out - simp only [Imp.mapVal, h₁, Buckets.mapVal, WF_iff]; refine ⟨?_, ?_, fun i h => ?_⟩ - · simp only [Buckets.size, Array.toList_map, List.map_map]; congr; funext l; simp - · simp only [Array.toList_map, List.forall_mem_map] - simp only [AssocList.toList_mapVal, List.pairwise_map] - exact fun _ => h₂.1 _ - · simp only [Array.size_map, AssocList.All, Array.getElem_map, AssocList.toList_mapVal, - List.mem_map, forall_exists_index, and_imp, forall_apply_eq_imp_iff₂] at h ⊢ - intro a m - apply h₂.2 _ _ _ m - -theorem WF.filterMap {α β γ} {f : α → β → Option γ} [BEq α] [Hashable α] - {m : Imp α β} (H : WF m) : WF (filterMap f m) := by - let g₁ (l : AssocList α β) := l.toList.filterMap (fun x => (f x.1 x.2).map (x.1, ·)) - have H1 (l n acc) : filterMap.go f acc l n = - (((g₁ l).reverse ++ acc.toList).toAssocList, ⟨n.1 + (g₁ l).length⟩) := by - induction l generalizing n acc with simp only [filterMap.go, AssocList.toList, - List.filterMap_nil, List.reverse_nil, List.nil_append, AssocList.toList_toAssocList, - List.length_nil, Nat.add_zero, List.filterMap_cons, g₁, *] - | cons a b l => match f a b with - | none => rfl - | some c => - simp only [Option.map_some', List.reverse_cons, List.append_assoc, List.singleton_append, - List.length_cons, Nat.succ_eq_add_one, Prod.mk.injEq, true_and] - rw [Nat.add_right_comm] - rfl - let g l := (g₁ l).reverse.toAssocList - let M := StateT (ULift Nat) Id - have H2 (l : List (AssocList α β)) n : - l.mapM (m := M) (filterMap.go f .nil) n = - (l.map g, ⟨n.1 + .sum ((l.map g).map (·.toList.length))⟩) := by - induction l generalizing n with - | nil => rfl - | cons l L IH => simp [bind, StateT.bind, IH, H1, Nat.add_assoc, g]; rfl - have H3 (l : List _) : - (l.filterMap (fun (a, b) => (f a b).map (a, ·))).map (fun a => a.fst) - |>.Sublist (l.map (·.1)) := by - induction l with - | nil => exact .slnil - | cons a l ih => - simp only [List.filterMap_cons, List.map_cons]; exact match f a.1 a.2 with - | none => .cons _ ih - | some b => .cons₂ _ ih - suffices ∀ bk sz (h : 0 < bk.length), - m.buckets.val.mapM (m := M) (filterMap.go f .nil) ⟨0⟩ = (⟨bk⟩, ⟨sz⟩) → - WF ⟨sz, ⟨bk⟩, h⟩ from this _ _ _ rfl - simp only [Array.mapM_eq_mapM_toList, bind, StateT.bind, H2, List.map_map, Nat.zero_add, g] - intro bk sz h e'; cases e' - refine .mk (by simp [Buckets.size]) ⟨?_, fun i h => ?_⟩ - · simp only [List.forall_mem_map, List.toList_toAssocList] - refine fun l h => (List.pairwise_reverse.2 ?_).imp (mt PartialEquivBEq.symm) - have := H.out.2.1 _ h - rw [← List.pairwise_map (R := (¬ · == ·))] at this ⊢ - exact this.sublist (H3 l.toList) - · simp only [Array.size_mk, List.length_map, Array.toList_length, Array.getElem_eq_getElem_toList, - List.getElem_map] at h ⊢ - have := H.out.2.2 _ h - simp only [AssocList.All, List.toList_toAssocList, List.mem_reverse, List.mem_filterMap, - Option.map_eq_some', forall_exists_index, and_imp, g₁] at this ⊢ - rintro _ _ h' _ _ rfl - exact this _ h' - -end Imp - -variable {_ : BEq α} {_ : Hashable α} - -/-- Map a function over the values in the map. -/ -@[inline] def mapVal (f : α → β → γ) (self : HashMap α β) : HashMap α γ := - ⟨self.1.mapVal f, self.2.mapVal⟩ - --- Temporarily removed on lean-pr-testing-5403. - --- /-- --- Applies `f` to each key-value pair `a, b` in the map. If it returns `some c` then --- `a, c` is pushed into the new map; else the key is removed from the map. +-- /- +-- Copyright (c) 2022 Mario Carneiro. All rights reserved. +-- Released under Apache 2.0 license as described in the file LICENSE. +-- Authors: Mario Carneiro -- -/ --- @[inline] def filterMap (f : α → β → Option γ) (self : HashMap α β) : HashMap α γ := --- ⟨self.1.filterMap f, self.2.filterMap⟩ - --- /-- Constructs a map with the set of all pairs `a, b` such that `f` returns true. -/ --- @[inline] def filter (f : α → β → Bool) (self : HashMap α β) : HashMap α β := --- self.filterMap fun a b => bif f a b then some b else none +-- import Batteries.Tactic.SeqFocus +-- import Batteries.Data.HashMap.Basic +-- import Batteries.Data.Nat.Lemmas +-- import Batteries.Data.List.Lemmas + +-- namespace Batteries.HashMap +-- namespace Imp + +-- attribute [-simp] Bool.not_eq_true + +-- namespace Buckets + +-- @[ext] protected theorem ext : ∀ {b₁ b₂ : Buckets α β}, b₁.1.toList = b₂.1.toList → b₁ = b₂ +-- | ⟨⟨_⟩, _⟩, ⟨⟨_⟩, _⟩, rfl => rfl + +-- theorem toList_update (self : Buckets α β) (i d h) : +-- (self.update i d h).1.toList = self.1.toList.set i.toNat d := rfl + +-- @[deprecated (since := "2024-09-09")] alias update_data := toList_update + +-- theorem exists_of_update (self : Buckets α β) (i d h) : +-- ∃ l₁ l₂, self.1.toList = l₁ ++ self.1[i] :: l₂ ∧ List.length l₁ = i.toNat ∧ +-- (self.update i d h).1.toList = l₁ ++ d :: l₂ := by +-- simp only [Array.length_toList, Array.ugetElem_eq_getElem, Array.getElem_eq_getElem_toList] +-- exact List.exists_of_set h + +-- theorem update_update (self : Buckets α β) (i d d' h h') : +-- (self.update i d h).update i d' h' = self.update i d' h := by +-- simp only [update, Array.uset, Array.length_toList] +-- congr 1 +-- rw [Array.set_set] + +-- theorem size_eq (data : Buckets α β) : +-- size data = .sum (data.1.toList.map (·.toList.length)) := rfl + +-- theorem mk_size (h) : (mk n h : Buckets α β).size = 0 := by +-- simp only [mk, mkArray, size_eq]; clear h +-- induction n <;> simp_all [List.replicate_succ] + +-- theorem WF.mk' [BEq α] [Hashable α] (h) : (Buckets.mk n h : Buckets α β).WF := by +-- refine ⟨fun _ h => ?_, fun i h => ?_⟩ +-- · simp only [Buckets.mk, mkArray, List.mem_replicate, ne_eq] at h +-- simp [h, List.Pairwise.nil] +-- · simp [Buckets.mk, empty', mkArray, Array.getElem_eq_getElem_toList, AssocList.All] + +-- theorem WF.update [BEq α] [Hashable α] {buckets : Buckets α β} {i d h} (H : buckets.WF) +-- (h₁ : ∀ [PartialEquivBEq α] [LawfulHashable α], +-- (buckets.1[i].toList.Pairwise fun a b => ¬(a.1 == b.1)) → +-- d.toList.Pairwise fun a b => ¬(a.1 == b.1)) +-- (h₂ : (buckets.1[i].All fun k _ => ((hash k).toUSize % buckets.1.size).toNat = i.toNat) → +-- d.All fun k _ => ((hash k).toUSize % buckets.1.size).toNat = i.toNat) : +-- (buckets.update i d h).WF := by +-- refine ⟨fun l hl => ?_, fun i hi p hp => ?_⟩ +-- · exact match List.mem_or_eq_of_mem_set hl with +-- | .inl hl => H.1 _ hl +-- | .inr rfl => h₁ (H.1 _ (Array.getElem_mem_toList ..)) +-- · revert hp +-- simp only [Array.getElem_eq_getElem_toList, toList_update, List.getElem_set, +-- Array.length_toList, update_size] +-- split <;> intro hp +-- · next eq => exact eq ▸ h₂ (H.2 _ _) _ hp +-- · simp only [update_size, Array.length_toList] at hi +-- exact H.2 i hi _ hp + +-- end Buckets + +-- theorem reinsertAux_size [Hashable α] (data : Buckets α β) (a : α) (b : β) : +-- (reinsertAux data a b).size = data.size.succ := by +-- simp only [reinsertAux, Array.length_toList, Array.ugetElem_eq_getElem, Buckets.size_eq, +-- Nat.succ_eq_add_one] +-- refine have ⟨l₁, l₂, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ +-- simp [h₁, Nat.succ_add]; rfl + +-- theorem reinsertAux_WF [BEq α] [Hashable α] {data : Buckets α β} {a : α} {b : β} (H : data.WF) +-- (h₁ : ∀ [PartialEquivBEq α] [LawfulHashable α], +-- haveI := mkIdx data.2 (hash a).toUSize +-- data.val[this.1].All fun x _ => ¬(a == x)) : +-- (reinsertAux data a b).WF := +-- H.update (.cons h₁) fun +-- | _, _, .head .. => rfl +-- | H, _, .tail _ h => H _ h + +-- theorem expand_size [Hashable α] {buckets : Buckets α β} : +-- (expand sz buckets).buckets.size = buckets.size := by +-- rw [expand, go] +-- · rw [Buckets.mk_size]; simp [Buckets.size] +-- · nofun +-- where +-- go (i source) (target : Buckets α β) (hs : ∀ j < i, source.toList[j]?.getD .nil = .nil) : +-- (expand.go i source target).size = +-- .sum (source.toList.map (·.toList.length)) + target.size := by +-- unfold expand.go; split +-- · next H => +-- refine (go (i+1) _ _ fun j hj => ?a).trans ?b +-- · case a => +-- simp only [Array.length_toList, Array.toList_set] +-- simp [List.getD_eq_getElem?_getD, List.getElem?_set, Option.map_eq_map]; split +-- · cases source.toList[j]? <;> rfl +-- · next H => exact hs _ (Nat.lt_of_le_of_ne (Nat.le_of_lt_succ hj) (Ne.symm H)) +-- · case b => +-- simp only [Array.length_toList, Array.toList_set, Array.get_eq_getElem, AssocList.foldl_eq] +-- refine have ⟨l₁, l₂, h₁, _, eq⟩ := List.exists_of_set H; eq ▸ ?_ +-- rw [h₁] +-- simp only [Buckets.size_eq, List.map_append, List.map_cons, AssocList.toList, +-- List.length_nil, Nat.sum_append, Nat.sum_cons, Nat.zero_add, Array.length_toList] +-- rw [Nat.add_assoc, Nat.add_assoc, Nat.add_assoc]; congr 1 +-- (conv => rhs; rw [Nat.add_left_comm]); congr 1 +-- rw [← Array.getElem_eq_getElem_toList] +-- have := @reinsertAux_size α β _; simp [Buckets.size] at this +-- induction source[i].toList generalizing target <;> simp [*, Nat.succ_add]; rfl +-- · next H => +-- rw [(_ : Nat.sum _ = 0), Nat.zero_add] +-- rw [← (_ : source.toList.map (fun _ => .nil) = source.toList)] +-- · simp only [List.map_map] +-- induction source.toList <;> simp [*] +-- refine List.ext_getElem (by simp) fun j h₁ h₂ => ?_ +-- simp only [List.getElem_map, Array.length_toList] +-- have := (hs j (Nat.lt_of_lt_of_le h₂ (Nat.not_lt.1 H))).symm +-- rwa [List.getElem?_eq_getElem] at this +-- termination_by source.size - i + +-- theorem expand_WF.foldl [BEq α] [Hashable α] (rank : α → Nat) {l : List (α × β)} {i : Nat} +-- (hl₁ : ∀ [PartialEquivBEq α] [LawfulHashable α], l.Pairwise fun a b => ¬(a.1 == b.1)) +-- (hl₂ : ∀ x ∈ l, rank x.1 = i) +-- {target : Buckets α β} (ht₁ : target.WF) +-- (ht₂ : ∀ bucket ∈ target.1.toList, +-- bucket.All fun k _ => rank k ≤ i ∧ +-- ∀ [PartialEquivBEq α] [LawfulHashable α], ∀ x ∈ l, ¬(x.1 == k)) : +-- (l.foldl (fun d x => reinsertAux d x.1 x.2) target).WF ∧ +-- ∀ bucket ∈ (l.foldl (fun d x => reinsertAux d x.1 x.2) target).1.toList, +-- bucket.All fun k _ => rank k ≤ i := by +-- induction l generalizing target with +-- | nil => exact ⟨ht₁, fun _ h₁ _ h₂ => (ht₂ _ h₁ _ h₂).1⟩ +-- | cons _ _ ih => +-- simp only [List.pairwise_cons, List.mem_cons, forall_eq_or_imp] at hl₁ hl₂ ht₂ +-- refine ih hl₁.2 hl₂.2 +-- (reinsertAux_WF ht₁ fun _ h => (ht₂ _ (Array.getElem_mem_toList ..) _ h).2.1) +-- (fun _ h => ?_) +-- simp only [reinsertAux, Buckets.update, Array.uset, Array.length_toList, +-- Array.ugetElem_eq_getElem, Array.toList_set] at h +-- match List.mem_or_eq_of_mem_set h with +-- | .inl h => +-- intro _ hf +-- have ⟨h₁, h₂⟩ := ht₂ _ h _ hf +-- exact ⟨h₁, h₂.2⟩ +-- | .inr h => subst h; intro +-- | _, .head .. => +-- exact ⟨hl₂.1 ▸ Nat.le_refl _, fun _ h h' => hl₁.1 _ h (PartialEquivBEq.symm h')⟩ +-- | _, .tail _ h => +-- have ⟨h₁, h₂⟩ := ht₂ _ (Array.getElem_mem_toList ..) _ h +-- exact ⟨h₁, h₂.2⟩ + +-- theorem expand_WF [BEq α] [Hashable α] {buckets : Buckets α β} (H : buckets.WF) : +-- (expand sz buckets).buckets.WF := +-- go _ H.1 H.2 ⟨.mk' _, fun _ _ _ _ => by simp_all [Buckets.mk, List.mem_replicate]⟩ +-- where +-- go (i) {source : Array (AssocList α β)} +-- (hs₁ : ∀ [LawfulHashable α] [PartialEquivBEq α], ∀ bucket ∈ source.toList, +-- bucket.toList.Pairwise fun a b => ¬(a.1 == b.1)) +-- (hs₂ : ∀ (j : Nat) (h : j < source.size), +-- source[j].All fun k _ => ((hash k).toUSize % source.size).toNat = j) +-- {target : Buckets α β} (ht : target.WF ∧ ∀ bucket ∈ target.1.toList, +-- bucket.All fun k _ => ((hash k).toUSize % source.size).toNat < i) : +-- (expand.go i source target).WF := by +-- unfold expand.go; split +-- · next H => +-- refine go (i+1) (fun _ hl => ?_) (fun i h => ?_) ?_ +-- · match List.mem_or_eq_of_mem_set hl with +-- | .inl hl => exact hs₁ _ hl +-- | .inr e => exact e ▸ .nil +-- · simp only [Array.length_toList, Array.size_set, Array.getElem_eq_getElem_toList, +-- Array.toList_set, List.getElem_set] +-- split +-- · nofun +-- · exact hs₂ _ (by simp_all) +-- · let rank (k : α) := ((hash k).toUSize % source.size).toNat +-- have := expand_WF.foldl rank ?_ (hs₂ _ H) ht.1 (fun _ h₁ _ h₂ => ?_) +-- · simp only [Array.get_eq_getElem, AssocList.foldl_eq, Array.size_set] +-- exact ⟨this.1, fun _ h₁ _ h₂ => Nat.lt_succ_of_le (this.2 _ h₁ _ h₂)⟩ +-- · exact hs₁ _ (Array.getElem_mem_toList ..) +-- · have := ht.2 _ h₁ _ h₂ +-- refine ⟨Nat.le_of_lt this, fun _ h h' => Nat.ne_of_lt this ?_⟩ +-- exact LawfulHashable.hash_eq h' ▸ hs₂ _ H _ h +-- · exact ht.1 +-- termination_by source.size - i + +-- theorem insert_size [BEq α] [Hashable α] {m : Imp α β} {k v} +-- (h : m.size = m.buckets.size) : +-- (insert m k v).size = (insert m k v).buckets.size := by +-- dsimp [insert, cond]; split +-- · unfold Buckets.size +-- refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ +-- simp [h, h₁, Buckets.size_eq] +-- split +-- · unfold Buckets.size +-- refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ +-- simp [h, h₁, Buckets.size_eq, Nat.succ_add]; rfl +-- · rw [expand_size]; simp only [expand, h, Buckets.size, Array.length_toList, Buckets.update_size] +-- refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ +-- simp [h₁, Buckets.size_eq, Nat.succ_add]; rfl + +-- private theorem mem_replaceF {l : List (α × β)} {x : α × β} {p : α × β → Bool} {f : α × β → β} : +-- x ∈ (l.replaceF fun a => bif p a then some (k, f a) else none) → x.1 = k ∨ x ∈ l := by +-- induction l with +-- | nil => exact .inr +-- | cons a l ih => +-- simp only [List.replaceF, List.mem_cons] +-- generalize e : cond .. = z; revert e +-- unfold cond; split <;> (intro h; subst h; simp) +-- · intro +-- | .inl eq => exact eq ▸ .inl rfl +-- | .inr h => exact .inr (.inr h) +-- · intro +-- | .inl eq => exact .inr (.inl eq) +-- | .inr h => exact (ih h).imp_right .inr + +-- private theorem pairwise_replaceF [BEq α] [PartialEquivBEq α] +-- {l : List (α × β)} {f : α × β → β} +-- (H : l.Pairwise fun a b => ¬(a.fst == b.fst)) : +-- (l.replaceF fun a => bif a.fst == k then some (k, f a) else none) +-- |>.Pairwise fun a b => ¬(a.fst == b.fst) := by +-- induction l with +-- | nil => simp [H] +-- | cons a l ih => +-- simp only [List.pairwise_cons, List.replaceF] at H ⊢ +-- generalize e : cond .. = z; unfold cond at e; revert e +-- split <;> (intro h; subst h; simp only [List.pairwise_cons]) +-- · next e => exact ⟨(H.1 · · ∘ PartialEquivBEq.trans e), H.2⟩ +-- · next e => +-- refine ⟨fun a h => ?_, ih H.2⟩ +-- match mem_replaceF h with +-- | .inl eq => exact eq ▸ ne_true_of_eq_false e +-- | .inr h => exact H.1 a h + +-- theorem insert_WF [BEq α] [Hashable α] {m : Imp α β} {k v} +-- (h : m.buckets.WF) : (insert m k v).buckets.WF := by +-- dsimp [insert, cond]; split +-- · next h₁ => +-- simp only [AssocList.contains_eq, List.any_eq_true] at h₁; have ⟨x, hx₁, hx₂⟩ := h₁ +-- refine h.update (fun H => ?_) (fun H a h => ?_) +-- · simp only [AssocList.toList_replace] +-- exact pairwise_replaceF H +-- · simp only [AssocList.All, Array.ugetElem_eq_getElem, AssocList.toList_replace] at H h ⊢ +-- match mem_replaceF h with +-- | .inl rfl => rfl +-- | .inr h => exact H _ h +-- · next h₁ => +-- rw [Bool.eq_false_iff] at h₁ +-- simp only [AssocList.contains_eq, ne_eq, List.any_eq_true, not_exists, not_and] at h₁ +-- suffices _ by split <;> [exact this; refine expand_WF this] +-- refine h.update (.cons ?_) (fun H a h => ?_) +-- · exact fun a h h' => h₁ a h (PartialEquivBEq.symm h') +-- · cases h with +-- | head => rfl +-- | tail _ h => exact H _ h + +-- theorem erase_size [BEq α] [Hashable α] {m : Imp α β} {k} +-- (h : m.size = m.buckets.size) : +-- (erase m k).size = (erase m k).buckets.size := by +-- dsimp [erase, cond]; split +-- · next H => +-- simp only [h, Buckets.size] +-- refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ +-- simp only [h₁, Array.length_toList, Array.ugetElem_eq_getElem, List.map_append, List.map_cons, +-- Nat.sum_append, Nat.sum_cons, AssocList.toList_erase] +-- rw [(_ : List.length _ = _ + 1), Nat.add_right_comm]; {rfl} +-- clear h₁ eq +-- simp only [AssocList.contains_eq, List.any_eq_true] at H +-- have ⟨a, h₁, h₂⟩ := H +-- refine have ⟨_, _, _, _, _, h, eq⟩ := List.exists_of_eraseP h₁ h₂; eq ▸ ?_ +-- simp [h]; rfl +-- · exact h + +-- theorem erase_WF [BEq α] [Hashable α] {m : Imp α β} {k} +-- (h : m.buckets.WF) : (erase m k).buckets.WF := by +-- dsimp [erase, cond]; split +-- · refine h.update (fun H => ?_) (fun H a h => ?_) <;> simp only [AssocList.toList_erase] at h ⊢ +-- · exact H.sublist (List.eraseP_sublist _) +-- · exact H _ (List.mem_of_mem_eraseP h) +-- · exact h + +-- theorem modify_size [BEq α] [Hashable α] {m : Imp α β} {k} +-- (h : m.size = m.buckets.size) : +-- (modify m k f).size = (modify m k f).buckets.size := by +-- dsimp [modify, cond]; rw [Buckets.update_update] +-- simp only [h, Buckets.size] +-- refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_ +-- simp [h, h₁, Buckets.size_eq] + +-- theorem modify_WF [BEq α] [Hashable α] {m : Imp α β} {k} +-- (h : m.buckets.WF) : (modify m k f).buckets.WF := by +-- dsimp [modify, cond]; rw [Buckets.update_update] +-- refine h.update (fun H => ?_) (fun H a h => ?_) <;> simp at h ⊢ +-- · exact pairwise_replaceF H +-- · simp only [AssocList.All, Array.ugetElem_eq_getElem] at H h ⊢ +-- match mem_replaceF h with +-- | .inl rfl => rfl +-- | .inr h => exact H _ h + +-- theorem WF.out [BEq α] [Hashable α] {m : Imp α β} (h : m.WF) : +-- m.size = m.buckets.size ∧ m.buckets.WF := by +-- induction h with +-- | mk h₁ h₂ => exact ⟨h₁, h₂⟩ +-- | @empty' _ h => exact ⟨(Buckets.mk_size h).symm, .mk' h⟩ +-- | insert _ ih => exact ⟨insert_size ih.1, insert_WF ih.2⟩ +-- | erase _ ih => exact ⟨erase_size ih.1, erase_WF ih.2⟩ +-- | modify _ ih => exact ⟨modify_size ih.1, modify_WF ih.2⟩ + +-- theorem WF_iff [BEq α] [Hashable α] {m : Imp α β} : +-- m.WF ↔ m.size = m.buckets.size ∧ m.buckets.WF := +-- ⟨(·.out), fun ⟨h₁, h₂⟩ => .mk h₁ h₂⟩ + +-- theorem WF.mapVal {α β γ} {f : α → β → γ} [BEq α] [Hashable α] +-- {m : Imp α β} (H : WF m) : WF (mapVal f m) := by +-- have ⟨h₁, h₂⟩ := H.out +-- simp only [Imp.mapVal, h₁, Buckets.mapVal, WF_iff]; refine ⟨?_, ?_, fun i h => ?_⟩ +-- · simp only [Buckets.size, Array.toList_map, List.map_map]; congr; funext l; simp +-- · simp only [Array.toList_map, List.forall_mem_map] +-- simp only [AssocList.toList_mapVal, List.pairwise_map] +-- exact fun _ => h₂.1 _ +-- · simp only [Array.size_map, AssocList.All, Array.getElem_map, AssocList.toList_mapVal, +-- List.mem_map, forall_exists_index, and_imp, forall_apply_eq_imp_iff₂] at h ⊢ +-- intro a m +-- apply h₂.2 _ _ _ m + +-- theorem WF.filterMap {α β γ} {f : α → β → Option γ} [BEq α] [Hashable α] +-- {m : Imp α β} (H : WF m) : WF (filterMap f m) := by +-- let g₁ (l : AssocList α β) := l.toList.filterMap (fun x => (f x.1 x.2).map (x.1, ·)) +-- have H1 (l n acc) : filterMap.go f acc l n = +-- (((g₁ l).reverse ++ acc.toList).toAssocList, ⟨n.1 + (g₁ l).length⟩) := by +-- induction l generalizing n acc with simp only [filterMap.go, AssocList.toList, +-- List.filterMap_nil, List.reverse_nil, List.nil_append, AssocList.toList_toAssocList, +-- List.length_nil, Nat.add_zero, List.filterMap_cons, g₁, *] +-- | cons a b l => match f a b with +-- | none => rfl +-- | some c => +-- simp only [Option.map_some', List.reverse_cons, List.append_assoc, List.singleton_append, +-- List.length_cons, Nat.succ_eq_add_one, Prod.mk.injEq, true_and] +-- rw [Nat.add_right_comm] +-- rfl +-- let g l := (g₁ l).reverse.toAssocList +-- let M := StateT (ULift Nat) Id +-- have H2 (l : List (AssocList α β)) n : +-- l.mapM (m := M) (filterMap.go f .nil) n = +-- (l.map g, ⟨n.1 + .sum ((l.map g).map (·.toList.length))⟩) := by +-- induction l generalizing n with +-- | nil => rfl +-- | cons l L IH => simp [bind, StateT.bind, IH, H1, Nat.add_assoc, g]; rfl +-- have H3 (l : List _) : +-- (l.filterMap (fun (a, b) => (f a b).map (a, ·))).map (fun a => a.fst) +-- |>.Sublist (l.map (·.1)) := by +-- induction l with +-- | nil => exact .slnil +-- | cons a l ih => +-- simp only [List.filterMap_cons, List.map_cons]; exact match f a.1 a.2 with +-- | none => .cons _ ih +-- | some b => .cons₂ _ ih +-- suffices ∀ bk sz (h : 0 < bk.length), +-- m.buckets.val.mapM (m := M) (filterMap.go f .nil) ⟨0⟩ = (⟨bk⟩, ⟨sz⟩) → +-- WF ⟨sz, ⟨bk⟩, h⟩ from this _ _ _ rfl +-- simp only [Array.mapM_eq_mapM_toList, bind, StateT.bind, H2, List.map_map, Nat.zero_add, g] +-- intro bk sz h e'; cases e' +-- refine .mk (by simp [Buckets.size]) ⟨?_, fun i h => ?_⟩ +-- · simp only [List.forall_mem_map, List.toList_toAssocList] +-- refine fun l h => (List.pairwise_reverse.2 ?_).imp (mt PartialEquivBEq.symm) +-- have := H.out.2.1 _ h +-- rw [← List.pairwise_map (R := (¬ · == ·))] at this ⊢ +-- exact this.sublist (H3 l.toList) +-- · simp only [Array.size_mk, List.length_map, Array.toList_length, Array.getElem_eq_getElem_toList, +-- List.getElem_map] at h ⊢ +-- have := H.out.2.2 _ h +-- simp only [AssocList.All, List.toList_toAssocList, List.mem_reverse, List.mem_filterMap, +-- Option.map_eq_some', forall_exists_index, and_imp, g₁] at this ⊢ +-- rintro _ _ h' _ _ rfl +-- exact this _ h' + +-- end Imp + +-- variable {_ : BEq α} {_ : Hashable α} + +-- /-- Map a function over the values in the map. -/ +-- @[inline] def mapVal (f : α → β → γ) (self : HashMap α β) : HashMap α γ := +-- ⟨self.1.mapVal f, self.2.mapVal⟩ + +-- -- Temporarily removed on lean-pr-testing-5403. + +-- -- /-- +-- -- Applies `f` to each key-value pair `a, b` in the map. If it returns `some c` then +-- -- `a, c` is pushed into the new map; else the key is removed from the map. +-- -- -/ +-- -- @[inline] def filterMap (f : α → β → Option γ) (self : HashMap α β) : HashMap α γ := +-- -- ⟨self.1.filterMap f, self.2.filterMap⟩ + +-- -- /-- Constructs a map with the set of all pairs `a, b` such that `f` returns true. -/ +-- -- @[inline] def filter (f : α → β → Bool) (self : HashMap α β) : HashMap α β := +-- -- self.filterMap fun a b => bif f a b then some b else none diff --git a/Batteries/Data/UnionFind/Basic.lean b/Batteries/Data/UnionFind/Basic.lean index 1dcbc227b9..186f83ca85 100644 --- a/Batteries/Data/UnionFind/Basic.lean +++ b/Batteries/Data/UnionFind/Basic.lean @@ -1,591 +1,591 @@ -/- -Copyright (c) 2021 Mario Carneiro. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Mario Carneiro --/ -import Batteries.Tactic.Lint.Misc -import Batteries.Tactic.SeqFocus -import Batteries.Data.Array.Lemmas - -namespace Batteries +-- /- +-- Copyright (c) 2021 Mario Carneiro. All rights reserved. +-- Released under Apache 2.0 license as described in the file LICENSE. +-- Authors: Mario Carneiro +-- -/ +-- import Batteries.Tactic.Lint.Misc +-- import Batteries.Tactic.SeqFocus +-- import Batteries.Data.Array.Lemmas + +-- namespace Batteries -/-- Union-find node type -/ -structure UFNode where - /-- Parent of node -/ - parent : Nat - /-- Rank of node -/ - rank : Nat - -namespace UnionFind - -/-- Panic with return value -/ -def panicWith (v : α) (msg : String) : α := @panic α ⟨v⟩ msg +-- /-- Union-find node type -/ +-- structure UFNode where +-- /-- Parent of node -/ +-- parent : Nat +-- /-- Rank of node -/ +-- rank : Nat + +-- namespace UnionFind + +-- /-- Panic with return value -/ +-- def panicWith (v : α) (msg : String) : α := @panic α ⟨v⟩ msg -@[simp] theorem panicWith_eq (v : α) (msg) : panicWith v msg = v := rfl +-- @[simp] theorem panicWith_eq (v : α) (msg) : panicWith v msg = v := rfl -/-- Parent of a union-find node, defaults to self when the node is a root -/ -def parentD (arr : Array UFNode) (i : Nat) : Nat := - if h : i < arr.size then (arr.get ⟨i, h⟩).parent else i +-- /-- Parent of a union-find node, defaults to self when the node is a root -/ +-- def parentD (arr : Array UFNode) (i : Nat) : Nat := +-- if h : i < arr.size then (arr.get ⟨i, h⟩).parent else i -/-- Rank of a union-find node, defaults to 0 when the node is a root -/ -def rankD (arr : Array UFNode) (i : Nat) : Nat := - if h : i < arr.size then (arr.get ⟨i, h⟩).rank else 0 +-- /-- Rank of a union-find node, defaults to 0 when the node is a root -/ +-- def rankD (arr : Array UFNode) (i : Nat) : Nat := +-- if h : i < arr.size then (arr.get ⟨i, h⟩).rank else 0 -theorem parentD_eq {arr : Array UFNode} {i} : parentD arr i.1 = (arr.get i).parent := dif_pos _ +-- theorem parentD_eq {arr : Array UFNode} {i} : parentD arr i.1 = (arr.get i).parent := dif_pos _ -theorem parentD_eq' {arr : Array UFNode} {i} (h) : - parentD arr i = (arr.get ⟨i, h⟩).parent := dif_pos _ +-- theorem parentD_eq' {arr : Array UFNode} {i} (h) : +-- parentD arr i = (arr.get ⟨i, h⟩).parent := dif_pos _ -theorem rankD_eq {arr : Array UFNode} {i} : rankD arr i.1 = (arr.get i).rank := dif_pos _ +-- theorem rankD_eq {arr : Array UFNode} {i} : rankD arr i.1 = (arr.get i).rank := dif_pos _ -theorem rankD_eq' {arr : Array UFNode} {i} (h) : rankD arr i = (arr.get ⟨i, h⟩).rank := dif_pos _ +-- theorem rankD_eq' {arr : Array UFNode} {i} (h) : rankD arr i = (arr.get ⟨i, h⟩).rank := dif_pos _ -theorem parentD_of_not_lt : ¬i < arr.size → parentD arr i = i := (dif_neg ·) +-- theorem parentD_of_not_lt : ¬i < arr.size → parentD arr i = i := (dif_neg ·) -theorem lt_of_parentD : parentD arr i ≠ i → i < arr.size := - Decidable.not_imp_comm.1 parentD_of_not_lt +-- theorem lt_of_parentD : parentD arr i ≠ i → i < arr.size := +-- Decidable.not_imp_comm.1 parentD_of_not_lt -theorem parentD_set {arr : Array UFNode} {x v i} : - parentD (arr.set x v) i = if x.1 = i then v.parent else parentD arr i := by - rw [parentD]; simp only [Array.size_set, Array.get_eq_getElem, parentD] - split - · split <;> simp_all - · split <;> [(subst i; cases ‹¬_› x.2); rfl] +-- theorem parentD_set {arr : Array UFNode} {x v i} : +-- parentD (arr.set x v) i = if x.1 = i then v.parent else parentD arr i := by +-- rw [parentD]; simp only [Array.size_set, Array.get_eq_getElem, parentD] +-- split +-- · split <;> simp_all +-- · split <;> [(subst i; cases ‹¬_› x.2); rfl] -theorem rankD_set {arr : Array UFNode} {x v i} : - rankD (arr.set x v) i = if x.1 = i then v.rank else rankD arr i := by - rw [rankD]; simp only [Array.size_set, Array.get_eq_getElem, rankD] - split - · split <;> simp_all - · split <;> [(subst i; cases ‹¬_› x.2); rfl] +-- theorem rankD_set {arr : Array UFNode} {x v i} : +-- rankD (arr.set x v) i = if x.1 = i then v.rank else rankD arr i := by +-- rw [rankD]; simp only [Array.size_set, Array.get_eq_getElem, rankD] +-- split +-- · split <;> simp_all +-- · split <;> [(subst i; cases ‹¬_› x.2); rfl] -end UnionFind +-- end UnionFind -open UnionFind +-- open UnionFind -/-- ### Union-find data structure +-- /-- ### Union-find data structure -The `UnionFind` structure is an implementation of disjoint-set data structure -that uses path compression to make the primary operations run in amortized -nearly linear time. The nodes of a `UnionFind` structure `s` are natural -numbers smaller than `s.size`. The structure associates with a canonical -representative from its equivalence class. The structure can be extended -using the `push` operation and equivalence classes can be updated using the -`union` operation. +-- The `UnionFind` structure is an implementation of disjoint-set data structure +-- that uses path compression to make the primary operations run in amortized +-- nearly linear time. The nodes of a `UnionFind` structure `s` are natural +-- numbers smaller than `s.size`. The structure associates with a canonical +-- representative from its equivalence class. The structure can be extended +-- using the `push` operation and equivalence classes can be updated using the +-- `union` operation. -The main operations for `UnionFind` are: - -* `empty`/`mkEmpty` are used to create a new empty structure. -* `size` returns the size of the data structure. -* `push` adds a new node to a structure, unlinked to any other node. -* `union` links two nodes of the data structure, joining their equivalence - classes, and performs path compression. -* `find` returns the canonical representative of a node and updates the data - structure using path compression. -* `root` returns the canonical representative of a node without altering the - data structure. -* `checkEquiv` checks whether two nodes have the same canonical representative - and updates the structure using path compression. - -Most use cases should prefer `find` over `root` to benefit from the speedup from path-compression. - -The main operations use `Fin s.size` to represent nodes of the union-find structure. -Some alternatives are provided: - -* `unionN`, `findN`, `rootN`, `checkEquivN` use `Fin n` with a proof that `n = s.size`. -* `union!`, `find!`, `root!`, `checkEquiv!` use `Nat` and panic when the indices are out of bounds. -* `findD`, `rootD`, `checkEquivD` use `Nat` and treat out of bound indices as isolated nodes. - -The noncomputable relation `UnionFind.Equiv` is provided to use the equivalence relation from a -`UnionFind` structure in the context of proofs. --/ -structure UnionFind where - /-- Array of union-find nodes -/ - arr : Array UFNode - /-- Validity for parent nodes -/ - parentD_lt : ∀ {i}, i < arr.size → parentD arr i < arr.size - /-- Validity for rank -/ - rankD_lt : ∀ {i}, parentD arr i ≠ i → rankD arr i < rankD arr (parentD arr i) - -namespace UnionFind - -/-- Size of union-find structure. -/ -@[inline] abbrev size (self : UnionFind) := self.arr.size - -/-- Create an empty union-find structure with specific capacity -/ -def mkEmpty (c : Nat) : UnionFind where - arr := Array.mkEmpty c - parentD_lt := nofun - rankD_lt := nofun - -/-- Empty union-find structure -/ -def empty := mkEmpty 0 - -instance : EmptyCollection UnionFind := ⟨.empty⟩ - -/-- Parent of union-find node -/ -abbrev parent (self : UnionFind) (i : Nat) : Nat := parentD self.arr i - -theorem parent'_lt (self : UnionFind) (i : Fin self.size) : - (self.arr.get i).parent < self.size := by - simp only [← parentD_eq, parentD_lt, Fin.is_lt, Array.length_toList] - -theorem parent_lt (self : UnionFind) (i : Nat) : self.parent i < self.size ↔ i < self.size := by - simp only [parentD]; split <;> simp only [*, parent'_lt] - -/-- Rank of union-find node -/ -abbrev rank (self : UnionFind) (i : Nat) : Nat := rankD self.arr i - -theorem rank_lt {self : UnionFind} {i : Nat} : self.parent i ≠ i → - self.rank i < self.rank (self.parent i) := by simpa only [rank] using self.rankD_lt - -theorem rank'_lt (self : UnionFind) (i : Fin self.size) : (self.arr.get i).parent ≠ i → - self.rank i < self.rank (self.arr.get i).parent := by - simpa only [← parentD_eq] using self.rankD_lt - -/-- Maximum rank of nodes in a union-find structure -/ -noncomputable def rankMax (self : UnionFind) := self.arr.foldr (max ·.rank) 0 + 1 - -theorem rank'_lt_rankMax (self : UnionFind) (i : Fin self.size) : - (self.arr.get i).rank < self.rankMax := by - let rec go : ∀ {l} {x : UFNode}, x ∈ l → x.rank ≤ List.foldr (max ·.rank) 0 l - | a::l, _, List.Mem.head _ => by dsimp; apply Nat.le_max_left - | a::l, _, .tail _ h => by dsimp; exact Nat.le_trans (go h) (Nat.le_max_right ..) - simp only [Array.get_eq_getElem, rankMax, Array.foldr_eq_foldr_toList] - exact Nat.lt_succ.2 <| go (self.arr.toList.get_mem i.1 i.2) - -theorem rankD_lt_rankMax (self : UnionFind) (i : Nat) : - rankD self.arr i < self.rankMax := by - simp [rankD]; split <;> [apply rank'_lt_rankMax; apply Nat.succ_pos] - -theorem lt_rankMax (self : UnionFind) (i : Nat) : self.rank i < self.rankMax := rankD_lt_rankMax .. - -theorem push_rankD (arr : Array UFNode) : rankD (arr.push ⟨arr.size, 0⟩) i = rankD arr i := by - simp only [rankD, Array.size_push, Array.get_eq_getElem, Array.get_push, dite_eq_ite] - split <;> split <;> first | simp | cases ‹¬_› (Nat.lt_succ_of_lt ‹_›) - -theorem push_parentD (arr : Array UFNode) : parentD (arr.push ⟨arr.size, 0⟩) i = parentD arr i := by - simp only [parentD, Array.size_push, Array.get_eq_getElem, Array.get_push, dite_eq_ite] - split <;> split <;> try simp - · exact Nat.le_antisymm (Nat.ge_of_not_lt ‹_›) (Nat.le_of_lt_succ ‹_›) - · cases ‹¬_› (Nat.lt_succ_of_lt ‹_›) - -/-- Add a new node to a union-find structure, unlinked with any other nodes -/ -def push (self : UnionFind) : UnionFind where - arr := self.arr.push ⟨self.arr.size, 0⟩ - parentD_lt {i} := by - simp only [Array.size_push, push_parentD]; simp only [parentD, Array.get_eq_getElem] - split <;> [exact fun _ => Nat.lt_succ_of_lt (self.parent'_lt _); exact id] - rankD_lt := by simp only [push_parentD, ne_eq, push_rankD]; exact self.rank_lt - -/-- Root of a union-find node. -/ -def root (self : UnionFind) (x : Fin self.size) : Fin self.size := - let y := (self.arr.get x).parent - if h : y = x then - x - else - have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ h) - self.root ⟨y, self.parent'_lt x⟩ -termination_by self.rankMax - self.rank x - -@[inherit_doc root] -def rootN (self : UnionFind) (x : Fin n) (h : n = self.size) : Fin n := - match n, h with | _, rfl => self.root x - -/-- Root of a union-find node. Panics if index is out of bounds. -/ -def root! (self : UnionFind) (x : Nat) : Nat := - if h : x < self.size then self.root ⟨x, h⟩ else panicWith x "index out of bounds" - -/-- Root of a union-find node. Returns input if index is out of bounds. -/ -def rootD (self : UnionFind) (x : Nat) : Nat := - if h : x < self.size then self.root ⟨x, h⟩ else x - -@[nolint unusedHavesSuffices] -theorem parent_root (self : UnionFind) (x : Fin self.size) : - (self.arr.get (self.root x)).parent = self.root x := by - rw [root]; split <;> [assumption; skip] - have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) - apply parent_root -termination_by self.rankMax - self.rank x - -theorem parent_rootD (self : UnionFind) (x : Nat) : - self.parent (self.rootD x) = self.rootD x := by - rw [rootD] - split - · simp [parentD, parent_root, -Array.get_eq_getElem] - · simp [parentD_of_not_lt, *] - -@[nolint unusedHavesSuffices] -theorem rootD_parent (self : UnionFind) (x : Nat) : self.rootD (self.parent x) = self.rootD x := by - simp only [rootD, Array.length_toList, parent_lt] - split - · simp only [parentD, ↓reduceDIte, *] - (conv => rhs; rw [root]); split - · rw [root, dif_pos] <;> simp_all - · simp - · simp only [not_false_eq_true, parentD_of_not_lt, *] - -theorem rootD_lt {self : UnionFind} {x : Nat} : self.rootD x < self.size ↔ x < self.size := by - simp only [rootD, Array.length_toList]; split <;> simp [*] - -@[nolint unusedHavesSuffices] -theorem rootD_eq_self {self : UnionFind} {x : Nat} : self.rootD x = x ↔ self.parent x = x := by - refine ⟨fun h => by rw [← h, parent_rootD], fun h => ?_⟩ - rw [rootD]; split <;> [rw [root, dif_pos (by rwa [parent, parentD_eq' ‹_›] at h)]; rfl] - -theorem rootD_rootD {self : UnionFind} {x : Nat} : self.rootD (self.rootD x) = self.rootD x := - rootD_eq_self.2 (parent_rootD ..) - -theorem rootD_ext {m1 m2 : UnionFind} - (H : ∀ x, m1.parent x = m2.parent x) {x} : m1.rootD x = m2.rootD x := by - if h : m2.parent x = x then - rw [rootD_eq_self.2 h, rootD_eq_self.2 ((H _).trans h)] - else - have := Nat.sub_lt_sub_left (m2.lt_rankMax x) (m2.rank_lt h) - rw [← rootD_parent, H, rootD_ext H, rootD_parent] -termination_by m2.rankMax - m2.rank x - -theorem le_rank_root {self : UnionFind} {x : Nat} : self.rank x ≤ self.rank (self.rootD x) := by - if h : self.parent x = x then - rw [rootD_eq_self.2 h]; exact Nat.le_refl .. - else - have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank_lt h) - rw [← rootD_parent] - exact Nat.le_trans (Nat.le_of_lt (self.rank_lt h)) le_rank_root -termination_by self.rankMax - self.rank x - -theorem lt_rank_root {self : UnionFind} {x : Nat} : - self.rank x < self.rank (self.rootD x) ↔ self.parent x ≠ x := by - refine ⟨fun h h' => Nat.ne_of_lt h (by rw [rootD_eq_self.2 h']), fun h => ?_⟩ - rw [← rootD_parent] - exact Nat.lt_of_lt_of_le (self.rank_lt h) le_rank_root - -/-- Auxiliary data structure for find operation -/ -structure FindAux (n : Nat) where - /-- Array of nodes -/ - s : Array UFNode - /-- Index of root node -/ - root : Fin n - /-- Size requirement -/ - size_eq : s.size = n - -/-- Auxiliary function for find operation -/ -def findAux (self : UnionFind) (x : Fin self.size) : FindAux self.size := - let y := (self.arr.get x).parent - if h : y = x then - ⟨self.arr, x, rfl⟩ - else - have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ h) - let ⟨arr₁, root, H⟩ := self.findAux ⟨y, self.parent'_lt x⟩ - ⟨arr₁.modify x fun s => { s with parent := root }, root, by simp [H]⟩ -termination_by self.rankMax - self.rank x - -@[nolint unusedHavesSuffices] -theorem findAux_root {self : UnionFind} {x : Fin self.size} : - (findAux self x).root = self.root x := by - rw [findAux, root] - simp only [Array.length_toList, Array.get_eq_getElem, dite_eq_ite] - split <;> simp only - have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) - exact findAux_root -termination_by self.rankMax - self.rank x - -@[nolint unusedHavesSuffices] -theorem findAux_s {self : UnionFind} {x : Fin self.size} : - (findAux self x).s = if (self.arr.get x).parent = x then self.arr else - (self.findAux ⟨_, self.parent'_lt x⟩).s.modify x fun s => - { s with parent := self.rootD x } := by - rw [show self.rootD _ = (self.findAux ⟨_, self.parent'_lt x⟩).root from _] - · rw [findAux]; split <;> rfl - · rw [← rootD_parent, parent, parentD_eq] - simp only [rootD, Array.get_eq_getElem, Array.length_toList, findAux_root] - apply dif_pos - exact parent'_lt .. - -set_option linter.deprecated false in -theorem rankD_findAux {self : UnionFind} {x : Fin self.size} : - rankD (findAux self x).s i = self.rank i := by - if h : i < self.size then - rw [findAux_s]; split <;> [rfl; skip] - have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) - have := lt_of_parentD (by rwa [parentD_eq]) - rw [rankD_eq' (by simp [FindAux.size_eq, h])] - rw [Array.get_modify (by rwa [FindAux.size_eq])] - split <;> simp [← rankD_eq, rankD_findAux (x := ⟨_, self.parent'_lt x⟩), -Array.get_eq_getElem] - else - simp only [rankD, Array.data_length, Array.get_eq_getElem, rank] - rw [dif_neg (by rwa [FindAux.size_eq]), dif_neg h] -termination_by self.rankMax - self.rank x - -set_option linter.deprecated false in -theorem parentD_findAux {self : UnionFind} {x : Fin self.size} : - parentD (findAux self x).s i = - if i = x then self.rootD x else parentD (self.findAux ⟨_, self.parent'_lt x⟩).s i := by - rw [findAux_s]; split <;> [split; skip] - · subst i; rw [rootD_eq_self.2 _] <;> simp [parentD_eq, *, -Array.get_eq_getElem] - · rw [findAux_s]; simp [*, -Array.get_eq_getElem] - · next h => - rw [parentD]; split <;> rename_i h' - · rw [Array.get_modify (by simpa using h')] - simp only [Array.data_length, @eq_comm _ i] - split <;> simp [← parentD_eq, -Array.get_eq_getElem] - · rw [if_neg (mt (by rintro rfl; simp [FindAux.size_eq]) h')] - rw [parentD, dif_neg]; simpa using h' - -theorem parentD_findAux_rootD {self : UnionFind} {x : Fin self.size} : - parentD (findAux self x).s (self.rootD x) = self.rootD x := by - rw [parentD_findAux]; split <;> [rfl; rename_i h] - rw [rootD_eq_self, parent, parentD_eq] at h - have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) - rw [← rootD_parent, parent, parentD_eq] - exact parentD_findAux_rootD (x := ⟨_, self.parent'_lt x⟩) -termination_by self.rankMax - self.rank x - -theorem parentD_findAux_lt {self : UnionFind} {x : Fin self.size} (h : i < self.size) : - parentD (findAux self x).s i < self.size := by - if h' : (self.arr.get x).parent = x then - rw [findAux_s, if_pos h']; apply self.parentD_lt h - else - rw [parentD_findAux] - split - · simp [rootD_lt] - · have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) - apply parentD_findAux_lt h -termination_by self.rankMax - self.rank x - -theorem parentD_findAux_or (self : UnionFind) (x : Fin self.size) (i) : - parentD (findAux self x).s i = self.rootD i ∧ self.rootD i = self.rootD x ∨ - parentD (findAux self x).s i = self.parent i := by - if h' : (self.arr.get x).parent = x then - rw [findAux_s, if_pos h']; exact .inr rfl - else - rw [parentD_findAux] - split - · simp [*] - · have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) - exact (parentD_findAux_or self ⟨_, self.parent'_lt x⟩ i).imp_left <| .imp_right fun h => by - simp only [h, ← parentD_eq, rootD_parent, Array.length_toList] -termination_by self.rankMax - self.rank x - -theorem lt_rankD_findAux {self : UnionFind} {x : Fin self.size} : - parentD (findAux self x).s i ≠ i → - self.rank i < self.rank (parentD (findAux self x).s i) := by - if h' : (self.arr.get x).parent = x then - rw [findAux_s, if_pos h']; apply self.rank_lt - else - rw [parentD_findAux]; split <;> rename_i h <;> intro h' - · subst i; rwa [lt_rank_root, Ne, ← rootD_eq_self] - · have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) - apply lt_rankD_findAux h' -termination_by self.rankMax - self.rank x - -/-- Find root of a union-find node, updating the structure using path compression. -/ -def find (self : UnionFind) (x : Fin self.size) : - (s : UnionFind) × {_root : Fin s.size // s.size = self.size} := - let r := self.findAux x - { 1.arr := r.s - 2.1.val := r.root - 1.parentD_lt := fun h => by - simp only [Array.length_toList, FindAux.size_eq] at * - exact parentD_findAux_lt h - 1.rankD_lt := fun h => by rw [rankD_findAux, rankD_findAux]; exact lt_rankD_findAux h - 2.1.isLt := show _ < r.s.size by rw [r.size_eq]; exact r.root.2 - 2.2 := by simp [size, r.size_eq] } - -@[inherit_doc find] -def findN (self : UnionFind) (x : Fin n) (h : n = self.size) : UnionFind × Fin n := - match n, h with | _, rfl => match self.find x with | ⟨s, r, h⟩ => (s, Fin.cast h r) - -/-- Find root of a union-find node, updating the structure using path compression. - Panics if index is out of bounds. -/ -def find! (self : UnionFind) (x : Nat) : UnionFind × Nat := - if h : x < self.size then - match self.find ⟨x, h⟩ with | ⟨s, r, _⟩ => (s, r) - else - panicWith (self, x) "index out of bounds" - -/-- Find root of a union-find node, updating the structure using path compression. - Returns inputs unchanged when index is out of bounds. -/ -def findD (self : UnionFind) (x : Nat) : UnionFind × Nat := - if h : x < self.size then - match self.find ⟨x, h⟩ with | ⟨s, r, _⟩ => (s, r) - else - (self, x) - -@[simp] theorem find_size (self : UnionFind) (x : Fin self.size) : - (self.find x).1.size = self.size := by simp [find, size, FindAux.size_eq] - -@[simp] theorem find_root_2 (self : UnionFind) (x : Fin self.size) : - (self.find x).2.1.1 = self.rootD x := by simp [find, findAux_root, rootD] - -@[simp] theorem find_parent_1 (self : UnionFind) (x : Fin self.size) : - (self.find x).1.parent x = self.rootD x := by - simp only [parent, Array.length_toList, find] - rw [parentD_findAux, if_pos rfl] - -theorem find_parent_or (self : UnionFind) (x : Fin self.size) (i) : - (self.find x).1.parent i = self.rootD i ∧ self.rootD i = self.rootD x ∨ - (self.find x).1.parent i = self.parent i := parentD_findAux_or .. - -@[simp] theorem find_root_1 (self : UnionFind) (x : Fin self.size) (i : Nat) : - (self.find x).1.rootD i = self.rootD i := by - if h : (self.find x).1.parent i = i then - rw [rootD_eq_self.2 h] - obtain ⟨h1, _⟩ | h1 := find_parent_or self x i <;> rw [h1] at h - · rw [h] - · rw [rootD_eq_self.2 h] - else - have := Nat.sub_lt_sub_left ((self.find x).1.lt_rankMax _) ((self.find x).1.rank_lt h) - rw [← rootD_parent, find_root_1 self x ((self.find x).1.parent i)] - obtain ⟨h1, _⟩ | h1 := find_parent_or self x i - · rw [h1, rootD_rootD] - · rw [h1, rootD_parent] -termination_by (self.find x).1.rankMax - (self.find x).1.rank i -decreasing_by exact this -- why is this needed? It is way slower without it - -/-- Link two union-find nodes -/ -def linkAux (self : Array UFNode) (x y : Fin self.size) : Array UFNode := - if x.1 = y then - self - else - let nx := self.get x - let ny := self.get y - if ny.rank < nx.rank then - self.set y {ny with parent := x} - else - let arr₁ := self.set x {nx with parent := y} - if nx.rank = ny.rank then - arr₁.set ⟨y, by simp [arr₁]⟩ {ny with rank := ny.rank + 1} - else - arr₁ - -theorem setParentBump_rankD_lt {arr : Array UFNode} {x y : Fin arr.size} - (hroot : (arr.get x).rank < (arr.get y).rank ∨ (arr.get y).parent = y) - (H : (arr.get x).rank ≤ (arr.get y).rank) {i : Nat} - (rankD_lt : parentD arr i ≠ i → rankD arr i < rankD arr (parentD arr i)) - (hP : parentD arr' i = if x.1 = i then y.1 else parentD arr i) - (hR : ∀ {i}, rankD arr' i = - if y.1 = i ∧ (arr.get x).rank = (arr.get y).rank then - (arr.get y).rank + 1 - else rankD arr i) : - ¬parentD arr' i = i → rankD arr' i < rankD arr' (parentD arr' i) := by - simp [hP, hR, -Array.get_eq_getElem] at *; split <;> rename_i h₁ <;> [simp [← h₁]; skip] <;> - split <;> rename_i h₂ <;> intro h - · simp [h₂] at h - · simp only [rankD_eq, Array.get_eq_getElem] - split <;> rename_i h₃ - · rw [← h₃]; apply Nat.lt_succ_self - · exact Nat.lt_of_le_of_ne H h₃ - · cases h₂.1 - simp only [h₂.2, false_or, Nat.lt_irrefl] at hroot - simp only [hroot, parentD_eq, not_true_eq_false] at h - · have := rankD_lt h - split <;> rename_i h₃ - · rw [← rankD_eq, h₃.1]; exact Nat.lt_succ_of_lt this - · exact this - -theorem setParent_rankD_lt {arr : Array UFNode} {x y : Fin arr.size} - (h : (arr.get x).rank < (arr.get y).rank) {i : Nat} - (rankD_lt : parentD arr i ≠ i → rankD arr i < rankD arr (parentD arr i)) : - let arr' := arr.set x ⟨y, (arr.get x).rank⟩ - parentD arr' i ≠ i → rankD arr' i < rankD arr' (parentD arr' i) := - setParentBump_rankD_lt (.inl h) (Nat.le_of_lt h) rankD_lt parentD_set - (by simp [rankD_set, Nat.ne_of_lt h, rankD_eq, -Array.get_eq_getElem]) - -@[simp] theorem linkAux_size : (linkAux self x y).size = self.size := by - simp only [linkAux, Array.get_eq_getElem] - split <;> [rfl; split] <;> [skip; split] <;> simp - -/-- Link a union-find node to a root node. -/ -def link (self : UnionFind) (x y : Fin self.size) (yroot : self.parent y = y) : UnionFind where - arr := linkAux self.arr x y - parentD_lt h := by - simp only [Array.length_toList, linkAux_size] at * - simp only [linkAux, Array.get_eq_getElem] - split <;> [skip; split <;> [skip; split]] - · exact self.parentD_lt h - · rw [parentD_set]; split <;> [exact x.2; exact self.parentD_lt h] - · rw [parentD_set]; split - · exact self.parent'_lt _ - · rw [parentD_set]; split <;> [exact y.2; exact self.parentD_lt h] - · rw [parentD_set]; split <;> [exact y.2; exact self.parentD_lt h] - rankD_lt := by - rw [parent, parentD_eq] at yroot - simp only [linkAux, Array.get_eq_getElem, ne_eq] - split <;> [skip; split <;> [skip; split]] - · exact self.rankD_lt - · exact setParent_rankD_lt ‹_› self.rankD_lt - · refine setParentBump_rankD_lt (.inr yroot) (Nat.le_of_eq ‹_›) self.rankD_lt (by - simp only [parentD_set, ite_eq_right_iff] - rintro rfl - simp [*, parentD_eq]) fun {i} => ?_ - simp only [rankD_set, Fin.eta, Array.get_eq_getElem] - split - · simp_all - · simp_all only [Array.get_eq_getElem, Array.length_toList, Nat.lt_irrefl, not_false_eq_true, - and_true, ite_false, ite_eq_right_iff] - rintro rfl - simp [rankD_eq, *] - · exact setParent_rankD_lt (Nat.lt_of_le_of_ne (Nat.not_lt.1 ‹_›) ‹_›) self.rankD_lt - -@[inherit_doc link] -def linkN (self : UnionFind) (x y : Fin n) (yroot : self.parent y = y) (h : n = self.size) : - UnionFind := match n, h with | _, rfl => self.link x y yroot - -/-- Link a union-find node to a root node. Panics if either index is out of bounds. -/ -def link! (self : UnionFind) (x y : Nat) (yroot : self.parent y = y) : UnionFind := - if h : x < self.size ∧ y < self.size then - self.link ⟨x, h.1⟩ ⟨y, h.2⟩ yroot - else - panicWith self "index out of bounds" - -/-- Link two union-find nodes, uniting their respective classes. -/ -def union (self : UnionFind) (x y : Fin self.size) : UnionFind := - let ⟨self₁, rx, ex⟩ := self.find x - have hy := by rw [ex]; exact y.2 - match eq : self₁.find ⟨y, hy⟩ with - | ⟨self₂, ry, ey⟩ => - self₂.link ⟨rx, by rw [ey]; exact rx.2⟩ ry <| by - have := find_root_1 self₁ ⟨y, hy⟩ (⟨y, hy⟩ : Fin _) - rw [← find_root_2, eq] at this; simp at this - rw [← this, parent_rootD] - -@[inherit_doc union] -def unionN (self : UnionFind) (x y : Fin n) (h : n = self.size) : UnionFind := - match n, h with | _, rfl => self.union x y - -/-- Link two union-find nodes, uniting their respective classes. -Panics if either index is out of bounds. -/ -def union! (self : UnionFind) (x y : Nat) : UnionFind := - if h : x < self.size ∧ y < self.size then - self.union ⟨x, h.1⟩ ⟨y, h.2⟩ - else - panicWith self "index out of bounds" - -/-- Check whether two union-find nodes are equivalent, updating structure using path compression. -/ -def checkEquiv (self : UnionFind) (x y : Fin self.size) : UnionFind × Bool := - let ⟨s, ⟨r₁, _⟩, h⟩ := self.find x - let ⟨s, ⟨r₂, _⟩, _⟩ := s.find (h ▸ y) - (s, r₁ == r₂) - -@[inherit_doc checkEquiv] -def checkEquivN (self : UnionFind) (x y : Fin n) (h : n = self.size) : UnionFind × Bool := - match n, h with | _, rfl => self.checkEquiv x y - -/-- Check whether two union-find nodes are equivalent, updating structure using path compression. -Panics if either index is out of bounds. -/ -def checkEquiv! (self : UnionFind) (x y : Nat) : UnionFind × Bool := - if h : x < self.size ∧ y < self.size then - self.checkEquiv ⟨x, h.1⟩ ⟨y, h.2⟩ - else - panicWith (self, false) "index out of bounds" - -/-- Check whether two union-find nodes are equivalent with path compression, -returns `x == y` if either index is out of bounds -/ -def checkEquivD (self : UnionFind) (x y : Nat) : UnionFind × Bool := - let (s, x) := self.findD x - let (s, y) := s.findD y - (s, x == y) - -/-- Equivalence relation from a `UnionFind` structure -/ -def Equiv (self : UnionFind) (a b : Nat) : Prop := self.rootD a = self.rootD b +-- The main operations for `UnionFind` are: + +-- * `empty`/`mkEmpty` are used to create a new empty structure. +-- * `size` returns the size of the data structure. +-- * `push` adds a new node to a structure, unlinked to any other node. +-- * `union` links two nodes of the data structure, joining their equivalence +-- classes, and performs path compression. +-- * `find` returns the canonical representative of a node and updates the data +-- structure using path compression. +-- * `root` returns the canonical representative of a node without altering the +-- data structure. +-- * `checkEquiv` checks whether two nodes have the same canonical representative +-- and updates the structure using path compression. + +-- Most use cases should prefer `find` over `root` to benefit from the speedup from path-compression. + +-- The main operations use `Fin s.size` to represent nodes of the union-find structure. +-- Some alternatives are provided: + +-- * `unionN`, `findN`, `rootN`, `checkEquivN` use `Fin n` with a proof that `n = s.size`. +-- * `union!`, `find!`, `root!`, `checkEquiv!` use `Nat` and panic when the indices are out of bounds. +-- * `findD`, `rootD`, `checkEquivD` use `Nat` and treat out of bound indices as isolated nodes. + +-- The noncomputable relation `UnionFind.Equiv` is provided to use the equivalence relation from a +-- `UnionFind` structure in the context of proofs. +-- -/ +-- structure UnionFind where +-- /-- Array of union-find nodes -/ +-- arr : Array UFNode +-- /-- Validity for parent nodes -/ +-- parentD_lt : ∀ {i}, i < arr.size → parentD arr i < arr.size +-- /-- Validity for rank -/ +-- rankD_lt : ∀ {i}, parentD arr i ≠ i → rankD arr i < rankD arr (parentD arr i) + +-- namespace UnionFind + +-- /-- Size of union-find structure. -/ +-- @[inline] abbrev size (self : UnionFind) := self.arr.size + +-- /-- Create an empty union-find structure with specific capacity -/ +-- def mkEmpty (c : Nat) : UnionFind where +-- arr := Array.mkEmpty c +-- parentD_lt := nofun +-- rankD_lt := nofun + +-- /-- Empty union-find structure -/ +-- def empty := mkEmpty 0 + +-- instance : EmptyCollection UnionFind := ⟨.empty⟩ + +-- /-- Parent of union-find node -/ +-- abbrev parent (self : UnionFind) (i : Nat) : Nat := parentD self.arr i + +-- theorem parent'_lt (self : UnionFind) (i : Fin self.size) : +-- (self.arr.get i).parent < self.size := by +-- simp only [← parentD_eq, parentD_lt, Fin.is_lt, Array.length_toList] + +-- theorem parent_lt (self : UnionFind) (i : Nat) : self.parent i < self.size ↔ i < self.size := by +-- simp only [parentD]; split <;> simp only [*, parent'_lt] + +-- /-- Rank of union-find node -/ +-- abbrev rank (self : UnionFind) (i : Nat) : Nat := rankD self.arr i + +-- theorem rank_lt {self : UnionFind} {i : Nat} : self.parent i ≠ i → +-- self.rank i < self.rank (self.parent i) := by simpa only [rank] using self.rankD_lt + +-- theorem rank'_lt (self : UnionFind) (i : Fin self.size) : (self.arr.get i).parent ≠ i → +-- self.rank i < self.rank (self.arr.get i).parent := by +-- simpa only [← parentD_eq] using self.rankD_lt + +-- /-- Maximum rank of nodes in a union-find structure -/ +-- noncomputable def rankMax (self : UnionFind) := self.arr.foldr (max ·.rank) 0 + 1 + +-- theorem rank'_lt_rankMax (self : UnionFind) (i : Fin self.size) : +-- (self.arr.get i).rank < self.rankMax := by +-- let rec go : ∀ {l} {x : UFNode}, x ∈ l → x.rank ≤ List.foldr (max ·.rank) 0 l +-- | a::l, _, List.Mem.head _ => by dsimp; apply Nat.le_max_left +-- | a::l, _, .tail _ h => by dsimp; exact Nat.le_trans (go h) (Nat.le_max_right ..) +-- simp only [Array.get_eq_getElem, rankMax, Array.foldr_eq_foldr_toList] +-- exact Nat.lt_succ.2 <| go (self.arr.toList.get_mem i.1 i.2) + +-- theorem rankD_lt_rankMax (self : UnionFind) (i : Nat) : +-- rankD self.arr i < self.rankMax := by +-- simp [rankD]; split <;> [apply rank'_lt_rankMax; apply Nat.succ_pos] + +-- theorem lt_rankMax (self : UnionFind) (i : Nat) : self.rank i < self.rankMax := rankD_lt_rankMax .. + +-- theorem push_rankD (arr : Array UFNode) : rankD (arr.push ⟨arr.size, 0⟩) i = rankD arr i := by +-- simp only [rankD, Array.size_push, Array.get_eq_getElem, Array.get_push, dite_eq_ite] +-- split <;> split <;> first | simp | cases ‹¬_› (Nat.lt_succ_of_lt ‹_›) + +-- theorem push_parentD (arr : Array UFNode) : parentD (arr.push ⟨arr.size, 0⟩) i = parentD arr i := by +-- simp only [parentD, Array.size_push, Array.get_eq_getElem, Array.get_push, dite_eq_ite] +-- split <;> split <;> try simp +-- · exact Nat.le_antisymm (Nat.ge_of_not_lt ‹_›) (Nat.le_of_lt_succ ‹_›) +-- · cases ‹¬_› (Nat.lt_succ_of_lt ‹_›) + +-- /-- Add a new node to a union-find structure, unlinked with any other nodes -/ +-- def push (self : UnionFind) : UnionFind where +-- arr := self.arr.push ⟨self.arr.size, 0⟩ +-- parentD_lt {i} := by +-- simp only [Array.size_push, push_parentD]; simp only [parentD, Array.get_eq_getElem] +-- split <;> [exact fun _ => Nat.lt_succ_of_lt (self.parent'_lt _); exact id] +-- rankD_lt := by simp only [push_parentD, ne_eq, push_rankD]; exact self.rank_lt + +-- /-- Root of a union-find node. -/ +-- def root (self : UnionFind) (x : Fin self.size) : Fin self.size := +-- let y := (self.arr.get x).parent +-- if h : y = x then +-- x +-- else +-- have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ h) +-- self.root ⟨y, self.parent'_lt x⟩ +-- termination_by self.rankMax - self.rank x + +-- @[inherit_doc root] +-- def rootN (self : UnionFind) (x : Fin n) (h : n = self.size) : Fin n := +-- match n, h with | _, rfl => self.root x + +-- /-- Root of a union-find node. Panics if index is out of bounds. -/ +-- def root! (self : UnionFind) (x : Nat) : Nat := +-- if h : x < self.size then self.root ⟨x, h⟩ else panicWith x "index out of bounds" + +-- /-- Root of a union-find node. Returns input if index is out of bounds. -/ +-- def rootD (self : UnionFind) (x : Nat) : Nat := +-- if h : x < self.size then self.root ⟨x, h⟩ else x + +-- @[nolint unusedHavesSuffices] +-- theorem parent_root (self : UnionFind) (x : Fin self.size) : +-- (self.arr.get (self.root x)).parent = self.root x := by +-- rw [root]; split <;> [assumption; skip] +-- have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) +-- apply parent_root +-- termination_by self.rankMax - self.rank x + +-- theorem parent_rootD (self : UnionFind) (x : Nat) : +-- self.parent (self.rootD x) = self.rootD x := by +-- rw [rootD] +-- split +-- · simp [parentD, parent_root, -Array.get_eq_getElem] +-- · simp [parentD_of_not_lt, *] + +-- @[nolint unusedHavesSuffices] +-- theorem rootD_parent (self : UnionFind) (x : Nat) : self.rootD (self.parent x) = self.rootD x := by +-- simp only [rootD, Array.length_toList, parent_lt] +-- split +-- · simp only [parentD, ↓reduceDIte, *] +-- (conv => rhs; rw [root]); split +-- · rw [root, dif_pos] <;> simp_all +-- · simp +-- · simp only [not_false_eq_true, parentD_of_not_lt, *] + +-- theorem rootD_lt {self : UnionFind} {x : Nat} : self.rootD x < self.size ↔ x < self.size := by +-- simp only [rootD, Array.length_toList]; split <;> simp [*] + +-- @[nolint unusedHavesSuffices] +-- theorem rootD_eq_self {self : UnionFind} {x : Nat} : self.rootD x = x ↔ self.parent x = x := by +-- refine ⟨fun h => by rw [← h, parent_rootD], fun h => ?_⟩ +-- rw [rootD]; split <;> [rw [root, dif_pos (by rwa [parent, parentD_eq' ‹_›] at h)]; rfl] + +-- theorem rootD_rootD {self : UnionFind} {x : Nat} : self.rootD (self.rootD x) = self.rootD x := +-- rootD_eq_self.2 (parent_rootD ..) + +-- theorem rootD_ext {m1 m2 : UnionFind} +-- (H : ∀ x, m1.parent x = m2.parent x) {x} : m1.rootD x = m2.rootD x := by +-- if h : m2.parent x = x then +-- rw [rootD_eq_self.2 h, rootD_eq_self.2 ((H _).trans h)] +-- else +-- have := Nat.sub_lt_sub_left (m2.lt_rankMax x) (m2.rank_lt h) +-- rw [← rootD_parent, H, rootD_ext H, rootD_parent] +-- termination_by m2.rankMax - m2.rank x + +-- theorem le_rank_root {self : UnionFind} {x : Nat} : self.rank x ≤ self.rank (self.rootD x) := by +-- if h : self.parent x = x then +-- rw [rootD_eq_self.2 h]; exact Nat.le_refl .. +-- else +-- have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank_lt h) +-- rw [← rootD_parent] +-- exact Nat.le_trans (Nat.le_of_lt (self.rank_lt h)) le_rank_root +-- termination_by self.rankMax - self.rank x + +-- theorem lt_rank_root {self : UnionFind} {x : Nat} : +-- self.rank x < self.rank (self.rootD x) ↔ self.parent x ≠ x := by +-- refine ⟨fun h h' => Nat.ne_of_lt h (by rw [rootD_eq_self.2 h']), fun h => ?_⟩ +-- rw [← rootD_parent] +-- exact Nat.lt_of_lt_of_le (self.rank_lt h) le_rank_root + +-- /-- Auxiliary data structure for find operation -/ +-- structure FindAux (n : Nat) where +-- /-- Array of nodes -/ +-- s : Array UFNode +-- /-- Index of root node -/ +-- root : Fin n +-- /-- Size requirement -/ +-- size_eq : s.size = n + +-- /-- Auxiliary function for find operation -/ +-- def findAux (self : UnionFind) (x : Fin self.size) : FindAux self.size := +-- let y := (self.arr.get x).parent +-- if h : y = x then +-- ⟨self.arr, x, rfl⟩ +-- else +-- have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ h) +-- let ⟨arr₁, root, H⟩ := self.findAux ⟨y, self.parent'_lt x⟩ +-- ⟨arr₁.modify x fun s => { s with parent := root }, root, by simp [H]⟩ +-- termination_by self.rankMax - self.rank x + +-- @[nolint unusedHavesSuffices] +-- theorem findAux_root {self : UnionFind} {x : Fin self.size} : +-- (findAux self x).root = self.root x := by +-- rw [findAux, root] +-- simp only [Array.length_toList, Array.get_eq_getElem, dite_eq_ite] +-- split <;> simp only +-- have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) +-- exact findAux_root +-- termination_by self.rankMax - self.rank x + +-- @[nolint unusedHavesSuffices] +-- theorem findAux_s {self : UnionFind} {x : Fin self.size} : +-- (findAux self x).s = if (self.arr.get x).parent = x then self.arr else +-- (self.findAux ⟨_, self.parent'_lt x⟩).s.modify x fun s => +-- { s with parent := self.rootD x } := by +-- rw [show self.rootD _ = (self.findAux ⟨_, self.parent'_lt x⟩).root from _] +-- · rw [findAux]; split <;> rfl +-- · rw [← rootD_parent, parent, parentD_eq] +-- simp only [rootD, Array.get_eq_getElem, Array.length_toList, findAux_root] +-- apply dif_pos +-- exact parent'_lt .. + +-- set_option linter.deprecated false in +-- theorem rankD_findAux {self : UnionFind} {x : Fin self.size} : +-- rankD (findAux self x).s i = self.rank i := by +-- if h : i < self.size then +-- rw [findAux_s]; split <;> [rfl; skip] +-- have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) +-- have := lt_of_parentD (by rwa [parentD_eq]) +-- rw [rankD_eq' (by simp [FindAux.size_eq, h])] +-- rw [Array.get_modify (by rwa [FindAux.size_eq])] +-- split <;> simp [← rankD_eq, rankD_findAux (x := ⟨_, self.parent'_lt x⟩), -Array.get_eq_getElem] +-- else +-- simp only [rankD, Array.data_length, Array.get_eq_getElem, rank] +-- rw [dif_neg (by rwa [FindAux.size_eq]), dif_neg h] +-- termination_by self.rankMax - self.rank x + +-- set_option linter.deprecated false in +-- theorem parentD_findAux {self : UnionFind} {x : Fin self.size} : +-- parentD (findAux self x).s i = +-- if i = x then self.rootD x else parentD (self.findAux ⟨_, self.parent'_lt x⟩).s i := by +-- rw [findAux_s]; split <;> [split; skip] +-- · subst i; rw [rootD_eq_self.2 _] <;> simp [parentD_eq, *, -Array.get_eq_getElem] +-- · rw [findAux_s]; simp [*, -Array.get_eq_getElem] +-- · next h => +-- rw [parentD]; split <;> rename_i h' +-- · rw [Array.get_modify (by simpa using h')] +-- simp only [Array.data_length, @eq_comm _ i] +-- split <;> simp [← parentD_eq, -Array.get_eq_getElem] +-- · rw [if_neg (mt (by rintro rfl; simp [FindAux.size_eq]) h')] +-- rw [parentD, dif_neg]; simpa using h' + +-- theorem parentD_findAux_rootD {self : UnionFind} {x : Fin self.size} : +-- parentD (findAux self x).s (self.rootD x) = self.rootD x := by +-- rw [parentD_findAux]; split <;> [rfl; rename_i h] +-- rw [rootD_eq_self, parent, parentD_eq] at h +-- have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) +-- rw [← rootD_parent, parent, parentD_eq] +-- exact parentD_findAux_rootD (x := ⟨_, self.parent'_lt x⟩) +-- termination_by self.rankMax - self.rank x + +-- theorem parentD_findAux_lt {self : UnionFind} {x : Fin self.size} (h : i < self.size) : +-- parentD (findAux self x).s i < self.size := by +-- if h' : (self.arr.get x).parent = x then +-- rw [findAux_s, if_pos h']; apply self.parentD_lt h +-- else +-- rw [parentD_findAux] +-- split +-- · simp [rootD_lt] +-- · have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) +-- apply parentD_findAux_lt h +-- termination_by self.rankMax - self.rank x + +-- theorem parentD_findAux_or (self : UnionFind) (x : Fin self.size) (i) : +-- parentD (findAux self x).s i = self.rootD i ∧ self.rootD i = self.rootD x ∨ +-- parentD (findAux self x).s i = self.parent i := by +-- if h' : (self.arr.get x).parent = x then +-- rw [findAux_s, if_pos h']; exact .inr rfl +-- else +-- rw [parentD_findAux] +-- split +-- · simp [*] +-- · have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) +-- exact (parentD_findAux_or self ⟨_, self.parent'_lt x⟩ i).imp_left <| .imp_right fun h => by +-- simp only [h, ← parentD_eq, rootD_parent, Array.length_toList] +-- termination_by self.rankMax - self.rank x + +-- theorem lt_rankD_findAux {self : UnionFind} {x : Fin self.size} : +-- parentD (findAux self x).s i ≠ i → +-- self.rank i < self.rank (parentD (findAux self x).s i) := by +-- if h' : (self.arr.get x).parent = x then +-- rw [findAux_s, if_pos h']; apply self.rank_lt +-- else +-- rw [parentD_findAux]; split <;> rename_i h <;> intro h' +-- · subst i; rwa [lt_rank_root, Ne, ← rootD_eq_self] +-- · have := Nat.sub_lt_sub_left (self.lt_rankMax x) (self.rank'_lt _ ‹_›) +-- apply lt_rankD_findAux h' +-- termination_by self.rankMax - self.rank x + +-- /-- Find root of a union-find node, updating the structure using path compression. -/ +-- def find (self : UnionFind) (x : Fin self.size) : +-- (s : UnionFind) × {_root : Fin s.size // s.size = self.size} := +-- let r := self.findAux x +-- { 1.arr := r.s +-- 2.1.val := r.root +-- 1.parentD_lt := fun h => by +-- simp only [Array.length_toList, FindAux.size_eq] at * +-- exact parentD_findAux_lt h +-- 1.rankD_lt := fun h => by rw [rankD_findAux, rankD_findAux]; exact lt_rankD_findAux h +-- 2.1.isLt := show _ < r.s.size by rw [r.size_eq]; exact r.root.2 +-- 2.2 := by simp [size, r.size_eq] } + +-- @[inherit_doc find] +-- def findN (self : UnionFind) (x : Fin n) (h : n = self.size) : UnionFind × Fin n := +-- match n, h with | _, rfl => match self.find x with | ⟨s, r, h⟩ => (s, Fin.cast h r) + +-- /-- Find root of a union-find node, updating the structure using path compression. +-- Panics if index is out of bounds. -/ +-- def find! (self : UnionFind) (x : Nat) : UnionFind × Nat := +-- if h : x < self.size then +-- match self.find ⟨x, h⟩ with | ⟨s, r, _⟩ => (s, r) +-- else +-- panicWith (self, x) "index out of bounds" + +-- /-- Find root of a union-find node, updating the structure using path compression. +-- Returns inputs unchanged when index is out of bounds. -/ +-- def findD (self : UnionFind) (x : Nat) : UnionFind × Nat := +-- if h : x < self.size then +-- match self.find ⟨x, h⟩ with | ⟨s, r, _⟩ => (s, r) +-- else +-- (self, x) + +-- @[simp] theorem find_size (self : UnionFind) (x : Fin self.size) : +-- (self.find x).1.size = self.size := by simp [find, size, FindAux.size_eq] + +-- @[simp] theorem find_root_2 (self : UnionFind) (x : Fin self.size) : +-- (self.find x).2.1.1 = self.rootD x := by simp [find, findAux_root, rootD] + +-- @[simp] theorem find_parent_1 (self : UnionFind) (x : Fin self.size) : +-- (self.find x).1.parent x = self.rootD x := by +-- simp only [parent, Array.length_toList, find] +-- rw [parentD_findAux, if_pos rfl] + +-- theorem find_parent_or (self : UnionFind) (x : Fin self.size) (i) : +-- (self.find x).1.parent i = self.rootD i ∧ self.rootD i = self.rootD x ∨ +-- (self.find x).1.parent i = self.parent i := parentD_findAux_or .. + +-- @[simp] theorem find_root_1 (self : UnionFind) (x : Fin self.size) (i : Nat) : +-- (self.find x).1.rootD i = self.rootD i := by +-- if h : (self.find x).1.parent i = i then +-- rw [rootD_eq_self.2 h] +-- obtain ⟨h1, _⟩ | h1 := find_parent_or self x i <;> rw [h1] at h +-- · rw [h] +-- · rw [rootD_eq_self.2 h] +-- else +-- have := Nat.sub_lt_sub_left ((self.find x).1.lt_rankMax _) ((self.find x).1.rank_lt h) +-- rw [← rootD_parent, find_root_1 self x ((self.find x).1.parent i)] +-- obtain ⟨h1, _⟩ | h1 := find_parent_or self x i +-- · rw [h1, rootD_rootD] +-- · rw [h1, rootD_parent] +-- termination_by (self.find x).1.rankMax - (self.find x).1.rank i +-- decreasing_by exact this -- why is this needed? It is way slower without it + +-- /-- Link two union-find nodes -/ +-- def linkAux (self : Array UFNode) (x y : Fin self.size) : Array UFNode := +-- if x.1 = y then +-- self +-- else +-- let nx := self.get x +-- let ny := self.get y +-- if ny.rank < nx.rank then +-- self.set y {ny with parent := x} +-- else +-- let arr₁ := self.set x {nx with parent := y} +-- if nx.rank = ny.rank then +-- arr₁.set ⟨y, by simp [arr₁]⟩ {ny with rank := ny.rank + 1} +-- else +-- arr₁ + +-- theorem setParentBump_rankD_lt {arr : Array UFNode} {x y : Fin arr.size} +-- (hroot : (arr.get x).rank < (arr.get y).rank ∨ (arr.get y).parent = y) +-- (H : (arr.get x).rank ≤ (arr.get y).rank) {i : Nat} +-- (rankD_lt : parentD arr i ≠ i → rankD arr i < rankD arr (parentD arr i)) +-- (hP : parentD arr' i = if x.1 = i then y.1 else parentD arr i) +-- (hR : ∀ {i}, rankD arr' i = +-- if y.1 = i ∧ (arr.get x).rank = (arr.get y).rank then +-- (arr.get y).rank + 1 +-- else rankD arr i) : +-- ¬parentD arr' i = i → rankD arr' i < rankD arr' (parentD arr' i) := by +-- simp [hP, hR, -Array.get_eq_getElem] at *; split <;> rename_i h₁ <;> [simp [← h₁]; skip] <;> +-- split <;> rename_i h₂ <;> intro h +-- · simp [h₂] at h +-- · simp only [rankD_eq, Array.get_eq_getElem] +-- split <;> rename_i h₃ +-- · rw [← h₃]; apply Nat.lt_succ_self +-- · exact Nat.lt_of_le_of_ne H h₃ +-- · cases h₂.1 +-- simp only [h₂.2, false_or, Nat.lt_irrefl] at hroot +-- simp only [hroot, parentD_eq, not_true_eq_false] at h +-- · have := rankD_lt h +-- split <;> rename_i h₃ +-- · rw [← rankD_eq, h₃.1]; exact Nat.lt_succ_of_lt this +-- · exact this + +-- theorem setParent_rankD_lt {arr : Array UFNode} {x y : Fin arr.size} +-- (h : (arr.get x).rank < (arr.get y).rank) {i : Nat} +-- (rankD_lt : parentD arr i ≠ i → rankD arr i < rankD arr (parentD arr i)) : +-- let arr' := arr.set x ⟨y, (arr.get x).rank⟩ +-- parentD arr' i ≠ i → rankD arr' i < rankD arr' (parentD arr' i) := +-- setParentBump_rankD_lt (.inl h) (Nat.le_of_lt h) rankD_lt parentD_set +-- (by simp [rankD_set, Nat.ne_of_lt h, rankD_eq, -Array.get_eq_getElem]) + +-- @[simp] theorem linkAux_size : (linkAux self x y).size = self.size := by +-- simp only [linkAux, Array.get_eq_getElem] +-- split <;> [rfl; split] <;> [skip; split] <;> simp + +-- /-- Link a union-find node to a root node. -/ +-- def link (self : UnionFind) (x y : Fin self.size) (yroot : self.parent y = y) : UnionFind where +-- arr := linkAux self.arr x y +-- parentD_lt h := by +-- simp only [Array.length_toList, linkAux_size] at * +-- simp only [linkAux, Array.get_eq_getElem] +-- split <;> [skip; split <;> [skip; split]] +-- · exact self.parentD_lt h +-- · rw [parentD_set]; split <;> [exact x.2; exact self.parentD_lt h] +-- · rw [parentD_set]; split +-- · exact self.parent'_lt _ +-- · rw [parentD_set]; split <;> [exact y.2; exact self.parentD_lt h] +-- · rw [parentD_set]; split <;> [exact y.2; exact self.parentD_lt h] +-- rankD_lt := by +-- rw [parent, parentD_eq] at yroot +-- simp only [linkAux, Array.get_eq_getElem, ne_eq] +-- split <;> [skip; split <;> [skip; split]] +-- · exact self.rankD_lt +-- · exact setParent_rankD_lt ‹_› self.rankD_lt +-- · refine setParentBump_rankD_lt (.inr yroot) (Nat.le_of_eq ‹_›) self.rankD_lt (by +-- simp only [parentD_set, ite_eq_right_iff] +-- rintro rfl +-- simp [*, parentD_eq]) fun {i} => ?_ +-- simp only [rankD_set, Fin.eta, Array.get_eq_getElem] +-- split +-- · simp_all +-- · simp_all only [Array.get_eq_getElem, Array.length_toList, Nat.lt_irrefl, not_false_eq_true, +-- and_true, ite_false, ite_eq_right_iff] +-- rintro rfl +-- simp [rankD_eq, *] +-- · exact setParent_rankD_lt (Nat.lt_of_le_of_ne (Nat.not_lt.1 ‹_›) ‹_›) self.rankD_lt + +-- @[inherit_doc link] +-- def linkN (self : UnionFind) (x y : Fin n) (yroot : self.parent y = y) (h : n = self.size) : +-- UnionFind := match n, h with | _, rfl => self.link x y yroot + +-- /-- Link a union-find node to a root node. Panics if either index is out of bounds. -/ +-- def link! (self : UnionFind) (x y : Nat) (yroot : self.parent y = y) : UnionFind := +-- if h : x < self.size ∧ y < self.size then +-- self.link ⟨x, h.1⟩ ⟨y, h.2⟩ yroot +-- else +-- panicWith self "index out of bounds" + +-- /-- Link two union-find nodes, uniting their respective classes. -/ +-- def union (self : UnionFind) (x y : Fin self.size) : UnionFind := +-- let ⟨self₁, rx, ex⟩ := self.find x +-- have hy := by rw [ex]; exact y.2 +-- match eq : self₁.find ⟨y, hy⟩ with +-- | ⟨self₂, ry, ey⟩ => +-- self₂.link ⟨rx, by rw [ey]; exact rx.2⟩ ry <| by +-- have := find_root_1 self₁ ⟨y, hy⟩ (⟨y, hy⟩ : Fin _) +-- rw [← find_root_2, eq] at this; simp at this +-- rw [← this, parent_rootD] + +-- @[inherit_doc union] +-- def unionN (self : UnionFind) (x y : Fin n) (h : n = self.size) : UnionFind := +-- match n, h with | _, rfl => self.union x y + +-- /-- Link two union-find nodes, uniting their respective classes. +-- Panics if either index is out of bounds. -/ +-- def union! (self : UnionFind) (x y : Nat) : UnionFind := +-- if h : x < self.size ∧ y < self.size then +-- self.union ⟨x, h.1⟩ ⟨y, h.2⟩ +-- else +-- panicWith self "index out of bounds" + +-- /-- Check whether two union-find nodes are equivalent, updating structure using path compression. -/ +-- def checkEquiv (self : UnionFind) (x y : Fin self.size) : UnionFind × Bool := +-- let ⟨s, ⟨r₁, _⟩, h⟩ := self.find x +-- let ⟨s, ⟨r₂, _⟩, _⟩ := s.find (h ▸ y) +-- (s, r₁ == r₂) + +-- @[inherit_doc checkEquiv] +-- def checkEquivN (self : UnionFind) (x y : Fin n) (h : n = self.size) : UnionFind × Bool := +-- match n, h with | _, rfl => self.checkEquiv x y + +-- /-- Check whether two union-find nodes are equivalent, updating structure using path compression. +-- Panics if either index is out of bounds. -/ +-- def checkEquiv! (self : UnionFind) (x y : Nat) : UnionFind × Bool := +-- if h : x < self.size ∧ y < self.size then +-- self.checkEquiv ⟨x, h.1⟩ ⟨y, h.2⟩ +-- else +-- panicWith (self, false) "index out of bounds" + +-- /-- Check whether two union-find nodes are equivalent with path compression, +-- returns `x == y` if either index is out of bounds -/ +-- def checkEquivD (self : UnionFind) (x y : Nat) : UnionFind × Bool := +-- let (s, x) := self.findD x +-- let (s, y) := s.findD y +-- (s, x == y) + +-- /-- Equivalence relation from a `UnionFind` structure -/ +-- def Equiv (self : UnionFind) (a b : Nat) : Prop := self.rootD a = self.rootD b diff --git a/Batteries/Data/UnionFind/Lemmas.lean b/Batteries/Data/UnionFind/Lemmas.lean index a42ece4508..a07ebfdbd5 100644 --- a/Batteries/Data/UnionFind/Lemmas.lean +++ b/Batteries/Data/UnionFind/Lemmas.lean @@ -1,139 +1,139 @@ -/- -Copyright (c) 2021 Mario Carneiro. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Mario Carneiro --/ -import Batteries.Data.UnionFind.Basic - -namespace Batteries.UnionFind - -@[simp] theorem arr_empty : empty.arr = #[] := rfl -@[simp] theorem parent_empty : empty.parent a = a := rfl -@[simp] theorem rank_empty : empty.rank a = 0 := rfl -@[simp] theorem rootD_empty : empty.rootD a = a := rfl - -@[simp] theorem arr_push {m : UnionFind} : m.push.arr = m.arr.push ⟨m.arr.size, 0⟩ := rfl - -@[simp] theorem parentD_push {arr : Array UFNode} : - parentD (arr.push ⟨arr.size, 0⟩) a = parentD arr a := by - simp [parentD]; split <;> split <;> try simp [Array.get_push, *] - · next h1 h2 => - simp [Nat.lt_succ] at h1 h2 - exact Nat.le_antisymm h2 h1 - · next h1 h2 => cases h1 (Nat.lt_succ_of_lt h2) - -@[simp] theorem parent_push {m : UnionFind} : m.push.parent a = m.parent a := by simp [parent] - -@[simp] theorem rankD_push {arr : Array UFNode} : - rankD (arr.push ⟨arr.size, 0⟩) a = rankD arr a := by - simp [rankD]; split <;> split <;> try simp [Array.get_push, *] - next h1 h2 => cases h1 (Nat.lt_succ_of_lt h2) - -@[simp] theorem rank_push {m : UnionFind} : m.push.rank a = m.rank a := by simp [rank] - -@[simp] theorem rankMax_push {m : UnionFind} : m.push.rankMax = m.rankMax := by simp [rankMax] - -@[simp] theorem root_push {self : UnionFind} : self.push.rootD x = self.rootD x := - rootD_ext fun _ => parent_push - -@[simp] theorem arr_link : (link self x y yroot).arr = linkAux self.arr x y := rfl - -theorem parentD_linkAux {self} {x y : Fin self.size} : - parentD (linkAux self x y) i = - if x.1 = y then - parentD self i - else - if (self.get y).rank < (self.get x).rank then - if y = i then x else parentD self i - else - if x = i then y else parentD self i := by - dsimp only [linkAux]; split <;> [rfl; split] <;> [rw [parentD_set]; split] <;> rw [parentD_set] - split <;> [(subst i; rwa [if_neg, parentD_eq]); rw [parentD_set]] - -theorem parent_link {self} {x y : Fin self.size} (yroot) {i} : - (link self x y yroot).parent i = - if x.1 = y then - self.parent i - else - if self.rank y < self.rank x then - if y = i then x else self.parent i - else - if x = i then y else self.parent i := by - simp [rankD_eq]; exact parentD_linkAux - -theorem root_link {self : UnionFind} {x y : Fin self.size} - (xroot : self.parent x = x) (yroot : self.parent y = y) : - ∃ r, (r = x ∨ r = y) ∧ ∀ i, - (link self x y yroot).rootD i = - if self.rootD i = x ∨ self.rootD i = y then r.1 else self.rootD i := by - if h : x.1 = y then - refine ⟨x, .inl rfl, fun i => ?_⟩ - rw [rootD_ext (m2 := self) (fun _ => by rw [parent_link, if_pos h])] - split <;> [obtain _ | _ := ‹_› <;> simp [*]; rfl] - else - have {x y : Fin self.size} - (xroot : self.parent x = x) (yroot : self.parent y = y) {m : UnionFind} - (hm : ∀ i, m.parent i = if y = i then x.1 else self.parent i) : - ∃ r, (r = x ∨ r = y) ∧ ∀ i, - m.rootD i = if self.rootD i = x ∨ self.rootD i = y then r.1 else self.rootD i := by - let rec go (i) : - m.rootD i = if self.rootD i = x ∨ self.rootD i = y then x.1 else self.rootD i := by - if h : m.parent i = i then - rw [rootD_eq_self.2 h]; rw [hm i] at h; split at h - · rw [if_pos, h]; simp [← h, rootD_eq_self, xroot] - · rw [rootD_eq_self.2 ‹_›]; split <;> [skip; rfl] - next h' => exact h'.resolve_right (Ne.symm ‹_›) - else - have _ := Nat.sub_lt_sub_left (m.lt_rankMax i) (m.rank_lt h) - rw [← rootD_parent, go (m.parent i)] - rw [hm i]; split <;> [subst i; rw [rootD_parent]] - simp [rootD_eq_self.2 xroot, rootD_eq_self.2 yroot] - termination_by m.rankMax - m.rank i - exact ⟨x, .inl rfl, go⟩ - if hr : self.rank y < self.rank x then - exact this xroot yroot fun i => by simp [parent_link, h, hr] - else - simpa (config := {singlePass := true}) [or_comm] using - this yroot xroot fun i => by simp [parent_link, h, hr] - -nonrec theorem Equiv.rfl : Equiv self a a := rfl -theorem Equiv.symm : Equiv self a b → Equiv self b a := .symm -theorem Equiv.trans : Equiv self a b → Equiv self b c → Equiv self a c := .trans - -@[simp] theorem equiv_empty : Equiv empty a b ↔ a = b := by simp [Equiv] - -@[simp] theorem equiv_push : Equiv self.push a b ↔ Equiv self a b := by simp [Equiv] - -@[simp] theorem equiv_rootD : Equiv self (self.rootD a) a := by simp [Equiv, rootD_rootD] -@[simp] theorem equiv_rootD_l : Equiv self (self.rootD a) b ↔ Equiv self a b := by - simp [Equiv, rootD_rootD] -@[simp] theorem equiv_rootD_r : Equiv self a (self.rootD b) ↔ Equiv self a b := by - simp [Equiv, rootD_rootD] - -theorem equiv_find : Equiv (self.find x).1 a b ↔ Equiv self a b := by simp [Equiv, find_root_1] - -theorem equiv_link {self : UnionFind} {x y : Fin self.size} - (xroot : self.parent x = x) (yroot : self.parent y = y) : - Equiv (link self x y yroot) a b ↔ - Equiv self a b ∨ Equiv self a x ∧ Equiv self y b ∨ Equiv self a y ∧ Equiv self x b := by - have {m : UnionFind} {x y : Fin self.size} - (xroot : self.rootD x = x) (yroot : self.rootD y = y) - (hm : ∀ i, m.rootD i = if self.rootD i = x ∨ self.rootD i = y then x.1 else self.rootD i) : - Equiv m a b ↔ - Equiv self a b ∨ Equiv self a x ∧ Equiv self y b ∨ Equiv self a y ∧ Equiv self x b := by - simp [Equiv, hm, xroot, yroot] - by_cases h1 : rootD self a = x <;> by_cases h2 : rootD self b = x <;> - simp [h1, h2, imp_false, Decidable.not_not] - · simp [h2, Ne.symm h2]; split <;> simp [@eq_comm _ _ (rootD self b), *] - · by_cases h1 : rootD self a = y <;> by_cases h2 : rootD self b = y <;> - simp [h1, h2, @eq_comm _ _ (rootD self b), *] - obtain ⟨r, ha, hr⟩ := root_link xroot yroot; revert hr - rw [← rootD_eq_self] at xroot yroot - obtain rfl | rfl := ha - · exact this xroot yroot - · simpa [or_comm, and_comm] using this yroot xroot - -theorem equiv_union {self : UnionFind} {x y : Fin self.size} : - Equiv (union self x y) a b ↔ - Equiv self a b ∨ Equiv self a x ∧ Equiv self y b ∨ Equiv self a y ∧ Equiv self x b := by - simp [union]; rw [equiv_link (by simp [← rootD_eq_self, rootD_rootD])]; simp [equiv_find] +-- /- +-- Copyright (c) 2021 Mario Carneiro. All rights reserved. +-- Released under Apache 2.0 license as described in the file LICENSE. +-- Authors: Mario Carneiro +-- -/ +-- import Batteries.Data.UnionFind.Basic + +-- namespace Batteries.UnionFind + +-- @[simp] theorem arr_empty : empty.arr = #[] := rfl +-- @[simp] theorem parent_empty : empty.parent a = a := rfl +-- @[simp] theorem rank_empty : empty.rank a = 0 := rfl +-- @[simp] theorem rootD_empty : empty.rootD a = a := rfl + +-- @[simp] theorem arr_push {m : UnionFind} : m.push.arr = m.arr.push ⟨m.arr.size, 0⟩ := rfl + +-- @[simp] theorem parentD_push {arr : Array UFNode} : +-- parentD (arr.push ⟨arr.size, 0⟩) a = parentD arr a := by +-- simp [parentD]; split <;> split <;> try simp [Array.get_push, *] +-- · next h1 h2 => +-- simp [Nat.lt_succ] at h1 h2 +-- exact Nat.le_antisymm h2 h1 +-- · next h1 h2 => cases h1 (Nat.lt_succ_of_lt h2) + +-- @[simp] theorem parent_push {m : UnionFind} : m.push.parent a = m.parent a := by simp [parent] + +-- @[simp] theorem rankD_push {arr : Array UFNode} : +-- rankD (arr.push ⟨arr.size, 0⟩) a = rankD arr a := by +-- simp [rankD]; split <;> split <;> try simp [Array.get_push, *] +-- next h1 h2 => cases h1 (Nat.lt_succ_of_lt h2) + +-- @[simp] theorem rank_push {m : UnionFind} : m.push.rank a = m.rank a := by simp [rank] + +-- @[simp] theorem rankMax_push {m : UnionFind} : m.push.rankMax = m.rankMax := by simp [rankMax] + +-- @[simp] theorem root_push {self : UnionFind} : self.push.rootD x = self.rootD x := +-- rootD_ext fun _ => parent_push + +-- @[simp] theorem arr_link : (link self x y yroot).arr = linkAux self.arr x y := rfl + +-- theorem parentD_linkAux {self} {x y : Fin self.size} : +-- parentD (linkAux self x y) i = +-- if x.1 = y then +-- parentD self i +-- else +-- if (self.get y).rank < (self.get x).rank then +-- if y = i then x else parentD self i +-- else +-- if x = i then y else parentD self i := by +-- dsimp only [linkAux]; split <;> [rfl; split] <;> [rw [parentD_set]; split] <;> rw [parentD_set] +-- split <;> [(subst i; rwa [if_neg, parentD_eq]); rw [parentD_set]] + +-- theorem parent_link {self} {x y : Fin self.size} (yroot) {i} : +-- (link self x y yroot).parent i = +-- if x.1 = y then +-- self.parent i +-- else +-- if self.rank y < self.rank x then +-- if y = i then x else self.parent i +-- else +-- if x = i then y else self.parent i := by +-- simp [rankD_eq]; exact parentD_linkAux + +-- theorem root_link {self : UnionFind} {x y : Fin self.size} +-- (xroot : self.parent x = x) (yroot : self.parent y = y) : +-- ∃ r, (r = x ∨ r = y) ∧ ∀ i, +-- (link self x y yroot).rootD i = +-- if self.rootD i = x ∨ self.rootD i = y then r.1 else self.rootD i := by +-- if h : x.1 = y then +-- refine ⟨x, .inl rfl, fun i => ?_⟩ +-- rw [rootD_ext (m2 := self) (fun _ => by rw [parent_link, if_pos h])] +-- split <;> [obtain _ | _ := ‹_› <;> simp [*]; rfl] +-- else +-- have {x y : Fin self.size} +-- (xroot : self.parent x = x) (yroot : self.parent y = y) {m : UnionFind} +-- (hm : ∀ i, m.parent i = if y = i then x.1 else self.parent i) : +-- ∃ r, (r = x ∨ r = y) ∧ ∀ i, +-- m.rootD i = if self.rootD i = x ∨ self.rootD i = y then r.1 else self.rootD i := by +-- let rec go (i) : +-- m.rootD i = if self.rootD i = x ∨ self.rootD i = y then x.1 else self.rootD i := by +-- if h : m.parent i = i then +-- rw [rootD_eq_self.2 h]; rw [hm i] at h; split at h +-- · rw [if_pos, h]; simp [← h, rootD_eq_self, xroot] +-- · rw [rootD_eq_self.2 ‹_›]; split <;> [skip; rfl] +-- next h' => exact h'.resolve_right (Ne.symm ‹_›) +-- else +-- have _ := Nat.sub_lt_sub_left (m.lt_rankMax i) (m.rank_lt h) +-- rw [← rootD_parent, go (m.parent i)] +-- rw [hm i]; split <;> [subst i; rw [rootD_parent]] +-- simp [rootD_eq_self.2 xroot, rootD_eq_self.2 yroot] +-- termination_by m.rankMax - m.rank i +-- exact ⟨x, .inl rfl, go⟩ +-- if hr : self.rank y < self.rank x then +-- exact this xroot yroot fun i => by simp [parent_link, h, hr] +-- else +-- simpa (config := {singlePass := true}) [or_comm] using +-- this yroot xroot fun i => by simp [parent_link, h, hr] + +-- nonrec theorem Equiv.rfl : Equiv self a a := rfl +-- theorem Equiv.symm : Equiv self a b → Equiv self b a := .symm +-- theorem Equiv.trans : Equiv self a b → Equiv self b c → Equiv self a c := .trans + +-- @[simp] theorem equiv_empty : Equiv empty a b ↔ a = b := by simp [Equiv] + +-- @[simp] theorem equiv_push : Equiv self.push a b ↔ Equiv self a b := by simp [Equiv] + +-- @[simp] theorem equiv_rootD : Equiv self (self.rootD a) a := by simp [Equiv, rootD_rootD] +-- @[simp] theorem equiv_rootD_l : Equiv self (self.rootD a) b ↔ Equiv self a b := by +-- simp [Equiv, rootD_rootD] +-- @[simp] theorem equiv_rootD_r : Equiv self a (self.rootD b) ↔ Equiv self a b := by +-- simp [Equiv, rootD_rootD] + +-- theorem equiv_find : Equiv (self.find x).1 a b ↔ Equiv self a b := by simp [Equiv, find_root_1] + +-- theorem equiv_link {self : UnionFind} {x y : Fin self.size} +-- (xroot : self.parent x = x) (yroot : self.parent y = y) : +-- Equiv (link self x y yroot) a b ↔ +-- Equiv self a b ∨ Equiv self a x ∧ Equiv self y b ∨ Equiv self a y ∧ Equiv self x b := by +-- have {m : UnionFind} {x y : Fin self.size} +-- (xroot : self.rootD x = x) (yroot : self.rootD y = y) +-- (hm : ∀ i, m.rootD i = if self.rootD i = x ∨ self.rootD i = y then x.1 else self.rootD i) : +-- Equiv m a b ↔ +-- Equiv self a b ∨ Equiv self a x ∧ Equiv self y b ∨ Equiv self a y ∧ Equiv self x b := by +-- simp [Equiv, hm, xroot, yroot] +-- by_cases h1 : rootD self a = x <;> by_cases h2 : rootD self b = x <;> +-- simp [h1, h2, imp_false, Decidable.not_not] +-- · simp [h2, Ne.symm h2]; split <;> simp [@eq_comm _ _ (rootD self b), *] +-- · by_cases h1 : rootD self a = y <;> by_cases h2 : rootD self b = y <;> +-- simp [h1, h2, @eq_comm _ _ (rootD self b), *] +-- obtain ⟨r, ha, hr⟩ := root_link xroot yroot; revert hr +-- rw [← rootD_eq_self] at xroot yroot +-- obtain rfl | rfl := ha +-- · exact this xroot yroot +-- · simpa [or_comm, and_comm] using this yroot xroot + +-- theorem equiv_union {self : UnionFind} {x y : Fin self.size} : +-- Equiv (union self x y) a b ↔ +-- Equiv self a b ∨ Equiv self a x ∧ Equiv self y b ∨ Equiv self a y ∧ Equiv self x b := by +-- simp [union]; rw [equiv_link (by simp [← rootD_eq_self, rootD_rootD])]; simp [equiv_find] diff --git a/Batteries/StdDeprecations.lean b/Batteries/StdDeprecations.lean index cb49a3bdbf..6582f6714a 100644 --- a/Batteries/StdDeprecations.lean +++ b/Batteries/StdDeprecations.lean @@ -48,16 +48,16 @@ alias Std.compareOfLessAndEq_eq_lt := Batteries.compareOfLessAndEq_eq_lt @[deprecated (since := "2024-05-07")] alias Std.mkRBMap := Batteries.mkRBMap @[deprecated (since := "2024-05-07")] alias Std.BinomialHeap := Batteries.BinomialHeap @[deprecated (since := "2024-05-07")] alias Std.mkBinomialHeap := Batteries.mkBinomialHeap -@[deprecated (since := "2024-05-07")] alias Std.UFNode := Batteries.UFNode -@[deprecated (since := "2024-05-07")] alias Std.UnionFind := Batteries.UnionFind +-- @[deprecated (since := "2024-05-07")] alias Std.UFNode := Batteries.UFNode +-- @[deprecated (since := "2024-05-07")] alias Std.UnionFind := Batteries.UnionFind --- Check that these generate usable deprecated hints --- when referring to names inside these namespaces. -set_option warningAsError true in -/-- -error: `Std.UnionFind` has been deprecated, use `Batteries.UnionFind` instead ---- -error: unknown constant 'Std.UnionFind.find' --/ -#guard_msgs in -#eval Std.UnionFind.find +-- -- Check that these generate usable deprecated hints +-- -- when referring to names inside these namespaces. +-- set_option warningAsError true in +-- /-- +-- error: `Std.UnionFind` has been deprecated, use `Batteries.UnionFind` instead +-- --- +-- error: unknown constant 'Std.UnionFind.find' +-- -/ +-- #guard_msgs in +-- #eval Std.UnionFind.find