Skip to content

Commit

Permalink
Use Kaggle Models URL for downloading Perch. This allows loading Vers…
Browse files Browse the repository at this point in the history
…ion 8. Also adds a convenience method for loading the model from just a version number.

PiperOrigin-RevId: 622223336
  • Loading branch information
sdenton4 authored and copybara-github committed Apr 5, 2024
1 parent bf471da commit e27e953
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion chirp/inference/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
import tensorflow.compat.v1 as tf1
import tensorflow_hub as hub

PERCH_TF_HUB_URL = 'https://tfhub.dev/google/bird-vocalization-classifier'
PERCH_TF_HUB_URL = (
'https://www.kaggle.com/models/google/'
'bird-vocalization-classifier/frameworks/TensorFlow2/'
'variations/bird-vocalization-classifier/versions'
)


def model_class_map() -> dict[str, Any]:
Expand Down Expand Up @@ -287,6 +291,10 @@ def from_tfhub(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF':
raise ValueError(
'Exactly one of tfhub_version and model_path should be set.'
)
if config.tfhub_version in (5, 6, 7):
# Due to SNAFUs uploading the new model version to KaggleModels,
# some version numbers were skipped.
raise ValueError('TFHub version 5, 6, and 7 do not exist.')

model_url = f'{PERCH_TF_HUB_URL}/{config.tfhub_version}'
# This model behaves exactly like the usual saved_model.
Expand All @@ -303,6 +311,20 @@ def from_tfhub(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF':
model=model, class_list=class_lists, batchable=batchable, **config
)

@classmethod
def load_version(
cls, tfhub_version: int, hop_size_s: float = 5.0
) -> 'TaxonomyModelTF':
cfg = config_dict.ConfigDict({
'model_path': '',
'sample_rate': 32000,
'window_size_s': 5.0,
'hop_size_s': hop_size_s,
'target_peak': 0.25,
'tfhub_version': tfhub_version,
})
return cls.from_tfhub(cfg)

@classmethod
def from_config(cls, config: config_dict.ConfigDict) -> 'TaxonomyModelTF':
logging.info('Loading taxonomy model...')
Expand Down

0 comments on commit e27e953

Please sign in to comment.