Replies: 3 comments 4 replies
-
Hi @cgarciae -- thank you so much for the detailed note, and I apologize for the late reply. I wanted to jot down some (incomplete) parts of an answer, as a starting point, after chatting with @andsteing about this a couple of weeks ago. Both @jheek and I, in the past, explored a design that's similar to the one you propose (the one in Treex). And I agree with you that simply calling methods seems cleaner than The current Flax Linen design allows you to get access to stateful module instances (via Here's the general problem, though. Once you have stateful module instances as a first-class recommended practice, you hit into JAX rough edges and "footguns" -- you can accidentally make big mistakes and not even know them. Here is one example (pseudo-code):
This example seems to work, and sort-of does. But it masks a big problem -- now the variables in Ultimately, there's a fundamental tension between Python object semantics and JAX jitted functions. Flax takes the approach of intentionally exposing only pure functions, so that there's less opportunity to accidentally write code that "seems like it should work but doesn't". We should still be thinking about some of your comments above, such as the ergonomics to replacing parameters within a module subtree. Maybe we could think of a purely functional wrapper, like a "replace parameters" method of sorts. I haven't thought this through carefully. |
Beta Was this translation helpful? Give feedback.
-
I just want to add that I really appreciate the effort you took to compare Flax and Trees in such a detailed way! I agree with you that things in Trees look much simpler and "Pytorch" like, so in a sense more normal and how you would expect it. However I also agree with Avital that having pure functions is absolutely fundamental to make sure you interoperate with JAX well. If you don't do this, you will definitely run into problems that are hard to catch. However, I do think you raise a few great points regarding the Flax API that can/should be improved. Here are two replies:
|
Beta Was this translation helpful? Give feedback.
-
Hey @avital and @marcvanzee, thanks for your responses! I've recently been thinking of an alternative method for the "Transfering Parameters" problem which might fit within Flax's model, as I don't know some of the internals so I hope this is not too crazy. If this idea makes sense I would be happy to port it to a FLIP. The bind_init APIThe core idea behind Using pretrained modelsPerforming parameter transfer this way is more streamlined as you don't actually have to know the names of where this will be use and the pattern should be easy to remember: class MyModule(Module):
pretrained_module: PretrainedModule
linear: Linear
def __call__(self, x):
x = self.pretrained_module(x)
x = self.linear(x)
return x
pretrained_variables, pretrained_module = load_pretrained(...)
pretrained_module = pretrained_module.bind_init(pretrained_variables)
module = MyModule(pretrained_module)
variables = module.init(...) Loading a pretrained module inside init
@dataclass
class MyModule(Module):
pretrained_module: PretrainedModule
linear: Linear
def __post_init__(self):
pretrained_variables, pretrained_module = load_pretrained(...)
self.pretrained_module = pretrained_module.bind_init(pretrained_variables)
...
def __call__(self, x):
x = self.pretrained_module(x)
x = self.linear(x)
return x
module = MyModule()
variables = module.init(...) Extracting a submoduleThis one is a bit different because will will actually call class VAE(Module):
encoder: Encoder
decoder: Decoder
# lets pretend its this simple
def __call__(self, x):
z = self.encoder(x)
y = self.decoder(z)
return y
module = VAE()
variables = module.init(...)
# do some training and stuff
...
# get new structure
decoder = module.bind_init(variables).decoder
decoder_variables = decoder.init(...) # should use all the existing ones
... # maybe you serialize it an load it in another process
y = decoder.apply(decoder_variables, z, ...) As this example shows it would actually be expected that init binding is applied recursively. Notes
|
Beta Was this translation helpful? Give feedback.
-
Hey Flax team! I had a chat with @avital a while ago and he invited me to post about things that I believe are usability issues in Flax so here it is.
In this discussion I will be comparing Flax with the new family of Pytree-based Module systems that includes Treex and Equinox since they seem to improve usability while being roughly as powerful (there are known edge cases though). Please feel free to suggests edits, I writing this in "vanilla" Flax and might not know more advanced mechanics.
Goal
Figure out if there is a way for Flax to simplify some of the presented cases.
Calling Modules
From a user perspective this is probably the most noticeable difference between Flax and traditional module systems like Pytorch and Keras.
Normal Call
Flax
When calling the top module in Flax using
apply
you have a lot of control at the cost of being significantly more verbose. You also get an asymmetry between calling modules inside other modules and when calling them on the outside.Trees
Pytree Modules tend to behave like regular python objects, since they contain their own parameters and can be jitted so they generally don't need
apply
-like functions and rely on__call__
directly.Calling Methods
This example is the same as the previous but the methods
some_method
andanother_method
will be called instead of the__call__
method.Flax
To achieve this in Flax you use the
method
argument inapply
which lets you call a method, the tradeoffs are the same as the previous example.Note: I've never done this but I assume methods are called normally when inside other modules.
Trees
As in the previous example, in Pytree Modules you just call the method directly.
Transfering Parameters
Transferring parameters from pretrained modules in Flax is one of the areas that will probably cause more friction for new users coming from Pytorch or Keras as having a parameter structure (variables) separate from the computational structure (the module) can be both a blessing and a curse.
Using pretrained modules
This first use use cases is about performing the typical Transfer Learning task of loading a pretrained model and fine tuning it with an added linear classifier on top.
Flax
The tricky thing in Flax is inserting the pretrained parameters into the correct place on the new parameter structure as names will matter here and have to be known in advance.
Trees
Pytree Module don't require anything special, it works like in Pytorch / Keras.
Loading a pretrained module inside
__init__
This is a continuation of the previous but does loading step inside the new module, the motivation for this is that sometimes the loading code is abstracted away from the user, Keras "applications" tend to do something like this.
Flax
I don't know if this is possible in Flax.
Trees
For Pytree Module the only thing that changes from the previous case is that the pretrained module is not passed as a parameter to the constructor but instead loaded and assigned within it.
Extracting a submodule
Here we extract a submodule as its own "top-level" module.
Flax
The tricky thing in Flax again is getting the parameters into the structure of their new module, names matter here and there is no editor support to get this right.
Trees
Nothing special for Pytree Modules, just getting hold of the reference is enough.
Static State
Static state is often used to control, amongst other things, how Modules behave during training vs testing.
Flax
In Flax you do this by propagating static state all the way down from the jitted functions to the inner Modules.
Trees
Pytree Modules can keep this state around in the static part of the Pytree so they can can keep track of it for the user, jax will recompile upon change so
static_argnums
is not needed.Beta Was this translation helpful? Give feedback.
All reactions