Skip to content

Commit

Permalink
#8: ModelTrainer class renamed to PyTorchModelTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
eyp committed Dec 1, 2020
1 parent 9eebaa1 commit 987a853
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
4 changes: 1 addition & 3 deletions client/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import sys
import requests
import random

from os import environ

from requests.exceptions import Timeout
from fastai.vision.all import *
from .utils import printf, model_params_to_request_params
from .utils import model_params_to_request_params
from .mnist_model_trainer import MnistModelTrainer
from .client_status import ClientStatus
from .config import DEFAULT_SERVER_URL
Expand Down
10 changes: 5 additions & 5 deletions client/mnist_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@

from .utils import printf
from .training_utils import mnist_loss, linear_model
from .model_trainer import ModelTrainer
from .pytorch_model_trainer import PyTorchModelTrainer


class MnistModelTrainer(ModelTrainer):
class MnistModelTrainer(PyTorchModelTrainer):
def __init__(self, model_params, client_config):
print('Initializing MnistModelTrainer...')
self.ACCURACY_THRESHOLD = 0.5
super().__init__(model_params, client_config)

def _ModelTrainer__train_epoch(self):
def _PyTorchModelTrainer__train_epoch(self):
for train_data, train_labels in self.training_dataloader:
self.__calculate_gradients(train_data, train_labels)
for model_param in self.model_params:
model_param.data -= model_param.grad * self.client_config.learning_rate
model_param.grad.zero_()

def _ModelTrainer__validate_epoch(self):
def _PyTorchModelTrainer__validate_epoch(self):
accuracies = [self.__accuracy(linear_model(train_data, weights=self.model_params[0], bias=self.model_params[1]), train_labels) for
train_data, train_labels in
self.validation_dataloader]
Expand All @@ -34,7 +34,7 @@ def __accuracy(self, train_data, train_labels):
corrections = (predictions > self.ACCURACY_THRESHOLD) == train_labels
return corrections.float().mean()

def _ModelTrainer__load_datasets(self):
def _PyTorchModelTrainer__load_datasets(self):
print('Loading dataset MNIST_SAMPLE...')
path = untar_data(URLs.MNIST_SAMPLE)
print('Content of MNIST_SAMPLE:', path.ls())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import ABC, abstractmethod


class ModelTrainer(ABC):
class PyTorchModelTrainer(ABC):
"""
This is the base class of model trainers.
If you want to implement a new training, your class must inherit from ModelTrainer and
Expand Down

0 comments on commit 987a853

Please sign in to comment.