Skip to content

Commit

Permalink
[feat] Implement classification results converter
Browse files Browse the repository at this point in the history
  • Loading branch information
hglee98 committed Dec 18, 2024
1 parent b42a73c commit 06eef36
Showing 1 changed file with 27 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,31 @@ def test_step(self, test_model, batch):
def get_metric_with_all_outputs(self, outputs, phase: Literal['train', 'valid'], metric_factory):
pass

def _convert_result(self, result, class_map):
assert "pred" in result and "images" in result
return_preds = []
for idx in range(len(result['pred'])):
image = result['images'][idx:idx+1]
height, width = image.shape[-2:]
pred = result['pred'][idx]
return_preds.append(
{
"class": int(pred[0]),
"name": class_map[int(pred[0])],
"shape": {
"width": width,
"height": height
}
}
)
return return_preds

def get_predictions(self, results, class_map):
pass
predictions = []
if isinstance(results, list):
for minibatch in results:
predictions.extend(self._convert_result(minibatch, class_map))
elif isinstance(results, dict):
predictions.extend(self._convert_result(results, class_map))

return predictions

0 comments on commit 06eef36

Please sign in to comment.