diff --git a/k3_addons/metrics/audio/__init__.py b/k3_addons/metrics/audio/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/k3_addons/metrics/audio/pesq.py b/k3_addons/metrics/audio/pesq.py new file mode 100644 index 0000000..ee5303e --- /dev/null +++ b/k3_addons/metrics/audio/pesq.py @@ -0,0 +1,63 @@ +import numpy as np +import torch +from keras import ops +from k3_addons.utils.checks import _check_same_shape +from k3_addons.utils.imports import MULTIPROCESSING_AVAILABLE, PESQ_AVAILABLE +from k3_addons.api_export import k3_export + + +@k3_export( + [ + "k3_addons.metrics.pesq", + "k3_addons.metrics.functional.pesq", + "k3_addons.metrics.audio.pesq", + ] +) +def perceptual_evaluation_speech_quality( + preds, + target, + fs, + mode, + n_processes=1, +): + if not PESQ_AVAILABLE: + raise ModuleNotFoundError( + "PESQ metric requires that pesq is installed." + " Install it using `pip install pesq`." + ) + import pesq as pesq_backend + + if fs not in (8000, 16000): + raise ValueError( + f"Expected argument `fs` to either be 8000 or 16000 but got {fs}" + ) + if mode not in ("wb", "nb"): + raise ValueError( + f"Expected argument `mode` to either be 'wb' or 'nb' but got {mode}" + ) + _check_same_shape(preds, target) + + if len(ops.shape(preds)) == 1: + pesq_val_np = pesq_backend.pesq( + fs, ops.convert_to_numpy(target), ops.convert_to_numpy(preds), mode + ) + pesq_val = torch.tensor(pesq_val_np) + else: + preds_np = ops.convert_to_numpy(ops.reshape(preds, (-1, preds.shape[-1]))) + target_np = ops.convert_to_numpy(ops.reshape(target, (-1, preds.shape[-1]))) + + if MULTIPROCESSING_AVAILABLE and n_processes != 1: + pesq_val_np = pesq_backend.pesq_batch( + fs, target_np, preds_np, mode, n_processor=n_processes + ) + pesq_val_np = np.array(pesq_val_np) + else: + pesq_val_np = np.empty(shape=(preds_np.shape[0])) + for b in range(preds_np.shape[0]): + pesq_val_np[b] = pesq_backend.pesq( + fs, target_np[b, :], preds_np[b, :], mode + ) + pesq_val = ops.convert_to_tensor(pesq_val_np) + pesq_val = ops.reshape(pesq_val, (preds.shape[:-1])) + + return pesq_val diff --git a/k3_addons/metrics/audio/pesq_test.py b/k3_addons/metrics/audio/pesq_test.py new file mode 100644 index 0000000..9a1a8ec --- /dev/null +++ b/k3_addons/metrics/audio/pesq_test.py @@ -0,0 +1,29 @@ +import keras +import pytest +import torch +import numpy as np +from keras import ops +from k3_addons.metrics.audio.pesq import ( + perceptual_evaluation_speech_quality as pesq_keras, +) +from torchmetrics.functional.audio.pesq import ( + perceptual_evaluation_speech_quality as pesq_torch, +) + + +@pytest.mark.parametrize( + "input_shape, fs, mode", + [ + ((8000,), 16000, "wb"), + ((8000,), 16000, "nb"), + ], +) +def test_stoi(input_shape, fs, mode): + inputs = keras.random.uniform(input_shape) + target = keras.random.uniform(input_shape) + stoi_keras_val = pesq_keras(inputs, target, fs, mode) + inputs = torch.tensor(ops.convert_to_numpy(inputs)) + target = torch.tensor(ops.convert_to_numpy(target)) + stoi_torch_val = pesq_torch(inputs, target, fs, mode).numpy() + + assert np.allclose(stoi_keras_val, stoi_torch_val, atol=1e-4) diff --git a/k3_addons/metrics/audio/stoi.py b/k3_addons/metrics/audio/stoi.py new file mode 100644 index 0000000..d76b58d --- /dev/null +++ b/k3_addons/metrics/audio/stoi.py @@ -0,0 +1,30 @@ +from keras import ops +from k3_addons.utils.imports import PYSOTI_AVAILABLE + +from k3_addons.utils.checks import _check_same_shape + + +def short_time_objective_intelligibility(preds, target, fs, extended=False): + if not PYSOTI_AVAILABLE: + raise ModuleNotFoundError( + "ShortTimeObjectiveIntelligibility metric requires that `pystoi` is installed." + " You can install it using `pip install pystoi`." + ) + from pystoi import stoi as stoi_backend + + _check_same_shape(preds, target) + + if len(preds.shape) == 1: + stoi_val_np = stoi_backend( + ops.convert_to_numpy(target), ops.convert_to_numpy(preds), fs, extended + ) + stoi_val = ops.convert_to_tensor(stoi_val_np) + else: + preds_np = ops.convert_to_numpy(ops.reshape(preds, (-1, preds.shape[-1]))) + target_np = ops.convert_to_numpy(ops.reshape(target, (-1, preds.shape[-1]))) + stoi_val_np = ops.empty(shape=(preds_np.shape[0])) + for b in range(ops.shape(preds_np)[0]): + stoi_val_np[b] = stoi_backend(target_np[b, :], preds_np[b, :], fs, extended) + stoi_val = ops.convert_to_tensor(stoi_val_np) + stoi_val = ops.reshape(stoi_val, (preds.shape[:-1])) + return stoi_val diff --git a/k3_addons/metrics/audio/stoi_test.py b/k3_addons/metrics/audio/stoi_test.py new file mode 100644 index 0000000..1a22ef9 --- /dev/null +++ b/k3_addons/metrics/audio/stoi_test.py @@ -0,0 +1,31 @@ +import keras +import pytest +import torch +import numpy as np +from keras import ops +from k3_addons.metrics.audio.stoi import ( + short_time_objective_intelligibility as stoi_keras, +) +from torchmetrics.functional.audio.stoi import ( + short_time_objective_intelligibility as stoi_torch, +) + + +@pytest.mark.parametrize( + "input_shape, fs, extended", + [ + ((8000,), 16000, False), + ((8000,), 16000, True), + ((8000,), 16000, False), + ((8000,), 16000, True), + ], +) +def test_stoi(input_shape, fs, extended): + inputs = keras.random.uniform(input_shape) + target = keras.random.uniform(input_shape) + stoi_keras_val = stoi_keras(inputs, target, fs, extended) + inputs = torch.tensor(ops.convert_to_numpy(inputs)) + target = torch.tensor(ops.convert_to_numpy(target)) + stoi_torch_val = stoi_torch(inputs, target, fs, extended).numpy() + + assert np.allclose(stoi_keras_val, stoi_torch_val, atol=1e-4) diff --git a/requirements/requirements_test.txt b/requirements/requirements_test.txt index 494f270..cbf6c33 100644 --- a/requirements/requirements_test.txt +++ b/requirements/requirements_test.txt @@ -1,4 +1,6 @@ torch --index-url https://download.pytorch.org/whl/cpu torchmetrics tensorflow -jax[cpu] \ No newline at end of file +jax[cpu] +pesq +pystoi \ No newline at end of file