diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1e5762c..7a9805a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,8 +28,8 @@ repos: rev: 22.3.0 hooks: - id: black - - repo: https://github.com/pycqa/pylint - rev: v2.9.6 + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: 'v0.0.265' hooks: - - id: pylint - args: [ "--rcfile=.pylintrc" ] + - id: ruff + exclude: ^datasets/|^data/|^.git/|^venv/ diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index b6d0d82..0000000 --- a/.pylintrc +++ /dev/null @@ -1,449 +0,0 @@ -# This Pylint rcfile contains a best-effort configuration to uphold the -# best-practices and style described in the Google Python style guide: -# https://google.github.io/styleguide/pyguide.html -# -# Its canonical open-source location is: -# https://google.github.io/styleguide/pylintrc - -[MASTER] - -# Files or directories to be skipped. They should be base names, not paths. -ignore=build,docs,projects - -# Files or directories matching the regex patterns are skipped. The regex -# matches against base names, not paths. -ignore-patterns= - -# Pickle collected data for later comparisons. -persistent=no - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - -# Use multiple processes to speed up Pylint. -jobs=4 - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code -extension-pkg-whitelist=pydantic - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED -confidence= - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -#enable= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" -disable=C0114, # module docstring - C0115, # class docstring - W0108, # unnecessary lambda - W0621, # redifining from outer scope - C0103, # bad function names - C0415, # import outside toplevel - W0212, # protected access - abstract-method, - apply-builtin, - arguments-differ, - attribute-defined-outside-init, - backtick, - bad-option-value, - basestring-builtin, - buffer-builtin, - c-extension-no-member, - consider-using-enumerate, - cmp-builtin, - cmp-method, - coerce-builtin, - coerce-method, - delslice-method, - div-method, - duplicate-code, - eq-without-hash, - execfile-builtin, - file-builtin, - filter-builtin-not-iterating, - fixme, - getslice-method, - global-statement, - hex-method, - idiv-method, - implicit-str-concat-in-sequence, - import-error, - import-self, - import-star-module-level, - inconsistent-return-statements, - input-builtin, - intern-builtin, - invalid-str-codec, - locally-disabled, - long-builtin, - long-suffix, - map-builtin-not-iterating, - misplaced-comparison-constant, - missing-function-docstring, - metaclass-assignment, - next-method-called, - next-method-defined, - no-absolute-import, - no-else-break, - no-else-continue, - no-else-raise, - no-else-return, - no-init, # added - no-member, - no-name-in-module, - no-self-use, - nonzero-method, - oct-method, - old-division, - old-ne-operator, - old-octal-literal, - old-raise-syntax, - parameter-unpacking, - print-statement, - raising-string, - range-builtin-not-iterating, - raw_input-builtin, - rdiv-method, - reduce-builtin, - relative-import, - reload-builtin, - round-builtin, - setslice-method, - signature-differs, - standarderror-builtin, - suppressed-message, - sys-max-int, - too-few-public-methods, - too-many-ancestors, - too-many-arguments, - too-many-boolean-expressions, - too-many-branches, - too-many-instance-attributes, - too-many-locals, - too-many-nested-blocks, - too-many-public-methods, - too-many-return-statements, - too-many-statements, - trailing-newlines, - unichr-builtin, - unicode-builtin, - unnecessary-pass, - unpacking-in-except, - useless-else-on-loop, - useless-object-inheritance, - useless-suppression, - using-cmp-argument, - wrong-import-order, - xrange-builtin, - zip-builtin-not-iterating, - - -[REPORTS] - -# Set the output format. Available formats are text, parseable, colorized, msvs -# (visual studio) and html. You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no - -# Tells whether to display a full report or only the messages -reports=no - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details -#msg-template= - - -[BASIC] - -# Good variable names which should always be accepted, separated by a comma -good-names=main,_ - -# Bad variable names which should always be refused, separated by a comma -bad-names= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Include a hint for the correct naming format with invalid-name -include-naming-hint=no - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl - -# Regular expression matching correct function names -function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ - -# Regular expression matching correct variable names -variable-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct constant names -const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct attribute names -attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ - -# Regular expression matching correct argument names -argument-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class attribute names -class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ - -# Regular expression matching correct inline iteration names -inlinevar-rgx=^[a-z][a-z0-9_]*$ - -# Regular expression matching correct class names -class-rgx=^_?[A-Z][a-zA-Z0-9]*$ - -# Regular expression matching correct module names -module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ - -# Regular expression matching correct method names -method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=10 - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - - -[FORMAT] - -# Maximum number of characters on a single line. -max-line-length=120 - -# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt -# lines made too long by directives to pytype. - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=(?x)( - ^\s*(\#\ )??$| - ^\s*(from\s+\S+\s+)?import\s+.+$) - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=yes - -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check= - -# Maximum number of lines in a module -max-module-lines=99999 - -# String used as indentation unit. The internal Google style guide mandates 2 -# spaces. Google's externaly-published style guide says 4, consistent with -# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google -# projects (like TensorFlow). -indent-string=' ' - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=TODO - - -[STRING] - -# This flag controls whether inconsistent-quotes generates a warning when the -# character used as a quote delimiter is used inconsistently within a module. -check-quote-consistency=yes - - -[VARIABLES] - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# A regular expression matching the name of dummy variables (i.e. expectedly -# not used). -dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_,_cb - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools - - -[LOGGING] - -# Logging modules to check that the string format arguments are in logging -# function parameter format -logging-modules=logging,absl.logging,tensorflow.io.logging - - -[SIMILARITIES] - -# Minimum lines number of a similarity. -min-similarity-lines=4 - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - - -[SPELLING] - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - - -[IMPORTS] - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant, absl - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls, - class_ - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=StandardError, - Exception, - BaseException diff --git a/README.md b/README.md index cd31af2..0c1fe9d 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ python -m pip install -e . ### GPU support Upgrade `jax` to the gpu version ``` -pip install --upgrade "jax[cuda]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +pip install --upgrade "jax[cuda]>=0.4.6" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` ## Validation @@ -46,7 +46,7 @@ Times are remeasured on Quadro RTX 4000, __model only__ on batches of 100 graphs .0043 21.22 .0045 - 4.47 + 3.77 gravity (position) @@ -85,11 +85,17 @@ QM9 is automatically downloaded and processed when running the respective experi The N-body datasets have to be generated locally from the directory [experiments/nbody/data](experiments/nbody/data) (it will take some time, especially n-body `gravity`) #### Charged dataset (5 bodies, 10000 training samples) ``` -python3 -u generate_dataset.py --simulation=charged +python3 -u generate_dataset.py --simulation=charged --seed=43 ``` #### Gravity dataset (100 bodies, 10000 training samples) ``` -python3 -u generate_dataset.py --simulation=gravity --n-balls=100 +python3 -u generate_dataset.py --simulation=gravity --n-balls=100 --seed=43 +``` + +### Notes +On `jax<=0.4.6`, the `jit`-`pjit` merge can be deactivated making traning faster (on nbody). This looks like an issue with dataloading and the validation training loop implementation and it does not affect SEGNN. +``` +export JAX_JIT_PJIT_API_MERGE=0 ``` ### Usage @@ -111,6 +117,7 @@ python validate.py --dataset=qm9 --epochs=1000 --target=alpha --lmax-hidden=2 -- (configurations used in validation) + ## Acknowledgments - [e3nn_jax](https://github.com/e3nn/e3nn-jax) made this reimplementation possible. - [Artur Toshev](https://github.com/arturtoshev) and [Johannes Brandsetter](https://github.com/brandstetter-johannes), for support. diff --git a/experiments/nbody/data/generate_dataset.py b/experiments/nbody/data/generate_dataset.py index 059bba4..48e83a4 100755 --- a/experiments/nbody/data/generate_dataset.py +++ b/experiments/nbody/data/generate_dataset.py @@ -1,8 +1,8 @@ """ Generate charged and gravity datasets. -charged: python3 generate_dataset.py --simulation=charged --num-train=10000 -gravity: python3 generate_dataset.py --simulation=gravity --num-train=10000 --n-balls=100 +charged: python3 generate_dataset.py --simulation=charged --num-train=10000 --seed=43 +gravity: python3 generate_dataset.py --simulation=gravity --num-train=10000 --n-balls=100 --seed=43 """ import argparse import time diff --git a/experiments/nbody/data/synthetic_sim.py b/experiments/nbody/data/synthetic_sim.py index 706c633..80ccbb2 100755 --- a/experiments/nbody/data/synthetic_sim.py +++ b/experiments/nbody/data/synthetic_sim.py @@ -1,6 +1,3 @@ -import time - -import matplotlib.pyplot as plt import numpy as np @@ -299,49 +296,3 @@ def sample_trajectory(self, T=10000, sample_freq=10): vel_save += np.random.randn(T_save, N, self.dim) * self.noise_var force_save += np.random.randn(T_save, N, self.dim) * self.noise_var return pos_save, vel_save, force_save, mass - - -if __name__ == "__main__": - from tqdm import tqdm - - color_map = "summer" - cmap = plt.get_cmap(color_map) - - np.random.seed(43) - - sim = GravitySim(n_balls=100, loc_std=1) - - t = time.time() - loc, vel, force, mass = sim.sample_trajectory(T=5000, sample_freq=1) - - print("Simulation time: {}".format(time.time() - t)) - plt.figure() - axes = plt.gca() - axes.set_xlim([-4.0, 4.0]) - axes.set_ylim([-4.0, 4.0]) - # for i in range(loc.shape[-2]): - # plt.plot(loc[:, i, 0], loc[:, i, 1], alpha=0.1, linewidth=1) - # plt.plot(loc[0, i, 0], loc[0, i, 1], 'o') - - offset = 4000 - N_frames = loc.shape[0] - offset - N_particles = loc.shape[-2] - - for i in tqdm(range(N_particles)): - color = cmap(i / N_particles) - # for j in range(loc.shape[0]-2): - for j in range(offset, offset + N_frames): - plt.plot( - loc[j : j + 2, i, 0], - loc[j : j + 2, i, 1], - alpha=0.2 + 0.7 * ((j - offset) / N_frames) ** 4, - linewidth=1, - color=color, - ) - plt.plot(loc[-1, i, 0], loc[-1, i, 1], "o", markersize=3, color=color) - plt.axis("off") - # plt.figure() - # energies = [sim._energy(loc[i, :, :], vel[i, :, :], mass, sim.interaction_strength) for i in - # range(loc.shape[0])] - # plt.plot(energies) - plt.show() diff --git a/experiments/nbody/utils.py b/experiments/nbody/utils.py index e00d1ca..1bdf856 100644 --- a/experiments/nbody/utils.py +++ b/experiments/nbody/utils.py @@ -4,7 +4,6 @@ import jax import jax.numpy as jnp import jax.tree_util as tree -import jraph import numpy as np import torch from jraph import GraphsTuple, segment_mean @@ -20,11 +19,15 @@ def O3Transform( node_features_irreps: e3nn.Irreps, edge_features_irreps: e3nn.Irreps, lmax_attributes: int, + scn: bool = False, ) -> Callable: """ Build a transformation function that includes (nbody) O3 attributes to a graph. """ - attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes) + if not scn: + attribute_irreps = e3nn.Irreps.spherical_harmonics(lmax_attributes) + else: + attribute_irreps = e3nn.Irrep("1o") @jax.jit def _o3_transform( @@ -51,12 +54,17 @@ def _o3_transform( jnp.concatenate((loc - mean_loc, vel, vel_abs), axis=-1), ) - edge_attributes = e3nn.spherical_harmonics( - attribute_irreps, rel_pos, normalize=True, normalization="integral" - ) - vel_embedding = e3nn.spherical_harmonics( - attribute_irreps, vel, normalize=True, normalization="integral" - ) + if not scn: + edge_attributes = e3nn.spherical_harmonics( + attribute_irreps, rel_pos, normalize=True, normalization="integral" + ) + vel_embedding = e3nn.spherical_harmonics( + attribute_irreps, vel, normalize=True, normalization="integral" + ) + else: + edge_attributes = e3nn.IrrepsArray(attribute_irreps, rel_pos) + vel_embedding = e3nn.IrrepsArray(attribute_irreps, vel) + # scatter edge attributes sum_n_node = tree.tree_leaves(nodes)[0].shape[0] node_attributes = ( @@ -66,9 +74,11 @@ def _o3_transform( ) + vel_embedding ) - - # scalar attribute to 1 by default - node_attributes.array = node_attributes.array.at[:, 0].set(1.0) + if not scn: + # scalar attribute to 1 by default + node_attributes = e3nn.IrrepsArray( + node_attributes.irreps, node_attributes.array.at[:, 0].set(1.0) + ) return SteerableGraphsTuple( graph=GraphsTuple( @@ -208,7 +218,10 @@ def setup_nbody_data( ) o3_transform = O3Transform( - args.node_irreps, args.additional_message_irreps, args.lmax_attributes + args.node_irreps, + args.additional_message_irreps, + args.lmax_attributes, + scn=args.o3_layer == "scn", ) graph_transform = NbodyGraphTransform( transform=o3_transform, diff --git a/experiments/qm9/dataset.py b/experiments/qm9/dataset.py index fed7736..3b249e3 100644 --- a/experiments/qm9/dataset.py +++ b/experiments/qm9/dataset.py @@ -308,7 +308,9 @@ def process(self): torch.save(self.collate(data_list), self.processed_paths[0]) def get_O3_attr(self, edge_index, pos, attr_irreps): - """Creates spherical harmonic edge attributes and node attributes for the SEGNN""" + """ + Creates spherical harmonic edge attributes and node attributes for the SEGNN. + """ rel_pos = ( pos[edge_index[0]] - pos[edge_index[1]] ) # pos_j - pos_i (note in edge_index stores tuples like (j,i)) diff --git a/experiments/qm9/utils.py b/experiments/qm9/utils.py index 1007cd0..6186094 100644 --- a/experiments/qm9/utils.py +++ b/experiments/qm9/utils.py @@ -52,7 +52,10 @@ def _to_steerable_graph( node_attributes = e3nn.IrrepsArray( attribute_irreps, jnp.pad(jnp.array(data.node_attr), node_attr_pad) ) - node_attributes.array = node_attributes.array.at[:, 0].set(1.0) + # scalar attribute to 1 by default + node_attributes = e3nn.IrrepsArray( + node_attributes.irreps, node_attributes.array.at[:, 0].set(1.0) + ) additional_message_features = e3nn.IrrepsArray( args.additional_message_irreps, @@ -129,7 +132,8 @@ def setup_qm9_data( target_mean, target_mad = dataset_train.calc_stats() - remove_offsets = lambda t: (t - target_mean) / target_mad + def remove_offsets(t): + return (t - target_mean) / target_mad # not great and very slow due to huge padding loader_train = DataLoader( @@ -158,6 +162,7 @@ def setup_qm9_data( train_trn=remove_offsets, ) - add_offsets = lambda p: p * target_mad + target_mean + def add_offsets(p): + return p * target_mad + target_mean return loader_train, loader_val, loader_test, to_graphs_tuple, add_offsets diff --git a/experiments/train.py b/experiments/train.py index f021131..9e2911e 100644 --- a/experiments/train.py +++ b/experiments/train.py @@ -4,16 +4,16 @@ import haiku as hk import jax -from jax import jit import jax.numpy as jnp import jraph import optax +from jax import jit from segnn_jax import SteerableGraphsTuple @partial(jit, static_argnames=["model_fn", "criterion", "task", "do_mask", "eval_trn"]) -def loss_fn( +def loss_fn_wrapper( params: hk.Params, state: hk.State, st_graph: SteerableGraphsTuple, @@ -27,16 +27,21 @@ def loss_fn( pred, state = model_fn(params, state, st_graph) if eval_trn is not None: pred = eval_trn(pred) - if task == "node": - mask = jraph.get_node_padding_mask(st_graph.graph) - if task == "graph": - mask = jraph.get_graph_padding_mask(st_graph.graph) - # broadcase mask for vector targets - if len(pred.shape) == 2: - mask = mask[:, jnp.newaxis] + if do_mask: - target = target * mask - pred = pred * mask + if task == "node": + mask = jraph.get_node_padding_mask(st_graph.graph) + if task == "graph": + mask = jraph.get_graph_padding_mask(st_graph.graph) + # broadcast mask for vector targets + if len(pred.shape) == 2: + mask = mask[:, jnp.newaxis] + else: + mask = jnp.ones_like(target) + + target = target * mask + pred = pred * mask + assert target.shape == pred.shape return jnp.sum(criterion(pred, target)) / jnp.count_nonzero(mask), state @@ -125,7 +130,7 @@ def train( for e in range(args.epochs): train_loss = 0.0 - train_start = time.perf_counter_ns() + epoch_start = time.perf_counter_ns() for data in loader_train: graph, target = graph_transform(data) loss, params, segnn_state, opt_state = update_fn( @@ -136,10 +141,11 @@ def train( opt_state=opt_state, ) train_loss += loss - train_time = (time.perf_counter_ns() - train_start) / 1e6 train_loss /= len(loader_train) + epoch_time = (time.perf_counter_ns() - epoch_start) / 1e9 + print( - f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {train_time:.2f}ms", + f"[Epoch {e+1:>4}] train loss {train_loss:.6f}, epoch {epoch_time:.2f}s", end="", ) if e % args.val_freq == 0: @@ -157,7 +163,7 @@ def train( test_loss = 0 _, test_loss = eval_fn(loader_test, params, segnn_state) # ignore compilation time - avg_time = avg_time[2:] + avg_time = avg_time[1:] if len(avg_time) > 1 else avg_time avg_time = sum(avg_time) / len(avg_time) print( "Training done.\n" diff --git a/requirements.txt b/requirements.txt index 2fe59a5..b138783 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html dm-haiku==0.0.9 -e3nn-jax==0.17.4 -jax[cuda]==0.4.8 +e3nn-jax==0.19.3 +jax[cuda] jraph==0.0.6.dev0 numpy>=1.23.4 optax==0.1.3 diff --git a/segnn_jax/__init__.py b/segnn_jax/__init__.py index 64f5182..fdf6c07 100644 --- a/segnn_jax/__init__.py +++ b/segnn_jax/__init__.py @@ -1,4 +1,9 @@ -from .blocks import O3TensorProduct, O3TensorProductGate, O3TensorProductLegacy +from .blocks import ( + O3TensorProduct, + O3TensorProductFC, + O3TensorProductGate, + O3TensorProductSCN, +) from .graph_utils import SteerableGraphsTuple from .irreps_computer import balanced_irreps, weight_balanced_irreps from .segnn import SEGNN, SEGNNLayer @@ -7,11 +12,12 @@ "SEGNN", "SEGNNLayer", "O3TensorProduct", - "O3TensorProductLegacy", + "O3TensorProductFC", + "O3TensorProductSCN", "O3TensorProductGate", "weight_balanced_irreps", "balanced_irreps", "SteerableGraphsTuple", ] -__version__ = "0.6" +__version__ = "0.7" diff --git a/segnn_jax/blocks.py b/segnn_jax/blocks.py index 58e74c1..10aba61 100644 --- a/segnn_jax/blocks.py +++ b/segnn_jax/blocks.py @@ -1,11 +1,13 @@ import warnings +from abc import ABC, abstractmethod from typing import Callable, Optional, Tuple, Union import e3nn_jax as e3nn import haiku as hk import jax import jax.numpy as jnp -from e3nn_jax._src.tensor_products import naive_broadcast_decorator +from e3nn_jax.experimental import linear_shtp as escn +from e3nn_jax.legacy import FunctionalFullyConnectedTensorProduct from .config import config @@ -27,14 +29,8 @@ def uniform_init( ) -class O3TensorProduct(hk.Module): - """ - O(3) equivariant linear parametrized tensor product layer. - - Functionally the same as O3TensorProductLegacy, but around 5-10% faster. - FullyConnectedTensorProduct seems faster than tensor_product + linear: - https://github.com/e3nn/e3nn-jax/releases/tag/0.14.0 - """ +class TensorProduct(hk.Module, ABC): + """O(3) equivariant linear parametrized tensor product layer.""" def __init__( self, @@ -54,10 +50,10 @@ def __init__( name: Name of the linear layer params init_fn: Weight initialization function. Default is uniform. gradient_normalization: Gradient normalization method. Default is "path" - NOTE: gradient_normalization="element" is the default in torch and haiku. + NOTE: gradient_normalization="element" is the default in torch and haiku. path_normalization: Path normalization method. Default is "element" """ - super().__init__(name) + super().__init__(name=name) if not isinstance(output_irreps, e3nn.Irreps): output_irreps = e3nn.Irreps(output_irreps) @@ -77,11 +73,96 @@ def __init__( self.biases = biases and "0e" in self.output_irreps + def _check_input( + self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None + ) -> Tuple[e3nn.IrrepsArray, e3nn.IrrepsArray]: + if not y: + y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype)) + + if x.irreps.lmax == 0 and y.irreps.lmax == 0 and self.output_irreps.lmax > 0: + warnings.warn( + f"The specified output irreps ({self.output_irreps}) are not scalars " + "but both operands are. This can have undesired behaviour (NaN). Try " + "redistributing them into scalars or choose higher orders." + ) + + return x, y + + @abstractmethod + def __call__( + self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None, **kwargs + ) -> e3nn.IrrepsArray: + """Applies an O(3) equivariant linear parametrized tensor product layer. + + Args: + x (IrrepsArray): Left tensor + y (IrrepsArray): Right tensor. If None it defaults to np.ones. + + Returns: + The output to the weighted tensor product (IrrepsArray). + """ + raise NotImplementedError + + +class O3TensorProduct(TensorProduct): + """O(3) equivariant linear parametrized tensor product layer. + + Original O3TensorProduct version that uses tensor_product + Linear instead of + FullyConnectedTensorProduct. + From e3nn 0.19.2 (https://github.com/e3nn/e3nn-jax/releases/tag/0.19.2), this is + as fast as FullyConnectedTensorProduct. + """ + + def __init__( + self, + output_irreps: e3nn.Irreps, + *, + biases: bool = True, + name: Optional[str] = None, + init_fn: Optional[InitFn] = None, + gradient_normalization: Optional[Union[str, float]] = "element", + path_normalization: Optional[Union[str, float]] = None, + ): + super().__init__( + output_irreps, + biases=biases, + name=name, + init_fn=init_fn, + gradient_normalization=gradient_normalization, + path_normalization=path_normalization, + ) + + self._linear = e3nn.haiku.Linear( + self.output_irreps, + get_parameter=self.get_parameter, + biases=self.biases, + name=f"{self.name}_linear", + gradient_normalization=self._gradient_normalization, + path_normalization=self._path_normalization, + ) + + def __call__( + self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None + ) -> TensorProductFn: + x, y = self._check_input(x, y) + # tensor product + linear + tp = self._linear(e3nn.tensor_product(x, y)) + return tp + + +class O3TensorProductFC(TensorProduct): + """ + O(3) equivariant linear parametrized tensor product layer. + + Functionally the same as O3TensorProduct, but uses FullyConnectedTensorProduct and + is slightly slower (~5-10%) than tensor_prodict + Linear. + """ + def _build_tensor_product( self, left_irreps: e3nn.Irreps, right_irreps: e3nn.Irreps ) -> Callable: """Build the tensor product function.""" - tp = e3nn.FunctionalFullyConnectedTensorProduct( + tp = FunctionalFullyConnectedTensorProduct( left_irreps, right_irreps, self.output_irreps, @@ -102,10 +183,19 @@ def _build_tensor_product( for ins in tp.instructions ] - def tensor_product(x, y, **kwargs): - return tp.left_right(ws, x, y, **kwargs)._convert(self.output_irreps) + def _tensor_product(x, y, **kwargs): + return tp.left_right(ws, x, y, **kwargs).rechunk(self.output_irreps) + + # naive broadcasting wrapper + # TODO: not the best + def _tp_wrapper(*args): + leading_shape = jnp.broadcast_shapes(*(arg.shape[:-1] for arg in args)) + args = [arg.broadcast_to(leading_shape + (-1,)) for arg in args] + for _ in range(len(leading_shape)): + f = jax.vmap(_tensor_product) + return f(*args) - return naive_broadcast_decorator(tensor_product) + return _tp_wrapper def _build_biases(self) -> Callable: """Build the add bias function.""" @@ -121,37 +211,19 @@ def _build_biases(self) -> Callable: b = e3nn.IrrepsArray(f"{self.output_irreps.count('0e')}x0e", jnp.concatenate(b)) # TODO: could be improved - def _wrapper(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray: + def _bias_wrapper(x: e3nn.IrrepsArray) -> e3nn.IrrepsArray: scalars = x.filter("0e") other = x.filter(drop="0e") return e3nn.concatenate( [scalars + b.broadcast_to(scalars.shape), other], axis=1 ) - return _wrapper + return _bias_wrapper def __call__( self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None, **kwargs ) -> e3nn.IrrepsArray: - """Applies an O(3) equivariant linear parametrized tensor product layer. - - Args: - x (IrrepsArray): Left tensor - y (IrrepsArray): Right tensor. If None it defaults to np.ones. - - Returns: - The output to the weighted tensor product (IrrepsArray). - """ - - if not y: - y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype)) - - if x.irreps.lmax == 0 and y.irreps.lmax == 0 and self.output_irreps.lmax > 0: - warnings.warn( - f"The specified output irreps ({self.output_irreps}) are not scalars " - "but both operands are. This can have undesired behaviour (NaN). Try " - "redistributing them into scalars or choose higher orders." - ) + x, y = self._check_input(x, y) tp = self._build_tensor_product(x.irreps, y.irreps) output = tp(x, y, **kwargs) @@ -164,78 +236,64 @@ def __call__( return output -def O3TensorProductLegacy( - output_irreps: e3nn.Irreps, - *, - biases: bool = True, - name: Optional[str] = None, - init_fn: Optional[InitFn] = None, - gradient_normalization: Optional[Union[str, float]] = "element", - path_normalization: Optional[Union[str, float]] = None, -): - """O(3) equivariant linear parametrized tensor product layer. - Legacy version of O3TensorProduct that uses e3nn.haiku.Linear instead of - e3nn.FunctionalFullyConnectedTensorProduct. - - Args: - output_irreps: Output representation - biases: If set ot true will add biases - name: Name of the linear layer params - init_fn: Weight initialization function. Default is uniform. - gradient_normalization: Gradient normalization method. Default is "path" - NOTE: gradient_normalization="element" is the default in torch and haiku. - path_normalization: Path normalization method. Default is "element" - - Returns: - A function that returns the output to the weighted tensor product. +class O3TensorProductSCN(TensorProduct): """ + O(3) equivariant linear parametrized tensor product layer. - if not isinstance(output_irreps, e3nn.Irreps): - output_irreps = e3nn.Irreps(output_irreps) - - if not init_fn: - init_fn = uniform_init - - linear = e3nn.haiku.Linear( - output_irreps, - get_parameter=init_fn, - biases=biases, - name=name, - gradient_normalization=gradient_normalization, - path_normalization=path_normalization, - ) - - def _tensor_product( - x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None - ) -> TensorProductFn: - """Applies an O(3) equivariant linear parametrized tensor product layer. + O3TensorProduct with eSCN optimization for larger spherical harmonic orders. Should + be used without spherical harmonics on the inputs. + """ - Args: - x (IrrepsArray): Left tensor - y (IrrepsArray): Right tensor. If None it defaults to np.ones. + def __init__( + self, + output_irreps: e3nn.Irreps, + *, + biases: bool = True, + name: Optional[str] = None, + init_fn: Optional[InitFn] = None, + gradient_normalization: Optional[Union[str, float]] = None, + path_normalization: Optional[Union[str, float]] = None, + ): + super().__init__( + output_irreps, + biases=biases, + name=name, + init_fn=init_fn, + gradient_normalization=gradient_normalization, + path_normalization=path_normalization, + ) - Returns: - The output to the weighted tensor product (IrrepsArray). - """ + self._linear = e3nn.haiku.Linear( + self.output_irreps, + get_parameter=self.get_parameter, + biases=self.biases, + name=f"{self.name}_linear", + gradient_normalization=self._gradient_normalization, + path_normalization=self._path_normalization, + ) + def _check_input( + self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None + ) -> Tuple[e3nn.IrrepsArray, e3nn.IrrepsArray]: if not y: - y = e3nn.IrrepsArray("1x0e", jnp.ones((1, 1), dtype=x.dtype)) - - if x.irreps.lmax == 0 and y.irreps.lmax == 0 and output_irreps.lmax > 0: - warnings.warn( - f"The specified output irreps ({output_irreps}) are not scalars " - "but both operands are. This can have undesired behaviour (NaN). Try " - "redistributing them into scalars or choose higher orders." - ) - - tp = e3nn.tensor_product(x, y) + raise ValueError("eSCN cannot be used without the right input.") + return super()._check_input(x, y) - return linear(tp) - - return _tensor_product + def __call__( + self, x: e3nn.IrrepsArray, y: Optional[e3nn.IrrepsArray] = None, **kwargs + ) -> e3nn.IrrepsArray: + """Apply the layer. y must not be into spherical harmonics.""" + x, y = self._check_input(x, y) + shtp = e3nn.utils.vmap(escn.shtp, in_axes=(0, 0, None)) + tp = shtp(x, y, self.output_irreps) + return self._linear(tp) -O3Layer = O3TensorProduct if config("o3_layer") == "new" else O3TensorProductLegacy +O3_LAYERS = { + "tpl": O3TensorProduct, + "fctp": O3TensorProductFC, + "scn": O3TensorProductSCN, +} def O3TensorProductGate( @@ -246,6 +304,7 @@ def O3TensorProductGate( gate_activation: Optional[Callable] = None, name: Optional[str] = None, init_fn: Optional[InitFn] = None, + o3_layer: Optional[Union[str, TensorProduct]] = None, ) -> TensorProductFn: """Non-linear (gated) O(3) equivariant linear tensor product layer. @@ -257,9 +316,10 @@ def O3TensorProductGate( scalar_activation: Activation function for scalars gate_activation: Activation function for higher order name: Name of the linear layer params + o3_layer: Tensor product layer type. "tpl", "fctp", "scn" or a custom layer Returns: - Function that applies the gated tensor product layer. + Function that applies the gated tensor product layer """ if not isinstance(output_irreps, e3nn.Irreps): @@ -269,12 +329,23 @@ def O3TensorProductGate( gate_irreps = e3nn.Irreps( f"{output_irreps.num_irreps - output_irreps.count('0e')}x0e" ) + + if o3_layer is None: + o3_layer = config("o3_layer") + + if isinstance(o3_layer, str): + assert o3_layer in O3_LAYERS, f"Unknown O3 layer {o3_layer}." + O3Layer = O3_LAYERS[o3_layer] + else: + O3Layer = o3_layer + tensor_product = O3Layer( (gate_irreps + output_irreps).regroup(), biases=biases, name=name, init_fn=init_fn, ) + if not scalar_activation: scalar_activation = jax.nn.silu if not gate_activation: diff --git a/segnn_jax/config.py b/segnn_jax/config.py index dc0586b..6036a5a 100644 --- a/segnn_jax/config.py +++ b/segnn_jax/config.py @@ -4,7 +4,7 @@ "gradient_normalization": "element", # "element" or "path" "path_normalization": "element", # "element" or "path" "default_dtype": jnp.float32, - "o3_layer": "new", # "new" or "legacy" + "o3_layer": "tpl", # "tpl" (tp + Linear) or "fctp" (FullyConnected) or "scn" (SCN) } diff --git a/segnn_jax/irreps_computer.py b/segnn_jax/irreps_computer.py index c8a3eaa..a00db1e 100644 --- a/segnn_jax/irreps_computer.py +++ b/segnn_jax/irreps_computer.py @@ -30,8 +30,15 @@ def weight_balanced_irreps( scalar_units: int, irreps_right: Irreps, use_sh: bool = True, lmax: int = None ) -> Irreps: """ - Determines left Irreps such that the weighted tensor product irreps_left x irreps_right + Determines irreps_left such that the parametrized tensor product + Linear(tensor_product(irreps_left, irreps_right)) has (at least) scalar_units weights. + + Args: + scalar_units: number of desired weights + irreps_right: irreps of the right tensor + use_sh: whether to use spherical harmonics + lmax: maximum level of spherical harmonics """ # irrep order if lmax is None: diff --git a/segnn_jax/segnn.py b/segnn_jax/segnn.py index a2567a4..4506b93 100644 --- a/segnn_jax/segnn.py +++ b/segnn_jax/segnn.py @@ -6,11 +6,16 @@ import jraph from jax.tree_util import Partial -from .blocks import O3Layer, O3TensorProductGate +from .blocks import O3_LAYERS, O3TensorProduct, O3TensorProductGate, TensorProduct +from .config import config from .graph_utils import SteerableGraphsTuple, pooling -def O3Embedding(embed_irreps: e3nn.Irreps, embed_edges: bool = True) -> Callable: +def O3Embedding( + embed_irreps: e3nn.Irreps, + embed_edges: bool = True, + O3Layer: TensorProduct = O3TensorProduct, +) -> Callable: """Linear steerable embedding. Embeds the graph nodes in the representation space :param embed_irreps:. @@ -18,6 +23,7 @@ def O3Embedding(embed_irreps: e3nn.Irreps, embed_edges: bool = True) -> Callable Args: embed_irreps: Output representation embed_edges: If true also embed edges/message passing features + O3Layer: Type of tensor product layer to use Returns: Function to embed graph nodes (and optionally edges) @@ -27,10 +33,9 @@ def _embedding( st_graph: SteerableGraphsTuple, ) -> SteerableGraphsTuple: graph = st_graph.graph - nodes = O3Layer( - embed_irreps, - name="embedding_nodes", - )(graph.nodes, st_graph.node_attributes) + nodes = O3Layer(embed_irreps, name="embedding_nodes")( + graph.nodes, st_graph.node_attributes + ) st_graph = st_graph._replace(graph=graph._replace(nodes=nodes)) # NOTE edge embedding is not in the original paper but can get good results @@ -57,6 +62,7 @@ def O3Decoder( blocks: int = 1, task: str = "graph", pool: Optional[str] = "avg", + O3Layer: TensorProduct = O3TensorProduct, ): """Steerable pooler and decoder. @@ -65,6 +71,8 @@ def O3Decoder( output_irreps: Output representation blocks: Number of tensor product blocks in the decoder task: Specifies where the output is located. Either 'graph' or 'node' + pool: Pooling method to use. One of 'avg', 'sum', 'none', None + O3Layer: Type of tensor product layer to use Returns: Decoded latent feature space to output space. @@ -88,8 +96,7 @@ def _decoder(st_graph: SteerableGraphsTuple): if task == "graph": # pool over graph - pooled_irreps = (latent_irreps.num_irreps * output_irreps).regroup() - nodes = O3Layer(pooled_irreps, name=f"prepool_{blocks}")( + nodes = O3Layer(latent_irreps, name=f"prepool_{blocks}")( nodes, st_graph.node_attributes ) @@ -103,8 +110,10 @@ def _decoder(st_graph: SteerableGraphsTuple): # post pool mlp (not steerable) for i in range(blocks): - nodes = O3TensorProductGate(pooled_irreps, name=f"postpool_{i}")(nodes) - nodes = O3Layer(output_irreps, name="output")(nodes) + nodes = O3TensorProductGate( + latent_irreps, name=f"postpool_{i}", o3_layer=O3TensorProduct + )(nodes) + nodes = O3TensorProduct(output_irreps, name="output")(nodes) return nodes @@ -125,6 +134,7 @@ def __init__( blocks: int = 2, norm: Optional[str] = None, aggregate_fn: Optional[Callable] = jraph.segment_sum, + O3Layer: TensorProduct = O3TensorProduct, ): """ Initialize the layer. @@ -135,6 +145,7 @@ def __init__( blocks: Number of tensor product blocks in the layer norm: Normalization type. Either be None, 'instance' or 'batch' aggregate_fn: Message aggregation function. Defaults to sum. + O3Layer: Type of tensor product layer to use """ super().__init__(f"layer_{layer_num}") assert norm in ["batch", "instance", "none", None], f"Unknown norm '{norm}'" @@ -143,6 +154,8 @@ def __init__( self._norm = norm self._aggregate_fn = aggregate_fn + self._O3Layer = O3Layer + def _message( self, edge_attribute: e3nn.IrrepsArray, @@ -187,7 +200,7 @@ def _update( x, node_attribute ) # last update layer without activation - update = O3Layer(self._output_irreps, name=f"tp_{self._blocks - 1}")( + update = self._O3Layer(self._output_irreps, name=f"tp_{self._blocks - 1}")( x, node_attribute ) # residual connection @@ -240,6 +253,7 @@ def __init__( task: Optional[str] = "graph", blocks_per_layer: int = 2, embed_msg_features: bool = False, + o3_layer: Optional[Union[str, TensorProduct]] = None, ): """ Initialize the network. @@ -253,6 +267,7 @@ def __init__( task: Specifies where the output is located. Either 'graph' or 'node' blocks_per_layer: Number of tensor product blocks in each message passing embed_msg_features: Set to true to also embed edges/message passing features + o3_layer: Tensor product layer type. "tpl", "fctp", "scn" or a custom layer """ super().__init__() @@ -265,14 +280,25 @@ def __init__( self._norm = norm self._blocks_per_layer = blocks_per_layer + # layer type + if o3_layer is None: + o3_layer = config("o3_layer") + if isinstance(o3_layer, str): + assert o3_layer in O3_LAYERS, f"Unknown O3 layer {o3_layer}." + self._O3Layer = O3_LAYERS[o3_layer] + else: + self._O3Layer = o3_layer + self._embedding = O3Embedding( self._hidden_irreps_units[0], + O3Layer=self._O3Layer, embed_edges=self._embed_msg_features, ) self._decoder = O3Decoder( latent_irreps=self._hidden_irreps_units[-1], output_irreps=output_irreps, + O3Layer=self._O3Layer, task=task, pool=pool, ) diff --git a/setup.cfg b/setup.cfg index 9bd45b6..8163953 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,9 +15,9 @@ packages = segnn_jax python_requires = >=3.8 install_requires = dm_haiku==0.0.9 - e3nn_jax==0.17.4 - jax==0.4.8 - jaxlib==0.4.8 + e3nn_jax==0.19.3 + jax + jaxlib jraph==0.0.6.dev0 numpy>=1.23.4 optax==0.1.3 diff --git a/tests/conftest.py b/tests/conftest.py index 07c921d..5acad1e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,13 +8,15 @@ from segnn_jax import SteerableGraphsTuple +os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" @pytest.fixture def dummy_graph(): - def _rand_graph(n_graphs: int = 1): + def _rand_graph(n_graphs: int = 1, attr_irreps: str = "1x0e + 1x1o"): + attr_irreps = e3nn.Irreps(attr_irreps) return SteerableGraphsTuple( graph=jraph.GraphsTuple( nodes=e3nn.IrrepsArray("1x1o", jnp.ones((n_graphs * 5, 3))), @@ -27,7 +29,9 @@ def _rand_graph(n_graphs: int = 1): ), additional_message_features=None, edge_attributes=None, - node_attributes=e3nn.IrrepsArray("1x0e+1x1o", jnp.ones((n_graphs * 5, 4))), + node_attributes=e3nn.IrrepsArray( + attr_irreps, jnp.ones((n_graphs * 5, attr_irreps.dim)) + ), ) return _rand_graph diff --git a/tests/test_blocks.py b/tests/test_blocks.py index e7a8336..508abf6 100644 --- a/tests/test_blocks.py +++ b/tests/test_blocks.py @@ -1,63 +1,61 @@ import e3nn_jax as e3nn import haiku as hk import pytest -from e3nn_jax.util import assert_equivariant +from e3nn_jax.utils import assert_equivariant -from segnn_jax import O3TensorProduct, O3TensorProductGate, O3TensorProductLegacy +from segnn_jax import ( + O3TensorProduct, + O3TensorProductFC, + O3TensorProductGate, + O3TensorProductSCN, +) @pytest.mark.parametrize("biases", [False, True]) -def test_linear(key, biases): - f = lambda x1, x2: O3TensorProduct("1x1o", biases=biases)(x1, x2) - f = hk.without_apply_rng(hk.transform(f)) - - v = e3nn.normal("1x1o", key, (5,)) - params = f.init(key, v, v) +@pytest.mark.parametrize( + "O3Layer", [O3TensorProduct, O3TensorProductFC, O3TensorProductSCN] +) +def test_linear_layers(key, biases, O3Layer): + def f(x1, x2): + return O3Layer("1x1o", biases=biases)(x1, x2) - wrapper = lambda x1, x2: f.apply(params, x1, x2) - - assert_equivariant( - wrapper, - key, - args_in=(e3nn.normal("1x1o", key, (5,)), e3nn.normal("1x1o", key, (5,))), - ) - - -@pytest.mark.parametrize("biases", [False, True]) -def test_gated(key, biases): - import segnn_jax.blocks - - segnn_jax.blocks.O3Layer = segnn_jax.blocks.O3TensorProduct - - f = lambda x1, x2: O3TensorProductGate("1x1o", biases=biases)(x1, x2) f = hk.without_apply_rng(hk.transform(f)) v = e3nn.normal("1x1o", key, (5,)) params = f.init(key, v, v) - wrapper = lambda x1, x2: f.apply(params, x1, x2) + def wrapper(x1, x2): + return f.apply(params, x1, x2) assert_equivariant( wrapper, key, - args_in=(e3nn.normal("1x1o", key, (5,)), e3nn.normal("1x1o", key, (5,))), + e3nn.normal("1x1o", key, (5,)), + e3nn.normal("1x1o", key, (5,)), ) @pytest.mark.parametrize("biases", [False, True]) -def test_linear_legacy(key, biases): - f = lambda x1, x2: O3TensorProductLegacy("1x1o", biases=biases)(x1, x2) +@pytest.mark.parametrize( + "O3Layer", [O3TensorProduct, O3TensorProductFC, O3TensorProductSCN] +) +def test_gated_layers(key, biases, O3Layer): + def f(x1, x2): + return O3TensorProductGate("1x1o", biases=biases, o3_layer=O3Layer)(x1, x2) + f = hk.without_apply_rng(hk.transform(f)) v = e3nn.normal("1x1o", key, (5,)) params = f.init(key, v, v) - wrapper = lambda x1, x2: f.apply(params, x1, x2) + def wrapper(x1, x2): + return f.apply(params, x1, x2) assert_equivariant( wrapper, key, - args_in=(e3nn.normal("1x1o", key, (5,)), e3nn.normal("1x1o", key, (5,))), + e3nn.normal("1x1o", key, (5,)), + e3nn.normal("1x1o", key, (5,)), ) diff --git a/tests/test_segnn.py b/tests/test_segnn.py index 377faea..68d405f 100644 --- a/tests/test_segnn.py +++ b/tests/test_segnn.py @@ -1,69 +1,62 @@ import e3nn_jax as e3nn import haiku as hk import pytest -from e3nn_jax.util import assert_equivariant +from e3nn_jax.utils import assert_equivariant -from segnn_jax import SEGNN, weight_balanced_irreps +from segnn_jax import ( + SEGNN, + O3TensorProduct, + O3TensorProductFC, + O3TensorProductSCN, + weight_balanced_irreps, +) @pytest.mark.parametrize("task", ["graph", "node"]) @pytest.mark.parametrize("norm", ["none", "instance"]) -def test_equivariance(key, dummy_graph, norm, task): - import segnn_jax.blocks +@pytest.mark.parametrize( + "O3Layer", [O3TensorProduct, O3TensorProductFC, O3TensorProductSCN] +) +def test_segnn_equivariance(key, dummy_graph, task, norm, O3Layer): + scn = O3Layer == O3TensorProductSCN + + hidden_irreps = weight_balanced_irreps( + 8, e3nn.Irreps.spherical_harmonics(1), use_sh=not scn + ) + + def segnn(x): + return SEGNN( + hidden_irreps=hidden_irreps, + output_irreps=e3nn.Irreps("1x1o"), + num_layers=1, + task=task, + norm=norm, + o3_layer=O3Layer, + )(x) - segnn_jax.blocks.O3Layer = segnn_jax.blocks.O3TensorProduct - - segnn = lambda x: SEGNN( - hidden_irreps=weight_balanced_irreps(8, e3nn.Irreps.spherical_harmonics(1)), - output_irreps=e3nn.Irreps("1x1o"), - num_layers=1, - task=task, - norm=norm, - )(x) segnn = hk.without_apply_rng(hk.transform_with_state(segnn)) - graph = dummy_graph() - params, segnn_state = segnn.init(key, graph) - - def wrapper(x): - st_graph = graph._replace( - graph=graph.graph._replace(nodes=x), - node_attributes=e3nn.spherical_harmonics("1x0e+1x1o", x, normalize=True), - ) - y, _ = segnn.apply(params, segnn_state, st_graph) - return e3nn.IrrepsArray("1x1o", y) - - assert_equivariant(wrapper, key, args_in=(e3nn.normal("1x1o", key, (5,)),)) - - -@pytest.mark.parametrize("task", ["graph", "node"]) -@pytest.mark.parametrize("norm", ["none", "instance"]) -def test_equivariance_legacy(key, dummy_graph, norm, task): - import segnn_jax.blocks - - segnn_jax.blocks.O3Layer = segnn_jax.blocks.O3TensorProductLegacy - - segnn = lambda x: SEGNN( - hidden_irreps=weight_balanced_irreps(8, e3nn.Irreps.spherical_harmonics(1)), - output_irreps=e3nn.Irreps("1x1o"), - num_layers=1, - task=task, - norm=norm, - )(x) - segnn = hk.without_apply_rng(hk.transform_with_state(segnn)) + if scn: + attr_irreps = e3nn.Irreps("1x1o") + else: + attr_irreps = e3nn.Irreps("1x0e+1x1o") - graph = dummy_graph() + graph = dummy_graph(attr_irreps=attr_irreps) params, segnn_state = segnn.init(key, graph) def wrapper(x): + if scn: + attrs = e3nn.IrrepsArray(attr_irreps, x.array) + else: + attrs = e3nn.spherical_harmonics(attr_irreps, x, normalize=True) st_graph = graph._replace( graph=graph.graph._replace(nodes=x), - node_attributes=e3nn.spherical_harmonics("1x0e+1x1o", x, normalize=True), + node_attributes=attrs, ) y, _ = segnn.apply(params, segnn_state, st_graph) return e3nn.IrrepsArray("1x1o", y) - assert_equivariant(wrapper, key, args_in=(e3nn.normal("1x1o", key, (5,)),)) + assert_equivariant(wrapper, key, e3nn.normal("1x1o", key, (5,))) if __name__ == "__main__": diff --git a/validate.py b/validate.py index 4eaa14e..1f4a3fd 100644 --- a/validate.py +++ b/validate.py @@ -128,6 +128,11 @@ action="store_true", help="Use double precision in model", ) + parser.add_argument( + "--scn", + action="store_true", + help="Train SEGNN with the eSCN optimization", + ) # wandb parameters parser.add_argument( @@ -181,6 +186,7 @@ args.node_irreps = e3nn.Irreps("11x0e") args.output_irreps = e3nn.Irreps("1x0e") args.additional_message_irreps = e3nn.Irreps("1x0e") + assert not args.scn, "eSCN not implemented for qm9" elif args.dataset in ["charged", "gravity"]: args.task = "node" args.node_irreps = e3nn.Irreps("2x1o + 1x0e") @@ -188,42 +194,56 @@ args.additional_message_irreps = e3nn.Irreps("2x0e") # Create hidden irreps + if not args.scn: + attr_irreps = e3nn.Irreps.spherical_harmonics(args.lmax_attributes) + else: + attr_irreps = e3nn.Irrep(f"{args.lmax_attribute}y") + hidden_irreps = weight_balanced_irreps( scalar_units=args.units, - # attribute irreps - irreps_right=e3nn.Irreps.spherical_harmonics(args.lmax_attributes), - use_sh=True, + irreps_right=attr_irreps, + use_sh=(not args.scn), lmax=args.lmax_hidden, ) + args.o3_layer = "scn" if args.scn else "tpl" + del args.scn + # build model - segnn = lambda x: SEGNN( - hidden_irreps=hidden_irreps, - output_irreps=args.output_irreps, - num_layers=args.layers, - task=args.task, - pool="avg", - blocks_per_layer=args.blocks, - norm=args.norm, - )(x) + def segnn(x): + return SEGNN( + hidden_irreps=hidden_irreps, + output_irreps=args.output_irreps, + num_layers=args.layers, + task=args.task, + pool="avg", + blocks_per_layer=args.blocks, + norm=args.norm, + o3_layer=args.o3_layer, + )(x) + segnn = hk.without_apply_rng(hk.transform_with_state(segnn)) loader_train, loader_val, loader_test, graph_transform, eval_trn = setup_data(args) if args.dataset == "qm9": - from experiments.train import loss_fn + from experiments.train import loss_fn_wrapper - _mae = lambda p, t: jnp.abs(p - t) + def _mae(p, t): + return jnp.abs(p - t) - train_loss = partial(loss_fn, criterion=_mae, task=args.task) - eval_loss = partial(loss_fn, criterion=_mae, eval_trn=eval_trn, task=args.task) + train_loss = partial(loss_fn_wrapper, criterion=_mae, task=args.task) + eval_loss = partial( + loss_fn_wrapper, criterion=_mae, eval_trn=eval_trn, task=args.task + ) if args.dataset in ["charged", "gravity"]: - from experiments.train import loss_fn + from experiments.train import loss_fn_wrapper - _mse = lambda p, t: jnp.power(p - t, 2) + def _mse(p, t): + return jnp.power(p - t, 2) - train_loss = partial(loss_fn, criterion=_mse, do_mask=False) - eval_loss = partial(loss_fn, criterion=_mse, do_mask=False) + train_loss = partial(loss_fn_wrapper, criterion=_mse, do_mask=False) + eval_loss = partial(loss_fn_wrapper, criterion=_mse, do_mask=False) train( key,