Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ADM classes #488

Closed
wants to merge 4 commits into from
Closed

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Oct 13, 2023

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 13, 2023
@ebsmothers
Copy link
Contributor Author

@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@codecov-commenter
Copy link

codecov-commenter commented Oct 13, 2023

Codecov Report

Attention: 2 lines in your changes are missing coverage. Please review.

Comparison is base (1fd96dc) 74.01% compared to head (8c0850f) 74.05%.

Additional details and impacted files
@@                    Coverage Diff                    @@
##           gh/ebsmothers/19/base     #488      +/-   ##
=========================================================
+ Coverage                  74.01%   74.05%   +0.04%     
=========================================================
  Files                        207      207              
  Lines                      14203    14225      +22     
=========================================================
+ Hits                       10512    10535      +23     
+ Misses                      3691     3690       -1     
Files Coverage Δ
tests/diffusion_labs/test_adm.py 100.00% <100.00%> (ø)
...dal/diffusion_labs/models/dalle2/dalle2_decoder.py 100.00% <ø> (ø)
...chmultimodal/diffusion_labs/models/adm_unet/adm.py 98.70% <96.36%> (+0.74%) ⬆️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ebsmothers
Copy link
Contributor Author

@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great overall and structurally good. I left some comments around name and parameters.

@@ -24,25 +24,55 @@
SinusoidalPositionEmbeddings,
)

DEFAULT_EMBED_NAME = "clip_image"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a parameter for ADMUnet instead of a global variable. If you look at the transform and adapter classes they all take the conditional input key as "*_field", for this case it could be "conditional_field". Also I don't think the default name should be based on CLIP as that's a dalle input and not generic to ADM. It should be something generic like "context" or "condition".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to context, left clip_image in the builder. Lmk if that's what you have in mind

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CLIP is a dalle2 input and is not specific to ADM. So I think we should remove the name from everywhere. Also I still don't believe this should be a global variable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditched the global variable and used "context" for default in both the ADMUNet class and the builder

torchmultimodal/diffusion_labs/models/adm_unet/adm.py Outdated Show resolved Hide resolved
@@ -453,7 +479,7 @@ def adm_unet(
time_embed_dim: int = 512,
cond_embed_dim: int = 2048,
clip_embed_dim: int = 768,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two variables should be just "embed_*" without clip

@ebsmothers
Copy link
Contributor Author

@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@ebsmothers
Copy link
Contributor Author

@ebsmothers has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good now. Thanks for merging these two classes!

@facebook-github-bot
Copy link
Contributor

@ebsmothers merged this pull request in 367130e.

@facebook-github-bot facebook-github-bot deleted the gh/ebsmothers/19/head branch October 20, 2023 14:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants