Skip to content

Commit

Permalink
deps and version bump
Browse files Browse the repository at this point in the history
  • Loading branch information
gerkone committed Jul 16, 2023
1 parent 2913042 commit fb0fdc4
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 8 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ python -m pip install -e .
### GPU support
Upgrade `jax` to the gpu version
```
pip install --upgrade "jax[cuda]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade "jax[cuda]>=0.4.6" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

## Validation
Expand Down Expand Up @@ -92,6 +92,12 @@ python3 -u generate_dataset.py --simulation=charged --seed=43
python3 -u generate_dataset.py --simulation=gravity --n-balls=100 --seed=43
```

### Notes
On `jax<=0.4.6`, the `jit`-`pjit` merge can be deactivated making traning faster. This looks like an issue with dataloading and the validation training loop implementation and it does not affect SEGNN.
```
export JAX_JIT_PJIT_API_MERGE=0
```

### Usage
#### N-body (charged)
```
Expand All @@ -111,6 +117,7 @@ python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 --
(configurations used in validation)



## Acknowledgments
- [e3nn_jax](https://github.com/e3nn/e3nn-jax) made this reimplementation possible.
- [Artur Toshev](https://github.com/arturtoshev) and [Johannes Brandsetter](https://github.com/brandstetter-johannes), for support.
7 changes: 4 additions & 3 deletions experiments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def train(

for e in range(args.epochs):
train_loss = 0.0
train_start = time.perf_counter_ns()
epoch_start = time.perf_counter_ns()
for data in loader_train:
graph, target = graph_transform(data)
loss, params, segnn_state, opt_state = update_fn(
Expand All @@ -136,10 +136,11 @@ def train(
opt_state=opt_state,
)
train_loss += loss
train_time = (time.perf_counter_ns() - train_start) / 1e6
train_loss /= len(loader_train)
epoch_time = (time.perf_counter_ns() - epoch_start) / 1e9

print(
f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {train_time:.2f}ms",
f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {epoch_time:.2f}s",
end="",
)
if e % args.val_freq == 0:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
dm-haiku==0.0.9
e3nn-jax==0.19.3
jax[cuda]==0.4.13
jax[cuda]
jraph==0.0.6.dev0
numpy>=1.23.4
optax==0.1.3
2 changes: 1 addition & 1 deletion segnn_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
"SteerableGraphsTuple",
]

__version__ = "0.6"
__version__ = "0.7"
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ python_requires = >=3.8
install_requires =
dm_haiku==0.0.9
e3nn_jax==0.19.3
jax==0.4.13
jaxlib==0.4.13
jax
jaxlib
jraph==0.0.6.dev0
numpy>=1.23.4
optax==0.1.3

0 comments on commit fb0fdc4

Please sign in to comment.