Skip to content

Commit

Permalink
Merge branch 'feature-993-updated' of https://github.com/DSGT-DLP/Dee…
Browse files Browse the repository at this point in the history
…p-Learning-Playground into feature-993-updated
  • Loading branch information
karkir0003 committed Nov 20, 2023
2 parents 08ff344 + 69d1477 commit 4f33824
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 24 deletions.
8 changes: 7 additions & 1 deletion training/training/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@


class TrainTestDatasetCreator(ABC):
"Creator that creates train and test PyTorch datasets from a given dataset"
"""
Creator that creates train and test PyTorch datasets from a given dataset.
This class serves as an abstract base class for creating training and testing
datasets compatible with PyTorch's dataset structure. Implementations should
define specific methods for dataset processing and loading.
"""

@abstractmethod
def createTrainDataset(self) -> Dataset:
Expand Down
4 changes: 0 additions & 4 deletions training/training/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ def _train_step(self, inputs: torch.Tensor, labels: torch.Tensor):
self.optimizer.zero_grad() # zero out gradient for each batch
self.model.forward(inputs) # make prediction on input
self._outputs: torch.Tensor = self.model(inputs) # make prediction on input
print('MODEL FORWARD PASS DONE!!!!')
print(f'output dim: {self._outputs.shape}')
print(f'label dim: {labels.shape}')
print(f'loss function used: {self.criterionHandler}')
loss = self.criterionHandler.compute_loss(self._outputs, labels)
loss.backward() # backpropagation
self.optimizer.step() # adjust optimizer weights
Expand Down
16 changes: 1 addition & 15 deletions training/training/routes/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,8 @@ def imageTrain(request: HttpRequest, imageParams: ImageParams):
print(vars(dataCreator))
train_loader = dataCreator.createTrainDataset()
test_loader = dataCreator.createTestDataset()
# train_loader = DataLoader(
# dataCreator.createTrainDataset(),
# batch_size=imageParams.batch_size,
# shuffle=False,
# drop_last=True,
# )

# test_loader = DataLoader(
# dataCreator.createTestDataset(),
# batch_size=imageParams.batch_size,
# shuffle=False,
# drop_last=True,
# )

model = DLModel.fromLayerParamsList(imageParams.user_arch)
print(f'model is: {model}')
# print(f'model is: {model}')
optimizer = getOptimizer(model, imageParams.optimizer_name, 0.05)
criterionHandler = getCriterionHandler(imageParams.criterion)
if imageParams.problem_type == "CLASSIFICATION":
Expand Down
4 changes: 0 additions & 4 deletions training/training/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
from training.routes.datasets.default.columns import router as default_dataset_router
from training.routes.tabular.tabular import router as tabular_router
from training.routes.image.image import router as image_router
# from training.routes.datasets.default import get_default_datasets_router
# from training.routes.tabular import get_tabular_router
# from training.routes.image import image_router

api = NinjaAPI()


Expand Down

0 comments on commit 4f33824

Please sign in to comment.