Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 1488 yolo nas r integration model #2001

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
148e145
Added DOTA2 dataset
BloodAxe May 21, 2024
03cbca0
Added DOTA2 dataset
BloodAxe May 21, 2024
b17aba4
Add missing type
BloodAxe May 21, 2024
619e26b
Added dataset setup instructions
BloodAxe May 22, 2024
c779513
Added extreme batch visualization callback
BloodAxe May 22, 2024
75a47e1
OBBDetectionMetrics inherit from DetectionMetrics
BloodAxe May 22, 2024
3427148
Added AbstractOBBDataset
BloodAxe May 22, 2024
27592e7
Added docs to prepare dataset script
BloodAxe May 22, 2024
772c245
Update docs on visualization callback
BloodAxe May 22, 2024
8996c1a
Update docs on dataset config file
BloodAxe May 22, 2024
f54dee9
Model, loss and recipes
BloodAxe May 23, 2024
ca809ea
Addedd export notebook
BloodAxe May 23, 2024
5e67d60
Added export support
BloodAxe May 23, 2024
4e331ef
Merge branch 'refs/heads/feature/SG-1488-YoloNAS-R-integration' into …
BloodAxe May 23, 2024
8569676
Added predict script for generation a submission to DOTA test-dev
BloodAxe May 23, 2024
5165c98
Added missing imports
BloodAxe May 23, 2024
133e507
Make YoloNASRDFLHead inherit a base YoloNASDFLHead
BloodAxe May 27, 2024
9b66319
Make YoloNASRDFLHead inherit a base YoloNASDFLHead
BloodAxe May 27, 2024
df7b859
Added explicit mention of TRT support
BloodAxe May 27, 2024
50d17fd
Added explicit mention of TRT support
BloodAxe May 27, 2024
93d70ff
Added new recipes to test
BloodAxe May 28, 2024
0edba4d
Added test to ensure flat_collate_tensors_with_batch_index works corr…
BloodAxe May 28, 2024
d3d4be1
Move limitations to the top
BloodAxe May 28, 2024
fde191e
Added inheritance
BloodAxe Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions LICENSE.YOLONAS-R.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# YOLO-NAS-R License

These model weights or any components comprising the model and the associated documentation (the "Software") is licensed to you by Deci.AI, Inc. ("Deci") under the following terms:
© 2023 – Deci.AI, Inc.

Subject to your full compliance with all of the terms herein, Deci hereby grants you a non-exclusive, revocable, non-sublicensable, non-transferable worldwide and limited right and license to use the Software. If you are using the Deci platform for model optimization, your use of the Software is subject to the Terms of Use available here (the "Terms of Use").

You shall not, without Deci's prior written consent:
(i) resell, lease, sublicense or distribute the Software to any person;
(ii) use the Software to provide third parties with managed services or provide remote access to the Software to any person or compete with Deci in any way;
(iii) represent that you possess any proprietary interest in the Software;
(iv) directly or indirectly, take any action to contest Deci's intellectual property rights or infringe them in any way;
(V) reverse-engineer, decompile, disassemble, alter, enhance, improve, add to, delete from, or otherwise modify, or derive (or attempt to derive) the technology or source code underlying any part of the Software;
(vi) use the Software (or any part thereof) in any illegal, indecent, misleading, harmful, abusive, harassing and/or disparaging manner or for any such purposes. Except as provided under the terms of any separate agreement between you and Deci, including the Terms of Use to the extent applicable, you may not use the Software for any commercial use, including in connection with any models used in a production environment.

DECI PROVIDES THE SOFTWARE "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE OR NON-INFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS OF THE SOFTWARE BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
11 changes: 11 additions & 0 deletions documentation/source/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ All the available models are listed in the column `Model name`.
> - Latency performance measured for T4 and Jetson Xavier NX with TensorRT, using FP16 precision and batch size 1
> - Latency performance measured for Cascade Lake CPU with OpenVINO, using FP16 precision and batch size 1

### Pretrained Oriented Object Detection Models

| Model | Model Name | Dataset | Resolution | mAP<sup>val<br>0.5 | mAP<sup>test<br>0.5 |
|--------------|----------------|---------|------------|--------------------|---------------------|
| YOLO-NAS-R S | yolo_nas_r_s | DOTA 2 | 1024x1024 | 63.424 | 56.56 |
| YOLO-NAS-R M | yolo_nas_r_m | DOTA 2 | 1024x1024 | 64.647 | 57.31 |
| YOLO-NAS-R L | yolo_nas_r_l | DOTA 2 | 1024x1024 | 66.223 | 59.82 |

> **NOTE:** <br/>
> Latency for YoloNAS-R should be nearly identical to original YoloNAS due to the fact most layers are exactly the same

### Pretrained Semantic Segmentation PyTorch Checkpoints

| Model | Model Name | Dataset | Resolution | mIoU | Latency b1<sub>T4</sub> | Latency b1<sub>T4</sub> including IO | Latency (Production)**<sub>Jetson Xavier NX</sub> | Torch Compile Support |
Expand Down
1,180 changes: 1,180 additions & 0 deletions notebooks/YoloNAS_R_Export_to_ONNX.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ class Models:
YOLO_NAS_POSE_M = "yolo_nas_pose_m"
YOLO_NAS_POSE_L = "yolo_nas_pose_l"

YOLO_NAS_R_S = "yolo_nas_r_s"
YOLO_NAS_R_M = "yolo_nas_r_m"
YOLO_NAS_R_L = "yolo_nas_r_l"


class ConcatenatedTensorFormats:
XYXY_LABEL = "XYXY_LABEL"
Expand Down Expand Up @@ -460,3 +464,7 @@ class Processings:
SegmentationResize = "SegmentationResize"
SegmentationPadShortToCropSize = "SegmentationPadShortToCropSize"
SegmentationPadToDivisible = "SegmentationPadToDivisible"
OBBDetectionLongestMaxSizeRescale = "OBBDetectionLongestMaxSizeRescale"
OBBDetectionAutoPadding = "OBBDetectionAutoPadding"
OBBDetectionCenterPadding = "OBBDetectionCenterPadding"
OBBDetectionBottomRightPadding = "OBBDetectionBottomRightPadding"
204 changes: 204 additions & 0 deletions src/super_gradients/conversion/onnx/obb_nms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from typing import Tuple

import torch
from super_gradients.common.abstractions.abstract_logger import get_logger
from torch import nn, Tensor

logger = get_logger(__name__)


class OBBNMSAndReturnAsBatchedResult(nn.Module):
__constants__ = ("batch_size", "confidence_threshold", "iou_threshold", "class_agnostic_nms", "num_pre_nms_predictions", "max_predictions_per_image")

def __init__(
self,
confidence_threshold: float,
iou_threshold: float,
batch_size: int,
class_agnostic_nms: bool,
num_pre_nms_predictions: int,
max_predictions_per_image: int,
):
"""
Perform NMS on the output of the model and return the results in batched format.
This module implements MatrixNMS algorithm for rotated bounding boxes.
Hence, iou_threshold has different meaning compared to regular NMS.

:param confidence_threshold: The confidence threshold to apply to the model output
:param iou_threshold: The IoU threshold for selecting final detections.
An iou_threshold has different meaning compared to regular NMS. In matrix NMS, it is the
multiplication of predicted confidence score and decay factor for the bounding box (A decay applied to
boxes that that has overlap with the current box).
:param batch_size: A fixed batch size for the model
:param class_agnostic_nms: If True, NMS will be class agnostic
:param num_pre_nms_predictions: The number of predictions before NMS step
:param max_predictions_per_image: Maximum number of predictions per image
"""
if max_predictions_per_image > num_pre_nms_predictions:
raise ValueError(
f"max_predictions_per_image ({max_predictions_per_image}) cannot be greater than num_pre_nms_predictions ({num_pre_nms_predictions})"
)
super().__init__()
self.batch_size = batch_size
self.class_agnostic_nms = class_agnostic_nms
self.confidence_threshold = confidence_threshold
self.iou_threshold = iou_threshold
self.num_pre_nms_predictions = num_pre_nms_predictions
self.max_predictions_per_image = max_predictions_per_image

def forward(self, input) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Take decoded predictions from the model, apply NMS to them and return the results in batched format.

:param pred_boxes: [B, N, 5] tensor, float32 in CXCYWHR format
:param pred_scores: [B, N, C] tensor, float32 class scores
:return: A tuple of 4 tensors (num_detections, detection_boxes, detection_scores, detection_classes) will be returned:
- A tensor of [batch_size, 1] containing the image indices for each detection.
- A tensor of [batch_size, max_predictions_per_image, 5] containing the bounding box coordinates
for each detection in [cx, cy, w, h, r] format.
- A tensor of [batch_size, max_predictions_per_image] containing the confidence scores for each detection.
- A tensor of [batch_size, max_predictions_per_image] containing the class indices for each detection.

"""
from super_gradients.training.models.detection_models.yolo_nas_r.yolo_nas_r_post_prediction_callback import rboxes_matrix_nms

pred_boxes, pred_scores = input
pred_cls_conf, pred_cls_labels = torch.max(pred_scores, dim=2)

# Apply confidence threshold
pred_cls_conf = torch.masked_fill(pred_cls_conf, mask=pred_cls_conf < self.confidence_threshold, value=0)
keep = rboxes_matrix_nms(
rboxes_cxcywhr=pred_boxes,
scores=pred_cls_conf,
labels=pred_cls_labels,
class_agnostic_nms=self.class_agnostic_nms,
iou_threshold=self.iou_threshold,
already_sorted=True,
)
num_predictions = []
batched_pred_boxes = []
batched_pred_scores = []
batched_pred_classes = []
for i in range(self.batch_size):
keep_i = keep[i]
pred_boxes_i = pred_boxes[i][keep_i]
pred_scores_i = pred_cls_conf[i][keep_i]
pred_classes_i = pred_cls_labels[i][keep_i]
num_predictions_i = keep_i.sum()

pad_size = self.max_predictions_per_image - pred_boxes_i.size(0)
pred_boxes_i = torch.nn.functional.pad(pred_boxes_i, [0, 0, 0, pad_size], value=-1, mode="constant")
pred_scores_i = torch.nn.functional.pad(pred_scores_i, [0, pad_size], value=-1, mode="constant")
pred_classes_i = torch.nn.functional.pad(pred_classes_i, [0, pad_size], value=-1, mode="constant")

num_predictions.append(num_predictions_i.reshape(1, 1))
batched_pred_boxes.append(pred_boxes_i.unsqueeze(0))
batched_pred_scores.append(pred_scores_i.unsqueeze(0))
batched_pred_classes.append(pred_classes_i.unsqueeze(0))

num_predictions = torch.cat(num_predictions, dim=0)
batched_pred_boxes = torch.cat(batched_pred_boxes, dim=0)
batched_pred_scores = torch.cat(batched_pred_scores, dim=0)
batched_pred_classes = torch.cat(batched_pred_classes, dim=0)

return num_predictions, batched_pred_boxes, batched_pred_scores, batched_pred_classes

def get_output_names(self):
return ["num_predictions", "pred_boxes", "pred_scores", "pred_classes"]

def get_dynamic_axes(self):
return {}


class OBBNMSAndReturnAsFlatResult(nn.Module):
"""
Select the output from ONNX NMS node and return them in flat format.

"""

__constants__ = ("iou_threshold", "confidence_threshold", "batch_size", "class_agnostic_nms", "num_pre_nms_predictions", "max_predictions_per_image")

def __init__(
self,
confidence_threshold,
iou_threshold: float,
batch_size: int,
class_agnostic_nms: bool,
num_pre_nms_predictions: int,
max_predictions_per_image: int,
):
"""
Perform NMS on the output of the model and return the results in flat format.
This module implements MatrixNMS algorithm for rotated bounding boxes.
Hence, iou_threshold has different meaning compared to regular NMS.

:param confidence_threshold: The confidence threshold to apply to the model output
:param iou_threshold: The IoU threshold for selecting final detections.
An iou_threshold has different meaning compared to regular NMS. In matrix NMS, it is the
multiplication of predicted confidence score and decay factor for the bounding box (A decay applied to
boxes that that has overlap with the current box).
:param batch_size: A fixed batch size for the model
:param class_agnostic_nms: If True, NMS will be class agnostic
:param num_pre_nms_predictions: The number of predictions before NMS step
:param max_predictions_per_image: Maximum number of predictions per image
"""
super().__init__()
self.batch_size = batch_size
self.class_agnostic_nms = class_agnostic_nms
self.confidence_threshold = confidence_threshold
self.num_pre_nms_predictions = num_pre_nms_predictions
self.max_predictions_per_image = max_predictions_per_image
self.iou_threshold = iou_threshold

def forward(self, input) -> Tensor:
"""
Take decoded predictions from the model, apply NMS to them and return the results in flat format.

:param pred_boxes: [B, N, 5] tensor
:param pred_scores: [B, N, C] tensor
:return: A single tensor of [Nout, 8] shape, where Nout is the total number of detections across all images in the batch.
Each row will contain [image_index, cx, cy, w, h, r, class confidence, class index] values.
Each image will have at most max_predictions_per_image detections.

"""
from super_gradients.training.models.detection_models.yolo_nas_r.yolo_nas_r_post_prediction_callback import rboxes_matrix_nms

pred_boxes, pred_scores = input
dtype = pred_scores.dtype
pred_cls_conf, pred_cls_labels = torch.max(pred_scores, dim=2)

# Apply confidence threshold
pred_cls_conf = torch.masked_fill(pred_cls_conf, mask=pred_cls_conf < self.confidence_threshold, value=0)
keep = rboxes_matrix_nms(
rboxes_cxcywhr=pred_boxes,
scores=pred_cls_conf,
labels=pred_cls_labels,
class_agnostic_nms=self.class_agnostic_nms,
iou_threshold=self.iou_threshold,
already_sorted=True,
)

flat_results = []
for i in range(self.batch_size):
keep_i = keep[i]
selected_boxes = pred_boxes[i][keep_i]
selected_scores = pred_cls_conf[i][keep_i]
label_indexes = pred_cls_labels[i][keep_i]
batch_indexes = torch.full_like(label_indexes, i)

flat_results_i = torch.cat(
[batch_indexes.unsqueeze(-1).to(dtype), selected_boxes, selected_scores.unsqueeze(-1), label_indexes.unsqueeze(-1).to(dtype)], dim=1
)
flat_results_i = flat_results_i[: self.max_predictions_per_image]
flat_results.append(flat_results_i)

flat_results = torch.cat(flat_results, dim=0)
return flat_results

def get_output_names(self):
return ["flat_predictions"]

def get_dynamic_axes(self):
return {
"flat_predictions": {0: "num_predictions"},
}
Loading
Loading