diff --git a/client/mnist_model_trainer.py b/client/mnist_model_trainer.py index 52e0dd2..6d82822 100644 --- a/client/mnist_model_trainer.py +++ b/client/mnist_model_trainer.py @@ -1,8 +1,3 @@ -import sys -import torch -import random - -from fastai.data.load import DataLoader from fastai.vision.all import * from .utils import printf @@ -11,6 +6,7 @@ class MnistModelTrainer: def __init__(self, model_params, client_config): + self.ACCURACY_THRESHOLD = 0.5 self.training_dataloader = None self.validation_dataloader = None self.training_dataset = None @@ -21,7 +17,40 @@ def __init__(self, model_params, client_config): self.learning_rate = self.client_config.learning_rate self.epochs = self.client_config.epochs - def load_datasets(self): + def train_model(self): + # print('Initial params:', self.model_params) + self.__load_datasets() + self.training_dataloader = DataLoader(self.training_dataset, batch_size=self.client_config.batch_size) + self.validation_dataloader = DataLoader(self.validation_dataset, batch_size=self.client_config.batch_size) + for epoch in range(self.epochs): + self.__train_epoch() + print('Accuracy of model trained at epoch', epoch + 1, ':', self.__validate_epoch(), end='\n', flush=True) + return self.model_params + + def __train_epoch(self): + for train_data, train_labels in self.training_dataloader: + self.__calculate_gradients(train_data, train_labels) + for model_param in self.model_params: + model_param.data -= model_param.grad * self.learning_rate + model_param.grad.zero_() + + def __validate_epoch(self): + accuracies = [self.__accuracy(linear_model(train_data, weights=self.model_params[0], bias=self.model_params[1]), train_labels) for + train_data, train_labels in + self.validation_dataloader] + return round(torch.stack(accuracies).mean().item(), 4) + + def __calculate_gradients(self, train_data, train_labels): + predictions = linear_model(train_data, self.model_params[0], self.model_params[1]) + loss = mnist_loss(predictions, train_labels) + loss.backward() + + def __accuracy(self, train_data, train_labels): + predictions = train_data.sigmoid() + corrections = (predictions > self.ACCURACY_THRESHOLD) == train_labels + return corrections.float().mean() + + def __load_datasets(self): print('Loading dataset MNIST_SAMPLE...') path = untar_data(URLs.MNIST_SAMPLE) print('Content of MNIST_SAMPLE:', path.ls()) @@ -59,35 +88,3 @@ def load_datasets(self): print('Dataset ready to be used') sys.stdout.flush() - def train_model(self): - # print('Initial params:', self.model_params) - self.load_datasets() - self.training_dataloader = DataLoader(self.training_dataset, batch_size=self.client_config.batch_size) - self.validation_dataloader = DataLoader(self.validation_dataset, batch_size=self.client_config.batch_size) - for epoch in range(self.epochs): - self.train_epoch() - print('Accuracy of model trained at epoch', epoch + 1, ':', self.validate_epoch(), end='\n', flush=True) - return self.model_params - - def train_epoch(self): - for train_data, train_labels in self.training_dataloader: - self.calculate_gradients(train_data, train_labels) - for model_param in self.model_params: - model_param.data -= model_param.grad * self.learning_rate - model_param.grad.zero_() - - def validate_epoch(self): - accuracies = [self.accuracy(linear_model(train_data, weights=self.model_params[0], bias=self.model_params[1]), train_labels) for - train_data, train_labels in - self.validation_dataloader] - return round(torch.stack(accuracies).mean().item(), 4) - - def accuracy(self, train_data, train_labels): - predictions = train_data.sigmoid() - corrections = (predictions > 0.5) == train_labels - return corrections.float().mean() - - def calculate_gradients(self, train_data, train_labels): - predictions = linear_model(train_data, self.model_params[0], self.model_params[1]) - loss = mnist_loss(predictions, train_labels) - loss.backward()