Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

online mix noise audio data in training step #2622

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
681f470
Remove comments check from alphabet
carlfm01 Jun 5, 2019
421243d
Remove sort from feeding
carlfm01 Jun 5, 2019
d08efad
Remove sort from evaluate tools
carlfm01 Jun 5, 2019
b0a14b5
Merge pull request #1 from carlfm01/master
carlfm01 Jun 29, 2019
ba1a587
Remove TF dependency
carlfm01 Jun 29, 2019
aebd08d
[ADD] mix noise audio
mychiux413 Dec 30, 2019
d255c3f
[FIX] add missing file decoded_augmentation.py
mychiux413 Dec 30, 2019
ec25136
mix noise works, but performance is bad
mychiux413 Dec 31, 2019
484134e
[MOD] use tf.Dataset to cache noise audio
mychiux413 Dec 31, 2019
4f24f08
rename decoded -> audio
mychiux413 Dec 31, 2019
1f57ece
[FIX] don't create tf.Dataset in other tf.Dataset's pipeline
mychiux413 Jan 2, 2020
66cc7c4
limit audio signal between +-1.0
mychiux413 Jan 13, 2020
b7eb0f4
[FIX] switch shuffle/map for memory cost, replace cache with prefetch…
mychiux413 Feb 11, 2020
ccae7cc
[MOD] limit the buffer size of .shuffle() to protect memory usage
mychiux413 Feb 17, 2020
8cc95f9
[ADD] bin/normalize_noise_audio.py
mychiux413 Feb 19, 2020
9e2648a
[MOD] mix noise into complete audio
mychiux413 Feb 21, 2020
2269514
[ADD] dev/test dataset can also mix noise [MOD] use SNR to balance no…
mychiux413 Mar 6, 2020
0b8147c
[ADD] use dbfs and SNR to determine the balance of audio/noise, add o…
mychiux413 Mar 16, 2020
42bc45b
[FIX] audiofile_to_features & samples_to_mfccs return 3 values now, a…
mychiux413 Mar 19, 2020
289722d
Fix issues.
Mar 29, 2020
9334e79
Save invalid files.
Mar 29, 2020
25736e0
Merge remote-tracking branch 'noiseaug/more-augment-options' into noi…
Mar 29, 2020
40b431b
Fix merging errors.
Mar 29, 2020
f7d1279
[FIX] replace tqdm with prograssbar [ADD] separate speech/noise mixin…
mychiux413 Mar 31, 2020
7792226
Merge branch 'no-sort' into more-augment-options
carlfm01 Apr 2, 2020
c4c3ced
Merge #f7d1279.
Apr 12, 2020
c151b1d
Merge branch 'master' into noisetest
Apr 17, 2020
c089b7f
Fix merge not detecting moved scripts.
Apr 17, 2020
491a4b0
Undo personal changes.
Apr 17, 2020
735cbbb
Merge branch 'master' of https://github.com/mozilla/DeepSpeech into n…
Apr 23, 2020
2fa91e8
To recover the incorrect merge
mychiux413 May 12, 2020
6b820bb
Merge pull request #1 from DanBmh/noiseaugmaster
mychiux413 May 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[ADD] use dbfs and SNR to determine the balance of audio/noise, add o…
…ption to dump audio into tensorboard [FIX] correct gain db formula
  • Loading branch information
mychiux413 committed Mar 16, 2020
commit 0b8147ce8c4a1906de80f1db793b8aa63dc15045
33 changes: 24 additions & 9 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
@@ -218,7 +218,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
the decoded result and the batch's original Y.
'''
# Obtain the next batch of data
batch_filenames, (batch_x, batch_seq_len), batch_y = iterator.get_next()
batch_filenames, (batch_x, batch_seq_len), batch_y, review_audio = iterator.get_next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the review_audio idea.


if FLAGS.use_cudnn_rnn:
rnn_impl = rnn_impl_cudnn_rnn
@@ -238,7 +238,7 @@ def calculate_mean_edit_distance_and_loss(iterator, dropout, reuse):
avg_loss = tf.reduce_mean(input_tensor=total_loss)

# Finally we return the average loss
return avg_loss, non_finite_files
return avg_loss, non_finite_files, review_audio


# Adam Optimization
@@ -299,7 +299,7 @@ def get_tower_results(iterator, optimizer, dropout_rates):
with tf.name_scope('tower_%d' % i):
# Calculate the avg_loss and mean_edit_distance and retrieve the decoded
# batch along with the original batch's labels (Y) of this tower
avg_loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)
avg_loss, non_finite_files, review_audio = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=i > 0)

# Allow for variables to be re-used by the next tower
tfv1.get_variable_scope().reuse_variables()
@@ -316,6 +316,8 @@ def get_tower_results(iterator, optimizer, dropout_rates):
tower_non_finite_files.append(non_finite_files)

avg_loss_across_towers = tf.reduce_mean(input_tensor=tower_avg_losses, axis=0)
if FLAGS.augmentation_review_audio_steps:
tfv1.summary.audio(name='step_audio', tensor=review_audio, sample_rate=16000, collections=['step_audio_summaries'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change sample_rate=16000 to sample_rate=FLAGS.audio_sample_rate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

tfv1.summary.scalar(name='step_loss', tensor=avg_loss_across_towers, collections=['step_summaries'])

all_non_finite_files = tf.concat(tower_non_finite_files, axis=0)
@@ -437,7 +439,7 @@ def train():
enable_cache=FLAGS.feature_cache and do_cache_dataset,
cache_path=FLAGS.feature_cache,
train_phase=True,
noise_dirs=FLAGS.audio_aug_mix_noise_walk_train_dirs)
noise_dirs_or_files=FLAGS.audio_aug_mix_noise_train_dirs_or_files)

iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set),
tfv1.data.get_output_shapes(train_set),
@@ -448,7 +450,7 @@ def train():

if FLAGS.dev_files:
dev_csvs = FLAGS.dev_files.split(',')
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False, noise_dirs=FLAGS.audio_aug_mix_noise_walk_dev_dirs) for csv in dev_csvs]
dev_sets = [create_dataset([csv], batch_size=FLAGS.dev_batch_size, train_phase=False, noise_dirs_or_files=FLAGS.audio_aug_mix_noise_dev_dirs_or_files) for csv in dev_csvs]
dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets]

# Dropout
@@ -484,6 +486,7 @@ def train():
apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step)

# Summaries
step_audio_summaries_op = tfv1.summary.merge_all('step_audio_summaries')
step_summaries_op = tfv1.summary.merge_all('step_summaries')
step_summary_writers = {
'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120),
@@ -594,11 +597,20 @@ def __call__(self, progress, data, **kwargs):
session.run(init_op)

# Batch loop

i_audio_steps = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better: audio_summary_steps

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified

while True:
try:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
step_audio_summary = None
if i_audio_steps < FLAGS.augmentation_review_audio_steps and epoch == 0:
_, current_step, batch_loss, problem_files, step_summary, step_audio_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op, step_audio_summaries_op],
feed_dict=feed_dict)
i_audio_steps += 1
else:
_, current_step, batch_loss, problem_files, step_summary = \
session.run([train_op, global_step, loss, non_finite_files, step_summaries_op],
feed_dict=feed_dict)
except tf.errors.OutOfRangeError:
break

@@ -612,6 +624,9 @@ def __call__(self, progress, data, **kwargs):

pbar.update(step_count)

if step_audio_summary is not None:
step_summary_writer.add_summary(step_audio_summary, current_step)

step_summary_writer.add_summary(step_summary, current_step)

if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs:
@@ -674,7 +689,7 @@ def __call__(self, progress, data, **kwargs):


def test():
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading, noise_dirs=FLAGS.audio_aug_mix_noise_walk_test_dirs)
samples = evaluate(FLAGS.test_files.split(','), create_model, try_loading, noise_dirs_or_files=FLAGS.audio_aug_mix_noise_test_dirs_or_files)
if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
json.dump(samples, open(FLAGS.test_output_file, 'w'), default=float)
6 changes: 3 additions & 3 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ def sparse_tuple_to_texts(sp_tuple, alphabet):
return [alphabet.decode(res) for res in results]


def evaluate(test_csvs, create_model, try_loading, noise_dirs=None):
def evaluate(test_csvs, create_model, try_loading, noise_dirs_or_files=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noise_sources

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified

if FLAGS.lm_binary_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
@@ -50,13 +50,13 @@ def evaluate(test_csvs, create_model, try_loading, noise_dirs=None):
scorer = None

test_csvs = FLAGS.test_files.split(',')
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False, noise_dirs=noise_dirs) for csv in test_csvs]
test_sets = [create_dataset([csv], batch_size=FLAGS.test_batch_size, train_phase=False, noise_dirs_or_files=noise_dirs_or_files) for csv in test_csvs]
iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(test_sets[0]),
tfv1.data.get_output_shapes(test_sets[0]),
output_classes=tfv1.data.get_output_classes(test_sets[0]))
test_init_ops = [iterator.make_initializer(test_set) for test_set in test_sets]

batch_wav_filename, (batch_x, batch_x_len), batch_y = iterator.get_next()
batch_wav_filename, (batch_x, batch_x_len), batch_y, _ = iterator.get_next()

# One rate per layer
no_dropout = [None] * 6
220 changes: 169 additions & 51 deletions util/audio_augmentation.py
Original file line number Diff line number Diff line change
@@ -6,89 +6,207 @@
from tensorflow.python.ops import gen_audio_ops as contrib_audio
import os
from util.logging import log_info
from util.config import Config

DBFS_COEF = 20.0 / np.log(10.0)

DBFS_COEF = 10.0 / np.log(10.0)

def get_dbfs(wav_filename):
def filename_to_audio(wav_filename):
r"""Decode `wab_filename` and return the audio

Args:
wav_filename: A str, the path of wav file

Returns:
A 2-D Tensor with shape [`time-steps`, 1].
"""
samples = tf.io.read_file(wav_filename)
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
rms = tf.sqrt(tf.reduce_mean(tf.square(decoded.audio)))
dbfs = DBFS_COEF * tf.math.log(rms)
return dbfs
return decoded.audio

def audio_to_dbfs(audio, sample_rate=16000, chunk_ms=100, reduce_funcs=tf.reduce_mean):
r"""Separately measure the chunks dbfs of `audio`, then return the statistics values through `reduce_funcs

Args:
audio: A 2-D Tensor with shape [`time-steps`, 1].
sample_rate: An integer, specifying the audio sample rate to determining the chunk size for dbfs measurement.
chunk_ms: An integer in milliseconds unit, specifying each chunk size for separately measuring dbfs, default is `100ms`
reduce_funcs: A function or A list of function, specifying the statistics method to chunks, default is tf.reduce_mean

Returns:
A float or A list of float, depends on reduce_funcs is function or list of function
"""
assert chunk_ms % 10 == 0, 'chunk_ms must be a multiple of 10'

audio_len = tf.shape(audio)[0]
chunk_len = tf.math.floordiv(sample_rate, tf.math.floordiv(1000, chunk_ms)) # default: 1600
n_chunks = tf.math.floordiv(audio_len, chunk_len)
trim_audio_len = tf.multiply(n_chunks, chunk_len)
audio = audio[:trim_audio_len]
splits = tf.reshape(audio, shape=[n_chunks, -1])

squares = tf.square(splits)
means = tf.reduce_mean(squares, axis=1)

# the statistics functions must execute before tf.log(), or the gain db would be wrong
if not isinstance(reduce_funcs, list):
reduces = reduce_funcs(means)
return DBFS_COEF * tf.math.log(reduces + 1e-8)

def create_noise_iterator(noise_dirs):
"""noise_dirs: `str` or `list`"""
if isinstance(noise_dirs, str):
noise_dirs = noise_dirs.split(',')
reduces = [reduce_func(means) for reduce_func in reduce_funcs]
return [DBFS_COEF * tf.math.log(reduce + 1e-8) for reduce in reduces]

noise_filenames = tf.convert_to_tensor(
list(collect_noise_filenames(noise_dirs)),
dtype=tf.string)
log_info("Collect {} noise files for mixing audio".format(
noise_filenames.shape[0]))

def extract_dbfs(wav_filename):
return wav_filename, get_dbfs(wav_filename)
def create_noise_iterator(noise_dirs_or_files, read_csvs_func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this has to be refactored to use functionality from sample_collections.

r"""Create an iterator to yield audio

Args:
noise_dirs_or_files: A list/tuple of str, the collection source of wav filenames.
read_csvs_func: A function, please specify the `read_csvs()` function from `util/feeding.py`, which is to prevent recursive import error.

Returns:
An one shot iterator of audio with 2-D Tensor of shape [`time-step`, 1], use `<iter>.get_next()` to get the Tensor.
"""
if isinstance(noise_dirs_or_files, str):
noise_dirs_or_files = noise_dirs_or_files.split(',')

noise_filenames = tf.convert_to_tensor(list(collect_noise_filenames(noise_dirs_or_files, read_csvs_func)), dtype=tf.string)
log_info("Collect {} noise files for mixing audio".format(noise_filenames.shape[0]))

noise_dataset = (tf.data.Dataset.from_tensor_slices(noise_filenames)
.map(extract_dbfs, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.cache()
.shuffle(min(noise_filenames.shape[0], 102400))
.map(noise_file_to_audio, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.map(filename_to_audio, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.prefetch(tfv1.data.experimental.AUTOTUNE)
.repeat())
noise_iterator = tfv1.data.make_one_shot_iterator(noise_dataset)
return noise_iterator


def collect_noise_filenames(walk_dirs):
assert isinstance(walk_dirs, list)
def collect_noise_filenames(dirs_or_files, read_csvs_func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be moved to bin/normalize_noise_audio.py. This tool should then be responsible for creating a regular sample collection (.csv or .sdb).

r"""Collect wav filenames from directories or csv files

for d in walk_dirs:
for dirpath, _, filenames in os.walk(d):
for filename in filenames:
if filename.endswith('.wav'):
yield os.path.join(dirpath, filename)
Args:
dirs_or_files: A list/tuple of str, the collection source of wav filenames.
read_csvs_func: A function, please specify the `read_csvs()` function from `util/feeding.py`, which is to prevent recursive import error.

Returns:
An iterator of str, yield every filename suffix with `.wav` or under `wav_filename` column of DataFrame
"""

def noise_file_to_audio(noise_file, noise_dbfs):
samples = tf.io.read_file(noise_file)
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
return decoded.audio, noise_dbfs
assert isinstance(dirs_or_files, (list, tuple))

for dir_or_file in dirs_or_files:
assert os.path.exists(dir_or_file)
if os.path.isdir(dir_or_file):
for dirpath, _, filenames in os.walk(dir_or_file):
for filename in filenames:
if filename.endswith('.wav'):
yield os.path.join(dirpath, filename)
elif os.path.isfile(dir_or_file):
df = read_csvs_func([dir_or_file])
for filename in df['wav_filename']:
yield filename


def augment_noise(audio,
audio_dbfs,
noise,
noise_dbfs,
max_audio_gain_db=5,
min_audio_gain_db=-10,
max_snr_db=30,
min_snr_db=5):
decoded_audio_len = tf.shape(audio)[0]
decoded_noise_len = tf.shape(noise)[0]
min_audio_dbfs=0.0,
max_audio_dbfs=-35.0,
min_snr_db=3.0,
max_snr_db=30.0,
limit_audio_peak_dbfs=7.0,
limit_noise_peak_dbfs=3.0,
sample_rate=16000):
r"""Mix audio Tensor with noise Tensor

If the noise length is shorter than audio, the process will automaticaly repeat the noise file to over audio length,
The process randomly choose a duration of the noise to complete coverage the audio,
i.e. the shapes between the choosen duration of noise and audio are equal.

multiply = tf.math.floordiv(decoded_audio_len, decoded_noise_len) + 1
noise_audio_tile = tf.tile(noise, [multiply, 1])
Args:
audio: A 2-D Tensor with shape [`time-steps`, 1].
noise: A 2-D Tensor with shape [`time-steps`, 1].
min_audio_dbfs: A float in dbfs unit, specifying the `minimum` volume of audio during gaining audio.
max_audio_dbfs: A float in dbfs unit, specifying the `maximum` volume of audio during gaining audio.
min_snr_db: A float in db unit, specifying the minimum signal-to-noise ratio during gaining audio and noise.
max_snr_db: A float in db unit, specifying the maximum signal-to-noise ratio during gaining audio and noise.
limit_audio_peak_dbfs: A float, specifying the limitation of maximun audio dbfs of chunks, the audio volume will not gain over than the specified value.
limit_noise_peak_dbfs: A float, specifying the limitation of maximun noise dbfs of chunks, the noise volume will not gain over than the specified value.
sample_rate: An integer, specifying the audio sample rate to determining the chunk size for dbfs measurement.

# Now, decoded_noise_len must > decoded_audio_len
decoded_noise_len = tf.shape(noise_audio_tile)[0]
Returns:
A 2-D Tensor with shape [`time-steps`, 1]. Has the same type and shape as `audio`.
"""

mix_decoded_start_point = tfv1.random_uniform([], minval=0, maxval=decoded_noise_len-decoded_audio_len, dtype=tf.int32)
mix_decoded_end_point = mix_decoded_start_point + decoded_audio_len
extract_noise_decoded = noise_audio_tile[mix_decoded_start_point:mix_decoded_end_point, :]
audio_len = tf.shape(audio)[0]
noise_len = tf.shape(noise)[0]

audio_gain_db = tfv1.random_uniform([], minval=min_audio_gain_db, maxval=max_audio_gain_db)
target_audio_dbfs = audio_dbfs + audio_gain_db
audio_gain_ratio = tf.math.pow(10.0, audio_gain_db / 10)
audio_mean_dbfs, audio_max_dbfs = audio_to_dbfs(audio, sample_rate, reduce_funcs=[tf.reduce_mean, tf.reduce_max])

multiply = tf.math.floordiv(audio_len, noise_len) + 1
noise_tile = tf.tile(noise, [multiply, 1])


# Now, noise_len must > audio_len
noise_tile_len = tf.shape(noise_tile)[0]

mix_decoded_start_point = tfv1.random_uniform([], minval=0, maxval=noise_tile_len-audio_len, dtype=tf.int32)
mix_decoded_end_point = mix_decoded_start_point + audio_len
extract_noise = noise_tile[mix_decoded_start_point:mix_decoded_end_point, :]

extract_noise_mean_dbfs, extract_noise_max_dbfs = audio_to_dbfs(extract_noise, sample_rate, reduce_funcs=[tf.reduce_mean, tf.reduce_max])

target_audio_dbfs = tfv1.random_uniform([], minval=min_audio_dbfs, maxval=max_audio_dbfs)

audio_gain_db = target_audio_dbfs - audio_mean_dbfs

# limit audio peak
audio_gain_db = tf.minimum(limit_audio_peak_dbfs - audio_max_dbfs, audio_gain_db)
target_audio_dbfs = audio_mean_dbfs + audio_gain_db

audio_gain_ratio = tf.math.pow(10.0, audio_gain_db / 20.0)

# target_snr_db := target_audio_dbfs - target_noise_dbfs
target_snr_db = tfv1.random_uniform([], minval=min_snr_db, maxval=max_snr_db)

target_noise_dbfs = target_audio_dbfs - target_snr_db
noise_gain_db = target_noise_dbfs - noise_dbfs
noise_gain_ratio = tf.math.pow(10.0, noise_gain_db / 10)
mixed_audio = tf.multiply(audio, audio_gain_ratio) + tf.multiply(extract_noise_decoded, noise_gain_ratio)
noise_gain_db = target_noise_dbfs - extract_noise_mean_dbfs

# limit noise peak
noise_gain_db = tf.minimum(limit_noise_peak_dbfs - extract_noise_max_dbfs, noise_gain_db)
noise_gain_ratio = tf.math.pow(10.0, noise_gain_db / 20.0)

mixed_audio = tf.multiply(audio, audio_gain_ratio) + tf.multiply(extract_noise, noise_gain_ratio)

mixed_audio = tf.maximum(tf.minimum(mixed_audio, 1.0), -1.0)

return mixed_audio

def gla(spectrogram):
r"""Use Griffin-Lim algorithm to reconstruct audio and fix iteration=10 to not waste too much performance in prefetch

Args:
spectrogram: A 3-D Tensor with shape [1, `time-steps`, `features`].
Returns:
A 2-D Tensor with shape [`time-steps`, 1], which is a reconstructed audio from spectrogram.
"""
frame_length = int(Config.audio_window_samples)
frame_step = int(Config.audio_step_samples)
fft_length = 512
spectrogram = tf.reshape(spectrogram, shape=[1, -1, 257])
abs_spectrogram = tf.abs(spectrogram)

def reconstruct_phases(prev_phases):
xi = tf.complex(abs_spectrogram, 0.0) * prev_phases
audio = tf.signal.inverse_stft(xi, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length)
next_xi = tf.signal.stft(audio, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length)
next_phases = tf.math.exp(tf.complex(0.0, tf.angle(next_xi)))
return next_phases

rands = tfv1.random_uniform(tf.shape(spectrogram), dtype=tf.float32)
phases = tf.math.exp(tf.complex(0.0, 2.0 * np.pi * rands))

reconstructed_phases = tf.while_loop(lambda _: True, reconstruct_phases, [phases], maximum_iterations=10)
xi = tf.complex(abs_spectrogram, 0.0) * reconstructed_phases
audio = tf.signal.inverse_stft(xi, frame_length=frame_length, frame_step=frame_step, fft_length=fft_length)
return tf.transpose(audio)
Loading