From 248fe890d35d4c4c5efc398e2ad9f76d82d28258 Mon Sep 17 00:00:00 2001 From: Sadra Barikbin Date: Wed, 4 Sep 2024 02:54:24 +0330 Subject: [PATCH] Use if check on torch.double usages for MPS backend --- ignite/metrics/mean_average_precision.py | 12 ++++++---- ...ject_detection_average_precision_recall.py | 24 ++++++++++++++----- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index 71896512a73..03f84cc8c6c 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -102,7 +102,7 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens ).where(rec_thresh_indices != recall.size(-1), 0) recall = rec_thresholds recall_differential = recall.diff( - dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=torch.double) + dim=-1, prepend=torch.zeros((*recall.shape[:-1], 1), device=recall.device, dtype=recall.dtype) ) return torch.sum(recall_differential * precision, dim=-1) @@ -327,7 +327,9 @@ def _compute_recall_and_precision( `(recall, precision)` """ indices = torch.argsort(y_pred, stable=True, descending=True) - tp_summation = y_true[indices].cumsum(dim=0).double() + tp_summation = y_true[indices].cumsum(dim=0) + if tp_summation.device != torch.device("mps"): + tp_summation = tp_summation.double() # Adopted from Scikit-learn's implementation unique_scores_indices = torch.nonzero( @@ -360,8 +362,8 @@ def compute(self) -> Union[torch.Tensor, float]: torch.long if self._type == "multiclass" else torch.uint8, self._device, ) - - y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), torch.double, self._device) + fp_precision = torch.double if self._device != torch.device("mps") else torch.float32 + y_pred = _cat_and_agg_tensors(self._y_pred, (num_classes,), fp_precision, self._device) if self._type == "multiclass": y_true = to_onehot(y_true, num_classes=num_classes).T @@ -369,7 +371,7 @@ def compute(self) -> Union[torch.Tensor, float]: y_true = y_true.reshape(1, -1) y_pred = y_pred.view(1, -1) y_true_positive_count = y_true.sum(dim=-1) - average_precisions = torch.zeros_like(y_true_positive_count, device=self._device, dtype=torch.double) + average_precisions = torch.zeros_like(y_true_positive_count, device=self._device, dtype=fp_precision) for cls in range(y_true_positive_count.size(0)): recall, precision = self._compute_recall_and_precision(y_true[cls], y_pred[cls], y_true_positive_count[cls]) average_precisions[cls] = self._compute_average_precision(recall, precision) diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 0ae07e924a5..f6111af1d1e 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -104,13 +104,19 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo except ImportError: raise ModuleNotFoundError("This metric requires torchvision to be installed.") + precision = torch.double if not torch.device(device) != torch.device("mps") else torch.float32 + if iou_thresholds is None: - iou_thresholds = torch.linspace(0.5, 0.95, 10, dtype=torch.double) + iou_thresholds = torch.linspace(0.5, 0.95, 10, device=device, dtype=precision) self._iou_thresholds = self._setup_thresholds(iou_thresholds, "iou_thresholds") + self._iou_thresholds = self._iou_thresholds.to(device=device, dtype=precision) if rec_thresholds is None: - rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=torch.double) + rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=precision) + + self._rec_thresholds = self._setup_thresholds(rec_thresholds, "rec_thresholds") + self._rec_thresholds = self._rec_thresholds.to(device=device, dtype=precision) self._num_classes = num_classes self._area_range = area_range @@ -204,9 +210,14 @@ def _compute_recall_and_precision( """ indices = torch.argsort(scores, dim=-1, stable=True, descending=True) tp = TP[..., indices] - tp_summation = tp.cumsum(dim=-1).double() + tp_summation = tp.cumsum(dim=-1) + if tp_summation.device != torch.device("mps"): + tp_summation = tp_summation.double() + fp = FP[..., indices] - fp_summation = fp.cumsum(dim=-1).double() + fp_summation = fp.cumsum(dim=-1) + if fp_summation.device != torch.device("mps"): + fp_summation = fp_summation.double() recall = tp_summation / y_true_count predicted_positive = tp_summation + fp_summation @@ -342,12 +353,13 @@ def _compute(self) -> torch.Tensor: pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.long, self._device) TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device) FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device) - scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), torch.double, self._device) + fp_precision = torch.double if self._device != torch.device("mps") else torch.float32 + scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), fp_precision, self._device) average_precisions_recalls = -torch.ones( (2, self._num_classes, len(self._iou_thresholds)), device=self._device, - dtype=torch.double, + dtype=fp_precision, ) for cls in range(self._num_classes): if self._y_true_count[cls] == 0: