-
Notifications
You must be signed in to change notification settings - Fork 31
/
UnionFindWithUndo.go
181 lines (159 loc) · 4.58 KB
/
UnionFindWithUndo.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
// UnionFindWithUndoAndWeight
// https://nyaannyaan.github.io/library/data-structure/rollback-union-find.hpp
// 可撤销并查集(时间旅行)
// API:
// RollbackUnionFind(int sz):
// Union(int x, int y):
// Find(int k):
// IsConnected(int x, int y):
// Undo():撤销上一次合并操作,没合并成功也要撤销.
// Snapshot():内部保存当前状态。
// !Snapshot() 之后可以调用 Rollback(-1) 回滚到这个状态.
// Rollback(int state = -1):回滚到指定状态。
// state等于-1时,会回滚到snapshot()中保存的状态。
// 否则,会回滚到指定的state次union调用时的状态。
// GetState():
// 返回当前状态,即union调用的次数。
package main
import (
"fmt"
"sort"
"strings"
)
func main() {
uf := NewUnionFindArrayWithUndo(10)
uf.Union(0, 1)
uf.Union(2, 3)
fmt.Println(uf.Find(0), uf.Find(1), uf.Find(2), uf.Find(3))
uf.Union(0, 2)
fmt.Println(uf.Find(0), uf.Find(1), uf.Find(2), uf.Find(3))
uf.Undo()
fmt.Println(uf.Find(0), uf.Find(1), uf.Find(2), uf.Find(3))
fmt.Println(uf)
uf.Snapshot()
fmt.Println(uf)
uf.Union(0, 2)
fmt.Println(uf)
uf.Union(4, 5)
fmt.Println(uf)
uf.Rollback(-1)
fmt.Println(uf)
uf.Rollback(0)
fmt.Println(uf, uf.Part)
}
func NewUnionFindArrayWithUndo(n int32) *UnionFindArrayWithUndo {
data := make([]int32, n)
for i := range data {
data[i] = -1
}
return &UnionFindArrayWithUndo{Part: n, n: n, data: data}
}
type historyItem struct{ a, b int32 }
type UnionFindArrayWithUndo struct {
Part int32
n int32
innerSnap int32
data []int32
history []*historyItem // (root,data)
}
// !撤销上一次合并操作,没合并成功也要撤销.
func (uf *UnionFindArrayWithUndo) Undo() bool {
if len(uf.history) == 0 {
return false
}
small, smallData := uf.history[len(uf.history)-1].a, uf.history[len(uf.history)-1].b
uf.history = uf.history[:len(uf.history)-1]
big, bigData := uf.history[len(uf.history)-1].a, uf.history[len(uf.history)-1].b
uf.history = uf.history[:len(uf.history)-1]
uf.data[small] = smallData
uf.data[big] = bigData
if big != small {
uf.Part++
}
return true
}
// 保存并查集当前的状态.
//
// !Snapshot() 之后可以调用 Rollback(-1) 回滚到这个状态.
func (uf *UnionFindArrayWithUndo) Snapshot() {
uf.innerSnap = int32(len(uf.history) >> 1)
}
// 回滚到指定的状态.
//
// state 为 -1 表示回滚到上一次 `SnapShot` 时保存的状态.
// 其他值表示回滚到状态id为state时的状态.
func (uf *UnionFindArrayWithUndo) Rollback(state int32) bool {
if state == -1 {
state = uf.innerSnap
}
state <<= 1
if state < 0 || state > int32(len(uf.history)) {
return false
}
for state < int32(len(uf.history)) {
uf.Undo()
}
return true
}
// 获取当前并查集的状态id.
//
// 也就是当前合并(Union)被调用的次数.
func (uf *UnionFindArrayWithUndo) GetState() int {
return len(uf.history) >> 1
}
func (uf *UnionFindArrayWithUndo) Reset() {
for len(uf.history) > 0 {
uf.Undo()
}
}
func (uf *UnionFindArrayWithUndo) Union(x, y int32) bool {
x, y = uf.Find(x), uf.Find(y)
uf.history = append(uf.history, &historyItem{x, uf.data[x]})
uf.history = append(uf.history, &historyItem{y, uf.data[y]})
if x == y {
return false
}
if uf.data[x] > uf.data[y] {
x ^= y
y ^= x
x ^= y
}
uf.data[x] += uf.data[y]
uf.data[y] = x
uf.Part--
return true
}
func (uf *UnionFindArrayWithUndo) Find(x int32) int32 {
cur := x
for uf.data[cur] >= 0 {
cur = uf.data[cur]
}
return cur
}
func (ufa *UnionFindArrayWithUndo) SetPart(part int32) { ufa.Part = part }
func (uf *UnionFindArrayWithUndo) IsConnected(x, y int32) bool { return uf.Find(x) == uf.Find(y) }
func (uf *UnionFindArrayWithUndo) GetSize(x int32) int32 { return -uf.data[uf.Find(x)] }
func (ufa *UnionFindArrayWithUndo) GetGroups() map[int32][]int32 {
groups := make(map[int32][]int32)
for i := int32(0); i < ufa.n; i++ {
root := ufa.Find(i)
groups[root] = append(groups[root], i)
}
return groups
}
func (ufa *UnionFindArrayWithUndo) String() string {
sb := []string{"UnionFindArray:"}
groups := ufa.GetGroups()
keys := make([]int32, 0, len(groups))
for k := range groups {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
for _, root := range keys {
member := groups[root]
cur := fmt.Sprintf("%d: %v", root, member)
sb = append(sb, cur)
}
sb = append(sb, fmt.Sprintf("Part: %d", ufa.Part))
return strings.Join(sb, "\n")
}