-
Notifications
You must be signed in to change notification settings - Fork 0
/
util.py
73 lines (61 loc) · 2.06 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import pylab as pl
pl.style.use('ggplot')
class RTPlot:
def __init__(self, plots=[1,1], pause=0.01):
pl.ion()
fig, ax = pl.subplots(*plots)
if plots[0]*plots[1] == 1:
ax = np.array([[ax]])
ax = ax.reshape(-1)
self.fax = [fig, ax]
self.pause = pause
self.plots = plots
self.xy = [([],[]) for i in range(ax.size)]
self.data = [self.fax[1][i].plot(*self.xy[i])[0] for i in range(ax.size)]
self.fax[0].canvas.draw()
def update(self, xya):
sz = min(len(xya), self.fax[1].size)
for i in range(sz):
if xya[i][-1]:
self.xy[i][0].append(xya[i][0])
self.xy[i][1].append(xya[i][1])
else:
self.xy[i] = xya[i][:-1]
for i in range(sz):
self.data[i].set_data(*self.xy[i])
self.fax[1][i].relim()
self.fax[1][i].autoscale_view(True,True,True)
self.fax[0].canvas.draw()
pl.pause(self.pause)
def close(self):
pl.waitforbuttonpress()
def __del__(self):
self.close()
class AdaGrad:
def __init__(self, vns, init=0.1, max_norm=0, decay_step=100, decay_rate=0.95, min_lr=0.01):
self.G = {vn:init for vn in vns}
self.max_norm = max_norm
self.n_step = 0
self.decay_step = decay_step
self.decay_rate = decay_rate
self.min_lr = min_lr
def update(self, vgs, lrate=0.1):
lrate = max(lrate*pow(self.decay_rate, int(self.n_step/self.decay_step)), self.min_lr)
for vn,v,g in vgs:
v = v.data
g = g.data
if self.max_norm>0:
norm = g.norm()
if norm>self.max_norm:
g = g*self.max_norm/norm
v += lrate*g/np.sqrt(self.G[vn])
self.G[vn] += g**2
self.n_step += 1
if __name__ == '__main__':
rtp = RTPlot([2,1], pause=0.002)
for i in range(10):
rtp.update([(i,i), (i,i)], 1)