diff --git a/csrc/faster_eval_api/faster_eval_api.cpp b/csrc/faster_eval_api/faster_eval_api.cpp index cddf8c1..e0ef991 100644 --- a/csrc/faster_eval_api/faster_eval_api.cpp +++ b/csrc/faster_eval_api/faster_eval_api.cpp @@ -57,7 +57,34 @@ namespace coco_eval pybind11::class_(m, "InstanceAnnotation") .def(pybind11::init()); - pybind11::class_(m, "ImageEvaluation").def(pybind11::init<>()); + pybind11::class_(m, "ImageEvaluation").def(pybind11::init<>()) + .def(py::pickle( + [](const COCOeval::ImageEvaluation &p) { + + std::vector> matched_annotations; + for (size_t i = 0; i < p.matched_annotations.size(); i++) { + matched_annotations.push_back(std::make_tuple(p.matched_annotations[i].dt_id, p.matched_annotations[i].gt_id, p.matched_annotations[i].iou)); + } + + return py::make_tuple(p.detection_matches, p.ground_truth_matches, p.detection_scores, p.ground_truth_ignores, p.detection_ignores, matched_annotations); + }, + [](py::tuple t) { // __setstate__ + if (t.size() != 6) + throw std::runtime_error("Invalid state!"); + + COCOeval::ImageEvaluation p; + p.detection_matches = t[0].cast>(); + p.ground_truth_matches = t[1].cast>(); + p.detection_scores = t[2].cast>(); + p.ground_truth_ignores = t[3].cast>(); + p.detection_ignores = t[4].cast>(); + std::vector> matched_annotations = t[5].cast>>(); + for (size_t i = 0; i < matched_annotations.size(); i++) { + p.matched_annotations.emplace_back(std::get<0>(matched_annotations[i]), std::get<1>(matched_annotations[i]), std::get<2>(matched_annotations[i])); + } + return p; + } + )); pybind11::class_(m, "Dataset").def(pybind11::init<>()) .def("append", &COCOeval::Dataset::append) @@ -73,4 +100,4 @@ namespace coco_eval #endif } -} // namespace coco_eval +} // namespace coco_eval