Skip to content

Commit

Permalink
Clean up BUILD files.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 667604964
  • Loading branch information
hawkinsp authored and jax authors committed Aug 26, 2024
1 parent 550607a commit 6d1f51e
Show file tree
Hide file tree
Showing 24 changed files with 134 additions and 141 deletions.
4 changes: 2 additions & 2 deletions benchmarks/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ jax_test(
disable_configs = DISABLED_CONFIGS,
tags = ["notap"],
deps = [
"//jax:mosaic_gpu",
"//jax/experimental/mosaic/gpu/examples:matmul",
"//third_party/py/google_benchmark",
"//third_party/py/jax:mosaic_gpu",
"//third_party/py/jax/experimental/mosaic/gpu/examples:matmul",
] + py_deps("absl/testing") + py_deps("numpy"),
)
2 changes: 1 addition & 1 deletion docs/cuda_custom_call/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ cuda_library(
name = "foo_",
srcs = ["foo.cu.cc"],
deps = [
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@local_config_cuda//cuda:cuda_headers",
],
)
8 changes: 4 additions & 4 deletions examples/jax_cpp/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

package(default_applicable_licenses = ["//third_party/py/jax:license"])
package(default_applicable_licenses = ["//jax:license"])

licenses(["notice"])

Expand All @@ -21,13 +21,13 @@ cc_binary(
srcs = ["main.cc"],
tags = ["manual"],
deps = [
"//third_party/absl/status:statusor",
"@com_google_absl//absl/status:statusor",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:platform_port",
"@xla//xla:literal",
"@xla//xla:literal_util",
"@xla//xla/pjrt:pjrt_client",
"@xla//xla/pjrt/cpu:cpu_client",
"@xla//xla/tools:hlo_module_loader",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:platform_port",
],
)
25 changes: 9 additions & 16 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -76,39 +76,32 @@ package_group(
packages = [
# Intentionally avoid jax dependencies on jax.extend.
# See https://jax.readthedocs.io/en/latest/jep/15856-jex.html
"//third_party/py/jax/tests/...",
"//tests/...",
] + jax_extend_internal_users,
)

package_group(
name = "mosaic_users",
packages = [
"//...",
] + mosaic_internal_users,
includes = [":internal"],
packages = mosaic_internal_users,
)

package_group(
name = "pallas_gpu_users",
packages = [
"//...",
"//learning/brain/research/jax",
] + pallas_gpu_internal_users,
includes = [":internal"],
packages = pallas_gpu_internal_users,
)

package_group(
name = "pallas_tpu_users",
packages = [
"//...",
"//learning/brain/research/jax",
] + pallas_tpu_internal_users,
includes = [":internal"],
packages = pallas_tpu_internal_users,
)

package_group(
name = "mosaic_gpu_users",
packages = [
"//...",
"//learning/brain/research/jax",
] + mosaic_gpu_internal_users,
includes = [":internal"],
packages = mosaic_gpu_internal_users,
)

# JAX-private test utilities.
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ load(

package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
default_visibility = ["//jax:internal"],
)

