Skip to content

Commit

Permalink
Move dtype canonicalization out of core.AbstractValue subclasses.
Browse files Browse the repository at this point in the history
This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context.

The callers to which canonicalization was added were:
a) all callers of `ConcreteArray` inside the JAX Tree.
b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures.

PiperOrigin-RevId: 414704700
  • Loading branch information
hawkinsp authored and jax authors committed Dec 7, 2021
1 parent 56f029f commit 06cd1fe
Show file tree
Hide file tree
Showing 14 changed files with 56 additions and 38 deletions.
11 changes: 8 additions & 3 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,23 @@ def zeros_like_array(x):
np.complex64, np.complex128,
np.longlong, np.intc}

def canonical_concrete_aval(val, weak_type=None):
return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val,
weak_type=weak_type)

for t in array_types:
core.pytype_aval_mappings[t] = ConcreteArray
core.pytype_aval_mappings[t] = canonical_concrete_aval
ad_util.jaxval_zeros_likers[t] = zeros_like_array

core.literalable_types.update(array_types)

def _zeros_like_python_scalar(t, x):
aval = core.ShapedArray((), dtypes.python_scalar_dtypes[t], weak_type=True)
dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[t])
aval = core.ShapedArray((), dtype, weak_type=True)
return ad_util.zeros_like_aval(aval)

def _make_concrete_python_scalar(t, x):
return ConcreteArray(
return canonical_concrete_aval(
np.array(x, dtype=dtypes._scalar_type_to_dtype(t, x)),
weak_type=True)

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
local_devices, process_index,
process_count, host_id, host_ids,
host_count, default_backend)
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.core import ShapedArray, raise_to_shaped
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import pxla
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/device_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from jax import core
from jax._src.config import config
from jax._src import abstract_arrays
from jax._src import dtypes
from jax._src import profiler
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -306,4 +307,4 @@ class DeletedBuffer(object): pass
device_array_types: List[type] = [xc.Buffer, _DeviceArray]
for _device_array in device_array_types:
core.literalable_types.add(_device_array)
core.pytype_aval_mappings[device_array] = core.ConcreteArray
core.pytype_aval_mappings[device_array] = abstract_arrays.canonical_concrete_aval
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,7 +1480,7 @@ def scan(f, init, xs, length=None):
return carry, stacked_y

x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat]
x_dtypes = [x.dtype for x in xs_flat]
x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat]
x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes))

def _create_jaxpr(init):
Expand Down Expand Up @@ -2038,7 +2038,7 @@ def masked(*args):
for new_c, c in zip(new_carry, carry)]
return [i + 1] + new_carry + ys

aval = ShapedArray((), dtypes.int_)
aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_))
const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry])
return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals)

Expand Down
11 changes: 6 additions & 5 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,7 +1240,8 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
if core.symbolic_equal_dim(operand.shape[0], 0):
output_shape = _gather_shape_rule(
core.ShapedArray(operand.shape[1:], operand.dtype),
core.ShapedArray(indices.shape[1:], indices.dtype),
core.ShapedArray(indices.shape[1:],
dtypes.canonicalize_dtype(indices.dtype)),
dimension_numbers=dimension_numbers, slice_sizes=slice_sizes,
unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
mode=mode, fill_value=fill_value)
Expand Down Expand Up @@ -1456,8 +1457,8 @@ def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices,
if mode == GatherScatterMode.CLIP:
clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False,
new_style=True)
indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)],
operand, indices, updates, dnums=dimension_numbers)
indices, = clip_fn(ctx, avals_in, None, operand, indices, updates,
dnums=dimension_numbers)

c = ctx.builder

Expand All @@ -1477,8 +1478,8 @@ def _scatter_add_translation_rule(
if mode == GatherScatterMode.CLIP:
clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False,
new_style=True)
indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)],
operand, indices, updates, dnums=dimension_numbers)
indices, = clip_fn(ctx, avals_in, None, operand, indices, updates,
dnums=dimension_numbers)

