Skip to content

Commit

Permalink
Delay importing tensorflow instead of directly depending on it in setup
Browse files Browse the repository at this point in the history
This will require users to manually install tensorflow if they happen to be using a part of fedjax that depends on tensorflow (a relatively small subset of the api).

PiperOrigin-RevId: 406173096
  • Loading branch information
jaehunro authored and fedjax authors committed Oct 28, 2021
1 parent 248fcff commit 6122339
Show file tree
Hide file tree
Showing 12 changed files with 63 additions and 20 deletions.
3 changes: 2 additions & 1 deletion fedjax/core/client_datasets_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

from absl import app
from fedjax.core import client_datasets
from fedjax.core import util
import numpy as np
import tensorflow as tf

# pylint: disable=cell-var-from-loop

Expand Down Expand Up @@ -78,6 +78,7 @@ def bench_client_dataset(preprocess, mode, batch_size=128, num_steps=100):

def bench_tf_dataset(preprocess, mode, batch_size=128, num_steps=100):
"""Benchmarks TF Dataset."""
tf = util.import_tf()
shuffle_buffer = 1000 # size of FAKE_MNIST
dataset = tf.data.Dataset.from_tensor_slices(FAKE_MNIST)
if mode == 'train':
Expand Down
5 changes: 3 additions & 2 deletions fedjax/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,23 @@
import pickle

from absl import logging
from fedjax.core import util
import jax
import msgpack
import numpy as np
# TODO(b/188556866): Remove dependency on TensorFlow.
import tensorflow as tf


def save_state(state, path):
"""Saves state to file path."""
tf = util.import_tf()
logging.info('Saving state to %s.', path)
with tf.io.gfile.GFile(path, 'wb') as f:
pickle.dump(state, f)


def load_state(path):
"""Loads saved state from file path."""
tf = util.import_tf()
logging.info('Loading params from %s.', path)
with tf.io.gfile.GFile(path, 'rb') as f:
return pickle.load(f)
Expand Down
3 changes: 2 additions & 1 deletion fedjax/core/sqlite_federated_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from absl.testing import absltest
from fedjax.core import serialization
from fedjax.core import sqlite_federated_data
from fedjax.core import util
import numpy as np
import numpy.testing as npt
import sqlite3
import tensorflow as tf

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -251,6 +251,7 @@ class TFFSQLiteClientsIteratorTest(absltest.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
tf = util.import_tf()
path = os.path.join(FLAGS.test_tmpdir, 'test_tff_sqlite_reader.sqlite')
split_name = 'train'

Expand Down
32 changes: 32 additions & 0 deletions fedjax/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,35 @@ def safe_div(a, b):
safe = b != 0
c = a / jnp.where(safe, b, 1)
return jnp.where(safe, c, 0)


# TODO(b/188556866): Remove dependency on TensorFlow.
def import_tf():
"""Imports and returns TensorFlow module if it is installed.
This is to avoid having FedJAX directly depend on TensorFlow since TensorFlow
is large and complex and the majority of FedJAX functions without it.
Returns:
TensorFlow module if it is installed.
Raises:
ModuleNotFoundError: If TensorFlow is not installed along with instructions
on how to install TensorFlow.
"""
try:
import tensorflow # pylint:disable=g-import-not-at-top
return tensorflow
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"""This FedJAX feature requires TensorFlow, but TensorFlow is not installed.
If you do not otherwise need TensorFlow, we recommend installing the smaller
CPU-only version of TensorFlow via
pip install tensorflow-cpu
If you may need to use TensorFlow with GPU, please install the full version via
pip install tensorflow
""") from e
3 changes: 2 additions & 1 deletion fedjax/datasets/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from fedjax.core import client_datasets
from fedjax.core import federated_data
from fedjax.core import sqlite_federated_data
from fedjax.core import util
from fedjax.datasets import downloads

import numpy as np
import tensorflow as tf

SPLITS = ('train', 'test')
_TFF_SQLITE_COMPRESSED_HEXDIGEST = '23d3916c9caa33395737ee560cc7cb77bbd05fc7b73647ee7be3a7e764172939'
Expand All @@ -45,6 +45,7 @@ def cite():


def _parse_tf_examples(vs: List[bytes]) -> client_datasets.Examples:
tf = util.import_tf()
tf_examples = tf.io.parse_example(
vs,
features={
Expand Down
5 changes: 4 additions & 1 deletion fedjax/datasets/stackoverflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from fedjax.core import client_datasets
from fedjax.core import federated_data
from fedjax.core import sqlite_federated_data
from fedjax.core import util
from fedjax.datasets import downloads
import numpy as np
import tensorflow as tf

SPLITS = ('train', 'held_out', 'test')

Expand Down Expand Up @@ -137,6 +137,7 @@ def preprocess_client(

def default_vocab(default_vocab_size) -> List[str]:
"""Loads the deafult stackoverflow vocabulary."""
tf = util.import_tf()
path = 'gs://gresearch/fedjax/stackoverflow/stackoverflow.word_count'
vocab = []
with tf.io.gfile.GFile(path) as f:
Expand Down Expand Up @@ -176,6 +177,7 @@ def __init__(self,
OOV labels starting at `default_vocab_size + 3`.
num_oov_buckets: Number of out of vocabulary buckets.
"""
tf = util.import_tf()
if vocab is None:
# Load default vocabulary.
vocab = default_vocab(default_vocab_size)
Expand All @@ -197,6 +199,7 @@ def as_preprocess_batch(
Returns:
A function that can be used with FederatedData.preprocess_batch().
"""
tf = util.import_tf()

@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def token_to_ids(tokens):
Expand Down
4 changes: 3 additions & 1 deletion fedjax/training/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from typing import Any, List, Optional, Tuple

from fedjax.core import serialization
import tensorflow as tf
from fedjax.core import util

_CHECKPOINT_PREFIX = 'checkpoint_'


def _get_checkpoint_paths(base_path: str) -> List[str]:
"""Returns all checkpoint paths present."""
tf = util.import_tf()
pattern = base_path + r'[0-9]{8}$'
checkpoint_paths = []
for path in tf.io.gfile.glob(base_path + '*'):
Expand Down Expand Up @@ -53,6 +54,7 @@ def save_checkpoint(root_dir: str,
round_num: int = 0,
keep: int = 1):
"""Saves checkpoint and cleans up old checkpoints."""
tf = util.import_tf()
base_path = os.path.join(root_dir, _CHECKPOINT_PREFIX)
checkpoint_path = f'{base_path}{round_num:08d}'
serialization.save_state(state, checkpoint_path)
Expand Down
4 changes: 3 additions & 1 deletion fedjax/training/federated_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from fedjax.core import federated_algorithm
from fedjax.core import federated_data
from fedjax.core import models
from fedjax.core import util
from fedjax.training import checkpoint
from fedjax.training import logging as fedjax_logging
import jax.numpy as jnp
import tensorflow as tf


def set_tf_cpu_only():
Expand All @@ -36,6 +36,7 @@ def set_tf_cpu_only():
TensorFlow is only used for data loading, so we prevent it from allocating
GPU/TPU memory.
"""
tf = util.import_tf()
tf.config.experimental.set_visible_devices([], 'GPU')
tf.config.experimental.set_visible_devices([], 'TPU')

Expand Down Expand Up @@ -177,6 +178,7 @@ def run_federated_experiment(
Returns:
Final state of the input federated algortihm after training.
"""
tf = util.import_tf()
if config.root_dir:
tf.io.gfile.makedirs(config.root_dir)

Expand Down
10 changes: 5 additions & 5 deletions fedjax/training/federated_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for fedjax.legacy.training.federated_experiment."""

import glob
import os.path

from absl.testing import absltest
Expand All @@ -26,7 +27,6 @@
import jax
import numpy as np
import numpy.testing as npt
import tensorflow as tf


class FakeClientSampler(client_samplers.ClientSampler):
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_checkpoint(self):
config=config)
self.assertEqual(state, (6, 14))
self.assertCountEqual(
tf.io.gfile.glob(os.path.join(config.root_dir, 'checkpoint_*')),
glob.glob(os.path.join(config.root_dir, 'checkpoint_*')),
[os.path.join(config.root_dir, 'checkpoint_00000003')])

with self.subTest('checkpoint restore'):
Expand All @@ -135,7 +135,7 @@ def test_checkpoint(self):
config=config)
self.assertEqual(state, (6, 16))
self.assertCountEqual(
tf.io.gfile.glob(os.path.join(config.root_dir, 'checkpoint_*')),
glob.glob(os.path.join(config.root_dir, 'checkpoint_*')),
[os.path.join(config.root_dir, 'checkpoint_00000004')])

def test_periodic_eval_fn_map(self):
Expand All @@ -159,7 +159,7 @@ def test_periodic_eval_fn_map(self):
})
self.assertEqual(state, (5, 15))
self.assertCountEqual(
tf.io.gfile.glob(os.path.join(config.root_dir, '*eval*')), [
glob.glob(os.path.join(config.root_dir, '*eval*')), [
os.path.join(config.root_dir, 'test_eval'),
os.path.join(config.root_dir, 'train_eval')
])
Expand All @@ -180,7 +180,7 @@ def test_final_eval_fn_map(self):
})
self.assertEqual(state, (5, 15))
self.assertCountEqual(
tf.io.gfile.glob(os.path.join(config.root_dir, 'final_eval.tsv')),
glob.glob(os.path.join(config.root_dir, 'final_eval.tsv')),
[os.path.join(config.root_dir, 'final_eval.tsv')])


Expand Down
4 changes: 2 additions & 2 deletions fedjax/training/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from typing import Any, Optional

from absl import logging

import tensorflow as tf
from fedjax.core import util


class Logger:
Expand All @@ -44,6 +43,7 @@ def log(self, writer_name: str, metric_name: str, metric_value: Any,
metric_value: Value of metric to log.
round_num: Round number to log.
"""
tf = util.import_tf()
logging.info('round %d %s: %s = %s', round_num, writer_name, metric_name,
metric_value)

Expand Down
9 changes: 5 additions & 4 deletions fedjax/training/logging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.
"""Tests for fedjax.training.logging."""

import tensorflow as tf
import os

from absl.testing import absltest
from fedjax.training import logging


class LoggingTest(tf.test.TestCase):
class LoggingTest(absltest.TestCase):

def test_log_no_root_dir(self):
logger = logging.Logger()
Expand All @@ -37,8 +38,8 @@ def test_log_root_dir(self):
logger.log(
writer_name='eval', metric_name='loss', metric_value=5.3, round_num=0)

self.assertCountEqual(['train', 'eval'], tf.io.gfile.listdir(root_dir))
self.assertCountEqual(['train', 'eval'], os.listdir(root_dir))


if __name__ == '__main__':
tf.test.main()
absltest.main()
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
'msgpack',
'optax',
'requests',
'tensorflow',
],
python_requires='>=3.6',
classifiers=[
Expand Down

0 comments on commit 6122339

Please sign in to comment.