diff --git a/envpool/python/xla_template.py b/envpool/python/xla_template.py index 4238f945..2fd802b8 100644 --- a/envpool/python/xla_template.py +++ b/envpool/python/xla_template.py @@ -21,7 +21,7 @@ from jax import core, dtypes from jax import numpy as jnp from jax.core import ShapedArray -from jax.interpreters import xla +from jax.interpreters import mlir, xla from jax.lib import xla_client @@ -91,12 +91,7 @@ def translation(c: Any, *args: Any, platform: str = "cpu") -> Any: prim.multiple_results = (len(out_specs) > 1) prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(abstract) - xla.backend_specific_translations["cpu"][prim] = partial( - translation, platform="cpu" - ) - xla.backend_specific_translations["gpu"][prim] = partial( - translation, platform="gpu" - ) + mlir.register_lowering(prim, translation) def call(*args: Any) -> Any: return prim.bind(*args)