Skip to content

Commit

Permalink
Fix val batch bug
Browse files Browse the repository at this point in the history
  • Loading branch information
danielward27 committed Jun 27, 2022
1 parent 9750703 commit b9f9092
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flowjax/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def step(flow, optimizer, opt_state, x, condition=None):
batches = range(0, val_args[0].shape[0] - batch_size, batch_size)
for i in batches:
batch = tuple(a[i : i + batch_size] for a in val_args)
epoch_val_loss += loss(flow, *val_args).item() / len(batches)
epoch_val_loss += loss(flow, *batch).item() / len(batches)

losses["train"].append(epoch_train_loss)
losses["val"].append(epoch_val_loss)
Expand Down

0 comments on commit b9f9092

Please sign in to comment.