Skip to content

Commit

Permalink
#8: Better class design for MnistModelTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
eyp committed Nov 28, 2020
1 parent d1baefb commit f05e282
Showing 1 changed file with 35 additions and 38 deletions.
73 changes: 35 additions & 38 deletions client/mnist_model_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
import sys
import torch
import random

from fastai.data.load import DataLoader
from fastai.vision.all import *

from .utils import printf
Expand All @@ -11,6 +6,7 @@

class MnistModelTrainer:
def __init__(self, model_params, client_config):
self.ACCURACY_THRESHOLD = 0.5
self.training_dataloader = None
self.validation_dataloader = None
self.training_dataset = None
Expand All @@ -21,7 +17,40 @@ def __init__(self, model_params, client_config):
self.learning_rate = self.client_config.learning_rate
self.epochs = self.client_config.epochs

def load_datasets(self):
def train_model(self):
# print('Initial params:', self.model_params)
self.__load_datasets()
self.training_dataloader = DataLoader(self.training_dataset, batch_size=self.client_config.batch_size)
self.validation_dataloader = DataLoader(self.validation_dataset, batch_size=self.client_config.batch_size)
for epoch in range(self.epochs):
self.__train_epoch()
print('Accuracy of model trained at epoch', epoch + 1, ':', self.__validate_epoch(), end='\n', flush=True)
return self.model_params

def __train_epoch(self):
for train_data, train_labels in self.training_dataloader:
self.__calculate_gradients(train_data, train_labels)
for model_param in self.model_params:
model_param.data -= model_param.grad * self.learning_rate
model_param.grad.zero_()

def __validate_epoch(self):
accuracies = [self.__accuracy(linear_model(train_data, weights=self.model_params[0], bias=self.model_params[1]), train_labels) for
train_data, train_labels in
self.validation_dataloader]
return round(torch.stack(accuracies).mean().item(), 4)

def __calculate_gradients(self, train_data, train_labels):
predictions = linear_model(train_data, self.model_params[0], self.model_params[1])
loss = mnist_loss(predictions, train_labels)
loss.backward()

def __accuracy(self, train_data, train_labels):
predictions = train_data.sigmoid()
corrections = (predictions > self.ACCURACY_THRESHOLD) == train_labels
return corrections.float().mean()

def __load_datasets(self):
print('Loading dataset MNIST_SAMPLE...')
path = untar_data(URLs.MNIST_SAMPLE)
print('Content of MNIST_SAMPLE:', path.ls())
Expand Down Expand Up @@ -59,35 +88,3 @@ def load_datasets(self):
print('Dataset ready to be used')
sys.stdout.flush()

def train_model(self):
# print('Initial params:', self.model_params)
self.load_datasets()
self.training_dataloader = DataLoader(self.training_dataset, batch_size=self.client_config.batch_size)
self.validation_dataloader = DataLoader(self.validation_dataset, batch_size=self.client_config.batch_size)
for epoch in range(self.epochs):
self.train_epoch()
print('Accuracy of model trained at epoch', epoch + 1, ':', self.validate_epoch(), end='\n', flush=True)
return self.model_params

def train_epoch(self):
for train_data, train_labels in self.training_dataloader:
self.calculate_gradients(train_data, train_labels)
for model_param in self.model_params:
model_param.data -= model_param.grad * self.learning_rate
model_param.grad.zero_()

def validate_epoch(self):
accuracies = [self.accuracy(linear_model(train_data, weights=self.model_params[0], bias=self.model_params[1]), train_labels) for
train_data, train_labels in
self.validation_dataloader]
return round(torch.stack(accuracies).mean().item(), 4)

def accuracy(self, train_data, train_labels):
predictions = train_data.sigmoid()
corrections = (predictions > 0.5) == train_labels
return corrections.float().mean()

def calculate_gradients(self, train_data, train_labels):
predictions = linear_model(train_data, self.model_params[0], self.model_params[1])
loss = mnist_loss(predictions, train_labels)
loss.backward()

0 comments on commit f05e282

Please sign in to comment.