Skip to content

Commit

Permalink
Merge pull request #194 from Nota-NetsPresso/177-detection-head-re-impl
Browse files Browse the repository at this point in the history
Detection head re-implementation for PyNetsPresso compatibility
  • Loading branch information
illian01 authored Oct 16, 2023
2 parents 6c9eb5a + 3dd1406 commit 50ae650
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 127 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ No changes to highlight.

## Other Changes:

- Fix Faster R-CNN detection head by `@illian01` in [PR 184](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/184)
- Fix Faster R-CNN detection head by `@illian01` in [PR 184](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/184), [PR 194](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/194)
- Refactoring models/op module by `@illian01` in [PR 189](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/189), [PR 190](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/190)
- Release NetsPresso Trainer colab tutorial by `@illian01` in [PR 191](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/191)
- Parameterize activation function of BasicBlock and Bottleneck by `@illian01` in [PR193](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/193)
Expand Down
9 changes: 3 additions & 6 deletions src/netspresso_trainer/dataloaders/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ def val_transforms_efficientformer(conf_augmentation):
return val_transforms_composed

def create_transform_detection(model_name: str, is_training=False):

if model_name == 'efficientformer':
if is_training:
return train_transforms_efficientformer
return val_transforms_efficientformer
raise ValueError(f"No such model named: {model_name} !!!")
if is_training:
return train_transforms_efficientformer
return val_transforms_efficientformer
48 changes: 24 additions & 24 deletions src/netspresso_trainer/metrics/detection/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,39 +171,39 @@ class DetectionMetric(BaseMetric):
def __init__(self, **kwargs):
super().__init__()

def calibrate(self, pred, target, **kwargs):
def calibrate(self, predictions, targets, **kwargs):
result_dict = {k: 0. for k in self.metric_names}

iou_thresholds = np.linspace(0.5, 0.95, 10)
stats = []

# Gather matching stats for predictions and targets

predicted_objs_bbox, predicted_objs_class, predicted_objs_confidence = pred
true_objs_bbox, true_objs_class = target

true_objs = np.concatenate((true_objs_bbox, true_objs_class[..., np.newaxis]), axis=-1)
predicted_objs = np.concatenate((predicted_objs_bbox, predicted_objs_class[..., np.newaxis], predicted_objs_confidence[..., np.newaxis]), axis=-1)

if predicted_objs.shape[0] == 0 and true_objs.shape[0]:
stats.append(
(
np.zeros((0, iou_thresholds.size), dtype=bool),
*np.zeros((2, 0)),
true_objs[:, 4],
for pred, target in zip(predictions, targets):
predicted_objs_bbox, predicted_objs_class, predicted_objs_confidence = pred['post_boxes'], pred['post_labels'], pred['post_scores']
true_objs_bbox, true_objs_class = target['boxes'], target['labels']

true_objs = np.concatenate((true_objs_bbox, true_objs_class[..., np.newaxis]), axis=-1)
predicted_objs = np.concatenate((predicted_objs_bbox, predicted_objs_class[..., np.newaxis], predicted_objs_confidence[..., np.newaxis]), axis=-1)

if predicted_objs.shape[0] == 0 and true_objs.shape[0]:
stats.append(
(
np.zeros((0, iou_thresholds.size), dtype=bool),
*np.zeros((2, 0)),
true_objs[:, 4],
)
)
)

if true_objs.shape[0]:
matches = match_detection_batch(predicted_objs, true_objs, iou_thresholds)
stats.append(
(
matches,
predicted_objs[:, 5],
predicted_objs[:, 4],
true_objs[:, 4],
if true_objs.shape[0]:
matches = match_detection_batch(predicted_objs, true_objs, iou_thresholds)
stats.append(
(
matches,
predicted_objs[:, 5],
predicted_objs[:, 4],
true_objs[:, 4],
)
)
)

# Compute average precisions if any matches exist
if stats:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,48 +211,8 @@ def postprocess_detections(
all_labels = []

# Apply Non-maximum suppression
# Now, it only implemented on batch size 1
boxes, scores, image_shape = pred_boxes_list[0], pred_scores_list[0], image_shapes

boxes = det_utils.clip_boxes_to_image(boxes, image_shape)

# create labels for each prediction
labels = (torch.ones_like(class_logits[0, :], dtype=torch.int64).cumsum(0) - 1).to(device)
labels = labels.view(1, -1).expand_as(scores)

# remove predictions with the background label
boxes = boxes[:, 1:]
scores = scores[:, 1:]
labels = labels[:, 1:]

# batch everything, by making every class prediction be a separate instance
boxes = boxes.reshape(-1, 4)
scores = scores.reshape(-1)
labels = labels.reshape(-1)

# remove low scoring boxes
inds = torch.where(scores > self.score_thresh)[0]
boxes, scores, labels = boxes[inds], scores[inds], labels[inds]

# remove empty boxes
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

# non-maximum suppression, independently done per class
keep = det_utils._batched_nms_vanilla(boxes, scores, labels, self.nms_thresh, self.class_ids)
# keep only topk scoring predictions
keep = keep[: self.detections_per_img]
boxes, scores, labels = boxes[keep], scores[keep], labels[keep]

all_boxes.append(boxes)
all_scores.append(scores)
all_labels.append(labels)

# TODO
# Apply NMS on various batch
'''
for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes):
boxes = det_utils.clip_boxes_to_image(boxes, image_shape)
for boxes, scores in zip(pred_boxes_list, pred_scores_list):
boxes = det_utils.clip_boxes_to_image(boxes, image_shapes)

# create labels for each prediction
labels = torch.arange(num_classes, device=device)
Expand Down Expand Up @@ -285,7 +245,6 @@ def postprocess_detections(
all_boxes.append(boxes)
all_scores.append(scores)
all_labels.append(labels)
'''

return all_boxes, all_scores, all_labels

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,36 +235,11 @@ def filter_proposals(

final_boxes = []
final_scores = []

# Apply Non-maximum suppression
# Now, it only implemented on batch size 1
boxes, scores, lvl, img_shape = proposals[0], objectness_prob[0], levels[0], image_shapes
boxes = det_utils.clip_boxes_to_image(boxes, img_shape)

# remove small boxes
keep = box_ops.remove_small_boxes(boxes, self.min_size)
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

# remove low scoring boxes
# use >= for Backwards compatibility
keep = torch.where(scores >= self.score_thresh)[0]
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

# non-maximum suppression, independently done per level
keep = det_utils._batched_nms_vanilla(boxes, scores, lvl, self.nms_thresh, list(range(len(num_anchors_per_level))))

# keep only topk scoring predictions
keep = keep[: self.post_nms_top_n()]
boxes, scores = boxes[keep], scores[keep]

final_boxes.append(boxes)
final_scores.append(scores)
# Apply Non-maximum suppression

# TODO
# Apply NMS on various batch
'''
for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
for boxes, scores, lvl in zip(proposals, objectness_prob, levels,):
boxes = box_ops.clip_boxes_to_image(boxes, image_shapes)

# remove small boxes
keep = box_ops.remove_small_boxes(boxes, self.min_size)
Expand All @@ -284,7 +259,7 @@ def filter_proposals(

final_boxes.append(boxes)
final_scores.append(scores)
'''

return final_boxes, final_scores

def forward(
Expand Down
122 changes: 98 additions & 24 deletions src/netspresso_trainer/pipelines/detection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import copy
import logging
import os
from pathlib import Path

import numpy as np
import torch
from omegaconf import OmegaConf

from ..models.utils import DetectionModelOutput
from ..models import build_model
from ..models.utils import DetectionModelOutput, load_from_checkpoint
from ..utils.fx import save_graphmodule
from ..utils.onnx import save_onnx
from .base import BasePipeline

logger = logging.getLogger("netspresso_trainer")
Expand All @@ -17,6 +22,15 @@ def __init__(self, conf, task, model_name, model, devices, train_dataloader, eva
train_dataloader, eval_dataloader, class_map, **kwargs)
self.num_classes = train_dataloader.dataset.num_classes

# Re-compose torch.fx backbone and nn.Module head
# To load head weights, config should have head_checkpoint value.
if kwargs['is_graphmodule_training']:
model = build_model(conf.model, task, self.num_classes, None, conf.augmentation.img_size)
model.backbone = self.model
model.head = load_from_checkpoint(model.head, conf.model.head_checkpoint)
model = model.to(device=devices)
self.model = model

def train_step(self, batch):
self.model.train()
images, labels, bboxes = batch['pixel_values'], batch['label'], batch['bbox']
Expand All @@ -39,7 +53,7 @@ def train_step(self, batch):

# generate proposals for training
proposals = rpn_features['boxes']
proposals, matched_idxs, labels, regression_targets = head.roi_heads.select_training_samples(proposals, targets)
proposals, matched_idxs, roi_head_labels, regression_targets = head.roi_heads.select_training_samples(proposals, targets)

# forward to roi head
roi_features = head.roi_heads(features, proposals, head.image_size)
Expand All @@ -48,33 +62,34 @@ def train_step(self, batch):
out = DetectionModelOutput()
out.update(rpn_features)
out.update(roi_features)
out.update({'labels': labels, 'regression_targets': regression_targets})
out.update({'labels': roi_head_labels, 'regression_targets': regression_targets})

# Compute loss
self.loss_factory.calc(out, target=targets, phase='train')

self.loss_factory.backward()
self.optimizer.step()

# TODO: metric update
# out = {k: v.detach() for k, v in out.items()}
# self.metric_factory(out['pred'], target=targets, mode='train')

if self.conf.distributed:
torch.distributed.barrier()

def valid_step(self, batch):
self.model.eval()
images, labels, bboxes = batch['pixel_values'], batch['label'], batch['bbox']
bboxes = [bbox.to(self.devices) for bbox in bboxes]
labels = [label.to(self.devices) for label in labels]
images = images.to(self.devices)
#targets = [{"boxes": box.to(self.devices), "labels": label.to(self.devices)}
# for box, label in zip(bboxes, labels)]
targets = [{"boxes": box, "labels": label} for box, label in zip(bboxes, labels)]

out = self.model(images)
# TODO: compute loss for validation
#self.loss_factory.calc(out, target=targets, phase='valid')

# TODO: metric update
# self.metric_factory(out['pred'], (labels, bboxes), mode='valid')
# Compute loss
head = self.model.head
matched_idxs, roi_head_labels = head.roi_heads.assign_targets_to_proposals(out['boxes'], bboxes, labels)
matched_gt_boxes = [bbox[idx] for idx, bbox in zip(matched_idxs, bboxes)]
regression_targets = head.roi_heads.box_coder.encode(matched_gt_boxes, out['boxes'])
out.update({'labels': roi_head_labels, 'regression_targets': regression_targets})
self.loss_factory.calc(out, target=targets, phase='valid')

if self.conf.distributed:
torch.distributed.barrier()
Expand Down Expand Up @@ -102,18 +117,77 @@ def test_step(self, batch):
return results

def get_metric_with_all_outputs(self, outputs):
targets = np.empty((0, 4))
preds = np.empty((0, 5)) # with confidence score
targets_indices = np.empty(0)
preds_indices = np.empty(0)
pred = []
targets = []
for output_batch in outputs:
for detection, class_idx in output_batch['target']:
targets = np.vstack([targets, detection])
targets_indices = np.append(targets_indices, class_idx)
target_on_image = {}
target_on_image['boxes'] = detection
target_on_image['labels'] = class_idx
targets.append(target_on_image)

for detection, class_idx in output_batch['pred']:
preds = np.vstack([preds, detection])
preds_indices = np.append(preds_indices, class_idx)

pred_bbox, pred_confidence = preds[..., :4], preds[..., -1] # (N x 4), (N,)
self.metric_factory.calc((pred_bbox, preds_indices, pred_confidence), (targets, targets_indices), phase='valid')
pred_on_image = {}
pred_on_image['post_boxes'] = detection[..., :4]
pred_on_image['post_scores'] = detection[..., -1]
pred_on_image['post_labels'] = class_idx
pred.append(pred_on_image)
self.metric_factory.calc(pred, target=targets, phase='valid')

def save_checkpoint(self, epoch: int):

# Check whether the valid loss is minimum at this epoch
valid_losses = {epoch: record['valid_losses'].get('total') for epoch, record in self.training_history.items()
if 'valid_losses' in record}
best_epoch = min(valid_losses, key=valid_losses.get)
save_best_model = best_epoch == epoch

model = self.model.module if hasattr(self.model, 'module') else self.model
if self.save_dtype == torch.float16:
model = copy.deepcopy(model).type(self.save_dtype)
result_dir = self.train_logger.result_dir
model_path = Path(result_dir) / f"{self.task}_{self.model_name}_epoch_{epoch}.ext"
best_model_path = Path(result_dir) / f"{self.task}_{self.model_name}_best.ext"
optimizer_path = Path(result_dir) / f"{self.task}_{self.model_name}_epoch_{epoch}_optimzer.pth"

if self.save_optimizer_state:
optimizer = self.optimizer.module if hasattr(self.optimizer, 'module') else self.optimizer
save_dict = {'optimizer': optimizer.state_dict(), 'start_epoch_at_one': self.start_epoch_at_one, 'last_epoch': epoch}
torch.save(save_dict, optimizer_path)
logger.debug(f"Optimizer state saved at {str(optimizer_path)}")

if self.is_graphmodule_training:
# Just save graphmodule checkpoint
torch.save(model, (model_path.parent / f"{model_path.stem}_backbone").with_suffix(".pth"))
logger.debug(f"PyTorch FX model saved at {(model_path.parent / f'{model_path.stem}_backbone').with_suffix('.pth')}")
torch.save(model.head.state_dict(), (model_path.parent / f"{model_path.stem}_head").with_suffix(".pth"))
logger.info(f"Detection head saved at {(model_path.parent / f'{model_path.stem}_head').with_suffix('.pth')}")
if save_best_model:
save_onnx(model, best_model_path.with_suffix(".onnx"), sample_input=self.sample_input.type(self.save_dtype))
logger.info(f"ONNX model converting and saved at {str(best_model_path.with_suffix('.onnx'))}")

torch.save(model.backbone, (model_path.parent / f"{best_model_path.stem}_backbone").with_suffix(".pt"))
logger.info(f"Best model saved at {(model_path.parent / f'{best_model_path.stem}_backbone').with_suffix('.pt')}")
# save head separately
torch.save(model.head.state_dict(), (model_path.parent / f"{best_model_path.stem}_head").with_suffix(".pth"))
logger.info(f"Detection head saved at {(model_path.parent / f'{best_model_path.stem}_head').with_suffix('.pth')}")
return
torch.save(model.state_dict(), model_path.with_suffix(".pth"))
logger.debug(f"PyTorch model saved at {str(model_path.with_suffix('.pth'))}")
if save_best_model:
torch.save(model.state_dict(), best_model_path.with_suffix(".pth"))
logger.info(f"Best model saved at {str(best_model_path.with_suffix('.pth'))}")

try:
save_onnx(model, best_model_path.with_suffix(".onnx"), sample_input=self.sample_input.type(self.save_dtype))
logger.info(f"ONNX model converting and saved at {str(best_model_path.with_suffix('.onnx'))}")

# fx backbone
save_graphmodule(model.backbone, (model_path.parent / f"{best_model_path.stem}_backbone_fx").with_suffix(".pt"))
logger.info(f"PyTorch FX model tracing and saved at {(model_path.parent / f'{best_model_path.stem}_backbone_fx').with_suffix('.pt')}")
# save head separately
torch.save(model.head.state_dict(), (model_path.parent / f"{best_model_path.stem}_head").with_suffix(".pth"))
logger.info(f"Detection head saved at {(model_path.parent / f'{best_model_path.stem}_head').with_suffix('.pth')}")
except Exception as e:
logger.error(e)
pass

0 comments on commit 50ae650

Please sign in to comment.