Skip to content

Commit

Permalink
[Pallas] Upstream pallas to JAX
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 552963029
  • Loading branch information
sharadmv authored and The jax_triton Authors committed Aug 1, 2023
1 parent b1b11d0 commit d258f8f
Show file tree
Hide file tree
Showing 23 changed files with 32 additions and 7,288 deletions.
8 changes: 4 additions & 4 deletions jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
"""Library for JAX-Triton integrations."""
import jaxlib
from jax._src.lib import gpu_triton
from jax_triton import pallas
from jax_triton import utils
from jax_triton.triton_lib import triton_call
from jax_triton.utils import cdiv
from jax_triton.utils import next_power_of_2
from jax_triton.utils import strides_from_shape
from jax.experimental.pallas import cdiv
from jax.experimental.pallas import next_power_of_2
from jax.experimental.pallas import strides_from_shape
from jax_triton.version import __version__
from jax_triton.version import __version_info__

Expand Down
36 changes: 2 additions & 34 deletions jax_triton/pallas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for pallas, a jaxpr "dialect" for Triton."""
from jax_triton.pallas.core import BlockSpec
from jax_triton.pallas.indexing import ds
from jax_triton.pallas.indexing import dslice
from jax_triton.pallas.indexing import broadcast_to
from jax_triton.pallas.pallas_call import pallas_call
from jax_triton.pallas.pallas_call import pallas_call_p
from jax_triton.pallas.primitives import atomic_add
from jax_triton.pallas.primitives import atomic_and
from jax_triton.pallas.primitives import atomic_cas
from jax_triton.pallas.primitives import atomic_max
from jax_triton.pallas.primitives import atomic_min
from jax_triton.pallas.primitives import atomic_or
from jax_triton.pallas.primitives import atomic_xchg
from jax_triton.pallas.primitives import atomic_xor
from jax_triton.pallas.primitives import dot
from jax_triton.pallas.primitives import load
from jax_triton.pallas.primitives import max_contiguous
from jax_triton.pallas.primitives import multiple_of
from jax_triton.pallas.primitives import program_id
from jax_triton.pallas.primitives import store
from jax_triton.pallas.primitives import swap
from jax_triton.pallas.utils import when
from jax_triton.utils import cdiv

try:
from jax_triton.pallas import triton
except (ImportError, ModuleNotFoundError):
pass

try:
from jax_triton.pallas import mosaic
except (ImportError, ModuleNotFoundError):
pass
"""Points pallas to JAX."""
from jax.experimental.pallas import *
226 changes: 0 additions & 226 deletions jax_triton/pallas/core.py

This file was deleted.

Loading

0 comments on commit d258f8f

Please sign in to comment.