From fbc05ee5ac34a019427929704d0099573ef1e4f1 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 29 Mar 2023 09:22:34 -0700 Subject: [PATCH] Remove global_arg_shapes from pmap since it was only used for sharded_jit and sharded_jit was removed from JAX a long time ago PiperOrigin-RevId: 520356179 --- CHANGELOG.md | 3 +++ docs/jaxpr.rst | 1 - jax/_src/api.py | 48 ++++++++++------------------------- jax/_src/interpreters/pxla.py | 36 ++++++-------------------- jax/_src/lax/lax.py | 3 +-- tests/host_callback_test.py | 1 - 6 files changed, 26 insertions(+), 66 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c31d7d8ffe2..8c99a7825803 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,9 @@ Remember to align the itemized text with the first line of an item within a list * CUDA 11.4 support has been dropped. JAX GPU wheels only support CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built from source. + * `global_arg_shapes` argument of pmap only worked with sharded_jit and has + been removed from pmap. Please migrate to pjit and remove global_arg_shapes + from pmap. ## jaxlib 0.4.8 diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst index 5b01c105e9a1..d33b1748001f 100644 --- a/docs/jaxpr.rst +++ b/docs/jaxpr.rst @@ -460,7 +460,6 @@ captured using the ``xla_pmap`` primitive. Consider this example in (k,) } devices=None donated_invars=(False, False) - global_arg_shapes=(None,) global_axis_size=1 in_axes=(None, 0) is_explicit_global_axis_size=False diff --git a/jax/_src/api.py b/jax/_src/api.py index ca464681e68a..76214c6c331f 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1437,12 +1437,6 @@ def pmap( For more details on buffer donation see the `FAQ `_. - global_arg_shapes: Optional, must be set when using pmap(sharded_jit) and - the partitioned values span multiple processes. The global cross-process - per-replica shape of each argument, i.e. does not include the leading - pmapped dimension. Can be None for replicated arguments. This API is - likely to change in the future. - Returns: A parallelized version of ``fun`` with arguments that correspond to those of ``fun`` but with extra array axes at positions indicated by ``in_axes`` and @@ -1565,6 +1559,12 @@ def pmap( >>> print(f2(jnp.array([2., 3.]))) # doctest: +SKIP [ 13. 13.] """ + if global_arg_shapes is not None: + raise ValueError( + "global_arg_shapes only worked with sharded_jit which has long been" + " removed from JAX. Please migrate to pjit and remove global_arg_shapes" + " from pmap.") + if FLAGS.experimental_cpp_pmap: func = _cpp_pmap else: @@ -1579,8 +1579,7 @@ def pmap( devices=devices, backend=backend, axis_size=axis_size, - donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes) + donate_argnums=donate_argnums) class PmapCallInfo(NamedTuple): @@ -1591,7 +1590,6 @@ class PmapCallInfo(NamedTuple): donated_invars: Sequence[bool] in_axes_flat: Sequence[Optional[int]] local_axis_size: int - global_arg_shapes_flat: Sequence[Optional[Tuple[int, ...]]] out_axes_thunk: HashableFunction devices: Optional[Sequence[xc.Device]] global_axis_size: int @@ -1628,7 +1626,7 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str, return global_axis_size def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, - donate_tuple, global_arg_shapes, in_devices, backend_name, + donate_tuple, in_devices, backend_name, axis_size, args, kwargs): if in_devices is not None and len(in_devices) == 0: raise ValueError("'devices' argument to pmap must be non-empty, or None.") @@ -1651,15 +1649,8 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, dyn_in_axes = tuple(in_axes[i] for i in dyn_argnums) else: dyn_in_axes = in_axes - dyn_global_arg_shapes = global_arg_shapes - - if isinstance(global_arg_shapes, tuple): - dyn_global_arg_shapes = tuple(global_arg_shapes[i] for i in dyn_argnums) - else: - dyn_global_arg_shapes = global_arg_shapes else: dyn_args, dyn_in_axes = args, in_axes - dyn_global_arg_shapes = global_arg_shapes args, in_tree = tree_flatten((dyn_args, kwargs)) if donate_tuple and not config.jax_debug_nans: @@ -1667,9 +1658,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, else: donated_invars = (False,) * len(args) in_axes_flat = tuple(flatten_axes("pmap in_axes", in_tree, (dyn_in_axes, 0))) - global_arg_shapes_flat = tuple(flatten_axes( - "pmap global_arg_shapes", in_tree, (dyn_global_arg_shapes, None), - kws=True)) local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap") f, res_paths = result_paths(f) @@ -1709,7 +1697,6 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, donated_invars=donated_invars, in_axes_flat=in_axes_flat, local_axis_size=local_axis_size, - global_arg_shapes_flat=global_arg_shapes_flat, out_axes_thunk=out_axes_thunk, devices=None if in_devices is None else tuple(in_devices), global_axis_size=global_axis_size, @@ -1727,12 +1714,11 @@ def _get_f_mapped( backend: Optional[str], axis_size: Optional[int], donate_tuple: Tuple[int, ...], - global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]], ): def pmap_f(*args, **kwargs): p = _prepare_pmap( fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, - global_arg_shapes, devices, backend, axis_size, args, kwargs) + devices, backend, axis_size, args, kwargs) for arg in p.flat_args: dispatch.check_arg(arg) out = pxla.xla_pmap( @@ -1741,7 +1727,6 @@ def pmap_f(*args, **kwargs): devices=p.devices, in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk, name=p.flat_fun.__name__, donated_invars=p.donated_invars, - global_arg_shapes=p.global_arg_shapes_flat, is_explicit_global_axis_size=p.is_explicit_global_axis_size) return p.out_tree, out @@ -1780,7 +1765,6 @@ def _python_pmap( backend: Optional[str] = None, axis_size: Optional[int] = None, donate_argnums: Union[int, Iterable[int]] = (), - global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None, ) -> stages.Wrapped: """The Python only implementation.""" axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( @@ -1799,7 +1783,6 @@ def pmap_f(*args, **kwargs): devices=devices, backend=backend, axis_size=axis_size, - global_arg_shapes=global_arg_shapes, donate_tuple=donate_tuple) out_tree, out_flat = f_pmapped_(*args, **kwargs) @@ -1807,7 +1790,7 @@ def pmap_f(*args, **kwargs): pmap_f.lower = _pmap_lower( fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices, - backend, axis_size, global_arg_shapes, donate_tuple) + backend, axis_size, donate_tuple) return cast(stages.Wrapped, pmap_f) @@ -1842,7 +1825,6 @@ def _cpp_pmap( backend: Optional[str] = None, axis_size: Optional[int] = None, donate_argnums: Union[int, Iterable[int]] = (), - global_arg_shapes: Optional[Tuple[Tuple[int, ...], ...]] = None, ) -> Any: axis_name, static_broadcasted_tuple, donate_tuple = _shared_code_pmap( fun, axis_name, static_broadcasted_argnums, donate_argnums, in_axes, @@ -1852,7 +1834,7 @@ def _cpp_pmap( @api_boundary def cache_miss(*args, **kwargs): p = _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple, - donate_tuple, global_arg_shapes, devices, backend, + donate_tuple, devices, backend, axis_size, args, kwargs) for arg in p.flat_args: dispatch.check_arg(arg) @@ -1867,7 +1849,6 @@ def cache_miss(*args, **kwargs): out_axes_thunk=p.out_axes_thunk, name=p.flat_fun.__name__, donated_invars=p.donated_invars, - global_arg_shapes=p.global_arg_shapes_flat, is_explicit_global_axis_size=p.is_explicit_global_axis_size, ) @@ -1939,13 +1920,13 @@ def cache_miss(*args, **kwargs): pmap_f.lower = _pmap_lower( fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, devices, - backend, axis_size, global_arg_shapes, donate_tuple) + backend, axis_size, donate_tuple) return pmap_f def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple, - devices, backend, axis_size, global_arg_shapes, donate_tuple): # noqa: F811 + devices, backend, axis_size, donate_tuple): # noqa: F811 """Make a ``lower`` method for pmapped functions.""" # If the function we returned from ``pmap`` were a class instance, # this might naturally be a method, with ``fun`` as a ``self`` and @@ -1966,7 +1947,7 @@ def lower(*args, _experimental_lowering_platform: Optional[str] = None, """ p = _prepare_pmap( fun, in_axes, out_axes, static_broadcasted_tuple, donate_tuple, - global_arg_shapes, devices, backend, axis_size, args, kwargs) + devices, backend, axis_size, args, kwargs) abstract_args = list(map(shaped_abstractify, p.flat_args)) computation = pxla.lower_parallel_callable( p.flat_fun, backend, axis_name, @@ -1976,7 +1957,6 @@ def lower(*args, _experimental_lowering_platform: Optional[str] = None, in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk, donated_invars=p.donated_invars, - global_arg_shapes=p.global_arg_shapes_flat, is_explicit_global_axis_size=p.is_explicit_global_axis_size, avals=abstract_args, lowering_platform=_experimental_lowering_platform) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 4142998eb427..e421728d1cc4 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -745,25 +745,22 @@ def xla_pmap_impl_lazy( in_axes: Sequence[Optional[int]], out_axes_thunk: Callable[[], Sequence[Optional[int]]], donated_invars: Sequence[bool], - global_arg_shapes: Sequence[Optional[Tuple[int, ...]]], is_explicit_global_axis_size: bool, ) -> Callable: if (config.jax_disable_jit and config.jax_eager_pmap and - not is_explicit_global_axis_size and not any(d for d in donated_invars) - and not all(g is not None for g in global_arg_shapes)): + not is_explicit_global_axis_size and not any(d for d in donated_invars)): def _emap_apply_fn(*args): return _emap_impl(fun, *args, backend=backend, axis_name=axis_name, axis_size=axis_size, global_axis_size=global_axis_size, devices=devices, name=name, in_axes=in_axes, out_axes_thunk=out_axes_thunk, donated_invars=donated_invars, - global_arg_shapes=global_arg_shapes, is_explicit_global_axis_size=is_explicit_global_axis_size) return _emap_apply_fn abstract_args = unsafe_map(xla.abstractify, args) compiled_fun, fingerprint = parallel_callable( fun, backend, axis_name, axis_size, global_axis_size, devices, name, - in_axes, out_axes_thunk, donated_invars, global_arg_shapes, + in_axes, out_axes_thunk, donated_invars, is_explicit_global_axis_size, *abstract_args) # Don't re-abstractify args unless logging is enabled for performance. @@ -793,15 +790,12 @@ def _emap_impl(fun: lu.WrappedFun, *args, in_axes: Sequence[Optional[int]], out_axes_thunk: Callable[[], Sequence[Optional[int]]], donated_invars: Sequence[bool], - global_arg_shapes: Sequence[Optional[Tuple[int, ...]]], is_explicit_global_axis_size: bool, ): from jax._src import array # TODO(sharadmv,mattjj): implement these cases if any(d for d in donated_invars): raise NotImplementedError("Buffer donation not supported in eager pmap.") - if any(g is not None for g in global_arg_shapes): - raise NotImplementedError("Global arg shapes not supported in eager pmap.") if is_explicit_global_axis_size: raise NotImplementedError("Non-default global_axis_size not supported in " "eager pmap.") @@ -1029,12 +1023,11 @@ def parallel_callable(fun: lu.WrappedFun, in_axes: Sequence[Optional[int]], out_axes_thunk: Callable[[], Sequence[Optional[int]]], donated_invars: Sequence[bool], - global_arg_shapes: Sequence[Optional[Tuple[int, ...]]], is_explicit_global_axis_size: bool, *avals): pmap_computation = lower_parallel_callable( fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, - in_axes, out_axes_thunk, donated_invars, global_arg_shapes, + in_axes, out_axes_thunk, donated_invars, is_explicit_global_axis_size, avals, lowering_platform=None) pmap_executable = pmap_computation.compile() return WeakRefList([pmap_executable.unsafe_call, pmap_executable.fingerprint]) @@ -1091,26 +1084,17 @@ def find_replicas(jaxpr, axis_size, global_axis_size): def stage_parallel_callable( pci: ParallelCallableInfo, - fun: lu.WrappedFun, - global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]): + fun: lu.WrappedFun): sharded_avals = tuple( shard_aval(pci.axis_size, axis, aval) if axis is not None else aval for axis, aval in safe_zip(pci.in_axes, pci.avals)) - if any(s is not None for s in global_arg_shapes): - # TODO(skye): we could take this branch unconditionally if we handled - # grad of global_arg_shapes correctly. - global_sharded_avals = [ - aval.update(shape=shape) if shape is not None else aval - for shape, aval in safe_zip(global_arg_shapes, sharded_avals)] - else: - global_sharded_avals = sharded_avals # type: ignore with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " "for pmap in {elapsed_time} sec", event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( - fun, global_sharded_avals, pe.debug_info_final(fun, "pmap")) + fun, sharded_avals, pe.debug_info_final(fun, "pmap")) jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) @@ -1133,7 +1117,7 @@ def stage_parallel_callable( num_global_shards = replicas.num_global_replicas * parts.num_partitions shards = ShardInfo( - sharded_avals, out_sharded_avals, global_sharded_avals, + sharded_avals, out_sharded_avals, sharded_avals, num_local_shards, num_global_shards) return jaxpr, consts, replicas, parts, shards @@ -1158,7 +1142,6 @@ def lower_parallel_callable( in_axes: Iterable[Optional[int]], out_axes_thunk: Callable[[], Sequence[Optional[int]]], donated_invars: Sequence[bool], - global_arg_shapes: Sequence[Optional[Tuple[int, ...]]], is_explicit_global_axis_size: bool, avals: Sequence[core.AbstractValue], *, @@ -1197,8 +1180,7 @@ def lower_parallel_callable( pci = ParallelCallableInfo( name, backend, axis_name, axis_size, global_axis_size, devices, in_axes, out_axes_thunk, avals) - jaxpr, consts, replicas, parts, shards = stage_parallel_callable( - pci, fun, global_arg_shapes) + jaxpr, consts, replicas, parts, shards = stage_parallel_callable(pci, fun) if logger.isEnabledFor(logging.DEBUG): logger.debug("sharded_avals: %s", shards.sharded_avals) logger.debug("global_sharded_avals: %s", shards.global_sharded_avals) @@ -1976,7 +1958,6 @@ def _pmap_dce_rule(used_outputs, eqn): eqn.params['global_axis_size'], None): new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) _, donated_invars = partition_list(used_inputs, eqn.params['donated_invars']) - # TODO(yashkatariya,mattjj): Handle global_arg_shapes here too. _, in_axes = partition_list(used_inputs, eqn.params['in_axes']) _, out_axes = partition_list(used_outputs, eqn.params['out_axes']) new_params = dict(eqn.params, call_jaxpr=new_jaxpr, @@ -2095,8 +2076,7 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, pl def _pmap_lowering(ctx, *in_nodes, axis_name, axis_size, global_axis_size, devices, name, call_jaxpr, backend=None, in_axes, out_axes, - donated_invars, global_arg_shapes, - is_explicit_global_axis_size): + donated_invars, is_explicit_global_axis_size): del donated_invars # Unused. xla.check_backend_matches(backend, ctx.module_context.platform) # We in-line here rather than generating a Call HLO as in the xla_call diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index df988a04ba47..59d7b83daefc 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4272,14 +4272,13 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs): _identity_fn, None, (), (), sharded_dim, sharded_dim) p = api._prepare_pmap( _identity_fn, sharded_dim, sharded_dim, static_broadcasted_tuple, - donate_tuple, None, None, None, None, args, kwargs) + donate_tuple, None, None, None, args, kwargs) out_flat = pxla.xla_pmap_impl( p.flat_fun, *p.flat_args, backend=None, axis_name=axis_name, axis_size=p.local_axis_size, global_axis_size=p.global_axis_size, devices=p.devices, in_axes=p.in_axes_flat, out_axes_thunk=p.out_axes_thunk, name=p.flat_fun.__name__, donated_invars=p.donated_invars, - global_arg_shapes=p.global_arg_shapes_flat, is_explicit_global_axis_size=p.is_explicit_global_axis_size, ) return tree_util.tree_unflatten(p.out_tree(), out_flat) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 6128192e6eb7..3314fce7d9ba 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -2999,7 +2999,6 @@ def f(xv): in (c, f, g) } devices=None donated_invars=(False, False, False) - global_arg_shapes=(None,) global_axis_size=None in_axes=(0, 0, 0) name=