Skip to content

Commit

Permalink
added torch.no_grad to estimator predict_proba method
Browse files Browse the repository at this point in the history
  • Loading branch information
victormvy committed Feb 26, 2024
1 parent 93705fa commit ef5dfdd
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions dlordinal/estimator/pytorch_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,30 +134,29 @@ def predict_proba(self, X: Union[DataLoader, torch.Tensor]):
if X is None:
raise ValueError("X must be a DataLoader or a torch Tensor")

else:
# check if X is a DataLoader
if isinstance(X, DataLoader):
print("Predicting ...")
self.model.eval()
predictions = []

# Iterate over batches
for _, (X_batch, _) in enumerate(X):
predictions_batch = self._predict_proba(X_batch)
predictions.append(predictions_batch)

# Concatenate predictions
predictions = torch.cat(predictions)
return predictions

# check if X is a torch Tensor
elif isinstance(X, torch.Tensor):
print("Predicting ...")
self.model.eval()
return self._predict_proba(X)
# check if X is a DataLoader
if isinstance(X, DataLoader):
print("Predicting ...")
self.model.eval()
predictions = []

else:
raise ValueError("X must be a DataLoader or a torch Tensor")
# Iterate over batches
for _, (X_batch, _) in enumerate(X):
predictions_batch = self._predict_proba(X_batch)
predictions.append(predictions_batch)

# Concatenate predictions
predictions = torch.cat(predictions)
return predictions

# check if X is a torch Tensor
elif isinstance(X, torch.Tensor):
print("Predicting ...")
self.model.eval()
return self._predict_proba(X)

else:
raise ValueError("X must be a DataLoader or a torch Tensor")

def _predict_proba(self, X):
"""
Expand All @@ -169,10 +168,11 @@ def _predict_proba(self, X):
X : torch.Tensor
The data to predict.
"""
X = X.to(self.device)
pred = self.model(X)
probabilities = F.softmax(pred, dim=1)
return probabilities
with torch.no_grad():
X = X.to(self.device)
pred = self.model(X)
probabilities = F.softmax(pred, dim=1)
return probabilities

def predict(self, X: Union[DataLoader, torch.Tensor]):
"""
Expand Down

0 comments on commit ef5dfdd

Please sign in to comment.