-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
31 lines (24 loc) · 1.48 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# -*- coding:utf-8 -*-
from argparse import ArgumentParser
def get_args():
parser = ArgumentParser(description='Hierarchical Math Solver')
parser.add_argument('--cuda', type=str, dest='cuda_id', default=None)
parser.add_argument('--checkpoint', type=str, dest='checkpoint', default=None)
parser.add_argument('--resume', action='store_true', dest='resume', default=False)
parser.add_argument('--log', type=str, dest='log', default=None)
parser.add_argument('--test-log', type=str, dest='test_log', default=None)
parser.add_argument('--seed', type=int, dest='seed', default=10)
parser.add_argument('--run-flag', type=str, dest='run_flag',default='train')
parser.add_argument('--epoch', type=int, dest='epoch', default=80)
parser.add_argument('--batch', type=int, dest='batch', default=64)
parser.add_argument('--lr', type=float, dest='lr', default=1e-3)
parser.add_argument('--weight-decay', type=float, dest='weight_decay', default=1e-5)
parser.add_argument('--step', type=int, dest='step', default=20)
parser.add_argument('--gamma', type=float, dest='gamma', default=0.5)
parser.add_argument('--beam', type=int, dest='beam', default=1)
parser.add_argument('--embed', type=int, dest='embed', default=128)
parser.add_argument('--hidden', type=int, dest='hidden', default=512)
parser.add_argument('--dropout', type=float, dest='dropout', default=0.5)
args = parser.parse_args()
args.use_cuda = args.cuda_id is not None
return args