Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add configs for vae example #3254

Merged
merged 2 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion examples/vae/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,24 @@ This code follows [pytorch/examples/vae](https://github.com/pytorch/examples/blo

```bash
pip install -r requirements.txt
python main.py
python main.py --workdir=/tmp/mnist --config=configs/default.py
```

## Overriding Hyperparameter configurations

This VAE example allows specifying a hyperparameter configuration by the means of
setting `--config` flag. Configuration flag is defined using
[config_flags](https://github.com/google/ml_collections/tree/master#config-flags).
`config_flags` allows overriding configuration fields. This can be done as
follows:

```shell
python main.py \
--workdir=/tmp/mnist --config=configs/default.py \
--config.learning_rate=0.01 --config.num_epochs=10
```


## Examples

If you run the code by above command, you can get some generated images:
Expand Down
28 changes: 28 additions & 0 deletions examples/vae/configs/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default Hyperparameter configuration."""

import ml_collections


def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()

config.learning_rate = 0.001
config.latents = 20
config.batch_size = 128
config.num_epochs = 30
return config
40 changes: 22 additions & 18 deletions examples/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,23 @@
that can be easily tested and imported in Colab.
"""

from absl import app
from absl import flags
import jax
import tensorflow as tf

import train

from absl import app, flags, logging
SauravMaheshkar marked this conversation as resolved.
Show resolved Hide resolved
from clu import platform
from ml_collections import config_flags

FLAGS = flags.FLAGS

flags.DEFINE_float(
'learning_rate',
default=1e-3,
help='The learning rate for the Adam optimizer.',
)

flags.DEFINE_integer('batch_size', default=128, help='Batch size for training.')

flags.DEFINE_integer(
'num_epochs', default=30, help='Number of training epochs.'
flags.DEFINE_string('workdir', None, 'Directory to store model data.')
config_flags.DEFINE_config_file(
'config',
None,
'File path to the training hyperparameter configuration.',
lock_config=True,
)

flags.DEFINE_integer('latents', default=20, help='Number of latent variables.')


def main(argv):
if len(argv) > 1:
Expand All @@ -50,9 +43,20 @@ def main(argv):
# Make sure tf does not allocate gpu memory.
tf.config.experimental.set_visible_devices([], 'GPU')

train.train_and_evaluate(
FLAGS.batch_size, FLAGS.learning_rate, FLAGS.num_epochs, FLAGS.latents
logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())
logging.info('JAX local devices: %r', jax.local_devices())

# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
)
platform.work_unit().create_artifact(
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
)

train.train_and_evaluate(FLAGS.config, FLAGS.workdir)


if __name__ == '__main__':
Expand Down
41 changes: 21 additions & 20 deletions examples/vae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@
# limitations under the License.
"""Traininga and evaluation logic."""

from absl import logging

from flax import linen as nn
from flax.training import train_state
import input_pipeline
import jax
from jax import random
import jax.numpy as jnp
import ml_collections
import models
import optax
import tensorflow_datasets as tfds

import input_pipeline
import models
import utils as vae_utils
from absl import logging
from jax import random

from flax import linen as nn
from flax.training import train_state


@jax.vmap
Expand Down Expand Up @@ -92,7 +92,8 @@ def eval_model(vae):
return nn.apply(eval_model, models.model(latents))({'params': params})


def train_and_evaluate(batch_size, learning_rate, num_epochs, latents):
def train_and_evaluate(
config: ml_collections.ConfigDict, workdir: str):
SauravMaheshkar marked this conversation as resolved.
Show resolved Hide resolved
"""Train and evaulate pipeline."""
rng = random.PRNGKey(0)
rng, key = random.split(rng)
Expand All @@ -101,32 +102,32 @@ def train_and_evaluate(batch_size, learning_rate, num_epochs, latents):
ds_builder.download_and_prepare()

logging.info('Initializing dataset.')
train_ds = input_pipeline.build_train_set(batch_size, ds_builder)
test_ds = input_pipeline.build_test_set(ds_builder)
train_ds = input_pipeline.build_train_set(config.batch_size, ds_builder)
test_ds = input_pipeline.build_test_set(config.ds_builder)
SauravMaheshkar marked this conversation as resolved.
Show resolved Hide resolved

logging.info('Initializing model.')
init_data = jnp.ones((batch_size, 784), jnp.float32)
params = models.model(latents).init(key, init_data, rng)['params']
init_data = jnp.ones((config.batch_size, 784), jnp.float32)
params = models.model(config.latents).init(key, init_data, rng)['params']

state = train_state.TrainState.create(
apply_fn=models.model(latents).apply,
apply_fn=models.model(config.latents).apply,
params=params,
tx=optax.adam(learning_rate),
tx=optax.adam(config.learning_rate),
)

rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, latents))
z = random.normal(z_key, (64, config.latents))

steps_per_epoch = ds_builder.info.splits['train'].num_examples // batch_size
steps_per_epoch = ds_builder.info.splits['train'].num_examples // config.batch_size

for epoch in range(num_epochs):
for epoch in range(config.num_epochs):
for _ in range(steps_per_epoch):
batch = next(train_ds)
rng, key = random.split(rng)
state = train_step(state, batch, key, latents)
state = train_step(state, batch, key, config.latents)

metrics, comparison, sample = eval_f(
state.params, test_ds, z, eval_rng, latents
state.params, test_ds, z, eval_rng, config.latents
)
vae_utils.save_image(
comparison, f'results/reconstruction_{epoch}.png', nrow=8
Expand Down
Loading