Skip to content

Commit

Permalink
What is broken cannot be reforged
Browse files Browse the repository at this point in the history
  • Loading branch information
csxeba committed Dec 20, 2017
1 parent 60ce37a commit e85d08c
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions xperiments/xp_petofi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,31 @@
from brainforge.optimization import RMSprop

from keras.models import Sequential
from keras.layers import LSTM as kLSTM, Dense as kDense
from keras.layers import LSTM as kLSTM, Dense as kDense, BatchNormalization

data = LazyText(os.path.expanduser("~/tmp/RIS.txt.txt"), n_gram=1, timestep=5)
data = LazyText(os.path.expanduser("~/Prog/data/txt/scripts.txt"), n_gram=1, timestep=10)
inshape, outshape = data.neurons_required
net = BackpropNetwork(input_shape=inshape, layerstack=[
LSTM(60, activation="tanh"),
DenseLayer(60, activation="tanh"),
DenseLayer(outshape, activation="softmax")
], cost="xent", optimizer=RMSprop(eta=0.01))

net.fit_generator(data.batchgen(20), lessons_per_epoch=data.N)

keras = Sequential([
kLSTM(60, input_shape=inshape, activation="tanh"),
kDense(60, activation="tanh"),
kDense(outshape, activation="softmax")
])
keras.compile(optimizer="rmsprop", loss="categorical_crossentropy", metrics=["acc"])
keras.fit_generator(data.batchgen(20), steps_per_epoch=data.N, epochs=30)


def run_brainforge():
net = BackpropNetwork(input_shape=inshape, layerstack=[
LSTM(60, activation="tanh"),
DenseLayer(60, activation="tanh"),
DenseLayer(outshape, activation="softmax")
], cost="xent", optimizer=RMSprop(eta=0.01))

net.fit_generator(data.batchgen(20), lessons_per_epoch=data.N)


def run_keras():
keras = Sequential([
kLSTM(120, input_shape=inshape, activation="relu"), BatchNormalization(),
kDense(60, activation="relu"), BatchNormalization(),
kDense(outshape, activation="softmax")
])
keras.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["acc"])
keras.fit_generator(data.batchgen(50), steps_per_epoch=data.N, epochs=30)


if __name__ == '__main__':
run_keras()

0 comments on commit e85d08c

Please sign in to comment.