-
Notifications
You must be signed in to change notification settings - Fork 57
/
laia-decode
executable file
·116 lines (104 loc) · 3.3 KB
/
laia-decode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env th
require 'laia'
local batcher = laia.RandomBatcher()
local parser = laia.argparse(){
name = 'laia-decode',
description = ''
}
-- Register laia.Version options
laia.Version():registerOptions(parser)
-- Register laia.log options.
laia.log.registerOptions(parser)
-- Register cudnn options, only if available.
if cudnn then cudnn.registerOptions(parser, true) end
-- Register batcher options.
batcher:registerOptions(parser)
parser:argument('checkpoint',
'Path to the file containing the trained checkpoint/model.')
parser:argument('image_list',
'Path to the file containing the list of images to decode.')
parser:option(
'--seed -s', 'Seed for random numbers generation.',
0, laia.toint)
parser:option(
'--gpu', 'If gpu>0, uses the specified GPU, otherwise uses the CPU.',
1, laia.toint)
parser:option(
'--auto_width_factor', 'If true, sets the width factor for the batchers ' ..
'automatically, from the size of the pooling layers.',
false, laia.toboolean)
:argname('<bool>')
parser:option(
'--batch_size -b', 'Batch size', 16, laia.toint)
:ge(1)
parser:option(
'--symbols_table', 'Path of the file containing the symbols table.', '')
:argname('<file>')
-- Parse options
local opts = parser:parse()
-- Initialize random seeds
laia.manualSeed(opts.seed)
-- Load *BEST* model from the checkpoint.
local model = laia.Checkpoint():load(opts.checkpoint):Best():getModel()
assert(model ~= nil, 'No model was found in the checkpoint file!')
-- If a GPU is requested, check that we have everything necessary.
if opts.gpu > 0 then
assert(cutorch ~= nil, 'Package cutorch is required in order to use the GPU.')
assert(nn ~= nil, 'Package nn is required in order to use the GPU.')
cutorch.setDevice(opts.gpu)
model = model:cuda()
-- If cudnn_force_convert=true, force all possible layers to use cuDNN impl.
if cudnn and cudnn.force_convert then
cudnn.convert(model, cudnn)
end
else
-- This should not be necessary, but just in case
model = model:float()
end
-- We are going to evaluate the model
model:evaluate()
-- Load symbols
local symbols_table
if opts.symbols_table ~= '' then
_, _, symbols_table = laia.read_symbols_table(opts.symbols_table)
end
-- Prepare batcher
if opts.auto_width_factor then
local width_factor = laia.getWidthFactor(model)
batcher:setOptions({width_factor = width_factor})
laia.log.info('Batcher width factor was automatically set to %d',
width_factor)
end
batcher:load(opts.image_list)
if opts.seed > 0 then
batcher:epochReset()
end
local n = 0
for batch=1,batcher:numSamples(),opts.batch_size do
-- Prepare batch
local batch_img, _, _, batch_ids = batcher:next(opts.batch_size)
if opts.gpu > 0 then
batch_img = batch_img:cuda()
end
-- Forward through network
local output = model:forward(batch_img)
local batch_decode = laia.framewise_decode(opts.batch_size, output)
for i=1,opts.batch_size do
n = n + 1
-- Batch can contain more images
if n > batcher:numSamples() then
break
end
io.write(string.format('%s', batch_ids[i]))
for t=1, #batch_decode[i] do
if symbols_table then
-- Print symbols
io.write(string.format(' %s', symbols_table[batch_decode[i][t]]))
else
-- Print id's
io.write(string.format(' %d', batch_decode[i][t]))
end
end
io.write('\n')
end
end