From 159a3aa748d67fc58f0eacfa73c05a7144738b3a Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Wed, 2 Aug 2023 12:16:57 +0530 Subject: [PATCH 1/2] feat: add configs for vae example Signed-off-by: Saurav Maheshkar --- examples/vae/README.md | 17 +++++++++++++- examples/vae/configs/default.py | 28 ++++++++++++++++++++++ examples/vae/main.py | 40 +++++++++++++++++--------------- examples/vae/train.py | 41 +++++++++++++++++---------------- 4 files changed, 87 insertions(+), 39 deletions(-) create mode 100644 examples/vae/configs/default.py diff --git a/examples/vae/README.md b/examples/vae/README.md index 99c3f57b0a..325aba8ff1 100644 --- a/examples/vae/README.md +++ b/examples/vae/README.md @@ -5,9 +5,24 @@ This code follows [pytorch/examples/vae](https://github.com/pytorch/examples/blo ```bash pip install -r requirements.txt -python main.py +python main.py --workdir=/tmp/mnist --config=configs/default.py ``` +## Overriding Hyperparameter configurations + +This VAE example allows specifying a hyperparameter configuration by the means of +setting `--config` flag. Configuration flag is defined using +[config_flags](https://github.com/google/ml_collections/tree/master#config-flags). +`config_flags` allows overriding configuration fields. This can be done as +follows: + +```shell +python main.py \ +--workdir=/tmp/mnist --config=configs/default.py \ +--config.learning_rate=0.01 --config.num_epochs=10 +``` + + ## Examples If you run the code by above command, you can get some generated images: diff --git a/examples/vae/configs/default.py b/examples/vae/configs/default.py new file mode 100644 index 0000000000..d96eca220e --- /dev/null +++ b/examples/vae/configs/default.py @@ -0,0 +1,28 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default Hyperparameter configuration.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.learning_rate = 0.001 + config.latents = 20 + config.batch_size = 128 + config.num_epochs = 30 + return config diff --git a/examples/vae/main.py b/examples/vae/main.py index c1c9bbd3f1..52fb2635a5 100644 --- a/examples/vae/main.py +++ b/examples/vae/main.py @@ -18,30 +18,23 @@ that can be easily tested and imported in Colab. """ -from absl import app -from absl import flags import jax import tensorflow as tf - import train - +from absl import app, flags, logging +from clu import platform +from ml_collections import config_flags FLAGS = flags.FLAGS -flags.DEFINE_float( - 'learning_rate', - default=1e-3, - help='The learning rate for the Adam optimizer.', -) - -flags.DEFINE_integer('batch_size', default=128, help='Batch size for training.') - -flags.DEFINE_integer( - 'num_epochs', default=30, help='Number of training epochs.' +flags.DEFINE_string('workdir', None, 'Directory to store model data.') +config_flags.DEFINE_config_file( + 'config', + None, + 'File path to the training hyperparameter configuration.', + lock_config=True, ) -flags.DEFINE_integer('latents', default=20, help='Number of latent variables.') - def main(argv): if len(argv) > 1: @@ -50,9 +43,20 @@ def main(argv): # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') - train.train_and_evaluate( - FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs, FLAGS.latents + logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) + logging.info('JAX local devices: %r', jax.local_devices()) + + # Add a note so that we can tell which task is which JAX host. + # (Depending on the platform task 0 is not guaranteed to be host 0) + platform.work_unit().set_task_status( + f'process_index: {jax.process_index()}, ' + f'process_count: {jax.process_count()}' ) + platform.work_unit().create_artifact( + platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' + ) + + train.train_and_evaluate(FLAGS.config, FLAGS.workdir) if __name__ == '__main__': diff --git a/examples/vae/train.py b/examples/vae/train.py index 6404f07cd5..6a48a5be5b 100644 --- a/examples/vae/train.py +++ b/examples/vae/train.py @@ -27,19 +27,19 @@ # limitations under the License. """Traininga and evaluation logic.""" -from absl import logging - -from flax import linen as nn -from flax.training import train_state +import input_pipeline import jax -from jax import random import jax.numpy as jnp +import ml_collections +import models import optax import tensorflow_datasets as tfds - -import input_pipeline -import models import utils as vae_utils +from absl import logging +from jax import random + +from flax import linen as nn +from flax.training import train_state @jax.vmap @@ -92,7 +92,8 @@ def eval_model(vae): return nn.apply(eval_model, models.model(latents))({'params': params}) -def train_and_evaluate(batch_size, learning_rate, num_epochs, latents): +def train_and_evaluate( + config: ml_collections.ConfigDict, workdir: str): """Train and evaulate pipeline.""" rng = random.PRNGKey(0) rng, key = random.split(rng) @@ -101,32 +102,32 @@ def train_and_evaluate(batch_size, learning_rate, num_epochs, latents): ds_builder.download_and_prepare() logging.info('Initializing dataset.') - train_ds = input_pipeline.build_train_set(batch_size, ds_builder) - test_ds = input_pipeline.build_test_set(ds_builder) + train_ds = input_pipeline.build_train_set(config.batch_size, ds_builder) + test_ds = input_pipeline.build_test_set(config.ds_builder) logging.info('Initializing model.') - init_data = jnp.ones((batch_size, 784), jnp.float32) - params = models.model(latents).init(key, init_data, rng)['params'] + init_data = jnp.ones((config.batch_size, 784), jnp.float32) + params = models.model(config.latents).init(key, init_data, rng)['params'] state = train_state.TrainState.create( - apply_fn=models.model(latents).apply, + apply_fn=models.model(config.latents).apply, params=params, - tx=optax.adam(learning_rate), + tx=optax.adam(config.learning_rate), ) rng, z_key, eval_rng = random.split(rng, 3) - z = random.normal(z_key, (64, latents)) + z = random.normal(z_key, (64, config.latents)) - steps_per_epoch = ds_builder.info.splits['train'].num_examples // batch_size + steps_per_epoch = ds_builder.info.splits['train'].num_examples // config.batch_size - for epoch in range(num_epochs): + for epoch in range(config.num_epochs): for _ in range(steps_per_epoch): batch = next(train_ds) rng, key = random.split(rng) - state = train_step(state, batch, key, latents) + state = train_step(state, batch, key, config.latents) metrics, comparison, sample = eval_f( - state.params, test_ds, z, eval_rng, latents + state.params, test_ds, z, eval_rng, config.latents ) vae_utils.save_image( comparison, f'results/reconstruction_{epoch}.png', nrow=8 From db7b1762f8abde1bc32ed2e47a75a1da8d365616 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Wed, 9 Aug 2023 11:47:55 +0530 Subject: [PATCH 2/2] style: pyink reformat Signed-off-by: Saurav Maheshkar --- examples/vae/main.py | 16 ++++++++-------- examples/vae/train.py | 9 +++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/vae/main.py b/examples/vae/main.py index 52fb2635a5..7a708feab7 100644 --- a/examples/vae/main.py +++ b/examples/vae/main.py @@ -18,16 +18,19 @@ that can be easily tested and imported in Colab. """ +from absl import app +from absl import flags +from absl import logging +from clu import platform import jax +from ml_collections import config_flags import tensorflow as tf + import train -from absl import app, flags, logging -from clu import platform -from ml_collections import config_flags + FLAGS = flags.FLAGS -flags.DEFINE_string('workdir', None, 'Directory to store model data.') config_flags.DEFINE_config_file( 'config', None, @@ -52,11 +55,8 @@ def main(argv): f'process_index: {jax.process_index()}, ' f'process_count: {jax.process_count()}' ) - platform.work_unit().create_artifact( - platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir' - ) - train.train_and_evaluate(FLAGS.config, FLAGS.workdir) + train.train_and_evaluate(FLAGS.config) if __name__ == '__main__': diff --git a/examples/vae/train.py b/examples/vae/train.py index 6a48a5be5b..f8f0e618eb 100644 --- a/examples/vae/train.py +++ b/examples/vae/train.py @@ -92,8 +92,7 @@ def eval_model(vae): return nn.apply(eval_model, models.model(latents))({'params': params}) -def train_and_evaluate( - config: ml_collections.ConfigDict, workdir: str): +def train_and_evaluate(config: ml_collections.ConfigDict): """Train and evaulate pipeline.""" rng = random.PRNGKey(0) rng, key = random.split(rng) @@ -103,7 +102,7 @@ def train_and_evaluate( logging.info('Initializing dataset.') train_ds = input_pipeline.build_train_set(config.batch_size, ds_builder) - test_ds = input_pipeline.build_test_set(config.ds_builder) + test_ds = input_pipeline.build_test_set(ds_builder) logging.info('Initializing model.') init_data = jnp.ones((config.batch_size, 784), jnp.float32) @@ -118,7 +117,9 @@ def train_and_evaluate( rng, z_key, eval_rng = random.split(rng, 3) z = random.normal(z_key, (64, config.latents)) - steps_per_epoch = ds_builder.info.splits['train'].num_examples // config.batch_size + steps_per_epoch = ( + ds_builder.info.splits['train'].num_examples // config.batch_size + ) for epoch in range(config.num_epochs): for _ in range(steps_per_epoch):