Skip to content

Commit

Permalink
Merge pull request #89 from DevLinyan/main
Browse files Browse the repository at this point in the history
minor fix
  • Loading branch information
ChonghaoSima authored Apr 26, 2024
2 parents 6f20873 + 50d2964 commit 8e4fc64
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions challenge/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def match_result(self, answer, GT):
# transform string into float
answer_nums = np.array([list(map(float, x.split()))[0] for x in answer_nums]).reshape(-1, 2)
GT_nums = np.array([list(map(float, x.split()))[0] for x in GT_nums]).reshape(-1, 2)

if len(answer_nums) == 0:
return [], 0
length = len(GT_nums)

matched_out = []
true_positives = 0
Expand All @@ -97,23 +95,25 @@ def match_result(self, answer, GT):
for pred in answer_nums:
closest_distance = float('inf')
closest_gt = None
for gt in GT_nums:
closest_id = None
for i, gt in enumerate(GT_nums):
distance = np.sum(np.abs(pred - gt))
if distance < closest_distance:
closest_distance = distance
closest_gt = gt
closest_id = i

if closest_distance < 16:
true_positives += 1
matched_out.append(closest_gt)
GT_nums.remove(closest_gt)
matched_out.append(closest_gt)
GT_nums = np.delete(GT_nums, closest_id, axis=0)
else:
false_positives += 1

false_negatives = len(GT_nums) - true_positives
precision = true_positives / (true_positives + false_positives)
recall = true_positives / (true_positives + false_negatives)
F1 = 2 * precision * recall / (precision + recall)
false_negatives = length - true_positives
precision = true_positives / (true_positives + false_positives + 1e-8)
recall = true_positives / (true_positives + false_negatives + 1e-8)
F1 = 2 * precision * recall / (precision + recall + 1e-8)

return matched_out, F1

Expand Down

0 comments on commit 8e4fc64

Please sign in to comment.