-
Notifications
You must be signed in to change notification settings - Fork 250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEATURE: mps support #834
Comments
Would you mind sharing which lines you modified? The code is already intended to use whatever device you specify. |
Thanks. Would you be able to submit a pull request with the changes when you're finished? |
yes! |
After digging more, there's actually quite a lot that needs revision. The main issue is that Pytorch's support for mps is more limited than for cuda, and there are several torch functions/methods called (e.g., coo, but others) which are not supported when torch tensors are on mps. For now I'm putting this on the shelf, but if I come back to it and finish I'll put in a pull request. |
eh... okay, I spoke too soon (again). It didn't take that much more work. I believe everything is working now on mps. I submitted a pull request. |
Thanks again for working on this, I'm going to close the issue for now and move the conversation to the related pull request. For anyone looking into this feature in the future, see #839 |
Feature you'd like to see:
The pytorch code is written to call torch.cuda explicitly. It will not try to call mps if 'device = torch.device("mps")' is passed into the functions. It only took me a few minutes and a few modified lines to get kilosort4 running on mps. It seems like this would be a pretty easy thing to implement.
Additional Context
No response
The text was updated successfully, but these errors were encountered: