-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_classifier.py
35 lines (24 loc) · 966 Bytes
/
train_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from __future__ import print_function
import matplotlib
from one_shot_aug import PretrainedClassifier
matplotlib.use('Agg')
import argparse
# import matplotlib.pyplot as plt
from data_loader import *
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--config', type=str, default='config', help='config file name in config dir')
parser.add_argument('--mode', type=str, default='train', help='Mode (train/test/train_and_evaluate)')
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
# Print Config setting
Config(args.config)
print("Config: ", Config)
if Config.get("description", None):
print("Config Description")
for key, value in Config.description.items():
print(f" - {key}: {value}")
classifier = PretrainedClassifier()
c = classifier.train_model()
# Save model
c.model.save_state_dict('fine_tuned_best_model.pt')