Skip to content

Commit

Permalink
Migrate users of tfp.experimental.substrates.jax to import it as
Browse files Browse the repository at this point in the history
tensorflow_probability.substrates.jax and to use the JAX specific
BUILD target.

PiperOrigin-RevId: 618931679
Change-Id: Ic47d7ed13e46336fd4e593725be14b6ebe86b1f7
  • Loading branch information
ThomasColthurst authored and copybara-github committed Mar 25, 2024
1 parent a04759f commit 440df89
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions acme/agents/jax/ail/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from acme.jax import networks as networks_lib
import jax
import jax.numpy as jnp
import tensorflow_probability as tfp
import tensorflow_probability.substrates.jax as tfp
import tree

tfp = tfp.experimental.substrates.jax

tfd = tfp.distributions

# The loss is a function taking the discriminator, its state, the demo
Expand Down

0 comments on commit 440df89

Please sign in to comment.