Skip to content

Commit

Permalink
Split core.py and several files in an SCC with it into a separate Baz…
Browse files Browse the repository at this point in the history
…el build target.

PiperOrigin-RevId: 520192610
  • Loading branch information
hawkinsp authored and jax authors committed Mar 29, 2023
1 parent 8c4fed6 commit c2d6fcc
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 12 deletions.
26 changes: 22 additions & 4 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,16 @@ py_library_providing_imports_info(
"_src/array.py",
"_src/callback.py",
"_src/checkify.py",
"_src/core.py",
"_src/custom_batching.py",
"_src/custom_derivatives.py",
"_src/custom_transpose.py",
"_src/debugging.py",
"_src/device_array.py",
"_src/dispatch.py",
"_src/dlpack.py",
"_src/dtypes.py",
"_src/errors.py",
"_src/flatten_util.py",
"_src/__init__.py",
"_src/lax_reference.py",
"_src/linear_util.py",
"_src/maps.py",
"_src/pjit.py",
"_src/prng.py",
Expand Down Expand Up @@ -172,6 +168,7 @@ py_library_providing_imports_info(
":custom_api_util",
":config",
":deprecations",
":core",
":effects",
":environment_info",
":lazy_loader",
Expand Down Expand Up @@ -215,6 +212,27 @@ pytype_strict_library(
],
)

pytype_strict_library(
name = "core",
srcs = [
"_src/core.py",
"_src/dtypes.py",
"_src/errors.py",
"_src/linear_util.py",
],
deps = [
":config",
":effects",
":pretty_printer",
":source_info_util",
":traceback_util",
":tree_util",
":typing",
":util",
"//jax/_src/lib",
] + py_deps("ml_dtypes") + py_deps("numpy"),
)

pytype_strict_library(
name = "custom_api_util",
srcs = ["_src/custom_api_util.py"],
Expand Down
22 changes: 21 additions & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2387,6 +2387,26 @@ def ax_leaf(l):
return broadcast_prefix(abstracted_axes, args, ax_leaf)


# TODO(phawkins): for some reason mypy cannot determine these overloads are
# non-overlapping. Pytype is happy with them.
@overload
def make_jaxpr(fun: Callable, # type: ignore
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
return_shape: Literal[False] = ...,
abstracted_axes: Optional[Any] = None,
) -> Callable[..., core.ClosedJaxpr]:
...

@overload
def make_jaxpr(fun: Callable, # type: ignore
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
return_shape: Literal[True] = ...,
abstracted_axes: Optional[Any] = None,
) -> Callable[..., Tuple[core.ClosedJaxpr, Any]]:
...

def make_jaxpr(fun: Callable,
static_argnums: Union[int, Iterable[int]] = (),
axis_env: Optional[Sequence[Tuple[AxisName, int]]] = None,
Expand Down Expand Up @@ -3062,7 +3082,7 @@ def clear_backends():
jax.lib.xla_bridge._backends = {}
dispatch.xla_primitive_callable.cache_clear()
pjit._pjit_lower_cached.cache_clear()
pjit._create_pjit_jaxpr.cache_clear()
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error
pjit._cpp_pjit_cache.clear()
xc._xla.PjitFunctionCache.clear_all()

Expand Down
5 changes: 3 additions & 2 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
from jax._src import config as jax_config
from jax._src import effects
from jax._src.config import FLAGS, config
from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from jax._src import linear_util as lu

from jax._src import source_info_util
Expand Down
1 change: 1 addition & 0 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def _check_sharding(aval, s):
pjit.pjit_check_aval_sharding(
(s,), (aval,), "device_put args", allow_uneven_sharding=False)

assert isinstance(aval, core.ShapedArray), aval
s.shard_shape(aval.shape) # should raise an Error if incompatible


Expand Down
4 changes: 2 additions & 2 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def shard_device_array(x, devices, indices, sharding):
shard_arg_handlers[t] = shard_device_array


def batched_device_put(aval: core.AbstractValue,
def batched_device_put(aval: core.ShapedArray,
sharding: jax.sharding.Sharding, xs: Sequence[Any],
devices: Sequence[jax.Device], committed: bool = True):
from jax._src import array
Expand Down Expand Up @@ -1794,7 +1794,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
devices = xb.get_backend(backend).get_default_device_assignment(nrep)
assert nrep == len(devices)

aval = xla.abstractify(val) # type: ShapedArray
aval = xla.abstractify(val)
if in_axis is not None:
replicated_aval = aval.update(shape=(axis_size,) + aval.shape)
else:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _canonicalize_python_scalar_dtype(typ, x):
canonicalize_dtype_handlers[core.Token] = identity
canonicalize_dtype_handlers[core.DArray] = identity

def abstractify(x) -> core.AbstractValue:
def abstractify(x) -> Any:
typ = type(x)
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def trans1(static_arg, *dynamic_args, **kwargs):
from typing import Any, Tuple, Callable, Optional, NamedTuple
import weakref

from jax.tree_util import tree_map
from jax.config import config
from jax._src.tree_util import tree_map
from jax._src.config import config
from jax._src import core
from jax._src import traceback_util
from jax._src.util import curry
Expand Down

0 comments on commit c2d6fcc

Please sign in to comment.