You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
This might be a basic question, but I'm wondering if there's a canonical way to go about this. I have a group of feedforward nets that I am training--each data point is an array that will get fed into one of these nets, along with an index that identifies which net it should get processed by. Hence my thinking of the overall net as a 'dict'. (For context, you can think of each data point as the position of an atom, along with an identifier for its element. I want to have a different net for each element, and feed the data point to the appropriate net.)
I can write a module for each element-specific network individually:
But I'm not sure how (or if it's desirable) to wrap up this individual nets to one module I can train. I think it would be something like
classMoleculeNetwork(nn.Module):
elementnets: Sequence[ElementNetwork]
defsetup(self):
self.elementdict= {net.element_index: netfornetinelementnets}
# initialize all the parameters here, and put them in a params dictdef__call__(self, element, datum):
elementnet=self.elementdict[element]
returnelementnet.apply(params[element], datum)
but I'm not entirely sure how to initialize such a net. I guess you could initialize all the params for each element's net in setup, then override the MoleculeNetwork's init to return the dict of params? This could still be trained, since params is a Pytree and a dict of params would still be a Pytree, correct?
A related uncertainty for me is how to pass the data in. I understand that we want to write a loss function for a single (x, y) pair, then vmap (and then jit) it. But I'm thinking of my data points as tuples, ie (element: int, datapoint: Array) pairs, and I don't think a single pair loss function that takes something like that in can be vmapped, since the arguments have to be JAX Numpy arrays, right?
Thanks for your help, and for creating this very cool library.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello,
This might be a basic question, but I'm wondering if there's a canonical way to go about this. I have a group of feedforward nets that I am training--each data point is an array that will get fed into one of these nets, along with an index that identifies which net it should get processed by. Hence my thinking of the overall net as a 'dict'. (For context, you can think of each data point as the position of an atom, along with an identifier for its element. I want to have a different net for each element, and feed the data point to the appropriate net.)
I can write a module for each element-specific network individually:
But I'm not sure how (or if it's desirable) to wrap up this individual nets to one module I can train. I think it would be something like
but I'm not entirely sure how to initialize such a net. I guess you could initialize all the params for each element's net in
setup
, then override the MoleculeNetwork's init to return the dict of params? This could still be trained, since params is a Pytree and a dict of params would still be a Pytree, correct?A related uncertainty for me is how to pass the data in. I understand that we want to write a loss function for a single (x, y) pair, then vmap (and then jit) it. But I'm thinking of my data points as tuples, ie
(element: int, datapoint: Array)
pairs, and I don't think a single pair loss function that takes something like that in can be vmapped, since the arguments have to be JAX Numpy arrays, right?Thanks for your help, and for creating this very cool library.
Beta Was this translation helpful? Give feedback.
All reactions