Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multilingual dataset eval support #26

Merged
merged 7 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/public-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ jobs:
shell: bash -el {0}
run: |
TEST_DEV=cpu $(which python) tests/test_audio_encoder.py
- name: Evaluate Unit Test
shell: bash -el {0}
run: |
$(which python) tests/test_evaluate.py --dataset librispeech-debug --pipeline WhisperKit
- name: Folder Evaluate Unit Test
shell: bash -el {0}
run: |
$(which python) tests/test_evaluate.py --dataset common_voice_17_0-debug-zip --pipeline WhisperKit --language-subset en
- name: Lint
shell: bash -el {0}
run: |
Expand Down
11 changes: 11 additions & 0 deletions scripts/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ def cli():
choices=("WhisperKit", "whisper.cpp", "WhisperMLX", "WhisperOpenAIAPI"),
required=True
)
parser.add_argument(
"--force-language",
action="store_true",
help="If specified, forces the language in each data sample (if available)"
)
parser.add_argument(
"--language-subset",
type=str,
default=None,
help="If specified, filters the dataset for the given language"
)

# Alias the CLI args to match the test scripts
args = parser.parse_args()
Expand Down
11 changes: 4 additions & 7 deletions tests/test_aihub.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@

from whisperkit.android import models as android
from whisperkit.android import utils as aihub_utils
from whisperkit import audio_encoder as apple

from tests.test_audio_encoder import TEST_N_SAMPLES

TEST_VOCAB_SIZE = 51865
TEST_PSNR_THR = 40
Expand Down Expand Up @@ -48,13 +45,12 @@ def setUpClass(cls):
}
}
super().setUpClass()

@classmethod
def tearDownClass(cls):
cls.models = None
super().tearDownClass()


def test_torch2torch_correctness(self):
""" Test forward pass functionality and correctness of PyTorch models
"""
Expand All @@ -70,8 +66,9 @@ def test_torch2torch_correctness(self):
logger.info(f"torch2torch model={model_key} PSNR={psnr:.3g}")
else:
logger.info(
f"torch2torch correctness test skipped: Reference model does not exist for {model_key}")

"torch2torch correctness test skipped: "
f"Reference model does not exist for {model_key}"
)

