Skip to content

Commit

Permalink
Fix tree_map warning with latest JAX. (#105)
Browse files Browse the repository at this point in the history
Using `jax.tree_util.tree_map` directly.
  • Loading branch information
balancap authored Apr 4, 2024
1 parent 16cc0c1 commit d06fe30
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 12 deletions.
4 changes: 2 additions & 2 deletions experiments/mnist/mnist_classifier_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def data_stream():
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
@jsa.autoscale
Expand All @@ -118,7 +118,7 @@ def update(params, batch):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)

with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)
Expand Down
4 changes: 2 additions & 2 deletions experiments/mnist/mnist_classifier_from_scratch_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def data_stream():
params = init_random_params(param_scale, layer_sizes)
# Transform parameters to `ScaledArray` and proper dtype.
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))
params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
@jsa.autoscale
Expand All @@ -145,7 +145,7 @@ def update(params, batch):
batch = next(batches)
# Scaled micro-batch + training dtype cast.
batch = jsa.as_scaled_array(batch, scale=scale_dtype(1))
batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)

with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params = update(params, batch)
Expand Down
6 changes: 3 additions & 3 deletions experiments/mnist/optax_cifar_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ def data_stream():

batches = data_stream()
params = init_random_params(param_scale, layer_sizes)
params = jax.tree_map(lambda v: v.astype(training_dtype), params)
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params)
# Transform parameters to `ScaledArray` and proper dtype.
optimizer = optax.adam(learning_rate=lr, eps=1e-5)
opt_state = optimizer.init(params)

if use_autoscale:
params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale))

params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)
params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf)

@jit
@autoscale
Expand All @@ -143,7 +143,7 @@ def update(params, batch, opt_state):
# Scaled micro-batch + training dtype cast.
if use_autoscale:
batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale))
batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)
batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf)

with jsa.AutoScaleConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype):
params, opt_state = update(params, batch, opt_state)
Expand Down
4 changes: 2 additions & 2 deletions jax_scaled_arithmetics/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def as_scaled_array(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray:
Returns:
Scaled array instance.
"""
return jax.tree_map(lambda x: as_scaled_array_base(x, scale), val, is_leaf=is_scaled_leaf)
return jax.tree_util.tree_map(lambda x: as_scaled_array_base(x, scale), val, is_leaf=is_scaled_leaf)


def asarray_base(val: Any, dtype: DTypeLike = None) -> GenericArray:
Expand All @@ -239,7 +239,7 @@ def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray:
Args:
dtype: Optional dtype of the final array.
"""
return jax.tree_map(lambda x: asarray_base(x, dtype), val, is_leaf=is_scaled_leaf)
return jax.tree_util.tree_map(lambda x: asarray_base(x, dtype), val, is_leaf=is_scaled_leaf)


def is_numpy_scalar_or_array(val):
Expand Down
2 changes: 1 addition & 1 deletion jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def wrapped(*args, **kwargs):
if len(kwargs) > 0:
raise NotImplementedError("`autoscale` JAX interpreter not supporting named tensors at present.")

aval_args = jax.tree_map(_get_aval, args, is_leaf=is_scaled_leaf)
aval_args = jax.tree_util.tree_map(_get_aval, args, is_leaf=is_scaled_leaf)
# Get jaxpr of unscaled/normal graph. Getting output Pytree shape as well.
closed_jaxpr, outshape = jax.make_jaxpr(fun, return_shape=True)(*aval_args, **kwargs)
out_leaves, out_pytree = jax.tree_util.tree_flatten(outshape)
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def test__autoscale_decorator__proper_graph_transformation_and_result(self, fn,
scaled_fn = self.variant(autoscale(fn))
scaled_output = scaled_fn(*inputs)
# Normal JAX path, without scaled arrays.
raw_inputs = jax.tree_map(np.asarray, inputs, is_leaf=is_scaled_leaf)
raw_inputs = jax.tree_util.tree_map(np.asarray, inputs, is_leaf=is_scaled_leaf)
expected_output = self.variant(fn)(*raw_inputs)

# Do we re-construct properly the output type (i.e. handling Pytree properly)?
Expand Down
1 change: 0 additions & 1 deletion tests/core/test_pow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def test__get_mantissa__proper_value__multi_dtypes(self, val_mant, dtype):
assert val_mant.dtype == val.dtype
assert val_mant.shape == ()
assert type(val_mant) in {type(val), np.ndarray}
print(mant, val_mant, dtype)
npt.assert_equal(val_mant, mant)
# Should be consistent with `pow2_round_down`. bitwise, not approximation.
npt.assert_equal(mant * pow2_round_down(val), val)

0 comments on commit d06fe30

Please sign in to comment.