-
Notifications
You must be signed in to change notification settings - Fork 0
/
aggregate_inplace.py
31 lines (25 loc) · 1.06 KB
/
aggregate_inplace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from functools import reduce
from typing import Any, Callable, List, Tuple
import numpy as np
from flwr.common import FitRes, NDArray, NDArrays, parameters_to_ndarrays
from flwr.server.client_proxy import ClientProxy
def aggregate_inplace(results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays:
"""Compute in-place weighted average."""
# Count total examples
num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
# Compute scaling factors for each result
scaling_factors = [
fit_res.num_examples / num_examples_total for _, fit_res in results
]
# Let's do in-place aggregation
# Get first result, then add up each other
params = [
scaling_factors[0] * x for x in parameters_to_ndarrays(results[0][1].parameters)
]
for i, (_, fit_res) in enumerate(results[1:]):
res = (
scaling_factors[i + 1] * x
for x in parameters_to_ndarrays(fit_res.parameters)
)
params = [reduce(np.add, layer_updates) for layer_updates in zip(params, res)]
return params