diff --git a/flowvision/models/detection/fcos.py b/flowvision/models/detection/fcos.py index 34979bae..a1c01d1d 100644 --- a/flowvision/models/detection/fcos.py +++ b/flowvision/models/detection/fcos.py @@ -14,6 +14,8 @@ from flowvision.layers import boxes as box_ops from flowvision.layers import misc as misc_nn_ops from flowvision.layers import LastLevelP6P7 +from ..utils import load_state_dict_from_url +from .transform import _resize_boxes,paste_masks_in_image,_resize_keypoints from . import det_utils from .anchor_utils import AnchorGenerator from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers @@ -21,10 +23,13 @@ from ..registry import ModelCreator +model_urls={ + "fcos_resnet50_fpn_coco":"http://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/fcos_resnet50_fpn/model.pth" +} + class FCOSHead(nn.Module): """ A regression and classification head for use in FCOS. - Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted @@ -153,7 +158,6 @@ def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: class FCOSClassificationHead(nn.Module): """ A classification head for use in FCOS. - Args: in_channels (int): number of channels of the input feature. num_anchors (int): number of anchors to be predicted. @@ -225,7 +229,6 @@ def forward(self, x: List[Tensor]) -> Tensor: class FCOSRegressionHead(nn.Module): """ A regression head for use in FCOS. - Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted @@ -297,21 +300,16 @@ def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]: class FCOS(nn.Module): """ Implements FCOS. - The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each image, and should be in 0-1 range. Different images can have different sizes. - The behavior of the model changes depending if it is in training or evaluation mode. - During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (Int64Tensor[N]): the class label for each ground-truth box - The model returns a Dict[Tensor] during training, containing the classification, regression and centerness losses. - During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as follows: @@ -319,7 +317,6 @@ class FCOS(nn.Module): ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (Int64Tensor[N]): the predicted labels for each image - scores (Tensor[N]): the scores for each prediction - Args: backbone (nn.Module): the network used to compute the features for the model. It should contain an out_channels attribute, which indicates the number of output @@ -534,6 +531,8 @@ def postprocess_detections( # keep only topk scoring predictions num_topk = min(self.topk_candidates, topk_idxs.size(0)) + if num_topk <= 0: + continue scores_per_level, idxs = scores_per_level.topk(num_topk) topk_idxs = topk_idxs[idxs] @@ -542,7 +541,7 @@ def postprocess_detections( boxes_per_level = self.box_coder.decode_single( box_regression_per_level[anchor_idxs], - anchors_per_level[anchors_idxs], + anchors_per_level[anchor_idxs], ) boxes_per_level = box_ops.clip_boxes_to_image( boxes_per_level, image_shape @@ -552,6 +551,16 @@ def postprocess_detections( image_scores.append(scores_per_level) image_labels.append(labels_per_level) + if len(image_boxes) <= 0: + detections.append( + { + "boxes": flow.tensor(image_boxes), + "scores": flow.tensor(image_scores), + "labels": flow.tensor(image_labels), + } + ) + continue + image_boxes = flow.cat(image_boxes, dim=0) image_scores = flow.cat(image_scores, dim=0) image_labels = flow.cat(image_labels, dim=0) @@ -572,6 +581,28 @@ def postprocess_detections( return detections + def postprocess_bbox(self,result,image_shapes,original_image_sizes): + if self.training: + return result + for i, (pred, im_s, o_im_s) in enumerate( + zip(result, image_shapes, original_image_sizes) + ): + boxes = pred["boxes"] + if len(boxes) <= 0: + result[i]["boxes"] = boxes + continue + boxes = _resize_boxes(boxes, im_s, o_im_s) + result[i]["boxes"] = boxes + if "masks" in pred: + masks = pred["masks"] + masks = paste_masks_in_image(masks, boxes, o_im_s) + result[i]["masks"] = masks + if "keypoints" in pred: + keypoints = pred["keypoints"] + keypoints = _resize_keypoints(keypoints, im_s, o_im_s) + result[i]["keypoints"] = keypoints + return result + def forward( self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None, ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: @@ -579,7 +610,6 @@ def forward( Args: images (list[Tensor]): images to be processed targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) - Returns: result (list[BoxList] or dict[Tensor]): the output from the model. During training, it returns a dict[Tensor] which contains the losses. @@ -668,9 +698,9 @@ def forward( # compute the detections detections = self.postprocess_detections( - split_head_outputs, split_anchors, images.image_size + split_head_outputs, split_anchors, images.image_sizes ) - detections = self.transform.postprocess( + detections = self.postprocess_bbox( detections, images.image_sizes, original_image_sizes ) @@ -710,7 +740,7 @@ def _fcos_resnet_fpn( state_dict = load_state_dict_from_url( model_urls[weights_name], progress=progress ) - model.load_state_dict(state_dict) + model.load_state_dict(state_dict['model']) det_utils.overwrite_eps(model, 0.0) return model diff --git a/flowvision/models/detection/retinanet.py b/flowvision/models/detection/retinanet.py index dcb083ea..374b8a92 100644 --- a/flowvision/models/detection/retinanet.py +++ b/flowvision/models/detection/retinanet.py @@ -11,6 +11,7 @@ from ..utils import load_state_dict_from_url +from .transform import _resize_boxes,paste_masks_in_image,_resize_keypoints from . import det_utils from .anchor_utils import AnchorGenerator from .transform import GeneralizedRCNNTransform @@ -39,7 +40,6 @@ def _sum(x: List[Tensor]) -> Tensor: class RetinaNetHead(nn.Module): """ A regression and classification head for use in RetinaNet. - Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted @@ -79,7 +79,6 @@ def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: class RetinaNetClassificationHead(nn.Module): """ A classification head for use in RetinaNet. - Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted @@ -177,7 +176,6 @@ def forward(self, x: List[Tensor]) -> Tensor: class RetinaNetRegressionHead(nn.Module): """ A regression head for use in RetinaNet. - Args: in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted @@ -277,21 +275,16 @@ def forward(self, x: List[Tensor]) -> Tensor: class RetinaNet(nn.Module): """ Implements RetinaNet. - The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each image, and should be in 0-1 range. Different images can have different sizes. - The behavior of the model changes depending if it is in training or evaluation mode. - During training, the model expects both the input tensors, as well as a targets (list of dictionary), containing: - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (Int64Tensor[N]): the class label for each ground-truth box - The model returns a Dict[Tensor] during training, containing the classification and regression losses. - During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as follows: @@ -299,7 +292,6 @@ class RetinaNet(nn.Module): ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. - labels (Int64Tensor[N]): the predicted labels for each image - scores (Tensor[N]): the scores for each prediction - Args: backbone (nn.Module): the network used to compute the features for the model. It should contain an out_channels attribute, which indicates the number of output @@ -467,6 +459,8 @@ def postprocess_detections( # keep only topk scoring predictions num_topk = min(self.topk_candidates, topk_idxs.size(0)) + if num_topk <= 0: + continue scores_per_level, idxs = scores_per_level.topk(num_topk) topk_idxs = topk_idxs[idxs] @@ -485,6 +479,15 @@ def postprocess_detections( image_scores.append(scores_per_level) image_labels.append(labels_per_level) + if len(image_boxes) <= 0: + detections.append( + { + "boxes": flow.tensor(image_boxes), + "scores": flow.tensor(image_scores), + "labels": flow.tensor(image_labels), + } + ) + continue image_boxes = flow.cat(image_boxes, dim=0) image_scores = flow.cat(image_scores, dim=0) image_labels = flow.cat(image_labels, dim=0) @@ -505,6 +508,29 @@ def postprocess_detections( return detections + + def postprocess_bbox(self,result,image_shapes,original_image_sizes): + if self.training: + return result + for i, (pred, im_s, o_im_s) in enumerate( + zip(result, image_shapes, original_image_sizes) + ): + boxes = pred["boxes"] + if len(boxes) <= 0: + result[i]["boxes"] = boxes + continue + boxes = _resize_boxes(boxes, im_s, o_im_s) + result[i]["boxes"] = boxes + if "masks" in pred: + masks = pred["masks"] + masks = paste_masks_in_image(masks, boxes, o_im_s) + result[i]["masks"] = masks + if "keypoints" in pred: + keypoints = pred["keypoints"] + keypoints = _resize_keypoints(keypoints, im_s, o_im_s) + result[i]["keypoints"] = keypoints + return result + def forward( self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: @@ -512,13 +538,11 @@ def forward( Args: images (list[Tensor]): images to be processed targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) - Returns: result (list[BoxList] or dict[Tensor]): the output from the model. During training, it returns a dict[Tensor] which contains the losses. During testing, it returns list[BoxList] contains additional fields like `scores`, `labels` and `mask` (for Mask R-CNN models). - """ if self.training and targets is None: raise ValueError("In training mode, targets should be passed") @@ -611,7 +635,7 @@ def forward( detections = self.postprocess_detections( split_head_outputs, split_anchors, images.image_sizes ) - detections = self.transform.postprocess( + detections = self.postprocess_bbox( detections, images.image_sizes, original_image_sizes ) diff --git a/projects/detection/coco_eval.py b/projects/detection/coco_eval.py index b816f2c9..0b8f0ae6 100644 --- a/projects/detection/coco_eval.py +++ b/projects/detection/coco_eval.py @@ -66,7 +66,7 @@ def prepare(self, predictions, iou_type): def prepare_for_coco_detection(self, predictions): coco_results = [] for original_id, prediction in predictions.items(): - if len(prediction) == 0: + if len(prediction) == 0 or len(prediction['boxes'])==0: continue boxes = prediction["boxes"]