Skip to content

Commit

Permalink
Merge pull request #322 from hua-zi/patch-1
Browse files Browse the repository at this point in the history
ditto算法无法使用,修复ditto算法
  • Loading branch information
dunzeng authored Jul 9, 2023
2 parents fa7427b + 0f2262d commit b64fa9d
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions fedlab/contrib/algorithm/ditto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import tqdm
from tqdm import *
from copy import deepcopy

from .basic_server import SyncServerHandler
Expand Down Expand Up @@ -42,7 +43,7 @@ def local_process(self, payload, id_list):
global_model = payload[0]
for id in tqdm(id_list):
# self._LOGGER.info("Local process is running. Training client {}".format(id))
train_loader = self.dataset.get_data_loader(id, batch_size=self.args.batch_size)
train_loader = self.dataset.get_dataloader(id, batch_size=self.batch_size)
self.local_models[id], glb_model = self.train(global_model, self.local_models[id], train_loader)
self.ditto_gmodels.append(deepcopy(glb_model))

Expand All @@ -56,10 +57,10 @@ def train(self, global_model_parameters, local_model_parameters, train_loader):
criterion = torch.nn.CrossEntropyLoss()
SerializationTool.deserialize_model(self._model, global_model_parameters)
self._model.train()
for ep in range(self.args.epochs):
for ep in range(self.epochs):
for data, label in train_loader:
if self.cuda:
data, label = data.cuda(self.gpu), label.cuda(self.gpu)
data, label = data.cuda(self.device), label.cuda(self.device)

preds = self._model(data)
loss = criterion(preds,label)
Expand All @@ -74,21 +75,22 @@ def train(self, global_model_parameters, local_model_parameters, train_loader):

SerializationTool.deserialize_model(self._model, local_model_parameters)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(self._model.parameters(), lr=self.args.lr)
optimizer = torch.optim.SGD(self._model.parameters(), lr=self.lr)

self._model.train()
for ep in range(self.args.epochs):
for ep in range(self.epochs):
for data, label in train_loader:
if self.cuda:
data, label = data.cuda(self.gpu), label.cuda(self.gpu)
data, label = data.cuda(self.device), label.cuda(self.device)

preds = self._model(data)
l1 = criterion(preds,label)
l2 = 0.0
for w0, w in zip(frz_model.parameters(), self._model.parameters()):
l2 += torch.sum(torch.pow(w - w0, 2))

loss = l1 + 0.5 * self.args.mu * l2
# loss = l1 + 0.5 * self.args.mu * l2
loss = l1 + 0.5 * 0.1 * l2 # fedprox 的 mu
optimizer.zero_grad()
loss.backward()
optimizer.step()
Expand Down

0 comments on commit b64fa9d

Please sign in to comment.