Skip to content

Commit

Permalink
add: add a simple demo of feadereated learning in ianvs
Browse files Browse the repository at this point in the history
Signed-off-by: Marchons <d12863606746@outlook.com>
Signed-off-by: Marchons <1286360646@qq.com>
  • Loading branch information
Yoda-wu committed Oct 28, 2024
1 parent bef4daf commit b9199de
Show file tree
Hide file tree
Showing 82 changed files with 7,147 additions and 145 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,11 @@ jobs:
python -m pip install ${{github.workspace}}/examples/resources/third_party/*
python -m pip install -r ${{github.workspace}}/requirements.txt
- name: Analysing code of core with pylint
# `--max-positional-arguments=10` is set for Python 3.9 to avoid `R0917: too-many-positional-arguments`.
# See details at https://github.com/kubeedge/ianvs/issues/157
run: |
pylint '${{github.workspace}}/core'
if [ "${{ matrix.python-version }}" = "3.9" ]; then
pylint --max-positional-arguments=10 '${{github.workspace}}/core'
else
pylint '${{github.workspace}}/core'
fi
11 changes: 11 additions & 0 deletions core/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class DatasetFormat(Enum):
File format of inputting dataset.
Currently, file formats are as follows: txt, csv.
"""

CSV = "csv"
TXT = "txt"
JSON = "json"
Expand All @@ -31,16 +32,20 @@ class ParadigmType(Enum):
"""
Algorithm paradigm type.
"""

SINGLE_TASK_LEARNING = "singletasklearning"
INCREMENTAL_LEARNING = "incrementallearning"
MULTIEDGE_INFERENCE = "multiedgeinference"
LIFELONG_LEARNING = "lifelonglearning"
FEDERATED_LEARNING = "federatedlearning"
FEDERATED_CLASS_INCREMENTAL_LEARNING = "federatedclassincrementallearning"


class ModuleType(Enum):
"""
Algorithm module type.
"""

BASEMODEL = "basemodel"

# HEM
Expand All @@ -63,20 +68,26 @@ class ModuleType(Enum):
UNSEEN_SAMPLE_RECOGNITION = "unseen_sample_recognition"
UNSEEN_SAMPLE_RE_RECOGNITION = "unseen_sample_re_recognition"

# FL_AGG
AGGREGATION = "aggregation"


class SystemMetricType(Enum):
"""
System metric type of ianvs.
"""

SAMPLES_TRANSFER_RATIO = "samples_transfer_ratio"
FWT = "FWT"
BWT = "BWT"
TASK_AVG_ACC = "task_avg_acc"
MATRIX = "MATRIX"
FORGET_RATE = "forget_rate"


class TestObjectType(Enum):
"""
Test object type of ianvs.
"""

ALGORITHMS = "algorithms"
50 changes: 33 additions & 17 deletions core/storymanager/rank/rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,12 @@ class Rank:

def __init__(self, config):
self.sort_by: list = []
self.visualization: dict = {
"mode": "selected_only",
"method": "print_table"
}
self.visualization: dict = {"mode": "selected_only", "method": "print_table"}
self.selected_dataitem: dict = {
"paradigms": ["all"],
"modules": ["all"],
"hyperparameters": ["all"],
"metrics": ["all"]
"metrics": ["all"],
}
self.save_mode: str = "selected_and_all"

Expand All @@ -62,15 +59,21 @@ def _parse_config(self, config):

def _check_fields(self):
if not self.sort_by and not isinstance(self.sort_by, list):
raise ValueError(f"rank's sort_by({self.sort_by}) must be provided and be list type.")
raise ValueError(
f"rank's sort_by({self.sort_by}) must be provided and be list type."
)

if not self.visualization and not isinstance(self.visualization, dict):
raise ValueError(f"rank's visualization({self.visualization}) "
f"must be provided and be dict type.")
raise ValueError(
f"rank's visualization({self.visualization}) "
f"must be provided and be dict type."
)

if not self.selected_dataitem and not isinstance(self.selected_dataitem, dict):
raise ValueError(f"rank's selected_dataitem({self.selected_dataitem}) "
f"must be provided and be dict type.")
raise ValueError(
f"rank's selected_dataitem({self.selected_dataitem}) "
f"must be provided and be dict type."
)

if not self.selected_dataitem.get("paradigms"):
raise ValueError("not found paradigms of selected_dataitem in rank.")
Expand All @@ -82,8 +85,10 @@ def _check_fields(self):
raise ValueError("not found metrics of selected_dataitem in rank.")

if not self.save_mode and not isinstance(self.save_mode, list):
raise ValueError(f"rank's save_mode({self.save_mode}) "
f"must be provided and be list type.")
raise ValueError(
f"rank's save_mode({self.save_mode}) "
f"must be provided and be list type."
)

@classmethod
def _get_all_metric_names(cls, test_results) -> list:
Expand Down Expand Up @@ -133,7 +138,7 @@ def _sort_all_df(self, all_df, all_metric_names):

if metric_name not in all_metric_names:
continue

print(metric_name)
sort_metric_list.append(metric_name)
is_ascend_list.append(ele.get(metric_name) == "ascend")

Expand Down Expand Up @@ -198,7 +203,15 @@ def _get_selected(self, test_cases, test_results) -> pd.DataFrame:
if metric_names == ["all"]:
metric_names = self._get_all_metric_names(test_results)

header = ["algorithm", *metric_names, "paradigm", *module_types, *hps_names, "time", "url"]
header = [
"algorithm",
*metric_names,
"paradigm",
*module_types,
*hps_names,
"time",
"url",
]

all_df = copy.deepcopy(self.all_df)
selected_df = pd.DataFrame(all_df, columns=header)
Expand All @@ -220,14 +233,16 @@ def _draw_pictures(self, test_cases, test_results):
for test_case in test_cases:
out_put = test_case.output_dir
test_result = test_results[test_case.id][0]
matrix = test_result.get('Matrix')
#print(out_put)
matrix = test_result.get("Matrix")
# print(out_put)
for key in matrix.keys():
draw_heatmap_picture(out_put, key, matrix[key])

def _prepare(self, test_cases, test_results, output_dir):
all_metric_names = self._get_all_metric_names(test_results)
print(f"in_prepare all_metric_names: {all_metric_names}")
all_hps_names = self._get_all_hps_names(test_cases)
print(f"in_prepare all_hps_names: {all_hps_names}")
all_module_types = self._get_all_module_types(test_cases)
self.all_df_header = [
"algorithm", *all_metric_names,
Expand Down Expand Up @@ -285,4 +300,5 @@ def plot(self):
except Exception as err:
raise RuntimeError(
f"process visualization(method={method}) of "
f"rank file({self.selected_rank_file}) failed, error: {err}.") from err
f"rank file({self.selected_rank_file}) failed, error: {err}."
) from err
18 changes: 17 additions & 1 deletion core/testcasecontroller/algorithm/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
IncrementalLearning,
MultiedgeInference,
LifelongLearning,
FederatedLearning,
FederatedClassIncrementalLearning
)
from core.testcasecontroller.generation_assistant import get_full_combinations

Expand Down Expand Up @@ -64,12 +66,21 @@ def __init__(self, name, config):
"train_ratio": 0.8,
"splitting_method": "default"
}
self.fl_data_setting: dict = {
"train_ratio": 1.0,
"splitting_method": "default",
"data_partition": "iid",
'non_iid_ratio': 0.6,
"label_data_ratio": 1.0
}

self.initial_model_url: str = ""
self.modules: list = []
self.modules_list = None
self._parse_config(config)
self._load_third_party_packages()

# pylint: disable=R0911
def paradigm(self, workspace: str, **kwargs):
"""
get test process of AI algorithm paradigm.
Expand All @@ -91,7 +102,6 @@ def paradigm(self, workspace: str, **kwargs):
# pylint: disable=C0103
for k, v in self.__dict__.items():
config.update({k: v})

if self.paradigm_type == ParadigmType.SINGLE_TASK_LEARNING.value:
return SingleTaskLearning(workspace, **config)

Expand All @@ -104,6 +114,12 @@ def paradigm(self, workspace: str, **kwargs):
if self.paradigm_type == ParadigmType.LIFELONG_LEARNING.value:
return LifelongLearning(workspace, **config)

if self.paradigm_type == ParadigmType.FEDERATED_LEARNING.value:
return FederatedLearning(workspace, **config)

if self.paradigm_type == ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value:
return FederatedClassIncrementalLearning(workspace, **config)

return None

def _check_fields(self):
Expand Down
16 changes: 15 additions & 1 deletion core/testcasecontroller/algorithm/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def get_module_instance(self, module_type):
function
"""
print(f'hyperparameters_list: {self.hyperparameters_list}')
class_factory_type = ClassType.GENERAL
if module_type in [ModuleType.HARD_EXAMPLE_MINING.value]:
class_factory_type = ClassType.HEM
Expand All @@ -106,14 +107,27 @@ def get_module_instance(self, module_type):
elif module_type in [ModuleType.UNSEEN_SAMPLE_RECOGNITION.value,
ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value]:
class_factory_type = ClassType.UTD
elif module_type in [ModuleType.AGGREGATION.value]:
class_factory_type = ClassType.FL_AGG
agg = None
print(self.url)
if self.url :
try:
utils.load_module(self.url)
agg = ClassFactory.get_cls(
type_name=class_factory_type, t_cls_name=self.name)(**self.hyperparameters)
print(agg)
except Exception as err:
raise RuntimeError(f"module(type={module_type} loads class(name={self.name}) "
f"failed, error: {err}.") from err
return self.name, agg

if self.url:
try:
utils.load_module(self.url)
# pylint: disable=E1134
func = ClassFactory.get_cls(
type_name=class_factory_type, t_cls_name=self.name)(**self.hyperparameters)

return func
except Exception as err:
raise RuntimeError(f"module(type={module_type} loads class(name={self.name}) "
Expand Down
1 change: 1 addition & 0 deletions core/testcasecontroller/algorithm/paradigm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .singletask_learning import SingleTaskLearning
from .multiedge_inference import MultiedgeInference
from .lifelong_learning import LifelongLearning
from .federated_learning import FederatedLearning, FederatedClassIncrementalLearning
44 changes: 31 additions & 13 deletions core/testcasecontroller/algorithm/paradigm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

from sedna.core.incremental_learning import IncrementalLearning
from sedna.core.lifelong_learning import LifelongLearning

from core.common.constant import ModuleType, ParadigmType
from .sedna_federated_learning import FederatedLearning


class ParadigmBase:
Expand Down Expand Up @@ -97,33 +97,51 @@ def build_paradigm_job(self, paradigm_type):
return IncrementalLearning(
estimator=self.module_instances.get(ModuleType.BASEMODEL.value),
hard_example_mining=self.module_instances.get(
ModuleType.HARD_EXAMPLE_MINING.value))
ModuleType.HARD_EXAMPLE_MINING.value
),
)

if paradigm_type == ParadigmType.LIFELONG_LEARNING.value:
return LifelongLearning(
estimator=self.module_instances.get(
ModuleType.BASEMODEL.value),
estimator=self.module_instances.get(ModuleType.BASEMODEL.value),
task_definition=self.module_instances.get(
ModuleType.TASK_DEFINITION.value),
ModuleType.TASK_DEFINITION.value
),
task_relationship_discovery=self.module_instances.get(
ModuleType.TASK_RELATIONSHIP_DISCOVERY.value),
ModuleType.TASK_RELATIONSHIP_DISCOVERY.value
),
task_allocation=self.module_instances.get(
ModuleType.TASK_ALLOCATION.value),
ModuleType.TASK_ALLOCATION.value
),
task_remodeling=self.module_instances.get(
ModuleType.TASK_REMODELING.value),
ModuleType.TASK_REMODELING.value
),
inference_integrate=self.module_instances.get(
ModuleType.INFERENCE_INTEGRATE.value),
ModuleType.INFERENCE_INTEGRATE.value
),
task_update_decision=self.module_instances.get(
ModuleType.TASK_UPDATE_DECISION.value),
ModuleType.TASK_UPDATE_DECISION.value
),
unseen_task_allocation=self.module_instances.get(
ModuleType.UNSEEN_TASK_ALLOCATION.value),
ModuleType.UNSEEN_TASK_ALLOCATION.value
),
unseen_sample_recognition=self.module_instances.get(
ModuleType.UNSEEN_SAMPLE_RECOGNITION.value),
ModuleType.UNSEEN_SAMPLE_RECOGNITION.value
),
unseen_sample_re_recognition=self.module_instances.get(
ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value)
ModuleType.UNSEEN_SAMPLE_RE_RECOGNITION.value
),
)
# pylint: disable=E1101
if paradigm_type == ParadigmType.MULTIEDGE_INFERENCE.value:
return self.modules_funcs.get(ModuleType.BASEMODEL.value)()

if paradigm_type in [
ParadigmType.FEDERATED_LEARNING.value,
ParadigmType.FEDERATED_CLASS_INCREMENTAL_LEARNING.value,
]:
return FederatedLearning(
estimator=self.module_instances.get(ModuleType.BASEMODEL.value)
)

return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2022 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=missing-module-docstring
from .federated_learning import FederatedLearning
from .federated_class_incremental_learning import FederatedClassIncrementalLearning
Loading

0 comments on commit b9199de

Please sign in to comment.