Skip to content

Commit

Permalink
Fixed dealing with model files not available in package.
Browse files Browse the repository at this point in the history
More model tests.
  • Loading branch information
hendriks73 committed Oct 16, 2024
1 parent 5f01036 commit 50e0b7d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
11 changes: 9 additions & 2 deletions tempocnn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,15 @@ def _extract_from_package(resource):
# ensure cache path exists
cache_path.parent.mkdir(parents=True, exist_ok=True)

data = pkgutil.get_data("tempocnn", resource)
if not data:
data = None
try:
data = pkgutil.get_data("tempocnn", resource)
except FileNotFoundError:
pass

# fallback
if data is None:
logger.info(f"Resource {resource} not found in package")
data = _load_model_from_github(resource)

# write to cache
Expand Down
43 changes: 40 additions & 3 deletions test/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,50 @@ def test_mirex_with_real_data_interpolate(test_track):


@pytest.mark.parametrize(
"model_name", ["cnn", "ismir2018", "fcn", "deeptemp", "shallowtemp", "deepsquare"]
"model_name,tempo",
[
("cnn", 100.0),
("deepsquare_k1", 102.0),
("deepsquare_k2", 100.0),
("deepsquare_k4", 100.0),
("deepsquare_k8", 100.0),
("deepsquare_k16", 100.0),
("deepsquare_k24", 100.0),
("deeptemp_k2", 100.0),
("deeptemp_k4", 100.0),
("deeptemp_k8", 100.0),
("deeptemp_k16", 100.0),
("deeptemp_k24", 100.0),
# dt_maz models are not meant for drumtracks...
("dt_maz_m_fold0", 160.0),
("dt_maz_m_fold1", 93.0),
("dt_maz_m_fold2", 73.0),
("dt_maz_m_fold3", 156.0),
("dt_maz_m_fold4", 83.0),
("dt_maz_v_fold0", 160.0),
("dt_maz_v_fold1", 93.0),
("dt_maz_v_fold2", 73.0),
("dt_maz_v_fold3", 156.0),
("dt_maz_v_fold4", 83.0),
("fcn", 100.0),
("fma2018", 100.0),
("ismir2018", 100.0),
("shallowtemp_k1", 100.0),
("shallowtemp_k2", 100.0),
("shallowtemp_k4", 100.0),
("shallowtemp_k6", 100.0),
("shallowtemp_k8", 100.0),
("shallowtemp_k12", 100.0),
("deeptemp", 100.0),
("shallowtemp", 100.0),
("deepsquare", 100.0),
],
)
def test_tempo_with_real_data(model_name, test_track):
def test_tempo_with_real_data(model_name, tempo, test_track):
features = read_features(test_track)
tempo_classifier = TempoClassifier(model_name)
tempo = tempo_classifier.estimate_tempo(features)
assert tempo == pytest.approx(100.0)
assert tempo == pytest.approx(tempo)


def test_quad_interpol_argmax():
Expand Down

0 comments on commit 50e0b7d

Please sign in to comment.