Skip to content

Commit

Permalink
GC seems to pass for LSTM o.o
Browse files Browse the repository at this point in the history
  • Loading branch information
csxeba committed May 7, 2019
1 parent 9ae1120 commit 5876042
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
8 changes: 4 additions & 4 deletions xperiments/xp_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
inshape, outshape = X.shape[1:], Y.shape[1:]

network = BackpropNetwork(input_shape=inshape, layerstack=[
DenseLayer(64, activation="tanh", trainable=1),
DenseLayer(32, activation="tanh", trainable=1),
DenseLayer(outshape, activation="sigmoid", trainable=1)
], cost="bxent", optimizer="sgd")
DenseLayer(32, activation="sigmoid", trainable=1),
DenseLayer(32, activation="sigmoid", trainable=1),
DenseLayer(outshape, activation="linear", trainable=1)
], cost="mse", optimizer="sgd")
network.fit(X[5:], Y[5:], epochs=1, batch_size=len(X)-5, verbose=0)

gcsuite = GradientCheck(network, epsilon=1e-7)
Expand Down
13 changes: 6 additions & 7 deletions xperiments/xp_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@
from brainforge.layers import LSTM, DenseLayer
from brainforge.gradientcheck import GradientCheck

np.random.seed(1337)
# np.random.seed(1337)

DSHAPE = 10, 1, 15
OUTSHP = 10, 15
DSHAPE = 20, 10, 1
OUTSHP = 20, 1
X = np.random.randn(*DSHAPE)
Y = np.random.randn(*OUTSHP)

net = BackpropNetwork(input_shape=DSHAPE[1:], layerstack=[
LSTM(32, activation="tanh"),
DenseLayer(10, activation="tanh", trainable=False),
DenseLayer(OUTSHP[1:], activation="linear", trainable=False)
LSTM(16, activation="tanh"),
DenseLayer(OUTSHP[1:], activation="linear", trainable=0)
], cost="mse", optimizer="sgd")

net.fit(X, Y, epochs=1, verbose=0)
GradientCheck(net, epsilon=1e-6, display=True).run(X, Y, throw=True)
GradientCheck(net, display=True).run(X, Y, throw=True)
2 changes: 1 addition & 1 deletion xperiments/xp_pggymin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from matplotlib import pyplot

from brainforge import BackpropNetwork
from brainforge.learner import BackpropNetwork
from brainforge.layers import DenseLayer
from brainforge.optimization import Momentum
from brainforge.reinforcement import PG, AgentConfig
Expand Down

0 comments on commit 5876042

Please sign in to comment.