Skip to content

Commit

Permalink
Fix implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts committed Aug 12, 2024
1 parent 18e6ac2 commit bc5b0c4
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions envpool/python/xla_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from typing import Any, Callable, List, Tuple, Union

import numpy as np
from jax import core, dtypes, interpreters
from jax import core, dtypes
from jax import numpy as jnp
from jax.core import ShapedArray
from jax.interpreters import xla
from jax.interpreters import mlir, xla
from jax.lib import xla_client


Expand Down Expand Up @@ -91,8 +91,7 @@ def translation(c: Any, *args: Any, platform: str = "cpu") -> Any:
prim.multiple_results = (len(out_specs) > 1)
prim.def_impl(partial(xla.apply_primitive, prim))
prim.def_abstract_eval(abstract)
interpreters.mlir["cpu"][prim] = partial(translation, platform="cpu")
interpreters.mlir["gpu"][prim] = partial(translation, platform="gpu")
mlir.register_lowering(prim, translation)

def call(*args: Any) -> Any:
return prim.bind(*args)
Expand Down

0 comments on commit bc5b0c4

Please sign in to comment.