diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index e8adb11987..8d9d7d7c24 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -321,8 +321,35 @@ def get_tower_results(iterator, optimizer, dropout_rates): # Retain tower's avg losses tower_avg_losses.append(avg_loss) + train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) + + # Filter out layers if we want to freeze some + if FLAGS.freeze_source_layers > 0: + filter_vars = [] + if FLAGS.freeze_source_layers <= 5: + filter_vars.append("layer_1") + if FLAGS.freeze_source_layers <= 4: + filter_vars.append("layer_2") + if FLAGS.freeze_source_layers <= 3: + filter_vars.append("layer_3") + if FLAGS.freeze_source_layers <= 2: + filter_vars.append("lstm") + if FLAGS.freeze_source_layers <= 1: + filter_vars.append("layer_5") + + new_train_vars = list(train_vars) + for fv in filter_vars: + for tv in train_vars: + if fv in tv.name: + new_train_vars.remove(tv) + train_vars = new_train_vars + msg = "Tower {} - Training only variables: {}" + print(msg.format(i, [v.name for v in train_vars])) + else: + print("Tower {} - Training all layers".format(i)) + # Compute gradients for model parameters using tower's mini-batch - gradients = optimizer.compute_gradients(avg_loss) + gradients = optimizer.compute_gradients(avg_loss, var_list=train_vars) # Retain tower's gradients tower_gradients.append(gradients) @@ -667,7 +694,6 @@ def __call__(self, progress, data, **kwargs): print('-' * 80) - except KeyboardInterrupt: pass log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) diff --git a/training/deepspeech_training/util/checkpoints.py b/training/deepspeech_training/util/checkpoints.py index 459a4d06c6..881c933109 100644 --- a/training/deepspeech_training/util/checkpoints.py +++ b/training/deepspeech_training/util/checkpoints.py @@ -46,6 +46,21 @@ def _load_checkpoint(session, checkpoint_path, allow_drop_layers, allow_lr_init= 'tensors. Missing variables: {}'.format(missing_var_names)) sys.exit(1) + if FLAGS.load_frozen_graph: + # After training with "freeze_source_layers" the Adam tensors for the frozen layers aren't + # existing anymore because they were not used + # Therefore we have to initialize them again to continue training on such checkpoints + for v in load_vars: + if v.op.name not in vars_in_ckpt: + if 'Adam' in v.name: + init_vars.add(v) + else: + msg = "Tried to load a frozen checkpoint but there was a missing " \ + "variable other than the Adam tensors: {}" + log_error(msg.format(v)) + sys.exit(1) + load_vars -= init_vars + if allow_drop_layers and FLAGS.drop_source_layers > 0: # This transfer learning approach requires supplying # the layers which we exclude from the source model. diff --git a/training/deepspeech_training/util/flags.py b/training/deepspeech_training/util/flags.py index fe78f0b706..a98ed6bc47 100644 --- a/training/deepspeech_training/util/flags.py +++ b/training/deepspeech_training/util/flags.py @@ -93,6 +93,8 @@ def create_flags(): # Transfer Learning f.DEFINE_integer('drop_source_layers', 0, 'single integer for how many layers to drop from source model (to drop just output == 1, drop penultimate and output ==2, etc)') + f.DEFINE_integer('freeze_source_layers', 0, 'use same value as above to freeze the other layers') + f.DEFINE_boolean('load_frozen_graph', False, 'Needed to load a graph checkpoint which was trained with "freeze_source_layers" flag. Allows initialization of missing optimization tensors.') # Exporting