def test_torch2aihub_performance_and_correctness(self):
""" Test AI Hub compilation and inference job results against local PyTorch test results
Expand Down
1 change: 1 addition & 0 deletions tests/test_audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TEST_PSNR_THR = 35

argmaxtools_test_utils.TEST_MIN_SPEEDUP_VS_CPU = 0.95
argmaxtools_test_utils.TEST_SKIP_SPEED_TESTS = True

# WhisperMelSpectrogram constants
# TEST_N_MELS = [80, 128]
Expand Down
26 changes: 21 additions & 5 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import datetime
import json
import os
import pprint
import subprocess
import unittest

Expand All @@ -17,6 +16,7 @@
from whisperkit._constants import EVALS_REPO_ID, MODEL_REPO_ID
from whisperkit.evaluate.datasets import EVAL_DATASETS
from whisperkit.evaluate.evaluate import evaluate
import whisperkit.evaluate.evaluate
from whisperkit.pipelines import get_pipeline_cls
from whisperkit.test_utils import BenchmarkContext

Expand All @@ -36,6 +36,7 @@
TEST_UPLOAD_RESULTS = os.getenv("TEST_UPLOAD_RESULTS", None) or False
TEST_QOI_REFERENCE = os.getenv("TEST_QOI_REFERENCE", None) or None # TODO
AVG_WER_SANITY_CHECK_THR = 0.5
LANGUAGE_SUBSET = None


class TestWhisperPipelineEvaluate(unittest.TestCase):
Expand All @@ -60,21 +61,28 @@ def setUpClass(cls) -> None:
shell=True
).stdout.decode('utf-8').strip()[:7]

inference_context_spec_dict = None
try:
inference_context_spec_dict = cls.inference_context.spec_dict()
except Exception as e:
logger.warning(f"Inference context spec dict failed: {e}")

cls.results = {
"results": evaluate(
cls.pipeline,
dataset_name=TEST_DATASET_NAME,
num_samples=TEST_NUM_SAMPLES,
cache_dir=TEST_CACHE_DIR,
num_proc=TEST_NUM_PROC),
num_proc=TEST_NUM_PROC,
language_subset=LANGUAGE_SUBSET),
"metadata": {
"num_samples": TEST_NUM_SAMPLES,
"num_proc": TEST_NUM_PROC,
"pipeline": TEST_PIPELINE,
"dataset_name": TEST_DATASET_NAME,
"model_version": TEST_MODEL_VERSION,
"whisperkittools_commit_hash": wkt_commit_hash,
"inference_context": cls.inference_context.spec_dict(),
"inference_context": inference_context_spec_dict,
"model_repo_id": MODEL_REPO_ID
}
}
Expand All @@ -96,7 +104,9 @@ def setUpClass(cls) -> None:
results_dir = os.path.join(
TEST_PIPELINE,
TEST_MODEL_VERSION.replace("/", "_"),
TEST_DATASET_NAME
TEST_DATASET_NAME,
"forced" if whisperkit.evaluate.evaluate.FORCE_LANGUAGE else "",
LANGUAGE_SUBSET if LANGUAGE_SUBSET else ""
)
results_fname = datetime.datetime.now().astimezone(
).strftime("%Y-%m-%d_%H:%M:%S_GMT%z") + ".json"
Expand All @@ -122,7 +132,7 @@ def test_evaluate(self):
def main(args):
global TEST_DATASET_NAME, TEST_PIPELINE, TEST_NUM_SAMPLES, TEST_CACHE_DIR, \
TEST_MODEL_VERSION, TEST_CODE_COMMIT_HASH, TEST_MODEL_COMMIT_HASH, \
TEST_NUM_PROC, TEST_UPLOAD_RESULTS, TEST_QOI_REFERENCE
TEST_NUM_PROC, TEST_UPLOAD_RESULTS, TEST_QOI_REFERENCE, LANGUAGE_SUBSET
TEST_DATASET_NAME = args.dataset
TEST_PIPELINE = args.pipeline
TEST_NUM_SAMPLES = args.num_samples
Expand All @@ -132,6 +142,10 @@ def main(args):
TEST_MODEL_COMMIT_HASH = args.model_commit_hash
TEST_NUM_PROC = args.num_proc
TEST_UPLOAD_RESULTS = args.upload_results
LANGUAGE_SUBSET = args.language_subset

# Force language option
whisperkit.evaluate.evaluate.FORCE_LANGUAGE = args.force_language

with argmaxtools_test_utils._get_test_cache_dir(
args.persistent_cache_dir
Expand Down Expand Up @@ -169,6 +183,8 @@ def main(args):
parser.add_argument("--model-commit-hash", type=str, default=None)
parser.add_argument("--num-proc", type=int, default=1)
parser.add_argument("--upload-results", action="store_true")
parser.add_argument("--language-subset", type=str, default=None)
parser.add_argument("--force-language", action="store_true")
args = parser.parse_args()

main(args)
2 changes: 2 additions & 0 deletions tests/test_text_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
TEST_PSNR_THR = 35
TEST_CACHE_DIR = os.getenv("TEST_CACHE_DIR", None) or "/tmp"

argmaxtools_test_utils.TEST_SKIP_SPEED_TESTS = True

# WhisperDecoderContextPrefill constants
TEST_PREFILL_CONSISTENCY_PSNR_THR = 20
TEST_BATCH = 16
Expand Down
108 changes: 107 additions & 1 deletion whisperkit/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
EVAL_DATASETS = [
"earnings22", "librispeech", "librispeech-200",
"earnings22-debug", "librispeech-debug",
"earnings22-12hours"
"earnings22-12hours",
"common_voice_17_0-debug-zip",
"common_voice_17_0-argmax_subset-400"
]
CUSTOM_EVAL_DATASET = os.getenv("EVAL_DATASET", None)
if CUSTOM_EVAL_DATASET is not None:
Expand All @@ -26,3 +28,107 @@
OPENAI_API_MAX_FILE_SIZE = 25e6 # bytes
OPENAI_API_COMPRESSED_UPLOAD_BIT_RATE = "12k" # kbps
TEST_DATA_REPO = "argmaxinc/whisperkit-test-data"

# Supported Languages
SUPPORTED_LANGUAGES = [
"af",
"am",
"ar",
"as",
"az",
"ba",
"be",
"bg",
"bn",
"bo",
"br",
"bs",
"ca",
"cs",
"cy",
"da",
"de",
"el",
"en",
"es",
"et",
"eu",
"fa",
"fi",
"fo",
"fr",
"gl",
"gu",
"ha",
"haw",
"he",
"hi",
"hr",
"ht",
"hu",
"hy",
"id",
"is",
"it",
"ja",
"jw",
"ka",
"kk",
"km",
"kn",
"ko",
"la",
"lb",
"ln",
"lo",
"lt",
"lv",
"mg",
"mi",
"mk",
"ml",
"mn",
"mr",
"ms",
"mt",
"my",
"ne",
"nl",
"nn",
"no",
"oc",
"pa",
"pl",
"ps",
"pt",
"ro",
"ru",
"sa",
"sd",
"si",
"sk",
"sl",
"sn",
"so",
"sq",
"sr",
"su",
"sv",
"sw",
"ta",
"te",
"tg",
"th",
"tk",
"tl",
"tr",
"tt",
"uk",
"ur",
"uz",
"vi",
"yi",
"yo",
"yue",
"zh",
]
10 changes: 5 additions & 5 deletions whisperkit/android/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, filter_length=1024, hop_length=512, win_length=None, window='
np.imag(fourier_basis[:self.cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])

assert(filter_length >= self.win_length)
assert (filter_length >= self.win_length)
fft_window = get_window(window, self.win_length, fftbins=True)
fft_window = pad_center(fft_window, size=filter_length)
fft_window = torch.from_numpy(fft_window).float()
Expand Down Expand Up @@ -91,14 +91,14 @@ def __init__(self, n_mels=80, n_fft=400, hop_length=160):
)

self.stft = DecomposedSTFT(
filter_length=self.n_fft,
hop_length=self.hop_length,
filter_length=self.n_fft,
hop_length=self.hop_length,
win_length=self.n_fft,
window='hann'
)

def forward(self, audio: tt.WhisperMelSpectrogramInputType) -> tt.WhisperMelSpectrogramOutputType:

transformed = self.stft(audio)
magnitudes = transformed[..., :-1]
mel_spec = self.mel_filters @ magnitudes
Expand All @@ -118,7 +118,7 @@ class WhisperDecoderPostProc(nn.Module):
def forward(self, logits):
TOKEN_TIMESTAMP_BEGIN = 50363
TOKEN_NO_SPEECH = 50361

# logprobs = F.log_softmax(logits, dim=0)
logprobs = torch.log(F.softmax(logits, dim=0))
timestamp_logprob = torch.logsumexp(logprobs[TOKEN_TIMESTAMP_BEGIN:], dim=0)
Expand Down
Loading
Loading