From 50e0b7d211be7e5a352c5e33f2a50d1e03766527 Mon Sep 17 00:00:00 2001 From: Hendrik Schreiber Date: Wed, 16 Oct 2024 11:13:52 +0200 Subject: [PATCH] Fixed dealing with model files not available in package. More model tests. --- tempocnn/classifier.py | 11 +++++++++-- test/test_classifier.py | 43 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/tempocnn/classifier.py b/tempocnn/classifier.py index 6099696..4e63fc0 100644 --- a/tempocnn/classifier.py +++ b/tempocnn/classifier.py @@ -308,8 +308,15 @@ def _extract_from_package(resource): # ensure cache path exists cache_path.parent.mkdir(parents=True, exist_ok=True) - data = pkgutil.get_data("tempocnn", resource) - if not data: + data = None + try: + data = pkgutil.get_data("tempocnn", resource) + except FileNotFoundError: + pass + + # fallback + if data is None: + logger.info(f"Resource {resource} not found in package") data = _load_model_from_github(resource) # write to cache diff --git a/test/test_classifier.py b/test/test_classifier.py index 8c5f34a..83a561a 100644 --- a/test/test_classifier.py +++ b/test/test_classifier.py @@ -103,13 +103,50 @@ def test_mirex_with_real_data_interpolate(test_track): @pytest.mark.parametrize( - "model_name", ["cnn", "ismir2018", "fcn", "deeptemp", "shallowtemp", "deepsquare"] + "model_name,tempo", + [ + ("cnn", 100.0), + ("deepsquare_k1", 102.0), + ("deepsquare_k2", 100.0), + ("deepsquare_k4", 100.0), + ("deepsquare_k8", 100.0), + ("deepsquare_k16", 100.0), + ("deepsquare_k24", 100.0), + ("deeptemp_k2", 100.0), + ("deeptemp_k4", 100.0), + ("deeptemp_k8", 100.0), + ("deeptemp_k16", 100.0), + ("deeptemp_k24", 100.0), + # dt_maz models are not meant for drumtracks... + ("dt_maz_m_fold0", 160.0), + ("dt_maz_m_fold1", 93.0), + ("dt_maz_m_fold2", 73.0), + ("dt_maz_m_fold3", 156.0), + ("dt_maz_m_fold4", 83.0), + ("dt_maz_v_fold0", 160.0), + ("dt_maz_v_fold1", 93.0), + ("dt_maz_v_fold2", 73.0), + ("dt_maz_v_fold3", 156.0), + ("dt_maz_v_fold4", 83.0), + ("fcn", 100.0), + ("fma2018", 100.0), + ("ismir2018", 100.0), + ("shallowtemp_k1", 100.0), + ("shallowtemp_k2", 100.0), + ("shallowtemp_k4", 100.0), + ("shallowtemp_k6", 100.0), + ("shallowtemp_k8", 100.0), + ("shallowtemp_k12", 100.0), + ("deeptemp", 100.0), + ("shallowtemp", 100.0), + ("deepsquare", 100.0), + ], ) -def test_tempo_with_real_data(model_name, test_track): +def test_tempo_with_real_data(model_name, tempo, test_track): features = read_features(test_track) tempo_classifier = TempoClassifier(model_name) tempo = tempo_classifier.estimate_tempo(features) - assert tempo == pytest.approx(100.0) + assert tempo == pytest.approx(tempo) def test_quad_interpol_argmax():