Skip to content

Commit

Permalink
ImageEvaluation pickleable
Browse files Browse the repository at this point in the history
  • Loading branch information
MiXaiLL76 committed Oct 31, 2024
1 parent dff1188 commit 1caac51
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions csrc/faster_eval_api/faster_eval_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,34 @@ namespace coco_eval
pybind11::class_<COCOeval::InstanceAnnotation>(m, "InstanceAnnotation")
.def(pybind11::init<uint64_t, double, double, bool, bool, bool>());

pybind11::class_<COCOeval::ImageEvaluation>(m, "ImageEvaluation").def(pybind11::init<>());
pybind11::class_<COCOeval::ImageEvaluation>(m, "ImageEvaluation").def(pybind11::init<>())
.def(py::pickle(
[](const COCOeval::ImageEvaluation &p) {

std::vector<std::tuple<uint64_t, uint64_t, double>> matched_annotations;
for (size_t i = 0; i < p.matched_annotations.size(); i++) {
matched_annotations.push_back(std::make_tuple(p.matched_annotations[i].dt_id, p.matched_annotations[i].gt_id, p.matched_annotations[i].iou));
}

return py::make_tuple(p.detection_matches, p.ground_truth_matches, p.detection_scores, p.ground_truth_ignores, p.detection_ignores, matched_annotations);
},
[](py::tuple t) { // __setstate__
if (t.size() != 6)
throw std::runtime_error("Invalid state!");

COCOeval::ImageEvaluation p;
p.detection_matches = t[0].cast<std::vector<int64_t>>();
p.ground_truth_matches = t[1].cast<std::vector<int64_t>>();
p.detection_scores = t[2].cast<std::vector<double>>();
p.ground_truth_ignores = t[3].cast<std::vector<bool>>();
p.detection_ignores = t[4].cast<std::vector<bool>>();
std::vector<std::tuple<uint64_t, uint64_t, double>> matched_annotations = t[5].cast<std::vector<std::tuple<uint64_t, uint64_t, double>>>();
for (size_t i = 0; i < matched_annotations.size(); i++) {
p.matched_annotations.emplace_back(std::get<0>(matched_annotations[i]), std::get<1>(matched_annotations[i]), std::get<2>(matched_annotations[i]));
}
return p;
}
));

pybind11::class_<COCOeval::Dataset>(m, "Dataset").def(pybind11::init<>())
.def("append", &COCOeval::Dataset::append)
Expand All @@ -73,4 +100,4 @@ namespace coco_eval
#endif
}

} // namespace coco_eval
} // namespace coco_eval

0 comments on commit 1caac51

Please sign in to comment.