Skip to content

Commit

Permalink
add support for multiprocess scenarios
Browse files Browse the repository at this point in the history
  • Loading branch information
StupidTrees committed Jul 20, 2023
1 parent 748c238 commit 5005ea0
Show file tree
Hide file tree
Showing 24 changed files with 759 additions and 389 deletions.
46 changes: 46 additions & 0 deletions examples/scale-mnist-board/board.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import argparse
import sys
from typing import Any

sys.path.append("../../")
from fedlab.board import fedboard
from fedlab.board.delegate import FedBoardDelegate
from fedlab.board.utils.roles import BOARD_SHOWER
from fedlab.contrib.dataset import PathologicalMNIST

parser = argparse.ArgumentParser(description="FedBoard example")
parser.add_argument("--port", type=str, default="8070")
args = parser.parse_args()

fedboard.register(id='mtp-01', roles=BOARD_SHOWER)
dataset = PathologicalMNIST(root='../../datasets/mnist/',
path="../../datasets/mnist/",
num_clients=100)


class mDelegate(FedBoardDelegate):
def sample_client_data(self, client_id: str, client_rank: str, type: str, amount: int) -> tuple[
list[Any], list[Any]]:
data = []
label = []
real_id = int(client_rank) * 10 + int(client_id.split('-')[-1])
for dt in dataset.get_dataloader(real_id, batch_size=amount, type=type):
x, y = dt
for x_p in x:
data.append(x_p)
for y_p in y:
label.append(y_p)
break
return data, label

def read_client_label(self, client_id: str, client_rank: str, type: str) -> list[Any]:
res = []
real_id = int(client_rank) * 10 + int(client_id.split('-')[1])
for _, label in dataset.get_dataloader(real_id, batch_size=64, type=type):
for y in label:
res.append(y.detach().cpu().item())
return res


fedboard.enable_builtin_charts(mDelegate())
fedboard.start_offline(port=args.port)
58 changes: 58 additions & 0 deletions examples/scale-mnist-board/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import argparse
import sys

import torch

sys.path.append("../../")

from fedlab.core.client import PassiveClientManager
from fedlab.core.network import DistNetwork
from pipeline.client_side import ExampleTrainer
from fedlab.contrib.dataset.pathological_mnist import PathologicalMNIST
from fedlab.models.mlp import MLP
from fedlab.board import fedboard
from fedlab.board.utils.roles import CLIENT_HOLDER

parser = argparse.ArgumentParser(description="Distbelief training example")

parser.add_argument("--ip", type=str, default="127.0.0.1")
parser.add_argument("--port", type=str, default="3002")
parser.add_argument("--world_size", type=int)
parser.add_argument("--rank", type=int)
parser.add_argument("--ethernet", type=str, default=None)

parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=100)

args = parser.parse_args()

if torch.cuda.is_available():
args.cuda = True
else:
args.cuda = False

model = MLP(784, 10)

fedboard.register(id='mtp-01', process_rank=args.rank, roles=CLIENT_HOLDER,
client_ids=[f'{args.rank}-{i}' for i in range(10)])

network = DistNetwork(address=(args.ip, args.port),
world_size=args.world_size,
rank=args.rank,
ethernet=args.ethernet)

trainer = ExampleTrainer(rank=args.rank, model=model, num_clients=10, cuda=args.cuda)

dataset = PathologicalMNIST(root='../../datasets/mnist/',
path="../../datasets/mnist/",
num_clients=100)

if args.rank == 1:
dataset.preprocess()

trainer.setup_dataset(dataset)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)

manager_ = PassiveClientManager(trainer=trainer, network=network)
manager_.run()
23 changes: 23 additions & 0 deletions examples/scale-mnist-board/launch_eg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/bin/bash


rm -rf ./.fedboard

python server.py --world_size 11 --round 10 &
echo "server started"
sleep 2

for ((i=1; i<=10; i++))
do
{
echo "client ${i} started"
python client.py --world_size 11 --rank ${i} &
sleep 1
}
done

python board.py &
echo "board started"
sleep 5

wait
46 changes: 46 additions & 0 deletions examples/scale-mnist-board/pipeline/client_side.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from tqdm import tqdm

from fedlab.board import fedboard
from fedlab.contrib.algorithm import SGDSerialClientTrainer


class ExampleTrainer(SGDSerialClientTrainer):

def __init__(self, rank, **kwargs):
super().__init__(**kwargs)
self.rank = rank

def train(self, model_parameters, train_loader):
self.set_model(model_parameters)
self._model.train()
loss = 0
for _ in range(self.epochs):
for data, target in train_loader:
if self.cuda:
data = data.cuda(self.device)
target = target.cuda(self.device)

output = self.model(data)
loss = self.criterion(output, target)

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

return [self.model_parameters, loss.detach().cpu()]

def local_process(self, payload, id_list):
model_parameters = payload[0]
global_round = int(payload[-1].detach().cpu().item())
client_metrics = {}
client_parameters = {}
for id in tqdm(id_list, desc=">>> Local training"):
data_loader = self.dataset.get_dataloader(id, self.batch_size)
pack = self.train(model_parameters, data_loader)
self.cache.append(pack)
global_id = f'{self.rank}-{id % 10}'
client_parameters[global_id] = pack[0]
loss = float(pack[-1].numpy())
client_metrics[global_id] = {'loss': loss, 'nloss': -loss}

fedboard.log(global_round, client_metrics=client_metrics, client_params=client_parameters)
57 changes: 57 additions & 0 deletions examples/scale-mnist-board/pipeline/server_side.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import threading

import torch

from fedlab.board import fedboard
from fedlab.contrib.algorithm import SyncServerHandler
from fedlab.contrib.client_sampler.uniform_sampler import RandomSampler
from fedlab.core.server import SynchronousServerManager
from fedlab.utils import MessageCode


class ExampleManager(SynchronousServerManager):
def activate_clients(self, round):
self._LOGGER.info("Client activation procedure")
clients_this_round = self._handler.sample_clients()
rank_dict = self.coordinator.map_id_list(clients_this_round)

self._LOGGER.info("Client id list: {}".format(clients_this_round))

for rank, values in rank_dict.items():
downlink_package = self._handler.downlink_package
downlink_package.append(torch.tensor(round))
id_list = torch.Tensor(values).to(downlink_package[0].dtype)
self._network.send(content=[id_list] + downlink_package,
message_code=MessageCode.ParameterUpdate,
dst=rank)

def main_loop(self):
rd = 1
while self._handler.if_stop is not True:
activator = threading.Thread(target=self.activate_clients, args=[rd])
activator.start()
total_loss = 0
while True:
sender_rank, message_code, payload = self._network.recv()
if message_code == MessageCode.ParameterUpdate:
if self._handler.load(payload):
break
total_loss += payload[1].numpy()
# self._handler.evaluate()
metric = {'loss': total_loss, 'nloss': -total_loss}
fedboard.log(rd, metrics=metric)
rd += 1


class ExampleHandler(SyncServerHandler):

def sample_clients(self, num_to_sample=None):
if self.sampler is None:
self.sampler = RandomSampler(self.num_clients)
# new version with built-in sampler
self.round_clients = max(1, int(self.sample_ratio * self.num_clients))
sampled = self.sampler.sample(self.round_clients)
self.round_clients = len(sampled)

assert self.num_clients_per_round == len(sampled)
return sorted(sampled)
40 changes: 40 additions & 0 deletions examples/scale-mnist-board/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
import sys

sys.path.append("../../")

from fedlab.core.network import DistNetwork
from fedlab.models.mlp import MLP
from fedlab.board import fedboard
from fedlab.board.utils.roles import SERVER
from pipeline.server_side import ExampleHandler, ExampleManager

parser = argparse.ArgumentParser(description='FL server example')

parser.add_argument('--ip', type=str, default="127.0.0.1")
parser.add_argument('--port', type=str, default="3002")
parser.add_argument('--world_size', type=int)
parser.add_argument('--ethernet', type=str, default=None)

parser.add_argument('--round', type=int)
parser.add_argument('--sample', type=float, default=0.5)

args = parser.parse_args()

model = MLP(784, 10)

handler = ExampleHandler(model,
global_round=args.round,
sample_ratio=args.sample,

cuda=False)

network = DistNetwork(address=(args.ip, args.port),
world_size=args.world_size,
rank=0)

manager_ = ExampleManager(network=network, handler=handler, mode="GLOBAL")

fedboard.register(id='mtp-01', max_round=args.round, roles=SERVER)

manager_.run()
3 changes: 3 additions & 0 deletions examples/scale-mnist-board/stop_all.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ps -ef |grep client.py |grep -v grep |awk '{print "kill -9 "$2}' | sh
ps -ef |grep server.py |grep -v grep |awk '{print "kill -9 "$2}' | sh
ps -ef |grep board.py |grep -v grep |awk '{print "kill -9 "$2}' | sh
30 changes: 30 additions & 0 deletions examples/standalone-mnist-board/board.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Any

from fedlab.board.delegate import FedBoardDelegate


class ExampleDelegate(FedBoardDelegate):

def __init__(self, dataset):
super().__init__()
self.dataset = dataset

def sample_client_data(self, client_id: str, client_rank: str, type: str, amount: int) -> tuple[
list[Any], list[Any]]:
data = []
label = []
for dt in self.dataset.get_dataloader(client_id, batch_size=amount, type=type):
x, y = dt
for x_p in x:
data.append(x_p)
for y_p in y:
label.append(y_p)
break
return data, label

def read_client_label(self, client_id: str, client_rank: str, type: str) -> list[Any]:
res = []
for _, label in self.dataset.get_dataloader(client_id, batch_size=args.batch_size, type=type):
for y in label:
res.append(y.detach().cpu().item())
return res
3 changes: 0 additions & 3 deletions examples/standalone-mnist-board/board_detached.py

This file was deleted.

Loading

0 comments on commit 5005ea0

Please sign in to comment.