Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving and Loading Objax EMLPs yields slightly different predictions #8

Open
mfinzi opened this issue Jun 5, 2021 · 1 comment
Open
Assignees
Labels
bug Something isn't working

Comments

@mfinzi
Copy link
Owner

mfinzi commented Jun 5, 2021

It appears loading up an EMLP models saved with objax.io.save_var_collection yields slightly different predictions than the original model.

import emlp
from emlp.groups import SO
from emlp.reps import T,V
import numpy as np
import objax

net = emlp.nn.EMLP(T(1)+T(2),T(0),SO(3))
x=np.linspace(0,1,12)
print(net(x).T)

# Saving with file descriptor
with open('net.npz', 'wb') as f:
    objax.io.save_var_collection(f, net.vars())

Output: [0.4905533]

import emlp
from emlp.groups import S
from emlp.reps import T,V
import numpy as np
import objax

net = emlp.nn.EMLP(T(1)+T(2),T(0),SO(3))
x=np.linspace(0,1,12)
# Loading with file descriptor
with open('net.npz', 'rb') as  f:
    objax.io.load_var_collection(f, net.vars())
    
print(net(x).T)

Output [0.4904544]

@mfinzi mfinzi added the bug Something isn't working label Jun 5, 2021
@mfinzi mfinzi self-assigned this Jun 5, 2021
@mfinzi
Copy link
Owner Author

mfinzi commented Jun 5, 2021

It looks like this is due to randomness in the bilinear layer that is not captured as objax state variables. A workaround is to use the same numpy seed when initializing the model in both cases:

import emlp
from emlp.groups import SO
from emlp.reps import T,V
import numpy as np
import objax

np.random.seed(42)
net = emlp.nn.EMLP(T(1)+T(2),T(0),SO(3))
x=np.linspace(0,1,12)
print(net(x).T)

# Saving with file descriptor
with open('net.npz', 'wb') as f:
    objax.io.save_var_collection(f, net.vars())

Output: [0.03251831]

import emlp
from emlp.groups import S
from emlp.reps import T,V
import numpy as np
import objax

np.random.seed(42)
net = emlp.nn.EMLP(T(1)+T(2),T(0),SO(3))
x=np.linspace(0,1,12)
# Loading with file descriptor
with open('net.npz', 'rb') as  f:
    objax.io.load_var_collection(f, net.vars())
    
print(net(x).T)

Output [0.03251831]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant