Skip to content

Commit

Permalink
run network training as subprocesses to free memory
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-louvet committed Aug 5, 2020
1 parent 646e797 commit 767275f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 24 deletions.
34 changes: 10 additions & 24 deletions ai_playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import os
import pickle
from sklearn.utils import shuffle
from generate_data_dimension import *
from generate_data_dimension import DataDescriptor, DataPoint, DataInstance
import numpy as np
import gc
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
import subprocess
import json

randomSeed = 468643654

Expand Down Expand Up @@ -102,24 +101,11 @@
os.remove(os.path.join(root, file))

for i in [1, 2, 4, 6, 8, 10, 12, 14, 16]:
with open('/tmp/data_file.pkl', 'wb') as dump:
pickle.dump(data, dump, pickle.HIGHEST_PROTOCOL)

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(
8, input_dim=input_shape[0], activation='relu', kernel_initializer='he_uniform'))
for _ in range(i):
model.add(tf.keras.layers.Dense(8, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

csv = tf.keras.callbacks.CSVLogger(
mypath + str(i) + 'layer.csv', separator=',', append=False)
model.summary()
model.compile(optimizer="adam",
loss='binary_crossentropy', metrics=['accuracy'])
model.fit(data, label, validation_split=0.2, batch_size=64,
epochs=epoch_number, shuffle=True, verbose=2, callbacks=[csv])
model.save(mypath + str(i) + 'layer.h5')
del model
del csv
gc.collect()
tf.keras.backend.clear_session()
tf.compat.v1.reset_default_graph()
with open('/tmp/label_file.pkl', 'wb') as dump:
pickle.dump(label, dump, pickle.HIGHEST_PROTOCOL)

subprocess.call(['python3', 'rasScript.py', str(i), str(input_shape[0]),
str(mypath), str(epoch_number), '/tmp/data_file.pkl', '/tmp/label_file.pkl'])
1 change: 1 addition & 0 deletions generate_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def train_and_save(*args, **kwargs):

def full_net_combined(i, input_shape, mypath, epoch_number, data, label):
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
model = build_model(
depth=i, input_shape=input_shape, width=8, activation='relu')
csv = tf.keras.callbacks.CSVLogger(
Expand Down
41 changes: 41 additions & 0 deletions rasScript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pickle
import sys
import gc
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

i = int(sys.argv[1])
input_shape = int(sys.argv[2])
mypath = sys.argv[3]
epoch_number = int(sys.argv[4])


with open(sys.argv[5], 'rb') as input:
data = pickle.load(input)

with open(sys.argv[6], 'rb') as input:
label = pickle.load(input)

print(i, input_shape, epoch_number)

model = tf.keras.Sequential()

model.add(tf.keras.layers.Dense(
8, input_dim=input_shape, activation='relu', kernel_initializer='he_uniform'))
for _ in range(i):
model.add(tf.keras.layers.Dense(8, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))

csv = tf.keras.callbacks.CSVLogger(
mypath + str(i) + 'layer.csv', separator=',', append=False)

model.summary()
model.compile(optimizer="adam", loss='binary_crossentropy',
metrics=['accuracy'])
model.fit(data, label, validation_split=0.2, batch_size=64,
epochs=epoch_number, shuffle=True, verbose=2, callbacks=[csv])
model.save(mypath + str(i) + 'layer.h5')
del model
gc.collect()
tf.keras.backend.clear_session()
tf.compat.v1.reset_default_graph()

0 comments on commit 767275f

Please sign in to comment.