-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.lua
executable file
·112 lines (90 loc) · 3.1 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
-------------------------------
require 'torch'
require 'xlua'
require 'optim'
-------------------------------
local trainer = {}
-- C'Tor
function trainer:new(data_type, clipping)
-- CUDA?
if data_type == 'cuda' then
model:cuda()
criterion:cuda()
-- Log results to files
trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
end
self.clipping = clipping
-- Log results to files
gradLogger = optim.Logger(paths.concat(opt.save, 'grad.log'))
-- testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
end
function trainer:set_optimizer(optimization)
if optimization == 'SGD' then
self.optim_state = {
learningRate = opt.learningRate,
weightDecay = opt.weightDecay,
momentum = opt.momentum,
learningRateDecay = 1e-7
}
self.optim_method = optim.sgd
elseif opt.optimization == 'ADAM' then
self.optim_state = {
learningRate = opt.learningRate
}
self.optim_method = optim.adam
elseif opt.optimization == 'ADAGRAD' then
self.optim_state = {
learningRate = opt.learningRate,
}
self.optim_method = optim.adagrad
elseif opt.optimization == 'ADADELTA' then
self.optim_state = {
learningRate = opt.learningRate,
}
self.optim_method = optim.adadelta
elseif opt.optimization == 'RMSPROP' then
self.optim_state = {
learningRate = opt.learningRate,
}
self.optim_method = optim.rmsprop
else
error('unknown optimization method')
end
end
function trainer:train(train_x, train_y)
-- local vars
local time = sys.clock()
-- create closure to evaluate f(X) and df/dX
local feval = function(x)
-- reset gradients
gradParameters:zero()
-- f is the average of all criterions
local f = 0
-- estimate f
local output = model:forward(train_x)
local err = criterion:forward(output, train_y)
f = f + err
-- estimate df/dW
local df_do = criterion:backward(output, train_y)
model:backward(train_x, df_do)
-- gradient clipping
gradParameters:clamp(-self.clipping, self.clipping)
-- normalize gradients and f(X)
gradParameters:div(#train_x)
f = f/#train_x
-- update logger/plot
-- tracking the gradients
gradLogger:add{['% grad norm (train set)'] = torch.norm(gradParameters)}
-- return f and df/dX
return f, gradParameters
end
-- optimize on current mini-batch
self.optim_method(feval, parameters, self.optim_state)
-- time taken
time = sys.clock() - time
time = time / #train_x
return time
end
return trainer
-------------------------------