diff --git a/lightorch/nn/sequential/residual.py b/lightorch/nn/sequential/residual.py index 089522e..38a42b7 100644 --- a/lightorch/nn/sequential/residual.py +++ b/lightorch/nn/sequential/residual.py @@ -4,9 +4,7 @@ class _Residual(nn.Module): - def __init__( - self, module: nn.Module, n_layers: int - ): + def __init__(self, module: nn.Module, n_layers: int): super().__init__() self.model = nn.ModuleList([module for _ in range(n_layers)]) diff --git a/lightorch/nn/transformer/transformer.py b/lightorch/nn/transformer/transformer.py index 1142b54..f2e93e8 100644 --- a/lightorch/nn/transformer/transformer.py +++ b/lightorch/nn/transformer/transformer.py @@ -78,7 +78,9 @@ def __init__( fc: Optional[nn.Module] = None, n_layers: int = 1, ) -> None: - assert (encoder is not None or decoder is not None), "Not valid parameters, must be at least one encoder or decoder." + assert ( + encoder is not None or decoder is not None + ), "Not valid parameters, must be at least one encoder or decoder." super().__init__() self.embedding = embedding_layer self.pe = positional_encoding @@ -133,8 +135,8 @@ def __init__( def _single_forward( self, - cell_1:TransformerCell, - cell_2:TransformerCell, + cell_1: TransformerCell, + cell_2: TransformerCell, head_1: Tensor, head_2: Tensor, ) -> Tuple[Tensor, Tensor]: @@ -149,9 +151,7 @@ def _single_forward( return out0, out1 - def forward( - self, head_1: Tensor, head_2: Tensor - ) -> Tuple[Tensor, Tensor]: + def forward(self, head_1: Tensor, head_2: Tensor) -> Tuple[Tensor, Tensor]: for cell_1, cell_2 in zip(self.cell_1, self.cell_2): head_1, head_2 = self._single_forward(cell_1, cell_2, head_1, head_2) diff --git a/lightorch/training/supervised.py b/lightorch/training/supervised.py index ce6f913..a00abde 100644 --- a/lightorch/training/supervised.py +++ b/lightorch/training/supervised.py @@ -38,15 +38,17 @@ def interval(algo: LRScheduler) -> str: else: return "epoch" -class Module(LightningModule): - """ - init: triggers: Dict[str, Dict[str, float]] -> This is an interpretative implementation for grouped optimization where the parameters are stored in groups given a "trigger", namely, as trigger parameters you can put a string describing the beginning of the parameters to optimize in a group. - optimizer: str | Optimizer -> Name of the optimizer or an Optimizer instance. - scheduler: str | LRScheduler -> Name of the scheduler or a Scheduler instance. - scheduler_kwargs: Dict[str, Any] -> Arguments of the scheduler. - gradient_clip_algorithm: str -> Gradient clip algorithm [value, norm]. - gradient_clip_val: float -> Clipping value. + +class Module(LightningModule): + """ + init: triggers: Dict[str, Dict[str, float]] -> This is an interpretative implementation for grouped optimization where the parameters are stored in groups given a "trigger", namely, as trigger parameters you can put a string describing the beginning of the parameters to optimize in a group. + optimizer: str | Optimizer -> Name of the optimizer or an Optimizer instance. + scheduler: str | LRScheduler -> Name of the scheduler or a Scheduler instance. + scheduler_kwargs: Dict[str, Any] -> Arguments of the scheduler. + gradient_clip_algorithm: str -> Gradient clip algorithm [value, norm]. + gradient_clip_val: float -> Clipping value. """ + def __init__( self, *, @@ -189,7 +191,9 @@ def _configure_scheduler(self, optimizer: Optimizer) -> LRScheduler: else: return self.scheduler(optimizer) - def configure_optimizers(self) -> Dict[str, Union[Optimizer, Dict[str, Union[float, int, LRScheduler]]]]: + def configure_optimizers( + self, + ) -> Dict[str, Union[Optimizer, Dict[str, Union[float, int, LRScheduler]]]]: optimizer = self._configure_optimizer() if self.scheduler is not None: scheduler = self._configure_scheduler(optimizer) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..4e898e3 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,156 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile ./requirements.in +# +aiohttp==3.9.5 + # via fsspec +aiosignal==1.3.1 + # via aiohttp +alembic==1.13.2 + # via optuna +async-timeout==4.0.3 + # via aiohttp +attrs==23.2.0 + # via aiohttp +colorlog==6.8.2 + # via optuna +einops==0.8.0 + # via -r ./requirements.in +filelock==3.15.4 + # via + # torch + # triton +frozenlist==1.4.1 + # via + # aiohttp + # aiosignal +fsspec[http]==2024.6.1 + # via + # lightning + # pytorch-lightning + # torch +greenlet==3.0.3 + # via sqlalchemy +idna==3.7 + # via yarl +jinja2==3.1.4 + # via torch +lightning==2.3.3 + # via -r ./requirements.in +lightning-utilities==0.11.3.post0 + # via + # lightning + # pytorch-lightning + # torchmetrics +mako==1.3.5 + # via alembic +markupsafe==2.1.5 + # via + # jinja2 + # mako +mpmath==1.3.0 + # via sympy +multidict==6.0.5 + # via + # aiohttp + # yarl +networkx==3.3 + # via torch +numpy==2.0.0 + # via + # lightning + # optuna + # pytorch-lightning + # torchmetrics + # torchvision +nvidia-cublas-cu12==12.1.3.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==8.9.2.26 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-nccl-cu12==2.20.5 + # via torch +nvidia-nvjitlink-cu12==12.5.82 + # via + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +optuna==3.6.1 + # via -r ./requirements.in +packaging==24.1 + # via + # lightning + # lightning-utilities + # optuna + # pytorch-lightning + # torchmetrics +pillow==10.4.0 + # via torchvision +pytorch-lightning==2.3.3 + # via lightning +pyyaml==6.0.1 + # via + # lightning + # optuna + # pytorch-lightning +sqlalchemy==2.0.31 + # via + # alembic + # optuna +sympy==1.13.0 + # via torch +torch==2.3.1 + # via + # -r ./requirements.in + # lightning + # pytorch-lightning + # torchmetrics + # torchvision +torchmetrics==1.4.0.post0 + # via + # lightning + # pytorch-lightning +torchvision==0.18.1 + # via -r ./requirements.in +tqdm==4.66.4 + # via + # -r ./requirements.in + # lightning + # optuna + # pytorch-lightning +triton==2.3.1 + # via torch +typing-extensions==4.12.2 + # via + # alembic + # lightning + # lightning-utilities + # pytorch-lightning + # sqlalchemy + # torch +yarl==1.9.4 + # via aiohttp + +# The following packages are considered to be unsafe in a requirements file: +# setuptools