Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add list_benchmarks to compare models easily #1253

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch>=1.4.0
torchvision>=0.5.0
pyyaml
pandas
26 changes: 23 additions & 3 deletions timm/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import fnmatch
from collections import defaultdict
from copy import deepcopy
import pandas as pd

__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'is_pretrained_cfg_key', 'has_pretrained_cfg_key', 'get_pretrained_cfg_value', 'is_model_pretrained']
Expand Down Expand Up @@ -66,8 +67,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)

Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
list_models('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
list_models('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
"""
if module:
all_models = list(_module_to_models[module])
Expand Down Expand Up @@ -96,6 +97,25 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
return list(sorted(models, key=_natural_key))


def list_benchmarks(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
""" Return list of available benchmarks on imagenet, sorted alphabetically

Args:
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)

Example:
list_benchmarks('gluon_resnet*') -- returns a pandas dataframe with all the benchmarks starting with 'gluon_resnet'
"""
models = list_models(filter=filter, module=module, pretrained=pretrained, exclude_filters=exclude_filters, name_matches_cfg=name_matches_cfg)
df = pd.read_csv("https://raw.githubusercontent.com/rwightman/pytorch-image-models/master/results/results-imagenet.csv")

return df[df["model"].isin(models)]


def is_model(model_name):
""" Check if a model name exists
"""
Expand Down Expand Up @@ -156,4 +176,4 @@ def get_pretrained_cfg_value(model_name, cfg_key):
"""
if model_name in _model_pretrained_cfgs:
return _model_pretrained_cfgs[model_name].get(cfg_key, None)
return None
return None