From ea6f98ab4af88f655171a1d429cd4589f0c76c0e Mon Sep 17 00:00:00 2001 From: cbrom_a Date: Tue, 16 Oct 2018 15:09:01 +0300 Subject: [PATCH] Refer to issue #77 changes to _pad, _pad_2d when **is_mulaw_quantize** is True - padding value is changed to mulaw_quantize(0, quantize_channels) --- train.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 11a6e94..cc2fb68 100644 --- a/train.py +++ b/train.py @@ -90,9 +90,9 @@ def _pad(seq, max_len, constant_values=0): mode='constant', constant_values=constant_values) -def _pad_2d(x, max_len, b_pad=0): +def _pad_2d(x, max_len, b_pad=0, constant_values=0): x = np.pad(x, [(b_pad, max_len - len(x) - b_pad), (0, 0)], - mode="constant", constant_values=0) + mode="constant", constant_values=constant_values) return x @@ -417,9 +417,10 @@ def collate_fn(batch): # (B, T, C) # pad for time-axis if is_mulaw_quantize(hparams.input_type): + padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels) x_batch = np.array([_pad_2d(np_utils.to_categorical( x[0], num_classes=hparams.quantize_channels), - max_input_len) for x in batch], dtype=np.float32) + max_input_len, padding_value) for x in batch], dtype=np.float32) else: x_batch = np.array([_pad_2d(x[0].reshape(-1, 1), max_input_len) for x in batch], dtype=np.float32) @@ -427,7 +428,8 @@ def collate_fn(batch): # (B, T) if is_mulaw_quantize(hparams.input_type): - y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.int) + padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels) + y_batch = np.array([_pad(x[0], max_input_len, constant_values=padding_value) for x in batch], dtype=np.int) else: y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.float32) assert len(y_batch.shape) == 2