Skip to content

Commit

Permalink
#8: Some classes attributes are now local variables
Browse files Browse the repository at this point in the history
  • Loading branch information
eyp committed Nov 30, 2020
1 parent f05e282 commit e1b0d4a
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions client/mnist_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@ def __init__(self, model_params, client_config):
self.ACCURACY_THRESHOLD = 0.5
self.training_dataloader = None
self.validation_dataloader = None
self.training_dataset = None
self.validation_dataset = None

self.client_config = client_config
self.model_params = model_params
self.learning_rate = self.client_config.learning_rate
self.epochs = self.client_config.epochs

def train_model(self):
# print('Initial params:', self.model_params)
self.__load_datasets()
self.training_dataloader = DataLoader(self.training_dataset, batch_size=self.client_config.batch_size)
self.validation_dataloader = DataLoader(self.validation_dataset, batch_size=self.client_config.batch_size)
for epoch in range(self.epochs):
training_dataset, validation_dataset = self.__load_datasets()
self.training_dataloader = DataLoader(training_dataset, batch_size=self.client_config.batch_size)
self.validation_dataloader = DataLoader(validation_dataset, batch_size=self.client_config.batch_size)
for epoch in range(self.client_config.epochs):
self.__train_epoch()
print('Accuracy of model trained at epoch', epoch + 1, ':', self.__validate_epoch(), end='\n', flush=True)
return self.model_params
Expand All @@ -31,7 +27,7 @@ def __train_epoch(self):
for train_data, train_labels in self.training_dataloader:
self.__calculate_gradients(train_data, train_labels)
for model_param in self.model_params:
model_param.data -= model_param.grad * self.learning_rate
model_param.data -= model_param.grad * self.client_config.learning_rate
model_param.grad.zero_()

def __validate_epoch(self):
Expand Down Expand Up @@ -79,12 +75,13 @@ def __load_datasets(self):

train_images = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28 * 28)
train_labels = tensor([1] * len(threes) + [0] * len(sevens)).unsqueeze(1)
self.training_dataset = list(zip(train_images, train_labels))
training_dataset = list(zip(train_images, train_labels))
print('Training images shape:', train_images.shape, ', training labels shape:', train_labels.shape)

valid_images = torch.cat([valid_three_tensors, valid_seven_tensors]).view(-1, 28 * 28)
valid_labels = tensor([1] * len(valid_three_tensors) + [0] * len(valid_seven_tensors)).unsqueeze(1)
self.validation_dataset = list(zip(valid_images, valid_labels))
validation_dataset = list(zip(valid_images, valid_labels))
print('Dataset ready to be used')
sys.stdout.flush()
return training_dataset, validation_dataset

0 comments on commit e1b0d4a

Please sign in to comment.