Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve memory management in clustering_qr.kmeans_plusplus #775

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

RobertoDF
Copy link
Contributor

This modification avoids the creation or immediately deletes unnecessary tensors in clustering_qr.kmeans_plusplus. It helps with OOM errors (#746 ) happening at

vexp = 2 * Xg @ Xc.T - (Xc**2).sum(1)

Xg can at sometimes be quite big (5GB in the case I get OOM), in both of these lines a copy of Xg was created unnecessarily on the GPU.

#Xg = torch.from_numpy(Xd).to(dev)
vtot = (Xg**2).sum(1)

&
vexp = 2 * Xg @ Xc.T - (Xc**2).sum(1)

The solution to line 202 does not impact speed. Solution to line 167 might impact speed but not in any noticeable fashion on my tests, for this reason I didn´t extend the reach of the clear_cache arg to the kmeans_plusplus func.

Tested on pytorch 2.1.2 and 2.4.1.

@jacobpennington
Copy link
Collaborator

@RobertoDF Are you able to share the data that you're seeing this problem with so that I can test this myself?

@RobertoDF
Copy link
Contributor Author

Sure, compressing the files now.

@RobertoDF
Copy link
Contributor Author

RobertoDF commented Sep 5, 2024

In the zip there is a jupyter notebook that shows the problem and the specific Xd tensor that causes the crash on my machine. I put the standard and modified versions of kmeans_plusplus. The old one should crash, if you run the new one afterwards, it should run without errors.
https://we.tl/t-40kiuNy3Cd

@RobertoDF
Copy link
Contributor Author

Just noticed that in the notebook I didn´t include the change at line vtot = (Xg**2).sum(1)

@jacobpennington
Copy link
Collaborator

@RobertoDF Those are not the files I would need. I mean the full recording, either a .bin file or whatever format you converted from, along with the probe file you used.

@RobertoDF
Copy link
Contributor Author

This last commit seems to really solve the OOM problems.

@Peyton-D
Copy link

Hello, I tried to use your last commit, but I'm still getting a CUDA OOM error in the final clustering phase. How much dedicated GPU memory do you have? I have 8 GB, and Kilosort used on average 6-7 GB throughout sorting until crashing at the end.

@RobertoDF
Copy link
Contributor Author

RobertoDF commented Sep 27, 2024

I have 12 GB. Without the modification I would get OOM often inside the kmeans_plus_plus func. Which line is problematic to you exactly? and what is the error message saying? also what is your recording duration?

@Peyton-D
Copy link

Thanks for the quick response. Yes, kmeans_plus_plus inside of clustering_qr seems to be the cause of each crash every time. My recording duration is 90 min. Here's the problematic line and the kilosort log if it helps:

File "C:\miniconda3\envs\kilosort\lib\site-packages\kilosort\clustering_qr.py", line 215, in kmeans_plusplus
mu[j] = Xg[ix].mean(0)

kilosort4_9_26_1700.log

@RobertoDF
Copy link
Contributor Author

Mmm never had a crash at that line. If you use the normal version, not my fork, does it also crashes in the same line?

@Peyton-D
Copy link

Just ran another attempt with normal version. Here's the problem line:

File "C:\Users\ColginLab\miniconda3\envs\kilosort\lib\site-packages\kilosort\clustering_qr.py", line 167, in kmeans_plusplus
vtot = (Xg**2).sum(1)

kilosort4_normal_version.log

@RobertoDF
Copy link
Contributor Author

Ok that was a problematic line also for me and indeed I would expect my solution to solve that one. But I never had a problem at the line you showed me before. Maybe it can be optimized further but I won't have time to check this in near future. If you have access to a 12 GB I would expect that to solve the problem.
If you are on windows you can try to use a debugger stopping at that line and inspect the GPU memory via task manager.

@Peyton-D
Copy link

Alright, I'll look into getting more GPU memory. Thanks for the help!

@jacobpennington
Copy link
Collaborator

@RobertoDF Are you able to provide a bit more explanation for the changes you proposed? I can see from other issues that they're helping with some memory problems, but I'm having a hard time finding any information in the Pytorch docs that would explain why these changes prevent copies / otherwise reduce memory usage.

@RobertoDF
Copy link
Contributor Author

Sure! I just went in the code using a debugger breakpoint while checking GPU memory consumption and substitute (while checking the output to be identitical) lines until I would find a combination that would somehow avoid the unnecessary creation of large arrays on the GPU without sacrificing any speed (at least in my tests). Loads of trial and error!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants