diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD index 027da12ce6d3..72aae09af4a2 100644 --- a/benchmarks/mosaic/BUILD +++ b/benchmarks/mosaic/BUILD @@ -49,8 +49,8 @@ jax_test( disable_configs = DISABLED_CONFIGS, tags = ["notap"], deps = [ + "//jax:mosaic_gpu", + "//jax/experimental/mosaic/gpu/examples:matmul", "//third_party/py/google_benchmark", - "//third_party/py/jax:mosaic_gpu", - "//third_party/py/jax/experimental/mosaic/gpu/examples:matmul", ] + py_deps("absl/testing") + py_deps("numpy"), ) diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD index 93715bdac171..0591eed1fbec 100644 --- a/docs/cuda_custom_call/BUILD +++ b/docs/cuda_custom_call/BUILD @@ -56,8 +56,8 @@ cuda_library( name = "foo_", srcs = ["foo.cu.cc"], deps = [ + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@local_config_cuda//cuda:cuda_headers", ], ) diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index fccf0cc37048..6e4647b5e491 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_applicable_licenses = ["//third_party/py/jax:license"]) +package(default_applicable_licenses = ["//jax:license"]) licenses(["notice"]) @@ -21,13 +21,13 @@ cc_binary( srcs = ["main.cc"], tags = ["manual"], deps = [ - "//third_party/absl/status:statusor", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:platform_port", "@xla//xla:literal", "@xla//xla:literal_util", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt/cpu:cpu_client", "@xla//xla/tools:hlo_module_loader", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:platform_port", ], ) diff --git a/jax/BUILD b/jax/BUILD index ec350b4b99a7..574559688c4d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -76,39 +76,32 @@ package_group( packages = [ # Intentionally avoid jax dependencies on jax.extend. # See https://jax.readthedocs.io/en/latest/jep/15856-jex.html - "//third_party/py/jax/tests/...", + "//tests/...", ] + jax_extend_internal_users, ) package_group( name = "mosaic_users", - packages = [ - "//...", - ] + mosaic_internal_users, + includes = [":internal"], + packages = mosaic_internal_users, ) package_group( name = "pallas_gpu_users", - packages = [ - "//...", - "//learning/brain/research/jax", - ] + pallas_gpu_internal_users, + includes = [":internal"], + packages = pallas_gpu_internal_users, ) package_group( name = "pallas_tpu_users", - packages = [ - "//...", - "//learning/brain/research/jax", - ] + pallas_tpu_internal_users, + includes = [":internal"], + packages = pallas_tpu_internal_users, ) package_group( name = "mosaic_gpu_users", - packages = [ - "//...", - "//learning/brain/research/jax", - ] + mosaic_gpu_internal_users, + includes = [":internal"], + packages = mosaic_gpu_internal_users, ) # JAX-private test utilities. diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 09cc3a81c2c2..7068c0ef6732 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -22,7 +22,7 @@ load( package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) py_library_providing_imports_info( diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index c0fa02131bc8..4ff7062ac1e8 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -21,7 +21,7 @@ load( package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index f1616962f349..071f09f3f567 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -20,7 +20,7 @@ load("//jaxlib:jax.bzl", "py_deps") package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 8f351020a86f..9d2dfd8dfa0f 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -24,7 +24,7 @@ load( package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 01d2480983d5..c40fb19ec808 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -23,7 +23,7 @@ load( package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/jax/experimental/jax2tf/g3doc/BUILD b/jax/experimental/jax2tf/g3doc/BUILD index 424d3b8b9e5d..6222b82b3550 100644 --- a/jax/experimental/jax2tf/g3doc/BUILD +++ b/jax/experimental/jax2tf/g3doc/BUILD @@ -15,7 +15,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"], + default_visibility = ["//jax/experimental/jax2tf:__subpackages__"], ) filegroup( diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD index f584ab5d3191..3417c1abf6ac 100644 --- a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD @@ -18,7 +18,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"], + default_visibility = ["//jax/experimental/jax2tf:__subpackages__"], ) py_library( diff --git a/jax/experimental/jax2tf/tests/flax_models/BUILD b/jax/experimental/jax2tf/tests/flax_models/BUILD index 19afb4a6877c..d3af9581ae02 100644 --- a/jax/experimental/jax2tf/tests/flax_models/BUILD +++ b/jax/experimental/jax2tf/tests/flax_models/BUILD @@ -19,7 +19,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"], + default_visibility = ["//jax/experimental/jax2tf:__subpackages__"], ) py_library( @@ -27,8 +27,8 @@ py_library( srcs = glob(["*.py"]), srcs_version = "PY3", deps = [ + "//jax", "//third_party/py/flax:core", - "//third_party/py/jax", "//third_party/py/jraph", "//third_party/py/numpy", "//third_party/py/typing_extensions", diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index 3f9496b38376..6f5af51fbf0f 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:jax.bzl", "py_deps") load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("//jaxlib:jax.bzl", "py_deps") licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax:mosaic_gpu_users"], + default_visibility = ["//jax:mosaic_gpu_users"], ) exports_files( @@ -27,15 +27,15 @@ exports_files( "flash_attention.py", "matmul.py", ], - visibility = ["//third_party/py/jax:internal"], + visibility = ["//jax:internal"], ) py_library( name = "matmul", srcs = ["matmul.py"], deps = [ - "//third_party/py/jax", - "//third_party/py/jax:mosaic_gpu", + "//jax", + "//jax:mosaic_gpu", ], ) @@ -43,8 +43,8 @@ py_library( name = "flash_attention", srcs = ["flash_attention.py"], deps = [ - "//third_party/py/jax", - "//third_party/py/jax:mosaic_gpu", + "//jax", + "//jax:mosaic_gpu", ], ) @@ -58,8 +58,8 @@ py_test( "requires-gpu-sm90-only", ], deps = [ + "//jax", + "//jax:mosaic_gpu", "//learning/brain/research/jax:gpu_support", - "//third_party/py/jax", - "//third_party/py/jax:mosaic_gpu", ] + py_deps("numpy"), ) diff --git a/jax/tools/build_defs.bzl b/jax/tools/build_defs.bzl index 1540afe42a6a..06f5e69833c5 100644 --- a/jax/tools/build_defs.bzl +++ b/jax/tools/build_defs.bzl @@ -146,9 +146,9 @@ EOF ) if format == "TF": - jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir_with_tensorflow" + jax_to_ir_rule = "//jax/tools:jax_to_ir_with_tensorflow" else: - jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir" + jax_to_ir_rule = "//jax/tools:jax_to_ir" py_binary( name = runner, diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 77b46d6d51aa..ab60b3fadd37 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -14,19 +14,19 @@ # JAX is Autograd and XLA -load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "py_library_providing_imports_info", "pybind_extension", "pytype_library", ) +load("//jaxlib:symlink_files.bzl", "symlink_files") licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) # This makes xla_extension module accessible from jax._src.lib. @@ -129,13 +129,13 @@ cc_library( hdrs = ["ffi_helpers.h"], features = ["-use_header_modules"], deps = [ - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", ], ) @@ -149,10 +149,10 @@ cc_library( features = ["-use_header_modules"], deps = [ ":kernel_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/base", "@nanobind", + "@xla//xla/ffi/api:c_api", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -201,10 +201,10 @@ pybind_extension( srcs = ["utils.cc"], module_name = "utils", deps = [ - "@xla//third_party/python_runtime:headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:inlined_vector", "@nanobind", + "@xla//third_party/python_runtime:headers", ], ) @@ -238,6 +238,9 @@ pybind_extension( module_name = "rocm_plugin_extension", deps = [ "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/status", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", "@xla//third_party/python_runtime:headers", "@xla//xla:status", "@xla//xla:util", @@ -248,9 +251,6 @@ pybind_extension( "@xla//xla/pjrt/c:pjrt_c_api_helpers", "@xla//xla/python:py_client_gpu", "@xla//xla/tsl/python/lib/core:numpy", - "@com_google_absl//absl/status", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", ], ) diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index 48332ee1a4d2..d3d15c4fc939 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -23,7 +23,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) # LAPACK @@ -36,13 +36,13 @@ cc_library( features = ["-use_header_modules"], deps = [ "//jaxlib:ffi_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -71,8 +71,8 @@ pybind_extension( deps = [ ":lapack_kernels", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/ffi/api:ffi", "@nanobind", + "@xla//xla/ffi/api:ffi", ], ) diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index bd74be6732fd..a7a47f431a1d 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -26,7 +26,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) cc_library( @@ -37,9 +37,9 @@ cc_library( defines = ["JAX_GPU_CUDA=1"], visibility = ["//visibility:public"], deps = [ - "@xla//xla/tsl/cuda:cupti", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", + "@xla//xla/tsl/cuda:cupti", ], ) @@ -57,9 +57,6 @@ cc_library( features = ["-use_header_modules"], deps = [ ":cuda_vendor", - "@xla//xla/tsl/cuda:cupti", - "@xla//xla/tsl/cuda:cusolver", - "@xla//xla/tsl/cuda:cusparse", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -69,6 +66,9 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cupti", + "@xla//xla/tsl/cuda:cusolver", + "@xla//xla/tsl/cuda:cusparse", ], ) @@ -90,11 +90,11 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:handle_pool", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -108,9 +108,6 @@ cc_library( ":cuda_make_batch_pointers", ":cuda_vendor", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -122,6 +119,9 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -145,12 +145,12 @@ pybind_extension( ":cublas_kernels", ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@nanobind", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -163,13 +163,13 @@ cc_library( ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cudnn", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cudnn", ], ) @@ -201,11 +201,11 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:handle_pool", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", ], ) @@ -218,12 +218,12 @@ cc_library( ":cuda_solver_handle_pool", ":cuda_vendor", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", ], ) @@ -238,13 +238,13 @@ cc_library( ":cuda_solver_handle_pool", ":cuda_vendor", "//jaxlib:ffi_helpers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@xla//xla/ffi/api:ffi", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", ], ) @@ -272,15 +272,15 @@ pybind_extension( ":cusolver_kernels", ":cusolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -293,13 +293,13 @@ cc_library( ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusparse", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusparse", ], ) @@ -324,9 +324,6 @@ pybind_extension( ":cuda_vendor", ":cusparse_kernels", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusparse", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -338,6 +335,9 @@ pybind_extension( "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusparse", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -354,13 +354,13 @@ cc_library( ":cuda_vendor", "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -390,10 +390,10 @@ pybind_extension( ":cuda_linalg_kernels", ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -409,12 +409,12 @@ cc_library( ":cuda_vendor", "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -428,9 +428,9 @@ cuda_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:kernel_helpers", + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", - "@local_config_cuda//cuda:cuda_headers", ], ) @@ -447,9 +447,9 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cudart", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -483,10 +483,6 @@ cc_library( ":cuda_vendor", ":triton_utils", "//jaxlib/gpu:triton_cc_proto", - "@xla//xla/service:custom_call_status", - "@xla//xla/stream_executor/cuda:cuda_asm_compiler", - "@xla//xla/tsl/cuda:cudart", - "@tsl//tsl/platform:env", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -497,6 +493,10 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:env", + "@xla//xla/service:custom_call_status", + "@xla//xla/stream_executor/cuda:cuda_asm_compiler", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -556,6 +556,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", + "@com_google_absl//absl/base:dynamic_annotations", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", @@ -563,7 +564,6 @@ cc_library( "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/cuda:cusparse", - "@com_google_absl//absl/base:dynamic_annotations", ], ) @@ -594,6 +594,8 @@ pybind_extension( ":versions_helpers", "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/status:statusor", + "@nanobind", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", @@ -601,8 +603,6 @@ pybind_extension( "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/cuda:cusparse", - "@com_google_absl//absl/status:statusor", - "@nanobind", ], ) diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 706cac6b46d4..f3524ccdf781 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -20,7 +20,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) exports_files(srcs = [ diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 10acec815475..5452520204b8 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_python//python:defs.bzl", "py_library") licenses(["notice"]) package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:mosaic_users", ], ) @@ -54,6 +54,14 @@ cc_library( # compatible with libtpu deps = [ ":tpu_inc_gen", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", @@ -71,18 +79,10 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorTransforms", + "@tsl//tsl/platform:statusor", "@xla//xla:array", "@xla//xla:shape_util", "@xla//xla:util", - "@tsl//tsl/platform:statusor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", ], ) @@ -192,14 +192,14 @@ cc_library( deps = [ ":tpu_dialect", ":tpu_inc_gen", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@xla//xla:array", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", ], ) diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 20fcf2b4ce74..e5eaeb347137 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -17,7 +17,7 @@ load("//jaxlib:jax.bzl", "pybind_extension") package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:mosaic_gpu_users"], ) py_library( @@ -105,6 +105,12 @@ cc_library( deps = [ ":passes", "//jaxlib/cuda:cuda_vendor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", @@ -142,12 +148,6 @@ cc_library( "@llvm-project//mlir:VectorDialect", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", ], alwayslink = True, ) @@ -168,11 +168,11 @@ pybind_extension( deps = [ "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/synchronization", "@nanobind", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -192,7 +192,7 @@ cc_binary( "notap", ], deps = [ - "@xla//xla/tsl/cuda:cudart", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", ], ) diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index 639e61a89062..48268bfcf30a 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -14,8 +14,8 @@ # Mosaic Python bindings -load("@rules_python//python:defs.bzl", "py_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") +load("@rules_python//python:defs.bzl", "py_library") gentbl_filegroup( name = "tpu_python_gen_raw", diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index 1d994209ffcc..95482e47e864 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") +load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library") licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) pytype_strict_library( @@ -56,8 +56,8 @@ genrule( out=$(RULEDIR)/$${base//_raw/} echo '# pytype: skip-file' > $${out} && \ cat $${src} | - sed -e 's/^from \\.\\./from jaxlib.mlir\\./g' | - sed -e 's/^from \\./from jaxlib.mlir\\.dialects\\./g' >> $${out} + sed -e 's/^from \\.\\./from jaxlib\\.mlir\\./g' | + sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' >> $${out} done """, ) diff --git a/tests/BUILD b/tests/BUILD index eab1d11287e2..ef5f27f9bccb 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1567,6 +1567,6 @@ filegroup( exclude = [], ) + ["BUILD"], visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index fdb7ad7b0a1f..255b03d3a002 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -68,10 +68,10 @@ jax_test( jax_test( name = "flash_attention", - srcs = ["//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py"], + srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"], disable_backends = DISABLED_BACKENDS, disable_configs = DISABLED_CONFIGS, - main = "//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py", + main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py", tags = ["notap"], deps = [ "//jax:mosaic_gpu",