diff --git a/.gitignore b/.gitignore index 2ea92832..5c5311bb 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,8 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +dataset/ +initial_model/ # PyInstaller # Usually these files are written by a python script from a template diff --git a/core/common/constant.py b/core/common/constant.py index 5444e289..d8723ae4 100644 --- a/core/common/constant.py +++ b/core/common/constant.py @@ -22,25 +22,31 @@ class DatasetFormat(Enum): File format of inputting dataset. Currently, file formats are as follows: txt, csv. """ + CSV = "csv" TXT = "txt" JSON = "json" + JSONL = "jsonl" class ParadigmType(Enum): """ Algorithm paradigm type. """ + SINGLE_TASK_LEARNING = "singletasklearning" INCREMENTAL_LEARNING = "incrementallearning" MULTIEDGE_INFERENCE = "multiedgeinference" LIFELONG_LEARNING = "lifelonglearning" + FEDERATED_LEARNING = "federatedlearning" + FEDERATED_CLASS_INCREMENTAL_LEARNING = "federatedclassincrementallearning" class ModuleType(Enum): """ Algorithm module type. """ + BASEMODEL = "basemodel" # HEM @@ -63,20 +69,26 @@ class ModuleType(Enum): UNSEEN_SAMPLE_RECOGNITION = "unseen_sample_recognition" UNSEEN_SAMPLE_RE_RECOGNITION = "unseen_sample_re_recognition" + # FL_AGG + AGGREGATION = "aggregation" + class SystemMetricType(Enum): """ System metric type of ianvs. """ + SAMPLES_TRANSFER_RATIO = "samples_transfer_ratio" FWT = "FWT" BWT = "BWT" TASK_AVG_ACC = "task_avg_acc" MATRIX = "MATRIX" + FORGET_RATE = "forget_rate" class TestObjectType(Enum): """ Test object type of ianvs. """ + ALGORITHMS = "algorithms" diff --git a/core/storymanager/rank/rank.py b/core/storymanager/rank/rank.py index 5c3d85d4..ac985c88 100644 --- a/core/storymanager/rank/rank.py +++ b/core/storymanager/rank/rank.py @@ -35,15 +35,12 @@ class Rank: def __init__(self, config): self.sort_by: list = [] - self.visualization: dict = { - "mode": "selected_only", - "method": "print_table" - } + self.visualization: dict = {"mode": "selected_only", "method": "print_table"} self.selected_dataitem: dict = { "paradigms": ["all"], "modules": ["all"], "hyperparameters": ["all"], - "metrics": ["all"] + "metrics": ["all"], } self.save_mode: str = "selected_and_all" @@ -62,15 +59,21 @@ def _parse_config(self, config): def _check_fields(self): if not self.sort_by and not isinstance(self.sort_by, list): - raise ValueError(f"rank's sort_by({self.sort_by}) must be provided and be list type.") + raise ValueError( + f"rank's sort_by({self.sort_by}) must be provided and be list type." + ) if not self.visualization and not isinstance(self.visualization, dict): - raise ValueError(f"rank's visualization({self.visualization}) " - f"must be provided and be dict type.") + raise ValueError( + f"rank's visualization({self.visualization}) " + f"must be provided and be dict type." + ) if not self.selected_dataitem and not isinstance(self.selected_dataitem, dict): - raise ValueError(f"rank's selected_dataitem({self.selected_dataitem}) " - f"must be provided and be dict type.") + raise ValueError( + f"rank's selected_dataitem({self.selected_dataitem}) " + f"must be provided and be dict type." + ) if not self.selected_dataitem.get("paradigms"): raise ValueError("not found paradigms of selected_dataitem in rank.") @@ -82,8 +85,10 @@ def _check_fields(self): raise ValueError("not found metrics of selected_dataitem in rank.") if not self.save_mode and not isinstance(self.save_mode, list): - raise ValueError(f"rank's save_mode({self.save_mode}) " - f"must be provided and be list type.") + raise ValueError( + f"rank's save_mode({self.save_mode}) " + f"must be provided and be list type." + ) @classmethod def _get_all_metric_names(cls, test_results) -> list: @@ -133,7 +138,6 @@ def _sort_all_df(self, all_df, all_metric_names): if metric_name not in all_metric_names: continue - sort_metric_list.append(metric_name) is_ascend_list.append(ele.get(metric_name) == "ascend") @@ -198,7 +202,15 @@ def _get_selected(self, test_cases, test_results) -> pd.DataFrame: if metric_names == ["all"]: metric_names = self._get_all_metric_names(test_results) - header = ["algorithm", *metric_names, "paradigm", *module_types, *hps_names, "time", "url"] + header = [ + "algorithm", + *metric_names, + "paradigm", + *module_types, + *hps_names, + "time", + "url", + ] all_df = copy.deepcopy(self.all_df) selected_df = pd.DataFrame(all_df, columns=header) @@ -220,8 +232,7 @@ def _draw_pictures(self, test_cases, test_results): for test_case in test_cases: out_put = test_case.output_dir test_result = test_results[test_case.id][0] - matrix = test_result.get('Matrix') - #print(out_put) + matrix = test_result.get("Matrix") for key in matrix.keys(): draw_heatmap_picture(out_put, key, matrix[key]) @@ -285,4 +296,5 @@ def plot(self): except Exception as err: raise RuntimeError( f"process visualization(method={method}) of " - f"rank file({self.selected_rank_file}) failed, error: {err}.") from err + f"rank file({self.selected_rank_file}) failed, error: {err}." + ) from err diff --git a/core/testcasecontroller/algorithm/algorithm.py b/core/testcasecontroller/algorithm/algorithm.py index cb2d9b7b..d933eac8 100644 --- a/core/testcasecontroller/algorithm/algorithm.py +++ b/core/testcasecontroller/algorithm/algorithm.py @@ -24,6 +24,8 @@ IncrementalLearning, MultiedgeInference, LifelongLearning, + FederatedLearning, + FederatedClassIncrementalLearning ) from core.testcasecontroller.generation_assistant import get_full_combinations @@ -64,12 +66,24 @@ def __init__(self, name, config): "train_ratio": 0.8, "splitting_method": "default" } + self.fl_data_setting: dict = { + "train_ratio": 1.0, + "splitting_method": "default", + "data_partition": "iid", + 'non_iid_ratio': 0.6, + "label_data_ratio": 1.0 + } + self.initial_model_url: str = "" self.modules: list = [] self.modules_list = None + self.mode: str = "" + self.quantization_type: str = "" + self.llama_quantize_path: str = "" self._parse_config(config) self._load_third_party_packages() + # pylint: disable=R0911 def paradigm(self, workspace: str, **kwargs): """ get test process of AI algorithm paradigm. @@ -91,7 +105,6 @@ def paradigm(self, workspace: str, **kwargs): # pylint: disable=C0103 for k, v in self.__dict__.items(): config.update({k: v}) - if self.paradigm_type == ParadigmType.SINGLE_TASK_LEARNING.value: return SingleTaskLearning(workspace, **config) @@ -104,6 +117,12 @@ def paradigm(self, workspace: str, **kwargs): if self.paradigm_type == ParadigmType.LIFELONG_LEARNING.value: return LifelongLearning(workspace, **config) + if self.paradigm_type == ParadigmType.FEDERATED_LEARNING.value: + return FederatedLearning(workspace, **config) + + if self.paradigm_type == ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value: + return FederatedClassIncrementalLearning(workspace, **config) + return None def _check_fields(self): diff --git a/core/testcasecontroller/algorithm/module/module.py b/core/testcasecontroller/algorithm/module/module.py index 6d487d97..1772725e 100644 --- a/core/testcasecontroller/algorithm/module/module.py +++ b/core/testcasecontroller/algorithm/module/module.py @@ -86,6 +86,7 @@ def get_module_instance(self, module_type): function """ + print(f'hyperparameters_list: {self.hyperparameters_list}') class_factory_type = ClassType.GENERAL if module_type in [ModuleType.HARD_EXAMPLE_MINING.value]: class_factory_type = ClassType.HEM @@ -106,6 +107,20 @@ def get_module_instance(self, module_type): elif module_type in [ModuleType.UNSEEN_SAMPLE_RECOGNITION.value, ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value]: class_factory_type = ClassType.UTD + elif module_type in [ModuleType.AGGREGATION.value]: + class_factory_type = ClassType.FL_AGG + agg = None + print(self.url) + if self.url : + try: + utils.load_module(self.url) + agg = ClassFactory.get_cls( + type_name=class_factory_type, t_cls_name=self.name)(**self.hyperparameters) + print(agg) + except Exception as err: + raise RuntimeError(f"module(type={module_type} loads class(name={self.name}) " + f"failed, error: {err}.") from err + return self.name, agg if self.url: try: @@ -113,7 +128,6 @@ def get_module_instance(self, module_type): # pylint: disable=E1134 func = ClassFactory.get_cls( type_name=class_factory_type, t_cls_name=self.name)(**self.hyperparameters) - return func except Exception as err: raise RuntimeError(f"module(type={module_type} loads class(name={self.name}) " diff --git a/core/testcasecontroller/algorithm/paradigm/__init__.py b/core/testcasecontroller/algorithm/paradigm/__init__.py index c966bd38..5c50b243 100644 --- a/core/testcasecontroller/algorithm/paradigm/__init__.py +++ b/core/testcasecontroller/algorithm/paradigm/__init__.py @@ -17,3 +17,4 @@ from .singletask_learning import SingleTaskLearning from .multiedge_inference import MultiedgeInference from .lifelong_learning import LifelongLearning +from .federated_learning import FederatedLearning, FederatedClassIncrementalLearning diff --git a/core/testcasecontroller/algorithm/paradigm/base.py b/core/testcasecontroller/algorithm/paradigm/base.py index cf36cd4e..e5178e29 100644 --- a/core/testcasecontroller/algorithm/paradigm/base.py +++ b/core/testcasecontroller/algorithm/paradigm/base.py @@ -18,8 +18,8 @@ from sedna.core.incremental_learning import IncrementalLearning from sedna.core.lifelong_learning import LifelongLearning - from core.common.constant import ModuleType, ParadigmType +from .sedna_federated_learning import FederatedLearning class ParadigmBase: @@ -97,33 +97,51 @@ def build_paradigm_job(self, paradigm_type): return IncrementalLearning( estimator=self.module_instances.get(ModuleType.BASEMODEL.value), hard_example_mining=self.module_instances.get( - ModuleType.HARD_EXAMPLE_MINING.value)) + ModuleType.HARD_EXAMPLE_MINING.value + ), + ) if paradigm_type == ParadigmType.LIFELONG_LEARNING.value: return LifelongLearning( - estimator=self.module_instances.get( - ModuleType.BASEMODEL.value), + estimator=self.module_instances.get(ModuleType.BASEMODEL.value), task_definition=self.module_instances.get( - ModuleType.TASK_DEFINITION.value), + ModuleType.TASK_DEFINITION.value + ), task_relationship_discovery=self.module_instances.get( - ModuleType.TASK_RELATIONSHIP_DISCOVERY.value), + ModuleType.TASK_RELATIONSHIP_DISCOVERY.value + ), task_allocation=self.module_instances.get( - ModuleType.TASK_ALLOCATION.value), + ModuleType.TASK_ALLOCATION.value + ), task_remodeling=self.module_instances.get( - ModuleType.TASK_REMODELING.value), + ModuleType.TASK_REMODELING.value + ), inference_integrate=self.module_instances.get( - ModuleType.INFERENCE_INTEGRATE.value), + ModuleType.INFERENCE_INTEGRATE.value + ), task_update_decision=self.module_instances.get( - ModuleType.TASK_UPDATE_DECISION.value), + ModuleType.TASK_UPDATE_DECISION.value + ), unseen_task_allocation=self.module_instances.get( - ModuleType.UNSEEN_TASK_ALLOCATION.value), + ModuleType.UNSEEN_TASK_ALLOCATION.value + ), unseen_sample_recognition=self.module_instances.get( - ModuleType.UNSEEN_SAMPLE_RECOGNITION.value), + ModuleType.UNSEEN_SAMPLE_RECOGNITION.value + ), unseen_sample_re_recognition=self.module_instances.get( - ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value) + ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value + ), ) # pylint: disable=E1101 if paradigm_type == ParadigmType.MULTIEDGE_INFERENCE.value: - return self.modules_funcs.get(ModuleType.BASEMODEL.value)() + return self.module_instances.get(ModuleType.BASEMODEL.value) + + if paradigm_type in [ + ParadigmType.FEDERATED_LEARNING.value, + ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value, + ]: + return FederatedLearning( + estimator=self.module_instances.get(ModuleType.BASEMODEL.value) + ) return None diff --git a/core/testcasecontroller/algorithm/paradigm/federated_learning/__init__.py b/core/testcasecontroller/algorithm/paradigm/federated_learning/__init__.py new file mode 100644 index 00000000..55ebbea2 --- /dev/null +++ b/core/testcasecontroller/algorithm/paradigm/federated_learning/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=missing-module-docstring +from .federated_learning import FederatedLearning +from .federated_class_incremental_learning import FederatedClassIncrementalLearning diff --git a/core/testcasecontroller/algorithm/paradigm/federated_learning/federated_class_incremental_learning.py b/core/testcasecontroller/algorithm/paradigm/federated_learning/federated_class_incremental_learning.py new file mode 100644 index 00000000..3baaf072 --- /dev/null +++ b/core/testcasecontroller/algorithm/paradigm/federated_learning/federated_class_incremental_learning.py @@ -0,0 +1,295 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Federated Class-Incremental Learning Paradigm""" +# pylint: disable=C0412 +# pylint: disable=W1203 +# pylint: disable=C0103 +# pylint: disable=C0206 +# pylint: disable=C0201 +import numpy as np +from core.common.log import LOGGER +from core.common.constant import ParadigmType, SystemMetricType +from core.testcasecontroller.metrics.metrics import get_metric_func +from .federated_learning import FederatedLearning + + +class FederatedClassIncrementalLearning(FederatedLearning): + # pylint: disable=too-many-instance-attributes + """ + FederatedClassIncrementalLearning + Federated Class-Incremental Learning Paradigm + Notes: + 1. Ianvs serves as testing tools for test objects, e.g., algorithms. + 2. Ianvs does NOT include code directly on test object. + 3. Algorithms serve as typical test objects in Ianvs + and detailed algorithms are thus NOT included in this Ianvs python file. + 4. As for the details of example test objects, e.g., algorithms, + please refer to third party packages in Ianvs example. + For example, AI workflow and interface pls refer to sedna + (sedna docs: https://sedna.readthedocs.io/en/latest/api/lib/index.html), + and module implementation pls refer to `examples' test algorithms`, + e.g., basemodel.py, hard_example_mining.py. + + Parameters + --------- + workspace: string + the output required for Federated Class-Incremental Learning paradigm. + kwargs: dict + config required for the test process of lifelong learning paradigm, + e.g.: algorithm modules, dataset, initial network, incremental rounds, + network eval config, etc. + """ + + def __init__(self, workspace, **kwargs): + super().__init__(workspace, **kwargs) + self.incremental_rounds = kwargs.get("incremental_rounds", 1) + self.system_metric_info = { + SystemMetricType.FORGET_RATE.value: [], + SystemMetricType.TASK_AVG_ACC.value: {}, + } + + self.aggregate_clients = [] + self.train_infos = [] + + self.forget_rate_metrics = [] + self.accuracy_per_round = [] + metrics_dict = kwargs.get("model_eval", {})["model_metric"] + _, accuracy_func = get_metric_func(metrics_dict) + self.accuracy_func = accuracy_func + + def task_definition(self, dataset_files, task_id): + """Define the task for the class incremental learning paradigm + + Args: + dataset_files (list): dataset_files for train data + task_id (int): task id for the current task + + Returns: + list: train dataset in numpy format for each task + """ + LOGGER.info(f"len(dataset_files): {len(dataset_files)}") + # 1. Partition Dataset + train_dataset_files, _ = dataset_files[task_id] + LOGGER.info(f"train_dataset_files: {train_dataset_files}") + train_datasets = self.train_data_partition(train_dataset_files) + LOGGER.info(f"train_datasets: {len(train_datasets)}") + task_size = self.get_task_size(train_datasets) + LOGGER.info(f"task_size: {task_size}") + # 2. According to setting, to split the label and unlabel data for each task + train_datasets = self.split_label_unlabel_data(train_datasets) + # 3. Return the dataset for each task [{label_data, unlabel_data}, ...] + return train_datasets, task_size + + def get_task_size(self, train_datasets): + """get the task size for each task + + Args: + train_datasets (list): train dataset for each client + + Returns: + int: task size for each task + """ + LOGGER.info(f"train_datasets: {len(train_datasets[0])}") + return np.unique(train_datasets[0][1]).shape[0] + + def split_label_unlabel_data(self, train_datasets): + """split train dataset into label and unlabel data for semi-supervised learning + + Args: + train_datasets (list): train dataset for each client + + Returns: + list: the new train dataset for each client that in label and unlabel format + [{label_x: [], label_y: [], unlabel_x: [], unlabel_y: []}, ...] + """ + label_ratio = self.fl_data_setting.get("label_data_ratio") + new_train_datasets = [] + train_dataset_len = len(train_datasets) + for i in range(train_dataset_len): + train_dataset_dict = {} + label_data_number = int(label_ratio * len(train_datasets[i][0])) + # split dataset into label and unlabel data + train_dataset_dict["label_x"] = train_datasets[i][0][:label_data_number] + train_dataset_dict["label_y"] = train_datasets[i][1][:label_data_number] + train_dataset_dict["unlabel_x"] = train_datasets[i][0][label_data_number:] + train_dataset_dict["unlabel_y"] = train_datasets[i][1][label_data_number:] + new_train_datasets.append(train_dataset_dict) + return new_train_datasets + + def init_client(self): + self.clients = [ + self.build_paradigm_job( + ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value + ) + for _ in range(self.clients_number) + ] + + def run(self): + """run the Federated Class-Incremental Learning paradigm + This function will run the Federated Class-Incremental Learning paradigm. + 1. initialize the clients + 2. split the dataset into several tasks + 3. train the model on the clients + 4. aggregate the model weights and maybe need to perform some helper function + 5. send the weights to the clients + 6. evaluate the model performance on old classes + 7. finally, return the prediction result and system metric information + Returns: + list: prediction result + dict: system metric information + """ + + self.init_client() + dataset_files = self._split_dataset(self.incremental_rounds) + test_dataset_files = self._split_test_dataset(self.incremental_rounds) + LOGGER.info(f"get the dataset_files: {dataset_files}") + forget_rate = self.system_metric_info.get(SystemMetricType.FORGET_RATE.value) + for task_id in range(self.incremental_rounds): + train_datasets, task_size = self.task_definition(dataset_files, task_id) + testdatasets = test_dataset_files[: task_id + 1] + for r in range(self.rounds): + LOGGER.info( + f"Round {r} task id: {task_id} len of train_datasets: {len(train_datasets)}" + ) + self.train( + train_datasets, task_id=task_id, round=r, task_size=task_size + ) + global_weights = self.aggregator.aggregate(self.aggregate_clients) + if hasattr(self.aggregator, "helper_function"): + self.helper_function(self.train_infos) + self.send_weights_to_clients(global_weights) + self.aggregate_clients.clear() + self.train_infos.clear() + forget_rate.append(self.evaluation(testdatasets, task_id)) + test_res = self.predict(self.dataset.test_url) + return test_res, self.system_metric_info + + def _split_test_dataset(self, split_time): + """split test dataset + This function will split a test dataset from test_url into several parts. + Each part will be used for the evaluation of the model after each round. + Args: + split_time (int): the number of split time + + Returns: + list : the test dataset for each round [{x: [], y: []}, ...] + """ + test_dataset = self.dataset.load_data(self.dataset.test_url, "eval") + all_data = len(test_dataset.x) + step = all_data // split_time + test_datasets_files = [] + index = 1 + while index <= split_time: + new_dataset = {} + if index == split_time: + new_dataset["x"] = test_dataset.x[step * (index - 1) :] + new_dataset["y"] = test_dataset.y[step * (index - 1) :] + else: + new_dataset["x"] = test_dataset.x[step * (index - 1) : step * index] + new_dataset["y"] = test_dataset.y[step * (index - 1) : step * index] + test_datasets_files.append(new_dataset) + index += 1 + return test_datasets_files + + def client_train(self, client_idx, train_datasets, validation_datasets, **kwargs): + """client train function that will be called by the thread + + Args: + client_idx (int): client index + train_datasets (list): train dataset for each client + validation_datasets (list): validation dataset for each client + """ + train_info = super().client_train( + client_idx, train_datasets, validation_datasets, **kwargs + ) + with self.lock: + self.train_infos.append(train_info) + + def helper_function(self, train_infos): + """helper function for FCI Method + Many of the FCI algorithms need server to perform some operations + after the training of each round e.g data generation, model update etc. + Args: + train_infos (list of dict): the train info that the clients want to send to the server + """ + + for i in range(self.clients_number): + helper_info = self.aggregator.helper_function(train_infos[i]) + self.clients[i].helper_function(helper_info) + LOGGER.info("finish helper function") + + # pylint: disable=too-many-locals + # pylint: disable=C0200 + def evaluation(self, testdataset_files, incremental_round): + """evaluate the model performance on old classes + + Args: + testdataset_files (list): the test dataset for each round + incremental_round (int): the total incremental training round + + Returns: + float: forget rate for the current round + reference: https://ieeexplore.ieee.org/document/10574196/ + """ + if self.accuracy_func is None: + raise ValueError("accuracy function is not defined") + LOGGER.info("********start evaluation********") + if isinstance(testdataset_files, str): + testdataset_files = [testdataset_files] + job = self.get_global_model() + # caculate the seen class accuracy + old_class_acc_list = ( + [] + ) # for current round [class_0: acc_0, class_1: acc1, ....] + for index in range(len(testdataset_files)): + acc_list = [] + for data_index in range(len(testdataset_files[index]["x"])): + data = testdataset_files[index]["x"][data_index] + res = job.inference([data]) + LOGGER.info( + f"label is {testdataset_files[index]['y'][data_index]}, res is {res}" + ) + acc = self.accuracy_func( + [testdataset_files[index]["y"][data_index]], res + ) + acc_list.append(acc) + old_class_acc_list.extend(acc_list) + current_forget_rate = 0.0 + max_acc_sum = 0 + self.accuracy_per_round.append(old_class_acc_list) + self.system_metric_info[SystemMetricType.TASK_AVG_ACC.value]["accuracy"] = ( + np.mean(old_class_acc_list) + ) + # caculate the forget rate + for i in range(len(old_class_acc_list)): + max_acc_diff = 0 + for j in range(incremental_round): + acc_per_round = self.accuracy_per_round[j] + if i < len(acc_per_round): + max_acc_diff = max( + max_acc_diff, acc_per_round[i] - old_class_acc_list[i] + ) + max_acc_sum += max_acc_diff + current_forget_rate = ( + max_acc_sum / len(old_class_acc_list) if incremental_round > 0 else 0.0 + ) + tavk_avg_acc = self.system_metric_info[SystemMetricType.TASK_AVG_ACC.value][ + "accuracy" + ] + LOGGER.info( + f"for current round: {incremental_round} forget rate: {current_forget_rate}" + f"task avg acc: {tavk_avg_acc}" + ) + return current_forget_rate diff --git a/core/testcasecontroller/algorithm/paradigm/federated_learning/federated_learning.py b/core/testcasecontroller/algorithm/paradigm/federated_learning/federated_learning.py new file mode 100644 index 00000000..2a714360 --- /dev/null +++ b/core/testcasecontroller/algorithm/paradigm/federated_learning/federated_learning.py @@ -0,0 +1,242 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Federated Learning Paradigm""" +# pylint: disable=C0412 +# pylint: disable=W1203 +# pylint: disable=C0103 +# pylint: disable=C0206 +# pylint: disable=C0201 +# pylint: disable=W1203 +from threading import Thread, RLock + +from sedna.algorithms.aggregation import AggClient +from core.common.log import LOGGER +from core.common.constant import ParadigmType, ModuleType +from core.common.utils import get_file_format +from core.testcasecontroller.algorithm.paradigm.base import ParadigmBase +from core.testenvmanager.dataset.utils import read_data_from_file_to_npy, partition_data + + +class FederatedLearning(ParadigmBase): + # pylint: disable=too-many-instance-attributes + """ + FederatedLearning + Federated Learning Paradigm + Notes: + 1. Ianvs serves as testing tools for test objects, e.g., algorithms. + 2. Ianvs does NOT include code directly on test object. + 3. Algorithms serve as typical test objects in Ianvs + and detailed algorithms are thus NOT included in this Ianvs python file. + 4. As for the details of example test objects, e.g., algorithms, + please refer to third party packages in Ianvs example. + For example, AI workflow and interface pls refer to sedna + (sedna docs: https://sedna.readthedocs.io/en/latest/api/lib/index.html), + and module implementation pls refer to `examples' test algorithms`, + e.g., basemodel.py, hard_example_mining.py. + + Parameters + --------- + workspace: string + the output required for Federated Class-Incremental Learning paradigm. + kwargs: dict + config required for the test process of lifelong learning paradigm, + e.g.: algorithm modules, dataset, initial network, incremental rounds, + network eval config, etc. + """ + + def __init__(self, workspace, **kwargs): + ParadigmBase.__init__(self, workspace, **kwargs) + + self.workspace = workspace + self.kwargs = kwargs + + self.fl_data_setting = kwargs.get("fl_data_setting") + self.rounds = kwargs.get("round", 1) + self.clients = [] + self.lock = RLock() + + self.aggregate_clients = [] + self.clients_number = kwargs.get("client_number", 1) + _, self.aggregator = self.module_instances.get(ModuleType.AGGREGATION.value) + + def init_client(self): + """init clients for the paradigm of federated learning.""" + self.clients = [ + self.build_paradigm_job(ParadigmType.FEDERATED_LEARNING.value) + for i in range(self.clients_number) + ] + + def run(self): + """ + run the test flow of incremental learning paradigm. + + Returns + ------ + test result: numpy.ndarray + system metric info: dict + information needed to compute system metrics. + """ + # init client wait for connection + self.init_client() + dataset_files = self.get_all_train_data() + train_dataset_file, _ = dataset_files[0] + train_datasets = self.train_data_partition(train_dataset_file) + for r in range(self.rounds): + self.train(train_datasets, round=r) + global_weights = self.aggregator.aggregate(self.aggregate_clients) + self.send_weights_to_clients(global_weights) + self.aggregate_clients.clear() + test_res = self.predict(self.dataset.test_url) + return test_res, self.system_metric_info + + def get_all_train_data(self): + """Get all train data for the paradigm of federated learning. + + Returns: + list: train data list + """ + split_time = 1 # only one split ——all the data + return self._split_dataset(split_time) + + def _split_dataset(self, splitting_dataset_times=1): + """spit the dataset using ianvs dataset.split dataset method + + Args: + splitting_dataset_times (int, optional): . Defaults to 1. + + Returns: + list: dataset files + """ + train_dataset_ratio = self.fl_data_setting.get("train_ratio") + splitting_dataset_method = self.fl_data_setting.get("splitting_method") + return self.dataset.split_dataset( + self.dataset.train_url, + get_file_format(self.dataset.train_url), + train_dataset_ratio, + method=splitting_dataset_method, + dataset_types=("model_train", "model_eval"), + output_dir=self.dataset_output_dir(), + times=splitting_dataset_times, + ) + + def train_data_partition(self, train_dataset_file): + """ + Partition the dataset for the class incremental learning paradigm + - i.i.d + - non-i.i.d + """ + LOGGER.info(train_dataset_file) + train_datasets = None + if isinstance(train_dataset_file, str): + train_datasets = self.dataset.load_data(train_dataset_file, "train") + if isinstance(train_dataset_file, list): + train_datasets = [] + for file in train_dataset_file: + train_datasets.append(self.dataset.load_data(file, "train")) + assert train_datasets is not None, "train_dataset is None" + # translate file to real data that can be used in train + # - provide a default method to read data from file to npy + # - can support customized method to read data from file to npy + train_datasets = read_data_from_file_to_npy(train_datasets) + # Partition data to iid or non-iid + train_datasets = partition_data( + train_datasets, + self.clients_number, + self.fl_data_setting.get("data_partition"), + self.fl_data_setting.get("non_iid_ratio"), + ) + return train_datasets + + def client_train(self, client_idx, train_datasets, validation_datasets, **kwargs): + """client train + + Args: + client_idx (int): client index + train_datasets (list): train data for each client + validation_datasets (list): validation data for each client + """ + train_info = self.clients[client_idx].train( + train_datasets[client_idx], validation_datasets, **kwargs + ) + train_info["client_id"] = client_idx + agg_client = AggClient() + agg_client.num_samples = train_info["num_samples"] + agg_client.weights = self.clients[client_idx].get_weights() + with self.lock: + self.aggregate_clients.append(agg_client) + return train_info + + def train(self, train_datasets, **kwargs): + """train——multi-threading to perform client local training + + Args: + train_datasets (list): train data for each client + """ + client_threads = [] + LOGGER.info(f"len(self.clients): {len(self.clients)}") + for idx in range(self.clients_number): + client_thread = Thread( + target=self.client_train, + args=(idx, train_datasets, None), + kwargs=kwargs, + ) + client_thread.start() + client_threads.append(client_thread) + for thread in client_threads: + thread.join() + LOGGER.info("finish training") + + def send_weights_to_clients(self, global_weights): + """send weights to clients + + Args: + global_weights (list): aggregated weights + """ + for client in self.clients: + client.set_weights(global_weights) + LOGGER.info("finish send weights to clients") + + def get_global_model(self): + """get the global model for evaluation + After final round training, and aggregation + the global model can be the first client model + + Returns: + JobBase: sedna_federated_learning.FederatedLearning + """ + return self.clients[0] + + def predict(self, test_dataset_file): + """global test to predict the test dataset + + Args: + test_dataset_file (list): test data + + Returns: + list: test result + """ + test_dataset = None + if isinstance(test_dataset_file, str): + test_dataset = self.dataset.load_data(test_dataset_file, "eval") + if isinstance(test_dataset_file, list): + test_dataset = [] + for file in test_dataset_file: + test_dataset.append(self.dataset.load_data(file, "eval")) + assert test_dataset is not None, "test_dataset is None" + LOGGER.info(f" before predict {len(test_dataset.x)}") + job = self.get_global_model() + test_res = job.inference(test_dataset.x) + LOGGER.info(f" after predict {len(test_res)}") + return test_res diff --git a/core/testcasecontroller/algorithm/paradigm/multiedge_inference/multiedge_inference.py b/core/testcasecontroller/algorithm/paradigm/multiedge_inference/multiedge_inference.py index cf8ef521..4085eafd 100644 --- a/core/testcasecontroller/algorithm/paradigm/multiedge_inference/multiedge_inference.py +++ b/core/testcasecontroller/algorithm/paradigm/multiedge_inference/multiedge_inference.py @@ -16,6 +16,10 @@ import os +# pylint: disable=E0401 +import onnx + +from core.common.log import LOGGER from core.common.constant import ParadigmType from core.testcasecontroller.algorithm.paradigm.base import ParadigmBase @@ -63,8 +67,15 @@ def run(self): """ job = self.build_paradigm_job(ParadigmType.MULTIEDGE_INFERENCE.value) - - inference_result = self._inference(job, self.initial_model) + if not job.__dict__.get('model_parallel'): + inference_result = self._inference(job, self.initial_model) + else: + if 'partition' in dir(job): + models_dir, map_info = job.partition(self.initial_model) + else: + models_dir, map_info = self._partition(job.__dict__.get('partition_point_list'), + self.initial_model, os.path.dirname(self.initial_model)) + inference_result = self._inference_mp(job, models_dir, map_info) return inference_result, self.system_metric_info @@ -77,3 +88,26 @@ def _inference(self, job, trained_model): job.load(trained_model) infer_res = job.predict(inference_dataset.x, train_dataset=train_dataset) return infer_res + + def _inference_mp(self, job, models_dir, map_info): + inference_dataset = self.dataset.load_data(self.dataset.test_url, "inference") + inference_output_dir = os.path.join(self.workspace, "output/inference/") + os.environ["RESULT_SAVED_URL"] = inference_output_dir + job.load(models_dir, map_info) + infer_res = job.predict(inference_dataset.x) + return infer_res + + # pylint: disable=W0718, C0103 + def _partition(self, partition_point_list, initial_model_path, sub_model_dir): + map_info = dict({}) + for idx, point in enumerate(partition_point_list): + input_names = point['input_names'] + output_names = point['output_names'] + sub_model_path = sub_model_dir + '/' + 'sub_model_' + str(idx+1) + '.onnx' + try: + onnx.utils.extract_model(initial_model_path, + sub_model_path, input_names, output_names) + except Exception as e: + LOGGER.info(str(e)) + map_info[sub_model_path.split('/')[-1]] = point['device_name'] + return sub_model_dir, map_info diff --git a/core/testcasecontroller/algorithm/paradigm/sedna_federated_learning.py b/core/testcasecontroller/algorithm/paradigm/sedna_federated_learning.py new file mode 100644 index 00000000..3856c7ac --- /dev/null +++ b/core/testcasecontroller/algorithm/paradigm/sedna_federated_learning.py @@ -0,0 +1,69 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sedna Federated Learning JobBase""" +# pylint: disable=C0412 +import copy +from sedna.core.base import JobBase + + +# pylint: disable=C0412 +class FederatedLearning(JobBase): + """Federated Learning JobBase Represent the Federated Learning Client side""" + + # pylint: disable=too-many-locals + def __init__(self, estimator): + super().__init__(estimator) + cp_estimator = copy.deepcopy(estimator) + self.estimator = cp_estimator + + # pylint: disable=W0221 + def train(self, train_data, vald_data, **kwargs): + """Local training function + + Args: + train_data (list): training data + vald_data (list): validation data (optional) + + Returns: + dict: train info that will be used for aggregation + """ + return self.estimator.train(train_data, vald_data, **kwargs) + + def get_weights(self): + """get the weights of the model + + Returns: + list: weights of the model + """ + return self.estimator.get_weights() + + def set_weights(self, weights): + """set the weights of the model + + Args: + weights (list): set the weights of the model + """ + self.estimator.set_weights(weights) + + def helper_function(self, helper_info): + """helper function that can be used for any purpose + + Args: + helper_info (dict): helper info that generated by the server's helper function + + Returns: + None: None + """ + return self.estimator.helper_function(helper_info) diff --git a/core/testcasecontroller/algorithm/paradigm/singletask_learning/singletask_learning.py b/core/testcasecontroller/algorithm/paradigm/singletask_learning/singletask_learning.py index 19972538..90a22129 100644 --- a/core/testcasecontroller/algorithm/paradigm/singletask_learning/singletask_learning.py +++ b/core/testcasecontroller/algorithm/paradigm/singletask_learning/singletask_learning.py @@ -15,7 +15,7 @@ """Single Task Learning Paradigm""" import os - +import subprocess from core.common.constant import ParadigmType from core.testcasecontroller.algorithm.paradigm.base import ParadigmBase @@ -49,6 +49,11 @@ class SingleTaskLearning(ParadigmBase): def __init__(self, workspace, **kwargs): ParadigmBase.__init__(self, workspace, **kwargs) self.initial_model = kwargs.get("initial_model_url") + self.mode = kwargs.get("mode") + self.quantization_type = kwargs.get("quantization_type") + self.llama_quantize_path = kwargs.get("llama_quantize_path") + if kwargs.get("use_gpu", True): + os.environ["CUDA_VISIBLE_DEVICES"] = "0" def run(self): """ @@ -66,10 +71,43 @@ def run(self): trained_model = self._train(job, self.initial_model) + if trained_model is None: + trained_model = self.initial_model + + if self.mode == 'with_compression': + trained_model = self._compress(trained_model) + inference_result = self._inference(job, trained_model) return inference_result, self.system_metric_info + + def _compress(self, trained_model): + if not os.path.exists(trained_model): + return None + + if self.llama_quantize_path is None or not os.path.exists(self.llama_quantize_path): + return None + + if self.quantization_type is None: + return None + + compressed_model = trained_model.replace('.gguf', f'_{self.quantization_type}.gguf') + + command = [ + self.llama_quantize_path, + trained_model, + compressed_model, + self.quantization_type + ] + + try: + subprocess.run(command, check=True) + except subprocess.CalledProcessError as _: + return trained_model + + return compressed_model + def _train(self, job, initial_model): train_output_dir = os.path.join(self.workspace, "output/train/") os.environ["BASE_MODEL_URL"] = initial_model @@ -84,5 +122,8 @@ def _inference(self, job, trained_model): inference_output_dir = os.path.join(self.workspace, "output/inference/") os.environ["RESULT_SAVED_URL"] = inference_output_dir job.load(trained_model) - infer_res = job.predict(inference_dataset.x) + if hasattr(inference_dataset, 'need_other_info'): + infer_res = job.predict(inference_dataset) + else: + infer_res = job.predict(inference_dataset.x) return infer_res diff --git a/core/testcasecontroller/metrics/metrics.py b/core/testcasecontroller/metrics/metrics.py index 4e9c886a..fa348a5a 100644 --- a/core/testcasecontroller/metrics/metrics.py +++ b/core/testcasecontroller/metrics/metrics.py @@ -39,8 +39,7 @@ def samples_transfer_ratio_func(system_metric_info: dict): """ - info = system_metric_info.get( - SystemMetricType.SAMPLES_TRANSFER_RATIO.value) + info = system_metric_info.get(SystemMetricType.SAMPLES_TRANSFER_RATIO.value) inference_num = 0 transfer_num = 0 for inference_data, transfer_data in info: @@ -53,8 +52,7 @@ def compute(key, matrix): """ Compute BWT and FWT scores for a given matrix. """ - print( - f"compute function: key={key}, matrix={matrix}, type(matrix)={type(matrix)}") + print(f"compute function: key={key}, matrix={matrix}, type(matrix)={type(matrix)}") length = len(matrix) accuracy = 0.0 @@ -63,7 +61,7 @@ def compute(key, matrix): flag = True for row in matrix: - if not isinstance(row, list) or len(row) != length-1: + if not isinstance(row, list) or len(row) != length - 1: flag = False break @@ -72,30 +70,29 @@ def compute(key, matrix): fwt_score = np.nan return bwt_score, fwt_score - for i in range(length-1): - for j in range(length-1): - if 'accuracy' in matrix[i+1][j] and 'accuracy' in matrix[i][j]: - accuracy += matrix[i+1][j]['accuracy'] - bwt_score += matrix[i+1][j]['accuracy'] - \ - matrix[i][j]['accuracy'] + for i in range(length - 1): + for j in range(length - 1): + if "accuracy" in matrix[i + 1][j] and "accuracy" in matrix[i][j]: + accuracy += matrix[i + 1][j]["accuracy"] + bwt_score += matrix[i + 1][j]["accuracy"] - matrix[i][j]["accuracy"] - for i in range(0, length-1): - if 'accuracy' in matrix[i][i] and 'accuracy' in matrix[0][i]: - fwt_score += matrix[i][i]['accuracy'] - matrix[0][i]['accuracy'] + for i in range(0, length - 1): + if "accuracy" in matrix[i][i] and "accuracy" in matrix[0][i]: + fwt_score += matrix[i][i]["accuracy"] - matrix[0][i]["accuracy"] - accuracy = accuracy / ((length-1) * (length-1)) - bwt_score = bwt_score / ((length-1) * (length-1)) - fwt_score = fwt_score / (length-1) + accuracy = accuracy / ((length - 1) * (length - 1)) + bwt_score = bwt_score / ((length - 1) * (length - 1)) + fwt_score = fwt_score / (length - 1) print(f"{key} BWT_score: {bwt_score}") print(f"{key} FWT_score: {fwt_score}") my_matrix = [] - for i in range(length-1): + for i in range(length - 1): my_matrix.append([]) - for j in range(length-1): - if 'accuracy' in matrix[i+1][j]: - my_matrix[i].append(matrix[i+1][j]['accuracy']) + for j in range(length - 1): + if "accuracy" in matrix[i + 1][j]: + my_matrix[i].append(matrix[i + 1][j]["accuracy"]) return my_matrix, bwt_score, fwt_score @@ -141,7 +138,16 @@ def task_avg_acc_func(system_metric_info: dict): compute task average accuracy """ info = system_metric_info.get(SystemMetricType.TASK_AVG_ACC.value) - return info["accuracy"] + return round(info["accuracy"], 3) + + +def forget_rate_func(system_metric_info: dict): + """ + compute task forget rate + """ + info = system_metric_info.get(SystemMetricType.FORGET_RATE.value) + forget_rate = np.mean(info) + return round(forget_rate, 3) def get_metric_func(metric_dict: dict): @@ -166,10 +172,12 @@ def get_metric_func(metric_dict: dict): try: load_module(url) metric_func = ClassFactory.get_cls( - type_name=ClassType.GENERAL, t_cls_name=name) + type_name=ClassType.GENERAL, t_cls_name=name + ) return name, metric_func except Exception as err: raise RuntimeError( - f"get metric func(url={url}) failed, error: {err}.") from err + f"get metric func(url={url}) failed, error: {err}." + ) from err return name, getattr(sys.modules[__name__], str.lower(name) + "_func") diff --git a/core/testenvmanager/dataset/dataset.py b/core/testenvmanager/dataset/dataset.py index 16bd038f..2edc960f 100644 --- a/core/testenvmanager/dataset/dataset.py +++ b/core/testenvmanager/dataset/dataset.py @@ -16,10 +16,16 @@ import os import tempfile - import pandas as pd -from sedna.datasources import CSVDataParse, TxtDataParse, JSONDataParse - +# pylint: disable=no-name-in-module +# pylint: disable=too-many-instance-attributes +from sedna.datasources import ( + CSVDataParse, + TxtDataParse, + JSONDataParse, + JsonlDataParse, + JSONMetaDataParse, +) from core.common import utils from core.common.constant import DatasetFormat @@ -38,12 +44,28 @@ class Dataset: def __init__(self, config): self.train_url: str = "" self.test_url: str = "" + self.train_index: str = "" + self.test_index: str = "" + self.train_data: str = "" + self.test_data: str = "" + self.train_data_info: str = "" + self.test_data_info: str = "" self.label: str = "" self._parse_config(config) def _check_fields(self): - self._check_dataset_url(self.train_url) - self._check_dataset_url(self.test_url) + if self.train_index: + self._check_dataset_url(self.train_index) + if self.test_index: + self._check_dataset_url(self.test_index) + if self.train_data: + self._check_dataset_url(self.train_data) + if self.test_data: + self._check_dataset_url(self.test_data) + if self.train_data_info: + self._check_dataset_url(self.train_data_info) + if self.test_data_info: + self._check_dataset_url(self.test_data_info) def _parse_config(self, config): for attr, value in config.items(): @@ -55,11 +77,15 @@ def _parse_config(self, config): @classmethod def _check_dataset_url(cls, url): if not utils.is_local_file(url) and not os.path.isabs(url): - raise ValueError(f"dataset file({url}) is not a local file and not a absolute path.") + raise ValueError( + f"dataset file({url}) is not a local file and not a absolute path." + ) file_format = utils.get_file_format(url) if file_format not in [v.value for v in DatasetFormat.__members__.values()]: - raise ValueError(f"dataset file({url})'s format({file_format}) is not supported.") + raise ValueError( + f"dataset file({url})'s format({file_format}) is not supported." + ) @classmethod def _process_txt_index_file(cls, file_url): @@ -79,15 +105,16 @@ def _process_txt_index_file(cls, file_url): tmp_file = os.path.join(tempfile.mkdtemp(), "index.txt") with open(tmp_file, "w", encoding="utf-8") as file: for line in lines: - #copy all the files in the line + # copy all the files in the line line = line.strip() words = line.split(" ") length = len(words) - words[-1] = words[-1] + '\n' + words[-1] = words[-1] + "\n" for i in range(length): file.writelines( - f"{os.path.abspath(os.path.join(root, words[i]))}") - if i < length-1: + f"{os.path.abspath(os.path.join(root, words[i]))}" + ) + if i < length - 1: file.writelines(" ") new_file = tmp_file @@ -103,6 +130,20 @@ def _process_index_file(self, file_url): return None + def _process_data_file(self, file_url): + file_format = utils.get_file_format(file_url) + if file_format == DatasetFormat.JSONL.value: + return file_url + + return None + + def _process_data_info_file(self, file_url): + file_format = utils.get_file_format(file_url) + if file_format == DatasetFormat.JSON.value: + return file_url + + return None + def process_dataset(self): """ process dataset: @@ -111,13 +152,38 @@ def process_dataset(self): in the index file(e.g.: txt index file). """ + if self.train_index: + self.train_url = self._process_index_file(self.train_index) + elif self.train_data: + self.train_url = self._process_data_file(self.train_data) + elif self.train_data_info: + self.train_url = self._process_data_info_file(self.train_data_info) + # raise NotImplementedError('to be done') + else: + raise NotImplementedError('not one of train_index/train_data/train_data_info') + + if self.test_index: + self.test_url = self._process_index_file(self.test_index) + elif self.test_data: + self.test_url = self._process_data_file(self.test_data) + elif self.test_data_info: + self.test_url = self._process_data_info_file(self.test_data_info) + # raise NotImplementedError('to be done') + else: + raise NotImplementedError('not one of test_index/test_data/test_data_info') - self.train_url = self._process_index_file(self.train_url) - self.test_url = self._process_index_file(self.test_url) # pylint: disable=too-many-arguments - def split_dataset(self, dataset_url, dataset_format, ratio, method="default", - dataset_types=None, output_dir=None, times=1): + def split_dataset( + self, + dataset_url, + dataset_format, + ratio, + method="default", + dataset_types=None, + output_dir=None, + times=1, + ): """ split dataset: step1: divide all data N(N = times) times to generate N pieces of data. @@ -152,30 +218,48 @@ def split_dataset(self, dataset_url, dataset_format, ratio, method="default", """ if method == "default": - return self._splitting_more_times(dataset_url, dataset_format, ratio, - data_types=dataset_types, - output_dir=output_dir, - times=times) + return self._splitting_more_times( + dataset_url, + dataset_format, + ratio, + data_types=dataset_types, + output_dir=output_dir, + times=times, + ) # add new splitting method for semantic segmantation if method == "city_splitting": - return self._city_splitting(dataset_url, dataset_format, ratio, - data_types=dataset_types, - output_dir=output_dir, - times=times) + return self._city_splitting( + dataset_url, + dataset_format, + ratio, + data_types=dataset_types, + output_dir=output_dir, + times=times, + ) if method == "fwt_splitting": - return self._fwt_splitting(dataset_url, dataset_format, ratio, - data_types=dataset_types, - output_dir=output_dir, - times=times) + return self._fwt_splitting( + dataset_url, + dataset_format, + ratio, + data_types=dataset_types, + output_dir=output_dir, + times=times, + ) if method == "hard-example_splitting": - return self._hard_example_splitting(dataset_url, dataset_format, ratio, - data_types=dataset_types, - output_dir=output_dir, - times=times) - - raise ValueError(f"dataset splitting method({method}) is not supported," - f"currently, method supports 'default'.") + return self._hard_example_splitting( + dataset_url, + dataset_format, + ratio, + data_types=dataset_types, + output_dir=output_dir, + times=times, + ) + + raise ValueError( + f"dataset splitting method({method}) is not supported," + f"currently, method supports 'default'." + ) @classmethod def _get_file_url(cls, output_dir, dataset_type, dataset_id, file_format): @@ -210,8 +294,9 @@ def _get_dataset_file(self, data, output_dir, dataset_type, index, dataset_forma return data_file - def _splitting_more_times(self, data_file, data_format, ratio, - data_types=None, output_dir=None, times=1): + def _splitting_more_times( + self, data_file, data_format, ratio, data_types=None, output_dir=None, times=1 + ): if not data_types: data_types = ("train", "eval") @@ -227,24 +312,38 @@ def _splitting_more_times(self, data_file, data_format, ratio, index = 1 while index <= times: if index == times: - new_dataset = all_data[step * (index - 1):] + new_dataset = all_data[step * (index - 1) :] else: - new_dataset = all_data[step * (index - 1):step * index] + new_dataset = all_data[step * (index - 1) : step * index] new_num = len(new_dataset) - data_files.append(( - self._get_dataset_file(new_dataset[:int(new_num * ratio)], output_dir, - data_types[0], index, data_format), - self._get_dataset_file(new_dataset[int(new_num * ratio):], output_dir, - data_types[1], index, data_format))) + data_files.append( + ( + self._get_dataset_file( + new_dataset[: int(new_num * ratio)], + output_dir, + data_types[0], + index, + data_format, + ), + self._get_dataset_file( + new_dataset[int(new_num * ratio) :], + output_dir, + data_types[1], + index, + data_format, + ), + ) + ) index += 1 return data_files - def _fwt_splitting(self, data_file, data_format, ratio, - data_types=None, output_dir=None, times=1): + def _fwt_splitting( + self, data_file, data_format, ratio, data_types=None, output_dir=None, times=1 + ): if not data_types: data_types = ("train", "eval") @@ -257,33 +356,52 @@ def _fwt_splitting(self, data_file, data_format, ratio, all_num = len(all_data) step = int(all_num / times) - data_files.append(( - self._get_dataset_file(all_data[:1], output_dir, - data_types[0], 0, data_format), - self._get_dataset_file(all_data[:1], output_dir, - data_types[1], 0, data_format))) + data_files.append( + ( + self._get_dataset_file( + all_data[:1], output_dir, data_types[0], 0, data_format + ), + self._get_dataset_file( + all_data[:1], output_dir, data_types[1], 0, data_format + ), + ) + ) index = 1 while index <= times: if index == times: - new_dataset = all_data[step * (index - 1):] + new_dataset = all_data[step * (index - 1) :] else: - new_dataset = all_data[step * (index - 1):step * index] + new_dataset = all_data[step * (index - 1) : step * index] new_num = len(new_dataset) - data_files.append(( - self._get_dataset_file(new_dataset[:int(new_num * ratio)], output_dir, - data_types[0], index, data_format), - self._get_dataset_file(new_dataset[int(new_num * ratio):], output_dir, - data_types[1], index, data_format))) + data_files.append( + ( + self._get_dataset_file( + new_dataset[: int(new_num * ratio)], + output_dir, + data_types[0], + index, + data_format, + ), + self._get_dataset_file( + new_dataset[int(new_num * ratio) :], + output_dir, + data_types[1], + index, + data_format, + ), + ) + ) index += 1 return data_files # add new splitting method for semantic segmentation - def _city_splitting(self, data_file, data_format, ratio, - data_types=None, output_dir=None, times=1): + def _city_splitting( + self, data_file, data_format, ratio, data_types=None, output_dir=None, times=1 + ): if not data_types: data_types = ("train", "eval") @@ -296,38 +414,67 @@ def _city_splitting(self, data_file, data_format, ratio, index0 = 0 for i, data in enumerate(all_data): - if 'synthia_sim' in data: + if "synthia_sim" in data: continue index0 = i break new_dataset = all_data[:index0] - data_files.append(( - self._get_dataset_file(new_dataset[:int(len(new_dataset) * ratio)], output_dir, - data_types[0], 1, data_format), - self._get_dataset_file(new_dataset[int(len(new_dataset) * ratio):], output_dir, - data_types[1], 1, data_format))) + data_files.append( + ( + self._get_dataset_file( + new_dataset[: int(len(new_dataset) * ratio)], + output_dir, + data_types[0], + 1, + data_format, + ), + self._get_dataset_file( + new_dataset[int(len(new_dataset) * ratio) :], + output_dir, + data_types[1], + 1, + data_format, + ), + ) + ) times = times - 1 - step = int((len(all_data)-index0) / times) + step = int((len(all_data) - index0) / times) index = 1 while index <= times: if index == times: - new_dataset = all_data[index0 + step * (index - 1):] + new_dataset = all_data[index0 + step * (index - 1) :] else: - new_dataset = all_data[index0 + step * (index - 1):index0 + step * index] - - data_files.append(( - self._get_dataset_file(new_dataset[:int(len(new_dataset) * ratio)], output_dir, - data_types[0], index+1, data_format), - self._get_dataset_file(new_dataset[int(len(new_dataset) * ratio):], output_dir, - data_types[1], index+1, data_format))) + new_dataset = all_data[ + index0 + step * (index - 1) : index0 + step * index + ] + + data_files.append( + ( + self._get_dataset_file( + new_dataset[: int(len(new_dataset) * ratio)], + output_dir, + data_types[0], + index + 1, + data_format, + ), + self._get_dataset_file( + new_dataset[int(len(new_dataset) * ratio) :], + output_dir, + data_types[1], + index + 1, + data_format, + ), + ) + ) index += 1 return data_files - def _hard_example_splitting(self, data_file, data_format, ratio, - data_types=None, output_dir=None, times=1): + def _hard_example_splitting( + self, data_file, data_format, ratio, data_types=None, output_dir=None, times=1 + ): if not data_types: data_types = ("train", "eval") @@ -339,33 +486,65 @@ def _hard_example_splitting(self, data_file, data_format, ratio, data_files = [] all_num = len(all_data) - step = int(all_num / (times*2)) - data_files.append(( - self._get_dataset_file(all_data[:int((all_num * ratio)/2)], output_dir, - data_types[0], 0, data_format), - self._get_dataset_file(all_data[int((all_num * ratio)/2):int(all_num/2)], output_dir, - data_types[1], 0, data_format))) + step = int(all_num / (times * 2)) + data_files.append( + ( + self._get_dataset_file( + all_data[: int((all_num * ratio) / 2)], + output_dir, + data_types[0], + 0, + data_format, + ), + self._get_dataset_file( + all_data[int((all_num * ratio) / 2) : int(all_num / 2)], + output_dir, + data_types[1], + 0, + data_format, + ), + ) + ) index = 1 while index <= times: if index == times: - new_dataset = all_data[int(all_num/2)+step*(index-1):] + new_dataset = all_data[int(all_num / 2) + step * (index - 1) :] else: - new_dataset = all_data[int(all_num/2)+step*(index-1): int(all_num/2)+step*index] + new_dataset = all_data[ + int(all_num / 2) + + step * (index - 1) : int(all_num / 2) + + step * index + ] new_num = len(new_dataset) - data_files.append(( - self._get_dataset_file(new_dataset[:int(new_num * ratio)], output_dir, - data_types[0], index, data_format), - self._get_dataset_file(new_dataset[int(new_num * ratio):], output_dir, - data_types[1], index, data_format))) + data_files.append( + ( + self._get_dataset_file( + new_dataset[: int(new_num * ratio)], + output_dir, + data_types[0], + index, + data_format, + ), + self._get_dataset_file( + new_dataset[int(new_num * ratio) :], + output_dir, + data_types[1], + index, + data_format, + ), + ) + ) index += 1 return data_files @classmethod - def load_data(cls, file: str, data_type: str, label=None, use_raw=False, feature_process=None): + def load_data( + cls, file: str, data_type: str, label=None, use_raw=False, feature_process=None + ): """ load data @@ -388,6 +567,11 @@ def load_data(cls, file: str, data_type: str, label=None, use_raw=False, feature e.g.: TxtDataParse, CSVDataParse. """ + if file.split('/')[-1] == "metadata.json": + data = JSONMetaDataParse(data_type=data_type, func=feature_process) + data.parse(file) + return data + data_format = utils.get_file_format(file) data = None @@ -397,11 +581,14 @@ def load_data(cls, file: str, data_type: str, label=None, use_raw=False, feature if data_format == DatasetFormat.TXT.value: data = TxtDataParse(data_type=data_type, func=feature_process) - #print(file) data.parse(file, use_raw=use_raw) if data_format == DatasetFormat.JSON.value: data = JSONDataParse(data_type=data_type, func=feature_process) data.parse(file) + if data_format == DatasetFormat.JSONL.value: + data = JsonlDataParse(data_type=data_type, func=feature_process) + data.parse(file) + return data diff --git a/core/testenvmanager/dataset/utils.py b/core/testenvmanager/dataset/utils.py new file mode 100644 index 00000000..1349ad07 --- /dev/null +++ b/core/testenvmanager/dataset/utils.py @@ -0,0 +1,92 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Dataset utils to read data from file and partition data """ +# pylint: disable=W1203 +import random +import numpy as np +from sedna.datasources import BaseDataSource +from core.common.log import LOGGER + + +def read_data_from_file_to_npy(files: BaseDataSource): + """ + read data from file to numpy array + + Parameters + --------- + files: list + the address url of data file. + + Returns + ------- + list + data in numpy array. + + """ + x_train = [] + y_train = [] + for i, file in enumerate(files.x): + x_data = np.load(file) + y_data = np.full((x_data.shape[0],), (files.y[i]).astype(np.int32)) + x_train.append(x_data) + y_train.append(y_data) + x_train = np.concatenate(x_train, axis=0) + y_train = np.concatenate(y_train, axis=0) + return x_train, y_train + + +def partition_data(datasets, client_number, data_partition="iid", non_iid_ratio=0.6): + """ + Partition data into clients. + + Parameters + ---------- + datasets: list + The list containing the data and labels (x_data, y_data). + client_number: int + The number of clients. + partition_methods: str + The partition method, either 'iid' or 'non-iid'. + + Returns + ------- + list + A list of data for each client in numpy array format. + """ + LOGGER.info(data_partition) + data = [] + if data_partition == "iid": + x_data = datasets[0] + y_data = datasets[1] + indices = np.arange(len(x_data)) + np.random.shuffle(indices) + for i in range(client_number): + start = i * len(x_data) // client_number + end = (i + 1) * len(x_data) // client_number + data.append([x_data[indices[start:end]], y_data[indices[start:end]]]) + elif data_partition == "non-iid": + class_num = len(np.unique(datasets[1])) + x_data = datasets[0] + y_data = datasets[1] + + for i in range(client_number): + sample_number = int(class_num * non_iid_ratio) + current_class = random.sample(range(class_num), sample_number) + LOGGER.info(f"for client{i} the class is {current_class}") + indices = np.where(y_data == current_class)[0] + data.append([x_data[indices], y_data[indices]]) + else: + raise ValueError("paritiion_methods must be 'iid' or 'non-iid'") + return data diff --git a/core/testenvmanager/testenv/testenv.py b/core/testenvmanager/testenv/testenv.py index 4a4fef4d..9e159901 100644 --- a/core/testenvmanager/testenv/testenv.py +++ b/core/testenvmanager/testenv/testenv.py @@ -39,11 +39,14 @@ def __init__(self, config): "url": "", }, "threshold": 0.9, - "operator": ">" + "operator": ">", } self.metrics = [] self.incremental_rounds = 2 + self.round = 1 + self.client_number = 1 self.dataset = None + self.use_gpu = False # default false self._parse_config(config) def _check_fields(self): @@ -51,8 +54,10 @@ def _check_fields(self): raise ValueError(f"not found testenv metrics({self.metrics}).") if not isinstance(self.incremental_rounds, int) or self.incremental_rounds < 2: - raise ValueError(f"testenv incremental_rounds(value={self.incremental_rounds})" - f" must be int type and not less than 2.") + raise ValueError( + f"testenv incremental_rounds(value={self.incremental_rounds})" + f" must be int type and not less than 2." + ) def _parse_config(self, config): config_dict = config[str.lower(TestEnv.__name__)] @@ -60,6 +65,8 @@ def _parse_config(self, config): for k, v in config_dict.items(): if k == str.lower(Dataset.__name__): self.dataset = Dataset(v) + elif k == 'use_gpu': + self.use_gpu = bool(v) else: if k in self.__dict__: self.__dict__[k] = v @@ -67,7 +74,7 @@ def _parse_config(self, config): self._check_fields() def prepare(self): - """ prepare env""" + """prepare env""" try: self.dataset.process_dataset() except Exception as err: diff --git a/docs/proposals/algorithms/multi-edge-inference/Heterogeneous Multi-Edge Collaborative Neural Network Inference for High Mobility Scenarios.md b/docs/proposals/algorithms/multi-edge-inference/Heterogeneous Multi-Edge Collaborative Neural Network Inference for High Mobility Scenarios.md new file mode 100644 index 00000000..17b86e1f --- /dev/null +++ b/docs/proposals/algorithms/multi-edge-inference/Heterogeneous Multi-Edge Collaborative Neural Network Inference for High Mobility Scenarios.md @@ -0,0 +1,142 @@ +# Backgroud +In high-mobility scenarios such as highways and high-speed railways, the connection between personal terminal devices and cloud servers is significantly weakened. However, in recent years, artificial intelligence technology has permeated every aspect of our lives, and we also need to use artificial intelligence technologies with high computational and storage demands and sensitive to latency in high-mobility scenarios. For example, even when driving through a tunnel with a weak network environment, we may still need to use AI capabilities such as image classification and large model dialogue. Therefore, in the event that edge devices lose connection with the cloud, offloading AI computing tasks to adjacent edge devices and achieving computational aggregation based on the mutual collaboration between devices, to complete computing tasks that traditionally require cloud-edge collaboration, has become an issue worth addressing. This project aims to use multiple heterogeneous computing units on the edge (such as personal mobile phones, tablets, bracelets, laptops, and other computing devices) for collaborative artificial intelligence neural network inference, enabling AI tasks to be completed with lower latency using devices that are closer to the edge, thereby enhancing the user experience. + +To simulate the workflow of multi-edge inference in the real world, Ianvs, as a distributed collaborative AI benchmarking platform, currently supports this AI paradigm and provides a simple simulation benchmark testing. However, facing various heterogeneous computing units, Ianvs does not have the capability for automatic partitioning and scheduling of the computation graph, but instead requires users to manually partition the computation graph and decide on its allocation to suitable computing nodes, which greatly limits the computational resource utilization and flexibility of multi-edge inference, especially in high-mobility scenarios, where this limitation can further deteriorate the user experience. + +Therefore, targeting high-mobility scenarios and heterogeneous computing devices, this proposal offers an automatic partitioning and scheduling framework for neural network computation graphs, thereby enhancing the collaborative and adaptive capabilities of multi-end inference. +# Goals + +1. Complete the partitioning module in Ianvs, supporting automatic partitioning of neural networks in any ONNX format while considering the different computational capabilities of heterogeneous computing units, adaptively achieve load balancing.; +2. Based on the aforementioned module, providing a multi-edge inference benchmarking job in a high-mobility scenario (such as edge-side LLM inference and image recognition, etc.) to directly use the automatic partitioning function of Ianvs, and form a demonstration example. +# Proposal +_Heterogeneous Multi-Edge Collaborative Neural Network Inference for High Mobility Scenarios_ based on the multi-edge inference paradigm supported by Ianvs, a sub-module for automatic computation graph partitioning is added on top, to adapt to the problem of heterogeneous computing capabilities of multiple edge devices in high mobility scenarios, avoiding the need for developers to manually partition the neural network computation graph, making the multi-edge inference workflow more efficient and productive. + +The scope of the system includes: + +1. Encapsulating the capabilities of automatic computation graph partitioning into function, and providing them as extended options for users to customize, seamlessly integrating with the existing multi-edge inference workflow; +2. Providing a multi-edge inference benchmarking job in a high-mobility scenario (such as edge-side LLM inference and image recognition, etc.) to verify the effectiveness and benefits of the automatic partitioning module. +3. Adding judgments to the multi-edge inference paradigm process has provided significant scalability for the partitioning algorithm. If the user has implemented a custom partition function, the user-defined partitioning algorithm is called first. + +Targeting users include: + +1. Beginners: Familiarize with distributed synergy AI and multi-edge inference, among other concepts. +2. Developers: Quickly integrate multi-edge inference algorithms into other development environments such as Sedna and test the performance for further optimization. +# Design Details +## Process Design +Firstly, taking the existing tracking_job and reid_job as examples, analyze the workflow of the two benchmarking jobs, clarify the function call logic of Ianvs, determine the writing position of configuration information and the insertion position of the partition function, to ensure high cohesion and low coupling of the overall code. The workflow starts from the main() function in the benchmarking.py file (located in the ianvs/core directory), which reads the user's configuration file reid_job.yaml and creates a BenchmarkingJob. This process parses the configuration parameters of the yaml file and creates instances of classes such as TestEnv, Rank, Simulation, and TestCaseController that match the configuration description. + +Subsequently, the run() method of the BenchmarkingJob instance is called, using the build_testcases() method of the TestCaseController instance to create test cases. This step is actually parsing the algorithm configuration specified by _test_object.url_ in the reid_job.yaml file and creating instances of Algorithm and TestCase that match the algorithm configuration description. Then, the run_testcases() method of the TestCaseController instance is called, which ultimately calls the run() method of the corresponding algorithm paradigm, such as the run() method of the MultiedgeInference class instance in this case. + +In this method, a job instance is created through self.build_paradigm_job(ParadigmType.MULTIEDGE_INFERENCE.value), which is actually the instance of the BaseModel class that the user has written themselves and has been registered in the module_instances during the configuration parsing process. Therefore, all subsequent load() and predict() methods are actually calling the methods defined in the user's own BaseModel class. For example, the user-defined load method implements simulation of data parallelism in a multi-device scenario: + +```python +# examples/MOT17/multiedge_inference_bench/pedestrian_tracking/testalgorithms/reid/m3l/basemodel.py + +def load(self, model_url=None): + if model_url: + arch = re.compile("_([a-zA-Z]+).pth").search(model_url).group(1) + # Create model + self.model = models.create( + arch, num_features=0, dropout=0, norm=True, BNNeck=True + ) + # use CUDA + self.model.cuda() + self.model = nn.DataParallel(self.model) + if Path(model_url).is_file(): + checkpoint = torch.load(model_url, map_location=torch.device('cpu')) + print("=> Loaded checkpoint '{}'".format(model_url)) + self.model.load_state_dict(checkpoint["state_dict"]) + else: + raise ValueError("=> No checkpoint found at '{}'".format(model_url)) + else: + raise Exception(f"model url is None") +``` + +Based on the above process analysis, we find that the existing multi-edge inference benchmarking job only uses Ianvs to create and manage test cases, where the core algorithmic processes such as multi-device parallelism and model partitioning are left to the user to implement. It is also worth mentioning that the nn.DataParallel(self.model) used in this case only achieves data parallelism, and for scenarios with low computing power on the edge and large models, relying solely on data parallelism is obviously insufficient to support edge inference needs. Therefore, this project needs to implement model parallel capabilities based on model partitioning and encapsulate these capabilities (partitioning and scheduling) into an function, separated from the user's code, as an optional feature in the multiedge_inference paradigm supported by Ianvs. + +The newly added automatic partitioning module will be inserted into the position indicated in the diagram below, thereby forming a new complete flowchart: +![process](images/process.png) + +## Module Design and Code Integration +From the above process analysis, it is known that to provide automatic graph partitioning and scheduling capabilities within the Ianvs framework, the optimal code integration point is in the Algorithm Paradigm module of the Test Case Controller component, specifically in the directory core/testcasecontroller/algorithm/paradigm. The current structure of this directory is: + +``` +paradigm +├── __init__.py +├── base.py +├── incremental_learning +│ ├── __init__.py +│ └── incremental_learning.py +├── lifelong_learning +│ ├── __init__.py +│ └── lifelong_learning.py +├── multiedge_inference +│ ├── __init__.py +│ └── multiedge_inference.py +└── singletask_learning + ├── __init__.py + ├── singletask_learning.py + ├── singletask_learning_active_boost.py + └── singletask_learning_tta.py +``` + +Based on the process analysis, this project intends to add a _partition function under the multiedge_inference paradigm, and implement our computation graph partitioning and scheduling capabilities within it. The total process should include: + +- Input: Initial model data and the user-declared devices.yaml file, which contains the number of heterogeneous devices the user simulates on a single machine, information about each device (such as GPU memory, number of GPUs, etc.), as well as communication bandwidth between devices. + +- Parsing: The user-declared devices.yaml file is parsed to obtain device data, and the initial model computation graph is parsed to obtain model data. + +- Modeling(optional): Joint analysis of the parsed device data and model data is performed to enable the algorithm to calculate a matching list of devices and computational subgraphs. + +- Partitioning: The model is partitioned based on the decided computational subgraphs. + +- Output: The matching list of devices and computational subgraphs, as well as the partitioned computational subgraphs. + +It is worth noting that we have implemented a general interface and a simple partitioning algorithm here (by analyzing the partitioning points specified by the user). More partitioning algorithms will be added in the future and user can customize their own partition methods in basemodel.py, they only need to comply with the input and output specifications defined by the interface as follows: + +``` +def partiton(self, initial_model): + ## 1. parsing devices.yaml + ## 2. modeling + ## 3. partition + return models_dir, map_info +``` + +Subsequently, modify the logic in multiedge_inference.py to decide whether to use the auto partitioning capability based on user's code. If it is chosen to use, pass url of initial_model and key information of devices.yaml into the automatic partitioning algorithm and then pass the returned matching list of devices and computational subgraphs as well as the partitioned computational subgraphs to the user's code. + +Further, provide the load method of the BaseModel class in the benchmarking job to receive these parameters and use them to complete the multi-inference process. + +At the same time, the corresponding multi-edge inference benchmarking job for high-mobility scenarios will be provided in the _examples_ folder. + +The following diagram illustrates the framework of the entire system after the modifications: +![framework](images/framework.png) + +## Method Desgin +![image](images/partition_method.png) +We implement the heterogeneous neural network multi-edge collaborative inference for high-mobility scenarios using the method shown in the above figure. + +First, the user-declared devices.yaml file is parsed to obtain device data. Considering further updates to the partitioning algorithm in the future (such as joint modeling based on device capabilities and model computational complexity), we have reserved sufficient fields in devices.yaml to obtain information such as device type, memory, frequency, bandwidth, etc. + +Subsequently, based on the matching relationship, the model is partitioned and scheduled to matching device (simulated by Docker or GPU by user themselves) to achieve the best collaborative effect. + +It is worth noting that the parallelism we implement here is model parallelism. When multiple inference tasks are carried out simultaneously, models that complete inference in this round do not have to wait for the models that have not finished inference. Instead, they can proceed in parallel with the inference of the next task, thus forming a simple pipeline parallelism. More complex and efficient pipeline parallelization strategies are left for future work. + +In this process, the most crucial part is the extraction and modeling of device information and model information. Since the simulation is carried out in a single-machine environment, the device information will be supplemented by the user as a configuration file named devices.yaml, and the memory occupancy and computational cost of each model layer will require the user to implement profiling. We will provide two benchmarking jobs in the examples/imagenet directory to simulate different partitioning algorithms. The manual partitioning algorithm, which is based on predefined partitioning points, is simulated in testalgorithms/manual, and the core partitioning process will be integrated into the ianvs core code as the default partition method. Since in most cases we do not know the cost of each computational subgraph of complex models, we provide an automatic partitioning algorithm based on profiling and memory matching in testalgorithms/automatic to form a comparative experiment with manual partitioning. + +The benchmarking job will use the vit-base-patch16-224 model and the ImageNet dataset to simulate edge-side image recognition tasks and verify the performance comparison between different partitioning methods and the baseline (no partitioning). The specific comparison metrics include accuracy, FPS, peak_memory, and peak_power. The structure of the Benchmarking Job is as follows: + +![bench](images/benchmarking.png) + +## Roadmap +**July** + +- Complete the arbitrary partitioning function of the ONNX computational graph. +- Implement the profiling of some large models such as ViT, Bert, etc. + +**August** + +- Implement the automatic graph scheduling and partitioning algorithm based on Ianvs. + +**September** + +- Implement the multiedge inference benchmarking job based on the automatic scheduling and partitioning of the neural network computational graph and complete the demonstration example. diff --git a/docs/proposals/algorithms/multi-edge-inference/images/benchmarking.png b/docs/proposals/algorithms/multi-edge-inference/images/benchmarking.png new file mode 100644 index 00000000..0b6a79e3 Binary files /dev/null and b/docs/proposals/algorithms/multi-edge-inference/images/benchmarking.png differ diff --git a/docs/proposals/algorithms/multi-edge-inference/images/framework.png b/docs/proposals/algorithms/multi-edge-inference/images/framework.png new file mode 100644 index 00000000..a00f9153 Binary files /dev/null and b/docs/proposals/algorithms/multi-edge-inference/images/framework.png differ diff --git a/docs/proposals/algorithms/multi-edge-inference/images/multiedge_inference_method.png b/docs/proposals/algorithms/multi-edge-inference/images/multiedge_inference_method.png new file mode 100644 index 00000000..fff64acc Binary files /dev/null and b/docs/proposals/algorithms/multi-edge-inference/images/multiedge_inference_method.png differ diff --git a/docs/proposals/algorithms/multi-edge-inference/images/partition_method.png b/docs/proposals/algorithms/multi-edge-inference/images/partition_method.png new file mode 100644 index 00000000..eefb0c7e Binary files /dev/null and b/docs/proposals/algorithms/multi-edge-inference/images/partition_method.png differ diff --git a/docs/proposals/algorithms/multi-edge-inference/images/process.png b/docs/proposals/algorithms/multi-edge-inference/images/process.png new file mode 100644 index 00000000..0f1cd4ab Binary files /dev/null and b/docs/proposals/algorithms/multi-edge-inference/images/process.png differ diff --git a/docs/proposals/scenarios/Smart_Coding/Smart Coding benchmark suite Proposal.md b/docs/proposals/scenarios/Smart_Coding/Smart Coding benchmark suite Proposal.md new file mode 100644 index 00000000..b612a1ad --- /dev/null +++ b/docs/proposals/scenarios/Smart_Coding/Smart Coding benchmark suite Proposal.md @@ -0,0 +1,178 @@ +# Background +Large Language Models (LLMs) have demonstrated powerful capabilities in tasks such as code generation, automatic programming, and code analysis. However, these models are typically trained on generic code data and often fail to fully leverage the collaboration and feedback from software engineers in real-world scenarios. To construct a more intelligent and efficient code ecosystem, it is necessary to establish a collaborative code dataset and evaluation benchmark to facilitate tight collaboration between LLMs and software engineers. This project aims to build a collaborative code intelligent agent alignment dataset and evaluation benchmark for LLMs based on the open-source edge computing framework KubeEdge-Ianvs. This dataset will include behavioral trajectories, feedback, and iterative processes of software engineers during development, as well as relevant code versions and annotation information. Through this data, we will design evaluation metrics and benchmarks to measure the performance of LLMs in tasks such as code generation, recommendation, and analysis, fostering collaboration between LLMs and software engineers. + +In today's software development practice, large language models (LLMs) show great potential in areas such as code generation, recommendation, and analysis. However, existing models are usually trained on general code bases and lack optimization for specific software engineering tasks. Therefore, creating a specific dataset and evaluation benchmark that integrates the actual work experience and feedback of software engineers is crucial to improving the application effect of these models in actual programming environments. +# Goals +1. Build a collaborative code intelligent agent alignment dataset for LLMs +2. Design a code intelligent agent collaborative evaluation benchmark for LLMs +3. Integrate the dataset and evaluation benchmark into the KubeEdge-Ianvs framework +# Proposal +## Building a large code language model dataset + +1. **Behavioral Trajectory During Development**: +Record the operations performed by software engineers during the development process. These operations may include code writing, code submission, code merging, code review, code refactoring, etc. +Specific behavioral data may include the development tools used, the code snippets written, submission records, review comments, etc. +2. **Feedback and Iteration Process**: +Collect feedback and iteration records of the code from R&D engineers during the development process. These feedbacks may include code review comments, test results, error reports, improvement suggestions, etc. +Record the time of feedback, feedback content, corresponding code modifications, and final solutions. +3. **Code version and comment information**: +Record each version of the code and the differences between each version, including new, modified, and deleted code. +Include detailed code comments and documentation to understand the function, purpose, and design ideas of the code. + +## Code Large Language Model Evaluation Benchmark +1. The benchmark should include common code agent tasks such as code generation, recommendation and analysis. +2. The evaluation indicators should cover multiple dimensions such as functionality, reliability, and interpretability, and match the feedback and needs of software engineers. +3. The benchmark should be able to evaluate the performance of LLMs on collaborative code agent tasks and provide a basis for further algorithm optimization. +### Integrate datasets and benchmarks into the KubeEdge-Ianvs framework + +1. The dataset and benchmark are included as part of the Ianvs framework, and provide good scalability and integration. +2. Ensure that the datasets and benchmarks can run efficiently on edge devices of the Ianvs framework and work seamlessly with other functional modules of Ianvs. + + +`examples/smart_coding` directory structure: +``` +smart_coding +└── smart_coding_learning_bench + └── smart_co + ├── benchmarkingjob.yaml + ├── testalgorithms + │ └── gen + │ ├── basemodel.py + │ ├── gen_algorithm.yaml + │ ├── op_eval.py + └── testenv + ├── acc.py + └── testenv.yaml +``` +The content format of the comment test set is as follows: +``` +{"description": "Add detailed comments to a given code/function.","code_snippet": "def calculate_area(length, width):\n return length * width",} +{"description": "Add detailed comments to a given Python function.", + "code_snippet": "def calculate_area(length, width):\n return length * width", + "annotations": [ + { + "line_number": 1, + "annotation": "Define a function calculate_area that accepts two parameters: length and width." + }, + { + "line_number": 2, + "annotation": "Returns the product of length and width, which is the area of ​​the rectangle." + } + ]} +``` + +In this project, I am mainly responsible for the test suite of the code model. For the code model, the main evaluation criteria are comments and issues in the task requirements. Different projects use different fields for the evaluation criteria of comments. The scoring part is based on the logic, accuracy, and format of the overall code. + +The data set part and the interface definition part adopt a question-and-answer method. Question is a line of code/a function, and Answer is also Comment, which is a comment on the code or function. + +The format of the issue test set is as follows: +``` +{ + "question": "title", + "user_login": "name", + "created_at":"time", + "updated_at": "time", + "body":"This is not possible right now afaik :/\r\n\r\nMaybe we could have something like this ? wdyt ?\r\n\r\n```python\r\nds = interleave_datasets(\r\n [shuffled_dataset_a, dataset_b],\r\n probabilities=probabilities,\r\n stopping_strategy='all_exhausted',\r\n reshuffle_each_iteration=True,\r\n)", + "answer_1": { + "user_login": "name", + "created_at":"time" + "updated_at": "time", + "body":"This is not possible right now afaik :/\r\n\r\nMaybe we could have something like this ? wdyt ?\r\n\r\n```python\r\nds = interleave_datasets(\r\n [shuffled_dataset_a, dataset_b],\r\n probabilities=probabilities,\r\n stopping_strategy='all_exhausted',\r\n reshuffle_each_iteration=True,\r\n)", + }, + "answer_2": { + "user_login": "name", + "created_at":"time" + "updated_at": "time", + "body":"XXXX" + }, + +} +``` +The format of the issue test set refers to a simple QA question-answering task. LLM is used to answer the solution to the problem. Manual judgment is used in the evaluation stage, and the acc accuracy is calculated in the end. + +The following is the operation flow of the benchmark system based on user input configuration. Because the interface part is written by Meng Zhuo, the general structure is basically consistent with Meng Zhuo. The flowchart shows the data verification, parsing, initialization and other operations. The difference lies in the reading of the issue data set. In the issue data set, there is only one Question, but there may be multiple Comments, so in the training part, the data reading needs to be adjusted. + +![](https://github.com/safe-b/ianvs/blob/main/docs/proposals/scenarios/Smart_Coding/image/data_process_change.png?raw=true) + +![](https://github.com/safe-b/ianvs/blob/main/docs/proposals/scenarios/Smart_Coding/image/change_part.png?raw=true) + +It is worth noting that this design is also compatible with the old version of index data. You only need to change the old version's `train_url` and `test_url` fields to train_index and test_index. + +In previous projects, it was necessary to configure the paths of the `train_url` and `test_url` index files in the `testenv.yaml` file. The index file would contain the file path of (input x, expected output y). This design has some limitations. + +Dataset file format A data.json/jsonl/ file contains both data and labels. The data set format of the government affairs model is defined as follows +``` +{"code_snippet": xxx, "comment": xxx} +{"code_snippet": xxx, "comment": xxx} +{"code_snippet": xxx, "comment": xxx} +``` + +In the code model, the data format definition for adding comments to code and functions is consistent with the government affairs model, but the attributes of the answer also need to be modified in the issue dataset, for example +``` +{ + "question": xxx, + "user_login": "XXXX", + "created_at":"XXXX-XX-XX", + "updated_at": "XXXX-XX-XX", + "body":"This is not possible right now afaik :/\r\n\r\nMaybe we could have something like this ? wdyt ?\r\n\r\n```python\r\nds = interleave_datasets(\r\n [shuffled_dataset_a, dataset_b],\r\n probabilities=probabilities,\r\n stopping_strategy='all_exhausted',\r\n reshuffle_each_iteration=True,\r\n)", + "answer_1":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + "answer_2":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + "answer_3":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + } +{ + "question": xxx, + "user_login": "XXXX-XX-XX", + "created_at":"XXXX-XX-XX", + "updated_at": "XXXX-XX-XX", + "body":"XXXX....", + "answer_1":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + "answer_2":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + "answer_3":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + } +....... +``` + +The format of the issue test set refers to a simple QA question-answering task. The solution to the problem is answered through LLM. Manual judgment is used in the evaluation stage, and the acc accuracy is calculated in the end. + +The following is the operation flow of the benchmark system based on user input configuration. Because the interface part is written by Meng Zhuo, the general structure is basically consistent with Meng Zhuo. The flowchart shows the data verification, parsing, initialization and other operations. +The difference lies in the reading of the issue data set. In the issue data set, there is only one Question, but there may be multiple Comments, so in the training part, the data reading needs to be adjusted. + + + +# Design Details +## Data collection +1. GitHub: Collect open source project code in various programming languages ​​from GitHub. Use GitHub API or manual retrieval. +2. GitHub Issues: collects problem reports submitted by developers, including bug reports, feature requests, and discussions。 +3. Pull Requests: Collect pull requests submitted by developers, including the review history of function implementation and code modifications. +4. Commit Logs: Extract the project's commit logs, including every change to the code, the committer information, and the commit time. +5. Branches and Merges: Consider branch and merge information to understand the development and merge history of the code. + +### Specific steps +In the early stage, by adding comments to Python files (with key lines and segments as the granularity), collecting classic Python projects, annotating the core code lines or functions in the projects, and organizing the data set, and reading related papers at the same time. + +In the mid-term stage, through the related large models read in the early stage, collecting papers on the issue data set, we began to organize data mainly based on issues, as well as App projects with Python as the main development language. + +In the final stage, after the data set is organized, we began to design test evaluation indicators, test the data set, and write unit tests/integration tests. Test-driven to ensure code correctness. + + +## Project development time plan +| Time plan | Task | +|------------|--------------------------------------------------------------------------| +| July 13th - early August | Read relevant papers; collect data sets, read and understand open source projects, write corresponding project product documents and requirement documents. The requirement documents must include key points such as command, issue, PR, etc. of the corresponding project | +| Mid-August to early September | Organize the collected data sets, expand the collection scope, expand the data sets, and write test evaluation indicators for the data sets and large models | +| Mid-September to the end of September | Write unit tests/integration tests and test-driven tests to ensure code correctness. | + + + + + + + + + + + + + + + diff --git a/docs/proposals/scenarios/Smart_Coding/Smart Coding benchmark suite Proposal_zh.md b/docs/proposals/scenarios/Smart_Coding/Smart Coding benchmark suite Proposal_zh.md new file mode 100644 index 00000000..775a85e4 --- /dev/null +++ b/docs/proposals/scenarios/Smart_Coding/Smart Coding benchmark suite Proposal_zh.md @@ -0,0 +1,172 @@ +# 背景 +大型语言模型(LLM)在代码生成、自动编程、代码分析等任务中展现出了强大的能力,但这些模型通常是在通用代码数据上训练的,往往不能充分利用实际场景中软件工程师的协作和反馈。为了构建更加智能高效的代码生态,需要建立协作代码数据集和评测基准,促进LLM与软件工程师的紧密协作。本项目旨在基于开源边缘计算框架KubeEdge-Ianvs构建LLM协作代码智能体对齐数据集和评测基准。该数据集将包括软件工程师在开发过程中的行为轨迹、反馈和迭代过程,以及相关的代码版本和注释信息。通过这些数据,我们将设计评测指标和基准来衡量LLM在代码生成、推荐和分析等任务中的表现,促进LLM与软件工程师之间的协作。 + +在当今的软件开发实践中,大型语言模型(LLM)在代码生成、推荐和分析等领域展现出巨大的潜力。但现有模型通常是在通用代码库上训练的,缺乏针对特定软件工程任务的优化,因此建立融合软件工程师实际工作经验与反馈的特定数据集与评估基准,对提升这些模型在实际编程环境中的应用效果至关重要。 +# Goals +1. 为大模型构建协作代码智能数据集 +2. 为大模型构建代码协同智能评估基准测试 +3. 将数据集和智能评估基准集成到KubeEdge-Ianvs框架中 +# Proposal +## 构建数据集 + +1. **开发过程中的行为轨迹**: +记录软件工程师在开发过程中执行的操作。这些操作可能包括代码编写、代码提交、代码合并、代码审查、代码重构等。 +具体的行为数据可能包括使用的开发工具、编写的代码片段、提交记录、审查意见等。 +2. **反馈及迭代**: +收集研发工程师在开发过程中对代码的反馈和迭代记录,这些反馈可能包括代码审查意见、测试结果、错误报告、改进建议等。 +记录反馈时间、反馈内容、对应的代码修改、最终解决方案。 +3. **代码版本及注释**: +记录每个版本的代码,以及各个版本之间的差异,包括新增、修改、删除的代码。 +包括详细的代码注释和文档,以了解代码的功能、用途、设计思想。 + +## 代码大模型语言评估基准 +1. 评测基准应包括代码生成、推荐和分析等常见的代码智能体任务。 +2. 评测指标应涵盖功能性、可靠性、可解释性等多个维度,并与软件工程师的反馈和需求相匹配。 +3. 评测基准应能够评估LLMs在协作式代码智能体任务上的性能,并为进一步的算法优化提供依据。 +## 将数据集和评测基准集成到KubeEdge-Ianvs框架中 + +1. 将数据集和评测基准作为Ianvs框架的一部分,并提供良好的可扩展性和可集成性。 +2. 确保数据集和评测基准能够在Ianvs框架的边缘设备上高效运行,并与Ianvs的其他功能模块无缝协作. + +# Design Details +## Data collection +1. GitHub: 从GitHub上收集各种编程语言的开源项目代码。通过GitHub API或手动检索。 +2. GitHub Issues: 收集开发者提交的问题报告,包括Bug报告、功能请求和讨论。 +3. Pull Requests: 收集开发者提交的拉取请求,包括功能实现和代码修改的审查历史。 +3. Commit Logs: 提取项目的提交日志,包括代码的每次变更、提交者信息和提交时间。 +4. Branches and Merges: 考虑分支和合并的信息,以理解代码的开发和合并历史。 + +`examples/smart_coding` 目录结构: +``` +smart_coding +└── smart_coding_learning_bench + └── smart_co + ├── benchmarkingjob.yaml + ├── testalgorithms + │ └── gen + │ ├── basemodel.py + │ ├── gen_algorithm.yaml + │ ├── op_eval.py + └── testenv + ├── acc.py + └── testenv.yaml +``` +comment测试集部分内容格式如下: +``` +{"description": "为给定的代码/函数添加详细注释。","code_snippet": "def calculate_area(length, width):\n return length * width",} +{"description": "为给定的Python函数添加详细注释。", + "code_snippet": "def calculate_area(length, width):\n return length * width", + "annotations": [ + { + "line_number": 1, + "annotation": "定义一个函数calculate_area,接受两个参数:length和width。" + }, + { + "line_number": 2, + "annotation": "返回length和width的乘积,即矩形的面积。" + } + ]} +``` + +在本项目中,负责的部分主要是代码大模型的测试套件,对于代码大模型来说,主要就是由任务要求中的comment和issue +而对于comment的评测标准,不同的项目,使用不同的字段,打分的部分,由通过代码整体部分的逻辑性,准确性,以及格式等部分来分别进行打分。 + +数据集部分,接口的定义部分,采用一问一答的方式,Question为某一行代码/某一个函数,Answer也是Comment则是对这个代码或函数的注释。 + +issue测试集部分内容格式如下: +``` +{ + "question": "title", + "user_login": "name", + "created_at":"time", + "updated_at": "time", + "body":"This is not possible right now afaik :/\r\n\r\nMaybe we could have something like this ? wdyt ?\r\n\r\n```python\r\nds = interleave_datasets(\r\n [shuffled_dataset_a, dataset_b],\r\n probabilities=probabilities,\r\n stopping_strategy='all_exhausted',\r\n reshuffle_each_iteration=True,\r\n)", + "answer_1": { + "user_login": "name", + "created_at":"time" + "updated_at": "time", + "body":"This is not possible right now afaik :/\r\n\r\nMaybe we could have something like this ? wdyt ?\r\n\r\n```python\r\nds = interleave_datasets(\r\n [shuffled_dataset_a, dataset_b],\r\n probabilities=probabilities,\r\n stopping_strategy='all_exhausted',\r\n reshuffle_each_iteration=True,\r\n)", + }, + "answer_2": { + "user_login": "name", + "created_at":"time" + "updated_at": "time", + "body":"XXXX" + }, + +} +``` +issue部分测试集格式,参照简单的QA问答任务,通过LLM来回答问题的解决办法,在评估阶段使用人工判决,最后来计算acc准确率。 + +如下是基于用户输入配置的基准测试系统的操作流程,这部分因为接口部分是由孟卓同学来写的,所以大体结构基本是与孟卓一致的。流程图展示了数据的验证、解析、初始化等操作。 +其中不同的点在于对issue数据集的读取,issue数据集中,Question只有一个,但Comment可能会有多个,所以在训练部分,数据的读取还有调整。 + +![](https://github.com/safe-b/ianvs/blob/main/docs/proposals/scenarios/Smart_Coding/image/data_process_change.png?raw=true) + +![](https://github.com/safe-b/ianvs/blob/main/docs/proposals/scenarios/Smart_Coding/image/change_part.png?raw=true) + +值得注意的是,该设计同时也兼容对旧版的index数据的支持。仅仅只需要将旧版的train_url和test_url字段改成train_index和test_index即可。 + +在之前的项目中,需要在`testenv.yaml`文件中配置`train_url`和`test_url`索引文件的路径,索引文件中会放 (输入x, 期望输出y) 的文件路径,这个设计是存在一些局限性的。 + +数据集文件格式一个data.json/jsonl/文件就把数据和标签都写进去了,其中政务大模型的数据集格式定义如下 +``` +{"code_snippet": xxx, "comment": xxx} +{"code_snippet": xxx, "comment": xxx} +{"code_snippet": xxx, "comment": xxx} +``` + +在代码大模型中,对于代码和函数添加Comment数据格式定义与政务大模型一致,但issue数据集中还需要对answer的属性进行修改,例如 +``` +{ + "question": xxx, + "user_login": "XXXX", + "created_at":"XXXX-XX-XX", + "updated_at": "XXXX-XX-XX", + "body":"This is not possible right now afaik :/\r\n\r\nMaybe we could have something like this ? wdyt ?\r\n\r\n```python\r\nds = interleave_datasets(\r\n [shuffled_dataset_a, dataset_b],\r\n probabilities=probabilities,\r\n stopping_strategy='all_exhausted',\r\n reshuffle_each_iteration=True,\r\n)", + "answer_1":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + "answer_2":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + "answer_3":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + } +{ + "question": xxx, + "user_login": "XXXX-XX-XX", + "created_at":"XXXX-XX-XX", + "updated_at": "XXXX-XX-XX", + "body":"XXXX....", + "answer_1":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + "answer_2":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + "answer_3":{"user_login": "name","created_at":"time""updated_at": "time","body":"XXXX"} + } +....... +``` + + +### 具体步骤 +前期阶段,通过给python文件加comments (以关键行、段为粒度),搜集经典python项目,对项目中的核心代码行或者函数,进行注释,以此来整理数据集,并同步阅读相关论文。 + +中期阶段,通过前期阅读的相关大模型搜集issue数据集的论文,在此开始着手整理以issue为主的数据,还有以python为主要开发语言的App项目。 + +最终阶段,数据集整理完毕,开始着手设计测试评估指标,针对数据集进行测试,并编写单元测试/集成测试,测试驱动保证代码正确性。 + +## 时间规划 +| 时间规划 | 任务 | +|------------|--------------------------------------------------------------------------| +| 7月13号-8月上旬 | 阅读相关论文;搜集数据集,阅读并了解开源项目,写出相应的项目产品文档,需求文档,需求文档需包含对应项目的command、issue、PR等关键点 | +| 8月中旬-9月上旬 | 对于已搜集的数据集进行整理,并且扩大搜集范围,扩充数据集,编写针对该数据集,大模型的测试评价指标 | +| 9月中旬-9月底 | 编写单元测试/集成测试,测试驱动保证代码正确性 | + + + + + + + + + + + + + + + diff --git a/docs/proposals/scenarios/Smart_Coding/image/change_part.png b/docs/proposals/scenarios/Smart_Coding/image/change_part.png new file mode 100644 index 00000000..480ffdd6 Binary files /dev/null and b/docs/proposals/scenarios/Smart_Coding/image/change_part.png differ diff --git a/docs/proposals/scenarios/Smart_Coding/image/data_process_change.png b/docs/proposals/scenarios/Smart_Coding/image/data_process_change.png new file mode 100644 index 00000000..7de24d7d Binary files /dev/null and b/docs/proposals/scenarios/Smart_Coding/image/data_process_change.png differ diff --git a/docs/proposals/scenarios/llm-benchmark-suite/images/fp16.jpg b/docs/proposals/scenarios/llm-benchmark-suite/images/fp16.jpg new file mode 100644 index 00000000..1f92eeff Binary files /dev/null and b/docs/proposals/scenarios/llm-benchmark-suite/images/fp16.jpg differ diff --git a/docs/proposals/scenarios/llm-benchmark-suite/images/llm-edge-ianvs.png b/docs/proposals/scenarios/llm-benchmark-suite/images/llm-edge-ianvs.png new file mode 100644 index 00000000..10126f3e Binary files /dev/null and b/docs/proposals/scenarios/llm-benchmark-suite/images/llm-edge-ianvs.png differ diff --git a/docs/proposals/scenarios/llm-benchmark-suite/images/prune.jpg b/docs/proposals/scenarios/llm-benchmark-suite/images/prune.jpg new file mode 100644 index 00000000..3ef388b1 Binary files /dev/null and b/docs/proposals/scenarios/llm-benchmark-suite/images/prune.jpg differ diff --git a/docs/proposals/scenarios/llm-benchmark-suite/images/quant.jpg b/docs/proposals/scenarios/llm-benchmark-suite/images/quant.jpg new file mode 100644 index 00000000..d16742c4 Binary files /dev/null and b/docs/proposals/scenarios/llm-benchmark-suite/images/quant.jpg differ diff --git a/docs/proposals/scenarios/llm-benchmark-suite/llm-edge-benchmark-suite.md b/docs/proposals/scenarios/llm-benchmark-suite/llm-edge-benchmark-suite.md new file mode 100644 index 00000000..853cdcb3 --- /dev/null +++ b/docs/proposals/scenarios/llm-benchmark-suite/llm-edge-benchmark-suite.md @@ -0,0 +1,262 @@ +# Large Language Model Edge Benchmark Suite: Implementation on KubeEdge-lanvs + + +## 1 Motivation + +Due to model size and data scale, LLMs are often trained in the cloud. At the same time, since the use of large language models often involves business secrets or user privacy, deploying LLMs on **edge devices** has gradually become a research hotspot. LLMs quantization technology is enabling LLMs **edge inference**. However, the limited resources of **edge devices** have an impact on the inference latency and accuracy of cloud-trained LLMs. Ianvs aims to leverage container resource management capabilities and edge-cloud collaboration capabilities to perform **edge-side** deployment benchmarking for cloud-trained LLMs. + +## 2 Goals + +The primary goal is to develop a benchmarking suite for Large Language Models (LLMs) on **edge devices** using the KubeEdge-Ianvs platform. This suite will enable thorough testing and validation of LLMs, focusing on performance, accuracy, and resource utilization on **edge devices**. + + +### 2.1 Flexibility and Scalability +Ensure the Test Environment Manager can flexibly handle multiple datasets and performance metrics to cater to diverse testing requirements and scenarios. + + + +### 2.2 Comprehensive Algorithm Evaluation +Cover testing across multiple machine learning algorithm paradigms, including pre-training, fine-tuning, self-supervised learning, and multi-task learning. + +### 2.3 Automation and Continuous Integration +Integrate CI/CD processes to automate the management and continual updating of test environments and test cases. + + +## 3 Proposal + + +The proposal includes developing a benchmark suite that utilizes Ianvs to evaluate the performance of LLMs under various **edge-cloud** configurations. This will include: + +### 3.1 Dataset Integration And Accuracy Evaluation +Integrate widely-used benchmark datasets to evaluate the performance of LLMs on **edge devices** across various tasks and domains. Some key datasets to consider: + +1. MMLU (Measuring Massive Multitask Language Understanding): +- A comprehensive English benchmark for evaluating the knowledge and reasoning abilities of language models across 57 disciplines, from humanities and social sciences to STEM fields. +- Covers topics ranging from elementary to advanced levels. +- Assesses a model's general intelligence and cognitive capabilities. +- Dataset URL: https://github.com/hendrycks/test + +2. CMMLU (Measuring Massive Multitask Language Understanding in Chinese): +- A Chinese-specific benchmark for evaluating language models' knowledge and reasoning abilities in the Chinese language context. +- Includes 67 topics, from basic disciplines to advanced professional levels, with many tasks having China-specific answers. +- Covers natural sciences, humanities, social sciences, and daily life knowledge like Chinese driving rules. +- A fully Chinese-centric testing benchmark for assessing a model's cognitive intelligence in Chinese scenarios. +- Dataset URL: https://github.com/haonan-li/CMMLU + + +### 3.2 Algorithm Testing + +**Objective**: Conduct inference tests on different LLMs, assessing their performance across a variety of metrics and scenarios. + +Sure, here's a brief introduction to the Qwen, LLaMA, and ChatGLM models, similar to the MMLU and CMMLU dataset descriptions: + +1. Qwen-7B, Qwen-13B, Qwen-30B, Qwen-65B +- A series of Chinese-specific language models with varying parameter sizes (7B, 13B, 30B, 65B). +- Trained on large-scale Chinese corpora, aiming to provide high-quality language understanding and generation capabilities in Chinese. +- Can be used for various natural language processing tasks, such as question answering, text summarization, and sentiment analysis. +- Offers a range of model sizes to accommodate different computational resources and performance requirements. + +1. LLaMA-7B, LLaMA-13B, LLaMA-33B, LLaMA-65B +- A collection of open-source language models developed by Facebook, with parameter sizes ranging from 7B to 65B. +- Trained on a vast amount of text data, demonstrating strong performance in various language tasks. +- Designed to be efficient and scalable, allowing for deployment in different environments. +- Provides a foundation for researchers and developers to build upon and adapt for specific use cases. + +1. ChatGLM-6B, ChatGLM2-6B, ChatGLM2-130B +- A series of conversational language models specifically designed for chatbot applications. +- Trained on large-scale dialogue datasets to generate human-like responses in conversational contexts. +- Offers models with 6B parameters (ChatGLM-6B and ChatGLM2-6B) and a larger 130B parameter model (ChatGLM2-130B) for more advanced conversational capabilities. +- Can be used to build engaging and interactive chatbots for various domains, such as customer support, virtual assistants, and entertainment. + + +### 3.3 Benchmarking LLMs on Edge Devices + +**Objective**: Evaluate the performance of LLMs on **edge devices** by measuring key metrics such as memory usage, CPU load, and bandwidth consumption. This benchmarking helps in understanding the resource requirements and limitations of deploying LLMs on **edge devices** with constrained resources. While the focus is on single-device performance, the insights gained can provide valuable reference points for designing efficient edge-cloud collaborative inference systems in the future. + +#### 3.3.1 Inference Speed +- Measure the time taken for the LLM to generate responses on **edge devices** for various input lengths and types. + +- Compare inference speeds across different model sizes (e.g., 7B vs. 13B vs. 65B) and hardware configurations (e.g., CPU vs. GPU, different memory constraints). + +- Analyze the relationship between input length and both total and prefill latency. + +- Evaluate the impact of optimization techniques (like INT8 quantization and sparsity method on both latency metrics. + +- Consider measuring token generation speed (tokens per second) during the decoding phase to complement the latency metrics: + +1. Total Latency + - Definition: The total time taken from receiving the input to generating the complete output. + - Measurement: Record the time from when the input is sent to the model until the final token is generated. + - Importance: Represents the overall responsiveness of the system. +2. Prefill Latency + - Definition: The time taken to process the initial input (prompt) before generating the first output token. + - Measurement: Record the time from when the input is sent to the model until it's ready to generate the first token. + - Importance: Crucial for understanding the model's initial response time, especially important for interactive applications. + + + +#### 3.3.2 Resource Utilization + +Efficient resource utilization is critical for running LLMs on edge devices with limited computational capabilities. We'll focus on the following metrics: + +1. **Throughput** +- Definition: The number of tokens or inferences the model can process per unit of time. +- Measurement: Calculate tokens generated per second (TPS) or inferences per second (IPS). +- Importance: Indicates the model's efficiency in handling multiple requests or generating longer outputs. +2. **Memory Usage** +- Peak Memory Consumption: The maximum amount of memory used during inference. +- Average Memory Usage: The average memory consumption over the course of inference. +- Memory Efficiency: Ratio of output generated to memory used. +3. **CPU/GPU Utilization** +- Average CPU/GPU Usage: The mean percentage of CPU/GPU utilized during inference. +- CPU/GPU Usage Pattern: Analyze how CPU usage varies during different phases of inference (e.g., prefill vs. decoding). + +## 4 Design Details + + +The architecture of this proposal is shown in the figure below. We leverage the existed TestEnvManager, TestCaseController and StoryManager in Ianvs. + +![Architecture Diagram](./images/llm-ianvs.png) +1. **TestEnvManager**, add MMLU and CMMLU as LLM benchmark and Accuracy, Latency, Throughput, Bandwith as metrics. + +2. **TestCaseController**, Incorporate INT8 quantization, FP16 mixed precision, sparsity methods, + +3. **StoryManager**, show Leaderboard and Test Report for users. + + + +### 4.1 Opencompass Dataset Integration + + +The Opencompass dataset provides a comprehensive set of benchmarks for evaluating the performance of various language models. Integrating this dataset will enhance the evaluation capabilities of Ianvs by providing standardized and recognized benchmarks for LLMs + +``` +├── __init__.py +├── dataset +│   ├── __init__.py +│   └── dataset.py +└── testenv + ├── __init__.py + └── testenv.py +``` +To integrate the Opencompass dataset, `dataset/dataset.py` + +``` +from mmengine.config import read_base + +with read_base(): + from .datasets.siqa.siqa_gen import siqa_datasets + from .datasets.winograd.winograd_ppl import winograd_datasets + +datasets = [*siqa_datasets, *winograd_datasets] +``` + +### 4.2 INT8 Quantization And Mixed Preicison + +INT8 quantization and FP16 mixed precision are techniques used to optimize the performance of machine learning models: + +- INT8 Quantization: Reduces the model size and increases inference speed by converting weights and activations from 32-bit floating point to 8-bit integer. + +![quantization](images/quant.jpg) + +- FP16 Mixed Precision: Uses 16-bit floating point representation for some operations while keeping others in 32-bit, balancing the trade-off between speed and precision. +![fp16](./images/fp16.jpg) + + +### 4.3 Pruning method to get Sparsity LLM + +Pruning is a technique used to reduce the number of parameters in a model, thereby increasing efficiency: + +Selectively removes weights in the neural network that have little impact on the final output. + +This results in a sparse model that requires less computation and memory, improving performance on both GPU and CPU. + + +![prune](./images/prune.jpg) + +### 4.4 GPU and CPU Env + +To provide comprehensive benchmarking, it is essential to test models in both GPU and CPU environments: + +- **GPU Environment**: Benchmarks the performance of models when executed on GPUs, which are optimized for parallel processing and commonly used for training and inference of large models.The setup includes NVIDIA drivers, Docker, CUDA, and cuDNN. +- **CPU Environment**: Benchmarks the performance of models on CPUs, which are more commonly available and used in various deployment scenarios. + +`core/testenvmanager/testenv/testenv.py` +```py +def __init__(self, config): + self.model_eval = { + "model_metric": { + "mode": "", + "name": "", + "url": "", + }, + "threshold": 0.9, + "operator": ">" + } + self.metrics = [] + self.incremental_rounds = 2 + self.dataset = None + self.use_gpu = False # default false + self._parse_config(config) +``` + +```py +def _parse_config(self, config): + config_dict = config[str.lower(TestEnv.__name__)] + for k, v in config_dict.items(): + if k == str.lower(Dataset.__name__): + self.dataset = Dataset(v) + elif k == 'use_gpu': + self.use_gpu = bool(v) # make sure use_gpu bool value + else: + if k in self.__dict__: + self.__dict__[k] = v + self._check_fields() +``` + +`testenv.yaml` + +```yaml +testenv: + use_gpu: true # or false +``` + +### 4.5 Building Test Cases (build_testcases Method) + +This method accepts a test environment (`test_env`) and a configuration (`test_object`) containing multiple test objects (e.g., different LLMs like Qwen, Llama, ChatGLM, Baichuan). + +1. **Parameter Parsing**: First, it parses the `test_object` parameter to extract the configuration containing algorithm information. +2. **Algorithm Configuration Parsing (_parse_algorithms_config Method)**: It iterates through the list of algorithm configurations and creates an `Algorithm` instance for each algorithm. +3. **Instantiating Test Cases**: For each `Algorithm` instance, the method creates a `TestCase` instance and adds it to the controller's `test_cases` list. + +### 4.6 Algorithm Configuration Parsing (_parse_algorithms_config Method) + +This method is responsible for reading and parsing algorithm configuration files and creating `Algorithm` objects. + +1. **Configuration File Reading**: It checks if the configuration file exists locally. +2. **Configuration Conversion**: Uses the `utils.yaml2dict` method to convert the configuration file into a dictionary. +3. **Algorithm Instantiation**: Creates an `Algorithm` instance for each configuration. + +### 4.7 Running Test Cases (run_testcases Method) + +This method is responsible for executing all configured test cases. + +1. **Execution**: Iterates through the `test_cases` list, calling each `TestCase`'s `run` method, which involves the startup and warm-up of the large model. +2. **Error Handling**: Captures any exceptions that may occur during execution and logs the error information. +3. **Result Collection**: Collects the results and execution time for each test case. + + +## 5 Road Map +| Time | Activity | +|-------|-------------| +| July | Familiarize with the Ianvs platform and prepare the development environment,Design the project interface and write a detailed user guide. | +| August| Develop functionalities to support various dataset formats and model invocation visualization,Test various LLMs and generate detailed benchmarking reports. | +| Septembter| Integrate the benchmarking suite into the model training process for real-time evaluation and optimization. | + + +## Reference + +- [KubeEdge-Ianvs](https://github.com/kubeedge/ianvs) +- [KubeEdge-Ianvs Benchmark Sample](https://github.com/kubeedge/ianvs/blob/main/examples/robot-cityscapes-synthia/lifelong_learning_bench/semantic-segmentation/README.md) +- [how-to-build-simulation-env](https://github.com/kubeedge/ianvs/blob/main/docs/guides/how-to-build-simulation-env.md) +- [Example LLMs Benchmark List](https://github.com/terryyz/llm-benchmark) diff --git a/docs/proposals/test-reports/Smart Coding benchmark suite Proposal_zh.md b/docs/proposals/test-reports/Smart Coding benchmark suite Proposal_zh.md new file mode 100644 index 00000000..3313acb7 --- /dev/null +++ b/docs/proposals/test-reports/Smart Coding benchmark suite Proposal_zh.md @@ -0,0 +1,129 @@ +# 背景 +大型语言模型(LLM)在代码生成、自动编程、代码分析等任务中展现出了强大的能力,但这些模型通常是在通用代码数据上训练的,往往不能充分利用实际场景中软件工程师的协作和反馈。为了构建更加智能高效的代码生态,需要建立协作代码数据集和评测基准,促进LLM与软件工程师的紧密协作。本项目旨在基于开源边缘计算框架KubeEdge-Ianvs构建LLM协作代码智能体对齐数据集和评测基准。该数据集将包括软件工程师在开发过程中的行为轨迹、反馈和迭代过程,以及相关的代码版本和注释信息。通过这些数据,我们将设计评测指标和基准来衡量LLM在代码生成、推荐和分析等任务中的表现,促进LLM与软件工程师之间的协作。 + +在当今的软件开发实践中,大型语言模型(LLM)在代码生成、推荐和分析等领域展现出巨大的潜力。但现有模型通常是在通用代码库上训练的,缺乏针对特定软件工程任务的优化,因此建立融合软件工程师实际工作经验与反馈的特定数据集与评估基准,对提升这些模型在实际编程环境中的应用效果至关重要。 +# Goals +1. 为大模型构建协作代码智能数据集 +2. 为大模型构建代码协同智能评估基准测试 +3. 将数据集和智能评估基准集成到KubeEdge-Ianvs框架中 +# Proposal +## 构建数据集 + +1. **开发过程中的行为轨迹**: +记录软件工程师在开发过程中执行的操作。这些操作可能包括代码编写、代码提交、代码合并、代码审查、代码重构等。 +具体的行为数据可能包括使用的开发工具、编写的代码片段、提交记录、审查意见等。 +2. **反馈及迭代**: +收集研发工程师在开发过程中对代码的反馈和迭代记录,这些反馈可能包括代码审查意见、测试结果、错误报告、改进建议等。 +记录反馈时间、反馈内容、对应的代码修改、最终解决方案。 +3. **代码版本及注释**: +记录每个版本的代码,以及各个版本之间的差异,包括新增、修改、删除的代码。 +包括详细的代码注释和文档,以了解代码的功能、用途、设计思想。 + +## 代码大模型语言评估基准 +1. 评测基准应包括代码生成、推荐和分析等常见的代码智能体任务。 +2. 评测指标应涵盖功能性、可靠性、可解释性等多个维度,并与软件工程师的反馈和需求相匹配。 +3. 评测基准应能够评估LLMs在协作式代码智能体任务上的性能,并为进一步的算法优化提供依据。 +## 将数据集和评测基准集成到KubeEdge-Ianvs框架中 + +1. 将数据集和评测基准作为Ianvs框架的一部分,并提供良好的可扩展性和可集成性。 +2. 确保数据集和评测基准能够在Ianvs框架的边缘设备上高效运行,并与Ianvs的其他功能模块无缝协作. + +# Design Details +## Data collection +1. GitHub: 从GitHub上收集各种编程语言的开源项目代码。通过GitHub API或手动检索。 +2. GitHub Issues: 收集开发者提交的问题报告,包括Bug报告、功能请求和讨论。 +3. Pull Requests: 收集开发者提交的拉取请求,包括功能实现和代码修改的审查历史。 +3. Commit Logs: 提取项目的提交日志,包括代码的每次变更、提交者信息和提交时间。 +4. Branches and Merges: 考虑分支和合并的信息,以理解代码的开发和合并历史。 + +`examples/smart_coding` 目录结构: +``` +smart_coding +└── smart_coding_learning_bench + └── smart_co + ├── benchmarkingjob.yaml + ├── testalgorithms + │ └── gen + │ ├── basemodel.py + │ ├── gen_algorithm.yaml + │ ├── op_eval.py + └── testenv + ├── acc.py + └── testenv.yaml +``` +comment测试集部分内容格式如下: +``` +{"description": "为给定的代码/函数添加详细注释。","code_snippet": "def calculate_area(length, width):\n return length * width",} +{"description": "为给定的Python函数添加详细注释。", + "code_snippet": "def calculate_area(length, width):\n return length * width", + "annotations": [ + { + "line_number": 1, + "annotation": "定义一个函数calculate_area,接受两个参数:length和width。" + }, + { + "line_number": 2, + "annotation": "返回length和width的乘积,即矩形的面积。" + } + ]} +``` + +在本项目中,负责的部分主要是代码大模型的测试套件,对于代码大模型来说,主要就是由任务要求中的comment和issue +而对于comment的评测标准,不同的项目,使用不同的字段,打分的部分,由通过代码整体部分的逻辑性,准确性,以及格式等部分来分别进行打分。 + +数据集部分,接口的定义部分,对于代码/函数部分的comment是否需要在对整个函数/代码块进行comment的前提下,再对单独的某一行代码进行comment,或者说给的是整体,那么就单纯的对于整体进行一个comment +如果用户需要对某一行代码进行comment的话,再重新提问进行指定回答。 + +issue测试集部分内容格式如下: +``` +{ + "issue_id": "issue编号", + "repository_name": "GitHub仓库名", + "label":"类型、级别" + "issue_description": "issue标题描述", + "code_before": { + "file_path": "代码的文件路径", + "code_snippet": "问题发生前的原始代码" + }, + "code_after": { + "file_path": "代码文件路径", + "code_snippet": "修改后的代码,需要包括对于问题的解决方案" + }, + "pull_request_id": "相对应的PR编号", + "pull_request_description": "PR的描述,说明了做出的更改及其原因" +} +``` +对于issue部分的数据格式,还需要再讨论一下 + +## BenchMark格式示例 +[引用陈孟卓部分的benchMark](https://github.com/IcyFeather233/ianvs/blob/main/docs/proposals/scenarios/llm-benchmarks/llm-benchmarks.md) + + +### 具体步骤 +前期阶段,通过给python文件加comments (以关键行、段为粒度),搜集经典python项目,对项目中的核心代码行或者函数,进行注释,以此来整理数据集,并同步阅读相关论文。 + +中期阶段,通过前期阅读的相关大模型搜集issue数据集的论文,在此开始着手整理以issue为主的数据,还有以python为主要开发语言的App项目。 + +最终阶段,数据集整理完毕,开始着手设计测试评估指标,针对数据集进行测试,并编写单元测试/集成测试,测试驱动保证代码正确性。 + +## 时间规划 +| 时间规划 | 任务 | +|------------|--------------------------------------------------------------------------| +| 7月13号-8月上旬 | 阅读相关论文;搜集数据集,阅读并了解开源项目,写出相应的项目产品文档,需求文档,需求文档需包含对应项目的command、issue、PR等关键点 | +| 8月中旬-9月上旬 | 对于已搜集的数据集进行整理,并且扩大搜集范围,扩充数据集,编写针对该数据集,大模型的测试评价指标 | +| 9月中旬-9月底 | 编写单元测试/集成测试,测试驱动保证代码正确性 | + + + + + + + + + + + + + + + diff --git a/examples/cifar100/fci_ssl/fed_ci_match/algorithm/FedCiMatch.py b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/FedCiMatch.py new file mode 100644 index 00000000..cc7b1ffc --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/FedCiMatch.py @@ -0,0 +1,418 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import logging +import tensorflow as tf +import keras +import numpy as np +from model import resnet10, resnet18 +from agumentation import * +from data_prepocessor import * + + +def get_one_hot(target, num_classes): + y = tf.one_hot(target, depth=num_classes) + if len(y.shape) == 3: + y = tf.squeeze(y, axis=1) + return y + + +class FedCiMatch: + + def __init__( + self, num_classes, batch_size, epochs, learning_rate, memory_size + ) -> None: + self.num_classes = num_classes + self.batch_size = batch_size + self.epochs = epochs + self.learning_rate = learning_rate + self.memory_size = memory_size + self.task_size = None + self.warm_up_round = 4 + self.accept_threshold = 0.95 + self.old_task_id = -1 + + self.classifier = None + self.feature_extractor = self.build_feature_extractor() + + self.fe_weights_length = 0 + self.labeled_train_loader = None + self.unlabeled_train_loader = None + self.labeled_train_set = None + self.unlabeled_train_set = None + dataset_name = "cifar100" + self.data_preprocessor = Dataset_Preprocessor( + dataset_name, Weak_Augment(dataset_name), RandAugment(dataset_name) + ) + self.last_classes = None + self.current_classes = None + self.learned_classes = [] + self.learned_classes_num = 0 + self.exemplar_set = [] + self.seen_classes = [] + self.best_old_model = None + print(f"self epoch is {self.epochs}") + + def build_feature_extractor(self): + feature_extractor = resnet18() + + feature_extractor.build(input_shape=(None, 32, 32, 3)) + feature_extractor.call(keras.Input(shape=(32, 32, 3))) + return feature_extractor + + def build_classifier(self): + if self.classifier != None: + new_classifier = keras.Sequential( + [ + keras.layers.Dense( + self.num_classes, kernel_initializer="lecun_normal" + ) + ] + ) + new_classifier.build( + input_shape=(None, self.feature_extractor.layers[-2].output_shape[-1]) + ) + new_weights = new_classifier.get_weights() + old_weights = self.classifier.get_weights() + # weight + new_weights[0][0 : old_weights[0].shape[0], 0 : old_weights[0].shape[1]] = ( + old_weights[0] + ) + # bias + new_weights[1][0 : old_weights[1].shape[0]] = old_weights[1] + new_classifier.set_weights(new_weights) + self.classifier = new_classifier + else: + logging.info( + f"input shape is {self.feature_extractor.layers[-2].output_shape[-1]}" + ) + self.classifier = keras.Sequential( + [ + keras.layers.Dense( + self.num_classes, kernel_initializer="lecun_normal" + ) + ] + ) + self.classifier.build( + input_shape=(None, self.feature_extractor.layers[-2].output_shape[-1]) + ) + + logging.info(f"finish ! initialize classifier {self.classifier.summary()}") + + def get_weights(self): + weights = [] + fe_weights = self.feature_extractor.get_weights() + self.fe_weights_length = len(fe_weights) + clf_weights = self.classifier.get_weights() + weights.extend(fe_weights) + weights.extend(clf_weights) + return weights + + def set_weights(self, weights): + fe_weights = weights[: self.fe_weights_length] + clf_weights = weights[self.fe_weights_length :] + self.feature_extractor.set_weights(fe_weights) + self.classifier.set_weights(clf_weights) + + def model_call(self, x, training=False): + x = self.feature_extractor(x, training=training) + x = self.classifier(x, training=training) + # x = tf.nn.softmax(x) + return x + + def before_train(self, task_id, round, train_data, task_size): + if self.task_size is None: + self.task_size = task_size + is_new_task = task_id != self.old_task_id + self.is_new_task = is_new_task + if is_new_task: + self.best_old_model = ( + (self.feature_extractor, self.classifier) + if self.classifier is not None + else None + ) + self.is_new_task = True + self.old_task_id = task_id + self.num_classes = self.task_size * (task_id + 1) + logging.info(f"num_classes: {self.num_classes}") + if self.current_classes is not None: + self.last_classes = self.current_classes + # self.build_classifier() + self.current_classes = np.unique(train_data["label_y"]).tolist() + logging.info(f"current_classes: {self.current_classes}") + + self.labeled_train_set = (train_data["label_x"], train_data["label_y"]) + self.unlabeled_train_set = ( + train_data["unlabel_x"], + train_data["unlabel_y"], + ) + logging.info( + f"self.labeled_train_set is None :{self.labeled_train_set is None}" + ) + logging.info( + f"self.unlabeled_train_set is None :{self.unlabeled_train_set is None}" + ) + self.labeled_train_loader, self.unlabeled_train_loader = self.get_train_loader() + + def get_data_size(self): + logging.info( + f"self.labeled_train_set is None :{self.labeled_train_set is None}" + ) + logging.info( + f"self.unlabeled_train_set is None :{self.unlabeled_train_set is None}" + ) + data_size = len(self.labeled_train_set[0]) + len(self.unlabeled_train_set[0]) + logging.info(f"data size: {data_size}") + return data_size + + def get_train_loader(self): + train_x = self.labeled_train_set[0] + train_y = self.labeled_train_set[1] + logging.info( + f"train_x shape: {train_x.shape} and train_y shape: {train_y.shape} and len of exemplar_set: {len(self.exemplar_set)}" + ) + if len(self.exemplar_set) != 0: + for exm_set in self.exemplar_set: + train_x = np.concatenate((train_x, exm_set[0]), axis=0) + label = np.array(exm_set[1]) + train_y = np.concatenate((train_y, label), axis=0) + logging.info( + f"train_x shape: {train_x.shape} and train_y shape: {train_y.shape}" + ) + + logging.info( + f"train_x shape: {train_x.shape} and train_y shape: {train_y.shape}" + ) + label_data_loader = self.data_preprocessor.preprocess_labeled_dataset( + train_x, train_y, self.batch_size + ) + unlabel_data_loader = None + if len(self.unlabeled_train_set[0]) > 0: + unlabel_data_loader = self.data_preprocessor.preprocess_unlabeled_dataset( + self.unlabeled_train_set[0], + self.unlabeled_train_set[1], + self.batch_size, + ) + logging.info( + f"unlabel_x shape: {self.unlabeled_train_set[0].shape} and unlabel_y shape: {self.unlabeled_train_set[1].shape}" + ) + return label_data_loader, unlabel_data_loader + + def build_exemplar(self): + if self.is_new_task and self.current_classes is not None: + self.last_classes = self.current_classes + self.learned_classes.extend(self.last_classes) + self.learned_classes_num += len(self.learned_classes) + m = int(self.memory_size / self.num_classes) + self.reduce_exemplar_set(m) + for cls in self.last_classes: + images = self.get_train_data(cls) + self.construct_exemplar_set(images, cls, m) + self.is_new_task = False + + def reduce_exemplar_set(self, m): + for i in range(len(self.exemplar_set)): + old_exemplar_data = self.exemplar_set[i][0][:m] + old_exemplar_label = self.exemplar_set[i][1][:m] + self.exemplar_set[i] = (old_exemplar_data, old_exemplar_label) + + def get_train_data(self, class_id): + images = [] + train_x = self.labeled_train_set[0] + train_y = self.labeled_train_set[1] + for i in range(len(train_x)): + if train_y[i] == class_id: + images.append(train_x[i]) + return images + + def construct_exemplar_set(self, images, class_id, m): + exemplar_data = [] + exemplar_label = [] + class_mean, fe_ouput = self.compute_exemplar_mean(images) + diff = tf.abs(fe_ouput - class_mean) + distance = [float(tf.reduce_sum(dis).numpy()) for dis in diff] + + sorted_index = np.argsort(distance).tolist() + if len(sorted_index) > m: + sorted_index = sorted_index[:m] + exemplar_data = [images[i] for i in sorted_index] + exemplar_label = [class_id] * len(exemplar_data) + self.exemplar_set.append((exemplar_data, exemplar_label)) + + + def compute_exemplar_mean(self, images): + images_data = ( + tf.data.Dataset.from_tensor_slices(images) + .batch(self.batch_size) + .map(lambda x: tf.cast(x, dtype=tf.float32) / 255.0) + ) + fe_output = self.feature_extractor.predict(images_data) + print("fe_output shape:", fe_output.shape) + class_mean = tf.reduce_mean(fe_output, axis=0) + return class_mean, fe_output + + def train(self, round): + # optimizer = keras.optimizers.SGD( + # learning_rate=self.learning_rate, momentum=0.9, weight_decay=0.0001 + # ) + optimizer = keras.optimizers.Adam( + learning_rate=self.learning_rate, weight_decay=0.0001 + ) + q = [] + logging.info(f"is new task: {self.is_new_task}") + if self.is_new_task: + self.build_classifier() + all_params = [] + all_params.extend(self.feature_extractor.trainable_variables) + all_params.extend(self.classifier.trainable_variables) + + for epoch in range(self.epochs): + # following code is for unsupervised learning + # for labeled_data, unlabeled_data in zip( + # self.labeled_train_loader, self.unlabeled_train_loader + # ): + for step, (labeled_x, labeled_y) in enumerate(self.labeled_train_loader): + with tf.GradientTape() as tape: + input = self.feature_extractor(inputs=labeled_x, training=True) + y_pred = self.classifier(inputs=input, training=True) + label_pred = tf.argmax(y_pred, axis=1) + label_pred = tf.cast(label_pred, dtype=tf.int32) + label_pred = tf.reshape(label_pred, labeled_y.shape) + correct = tf.reduce_sum( + tf.cast(tf.equal(label_pred, labeled_y), dtype=tf.int32) + ) + CE_loss = self.supervised_loss(labeled_x, labeled_y) + KD_loss = self.distil_loss(labeled_x, labeled_y) + supervised_loss = CE_loss + + # following code is for unsupervised learning + # if epoch > self.warm_up_round: + # unsupervised_loss = self.unsupervised_loss( + # weak_unlabeled_x, strong_unlabeled_x, unlabeled_x + # ) + # logging.info(f"unsupervised loss: {unsupervised_loss}") + # loss = 0.5 * supervised_loss + 0.5 * unsupervised_loss + # else: + # loss = supervised_loss + loss = CE_loss + KD_loss + logging.info( + f"epoch {epoch} loss: {loss} correct {correct} and total {labeled_x.shape[0]} class is {np.unique(labeled_y)}" + ) + grads = tape.gradient(loss, all_params) + optimizer.apply_gradients(zip(grads, all_params)) + + def caculate_pre_update(self): + q = [] + for images, _ in self.labeled_train_loader: + x = self.feature_extractor(images, training=False) + x = self.classifier(x, training=False) + x = tf.nn.sigmoid(x) + q.append(x) + logging.info(f"q shape: {len(q)}") + return q + + def supervised_loss(self, x, y): + input = x + input = self.feature_extractor(input, training=True) + y_pred = self.classifier(input, training=True) + target = get_one_hot(y, self.num_classes) + loss = keras.losses.categorical_crossentropy(target, y_pred, from_logits=True) + logging.info(f"loss shape: {loss.shape}") + loss = tf.reduce_mean(loss) + logging.info(f"CE loss: {loss}") + + return loss + + def distil_loss(self, x, y): + KD_loss = 0 + + if len(self.learned_classes) > 0 and self.best_old_model is not None: + g = self.feature_extractor(x, training=True) + g = self.classifier(g, training=True) + og = self.best_old_model[0](x, training=False) + og = self.best_old_model[1](og, training=False) + sigmoid_og = tf.nn.sigmoid(og) + sigmoid_g = tf.nn.sigmoid(g) + BCELoss = keras.losses.BinaryCrossentropy() + loss = [] + for y in self.learned_classes: + if y not in self.current_classes: + loss.append(BCELoss(sigmoid_og[:, y], sigmoid_g[:, y])) + KD_loss = tf.reduce_sum(loss) + logging.info(f"KD_loss: {KD_loss}") + return KD_loss + + def unsupervised_loss(self, weak_x, strong_x, x): + prob_on_wux = tf.nn.softmax( + self.classifier( + self.feature_extractor(weak_x, training=True), training=True + ) + ) + pseudo_mask = tf.cast( + tf.reduce_max(prob_on_wux, axis=1) > self.accept_threshold, tf.float32 + ) + pse_uy = tf.one_hot( + tf.argmax(prob_on_wux, axis=1), depth=self.num_classes + ).numpy() + prob_on_sux = tf.nn.softmax( + self.classifier( + self.feature_extractor(strong_x, training=True), training=True + ) + ) + loss = keras.losses.categorical_crossentropy(pse_uy, prob_on_sux) + loss = tf.reduce_mean(loss * pseudo_mask) + return loss + + def predict(self, x): + mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + x = (tf.cast(x, dtype=tf.float32) / 255.0 - mean) / std + pred = self.classifier(self.feature_extractor(x, training=False)) + prob = tf.nn.softmax(pred, axis=1) + pred = tf.argmax(prob, axis=1) + pred = tf.cast(pred, dtype=tf.int32) + return pred + + def icarl_predict(self, x): + mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + x = (tf.cast(x, dtype=tf.float32) / 255.0 - mean) / std + bs = x.shape[0] + print(x.shape) + exemplar_mean = [] + for exemplar in self.exemplar_set: + # features = [] + ex, _ = exemplar + ex = (tf.cast(ex, dtype=tf.float32) / 255.0 - mean) / std + feature = self.feature_extractor(ex, training=False) + feature = feature / tf.norm(feature) + mu_y = tf.reduce_mean(feature, axis=0) + mu_y = mu_y / tf.norm(mu_y) + exemplar_mean.append(mu_y) + means = tf.stack(exemplar_mean) # shape: (num_classes, feature_shape) + means = tf.stack([means] * bs) # shape: (bs, num_classes, feature_shape) + means = tf.transpose( + means, perm=[0, 2, 1] + ) # shape: (bs, feature_shape, num_classes) + feature = self.feature_extractor( + x, training=False + ) # shape (bs , feature_shape) + feature = feature / tf.norm(feature) + feature = tf.expand_dims(feature, axis=2) + feature = tf.tile(feature, [1, 1, self.num_classes]) + dists = tf.pow((feature - means), 2) + dists = tf.reduce_sum(dists, axis=1) # shape: (bs, num_classes) + preds = tf.argmin(dists, axis=1) # shape: (bs) + logging.info(f"preds : {preds}") + return preds diff --git a/examples/cifar100/fci_ssl/fed_ci_match/algorithm/aggregation.py b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/aggregation.py new file mode 100644 index 00000000..46c57f19 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/aggregation.py @@ -0,0 +1,56 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import numpy as np +from sedna.algorithms.aggregation.aggregation import BaseAggregation +from sedna.common.class_factory import ClassType, ClassFactory + +@ClassFactory.register(ClassType.FL_AGG, "FedAvg") +class FedAvg(BaseAggregation, abc.ABC): + def __init__(self): + super(FedAvg, self).__init__() + + def aggregate(self, clients): + """ + Calculate the average weight according to the number of samples + + Parameters + ---------- + clients: List + All clients in federated learning job + + Returns + ------- + update_weights : Array-like + final weights use to update model layer + """ + + + print("aggregation....") + if not len(clients): + return self.weights + self.total_size = sum([c.num_samples for c in clients]) + old_weight = [np.zeros(np.array(c).shape) for c in + next(iter(clients)).weights] + updates = [] + for inx, row in enumerate(old_weight): + for c in clients: + row += (np.array(c.weights[inx]) * c.num_samples + / self.total_size) + updates.append(row.tolist()) + + print("finish aggregation....") + return [np.array(layer) for layer in updates] diff --git a/examples/cifar100/fci_ssl/fed_ci_match/algorithm/agumentation.py b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/agumentation.py new file mode 100644 index 00000000..89d1bef2 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/agumentation.py @@ -0,0 +1,230 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import random +import tensorflow as tf +from PIL import Image, ImageEnhance, ImageOps + + +""" +Reference: https://github.com/heartInsert/randaugment +""" + + +class Rand_Augment: + def __init__(self, Numbers=None, max_Magnitude=None): + self.transforms = [ + "autocontrast", + "equalize", + "rotate", + "solarize", + "color", + "posterize", + "contrast", + "brightness", + "sharpness", + "shearX", + "shearY", + "translateX", + "translateY", + ] + if Numbers is None: + self.Numbers = len(self.transforms) // 2 + else: + self.Numbers = Numbers + if max_Magnitude is None: + self.max_Magnitude = 10 + else: + self.max_Magnitude = max_Magnitude + fillcolor = 128 + self.ranges = { + # these Magnitude range , you must test it yourself , see what will happen after these operation , + # it is no need to obey the value in autoaugment.py + "shearX": np.linspace(0, 0.3, 10), + "shearY": np.linspace(0, 0.3, 10), + "translateX": np.linspace(0, 0.2, 10), + "translateY": np.linspace(0, 0.2, 10), + "rotate": np.linspace(0, 360, 10), + "color": np.linspace(0.0, 0.9, 10), + "posterize": np.round(np.linspace(8, 4, 10), 0).astype(int), + "solarize": np.linspace(256, 231, 10), + "contrast": np.linspace(0.0, 0.5, 10), + "sharpness": np.linspace(0.0, 0.9, 10), + "brightness": np.linspace(0.0, 0.3, 10), + "autocontrast": [0] * 10, + "equalize": [0] * 10, + "invert": [0] * 10, + } + self.func = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, + fill=fillcolor, + ), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, + fill=fillcolor, + ), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), + fill=fillcolor, + ), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), + fill=fillcolor, + ), + "rotate": lambda img, magnitude: self.rotate_with_fill(img, magnitude), + # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: img, + "invert": lambda img, magnitude: ImageOps.invert(img), + } + + def rand_augment(self): + """Generate a set of distortions. + Args: + N: Number of augmentation transformations to apply sequentially. N is len(transforms)/2 will be best + M: Max_Magnitude for all the transformations. should be <= self.max_Magnitude + """ + + M = np.random.randint(0, self.max_Magnitude, self.Numbers) + + sampled_ops = np.random.choice(self.transforms, self.Numbers) + return [(op, Magnitude) for (op, Magnitude) in zip(sampled_ops, M)] + + def __call__(self, image): + operations = self.rand_augment() + for op_name, M in operations: + operation = self.func[op_name] + mag = self.ranges[op_name][M] + image = operation(image, mag) + return image + + def rotate_with_fill(self, img, magnitude): + # I don't know why rotate must change to RGBA , it is copy from Autoaugment - pytorch + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite( + rot, Image.new("RGBA", rot.size, (128,) * 4), rot + ).convert(img.mode) + + def test_single_operation(self, image, op_name, M=-1): + """ + :param image: image + :param op_name: operation name in self.transforms + :param M: -1 stands for the max Magnitude in there operation + :return: + """ + operation = self.func[op_name] + mag = self.ranges[op_name][M] + image = operation(image, mag) + return image + + +class Base_Augment: + def __init__(self, dataset_name: str) -> None: + self.dataset_name = dataset_name + + def __call__(self, images): + return images + + +class Weak_Augment(Base_Augment): + def __init__(self, dataset_name: str) -> None: + super().__init__(dataset_name) + self.augment_impl = self.augment_for_cifar + + def augment_mirror(self, x): + new_images = x.copy() + indices = np.arange(len(new_images)).tolist() + sampled = random.sample( + indices, int(round(0.5 * len(indices))) + ) # flip horizontally 50% + new_images[sampled] = np.fliplr(new_images[sampled]) + return new_images # random shift + + def augment_shift(self, x, w): + y = tf.pad(x, [[0] * 2, [w] * 2, [w] * 2, [0] * 2], mode="REFLECT") + return tf.image.random_crop(y, tf.shape(x)) + + def augment_for_cifar(self, images: np.ndarray): + return self.augment_shift(self.augment_mirror(images), 4) + + def __call__(self, images: np.ndarray): + return self.augment_impl(images) + + +class Strong_Augment(Base_Augment): + def __init__(self, dataset_name: str) -> None: + super().__init__(dataset_name) + + def augment_mirror(self, x): + new_images = x.copy() + indices = np.arange(len(new_images)).tolist() + sampled = random.sample( + indices, int(round(0.5 * len(indices))) + ) # flip horizontally 50% + new_images[sampled] = np.fliplr(new_images[sampled]) + return new_images # random shift + + def augment_shift_mnist(self, x, w): + y = tf.pad(x, [[0] * 2, [w] * 2, [w] * 2], mode="REFLECT") + return tf.image.random_crop(y, tf.shape(x)) + + def __call__(self, images: np.ndarray): + return self.augment_shift_mnist(self.augment_mirror(images), 4) + + +class RandAugment(Base_Augment): + def __init__(self, dataset_name: str) -> None: + super().__init__(dataset_name) + self.rand_augment = Rand_Augment() + self.input_shape = (32, 32, 3) + + def __call__(self, images): + print("images:", images.shape) + + return np.array( + [ + np.array( + self.rand_augment( + Image.fromarray(np.reshape(img, self.input_shape)) + ) + ) + for img in images + ] + ) diff --git a/examples/cifar100/fci_ssl/fed_ci_match/algorithm/algorithm.yaml b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/algorithm.yaml new file mode 100644 index 00000000..0701a660 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/algorithm.yaml @@ -0,0 +1,28 @@ +algorithm: + paradigm_type: "federatedclassincrementallearning" + fl_data_setting: + train_ratio: 1.0 + splitting_method: "default" + label_data_ratio: 1.0 + data_partition: "iid" + non_iid_ratio: "0.6" + initial_model_url: "/home/wyd/ianvs/project/init_model/cnn.pb" + + modules: + - type: "basemodel" + name: "FediCarl-Client" + url: "./examples/cifar100/fci_ssl/fed_ci_match/algorithm/basemodel.py" + hyperparameters: + - batch_size: + values: + - 128 + - learning_rate: + values: + - 0.001 + - epochs: + values: + - 1 + - type: "aggregation" + name: "FedAvg" + url: "./examples/cifar100/fci_ssl/fed_ci_match/algorithm/aggregation.py" + diff --git a/examples/cifar100/fci_ssl/fed_ci_match/algorithm/basemodel.py b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/basemodel.py new file mode 100644 index 00000000..e351ddf0 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/basemodel.py @@ -0,0 +1,79 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +sys.path.append(".") +sys.path.append("..") +import os +import numpy as np +import keras +import tensorflow as tf +from sedna.common.class_factory import ClassType, ClassFactory +from model import resnet10 +from FedCiMatch import FedCiMatch +import logging + +os.environ["BACKEND_TYPE"] = "KERAS" +__all__ = ["BaseModel"] +logging.getLogger().setLevel(logging.INFO) + + +@ClassFactory.register(ClassType.GENERAL, alias="FediCarl-Client") +class BaseModel: + def __init__(self, **kwargs) -> None: + self.kwargs = kwargs + self.learning_rate = kwargs.get("learning_rate", 0.001) + self.epochs = kwargs.get("epochs", 1) + self.batch_size = kwargs.get("batch_size", 32) + self.task_size = kwargs.get("task_size", 2) + self.memory_size = kwargs.get("memory_size", 2000) + self.num_classes = 50 # the number of class for the first task + self.FedCiMatch = FedCiMatch( + self.num_classes, + self.batch_size, + self.epochs, + self.learning_rate, + self.memory_size, + ) + self.class_learned = 0 + + def get_weights(self): + print("get weights") + return self.FedCiMatch.get_weights() + + def set_weights(self, weights): + print("set weights") + self.FedCiMatch.set_weights(weights) + + def train(self, train_data, val_data, **kwargs): + task_id = kwargs.get("task_id", 0) + round = kwargs.get("round", 1) + task_size = kwargs.get("task_size", self.task_size) + logging.info(f"in train: {round} task_id: {task_id}") + self.class_learned += self.task_size + self.FedCiMatch.before_train(task_id, round, train_data, task_size) + self.FedCiMatch.train(round) + logging.info(f"update example memory") + self.FedCiMatch.build_exemplar() + return {"num_samples": self.FedCiMatch.get_data_size(), "task_id": task_id} + + def predict(self, data_files, **kwargs): + result = {} + for data in data_files: + x = np.load(data) + logging.info(f"predicting {x.shape}") + res = self.FedCiMatch.predict(x) + result[data] = res.numpy() + print("finish predict") + return result diff --git a/examples/cifar100/fci_ssl/fed_ci_match/algorithm/data_prepocessor.py b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/data_prepocessor.py new file mode 100644 index 00000000..004143ac --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/data_prepocessor.py @@ -0,0 +1,68 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +from agumentation import Base_Augment + + +class Dataset_Preprocessor: + def __init__( + self, + dataset_name: str, + weak_augment_helper: Base_Augment, + strong_augment_helper: Base_Augment, + ) -> None: + self.weak_augment_helper = weak_augment_helper + self.strong_augment_helper = strong_augment_helper + self.mean = 0.0 + self.std = 1.0 + + if dataset_name == "cifar100": + self.mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + self.std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + print(f"mean: {self.mean}, std: {self.std}") + + def preprocess_labeled_dataset(self, x, y, batch_size): + # wx = self.weak_augment_helper(x) + return ( + tf.data.Dataset.from_tensor_slices((x, y)) + .shuffle(100000) + .map( + lambda x, y: ( + (tf.cast(x, dtype=tf.float32) / 255.0 - self.mean) / self.std, + tf.cast(y, dtype=tf.int32), + ) + ) + .batch(batch_size) + ) + + def preprocess_unlabeled_dataset(self, ux, uy, batch_size): + # unlabeled_train_db = tf.data.Dataset.from_tensor_slices((ux, ux, ux, uy)) + + wux = self.weak_augment_helper(ux) + sux = self.strong_augment_helper(ux) + return ( + tf.data.Dataset.from_tensor_slices((ux, wux, sux, uy)) + .shuffle(1000) + .map( + lambda ux, wux, sux, uy: ( + (tf.cast(ux, dtype=tf.float32) / 255.0 - self.mean) / self.std, + (tf.cast(wux, dtype=tf.float32) / 255.0 - self.mean) / self.std, + (tf.cast(sux, dtype=tf.float32) / 255.0 - self.mean) / self.std, + tf.cast(uy, dtype=tf.int32), + ) + ) + .batch(batch_size) + ) diff --git a/examples/cifar100/fci_ssl/fed_ci_match/algorithm/model.py b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/model.py new file mode 100644 index 00000000..7ffeb2dd --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match/algorithm/model.py @@ -0,0 +1,171 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import keras + + +# Input--conv2D--BN--ReLU--conv2D--BN--ReLU--Output +# \ / +# ------------------------------ +class BasicBlock(keras.layers.Layer): + def __init__(self, filter_num, stride=1): + super(BasicBlock, self).__init__() + + self.conv1 = keras.layers.Conv2D( + filter_num, (3, 3), strides=stride, padding="same" + ) + self.bn1 = keras.layers.BatchNormalization() + self.relu = keras.layers.Activation("relu") + + self.conv2 = keras.layers.Conv2D(filter_num, (3, 3), strides=1, padding="same") + self.bn2 = keras.layers.BatchNormalization() + + if stride != 1: + self.downsample = keras.models.Sequential() + self.downsample.add(keras.layers.Conv2D(filter_num, (1, 1), strides=stride)) + else: + self.downsample = lambda x: x + + def call(self, inputs, training=None): + # [b, h, w, c] + out = self.conv1(inputs) + out = self.bn1(out, training=training) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out, training=training) + + identity = self.downsample(inputs) + + output = keras.layers.add([out, identity]) + output = tf.nn.relu(output) + + return output + + +class ResNet(keras.Model): + def __init__(self, layer_dims): # [2, 2, 2, 2] + super(ResNet, self).__init__() + self.layer_dims = layer_dims + + self.stem = keras.models.Sequential( + [ + keras.layers.Conv2D(64, (3, 3), strides=(1, 1)), + keras.layers.BatchNormalization(), + keras.layers.Activation("relu"), + keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(1, 1), padding="same" + ), + ] + ) + + self.layer1 = self.build_resblock(64, layer_dims[0]) + self.layer2 = self.build_resblock(128, layer_dims[1], stride=2) + self.layer3 = self.build_resblock(256, layer_dims[2], stride=2) + self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) + + # output: [b, 512, h, w], + self.avgpool = keras.layers.GlobalAveragePooling2D() + + def call(self, inputs, training=None): + x = self.stem(inputs, training=training) + + x = self.layer1(x, training=training) + x = self.layer2(x, training=training) + x = self.layer3(x, training=training) + x = self.layer4(x, training=training) + x = self.avgpool(x) + return x + + def build_resblock(self, filter_num, blocks, stride=1): + res_blocks = keras.models.Sequential() + # may down sample + res_blocks.add(BasicBlock(filter_num, stride)) + for _ in range(1, blocks): + res_blocks.add(BasicBlock(filter_num, stride=1)) + return res_blocks + + def get_config(self): + return { + "layer_dims": self.layer_dims, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +class LeNet(keras.Model): + def __init__(self, input_shape, channels=3, num_classes=10): + super(LeNet, self).__init__() + self.input_shape = input_shape + self.channels = channels + self.num_classes = num_classes + + self.conv1 = keras.layers.Conv2D( + 6, + kernel_size=5, + strides=1, + activation="relu", + input_shape=(input_shape, input_shape, channels), + ) + self.pool1 = keras.layers.MaxPool2D(pool_size=2, strides=2) + self.conv2 = keras.layers.Conv2D( + 16, kernel_size=5, strides=1, activation="relu" + ) + self.pool2 = keras.layers.MaxPool2D(pool_size=2, strides=2) + self.flatten = keras.layers.Flatten() + + self.fc1 = keras.layers.Dense(120, activation="relu") + self.fc2 = keras.layers.Dense(84, activation="relu") + self.fc3 = keras.layers.Dense(num_classes, activation="softmax") + + def call(self, inputs, training=None): + x = self.conv1(inputs) + x = self.pool1(x) + x = self.conv2(x) + x = self.pool2(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x + + def get_config(self): + return { + "input_shape": self.input_shape, + "channels": self.channels, + "num_classes": self.num_classes, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def lenet5(input_shape, num_classes: int): + return LeNet(input_shape, 3, num_classes) + + +def resnet10(): + return ResNet([1, 1, 1, 1]) + + +def resnet18(): + return ResNet([2, 2, 2, 2]) + + +def resnet34(): + return ResNet([3, 4, 6, 3]) diff --git a/examples/cifar100/fci_ssl/fed_ci_match/benchmarkingjob.yaml b/examples/cifar100/fci_ssl/fed_ci_match/benchmarkingjob.yaml new file mode 100644 index 00000000..5eb8a6de --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match/benchmarkingjob.yaml @@ -0,0 +1,71 @@ +benchmarkingjob: + # job name of bechmarking; string type; + name: "benchmarkingjob" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "/home/wyd/ianvs/federated_class_incremental_learning/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/cifar100/fci_ssl/fed_ci_match/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "FediCarl" + # the url address of test algorithm configuration file; string type; + # the file format supports yaml/yml + url: "./examples/cifar100/fci_ssl/fed_ci_match/algorithm/algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "task_avg_acc": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "F1_SCORE" + metrics: [ "task_avg_acc", "forget_rate" ] + + # network of save selected and all dataitems in workspace `./rank` ; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + diff --git a/examples/cifar100/fci_ssl/fed_ci_match/testenv/acc.py b/examples/cifar100/fci_ssl/fed_ci_match/testenv/acc.py new file mode 100644 index 00000000..f55961f3 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match/testenv/acc.py @@ -0,0 +1,39 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ['acc'] + + +@ClassFactory.register(ClassType.GENERAL, alias='accuracy') +def accuracy(y_true, y_pred, **kwargs): + print(f"y_true: {y_true}") + y_pred_arr = [val for val in y_pred.values()] + y_true_arr = [] + for i in range(len(y_pred_arr)): + y_true_arr.append(np.full(y_pred_arr[i].shape, int(y_true[i]))) + y_pred = tf.cast(tf.convert_to_tensor(np.concatenate(y_pred_arr, axis=0)), tf.int64) + + y_true = tf.cast(tf.convert_to_tensor(np.concatenate(y_true_arr, axis=0)), tf.int64) + # print(y_true, y_pred) + total = tf.shape(y_true)[0] + correct = tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.int32)) + print(f"correct:{correct}, total:{total}") + acc = float(int(correct) / total) + print(f"acc:{acc}") + return acc + diff --git a/examples/cifar100/fci_ssl/fed_ci_match/testenv/testenv.yaml b/examples/cifar100/fci_ssl/fed_ci_match/testenv/testenv.yaml new file mode 100644 index 00000000..c27c9229 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match/testenv/testenv.yaml @@ -0,0 +1,38 @@ +testenv: + backend: "tensorflow" + dataset: + name: 'cifar100' + # the url address of train dataset index; string type; + train_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_train.txt" + # the url address of test dataset index; string type; + test_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_test.txt" + + + # network eval configuration of incremental learning; + model_eval: + # metric used for network evaluation + model_metric: + # metric name; string type; + name: "accuracy" + # the url address of python file + url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/fci_ssl/fed_ci_match/testenv/acc.py" + + # condition of triggering inference network to update + # threshold of the condition; types are float/int + threshold: 0.01 + # operator of the condition; string type; + # values are ">=", ">", "<=", "<" and "="; + operator: "<=" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + # - name: "accuracy" + # # the url address of python file + # url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/fci_ssl/fed_ci_match/testenv/acc.py" + - name: "forget_rate" + - name: "task_avg_acc" + # incremental rounds setting of incremental learning; int type; default value is 2; + incremental_rounds: 2 + round: 1 + client_number: 5 \ No newline at end of file diff --git a/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/FedCiMatch.py b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/FedCiMatch.py new file mode 100644 index 00000000..77513380 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/FedCiMatch.py @@ -0,0 +1,309 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import logging +import tensorflow as tf +import keras +import numpy as np +from model import resnet10 +from agumentation import * +from data_prepocessor import * + + +def get_one_hot(target, num_classes): + y = tf.one_hot(target, depth=num_classes) + if len(y.shape) == 3: + y = tf.squeeze(y, axis=1) + return y + + +class FedCiMatch: + + def __init__( + self, num_classes, batch_size, epochs, learning_rate, memory_size + ) -> None: + self.num_classes = num_classes + self.batch_size = batch_size + self.epochs = epochs + self.learning_rate = learning_rate + self.memory_size = memory_size + self.task_size = None + self.warm_up_round = 1 + self.accept_threshold = 0.85 + self.old_task_id = -1 + + self.classifier = None + self.feature_extractor = self._build_feature_extractor() + dataset_name = "cifar100" + self.data_preprocessor = Dataset_Preprocessor( + dataset_name, Weak_Augment(dataset_name), RandAugment(dataset_name) + ) + + self.observed_classes = [] + self.class_mapping = {} + self.class_per_round = [] + self.x_exemplars = [] + self.y_exemplars = [] + self.num_meta_round = 5 + self.beta = 0.1 + self.num_rounds = 100 + + print(f"self epoch is {self.epochs}") + + def _build_feature_extractor(self): + self.global_model = resnet10(is_combined=True) + self.global_model.build(input_shape=(None, 32, 32, 3)) + self.global_model.call(keras.Input(shape=(32, 32, 3))) + feature_extractor = resnet10(is_combined=True) + feature_extractor.build(input_shape=(None, 32, 32, 3)) + feature_extractor.call(keras.Input(shape=(32, 32, 3))) + return feature_extractor + + def _build_classifier(self): + logging.info(f"build classifier with classes {len(self.class_mapping)}") + if self.classifier != None: + new_classifier = keras.Sequential( + [ + keras.layers.Dense( + len(self.class_mapping), kernel_initializer="lecun_normal" + ) + ] + ) + new_classifier.build( + input_shape=(None, self.feature_extractor.layers[-2].output_shape[-1]) + ) + new_weights = new_classifier.get_weights() + old_weights = self.classifier.get_weights() + # weight + new_weights[0][0 : old_weights[0].shape[0], 0 : old_weights[0].shape[1]] = ( + old_weights[0] + ) + # bias + new_weights[1][0 : old_weights[1].shape[0]] = old_weights[1] + new_classifier.set_weights(new_weights) + self.classifier = new_classifier + else: + logging.info( + f"input shape is {self.feature_extractor.layers[-2].output_shape[-1]}" + ) + self.classifier = keras.Sequential( + [ + keras.layers.Dense( + len(self.class_mapping), kernel_initializer="lecun_normal" + ) + ] + ) + self.classifier.build( + input_shape=(None, self.feature_extractor.layers[-2].output_shape[-1]) + ) + + + def get_weights(self): + return self.feature_extractor.get_weights() + + def set_weights(self, weights): + self.feature_extractor.set_weights(weights) + self.global_model.set_weights(weights) + + def get_data_size(self): + data_size = len(self.labeled_train_set[0]) + len(self.unlabeled_train_set[0]) + logging.info(f"data size: {data_size}") + return data_size + + def model_call(self, x, training=False): + x = self.feature_extractor(x, training=training) + x = self.classifier(x, training=training) + return x + + def _build_class_mapping(self): + y_train = self.labeled_train_set[1] + y = np.unique(y_train) + logging.info(f'build class mapping, y is {y}') + for i in y: + if not i in self.class_mapping.keys(): + self.class_mapping[i] = len(self.class_mapping) + self.class_per_round.append([self.class_mapping[i] for i in y]) + logging.info(f'build class mapping, class mapping is {self.class_mapping} and class per round is {self.class_per_round}') + + def _mix_with_exemplar(self): + x_train, y_train = self.labeled_train_set + if len(self.x_exemplars) == 0: + return + x_train = np.concatenate([x_train, np.array(self.x_exemplars)], axis=0) + y_train = np.concatenate([y_train, np.array(self.y_exemplars)], axis=0) + self.labeled_train_set = (x_train, y_train) + + def get_train_loader(self): + label_train_loader = self.data_preprocessor.preprocess_labeled_dataset( + self.labeled_train_set[0], self.labeled_train_set[1], self.batch_size + ) + un_label_train_loader = None + if len(self.unlabeled_train_set[0]) > 0: + un_label_train_loader = self.data_preprocessor.preprocess_unlabeled_dataset( + self.unlabeled_train_set[0], self.unlabeled_train_set[1], self.batch_size + ) + return label_train_loader, un_label_train_loader + + def before_train(self, task_id, round, train_data, task_size): + if self.task_size is None: + self.task_size = task_size + self.labeled_train_set = (train_data["label_x"], train_data["label_y"]) + self.unlabeled_train_set = ( + train_data["unlabel_x"], + train_data["unlabel_y"], + ) + self._build_class_mapping() + self._build_classifier() + if task_id > 0: + self._mix_with_exemplar() + self.feature_extractor.initialize_alpha() + self.labeled_train_loader, self.unlabeled_train_loader = self.get_train_loader() + + def train(self, task_id, round): + optimizer = keras.optimizers.SGD(learning_rate=self.learning_rate, momentum=0.9, weight_decay=0.0001) + all_parameter = [] + all_parameter.extend(self.feature_extractor.trainable_variables) + all_parameter.extend(self.classifier.trainable_variables) + + for epoch in range(self.epochs): + for x, y in self.labeled_train_loader: + y = np.array([self.class_mapping[i] for i in y.numpy()]) + tasks = self._split_tasks(x, y) + base_model_weights = self.feature_extractor.get_weights() + meta_model_weights = [] + for task_x, task_y in tasks: + self.feature_extractor.set_weights(base_model_weights) + for _ in range(self.num_meta_round): + with tf.GradientTape() as tape: + base_loss = self._loss(task_x, task_y) + l2_loss = self._loss_l2(self.global_model) + loss = base_loss + l2_loss*0.1 + grads = tape.gradient(loss, all_parameter) + optimizer.apply_gradients(zip(grads, all_parameter)) + meta_model_weights.append(self.feature_extractor.get_weights()) + logging.info(f'Round{round} task{task_id} epoch{epoch} loss is {loss} ') + self._merge_models(round, base_model_weights, meta_model_weights) + + self.feature_extractor.merge_to_local_model() + self.store_exemplars(task_id) + + def evaluate(self): + total_num = 0 + total_correct = 0 + for x,y in self.labeled_train_loader: + logits = self.classifier(self.feature_extractor(x, training=False)) + prob = tf.nn.softmax(logits, axis=1) + pred = tf.argmax(prob, axis=1) + pred = tf.cast(pred, dtype=tf.int32) + pred = tf.reshape(pred, y.shape) + correct = tf.cast(tf.equal(pred, y), dtype=tf.int32) + correct = tf.reduce_sum(correct) + + total_num += x.shape[0] + total_correct += int(correct) + + acc = total_correct / total_num + del total_correct, total_num + return acc + + def _loss(self,x ,y): + feature = self.feature_extractor(x) + prediction = self.classifier(feature) + loss = keras.losses.categorical_crossentropy(tf.one_hot(y, len(self.class_mapping)), prediction, from_logits=True) + return tf.reduce_mean(loss) + + def _loss_l2(self, global_model): + return 0.0 + + def unsupervised_loss(self, sux, wux): + return 0.0 + + def _merge_models(self, round, base_model_weights, meta_model_weights): + eta = np.exp(-self.beta * (round + 1 ) / self.num_rounds) + merged_meta_parameters = [ + np.average( + [meta_model_weights[i][j] for i in range(len(meta_model_weights))], axis=0 + )for j in range(len(meta_model_weights[0])) + + ] + self.feature_extractor.set_weights([eta * l_meta + (1-eta) * l_base for l_base, l_meta in zip(base_model_weights, merged_meta_parameters)]) + + def _split_tasks(self, x, y): + tasks = [] + for classes in self.class_per_round: + task = None + for cl in classes: + x_cl = x[y == cl] + y_cl = y[y == cl] + if task is None: + task = (x_cl, y_cl) + else: + task = (np.concatenate([task[0], x_cl], axis=0), + np.concatenate([task[1], y_cl], axis=0)) + if len(task[0]) > 0: + self.random_shuffle(task[0],task[1]) + tasks.append(task) + return tasks + + def random_shuffle(self, x, y): + p = np.random.permutation(len(x)) + return x[p], y[p] + + def store_exemplars(self, task): + x = self.labeled_train_set[0] + y = self.labeled_train_set[1] + logging.info(f'Storing exemplars..') + new_classes = self.class_per_round[-1] + model_classes = np.concatenate(self.class_per_round).tolist() + old_classes = model_classes[:(-len(new_classes))] + exemplars_per_class = int(self.memory_size / (len(new_classes) + len(old_classes))) + + if task > 0 : + labels = np.array(self.y_exemplars) + new_x_exemplars = [] + new_y_exemplars = [] + for cl in old_classes: + cl_x = np.array(self.x_exemplars)[labels == cl] + cl_y = np.array(self.y_exemplars)[labels == cl] + new_x_exemplars.extend(cl_x[:exemplars_per_class]) + new_y_exemplars.extend(cl_y[:exemplars_per_class]) + self.x_exemplars = new_x_exemplars + self.y_exemplars = new_y_exemplars + + for cl in new_classes: + logging.info(f'Processing class {cl} and y is {y.shape}') + cl_x = x[y == cl] + cl_y = y[y == cl] + + cl_feat = self.feature_extractor(cl_x) + cl_mean = tf.reduce_mean(cl_feat, axis=0) + + diff = tf.abs(cl_feat - cl_mean) + distance = [float(tf.reduce_sum(dis).numpy()) for dis in diff] + + sorted_index = np.argsort(distance).tolist() + if len(cl_x) > exemplars_per_class: + sorted_index = sorted_index[:exemplars_per_class] + self.x_exemplars.extend(cl_x[sorted_index]) + self.y_exemplars.extend(cl_y[sorted_index]) + + def predict(self, x): + mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + x = (tf.cast(x, dtype=tf.float32) / 255.0 - mean) / std + pred = self.classifier(self.feature_extractor(x, training=False)) + prob = tf.nn.softmax(pred, axis=1) + pred = tf.argmax(prob, axis=1) + pred = tf.cast(pred, dtype=tf.int32) + return pred \ No newline at end of file diff --git a/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/aggregation.py b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/aggregation.py new file mode 100644 index 00000000..fdf10494 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/aggregation.py @@ -0,0 +1,63 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import keras +import numpy as np +from sedna.algorithms.aggregation.aggregation import BaseAggregation +from sedna.common.class_factory import ClassType, ClassFactory +from model import resnet10 + +@ClassFactory.register(ClassType.FL_AGG, "FedAvg") +class FedAvg(BaseAggregation, abc.ABC): + def __init__(self): + super(FedAvg, self).__init__() + self.global_feature_extractor = resnet10(True) + self.global_feature_extractor.build((None, 32, 32, 3)) + self.global_feature_extractor.call(keras.Input(shape=(32, 32, 3))) + + def aggregate(self, clients): + """ + Calculate the average weight according to the number of samples + + Parameters + ---------- + clients: List + All clients in federated learning job + + Returns + ------- + update_weights : Array-like + final weights use to update model layer + """ + + + print("aggregation....") + if not len(clients): + return self.weights + self.total_size = sum([c.num_samples for c in clients]) + old_weight = [np.zeros(np.array(c).shape) for c in + next(iter(clients)).weights] + updates = [] + for inx, row in enumerate(old_weight): + for c in clients: + row += (np.array(c.weights[inx]) * c.num_samples + / self.total_size) + updates.append(row.tolist()) + global_weights = [np.array(layer) for layer in updates] + self.global_feature_extractor.set_weights(global_weights) + self.global_feature_extractor.switch_to_global() + print("finish aggregation....") + return [np.array(layer) for layer in updates] diff --git a/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/agumentation.py b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/agumentation.py new file mode 100644 index 00000000..89d1bef2 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/agumentation.py @@ -0,0 +1,230 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import random +import tensorflow as tf +from PIL import Image, ImageEnhance, ImageOps + + +""" +Reference: https://github.com/heartInsert/randaugment +""" + + +class Rand_Augment: + def __init__(self, Numbers=None, max_Magnitude=None): + self.transforms = [ + "autocontrast", + "equalize", + "rotate", + "solarize", + "color", + "posterize", + "contrast", + "brightness", + "sharpness", + "shearX", + "shearY", + "translateX", + "translateY", + ] + if Numbers is None: + self.Numbers = len(self.transforms) // 2 + else: + self.Numbers = Numbers + if max_Magnitude is None: + self.max_Magnitude = 10 + else: + self.max_Magnitude = max_Magnitude + fillcolor = 128 + self.ranges = { + # these Magnitude range , you must test it yourself , see what will happen after these operation , + # it is no need to obey the value in autoaugment.py + "shearX": np.linspace(0, 0.3, 10), + "shearY": np.linspace(0, 0.3, 10), + "translateX": np.linspace(0, 0.2, 10), + "translateY": np.linspace(0, 0.2, 10), + "rotate": np.linspace(0, 360, 10), + "color": np.linspace(0.0, 0.9, 10), + "posterize": np.round(np.linspace(8, 4, 10), 0).astype(int), + "solarize": np.linspace(256, 231, 10), + "contrast": np.linspace(0.0, 0.5, 10), + "sharpness": np.linspace(0.0, 0.9, 10), + "brightness": np.linspace(0.0, 0.3, 10), + "autocontrast": [0] * 10, + "equalize": [0] * 10, + "invert": [0] * 10, + } + self.func = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, + fill=fillcolor, + ), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, + fill=fillcolor, + ), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), + fill=fillcolor, + ), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), + fill=fillcolor, + ), + "rotate": lambda img, magnitude: self.rotate_with_fill(img, magnitude), + # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: img, + "invert": lambda img, magnitude: ImageOps.invert(img), + } + + def rand_augment(self): + """Generate a set of distortions. + Args: + N: Number of augmentation transformations to apply sequentially. N is len(transforms)/2 will be best + M: Max_Magnitude for all the transformations. should be <= self.max_Magnitude + """ + + M = np.random.randint(0, self.max_Magnitude, self.Numbers) + + sampled_ops = np.random.choice(self.transforms, self.Numbers) + return [(op, Magnitude) for (op, Magnitude) in zip(sampled_ops, M)] + + def __call__(self, image): + operations = self.rand_augment() + for op_name, M in operations: + operation = self.func[op_name] + mag = self.ranges[op_name][M] + image = operation(image, mag) + return image + + def rotate_with_fill(self, img, magnitude): + # I don't know why rotate must change to RGBA , it is copy from Autoaugment - pytorch + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite( + rot, Image.new("RGBA", rot.size, (128,) * 4), rot + ).convert(img.mode) + + def test_single_operation(self, image, op_name, M=-1): + """ + :param image: image + :param op_name: operation name in self.transforms + :param M: -1 stands for the max Magnitude in there operation + :return: + """ + operation = self.func[op_name] + mag = self.ranges[op_name][M] + image = operation(image, mag) + return image + + +class Base_Augment: + def __init__(self, dataset_name: str) -> None: + self.dataset_name = dataset_name + + def __call__(self, images): + return images + + +class Weak_Augment(Base_Augment): + def __init__(self, dataset_name: str) -> None: + super().__init__(dataset_name) + self.augment_impl = self.augment_for_cifar + + def augment_mirror(self, x): + new_images = x.copy() + indices = np.arange(len(new_images)).tolist() + sampled = random.sample( + indices, int(round(0.5 * len(indices))) + ) # flip horizontally 50% + new_images[sampled] = np.fliplr(new_images[sampled]) + return new_images # random shift + + def augment_shift(self, x, w): + y = tf.pad(x, [[0] * 2, [w] * 2, [w] * 2, [0] * 2], mode="REFLECT") + return tf.image.random_crop(y, tf.shape(x)) + + def augment_for_cifar(self, images: np.ndarray): + return self.augment_shift(self.augment_mirror(images), 4) + + def __call__(self, images: np.ndarray): + return self.augment_impl(images) + + +class Strong_Augment(Base_Augment): + def __init__(self, dataset_name: str) -> None: + super().__init__(dataset_name) + + def augment_mirror(self, x): + new_images = x.copy() + indices = np.arange(len(new_images)).tolist() + sampled = random.sample( + indices, int(round(0.5 * len(indices))) + ) # flip horizontally 50% + new_images[sampled] = np.fliplr(new_images[sampled]) + return new_images # random shift + + def augment_shift_mnist(self, x, w): + y = tf.pad(x, [[0] * 2, [w] * 2, [w] * 2], mode="REFLECT") + return tf.image.random_crop(y, tf.shape(x)) + + def __call__(self, images: np.ndarray): + return self.augment_shift_mnist(self.augment_mirror(images), 4) + + +class RandAugment(Base_Augment): + def __init__(self, dataset_name: str) -> None: + super().__init__(dataset_name) + self.rand_augment = Rand_Augment() + self.input_shape = (32, 32, 3) + + def __call__(self, images): + print("images:", images.shape) + + return np.array( + [ + np.array( + self.rand_augment( + Image.fromarray(np.reshape(img, self.input_shape)) + ) + ) + for img in images + ] + ) diff --git a/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/algorithm.yaml b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/algorithm.yaml new file mode 100644 index 00000000..c9bd7996 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/algorithm.yaml @@ -0,0 +1,27 @@ +algorithm: + paradigm_type: "federatedclassincrementallearning" + fl_data_setting: + train_ratio: 1.0 + splitting_method: "default" + label_data_ratio: 1.0 + data_partition: "iid" + initial_model_url: "/home/wyd/ianvs/project/init_model/cnn.pb" + + modules: + - type: "basemodel" + name: "fci_ssl" + url: "./examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/basemodel.py" + hyperparameters: + - batch_size: + values: + - 128 + - learning_rate: + values: + - 0.001 + - epochs: + values: + - 16 + - type: "aggregation" + name: "FedAvg" + url: "./examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/aggregation.py" + diff --git a/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/basemodel.py b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/basemodel.py new file mode 100644 index 00000000..09932267 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/basemodel.py @@ -0,0 +1,72 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +from sedna.common.class_factory import ClassType, ClassFactory +from model import resnet10 +from FedCiMatch import FedCiMatch +import logging + +os.environ["BACKEND_TYPE"] = "KERAS" +__all__ = ["BaseModel"] +logging.getLogger().setLevel(logging.INFO) + + +@ClassFactory.register(ClassType.GENERAL, alias="fci_ssl") +class BaseModel: + def __init__(self, **kwargs) -> None: + self.kwargs = kwargs + self.learning_rate = kwargs.get("learning_rate", 0.001) + self.epochs = kwargs.get("epochs", 1) + self.batch_size = kwargs.get("batch_size", 32) + self.task_size = kwargs.get("task_size", 10) + self.memory_size = kwargs.get("memory_size", 2000) + self.num_classes = 2 # the number of class for the first task + self.FedCiMatch = FedCiMatch( + self.num_classes, + self.batch_size, + self.epochs, + self.learning_rate, + self.memory_size, + ) + self.class_learned = 0 + + def get_weights(self): + print("get weights") + return self.FedCiMatch.get_weights() + + def set_weights(self, weights): + print("set weights") + self.FedCiMatch.set_weights(weights) + + def train(self, train_data, val_data, **kwargs): + task_id = kwargs.get("task_id", 0) + round = kwargs.get("round", 1) + round = task_id * 1 + round + task_size = kwargs.get("task_size", self.task_size) + logging.info(f"in train: {round} task_id: {task_id}") + self.FedCiMatch.before_train(task_id, round, train_data, task_size) + self.FedCiMatch.train(task_id, round) + return {"num_samples": self.FedCiMatch.get_data_size(), "task_id": task_id} + + def predict(self, data_files, **kwargs): + result = {} + for data in data_files: + x = np.load(data) + logging.info(f"predicting {x.shape}") + res = self.FedCiMatch.predict(x) + result[data] = res.numpy() + print("finish predict") + return result diff --git a/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/data_prepocessor.py b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/data_prepocessor.py new file mode 100644 index 00000000..15fe199c --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/data_prepocessor.py @@ -0,0 +1,52 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +from agumentation import Base_Augment + +class Dataset_Preprocessor: + def __init__(self, + dataset_name:str, + weak_augment_helper:Base_Augment, + strong_augment_helper:Base_Augment) -> None: + self.weak_augment_helper = weak_augment_helper + self.strong_augment_helper = strong_augment_helper + self.mean = 0.0 + self.std = 1.0 + if dataset_name == 'cifar100': + self.mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + self.std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + print(f"mean: {self.mean}, std: {self.std}") + def preprocess_labeled_dataset(self, x, y, batch_size): + return tf.data.Dataset.from_tensor_slices((x, y)).shuffle(100000).map( + lambda x,y:( + (tf.cast(x, dtype=tf.float32) / 255. - self.mean) / self.std, + tf.cast(y, dtype=tf.int32) + ) + ).batch(batch_size) + + + def preprocess_unlabeled_dataset(self, ux, uy, batch_size): + wux = self.weak_augment_helper(ux) + sux = self.strong_augment_helper(ux) + return tf.data.Dataset.from_tensor_slices((ux, wux, sux, uy)).shuffle(1000).map( + lambda ux,wux,sux,uy: ( + (tf.cast(ux, dtype=tf.float32) / 255. - self.mean) / self.std, + (tf.cast(wux, dtype=tf.float32) / 255. - self.mean) / self.std, + (tf.cast(sux, dtype=tf.float32) / 255. - self.mean) / self.std, + tf.cast(uy, dtype=tf.int32) + ) + ).batch(batch_size) + diff --git a/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/model.py b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/model.py new file mode 100644 index 00000000..a6f08fa8 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/model.py @@ -0,0 +1,309 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +import tensorflow as tf +import numpy as np +import keras +from keras import layers, Sequential + + +class Conv2D(keras.layers.Layer): + def __init__( + self, + is_combined: bool, + alpha: float, + filter_num, + kernel_size, + strides=(1, 1), + padding: str = "valid", + ): + super(Conv2D, self).__init__() + self.is_combined = is_combined + self.alpha = tf.Variable(alpha) + self.conv_local = layers.Conv2D( + filter_num, kernel_size, strides, padding, kernel_initializer="he_normal" + ) + self.conv_global = layers.Conv2D( + filter_num, kernel_size, strides, padding, kernel_initializer="he_normal" + ) + + def call(self, inputs): + return self.alpha * self.conv_global(inputs) + ( + 1 - self.alpha + ) * self.conv_local(inputs) + + def get_alpha(self): + return self.alpha + + def set_alpha(self, alpha): + self.alpha.assign(alpha) + + def get_global_weights(self): + return self.conv_global.get_weights() + + def set_global_weights(self, global_weights): + self.conv_global.set_weights(global_weights) + + def get_global_variables(self): + return self.conv_global.trainable_variables + + def merge_to_local(self): + new_weight = [] + for w_local, w_global in zip( + self.conv_local.get_weights(), self.conv_global.get_weights() + ): + new_weight.append(self.alpha * w_global + (1 - self.alpha) * w_local) + self.conv_local.set_weights(new_weight) + self.alpha.assign(0.0) + + def switch_to_global(self): + self.conv_global.set_weights(self.conv_local.get_weights()) + + +# Input--conv2D--BN--ReLU--conv2D--BN--ReLU--Output +# \ / +# ------------------------------ +class BasicBlock(keras.Model): + def __init__(self, is_combined: bool, filter_num, stride=1): + super(BasicBlock, self).__init__() + + self.filter_num = filter_num + self.stride = stride + + self.conv1 = Conv2D( + is_combined, 0.0, filter_num, (3, 3), strides=stride, padding="same" + ) + self.bn1 = layers.BatchNormalization() + self.relu = layers.Activation("relu") + + self.conv2 = Conv2D( + is_combined, 0.0, filter_num, (3, 3), strides=1, padding="same" + ) + self.bn2 = layers.BatchNormalization() + + if stride != 1: + self.downsample = Sequential() + self.downsample.add( + Conv2D(is_combined, 0.0, filter_num, (1, 1), strides=stride) + ) + else: + self.downsample = lambda x: x + + def call(self, inputs, training=None): + # [b, h, w, c] + out = self.conv1(inputs) + out = self.bn1(out, training=training) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out, training=training) + + identity = self.downsample(inputs) + + output = layers.add([out, identity]) + output = tf.nn.relu(output) + + return output + + +class ResNet(keras.Model): + def __init__(self, is_combined: bool, layer_dims): # [2, 2, 2, 2] + super(ResNet, self).__init__() + + self.is_combined = is_combined + self.stem = Sequential( + [ + Conv2D(is_combined, 0.0, 64, (3, 3), strides=(1, 1)), + layers.BatchNormalization(), + layers.Activation("relu"), + layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding="same"), + ] + ) + + self.layer1 = self.build_resblock(64, layer_dims[0]) + self.layer2 = self.build_resblock(128, layer_dims[1], stride=2) + self.layer3 = self.build_resblock(256, layer_dims[2], stride=2) + self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) + + # output: [b, 512, h, w], + self.avgpool = layers.GlobalAveragePooling2D() + + def call(self, inputs, training=None): + x = self.stem(inputs, training=training) + + x = self.layer1(x, training=training) + x = self.layer2(x, training=training) + x = self.layer3(x, training=training) + x = self.layer4(x, training=training) + + # [b, c] + x = self.avgpool(x) + return x + + def build_resblock(self, filter_num, blocks, stride=1): + res_blocks = [] + # may down sample + res_blocks.append(BasicBlock(self.is_combined, filter_num, stride)) + for _ in range(1, blocks): + res_blocks.append(BasicBlock(self.is_combined, filter_num, stride=1)) + return Sequential(res_blocks) + + def get_alpha(self): + convs = self._get_all_conv_layers() + ret = [] + for conv in convs: + ret.append(conv.get_alpha()) + return ret + + def set_alpha(self, alpha=0.0): + convs = self._get_all_conv_layers() + for conv in convs: + conv.set_alpha(alpha) + + def merge_to_local_model(self): + convs = self._get_all_conv_layers() + for conv in convs: + conv.merge_to_local() + + def switch_to_global(self): + convs = self._get_all_conv_layers() + for conv in convs: + conv.switch_to_global() + + def initialize_alpha(self): + convs = self._get_all_conv_layers() + for conv in convs: + conv.set_alpha(np.random.random()) + + def set_global_model(self, global_model): + local_convs = self._get_all_conv_layers() + global_convs = global_model._get_all_conv_layers() + for local_conv, global_conv in zip(local_convs, global_convs): + local_conv.set_global_weights(global_conv.get_global_weights()) + + def get_global_variables(self): + convs = self._get_all_conv_layers() + ret = [] + for conv in convs: + ret.extend(conv.get_global_variables()) + return ret + + def _get_all_conv_layers(self) -> List[Conv2D]: + def get_all_conv_layers_(model): + convs = [] + for i in model.layers: + if isinstance(i, Conv2D): + convs.append(i) + elif isinstance(i, keras.Model): + convs.extend(get_all_conv_layers_(i)) + return convs + + return get_all_conv_layers_(self) + + +def resnet10(is_combined=False) -> ResNet: + return ResNet(is_combined, [1, 1, 1, 1]) + + +def resnet18(is_combined=False) -> ResNet: + return ResNet(is_combined, [2, 2, 2, 2]) + + +def resnet34(is_combined=False) -> ResNet: + return ResNet(is_combined, [3, 4, 6, 3]) + + +class LeNet5(keras.Model): + def __init__(self): # [2, 2, 2, 2] + super(LeNet5, self).__init__() + self.cnn_layers = keras.Sequential( + [ + Conv2D(True, 0.0, 6, kernel_size=(5, 5), padding="valid"), + layers.ReLU(), + layers.MaxPool2D(pool_size=(2, 2)), + Conv2D(True, 0.0, 16, kernel_size=(5, 5), padding="valid"), + layers.ReLU(), + layers.MaxPool2D(pool_size=(2, 2)), + ] + ) + + self.flatten = layers.Flatten() + + self.fc_layers = keras.Sequential( + [ + layers.Dense(120), + layers.ReLU(), + layers.Dense(84), + layers.ReLU(), + ] + ) + + def call(self, inputs, training=None): + x = self.cnn_layers(inputs, training=training) + + x = self.flatten(x, training=training) + x = self.fc_layers(x, training=training) + + def get_alpha(self): + convs = self._get_all_conv_layers() + ret = [] + for conv in convs: + ret.append(conv.get_alpha()) + return ret + + def set_alpha(self, alpha=0.0): + convs = self._get_all_conv_layers() + for conv in convs: + conv.set_alpha(alpha) + + def merge_to_local_model(self): + convs = self._get_all_conv_layers() + for conv in convs: + conv.merge_to_local() + + def switch_to_global(self): + convs = self._get_all_conv_layers() + for conv in convs: + conv.switch_to_global() + + def initialize_alpha(self): + convs = self._get_all_conv_layers() + for conv in convs: + conv.set_alpha(np.random.random()) + + def set_global_model(self, global_model): + local_convs = self._get_all_conv_layers() + global_convs = global_model._get_all_conv_layers() + for local_conv, global_conv in zip(local_convs, global_convs): + local_conv.set_global_weights(global_conv.get_global_weights()) + + def get_global_variables(self): + convs = self._get_all_conv_layers() + ret = [] + for conv in convs: + ret.extend(conv.get_global_variables()) + return ret + + def _get_all_conv_layers(self) -> List[Conv2D]: + def get_all_conv_layers_(model): + convs = [] + for i in model.layers: + if isinstance(i, Conv2D): + convs.append(i) + elif isinstance(i, keras.Model): + convs.extend(get_all_conv_layers_(i)) + return convs + + return get_all_conv_layers_(self) diff --git a/examples/cifar100/fci_ssl/fed_ci_match_v2/benchmarkingjob.yaml b/examples/cifar100/fci_ssl/fed_ci_match_v2/benchmarkingjob.yaml new file mode 100644 index 00000000..b6564058 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match_v2/benchmarkingjob.yaml @@ -0,0 +1,71 @@ +benchmarkingjob: + # job name of bechmarking; string type; + name: "benchmarkingjob" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "/home/wyd/ianvs/federated_class_incremental_learning/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/cifar100/fci_ssl/fed_ci_match_v2/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "fci_ssl_test" + # the url address of test algorithm configuration file; string type; + # the file format supports yaml/yml + url: "./examples/cifar100/fci_ssl/fed_ci_match_v2/algorithm/algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "accuracy": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "F1_SCORE" + metrics: [ "accuracy", "forget_rate" ] + + # network of save selected and all dataitems in workspace `./rank` ; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + diff --git a/examples/cifar100/fci_ssl/fed_ci_match_v2/testenv/acc.py b/examples/cifar100/fci_ssl/fed_ci_match_v2/testenv/acc.py new file mode 100644 index 00000000..9dcb119e --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match_v2/testenv/acc.py @@ -0,0 +1,35 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ["acc"] + + +@ClassFactory.register(ClassType.GENERAL, alias="accuracy") +def accuracy(y_true, y_pred, **kwargs): + y_pred_arr = [val for val in y_pred.values()] + y_true_arr = [] + for i in range(len(y_pred_arr)): + y_true_arr.append(np.full(y_pred_arr[i].shape, int(y_true[i]))) + y_pred = tf.cast(tf.convert_to_tensor(np.concatenate(y_pred_arr, axis=0)), tf.int64) + y_true = tf.cast(tf.convert_to_tensor(np.concatenate(y_true_arr, axis=0)), tf.int64) + total = tf.shape(y_true)[0] + correct = tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.int32)) + print(f"correct:{correct}, total:{total}") + acc = float(int(correct) / total) + print(f"acc:{acc}") + return acc diff --git a/examples/cifar100/fci_ssl/fed_ci_match_v2/testenv/testenv.yaml b/examples/cifar100/fci_ssl/fed_ci_match_v2/testenv/testenv.yaml new file mode 100644 index 00000000..bb906c28 --- /dev/null +++ b/examples/cifar100/fci_ssl/fed_ci_match_v2/testenv/testenv.yaml @@ -0,0 +1,37 @@ +testenv: + backend: "tensorflow" + dataset: + name: 'cifar100' + # the url address of train dataset index; string type; + train_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_train.txt" + # the url address of test dataset index; string type; + test_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_test.txt" + + + # network eval configuration of incremental learning; + model_eval: + # metric used for network evaluation + model_metric: + # metric name; string type; + name: "accuracy" + # the url address of python file + url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/fci_ssl/fed_ci_match_v2/testenv/acc.py" + + # condition of triggering inference network to update + # threshold of the condition; types are float/int + threshold: 0.01 + # operator of the condition; string type; + # values are ">=", ">", "<=", "<" and "="; + operator: "<=" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + - name: "accuracy" + # the url address of python file + url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/fci_ssl/fed_ci_match_v2/testenv/acc.py" + - name: "forget_rate" + # incremental rounds setting of incremental learning; int type; default value is 2; + incremental_rounds: 50 + round: 1 + client_number: 2 \ No newline at end of file diff --git a/examples/cifar100/fci_ssl/fedavg/__init__.py b/examples/cifar100/fci_ssl/fedavg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/cifar100/fci_ssl/fedavg/algorithm/__init__.py b/examples/cifar100/fci_ssl/fedavg/algorithm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/cifar100/fci_ssl/fedavg/algorithm/aggregation.py b/examples/cifar100/fci_ssl/fedavg/algorithm/aggregation.py new file mode 100644 index 00000000..1aa52589 --- /dev/null +++ b/examples/cifar100/fci_ssl/fedavg/algorithm/aggregation.py @@ -0,0 +1,60 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from copy import deepcopy +from typing import List + +import numpy as np +from sedna.algorithms.aggregation.aggregation import BaseAggregation +from sedna.common.class_factory import ClassType, ClassFactory + + +@ClassFactory.register(ClassType.FL_AGG, "FedAvg") +class FedAvg(BaseAggregation, abc.ABC): + def __init__(self): + super(FedAvg, self).__init__() + + """ + Federated averaging algorithm + """ + + def aggregate(self, clients: List): + """ + Calculate the average weight according to the number of samples + + Parameters + ---------- + clients: List + All clients in federated learning job + + Returns + ------- + update_weights : Array-like + final weights use to update model layer + """ + + print("aggregation....") + if not len(clients): + return self.weights + self.total_size = sum([c.num_samples for c in clients]) + old_weight = [np.zeros(np.array(c).shape) for c in next(iter(clients)).weights] + updates = [] + for inx, row in enumerate(old_weight): + for c in clients: + row += np.array(c.weights[inx]) * c.num_samples / self.total_size + updates.append(row.tolist()) + self.weights = deepcopy(updates) + print("finish aggregation....") + return [np.array(layer) for layer in updates] diff --git a/examples/cifar100/fci_ssl/fedavg/algorithm/algorithm.yaml b/examples/cifar100/fci_ssl/fedavg/algorithm/algorithm.yaml new file mode 100644 index 00000000..0b4c683f --- /dev/null +++ b/examples/cifar100/fci_ssl/fedavg/algorithm/algorithm.yaml @@ -0,0 +1,49 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "federatedclassincrementallearning" + fl_data_setting: + # ratio of training dataset; float type; + # the default value is 0.8. + train_ratio: 1.0 + # the method of splitting dataset; string type; optional; + # currently the options of value are as follows: + # 1> "default": the dataset is evenly divided based train_ratio; + splitting_method: "default" + label_data_ratio: 1.0 + data_partition: "iid" + # the url address of initial network for network pre-training; string url; + # the url address of initial network; string type; optional; + initial_model_url: "/home/wyd/ianvs/project/init_model/cnn.pb" + # algorithm module configuration in the paradigm; list type; + # incremental rounds setting of incremental learning; int type; default value is 2; + + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "fedavg-client" + # the url address of python module; string type; + url: "./examples/cifar100/fci_ssl/fedavg/algorithm/basemodel.py" + + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + - batch_size: + values: + - 64 + - learning_rate: + values: + - 0.001 + - epochs: + values: + - 1 + - type: "aggregation" + name: "FedAvg" + url: "./examples/cifar100/fci_ssl/fedavg/algorithm/aggregation.py" + diff --git a/examples/cifar100/fci_ssl/fedavg/algorithm/basemodel.py b/examples/cifar100/fci_ssl/fedavg/algorithm/basemodel.py new file mode 100644 index 00000000..7fdc725b --- /dev/null +++ b/examples/cifar100/fci_ssl/fedavg/algorithm/basemodel.py @@ -0,0 +1,211 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import keras +import numpy as np +import tensorflow as tf +from sedna.common.class_factory import ClassType, ClassFactory +from model import resnet10 + +__all__ = ["BaseModel"] +os.environ["BACKEND_TYPE"] = "KERAS" +logging.getLogger().setLevel(logging.INFO) + + +@ClassFactory.register(ClassType.GENERAL, alias="fedavg-client") +class BaseModel: + def __init__(self, **kwargs): + self.kwargs = kwargs + print(f"kwargs: {kwargs}") + self.batch_size = kwargs.get("batch_size", 1) + print(f"batch_size: {self.batch_size}") + self.epochs = kwargs.get("epochs", 1) + self.learning_rate = kwargs.get("learning_rate", 0.001) + self.num_classes = 50 + self.task_size = 50 + self.old_task_id = -1 + self.mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + self.std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + self.fe = resnet10() + logging.info(type(self.fe)) + self.classifier = None + self._init_model() + + def _init_model(self): + self.fe.compile( + optimizer="sgd", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + self.fe.call(keras.Input(shape=(32, 32, 3))) + fe_weights = self.fe.get_weights() + self.fe_weights_length = len(fe_weights) + + def load(self, model_url=None): + logging.info(f"load model from {model_url}") + + def _initialize(self): + logging.info(f"initialize finished") + + def get_weights(self): + weights = [] + fe_weights = self.fe.get_weights() + self.fe_weights_length = len(fe_weights) + clf_weights = self.classifier.get_weights() + weights.extend(fe_weights) + weights.extend(clf_weights) + return weights + + def set_weights(self, weights): + fe_weights = weights[: self.fe_weights_length] + clf_weights = weights[self.fe_weights_length :] + self.fe.set_weights(fe_weights) + self.classifier.set_weights(clf_weights) + + def save(self, model_path=""): + logging.info("save model") + + def model_info(self, model_path, result, relpath): + logging.info("model info") + return {} + + def build_classifier(self): + if self.classifier != None: + new_classifier = keras.Sequential( + [ + keras.layers.Dense( + self.num_classes, kernel_initializer="lecun_normal" + ) + ] + ) + new_classifier.build( + input_shape=(None, self.fe.layers[-2].output_shape[-1]) + ) + new_weights = new_classifier.get_weights() + old_weights = self.classifier.get_weights() + # weight + new_weights[0][0 : old_weights[0].shape[0], 0 : old_weights[0].shape[1]] = ( + old_weights[0] + ) + # bias + new_weights[1][0 : old_weights[1].shape[0]] = old_weights[1] + new_classifier.set_weights(new_weights) + self.classifier = new_classifier + else: + logging.info(f"input shape is {self.fe.layers[-2].output_shape[-1]}") + self.classifier = keras.Sequential( + [ + keras.layers.Dense( + self.num_classes, kernel_initializer="lecun_normal" + ) + ] + ) + self.classifier.build( + input_shape=(None, self.fe.layers[-2].output_shape[-1]) + ) + + logging.info(f"finish ! initialize classifier {self.classifier.summary()}") + + def train(self, train_data, valid_data, **kwargs): + optimizer = keras.optimizers.SGD( + learning_rate=self.learning_rate, momentum=0.9, weight_decay=0.0001 + ) + round = kwargs.get("round", -1) + task_id = kwargs.get("task_id", -1) + if self.old_task_id != task_id: + self.old_task_id = task_id + self.num_classes = self.task_size * (task_id + 1) + self.build_classifier() + data = (train_data["label_x"], train_data["label_y"]) + train_db = self.data_process(data) + logging.info(train_db) + all_params = [] + all_params.extend(self.fe.trainable_variables) + all_params.extend(self.classifier.trainable_variables) + for epoch in range(self.epochs): + total_loss = 0 + total_num = 0 + logging.info(f"Epoch {epoch + 1} / {self.epochs}") + logging.info("-" * 50) + for x, y in train_db: + with tf.GradientTape() as tape: + logits = self.classifier(self.fe(x, training=True), training=True) + loss = tf.reduce_mean( + keras.losses.sparse_categorical_crossentropy( + y, logits, from_logits=True + ) + ) + grads = tape.gradient(loss, all_params) + optimizer.apply(grads, all_params) + total_loss += loss + total_num += 1 + + logging.info( + f"train round {round}: Epoch {epoch + 1} avg loss: {total_loss / total_num}" + ) + logging.info(f"finish round {round} train") + return {"num_samples": data[0].shape[0]} + + def predict(self, data_files, **kwargs): + result = {} + for data in data_files: + x = np.load(data) + logging.info(f"predicting {x.shape}") + mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + x = (tf.cast(x, dtype=tf.float32) / 255.0 - mean) / std + pred = self.classifier(self.fe(x, training=False)) + prob = tf.nn.softmax(pred, axis=1) + pred = tf.argmax(prob, axis=1) + pred = tf.cast(pred, dtype=tf.int32) + result[data] = pred.numpy() + logging.info("finish predict") + return result + + def eval(self, data, round, **kwargs): + total_num = 0 + total_correct = 0 + data = self.data_process(data) + for i, (x, y) in enumerate(data): + logits = self.model(x, training=False) + # prob = tf.nn.softmax(logits, axis=1) + pred = tf.argmax(logits, axis=1) + pred = tf.cast(pred, dtype=tf.int32) + correct = tf.cast(tf.equal(pred, y), dtype=tf.int32) + correct = tf.reduce_sum(correct) + total_num += x.shape[0] + total_correct += int(correct) + logging.info(f"total_correct: {total_correct}, total_num: {total_num}") + acc = total_correct / total_num + del total_correct + logging.info(f"finsih round {round}evaluate, acc: {acc}") + return acc + + def data_process(self, data, **kwargs): + + assert data is not None, "data is None" + # data[0]'shape = (50000, 32,32,3) data[1]'shape = (50000,) + return ( + tf.data.Dataset.from_tensor_slices((data[0], data[1])) + .shuffle(100000) + .map( + lambda x, y: ( + (tf.cast(x, dtype=tf.float32) / 255.0 - self.mean) / self.std, + tf.cast(y, dtype=tf.int32), + ) + ) + .batch(self.batch_size) + ) diff --git a/examples/cifar100/fci_ssl/fedavg/algorithm/model.py b/examples/cifar100/fci_ssl/fedavg/algorithm/model.py new file mode 100644 index 00000000..5e5ea3a5 --- /dev/null +++ b/examples/cifar100/fci_ssl/fedavg/algorithm/model.py @@ -0,0 +1,171 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import keras + + +# Input--conv2D--BN--ReLU--conv2D--BN--ReLU--Output +# \ / +# ------------------------------ +class BasicBlock(keras.layers.Layer): + def __init__(self, filter_num, stride=1): + super(BasicBlock, self).__init__() + + self.conv1 = keras.layers.Conv2D( + filter_num, (3, 3), strides=stride, padding="same" + ) + self.bn1 = keras.layers.BatchNormalization() + self.relu = keras.layers.Activation("relu") + + self.conv2 = keras.layers.Conv2D(filter_num, (3, 3), strides=1, padding="same") + self.bn2 = keras.layers.BatchNormalization() + + if stride != 1: + self.downsample = keras.models.Sequential() + self.downsample.add(keras.layers.Conv2D(filter_num, (1, 1), strides=stride)) + else: + self.downsample = lambda x: x + + def call(self, inputs, training=None): + # [b, h, w, c] + out = self.conv1(inputs) + out = self.bn1(out, training=training) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out, training=training) + + identity = self.downsample(inputs) + + output = keras.layers.add([out, identity]) + output = tf.nn.relu(output) + + return output + + +class ResNet(keras.Model): + def __init__(self, layer_dims): # [2, 2, 2, 2] + super(ResNet, self).__init__() + self.layer_dims = layer_dims + + self.stem = keras.models.Sequential( + [ + keras.layers.Conv2D(64, (3, 3), strides=(1, 1)), + keras.layers.BatchNormalization(), + keras.layers.Activation("relu"), + keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(1, 1), padding="same" + ), + ] + ) + + self.layer1 = self.build_resblock(64, layer_dims[0]) + self.layer2 = self.build_resblock(128, layer_dims[1], stride=2) + self.layer3 = self.build_resblock(256, layer_dims[2], stride=2) + self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) + + # output: [b, 512, h, w], + self.avgpool = keras.layers.GlobalAveragePooling2D() + + def call(self, inputs, training=None): + x = self.stem(inputs, training=training) + + x = self.layer1(x, training=training) + x = self.layer2(x, training=training) + x = self.layer3(x, training=training) + x = self.layer4(x, training=training) + x = self.avgpool(x) + return x + + def build_resblock(self, filter_num, blocks, stride=1): + res_blocks = keras.models.Sequential() + # may down sample + res_blocks.add(BasicBlock(filter_num, stride)) + for _ in range(1, blocks): + res_blocks.add(BasicBlock(filter_num, stride=1)) + return res_blocks + + def get_config(self): + return { + "layer_dims": self.layer_dims, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +class LeNet(keras.Model): + def __init__(self, input_shape, channels=3, num_classes=10): + super(LeNet, self).__init__() + self.input_shape = input_shape + self.channels = channels + self.num_classes = num_classes + + self.conv1 = keras.layers.Conv2D( + 6, + kernel_size=5, + strides=1, + activation="relu", + input_shape=(input_shape, input_shape, channels), + ) + self.pool1 = keras.layers.MaxPool2D(pool_size=2, strides=2) + self.conv2 = keras.layers.Conv2D( + 16, kernel_size=5, strides=1, activation="relu" + ) + self.pool2 = keras.layers.MaxPool2D(pool_size=2, strides=2) + self.flatten = keras.layers.Flatten() + + self.fc1 = keras.layers.Dense(120, activation="relu") + self.fc2 = keras.layers.Dense(84, activation="relu") + self.fc3 = keras.layers.Dense(num_classes, activation="softmax") + + def call(self, inputs, training=None): + x = self.conv1(inputs) + x = self.pool1(x) + x = self.conv2(x) + x = self.pool2(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x + + def get_config(self): + return { + "input_shape": self.input_shape, + "channels": self.channels, + "num_classes": self.num_classes, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def lenet5(input_shape, num_classes: int): + return LeNet(input_shape, 3, num_classes) + + +def resnet10(): + return ResNet([1, 1, 1, 1]) + + +def resnet18(num_classes: int): + return ResNet([2, 2, 2, 2]) + + +def resnet34(num_classes: int): + return ResNet([3, 4, 6, 3]) diff --git a/examples/cifar100/fci_ssl/fedavg/algorithm/resnet.py b/examples/cifar100/fci_ssl/fedavg/algorithm/resnet.py new file mode 100644 index 00000000..cbc4de72 --- /dev/null +++ b/examples/cifar100/fci_ssl/fedavg/algorithm/resnet.py @@ -0,0 +1,121 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import keras + + +# Input--conv2D--BN--ReLU--conv2D--BN--ReLU--Output +# \ / +# ------------------------------ +class BasicBlock(keras.layers.Layer): + def __init__(self, filter_num, stride=1): + super(BasicBlock, self).__init__() + + self.conv1 = keras.layers.Conv2D( + filter_num, (3, 3), strides=stride, padding="same" + ) + self.bn1 = keras.layers.BatchNormalization() + self.relu = keras.layers.Activation("relu") + + self.conv2 = keras.layers.Conv2D(filter_num, (3, 3), strides=1, padding="same") + self.bn2 = keras.layers.BatchNormalization() + + if stride != 1: + self.downsample = keras.models.Sequential() + self.downsample.add(keras.layers.Conv2D(filter_num, (1, 1), strides=stride)) + else: + self.downsample = lambda x: x + + def call(self, inputs, training=None): + # [b, h, w, c] + out = self.conv1(inputs) + out = self.bn1(out, training=training) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out, training=training) + + identity = self.downsample(inputs) + + output = keras.layers.add([out, identity]) + output = tf.nn.relu(output) + + return output + + +class ResNet(keras.Model): + def __init__(self, layer_dims, num_classes=10): # [2, 2, 2, 2] + super(ResNet, self).__init__() + self.layer_dims = layer_dims + self.num_classes = num_classes + + self.stem = keras.models.Sequential( + [ + keras.layers.Conv2D(64, (3, 3), strides=(1, 1)), + keras.layers.BatchNormalization(), + keras.layers.Activation("relu"), + keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(1, 1), padding="same" + ), + ] + ) + + self.layer1 = self.build_resblock(64, layer_dims[0]) + self.layer2 = self.build_resblock(128, layer_dims[1], stride=2) + self.layer3 = self.build_resblock(256, layer_dims[2], stride=2) + self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) + + # output: [b, 512, h, w], + self.avgpool = keras.layers.GlobalAveragePooling2D() + # self.fc = keras.layers.Dense(num_classes) + + def call(self, inputs, training=None): + x = self.stem(inputs, training=training) + + x = self.layer1(x, training=training) + x = self.layer2(x, training=training) + x = self.layer3(x, training=training) + x = self.layer4(x, training=training) + + # [b, c] + x = self.avgpool(x) + return x + + def build_resblock(self, filter_num, blocks, stride=1): + res_blocks = keras.models.Sequential() + # may down sample + res_blocks.add(BasicBlock(filter_num, stride)) + for _ in range(1, blocks): + res_blocks.add(BasicBlock(filter_num, stride=1)) + return res_blocks + + def get_config(self): + return {"layer_dims": self.layer_dims, "num_classes": self.num_classes} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def resnet10(num_classes: int): + return ResNet([1, 1, 1, 1], num_classes) + + +def resnet18(num_classes: int): + return ResNet([2, 2, 2, 2], num_classes) + + +def resnet34(num_classes: int): + return ResNet([3, 4, 6, 3], num_classes) diff --git a/examples/cifar100/fci_ssl/fedavg/benchmarkingjob.yaml b/examples/cifar100/fci_ssl/fedavg/benchmarkingjob.yaml new file mode 100644 index 00000000..6f4b9642 --- /dev/null +++ b/examples/cifar100/fci_ssl/fedavg/benchmarkingjob.yaml @@ -0,0 +1,71 @@ +benchmarkingjob: + # job name of bechmarking; string type; + name: "benchmarkingjob" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "/home/wyd/ianvs/federated_class_incremental_learning/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/cifar100/fci_ssl/fedavg/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms:conda + # algorithm name; string type; + - name: "basic-fedavg" + # the url address of test algorithm configuration file; string type; + # the file format supports yaml/yml + url: "./examples/cifar100/fci_ssl/fedavg/algorithm/algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "task_avg_acc": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "F1_SCORE" + metrics: ["task_avg_acc","forget_rate"] + + # network of save selected and all dataitems in workspace `./rank` ; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + diff --git a/examples/cifar100/fci_ssl/fedavg/test.py b/examples/cifar100/fci_ssl/fedavg/test.py new file mode 100644 index 00000000..5b3cf860 --- /dev/null +++ b/examples/cifar100/fci_ssl/fedavg/test.py @@ -0,0 +1,14 @@ +from algorithm.resnet import resnet10 +from algorithm.network import NetWork, incremental_learning +import copy +import numpy as np +fe = resnet10(10) +model = NetWork(10, fe) +new_model = copy.deepcopy(model) + +x = np.random.rand(1, 32, 32, 3) +y = np.random.randint(0, 10, 1) +model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) +model.fit(x, y, epochs=1) +new_model.compile(optimizer='sgd', loss='sparse_categorical_crossentropy', metrics=['accuracy']) +new_model.fit(x, y, epochs=1) \ No newline at end of file diff --git a/examples/cifar100/fci_ssl/fedavg/testenv/__init__.py b/examples/cifar100/fci_ssl/fedavg/testenv/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/cifar100/fci_ssl/fedavg/testenv/acc.py b/examples/cifar100/fci_ssl/fedavg/testenv/acc.py new file mode 100644 index 00000000..0f1eaf1c --- /dev/null +++ b/examples/cifar100/fci_ssl/fedavg/testenv/acc.py @@ -0,0 +1,34 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ["acc"] + + +@ClassFactory.register(ClassType.GENERAL, alias="accuracy") +def accuracy(y_true, y_pred, **kwargs): + y_pred_arr = [val for val in y_pred.values()] + y_true_arr = [] + for i in range(len(y_pred_arr)): + y_true_arr.append(np.full(y_pred_arr[i].shape, int(y_true[i]))) + y_pred = tf.cast(tf.convert_to_tensor(np.concatenate(y_pred_arr, axis=0)), tf.int64) + y_true = tf.cast(tf.convert_to_tensor(np.concatenate(y_true_arr, axis=0)), tf.int64) + total = tf.shape(y_true)[0] + correct = tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.int32)) + acc = float(int(correct) / total) + print(f"acc:{acc}") + return acc diff --git a/examples/cifar100/fci_ssl/fedavg/testenv/testenv.yaml b/examples/cifar100/fci_ssl/fedavg/testenv/testenv.yaml new file mode 100644 index 00000000..74eff7c0 --- /dev/null +++ b/examples/cifar100/fci_ssl/fedavg/testenv/testenv.yaml @@ -0,0 +1,38 @@ +testenv: + backend: "tensorflow" + dataset: + name: 'cifar100' + # the url address of train dataset index; string type; + train_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_train.txt" + # the url address of test dataset index; string type; + test_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_test.txt" + + + # network eval configuration of incremental learning; + model_eval: + # metric used for network evaluation + model_metric: + # metric name; string type; + name: "accuracy" + # the url address of python file + url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/fci_ssl/fedavg/testenv/acc.py" + + # condition of triggering inference network to update + # threshold of the condition; types are float/int + threshold: 0.01 + # operator of the condition; string type; + # values are ">=", ">", "<=", "<" and "="; + operator: "<=" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + - name: "accuracy" + # the url address of python file + url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/fci_ssl/fedavg/testenv/acc.py" + - name: "task_avg_acc" + - name: "forget_rate" + # incremental rounds setting of incremental learning; int type; default value is 2; + incremental_rounds: 2 + round: 1 + client_number: 1 \ No newline at end of file diff --git a/examples/cifar100/fci_ssl/glfc/algorithm/GLFC.py b/examples/cifar100/fci_ssl/glfc/algorithm/GLFC.py new file mode 100644 index 00000000..469ae720 --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/algorithm/GLFC.py @@ -0,0 +1,447 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import numpy as np +import tensorflow as tf +import keras +import logging +from agumentation import * +from data_prepocessor import * +from model import resnet10 + + +def get_one_hot(target, num_classes): + y = tf.one_hot(target, depth=num_classes) + if len(y.shape) == 3: + y = tf.squeeze(y, axis=1) + return y + + +class GLFC_Client: + def __init__( + self, + num_classes, + batch_size, + task_size, + memory_size, + epochs, + learning_rate, + encode_model, + ): + self.epochs = epochs + self.learning_rate = learning_rate + + self.encode_model = encode_model + + self.num_classes = num_classes + logging.info(f"num_classes is {num_classes}") + self.batch_size = batch_size + self.task_size = task_size + + self.old_model = None + self.train_set = None + + self.exemplar_set = [] # + self.class_mean_set = [] + self.learned_classes = [] + self.learned_classes_numebr = 0 + self.memory_size = memory_size + + self.old_task_id = -1 + self.current_classes = None + self.last_class = None + self.train_loader = None + self.build_feature_extractor() + self.classifier = None + self.labeled_train_set = None + self.unlabeled_train_set = None + self.data_preprocessor = Dataset_Preprocessor( + "cifar100", Weak_Augment("cifar100"), RandAugment("cifar100") + ) + self.warm_up_epochs = 10 + + def build_feature_extractor(self): + self.feature_extractor = resnet10() + self.feature_extractor.build(input_shape=(None, 32, 32, 3)) + self.feature_extractor.call(keras.Input(shape=(32, 32, 3))) + self.feature_extractor.load_weights( + "examples/cifar100/fci_ssl/glfc/algorithm/feature_extractor.weights.h5" + ) + + def _initialize_classifier(self): + if self.classifier != None: + new_classifier = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.num_classes, kernel_initializer="lecun_normal" + ) + ] + ) + new_classifier.build( + input_shape=(None, self.feature_extractor.layers[-2].output_shape[-1]) + ) + new_weights = new_classifier.get_weights() + old_weights = self.classifier.get_weights() + # weight + new_weights[0][0 : old_weights[0].shape[0], 0 : old_weights[0].shape[1]] = ( + old_weights[0] + ) + # bias + new_weights[1][0 : old_weights[1].shape[0]] = old_weights[1] + new_classifier.set_weights(new_weights) + self.classifier = new_classifier + else: + logging.info( + f"input shape is {self.feature_extractor.layers[-2].output_shape[-1]}" + ) + self.classifier = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.num_classes, kernel_initializer="lecun_normal" + ) + ] + ) + self.classifier.build( + input_shape=(None, self.feature_extractor.layers[-2].output_shape[-1]) + ) + + logging.info(f"finish ! initialize classifier {self.classifier.summary()}") + + def before_train(self, task_id, train_data, class_learned, old_model): + logging.info(f"------before train task_id: {task_id}------") + self.need_update = task_id != self.old_task_id + if self.need_update: + self.old_task_id = task_id + self.num_classes = self.task_size * (task_id + 1) + if self.current_classes is not None: + self.last_class = self.current_classes + logging.info( + f"self.last_class is , {self.last_class}, {self.num_classes} tasksize is {self.task_size}, task_id is {task_id}" + ) + self._initialize_classifier() + self.current_classes = np.unique(train_data["label_y"]).tolist() + self.update_new_set(self.need_update) + self.labeled_train_set = (train_data["label_x"], train_data["label_y"]) + self.unlabeled_train_set = ( + train_data["unlabel_x"], + train_data["unlabel_y"], + ) + if len(old_model) != 0: + self.old_model = old_model[0] + self.labeled_train_set = (train_data["label_x"], train_data["label_y"]) + self.unlabeled_train_set = (train_data["unlabel_x"], train_data["unlabel_y"]) + self.labeled_train_loader, self.unlabeled_train_loader = self._get_train_loader( + True + ) + logging.info( + f"------finish before train task_id: {task_id} {self.current_classes}------" + ) + + def update_new_set(self, need_update): + if need_update and self.last_class is not None: + # update exemplar + self.learned_classes += self.last_class + self.learned_classes_numebr += len(self.last_class) + m = int(self.memory_size / self.learned_classes_numebr) + self._reduce_exemplar_set(m) + for i in self.last_class: + images = self.get_train_set_data(i) + self._construct_exemplar_set(images, i, m) + + def _get_train_loader(self, mix): + self.mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + self.std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + train_x = self.labeled_train_set[0] + train_y = self.labeled_train_set[1] + if mix: + for exm_set in self.exemplar_set: + logging.info(f"mix the exemplar{len(exm_set[0])}, {len(exm_set[1])}") + label = np.array(exm_set[1]) + train_x = np.concatenate((train_x, exm_set[0]), axis=0) + train_y = np.concatenate((train_y, label), axis=0) + label_data_loader = self.data_preprocessor.preprocess_labeled_dataset( + train_x, train_y, self.batch_size + ) + unlabel_data_loader = None + if len(self.unlabeled_train_set[0]) > 0: + unlabel_data_loader = self.data_preprocessor.preprocess_unlabeled_dataset( + self.unlabeled_train_set[0], + self.unlabeled_train_set[1], + self.batch_size, + ) + logging.info( + f"unlabel_x shape: {self.unlabeled_train_set[0].shape} and unlabel_y shape: {self.unlabeled_train_set[1].shape}" + ) + return (label_data_loader, unlabel_data_loader) + + def train(self, round): + opt = keras.optimizers.Adam( + learning_rate=self.learning_rate, weight_decay=0.00001 + ) + feature_extractor_params = self.feature_extractor.trainable_variables + classifier_params = self.classifier.trainable_variables + all_params = [] + all_params.extend(feature_extractor_params) + all_params.extend(classifier_params) + + for epoch in range(self.epochs): + # following code is for semi-supervised learning + # for labeled_data, unlabeled_data in zip( + # self.labeled_train_loader, self.unlabeled_train_loader + # ): + # labeled_x, labeled_y = labeled_data + # unlabeled_x, weak_unlabeled_x, strong_unlabeled_x, unlabeled_y = ( + # unlabeled_data + # ) + + # following code is for supervised learning + for step, (x, y) in enumerate(self.labeled_train_loader): + # opt = keras.optimizers.SGD(learning_rate=self.learning_rate, weight_decay=0.00001) + with tf.GradientTape() as tape: + supervised_loss = self._compute_loss(x, y) + loss = supervised_loss + + # following code is for semi-supervised learning + # if epoch > self.warm_up_epochs: + # unsupervised_loss = self.unsupervised_loss( + # weak_unlabeled_x, strong_unlabeled_x + # ) + # loss = loss + 0.5 * unsupervised_loss + # logging.info( + # f"supervised loss is {supervised_loss} unsupervised loss is {unsupervised_loss}" + # ) + logging.info( + f"------round{round} epoch{epoch} loss: {loss} and loss dim is {loss.shape}------" + ) + grads = tape.gradient(loss, all_params) + opt.apply_gradients(zip(grads, all_params)) + + logging.info(f"------finish round{round} traning------") + + def model_call(self, x, training=False): + input = self.feature_extractor(inputs=x, training=training) + return self.classifier(inputs=input, training=training) + + def _compute_loss(self, imgs, labels): + logging.info(f"self.old_model is available: {self.old_model is not None}") + y_pred = self.model_call(imgs, training=True) + target = get_one_hot(labels, self.num_classes) + logits = y_pred + pred = tf.argmax(logits, axis=1) + pred = tf.cast(pred, dtype=tf.int32) + pred = tf.reshape(pred, labels.shape) + + y = tf.cast(labels, dtype=tf.int32) + correct = tf.cast(tf.equal(pred, y), dtype=tf.int32) + correct = tf.reduce_sum(correct) + logging.info( + f"current class numbers is {self.num_classes} correct is {correct} and acc is {correct/imgs.shape[0]} tasksize is {self.task_size} self.old_task_id {self.old_task_id}" + ) + if self.old_model == None: + w = self.efficient_old_class_weight(target, labels) + loss = tf.reduce_mean( + keras.losses.categorical_crossentropy(target, y_pred, from_logits=True) + * w + ) + logging.info( + f"in _compute_loss, without old model loss is {loss} and shape is {loss.shape}" + ) + return loss + else: + w = self.efficient_old_class_weight(target, labels) + loss = tf.reduce_mean( + keras.losses.binary_crossentropy(target, y_pred, from_logits=True) * w + ) + distill_target = tf.Variable(get_one_hot(labels, self.num_classes)) + old_target = tf.sigmoid(self.old_model[1](self.old_model[0]((imgs)))) + old_task_size = old_target.shape[1] + distill_target[:, :old_task_size].assign(old_target) + loss_old = tf.reduce_mean( + keras.losses.binary_crossentropy( + distill_target, y_pred, from_logits=True + ) + ) + logging.info(f"loss old is {loss_old}") + return loss + loss_old + + def unsupervised_loss(self, weak_x, strong_x): + self.accept_threshold = 0.95 + prob_on_wux = tf.nn.softmax( + self.classifier( + self.feature_extractor(weak_x, training=True), training=True + ) + ) + pseudo_mask = tf.cast( + (tf.reduce_max(prob_on_wux, axis=1) >= self.accept_threshold), tf.float32 + ) + pse_uy = tf.one_hot( + tf.argmax(prob_on_wux, axis=1), depth=self.num_classes + ).numpy() + prob_on_sux = tf.nn.softmax( + self.classifier( + self.feature_extractor(strong_x, training=True), training=True + ) + ) + loss = keras.losses.categorical_crossentropy(pse_uy, prob_on_sux) + loss = tf.reduce_mean(loss * pseudo_mask) + return loss + + def efficient_old_class_weight(self, output, labels): + pred = tf.sigmoid(output) + N, C = pred.shape + class_mask = tf.zeros([N, C], dtype=tf.float32) + class_mask = tf.Variable(class_mask) + ids = np.zeros([N, 2], dtype=np.int32) + for i in range(N): + ids[i][0] = i + ids[i][1] = labels[i] + updates = tf.ones([N], dtype=tf.float32) + class_mask = tf.tensor_scatter_nd_update(class_mask, ids, updates) + target = get_one_hot(labels, self.num_classes) + g = tf.abs(target - pred) + g = tf.reduce_sum(g * class_mask, axis=1) + idx = tf.cast(tf.reshape(labels, (-1, 1)), tf.int32) + if len(self.learned_classes) != 0: + for i in self.learned_classes: + mask = tf.math.not_equal(idx, i) + negative_value = tf.cast(tf.fill(tf.shape(idx), -1), tf.int32) + idx = tf.where(mask, idx, negative_value) + index1 = tf.cast(tf.equal(idx, -1), tf.float32) + index2 = tf.cast(tf.not_equal(idx, -1), tf.float32) + w1 = tf.where( + tf.not_equal(tf.reduce_sum(index1), 0), + tf.math.divide( + g * index1, (tf.reduce_sum(g * index1) / tf.reduce_sum(index1)) + ), + tf.zeros_like(g), + ) + w2 = tf.where( + tf.not_equal(tf.reduce_sum(index2), 0), + tf.math.divide( + g * index2, (tf.reduce_sum(g * index2) / tf.reduce_sum(index2)) + ), + tf.zeros_like(g), + ) + w = w1 + w2 + return w + else: + return tf.ones(g.shape, dtype=tf.float32) + + def get_train_set_data(self, class_id): + + images = [] + train_x = self.labeled_train_set[0] + train_y = self.labeled_train_set[1] + for i in range(len(train_x)): + if train_y[i] == class_id: + images.append(train_x[i]) + return images + + def get_data_size(self): + logging.info( + f"self.labeled_train_set is None :{self.labeled_train_set is None}" + ) + logging.info( + f"self.unlabeled_train_set is None :{self.unlabeled_train_set is None}" + ) + data_size = len(self.labeled_train_set[0]) + logging.info(f"data size: {data_size}") + return data_size + + def _reduce_exemplar_set(self, m): + for i in range(len(self.exemplar_set)): + old_exemplar_data = self.exemplar_set[i][0][:m] + old_exemplar_label = self.exemplar_set[i][1][:m] + self.exemplar_set[i] = (old_exemplar_data, old_exemplar_label) + + def _construct_exemplar_set(self, images, label, m): + class_mean, fe_outpu = self.compute_class_mean(images) + exemplar = [] + labels = [] + now_class_mean = np.zeros((1, 512)) + for i in range(m): + x = class_mean - (now_class_mean + fe_outpu) / (i + 1) + x = np.linalg.norm(x) + index = np.argmin(x) + now_class_mean += fe_outpu[index] + exemplar.append(images[index]) + labels.append(label) + self.exemplar_set.append((exemplar, labels)) + + def compute_class_mean(self, images): + images_data = tf.data.Dataset.from_tensor_slices(images).batch(self.batch_size) + fe_output = self.feature_extractor.predict(images_data) + fe_output = tf.nn.l2_normalize(fe_output).numpy() + class_mean = tf.reduce_mean(fe_output, axis=0) + return class_mean, fe_output + + def proto_grad(self): + if self.need_update == False: + return None + self.need_update = False + cri_loss = keras.losses.SparseCategoricalCrossentropy() + proto = [] + proto_grad = [] + logging.info(f"self. current class is {self.current_classes}") + for i in self.current_classes: + images = self.get_train_set_data(i) + class_mean, fe_output = self.compute_class_mean(images) + dis = np.linalg.norm(class_mean - fe_output, axis=1) + pro_index = np.argmin(dis) + proto.append(images[pro_index]) + + for i in range(len(proto)): + data = proto[i] + data = tf.cast(tf.expand_dims(data, axis=0), tf.float32) + label = self.current_classes[i] + label = tf.constant([label]) + target = get_one_hot(label, self.num_classes) + logging.info( + f"proto_grad target shape is {target.shape} and num_classes is {self.num_classes}" + ) + proto_fe = resnet10() + proto_fe.build(input_shape=(None, 32, 32, 3)) + proto_fe.call(keras.Input(shape=(32, 32, 3))) + proto_fe.set_weights(self.feature_extractor.get_weights()) + proto_clf = copy.deepcopy(self.classifier) + proto_param = proto_fe.trainable_variables + proto_param.extend(proto_clf.trainable_variables) + with tf.GradientTape() as tape: + outputs = self.encode_model(data) + loss_cls = cri_loss(label, outputs) + dy_dx = tape.gradient(loss_cls, self.encode_model.trainable_variables) + original_dy_dx = [tf.identity(grad) for grad in dy_dx] + proto_grad.append(original_dy_dx) + return proto_grad + + def evaluate(self): + logging.info("evaluate") + total_num = 0 + total_correct = 0 + for x, y in self.train_loader: + logits = self.model_call(x, training=False) + pred = tf.argmax(logits, axis=1) + pred = tf.cast(pred, dtype=tf.int32) + pred = tf.reshape(pred, y.shape) + logging.info(pred) + y = tf.cast(y, dtype=tf.int32) + correct = tf.cast(tf.equal(pred, y), dtype=tf.int32) + correct = tf.reduce_sum(correct) + total_num += x.shape[0] + total_correct += int(correct) + acc = total_correct / total_num + del total_correct + logging.info(f"finsih task {self.old_task_id} evaluate, acc: {acc}") + return acc diff --git a/examples/cifar100/fci_ssl/glfc/algorithm/aggregation.py b/examples/cifar100/fci_ssl/glfc/algorithm/aggregation.py new file mode 100644 index 00000000..a61c791b --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/algorithm/aggregation.py @@ -0,0 +1,82 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from copy import deepcopy +from typing import List + +import numpy as np +from sedna.algorithms.aggregation.aggregation import BaseAggregation +from sedna.common.class_factory import ClassType, ClassFactory +from proxy_server import ProxyServer + + +@ClassFactory.register(ClassType.FL_AGG, "FedAvg") +class FedAvg(BaseAggregation, abc.ABC): + def __init__(self): + super(FedAvg, self).__init__() + self.proxy_server = ProxyServer( + learning_rate=0.01, num_classes=10, test_data=None + ) + self.task_id = -1 + self.num_classes = 10 + + def aggregate(self, clients): + """ + Calculate the average weight according to the number of samples + + Parameters + ---------- + clients: List + All clients in federated learning job + + Returns + ------- + update_weights : Array-like + final weights use to update model layer + """ + + print("aggregation....") + if not len(clients): + return self.weights + self.total_size = sum([c.num_samples for c in clients]) + old_weight = [np.zeros(np.array(c).shape) for c in next(iter(clients)).weights] + updates = [] + for inx, row in enumerate(old_weight): + for c in clients: + row += np.array(c.weights[inx]) * c.num_samples / self.total_size + updates.append(row.tolist()) + + self.weights = [np.array(layer) for layer in updates] + + print("finish aggregation....") + return self.weights + + def helper_function(self, train_infos, **kwargs): + proto_grad = [] + task_id = -1 + for key, value in train_infos.items(): + if "proto_grad" == key and value is not None: + for grad_i in value: + proto_grad.append(grad_i) + if "task_id" == key: + task_id = max(value, task_id) + self.proxy_server.dataload(proto_grad) + if task_id > self.task_id: + self.task_id = task_id + print(f"incremental num classes is {self.num_classes * (task_id + 1)}") + self.proxy_server.increment_class(self.num_classes * (task_id + 1)) + self.proxy_server.set_weights(self.weights) + print(f"finish set weight for proxy server") + return {"best_old_model": self.proxy_server.model_back()} diff --git a/examples/cifar100/fci_ssl/glfc/algorithm/agumentation.py b/examples/cifar100/fci_ssl/glfc/algorithm/agumentation.py new file mode 100644 index 00000000..89d1bef2 --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/algorithm/agumentation.py @@ -0,0 +1,230 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import random +import tensorflow as tf +from PIL import Image, ImageEnhance, ImageOps + + +""" +Reference: https://github.com/heartInsert/randaugment +""" + + +class Rand_Augment: + def __init__(self, Numbers=None, max_Magnitude=None): + self.transforms = [ + "autocontrast", + "equalize", + "rotate", + "solarize", + "color", + "posterize", + "contrast", + "brightness", + "sharpness", + "shearX", + "shearY", + "translateX", + "translateY", + ] + if Numbers is None: + self.Numbers = len(self.transforms) // 2 + else: + self.Numbers = Numbers + if max_Magnitude is None: + self.max_Magnitude = 10 + else: + self.max_Magnitude = max_Magnitude + fillcolor = 128 + self.ranges = { + # these Magnitude range , you must test it yourself , see what will happen after these operation , + # it is no need to obey the value in autoaugment.py + "shearX": np.linspace(0, 0.3, 10), + "shearY": np.linspace(0, 0.3, 10), + "translateX": np.linspace(0, 0.2, 10), + "translateY": np.linspace(0, 0.2, 10), + "rotate": np.linspace(0, 360, 10), + "color": np.linspace(0.0, 0.9, 10), + "posterize": np.round(np.linspace(8, 4, 10), 0).astype(int), + "solarize": np.linspace(256, 231, 10), + "contrast": np.linspace(0.0, 0.5, 10), + "sharpness": np.linspace(0.0, 0.9, 10), + "brightness": np.linspace(0.0, 0.3, 10), + "autocontrast": [0] * 10, + "equalize": [0] * 10, + "invert": [0] * 10, + } + self.func = { + "shearX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), + Image.BICUBIC, + fill=fillcolor, + ), + "shearY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), + Image.BICUBIC, + fill=fillcolor, + ), + "translateX": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), + fill=fillcolor, + ), + "translateY": lambda img, magnitude: img.transform( + img.size, + Image.AFFINE, + (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), + fill=fillcolor, + ), + "rotate": lambda img, magnitude: self.rotate_with_fill(img, magnitude), + # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), + "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), + "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( + 1 + magnitude * random.choice([-1, 1]) + ), + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), + "equalize": lambda img, magnitude: img, + "invert": lambda img, magnitude: ImageOps.invert(img), + } + + def rand_augment(self): + """Generate a set of distortions. + Args: + N: Number of augmentation transformations to apply sequentially. N is len(transforms)/2 will be best + M: Max_Magnitude for all the transformations. should be <= self.max_Magnitude + """ + + M = np.random.randint(0, self.max_Magnitude, self.Numbers) + + sampled_ops = np.random.choice(self.transforms, self.Numbers) + return [(op, Magnitude) for (op, Magnitude) in zip(sampled_ops, M)] + + def __call__(self, image): + operations = self.rand_augment() + for op_name, M in operations: + operation = self.func[op_name] + mag = self.ranges[op_name][M] + image = operation(image, mag) + return image + + def rotate_with_fill(self, img, magnitude): + # I don't know why rotate must change to RGBA , it is copy from Autoaugment - pytorch + rot = img.convert("RGBA").rotate(magnitude) + return Image.composite( + rot, Image.new("RGBA", rot.size, (128,) * 4), rot + ).convert(img.mode) + + def test_single_operation(self, image, op_name, M=-1): + """ + :param image: image + :param op_name: operation name in self.transforms + :param M: -1 stands for the max Magnitude in there operation + :return: + """ + operation = self.func[op_name] + mag = self.ranges[op_name][M] + image = operation(image, mag) + return image + + +class Base_Augment: + def __init__(self, dataset_name: str) -> None: + self.dataset_name = dataset_name + + def __call__(self, images): + return images + + +class Weak_Augment(Base_Augment): + def __init__(self, dataset_name: str) -> None: + super().__init__(dataset_name) + self.augment_impl = self.augment_for_cifar + + def augment_mirror(self, x): + new_images = x.copy() + indices = np.arange(len(new_images)).tolist() + sampled = random.sample( + indices, int(round(0.5 * len(indices))) + ) # flip horizontally 50% + new_images[sampled] = np.fliplr(new_images[sampled]) + return new_images # random shift + + def augment_shift(self, x, w): + y = tf.pad(x, [[0] * 2, [w] * 2, [w] * 2, [0] * 2], mode="REFLECT") + return tf.image.random_crop(y, tf.shape(x)) + + def augment_for_cifar(self, images: np.ndarray): + return self.augment_shift(self.augment_mirror(images), 4) + + def __call__(self, images: np.ndarray): + return self.augment_impl(images) + + +class Strong_Augment(Base_Augment): + def __init__(self, dataset_name: str) -> None: + super().__init__(dataset_name) + + def augment_mirror(self, x): + new_images = x.copy() + indices = np.arange(len(new_images)).tolist() + sampled = random.sample( + indices, int(round(0.5 * len(indices))) + ) # flip horizontally 50% + new_images[sampled] = np.fliplr(new_images[sampled]) + return new_images # random shift + + def augment_shift_mnist(self, x, w): + y = tf.pad(x, [[0] * 2, [w] * 2, [w] * 2], mode="REFLECT") + return tf.image.random_crop(y, tf.shape(x)) + + def __call__(self, images: np.ndarray): + return self.augment_shift_mnist(self.augment_mirror(images), 4) + + +class RandAugment(Base_Augment): + def __init__(self, dataset_name: str) -> None: + super().__init__(dataset_name) + self.rand_augment = Rand_Augment() + self.input_shape = (32, 32, 3) + + def __call__(self, images): + print("images:", images.shape) + + return np.array( + [ + np.array( + self.rand_augment( + Image.fromarray(np.reshape(img, self.input_shape)) + ) + ) + for img in images + ] + ) diff --git a/examples/cifar100/fci_ssl/glfc/algorithm/algorithm.yaml b/examples/cifar100/fci_ssl/glfc/algorithm/algorithm.yaml new file mode 100644 index 00000000..f09a6e86 --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/algorithm/algorithm.yaml @@ -0,0 +1,49 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "federatedclassincrementallearning" + fl_data_setting: + # ratio of training dataset; float type; + # the default value is 0.8. + train_ratio: 1.0 + # the method of splitting dataset; string type; optional; + # currently the options of value are as follows: + # 1> "default": the dataset is evenly divided based train_ratio; + splitting_method: "default" + label_data_ratio: 1.0 + data_partition: "iid" + # the url address of initial network for network pre-training; string url; + # the url address of initial network; string type; optional; + initial_model_url: "/home/wyd/ianvs/project/init_model/cnn.pb" + # algorithm module configuration in the paradigm; list type; + # incremental rounds setting of incremental learning; int type; default value is 2; + + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "GLFCMatch-Client" + # the url address of python module; string type; + url: "./examples/cifar100/fci_ssl/glfc/algorithm/basemodel.py" + + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + - batch_size: + values: + - 64 + - learning_rate: + values: + - 0.001 + - epochs: + values: + - 20 + - type: "aggregation" + name: "FedAvg" + url: "./examples/cifar100/fci_ssl/glfc/algorithm/aggregation.py" + diff --git a/examples/cifar100/fci_ssl/glfc/algorithm/basemodel.py b/examples/cifar100/fci_ssl/glfc/algorithm/basemodel.py new file mode 100644 index 00000000..e459e8ee --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/algorithm/basemodel.py @@ -0,0 +1,106 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import keras +import tensorflow as tf +from sedna.common.class_factory import ClassType, ClassFactory +from model import resnet10, lenet5 +from GLFC import GLFC_Client +import logging + +os.environ["BACKEND_TYPE"] = "KERAS" +__all__ = ["BaseModel"] +logging.getLogger().setLevel(logging.INFO) + + +@ClassFactory.register(ClassType.GENERAL, alias="GLFCMatch-Client") +class BaseModel: + def __init__(self, **kwargs) -> None: + self.kwargs = kwargs + self.learning_rate = kwargs.get("learning_rate", 0.001) + self.epochs = kwargs.get("epochs", 1) + self.batch_size = kwargs.get("batch_size", 32) + self.task_size = kwargs.get("task_size", 10) + self.memory_size = kwargs.get("memory_size", 2000) + self.encode_model = lenet5(32, 100) + self.encode_model.call(keras.Input(shape=(32, 32, 3))) + self.num_classes = 10 # the number of class for the first task + self.GLFC_Client = GLFC_Client( + self.num_classes, + self.batch_size, + self.task_size, + self.memory_size, + self.epochs, + self.learning_rate, + self.encode_model, + ) + self.best_old_model = [] + self.class_learned = 0 + self.fe_weights_length = len(self.GLFC_Client.feature_extractor.get_weights()) + + def get_weights(self): + print("get weights") + weights = [] + fe_weights = self.GLFC_Client.feature_extractor.get_weights() + clf_weights = self.GLFC_Client.classifier.get_weights() + weights.extend(fe_weights) + weights.extend(clf_weights) + return weights + + def set_weights(self, weights): + print("set weights") + fe_weights = weights[: self.fe_weights_length] + + clf_weights = weights[self.fe_weights_length :] + self.GLFC_Client.feature_extractor.set_weights(fe_weights) + self.GLFC_Client.classifier.set_weights(clf_weights) + + def train(self, train_data, val_data, **kwargs): + task_id = kwargs.get("task_id", 0) + round = kwargs.get("round", 1) + logging.info(f"in train: {round} task_id: {task_id}") + self.class_learned += self.task_size + self.GLFC_Client.before_train( + task_id, train_data, self.class_learned, old_model=self.best_old_model + ) + + self.GLFC_Client.train(round) + proto_grad = self.GLFC_Client.proto_grad() + return { + "num_samples": self.GLFC_Client.get_data_size(), + "proto_grad": proto_grad, + "task_id": task_id, + } + + def helper_function(self, helper_info, **kwargs): + self.best_old_model = helper_info["best_old_model"] + if self.best_old_model[1] != None: + self.GLFC_Client.old_model = self.best_old_model[1] + else: + self.GLFC_Client.old_model = self.best_old_model[0] + + def predict(self, datas, **kwargs): + result = {} + mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + for data in datas: + x = np.load(data) + x = (tf.cast(x, dtype=tf.float32) / 255.0 - mean) / std + logits = self.GLFC_Client.model_call(x, training=False) + pred = tf.cast(tf.argmax(logits, axis=1), tf.int32) + result[data] = pred.numpy() + print("finish predict") + return result diff --git a/examples/cifar100/fci_ssl/glfc/algorithm/data_prepocessor.py b/examples/cifar100/fci_ssl/glfc/algorithm/data_prepocessor.py new file mode 100644 index 00000000..449e09f2 --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/algorithm/data_prepocessor.py @@ -0,0 +1,65 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +from agumentation import Base_Augment + + +class Dataset_Preprocessor: + def __init__( + self, + dataset_name: str, + weak_augment_helper: Base_Augment, + strong_augment_helper: Base_Augment, + ) -> None: + self.weak_augment_helper = weak_augment_helper + self.strong_augment_helper = strong_augment_helper + self.mean = 0.0 + self.std = 1.0 + if dataset_name == "cifar100": + self.mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + self.std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + print(f"mean: {self.mean}, std: {self.std}") + + def preprocess_labeled_dataset(self, x, y, batch_size): + return ( + tf.data.Dataset.from_tensor_slices((x, y)) + .shuffle(100000) + .map( + lambda x, y: ( + (tf.cast(x, dtype=tf.float32) / 255.0 - self.mean) / self.std, + tf.cast(y, dtype=tf.int32), + ) + ) + .batch(batch_size) + ) + + def preprocess_unlabeled_dataset(self, ux, uy, batch_size): + + wux = self.weak_augment_helper(ux) + sux = self.strong_augment_helper(ux) + return ( + tf.data.Dataset.from_tensor_slices((ux, wux, sux, uy)) + .shuffle(100000) + .map( + lambda ux, wux, sux, uy: ( + (tf.cast(ux, dtype=tf.float32) / 255.0 - self.mean) / self.std, + (tf.cast(wux, dtype=tf.float32) / 255.0 - self.mean) / self.std, + (tf.cast(sux, dtype=tf.float32) / 255.0 - self.mean) / self.std, + tf.cast(uy, dtype=tf.int32), + ) + ) + .batch(batch_size) + ) diff --git a/examples/cifar100/fci_ssl/glfc/algorithm/feature_extractor.weights.h5 b/examples/cifar100/fci_ssl/glfc/algorithm/feature_extractor.weights.h5 new file mode 100644 index 00000000..d216d072 Binary files /dev/null and b/examples/cifar100/fci_ssl/glfc/algorithm/feature_extractor.weights.h5 differ diff --git a/examples/cifar100/fci_ssl/glfc/algorithm/model.py b/examples/cifar100/fci_ssl/glfc/algorithm/model.py new file mode 100644 index 00000000..8498eba4 --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/algorithm/model.py @@ -0,0 +1,205 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import keras + + +# Input--conv2D--BN--ReLU--conv2D--BN--ReLU--Output +# \ / +# ------------------------------ +class BasicBlock(keras.layers.Layer): + def __init__(self, filter_num, stride=1): + super(BasicBlock, self).__init__() + + self.conv1 = keras.layers.Conv2D( + filter_num, (3, 3), strides=stride, padding="same" + ) + self.bn1 = keras.layers.BatchNormalization() + self.relu = keras.layers.Activation("relu") + + self.conv2 = keras.layers.Conv2D(filter_num, (3, 3), strides=1, padding="same") + self.bn2 = keras.layers.BatchNormalization() + + if stride != 1: + self.downsample = keras.models.Sequential() + self.downsample.add(keras.layers.Conv2D(filter_num, (1, 1), strides=stride)) + else: + self.downsample = lambda x: x + + def call(self, inputs, training=None): + # [b, h, w, c] + out = self.conv1(inputs) + out = self.bn1(out, training=training) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out, training=training) + + identity = self.downsample(inputs) + + output = keras.layers.add([out, identity]) + output = tf.nn.relu(output) + + return output + + +# restnet +class ResNet(keras.Model): + def __init__(self, layer_dims): # [2, 2, 2, 2] + super(ResNet, self).__init__() + self.layer_dims = layer_dims + + self.stem = keras.models.Sequential( + [ + keras.layers.Conv2D(64, (3, 3), strides=(1, 1)), + keras.layers.BatchNormalization(), + keras.layers.Activation("relu"), + keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(1, 1), padding="same" + ), + ] + ) + + self.layer1 = self.build_resblock(64, layer_dims[0]) + self.layer2 = self.build_resblock(128, layer_dims[1], stride=2) + self.layer3 = self.build_resblock(256, layer_dims[2], stride=2) + self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) + + # output: [b, 512, h, w], + self.avgpool = keras.layers.GlobalAveragePooling2D() + + def call(self, inputs, training=None): + x = self.stem(inputs, training=training) + + x = self.layer1(x, training=training) + x = self.layer2(x, training=training) + x = self.layer3(x, training=training) + x = self.layer4(x, training=training) + x = self.avgpool(x) + return x + + def build_resblock(self, filter_num, blocks, stride=1): + res_blocks = keras.models.Sequential() + # may down sample + res_blocks.add(BasicBlock(filter_num, stride)) + for _ in range(1, blocks): + res_blocks.add(BasicBlock(filter_num, stride=1)) + return res_blocks + + def get_config(self): + return { + "layer_dims": self.layer_dims, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +class LeNet(keras.Model): + def __init__(self, input_shape, channels=3, num_classes=10): + super(LeNet, self).__init__() + self.input_shape = input_shape + self.channels = channels + self.num_classes = num_classes + + self.conv1 = keras.layers.Conv2D( + 6, + kernel_size=5, + strides=1, + activation="relu", + input_shape=(input_shape, input_shape, channels), + kernel_initializer=keras.initializers.RandomUniform( + minval=-0.5, maxval=0.5 + ), + bias_initializer=keras.initializers.RandomUniform(minval=-0.5, maxval=0.5), + ) + + self.pool1 = keras.layers.MaxPool2D(pool_size=2, strides=2) + self.conv2 = keras.layers.Conv2D( + 16, + kernel_size=5, + strides=1, + activation="relu", + kernel_initializer=keras.initializers.RandomUniform( + minval=-0.5, maxval=0.5 + ), + bias_initializer=keras.initializers.RandomUniform(minval=-0.5, maxval=0.5), + ) + self.pool2 = keras.layers.MaxPool2D(pool_size=2, strides=2) + self.flatten = keras.layers.Flatten() + + self.fc1 = keras.layers.Dense( + 120, + activation="relu", + kernel_initializer=keras.initializers.RandomUniform( + minval=-0.5, maxval=0.5 + ), + bias_initializer=keras.initializers.RandomUniform(minval=-0.5, maxval=0.5), + ) + self.fc2 = keras.layers.Dense( + 84, + activation="relu", + kernel_initializer=keras.initializers.RandomUniform( + minval=-0.5, maxval=0.5 + ), + bias_initializer=keras.initializers.RandomUniform(minval=-0.5, maxval=0.5), + ) + self.fc3 = keras.layers.Dense( + num_classes, + activation="softmax", + kernel_initializer=keras.initializers.RandomUniform( + minval=-0.5, maxval=0.5 + ), + bias_initializer=keras.initializers.RandomUniform(minval=-0.5, maxval=0.5), + ) + + def call(self, inputs, training=None): + x = self.conv1(inputs) + x = self.pool1(x) + x = self.conv2(x) + x = self.pool2(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + return x + + def get_config(self): + return { + "input_shape": self.input_shape, + "channels": self.channels, + "num_classes": self.num_classes, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def lenet5(input_shape, num_classes: int): + return LeNet(input_shape, 3, num_classes) + + +def resnet10(): + return ResNet([1, 1, 1, 1]) + + +def resnet18(num_classes: int): + return ResNet([2, 2, 2, 2]) + + +def resnet34(num_classes: int): + return ResNet([3, 4, 6, 3]) diff --git a/examples/cifar100/fci_ssl/glfc/algorithm/proxy_server.py b/examples/cifar100/fci_ssl/glfc/algorithm/proxy_server.py new file mode 100644 index 00000000..69e03901 --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/algorithm/proxy_server.py @@ -0,0 +1,201 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras +import copy +import numpy as np +import tensorflow as tf +import logging +from model import resnet10, resnet18, resnet34, lenet5 + +logging.getLogger().setLevel(logging.INFO) + + +class ProxyData: + def __init__(self): + self.test_data = [] + self.test_label = [] + + +class ProxyServer: + def __init__(self, learning_rate, num_classes, **kwargs): + self.learning_rate = learning_rate + + self.encode_model = lenet5(32, 100) + + self.monitor_dataset = ProxyData() + self.new_set = [] + self.new_set_label = [] + self.num_classes = num_classes + self.proto_grad = None + + self.best_perf = 0 + + self.num_image = 20 + self.Iteration = 250 + + self.build_model() + self.fe_weights_length = len(self.feature_extractor.get_weights()) + self.classifier = None + self.best_model_1 = None + self.best_model_2 = None + + def build_model(self): + self.feature_extractor = resnet10() + self.feature_extractor.build(input_shape=(None, 32, 32, 3)) + self.feature_extractor.call(keras.Input(shape=(32, 32, 3))) + + def set_weights(self, weights): + print(f"set weights {self.num_classes}") + fe_weights = weights[: self.fe_weights_length] + clf_weights = weights[self.fe_weights_length :] + self.feature_extractor.set_weights(fe_weights) + self.classifier.set_weights(clf_weights) + + def increment_class(self, num_classes): + print(f"increment class {num_classes}") + self.num_classes = num_classes + self._initialize_classifier() + + def _initialize_classifier(self): + if self.classifier != None: + new_classifier = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.num_classes, kernel_initializer="lecun_normal" + ) + ] + ) + new_classifier.build( + input_shape=(None, self.feature_extractor.layers[-2].output_shape[-1]) + ) + new_weights = new_classifier.get_weights() + old_weights = self.classifier.get_weights() + # weight + new_weights[0][0 : old_weights[0].shape[0], 0 : old_weights[0].shape[1]] = ( + old_weights[0] + ) + # bias + new_weights[1][0 : old_weights[1].shape[0]] = old_weights[1] + new_classifier.set_weights(new_weights) + self.classifier = new_classifier + else: + + self.classifier = tf.keras.Sequential( + [ + tf.keras.layers.Dense( + self.num_classes, kernel_initializer="lecun_normal" + ) + ] + ) + self.classifier.build( + input_shape=(None, self.feature_extractor.layers[-2].output_shape[-1]) + ) + self.best_model_1 = (self.feature_extractor, self.classifier) + logging.info(f"finish ! initialize classifier {self.classifier}") + + def model_back(self): + return [self.best_model_1, self.best_model_2] + + def dataload(self, proto_grad): + self._initialize_classifier() + self.proto_grad = proto_grad + if len(proto_grad) != 0: + self.reconstruction() + self.monitor_dataset.test_data = self.new_set + self.monitor_dataset.test_label = self.new_set_label + self.last_perf = 0 + self.best_model_1 = self.best_model_2 + cur_perf = self.monitor() + logging.info(f"in proxy server, current performance is {cur_perf}") + if cur_perf > self.best_perf: + self.best_perf = cur_perf + self.best_model_2 = (self.feature_extractor, self.classifier) + + def monitor(self): + correct, total = 0, 0 + for x, y in zip( + self.monitor_dataset.test_data, self.monitor_dataset.test_label + ): + y_pred = self.classifier(self.feature_extractor((x))) + + predicts = tf.argmax(y_pred, axis=-1) + predicts = tf.cast(predicts, tf.int32) + logging.info(f"y_pred {predicts} and y {y}") + correct += tf.reduce_sum(tf.cast(tf.equal(predicts, y), dtype=tf.int32)) + total += x.shape[0] + acc = 100 * correct / total + return acc + + def grad2label(self): + proto_grad_label = [] + for w_single in self.proto_grad: + pred = tf.argmin(tf.reduce_sum(w_single[-2], axis=-1), axis=-1) + proto_grad_label.append(pred) + return proto_grad_label + + def reconstruction(self): + self.new_set = [] + self.new_set_label = [] + proto_label = self.grad2label() + proto_label = np.array(proto_label) + class_ratio = np.zeros((1, 100)) + + for i in proto_label: + class_ratio[0][i] += 1 + + for label_i in range(100): + if class_ratio[0][label_i] > 0: + agumentation = [] + + grad_index = np.where(proto_label == label_i) + logging.info(f"grad index : {grad_index} and label is {label_i}") + for j in range(len(grad_index[0])): + grad_true_temp = self.proto_grad[grad_index[0][j]] + + dummy_data = tf.Variable( + np.random.rand(1, 32, 32, 3), trainable=True + ) + label_pred = tf.constant([label_i]) + + opt = keras.optimizers.SGD(learning_rate=0.1) + cri = keras.losses.SparseCategoricalCrossentropy() + + recon_model = copy.deepcopy(self.encode_model) + + for iter in range(self.Iteration): + with tf.GradientTape() as tape0: + with tf.GradientTape() as tape1: + y_pred = recon_model(dummy_data) + loss = cri(label_pred, y_pred) + dummy_dy_dx = tape1.gradient( + loss, recon_model.trainable_variables + ) + + grad_diff = 0 + for gx, gy in zip(dummy_dy_dx, grad_true_temp): + gx = tf.cast(gx, tf.double) + gy = tf.cast(gy, tf.double) + sub_value = tf.subtract(gx, gy) + pow_value = tf.pow(sub_value, 2) + grad_diff += tf.reduce_sum(pow_value) + grad = tape0.gradient(grad_diff, dummy_data) + opt.apply_gradients(zip([grad], [dummy_data])) + + if iter >= self.Iteration - self.num_image: + dummy_data_temp = np.asarray(dummy_data) + agumentation.append(dummy_data_temp) + + self.new_set.extend(agumentation) + self.new_set_label.extend([label_i]) diff --git a/examples/cifar100/fci_ssl/glfc/benchmarkingjob.yaml b/examples/cifar100/fci_ssl/glfc/benchmarkingjob.yaml new file mode 100644 index 00000000..006d14ce --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/benchmarkingjob.yaml @@ -0,0 +1,71 @@ +benchmarkingjob: + # job name of bechmarking; string type; + name: "benchmarkingjob" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "/home/wyd/ianvs/federated_class_incremental_learning/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/cifar100/fci_ssl/glfc/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "GLFCMatch" + # the url address of test algorithm configuration file; string type; + # the file format supports yaml/yml + url: "./examples/cifar100/fci_ssl/glfc/algorithm/algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "task_avg_acc": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "F1_SCORE" + metrics: [ "task_avg_acc","forget_rate" ] + + # network of save selected and all dataitems in workspace `./rank` ; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + diff --git a/examples/cifar100/fci_ssl/glfc/testenv/acc.py b/examples/cifar100/fci_ssl/glfc/testenv/acc.py new file mode 100644 index 00000000..0fe532d3 --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/testenv/acc.py @@ -0,0 +1,35 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ['acc'] + + +@ClassFactory.register(ClassType.GENERAL, alias='accuracy') +def accuracy(y_true, y_pred, **kwargs): + y_pred_arr = [val for val in y_pred.values()] + y_true_arr = [] + for i in range(len(y_pred_arr)): + y_true_arr.append(np.full(y_pred_arr[i].shape, int(y_true[i]))) + y_pred = tf.cast(tf.convert_to_tensor(np.concatenate(y_pred_arr, axis=0)), tf.int64) + y_true = tf.cast(tf.convert_to_tensor(np.concatenate(y_true_arr, axis=0)), tf.int64) + total = tf.shape(y_true)[0] + correct = tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.int32)) + acc = float(int(correct) / total) + print(f"acc:{acc}") + return acc + diff --git a/examples/cifar100/fci_ssl/glfc/testenv/testenv.yaml b/examples/cifar100/fci_ssl/glfc/testenv/testenv.yaml new file mode 100644 index 00000000..5c5b961e --- /dev/null +++ b/examples/cifar100/fci_ssl/glfc/testenv/testenv.yaml @@ -0,0 +1,31 @@ +testenv: + backend: "tensorflow" + dataset: + name: 'cifar100' + # the url address of train dataset index; string type; + train_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_train.txt" + # the url address of test dataset index; string type; + test_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_test.txt" + + + # network eval configuration of incremental learning; + model_eval: + # metric used for network evaluation + model_metric: + # metric name; string type; + name: "accuracy" + # the url address of python file + url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/fci_ssl/glfc/testenv/acc.py" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + # - name: "accuracy" + # # the url address of python file + # url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/fci_ssl/glfc/testenv/acc.py" + - name: "forget_rate" + - name: "task_avg_acc" + # incremental rounds setting of incremental learning; int type; default value is 2; + incremental_rounds: 10 + round: 5 + client_number: 5 \ No newline at end of file diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/__init__.py b/examples/cifar100/federated_class_incremental_learning/fedavg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/__init__.py b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/aggregation.py b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/aggregation.py new file mode 100644 index 00000000..53237aa1 --- /dev/null +++ b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/aggregation.py @@ -0,0 +1,60 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from copy import deepcopy +from typing import List + +import numpy as np +from sedna.algorithms.aggregation.aggregation import BaseAggregation +from sedna.common.class_factory import ClassType, ClassFactory + + +@ClassFactory.register(ClassType.FL_AGG, "FedAvg") +class FedAvg(BaseAggregation, abc.ABC): + def __init__(self): + super(FedAvg, self).__init__() + + """ + Federated averaging algorithm + """ + + def aggregate(self, clients: List): + """ + Calculate the average weight according to the number of samples + + Parameters + ---------- + clients: List + All clients in federated learning job + + Returns + ------- + update_weights : Array-like + final weights use to update model layer + """ + + print("aggregation....") + if not len(clients): + return self.weights + self.total_size = sum([c.num_samples for c in clients]) + old_weight = [np.zeros(np.array(c).shape) for c in next(iter(clients)).weights] + updates = [] + for inx, row in enumerate(old_weight): + for c in clients: + row += np.array(c.weights[inx]) * c.num_samples / self.total_size + updates.append(row.tolist()) + self.weights = deepcopy(updates) + print("finish aggregation....") + return updates diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/algorithm.yaml b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/algorithm.yaml new file mode 100644 index 00000000..5c6b87cf --- /dev/null +++ b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/algorithm.yaml @@ -0,0 +1,49 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "federatedclassincrementallearning" + fl_data_setting: + # ratio of training dataset; float type; + # the default value is 0.8. + train_ratio: 1.0 + # the method of splitting dataset; string type; optional; + # currently the options of value are as follows: + # 1> "default": the dataset is evenly divided based train_ratio; + splitting_method: "default" + label_data_ratio: 1.0 + data_partition: "iid" + # the url address of initial network for network pre-training; string url; + # the url address of initial network; string type; optional; + initial_model_url: "/home/wyd/ianvs/project/init_model/cnn.pb" + # algorithm module configuration in the paradigm; list type; + # incremental rounds setting of incremental learning; int type; default value is 2; + + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "fcil" + # the url address of python module; string type; + url: "./examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/basemodel.py" + + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + - batch_size: + values: + - 32 + - learning_rate: + values: + - 0.001 + - epochs: + values: + - 1 + - type: "aggregation" + name: "FedAvg" + url: "./examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/aggregation.py" + diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/basemodel.py b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/basemodel.py new file mode 100644 index 00000000..d9bc71b5 --- /dev/null +++ b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/basemodel.py @@ -0,0 +1,162 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import zipfile +import logging +import keras +import numpy as np +import tensorflow as tf +from sedna.common.class_factory import ClassType, ClassFactory +from resnet import resnet10 +from network import NetWork, incremental_learning + +__all__ = ["BaseModel"] +os.environ["BACKEND_TYPE"] = "KERAS" +logging.getLogger().setLevel(logging.INFO) + + +@ClassFactory.register(ClassType.GENERAL, alias="fcil") +class BaseModel: + def __init__(self, **kwargs): + self.kwargs = kwargs + print(f"kwargs: {kwargs}") + self.batch_size = kwargs.get("batch_size", 1) + print(f"batch_size: {self.batch_size}") + self.epochs = kwargs.get("epochs", 1) + self.lr = kwargs.get("lr", 0.001) + self.optimizer = keras.optimizers.SGD(learning_rate=self.lr) + self.old_task_id = -1 + self.fe = resnet10(10) + logging.info(type(self.fe)) + self.model = NetWork(100, self.fe) + self._init_model() + + def _init_model(self): + self.model.compile( + optimizer="sgd", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + x = np.random.rand(1, 32, 32, 3) + y = np.random.randint(0, 10, 1) + + self.model.fit(x, y, epochs=1) + + def load(self, model_url=None): + logging.info(f"load model from {model_url}") + extra_model_path = os.path.basename(model_url) + "/model" + with zipfile.ZipFile(model_url, "r") as zip_ref: + zip_ref.extractall(extra_model_path) + self.model = tf.saved_model.load(extra_model_path) + + def _initialize(self): + logging.info(f"initialize finished") + + def get_weights(self): + logging.info(f"get_weights") + weights = [layer.tolist() for layer in self.model.get_weights()] + logging.info(len(weights)) + return weights + + def set_weights(self, weights): + weights = [np.array(layer) for layer in weights] + self.model.set_weights(weights) + logging.info("----------finish set weights-------------") + + def save(self, model_path=""): + logging.info("save model") + + def model_info(self, model_path, result, relpath): + logging.info("model info") + return {} + + def train(self, train_data, valid_data, **kwargs): + round = kwargs.get("round", -1) + self.model.compile( + optimizer=self.optimizer, + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + logging.info( + f"train data: {train_data['label_x'].shape} {train_data['label_y'].shape}" + ) + train_db = self.data_process(train_data) + logging.info(train_db) + for epoch in range(self.epochs): + total_loss = 0 + total_num = 0 + logging.info(f"Epoch {epoch + 1} / {self.epochs}") + logging.info("-" * 50) + for x, y in train_db: + with tf.GradientTape() as tape: + logits = self.model(x, training=True) + loss = tf.reduce_mean( + keras.losses.sparse_categorical_crossentropy( + y, logits, from_logits=True + ) + ) + grads = tape.gradient(loss, self.model.trainable_variables) + self.optimizer.apply(grads, self.model.trainable_variables) + # self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables)) + total_loss += loss + total_num += 1 + + logging.info( + f"train round {round}: Epoch {epoch + 1} avg loss: {total_loss / total_num}" + ) + logging.info(f"finish round {round} train") + self.eval(train_data, round) + return {"num_samples": train_data["label_x"].shape[0]} + + def predict(self, data, **kwargs): + result = {} + for data in data.x: + x = np.load(data) + logits = self.model(x, training=False) + pred = tf.cast(tf.argmax(logits, axis=1), tf.int32) + result[data] = pred.numpy() + logging.info("finish predict") + return result + + def eval(self, data, round, **kwargs): + total_num = 0 + total_correct = 0 + data = self.data_process(data) + for i, (x, y) in enumerate(data): + logits = self.model(x, training=False) + pred = tf.argmax(logits, axis=1) + pred = tf.cast(pred, dtype=tf.int32) + pred = tf.reshape(pred, y.shape) + correct = tf.cast(tf.equal(pred, y), dtype=tf.int32) + correct = tf.reduce_sum(correct) + total_num += x.shape[0] + total_correct += int(correct) + logging.info(f"total_correct: {total_correct}, total_num: {total_num}") + acc = total_correct / total_num + del total_correct + logging.info(f"finsih round {round}evaluate, acc: {acc}") + return acc + + def data_process(self, data, **kwargs): + + assert data is not None, "data is None" + x_trian = data["label_x"] + y_train = data["label_y"] + # data[0]'shape = (50000, 32,32,3) data[1]'shape = (50000,1) + return ( + tf.data.Dataset.from_tensor_slices((x_trian, y_train)) + .shuffle(100000) + .batch(self.batch_size) + ) diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/network.py b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/network.py new file mode 100644 index 00000000..20630623 --- /dev/null +++ b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/network.py @@ -0,0 +1,83 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import keras +import tensorflow as tf +import numpy as np +from keras.src.layers import Dense +from resnet import resnet10 + + +class NetWork(keras.Model): + def __init__(self, num_classes, feature_extractor): + super(NetWork, self).__init__() + self.num_classes = num_classes + self.feature = feature_extractor + self.fc = Dense(num_classes, activation="softmax") + + def call(self, inputs): + x = self.feature(inputs) + x = self.fc(x) + return x + + def feature_extractor(self, inputs): + return self.feature.predict(inputs) + + def predict(self, fea_input): + return self.fc(fea_input) + + def get_config(self): + return { + "num_classes": self.num_classes, + "feature_extractor": self.feature, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def incremental_learning(old_model: NetWork, num_class): + new_model = NetWork(num_class, resnet10(num_class)) + x = np.random.rand(1, 32, 32, 3) + y = np.random.randint(0, num_class, 1) + new_model.compile( + optimizer="sgd", loss="sparse_categorical_crossentropy", metrics=["accuracy"] + ) + new_model.fit(x, y, epochs=1) + print(old_model.fc.units, new_model.fc.units) + for layer in old_model.layers: + if hasattr(new_model.feature, layer.name): + new_model.feature.__setattr__(layer.name, layer) + if num_class > old_model.fc.units: + original_use_bias = hasattr(old_model.fc, "bias") + print("original_use_bias", original_use_bias) + init_class = old_model.fc.units + k = new_model.fc.kernel + new_model.fc.kernel.assign( + tf.pad(old_model.fc.kernel, [[0, 0], [0, num_class - init_class]]) + ) + if original_use_bias: + new_model.fc.bias.assign( + tf.pad(old_model.fc.bias, [[0, num_class - init_class]]) + ) + new_model.build((None, 32, 32, 3)) + return new_model + + +def copy_model(model: NetWork): + cfg = model.get_config() + + copy_model = model.from_config(cfg) + return copy_model diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/resnet.py b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/resnet.py new file mode 100644 index 00000000..48db9122 --- /dev/null +++ b/examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/resnet.py @@ -0,0 +1,120 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import keras + + +# Input--conv2D--BN--ReLU--conv2D--BN--ReLU--Output +# \ / +# ------------------------------ +class BasicBlock(keras.layers.Layer): + def __init__(self, filter_num, stride=1): + super(BasicBlock, self).__init__() + + self.conv1 = keras.layers.Conv2D( + filter_num, (3, 3), strides=stride, padding="same" + ) + self.bn1 = keras.layers.BatchNormalization() + self.relu = keras.layers.Activation("relu") + + self.conv2 = keras.layers.Conv2D(filter_num, (3, 3), strides=1, padding="same") + self.bn2 = keras.layers.BatchNormalization() + + if stride != 1: + self.downsample = keras.models.Sequential() + self.downsample.add(keras.layers.Conv2D(filter_num, (1, 1), strides=stride)) + else: + self.downsample = lambda x: x + + def call(self, inputs, training=None): + # [b, h, w, c] + out = self.conv1(inputs) + out = self.bn1(out, training=training) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out, training=training) + + identity = self.downsample(inputs) + + output = keras.layers.add([out, identity]) + output = tf.nn.relu(output) + + return output + + +class ResNet(keras.Model): + def __init__(self, layer_dims, num_classes=10): # [2, 2, 2, 2] + super(ResNet, self).__init__() + self.layer_dims = layer_dims + self.num_classes = num_classes + + self.stem = keras.models.Sequential( + [ + keras.layers.Conv2D(64, (3, 3), strides=(1, 1)), + keras.layers.BatchNormalization(), + keras.layers.Activation("relu"), + keras.layers.MaxPool2D( + pool_size=(2, 2), strides=(1, 1), padding="same" + ), + ] + ) + + self.layer1 = self.build_resblock(64, layer_dims[0]) + self.layer2 = self.build_resblock(128, layer_dims[1], stride=2) + self.layer3 = self.build_resblock(256, layer_dims[2], stride=2) + self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) + + # output: [b, 512, h, w], + self.avgpool = keras.layers.GlobalAveragePooling2D() + + def call(self, inputs, training=None): + x = self.stem(inputs, training=training) + + x = self.layer1(x, training=training) + x = self.layer2(x, training=training) + x = self.layer3(x, training=training) + x = self.layer4(x, training=training) + + # [b, c] + x = self.avgpool(x) + return x + + def build_resblock(self, filter_num, blocks, stride=1): + res_blocks = keras.models.Sequential() + # may down sample + res_blocks.add(BasicBlock(filter_num, stride)) + for _ in range(1, blocks): + res_blocks.add(BasicBlock(filter_num, stride=1)) + return res_blocks + + def get_config(self): + return {"layer_dims": self.layer_dims, "num_classes": self.num_classes} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def resnet10(num_classes: int): + return ResNet([1, 1, 1, 1], num_classes) + + +def resnet18(num_classes: int): + return ResNet([2, 2, 2, 2], num_classes) + + +def resnet34(num_classes: int): + return ResNet([3, 4, 6, 3], num_classes) diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/benchmarkingjob.yaml b/examples/cifar100/federated_class_incremental_learning/fedavg/benchmarkingjob.yaml new file mode 100644 index 00000000..864c9b29 --- /dev/null +++ b/examples/cifar100/federated_class_incremental_learning/fedavg/benchmarkingjob.yaml @@ -0,0 +1,71 @@ +benchmarkingjob: + # job name of bechmarking; string type; + name: "benchmarkingjob" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "/home/wyd/ianvs/federated_class_incremental_learning/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/cifar100/federated_class_incremental_learning/fedavg/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "fcil_test" + # the url address of test algorithm configuration file; string type; + # the file format supports yaml/yml + url: "./examples/cifar100/federated_class_incremental_learning/fedavg/algorithm/algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "accuracy": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "F1_SCORE" + metrics: [ "accuracy" ] + + # network of save selected and all dataitems in workspace `./rank` ; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/testenv/__init__.py b/examples/cifar100/federated_class_incremental_learning/fedavg/testenv/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/testenv/acc.py b/examples/cifar100/federated_class_incremental_learning/fedavg/testenv/acc.py new file mode 100644 index 00000000..0fe532d3 --- /dev/null +++ b/examples/cifar100/federated_class_incremental_learning/fedavg/testenv/acc.py @@ -0,0 +1,35 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ['acc'] + + +@ClassFactory.register(ClassType.GENERAL, alias='accuracy') +def accuracy(y_true, y_pred, **kwargs): + y_pred_arr = [val for val in y_pred.values()] + y_true_arr = [] + for i in range(len(y_pred_arr)): + y_true_arr.append(np.full(y_pred_arr[i].shape, int(y_true[i]))) + y_pred = tf.cast(tf.convert_to_tensor(np.concatenate(y_pred_arr, axis=0)), tf.int64) + y_true = tf.cast(tf.convert_to_tensor(np.concatenate(y_true_arr, axis=0)), tf.int64) + total = tf.shape(y_true)[0] + correct = tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.int32)) + acc = float(int(correct) / total) + print(f"acc:{acc}") + return acc + diff --git a/examples/cifar100/federated_class_incremental_learning/fedavg/testenv/testenv.yaml b/examples/cifar100/federated_class_incremental_learning/fedavg/testenv/testenv.yaml new file mode 100644 index 00000000..78ae70cd --- /dev/null +++ b/examples/cifar100/federated_class_incremental_learning/fedavg/testenv/testenv.yaml @@ -0,0 +1,36 @@ +testenv: + backend: "tensorflow" + dataset: + name: 'cifar100' + # the url address of train dataset index; string type; + train_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_train.txt" + # the url address of test dataset index; string type; + test_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_test.txt" + + + # network eval configuration of incremental learning; + model_eval: + # metric used for network evaluation + model_metric: + # metric name; string type; + name: "accuracy" + # the url address of python file + url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/federated_class_incremental_learning/fedavg/testenv/acc.py" + + # condition of triggering inference network to update + # threshold of the condition; types are float/int + threshold: 0.01 + # operator of the condition; string type; + # values are ">=", ">", "<=", "<" and "="; + operator: "<=" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + - name: "accuracy" + # the url address of python file + url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/federated_class_incremental_learning/fedavg/testenv/acc.py" + + # incremental rounds setting of incremental learning; int type; default value is 2; + incremental_rounds: 2 + round: 2 \ No newline at end of file diff --git a/examples/cifar100/federated_learning/fedavg/algorithm/aggregation.py b/examples/cifar100/federated_learning/fedavg/algorithm/aggregation.py new file mode 100644 index 00000000..df1de505 --- /dev/null +++ b/examples/cifar100/federated_learning/fedavg/algorithm/aggregation.py @@ -0,0 +1,55 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from copy import deepcopy + +import numpy as np +from sedna.algorithms.aggregation.aggregation import BaseAggregation +from sedna.common.class_factory import ClassType, ClassFactory + + +@ClassFactory.register(ClassType.FL_AGG, "FedAvg") +class FedAvg(BaseAggregation, abc.ABC): + def __init__(self): + super(FedAvg, self).__init__() + + def aggregate(self, clients): + """ + Calculate the average weight according to the number of samples + + Parameters + ---------- + clients: + All clients in federated learning job + + Returns + ------- + update_weights : Array-like + final weights use to update model layer + """ + + print("aggregation....") + if not len(clients): + return self.weights + self.total_size = sum([c.num_samples for c in clients]) + old_weight = [np.zeros(np.array(c).shape) for c in next(iter(clients)).weights] + updates = [] + for inx, row in enumerate(old_weight): + for c in clients: + row += np.array(c.weights[inx]) * c.num_samples / self.total_size + updates.append(row.tolist()) + self.weights = deepcopy(updates) + print("finish aggregation....") + return updates diff --git a/examples/cifar100/federated_learning/fedavg/algorithm/algorithm.yaml b/examples/cifar100/federated_learning/fedavg/algorithm/algorithm.yaml new file mode 100644 index 00000000..7b37eb88 --- /dev/null +++ b/examples/cifar100/federated_learning/fedavg/algorithm/algorithm.yaml @@ -0,0 +1,49 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "federatedlearning" + fl_data_setting: + # ratio of training dataset; float type; + # the default value is 0.8. + train_ratio: 1.0 + # the method of splitting dataset; string type; optional; + # currently the options of value are as follows: + # 1> "default": the dataset is evenly divided based train_ratio; + splitting_method: "default" + label_data_ratio: 1.0 + data_partition: "iid" + # the url address of initial network for network pre-training; string url; + # the url address of initial network; string type; optional; + initial_model_url: "/home/wyd/ianvs/project/init_model/restnet.pb" + # algorithm module configuration in the paradigm; list type; + # incremental rounds setting of incremental learning; int type; default value is 2; + + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "fedavg" + # the url address of python module; string type; + url: "./examples/cifar100/federated_learning/fedavg/algorithm/basemodel.py" + + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + - batch_size: + values: + - 32 + - learning_rate: + values: + - 0.001 + - epochs: + values: + - 10 + - type: "aggregation" + name: "FedAvg" + url: "./examples/cifar100/federated_learning/fedavg/algorithm/aggregation.py" + diff --git a/examples/cifar100/federated_learning/fedavg/algorithm/basemodel.py b/examples/cifar100/federated_learning/fedavg/algorithm/basemodel.py new file mode 100644 index 00000000..faf9e3fe --- /dev/null +++ b/examples/cifar100/federated_learning/fedavg/algorithm/basemodel.py @@ -0,0 +1,194 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import zipfile + +import keras +import numpy as np +import tensorflow as tf +from keras import Sequential +from keras.src.layers import Conv2D, MaxPooling2D, Flatten, Dropout, Dense +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ["BaseModel"] +os.environ["BACKEND_TYPE"] = "KEARS" + + +@ClassFactory.register(ClassType.GENERAL, alias="fedavg") +class BaseModel: + def __init__(self, **kwargs): + self.batch_size = kwargs.get("batch_size", 1) + print(f"batch_size: {self.batch_size}") + self.epochs = kwargs.get("epochs", 1) + self.lr = kwargs.get("lr", 0.001) + self.model = self.build(num_classes=100) + self.optimizer = tf.keras.optimizers.SGD( + learning_rate=self.lr, weight_decay=0.0001 + ) + self._init_model() + + @staticmethod + def build(num_classes: int): + model = Sequential() + model.add( + Conv2D( + 64, + kernel_size=(3, 3), + activation="relu", + strides=(2, 2), + input_shape=(32, 32, 3), + ) + ) + model.add(MaxPooling2D(pool_size=(2, 2))) + model.add(Conv2D(32, kernel_size=(3, 3), activation="relu")) + model.add(MaxPooling2D(pool_size=(2, 2))) + model.add(Flatten()) + model.add(Dropout(0.25)) + model.add(Dense(64, activation="relu")) + model.add(Dense(32, activation="relu")) + model.add(Dropout(0.5)) + model.add(Dense(num_classes, activation="softmax")) + + model.compile( + loss="categorical_crossentropy", optimizer="sgd", metrics=["accuracy"] + ) + return model + + def _init_model(self): + self.model.compile( + optimizer="sgd", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + x = np.random.rand(1, 32, 32, 3) + y = np.random.randint(0, 100, 1) + + self.model.fit(x, y, epochs=1) + + def load(self, model_url=None): + print(f"load model from {model_url}") + extra_model_path = os.path.basename(model_url) + "/model" + with zipfile.ZipFile(model_url, "r") as zip_ref: + zip_ref.extractall(extra_model_path) + self.model = tf.saved_model.load(extra_model_path) + + def _initialize(self): + print(f"initialize finished") + + def get_weights(self): + print(f"get_weights") + weights = [layer.tolist() for layer in self.model.get_weights()] + print(len(weights)) + return weights + + def set_weights(self, weights): + weights = [np.array(layer) for layer in weights] + self.model.set_weights(weights) + print("----------finish set weights-------------") + + def save(self, model_path=""): + print("save model") + + def model_info(self, model_path, result, relpath): + print("model info") + return {} + + def train(self, train_data, valid_data, **kwargs): + round = kwargs.get("round", -1) + print(f"train data: {train_data[0].shape} {train_data[1].shape}") + train_db = self.data_process(train_data) + print(train_db) + for epoch in range(self.epochs): + total_loss = 0 + total_num = 0 + print(f"Epoch {epoch + 1} / {self.epochs}") + print("-" * 50) + for x, y in train_db: + with tf.GradientTape() as tape: + logits = self.model(x, training=True) + y_pred = tf.cast(tf.argmax(logits, axis=1), tf.int32) + correct = tf.equal(y_pred, y) + correct = tf.cast(correct, dtype=tf.int32) + correct = tf.reduce_sum(correct) + y = tf.one_hot(y, depth=100) + # y = tf.squeeze(y, axis=1) + loss = tf.reduce_mean( + keras.losses.categorical_crossentropy( + y, logits, from_logits=True + ) + ) + print( + f"loss is {loss}, correct {correct} total is {x.shape[0]} acc : {correct / x.shape[0]}" + ) + grads = tape.gradient(loss, self.model.trainable_variables) + self.optimizer.apply(grads, self.model.trainable_variables) + total_loss += loss + total_num += 1 + + print( + f"train round {round}: Epoch {epoch + 1} avg loss: {total_loss / total_num}" + ) + print(f"finish round {round} train") + return {"num_samples": train_data[0].shape[0]} + + def predict(self, data, **kwargs): + result = {} + mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + for data in data.x: + x = np.load(data) + x = (tf.cast(x, dtype=tf.float32) / 255.0 - mean) / std + logits = self.model(x, training=False) + pred = tf.cast(tf.argmax(logits, axis=1), tf.int32) + result[data] = pred.numpy() + print("finish predict") + return result + + def eval(self, data, round, **kwargs): + total_num = 0 + total_correct = 0 + data = self.data_process(data) + print(f"in evalute data: {data}") + for i, (x, y) in enumerate(data): + logits = self.model(x, training=False) + pred = tf.argmax(logits, axis=1) + pred = tf.cast(pred, dtype=tf.int32) + pred = tf.reshape(pred, y.shape) + correct = tf.cast(tf.equal(pred, y), dtype=tf.int32) + correct = tf.reduce_sum(correct) + total_num += x.shape[0] + total_correct += int(correct) + print(f"total_correct: {total_correct}, total_num: {total_num}") + acc = total_correct / total_num + del total_correct + print(f"finsih round {round}evaluate, acc: {acc}") + return acc + + def data_process(self, data, **kwargs): + mean = np.array((0.5071, 0.4867, 0.4408), np.float32).reshape(1, 1, -1) + std = np.array((0.2675, 0.2565, 0.2761), np.float32).reshape(1, 1, -1) + assert data is not None, "data is None" + # data[0]'shape = (50000, 32,32,3) data[1]'shape = (50000,1) + return ( + tf.data.Dataset.from_tensor_slices((data[0][:5000], data[1][:5000])) + .shuffle(100000) + .map( + lambda x, y: ( + (tf.cast(x, dtype=tf.float32) / 255.0 - mean) / std, + tf.cast(y, dtype=tf.int32), + ) + ) + .batch(self.batch_size) + ) diff --git a/examples/cifar100/federated_learning/fedavg/benchmarkingjob.yaml b/examples/cifar100/federated_learning/fedavg/benchmarkingjob.yaml new file mode 100644 index 00000000..6f6e794f --- /dev/null +++ b/examples/cifar100/federated_learning/fedavg/benchmarkingjob.yaml @@ -0,0 +1,71 @@ +benchmarkingjob: + # job name of bechmarking; string type; + name: "benchmarkingjob" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "/home/wyd/ianvs/federated_learning/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/cifar100/federated_learning/fedavg/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "fedavg_test" + # the url address of test algorithm configuration file; string type; + # the file format supports yaml/yml + url: "./examples/cifar100/federated_learning/fedavg/algorithm/algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "accuracy": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "F1_SCORE" + metrics: [ "accuracy" ] + + # network of save selected and all dataitems in workspace `./rank` ; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + diff --git a/examples/cifar100/federated_learning/fedavg/testenv/__init__.py b/examples/cifar100/federated_learning/fedavg/testenv/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/cifar100/federated_learning/fedavg/testenv/acc.py b/examples/cifar100/federated_learning/fedavg/testenv/acc.py new file mode 100644 index 00000000..e7e9fde9 --- /dev/null +++ b/examples/cifar100/federated_learning/fedavg/testenv/acc.py @@ -0,0 +1,33 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf +import numpy as np +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ["acc"] + + +@ClassFactory.register(ClassType.GENERAL, alias="accuracy") +def accuracy(y_true, y_pred, **kwargs): + y_pred_arr = [val for val in y_pred.values()] + y_true_arr = [] + for i in range(len(y_pred_arr)): + y_true_arr.append(np.full(y_pred_arr[i].shape, int(y_true[i]))) + y_pred = tf.cast(tf.convert_to_tensor(np.concatenate(y_pred_arr, axis=0)), tf.int64) + y_true = tf.cast(tf.convert_to_tensor(np.concatenate(y_true_arr, axis=0)), tf.int64) + total = tf.shape(y_true)[0] + correct = tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.int32)) + acc = float(int(correct) / total) + print(f"acc:{acc}") + return acc diff --git a/examples/cifar100/federated_learning/fedavg/testenv/testenv.yaml b/examples/cifar100/federated_learning/fedavg/testenv/testenv.yaml new file mode 100644 index 00000000..4e3912b1 --- /dev/null +++ b/examples/cifar100/federated_learning/fedavg/testenv/testenv.yaml @@ -0,0 +1,37 @@ +testenv: + backend: "tensorflow" + dataset: + name: 'cifar100' + # the url address of train dataset index; string type; + train_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_train.txt" + # the url address of test dataset index; string type; + test_url: "/home/wyd/ianvs/project/data/cifar100/cifar100_test.txt" + + + # network eval configuration of incremental learning; + model_eval: + # metric used for network evaluation + model_metric: + # metric name; string type; + name: "accuracy" + # the url address of python file + url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/federated_learning/fedavg/testenv/acc.py" + + # condition of triggering inference network to update + # threshold of the condition; types are float/int + threshold: 0.01 + # operator of the condition; string type; + # values are ">=", ">", "<=", "<" and "="; + operator: "<=" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + - name: "accuracy" + # the url address of python file + url: "/home/wyd/ianvs/project/ianvs/examples/cifar100/federated_learning/fedavg/testenv/acc.py" + + # incremental rounds setting of incremental learning; int type; default value is 2; + task_size: 10 + incremental_rounds: 10 + round: 200 \ No newline at end of file diff --git a/examples/cifar100/sedna_federated_learning/aggregation_worker/aggregate.py b/examples/cifar100/sedna_federated_learning/aggregation_worker/aggregate.py new file mode 100644 index 00000000..26eabb65 --- /dev/null +++ b/examples/cifar100/sedna_federated_learning/aggregation_worker/aggregate.py @@ -0,0 +1,35 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sedna.common.config import Context +from sedna.service.server import AggregationServer + + +def run_server(): + aggregation_algorithm = Context.get_parameters("aggregation_algorithm", "FedAvg") + exit_round = int(Context.get_parameters("exit_round", 3)) + participants_count = int(Context.get_parameters("participants_count", 1)) + + server = AggregationServer( + aggregation=aggregation_algorithm, + exit_round=exit_round, + ws_size=20 * 1024 * 1024, + participants_count=participants_count, + host="127.0.0.1", + ) + server.start() + + +if __name__ == "__main__": + run_server() diff --git a/examples/cifar100/sedna_federated_learning/train_worker/basemodel.py b/examples/cifar100/sedna_federated_learning/train_worker/basemodel.py new file mode 100644 index 00000000..60ed50e9 --- /dev/null +++ b/examples/cifar100/sedna_federated_learning/train_worker/basemodel.py @@ -0,0 +1,100 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tensorflow as tf +import numpy as np +from keras.src.layers import Dense, MaxPooling2D, Conv2D, Flatten, Dropout +from keras.src.models import Sequential + +os.environ["BACKEND_TYPE"] = "KERAS" + + +class Estimator: + def __init__(self, **kwargs): + """Model init""" + + self.model = self.build() + self.has_init = False + + @staticmethod + def build(): + model = Sequential() + model.add( + Conv2D( + 64, + kernel_size=(3, 3), + activation="relu", + strides=(2, 2), + input_shape=(32, 32, 3), + ) + ) + model.add(MaxPooling2D(pool_size=(2, 2))) + model.add(Conv2D(32, kernel_size=(3, 3), activation="relu")) + model.add(MaxPooling2D(pool_size=(2, 2))) + model.add(Flatten()) + model.add(Dropout(0.25)) + model.add(Dense(64, activation="relu")) + model.add(Dense(32, activation="relu")) + model.add(Dropout(0.5)) + model.add(Dense(1, activation="softmax")) + + model.compile( + loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"] + ) + return model + + def train( + self, + train_data, + valid_data=None, + epochs=1, + batch_size=1, + learning_rate=0.01, + validation_split=0.2, + ): + """Model train""" + train_loader = ( + tf.data.Dataset.from_tensor_slices(train_data) + .shuffle(500000) + .batch(batch_size) + ) + history = self.model.fit(train_loader, epochs=int(epochs)) + return {k: list(map(np.float, v)) for k, v in history.history.items()} + + def get_weights(self): + return self.model.get_weights() + + def set_weights(self, weights): + self.model.set_weights(weights) + + def load_weights(self, model): + if not os.path.isfile(model): + return + return self.model.load_weights(model) + + def predict(self, datas): + return self.model.predict(datas) + + def evaluate(self, test_data, **kwargs): + pass + + def load(self, model_url): + print("load model") + + def save(self, model_path=None): + """ + save model as a single pb file from checkpoint + """ + print("save model") diff --git a/examples/cifar100/sedna_federated_learning/train_worker/train.py b/examples/cifar100/sedna_federated_learning/train_worker/train.py new file mode 100644 index 00000000..5471708d --- /dev/null +++ b/examples/cifar100/sedna_federated_learning/train_worker/train.py @@ -0,0 +1,61 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sedna.core.federated_learning import FederatedLearning +from sedna.datasources import TxtDataParse +import numpy as np +from basemodel import Estimator + + +def read_data_from_file_to_npy(files): + """ + read data from file to numpy array + + Parameters + --------- + files: list + the address url of data file. + + Returns + ------- + list + data in numpy array. + + """ + x_train = [] + y_train = [] + for i, file in enumerate(files.x): + x = np.load(file) + y = np.full((x.shape[0], 1), (files.y[i]).astype(np.int32)) + x_train.append(x) + y_train.append(y) + x_train = np.concatenate(x_train, axis=0) + y_train = np.concatenate(y_train, axis=0) + print(x_train.shape, y_train.shape) + return x_train, y_train + + +def main(): + train_file = "/home/wyd/ianvs/project/data/cifar100/cifar100_train.txt" + train_data = TxtDataParse(data_type="train") + train_data.parse(train_file) + train_data = read_data_from_file_to_npy(train_data) + epochs = 3 + batch_size = 128 + fl_job = FederatedLearning(estimator=Estimator(), aggregation="FedAvg") + fl_job.train(train_data=train_data, epochs=epochs, batch_size=batch_size) + + +if __name__ == "__main__": + main() diff --git a/examples/cifar100/utils.py b/examples/cifar100/utils.py new file mode 100644 index 00000000..57333e20 --- /dev/null +++ b/examples/cifar100/utils.py @@ -0,0 +1,70 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +import numpy as np +import os + + +def process_cifar100(): + if not os.path.exists("/home/wyd/ianvs/project/data/cifar100"): + os.makedirs("/home/wyd/ianvs/project/data/cifar100") + train_txt = "/home/wyd/ianvs/project/data/cifar100/cifar100_train.txt" + with open(train_txt, "w") as f: + pass + test_txt = "/home/wyd/ianvs/project/data/cifar100/cifar100_test.txt" + with open(test_txt, "w") as f: + pass + # load CIFAR-100 dataset + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data() + print(y_test.shape) + + # change label to class index + class_labels = np.unique(y_train) # get all class + train_class_dict = {label: [] for label in class_labels} + test_class_dict = {label: [] for label in class_labels} + # organize training data by category + for img, label in zip(x_train, y_train): + train_class_dict[label[0]].append(img) + # organize testing data by category + for img, label in zip(x_test, y_test): + # test_class_dict[label[0]].append(img) + test_class_dict[label[0]].append(img) + # save training data to local file + for label, imgs in train_class_dict.items(): + data = np.array(imgs) + print(data.shape) + np.save( + f"/home/wyd/ianvs/project/data/cifar100/cifar100_train_index_{label}.npy", + data, + ) + with open(train_txt, "a") as f: + f.write( + f"/home/wyd/ianvs/project/data/cifar100/cifar100_train_index_{label}.npy\t{label}\n" + ) + # save test data to local file + for label, imgs in test_class_dict.items(): + np.save( + f"/home/wyd/ianvs/project/data/cifar100/cifar100_test_index_{label}.npy", + np.array(imgs), + ) + with open(test_txt, "a") as f: + f.write( + f"/home/wyd/ianvs/project/data/cifar100/cifar100_test_index_{label}.npy\t{label}\n" + ) + print(f"CIFAR-100 have saved as ianvs format") + + +if __name__ == "__main__": + process_cifar100() diff --git a/examples/government/singletask_learning_bench/README.md b/examples/government/singletask_learning_bench/README.md new file mode 100644 index 00000000..22dfbfed --- /dev/null +++ b/examples/government/singletask_learning_bench/README.md @@ -0,0 +1,104 @@ +# Government BenchMark + +## Introduction + +This is the work for Domain-specific Large Model Benchmark: + +Constructs a suite for the government sector, including test datasets, evaluation metrics, testing environments, and usage guidelines. + +This Benchmark consists of two parts: subjective evaluation data and objective evaluation data. + +## Design + +### Metadata Format + +| Name | Field Name | Option | Description | +| --- | --- | --- | --- | +| Data Name | dataset | Required | Name of the dataset | +| Data Description | description | Optional | Dataset description, such as usage scope, sample size, etc. | +| First-level Dimension | level_1_dim | Required | Should fill in "Single Modal" or "Multi-Modal" | +| Second-level Dimension | level_2_dim | Required | For "Single Modal", fill in "Text", "Image", or "Audio". For "Multi-Modal", fill in "Text-Image", "Text-Audio", "Image-Audio", or "Text-Image-Audio" | +| Third-level Dimension | level_3_dim | Optional | Should be filled if all samples in the dataset have the same third-level dimension. If filled, content should be based on the standards shown in the normative reference document | +| Fourth-level Dimension | level_4_dim | Optional | Should be filled if all samples in the dataset have the same third-level dimension. If filled, content should be based on the standards shown in the normative reference document | + +metadata example: + +```json +{ + "dataset": "Medical BenchMark", + "description": "xxx", + "level_1_dim": "single-modal", + "level_2_dim": "text", + "level_3_dim": "Q&A", + "level_4_dim": "medical" +} +``` + +### Data format: + +|name|Option|information| +|---|---|---| +|prompt|Optional|the background of the LLM testing| +|query|Required|the testing question| +|response|Required|the answer of the question| +|explanation|Optional|the explanation of the answer| +|judge_prompt|Optional|the prompt of the judge model| +|level_1_dim|Optional|single-modal or multi-modal| +|level_2_dim|Optional|single-modal: text, image, video; multi-modal: text-image, text-video, text-image-video| +|level_3_dim|Required|details| +|level_4_dim|Required|details| + +data example: + +```json +{ + "prompt": "Please think step by step and answer the question.", + "question": "Which one is the correct answer of xxx? A. xxx B. xxx C. xxx D. xxx", + "response": "C", + "explanation": "xxx", + "level_1_dim": "single-modal", + "level_2_dim": "text", + "level_3_dim": "knowledge Q&A", + "level_4_dim": "medical knowledge" +} +``` + + +## Change to Core Code + +![](./imgs/structure.png) + +## Prepare Datasets + +You can download dataset in [kaggle](https://www.kaggle.com/datasets/kubeedgeianvs/the-government-affairs-dataset-govaff/data?select=government_benchmark) + +``` +dataset/government +├── objective +│ ├── test_data +│ │ ├── data.jsonl +│ │ └── metadata.json +│ └── train_data +└── subjective + ├── test_data + │ ├── data_full.jsonl + │ ├── data.jsonl + │ └── metadata.json + └── train_data +``` + +## Prepare Environment + +You should change your sedna package like this: [my sedna repo commit](https://github.com/IcyFeather233/sedna/commit/e13b82363c03dc771fca4922a24798554ca32a9f) + +Or you can replace the file in `yourpath/anaconda3/envs/ianvs/lib/python3.x/site-packages/sedna` with `examples/resources/sedna-llm.zip` + +## Run Ianvs + +### Objective + +`ianvs -f examples/government/singletask_learning_bench/objective/benchmarkingjob.yaml` + +### Subjective + +`ianvs -f examples/government/singletask_learning_bench/subjective/benchmarkingjob.yaml` \ No newline at end of file diff --git a/examples/government/singletask_learning_bench/imgs/structure.png b/examples/government/singletask_learning_bench/imgs/structure.png new file mode 100644 index 00000000..22c1695e Binary files /dev/null and b/examples/government/singletask_learning_bench/imgs/structure.png differ diff --git a/examples/government/singletask_learning_bench/objective/benchmarkingjob.yaml b/examples/government/singletask_learning_bench/objective/benchmarkingjob.yaml new file mode 100644 index 00000000..38c8f2c5 --- /dev/null +++ b/examples/government/singletask_learning_bench/objective/benchmarkingjob.yaml @@ -0,0 +1,72 @@ +benchmarkingjob: + # job name of bechmarking; string type; + name: "benchmarkingjob" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "/home/icyfeather/project/ianvs/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/government/singletask_learning_bench/objective/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "politic_bench_singletask_learning" + # the url address of test algorithm configuration file; string type; + # the file format supports yaml/yml; + url: "./examples/government/singletask_learning_bench/objective/testalgorithms/gen/gen_algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "acc": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "f1_score" + metrics: [ "acc" ] + + # model of save selected and all dataitems in workspace; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + + diff --git a/examples/government/singletask_learning_bench/objective/testalgorithms/gen/basemodel.py b/examples/government/singletask_learning_bench/objective/testalgorithms/gen/basemodel.py new file mode 100644 index 00000000..b6340ec3 --- /dev/null +++ b/examples/government/singletask_learning_bench/objective/testalgorithms/gen/basemodel.py @@ -0,0 +1,105 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division + +import os +import tempfile +import time +import zipfile +import logging + +import numpy as np +import random +from tqdm import tqdm +from sedna.common.config import Context +from sedna.common.class_factory import ClassType, ClassFactory +from core.common.log import LOGGER + + +from transformers import AutoModelForCausalLM, AutoTokenizer +device = "cuda" # the device to load the model onto + + +logging.disable(logging.WARNING) + +__all__ = ["BaseModel"] + +os.environ['BACKEND_TYPE'] = 'TORCH' + + +@ClassFactory.register(ClassType.GENERAL, alias="gen") +class BaseModel: + + def __init__(self, **kwargs): + self.model = AutoModelForCausalLM.from_pretrained( + "/home/icyfeather/models/Qwen2-0.5B-Instruct", + torch_dtype="auto", + device_map="auto" + ) + self.tokenizer = AutoTokenizer.from_pretrained("/home/icyfeather/models/Qwen2-0.5B-Instruct") + + def train(self, train_data, valid_data=None, **kwargs): + LOGGER.info("BaseModel train") + + + def save(self, model_path): + LOGGER.info("BaseModel save") + + def predict(self, data, input_shape=None, **kwargs): + LOGGER.info("BaseModel predict") + LOGGER.info(f"Dataset: {data.dataset_name}") + LOGGER.info(f"Description: {data.description}") + LOGGER.info(f"Data Level 1 Dim: {data.level_1_dim}") + LOGGER.info(f"Data Level 2 Dim: {data.level_2_dim}") + + answer_list = [] + for line in tqdm(data.x, desc="Processing", unit="question"): + # 3-shot + indices = random.sample([i for i, l in enumerate(data.x) if l != line], 3) + history = [] + for idx in indices: + history.append({"role": "user", "content": data.x[idx]}) + history.append({"role": "assistant", "content": data.y[idx]}) + history.append({"role": "user", "content": line}) + response = self._infer(history) + answer_list.append(response) + return answer_list + + def load(self, model_url=None): + LOGGER.info("BaseModel load") + + def evaluate(self, data, model_path, **kwargs): + LOGGER.info("BaseModel evaluate") + + def _infer(self, messages): + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = self.tokenizer([text], return_tensors="pt").to(device) + + generated_ids = self.model.generate( + model_inputs.input_ids, + max_new_tokens=512, + temperature = 0.1, + top_p = 0.9 + ) + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + return response diff --git a/examples/government/singletask_learning_bench/objective/testalgorithms/gen/gen_algorithm.yaml b/examples/government/singletask_learning_bench/objective/testalgorithms/gen/gen_algorithm.yaml new file mode 100644 index 00000000..3167cbe8 --- /dev/null +++ b/examples/government/singletask_learning_bench/objective/testalgorithms/gen/gen_algorithm.yaml @@ -0,0 +1,18 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "singletasklearning" + + # algorithm module configuration in the paradigm; list type; + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "gen" + # the url address of python module; string type; + url: "./examples/government/singletask_learning_bench/objective/testalgorithms/gen/basemodel.py" \ No newline at end of file diff --git a/examples/government/singletask_learning_bench/objective/testenv/acc.py b/examples/government/singletask_learning_bench/objective/testenv/acc.py new file mode 100644 index 00000000..a4041f48 --- /dev/null +++ b/examples/government/singletask_learning_bench/objective/testenv/acc.py @@ -0,0 +1,39 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ["acc"] + +def get_last_letter(input_string): + if not input_string or not any(char.isalpha() for char in input_string): + return None + + for char in reversed(input_string): + if 'A' <= char <= 'D': + return char + + return None + + +@ClassFactory.register(ClassType.GENERAL, alias="acc") +def acc(y_true, y_pred): + y_pred = [get_last_letter(pred) for pred in y_pred] + y_true = [get_last_letter(pred) for pred in y_true] + + same_elements = [y_pred[i] == y_true[i] for i in range(len(y_pred))] + + acc = sum(same_elements) / len(same_elements) + + return acc diff --git a/examples/government/singletask_learning_bench/objective/testenv/testenv.yaml b/examples/government/singletask_learning_bench/objective/testenv/testenv.yaml new file mode 100644 index 00000000..e3a13834 --- /dev/null +++ b/examples/government/singletask_learning_bench/objective/testenv/testenv.yaml @@ -0,0 +1,14 @@ +testenv: + # dataset configuration + dataset: + # the url address of train dataset index; string type; + train_data: "/home/icyfeather/Projects/ianvs/dataset/government/objective/train_data/data.jsonl" + # the url address of test dataset index; string type; + test_data_info: "/home/icyfeather/Projects/ianvs/dataset/government/objective/test_data/metadata.json" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + - name: "acc" + # the url address of python file + url: "./examples/government/singletask_learning_bench/objective/testenv/acc.py" diff --git a/examples/government/singletask_learning_bench/subjective/benchmarkingjob.yaml b/examples/government/singletask_learning_bench/subjective/benchmarkingjob.yaml new file mode 100644 index 00000000..26008c3c --- /dev/null +++ b/examples/government/singletask_learning_bench/subjective/benchmarkingjob.yaml @@ -0,0 +1,72 @@ +benchmarkingjob: + # job name of bechmarking; string type; + name: "benchmarkingjob" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "/home/icyfeather/project/ianvs/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/government/singletask_learning_bench/subjective/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "politic_bench_singletask_learning" + # the url address of test algorithm configuration file; string type; + # the file format supports yaml/yml; + url: "./examples/government/singletask_learning_bench/subjective/testalgorithms/gen/gen_algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "llm_judgement": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "f1_score" + metrics: [ "llm_judgement" ] + + # model of save selected and all dataitems in workspace; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + + diff --git a/examples/government/singletask_learning_bench/subjective/testalgorithms/gen/basemodel.py b/examples/government/singletask_learning_bench/subjective/testalgorithms/gen/basemodel.py new file mode 100644 index 00000000..ee7f2585 --- /dev/null +++ b/examples/government/singletask_learning_bench/subjective/testalgorithms/gen/basemodel.py @@ -0,0 +1,131 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division + +import os +import tempfile +import time +import zipfile +import logging + +import numpy as np +import random +from tqdm import tqdm +from sedna.common.config import Context +from sedna.common.class_factory import ClassType, ClassFactory +from core.common.log import LOGGER +from openai import OpenAI + +from transformers import AutoModelForCausalLM, AutoTokenizer +device = "cuda" # the device to load the model onto + + +logging.disable(logging.WARNING) + +__all__ = ["BaseModel"] + +os.environ['BACKEND_TYPE'] = 'TORCH' + + +@ClassFactory.register(ClassType.GENERAL, alias="gen") +class BaseModel: + + def __init__(self, **kwargs): + self.model = AutoModelForCausalLM.from_pretrained( + "/home/icyfeather/models/Qwen2-0.5B-Instruct", + torch_dtype="auto", + device_map="auto" + ) + self.tokenizer = AutoTokenizer.from_pretrained("/home/icyfeather/models/Qwen2-0.5B-Instruct") + + def train(self, train_data, valid_data=None, **kwargs): + LOGGER.info("BaseModel train") + + + def save(self, model_path): + LOGGER.info("BaseModel save") + + def predict(self, data, input_shape=None, **kwargs): + LOGGER.info("BaseModel predict") + LOGGER.info(f"Dataset: {data.dataset_name}") + LOGGER.info(f"Description: {data.description}") + LOGGER.info(f"Data Level 1 Dim: {data.level_1_dim}") + LOGGER.info(f"Data Level 2 Dim: {data.level_2_dim}") + + answer_list = [] + for line in tqdm(data.x, desc="Processing", unit="question"): + history = [] + history.append({"role": "user", "content": line}) + response = self._infer(history) + answer_list.append(response) + + judgement_list = [] + + # evaluate by llm + for index in tqdm(range(len(answer_list)), desc="Evaluating", ascii=False, ncols=75): + prompt = data.judge_prompts[index] + answer_list[index] + judgement = self._openai_generate(prompt) + judgement_list.append(judgement) + + return judgement_list + + def load(self, model_url=None): + LOGGER.info("BaseModel load") + + def evaluate(self, data, model_path, **kwargs): + LOGGER.info("BaseModel evaluate") + + def _infer(self, messages): + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = self.tokenizer([text], return_tensors="pt").to(device) + + generated_ids = self.model.generate( + model_inputs.input_ids, + max_new_tokens=512, + temperature = 0.1, + top_p = 0.9 + ) + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + return response + + + def _openai_generate(self, user_question, system=None): + key = os.getenv("DEEPSEEK_API_KEY") + if not key: + raise ValueError("You should set DEEPSEEK_API_KEY in your env.") + client = OpenAI(api_key=key, base_url="https://api.deepseek.com") + + messages = [] + if system: + messages.append({"role": "system", "content": system}) + messages.append({"role": "user", "content": user_question}) + + response = client.chat.completions.create( + model="deepseek-chat", + messages=messages, + stream=False + ) + + res = response.choices[0].message.content + + return res \ No newline at end of file diff --git a/examples/government/singletask_learning_bench/subjective/testalgorithms/gen/gen_algorithm.yaml b/examples/government/singletask_learning_bench/subjective/testalgorithms/gen/gen_algorithm.yaml new file mode 100644 index 00000000..f20e9047 --- /dev/null +++ b/examples/government/singletask_learning_bench/subjective/testalgorithms/gen/gen_algorithm.yaml @@ -0,0 +1,18 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "singletasklearning" + + # algorithm module configuration in the paradigm; list type; + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "gen" + # the url address of python module; string type; + url: "./examples/government/singletask_learning_bench/subjective/testalgorithms/gen/basemodel.py" \ No newline at end of file diff --git a/examples/government/singletask_learning_bench/subjective/testenv/llm_judgement.py b/examples/government/singletask_learning_bench/subjective/testenv/llm_judgement.py new file mode 100644 index 00000000..97cbc72a --- /dev/null +++ b/examples/government/singletask_learning_bench/subjective/testenv/llm_judgement.py @@ -0,0 +1,42 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from sedna.common.class_factory import ClassType, ClassFactory +from core.common.log import LOGGER + +__all__ = ["llm_judgement"] + +def extract_comprehensive_score(input_str): + # extract overall points + match = re.search(r"'Overall Points': (\d+)", input_str) + if match: + return int(match.group(1)) + else: + return None + + +@ClassFactory.register(ClassType.GENERAL, alias="llm_judgement") +def llm_judgement(y_true, y_pred): + y_pred = [extract_comprehensive_score(pred) for pred in y_pred] + + valid_scores = [score for score in y_pred if score is not None] + + LOGGER.info(f"Extracted {len(valid_scores)} datas from {len(y_pred)} datas") + + if valid_scores: + average_score = sum(valid_scores) / len(valid_scores) + return average_score + else: + return -1 diff --git a/examples/government/singletask_learning_bench/subjective/testenv/testenv.yaml b/examples/government/singletask_learning_bench/subjective/testenv/testenv.yaml new file mode 100644 index 00000000..f197b2fb --- /dev/null +++ b/examples/government/singletask_learning_bench/subjective/testenv/testenv.yaml @@ -0,0 +1,14 @@ +testenv: + # dataset configuration + dataset: + # the url address of train dataset index; string type; + train_data: "/home/icyfeather/Projects/ianvs/dataset/government/subjective/train_data/data.jsonl" + # the url address of test dataset index; string type; + test_data_info: "/home/icyfeather/Projects/ianvs/dataset/government/subjective/test_data/metadata.json" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + - name: "llm_judgement" + # the url address of python file + url: "./examples/government/singletask_learning_bench/subjective/testenv/llm_judgement.py" diff --git a/examples/imagenet/multiedge_inference_bench/README.md b/examples/imagenet/multiedge_inference_bench/README.md new file mode 100644 index 00000000..ae0fd01e --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/README.md @@ -0,0 +1,104 @@ +# Benchmarking of Image Clasification for High Mobility Scenarios + +In high-mobility scenarios such as highways and high-speed railways, the connection between personal terminal devices and cloud servers is significantly weakened. However, in recent years, artificial intelligence technology has permeated every aspect of our lives, and we also need to use artificial intelligence technologies with high computational and storage demands and sensitive to latency in high-mobility scenarios. For example, even when driving through a tunnel with a weak network environment, we may still need to use AI capabilities such as image classification. Therefore, in the event that edge devices lose connection with the cloud, offloading AI computing tasks to adjacent edge devices and achieving computational aggregation based on the mutual collaboration between devices, to complete computing tasks that traditionally require cloud-edge collaboration, has become an issue worth addressing. This benchmarking job aims to simulate such scenario: using multiple heterogeneous computing units on the edge (such as personal mobile phones, tablets, bracelets, laptops, and other computing devices) for collaborative ViT inference, enabling image classification to be completed with lower latency using devices that are closer to the edge, thereby enhancing the user experience.After running benchmarking jobs, a report will be generated. + +With Ianvs installed and related environment prepared, users is then able to run the benchmarking process using the following steps. If you haven't installed Ianvs, please refer to [how-to-install-ianvs](../../../docs/guides/how-to-install-ianvs.md). + +## Prerequisites + +To setup the environment, run the following commands: +```shell +cd +pip install ./examples/resources/third_party/* +pip install -r requirements.txt +cd ./examples/imagenet/multiedge_inference_bench/ +pip install -r requirements.txt +cd +mkdir dataset initial_model +``` +Please refer to [this link](https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html) and ensure that the versions of CUDA and cuDNN are compatible with the version of ONNX Runtime. + +Note that it is advisable to avoid using lower versions of the ONNX library, as they are very time-consuming when performing computational graph partitioning. The version of onnx-cuda-cudnn we used in our tests is as follows: +![onnx_version](images/onnx_version.png) + +## Step 1. Prepare Dataset +Download [ImageNet 2012 dataset](https://image-net.org/download.php) and put it under /dataset in the following structure: + +``` +dataset + |------ILSVRC2012_devkit_t12.tar.gz + |------ILSVRC2012_img_val.tar +``` +Then, you need to process the dataset and generate the _train.txt_ and _val.txt_: + +```shell +cd +python ./examples/imagenet/multiedge_inference_bench/testalgorithms/manual/dataset.py +``` + +## Step 2. Prepare Model + +Next, download pretrained model via [[huggingface]](https://huggingface.co/optimum/vit-base-patch16-224/tree/main), rename it to vit-base-patch16-224.onnx and put it under /initial_model/ + +## Step 3. Run Benchmarking Job - Manual +We are now ready to run the ianvs for benchmarking image classification for high mobility scenarios on the ImageNet dataset. + +```python +ianvs -f ./examples/imagenet/multiedge_inference_bench/classification_job_manual.yaml +``` + +The benchmarking process takes a few minutes and varies depending on devices. + +## Step 4. Check the Result + +Finally, the user can check the result of benchmarking on the console and also in the output path (/ianvs/multiedge_inference_bench/workspace) defined in the benchmarking config file (classification_job.yaml). + +The final output might look like this: +![result](images/result.png) + +You can view the graphical representation of relevant metrics in /ianvs/multiedge_inference_bench/workspace/images/, such as the following: +![plot](images/plot.png) + +To compare the running conditions of the model with and without parallelism in the multiedge inference scenario, you can modify the value of --devices_info in base_model.py to devices_one.yaml to view the relevant metrics when the model runs on a single device. + +## Step 5. Run Benchmarking Job - Automatic +We offer a profiling-based and memory matching partition algorithm to compare with the method of manually specifying partitioning points. This method prioritizes the memory matching between the computational subgraph and the device. First, we profile the initial model on the CPU to collect memory usage, the number of parameters, computational cost, and the input and output data shapes for each layer, as well as the total number of layers and their names in the entire model. To facilitate subsequent integration, we have implemented profiling for three types of transformer models: vit, bert, and deit. Secondly, based on the results of the profiling and the device information provided in devices.yaml, we can identify the partitioning point that matches the device memory through a single traversal and perform model partitioning. + +You should first run the following command to generate a profiling result: +```shell +cd +python ./examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/profiler.py +``` + +Then you will find a profiler_results.yml file in the /examples/imagenet/multiedge_inference_bench/testalgorithms/automatic directory, just like this: +![profiler_result](images/profiler_results.png) + +Then you can run the following command to perform benchmarking: +```shell +ianvs -f ./examples/imagenet/multiedge_inference_bench/classification_job_auto.yaml +``` + +After running, you will see the profit from the automatic method compared with the manual method. +![result](images/auto_result.png) + +## Explanation for devices.yaml + +This file defines the specific information of edge-side multi-devices and the model's partition points. The devices section includes the computing resource type, memory, frequency, and bandwidth for each device. The partition_points section defines the input and output names of each computational subgraph and their mapping relationships with devices. This benchmarking job achieves the partitioning of the computational graph and model parallelism by manually defining partition points. You can implement custom partitioning algorithms based on the rich device information in devices.yaml. + +## Custom Partitioning Algorithms + +How to partition an ONNX model based on device information is an interesting question. You can solve this issue using greedy algorithms, dynamic programming algorithms, or other innovative graph algorithms to achieve optimal resource utilization and the lowest inference latency. + +More partitioning algorithms will be added in the future and you can customize their own partition methods in basemodel.py, they only need to comply with the input and output specifications defined by the interface as follows: + +``` +def partiton(self, initial_model): + ## 1. parsing + ## 2. modeling + ## 3. partition + return models_dir, map_info +``` + +Hope you have a perfect journey in solving this problem! + + diff --git a/examples/imagenet/multiedge_inference_bench/classification_job_automatic.yaml b/examples/imagenet/multiedge_inference_bench/classification_job_automatic.yaml new file mode 100644 index 00000000..4e236e71 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/classification_job_automatic.yaml @@ -0,0 +1,72 @@ +benchmarkingjob: + # job name of benchmarking; string type; + name: "classification_job" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "./multiedge_inference_bench/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/imagenet/multiedge_inference_bench/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "classification" + # # the url address of test algorithm configuration file; string type; + # # the file format supports yaml/yml; + url: "./examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/classification_algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "mota": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "f1_score" + metrics: [ "all" ] + + # model of save selected and all dataitems in workspace; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + + diff --git a/examples/imagenet/multiedge_inference_bench/classification_job_manual.yaml b/examples/imagenet/multiedge_inference_bench/classification_job_manual.yaml new file mode 100644 index 00000000..1b251c38 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/classification_job_manual.yaml @@ -0,0 +1,72 @@ +benchmarkingjob: + # job name of benchmarking; string type; + name: "classification_job" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "./multiedge_inference_bench/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/imagenet/multiedge_inference_bench/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "classification" + # # the url address of test algorithm configuration file; string type; + # # the file format supports yaml/yml; + url: "./examples/imagenet/multiedge_inference_bench/testalgorithms/manual/classification_algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "mota": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "f1_score" + metrics: [ "all" ] + + # model of save selected and all dataitems in workspace; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + + diff --git a/examples/imagenet/multiedge_inference_bench/images/auto_result.png b/examples/imagenet/multiedge_inference_bench/images/auto_result.png new file mode 100644 index 00000000..e75443db Binary files /dev/null and b/examples/imagenet/multiedge_inference_bench/images/auto_result.png differ diff --git a/examples/imagenet/multiedge_inference_bench/images/onnx_version.png b/examples/imagenet/multiedge_inference_bench/images/onnx_version.png new file mode 100644 index 00000000..076e4a16 Binary files /dev/null and b/examples/imagenet/multiedge_inference_bench/images/onnx_version.png differ diff --git a/examples/imagenet/multiedge_inference_bench/images/plot.png b/examples/imagenet/multiedge_inference_bench/images/plot.png new file mode 100644 index 00000000..c3f93796 Binary files /dev/null and b/examples/imagenet/multiedge_inference_bench/images/plot.png differ diff --git a/examples/imagenet/multiedge_inference_bench/images/profiler_results.png b/examples/imagenet/multiedge_inference_bench/images/profiler_results.png new file mode 100644 index 00000000..9b46bb37 Binary files /dev/null and b/examples/imagenet/multiedge_inference_bench/images/profiler_results.png differ diff --git a/examples/imagenet/multiedge_inference_bench/images/result.png b/examples/imagenet/multiedge_inference_bench/images/result.png new file mode 100644 index 00000000..9a4ef272 Binary files /dev/null and b/examples/imagenet/multiedge_inference_bench/images/result.png differ diff --git a/examples/imagenet/multiedge_inference_bench/requirements.txt b/examples/imagenet/multiedge_inference_bench/requirements.txt new file mode 100644 index 00000000..18b77692 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/requirements.txt @@ -0,0 +1,113 @@ +absl-py==2.1.0 +addict==2.4.0 +asgiref==3.7.2 +cachetools==5.3.3 +certifi==2024.2.2 +charset-normalizer==3.3.2 +click==8.1.7 +coloredlogs==15.0.1 +colorlog==4.7.2 +cycler==0.11.0 +Cython==3.0.10 +cython-bbox==0.1.5 +fastapi==0.68.2 +filelock==3.12.2 +filterpy==1.4.5 +flatbuffers==24.3.25 +fonttools==4.38.0 +fpdf==1.7.2 +fsspec==2023.1.0 +google-auth==2.29.0 +google-auth-oauthlib==0.4.6 +grpcio==1.62.2 +h11==0.14.0 +h5py==3.8.0 +huggingface-hub==0.16.4 +humanfriendly==10.0 +ianvs==0.1.0 +idna==3.7 +imageio==2.31.2 +importlib-metadata==6.7.0 +joblib==1.2.0 +kiwisolver==1.4.5 +lap==0.4.0 +loguru==0.7.2 +Markdown==3.4.4 +markdown-it-py==2.2.0 +MarkupSafe==2.1.5 +matplotlib==3.5.3 +mdurl==0.1.2 +minio==7.0.4 +mmcv==1.5.0 +mmengine==0.10.4 +motmetrics==1.4.0 +mpmath==1.3.0 +networkx==2.6.3 +ninja==1.11.1.1 +numpy==1.21.6 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +oauthlib==3.2.2 +onnx==1.14.1 +onnx-simplifier==0.3.5 +onnxoptimizer==0.3.13 +onnxruntime==1.14.1 +onnxruntime-gpu==1.14.1 +opencv-python==4.9.0.80 +packaging==24.0 +pandas==1.3.5 +Pillow==9.5.0 +platformdirs==4.0.0 +prettytable==2.5.0 +protobuf==3.20.3 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pycocotools==2.0.7 +pydantic==1.10.15 +Pygments==2.17.2 +pynvml==11.5.3 +pyparsing==3.1.2 +python-dateutil==2.9.0.post0 +pytz==2024.1 +PyWavelets==1.3.0 +PyYAML==6.0.1 +regex==2024.4.16 +requests==2.31.0 +requests-oauthlib==2.0.0 +rich==13.7.1 +rsa==4.9 +safetensors==0.4.5 +scikit-image==0.19.3 +scikit-learn==1.0.2 +scipy==1.7.3 +seaborn==0.12.2 +six==1.15.0 +starlette==0.14.2 +sympy==1.10.1 +tabulate==0.9.0 +tenacity==8.0.1 +tensorboard==2.11.2 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +termcolor==2.3.0 +thop==0.1.1.post2209072238 +threadpoolctl==3.1.0 +tifffile==2021.11.2 +tokenizers==0.13.3 +tomli==2.0.1 +torch==1.13.1 +torchvision==0.14.1 +tqdm==4.66.4 +transformers==4.30.2 +typing_extensions==4.7.1 +urllib3==2.0.7 +uvicorn==0.14.0 +wcwidth==0.2.13 +websockets==9.1 +Werkzeug==2.2.3 +xmltodict==0.13.0 +yapf==0.40.2 +-e git+https://github.com/ifzhang/ByteTrack.git@d1bf0191adff59bc8fcfeaa0b33d3d1642552a99#egg=yolox +zipp==3.15.0 diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/basemodel.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/basemodel.py new file mode 100644 index 00000000..a35e5483 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/basemodel.py @@ -0,0 +1,205 @@ +# Modified Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import glob +import os +from collections import OrderedDict +from pathlib import Path +from collections import defaultdict +import time + +from sedna.common.class_factory import ClassType, ClassFactory +from dataset import load_dataset +import model_cfg + +import yaml +import onnxruntime as ort +from torch.utils.data import DataLoader +import torch +import numpy as np +from tqdm import tqdm +import pynvml + + +__all__ = ["BaseModel"] + +# set backend +os.environ["BACKEND_TYPE"] = "ONNX" + + +def make_parser(): + parser = argparse.ArgumentParser("ViT Eval") + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument("--devices_info", default="./devices.yaml", type=str, help="devices conf") + parser.add_argument("--profiler_info", default="./profiler_results.yml", type=str, help="profiler results") + parser.add_argument("--model_parallel", default=True, action="store_true") + parser.add_argument("--split", default="val", type=str, help="split of dataset") + parser.add_argument("--indices", default=None, type=str, help="indices of dataset") + parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle data") + parser.add_argument("--model_name", default="google/vit-base-patch16-224", type=str, help="model name") + parser.add_argument("--dataset_name", default="ImageNet", type=str, help="dataset name") + parser.add_argument("--data_size", default=1000, type=int, help="data size to inference") + # remove conflict with ianvs + parser.add_argument("-f") + return parser + + +@ClassFactory.register(ClassType.GENERAL, alias="Classification") +class BaseModel: + + def __init__(self, **kwargs) -> None: + self.args = make_parser().parse_args() + self.model_parallel = self.args.model_parallel + self.models = [] + self.devices_info_url = str(Path(Path(__file__).parent.resolve(), self.args.devices_info)) + self.device_info = self._parse_yaml(self.devices_info_url) + self.profiler_info_url = str(Path(Path(__file__).parent.resolve(), self.args.profiler_info)) + self.profiler_info = self._parse_yaml(self.profiler_info_url) + self.partition_point_list = [] + return + + ## auto partition by memory usage + def partition(self, initial_model): + map_info = {} + def _partition_model(pre, cur, flag): + print("========= Sub Model {} Partition =========".format(flag)) + model = model_cfg.module_shard_factory(self.args.model_name, initial_model, pre+1, cur+1, 1) + dummy_input = torch.randn(1, *self.profiler_info.get('profile_data')[pre].get("shape_in")[0]) + torch.onnx.export(model, + dummy_input, + str(Path(Path(initial_model).parent.resolve())) + "/sub_model_" + str(flag) + ".onnx", + export_params=True, + opset_version=16, + do_constant_folding=True, + input_names=['input_' + str(pre+1)], + output_names=['output_' + str(cur+1)]) + self.partition_point_list.append({ + 'input_names': ['input_' + str(pre+1)], + 'output_names': ['output_' + str(cur+1)] + }) + map_info["sub_model_" + str(flag) + ".onnx"] = self.device_info.get('devices')[flag-1].get("name") + + layer_list = [(layer.get("memory"), len(layer.get("shape_out"))) for layer in self.profiler_info.get('profile_data')] + total_model_memory = sum([layer[0] for layer in layer_list]) + devices_memory = [int(device.get('memory')) for device in self.device_info.get('devices')] + total_devices_memory = sum(devices_memory) + devices_memory = [per_mem * total_model_memory / total_devices_memory for per_mem in devices_memory] + + flag = 0 + sum_ = 0 + pre = 0 + for cur, layer in enumerate(layer_list): + if flag == len(devices_memory)-1: + cur = len(layer_list) + _partition_model(pre, cur-1, flag+1) + break + elif layer[1] == 1 and sum_ >= devices_memory[flag]: + sum_ = 0 + flag += 1 + _partition_model(pre, cur, flag) + pre = cur + 1 + else: + sum_ += layer[0] + continue + return str(Path(Path(initial_model).parent.resolve())), map_info + + + def load(self, models_dir=None, map_info=None) -> None: + cnt = 0 + for model_name, device in map_info.items(): + model = models_dir + '/' + model_name + if not os.path.exists(model): + raise ValueError("=> No modle found at '{}'".format(model)) + if device == 'cpu': + session = ort.InferenceSession(model, providers=['CPUExecutionProvider']) + elif 'gpu' in device: + device_id = int(device.split('-')[-1]) + session = ort.InferenceSession(model, providers=[('CUDAExecutionProvider', {'device_id': device_id})]) + else: + raise ValueError("Error device info: '{}'".format(device)) + self.models.append({ + 'session': session, + 'name': model_name, + 'device': device, + 'input_names': self.partition_point_list[cnt]['input_names'], + 'output_names': self.partition_point_list[cnt]['output_names'], + }) + cnt += 1 + print("=> Loaded onnx model: '{}'".format(model)) + return + + def predict(self, data, input_shape=None, **kwargs): + pynvml.nvmlInit() + root = str(Path(data[0]).parents[2]) + dataset_cfg = { + 'name': self.args.dataset_name, + 'root': root, + 'split': self.args.split, + 'indices': self.args.indices, + 'shuffle': self.args.shuffle + } + data_loader, ids = self._get_eval_loader(dataset_cfg) + data_loader = tqdm(data_loader, desc='Evaluating', unit='batch') + pred = [] + inference_time_per_device = defaultdict(int) + power_usage_per_device = defaultdict(list) + mem_usage_per_device = defaultdict(list) + cnt = 0 + for data, id in zip(data_loader, ids): + outputs = data[0].numpy() + for model in self.models: + start_time = time.time() + outputs = model['session'].run(None, {model['input_names'][0]: outputs})[0] + end_time = time.time() + device = model.get('device') + inference_time_per_device[device] += end_time - start_time + if 'gpu' in device and cnt % 100 == 0: + handle = pynvml.nvmlDeviceGetHandleByIndex(int(device.split('-')[-1])) + power_usage = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 + memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle).used / (1024**2) + power_usage_per_device[device] += [power_usage] + mem_usage_per_device[device] += [memory_info] + max_ids = np.argmax(outputs) + pred.append((max_ids, id)) + cnt += 1 + data_loader.close() + result = dict({}) + result["pred"] = pred + result["inference_time_per_device"] = inference_time_per_device + result["power_usage_per_device"] = power_usage_per_device + result["mem_usage_per_device"] = mem_usage_per_device + return result + + + def _get_eval_loader(self, dataset_cfg): + model_name = self.args.model_name + data_size = self.args.data_size + dataset, _, ids = load_dataset(dataset_cfg, model_name, data_size) + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + return data_loader, ids + + def _parse_yaml(self, url): + """Convert yaml file to the dict.""" + if url.endswith('.yaml') or url.endswith('.yml'): + with open(url, "rb") as file: + info_dict = yaml.load(file, Loader=yaml.SafeLoader) + return info_dict + else: + raise RuntimeError('config file must be the yaml format') \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/classification_algorithm.yaml b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/classification_algorithm.yaml new file mode 100644 index 00000000..cea77f57 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/classification_algorithm.yaml @@ -0,0 +1,27 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "multiedgeinference" + # the url address of initial model; string type; optional; + initial_model_url: "./initial_model/ViT-B_16-224.npz" + + # algorithm module configuration in the paradigm; list type; + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "Classification" + # the url address of python module; string type; + url: "./examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/basemodel.py" + + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + - batch_size: + values: + - 1 diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/dataset.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/dataset.py new file mode 100644 index 00000000..9b4ee16c --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/dataset.py @@ -0,0 +1,71 @@ +import logging +import random +from typing import Callable, Optional, Sequence +import os + +from torch.utils.data import DataLoader, Dataset, Subset +from transformers import ViTFeatureExtractor +from torchvision.datasets import ImageNet + + +def load_dataset_imagenet(feature_extractor: Callable, root: str, split: str='train') -> Dataset: + """Get the ImageNet dataset.""" + + def transform(img): + pixels = feature_extractor(images=img.convert('RGB'), return_tensors='pt')['pixel_values'] + return pixels[0] + return ImageNet(root, split=split, transform=transform) + +def load_dataset_subset(dataset: Dataset, indices: Optional[Sequence[int]]=None, + max_size: Optional[int]=None, shuffle: bool=False) -> Dataset: + """Get a Dataset subset.""" + if indices is None: + indices = list(range(len(dataset))) + if shuffle: + random.shuffle(indices) + if max_size is not None: + indices = indices[:max_size] + image_paths = [] + for index in indices: + image_paths.append(dataset.imgs[index][0]) + return Subset(dataset, indices), image_paths, indices + +def load_dataset(dataset_cfg: dict, model_name: str, batch_size: int) -> Dataset: + """Load inputs based on model.""" + def _get_feature_extractor(): + feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) + return feature_extractor + dataset_name = dataset_cfg['name'] + dataset_root = dataset_cfg['root'] + dataset_split = dataset_cfg['split'] + indices = dataset_cfg['indices'] + dataset_shuffle = dataset_cfg['shuffle'] + if dataset_name == 'ImageNet': + if dataset_root is None: + dataset_root = 'ImageNet' + logging.info("Dataset root not set, assuming: %s", dataset_root) + feature_extractor = _get_feature_extractor() + dataset = load_dataset_imagenet(feature_extractor, dataset_root, split=dataset_split) + dataset, paths, ids = load_dataset_subset(dataset, indices=indices, max_size=batch_size, + shuffle=dataset_shuffle) + return dataset, paths, ids + +if __name__ == '__main__': + dataset_cfg = { + 'name': "ImageNet", + 'root': './dataset', + 'split': 'val', + 'indices': None, + 'shuffle': False, + } + model_name = "google/vit-base-patch16-224" + ## Total images to be inferenced. + data_size = 1000 + dataset, paths, _ = load_dataset(dataset_cfg, model_name, data_size) + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + with open('./dataset/train.txt', 'w') as f: + for i, (image, label) in enumerate(data_loader): + original_path = paths[i].replace('/dataset', '') + f.write(f"{original_path} {label.item()}\n") + f.close() + os.popen('cp ./dataset/train.txt ./dataset/test.txt') \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.py new file mode 100644 index 00000000..6093e150 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.py @@ -0,0 +1,24 @@ +"""Common device configuration.""" +from typing import Tuple, Union +import torch + +# The torch.device to use for computation +DEVICE = None + +def forward_pre_hook_to_device(_module, inputs) \ + -> Union[Tuple[torch.tensor], Tuple[Tuple[torch.Tensor]]]: + """Move tensors to the compute device (e.g., GPU), if needed.""" + assert isinstance(inputs, tuple) + assert len(inputs) == 1 + if isinstance(inputs[0], torch.Tensor): + inputs = (inputs,) + tensors_dev = tuple(t.to(device=DEVICE) for t in inputs[0]) + return tensors_dev if len(tensors_dev) == 1 else (tensors_dev,) + +def forward_hook_to_cpu(_module, _inputs, outputs) -> Union[torch.tensor, Tuple[torch.Tensor]]: + """Move tensors to the CPU, if needed.""" + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + assert isinstance(outputs, tuple) + tensors_cpu = tuple(t.cpu() for t in outputs) + return tensors_cpu[0] if len(tensors_cpu) == 1 else tensors_cpu diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.yaml b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.yaml new file mode 100644 index 00000000..1317c4ac --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/devices.yaml @@ -0,0 +1,16 @@ +devices: + - name: "gpu-0" + type: "gpu" + memory: "1024" + freq: "2.6" + bandwith: "100" + - name: "gpu-1" + type: "gpu" + memory: "1024" + freq: "2.6" + bandwith: "80" + - name: "gpu-2" + type: "gpu" + memory: "1024" + freq: "2.6" + bandwith: "90" \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/model_cfg.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/model_cfg.py new file mode 100644 index 00000000..df41a426 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/model_cfg.py @@ -0,0 +1,93 @@ +"""Model configurations and default parameters.""" +import logging +from typing import Any, Callable, List, Optional, Tuple +from transformers import AutoConfig +from models import ModuleShard, ModuleShardConfig +from models.transformers import bert, deit, vit +import devices + +_logger = logging.getLogger(__name__) + +_MODEL_CONFIGS = {} + +def _model_cfg_add(name, layers, weights_file, shard_module): + _MODEL_CONFIGS[name] = { + 'name': name, + 'layers': layers, + 'weights_file': weights_file, + 'shard_module': shard_module, + } + +# Transformer blocks can be split 4 ways, e.g., where ViT-Base has 12 layers, we specify 12*4=48 +_model_cfg_add('google/vit-base-patch16-224', 48, './initial_model/ViT-B_16-224.npz', + vit.ViTShardForImageClassification) +_model_cfg_add('google/vit-large-patch16-224', 96, 'ViT-L_16-224.npz', + vit.ViTShardForImageClassification) +# NOTE: This ViT-Huge model doesn't include classification, so the config must be extended +_model_cfg_add('google/vit-huge-patch14-224-in21k', 128, 'ViT-H_14.npz', + vit.ViTShardForImageClassification) +# NOTE: BertModelShard alone doesn't do classification +_model_cfg_add('bert-base-uncased', 48, 'BERT-B.npz', + bert.BertModelShard) +_model_cfg_add('bert-large-uncased', 96, 'BERT-L.npz', + bert.BertModelShard) +_model_cfg_add('textattack/bert-base-uncased-CoLA', 48, 'BERT-B-CoLA.npz', + bert.BertShardForSequenceClassification) +_model_cfg_add('facebook/deit-base-distilled-patch16-224', 48, 'DeiT_B_distilled.npz', + deit.DeiTShardForImageClassification) +_model_cfg_add('facebook/deit-small-distilled-patch16-224', 48, 'DeiT_S_distilled.npz', + deit.DeiTShardForImageClassification) +_model_cfg_add('facebook/deit-tiny-distilled-patch16-224', 48, 'DeiT_T_distilled.npz', + deit.DeiTShardForImageClassification) + +def get_model_names() -> List[str]: + """Get a list of available model names.""" + return list(_MODEL_CONFIGS.keys()) + +def get_model_dict(model_name: str) -> dict: + """Get a model's key/value properties - modify at your own risk.""" + return _MODEL_CONFIGS[model_name] + +def get_model_layers(model_name: str) -> int: + """Get a model's layer count.""" + return _MODEL_CONFIGS[model_name]['layers'] + +def get_model_config(model_name: str) -> Any: + """Get a model's config.""" + # We'll need more complexity if/when we add support for models not from `transformers` + config = AutoConfig.from_pretrained(model_name) + # Config overrides + if model_name == 'google/vit-huge-patch14-224-in21k': + # ViT-Huge doesn't include classification, so we have to set this ourselves + # NOTE: not setting 'id2label' or 'label2id' + config.num_labels = 21843 + return config + +def get_model_default_weights_file(model_name: str) -> str: + """Get a model's default weights file name.""" + return _MODEL_CONFIGS[model_name]['weights_file'] + +def save_model_weights_file(model_name: str, model_file: Optional[str]=None) -> None: + """Save a model's weights file.""" + if model_file is None: + model_file = get_model_default_weights_file(model_name) + # This works b/c all shard implementations have the same save_weights interface + module = _MODEL_CONFIGS[model_name]['shard_module'] + module.save_weights(model_name, model_file) + +def module_shard_factory(model_name: str, model_file: Optional[str], layer_start: int, + layer_end: int, stage: int) -> ModuleShard: + """Get a shard instance on the globally-configured `devices.DEVICE`.""" + # This works b/c all shard implementations have the same constructor interface + if model_file is None: + model_file = get_model_default_weights_file(model_name) + config = get_model_config(model_name) + is_first = layer_start == 1 + is_last = layer_end == get_model_layers(model_name) + shard_config = ModuleShardConfig(layer_start=layer_start, layer_end=layer_end, + is_first=is_first, is_last=is_last) + module = _MODEL_CONFIGS[model_name]['shard_module'] + shard = module(config, shard_config, model_file) + _logger.info("======= %s Stage %d =======", module.__name__, stage) + shard.to(device=devices.DEVICE) + return shard \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/__init__.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/__init__.py new file mode 100644 index 00000000..bbb74188 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/__init__.py @@ -0,0 +1,49 @@ +"""Models module.""" +from typing import Any, Tuple, Type, Union +from torch import nn, Tensor + +ModuleShardData: Type = Union[Tensor, Tuple[Tensor, ...]] +"""A module shard input/output type.""" + + +class ModuleShardConfig: + """Base class for shard configurations (distinct from model configurations).""" + # pylint: disable=too-few-public-methods + + def __init__(self, **kwargs: dict): + # Attributes with default values + self.layer_start: int = kwargs.pop('layer_start', 0) + self.layer_end: int = kwargs.pop('layer_end', 0) + self.is_first: bool = kwargs.pop('is_first', False) + self.is_last: bool = kwargs.pop('is_last', False) + + # Attributes without default values + for key, value in kwargs.items(): + setattr(self, key, value) + + +class ModuleShard(nn.Module): + """Abstract parent class for module shards.""" + # pylint: disable=abstract-method + + def __init__(self, config: Any, shard_config: ModuleShardConfig): + super().__init__() + self.config = config + self.shard_config = shard_config + + def has_layer(self, layer: int) -> bool: + """Check if shard has the specified layer.""" + return layer in range(self.shard_config.layer_start, self.shard_config.layer_end + 1) + + +def get_microbatch_size(shard_data: ModuleShardData, verify: bool=False): + """Get the microbatch size from shard data.""" + if isinstance(shard_data, Tensor): + shard_data = (shard_data,) + ubatch_size = 0 if len(shard_data) == 0 else len(shard_data[0]) + if verify: + # Sanity check that tensors are the same length + for tensor in shard_data: + assert isinstance(tensor, Tensor) + assert len(tensor) == ubatch_size + return ubatch_size diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/__init__.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/__init__.py new file mode 100644 index 00000000..532c96da --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/__init__.py @@ -0,0 +1,6 @@ +"""Transformers module.""" +from typing import Tuple, Type, Union +from torch import Tensor + +TransformerShardData: Type = Union[Tensor, Tuple[Tensor, Tensor]] +"""A transformer shard input/output type.""" diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/bert.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/bert.py new file mode 100644 index 00000000..e33fe989 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/bert.py @@ -0,0 +1,219 @@ +"""BERT transformers.""" +from collections.abc import Mapping +import logging +import math +from typing import Union +import numpy as np +import torch +from torch import nn +from transformers import BertConfig, BertForSequenceClassification, BertModel +from transformers.models.bert.modeling_bert import ( + BertEmbeddings, BertIntermediate, BertOutput, BertPooler, BertSelfAttention, BertSelfOutput +) +from .. import ModuleShard, ModuleShardConfig +from . import TransformerShardData + + +logger = logging.getLogger(__name__) + + +class BertLayerShard(ModuleShard): + """Module shard based on `BertLayer`.""" + + def __init__(self, config: BertConfig, shard_config: ModuleShardConfig): + super().__init__(config, shard_config) + self.self_attention = None + self.self_output = None + self.intermediate = None + self.output = None + self._build_shard() + + def _build_shard(self): + if self.has_layer(0): + self.self_attention = BertSelfAttention(self.config) + if self.has_layer(1): + self.self_output = BertSelfOutput(self.config) + if self.has_layer(2): + self.intermediate = BertIntermediate(self.config) + if self.has_layer(3): + self.output = BertOutput(self.config) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute layer shard.""" + if self.has_layer(0): + data = (self.self_attention(data)[0], data) + if self.has_layer(1): + data = self.self_output(data[0], data[1]) + if self.has_layer(2): + data = (self.intermediate(data), data) + if self.has_layer(3): + data = self.output(data[0], data[1]) + return data + + +class BertModelShard(ModuleShard): + """Module shard based on `BertModel`.""" + + def __init__(self, config: BertConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.embeddings = None + # BertModel uses an encoder here, but we'll just add the layers here instead. + # Since we just do inference, a BertEncoderShard class wouldn't provide real benefit. + self.layers = nn.ModuleList() + self.pooler = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + if self.shard_config.is_first: + logger.debug(">>>> Load embeddings layer for the first shard") + self.embeddings = BertEmbeddings(self.config) + self.embeddings.eval() + self._load_weights_first(weights) + + layer_curr = self.shard_config.layer_start + while layer_curr <= self.shard_config.layer_end: + layer_id = math.ceil(layer_curr / 4) - 1 + sublayer_start = (layer_curr - 1) % 4 + if layer_id == math.ceil(self.shard_config.layer_end / 4) - 1: + sublayer_end = (self.shard_config.layer_end - 1) % 4 + else: + sublayer_end = 3 + logger.debug(">>>> Load layer %d, sublayers %d-%d", + layer_id, sublayer_start, sublayer_end) + layer_config = ModuleShardConfig(layer_start=sublayer_start, layer_end=sublayer_end) + layer = BertLayerShard(self.config, layer_config) + self._load_weights_layer(weights, layer_id, layer) + self.layers.append(layer) + layer_curr += sublayer_end - sublayer_start + 1 + + if self.shard_config.is_last: + logger.debug(">>>> Load pooler for the last shard") + self.pooler = BertPooler(self.config) + self.pooler.eval() + self._load_weights_last(weights) + + @torch.no_grad() + def _load_weights_first(self, weights): + self.embeddings.position_ids.copy_(torch.from_numpy((weights["embeddings.position_ids"]))) + self.embeddings.word_embeddings.weight.copy_(torch.from_numpy(weights['embeddings.word_embeddings.weight'])) + self.embeddings.position_embeddings.weight.copy_(torch.from_numpy(weights['embeddings.position_embeddings.weight'])) + self.embeddings.token_type_embeddings.weight.copy_(torch.from_numpy(weights['embeddings.token_type_embeddings.weight'])) + self.embeddings.LayerNorm.weight.copy_(torch.from_numpy(weights['embeddings.LayerNorm.weight'])) + self.embeddings.LayerNorm.bias.copy_(torch.from_numpy(weights['embeddings.LayerNorm.bias'])) + + @torch.no_grad() + def _load_weights_last(self, weights): + self.pooler.dense.weight.copy_(torch.from_numpy(weights["pooler.dense.weight"])) + self.pooler.dense.bias.copy_(torch.from_numpy(weights['pooler.dense.bias'])) + + @torch.no_grad() + def _load_weights_layer(self, weights, layer_id, layer): + root = f"encoder.layer.{layer_id}." + if layer.has_layer(0): + layer.self_attention.query.weight.copy_(torch.from_numpy(weights[root + "attention.self.query.weight"])) + layer.self_attention.key.weight.copy_(torch.from_numpy(weights[root + "attention.self.key.weight"])) + layer.self_attention.value.weight.copy_(torch.from_numpy(weights[root + "attention.self.value.weight"])) + layer.self_attention.query.bias.copy_(torch.from_numpy(weights[root + "attention.self.query.bias"])) + layer.self_attention.key.bias.copy_(torch.from_numpy(weights[root + "attention.self.key.bias"])) + layer.self_attention.value.bias.copy_(torch.from_numpy(weights[root + "attention.self.value.bias"])) + if layer.has_layer(1): + layer.self_output.dense.weight.copy_(torch.from_numpy(weights[root + "attention.output.dense.weight"])) + layer.self_output.LayerNorm.weight.copy_(torch.from_numpy(weights[root + "attention.output.LayerNorm.weight"])) + layer.self_output.dense.bias.copy_(torch.from_numpy(weights[root + "attention.output.dense.bias"])) + layer.self_output.LayerNorm.bias.copy_(torch.from_numpy(weights[root + "attention.output.LayerNorm.bias"])) + if layer.has_layer(2): + layer.intermediate.dense.weight.copy_(torch.from_numpy(weights[root + "intermediate.dense.weight"])) + layer.intermediate.dense.bias.copy_(torch.from_numpy(weights[root + "intermediate.dense.bias"])) + if layer.has_layer(3): + layer.output.dense.weight.copy_(torch.from_numpy(weights[root + "output.dense.weight"])) + layer.output.dense.bias.copy_(torch.from_numpy(weights[root + "output.dense.bias"])) + layer.output.LayerNorm.weight.copy_(torch.from_numpy(weights[root + "output.LayerNorm.weight"])) + layer.output.LayerNorm.bias.copy_(torch.from_numpy(weights[root + "output.LayerNorm.bias"])) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + if self.shard_config.is_first: + data = self.embeddings(data) + for layer in self.layers: + data = layer(data) + if self.shard_config.is_last: + data = self.pooler(data) + return data + + @staticmethod + def save_weights(model_name: str, model_file: str) -> None: + """Save the model weights file.""" + model = BertModel.from_pretrained(model_name) + state_dict = model.state_dict() + weights = {} + for key, val in state_dict.items(): + weights[key] = val + np.savez(model_file, **weights) + + +class BertShardForSequenceClassification(ModuleShard): + """Module shard based on `BertForSequenceClassification`.""" + + def __init__(self, config: BertConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.bert = None + self.classifier = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + ## all shards use the inner BERT model + self.bert = BertModelShard(self.config, self.shard_config, + self._extract_weights_bert(weights)) + + if self.shard_config.is_last: + logger.debug(">>>> Load classifier for the last shard") + self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels) + self._load_weights_last(weights) + + def _extract_weights_bert(self, weights): + bert_weights = {} + for key, val in weights.items(): + if key.startswith('bert.'): + bert_weights[key[len('bert.'):]] = val + return bert_weights + + @torch.no_grad() + def _load_weights_last(self, weights): + self.classifier.weight.copy_(torch.from_numpy(weights['classifier.weight'])) + self.classifier.bias.copy_(torch.from_numpy(weights['classifier.bias'])) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + data = self.bert(data) + if self.shard_config.is_last: + data = self.classifier(data) + return data + + @staticmethod + def save_weights(model_name: str, model_file: str) -> None: + """Save the model weights file.""" + model = BertForSequenceClassification.from_pretrained(model_name) + state_dict = model.state_dict() + weights = {} + for key, val in state_dict.items(): + weights[key] = val + np.savez(model_file, **weights) diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/deit.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/deit.py new file mode 100644 index 00000000..dc6b6144 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/deit.py @@ -0,0 +1,233 @@ +"""DeiT Transformers.""" +from collections.abc import Mapping +import logging +import math +from typing import Optional, Union +import numpy as np +import torch +from torch import nn +from transformers import DeiTConfig +from transformers.models.deit.modeling_deit import DeiTEmbeddings +from transformers.models.vit.modeling_vit import ( + ViTIntermediate, ViTOutput, ViTSelfAttention, ViTSelfOutput +) +from .. import ModuleShard, ModuleShardConfig +from . import TransformerShardData + + +logger = logging.getLogger(__name__) + +_HUB_MODEL_NAMES = { + 'facebook/deit-base-distilled-patch16-224': 'deit_base_distilled_patch16_224', + 'facebook/deit-small-distilled-patch16-224': 'deit_small_distilled_patch16_224', + 'facebook/deit-tiny-distilled-patch16-224': 'deit_tiny_distilled_patch16_224', +} + + +class DeiTLayerShard(ModuleShard): + """Module shard based on `DeiTLayer` (copied from `.vit.ViTLayerShard`).""" + + def __init__(self, config: DeiTConfig, shard_config: ModuleShardConfig): + super().__init__(config, shard_config) + self.layernorm_before = None + self.self_attention = None + self.self_output = None + self.layernorm_after = None + self.intermediate = None + self.output = None + self._build_shard() + + def _build_shard(self): + if self.has_layer(0): + self.layernorm_before = nn.LayerNorm(self.config.hidden_size, + eps=self.config.layer_norm_eps) + self.self_attention = ViTSelfAttention(self.config) + if self.has_layer(1): + self.self_output = ViTSelfOutput(self.config) + if self.has_layer(2): + self.layernorm_after = nn.LayerNorm(self.config.hidden_size, + eps=self.config.layer_norm_eps) + self.intermediate = ViTIntermediate(self.config) + if self.has_layer(3): + self.output = ViTOutput(self.config) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute layer shard.""" + if self.has_layer(0): + data_norm = self.layernorm_before(data) + data = (self.self_attention(data_norm)[0], data) + if self.has_layer(1): + skip = data[1] + data = self.self_output(data[0], skip) + data += skip + if self.has_layer(2): + data_norm = self.layernorm_after(data) + data = (self.intermediate(data_norm), data) + if self.has_layer(3): + data = self.output(data[0], data[1]) + return data + + +class DeiTModelShard(ModuleShard): + """Module shard based on `DeiTModel`.""" + + def __init__(self, config: DeiTConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.embeddings = None + # DeiTModel uses an encoder here, but we'll just add the layers here instead. + # Since we just do inference, a DeiTEncoderShard class wouldn't provide real benefit. + self.layers = nn.ModuleList() + self.layernorm = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + if self.shard_config.is_first: + logger.debug(">>>> Load embeddings layer for the first shard") + self.embeddings = DeiTEmbeddings(self.config) + self._load_weights_first(weights) + + layer_curr = self.shard_config.layer_start + while layer_curr <= self.shard_config.layer_end: + layer_id = math.ceil(layer_curr / 4) - 1 + sublayer_start = (layer_curr - 1) % 4 + if layer_id == math.ceil(self.shard_config.layer_end / 4) - 1: + sublayer_end = (self.shard_config.layer_end - 1) % 4 + else: + sublayer_end = 3 + logger.debug(">>>> Load layer %d, sublayers %d-%d", + layer_id, sublayer_start, sublayer_end) + layer_config = ModuleShardConfig(layer_start=sublayer_start, layer_end=sublayer_end) + layer = DeiTLayerShard(self.config, layer_config) + self._load_weights_layer(weights, layer_id, layer) + self.layers.append(layer) + layer_curr += sublayer_end - sublayer_start + 1 + + if self.shard_config.is_last: + logger.debug(">>>> Load layernorm for the last shard") + self.layernorm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) + self._load_weights_last(weights) + + @torch.no_grad() + def _load_weights_first(self, weights): + self.embeddings.cls_token.copy_(torch.from_numpy(weights["cls_token"])) + self.embeddings.position_embeddings.copy_(torch.from_numpy((weights["pos_embed"]))) + self.embeddings.patch_embeddings.projection.weight.copy_(torch.from_numpy(weights["patch_embed.proj.weight"])) + self.embeddings.patch_embeddings.projection.bias.copy_(torch.from_numpy(weights["patch_embed.proj.bias"])) + + @torch.no_grad() + def _load_weights_last(self, weights): + self.layernorm.weight.copy_(torch.from_numpy(weights["norm.weight"])) + self.layernorm.bias.copy_(torch.from_numpy(weights["norm.bias"])) + + @torch.no_grad() + def _load_weights_layer(self, weights, layer_id, layer): + root = f"blocks.{layer_id}." + embed_dim = self.config.hidden_size + if layer.has_layer(0): + layer.layernorm_before.weight.copy_(torch.from_numpy(weights[root + "norm1.weight"])) + layer.layernorm_before.bias.copy_(torch.from_numpy(weights[root + "norm1.bias"])) + qkv_weight = weights[root + "attn.qkv.weight"] + layer.self_attention.query.weight.copy_(torch.from_numpy(qkv_weight[0:embed_dim,:])) + layer.self_attention.key.weight.copy_(torch.from_numpy(qkv_weight[embed_dim:embed_dim*2,:])) + layer.self_attention.value.weight.copy_(torch.from_numpy(qkv_weight[embed_dim*2:embed_dim*3,:])) + qkv_bias = weights[root + "attn.qkv.bias"] + layer.self_attention.query.bias.copy_(torch.from_numpy(qkv_bias[0:embed_dim,])) + layer.self_attention.key.bias.copy_(torch.from_numpy(qkv_bias[embed_dim:embed_dim*2])) + layer.self_attention.value.bias.copy_(torch.from_numpy(qkv_bias[embed_dim*2:embed_dim*3])) + if layer.has_layer(1): + layer.self_output.dense.weight.copy_(torch.from_numpy(weights[root + "attn.proj.weight"])) + layer.self_output.dense.bias.copy_(torch.from_numpy(weights[root + "attn.proj.bias"])) + if layer.has_layer(2): + layer.layernorm_after.weight.copy_(torch.from_numpy(weights[root + "norm2.weight"])) + layer.layernorm_after.bias.copy_(torch.from_numpy(weights[root + "norm2.bias"])) + layer.intermediate.dense.weight.copy_(torch.from_numpy(weights[root + "mlp.fc1.weight"])) + layer.intermediate.dense.bias.copy_(torch.from_numpy(weights[root + "mlp.fc1.bias"])) + if layer.has_layer(3): + layer.output.dense.weight.copy_(torch.from_numpy(weights[root + "mlp.fc2.weight"])) + layer.output.dense.bias.copy_(torch.from_numpy(weights[root + "mlp.fc2.bias"])) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + if self.shard_config.is_first: + data = self.embeddings(data) + for layer in self.layers: + data = layer(data) + if self.shard_config.is_last: + data = self.layernorm(data) + return data + + # NOTE: repo has a dependency on the timm package, which isn't an automatic torch dependency + @staticmethod + def save_weights(model_name: str, model_file: str, hub_repo: str='facebookresearch/deit:main', + hub_model_name: Optional[str]=None) -> None: + """Save the model weights file.""" + if hub_model_name is None: + if model_name in _HUB_MODEL_NAMES: + hub_model_name = _HUB_MODEL_NAMES[model_name] + logger.debug("Mapping model name to torch hub equivalent: %s: %s", model_name, + hub_model_name) + else: + hub_model_name = model_name + model = torch.hub.load(hub_repo, hub_model_name, pretrained=True) + state_dict = model.state_dict() + weights = {} + for key, val in state_dict.items(): + weights[key] = val + np.savez(model_file, **weights) + + +class DeiTShardForImageClassification(ModuleShard): + """Module shard based on `DeiTForImageClassification`.""" + + def __init__(self, config: DeiTConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.deit = None + self.classifier = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + ## all shards use the inner DeiT model + self.deit = DeiTModelShard(self.config, self.shard_config, weights) + + if self.shard_config.is_last: + logger.debug(">>>> Load classifier for the last shard") + self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels) if self.config.num_labels > 0 else nn.Identity() + self._load_weights_last(weights) + + @torch.no_grad() + def _load_weights_last(self, weights): + self.classifier.weight.copy_(torch.from_numpy(weights["head.weight"])) + self.classifier.bias.copy_(torch.from_numpy(weights["head.bias"])) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + data = self.deit(data) + if self.shard_config.is_last: + data = self.classifier(data[:, 0, :]) + return data + + @staticmethod + def save_weights(model_name: str, model_file: str, hub_repo: str='facebookresearch/deit:main', + hub_model_name: Optional[str]=None) -> None: + """Save the model weights file.""" + DeiTModelShard.save_weights(model_name, model_file, hub_repo=hub_repo, + hub_model_name=hub_model_name) diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/vit.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/vit.py new file mode 100644 index 00000000..7760f5e6 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/models/transformers/vit.py @@ -0,0 +1,232 @@ +"""ViT Transformers.""" +from collections.abc import Mapping +import logging +import math +import os +from typing import Optional, Union +import numpy as np +import requests +import torch +from torch import nn +from transformers import ViTConfig +from transformers.models.vit.modeling_vit import ( + ViTEmbeddings, ViTIntermediate, ViTOutput, ViTSelfAttention, ViTSelfOutput +) +from .. import ModuleShard, ModuleShardConfig +from . import TransformerShardData + + +logger = logging.getLogger(__name__) + +_WEIGHTS_URLS = { + 'google/vit-base-patch16-224': 'https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-B_16-224.npz', + 'google/vit-large-patch16-224': 'https://storage.googleapis.com/vit_models/imagenet21k%2Bimagenet2012/ViT-L_16-224.npz', + 'google/vit-huge-patch14-224-in21k': 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', +} + + +class ViTLayerShard(ModuleShard): + """Module shard based on `ViTLayer`.""" + + def __init__(self, config: ViTConfig, shard_config: ModuleShardConfig): + super().__init__(config, shard_config) + self.layernorm_before = None + self.self_attention = None + self.self_output = None + self.layernorm_after = None + self.intermediate = None + self.output = None + self._build_shard() + + def _build_shard(self): + if self.has_layer(0): + self.layernorm_before = nn.LayerNorm(self.config.hidden_size, + eps=self.config.layer_norm_eps) + self.self_attention = ViTSelfAttention(self.config) + if self.has_layer(1): + self.self_output = ViTSelfOutput(self.config) + if self.has_layer(2): + self.layernorm_after = nn.LayerNorm(self.config.hidden_size, + eps=self.config.layer_norm_eps) + self.intermediate = ViTIntermediate(self.config) + if self.has_layer(3): + self.output = ViTOutput(self.config) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute layer shard.""" + if self.has_layer(0): + data_norm = self.layernorm_before(data) + data = (self.self_attention(data_norm)[0], data) + if self.has_layer(1): + skip = data[1] + data = self.self_output(data[0], skip) + data += skip + if self.has_layer(2): + data_norm = self.layernorm_after(data) + data = (self.intermediate(data_norm), data) + if self.has_layer(3): + data = self.output(data[0], data[1]) + return data + + +class ViTModelShard(ModuleShard): + """Module shard based on `ViTModel` (no pooling layer).""" + + def __init__(self, config: ViTConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.embeddings = None + # ViTModel uses an encoder here, but we'll just add the layers here instead. + # Since we just do inference, a ViTEncoderShard class wouldn't provide real benefit. + self.layers = nn.ModuleList() + self.layernorm = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + if self.shard_config.is_first: + logger.debug(">>>> Load embeddings layer for the first shard") + self.embeddings = ViTEmbeddings(self.config) + self._load_weights_first(weights) + + layer_curr = self.shard_config.layer_start + while layer_curr <= self.shard_config.layer_end: + layer_id = math.ceil(layer_curr / 4) - 1 + sublayer_start = (layer_curr - 1) % 4 + if layer_id == math.ceil(self.shard_config.layer_end / 4) - 1: + sublayer_end = (self.shard_config.layer_end - 1) % 4 + else: + sublayer_end = 3 + logger.debug(">>>> Load layer %d, sublayers %d-%d", + layer_id, sublayer_start, sublayer_end) + layer_config = ModuleShardConfig(layer_start=sublayer_start, layer_end=sublayer_end) + layer = ViTLayerShard(self.config, layer_config) + self._load_weights_layer(weights, layer_id, layer) + self.layers.append(layer) + layer_curr += sublayer_end - sublayer_start + 1 + + if self.shard_config.is_last: + logger.debug(">>>> Load layernorm for the last shard") + self.layernorm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) + self._load_weights_last(weights) + + @torch.no_grad() + def _load_weights_first(self, weights): + self.embeddings.cls_token.copy_(torch.from_numpy(weights["cls"])) + self.embeddings.position_embeddings.copy_(torch.from_numpy((weights["Transformer/posembed_input/pos_embedding"]))) + conv_weight = weights["embedding/kernel"] + # O, I, J, K = conv_weight.shape + # conv_weight = conv_weight.reshape(K,J,O,I) + conv_weight = conv_weight.transpose([3, 2, 0, 1]) + self.embeddings.patch_embeddings.projection.weight.copy_(torch.from_numpy(conv_weight)) + self.embeddings.patch_embeddings.projection.bias.copy_(torch.from_numpy(weights["embedding/bias"])) + + @torch.no_grad() + def _load_weights_last(self, weights): + self.layernorm.weight.copy_(torch.from_numpy(weights["Transformer/encoder_norm/scale"])) + self.layernorm.bias.copy_(torch.from_numpy(weights["Transformer/encoder_norm/bias"])) + + @torch.no_grad() + def _load_weights_layer(self, weights, layer_id, layer): + root = f"Transformer/encoderblock_{layer_id}/" + hidden_size = self.config.hidden_size + if layer.has_layer(0): + layer.layernorm_before.weight.copy_(torch.from_numpy(weights[root + "LayerNorm_0/scale"])) + layer.layernorm_before.bias.copy_(torch.from_numpy(weights[root + "LayerNorm_0/bias"])) + layer.self_attention.query.weight.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/query/kernel"]).view(hidden_size, hidden_size).t()) + layer.self_attention.key.weight.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/key/kernel"]).view(hidden_size, hidden_size).t()) + layer.self_attention.value.weight.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/value/kernel"]).view(hidden_size, hidden_size).t()) + layer.self_attention.query.bias.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/query/bias"]).view(-1)) + layer.self_attention.key.bias.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/key/bias"]).view(-1)) + layer.self_attention.value.bias.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/value/bias"]).view(-1)) + if layer.has_layer(1): + layer.self_output.dense.weight.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/out/kernel"]).view(hidden_size, hidden_size).t()) + layer.self_output.dense.bias.copy_(torch.from_numpy(weights[root + "MultiHeadDotProductAttention_1/out/bias"]).view(-1)) + if layer.has_layer(2): + layer.layernorm_after.weight.copy_(torch.from_numpy(weights[root + "LayerNorm_2/scale"])) + layer.layernorm_after.bias.copy_(torch.from_numpy(weights[root + "LayerNorm_2/bias"])) + layer.intermediate.dense.weight.copy_(torch.from_numpy(weights[root + "MlpBlock_3/Dense_0/kernel"]).t()) + layer.intermediate.dense.bias.copy_(torch.from_numpy(weights[root + "MlpBlock_3/Dense_0/bias"]).t()) + if layer.has_layer(3): + layer.output.dense.weight.copy_(torch.from_numpy(weights[root + "MlpBlock_3/Dense_1/kernel"]).t()) + layer.output.dense.bias.copy_(torch.from_numpy(weights[root + "MlpBlock_3/Dense_1/bias"]).t()) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + if self.shard_config.is_first: + data = self.embeddings(data) + for layer in self.layers: + data = layer(data) + if self.shard_config.is_last: + data = self.layernorm(data) + return data + + @staticmethod + def save_weights(model_name: str, model_file: str, url: Optional[str]=None, + timeout_sec: Optional[float]=None) -> None: + """Save the model weights file.""" + if url is None: + url = _WEIGHTS_URLS[model_name] + logger.info('Downloading model: %s: %s', model_name, url) + req = requests.get(url, stream=True, timeout=timeout_sec) + req.raise_for_status() + with open(model_file, 'wb') as file: + for chunk in req.iter_content(chunk_size=8192): + if chunk: + file.write(chunk) + file.flush() + os.fsync(file.fileno()) + + +class ViTShardForImageClassification(ModuleShard): + """Module shard based on `ViTForImageClassification`.""" + + def __init__(self, config: ViTConfig, shard_config: ModuleShardConfig, + model_weights: Union[str, Mapping]): + super().__init__(config, shard_config) + self.vit = None + self.classifier = None + + logger.debug(">>>> Model name: %s", self.config.name_or_path) + if isinstance(model_weights, str): + logger.debug(">>>> Load weight file: %s", model_weights) + with np.load(model_weights) as weights: + self._build_shard(weights) + else: + self._build_shard(model_weights) + + def _build_shard(self, weights): + ## all shards use the inner ViT model + self.vit = ViTModelShard(self.config, self.shard_config, weights) + + if self.shard_config.is_last: + logger.debug(">>>> Load classifier for the last shard") + self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels) if self.config.num_labels > 0 else nn.Identity() + self._load_weights_last(weights) + + @torch.no_grad() + def _load_weights_last(self, weights): + self.classifier.weight.copy_(torch.from_numpy(np.transpose(weights["head/kernel"]))) + self.classifier.bias.copy_(torch.from_numpy(weights["head/bias"])) + + @torch.no_grad() + def forward(self, data: TransformerShardData) -> TransformerShardData: + """Compute shard layers.""" + data = self.vit(data) + if self.shard_config.is_last: + data = self.classifier(data[:, 0, :]) + return data + + @staticmethod + def save_weights(model_name: str, model_file: str, url: Optional[str]=None, + timeout_sec: Optional[float]=None) -> None: + """Save the model weights file.""" + ViTModelShard.save_weights(model_name, model_file, url=url, timeout_sec=timeout_sec) diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/profiler.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/profiler.py new file mode 100644 index 00000000..61fdf33f --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/profiler.py @@ -0,0 +1,263 @@ +"""Module shard profiler.""" +import argparse +import gc +import os +import time +import numpy as np +import psutil +import torch +import torch.multiprocessing as mp +import yaml +from transformers import BertTokenizer +import devices +import model_cfg + + +def get_shapes(tensors): + """Get the tensor shapes, excluding the outer dimension (microbatch size).""" + if isinstance(tensors, tuple): + shape = [] + for tensor in tensors: + shape.append(tuple(tensor.shape[1:])) + else: + shape = [tuple(tensors.shape[1:])] + return shape + + +def create_module_shard(module_cfg, stage_cfg): + """Create a module shard.""" + model_name = module_cfg['name'] + model_file = module_cfg['file'] + stage = stage_cfg['stage'] + layer_start = stage_cfg['layer_start'] + layer_end = stage_cfg['layer_end'] + return model_cfg.module_shard_factory(model_name, model_file, layer_start, layer_end, stage) + + +def profile_module_shard(module_cfg, stage_cfg, stage_inputs, warmup, iterations): + """Profile a module shard.""" + process = psutil.Process(os.getpid()) + + # Measure memory (create shard) on the CPU. + # This avoids capturing additional memory overhead when using other devices, like GPUs. + # It's OK if the model fits in DRAM but not on the "device" - we'll just fail later. + # We consider memory requirements to be a property of the model, not the device/platform. + assert devices.DEVICE is None + # Capturing memory behavior in Python is extremely difficult and results are subject to many + # factors beyond our ability to control or reliably detect/infer. + # This works best when run once per process execution with only minimal work done beforehand. + gc.collect() + stage_start_mem = process.memory_info().rss / 1000000 + module = create_module_shard(module_cfg, stage_cfg) + gc.collect() + stage_end_mem = process.memory_info().rss / 1000000 + + # Now move the module to the specified device + device = module_cfg['device'] + if device is not None: + devices.DEVICE = torch.device(device) + if devices.DEVICE is not None and devices.DEVICE.type == 'cuda': + torch.cuda.init() + module.to(device=device) + module.register_forward_pre_hook(devices.forward_pre_hook_to_device) + module.register_forward_hook(devices.forward_hook_to_cpu) + + # Measure data input + shape_in = get_shapes(stage_inputs) + + # Optional warmup + if warmup: + module(stage_inputs) + + # Measure timing (execute shard) - includes data movement overhead (performed in hooks) + stage_times = [] + for _ in range(iterations): + stage_start_time = time.time() + stage_outputs = module(stage_inputs) + stage_end_time = time.time() + stage_times.append(stage_end_time - stage_start_time) + stage_time_avg = sum(stage_times) / len(stage_times) + + # Measure data output + shape_out = get_shapes(stage_outputs) + + results = { + 'shape_in': shape_in, + 'shape_out': shape_out, + 'memory': stage_end_mem - stage_start_mem, + 'time': stage_time_avg, + } + return (stage_outputs, results) + + +def profile_module_shard_mp_queue(queue, evt_done, args): + """Multiprocessing target function for `profile_module_shard` which adds output to queue.""" + queue.put(profile_module_shard(*args)) + evt_done.wait() + + +def profile_module_shard_mp(args): + """Run `profile_module_shard` with multiprocessing (for more accurate memory results).""" + # First, a non-optional module warmup in case PyTorch needs to fetch/cache models on first use + print("Performing module warmup...") + proc = mp.Process(target=create_module_shard, args=(args[0], args[1])) + proc.start() + proc.join() + + # Now, the actual profiling + print("Performing module profiling...") + queue = mp.Queue() + # The child process sometimes exits before we read the queue items, even though it should have + # flushed all data to the underlying pipe before that, so use an event to keep it alive. + evt_done = mp.Event() + proc = mp.Process(target=profile_module_shard_mp_queue, args=(queue, evt_done, args)) + proc.start() + tensors, prof_dict = queue.get() + evt_done.set() + proc.join() + return (tensors, prof_dict) + + +def profile_layers(module_cfg, tensors, layer_start, layer_end, warmup, iterations): + """Profile a shard with layer_start through layer_end.""" + shard = { + 'stage': 0, + 'layer_start': layer_start, + 'layer_end': layer_end, + } + _, prof_dict = profile_module_shard_mp(args=(module_cfg, shard, tensors, warmup, iterations)) + prof_dict['layer'] = 0 + return [prof_dict] + + +def profile_layers_individually(module_cfg, tensors, layer_start, layer_end, warmup, iterations): + """Profile module shards for each layer individually.""" + results = [] + for layer in range(layer_start, layer_end + 1): + shard = { + 'stage': layer, + 'layer_start': layer, + 'layer_end': layer, + } + tensors, prof_dict = profile_module_shard_mp(args=(module_cfg, shard, tensors, warmup, iterations)) + prof_dict['layer'] = layer + results.append(prof_dict) + return results + + +def profile_layers_cumulatively(module_cfg, tensors, layer_start, layer_end, warmup, iterations): + """Profile module shards with increasing numbers of layers.""" + results = [] + for layer in range(1, layer_end + 1): + shard = { + 'stage': layer, + 'layer_start': layer_start, + 'layer_end': layer, + } + _, prof_dict = profile_module_shard_mp(args=(module_cfg, shard, tensors, warmup, iterations)) + prof_dict['layer'] = layer + results.append(prof_dict) + return results + + +def validate_profile_results(profile_results, args, inputs, model_layers, layer_end): + """Validate that we can work with existing profiling results""" + assert profile_results['model_name'] == args.model_name, "model name mismatch with existing results" + dtype = inputs[0].dtype if isinstance(inputs, tuple) else inputs.dtype + assert profile_results['dtype'] == str(dtype), "dtype mismatch with existing results" + assert profile_results['batch_size'] == args.batch_size, "batch size mismatch with existing results" + assert profile_results['layers'] == model_layers, "layer count mismatch with existing results" + # check for overlap with existing results data + for _layer in range(args.layer_start, layer_end + 1): + for _pd in profile_results['profile_data']: + assert _layer != _pd['layer'], "layer to be profiled already in existing results" + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser(description="Module Shard Profiler", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("-o", "--results-yml", default="./examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/profiler_results.yml", type=str, + help="output YAML file") + parser.add_argument("-d", "--device", type=str, default=None, + help="compute device type to use, with optional ordinal, " + "e.g.: 'cpu', 'cuda', 'cuda:1'") + parser.add_argument("-m", "--model-name", type=str, default="google/vit-base-patch16-224", + choices=model_cfg.get_model_names(), + help="the neural network model for loading") + parser.add_argument("-M", "--model-file", type=str, + help="the model file, if not in working directory") + parser.add_argument("-l", "--layer-start", default=1, type=int, help="start layer") + parser.add_argument("-L", "--layer-end", type=int, help="end layer; default: last layer in the model") + parser.add_argument("-s", "--shape-input", type=str, action='append', + help="comma-delimited shape input, e.g., '3,224,224' (required for start_layer != 1)") + parser.add_argument("-b", "--batch-size", default=8, type=int, help="batch size") + parser.add_argument("-w", "--warmup", action="store_true", default=True, + help="perform a warmup iteration " + "(strongly recommended, esp. with device='cuda' or iterations>1)") + parser.add_argument("--no-warmup", action="store_false", dest="warmup", + help="don't perform a warmup iteration") + parser.add_argument("-i", "--iterations", default=1, type=int, + help="iterations to average runtime for") + args = parser.parse_args() + + if args.shape_input is not None: + shapes = [] + for shp in args.shape_input: + shapes.append(tuple(int(d) for d in shp.split(','))) + if len(shapes) > 1: + # tuple of tensors + inputs = tuple(torch.randn(args.batch_size, *shp) for shp in shapes) + else: + # single tensor + inputs = torch.randn(args.batch_size, *shapes[0]) + elif args.model_name in ['bert-base-uncased', 'bert-large-uncased']: + with np.load("bert_input.npz") as bert_inputs: + inputs_sentence = list(bert_inputs['input'][0: args.batch_size]) + tokenizer = BertTokenizer.from_pretrained(args.model_name) + inputs = tokenizer(inputs_sentence, padding=True, truncation=True, return_tensors="pt")['input_ids'] + else: + inputs = torch.randn(args.batch_size, 3, 224, 224) + + model_layers = model_cfg.get_model_layers(args.model_name) + layer_end = args.layer_end + if layer_end is None: + layer_end = model_layers + + # get or create profile_results + if os.path.exists(args.results_yml): + print("Using existing results file") + with open(args.results_yml, 'r', encoding='utf-8') as yfile: + profile_results = yaml.safe_load(yfile) + validate_profile_results(profile_results, args, inputs, model_layers, layer_end) + else: + profile_results = { + 'model_name': args.model_name, + 'dtype': str(inputs.dtype), + 'batch_size': args.batch_size, + 'layers': model_layers, + 'profile_data': [], + } + + module_cfg = { + 'device': args.device, + 'name': args.model_name, + 'file': args.model_file, + } + if args.model_file is None: + module_cfg['file'] = model_cfg.get_model_default_weights_file(args.model_name) + # a single shard measurement can be a useful reference + # results = profile_layers(module_cfg, inputs, args.layer_start, layer_end, args.warmup, args.iterations) + # cumulative won't work if the whole model doesn't fit on the device + # results = profile_layers_cumulatively(module_cfg, inputs, args.layer_start, layer_end, args.warmup, args.iterations) + results = profile_layers_individually(module_cfg, inputs, args.layer_start, layer_end, args.warmup, args.iterations) + + # just a dump of the configuration and profiling results + profile_results['profile_data'].extend(results) + profile_results['profile_data'].sort(key=lambda pd: pd['layer']) + with open(args.results_yml, 'w', encoding='utf-8') as yfile: + yaml.safe_dump(profile_results, yfile, default_flow_style=None, encoding='utf-8') + + +if __name__=="__main__": + main() diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_files.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_files.py new file mode 100644 index 00000000..b90459a2 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_files.py @@ -0,0 +1,49 @@ +"""Manage YAML files.""" +import os +import yaml + + +def _yaml_load_map(file): + if os.path.exists(file): + with open(file, 'r', encoding='utf-8') as yfile: + yml = yaml.safe_load(yfile) + else: + yml = {} + return yml + + +def yaml_models_load(file) -> dict: + """Load a YAML models file.""" + # models files are a map of model names to yaml_model values. + return _yaml_load_map(file) + + +def yaml_device_types_load(file) -> dict: + """Load a YAML device types file.""" + # device types files are a map of device type names to yaml_device_type values. + return _yaml_load_map(file) + + +def yaml_devices_load(file) -> dict: + """Load a YAML devices file.""" + # devices files are a map of device type names to lists of hosts. + return _yaml_load_map(file) + + +def yaml_device_neighbors_load(file) -> dict: + """Load a YAML device neighbors file.""" + # device neighbors files are a map of neighbor hostnames to yaml_device_neighbors_type values. + return _yaml_load_map(file) + + +def yaml_device_neighbors_world_load(file) -> dict: + """Load a YAML device neighbors world file.""" + # device neighbors world files are a map of hostnames to a map of neighbor hostnames to + # yaml_device_neighbors_type values. + return _yaml_load_map(file) + + +def yaml_save(yml, file): + """Save a YAML file.""" + with open(file, 'w', encoding='utf-8') as yfile: + yaml.safe_dump(yml, yfile, default_flow_style=None, encoding='utf-8') diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_types.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_types.py new file mode 100644 index 00000000..a5639fe5 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/automatic/utils/yaml_types.py @@ -0,0 +1,82 @@ +"""YAML types.""" +from typing import List, Optional, Union + + +def _assert_list_type(lst, dtype): + assert isinstance(lst, list) + for var in lst: + assert isinstance(var, dtype) + + +def yaml_model(num_layers: int, parameters_in: int, parameters_out: List[int], + mem_MB: Union[List[int], List[float]]) -> dict: + """Create a YAML model.""" + assert isinstance(num_layers, int) + assert isinstance(parameters_in, int) + _assert_list_type(parameters_out, int) + _assert_list_type(mem_MB, (int, float)) + return { + 'layers': num_layers, + 'parameters_in': parameters_in, + 'parameters_out': parameters_out, + 'mem_MB': mem_MB, + } + + +def yaml_model_profile(dtype: str, batch_size: int, time_s: Union[List[int], List[float]]) -> dict: + """Create a YAML model profile.""" + assert isinstance(dtype, str) + assert isinstance(batch_size, int) + _assert_list_type(time_s, (int, float)) + return { + 'dtype': dtype, + 'batch_size': batch_size, + 'time_s': time_s, + } + + +def _assert_model_profile(model_prof): + assert isinstance(model_prof, dict) + for model_prof_prop in model_prof: + # only 'time_s' is supported + assert model_prof_prop == 'time_s' + _assert_list_type(model_prof['time_s'], (int, float)) + + +def _assert_model_profiles(model_profiles): + assert isinstance(model_profiles, dict) + for model in model_profiles: + assert isinstance(model, str) + _assert_model_profile(model_profiles[model]) + + +def yaml_device_type(mem_MB: Union[int, float], bw_Mbps: Union[int, float], + model_profiles: Optional[dict]) -> dict: + """Create a YAML device type.""" + assert isinstance(mem_MB, (int, float)) + assert isinstance(bw_Mbps, (int, float)) + if model_profiles is None: + model_profiles = {} + _assert_model_profiles(model_profiles) + return { + 'mem_MB': mem_MB, + 'bw_Mbps': bw_Mbps, + 'model_profiles': model_profiles, + } + +def yaml_device_neighbors_type(bw_Mbps: Union[int, float]) -> dict: + """Create a YAML device neighbors type.""" + assert isinstance(bw_Mbps, (int, float)) + return { + 'bw_Mbps': bw_Mbps, + # Currently only one field, but could be extended, e.g., to include latency_{ms,us}. + } + +def yaml_device_neighbors(neighbors: List[str], bws_Mbps: Union[List[int], List[float]]) -> dict: + """Create a YAML device neighbors.""" + _assert_list_type(neighbors, str) + _assert_list_type(bws_Mbps, (int, float)) + return { + neighbor: yaml_device_neighbors_type(bw_Mbps) + for neighbor, bw_Mbps in zip(neighbors, bws_Mbps) + } diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/basemodel.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/basemodel.py new file mode 100644 index 00000000..2efcb901 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/basemodel.py @@ -0,0 +1,154 @@ +# Modified Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import glob +import os +from collections import OrderedDict +from pathlib import Path +from collections import defaultdict +import time + +from sedna.common.class_factory import ClassType, ClassFactory +from dataset import load_dataset + +import yaml +import onnxruntime as ort +from torch.utils.data import DataLoader +import numpy as np +from tqdm import tqdm +import pynvml + + +__all__ = ["BaseModel"] + +# set backend +os.environ["BACKEND_TYPE"] = "ONNX" + + +def make_parser(): + parser = argparse.ArgumentParser("ViT Eval") + parser.add_argument( + "opts", + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument("--devices_info", default="./devices.yaml", type=str, help="devices conf") + parser.add_argument("--model_parallel", default=True, action="store_true") + parser.add_argument("--split", default="val", type=str, help="split of dataset") + parser.add_argument("--indices", default=None, type=str, help="indices of dataset") + parser.add_argument("--shuffle", default=False, action="store_true", help="shuffle data") + parser.add_argument("--model_name", default="google/vit-base-patch16-224", type=str, help="model name") + parser.add_argument("--dataset_name", default="ImageNet", type=str, help="dataset name") + parser.add_argument("--data_size", default=1000, type=int, help="data size to inference") + # remove conflict with ianvs + parser.add_argument("-f") + return parser + + +@ClassFactory.register(ClassType.GENERAL, alias="Classification") +class BaseModel: + + def __init__(self, **kwargs) -> None: + self.args = make_parser().parse_args() + self.devices_info_url = str(Path(Path(__file__).parent.resolve(), self.args.devices_info)) + self.model_parallel = self.args.model_parallel + self.partition_point_list = self._parse_devices_info(self.devices_info_url).get('partition_points') + self.models = [] + return + + + def load(self, models_dir=None, map_info=None) -> None: + cnt = 0 + for model_name, device in map_info.items(): + model = models_dir + '/' + model_name + if not os.path.exists(model): + raise ValueError("=> No modle found at '{}'".format(model)) + if device == 'cpu': + session = ort.InferenceSession(model, providers=['CPUExecutionProvider']) + elif 'gpu' in device: + device_id = int(device.split('-')[-1]) + session = ort.InferenceSession(model, providers=[('CUDAExecutionProvider', {'device_id': device_id})]) + else: + raise ValueError("Error device info: '{}'".format(device)) + self.models.append({ + 'session': session, + 'name': model_name, + 'device': device, + 'input_names': self.partition_point_list[cnt]['input_names'], + 'output_names': self.partition_point_list[cnt]['output_names'], + }) + cnt += 1 + print("=> Loaded onnx model: '{}'".format(model)) + return + + def predict(self, data, input_shape=None, **kwargs): + pynvml.nvmlInit() + root = str(Path(data[0]).parents[2]) + dataset_cfg = { + 'name': self.args.dataset_name, + 'root': root, + 'split': self.args.split, + 'indices': self.args.indices, + 'shuffle': self.args.shuffle + } + data_loader, ids = self._get_eval_loader(dataset_cfg) + data_loader = tqdm(data_loader, desc='Evaluating', unit='batch') + pred = [] + inference_time_per_device = defaultdict(int) + power_usage_per_device = defaultdict(list) + mem_usage_per_device = defaultdict(list) + cnt = 0 + for data, id in zip(data_loader, ids): + outputs = data[0].numpy() + for model in self.models: + start_time = time.time() + outputs = model['session'].run(None, {model['input_names'][0]: outputs})[0] + end_time = time.time() + device = model.get('device') + inference_time_per_device[device] += end_time - start_time + if 'gpu' in device and cnt % 100 == 0: + handle = pynvml.nvmlDeviceGetHandleByIndex(int(device.split('-')[-1])) + power_usage = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 + memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle).used / (1024**2) + power_usage_per_device[device] += [power_usage] + mem_usage_per_device[device] += [memory_info] + max_ids = np.argmax(outputs) + pred.append((max_ids, id)) + cnt += 1 + data_loader.close() + result = dict({}) + result["pred"] = pred + result["inference_time_per_device"] = inference_time_per_device + result["power_usage_per_device"] = power_usage_per_device + result["mem_usage_per_device"] = mem_usage_per_device + return result + + + def _get_eval_loader(self, dataset_cfg): + model_name = self.args.model_name + data_size = self.args.data_size + dataset, _, ids = load_dataset(dataset_cfg, model_name, data_size) + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + return data_loader, ids + + def _parse_devices_info(self, url): + """Convert yaml file to the dict.""" + if url.endswith('.yaml') or url.endswith('.yml'): + with open(url, "rb") as file: + devices_info_dict = yaml.load(file, Loader=yaml.SafeLoader) + return devices_info_dict + else: + raise RuntimeError('config file must be the yaml format') \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/classification_algorithm.yaml b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/classification_algorithm.yaml new file mode 100644 index 00000000..c8d212f7 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/classification_algorithm.yaml @@ -0,0 +1,27 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "multiedgeinference" + # the url address of initial model; string type; optional; + initial_model_url: "./initial_model/vit-base-patch16-224.onnx" + + # algorithm module configuration in the paradigm; list type; + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "Classification" + # the url address of python module; string type; + url: "./examples/imagenet/multiedge_inference_bench/testalgorithms/manual/basemodel.py" + + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + - batch_size: + values: + - 1 diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/dataset.py b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/dataset.py new file mode 100644 index 00000000..9b4ee16c --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/dataset.py @@ -0,0 +1,71 @@ +import logging +import random +from typing import Callable, Optional, Sequence +import os + +from torch.utils.data import DataLoader, Dataset, Subset +from transformers import ViTFeatureExtractor +from torchvision.datasets import ImageNet + + +def load_dataset_imagenet(feature_extractor: Callable, root: str, split: str='train') -> Dataset: + """Get the ImageNet dataset.""" + + def transform(img): + pixels = feature_extractor(images=img.convert('RGB'), return_tensors='pt')['pixel_values'] + return pixels[0] + return ImageNet(root, split=split, transform=transform) + +def load_dataset_subset(dataset: Dataset, indices: Optional[Sequence[int]]=None, + max_size: Optional[int]=None, shuffle: bool=False) -> Dataset: + """Get a Dataset subset.""" + if indices is None: + indices = list(range(len(dataset))) + if shuffle: + random.shuffle(indices) + if max_size is not None: + indices = indices[:max_size] + image_paths = [] + for index in indices: + image_paths.append(dataset.imgs[index][0]) + return Subset(dataset, indices), image_paths, indices + +def load_dataset(dataset_cfg: dict, model_name: str, batch_size: int) -> Dataset: + """Load inputs based on model.""" + def _get_feature_extractor(): + feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) + return feature_extractor + dataset_name = dataset_cfg['name'] + dataset_root = dataset_cfg['root'] + dataset_split = dataset_cfg['split'] + indices = dataset_cfg['indices'] + dataset_shuffle = dataset_cfg['shuffle'] + if dataset_name == 'ImageNet': + if dataset_root is None: + dataset_root = 'ImageNet' + logging.info("Dataset root not set, assuming: %s", dataset_root) + feature_extractor = _get_feature_extractor() + dataset = load_dataset_imagenet(feature_extractor, dataset_root, split=dataset_split) + dataset, paths, ids = load_dataset_subset(dataset, indices=indices, max_size=batch_size, + shuffle=dataset_shuffle) + return dataset, paths, ids + +if __name__ == '__main__': + dataset_cfg = { + 'name': "ImageNet", + 'root': './dataset', + 'split': 'val', + 'indices': None, + 'shuffle': False, + } + model_name = "google/vit-base-patch16-224" + ## Total images to be inferenced. + data_size = 1000 + dataset, paths, _ = load_dataset(dataset_cfg, model_name, data_size) + data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) + with open('./dataset/train.txt', 'w') as f: + for i, (image, label) in enumerate(data_loader): + original_path = paths[i].replace('/dataset', '') + f.write(f"{original_path} {label.item()}\n") + f.close() + os.popen('cp ./dataset/train.txt ./dataset/test.txt') \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices.yaml b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices.yaml new file mode 100644 index 00000000..82be6c7c --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices.yaml @@ -0,0 +1,26 @@ +devices: + - name: "gpu" + type: "gpu-0" + memory: "1024" + freq: "2.6" + bandwith: "100" + - name: "gpu-1" + type: "gpu" + memory: "1024" + freq: "2.6" + bandwith: "80" + - name: "gpu-2" + type: "gpu" + memory: "2048" + freq: "2.6" + bandwith: "90" +partition_points: + - input_names: ["pixel_values"] + output_names: ["input.60"] + device_name: "gpu-0" + - input_names: ["input.60"] + output_names: ["input.160"] + device_name: "gpu-1" + - input_names: ["input.160"] + output_names: ["logits"] + device_name: "gpu-2" \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices_one.yaml b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices_one.yaml new file mode 100644 index 00000000..2f4a5640 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testalgorithms/manual/devices_one.yaml @@ -0,0 +1,11 @@ +devices: + - name: "gpu-1" + type: gpu" + memory: "1024" + freq: "2.6" + bandwith: "100" + +partition_points: + - input_names: ["pixel_values"] + output_names: ["logits"] + device_name: "gpu-1" \ No newline at end of file diff --git a/examples/imagenet/multiedge_inference_bench/testenv/accuracy.py b/examples/imagenet/multiedge_inference_bench/testenv/accuracy.py new file mode 100644 index 00000000..86f78d69 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testenv/accuracy.py @@ -0,0 +1,14 @@ +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ('accuracy') + +@ClassFactory.register(ClassType.GENERAL, alias="accuracy") +def accuracy(y_true, y_pred, **kwargs): + y_pred = y_pred.get("pred") + total = len(y_pred) + y_true_ = [int(y_true[i].split('/')[-1]) for (_, i) in y_pred] + y_pred_ = [int(i) for (i, _) in y_pred] + correct_predictions = sum(yt == yp for yt, yp in zip(y_true_, y_pred_)) + accuracy = (correct_predictions / total) * 100 + print("Accuracy: {:.2f}%".format(accuracy)) + return accuracy diff --git a/examples/imagenet/multiedge_inference_bench/testenv/fps.py b/examples/imagenet/multiedge_inference_bench/testenv/fps.py new file mode 100644 index 00000000..810e72f6 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testenv/fps.py @@ -0,0 +1,34 @@ +import sys +import os + +from sedna.common.class_factory import ClassType, ClassFactory + +import matplotlib.pyplot as plt + +__all__ = ('fps') + +@ClassFactory.register(ClassType.GENERAL, alias="fps") +def fps(y_true, y_pred, **kwargs): + total = len(y_pred.get("pred")) + inference_time_per_device = y_pred.get("inference_time_per_device") + plt.figure() + min_fps = sys.maxsize + for device, time in inference_time_per_device.items(): + fps = total / time + plt.bar(device, fps, label=f'{device}') + min_fps = min(fps, min_fps) + plt.axhline(y=min_fps, color='red', linewidth=2, label='Min FPS') + + plt.xticks(rotation=45) + plt.ylabel('FPS') + plt.xlabel('Device') + plt.legend() + + dir = './multiedge_inference_bench/workspace/classification_job/images/' + if not os.path.exists(dir): + os.makedirs(dir) + from datetime import datetime + now = datetime.now().strftime("%H_%M_%S") + plt.savefig(dir + 'FPS_per_device' + now + '.png') + + return min_fps diff --git a/examples/imagenet/multiedge_inference_bench/testenv/peak_memory.py b/examples/imagenet/multiedge_inference_bench/testenv/peak_memory.py new file mode 100644 index 00000000..c6e6dccb --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testenv/peak_memory.py @@ -0,0 +1,32 @@ +import sys +import os + +from sedna.common.class_factory import ClassType, ClassFactory + +import matplotlib.pyplot as plt + +__all__ = ('peak_memory') + +@ClassFactory.register(ClassType.GENERAL, alias="peak_memory") +def peak_power(y_true, y_pred, **kwargs): + mem_usage_per_device = y_pred.get("mem_usage_per_device") + plt.figure() + peak_mem = -sys.maxsize + for device, mem_list in mem_usage_per_device.items(): + plt.bar(device, max(mem_list), label=f'{device}') + peak_mem = max(peak_mem, max(mem_list)) + plt.axhline(y=peak_mem, color='red', linewidth=2, label='Peak Memory') + + plt.xticks(rotation=45) + plt.ylabel('Memory') + plt.xlabel('Device') + plt.legend() + + dir = './multiedge_inference_bench/workspace/classification_job/images/' + if not os.path.exists(dir): + os.makedirs(dir) + from datetime import datetime + now = datetime.now().strftime("%H_%M_%S") + plt.savefig(dir + 'peak_mem_per_device' + now + '.png') + + return peak_mem diff --git a/examples/imagenet/multiedge_inference_bench/testenv/peak_power.py b/examples/imagenet/multiedge_inference_bench/testenv/peak_power.py new file mode 100644 index 00000000..7618e714 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testenv/peak_power.py @@ -0,0 +1,32 @@ +import sys +import os + +from sedna.common.class_factory import ClassType, ClassFactory + +import matplotlib.pyplot as plt + +__all__ = ('peak_power') + +@ClassFactory.register(ClassType.GENERAL, alias="peak_power") +def peak_power(y_true, y_pred, **kwargs): + power_usage_per_device = y_pred.get("power_usage_per_device") + plt.figure() + peak_power = -sys.maxsize + for device, power_list in power_usage_per_device.items(): + plt.plot(power_list, label=device) + peak_power = max(peak_power, max(power_list)) + plt.axhline(y=peak_power, color='red', linewidth=2, label='Peak Power') + + plt.xticks(rotation=45) + plt.ylabel('Power') + plt.xlabel('Device') + plt.legend() + + dir = './multiedge_inference_bench/workspace/classification_job/images/' + if not os.path.exists(dir): + os.makedirs(dir) + from datetime import datetime + now = datetime.now().strftime("%H_%M_%S") + plt.savefig(dir + 'power_usage_per_device' + now + '.png') + + return peak_power diff --git a/examples/imagenet/multiedge_inference_bench/testenv/testenv.yaml b/examples/imagenet/multiedge_inference_bench/testenv/testenv.yaml new file mode 100644 index 00000000..f2bb70d2 --- /dev/null +++ b/examples/imagenet/multiedge_inference_bench/testenv/testenv.yaml @@ -0,0 +1,25 @@ +testenv: + # dataset configuration + dataset: + # the url address of train dataset index; string type; + train_url: "./dataset/train.txt" + # the url address of test dataset index; string type; + test_url: "./dataset/test.txt" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + - name: "accuracy" + # the url address of python file + url: "./examples/imagenet/multiedge_inference_bench/testenv/accuracy.py" + - name: "fps" + # the url address of python file + url: "./examples/imagenet/multiedge_inference_bench/testenv/fps.py" + - name: "peak_memory" + # the url address of python file + url: "./examples/imagenet/multiedge_inference_bench/testenv/peak_memory.py" + - name: "peak_power" + # the url address of python file + url: "./examples/imagenet/multiedge_inference_bench/testenv/peak_power.py" + devices: + - url : "./devices.yaml" \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/README.md b/examples/llm-edge-benchmark-suite/README.md new file mode 100644 index 00000000..8ef4ae97 --- /dev/null +++ b/examples/llm-edge-benchmark-suite/README.md @@ -0,0 +1,41 @@ +Large Language Model Edge Benchmark Suite: Implementation on KubeEdge-lanvs + + +## dataset + +### Prepare Data + +The data of llm-edge-benchmark-suite example structure is: + +``` +. +├── test_data +│ └── data.jsonl +└── train_data + └── data.jsonl +``` + +`train_data/data.jsonl` is empty, and the `test_data/data.jsonl` is as follows: + +``` +{"question": "Which of the following numbers is the smallest prime number?\nA. 0\nB. 1\nC. 2\nD. 4", "answer": "C"} +``` +### prepare env + +```shell +python setup.py install +``` + +### Run Ianvs + + + +```shell +ianvs -f examples/llm-edge-benchmark-suite/single_task_bench/benchmarkingjob.yaml +``` + + +```shell +ianvs -f examples/llm-edge-benchmark-suite/single_task_bench_with_compression/benchmarkingjob.yaml +``` + diff --git a/examples/llm-edge-benchmark-suite/single_task_bench/README.md b/examples/llm-edge-benchmark-suite/single_task_bench/README.md new file mode 100644 index 00000000..3a3835c7 --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench/README.md @@ -0,0 +1,2 @@ +Large Language Model Edge Benchmark Suite: Implementation on KubeEdge-lanvs + diff --git a/examples/llm-edge-benchmark-suite/single_task_bench/benchmarkingjob.yaml b/examples/llm-edge-benchmark-suite/single_task_bench/benchmarkingjob.yaml new file mode 100644 index 00000000..ae2b23c6 --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench/benchmarkingjob.yaml @@ -0,0 +1,30 @@ +benchmarkingjob: + name: "benchmarkingjob" + workspace: "./workspace" + + testenv: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/testenv.yaml" + + test_object: + type: "algorithms" + algorithms: + - name: "llama-cpp" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/algorithm.yaml" + + rank: + sort_by: + - { "latency": "descend" } + - { "throughput": "ascend" } + - { "mem_usage": "ascend" } + - { "prefill_latency": "ascend"} + + visualization: + mode: "selected_only" + method: "print_table" + + selected_dataitem: + paradigms: [ "all" ] + modules: [ "all" ] + hyperparameters: [ "all" ] + metrics: [ "latency", "throughput", "prefill_latency" ] + + save_mode: "selected_and_all" \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench/testalgorithms/algorithm.yaml b/examples/llm-edge-benchmark-suite/single_task_bench/testalgorithms/algorithm.yaml new file mode 100644 index 00000000..d15c4326 --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench/testalgorithms/algorithm.yaml @@ -0,0 +1,16 @@ +algorithm: + paradigm_type: "singletasklearningwithcompression" + + initial_model_url: "models/qwen/qwen_1_5_0_5b.gguf" + + modules: + - type: "basemodel" + name: "LlamaCppModel" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/basemodel.py" + hyperparameters: + - model_path: + values: + - "models/qwen/qwen_1_5_0_5b.gguf" + - n_ctx: + values: + - 2048 \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench/testalgorithms/basemodel.py b/examples/llm-edge-benchmark-suite/single_task_bench/testalgorithms/basemodel.py new file mode 100644 index 00000000..477cc61a --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench/testalgorithms/basemodel.py @@ -0,0 +1,135 @@ +from sedna.common.class_factory import ClassFactory, ClassType +from llama_cpp import Llama +from contextlib import redirect_stderr +import os +import psutil +import time +import io +import statistics +import logging + +logging.getLogger().setLevel(logging.INFO) + +@ClassFactory.register(ClassType.GENERAL, alias="LlamaCppModel") +class LlamaCppModel: + def __init__(self, **kwargs): + """ + init llama-cpp + """ + model_path = kwargs.get("model_path") + if not model_path: + raise ValueError("Model path is required.") + quantization_type = kwargs.get("quantization_type", None) + if quantization_type: + logging.info(f"Using quantization type: {quantization_type}") + # Init LLM model + self.model = Llama( + model_path=model_path, + n_ctx=kwargs.get("n_ctx", 512), + n_gpu_layers=kwargs.get("n_gpu_layers", 0), + seed=kwargs.get("seed", -1), + f16_kv=kwargs.get("f16_kv", True), + logits_all=kwargs.get("logits_all", False), + vocab_only=kwargs.get("vocab_only", False), + use_mlock=kwargs.get("use_mlock", False), + embedding=kwargs.get("embedding", False), + ) + + def predict(self, data, input_shape=None, **kwargs): + data = data[:10] + process = psutil.Process(os.getpid()) + start_time = time.time() + + results = [] + total_times = [] + prefill_latencies = [] + mem_usages = [] + + for prompt in data: + prompt_start_time = time.time() + + f = io.StringIO() + with redirect_stderr(f): + output = self.model( + prompt=prompt, + max_tokens=kwargs.get("max_tokens", 32), + stop=kwargs.get("stop", ["Q:", "\n"]), + echo=kwargs.get("echo", True), + temperature=kwargs.get("temperature", 0.8), + top_p=kwargs.get("top_p", 0.95), + top_k=kwargs.get("top_k", 40), + repeat_penalty=kwargs.get("repeat_penalty", 1.1), + ) + stdout_output = f.getvalue() + + # parse timing info + timings = self._parse_timings(stdout_output) + prefill_latency = timings.get('prompt_eval_time', 0.0) # ms + generated_text = output['choices'][0]['text'] + + prompt_end_time = time.time() + prompt_total_time = (prompt_end_time - prompt_start_time) * 1000 # convert to ms + + result_with_time = { + "generated_text": generated_text, + "total_time": prompt_total_time, + "prefill_latency": prefill_latency, + "mem_usage":process.memory_info().rss, + } + + results.append(result_with_time) + + predict_dict = { + "results": results, + } + + return predict_dict + + def _parse_timings(self, stdout_output): + import re + timings = {} + for line in stdout_output.split('\n'): + match = re.match(r'llama_print_timings:\s*(.+?)\s*=\s*([0-9\.]+)\s*ms', line) + if match: + key = match.group(1).strip() + value = float(match.group(2)) + + key = key.lower().replace(' ', '_') + timings[key] = value + + return timings + + def evaluate(self, data, model_path=None, **kwargs): + """ + evaluate model + """ + if data is None or data.x is None: + raise ValueError("Evaluation data is None.") + + if model_path: + self.load(model_path) + + # do predict + predict_dict = self.predict(data.x, **kwargs) + + # compute metrics + metric = kwargs.get("metric") + if metric is None: + raise ValueError("No metric provided in kwargs.") + + metric_name, metric_func = metric + + if callable(metric_func): + metric_value = metric_func(None, predict_dict["results"]) + return {metric_name: metric_value} + else: + raise ValueError(f"Metric function {metric_name} is not callable or not provided.") + + def save(self, model_path): + pass + + def load(self, model_url): + pass + + def train(self, train_data, valid_data=None, **kwargs): + return \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench/testalgorithms/download_model_modelscope.py b/examples/llm-edge-benchmark-suite/single_task_bench/testalgorithms/download_model_modelscope.py new file mode 100644 index 00000000..15f8933c --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench/testalgorithms/download_model_modelscope.py @@ -0,0 +1,25 @@ +import os +import argparse +import logging +from modelscope import snapshot_download + +logging.getLogger().setLevel(logging.INFO) + +def download_model(model_id, revision, local_dir): + try: + model_dir = snapshot_download(model_id, revision=revision, cache_dir=local_dir) + logging.info(f"Model successfully downloaded to: {model_dir}") + return model_dir + except Exception as e: + logging.info(f"Error downloading model: {str(e)}") + return None + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download a model from ModelScope") + parser.add_argument("--model_id", type=str, required=True, help="ModelScope model ID") + parser.add_argument("--revision", type=str, default="master", help="Model revision") + parser.add_argument("--local_dir", type=str, required=True, help="Local directory to save the model") + + args = parser.parse_args() + + download_model(args.model_id, args.revision, args.local_dir) \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench/testenv/latency.py b/examples/llm-edge-benchmark-suite/single_task_bench/testenv/latency.py new file mode 100644 index 00000000..c561cc7d --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench/testenv/latency.py @@ -0,0 +1,29 @@ +# Copyright 2023 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sedna.common.class_factory import ClassType, ClassFactory +import statistics + +__all__ = ["latency"] + + +@ClassFactory.register(ClassType.GENERAL, alias="latency") +def latency(y_true, y_pred): + results_list = y_pred.get('results', []) + num_requests = len(results_list) + total_latency = 0.0 + for result in results_list: + total_latency += result['total_time'] + average_latency = total_latency / num_requests + return average_latency \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench/testenv/mem_usage.py b/examples/llm-edge-benchmark-suite/single_task_bench/testenv/mem_usage.py new file mode 100644 index 00000000..3d57b672 --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench/testenv/mem_usage.py @@ -0,0 +1,13 @@ +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ["mem_usage"] + +@ClassFactory.register(ClassType.GENERAL, alias="mem_usage") +def mem_usage(y_true, y_pred): + results_list = y_pred.get('results', []) + total_mem_usage = 0.0 + num_requests = len(results_list) + for result in results_list: + total_mem_usage += result['mem_usage'] + average_mem_usage = total_mem_usage / num_requests + return average_mem_usage \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench/testenv/prefill_latency.py b/examples/llm-edge-benchmark-suite/single_task_bench/testenv/prefill_latency.py new file mode 100644 index 00000000..b7743577 --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench/testenv/prefill_latency.py @@ -0,0 +1,13 @@ +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ["prefill_latency"] + +@ClassFactory.register(ClassType.GENERAL, alias="prefill_latency") +def prefill_latency(y_true, y_pred): + results_list = y_pred.get('results', []) + num_requests = len(results_list) + total_prefill_latency = 0.0 + for result in results_list: + total_prefill_latency += result['prefill_latency'] + avg_prefill_latency = total_prefill_latency / num_requests + return avg_prefill_latency \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench/testenv/testenv.yaml b/examples/llm-edge-benchmark-suite/single_task_bench/testenv/testenv.yaml new file mode 100644 index 00000000..e4a6e88a --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench/testenv/testenv.yaml @@ -0,0 +1,14 @@ +testenv: + dataset: + train_data: "ianvs/government/objective/train_data/data.jsonl" + test_data: "ianvs/government/objective/test_data/data.jsonl" + use_gpu: true + metrics: + - name: "latency" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/latency.py" + - name: "throughput" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/throughput.py" + - name: "prefill_latency" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/prefill_latency.py" + - name: "mem_usage" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/mem_usage.py" diff --git a/examples/llm-edge-benchmark-suite/single_task_bench/testenv/throughput.py b/examples/llm-edge-benchmark-suite/single_task_bench/testenv/throughput.py new file mode 100644 index 00000000..3ad7a05a --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench/testenv/throughput.py @@ -0,0 +1,30 @@ +# Copyright 2023 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ["throughput"] + +@ClassFactory.register(ClassType.GENERAL, alias="throughput") +def throughput(y_true, y_pred): + # total_time = y_pred.get('avg_total_time', []) + results_list = y_pred.get('results', []) + num_requests = len(results_list) + total_latency = 0.0 + for result in results_list: + total_latency += result['total_time'] + avg_throughput = num_requests /total_latency + return avg_throughput \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/README.md b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/README.md new file mode 100644 index 00000000..3a3835c7 --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/README.md @@ -0,0 +1,2 @@ +Large Language Model Edge Benchmark Suite: Implementation on KubeEdge-lanvs + diff --git a/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/benchmarkingjob.yaml b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/benchmarkingjob.yaml new file mode 100644 index 00000000..ae2b23c6 --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/benchmarkingjob.yaml @@ -0,0 +1,30 @@ +benchmarkingjob: + name: "benchmarkingjob" + workspace: "./workspace" + + testenv: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/testenv.yaml" + + test_object: + type: "algorithms" + algorithms: + - name: "llama-cpp" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/algorithm.yaml" + + rank: + sort_by: + - { "latency": "descend" } + - { "throughput": "ascend" } + - { "mem_usage": "ascend" } + - { "prefill_latency": "ascend"} + + visualization: + mode: "selected_only" + method: "print_table" + + selected_dataitem: + paradigms: [ "all" ] + modules: [ "all" ] + hyperparameters: [ "all" ] + metrics: [ "latency", "throughput", "prefill_latency" ] + + save_mode: "selected_and_all" \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/algorithm.yaml b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/algorithm.yaml new file mode 100644 index 00000000..1fdd5d5b --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/algorithm.yaml @@ -0,0 +1,18 @@ +algorithm: + paradigm_type: "singletasklearning_with_compression" + mode: "with_compression" + initial_model_url: "models/qwen/qwen_1_5_0_5b.gguf" + quantization_type: "q8_0" + llama_quantize_path: "llama.cpp/llama-quantize" + modules: + - type: "basemodel" + name: "LlamaCppModel" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/basemodel.py" + + hyperparameters: + - model_path: + values: + - "models/qwen/qwen_1_5_0_5b.gguf" + - n_ctx: + values: + - 2048 \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/basemodel.py b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/basemodel.py new file mode 100644 index 00000000..4ad1634b --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/basemodel.py @@ -0,0 +1,129 @@ +from sedna.common.class_factory import ClassFactory, ClassType +from llama_cpp import Llama +from contextlib import redirect_stderr +import os +import psutil +import time +import io +import statistics + +@ClassFactory.register(ClassType.GENERAL, alias="LlamaCppModel") +class LlamaCppModel: + def __init__(self, **kwargs): + """ + init llama-cpp + """ + model_path = kwargs.get("model_path") + if not model_path: + raise ValueError("Model path is required.") + quantization_type = kwargs.get("quantization_type", None) + # Init LLM model + self.model = Llama( + model_path=model_path, + n_ctx=kwargs.get("n_ctx", 512), + n_gpu_layers=kwargs.get("n_gpu_layers", 0), + seed=kwargs.get("seed", -1), + f16_kv=kwargs.get("f16_kv", True), + logits_all=kwargs.get("logits_all", False), + vocab_only=kwargs.get("vocab_only", False), + use_mlock=kwargs.get("use_mlock", False), + embedding=kwargs.get("embedding", False), + ) + + def predict(self, data, input_shape=None, **kwargs): + data = data[:10] + process = psutil.Process(os.getpid()) + start_time = time.time() + + results = [] + total_times = [] + prefill_latencies = [] + mem_usages = [] + + for prompt in data: + prompt_start_time = time.time() + + f = io.StringIO() + with redirect_stderr(f): + output = self.model( + prompt=prompt, + max_tokens=kwargs.get("max_tokens", 32), + stop=kwargs.get("stop", ["Q:", "\n"]), + echo=kwargs.get("echo", True), + temperature=kwargs.get("temperature", 0.8), + top_p=kwargs.get("top_p", 0.95), + top_k=kwargs.get("top_k", 40), + repeat_penalty=kwargs.get("repeat_penalty", 1.1), + ) + stdout_output = f.getvalue() + + # parse timing info + timings = self._parse_timings(stdout_output) + prefill_latency = timings.get('prompt_eval_time', 0.0) # ms + generated_text = output['choices'][0]['text'] + + prompt_end_time = time.time() + prompt_total_time = (prompt_end_time - prompt_start_time) * 1000 # convert to ms + + result_with_time = { + "generated_text": generated_text, + "total_time": prompt_total_time, + "prefill_latency": prefill_latency, + "mem_usage":process.memory_info().rss, + } + + results.append(result_with_time) + + predict_dict = { + "results": results, + } + + return predict_dict + + def _parse_timings(self, stdout_output): + import re + timings = {} + for line in stdout_output.split('\n'): + match = re.match(r'llama_print_timings:\s*(.+?)\s*=\s*([0-9\.]+)\s*ms', line) + if match: + key = match.group(1).strip() + value = float(match.group(2)) + + key = key.lower().replace(' ', '_') + timings[key] = value + return timings + + def evaluate(self, data, model_path=None, **kwargs): + """ + evaluate model + """ + if data is None or data.x is None: + raise ValueError("Evaluation data is None.") + + if model_path: + self.load(model_path) + + # do predict + predict_dict = self.predict(data.x, **kwargs) + + # compute metrics + metric = kwargs.get("metric") + if metric is None: + raise ValueError("No metric provided in kwargs.") + + metric_name, metric_func = metric + + if callable(metric_func): + metric_value = metric_func(None, predict_dict["results"]) + return {metric_name: metric_value} + else: + raise ValueError(f"Metric function {metric_name} is not callable or not provided.") + + def save(self, model_path): + pass + + def load(self, model_url): + pass + + def train(self, train_data, valid_data=None, **kwargs): + return \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/download_model_modelscope.py b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/download_model_modelscope.py new file mode 100644 index 00000000..1adda3af --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testalgorithms/download_model_modelscope.py @@ -0,0 +1,25 @@ +import os +import argparse +from modelscope import snapshot_download +import logging + +logging.getLogger().setLevel(logging.INFO) + +def download_model(model_id, revision, local_dir): + try: + model_dir = snapshot_download(model_id, revision=revision, cache_dir=local_dir) + logging.info(f"Model successfully downloaded to: {model_dir}") + return model_dir + except Exception as e: + logging.info(f"Error downloading model: {str(e)}") + return None + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Download a model from ModelScope") + parser.add_argument("--model_id", type=str, required=True, help="ModelScope model ID") + parser.add_argument("--revision", type=str, default="master", help="Model revision") + parser.add_argument("--local_dir", type=str, required=True, help="Local directory to save the model") + + args = parser.parse_args() + + download_model(args.model_id, args.revision, args.local_dir) \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/latency.py b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/latency.py new file mode 100644 index 00000000..c561cc7d --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/latency.py @@ -0,0 +1,29 @@ +# Copyright 2023 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sedna.common.class_factory import ClassType, ClassFactory +import statistics + +__all__ = ["latency"] + + +@ClassFactory.register(ClassType.GENERAL, alias="latency") +def latency(y_true, y_pred): + results_list = y_pred.get('results', []) + num_requests = len(results_list) + total_latency = 0.0 + for result in results_list: + total_latency += result['total_time'] + average_latency = total_latency / num_requests + return average_latency \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/mem_usage.py b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/mem_usage.py new file mode 100644 index 00000000..3d57b672 --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/mem_usage.py @@ -0,0 +1,13 @@ +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ["mem_usage"] + +@ClassFactory.register(ClassType.GENERAL, alias="mem_usage") +def mem_usage(y_true, y_pred): + results_list = y_pred.get('results', []) + total_mem_usage = 0.0 + num_requests = len(results_list) + for result in results_list: + total_mem_usage += result['mem_usage'] + average_mem_usage = total_mem_usage / num_requests + return average_mem_usage \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/prefill_latency.py b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/prefill_latency.py new file mode 100644 index 00000000..b7743577 --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/prefill_latency.py @@ -0,0 +1,13 @@ +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ["prefill_latency"] + +@ClassFactory.register(ClassType.GENERAL, alias="prefill_latency") +def prefill_latency(y_true, y_pred): + results_list = y_pred.get('results', []) + num_requests = len(results_list) + total_prefill_latency = 0.0 + for result in results_list: + total_prefill_latency += result['prefill_latency'] + avg_prefill_latency = total_prefill_latency / num_requests + return avg_prefill_latency \ No newline at end of file diff --git a/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/testenv.yaml b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/testenv.yaml new file mode 100644 index 00000000..69de256f --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/testenv.yaml @@ -0,0 +1,14 @@ +testenv: + dataset: + train_data: "ianvs/government/objective/train_data/data.jsonl" + test_data: "ianvs/government/objective/test_data/data.jsonl" + use_gpu: false + metrics: + - name: "latency" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/latency.py" + - name: "throughput" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/throughput.py" + - name: "prefill_latency" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/prefill_latency.py" + - name: "mem_usage" + url: "./examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/mem_usage.py" diff --git a/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/throughput.py b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/throughput.py new file mode 100644 index 00000000..3ad7a05a --- /dev/null +++ b/examples/llm-edge-benchmark-suite/single_task_bench_with_compression/testenv/throughput.py @@ -0,0 +1,30 @@ +# Copyright 2023 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ["throughput"] + +@ClassFactory.register(ClassType.GENERAL, alias="throughput") +def throughput(y_true, y_pred): + # total_time = y_pred.get('avg_total_time', []) + results_list = y_pred.get('results', []) + num_requests = len(results_list) + total_latency = 0.0 + for result in results_list: + total_latency += result['total_time'] + avg_throughput = num_requests /total_latency + return avg_throughput \ No newline at end of file diff --git a/examples/llm_simple_qa/README.md b/examples/llm_simple_qa/README.md new file mode 100644 index 00000000..dbaf845a --- /dev/null +++ b/examples/llm_simple_qa/README.md @@ -0,0 +1,84 @@ +# README + +## Simple QA + +### Prepare Data + +The data of simple-qa example structure is: + +``` +. +├── test_data +│ └── data.jsonl +└── train_data + └── data.jsonl +``` + +`train_data/data.jsonl` is empty, and the `test_data/data.jsonl` is as follows: + +``` +{ + "question": "If Xiao Ming has 5 apples, and he gives 3 to Xiao Hua, how many apples does Xiao Ming have left?\nA. 2\nB. 3\nC. 4\nD. 5", + "answer": "A" +} +{ + "question": "Which of the following numbers is the smallest prime number?\nA. 0\nB. 1\nC. 2\nD. 4", + "answer": "C" +} +{ + "question": "A rectangle has a length of 10 centimeters and a width of 5 centimeters, what is its perimeter in centimeters?\nA. 20 centimeters\nB. 30 centimeters\nC. 40 centimeters\nD. 50 centimeters", + "answer": "B" +} +{ + "question": "Which of the following fractions is closest to 1?\nA. 1/2\nB. 3/4\nC. 4/5\nD. 5/6", + "answer": "D" +} +{ + "question": "If a number plus 10 equals 30, what is the number?\nA. 20\nB. 21\nC. 22\nD. 23", + "answer": "A" +} +{ + "question": "Which of the following expressions has the largest result?\nA. 3 + 4\nB. 5 - 2\nC. 6 * 2\nD. 7 ÷ 2", + "answer": "C" +} +{ + "question": "A class has 24 students, and if each student brings 2 books, how many books are there in total?\nA. 48\nB. 36\nC. 24\nD. 12", + "answer": "A" +} +{ + "question": "Which of the following is the correct multiplication rhyme?\nA. Three threes are seven\nB. Four fours are sixteen\nC. Five fives are twenty-five\nD. Six sixes are thirty-six", + "answer": "B" +} +{ + "question": "If one number is three times another number, and this number is 15, what is the other number?\nA. 5\nB. 10\nC. 15\nD. 45", + "answer": "A" +} +{ + "question": "Which of the following shapes has the longest perimeter?\nA. Square\nB. Rectangle\nC. Circle\nD. Triangle", + "answer": "C" +} +``` + +### Prepare Environment + +You need to install the changed-sedna package, which added `JsonlDataParse` in `sedna.datasources` + +Replace the file in `yourpath/anaconda3/envs/ianvs/lib/python3.x/site-packages/sedna` with `examples/resources/sedna-with-jsonl.zip` + + +### Run Ianvs + +Run the following command: + +`ianvs -f examples/llm/singletask_learning_bench/simple_qa/benchmarkingjob.yaml` + +## OpenCompass Evaluation + +### Prepare Environment + +`pip install examples/resources/opencompass-0.2.5-py3-none-any.whl` + +### Run Evaluation + +`python run_op.py examples/llm/singletask_learning_bench/simple_qa/testalgorithms/gen/op_eval.py` + diff --git a/examples/llm_simple_qa/benchmarkingjob.yaml b/examples/llm_simple_qa/benchmarkingjob.yaml new file mode 100644 index 00000000..78961e52 --- /dev/null +++ b/examples/llm_simple_qa/benchmarkingjob.yaml @@ -0,0 +1,72 @@ +benchmarkingjob: + # job name of bechmarking; string type; + name: "benchmarkingjob" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "/home/icyfeather/project/ianvs/workspace" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/llm/singletask_learning_bench/simple_qa/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "simple_qa_singletask_learning" + # the url address of test algorithm configuration file; string type; + # the file format supports yaml/yml; + url: "./examples/llm/singletask_learning_bench/simple_qa/testalgorithms/gen/gen_algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "acc": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "f1_score" + metrics: [ "acc" ] + + # model of save selected and all dataitems in workspace; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + + diff --git a/examples/llm_simple_qa/testalgorithms/gen/basemodel.py b/examples/llm_simple_qa/testalgorithms/gen/basemodel.py new file mode 100644 index 00000000..fdeedc98 --- /dev/null +++ b/examples/llm_simple_qa/testalgorithms/gen/basemodel.py @@ -0,0 +1,98 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import, division, print_function + +import os +import tempfile +import time +import zipfile +import logging + +import numpy as np +from sedna.common.config import Context +from sedna.common.class_factory import ClassType, ClassFactory + + +from transformers import AutoModelForCausalLM, AutoTokenizer +device = "cuda" # the device to load the model onto + + +logging.disable(logging.WARNING) + +__all__ = ["BaseModel"] + +os.environ['BACKEND_TYPE'] = 'TORCH' + + +@ClassFactory.register(ClassType.GENERAL, alias="gen") +class BaseModel: + + def __init__(self, **kwargs): + self.model = AutoModelForCausalLM.from_pretrained( + "/home/icyfeather/models/Qwen2-0.5B-Instruct", + torch_dtype="auto", + device_map="auto" + ) + self.tokenizer = AutoTokenizer.from_pretrained("/home/icyfeather/models/Qwen2-0.5B-Instruct") + + def train(self, train_data, valid_data=None, **kwargs): + print("BaseModel doesn't need to train") + + + def save(self, model_path): + print("BaseModel doesn't need to save") + + def predict(self, data, input_shape=None, **kwargs): + print("BaseModel predict") + answer_list = [] + for line in data: + response = self._infer(line) + answer_list.append(response) + return answer_list + + def load(self, model_url=None): + print("BaseModel load") + + def evaluate(self, data, model_path, **kwargs): + print("BaseModel evaluate") + + def _infer(self, prompt, system=None): + if system: + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": prompt} + ] + else: + messages = [ + {"role": "user", "content": prompt} + ] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + model_inputs = self.tokenizer([text], return_tensors="pt").to(device) + + generated_ids = self.model.generate( + model_inputs.input_ids, + max_new_tokens=512 + ) + generated_ids = [ + output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + ] + + response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] + + return response diff --git a/examples/llm_simple_qa/testalgorithms/gen/gen_algorithm.yaml b/examples/llm_simple_qa/testalgorithms/gen/gen_algorithm.yaml new file mode 100644 index 00000000..6536ceb9 --- /dev/null +++ b/examples/llm_simple_qa/testalgorithms/gen/gen_algorithm.yaml @@ -0,0 +1,18 @@ +algorithm: + # paradigm name; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + paradigm_type: "singletasklearning" + + # algorithm module configuration in the paradigm; list type; + modules: + # kind of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel" + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "gen" + # the url address of python module; string type; + url: "./examples/llm/singletask_learning_bench/simple_qa/testalgorithms/gen/basemodel.py" \ No newline at end of file diff --git a/examples/llm_simple_qa/testalgorithms/gen/op_eval.py b/examples/llm_simple_qa/testalgorithms/gen/op_eval.py new file mode 100644 index 00000000..dc6d9c04 --- /dev/null +++ b/examples/llm_simple_qa/testalgorithms/gen/op_eval.py @@ -0,0 +1,21 @@ +from mmengine.config import read_base +from opencompass.models import HuggingFacewithChatTemplate +# import sys +# sys.path.append('/home/icyfeather/project/ianvs') + +with read_base(): + from core.op_extra.datasets.cmmlu.cmmlu_gen import cmmlu_datasets + +datasets = [*cmmlu_datasets] + +models = [ + dict( + type=HuggingFacewithChatTemplate, + abbr='qwen1.5-1.8b-chat-hf', + path='/home/icyfeather/models/Qwen1.5-1.8B-Chat', + max_out_len=1024, + batch_size=2, + run_cfg=dict(num_gpus=1), + stop_words=['<|im_end|>', '<|im_start|>'], + ) +] diff --git a/examples/llm_simple_qa/testenv/acc.py b/examples/llm_simple_qa/testenv/acc.py new file mode 100644 index 00000000..beccdadf --- /dev/null +++ b/examples/llm_simple_qa/testenv/acc.py @@ -0,0 +1,40 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ["acc"] + +def get_last_letter(input_string): + if not input_string or not any(char.isalpha() for char in input_string): + return None + + for char in reversed(input_string): + if 'A' <= char <= 'D': + return char + + return None + + +@ClassFactory.register(ClassType.GENERAL, alias="acc") +def acc(y_true, y_pred): + y_pred = [get_last_letter(pred) for pred in y_pred] + print(y_true) + print(y_pred) + + same_elements = [y_pred[i] == y_true[i] for i in range(len(y_pred))] + + acc = sum(same_elements) / len(same_elements) + + return acc diff --git a/examples/llm_simple_qa/testenv/testenv.yaml b/examples/llm_simple_qa/testenv/testenv.yaml new file mode 100644 index 00000000..0bc7239f --- /dev/null +++ b/examples/llm_simple_qa/testenv/testenv.yaml @@ -0,0 +1,14 @@ +testenv: + # dataset configuration + dataset: + # the url address of train dataset index; string type; + train_data: "/home/icyfeather/Projects/ianvs/dataset/llm_simple_qa/train_data/data.jsonl" + # the url address of test dataset index; string type; + test_data: "/home/icyfeather/Projects/ianvs/dataset/llm_simple_qa/test_data/data.jsonl" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + - name: "acc" + # the url address of python file + url: "./examples/llm/singletask_learning_bench/simple_qa/testenv/acc.py" diff --git a/examples/resources/sedna-llm.zip b/examples/resources/sedna-llm.zip new file mode 100644 index 00000000..8ea3c0d3 Binary files /dev/null and b/examples/resources/sedna-llm.zip differ