From dc32fab3cbd1c435079257f64a47d4d1fc16b5fa Mon Sep 17 00:00:00 2001 From: Michael Fuerst Date: Thu, 18 Jan 2018 02:26:22 +0100 Subject: [PATCH] Made training preparation a bit more robust. User does not know the datatypes (users are bad at that). --- examples/gru_function_classifier.json | 2 +- examples/gru_function_classifier_example.py | 14 ++++---- utils/prepare_training.py | 37 ++++++++++++++++++--- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/examples/gru_function_classifier.json b/examples/gru_function_classifier.json index 68ad63c..f858326 100644 --- a/examples/gru_function_classifier.json +++ b/examples/gru_function_classifier.json @@ -5,7 +5,7 @@ "batch_size": 200, "iters": 50000, "summary_iters": 50, - "checkpoint_path": "models/checkpoints/function_classifier" + "checkpoint_path": "models/checkpoints/gru_function_classifier/chkpt" }, "arch": { "sequence_length": 100, diff --git a/examples/gru_function_classifier_example.py b/examples/gru_function_classifier_example.py index f732886..afcf2f8 100644 --- a/examples/gru_function_classifier_example.py +++ b/examples/gru_function_classifier_example.py @@ -1,9 +1,9 @@ import math - +import os import tensorflow as tf from datasets.classification.function_generator import function_generator -from utils.prepare_training import write_tf_records, read_tf_records +from utils.prepare_training import write_tf_records, read_tf_records, PHASE_TRAIN, PHASE_VALIDATION from models.gru_function_classifier import FunctionClassifier @@ -17,11 +17,13 @@ def main(): training_examples_number = 10000 validation_examples_number = 1000 - if GENERATE_DATA: + if GENERATE_DATA or not os.path.exists(data_tmp_folder): + if not os.path.exists(data_tmp_folder): + os.makedirs(data_tmp_folder) # Create training data. print("Generating data") train_data = function_generator([lambda x, off: math.sin(x / 50.0 + off), lambda x, off: x / 50.0 + off], 100, training_examples_number) - validation_data = function_generator([lambda x, off: math.sin(x / 50.0 + off), lambda x, off: x / 50.0 + off], 100, validation_examples_number) + validation_data = function_generator([lambda x, off: math.sin(x / 50.0 + off), lambda x, off: x / 50.0 + off], 100, validation_examples_number) # Write tf records print("Writing data") @@ -33,8 +35,8 @@ def main(): # Load data with tf records. print("Loading data") - train_features, train_labels = read_tf_records(data_tmp_folder, "train", model.hyper_params.train.batch_size, (100,), (2,), tf.float32, tf.uint8, 4) - validation_features, validation_labels = read_tf_records(data_tmp_folder, "validation", model.hyper_params.train.batch_size, (100,), (2,), tf.float32, tf.uint8, 2) + train_features, train_labels = read_tf_records(data_tmp_folder, PHASE_TRAIN, model.hyper_params.train.batch_size) + validation_features, validation_labels = read_tf_records(data_tmp_folder, PHASE_VALIDATION, model.hyper_params.train.batch_size) # Limit used gpu memory. config = tf.ConfigProto() diff --git a/utils/prepare_training.py b/utils/prepare_training.py index cf8490d..763a957 100644 --- a/utils/prepare_training.py +++ b/utils/prepare_training.py @@ -1,8 +1,14 @@ import os +import numpy as np from multiprocessing import Pool import tensorflow as tf from os import listdir from os.path import isfile, join +import json + + +PHASE_TRAIN = "train" +PHASE_VALIDATION = "validation" def _bytes_feature(value): @@ -51,11 +57,20 @@ def _read_tf_record(record_filename, feature_shape, label_shape, feature_type, l return feature, label -def read_tf_records(folder, phase, batch_size, feature_shape, label_shape, feature_type, label_type, num_threads=4): +def read_tf_records(folder, phase, batch_size): + assert phase == PHASE_TRAIN or phase == PHASE_VALIDATION + + config = json.load(open(os.path.join(folder, 'config.json'))) + feature_shape = config["feature_shape"] + label_shape = config["label_shape"] + feature_type = np.dtype(config["feature_dtype"]) + label_type = np.dtype(config["label_dtype"]) + num_threads = config["num_threads_" + phase] + filenames = [folder + "/" + f for f in listdir(folder) if isfile(join(folder, f)) and phase in f] # Create a tf object for the filename list and the readers. - filename_queue = tf.train.string_input_producer(filenames, num_epochs=50000) + filename_queue = tf.train.string_input_producer(filenames) readers = [_read_tf_record(filename_queue, feature_shape, label_shape, feature_type, label_type) for _ in range(num_threads)] feature_batch, label_batch = tf.train.shuffle_batch_join( @@ -68,9 +83,21 @@ def read_tf_records(folder, phase, batch_size, feature_shape, label_shape, featu return feature_batch, label_batch -def write_tf_records(output_folder, num_threads_train, num_threads_validation, train_data, validation_data, preprocess_feature=None, preprocess_label=None): - args_train = [(train_data, num_threads_train, i, os.path.join(output_folder, "train_%d.tfrecords" % i ), preprocess_feature, preprocess_label) for i in range(num_threads_train)] - args_validation = [(validation_data, num_threads_validation, i, os.path.join(output_folder, "validation_%d.tfrecords" % i ), preprocess_feature, preprocess_label) for i in range(num_threads_validation)] +def write_tf_records(output_folder, + num_threads_train, num_threads_validation, + train_data, validation_data, + preprocess_feature=None, preprocess_label=None): + args_train = [(train_data, num_threads_train, i, os.path.join(output_folder, PHASE_TRAIN + "_%d.tfrecords" % i ), preprocess_feature, preprocess_label) for i in range(num_threads_train)] + args_validation = [(validation_data, num_threads_validation, i, os.path.join(output_folder, PHASE_VALIDATION + "_%d.tfrecords" % i ), preprocess_feature, preprocess_label) for i in range(num_threads_validation)] + + sample_feature, sample_label = next(train_data()) + + config = {"num_threads_" + PHASE_TRAIN: num_threads_train, "num_threads_" + PHASE_VALIDATION: num_threads_validation, + "feature_shape": sample_feature.shape, "label_shape": sample_label.shape, + "feature_dtype": sample_feature.dtype.name, "label_dtype": sample_label.dtype.name} + + with open(os.path.join(output_folder, 'config.json'), 'w') as outfile: + json.dump(config, outfile) for arg in args_train + args_validation: _write_tf_record_pool_helper(arg)