Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorgedavyd committed Jul 10, 2024
2 parents 0c8942b + dd4fd8e commit 21699bb
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 18 deletions.
4 changes: 1 addition & 3 deletions lightorch/nn/sequential/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@


class _Residual(nn.Module):
def __init__(
self, module: nn.Module, n_layers: int
):
def __init__(self, module: nn.Module, n_layers: int):
super().__init__()
self.model = nn.ModuleList([module for _ in range(n_layers)])

Expand Down
12 changes: 6 additions & 6 deletions lightorch/nn/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def __init__(
fc: Optional[nn.Module] = None,
n_layers: int = 1,
) -> None:
assert (encoder is not None or decoder is not None), "Not valid parameters, must be at least one encoder or decoder."
assert (
encoder is not None or decoder is not None
), "Not valid parameters, must be at least one encoder or decoder."
super().__init__()
self.embedding = embedding_layer
self.pe = positional_encoding
Expand Down Expand Up @@ -133,8 +135,8 @@ def __init__(

def _single_forward(
self,
cell_1:TransformerCell,
cell_2:TransformerCell,
cell_1: TransformerCell,
cell_2: TransformerCell,
head_1: Tensor,
head_2: Tensor,
) -> Tuple[Tensor, Tensor]:
Expand All @@ -149,9 +151,7 @@ def _single_forward(

return out0, out1

def forward(
self, head_1: Tensor, head_2: Tensor
) -> Tuple[Tensor, Tensor]:
def forward(self, head_1: Tensor, head_2: Tensor) -> Tuple[Tensor, Tensor]:
for cell_1, cell_2 in zip(self.cell_1, self.cell_2):
head_1, head_2 = self._single_forward(cell_1, cell_2, head_1, head_2)

Expand Down
22 changes: 13 additions & 9 deletions lightorch/training/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,17 @@ def interval(algo: LRScheduler) -> str:
else:
return "epoch"

class Module(LightningModule):
"""
init: triggers: Dict[str, Dict[str, float]] -> This is an interpretative implementation for grouped optimization where the parameters are stored in groups given a "trigger", namely, as trigger parameters you can put a string describing the beginning of the parameters to optimize in a group.
optimizer: str | Optimizer -> Name of the optimizer or an Optimizer instance.
scheduler: str | LRScheduler -> Name of the scheduler or a Scheduler instance.
scheduler_kwargs: Dict[str, Any] -> Arguments of the scheduler.
gradient_clip_algorithm: str -> Gradient clip algorithm [value, norm].
gradient_clip_val: float -> Clipping value.

class Module(LightningModule):
"""
init: triggers: Dict[str, Dict[str, float]] -> This is an interpretative implementation for grouped optimization where the parameters are stored in groups given a "trigger", namely, as trigger parameters you can put a string describing the beginning of the parameters to optimize in a group.
optimizer: str | Optimizer -> Name of the optimizer or an Optimizer instance.
scheduler: str | LRScheduler -> Name of the scheduler or a Scheduler instance.
scheduler_kwargs: Dict[str, Any] -> Arguments of the scheduler.
gradient_clip_algorithm: str -> Gradient clip algorithm [value, norm].
gradient_clip_val: float -> Clipping value.
"""

def __init__(
self,
*,
Expand Down Expand Up @@ -189,7 +191,9 @@ def _configure_scheduler(self, optimizer: Optimizer) -> LRScheduler:
else:
return self.scheduler(optimizer)

def configure_optimizers(self) -> Dict[str, Union[Optimizer, Dict[str, Union[float, int, LRScheduler]]]]:
def configure_optimizers(
self,
) -> Dict[str, Union[Optimizer, Dict[str, Union[float, int, LRScheduler]]]]:
optimizer = self._configure_optimizer()
if self.scheduler is not None:
scheduler = self._configure_scheduler(optimizer)
Expand Down
156 changes: 156 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile ./requirements.in
#
aiohttp==3.9.5
# via fsspec
aiosignal==1.3.1
# via aiohttp
alembic==1.13.2
# via optuna
async-timeout==4.0.3
# via aiohttp
attrs==23.2.0
# via aiohttp
colorlog==6.8.2
# via optuna
einops==0.8.0
# via -r ./requirements.in
filelock==3.15.4
# via
# torch
# triton
frozenlist==1.4.1
# via
# aiohttp
# aiosignal
fsspec[http]==2024.6.1
# via
# lightning
# pytorch-lightning
# torch
greenlet==3.0.3
# via sqlalchemy
idna==3.7
# via yarl
jinja2==3.1.4
# via torch
lightning==2.3.3
# via -r ./requirements.in
lightning-utilities==0.11.3.post0
# via
# lightning
# pytorch-lightning
# torchmetrics
mako==1.3.5
# via alembic
markupsafe==2.1.5
# via
# jinja2
# mako
mpmath==1.3.0
# via sympy
multidict==6.0.5
# via
# aiohttp
# yarl
networkx==3.3
# via torch
numpy==2.0.0
# via
# lightning
# optuna
# pytorch-lightning
# torchmetrics
# torchvision
nvidia-cublas-cu12==12.1.3.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via
# nvidia-cusolver-cu12
# torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.5.82
# via
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
optuna==3.6.1
# via -r ./requirements.in
packaging==24.1
# via
# lightning
# lightning-utilities
# optuna
# pytorch-lightning
# torchmetrics
pillow==10.4.0
# via torchvision
pytorch-lightning==2.3.3
# via lightning
pyyaml==6.0.1
# via
# lightning
# optuna
# pytorch-lightning
sqlalchemy==2.0.31
# via
# alembic
# optuna
sympy==1.13.0
# via torch
torch==2.3.1
# via
# -r ./requirements.in
# lightning
# pytorch-lightning
# torchmetrics
# torchvision
torchmetrics==1.4.0.post0
# via
# lightning
# pytorch-lightning
torchvision==0.18.1
# via -r ./requirements.in
tqdm==4.66.4
# via
# -r ./requirements.in
# lightning
# optuna
# pytorch-lightning
triton==2.3.1
# via torch
typing-extensions==4.12.2
# via
# alembic
# lightning
# lightning-utilities
# pytorch-lightning
# sqlalchemy
# torch
yarl==1.9.4
# via aiohttp

# The following packages are considered to be unsafe in a requirements file:
# setuptools

0 comments on commit 21699bb

Please sign in to comment.