Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
“stupidtree” authored and StupidTrees committed Jul 20, 2023
1 parent 771733b commit 0639df5
Show file tree
Hide file tree
Showing 22 changed files with 1,292 additions and 0 deletions.
Empty file.
3 changes: 3 additions & 0 deletions examples/standalone-mnist-board/board_detached.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from fedlab.visual import fedboard

fedboard.start(port=8040)
114 changes: 114 additions & 0 deletions examples/standalone-mnist-board/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from copy import deepcopy
from typing import List

import torch

from fedlab.contrib.algorithm import SyncServerHandler
from fedlab.contrib.client_sampler.base_sampler import FedSampler
from fedlab.contrib.client_sampler.uniform_sampler import RandomSampler
from fedlab.core.server.handler import ServerHandler
from fedlab.utils import Logger, Aggregators, SerializationTool


class StandaloneSyncServerHandler(ServerHandler):

def __init__(
self,
model: torch.nn.Module,
global_round: int,
sample_ratio: float,
cuda: bool = False,
device: str = None,
sampler: FedSampler = None,
logger: Logger = None,
):
super(StandaloneSyncServerHandler, self).__init__(model, cuda, device)

self._LOGGER = Logger() if logger is None else logger
assert sample_ratio >= 0.0 and sample_ratio <= 1.0

# basic setting
self.num_clients = 0
self.sample_ratio = sample_ratio
self.sampler = sampler

# client buffer
self.round_clients = max(
1, int(self.sample_ratio * self.num_clients)
) # for dynamic client sampling
self.client_buffer_cache = []

# stop condition
self.global_round = global_round
self.round = 0

@property
def downlink_package(self) -> List[torch.Tensor]:
"""Property for manager layer. Server manager will call this property when activates clients."""
return [self.model_parameters]

@property
def num_clients_per_round(self):
return self.round_clients

@property
def if_stop(self):
""":class:`NetworkManager` keeps monitoring this attribute, and it will stop all related processes and threads when ``True`` returned."""
return self.round >= self.global_round

# for built-in sampler
# @property
# def num_clients_per_round(self):
# return max(1, int(self.sample_ratio * self.num_clients))

def sample_clients(self, num_to_sample=None):
"""Return a list of client rank indices selected randomly. The client ID is from ``0`` to
``self.num_clients -1``."""
# selection = random.sample(range(self.num_clients),
# self.num_clients_per_round)
# If the number of clients per round is not fixed, please change the value of self.sample_ratio correspondly.
# self.sample_ratio = float(len(selection))/self.num_clients
# assert self.num_clients_per_round == len(selection)

if self.sampler is None:
self.sampler = RandomSampler(self.num_clients)
# new version with built-in sampler
num_to_sample = self.round_clients if num_to_sample is None else num_to_sample
sampled = self.sampler.sample(num_to_sample)
self.round_clients = len(sampled)

assert self.num_clients_per_round == len(sampled)
return sorted(sampled)

def global_update(self, buffer):
parameters_list = [ele[0] for ele in buffer]
serialized_parameters = Aggregators.fedavg_aggregate(parameters_list)
SerializationTool.deserialize_model(self._model, serialized_parameters)

def load(self, payload: List[torch.Tensor]) -> bool:
"""Update global model with collected parameters from clients.
Note:
Server handler will call this method when its ``client_buffer_cache`` is full. User can
overwrite the strategy of aggregation to apply on :attr:`model_parameters_list`, and
use :meth:`SerializationTool.deserialize_model` to load serialized parameters after
aggregation into :attr:`self._model`.
Args:
payload (list[torch.Tensor]): A list of tensors passed by manager layer.
"""
assert len(payload) > 0
self.client_buffer_cache.append(deepcopy(payload))

assert len(self.client_buffer_cache) <= self.num_clients_per_round

if len(self.client_buffer_cache) == self.num_clients_per_round:
self.global_update(self.client_buffer_cache)
self.round += 1

# reset cache
self.client_buffer_cache = []

return True # return True to end this round.
else:
return False
3 changes: 3 additions & 0 deletions examples/standalone-mnist-board/launch_eg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash

python standalone.py --total_client 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1
47 changes: 47 additions & 0 deletions examples/standalone-mnist-board/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np

from fedlab.core.client.trainer import SerialClientTrainer
from fedlab.core.server.handler import ServerHandler
from fedlab.board import fedboard


