-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add vision.ops.boxes module * add vision.ops.boxes module * code format
- Loading branch information
Showing
5 changed files
with
596 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from .boxes import ( | ||
batched_nms, | ||
box_area, | ||
box_convert, | ||
box_iou, | ||
clip_boxes_to_image, | ||
complete_box_iou, | ||
distance_box_iou, | ||
generalized_box_iou, | ||
masks_to_boxes, | ||
nms, | ||
remove_small_boxes, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import oneflow as torch | ||
from oneflow import Tensor | ||
|
||
|
||
def _box_cxcywh_to_xyxy(boxes: Tensor) -> Tensor: | ||
""" | ||
Converts bounding boxes from (cx, cy, w, h) format to (x1, y1, x2, y2) format. | ||
(cx, cy) refers to center of bounding box | ||
(w, h) are width and height of bounding box | ||
Args: | ||
boxes (Tensor[N, 4]): boxes in (cx, cy, w, h) format which will be converted. | ||
Returns: | ||
boxes (Tensor(N, 4)): boxes in (x1, y1, x2, y2) format. | ||
""" | ||
# We need to change all 4 of them so some temporary variable is needed. | ||
cx, cy, w, h = boxes.unbind(-1) | ||
x1 = cx - 0.5 * w | ||
y1 = cy - 0.5 * h | ||
x2 = cx + 0.5 * w | ||
y2 = cy + 0.5 * h | ||
|
||
boxes = torch.stack((x1, y1, x2, y2), dim=-1) | ||
|
||
return boxes | ||
|
||
|
||
def _box_xyxy_to_cxcywh(boxes: Tensor) -> Tensor: | ||
""" | ||
Converts bounding boxes from (x1, y1, x2, y2) format to (cx, cy, w, h) format. | ||
(x1, y1) refer to top left of bounding box | ||
(x2, y2) refer to bottom right of bounding box | ||
Args: | ||
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format which will be converted. | ||
Returns: | ||
boxes (Tensor(N, 4)): boxes in (cx, cy, w, h) format. | ||
""" | ||
x1, y1, x2, y2 = boxes.unbind(-1) | ||
cx = (x1 + x2) / 2 | ||
cy = (y1 + y2) / 2 | ||
w = x2 - x1 | ||
h = y2 - y1 | ||
|
||
boxes = torch.stack((cx, cy, w, h), dim=-1) | ||
|
||
return boxes | ||
|
||
|
||
def _box_xywh_to_xyxy(boxes: Tensor) -> Tensor: | ||
""" | ||
Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format. | ||
(x, y) refers to top left of bounding box. | ||
(w, h) refers to width and height of box. | ||
Args: | ||
boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted. | ||
Returns: | ||
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) format. | ||
""" | ||
x, y, w, h = boxes.unbind(-1) | ||
boxes = torch.stack([x, y, x + w, y + h], dim=-1) | ||
return boxes | ||
|
||
|
||
def _box_xyxy_to_xywh(boxes: Tensor) -> Tensor: | ||
""" | ||
Converts bounding boxes from (x1, y1, x2, y2) format to (x, y, w, h) format. | ||
(x1, y1) refer to top left of bounding box | ||
(x2, y2) refer to bottom right of bounding box | ||
Args: | ||
boxes (Tensor[N, 4]): boxes in (x1, y1, x2, y2) which will be converted. | ||
Returns: | ||
boxes (Tensor[N, 4]): boxes in (x, y, w, h) format. | ||
""" | ||
x1, y1, x2, y2 = boxes.unbind(-1) | ||
w = x2 - x1 # x2 - x1 | ||
h = y2 - y1 # y2 - y1 | ||
boxes = torch.stack((x1, y1, w, h), dim=-1) | ||
return boxes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from typing import List, Optional, Tuple, Union | ||
|
||
import oneflow as torch | ||
from oneflow import nn, Tensor | ||
|
||
|
||
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor: | ||
""" | ||
Efficient version of torch.cat that avoids a copy if there is only a single element in a list | ||
""" | ||
# TODO add back the assert | ||
# assert isinstance(tensors, (list, tuple)) | ||
if len(tensors) == 1: | ||
return tensors[0] | ||
return torch.cat(tensors, dim) | ||
|
||
|
||
def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor: | ||
concat_boxes = _cat([b for b in boxes], dim=0) | ||
temp = [] | ||
for i, b in enumerate(boxes): | ||
temp.append(torch.full_like(b[:, :1], i)) | ||
ids = _cat(temp, dim=0) | ||
rois = torch.cat([ids, concat_boxes], dim=1) | ||
return rois | ||
|
||
|
||
def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]): | ||
if isinstance(boxes, (list, tuple)): | ||
for _tensor in boxes: | ||
torch._assert( | ||
_tensor.size(1) == 4, | ||
"The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]", | ||
) | ||
elif isinstance(boxes, torch.Tensor): | ||
torch._assert( | ||
boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]" | ||
) | ||
else: | ||
torch._assert( | ||
False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]" | ||
) | ||
return | ||
|
||
|
||
def split_normalization_params( | ||
model: nn.Module, norm_classes: Optional[List[type]] = None | ||
) -> Tuple[List[Tensor], List[Tensor]]: | ||
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501 | ||
if not norm_classes: | ||
norm_classes = [ | ||
nn.modules.batchnorm._BatchNorm, | ||
nn.LayerNorm, | ||
nn.GroupNorm, | ||
nn.modules.instancenorm._InstanceNorm, | ||
] | ||
|
||
for t in norm_classes: | ||
if not issubclass(t, nn.Module): | ||
raise ValueError(f"Class {t} is not a subclass of nn.Module.") | ||
|
||
classes = tuple(norm_classes) | ||
|
||
norm_params = [] | ||
other_params = [] | ||
for module in model.modules(): | ||
if next(module.children(), None): | ||
other_params.extend( | ||
p for p in module.parameters(recurse=False) if p.requires_grad | ||
) | ||
elif isinstance(module, classes): | ||
norm_params.extend(p for p in module.parameters() if p.requires_grad) | ||
else: | ||
other_params.extend(p for p in module.parameters() if p.requires_grad) | ||
return norm_params, other_params | ||
|
||
|
||
def _upcast(t: Tensor) -> Tensor: | ||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type | ||
if t.is_floating_point(): | ||
return t if t.dtype in (torch.float32, torch.float64) else t.float() | ||
else: | ||
return t if t.dtype in (torch.int32, torch.int64) else t.int() | ||
|
||
|
||
def _upcast_non_float(t: Tensor) -> Tensor: | ||
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type | ||
if t.dtype not in (torch.float32, torch.float64): | ||
return t.float() | ||
return t | ||
|
||
|
||
def _loss_inter_union( | ||
boxes1: torch.Tensor, boxes2: torch.Tensor, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
|
||
x1, y1, x2, y2 = boxes1.unbind(dim=-1) | ||
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) | ||
|
||
# Intersection keypoints | ||
xkis1 = torch.max(x1, x1g) | ||
ykis1 = torch.max(y1, y1g) | ||
xkis2 = torch.min(x2, x2g) | ||
ykis2 = torch.min(y2, y2g) | ||
|
||
intsctk = torch.zeros_like(x1) | ||
mask = (ykis2 > ykis1) & (xkis2 > xkis1) | ||
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) | ||
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk | ||
|
||
return intsctk, unionk |
Oops, something went wrong.