-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support for multiprocess scenarios
- Loading branch information
1 parent
748c238
commit 5005ea0
Showing
24 changed files
with
759 additions
and
389 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import argparse | ||
import sys | ||
from typing import Any | ||
|
||
sys.path.append("../../") | ||
from fedlab.board import fedboard | ||
from fedlab.board.delegate import FedBoardDelegate | ||
from fedlab.board.utils.roles import BOARD_SHOWER | ||
from fedlab.contrib.dataset import PathologicalMNIST | ||
|
||
parser = argparse.ArgumentParser(description="FedBoard example") | ||
parser.add_argument("--port", type=str, default="8070") | ||
args = parser.parse_args() | ||
|
||
fedboard.register(id='mtp-01', roles=BOARD_SHOWER) | ||
dataset = PathologicalMNIST(root='../../datasets/mnist/', | ||
path="../../datasets/mnist/", | ||
num_clients=100) | ||
|
||
|
||
class mDelegate(FedBoardDelegate): | ||
def sample_client_data(self, client_id: str, client_rank: str, type: str, amount: int) -> tuple[ | ||
list[Any], list[Any]]: | ||
data = [] | ||
label = [] | ||
real_id = int(client_rank) * 10 + int(client_id.split('-')[-1]) | ||
for dt in dataset.get_dataloader(real_id, batch_size=amount, type=type): | ||
x, y = dt | ||
for x_p in x: | ||
data.append(x_p) | ||
for y_p in y: | ||
label.append(y_p) | ||
break | ||
return data, label | ||
|
||
def read_client_label(self, client_id: str, client_rank: str, type: str) -> list[Any]: | ||
res = [] | ||
real_id = int(client_rank) * 10 + int(client_id.split('-')[1]) | ||
for _, label in dataset.get_dataloader(real_id, batch_size=64, type=type): | ||
for y in label: | ||
res.append(y.detach().cpu().item()) | ||
return res | ||
|
||
|
||
fedboard.enable_builtin_charts(mDelegate()) | ||
fedboard.start_offline(port=args.port) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import argparse | ||
import sys | ||
|
||
import torch | ||
|
||
sys.path.append("../../") | ||
|
||
from fedlab.core.client import PassiveClientManager | ||
from fedlab.core.network import DistNetwork | ||
from pipeline.client_side import ExampleTrainer | ||
from fedlab.contrib.dataset.pathological_mnist import PathologicalMNIST | ||
from fedlab.models.mlp import MLP | ||
from fedlab.board import fedboard | ||
from fedlab.board.utils.roles import CLIENT_HOLDER | ||
|
||
parser = argparse.ArgumentParser(description="Distbelief training example") | ||
|
||
parser.add_argument("--ip", type=str, default="127.0.0.1") | ||
parser.add_argument("--port", type=str, default="3002") | ||
parser.add_argument("--world_size", type=int) | ||
parser.add_argument("--rank", type=int) | ||
parser.add_argument("--ethernet", type=str, default=None) | ||
|
||
parser.add_argument("--lr", type=float, default=0.01) | ||
parser.add_argument("--epochs", type=int, default=2) | ||
parser.add_argument("--batch_size", type=int, default=100) | ||
|
||
args = parser.parse_args() | ||
|
||
if torch.cuda.is_available(): | ||
args.cuda = True | ||
else: | ||
args.cuda = False | ||
|
||
model = MLP(784, 10) | ||
|
||
fedboard.register(id='mtp-01', process_rank=args.rank, roles=CLIENT_HOLDER, | ||
client_ids=[f'{args.rank}-{i}' for i in range(10)]) | ||
|
||
network = DistNetwork(address=(args.ip, args.port), | ||
world_size=args.world_size, | ||
rank=args.rank, | ||
ethernet=args.ethernet) | ||
|
||
trainer = ExampleTrainer(rank=args.rank, model=model, num_clients=10, cuda=args.cuda) | ||
|
||
dataset = PathologicalMNIST(root='../../datasets/mnist/', | ||
path="../../datasets/mnist/", | ||
num_clients=100) | ||
|
||
if args.rank == 1: | ||
dataset.preprocess() | ||
|
||
trainer.setup_dataset(dataset) | ||
trainer.setup_optim(args.epochs, args.batch_size, args.lr) | ||
|
||
manager_ = PassiveClientManager(trainer=trainer, network=network) | ||
manager_.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#!/bin/bash | ||
|
||
|
||
rm -rf ./.fedboard | ||
|
||
python server.py --world_size 11 --round 10 & | ||
echo "server started" | ||
sleep 2 | ||
|
||
for ((i=1; i<=10; i++)) | ||
do | ||
{ | ||
echo "client ${i} started" | ||
python client.py --world_size 11 --rank ${i} & | ||
sleep 1 | ||
} | ||
done | ||
|
||
python board.py & | ||
echo "board started" | ||
sleep 5 | ||
|
||
wait |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from tqdm import tqdm | ||
|
||
from fedlab.board import fedboard | ||
from fedlab.contrib.algorithm import SGDSerialClientTrainer | ||
|
||
|
||
class ExampleTrainer(SGDSerialClientTrainer): | ||
|
||
def __init__(self, rank, **kwargs): | ||
super().__init__(**kwargs) | ||
self.rank = rank | ||
|
||
def train(self, model_parameters, train_loader): | ||
self.set_model(model_parameters) | ||
self._model.train() | ||
loss = 0 | ||
for _ in range(self.epochs): | ||
for data, target in train_loader: | ||
if self.cuda: | ||
data = data.cuda(self.device) | ||
target = target.cuda(self.device) | ||
|
||
output = self.model(data) | ||
loss = self.criterion(output, target) | ||
|
||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
return [self.model_parameters, loss.detach().cpu()] | ||
|
||
def local_process(self, payload, id_list): | ||
model_parameters = payload[0] | ||
global_round = int(payload[-1].detach().cpu().item()) | ||
client_metrics = {} | ||
client_parameters = {} | ||
for id in tqdm(id_list, desc=">>> Local training"): | ||
data_loader = self.dataset.get_dataloader(id, self.batch_size) | ||
pack = self.train(model_parameters, data_loader) | ||
self.cache.append(pack) | ||
global_id = f'{self.rank}-{id % 10}' | ||
client_parameters[global_id] = pack[0] | ||
loss = float(pack[-1].numpy()) | ||
client_metrics[global_id] = {'loss': loss, 'nloss': -loss} | ||
|
||
fedboard.log(global_round, client_metrics=client_metrics, client_params=client_parameters) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import threading | ||
|
||
import torch | ||
|
||
from fedlab.board import fedboard | ||
from fedlab.contrib.algorithm import SyncServerHandler | ||
from fedlab.contrib.client_sampler.uniform_sampler import RandomSampler | ||
from fedlab.core.server import SynchronousServerManager | ||
from fedlab.utils import MessageCode | ||
|
||
|
||
class ExampleManager(SynchronousServerManager): | ||
def activate_clients(self, round): | ||
self._LOGGER.info("Client activation procedure") | ||
clients_this_round = self._handler.sample_clients() | ||
rank_dict = self.coordinator.map_id_list(clients_this_round) | ||
|
||
self._LOGGER.info("Client id list: {}".format(clients_this_round)) | ||
|
||
for rank, values in rank_dict.items(): | ||
downlink_package = self._handler.downlink_package | ||
downlink_package.append(torch.tensor(round)) | ||
id_list = torch.Tensor(values).to(downlink_package[0].dtype) | ||
self._network.send(content=[id_list] + downlink_package, | ||
message_code=MessageCode.ParameterUpdate, | ||
dst=rank) | ||
|
||
def main_loop(self): | ||
rd = 1 | ||
while self._handler.if_stop is not True: | ||
activator = threading.Thread(target=self.activate_clients, args=[rd]) | ||
activator.start() | ||
total_loss = 0 | ||
while True: | ||
sender_rank, message_code, payload = self._network.recv() | ||
if message_code == MessageCode.ParameterUpdate: | ||
if self._handler.load(payload): | ||
break | ||
total_loss += payload[1].numpy() | ||
# self._handler.evaluate() | ||
metric = {'loss': total_loss, 'nloss': -total_loss} | ||
fedboard.log(rd, metrics=metric) | ||
rd += 1 | ||
|
||
|
||
class ExampleHandler(SyncServerHandler): | ||
|
||
def sample_clients(self, num_to_sample=None): | ||
if self.sampler is None: | ||
self.sampler = RandomSampler(self.num_clients) | ||
# new version with built-in sampler | ||
self.round_clients = max(1, int(self.sample_ratio * self.num_clients)) | ||
sampled = self.sampler.sample(self.round_clients) | ||
self.round_clients = len(sampled) | ||
|
||
assert self.num_clients_per_round == len(sampled) | ||
return sorted(sampled) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import argparse | ||
import sys | ||
|
||
sys.path.append("../../") | ||
|
||
from fedlab.core.network import DistNetwork | ||
from fedlab.models.mlp import MLP | ||
from fedlab.board import fedboard | ||
from fedlab.board.utils.roles import SERVER | ||
from pipeline.server_side import ExampleHandler, ExampleManager | ||
|
||
parser = argparse.ArgumentParser(description='FL server example') | ||
|
||
parser.add_argument('--ip', type=str, default="127.0.0.1") | ||
parser.add_argument('--port', type=str, default="3002") | ||
parser.add_argument('--world_size', type=int) | ||
parser.add_argument('--ethernet', type=str, default=None) | ||
|
||
parser.add_argument('--round', type=int) | ||
parser.add_argument('--sample', type=float, default=0.5) | ||
|
||
args = parser.parse_args() | ||
|
||
model = MLP(784, 10) | ||
|
||
handler = ExampleHandler(model, | ||
global_round=args.round, | ||
sample_ratio=args.sample, | ||
|
||
cuda=False) | ||
|
||
network = DistNetwork(address=(args.ip, args.port), | ||
world_size=args.world_size, | ||
rank=0) | ||
|
||
manager_ = ExampleManager(network=network, handler=handler, mode="GLOBAL") | ||
|
||
fedboard.register(id='mtp-01', max_round=args.round, roles=SERVER) | ||
|
||
manager_.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
ps -ef |grep client.py |grep -v grep |awk '{print "kill -9 "$2}' | sh | ||
ps -ef |grep server.py |grep -v grep |awk '{print "kill -9 "$2}' | sh | ||
ps -ef |grep board.py |grep -v grep |awk '{print "kill -9 "$2}' | sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Any | ||
|
||
from fedlab.board.delegate import FedBoardDelegate | ||
|
||
|
||
class ExampleDelegate(FedBoardDelegate): | ||
|
||
def __init__(self, dataset): | ||
super().__init__() | ||
self.dataset = dataset | ||
|
||
def sample_client_data(self, client_id: str, client_rank: str, type: str, amount: int) -> tuple[ | ||
list[Any], list[Any]]: | ||
data = [] | ||
label = [] | ||
for dt in self.dataset.get_dataloader(client_id, batch_size=amount, type=type): | ||
x, y = dt | ||
for x_p in x: | ||
data.append(x_p) | ||
for y_p in y: | ||
label.append(y_p) | ||
break | ||
return data, label | ||
|
||
def read_client_label(self, client_id: str, client_rank: str, type: str) -> list[Any]: | ||
res = [] | ||
for _, label in self.dataset.get_dataloader(client_id, batch_size=args.batch_size, type=type): | ||
for y in label: | ||
res.append(y.detach().cpu().item()) | ||
return res |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.