Skip to content

Commit

Permalink
Reverse automatic formatting applied by Copybara
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 368210967
  • Loading branch information
kho authored and fedjax authors committed Apr 13, 2021
1 parent 32dbdd9 commit 789f81c
Show file tree
Hide file tree
Showing 24 changed files with 140 additions and 153 deletions.
3 changes: 2 additions & 1 deletion examples/emnist_simple_fed_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
flags.DEFINE_string(
'cache_dir', None,
'Cache directory. If specified, files will be downloaded to disk. If '
'unspecified, files are read directly over network.')
'unspecified, files are read directly over network.'
)
flags.DEFINE_bool('only_digits', False, 'Whether to use only digits or not.')
flags.DEFINE_integer('seed', 0, 'Random seed.')
flags.DEFINE_integer('client_batch_size', 20, 'Client local batch size.')
Expand Down
7 changes: 5 additions & 2 deletions fedjax/algorithms/agnostic_fed_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,12 @@ def _update_domain_weights(domain_weights: jnp.ndarray,
class AgnosticFedAvg(core.FederatedAlgorithm):
"""Agnostic federated averaging algorithm."""

def __init__(self, federated_data: core.FederatedData, model: core.Model,
def __init__(self,
federated_data: core.FederatedData,
model: core.Model,
client_optimizer: core.Optimizer,
server_optimizer: core.Optimizer, hparams: AgnosticFedAvgHParams,
server_optimizer: core.Optimizer,
hparams: AgnosticFedAvgHParams,
rng_seq: core.PRNGSequence):
"""Initializes HypCluster algorithm.
Expand Down
7 changes: 5 additions & 2 deletions fedjax/algorithms/mime.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@
class Mime(core.FederatedAlgorithm):
"""Mime algorithm."""

def __init__(self, federated_data: core.FederatedData, model: core.Model,
def __init__(self,
federated_data: core.FederatedData,
model: core.Model,
base_optimizer: core.Optimizer,
hparams: mime_lite.MimeLiteHParams, rng_seq: core.PRNGSequence):
hparams: mime_lite.MimeLiteHParams,
rng_seq: core.PRNGSequence):
"""Initializes MimeLite algorithm.
Args:
Expand Down
7 changes: 5 additions & 2 deletions fedjax/algorithms/mime_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ def compute_gradient(stream: Iterable[core.Batch], params: core.Params,
class MimeLite(core.FederatedAlgorithm):
"""Mime Lite algorithm."""

def __init__(self, federated_data: core.FederatedData, model: core.Model,
base_optimizer: core.Optimizer, hparams: MimeLiteHParams,
def __init__(self,
federated_data: core.FederatedData,
model: core.Model,
base_optimizer: core.Optimizer,
hparams: MimeLiteHParams,
rng_seq: core.PRNGSequence):
"""Initializes MimeLite algorithm.
Expand Down
25 changes: 13 additions & 12 deletions fedjax/core/client_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,22 @@
import jax.numpy as jnp
import jax.random as jrandom


T = TypeVar('T')

FLAGS = flags.FLAGS

flags.DEFINE_enum(
'fedjax_experimental_disable_parallel', 'true', ['auto', 'true', 'false'],
'Set to `false` to parallelize `train_multiple_clients` on '
'multiple local devices via `jax.pmap`. '
'Defaults to `auto`, which will train in parallel only if '
'there is more than one local device available. '
'Training in parallel will automatically drop batches that '
'are not full batch size (the last batch).'
'Set to `true` to disable parallel and train sequentially, '
'meaning adding more devices does not help performance.')

flags.DEFINE_enum('fedjax_experimental_disable_parallel', 'true',
['auto', 'true', 'false'],
'Set to `false` to parallelize `train_multiple_clients` on '
'multiple local devices via `jax.pmap`. '
'Defaults to `auto`, which will train in parallel only if '
'there is more than one local device available. '
'Training in parallel will automatically drop batches that '
'are not full batch size (the last batch).'
'Set to `true` to disable parallel and train sequentially, '
'meaning adding more devices does not help performance.')


class ClientTrainer(Generic[T], metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -328,7 +330,6 @@ def masked_func(mask, *args, **kwargs):

return masked_func


# Defaults are input trainer state and rng respectively.
_mask_step = _pmask(_one_step_f, default_argnums=(1, 3))
# static_broadcasted_argnums=1 points to the input ClientTrainer instance since
Expand Down Expand Up @@ -383,7 +384,7 @@ def _train_multiple_clients_parallel(
num_iterations = quotient + bool(remainder)
for i in range(num_iterations):
stack_state = init_stack_state
streams = client_data[num_local_devices * i:num_local_devices * (i + 1)]
streams = client_data[num_local_devices * i: num_local_devices * (i + 1)]
# Handle number of clients not divisible by num_devices.
if len(streams) < num_local_devices:
client_count = len(streams)
Expand Down
12 changes: 4 additions & 8 deletions fedjax/core/client_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,11 @@ def test_train_single_client_iterator(self):
'testcase_name': 'sequential',
'disable_parallel': 'true',
'expected_num_examples': 20,
},
{
}, {
'testcase_name': 'parallel',
'disable_parallel': 'false',
'expected_num_examples': 18, # (20 // 3) * 3
},
{
}, {
'testcase_name': 'auto',
'disable_parallel': 'auto',
'expected_num_examples': 18, # (20 // 3) * 3
Expand Down Expand Up @@ -201,13 +199,11 @@ def test_loop(self):
'testcase_name': 'sequential',
'disable_parallel': 'true',
'expected_num_examples': 20,
},
{
}, {
'testcase_name': 'parallel',
'disable_parallel': 'false',
'expected_num_examples': 18, # (20 // 3) * 3
},
{
}, {
'testcase_name': 'auto',
'disable_parallel': 'auto',
'expected_num_examples': 18, # (20 // 3) * 3
Expand Down
6 changes: 3 additions & 3 deletions fedjax/core/dataset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def preprocess_tf_dataset(dataset: tf.data.Dataset,
dataset = dataset.repeat(hparams.num_epochs)
if hparams.shuffle_buffer_size:
dataset = dataset.shuffle(hparams.shuffle_buffer_size)
dataset = (
dataset.batch(hparams.batch_size,
drop_remainder=hparams.drop_remainder).prefetch(1))
dataset = (dataset.batch(hparams.batch_size,
drop_remainder=hparams.drop_remainder)
.prefetch(1))
return dataset.take(hparams.num_batches)


Expand Down
2 changes: 0 additions & 2 deletions fedjax/core/evaluation_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,8 @@ def evaluate_single_client(dataset: dataset_util.DatasetOrIterable,
Returns:
Ordered mapping of client metrics or empty mapping if the dataset is empty.
"""

def compute_batch_metrics(batch):
return model.evaluate(params, batch)

batch_metrics_iterator = map(compute_batch_metrics,
dataset_util.iterate(dataset))
return aggregate_metrics(batch_metrics_iterator)
Expand Down
18 changes: 9 additions & 9 deletions fedjax/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ def clz_from_iterable(meta, data):
kwargs = dict(meta_args + data_args)
return data_clz(**kwargs)

jax.tree_util.register_pytree_node(data_clz, iterate_clz, clz_from_iterable)
jax.tree_util.register_pytree_node(data_clz,
iterate_clz,
clz_from_iterable)
return data_clz


class Metric(metaclass=abc.ABCMeta):
"""Interface for all metric containers (e.g.
accuracy).
"""Interface for all metric containers (e.g. accuracy).
`Metric` stores intermediate values as well as methods for accumulation and
final result computation.
Expand Down Expand Up @@ -209,8 +209,7 @@ def accuracy_fn(targets: jnp.ndarray, preds: jnp.ndarray) -> MeanMetric:
def masked_accuracy_fn(
targets: jnp.ndarray,
preds: jnp.ndarray,
mask_values: Tuple[int, ...] = (),
) -> MeanMetric:
mask_values: Tuple[int, ...] = (),) -> MeanMetric:
"""Computes accuracy after discounting masked values.
Args:
Expand Down Expand Up @@ -253,16 +252,17 @@ def masked_accuracy_fn_with_logits_mask(
return MeanMetric.from_values(pred_class == targets, weights=weights)


def masked_count(
targets: jnp.ndarray, mask_values: Tuple[Any, ...] = ()) -> CountMetric:
def masked_count(targets: jnp.ndarray,
mask_values: Tuple[Any, ...] = ()) -> CountMetric:
"""Counts total number of non masked targets."""
weights = jnp.ones_like(targets, dtype=jnp.int32)
for mv in mask_values:
weights *= (targets != mv)
return CountMetric(count=jnp.sum(weights))


def truncation_rate(targets: jnp.ndarray, eos_value: int,
def truncation_rate(targets: jnp.ndarray,
eos_value: int,
pad_value: int) -> MeanMetric:
"""Computes the proportion of sequence examples that were truncated.
Expand Down
15 changes: 7 additions & 8 deletions fedjax/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class Model:
apply_fn: Function that produces model predictions for input batch.
loss_fn: Loss function that takes input batch and model predictions and
returns scalar loss.
reg_fn: Regularization function that takes parameters in and returns a
scalar regularizer value.
reg_fn: Regularization function that takes parameters in and returns
a scalar regularizer value.
metrics_fn_map: Ordered mapping of metric names to metric functions that
take input batch and model predictions and return metric values.
train_kwargs: Keyword arguments passed to apply for training.
Expand Down Expand Up @@ -131,17 +131,16 @@ def create_model_from_haiku(
metrics_fn_map: Optional[Mapping[str, MetricsFn]] = None,
train_kwargs: Optional[Mapping[str, Any]] = None,
test_kwargs: Optional[Mapping[str, Any]] = None,
non_trainable_module_names: Tuple[str] = ()
) -> Model: # pytype: disable=annotation-type-mismatch
non_trainable_module_names: Tuple[str] = ()) -> Model: # pytype: disable=annotation-type-mismatch
"""Creates Model after applying defaults and haiku specific preprocessing.
Args:
transformed_forward_pass: Transformed forward pass from `hk.transform`.
sample_batch: Example input batch used to determine model parameter shapes.
loss_fn: Loss function that takes input batch and model predictions and
returns scalar loss.
reg_fn: Regularization function that takes parameters in and returns a
scalar regularizer value. Defaults to no regularization.
reg_fn: Regularization function that takes parameters in and returns
a scalar regularizer value. Defaults to no regularization.
metrics_fn_map: Ordered mapping of metric names to metric functions that
take input batch and model predictions and return metric values that will
be freezed for immutability. Defaults to empty frozen dictionary.
Expand Down Expand Up @@ -190,8 +189,8 @@ def create_model_from_stax(
loss_fn: Loss function that takes input batch and model predictions and
returns scalar loss.
input_key: Key name for the input from batch mapping.
reg_fn: Regularization function that takes parameters in and returns a
scalar regularizer value. Defaults to no regularization.
reg_fn: Regularization function that takes parameters in and returns
a scalar regularizer value. Defaults to no regularization.
metrics_fn_map: Ordered mapping of metric names to metric functions that
take input batch and model predictions and return metric values that will
be freezed for immutability. Defaults to empty frozen dictionary.
Expand Down
14 changes: 7 additions & 7 deletions fedjax/core/regularizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ class L2Regularizer(Regularizer):
regularization strength for each parameter.
"""

def __init__(self,
center_params: Optional[Params] = None,
weight: float = 1.0,
param_weights: Optional[Params] = None):
def __init__(
self,
center_params: Optional[Params] = None,
weight: float = 1.0,
param_weights: Optional[Params] = None):
super().__init__(center_params)
self._weight = weight
self._param_weights = param_weights
Expand All @@ -65,8 +66,7 @@ def __call__(self, params: Params) -> float:

if self._param_weights:
param_weight_leaves, _ = jax.tree_flatten(self._param_weights)
return sum(
jnp.vdot(pw, jnp.square(x))
for pw, x in zip(param_weight_leaves, leaves)) * self._weight
return sum(jnp.vdot(pw, jnp.square(x))
for pw, x in zip(param_weight_leaves, leaves)) * self._weight

return sum(jnp.vdot(x, x) for x in leaves) * self._weight
10 changes: 5 additions & 5 deletions fedjax/core/regularizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def test_l2_regularizer_parameter_weight(self):
params = test_util.create_mock_state(seed=0).params
original_output = regularizers.L2Regularizer()(params)

param_weights = jax.tree_map(lambda leaf: 2 * jax.numpy.ones(leaf.shape),
param_weights = jax.tree_map(lambda leaf: 2*jax.numpy.ones(leaf.shape),
params)
output = regularizers.L2Regularizer(
weight=1.0, param_weights=param_weights)(
params)
self.assertAlmostEqual(output, 2 * original_output, delta=1e-5)
output = regularizers.L2Regularizer(weight=1.0,
param_weights=param_weights)(params)
self.assertAlmostEqual(output, 2*original_output,
delta=1e-5)

def test_l2_regularizer_evaluation_with_center(self):
params = test_util.create_mock_state(seed=0).params
Expand Down
18 changes: 14 additions & 4 deletions fedjax/core/test_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ class TestUtilsTest(tf.test.TestCase):

def test_create_toy_data(self):
data = test_util.create_toy_data(
num_clients=10, num_clusters=2, num_classes=4, num_examples=5, seed=10)
num_clients=10,
num_clusters=2,
num_classes=4,
num_examples=5,
seed=10)
client_id = data.client_ids[3]
client_data = list(data.create_tf_dataset_for_client(client_id))
self.assertLen(data.client_ids, 10)
Expand All @@ -39,9 +43,15 @@ def test_create_toy_model(self):

def test_create_toy_example(self):
data, model = test_util.create_toy_example(
num_clients=10, num_clusters=2, num_classes=4, num_examples=5, seed=10)
batch = next((data.create_tf_dataset_for_client(
data.client_ids[0]).batch(3).as_numpy_iterator()))
num_clients=10,
num_clusters=2,
num_classes=4,
num_examples=5,
seed=10)
batch = next(
(data.create_tf_dataset_for_client(data.client_ids[0])
.batch(3)
.as_numpy_iterator()))
params = model.init_params(next(hk.PRNGSequence(0)))
self.assertTupleEqual(model.apply_fn(params, None, batch).shape, (3, 4))

Expand Down
15 changes: 3 additions & 12 deletions fedjax/core/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,9 @@ def test_tree_broadcast(self):

def test_tree_stack(self):
pytrees = [
{
'x': jnp.array([[0, 0]]),
'y': jnp.array([0])
},
{
'x': jnp.array([[1, 1]]),
'y': jnp.array([1])
},
{
'x': jnp.array([[2, 2]]),
'y': jnp.array([2])
},
{'x': jnp.array([[0, 0]]), 'y': jnp.array([0])},
{'x': jnp.array([[1, 1]]), 'y': jnp.array([1])},
{'x': jnp.array([[2, 2]]), 'y': jnp.array([2])},
]
pytree = tree_util.tree_stack(pytrees)
self.assertAllEqual(pytree['x'], [[[0, 0]], [[1, 1]], [[2, 2]]])
Expand Down
8 changes: 5 additions & 3 deletions fedjax/experimental/aggregators/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ class MeanAggregator(Aggregator):
def init_state(self):
return MeanAggregatorState(total_weight=0.0)

def aggregate(self, aggregator_state: MeanAggregatorState,
params_and_weight: Iterable[Tuple[W, float]],
rng_seq: core.PRNGSequence) -> Tuple[MeanAggregatorState, W]:
def aggregate(
self, aggregator_state: MeanAggregatorState,
params_and_weight: Iterable[Tuple[W, float]],
rng_seq: core.PRNGSequence
) -> Tuple[MeanAggregatorState, W]:
"""Returns (weighted) mean of input trees and weights.
Args:
Expand Down
4 changes: 2 additions & 2 deletions fedjax/experimental/aggregators/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def quantize_params_and_weight(param_weight_rng):

def binary_stochastic_quantize_with_rng(param):
return binary_stochastic_quantize(param, rng)

return jax.tree_map(binary_stochastic_quantize_with_rng, param), weight

# TODO(theertha): avoid the need to copying the entire sequence to memory.

params_and_weight_rng = zip(params_and_weight, rng_seq)
quantized_p_and_w = map(quantize_params_and_weight, params_and_weight_rng)
quantized_p_and_w = map(quantize_params_and_weight,
params_and_weight_rng)
quantized_p_and_w = list(quantized_p_and_w)
weights = [weight for params, weight in quantized_p_and_w]
new_weight = sum(weights)
Expand Down
6 changes: 3 additions & 3 deletions fedjax/experimental/federated_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def run_one_round(state, clients):
```
Attributes:
init: Initializes the `AlgorithmState`. Typically, the input to this method
will be the initial model `Params`. This should only be run once at the
beginning of training.
init: Initializes the `AlgorithmState`. Typically, the input to this
method will be the initial model `Params`. This should only be run once
at the beginning of training.
run_one_round: Completes one round of federated training given an input
`AlgorithmState` and a sequence of client identifiers to client datasets.
The output will be a new, updated `AlgorithmState` along with any per
Expand Down
Loading

0 comments on commit 789f81c

Please sign in to comment.