Skip to content

Commit

Permalink
update jax optimizers syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Sep 11, 2023
1 parent 8e6a9fd commit 638ebb9
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions probml_utils/mix_bernoulli_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jax.lax import scan
from jax.scipy.special import expit, logit
from jax.nn import softmax
from jax.experimental import optimizers
from jax.example_libraries import optimizers

import distrax
from distrax._src.utils import jittable
Expand Down Expand Up @@ -123,7 +123,9 @@ def m_step_per_bernoulli(responsibility):
mu = jnp.sum(responsibility[:, None] * observations, axis=0) / norm_const
return mu, norm_const

mus, ns = vmap(m_step_per_bernoulli, in_axes=(1))(self.responsibilities(observations))
mus, ns = vmap(m_step_per_bernoulli, in_axes=(1))(
self.responsibilities(observations)
)
return ns / n_obs, mus

def fit_em(self, observations, num_of_iters=10):
Expand Down Expand Up @@ -169,7 +171,9 @@ def train_step(params, i):
ll_hist, responsibility_hist = history

ll_hist = jnp.append(ll_hist, self.expected_log_likelihood(observations))
responsibility_hist = jnp.vstack([responsibility_hist, jnp.array([self.responsibilities(observations)])])
responsibility_hist = jnp.vstack(
[responsibility_hist, jnp.array([self.responsibilities(observations)])]
)

return ll_hist, responsibility_hist

Expand Down Expand Up @@ -254,7 +258,9 @@ def update(self, i, opt_state, batch):
loss, grads = value_and_grad(self.loss_fn)(params, batch)
return opt_update(i, grads, opt_state), loss

def fit_sgd(self, observations, batch_size, rng_key=None, optimizer=None, num_epochs=1):
def fit_sgd(
self, observations, batch_size, rng_key=None, optimizer=None, num_epochs=1
):
"""
Fits the model using gradient descent algorithm with the given hyperparameters.
Expand Down Expand Up @@ -314,7 +320,11 @@ def train_step(opt_state, batch):
self.model = (softmax(mixing_coeffs), probs)
self._probs = probs

return opt_state, (losses.mean(), *params, self.responsibilities(observations))
return opt_state, (
losses.mean(),
*params,
self.responsibilities(observations),
)

epochs = split(rng_key, num_epochs)
opt_state, history = scan(epoch_step, opt_state, epochs)
Expand All @@ -339,10 +349,14 @@ def plot(self, n_row, n_col, file_name):
The path where the figure will be stored
"""
if n_row * n_col != len(self.mixing_coeffs):
raise TypeError("The number of rows and columns does not match with the number of component distribution.")
raise TypeError(
"The number of rows and columns does not match with the number of component distribution."
)
fig, axes = plt.subplots(n_row, n_col)

for (coeff, mean), ax in zip(zip(self.mixing_coeffs, self.probs), axes.flatten()):
for (coeff, mean), ax in zip(
zip(self.mixing_coeffs, self.probs), axes.flatten()
):
ax.imshow(mean.reshape(28, 28), cmap=plt.cm.gray)
ax.set_title("%1.2f" % coeff)
ax.axis("off")
Expand Down

0 comments on commit 638ebb9

Please sign in to comment.