Skip to content

Commit

Permalink
nbody scn validation
Browse files Browse the repository at this point in the history
  • Loading branch information
gerkone committed Jul 16, 2023
1 parent ea5f734 commit 2913042
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
10 changes: 5 additions & 5 deletions experiments/nbody/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ def _o3_transform(
)
+ vel_embedding
)

# scalar attribute to 1 by default
node_attributes = e3nn.IrrepsArray(
node_attributes.irreps, node_attributes.array.at[:, 0].set(1.0)
)
if not scn:
# scalar attribute to 1 by default
node_attributes = e3nn.IrrepsArray(
node_attributes.irreps, node_attributes.array.at[:, 0].set(1.0)
)

return SteerableGraphsTuple(
graph=GraphsTuple(
Expand Down
5 changes: 4 additions & 1 deletion experiments/qm9/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def _to_steerable_graph(
node_attributes = e3nn.IrrepsArray(
attribute_irreps, jnp.pad(jnp.array(data.node_attr), node_attr_pad)
)
node_attributes.array = node_attributes.array.at[:, 0].set(1.0)
# scalar attribute to 1 by default
node_attributes = e3nn.IrrepsArray(
node_attributes.irreps, node_attributes.array.at[:, 0].set(1.0)
)

additional_message_features = e3nn.IrrepsArray(
args.additional_message_irreps,
Expand Down
10 changes: 7 additions & 3 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,15 @@
args.additional_message_irreps = e3nn.Irreps("2x0e")

# Create hidden irreps
if not args.scn:
attr_irreps = e3nn.Irreps.spherical_harmonics(args.lmax_attributes)
else:
attr_irreps = e3nn.Irrep(f"{args.lmax_attribute}y")

hidden_irreps = weight_balanced_irreps(
scalar_units=args.units,
# attribute irreps
irreps_right=e3nn.Irreps.spherical_harmonics(args.lmax_attributes),
use_sh=True,
irreps_right=attr_irreps,
use_sh=(not args.scn),
lmax=args.lmax_hidden,
)

Expand Down

0 comments on commit 2913042

Please sign in to comment.