Skip to content

Commit

Permalink
add support for diy charts
Browse files Browse the repository at this point in the history
  • Loading branch information
“stupidtree” authored and StupidTrees committed Jul 20, 2023
1 parent 0639df5 commit 5104735
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 55 deletions.
4 changes: 2 additions & 2 deletions examples/standalone-mnist-board/board_detached.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from fedlab.visual import fedboard
from fedlab.board import fedboard

fedboard.start(port=8040)
fedboard.start_offline(port=8040)
1 change: 0 additions & 1 deletion examples/standalone-mnist-board/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch

from fedlab.contrib.algorithm import SyncServerHandler
from fedlab.contrib.client_sampler.base_sampler import FedSampler
from fedlab.contrib.client_sampler.uniform_sampler import RandomSampler
from fedlab.core.server.handler import ServerHandler
Expand Down
53 changes: 48 additions & 5 deletions examples/standalone-mnist-board/standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

import torch
from sklearn.manifold import TSNE

sys.path.append("../../")
torch.manual_seed(0)
Expand All @@ -14,6 +15,7 @@
from trainer import StandaloneSerialClientTrainer
from fedlab.models.mlp import MLP
from fedlab.contrib.dataset.pathological_mnist import PathologicalMNIST
import plotly.graph_objects as go

# configuration
parser = argparse.ArgumentParser(description="Standalone training example")
Expand All @@ -39,7 +41,12 @@
trainer.setup_dataset(dataset)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)

# main
pipeline = StandalonePipeline(handler, trainer)


# set up FedLabBoard
# define delegate for additional dataset analysis
class mDelegate(FedBoardDelegate):
def sample_client_data(self, client_id: str, type: str, amount: int) -> tuple[list[Any], list[Any]]:
data = []
Expand All @@ -60,10 +67,46 @@ def read_client_label(self, client_id: str, type: str) -> list[Any]:
res.append(y.detach().cpu().item())
return res

# main
pipeline = StandalonePipeline(handler, trainer)
# set up FedLabBoard
fedboard.setup(mDelegate(), max_round=args.com_round, client_ids=[str(i) for i in range(args.total_client)])
# pipeline.main()

delegate = mDelegate()
fedboard.setup(delegate, max_round=args.com_round, client_ids=[str(i) for i in range(args.total_client)])

# Add diy chart
fedboard.add_section(section='diy', type='normal')


@fedboard.add_chart(section='diy', figure_name='2d-dataset-tsne', span=12)
def diy_chart(selected_clients, selected_colors):
"""
Args:
selected_clients: selected client ids, ['1','2',...'124']
selected_colors: colors of selected clients, ['#ffffff','#982223',...,'#128842']
Returns:
A Plotly Figure
"""
raw = []
client_range = {}
for client_id in selected_clients:
data, label = delegate.sample_client_data(client_id, 'train', 100)
client_range[client_id] = (len(raw), len(raw) + len(data))
raw += data
raw = torch.stack(raw).view(len(raw), -1)
tsne = TSNE(n_components=2, learning_rate=100, random_state=501,
perplexity=min(30.0, len(raw) - 1)).fit_transform(raw)
tsne_data = {cid: tsne[s:e] for cid, (s, e) in client_range.items()}
data = []
for idx, cid in enumerate(selected_clients):
data.append(go.Scatter(
x=tsne_data[cid][:, 0], y=tsne_data[cid][:, 1], mode='markers',
marker=dict(color=selected_colors[idx], size=8, opacity=0.8),
name=f'Client{cid}'
))
tsne_figure = go.Figure(data=data,
layout_title_text=f"Local Dataset 2D t-SNE")
tsne_figure.update_layout(margin=dict(l=48, r=48, b=64, t=80), dict1={"height": 600})
return tsne_figure


# start experiment along with FedBoard
with RuntimeFedBoard(port=8040):
pipeline.main()
15 changes: 7 additions & 8 deletions fedlab/board/builtin/charts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from fedlab.board.front.app import viewModel, _add_section, _add_chart
import plotly.graph_objects as go

Expand All @@ -8,14 +7,14 @@ def _add_built_in_charts():
_add_section('parameters', 'slider')

@_add_chart(section='parameters', figure_name='figure_tsne', span=12)
def update_tsne_figure(value, selected_client):
def update_tsne_figure(value, selected_client, selected_colors):
tsne_data = viewModel.client_param_tsne(value, selected_client)
if tsne_data is not None:
data = []
for idx, cid in enumerate(selected_client):
data.append(go.Scatter(
x=[tsne_data[idx, 0]], y=[tsne_data[idx, 1]], mode='markers',
marker=dict(color=viewModel.get_color(cid), size=16),
marker=dict(color=selected_colors[idx], size=16),
name=f'Client{cid}'
))
tsne_figure = go.Figure(data=data,
Expand All @@ -25,7 +24,7 @@ def update_tsne_figure(value, selected_client):
return tsne_figure

@_add_chart(section='dataset', figure_name='figure_client_classes', span=6)
def update_data_classes(selected_client):
def update_data_classes(selected_client, selected_colors):
client_targets = viewModel.get_client_data_report(selected_client, type='train')
class_sizes: dict[str, dict[str, int]] = {}
for cid, targets in client_targets.items():
Expand All @@ -46,26 +45,26 @@ def update_data_classes(selected_client):
return client_classes

@_add_chart(section='dataset', figure_name='figure_client_sizes', span=6)
def update_data_sizes(selected_client):
def update_data_sizes(selected_client, selected_colors):
client_targets = viewModel.get_client_data_report(selected_client, type='train')
client_sizes = go.Figure(
data=[go.Bar(x=[f'Client{n}' for n, _ in client_targets.items()],
y=[len(ce) for _, ce in client_targets.items()],
marker=dict(color=[viewModel.get_color(id) for id in selected_client]))],
marker=dict(color=selected_colors))],
layout_title_text="Dataset Sizes"
)
client_sizes.update_layout(margin=dict(l=48, r=48, b=64, t=86))
return client_sizes

@_add_chart(section='dataset', figure_name='figure_client_data_tsne', span=12)
def update_data_data_value(selected_client):
def update_data_tsne_value(selected_client, selected_colors):
tsne_data = viewModel.get_client_dataset_tsne(selected_client, "train", 200)
if tsne_data is not None:
data = []
for idx, cid in enumerate(selected_client):
data.append(go.Scatter3d(
x=tsne_data[cid][:, 0], y=tsne_data[cid][:, 1], z=tsne_data[cid][:, 2], mode='markers',
marker=dict(color=viewModel.get_color(cid), size=4, opacity=0.8),
marker=dict(color=selected_colors[idx], size=4, opacity=0.8),
name=f'Client{cid}'
))
else:
Expand Down
27 changes: 18 additions & 9 deletions fedlab/board/fedboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,16 @@ def setup(delegate: FedBoardDelegate, client_ids, max_round, name=None, log_dir=
_clear_log(log_dir)


def start(port=8080):
def start_offline(log_dir=None, port=8080):
if log_dir is None:
calling_file = inspect.stack()[1].filename
calling_directory = os.path.dirname(os.path.abspath(calling_file))
log_dir = calling_directory
log_dir = path.join(log_dir, '.fedboard/')
_add_built_in_charts()
global _app
_app = create_app(log_dir)
add_callbacks(_app)
if _app is None:
logging.error('FedBoard hasn\'t been initialized!')
return
Expand Down Expand Up @@ -76,27 +85,27 @@ def log(round: int, client_params: dict[str, Any] = None, metrics: dict[str, Any
state = 'DONE'
_update_meta_file(viewModel.dir, section='runtime', dct={'state': state, 'round': round})
if client_params:
os.makedirs(path.join(viewModel.dir, f'params/raw/'), exist_ok=True)
pickle.dump(client_params, open(path.join(viewModel.dir, f'params/raw/rd{round}.pkl'), 'wb+'))
os.makedirs(path.join(viewModel.dir, f'log/params/raw/'), exist_ok=True)
pickle.dump(client_params, open(path.join(viewModel.dir, f'log/params/raw/rd{round}.pkl'), 'wb+'))
if metrics:
os.makedirs(path.join(viewModel.dir, f'performs/'), exist_ok=True)
os.makedirs(path.join(viewModel.dir, f'log/performs/'), exist_ok=True)
if main_metric_name is None:
main_metric_name = list[metrics.keys()][0]
metrics['main_name'] = main_metric_name
with open(path.join(viewModel.dir, f'performs/overall'), 'a+') as f:
with open(path.join(viewModel.dir, f'log/performs/overall'), 'a+') as f:
f.write(json.dumps(metrics) + '\n')
if client_metrics:
os.makedirs(path.join(viewModel.dir, f'performs/'), exist_ok=True)
os.makedirs(path.join(viewModel.dir, f'log/performs/'), exist_ok=True)
if main_metric_name is None:
main_metric_name = list(client_metrics[list(client_metrics.keys())[0]].keys())[0]
for cid in client_metrics.keys():
client_metrics[cid]['main_name'] = main_metric_name
with open(path.join(viewModel.dir, f'performs/client'), 'a+') as f:
with open(path.join(viewModel.dir, f'log/performs/client'), 'a+') as f:
f.write(json.dumps(client_metrics) + '\n')


def add_sections(section: str, type: str):
_add_section(section=section,type=type)
def add_section(section: str, type: str):
_add_section(section=section, type=type)


def add_chart(section=None, figure_name=None, span=6):
Expand Down
6 changes: 4 additions & 2 deletions fedlab/board/front/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def wrapper(selected_client, regex, fig_id):
os.makedirs(cached_path, exist_ok=True)
if os.path.exists(cached_file):
return pickle.load(open(cached_file, 'rb'))
fig = dic[fig_id]['func'](selected_client)
selected_colors = [viewModel.get_color(id) for id in selected_client]
fig = dic[fig_id]['func'](selected_client, selected_colors)
pickle.dump(fig, open(cached_file, 'wb'))
return fig
return None
Expand All @@ -91,7 +92,8 @@ def wrapper(value, selected_client, regex, fig_id):
os.makedirs(cached_path, exist_ok=True)
if os.path.exists(cached_file):
return pickle.load(open(cached_file, 'rb'))
fig = dic[fig_id]['func'](value, selected_client)
selected_colors = [viewModel.get_color(id) for id in selected_client]
fig = dic[fig_id]['func'](value, selected_client, selected_colors)
pickle.dump(fig, open(cached_file, 'wb'))
return fig
return None
Expand Down
39 changes: 13 additions & 26 deletions fedlab/board/front/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def init(self, dir: str, delegate: FedBoardDelegate = None):
self.dir = dir
self.colors = {id: random_color(int(id)) for id in self.get_client_ids()}
self.setup = True
cache = diskcache.Cache(path.join(dir, "cache/"))
cache = diskcache.Cache(path.join(dir, "dash-cache/"))
self.background_callback_manager = DiskcacheManager(cache)

def get_color(self, client_id):
Expand Down Expand Up @@ -100,38 +100,26 @@ def encode_client_ids(self, client_ids: list[str]):
return encode_int_array(client_indexes)

def client_param_tsne(self, round: int, client_ids: list[str]):
if not os.path.exists(os.path.join(self.dir, f'params/raw/rd{round}.pkl')):
if not os.path.exists(os.path.join(self.dir, f'log/params/raw/rd{round}.pkl')):
return None
if len(client_ids) < 2:
return None
raw_params = {str(id): param for id, param in
pickle.load(open(os.path.join(self.dir, f'params/raw/rd{round}.pkl'), 'rb')).items()}
# fn = self.encode_client_ids(client_ids)
# target_file = os.path.join(self.dir, f'params/tsne/rd{round}/{fn}.pkl')
# if os.path.exists(target_file):
# return pickle.load(open(target_file, 'rb'))
os.makedirs(os.path.join(self.dir, f'params/tsne/rd{round}/'), exist_ok=True)
pickle.load(open(os.path.join(self.dir, f'log/params/raw/rd{round}.pkl'), 'rb')).items()}
params_selected = [raw_params[id][0] for id in client_ids if id in raw_params.keys()]
if len(params_selected) < 1:
return None
params_selected = torch.stack(params_selected)
params_tsne = TSNE(n_components=2, learning_rate=100, random_state=501,
perplexity=min(30.0, len(params_selected) - 1)).fit_transform(
params_selected)
# pickle.dump(params_tsne, open(target_file, 'wb+'))
return params_tsne

def get_client_dataset_tsne(self, client_ids: list, type: str, size):
if len(client_ids) < 2:
return None
if not self.delegate:
return None
# client_indexes = self.client_ids2indexes(client_ids)
# fn = encode_int_array(client_indexes)
# target_file = os.path.join(self.dir, f'data/tsne/{fn}.pkl')
# if os.path.exists(target_file):
# return pickle.load(open(target_file, 'rb'))
os.makedirs(os.path.join(self.dir, f'data/tsne/'), exist_ok=True)
raw = []
client_range = {}
for client_id in client_ids:
Expand All @@ -142,17 +130,16 @@ def get_client_dataset_tsne(self, client_ids: list, type: str, size):
tsne = TSNE(n_components=3, learning_rate=100, random_state=501,
perplexity=min(30.0, len(raw) - 1)).fit_transform(raw)
tsne = {cid: tsne[s:e] for cid, (s, e) in client_range.items()}
# pickle.dump(tsne, open(target_file, 'wb+'))
return tsne

def get_client_data_report(self, clients_ids: list, type: str):
res = {}
for client_id in clients_ids:
target_file = os.path.join(self.dir, f'data/partition/{client_id}.pkl')
target_file = os.path.join(self.dir, f'cache/data/partition/{client_id}.pkl')
if os.path.exists(target_file):
res[client_id] = pickle.load(open(target_file, 'rb'))
else:
os.makedirs(os.path.join(self.dir, f'data/partition/'), exist_ok=True)
os.makedirs(os.path.join(self.dir, f'cache/data/partition/'), exist_ok=True)
if self.delegate:
res[client_id] = self.delegate.read_client_label(client_id, type=type)
else:
Expand All @@ -163,9 +150,9 @@ def get_client_data_report(self, clients_ids: list, type: str):
def get_overall_metrics(self):
main_name = ""
metrics = []
if not os.path.exists(path.join(self.dir, f'performs/overall')):
if not os.path.exists(path.join(self.dir, f'log/performs/overall')):
return metrics, main_name
log_lines = open(path.join(self.dir, f'performs/overall')).readlines()
log_lines = open(path.join(self.dir, f'log/performs/overall')).readlines()
if len(log_lines) > 1:
obj: dict[str, Any] = json.loads(log_lines[-1])
main_name = obj['main_name']
Expand All @@ -175,9 +162,9 @@ def get_overall_metrics(self):
def get_client_metrics(self):
main_name = ""
metrics = []
if not os.path.exists(path.join(self.dir, f'performs/client')):
if not os.path.exists(path.join(self.dir, f'log/performs/client')):
return metrics, main_name
log_lines = open(path.join(self.dir, f'performs/client')).readlines()
log_lines = open(path.join(self.dir, f'log/performs/client')).readlines()
if len(log_lines) > 1:
obj: dict[str, dict[str:Any]] = json.loads(log_lines[-1])
if len(obj.keys()) > 0:
Expand All @@ -189,9 +176,9 @@ def get_client_metrics(self):
def get_overall_performance(self):
res_all = []
main_name = ""
if not os.path.exists(path.join(self.dir, f'performs/overall')):
if not os.path.exists(path.join(self.dir, f'log/performs/overall')):
return res_all, main_name
for line in open(path.join(self.dir, f'performs/overall')).readlines():
for line in open(path.join(self.dir, f'log/performs/overall')).readlines():
obj = json.loads(line)
main_name = obj['main_name']
res_all.append(obj)
Expand All @@ -200,9 +187,9 @@ def get_overall_performance(self):
def get_client_performance(self, client_ids: list[str]):
res = {}
main_name = ""
if not os.path.exists(path.join(self.dir, f'performs/client')):
if not os.path.exists(path.join(self.dir, f'log/performs/client')):
return res, main_name
for line in open(path.join(self.dir, f'performs/client')).readlines():
for line in open(path.join(self.dir, f'log/performs/client')).readlines():
obj = json.loads(line)
for client_id in client_ids:
main_name = obj[client_id]['main_name']
Expand Down
4 changes: 2 additions & 2 deletions fedlab/board/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


def _clear_log(dir):
shutil.rmtree(path.join(dir, f'performs/'), ignore_errors=True)
shutil.rmtree(path.join(dir, f'params/'), ignore_errors=True)
shutil.rmtree(path.join(dir, f'log/'), ignore_errors=True)
shutil.rmtree(path.join(dir, f'cache/'), ignore_errors=True)


def _update_meta_file(file_root: str, section: str, dct: dict):
Expand Down

0 comments on commit 5104735

Please sign in to comment.