-
I cannot find a function in jax that computes the leading singular vector of a matrix without also computing the full singular value decomposition. Essentially, I am looking for jax analog of the scipy function scipy.sparse.linalg.svds when k=1. Does such a function exist already in jax? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I think under the hood |
Beta Was this translation helpful? Give feedback.
-
Just follow the logic of scipy.sparse.linalg.svds with solver lobpcg |
Beta Was this translation helpful? Give feedback.
I think under the hood
svds
is just callingeighs
onA.T@A
, you could probably try something similar withjax.experimental.sparse.linalg.lobpcg_standard