Skip to content

Commit

Permalink
adapted nbody experiments for scn
Browse files Browse the repository at this point in the history
  • Loading branch information
gerkone committed Jul 16, 2023
1 parent d145155 commit 5eb5644
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 13 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ QM9 is automatically downloaded and processed when running the respective experi
The N-body datasets have to be generated locally from the directory [experiments/nbody/data](experiments/nbody/data) (it will take some time, especially n-body `gravity`)
#### Charged dataset (5 bodies, 10000 training samples)
```
python3 -u generate_dataset.py --simulation=charged
python3 -u generate_dataset.py --simulation=charged --seed=43
```
#### Gravity dataset (100 bodies, 10000 training samples)
```
python3 -u generate_dataset.py --simulation=gravity --n-balls=100
python3 -u generate_dataset.py --simulation=gravity --n-balls=100 --seed=43
```

### Usage
Expand Down
4 changes: 2 additions & 2 deletions experiments/nbody/data/generate_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Generate charged and gravity datasets.
charged: python3 generate_dataset.py --simulation=charged --num-train=10000
gravity: python3 generate_dataset.py --simulation=gravity --num-train=10000 --n-balls=100
charged: python3 generate_dataset.py --simulation=charged --num-train=10000 --seed=43
gravity: python3 generate_dataset.py --simulation=gravity --num-train=10000 --n-balls=100 --seed=43
"""
import argparse
import time
Expand Down
28 changes: 20 additions & 8 deletions experiments/nbody/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ def O3Transform(
node_features_irreps: e3nn.Irreps,
edge_features_irreps: e3nn.Irreps,
lmax_attributes: int,
scn: bool = False,
) -> Callable:
"""
Build a transformation function that includes (nbody) O3 attributes to a graph.
"""
attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes)
if not scn:
attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes)
else:
attribute_irreps = e3nn.Irrep("1o")

@jax.jit
def _o3_transform(
Expand All @@ -50,12 +54,17 @@ def _o3_transform(
jnp.concatenate((loc - mean_loc, vel, vel_abs), axis=-1),
)

edge_attributes = e3nn.spherical_harmonics(
attribute_irreps, rel_pos, normalize=True, normalization="integral"
)
vel_embedding = e3nn.spherical_harmonics(
attribute_irreps, vel, normalize=True, normalization="integral"
)
if not scn:
edge_attributes = e3nn.spherical_harmonics(
attribute_irreps, rel_pos, normalize=True, normalization="integral"
)
vel_embedding = e3nn.spherical_harmonics(
attribute_irreps, vel, normalize=True, normalization="integral"
)
else:
edge_attributes = e3nn.IrrepsArray(attribute_irreps, rel_pos)
vel_embedding = e3nn.IrrepsArray(attribute_irreps, vel)

# scatter edge attributes
sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
node_attributes = (
Expand Down Expand Up @@ -209,7 +218,10 @@ def setup_nbody_data(
)

o3_transform = O3Transform(
args.node_irreps, args.additional_message_irreps, args.lmax_attributes
args.node_irreps,
args.additional_message_irreps,
args.lmax_attributes,
scn=args.o3_layer == "scn",
)
graph_transform = NbodyGraphTransform(
transform=o3_transform,
Expand Down
2 changes: 1 addition & 1 deletion experiments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def train(
test_loss = 0
_, test_loss = eval_fn(loader_test, params, segnn_state)
# ignore compilation time
avg_time = avg_time[2:]
avg_time = avg_time[1:] if len(avg_time) > 1 else avg_time
avg_time = sum(avg_time) / len(avg_time)
print(
"Training done.\n"
Expand Down
10 changes: 10 additions & 0 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@
action="store_true",
help="Use double precision in model",
)
parser.add_argument(
"--scn",
action="store_true",
help="Train SEGNN with the eSCN optimization",
)

# wandb parameters
parser.add_argument(
Expand Down Expand Up @@ -181,6 +186,7 @@
args.node_irreps = e3nn.Irreps("11x0e")
args.output_irreps = e3nn.Irreps("1x0e")
args.additional_message_irreps = e3nn.Irreps("1x0e")
assert not args.scn, "eSCN not implemented for qm9"
elif args.dataset in ["charged", "gravity"]:
args.task = "node"
args.node_irreps = e3nn.Irreps("2x1o + 1x0e")
Expand All @@ -196,6 +202,9 @@
lmax=args.lmax_hidden,
)

args.o3_layer = "scn" if args.scn else "tpl"
del args.scn

# build model
def segnn(x):
return SEGNN(
Expand All @@ -206,6 +215,7 @@ def segnn(x):
pool="avg",
blocks_per_layer=args.blocks,
norm=args.norm,
o3_layer=args.o3_layer,
)(x)

segnn = hk.without_apply_rng(hk.transform_with_state(segnn))
Expand Down

0 comments on commit 5eb5644

Please sign in to comment.