From 61223394892db3ca63768c4e39eb2caefcc02b65 Mon Sep 17 00:00:00 2001 From: Jae Ro Date: Thu, 28 Oct 2021 10:44:25 -0700 Subject: [PATCH] Delay importing tensorflow instead of directly depending on it in setup This will require users to manually install tensorflow if they happen to be using a part of fedjax that depends on tensorflow (a relatively small subset of the api). PiperOrigin-RevId: 406173096 --- fedjax/core/client_datasets_benchmark.py | 3 +- fedjax/core/serialization.py | 5 +-- fedjax/core/sqlite_federated_data_test.py | 3 +- fedjax/core/util.py | 32 ++++++++++++++++++++ fedjax/datasets/cifar100.py | 3 +- fedjax/datasets/stackoverflow.py | 5 ++- fedjax/training/checkpoint.py | 4 ++- fedjax/training/federated_experiment.py | 4 ++- fedjax/training/federated_experiment_test.py | 10 +++--- fedjax/training/logging.py | 4 +-- fedjax/training/logging_test.py | 9 +++--- setup.py | 1 - 12 files changed, 63 insertions(+), 20 deletions(-) diff --git a/fedjax/core/client_datasets_benchmark.py b/fedjax/core/client_datasets_benchmark.py index 48ff0ea..e6d0d17 100644 --- a/fedjax/core/client_datasets_benchmark.py +++ b/fedjax/core/client_datasets_benchmark.py @@ -28,8 +28,8 @@ from absl import app from fedjax.core import client_datasets +from fedjax.core import util import numpy as np -import tensorflow as tf # pylint: disable=cell-var-from-loop @@ -78,6 +78,7 @@ def bench_client_dataset(preprocess, mode, batch_size=128, num_steps=100): def bench_tf_dataset(preprocess, mode, batch_size=128, num_steps=100): """Benchmarks TF Dataset.""" + tf = util.import_tf() shuffle_buffer = 1000 # size of FAKE_MNIST dataset = tf.data.Dataset.from_tensor_slices(FAKE_MNIST) if mode == 'train': diff --git a/fedjax/core/serialization.py b/fedjax/core/serialization.py index ef41cb4..6848307 100644 --- a/fedjax/core/serialization.py +++ b/fedjax/core/serialization.py @@ -38,15 +38,15 @@ import pickle from absl import logging +from fedjax.core import util import jax import msgpack import numpy as np -# TODO(b/188556866): Remove dependency on TensorFlow. -import tensorflow as tf def save_state(state, path): """Saves state to file path.""" + tf = util.import_tf() logging.info('Saving state to %s.', path) with tf.io.gfile.GFile(path, 'wb') as f: pickle.dump(state, f) @@ -54,6 +54,7 @@ def save_state(state, path): def load_state(path): """Loads saved state from file path.""" + tf = util.import_tf() logging.info('Loading params from %s.', path) with tf.io.gfile.GFile(path, 'rb') as f: return pickle.load(f) diff --git a/fedjax/core/sqlite_federated_data_test.py b/fedjax/core/sqlite_federated_data_test.py index 17310a8..7159743 100644 --- a/fedjax/core/sqlite_federated_data_test.py +++ b/fedjax/core/sqlite_federated_data_test.py @@ -21,10 +21,10 @@ from absl.testing import absltest from fedjax.core import serialization from fedjax.core import sqlite_federated_data +from fedjax.core import util import numpy as np import numpy.testing as npt import sqlite3 -import tensorflow as tf FLAGS = flags.FLAGS @@ -251,6 +251,7 @@ class TFFSQLiteClientsIteratorTest(absltest.TestCase): @classmethod def setUpClass(cls): super().setUpClass() + tf = util.import_tf() path = os.path.join(FLAGS.test_tmpdir, 'test_tff_sqlite_reader.sqlite') split_name = 'train' diff --git a/fedjax/core/util.py b/fedjax/core/util.py index 5aa10b2..3a559e6 100644 --- a/fedjax/core/util.py +++ b/fedjax/core/util.py @@ -22,3 +22,35 @@ def safe_div(a, b): safe = b != 0 c = a / jnp.where(safe, b, 1) return jnp.where(safe, c, 0) + + +# TODO(b/188556866): Remove dependency on TensorFlow. +def import_tf(): + """Imports and returns TensorFlow module if it is installed. + + This is to avoid having FedJAX directly depend on TensorFlow since TensorFlow + is large and complex and the majority of FedJAX functions without it. + + Returns: + TensorFlow module if it is installed. + + Raises: + ModuleNotFoundError: If TensorFlow is not installed along with instructions + on how to install TensorFlow. + """ + try: + import tensorflow # pylint:disable=g-import-not-at-top + return tensorflow + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + """This FedJAX feature requires TensorFlow, but TensorFlow is not installed. + +If you do not otherwise need TensorFlow, we recommend installing the smaller +CPU-only version of TensorFlow via + + pip install tensorflow-cpu + +If you may need to use TensorFlow with GPU, please install the full version via + + pip install tensorflow +""") from e diff --git a/fedjax/datasets/cifar100.py b/fedjax/datasets/cifar100.py index bc1fd75..62f2454 100644 --- a/fedjax/datasets/cifar100.py +++ b/fedjax/datasets/cifar100.py @@ -19,10 +19,10 @@ from fedjax.core import client_datasets from fedjax.core import federated_data from fedjax.core import sqlite_federated_data +from fedjax.core import util from fedjax.datasets import downloads import numpy as np -import tensorflow as tf SPLITS = ('train', 'test') _TFF_SQLITE_COMPRESSED_HEXDIGEST = '23d3916c9caa33395737ee560cc7cb77bbd05fc7b73647ee7be3a7e764172939' @@ -45,6 +45,7 @@ def cite(): def _parse_tf_examples(vs: List[bytes]) -> client_datasets.Examples: + tf = util.import_tf() tf_examples = tf.io.parse_example( vs, features={ diff --git a/fedjax/datasets/stackoverflow.py b/fedjax/datasets/stackoverflow.py index 57321e4..70e597b 100644 --- a/fedjax/datasets/stackoverflow.py +++ b/fedjax/datasets/stackoverflow.py @@ -19,9 +19,9 @@ from fedjax.core import client_datasets from fedjax.core import federated_data from fedjax.core import sqlite_federated_data +from fedjax.core import util from fedjax.datasets import downloads import numpy as np -import tensorflow as tf SPLITS = ('train', 'held_out', 'test') @@ -137,6 +137,7 @@ def preprocess_client( def default_vocab(default_vocab_size) -> List[str]: """Loads the deafult stackoverflow vocabulary.""" + tf = util.import_tf() path = 'gs://gresearch/fedjax/stackoverflow/stackoverflow.word_count' vocab = [] with tf.io.gfile.GFile(path) as f: @@ -176,6 +177,7 @@ def __init__(self, OOV labels starting at `default_vocab_size + 3`. num_oov_buckets: Number of out of vocabulary buckets. """ + tf = util.import_tf() if vocab is None: # Load default vocabulary. vocab = default_vocab(default_vocab_size) @@ -197,6 +199,7 @@ def as_preprocess_batch( Returns: A function that can be used with FederatedData.preprocess_batch(). """ + tf = util.import_tf() @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) def token_to_ids(tokens): diff --git a/fedjax/training/checkpoint.py b/fedjax/training/checkpoint.py index 443cc9f..d35745f 100644 --- a/fedjax/training/checkpoint.py +++ b/fedjax/training/checkpoint.py @@ -18,13 +18,14 @@ from typing import Any, List, Optional, Tuple from fedjax.core import serialization -import tensorflow as tf +from fedjax.core import util _CHECKPOINT_PREFIX = 'checkpoint_' def _get_checkpoint_paths(base_path: str) -> List[str]: """Returns all checkpoint paths present.""" + tf = util.import_tf() pattern = base_path + r'[0-9]{8}$' checkpoint_paths = [] for path in tf.io.gfile.glob(base_path + '*'): @@ -53,6 +54,7 @@ def save_checkpoint(root_dir: str, round_num: int = 0, keep: int = 1): """Saves checkpoint and cleans up old checkpoints.""" + tf = util.import_tf() base_path = os.path.join(root_dir, _CHECKPOINT_PREFIX) checkpoint_path = f'{base_path}{round_num:08d}' serialization.save_state(state, checkpoint_path) diff --git a/fedjax/training/federated_experiment.py b/fedjax/training/federated_experiment.py index c091b50..28752b1 100644 --- a/fedjax/training/federated_experiment.py +++ b/fedjax/training/federated_experiment.py @@ -24,10 +24,10 @@ from fedjax.core import federated_algorithm from fedjax.core import federated_data from fedjax.core import models +from fedjax.core import util from fedjax.training import checkpoint from fedjax.training import logging as fedjax_logging import jax.numpy as jnp -import tensorflow as tf def set_tf_cpu_only(): @@ -36,6 +36,7 @@ def set_tf_cpu_only(): TensorFlow is only used for data loading, so we prevent it from allocating GPU/TPU memory. """ + tf = util.import_tf() tf.config.experimental.set_visible_devices([], 'GPU') tf.config.experimental.set_visible_devices([], 'TPU') @@ -177,6 +178,7 @@ def run_federated_experiment( Returns: Final state of the input federated algortihm after training. """ + tf = util.import_tf() if config.root_dir: tf.io.gfile.makedirs(config.root_dir) diff --git a/fedjax/training/federated_experiment_test.py b/fedjax/training/federated_experiment_test.py index 79e1392..3c67fa4 100644 --- a/fedjax/training/federated_experiment_test.py +++ b/fedjax/training/federated_experiment_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for fedjax.legacy.training.federated_experiment.""" +import glob import os.path from absl.testing import absltest @@ -26,7 +27,6 @@ import jax import numpy as np import numpy.testing as npt -import tensorflow as tf class FakeClientSampler(client_samplers.ClientSampler): @@ -123,7 +123,7 @@ def test_checkpoint(self): config=config) self.assertEqual(state, (6, 14)) self.assertCountEqual( - tf.io.gfile.glob(os.path.join(config.root_dir, 'checkpoint_*')), + glob.glob(os.path.join(config.root_dir, 'checkpoint_*')), [os.path.join(config.root_dir, 'checkpoint_00000003')]) with self.subTest('checkpoint restore'): @@ -135,7 +135,7 @@ def test_checkpoint(self): config=config) self.assertEqual(state, (6, 16)) self.assertCountEqual( - tf.io.gfile.glob(os.path.join(config.root_dir, 'checkpoint_*')), + glob.glob(os.path.join(config.root_dir, 'checkpoint_*')), [os.path.join(config.root_dir, 'checkpoint_00000004')]) def test_periodic_eval_fn_map(self): @@ -159,7 +159,7 @@ def test_periodic_eval_fn_map(self): }) self.assertEqual(state, (5, 15)) self.assertCountEqual( - tf.io.gfile.glob(os.path.join(config.root_dir, '*eval*')), [ + glob.glob(os.path.join(config.root_dir, '*eval*')), [ os.path.join(config.root_dir, 'test_eval'), os.path.join(config.root_dir, 'train_eval') ]) @@ -180,7 +180,7 @@ def test_final_eval_fn_map(self): }) self.assertEqual(state, (5, 15)) self.assertCountEqual( - tf.io.gfile.glob(os.path.join(config.root_dir, 'final_eval.tsv')), + glob.glob(os.path.join(config.root_dir, 'final_eval.tsv')), [os.path.join(config.root_dir, 'final_eval.tsv')]) diff --git a/fedjax/training/logging.py b/fedjax/training/logging.py index 73b8c6e..65f7978 100644 --- a/fedjax/training/logging.py +++ b/fedjax/training/logging.py @@ -18,8 +18,7 @@ from typing import Any, Optional from absl import logging - -import tensorflow as tf +from fedjax.core import util class Logger: @@ -44,6 +43,7 @@ def log(self, writer_name: str, metric_name: str, metric_value: Any, metric_value: Value of metric to log. round_num: Round number to log. """ + tf = util.import_tf() logging.info('round %d %s: %s = %s', round_num, writer_name, metric_name, metric_value) diff --git a/fedjax/training/logging_test.py b/fedjax/training/logging_test.py index 83f923d..19b12ae 100644 --- a/fedjax/training/logging_test.py +++ b/fedjax/training/logging_test.py @@ -13,12 +13,13 @@ # limitations under the License. """Tests for fedjax.training.logging.""" -import tensorflow as tf +import os +from absl.testing import absltest from fedjax.training import logging -class LoggingTest(tf.test.TestCase): +class LoggingTest(absltest.TestCase): def test_log_no_root_dir(self): logger = logging.Logger() @@ -37,8 +38,8 @@ def test_log_root_dir(self): logger.log( writer_name='eval', metric_name='loss', metric_value=5.3, round_num=0) - self.assertCountEqual(['train', 'eval'], tf.io.gfile.listdir(root_dir)) + self.assertCountEqual(['train', 'eval'], os.listdir(root_dir)) if __name__ == '__main__': - tf.test.main() + absltest.main() diff --git a/setup.py b/setup.py index 74c59c8..ead7315 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,6 @@ 'msgpack', 'optax', 'requests', - 'tensorflow', ], python_requires='>=3.6', classifiers=[