Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OT-CFM performs worse on conditional generation tasks #117

Open
rlee3359 opened this issue May 28, 2024 · 1 comment
Open

OT-CFM performs worse on conditional generation tasks #117

rlee3359 opened this issue May 28, 2024 · 1 comment

Comments

@rlee3359
Copy link

rlee3359 commented May 28, 2024

Thanks very much for this code base, it's been a great way to learn about flow matching. I have a question regarding conditional generation with OT-CFM.

When testing different FM approaches on my own data, I noticed that OT-CFM trains significantly slower and tends to perform much worse on tasks with conditioning. In an effort to isolate this problem I tried conditional MNIST, comparing OT-CFM with FM (using the example provided).

After a single epoch of training, I visualized the generations of both approaches with 1 step and dopri5. FM is on the left, OT-CFM is on the right.
One step generation (euler with 1 step):

Adaptive generation with dopri5:

After one epoch of training, FM has much nicer generations for both 1 sampling step and with dopri5. Even after a longer training time, FM continues to outperform OT-CFM (converges much faster).

After reading more, I noticed that both OT-CFM and Multisample Flow Matching papers only report results for unconditional generation, while papers doing conditional generation such as Stable Diffusion 3 and Flow Matching in Latent Space seem to use standard flow matching without batch optimal transport.

I wonder if the authors have studied this, and if there are any results for OT-CFM conditional tasks, or perhaps if there is a reason or explanation that OT-CFM should not work in this setting. My intuition was that adding conditioning makes the combinatorial space of the OT plan extremely hard to approximate from the limited samples in the batch, and this would be further exaggerated if the conditioning is not on simple class labels but rather continuous values (for example language embeddings for text to image generation etc).

I would greatly appreciate any insight on this, and if there is an approach that is applicable to conditional generation. Thank you!

The code tweaks for this were:

sigma=0.0
if args.fm_method == "fm":
    FM = TargetConditionalFlowMatcher(sigma=sigma)
elif args.fm_method == "otcfm":
    FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
if args.fm_method == "fm":
    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
    y1 = y
elif args.fm_method == "otcfm":
    t, xt, ut, _, y1 = FM.guided_sample_location_and_conditional_flow(x0, x1, y1=y)
@atong01
Copy link
Owner

atong01 commented May 31, 2024

Hi @rlee3359,

Interesting finding!

We have not explored OT in this setting very much. We do use it in a text conditioned model in our most recent work on protein generation (see https://arxiv.org/abs/2405.20313), but did not test the extent to which OT helps in this setting, as it worked so well in the conditional setting.

My intuition was that adding conditioning makes the combinatorial space of the OT plan extremely hard to approximate from the limited samples in the batch, and this would be further exaggerated if the conditioning is not on simple class labels but rather continuous values (for example language embeddings for text to image generation etc).

I'm not sure about this intuition as it seems that even if the OT plan is not approximated well, this should just fall back to random pairings. It seems like the OT pairing is actively harmful in this setting.

It's also quite interesting to see that the one-step generations are all the same for FM but not for OT-CFM, and that the dopri5 generated samples seem more uniform for the FM (line thickness especially). I suspect what is happening is that FM is learning some averaged image first where OT-CFM may be forced to try to directly predict the diverse images from the noise. This is probably difficult to learn especially early in training.

Happy to discuss sometime if you're interested.

--Alex

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants