-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor ADM classes #488
Refactor ADM classes #488
Conversation
[ghstack-poisoned]
@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## gh/ebsmothers/19/base #488 +/- ##
=========================================================
+ Coverage 74.01% 74.05% +0.04%
=========================================================
Files 207 207
Lines 14203 14225 +22
=========================================================
+ Hits 10512 10535 +23
+ Misses 3691 3690 -1
☔ View full report in Codecov by Sentry. |
@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great overall and structurally good. I left some comments around name and parameters.
@@ -24,25 +24,55 @@ | |||
SinusoidalPositionEmbeddings, | |||
) | |||
|
|||
DEFAULT_EMBED_NAME = "clip_image" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be a parameter for ADMUnet instead of a global variable. If you look at the transform and adapter classes they all take the conditional input key as "*_field", for this case it could be "conditional_field". Also I don't think the default name should be based on CLIP as that's a dalle input and not generic to ADM. It should be something generic like "context" or "condition".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to context, left clip_image in the builder. Lmk if that's what you have in mind
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CLIP is a dalle2 input and is not specific to ADM. So I think we should remove the name from everywhere. Also I still don't believe this should be a global variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditched the global variable and used "context" for default in both the ADMUNet class and the builder
@@ -453,7 +479,7 @@ def adm_unet( | |||
time_embed_dim: int = 512, | |||
cond_embed_dim: int = 2048, | |||
clip_embed_dim: int = 768, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two variables should be just "embed_*" without clip
Differential Revision: [D50288361](https://our.internmc.facebook.com/intern/diff/D50288361) [ghstack-poisoned]
Differential Revision: [D50288361](https://our.internmc.facebook.com/intern/diff/D50288361) [ghstack-poisoned]
@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Differential Revision: [D50288361](https://our.internmc.facebook.com/intern/diff/D50288361) [ghstack-poisoned]
@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good now. Thanks for merging these two classes!
@ebsmothers merged this pull request in 367130e. |
Stack from ghstack (oldest at bottom):
Differential Revision: D50288361