-
Notifications
You must be signed in to change notification settings - Fork 0
/
args.py
26 lines (25 loc) · 2.05 KB
/
args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--optim", default="bisam_log", type=str, help="Choose optimizers from sam, bisam_log, bisam_tanh.")
parser.add_argument("--adaptive", default=False, type=bool, help="True if you want to use the Adaptive SAM.")
parser.add_argument("--batch_size", default=128, type=int, help="Batch size used in the training and validation loop.")
parser.add_argument("--depth", default=16, type=int, help="Number of layers.")
parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=200, type=int, help="Total number of epochs.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")
parser.add_argument("--learning_rate", default=0.1, type=float, help="Base learning rate at the start of the training.")
parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
parser.add_argument("--threads", default=2, type=int, help="Number of CPU threads for dataloaders.")
parser.add_argument("--rho", default=0.05, type=float, help="Rho parameter for SAM.")
parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
parser.add_argument("--width_factor", default=8, type=int, help="How many times wider compared to normal ResNet.")
parser.add_argument('--valid', action='store_true', help='Use the validation dataset for finetuning.')
parser.add_argument("--model", default='resnet56', type=str, help="Choose model.")
parser.add_argument("--dataset", default='cifar10', type=str, help="Choose dataset.")
parser.add_argument("--mu", default=1.0, type=float, help="\mu for BiSAM.")
parser.add_argument("--alpha", default=0.1, type=float, help="\alpha for BiSAM.")
parser.add_argument("--adam", action='store_true', help="Choose base optimizer from sgd and adam")
parser.add_argument("--seed", default=42, type=int, help="seed")
args = parser.parse_args()
return args