Skip to content


Browse files Browse the repository at this point in the history
  • Loading branch information
“stupidtree” authored and StupidTrees committed Jul 20, 2023
1 parent 771733b commit 0639df5
Show file tree
Hide file tree
Showing 22 changed files with 1,292 additions and 0 deletions.
Empty file.
3 changes: 3 additions & 0 deletions examples/standalone-mnist-board/
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from fedlab.visual import fedboard

114 changes: 114 additions & 0 deletions examples/standalone-mnist-board/
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from copy import deepcopy
from typing import List

import torch

from fedlab.contrib.algorithm import SyncServerHandler
from fedlab.contrib.client_sampler.base_sampler import FedSampler
from fedlab.contrib.client_sampler.uniform_sampler import RandomSampler
from fedlab.core.server.handler import ServerHandler
from fedlab.utils import Logger, Aggregators, SerializationTool

class StandaloneSyncServerHandler(ServerHandler):

def __init__(
model: torch.nn.Module,
global_round: int,
sample_ratio: float,
cuda: bool = False,
device: str = None,
sampler: FedSampler = None,
logger: Logger = None,
super(StandaloneSyncServerHandler, self).__init__(model, cuda, device)

self._LOGGER = Logger() if logger is None else logger
assert sample_ratio >= 0.0 and sample_ratio <= 1.0

# basic setting
self.num_clients = 0
self.sample_ratio = sample_ratio
self.sampler = sampler

# client buffer
self.round_clients = max(
1, int(self.sample_ratio * self.num_clients)
) # for dynamic client sampling
self.client_buffer_cache = []

# stop condition
self.global_round = global_round
self.round = 0

def downlink_package(self) -> List[torch.Tensor]:
"""Property for manager layer. Server manager will call this property when activates clients."""
return [self.model_parameters]

def num_clients_per_round(self):
return self.round_clients

def if_stop(self):
""":class:`NetworkManager` keeps monitoring this attribute, and it will stop all related processes and threads when ``True`` returned."""
return self.round >= self.global_round

# for built-in sampler
# @property
# def num_clients_per_round(self):
# return max(1, int(self.sample_ratio * self.num_clients))

def sample_clients(self, num_to_sample=None):
"""Return a list of client rank indices selected randomly. The client ID is from ``0`` to
``self.num_clients -1``."""
# selection = random.sample(range(self.num_clients),
# self.num_clients_per_round)
# If the number of clients per round is not fixed, please change the value of self.sample_ratio correspondly.
# self.sample_ratio = float(len(selection))/self.num_clients
# assert self.num_clients_per_round == len(selection)

if self.sampler is None:
self.sampler = RandomSampler(self.num_clients)
# new version with built-in sampler
num_to_sample = self.round_clients if num_to_sample is None else num_to_sample
sampled = self.sampler.sample(num_to_sample)
self.round_clients = len(sampled)

assert self.num_clients_per_round == len(sampled)
return sorted(sampled)

def global_update(self, buffer):
parameters_list = [ele[0] for ele in buffer]
serialized_parameters = Aggregators.fedavg_aggregate(parameters_list)
SerializationTool.deserialize_model(self._model, serialized_parameters)

def load(self, payload: List[torch.Tensor]) -> bool:
"""Update global model with collected parameters from clients.
Server handler will call this method when its ``client_buffer_cache`` is full. User can
overwrite the strategy of aggregation to apply on :attr:`model_parameters_list`, and
use :meth:`SerializationTool.deserialize_model` to load serialized parameters after
aggregation into :attr:`self._model`.
payload (list[torch.Tensor]): A list of tensors passed by manager layer.
assert len(payload) > 0

assert len(self.client_buffer_cache) <= self.num_clients_per_round

if len(self.client_buffer_cache) == self.num_clients_per_round:
self.round += 1

# reset cache
self.client_buffer_cache = []

return True # return True to end this round.
return False
3 changes: 3 additions & 0 deletions examples/standalone-mnist-board/
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

python --total_client 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1
47 changes: 47 additions & 0 deletions examples/standalone-mnist-board/
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np

from fedlab.core.client.trainer import SerialClientTrainer
from fedlab.core.server.handler import ServerHandler
from fedlab.board import fedboard

class StandalonePipeline(object):
def __init__(self, handler: ServerHandler, trainer: SerialClientTrainer):
"""Perform standalone simulation process.
handler (ServerHandler): _description_
trainer (SerialClientTrainer): _description_
self.handler = handler
self.trainer = trainer

# initialization
self.handler.num_clients = self.trainer.num_clients

def main(self):
round = 0
while self.handler.if_stop is False:
# server side
sampled_clients = self.handler.sample_clients(self.trainer.num_clients)
broadcast = self.handler.downlink_package

# client side
self.trainer.local_process(broadcast, sampled_clients)
uploads = self.trainer.uplink_package

# server side
for pack in uploads:

# evaluate and log the result to FedBoard
losses = self.evaluate()
overall_loss = np.average([l for l in losses.values()])
metrics = {'loss': overall_loss, 'nlosss': -overall_loss}
client_metrics = {str(id): {'loss': ls, 'nloss': -ls} for id, ls in losses.items()}
fedboard.log(round + 1, client_params={str(id): pack for id, pack in enumerate(uploads)},
metrics=metrics, main_metric_name='loss', client_metrics=client_metrics)
round += 1

def evaluate(self):
return self.trainer.get_loss()
69 changes: 69 additions & 0 deletions examples/standalone-mnist-board/
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import argparse
import sys
from typing import Any

import torch

from fedlab.board import fedboard
from fedlab.board.delegate import FedBoardDelegate
from fedlab.board.fedboard import RuntimeFedBoard
from handler import StandaloneSyncServerHandler
from pipeline import StandalonePipeline
from trainer import StandaloneSerialClientTrainer
from fedlab.models.mlp import MLP
from fedlab.contrib.dataset.pathological_mnist import PathologicalMNIST

# configuration
parser = argparse.ArgumentParser(description="Standalone training example")
parser.add_argument("--total_client", type=int, default=50)
parser.add_argument("--com_round", type=int, default=1000)

parser.add_argument("--sample_ratio", type=float, default=1.0)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--lr", type=float, default=0.01)

args = parser.parse_args()

model = MLP(784, 10)

# server
handler = StandaloneSyncServerHandler(model, args.com_round, args.sample_ratio)

# client
trainer = StandaloneSerialClientTrainer(model, args.total_client, cuda=False)
dataset = PathologicalMNIST(root='../../datasets/mnist/', path="../../datasets/mnist/", num_clients=args.total_client)
trainer.setup_optim(args.epochs, args.batch_size,

class mDelegate(FedBoardDelegate):
def sample_client_data(self, client_id: str, type: str, amount: int) -> tuple[list[Any], list[Any]]:
data = []
label = []
for dt in dataset.get_dataloader(client_id, batch_size=amount, type=type):
x, y = dt
for x_p in x:
for y_p in y:
return data, label

def read_client_label(self, client_id: str, type: str) -> list[Any]:
res = []
for _, label in dataset.get_dataloader(client_id, batch_size=args.batch_size, type=type):
for y in label:
return res

# main
pipeline = StandalonePipeline(handler, trainer)
# set up FedLabBoard
fedboard.setup(mDelegate(), max_round=args.com_round, client_ids=[str(i) for i in range(args.total_client)])
# pipeline.main()
with RuntimeFedBoard(port=8040):
70 changes: 70 additions & 0 deletions examples/standalone-mnist-board/
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from copy import deepcopy

import torch
from tqdm import tqdm

from fedlab.core.client.trainer import SerialClientTrainer
from fedlab.utils import Logger

class StandaloneSerialClientTrainer(SerialClientTrainer):

def __init__(self, model, num_clients, cuda=False, device=None, logger=None, personal=False) -> None:
super().__init__(model, num_clients, cuda, device, personal)
self._LOGGER = Logger() if logger is None else logger
self.cache = []
self.loss = {}

def setup_dataset(self, dataset):
self.dataset = dataset

def setup_optim(self, epochs, batch_size, lr):
"""Set up local optimization configuration.
epochs (int): Local epochs.
batch_size (int): Local batch size.
lr (float): Learning rate.
self.epochs = epochs
self.batch_size = batch_size = lr
self.optimizer = torch.optim.SGD(self._model.parameters(), lr)
self.criterion = torch.nn.CrossEntropyLoss()

def uplink_package(self):
package = deepcopy(self.cache)
self.cache = []
return package

def local_process(self, payload, id_list):
model_parameters = payload[0]
for id in tqdm(id_list, desc=">>> Local training"):
data_loader = self.dataset.get_dataloader(id, self.batch_size)
pack, loss = self.train(model_parameters, data_loader)
self.loss[id] = loss

def train(self, model_parameters, train_loader):


for _ in range(self.epochs):
total_loss = 0
for data, target in train_loader:
if self.cuda:
data = data.cuda(self.device)
target = target.cuda(self.device)

output = self.model(data)
loss = self.criterion(output, target)
total_loss += loss.detach().cpu().item()
return [self.model_parameters], total_loss

def get_loss(self):
return self.loss
1 change: 1 addition & 0 deletions fedlab/board/
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = ['delegate', 'fedboard']
Empty file.
76 changes: 76 additions & 0 deletions fedlab/board/builtin/
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

from import viewModel, _add_section, _add_chart
import plotly.graph_objects as go

def _add_built_in_charts():
_add_section('dataset', 'normal')
_add_section('parameters', 'slider')

@_add_chart(section='parameters', figure_name='figure_tsne', span=12)
def update_tsne_figure(value, selected_client):
tsne_data = viewModel.client_param_tsne(value, selected_client)
if tsne_data is not None:
data = []
for idx, cid in enumerate(selected_client):
x=[tsne_data[idx, 0]], y=[tsne_data[idx, 1]], mode='markers',
marker=dict(color=viewModel.get_color(cid), size=16),
tsne_figure = go.Figure(data=data,
layout_title_text=f"Parameters t-SNE")
tsne_figure = []
return tsne_figure

@_add_chart(section='dataset', figure_name='figure_client_classes', span=6)
def update_data_classes(selected_client):
client_targets = viewModel.get_client_data_report(selected_client, type='train')
class_sizes: dict[str, dict[str, int]] = {}
for cid, targets in client_targets.items():
for y in targets:
class_sizes.setdefault(y, {id: 0 for id in selected_client})
class_sizes[y][cid] += 1
client_classes = go.Figure(
go.Bar(y=[f'Client{id}' for id in selected_client],
x=[sizes[id] for id in selected_client],
name=f'Class {clz}', orientation='h')
# marker=dict(color=[viewModel.colors[id] for id in selected_client])),
for clz, sizes in class_sizes.items()
layout_title_text="Label Distribution"
client_classes.update_layout(barmode='stack', margin=dict(l=48, r=48, b=64, t=86))
return client_classes

@_add_chart(section='dataset', figure_name='figure_client_sizes', span=6)
def update_data_sizes(selected_client):
client_targets = viewModel.get_client_data_report(selected_client, type='train')
client_sizes = go.Figure(
data=[go.Bar(x=[f'Client{n}' for n, _ in client_targets.items()],
y=[len(ce) for _, ce in client_targets.items()],
marker=dict(color=[viewModel.get_color(id) for id in selected_client]))],
layout_title_text="Dataset Sizes"
client_sizes.update_layout(margin=dict(l=48, r=48, b=64, t=86))
return client_sizes

@_add_chart(section='dataset', figure_name='figure_client_data_tsne', span=12)
def update_data_data_value(selected_client):
tsne_data = viewModel.get_client_dataset_tsne(selected_client, "train", 200)
if tsne_data is not None:
data = []
for idx, cid in enumerate(selected_client):
x=tsne_data[cid][:, 0], y=tsne_data[cid][:, 1], z=tsne_data[cid][:, 2], mode='markers',
marker=dict(color=viewModel.get_color(cid), size=4, opacity=0.8),
data = []
tsne_figure = go.Figure(data=data,
layout_title_text=f"Local Dataset t-SNE")
tsne_figure.update_layout(margin=dict(l=48, r=48, b=64, t=64), dict1={"height": 600})
return tsne_figure

0 comments on commit 0639df5

Please sign in to comment.