Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breaking the Equivariant when using the haiku.module class #23

Open
zdcao121 opened this issue Oct 2, 2023 · 2 comments
Open

Breaking the Equivariant when using the haiku.module class #23

zdcao121 opened this issue Oct 2, 2023 · 2 comments

Comments

@zdcao121
Copy link

zdcao121 commented Oct 2, 2023

Hello, it's a great project!
I tried to use the EMLP with dm-haiku, and I write two version of codes in different ways. The first is directly using the emlp.nn.haiku, and the second is using the haiku.module class. But I found that the second version will break the equivariance of neural network. In my view, the two versions have no difference in the architecture of neural network. Could you tell me if something wrong with my way of using EMLP?

Attached is the codes. The first is:

import emlp.nn.haiku as ehk
import haiku as hk
from emlp.reps import V
from emlp.groups import SO
from jax import random
import jax.numpy as jnp

n = 10
dim = 3

G = SO(dim)

rep_in = n*V(G)
rep_out = n*V(G)

model = ehk.EMLP(rep_in, rep_out, group=G, num_layers=2, ch=256)
net = hk.without_apply_rng(hk.transform(model))

key = random.PRNGKey(0)
x = random.normal(key, (n*dim,))
params = net.init(random.PRNGKey(42), x)

v = net.apply(params, x)

g = G.sample()
x_1 = rep_in.rho(g)@x
v_1 = net.apply(params, x_1)

v_2 = rep_out.rho(g)@v
print(f"v(𝜌(g)x) =\n{v_1}")
print(f"𝜌(g)v(x) =\n{v_2}")
assert jnp.allclose(v_1, v_2)

and the second is:

import emlp.nn.haiku as ehk
from emlp.reps import V
from emlp.groups import SO
import haiku as hk
from jax import random
import jax.numpy as jnp


class test_EMLP(hk.Module):
    def __init__(self, n, dim, group, num_layers, ch, name=None):
      super().__init__(name=name)
      self.n = n
      self.dim = dim
      self.group = group(dim)
      self.rep_in = self.n*V(self.group)
      self.rep_out = self.n*V(self.group)
      self.num_layers = num_layers
      self.ch = ch

      self.e_mlp =self.e_mlp()
    
    def e_mlp(self):
        return ehk.EMLP(self.rep_in,
                        self.rep_out, 
                        group=self.group, 
                        num_layers=self.num_layers, 
                        ch=self.ch)

    def __call__(self, x):
       return self.e_mlp(x)


def forward_fn(x):
    model = test_EMLP(n=10, dim=3, group=SO, num_layers=2, ch=256)
    return model(x)

net = hk.without_apply_rng(hk.transform(forward_fn))

n = 10
dim = 3
G = SO(dim)
rep_in = n*V(G)
rep_out = n*V(G)


key = random.PRNGKey(1)
x = random.normal(key, (n*dim,))
params = net.init(random.PRNGKey(42), x)

v = net.apply(params, x)

g = G.sample()
x_1 = rep_in.rho(g)@x
v_1 = net.apply(params, x_1)

v_2 = rep_out.rho(g)@v
print(f"v(𝜌(g)x) =\n{v_1}")
print(f"𝜌(g)v(x) =\n{v_2}")
assert jnp.allclose(v_1, v_2)

@mfinzi
Copy link
Owner

mfinzi commented Oct 3, 2023

Yeah I believe this is because in Haiku, the init gets called on every forward pass of the model.
In EMLP for the BiLinear layer, a random subset of all possible bilinear interactions are chosen to limit the size and computational cost of the layer. However, this random subset will be different on different instantiations of the model (unless the random seed is held fixed), and therefore constructing the EMLP this way with Haiku will actually be evaluating slightly different models.

You can check by calling the above net.apply(params, x) for your test_EMLP multiple times.
If you look in emlp/nn/haiku.py, the way that I get around this is to have the relevant constructors be standard functions which will return the input output mapping through a haiku module. This way we can ensure that the precomputation of the equivariant bases (and choice of random subsets) is only performed once. (You can notice that your test_EMLP will be much slower than ehk.EMLP because it must compute the equivariant basis multiple times, although the caching may ameliorate this somewhat).

If you want to use it inside another Haiku module, my advice would be either

  • fix the random seed (still not ideal because it will end up repeating some fixed computations)
  • write the module constructor as a stateless function like with ehk.EMLP
  • Possibly there is some haiku specific mechanism for storing this precomputed state, and one could explore that

@zdcao121
Copy link
Author

zdcao121 commented Oct 6, 2023

Thanks for your early reply! I will try it soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants