-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
online mix noise audio data in training step #2622
base: master
Are you sure you want to change the base?
Changes from 13 commits
681f470
421243d
d08efad
b0a14b5
ba1a587
aebd08d
d255c3f
ec25136
484134e
4f24f08
1f57ece
66cc7c4
b7eb0f4
ccae7cc
8cc95f9
9e2648a
2269514
0b8147c
42bc45b
289722d
9334e79
25736e0
40b431b
f7d1279
7792226
c4c3ced
c151b1d
c089b7f
491a4b0
735cbbb
2fa91e8
6b820bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -218,7 +218,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse): | |
the decoded result and the batch's original Y. | ||
''' | ||
# Obtain the next batch of data | ||
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next() | ||
batch_filenames, (batch_x, batch_seq_len), batch_y, review_audio = iterator.get_next() | ||
|
||
if FLAGS.use_cudnn_rnn: | ||
rnn_impl = rnn_impl_cudnn_rnn | ||
|
@@ -238,7 +238,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse): | |
avg_loss = tf.reduce_mean(input_tensor=total_loss) | ||
|
||
# Finally we return the average loss | ||
return avg_loss, non_finite_files | ||
return avg_loss, non_finite_files, review_audio | ||
|
||
|
||
# Adam Optimization | ||
|
@@ -299,7 +299,7 @@ def get_tower_results(iterator, optimizer, dropout_rates): | |
with tf.name_scope('tower_%d' % i): | ||
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded | ||
# batch along with the original batch's labels (Y) of this tower | ||
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0) | ||
avg_loss, non_finite_files, review_audio = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0) | ||
|
||
# Allow for variables to be re-used by the next tower | ||
tfv1.get_variable_scope().reuse_variables() | ||
|
@@ -316,6 +316,8 @@ def get_tower_results(iterator, optimizer, dropout_rates): | |
tower_non_finite_files.append(non_finite_files) | ||
|
||
avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0) | ||
if FLAGS.augmentation_review_audio_steps: | ||
tfv1.summary.audio(name='step_audio', tensor=review_audio, sample_rate=16000, collections=['step_audio_summaries']) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries']) | ||
|
||
all_non_finite_files = tf.concat(tower_non_finite_files, axis=0) | ||
|
@@ -436,7 +438,8 @@ def train(): | |
batch_size=FLAGS.train_batch_size, | ||
enable_cache=FLAGS.feature_cache and do_cache_dataset, | ||
cache_path=FLAGS.feature_cache, | ||
train_phase=True) | ||
train_phase=True, | ||
noise_dirs_or_files=FLAGS.audio_aug_mix_noise_train_dirs_or_files) | ||
|
||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), | ||
tfv1.data.get_output_shapes(train_set), | ||
|
@@ -447,7 +450,7 @@ def train(): | |
|
||
if FLAGS.dev_files: | ||
dev_csvs = FLAGS.dev_files.split(',') | ||
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False) for csv in dev_csvs] | ||
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False, noise_dirs_or_files=FLAGS.audio_aug_mix_noise_dev_dirs_or_files) for csv in dev_csvs] | ||
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] | ||
|
||
# Dropout | ||
|
@@ -483,6 +486,7 @@ def train(): | |
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) | ||
|
||
# Summaries | ||
step_audio_summaries_op = tfv1.summary.merge_all('step_audio_summaries') | ||
step_summaries_op = tfv1.summary.merge_all('step_summaries') | ||
step_summary_writers = { | ||
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), | ||
|
@@ -593,11 +597,20 @@ def __call__(self, progress, data, **kwargs): | |
session.run(init_op) | ||
|
||
# Batch loop | ||
|
||
i_audio_steps = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. modified |
||
while True: | ||
try: | ||
_, current_step, batch_loss, problem_files, step_summary = \ | ||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], | ||
feed_dict=feed_dict) | ||
step_audio_summary = None | ||
if i_audio_steps < FLAGS.augmentation_review_audio_steps and epoch == 0: | ||
_, current_step, batch_loss, problem_files, step_summary, step_audio_summary = \ | ||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op, step_audio_summaries_op], | ||
feed_dict=feed_dict) | ||
i_audio_steps += 1 | ||
else: | ||
_, current_step, batch_loss, problem_files, step_summary = \ | ||
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], | ||
feed_dict=feed_dict) | ||
except tf.errors.OutOfRangeError: | ||
break | ||
|
||
|
@@ -611,6 +624,9 @@ def __call__(self, progress, data, **kwargs): | |
|
||
pbar.update(step_count) | ||
|
||
if step_audio_summary is not None: | ||
step_summary_writer.add_summary(step_audio_summary, current_step) | ||
|
||
step_summary_writer.add_summary(step_summary, current_step) | ||
|
||
if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: | ||
|
@@ -673,7 +689,7 @@ def __call__(self, progress, data, **kwargs): | |
|
||
|
||
def test(): | ||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading) | ||
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading, noise_dirs_or_files=FLAGS.audio_aug_mix_noise_test_dirs_or_files) | ||
if FLAGS.test_output_file: | ||
# Save decoded tuples as JSON, converting NumPy floats to Python floats | ||
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float) | ||
|
@@ -896,7 +912,7 @@ def do_single_file_inference(input_file_path): | |
print('Could not load checkpoint from {}'.format(FLAGS.checkpoint_dir)) | ||
sys.exit(1) | ||
|
||
features, features_len = audiofile_to_features(input_file_path) | ||
features, features_len = audiofile_to_features(input_file_path, 0.0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
previous_state_c = np.zeros([1, Config.n_cell_dim]) | ||
previous_state_h = np.zeros([1, Config.n_cell_dim]) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
from __future__ import absolute_import, division, print_function | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about |
||
|
||
# Make sure we can import stuff from util/ | ||
# This script needs to be run from the root of the DeepSpeech repository | ||
|
||
from util.feeding import secs_to_hours | ||
from librosa import get_duration | ||
from multiprocessing import Pool | ||
from functools import partial | ||
import math | ||
import argparse | ||
import sys | ||
import os | ||
sys.path.insert(1, os.path.join(sys.path[0], '..')) | ||
|
||
try: | ||
import tqdm | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to keep dependencies at a minimum. Please change this to the way how we deal with progress notifications in the importers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should be able to replace it with:
|
||
except ImportError as err: | ||
print('[ImportError] try `pip install tqdm`') | ||
raise err | ||
|
||
try: | ||
from pydub import AudioSegment | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We want to keep dependencies at a minimum. Please check, if your required functionality couldn't be covered by e.g. |
||
except ImportError as err: | ||
print('[ImportError] try `sudo apt-get install ffmpeg && pip install pydub`') | ||
raise err | ||
|
||
|
||
def detect_silence(sound: AudioSegment, silence_threshold=-50.0, chunk_size=10): | ||
start_trim = 0 # ms | ||
sound_size = len(sound) | ||
assert chunk_size > 0 # to avoid infinite loop | ||
while sound[start_trim:(start_trim + chunk_size)].dBFS < silence_threshold and start_trim < sound_size: | ||
start_trim += chunk_size | ||
|
||
end_trim = sound_size | ||
while sound[(end_trim - chunk_size):end_trim].dBFS < silence_threshold and end_trim > 0: | ||
end_trim -= chunk_size | ||
|
||
start_trim = min(sound_size, start_trim) | ||
end_trim = max(0, end_trim) | ||
|
||
return min([start_trim, end_trim]), max([start_trim, end_trim]) | ||
|
||
|
||
def trim_silence_audio(sound: AudioSegment, silence_threshold=-50.0, chunk_size=10): | ||
start_trim, end_trim = detect_silence(sound, silence_threshold, chunk_size) | ||
return sound[start_trim:end_trim] | ||
|
||
|
||
def convert(filename, dst_dirpath, dirpath, normalize, trim_silence, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check, if how this is covered by or could be merged into the current audio.py. |
||
min_duration_seconds, max_duration_seconds): | ||
if not filename.endswith(('.wav', '.raw')): | ||
return | ||
|
||
filepath = os.path.join(dirpath, filename) | ||
if filename.endswith('.wav'): | ||
sound: AudioSegment = AudioSegment.from_file(filepath) | ||
else: | ||
try: | ||
sound: AudioSegment = AudioSegment.from_raw(filepath, | ||
sample_width=2, | ||
frame_rate=44100, | ||
channels=1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please take |
||
except Exception as err: # pylint: disable=broad-except | ||
print('Retrying conversion: {}'.format(err)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? |
||
try: | ||
sound: AudioSegment = AudioSegment.from_raw(filepath, | ||
sample_width=2, | ||
frame_rate=48000, | ||
channels=1) | ||
except Exception as err: # pylint: disable=broad-except | ||
print('Skipping file {}, got error: {}'.format(filepath, err)) | ||
return | ||
try: | ||
sound = sound.set_frame_rate(16000) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please make this command-line configurable. |
||
except Exception as err: # pylint: disable=broad-except | ||
print('Skipping {}'.format(err)) | ||
return | ||
|
||
n_splits = max(1, math.ceil(sound.duration_seconds / max_duration_seconds)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great idea to split noise into chunks to limit wasted overlap during augmentation! |
||
chunk_duration_ms = math.ceil(len(sound) / n_splits) | ||
chunks = [] | ||
|
||
for i in range(n_splits): | ||
end_ms = min((i + 1) * chunk_duration_ms, len(sound)) | ||
chunk = sound[(i * chunk_duration_ms):end_ms] | ||
chunks.append(chunk) | ||
|
||
for i, chunk in enumerate(chunks): | ||
dst_path = os.path.join(dst_dirpath, str(i) + '_' + filename) | ||
if dst_path.endswith('.raw'): | ||
dst_path = dst_path[:-4] + '.wav' | ||
|
||
if os.path.exists(dst_path): | ||
print('Audio already exists: {}'.format(dst_path)) | ||
return | ||
|
||
if normalize: | ||
chunk = chunk.normalize() | ||
if chunk.dBFS < -30.0: | ||
chunk = chunk.compress_dynamic_range().normalize() | ||
if chunk.dBFS < -30.0: | ||
chunk = chunk.compress_dynamic_range().normalize() | ||
if trim_silence: | ||
chunk = trim_silence_audio(chunk) | ||
|
||
if chunk.duration_seconds < min_duration_seconds: | ||
return | ||
chunk.export(dst_path, format='wav') | ||
|
||
|
||
def get_noise_duration(dst_dir): | ||
duration = 0.0 | ||
file_num = 0 | ||
for dirpath, _, filenames in os.walk(dst_dir): | ||
for f in filenames: | ||
if not f.endswith('.wav'): | ||
continue | ||
duration += get_duration(filename=os.path.join(dirpath, f)) | ||
file_num += 1 | ||
return duration, file_num | ||
|
||
|
||
def main(src_dir, | ||
dst_dir, | ||
min_duration_seconds, | ||
max_duration_seconds, | ||
normalize=True, | ||
trim_silence=True): | ||
assert os.path.exists(src_dir) | ||
if not os.path.exists(dst_dir): | ||
os.makedirs(dst_dir, exist_ok=False) | ||
src_dir = os.path.abspath(src_dir) | ||
dst_dir = os.path.abspath(dst_dir) | ||
|
||
for dirpath, _, filenames in os.walk(src_dir): | ||
dirpath = os.path.abspath(dirpath) | ||
dst_dirpath = os.path.join( | ||
dst_dir, dirpath.replace(src_dir, '').lstrip('/')) | ||
|
||
print('Converting directory: {} -> {}'.format(dirpath, dst_dirpath)) | ||
if not os.path.exists(dst_dirpath): | ||
os.makedirs(dst_dirpath, exist_ok=False) | ||
|
||
convert_func = partial(convert, | ||
dst_dirpath=dst_dirpath, | ||
dirpath=dirpath, | ||
normalize=normalize, | ||
trim_silence=trim_silence, | ||
min_duration_seconds=min_duration_seconds, | ||
max_duration_seconds=max_duration_seconds) | ||
|
||
pool = Pool(processes=None) | ||
for _ in tqdm.tqdm(pool.imap_unordered(convert_func, filenames), total=len(filenames)): | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
PARSER = argparse.ArgumentParser(description='Optimize noise files') | ||
PARSER.add_argument('--from_dir', help='Convert wav from directory', type=str) | ||
PARSER.add_argument('--to_dir', help='save wav to directory', type=str) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This tool should also be able to produce SDBs like our SDB tool. I'll put up a PR for changing classes |
||
PARSER.add_argument('--min_sec', help='min duration seconds of saved file', type=float, default=1.0) | ||
PARSER.add_argument('--max_sec', help='max duration seconds of saved file', type=float, default=30.0) | ||
PARSER.add_argument('--normalize', action='store_true', help='Normalize sound range, default is true', default=True) | ||
PARSER.add_argument('--trim', action='store_true', help='Trim silence, default is true', default=True) | ||
PARAMS = PARSER.parse_args() | ||
|
||
main(PARAMS.from_dir, PARAMS.to_dir, PARAMS.min_sec, PARAMS.max_sec, PARAMS.normalize, PARAMS.trim) | ||
|
||
DURATION, FILE_NUM = get_noise_duration(PARAMS.to_dir) | ||
print("Your noise dataset has {} files and a duration of {}\n".format(FILE_NUM, secs_to_hours(DURATION))) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,7 +41,7 @@ def sparse_tuple_to_texts(sp_tuple, alphabet): | |
return [alphabet.decode(res) for res in results] | ||
|
||
|
||
def evaluate(test_csvs, create_model, try_loading): | ||
def evaluate(test_csvs, create_model, try_loading, noise_dirs_or_files=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. modified |
||
if FLAGS.lm_binary_path: | ||
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, | ||
FLAGS.lm_binary_path, FLAGS.lm_trie_path, | ||
|
@@ -50,13 +50,13 @@ def evaluate(test_csvs, create_model, try_loading): | |
scorer = None | ||
|
||
test_csvs = FLAGS.test_files.split(',') | ||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False) for csv in test_csvs] | ||
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False, noise_dirs_or_files=noise_dirs_or_files) for csv in test_csvs] | ||
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]), | ||
tfv1.data.get_output_shapes(test_sets[0]), | ||
output_classes=tfv1.data.get_output_classes(test_sets[0])) | ||
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets] | ||
|
||
batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next() | ||
batch_wav_filename, (batch_x, batch_x_len), batch_y, _ = iterator.get_next() | ||
|
||
# One rate per layer | ||
no_dropout = [None] * 6 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the
review_audio
idea.