Skip to content

Commit

Permalink
[pallas] Added MemoryRef and run_scoped to the API docs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683349061
  • Loading branch information
superbobry authored and Google-ML-Automation committed Oct 7, 2024
1 parent ce2b497 commit 76d5938
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
4 changes: 4 additions & 0 deletions docs/jax.experimental.pallas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Classes
GridSpec
Slice

MemoryRef

Functions
---------

Expand All @@ -36,3 +38,5 @@ Functions
atomic_xchg

debug_print

run_scoped
12 changes: 5 additions & 7 deletions jax/_src/pallas/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,15 +819,13 @@ def debug_print_lowering_rule(ctx, *args, **params):
run_scoped_p.multiple_results = True


def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any:
"""Call the function with allocated references.
def run_scoped(f: Callable[..., Any], *types: Any, **kw_types: Any) -> Any:
"""Calls the function with allocated references and returns the result.
Args:
f: The function that generates the jaxpr.
*types: The types of the function's positional arguments.
**kw_types: The types of the function's keyword arguments.
The positional and keyword arguments describe which reference types
to allocate for each argument. Each backend has its own set of reference
types in addition to :class:`jax.experimental.pallas.MemoryRef`.
"""

flat_types, in_tree = tree_util.tree_flatten((types, kw_types))
flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree)
avals = [t.get_ref_aval() for t in flat_types]
Expand Down
4 changes: 3 additions & 1 deletion jax/experimental/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from jax._src.pallas.core import CostEstimate
from jax._src.pallas.core import GridSpec
from jax._src.pallas.core import IndexingMode
from jax._src.pallas.core import MemorySpace
from jax._src.pallas.core import MemoryRef
from jax._src.pallas.core import no_block_spec
from jax._src.pallas.core import Unblocked
from jax._src.pallas.core import unblocked
from jax._src.pallas.core import MemorySpace
from jax._src.pallas.pallas_call import pallas_call
from jax._src.pallas.pallas_call import pallas_call_p
from jax._src.pallas.primitives import atomic_add
Expand Down Expand Up @@ -57,4 +58,5 @@
from jax._src.state.indexing import Slice
from jax._src.state.primitives import broadcast_to


ANY = MemorySpace.ANY

0 comments on commit 76d5938

Please sign in to comment.