You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all! I am posting this in the discussion section because I'm not really sure if this is an issue or if I am doing some simple mistake (also, I think I found a solution for this), but I'd like to understand what is going on if possible. In short, I'm optimizing some function, and jitting the function that computes the gradient. As optimization proceeds, memory usage grows a lot and consistently. If I don't use jit at all, then the memory is just fine. The code I'm running is quite big, but here's a "small" example:
import jax
import jax.experimental.optimizers as opt
import jax.numpy as np
from tqdm import tqdm
def f(A, eps):
A = np.tril(A)
x = np.dot(A, eps)
out = np.dot(x, x) - np.log(np.dot(np.diag(A), np.diag(A)))
return out
f_vec = jax.vmap(f, in_axes = (None, 0))
def f_avg(A, eps):
out = f_vec(A, eps)
return out.mean(), out.mean()
gf = jax.jit(jax.grad(f_avg, has_aux = True))
rng_key = jax.random.PRNGKey(1)
dim = 2000
A = np.tril(np.ones((dim, dim)))
init_fun, update_fun, get_params = opt.adam(0.01)
opt_state = init_fun(A)
update_fun = jax.jit(update_fun)
losses = []
for i in tqdm(range(100000)):
rng_key, _ = jax.random.split(rng_key)
A = get_params(opt_state)
eps = jax.random.normal(rng_key, shape = (50, dim,))
g, loss = gf(A, eps)
losses.append(loss)
opt_state = update_fun(i, g, opt_state)
The possible solutions I found:
1. As I said before, if I don't use jit when I define gf, then things are fine. However, this is not ideal since things become slow.
2. If I use jit, but replace the line losses.append(loss) by losses.append(loss.item()) things are also apparently fine (at least in the examples I tried things are okay). This makes me think that I may be missing something simple...
I am running things on a cluster. To my surprise, I am not able to reproduce this strange memory behavior when running things locally on my computer. That is, when running things locally memory always stays low (regardless of using jit or not, and adding .item() or not).
I am using Jax version 0.2.9, and jaxlib version 0.1.61 (both in the cluster and locally). Slightly different versions of python: 3.9.1 in the cluster, and 3.9.2 locally. Also, different OS (CentOS Linux 7 in cluster and MacOS 10.14.5 locally).
Any suggestions for potential causes for this would be highly appreciated!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all! I am posting this in the discussion section because I'm not really sure if this is an issue or if I am doing some simple mistake (also, I think I found a solution for this), but I'd like to understand what is going on if possible. In short, I'm optimizing some function, and jitting the function that computes the gradient. As optimization proceeds, memory usage grows a lot and consistently. If I don't use jit at all, then the memory is just fine. The code I'm running is quite big, but here's a "small" example:
The possible solutions I found:
1. As I said before, if I don't use jit when I define gf, then things are fine. However, this is not ideal since things become slow.
2. If I use jit, but replace the line
losses.append(loss)
bylosses.append(loss.item())
things are also apparently fine (at least in the examples I tried things are okay). This makes me think that I may be missing something simple...I am running things on a cluster. To my surprise, I am not able to reproduce this strange memory behavior when running things locally on my computer. That is, when running things locally memory always stays low (regardless of using jit or not, and adding .item() or not).
I am using Jax version 0.2.9, and jaxlib version 0.1.61 (both in the cluster and locally). Slightly different versions of python: 3.9.1 in the cluster, and 3.9.2 locally. Also, different OS (CentOS Linux 7 in cluster and MacOS 10.14.5 locally).
Any suggestions for potential causes for this would be highly appreciated!
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions