-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
771733b
commit 0639df5
Showing
22 changed files
with
1,292 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from fedlab.visual import fedboard | ||
|
||
fedboard.start(port=8040) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from copy import deepcopy | ||
from typing import List | ||
|
||
import torch | ||
|
||
from fedlab.contrib.algorithm import SyncServerHandler | ||
from fedlab.contrib.client_sampler.base_sampler import FedSampler | ||
from fedlab.contrib.client_sampler.uniform_sampler import RandomSampler | ||
from fedlab.core.server.handler import ServerHandler | ||
from fedlab.utils import Logger, Aggregators, SerializationTool | ||
|
||
|
||
class StandaloneSyncServerHandler(ServerHandler): | ||
|
||
def __init__( | ||
self, | ||
model: torch.nn.Module, | ||
global_round: int, | ||
sample_ratio: float, | ||
cuda: bool = False, | ||
device: str = None, | ||
sampler: FedSampler = None, | ||
logger: Logger = None, | ||
): | ||
super(StandaloneSyncServerHandler, self).__init__(model, cuda, device) | ||
|
||
self._LOGGER = Logger() if logger is None else logger | ||
assert sample_ratio >= 0.0 and sample_ratio <= 1.0 | ||
|
||
# basic setting | ||
self.num_clients = 0 | ||
self.sample_ratio = sample_ratio | ||
self.sampler = sampler | ||
|
||
# client buffer | ||
self.round_clients = max( | ||
1, int(self.sample_ratio * self.num_clients) | ||
) # for dynamic client sampling | ||
self.client_buffer_cache = [] | ||
|
||
# stop condition | ||
self.global_round = global_round | ||
self.round = 0 | ||
|
||
@property | ||
def downlink_package(self) -> List[torch.Tensor]: | ||
"""Property for manager layer. Server manager will call this property when activates clients.""" | ||
return [self.model_parameters] | ||
|
||
@property | ||
def num_clients_per_round(self): | ||
return self.round_clients | ||
|
||
@property | ||
def if_stop(self): | ||
""":class:`NetworkManager` keeps monitoring this attribute, and it will stop all related processes and threads when ``True`` returned.""" | ||
return self.round >= self.global_round | ||
|
||
# for built-in sampler | ||
# @property | ||
# def num_clients_per_round(self): | ||
# return max(1, int(self.sample_ratio * self.num_clients)) | ||
|
||
def sample_clients(self, num_to_sample=None): | ||
"""Return a list of client rank indices selected randomly. The client ID is from ``0`` to | ||
``self.num_clients -1``.""" | ||
# selection = random.sample(range(self.num_clients), | ||
# self.num_clients_per_round) | ||
# If the number of clients per round is not fixed, please change the value of self.sample_ratio correspondly. | ||
# self.sample_ratio = float(len(selection))/self.num_clients | ||
# assert self.num_clients_per_round == len(selection) | ||
|
||
if self.sampler is None: | ||
self.sampler = RandomSampler(self.num_clients) | ||
# new version with built-in sampler | ||
num_to_sample = self.round_clients if num_to_sample is None else num_to_sample | ||
sampled = self.sampler.sample(num_to_sample) | ||
self.round_clients = len(sampled) | ||
|
||
assert self.num_clients_per_round == len(sampled) | ||
return sorted(sampled) | ||
|
||
def global_update(self, buffer): | ||
parameters_list = [ele[0] for ele in buffer] | ||
serialized_parameters = Aggregators.fedavg_aggregate(parameters_list) | ||
SerializationTool.deserialize_model(self._model, serialized_parameters) | ||
|
||
def load(self, payload: List[torch.Tensor]) -> bool: | ||
"""Update global model with collected parameters from clients. | ||
Note: | ||
Server handler will call this method when its ``client_buffer_cache`` is full. User can | ||
overwrite the strategy of aggregation to apply on :attr:`model_parameters_list`, and | ||
use :meth:`SerializationTool.deserialize_model` to load serialized parameters after | ||
aggregation into :attr:`self._model`. | ||
Args: | ||
payload (list[torch.Tensor]): A list of tensors passed by manager layer. | ||
""" | ||
assert len(payload) > 0 | ||
self.client_buffer_cache.append(deepcopy(payload)) | ||
|
||
assert len(self.client_buffer_cache) <= self.num_clients_per_round | ||
|
||
if len(self.client_buffer_cache) == self.num_clients_per_round: | ||
self.global_update(self.client_buffer_cache) | ||
self.round += 1 | ||
|
||
# reset cache | ||
self.client_buffer_cache = [] | ||
|
||
return True # return True to end this round. | ||
else: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
#!/bin/bash | ||
|
||
python standalone.py --total_client 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
|
||
from fedlab.core.client.trainer import SerialClientTrainer | ||
from fedlab.core.server.handler import ServerHandler | ||
from fedlab.board import fedboard | ||
|
||
|
||
class StandalonePipeline(object): | ||
def __init__(self, handler: ServerHandler, trainer: SerialClientTrainer): | ||
"""Perform standalone simulation process. | ||
Args: | ||
handler (ServerHandler): _description_ | ||
trainer (SerialClientTrainer): _description_ | ||
""" | ||
self.handler = handler | ||
self.trainer = trainer | ||
|
||
# initialization | ||
self.handler.num_clients = self.trainer.num_clients | ||
|
||
def main(self): | ||
round = 0 | ||
while self.handler.if_stop is False: | ||
# server side | ||
sampled_clients = self.handler.sample_clients(self.trainer.num_clients) | ||
broadcast = self.handler.downlink_package | ||
|
||
# client side | ||
self.trainer.local_process(broadcast, sampled_clients) | ||
uploads = self.trainer.uplink_package | ||
|
||
# server side | ||
for pack in uploads: | ||
self.handler.load(pack) | ||
|
||
# evaluate and log the result to FedBoard | ||
losses = self.evaluate() | ||
overall_loss = np.average([l for l in losses.values()]) | ||
metrics = {'loss': overall_loss, 'nlosss': -overall_loss} | ||
client_metrics = {str(id): {'loss': ls, 'nloss': -ls} for id, ls in losses.items()} | ||
fedboard.log(round + 1, client_params={str(id): pack for id, pack in enumerate(uploads)}, | ||
metrics=metrics, main_metric_name='loss', client_metrics=client_metrics) | ||
round += 1 | ||
|
||
def evaluate(self): | ||
return self.trainer.get_loss() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import argparse | ||
import sys | ||
from typing import Any | ||
|
||
import torch | ||
|
||
sys.path.append("../../") | ||
torch.manual_seed(0) | ||
from fedlab.board import fedboard | ||
from fedlab.board.delegate import FedBoardDelegate | ||
from fedlab.board.fedboard import RuntimeFedBoard | ||
from handler import StandaloneSyncServerHandler | ||
from pipeline import StandalonePipeline | ||
from trainer import StandaloneSerialClientTrainer | ||
from fedlab.models.mlp import MLP | ||
from fedlab.contrib.dataset.pathological_mnist import PathologicalMNIST | ||
|
||
# configuration | ||
parser = argparse.ArgumentParser(description="Standalone training example") | ||
parser.add_argument("--total_client", type=int, default=50) | ||
parser.add_argument("--com_round", type=int, default=1000) | ||
|
||
parser.add_argument("--sample_ratio", type=float, default=1.0) | ||
parser.add_argument("--batch_size", type=int, default=64) | ||
parser.add_argument("--epochs", type=int, default=1) | ||
parser.add_argument("--lr", type=float, default=0.01) | ||
|
||
args = parser.parse_args() | ||
|
||
model = MLP(784, 10) | ||
|
||
# server | ||
handler = StandaloneSyncServerHandler(model, args.com_round, args.sample_ratio) | ||
|
||
# client | ||
trainer = StandaloneSerialClientTrainer(model, args.total_client, cuda=False) | ||
dataset = PathologicalMNIST(root='../../datasets/mnist/', path="../../datasets/mnist/", num_clients=args.total_client) | ||
dataset.preprocess() | ||
trainer.setup_dataset(dataset) | ||
trainer.setup_optim(args.epochs, args.batch_size, args.lr) | ||
|
||
|
||
class mDelegate(FedBoardDelegate): | ||
def sample_client_data(self, client_id: str, type: str, amount: int) -> tuple[list[Any], list[Any]]: | ||
data = [] | ||
label = [] | ||
for dt in dataset.get_dataloader(client_id, batch_size=amount, type=type): | ||
x, y = dt | ||
for x_p in x: | ||
data.append(x_p) | ||
for y_p in y: | ||
label.append(y_p) | ||
break | ||
return data, label | ||
|
||
def read_client_label(self, client_id: str, type: str) -> list[Any]: | ||
res = [] | ||
for _, label in dataset.get_dataloader(client_id, batch_size=args.batch_size, type=type): | ||
for y in label: | ||
res.append(y.detach().cpu().item()) | ||
return res | ||
|
||
# main | ||
pipeline = StandalonePipeline(handler, trainer) | ||
# set up FedLabBoard | ||
fedboard.setup(mDelegate(), max_round=args.com_round, client_ids=[str(i) for i in range(args.total_client)]) | ||
# pipeline.main() | ||
with RuntimeFedBoard(port=8040): | ||
pipeline.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from copy import deepcopy | ||
|
||
import torch | ||
from tqdm import tqdm | ||
|
||
from fedlab.core.client.trainer import SerialClientTrainer | ||
from fedlab.utils import Logger | ||
|
||
|
||
class StandaloneSerialClientTrainer(SerialClientTrainer): | ||
|
||
def __init__(self, model, num_clients, cuda=False, device=None, logger=None, personal=False) -> None: | ||
super().__init__(model, num_clients, cuda, device, personal) | ||
self._LOGGER = Logger() if logger is None else logger | ||
self.cache = [] | ||
self.loss = {} | ||
|
||
def setup_dataset(self, dataset): | ||
self.dataset = dataset | ||
|
||
def setup_optim(self, epochs, batch_size, lr): | ||
"""Set up local optimization configuration. | ||
Args: | ||
epochs (int): Local epochs. | ||
batch_size (int): Local batch size. | ||
lr (float): Learning rate. | ||
""" | ||
self.epochs = epochs | ||
self.batch_size = batch_size | ||
self.lr = lr | ||
self.optimizer = torch.optim.SGD(self._model.parameters(), lr) | ||
self.criterion = torch.nn.CrossEntropyLoss() | ||
|
||
@property | ||
def uplink_package(self): | ||
package = deepcopy(self.cache) | ||
self.cache = [] | ||
return package | ||
|
||
def local_process(self, payload, id_list): | ||
model_parameters = payload[0] | ||
for id in tqdm(id_list, desc=">>> Local training"): | ||
data_loader = self.dataset.get_dataloader(id, self.batch_size) | ||
pack, loss = self.train(model_parameters, data_loader) | ||
self.cache.append(pack) | ||
self.loss[id] = loss | ||
|
||
def train(self, model_parameters, train_loader): | ||
|
||
self.set_model(model_parameters) | ||
self._model.train() | ||
|
||
for _ in range(self.epochs): | ||
total_loss = 0 | ||
for data, target in train_loader: | ||
if self.cuda: | ||
data = data.cuda(self.device) | ||
target = target.cuda(self.device) | ||
|
||
output = self.model(data) | ||
loss = self.criterion(output, target) | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
total_loss += loss.detach().cpu().item() | ||
return [self.model_parameters], total_loss | ||
|
||
def get_loss(self): | ||
return self.loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__all__ = ['delegate', 'fedboard'] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
|
||
from fedlab.board.front.app import viewModel, _add_section, _add_chart | ||
import plotly.graph_objects as go | ||
|
||
|
||
def _add_built_in_charts(): | ||
_add_section('dataset', 'normal') | ||
_add_section('parameters', 'slider') | ||
|
||
@_add_chart(section='parameters', figure_name='figure_tsne', span=12) | ||
def update_tsne_figure(value, selected_client): | ||
tsne_data = viewModel.client_param_tsne(value, selected_client) | ||
if tsne_data is not None: | ||
data = [] | ||
for idx, cid in enumerate(selected_client): | ||
data.append(go.Scatter( | ||
x=[tsne_data[idx, 0]], y=[tsne_data[idx, 1]], mode='markers', | ||
marker=dict(color=viewModel.get_color(cid), size=16), | ||
name=f'Client{cid}' | ||
)) | ||
tsne_figure = go.Figure(data=data, | ||
layout_title_text=f"Parameters t-SNE") | ||
else: | ||
tsne_figure = [] | ||
return tsne_figure | ||
|
||
@_add_chart(section='dataset', figure_name='figure_client_classes', span=6) | ||
def update_data_classes(selected_client): | ||
client_targets = viewModel.get_client_data_report(selected_client, type='train') | ||
class_sizes: dict[str, dict[str, int]] = {} | ||
for cid, targets in client_targets.items(): | ||
for y in targets: | ||
class_sizes.setdefault(y, {id: 0 for id in selected_client}) | ||
class_sizes[y][cid] += 1 | ||
client_classes = go.Figure( | ||
data=[ | ||
go.Bar(y=[f'Client{id}' for id in selected_client], | ||
x=[sizes[id] for id in selected_client], | ||
name=f'Class {clz}', orientation='h') | ||
# marker=dict(color=[viewModel.colors[id] for id in selected_client])), | ||
for clz, sizes in class_sizes.items() | ||
], | ||
layout_title_text="Label Distribution" | ||
) | ||
client_classes.update_layout(barmode='stack', margin=dict(l=48, r=48, b=64, t=86)) | ||
return client_classes | ||
|
||
@_add_chart(section='dataset', figure_name='figure_client_sizes', span=6) | ||
def update_data_sizes(selected_client): | ||
client_targets = viewModel.get_client_data_report(selected_client, type='train') | ||
client_sizes = go.Figure( | ||
data=[go.Bar(x=[f'Client{n}' for n, _ in client_targets.items()], | ||
y=[len(ce) for _, ce in client_targets.items()], | ||
marker=dict(color=[viewModel.get_color(id) for id in selected_client]))], | ||
layout_title_text="Dataset Sizes" | ||
) | ||
client_sizes.update_layout(margin=dict(l=48, r=48, b=64, t=86)) | ||
return client_sizes | ||
|
||
@_add_chart(section='dataset', figure_name='figure_client_data_tsne', span=12) | ||
def update_data_data_value(selected_client): | ||
tsne_data = viewModel.get_client_dataset_tsne(selected_client, "train", 200) | ||
if tsne_data is not None: | ||
data = [] | ||
for idx, cid in enumerate(selected_client): | ||
data.append(go.Scatter3d( | ||
x=tsne_data[cid][:, 0], y=tsne_data[cid][:, 1], z=tsne_data[cid][:, 2], mode='markers', | ||
marker=dict(color=viewModel.get_color(cid), size=4, opacity=0.8), | ||
name=f'Client{cid}' | ||
)) | ||
else: | ||
data = [] | ||
tsne_figure = go.Figure(data=data, | ||
layout_title_text=f"Local Dataset t-SNE") | ||
tsne_figure.update_layout(margin=dict(l=48, r=48, b=64, t=64), dict1={"height": 600}) | ||
return tsne_figure |
Oops, something went wrong.