Skip to content

Commit

Permalink
fixing: identation error
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed Jul 10, 2024
1 parent c3bba50 commit b8bfcf6
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions lightorch/training/supervised.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from lightning.pytorch import LightningModule
from typing import Union, Sequence, Any, Tuple, Dict
from typing import Union, Sequence, Any, Dict, List
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
Expand Down Expand Up @@ -38,7 +38,15 @@ def interval(algo: LRScheduler) -> str:
else:
return "epoch"

class Module(LightningModule): """ init: triggers: Dict[str, Dict[str, float]] -> This is an interpretative implementation for grouped optimization where the parameters are stored in groups given a "trigger", namely, as trigger parameters you can put a string describing the beginning of the parameters to optimize in a group. optimizer: str | Optimizer -> Name of the optimizer or an Optimizer instance. scheduler: str | LRScheduler -> Name of the scheduler or a Scheduler instance. scheduler_kwargs: Dict[str, Any] -> Arguments of the scheduler. gradient_clip_algorithm: str -> Gradient clip algorithm [value, norm]. gradient_clip_val: float -> Clipping value. """
class Module(LightningModule):
"""
init: triggers: Dict[str, Dict[str, float]] -> This is an interpretative implementation for grouped optimization where the parameters are stored in groups given a "trigger", namely, as trigger parameters you can put a string describing the beginning of the parameters to optimize in a group.
optimizer: str | Optimizer -> Name of the optimizer or an Optimizer instance.
scheduler: str | LRScheduler -> Name of the scheduler or a Scheduler instance.
scheduler_kwargs: Dict[str, Any] -> Arguments of the scheduler.
gradient_clip_algorithm: str -> Gradient clip algorithm [value, norm].
gradient_clip_val: float -> Clipping value.
"""
def __init__(
self,
*,
Expand Down Expand Up @@ -100,7 +108,7 @@ def __init__(
def loss_forward(self, batch: Tensor, idx: int) -> Dict[str, Union[Tensor, float]]:
raise NotImplementedError("Should have defined loss_forward method.")

def training_step(self, batch: Tensor, idx: int) -> Tensor:
def training_step(self, batch: Tensor, idx: int) -> Union[float, Tensor]:
kwargs = self.loss_forward(batch, idx)
return self._compute_training_loss(**kwargs)

Expand All @@ -109,7 +117,7 @@ def validation_step(self, batch: Tensor, idx: int) -> None:
kwargs = self.loss_forward(batch, idx)
return self._compute_valid_metrics(**kwargs)

def _compute_training_loss(self, **kwargs) -> Union[Tensor, Sequence[Tensor]]:
def _compute_training_loss(self, **kwargs) -> Union[float, Tensor]:
args = self.criterion(**kwargs)
self.log_dict(
{f"Training/{k}": v for k, v in zip(self.criterion.labels, args)},
Expand All @@ -132,12 +140,12 @@ def _compute_valid_metrics(self, **kwargs) -> None:
True,
)

def get_param_groups(self) -> Tuple:
def get_param_groups(self) -> List[Dict[str, Union[nn.Module, List[float]]]]:
"""
Given a list of "triggers", the param groups are defined.
"""
if self.triggers is not None:
param_groups: Sequence[Dict[str, Sequence[nn.Module]]] = [
param_groups: Sequence[Dict[str, List[nn.Module]]] = [
defaultdict(list) for _ in range(len(self.triggers))
]
# Update the model parameters per group and finally add the
Expand Down Expand Up @@ -181,7 +189,7 @@ def _configure_scheduler(self, optimizer: Optimizer) -> LRScheduler:
else:
return self.scheduler(optimizer)

def configure_optimizers(self) -> Union[Optimizer, Sequence[Optimizer]]:
def configure_optimizers(self) -> Dict[str, Union[Optimizer, Dict[str, Union[float, int, LRScheduler]]]]:
optimizer = self._configure_optimizer()
if self.scheduler is not None:
scheduler = self._configure_scheduler(optimizer)
Expand Down

0 comments on commit b8bfcf6

Please sign in to comment.