-
Notifications
You must be signed in to change notification settings - Fork 0
/
options.py
107 lines (96 loc) · 5.85 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import torch
import torch.nn as nn
parser = argparse.ArgumentParser()
parser.add_argument('--max_epoch', type=int, default=30)
parser.add_argument('--learning_rate', type=float, default=0.0001)
parser.add_argument('--D_learning_rate', type=float, default=0.0001)
parser.add_argument('--batch_size', type=int, default=8) # 8
parser.add_argument('--n_critic', type=int, default=5) # hyperparameter k in the paper
parser.add_argument('--random_seed', type=int, default=10)
parser.add_argument('--model_save_file', default='./save/model')
parser.add_argument('--init_save_file', default='./save/init_model')
# Enhancement1: domain label smoothing
parser.add_argument('--label_smooth', action='store_true', default=True, help='domain label smooth')
parser.add_argument('--eps', type=float, default=0.2, help='eps for domain label smooth')
parser.add_argument('--lambd_rplr', type=float, default=1, help='coefficient of pseudo-label loss ')
parser.add_argument('--lambd', type=float, default=0.0001)
parser.add_argument('--dataset', default='prep-amazon') # prep-amazon or fdu-mtl
parser.add_argument('--prep_amazon_file', default='../data/prep-amazon/amazon.pkl')
parser.add_argument('--fdu_mtl_dir', default='fdu-mtl/') # default='../data/fdu-mtl/'
# pre-shuffle for cross validation data split
parser.add_argument('--use_preshuffle/', dest='use_preshuffle', action='store_true', default=True)
parser.add_argument('--amazon_preshuffle_file', default='../data/prep-amazon/amazon-shuffle-indices.pkl')
# for preprocessed amazon dataset; set to -1 to use 30000
parser.add_argument('--feature_num', type=int, default=5000) # 5000d feature vector
# labeled domains: if not set, will use default domains for the dataset
parser.add_argument('--domains', type=str, nargs='+', default=[])
parser.add_argument('--unlabeled_domains', type=str, nargs='+', default=[])
parser.add_argument('--dev_domains', type=str, nargs='+', default=[])
parser.add_argument('--emb_filename', default='w2v/word2vec.txt')
parser.add_argument('--kfold', type=int, default=5) # cross-validation (n>=3)
parser.add_argument('--max_seq_len', type=int, default=0) # set to <=0 to not truncate
# which data to be used as unlabeled data: train, unlabeled, or both
parser.add_argument('--unlabeled_data', type=str, default='both')
parser.add_argument('--test_only', dest='test_only', action='store_true')
parser.add_argument('--fix_emb', action='store_true', default=False)
parser.add_argument('--random_emb', action='store_true', default=False)
# Feature Extractor model: dan or lstm or cnn (for FDU-MTL dataset) or mlp (for prep-amazon dataset)
parser.add_argument('--model', default='mlp') # mlp or cnn
# for LSTM model
parser.add_argument('--attn', default='dot') # attention mechanism (for LSTM): avg, last, dot
parser.add_argument('--bdrnn/', dest='bdrnn', action='store_true', default=True) # bi-directional LSTM
parser.add_argument('--no_bdrnn/', dest='bdrnn', action='store_false', default=True) # bi-directional LSTM
# for DAN model
parser.add_argument('--sum_pooling/', dest='sum_pooling', action='store_true')
parser.add_argument('--avg_pooling/', dest='sum_pooling', action='store_false')
# for CNN model
parser.add_argument('--kernel_num', type=int, default=200)
parser.add_argument('--kernel_sizes', type=int, nargs='+', default=[3,4,5])
# for MLP model
# F size: feature_num -> F_hidden_sizes -> shared/domain_hidden_size
parser.add_argument('--F_hidden_sizes', type=int, nargs='+', default=[1000, 500])
# gr (gradient reversing, NLL loss in the paper), bs (boundary seeking), l2
# in the paper, we did not talk about the BS loss;
# it's nearly equivalent to the GR (NLL) loss
# parser.add_argument('--loss', default='gr')
parser.add_argument('--shared_hidden_size', type=int, default=128)
parser.add_argument('--domain_hidden_size', type=int, default=64)
parser.add_argument('--activation', default='relu') # relu, leaky
parser.add_argument('--F_layers', type=int, default=1)
parser.add_argument('--C_layers', type=int, default=1)
parser.add_argument('--D_layers', type=int, default=1)
parser.add_argument('--wgan_trick/', dest='wgan_trick', action='store_true', default=True)
parser.add_argument('--no_wgan_trick/', dest='wgan_trick', action='store_false')
# batch normalization
parser.add_argument('--F_bn/', dest='F_bn', action='store_true', default=False)
parser.add_argument('--no_F_bn/', dest='F_bn', action='store_false')
parser.add_argument('--C_bn/', dest='C_bn', action='store_true', default=True)
parser.add_argument('--no_C_bn/', dest='C_bn', action='store_false')
parser.add_argument('--D_bn/', dest='D_bn', action='store_true', default=True)
parser.add_argument('--no_D_bn/', dest='D_bn', action='store_false')
parser.add_argument('--dropout', type=float, default=0.4)
parser.add_argument('--device/', dest='device', type=str, default='cuda')
parser.add_argument('--debug/', dest='debug', action='store_true')
opt = parser.parse_args()
# automatically prepared options
if not torch.cuda.is_available():
opt.device = 'cpu'
if len(opt.domains) == 0:
# use default domains
if opt.dataset.lower() == 'prep-amazon':
opt.domains = ['books', 'dvd', 'electronics', 'kitchen'] # opt.domains = ['books', 'dvd', 'electronics', 'kitchen']
elif opt.dataset.lower() == 'fdu-mtl':
opt.domains = ['books', 'electronics', 'dvd', 'kitchen_housewares', 'apparel', 'camera_photo', 'health_personal_care', 'music', 'toys_games', 'video', 'baby', 'magazines', 'software', 'sports_outdoors', 'imdb', 'MR']
else:
raise Exception(f'Unknown dataset {opt.dataset}')
opt.all_domains = opt.domains + opt.unlabeled_domains
if len(opt.dev_domains) == 0:
opt.dev_domains = opt.all_domains
opt.max_kernel_size = max(opt.kernel_sizes)
if opt.activation.lower() == 'relu':
opt.act_unit = nn.ReLU()
elif opt.activation.lower() == 'leaky':
opt.act_unit = nn.LeakyReLU()
else:
raise Exception(f'Unknown activation function {opt.activation}')