py_library_providing_imports_info(
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ load(
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
"//jax:internal",
],
)

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ load("//jaxlib:jax.bzl", "py_deps")
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
"//jax:internal",
],
)

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic_gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ load(
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
"//jax:internal",
],
)

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ load(
package(
default_applicable_licenses = [],
default_visibility = [
"//:__subpackages__",
"//jax:internal",
],
)

Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/g3doc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ licenses(["notice"])

package(
default_applicable_licenses = [],
default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"],
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
)

filegroup(
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/back_compat_testdata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ licenses(["notice"])

package(
default_applicable_licenses = [],
default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"],
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
)

py_library(
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/tests/flax_models/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ licenses(["notice"])

package(
default_applicable_licenses = [],
default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"],
default_visibility = ["//jax/experimental/jax2tf:__subpackages__"],
)

py_library(
name = "flax_models",
srcs = glob(["*.py"]),
srcs_version = "PY3",
deps = [
"//jax",
"//third_party/py/flax:core",
"//third_party/py/jax",
"//third_party/py/jraph",
"//third_party/py/numpy",
"//third_party/py/typing_extensions",
Expand Down
18 changes: 9 additions & 9 deletions jax/experimental/mosaic/gpu/examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("//jaxlib:jax.bzl", "py_deps")
load("@rules_python//python:defs.bzl", "py_library", "py_test")
load("//jaxlib:jax.bzl", "py_deps")

licenses(["notice"])

package(
default_applicable_licenses = [],
default_visibility = ["//third_party/py/jax:mosaic_gpu_users"],
default_visibility = ["//jax:mosaic_gpu_users"],
)

exports_files(
srcs = [
"flash_attention.py",
"matmul.py",
],
visibility = ["//third_party/py/jax:internal"],
visibility = ["//jax:internal"],
)

py_library(
name = "matmul",
srcs = ["matmul.py"],
deps = [
"//third_party/py/jax",
"//third_party/py/jax:mosaic_gpu",
"//jax",
"//jax:mosaic_gpu",
],
)

py_library(
name = "flash_attention",
srcs = ["flash_attention.py"],
deps = [
"//third_party/py/jax",
"//third_party/py/jax:mosaic_gpu",
"//jax",
"//jax:mosaic_gpu",
],
)

Expand All @@ -58,8 +58,8 @@ py_test(
"requires-gpu-sm90-only",
],
deps = [
"//jax",
"//jax:mosaic_gpu",
"//learning/brain/research/jax:gpu_support",
"//third_party/py/jax",
"//third_party/py/jax:mosaic_gpu",
] + py_deps("numpy"),
)
4 changes: 2 additions & 2 deletions jax/tools/build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ EOF
)

if format == "TF":
jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir_with_tensorflow"
jax_to_ir_rule = "//jax/tools:jax_to_ir_with_tensorflow"
else:
jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir"
jax_to_ir_rule = "//jax/tools:jax_to_ir"

py_binary(
name = runner,
Expand Down
20 changes: 10 additions & 10 deletions jaxlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@

# JAX is Autograd and XLA

load("//jaxlib:symlink_files.bzl", "symlink_files")
load(
"//jaxlib:jax.bzl",
"py_library_providing_imports_info",
"pybind_extension",
"pytype_library",
)
load("//jaxlib:symlink_files.bzl", "symlink_files")

licenses(["notice"])

package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
default_visibility = ["//jax:internal"],
)

# This makes xla_extension module accessible from jax._src.lib.
Expand Down Expand Up @@ -129,13 +129,13 @@ cc_library(
hdrs = ["ffi_helpers.h"],
features = ["-use_header_modules"],
deps = [
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
],
)

Expand All @@ -149,10 +149,10 @@ cc_library(
features = ["-use_header_modules"],
deps = [
":kernel_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/tsl/python/lib/core:numpy",
"@com_google_absl//absl/base",
"@nanobind",
"@xla//xla/ffi/api:c_api",
"@xla//xla/tsl/python/lib/core:numpy",
],
)

Expand Down Expand Up @@ -201,10 +201,10 @@ pybind_extension(
srcs = ["utils.cc"],
module_name = "utils",
deps = [
"@xla//third_party/python_runtime:headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:inlined_vector",
"@nanobind",
"@xla//third_party/python_runtime:headers",
],
)

Expand Down Expand Up @@ -238,6 +238,9 @@ pybind_extension(
module_name = "rocm_plugin_extension",
deps = [
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/status",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
"@xla//third_party/python_runtime:headers",
"@xla//xla:status",
"@xla//xla:util",
Expand All @@ -248,9 +251,6 @@ pybind_extension(
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
"@xla//xla/python:py_client_gpu",
"@xla//xla/tsl/python/lib/core:numpy",
"@com_google_absl//absl/status",
"@local_config_rocm//rocm:rocm_headers",
"@nanobind",
],
)

Expand Down
10 changes: 5 additions & 5 deletions jaxlib/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ licenses(["notice"])

package(
default_applicable_licenses = [],
default_visibility = ["//:__subpackages__"],
default_visibility = ["//jax:internal"],
)

# LAPACK
Expand All @@ -36,13 +36,13 @@ cc_library(
features = ["-use_header_modules"],
deps = [
"//jaxlib:ffi_helpers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:dynamic_annotations",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_status",
],
)

Expand Down Expand Up @@ -71,8 +71,8 @@ pybind_extension(
deps = [
":lapack_kernels",
"//jaxlib:kernel_nanobind_helpers",
"@xla//xla/ffi/api:ffi",
"@nanobind",
"@xla//xla/ffi/api:ffi",
],
)

Expand Down
Loading

0 comments on commit 6d1f51e

Please sign in to comment.