Skip to content

Commit

Permalink
Fix a bug related to MPS
Browse files Browse the repository at this point in the history
  • Loading branch information
sadra-barikbin committed Sep 4, 2024
1 parent 085e0df commit 3658f95
Showing 1 changed file with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
rec_thresholds=rec_thresholds,
class_mean=None,
)
precision = torch.double if torch.device(device) != torch.device("mps") else torch.float32
precision = torch.double if torch.device(device).type != "mps" else torch.float32
self.rec_thresholds = self.rec_thresholds.to(device=device, dtype=precision)

@reinit__is_reduced
Expand Down Expand Up @@ -207,12 +207,12 @@ def _compute_recall_and_precision(
indices = torch.argsort(scores, dim=-1, stable=True, descending=True)
tp = TP[..., indices]
tp_summation = tp.cumsum(dim=-1)
if tp_summation.device != torch.device("mps"):
if tp_summation.device.type != "mps":
tp_summation = tp_summation.double()

fp = FP[..., indices]
fp_summation = fp.cumsum(dim=-1)
if fp_summation.device != torch.device("mps"):
if fp_summation.device.type != "mps":
fp_summation = fp_summation.double()

recall = tp_summation / y_true_count
Expand Down Expand Up @@ -342,7 +342,7 @@ def update(self, output: Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, tor
)

scores = pred["scores"][max_best_detections_index]
if self._device == torch.device("mps") and scores.dtype == torch.double:
if self._device.type == "mps" and scores.dtype == torch.double:
scores = scores.to(dtype=torch.float32)
self._scores.append(scores.to(self._device))
self._y_pred_labels.append(pred_labels.to(dtype=torch.int, device=self._device))
Expand All @@ -352,7 +352,7 @@ def _compute(self) -> torch.Tensor:
pred_labels = _cat_and_agg_tensors(self._y_pred_labels, cast(Tuple[int], ()), torch.int, self._device)
TP = _cat_and_agg_tensors(self._tps, (len(self._iou_thresholds),), torch.uint8, self._device)
FP = _cat_and_agg_tensors(self._fps, (len(self._iou_thresholds),), torch.uint8, self._device)
fp_precision = torch.double if self._device != torch.device("mps") else torch.float32
fp_precision = torch.double if self._device.type != "mps" else torch.float32
scores = _cat_and_agg_tensors(self._scores, cast(Tuple[int], ()), fp_precision, self._device)

average_precisions_recalls = -torch.ones(
Expand Down

0 comments on commit 3658f95

Please sign in to comment.