diff --git a/projects/benchmark/classification/benchmark.py b/projects/benchmark/classification/benchmark.py index 33d4b498..aa08ac4b 100644 --- a/projects/benchmark/classification/benchmark.py +++ b/projects/benchmark/classification/benchmark.py @@ -23,6 +23,7 @@ def param_count(model): def get_mean_std(mode="imagenet_default_mean_std"): + mode = mode.upper() if mode == "IMAGENET_DEFAULT_MEAN_STD": mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225)