-
I am hoping to do sparse matrix operations on some data because I am finding that dense matrix operations are taking too long. In particular, I want to use the But I found that only the CSR data structure can be used with Furthermore, I'm having trouble even getting I made a new conda environment to test this out, using the most recent versions of jaxlib and jax, and sure enough I can now call Any ideas on what I can try here? It seems like a very tricky spot to be in where my CUDA drivers are too old to be used with |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
I also just saw in a recent post that while sparse matrices may improve memory costs, they might not actually improve runtime significantly. In which case, are there any optimizations I can perform on my code to get it to run faster? The objective function I'm using in a training loop (using the optax library) is given below. These matrices can get pretty large, and I think the costliest line is the
I'm not sure if people here have much experience using optax, but I thought I'd attach the training loop as well in case anyone has thoughts on how to speed it up. I've trivially found that using other optimizers than AdamW can speed things up, but AdamW seems to be the most stable and reliable so I've stuck with it.
|
Beta Was this translation helpful? Give feedback.
-
That's correct. This function lowers to a GPU primitive that only supports CSR input.
That's correct. Most of our existing JAX sparse functionality is built for BCOO. But BCSR is getting more complete, and that should work with
Unlikely, because the GPU lowering for spsolve requires CSR.
Yeah, CUDA versioning is a real pain. Cusparse/cusolver in particular seems to be pretty unstable between releases, which makes it difficult for a library like JAX to target effectively because things break with little warning, and standing up a comprehensive CI test matrix across CUDA/cusparse/cusolver/cudnn/hardware versions is too expensive to justify.
No, no real suggestions. This is why JAX sparse is still in Hope that helps! |
Beta Was this translation helpful? Give feedback.
That's correct. This function lowers to a GPU primitive that only supports CSR input.
That's correct. Most of our existing JAX sparse functionality is built for BCOO. But BCSR is getting more complete, and that should work with
spsolve()
, so that's worth a try.Unlikely, because the GPU lowering for spsolve requires CSR.
Yeah, CUDA versioning is a real pain. Cusparse/cusolver in particular seems to be pretty unstable be…