Skip to content
This repository has been archived by the owner on Aug 12, 2020. It is now read-only.

Commit

Permalink
Made training preparation a bit more robust.
Browse files Browse the repository at this point in the history
User does not know the datatypes (users are bad at that).
  • Loading branch information
penguinmenac3 committed Jan 18, 2018
1 parent d71f3df commit dc32fab
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/gru_function_classifier.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"batch_size": 200,
"iters": 50000,
"summary_iters": 50,
"checkpoint_path": "models/checkpoints/function_classifier"
"checkpoint_path": "models/checkpoints/gru_function_classifier/chkpt"
},
"arch": {
"sequence_length": 100,
Expand Down
14 changes: 8 additions & 6 deletions examples/gru_function_classifier_example.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import math

import os
import tensorflow as tf

from datasets.classification.function_generator import function_generator
from utils.prepare_training import write_tf_records, read_tf_records
from utils.prepare_training import write_tf_records, read_tf_records, PHASE_TRAIN, PHASE_VALIDATION
from models.gru_function_classifier import FunctionClassifier


Expand All @@ -17,11 +17,13 @@ def main():
training_examples_number = 10000
validation_examples_number = 1000

if GENERATE_DATA:
if GENERATE_DATA or not os.path.exists(data_tmp_folder):
if not os.path.exists(data_tmp_folder):
os.makedirs(data_tmp_folder)
# Create training data.
print("Generating data")
train_data = function_generator([lambda x, off: math.sin(x / 50.0 + off), lambda x, off: x / 50.0 + off], 100, training_examples_number)
validation_data = function_generator([lambda x, off: math.sin(x / 50.0 + off), lambda x, off: x / 50.0 + off], 100, validation_examples_number)
validation_data = function_generator([lambda x, off: math.sin(x / 50.0 + off), lambda x, off: x / 50.0 + off], 100, validation_examples_number)

# Write tf records
print("Writing data")
Expand All @@ -33,8 +35,8 @@ def main():

# Load data with tf records.
print("Loading data")
train_features, train_labels = read_tf_records(data_tmp_folder, "train", model.hyper_params.train.batch_size, (100,), (2,), tf.float32, tf.uint8, 4)
validation_features, validation_labels = read_tf_records(data_tmp_folder, "validation", model.hyper_params.train.batch_size, (100,), (2,), tf.float32, tf.uint8, 2)
train_features, train_labels = read_tf_records(data_tmp_folder, PHASE_TRAIN, model.hyper_params.train.batch_size)
validation_features, validation_labels = read_tf_records(data_tmp_folder, PHASE_VALIDATION, model.hyper_params.train.batch_size)

# Limit used gpu memory.
config = tf.ConfigProto()
Expand Down
37 changes: 32 additions & 5 deletions utils/prepare_training.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import os
import numpy as np
from multiprocessing import Pool
import tensorflow as tf
from os import listdir
from os.path import isfile, join
import json


PHASE_TRAIN = "train"
PHASE_VALIDATION = "validation"


def _bytes_feature(value):
Expand Down Expand Up @@ -51,11 +57,20 @@ def _read_tf_record(record_filename, feature_shape, label_shape, feature_type, l
return feature, label


def read_tf_records(folder, phase, batch_size, feature_shape, label_shape, feature_type, label_type, num_threads=4):
def read_tf_records(folder, phase, batch_size):
assert phase == PHASE_TRAIN or phase == PHASE_VALIDATION

config = json.load(open(os.path.join(folder, 'config.json')))
feature_shape = config["feature_shape"]
label_shape = config["label_shape"]
feature_type = np.dtype(config["feature_dtype"])
label_type = np.dtype(config["label_dtype"])
num_threads = config["num_threads_" + phase]

filenames = [folder + "/" + f for f in listdir(folder) if isfile(join(folder, f)) and phase in f]

# Create a tf object for the filename list and the readers.
filename_queue = tf.train.string_input_producer(filenames, num_epochs=50000)
filename_queue = tf.train.string_input_producer(filenames)
readers = [_read_tf_record(filename_queue, feature_shape, label_shape, feature_type, label_type) for _ in range(num_threads)]

feature_batch, label_batch = tf.train.shuffle_batch_join(
Expand All @@ -68,9 +83,21 @@ def read_tf_records(folder, phase, batch_size, feature_shape, label_shape, featu
return feature_batch, label_batch


def write_tf_records(output_folder, num_threads_train, num_threads_validation, train_data, validation_data, preprocess_feature=None, preprocess_label=None):
args_train = [(train_data, num_threads_train, i, os.path.join(output_folder, "train_%d.tfrecords" % i ), preprocess_feature, preprocess_label) for i in range(num_threads_train)]
args_validation = [(validation_data, num_threads_validation, i, os.path.join(output_folder, "validation_%d.tfrecords" % i ), preprocess_feature, preprocess_label) for i in range(num_threads_validation)]
def write_tf_records(output_folder,
num_threads_train, num_threads_validation,
train_data, validation_data,
preprocess_feature=None, preprocess_label=None):
args_train = [(train_data, num_threads_train, i, os.path.join(output_folder, PHASE_TRAIN + "_%d.tfrecords" % i ), preprocess_feature, preprocess_label) for i in range(num_threads_train)]
args_validation = [(validation_data, num_threads_validation, i, os.path.join(output_folder, PHASE_VALIDATION + "_%d.tfrecords" % i ), preprocess_feature, preprocess_label) for i in range(num_threads_validation)]

sample_feature, sample_label = next(train_data())

config = {"num_threads_" + PHASE_TRAIN: num_threads_train, "num_threads_" + PHASE_VALIDATION: num_threads_validation,
"feature_shape": sample_feature.shape, "label_shape": sample_label.shape,
"feature_dtype": sample_feature.dtype.name, "label_dtype": sample_label.dtype.name}

with open(os.path.join(output_folder, 'config.json'), 'w') as outfile:
json.dump(config, outfile)

for arg in args_train + args_validation:
_write_tf_record_pool_helper(arg)
Expand Down

0 comments on commit dc32fab

Please sign in to comment.