Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inefficient _validate_input and mistake #12

Open
Trotew opened this issue May 12, 2022 · 4 comments
Open

inefficient _validate_input and mistake #12

Trotew opened this issue May 12, 2022 · 4 comments

Comments

@Trotew
Copy link

Trotew commented May 12, 2022

The function _validate_input seems to be incorrect. The code tries to block non-monotonic data, but since prev_t_i is not updated, it doesn't seem to work. And in a Cuda environment, this part of the code is very inefficient.

@patrick-kidger
Copy link
Owner

Good catch. I'd be happy to accept a PR fixing this.

Yep, none of this code is very efficient on GPU. My own use case for this was interpolating data as a preprocessing step prior to training, so it didn't matter.

If you're agnostic to the choice of autodiff framework, then you may like to have a look at Diffrax, which includes code for backward Hermite cubic splines. (Rather than the natural cubic splines in this package.) Unlike the code here that should be relatively GPU-efficient.

@Trotew
Copy link
Author

Trotew commented May 17, 2022

Thanks for the recommendation. I actually designed an interpolation module in a PyTorch model that needs to run online, so I am looking for a more efficient open-source PyTorch-based implementation of cubic spline interpolation.

I am also currently considering whether to switch to a more efficient interpolation method or to rewrite a more efficient implementation.

@patrick-kidger
Copy link
Owner

patrick-kidger commented May 17, 2022

Right. If you need it to run online then the implementation in this repository won't be suitable, I'm afraid, as natural cubic splines don't satisfy that property: the "future" affects the "past".

The backward Hermite cubic splines I mentioned above are probably the appropriate algorithmic tool here, but of course if you're constrained to PyTorch then you'll have to reimplement them yourself. I think torchcde does have an implementation you can use as a starting but IIRC this implementation is mistakenly noncausal in the presence of missing data (represented as NaNs). More broadly if you want a paper reference for backward Hermite cubic splines btw then see here.

@Trotew
Copy link
Author

Trotew commented May 17, 2022

Really appreciate your help.

I think I need to spend some time researching this problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants