Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detection head re-implementation for PyNetsPresso compatibility #194

Merged
merged 12 commits into from
Oct 16, 2023
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ No changes to highlight.

## Other Changes:

- Fix Faster R-CNN detection head by `@illian01` in [PR 184](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/184)
- Fix Faster R-CNN detection head by `@illian01` in [PR 184](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/184), [PR 194](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/194)
- Refactoring models/op module by `@illian01` in [PR 189](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/189), [PR 190](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/190)
- Release NetsPresso Trainer colab tutorial `@illian01` in [PR 191](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/191)

Expand Down
9 changes: 3 additions & 6 deletions src/netspresso_trainer/dataloaders/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ def val_transforms_efficientformer(conf_augmentation):
return val_transforms_composed

def create_transform_detection(model_name: str, is_training=False):

if model_name == 'efficientformer':
if is_training:
return train_transforms_efficientformer
return val_transforms_efficientformer
raise ValueError(f"No such model named: {model_name} !!!")
if is_training:
return train_transforms_efficientformer
return val_transforms_efficientformer
48 changes: 24 additions & 24 deletions src/netspresso_trainer/metrics/detection/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,39 +171,39 @@ class DetectionMetric(BaseMetric):
def __init__(self, **kwargs):
super().__init__()

def calibrate(self, pred, target, **kwargs):
def calibrate(self, predictions, targets, **kwargs):
result_dict = {k: 0. for k in self.metric_names}

iou_thresholds = np.linspace(0.5, 0.95, 10)
stats = []

# Gather matching stats for predictions and targets

predicted_objs_bbox, predicted_objs_class, predicted_objs_confidence = pred
true_objs_bbox, true_objs_class = target

true_objs = np.concatenate((true_objs_bbox, true_objs_class[..., np.newaxis]), axis=-1)
predicted_objs = np.concatenate((predicted_objs_bbox, predicted_objs_class[..., np.newaxis], predicted_objs_confidence[..., np.newaxis]), axis=-1)

if predicted_objs.shape[0] == 0 and true_objs.shape[0]:
stats.append(
(
np.zeros((0, iou_thresholds.size), dtype=bool),
*np.zeros((2, 0)),
true_objs[:, 4],
for pred, target in zip(predictions, targets):
predicted_objs_bbox, predicted_objs_class, predicted_objs_confidence = pred['post_boxes'], pred['post_labels'], pred['post_scores']
true_objs_bbox, true_objs_class = target['boxes'], target['labels']

true_objs = np.concatenate((true_objs_bbox, true_objs_class[..., np.newaxis]), axis=-1)
predicted_objs = np.concatenate((predicted_objs_bbox, predicted_objs_class[..., np.newaxis], predicted_objs_confidence[..., np.newaxis]), axis=-1)

if predicted_objs.shape[0] == 0 and true_objs.shape[0]:
stats.append(
(
np.zeros((0, iou_thresholds.size), dtype=bool),
*np.zeros((2, 0)),
true_objs[:, 4],
)
)
)

if true_objs.shape[0]:
matches = match_detection_batch(predicted_objs, true_objs, iou_thresholds)
stats.append(
(
matches,
predicted_objs[:, 5],
predicted_objs[:, 4],
true_objs[:, 4],
if true_objs.shape[0]:
matches = match_detection_batch(predicted_objs, true_objs, iou_thresholds)
stats.append(
(
matches,
predicted_objs[:, 5],
predicted_objs[:, 4],
true_objs[:, 4],
)
)
)

# Compute average precisions if any matches exist
if stats:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,48 +211,8 @@ def postprocess_detections(
all_labels = []

# Apply Non-maximum suppression
# Now, it only implemented on batch size 1
boxes, scores, image_shape = pred_boxes_list[0], pred_scores_list[0], image_shapes

boxes = det_utils.clip_boxes_to_image(boxes, image_shape)

# create labels for each prediction
labels = (torch.ones_like(class_logits[0, :], dtype=torch.int64).cumsum(0) - 1).to(device)
labels = labels.view(1, -1).expand_as(scores)

# remove predictions with the background label
boxes = boxes[:, 1:]
scores = scores[:, 1:]
labels = labels[:, 1:]

# batch everything, by making every class prediction be a separate instance
boxes = boxes.reshape(-1, 4)
scores = scores.reshape(-1)
labels = labels.reshape(-1)

# remove low scoring boxes
inds = torch.where(scores > self.score_thresh)[0]
boxes, scores, labels = boxes[inds], scores[inds], labels[inds]

# remove empty boxes
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

# non-maximum suppression, independently done per class
keep = det_utils._batched_nms_vanilla(boxes, scores, labels, self.nms_thresh, self.class_ids)
# keep only topk scoring predictions
keep = keep[: self.detections_per_img]
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

all_boxes.append(boxes)
all_scores.append(scores)
all_labels.append(labels)

# TODO
# Apply NMS on various batch
'''
for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
boxes = det_utils.clip_boxes_to_image(boxes, image_shape)
for boxes, scores in zip(pred_boxes_list, pred_scores_list):
boxes = det_utils.clip_boxes_to_image(boxes, image_shapes)

# create labels for each prediction
labels = torch.arange(num_classes, device=device)
Expand Down Expand Up @@ -285,7 +245,6 @@ def postprocess_detections(
all_boxes.append(boxes)
all_scores.append(scores)
all_labels.append(labels)
'''

return all_boxes, all_scores, all_labels

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,36 +235,11 @@ def filter_proposals(

final_boxes = []
final_scores = []

# Apply Non-maximum suppression
# Now, it only implemented on batch size 1
boxes, scores, lvl, img_shape = proposals[0], objectness_prob[0], levels[0], image_shapes
boxes = det_utils.clip_boxes_to_image(boxes, img_shape)

# remove small boxes
keep = box_ops.remove_small_boxes(boxes, self.min_size)
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

# remove low scoring boxes
# use >= for Backwards compatibility
keep = torch.where(scores >= self.score_thresh)[0]
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

# non-maximum suppression, independently done per level
keep = det_utils._batched_nms_vanilla(boxes, scores, lvl, self.nms_thresh, list(range(len(num_anchors_per_level))))

# keep only topk scoring predictions
keep = keep[: self.post_nms_top_n()]
boxes, scores = boxes[keep], scores[keep]

final_boxes.append(boxes)
final_scores.append(scores)
# Apply Non-maximum suppression

# TODO
# Apply NMS on various batch
'''
for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
for boxes, scores, lvl in zip(proposals, objectness_prob, levels,):
boxes = box_ops.clip_boxes_to_image(boxes, image_shapes)

# remove small boxes
keep = box_ops.remove_small_boxes(boxes, self.min_size)
Expand All @@ -284,7 +259,7 @@ def filter_proposals(

final_boxes.append(boxes)
final_scores.append(scores)
'''

return final_boxes, final_scores

def forward(
Expand Down
122 changes: 98 additions & 24 deletions src/netspresso_trainer/pipelines/detection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import copy
import logging
import os
from pathlib import Path

import numpy as np
import torch
from omegaconf import OmegaConf

from ..models.utils import DetectionModelOutput
from ..models import build_model
from ..models.utils import DetectionModelOutput, load_from_checkpoint
from ..utils.fx import save_graphmodule
from ..utils.onnx import save_onnx
from .base import BasePipeline

logger = logging.getLogger("netspresso_trainer")
Expand All @@ -17,6 +22,15 @@ def __init__(self, conf, task, model_name, model, devices, train_dataloader, eva
train_dataloader, eval_dataloader, class_map, **kwargs)
self.num_classes = train_dataloader.dataset.num_classes

# Re-compose torch.fx backbone and nn.Module head
# To load head weights, config should have head_checkpoint value.
if kwargs['is_graphmodule_training']:
model = build_model(conf.model, task, self.num_classes, None, conf.augmentation.img_size)
model.backbone = self.model
model.head = load_from_checkpoint(model.head, conf.model.head_checkpoint)
model = model.to(device=devices)
self.model = model

def train_step(self, batch):
self.model.train()
images, labels, bboxes = batch['pixel_values'], batch['label'], batch['bbox']
Expand All @@ -39,7 +53,7 @@ def train_step(self, batch):

# generate proposals for training
proposals = rpn_features['boxes']
proposals, matched_idxs, labels, regression_targets = head.roi_heads.select_training_samples(proposals, targets)
proposals, matched_idxs, roi_head_labels, regression_targets = head.roi_heads.select_training_samples(proposals, targets)

# forward to roi head
roi_features = head.roi_heads(features, proposals, head.image_size)
Expand All @@ -48,33 +62,34 @@ def train_step(self, batch):
out = DetectionModelOutput()
out.update(rpn_features)
out.update(roi_features)
out.update({'labels': labels, 'regression_targets': regression_targets})
out.update({'labels': roi_head_labels, 'regression_targets': regression_targets})

# Compute loss
self.loss_factory.calc(out, target=targets, phase='train')

self.loss_factory.backward()
self.optimizer.step()

# TODO: metric update
# out = {k: v.detach() for k, v in out.items()}
# self.metric_factory(out['pred'], target=targets, mode='train')

if self.conf.distributed:
torch.distributed.barrier()

def valid_step(self, batch):
self.model.eval()
images, labels, bboxes = batch['pixel_values'], batch['label'], batch['bbox']
bboxes = [bbox.to(self.devices) for bbox in bboxes]
labels = [label.to(self.devices) for label in labels]
images = images.to(self.devices)
#targets = [{"boxes": box.to(self.devices), "labels": label.to(self.devices)}
# for box, label in zip(bboxes, labels)]
targets = [{"boxes": box, "labels": label} for box, label in zip(bboxes, labels)]

out = self.model(images)
# TODO: compute loss for validation
#self.loss_factory.calc(out, target=targets, phase='valid')

# TODO: metric update
# self.metric_factory(out['pred'], (labels, bboxes), mode='valid')
# Compute loss
head = self.model.head
matched_idxs, roi_head_labels = head.roi_heads.assign_targets_to_proposals(out['boxes'], bboxes, labels)
matched_gt_boxes = [bbox[idx] for idx, bbox in zip(matched_idxs, bboxes)]
regression_targets = head.roi_heads.box_coder.encode(matched_gt_boxes, out['boxes'])
out.update({'labels': roi_head_labels, 'regression_targets': regression_targets})
self.loss_factory.calc(out, target=targets, phase='valid')

if self.conf.distributed:
torch.distributed.barrier()
Expand Down Expand Up @@ -102,18 +117,77 @@ def test_step(self, batch):
return results

def get_metric_with_all_outputs(self, outputs):
targets = np.empty((0, 4))
preds = np.empty((0, 5)) # with confidence score
targets_indices = np.empty(0)
preds_indices = np.empty(0)
pred = []
targets = []
for output_batch in outputs:
for detection, class_idx in output_batch['target']:
targets = np.vstack([targets, detection])
targets_indices = np.append(targets_indices, class_idx)
target_on_image = {}
target_on_image['boxes'] = detection
target_on_image['labels'] = class_idx
targets.append(target_on_image)

for detection, class_idx in output_batch['pred']:
preds = np.vstack([preds, detection])
preds_indices = np.append(preds_indices, class_idx)

pred_bbox, pred_confidence = preds[..., :4], preds[..., -1] # (N x 4), (N,)
self.metric_factory.calc((pred_bbox, preds_indices, pred_confidence), (targets, targets_indices), phase='valid')
pred_on_image = {}
pred_on_image['post_boxes'] = detection[..., :4]
pred_on_image['post_scores'] = detection[..., -1]
pred_on_image['post_labels'] = class_idx
pred.append(pred_on_image)
self.metric_factory.calc(pred, target=targets, phase='valid')

def save_checkpoint(self, epoch: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@illian01

If we override task-specific save_checkpoint function, it may be better to express as abstractmethod in base pipeline.
Someone may be hard to find where it is defined..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@deepkyu
I overrided save_checkpoint method because we have to separate backbone and head for detection task.
Thus, this overrided method don't need if we serve fx model with integrated one.
I think it would be better to remove this method when we can compress and benchmark the integrated detection task model.
I will fix save_checkpoint to abstractedmethod if we determined to serve detection model as is (backbone and head separately) through internal discussion.


# Check whether the valid loss is minimum at this epoch
valid_losses = {epoch: record['valid_losses'].get('total') for epoch, record in self.training_history.items()
if 'valid_losses' in record}
best_epoch = min(valid_losses, key=valid_losses.get)
save_best_model = best_epoch == epoch

model = self.model.module if hasattr(self.model, 'module') else self.model
if self.save_dtype == torch.float16:
model = copy.deepcopy(model).type(self.save_dtype)
result_dir = self.train_logger.result_dir
model_path = Path(result_dir) / f"{self.task}_{self.model_name}_epoch_{epoch}.ext"
best_model_path = Path(result_dir) / f"{self.task}_{self.model_name}_best.ext"
optimizer_path = Path(result_dir) / f"{self.task}_{self.model_name}_epoch_{epoch}_optimzer.pth"

if self.save_optimizer_state:
optimizer = self.optimizer.module if hasattr(self.optimizer, 'module') else self.optimizer
save_dict = {'optimizer': optimizer.state_dict(), 'start_epoch_at_one': self.start_epoch_at_one, 'last_epoch': epoch}
torch.save(save_dict, optimizer_path)
logger.debug(f"Optimizer state saved at {str(optimizer_path)}")

if self.is_graphmodule_training:
# Just save graphmodule checkpoint
torch.save(model, (model_path.parent / f"{model_path.stem}_backbone").with_suffix(".pth"))
logger.debug(f"PyTorch FX model saved at {(model_path.parent / f'{model_path.stem}_backbone').with_suffix('.pth')}")
torch.save(model.head.state_dict(), (model_path.parent / f"{model_path.stem}_head").with_suffix(".pth"))
logger.info(f"Detection head saved at {(model_path.parent / f'{model_path.stem}_head').with_suffix('.pth')}")
if save_best_model:
save_onnx(model, best_model_path.with_suffix(".onnx"), sample_input=self.sample_input.type(self.save_dtype))
logger.info(f"ONNX model converting and saved at {str(best_model_path.with_suffix('.onnx'))}")

torch.save(model.backbone, (model_path.parent / f"{best_model_path.stem}_backbone").with_suffix(".pt"))
logger.info(f"Best model saved at {(model_path.parent / f'{best_model_path.stem}_backbone').with_suffix('.pt')}")
# save head separately
torch.save(model.head.state_dict(), (model_path.parent / f"{best_model_path.stem}_head").with_suffix(".pth"))
logger.info(f"Detection head saved at {(model_path.parent / f'{best_model_path.stem}_head').with_suffix('.pth')}")
return
torch.save(model.state_dict(), model_path.with_suffix(".pth"))
logger.debug(f"PyTorch model saved at {str(model_path.with_suffix('.pth'))}")
if save_best_model:
torch.save(model.state_dict(), best_model_path.with_suffix(".pth"))
logger.info(f"Best model saved at {str(best_model_path.with_suffix('.pth'))}")

try:
save_onnx(model, best_model_path.with_suffix(".onnx"), sample_input=self.sample_input.type(self.save_dtype))
logger.info(f"ONNX model converting and saved at {str(best_model_path.with_suffix('.onnx'))}")

# fx backbone
save_graphmodule(model.backbone, (model_path.parent / f"{best_model_path.stem}_backbone_fx").with_suffix(".pt"))
logger.info(f"PyTorch FX model tracing and saved at {(model_path.parent / f'{best_model_path.stem}_backbone_fx').with_suffix('.pt')}")
# save head separately
torch.save(model.head.state_dict(), (model_path.parent / f"{best_model_path.stem}_head").with_suffix(".pth"))
logger.info(f"Detection head saved at {(model_path.parent / f'{best_model_path.stem}_head').with_suffix('.pth')}")
except Exception as e:
logger.error(e)
pass