Skip to content

Commit

Permalink
Rewrote xp_sin a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
csxeba committed Aug 18, 2017
1 parent 620b702 commit 55cd1d1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
4 changes: 4 additions & 0 deletions model/layerstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def feedforward(self, X):
X = layer.feedforward(X)
return X

def get_states(self, unfold=True):
hs = [layer.output for layer in self.layers]
return np.concatenate(hs) if unfold else ws

def get_weights(self, unfold=True):
ws = [layer.get_weights(unfold=unfold) for
layer in self.layers if layer.trainable]
Expand Down
49 changes: 34 additions & 15 deletions xperiments/xp_sin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,47 @@

np.random.seed(1234)

X = np.linspace(-6., 6., 100)[:, None]
Y = np.sqrt(np.sin(X) + 1)
rX = np.linspace(-6., 6., 200)[:, None]
rY = np.sin(rX)

net = BackpropNetwork([DenseLayer(120, activation="relu"),
DenseLayer(120, activation="relu"),
DenseLayer(40, activation="relu"),
arg = np.arange(len(rX))
np.random.shuffle(arg)
targ, varg = arg[:100], arg[100:]
targ.sort()
varg.sort()

tX, tY = rX[targ], rY[targ]
vX, vY = rX[varg], rY[varg]

tX += np.random.randn(*tX.shape) / np.sqrt(tX.size*0.25)

net = BackpropNetwork([DenseLayer(120, activation="tanh"),
DenseLayer(120, activation="tanh"),
DenseLayer(1, activation="linear")],
input_shape=1, optimizer="adam")

pred = net.predict(X)
tpred = net.predict(tX)
vpred = net.predict(vX)
plt.ion()
plt.plot(X, Y, "b--")
plt.plot(tX, tY, "b--", alpha=0.5, label="Training data (noisy)")
plt.plot(rX, rY, "r--", alpha=0.5, label="Validation data (clean)")
plt.ylim(-2, 2)
plt.plot(X, np.ones_like(X), c="black", linestyle="--")
plt.plot(X, -np.ones_like(X), c="black", linestyle="--")
plt.plot(X, np.zeros_like(X), c="grey", linestyle="--")
obj, = plt.plot(X, pred, "r-", linewidth=2)
plt.plot(rX, np.ones_like(rX), c="black", linestyle="--")
plt.plot(rX, -np.ones_like(rX), c="black", linestyle="--")
plt.plot(rX, np.zeros_like(rX), c="grey", linestyle="--")
tobj, = plt.plot(tX, tpred, "bo", markersize=3, alpha=0.5, label="Training pred")
vobj, = plt.plot(vX, vpred, "ro", markersize=3, alpha=0.5, label="Validation pred")
templ = "Batch: {:>5}, tMSE: {:>.4f}, vMSE: {:>.4f}"
t = plt.title(templ.format(0, 0., 0.))
plt.legend()
batchno = 1
while 1:
cost = net.learn_batch(X, Y)
pred = net.predict(X)
obj.set_data(X, pred)
tcost = net.learn_batch(tX, tY)
tpred = net.predict(tX)
vpred = net.predict(vX)
vcost = net.cost(vpred, vY) / len(vpred)
tobj.set_data(tX, tpred)
vobj.set_data(vX, vpred)
plt.pause(0.01)
plt.title(f"Batch: {batchno:>5}, MSE: {cost:>.4f}")
t.set_text(templ.format(batchno, tcost, vcost))
batchno += 1

0 comments on commit 55cd1d1

Please sign in to comment.