diff --git a/README.md b/README.md index e1b2ead..1da4145 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,36 @@ model.fit(x, y) * See the Jupyter notebook at the [example folder](https://github.com/keunwoochoi/kapre/tree/master/examples) +## Tflite compatbility + +The `STFT` layer is not tflite compatible (due to `tf.signal.stft`). To create a tflite +compatible model, first train using the normal `kapre` layers then create a new +model replacing `STFT` and `Magnitude` with `STFTTflite`, `MagnitudeTflite`. +Tflite compatible layers are restricted to a batch size of 1 which prevents use +of them during training. + +```python +# assumes you have run the one-shot example above. +from kapre import STFTTflite, MagnitudeTflite +model_tflite = Sequential() + +model_tflite.add(STFTTflite(n_fft=2048, win_length=2018, hop_length=1024, + window_name=None, pad_end=False, + input_data_format='channels_last', output_data_format='channels_last', + input_shape=input_shape)) +model_tflite.add(MagnitudeTflite()) +model_tflite.add(MagnitudeToDecibel()) +model_tflite.add(Conv2D(32, (3, 3), strides=(2, 2))) +model_tflite.add(BatchNormalization()) +model_tflite.add(ReLU()) +model_tflite.add(GlobalAveragePooling2D()) +model_tflite.add(Dense(10)) +model_tflite.add(Softmax()) + +# load the trained weights into the tflite compatible model. +model_tflite.set_weights(model.get_weights()) +``` + # Citation Please cite this paper if you use Kapre for your work. diff --git a/docs/release_note.rst b/docs/release_note.rst index a86cd94..aeaf210 100644 --- a/docs/release_note.rst +++ b/docs/release_note.rst @@ -1,6 +1,10 @@ Release Note ^^^^^^^^^^^^ +* 13 Nov 2021 + - 0.3.6 + - bugfix/pad end tflite #131 + * 18 March 2021 - 0.3.5 - Add `kapre.time_frequency_tflite` which uses tflite for a faster CPU inference. diff --git a/kapre/__init__.py b/kapre/__init__.py index 14e71c9..a734822 100644 --- a/kapre/__init__.py +++ b/kapre/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.3.5' +__version__ = '0.3.6' VERSION = __version__ from . import composed diff --git a/kapre/tflite_compatible_stft.py b/kapre/tflite_compatible_stft.py index 7ab1000..ada57ef 100644 --- a/kapre/tflite_compatible_stft.py +++ b/kapre/tflite_compatible_stft.py @@ -175,9 +175,10 @@ def stft_tflite(signal, frame_length, frame_step, fft_length, window_fn, pad_end signal = tf.cast(signal, tf.float32) if pad_end: # the number of whole frames + # (NOTE: kenders2000), padding is pre-calculated and thus fixed in graph length_samples = signal.shape[-1] - num_steps_round_up = tf.math.ceil(length_samples / frame_step) - pad_amount = int((num_steps_round_up * frame_step) - length_samples) + num_steps_round_up = int(np.ceil(length_samples / frame_step)) + pad_amount = (num_steps_round_up * frame_step + frame_length - frame_step) - length_samples signal = tf.pad(signal, tf.constant([[0, 0], [0, 0], [0, pad_amount]])) # Make the window be shape (1, frame_length) instead of just frame_length # in an effort to help the tflite broadcast logic. diff --git a/kapre/time_frequency_tflite.py b/kapre/time_frequency_tflite.py index 0011796..8000497 100644 --- a/kapre/time_frequency_tflite.py +++ b/kapre/time_frequency_tflite.py @@ -31,7 +31,13 @@ class STFTTflite(STFT): Ues `stft_tflite` from tflite_compatible_stft.py, this contains a tflite compatible stft (using a rdft), and `fixed_frame()` to window the audio. Tflite does not cope with comple types so real and imaginary parts are stored in extra dim. - Ouput shape is now: (batch, channel, time, re/im) or (batch, time, channel, re/im) + Ouput shape is now: (batch, channel, time, re/im) or (batch, time, channel, re/im). + `MagnitudeTflite`, and `PhaseTflite` are versions of the `Magnitude` and `Phase` + layers that account for this extra dimensionality. Currently this layer is + restricted to a batch size of one, for training use the `STFT` layer, and + once complete transfer the weights to a new model, replacing the `STFT` layer + with the `STFTTflite` layer and `Magnitude` and `Phase` layers with + `MagnitudeTflite` and `PhaseTflite` layers. Additionally, it reshapes the output to be a proper 2D batch. diff --git a/setup.py b/setup.py index c2acad3..f945cb9 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='kapre', - version='0.3.5', + version='0.3.6', description='Kapre: Keras Audio Preprocessors. Tensorflow.Keras layers for audio pre-processing in deep learning', author='Keunwoo Choi', url='http://github.com/keunwoochoi/kapre/', diff --git a/tests/test_time_frequency.py b/tests/test_time_frequency.py index 7b1d6a4..af4bf48 100644 --- a/tests/test_time_frequency.py +++ b/tests/test_time_frequency.py @@ -269,8 +269,9 @@ def _get_melgram_model(return_decibel, amin, dynamic_range, input_shape=None): @pytest.mark.parametrize('data_format', ['default', 'channels_first', 'channels_last']) @pytest.mark.parametrize('batch_size', [1, 2]) @pytest.mark.parametrize('win_length', [1000, 512]) +@pytest.mark.parametrize('pad_end', [False, True]) def test_spectrogram_tflite_correctness( - n_fft, hop_length, n_ch, data_format, batch_size, win_length + n_fft, hop_length, n_ch, data_format, batch_size, win_length, pad_end ): def _get_stft_model(following_layer=None, tflite_compatible=False): # compute with kapre @@ -282,7 +283,7 @@ def _get_stft_model(following_layer=None, tflite_compatible=False): win_length=win_length, hop_length=hop_length, window_name=None, - pad_end=False, + pad_end=pad_end, input_data_format=data_format, output_data_format=data_format, input_shape=input_shape, @@ -296,7 +297,7 @@ def _get_stft_model(following_layer=None, tflite_compatible=False): win_length=win_length, hop_length=hop_length, window_name=None, - pad_end=False, + pad_end=pad_end, input_data_format=data_format, output_data_format=data_format, input_shape=input_shape,