Replies: 7 comments 14 replies
-
I was closely looking at this, it would be very helpful to see the discussion or guideline on how you guys are doing the linen API design spec, if you could share that with us, that would be awesome! Please and thank you! |
Beta Was this translation helpful? Give feedback.
-
Yes we are about to suggest people play around with the Linen abstraction as an alpha release this week. We'll also post some of our design goals, perhaps as a new GitHub discussion. For now, though, to answer your questions,
In the meanwhile, you can take a look at some of our ported examples at https://github.com/google/flax/tree/master/linen_examples. The VAE example most clearly shows how to deal with module instances and methods on them. |
Beta Was this translation helpful? Give feedback.
-
@avital I tried the example but I see that you never actually get hold of a reference to a Say you have a way to get a pretrained model model, pretrained_params = load_pretrained() and you want to create a new model to e.g. perform transfer learning class AwesomeClassifier(Module):
@nn.compact
def __call__(self, x):
x = self.somehow_use_pretrained_model(x)
logits = nn.Dense(10)(x)
return logits How would you construct this module and how are the |
Beta Was this translation helpful? Give feedback.
-
@avital import numpy as np
from jax import random
from jax.config import config
config.enable_omnistaging()
from flax import linen as nn
class Child(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(32)(x)
x = nn.Dense(10)(x)
return x
class Parent(nn.Module):
child: nn.Module
@nn.compact
def __call__(self, x):
x = self.child()(x)
x = nn.Dense(2)(x)
return x
def load_pretrained(x):
key = random.PRNGKey(42)
child_params = Child().init(key, x)["param"]
return Child, child_params
x = np.random.uniform(size=(64, 3))
model, pretrained_params = load_pretrained(x)
key = random.PRNGKey(42)
params = Parent(model).init(key, x)["param"]
# Manually add pretrained <== HERE
params["Child_0"] = pretrained_params # FrozenDict forbids this but you get the idea
y = Parent(model).apply({"param": params}, x) To transfer the pretrained parameters from the |
Beta Was this translation helpful? Give feedback.
-
I also think this is very nice and all framework are "forced" to this good habit by Jax, the thing is that they can also be wrongly modified to not agree with the computation as is the issue we are seeing here.
I mean yes BUT modern python, specially with the help of the
Yeah, Elegy uses the equivalent of model = nn.Linear(10)
model.init(...)(...)
model.w, model.b # <== model tracks its own weights so you can pass it around Although you can also get them as a dict, this is needed for actual training: model.get_paramteres() # {"w": ..., "b": ...} The thing is that if a pass parent = Parent(model)
parent.init(...)(...)
parent.get_parameters() # {"child": {"w": ..., "b": ...}, ...} Elegy is stateful intentionally, if you train
Nice! Will definitely check it out when its available. |
Beta Was this translation helpful? Give feedback.
-
@avital This is a bit more meta, but I've been intrigued by the ideas proposed in parallax. I like:
def loss(model, x, y)
preds = model(x)
return jnp.mean(jnp.square(y - preds))
grads = jax.grad(loss)(model, x, y)
model = jax.tree_multimap(lambda p, g: p - 0.01*g, model, grads) Things I don't like:
Any way. It would be interesting if the Jax ecosystem could join forces to create a standard. |
Beta Was this translation helpful? Give feedback.
-
Apologies for reviving this old thread, but would this still be the recommended way of doing transfer learning?
I'd like to be able to freeze/not calculate gradients on the |
Beta Was this translation helpful? Give feedback.
-
Hey, I am the main developer of Elegy. Initially this library was supposed to be a Keras-like interface on top of an intermediate library like Flax or Haiku. Initially Haiku was chosen but it soon became apparent that if modules didn't have references between them doing stuff like transfer learning would be very painful since you would have to merge parameter structures manually to match the code structure.
So in version
0.2.0
all submodules registered themselves under their parent module and you could very simply extract and reuse them since they also keep track of their own parameters and states. Take a look at this VAE example where the standalone decoder is extracted by reference from the full VAE in order to generate new samples:https://github.com/poets-ai/elegy/blob/master/examples/mnist_vae.py#L197
But with this change we had to implement our own Module class and port all the layers since we were no longer compatible with Haiku (which was not the initial idea). I recently saw some comments in the Linen API which seem to point in this direction:
https://github.com/google/flax/blob/master/flax/linen/module.py#L51
However it seems the API is very new and there is no documentation. I was wondering if you can give me some details on it capabilities:
Beta Was this translation helpful? Give feedback.
All reactions