Skip to content

Commit

Permalink
Cleaning up fedjax.datasets docs and pruning unused dependencies from…
Browse files Browse the repository at this point in the history
… setup
  • Loading branch information
jaehunro committed Oct 26, 2021
1 parent d4ec1a5 commit 69aee78
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 74 deletions.
21 changes: 15 additions & 6 deletions docs/fedjax.datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,39 @@ fedjax.datasets package
.. autosummary::
:nosignatures:

fedjax.datasets.cifar100
fedjax.datasets.emnist
fedjax.datasets.shakespeare
fedjax.datasets.stackoverflow

CIFAR-100
---------

.. automodule:: fedjax.datasets.cifar100
:members: cite, load_data, load_split, preprocess_image
:show-inheritance:

EMNIST
------

.. automodule:: fedjax.datasets.emnist
:members:
:undoc-members:
:members: cite, load_data, load_split, domain_id
:show-inheritance:

Shakespeare
-----------

.. automodule:: fedjax.datasets.shakespeare
:members:
:undoc-members:
:members: cite, load_data, load_split, preprocess_client
:show-inheritance:

Stack Overflow
--------------

.. automodule:: fedjax.datasets.stackoverflow
:members:
:undoc-members:
:members: cite, load_data, load_split
:show-inheritance:

.. autoclass:: fedjax.datasets.stackoverflow.StackoverflowTokenizer
:members:
:special-members: __init__
67 changes: 41 additions & 26 deletions fedjax/datasets/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@
_FEDJAX_SQLITE_NUM_BYTES = {'train': 140521472, 'test': 28135424}


def cite():
"""Returns BibTeX citation for the dataset."""
return """@TECHREPORT{Krizhevsky09learningmultiple,
author = {Alex Krizhevsky},
title = {Learning multiple layers of features from tiny images},
institution = {},
year = {2009}
}"""


def _parse_tf_examples(vs: List[bytes]) -> client_datasets.Examples:
tf_examples = tf.io.parse_example(
vs,
Expand All @@ -52,9 +62,10 @@ def load_split(split: str,
"""Loads a cifar100 split.
Features:
image: [N, 32, 32, 3] uint8 pixels.
coarse_label: [N] int64 coarse labels in the range [0, 20).
label: [N] int64 labels in the range [0, 100).
- image: [N, 32, 32, 3] uint8 pixels.
- coarse_label: [N] int64 coarse labels in the range [0, 20).
- label: [N] int64 labels in the range [0, 100).
Args:
split: Name of the split. One of SPLITS.
Expand Down Expand Up @@ -105,30 +116,33 @@ def load_data(
"""Loads partially preprocessed cifar100 splits.
Features:
x: [N, 32, 32, 3] uint8 pixels.
y: [N] int32 labels in the range [0, 100).
- x: [N, 32, 32, 3] uint8 pixels.
- y: [N] int32 labels in the range [0, 100).
Additional preprocessing (e.g. centering and normalizing) depends on whether
a split is used for training or eval. For example,
```
import functools
from fedjax.datasets import cifar100
# Load partially preprocessed splits.
train, test = cifar100.load_data()
# Preprocessing for training.
train_for_train = train.preprocess_batch(
functools.partial(preprocess_batch, is_train=True))
# Preprocessing for eval.
train_for_eval = train.preprocess_batch(
functools.partial(preprocess_batch, is_train=False))
test = test.preprocess_batch(
functools.partial(preprocess_batch, is_train=False))
```
a split is used for training or eval. For example,::
import functools
from fedjax.datasets import cifar100
# Load partially preprocessed splits.
train, test = cifar100.load_data()
# Preprocessing for training.
train_for_train = train.preprocess_batch(
functools.partial(preprocess_batch, is_train=True))
# Preprocessing for eval.
train_for_eval = train.preprocess_batch(
functools.partial(preprocess_batch, is_train=False))
test = test.preprocess_batch(
functools.partial(preprocess_batch, is_train=False))
Features after final preprocessing:
x: [N, 32, 32, 3] float32 preprocessed pixels.
y: [N] int32 labels in the range [0, 100).
- x: [N, 32, 32, 3] float32 preprocessed pixels.
- y: [N] int32 labels in the range [0, 100).
NOTE: ``preprocess_batch`` is just a convenience wrapper around :meth:`preprocess_image`
so that it can be used with :meth:`fedjax.FederatedData.preprocess_batch`.
Args:
mode: 'sqlite'.
Expand All @@ -150,9 +164,7 @@ def preprocess_client(
return {'x': examples['image'], 'y': examples['label'].astype(np.int32)}


# Preprocessing procedure and values taken from
# https://github.com/kuangliu/pytorch-cifar/blob/49b7aa97b0c12fe0d4054e670403a16b6b834ddd/main.py#L30-L40
#

# Mean and stddev computed assuming RGB values lie in [0,1].
CIFAR100_PIXELS_MEAN = np.array([0.4914, 0.4822, 0.4465], dtype=np.float32)
CIFAR100_PIXELS_INVERSE_STDDEV = (
Expand All @@ -162,6 +174,9 @@ def preprocess_client(
def preprocess_image(image: np.ndarray, is_train: bool) -> np.ndarray:
"""Augments and preprocesses CIFAR-100 images by cropping, flipping, and normalizing.
Preprocessing procedure and values taken from
`pytorch-cifar <https://github.com/kuangliu/pytorch-cifar/blob/49b7aa97b0c12fe0d4054e670403a16b6b834ddd/main.py#L30-L40>`_.
Args:
image: [N, 32, 32, 3] uint8 pixels.
is_train: Whether we are preprocessing for training or eval.
Expand Down
35 changes: 25 additions & 10 deletions fedjax/datasets/emnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@


def cite():
"""Returns BibTeX citation for the dataset."""
return """@inproceedings{cohen2017emnist,
title={EMNIST: Extending MNIST to handwritten letters},
author={Cohen, Gregory and Afshar, Saeed and Tapson, Jonathan and
Expand All @@ -43,8 +44,9 @@ def load_split(split: str,
"""Loads an unprocessed federated emnist split.
Features:
pixels: [N, 28, 28] float32 image pixels.
label: [N] int32 classification label.
- pixels: [N, 28, 28] float32 image pixels.
- label: [N] int32 classification label.
Args:
split: Name of the split. One of SPLITS.
Expand Down Expand Up @@ -73,12 +75,24 @@ def load_split(split: str,


def domain_id(client_id: federated_data.ClientId) -> int:
"""Returns domain id for client id."""
"""Returns domain id for client id.
Domain ids are based on the NIST data source, where examples were collected
from two sources:
Bethesda high school (HIGH_SCHOOL) and Census Bureau in Suitland (CENSUS).
For more details, see the
`NIST documentation <https://s3.amazonaws.com/nist-srd/SD19/sd19_users_guide_edition_2.pdf>`_.
Args:
client_id: Client id of the format
``[16-byte hex hash]:f[4-digit integer]_[2-digit integer]`` or
``f[4-digit integer]_[2-digit integer]``.
Returns:
Domain id that is 0 (HIGH_SCHOOL) or 1 (CENSUS).
"""
# client ids are of the following format:
# - sqlite: "[16-byte hex hash]:f[4-digit integer]_[2-digit integer]"
#
# These domain ids are based on NIST data source. For more details, see
# https://s3.amazonaws.com/nist-srd/SD19/sd19_users_guide_edition_2.pdf.
if len(client_id) == 25:
cid = int(client_id[18:22])
elif len(client_id) == 8:
Expand All @@ -87,7 +101,7 @@ def domain_id(client_id: federated_data.ClientId) -> int:
raise ValueError(f'Invalid client_id: {client_id!r}')
if 2100 <= cid and cid <= 2599:
return 0 # HIGH_SCHOOL.
return 1 # CENSUS_FIELD.
return 1 # CENSUS.


def preprocess_client(
Expand Down Expand Up @@ -122,9 +136,10 @@ def load_data(
"""Loads processed EMNIST train and test splits.
Features:
x: [N, 28, 28, 1] float32 flipped image pixels.
y: [N] int32 classification label.
domain_id: [N] int32 domain id (see domain_id()).
- x: [N, 28, 28, 1] float32 flipped image pixels.
- y: [N] int32 classification label.
- domain_id: [N] int32 domain id (see :meth:`domain_id`).
Args:
only_digits: Whether to only load the digits data.
Expand Down
39 changes: 22 additions & 17 deletions fedjax/datasets/shakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@


def cite():
"""Returns BibTeX citation for the dataset."""
return """@inproceedings{mcmahan2017communication,
title={Communication-efficient learning of deep networks from
decentralized data},
Expand All @@ -44,7 +45,8 @@ def load_split(split: str,
"""Loads a shakespeare split.
Features:
snippets: [N] bytes array of snippet text.
- snippets: [N] bytes array of snippet text.
Args:
split: Name of the split. One of SPLITS.
Expand Down Expand Up @@ -74,11 +76,15 @@ def load_data(
) -> Tuple[federated_data.FederatedData, federated_data.FederatedData]:
"""Loads preprocessed shakespeare splits.
Preprocessing is done using :meth:`fedjax.FederatedData.preprocess_client`
and :meth:`preprocess_client`.
Features (M below is possibly different from N in load_split):
x: [M, sequence_length] int32 input labels, in the range of [0,
shakespeare.VOCAB_SIZE)
y: [M, sequence_length] int32 output labels, in the range of [0,
shakespeare.VOCAB_SIZE)
- x: [M, sequence_length] int32 input labels, in the range of [0,
shakespeare.VOCAB_SIZE)
- y: [M, sequence_length] int32 output labels, in the range of [0,
shakespeare.VOCAB_SIZE)
Args:
sequence_length: The fixed sequence length after preprocessing.
Expand Down Expand Up @@ -138,22 +144,21 @@ def preprocess_client(client_id: federated_data.ClientId,
"""Turns snippets into sequences of integer labels.
Features (M below is possibly different from N in load_split):
x: [M, sequence_length] int32 input labels, in the range of [0,
shakespeare.VOCAB_SIZE)
y: [M, sequence_length] int32 output labels, in the range of [0,
shakespeare.VOCAB_SIZE)
- x: [M, sequence_length] int32 input labels, in the range of [0,
shakespeare.VOCAB_SIZE)
- y: [M, sequence_length] int32 output labels, in the range of [0,
shakespeare.VOCAB_SIZE)
All snippets in a client dataset are first joined into a single sequence (with
BOS/EOS added), and then split into pairs of `sequence_length` chunks for
language model training. For example, with sequence_length=3,
`[b'ABCD', b'E']` becomes
```
Input sequences: [[BOS, A, B], [C, D, EOS], [BOS, E, PAD]]
Output seqeunces: [[A, B, C], [D, EOS, BOS], [E, EOS, PAD]]
```
Note: This is not equivalent to
https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation#load_and_preprocess_the_federated_shakespeare_data
`[b'ABCD', b'E']` becomes::
Input sequences: [[BOS, A, B], [C, D, EOS], [BOS, E, PAD]]
Output seqeunces: [[A, B, C], [D, EOS, BOS], [E, EOS, PAD]]
Note: This is not equivalent to the `TensorFlow Federated text generation tutorial <https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation#load_and_preprocess_the_federated_shakespeare_data>`_
(The processing logic there loses ~1/sequence_length portion of the tokens).
Args:
Expand Down
30 changes: 17 additions & 13 deletions fedjax/datasets/stackoverflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


def cite():
"""Returns BibTeX citation for the dataset."""
return """@misc{stackoverflow2019,
title={TensorFlow Federated Stack Overflow dataset},
author={The TensorFlow Federated Authors.},
Expand All @@ -42,13 +43,14 @@ def load_split(split: str,
All bytes arrays are stored with `dtype=np.object`.
Features:
creation_date: [N] bytes array. Textual timestamp, e.g.
b'2018-02-28 19:06:18.34 UTC'.
title: [N] bytes array. The title of a post.
score: [N] int64 array. The score of a post.
tags: [N] bytes array. '|' separated list of tags, e.g. b'mysql|join'.
tokens: [N] bytes array. Space separated list of tokens.
type: [N] bytes array. Either b'question' or b'answer'.
- creation_date: [N] bytes array. Textual timestamp, e.g.
b'2018-02-28 19:06:18.34 UTC'.
- title: [N] bytes array. The title of a post.
- score: [N] int64 array. The score of a post.
- tags: [N] bytes array. '|' separated list of tags, e.g. b'mysql|join'.
- tokens: [N] bytes array. Space separated list of tokens.
- type: [N] bytes array. Either b'question' or b'answer'.
Args:
split: Name of the split. One of SPLITS.
Expand Down Expand Up @@ -79,8 +81,9 @@ def load_data(
"""Loads partially preprocessed stackoverflow splits.
Features:
domain_id: [N] int32 domain id derived from type (question = 0; answer = 1).
tokens: [N] bytes array. Space separated list of tokens.
- domain_id: [N] int32 domain id derived from type (question = 0; answer = 1).
- tokens: [N] bytes array. Space separated list of tokens.
To convert `tokens` into padded/truncated integer labels, use a
StackoverflowTokenizer. For example,::
Expand All @@ -101,9 +104,10 @@ def load_data(
tokenizer.as_preprocess_batch(eval_max_length))
Features after tokenization:
domain_id: Same as before.
x: [N, max_length] int32 array of padded/truncated input labels.
y: [N, max_length] int32 array of padded/truncated output labels.
- domain_id: Same as before.
- x: [N, max_length] int32 array of padded/truncated input labels.
- y: [N, max_length] int32 array of padded/truncated output labels.
Args:
mode: 'sqlite'.
Expand Down Expand Up @@ -146,7 +150,7 @@ def default_vocab(default_vocab_size) -> List[str]:
class StackoverflowTokenizer:
"""Tokenizer for the `tokens` feature in stackoverflow.
See load_data() for examples.
See :meth:`load_data` for examples.
"""
PAD = 0
BOS = 1
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@
install_requires=[
'absl-py',
'dm-haiku',
'immutabledict',
'jax',
'jaxlib',
'msgpack',
'optax',
'requests',
'tensorflow-federated',
],
python_requires='>=3.6',
classifiers=[
Expand Down

0 comments on commit 69aee78

Please sign in to comment.