-
Notifications
You must be signed in to change notification settings - Fork 3
/
evaluate.lua
35 lines (31 loc) · 1 KB
/
evaluate.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
require 'torch'
require 'xlua'
local evaluator = {}
function evaluator:evaluate(model, criterion, x, y, f_n, verbose)
local avg_score = 0
local cumulative_loss = 0
model:evaluate()
for t =1, #x do
if not verbose then
xlua.progress(t, #x)
end
local output = model:forward(x[t])
local score, onset, offset = criterion:predict(output)
local loss = criterion:task_loss(y[t], {onset, offset})
avg_score = avg_score + score
cumulative_loss = cumulative_loss + loss
-- printings
if verbose then
print('Filename: ' .. f_n[t] .. ', score: ' .. score .. ', y_hat: [' .. onset .. ', ' .. offset ..'], y: [' .. y[t][1] .. ', ' .. y[t][2] .. '], loss: ' .. loss)
end
end
cumulative_loss = cumulative_loss / #x
avg_score = avg_score / #x
-- printings
if verbose then
print('\nAverage score: ' .. avg_score)
print('Average cumulative loss: ' .. cumulative_loss)
end
return cumulative_loss, avg_score
end
return evaluator