diff --git a/jax/BUILD b/jax/BUILD index c24010ac1f6d..13d636e634ac 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -96,7 +96,6 @@ py_library_providing_imports_info( "_src/array.py", "_src/callback.py", "_src/checkify.py", - "_src/core.py", "_src/custom_batching.py", "_src/custom_derivatives.py", "_src/custom_transpose.py", @@ -104,12 +103,9 @@ py_library_providing_imports_info( "_src/device_array.py", "_src/dispatch.py", "_src/dlpack.py", - "_src/dtypes.py", - "_src/errors.py", "_src/flatten_util.py", "_src/__init__.py", "_src/lax_reference.py", - "_src/linear_util.py", "_src/maps.py", "_src/pjit.py", "_src/prng.py", @@ -172,6 +168,7 @@ py_library_providing_imports_info( ":custom_api_util", ":config", ":deprecations", + ":core", ":effects", ":environment_info", ":lazy_loader", @@ -215,6 +212,27 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "core", + srcs = [ + "_src/core.py", + "_src/dtypes.py", + "_src/errors.py", + "_src/linear_util.py", + ], + deps = [ + ":config", + ":effects", + ":pretty_printer", + ":source_info_util", + ":traceback_util", + ":tree_util", + ":typing", + ":util", + "//jax/_src/lib", + ] + py_deps("ml_dtypes") + py_deps("numpy"), +) + pytype_strict_library( name = "custom_api_util", srcs = ["_src/custom_api_util.py"], diff --git a/jax/_src/api.py b/jax/_src/api.py index 60ea338b651f..ca464681e68a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2387,6 +2387,26 @@ def ax_leaf(l): return broadcast_prefix(abstracted_axes, args, ax_leaf) +# TODO(phawkins): for some reason mypy cannot determine these overloads are +# non-overlapping. Pytype is happy with them. +@overload +def make_jaxpr(fun: Callable, # type: ignore + static_argnums: Union[int, Iterable[int]] = (), + axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None, + return_shape: Literal[False] = ..., + abstracted_axes: Optional[Any] = None, + ) -> Callable[..., core.ClosedJaxpr]: + ... + +@overload +def make_jaxpr(fun: Callable, # type: ignore + static_argnums: Union[int, Iterable[int]] = (), + axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None, + return_shape: Literal[True] = ..., + abstracted_axes: Optional[Any] = None, + ) -> Callable[..., Tuple[core.ClosedJaxpr, Any]]: + ... + def make_jaxpr(fun: Callable, static_argnums: Union[int, Iterable[int]] = (), axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None, @@ -3062,7 +3082,7 @@ def clear_backends(): jax.lib.xla_bridge._backends = {} dispatch.xla_primitive_callable.cache_clear() pjit._pjit_lower_cached.cache_clear() - pjit._create_pjit_jaxpr.cache_clear() + pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error pjit._cpp_pjit_cache.clear() xc._xla.PjitFunctionCache.clear_all() diff --git a/jax/_src/core.py b/jax/_src/core.py index 61d5fcdbdcc5..df4a2e046b2d 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -40,8 +40,9 @@ from jax._src import config as jax_config from jax._src import effects from jax._src.config import FLAGS, config -from jax.errors import (ConcretizationTypeError, TracerArrayConversionError, - TracerIntegerConversionError, UnexpectedTracerError) +from jax._src.errors import ( + ConcretizationTypeError, TracerArrayConversionError, + TracerIntegerConversionError, UnexpectedTracerError) from jax._src import linear_util as lu from jax._src import source_info_util diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 77b601cdb1d2..7a833b1d76b5 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -600,6 +600,7 @@ def _check_sharding(aval, s): pjit.pjit_check_aval_sharding( (s,), (aval,), "device_put args", allow_uneven_sharding=False) + assert isinstance(aval, core.ShapedArray), aval s.shard_shape(aval.shape) # should raise an Error if incompatible diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 01c835df40ab..4142998eb427 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -443,7 +443,7 @@ def shard_device_array(x, devices, indices, sharding): shard_arg_handlers[t] = shard_device_array -def batched_device_put(aval: core.AbstractValue, +def batched_device_put(aval: core.ShapedArray, sharding: jax.sharding.Sharding, xs: Sequence[Any], devices: Sequence[jax.Device], committed: bool = True): from jax._src import array @@ -1794,7 +1794,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0): devices = xb.get_backend(backend).get_default_device_assignment(nrep) assert nrep == len(devices) - aval = xla.abstractify(val) # type: ShapedArray + aval = xla.abstractify(val) if in_axis is not None: replicated_aval = aval.update(shape=(axis_size,) + aval.shape) else: diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 6623069b2486..2b356e21e965 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -191,7 +191,7 @@ def _canonicalize_python_scalar_dtype(typ, x): canonicalize_dtype_handlers[core.Token] = identity canonicalize_dtype_handlers[core.DArray] = identity -def abstractify(x) -> core.AbstractValue: +def abstractify(x) -> Any: typ = type(x) aval_fn = pytype_aval_mappings.get(typ) if aval_fn: return aval_fn(x) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index bda8e73e18b6..7fd3e3951718 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -67,8 +67,8 @@ def trans1(static_arg, *dynamic_args, **kwargs): from typing import Any, Tuple, Callable, Optional, NamedTuple import weakref -from jax.tree_util import tree_map -from jax.config import config +from jax._src.tree_util import tree_map +from jax._src.config import config from jax._src import core from jax._src import traceback_util from jax._src.util import curry