Skip to content

Commit

Permalink
Merge pull request #3 from gerkone/qm9-fixes
Browse files Browse the repository at this point in the history
Pooling irreps fix and e3nn version
  • Loading branch information
gerkone committed Jul 27, 2023
2 parents a9f8af0 + b448c80 commit 217915b
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 45 deletions.
16 changes: 3 additions & 13 deletions experiments/qm9/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,9 @@ def setup_qm9_data(
feature_type=args.feature_type,
)

max_batch_nodes = int(
max(
sum(d.top_n_nodes(args.batch_size))
for d in [dataset_train, dataset_val, dataset_test]
)
)

max_batch_edges = int(
max(
sum(d.top_n_edges(args.batch_size))
for d in [dataset_train, dataset_val, dataset_test]
)
)
# 0.8 (un)safety factor for rejitting
max_batch_nodes = int(0.8 * sum(dataset_test.top_n_nodes(args.batch_size)))
max_batch_edges = int(0.8 * sum(dataset_test.top_n_edges(args.batch_size)))

target_mean, target_mad = dataset_train.calc_stats()

Expand Down
9 changes: 4 additions & 5 deletions experiments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
from segnn_jax import SteerableGraphsTuple


@partial(jit, static_argnames=["model_fn", "criterion", "task", "do_mask", "eval_trn"])
@partial(jit, static_argnames=["model_fn", "criterion", "do_mask", "eval_trn"])
def loss_fn_wrapper(
params: hk.Params,
state: hk.State,
st_graph: SteerableGraphsTuple,
target: jnp.ndarray,
model_fn: Callable,
criterion: Callable,
task: str = "node",
do_mask: bool = True,
eval_trn: Callable = None,
) -> Tuple[float, hk.State]:
Expand All @@ -29,9 +28,9 @@ def loss_fn_wrapper(
pred = eval_trn(pred)

if do_mask:
if task == "node":
if target.shape == st_graph.graph.nodes.shape:
mask = jraph.get_node_padding_mask(st_graph.graph)
if task == "graph":
else:
mask = jraph.get_graph_padding_mask(st_graph.graph)
# broadcast mask for vector targets
if len(pred.shape) == 2:
Expand Down Expand Up @@ -140,7 +139,7 @@ def train(
target=target,
opt_state=opt_state,
)
train_loss += loss
train_loss += jax.block_until_ready(loss)
train_loss /= len(loader_train)
epoch_time = (time.perf_counter_ns() - epoch_start) / 1e9

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
dm-haiku==0.0.9
e3nn-jax==0.19.3
dm-haiku>=0.0.9
e3nn-jax>=0.17.4
jax[cuda]
jraph==0.0.6.dev0
numpy>=1.23.4
Expand Down
37 changes: 31 additions & 6 deletions segnn_jax/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
import haiku as hk
import jax
import jax.numpy as jnp
from e3nn_jax.experimental import linear_shtp as escn
from e3nn_jax.legacy import FunctionalFullyConnectedTensorProduct

try:
from e3nn_jax.experimental import linear_shtp as escn
except ImportError:
escn = None

try:
from e3nn_jax import FunctionalFullyConnectedTensorProduct
except ImportError:
from e3nn_jax.legacy import FunctionalFullyConnectedTensorProduct # type: ignore

from .config import config

Expand Down Expand Up @@ -136,11 +144,19 @@ def __init__(
self.output_irreps,
get_parameter=self.get_parameter,
biases=self.biases,
name=f"{self.name}_linear",
gradient_normalization=self._gradient_normalization,
path_normalization=self._path_normalization,
)

def _check_input(
self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None
) -> Tuple[e3nn.IrrepsArray, e3nn.IrrepsArray]:
x, y = super()._check_input(x, y)
miss = self.output_irreps.filter(drop=e3nn.tensor_product(x.irreps, y.irreps))
if len(miss) > 0:
warnings.warn(f"Output irreps: '{miss}' are unreachable and were ignored.")
return x, y

def __call__(
self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None
) -> TensorProductFn:
Expand Down Expand Up @@ -263,11 +279,16 @@ def __init__(
path_normalization=path_normalization,
)

if escn is None:
raise ImportError(
"eSCN is available from e3nn-jax>=0.17.4. "
f"Your version: {e3nn.__version__}"
)

self._linear = e3nn.haiku.Linear(
self.output_irreps,
get_parameter=self.get_parameter,
biases=self.biases,
name=f"{self.name}_linear",
gradient_normalization=self._gradient_normalization,
path_normalization=self._path_normalization,
)
Expand All @@ -280,7 +301,7 @@ def _check_input(
return super()._check_input(x, y)

def __call__(
self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None, **kwargs
self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None
) -> e3nn.IrrepsArray:
"""Apply the layer. y must not be into spherical harmonics."""
x, y = self._check_input(x, y)
Expand Down Expand Up @@ -355,6 +376,10 @@ def _gated_tensor_product(
x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None, **kwargs
) -> e3nn.IrrepsArray:
tp = tensor_product(x, y, **kwargs)
return e3nn.gate(tp, even_act=scalar_activation, odd_gate_act=gate_activation)
# skip gate if the gating scalars are not reachable
if len(gate_irreps.filter(drop=tp.irreps)) > 0:
return tp
else:
return e3nn.gate(tp, scalar_activation, odd_gate_act=gate_activation)

return _gated_tensor_product
49 changes: 36 additions & 13 deletions segnn_jax/segnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def O3Decoder(
blocks: int = 1,
task: str = "graph",
pool: Optional[str] = "avg",
pooled_irreps: Optional[e3nn.Irreps] = None,
O3Layer: TensorProduct = O3TensorProduct,
):
"""Steerable pooler and decoder.
Expand All @@ -72,6 +73,7 @@ def O3Decoder(
blocks: Number of tensor product blocks in the decoder
task: Specifies where the output is located. Either 'graph' or 'node'
pool: Pooling method to use. One of 'avg', 'sum', 'none', None
pooled_irreps: Pooled irreps. When left None the original implementation is used
O3Layer: Type of tensor product layer to use
Returns:
Expand All @@ -81,6 +83,12 @@ def O3Decoder(
assert task in ["node", "graph"], f"Unknown task {task}"
assert pool in ["avg", "sum", "none", None], f"Unknown pooling '{pool}'"

# NOTE: original implementation restricted final layers to pooled_irreps.
# This way gates cannot be applied in the post pool block when returning vectors,
# because the gating scalars cannot be reached.
if pooled_irreps is None:
pooled_irreps = (output_irreps * latent_irreps.num_irreps).regroup()

def _decoder(st_graph: SteerableGraphsTuple):
nodes = st_graph.graph.nodes
# pre pool block
Expand All @@ -96,7 +104,7 @@ def _decoder(st_graph: SteerableGraphsTuple):

if task == "graph":
# pool over graph
nodes = O3Layer(latent_irreps, name=f"prepool_{blocks}")(
nodes = O3Layer(pooled_irreps, name=f"prepool_{blocks}")(
nodes, st_graph.node_attributes
)

Expand All @@ -111,7 +119,7 @@ def _decoder(st_graph: SteerableGraphsTuple):
# post pool mlp (not steerable)
for i in range(blocks):
nodes = O3TensorProductGate(
latent_irreps, name=f"postpool_{i}", o3_layer=O3TensorProduct
pooled_irreps, name=f"postpool_{i}", o3_layer=O3TensorProduct
)(nodes)
nodes = O3TensorProduct(output_irreps, name="output")(nodes)

Expand All @@ -134,6 +142,7 @@ def __init__(
blocks: int = 2,
norm: Optional[str] = None,
aggregate_fn: Optional[Callable] = jraph.segment_sum,
residual: bool = True,
O3Layer: TensorProduct = O3TensorProduct,
):
"""
Expand All @@ -145,6 +154,7 @@ def __init__(
blocks: Number of tensor product blocks in the layer
norm: Normalization type. Either be None, 'instance' or 'batch'
aggregate_fn: Message aggregation function. Defaults to sum.
residual: If true, use residual connections
O3Layer: Type of tensor product layer to use
"""
super().__init__(f"layer_{layer_num}")
Expand All @@ -153,6 +163,7 @@ def __init__(
self._blocks = blocks
self._norm = norm
self._aggregate_fn = aggregate_fn
self._residual = residual

self._O3Layer = O3Layer

Expand Down Expand Up @@ -204,7 +215,10 @@ def _update(
x, node_attribute
)
# residual connection
nodes += update
if self._residual:
nodes += update
else:
nodes = update
# message norm
if self._norm in ["batch", "instance"]:
nodes = e3nn.haiku.BatchNorm(
Expand Down Expand Up @@ -271,10 +285,13 @@ def __init__(
"""
super().__init__()

if isinstance(hidden_irreps, e3nn.Irreps):
self._hidden_irreps_units = num_layers * [hidden_irreps]
else:
self._hidden_irreps_units = hidden_irreps
if not isinstance(output_irreps, e3nn.Irreps):
output_irreps = e3nn.Irreps(output_irreps)
if not isinstance(hidden_irreps, e3nn.Irreps):
hidden_irreps = e3nn.Irreps(hidden_irreps)

self._hidden_irreps = hidden_irreps
self._num_layers = num_layers

self._embed_msg_features = embed_msg_features
self._norm = norm
Expand All @@ -290,28 +307,34 @@ def __init__(
self._O3Layer = o3_layer

self._embedding = O3Embedding(
self._hidden_irreps_units[0],
self._hidden_irreps,
O3Layer=self._O3Layer,
embed_edges=self._embed_msg_features,
)

pooled_irreps = None
if task == "graph" and "0e" not in output_irreps:
# NOTE: different from original. This way proper gates are always applied
pooled_irreps = hidden_irreps

self._decoder = O3Decoder(
latent_irreps=self._hidden_irreps_units[-1],
latent_irreps=self._hidden_irreps,
output_irreps=output_irreps,
O3Layer=self._O3Layer,
task=task,
pool=pool,
pooled_irreps=pooled_irreps,
)

def __call__(self, st_graph: SteerableGraphsTuple) -> jnp.array:
# node (and edge) embedding
st_graph = self._embedding(st_graph)

# message passing
for n, hrp in enumerate(self._hidden_irreps_units):
st_graph = SEGNNLayer(output_irreps=hrp, layer_num=n, norm=self._norm)(
st_graph
)
for n in range(self._num_layers):
st_graph = SEGNNLayer(
output_irreps=self._hidden_irreps, layer_num=n, norm=self._norm
)(st_graph)

# decoder/pooler
nodes = self._decoder(st_graph)
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ classifiers =
packages = segnn_jax
python_requires = >=3.8
install_requires =
dm_haiku==0.0.9
e3nn_jax==0.19.3
dm_haiku>=0.0.9
e3nn_jax>=0.17.4
jax
jaxlib
jraph==0.0.6.dev0
Expand Down
6 changes: 2 additions & 4 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,8 @@ def segnn(x):
def _mae(p, t):
return jnp.abs(p - t)

train_loss = partial(loss_fn_wrapper, criterion=_mae, task=args.task)
eval_loss = partial(
loss_fn_wrapper, criterion=_mae, eval_trn=eval_trn, task=args.task
)
train_loss = partial(loss_fn_wrapper, criterion=_mae)
eval_loss = partial(loss_fn_wrapper, criterion=_mae, eval_trn=eval_trn)
if args.dataset in ["charged", "gravity"]:
from experiments.train import loss_fn_wrapper

Expand Down

0 comments on commit 217915b

Please sign in to comment.