Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NeuralODE trajectory API is quite limiting #113

Open
rsanchezgarc opened this issue Apr 22, 2024 · 2 comments
Open

NeuralODE trajectory API is quite limiting #113

rsanchezgarc opened this issue Apr 22, 2024 · 2 comments

Comments

@rsanchezgarc
Copy link

Hi,

I am trying to use your package with the from torchcfm.models.unet.unet import SuperResModel and other custom models that have kwargs in their forward method, but I think that the NeuralODE.trajectory method is not compatible with those models?

Could you please try to add a model_kwargs parameters to NeuralODE.trajectory, NeuralODE.forward, etc?

Thanks!

@kilianFatras
Copy link
Collaborator

kilianFatras commented Apr 23, 2024

Hello,

This seems to be a problem with Torchdyn. A workaround might be to use torchdiffeq instead. You could also write your own custom Euler integration method. Unfortunately, as the NeurIPS deadline is only one month away, I will not have time to look to this issue especially as it is not really related to TorchCFM but rather to Torchdyn.

Best,
K.

@atong01
Copy link
Owner

atong01 commented Apr 24, 2024

I would implement this by inheriting from the SuperResModel class i.e.

class MySuperResModel(SuperResModel):
    def forward(t, x):
         return super().forward(t, x, model_kwargs)

I think we do this in the conditional example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants