From 185d7a9fd954b0617824675c5467e1b9a2c31648 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 19 Oct 2021 06:48:34 -0700 Subject: [PATCH] Delete xla_bridge.xla.dtype_to_etype, replace it with jax.interpreters.xla.dtype_to_primitive_type. The new version does *not* canonicalize dtypes. We should be canonicalizing dtypes as part of tracing to a jaxpr, not in any way as part of XLA lowering. In all cases as best I can tell the dtypes from the callers are already canonical anyway. jax.interpreters.xla is also a better location: I'm not even sure why we have a bunch of random things in xla_bridge any more, so it makes sense to consolidate them in xla.py along with the other registrations for things like avals. Also delete the unused function xla_bridge.supported_numpy_dtypes. PiperOrigin-RevId: 404246574 --- jax/_src/lax/lax.py | 67 +++++++++++-------- jax/_src/lax/parallel.py | 3 +- jax/_src/lib/xla_bridge.py | 12 ---- jax/experimental/djax.py | 2 +- jax/experimental/jax2tf/call_tf.py | 2 +- jax/experimental/jax2tf/impl_no_xla.py | 3 +- .../jax2tf/tests/primitive_harness.py | 8 ++- jax/experimental/maps.py | 6 +- jax/interpreters/pxla.py | 6 +- jax/interpreters/xla.py | 28 ++++++++ 10 files changed, 85 insertions(+), 52 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 478ff4dd0bc1..2eee69109a7f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -648,6 +648,8 @@ def conv_general_dilated( padding = padtype_to_pads( np.take(lhs.shape, lhs_perm)[2:], effective_rhs_shape, # type: ignore[index] window_strides, padding) + preferred_element_type = (None if preferred_element_type is None else + np.dtype(preferred_element_type)) return conv_general_dilated_p.bind( lhs, rhs, window_strides=tuple(window_strides), padding=tuple(padding), lhs_dilation=tuple(lhs_dilation), rhs_dilation=tuple(rhs_dilation), @@ -684,7 +686,8 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, """ if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.symbolic_equal_dim(lhs.shape[-1], rhs.shape[0]): return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())), - precision=precision, preferred_element_type=preferred_element_type) + precision=precision, + preferred_element_type=preferred_element_type) else: raise TypeError("Incompatible shapes for dot: got {} and {}.".format( lhs.shape, rhs.shape)) @@ -722,6 +725,8 @@ def dot_general(lhs: Array, rhs: Array, dimension_numbers: DotDimensionNumbers, contract_dims_seq, batch_dims_seq = dimension_numbers contract_dims = tuple(map(tuple, contract_dims_seq)) # type: ignore batch_dims = tuple(map(tuple, batch_dims_seq)) # type: ignore + preferred_element_type = (None if preferred_element_type is None else + np.dtype(preferred_element_type)) return dot_general_p.bind(lhs, rhs, dimension_numbers=(contract_dims, batch_dims), precision=canonicalize_precision(precision), @@ -3032,7 +3037,7 @@ def _convert_element_type_translation_rule(ctx, avals_in, avals_out, operand, *, if (dtypes.issubdtype(old_dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): operand = xops.Real(operand) - new_etype = xla_client.dtype_to_etype(new_dtype) + new_etype = xla.dtype_to_primitive_type(new_dtype) return [xops.ConvertElementType(operand, new_element_type=new_etype)] def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type): @@ -3082,7 +3087,7 @@ def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): def _bitcast_convert_type_translation_rule(ctx, avals_in, avals_out, operand, *, new_dtype): - new_etype = xla_bridge.dtype_to_etype(new_dtype) + new_etype = xla.dtype_to_primitive_type(new_dtype) return [xops.BitcastConvertType(operand, new_element_type=new_etype)] bitcast_convert_type_p = standard_primitive( @@ -3307,8 +3312,9 @@ def _conv_general_dilated_translation_rule( if preferred_element_type is not None: # Convert complex dtype to types used for real and imaginary parts assert np.issubdtype(preferred_element_type, np.complexfloating) - preferred_element_type = xla_client.dtype_to_etype( - np.float64 if preferred_element_type == np.complex128 else np.float32) + preferred_element_type = xla.dtype_to_primitive_type(np.dtype( + np.float64 if preferred_element_type == np.complex128 + else np.float32)) conv = lambda x, y: xops.ConvGeneralDilated( x, y, window_strides, padding, lhs_dilation, rhs_dilation, @@ -3323,7 +3329,7 @@ def _conv_general_dilated_translation_rule( return [xops.Complex(xops.Sub(k1, k3), xops.Add(k1, k2))] if preferred_element_type is not None: - preferred_element_type = xla_client.dtype_to_etype(preferred_element_type) + preferred_element_type = xla.dtype_to_primitive_type(preferred_element_type) return [xops.ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, @@ -3666,7 +3672,7 @@ def _dot_general_translation_rule(ctx, avals_in, avals_out, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: Optional[DType]): if preferred_element_type is not None: - preferred_element_type = xla_client.dtype_to_etype(preferred_element_type) + preferred_element_type = xla.dtype_to_primitive_type(preferred_element_type) return [xops.DotGeneral(lhs, rhs, xc.make_dot_dimension_numbers(dimension_numbers), precision_config=_precision_config(precision), @@ -3676,14 +3682,17 @@ def _dot_general_cpu_translation_rule(ctx, avals_in, avals_out, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: Optional[DType]): if preferred_element_type is not None: - preferred_element_type = xla_client.dtype_to_etype(preferred_element_type) + preferred_element_type = xla.dtype_to_primitive_type(preferred_element_type) # TODO(b/195364460): Work around slow XLA/CPU implementation of float16 matmul if avals_in[0].dtype == np.float16: - lhs = xops.ConvertElementType(lhs, xla_client.dtype_to_etype(np.float32)) - rhs = xops.ConvertElementType(rhs, xla_client.dtype_to_etype(np.float32)) - preferred_element_type = (preferred_element_type or - xla_client.dtype_to_etype(np.float16)) + lhs = xops.ConvertElementType( + lhs, xla.dtype_to_primitive_type(np.dtype(np.float32))) + rhs = xops.ConvertElementType( + rhs, xla.dtype_to_primitive_type(np.dtype(np.float32))) + preferred_element_type = ( + preferred_element_type or + xla.dtype_to_primitive_type(np.dtype(np.float16))) return [xops.DotGeneral(lhs, rhs, xc.make_dot_dimension_numbers(dimension_numbers), @@ -4738,7 +4747,7 @@ def _gather_translation_rule(ctx, avals_in, avals_out, operand, indices, *, intarray = partial(np.array, dtype=np.int64) operand_dims = intarray(operand_aval.shape) indices = xops.ConvertElementType( - indices, xb.dtype_to_etype(np.int64)) + indices, xla.dtype_to_primitive_type(dtypes.canonicalize_dtype(np.int64))) num_batch_dims = len(indices_aval.shape) - 1 upper_bound = operand_dims[intarray(dnums.start_index_map)] @@ -6254,15 +6263,15 @@ def _select_and_gather_add_shape_rule( window_dilation) _UINT_DTYPES = { - 16: np.uint16, - 32: np.uint32, - 64: np.uint64, + 16: np.dtype(np.uint16), + 32: np.dtype(np.uint32), + 64: np.dtype(np.uint64), } _INT_DTYPES = { - 16: np.int16, - 32: np.int32, - 64: np.int64, + 16: np.dtype(np.int16), + 32: np.dtype(np.int32), + 64: np.dtype(np.int64), } def _select_and_gather_add_translation( @@ -6272,7 +6281,7 @@ def _select_and_gather_add_translation( c = ctx.builder tangents_aval, operand_aval, = avals_in dtype = operand_aval.dtype - etype = xla_client.dtype_to_etype(dtype) + etype = xla.dtype_to_primitive_type(dtype) nbits = dtypes.finfo(dtype).bits assert nbits <= max_bits @@ -6287,8 +6296,8 @@ def _select_and_gather_add_translation( # 2k-bit unsigned integer using bit tricks. word_dtype = _UINT_DTYPES[nbits] double_word_dtype = _UINT_DTYPES[nbits * 2] - word_type = xla_client.dtype_to_etype(word_dtype) - double_word_type = xla_client.dtype_to_etype(double_word_dtype) + word_type = xla.dtype_to_primitive_type(word_dtype) + double_word_type = xla.dtype_to_primitive_type(double_word_dtype) # Packs two values into a tuple. def pack(a, b): @@ -6323,7 +6332,7 @@ def snd(t): nmant = r_nbits - nexp - 1 double_word_dtype = word_dtype = _UINT_DTYPES[nbits] - word_type = xla_client.dtype_to_etype(word_dtype) + word_type = xla.dtype_to_primitive_type(word_dtype) # Packs two values into a tuple. def pack(a, b): @@ -6497,7 +6506,7 @@ def _float_to_int_for_sort(x): signed = bitcast_convert_type(x, signed_dtype) unsigned = bitcast_convert_type(x, unsigned_dtype) flipped = bitcast_convert_type( - sub(unsigned_dtype(np.iinfo(signed_dtype).max), unsigned), signed_dtype) + sub(unsigned_dtype.type(np.iinfo(signed_dtype).max), unsigned), signed_dtype) return select(lt(signed, _zero(signed)), flipped, signed) # Default comparator that sorts the operands lexicographically on the @@ -6845,7 +6854,7 @@ def _rng_bit_generator_translation_rule( # TODO(mattjj): the BitcastConvertType segfaults on GPU # TODO(mattjj): remove fallback when minimum jaxlib is 0.1.72 or newer if jaxlib_version >= (0, 1, 72) and not backend_is_gpu: - u64_etype = xc.dtype_to_etype(dtypes.dtype('uint64')) + u64_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint64')) key = xops.BitcastConvertType(xops.Reshape(key, (2, 2)), u64_etype) else: key = _convert_4xU32_to_2xU64_without_bitcast(c, key) @@ -6853,14 +6862,14 @@ def _rng_bit_generator_translation_rule( c, xops.RngBitGenerator(algorithm, key, xla_shape)) if key_dtype == dtypes.dtype('uint32'): if jaxlib_version >= (0, 1, 72) and not backend_is_gpu: - u32_etype = xc.dtype_to_etype(dtypes.dtype('uint32')) + u32_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint32')) out_key = xops.Reshape(xops.BitcastConvertType(out_key, u32_etype), (4,)) else: out_key = _convert_2xU64_to_4xU32_without_bitcast(c, out_key) return [out_key, out_vals] def _convert_4xU32_to_2xU64_without_bitcast(c, key): - u64_etype = xc.dtype_to_etype(dtypes.dtype('uint64')) + u64_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint64')) new_key = xb.constant(c, np.zeros(2, dtype=np.dtype('uint64')), canonicalize_types=False) _32 = xb.constant(c, np.uint64(32), canonicalize_types=False) @@ -6872,7 +6881,7 @@ def _convert_4xU32_to_2xU64_without_bitcast(c, key): return new_key def _convert_2xU64_to_4xU32_without_bitcast(c, key): - u32_etype = xc.dtype_to_etype(dtypes.dtype('uint32')) + u32_etype = xla.dtype_to_primitive_type(dtypes.dtype('uint32')) new_key = xb.constant(c, np.zeros(4, dtype=np.dtype('uint32'))) _32 = xb.constant(c, np.uint64(32), canonicalize_types=False) for i in [0, 1]: @@ -6937,7 +6946,7 @@ def _iota_abstract_eval(*, dtype, shape, dimension): def _iota_translation_rule(ctx, avals_in, avals_out, *, dtype, shape, dimension): - etype = xla_client.dtype_to_etype(dtype) + etype = xla.dtype_to_primitive_type(dtype) xla_shape = xc.Shape.array_shape(etype, shape) return [xops.Iota(ctx.builder, xla_shape, dimension)] diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index d3b7549dff3b..d9f0a57efd60 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -1325,7 +1325,8 @@ def _build_axis_index_lowering(c, axis_name, axis_env): dtype=np.uint32)) mod = xb.constant(c, np.array(axis_env.sizes[axis_pos], dtype=np.uint32)) unsigned_index = xops.Rem(xops.Div(xops.ReplicaId(c), div), mod) - return xops.ConvertElementType(unsigned_index, xb.dtype_to_etype(np.int32)) + return xops.ConvertElementType( + unsigned_index, xla.dtype_to_primitive_type(np.dtype(np.int32))) def _axis_index_translation_rule(ctx, avals_in, avals_out, *, axis_name): return [_build_axis_index_lowering(ctx.builder, axis_name, ctx.axis_env)] diff --git a/jax/_src/lib/xla_bridge.py b/jax/_src/lib/xla_bridge.py index d65f52094ba1..bd20b87365c3 100644 --- a/jax/_src/lib/xla_bridge.py +++ b/jax/_src/lib/xla_bridge.py @@ -427,18 +427,6 @@ def host_ids(backend=None): ### utility functions -@util.memoize -def dtype_to_etype(dtype): - """Convert from dtype to canonical etype (reading config.x64_enabled).""" - return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype)) - - -@util.memoize -def supported_numpy_dtypes(): - return {dtypes.canonicalize_dtype(dtype) - for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()} - - # TODO(mattjj,frostig): try to remove this function def normalize_to_xla_dtypes(val): """Normalize dtypes in a value.""" diff --git a/jax/experimental/djax.py b/jax/experimental/djax.py index c954b26d5688..69f0d0529a57 100644 --- a/jax/experimental/djax.py +++ b/jax/experimental/djax.py @@ -1357,7 +1357,7 @@ def _iota_translation_rule(c, dims, avals, operands, *, size=None): shape = aval.shape else: shape = () - etype = xc.dtype_to_etype(np.dtype('int32')) + etype = xla.dtype_to_primitive_type(np.dtype('int32')) xla_shape = xc.Shape.array_shape(etype, (*shape, size)) return [[xops.Iota(c, xla_shape, len(shape))]] translations[iota_p] = _iota_translation_rule diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 8d162292607c..40ad5cefb140 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -396,7 +396,7 @@ def post_process_result(idx: int, res_aval: core.ShapedArray, res_shape: xla.Xla if res_aval.dtype != res_shape.numpy_dtype(): res_op = xops.ConvertElementType( res_op, - new_element_type=xla_client.dtype_to_etype(res_aval.dtype)) + new_element_type=xla.dtype_to_primitive_type(res_aval.dtype)) return res_op results = [ diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 1bc0db62087e..963420aee699 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -134,7 +134,8 @@ def error(msg): raise error("Unimplemented support for batch_group_count != 1 " f"(found {batch_group_count})") - if preferred_element_type is not None and preferred_element_type != lhs.dtype: + if (preferred_element_type is not None and + preferred_element_type != lhs.dtype.as_numpy_dtype): raise error("Unimplemented support for preferred_element_type") lhs, rhs = _transpose_for_tf_conv(lhs, rhs, dimension_numbers) diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index f326a6478798..0526a1575023 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -492,7 +492,8 @@ def _make_convert_element_type_harness(name, "convert_element_type", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_olddtype={jtu.dtype_str(dtype)}_newdtype={jtu.dtype_str(new_dtype)}", lambda arg: (lax.convert_element_type_p.bind( - arg, new_dtype=new_dtype, weak_type=False)), [RandArg(shape, dtype)], + arg, new_dtype=np.dtype(new_dtype), weak_type=False)), + [RandArg(shape, dtype)], shape=shape, dtype=dtype, new_dtype=new_dtype) @@ -660,7 +661,8 @@ def _make_bitcast_convert_type_harness(name, define( "bitcast_convert_type", f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_newdtype={np.dtype(new_dtype).name}", - lambda x: (lax.bitcast_convert_type_p.bind(x, new_dtype=new_dtype)), + lambda x: lax.bitcast_convert_type_p.bind(x, + new_dtype=np.dtype(new_dtype)), [RandArg(shape, dtype)], shape=shape, dtype=dtype, @@ -856,7 +858,7 @@ def _make_iota_harness(name, *, shape=(2, 3), dtype=np.float32, dimension=0): lax.iota_p, f"{name}_shape={jtu.format_shape_dtype_string(shape, dtype)}_dimension={dimension}", lambda dtype, shape, dim: - (lax.iota_p.bind(dtype=dtype, shape=shape, dimension=dim)), + (lax.iota_p.bind(dtype=np.dtype(dtype), shape=shape, dimension=dim)), [StaticArg(dtype), StaticArg(shape), StaticArg(dimension)], diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index b1597bac4ad4..04fc4657a6ca 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -1396,7 +1396,8 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend): convert_bool = (np.issubdtype(x_dtype, np.bool_) and xb.get_backend(backend).platform in ('cpu', 'gpu')) if convert_bool: - x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32)) + x = xops.ConvertElementType( + x, xla.dtype_to_primitive_type(np.dtype(np.float32))) tile_shape = list(xla_shape.dimensions()) shape = list(tile_shape) @@ -1413,7 +1414,8 @@ def _xla_untile(c, axis_env, x, out_axes, axis_sizes, backend): # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU if convert_bool: nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32))) - out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_)) + out = xops.ConvertElementType( + nonzero, xla.dtype_to_primitive_type(np.dtype(np.bool_))) return out def _xmap_translation_rule_spmd(c, axis_env, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 047099f2f10c..ac102dca8754 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -1339,7 +1339,8 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend): convert_bool = (np.issubdtype(aval.dtype, np.bool_) and xb.get_backend(backend).platform in ('cpu', 'gpu')) if convert_bool: - x = xops.ConvertElementType(x, xb.dtype_to_etype(np.float32)) + x = xops.ConvertElementType( + x, xla.dtype_to_primitive_type(np.dtype(np.float32))) xla_shape = c.get_shape(x) dims = list(xla_shape.dimensions()) @@ -1360,7 +1361,8 @@ def _xla_unshard(c, aval, axis_env, out_axis, x, backend): # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU if convert_bool: nonzero = xops.Ne(out, xb.constant(c, np.array(0, dtype=np.float32))) - out = xops.ConvertElementType(nonzero, xb.dtype_to_etype(np.bool_)) + out = xops.ConvertElementType( + nonzero, xla.dtype_to_primitive_type(np.dtype(np.bool_))) return out else: raise TypeError((aval, c.get_shape(x))) diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 6705d873b645..b7e1eab54bc9 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -128,6 +128,34 @@ def make_op_metadata(primitive: core.Primitive, ### handlers +_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = { + np.dtype('bool'): xc.PrimitiveType.PRED, + np.dtype('int8'): xc.PrimitiveType.S8, + np.dtype('int16'): xc.PrimitiveType.S16, + np.dtype('int32'): xc.PrimitiveType.S32, + np.dtype('int64'): xc.PrimitiveType.S64, + np.dtype('uint8'): xc.PrimitiveType.U8, + np.dtype('uint16'): xc.PrimitiveType.U16, + np.dtype('uint32'): xc.PrimitiveType.U32, + np.dtype('uint64'): xc.PrimitiveType.U64, + np.dtype(dtypes.bfloat16): xc.PrimitiveType.BF16, + np.dtype('float16'): xc.PrimitiveType.F16, + np.dtype('float32'): xc.PrimitiveType.F32, + np.dtype('float64'): xc.PrimitiveType.F64, + np.dtype('complex64'): xc.PrimitiveType.C64, + np.dtype('complex128'): xc.PrimitiveType.C128, +} + +def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType: + """Converts a NumPy dtype into an XLA PrimitiveType.""" + # Many things (e.g., strings, scalar types) can be compared with NumPy dtypes, + # but may not hash correctly. Make sure we have a true np.dtype. + assert isinstance(dtype, np.dtype), type(dtype) + try: + return _dtype_to_primitive_type[dtype] + except KeyError as err: + raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err + xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c)) def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]: