Skip to content

Commit

Permalink
fix evaluate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
awkrail committed Oct 7, 2024
1 parent e2cc675 commit 23a1c9f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ python training/train.py --model moment_detr --dataset tvsum --feature clip_slow
#### Evaluation
The evaluation command is:
```
python training/evaluate.py --model MODEL --dataset DATASET --feature FEATURE --split {val,test} --model_path MODEL_PATH --eval_path EVAL_PATH
python training/evaluate.py --model MODEL --dataset DATASET --feature FEATURE --split {val,test} --model_path MODEL_PATH --eval_path EVAL_PATH [--domain DOMAIN]
```
(**Example 1**) Evaluating Moment DETR w/ CLIP+Slowfast on the QVHighlights val set:
```
Expand Down
31 changes: 19 additions & 12 deletions training/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def start_inference(opt, domain=None):
logger.info("metrics_no_nms {}".format(pprint.pformat(metrics["brief"], indent=4)))


def check_valid_combination(dataset, feature):
def check_valid_combination(dataset, feature, domain):
dataset_feature_map = {
'qvhighlight': ['resnet_glove', 'clip', 'clip_slowfast', 'clip_slowfast_pann'],
'qvhighlight_pretrain': ['resnet_glove', 'clip', 'clip_slowfast', 'clip_slowfast_pann'],
Expand All @@ -403,7 +403,16 @@ def check_valid_combination(dataset, feature):
'youtube_highlight': ['clip', 'clip_slowfast'],
'clotho-moment': ['clap'],
}
return feature in dataset_feature_map[dataset]

domain_map = {
'tvsum': ['BK', 'BT', 'DS', 'FM', 'GA', 'MS', 'PK', 'PR', 'VT', 'VU'],
'youtube_highlight': ['dog', 'gymnastics', 'parkour', 'skating', 'skiing', 'surfing'],
}

if dataset in domain_map:
return feature in dataset_feature_map[dataset] and domain in domain_map[dataset]
else:
return feature in dataset_feature_map[dataset]


if __name__ == '__main__':
Expand All @@ -421,26 +430,24 @@ def check_valid_combination(dataset, feature):
parser.add_argument('--model_path', type=str, required=True, help='saved model path')
parser.add_argument('--split', type=str, required=True, choices=['val', 'test'], help='val or test')
parser.add_argument('--eval_path', type=str, required=True, help='evaluation data')
args = parser.parse_args()
parser.add_argument('--domain', '-dm', type=str,
choices=['BK', 'BT', 'DS', 'FM', 'GA', 'MS', 'PK', 'PR', 'VT', 'VU',
'dog', 'gymnastics', 'parkour', 'skating', 'skiing', 'surfing'],
help='domain for highlight detection dataset (e.g., BK for TVSum, dog for YouTube Highlight).')

is_valid = check_valid_combination(args.dataset, args.feature)
args = parser.parse_args()
is_valid = check_valid_combination(args.dataset, args.feature, args.domain)

if is_valid:
option_manager = BaseOptions(args.model, args.dataset, args.feature)
option_manager = BaseOptions(args.model, args.dataset, args.feature, args.domain)
option_manager.parse()
opt = option_manager.option
os.makedirs(opt.results_dir, exist_ok=True)

opt.model_path = args.model_path
opt.eval_split_name = args.split
opt.eval_path = args.eval_path

if 'domains' in opt:
for domain in opt.domains:
opt.results_dir = os.path.join(opt.results_dir, domain)
start_inference(opt, domain=domain)
else:
start_inference(opt)
start_inference(opt, domain=args.domain)

else:
raise ValueError('The combination of dataset and feature is invalid: dataset={}, feature={}'.format(args.dataset, args.feature))

0 comments on commit 23a1c9f

Please sign in to comment.