-
-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inefficient _validate_input and mistake #12
Comments
Good catch. I'd be happy to accept a PR fixing this. Yep, none of this code is very efficient on GPU. My own use case for this was interpolating data as a preprocessing step prior to training, so it didn't matter. If you're agnostic to the choice of autodiff framework, then you may like to have a look at Diffrax, which includes code for backward Hermite cubic splines. (Rather than the natural cubic splines in this package.) Unlike the code here that should be relatively GPU-efficient. |
Thanks for the recommendation. I actually designed an interpolation module in a PyTorch model that needs to run online, so I am looking for a more efficient open-source PyTorch-based implementation of cubic spline interpolation. I am also currently considering whether to switch to a more efficient interpolation method or to rewrite a more efficient implementation. |
Right. If you need it to run online then the implementation in this repository won't be suitable, I'm afraid, as natural cubic splines don't satisfy that property: the "future" affects the "past". The backward Hermite cubic splines I mentioned above are probably the appropriate algorithmic tool here, but of course if you're constrained to PyTorch then you'll have to reimplement them yourself. I think torchcde does have an implementation you can use as a starting but IIRC this implementation is mistakenly noncausal in the presence of missing data (represented as NaNs). More broadly if you want a paper reference for backward Hermite cubic splines btw then see here. |
Really appreciate your help. I think I need to spend some time researching this problem. |
The function
_validate_input
seems to be incorrect. The code tries to block non-monotonic data, but sinceprev_t_i
is not updated, it doesn't seem to work. And in a Cuda environment, this part of the code is very inefficient.The text was updated successfully, but these errors were encountered: