Skip to content

Commit

Permalink
added verbose parameter to PytorchEstimator and ability to print loss…
Browse files Browse the repository at this point in the history
… per epoch
  • Loading branch information
victormvy committed Mar 26, 2024
1 parent 4397709 commit 2ba8acf
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions dlordinal/estimator/pytorch_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class PytorchEstimator(BaseEstimator):
A Pytorch device.
max_iter : int
The maximum number of iterations.
verbose : int, default=0
Verbosity level.
If 0, no output is printed.
If 1, a message is printed at the beginning of the training/prediction.
If 2, the epoch progress is printed.
If 3, the loss is also printed.
**kwargs : dict
Additional keyword arguments.
"""
Expand All @@ -34,6 +40,7 @@ def __init__(
optimizer: torch.optim.Optimizer,
device: torch.device,
max_iter: int,
verbose: int = 0,
**kwargs,
):
self.kwargs = kwargs
Expand All @@ -42,6 +49,7 @@ def __init__(
self.optimizer = optimizer
self.device = device
self.max_iter = max_iter
self.verbose = verbose

def fit(
self,
Expand All @@ -59,19 +67,26 @@ def fit(
The training labels, only used if X is a ``torch.Tensor``.
"""

if self.verbose >= 1:
print("Training ...")

# Check if X is a DataLoader
if isinstance(X, DataLoader):
if y is None:
print("Training ...")
self.model.train()

# Iterate over epochs
for epoch in range(self.max_iter):
print(f"Epoch {epoch+1}/{self.max_iter}")
if self.verbose >= 2:
print(f"Epoch {epoch+1}/{self.max_iter}")

# Iterate over batches
loss = 0
for _, (X_batch, y_batch) in enumerate(X):
self._fit(X_batch, y_batch)
loss += self._fit(X_batch, y_batch)
loss /= len(X)
if self.verbose >= 3:
print(f"Loss: {loss}")

else:
raise ValueError("If X is a DataLoader, y must be None")
Expand All @@ -83,13 +98,14 @@ def fit(

# Check if y is a torch Tensor
elif isinstance(y, torch.Tensor):
print("Training ...")
self.model.train()

# Iterate over epochs
for epoch in range(self.max_iter):
print(f"Epoch {epoch+1}/{self.max_iter}")
self._fit(X, y)
if self.verbose >= 2:
print(f"Epoch {epoch+1}/{self.max_iter}")
loss = self._fit(X, y)
print(f"Loss: {loss}")

else:
raise ValueError("y must be a torch.Tensor")
Expand Down Expand Up @@ -122,6 +138,8 @@ def _fit(self, X, y):
loss.backward()
self.optimizer.step()

return loss.item()

def predict_proba(self, X: Union[DataLoader, torch.Tensor]):
"""
predict_proba() is a method that predicts the probability of each class.
Expand All @@ -131,12 +149,14 @@ def predict_proba(self, X: Union[DataLoader, torch.Tensor]):
X : Union[DataLoader, torch.Tensor]
The data to predict.
"""
if self.verbose >= 1:
print("Predicting ...")

if X is None:
raise ValueError("X must be a DataLoader or a torch Tensor")

# check if X is a DataLoader
if isinstance(X, DataLoader):
print("Predicting ...")
self.model.eval()
predictions = []

Expand All @@ -151,7 +171,6 @@ def predict_proba(self, X: Union[DataLoader, torch.Tensor]):

# check if X is a torch Tensor
elif isinstance(X, torch.Tensor):
print("Predicting ...")
self.model.eval()
return self._predict_proba(X)

Expand Down

0 comments on commit 2ba8acf

Please sign in to comment.