-
Notifications
You must be signed in to change notification settings - Fork 24
/
vqa.lua
178 lines (142 loc) · 5.05 KB
/
vqa.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
require 'nn'
require 'torch'
require 'optim'
require 'misc.DataLoader'
require 'misc.word_level'
require 'misc.phrase_level'
require 'misc.ques_level'
require 'misc.recursive_atten'
require 'misc.cnnModel'
require 'misc.optim_updates'
utils = require 'misc.utils'
require 'xlua'
require 'image'
local TorchModel = torch.class('VQADemoTorchModel')
function TorchModel:__init(vqa_model, cnn_proto, cnn_model, json_file, backend, gpuid)
self.vqa_model = vqa_model
self.cnn_proto = cnn_proto
self.cnn_model = cnn_model
self.json_file = json_file
self.backend = backend
self.gpuid = gpuid
print(self.vqa_model, self.cnn_proto, self.cnn_model, self.json_file, self.backend)
if self.gpuid >= 0 then
require 'cutorch'
require 'cunn'
if self.backend == 'cudnn' then
require 'cudnn'
end
cutorch.setDevice(self.gpuid)
end
local loaded_checkpoint = torch.load(self.vqa_model)
local lmOpt = loaded_checkpoint.lmOpt
lmOpt.hidden_size = 512
lmOpt.feature_type = 'VGG'
lmOpt.atten_type = 'Alternating'
cnnOpt = {}
cnnOpt.cnn_proto = self.cnn_proto
cnnOpt.cnn_model = self.cnn_model
cnnOpt.backend = self.backend
cnnOpt.input_size_image = 512
cnnOpt.output_size = 512
cnnOpt.h = 14
cnnOpt.w = 14
cnnOpt.layer_num = 37
-- load the vocabulary and answers.
local json_file = utils.read_json(self.json_file)
ix_to_word = json_file.ix_to_word
ix_to_ans = json_file.ix_to_ans
word_to_ix = {}
for ix, word in pairs(ix_to_word) do
word_to_ix[word]=ix
end
-- load the model
protos = {}
protos.word = nn.word_level(lmOpt)
protos.phrase = nn.phrase_level(lmOpt)
protos.ques = nn.ques_level(lmOpt)
protos.atten = nn.recursive_atten()
protos.crit = nn.CrossEntropyCriterion()
protos.cnn = nn.cnnModel(cnnOpt)
if self.gpuid >= 0 then
for k,v in pairs(protos) do v:cuda() end
end
cparams, grad_cparams = protos.cnn:getParameters()
wparams, grad_wparams = protos.word:getParameters()
pparams, grad_pparams = protos.phrase:getParameters()
qparams, grad_qparams = protos.ques:getParameters()
aparams, grad_aparams = protos.atten:getParameters()
print('Load the weight...')
wparams:copy(loaded_checkpoint.wparams)
pparams:copy(loaded_checkpoint.pparams)
qparams:copy(loaded_checkpoint.qparams)
aparams:copy(loaded_checkpoint.aparams)
print('total number of parameters in cnn_model: ', cparams:nElement())
assert(cparams:nElement() == grad_cparams:nElement())
print('total number of parameters in word_level: ', wparams:nElement())
assert(wparams:nElement() == grad_wparams:nElement())
print('total number of parameters in phrase_level: ', pparams:nElement())
assert(pparams:nElement() == grad_pparams:nElement())
print('total number of parameters in ques_level: ', qparams:nElement())
assert(qparams:nElement() == grad_qparams:nElement())
protos.ques:shareClones()
print('total number of parameters in recursive_attention: ', aparams:nElement())
assert(aparams:nElement() == grad_aparams:nElement())
self.protos = protos
self.word_to_ix = word_to_ix
end
function TorchModel:predict(img_path, question)
-- specify the image and the question.
local img_path = img_path
local question = question
-- load the image
local img = image.load(img_path)
-- scale the image
img = image.scale(img,448,448)
-- itorch.image(img)
img = img:view(1,img:size(1),img:size(2),img:size(3))
-- parse and encode the question (in a simple way).
local ques_encode = torch.IntTensor(26):zero()
local count = 1
for word in string.gmatch(question, "%S+") do
-- check that length of question is <= ques_encode length
if count <= 26 then
ques_encode[count] = word_to_ix[word] or word_to_ix['UNK']
count = count + 1
end
end
ques_encode = ques_encode:view(1,ques_encode:size(1))
-- doing the prediction
self.protos.word:evaluate()
self.protos.phrase:evaluate()
self.protos.ques:evaluate()
self.protos.atten:evaluate()
self.protos.cnn:evaluate()
local image_raw = utils.prepro(img, false)
if self.gpuid >= 0 then
image_raw = image_raw:cuda()
end
ques_encode = ques_encode:cuda()
local image_feat = self.protos.cnn:forward(image_raw)
local ques_len = torch.Tensor(1,1):cuda()
ques_len[1] = count-1
local word_feat, img_feat, w_ques, w_img, mask = unpack(self.protos.word:forward({ques_encode, image_feat}))
local conv_feat, p_ques, p_img = unpack(self.protos.phrase:forward({word_feat, ques_len, img_feat, mask}))
local q_ques, q_img = unpack(self.protos.ques:forward({conv_feat, ques_len, img_feat, mask}))
local feature_ensemble = {w_ques, w_img, p_ques, p_img, q_ques, q_img}
local out_feat = self.protos.atten:forward(feature_ensemble)
out_feat_sf = nn.SoftMax():forward(out_feat:double())
out_feat_sf = out_feat_sf:squeeze()
top5, ix = out_feat_sf:topk(5, true)
top5_answer = {}
softmax_score = {}
for i=1,5 do
local ans = ix_to_ans[tostring(ix[i])]
softmax_score[i] = top5[i]
top5_answer[i] = ans
end
result = {}
result['top5_answer'] = top5_answer
result['softmax_score'] = softmax_score
return result
end