diff --git a/examples/standalone-mnist-board/launch_eg.sh b/examples/standalone-mnist-board/launch_eg.sh index 4630bafd..4542156e 100644 --- a/examples/standalone-mnist-board/launch_eg.sh +++ b/examples/standalone-mnist-board/launch_eg.sh @@ -1,3 +1,3 @@ #!/bin/bash -python standalone.py --total_client 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1 +python standalone.py --total_client 100 --com_round 10 --sample_ratio 0.1 --batch_size 128 --epochs 3 --lr 0.1 --port 8040 diff --git a/examples/standalone-mnist-board/standalone.py b/examples/standalone-mnist-board/standalone.py index f6e59abd..bd9446ec 100644 --- a/examples/standalone-mnist-board/standalone.py +++ b/examples/standalone-mnist-board/standalone.py @@ -26,6 +26,7 @@ parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--lr", type=float, default=0.01) +parser.add_argument("--port", type=int, default=8040) args = parser.parse_args() @@ -44,9 +45,11 @@ # main pipeline = StandalonePipeline(handler, trainer) - # set up FedLabBoard -# define delegate for additional dataset analysis +fedboard.setup(max_round=args.com_round, client_ids=[str(i) for i in range(args.total_client)]) + + +# To enable builtin figures, a dataset-reading delegate is required class mDelegate(FedBoardDelegate): def sample_client_data(self, client_id: str, type: str, amount: int) -> tuple[list[Any], list[Any]]: data = [] @@ -69,13 +72,13 @@ def read_client_label(self, client_id: str, type: str) -> list[Any]: delegate = mDelegate() -fedboard.setup(delegate, max_round=args.com_round, client_ids=[str(i) for i in range(args.total_client)]) +fedboard.enable_builtin_charts(delegate) # Add diy chart fedboard.add_section(section='diy', type='normal') -@fedboard.add_chart(section='diy', figure_name='2d-dataset-tsne', span=12) +@fedboard.add_chart(section='diy', figure_name='2d-dataset-tsne', span=1.0) def diy_chart(selected_clients, selected_colors): """ Args: @@ -107,6 +110,6 @@ def diy_chart(selected_clients, selected_colors): return tsne_figure -# start experiment along with FedBoard -with RuntimeFedBoard(port=8040): +# Start experiment along with FedBoard +with RuntimeFedBoard(port=args.port): pipeline.main() diff --git a/fedlab/board/builtin/charts.py b/fedlab/board/builtin/charts.py index 551e3766..50d5f5a2 100644 --- a/fedlab/board/builtin/charts.py +++ b/fedlab/board/builtin/charts.py @@ -1,14 +1,16 @@ -from fedlab.board.front.app import viewModel, _add_section, _add_chart import plotly.graph_objects as go +from fedlab.board import fedboard +from fedlab.board.builtin.renderer import client_param_tsne, get_client_dataset_tsne, get_client_data_report -def _add_built_in_charts(): - _add_section('dataset', 'normal') - _add_section('parameters', 'slider') - @_add_chart(section='parameters', figure_name='figure_tsne', span=12) +def add_built_in_charts(): + fedboard.add_section('dataset', 'normal') + fedboard.add_section('parameters', 'slider') + + @fedboard.add_chart(section='parameters', figure_name='figure_tsne', span=1.0) def update_tsne_figure(value, selected_client, selected_colors): - tsne_data = viewModel.client_param_tsne(value, selected_client) + tsne_data = client_param_tsne(value, selected_client) if tsne_data is not None: data = [] for idx, cid in enumerate(selected_client): @@ -23,9 +25,9 @@ def update_tsne_figure(value, selected_client, selected_colors): tsne_figure = [] return tsne_figure - @_add_chart(section='dataset', figure_name='figure_client_classes', span=6) + @fedboard.add_chart(section='dataset', figure_name='figure_client_classes', span=0.5) def update_data_classes(selected_client, selected_colors): - client_targets = viewModel.get_client_data_report(selected_client, type='train') + client_targets = get_client_data_report(selected_client, type='train') class_sizes: dict[str, dict[str, int]] = {} for cid, targets in client_targets.items(): for y in targets: @@ -44,9 +46,9 @@ def update_data_classes(selected_client, selected_colors): client_classes.update_layout(barmode='stack', margin=dict(l=48, r=48, b=64, t=86)) return client_classes - @_add_chart(section='dataset', figure_name='figure_client_sizes', span=6) + @fedboard.add_chart(section='dataset', figure_name='figure_client_sizes', span=0.5) def update_data_sizes(selected_client, selected_colors): - client_targets = viewModel.get_client_data_report(selected_client, type='train') + client_targets = get_client_data_report(selected_client, type='train') client_sizes = go.Figure( data=[go.Bar(x=[f'Client{n}' for n, _ in client_targets.items()], y=[len(ce) for _, ce in client_targets.items()], @@ -56,9 +58,9 @@ def update_data_sizes(selected_client, selected_colors): client_sizes.update_layout(margin=dict(l=48, r=48, b=64, t=86)) return client_sizes - @_add_chart(section='dataset', figure_name='figure_client_data_tsne', span=12) + @fedboard.add_chart(section='dataset', figure_name='figure_client_data_tsne', span=1.0) def update_data_tsne_value(selected_client, selected_colors): - tsne_data = viewModel.get_client_dataset_tsne(selected_client, "train", 200) + tsne_data = get_client_dataset_tsne(selected_client, "train", 200) if tsne_data is not None: data = [] for idx, cid in enumerate(selected_client): diff --git a/fedlab/board/builtin/renderer.py b/fedlab/board/builtin/renderer.py new file mode 100644 index 00000000..e7221c3e --- /dev/null +++ b/fedlab/board/builtin/renderer.py @@ -0,0 +1,52 @@ +from typing import Any + +import torch +from sklearn.manifold import TSNE + +from fedlab.board import fedboard + + +def client_param_tsne(round: int, client_ids: list[str]): + if len(client_ids) < 2: + return None + client_params: dict[str, Any] = fedboard.read_logged_obj(round, 'client_params') + raw_params = {str(id): param for id, param in client_params.items()} + params_selected = [raw_params[id][0] for id in client_ids if id in raw_params.keys()] + if len(params_selected) < 1: + return None + params_selected = torch.stack(params_selected) + params_tsne = TSNE(n_components=2, learning_rate=100, random_state=501, + perplexity=min(30.0, len(params_selected) - 1)).fit_transform( + params_selected) + return params_tsne + + +def get_client_dataset_tsne(client_ids: list[str], type: str, size): + if len(client_ids) < 1: + return None + if not fedboard.get_delegate(): + return None + raw = [] + client_range = {} + for client_id in client_ids: + data, label = fedboard.get_delegate().sample_client_data(client_id, type, size) + client_range[client_id] = (len(raw), len(raw) + len(data)) + raw += data + raw = torch.stack(raw).view(len(raw), -1) + tsne = TSNE(n_components=3, learning_rate=100, random_state=501, + perplexity=min(30.0, len(raw) - 1)).fit_transform(raw) + tsne = {cid: tsne[s:e] for cid, (s, e) in client_range.items()} + return tsne + + +def get_client_data_report(clients_ids: list[str], type: str): + res = {} + for client_id in clients_ids: + def rd(): + if fedboard.get_delegate(): + return fedboard.get_delegate().read_client_label(client_id,type=type) + else: + return {} + obj = fedboard.read_cached_obj('data','partition',f'{client_id}',rd) + res[client_id] = obj + return res diff --git a/fedlab/board/fedboard.py b/fedlab/board/fedboard.py index 3a0c61e8..a2db6666 100644 --- a/fedlab/board/fedboard.py +++ b/fedlab/board/fedboard.py @@ -2,22 +2,40 @@ import json import logging import os -import pickle from os import path from threading import Thread from typing import Any from dash import Dash -from fedlab.board.builtin.charts import _add_built_in_charts +from fedlab.board.builtin.charts import add_built_in_charts from fedlab.board.delegate import FedBoardDelegate -from fedlab.board.front.app import viewModel, create_app, add_callbacks, set_up_layout, _add_chart, _add_section -from fedlab.board.utils.io import _update_meta_file, _clear_log +from fedlab.board.front.app import viewModel, create_app,_set_up_layout, _add_chart, _add_section,_add_callbacks +from fedlab.board.utils.io import _update_meta_file, _clear_log, _log_to_fs, _read_log_from_fs, _read_cached_from_fs, \ + _cache_to_fs, _log_to_fs_append, _read_log_from_fs_appended _app: Dash | None = None +_delegate: FedBoardDelegate | None = None +_dir: str = '' -def setup(delegate: FedBoardDelegate, client_ids, max_round, name=None, log_dir=None): +def get_delegate(): + return _delegate + + +def get_log_dir(): + return _dir + + +def setup(client_ids: list[str], max_round: int, name: str = None, log_dir: str = None): + """ + Set up FedBoard + Args: + client_ids: List of client ids + max_round: Max communication round + name: Experiment name + log_dir: Log directory + """ meta_info = { 'name': name, 'max_round': max_round, @@ -31,26 +49,47 @@ def setup(delegate: FedBoardDelegate, client_ids, max_round, name=None, log_dir= _update_meta_file(log_dir, 'meta', meta_info) _update_meta_file(log_dir, 'runtime', {'state': 'START', 'round': 0}) global _app - _add_built_in_charts() - _app = create_app(log_dir, delegate) - add_callbacks(_app) + global _delegate + global _dir + _app = create_app(log_dir) + _dir = log_dir + _add_callbacks(_app) _clear_log(log_dir) +def enable_builtin_charts(delegate: FedBoardDelegate): + """ + Enable builtin charts, including 'parameters' section and 'dataset' section. + A dataset-reading delegate is required to enable these charts + Args: + delegate (FedBoardDelegate): dataset-reading delegate + + """ + global _delegate + _delegate = delegate + add_built_in_charts() + + def start_offline(log_dir=None, port=8080): + """ + Start Fedboard offline (seperated from the experiment) + Args: + log_dir: the experiment's log directory + port: Which port will the board run in + """ if log_dir is None: calling_file = inspect.stack()[1].filename calling_directory = os.path.dirname(os.path.abspath(calling_file)) log_dir = calling_directory log_dir = path.join(log_dir, '.fedboard/') - _add_built_in_charts() + add_built_in_charts() global _app _app = create_app(log_dir) - add_callbacks(_app) + _add_callbacks(_app) if _app is None: logging.error('FedBoard hasn\'t been initialized!') return - set_up_layout(_app) + _set_up_layout(_app) _app.run(host='0.0.0.0', port=port, debug=False, dev_tools_ui=True, use_reloader=False) @@ -60,14 +99,14 @@ def __init__(self, port): meta_info = { 'port': port, } - _update_meta_file(viewModel.dir, 'meta', meta_info) + _update_meta_file(_dir, 'meta', meta_info) self.port = port def _start_app(self): if _app is None: logging.error('FedBoard hasn\'t been initialized!') return - set_up_layout(_app) + _set_up_layout(_app) _app.run(host='0.0.0.0', port=self.port, debug=False, dev_tools_ui=True, use_reloader=False) def __enter__(self): @@ -78,35 +117,97 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.p1.join() -def log(round: int, client_params: dict[str, Any] = None, metrics: dict[str, Any] = None, - main_metric_name: str = None, client_metrics: dict[str, dict[str, Any]] = None): +def log(round: int, metrics: dict[str, Any] = None, client_metrics: dict[str, dict[str, Any]] = None, + main_metric_name: str = None, client_main_metric_name: str = None, **kwargs): + """ + + Args: + + round (int): Which communication round + metrics (dict): Global performance at this round. E.g., {'loss':0.02, 'acc':0.85} + client_metrics (dict): Client performance at this round. E.g., {'Client0':{'loss':0.01, 'acc':0.85}, 'Client1':...} + main_metric_name (str): Main global metric. E.g., 'loss' + client_main_metric_name (str): Main Client metric. E.g., 'acc' + + + Returns: + + """ state = "RUNNING" if round == viewModel.get_max_round(): state = 'DONE' - _update_meta_file(viewModel.dir, section='runtime', dct={'state': state, 'round': round}) - if client_params: - os.makedirs(path.join(viewModel.dir, f'log/params/raw/'), exist_ok=True) - pickle.dump(client_params, open(path.join(viewModel.dir, f'log/params/raw/rd{round}.pkl'), 'wb+')) + _update_meta_file(_dir, section='runtime', dct={'state': state, 'round': round}) + for key, obj in kwargs.items(): + _log_to_fs(_dir, type='params', sub_type=key, name=f'rd{round}', obj=obj) if metrics: - os.makedirs(path.join(viewModel.dir, f'log/performs/'), exist_ok=True) if main_metric_name is None: main_metric_name = list[metrics.keys()][0] metrics['main_name'] = main_metric_name - with open(path.join(viewModel.dir, f'log/performs/overall'), 'a+') as f: - f.write(json.dumps(metrics) + '\n') + _log_to_fs_append(_dir, type='performs', name='overall', obj=metrics) if client_metrics: - os.makedirs(path.join(viewModel.dir, f'log/performs/'), exist_ok=True) - if main_metric_name is None: - main_metric_name = list(client_metrics[list(client_metrics.keys())[0]].keys())[0] + if client_main_metric_name is None: + if main_metric_name: + client_main_metric_name = main_metric_name + else: + client_main_metric_name = list(client_metrics[list(client_metrics.keys())[0]].keys())[0] for cid in client_metrics.keys(): - client_metrics[cid]['main_name'] = main_metric_name - with open(path.join(viewModel.dir, f'log/performs/client'), 'a+') as f: - f.write(json.dumps(client_metrics) + '\n') + client_metrics[cid]['main_name'] = client_main_metric_name + _log_to_fs_append(_dir, type='performs', name='client', obj=client_metrics) def add_section(section: str, type: str): + """ + + Args: + section (str): Section name + type (str): Section type, can be 'normal' and 'slider', when set to 'slider', additional + + + Returns: + + """ + assert type in ['normal', 'slider'] _add_section(section=section, type=type) -def add_chart(section=None, figure_name=None, span=6): +def add_chart(section=None, figure_name=None, span=0.5): + """ + Used as decorators for other functions, + For sections with type = 'normal', the function takes input (selected_clients, selected_colors) + For sections with type = 'slider', the function takes input (slider_value, selected_clients, selected_colors) + + Args: + section (str): Section the chart will be added to + figure_name (str) : Chart ID + span (float): Chart span, E.g. 0.6 for 60% row width + + Examples: + + @add_chart('diy', 'slider', 1.0) + def ct(slider_value, selected_clients, selected_colors): + ...render the figure + return figure + + @add_chart('diy2', 'slider', 1.0) + def ct(slider_value, selected_clients, selected_colors): + ...render the figure + return figure + + """ return _add_chart(section=section, figure_name=figure_name, span=span) + + +def read_logged_obj(round: int, type: str): + return _read_log_from_fs(_dir, type='params', sub_type=type, name=f'rd{round}') + + +def read_logged_obj_appended(type: str, name: str, sub_type: str = None): + return _read_log_from_fs_appended(_dir, type=type, name=name, sub_type=sub_type) + + +def read_cached_obj(type: str, sub_type: str, key: str, creator: callable): + obj = _read_cached_from_fs(_dir, type=type, sub_type=sub_type, name=key) + if not obj: + obj = creator() + _cache_to_fs(obj, _dir, type, sub_type, key) + return obj diff --git a/fedlab/board/front/app.py b/fedlab/board/front/app.py index 1b22c92c..2f076ba6 100644 --- a/fedlab/board/front/app.py +++ b/fedlab/board/front/app.py @@ -18,16 +18,16 @@ _charts: dict[str:dict[str, dict]] = {} -def create_app(log_dir, delegate=None): - viewModel.init(log_dir, delegate) +def create_app(log_dir): + viewModel.init(log_dir) app = Dash(__name__, title="FedBoard", update_title=None, assets_url_path='assets') return app -def _add_chart(section=None, figure_name=None, span=6): +def _add_chart(section=None, figure_name=None, span=0.5): def ac(func): _charts.setdefault(section, {}) - _charts[section][figure_name] = {'func': func, 'name': figure_name, 'span': span} + _charts[section][figure_name] = {'func': func, 'name': figure_name, 'span': int(12 * span)} return func return ac @@ -99,10 +99,10 @@ def wrapper(value, selected_client, regex, fig_id): return None -def set_up_layout(app: Dash): - tabs = [dmc.Tab('performance', value='performance')] +def _set_up_layout(app: Dash): + tabs = [dmc.Tab('performance', value='performance', style={"font-size": 17})] for sec in _charts.keys(): - tabs.append(dmc.Tab(sec, value=sec)) + tabs.append(dmc.Tab(sec, value=sec, style={"font-size": 17})) tablist = dmc.TabsList(tabs) tabs_pages = [tablist, dmc.TabsPanel(page_performance, value="performance")] for section, type in _section_types.items(): @@ -127,22 +127,21 @@ def set_up_layout(app: Dash): children=[ dmc.Col(selection, span='content'), dmc.Divider(orientation='vertical', mt='md', mb='md', mr='lg'), - dmc.Col(tabs, span='auto')] + dmc.Col(tabs, span='auto', ml='xs')] ) ]) main = dmc.Grid([dmc.Col(card_state, span='content'), dmc.Col( - cyto_graph, span=5), dmc.Col(card_overall_performance, span='auto') + cyto_graph, span='auto'), dmc.Col(card_overall_performance, span=5) , dmc.Col(bottom_page, span=12)]) app.layout = dmc.Container( - [dmc.Header( - height=100, children=[dmc.Grid( - dmc.Col(dmc.Image(src='assets/FedLab-logo.svg', height=56, fit="contain", mt='lg'), - span="content"))], style={"backgroundColor": "#ffffff"}, mb='lg' - ), main], fluid=True) + [dmc.Header(height=110, children=[dmc.Grid( + dmc.Col(dmc.Image(src='assets/FedLab-logo.svg', height=64, fit="contain", mt='lg'), + span=3))], style={"backgroundColor": "#ffffff"}, mb='lg' + ), main], fluid=True) -def add_callbacks(app: Dash): +def _add_callbacks(app: Dash): @app.callback( Output("cytoscape", "elements"), Output("cytoscape", "layout"), diff --git a/fedlab/board/front/layout.py b/fedlab/board/front/layout.py index 999b5b4d..d551f534 100644 --- a/fedlab/board/front/layout.py +++ b/fedlab/board/front/layout.py @@ -5,6 +5,7 @@ from dash_iconify import DashIconify OVERVIEW_HEIGHT = 300 +OVERVIEW_WIDTH = 290 card_state = dmc.Card( children=[ dcc.Interval( @@ -49,7 +50,7 @@ withBorder=True, shadow="lg", radius="lg", - style={"width": 270, 'height': OVERVIEW_HEIGHT}, + style={"width": OVERVIEW_WIDTH, 'height': OVERVIEW_HEIGHT}, ) card_overall_performance = dmc.Card( @@ -59,9 +60,9 @@ style={'height': OVERVIEW_HEIGHT}, children=[ dmc.Group(children=[ - dmc.Text("Overall Performance", id='name_overall', size=16, mb='md'), + dmc.Text("Overall Performance", id='name_overall', size=16), dmc.Select(id='select_overall_metrics', size='xs', clearable=False, value="main")] - , position="apart", style={"height": 40}), + , position="apart"), dmc.Space(h='md'), dcc.Graph(id='figure_overall', style={"height": "80%"}, config={'autosizable': False, 'displaylogo': False}) @@ -70,14 +71,11 @@ cyto_graph = dmc.Card( children=[ - # html.H5("显示"), dmc.ChipGroup( [dmc.Chip(x['label'], value=x['value'], size='xs') for x in [ {"label": "COSE", "value": "cose"}, {"label": "Cent", "value": "concentric"}, - # {"label": "环状", "value": "circle"}, {"label": "Breadth", "value": "breadthfirst"}, - # {"label": "随机", "value": "random"}, {"label": "Grid", "value": "grid"}, ]], id="select_cyto_layout", @@ -98,20 +96,19 @@ style={'height': OVERVIEW_HEIGHT}, ) - page_performance = html.Div( children=[ dmc.Group(children=[ dmc.Text("Client Performance", size=17, ml='md'), - dmc.Select(id='select_client_metrics', size='sm', value="main", mb=0, mt='md', ml='md')] - , style={"height": 60}, align='center'), + dmc.Select(id='select_client_metrics', size='sm', value="main", mb=0, ml='md')] + , style={"height": 60}, align='center', mt='md'), dmc.Space(h='md'), dcc.Graph(id='figure_client_perform', config={'displaylogo': False}) ] ) selection = html.Div( - style={'width': 260}, + style={'width': OVERVIEW_WIDTH - 16}, children=[dmc.Text('Select Clients', size=18), dmc.Space(h='md'), dmc.Grid( @@ -125,14 +122,16 @@ dmc.ActionIcon( DashIconify(icon='iconoir:list-select', width=24), id="client_selection_check", - color="blue", variant="light", size=32, mr='md', mb='md'), + color="blue", variant="light", size=32, mr='sm'), span='content')], mb='sm', align='center'), dmc.ChipGroup( [], id="client_selection_ms", value=[], + mr='xs', + mb='lg', multiple=True, - mah=400, + mah=100, ) ]) @@ -150,7 +149,7 @@ def _gen_charts_grid(section, charts_config): def build_normal_charts(section, charts): grids = _gen_charts_grid(section, charts) - return dmc.Grid(grids) + return dmc.Grid(grids,mt='md') def build_slider_charts(section, charts): @@ -169,9 +168,9 @@ def build_slider_charts(section, charts): ml='sm', min=1, max=1, - style={'width': '50%', 'height': 60} + style={'width': '70%', 'height': 60} )], - mt=5 + mt='md' , style={"height": 60, "width": "100%"}, align='center'), dmc.Grid(grids)] ) diff --git a/fedlab/board/front/view_model.py b/fedlab/board/front/view_model.py index 9663f719..fcfa995d 100644 --- a/fedlab/board/front/view_model.py +++ b/fedlab/board/front/view_model.py @@ -1,18 +1,13 @@ import json -import os.path -import pickle from os import path from typing import Any import diskcache -import torch from dash import DiskcacheManager -from sklearn.manifold import TSNE -from fedlab.board.delegate import FedBoardDelegate from fedlab.board.utils.color import random_color from fedlab.board.utils.data import encode_int_array -from fedlab.board.utils.io import _read_meta_file +from fedlab.board.utils.io import _read_meta_file, _read_log_from_fs_appended class ViewModel: @@ -21,8 +16,7 @@ def __init__(self): self.setup = False self.background_callback_manager = None - def init(self, dir: str, delegate: FedBoardDelegate = None): - self.delegate = delegate + def init(self, dir: str): self.dir = dir self.colors = {id: random_color(int(id)) for id in self.get_client_ids()} self.setup = True @@ -72,12 +66,12 @@ def get_client_ids(self): return res def client_id2index(self, client_id: str) -> int: - res = _read_meta_file(self.dir, "meta", ["client_ids"])['client_ids'] + res: str = _read_meta_file(self.dir, "meta", ["client_ids"])['client_ids'] res: list[str] = json.loads(res) return res.index(client_id) def client_ids2indexes(self, client_ids: list[str]) -> list[int]: - res = _read_meta_file(self.dir, "meta", ["client_ids"])['client_ids'] + res: str = _read_meta_file(self.dir, "meta", ["client_ids"])['client_ids'] res: list[str] = json.loads(res) return [res.index(id) for id in client_ids] @@ -99,60 +93,10 @@ def encode_client_ids(self, client_ids: list[str]): client_indexes = self.client_ids2indexes(client_ids) return encode_int_array(client_indexes) - def client_param_tsne(self, round: int, client_ids: list[str]): - if not os.path.exists(os.path.join(self.dir, f'log/params/raw/rd{round}.pkl')): - return None - if len(client_ids) < 2: - return None - raw_params = {str(id): param for id, param in - pickle.load(open(os.path.join(self.dir, f'log/params/raw/rd{round}.pkl'), 'rb')).items()} - params_selected = [raw_params[id][0] for id in client_ids if id in raw_params.keys()] - if len(params_selected) < 1: - return None - params_selected = torch.stack(params_selected) - params_tsne = TSNE(n_components=2, learning_rate=100, random_state=501, - perplexity=min(30.0, len(params_selected) - 1)).fit_transform( - params_selected) - return params_tsne - - def get_client_dataset_tsne(self, client_ids: list, type: str, size): - if len(client_ids) < 2: - return None - if not self.delegate: - return None - raw = [] - client_range = {} - for client_id in client_ids: - data, label = self.delegate.sample_client_data(client_id, type, size) - client_range[client_id] = (len(raw), len(raw) + len(data)) - raw += data - raw = torch.stack(raw).view(len(raw), -1) - tsne = TSNE(n_components=3, learning_rate=100, random_state=501, - perplexity=min(30.0, len(raw) - 1)).fit_transform(raw) - tsne = {cid: tsne[s:e] for cid, (s, e) in client_range.items()} - return tsne - - def get_client_data_report(self, clients_ids: list, type: str): - res = {} - for client_id in clients_ids: - target_file = os.path.join(self.dir, f'cache/data/partition/{client_id}.pkl') - if os.path.exists(target_file): - res[client_id] = pickle.load(open(target_file, 'rb')) - else: - os.makedirs(os.path.join(self.dir, f'cache/data/partition/'), exist_ok=True) - if self.delegate: - res[client_id] = self.delegate.read_client_label(client_id, type=type) - else: - res[client_id] = {} - pickle.dump(res[client_id], open(target_file, 'wb+')) - return res - def get_overall_metrics(self): main_name = "" metrics = [] - if not os.path.exists(path.join(self.dir, f'log/performs/overall')): - return metrics, main_name - log_lines = open(path.join(self.dir, f'log/performs/overall')).readlines() + log_lines = _read_log_from_fs_appended(self.dir, type='performs', name='overall') if len(log_lines) > 1: obj: dict[str, Any] = json.loads(log_lines[-1]) main_name = obj['main_name'] @@ -162,9 +106,7 @@ def get_overall_metrics(self): def get_client_metrics(self): main_name = "" metrics = [] - if not os.path.exists(path.join(self.dir, f'log/performs/client')): - return metrics, main_name - log_lines = open(path.join(self.dir, f'log/performs/client')).readlines() + log_lines = _read_log_from_fs_appended(self.dir, type='performs', name='client') if len(log_lines) > 1: obj: dict[str, dict[str:Any]] = json.loads(log_lines[-1]) if len(obj.keys()) > 0: @@ -176,9 +118,8 @@ def get_client_metrics(self): def get_overall_performance(self): res_all = [] main_name = "" - if not os.path.exists(path.join(self.dir, f'log/performs/overall')): - return res_all, main_name - for line in open(path.join(self.dir, f'log/performs/overall')).readlines(): + log_lines = _read_log_from_fs_appended(self.dir, type='performs', name='overall') + for line in log_lines: obj = json.loads(line) main_name = obj['main_name'] res_all.append(obj) @@ -187,9 +128,8 @@ def get_overall_performance(self): def get_client_performance(self, client_ids: list[str]): res = {} main_name = "" - if not os.path.exists(path.join(self.dir, f'log/performs/client')): - return res, main_name - for line in open(path.join(self.dir, f'log/performs/client')).readlines(): + log_lines = _read_log_from_fs_appended(self.dir, type='performs', name='client') + for line in log_lines: obj = json.loads(line) for client_id in client_ids: main_name = obj[client_id]['main_name'] diff --git a/fedlab/board/utils/io.py b/fedlab/board/utils/io.py index 3337c179..3f43479b 100644 --- a/fedlab/board/utils/io.py +++ b/fedlab/board/utils/io.py @@ -1,7 +1,10 @@ import configparser +import json import os +import pickle import shutil from os import path +from typing import Any def _clear_log(dir): @@ -24,6 +27,7 @@ def _update_meta_file(file_root: str, section: str, dct: dict): with open(config_file, 'w') as configfile: config.write(configfile) + def _read_meta_file(file_root: str, section: str, keys): config_file = path.join(file_root, 'experiment.ini') if not os.path.isfile(config_file): @@ -34,3 +38,55 @@ def _read_meta_file(file_root: str, section: str, keys): return None res = {key: config.get(section, key) for key in keys} return res + + +def _log_to_fs(file_root: str, type: str, name: str, obj: Any, sub_type: str = None): + pt = path.join(file_root, f'log/{type}/') + if sub_type: + pt = path.join(pt, f'{sub_type}/') + os.makedirs(pt, exist_ok=True) + pickle.dump(obj, open(path.join(pt, f'{name}.pkl'), 'wb+')) + + +def _log_to_fs_append(file_root: str, type: str, name: str, obj: Any, sub_type: str = None): + pt = path.join(file_root, f'log/{type}/') + if sub_type: + pt = path.join(pt, f'{sub_type}/') + os.makedirs(pt, exist_ok=True) + with open(path.join(pt, f'{name}.log'), 'a+') as f: + f.write(json.dumps(obj) + '\n') + + +def _read_log_from_fs(file_root: str, type: str, name: str, sub_type: str = None): + target = path.join(file_root, f'log/{type}/') + if sub_type: + target = path.join(target, f'{sub_type}/') + target = path.join(target, f'{name}.pkl') + try: + return pickle.load(open(target, 'rb')) + except: + return None + + +def _read_log_from_fs_appended(file_root: str, type: str, name: str, sub_type: str = None): + target = path.join(file_root, f'log/{type}/') + if sub_type: + target = path.join(target, f'{sub_type}/') + target = path.join(target, f'{name}.log') + if not os.path.exists(target): + return [] + return open(target).readlines() + + +def _read_cached_from_fs(file_root: str, type: str, sub_type: str, name: str): + target = path.join(file_root, f'cache/{type}/{sub_type}/{name}.pkl') + try: + return pickle.load(open(target, 'rb')) + except: + return None + + +def _cache_to_fs(obj, file_root: str, type: str, sub_type: str, name: str): + os.makedirs(path.join(file_root, f'cache/{type}/{sub_type}/'), exist_ok=True) + target = path.join(file_root, f'cache/{type}/{sub_type}/{name}.pkl') + pickle.dump(obj, open(target, 'wb+'))