diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index 22a996efd64a..bf5ec0dfc8af 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -19,21 +19,15 @@ import dataclasses import functools import hashlib -import itertools import math import os import pathlib -import subprocess -import tempfile import time from typing import Any, Generic, TypeVar import weakref import jax -from jax._src import config -from jax._src import core as jax_core from jax._src.interpreters import mlir -from jax._src.lib import xla_client from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import builtin @@ -42,7 +36,6 @@ from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvvm -from jaxlib.mlir.passmanager import PassManager import numpy as np from . import profiler