dtype = operand_aval.dtype
scatter_dims = _scatter_dimensions_proto(
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/lax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
least_specialized = _max(map(type, avals),
key=operator.attrgetter('array_abstraction_level'))
if least_specialized is core.ConcreteArray:
return core.ConcreteArray(prim.impl(*[x.val for x in avals], **kwargs),
weak_type=weak_type)
out = prim.impl(*[x.val for x in avals], **kwargs)
return core.ConcreteArray(out.dtype, out, weak_type=weak_type)
elif least_specialized is core.ShapedArray:
return core.ShapedArray(shape_rule(*avals, **kwargs),
dtype_rule(*avals, **kwargs), weak_type=weak_type,
Expand All @@ -81,7 +81,7 @@ def standard_multi_result_abstract_eval(
weak_types = weak_type_rule(*avals, **kwargs)
if least_specialized is core.ConcreteArray:
out_vals = prim.impl(*[x.val for x in avals], **kwargs)
return [core.ConcreteArray(val, weak_type=weak_type)
return [core.ConcreteArray(val.dtype, val, weak_type=weak_type)
for val, weak_type in safe_zip(out_vals, weak_types)]
elif least_specialized is core.ShapedArray:
out_shapes = shape_rule(*avals, **kwargs)
Expand Down
24 changes: 13 additions & 11 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ class UnshapedArray(AbstractValue):
array_abstraction_level = 2

def __init__(self, dtype, weak_type=False):
self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype))
self.dtype = np.dtype(dtype)
self.weak_type = weak_type

def update(self, dtype=None, weak_type=None):
Expand Down Expand Up @@ -1183,19 +1183,20 @@ class ConcreteArray(ShapedArray):
__slots__ = ['val']
array_abstraction_level = 0

def __init__(self, val, weak_type=None):
super().__init__(np.shape(val), np.result_type(val),
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
def __init__(self, dtype, val, weak_type=None):
super().__init__(
np.shape(val), dtype,
weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type)
# Note: canonicalized self.dtype doesn't necessarily match self.val
assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype)
self.val = val
assert self.dtype != np.dtype('O'), val

def update(self, val=None, weak_type=None):
if val is None:
val = self.val
if weak_type is None:
weak_type = self.weak_type
return ConcreteArray(val, weak_type)
def update(self, dtype=None, val=None, weak_type=None):
dtype = self.dtype if dtype is None else dtype
val = self.val if val is None else val
weak_type = self.weak_type if weak_type is None else weak_type
return ConcreteArray(dtype, val, weak_type)

def __eq__(self, other):
if (type(self) is type(other) and self.dtype == other.dtype
Expand Down Expand Up @@ -1271,7 +1272,8 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None):
Bot: lambda aval, _: aval,
UnshapedArray: lambda aval, _: aval,
ShapedArray: lambda aval, weak_type: ShapedArray(
aval.shape, aval.dtype, weak_type, aval.named_shape)
aval.shape, dtypes.canonicalize_dtype(aval.dtype), weak_type,
aval.named_shape)
}

### Operations on shapes and dimension sizes.
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,14 +351,14 @@ def is_fully_known_shape(s):
xla_comp_parameter_shapes = xla_comp.program_shape().parameter_shapes()
found_parameter_avals = [
core.ShapedArray(found_xla_shape.dimensions(),
found_xla_shape.numpy_dtype())
dtypes.canonicalize_dtype(found_xla_shape.numpy_dtype()))
for found_xla_shape in xla_comp_parameter_shapes
]
# Add the captured_inputs to args_flat_sig_tf
expected_args_flat_sig_tf = list(args_flat_sig_tf) + list(captured_inputs)
expected_parameter_avals = [
core.ShapedArray(tuple(arg_sig.shape.as_list()),
arg_sig.dtype.as_numpy_dtype)
dtypes.canonicalize_dtype(arg_sig.dtype.as_numpy_dtype))
for arg_sig in expected_args_flat_sig_tf]
if found_parameter_avals != expected_parameter_avals:
msg = ("Compiled TensorFlow function has unexpected parameter types " +
Expand Down
3 changes: 2 additions & 1 deletion jax/interpreters/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,8 @@ def __init__(self, trace, val, polymorphic_shape):

@property
def aval(self):
return ShapedArray(self.polymorphic_shape, self.dtype)
return ShapedArray(self.polymorphic_shape,
dtypes.canonicalize_dtype(self.dtype))

@property
def dtype(self):
Expand Down
3 changes: 2 additions & 1 deletion jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from jax._src.config import config
from jax import core
from jax import linear_util as lu
from jax._src import abstract_arrays
from jax._src.abstract_arrays import array_types
from jax.core import ConcreteArray, ShapedArray
from jax._src import device_array
Expand Down Expand Up @@ -740,7 +741,7 @@ def _register_handlers_for_sharded_device_array(sda):
shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path
xla.register_constant_handler(sda, _sharded_device_array_constant_handler)

core.pytype_aval_mappings[sda] = ConcreteArray
core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval
dispatch.device_put_handlers[sda] = dispatch._device_put_array
xla.pytype_aval_mappings[sda] = op.attrgetter("aval")
xla.canonicalize_dtype_handlers[sda] = identity
Expand Down
3 changes: 2 additions & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2729,7 +2729,8 @@ def test_join_concrete_arrays_with_omnistaging(self):

@jit
def f():
core.lattice_join(core.ConcreteArray(x), core.ConcreteArray(y))
core.lattice_join(core.ConcreteArray(x.dtype, x),
core.ConcreteArray(y.dtype, y))

f() # doesn't crash

Expand Down
3 changes: 2 additions & 1 deletion tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def f(x, y):
def test_concrete_array_string_representation(self):
# https://github.com/google/jax/issues/5364
self.assertEqual(
str(core.ConcreteArray(np.array([1], dtype=np.int32))),
str(core.ConcreteArray(np.dtype(np.int32),
np.array([1], dtype=np.int32))),
'ConcreteArray([1], dtype=int32)')


Expand Down
11 changes: 7 additions & 4 deletions tests/custom_object_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jax import core, jit, lax, make_jaxpr
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.lib.mlir import ir
Expand Down Expand Up @@ -66,13 +67,15 @@ class AbstractSparseArray(core.ShapedArray):

def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False,
named_shape=None):
super().__init__(shape, dtype)
super().__init__(shape, dtypes.canonicalize_dtype(dtype))
named_shape = {} if named_shape is None else named_shape
self.index_dtype = index_dtype
self.nnz = nnz
self.data_aval = core.ShapedArray((nnz,), dtype, weak_type, named_shape)
self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype,
named_shape=named_shape)
self.data_aval = core.ShapedArray((nnz,), dtypes.canonicalize_dtype(dtype),
weak_type, named_shape)
self.indices_aval = core.ShapedArray(
(nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype),
named_shape=named_shape)

def update(self, shape=None, dtype=None, index_dtype=None, nnz=None,
weak_type=None, named_shape=None):
Expand Down
6 changes: 4 additions & 2 deletions tests/jax_jit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def test_device_put_on_numpy_scalars(self, device_put_function):
output_buffer = device_put_function(value, device=device)

self.assertFalse(output_buffer.aval.weak_type)
dtype = dtypes.canonicalize_dtype(dtype)
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((), dtype))
self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
self.assertEqual(output_buffer.dtype, dtype)

@parameterized.parameters([jax.device_put, _cpp_device_put])
def test_device_put_on_numpy_arrays(self, device_put_function):
Expand All @@ -68,8 +69,9 @@ def test_device_put_on_numpy_arrays(self, device_put_function):
output_buffer = device_put_function(value, device=device)

self.assertFalse(output_buffer.aval.weak_type)
dtype = dtypes.canonicalize_dtype(dtype)
self.assertEqual(output_buffer.aval, jax.core.ShapedArray((3, 4), dtype))
self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype))
self.assertEqual(output_buffer.dtype, dtype)
np.testing.assert_array_equal(output_buffer, np.zeros((3, 4),
dtype=dtype))

Expand Down

0 comments on commit 06cd1fe

Please sign in to comment.