Skip to content

Commit

Permalink
#8: Introduced ModelTrainer abstract class
Browse files Browse the repository at this point in the history
  • Loading branch information
eyp committed Nov 30, 2020
1 parent e1b0d4a commit 32f8e51
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 20 deletions.
28 changes: 8 additions & 20 deletions client/mnist_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,23 @@

from .utils import printf
from .training_utils import mnist_loss, linear_model
from .model_trainer import ModelTrainer


class MnistModelTrainer:
class MnistModelTrainer(ModelTrainer):
def __init__(self, model_params, client_config):
print('Initializing MnistModelTrainer...')
self.ACCURACY_THRESHOLD = 0.5
self.training_dataloader = None
self.validation_dataloader = None

self.client_config = client_config
self.model_params = model_params

def train_model(self):
# print('Initial params:', self.model_params)
training_dataset, validation_dataset = self.__load_datasets()
self.training_dataloader = DataLoader(training_dataset, batch_size=self.client_config.batch_size)
self.validation_dataloader = DataLoader(validation_dataset, batch_size=self.client_config.batch_size)
for epoch in range(self.client_config.epochs):
self.__train_epoch()
print('Accuracy of model trained at epoch', epoch + 1, ':', self.__validate_epoch(), end='\n', flush=True)
return self.model_params

def __train_epoch(self):
super().__init__(model_params, client_config)

def _ModelTrainer__train_epoch(self):
for train_data, train_labels in self.training_dataloader:
self.__calculate_gradients(train_data, train_labels)
for model_param in self.model_params:
model_param.data -= model_param.grad * self.client_config.learning_rate
model_param.grad.zero_()

def __validate_epoch(self):
def _ModelTrainer__validate_epoch(self):
accuracies = [self.__accuracy(linear_model(train_data, weights=self.model_params[0], bias=self.model_params[1]), train_labels) for
train_data, train_labels in
self.validation_dataloader]
Expand All @@ -46,7 +34,7 @@ def __accuracy(self, train_data, train_labels):
corrections = (predictions > self.ACCURACY_THRESHOLD) == train_labels
return corrections.float().mean()

def __load_datasets(self):
def _ModelTrainer__load_datasets(self):
print('Loading dataset MNIST_SAMPLE...')
path = untar_data(URLs.MNIST_SAMPLE)
print('Content of MNIST_SAMPLE:', path.ls())
Expand Down
51 changes: 51 additions & 0 deletions client/model_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from fastai.vision.all import DataLoader
from abc import ABC, abstractmethod


class ModelTrainer(ABC):
"""
This is the base class of model trainers.
If you want to implement a new training, your class must inherit from ModelTrainer and
implement the abstract methods.
See an implementation in mnist_model_trainer.py
"""

def __init__(self, model_params, client_config):
self.training_dataloader = None
self.validation_dataloader = None

self.client_config = client_config
self.model_params = model_params

def train_model(self):
training_dataset, validation_dataset = self.__load_datasets()
self.training_dataloader = DataLoader(training_dataset, batch_size=self.client_config.batch_size)
self.validation_dataloader = DataLoader(validation_dataset, batch_size=self.client_config.batch_size)
for epoch in range(self.client_config.epochs):
self.__train_epoch()
print('Accuracy of model trained at epoch', epoch + 1, ':', self.__validate_epoch(), end='\n', flush=True)
return self.model_params

@abstractmethod
def __train_epoch(self):
"""
Implements the actual model training. It will be called the times defined in 'client_config.epochs'
"""
raise NotImplementedError()

@abstractmethod
def __validate_epoch(self):
"""
Validates the training
:returns the accuracy of the model as float
"""
raise NotImplementedError()

@abstractmethod
def __load_datasets(self):
"""
Load the dataset used for training the model in ModelTrainer.train_model().
:returns training_dataset, validation_dataset
"""
raise NotImplementedError()

0 comments on commit 32f8e51

Please sign in to comment.