How to design classes that can load pretrained checkpoints. #2454
Replies: 4 comments 4 replies
-
Interesting, I guess the goal is to eliminate the need for the separate configuration class and directly define the model without a separate module. New API looks fine to me. I didn't really understand the issue with dataclasses (which avoids repetitive code and errors). Can you just decorate it directly with |
Beta Was this translation helpful? Give feedback.
-
Hey @patrickvonplaten, this would be amazing! I've been using the Regarding custom
Point 1 is required because Lines 972 to 974 in e320e11 Example: import flax.linen as nn
class GoodCustomInit(nn.Module):
a: int
def __init__(self, a: int, parent=None, name=None):
self.a = a
self.parent = parent
self.name = name
def __call__(self, x):
return x + self.a
module = GoodCustomInit(1, 2)
y = module.apply({}, 3) Notice you must additionally provide + set the class BadCustomInit(nn.Module):
a: int
def __init__(self, a: int):
self.a = 10
def __call__(self, x):
return x + self.a
module = BadCustomInit(1)
y = module.apply({}, 3) # TypeError: __init__() got an unexpected keyword argument 'parent' Additionally, beware of how Mixins interact with Flax (e.g. #1409), and in general take into account how inheritance works in dataclasses (check out some patterns that don't work here). A strategy that might work:
Example: import flax.linen as nn
class A(nn.Module):
a: int
def __init__(self, a: int, parent=None, name=None):
self.a = a
self.parent = parent
self.name = name
def __call__(self, x):
return x + self.a
class B(A):
b: int
def __init__(self, a: int, b: int, **kwargs):
self.b = b
super().__init__(a, **kwargs)
def __call__(self, x):
return super().__call__(x) + self.b
module = B(1, 2)
y = module.apply({}, 3) |
Beta Was this translation helpful? Give feedback.
This comment has been minimized.
This comment has been minimized.
-
To avoid all the mess I've mentioned about having custom @new_register_to_config
class BertModel(nn.Module):
... |
Beta Was this translation helpful? Give feedback.
-
Hey,
We would like to integrate Flax/JAX into the diffusers repo: huggingface/diffusers#475 .
Similar to Transformers, a very important aspect of
diffusers
is the ability to load pretrained checkpoints easily into the classes, so wewould like to have the following functionality.
Because of this we decided on the following design in Transformers (defining it here as design 1.))
In the aftermath @patil-suraj and I had the feeling that this might have been the wrong design as it breaks the "state-less" assumption of JAX models. This is the reason we added a second API to transformers only recently: huggingface/transformers#16148
Having thought about this a bit more, we would like to change the design in diffusers from the start and only adopt the "stateless" solution. This would mean we would like to do something (design 2.))
Note that this would mean that the API for JAX would be something like:
Does this API make sense for you?
The main questions would be:
1.) Is it ok to define a
__init__(...)
in ann.Module
. Note that we need to do this so that ourregister_config
function works as expected: https://github.com/huggingface/diffusers/blob/25a51b63ca75e1351069bee87a0fb3df5abb89c3/src/diffusers/models/unet_2d.py#L582.) What do you think about the design opinion here in general. Would you also favor 2.) over 1.) to keep "stateless-ness" intact?
cc @marcvanzee @jheek @cgarciae @levskaya @jekbradbury @boris @borisdayma
Beta Was this translation helpful? Give feedback.
All reactions