From d937d98d0566061f896ba2fd19edeee3de13f5b6 Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Wed, 11 Oct 2023 12:11:11 +0900 Subject: [PATCH 01/11] Override save func to save separately --- src/netspresso_trainer/pipelines/detection.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/netspresso_trainer/pipelines/detection.py b/src/netspresso_trainer/pipelines/detection.py index ec256bf5d..9ac9c5eda 100644 --- a/src/netspresso_trainer/pipelines/detection.py +++ b/src/netspresso_trainer/pipelines/detection.py @@ -1,5 +1,7 @@ +import copy import logging import os +from pathlib import Path import numpy as np import torch @@ -7,6 +9,8 @@ from ..models.utils import DetectionModelOutput from .base import BasePipeline +from ..utils.fx import save_graphmodule +from ..utils.onnx import save_onnx logger = logging.getLogger("netspresso_trainer") @@ -117,3 +121,55 @@ def get_metric_with_all_outputs(self, outputs): pred_bbox, pred_confidence = preds[..., :4], preds[..., -1] # (N x 4), (N,) self.metric_factory.calc((pred_bbox, preds_indices, pred_confidence), (targets, targets_indices), phase='valid') + + def save_checkpoint(self, epoch: int): + + # Check whether the valid loss is minimum at this epoch + valid_losses = {epoch: record['valid_losses'].get('total') for epoch, record in self.training_history.items() + if 'valid_losses' in record} + best_epoch = min(valid_losses, key=valid_losses.get) + save_best_model = best_epoch == epoch + + model = self.model.module if hasattr(self.model, 'module') else self.model + if self.save_dtype == torch.float16: + model = copy.deepcopy(model).type(self.save_dtype) + result_dir = self.train_logger.result_dir + model_path = Path(result_dir) / f"{self.task}_{self.model_name}_epoch_{epoch}.ext" + best_model_path = Path(result_dir) / f"{self.task}_{self.model_name}_best.ext" + optimizer_path = Path(result_dir) / f"{self.task}_{self.model_name}_epoch_{epoch}_optimzer.pth" + + if self.save_optimizer_state: + optimizer = self.optimizer.module if hasattr(self.optimizer, 'module') else self.optimizer + save_dict = {'optimizer': optimizer.state_dict(), 'start_epoch_at_one': self.start_epoch_at_one, 'last_epoch': epoch} + torch.save(save_dict, optimizer_path) + logger.debug(f"Optimizer state saved at {str(optimizer_path)}") + + if self.is_graphmodule_training: + # Just save graphmodule checkpoint + torch.save(model, model_path.with_suffix(".pt")) + logger.debug(f"PyTorch FX model saved at {str(model_path.with_suffix('.pt'))}") + if save_best_model: + save_onnx(model, best_model_path.with_suffix(".onnx"), sample_input=self.sample_input.type(self.save_dtype)) + logger.info(f"ONNX model converting and saved at {str(best_model_path.with_suffix('.onnx'))}") + torch.save(model, best_model_path.with_suffix(".pt")) + logger.info(f"Best model saved at {str(best_model_path.with_suffix('.pt'))}") + return + torch.save(model.state_dict(), model_path.with_suffix(".pth")) + logger.debug(f"PyTorch model saved at {str(model_path.with_suffix('.pth'))}") + if save_best_model: + torch.save(model.state_dict(), best_model_path.with_suffix(".pth")) + logger.info(f"Best model saved at {str(best_model_path.with_suffix('.pth'))}") + + try: + save_onnx(model, best_model_path.with_suffix(".onnx"), sample_input=self.sample_input.type(self.save_dtype)) + logger.info(f"ONNX model converting and saved at {str(best_model_path.with_suffix('.onnx'))}") + + # fx backbone + save_graphmodule(model.backbone, (model_path.parent / f"{best_model_path.stem}_backbone_fx").with_suffix(".pt")) + logger.info(f"PyTorch FX model tracing and saved at {str(best_model_path.with_suffix('.pt'))}") + # save head separately + torch.save(model.head.state_dict(), (model_path.parent / f"{best_model_path.stem}_head").with_suffix(".pth")) + logger.info(f"Detection head saved at {str(best_model_path.with_suffix('.pth'))}") + except Exception as e: + logger.error(e) + pass From 28ee9bae7c3dba98913b0b30bd8455705c015021 Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Wed, 11 Oct 2023 12:20:47 +0900 Subject: [PATCH 02/11] Revert nms process --- .../experimental/detection/roi_heads.py | 45 +------------------ .../detection/experimental/detection/rpn.py | 33 ++------------ 2 files changed, 6 insertions(+), 72 deletions(-) diff --git a/src/netspresso_trainer/models/heads/detection/experimental/detection/roi_heads.py b/src/netspresso_trainer/models/heads/detection/experimental/detection/roi_heads.py index 83dcc9a67..0eeaa064f 100644 --- a/src/netspresso_trainer/models/heads/detection/experimental/detection/roi_heads.py +++ b/src/netspresso_trainer/models/heads/detection/experimental/detection/roi_heads.py @@ -211,48 +211,8 @@ def postprocess_detections( all_labels = [] # Apply Non-maximum suppression - # Now, it only implemented on batch size 1 - boxes, scores, image_shape = pred_boxes_list[0], pred_scores_list[0], image_shapes - - boxes = det_utils.clip_boxes_to_image(boxes, image_shape) - - # create labels for each prediction - labels = (torch.ones_like(class_logits[0, :], dtype=torch.int64).cumsum(0) - 1).to(device) - labels = labels.view(1, -1).expand_as(scores) - - # remove predictions with the background label - boxes = boxes[:, 1:] - scores = scores[:, 1:] - labels = labels[:, 1:] - - # batch everything, by making every class prediction be a separate instance - boxes = boxes.reshape(-1, 4) - scores = scores.reshape(-1) - labels = labels.reshape(-1) - - # remove low scoring boxes - inds = torch.where(scores > self.score_thresh)[0] - boxes, scores, labels = boxes[inds], scores[inds], labels[inds] - - # remove empty boxes - keep = box_ops.remove_small_boxes(boxes, min_size=1e-2) - boxes, scores, labels = boxes[keep], scores[keep], labels[keep] - - # non-maximum suppression, independently done per class - keep = det_utils._batched_nms_vanilla(boxes, scores, labels, self.nms_thresh, self.class_ids) - # keep only topk scoring predictions - keep = keep[: self.detections_per_img] - boxes, scores, labels = boxes[keep], scores[keep], labels[keep] - - all_boxes.append(boxes) - all_scores.append(scores) - all_labels.append(labels) - - # TODO - # Apply NMS on various batch - ''' - for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes): - boxes = det_utils.clip_boxes_to_image(boxes, image_shape) + for boxes, scores in zip(pred_boxes_list, pred_scores_list): + boxes = det_utils.clip_boxes_to_image(boxes, image_shapes) # create labels for each prediction labels = torch.arange(num_classes, device=device) @@ -285,7 +245,6 @@ def postprocess_detections( all_boxes.append(boxes) all_scores.append(scores) all_labels.append(labels) - ''' return all_boxes, all_scores, all_labels diff --git a/src/netspresso_trainer/models/heads/detection/experimental/detection/rpn.py b/src/netspresso_trainer/models/heads/detection/experimental/detection/rpn.py index 067464dd5..734b4b7f7 100644 --- a/src/netspresso_trainer/models/heads/detection/experimental/detection/rpn.py +++ b/src/netspresso_trainer/models/heads/detection/experimental/detection/rpn.py @@ -235,36 +235,11 @@ def filter_proposals( final_boxes = [] final_scores = [] - - # Apply Non-maximum suppression - # Now, it only implemented on batch size 1 - boxes, scores, lvl, img_shape = proposals[0], objectness_prob[0], levels[0], image_shapes - boxes = det_utils.clip_boxes_to_image(boxes, img_shape) - - # remove small boxes - keep = box_ops.remove_small_boxes(boxes, self.min_size) - boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep] - - # remove low scoring boxes - # use >= for Backwards compatibility - keep = torch.where(scores >= self.score_thresh)[0] - boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep] - # non-maximum suppression, independently done per level - keep = det_utils._batched_nms_vanilla(boxes, scores, lvl, self.nms_thresh, list(range(len(num_anchors_per_level)))) - - # keep only topk scoring predictions - keep = keep[: self.post_nms_top_n()] - boxes, scores = boxes[keep], scores[keep] - - final_boxes.append(boxes) - final_scores.append(scores) + # Apply Non-maximum suppression - # TODO - # Apply NMS on various batch - ''' - for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes): - boxes = box_ops.clip_boxes_to_image(boxes, img_shape) + for boxes, scores, lvl in zip(proposals, objectness_prob, levels,): + boxes = box_ops.clip_boxes_to_image(boxes, image_shapes) # remove small boxes keep = box_ops.remove_small_boxes(boxes, self.min_size) @@ -284,7 +259,7 @@ def filter_proposals( final_boxes.append(boxes) final_scores.append(scores) - ''' + return final_boxes, final_scores def forward( From 52b4c499d73cabb4bb2774aae3693107bdf0635c Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Wed, 11 Oct 2023 18:46:59 +0900 Subject: [PATCH 03/11] Apply metric update --- .../metrics/detection/metric.py | 48 +++++++++---------- src/netspresso_trainer/pipelines/detection.py | 41 +++++++--------- 2 files changed, 42 insertions(+), 47 deletions(-) diff --git a/src/netspresso_trainer/metrics/detection/metric.py b/src/netspresso_trainer/metrics/detection/metric.py index 774072765..17a83845d 100644 --- a/src/netspresso_trainer/metrics/detection/metric.py +++ b/src/netspresso_trainer/metrics/detection/metric.py @@ -171,39 +171,39 @@ class DetectionMetric(BaseMetric): def __init__(self, **kwargs): super().__init__() - def calibrate(self, pred, target, **kwargs): + def calibrate(self, predictions, targets, **kwargs): result_dict = {k: 0. for k in self.metric_names} iou_thresholds = np.linspace(0.5, 0.95, 10) stats = [] # Gather matching stats for predictions and targets - - predicted_objs_bbox, predicted_objs_class, predicted_objs_confidence = pred - true_objs_bbox, true_objs_class = target - - true_objs = np.concatenate((true_objs_bbox, true_objs_class[..., np.newaxis]), axis=-1) - predicted_objs = np.concatenate((predicted_objs_bbox, predicted_objs_class[..., np.newaxis], predicted_objs_confidence[..., np.newaxis]), axis=-1) - - if predicted_objs.shape[0] == 0 and true_objs.shape[0]: - stats.append( - ( - np.zeros((0, iou_thresholds.size), dtype=bool), - *np.zeros((2, 0)), - true_objs[:, 4], + for pred, target in zip(predictions, targets): + predicted_objs_bbox, predicted_objs_class, predicted_objs_confidence = pred['post_boxes'], pred['post_labels'], pred['post_scores'] + true_objs_bbox, true_objs_class = target['boxes'], target['labels'] + + true_objs = np.concatenate((true_objs_bbox, true_objs_class[..., np.newaxis]), axis=-1) + predicted_objs = np.concatenate((predicted_objs_bbox, predicted_objs_class[..., np.newaxis], predicted_objs_confidence[..., np.newaxis]), axis=-1) + + if predicted_objs.shape[0] == 0 and true_objs.shape[0]: + stats.append( + ( + np.zeros((0, iou_thresholds.size), dtype=bool), + *np.zeros((2, 0)), + true_objs[:, 4], + ) ) - ) - if true_objs.shape[0]: - matches = match_detection_batch(predicted_objs, true_objs, iou_thresholds) - stats.append( - ( - matches, - predicted_objs[:, 5], - predicted_objs[:, 4], - true_objs[:, 4], + if true_objs.shape[0]: + matches = match_detection_batch(predicted_objs, true_objs, iou_thresholds) + stats.append( + ( + matches, + predicted_objs[:, 5], + predicted_objs[:, 4], + true_objs[:, 4], + ) ) - ) # Compute average precisions if any matches exist if stats: diff --git a/src/netspresso_trainer/pipelines/detection.py b/src/netspresso_trainer/pipelines/detection.py index 9ac9c5eda..fef0bd9ed 100644 --- a/src/netspresso_trainer/pipelines/detection.py +++ b/src/netspresso_trainer/pipelines/detection.py @@ -21,6 +21,9 @@ def __init__(self, conf, task, model_name, model, devices, train_dataloader, eva train_dataloader, eval_dataloader, class_map, **kwargs) self.num_classes = train_dataloader.dataset.num_classes + if kwargs['is_graphmodule_training']: + pass + def train_step(self, batch): self.model.train() images, labels, bboxes = batch['pixel_values'], batch['label'], batch['bbox'] @@ -59,9 +62,11 @@ def train_step(self, batch): self.loss_factory.backward() self.optimizer.step() - # TODO: metric update - # out = {k: v.detach() for k, v in out.items()} - # self.metric_factory(out['pred'], target=targets, mode='train') + pred = [{'post_boxes': b.detach().cpu().numpy(), 'post_labels': l.detach().cpu().numpy(), 'post_scores': c.detach().cpu().numpy()} + for b, l, c in zip(out['post_boxes'], out['post_labels'], out['post_scores'])] + targets = [{'boxes': target['boxes'].detach().cpu().numpy(), 'labels': target['labels'].detach().cpu().numpy()} + for target in targets] + self.metric_factory(pred, target=targets, phase='train') if self.conf.distributed: torch.distributed.barrier() @@ -70,16 +75,20 @@ def valid_step(self, batch): self.model.eval() images, labels, bboxes = batch['pixel_values'], batch['label'], batch['bbox'] images = images.to(self.devices) - #targets = [{"boxes": box.to(self.devices), "labels": label.to(self.devices)} - # for box, label in zip(bboxes, labels)] + targets = [{"boxes": box.to(self.devices), "labels": label.to(self.devices)} + for box, label in zip(bboxes, labels)] out = self.model(images) + + pred = [{'post_boxes': b.detach().cpu().numpy(), 'post_labels': l.detach().cpu().numpy(), 'post_scores': c.detach().cpu().numpy()} + for b, l, c in zip(out['post_boxes'], out['post_labels'], out['post_scores'])] + targets = [{'boxes': target['boxes'].detach().cpu().numpy(), 'labels': target['labels'].detach().cpu().numpy()} + for target in targets] + self.metric_factory(pred, target=targets, phase='valid') + # TODO: compute loss for validation #self.loss_factory.calc(out, target=targets, phase='valid') - # TODO: metric update - # self.metric_factory(out['pred'], (labels, bboxes), mode='valid') - if self.conf.distributed: torch.distributed.barrier() @@ -106,21 +115,7 @@ def test_step(self, batch): return results def get_metric_with_all_outputs(self, outputs): - targets = np.empty((0, 4)) - preds = np.empty((0, 5)) # with confidence score - targets_indices = np.empty(0) - preds_indices = np.empty(0) - for output_batch in outputs: - for detection, class_idx in output_batch['target']: - targets = np.vstack([targets, detection]) - targets_indices = np.append(targets_indices, class_idx) - - for detection, class_idx in output_batch['pred']: - preds = np.vstack([preds, detection]) - preds_indices = np.append(preds_indices, class_idx) - - pred_bbox, pred_confidence = preds[..., :4], preds[..., -1] # (N x 4), (N,) - self.metric_factory.calc((pred_bbox, preds_indices, pred_confidence), (targets, targets_indices), phase='valid') + pass def save_checkpoint(self, epoch: int): From 222758e29f646e8594036592f4acd51708d31370 Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Wed, 11 Oct 2023 19:55:18 +0900 Subject: [PATCH 04/11] Remove if statement on detection transform --- .../dataloaders/detection/transforms.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/netspresso_trainer/dataloaders/detection/transforms.py b/src/netspresso_trainer/dataloaders/detection/transforms.py index 3600db78a..ac3090a59 100644 --- a/src/netspresso_trainer/dataloaders/detection/transforms.py +++ b/src/netspresso_trainer/dataloaders/detection/transforms.py @@ -35,9 +35,6 @@ def val_transforms_efficientformer(conf_augmentation): return val_transforms_composed def create_transform_detection(model_name: str, is_training=False): - - if model_name == 'efficientformer': - if is_training: - return train_transforms_efficientformer - return val_transforms_efficientformer - raise ValueError(f"No such model named: {model_name} !!!") + if is_training: + return train_transforms_efficientformer + return val_transforms_efficientformer From c443d70116d4e6067691ef48e4fca0fe3216ede1 Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Wed, 11 Oct 2023 19:55:36 +0900 Subject: [PATCH 05/11] Enable fx model retrain --- src/netspresso_trainer/pipelines/detection.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/netspresso_trainer/pipelines/detection.py b/src/netspresso_trainer/pipelines/detection.py index fef0bd9ed..6a7c12999 100644 --- a/src/netspresso_trainer/pipelines/detection.py +++ b/src/netspresso_trainer/pipelines/detection.py @@ -7,10 +7,11 @@ import torch from omegaconf import OmegaConf -from ..models.utils import DetectionModelOutput +from ..models.utils import DetectionModelOutput, load_from_checkpoint from .base import BasePipeline from ..utils.fx import save_graphmodule from ..utils.onnx import save_onnx +from ..models import build_model logger = logging.getLogger("netspresso_trainer") @@ -21,8 +22,14 @@ def __init__(self, conf, task, model_name, model, devices, train_dataloader, eva train_dataloader, eval_dataloader, class_map, **kwargs) self.num_classes = train_dataloader.dataset.num_classes + # Re-compose torch.fx backbone and nn.Module head + # To load head weights, config should have head_checkpoint value. if kwargs['is_graphmodule_training']: - pass + model = build_model(conf.model, task, self.num_classes, None, conf.augmentation.img_size) + model.backbone = self.model + model.head = load_from_checkpoint(model.head, conf.model.head_checkpoint) + model = model.to(device=devices) + self.model = model def train_step(self, batch): self.model.train() @@ -146,8 +153,12 @@ def save_checkpoint(self, epoch: int): if save_best_model: save_onnx(model, best_model_path.with_suffix(".onnx"), sample_input=self.sample_input.type(self.save_dtype)) logger.info(f"ONNX model converting and saved at {str(best_model_path.with_suffix('.onnx'))}") - torch.save(model, best_model_path.with_suffix(".pt")) + + torch.save(model.backbone, best_model_path.with_suffix(".pt")) logger.info(f"Best model saved at {str(best_model_path.with_suffix('.pt'))}") + # save head separately + torch.save(model.head.state_dict(), (model_path.parent / f"{best_model_path.stem}_head").with_suffix(".pth")) + logger.info(f"Detection head saved at {str(best_model_path.with_suffix('.pth'))}") return torch.save(model.state_dict(), model_path.with_suffix(".pth")) logger.debug(f"PyTorch model saved at {str(model_path.with_suffix('.pth'))}") From 641597712ece3aac717b0427f1721346dafbc646 Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Fri, 13 Oct 2023 10:17:24 +0900 Subject: [PATCH 06/11] Update validation loss --- src/netspresso_trainer/pipelines/detection.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/netspresso_trainer/pipelines/detection.py b/src/netspresso_trainer/pipelines/detection.py index 6a7c12999..08280ceaf 100644 --- a/src/netspresso_trainer/pipelines/detection.py +++ b/src/netspresso_trainer/pipelines/detection.py @@ -64,11 +64,13 @@ def train_step(self, batch): out.update(roi_features) out.update({'labels': labels, 'regression_targets': regression_targets}) + # Compute loss self.loss_factory.calc(out, target=targets, phase='train') self.loss_factory.backward() self.optimizer.step() + # Update metrics pred = [{'post_boxes': b.detach().cpu().numpy(), 'post_labels': l.detach().cpu().numpy(), 'post_scores': c.detach().cpu().numpy()} for b, l, c in zip(out['post_boxes'], out['post_labels'], out['post_scores'])] targets = [{'boxes': target['boxes'].detach().cpu().numpy(), 'labels': target['labels'].detach().cpu().numpy()} @@ -81,21 +83,28 @@ def train_step(self, batch): def valid_step(self, batch): self.model.eval() images, labels, bboxes = batch['pixel_values'], batch['label'], batch['bbox'] + bboxes = [b.to(self.devices) for b in bboxes] + labels = [l.to(self.devices) for l in labels] images = images.to(self.devices) - targets = [{"boxes": box.to(self.devices), "labels": label.to(self.devices)} - for box, label in zip(bboxes, labels)] + targets = [{"boxes": box, "labels": label} for box, label in zip(bboxes, labels)] out = self.model(images) + # Compute loss + head = self.model.head + matched_idxs, labels = head.roi_heads.assign_targets_to_proposals(out['boxes'], bboxes, labels) + matched_gt_boxes = [bbox[idx] for idx, bbox in zip(matched_idxs, bboxes)] + regression_targets = head.roi_heads.box_coder.encode(matched_gt_boxes, out['boxes']) + out.update({'labels': labels, 'regression_targets': regression_targets}) + self.loss_factory.calc(out, target=targets, phase='valid') + + # Update metrics pred = [{'post_boxes': b.detach().cpu().numpy(), 'post_labels': l.detach().cpu().numpy(), 'post_scores': c.detach().cpu().numpy()} for b, l, c in zip(out['post_boxes'], out['post_labels'], out['post_scores'])] targets = [{'boxes': target['boxes'].detach().cpu().numpy(), 'labels': target['labels'].detach().cpu().numpy()} for target in targets] self.metric_factory(pred, target=targets, phase='valid') - # TODO: compute loss for validation - #self.loss_factory.calc(out, target=targets, phase='valid') - if self.conf.distributed: torch.distributed.barrier() From a66d6d163adc0965ae1a8a895142a1123d7aaa01 Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Fri, 13 Oct 2023 15:35:44 +0900 Subject: [PATCH 07/11] Revert and fix metric update --- src/netspresso_trainer/pipelines/detection.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/netspresso_trainer/pipelines/detection.py b/src/netspresso_trainer/pipelines/detection.py index 08280ceaf..fd94002ff 100644 --- a/src/netspresso_trainer/pipelines/detection.py +++ b/src/netspresso_trainer/pipelines/detection.py @@ -53,7 +53,7 @@ def train_step(self, batch): # generate proposals for training proposals = rpn_features['boxes'] - proposals, matched_idxs, labels, regression_targets = head.roi_heads.select_training_samples(proposals, targets) + proposals, matched_idxs, roi_head_labels, regression_targets = head.roi_heads.select_training_samples(proposals, targets) # forward to roi head roi_features = head.roi_heads(features, proposals, head.image_size) @@ -62,7 +62,7 @@ def train_step(self, batch): out = DetectionModelOutput() out.update(rpn_features) out.update(roi_features) - out.update({'labels': labels, 'regression_targets': regression_targets}) + out.update({'labels': roi_head_labels, 'regression_targets': regression_targets}) # Compute loss self.loss_factory.calc(out, target=targets, phase='train') @@ -70,13 +70,6 @@ def train_step(self, batch): self.loss_factory.backward() self.optimizer.step() - # Update metrics - pred = [{'post_boxes': b.detach().cpu().numpy(), 'post_labels': l.detach().cpu().numpy(), 'post_scores': c.detach().cpu().numpy()} - for b, l, c in zip(out['post_boxes'], out['post_labels'], out['post_scores'])] - targets = [{'boxes': target['boxes'].detach().cpu().numpy(), 'labels': target['labels'].detach().cpu().numpy()} - for target in targets] - self.metric_factory(pred, target=targets, phase='train') - if self.conf.distributed: torch.distributed.barrier() @@ -92,19 +85,12 @@ def valid_step(self, batch): # Compute loss head = self.model.head - matched_idxs, labels = head.roi_heads.assign_targets_to_proposals(out['boxes'], bboxes, labels) + matched_idxs, roi_head_labels = head.roi_heads.assign_targets_to_proposals(out['boxes'], bboxes, labels) matched_gt_boxes = [bbox[idx] for idx, bbox in zip(matched_idxs, bboxes)] regression_targets = head.roi_heads.box_coder.encode(matched_gt_boxes, out['boxes']) - out.update({'labels': labels, 'regression_targets': regression_targets}) + out.update({'labels': roi_head_labels, 'regression_targets': regression_targets}) self.loss_factory.calc(out, target=targets, phase='valid') - # Update metrics - pred = [{'post_boxes': b.detach().cpu().numpy(), 'post_labels': l.detach().cpu().numpy(), 'post_scores': c.detach().cpu().numpy()} - for b, l, c in zip(out['post_boxes'], out['post_labels'], out['post_scores'])] - targets = [{'boxes': target['boxes'].detach().cpu().numpy(), 'labels': target['labels'].detach().cpu().numpy()} - for target in targets] - self.metric_factory(pred, target=targets, phase='valid') - if self.conf.distributed: torch.distributed.barrier() @@ -131,8 +117,23 @@ def test_step(self, batch): return results def get_metric_with_all_outputs(self, outputs): - pass - + pred = list() + targets = list() + for output_batch in outputs: + for detection, class_idx in output_batch['target']: + target_on_image = dict() + target_on_image['boxes'] = detection + target_on_image['labels'] = class_idx + targets.append(target_on_image) + + for detection, class_idx in output_batch['pred']: + pred_on_image = dict() + pred_on_image['post_boxes'] = detection[..., :4] + pred_on_image['post_scores'] = detection[..., -1] + pred_on_image['post_labels'] = class_idx + pred.append(pred_on_image) + self.metric_factory(pred, target=targets, phase='valid') + def save_checkpoint(self, epoch: int): # Check whether the valid loss is minimum at this epoch From d9403855336be8a491371adeb9f7e3f1494bd9d0 Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Fri, 13 Oct 2023 16:03:41 +0900 Subject: [PATCH 08/11] Fix save name --- src/netspresso_trainer/pipelines/detection.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/netspresso_trainer/pipelines/detection.py b/src/netspresso_trainer/pipelines/detection.py index fd94002ff..b5fbe7041 100644 --- a/src/netspresso_trainer/pipelines/detection.py +++ b/src/netspresso_trainer/pipelines/detection.py @@ -158,17 +158,19 @@ def save_checkpoint(self, epoch: int): if self.is_graphmodule_training: # Just save graphmodule checkpoint - torch.save(model, model_path.with_suffix(".pt")) - logger.debug(f"PyTorch FX model saved at {str(model_path.with_suffix('.pt'))}") + torch.save(model, (model_path.parent / f"{model_path.stem}_backbone").with_suffix(".pth")) + logger.debug(f"PyTorch FX model saved at {(model_path.parent / f'{model_path.stem}_backbone').with_suffix('.pth')}") + torch.save(model.head.state_dict(), (model_path.parent / f"{model_path.stem}_head").with_suffix(".pth")) + logger.info(f"Detection head saved at {(model_path.parent / f'{model_path.stem}_head').with_suffix('.pth')}") if save_best_model: save_onnx(model, best_model_path.with_suffix(".onnx"), sample_input=self.sample_input.type(self.save_dtype)) logger.info(f"ONNX model converting and saved at {str(best_model_path.with_suffix('.onnx'))}") - torch.save(model.backbone, best_model_path.with_suffix(".pt")) - logger.info(f"Best model saved at {str(best_model_path.with_suffix('.pt'))}") + torch.save(model.backbone, (model_path.parent / f"{best_model_path.stem}_backbone").with_suffix(".pt")) + logger.info(f"Best model saved at {(model_path.parent / f'{best_model_path.stem}_backbone').with_suffix('.pt')}") # save head separately torch.save(model.head.state_dict(), (model_path.parent / f"{best_model_path.stem}_head").with_suffix(".pth")) - logger.info(f"Detection head saved at {str(best_model_path.with_suffix('.pth'))}") + logger.info(f"Detection head saved at {(model_path.parent / f'{best_model_path.stem}_head').with_suffix('.pth')}") return torch.save(model.state_dict(), model_path.with_suffix(".pth")) logger.debug(f"PyTorch model saved at {str(model_path.with_suffix('.pth'))}") @@ -182,10 +184,10 @@ def save_checkpoint(self, epoch: int): # fx backbone save_graphmodule(model.backbone, (model_path.parent / f"{best_model_path.stem}_backbone_fx").with_suffix(".pt")) - logger.info(f"PyTorch FX model tracing and saved at {str(best_model_path.with_suffix('.pt'))}") + logger.info(f"PyTorch FX model tracing and saved at {(model_path.parent / f'{best_model_path.stem}_backbone_fx').with_suffix('.pt')}") # save head separately torch.save(model.head.state_dict(), (model_path.parent / f"{best_model_path.stem}_head").with_suffix(".pth")) - logger.info(f"Detection head saved at {str(best_model_path.with_suffix('.pth'))}") + logger.info(f"Detection head saved at {(model_path.parent / f'{best_model_path.stem}_head').with_suffix('.pth')}") except Exception as e: logger.error(e) pass From 02c34d17ded8f2e685eebfe13526daf4cbd93086 Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Fri, 13 Oct 2023 17:55:51 +0900 Subject: [PATCH 09/11] Ruff fix --- src/netspresso_trainer/pipelines/detection.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/netspresso_trainer/pipelines/detection.py b/src/netspresso_trainer/pipelines/detection.py index b5fbe7041..9c158fd8d 100644 --- a/src/netspresso_trainer/pipelines/detection.py +++ b/src/netspresso_trainer/pipelines/detection.py @@ -7,11 +7,11 @@ import torch from omegaconf import OmegaConf +from ..models import build_model from ..models.utils import DetectionModelOutput, load_from_checkpoint -from .base import BasePipeline from ..utils.fx import save_graphmodule from ..utils.onnx import save_onnx -from ..models import build_model +from .base import BasePipeline logger = logging.getLogger("netspresso_trainer") @@ -76,8 +76,8 @@ def train_step(self, batch): def valid_step(self, batch): self.model.eval() images, labels, bboxes = batch['pixel_values'], batch['label'], batch['bbox'] - bboxes = [b.to(self.devices) for b in bboxes] - labels = [l.to(self.devices) for l in labels] + bboxes = [bbox.to(self.devices) for bbox in bboxes] + labels = [label.to(self.devices) for label in labels] images = images.to(self.devices) targets = [{"boxes": box, "labels": label} for box, label in zip(bboxes, labels)] @@ -117,17 +117,17 @@ def test_step(self, batch): return results def get_metric_with_all_outputs(self, outputs): - pred = list() - targets = list() + pred = [] + targets = [] for output_batch in outputs: for detection, class_idx in output_batch['target']: - target_on_image = dict() + target_on_image = {} target_on_image['boxes'] = detection target_on_image['labels'] = class_idx targets.append(target_on_image) for detection, class_idx in output_batch['pred']: - pred_on_image = dict() + pred_on_image = {} pred_on_image['post_boxes'] = detection[..., :4] pred_on_image['post_scores'] = detection[..., -1] pred_on_image['post_labels'] = class_idx From 61353325ac879e971fae4597a4fa50abb91619ce Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Fri, 13 Oct 2023 18:09:31 +0900 Subject: [PATCH 10/11] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b1871ef6..c67408044 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ No changes to highlight. ## Other Changes: -- Fix Faster R-CNN detection head by `@illian01` in [PR 184](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/184) +- Fix Faster R-CNN detection head by `@illian01` in [PR 184](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/184), [PR 194](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/194) - Refactoring models/op module by `@illian01` in [PR 189](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/189), [PR 190](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/190) - Release NetsPresso Trainer colab tutorial `@illian01` in [PR 191](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/191) From 3dd140673dc91ebe380a5dfe3cd4880b85ddb636 Mon Sep 17 00:00:00 2001 From: Junho Shin Date: Mon, 16 Oct 2023 11:24:28 +0900 Subject: [PATCH 11/11] Fix metric_factory __call__ to calc --- src/netspresso_trainer/pipelines/detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/netspresso_trainer/pipelines/detection.py b/src/netspresso_trainer/pipelines/detection.py index 9c158fd8d..8ff1a522c 100644 --- a/src/netspresso_trainer/pipelines/detection.py +++ b/src/netspresso_trainer/pipelines/detection.py @@ -132,7 +132,7 @@ def get_metric_with_all_outputs(self, outputs): pred_on_image['post_scores'] = detection[..., -1] pred_on_image['post_labels'] = class_idx pred.append(pred_on_image) - self.metric_factory(pred, target=targets, phase='valid') + self.metric_factory.calc(pred, target=targets, phase='valid') def save_checkpoint(self, epoch: int):