From 40b561cc72c8b45ccbf30a3a520efca34c18796f Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 17 Nov 2022 18:43:33 -0800 Subject: [PATCH] Project import generated by Copybara. PiperOrigin-RevId: 489359552 --- .bazelrc | 245 +++++++++++++++++++++++++++ README.md | 16 ++ WORKSPACE | 40 +++++ australis/australis_computation.cc | 2 +- australis/exporter.py | 6 +- australis/tests/donate_arg_pjit.py | 46 ----- australis/tests/higher_arity_pjit.py | 45 ----- australis/tests/multi_return_pjit.py | 47 ----- australis/tests/pjit_test.cc | 6 +- australis/tests/pjit_test_fns.py | 104 ++++++++++++ australis/tests/simple_pjit.py | 50 ------ australis/tests/tuple_pjit.py | 51 ------ example/BUILD | 40 +++++ example/flax_example.cc | 55 ++++++ example/flax_jit.py | 72 ++++++++ 15 files changed, 578 insertions(+), 247 deletions(-) create mode 100644 .bazelrc delete mode 100644 australis/tests/donate_arg_pjit.py delete mode 100644 australis/tests/higher_arity_pjit.py delete mode 100644 australis/tests/multi_return_pjit.py create mode 100644 australis/tests/pjit_test_fns.py delete mode 100644 australis/tests/simple_pjit.py delete mode 100644 australis/tests/tuple_pjit.py create mode 100644 example/BUILD create mode 100644 example/flax_example.cc create mode 100644 example/flax_jit.py diff --git a/.bazelrc b/.bazelrc new file mode 100644 index 0000000..a77e53b --- /dev/null +++ b/.bazelrc @@ -0,0 +1,245 @@ +############################################################################ +# All default build options below. + +# Sets the default Apple platform to macOS. +build --apple_platform_type=macos +build --macos_minimum_os=10.14 + +# Make Bazel print out all options from rc files. +build --announce_rc + +build --define open_source_build=true + +build --spawn_strategy=standalone + +build --enable_platform_specific_config + +build --experimental_cc_shared_library + +# Disable enabled-by-default TensorFlow features that we don't care about. +build --define=no_aws_support=true +build --define=no_gcp_support=true +build --define=no_hdfs_support=true +build --define=no_kafka_support=true +build --define=no_ignite_support=true + +build --define=grpc_no_ares=true + +build --define=tsl_link_protobuf=true + +build -c opt + +build --config=short_logs + +build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir. + +# Later Bazel flag values override earlier values; if CUDA/ROCM/TPU are enabled, +# these values are overridden. +build --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=false +build --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=false +build --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=false + +########################################################################### + +build:posix --copt=-fvisibility=hidden +build:posix --copt=-Wno-sign-compare +build:posix --cxxopt=-std=c++17 +build:posix --host_cxxopt=-std=c++17 + +build:avx_posix --copt=-mavx +build:avx_posix --host_copt=-mavx + +build:avx_windows --copt=/arch=AVX + +build:avx_linux --copt=-mavx +build:avx_linux --host_copt=-mavx + +build:native_arch_posix --copt=-march=native +build:native_arch_posix --host_copt=-march=native + +build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 + +build:cuda --repo_env TF_NEED_CUDA=1 +# "sm" means we emit only cubin, which is forward compatible within a GPU generation. +# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. +build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_52,sm_60,sm_70,compute_80" +build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain +build:cuda --@local_config_cuda//:enable_cuda +build:cuda --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true +build:cuda --define=xla_python_enable_gpu=true + +build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain +build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true +build:rocm --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true +build:rocm --define=xla_python_enable_gpu=true +build:rocm --repo_env TF_NEED_ROCM=1 +build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030" + +build:nonccl --define=no_nccl_support=true + +# Tensorflow uses M_* math constants that only get defined by MSVC headers if +# _USE_MATH_DEFINES is defined. +build:windows --copt=/D_USE_MATH_DEFINES +build:windows --host_copt=/D_USE_MATH_DEFINES +# Make sure to include as little of windows.h as possible +build:windows --copt=-DWIN32_LEAN_AND_MEAN +build:windows --host_copt=-DWIN32_LEAN_AND_MEAN +build:windows --copt=-DNOGDI +build:windows --host_copt=-DNOGDI +# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/ +# otherwise, there will be some compiling error due to preprocessing. +build:windows --copt=/Zc:preprocessor +build:windows --cxxopt=/std:c++17 +build:windows --host_cxxopt=/std:c++17 +# Generate PDB files, to generate useful PDBs, in opt compilation_mode +# --copt /Z7 is needed. +build:windows --linkopt=/DEBUG +build:windows --host_linkopt=/DEBUG +build:windows --linkopt=/OPT:REF +build:windows --host_linkopt=/OPT:REF +build:windows --linkopt=/OPT:ICF +build:windows --host_linkopt=/OPT:ICF +build:windows --incompatible_strict_action_env=true + +build:linux --config=posix +build:linux --copt=-Wno-unknown-warning-option +# Workaround for gcc 10+ warnings related to upb. +# See https://github.com/tensorflow/tensorflow/issues/39467 +build:linux --copt=-Wno-stringop-truncation +build:linux --copt=-Wno-array-parameter + +build:macos --config=posix + +# Suppress all warning messages. +build:short_logs --output_filter=DONT_MATCH_ANYTHING + +build:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true +build:tpu --define=with_tpu_support=true + +build:plugin_device --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=true + +######################################################################### +# RBE config options below. +# Flag to enable remote config +common --experimental_repo_remote_exec + +build:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1 +build:rbe --google_default_credentials +build:rbe --bes_backend=buildeventservice.googleapis.com +build:rbe --bes_results_url="https://source.cloud.google.com/results/invocations" +build:rbe --bes_timeout=600s +build:rbe --define=EXECUTOR=remote +build:rbe --distinct_host_configuration=false +build:rbe --flaky_test_attempts=3 +build:rbe --jobs=200 +build:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com +build:rbe --remote_timeout=3600 +build:rbe --spawn_strategy=remote,worker,standalone,local +test:rbe --test_env=USER=anon +# Attempt to minimize the amount of data transfer between bazel and the remote +# workers: +build:rbe --remote_download_toplevel + +build:rbe_linux --config=rbe +build:rbe_linux --action_env=PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin" +build:rbe_linux --host_javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 +build:rbe_linux --javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8 +build:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 +build:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8 + +# Non-rbe settings we should include because we do not run configure +build:rbe_linux --config=avx_linux +build:rbe_linux --linkopt=-lrt +build:rbe_linux --host_linkopt=-lrt +build:rbe_linux --linkopt=-lm +build:rbe_linux --host_linkopt=-lm + +# Use the GPU toolchain until the CPU one is ready. +# https://github.com/bazelbuild/bazel/issues/13623 +build:rbe_cpu_linux_base --config=rbe_linux +build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain" +build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain" +build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" +build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" +build:rbe_cpu_linux_base --platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" + +build:rbe_cpu_linux_py37 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.7" +build:rbe_cpu_linux_py37 --python_path="/usr/local/bin/python3.7" +build:rbe_cpu_linux_py38 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.8" +build:rbe_cpu_linux_py38 --python_path="/usr/local/bin/python3.8" +build:rbe_cpu_linux_py39 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.9" +build:rbe_cpu_linux_py39 --python_path="/usr/local/bin/python3.9" +build:rbe_cpu_linux_py310 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.10" +build:rbe_cpu_linux_py310 --python_path="/usr/local/bin/python3.10" +build:rbe_cpu_linux_py311 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.11" +build:rbe_cpu_linux_py311 --python_path="/usr/local/bin/python3.11" + +build:rbe_linux_cuda_base --config=rbe_linux +build:rbe_linux_cuda_base --config=cuda +build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 + +build:rbe_linux_cuda11.1_nvcc_base --config=rbe_linux_cuda_base +build:rbe_linux_cuda11.1_nvcc_base --action_env=TF_CUDA_VERSION=11 +build:rbe_linux_cuda11.1_nvcc_base --action_env=TF_CUDNN_VERSION=8 +build:rbe_linux_cuda11.1_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.1" +build:rbe_linux_cuda11.1_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" +build:rbe_linux_cuda11.1_nvcc_base --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +test:rbe_linux_cuda11.1_nvcc_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64" +build:rbe_linux_cuda11.1_nvcc_base --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda11.1_nvcc_base --crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda11.1_nvcc_base --extra_toolchains="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda11.1_nvcc_base --extra_execution_platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda11.1_nvcc_base --host_platform="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda11.1_nvcc_base --platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda" +build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_tensorrt" +build:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_nccl" +build:rbe_linux_cuda11.1_nvcc_py3.7 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.7" +build:rbe_linux_cuda11.1_nvcc_py3.7 --python_path="/usr/local/bin/python3.7" +build:rbe_linux_cuda11.1_nvcc_py3.8 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.8" +build:rbe_linux_cuda11.1_nvcc_py3.8 --python_path="/usr/local/bin/python3.8" +build:rbe_linux_cuda11.1_nvcc_py3.9 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.9" +build:rbe_linux_cuda11.1_nvcc_py3.9 --python_path="/usr/local/bin/python3.9" +build:rbe_linux_cuda11.1_nvcc_py3.10 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.10" +build:rbe_linux_cuda11.1_nvcc_py3.10 --python_path="/usr/local/bin/python3.10" +build:rbe_linux_cuda11.1_nvcc_py3.11 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.11" +build:rbe_linux_cuda11.1_nvcc_py3.11 --python_path="/usr/local/bin/python3.11" + +build:rbe_linux_cuda11.4_nvcc_base --config=rbe_linux_cuda_base +build:rbe_linux_cuda11.4_nvcc_base --action_env=TF_CUDA_VERSION=11 +build:rbe_linux_cuda11.4_nvcc_base --action_env=TF_CUDNN_VERSION=8 +build:rbe_linux_cuda11.4_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.4" +build:rbe_linux_cuda11.4_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" +build:rbe_linux_cuda11.4_nvcc_base --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build:rbe_linux_cuda11.4_nvcc_base --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda11.4_nvcc_base --crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda11.4_nvcc_base --extra_toolchains="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda11.4_nvcc_base --extra_execution_platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda11.4_nvcc_base --host_platform="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda11.4_nvcc_base --platforms="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform" +build:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda" +build:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_tensorrt" +build:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_nccl" +build:rbe_linux_cuda11.4_nvcc_py3.7 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.7" +build:rbe_linux_cuda11.4_nvcc_py3.7 --python_path="/usr/local/bin/python3.7" +build:rbe_linux_cuda11.4_nvcc_py3.8 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.8" +build:rbe_linux_cuda11.4_nvcc_py3.8 --python_path="/usr/local/bin/python3.8" +build:rbe_linux_cuda11.4_nvcc_py3.9 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.9" +build:rbe_linux_cuda11.4_nvcc_py3.9 --python_path="/usr/local/bin/python3.9" +build:rbe_linux_cuda11.4_nvcc_py3.10 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.10" +build:rbe_linux_cuda11.4_nvcc_py3.10 --python_path="/usr/local/bin/python3.10" +build:rbe_linux_cuda11.4_nvcc_py3.11 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.11" +build:rbe_linux_cuda11.4_nvcc_py3.11 --python_path="/usr/local/bin/python3.11" + +# These you may need to change for your own GCP project. +build:tensorflow_testing_rbe --project_id=tensorflow-testing +common:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance +build:tensorflow_testing_rbe_linux --config=tensorflow_testing_rbe +############################################################################# + +# Load `.jax_configure.bazelrc` file written by build.py +try-import %workspace%/.jax_configure.bazelrc + +# Load rc file with user-specific options. +try-import %workspace%/.bazelrc.user diff --git a/README.md b/README.md index 7a55d9d..bd06349 100644 --- a/README.md +++ b/README.md @@ -2,3 +2,19 @@ Australis is library that abstracts away the pjrt runtime API. We want to provide a nicer C++ friendly API in order to decouple users. + +To build the example: + +```bash +# Because australis essentially links in jaxlib in order to access the XLA compiler, +# follow the latest prerequisite instructions for building jaxlib: https://jax.readthedocs.io/en/latest/developer.html. +# Also install jax with the proper jaxlib for your target architecture (precompiled is fine). +git clone https://github.com/jax-ml/australis.git +mkdir build +cd australis +wget https://github.com/bazelbuild/bazel/releases/download/5.1.1/bazel-5.1.1-linux-x86_64 +chmod +x bazel-5.1.1-linux-x86_64 +sudo apt-get install protobuf-compiler +pip install . flax +./bazel-5.1.1-linux-x86_64 build --config=cuda --check_visibility=false //example:flax_example +``` diff --git a/WORKSPACE b/WORKSPACE index e69de29..08e8ab2 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -0,0 +1,40 @@ +workspace(name = "australis") + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") + +git_repository( + name = "jax", + commit = "0324cac8882e3ea1b2148818ee2322e2b96696da", + remote = "https://github.com/google/jax.git", +) + +# To update TensorFlow to a new revision, +# a) update URL and strip_prefix to the new git commit hash +# b) get the sha256 hash of the commit by running: +# curl -L https://github.com/tensorflow/tensorflow/archive/.tar.gz | sha256sum +# and update the sha256 with the result. +http_archive( + name = "org_tensorflow", + sha256 = "bfd40279b247d2d0b0dc5c5a776b595c9d4979889dcf0529c85fe9f6ff7a5255", + strip_prefix = "tensorflow-c21f137bc42450f10f7d04f9d263852827afd079", + urls = [ + "https://github.com/tensorflow/tensorflow/archive/c21f137bc42450f10f7d04f9d263852827afd079.tar.gz", + ], +) + +# For development, one can use a local TF repository instead. +# local_repository( +# name = "org_tensorflow", +# path = "tensorflow", +# ) + +# Initialize TensorFlow's external dependencies. +load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3") +tf_workspace3() +load("@org_tensorflow//tensorflow:workspace2.bzl", "tf_workspace2") +tf_workspace2() +load("@org_tensorflow//tensorflow:workspace1.bzl", "tf_workspace1") +tf_workspace1() +load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") +tf_workspace0() diff --git a/australis/australis_computation.cc b/australis/australis_computation.cc index d9e7ded..a0cde7e 100644 --- a/australis/australis_computation.cc +++ b/australis/australis_computation.cc @@ -16,9 +16,9 @@ limitations under the License. #include -#include "google/protobuf/io/coded_stream.h" #include "australis/executable.pb.h" #include "australis/petri.pb.h" +#include "google/protobuf/io/coded_stream.h" namespace aux::internal { diff --git a/australis/exporter.py b/australis/exporter.py index 5f15922..9c5a0da 100644 --- a/australis/exporter.py +++ b/australis/exporter.py @@ -40,7 +40,7 @@ def square(x): import jax.experimental.maps from jax.interpreters import pxla from jax.stages import Lowered -from google3.third_party.australis import executable_pb2 +from australis import executable_pb2 import numpy as np import sys @@ -78,9 +78,11 @@ class FakeClient: def __init__(self, n, platform): self.platform = platform - self.process_index = 0 self.devices = [FakeDevice(self, i, platform) for i in range(n)] + def process_index(self): + return 0 + def device_count(self): return len(self.devices) diff --git a/australis/tests/donate_arg_pjit.py b/australis/tests/donate_arg_pjit.py deleted file mode 100644 index 902d97f..0000000 --- a/australis/tests/donate_arg_pjit.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Stage out a jitted function with multiple arguments and argument donation.""" - -import numpy as np - -import jax -import jax.numpy as jnp -import jax.experimental.pjit -from jax.interpreters.pxla import PartitionSpec as P -from jax.stages import Lowered -from jax.experimental.australis import exporter - - -def f(x, y): - return (x[0] + x[1] * y, x[0]) # Avoid untupling. - - -def lower() -> Lowered: - device_mesh = np.array(exporter.fake_devices(2, 'tpu')) - mesh_axis_names = ('x',) - fn = jax.experimental.pjit.pjit( - f, - in_axis_resources=((P('x'), P('x')), None), - out_axis_resources=P('x'), - donate_argnums=0) - with jax.experimental.maps.Mesh(device_mesh, mesh_axis_names): - return fn.lower((jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32')), - jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32'))), - jax.ShapeDtypeStruct((), jnp.dtype('float32'))) - - -if __name__ == '__main__': - exporter.run(lower) diff --git a/australis/tests/higher_arity_pjit.py b/australis/tests/higher_arity_pjit.py deleted file mode 100644 index ffc82cc..0000000 --- a/australis/tests/higher_arity_pjit.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Stage out a jitted function with multiple arguments. -""" - -import numpy as np - -import jax -import jax.numpy as jnp -import jax.experimental.pjit -from jax.interpreters.pxla import PartitionSpec as P -from jax.stages import Lowered -from jax.experimental.australis import exporter - - -def f(x, y): - return x[0] + x[1] * y - - -def lower() -> Lowered: - device_mesh = np.array(exporter.fake_devices(2, 'tpu')) - mesh_axis_names = ('x',) - fn = jax.experimental.pjit.pjit( - f, in_axis_resources=((P('x'), P('x')), None), out_axis_resources=P('x')) - with jax.experimental.maps.Mesh(device_mesh, mesh_axis_names): - return fn.lower( - (jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32')), - jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32'))), - jax.ShapeDtypeStruct((), jnp.dtype('float32'))) - - -if __name__ == '__main__': - exporter.run(lower) diff --git a/australis/tests/multi_return_pjit.py b/australis/tests/multi_return_pjit.py deleted file mode 100644 index a6a3e56..0000000 --- a/australis/tests/multi_return_pjit.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Stage out a pjitted function with multiple arguments. -""" - -import numpy as np - -import jax -import jax.numpy as jnp -import jax.experimental.pjit -from jax.interpreters.pxla import PartitionSpec as P -from jax.stages import Lowered -from jax.experimental.australis import exporter - - -def f(x, y): - return (x[0] + x[1] * y, x[0], x[1]), y + 1, x - - -def lower() -> Lowered: - device_mesh = np.array(exporter.fake_devices(2, 'tpu')) - mesh_axis_names = ('x',) - fn = jax.experimental.pjit.pjit( - f, - in_axis_resources=((P('x'), P('x')), None), - out_axis_resources=(P('x'), None, P('x'))) - with jax.experimental.maps.Mesh(device_mesh, mesh_axis_names): - return fn.lower( - (jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32')), - jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32'))), - jax.ShapeDtypeStruct((), jnp.dtype('float32'))) - - -if __name__ == '__main__': - exporter.run(lower) diff --git a/australis/tests/pjit_test.cc b/australis/tests/pjit_test.cc index 7ed1412..63697ef 100644 --- a/australis/tests/pjit_test.cc +++ b/australis/tests/pjit_test.cc @@ -22,11 +22,7 @@ limitations under the License. #include "absl/status/status.h" #include "australis/australis.h" #include "australis/petri.h" -#include "australis/tests/donate_arg_pjit.h" -#include "australis/tests/higher_arity_pjit.h" -#include "australis/tests/multi_return_pjit.h" -#include "australis/tests/simple_pjit.h" -#include "australis/tests/tuple_pjit.h" +#include "australis/tests/pjit_test_fns.h" namespace { diff --git a/australis/tests/pjit_test_fns.py b/australis/tests/pjit_test_fns.py new file mode 100644 index 0000000..ce6e45f --- /dev/null +++ b/australis/tests/pjit_test_fns.py @@ -0,0 +1,104 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The user configuration used to stage out a user's model. + +Note: this might in the future be automatically code-generated from bzl rules. +""" + +import functools +import numpy as np +import sys + +import jax +import jax.experimental.pjit +from jax.interpreters.pxla import PartitionSpec as P +import jax.numpy as jnp + +from jax.stages import Lowered + +from jax.experimental.australis import exporter + + +@functools.partial( + jax.experimental.pjit.pjit, + in_axis_resources=((P('x'), P('x')), None), + out_axis_resources=P('x'), + donate_argnums=0) +def donate_arg_fn(x, y): + return (x[0] + x[1] * y, x[0]) # Avoid untupling. + + +@functools.partial( + jax.experimental.pjit.pjit, + in_axis_resources=P('x'), + out_axis_resources=P('x')) +def simple_fn(x): + return 2 * x + + +@functools.partial( + jax.experimental.pjit.pjit, + in_axis_resources=((P('x'), P('x')),), + out_axis_resources=P('x')) +def tuple_fn(x): + return x[0] + x[1] + + +@functools.partial( + jax.experimental.pjit.pjit, + in_axis_resources=((P('x'), P('x')), None), + out_axis_resources=(P('x'), None, P('x'))) +def multi_return_fn(x, y): + return (x[0] + x[1] * y, x[0], x[1]), y + 1, x + + +@functools.partial( + jax.experimental.pjit.pjit, + in_axis_resources=((P('x'), P('x')), None), + out_axis_resources=P('x')) +def higher_arity_fn(x, y): + return x[0] + x[1] * y + + +def lower() -> Lowered: + device_mesh = np.array(exporter.fake_devices(2, 'tpu')) + mesh_axis_names = ('x',) + print('device_mesh ', device_mesh, file=sys.stderr) + with jax.experimental.maps.Mesh(device_mesh, mesh_axis_names): + return [ + ('simple_pjit', + simple_fn.lower(jax.ShapeDtypeStruct((2, 2), jnp.dtype('float32')))), + ('tuple_pjit', + tuple_fn.lower((jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32')), + jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32'))))), + ('multi_return_pjit', + multi_return_fn.lower( + (jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32')), + jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32'))), + jax.ShapeDtypeStruct((), jnp.dtype('float32')))), + ('higher_arity_pjit', + higher_arity_fn.lower( + (jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32')), + jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32'))), + jax.ShapeDtypeStruct((), jnp.dtype('float32')))), + ('donate_arg_pjit', + donate_arg_fn.lower( + (jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32')), + jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32'))), + jax.ShapeDtypeStruct((), jnp.dtype('float32')))), + ] + + +if __name__ == '__main__': + exporter.run(lower) diff --git a/australis/tests/simple_pjit.py b/australis/tests/simple_pjit.py deleted file mode 100644 index fcd6f83..0000000 --- a/australis/tests/simple_pjit.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The user configuration used to stage out a user's model. - -Note: this might in the future be automatically code-generated from bzl rules. -""" - -import numpy as np -import sys - -import jax -import jax.experimental.pjit -from jax.interpreters.pxla import PartitionSpec as P -import jax.numpy as jnp - -from jax.stages import Lowered - -from jax.experimental.australis import exporter - - -def f(x): - return 2 * x - - -def lower() -> Lowered: - device_mesh = np.array(exporter.fake_devices(2, 'tpu')) - mesh_axis_names = ('x',) - print('device_mesh ', device_mesh, file=sys.stderr) - fn = jax.experimental.pjit.pjit( - f, in_axis_resources=P('x'), out_axis_resources=P('x')) - with jax.experimental.maps.Mesh(device_mesh, mesh_axis_names): - staged: Lowered = fn.lower( - jax.ShapeDtypeStruct((2, 2), jnp.dtype('float32'))) - return staged - - -if __name__ == '__main__': - exporter.run(lower) diff --git a/australis/tests/tuple_pjit.py b/australis/tests/tuple_pjit.py deleted file mode 100644 index e125e33..0000000 --- a/australis/tests/tuple_pjit.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The user configuration used to stage out a user's model. - -Note: this might in the future be automatically code-generated from bzl rules. -""" - -import numpy as np -import sys - -import jax -import jax.experimental.pjit -from jax.interpreters.pxla import PartitionSpec as P -import jax.numpy as jnp - -from jax.stages import Lowered - -from jax.experimental.australis import exporter - - -def f(x): - return x[0] + x[1] - - -def lower() -> Lowered: - device_mesh = np.array(exporter.fake_devices(2, 'tpu')) - mesh_axis_names = ('x',) - print('device_mesh ', device_mesh, file=sys.stderr) - fn = jax.experimental.pjit.pjit( - f, in_axis_resources=((P('x'), P('x')),), out_axis_resources=P('x')) - with jax.experimental.maps.Mesh(device_mesh, mesh_axis_names): - staged: Lowered = fn.lower( - (jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32')), - jax.ShapeDtypeStruct((2, 3), jnp.dtype('float32')))) - return staged - - -if __name__ == '__main__': - exporter.run(lower) diff --git a/example/BUILD b/example/BUILD new file mode 100644 index 0000000..81ad454 --- /dev/null +++ b/example/BUILD @@ -0,0 +1,40 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tests for australis. + +load("@australis//:australis.bzl", "australis") + +licenses(["notice"]) + +australis( + name = "flax_jit", + cc_namespace = "australis::test", + py_deps = [], # Internal flax deps +) + +cc_test( + name = "flax_example", + srcs = ["flax_example.cc"], + linkstatic = 1, + deps = [ + ":flax_jit_cc", + "@australis//australis", + "@australis//australis:cpu_support", + "@australis//australis:gpu_support", + "@australis//australis:petri", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/example/flax_example.cc b/example/flax_example.cc new file mode 100644 index 0000000..d7c624c --- /dev/null +++ b/example/flax_example.cc @@ -0,0 +1,55 @@ +/* Copyright 2022 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "gtest/gtest.h" +#include "absl/status/status.h" +#include "australis/australis.h" +#include "example/flax_jit.h" +#include "australis/petri.h" + +namespace aux { +namespace { + +absl::StatusOr> Unpack2Tuple( + absl::StatusOr input) { + auto tmp = *PTree::DestructureTuple(std::move(input)); + if (tmp.size() != 2) { + return absl::InvalidArgumentError(absl::StrCat("Wrong size: ", tmp.size())); + } + return std::tuple(std::move(tmp[0]), std::move(tmp[1])); +} + +TEST(Flax, OptimizerStep) { + auto client = *Client::GetDefault(); + auto dev = client.LocalDevices()[0]; + auto init_fn = *australis::test::FlaxInit::Load(client); + auto optimizer_step_fn = *australis::test::FlaxOptimizerStep::Load(client); + + auto [params, opt_state] = *Unpack2Tuple(init_fn()); + + const float inputs[] = {0, 1, 2, 3, 4, 5, 6, 7}; + auto x = *PTree::BufferRN(inputs, {2, 4}, dev); + std::tie(params, opt_state) = + *Unpack2Tuple(optimizer_step_fn(params, opt_state, x)); + + // TODO(parkers): Properly encode the flax weights structure (or an + // approximation of it) into the new executable proto type. + EXPECT_EQ( + "((((Buffer(f32[16]), Buffer(f32[4,16])), (Buffer(f32[16]), " + "Buffer(f32[16,16])), (Buffer(f32[16]), Buffer(f32[16,16])))))", + params.ToString()); +} + +} // namespace +} // namespace aux diff --git a/example/flax_jit.py b/example/flax_jit.py new file mode 100644 index 0000000..fc5b2bf --- /dev/null +++ b/example/flax_jit.py @@ -0,0 +1,72 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stage out a flax model.""" + +from typing import Sequence + +import jax +import jax.numpy as jnp +from flax import linen as nn +import optax + +from jax.stages import Lowered +from australis import exporter + + +class Sequential(nn.Module): + layers: Sequence[nn.Module] + + @nn.compact + def __call__(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def lower() -> Lowered: + model = Sequential(layers=[nn.Dense(16), nn.Dense(16), nn.Dense(16)]) + + tx = optax.adam(0.1) + + @jax.jit + def init(): + init_rng = jax.random.PRNGKey(0) + params = model.init(init_rng, jnp.ones((2, 4), dtype=jnp.float32)) + return params, tx.init(params) + + init_fn = init.lower() + + @jax.jit + def optimizer_step(params, opt_state, x): + + def fwd(params): + return model.apply(params, x).sum() + + grads = jax.grad(fwd)(params) + updates, opt_state = tx.update(grads, opt_state) + params = optax.apply_updates(params, updates) + return params, opt_state + + params, opt_state = jax.eval_shape(init) + optimizer_step_lowered = optimizer_step.lower( + params, opt_state, jax.ShapeDtypeStruct((2, 4), jnp.float32)) + + return [ + ("flax_init", init_fn), + ("flax_optimizer_step", optimizer_step_lowered), + ] + + +if __name__ == "__main__": + exporter.run(lower)