Skip to content

Commit

Permalink
Add a default scale dtype to AutoScaleConfig (#79)
Browse files Browse the repository at this point in the history
The default scale dtype can be set with `AutoScaleConfig(scale_dtype=xxx)` context manager.
Having a default scale dtype information when the dtype of `data` and `scale` are not the same. If no
information is provided, the `autoscale` JAX interpreter has no way of deciding which dtype is the proper
one during the initial tracing.
  • Loading branch information
balancap authored Jan 11, 2024
1 parent 621d85e commit 7329bc7
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 8 deletions.
7 changes: 6 additions & 1 deletion jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from jax._src.util import safe_map

from .datatype import NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .datatype import DTypeLike, NDArray, ScaledArray, as_scaled_array_base, is_scaled_leaf
from .utils import Pow2RoundMode


Expand All @@ -25,9 +25,14 @@ class AutoScaleConfig:
NOTE: this config can be locally changed using a Python context manager:
`with AutoScaleConfig(...):`
Args:
rounding_mode: Power-of-2 rounding mode.
scale_dtype: Scale (default) datatype.
"""

rounding_mode: Pow2RoundMode = Pow2RoundMode.DOWN
scale_dtype: DTypeLike = None

def __enter__(self):
global _autoscale_config_stack
Expand Down
28 changes: 25 additions & 3 deletions jax_scaled_arithmetics/lax/base_scaling_primitives.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import logging
from typing import Optional, Sequence, Union

import numpy as np
Expand All @@ -12,6 +13,7 @@
ScaledArray,
ScaledPrimitiveType,
asarray,
get_autoscale_config,
is_static_one_scalar,
register_scaled_op,
safe_div,
Expand Down Expand Up @@ -163,6 +165,11 @@ def scaled_stop_scaling(values: ScaledArray, dtype: Optional[DTypeLike] = None)
"""


def get_scale_dtype() -> Optional[DTypeLike]:
"""Get the scale dtype, if set in the AutoScale config."""
return get_autoscale_config().scale_dtype


def get_data_scale(values: Array) -> Array:
"""`get_data_scale` primitive call method."""
return get_data_scale_p.bind(values)
Expand All @@ -171,27 +178,42 @@ def get_data_scale(values: Array) -> Array:
def get_data_scale_impl(values: Array) -> Array:
if isinstance(values, ScaledArray):
return (values.data, values.scale)
scale = np.ones((), dtype=values.dtype)
# Use array dtype for scale by default.
scale_dtype = get_scale_dtype() or values.dtype
scale = np.ones((), dtype=scale_dtype)
return values, scale


def get_data_scale_abstract_eval(values: core.ShapedArray) -> core.ShapedArray:
if isinstance(values, ScaledArray):
return (values.data, values.scale)
return values, core.ShapedArray((), dtype=values.dtype)
# Use array dtype for scale by default.
scale_dtype = get_scale_dtype() or values.dtype
print(scale_dtype)
return values, core.ShapedArray((), dtype=scale_dtype)


def get_data_scale_mlir_lowering(
ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]]
) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]:
# Just forwarding `values` term, adding a constant scalar scale(1).
assert len(args) == 1
scale = ir_constant(np.ones((), dtype=ctx.avals_in[0].dtype))
assert len(ctx.avals_in) == 1
assert len(ctx.avals_out) == 2
# Scale dtype "decided" during initial JAX tracing.
scale_dtype = ctx.avals_out[1].dtype
scale = ir_constant(np.ones((), dtype=scale_dtype))
return (args[0], scale)


def scaled_get_data_scale(values: ScaledArray) -> Array:
"""Scaled `get_data_scale` implementation: return scale tensor."""
scale_dtype = get_scale_dtype()
# Mis-match may potentially create issues (i.e. not equivalent scale dtype after autoscale tracer)!
if scale_dtype != values.scale.dtype:
logging.warning(
f"Autoscale config scale dtype not matching ScaledArray scale dtype: '{values.scale.dtype}' vs '{scale_dtype}'. AutoScale graph transformation may fail because of that."
)
return values.data, values.scale


Expand Down
4 changes: 3 additions & 1 deletion tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,11 @@ def test__autoscale_config__default_values(self):
cfg = get_autoscale_config()
assert isinstance(cfg, AutoScaleConfig)
assert cfg.rounding_mode == Pow2RoundMode.DOWN
assert cfg.scale_dtype is None

def test__autoscale_config__context_manager(self):
with AutoScaleConfig(rounding_mode=Pow2RoundMode.NONE):
with AutoScaleConfig(rounding_mode=Pow2RoundMode.NONE, scale_dtype=np.float32):
cfg = get_autoscale_config()
assert isinstance(cfg, AutoScaleConfig)
assert cfg.rounding_mode == Pow2RoundMode.NONE
assert cfg.scale_dtype == np.float32
10 changes: 7 additions & 3 deletions tests/lax/test_base_scaling_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy.testing as npt
from absl.testing import parameterized

from jax_scaled_arithmetics.core import Array, ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.core import Array, AutoScaleConfig, ScaledArray, autoscale, scaled_array
from jax_scaled_arithmetics.lax.base_scaling_primitives import (
get_data_scale,
rebalance,
Expand Down Expand Up @@ -146,13 +146,17 @@ class GetDataScalePrimitiveTests(chex.TestCase):
@chex.variants(with_jit=True, without_jit=True)
def test__get_data_scale_primitive__proper_result_without_autoscale(self):
def fn(arr):
return get_data_scale(arr)
# Set a default scale dtype.
with AutoScaleConfig(scale_dtype=np.float32):
return get_data_scale(arr)

fn = self.variant(fn)
arr = jnp.array([2, 3], dtype=np.float16)
data, scale = fn(arr)
assert data.dtype == np.float16
assert scale.dtype == np.float32
npt.assert_array_equal(data, arr)
npt.assert_equal(scale, np.array(1, arr.dtype))
npt.assert_equal(scale, np.array(1, np.float32))

@chex.variants(with_jit=True, without_jit=True)
def test__get_data_scale_primitive__proper_result_with_autoscale(self):
Expand Down

0 comments on commit 7329bc7

Please sign in to comment.