diff --git a/client/client.py b/client/client.py index 24e5857..a9538ed 100644 --- a/client/client.py +++ b/client/client.py @@ -1,12 +1,10 @@ import sys import requests -import random from os import environ from requests.exceptions import Timeout -from fastai.vision.all import * -from .utils import printf, model_params_to_request_params +from .utils import model_params_to_request_params from .mnist_model_trainer import MnistModelTrainer from .client_status import ClientStatus from .config import DEFAULT_SERVER_URL diff --git a/client/mnist_model_trainer.py b/client/mnist_model_trainer.py index a4700e2..9c69a2d 100644 --- a/client/mnist_model_trainer.py +++ b/client/mnist_model_trainer.py @@ -2,23 +2,23 @@ from .utils import printf from .training_utils import mnist_loss, linear_model -from .model_trainer import ModelTrainer +from .pytorch_model_trainer import PyTorchModelTrainer -class MnistModelTrainer(ModelTrainer): +class MnistModelTrainer(PyTorchModelTrainer): def __init__(self, model_params, client_config): print('Initializing MnistModelTrainer...') self.ACCURACY_THRESHOLD = 0.5 super().__init__(model_params, client_config) - def _ModelTrainer__train_epoch(self): + def _PyTorchModelTrainer__train_epoch(self): for train_data, train_labels in self.training_dataloader: self.__calculate_gradients(train_data, train_labels) for model_param in self.model_params: model_param.data -= model_param.grad * self.client_config.learning_rate model_param.grad.zero_() - def _ModelTrainer__validate_epoch(self): + def _PyTorchModelTrainer__validate_epoch(self): accuracies = [self.__accuracy(linear_model(train_data, weights=self.model_params[0], bias=self.model_params[1]), train_labels) for train_data, train_labels in self.validation_dataloader] @@ -34,7 +34,7 @@ def __accuracy(self, train_data, train_labels): corrections = (predictions > self.ACCURACY_THRESHOLD) == train_labels return corrections.float().mean() - def _ModelTrainer__load_datasets(self): + def _PyTorchModelTrainer__load_datasets(self): print('Loading dataset MNIST_SAMPLE...') path = untar_data(URLs.MNIST_SAMPLE) print('Content of MNIST_SAMPLE:', path.ls()) diff --git a/client/model_trainer.py b/client/pytorch_model_trainer.py similarity index 98% rename from client/model_trainer.py rename to client/pytorch_model_trainer.py index 853ce10..88cd24b 100644 --- a/client/model_trainer.py +++ b/client/pytorch_model_trainer.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod -class ModelTrainer(ABC): +class PyTorchModelTrainer(ABC): """ This is the base class of model trainers. If you want to implement a new training, your class must inherit from ModelTrainer and