Skip to content
This repository has been archived by the owner on Oct 4, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into dependabot-pip-gradio-lte-4.0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Nov 1, 2023
2 parents 85f4371 + fa2e88b commit abc9969
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
### Install Lightning

```bash
pip install lightning
pip install lightning[app]
```

### Locally
Expand Down
8 changes: 4 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import os.path as ops

import lightning as L
from lightning.app import CloudCompute, LightningFlow, LightningApp

from quick_start.components import ImageServeGradio, PyTorchLightningScript


class TrainDeploy(L.LightningFlow):
class TrainDeploy(LightningFlow):
def __init__(self):
super().__init__()
self.train_work = PyTorchLightningScript(
script_path=ops.join(ops.dirname(__file__), "./train_script.py"),
script_args=["--trainer.max_epochs=10"],
cloud_compute=L.CloudCompute("cpu-medium", idle_timeout=60),
cloud_compute=CloudCompute("cpu-medium", idle_timeout=60),
)

self.serve_work = ImageServeGradio()
Expand All @@ -32,4 +32,4 @@ def configure_layout(self):
return tabs


app = L.LightningApp(TrainDeploy())
app = LightningApp(TrainDeploy())
8 changes: 4 additions & 4 deletions app_hpo.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os.path as ops

import lightning as L
import optuna
from lightning.app import LightningFlow, CloudCompute, LightningApp
from lightning_hpo import BaseObjective, Optimizer

from quick_start.components import ImageServeGradio, PyTorchLightningScript
Expand All @@ -13,7 +13,7 @@ def distributions():
return {"model.lr": optuna.distributions.LogUniformDistribution(0.0001, 0.1)}


class TrainDeploy(L.LightningFlow):
class TrainDeploy(LightningFlow):
def __init__(self):
super().__init__()
self.train_work = Optimizer(
Expand All @@ -23,7 +23,7 @@ def __init__(self):
n_trials=4,
)

self.serve_work = ImageServeGradio(L.CloudCompute("cpu"))
self.serve_work = ImageServeGradio(CloudCompute("cpu"))

def run(self):
# 1. Run the python script that trains the model
Expand All @@ -39,4 +39,4 @@ def configure_layout(self):
return [tab_1, tab_2]


app = L.LightningApp(TrainDeploy())
app = LightningApp(TrainDeploy())
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ gradio <4.1.0
pyyaml <=6.0.1
protobuf <4.25.0 # 4.21 breaks with wandb, tensorboard, or pytorch-lightning: https://github.com/protocolbuffers/protobuf/issues/10048
websockets
lightning
lightning[app] >=2.1.0
tensorboard
2 changes: 1 addition & 1 deletion tests/integration_app/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from lightning import LightningApp, LightningFlow
from lightning.app import LightningApp, LightningFlow

from quick_start import PyTorchLightningScript, ImageServeGradio

Expand Down
6 changes: 3 additions & 3 deletions train_script.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os

import lightning as L
import torch
import torchvision.transforms as T
from lightning import LightningModule, LightningDataModule
from lightning.pytorch.cli import LightningCLI
from torch import nn
from torch.nn import functional as F
Expand Down Expand Up @@ -35,7 +35,7 @@ def forward(self, x):
return output


class ImageClassifier(L.LightningModule):
class ImageClassifier(LightningModule):
def __init__(self, model=None, lr=1.0, gamma=0.7, batch_size=32):
super().__init__()
self.save_hyperparameters(ignore="model")
Expand Down Expand Up @@ -69,7 +69,7 @@ def configure_optimizers(self):
return torch.optim.Adadelta(self.model.parameters(), lr=self.hparams.lr)


class MNISTDataModule(L.LightningDataModule):
class MNISTDataModule(LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.save_hyperparameters()
Expand Down

0 comments on commit abc9969

Please sign in to comment.