Skip to content

Commit

Permalink
feat: add configs for vae example
Browse files Browse the repository at this point in the history
Signed-off-by: Saurav Maheshkar <sauravvmaheshkar@gmail.com>
  • Loading branch information
SauravMaheshkar committed Aug 2, 2023
1 parent 8da8c46 commit 159a3aa
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 39 deletions.
17 changes: 16 additions & 1 deletion examples/vae/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,24 @@ This code follows [pytorch/examples/vae](https://github.com/pytorch/examples/blo

```bash
pip install -r requirements.txt
python main.py
python main.py --workdir=/tmp/mnist --config=configs/default.py
```

## Overriding Hyperparameter configurations

This VAE example allows specifying a hyperparameter configuration by the means of
setting `--config` flag. Configuration flag is defined using
[config_flags](https://github.com/google/ml_collections/tree/master#config-flags).
`config_flags` allows overriding configuration fields. This can be done as
follows:

```shell
python main.py \
--workdir=/tmp/mnist --config=configs/default.py \
--config.learning_rate=0.01 --config.num_epochs=10
```


## Examples

If you run the code by above command, you can get some generated images:
Expand Down
28 changes: 28 additions & 0 deletions examples/vae/configs/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default Hyperparameter configuration."""

import ml_collections


def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()

config.learning_rate = 0.001
config.latents = 20
config.batch_size = 128
config.num_epochs = 30
return config
40 changes: 22 additions & 18 deletions examples/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,23 @@
that can be easily tested and imported in Colab.
"""

from absl import app
from absl import flags
import jax
import tensorflow as tf

import train

from absl import app, flags, logging
from clu import platform
from ml_collections import config_flags

FLAGS = flags.FLAGS

flags.DEFINE_float(
'learning_rate',
default=1e-3,
help='The learning rate for the Adam optimizer.',
)

flags.DEFINE_integer('batch_size', default=128, help='Batch size for training.')

flags.DEFINE_integer(
'num_epochs', default=30, help='Number of training epochs.'
flags.DEFINE_string('workdir', None, 'Directory to store model data.')
config_flags.DEFINE_config_file(
'config',
None,
'File path to the training hyperparameter configuration.',
lock_config=True,
)

flags.DEFINE_integer('latents', default=20, help='Number of latent variables.')


def main(argv):
if len(argv) > 1:
Expand All @@ -50,9 +43,20 @@ def main(argv):
# Make sure tf does not allocate gpu memory.
tf.config.experimental.set_visible_devices([], 'GPU')

train.train_and_evaluate(
FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs, FLAGS.latents
logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
logging.info('JAX local devices: %r', jax.local_devices())

# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
)
platform.work_unit().create_artifact(
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
)

train.train_and_evaluate(FLAGS.config, FLAGS.workdir)


if __name__ == '__main__':
Expand Down
41 changes: 21 additions & 20 deletions examples/vae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@
# limitations under the License.
"""Traininga and evaluation logic."""

from absl import logging

from flax import linen as nn
from flax.training import train_state
import input_pipeline
import jax
from jax import random
import jax.numpy as jnp
import ml_collections
import models
import optax
import tensorflow_datasets as tfds

import input_pipeline
import models
import utils as vae_utils
from absl import logging
from jax import random

from flax import linen as nn
from flax.training import train_state


@jax.vmap
Expand Down Expand Up @@ -92,7 +92,8 @@ def eval_model(vae):
return nn.apply(eval_model, models.model(latents))({'params': params})


def train_and_evaluate(batch_size, learning_rate, num_epochs, latents):
def train_and_evaluate(
config: ml_collections.ConfigDict, workdir: str):
"""Train and evaulate pipeline."""
rng = random.PRNGKey(0)
rng, key = random.split(rng)
Expand All @@ -101,32 +102,32 @@ def train_and_evaluate(batch_size, learning_rate, num_epochs, latents):
ds_builder.download_and_prepare()

logging.info('Initializing dataset.')
train_ds = input_pipeline.build_train_set(batch_size, ds_builder)
test_ds = input_pipeline.build_test_set(ds_builder)
train_ds = input_pipeline.build_train_set(config.batch_size, ds_builder)
test_ds = input_pipeline.build_test_set(config.ds_builder)

logging.info('Initializing model.')
init_data = jnp.ones((batch_size, 784), jnp.float32)
params = models.model(latents).init(key, init_data, rng)['params']
init_data = jnp.ones((config.batch_size, 784), jnp.float32)
params = models.model(config.latents).init(key, init_data, rng)['params']

state = train_state.TrainState.create(
apply_fn=models.model(latents).apply,
apply_fn=models.model(config.latents).apply,
params=params,
tx=optax.adam(learning_rate),
tx=optax.adam(config.learning_rate),
)

rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, latents))
z = random.normal(z_key, (64, config.latents))

steps_per_epoch = ds_builder.info.splits['train'].num_examples // batch_size
steps_per_epoch = ds_builder.info.splits['train'].num_examples // config.batch_size

for epoch in range(num_epochs):
for epoch in range(config.num_epochs):
for _ in range(steps_per_epoch):
batch = next(train_ds)
rng, key = random.split(rng)
state = train_step(state, batch, key, latents)
state = train_step(state, batch, key, config.latents)

metrics, comparison, sample = eval_f(
state.params, test_ds, z, eval_rng, latents
state.params, test_ds, z, eval_rng, config.latents
)
vae_utils.save_image(
comparison, f'results/reconstruction_{epoch}.png', nrow=8
Expand Down

0 comments on commit 159a3aa

Please sign in to comment.