-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
156 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import numpy as np | ||
import torch | ||
from keras import ops | ||
from k3_addons.utils.checks import _check_same_shape | ||
from k3_addons.utils.imports import MULTIPROCESSING_AVAILABLE, PESQ_AVAILABLE | ||
from k3_addons.api_export import k3_export | ||
|
||
|
||
@k3_export( | ||
[ | ||
"k3_addons.metrics.pesq", | ||
"k3_addons.metrics.functional.pesq", | ||
"k3_addons.metrics.audio.pesq", | ||
] | ||
) | ||
def perceptual_evaluation_speech_quality( | ||
preds, | ||
target, | ||
fs, | ||
mode, | ||
n_processes=1, | ||
): | ||
if not PESQ_AVAILABLE: | ||
raise ModuleNotFoundError( | ||
"PESQ metric requires that pesq is installed." | ||
" Install it using `pip install pesq`." | ||
) | ||
import pesq as pesq_backend | ||
|
||
if fs not in (8000, 16000): | ||
raise ValueError( | ||
f"Expected argument `fs` to either be 8000 or 16000 but got {fs}" | ||
) | ||
if mode not in ("wb", "nb"): | ||
raise ValueError( | ||
f"Expected argument `mode` to either be 'wb' or 'nb' but got {mode}" | ||
) | ||
_check_same_shape(preds, target) | ||
|
||
if len(ops.shape(preds)) == 1: | ||
pesq_val_np = pesq_backend.pesq( | ||
fs, ops.convert_to_numpy(target), ops.convert_to_numpy(preds), mode | ||
) | ||
pesq_val = torch.tensor(pesq_val_np) | ||
else: | ||
preds_np = ops.convert_to_numpy(ops.reshape(preds, (-1, preds.shape[-1]))) | ||
target_np = ops.convert_to_numpy(ops.reshape(target, (-1, preds.shape[-1]))) | ||
|
||
if MULTIPROCESSING_AVAILABLE and n_processes != 1: | ||
pesq_val_np = pesq_backend.pesq_batch( | ||
fs, target_np, preds_np, mode, n_processor=n_processes | ||
) | ||
pesq_val_np = np.array(pesq_val_np) | ||
else: | ||
pesq_val_np = np.empty(shape=(preds_np.shape[0])) | ||
for b in range(preds_np.shape[0]): | ||
pesq_val_np[b] = pesq_backend.pesq( | ||
fs, target_np[b, :], preds_np[b, :], mode | ||
) | ||
pesq_val = ops.convert_to_tensor(pesq_val_np) | ||
pesq_val = ops.reshape(pesq_val, (preds.shape[:-1])) | ||
|
||
return pesq_val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import keras | ||
import pytest | ||
import torch | ||
import numpy as np | ||
from keras import ops | ||
from k3_addons.metrics.audio.pesq import ( | ||
perceptual_evaluation_speech_quality as pesq_keras, | ||
) | ||
from torchmetrics.functional.audio.pesq import ( | ||
perceptual_evaluation_speech_quality as pesq_torch, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape, fs, mode", | ||
[ | ||
((8000,), 16000, "wb"), | ||
((8000,), 16000, "nb"), | ||
], | ||
) | ||
def test_stoi(input_shape, fs, mode): | ||
inputs = keras.random.uniform(input_shape) | ||
target = keras.random.uniform(input_shape) | ||
stoi_keras_val = pesq_keras(inputs, target, fs, mode) | ||
inputs = torch.tensor(ops.convert_to_numpy(inputs)) | ||
target = torch.tensor(ops.convert_to_numpy(target)) | ||
stoi_torch_val = pesq_torch(inputs, target, fs, mode).numpy() | ||
|
||
assert np.allclose(stoi_keras_val, stoi_torch_val, atol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from keras import ops | ||
from k3_addons.utils.imports import PYSOTI_AVAILABLE | ||
|
||
from k3_addons.utils.checks import _check_same_shape | ||
|
||
|
||
def short_time_objective_intelligibility(preds, target, fs, extended=False): | ||
if not PYSOTI_AVAILABLE: | ||
raise ModuleNotFoundError( | ||
"ShortTimeObjectiveIntelligibility metric requires that `pystoi` is installed." | ||
" You can install it using `pip install pystoi`." | ||
) | ||
from pystoi import stoi as stoi_backend | ||
|
||
_check_same_shape(preds, target) | ||
|
||
if len(preds.shape) == 1: | ||
stoi_val_np = stoi_backend( | ||
ops.convert_to_numpy(target), ops.convert_to_numpy(preds), fs, extended | ||
) | ||
stoi_val = ops.convert_to_tensor(stoi_val_np) | ||
else: | ||
preds_np = ops.convert_to_numpy(ops.reshape(preds, (-1, preds.shape[-1]))) | ||
target_np = ops.convert_to_numpy(ops.reshape(target, (-1, preds.shape[-1]))) | ||
stoi_val_np = ops.empty(shape=(preds_np.shape[0])) | ||
for b in range(ops.shape(preds_np)[0]): | ||
stoi_val_np[b] = stoi_backend(target_np[b, :], preds_np[b, :], fs, extended) | ||
stoi_val = ops.convert_to_tensor(stoi_val_np) | ||
stoi_val = ops.reshape(stoi_val, (preds.shape[:-1])) | ||
return stoi_val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import keras | ||
import pytest | ||
import torch | ||
import numpy as np | ||
from keras import ops | ||
from k3_addons.metrics.audio.stoi import ( | ||
short_time_objective_intelligibility as stoi_keras, | ||
) | ||
from torchmetrics.functional.audio.stoi import ( | ||
short_time_objective_intelligibility as stoi_torch, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shape, fs, extended", | ||
[ | ||
((8000,), 16000, False), | ||
((8000,), 16000, True), | ||
((8000,), 16000, False), | ||
((8000,), 16000, True), | ||
], | ||
) | ||
def test_stoi(input_shape, fs, extended): | ||
inputs = keras.random.uniform(input_shape) | ||
target = keras.random.uniform(input_shape) | ||
stoi_keras_val = stoi_keras(inputs, target, fs, extended) | ||
inputs = torch.tensor(ops.convert_to_numpy(inputs)) | ||
target = torch.tensor(ops.convert_to_numpy(target)) | ||
stoi_torch_val = stoi_torch(inputs, target, fs, extended).numpy() | ||
|
||
assert np.allclose(stoi_keras_val, stoi_torch_val, atol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
torch --index-url https://download.pytorch.org/whl/cpu | ||
torchmetrics | ||
tensorflow | ||
jax[cpu] | ||
jax[cpu] | ||
pesq | ||
pystoi |