forked from Shawn1993/cnn-text-classification-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·117 lines (100 loc) · 5.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#! /usr/bin/env python
import os
import argparse
import datetime
import torch
import torchtext.data as data
import torchtext.datasets as datasets
import model
import train
import mydatasets
parser = argparse.ArgumentParser(description='CNN text classificer')
# learning
parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate [default: 0.001]')
parser.add_argument('-epochs', type=int, default=256, help='number of epochs for train [default: 256]')
parser.add_argument('-batch-size', type=int, default=64, help='batch size for training [default: 64]')
parser.add_argument('-log-interval', type=int, default=1, help='how many steps to wait before logging training status [default: 1]')
parser.add_argument('-test-interval', type=int, default=100, help='how many steps to wait before testing [default: 100]')
parser.add_argument('-save-interval', type=int, default=500, help='how many steps to wait before saving [default:500]')
parser.add_argument('-save-dir', type=str, default='snapshot', help='where to save the snapshot')
# data
parser.add_argument('-shuffle', action='store_true', default=False, help='shuffle the data every epoch' )
# model
parser.add_argument('-dropout', type=float, default=0.5, help='the probability for dropout [default: 0.5]')
parser.add_argument('-max-norm', type=float, default=3.0, help='l2 constraint of parameters [default: 3.0]')
parser.add_argument('-embed-dim', type=int, default=128, help='number of embedding dimension [default: 128]')
parser.add_argument('-kernel-num', type=int, default=100, help='number of each kind of kernel')
parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='comma-separated kernel size to use for convolution')
parser.add_argument('-static', action='store_true', default=False, help='fix the embedding')
# device
parser.add_argument('-device', type=int, default=-1, help='device to use for iterate data, -1 mean cpu [default: -1]')
parser.add_argument('-no-cuda', action='store_true', default=False, help='disable the gpu' )
# option
parser.add_argument('-snapshot', type=str, default=None, help='filename of model snapshot [default: None]')
parser.add_argument('-predict', type=str, default=None, help='predict the sentence given')
parser.add_argument('-test', action='store_true', default=False, help='train or test')
args = parser.parse_args()
# load SST dataset
def sst(text_field, label_field, **kargs):
train_data, dev_data, test_data = datasets.SST.splits(text_field, label_field, fine_grained=True)
text_field.build_vocab(train_data, dev_data, test_data)
label_field.build_vocab(train_data, dev_data, test_data)
train_iter, dev_iter, test_iter = data.BucketIterator.splits(
(train_data, dev_data, test_data),
batch_sizes=(args.batch_size,
len(dev_data),
len(test_data)),
**kargs)
return train_iter, dev_iter, test_iter
# load MR dataset
def mr(text_field, label_field, **kargs):
train_data, dev_data = mydatasets.MR.splits(text_field, label_field)
text_field.build_vocab(train_data, dev_data)
label_field.build_vocab(train_data, dev_data)
train_iter, dev_iter = data.Iterator.splits(
(train_data, dev_data),
batch_sizes=(args.batch_size, len(dev_data)),
**kargs)
return train_iter, dev_iter
# load data
print("\nLoading data...")
text_field = data.Field(lower=True)
label_field = data.Field(sequential=False)
train_iter, dev_iter = mr(text_field, label_field, device=-1, repeat=False)
#train_iter, dev_iter, test_iter = sst(text_field, label_field, device=-1, repeat=False)
# update args and print
args.embed_num = len(text_field.vocab)
args.class_num = len(label_field.vocab) - 1
args.cuda = (not args.no_cuda) and torch.cuda.is_available(); del args.no_cuda
args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
print("\nParameters:")
for attr, value in sorted(args.__dict__.items()):
print("\t{}={}".format(attr.upper(), value))
# model
if args.snapshot is None:
cnn = model.CNN_Text(args)
else :
print('\nLoading model from [%s]...' % args.snapshot)
try:
cnn = torch.load(args.snapshot)
except :
print("Sorry, This snapshot doesn't exist."); exit()
if args.cuda:
cnn = cnn.cuda()
# train or predict
if args.predict is not None:
label = train.predict(args.predict, cnn, text_field, label_field, args.cuda)
print('\n[Text] {}\n[Label] {}\n'.format(args.predict, label))
elif args.test :
try:
train.eval(test_iter, cnn, args)
except Exception as e:
print("\nSorry. The test dataset doesn't exist.\n")
else :
print()
try:
train.train(train_iter, dev_iter, cnn, args)
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')