Skip to content

Commit

Permalink
updates to README
Browse files Browse the repository at this point in the history
  • Loading branch information
Vivek Khimani committed Nov 13, 2019
1 parent 79d0a27 commit 1dfc2fa
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 10 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
# Implementing Federated Learning using PySyft
- More details to come.

- Dataset - MNIST
- Number of Workers - 32
- Classification Model - CNN (see the details in models directory)
- Tools Used - PySyft, PyTorch

16 changes: 7 additions & 9 deletions main_fed.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,14 @@ def virtualWorkers():
thirty_one=sy.VirtualWorker(hook, id="thirty_one")
thirty_two=sy.VirtualWorker(hook, id="thirty_two")

return one,two,three,four,five,six,seven,eight,nine,ten,eleven,twelve,thirteen,fourteen,fifteen,sixteen,seventeen,eighteen,nineteen,twenty,twenty_one,twenty_two,twenty_three,twenty_four,twenty_five,twenty_six,twenty_seven,twenty_eight,twenty_nine,thirty,thirty_one,thirty_two

one,two,three,four,five,six,seven,eight,nine,ten,eleven,twelve,thirteen,fourteen,fifteen,sixteen,seventeen,eighteen,nineteen,twenty,twenty_one,twenty_two,twenty_three,twenty_four,twenty_five,twenty_six,twenty_seven,twenty_eight,twenty_nine,thirty,thirty_one,thirty_two = virtualWorkers()

return [one,two,three,four,five,six,seven,eight,nine,ten,eleven,twelve,thirteen,fourteen,fifteen,sixteen,seventeen,eighteen,nineteen,twenty,twenty_one,twenty_two,twenty_three,twenty_four,twenty_five,twenty_six,twenty_seven,twenty_eight,twenty_nine,thirty,thirty_one,thirty_two]

vList = virtualWorkers()

def loadMNISTData():
federated_train_loader = sy.FederatedDataLoader(
datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])).federate((one,two,three,four,five,six,seven,eight,nine,ten,eleven,twelve,thirteen,fourteen,fifteen,sixteen,seventeen,eighteen,nineteen,twenty,twenty_one,twenty_two,twenty_three,twenty_four,twenty_five,twenty_six,twenty_seven,twenty_eight,twenty_nine,thirty,thirty_one,thirty_two)),batch_size=Arguments.args.batch_size, shuffle=True, **Arguments.kwargs)
datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])).federate((vList[0],vList[1],vList[2],vList[3],vList[4],vList[5],vList[6],vList[7],vList[8],vList[9],vList[10],vList[11],vList[12],vList[13],vList[14],vList[15],vList[16],vList[17],vList[18],vList[19],vList[20],vList[21],vList[22],vList[23],vList[24],vList[25],vList[26],vList[27],vList[28],vList[29],vList[30],vList[31])),batch_size=Arguments.args.batch_size, shuffle=True, **Arguments.kwargs)

test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])),batch_size=Arguments.args.test_batch_size, shuffle=True, **Arguments.kwargs)

return federated_train_loader,test_loader
Expand All @@ -67,7 +65,7 @@ def train(args, model, device, train_loader, optimizer, epoch):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss = nn.CrossEntropyLoss(output, target)
loss.backward()
optimizer.step()
model.get() # <-- NEW: get the model back
Expand All @@ -85,8 +83,8 @@ def test(args, model, device, test_loader):
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
test_loss += F.cross_entropy(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
Expand Down
Binary file modified models/__pycache__/CNN.cpython-36.pyc
Binary file not shown.
1 change: 1 addition & 0 deletions utils/Arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ def __init__(self):
self.batch_size = 64
self.test_batch_size = 128
self.epochs = 50
self.local_epochs = 5
self.lr = 0.01
self.momentum = 0.5
self.no_cuda = False
Expand Down
Binary file modified utils/__pycache__/Arguments.cpython-36.pyc
Binary file not shown.

0 comments on commit 1dfc2fa

Please sign in to comment.