diff --git a/ignite/metrics/vision/object_detection_average_precision_recall.py b/ignite/metrics/vision/object_detection_average_precision_recall.py index 595cef6eb08..a56a18c164f 100644 --- a/ignite/metrics/vision/object_detection_average_precision_recall.py +++ b/ignite/metrics/vision/object_detection_average_precision_recall.py @@ -124,7 +124,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo rec_thresholds=rec_thresholds, class_mean=None, ) - precision = torch.double if torch.device(device) != torch.device("mps") else torch.float32 + precision = torch.double if torch.device(device).type != "mps" else torch.float32 self.rec_thresholds = self.rec_thresholds.to(device=device, dtype=precision) @reinit__is_reduced @@ -207,12 +207,12 @@ def _compute_recall_and_precision( indices = torch.argsort(scores, dim=-1, stable=True, descending=True) tp = TP[..., indices] tp_summation = tp.cumsum(dim=-1) - if tp_summation.device != torch.device("mps"): + if tp_summation.device.type != "mps": tp_summation = tp_summation.double() fp = FP[..., indices] fp_summation = fp.cumsum(dim=-1) - if fp_summation.device != torch.device("mps"): + if fp_summation.device.type != "mps": fp_summation = fp_summation.double() recall = tp_summation / y_true_count @@ -342,7 +342,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor ) scores = pred["scores"][max_best_detections_index] - if self._device == torch.device("mps") and scores.dtype == torch.double: + if self._device.type == "mps" and scores.dtype == torch.double: scores = scores.to(dtype=torch.float32) self._scores.append(scores.to(self._device)) self._y_pred_labels.append(pred_labels.to(dtype=torch.int, device=self._device)) @@ -352,7 +352,7 @@ def _compute(self) -> torch.Tensor: pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.int, self._device) TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device) FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device) - fp_precision = torch.double if self._device != torch.device("mps") else torch.float32 + fp_precision = torch.double if self._device.type != "mps" else torch.float32 scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), fp_precision, self._device) average_precisions_recalls = -torch.ones(