diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 7879a0b..62c9387 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -15,13 +15,15 @@ jobs: python-version: [3.8, 3.9, 3.11, 3.12] include: - python-version: 3.12 - commit: true + commit: false - python-version: 3.8 commit: false - python-version: 3.9 commit: false - python-version: 3.11 commit: false + - python-version: 3.10 + commit: true steps: - name: Checkout repository @@ -32,15 +34,9 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Set up dependencies - run: | - python -m pip install --upgrade pip - pip install pipreqs pip-tools - chmod +x requirements.sh - sh ./requirements.sh - - name: Install dependencies run: | + python -m pip install --upgrade pip pip install pytest black pip install -r requirements.txt @@ -48,13 +44,10 @@ jobs: run: | pytest tests/ - - name: Run black - run: | - black . - - name: Commit to repo if: matrix.commit == true run: | + black . git config --global user.name 'Jorgedavyd' git config --global user.email 'jorged.encyso@gmail.com' git diff --exit-code || (git add . && git commit -m "Automatically formatted with black" && git push) diff --git a/docs/api/htuning.md b/docs/api/htuning.md index caebeb7..f29b498 100644 --- a/docs/api/htuning.md +++ b/docs/api/htuning.md @@ -16,12 +16,12 @@ if __name__ == '__main__': model_class = FourierVAE, hparam_objective = objective, datamodule = NormalModule, - valid_metrics = [f"Training/{name}" for name in [ + valid_metrics = [ "Pixel", "Perceptual", "Style", "Total variance", - "KL Divergence"]], + "KL Divergence"], directions = ['minimize', 'minimize', 'minimize', 'minimize', 'minimize'], precision = 'medium', n_trials = 150, diff --git a/docs/api/transformer.md b/docs/api/transformer.md deleted file mode 100644 index 80f9bcb..0000000 --- a/docs/api/transformer.md +++ /dev/null @@ -1 +0,0 @@ -# Transformers \ No newline at end of file diff --git a/lightorch/htuning/optuna.py b/lightorch/htuning/optuna.py index 0af44bc..8d6f69d 100644 --- a/lightorch/htuning/optuna.py +++ b/lightorch/htuning/optuna.py @@ -43,13 +43,20 @@ def objective(trial: optuna.trial.Trial): trainer.fit(model, datamodule=dataset) if isinstance(valid_metrics, str): - return trainer.callback_metrics[valid_metrics].item() - - return ( - trainer.callback_metrics[valid_metric].item() - for valid_metric in valid_metrics - ) - + if valid_metrics == 'hp_metric': + return trainer.callback_metrics[valid_metrics].item() + return trainer.callback_metrics[f'Training/{valid_metrics}'].item() + + else: + out = [] + for valid_metric in valid_metrics: + if valid_metric == 'hp_metric': + out.append(trainer.callback_metrics[valid_metric].item()) + else: + out.append(trainer.callback_metrics[f'Training/{valid_metric}'].item()) + + return out + if "precision" in kwargs: torch.set_float32_matmul_precision(precision) else: diff --git a/lightorch/nn/criterions.py b/lightorch/nn/criterions.py index 719941d..f6890c7 100644 --- a/lightorch/nn/criterions.py +++ b/lightorch/nn/criterions.py @@ -27,13 +27,11 @@ def __init__( class Loss(LighTorchLoss): def __init__(self, *loss) -> None: + assert (len(set(map(type, loss))) == len(loss)), 'Not valid input classes, each should be different.' super().__init__( list(set([*chain.from_iterable([i.labels for i in loss])])), _merge_dicts([i.factors for i in loss]), ) - assert len(loss) == len( - self.factors - ), "Must have the same length of losses as factors" self.loss = loss def forward(self, **kwargs) -> Tuple[Tensor, ...]: @@ -41,9 +39,9 @@ def forward(self, **kwargs) -> Tuple[Tensor, ...]: out_list = [] for loss in self.loss: - *loss_arg, out_ = loss(**kwargs) - out_list.extend(list(*loss_arg)) - loss_ += out_ + args = loss(**kwargs) + out_list.extend(list(args[:-1])) + loss_ += args[-1] out_list.append(loss_) diff --git a/lightorch/training/adversarial.py b/lightorch/training/adversarial.py index 86a951a..9178c80 100644 --- a/lightorch/training/adversarial.py +++ b/lightorch/training/adversarial.py @@ -1,47 +1,14 @@ -from typing import Union, Sequence, Any, Tuple, Dict -from torch import Tensor, nn +from typing import Any, Dict from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -from collections import defaultdict import torch -from torch.optim import Adam, Adadelta, Adamax, AdamW, SGD, LBFGS, RMSprop from .supervised import Module as Module_ -from torch.optim.lr_scheduler import ( - OneCycleLR, - ReduceLROnPlateau, - ExponentialLR, - LinearLR, -) +from torch import Tensor import torchvision -VALID_OPTIMIZERS = { - "adam": Adam, - "adadelta": Adadelta, - "adamax": Adamax, - "adamw": AdamW, - "sgd": SGD, - "lbfgs": LBFGS, - "rms": RMSprop, -} - -VALID_SCHEDULERS = { - "onecycle": OneCycleLR, - "plateau": ReduceLROnPlateau, - "exponential": ExponentialLR, - "linear": LinearLR, -} - - -def interval(algo: LRScheduler) -> str: - if isinstance(algo, OneCycleLR): - return "step" - else: - return "epoch" - - class Module(Module_): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + def __init__(self, *, optimizer: str | Optimizer, scheduler: str | LRScheduler = None, triggers: Dict[str, Dict[str, float]] = None, optimizer_kwargs: Dict[str, Any] = None, scheduler_kwargs: Dict[str, Any] = None, gradient_clip_algorithm: str = None, gradient_clip_val: float = None) -> None: + super().__init__(optimizer=optimizer, scheduler=scheduler, triggers=triggers, optimizer_kwargs=optimizer_kwargs, scheduler_kwargs=scheduler_kwargs, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val) self.automatic_optimization = False def validation_step(self) -> None: @@ -90,53 +57,5 @@ def training_step(self, batch: Tensor, idx: int) -> Tensor: opt_d.zero_grad() self.untoggle_optimizer(opt_d) - def get_param_groups(self, *triggers) -> Tuple: - """ - Given a list of "triggers", the param groups are defined. - """ - - param_groups: Sequence[Dict[str, Sequence[nn.Module]]] = [ - defaultdict(list) * len(triggers) - ] - - for param_group, trigger in zip(param_groups, triggers): - for name, param in self.named_modules(): - if name.startswith(trigger): - param_group["params"].append(param) - - return param_groups - - def _configure_optimizer(self) -> Optimizer: - optimizer_args: Dict[str, Union[float, nn.Module]] = [] - for hparam, param_group in zip( - self.get_hparams(), self.get_param_groups(*self.triggers) - ): - optimizer_args.append(param_group.update(hparam)) - optimizer = VALID_OPTIMIZERS[self.optimizer](optimizer_args) - return optimizer - - def _configure_scheduler(self, optimizer: Optimizer) -> LRScheduler: - if self.scheduler == "onecycle": - return VALID_SCHEDULERS[self.scheduler]( - optimizer, - **self.scheduler_kwargs.update( - {"total_steps": self.trainer.estimated_stepping_batches} - ) - ) - else: - return VALID_SCHEDULERS[self.scheduler](optimizer, **self.scheduler_kwargs) - - def configure_optimizers(self) -> Optimizer | Sequence[Optimizer]: - optimizer = self._configure_optimizer() - scheduler = self._configure_scheduler(optimizer) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "interval": interval(optimizer), - "frequency": 1, - }, - } - __all__ = ["Module"] diff --git a/lightorch/training/cli.py b/lightorch/training/cli.py index 92df32a..f139bf6 100644 --- a/lightorch/training/cli.py +++ b/lightorch/training/cli.py @@ -1,6 +1,5 @@ -from lightning.pytorch import LightningDataModule -import torch from lightning.pytorch.cli import LightningCLI +import torch def trainer( @@ -15,6 +14,7 @@ def trainer( trainer_defaults={ "deterministic": deterministic, }, + ) diff --git a/lightorch/training/supervised.py b/lightorch/training/supervised.py index 9736152..be596ff 100644 --- a/lightorch/training/supervised.py +++ b/lightorch/training/supervised.py @@ -31,34 +31,91 @@ "linear": LinearLR, } - def interval(algo: LRScheduler) -> str: if isinstance(algo, OneCycleLR): return "step" else: return "epoch" - class Module(LightningModule): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + """ + init: + triggers: Dict[str, Dict[str, float]] -> This is an + interpretative implementation for grouped optimization + where the parameters are stored in groups given a "trigger", + namely, as trigger parameters you can put a string describing + the beginning of the parameters to optimize in a group. + optimizer: str | Optimizer -> Name of the optimizer or an Optimizer instance. + scheduler: str | LRScheduler -> Name of the scheduler or a Scheduler instance. + scheduler_kwargs: Dict[str, Any] -> Arguments of the scheduler. + gradient_clip_algorithm: str -> Gradient clip algorithm [value, norm]. + gradient_clip_val: float -> Clipping value. + """ + def __init__( + self, + *, + optimizer: Union[str, Optimizer], + scheduler: Union[str, LRScheduler] = None, + triggers: Dict[str, Dict[str, float]] = None, + optimizer_kwargs: Dict[str, Any] = None, + scheduler_kwargs: Dict[str, Any] = None, + gradient_clip_algorithm: str = None, + gradient_clip_val: float = None, + **kwargs + ) -> None: + super().__init__() for att in kwargs: setattr(self, att, kwargs[att]) - # Setting up the gradient clipping - self.trainer.gradient_clip_algorithm = self.gradient_clip_algorithm - self.trainer.gradient_clip_val = self.gradient_clip_val + + # Initializing the optimizer and the triggers + self.triggers = triggers + if triggers is not None: + assert optimizer_kwargs is None, 'Not valid optimizer_kwargs parameter for trigger-based setting, include all optimizer parameters in the dictionary with their respective name.' + self.triggers = triggers + else: + if not isinstance(optimizer, Optimizer): + assert optimizer_kwargs is not None, 'Must specify optimizer_kwargs parameter for non-trigger-based setting.' + self.optimizer_kwargs = optimizer_kwargs + else: + assert optimizer_kwargs is None, 'Not valid optimizer_kwargs parameter for initialized optimizer.' + self.optimizer = optimizer + + if isinstance(optimizer, str) or issubclass(optimizer, Optimizer): + self.optimizer = optimizer + else: + if not getattr(self, 'optimizer', False): + raise ValueError(f'Not valid optimizer parameter, expecting str | Optimizer got {type(optimizer)}') + + # Initializing the scheduler + if scheduler is not None: + if isinstance(scheduler, str): + self.scheduler = scheduler + self.scheduler_kwargs = scheduler_kwargs + elif isinstance(scheduler, LRScheduler): + self.scheduler = lambda optimizer: scheduler(optimizer=optimizer, **scheduler_kwargs) + else: + raise ValueError('Not valid scheduler parameter') + else: + assert scheduler_kwargs is None, 'Not valid scheduler_kwargs parameter for NoneType scheduler' + self.scheduler = None + + self.trainer.gradient_clip_algorithm = gradient_clip_algorithm + self.trainer.gradient_clip_val = gradient_clip_val + + def loss_forward(self, batch: Tensor, idx: int) -> Dict[str, Union[Tensor, float]]: + raise NotImplementedError('Should have defined loss_forward method.') def training_step(self, batch: Tensor, idx: int) -> Tensor: - args = self.loss_forward(batch, idx) - return self._compute_training_loss(*args) + kwargs = self.loss_forward(batch, idx) + return self._compute_training_loss(**kwargs) @torch.no_grad() def validation_step(self, batch: Tensor, idx: int) -> None: - args = self.loss_forward(batch, idx) - return self._compute_valid_metrics(*args) + kwargs = self.loss_forward(batch, idx) + return self._compute_valid_metrics(**kwargs) - def _compute_training_loss(self, *args) -> Tensor | Sequence[Tensor]: - args = self.criterion(*args) + def _compute_training_loss(self, **kwargs) -> Union[Tensor, Sequence[Tensor]]: + args = self.criterion(**kwargs) self.log_dict( {f"Training/{k}": v for k, v in zip(self.criterion.labels, args)}, True, @@ -70,78 +127,74 @@ def _compute_training_loss(self, *args) -> Tensor | Sequence[Tensor]: return args[-1] @torch.no_grad() - def _compute_valid_metrics(self, *args) -> None: - args = self.criterion.val_step(*args) + def _compute_valid_metrics(self, **kwargs) -> None: + args = self.criterion(**kwargs) self.log_dict( - {f"Validation/{k}": v for k, v in zip(self.criterion.val_labels, args)}, + {f"Validation/{k}": v for k, v in zip(self.criterion.labels, args)}, True, True, True, True, ) - def get_param_groups(self, *triggers) -> Tuple: + def get_param_groups(self) -> Tuple: """ Given a list of "triggers", the param groups are defined. """ - - param_groups: Sequence[Dict[str, Sequence[nn.Module]]] = [ - defaultdict(list) * len(triggers) - ] - - for param_group, trigger in zip(param_groups, triggers): - for name, param in self.named_modules(): - if name.startswith(trigger): - param_group["params"].append(param) - - return param_groups - - def get_hparams(self) -> Sequence[Dict[str, float]]: - return ( - [ - {"lr": lr, "weight_decay": wd, "momentum": mom} - for lr, wd, mom in zip( - self.learning_rate, self.weight_decay, self.momentum - ) - ] - if getattr(self, "momentum", False) - else [ - {"lr": lr, "weight_decay": mom} - for lr, mom in zip(self.learning_rate, self.weight_decay) + if self.triggers is not None: + param_groups: Sequence[Dict[str, Sequence[nn.Module]]] = [ + defaultdict(list) for _ in range(len(self.triggers)) ] - ) - + # Update the model parameters per group and finally add the + # hyperparameters + for param_group, trigger in zip(param_groups, self.triggers): + for name, param in self.named_parameters(): + if name.startswith(trigger): + param_group["params"].append(param) + + param_group.update(self.triggers[trigger]) + + return param_groups + return None def _configure_optimizer(self) -> Optimizer: - optimizer_args: Dict[str, Union[float, nn.Module]] = [] - for hparam, param_group in zip( - self.get_hparams(), self.get_param_groups(*self.triggers) - ): - optimizer_args.append(param_group.update(hparam)) - optimizer = VALID_OPTIMIZERS[self.optimizer](optimizer_args) - return optimizer - - def _configure_scheduler(self, optimizer: Optimizer) -> LRScheduler: - if self.scheduler == "onecycle": - return VALID_SCHEDULERS[self.scheduler]( - optimizer, - **self.scheduler_kwargs.update( - {"total_steps": self.trainer.estimated_stepping_batches} - ), - ) + if params:= self.get_param_groups() is not None: + if isinstance(self.optimizer, str): + return VALID_OPTIMIZERS[self.optimizer](params) + elif isinstance(self.optimizer, torch.optim.Optimizer): + return self.optimizer + elif issubclass(self.optimizer, torch.optim.Optimizer): + return self.optimizer(params) else: + + if isinstance(self.optimizer, str): + self.optimizer = VALID_OPTIMIZERS[self.optimizer] + elif isinstance(self.optimizer, Optimizer): + return self.optimizer + elif issubclass(self.optimizer, Optimizer): + pass + + return self.optimizer(self.parameters(), **self.optimizer_kwargs) + + def _configure_scheduler(self, optimizer: Optimizer) -> LRScheduler: + if isinstance(self.scheduler, str): + if self.scheduler == "onecycle": + self.scheduler_kwargs["total_steps"] = self.trainer.estimated_stepping_batches return VALID_SCHEDULERS[self.scheduler](optimizer, **self.scheduler_kwargs) - - def configure_optimizers(self) -> Optimizer | Sequence[Optimizer]: + else: + return self.scheduler(optimizer) + + def configure_optimizers(self) -> Union[Optimizer, Sequence[Optimizer]]: optimizer = self._configure_optimizer() - scheduler = self._configure_scheduler(optimizer) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "interval": interval(optimizer), - "frequency": 1, - }, - } - + if self.scheduler is not None: + scheduler = self._configure_scheduler(optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": interval(scheduler), + "frequency": 1, + }, + } + return {"optimizer": optimizer} __all__ = ["Module"] diff --git a/requirements.sh b/requirements.sh deleted file mode 100755 index ec47f5e..0000000 --- a/requirements.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -# Remove existing requirements.txt if it exists -rm -f requirements.txt - -# Generate requirements.in from the lightorch directory -pipreqs lightorch/ --savepath=requirements.in - -# Compile requirements.in to requirements.txt -pip-compile requirements.in - -# Clean up -rm -f requirements.in \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index bfcc793..4c963ba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,132 +1,523 @@ -# -# This file is autogenerated by pip-compile with Python 3.11 -# by the following command: -# -# pip-compile requirements.in -# -aiohttp==3.9.5 - # via fsspec +absl-py==2.1.0 +accelerate==0.29.3 +access==1.1.9 +affine==2.4.0 +aiofiles==23.2.1 +aioftp==0.21.4 +aiohttp==3.9.1 +aiolimiter==1.1.0 aiosignal==1.3.1 - # via aiohttp -attrs==23.2.0 - # via aiohttp +alembic==1.13.1 +altair==5.3.0 +annotated-types==0.6.0 +antlr4-python3-runtime==4.9.3 +anyio==4.3.0 +appdirs==1.4.4 +apturl==0.5.2 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +astropy==6.0.0 +astropy-iers-data==0.2023.12.4.0.30.20 +astroquery==0.4.6 +asttokens==2.4.1 +async-lru==2.0.4 +async-timeout==4.0.3 +attrs==23.1.0 +Authlib==1.2.1 +av==11.0.0 +ax-platform==0.4.0 +azure-common==1.1.28 +azure-core==1.30.0 +azure-identity==1.15.0 +azure-mgmt-core==1.4.0 +azure-mgmt-rdbms==10.1.0 +azure-mgmt-resource==23.0.1 +azure-mgmt-subscription==3.1.1 +Babel==2.14.0 +backcall==0.2.0 +backoff==2.2.1 +bcrypt==4.0.1 +beautifulsoup4==4.12.2 +bidict==0.23.0 +black==24.4.0 +bleach==6.1.0 +blessed==1.20.0 +blinker==1.7.0 +blosc2==2.3.2 +boto3==1.33.13 +botocore==1.33.13 +botorch==0.11.0 +branca==0.7.0 +Brotli==1.1.0 +bs4==0.0.1 +build==1.2.1 +CacheControl==0.12.10 +cachetools==5.3.2 +catalyst==22.4 +cdflib==1.2.3 +certifi==2024.2.2 +cffi==1.16.0 +cfgv==3.4.0 +cftime==1.6.3 +chardet==4.0.0 +charset-normalizer==3.3.2 +chromedriver-binary==125.0.6422.76.0 +click==8.1.7 +click-plugins==1.1.1 +cligj==0.7.2 +colorama==0.4.4 +colorlog==6.8.2 +comm==0.2.0 +command-not-found==0.3 +configobj==5.0.6 +contourpy==1.2.0 +corkit==1.0.15 +coverage==7.3.2 +croniter==1.3.15 +cryptography==41.0.7 +cssselect==1.1.0 +cuda-python==12.3.0 +cudf-cu12==23.12.1 +cupshelpers==1.0 +cupy-cuda12x==13.0.0 +cycler==0.12.1 +DateTime==5.3 +dbus-python==1.2.18 +debugpy==1.8.0 +decorator==5.1.1 +deepdiff==7.0.1 +defer==1.0.6 +defusedxml==0.7.1 +deprecation==2.1.0 +difftorch==1.2.2 +distlib==0.3.8 +distro==1.7.0 +distro-info==1.1+ubuntu0.2 +dnspython==2.6.0 +docker-pycreds==0.4.0 +docopt==0.6.2 +docstring_parser==0.16 +docutils==0.21.1 +drms==0.7.0 +duplicity==0.8.21 +editor==1.6.6 einops==0.8.0 - # via -r requirements.in -filelock==3.14.0 - # via - # torch - # triton -frozenlist==1.4.1 - # via - # aiohttp - # aiosignal -fsspec[http]==2024.5.0 - # via - # lightning - # pytorch-lightning - # torch -idna==3.7 - # via yarl -jinja2==3.1.4 - # via torch -lightning==2.2.5 - # via -r requirements.in +email-validator==2.1.0.post1 +esda==2.5.1 +eventlet==0.34.2 +exceptiongroup==1.2.0 +executing==2.0.1 +fastapi==0.111.0 +fastapi-cli==0.0.2 +fasteners==0.14.1 +fastjsonschema==2.19.1 +fastrlock==0.8.2 +filelock==3.13.1 +fiona==1.9.5 +Flask==2.3.3 +flask-babel==4.0.0 +Flask-Compress==1.14 +Flask-Gravatar==0.5.0 +Flask-Login==0.6.3 +Flask-Mail==0.9.1 +Flask-Migrate==4.0.5 +Flask-Paranoid==0.3.0 +Flask-Principal==0.4.0 +Flask-Security-Too==5.2.0 +Flask-SocketIO==5.3.6 +Flask-SQLAlchemy==3.1.1 +Flask-WTF==1.2.1 +folium==0.15.1 +fonttools==4.46.0 +fqdn==1.5.1 +frozenlist==1.4.0 +fsspec==2023.10.0 +future==0.18.2 +fvcore==0.1.5.post20221221 +GDAL==3.4.1 +gdown==5.1.0 +geographiclib==2.0 +geomagpy==1.1.7 +geopandas==0.14.1 +geopy==2.4.1 +georasters==0.5.27 +giddy==2.3.4 +gitdb==4.0.11 +GitPython==3.1.43 +google-api-core==2.17.1 +google-api-python-client==2.118.0 +google-auth==2.27.0 +google-auth-httplib2==0.2.0 +google-auth-oauthlib==1.1.0 +googleapis-common-protos==1.62.0 +gpustat==1.1.1 +gpytorch==1.11 +greenlet==1.1.2 +grpcio==1.60.0 +gyp==0.1 +h11==0.14.0 +h5netcdf==1.3.0 +h5py==3.10.0 +hiplot==0.1.33 +html5lib==1.1 +httpagentparser==1.9.5 +httpcore==1.0.5 +httplib2==0.20.2 +httptools==0.6.1 +httpx==0.27.0 +huggingface-hub==0.23.0 +hydra-core==1.3.2 +hydra-slayer==0.5.0 +icecream==2.1.3 +identify==2.5.36 +idna==3.3 +imageio==2.33.1 +importlib-metadata==7.0.1 +importlib-resources==6.1.1 +inequality==1.0.1 +iniconfig==2.0.0 +inquirer==3.2.4 +iopath==0.1.10 +ipykernel==6.27.1 +ipython==8.12.3 +ipywidgets==8.1.1 +isodate==0.6.1 +isoduration==20.11.0 +itsdangerous==2.1.2 +jaraco.classes==3.3.1 +jaxtyping==0.2.28 +jedi==0.19.1 +jeepney==0.7.1 +Jinja2==3.1.2 +jmespath==1.0.1 +joblib==1.3.2 +json5==0.9.25 +jsonargparse==4.28.0 +jsonpointer==2.4 +jsonschema==4.22.0 +jsonschema-specifications==2023.12.1 +jupyter-events==0.10.0 +jupyter-lsp==2.2.5 +jupyter_client==8.6.0 +jupyter_core==5.5.0 +jupyter_server==2.14.0 +jupyter_server_terminals==0.5.3 +jupyterlab==4.1.8 +jupyterlab-widgets==3.0.9 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.1 +keyring==24.3.0 +kiwisolver==1.4.5 +language-selector==0.1 +launchpadlib==1.10.16 +lazr.restfulclient==0.14.4 +lazr.uri==1.0.6 +lazy_loader==0.3 +ldap3==2.9.1 +libpysal==4.9.2 +lightning==2.2.4 lightning-utilities==0.11.2 - # via - # lightning - # pytorch-lightning - # torchmetrics -markupsafe==2.1.5 - # via jinja2 +lightorch==0.0.1 +linear-operator==0.5.1 +llvmlite==0.40.1 +lockfile==0.12.2 +lxml==4.9.3 +macaroonbakery==1.3.1 +Mako==1.1.3 +mapclassify==2.6.1 +Markdown==3.5.2 +markdown-it-py==3.0.0 +MarkupSafe==2.1.4 +matplotlib==3.8.2 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +mgwr==2.2.0 +mistune==3.0.2 +momepy==0.7.0 +monotonic==1.6 +more-itertools==8.10.0 mpmath==1.3.0 - # via sympy -multidict==6.0.5 - # via - # aiohttp - # yarl -networkx==3.3 - # via torch -numpy==1.26.4 - # via - # lightning - # pytorch-lightning - # torchmetrics - # torchvision +msal==1.26.0 +msal-extensions==1.1.0 +msgpack==1.0.3 +msrest==0.7.1 +multidict==6.0.4 +multipledispatch==1.0.0 +mypy-extensions==1.0.0 +nbclient==0.10.0 +nbconvert==7.16.4 +nbformat==5.10.4 +ndindex==1.7 +nest-asyncio==1.5.8 +netCDF4==1.6.5 +netifaces==0.11.0 +networkx==3.2.1 +nh3==0.2.17 +nodeenv==1.8.0 +notebook_shim==0.2.4 +numba==0.57.1 +numexpr==2.8.7 +numpy==1.24.4 nvidia-cublas-cu12==12.1.3.1 - # via - # nvidia-cudnn-cu12 - # nvidia-cusolver-cu12 - # torch nvidia-cuda-cupti-cu12==12.1.105 - # via torch nvidia-cuda-nvrtc-cu12==12.1.105 - # via torch nvidia-cuda-runtime-cu12==12.1.105 - # via torch nvidia-cudnn-cu12==8.9.2.26 - # via torch nvidia-cufft-cu12==11.0.2.54 - # via torch nvidia-curand-cu12==10.3.2.106 - # via torch nvidia-cusolver-cu12==11.4.5.107 - # via torch nvidia-cusparse-cu12==12.1.0.106 - # via - # nvidia-cusolver-cu12 - # torch -nvidia-nccl-cu12==2.20.5 - # via torch -nvidia-nvjitlink-cu12==12.5.40 - # via - # nvidia-cusolver-cu12 - # nvidia-cusparse-cu12 +nvidia-ml-py==12.535.133 +nvidia-nccl-cu12==2.18.1 +nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 - # via torch -packaging==24.0 - # via - # lightning - # lightning-utilities - # pytorch-lightning - # torchmetrics -pillow==10.3.0 - # via torchvision -pytorch-lightning==2.2.5 - # via lightning -pyyaml==6.0.1 - # via - # lightning - # pytorch-lightning -sympy==1.12.1 - # via torch -torch==2.3.0 - # via - # -r requirements.in - # lightning - # pytorch-lightning - # torchmetrics - # torchvision -torchmetrics==1.4.0.post0 - # via - # lightning - # pytorch-lightning -torchvision==0.18.0 - # via -r requirements.in -tqdm==4.66.4 - # via - # -r requirements.in - # lightning - # pytorch-lightning -triton==2.3.0 - # via torch -typing-extensions==4.12.0 - # via - # lightning - # lightning-utilities - # pytorch-lightning - # torch -yarl==1.9.4 - # via aiohttp - -# The following packages are considered to be unsafe in a requirements file: -# setuptools +nvtx==0.2.8 +oauthlib==3.2.0 +olefile==0.46 +omegaconf==2.3.0 +opencv-python==4.9.0.80 +opt-einsum==3.3.0 +optuna==3.6.1 +optuna-integration==3.6.0 +ordered-set==4.1.0 +orjson==3.10.3 +outcome==1.3.0.post0 +overrides==7.7.0 +packaging==23.2 +paho-mqtt==1.6.1 +pandas==1.5.3 +pandocfilters==1.5.1 +parameterized==0.9.0 +paramiko==2.9.3 +parfive==2.0.2 +parso==0.8.3 +passlib==1.7.4 +pathspec==0.12.1 +patsy==0.5.4 +peewee==3.17.3 +pexpect==4.8.0 +pgadmin4==8.3 +pickleshare==0.7.5 +Pillow==9.0.1 +pip-tools==7.4.1 +pipreqs==0.5.0 +pkginfo==1.10.0 +platformdirs==4.0.0 +plotly==5.22.0 +pluggy==1.5.0 +pointpats==2.4.0 +pooch==1.8.0 +portalocker==2.8.2 +pre-commit==3.7.0 +prometheus_client==0.20.0 +prompt-toolkit==3.0.41 +protobuf==4.23.4 +psutil==5.9.6 +psycopg==3.1.12 +psycopg-binary==3.1.12 +ptyprocess==0.7.0 +PuLP==2.7.0 +pure-eval==0.2.2 +py-cpuinfo==9.0.0 +pyarrow==14.0.2 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pycairo==1.20.1 +pycparser==2.21 +pycups==2.0.1 +pydantic==2.7.1 +pydantic_core==2.18.2 +pydeck==0.9.0 +pyerfa==2.0.1.1 +pyfiglet==1.0.2 +Pygments==2.17.2 +PyGObject==3.42.1 +pyhdf==0.11.3 +PyJWT==2.3.0 +pymacaroons==0.13.0 +PyMySQL==1.1.0 +PyNaCl==1.5.0 +pynvim==0.4.2 +pyotp==2.9.0 +pyparsing==2.4.7 +pypng==0.20220715.0 +pyproj==3.6.1 +pyproject_hooks==1.1.0 +Pypubsub==4.0.3 +pyre-extensions==0.0.30 +pyreadstat==1.2.7 +pyRFC3339==1.1 +pyro-api==0.1.2 +pyro-ppl==1.9.0 +pysal==23.7 +PySocks==1.7.1 +pytest==8.2.1 +python-apt==2.4.0+ubuntu3 +python-dateutil==2.8.2 +python-debian==0.1.43+ubuntu1.1 +python-dotenv==1.0.1 +python-engineio==4.9.0 +python-json-logger==2.0.7 +python-multipart==0.0.9 +python-socketio==5.11.1 +pytorch-ignite==0.5.0.post2 +pytorch-lightning==2.2.4 +pytorchvideo==0.1.5 +pytz==2023.4 +pyvo==1.4.2 +PyWavelets==1.6.0 +pyxdg==0.27 +PyYAML==5.4.1 +pyzmq==25.1.1 +qrcode==7.4.2 +quantecon==0.7.1 +rasterio==1.3.9 +rasterstats==0.19.0 +readchar==4.0.6 +readme_renderer==43.0 +referencing==0.35.1 +reportlab==3.6.8 +requests==2.31.0 +requests-file==1.5.1 +requests-oauthlib==1.3.1 +requests-toolbelt==1.0.0 +rfc3339-validator==0.1.4 +rfc3986==2.0.0 +rfc3986-validator==0.1.1 +rich==13.7.0 +rmm-cu12==23.12.0 +rpds-py==0.18.0 +rsa==4.9 +Rtree==1.1.0 +runs==1.2.2 +s3transfer==0.8.2 +safetensors==0.4.3 +savReaderWriter==3.4.2 +savvy==0.0.1 +scikit-fuzzy==0.4.2 +scikit-image==0.22.0 +scikit-learn==1.3.2 +scipy==1.11.4 +scour==0.38.2 +screen-resolution-extra==0.0.0 +seaborn==0.13.0 +SecretStorage==3.3.1 +segregation==2.5 +selenium==4.18.1 +Send2Trash==1.8.3 +sentry-sdk==2.0.1 +seppy==0.1.11 +setproctitle==1.3.3 +shapely==2.0.2 +shellingham==1.5.4 +simple-websocket==1.0.0 +simplejson==3.19.2 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.1 +snuggs==1.4.7 +solo-epd-loader==0.3.6 +sortedcontainers==2.4.0 +soupsieve==2.5 +spacepy==0.4.1 +spaghetti==1.7.4 +speaklater3==1.4 +spglm==1.1.0 +spint==1.0.7 +splot==1.1.5.post1 +spopt==0.6.0 +spreg==1.4.2 +spvcm==0.3.0 +SQLAlchemy==2.0.27 +sqlite-web==0.6.3 +sqlmodel==0.0.18 +sqlparse==0.4.4 +sshtunnel==0.4.0 +stack-data==0.6.3 +starlette==0.37.2 +starsessions==1.3.0 +statsmodels==0.14.0 +streamlit==1.34.0 +sunpy==5.1.0 +sympy==1.12 +systemd-python==234 +tables==3.9.2 +tabulate==0.9.0 +tenacity==8.2.3 +tensorboard==2.15.1 +tensorboard-data-server==0.7.2 +tensorboardX==2.6.2.2 +termcolor==2.4.0 +termdown==1.18.0 +terminado==0.18.1 +terminator==2.1.3 +threadpoolctl==3.2.0 +tifffile==2023.12.9 +tinycss2==1.3.0 +tobler==0.11.2 +toml==0.10.2 +tomli==2.0.1 +toolz==0.12.1 +torch==2.1.1 +torch-tb-profiler==0.4.3 +torchaudio==2.1.1 +torchinfo==1.8.0 +torchmetrics==1.3.0.post0 +torchvision==0.16.1 +tornado==6.4 +tqdm==4.66.1 +traitlets==5.14.0 +trio==0.24.0 +trio-websocket==0.11.1 +triton==2.1.0 +twine==5.0.0 +typeguard==2.13.3 +typer==0.12.3 +types-python-dateutil==2.9.0.20240316 +typeshed_client==2.5.1 +typing-inspect==0.9.0 +typing_extensions==4.10.0 +tzdata==2023.3 +ua-parser==0.18.0 +ubuntu-drivers-common==0.0.0 +ubuntu-pro-client==8001 +ucimlrepo==0.0.7 +ufw==0.36.1 +ujson==5.9.0 +unattended-upgrades==0.1 +uri-template==1.3.0 +uritemplate==4.1.1 +urllib3==2.0.7 +usb-creator==0.3.7 +user-agents==2.2.0 +uvicorn==0.29.0 +uvloop==0.19.0 +viresclient==0.11.3 +virtualenv==20.25.3 +wadllib==1.3.6 +wandb==0.16.6 +watchdog==4.0.0 +watchfiles==0.21.0 +wcwidth==0.2.12 +webcolors==1.13 +webdriver-manager==4.0.1 +webencodings==0.5.1 +websocket-client==1.8.0 +websockets==12.0 +Werkzeug==2.3.8 +widgetsnbextension==4.0.9 +wsproto==1.2.0 +WTForms==3.1.2 +xarray==2023.11.0 +xdg==5 +xgboost==2.0.2 +xkit==0.0.0 +xmod==1.8.1 +xyzservices==2023.10.1 +yacs==0.1.8 +yarg==0.1.9 +yarl==1.9.3 +zeep==4.2.1 +zipp==1.0.0 +zope.interface==6.1 diff --git a/setup.py b/setup.py index 769885c..bf824a8 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ long_description_content_type="text/markdown", author_email=__email__, description="Pytorch & Lightning based framework for research and ml-pipeline automation.", - url="https://github.com/Jorgedavyd/lightorch", + url="https://github.com/Jorgedavyd/LighTorch", license="MIT", install_requires=["lightning", "torch", "torchvision", "optuna", "tqdm"], classifiers=[ @@ -23,9 +23,6 @@ "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/tests/test_adversarial.py b/tests/test_adversarial.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_nn.py b/tests/test_nn.py index 41bd736..d4936d2 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -192,3 +192,5 @@ def test_monte_carlo() -> None: def test_kan() -> None: # Placeholder for future implementation raise NotImplementedError("KAN test not implemented") + +# transformers, attention, mlp, etc. \ No newline at end of file diff --git a/tests/test_supervised.py b/tests/test_supervised.py index f2104b4..a38b63d 100644 --- a/tests/test_supervised.py +++ b/tests/test_supervised.py @@ -1,10 +1,12 @@ from lightorch.training.supervised import Module -from lightorch.nn.criterions import PeakSignalNoiseRatio +from lightorch.nn.criterions import MSELoss from lightorch.htuning.optuna import htuning -from .utils import Model, create_inputs, DataModule - +from .utils import create_inputs, DataModule +from torch import nn +import optuna +import torch import random -from torch import Tensor, nn +from torch import Tensor in_size: int = 32 input: Tensor = create_inputs(1, in_size) @@ -15,26 +17,73 @@ class SupModel(Module): def __init__(self, **hparams) -> None: super().__init__(**hparams) - self.criterion = PeakSignalNoiseRatio(1, randint) - self.model = Model(in_size) - + # Criterion + self.criterion = MSELoss() + + self.model = nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 1), + nn.Sigmoid() + ) + def forward(self, input: Tensor) -> Tensor: return self.model(input) -def objective(): - pass +def objective1(trial: optuna.trial.Trial): + return dict( + triggers = {'model': dict( + lr = trial.suggest_float('lr', 1e-4, 1e-1), + weight_decay = trial.suggest_float('weight_decay', 1e-4, 1e-1), + momentum = trial.suggest_float('momentum', 0.1, 0.7) + )}, + optimizer = 'sgd', + scheduler = 'onecycle', + scheduler_kwargs = dict( + max_lr = trial.suggest_float('max_lr', 1e-2, 1e-1) + ), + gradient_clip_algorithm = trial.suggest_categorical('clip_mode', ['value', 'norm']), + gradient_clip_value = trial.suggest_float('clip_value', 1e-3, 1e-2) + ) + +def objective2(trial: optuna.trial.Trial): + return dict( + optimizer = torch.optim.Adam, + optimizer_kwargs = dict( + lr = trial.suggest_float('lr', 1e-4, 1e-1), + weight_decay = trial.suggest_float('weight_decay', 1e-4, 1e-1) + ), + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau, + scheduler_kwargs = dict( + max_lr = trial.suggest_float('max_lr', 1e-2, 1e-1) + ), + gradient_clip_algorithm = trial.suggest_categorical('clip_mode', ['value', 'norm']), + gradient_clip_value = trial.suggest_float('clip_value', 1e-3, 1e-2) + ) def test_supervised() -> None: htuning( model_class=SupModel, - hparam_objective=objective, + hparam_objective=objective1, datamodule=DataModule, valid_metrics="MSE", datamodule_kwargs=dict(pin_memory=False, num_workers=1, batch_size=1), - directions="minimize", + directions=["minimize"], precision="high", n_trials=10, trianer_kwargs=dict(fast_dev_run=True), ) + + htuning( + model_class=SupModel, + hparam_objective=objective2, + datamodule=DataModule, + valid_metrics="MSE", + datamodule_kwargs=dict(pin_memory=False, num_workers=1, batch_size=1), + directions="minimize", + precision="medium", + n_trials=10, + trianer_kwargs=dict(fast_dev_run=True), + ) \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index c0b2cf8..0743e62 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,6 +24,10 @@ def __init__( self, ) -> None: pass + def __len__(self) -> int: + return 100 + def __getitem__(self, index) -> Tensor: + return torch.randn(10) class DataModule(LightningDataModule): @@ -32,9 +36,6 @@ def __init__(self, batch_size: int, pin_memory: bool = False, num_workers: int = self.pin_memory = pin_memory self.num_workers = num_workers - def setup(self) -> None: - self.train_ds - def train_dataloader(self) -> DataLoader: return DataLoader( Data(),