Skip to content

Commit

Permalink
code optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
StupidTrees committed Jul 20, 2023
1 parent 5104735 commit 748c238
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 148 deletions.
2 changes: 1 addition & 1 deletion examples/standalone-mnist-board/launch_eg.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash

python standalone.py --total_client 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1
python standalone.py --total_client 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1 --port 8040
15 changes: 9 additions & 6 deletions examples/standalone-mnist-board/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--port", type=int, default=8040)

args = parser.parse_args()

Expand All @@ -44,9 +45,11 @@
# main
pipeline = StandalonePipeline(handler, trainer)


# set up FedLabBoard
# define delegate for additional dataset analysis
fedboard.setup(max_round=args.com_round, client_ids=[str(i) for i in range(args.total_client)])


# To enable builtin figures, a dataset-reading delegate is required
class mDelegate(FedBoardDelegate):
def sample_client_data(self, client_id: str, type: str, amount: int) -> tuple[list[Any], list[Any]]:
data = []
Expand All @@ -69,13 +72,13 @@ def read_client_label(self, client_id: str, type: str) -> list[Any]:


delegate = mDelegate()
fedboard.setup(delegate, max_round=args.com_round, client_ids=[str(i) for i in range(args.total_client)])
fedboard.enable_builtin_charts(delegate)

# Add diy chart
fedboard.add_section(section='diy', type='normal')


