Skip to content

Commit

Permalink
Refactor changes in refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
HiroIshida committed Nov 12, 2022
1 parent 740644b commit e3596bd
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions node_script/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from detic_ros.msg import SegmentationInfo


_cv_bridge = CvBridge()


Expand All @@ -28,7 +27,7 @@ class InferenceRawResult:
scores: List[float]
visualization: Optional[VisImage]
header: Header
class_names: List[str]
detected_class_names: List[str]

def get_ros_segmentaion_image(self) -> Image:
seg_img = _cv_bridge.cv2_to_imgmsg(self.segmentation_raw_image, encoding="32SC1")
Expand All @@ -47,12 +46,13 @@ def get_ros_debug_segmentation_img(self) -> Image:
human_friendly_scaling = 255 // self.segmentation_raw_image.max()
new_data = (self.segmentation_raw_image * human_friendly_scaling).astype(np.uint8)
debug_seg_img = _cv_bridge.cv2_to_imgmsg(new_data, encoding="mono8")
assert self.header is not None
debug_seg_img.header = self.header
return debug_seg_img

def get_label_array(self) -> LabelArray:
labels = [Label(id=i + 1, name=self.class_names[i]) for i in self.class_indices]
labels = [Label(id=i + 1, name=name)
for i, name
in zip(self.class_indices, self.detected_class_names)]
lab_arr = LabelArray(header=self.header, labels=labels)
return lab_arr

Expand All @@ -62,8 +62,7 @@ def get_score_array(self) -> VectorArray:

def get_segmentation_info(self) -> SegmentationInfo:
seg_img = self.get_ros_segmentaion_image()
detected_classes_names = [self.class_names[i] for i in self.class_indices]
seg_info = SegmentationInfo(detected_classes=detected_classes_names,
seg_info = SegmentationInfo(detected_classes=self.detected_class_names,
scores=self.scores,
segmentation=seg_img,
header=self.header)
Expand Down Expand Up @@ -131,13 +130,14 @@ def infer(self, msg: Image) -> InferenceRawResult:
data[mask] = (i + 1)

# Get class and score arrays
class_indexes = instances.pred_classes.tolist()
class_indices = instances.pred_classes.tolist()
detected_classes_names = [self.class_names[i] for i in class_indices]
scores = instances.scores.tolist()
result = InferenceRawResult(
data,
class_indexes,
class_indices,
scores,
visualized_output,
msg.header,
self.class_names)
detected_classes_names)
return result

0 comments on commit e3596bd

Please sign in to comment.