-
Notifications
You must be signed in to change notification settings - Fork 4
/
clients.py
76 lines (61 loc) · 3.24 KB
/
clients.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader, RandomSampler
from getData import GetDataSet
import random
class client(object):
def __init__(self, trainDataSet, dev):
self.train_ds = trainDataSet
self.dev = dev
self.train_dl = None
self.local_parameters = None
def localUpdate(self, size, localEpoch, localBatchSize, Net, lossFun, opti, global_parameters):
Net.load_state_dict(global_parameters, strict=True)
self.train_dl = DataLoader(self.train_ds, batch_size=localBatchSize, sampler=RandomSampler(self.train_ds, replacement=False, num_samples=size))
for epoch in range(localEpoch):
for data, label in self.train_dl:
data, label = data.to(self.dev), label.to(self.dev)
preds = Net(data)
loss = lossFun(preds, label)
loss.backward()
opti.step()
opti.zero_grad()
return Net.state_dict()
def local_val(self):
pass
class ClientsGroup(object):
def __init__(self, dataSetName, isIID, numOfClients, dev):
self.data_set_name = dataSetName
self.is_iid = isIID
self.num_of_clients = numOfClients
self.dev = dev
self.clients_set = {}
self.test_data_loader = None
self.dataSetBalanceAllocation()
def dataSetBalanceAllocation(self):
mnistDataSet = GetDataSet(self.data_set_name, self.is_iid)
test_data = torch.tensor(mnistDataSet.test_data)
test_label = torch.argmax(torch.tensor(mnistDataSet.test_label), dim=1)
self.test_data_loader = DataLoader(TensorDataset( test_data, test_label), batch_size=100, shuffle=False)
train_data = mnistDataSet.train_data
train_label = mnistDataSet.train_label
shard_size = mnistDataSet.train_data_size // self.num_of_clients // 2
shards_id = np.random.permutation(mnistDataSet.train_data_size // shard_size)
for i in range(self.num_of_clients):
shards_id1 = shards_id[i * 2]
shards_id2 = shards_id[i * 2 + 1]
data_shards1 = train_data[shards_id1 * shard_size: shards_id1 * shard_size + shard_size]
data_shards2 = train_data[shards_id2 * shard_size: shards_id2 * shard_size + shard_size]
label_shards1 = train_label[shards_id1 * shard_size: shards_id1 * shard_size + shard_size]
label_shards2 = train_label[shards_id2 * shard_size: shards_id2 * shard_size + shard_size]
local_data, local_label = np.vstack((data_shards1, data_shards2)), np.vstack((label_shards1, label_shards2))
local_label = np.argmax(local_label, axis=1)
# drop_index = random.randint(1, 20)
# someone = client(TensorDataset(torch.tensor(local_data[:-10*drop_index]), torch.tensor(local_label[:-10*drop_index])), self.dev)
someone = client(TensorDataset(torch.tensor(local_data), torch.tensor(local_label)), self.dev)
self.clients_set['client{}'.format(i)] = someone
if __name__=="__main__":
MyClients = ClientsGroup('mnist', True, 100, 1)
print(MyClients.clients_set['client10'].train_ds[0:100])
print(MyClients.clients_set['client11'].train_ds[400:500])