From ae89c1e56da5ac38e68575c0baec047c442b266d Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Mon, 25 Sep 2023 04:32:22 +0300 Subject: [PATCH 01/13] [OpenCL] Don't initialize OpenCL runtime on host (#15745) * [OpenCL] Don't initialize OpenCL runtime on host After adding OpenCL wrapper, it is possible to build TVM with OpenCL support also on the host which doesn't have OpenCL libraries. But if you want to compile OpenCL module for a remote device on such host machine then you will see an error that OpenCL lib cannot be open. To avoid such problem, we need to call OpenCL functions only in runtime. So function for initializing OpenCL workspace was removed from OpenCLModuleNode. And a new function `IsProgramCreated` was added. The last function is necessary to prepare vectors with OpenCL programs, associated with OpenCL devices. Previously it was done during OpenCLModule initialization. So, now we create such vectors only in runtime after getting list of available OpenCL devices. * Call workspace init function before all OpenCL API calls --- src/runtime/opencl/opencl_common.h | 6 ++++-- src/runtime/opencl/opencl_device_api.cc | 4 ++++ src/runtime/opencl/opencl_module.cc | 22 +++++++++++++++------- src/runtime/opencl/opencl_module.h | 1 + 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index a303141357..8c1607c4e5 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -220,7 +220,7 @@ struct BufferDescriptor; class OpenCLWorkspace : public DeviceAPI { public: // type key - std::string type_key; + std::string type_key{"opencl"}; // available platforms std::vector platform_ids; // map platform to its context @@ -253,7 +253,7 @@ class OpenCLWorkspace : public DeviceAPI { // Initialize the device. void Init(const std::string& type_key, const std::string& device_type, const std::string& platform_name = ""); - virtual void Init() { Init("opencl", "gpu"); } + virtual void Init() { Init(this->type_key, "gpu"); } // Check whether the context is OpenCL or not. virtual bool IsOpenCLDevice(Device dev) { return dev.device_type == kDLOpenCL; } // get the queue of the device @@ -465,6 +465,8 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + // Return true if OpenCL program for the requested function and device was created + bool IsProgramCreated(const std::string& func_name, int device_id); void SaveToFile(const String& file_name, const String& format) final; void SaveToBinary(dmlc::Stream* stream) final; void SetPreCompiledPrograms(const std::string& bytes); diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 35e77eb6d1..fb9adc2757 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -111,6 +111,7 @@ OpenCLWorkspace* OpenCLWorkspace::Global() { } cl_device_id OpenCLWorkspace::GetCLDeviceID(int device_id) { + this->Init(); ICHECK_LT(device_id, devices.size()) << "Invalid device id " << device_id << ". " << GetError(); return devices[device_id]; } @@ -210,6 +211,7 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) void* OpenCLWorkspace::CreateHostPtrIfEnabled(cl::BufferDescriptor* desc, Device dev, size_t size) { #if defined(OPENCL_ENABLE_HOST_PTR) + this->Init(); cl_int err_code; desc->host_ptr = reinterpret_cast( clEnqueueMapBuffer(this->GetQueue(dev), desc->buffer, CL_TRUE, CL_MAP_WRITE, 0, @@ -300,6 +302,7 @@ void OpenCLWorkspace::FreeTextureWorkspace(Device dev, void* ptr) { } void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { + this->Init(); size_t nbytes = GetDataSize(*from); ICHECK_EQ(nbytes, GetDataSize(*to)); ICHECK(IsContiguous(*from) && IsContiguous(*to)) @@ -379,6 +382,7 @@ void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHand } void OpenCLWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { + this->Init(); ICHECK(stream == nullptr); OPENCL_CALL(clFinish(this->GetQueue(dev))); } diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 6829d46d43..567b7ad88a 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -185,7 +185,6 @@ String OpenCLModuleNode::GetSource(const String& format) { void OpenCLModuleNode::Init() { workspace_ = GetGlobalWorkspace(); - workspace_->Init(); // initialize the kernel id, need to lock global table. std::lock_guard lock(workspace_->mu); for (const auto& kv : fmap_) { @@ -208,10 +207,17 @@ void OpenCLModuleNode::Init() { << "delimiter was found."; ICHECK_EQ(fmap_.size(), parsed_kernels_.size()) << "The number of parsed kernel sources does not match the number of kernel functions"; +} + +bool OpenCLModuleNode::IsProgramCreated(const std::string& func_name, int device_id) { + auto size = programs_[func_name].size(); + if (size > 0 && programs_[func_name][device_id] != nullptr) return true; + auto dev_size = GetGlobalWorkspace()->devices.size(); + ICHECK(device_id < static_cast(dev_size)) + << "Device id " << device_id << " is bigger than number of available devices"; // zero initialize cl_program pointers for each device kernel - for (auto& kv : parsed_kernels_) { - programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); - } + if (size == 0) programs_[func_name].resize(dev_size, nullptr); + return false; } cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, @@ -220,7 +226,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre int device_id = t->device.device_id; auto did = w->GetCLDeviceID(device_id); auto platform = w->device_to_platform[did]; - if (programs_[func_name][device_id] == nullptr) { + if (!IsProgramCreated(func_name, device_id)) { // create program if (fmt_ == "cl") { const char* s = parsed_kernels_[func_name].c_str(); @@ -268,6 +274,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre } void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { + workspace_->Init(); std::string data = bytes; dmlc::MemoryStringStream reader(&data); dmlc::Stream* strm = &reader; @@ -280,7 +287,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { std::vector bin_vector; strm->Read(&name); strm->Read(&bin_vector); - if (programs_[name][device_id] == nullptr) { + if (!IsProgramCreated(name, device_id)) { cl_int err = 0; cl_int binaryStatus; size_t binarySize = bin_vector.size(); @@ -310,6 +317,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { } std::string OpenCLModuleNode::GetPreCompiledPrograms() { + workspace_->Init(); std::string data; dmlc::MemoryStringStream writer(&data); dmlc::Stream* strm = &writer; @@ -319,7 +327,7 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { cl::OpenCLThreadEntry* t = workspace_->GetThreadEntry(); int device_id = t->device.device_id; t->kernel_table.resize(workspace_->num_registered_kernels); - if (programs_[std::string(name)][device_id] == nullptr) { + if (!IsProgramCreated(name, device_id)) { InstallKernel(workspace_, t, name, kid_map_[name]); } size_t size; diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 834f53510e..22fc119e03 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -42,6 +42,7 @@ namespace runtime { * \param data The module data. * \param fmt The format of the data, can be "clbin", "cl" * \param fmap The map function information map of each function. + * \param source Generated OpenCL kernels. */ Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source); From cde83e108880f96bc5df58a8e92c02a9aeff26a6 Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 25 Sep 2023 21:26:45 +0530 Subject: [PATCH 02/13] [TVMC] enable dumping imported modules too (#15779) Now we can dump the imported modules source too like device code. --- python/tvm/driver/tvmc/compiler.py | 2 ++ tests/python/driver/tvmc/test_compiler.py | 1 + 2 files changed, 3 insertions(+) diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 3496136470..4273674180 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -445,6 +445,8 @@ def compile_model( # TODO lib.get_source call have inconsistent behavior for unsupported # formats (@leandron). dumps[source_type] = lib.get_source(source_type) + for smod in lib.imported_modules: + dumps[smod.type_key] = smod.get_source() # Create a new tvmc model package object from the graph definition. package_path = tvmc_model.export_package( diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 6165099558..ca56172cb7 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -798,6 +798,7 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128): assert type(tvmc_package.lib_path) is str assert type(tvmc_package.params) is bytearray assert os.path.exists(dumps_path) + assert path.exists("{}.{}".format(tvmc_package.package_path, "opencl")) @tvm.testing.requires_cmsisnn From dfd525bda5acdecb148220e32a7028e48797b10d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 26 Sep 2023 05:26:31 -0700 Subject: [PATCH 03/13] Revert "[TensorIR][Visitor] Visit buffer members in `match_buffer`'s in block visitor functions (#15153) (#15816) * Revert "[TensorIR][Visitor] Visit buffer members in `match_buffer`'s in block visitor functions (#15153)" --- src/tir/ir/stmt_functor.cc | 32 ++------------ ...test_tir_transform_unify_thread_binding.py | 43 ------------------- 2 files changed, 4 insertions(+), 71 deletions(-) diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 7d1fe9f8dd..1c15f95826 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -135,15 +135,8 @@ void StmtVisitor::VisitStmt_(const BlockNode* op) { VisitArray(op->reads, fvisit_buffer_region); VisitArray(op->writes, fvisit_buffer_region); VisitArray(op->match_buffers, - [this, fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { + [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { fvisit_buffer_region(match_buffer_region->source); - this->VisitExpr(match_buffer_region->buffer->elem_offset); - VisitArray(match_buffer_region->buffer->strides, - [this](const PrimExpr& e) { this->VisitExpr(e); }); - VisitArray(match_buffer_region->buffer->shape, - [this](const PrimExpr& e) { this->VisitExpr(e); }); - VisitArray(match_buffer_region->buffer->axis_separators, - [this](const IntImm& e) { this->VisitExpr(e); }); }); if (op->init.defined()) { this->VisitStmt(op->init.value()); @@ -245,28 +238,11 @@ class StmtMutator::Internal { static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const MatchBufferRegion& match_buffer_region) { - const Buffer& buffer = match_buffer_region->buffer; Array region = Mutate(self, match_buffer_region->source->region); - PrimExpr elem_offset = self->VisitExpr(buffer->elem_offset); - Array strides = Mutate(self, buffer->strides); - Array shape = Mutate(self, buffer->shape); - Array axis_separators = - MutateArray(self, buffer->axis_separators, - [self](const IntImm& e) { return Downcast(self->VisitExpr(e)); }); - - if (elem_offset.same_as(buffer->elem_offset) && strides.same_as(buffer->strides) && - shape.same_as(buffer->shape) && axis_separators.same_as(buffer->axis_separators)) { - if (region.same_as(match_buffer_region->source->region)) { - return match_buffer_region; - } else { - return MatchBufferRegion(buffer, - BufferRegion(match_buffer_region->source->buffer, region)); - } + if (region.same_as(match_buffer_region->source->region)) { + return match_buffer_region; } else { - Buffer new_buffer(buffer->data, buffer->dtype, shape, strides, elem_offset, buffer->name, - buffer->data_alignment, buffer->offset_factor, buffer->buffer_type, - axis_separators, buffer->span); - return MatchBufferRegion(new_buffer, + return MatchBufferRegion(match_buffer_region->buffer, BufferRegion(match_buffer_region->source->buffer, region)); } }; diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py b/tests/python/unittest/test_tir_transform_unify_thread_binding.py index d42adfcee4..9ee8643312 100644 --- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py +++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py @@ -258,45 +258,6 @@ def unified_element_wise_implicit_block(a: T.handle, b: T.handle, c: T.handle) - ) -@T.prim_func -def match_buffer_with_elem_offset( - A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset: T.int32 -) -> None: - for i in T.thread_binding(0, 4, "blockIdx.x"): - for j in range(2): - with T.block(): - T.writes(A[I[i], offset, j * 4 : j * 4 + 4]) - sub_A = T.match_buffer( - A[I[i], offset, j * 4 : j * 4 + 4], - (4), - elem_offset=I[i] * 80 + offset * 8 + j * 4, - ) - for ji in range(0, 4): - sub_A[j * 4 + ji] = 1 - - -@T.prim_func -def unified_match_buffer_with_elem_offset( - A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset: T.int32 -) -> None: - for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"): - for j in range(2): - with T.block(""): - T.reads(I[blockIdx_x]) - T.writes(A[I[blockIdx_x], offset, j * 4 : j * 4 + 4]) - sub_A = T.match_buffer( - A[I[blockIdx_x], offset, j * 4 : j * 4 + 4], - (4,), - elem_offset=I[blockIdx_x] * 80 + offset * 8 + j * 4, - ) - for ji in range(4): - i = T.int32() - sub_A_1 = T.Buffer( - (4,), data=sub_A.data, elem_offset=I[i] * 80 + offset * 8 + j * 4 - ) - sub_A_1[j * 4 + ji] = T.float32(1) - - def test_thread_x(): _check(element_wise_thread_x, unified_element_wise_thread_x) @@ -327,10 +288,6 @@ def test_implicit_block(): _check(element_wise_implicit_block, unified_element_wise_implicit_block) -def test_match_buffer_with_elem_offset(): - _check(match_buffer_with_elem_offset, unified_match_buffer_with_elem_offset) - - def test_inner_binding_with_annotation(): @T.prim_func def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")): From d5fab9e4fb3b6f06a07b7c4a887d1f7db9ed19a1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 26 Sep 2023 15:06:53 -0500 Subject: [PATCH 04/13] [TVMScript] Use environment variable TVM_BLACK_FORMAT for .show() (#15762) Prior to this commit, the default behavior of the `black_format` argument in TVMScript printing has changed back and forth, based on conflicting user preferences. This commit allows the default to be specified by each using using the `TVM_BLACK_FORMAT` environment variable. If unspecified in a `obj.show()` method call, this environment variable is used to determine the default. --- python/tvm/runtime/script_printer.py | 30 +++++++++++++++++++++++++--- python/tvm/tir/schedule/schedule.py | 25 +++++++++++------------ python/tvm/tir/schedule/trace.py | 9 ++++++++- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index e7846c0680..dff5429111 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Configuration of TVMScript printer""" +import os from typing import Dict, List, Optional, Sequence from tvm._ffi import get_global_func, register_object @@ -199,7 +200,7 @@ def script( def show( self, style: Optional[str] = None, - black_format: bool = False, + black_format: Optional[bool] = None, *, name: Optional[str] = None, show_meta: bool = False, @@ -226,8 +227,26 @@ def show( style : str, optional Pygmentize printing style, auto-detected if None. See `tvm.script.highlight.cprint` for more details. - black_format: bool - If true, use the formatter Black to format the TVMScript + + black_format: Optional[bool] + + If true, use the formatter Black to format the TVMScript. + If false, do not apply the auto-formatter. + + If None (default), determine the behavior based on the + environment variable "TVM_BLACK_FORMAT". If this + environment variable is unset, set to the empty string, or + set to the integer zero, black auto-formatting will be + disabled. If the environment variable is set to a + non-zero integer, black auto-formatting will be enabled. + + Note that the "TVM_BLACK_FORMAT" environment variable only + applies to the `.show()` method, and not the underlying + `.script()` method. The `.show()` method is intended for + human-readable output based on individual user + preferences, while the `.script()` method is intended to + provided a consistent output regardless of environment. + name : Optional[str] = None The name of the object show_meta : bool = False @@ -263,11 +282,16 @@ def show( Object to be underlined obj_to_annotate : Optional[Dict[Object, str]] = None Object to be annotated + """ from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel cprint, ) + if black_format is None: + env = os.environ.get("TVM_BLACK_FORMAT") + black_format = env and int(env) + cprint( self.script( name=name, diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 85d1f19bba..5df3a486ca 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" +import inspect from typing import Callable, Dict, List, Optional, Tuple, Union from tvm._ffi import register_object as _register_object @@ -268,26 +269,24 @@ def fork_seed(self) -> int: """ return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member - def show(self, style: Optional[str] = None, black_format: bool = False) -> None: + def show(self, *args, **kwargs) -> None: """A sugar for print highlighted TVM script. - Parameters - ---------- - style : str, optional - - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - - black_format: bool - - If true, use the formatter Black to format the TVMScript + All parameters are forwarded to the underlying `Module.show` + and `Trace.show` methods. """ mod = self.mod if mod is not None: - mod.show(style=style, black_format=black_format) + mod.show(*args, **kwargs) + trace = self.trace if trace is not None: - trace.show(style=style, black_format=black_format) + # Trace.show only supports the style and black_format arguments + param_binding = inspect.signature(mod.show).bind(*args, **kwargs) + param_binding.apply_defaults() + bound_args = param_binding.arguments + + trace.show(style=bound_args["style"], black_format=bound_args["black_format"]) ########## Lookup ########## diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index e317304993..cb8d5ce997 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """An execution trace of a scheduling program""" +import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from tvm._ffi import register_object as _register_object @@ -274,10 +275,16 @@ def show(self, style: Optional[str] = None, black_format: bool = False) -> None: black_format: bool - If true, use the formatter Black to format the TVMScript + If true, use the formatter Black to format the TVMScript. + If None, determine based on the "TVM_BLACK_FORMAT" environment + variable. """ from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel cprint, ) + if black_format is None: + env = os.environ.get("TVM_BLACK_FORMAT") + black_format = bool(env and int(env)) + cprint(str(self), style=style, black_format=black_format) From c318fa8632ba50efed764730da59c2ae588898ce Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 27 Sep 2023 09:13:32 +0100 Subject: [PATCH 05/13] [Docker] Install oneflow from PyPi (#15819) Installing oneflow from the current link (https://release.oneflow.info) seems to be broken as reported in #15754, which is impacting other unrelated changes in CI. This commit attempts to fix the install by using a version from PyPi. --- docker/install/ubuntu_install_oneflow.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/install/ubuntu_install_oneflow.sh b/docker/install/ubuntu_install_oneflow.sh index 3eb6b7d89b..04fccd5b9c 100755 --- a/docker/install/ubuntu_install_oneflow.sh +++ b/docker/install/ubuntu_install_oneflow.sh @@ -22,4 +22,4 @@ set -o pipefail pip3 install flowvision==0.1.0 -python3 -m pip install -f https://release.oneflow.info oneflow==0.7.0+cpu +python3 -m pip install oneflow==0.7.0 From cf8521ad5c4052e94103ca66dd782c5a4a1bc137 Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Wed, 27 Sep 2023 22:09:53 +0300 Subject: [PATCH 06/13] [Target] LLVM helper functions for any target info (#15761) --- cmake/modules/LLVM.cmake | 3 + python/tvm/target/codegen.py | 93 +++++++-- python/tvm/target/x86.py | 28 +-- python/tvm/topi/x86/batch_matmul.py | 8 +- python/tvm/topi/x86/dense.py | 9 +- python/tvm/topi/x86/dense_alter_op.py | 5 +- .../space_generator/space_generator.cc | 19 +- src/relay/qnn/op/requantize.cc | 6 +- src/relay/qnn/op/requantize_config.h | 6 +- src/target/llvm/codegen_x86_64.cc | 39 +--- src/target/llvm/llvm_instance.cc | 154 +++++++++++++- src/target/llvm/llvm_instance.h | 30 +++ src/target/llvm/llvm_module.cc | 197 +++++------------- tests/python/relay/test_op_level2.py | 5 +- .../relay/test_op_qnn_conv2_transpose.py | 2 +- tests/python/relay/test_op_qnn_conv2d.py | 4 +- .../python/relay/test_pass_alter_op_layout.py | 6 +- tests/python/relay/test_pass_qnn_legalize.py | 14 +- .../python/target/test_llvm_features_info.py | 104 +++++++++ tests/python/target/test_x86_features.py | 155 +++++++------- .../unittest/test_target_codegen_llvm.py | 2 +- 21 files changed, 538 insertions(+), 351 deletions(-) create mode 100644 tests/python/target/test_llvm_features_info.py diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index 6c21356ae8..6fb74fc1ef 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -29,6 +29,9 @@ add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1) # It may be a boolean or a string if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) find_llvm(${USE_LLVM}) + if (${TVM_LLVM_VERSION} LESS 60) + message(FATAL_ERROR "LLVM version 6.0 or greater is required.") + endif() include_directories(SYSTEM ${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) message(STATUS "Build with LLVM " ${LLVM_PACKAGE_VERSION}) diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index 5d43f4ae24..1a2efd4efa 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -17,6 +17,7 @@ """Code generation related functions.""" from . import _ffi_api from .target import Target +from ..ir.container import Array def build_module(mod, target): @@ -39,6 +40,30 @@ def build_module(mod, target): return _ffi_api.Build(mod, target) +def target_has_features(cpu_features, target=None): + """Check CPU features for the target's `-mtriple` and `-mcpu` and `-mattr`. + + Parameters + ---------- + target : Target + The TVM target. + cpu_features : str or Array + CPU Feature(s) to check. + + Returns + ------- + has_features : bool + True if target has the feature(s). + """ + assert isinstance(target, Target) or target is None + assert isinstance(cpu_features, (Array, list, tuple, str)) + has_feats = True + cpu_features = [cpu_features] if isinstance(cpu_features, str) else cpu_features + for feat in cpu_features: + has_feats &= _ffi_api.target_has_feature(feat, target) + return has_feats + + def llvm_lookup_intrinsic_id(name): """Lookup LLVM intrinsic id by name. @@ -71,36 +96,76 @@ def llvm_get_intrinsic_name(intrin_id: int) -> str: return _ffi_api.llvm_get_intrinsic_name(intrin_id) -def llvm_x86_get_archlist(only64bit=False): - """Get X86 CPU name list. +def llvm_get_targets(): + """Get LLVM target list. + + Parameters + ---------- + + Returns + ------- + llvm_targets : list[str] + List of available LLVM targets. + """ + return _ffi_api.llvm_get_targets() + + +def llvm_get_cpu_archlist(target=None): + """Get CPU architectures for the target's `-mtriple`. + + Parameters + ---------- + target : Target + The TVM target. + + Returns + ------- + cpu_archlist : list[str] + List of available CPU architectures. + """ + assert isinstance(target, Target) or target is None + return _ffi_api.llvm_get_cpu_archlist(target) + + +def llvm_get_cpu_features(target=None): + """Get CPU features for the target's `-mtriple` and `-mcpu` and considering `-mattr`. Parameters ---------- - only64bit : bool - Filter 64bit architectures. + target : Target + The TVM target. Returns ------- - features : list[str] - String list of X86 architectures. + cpu_features : list[str] + List of available CPU features. """ - return _ffi_api.llvm_x86_get_archlist(only64bit) + assert isinstance(target, Target) or target is None + return _ffi_api.llvm_get_cpu_features(target) -def llvm_x86_get_features(cpu_name): - """Get X86 CPU features. +def llvm_cpu_has_features(cpu_features, target=None): + """Check CPU features for the target's `-mtriple` and `-mcpu` and considering `-mattr`. Parameters ---------- - cpu_name : string - X86 CPU name (e.g. "skylake"). + target : Target + The TVM target. + cpu_features : str or Array + CPU Feature(s) to check. Returns ------- - features : list[str] - String list of X86 CPU features. + has_features : bool + True if target CPU has the feature(s). """ - return _ffi_api.llvm_x86_get_features(cpu_name) + assert isinstance(target, Target) or target is None + assert isinstance(cpu_features, (Array, list, tuple, str)) + has_feats = True + cpu_features = [cpu_features] if isinstance(cpu_features, str) else cpu_features + for feat in cpu_features: + has_feats &= _ffi_api.llvm_cpu_has_feature(feat, target) + return has_feats def llvm_version_major(allow_none=False): diff --git a/python/tvm/target/x86.py b/python/tvm/target/x86.py index a3dcb62e8a..c040eface8 100644 --- a/python/tvm/target/x86.py +++ b/python/tvm/target/x86.py @@ -16,30 +16,7 @@ # under the License. """Common x86 related utilities""" from .._ffi import register_func -from . import _ffi_api -from ..ir.container import Array - - -@register_func("tvm.target.x86.target_has_features") -def target_has_features(features, target=None): - """Check X86 CPU features. - Parameters - ---------- - features : str or Array - Feature(s) to check. - target : Target - Optional TVM target, default `None` use the global context target. - Returns - ------- - has_feats : bool - True if feature(s) are in the target arch. - """ - has_feats = True - assert isinstance(features, (Array, str)) - features = [features] if isinstance(features, str) else features - for feat in features: - has_feats &= _ffi_api.llvm_x86_has_feature(feat, target) - return has_feats +from .codegen import target_has_features @register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") @@ -53,9 +30,6 @@ def get_simd_32bit_lanes(): The optimal vector length of CPU from the global context target. """ vec_len = 4 - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" if target_has_features(["avx512bw", "avx512f"]): vec_len = 16 elif target_has_features("avx2"): diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 85accab87b..e103133230 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -21,7 +21,7 @@ from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, mkl -from tvm.target.x86 import target_has_features +from tvm.target.codegen import target_has_features from .. import generic, nn from ..transform import layout_transform @@ -38,9 +38,6 @@ def batch_matmul_int8_compute(cfg, x, y, *_): packed_y = layout_transform(y, "BNK", packed_y_layout) _, n_o, _, n_i, _ = packed_y.shape ak = te.reduce_axis((0, k), name="k") - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" if target_has_features(["avx512bw", "avx512f"]): attrs_info = {"schedule_rule": "batch_matmul_int8"} else: @@ -241,9 +238,6 @@ def _callback(op): layout_trans = op.input_tensors[1] if target_has_features("amx-int8"): batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans) - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" elif target_has_features(["avx512bw", "avx512f"]): batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans) diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 2437b1a695..4151ea0b70 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -23,7 +23,8 @@ from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, dnnl, mkl -from tvm.target.x86 import get_simd_32bit_lanes, target_has_features +from tvm.target.x86 import get_simd_32bit_lanes +from tvm.target.codegen import target_has_features from .. import generic, tag from ..utils import get_const_tuple, traverse_inline @@ -303,9 +304,6 @@ def _callback(op): if "dense_int8" in op.tag: if target_has_features("amx-int8"): dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" elif target_has_features(["avx512bw", "avx512f"]): dense_int8_schedule(cfg, s, op.output(0), outs[0]) @@ -318,9 +316,6 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): m, k = X.shape n_o, _, n_i, _ = packed_w.shape ak = te.reduce_axis((0, k), name="k") - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" if target_has_features(["avx512bw", "avx512f"]): target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"} else: diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index ef6df7dd2c..0e9b1f7b65 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -19,7 +19,7 @@ import tvm from tvm import autotvm, relay, te -from tvm.target.x86 import target_has_features +from tvm.target.codegen import target_has_features from .. import nn from ..nn import dense_alter_layout @@ -28,9 +28,6 @@ def check_int8_applicable(x, y, allow_padding=False): - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" simd_avai = target_has_features(["avx512bw", "avx512f"]) simd_avai |= target_has_features("amx-int8") # TODO(vvchernov): may be also target_has_features("avx2") or lower? diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 4657f962f3..73df303f72 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -24,20 +24,17 @@ namespace meta_schedule { String GetRuleKindFromTarget(const Target& target) { if (target->kind->name == "llvm") { - static const PackedFunc* llvm_x86_has_feature_fn_ptr = - runtime::Registry::Get("target.llvm_x86_has_feature"); - ICHECK(llvm_x86_has_feature_fn_ptr != nullptr) - << "The `target.llvm_x86_has_feature` func is not in tvm registry."; - bool have_avx512vnni = (*llvm_x86_has_feature_fn_ptr)("avx512vnni", target); - bool have_avxvnni = (*llvm_x86_has_feature_fn_ptr)("avxvnni", target); + static const PackedFunc* target_has_feature_fn_ptr = + runtime::Registry::Get("target.target_has_feature"); + ICHECK(target_has_feature_fn_ptr != nullptr) + << "The `target.target_has_feature` func is not in tvm registry."; + bool have_avx512vnni = (*target_has_feature_fn_ptr)("avx512vnni", target); + bool have_avxvnni = (*target_has_feature_fn_ptr)("avxvnni", target); if (have_avx512vnni || have_avxvnni) { return "vnni"; } else { - // avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - // avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - // + llvm.x86.avx512.pmaddw.d.512" - bool have_avx512f = (*llvm_x86_has_feature_fn_ptr)("avx512f", target); - bool have_avx512bw = (*llvm_x86_has_feature_fn_ptr)("avx512bw", target); + bool have_avx512f = (*target_has_feature_fn_ptr)("avx512f", target); + bool have_avx512bw = (*target_has_feature_fn_ptr)("avx512bw", target); if (have_avx512bw && have_avx512f) { return "avx512"; } diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index b57710b266..2dd74e1321 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -121,9 +121,9 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs, } bool has_current_target_sse41_support() { - auto llvm_x86_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.llvm_x86_has_feature"); - ICHECK(llvm_x86_has_feature_fn_ptr) << "Function target.llvm_x86_has_feature not found"; - return (*llvm_x86_has_feature_fn_ptr)("sse4.1", Target::Current(true)); + auto target_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.target_has_feature"); + ICHECK(target_has_feature_fn_ptr) << "Function target.target_has_feature not found"; + return (*target_has_feature_fn_ptr)("sse4.1", Target::Current(true)); } /* diff --git a/src/relay/qnn/op/requantize_config.h b/src/relay/qnn/op/requantize_config.h index 956bc3533b..a4238fa498 100644 --- a/src/relay/qnn/op/requantize_config.h +++ b/src/relay/qnn/op/requantize_config.h @@ -61,10 +61,10 @@ class RequantizeConfigNode : public Object { // For the x86 architecture, the float32 computation is expected to give significant speedup, // with little loss in the accuracy of the requantize operation. auto target = Target::Current(true); - auto llvm_x86_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.llvm_x86_has_feature"); - ICHECK(llvm_x86_has_feature_fn_ptr) << "Function target.llvm_x86_has_feature not found"; + auto target_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.target_has_feature"); + ICHECK(target_has_feature_fn_ptr) << "Function target.target_has_feature not found"; if (target.defined() && target->kind->name == "llvm") { - if ((*llvm_x86_has_feature_fn_ptr)("sse4.1", target)) { + if ((*target_has_feature_fn_ptr)("sse4.1", target)) { return "float32"; } } diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index efe15c5c4a..1872d64d71 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -29,9 +29,7 @@ #if TVM_LLVM_VERSION >= 100 #include #endif -#include #include -#include #include #include @@ -43,38 +41,6 @@ namespace tvm { namespace codegen { -namespace { -bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) { - // MCSubTargetInfo::checkFeatures was added in LLVM 6.0 -#if TVM_LLVM_VERSION >= 60 - const auto* MCInfo = tm.getMCSubtargetInfo(); - return MCInfo->checkFeatures(std::string("+") + feature); -#else - return false; - // TODO(tulloch) - enable this block, need to figure out how to reimplement - // this given visibility constraints, similar to - // https://github.com/rust-lang/rust/pull/31709 - - // Copied from - // https://github.com/llvm-mirror/llvm/blob/5136df4/lib/MC/MCSubtargetInfo.cpp#L78-L88. - - // auto checkFeatures = [&](const std::string FS) { - // llvm::SubtargetFeatures T(FS); - // llvm::FeatureBitset Set, All; - // for (std::string F : T.getFeatures()) { - // llvm::SubtargetFeatures::ApplyFeatureFlag(Set, F, MCInfo->ProcFeatures); - // if (F[0] == '-') { - // F[0] = '+'; - // } - // llvm::SubtargetFeatures::ApplyFeatureFlag(All, F, MCInfo->ProcFeatures); - // } - // return (MCInfo->getFeatureBits() & All) == Set; - // }; - // return checkFeatures(MCInfo, std::string("+") + feature); -#endif -} -} // namespace - class CodeGenX86_64 final : public CodeGenCPU { public: llvm::Value* VisitExpr_(const CastNode* op) override; @@ -92,9 +58,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto to = op->dtype; if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) { ICHECK_EQ(from.lanes(), to.lanes()); - llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); - const auto has_avx512 = TargetHasFeature(*tm, "avx512f"); + const auto has_avx512 = llvm_target_->TargetHasCPUFeature("avx512f"); if (from.lanes() >= 16 && has_avx512) { return CallVectorIntrin( @@ -111,7 +76,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { #if TVM_LLVM_VERSION <= 100 // The intrinsic x86_vcvtph2ps_256 was removed in LLVM 11. - const auto has_f16c = TargetHasFeature(*tm, "f16c"); + const auto has_f16c = llvm_target_->TargetHasCPUFeature("f16c"); if (from.lanes() >= 8 && has_f16c) { return CallVectorIntrin(llvm::Intrinsic::x86_vcvtph2ps_256, 8, diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index 2aa190ad70..e270a9b66c 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -33,6 +33,7 @@ #include #include #include +#include #if TVM_LLVM_VERSION >= 140 #include #else @@ -66,6 +67,27 @@ #include #include +#if TVM_LLVM_VERSION < 180 +namespace llvm { +#if TVM_LLVM_VERSION < 170 +// SubtargetSubTypeKV view +template MCSubtargetInfo::*Member> +struct ArchViewer { + friend ArrayRef& archViewer(MCSubtargetInfo Obj) { return Obj.*Member; } +}; +template struct ArchViewer<&MCSubtargetInfo::ProcDesc>; +ArrayRef& archViewer(MCSubtargetInfo); +#endif +// SubtargetFeatureKV view +template MCSubtargetInfo::*Member> +struct FeatViewer { + friend ArrayRef& featViewer(MCSubtargetInfo Obj) { return Obj.*Member; } +}; +template struct FeatViewer<&MCSubtargetInfo::ProcFeatures>; +ArrayRef& featViewer(MCSubtargetInfo); +} // namespace llvm +#endif + namespace tvm { namespace codegen { @@ -175,6 +197,17 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) { attrs_.push_back(s); } } + // llvm module target + if (target->kind->name == "llvm") { + // legalize -mcpu with the target -mtriple + auto arches = GetAllLLVMTargetArches(); + bool has_arch = + std::any_of(arches.begin(), arches.end(), [&](const auto& var) { return var == cpu_; }); + if (!has_arch) { + LOG(FATAL) << "LLVM cpu architecture `-mcpu=" << cpu_ + << "` is not valid in `-mtriple=" << triple_ << "`"; + } + } if (const Optional>& v = target->GetAttr>("cl-opt")) { llvm::StringMap& options = llvm::cl::getRegisteredOptions(); @@ -288,19 +321,59 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& scope, const std::string& target_st LLVMTargetInfo::~LLVMTargetInfo() = default; +static const llvm::Target* CreateLLVMTargetInstance(const std::string triple, + const bool allow_missing = true) { + std::string error; + // create LLVM instance + // required mimimum: llvm::InitializeAllTargets() + const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple, error); + if (!allow_missing && !llvm_instance) { + ICHECK(llvm_instance) << "LLVM instance error: `" << error << "`"; + } + + return llvm_instance; +} + +static llvm::TargetMachine* CreateLLVMTargetMachine( + const llvm::Target* llvm_instance, const std::string& triple, const std::string& cpu, + const std::string& features, const llvm::TargetOptions& target_options, + const llvm::Reloc::Model& reloc_model, const llvm::CodeModel::Model& code_model, + const llvm::CodeGenOpt::Level& opt_level) { + llvm::TargetMachine* tm = llvm_instance->createTargetMachine( + triple, cpu, features, target_options, reloc_model, code_model, opt_level); + ICHECK(tm != nullptr); + + return tm; +} + +static const llvm::MCSubtargetInfo* GetLLVMSubtargetInfo(const std::string& triple, + const std::string& cpu_name, + const std::string& feats) { + // create a LLVM instance + auto llvm_instance = CreateLLVMTargetInstance(triple, true); + // create a target machine + // required minimum: llvm::InitializeAllTargetMCs() + llvm::TargetOptions target_options; + auto tm = CreateLLVMTargetMachine(llvm_instance, triple, cpu_name, feats, target_options, + llvm::Reloc::Static, llvm::CodeModel::Small, + llvm::CodeGenOpt::Level(0)); + // create subtarget info module + const llvm::MCSubtargetInfo* MCInfo = tm->getMCSubtargetInfo(); + + return MCInfo; +} + llvm::TargetMachine* LLVMTargetInfo::GetOrCreateTargetMachine(bool allow_missing) { if (target_machine_) return target_machine_.get(); std::string error; - if (const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple_, error)) { + if (const llvm::Target* llvm_instance = CreateLLVMTargetInstance(triple_, allow_missing)) { llvm::TargetMachine* tm = - llvm_instance->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_, - reloc_model_, code_model_, opt_level_); + CreateLLVMTargetMachine(llvm_instance, triple_, cpu_, GetTargetFeatureString(), + target_options_, reloc_model_, code_model_, opt_level_); target_machine_ = std::unique_ptr(tm); } - if (!allow_missing) { - ICHECK(target_machine_ != nullptr) << error; - } + ICHECK(target_machine_ != nullptr); return target_machine_.get(); } @@ -662,6 +735,75 @@ void LLVMTargetInfo::GetOptionValue(LLVMTargetInfo::Option* opt) const { } } +const Array LLVMTargetInfo::GetAllLLVMTargets() const { + Array llvm_targets; + // iterate all archtypes + for (auto a = llvm::Triple::ArchType(llvm::Triple::ArchType::UnknownArch + 1); + a < llvm::Triple::ArchType::LastArchType; a = llvm::Triple::ArchType(a + 1)) { + std::string target_name = llvm::Triple::getArchTypeName(a).str(); + // get valid target + if (CreateLLVMTargetInstance(target_name + "--", true)) { + llvm_targets.push_back(target_name); + } + } + + return llvm_targets; +} + +const Array LLVMTargetInfo::GetAllLLVMTargetArches() const { + Array cpu_arches; + // get the subtarget info module + const auto MCInfo = GetLLVMSubtargetInfo(triple_, "", ""); + if (!MCInfo) { + return cpu_arches; + } + // get all arches + llvm::ArrayRef llvm_arches = +#if TVM_LLVM_VERSION < 170 + llvm::archViewer(*(llvm::MCSubtargetInfo*)MCInfo); +#else + MCInfo->getAllProcessorDescriptions(); +#endif + for (const auto& arch : llvm_arches) { + cpu_arches.push_back(arch.Key); + } + + return cpu_arches; +} + +const Array LLVMTargetInfo::GetAllLLVMCpuFeatures() const { + std::string feats = ""; + for (const auto& attr : attrs_) { + feats += feats.empty() ? attr : ("," + attr); + } + // get the subtarget info module + const auto MCInfo = GetLLVMSubtargetInfo(triple_, cpu_.c_str(), feats); + // get all features for CPU + llvm::ArrayRef llvm_features = +#if TVM_LLVM_VERSION < 180 + llvm::featViewer(*(llvm::MCSubtargetInfo*)MCInfo); +#else + MCInfo->getAllProcessorFeatures(); +#endif + Array cpu_features; + for (const auto& feat : llvm_features) { + if (MCInfo->checkFeatures("+" + std::string(feat.Key))) { + cpu_features.push_back(feat.Key); + } + } + + return cpu_features; +} + +const bool LLVMTargetInfo::TargetHasCPUFeature(const std::string& feature) const { + // lookup features for `-mcpu` + auto feats = GetAllLLVMCpuFeatures(); + bool has_feature = + std::any_of(feats.begin(), feats.end(), [&](const auto& var) { return var == feature; }); + + return has_feature; +} + // LLVMTarget bool LLVMTarget::modified_llvm_state_ = false; diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index 217db63aad..ac08008b80 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -266,6 +266,36 @@ class LLVMTargetInfo { */ bool MatchesGlobalState() const; + /*! + * \brief Get all supported targets from the LLVM backend + * \return list with all valid targets + */ + const Array GetAllLLVMTargets() const; + + /*! + * \brief Get all CPU arches from target + * \return list with all valid cpu architectures + * \note The arches are fetched from the LLVM backend using the target `-mtriple`. + */ + const Array GetAllLLVMTargetArches() const; + + /*! + * \brief Get all CPU features from target + * \return list with all valid cpu features + * \note The features are fetched from the LLVM backend using the target `-mtriple` + * and the `-mcpu` architecture, but also consider the `-mattr` attributes. + */ + const Array GetAllLLVMCpuFeatures() const; + + /*! + * \brief Check the target if has a specific cpu feature + * \param feature string with the feature to check + * \return true or false + * \note The feature is checked in the LLVM backend for the target `-mtriple` + * and `-mcpu` architecture, but also consider the `-mattr` attributes. + */ + const bool TargetHasCPUFeature(const std::string& feature) const; + protected: /*! * \brief Get the current value of given LLVM option diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 168163c416..05a7df230f 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -45,16 +45,6 @@ #include #include #include -#if TVM_LLVM_VERSION < 110 -#include -#include -#else -#if TVM_LLVM_VERSION < 170 -#include -#else -#include -#endif -#endif #include #include #include @@ -87,25 +77,6 @@ #include "codegen_llvm.h" #include "llvm_instance.h" -#if TVM_LLVM_VERSION < 110 -namespace llvm { -// SubtargetSubTypeKV view -template MCSubtargetInfo::*Member> -struct ArchViewer { - friend ArrayRef& archViewer(MCSubtargetInfo Obj) { return Obj.*Member; } -}; -template struct ArchViewer<&MCSubtargetInfo::ProcDesc>; -ArrayRef& archViewer(MCSubtargetInfo); -// SubtargetFeatureKV view -template MCSubtargetInfo::*Member> -struct FeatViewer { - friend ArrayRef& featViewer(MCSubtargetInfo Obj) { return Obj.*Member; } -}; -template struct FeatViewer<&MCSubtargetInfo::ProcFeatures>; -ArrayRef& featViewer(MCSubtargetInfo); -} // namespace llvm -#endif - namespace tvm { namespace codegen { @@ -514,131 +485,69 @@ TVM_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t #endif }); -#if TVM_LLVM_VERSION < 110 -static const llvm::MCSubtargetInfo* llvm_compat_get_subtargetinfo(const std::string triple, - const std::string cpu_name) { - std::string error; - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - // create a LLVM x86 instance - auto* llvm_instance = llvm::TargetRegistry::lookupTarget(triple, error); - // create a target machine - llvm::TargetOptions target_options; - auto RM = llvm::Optional(); - auto* tm = llvm_instance->createTargetMachine(triple, cpu_name.c_str(), "", target_options, RM); - // create subtarget info module - const llvm::MCSubtargetInfo* MCInfo = tm->getMCSubtargetInfo(); - - return MCInfo; -} - -static const Array llvm_compat_get_archlist(const std::string triple) { - // get the subtarget info module - const auto* MCInfo = llvm_compat_get_subtargetinfo(triple, ""); - // get all X86 arches - llvm::ArrayRef x86_arches = - llvm::archViewer(*(llvm::MCSubtargetInfo*)MCInfo); - Array cpu_arches; - for (auto& arch : x86_arches) { - cpu_arches.push_back(arch.Key); - } - return cpu_arches; -} - -static const Array llvm_compat_get_features(const std::string triple, - const std::string cpu_name) { - // get the subtarget info module - const auto* MCInfo = llvm_compat_get_subtargetinfo(triple, cpu_name.c_str()); - // get all features - llvm::ArrayRef x86_features = - llvm::featViewer(*(llvm::MCSubtargetInfo*)MCInfo); - // only targeted CPU features - Array cpu_features; - for (auto& feat : x86_features) { - if (MCInfo->checkFeatures("+" + std::string(feat.Key))) { - cpu_features.push_back(feat.Key); - } - } - return cpu_features; -} -#endif +TVM_REGISTER_GLOBAL("target.llvm_get_targets").set_body_typed([]() -> Array { + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); + return llvm_backend.GetAllLLVMTargets(); +}); -TVM_REGISTER_GLOBAL("target.llvm_x86_get_archlist") - .set_body_typed([](bool only64bit) -> Array { - Array cpu_arches; -#if TVM_LLVM_VERSION < 110 - cpu_arches = llvm_compat_get_archlist("x86_64--"); -#else - llvm::SmallVector x86_arches; - llvm::X86::fillValidCPUArchList(x86_arches, only64bit); - for (auto& arch : x86_arches) { - cpu_arches.push_back(arch.str()); +TVM_REGISTER_GLOBAL("target.llvm_get_cpu_archlist") + .set_body_typed([](const Target& target) -> Array { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return Array{}; + } } -#endif - return cpu_arches; + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetAllLLVMTargetArches(); }); -TVM_REGISTER_GLOBAL("target.llvm_x86_get_features") - .set_body_typed([](std::string cpu_name) -> Array { - Array cpu_features; -#if TVM_LLVM_VERSION < 110 - cpu_features = llvm_compat_get_features("x86_64--", cpu_name); -#else - llvm::SmallVector x86_features; - llvm::X86::getFeaturesForCPU(cpu_name, x86_features); - for (auto& feat : x86_features) { - cpu_features.push_back(feat.str()); +TVM_REGISTER_GLOBAL("target.llvm_get_cpu_features") + .set_body_typed([](const Target& target) -> Array { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return Array{}; + } } -#endif - return cpu_features; + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetAllLLVMCpuFeatures(); }); -TVM_REGISTER_GLOBAL("target.llvm_x86_has_feature") - .set_body_typed([](String feature, const Target& target) -> bool { - // target argument is optional (nullptr or None) - // if not explicit then use the current context target - Optional mcpu = target.defined() ? target->GetAttr("mcpu") - : Target::Current(false)->GetAttr("mcpu"); - Optional> mattr = target.defined() - ? target->GetAttr>("mattr") - : Target::Current(false)->GetAttr>("mattr"); - String name = target.defined() ? target->kind->name : Target::Current(false)->kind->name; - // lookup only for `llvm` targets having -mcpu - if ((name != "llvm") || !mcpu) { - return false; - } - // lookup in -mattr flags - bool is_in_mattr = - !mattr ? false - : std::any_of(mattr.value().begin(), mattr.value().end(), - [&](const String& var) { return var == ("+" + feature); }); -#if TVM_LLVM_VERSION < 110 - auto x86_arches = llvm_compat_get_archlist("x86_64--"); - // decline on invalid arch (avoid llvm assertion) - if (!std::any_of(x86_arches.begin(), x86_arches.end(), - [&](const String& var) { return var == mcpu.value(); })) { - return false; +TVM_REGISTER_GLOBAL("target.llvm_cpu_has_feature") + .set_body_typed([](const String feature, const Target& target) -> bool { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return false; + } } - // lookup in -mcpu llvm architecture flags - auto cpu_features = llvm_compat_get_features("x86_64--", mcpu.value()); + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + auto cpu_features = llvm_backend.GetAllLLVMCpuFeatures(); bool has_feature = std::any_of(cpu_features.begin(), cpu_features.end(), - [&](const String& var) { return var == feature; }); -#else - llvm::SmallVector x86_arches; - llvm::X86::fillValidCPUArchList(x86_arches, false); - // decline on invalid arch (avoid llvm assertion) - if (!std::any_of(x86_arches.begin(), x86_arches.end(), - [&](const llvm::StringRef& var) { return var == mcpu.value().c_str(); })) { - return false; + [&](auto& var) { return var == feature; }); + return has_feature; + }); + +TVM_REGISTER_GLOBAL("target.target_has_feature") + .set_body_typed([](const String feature, const Target& target) -> bool { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return false; + } } - // lookup in -mcpu llvm architecture flags - llvm::SmallVector x86_features; - llvm::X86::getFeaturesForCPU(mcpu.value().c_str(), x86_features); - bool has_feature = - std::any_of(x86_features.begin(), x86_features.end(), - [&](const llvm::StringRef& var) { return var == feature.c_str(); }); -#endif - return has_feature || is_in_mattr; + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_target(*llvm_instance, use_target); + return llvm_target.TargetHasCPUFeature(feature); }); TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 0a0ae561ab..bd984d32e6 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1837,7 +1837,10 @@ def test_depthwise_conv2d_int8(): wdata = np.random.rand(*kernel_shape) * 10 parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))} - targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] + targets = [ + "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512", + "llvm -mtriple=x86_64-linux-gnu -mcpu=cascadelake", + ] llvm_version = tvm.target.codegen.llvm_version_major() for target in targets: if llvm_version >= 8: diff --git a/tests/python/relay/test_op_qnn_conv2_transpose.py b/tests/python/relay/test_op_qnn_conv2_transpose.py index ec273eb2f7..18ad68df9e 100644 --- a/tests/python/relay/test_op_qnn_conv2_transpose.py +++ b/tests/python/relay/test_op_qnn_conv2_transpose.py @@ -644,7 +644,7 @@ def test_broadcast_layout(): func = relay.Function(relay.analysis.free_vars(func), func) mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(opt_level=3): - libs = relay.build(mod, "llvm -mcpu=skylake-avx512") + libs = relay.build(mod, "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512") def test_non_scalar_input_scale_zp(): diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 3736350cbf..e7d2c8941b 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -948,7 +948,9 @@ def test_broadcast_layout(): func = relay.Function(relay.analysis.free_vars(func), func) mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(opt_level=3): - graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") + graph, lib, params = relay.build( + mod, "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512" + ) def test_depthwise_depth_multiplier(): diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 829c1d6ae4..87065b2d27 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -886,7 +886,7 @@ def before(): from tvm import topi def alter_conv2d(attrs, inputs, tinfos, out_type): - with tvm.target.Target("llvm -mcpu=core-avx2"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2"): return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type) def expected(): @@ -1373,7 +1373,7 @@ def expected(): y = relay.Function(analysis.free_vars(y), y) return y - target = "llvm -mcpu=core-avx2" + target = "llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2" with tvm.target.Target(target): with TempOpAttr( "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout @@ -1441,7 +1441,7 @@ def expected(): ) return relay.Function(analysis.free_vars(dense), dense) - with tvm.target.Target("llvm -mcpu=core-avx2"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2"): with TempOpAttr( "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout ): diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 73ba9c2208..a32100ea20 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -138,7 +138,10 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check that Intel AVX512 (with or w/o VNNI) gets picked up. - for target in ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]: + for target in [ + "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512", + "llvm -mtriple=x86_64-linux-gnu -mcpu=cascadelake", + ]: with tvm.target.Target(target): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) @@ -170,7 +173,7 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check no transformation for Intel AVX512. - with tvm.target.Target("llvm -mcpu=skylake-avx512"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512"): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) assert tvm.ir.structural_equal(mod, legalized_mod) @@ -232,7 +235,10 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check that Intel AVX512 (with or w/o VNNI) gets picked up. - for target in ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]: + for target in [ + "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512", + "llvm -mtriple=x86_64-linux-gnu -mcpu=cascadelake", + ]: with tvm.target.Target(target): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) @@ -264,7 +270,7 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check no transformation for Intel AVX512. - with tvm.target.Target("llvm -mcpu=skylake-avx512"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512"): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) assert tvm.ir.structural_equal(mod, legalized_mod) diff --git a/tests/python/target/test_llvm_features_info.py b/tests/python/target/test_llvm_features_info.py new file mode 100644 index 0000000000..1be71331dd --- /dev/null +++ b/tests/python/target/test_llvm_features_info.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm.target import _ffi_api, codegen, Target + +LLVM_VERSION = codegen.llvm_version_major() + + +def test_llvm_targets(): + + ## + ## check LLVM backend + ## + + # check blank results + assert len(codegen.llvm_get_targets()) + # check ffi vs python + assert str(codegen.llvm_get_targets()) == str(_ffi_api.llvm_get_targets()) + + # check LLVM target -mcpu legality + try: + tvm.target.codegen.llvm_get_cpu_features( + tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=dummy") + ) + assert False + except tvm.error.TVMError as e: + msg = str(e) + assert ( + msg.find( + "TVMError: LLVM cpu architecture `-mcpu=dummy` is not valid in `-mtriple=x86_64-linux-gnu`" + ) + != -1 + ) + + +min_llvm_version, llvm_target, cpu_arch, cpu_features, is_supported = tvm.testing.parameters( + (-1, "x86_64", "sandybridge", "sse4.1", True), + (-1, "x86_64", "ivybridge", ["sse4.1", "ssse3"], True), + (-1, "x86_64", "ivybridge", ["sse4.1", "ssse3", "avx512bw"], False), + # 32bit vs 64bit + (-1, "aarch64", "cortex-a55", "neon", True), + (-1, "aarch64", "cortex-a55", "dotprod", True), + (-1, "aarch64", "cortex-a55", "dsp", False), + (-1, "arm", "cortex-a55", "dsp", True), + (-1, "aarch64", "cortex-a55", ["neon", "dotprod"], True), + (-1, "aarch64", "cortex-a55", ["neon", "dotprod", "dsp"], False), + (-1, "arm", "cortex-a55", ["neon", "dotprod"], True), + (-1, "aarch64", "cortex-a55", ["neon", "dotprod", "dsp"], False), + (-1, "arm", "cortex-a55", ["neon", "dotprod", "dsp"], True), +) + + +def test_target_features(min_llvm_version, llvm_target, cpu_arch, cpu_features, is_supported): + + target = Target("llvm -mtriple=%s-- -mcpu=%s" % (llvm_target, cpu_arch)) + + ## + ## legalize llvm_target + ## + + assert llvm_target in codegen.llvm_get_targets() + + ## + ## legalize cpu_arch + ## + + ### with context + with target: + assert cpu_arch in codegen.llvm_get_cpu_archlist() + ### no context but with expicit target + assert cpu_arch in codegen.llvm_get_cpu_archlist(target) + # check ffi vs python + assert str(codegen.llvm_get_cpu_archlist(target)) == str(_ffi_api.llvm_get_cpu_archlist(target)) + + ## + ## check has_features + ## + + ### with context + with target: + assert codegen.llvm_cpu_has_features(cpu_features) == is_supported + ### no context but with expicit target + assert codegen.llvm_cpu_has_features(cpu_features, target) == is_supported + # check ffi vs python + for feat in cpu_features: + assert str(codegen.llvm_cpu_has_features(feat, target)) == str( + _ffi_api.llvm_cpu_has_feature(feat, target) + ) diff --git a/tests/python/target/test_x86_features.py b/tests/python/target/test_x86_features.py index 31a823b504..ef1ab9b423 100644 --- a/tests/python/target/test_x86_features.py +++ b/tests/python/target/test_x86_features.py @@ -18,85 +18,89 @@ import tvm from tvm.target import _ffi_api, codegen, Target -from tvm.target.x86 import target_has_features +from tvm.target.codegen import target_has_features LLVM_VERSION = codegen.llvm_version_major() min_llvm_version, tvm_target, x86_feature, is_supported = tvm.testing.parameters( # sse4.1 - (-1, "llvm -mcpu=btver2", "sse4a", True), - (-1, "llvm -mcpu=penryn", "sse4.1", True), - (-1, "llvm -mcpu=silvermont", "sse4.2", True), - (11, "llvm -mcpu=slm", "sse4.2", True), - (-1, "llvm -mcpu=goldmont", "sse4.2", True), - (-1, "llvm -mcpu=goldmont-plus", "sse4.2", True), - (-1, "llvm -mcpu=tremont", "sse4.2", True), - (-1, "llvm -mcpu=nehalem", "sse4.2", True), - (11, "llvm -mcpu=corei7", "sse4.2", True), - (-1, "llvm -mcpu=westmere", "sse4.2", True), - (-1, "llvm -mcpu=bdver1", "sse4.2", True), - (-1, "llvm -mcpu=bdver2", "sse4.2", True), - (-1, "llvm -mcpu=bdver3", "sse4.2", True), - (11, "llvm -mcpu=x86-64-v2", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=btver2", "sse4a", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=penryn", "sse4.1", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=silvermont", "sse4.2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=slm", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=goldmont", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=goldmont-plus", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=tremont", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=nehalem", "sse4.2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=corei7", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=westmere", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver1", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver2", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver3", "sse4.2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=x86-64-v2", "sse4.2", True), # avx - (-1, "llvm -mcpu=sandybridge", "avx", True), - (11, "llvm -mcpu=corei7-avx", "avx", True), - (-1, "llvm -mcpu=ivybridge", "avx", True), - (11, "llvm -mcpu=core-avx-i", "avx", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=sandybridge", "avx", True), + (11, "llvm -mtriple=x86_64-- -mcpu=corei7-avx", "avx", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=ivybridge", "avx", True), + (11, "llvm -mtriple=x86_64-- -mcpu=core-avx-i", "avx", True), # avx2 - (-1, "llvm -mcpu=haswell", "avx2", True), - (11, "llvm -mcpu=core-avx2", "avx2", True), - (-1, "llvm -mcpu=broadwell", "avx2", True), - (-1, "llvm -mcpu=skylake", "avx2", True), - (-1, "llvm -mcpu=bdver4", "avx2", True), - (-1, "llvm -mcpu=znver1", "avx2", True), - (-1, "llvm -mcpu=znver2", "avx2", True), - (11, "llvm -mcpu=znver3", "avx2", True), - (11, "llvm -mcpu=x86-64-v3", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=haswell", "avx2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=core-avx2", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=broadwell", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=skylake", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver4", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=znver1", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=znver2", "avx2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=znver3", "avx2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=x86-64-v3", "avx2", True), # avx512bw - (-1, "llvm -mcpu=skylake-avx512", "avx512bw", True), - (11, "llvm -mcpu=skx", "avx512bw", True), - (11, "llvm -mcpu=knl", "avx512bw", False), - (-1, "llvm -mcpu=knl", "avx512f", True), - (11, "llvm -mcpu=knl", ["avx512bw", "avx512f"], False), - (11, "llvm -mcpu=knl", ("avx512bw", "avx512f"), False), - (-1, "llvm -mcpu=knl", "avx512cd", True), - (11, "llvm -mcpu=knl", ["avx512cd", "avx512f"], True), - (11, "llvm -mcpu=knl", ("avx512cd", "avx512f"), True), - (-1, "llvm -mcpu=knl", "avx512er", True), - (-1, "llvm -mcpu=knl", "avx512pf", True), - (11, "llvm -mcpu=knm", "avx512bw", False), - (-1, "llvm -mcpu=knm", "avx512f", True), - (-1, "llvm -mcpu=knm", "avx512cd", True), - (-1, "llvm -mcpu=knm", "avx512er", True), - (-1, "llvm -mcpu=knm", "avx512pf", True), - (11, "llvm -mcpu=x86-64-v4", "avx512bw", True), - (-1, "llvm -mcpu=cannonlake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=skylake-avx512", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=skx", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512bw", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512f", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ["avx512bw", "avx512f"], False), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ("avx512bw", "avx512f"), False), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512cd", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ["avx512cd", "avx512f"], True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ("avx512cd", "avx512f"), True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512er", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512pf", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512bw", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512f", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512cd", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512er", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512pf", True), + (11, "llvm -mtriple=x86_64-- -mcpu=x86-64-v4", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cannonlake", "avx512bw", True), # explicit enumeration of VNNI capable due to collision with alderlake - (11, "llvm -mcpu=alderlake", "avx512bw", False), - (-1, "llvm -mcpu=cascadelake", "avx512bw", True), - (-1, "llvm -mcpu=icelake-client", "avx512bw", True), - (-1, "llvm -mcpu=icelake-server", "avx512bw", True), - (11, "llvm -mcpu=rocketlake", "avx512bw", True), - (-1, "llvm -mcpu=tigerlake", "avx512bw", True), - (-1, "llvm -mcpu=cooperlake", "avx512bw", True), - (11, "llvm -mcpu=sapphirerapids", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=alderlake", "avx512bw", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=cascadelake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-client", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-server", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=rocketlake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=tigerlake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cooperlake", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=sapphirerapids", "avx512bw", True), # avx512vnni - (11, "llvm -mcpu=alderlake", "avx512vnni", False), - (11, "llvm -mcpu=alderlake", "avxvnni", True), - (-1, "llvm -mcpu=cascadelake", "avx512vnni", True), - (-1, "llvm -mcpu=icelake-client", "avx512vnni", True), - (-1, "llvm -mcpu=icelake-server", "avx512vnni", True), - (11, "llvm -mcpu=rocketlake", "avx512vnni", True), - (-1, "llvm -mcpu=tigerlake", "avx512vnni", True), - (-1, "llvm -mcpu=cooperlake", "avx512vnni", True), - (11, "llvm -mcpu=sapphirerapids", "avx512vnni", True), + (11, "llvm -mtriple=x86_64-- -mcpu=alderlake", "avx512vnni", False), + (11, "llvm -mtriple=x86_64-- -mcpu=alderlake", "avxvnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cascadelake", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-client", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-server", "avx512vnni", True), + (11, "llvm -mtriple=x86_64-- -mcpu=rocketlake", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=tigerlake", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cooperlake", "avx512vnni", True), + (11, "llvm -mtriple=x86_64-- -mcpu=sapphirerapids", "avx512vnni", True), # amx-int8 - (11, "llvm -mcpu=sapphirerapids", "amx-int8", True), + (11, "llvm -mtriple=x86_64-- -mcpu=sapphirerapids", "amx-int8", True), # generic CPU (no features) but with extra -mattr - (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "avx2", True), - (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "sse4.1", True), - (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "ssse3", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=x86-64 -mattr=+sse4.1,+avx2", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=x86-64 -mattr=+sse4.1,+avx2", "sse4.1", True), + # enabling +sse4.1 implies ssse3 presence in LLVM + (-1, "llvm -mtriple=x86_64-- -mcpu=x86-64 -mattr=+sse4.1,+avx2", "ssse3", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=ivybridge -mattr=-ssse3", "ssse3", False), + # disabling avx512f (foundation) also disables avx512bw + (-1, "llvm -mtriple=x86_64-- -mcpu=cascadelake -mattr=-avx512f", "avx512bw", False), ) @@ -135,7 +139,7 @@ def test_x86_target_features(min_llvm_version, tvm_target, x86_feature, is_suppo if isinstance(x86_feature, str): # check for feature via the ffi llvm api (no explicit target, no context target) try: - assert _ffi_api.llvm_x86_has_feature(x86_feature, None) == is_supported + assert _ffi_api.target_has_feature(x86_feature, None) == is_supported assert False except tvm.error.InternalError as e: msg = str(e) @@ -154,7 +158,7 @@ def test_x86_target_features(min_llvm_version, tvm_target, x86_feature, is_suppo assert target_has_features(x86_feature, Target(tvm_target)) == is_supported if isinstance(x86_feature, str): # check for feature via the ffi llvm api (with explicit target, no context target) - assert _ffi_api.llvm_x86_has_feature(x86_feature, Target(tvm_target)) == is_supported + assert _ffi_api.target_has_feature(x86_feature, Target(tvm_target)) == is_supported ## ## with context @@ -166,11 +170,8 @@ def test_x86_target_features(min_llvm_version, tvm_target, x86_feature, is_suppo assert target_has_features(x86_feature) == is_supported # check for feature via the python api (with explicit target) assert target_has_features(x86_feature, Target(tvm_target)) == is_supported - if isinstance(x86_feature, str): - # check for feature via the ffi llvm api (current context target) - assert _ffi_api.llvm_x86_has_feature(x86_feature, None) == is_supported - # check for feature via the ffi llvm api (with explicit target) - assert _ffi_api.llvm_x86_has_feature(x86_feature, Target(tvm_target)) == is_supported - # check for feature in target's llvm full x86 CPU feature list - if not Target(tvm_target).mattr: - assert (x86_feature in codegen.llvm_x86_get_features(mcpu)) == is_supported + # check for feature via the ffi llvm api (current context target) + (sum(_ffi_api.target_has_feature(feat, None) for feat in x86_feature) > 0) == is_supported + # check for feature in target's llvm full x86 CPU feature list + if (not Target(tvm_target).mattr) and isinstance(x86_feature, str): + assert (x86_feature in codegen.llvm_get_cpu_features()) == is_supported diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 6a2f5573b2..f1316ae3ce 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -946,7 +946,7 @@ def test_llvm_target_attributes(): xo, xi = s[C].split(C.op.axis[0], nparts=2) s[C].parallel(xo) - target_llvm = "llvm -mcpu=skylake -mattr=+avx512f" + target_llvm = "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake -mattr=+avx512f" target = tvm.target.Target(target_llvm, host=target_llvm) module = tvm.build(s, [A, B, C, n], target=target, name="test_func") From 0d683284b0a5596369e2a7acf8179d02aa89d893 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 27 Sep 2023 14:51:06 -0500 Subject: [PATCH 07/13] [Unittest][Metal] Add minimal metal functionality test to CI (#15756) * [Unittest][Metal] Add minimal metal functionality test to CI Prior to this commit, the CI compiled TVM with `USE_METAL=ON` on OSX, as defined in `conda/recipe/build.sh`, but did not validate the execution of any generated metal kernels. As a result, breakage could occur without being caught by the CI, such as found following https://github.com/apache/tvm/pull/15103. This commit adds the execution of a single metal kernel as a minimal functionality test of the metal backend. * CI testing, attempt a compile-only test case * CI testing, moved intentional failure from test-case to contrib.xcode * Move intentional failure point into codegen * ci bump * Removing the intentional failure during metallib compilation --- .github/workflows/main.yml | 8 +++++ tests/python/unittest/test_allreduce.py | 44 +++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a4a30fe19a..7b4b8d826f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -69,6 +69,14 @@ jobs: shell: bash -l {0} run: >- python -m pytest -v tests/python/all-platform-minimal-test + - name: Minimal Metal Compile-Only + shell: bash -l {0} + run: >- + python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile' + - name: Minimal Metal Compile-and-Run + shell: bash -l {0} + run: >- + python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum[dims0-metal]' - name: Test iOS RPC shell: bash -l {0} run: >- diff --git a/tests/python/unittest/test_allreduce.py b/tests/python/unittest/test_allreduce.py index 708384daf0..fed4e4c04d 100644 --- a/tests/python/unittest/test_allreduce.py +++ b/tests/python/unittest/test_allreduce.py @@ -19,6 +19,8 @@ import numpy as np from tvm.script import tir as T +import pytest + @T.prim_func def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None: @@ -82,6 +84,48 @@ def test_allreduce_sum(dims, target, dev): tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6) +define_metal_compile_callback = tvm.testing.parameter(True, False) + + +@pytest.fixture +def optional_metal_compile_callback(define_metal_compile_callback): + name = "tvm_callback_metal_compile" + cached = tvm.get_global_func(name, allow_missing=True) + + if define_metal_compile_callback: + + @tvm.register_func(name, override=True) + def compile_metal(src, target): + return tvm.contrib.xcode.compile_metal(src, sdk="macosx") + + yield + + if define_metal_compile_callback: + if cached is None: + tvm._ffi.registry.remove_global_func(name) + else: + tvm.register_func(name, cached, override=True) + + +@tvm.testing.requires_metal(support_required="compile-only") +def test_allreduce_sum_compile(optional_metal_compile_callback): + # Disable the parametrization over dims, at least for now + dims = (1, 1, 2) + target = "metal" + + d1, d2, d3 = dims + _, _, _d1, _d2, _d3 = reduce.params + mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3}) + sch = tvm.tir.Schedule(mod) + blk = sch.get_block("reduce") + i, j, k, l = sch.get_loops(blk) + sch.bind(i, "blockIdx.x") + sch.bind(j, "threadIdx.z") + sch.bind(k, "threadIdx.y") + sch.bind(l, "threadIdx.x") + tvm.build(sch.mod["main"], target=target) + + @tvm.testing.parametrize_targets("cuda", "metal") def test_allreduce_max(dims, target, dev): d1, d2, d3 = dims From 73e7909a71f03dbdc0c8e8ad250a77597c994508 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 27 Sep 2023 14:57:01 -0500 Subject: [PATCH 08/13] [TVMScript] Preserve traceback across TVMScript parsing (#15824) Prior to this commit, exceptions raised during the parsing of TVMScript would be caught and replaced with a new exception. While this does allow the TVMScript location of the error to be included in the exception, it also removes the stack trace of the original error. This commit updates the `Parser.report_error` function to provide the original stack trace alongside the updated exception object. --- python/tvm/script/parser/core/parser.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 7b7dd066c3..ae79eef126 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -522,7 +522,18 @@ def report_error( msg = "KeyError: " + str(err) else: msg = str(err) - self.diag.error(node, msg) + + try: + self.diag.error(node, msg) + except Exception as diag_err: + # Calling self.diag.error is guaranteed to throw an + # exception. When shown to a user, this error should + # reference the point of error within the provided + # TVMScript. However, when caught in pdb, the full + # traceback should be available for debugging. + if isinstance(err, Exception): + diag_err = diag_err.with_traceback(err.__traceback__) + raise diag_err def visit(self, node: doc.AST) -> None: """The general visiting method. From cf081d992992b191ebad4256f7a915b2c7368185 Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Thu, 28 Sep 2023 09:28:40 +0800 Subject: [PATCH 09/13] [BugFix][CPP] Fix cpp deploy bug (#15773) ## Motivation Fix bug when using `apps/howto_deploy/run_example.sh`, it will cause `core dumped` ![image](https://github.com/apache/tvm/assets/25873202/a55b40ba-0f28-4c75-b216-591c6008734f) After fixed, all test pass. ![image](https://github.com/apache/tvm/assets/25873202/61db50eb-77a2-43b4-9cf4-1f24aba49b67) ## Modification Add runtime for `tvm.build` --- apps/howto_deploy/prepare_test_libs.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/apps/howto_deploy/prepare_test_libs.py b/apps/howto_deploy/prepare_test_libs.py index f5afc5bf67..45d7c0abd9 100644 --- a/apps/howto_deploy/prepare_test_libs.py +++ b/apps/howto_deploy/prepare_test_libs.py @@ -33,7 +33,13 @@ def prepare_test_libs(base_path): fadd_dylib.export_library(dylib_path) # Compile library in system library mode - fadd_syslib = tvm.build(s, [A, B], "llvm", name="addonesys") + fadd_syslib = tvm.build( + s, + [A, B], + "llvm", + name="addonesys", + runtime=relay.backend.Runtime("cpp", {"system-lib": True}), + ) syslib_path = os.path.join(base_path, "test_addone_sys.o") fadd_syslib.save(syslib_path) From 9d8e6fda50bce14bf597de1f87711230e6001e4e Mon Sep 17 00:00:00 2001 From: Siva Date: Thu, 28 Sep 2023 15:38:30 +0530 Subject: [PATCH 10/13] [ADRENO] Minor changes for Adreno docs and help scripts (#15830) [ADRENO] Minor changes for Adreno docs and help scripts NCHW is mandatory layout for CLML offload. Updated the docs. CI scripts will keep OpenCL enbaled as fallback always. Enable configurable device bind ports. Helps in multi user environments. --- docs/how_to/deploy/adreno.rst | 2 +- tests/scripts/setup-adreno-env.sh | 18 +++++++++++++----- tests/scripts/task_config_build_adreno.sh | 2 ++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index ed016a3ff7..f0b8c6f757 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -432,7 +432,7 @@ as the OpenCL path is fallback option for any operator didn't go through OpenCLM python3 -m tvm.driver.tvmc compile \ --cross-compiler ${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang \ - --target="opencl, clml, llvm" --target-llvm-mtriple aarch64-linux-gnu --target-opencl-device adreno \ + --target="opencl, clml, llvm" --desired-layout NCHW --target-llvm-mtriple aarch64-linux-gnu --target-opencl-device adreno \ --tuning-records keras-resnet50.log -o keras-resnet50.tar resnet50.h5 On successful compilation, above command produce ``keras-resnet50.tar``. diff --git a/tests/scripts/setup-adreno-env.sh b/tests/scripts/setup-adreno-env.sh index 55a92c5f61..15c124a0f0 100755 --- a/tests/scripts/setup-adreno-env.sh +++ b/tests/scripts/setup-adreno-env.sh @@ -20,12 +20,13 @@ ENVIRONMENT="" RPC_PORT="" ADB_SERIAL="" +LISTEN_PORT=5000 function usage() { echo "Helper script to setup the environment for Tracker, RPC Device and for application" echo "Usage (Help) : source setup-adreno-env.sh -h" echo "Usage (Tracker): source setup-adreno-env.sh -e tracker -p " - echo "Usage (Device): source setup-adreno-env.sh -e device -p -d " + echo "Usage (Device): source setup-adreno-env.sh -e device -p -d [-l ]" echo "Usage (Query): source setup-adreno-env.sh -e query -p " } @@ -46,6 +47,11 @@ while [[ $# -gt 0 ]]; do shift # past argument shift # past value ;; + -l|--listen-port) + LISTEN_PORT="$2" + shift # past argument + shift # past value + ;; -h|--help) usage return 0 @@ -62,6 +68,7 @@ done echo "ENVIRONMENT = ${ENVIRONMENT}" echo "RPC_PORT = ${RPC_PORT}" echo "ADB_SERIAL = ${ADB_SERIAL}" +echo "DEVICE LISTEN POPRT = ${LISTEN_PORT}" function def_environment() { @@ -100,10 +107,11 @@ case ${ENVIRONMENT} in fi adb reverse tcp:${TVM_TRACKER_PORT} tcp:${TVM_TRACKER_PORT} - adb forward tcp:5000 tcp:5000 - adb forward tcp:5001 tcp:5001 - adb forward tcp:5002 tcp:5002 - adb shell "cd ${TARGET_FOLDER}; killall -9 tvm_rpc-${USER}; sleep 2; LD_LIBRARY_PATH=${TARGET_FOLDER}/ ./tvm_rpc-${USER} server --host=0.0.0.0 --port=5000 --port-end=5010 --tracker=127.0.0.1:${TVM_TRACKER_PORT} --key=${RPC_DEVICE_KEY}" + adb forward tcp:${LISTEN_PORT} tcp:${LISTEN_PORT} + adb forward tcp:$((LISTEN_PORT + 1)) tcp:$((LISTEN_PORT + 1)) + adb forward tcp:$((LISTEN_PORT + 2)) tcp:$((LISTEN_PORT + 2)) + adb forward tcp:$((LISTEN_PORT + 3)) tcp:$((LISTEN_PORT + 3)) + adb shell "cd ${TARGET_FOLDER}; killall -9 tvm_rpc-${USER}; sleep 2; LD_LIBRARY_PATH=${TARGET_FOLDER}/ ./tvm_rpc-${USER} server --host=0.0.0.0 --port=${LISTEN_PORT} --port-end=$((LISTEN_PORT + 10)) --tracker=127.0.0.1:${TVM_TRACKER_PORT} --key=${RPC_DEVICE_KEY}" ;; "query") diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh index 62e6ffecbc..1b6750f165 100755 --- a/tests/scripts/task_config_build_adreno.sh +++ b/tests/scripts/task_config_build_adreno.sh @@ -25,6 +25,8 @@ cp ../cmake/config.cmake . if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake +else +echo set\(USE_OPENCL ON\) >> config.cmake fi echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake From 8b40f5d028632da82bd6cbf83865041d4186b068 Mon Sep 17 00:00:00 2001 From: Siva Date: Fri, 29 Sep 2023 10:29:00 +0530 Subject: [PATCH 11/13] [FRONTEND] Fix unnecessary pylint errors (#15838) Handle unnecessary pylint errors from these frontends --- tests/python/frontend/keras/test_forward.py | 2 +- tests/python/frontend/oneflow/test_forward.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 9d33b15a91..ba3880e186 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -28,11 +28,11 @@ # prevent Keras from using up all gpu memory import keras +import pytest import tvm from tvm import relay from tvm.contrib import graph_executor import tvm.testing -import pytest if tf.executing_eagerly(): GPUS = tf.config.experimental.list_physical_devices("GPU") diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index 7ddc347e86..fda5f1b723 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -20,11 +20,11 @@ import numpy as np import oneflow as flow +from packaging import version as package_version import tvm import tvm.testing import tvm.topi.testing from tvm import relay -from packaging import version as package_version MODEL_HOME = "test_model" From def551dfd50bfff4e9d50108dc4e8027b553b8ec Mon Sep 17 00:00:00 2001 From: Siva Date: Fri, 29 Sep 2023 10:30:20 +0530 Subject: [PATCH 12/13] [CLI TOOLS][RTVM] Improve rtvm tool with new options to measure native performance (#15818) * [RTVM] Improve rtvm tool with new options to measure native performance Few fixes and enhancements that affects model loading times New options to measure performance. * * review comments * * review comments --- apps/cpp_rtvm/README.md | 22 ++++ apps/cpp_rtvm/main.cc | 199 +++++++++++++++++++++++++++++++----- apps/cpp_rtvm/tvm_runner.cc | 129 +++++++++++++++++------ apps/cpp_rtvm/tvm_runner.h | 24 ++++- 4 files changed, 316 insertions(+), 58 deletions(-) diff --git a/apps/cpp_rtvm/README.md b/apps/cpp_rtvm/README.md index c60a7b0e12..652d46eb58 100644 --- a/apps/cpp_rtvm/README.md +++ b/apps/cpp_rtvm/README.md @@ -122,6 +122,11 @@ Command line usage --input - Numpy file for the model input (optional and we use random of not given) --output - Numpy file name to dump the model output as numpy --dump-meta - Dump model meta information +--pre-compiled - The file name of a file where pre-compiled programs should be stored +--profile - Profile over all execution +--dry-run - Profile after given dry runs, default 10 +--run-count - Profile for given runs, default 50 +--zero-copy - Profile with zero copy api Example ./rtvm --model=keras-resnet50 --device="opencl" --dump-meta @@ -366,3 +371,20 @@ stored. If the pre-compiled file name was passed to the `rtvm` then After method `Load`, method `UsePreCompiledProgram` is called. This method loads pre-compiled programs if the file exists. In opposite case the file will be created and pre-compiled programs will be saved to this file. + +# Performnace Profiling Options +The tool has added few options to measure wall clock performance of the given model on Target natively. +--profile : Can turn on the profiling +--dry-run : The number of times dry run the model before mearuring the performance. Default value os 10 +--run-count : The number times to run the model and take an average. Default value is 50. +--zero-copy: This option enables graph runtime zero copy to be used for input and output than byte copy to DLTensor. + +Performance profile options dumps information summary as given below. + Module Load :27 ms + Graph Runtime Create :11 ms + Params Read :15 ms + Params Set :41 ms + Pre Compiled Progs Load :24 ms +Total Load Time :118 ms +Average ExecTime :27 ms +Unload Time :35.9236 ms diff --git a/apps/cpp_rtvm/main.cc b/apps/cpp_rtvm/main.cc index c38a5f62bd..dc3cf1c414 100644 --- a/apps/cpp_rtvm/main.cc +++ b/apps/cpp_rtvm/main.cc @@ -29,6 +29,7 @@ #endif #include +#include #include #include #include @@ -54,7 +55,11 @@ static const string kUsage = "--input - Numpy file for the model input (optional and we use random of not given)\n" "--output - Numpy file name to dump the model output as numpy\n" "--dump-meta - Dump model meta information\n" - "--pre-compiled - The file name of a file where pre-compiled programs should be stored" + "--pre-compiled - The file name of a file where pre-compiled programs should be stored\n" + "--profile - Profile over all execution\n" + "--dry-run - Profile after given dry runs, default 10\n" + "--run-count - Profile for given runs, default 50\n" + "--zero-copy - Profile with zero copy api\n" "\n" " Example\n" " ./rtvm --model=keras-resnet50 --device=\"opencl\" --dump-meta\n" @@ -68,6 +73,7 @@ static const string kUsage = * \arg input Numpy file for the model input * \arg output Numpy file name to dump the model output as numpy * \arg pre_compiled File name where pre-compiled programs should be stored + * \arg profile Do we profile overall execution */ struct ToolArgs { string model; @@ -75,7 +81,11 @@ struct ToolArgs { string input; string output; string pre_compiled; - bool dump_meta = false; + bool dump_meta{false}; + bool profile{false}; + int dry_run{10}; + int run_count{50}; + bool zero_copy{false}; }; /*! @@ -89,6 +99,10 @@ void PrintArgs(const ToolArgs& args) { LOG(INFO) << "Output = " << args.output; LOG(INFO) << "Pre-compiled = " << args.pre_compiled; LOG(INFO) << "Dump Metadata = " << ((args.dump_meta) ? ("True") : ("False")); + LOG(INFO) << "Profile = " << ((args.profile) ? ("True") : ("False")); + LOG(INFO) << "Dry Run = " << args.dry_run; + LOG(INFO) << "Run Count = " << args.run_count; + LOG(INFO) << "Zero Copy = " << ((args.zero_copy) ? ("True") : ("False")); } #if defined(__linux__) || defined(__ANDROID__) @@ -178,6 +192,26 @@ void ParseCmdArgs(int argc, char* argv[], struct ToolArgs& args) { } args.pre_compiled = GetCmdOption(argc, argv, "--pre-compiled="); + + const string pprofile = GetCmdOption(argc, argv, "--profile", true); + if (!pprofile.empty()) { + args.profile = true; + } + + const string pdry_run = GetCmdOption(argc, argv, "--dry-run="); + if (!pdry_run.empty()) { + args.dry_run = stoi(pdry_run); + } + + const string prun = GetCmdOption(argc, argv, "--run-count="); + if (!prun.empty()) { + args.run_count = stoi(prun); + } + + const string pzcopy = GetCmdOption(argc, argv, "--zero-copy", true); + if (!pzcopy.empty()) { + args.zero_copy = true; + } } /*! @@ -192,59 +226,174 @@ int ExecuteModel(ToolArgs& args) { #endif // Initialize TVM Runner - TVMRunner runner = TVMRunner(args.model, args.device); + auto runner = new TVMRunner(args.model, args.device); // Load the model - runner.Load(); + runner->Load(); if (!args.pre_compiled.empty()) { - runner.UsePreCompiledPrograms(args.pre_compiled); + runner->UsePreCompiledPrograms(args.pre_compiled); } // Query Model meta Information - TVMMetaInfo mInfo = runner.GetMetaInfo(); + TVMMetaInfo mInfo = runner->GetMetaInfo(); // Print Meta Information - if (args.dump_meta) runner.PrintMetaInfo(); + if (args.dump_meta) runner->PrintMetaInfo(); + + int total_exec_time = 0; + + if (args.profile) { + if (args.dry_run) { + for (int ii = 0; ii < args.dry_run; ++ii) { + runner->Run(); + } + TVMSynchronize(GetTVMDevice(args.device), 0, nullptr); + } + int total_time = 0; + std::map input_data_even, input_data_odd; + std::map output_data_even, output_data_odd; + + std::map input_data; + std::map output_data; + + // Alloc / populate and keep input data ready + for (auto& elem : mInfo.input_info) { + if (args.zero_copy) { + auto ndarr = + NDArray::Empty(elem.second.first, tvm::runtime::String2DLDataType(elem.second.second), + DLDevice{GetTVMDevice(args.device), 0}); + input_data_even.insert({elem.first, ndarr}); + + ndarr = + NDArray::Empty(elem.second.first, tvm::runtime::String2DLDataType(elem.second.second), + DLDevice{GetTVMDevice(args.device), 0}); + input_data_odd.insert({elem.first, ndarr}); + } else { + char* data = (char*)malloc(runner->GetInputMemSize(elem.first)); + input_data.insert({elem.first, data}); + } + } + + // Alloc and keep output bufers ready + for (auto& elem : mInfo.output_info) { + if (args.zero_copy) { + auto ndarr = + NDArray::Empty(elem.second.first, tvm::runtime::String2DLDataType(elem.second.second), + DLDevice{GetTVMDevice(args.device), 0}); + output_data_even.insert({elem.first, ndarr}); + + ndarr = + NDArray::Empty(elem.second.first, tvm::runtime::String2DLDataType(elem.second.second), + DLDevice{GetTVMDevice(args.device), 0}); + output_data_odd.insert({elem.first, ndarr}); + } else { + char* data = (char*)malloc(runner->GetOutputMemSize(elem.first)); + output_data.insert({elem.first, data}); + } + } + + for (int ii = 0; ii < args.run_count; ++ii) { + // Timer start + auto tstart = std::chrono::high_resolution_clock::now(); + // Set random input for all input + for (auto& elem : mInfo.input_info) { + if (args.zero_copy) { + if (ii % 2) { + runner->SetInput(elem.first, input_data_even[elem.first]); + } else { + runner->SetInput(elem.first, input_data_odd[elem.first]); + } + } else { + runner->SetInput(elem.first, input_data[elem.first]); + } + } + + if (args.zero_copy) { + // With zero copy set the result NDArray up front + for (auto& elem : mInfo.output_info) { + if (ii % 2) { + runner->SetOutput(elem.first, output_data_even[elem.first]); + } else { + runner->SetOutput(elem.first, output_data_odd[elem.first]); + } + } + } - if (args.input.empty() || args.output.empty()) { + // Run the model + runner->Run(); + + if (!args.zero_copy) { + // W/o zero copy we need to invoke explicite data copy + for (auto& elem : mInfo.output_info) { + runner->GetOutput(elem.first, output_data[elem.first]); + } + } else { + // Just wait for the run to complete. + TVMSynchronize(GetTVMDevice(args.device), 0, nullptr); + } + + // Timer end + auto tend = std::chrono::high_resolution_clock::now(); + LOG(INFO) << "Exec Time:" << static_cast((tend - tstart).count()) / 1e6; + total_exec_time += static_cast((tend - tstart).count()) / 1e6; + } + + // Free input bufers + for (auto& elem : mInfo.input_info) { + free(input_data[elem.first]); + } + + // Free output bufers + for (auto& elem : mInfo.output_info) { + free(output_data[elem.first]); + } + } else if (!args.input.empty() && !args.output.empty()) { + LOG(INFO) << "Executing with Input:" << args.input << " Output:" << args.output; + // Set Input from Numpy Input + runner->SetInput(args.input); + // Run the model + runner->Run(); + // Get Output as Numpy dump + runner->GetOutput(args.output); + } else { LOG(INFO) << "Executing dry run ... "; // Set random input for all inputs for (auto& elem : mInfo.input_info) { LOG(INFO) << "Set Random Input for :" << elem.first; auto shape = elem.second.first; - size_t ssize = runner.GetInputMemSize(elem.first); + size_t ssize = runner->GetInputMemSize(elem.first); char* data = (char*)malloc(ssize); LOG(INFO) << "Random Input Size:" << ssize << " bytes"; - runner.SetInput(elem.first, data); + runner->SetInput(elem.first, data); free(data); } - // Run the model - runner.Run(); - + runner->Run(); // Get Output and dump few values for (auto& elem : mInfo.output_info) { LOG(INFO) << "Get Output for :" << elem.first; auto shape = elem.second.first; - size_t ssize = runner.GetOutputMemSize(elem.first); + size_t ssize = runner->GetOutputMemSize(elem.first); char* data = (char*)malloc(ssize); - runner.GetOutput(elem.first, data); + runner->GetOutput(elem.first, data); LOG(INFO) << "Output Size:" << ssize << " bytes"; free(data); } - } else { - LOG(INFO) << "Executing with Input:" << args.input << " Output:" << args.output; - - // Set Input from Numpy Input - runner.SetInput(args.input); - - // Run the model - runner.Run(); + } - // Get Output as Numpy dump - runner.GetOutput(args.output); + if (args.profile) { + // Print Stats + runner->PrintStats(); } + auto tstart = std::chrono::high_resolution_clock::now(); + delete runner; + auto tend = std::chrono::high_resolution_clock::now(); + if (args.profile) { + LOG(INFO) << "Average ExecTime :" << total_exec_time / args.run_count << " ms"; + LOG(INFO) << "Unload Time :" << static_cast((tend - tstart).count()) / 1e6 + << " ms"; + } return 0; } diff --git a/apps/cpp_rtvm/tvm_runner.cc b/apps/cpp_rtvm/tvm_runner.cc index 2fd4f2281e..7d6dbc23ee 100644 --- a/apps/cpp_rtvm/tvm_runner.cc +++ b/apps/cpp_rtvm/tvm_runner.cc @@ -26,10 +26,12 @@ #include +#include #include #include #include #include +#include namespace tvm { namespace runtime { @@ -39,25 +41,25 @@ namespace runtime { * \param device the target device in string format. * \return dl_device corresponding to the device string. */ -int GetTVMDevice(std::string device) { +DLDeviceType GetTVMDevice(std::string device) { if (!device.compare("cpu")) { - return static_cast(kDLCPU); + return kDLCPU; } else if (!device.compare("llvm")) { - return static_cast(kDLCPU); + return kDLCPU; } else if (!device.compare("cuda")) { - return static_cast(kDLCUDA); + return kDLCUDA; } else if (!device.compare("opencl")) { - return static_cast(kDLOpenCL); + return kDLOpenCL; } else if (!device.compare("vulkan")) { - return static_cast(kDLVulkan); + return kDLVulkan; } else if (!device.compare("metal")) { - return static_cast(kDLMetal); + return kDLMetal; } else if (!device.compare("vpi")) { - return static_cast(kDLVPI); + return kDLVPI; } else if (!device.compare("rocm")) { - return static_cast(kDLROCM); + return kDLROCM; } else if (!device.compare("oneapi")) { - return static_cast(kDLOneAPI); + return kDLOneAPI; } else { LOG(FATAL) << "TVMRunner : Unsupported device :" << device; } @@ -80,34 +82,59 @@ TVMRunner::TVMRunner(std::string path, std::string device) int TVMRunner::Load(void) { LOG(INFO) << "TVMRunner Load:" << r_model_path; // Load the lib file + auto tstart = std::chrono::high_resolution_clock::now(); + r_mod_handle = Module::LoadFromFile((r_model_path + "/mod.so").c_str(), "so"); + auto tend = std::chrono::high_resolution_clock::now(); + r_module_load_ms = static_cast((tend - tstart).count()) / 1e6; + tstart = std::chrono::high_resolution_clock::now(); // Read model json file std::ifstream json_reader((r_model_path + "/mod.json").c_str()); CHECK(!json_reader.fail()) << "Failed to open json file:" << (r_model_path + "/mod.json").c_str(); - std::string json_str((std::istreambuf_iterator(json_reader)), - std::istreambuf_iterator()); + json_reader.seekg(0, std::ios_base::end); + std::size_t json_size = json_reader.tellg(); + json_reader.seekg(0, std::ios_base::beg); + std::string json_data; + json_data.reserve(json_size); + json_reader.read((char*)json_data.c_str(), json_size); json_reader.close(); // Get ref to graph exeutor auto f_handle = tvm::runtime::Registry::Get("tvm.graph_executor.create"); // Greate graph runtime - r_graph_handle = (*f_handle)(json_str, r_mod_handle, GetTVMDevice(r_device), 0); + r_graph_handle = + (*f_handle)(json_data, r_mod_handle, static_cast(GetTVMDevice(r_device)), 0); + + tend = std::chrono::high_resolution_clock::now(); + r_graph_load_ms = static_cast((tend - tstart).count()) / 1e6; // Read params binary file + tstart = std::chrono::high_resolution_clock::now(); std::ifstream params_reader((r_model_path + "/mod.params").c_str(), std::ios::binary); CHECK(!params_reader.fail()) << "Failed to open json file:" << (r_model_path + "/mod.params").c_str(); - const std::string params_str((std::istreambuf_iterator(params_reader)), - std::istreambuf_iterator()); + + params_reader.seekg(0, std::ios_base::end); + std::size_t param_size = params_reader.tellg(); + params_reader.seekg(0, std::ios_base::beg); + std::vector param_data(param_size / sizeof(char)); + params_reader.read((char*)¶m_data[0], param_size); params_reader.close(); + TVMByteArray params_arr; - params_arr.data = params_str.c_str(); - params_arr.size = params_str.length(); + params_arr.data = (char*)¶m_data[0]; + params_arr.size = param_size; + + tend = std::chrono::high_resolution_clock::now(); + r_param_read_ms = static_cast((tend - tstart).count()) / 1e6; // Load parameters + tstart = std::chrono::high_resolution_clock::now(); r_graph_handle.GetFunction("load_params")(params_arr); + tend = std::chrono::high_resolution_clock::now(); + r_param_load_ms = static_cast((tend - tstart).count()) / 1e6; return 0; } @@ -117,6 +144,7 @@ int TVMRunner::Load(void) { * \param file_name File name where pre-compiled programs should be stored. */ void TVMRunner::UsePreCompiledPrograms(std::string file_name) { + auto tstart = std::chrono::high_resolution_clock::now(); if (r_run_was_called) { LOG(INFO) << "TVMRunner UsePreCompiledPrograms: should be called before first run"; return; @@ -130,10 +158,19 @@ void TVMRunner::UsePreCompiledPrograms(std::string file_name) { std::ofstream fs(file_name, std::ofstream::binary); fs.write(bytes.c_str(), bytes.size()); } else { - std::string bytes((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); - f_set(String(bytes)); + ifs.seekg(0, std::ios_base::end); + std::size_t blob_size = ifs.tellg(); + ifs.seekg(0, std::ios_base::beg); + std::string blob_data; + blob_data.reserve(blob_size); + blob_data.resize(blob_size); + ifs.read((char*)blob_data.c_str(), blob_size); + ifs.close(); + f_set(String(blob_data)); } } + auto tend = std::chrono::high_resolution_clock::now(); + r_pre_compiled_load_ms = static_cast((tend - tstart).count()) / 1e6; } /*! @@ -156,8 +193,6 @@ inline size_t GetMemSize(NDArray& narr) { * \return The memory size. */ size_t TVMRunner::GetInputMemSize(std::string input_id) { - LOG(INFO) << "TVMRunner::GetInputMemSize:" << input_id; - NDArray in_arr = r_graph_handle.GetFunction("get_input")(input_id); auto ssize = GetMemSize(in_arr); @@ -170,8 +205,6 @@ size_t TVMRunner::GetInputMemSize(std::string input_id) { * \return The memory size. */ size_t TVMRunner::GetOutputMemSize(std::string output_id) { - LOG(INFO) << "TVMRunner::GetOutputMemSize:" << output_id; - NDArray out_arr = r_graph_handle.GetFunction("get_output")(output_id); auto ssize = GetMemSize(out_arr); @@ -209,13 +242,23 @@ int TVMRunner::SetInput(std::string inputfile) { * \param 0 on success else error code. */ int TVMRunner::SetInput(std::string input_id, char* raw_input) { - LOG(INFO) << "TVMRunner::SetInput (Raw)"; NDArray in_arr = r_graph_handle.GetFunction("get_input")(input_id); auto ssize = GetMemSize(in_arr); in_arr.CopyFromBytes(raw_input, ssize); return 0; } +/*! + * \brief Set the model input from given NDArray with zero copy. + * \param input_id input node name. + * \param ndarr NDArray. + * \param 0 on success else error code. + */ +int TVMRunner::SetInput(std::string input_id, NDArray& ndarr) { + r_graph_handle.GetFunction("set_input_zero_copy")(input_id, ndarr); + return 0; +} + /*! * \brief Get the model outputs and dump them to npz file. * \param outputfile the npz file to where we dump the output data. @@ -255,21 +298,29 @@ int TVMRunner::GetOutput(std::string outputfile) { * \param 0 on success else error code. */ int TVMRunner::GetOutput(std::string output_id, char* raw_output) { - LOG(INFO) << "TVMRunner::GetOutput (Raw)"; NDArray out_arr = r_graph_handle.GetFunction("get_output")(output_id); auto ssize = GetMemSize(out_arr); out_arr.CopyToBytes(raw_output, ssize); return 0; } +/*! + * \brief Set the model output from given NDArray with zero copy. + * \param output_id output node name. + * \param ndarr NDArray. + * \param 0 on success else error code. + */ +int TVMRunner::SetOutput(std::string output_id, NDArray& ndarr) { + r_graph_handle.GetFunction("set_output_zero_copy")(output_id, ndarr); + return 0; +} + /*! * \brief Call one cycle of execution for the model. * \param 0 on success else error code. */ int TVMRunner::Run(void) { - LOG(INFO) << "TVMRunner::Run"; r_run_was_called = true; - r_graph_handle.GetFunction("run")(); return 0; } @@ -289,10 +340,10 @@ TVMMetaInfo TVMRunner::GetMetaInfo(void) { auto dtype_info = GetRef>(tvm_input_info["dtype"].as()); for (const auto& kv : shape_info) { auto stuple = GetRef(kv.second.as()); - std::vector vshape; + std::vector vshape; vshape.assign(stuple.begin(), stuple.end()); auto dtype = GetRef(dtype_info[kv.first].as()); - std::pair, std::string> value = std::make_pair(vshape, dtype); + std::pair, std::string> value = std::make_pair(vshape, dtype); mInfo.input_info.insert({kv.first, value}); } @@ -301,10 +352,10 @@ TVMMetaInfo TVMRunner::GetMetaInfo(void) { dtype_info = GetRef>(tvm_input_info["dtype"].as()); for (const auto& kv : shape_info) { auto stuple = GetRef(kv.second.as()); - std::vector vshape; + std::vector vshape; vshape.assign(stuple.begin(), stuple.end()); auto dtype = GetRef(dtype_info[kv.first].as()); - std::pair, std::string> value = std::make_pair(vshape, dtype); + std::pair, std::string> value = std::make_pair(vshape, dtype); mInfo.output_info.insert({kv.first, value}); } @@ -343,5 +394,21 @@ void TVMRunner::PrintMetaInfo(void) { } } +/*! + * \brief Print stats information. + */ +void TVMRunner::PrintStats(void) { + LOG(INFO) << "Performance Stats:" << r_model_path; + LOG(INFO) << " Module Load :" << r_module_load_ms << " ms"; + LOG(INFO) << " Graph Runtime Create :" << r_graph_load_ms << " ms"; + LOG(INFO) << " Params Read :" << r_param_read_ms << " ms"; + LOG(INFO) << " Params Set :" << r_param_load_ms << " ms"; + LOG(INFO) << " Pre Compiled Progs Load :" << r_pre_compiled_load_ms << " ms"; + LOG(INFO) << "Total Load Time :" + << r_module_load_ms + r_graph_load_ms + r_param_read_ms + r_param_load_ms + + r_pre_compiled_load_ms + << " ms"; +} + } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rtvm/tvm_runner.h b/apps/cpp_rtvm/tvm_runner.h index 47717c3ecf..e93b63ae85 100644 --- a/apps/cpp_rtvm/tvm_runner.h +++ b/apps/cpp_rtvm/tvm_runner.h @@ -41,8 +41,8 @@ namespace runtime { typedef struct _TVMMetaInfo { int n_inputs; int n_outputs; - std::map, std::string>> input_info; - std::map, std::string>> output_info; + std::map, std::string>> input_info; + std::map, std::string>> output_info; } TVMMetaInfo; /*! @@ -63,10 +63,14 @@ class TVMRunner { int SetInput(std::string); /*! \brief To set the input from binary data */ int SetInput(std::string, char*); + /*! \brief To set the input from NDArray */ + int SetInput(std::string, NDArray& ndarr); /*! \brief Save the model output into given npz file */ int GetOutput(std::string); /*! \brief Get the model output in binary format */ int GetOutput(std::string, char*); + /*! \brief Swap output NDArray with given one */ + int SetOutput(std::string, NDArray& ndarr); /*! \brief To get the input mem size */ size_t GetInputMemSize(std::string); /*! \brief To get the output mem size */ @@ -76,6 +80,21 @@ class TVMRunner { /*! \brief Print function to show all meta information */ void PrintMetaInfo(void); + /*! \brief Print function to show all stats information */ + void PrintStats(void); + + // Public profiling information + /*! Module load time */ + int r_module_load_ms{0}; + /*! Graph runtime creatint time */ + int r_graph_load_ms{0}; + /*! Params read time */ + int r_param_read_ms{0}; + /*! Params load time */ + int r_param_load_ms{0}; + /*! Pre compiled programs load time */ + int r_pre_compiled_load_ms{0}; + private: /*! \brief Module handle for the shared object */ Module r_mod_handle; @@ -91,6 +110,7 @@ class TVMRunner { bool r_run_was_called; }; +DLDeviceType GetTVMDevice(std::string device); } // namespace runtime } // namespace tvm #endif // TVM_APPS_CPP_RTVM_RUNNER_H_ From 28908998e0c55025a89e8e2bd26a3fe3e6c84356 Mon Sep 17 00:00:00 2001 From: Qingchao Shen Date: Fri, 29 Sep 2023 15:54:23 +0800 Subject: [PATCH 13/13] [Relay][Keras][Bugfix] fix the converters of GRU and SimpleRNN about the go_backwards attribute (#15829) * fix bug in gru and simpleRNN about go_backwards * Update test_forward.py * Update keras.py --- python/tvm/relay/frontend/keras.py | 4 ++++ tests/python/frontend/keras/test_forward.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 9e09cb400a..6c82ebb427 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -1062,6 +1062,8 @@ def _convert_simple_rnn( in_bias = etab.new_const(weightList[2]) assert len(in_data.type_annotation.shape) == 3 timeDim = in_data.type_annotation.shape[1].value + if keras_layer.go_backwards: + in_data = _op.reverse(in_data, axis=1) in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1) for i in range(len(in_data_split)): in_data_split_i = _op.nn.batch_flatten(in_data_split[i]) @@ -1090,6 +1092,8 @@ def _convert_gru( recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) if keras_layer.use_bias: in_bias = etab.new_const(weightList[2]) + if keras_layer.go_backwards: + in_data = _op.reverse(in_data, axis=1) units = list(weightList[0].shape)[1] assert units > 0, "The value of units must be a positive integer" in_data = _op.nn.batch_flatten(in_data) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index ba3880e186..8c5b578060 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -568,12 +568,23 @@ def test_forward_rnn(self, keras_mod): keras_mod.layers.SimpleRNN( units=16, return_state=False, activation="tanh", use_bias=False ), + keras_mod.layers.SimpleRNN( + units=16, return_state=False, activation="tanh", go_backwards=True + ), + keras_mod.layers.GRU( + units=16, + return_state=False, + recurrent_activation="sigmoid", + activation="tanh", + reset_after=False, + ), keras_mod.layers.GRU( units=16, return_state=False, recurrent_activation="sigmoid", activation="tanh", reset_after=False, + use_bias=False, ), keras_mod.layers.GRU( units=16, @@ -582,6 +593,7 @@ def test_forward_rnn(self, keras_mod): activation="tanh", reset_after=False, use_bias=False, + go_backwards=True, ), ] for rnn_func in rnn_funcs: