diff --git a/.github/workflows/mlc.yml b/.github/workflows/mlc.yml new file mode 100644 index 0000000000..d6bbfc2b76 --- /dev/null +++ b/.github/workflows/mlc.yml @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# GH actions. +# We use it to cover windows and mac builds +# Jenkins is still the primary CI + +name: CI + +on: + push: + branches: + - mlc + pull_request: + branches: + - mlc + workflow_dispatch: + +concurrency: + group: CI-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + +jobs: + MacOS: + if: ${{ github.repository == 'mlc-ai/relax' }} + runs-on: macOS-latest + steps: + - uses: actions/checkout@v2 + with: + submodules: 'recursive' + - name: Set up environment + uses: ./.github/actions/setup + - name: Conda Build + shell: bash -l {0} + run: >- + conda build --output-folder=conda/pkg conda/recipe && + conda install tvm -c ./conda/pkg + - name: Build iOS RPC + run: | + IOS_VERSION="14.0" + CMAKE_FLAGS="-DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_SYSTEM_VERSION=${IOS_VERSION} \ + -DCMAKE_OSX_SYSROOT=iphonesimulator \ + -DCMAKE_OSX_ARCHITECTURES=x86_64 \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_BUILD_WITH_INSTALL_NAME_DIR=ON \ + -DUSE_IOS_RPC=ON" + + mkdir build-ios-simulator + cd build-ios-simulator + cmake .. ${CMAKE_FLAGS} + cmake --build . --target ios_rpc + - name: Test + shell: bash -l {0} + run: >- + python -m pytest -v tests/python/all-platform-minimal-test + - name: Test iOS RPC + shell: bash -l {0} + run: >- + python -m pip install tornado psutil cloudpickle && + export PYTHONPATH=tests/python/contrib:${PYTHONPATH} && + export BUNDLE_ID=org.apache.tvmrpc && + export BUNDLE_PATH=build-ios-simulator/apps/ios_rpc/ios_rpc/src/ios_rpc-build/Release-iphonesimulator/tvmrpc.app && + python -m pytest -v tests/python/contrib/test_rpc_server_device.py + + Windows: + if: ${{ github.repository == 'mlc-ai/relax' }} + runs-on: windows-2019 + steps: + - uses: actions/checkout@v2 + with: + submodules: 'recursive' + - name: Set up environment + uses: ./.github/actions/setup + - name: Conda Build + shell: cmd /C call {0} + run: >- + conda build --output-folder=conda/pkg conda/recipe && + conda install tvm -c ./conda/pkg + - name: Test + shell: cmd /C call {0} + run: >- + python -m pytest -v tests/python/all-platform-minimal-test diff --git a/CMakeLists.txt b/CMakeLists.txt index 47d57d56bd..389879a883 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -843,8 +843,4 @@ if(USE_CUDA AND USE_CUTLASS) install(TARGETS fpA_intB_gemm EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) target_link_libraries(tvm PRIVATE fpA_intB_gemm) target_link_libraries(tvm_runtime PRIVATE fpA_intB_gemm) - - install(TARGETS flash_attn EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) - target_link_libraries(tvm PRIVATE -Wl,--no-as-needed flash_attn) - target_link_libraries(tvm_runtime PRIVATE -Wl,--no-as-needed flash_attn) endif() diff --git a/ci/jenkins/mlc_jenkinsfile.groovy b/ci/jenkins/mlc_jenkinsfile.groovy new file mode 100644 index 0000000000..f7b7ad1bb1 --- /dev/null +++ b/ci/jenkins/mlc_jenkinsfile.groovy @@ -0,0 +1,341 @@ +#!groovy +// -*- mode: groovy -*- + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Jenkins pipeline +// See documents at https://jenkins.io/doc/book/pipeline/jenkinsfile/ + +// ============================= IMPORTANT NOTE ============================= +// To keep things simple +// This file is manually updated to maintain unity branch specific builds. +// Please do not send this file to main + + +import org.jenkinsci.plugins.pipeline.modeldefinition.Utils + +// NOTE: these lines are scanned by docker/dev_common.sh. Please update the regex as needed. --> +ci_lint = 'tlcpackstaging/ci_lint:20230504-142417-4d37a0a0' +ci_gpu = 'tlcpackstaging/ci_gpu:20230504-142417-4d37a0a0' +ci_cpu = 'tlcpackstaging/ci_cpu:20230513-200357-e54bbc73' +ci_wasm = 'tlcpack/ci-wasm:v0.72' +ci_i386 = 'tlcpack/ci-i386:v0.75' +ci_qemu = 'tlcpack/ci-qemu:v0.11' +ci_arm = 'tlcpack/ci-arm:v0.08' +ci_hexagon = 'tlcpackstaging/ci_hexagon:20230504-142417-4d37a0a0' +// <--- End of regex-scanned config. + +// Parameters to allow overriding (in Jenkins UI), the images +// to be used by a given build. When provided, they take precedence +// over default values above. +properties([ + parameters([ + string(name: 'ci_lint_param', defaultValue: ''), + string(name: 'ci_cpu_param', defaultValue: ''), + string(name: 'ci_gpu_param', defaultValue: ''), + string(name: 'ci_wasm_param', defaultValue: ''), + string(name: 'ci_i386_param', defaultValue: ''), + string(name: 'ci_qemu_param', defaultValue: ''), + string(name: 'ci_arm_param', defaultValue: ''), + string(name: 'ci_hexagon_param', defaultValue: '') + ]) +]) + +// tvm libraries +tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' +tvm_lib = 'build/libtvm.so, ' + tvm_runtime +// LLVM upstream lib +tvm_multilib = 'build/libtvm.so, ' + + 'build/libvta_fsim.so, ' + + tvm_runtime + +tvm_multilib_tsim = 'build/libvta_tsim.so, ' + + tvm_multilib + +// command to start a docker container +docker_run = 'docker/bash.sh' +// timeout in minutes +max_time = 240 + +def per_exec_ws(folder) { + return "workspace/exec_${env.EXECUTOR_NUMBER}/" + folder +} + +// initialize source codes +def init_git() { + checkout scm + // Add more info about job node + sh ( + script: "echo NODE_NAME=${env.NODE_NAME}", + label: 'Show executor node info', + ) + retry(5) { + timeout(time: 2, unit: 'MINUTES') { + sh (script: 'git submodule update --init --recursive -f', label: 'Update git submodules') + } + } +} + +def should_skip_slow_tests(pr_number) { + withCredentials([string( + credentialsId: 'tvm-bot-jenkins-reader', + variable: 'GITHUB_TOKEN', + )]) { + // Exit code of 1 means run slow tests, exit code of 0 means skip slow tests + result = sh ( + returnStatus: true, + script: "./tests/scripts/should_run_slow_tests.py --pr '${pr_number}'", + label: 'Check if CI should run slow tests', + ) + } + return result == 0 +} + +def cancel_previous_build() { + // cancel previous build if it is not on main. + if (env.BRANCH_NAME != 'main') { + def buildNumber = env.BUILD_NUMBER as int + // Milestone API allows us to cancel previous build + // with the same milestone number + if (buildNumber > 1) milestone(buildNumber - 1) + milestone(buildNumber) + } +} + +def should_skip_ci(pr_number) { + withCredentials([string( + credentialsId: 'tvm-bot-jenkins-reader', + variable: 'TOKEN', + )]) { + // Exit code of 1 means run full CI (or the script had an error, so run + // full CI just in case). Exit code of 0 means skip CI. + git_skip_ci_code = sh ( + returnStatus: true, + script: "./tests/scripts/git_skip_ci.py --pr '${pr_number}'", + label: 'Check if CI should be skipped', + ) + } + return git_skip_ci_code == 0 +} + +cancel_previous_build() + +def lint() { +stage('Prepare') { + node('CPU-SMALL') { + // When something is provided in ci_*_param, use it, otherwise default with ci_* + ci_lint = params.ci_lint_param ?: ci_lint + ci_cpu = params.ci_cpu_param ?: ci_cpu + ci_gpu = params.ci_gpu_param ?: ci_gpu + ci_wasm = params.ci_wasm_param ?: ci_wasm + ci_i386 = params.ci_i386_param ?: ci_i386 + ci_qemu = params.ci_qemu_param ?: ci_qemu + ci_arm = params.ci_arm_param ?: ci_arm + ci_hexagon = params.ci_hexagon_param ?: ci_hexagon + + sh (script: """ + echo "Docker images being used in this build:" + echo " ci_lint = ${ci_lint}" + echo " ci_cpu = ${ci_cpu}" + echo " ci_gpu = ${ci_gpu}" + echo " ci_wasm = ${ci_wasm}" + echo " ci_i386 = ${ci_i386}" + echo " ci_qemu = ${ci_qemu}" + echo " ci_arm = ${ci_arm}" + echo " ci_hexagon = ${ci_hexagon}" + """, label: 'Docker image names') + } +} + +stage('Sanity Check') { + timeout(time: max_time, unit: 'MINUTES') { + node('CPU') { + ws(per_exec_ws('tvm/sanity')) { + init_git() + is_docs_only_build = sh ( + returnStatus: true, + script: './tests/scripts/git_change_docs.sh', + label: 'Check for docs only changes', + ) + // skip_ci = should_skip_ci(env.CHANGE_ID) + // skip_slow_tests = should_skip_slow_tests(env.CHANGE_ID) + sh ( + script: "${docker_run} ${ci_lint} ./tests/scripts/mlc/task_mlc_lint_cleanup.sh", + label: 'Cleanup before linting', + ) + sh ( + script: "${docker_run} ${ci_lint} ./tests/scripts/task_lint.sh", + label: 'Run lint', + ) + sh ( + script: "${docker_run} ${ci_lint} ./tests/scripts/unity/task_extra_lint.sh", + label: 'Run extra lint', + ) + } + } + } +} +} + +lint() + +// Run make. First try to do an incremental make from a previous workspace in hope to +// accelerate the compilation. If something is wrong, clean the workspace and then +// build from scratch. +def make(docker_type, path, make_flag) { + timeout(time: max_time, unit: 'MINUTES') { + try { + cmake_build(docker_type, path, make_flag) + // always run cpp test when build + // sh "${docker_run} ${docker_type} ./tests/scripts/task_cpp_unittest.sh" + } catch (hudson.AbortException ae) { + // script exited due to user abort, directly throw instead of retry + if (ae.getMessage().contains('script returned exit code 143')) { + throw ae + } + echo 'Incremental compilation failed. Fall back to build from scratch' + sh ( + script: "${docker_run} ${docker_type} ./tests/scripts/task_clean.sh ${path}", + label: 'Clear old cmake workspace', + ) + cmake_build(docker_type, path, make_flag) + cpp_unittest(docker_type) + } + } +} + +// Specifications to Jenkins "stash" command for use with various pack_ and unpack_ functions. +tvm_runtime = 'build/libtvm_runtime.so, build/config.cmake' // use libtvm_runtime.so. +tvm_lib = 'build/libtvm.so, ' + tvm_runtime // use libtvm.so to run the full compiler. +// LLVM upstream lib +tvm_multilib = 'build/libtvm.so, ' + + 'build/libvta_fsim.so, ' + + tvm_runtime + +tvm_multilib_tsim = 'build/libvta_tsim.so, ' + + tvm_multilib + +microtvm_tar_gz = 'build/microtvm_template_projects.tar.gz' + +// pack libraries for later use +def pack_lib(name, libs) { + sh (script: """ + echo "Packing ${libs} into ${name}" + echo ${libs} | sed -e 's/,/ /g' | xargs md5sum + """, label: 'Stash libraries and show md5') + stash includes: libs, name: name +} + +// unpack libraries saved before +def unpack_lib(name, libs) { + unstash name + sh (script: """ + echo "Unpacked ${libs} from ${name}" + echo ${libs} | sed -e 's/,/ /g' | xargs md5sum + """, label: 'Unstash libraries and show md5') +} + +// compress microtvm template projects and pack the tar. +def pack_microtvm_template_projects(name) { + sh( + script: 'cd build && tar -czvf microtvm_template_projects.tar.gz microtvm_template_projects/', + label: 'Compress microtvm_template_projects' + ) + pack_lib(name + '-microtvm-libs', microtvm_tar_gz) +} + +def unpack_microtvm_template_projects(name) { + unpack_lib(name + '-microtvm-libs', microtvm_tar_gz) + sh( + script: 'cd build && tar -xzvf microtvm_template_projects.tar.gz', + label: 'Unpack microtvm_template_projects' + ) +} + +def ci_setup(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_ci_setup.sh", + label: 'Set up CI environment', + ) +} + +def python_unittest(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_python_unittest.sh", + label: 'Run Python unit tests', + ) +} + +def fsim_test(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_python_vta_fsim.sh", + label: 'Run VTA tests in FSIM', + ) +} + +def cmake_build(image, path, make_flag) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/mlc/task_mlc_build.sh", + label: 'Run cmake build', + ) +} + +def cpp_unittest(image) { + sh ( + script: "${docker_run} ${image} ./tests/scripts/task_cpp_unittest.sh", + label: 'Build and run C++ tests', + ) +} + +def add_hexagon_permissions() { + sh( + script: 'find build/hexagon_api_output -type f | xargs chmod +x', + label: 'Add execute permissions for hexagon files', + ) +} + +// NOTE: limit tests to relax folder for now to allow us to skip some of the tests +// that are mostly related to changes in main. +// This helps to speedup CI time and reduce CI cost. +stage('Build and Test') { + if (is_docs_only_build != 1) { + parallel 'BUILD: GPU': { + node('GPU') { + ws(per_exec_ws('tvm/build-gpu')) { + init_git() + sh "${docker_run} ${ci_gpu} nvidia-smi" + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_config_build_gpu.sh build" + make("${ci_gpu}", 'build', '-j2') + sh "${docker_run} ${ci_gpu} ./tests/scripts/unity/task_python_relax_gpuonly.sh" + } + } + }, + 'BUILD: CPU': { + node('CPU') { + ws(per_exec_ws('tvm/build-cpu')) { + init_git() + sh "${docker_run} ${ci_cpu} ./tests/scripts/task_config_build_cpu.sh build" + make(ci_cpu, 'build', '-j2') + sh "${docker_run} ${ci_cpu} ./tests/scripts/unity/task_python_relax.sh" + } + } + } + } else { + Utils.markStageSkippedForConditional('BUILD: CPU') + } +} diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index bd3e3b1166..f8aaa2f40d 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -21,7 +21,6 @@ if(USE_CUDA AND USE_CUTLASS) set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass) add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm) - add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn) list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc) message(STATUS "Build with CUTLASS") diff --git a/include/tvm/topi/einsum.h b/include/tvm/topi/einsum.h index 5e7813f843..b5756e1200 100644 --- a/include/tvm/topi/einsum.h +++ b/include/tvm/topi/einsum.h @@ -65,13 +65,19 @@ Array InferEinsumShape(const std::string& subscripts, * \param subscripts_str Specifies the subscripts for summation as comma separated list of * subscript labels. * \param inputs Arrays for the operation. + * \param fcompute Specifies the computation expression of the innermost loop. + * \param fcombine Specifies the associative computation involved in constructing + * the commutative reduction. + * \param fidentity Establishes the identity elements for the commutative reduction process. * \param name The name of the operation. * \param tag The tag to mark the operation. * * \return The calculation based on the Einstein summation convention. */ -Tensor einsum(const std::string& subscripts_str, const Array inputs, - std::string name = "T_einsum", std::string tag = kEinsum); +Array einsum(const std::string& subscripts_str, const Array inputs, + PackedFunc fcompute = nullptr, PackedFunc fcombine = nullptr, + PackedFunc fidentity = nullptr, std::string name = "T_einsum", + std::string tag = kEinsum); struct EinsumEquation { /*! @@ -89,6 +95,14 @@ struct EinsumEquation { std::vector inputs; // The output subscript of the Einsum equation. Subscript output; + // The number of outputs. + size_t num_outputs = 0; + + /*! + * \brief Set output subscript of the Einsum equation, and ensure that + * all output subscripts are identical. + */ + void SetOutput(Subscript output_subscript); }; } // namespace topi diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 9b4fa78127..b6a9517f80 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -159,100 +159,3 @@ def instantiate_attention_template(attrs): ) return substitute_template(template, attrs) - - -def instantiate_flash_attention_template(attrs): - """Return host code for flash attention.""" - - template = """ - int q_head_stride = ${head_dim}; - int k_head_stride = ${head_dim}; - int v_head_stride = ${head_dim}; - int o_head_stride = ${head_dim}; - int q_row_stride = q_head_stride * ${num_heads}; - int k_row_stride = k_head_stride * ${num_heads}; - int v_row_stride = v_head_stride * ${num_heads}; - int o_row_stride = o_head_stride * ${num_heads}; - int q_batch_stride = q_row_stride * ${num_queries}; - int k_batch_stride = k_row_stride * ${num_keys}; - int v_batch_stride = v_row_stride * ${num_keys}; - int o_batch_stride = o_row_stride * ${num_queries}; - - flash_attn::flash_attention_forward( - static_cast(${query}->data), - static_cast(${key}->data), - static_cast(${value}->data), - static_cast(out0->data), - ${num_batches}, - ${num_queries}, - ${num_keys}, - ${num_heads}, - ${num_heads}, - ${head_dim}, - q_batch_stride, - k_batch_stride, - v_batch_stride, - o_batch_stride, - q_head_stride, - k_head_stride, - v_head_stride, - o_head_stride, - q_row_stride, - k_row_stride, - v_row_stride, - o_row_stride, - ${scale}, - ${is_causal}, - nullptr); - """ - - template_stacked = """ - int q_head_stride = ${head_dim}; - int k_head_stride = ${head_dim}; - int v_head_stride = ${head_dim}; - int o_head_stride = ${head_dim}; - int row_stride = q_head_stride * ${num_heads} + - k_head_stride * ${num_heads} + - v_head_stride * ${num_heads}; - int q_row_stride = row_stride; - int k_row_stride = row_stride; - int v_row_stride = row_stride; - int o_row_stride = o_head_stride * ${num_heads}; - - int q_batch_stride = q_row_stride * ${num_queries}; - int k_batch_stride = k_row_stride * ${num_keys}; - int v_batch_stride = v_row_stride * ${num_keys}; - int o_batch_stride = o_row_stride * ${num_queries}; - - flash_attn::flash_attention_forward( - static_cast(${qkv}->data), - static_cast(${qkv}->data) + ${head_dim} * ${num_heads}, - static_cast(${qkv}->data) + ${head_dim} * ${num_heads} * 2, - static_cast(out0->data), - ${num_batches}, - ${num_queries}, - ${num_keys}, - ${num_heads}, - ${num_heads}, - ${head_dim}, - q_batch_stride, - k_batch_stride, - v_batch_stride, - o_batch_stride, - q_head_stride, - k_head_stride, - v_head_stride, - o_head_stride, - q_row_stride, - k_row_stride, - v_row_stride, - o_row_stride, - ${scale}, - ${is_causal}, - nullptr); - """ - - if "qkv" in attrs: - return substitute_template(template_stacked, attrs) - - return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 7b1ab67172..1b8b88bb1d 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -59,7 +59,6 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): cutlass_util_include = os.path.join(cutlass_root, "tools/util/include") cutlass_attention_include = os.path.join(cutlass_root, "examples/41_fused_multi_head_attention") cutlass_fpA_intB_gemm_include = os.path.join(cutlass_root, "../cutlass_fpA_intB_gemm") - flash_attn_include = os.path.join(cutlass_root, "../libflash_attn/include") kwargs = {} kwargs["cc"] = "nvcc" @@ -78,7 +77,6 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): f"-I{cutlass_util_include}", f"-I{cutlass_attention_include}", f"-I{cutlass_fpA_intB_gemm_include}", - f"-I{flash_attn_include}", ] if use_fast_math: kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID") diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 7133193c1e..bf02d8f7b8 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -29,10 +29,7 @@ from tvm.tir import IntImm from . import _ffi_api as ffi -from .attention_operation import ( - instantiate_attention_template, - instantiate_flash_attention_template, -) +from .attention_operation import instantiate_attention_template from .conv2d_operation import instantiate_conv2d_template from .gemm_operation import instantiate_gemm_template, emit_fp16A_intB_matmul from .layer_norm_operation import instantiate_layer_norm_template @@ -715,6 +712,7 @@ def get_batch_on_arg(arg_name, arg_shape): return CodegenResult(code, headers) elif "attention" in func_name: + headers.append("kernel_forward.h") data_type = dtype_map[annotations["arg0_dtype"]] attrs["qkv_layout"] = annotations["qkv_layout"] @@ -741,86 +739,62 @@ def get_batch_on_arg(arg_name, arg_shape): attrs["head_dim"] = h = annotations["head_dim"] attrs["head_dim_value"] = h_v = annotations["head_dim_value"] attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"])) + + data_type_size = DataTypeSize[data_type] + if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0: + attrs["kIsAligned"] = True + elif (h % 4 == 0) and (h_v % 4 == 0): + attrs["kIsAligned"] = False + else: + raise NotImplementedError() + if h_v > 64: + attrs["kQueriesPerBlock"] = 32 + attrs["kKeysPerBlock"] = 128 + attrs["kSingleValueIteration"] = h_v <= 128 + else: + attrs["kQueriesPerBlock"] = 64 + attrs["kKeysPerBlock"] = 64 + attrs["kSingleValueIteration"] = True + attrs["output_size"] = f"{b} * {s} * {n} * {h_v}" attrs["scale"] = ( float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"] ) - - use_flash = ( - annotations["ret_dtype"] == "float16" - and "bias" not in attrs - and int(attrs["head_dim"]) <= 256 - and int(attrs["head_dim"]) % 8 == 0 - and int(attrs["head_dim"]) == int(attrs["head_dim_value"]) - # We have not thoroughly validated flash with causal mask yet, so for now we support - # only non-causal cases. - and int(annotations["custom_mask_type"]) == 0 - # Flash v2 is currently not supported for sm < 80 - and int(annotations["arch"]) >= 80 - ) - - if use_flash: - headers.append("flash.h") - attrs["is_causal"] = int(annotations["custom_mask_type"]) == 0 - code = instantiate_flash_attention_template(attrs) - else: - headers.append("kernel_forward.h") - - data_type_size = DataTypeSize[data_type] - if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0: - attrs["kIsAligned"] = True - elif (h % 4 == 0) and (h_v % 4 == 0): - attrs["kIsAligned"] = False - else: - raise NotImplementedError() - if h_v > 64: - attrs["kQueriesPerBlock"] = 32 - attrs["kKeysPerBlock"] = 128 - attrs["kSingleValueIteration"] = h_v <= 128 - else: - attrs["kQueriesPerBlock"] = 64 - attrs["kKeysPerBlock"] = 64 - attrs["kSingleValueIteration"] = True - - assert ( - attrs["scale"] > 0 or attrs["scale"] < 0 - ), "Cutlass may generate nan occasionally when scale == 0.0" - attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) - attrs["kSupportsDropout"] = False - - attrs["output_size"] = f"{b} * {s} * {n} * {h_v}" - - attrs["custom_mask_type"] = annotations["custom_mask_type"] - - for arg in func_args: - if "workspace" in arg: - attrs["workspace"] = arg - if "bias" in attrs: - attrs["kSupportsBias"] = True - if len(annotations["bias_shape"]) == 4: - strides = "p.num_keys" - if annotations["bias_shape"][2] == 1: - attrs["bias_strideM"] = 0 - else: - attrs["bias_strideM"] = strides - strides = f"p.num_queries * {strides}" - if annotations["bias_shape"][1] == 1: - attrs["bias_strideH"] = 0 - else: - attrs["bias_strideH"] = strides - strides = f"p.num_heads * {strides}" - if annotations["bias_shape"][0] == 1: - attrs["bias_strideB"] = 0 - else: - attrs["bias_strideB"] = strides + attrs["custom_mask_type"] = annotations["custom_mask_type"] + + assert ( + attrs["scale"] > 0 or attrs["scale"] < 0 + ), "Cutlass may generate nan occasionally when scale == 0.0" + attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) + attrs["kSupportsDropout"] = False + + for arg in func_args: + if "workspace" in arg: + attrs["workspace"] = arg + if "bias" in attrs: + attrs["kSupportsBias"] = True + if len(annotations["bias_shape"]) == 4: + strides = "p.num_keys" + if annotations["bias_shape"][2] == 1: + attrs["bias_strideM"] = 0 + else: + attrs["bias_strideM"] = strides + strides = f"p.num_queries * {strides}" + if annotations["bias_shape"][1] == 1: + attrs["bias_strideH"] = 0 + else: + attrs["bias_strideH"] = strides + strides = f"p.num_heads * {strides}" + if annotations["bias_shape"][0] == 1: + attrs["bias_strideB"] = 0 else: - raise NotImplementedError() + attrs["bias_strideB"] = strides else: - # To support negative scale in current Cutlass implementation, - # kSupportsBias should be set true, or there are nan's as result. - attrs["kSupportsBias"] = attrs["scale"] < 0 - - code = instantiate_attention_template(attrs) - + raise NotImplementedError() + else: + # To support negative scale in current Cutlass implementation, + # kSupportsBias should be set true, or there are nan's as result. + attrs["kSupportsBias"] = attrs["scale"] < 0 + code = instantiate_attention_template(attrs) return CodegenResult(code, headers) elif "layer_norm" in func_name: headers.append("cutlass/util/device_layernorm.h") diff --git a/python/tvm/topi/einsum.py b/python/tvm/topi/einsum.py index f1f426ec81..cd4ced87d9 100644 --- a/python/tvm/topi/einsum.py +++ b/python/tvm/topi/einsum.py @@ -16,10 +16,12 @@ # under the License. # pylint: disable=invalid-name,consider-using-enumerate,redefined-outer-name """Einsum operator""" +from tvm.runtime import convert + from . import cpp -def einsum(subscripts, *operand): +def einsum(subscripts, *operand, fcompute=None, fcombine=None, fidentity=None): """Evaluates the Einstein summation convention on the operands. Parameters @@ -35,10 +37,77 @@ def einsum(subscripts, *operand): The only difference of einsum between in tvm and numpy is it needs an extra brackets for the tensors. For example, topi.einsum("ij, jk -> ik", (A, B)). + fcompute : function(List[value] -> List[value]) + Specifies the computation expression of the innermost loop. + + fcombine : function(Expr, Expr -> Expr) + Specifies the associative computation involved in constructing the commutative reduction. + + fidentity: function(List[str] -> List[Expr]) + Establishes the identity elements for the commutative reduction process. + + Returns ------- out : tvm.te.Tensor The calculation based on the Einstein summation convention. """ - return cpp.einsum(subscripts, operand) + def wrap_fcompute(fcompute): + if fcompute is None: + return None + + # On the C++ side, fcompute is utilized with an input of Array, + # and is expects to return Array. + def wrapped_fcompute(array_var): + args = [array_var[i] for i in range(len(array_var))] + + result = fcompute(*args) + if not isinstance(result, (list, tuple)): + result = [result] + result = convert(result) + return result + + return wrapped_fcompute + + # On the C++ side, fcompute is expects to return Array. + def wrap_fcombine(fcombine): + if fcombine is None: + return None + + def wrapped_fcombine(x, y): + result = fcombine(x, y) + if not isinstance(result, (list, tuple)): + result = [result] + result = convert(result) + return result + + return wrapped_fcombine + + # On the C++ side, fcompute is utilized with an input of Array, + # and is expects to return Array. + def wrap_fidentity(fidentity): + if fidentity is None: + return None + + def wrapped_fidentity(array_dtype): + dtypes = [array_dtype[i] for i in range(len(array_dtype))] + + result = fidentity(*dtypes) + if not isinstance(result, (list, tuple)): + result = [result] + result = convert(result) + return result + + return wrapped_fidentity + + result = cpp.einsum( + subscripts, + operand, + wrap_fcompute(fcompute), + wrap_fcombine(fcombine), + wrap_fidentity(fidentity), + ) + if len(result) == 1: + result = result[0] + return result diff --git a/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc b/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc index 13b7d94706..e9d60553e7 100644 --- a/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc +++ b/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc @@ -36,6 +36,7 @@ #define DMLC_USE_LOGGING_LIBRARY #include +#include #include namespace { @@ -45,6 +46,7 @@ static const std::vector default_so_paths = { #elif defined(__ANDROID__) static const std::vector default_so_paths = { "libOpenCL.so", + "libOpenCL-pixel.so", "/system/lib64/libOpenCL.so", "/system/vendor/lib64/libOpenCL.so", "/system/vendor/lib64/egl/libGLES_mali.so", @@ -66,6 +68,9 @@ static const std::vector default_so_paths = {"libOpenCL.so", "/usr/lib32/libOpenCL.so"}; #endif +typedef void (*enableOpenCL_t)(); +typedef void* (*loadOpenCLPointer_t)(const char* name); + class LibOpenCLWrapper { public: static LibOpenCLWrapper& getInstance() { @@ -79,7 +84,11 @@ class LibOpenCLWrapper { #if defined(_WIN32) return GetProcAddress(m_libHandler, funcName); #else - return dlsym(m_libHandler, funcName); + if (loadOpenCLPointer != nullptr) { + return loadOpenCLPointer(funcName); + } else { + return dlsym(m_libHandler, funcName); + } #endif } @@ -98,6 +107,21 @@ class LibOpenCLWrapper { m_libHandler = LoadLibrary(it); #else m_libHandler = dlopen(it, RTLD_LAZY); + + if (std::strcmp(it, "libOpenCL-pixel.so") == 0) { + enableOpenCL_t enableOpenCL = + reinterpret_cast(dlsym(m_libHandler, "enableOpenCL")); + if (enableOpenCL == nullptr) { + continue; + } + enableOpenCL(); + loadOpenCLPointer = + reinterpret_cast(dlsym(m_libHandler, "loadOpenCLPointer")); + if (loadOpenCLPointer == nullptr) { + continue; + } + } + #endif if (m_libHandler != nullptr) return; } @@ -109,6 +133,8 @@ class LibOpenCLWrapper { HMODULE m_libHandler = nullptr; #else void* m_libHandler = nullptr; + loadOpenCLPointer_t loadOpenCLPointer; + #endif }; diff --git a/src/topi/einsum.cc b/src/topi/einsum.cc index c69baf3ebd..1fc1cd8ff6 100644 --- a/src/topi/einsum.cc +++ b/src/topi/einsum.cc @@ -39,19 +39,27 @@ EinsumEquation EinsumEquation::FromString(const std::string& equation) { // Ignore spaces break; case '-': - // Arrow + // Arrow, end of inputs, push current CHECK(!has_arrow) << "Equation can only have one arrow"; CHECK(i + 1 < n && equation[i + 1] == '>') << "Cannot parse the Einsum equation: invalid arrow"; i++; has_arrow = true; - [[fallthrough]]; - case ',': - // Delimiter between inputs, push current and start a new one result.inputs.emplace_back(current); current.clear(); has_ellipsis = false; break; + case ',': + if (has_arrow) { + // Delimiter between outputs, push current and start a new one + result.SetOutput(current); + } else { + // Delimiter between inputs, push current and start a new one + result.inputs.emplace_back(current); + } + current.clear(); + has_ellipsis = false; + break; case '.': // Ellipsis CHECK(!has_ellipsis) << "Ellipsis can only appear once for each input and output"; @@ -72,7 +80,7 @@ EinsumEquation EinsumEquation::FromString(const std::string& equation) { if (has_arrow) { // If there is an arrow, the last subscript is the output - result.output = current; + result.SetOutput(current); } else { // Otherwise, the equation is in implicit mode, and the last subscript is an input result.inputs.emplace_back(current); @@ -80,6 +88,7 @@ EinsumEquation EinsumEquation::FromString(const std::string& equation) { // Convert the equation to explicit mode if it is in implicit mode if (!has_arrow) { + Subscript output; // The output of the implicit mode is all repeated labels sorted in alphabetical order and the // ellipsis in the leftmost if it exists in the inputs. std::map label_counts; @@ -90,13 +99,23 @@ EinsumEquation EinsumEquation::FromString(const std::string& equation) { } for (auto [label, count] : label_counts) { if (label == kEllipsis || count == 1) { - result.output.emplace_back(label); + output.emplace_back(label); } } + result.SetOutput(output); } return result; } +void EinsumEquation::SetOutput(Subscript output_subscript) { + if (num_outputs == 0) { + output = output_subscript; + } else { + CHECK(output == output_subscript) << "The output subscript should be the same."; + } + num_outputs++; +} + PrimExpr GetBroadcastedExtent(const PrimExpr& extent1, const PrimExpr& extent2) { const IntImmNode* extent1_imm = extent1.as(); const IntImmNode* extent2_imm = extent2.as(); @@ -204,7 +223,8 @@ class EinsumBuilder { return output_shape_; } - PrimExpr BuildOutputExpr(const Array inputs, const Array& indices) { + Array BuildOutputExpr(const Array inputs, const Array& indices, + PackedFunc fcompute, PackedFunc fcombine, PackedFunc fidentity) { std::unordered_map label_to_index; Array ellipsis_indices; Array reduce_axes; @@ -214,22 +234,84 @@ class EinsumBuilder { auto zero = make_zero(inputs[0]->dtype); - PrimExpr result = zero; + Array results; + Array operands; for (int i = 0, n = static_cast(inputs.size()); i < n; ++i) { - auto term = inputs[i](GetIndicesForOperand(i, label_to_index, ellipsis_indices)); - if (i == 0) { - result = term; - } else { - result = result * term; + tvm::PrimExpr term = inputs[i](GetIndicesForOperand(i, label_to_index, ellipsis_indices)); + operands.push_back(term); + } + + if (fcompute != nullptr) { + // Call customized fcompute + results = fcompute(operands); + CHECK(results.size() == equation_.num_outputs) + << "fcompute is intended to produce " << equation_.num_outputs + << " outputs, but only returns " << results.size(); + } else { + // Default computation: multiply all the operands together. + PrimExpr result; + for (int i = 0, n = static_cast(inputs.size()); i < n; ++i) { + if (i == 0) { + result = operands[i]; + } else { + result = result * operands[i]; + } + } + for (size_t i = 0; i < equation_.num_outputs; ++i) { + results.push_back(result); } } + if (reduce_axes.size() > 0) { - result = sum(result, reduce_axes, {zero}); + results = CreateReduce(results, reduce_axes, fcombine, fidentity); } - return result; + return results; } private: + /*! + * \brief Construct reduce: default is sum. + */ + Array CreateReduce(Array source, Array rdom, PackedFunc fcombine, + PackedFunc fidentity, Span span = Span()) { + Array x_; + Array y_; + Array results; + Array identity_elements; + Array inits = {}; + Array data_types; + for (size_t i = 0; i < source.size(); ++i) { + x_.push_back(Var("x_" + std::to_string(i), source[i].dtype(), span)); + y_.push_back(Var("y_" + std::to_string(i), source[i].dtype(), span)); + data_types.push_back(DLDataType2String(source[i].dtype())); + } + + if (fcombine == nullptr && fidentity == nullptr) { + // Default reduction: sum + for (size_t i = 0; i < source.size(); ++i) { + results.push_back(tir::Add(x_[i], y_[i], span)); + identity_elements.push_back(make_zero(source[i].dtype(), span)); + } + } else if (fcombine != nullptr && fidentity != nullptr) { + // Call customized fcombine and fidentity + if (x_.size() == 1) { + results = fcombine(x_[0], y_[0]); + } else { + results = fcombine(x_, y_); + } + identity_elements = fidentity(data_types); + } else { + CHECK(false) << "Define both fcombine and fidentity simultaneously."; + } + tir::CommReducer combiner = tir::CommReducer(x_, y_, results, identity_elements, span); + Array outputs; + PrimExpr condition = make_const(DataType::Bool(1), true); + for (size_t i = 0; i < source.size(); ++i) { + outputs.push_back(tir::Reduce(combiner, source, rdom, condition, i, inits, span)); + } + return outputs; + } + /*! * \brief Prepare mapping from label (including ellipsis) to the output indices */ @@ -333,8 +415,9 @@ class EinsumBuilder { Optional> ellipsis_shape_; }; -Tensor einsum(const std::string& subscripts_str, const Array inputs, std::string name, - std::string tag) { +Array einsum(const std::string& subscripts_str, const Array inputs, + PackedFunc fcompute, PackedFunc fcombine, PackedFunc fidentity, + std::string name, std::string tag) { EinsumEquation equation = EinsumEquation::FromString(subscripts_str); Array> input_shapes; for (const Tensor& input : inputs) { @@ -344,7 +427,9 @@ Tensor einsum(const std::string& subscripts_str, const Array inputs, std auto output_shape = einsum_builder.InferShape(); return te::compute( output_shape, - [&](const Array& indices) { return einsum_builder.BuildOutputExpr(inputs, indices); }, + [&](const Array& indices) { + return einsum_builder.BuildOutputExpr(inputs, indices, fcompute, fcombine, fidentity); + }, name, tag); } @@ -356,7 +441,7 @@ Array InferEinsumShape(const std::string& subscripts, } TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = einsum(args[0], args[1]); + *rv = einsum(args[0], args[1], args[2], args[3], args[4]); }); } // namespace topi diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 952036584f..61831fa1dc 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -854,7 +854,7 @@ def stacked_attention_size(request): def test_stacked_attention_split_offload(stacked_attention_size): b, s, n, (h, h_v), bias_shape, scale, single_shape = stacked_attention_size - qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float16") + qkv, bias, ref = get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, scale, "float32") if scale == "none": mod = get_relax_stacked_attention_module( qkv, b, s, n, h, h_v, "split", bias, single_shape=single_shape diff --git a/tests/python/topi/python/test_topi_einsum_advanced.py b/tests/python/topi/python/test_topi_einsum_advanced.py new file mode 100644 index 0000000000..df38d74651 --- /dev/null +++ b/tests/python/topi/python/test_topi_einsum_advanced.py @@ -0,0 +1,454 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import te +from tvm import topi +from tvm.topi.utils import get_const_tuple + + +def with_tvm(lam, shapes, ops, out_shapes): + """Take numpy arrays as args, convert them to TVM tensors and call `lam`. + Result of lambda is converted back to numpy array and returned. + """ + dev = tvm.cpu(0) + pls = [] # placeholders + vals_nd = [] # initial values + out_nd = [] # output values + for i, (shape, arg) in enumerate(zip(shapes, ops)): + pls.append(te.placeholder(shape, name="pl" + str(i))) + vals_nd.append(tvm.nd.array(arg, dev)) + + outputs = lam(*pls) + if isinstance(outputs, tvm.ir.container.Array): + outputs = [outputs[i] for i in range(len(outputs))] + else: + outputs = [outputs] + for out_shape, out in zip(out_shapes, outputs): + out_nd.append(tvm.nd.array(np.zeros(out_shape).astype(out.dtype), device=dev)) + func = te.create_prim_func(pls + outputs) + m = tvm.build(func, target="llvm") + m(*(vals_nd + out_nd)) + return [out_.numpy() for out_ in out_nd] + + +def verify_einsum(subscripts, shapes, fcompute, fcombine, fidentity, np_lambda, shape_dict={}): + ops = [] # ndarrays to be used as inputs + symbolic_shapes = [] # shapes to declare the placeholders + name_to_var = {} + + def get_concrete_shape(shape): + return [shape_dict[s] if isinstance(s, str) else s for s in shape] + + def get_symblic_shape_var(name, dtype="int32"): + if name not in name_to_var: + name_to_var[name] = te.var(name, dtype=dtype) + return name_to_var[name] + + def get_symbolic_shape(shape): + return [get_symblic_shape_var(s) if isinstance(s, str) else s for s in shape] + + for shape in shapes: + concrete_shape = get_concrete_shape(shape) + tmp = np.random.uniform(low=-1.0, high=1.0, size=concrete_shape).astype(np.float32) + ops.append(tmp) + symbolic_shape = get_symbolic_shape(shape) + symbolic_shapes.append(symbolic_shape) + + np_outs = np_lambda(*ops) + if not isinstance(np_outs, (list, tuple)): + np_outs = [np_outs] + out_shapes = [out_.shape for out_ in np_outs] + + if len(ops) == 1: + tvm_outs = with_tvm( + lambda A: topi.einsum( + subscripts, A, fcompute=fcompute, fcombine=fcombine, fidentity=fidentity + ), + symbolic_shapes, + ops, + out_shapes, + ) + elif len(ops) == 2: + tvm_outs = with_tvm( + lambda A, B: topi.einsum( + subscripts, A, B, fcompute=fcompute, fcombine=fcombine, fidentity=fidentity + ), + symbolic_shapes, + ops, + out_shapes, + ) + elif len(ops) == 3: + tvm_outs = with_tvm( + lambda A, B, C: topi.einsum( + subscripts, A, B, C, fcompute=fcompute, fcombine=fcombine, fidentity=fidentity + ), + symbolic_shapes, + ops, + out_shapes, + ) + + assert len(np_outs) == len(tvm_outs) + for c1, c2 in zip(np_outs, tvm_outs): + tvm.testing.assert_allclose(c1, c2, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "equation,shapes,fcompute,fcombine,fidentity,np_lambda", + [ + ("ij->ij, ij", [(5, 5)], None, None, None, lambda A: (A, A)), # cannot have ii in output + ("ij->ij, ij", [(5, 5)], None, None, None, lambda A: (A, A)), + ("...j->...j, ...j", [(5, 5)], None, None, None, lambda A: (A, A)), + ( + "ijk, jil->ijkl, ijkl", + [(3, 4, 5), (4, 3, 2)], + None, + None, + None, + lambda A, B: (np.einsum("ijk, jil->ijkl", A, B), np.einsum("ijk, jil->ijkl", A, B)), + ), + ("ij, ij -> ij, ij", [(1, 4), (2, 4)], None, None, None, lambda A, B: (A * B, A * B)), + ( + "...ij, ...jk -> ...ijk, ...ijk", + [(1, 4), (4, 2)], + None, + None, + None, + lambda A, B: ( + np.einsum("...ij, ...jk->...ijk", A, B), + np.einsum("...ij, ...jk->...ijk", A, B), + ), + ), + ( + "...ij, ...ik -> ...ijk, ...ijk", + [(1, 1, 1, 4), (1, 1, 1, 3)], + None, + None, + None, + lambda A, B: ( + np.einsum("...ij, ...ik -> ...ijk", A, B), + np.einsum("...ij, ...ik -> ...ijk", A, B), + ), + ), + ( + "...ik, ...jk, ...hk -> i...jhk, i...jhk", + [(3, 4, 4), (1, 5, 3, 8, 4), (2, 5, 3, 6, 4)], + None, + None, + None, + lambda A, B, C: ( + np.einsum("...ik, ...jk, ...hk -> i...jhk", A, B, C), + np.einsum("...ik, ...jk, ...hk -> i...jhk", A, B, C), + ), + ), + ( + "ij,jk->ijk, ijk", + [(2, 3), (3, 4)], + None, + None, + None, + lambda A, B: (np.einsum("ij,jk->ijk", A, B), np.einsum("ij,jk->ijk", A, B)), + ), + ( + "ij,jk,km->ijkm, ijkm", + [(2, 3), (3, 4), (4, 5)], + None, + None, + None, + lambda A, B, C: ( + np.einsum("ij,jk,km->ijkm", A, B, C), + np.einsum("ij,jk,km->ijkm", A, B, C), + ), + ), + ], +) +def test_multi_outputs_without_reduction( + equation, shapes, fcompute, fcombine, fidentity, np_lambda +): + verify_einsum(equation, shapes, fcompute, fcombine, fidentity, np_lambda) + + +@pytest.mark.parametrize( + "equation,shapes,fcompute,fcombine,fidentity,np_lambda", + [ + ( + "ii -> ,", + [(5, 5)], + None, + None, + None, + lambda A: (np.einsum("ii->", A), np.einsum("ii->", A)), + ), + ( + "ii->i, i", + [(5, 5)], + None, + None, + None, + lambda A: (np.einsum("ii->i", A), np.einsum("ii->i", A)), + ), + ( + "ij->i, i", + [(5, 5)], + None, + None, + None, + lambda A: (np.einsum("ij->i", A), np.einsum("ij->i", A)), + ), + ( + "...j->..., ...", + [(5, 5)], + None, + None, + None, + lambda A: (np.einsum("...j->...", A), np.einsum("...j->...", A)), + ), + ( + "...j, j->..., ...", + [(5, 5), (5,)], + None, + None, + None, + lambda A, B: (np.einsum("...j, j->...", A, B), np.einsum("...j, j->...", A, B)), + ), + ( + "..., ...-> ..., ...", + [(), (2, 3)], + None, + None, + None, + lambda A, B: (np.einsum("..., ... -> ...", A, B), np.einsum("..., ... -> ...", A, B)), + ), + ( + "ijk, jil->kl, kl", + [(3, 4, 5), (4, 3, 2)], + None, + None, + None, + lambda A, B: (np.einsum("ijk, jil->kl", A, B), np.einsum("ijk, jil->kl", A, B)), + ), + ( + "ij, ij -> i, i", + [(1, 4), (2, 4)], + None, + None, + None, + lambda A, B: (np.einsum("ij, ij -> i", A, B), np.einsum("ij, ij -> i", A, B)), + ), + ( + "...ij, ...jk -> ...ik, ...ik", + [(1, 4), (4, 2)], + None, + None, + None, + lambda A, B: ( + np.einsum("...ij, ...jk -> ...ik", A, B), + np.einsum("...ij, ...jk -> ...ik", A, B), + ), + ), + ( + "...ij, ...ik -> ...jk, ...jk", + [(1, 1, 1, 4), (1, 1, 1, 3)], + None, + None, + None, + lambda A, B: ( + np.einsum("...ij, ...ik -> ...jk", A, B), + np.einsum("...ij, ...ik -> ...jk", A, B), + ), + ), + ( + "...ik, ...jk, ...hk -> i...jh, i...jh", + [(3, 4, 4), (1, 5, 3, 8, 4), (2, 5, 3, 6, 4)], + None, + None, + None, + lambda A, B, C: ( + np.einsum("...ik, ...jk, ...hk -> i...jh", A, B, C), + np.einsum("...ik, ...jk, ...hk -> i...jh", A, B, C), + ), + ), + ( + "ij,jk->ik,ik", + [(2, 3), (3, 4)], + None, + None, + None, + lambda A, B: (np.einsum("ij, jk->ik", A, B), np.einsum("ij, jk->ik", A, B)), + ), + ( + "ij,jk,km->im,im", + [(2, 3), (3, 4), (4, 5)], + None, + None, + None, + lambda A, B, C: ( + np.einsum("ij,jk,km->im", A, B, C), + np.einsum("ij,jk,km->im", A, B, C), + ), + ), + ], +) +def test_multi_outpus_with_default_reduction( + equation, shapes, fcompute, fcombine, fidentity, np_lambda +): + verify_einsum(equation, shapes, fcompute, fcombine, fidentity, np_lambda) + + +@pytest.mark.parametrize( + "equation,shapes,fcompute,fcombine,fidentity,np_lambda", + [ + ("ij->ij", [(5, 5)], lambda x_ij: x_ij, None, None, lambda x: x), + ( + "ij->ij", + [(5, 5)], + lambda x_ij: x_ij * x_ij + x_ij, + None, + None, + lambda x: np.power(x, 2) + x, + ), + ("ij->ij", [(5, 5)], lambda x_ij: 1.0, None, None, lambda x: np.ones_like(x)), + ("ij->i", [(5, 5)], lambda x_ij: x_ij * 2, None, None, lambda x: np.einsum("ij->i", 2 * x)), + ( + "ij,jk->ik", + [(2, 3), (3, 4)], + lambda x_ij, y_jk: x_ij * y_jk + y_jk, + None, + None, + lambda x, y: x @ y + np.einsum("jk->k", y), + ), + ( + "ij,jk,km->im", + [(2, 3), (3, 4), (4, 5)], + lambda x_ij, y_jk, z_km: x_ij * y_jk * z_km, + None, + None, + lambda x, y, z: np.einsum("ij,jk,km->im", x, y, z), + ), + ( + ("ij->ij, ij"), + [(5, 5)], + lambda x_ij: (x_ij, x_ij * x_ij), + None, + None, + lambda x: (x, np.power(x, 2)), + ), + ( + "ij->i, i", + [(5, 5)], + lambda x_ij: (x_ij, x_ij * x_ij), + None, + None, + lambda x: (np.sum(x, axis=-1), np.sum(x * x, axis=-1)), + ), + ( + "ij,jk->ik, ik", + [(2, 3), (3, 4)], + lambda x_ij, y_jk: (x_ij * y_jk, x_ij + y_jk), + None, + None, + lambda x, y: ( + x @ y, + np.sum(x, axis=1, keepdims=True) + np.sum(y, axis=0, keepdims=True), + ), + ), + ( + "ij,jk,km->im, im", + [(2, 3), (3, 4), (4, 5)], + lambda x_ij, y_jk, z_km: (x_ij * y_jk * z_km, x_ij * y_jk / z_km), + None, + None, + lambda x, y, z: ( + np.einsum("ij,jk,km->im", x, y, z), + np.einsum("ij,jk,km->im", x, y, 1 / z), + ), + ), + ], +) +def test_customize_compute(equation, shapes, fcompute, fcombine, fidentity, np_lambda): + verify_einsum(equation, shapes, fcompute, fcombine, fidentity, np_lambda) + + +@pytest.mark.parametrize( + "equation,shapes,fcompute,fcombine,fidentity,np_lambda", + [ + ( + "ij->ij", + [(5, 5)], + None, + lambda x, y: x + y, # no accumulate + lambda dtype1: tvm.tir.const(0, dtype1), + lambda x: x, + ), + ( + "ij->i", + [(5, 5)], + None, + lambda x, y: x + y, + lambda dtype1: tvm.tir.const(0, dtype1), + lambda x: np.sum(x, axis=1), + ), + ( + "ij->i", + [(5, 5)], + None, + lambda x, y: x * y, + lambda dtype1: tvm.tir.const(1, dtype1), + lambda x: np.prod(x, axis=1), + ), + ( + "ij->i", + [(5, 5)], + lambda x_ij: 2 * x_ij, + lambda x, y: x * y, + lambda dtype1: tvm.tir.const(1, dtype1), + lambda x: np.prod(2 * x, axis=1), + ), + ( + "ij,jk->ik", + [(2, 3), (3, 4)], + lambda x_ij, y_jk: x_ij + y_jk, + lambda x, y: x * y, + lambda dtype1: tvm.tir.const(1, dtype1), + lambda x, y: np.prod(np.expand_dims(x, -1) + np.expand_dims(y, 0), axis=1), + ), + ( + "ij,jk,km->im", + [(2, 3), (3, 4), (4, 5)], + None, + lambda x, y: x + y, + lambda dtype1: tvm.tir.const(0, dtype1), + lambda x, y, z: np.einsum("ij,jk,km->im", x, y, z), + ), + ( + "ij->i, i", + [(5, 5)], + lambda x_ij: (x_ij, x_ij), + lambda x, y: (x[0] + y[0], x[1] * y[1]), + lambda dtype1, dtype2: (tvm.tir.const(0, dtype1), tvm.tir.const(1, dtype2)), + lambda x: (np.sum(x, axis=-1), np.prod(x, axis=-1)), + ), + ], +) +def test_customize_combine(equation, shapes, fcompute, fcombine, fidentity, np_lambda): + verify_einsum(equation, shapes, fcompute, fcombine, fidentity, np_lambda) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/scripts/mlc/task_mlc_build.sh b/tests/scripts/mlc/task_mlc_build.sh new file mode 100755 index 0000000000..c38832677c --- /dev/null +++ b/tests/scripts/mlc/task_mlc_build.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +set -euxo pipefail + +cd build +cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo .. +make -j8 diff --git a/tests/scripts/mlc/task_mlc_lint_cleanup.sh b/tests/scripts/mlc/task_mlc_lint_cleanup.sh new file mode 100755 index 0000000000..a9cacb9805 --- /dev/null +++ b/tests/scripts/mlc/task_mlc_lint_cleanup.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +set -euxo pipefail + +echo "Cleanup before linting..." +# Remove clang-format-index.locok +rm -f .git/clang-format-index.lock