From e1b0d4a63da19f60dd7410a1e07a7889cf773650 Mon Sep 17 00:00:00 2001 From: eyp Date: Mon, 30 Nov 2020 08:41:41 +0100 Subject: [PATCH] #8: Some classes attributes are now local variables --- client/mnist_model_trainer.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/client/mnist_model_trainer.py b/client/mnist_model_trainer.py index 6d82822..872b8dd 100644 --- a/client/mnist_model_trainer.py +++ b/client/mnist_model_trainer.py @@ -9,20 +9,16 @@ def __init__(self, model_params, client_config): self.ACCURACY_THRESHOLD = 0.5 self.training_dataloader = None self.validation_dataloader = None - self.training_dataset = None - self.validation_dataset = None self.client_config = client_config self.model_params = model_params - self.learning_rate = self.client_config.learning_rate - self.epochs = self.client_config.epochs def train_model(self): # print('Initial params:', self.model_params) - self.__load_datasets() - self.training_dataloader = DataLoader(self.training_dataset, batch_size=self.client_config.batch_size) - self.validation_dataloader = DataLoader(self.validation_dataset, batch_size=self.client_config.batch_size) - for epoch in range(self.epochs): + training_dataset, validation_dataset = self.__load_datasets() + self.training_dataloader = DataLoader(training_dataset, batch_size=self.client_config.batch_size) + self.validation_dataloader = DataLoader(validation_dataset, batch_size=self.client_config.batch_size) + for epoch in range(self.client_config.epochs): self.__train_epoch() print('Accuracy of model trained at epoch', epoch + 1, ':', self.__validate_epoch(), end='\n', flush=True) return self.model_params @@ -31,7 +27,7 @@ def __train_epoch(self): for train_data, train_labels in self.training_dataloader: self.__calculate_gradients(train_data, train_labels) for model_param in self.model_params: - model_param.data -= model_param.grad * self.learning_rate + model_param.data -= model_param.grad * self.client_config.learning_rate model_param.grad.zero_() def __validate_epoch(self): @@ -79,12 +75,13 @@ def __load_datasets(self): train_images = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28 * 28) train_labels = tensor([1] * len(threes) + [0] * len(sevens)).unsqueeze(1) - self.training_dataset = list(zip(train_images, train_labels)) + training_dataset = list(zip(train_images, train_labels)) print('Training images shape:', train_images.shape, ', training labels shape:', train_labels.shape) valid_images = torch.cat([valid_three_tensors, valid_seven_tensors]).view(-1, 28 * 28) valid_labels = tensor([1] * len(valid_three_tensors) + [0] * len(valid_seven_tensors)).unsqueeze(1) - self.validation_dataset = list(zip(valid_images, valid_labels)) + validation_dataset = list(zip(valid_images, valid_labels)) print('Dataset ready to be used') sys.stdout.flush() + return training_dataset, validation_dataset