From 32f8e510b2670f6fb03144063b0d6d16d9b55300 Mon Sep 17 00:00:00 2001 From: eyp Date: Mon, 30 Nov 2020 09:46:20 +0100 Subject: [PATCH] #8: Introduced ModelTrainer abstract class --- client/mnist_model_trainer.py | 28 ++++++------------- client/model_trainer.py | 51 +++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 20 deletions(-) create mode 100644 client/model_trainer.py diff --git a/client/mnist_model_trainer.py b/client/mnist_model_trainer.py index 872b8dd..a4700e2 100644 --- a/client/mnist_model_trainer.py +++ b/client/mnist_model_trainer.py @@ -2,35 +2,23 @@ from .utils import printf from .training_utils import mnist_loss, linear_model +from .model_trainer import ModelTrainer -class MnistModelTrainer: +class MnistModelTrainer(ModelTrainer): def __init__(self, model_params, client_config): + print('Initializing MnistModelTrainer...') self.ACCURACY_THRESHOLD = 0.5 - self.training_dataloader = None - self.validation_dataloader = None - - self.client_config = client_config - self.model_params = model_params - - def train_model(self): - # print('Initial params:', self.model_params) - training_dataset, validation_dataset = self.__load_datasets() - self.training_dataloader = DataLoader(training_dataset, batch_size=self.client_config.batch_size) - self.validation_dataloader = DataLoader(validation_dataset, batch_size=self.client_config.batch_size) - for epoch in range(self.client_config.epochs): - self.__train_epoch() - print('Accuracy of model trained at epoch', epoch + 1, ':', self.__validate_epoch(), end='\n', flush=True) - return self.model_params - - def __train_epoch(self): + super().__init__(model_params, client_config) + + def _ModelTrainer__train_epoch(self): for train_data, train_labels in self.training_dataloader: self.__calculate_gradients(train_data, train_labels) for model_param in self.model_params: model_param.data -= model_param.grad * self.client_config.learning_rate model_param.grad.zero_() - def __validate_epoch(self): + def _ModelTrainer__validate_epoch(self): accuracies = [self.__accuracy(linear_model(train_data, weights=self.model_params[0], bias=self.model_params[1]), train_labels) for train_data, train_labels in self.validation_dataloader] @@ -46,7 +34,7 @@ def __accuracy(self, train_data, train_labels): corrections = (predictions > self.ACCURACY_THRESHOLD) == train_labels return corrections.float().mean() - def __load_datasets(self): + def _ModelTrainer__load_datasets(self): print('Loading dataset MNIST_SAMPLE...') path = untar_data(URLs.MNIST_SAMPLE) print('Content of MNIST_SAMPLE:', path.ls()) diff --git a/client/model_trainer.py b/client/model_trainer.py new file mode 100644 index 0000000..853ce10 --- /dev/null +++ b/client/model_trainer.py @@ -0,0 +1,51 @@ +from fastai.vision.all import DataLoader +from abc import ABC, abstractmethod + + +class ModelTrainer(ABC): + """ + This is the base class of model trainers. + If you want to implement a new training, your class must inherit from ModelTrainer and + implement the abstract methods. + See an implementation in mnist_model_trainer.py + """ + + def __init__(self, model_params, client_config): + self.training_dataloader = None + self.validation_dataloader = None + + self.client_config = client_config + self.model_params = model_params + + def train_model(self): + training_dataset, validation_dataset = self.__load_datasets() + self.training_dataloader = DataLoader(training_dataset, batch_size=self.client_config.batch_size) + self.validation_dataloader = DataLoader(validation_dataset, batch_size=self.client_config.batch_size) + for epoch in range(self.client_config.epochs): + self.__train_epoch() + print('Accuracy of model trained at epoch', epoch + 1, ':', self.__validate_epoch(), end='\n', flush=True) + return self.model_params + + @abstractmethod + def __train_epoch(self): + """ + Implements the actual model training. It will be called the times defined in 'client_config.epochs' + """ + raise NotImplementedError() + + @abstractmethod + def __validate_epoch(self): + """ + Validates the training + :returns the accuracy of the model as float + """ + raise NotImplementedError() + + @abstractmethod + def __load_datasets(self): + """ + Load the dataset used for training the model in ModelTrainer.train_model(). + :returns training_dataset, validation_dataset + """ + raise NotImplementedError() +