class StandalonePipeline(object):
def __init__(self, handler: ServerHandler, trainer: SerialClientTrainer):
"""Perform standalone simulation process.
Args:
handler (ServerHandler): _description_
trainer (SerialClientTrainer): _description_
"""
self.handler = handler
self.trainer = trainer

# initialization
self.handler.num_clients = self.trainer.num_clients

def main(self):
round = 0
while self.handler.if_stop is False:
# server side
sampled_clients = self.handler.sample_clients(self.trainer.num_clients)
broadcast = self.handler.downlink_package

# client side
self.trainer.local_process(broadcast, sampled_clients)
uploads = self.trainer.uplink_package

# server side
for pack in uploads:
self.handler.load(pack)

# evaluate and log the result to FedBoard
losses = self.evaluate()
overall_loss = np.average([l for l in losses.values()])
metrics = {'loss': overall_loss, 'nlosss': -overall_loss}
client_metrics = {str(id): {'loss': ls, 'nloss': -ls} for id, ls in losses.items()}
fedboard.log(round + 1, client_params={str(id): pack for id, pack in enumerate(uploads)},
metrics=metrics, main_metric_name='loss', client_metrics=client_metrics)
round += 1

def evaluate(self):
return self.trainer.get_loss()
69 changes: 69 additions & 0 deletions examples/standalone-mnist-board/standalone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import argparse
import sys
from typing import Any

import torch

sys.path.append("../../")
torch.manual_seed(0)
from fedlab.board import fedboard
from fedlab.board.delegate import FedBoardDelegate
from fedlab.board.fedboard import RuntimeFedBoard
from handler import StandaloneSyncServerHandler
from pipeline import StandalonePipeline
from trainer import StandaloneSerialClientTrainer
from fedlab.models.mlp import MLP
from fedlab.contrib.dataset.pathological_mnist import PathologicalMNIST

# configuration
parser = argparse.ArgumentParser(description="Standalone training example")
parser.add_argument("--total_client", type=int, default=50)
parser.add_argument("--com_round", type=int, default=1000)

parser.add_argument("--sample_ratio", type=float, default=1.0)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--lr", type=float, default=0.01)

args = parser.parse_args()

model = MLP(784, 10)

# server
handler = StandaloneSyncServerHandler(model, args.com_round, args.sample_ratio)

# client
trainer = StandaloneSerialClientTrainer(model, args.total_client, cuda=False)
dataset = PathologicalMNIST(root='../../datasets/mnist/', path="../../datasets/mnist/", num_clients=args.total_client)
dataset.preprocess()
trainer.setup_dataset(dataset)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)


class mDelegate(FedBoardDelegate):
def sample_client_data(self, client_id: str, type: str, amount: int) -> tuple[list[Any], list[Any]]:
data = []
label = []
for dt in dataset.get_dataloader(client_id, batch_size=amount, type=type):
x, y = dt
for x_p in x:
data.append(x_p)
for y_p in y:
label.append(y_p)
break
return data, label

def read_client_label(self, client_id: str, type: str) -> list[Any]:
res = []
for _, label in dataset.get_dataloader(client_id, batch_size=args.batch_size, type=type):
for y in label:
res.append(y.detach().cpu().item())
return res

# main
pipeline = StandalonePipeline(handler, trainer)
# set up FedLabBoard
fedboard.setup(mDelegate(), max_round=args.com_round, client_ids=[str(i) for i in range(args.total_client)])
# pipeline.main()
with RuntimeFedBoard(port=8040):
pipeline.main()
70 changes: 70 additions & 0 deletions examples/standalone-mnist-board/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from copy import deepcopy

import torch
from tqdm import tqdm

from fedlab.core.client.trainer import SerialClientTrainer
from fedlab.utils import Logger


class StandaloneSerialClientTrainer(SerialClientTrainer):

def __init__(self, model, num_clients, cuda=False, device=None, logger=None, personal=False) -> None:
super().__init__(model, num_clients, cuda, device, personal)
self._LOGGER = Logger() if logger is None else logger
self.cache = []
self.loss = {}

def setup_dataset(self, dataset):
self.dataset = dataset

def setup_optim(self, epochs, batch_size, lr):
"""Set up local optimization configuration.
Args:
epochs (int): Local epochs.
batch_size (int): Local batch size.
lr (float): Learning rate.
"""
self.epochs = epochs
self.batch_size = batch_size
self.lr = lr
self.optimizer = torch.optim.SGD(self._model.parameters(), lr)
self.criterion = torch.nn.CrossEntropyLoss()

@property
def uplink_package(self):
package = deepcopy(self.cache)
self.cache = []
return package

def local_process(self, payload, id_list):
model_parameters = payload[0]
for id in tqdm(id_list, desc=">>> Local training"):
data_loader = self.dataset.get_dataloader(id, self.batch_size)
pack, loss = self.train(model_parameters, data_loader)
self.cache.append(pack)
self.loss[id] = loss

def train(self, model_parameters, train_loader):

self.set_model(model_parameters)
self._model.train()

for _ in range(self.epochs):
total_loss = 0
for data, target in train_loader:
if self.cuda:
data = data.cuda(self.device)
target = target.cuda(self.device)

output = self.model(data)
loss = self.criterion(output, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.detach().cpu().item()
return [self.model_parameters], total_loss

def get_loss(self):
return self.loss
1 change: 1 addition & 0 deletions fedlab/board/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__ = ['delegate', 'fedboard']
Empty file.
76 changes: 76 additions & 0 deletions fedlab/board/builtin/charts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@

from fedlab.board.front.app import viewModel, _add_section, _add_chart
import plotly.graph_objects as go


def _add_built_in_charts():
_add_section('dataset', 'normal')
_add_section('parameters', 'slider')

@_add_chart(section='parameters', figure_name='figure_tsne', span=12)
def update_tsne_figure(value, selected_client):
tsne_data = viewModel.client_param_tsne(value, selected_client)
if tsne_data is not None:
data = []
for idx, cid in enumerate(selected_client):
data.append(go.Scatter(
x=[tsne_data[idx, 0]], y=[tsne_data[idx, 1]], mode='markers',
marker=dict(color=viewModel.get_color(cid), size=16),
name=f'Client{cid}'
))
tsne_figure = go.Figure(data=data,
layout_title_text=f"Parameters t-SNE")
else:
tsne_figure = []
return tsne_figure

@_add_chart(section='dataset', figure_name='figure_client_classes', span=6)
def update_data_classes(selected_client):
client_targets = viewModel.get_client_data_report(selected_client, type='train')
class_sizes: dict[str, dict[str, int]] = {}
for cid, targets in client_targets.items():
for y in targets:
class_sizes.setdefault(y, {id: 0 for id in selected_client})
class_sizes[y][cid] += 1
client_classes = go.Figure(
data=[
go.Bar(y=[f'Client{id}' for id in selected_client],
x=[sizes[id] for id in selected_client],
name=f'Class {clz}', orientation='h')
# marker=dict(color=[viewModel.colors[id] for id in selected_client])),
for clz, sizes in class_sizes.items()
],
layout_title_text="Label Distribution"
)
client_classes.update_layout(barmode='stack', margin=dict(l=48, r=48, b=64, t=86))
return client_classes

@_add_chart(section='dataset', figure_name='figure_client_sizes', span=6)
def update_data_sizes(selected_client):
client_targets = viewModel.get_client_data_report(selected_client, type='train')
client_sizes = go.Figure(
data=[go.Bar(x=[f'Client{n}' for n, _ in client_targets.items()],
y=[len(ce) for _, ce in client_targets.items()],
marker=dict(color=[viewModel.get_color(id) for id in selected_client]))],
layout_title_text="Dataset Sizes"
)
client_sizes.update_layout(margin=dict(l=48, r=48, b=64, t=86))
return client_sizes

@_add_chart(section='dataset', figure_name='figure_client_data_tsne', span=12)
def update_data_data_value(selected_client):
tsne_data = viewModel.get_client_dataset_tsne(selected_client, "train", 200)
if tsne_data is not None:
data = []
for idx, cid in enumerate(selected_client):
data.append(go.Scatter3d(
x=tsne_data[cid][:, 0], y=tsne_data[cid][:, 1], z=tsne_data[cid][:, 2], mode='markers',
marker=dict(color=viewModel.get_color(cid), size=4, opacity=0.8),
name=f'Client{cid}'
))
else:
data = []
tsne_figure = go.Figure(data=data,
layout_title_text=f"Local Dataset t-SNE")
tsne_figure.update_layout(margin=dict(l=48, r=48, b=64, t=64), dict1={"height": 600})
return tsne_figure
Loading

0 comments on commit 0639df5

Please sign in to comment.