@fedboard.add_chart(section='diy', figure_name='2d-dataset-tsne', span=12)
@fedboard.add_chart(section='diy', figure_name='2d-dataset-tsne', span=1.0)
def diy_chart(selected_clients, selected_colors):
"""
Args:
Expand Down Expand Up @@ -107,6 +110,6 @@ def diy_chart(selected_clients, selected_colors):
return tsne_figure


# start experiment along with FedBoard
with RuntimeFedBoard(port=8040):
# Start experiment along with FedBoard
with RuntimeFedBoard(port=args.port):
pipeline.main()
26 changes: 14 additions & 12 deletions fedlab/board/builtin/charts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from fedlab.board.front.app import viewModel, _add_section, _add_chart
import plotly.graph_objects as go

from fedlab.board import fedboard
from fedlab.board.builtin.renderer import client_param_tsne, get_client_dataset_tsne, get_client_data_report

def _add_built_in_charts():
_add_section('dataset', 'normal')
_add_section('parameters', 'slider')

@_add_chart(section='parameters', figure_name='figure_tsne', span=12)
def add_built_in_charts():
fedboard.add_section('dataset', 'normal')
fedboard.add_section('parameters', 'slider')

@fedboard.add_chart(section='parameters', figure_name='figure_tsne', span=1.0)
def update_tsne_figure(value, selected_client, selected_colors):
tsne_data = viewModel.client_param_tsne(value, selected_client)
tsne_data = client_param_tsne(value, selected_client)
if tsne_data is not None:
data = []
for idx, cid in enumerate(selected_client):
Expand All @@ -23,9 +25,9 @@ def update_tsne_figure(value, selected_client, selected_colors):
tsne_figure = []
return tsne_figure

@_add_chart(section='dataset', figure_name='figure_client_classes', span=6)
@fedboard.add_chart(section='dataset', figure_name='figure_client_classes', span=0.5)
def update_data_classes(selected_client, selected_colors):
client_targets = viewModel.get_client_data_report(selected_client, type='train')
client_targets = get_client_data_report(selected_client, type='train')
class_sizes: dict[str, dict[str, int]] = {}
for cid, targets in client_targets.items():
for y in targets:
Expand All @@ -44,9 +46,9 @@ def update_data_classes(selected_client, selected_colors):
client_classes.update_layout(barmode='stack', margin=dict(l=48, r=48, b=64, t=86))
return client_classes

@_add_chart(section='dataset', figure_name='figure_client_sizes', span=6)
@fedboard.add_chart(section='dataset', figure_name='figure_client_sizes', span=0.5)
def update_data_sizes(selected_client, selected_colors):
client_targets = viewModel.get_client_data_report(selected_client, type='train')
client_targets = get_client_data_report(selected_client, type='train')
client_sizes = go.Figure(
data=[go.Bar(x=[f'Client{n}' for n, _ in client_targets.items()],
y=[len(ce) for _, ce in client_targets.items()],
Expand All @@ -56,9 +58,9 @@ def update_data_sizes(selected_client, selected_colors):
client_sizes.update_layout(margin=dict(l=48, r=48, b=64, t=86))
return client_sizes

@_add_chart(section='dataset', figure_name='figure_client_data_tsne', span=12)
@fedboard.add_chart(section='dataset', figure_name='figure_client_data_tsne', span=1.0)
def update_data_tsne_value(selected_client, selected_colors):
tsne_data = viewModel.get_client_dataset_tsne(selected_client, "train", 200)
tsne_data = get_client_dataset_tsne(selected_client, "train", 200)
if tsne_data is not None:
data = []
for idx, cid in enumerate(selected_client):
Expand Down
52 changes: 52 additions & 0 deletions fedlab/board/builtin/renderer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Any

import torch
from sklearn.manifold import TSNE

from fedlab.board import fedboard


def client_param_tsne(round: int, client_ids: list[str]):
if len(client_ids) < 2:
return None
client_params: dict[str, Any] = fedboard.read_logged_obj(round, 'client_params')
raw_params = {str(id): param for id, param in client_params.items()}
params_selected = [raw_params[id][0] for id in client_ids if id in raw_params.keys()]
if len(params_selected) < 1:
return None
params_selected = torch.stack(params_selected)
params_tsne = TSNE(n_components=2, learning_rate=100, random_state=501,
perplexity=min(30.0, len(params_selected) - 1)).fit_transform(
params_selected)
return params_tsne


def get_client_dataset_tsne(client_ids: list[str], type: str, size):
if len(client_ids) < 1:
return None
if not fedboard.get_delegate():
return None
raw = []
client_range = {}
for client_id in client_ids:
data, label = fedboard.get_delegate().sample_client_data(client_id, type, size)
client_range[client_id] = (len(raw), len(raw) + len(data))
raw += data
raw = torch.stack(raw).view(len(raw), -1)
tsne = TSNE(n_components=3, learning_rate=100, random_state=501,
perplexity=min(30.0, len(raw) - 1)).fit_transform(raw)
tsne = {cid: tsne[s:e] for cid, (s, e) in client_range.items()}
return tsne


def get_client_data_report(clients_ids: list[str], type: str):
res = {}
for client_id in clients_ids:
def rd():
if fedboard.get_delegate():
return fedboard.get_delegate().read_client_label(client_id,type=type)
else:
return {}
obj = fedboard.read_cached_obj('data','partition',f'{client_id}',rd)
res[client_id] = obj
return res
159 changes: 130 additions & 29 deletions fedlab/board/fedboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,40 @@
import json
import logging
import os
import pickle
from os import path
from threading import Thread
from typing import Any

from dash import Dash

from fedlab.board.builtin.charts import _add_built_in_charts
from fedlab.board.builtin.charts import add_built_in_charts
from fedlab.board.delegate import FedBoardDelegate
from fedlab.board.front.app import viewModel, create_app, add_callbacks, set_up_layout, _add_chart, _add_section
from fedlab.board.utils.io import _update_meta_file, _clear_log
from fedlab.board.front.app import viewModel, create_app,_set_up_layout, _add_chart, _add_section,_add_callbacks
from fedlab.board.utils.io import _update_meta_file, _clear_log, _log_to_fs, _read_log_from_fs, _read_cached_from_fs, \
_cache_to_fs, _log_to_fs_append, _read_log_from_fs_appended

_app: Dash | None = None
_delegate: FedBoardDelegate | None = None
_dir: str = ''


def setup(delegate: FedBoardDelegate, client_ids, max_round, name=None, log_dir=None):
def get_delegate():
return _delegate


def get_log_dir():
return _dir


def setup(client_ids: list[str], max_round: int, name: str = None, log_dir: str = None):
"""
Set up FedBoard
Args:
client_ids: List of client ids
max_round: Max communication round
name: Experiment name
log_dir: Log directory
"""
meta_info = {
'name': name,
'max_round': max_round,
Expand All @@ -31,26 +49,47 @@ def setup(delegate: FedBoardDelegate, client_ids, max_round, name=None, log_dir=
_update_meta_file(log_dir, 'meta', meta_info)
_update_meta_file(log_dir, 'runtime', {'state': 'START', 'round': 0})
global _app
_add_built_in_charts()
_app = create_app(log_dir, delegate)
add_callbacks(_app)
global _delegate
global _dir
_app = create_app(log_dir)
_dir = log_dir
_add_callbacks(_app)
_clear_log(log_dir)


def enable_builtin_charts(delegate: FedBoardDelegate):
"""
Enable builtin charts, including 'parameters' section and 'dataset' section.
A dataset-reading delegate is required to enable these charts
Args:
delegate (FedBoardDelegate): dataset-reading delegate
"""
global _delegate
_delegate = delegate
add_built_in_charts()


def start_offline(log_dir=None, port=8080):
"""
Start Fedboard offline (seperated from the experiment)
Args:
log_dir: the experiment's log directory
port: Which port will the board run in
"""
if log_dir is None:
calling_file = inspect.stack()[1].filename
calling_directory = os.path.dirname(os.path.abspath(calling_file))
log_dir = calling_directory
log_dir = path.join(log_dir, '.fedboard/')
_add_built_in_charts()
add_built_in_charts()
global _app
_app = create_app(log_dir)
add_callbacks(_app)
_add_callbacks(_app)
if _app is None:
logging.error('FedBoard hasn\'t been initialized!')
return
set_up_layout(_app)
_set_up_layout(_app)
_app.run(host='0.0.0.0', port=port, debug=False, dev_tools_ui=True, use_reloader=False)


Expand All @@ -60,14 +99,14 @@ def __init__(self, port):
meta_info = {
'port': port,
}
_update_meta_file(viewModel.dir, 'meta', meta_info)
_update_meta_file(_dir, 'meta', meta_info)
self.port = port

def _start_app(self):
if _app is None:
logging.error('FedBoard hasn\'t been initialized!')
return
set_up_layout(_app)
_set_up_layout(_app)
_app.run(host='0.0.0.0', port=self.port, debug=False, dev_tools_ui=True, use_reloader=False)

def __enter__(self):
Expand All @@ -78,35 +117,97 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.p1.join()


def log(round: int, client_params: dict[str, Any] = None, metrics: dict[str, Any] = None,
main_metric_name: str = None, client_metrics: dict[str, dict[str, Any]] = None):
def log(round: int, metrics: dict[str, Any] = None, client_metrics: dict[str, dict[str, Any]] = None,
main_metric_name: str = None, client_main_metric_name: str = None, **kwargs):
"""
Args:
round (int): Which communication round
metrics (dict): Global performance at this round. E.g., {'loss':0.02, 'acc':0.85}
client_metrics (dict): Client performance at this round. E.g., {'Client0':{'loss':0.01, 'acc':0.85}, 'Client1':...}
main_metric_name (str): Main global metric. E.g., 'loss'
client_main_metric_name (str): Main Client metric. E.g., 'acc'
Returns:
"""
state = "RUNNING"
if round == viewModel.get_max_round():
state = 'DONE'
_update_meta_file(viewModel.dir, section='runtime', dct={'state': state, 'round': round})
if client_params:
os.makedirs(path.join(viewModel.dir, f'log/params/raw/'), exist_ok=True)
pickle.dump(client_params, open(path.join(viewModel.dir, f'log/params/raw/rd{round}.pkl'), 'wb+'))
_update_meta_file(_dir, section='runtime', dct={'state': state, 'round': round})
for key, obj in kwargs.items():
_log_to_fs(_dir, type='params', sub_type=key, name=f'rd{round}', obj=obj)
if metrics:
os.makedirs(path.join(viewModel.dir, f'log/performs/'), exist_ok=True)
if main_metric_name is None:
main_metric_name = list[metrics.keys()][0]
metrics['main_name'] = main_metric_name
with open(path.join(viewModel.dir, f'log/performs/overall'), 'a+') as f:
f.write(json.dumps(metrics) + '\n')
_log_to_fs_append(_dir, type='performs', name='overall', obj=metrics)
if client_metrics:
os.makedirs(path.join(viewModel.dir, f'log/performs/'), exist_ok=True)
if main_metric_name is None:
main_metric_name = list(client_metrics[list(client_metrics.keys())[0]].keys())[0]
if client_main_metric_name is None:
if main_metric_name:
client_main_metric_name = main_metric_name
else:
client_main_metric_name = list(client_metrics[list(client_metrics.keys())[0]].keys())[0]
for cid in client_metrics.keys():
client_metrics[cid]['main_name'] = main_metric_name
with open(path.join(viewModel.dir, f'log/performs/client'), 'a+') as f:
f.write(json.dumps(client_metrics) + '\n')
client_metrics[cid]['main_name'] = client_main_metric_name
_log_to_fs_append(_dir, type='performs', name='client', obj=client_metrics)


def add_section(section: str, type: str):
"""
Args:
section (str): Section name
type (str): Section type, can be 'normal' and 'slider', when set to 'slider', additional
Returns:
"""
assert type in ['normal', 'slider']
_add_section(section=section, type=type)


def add_chart(section=None, figure_name=None, span=6):
def add_chart(section=None, figure_name=None, span=0.5):
"""
Used as decorators for other functions,
For sections with type = 'normal', the function takes input (selected_clients, selected_colors)
For sections with type = 'slider', the function takes input (slider_value, selected_clients, selected_colors)
Args:
section (str): Section the chart will be added to
figure_name (str) : Chart ID
span (float): Chart span, E.g. 0.6 for 60% row width
Examples:
@add_chart('diy', 'slider', 1.0)
def ct(slider_value, selected_clients, selected_colors):
...render the figure
return figure
@add_chart('diy2', 'slider', 1.0)
def ct(slider_value, selected_clients, selected_colors):
...render the figure
return figure
"""
return _add_chart(section=section, figure_name=figure_name, span=span)


def read_logged_obj(round: int, type: str):
return _read_log_from_fs(_dir, type='params', sub_type=type, name=f'rd{round}')


def read_logged_obj_appended(type: str, name: str, sub_type: str = None):
return _read_log_from_fs_appended(_dir, type=type, name=name, sub_type=sub_type)


def read_cached_obj(type: str, sub_type: str, key: str, creator: callable):
obj = _read_cached_from_fs(_dir, type=type, sub_type=sub_type, name=key)
if not obj:
obj = creator()
_cache_to_fs(obj, _dir, type, sub_type, key)
return obj
Loading

0 comments on commit 748c238

Please sign in to comment.