-
Notifications
You must be signed in to change notification settings - Fork 0
/
composer.lua
159 lines (125 loc) · 4.16 KB
/
composer.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
require 'optim'
-- class for training a composition model
-- code based on the Torch supervised learning tutorial (https://github.com/torch/tutorials/tree/master/2_supervised)
local Composer = torch.class('nn.Composer')
function Composer:__init(m, config)
self.module = m
if (config.criterion == 'abs') then
self.criterion = nn.AbsCriterion()
elseif (config.criterion == 'mse') then
self.criterion = nn.MSECriterion()
else
error("Unknown criterion.")
end
self.config = config
print(self.config)
end
function Composer:train(trainDataset, devDataset)
local parameters, gradParameters = self.module:getParameters()
self.bestModel = nil
self.bestError = 2^20
print("# Composer: training")
function doTrain(module, criterion, config, trainDataset)
module:training()
epoch = epoch or 1
print("# Epoch ", epoch)
local trainError = 0
-- shuffe indices; this means that each pass sees the training examples in a different order
local shuffledIndices = torch.randperm(trainDataset:size(), 'torch.LongTensor')
-- go through the dataset
for t = 1, trainDataset:size(), config.batchSize do
-- create mini-batch
local inputs = {}
local targets = {}
for i = t, math.min(t + config.batchSize - 1, trainDataset:size()) do
-- load new sample
local example = trainDataset[shuffledIndices[i]]
local input = example[1]
local target = example[2]
table.insert(inputs, input)
table.insert(targets, target)
end
-- create clojure to evaluate function and its derivative on the mini-batch
local feval = function(x)
-- just in case:
collectgarbage()
-- get new parameters
if x ~= parameters then
parameters:copy(x)
end
-- reset gradients
gradParameters:zero()
-- f is the average of all criterions (f is the average error on the mini-batch)
local f = 0
-- evaluate on mini-batch
for i = 1, #inputs do
-- estimate f for the current input
local prediction = module:forward(inputs[i])
-- compute the error between the prediction and the (correct) target using the criterion
local err = criterion:forward(prediction, targets[i])
-- add the error to f
f = f + err
-- estimate df/dW
local df_do = criterion:backward(prediction, targets[i])
module:backward(inputs[i], df_do)
end
-- normalize gradients and f(X)
gradParameters:div(#inputs)
f = f/#inputs
trainError = trainError + f
return f, gradParameters
end
-- optimize the current mini-batch
if config.optimizer == 'adagrad' then
optim.adagrad(feval, parameters, config.adagrad_config)
else
error('unknown optimization method')
end
end
-- train error
trainError = trainError / math.floor(trainDataset:size()/config.batchSize)
-- next epoch
epoch = epoch + 1
return trainError
end
function doTest(module, criterion, dataset)
module:evaluate()
local testError = 0
for t = 1, dataset:size() do
local input = dataset[t][1]
local target = dataset[t][2]
-- test sample
local pred = module:forward(input)
-- compute error
err = criterion:forward(pred, target)
testError = testError + err
end
-- average error over the dataset
testError = testError/dataset:size()
return testError
end
while true do
trainErr = doTrain(self.module, self.criterion, self.config, trainDataset)
testErr = doTest(self.module, self.criterion, devDataset)
print('Train error:\t', string.format("%.6f", trainErr))
print('Test error:\t', string.format("%.6f", testErr))
print('Best error:\t', string.format("%.6f", self.bestError))
-- early stopping when the error on the test set ceases to decrease
if (self.config.earlyStopping == true) then
if (testErr < self.bestError) then
self.bestError = testErr
self.bestModel = self.module:clone()
self.extraIndex = 1
else
if (self.extraIndex < self.config.extraEpochs) then
self.extraIndex = self.extraIndex + 1
else
print("# Composer: stopping - you have reached the maximum number of epochs after the best model")
print("# Composer: best error: " .. self.bestError)
self.module = self.bestModel:clone()
break
end
end
end
end
end