Skip to content

Commit

Permalink
style: pyink reformat
Browse files Browse the repository at this point in the history
Signed-off-by: Saurav Maheshkar <sauravvmaheshkar@gmail.com>
  • Loading branch information
SauravMaheshkar committed Aug 9, 2023
1 parent 159a3aa commit db7b176
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
16 changes: 8 additions & 8 deletions examples/vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
that can be easily tested and imported in Colab.
"""

from absl import app
from absl import flags
from absl import logging
from clu import platform
import jax
from ml_collections import config_flags
import tensorflow as tf

import train
from absl import app, flags, logging
from clu import platform
from ml_collections import config_flags


FLAGS = flags.FLAGS

flags.DEFINE_string('workdir', None, 'Directory to store model data.')
config_flags.DEFINE_config_file(
'config',
None,
Expand All @@ -52,11 +55,8 @@ def main(argv):
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
)
platform.work_unit().create_artifact(
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
)

train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
train.train_and_evaluate(FLAGS.config)


if __name__ == '__main__':
Expand Down
9 changes: 5 additions & 4 deletions examples/vae/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def eval_model(vae):
return nn.apply(eval_model, models.model(latents))({'params': params})


def train_and_evaluate(
config: ml_collections.ConfigDict, workdir: str):
def train_and_evaluate(config: ml_collections.ConfigDict):
"""Train and evaulate pipeline."""
rng = random.PRNGKey(0)
rng, key = random.split(rng)
Expand All @@ -103,7 +102,7 @@ def train_and_evaluate(

logging.info('Initializing dataset.')
train_ds = input_pipeline.build_train_set(config.batch_size, ds_builder)
test_ds = input_pipeline.build_test_set(config.ds_builder)
test_ds = input_pipeline.build_test_set(ds_builder)

logging.info('Initializing model.')
init_data = jnp.ones((config.batch_size, 784), jnp.float32)
Expand All @@ -118,7 +117,9 @@ def train_and_evaluate(
rng, z_key, eval_rng = random.split(rng, 3)
z = random.normal(z_key, (64, config.latents))

steps_per_epoch = ds_builder.info.splits['train'].num_examples // config.batch_size
steps_per_epoch = (
ds_builder.info.splits['train'].num_examples // config.batch_size
)

for epoch in range(config.num_epochs):
for _ in range(steps_per_epoch):
Expand Down

0 comments on commit db7b176

Please sign in to comment.