Skip to content

Commit

Permalink
Merge pull request #3254 from SauravMaheshkar:saurav/vae
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555925898
  • Loading branch information
Flax Authors committed Aug 11, 2023
2 parents 5030149 + db7b176 commit d05aeab
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 34 deletions.
17 changes: 16 additions & 1 deletion examples/vae/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,24 @@ This code follows [pytorch/examples/vae](https://github.com/pytorch/examples/blo

```bash
pip install -r requirements.txt
python main.py
python main.py --workdir=/tmp/mnist --config=configs/default.py
```

## Overriding Hyperparameter configurations

This VAE example allows specifying a hyperparameter configuration by the means of
setting `--config` flag. Configuration flag is defined using
[config_flags](https://github.com/google/ml_collections/tree/master#config-flags).
`config_flags` allows overriding configuration fields. This can be done as
follows:

```shell
python main.py \
--workdir=/tmp/mnist --config=configs/default.py \
--config.learning_rate=0.01 --config.num_epochs=10
```


## Examples

If you run the code by above command, you can get some generated images:
Expand Down
28 changes: 28 additions & 0 deletions examples/vae/configs/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default Hyperparameter configuration."""

import ml_collections


def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()

config.learning_rate = 0.001
config.latents = 20
config.batch_size = 128
config.num_epochs = 30
return config
32 changes: 18 additions & 14 deletions examples/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,24 @@

from absl import app
from absl import flags
from absl import logging
from clu import platform
import jax
from ml_collections import config_flags
import tensorflow as tf

import train


FLAGS = flags.FLAGS

flags.DEFINE_float(
'learning_rate',
default=1e-3,
help='The learning rate for the Adam optimizer.',
config_flags.DEFINE_config_file(
'config',
None,
'File path to the training hyperparameter configuration.',
lock_config=True,
)

flags.DEFINE_integer('batch_size', default=128, help='Batch size for training.')

flags.DEFINE_integer(
'num_epochs', default=30, help='Number of training epochs.'
)

flags.DEFINE_integer('latents', default=20, help='Number of latent variables.')


def main(argv):
if len(argv) > 1:
Expand All @@ -50,10 +46,18 @@ def main(argv):
# Make sure tf does not allocate gpu memory.
tf.config.experimental.set_visible_devices([], 'GPU')

train.train_and_evaluate(
FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs, FLAGS.latents
logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
logging.info('JAX local devices: %r', jax.local_devices())

# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
)

train.train_and_evaluate(FLAGS.config)


if __name__ == '__main__':
app.run(main)
40 changes: 21 additions & 19 deletions examples/vae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@
# limitations under the License.
"""Traininga and evaluation logic."""

from absl import logging

from flax import linen as nn
from flax.training import train_state
import input_pipeline
import jax
from jax import random
import jax.numpy as jnp
import ml_collections
import models
import optax
import tensorflow_datasets as tfds

import input_pipeline
import models
import utils as vae_utils
from absl import logging
from jax import random

from flax import linen as nn
from flax.training import train_state


@jax.vmap
Expand Down Expand Up @@ -92,7 +92,7 @@ def eval_model(vae):
return nn.apply(eval_model, models.model(latents))({'params': params})


def train_and_evaluate(batch_size, learning_rate, num_epochs, latents):
def train_and_evaluate(config: ml_collections.ConfigDict):
"""Train and evaulate pipeline."""
rng = random.PRNGKey(0)
rng, key = random.split(rng)
Expand All @@ -101,32 +101,34 @@ def train_and_evaluate(batch_size, learning_rate, num_epochs, latents):
ds_builder.download_and_prepare()

logging.info('Initializing dataset.')
train_ds = input_pipeline.build_train_set(batch_size, ds_builder)
train_ds = input_pipeline.build_train_set(config.batch_size, ds_builder)
test_ds = input_pipeline.build_test_set(ds_builder)

logging.info('Initializing model.')
init_data = jnp.ones((batch_size, 784), jnp.float32)
params = models.model(latents).init(key, init_data, rng)['params']
init_data = jnp.ones((config.batch_size, 784), jnp.float32)
params = models.model(config.latents).init(key, init_data, rng)['params']

state = train_state.TrainState.create(
apply_fn=models.model(latents).apply,
apply_fn=models.model(config.latents).apply,
params=params,
tx=optax.adam(learning_rate),
tx=optax.adam(config.learning_rate),
)

rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, latents))
z = random.normal(z_key, (64, config.latents))

steps_per_epoch = ds_builder.info.splits['train'].num_examples // batch_size
steps_per_epoch = (
ds_builder.info.splits['train'].num_examples // config.batch_size
)

for epoch in range(num_epochs):
for epoch in range(config.num_epochs):
for _ in range(steps_per_epoch):
batch = next(train_ds)
rng, key = random.split(rng)
state = train_step(state, batch, key, latents)
state = train_step(state, batch, key, config.latents)

metrics, comparison, sample = eval_f(
state.params, test_ds, z, eval_rng, latents
state.params, test_ds, z, eval_rng, config.latents
)
vae_utils.save_image(
comparison, f'results/reconstruction_{epoch}.png', nrow=8
Expand Down

0 comments on commit d05aeab

Please sign in to comment.