Skip to content

Commit

Permalink
Migrate to modern JAX
Browse files Browse the repository at this point in the history
  • Loading branch information
RadostW committed Oct 23, 2024
1 parent 2a36ac8 commit 2070785
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions pychastic/sde_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import jax.numpy as jnp
import numpy as np
import tqdm
from jax.experimental.host_callback import id_tap
import jax.debug
import jax.tree_util
from pychastic.sde_problem import SDEProblem
from pychastic.vectorized_I_generation import get_wiener_integrals

Expand Down Expand Up @@ -262,7 +263,7 @@ def tap_func(*args,**kwargs):
def chunk_function(chunk_start, wieners_chunk):
# Parameters: chunk_start = (t0, x0, w0) values at beggining of chunk
# wieners_chunk = array of wiener increments
id_tap(tap_func,0)
jax.debug.callback(tap_func,0)
z = jax.lax.scan( scan_func , chunk_start , wieners_chunk )[0] #discard trajectory at chunk resolution
return z, z

Expand All @@ -272,7 +273,7 @@ def get_solution_fragment(starting_state,key):
last_state , (time_values, solution_values, wiener_values) = jax.lax.scan(
chunk_function,
starting_state,
jax.tree_map(lambda x: jnp.reshape(x,(-1,chunk_size)+x.shape[1:]), wiener_integrals)
jax.tree_util.tree_map(lambda x: jnp.reshape(x,(-1,chunk_size)+x.shape[1:]), wiener_integrals)
) #discard carry, remember trajectory

return (
Expand All @@ -291,7 +292,7 @@ def get_solution(key, x0):
jax.random.split(key, number_of_chunks // chunks_per_randomization)
)

return jax.tree_map(lambda x: x.reshape((-1,)+x.shape[2:]),chunked_solution) #combine big chunks into one trajectory
return jax.tree_util.tree_map(lambda x: x.reshape((-1,)+x.shape[2:]),chunked_solution) #combine big chunks into one trajectory

get_solution = jax.vmap(get_solution, in_axes=(0, 0))

Expand Down Expand Up @@ -354,7 +355,7 @@ def solve(self, problem, seed=0, chunk_size=1, chunks_per_randomization = None,
"""
solution = self.solve_many(problem, n_trajectories=1, seed=seed, chunk_size = chunk_size, chunks_per_randomization = chunks_per_randomization, progress_bar = progress_bar)
solution = jax.tree_map(lambda x: x[0], solution)
solution = jax.tree_util.tree_map(lambda x: x[0], solution)
return solution

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion pychastic/vectorized_I_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def take(tensor, idx, fill=0):
# Non jit-friendly implementation
# illegal = jnp.logical_or(idx > p,idx < 1)
# return tensor[..., idx-1].at[..., illegal].set(fill)
legalized_idx = jnp.clip(idx, a_min=1, a_max=p)
legalized_idx = jnp.clip(idx, min=1, max=p)
illegal_mask = jnp.logical_or(idx > p, idx < 1)
return (
tensor[..., legalized_idx - 1] * (1 - 1 * illegal_mask)
Expand Down

0 comments on commit 2070785

Please sign in to comment.