diff --git a/examples/vae/README.md b/examples/vae/README.md index 99c3f57b0a..325aba8ff1 100644 --- a/examples/vae/README.md +++ b/examples/vae/README.md @@ -5,9 +5,24 @@ This code follows [pytorch/examples/vae](https://github.com/pytorch/examples/blo ```bash pip install -r requirements.txt -python main.py +python main.py --workdir=/tmp/mnist --config=configs/default.py ``` +## Overriding Hyperparameter configurations + +This VAE example allows specifying a hyperparameter configuration by the means of +setting `--config` flag. Configuration flag is defined using +[config_flags](https://github.com/google/ml_collections/tree/master#config-flags). +`config_flags` allows overriding configuration fields. This can be done as +follows: + +```shell +python main.py \ +--workdir=/tmp/mnist --config=configs/default.py \ +--config.learning_rate=0.01 --config.num_epochs=10 +``` + + ## Examples If you run the code by above command, you can get some generated images: diff --git a/examples/vae/configs/default.py b/examples/vae/configs/default.py new file mode 100644 index 0000000000..d96eca220e --- /dev/null +++ b/examples/vae/configs/default.py @@ -0,0 +1,28 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Default Hyperparameter configuration.""" + +import ml_collections + + +def get_config(): + """Get the default hyperparameter configuration.""" + config = ml_collections.ConfigDict() + + config.learning_rate = 0.001 + config.latents = 20 + config.batch_size = 128 + config.num_epochs = 30 + return config diff --git a/examples/vae/main.py b/examples/vae/main.py index c1c9bbd3f1..7a708feab7 100644 --- a/examples/vae/main.py +++ b/examples/vae/main.py @@ -20,7 +20,10 @@ from absl import app from absl import flags +from absl import logging +from clu import platform import jax +from ml_collections import config_flags import tensorflow as tf import train @@ -28,20 +31,13 @@ FLAGS = flags.FLAGS -flags.DEFINE_float( - 'learning_rate', - default=1e-3, - help='The learning rate for the Adam optimizer.', +config_flags.DEFINE_config_file( + 'config', + None, + 'File path to the training hyperparameter configuration.', + lock_config=True, ) -flags.DEFINE_integer('batch_size', default=128, help='Batch size for training.') - -flags.DEFINE_integer( - 'num_epochs', default=30, help='Number of training epochs.' -) - -flags.DEFINE_integer('latents', default=20, help='Number of latent variables.') - def main(argv): if len(argv) > 1: @@ -50,10 +46,18 @@ def main(argv): # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') - train.train_and_evaluate( - FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs, FLAGS.latents + logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) + logging.info('JAX local devices: %r', jax.local_devices()) + + # Add a note so that we can tell which task is which JAX host. + # (Depending on the platform task 0 is not guaranteed to be host 0) + platform.work_unit().set_task_status( + f'process_index: {jax.process_index()}, ' + f'process_count: {jax.process_count()}' ) + train.train_and_evaluate(FLAGS.config) + if __name__ == '__main__': app.run(main) diff --git a/examples/vae/train.py b/examples/vae/train.py index 6404f07cd5..f8f0e618eb 100644 --- a/examples/vae/train.py +++ b/examples/vae/train.py @@ -27,19 +27,19 @@ # limitations under the License. """Traininga and evaluation logic.""" -from absl import logging - -from flax import linen as nn -from flax.training import train_state +import input_pipeline import jax -from jax import random import jax.numpy as jnp +import ml_collections +import models import optax import tensorflow_datasets as tfds - -import input_pipeline -import models import utils as vae_utils +from absl import logging +from jax import random + +from flax import linen as nn +from flax.training import train_state @jax.vmap @@ -92,7 +92,7 @@ def eval_model(vae): return nn.apply(eval_model, models.model(latents))({'params': params}) -def train_and_evaluate(batch_size, learning_rate, num_epochs, latents): +def train_and_evaluate(config: ml_collections.ConfigDict): """Train and evaulate pipeline.""" rng = random.PRNGKey(0) rng, key = random.split(rng) @@ -101,32 +101,34 @@ def train_and_evaluate(batch_size, learning_rate, num_epochs, latents): ds_builder.download_and_prepare() logging.info('Initializing dataset.') - train_ds = input_pipeline.build_train_set(batch_size, ds_builder) + train_ds = input_pipeline.build_train_set(config.batch_size, ds_builder) test_ds = input_pipeline.build_test_set(ds_builder) logging.info('Initializing model.') - init_data = jnp.ones((batch_size, 784), jnp.float32) - params = models.model(latents).init(key, init_data, rng)['params'] + init_data = jnp.ones((config.batch_size, 784), jnp.float32) + params = models.model(config.latents).init(key, init_data, rng)['params'] state = train_state.TrainState.create( - apply_fn=models.model(latents).apply, + apply_fn=models.model(config.latents).apply, params=params, - tx=optax.adam(learning_rate), + tx=optax.adam(config.learning_rate), ) rng, z_key, eval_rng = random.split(rng, 3) - z = random.normal(z_key, (64, latents)) + z = random.normal(z_key, (64, config.latents)) - steps_per_epoch = ds_builder.info.splits['train'].num_examples // batch_size + steps_per_epoch = ( + ds_builder.info.splits['train'].num_examples // config.batch_size + ) - for epoch in range(num_epochs): + for epoch in range(config.num_epochs): for _ in range(steps_per_epoch): batch = next(train_ds) rng, key = random.split(rng) - state = train_step(state, batch, key, latents) + state = train_step(state, batch, key, config.latents) metrics, comparison, sample = eval_f( - state.params, test_ds, z, eval_rng, latents + state.params, test_ds, z, eval_rng, config.latents ) vae_utils.save_image( comparison, f'results/reconstruction_{epoch}.png', nrow=8