-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add setup.py, workflows, unit tests, and examples (#1)
* adding setup, requirements, and initial openfl_contrib dir Signed-off-by: kta-intel <kevin.ta@intel.com> * deleting unnecessary files under openfl_contrib, update setup.py, add some basic tests Signed-off-by: kta-intel <kevin.ta@intel.com> * pulling in aggregation_function/core from base openfl library Signed-off-by: kta-intel <kevin.ta@intel.com> * inheriting objects directly from openfl Signed-off-by: kta-intel <kevin.ta@intel.com> * update eden_pipeline.py Signed-off-by: kta-intel <kevin.ta@intel.com> * add contrib workspace Signed-off-by: kta-intel <kevin.ta@intel.com> * add skc_compression workspace, add tests, add workflows Signed-off-by: kta-intel <kevin.ta@intel.com> * update initialize_and_certify_workspace function name Signed-off-by: kta-intel <kevin.ta@intel.com> * remove default workspace directory Signed-off-by: kta-intel <kevin.ta@intel.com> * removing all aggregation and pipeline functions that are not used as example Signed-off-by: kta-intel <kevin.ta@intel.com> * add custom weighted averaging and remove median, add associated tests and workflows Signed-off-by: kta-intel <kevin.ta@intel.com> * add unit test for skc compression, minr fix to custom weighted average test Signed-off-by: kta-intel <kevin.ta@intel.com> * fix version and update python req to match openfl base Signed-off-by: kta-intel <kevin.ta@intel.com> * update copyright Signed-off-by: kta-intel <kevin.ta@intel.com> * additional fixes Signed-off-by: kta-intel <kevin.ta@intel.com> * fix import Signed-off-by: kta-intel <kevin.ta@intel.com> --------- Signed-off-by: kta-intel <kevin.ta@intel.com>
- Loading branch information
Showing
44 changed files
with
1,477 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a single version of Python | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: Lint with Flake8 | ||
|
||
on: | ||
pull_request: | ||
branches: [ develop ] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python 3.8 | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.8" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r requirements-linters.txt | ||
pip install . | ||
- name: Lint with flake8 | ||
run: | | ||
flake8 --show-source |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a single version of Python | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: Pytest and code coverage | ||
|
||
on: | ||
pull_request: | ||
branches: [ develop ] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
build: | ||
|
||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python 3.8 | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.8" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install pytest coverage | ||
pip install -r requirements-test.txt | ||
pip install . | ||
- name: Test with pytest and report code coverage | ||
run: | | ||
coverage run -m pytest -rA | ||
coverage report |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a single version of Python | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: Task Runner with Median Aggregation | ||
|
||
on: | ||
pull_request: | ||
branches: [ develop ] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
build: | ||
strategy: | ||
matrix: | ||
os: ['ubuntu-latest', 'windows-latest'] | ||
python-version: ['3.8','3.9','3.10','3.11'] | ||
runs-on: ${{ matrix.os }} | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies ubuntu | ||
if: matrix.os == 'ubuntu-latest' | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install . | ||
- name: Install dependencies windows | ||
if: matrix.os == 'windows-latest' | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install . | ||
- name: Test Task Runner API | ||
run: | | ||
python -m tests.github.test_task_runner --workspace torch_cnn_mnist_custom_weighted_average --col1 col1 --col2 col2 --rounds-to-train 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# This workflow will install Python dependencies, run tests and lint with a single version of Python | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: Task Runner with SKC Compression | ||
|
||
on: | ||
pull_request: | ||
branches: [ develop ] | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
build: | ||
strategy: | ||
matrix: | ||
os: ['ubuntu-latest', 'windows-latest'] | ||
python-version: ['3.8','3.9','3.10','3.11'] | ||
runs-on: ${{ matrix.os }} | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Install dependencies ubuntu | ||
if: matrix.os == 'ubuntu-latest' | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install . | ||
- name: Install dependencies windows | ||
if: matrix.os == 'windows-latest' | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install . | ||
- name: Test Task Runner API | ||
run: | | ||
python -m tests.github.test_task_runner --workspace torch_cnn_mnist_skc_compression --col1 col1 --col2 col2 --rounds-to-train 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
name: Ubuntu (latest) | ||
|
||
on: | ||
schedule: | ||
- cron: '0 0 * * *' | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
lint: # from lint.yml | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python 3.8 | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.8" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -r requirements-linters.txt | ||
pip install . | ||
- name: Lint with flake8 | ||
run: | | ||
flake8 --show-source | ||
pytest-coverage: # from pytest_coverage.yml | ||
needs: lint | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python 3.8 | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.8" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install pytest coverage | ||
pip install -r requirements-test.txt | ||
pip install . | ||
- name: Test with pytest and report code coverage | ||
run: | | ||
coverage run -m pytest -rA | ||
coverage report |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
name: Windows (latest) | ||
|
||
on: | ||
schedule: | ||
- cron: '0 0 * * *' | ||
|
||
permissions: | ||
contents: read | ||
|
||
jobs: | ||
pytest-coverage: # from pytest_coverage.yml | ||
runs-on: windows-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python 3.8 | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.8" | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install pytest coverage | ||
pip install -r requirements-test.txt | ||
pip install . | ||
- name: Test with pytest and report code coverage | ||
run: | | ||
coverage run -m pytest -rA | ||
coverage report |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""openfl base package.""" | ||
from openfl_contrib.__version__ import __version__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""openfl-contrib version information.""" | ||
__version__ = '0.1.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""openfl.interface package.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Aggregation functions package.""" | ||
|
||
from openfl_contrib.interface.aggregation_functions.custom_weighted_average import CustomWeightedAverage |
45 changes: 45 additions & 0 deletions
45
openfl_contrib/interface/aggregation_functions/custom_weighted_average.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Custom Federated averaging module.""" | ||
|
||
import numpy as np | ||
|
||
from openfl.interface.aggregation_functions.core import AggregationFunction | ||
|
||
|
||
class CustomWeightedAverage(AggregationFunction): | ||
"""Weighted average aggregation.""" | ||
|
||
def call(self, local_tensors, *_) -> np.ndarray: | ||
"""Aggregate tensors. | ||
Args: | ||
local_tensors(list[openfl.utilities.LocalTensor]): List of local tensors to aggregate. | ||
db_iterator: iterator over history of all tensors. Columns: | ||
- 'tensor_name': name of the tensor. | ||
Examples for `torch.nn.Module`s: 'conv1.weight', 'fc2.bias'. | ||
- 'round': 0-based number of round corresponding to this tensor. | ||
- 'tags': tuple of tensor tags. Tags that can appear: | ||
- 'model' indicates that the tensor is a model parameter. | ||
- 'trained' indicates that tensor is a part of a training result. | ||
These tensors are passed to the aggregator node after local learning. | ||
- 'aggregated' indicates that tensor is a result of aggregation. | ||
These tensors are sent to collaborators for the next round. | ||
- 'delta' indicates that value is a difference between rounds | ||
for a specific tensor. | ||
also one of the tags is a collaborator name | ||
if it corresponds to a result of a local task. | ||
- 'nparray': value of the tensor. | ||
tensor_name: name of the tensor | ||
fl_round: round number | ||
tags: tuple of tags for this tensor | ||
Returns: | ||
np.ndarray: aggregated tensor | ||
""" | ||
tensors, weights = zip(*[(x.tensor, x.weight) for x in local_tensors]) | ||
|
||
total_weight = sum(weights) | ||
weighted_sum = np.sum([tensor * weight for tensor, weight in zip(tensors, weights)], axis=0) | ||
return weighted_sum / total_weight |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""openfl.pipelines module.""" | ||
|
||
from openfl_contrib.pipelines.skc_pipeline import SKCPipeline |
Oops, something went wrong.