Skip to content

Commit

Permalink
SQLite backed FederatedData
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 368012230
  • Loading branch information
kho authored and fedjax authors committed Apr 12, 2021
1 parent dddee17 commit 5a3f279
Show file tree
Hide file tree
Showing 6 changed files with 692 additions and 1 deletion.
23 changes: 22 additions & 1 deletion fedjax/experimental/federated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
"""FederatedData interface for providing access to a federated dataset."""

import abc
from typing import Callable, Iterable, Iterator, Optional, Tuple
import itertools
import random
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple

from fedjax.experimental import client_datasets

Expand Down Expand Up @@ -171,3 +173,22 @@ def preprocess_batch(
self, fn: Callable[[client_datasets.Examples], client_datasets.Examples]
) -> 'FederatedData':
"""Registers a preprocessing function to be called after batching in ClientDatasets."""


# Utility functions useful when implementing FederatedData.


def buffered_shuffle(source: Iterable[Any], buffer_size: int,
rng: random.Random) -> Iterator[Any]:
"""Shuffles an iterable via buffered shuffling."""
it = iter(source)
buf = list(itertools.islice(it, buffer_size))
rng.shuffle(buf)
for i in it:
r, buf[0] = buf[0], i
swap = rng.randrange(buffer_size)
if swap < buffer_size - 1:
buf[swap], buf[0] = buf[0], buf[swap]
yield r
for i in buf:
yield i
170 changes: 170 additions & 0 deletions fedjax/experimental/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Derived from https://github.com/google/flax/blob/master/flax/serialization.py
#
# Copyright 2021 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple serialization scheme based on msgpack.
We support serializing a dict and list of ndarrays holding numeric types, and
bytes (note tuples cannot be serialized since msgpack does not distinguish
tuples from lists during deserialization).
"""
import enum
import jax
import msgpack
import numpy as np

# On-the-wire / disk serialization format

# We encode state-dicts via msgpack, using its custom type extension.
# https://github.com/msgpack/msgpack/blob/master/spec.md
#
# - ndarrays and DeviceArrays of numeric types are serialized to nested
# msgpack-encoded string of (shape-tuple, dtype-name (e.g. 'float32'),
# row-major array-bytes).
#
# - native complex scalars are converted to nested msgpack-encoded tuples
# (real, imag).
#
# - ndarrays of bytes are serialized to nested msgpack-encoded string of
# (shape-tuple, concatenated bytes).


def _ndarray_to_bytes(arr):
"""Save ndarray to simple msgpack encoding."""
if isinstance(arr, jax.xla.DeviceArray):
arr = np.array(arr)
if arr.dtype.hasobject or arr.dtype.isalignedstruct:
raise ValueError('Object and structured dtypes not supported '
'for serialization of ndarrays.')
tpl = (arr.shape, arr.dtype.name, arr.tobytes('C'))
return msgpack.packb(tpl, use_bin_type=True)


def _dtype_from_name(name):
"""Handle JAX bfloat16 dtype correctly."""
if name == b'bfloat16':
return jax.numpy.bfloat16
else:
return np.dtype(name)


def _ndarray_from_bytes(data):
"""Load ndarray from simple msgpack encoding."""
shape, dtype_name, buffer = msgpack.unpackb(data, raw=True)
return np.frombuffer(
buffer, dtype=_dtype_from_name(dtype_name), count=-1, offset=0).reshape(
shape, order='C')


def _bytes_ndarray_to_bytes(x):
shape = x.shape
flat = list(x.flatten())
if flat and not isinstance(flat[0], bytes):
raise ValueError('Only ndarrays holding bytes objects can be serialized.')
tpl = shape, flat
return msgpack.packb(tpl, use_bin_type=True)


def _object_ndarray_from_bytes(data):
shape, flat = msgpack.unpackb(data, raw=True)
return np.array(flat, dtype=object).reshape(shape)


class _MsgpackExtType(enum.IntEnum):
"""Messagepack custom type ids."""
ndarray = 1
native_complex = 2
npscalar = 3
bytes_ndarray = 4


def _msgpack_ext_pack(x):
"""Messagepack encoders for custom types."""
if isinstance(x, np.ndarray) and x.dtype.hasobject:
return msgpack.ExtType(_MsgpackExtType.bytes_ndarray,
_bytes_ndarray_to_bytes(x))
elif isinstance(x, (np.ndarray, jax.xla.DeviceArray)):
return msgpack.ExtType(_MsgpackExtType.ndarray, _ndarray_to_bytes(x))
elif np.issctype(type(x)):
# pack scalar as ndarray
return msgpack.ExtType(_MsgpackExtType.npscalar,
_ndarray_to_bytes(np.asarray(x)))
elif isinstance(x, complex):
return msgpack.ExtType(_MsgpackExtType.native_complex,
msgpack.packb((x.real, x.imag)))
print('falling back', repr(x))
return x


def _msgpack_ext_unpack(code, data):
"""Messagepack decoders for custom types."""
if code == _MsgpackExtType.ndarray:
return _ndarray_from_bytes(data)
elif code == _MsgpackExtType.native_complex:
complex_tuple = msgpack.unpackb(data)
return complex(complex_tuple[0], complex_tuple[1])
elif code == _MsgpackExtType.npscalar:
ar = _ndarray_from_bytes(data)
return ar[()] # unpack ndarray to scalar
elif code == _MsgpackExtType.bytes_ndarray:
return _object_ndarray_from_bytes(data)
return msgpack.ExtType(code, data)


# User-facing API calls:


def msgpack_serialize(pytree):
"""Save data structure to bytes in msgpack format.
Low-level function that only supports python trees with array leaves,
for custom objects use `to_bytes`.
Args:
pytree: python tree of dict, list with python primitives and array leaves.
Returns:
msgpack-encoded bytes of pytree.
"""
return msgpack.packb(pytree, default=_msgpack_ext_pack, strict_types=True)


def msgpack_deserialize(encoded_pytree):
"""Restore data structure from bytes in msgpack format.
Low-level function that only supports python trees with array leaves,
for custom objects use `from_bytes`.
Args:
encoded_pytree: msgpack-encoded bytes of python tree.
Returns:
Python tree of dict, list with python primitive and array leaves.
"""
return msgpack.unpackb(
encoded_pytree, ext_hook=_msgpack_ext_unpack, raw=False)
69 changes: 69 additions & 0 deletions fedjax/experimental/serialization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for fedjax.google.experimental.serialization."""

from absl.testing import absltest
from fedjax.experimental import serialization
import numpy as np
import numpy.testing as npt


class SerializationTest(absltest.TestCase):

def test_dict(self):
original = {
'int32':
np.arange(4, dtype=np.int32).reshape([2, 2]),
'float64':
-np.arange(4, dtype=np.float64).reshape([1, 4]),
'bytes':
np.array([b'a', b'bc', b'def'], dtype=np.object).reshape([3, 1]),
}
output = serialization.msgpack_deserialize(
serialization.msgpack_serialize(original))
self.assertCountEqual(output, original)
self.assertEqual(output['int32'].dtype, np.int32)
npt.assert_array_equal(output['int32'], original['int32'])
self.assertEqual(output['float64'].dtype, np.float64)
npt.assert_array_equal(output['float64'], original['float64'])
self.assertEqual(output['bytes'].dtype, np.object)
npt.assert_array_equal(output['bytes'], original['bytes'])

def test_nested_list(self):
original = [
np.arange(4, dtype=np.int32).reshape([2, 2]),
[
-np.arange(4, dtype=np.float64).reshape([1, 4]),
[
np.array([b'a', b'bc', b'def'],
dtype=np.object).reshape([3, 1]), []
]
]
]
output = serialization.msgpack_deserialize(
serialization.msgpack_serialize(original))
int32_array, rest = output
self.assertEqual(int32_array.dtype, np.int32)
npt.assert_array_equal(int32_array, original[0])
float64_array, rest = rest
self.assertEqual(float64_array.dtype, np.float64)
npt.assert_array_equal(float64_array, original[1][0])
bytes_array, rest = rest
self.assertEqual(bytes_array.dtype, np.object)
npt.assert_array_equal(bytes_array, original[1][1][0])
self.assertEqual(rest, [])


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 5a3f279

Please sign in to comment.