diff --git a/dlordinal/estimator/pytorch_estimator.py b/dlordinal/estimator/pytorch_estimator.py index 081d919..17e6b34 100644 --- a/dlordinal/estimator/pytorch_estimator.py +++ b/dlordinal/estimator/pytorch_estimator.py @@ -23,6 +23,12 @@ class PytorchEstimator(BaseEstimator): A Pytorch device. max_iter : int The maximum number of iterations. + verbose : int, default=0 + Verbosity level. + If 0, no output is printed. + If 1, a message is printed at the beginning of the training/prediction. + If 2, the epoch progress is printed. + If 3, the loss is also printed. **kwargs : dict Additional keyword arguments. """ @@ -34,6 +40,7 @@ def __init__( optimizer: torch.optim.Optimizer, device: torch.device, max_iter: int, + verbose: int = 0, **kwargs, ): self.kwargs = kwargs @@ -42,6 +49,7 @@ def __init__( self.optimizer = optimizer self.device = device self.max_iter = max_iter + self.verbose = verbose def fit( self, @@ -59,19 +67,26 @@ def fit( The training labels, only used if X is a ``torch.Tensor``. """ + if self.verbose >= 1: + print("Training ...") + # Check if X is a DataLoader if isinstance(X, DataLoader): if y is None: - print("Training ...") self.model.train() # Iterate over epochs for epoch in range(self.max_iter): - print(f"Epoch {epoch+1}/{self.max_iter}") + if self.verbose >= 2: + print(f"Epoch {epoch+1}/{self.max_iter}") # Iterate over batches + loss = 0 for _, (X_batch, y_batch) in enumerate(X): - self._fit(X_batch, y_batch) + loss += self._fit(X_batch, y_batch) + loss /= len(X) + if self.verbose >= 3: + print(f"Loss: {loss}") else: raise ValueError("If X is a DataLoader, y must be None") @@ -83,13 +98,14 @@ def fit( # Check if y is a torch Tensor elif isinstance(y, torch.Tensor): - print("Training ...") self.model.train() # Iterate over epochs for epoch in range(self.max_iter): - print(f"Epoch {epoch+1}/{self.max_iter}") - self._fit(X, y) + if self.verbose >= 2: + print(f"Epoch {epoch+1}/{self.max_iter}") + loss = self._fit(X, y) + print(f"Loss: {loss}") else: raise ValueError("y must be a torch.Tensor") @@ -122,6 +138,8 @@ def _fit(self, X, y): loss.backward() self.optimizer.step() + return loss.item() + def predict_proba(self, X: Union[DataLoader, torch.Tensor]): """ predict_proba() is a method that predicts the probability of each class. @@ -131,12 +149,14 @@ def predict_proba(self, X: Union[DataLoader, torch.Tensor]): X : Union[DataLoader, torch.Tensor] The data to predict. """ + if self.verbose >= 1: + print("Predicting ...") + if X is None: raise ValueError("X must be a DataLoader or a torch Tensor") # check if X is a DataLoader if isinstance(X, DataLoader): - print("Predicting ...") self.model.eval() predictions = [] @@ -151,7 +171,6 @@ def predict_proba(self, X: Union[DataLoader, torch.Tensor]): # check if X is a torch Tensor elif isinstance(X, torch.Tensor): - print("Predicting ...") self.model.eval() return self._predict_proba(X)