Skip to content

Commit

Permalink
add audio metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-rz committed Mar 1, 2024
1 parent 854338b commit f900919
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 1 deletion.
Empty file.
63 changes: 63 additions & 0 deletions k3_addons/metrics/audio/pesq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
import torch
from keras import ops
from k3_addons.utils.checks import _check_same_shape
from k3_addons.utils.imports import MULTIPROCESSING_AVAILABLE, PESQ_AVAILABLE
from k3_addons.api_export import k3_export


@k3_export(
[
"k3_addons.metrics.pesq",
"k3_addons.metrics.functional.pesq",
"k3_addons.metrics.audio.pesq",
]
)
def perceptual_evaluation_speech_quality(
preds,
target,
fs,
mode,
n_processes=1,
):
if not PESQ_AVAILABLE:
raise ModuleNotFoundError(
"PESQ metric requires that pesq is installed."
" Install it using `pip install pesq`."
)
import pesq as pesq_backend

if fs not in (8000, 16000):
raise ValueError(
f"Expected argument `fs` to either be 8000 or 16000 but got {fs}"
)
if mode not in ("wb", "nb"):
raise ValueError(
f"Expected argument `mode` to either be 'wb' or 'nb' but got {mode}"
)
_check_same_shape(preds, target)

if len(ops.shape(preds)) == 1:
pesq_val_np = pesq_backend.pesq(
fs, ops.convert_to_numpy(target), ops.convert_to_numpy(preds), mode
)
pesq_val = torch.tensor(pesq_val_np)
else:
preds_np = ops.convert_to_numpy(ops.reshape(preds, (-1, preds.shape[-1])))
target_np = ops.convert_to_numpy(ops.reshape(target, (-1, preds.shape[-1])))

if MULTIPROCESSING_AVAILABLE and n_processes != 1:
pesq_val_np = pesq_backend.pesq_batch(
fs, target_np, preds_np, mode, n_processor=n_processes
)
pesq_val_np = np.array(pesq_val_np)
else:
pesq_val_np = np.empty(shape=(preds_np.shape[0]))
for b in range(preds_np.shape[0]):
pesq_val_np[b] = pesq_backend.pesq(
fs, target_np[b, :], preds_np[b, :], mode
)
pesq_val = ops.convert_to_tensor(pesq_val_np)
pesq_val = ops.reshape(pesq_val, (preds.shape[:-1]))

return pesq_val
29 changes: 29 additions & 0 deletions k3_addons/metrics/audio/pesq_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import keras
import pytest
import torch
import numpy as np
from keras import ops
from k3_addons.metrics.audio.pesq import (
perceptual_evaluation_speech_quality as pesq_keras,
)
from torchmetrics.functional.audio.pesq import (
perceptual_evaluation_speech_quality as pesq_torch,
)


@pytest.mark.parametrize(
"input_shape, fs, mode",
[
((8000,), 16000, "wb"),
((8000,), 16000, "nb"),
],
)
def test_stoi(input_shape, fs, mode):
inputs = keras.random.uniform(input_shape)
target = keras.random.uniform(input_shape)
stoi_keras_val = pesq_keras(inputs, target, fs, mode)
inputs = torch.tensor(ops.convert_to_numpy(inputs))
target = torch.tensor(ops.convert_to_numpy(target))
stoi_torch_val = pesq_torch(inputs, target, fs, mode).numpy()

assert np.allclose(stoi_keras_val, stoi_torch_val, atol=1e-4)
30 changes: 30 additions & 0 deletions k3_addons/metrics/audio/stoi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from keras import ops
from k3_addons.utils.imports import PYSOTI_AVAILABLE

from k3_addons.utils.checks import _check_same_shape


def short_time_objective_intelligibility(preds, target, fs, extended=False):
if not PYSOTI_AVAILABLE:
raise ModuleNotFoundError(
"ShortTimeObjectiveIntelligibility metric requires that `pystoi` is installed."
" You can install it using `pip install pystoi`."
)
from pystoi import stoi as stoi_backend

_check_same_shape(preds, target)

if len(preds.shape) == 1:
stoi_val_np = stoi_backend(
ops.convert_to_numpy(target), ops.convert_to_numpy(preds), fs, extended
)
stoi_val = ops.convert_to_tensor(stoi_val_np)
else:
preds_np = ops.convert_to_numpy(ops.reshape(preds, (-1, preds.shape[-1])))
target_np = ops.convert_to_numpy(ops.reshape(target, (-1, preds.shape[-1])))
stoi_val_np = ops.empty(shape=(preds_np.shape[0]))
for b in range(ops.shape(preds_np)[0]):
stoi_val_np[b] = stoi_backend(target_np[b, :], preds_np[b, :], fs, extended)
stoi_val = ops.convert_to_tensor(stoi_val_np)
stoi_val = ops.reshape(stoi_val, (preds.shape[:-1]))
return stoi_val
31 changes: 31 additions & 0 deletions k3_addons/metrics/audio/stoi_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import keras
import pytest
import torch
import numpy as np
from keras import ops
from k3_addons.metrics.audio.stoi import (
short_time_objective_intelligibility as stoi_keras,
)
from torchmetrics.functional.audio.stoi import (
short_time_objective_intelligibility as stoi_torch,
)


@pytest.mark.parametrize(
"input_shape, fs, extended",
[
((8000,), 16000, False),
((8000,), 16000, True),
((8000,), 16000, False),
((8000,), 16000, True),
],
)
def test_stoi(input_shape, fs, extended):
inputs = keras.random.uniform(input_shape)
target = keras.random.uniform(input_shape)
stoi_keras_val = stoi_keras(inputs, target, fs, extended)
inputs = torch.tensor(ops.convert_to_numpy(inputs))
target = torch.tensor(ops.convert_to_numpy(target))
stoi_torch_val = stoi_torch(inputs, target, fs, extended).numpy()

assert np.allclose(stoi_keras_val, stoi_torch_val, atol=1e-4)
4 changes: 3 additions & 1 deletion requirements/requirements_test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
torch --index-url https://download.pytorch.org/whl/cpu
torchmetrics
tensorflow
jax[cpu]
jax[cpu]
pesq
pystoi

0 comments on commit f900919

Please sign in to comment.