From fadaf2a7e3389f205914450e3eab83148258578c Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 13 Aug 2020 14:31:58 +0200 Subject: [PATCH 1/2] Freeze layers for transfer learning. --- training/deepspeech_training/train.py | 30 +++++++++++++++++-- .../deepspeech_training/util/checkpoints.py | 15 ++++++++++ training/deepspeech_training/util/flags.py | 2 ++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index 94ca7c04d3..cc354ca0d6 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -322,8 +322,35 @@ def get_tower_results(iterator, optimizer, dropout_rates): # Retain tower's avg losses tower_avg_losses.append(avg_loss) + train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) + + # Filter out layers if we want to freeze some + if FLAGS.freeze_source_layers > 0: + filter_vars = [] + if FLAGS.freeze_source_layers <= 5: + filter_vars.append("layer_1") + if FLAGS.freeze_source_layers <= 4: + filter_vars.append("layer_2") + if FLAGS.freeze_source_layers <= 3: + filter_vars.append("layer_3") + if FLAGS.freeze_source_layers <= 2: + filter_vars.append("lstm") + if FLAGS.freeze_source_layers <= 1: + filter_vars.append("layer_5") + + new_train_vars = list(train_vars) + for fv in filter_vars: + for tv in train_vars: + if fv in tv.name: + new_train_vars.remove(tv) + train_vars = new_train_vars + msg = "Tower {} - Training only variables: {}" + print(msg.format(i, [v.name for v in train_vars])) + else: + print("Tower {} - Training all layers".format(i)) + # Compute gradients for model parameters using tower's mini-batch - gradients = optimizer.compute_gradients(avg_loss) + gradients = optimizer.compute_gradients(avg_loss, var_list=train_vars) # Retain tower's gradients tower_gradients.append(gradients) @@ -671,7 +698,6 @@ def __call__(self, progress, data, **kwargs): print('-' * 80) - except KeyboardInterrupt: pass log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) diff --git a/training/deepspeech_training/util/checkpoints.py b/training/deepspeech_training/util/checkpoints.py index 459a4d06c6..881c933109 100644 --- a/training/deepspeech_training/util/checkpoints.py +++ b/training/deepspeech_training/util/checkpoints.py @@ -46,6 +46,21 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= 'tensors. Missing variables: {}'.format(missing_var_names)) sys.exit(1) + if FLAGS.load_frozen_graph: + # After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't + # existing anymore because they were not used + # Therefore we have to initialize them again to continue training on such checkpoints + for v in load_vars: + if v.op.name not in vars_in_ckpt: + if 'Adam' in v.name: + init_vars.add(v) + else: + msg = "Tried to load a frozen checkpoint but there was a missing " \ + "variable other than the Adam tensors: {}" + log_error(msg.format(v)) + sys.exit(1) + load_vars -= init_vars + if allow_drop_layers and FLAGS.drop_source_layers > 0: # This transfer learning approach requires supplying # the layers which we exclude from the source model. diff --git a/training/deepspeech_training/util/flags.py b/training/deepspeech_training/util/flags.py index cf32159498..3aad35fc72 100644 --- a/training/deepspeech_training/util/flags.py +++ b/training/deepspeech_training/util/flags.py @@ -93,6 +93,8 @@ def create_flags(): # Transfer Learning f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)') + f.DEFINE_integer('freeze_source_layers', 0, 'use same value as above to freeze the other layers') + f.DEFINE_boolean('load_frozen_graph', False, 'Needed to load a graph checkpoint which was trained with "freeze_source_layers" flag. Allows initialization of missing optimization tensors.') # Exporting From a0d559712f74301a7c9ea42b0bd4be9850edbcbe Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 1 Oct 2020 15:22:02 +0200 Subject: [PATCH 2/2] Refactor freezing. --- training/deepspeech_training/train.py | 15 +--- .../deepspeech_training/util/checkpoints.py | 78 ++++++++++--------- training/deepspeech_training/util/flags.py | 5 +- 3 files changed, 45 insertions(+), 53 deletions(-) diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index cc354ca0d6..781b2033ef 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -29,7 +29,7 @@ from .evaluate import evaluate from six.moves import zip, range from .util.config import Config, initialize_globals -from .util.checkpoints import load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint +from .util.checkpoints import drop_freeze_number_to_layers, load_or_init_graph_for_training, load_graph_for_evaluation, reload_best_checkpoint from .util.evaluate_tools import save_samples_json from .util.feeding import create_dataset, audio_to_features, audiofile_to_features from .util.flags import create_flags, FLAGS @@ -326,18 +326,7 @@ def get_tower_results(iterator, optimizer, dropout_rates): # Filter out layers if we want to freeze some if FLAGS.freeze_source_layers > 0: - filter_vars = [] - if FLAGS.freeze_source_layers <= 5: - filter_vars.append("layer_1") - if FLAGS.freeze_source_layers <= 4: - filter_vars.append("layer_2") - if FLAGS.freeze_source_layers <= 3: - filter_vars.append("layer_3") - if FLAGS.freeze_source_layers <= 2: - filter_vars.append("lstm") - if FLAGS.freeze_source_layers <= 1: - filter_vars.append("layer_5") - + filter_vars = drop_freeze_number_to_layers(FLAGS.freeze_source_layers, "freeze") new_train_vars = list(train_vars) for fv in filter_vars: for tv in train_vars: diff --git a/training/deepspeech_training/util/checkpoints.py b/training/deepspeech_training/util/checkpoints.py index 881c933109..46bfb4ad67 100644 --- a/training/deepspeech_training/util/checkpoints.py +++ b/training/deepspeech_training/util/checkpoints.py @@ -1,9 +1,9 @@ import sys -import tensorflow as tf + import tensorflow.compat.v1 as tfv1 from .flags import FLAGS -from .logging import log_info, log_error, log_warn +from .logging import log_error, log_info, log_warn def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init=True): @@ -19,47 +19,33 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= # compatibility with older checkpoints. lr_var = set(v for v in load_vars if v.op.name == 'learning_rate') if lr_var and ('learning_rate' not in vars_in_ckpt or - (FLAGS.force_initialize_learning_rate and allow_lr_init)): + (FLAGS.force_initialize_learning_rate and allow_lr_init)): assert len(lr_var) <= 1 load_vars -= lr_var init_vars |= lr_var - if FLAGS.load_cudnn: - # Initialize training from a CuDNN RNN checkpoint - # Identify the variables which we cannot load, and set them - # for initialization - missing_vars = set() - for v in load_vars: - if v.op.name not in vars_in_ckpt: - log_warn('CUDNN variable not found: %s' % (v.op.name)) - missing_vars.add(v) + # After training with "freeze_source_layers" the Adam moment tensors for the frozen layers + # are missing because they were not used. This might also occur when loading a cudnn checkpoint + # Therefore we have to initialize them again to continue training on such checkpoints + print_msg = False + for v in load_vars: + if v.op.name not in vars_in_ckpt: + if 'Adam' in v.name: init_vars.add(v) + print_msg = True + if print_msg: + msg = "Some Adam tensors are missing, they will be initialized automatically." + log_info(msg) + load_vars -= init_vars - load_vars -= init_vars - - # Check that the only missing variables (i.e. those to be initialised) - # are the Adam moment tensors, if they aren't then we have an issue - missing_var_names = [v.op.name for v in missing_vars] - if any('Adam' not in v for v in missing_var_names): - log_error('Tried to load a CuDNN RNN checkpoint but there were ' - 'more missing variables than just the Adam moment ' - 'tensors. Missing variables: {}'.format(missing_var_names)) - sys.exit(1) - - if FLAGS.load_frozen_graph: - # After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't - # existing anymore because they were not used - # Therefore we have to initialize them again to continue training on such checkpoints + if FLAGS.load_cudnn: + # Check all required tensors are included in the cudnn checkpoint we want to load for v in load_vars: - if v.op.name not in vars_in_ckpt: - if 'Adam' in v.name: - init_vars.add(v) - else: - msg = "Tried to load a frozen checkpoint but there was a missing " \ - "variable other than the Adam tensors: {}" - log_error(msg.format(v)) - sys.exit(1) - load_vars -= init_vars + if v.op.name not in vars_in_ckpt and 'Adam' not in v.op.name: + msg = 'Tried to load a CuDNN RNN checkpoint but there was a missing' \ + ' variable other than an Adam moment tensor: {}' + log_error(msg.format(v.op.name)) + sys.exit(1) if allow_drop_layers and FLAGS.drop_source_layers > 0: # This transfer learning approach requires supplying @@ -74,7 +60,7 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= 'dropping only 5 layers.') FLAGS.drop_source_layers = 5 - dropped_layers = ['2', '3', 'lstm', '5', '6'][-1 * int(FLAGS.drop_source_layers):] + dropped_layers = drop_freeze_number_to_layers(FLAGS.drop_source_layers, "drop") # Initialize all variables needed for DS, but not loaded from ckpt for v in load_vars: if any(layer in v.op.name for layer in dropped_layers): @@ -90,6 +76,24 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= session.run(v.initializer) +def drop_freeze_number_to_layers(drop_freeze_number, mode): + """ Convert number of layers to drop or freeze into layer names """ + + if drop_freeze_number >= 6: + log_warn('The checkpoint only has 6 layers, but you are trying ' + 'to drop or freeze all of them or more. Continuing with 5 layers.') + drop_freeze_number = 5 + + layer_keys = ["layer_1", "layer_2", "layer_3", "lstm", "layer_5", "layer_6"] + if mode == "drop": + layer_keys = layer_keys[-1 * int(drop_freeze_number):] + elif mode == "freeze": + layer_keys = layer_keys[:-1 * int(drop_freeze_number)] + else: + raise ValueError + return layer_keys + + def _checkpoint_path_or_none(checkpoint_filename): checkpoint = tfv1.train.get_checkpoint_state(FLAGS.load_checkpoint_dir, checkpoint_filename) if not checkpoint: diff --git a/training/deepspeech_training/util/flags.py b/training/deepspeech_training/util/flags.py index 3aad35fc72..525d0a1b75 100644 --- a/training/deepspeech_training/util/flags.py +++ b/training/deepspeech_training/util/flags.py @@ -92,9 +92,8 @@ def create_flags(): # Transfer Learning - f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)') - f.DEFINE_integer('freeze_source_layers', 0, 'use same value as above to freeze the other layers') - f.DEFINE_boolean('load_frozen_graph', False, 'Needed to load a graph checkpoint which was trained with "freeze_source_layers" flag. Allows initialization of missing optimization tensors.') + f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output == 2, etc)') + f.DEFINE_integer('freeze_source_layers', 0, 'freeze layer weights (to freeze all but output == 1, freeze all but penultimate and output == 2, etc). Normally used in combination with "drop_source_layers" flag and should be used in a two step training (first drop and freeze layers and train a few epochs, second continue without both flags)') # Exporting