Skip to content

Commit

Permalink
Kapre 0.3.0 (#81)
Browse files Browse the repository at this point in the history
* draft is done. backend test is done

* more work on spectrogram test

* all works well except saving filterbank layer

* fixed filterbank issue, add more test

* update readme

* update readme

* update notebook

* fix code in readme

* add scripts

* add pip install on readme

Co-authored-by: Keunwoo Choi <keunwoochoi@KCs-qmul-mbp.local>
Co-authored-by: keunwoochoi <gnuchoi+github@gmail.com`>
  • Loading branch information
3 people authored Aug 15, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent bb0ea2c commit 8cdbb16
Showing 22 changed files with 1,149 additions and 2,425 deletions.
369 changes: 39 additions & 330 deletions README.md

Large diffs are not rendered by default.

495 changes: 0 additions & 495 deletions examples/example_codes.ipynb

This file was deleted.

304 changes: 304 additions & 0 deletions examples/how-to-use-Kapre.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# How to use Kapre - example"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2020/8/14\n",
"Tensorflow: 2.3.0\n",
"Librosa: 0.8.0\n",
"Image data format: channels_last\n",
"Kapre: 0.3.0-rc\n"
]
}
],
"source": [
"import librosa\n",
"import kapre\n",
"import tensorflow as tf\n",
"from tensorflow.keras.models import Sequential\n",
"import numpy as np\n",
"\n",
"from datetime import datetime\n",
"now = datetime.now()\n",
"\n",
"print('%s/%s/%s' % (now.year, now.month, now.day))\n",
"print('Tensorflow: {}'.format(tf.__version__))\n",
"print('Librosa: {}'.format(librosa.__version__))\n",
"print('Image data format: {}'.format(tf.keras.backend.image_data_format()))\n",
"print('Kapre: {}'.format(kapre.__version__))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Loading an mp3 file"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Audio length: 453888 samples, 10.29 seconds. \n",
"Audio sample rate: 44100 Hz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/keunwoochoi/miniconda3/envs/kapre/lib/python3.7/site-packages/librosa/core/audio.py:162: UserWarning: PySoundFile failed. Trying audioread instead.\n",
" warnings.warn(\"PySoundFile failed. Trying audioread instead.\")\n"
]
}
],
"source": [
"src, sr = librosa.load('../srcs/bensound-cute.mp3', sr=None, mono=True)\n",
"print('Audio length: %d samples, %04.2f seconds. \\n' % (len(src), len(src) / sr) +\n",
" 'Audio sample rate: %d Hz' % sr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Trim it and make it a 2d.\n",
"\n",
"If your file is mono, librosa.load returns a 1D array. Kapre always expects 2d array, so make it 2d.\n",
"\n",
"On my computer, I use default `image_data_format == 'channels_last'`. I'll keep it in that way for the audio data, too."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The shape of an item (44100, 1)\n"
]
}
],
"source": [
"len_second = 1.0 # Let's trim it to make it quick\n",
"src = src[:int(sr*len_second)]\n",
"src = np.expand_dims(src, axis=1)\n",
"input_shape = src.shape\n",
"print('The shape of an item', input_shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Let's make it a batch of 4 items\n",
"\n",
"to make it more like a proper dataset. You should have many files indeed."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The shape of a batch: (4, 44100, 1)\n"
]
}
],
"source": [
"x = np.array([src] * 4)\n",
"print('The shape of a batch: ',x.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A Keras model\n",
"\n",
"A simple model with 10-class and single-label classification."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_5\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"stft-layer (STFT) (None, 42, 1025, 1) 0 \n",
"=================================================================\n",
"Total params: 0\n",
"Trainable params: 0\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"from kapre.time_frequency import STFT, Magnitude, MagnitudeToDecibel\n",
"\n",
"\n",
"model = Sequential()\n",
"# A STFT layer\n",
"model.add(STFT(n_fft=2048, win_length=2018, hop_length=1024,\n",
" window_fn=None, pad_end=False,\n",
" input_data_format='channels_last', output_data_format='channels_last',\n",
" input_shape=input_shape,\n",
" name='stft-layer'))\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- The model has no trainable parameters because `STFT` layer uses `tf.signal.stft()` function which is just an implementation of FFT-based short-time Fourier transform.\n",
"- The output shape is `(batch, time, frequency, channels)`. \n",
" - `42` (time) is the number of STFT frames. A shorter hop length would make it (nearly) proportionally longer. If `pad_end=True`, the input audio signals become a little longer, hence the number of frames would get longer, too.\n",
" - `1025` is the number of STFT bins and decided as `n_fft // 2 + 1`. \n",
" - `1` channel: because the input signal was single-channel.\n",
"- The output of `STFT` layer is `complex` number."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's add more layers like a real model!"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, GlobalAveragePooling2D, Dense, Softmax"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"model.add(Magnitude())\n",
"model.add(MagnitudeToDecibel())\n",
"model.add(Conv2D(32, (3, 3), strides=(2, 2)))\n",
"model.add(BatchNormalization())\n",
"model.add(ReLU())\n",
"model.add(GlobalAveragePooling2D())\n",
"model.add(Dense(10))\n",
"model.add(Softmax())\n",
"\n",
"# Compile the model\n",
"model.compile('adam', 'categorical_crossentropy') # if single-label classification\n"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_5\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"stft-layer (STFT) (None, 42, 1025, 1) 0 \n",
"_________________________________________________________________\n",
"magnitude (Magnitude) (None, 42, 1025, 1) 0 \n",
"_________________________________________________________________\n",
"magnitude_to_decibel (Magnit (None, 42, 1025, 1) 0 \n",
"_________________________________________________________________\n",
"conv2d_1 (Conv2D) (None, 20, 512, 32) 320 \n",
"_________________________________________________________________\n",
"batch_normalization (BatchNo (None, 20, 512, 32) 128 \n",
"_________________________________________________________________\n",
"re_lu (ReLU) (None, 20, 512, 32) 0 \n",
"_________________________________________________________________\n",
"global_average_pooling2d (Gl (None, 32) 0 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 10) 330 \n",
"_________________________________________________________________\n",
"softmax (Softmax) (None, 10) 0 \n",
"=================================================================\n",
"Total params: 778\n",
"Trainable params: 714\n",
"Non-trainable params: 64\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- I added `Magnitude()` which is a simple `abs()` operation on the complex numbers.\n",
"- `MagnitudeToDecibel` maps the numbers to a decibel scale."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
278 changes: 0 additions & 278 deletions examples/prepare audio.ipynb

This file was deleted.

8 changes: 2 additions & 6 deletions kapre/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
__version__ = '0.2.0'
__version__ = '0.3.0'
VERSION = __version__

from . import time_frequency
from . import composed
from . import backend
from . import backend_keras

from . import augmentation
from . import filterbank
from . import utils
63 changes: 0 additions & 63 deletions kapre/augmentation.py

This file was deleted.

180 changes: 77 additions & 103 deletions kapre/backend.py
Original file line number Diff line number Diff line change
@@ -1,134 +1,107 @@
"""
Kapre backend functions
=======================\
| Some backend functions that mainly use numpy.
| Functions with Keras' backend is in ``backend_keras.py``.
Notes
-----
* Don't forget to use ``K.float()``! Otherwise numpy uses float64.
* Some functions are copied-and-pasted from librosa (to reduce dependency), but
later I realised it'd be better to just use it.
* TODO: remove copied code and use librosa.
"""
from tensorflow.keras import backend as K
import tensorflow as tf
import numpy as np
import librosa

EPS = 1e-7


def eps():
return EPS

def magnitude_to_decibel(x, ref_value=1.0, amin=1e-5, dynamic_range=80.0):
"""
Similar to `librosa.amplitude_to_db` with `ref=1.0` and `top_db=dynamic_range`
def mel(sr, n_dft, n_mels=128, fmin=0.0, fmax=None, htk=False, norm='slaney'):
"""[np] create a filterbank matrix to combine stft bins into mel-frequency bins
use Slaney (said Librosa)
Args:
x (tensor): float tensor. Can be batch or not. Something like magnitude of STFT.
ref_value (float): an input value that would become 0 dB in the result.
For spectrogram magnitudes, ref_value=1.0 usually make the decibel-sclaed output to be around zero
if the input audio was in [-1, 1].
amin (float): the noise floor of the input. An input that is smaller than `amin`, it's converted to `amin.
dynamic_range (float): range of the resulting value. E.g., if the maximum magnitude is 30 dB,
the noise floor of the output would become (30 - dynamic_range) dB
n_mels: numbre of mel bands
fmin : lowest frequency [Hz]
fmax : highest frequency [Hz]
If `None`, use `sr / 2.0`
"""
return librosa.filters.mel(
sr=sr, n_fft=n_dft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
).astype(K.floatx())

def _log10(x):
return tf.math.log(x) / tf.math.log(tf.constant(10, dtype=x.dtype))

def get_stft_kernels(n_dft):
"""[np] Return dft kernels for real/imagnary parts assuming
the input . is real.
An asymmetric hann window is used (scipy.signal.hann).
if K.ndim(x) > 1: # we assume x is batch in this case
max_axis = tuple(range(K.ndim(x))[1:])
else:
max_axis = None

Parameters
----------
n_dft : int > 0 and power of 2 [scalar]
Number of dft components.
if amin is None:
amin = 1e-5

Returns
-------
| dft_real_kernels : np.ndarray [shape=(nb_filter, 1, 1, n_win)]
| dft_imag_kernels : np.ndarray [shape=(nb_filter, 1, 1, n_win)]
log_spec = 10.0 * _log10(tf.math.maximum(x, amin))
log_spec = log_spec - 10.0 * _log10(tf.math.maximum(amin, ref_value))

* nb_filter = n_dft/2 + 1
* n_win = n_dft
"""
assert n_dft > 1 and ((n_dft & (n_dft - 1)) == 0), (
'n_dft should be > 1 and power of 2, but n_dft == %d' % n_dft
log_spec = tf.math.maximum(
log_spec, tf.math.reduce_max(log_spec, axis=max_axis, keepdims=True) - dynamic_range
)

nb_filter = int(n_dft // 2 + 1)
return log_spec

# prepare DFT filters
timesteps = np.array(range(n_dft))
w_ks = np.arange(nb_filter) * 2 * np.pi / float(n_dft)
dft_real_kernels = np.cos(w_ks.reshape(-1, 1) * timesteps.reshape(1, -1))
dft_imag_kernels = -np.sin(w_ks.reshape(-1, 1) * timesteps.reshape(1, -1))

# windowing DFT filters
dft_window = librosa.filters.get_window('hann', n_dft, fftbins=True) # _hann(n_dft, sym=False)
dft_window = dft_window.astype(K.floatx())
dft_window = dft_window.reshape((1, -1))
dft_real_kernels = np.multiply(dft_real_kernels, dft_window)
dft_imag_kernels = np.multiply(dft_imag_kernels, dft_window)
def filterbank_mel(
sample_rate, n_freq, n_mels=128, f_min=0.0, f_max=None, htk=False, norm='slaney'
):
"""A wrapper for librosa.filters.mel that additionally does transpose and tensor conversion
dft_real_kernels = dft_real_kernels.transpose()
dft_imag_kernels = dft_imag_kernels.transpose()
dft_real_kernels = dft_real_kernels[:, np.newaxis, np.newaxis, :]
dft_imag_kernels = dft_imag_kernels[:, np.newaxis, np.newaxis, :]
Args:
sample_rate (int): sample rate of the input audio
n_freq (int): number of frequency bins in the input STFT magnitude.
n_mels (int): the number of mel bands
f_min (float): lowest frequency that is going to be included in the mel filterbank (Hertz)
f_max (float): highest frequency that is going to be included in the mel filterbank (Hertz)
htk (bool): whether to use `htk` formula or not
norm: The default, 'slaney', would normalize the the mel weights by the width of the mel band.
return dft_real_kernels.astype(K.floatx()), dft_imag_kernels.astype(K.floatx())


def filterbank_mel(sr, n_freq, n_mels=128, fmin=0.0, fmax=None, htk=False, norm='slaney'):
"""[np] """
return mel(
sr, (n_freq - 1) * 2, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
Return:
Mel filterbank tensor. Shape=(n_freq, n_mels)
"""
filterbank = librosa.filters.mel(
sr=sample_rate,
n_fft=(n_freq - 1) * 2,
n_mels=n_mels,
fmin=f_min,
fmax=f_max,
htk=htk,
norm=norm,
).astype(K.floatx())
return tf.convert_to_tensor(filterbank.T)


def filterbank_log(
sr, n_freq, n_bins=84, bins_per_octave=12, fmin=None, spread=0.125
): # pragma: no cover
"""[np] Approximate a constant-Q filter bank for a fixed-window STFT.
def filterbank_log(sample_rate, n_freq, n_bins=84, bins_per_octave=12, f_min=None, spread=0.125):
"""Approximate a constant-Q filter bank for a fixed-window STFT.
Each filter is a log-normal window centered at the corresponding frequency.
Note: `logfrequency` in librosa 0.4 (deprecated), so copy-and-pasted,
`tuning` was removed, `n_freq` instead of `n_fft`.
Parameters
----------
sr : number > 0 [scalar]
audio sampling rate
n_freq : int > 0 [scalar]
number of frequency bins
n_bins : int > 0 [scalar]
Number of bins. Defaults to 84 (7 octaves).
Args:
sample_rate (int): audio sampling rate
n_freq (int): number of the input frequency bins. E.g., `n_fft / 2 + 1`
n_bins (int): number of the resulting log-frequency bins. Defaults to 84 (7 octaves).
bins_per_octave (int): number of bins per octave. Defaults to 12 (semitones).
f_min (float): lowest frequency that is going to be included in the log filterbank. Defaults to `C1 ~= 32.70`
spread (float): spread of each filter, as a fraction of a bin.
bins_per_octave : int > 0 [scalar]
Number of bins per octave. Defaults to 12 (semitones).
fmin : float > 0 [scalar]
Minimum frequency bin. Defaults to `C1 ~= 32.70`
spread : float > 0 [scalar]
Spread of each filter, as a fraction of a bin.
Returns
-------
C : np.ndarray [shape=(n_bins, 1 + n_fft/2)]
log-frequency filter bank.
Returns:
log-frequency filterbank tensor. Shape=(n_freq, n_bins)
"""

if fmin is None:
fmin = 32.70319566
if f_min is None:
f_min = 32.70319566

f_max = f_min * 2 ** (n_bins / bins_per_octave)
if f_max > sample_rate // 2:
raise RuntimeError(
'Maximum frequency of log filterbank should be lower or equal to the maximum'
'frequency of the input (defined by its sample rate), '
'but f_max=%f and maximum frequency is %f. \n'
'Fix it by reducing n_bins, increasing bins_per_octave and/or reducing f_min.\n'
'You can also do it by increasing sample_rate but it means you need to upsample'
'the input audio data, too.' % (f_max, sample_rate)
)

# What's the shape parameter for our log-normal filters?
sigma = float(spread) / bins_per_octave
@@ -137,11 +110,11 @@ def filterbank_log(
basis = np.zeros((n_bins, n_freq))

# Get log frequencies of bins
log_freqs = np.log2(librosa.fft_frequencies(sr, (n_freq - 1) * 2)[1:])
log_freqs = np.log2(librosa.fft_frequencies(sample_rate, (n_freq - 1) * 2)[1:])

for i in range(n_bins):
# What's the center (median) frequency of this filter?
c_freq = fmin * (2.0 ** (float(i) / bins_per_octave))
c_freq = f_min * (2.0 ** (float(i) / bins_per_octave))

# Place a log-normal window around c_freq
basis[i, 1:] = np.exp(
@@ -150,5 +123,6 @@ def filterbank_log(

# Normalize the filters
basis = librosa.util.normalize(basis, norm=1, axis=1)
basis = basis.astype(K.floatx())

return basis.astype(K.floatx())
return tf.convert_to_tensor(basis.T)
25 changes: 0 additions & 25 deletions kapre/backend_keras.py

This file was deleted.

187 changes: 187 additions & 0 deletions kapre/composed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from .time_frequency import STFT, Magnitude, Phase, MagnitudeToDecibel, ApplyFilterbank
from . import backend

from tensorflow.keras import Sequential


def get_melspectrogram_layer(
input_shape=None,
n_fft=2048,
win_length=None,
hop_length=None,
window_fn=None,
pad_end=False,
sample_rate=22050,
n_mels=128,
mel_f_min=0.0,
mel_f_max=None,
mel_htk=False,
mel_norm='slaney',
return_decibel=False,
db_amin=1e-5,
db_ref_value=1.0,
db_dynamic_range=80.0,
input_data_format='default',
output_data_format='default',
):
"""A function that retunrs a melspectrogram layer, which is a Sequential model consists of
`STFT`, `Magnitude`, `ApplyFilterbank(_mel_filterbank)`, and optionally `MagnitudeToDecibel`.
Args:
input_shape (None or tuple of integers): input shape of the model if this melspectrogram layer is
is the first layer of your model (see `keras.model.Sequential()` for more details)
n_fft (int): number of FFT points in `STFT`
win_length (int): window length of `STFT`
hop_length (int): hop length of `STFT`
window_fn (function or None): windowing function of `STFT`.
Defaults to `None`, which would follow tf.signal.stft default (hann window at the moment)
pad_end (bool): whether to pad the input signal at the end in `STFT`.
sample_rate (int): sample rate of the input audio
n_mels (int): number of mel bins in the mel filterbank
mel_f_min (float): lowest frequency of the mel filterbank
mel_f_max (float): highest frequency of the mel filterbank
mel_htk (bool): whether to follow the htk mel filterbank fomula or not
mel_norm ('slaney' or int): normalization policy of the mel filterbank triangles
return_decibel (bool): whether to apply decibel scaling at the end
db_amin (float): noise floor of decibel scaling input. See `MagnitudeToDecibel` for more details.
db_ref_value (float): reference value of decibel scaling. See `MagnitudeToDecibel` for more details.
db_dynamic_range (float): dynamic range of the decibel scaling result.
input_data_format (str): the audio data format of input waveform batch.
`'channels_last'` if it's `(batch, time, channels)`
`'channels_first'` if it's `(batch, channels, time)`
Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
output_data_format (str): the data format of output mel spectrogram.
`'channels_last'` if you want `(batch, time, frequency, channels)`
`'channels_first'` if you want `(batch, channels, time, frequency)`
Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
"""
waveform_to_stft = STFT(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window_fn=window_fn,
pad_end=pad_end,
input_data_format=input_data_format,
output_data_format=output_data_format,
input_shape=input_shape,
)

stft_to_stftm = Magnitude()

kwargs = {
'sample_rate': sample_rate,
'n_freq': n_fft // 2 + 1,
'n_mels': n_mels,
'f_min': mel_f_min,
'f_max': mel_f_max,
'htk': mel_htk,
'norm': mel_norm,
}
stftm_to_melgram = ApplyFilterbank(
type='mel', filterbank_kwargs=kwargs, data_format=output_data_format
)

layers = [waveform_to_stft, stft_to_stftm, stftm_to_melgram]
if return_decibel:
mag_to_decibel = MagnitudeToDecibel(
ref_value=db_ref_value, amin=db_amin, dynamic_range=db_dynamic_range
)
layers.append(mag_to_decibel)

return Sequential(layers)


def get_log_frequency_spectrogram_layer(
input_shape=None,
n_fft=2048,
win_length=None,
hop_length=None,
window_fn=None,
pad_end=False,
sample_rate=22050,
log_n_bins=84,
log_f_min=None,
log_bins_per_octave=12,
log_spread=0.125,
return_decibel=False,
db_amin=1e-5,
db_ref_value=1.0,
db_dynamic_range=80.0,
input_data_format='default',
output_data_format='default',
):
"""A function that retunrs a log-frequency STFT layer, which is a Sequential model consists of
`STFT`, `Magnitude`, `ApplyFilterbank(_log_filterbank)`, and optionally `MagnitudeToDecibel`.
Args:
input_shape (None or tuple of integers): input shape of the model if this melspectrogram layer is
is the first layer of your model (see `keras.model.Sequential()` for more details)
n_fft (int): number of FFT points in `STFT`
win_length (int): window length of `STFT`
hop_length (int): hop length of `STFT`
window_fn (function or None): windowing function of `STFT`.
Defaults to `None`, which would follow tf.signal.stft default (hann window at the moment)
pad_end (bool): whether to pad the input signal at the end in `STFT`.
sample_rate (int): sample rate of the input audio
log_n_bins (int): number of the bins in the log-frequency filterbank
log_f_min (float): lowest frequency of the filterbank
log_bins_per_octave (int): number of bins in each octave in the filterbank
log_spread (float): spread constant (Q value) in the log filterbank.
return_decibel (bool): whether to apply decibel scaling at the end
db_amin (float): noise floor of decibel scaling input. See `MagnitudeToDecibel` for more details.
db_ref_value (float): reference value of decibel scaling. See `MagnitudeToDecibel` for more details.
db_dynamic_range (float): dynamic range of the decibel scaling result.
input_data_format (str): the audio data format of input waveform batch.
`'channels_last'` if it's `(batch, time, channels)`
`'channels_first'` if it's `(batch, channels, time)`
Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
output_data_format (str): the data format of output mel spectrogram.
`'channels_last'` if you want `(batch, time, frequency, channels)`
`'channels_first'` if you want `(batch, channels, time, frequency)`
Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
"""
waveform_to_stft = STFT(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
window_fn=window_fn,
pad_end=pad_end,
input_data_format=input_data_format,
output_data_format=output_data_format,
input_shape=input_shape,
)

stft_to_stftm = Magnitude()

_log_filterbank = backend.filterbank_log(
sample_rate=sample_rate,
n_freq=n_fft // 2 + 1,
n_bins=log_n_bins,
bins_per_octave=log_bins_per_octave,
f_min=log_f_min,
spread=log_spread,
)
kwargs = {
'sample_rate': sample_rate,
'n_freq': n_fft // 2 + 1,
'n_bins': log_n_bins,
'bins_per_octave': log_bins_per_octave,
'f_min': log_f_min,
'spread': log_spread,
}

stftm_to_loggram = ApplyFilterbank(
type='log', filterbank_kwargs=kwargs, data_format=output_data_format
)

layers = [waveform_to_stft, stft_to_stftm, stftm_to_loggram]

if return_decibel:
mag_to_decibel = MagnitudeToDecibel(
ref_value=db_ref_value, amin=db_amin, dynamic_range=db_dynamic_range
)
layers.append(mag_to_decibel)

return Sequential(layers)
148 changes: 0 additions & 148 deletions kapre/filterbank.py

This file was deleted.

560 changes: 227 additions & 333 deletions kapre/time_frequency.py

Large diffs are not rendered by default.

228 changes: 0 additions & 228 deletions kapre/utils.py

This file was deleted.

2 changes: 2 additions & 0 deletions scripts/apply-black.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
black kapre
black tests
5 changes: 5 additions & 0 deletions scripts/upload-to-pypi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

python setup.py sdist
pip install twine
twine upload dist/*
11 changes: 4 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
@@ -2,27 +2,24 @@

setup(
name='kapre',
version='0.2.0',
description='Kapre: Keras Audio Preprocessors. Keras layers for audio pre-processing in deep learning',
version='0.3.0',
description='Kapre: Keras Audio Preprocessors. Tensorflow.Keras layers for audio pre-processing in deep learning',
author='Keunwoo Choi',
url='http://github.com/keunwoochoi/kapre/',
author_email='gnuchoi@gmail.com',
license='MIT',
packages=['kapre'],
package_data={
'kapre': [
'tests/fblog_8000_512.npy',
'tests/speech_test_file.npz',
'tests/test_audio_mel_g0.npy',
'tests/test_audio_stft_g0.npy',
]
},
include_package_data=True,
install_requires=[
'numpy >= 1.8.0',
'librosa >= 0.7.2',
'tensorflow >= 1.15',
'tensorflow >= 2.0',
],
keywords='audio music deep learning keras',
keywords='audio music speech sound deep learning keras tensorflow',
zip_safe=False,
)
Binary file removed tests/fblog_8000_512.npy
Binary file not shown.
Binary file removed tests/test_audio_mel_g0.npy
Binary file not shown.
Binary file removed tests/test_audio_stft_g0.npy
Binary file not shown.
117 changes: 74 additions & 43 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,95 @@
import os
import pytest
from kapre import backend as KPB
from tensorflow.keras import backend as K
import numpy as np
import librosa
from tensorflow.keras import backend as K
from kapre import backend as KPB
from kapre.backend import magnitude_to_decibel

TOL = 1e-5


def test_amplitude_to_decibel():
"""test for backend_keras.amplitude_to_decibel"""
from kapre.backend_keras import amplitude_to_decibel
@pytest.mark.parametrize('dynamic_range', [80.0, 120.0])
def test_magnitude_to_decibel(dynamic_range):
"""test for backend_keras.magnitude_to_decibel"""

x = np.array([[1e-20, 1e-5, 1e-3, 5e-2], [0.3, 1.0, 20.5, 9999]]) # random positive numbers

amin = 1e-5
dynamic_range = 80.0

x_decibel = 10 * np.log10(np.maximum(x, amin))
x_decibel = x_decibel - np.max(x_decibel, axis=(1,), keepdims=True)
x_decibel_ref = np.maximum(x_decibel, -1 * dynamic_range)
x_decibel_ref = np.stack(
(
librosa.power_to_db(x[0], amin=amin, ref=1.0, top_db=dynamic_range),
librosa.power_to_db(x[1], amin=amin, ref=1.0, top_db=dynamic_range),
),
axis=0,
)

x_var = K.variable(x)
x_decibel_kapre = amplitude_to_decibel(x_var, amin, dynamic_range)

assert np.allclose(K.eval(x_decibel_kapre), x_decibel_ref, atol=TOL)


def test_mel():
"""test for backend.mel_frequencies
For librosa wrappers, it only tests the data type of returned value
"""
assert KPB.mel(sr=22050, n_dft=512).dtype == K.floatx()


def test_get_stft_kernels():
"""test for backend.get_stft_kernels"""
n_dft = 4
real_kernels, imag_kernels = KPB.get_stft_kernels(n_dft)
x_decibel_kapre = magnitude_to_decibel(
x_var, ref_value=1.0, amin=amin, dynamic_range=dynamic_range
)

real_kernels_ref = np.array(
[[[[0.0, 0.0, 0.0]]], [[[0.5, 0.0, -0.5]]], [[[1.0, -1.0, 1.0]]], [[[0.5, 0.0, -0.5]]]],
dtype=K.floatx(),
np.testing.assert_allclose(K.eval(x_decibel_kapre), x_decibel_ref, atol=TOL)


@pytest.mark.parametrize('sample_rate', [44100, 22050])
@pytest.mark.parametrize('n_freq', [1025, 257])
@pytest.mark.parametrize('n_mels', [32, 128])
@pytest.mark.parametrize('f_min', [0.0, 200])
@pytest.mark.parametrize('f_max_ratio', [1.0, 0.5])
@pytest.mark.parametrize('htk', [True, False])
@pytest.mark.parametrize('norm', [None, 'slaney', 1.0])
def test_mel(sample_rate, n_freq, n_mels, f_min, f_max_ratio, htk, norm):
f_max = int(f_max_ratio * (sample_rate // 2))
mel_fb = KPB.filterbank_mel(
sample_rate=sample_rate,
n_freq=n_freq,
n_mels=n_mels,
f_min=f_min,
f_max=f_max,
htk=htk,
norm=norm,
)
imag_kernels_ref = np.array(
[[[[0.0, 0.0, 0.0]]], [[[0.0, -0.5, 0.0]]], [[[0.0, 0.0, 0.0]]], [[[0.0, 0.5, 0.0]]]],
dtype=K.floatx(),
mel_fb = mel_fb.numpy()

mel_fb_ref = librosa.filters.mel(
sr=sample_rate,
n_fft=(n_freq - 1) * 2,
n_mels=n_mels,
fmin=f_min,
fmax=f_max,
htk=htk,
norm=norm,
).T

assert mel_fb.dtype == K.floatx()
assert mel_fb.shape == (n_freq, n_mels)
np.testing.assert_allclose(mel_fb_ref, mel_fb)


@pytest.mark.parametrize('sample_rate', [44100, 22050])
@pytest.mark.parametrize('n_freq', [1025, 257])
@pytest.mark.parametrize('n_bins', [32, 84])
@pytest.mark.parametrize('bins_per_octave', [8, 12, 36])
@pytest.mark.parametrize('f_min', [1.0, 0.5])
@pytest.mark.parametrize('spread', [0.5, 0.125])
def test_filterbank_log(sample_rate, n_freq, n_bins, bins_per_octave, f_min, spread):
"""It only tests if the function is a valid wrapper"""
log_fb = KPB.filterbank_log(
sample_rate=sample_rate,
n_freq=n_freq,
n_bins=n_bins,
bins_per_octave=bins_per_octave,
f_min=f_min,
spread=spread,
)

assert real_kernels.shape == (n_dft, 1, 1, n_dft // 2 + 1)
assert imag_kernels.shape == (n_dft, 1, 1, n_dft // 2 + 1)
assert np.allclose(real_kernels, real_kernels_ref, atol=TOL)
assert np.allclose(imag_kernels, imag_kernels_ref, atol=TOL)
assert log_fb.dtype == K.floatx()
assert log_fb.shape == (n_freq, n_bins)


def test_filterbank_log():
"""test for backend.filterback_log"""
fblog_ref = np.load(os.path.join(os.path.dirname(__file__), 'fblog_8000_512.npy'))
fblog = KPB.filterbank_log(sr=8000, n_freq=512)
assert fblog.shape == fblog_ref.shape
assert np.allclose(fblog, fblog_ref, atol=TOL)
@pytest.mark.xfail()
def test_fb_log_fail():
_ = KPB.filterbank_log(sample_rate=22050, n_freq=513, n_bins=300, bins_per_octave=12)


if __name__ == '__main__':
507 changes: 228 additions & 279 deletions tests/test_time_frequency.py

Large diffs are not rendered by default.

85 changes: 0 additions & 85 deletions tests/test_utils.py

This file was deleted.

2 changes: 0 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -2,8 +2,6 @@
envlist = py37,python3.6,black
skipsdist = False
usedevelop = True
indexserver =
spotify = https://artifactory.spotify.net/artifactory/api/pypi/pypi/simple

[testenv]
whitelist_externals = python

0 comments on commit 8cdbb16

Please sign in to comment.