Skip to content

Commit

Permalink
Merge branch 'new_translators' into new_test_suite
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Sep 26, 2024
2 parents cc10a48 + 18bdee9 commit 3588a77
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ConvertElementTypeTranslator(mapped_base.MappedOperationTranslatorBase):
will perform the type conversion operation.
Note:
The type to cast to id inferred from the output variable and the `new_dtype`
The type to cast to is inferred from the output variable and the `new_dtype`
parameter of the equation is ignored.
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def gather_translator( # noqa: PLR0914 [too-many-locals] # Can not reduce any
if eqn.params["mode"] != jax_lax.GatherScatterMode.PROMISE_IN_BOUNDS:
raise NotImplementedError(f"The mode {eqn.params['mode']} is not implemented.")

# This is the size of the slice window that is copied. Its length equal the rank
# This is the size of the slice window that is copied. Its length is the rank
# of the source array, dimensions that are excluded from copying are listed
# in `collapsed_slice_dims`.
slice_sizes: Sequence[int] = eqn.params["slice_sizes"]
Expand Down
4 changes: 1 addition & 3 deletions src/jace/translator/primitive_translators/pjit_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def pjit_translator(
nested_jaxpr: jax_core.ClosedJaxpr = eqn.params["jaxpr"]
in_shardings = eqn.params["in_shardings"]
out_shardings = eqn.params["out_shardings"]
_ = eqn.params["donated_invars"] # Always ignored
_ = eqn.params["keep_unused"]
_ = eqn.params["inline"]
# "donated_invars", "keep_unused", "inline" parameters are just ignored

if not all(in_sharding is jax_sharding.UNSPECIFIED for in_sharding in in_shardings):
raise NotImplementedError("Currently 'pjit' does not support sharding in its input.")
Expand Down

0 comments on commit 3588a77

Please sign in to comment.