From 06cd1fedeebe6764de8132f3d158c0e6d700cf38 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 7 Dec 2021 06:12:32 -0800 Subject: [PATCH] Move dtype canonicalization out of core.AbstractValue subclasses. This is a strictly mechanical change that moves abstract value canonicalization out of the core.AbstractValue subclasses and into their callers. This makes it safe to manipulate non-canonical abstract values even inside an -x32 context. The callers to which canonicalization was added were: a) all callers of `ConcreteArray` inside the JAX Tree. b) all callers of `ShapedArray` and `UnshapedArray` that were found to be passing non-canonical dtypes during a global presubmit. These were identified by adding an assertion that the dtype is in fact canonical and fixing all the resulting test failures. PiperOrigin-RevId: 414704700 --- jax/_src/abstract_arrays.py | 11 ++++++++--- jax/_src/api.py | 2 +- jax/_src/device_array.py | 3 ++- jax/_src/lax/control_flow.py | 4 ++-- jax/_src/lax/slicing.py | 11 ++++++----- jax/_src/lax/utils.py | 6 +++--- jax/core.py | 24 +++++++++++++----------- jax/experimental/jax2tf/call_tf.py | 4 ++-- jax/interpreters/masking.py | 3 ++- jax/interpreters/pxla.py | 3 ++- tests/api_test.py | 3 ++- tests/core_test.py | 3 ++- tests/custom_object_test.py | 11 +++++++---- tests/jax_jit_test.py | 6 ++++-- 14 files changed, 56 insertions(+), 38 deletions(-) diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 51ee2e4272ff..3893fcc43505 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -49,18 +49,23 @@ def zeros_like_array(x): np.complex64, np.complex128, np.longlong, np.intc} +def canonical_concrete_aval(val, weak_type=None): + return ConcreteArray(dtypes.canonicalize_dtype(np.result_type(val)), val, + weak_type=weak_type) + for t in array_types: - core.pytype_aval_mappings[t] = ConcreteArray + core.pytype_aval_mappings[t] = canonical_concrete_aval ad_util.jaxval_zeros_likers[t] = zeros_like_array core.literalable_types.update(array_types) def _zeros_like_python_scalar(t, x): - aval = core.ShapedArray((), dtypes.python_scalar_dtypes[t], weak_type=True) + dtype = dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[t]) + aval = core.ShapedArray((), dtype, weak_type=True) return ad_util.zeros_like_aval(aval) def _make_concrete_python_scalar(t, x): - return ConcreteArray( + return canonical_concrete_aval( np.array(x, dtype=dtypes._scalar_type_to_dtype(t, x)), weak_type=True) diff --git a/jax/_src/api.py b/jax/_src/api.py index cf7b2b6f5956..d0cb6ba6ec5a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -69,7 +69,7 @@ local_devices, process_index, process_count, host_id, host_ids, host_count, default_backend) -from jax.core import ConcreteArray, ShapedArray, raise_to_shaped +from jax.core import ShapedArray, raise_to_shaped from jax.interpreters import partial_eval as pe from jax.interpreters import xla from jax.interpreters import pxla diff --git a/jax/_src/device_array.py b/jax/_src/device_array.py index a26b9754bc94..d2e367024196 100644 --- a/jax/_src/device_array.py +++ b/jax/_src/device_array.py @@ -23,6 +23,7 @@ from jax import core from jax._src.config import config +from jax._src import abstract_arrays from jax._src import dtypes from jax._src import profiler from jax._src.lib import xla_client as xc @@ -306,4 +307,4 @@ class DeletedBuffer(object): pass device_array_types: List[type] = [xc.Buffer, _DeviceArray] for _device_array in device_array_types: core.literalable_types.add(_device_array) - core.pytype_aval_mappings[device_array] = core.ConcreteArray + core.pytype_aval_mappings[device_array] = abstract_arrays.canonical_concrete_aval diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 781bcd4f2bff..2b8c7f2defba 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -1480,7 +1480,7 @@ def scan(f, init, xs, length=None): return carry, stacked_y x_shapes = [masking.padded_shape_as_value(x.shape[1:]) for x in xs_flat] - x_dtypes = [x.dtype for x in xs_flat] + x_dtypes = [dtypes.canonicalize_dtype(x.dtype) for x in xs_flat] x_avals = tuple(_map(ShapedArray, x_shapes, x_dtypes)) def _create_jaxpr(init): @@ -2038,7 +2038,7 @@ def masked(*args): for new_c, c in zip(new_carry, carry)] return [i + 1] + new_carry + ys - aval = ShapedArray((), dtypes.int_) + aval = ShapedArray((), dtypes.canonicalize_dtype(dtypes.int_)) const_avals, carry_avals, x_avals = split_list(jaxpr.in_avals, [num_consts, num_carry]) return _make_closed_jaxpr(masked, [aval] + const_avals + [aval] + carry_avals + x_avals) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 46a090634b15..f1d713c0c946 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1240,7 +1240,8 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers, if core.symbolic_equal_dim(operand.shape[0], 0): output_shape = _gather_shape_rule( core.ShapedArray(operand.shape[1:], operand.dtype), - core.ShapedArray(indices.shape[1:], indices.dtype), + core.ShapedArray(indices.shape[1:], + dtypes.canonicalize_dtype(indices.dtype)), dimension_numbers=dimension_numbers, slice_sizes=slice_sizes, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value) @@ -1456,8 +1457,8 @@ def _scatter_translation_rule(ctx, avals_in, avals_out, operand, indices, if mode == GatherScatterMode.CLIP: clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False, new_style=True) - indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)], - operand, indices, updates, dnums=dimension_numbers) + indices, = clip_fn(ctx, avals_in, None, operand, indices, updates, + dnums=dimension_numbers) c = ctx.builder @@ -1477,8 +1478,8 @@ def _scatter_add_translation_rule( if mode == GatherScatterMode.CLIP: clip_fn = xla.lower_fun(_clamp_scatter_indices, multiple_results=False, new_style=True) - indices, = clip_fn(ctx, avals_in, [indices_aval.update(dtype=np.int64)], - operand, indices, updates, dnums=dimension_numbers) + indices, = clip_fn(ctx, avals_in, None, operand, indices, updates, + dnums=dimension_numbers) dtype = operand_aval.dtype scatter_dims = _scatter_dimensions_proto( diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 030858b4932b..b445856b61fa 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -60,8 +60,8 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, least_specialized = _max(map(type, avals), key=operator.attrgetter('array_abstraction_level')) if least_specialized is core.ConcreteArray: - return core.ConcreteArray(prim.impl(*[x.val for x in avals], **kwargs), - weak_type=weak_type) + out = prim.impl(*[x.val for x in avals], **kwargs) + return core.ConcreteArray(out.dtype, out, weak_type=weak_type) elif least_specialized is core.ShapedArray: return core.ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs), weak_type=weak_type, @@ -81,7 +81,7 @@ def standard_multi_result_abstract_eval( weak_types = weak_type_rule(*avals, **kwargs) if least_specialized is core.ConcreteArray: out_vals = prim.impl(*[x.val for x in avals], **kwargs) - return [core.ConcreteArray(val, weak_type=weak_type) + return [core.ConcreteArray(val.dtype, val, weak_type=weak_type) for val, weak_type in safe_zip(out_vals, weak_types)] elif least_specialized is core.ShapedArray: out_shapes = shape_rule(*avals, **kwargs) diff --git a/jax/core.py b/jax/core.py index 59c1cec89684..66183e07e743 100644 --- a/jax/core.py +++ b/jax/core.py @@ -1041,7 +1041,7 @@ class UnshapedArray(AbstractValue): array_abstraction_level = 2 def __init__(self, dtype, weak_type=False): - self.dtype = np.dtype(dtypes.canonicalize_dtype(dtype)) + self.dtype = np.dtype(dtype) self.weak_type = weak_type def update(self, dtype=None, weak_type=None): @@ -1183,19 +1183,20 @@ class ConcreteArray(ShapedArray): __slots__ = ['val'] array_abstraction_level = 0 - def __init__(self, val, weak_type=None): - super().__init__(np.shape(val), np.result_type(val), - weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type) + def __init__(self, dtype, val, weak_type=None): + super().__init__( + np.shape(val), dtype, + weak_type=dtypes.is_weakly_typed(val) if weak_type is None else weak_type) # Note: canonicalized self.dtype doesn't necessarily match self.val + assert self.dtype == dtypes.canonicalize_dtype(np.result_type(val)), (val, dtype) self.val = val assert self.dtype != np.dtype('O'), val - def update(self, val=None, weak_type=None): - if val is None: - val = self.val - if weak_type is None: - weak_type = self.weak_type - return ConcreteArray(val, weak_type) + def update(self, dtype=None, val=None, weak_type=None): + dtype = self.dtype if dtype is None else dtype + val = self.val if val is None else val + weak_type = self.weak_type if weak_type is None else weak_type + return ConcreteArray(dtype, val, weak_type) def __eq__(self, other): if (type(self) is type(other) and self.dtype == other.dtype @@ -1271,7 +1272,8 @@ def raise_to_shaped(aval: AbstractValue, weak_type=None): Bot: lambda aval, _: aval, UnshapedArray: lambda aval, _: aval, ShapedArray: lambda aval, weak_type: ShapedArray( - aval.shape, aval.dtype, weak_type, aval.named_shape) + aval.shape, dtypes.canonicalize_dtype(aval.dtype), weak_type, + aval.named_shape) } ### Operations on shapes and dimension sizes. diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index a0175a2bf64d..c8b08ce581e3 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -351,14 +351,14 @@ def is_fully_known_shape(s): xla_comp_parameter_shapes = xla_comp.program_shape().parameter_shapes() found_parameter_avals = [ core.ShapedArray(found_xla_shape.dimensions(), - found_xla_shape.numpy_dtype()) + dtypes.canonicalize_dtype(found_xla_shape.numpy_dtype())) for found_xla_shape in xla_comp_parameter_shapes ] # Add the captured_inputs to args_flat_sig_tf expected_args_flat_sig_tf = list(args_flat_sig_tf) + list(captured_inputs) expected_parameter_avals = [ core.ShapedArray(tuple(arg_sig.shape.as_list()), - arg_sig.dtype.as_numpy_dtype) + dtypes.canonicalize_dtype(arg_sig.dtype.as_numpy_dtype)) for arg_sig in expected_args_flat_sig_tf] if found_parameter_avals != expected_parameter_avals: msg = ("Compiled TensorFlow function has unexpected parameter types " + diff --git a/jax/interpreters/masking.py b/jax/interpreters/masking.py index edf51a82693b..6333ace7c8e0 100644 --- a/jax/interpreters/masking.py +++ b/jax/interpreters/masking.py @@ -459,7 +459,8 @@ def __init__(self, trace, val, polymorphic_shape): @property def aval(self): - return ShapedArray(self.polymorphic_shape, self.dtype) + return ShapedArray(self.polymorphic_shape, + dtypes.canonicalize_dtype(self.dtype)) @property def dtype(self): diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 03c9019e9e24..c155e4f69f23 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -45,6 +45,7 @@ from jax._src.config import config from jax import core from jax import linear_util as lu +from jax._src import abstract_arrays from jax._src.abstract_arrays import array_types from jax.core import ConcreteArray, ShapedArray from jax._src import device_array @@ -740,7 +741,7 @@ def _register_handlers_for_sharded_device_array(sda): shard_arg_handlers[sda] = _shard_sharded_device_array_slow_path xla.register_constant_handler(sda, _sharded_device_array_constant_handler) - core.pytype_aval_mappings[sda] = ConcreteArray + core.pytype_aval_mappings[sda] = abstract_arrays.canonical_concrete_aval dispatch.device_put_handlers[sda] = dispatch._device_put_array xla.pytype_aval_mappings[sda] = op.attrgetter("aval") xla.canonicalize_dtype_handlers[sda] = identity diff --git a/tests/api_test.py b/tests/api_test.py index a143ee93b5ef..bc43bbadc04b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2729,7 +2729,8 @@ def test_join_concrete_arrays_with_omnistaging(self): @jit def f(): - core.lattice_join(core.ConcreteArray(x), core.ConcreteArray(y)) + core.lattice_join(core.ConcreteArray(x.dtype, x), + core.ConcreteArray(y.dtype, y)) f() # doesn't crash diff --git a/tests/core_test.py b/tests/core_test.py index f9a11d57862d..95dc2e6c0ed8 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -336,7 +336,8 @@ def f(x, y): def test_concrete_array_string_representation(self): # https://github.com/google/jax/issues/5364 self.assertEqual( - str(core.ConcreteArray(np.array([1], dtype=np.int32))), + str(core.ConcreteArray(np.dtype(np.int32), + np.array([1], dtype=np.int32))), 'ConcreteArray([1], dtype=int32)') diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index 09681ec7824d..af340dc2ad4d 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -21,6 +21,7 @@ from jax import core, jit, lax, make_jaxpr from jax._src import device_array from jax._src import dispatch +from jax._src import dtypes from jax.interpreters import mlir from jax.interpreters import xla from jax._src.lib.mlir import ir @@ -66,13 +67,15 @@ class AbstractSparseArray(core.ShapedArray): def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False, named_shape=None): - super().__init__(shape, dtype) + super().__init__(shape, dtypes.canonicalize_dtype(dtype)) named_shape = {} if named_shape is None else named_shape self.index_dtype = index_dtype self.nnz = nnz - self.data_aval = core.ShapedArray((nnz,), dtype, weak_type, named_shape) - self.indices_aval = core.ShapedArray((nnz, len(shape)), index_dtype, - named_shape=named_shape) + self.data_aval = core.ShapedArray((nnz,), dtypes.canonicalize_dtype(dtype), + weak_type, named_shape) + self.indices_aval = core.ShapedArray( + (nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype), + named_shape=named_shape) def update(self, shape=None, dtype=None, index_dtype=None, nnz=None, weak_type=None, named_shape=None): diff --git a/tests/jax_jit_test.py b/tests/jax_jit_test.py index 292968a275b5..57fe0fb41874 100644 --- a/tests/jax_jit_test.py +++ b/tests/jax_jit_test.py @@ -56,8 +56,9 @@ def test_device_put_on_numpy_scalars(self, device_put_function): output_buffer = device_put_function(value, device=device) self.assertFalse(output_buffer.aval.weak_type) + dtype = dtypes.canonicalize_dtype(dtype) self.assertEqual(output_buffer.aval, jax.core.ShapedArray((), dtype)) - self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype)) + self.assertEqual(output_buffer.dtype, dtype) @parameterized.parameters([jax.device_put, _cpp_device_put]) def test_device_put_on_numpy_arrays(self, device_put_function): @@ -68,8 +69,9 @@ def test_device_put_on_numpy_arrays(self, device_put_function): output_buffer = device_put_function(value, device=device) self.assertFalse(output_buffer.aval.weak_type) + dtype = dtypes.canonicalize_dtype(dtype) self.assertEqual(output_buffer.aval, jax.core.ShapedArray((3, 4), dtype)) - self.assertEqual(output_buffer.dtype, dtypes.canonicalize_dtype(dtype)) + self.assertEqual(output_buffer.dtype, dtype) np.testing.assert_array_equal(output_buffer, np.zeros((3, 4), dtype=dtype))