diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a4a30fe19a..7b4b8d826f 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -69,6 +69,14 @@ jobs: shell: bash -l {0} run: >- python -m pytest -v tests/python/all-platform-minimal-test + - name: Minimal Metal Compile-Only + shell: bash -l {0} + run: >- + python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum_compile' + - name: Minimal Metal Compile-and-Run + shell: bash -l {0} + run: >- + python -m pytest -v -s 'tests/python/unittest/test_allreduce.py::test_allreduce_sum[dims0-metal]' - name: Test iOS RPC shell: bash -l {0} run: >- diff --git a/apps/cpp_rtvm/README.md b/apps/cpp_rtvm/README.md index c60a7b0e12..652d46eb58 100644 --- a/apps/cpp_rtvm/README.md +++ b/apps/cpp_rtvm/README.md @@ -122,6 +122,11 @@ Command line usage --input - Numpy file for the model input (optional and we use random of not given) --output - Numpy file name to dump the model output as numpy --dump-meta - Dump model meta information +--pre-compiled - The file name of a file where pre-compiled programs should be stored +--profile - Profile over all execution +--dry-run - Profile after given dry runs, default 10 +--run-count - Profile for given runs, default 50 +--zero-copy - Profile with zero copy api Example ./rtvm --model=keras-resnet50 --device="opencl" --dump-meta @@ -366,3 +371,20 @@ stored. If the pre-compiled file name was passed to the `rtvm` then After method `Load`, method `UsePreCompiledProgram` is called. This method loads pre-compiled programs if the file exists. In opposite case the file will be created and pre-compiled programs will be saved to this file. + +# Performnace Profiling Options +The tool has added few options to measure wall clock performance of the given model on Target natively. +--profile : Can turn on the profiling +--dry-run : The number of times dry run the model before mearuring the performance. Default value os 10 +--run-count : The number times to run the model and take an average. Default value is 50. +--zero-copy: This option enables graph runtime zero copy to be used for input and output than byte copy to DLTensor. + +Performance profile options dumps information summary as given below. + Module Load :27 ms + Graph Runtime Create :11 ms + Params Read :15 ms + Params Set :41 ms + Pre Compiled Progs Load :24 ms +Total Load Time :118 ms +Average ExecTime :27 ms +Unload Time :35.9236 ms diff --git a/apps/cpp_rtvm/main.cc b/apps/cpp_rtvm/main.cc index c38a5f62bd..dc3cf1c414 100644 --- a/apps/cpp_rtvm/main.cc +++ b/apps/cpp_rtvm/main.cc @@ -29,6 +29,7 @@ #endif #include +#include #include #include #include @@ -54,7 +55,11 @@ static const string kUsage = "--input - Numpy file for the model input (optional and we use random of not given)\n" "--output - Numpy file name to dump the model output as numpy\n" "--dump-meta - Dump model meta information\n" - "--pre-compiled - The file name of a file where pre-compiled programs should be stored" + "--pre-compiled - The file name of a file where pre-compiled programs should be stored\n" + "--profile - Profile over all execution\n" + "--dry-run - Profile after given dry runs, default 10\n" + "--run-count - Profile for given runs, default 50\n" + "--zero-copy - Profile with zero copy api\n" "\n" " Example\n" " ./rtvm --model=keras-resnet50 --device=\"opencl\" --dump-meta\n" @@ -68,6 +73,7 @@ static const string kUsage = * \arg input Numpy file for the model input * \arg output Numpy file name to dump the model output as numpy * \arg pre_compiled File name where pre-compiled programs should be stored + * \arg profile Do we profile overall execution */ struct ToolArgs { string model; @@ -75,7 +81,11 @@ struct ToolArgs { string input; string output; string pre_compiled; - bool dump_meta = false; + bool dump_meta{false}; + bool profile{false}; + int dry_run{10}; + int run_count{50}; + bool zero_copy{false}; }; /*! @@ -89,6 +99,10 @@ void PrintArgs(const ToolArgs& args) { LOG(INFO) << "Output = " << args.output; LOG(INFO) << "Pre-compiled = " << args.pre_compiled; LOG(INFO) << "Dump Metadata = " << ((args.dump_meta) ? ("True") : ("False")); + LOG(INFO) << "Profile = " << ((args.profile) ? ("True") : ("False")); + LOG(INFO) << "Dry Run = " << args.dry_run; + LOG(INFO) << "Run Count = " << args.run_count; + LOG(INFO) << "Zero Copy = " << ((args.zero_copy) ? ("True") : ("False")); } #if defined(__linux__) || defined(__ANDROID__) @@ -178,6 +192,26 @@ void ParseCmdArgs(int argc, char* argv[], struct ToolArgs& args) { } args.pre_compiled = GetCmdOption(argc, argv, "--pre-compiled="); + + const string pprofile = GetCmdOption(argc, argv, "--profile", true); + if (!pprofile.empty()) { + args.profile = true; + } + + const string pdry_run = GetCmdOption(argc, argv, "--dry-run="); + if (!pdry_run.empty()) { + args.dry_run = stoi(pdry_run); + } + + const string prun = GetCmdOption(argc, argv, "--run-count="); + if (!prun.empty()) { + args.run_count = stoi(prun); + } + + const string pzcopy = GetCmdOption(argc, argv, "--zero-copy", true); + if (!pzcopy.empty()) { + args.zero_copy = true; + } } /*! @@ -192,59 +226,174 @@ int ExecuteModel(ToolArgs& args) { #endif // Initialize TVM Runner - TVMRunner runner = TVMRunner(args.model, args.device); + auto runner = new TVMRunner(args.model, args.device); // Load the model - runner.Load(); + runner->Load(); if (!args.pre_compiled.empty()) { - runner.UsePreCompiledPrograms(args.pre_compiled); + runner->UsePreCompiledPrograms(args.pre_compiled); } // Query Model meta Information - TVMMetaInfo mInfo = runner.GetMetaInfo(); + TVMMetaInfo mInfo = runner->GetMetaInfo(); // Print Meta Information - if (args.dump_meta) runner.PrintMetaInfo(); + if (args.dump_meta) runner->PrintMetaInfo(); + + int total_exec_time = 0; + + if (args.profile) { + if (args.dry_run) { + for (int ii = 0; ii < args.dry_run; ++ii) { + runner->Run(); + } + TVMSynchronize(GetTVMDevice(args.device), 0, nullptr); + } + int total_time = 0; + std::map input_data_even, input_data_odd; + std::map output_data_even, output_data_odd; + + std::map input_data; + std::map output_data; + + // Alloc / populate and keep input data ready + for (auto& elem : mInfo.input_info) { + if (args.zero_copy) { + auto ndarr = + NDArray::Empty(elem.second.first, tvm::runtime::String2DLDataType(elem.second.second), + DLDevice{GetTVMDevice(args.device), 0}); + input_data_even.insert({elem.first, ndarr}); + + ndarr = + NDArray::Empty(elem.second.first, tvm::runtime::String2DLDataType(elem.second.second), + DLDevice{GetTVMDevice(args.device), 0}); + input_data_odd.insert({elem.first, ndarr}); + } else { + char* data = (char*)malloc(runner->GetInputMemSize(elem.first)); + input_data.insert({elem.first, data}); + } + } + + // Alloc and keep output bufers ready + for (auto& elem : mInfo.output_info) { + if (args.zero_copy) { + auto ndarr = + NDArray::Empty(elem.second.first, tvm::runtime::String2DLDataType(elem.second.second), + DLDevice{GetTVMDevice(args.device), 0}); + output_data_even.insert({elem.first, ndarr}); + + ndarr = + NDArray::Empty(elem.second.first, tvm::runtime::String2DLDataType(elem.second.second), + DLDevice{GetTVMDevice(args.device), 0}); + output_data_odd.insert({elem.first, ndarr}); + } else { + char* data = (char*)malloc(runner->GetOutputMemSize(elem.first)); + output_data.insert({elem.first, data}); + } + } + + for (int ii = 0; ii < args.run_count; ++ii) { + // Timer start + auto tstart = std::chrono::high_resolution_clock::now(); + // Set random input for all input + for (auto& elem : mInfo.input_info) { + if (args.zero_copy) { + if (ii % 2) { + runner->SetInput(elem.first, input_data_even[elem.first]); + } else { + runner->SetInput(elem.first, input_data_odd[elem.first]); + } + } else { + runner->SetInput(elem.first, input_data[elem.first]); + } + } + + if (args.zero_copy) { + // With zero copy set the result NDArray up front + for (auto& elem : mInfo.output_info) { + if (ii % 2) { + runner->SetOutput(elem.first, output_data_even[elem.first]); + } else { + runner->SetOutput(elem.first, output_data_odd[elem.first]); + } + } + } - if (args.input.empty() || args.output.empty()) { + // Run the model + runner->Run(); + + if (!args.zero_copy) { + // W/o zero copy we need to invoke explicite data copy + for (auto& elem : mInfo.output_info) { + runner->GetOutput(elem.first, output_data[elem.first]); + } + } else { + // Just wait for the run to complete. + TVMSynchronize(GetTVMDevice(args.device), 0, nullptr); + } + + // Timer end + auto tend = std::chrono::high_resolution_clock::now(); + LOG(INFO) << "Exec Time:" << static_cast((tend - tstart).count()) / 1e6; + total_exec_time += static_cast((tend - tstart).count()) / 1e6; + } + + // Free input bufers + for (auto& elem : mInfo.input_info) { + free(input_data[elem.first]); + } + + // Free output bufers + for (auto& elem : mInfo.output_info) { + free(output_data[elem.first]); + } + } else if (!args.input.empty() && !args.output.empty()) { + LOG(INFO) << "Executing with Input:" << args.input << " Output:" << args.output; + // Set Input from Numpy Input + runner->SetInput(args.input); + // Run the model + runner->Run(); + // Get Output as Numpy dump + runner->GetOutput(args.output); + } else { LOG(INFO) << "Executing dry run ... "; // Set random input for all inputs for (auto& elem : mInfo.input_info) { LOG(INFO) << "Set Random Input for :" << elem.first; auto shape = elem.second.first; - size_t ssize = runner.GetInputMemSize(elem.first); + size_t ssize = runner->GetInputMemSize(elem.first); char* data = (char*)malloc(ssize); LOG(INFO) << "Random Input Size:" << ssize << " bytes"; - runner.SetInput(elem.first, data); + runner->SetInput(elem.first, data); free(data); } - // Run the model - runner.Run(); - + runner->Run(); // Get Output and dump few values for (auto& elem : mInfo.output_info) { LOG(INFO) << "Get Output for :" << elem.first; auto shape = elem.second.first; - size_t ssize = runner.GetOutputMemSize(elem.first); + size_t ssize = runner->GetOutputMemSize(elem.first); char* data = (char*)malloc(ssize); - runner.GetOutput(elem.first, data); + runner->GetOutput(elem.first, data); LOG(INFO) << "Output Size:" << ssize << " bytes"; free(data); } - } else { - LOG(INFO) << "Executing with Input:" << args.input << " Output:" << args.output; - - // Set Input from Numpy Input - runner.SetInput(args.input); - - // Run the model - runner.Run(); + } - // Get Output as Numpy dump - runner.GetOutput(args.output); + if (args.profile) { + // Print Stats + runner->PrintStats(); } + auto tstart = std::chrono::high_resolution_clock::now(); + delete runner; + auto tend = std::chrono::high_resolution_clock::now(); + if (args.profile) { + LOG(INFO) << "Average ExecTime :" << total_exec_time / args.run_count << " ms"; + LOG(INFO) << "Unload Time :" << static_cast((tend - tstart).count()) / 1e6 + << " ms"; + } return 0; } diff --git a/apps/cpp_rtvm/tvm_runner.cc b/apps/cpp_rtvm/tvm_runner.cc index 2fd4f2281e..7d6dbc23ee 100644 --- a/apps/cpp_rtvm/tvm_runner.cc +++ b/apps/cpp_rtvm/tvm_runner.cc @@ -26,10 +26,12 @@ #include +#include #include #include #include #include +#include namespace tvm { namespace runtime { @@ -39,25 +41,25 @@ namespace runtime { * \param device the target device in string format. * \return dl_device corresponding to the device string. */ -int GetTVMDevice(std::string device) { +DLDeviceType GetTVMDevice(std::string device) { if (!device.compare("cpu")) { - return static_cast(kDLCPU); + return kDLCPU; } else if (!device.compare("llvm")) { - return static_cast(kDLCPU); + return kDLCPU; } else if (!device.compare("cuda")) { - return static_cast(kDLCUDA); + return kDLCUDA; } else if (!device.compare("opencl")) { - return static_cast(kDLOpenCL); + return kDLOpenCL; } else if (!device.compare("vulkan")) { - return static_cast(kDLVulkan); + return kDLVulkan; } else if (!device.compare("metal")) { - return static_cast(kDLMetal); + return kDLMetal; } else if (!device.compare("vpi")) { - return static_cast(kDLVPI); + return kDLVPI; } else if (!device.compare("rocm")) { - return static_cast(kDLROCM); + return kDLROCM; } else if (!device.compare("oneapi")) { - return static_cast(kDLOneAPI); + return kDLOneAPI; } else { LOG(FATAL) << "TVMRunner : Unsupported device :" << device; } @@ -80,34 +82,59 @@ TVMRunner::TVMRunner(std::string path, std::string device) int TVMRunner::Load(void) { LOG(INFO) << "TVMRunner Load:" << r_model_path; // Load the lib file + auto tstart = std::chrono::high_resolution_clock::now(); + r_mod_handle = Module::LoadFromFile((r_model_path + "/mod.so").c_str(), "so"); + auto tend = std::chrono::high_resolution_clock::now(); + r_module_load_ms = static_cast((tend - tstart).count()) / 1e6; + tstart = std::chrono::high_resolution_clock::now(); // Read model json file std::ifstream json_reader((r_model_path + "/mod.json").c_str()); CHECK(!json_reader.fail()) << "Failed to open json file:" << (r_model_path + "/mod.json").c_str(); - std::string json_str((std::istreambuf_iterator(json_reader)), - std::istreambuf_iterator()); + json_reader.seekg(0, std::ios_base::end); + std::size_t json_size = json_reader.tellg(); + json_reader.seekg(0, std::ios_base::beg); + std::string json_data; + json_data.reserve(json_size); + json_reader.read((char*)json_data.c_str(), json_size); json_reader.close(); // Get ref to graph exeutor auto f_handle = tvm::runtime::Registry::Get("tvm.graph_executor.create"); // Greate graph runtime - r_graph_handle = (*f_handle)(json_str, r_mod_handle, GetTVMDevice(r_device), 0); + r_graph_handle = + (*f_handle)(json_data, r_mod_handle, static_cast(GetTVMDevice(r_device)), 0); + + tend = std::chrono::high_resolution_clock::now(); + r_graph_load_ms = static_cast((tend - tstart).count()) / 1e6; // Read params binary file + tstart = std::chrono::high_resolution_clock::now(); std::ifstream params_reader((r_model_path + "/mod.params").c_str(), std::ios::binary); CHECK(!params_reader.fail()) << "Failed to open json file:" << (r_model_path + "/mod.params").c_str(); - const std::string params_str((std::istreambuf_iterator(params_reader)), - std::istreambuf_iterator()); + + params_reader.seekg(0, std::ios_base::end); + std::size_t param_size = params_reader.tellg(); + params_reader.seekg(0, std::ios_base::beg); + std::vector param_data(param_size / sizeof(char)); + params_reader.read((char*)¶m_data[0], param_size); params_reader.close(); + TVMByteArray params_arr; - params_arr.data = params_str.c_str(); - params_arr.size = params_str.length(); + params_arr.data = (char*)¶m_data[0]; + params_arr.size = param_size; + + tend = std::chrono::high_resolution_clock::now(); + r_param_read_ms = static_cast((tend - tstart).count()) / 1e6; // Load parameters + tstart = std::chrono::high_resolution_clock::now(); r_graph_handle.GetFunction("load_params")(params_arr); + tend = std::chrono::high_resolution_clock::now(); + r_param_load_ms = static_cast((tend - tstart).count()) / 1e6; return 0; } @@ -117,6 +144,7 @@ int TVMRunner::Load(void) { * \param file_name File name where pre-compiled programs should be stored. */ void TVMRunner::UsePreCompiledPrograms(std::string file_name) { + auto tstart = std::chrono::high_resolution_clock::now(); if (r_run_was_called) { LOG(INFO) << "TVMRunner UsePreCompiledPrograms: should be called before first run"; return; @@ -130,10 +158,19 @@ void TVMRunner::UsePreCompiledPrograms(std::string file_name) { std::ofstream fs(file_name, std::ofstream::binary); fs.write(bytes.c_str(), bytes.size()); } else { - std::string bytes((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); - f_set(String(bytes)); + ifs.seekg(0, std::ios_base::end); + std::size_t blob_size = ifs.tellg(); + ifs.seekg(0, std::ios_base::beg); + std::string blob_data; + blob_data.reserve(blob_size); + blob_data.resize(blob_size); + ifs.read((char*)blob_data.c_str(), blob_size); + ifs.close(); + f_set(String(blob_data)); } } + auto tend = std::chrono::high_resolution_clock::now(); + r_pre_compiled_load_ms = static_cast((tend - tstart).count()) / 1e6; } /*! @@ -156,8 +193,6 @@ inline size_t GetMemSize(NDArray& narr) { * \return The memory size. */ size_t TVMRunner::GetInputMemSize(std::string input_id) { - LOG(INFO) << "TVMRunner::GetInputMemSize:" << input_id; - NDArray in_arr = r_graph_handle.GetFunction("get_input")(input_id); auto ssize = GetMemSize(in_arr); @@ -170,8 +205,6 @@ size_t TVMRunner::GetInputMemSize(std::string input_id) { * \return The memory size. */ size_t TVMRunner::GetOutputMemSize(std::string output_id) { - LOG(INFO) << "TVMRunner::GetOutputMemSize:" << output_id; - NDArray out_arr = r_graph_handle.GetFunction("get_output")(output_id); auto ssize = GetMemSize(out_arr); @@ -209,13 +242,23 @@ int TVMRunner::SetInput(std::string inputfile) { * \param 0 on success else error code. */ int TVMRunner::SetInput(std::string input_id, char* raw_input) { - LOG(INFO) << "TVMRunner::SetInput (Raw)"; NDArray in_arr = r_graph_handle.GetFunction("get_input")(input_id); auto ssize = GetMemSize(in_arr); in_arr.CopyFromBytes(raw_input, ssize); return 0; } +/*! + * \brief Set the model input from given NDArray with zero copy. + * \param input_id input node name. + * \param ndarr NDArray. + * \param 0 on success else error code. + */ +int TVMRunner::SetInput(std::string input_id, NDArray& ndarr) { + r_graph_handle.GetFunction("set_input_zero_copy")(input_id, ndarr); + return 0; +} + /*! * \brief Get the model outputs and dump them to npz file. * \param outputfile the npz file to where we dump the output data. @@ -255,21 +298,29 @@ int TVMRunner::GetOutput(std::string outputfile) { * \param 0 on success else error code. */ int TVMRunner::GetOutput(std::string output_id, char* raw_output) { - LOG(INFO) << "TVMRunner::GetOutput (Raw)"; NDArray out_arr = r_graph_handle.GetFunction("get_output")(output_id); auto ssize = GetMemSize(out_arr); out_arr.CopyToBytes(raw_output, ssize); return 0; } +/*! + * \brief Set the model output from given NDArray with zero copy. + * \param output_id output node name. + * \param ndarr NDArray. + * \param 0 on success else error code. + */ +int TVMRunner::SetOutput(std::string output_id, NDArray& ndarr) { + r_graph_handle.GetFunction("set_output_zero_copy")(output_id, ndarr); + return 0; +} + /*! * \brief Call one cycle of execution for the model. * \param 0 on success else error code. */ int TVMRunner::Run(void) { - LOG(INFO) << "TVMRunner::Run"; r_run_was_called = true; - r_graph_handle.GetFunction("run")(); return 0; } @@ -289,10 +340,10 @@ TVMMetaInfo TVMRunner::GetMetaInfo(void) { auto dtype_info = GetRef>(tvm_input_info["dtype"].as()); for (const auto& kv : shape_info) { auto stuple = GetRef(kv.second.as()); - std::vector vshape; + std::vector vshape; vshape.assign(stuple.begin(), stuple.end()); auto dtype = GetRef(dtype_info[kv.first].as()); - std::pair, std::string> value = std::make_pair(vshape, dtype); + std::pair, std::string> value = std::make_pair(vshape, dtype); mInfo.input_info.insert({kv.first, value}); } @@ -301,10 +352,10 @@ TVMMetaInfo TVMRunner::GetMetaInfo(void) { dtype_info = GetRef>(tvm_input_info["dtype"].as()); for (const auto& kv : shape_info) { auto stuple = GetRef(kv.second.as()); - std::vector vshape; + std::vector vshape; vshape.assign(stuple.begin(), stuple.end()); auto dtype = GetRef(dtype_info[kv.first].as()); - std::pair, std::string> value = std::make_pair(vshape, dtype); + std::pair, std::string> value = std::make_pair(vshape, dtype); mInfo.output_info.insert({kv.first, value}); } @@ -343,5 +394,21 @@ void TVMRunner::PrintMetaInfo(void) { } } +/*! + * \brief Print stats information. + */ +void TVMRunner::PrintStats(void) { + LOG(INFO) << "Performance Stats:" << r_model_path; + LOG(INFO) << " Module Load :" << r_module_load_ms << " ms"; + LOG(INFO) << " Graph Runtime Create :" << r_graph_load_ms << " ms"; + LOG(INFO) << " Params Read :" << r_param_read_ms << " ms"; + LOG(INFO) << " Params Set :" << r_param_load_ms << " ms"; + LOG(INFO) << " Pre Compiled Progs Load :" << r_pre_compiled_load_ms << " ms"; + LOG(INFO) << "Total Load Time :" + << r_module_load_ms + r_graph_load_ms + r_param_read_ms + r_param_load_ms + + r_pre_compiled_load_ms + << " ms"; +} + } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rtvm/tvm_runner.h b/apps/cpp_rtvm/tvm_runner.h index 47717c3ecf..e93b63ae85 100644 --- a/apps/cpp_rtvm/tvm_runner.h +++ b/apps/cpp_rtvm/tvm_runner.h @@ -41,8 +41,8 @@ namespace runtime { typedef struct _TVMMetaInfo { int n_inputs; int n_outputs; - std::map, std::string>> input_info; - std::map, std::string>> output_info; + std::map, std::string>> input_info; + std::map, std::string>> output_info; } TVMMetaInfo; /*! @@ -63,10 +63,14 @@ class TVMRunner { int SetInput(std::string); /*! \brief To set the input from binary data */ int SetInput(std::string, char*); + /*! \brief To set the input from NDArray */ + int SetInput(std::string, NDArray& ndarr); /*! \brief Save the model output into given npz file */ int GetOutput(std::string); /*! \brief Get the model output in binary format */ int GetOutput(std::string, char*); + /*! \brief Swap output NDArray with given one */ + int SetOutput(std::string, NDArray& ndarr); /*! \brief To get the input mem size */ size_t GetInputMemSize(std::string); /*! \brief To get the output mem size */ @@ -76,6 +80,21 @@ class TVMRunner { /*! \brief Print function to show all meta information */ void PrintMetaInfo(void); + /*! \brief Print function to show all stats information */ + void PrintStats(void); + + // Public profiling information + /*! Module load time */ + int r_module_load_ms{0}; + /*! Graph runtime creatint time */ + int r_graph_load_ms{0}; + /*! Params read time */ + int r_param_read_ms{0}; + /*! Params load time */ + int r_param_load_ms{0}; + /*! Pre compiled programs load time */ + int r_pre_compiled_load_ms{0}; + private: /*! \brief Module handle for the shared object */ Module r_mod_handle; @@ -91,6 +110,7 @@ class TVMRunner { bool r_run_was_called; }; +DLDeviceType GetTVMDevice(std::string device); } // namespace runtime } // namespace tvm #endif // TVM_APPS_CPP_RTVM_RUNNER_H_ diff --git a/apps/howto_deploy/prepare_test_libs.py b/apps/howto_deploy/prepare_test_libs.py index f5afc5bf67..45d7c0abd9 100644 --- a/apps/howto_deploy/prepare_test_libs.py +++ b/apps/howto_deploy/prepare_test_libs.py @@ -33,7 +33,13 @@ def prepare_test_libs(base_path): fadd_dylib.export_library(dylib_path) # Compile library in system library mode - fadd_syslib = tvm.build(s, [A, B], "llvm", name="addonesys") + fadd_syslib = tvm.build( + s, + [A, B], + "llvm", + name="addonesys", + runtime=relay.backend.Runtime("cpp", {"system-lib": True}), + ) syslib_path = os.path.join(base_path, "test_addone_sys.o") fadd_syslib.save(syslib_path) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index 6c21356ae8..6fb74fc1ef 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -29,6 +29,9 @@ add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1) # It may be a boolean or a string if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) find_llvm(${USE_LLVM}) + if (${TVM_LLVM_VERSION} LESS 60) + message(FATAL_ERROR "LLVM version 6.0 or greater is required.") + endif() include_directories(SYSTEM ${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) message(STATUS "Build with LLVM " ${LLVM_PACKAGE_VERSION}) diff --git a/docker/install/ubuntu_install_oneflow.sh b/docker/install/ubuntu_install_oneflow.sh index 3eb6b7d89b..04fccd5b9c 100755 --- a/docker/install/ubuntu_install_oneflow.sh +++ b/docker/install/ubuntu_install_oneflow.sh @@ -22,4 +22,4 @@ set -o pipefail pip3 install flowvision==0.1.0 -python3 -m pip install -f https://release.oneflow.info oneflow==0.7.0+cpu +python3 -m pip install oneflow==0.7.0 diff --git a/docs/how_to/deploy/adreno.rst b/docs/how_to/deploy/adreno.rst index ed016a3ff7..f0b8c6f757 100644 --- a/docs/how_to/deploy/adreno.rst +++ b/docs/how_to/deploy/adreno.rst @@ -432,7 +432,7 @@ as the OpenCL path is fallback option for any operator didn't go through OpenCLM python3 -m tvm.driver.tvmc compile \ --cross-compiler ${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang \ - --target="opencl, clml, llvm" --target-llvm-mtriple aarch64-linux-gnu --target-opencl-device adreno \ + --target="opencl, clml, llvm" --desired-layout NCHW --target-llvm-mtriple aarch64-linux-gnu --target-opencl-device adreno \ --tuning-records keras-resnet50.log -o keras-resnet50.tar resnet50.h5 On successful compilation, above command produce ``keras-resnet50.tar``. diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 3496136470..4273674180 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -445,6 +445,8 @@ def compile_model( # TODO lib.get_source call have inconsistent behavior for unsupported # formats (@leandron). dumps[source_type] = lib.get_source(source_type) + for smod in lib.imported_modules: + dumps[smod.type_key] = smod.get_source() # Create a new tvmc model package object from the graph definition. package_path = tvmc_model.export_package( diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 9e09cb400a..6c82ebb427 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -1062,6 +1062,8 @@ def _convert_simple_rnn( in_bias = etab.new_const(weightList[2]) assert len(in_data.type_annotation.shape) == 3 timeDim = in_data.type_annotation.shape[1].value + if keras_layer.go_backwards: + in_data = _op.reverse(in_data, axis=1) in_data_split = _op.split(in_data, indices_or_sections=timeDim, axis=1) for i in range(len(in_data_split)): in_data_split_i = _op.nn.batch_flatten(in_data_split[i]) @@ -1090,6 +1092,8 @@ def _convert_gru( recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) if keras_layer.use_bias: in_bias = etab.new_const(weightList[2]) + if keras_layer.go_backwards: + in_data = _op.reverse(in_data, axis=1) units = list(weightList[0].shape)[1] assert units > 0, "The value of units must be a positive integer" in_data = _op.nn.batch_flatten(in_data) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 00db06bf3c..260d0ead9d 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Configuration of TVMScript printer""" +import os from typing import Dict, List, Optional, Sequence from tvm._ffi import get_global_func, register_object @@ -261,7 +262,7 @@ def _relax_script( def show( self, style: Optional[str] = None, - black_format: bool = False, + black_format: Optional[bool] = None, *, name: Optional[str] = None, show_meta: bool = False, @@ -290,8 +291,26 @@ def show( style : str, optional Pygmentize printing style, auto-detected if None. See `tvm.script.highlight.cprint` for more details. - black_format: bool - If true, use the formatter Black to format the TVMScript + + black_format: Optional[bool] + + If true, use the formatter Black to format the TVMScript. + If false, do not apply the auto-formatter. + + If None (default), determine the behavior based on the + environment variable "TVM_BLACK_FORMAT". If this + environment variable is unset, set to the empty string, or + set to the integer zero, black auto-formatting will be + disabled. If the environment variable is set to a + non-zero integer, black auto-formatting will be enabled. + + Note that the "TVM_BLACK_FORMAT" environment variable only + applies to the `.show()` method, and not the underlying + `.script()` method. The `.show()` method is intended for + human-readable output based on individual user + preferences, while the `.script()` method is intended to + provided a consistent output regardless of environment. + name : Optional[str] = None The name of the object show_meta : bool = False @@ -331,11 +350,16 @@ def show( Object to be underlined obj_to_annotate : Optional[Dict[Object, str]] = None Object to be annotated + """ from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel cprint, ) + if black_format is None: + env = os.environ.get("TVM_BLACK_FORMAT") + black_format = env and int(env) + cprint( self.script( name=name, diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index ff19bfda91..b7819d5b7f 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -532,7 +532,18 @@ def report_error( msg = "KeyError: " + str(err) else: msg = str(err) - self.diag.error(node, msg) + + try: + self.diag.error(node, msg) + except Exception as diag_err: + # Calling self.diag.error is guaranteed to throw an + # exception. When shown to a user, this error should + # reference the point of error within the provided + # TVMScript. However, when caught in pdb, the full + # traceback should be available for debugging. + if isinstance(err, Exception): + diag_err = diag_err.with_traceback(err.__traceback__) + raise diag_err def visit(self, node: doc.AST) -> None: """The general visiting method. diff --git a/python/tvm/target/codegen.py b/python/tvm/target/codegen.py index 5d43f4ae24..1a2efd4efa 100644 --- a/python/tvm/target/codegen.py +++ b/python/tvm/target/codegen.py @@ -17,6 +17,7 @@ """Code generation related functions.""" from . import _ffi_api from .target import Target +from ..ir.container import Array def build_module(mod, target): @@ -39,6 +40,30 @@ def build_module(mod, target): return _ffi_api.Build(mod, target) +def target_has_features(cpu_features, target=None): + """Check CPU features for the target's `-mtriple` and `-mcpu` and `-mattr`. + + Parameters + ---------- + target : Target + The TVM target. + cpu_features : str or Array + CPU Feature(s) to check. + + Returns + ------- + has_features : bool + True if target has the feature(s). + """ + assert isinstance(target, Target) or target is None + assert isinstance(cpu_features, (Array, list, tuple, str)) + has_feats = True + cpu_features = [cpu_features] if isinstance(cpu_features, str) else cpu_features + for feat in cpu_features: + has_feats &= _ffi_api.target_has_feature(feat, target) + return has_feats + + def llvm_lookup_intrinsic_id(name): """Lookup LLVM intrinsic id by name. @@ -71,36 +96,76 @@ def llvm_get_intrinsic_name(intrin_id: int) -> str: return _ffi_api.llvm_get_intrinsic_name(intrin_id) -def llvm_x86_get_archlist(only64bit=False): - """Get X86 CPU name list. +def llvm_get_targets(): + """Get LLVM target list. + + Parameters + ---------- + + Returns + ------- + llvm_targets : list[str] + List of available LLVM targets. + """ + return _ffi_api.llvm_get_targets() + + +def llvm_get_cpu_archlist(target=None): + """Get CPU architectures for the target's `-mtriple`. + + Parameters + ---------- + target : Target + The TVM target. + + Returns + ------- + cpu_archlist : list[str] + List of available CPU architectures. + """ + assert isinstance(target, Target) or target is None + return _ffi_api.llvm_get_cpu_archlist(target) + + +def llvm_get_cpu_features(target=None): + """Get CPU features for the target's `-mtriple` and `-mcpu` and considering `-mattr`. Parameters ---------- - only64bit : bool - Filter 64bit architectures. + target : Target + The TVM target. Returns ------- - features : list[str] - String list of X86 architectures. + cpu_features : list[str] + List of available CPU features. """ - return _ffi_api.llvm_x86_get_archlist(only64bit) + assert isinstance(target, Target) or target is None + return _ffi_api.llvm_get_cpu_features(target) -def llvm_x86_get_features(cpu_name): - """Get X86 CPU features. +def llvm_cpu_has_features(cpu_features, target=None): + """Check CPU features for the target's `-mtriple` and `-mcpu` and considering `-mattr`. Parameters ---------- - cpu_name : string - X86 CPU name (e.g. "skylake"). + target : Target + The TVM target. + cpu_features : str or Array + CPU Feature(s) to check. Returns ------- - features : list[str] - String list of X86 CPU features. + has_features : bool + True if target CPU has the feature(s). """ - return _ffi_api.llvm_x86_get_features(cpu_name) + assert isinstance(target, Target) or target is None + assert isinstance(cpu_features, (Array, list, tuple, str)) + has_feats = True + cpu_features = [cpu_features] if isinstance(cpu_features, str) else cpu_features + for feat in cpu_features: + has_feats &= _ffi_api.llvm_cpu_has_feature(feat, target) + return has_feats def llvm_version_major(allow_none=False): diff --git a/python/tvm/target/x86.py b/python/tvm/target/x86.py index a3dcb62e8a..c040eface8 100644 --- a/python/tvm/target/x86.py +++ b/python/tvm/target/x86.py @@ -16,30 +16,7 @@ # under the License. """Common x86 related utilities""" from .._ffi import register_func -from . import _ffi_api -from ..ir.container import Array - - -@register_func("tvm.target.x86.target_has_features") -def target_has_features(features, target=None): - """Check X86 CPU features. - Parameters - ---------- - features : str or Array - Feature(s) to check. - target : Target - Optional TVM target, default `None` use the global context target. - Returns - ------- - has_feats : bool - True if feature(s) are in the target arch. - """ - has_feats = True - assert isinstance(features, (Array, str)) - features = [features] if isinstance(features, str) else features - for feat in features: - has_feats &= _ffi_api.llvm_x86_has_feature(feat, target) - return has_feats +from .codegen import target_has_features @register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") @@ -53,9 +30,6 @@ def get_simd_32bit_lanes(): The optimal vector length of CPU from the global context target. """ vec_len = 4 - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" if target_has_features(["avx512bw", "avx512f"]): vec_len = 16 elif target_has_features("avx2"): diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 85d1f19bba..5df3a486ca 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" +import inspect from typing import Callable, Dict, List, Optional, Tuple, Union from tvm._ffi import register_object as _register_object @@ -268,26 +269,24 @@ def fork_seed(self) -> int: """ return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member - def show(self, style: Optional[str] = None, black_format: bool = False) -> None: + def show(self, *args, **kwargs) -> None: """A sugar for print highlighted TVM script. - Parameters - ---------- - style : str, optional - - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - - black_format: bool - - If true, use the formatter Black to format the TVMScript + All parameters are forwarded to the underlying `Module.show` + and `Trace.show` methods. """ mod = self.mod if mod is not None: - mod.show(style=style, black_format=black_format) + mod.show(*args, **kwargs) + trace = self.trace if trace is not None: - trace.show(style=style, black_format=black_format) + # Trace.show only supports the style and black_format arguments + param_binding = inspect.signature(mod.show).bind(*args, **kwargs) + param_binding.apply_defaults() + bound_args = param_binding.arguments + + trace.show(style=bound_args["style"], black_format=bound_args["black_format"]) ########## Lookup ########## diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index e317304993..cb8d5ce997 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """An execution trace of a scheduling program""" +import os from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from tvm._ffi import register_object as _register_object @@ -274,10 +275,16 @@ def show(self, style: Optional[str] = None, black_format: bool = False) -> None: black_format: bool - If true, use the formatter Black to format the TVMScript + If true, use the formatter Black to format the TVMScript. + If None, determine based on the "TVM_BLACK_FORMAT" environment + variable. """ from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel cprint, ) + if black_format is None: + env = os.environ.get("TVM_BLACK_FORMAT") + black_format = bool(env and int(env)) + cprint(str(self), style=style, black_format=black_format) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 85accab87b..e103133230 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -21,7 +21,7 @@ from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, mkl -from tvm.target.x86 import target_has_features +from tvm.target.codegen import target_has_features from .. import generic, nn from ..transform import layout_transform @@ -38,9 +38,6 @@ def batch_matmul_int8_compute(cfg, x, y, *_): packed_y = layout_transform(y, "BNK", packed_y_layout) _, n_o, _, n_i, _ = packed_y.shape ak = te.reduce_axis((0, k), name="k") - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" if target_has_features(["avx512bw", "avx512f"]): attrs_info = {"schedule_rule": "batch_matmul_int8"} else: @@ -241,9 +238,6 @@ def _callback(op): layout_trans = op.input_tensors[1] if target_has_features("amx-int8"): batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans) - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" elif target_has_features(["avx512bw", "avx512f"]): batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans) diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 2437b1a695..4151ea0b70 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -23,7 +23,8 @@ from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, dnnl, mkl -from tvm.target.x86 import get_simd_32bit_lanes, target_has_features +from tvm.target.x86 import get_simd_32bit_lanes +from tvm.target.codegen import target_has_features from .. import generic, tag from ..utils import get_const_tuple, traverse_inline @@ -303,9 +304,6 @@ def _callback(op): if "dense_int8" in op.tag: if target_has_features("amx-int8"): dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" elif target_has_features(["avx512bw", "avx512f"]): dense_int8_schedule(cfg, s, op.output(0), outs[0]) @@ -318,9 +316,6 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): m, k = X.shape n_o, _, n_i, _ = packed_w.shape ak = te.reduce_axis((0, k), name="k") - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" if target_has_features(["avx512bw", "avx512f"]): target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"} else: diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index ef6df7dd2c..0e9b1f7b65 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -19,7 +19,7 @@ import tvm from tvm import autotvm, relay, te -from tvm.target.x86 import target_has_features +from tvm.target.codegen import target_has_features from .. import nn from ..nn import dense_alter_layout @@ -28,9 +28,6 @@ def check_int8_applicable(x, y, allow_padding=False): - # avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - # avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - # + llvm.x86.avx512.pmaddw.d.512" simd_avai = target_has_features(["avx512bw", "avx512f"]) simd_avai |= target_has_features("amx-int8") # TODO(vvchernov): may be also target_has_features("avx2") or lower? diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 4657f962f3..73df303f72 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -24,20 +24,17 @@ namespace meta_schedule { String GetRuleKindFromTarget(const Target& target) { if (target->kind->name == "llvm") { - static const PackedFunc* llvm_x86_has_feature_fn_ptr = - runtime::Registry::Get("target.llvm_x86_has_feature"); - ICHECK(llvm_x86_has_feature_fn_ptr != nullptr) - << "The `target.llvm_x86_has_feature` func is not in tvm registry."; - bool have_avx512vnni = (*llvm_x86_has_feature_fn_ptr)("avx512vnni", target); - bool have_avxvnni = (*llvm_x86_has_feature_fn_ptr)("avxvnni", target); + static const PackedFunc* target_has_feature_fn_ptr = + runtime::Registry::Get("target.target_has_feature"); + ICHECK(target_has_feature_fn_ptr != nullptr) + << "The `target.target_has_feature` func is not in tvm registry."; + bool have_avx512vnni = (*target_has_feature_fn_ptr)("avx512vnni", target); + bool have_avxvnni = (*target_has_feature_fn_ptr)("avxvnni", target); if (have_avx512vnni || have_avxvnni) { return "vnni"; } else { - // avx512f: llvm.x86.avx512.addpd.w.512 (LLVM auto, added) - // avx512bw: llvm.x86.avx512.pmaddubs.w.512" (TVM required) - // + llvm.x86.avx512.pmaddw.d.512" - bool have_avx512f = (*llvm_x86_has_feature_fn_ptr)("avx512f", target); - bool have_avx512bw = (*llvm_x86_has_feature_fn_ptr)("avx512bw", target); + bool have_avx512f = (*target_has_feature_fn_ptr)("avx512f", target); + bool have_avx512bw = (*target_has_feature_fn_ptr)("avx512bw", target); if (have_avx512bw && have_avx512f) { return "avx512"; } diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index b57710b266..2dd74e1321 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -121,9 +121,9 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs, } bool has_current_target_sse41_support() { - auto llvm_x86_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.llvm_x86_has_feature"); - ICHECK(llvm_x86_has_feature_fn_ptr) << "Function target.llvm_x86_has_feature not found"; - return (*llvm_x86_has_feature_fn_ptr)("sse4.1", Target::Current(true)); + auto target_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.target_has_feature"); + ICHECK(target_has_feature_fn_ptr) << "Function target.target_has_feature not found"; + return (*target_has_feature_fn_ptr)("sse4.1", Target::Current(true)); } /* diff --git a/src/relay/qnn/op/requantize_config.h b/src/relay/qnn/op/requantize_config.h index 956bc3533b..a4238fa498 100644 --- a/src/relay/qnn/op/requantize_config.h +++ b/src/relay/qnn/op/requantize_config.h @@ -61,10 +61,10 @@ class RequantizeConfigNode : public Object { // For the x86 architecture, the float32 computation is expected to give significant speedup, // with little loss in the accuracy of the requantize operation. auto target = Target::Current(true); - auto llvm_x86_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.llvm_x86_has_feature"); - ICHECK(llvm_x86_has_feature_fn_ptr) << "Function target.llvm_x86_has_feature not found"; + auto target_has_feature_fn_ptr = tvm::runtime::Registry::Get("target.target_has_feature"); + ICHECK(target_has_feature_fn_ptr) << "Function target.target_has_feature not found"; if (target.defined() && target->kind->name == "llvm") { - if ((*llvm_x86_has_feature_fn_ptr)("sse4.1", target)) { + if ((*target_has_feature_fn_ptr)("sse4.1", target)) { return "float32"; } } diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index a303141357..8c1607c4e5 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -220,7 +220,7 @@ struct BufferDescriptor; class OpenCLWorkspace : public DeviceAPI { public: // type key - std::string type_key; + std::string type_key{"opencl"}; // available platforms std::vector platform_ids; // map platform to its context @@ -253,7 +253,7 @@ class OpenCLWorkspace : public DeviceAPI { // Initialize the device. void Init(const std::string& type_key, const std::string& device_type, const std::string& platform_name = ""); - virtual void Init() { Init("opencl", "gpu"); } + virtual void Init() { Init(this->type_key, "gpu"); } // Check whether the context is OpenCL or not. virtual bool IsOpenCLDevice(Device dev) { return dev.device_type == kDLOpenCL; } // get the queue of the device @@ -465,6 +465,8 @@ class OpenCLModuleNode : public OpenCLModuleNodeBase { : OpenCLModuleNodeBase(fmap), data_(data), fmt_(fmt), source_(source) {} PackedFunc GetFunction(const String& name, const ObjectPtr& sptr_to_self) final; + // Return true if OpenCL program for the requested function and device was created + bool IsProgramCreated(const std::string& func_name, int device_id); void SaveToFile(const String& file_name, const String& format) final; void SaveToBinary(dmlc::Stream* stream) final; void SetPreCompiledPrograms(const std::string& bytes); diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 35e77eb6d1..fb9adc2757 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -111,6 +111,7 @@ OpenCLWorkspace* OpenCLWorkspace::Global() { } cl_device_id OpenCLWorkspace::GetCLDeviceID(int device_id) { + this->Init(); ICHECK_LT(device_id, devices.size()) << "Invalid device id " << device_id << ". " << GetError(); return devices[device_id]; } @@ -210,6 +211,7 @@ void OpenCLWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) void* OpenCLWorkspace::CreateHostPtrIfEnabled(cl::BufferDescriptor* desc, Device dev, size_t size) { #if defined(OPENCL_ENABLE_HOST_PTR) + this->Init(); cl_int err_code; desc->host_ptr = reinterpret_cast( clEnqueueMapBuffer(this->GetQueue(dev), desc->buffer, CL_TRUE, CL_MAP_WRITE, 0, @@ -300,6 +302,7 @@ void OpenCLWorkspace::FreeTextureWorkspace(Device dev, void* ptr) { } void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { + this->Init(); size_t nbytes = GetDataSize(*from); ICHECK_EQ(nbytes, GetDataSize(*to)); ICHECK(IsContiguous(*from) && IsContiguous(*to)) @@ -379,6 +382,7 @@ void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHand } void OpenCLWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { + this->Init(); ICHECK(stream == nullptr); OPENCL_CALL(clFinish(this->GetQueue(dev))); } diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 6829d46d43..567b7ad88a 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -185,7 +185,6 @@ String OpenCLModuleNode::GetSource(const String& format) { void OpenCLModuleNode::Init() { workspace_ = GetGlobalWorkspace(); - workspace_->Init(); // initialize the kernel id, need to lock global table. std::lock_guard lock(workspace_->mu); for (const auto& kv : fmap_) { @@ -208,10 +207,17 @@ void OpenCLModuleNode::Init() { << "delimiter was found."; ICHECK_EQ(fmap_.size(), parsed_kernels_.size()) << "The number of parsed kernel sources does not match the number of kernel functions"; +} + +bool OpenCLModuleNode::IsProgramCreated(const std::string& func_name, int device_id) { + auto size = programs_[func_name].size(); + if (size > 0 && programs_[func_name][device_id] != nullptr) return true; + auto dev_size = GetGlobalWorkspace()->devices.size(); + ICHECK(device_id < static_cast(dev_size)) + << "Device id " << device_id << " is bigger than number of available devices"; // zero initialize cl_program pointers for each device kernel - for (auto& kv : parsed_kernels_) { - programs_.insert({kv.first, std::vector(workspace_->devices.size(), nullptr)}); - } + if (size == 0) programs_[func_name].resize(dev_size, nullptr); + return false; } cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t, @@ -220,7 +226,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre int device_id = t->device.device_id; auto did = w->GetCLDeviceID(device_id); auto platform = w->device_to_platform[did]; - if (programs_[func_name][device_id] == nullptr) { + if (!IsProgramCreated(func_name, device_id)) { // create program if (fmt_ == "cl") { const char* s = parsed_kernels_[func_name].c_str(); @@ -268,6 +274,7 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre } void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { + workspace_->Init(); std::string data = bytes; dmlc::MemoryStringStream reader(&data); dmlc::Stream* strm = &reader; @@ -280,7 +287,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { std::vector bin_vector; strm->Read(&name); strm->Read(&bin_vector); - if (programs_[name][device_id] == nullptr) { + if (!IsProgramCreated(name, device_id)) { cl_int err = 0; cl_int binaryStatus; size_t binarySize = bin_vector.size(); @@ -310,6 +317,7 @@ void OpenCLModuleNode::SetPreCompiledPrograms(const std::string& bytes) { } std::string OpenCLModuleNode::GetPreCompiledPrograms() { + workspace_->Init(); std::string data; dmlc::MemoryStringStream writer(&data); dmlc::Stream* strm = &writer; @@ -319,7 +327,7 @@ std::string OpenCLModuleNode::GetPreCompiledPrograms() { cl::OpenCLThreadEntry* t = workspace_->GetThreadEntry(); int device_id = t->device.device_id; t->kernel_table.resize(workspace_->num_registered_kernels); - if (programs_[std::string(name)][device_id] == nullptr) { + if (!IsProgramCreated(name, device_id)) { InstallKernel(workspace_, t, name, kid_map_[name]); } size_t size; diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 834f53510e..22fc119e03 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -42,6 +42,7 @@ namespace runtime { * \param data The module data. * \param fmt The format of the data, can be "clbin", "cl" * \param fmap The map function information map of each function. + * \param source Generated OpenCL kernels. */ Module OpenCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source); diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index efe15c5c4a..1872d64d71 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -29,9 +29,7 @@ #if TVM_LLVM_VERSION >= 100 #include #endif -#include #include -#include #include #include @@ -43,38 +41,6 @@ namespace tvm { namespace codegen { -namespace { -bool TargetHasFeature(const llvm::TargetMachine& tm, const std::string& feature) { - // MCSubTargetInfo::checkFeatures was added in LLVM 6.0 -#if TVM_LLVM_VERSION >= 60 - const auto* MCInfo = tm.getMCSubtargetInfo(); - return MCInfo->checkFeatures(std::string("+") + feature); -#else - return false; - // TODO(tulloch) - enable this block, need to figure out how to reimplement - // this given visibility constraints, similar to - // https://github.com/rust-lang/rust/pull/31709 - - // Copied from - // https://github.com/llvm-mirror/llvm/blob/5136df4/lib/MC/MCSubtargetInfo.cpp#L78-L88. - - // auto checkFeatures = [&](const std::string FS) { - // llvm::SubtargetFeatures T(FS); - // llvm::FeatureBitset Set, All; - // for (std::string F : T.getFeatures()) { - // llvm::SubtargetFeatures::ApplyFeatureFlag(Set, F, MCInfo->ProcFeatures); - // if (F[0] == '-') { - // F[0] = '+'; - // } - // llvm::SubtargetFeatures::ApplyFeatureFlag(All, F, MCInfo->ProcFeatures); - // } - // return (MCInfo->getFeatureBits() & All) == Set; - // }; - // return checkFeatures(MCInfo, std::string("+") + feature); -#endif -} -} // namespace - class CodeGenX86_64 final : public CodeGenCPU { public: llvm::Value* VisitExpr_(const CastNode* op) override; @@ -92,9 +58,8 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { const auto to = op->dtype; if (from.is_float() && to.is_float() && from.bits() == 16 && to.bits() == 32) { ICHECK_EQ(from.lanes(), to.lanes()); - llvm::TargetMachine* tm = llvm_target_->GetOrCreateTargetMachine(); - const auto has_avx512 = TargetHasFeature(*tm, "avx512f"); + const auto has_avx512 = llvm_target_->TargetHasCPUFeature("avx512f"); if (from.lanes() >= 16 && has_avx512) { return CallVectorIntrin( @@ -111,7 +76,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { #if TVM_LLVM_VERSION <= 100 // The intrinsic x86_vcvtph2ps_256 was removed in LLVM 11. - const auto has_f16c = TargetHasFeature(*tm, "f16c"); + const auto has_f16c = llvm_target_->TargetHasCPUFeature("f16c"); if (from.lanes() >= 8 && has_f16c) { return CallVectorIntrin(llvm::Intrinsic::x86_vcvtph2ps_256, 8, diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index 2aa190ad70..e270a9b66c 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -33,6 +33,7 @@ #include #include #include +#include #if TVM_LLVM_VERSION >= 140 #include #else @@ -66,6 +67,27 @@ #include #include +#if TVM_LLVM_VERSION < 180 +namespace llvm { +#if TVM_LLVM_VERSION < 170 +// SubtargetSubTypeKV view +template MCSubtargetInfo::*Member> +struct ArchViewer { + friend ArrayRef& archViewer(MCSubtargetInfo Obj) { return Obj.*Member; } +}; +template struct ArchViewer<&MCSubtargetInfo::ProcDesc>; +ArrayRef& archViewer(MCSubtargetInfo); +#endif +// SubtargetFeatureKV view +template MCSubtargetInfo::*Member> +struct FeatViewer { + friend ArrayRef& featViewer(MCSubtargetInfo Obj) { return Obj.*Member; } +}; +template struct FeatViewer<&MCSubtargetInfo::ProcFeatures>; +ArrayRef& featViewer(MCSubtargetInfo); +} // namespace llvm +#endif + namespace tvm { namespace codegen { @@ -175,6 +197,17 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) { attrs_.push_back(s); } } + // llvm module target + if (target->kind->name == "llvm") { + // legalize -mcpu with the target -mtriple + auto arches = GetAllLLVMTargetArches(); + bool has_arch = + std::any_of(arches.begin(), arches.end(), [&](const auto& var) { return var == cpu_; }); + if (!has_arch) { + LOG(FATAL) << "LLVM cpu architecture `-mcpu=" << cpu_ + << "` is not valid in `-mtriple=" << triple_ << "`"; + } + } if (const Optional>& v = target->GetAttr>("cl-opt")) { llvm::StringMap& options = llvm::cl::getRegisteredOptions(); @@ -288,19 +321,59 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& scope, const std::string& target_st LLVMTargetInfo::~LLVMTargetInfo() = default; +static const llvm::Target* CreateLLVMTargetInstance(const std::string triple, + const bool allow_missing = true) { + std::string error; + // create LLVM instance + // required mimimum: llvm::InitializeAllTargets() + const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple, error); + if (!allow_missing && !llvm_instance) { + ICHECK(llvm_instance) << "LLVM instance error: `" << error << "`"; + } + + return llvm_instance; +} + +static llvm::TargetMachine* CreateLLVMTargetMachine( + const llvm::Target* llvm_instance, const std::string& triple, const std::string& cpu, + const std::string& features, const llvm::TargetOptions& target_options, + const llvm::Reloc::Model& reloc_model, const llvm::CodeModel::Model& code_model, + const llvm::CodeGenOpt::Level& opt_level) { + llvm::TargetMachine* tm = llvm_instance->createTargetMachine( + triple, cpu, features, target_options, reloc_model, code_model, opt_level); + ICHECK(tm != nullptr); + + return tm; +} + +static const llvm::MCSubtargetInfo* GetLLVMSubtargetInfo(const std::string& triple, + const std::string& cpu_name, + const std::string& feats) { + // create a LLVM instance + auto llvm_instance = CreateLLVMTargetInstance(triple, true); + // create a target machine + // required minimum: llvm::InitializeAllTargetMCs() + llvm::TargetOptions target_options; + auto tm = CreateLLVMTargetMachine(llvm_instance, triple, cpu_name, feats, target_options, + llvm::Reloc::Static, llvm::CodeModel::Small, + llvm::CodeGenOpt::Level(0)); + // create subtarget info module + const llvm::MCSubtargetInfo* MCInfo = tm->getMCSubtargetInfo(); + + return MCInfo; +} + llvm::TargetMachine* LLVMTargetInfo::GetOrCreateTargetMachine(bool allow_missing) { if (target_machine_) return target_machine_.get(); std::string error; - if (const llvm::Target* llvm_instance = llvm::TargetRegistry::lookupTarget(triple_, error)) { + if (const llvm::Target* llvm_instance = CreateLLVMTargetInstance(triple_, allow_missing)) { llvm::TargetMachine* tm = - llvm_instance->createTargetMachine(triple_, cpu_, GetTargetFeatureString(), target_options_, - reloc_model_, code_model_, opt_level_); + CreateLLVMTargetMachine(llvm_instance, triple_, cpu_, GetTargetFeatureString(), + target_options_, reloc_model_, code_model_, opt_level_); target_machine_ = std::unique_ptr(tm); } - if (!allow_missing) { - ICHECK(target_machine_ != nullptr) << error; - } + ICHECK(target_machine_ != nullptr); return target_machine_.get(); } @@ -662,6 +735,75 @@ void LLVMTargetInfo::GetOptionValue(LLVMTargetInfo::Option* opt) const { } } +const Array LLVMTargetInfo::GetAllLLVMTargets() const { + Array llvm_targets; + // iterate all archtypes + for (auto a = llvm::Triple::ArchType(llvm::Triple::ArchType::UnknownArch + 1); + a < llvm::Triple::ArchType::LastArchType; a = llvm::Triple::ArchType(a + 1)) { + std::string target_name = llvm::Triple::getArchTypeName(a).str(); + // get valid target + if (CreateLLVMTargetInstance(target_name + "--", true)) { + llvm_targets.push_back(target_name); + } + } + + return llvm_targets; +} + +const Array LLVMTargetInfo::GetAllLLVMTargetArches() const { + Array cpu_arches; + // get the subtarget info module + const auto MCInfo = GetLLVMSubtargetInfo(triple_, "", ""); + if (!MCInfo) { + return cpu_arches; + } + // get all arches + llvm::ArrayRef llvm_arches = +#if TVM_LLVM_VERSION < 170 + llvm::archViewer(*(llvm::MCSubtargetInfo*)MCInfo); +#else + MCInfo->getAllProcessorDescriptions(); +#endif + for (const auto& arch : llvm_arches) { + cpu_arches.push_back(arch.Key); + } + + return cpu_arches; +} + +const Array LLVMTargetInfo::GetAllLLVMCpuFeatures() const { + std::string feats = ""; + for (const auto& attr : attrs_) { + feats += feats.empty() ? attr : ("," + attr); + } + // get the subtarget info module + const auto MCInfo = GetLLVMSubtargetInfo(triple_, cpu_.c_str(), feats); + // get all features for CPU + llvm::ArrayRef llvm_features = +#if TVM_LLVM_VERSION < 180 + llvm::featViewer(*(llvm::MCSubtargetInfo*)MCInfo); +#else + MCInfo->getAllProcessorFeatures(); +#endif + Array cpu_features; + for (const auto& feat : llvm_features) { + if (MCInfo->checkFeatures("+" + std::string(feat.Key))) { + cpu_features.push_back(feat.Key); + } + } + + return cpu_features; +} + +const bool LLVMTargetInfo::TargetHasCPUFeature(const std::string& feature) const { + // lookup features for `-mcpu` + auto feats = GetAllLLVMCpuFeatures(); + bool has_feature = + std::any_of(feats.begin(), feats.end(), [&](const auto& var) { return var == feature; }); + + return has_feature; +} + // LLVMTarget bool LLVMTarget::modified_llvm_state_ = false; diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index 217db63aad..ac08008b80 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -266,6 +266,36 @@ class LLVMTargetInfo { */ bool MatchesGlobalState() const; + /*! + * \brief Get all supported targets from the LLVM backend + * \return list with all valid targets + */ + const Array GetAllLLVMTargets() const; + + /*! + * \brief Get all CPU arches from target + * \return list with all valid cpu architectures + * \note The arches are fetched from the LLVM backend using the target `-mtriple`. + */ + const Array GetAllLLVMTargetArches() const; + + /*! + * \brief Get all CPU features from target + * \return list with all valid cpu features + * \note The features are fetched from the LLVM backend using the target `-mtriple` + * and the `-mcpu` architecture, but also consider the `-mattr` attributes. + */ + const Array GetAllLLVMCpuFeatures() const; + + /*! + * \brief Check the target if has a specific cpu feature + * \param feature string with the feature to check + * \return true or false + * \note The feature is checked in the LLVM backend for the target `-mtriple` + * and `-mcpu` architecture, but also consider the `-mattr` attributes. + */ + const bool TargetHasCPUFeature(const std::string& feature) const; + protected: /*! * \brief Get the current value of given LLVM option diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 168163c416..05a7df230f 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -45,16 +45,6 @@ #include #include #include -#if TVM_LLVM_VERSION < 110 -#include -#include -#else -#if TVM_LLVM_VERSION < 170 -#include -#else -#include -#endif -#endif #include #include #include @@ -87,25 +77,6 @@ #include "codegen_llvm.h" #include "llvm_instance.h" -#if TVM_LLVM_VERSION < 110 -namespace llvm { -// SubtargetSubTypeKV view -template MCSubtargetInfo::*Member> -struct ArchViewer { - friend ArrayRef& archViewer(MCSubtargetInfo Obj) { return Obj.*Member; } -}; -template struct ArchViewer<&MCSubtargetInfo::ProcDesc>; -ArrayRef& archViewer(MCSubtargetInfo); -// SubtargetFeatureKV view -template MCSubtargetInfo::*Member> -struct FeatViewer { - friend ArrayRef& featViewer(MCSubtargetInfo Obj) { return Obj.*Member; } -}; -template struct FeatViewer<&MCSubtargetInfo::ProcFeatures>; -ArrayRef& featViewer(MCSubtargetInfo); -} // namespace llvm -#endif - namespace tvm { namespace codegen { @@ -514,131 +485,69 @@ TVM_REGISTER_GLOBAL("target.llvm_get_intrinsic_name").set_body_typed([](int64_t #endif }); -#if TVM_LLVM_VERSION < 110 -static const llvm::MCSubtargetInfo* llvm_compat_get_subtargetinfo(const std::string triple, - const std::string cpu_name) { - std::string error; - llvm::InitializeAllTargets(); - llvm::InitializeAllTargetMCs(); - // create a LLVM x86 instance - auto* llvm_instance = llvm::TargetRegistry::lookupTarget(triple, error); - // create a target machine - llvm::TargetOptions target_options; - auto RM = llvm::Optional(); - auto* tm = llvm_instance->createTargetMachine(triple, cpu_name.c_str(), "", target_options, RM); - // create subtarget info module - const llvm::MCSubtargetInfo* MCInfo = tm->getMCSubtargetInfo(); - - return MCInfo; -} - -static const Array llvm_compat_get_archlist(const std::string triple) { - // get the subtarget info module - const auto* MCInfo = llvm_compat_get_subtargetinfo(triple, ""); - // get all X86 arches - llvm::ArrayRef x86_arches = - llvm::archViewer(*(llvm::MCSubtargetInfo*)MCInfo); - Array cpu_arches; - for (auto& arch : x86_arches) { - cpu_arches.push_back(arch.Key); - } - return cpu_arches; -} - -static const Array llvm_compat_get_features(const std::string triple, - const std::string cpu_name) { - // get the subtarget info module - const auto* MCInfo = llvm_compat_get_subtargetinfo(triple, cpu_name.c_str()); - // get all features - llvm::ArrayRef x86_features = - llvm::featViewer(*(llvm::MCSubtargetInfo*)MCInfo); - // only targeted CPU features - Array cpu_features; - for (auto& feat : x86_features) { - if (MCInfo->checkFeatures("+" + std::string(feat.Key))) { - cpu_features.push_back(feat.Key); - } - } - return cpu_features; -} -#endif +TVM_REGISTER_GLOBAL("target.llvm_get_targets").set_body_typed([]() -> Array { + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, "llvm"); + return llvm_backend.GetAllLLVMTargets(); +}); -TVM_REGISTER_GLOBAL("target.llvm_x86_get_archlist") - .set_body_typed([](bool only64bit) -> Array { - Array cpu_arches; -#if TVM_LLVM_VERSION < 110 - cpu_arches = llvm_compat_get_archlist("x86_64--"); -#else - llvm::SmallVector x86_arches; - llvm::X86::fillValidCPUArchList(x86_arches, only64bit); - for (auto& arch : x86_arches) { - cpu_arches.push_back(arch.str()); +TVM_REGISTER_GLOBAL("target.llvm_get_cpu_archlist") + .set_body_typed([](const Target& target) -> Array { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return Array{}; + } } -#endif - return cpu_arches; + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetAllLLVMTargetArches(); }); -TVM_REGISTER_GLOBAL("target.llvm_x86_get_features") - .set_body_typed([](std::string cpu_name) -> Array { - Array cpu_features; -#if TVM_LLVM_VERSION < 110 - cpu_features = llvm_compat_get_features("x86_64--", cpu_name); -#else - llvm::SmallVector x86_features; - llvm::X86::getFeaturesForCPU(cpu_name, x86_features); - for (auto& feat : x86_features) { - cpu_features.push_back(feat.str()); +TVM_REGISTER_GLOBAL("target.llvm_get_cpu_features") + .set_body_typed([](const Target& target) -> Array { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return Array{}; + } } -#endif - return cpu_features; + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + return llvm_backend.GetAllLLVMCpuFeatures(); }); -TVM_REGISTER_GLOBAL("target.llvm_x86_has_feature") - .set_body_typed([](String feature, const Target& target) -> bool { - // target argument is optional (nullptr or None) - // if not explicit then use the current context target - Optional mcpu = target.defined() ? target->GetAttr("mcpu") - : Target::Current(false)->GetAttr("mcpu"); - Optional> mattr = target.defined() - ? target->GetAttr>("mattr") - : Target::Current(false)->GetAttr>("mattr"); - String name = target.defined() ? target->kind->name : Target::Current(false)->kind->name; - // lookup only for `llvm` targets having -mcpu - if ((name != "llvm") || !mcpu) { - return false; - } - // lookup in -mattr flags - bool is_in_mattr = - !mattr ? false - : std::any_of(mattr.value().begin(), mattr.value().end(), - [&](const String& var) { return var == ("+" + feature); }); -#if TVM_LLVM_VERSION < 110 - auto x86_arches = llvm_compat_get_archlist("x86_64--"); - // decline on invalid arch (avoid llvm assertion) - if (!std::any_of(x86_arches.begin(), x86_arches.end(), - [&](const String& var) { return var == mcpu.value(); })) { - return false; +TVM_REGISTER_GLOBAL("target.llvm_cpu_has_feature") + .set_body_typed([](const String feature, const Target& target) -> bool { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return false; + } } - // lookup in -mcpu llvm architecture flags - auto cpu_features = llvm_compat_get_features("x86_64--", mcpu.value()); + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_backend(*llvm_instance, use_target); + auto cpu_features = llvm_backend.GetAllLLVMCpuFeatures(); bool has_feature = std::any_of(cpu_features.begin(), cpu_features.end(), - [&](const String& var) { return var == feature; }); -#else - llvm::SmallVector x86_arches; - llvm::X86::fillValidCPUArchList(x86_arches, false); - // decline on invalid arch (avoid llvm assertion) - if (!std::any_of(x86_arches.begin(), x86_arches.end(), - [&](const llvm::StringRef& var) { return var == mcpu.value().c_str(); })) { - return false; + [&](auto& var) { return var == feature; }); + return has_feature; + }); + +TVM_REGISTER_GLOBAL("target.target_has_feature") + .set_body_typed([](const String feature, const Target& target) -> bool { + auto use_target = target.defined() ? target : Target::Current(false); + // ignore non "llvm" target + if (target.defined()) { + if (target->kind->name != "llvm") { + return false; + } } - // lookup in -mcpu llvm architecture flags - llvm::SmallVector x86_features; - llvm::X86::getFeaturesForCPU(mcpu.value().c_str(), x86_features); - bool has_feature = - std::any_of(x86_features.begin(), x86_features.end(), - [&](const llvm::StringRef& var) { return var == feature.c_str(); }); -#endif - return has_feature || is_in_mattr; + auto llvm_instance = std::make_unique(); + LLVMTargetInfo llvm_target(*llvm_instance, use_target); + return llvm_target.TargetHasCPUFeature(feature); }); TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 7d1fe9f8dd..1c15f95826 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -135,15 +135,8 @@ void StmtVisitor::VisitStmt_(const BlockNode* op) { VisitArray(op->reads, fvisit_buffer_region); VisitArray(op->writes, fvisit_buffer_region); VisitArray(op->match_buffers, - [this, fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { + [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { fvisit_buffer_region(match_buffer_region->source); - this->VisitExpr(match_buffer_region->buffer->elem_offset); - VisitArray(match_buffer_region->buffer->strides, - [this](const PrimExpr& e) { this->VisitExpr(e); }); - VisitArray(match_buffer_region->buffer->shape, - [this](const PrimExpr& e) { this->VisitExpr(e); }); - VisitArray(match_buffer_region->buffer->axis_separators, - [this](const IntImm& e) { this->VisitExpr(e); }); }); if (op->init.defined()) { this->VisitStmt(op->init.value()); @@ -245,28 +238,11 @@ class StmtMutator::Internal { static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const MatchBufferRegion& match_buffer_region) { - const Buffer& buffer = match_buffer_region->buffer; Array region = Mutate(self, match_buffer_region->source->region); - PrimExpr elem_offset = self->VisitExpr(buffer->elem_offset); - Array strides = Mutate(self, buffer->strides); - Array shape = Mutate(self, buffer->shape); - Array axis_separators = - MutateArray(self, buffer->axis_separators, - [self](const IntImm& e) { return Downcast(self->VisitExpr(e)); }); - - if (elem_offset.same_as(buffer->elem_offset) && strides.same_as(buffer->strides) && - shape.same_as(buffer->shape) && axis_separators.same_as(buffer->axis_separators)) { - if (region.same_as(match_buffer_region->source->region)) { - return match_buffer_region; - } else { - return MatchBufferRegion(buffer, - BufferRegion(match_buffer_region->source->buffer, region)); - } + if (region.same_as(match_buffer_region->source->region)) { + return match_buffer_region; } else { - Buffer new_buffer(buffer->data, buffer->dtype, shape, strides, elem_offset, buffer->name, - buffer->data_alignment, buffer->offset_factor, buffer->buffer_type, - axis_separators, buffer->span); - return MatchBufferRegion(new_buffer, + return MatchBufferRegion(match_buffer_region->buffer, BufferRegion(match_buffer_region->source->buffer, region)); } }; diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 6165099558..ca56172cb7 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -798,6 +798,7 @@ def test_compile_opencl(tflite_mobilenet_v1_0_25_128): assert type(tvmc_package.lib_path) is str assert type(tvmc_package.params) is bytearray assert os.path.exists(dumps_path) + assert path.exists("{}.{}".format(tvmc_package.package_path, "opencl")) @tvm.testing.requires_cmsisnn diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 9d33b15a91..8c5b578060 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -28,11 +28,11 @@ # prevent Keras from using up all gpu memory import keras +import pytest import tvm from tvm import relay from tvm.contrib import graph_executor import tvm.testing -import pytest if tf.executing_eagerly(): GPUS = tf.config.experimental.list_physical_devices("GPU") @@ -568,12 +568,23 @@ def test_forward_rnn(self, keras_mod): keras_mod.layers.SimpleRNN( units=16, return_state=False, activation="tanh", use_bias=False ), + keras_mod.layers.SimpleRNN( + units=16, return_state=False, activation="tanh", go_backwards=True + ), + keras_mod.layers.GRU( + units=16, + return_state=False, + recurrent_activation="sigmoid", + activation="tanh", + reset_after=False, + ), keras_mod.layers.GRU( units=16, return_state=False, recurrent_activation="sigmoid", activation="tanh", reset_after=False, + use_bias=False, ), keras_mod.layers.GRU( units=16, @@ -582,6 +593,7 @@ def test_forward_rnn(self, keras_mod): activation="tanh", reset_after=False, use_bias=False, + go_backwards=True, ), ] for rnn_func in rnn_funcs: diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index 7ddc347e86..fda5f1b723 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -20,11 +20,11 @@ import numpy as np import oneflow as flow +from packaging import version as package_version import tvm import tvm.testing import tvm.topi.testing from tvm import relay -from packaging import version as package_version MODEL_HOME = "test_model" diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 0a0ae561ab..bd984d32e6 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1837,7 +1837,10 @@ def test_depthwise_conv2d_int8(): wdata = np.random.rand(*kernel_shape) * 10 parameters = {"weight": tvm.nd.array(wdata.astype(weight_dtype))} - targets = ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"] + targets = [ + "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512", + "llvm -mtriple=x86_64-linux-gnu -mcpu=cascadelake", + ] llvm_version = tvm.target.codegen.llvm_version_major() for target in targets: if llvm_version >= 8: diff --git a/tests/python/relay/test_op_qnn_conv2_transpose.py b/tests/python/relay/test_op_qnn_conv2_transpose.py index ec273eb2f7..18ad68df9e 100644 --- a/tests/python/relay/test_op_qnn_conv2_transpose.py +++ b/tests/python/relay/test_op_qnn_conv2_transpose.py @@ -644,7 +644,7 @@ def test_broadcast_layout(): func = relay.Function(relay.analysis.free_vars(func), func) mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(opt_level=3): - libs = relay.build(mod, "llvm -mcpu=skylake-avx512") + libs = relay.build(mod, "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512") def test_non_scalar_input_scale_zp(): diff --git a/tests/python/relay/test_op_qnn_conv2d.py b/tests/python/relay/test_op_qnn_conv2d.py index 3736350cbf..e7d2c8941b 100644 --- a/tests/python/relay/test_op_qnn_conv2d.py +++ b/tests/python/relay/test_op_qnn_conv2d.py @@ -948,7 +948,9 @@ def test_broadcast_layout(): func = relay.Function(relay.analysis.free_vars(func), func) mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(opt_level=3): - graph, lib, params = relay.build(mod, "llvm -mcpu=skylake-avx512") + graph, lib, params = relay.build( + mod, "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512" + ) def test_depthwise_depth_multiplier(): diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 829c1d6ae4..87065b2d27 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -886,7 +886,7 @@ def before(): from tvm import topi def alter_conv2d(attrs, inputs, tinfos, out_type): - with tvm.target.Target("llvm -mcpu=core-avx2"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2"): return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, out_type) def expected(): @@ -1373,7 +1373,7 @@ def expected(): y = relay.Function(analysis.free_vars(y), y) return y - target = "llvm -mcpu=core-avx2" + target = "llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2" with tvm.target.Target(target): with TempOpAttr( "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout @@ -1441,7 +1441,7 @@ def expected(): ) return relay.Function(analysis.free_vars(dense), dense) - with tvm.target.Target("llvm -mcpu=core-avx2"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=core-avx2"): with TempOpAttr( "nn.dense", "FTVMAlterOpLayout", topi.x86.dense_alter_op._alter_dense_layout ): diff --git a/tests/python/relay/test_pass_qnn_legalize.py b/tests/python/relay/test_pass_qnn_legalize.py index 73ba9c2208..a32100ea20 100644 --- a/tests/python/relay/test_pass_qnn_legalize.py +++ b/tests/python/relay/test_pass_qnn_legalize.py @@ -138,7 +138,10 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check that Intel AVX512 (with or w/o VNNI) gets picked up. - for target in ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]: + for target in [ + "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512", + "llvm -mtriple=x86_64-linux-gnu -mcpu=cascadelake", + ]: with tvm.target.Target(target): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) @@ -170,7 +173,7 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check no transformation for Intel AVX512. - with tvm.target.Target("llvm -mcpu=skylake-avx512"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512"): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) assert tvm.ir.structural_equal(mod, legalized_mod) @@ -232,7 +235,10 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check that Intel AVX512 (with or w/o VNNI) gets picked up. - for target in ["llvm -mcpu=skylake-avx512", "llvm -mcpu=cascadelake"]: + for target in [ + "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512", + "llvm -mtriple=x86_64-linux-gnu -mcpu=cascadelake", + ]: with tvm.target.Target(target): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) @@ -264,7 +270,7 @@ def _get_mod(data_dtype, kernel_dtype): # Check transformations for platforms with fast Int8 support. ############################################################# # Check no transformation for Intel AVX512. - with tvm.target.Target("llvm -mcpu=skylake-avx512"): + with tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=skylake-avx512"): mod = relay.transform.InferType()(mod) legalized_mod = relay.qnn.transform.Legalize()(mod) assert tvm.ir.structural_equal(mod, legalized_mod) diff --git a/tests/python/target/test_llvm_features_info.py b/tests/python/target/test_llvm_features_info.py new file mode 100644 index 0000000000..1be71331dd --- /dev/null +++ b/tests/python/target/test_llvm_features_info.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm.target import _ffi_api, codegen, Target + +LLVM_VERSION = codegen.llvm_version_major() + + +def test_llvm_targets(): + + ## + ## check LLVM backend + ## + + # check blank results + assert len(codegen.llvm_get_targets()) + # check ffi vs python + assert str(codegen.llvm_get_targets()) == str(_ffi_api.llvm_get_targets()) + + # check LLVM target -mcpu legality + try: + tvm.target.codegen.llvm_get_cpu_features( + tvm.target.Target("llvm -mtriple=x86_64-linux-gnu -mcpu=dummy") + ) + assert False + except tvm.error.TVMError as e: + msg = str(e) + assert ( + msg.find( + "TVMError: LLVM cpu architecture `-mcpu=dummy` is not valid in `-mtriple=x86_64-linux-gnu`" + ) + != -1 + ) + + +min_llvm_version, llvm_target, cpu_arch, cpu_features, is_supported = tvm.testing.parameters( + (-1, "x86_64", "sandybridge", "sse4.1", True), + (-1, "x86_64", "ivybridge", ["sse4.1", "ssse3"], True), + (-1, "x86_64", "ivybridge", ["sse4.1", "ssse3", "avx512bw"], False), + # 32bit vs 64bit + (-1, "aarch64", "cortex-a55", "neon", True), + (-1, "aarch64", "cortex-a55", "dotprod", True), + (-1, "aarch64", "cortex-a55", "dsp", False), + (-1, "arm", "cortex-a55", "dsp", True), + (-1, "aarch64", "cortex-a55", ["neon", "dotprod"], True), + (-1, "aarch64", "cortex-a55", ["neon", "dotprod", "dsp"], False), + (-1, "arm", "cortex-a55", ["neon", "dotprod"], True), + (-1, "aarch64", "cortex-a55", ["neon", "dotprod", "dsp"], False), + (-1, "arm", "cortex-a55", ["neon", "dotprod", "dsp"], True), +) + + +def test_target_features(min_llvm_version, llvm_target, cpu_arch, cpu_features, is_supported): + + target = Target("llvm -mtriple=%s-- -mcpu=%s" % (llvm_target, cpu_arch)) + + ## + ## legalize llvm_target + ## + + assert llvm_target in codegen.llvm_get_targets() + + ## + ## legalize cpu_arch + ## + + ### with context + with target: + assert cpu_arch in codegen.llvm_get_cpu_archlist() + ### no context but with expicit target + assert cpu_arch in codegen.llvm_get_cpu_archlist(target) + # check ffi vs python + assert str(codegen.llvm_get_cpu_archlist(target)) == str(_ffi_api.llvm_get_cpu_archlist(target)) + + ## + ## check has_features + ## + + ### with context + with target: + assert codegen.llvm_cpu_has_features(cpu_features) == is_supported + ### no context but with expicit target + assert codegen.llvm_cpu_has_features(cpu_features, target) == is_supported + # check ffi vs python + for feat in cpu_features: + assert str(codegen.llvm_cpu_has_features(feat, target)) == str( + _ffi_api.llvm_cpu_has_feature(feat, target) + ) diff --git a/tests/python/target/test_x86_features.py b/tests/python/target/test_x86_features.py index 31a823b504..ef1ab9b423 100644 --- a/tests/python/target/test_x86_features.py +++ b/tests/python/target/test_x86_features.py @@ -18,85 +18,89 @@ import tvm from tvm.target import _ffi_api, codegen, Target -from tvm.target.x86 import target_has_features +from tvm.target.codegen import target_has_features LLVM_VERSION = codegen.llvm_version_major() min_llvm_version, tvm_target, x86_feature, is_supported = tvm.testing.parameters( # sse4.1 - (-1, "llvm -mcpu=btver2", "sse4a", True), - (-1, "llvm -mcpu=penryn", "sse4.1", True), - (-1, "llvm -mcpu=silvermont", "sse4.2", True), - (11, "llvm -mcpu=slm", "sse4.2", True), - (-1, "llvm -mcpu=goldmont", "sse4.2", True), - (-1, "llvm -mcpu=goldmont-plus", "sse4.2", True), - (-1, "llvm -mcpu=tremont", "sse4.2", True), - (-1, "llvm -mcpu=nehalem", "sse4.2", True), - (11, "llvm -mcpu=corei7", "sse4.2", True), - (-1, "llvm -mcpu=westmere", "sse4.2", True), - (-1, "llvm -mcpu=bdver1", "sse4.2", True), - (-1, "llvm -mcpu=bdver2", "sse4.2", True), - (-1, "llvm -mcpu=bdver3", "sse4.2", True), - (11, "llvm -mcpu=x86-64-v2", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=btver2", "sse4a", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=penryn", "sse4.1", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=silvermont", "sse4.2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=slm", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=goldmont", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=goldmont-plus", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=tremont", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=nehalem", "sse4.2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=corei7", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=westmere", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver1", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver2", "sse4.2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver3", "sse4.2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=x86-64-v2", "sse4.2", True), # avx - (-1, "llvm -mcpu=sandybridge", "avx", True), - (11, "llvm -mcpu=corei7-avx", "avx", True), - (-1, "llvm -mcpu=ivybridge", "avx", True), - (11, "llvm -mcpu=core-avx-i", "avx", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=sandybridge", "avx", True), + (11, "llvm -mtriple=x86_64-- -mcpu=corei7-avx", "avx", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=ivybridge", "avx", True), + (11, "llvm -mtriple=x86_64-- -mcpu=core-avx-i", "avx", True), # avx2 - (-1, "llvm -mcpu=haswell", "avx2", True), - (11, "llvm -mcpu=core-avx2", "avx2", True), - (-1, "llvm -mcpu=broadwell", "avx2", True), - (-1, "llvm -mcpu=skylake", "avx2", True), - (-1, "llvm -mcpu=bdver4", "avx2", True), - (-1, "llvm -mcpu=znver1", "avx2", True), - (-1, "llvm -mcpu=znver2", "avx2", True), - (11, "llvm -mcpu=znver3", "avx2", True), - (11, "llvm -mcpu=x86-64-v3", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=haswell", "avx2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=core-avx2", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=broadwell", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=skylake", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=bdver4", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=znver1", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=znver2", "avx2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=znver3", "avx2", True), + (11, "llvm -mtriple=x86_64-- -mcpu=x86-64-v3", "avx2", True), # avx512bw - (-1, "llvm -mcpu=skylake-avx512", "avx512bw", True), - (11, "llvm -mcpu=skx", "avx512bw", True), - (11, "llvm -mcpu=knl", "avx512bw", False), - (-1, "llvm -mcpu=knl", "avx512f", True), - (11, "llvm -mcpu=knl", ["avx512bw", "avx512f"], False), - (11, "llvm -mcpu=knl", ("avx512bw", "avx512f"), False), - (-1, "llvm -mcpu=knl", "avx512cd", True), - (11, "llvm -mcpu=knl", ["avx512cd", "avx512f"], True), - (11, "llvm -mcpu=knl", ("avx512cd", "avx512f"), True), - (-1, "llvm -mcpu=knl", "avx512er", True), - (-1, "llvm -mcpu=knl", "avx512pf", True), - (11, "llvm -mcpu=knm", "avx512bw", False), - (-1, "llvm -mcpu=knm", "avx512f", True), - (-1, "llvm -mcpu=knm", "avx512cd", True), - (-1, "llvm -mcpu=knm", "avx512er", True), - (-1, "llvm -mcpu=knm", "avx512pf", True), - (11, "llvm -mcpu=x86-64-v4", "avx512bw", True), - (-1, "llvm -mcpu=cannonlake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=skylake-avx512", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=skx", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512bw", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512f", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ["avx512bw", "avx512f"], False), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ("avx512bw", "avx512f"), False), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512cd", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ["avx512cd", "avx512f"], True), + (11, "llvm -mtriple=x86_64-- -mcpu=knl", ("avx512cd", "avx512f"), True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512er", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knl", "avx512pf", True), + (11, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512bw", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512f", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512cd", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512er", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=knm", "avx512pf", True), + (11, "llvm -mtriple=x86_64-- -mcpu=x86-64-v4", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cannonlake", "avx512bw", True), # explicit enumeration of VNNI capable due to collision with alderlake - (11, "llvm -mcpu=alderlake", "avx512bw", False), - (-1, "llvm -mcpu=cascadelake", "avx512bw", True), - (-1, "llvm -mcpu=icelake-client", "avx512bw", True), - (-1, "llvm -mcpu=icelake-server", "avx512bw", True), - (11, "llvm -mcpu=rocketlake", "avx512bw", True), - (-1, "llvm -mcpu=tigerlake", "avx512bw", True), - (-1, "llvm -mcpu=cooperlake", "avx512bw", True), - (11, "llvm -mcpu=sapphirerapids", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=alderlake", "avx512bw", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=cascadelake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-client", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-server", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=rocketlake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=tigerlake", "avx512bw", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cooperlake", "avx512bw", True), + (11, "llvm -mtriple=x86_64-- -mcpu=sapphirerapids", "avx512bw", True), # avx512vnni - (11, "llvm -mcpu=alderlake", "avx512vnni", False), - (11, "llvm -mcpu=alderlake", "avxvnni", True), - (-1, "llvm -mcpu=cascadelake", "avx512vnni", True), - (-1, "llvm -mcpu=icelake-client", "avx512vnni", True), - (-1, "llvm -mcpu=icelake-server", "avx512vnni", True), - (11, "llvm -mcpu=rocketlake", "avx512vnni", True), - (-1, "llvm -mcpu=tigerlake", "avx512vnni", True), - (-1, "llvm -mcpu=cooperlake", "avx512vnni", True), - (11, "llvm -mcpu=sapphirerapids", "avx512vnni", True), + (11, "llvm -mtriple=x86_64-- -mcpu=alderlake", "avx512vnni", False), + (11, "llvm -mtriple=x86_64-- -mcpu=alderlake", "avxvnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cascadelake", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-client", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=icelake-server", "avx512vnni", True), + (11, "llvm -mtriple=x86_64-- -mcpu=rocketlake", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=tigerlake", "avx512vnni", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=cooperlake", "avx512vnni", True), + (11, "llvm -mtriple=x86_64-- -mcpu=sapphirerapids", "avx512vnni", True), # amx-int8 - (11, "llvm -mcpu=sapphirerapids", "amx-int8", True), + (11, "llvm -mtriple=x86_64-- -mcpu=sapphirerapids", "amx-int8", True), # generic CPU (no features) but with extra -mattr - (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "avx2", True), - (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "sse4.1", True), - (-1, "llvm -mcpu=x86-64 -mattr=+sse4.1,+avx2", "ssse3", False), + (-1, "llvm -mtriple=x86_64-- -mcpu=x86-64 -mattr=+sse4.1,+avx2", "avx2", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=x86-64 -mattr=+sse4.1,+avx2", "sse4.1", True), + # enabling +sse4.1 implies ssse3 presence in LLVM + (-1, "llvm -mtriple=x86_64-- -mcpu=x86-64 -mattr=+sse4.1,+avx2", "ssse3", True), + (-1, "llvm -mtriple=x86_64-- -mcpu=ivybridge -mattr=-ssse3", "ssse3", False), + # disabling avx512f (foundation) also disables avx512bw + (-1, "llvm -mtriple=x86_64-- -mcpu=cascadelake -mattr=-avx512f", "avx512bw", False), ) @@ -135,7 +139,7 @@ def test_x86_target_features(min_llvm_version, tvm_target, x86_feature, is_suppo if isinstance(x86_feature, str): # check for feature via the ffi llvm api (no explicit target, no context target) try: - assert _ffi_api.llvm_x86_has_feature(x86_feature, None) == is_supported + assert _ffi_api.target_has_feature(x86_feature, None) == is_supported assert False except tvm.error.InternalError as e: msg = str(e) @@ -154,7 +158,7 @@ def test_x86_target_features(min_llvm_version, tvm_target, x86_feature, is_suppo assert target_has_features(x86_feature, Target(tvm_target)) == is_supported if isinstance(x86_feature, str): # check for feature via the ffi llvm api (with explicit target, no context target) - assert _ffi_api.llvm_x86_has_feature(x86_feature, Target(tvm_target)) == is_supported + assert _ffi_api.target_has_feature(x86_feature, Target(tvm_target)) == is_supported ## ## with context @@ -166,11 +170,8 @@ def test_x86_target_features(min_llvm_version, tvm_target, x86_feature, is_suppo assert target_has_features(x86_feature) == is_supported # check for feature via the python api (with explicit target) assert target_has_features(x86_feature, Target(tvm_target)) == is_supported - if isinstance(x86_feature, str): - # check for feature via the ffi llvm api (current context target) - assert _ffi_api.llvm_x86_has_feature(x86_feature, None) == is_supported - # check for feature via the ffi llvm api (with explicit target) - assert _ffi_api.llvm_x86_has_feature(x86_feature, Target(tvm_target)) == is_supported - # check for feature in target's llvm full x86 CPU feature list - if not Target(tvm_target).mattr: - assert (x86_feature in codegen.llvm_x86_get_features(mcpu)) == is_supported + # check for feature via the ffi llvm api (current context target) + (sum(_ffi_api.target_has_feature(feat, None) for feat in x86_feature) > 0) == is_supported + # check for feature in target's llvm full x86 CPU feature list + if (not Target(tvm_target).mattr) and isinstance(x86_feature, str): + assert (x86_feature in codegen.llvm_get_cpu_features()) == is_supported diff --git a/tests/python/unittest/test_allreduce.py b/tests/python/unittest/test_allreduce.py index 708384daf0..fed4e4c04d 100644 --- a/tests/python/unittest/test_allreduce.py +++ b/tests/python/unittest/test_allreduce.py @@ -19,6 +19,8 @@ import numpy as np from tvm.script import tir as T +import pytest + @T.prim_func def reduce(a: T.handle, b: T.handle, d1: T.int32, d2: T.int32, d3: T.int32) -> None: @@ -82,6 +84,48 @@ def test_allreduce_sum(dims, target, dev): tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-6, atol=1e-6) +define_metal_compile_callback = tvm.testing.parameter(True, False) + + +@pytest.fixture +def optional_metal_compile_callback(define_metal_compile_callback): + name = "tvm_callback_metal_compile" + cached = tvm.get_global_func(name, allow_missing=True) + + if define_metal_compile_callback: + + @tvm.register_func(name, override=True) + def compile_metal(src, target): + return tvm.contrib.xcode.compile_metal(src, sdk="macosx") + + yield + + if define_metal_compile_callback: + if cached is None: + tvm._ffi.registry.remove_global_func(name) + else: + tvm.register_func(name, cached, override=True) + + +@tvm.testing.requires_metal(support_required="compile-only") +def test_allreduce_sum_compile(optional_metal_compile_callback): + # Disable the parametrization over dims, at least for now + dims = (1, 1, 2) + target = "metal" + + d1, d2, d3 = dims + _, _, _d1, _d2, _d3 = reduce.params + mod = reduce.specialize({_d1: d1, _d2: d2, _d3: d3}) + sch = tvm.tir.Schedule(mod) + blk = sch.get_block("reduce") + i, j, k, l = sch.get_loops(blk) + sch.bind(i, "blockIdx.x") + sch.bind(j, "threadIdx.z") + sch.bind(k, "threadIdx.y") + sch.bind(l, "threadIdx.x") + tvm.build(sch.mod["main"], target=target) + + @tvm.testing.parametrize_targets("cuda", "metal") def test_allreduce_max(dims, target, dev): d1, d2, d3 = dims diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 6a2f5573b2..f1316ae3ce 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -946,7 +946,7 @@ def test_llvm_target_attributes(): xo, xi = s[C].split(C.op.axis[0], nparts=2) s[C].parallel(xo) - target_llvm = "llvm -mcpu=skylake -mattr=+avx512f" + target_llvm = "llvm -mtriple=x86_64-linux-gnu -mcpu=skylake -mattr=+avx512f" target = tvm.target.Target(target_llvm, host=target_llvm) module = tvm.build(s, [A, B, C, n], target=target, name="test_func") diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py b/tests/python/unittest/test_tir_transform_unify_thread_binding.py index d42adfcee4..9ee8643312 100644 --- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py +++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py @@ -258,45 +258,6 @@ def unified_element_wise_implicit_block(a: T.handle, b: T.handle, c: T.handle) - ) -@T.prim_func -def match_buffer_with_elem_offset( - A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset: T.int32 -) -> None: - for i in T.thread_binding(0, 4, "blockIdx.x"): - for j in range(2): - with T.block(): - T.writes(A[I[i], offset, j * 4 : j * 4 + 4]) - sub_A = T.match_buffer( - A[I[i], offset, j * 4 : j * 4 + 4], - (4), - elem_offset=I[i] * 80 + offset * 8 + j * 4, - ) - for ji in range(0, 4): - sub_A[j * 4 + ji] = 1 - - -@T.prim_func -def unified_match_buffer_with_elem_offset( - A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset: T.int32 -) -> None: - for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"): - for j in range(2): - with T.block(""): - T.reads(I[blockIdx_x]) - T.writes(A[I[blockIdx_x], offset, j * 4 : j * 4 + 4]) - sub_A = T.match_buffer( - A[I[blockIdx_x], offset, j * 4 : j * 4 + 4], - (4,), - elem_offset=I[blockIdx_x] * 80 + offset * 8 + j * 4, - ) - for ji in range(4): - i = T.int32() - sub_A_1 = T.Buffer( - (4,), data=sub_A.data, elem_offset=I[i] * 80 + offset * 8 + j * 4 - ) - sub_A_1[j * 4 + ji] = T.float32(1) - - def test_thread_x(): _check(element_wise_thread_x, unified_element_wise_thread_x) @@ -327,10 +288,6 @@ def test_implicit_block(): _check(element_wise_implicit_block, unified_element_wise_implicit_block) -def test_match_buffer_with_elem_offset(): - _check(match_buffer_with_elem_offset, unified_match_buffer_with_elem_offset) - - def test_inner_binding_with_annotation(): @T.prim_func def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")): diff --git a/tests/scripts/setup-adreno-env.sh b/tests/scripts/setup-adreno-env.sh index 55a92c5f61..15c124a0f0 100755 --- a/tests/scripts/setup-adreno-env.sh +++ b/tests/scripts/setup-adreno-env.sh @@ -20,12 +20,13 @@ ENVIRONMENT="" RPC_PORT="" ADB_SERIAL="" +LISTEN_PORT=5000 function usage() { echo "Helper script to setup the environment for Tracker, RPC Device and for application" echo "Usage (Help) : source setup-adreno-env.sh -h" echo "Usage (Tracker): source setup-adreno-env.sh -e tracker -p " - echo "Usage (Device): source setup-adreno-env.sh -e device -p -d " + echo "Usage (Device): source setup-adreno-env.sh -e device -p -d [-l ]" echo "Usage (Query): source setup-adreno-env.sh -e query -p " } @@ -46,6 +47,11 @@ while [[ $# -gt 0 ]]; do shift # past argument shift # past value ;; + -l|--listen-port) + LISTEN_PORT="$2" + shift # past argument + shift # past value + ;; -h|--help) usage return 0 @@ -62,6 +68,7 @@ done echo "ENVIRONMENT = ${ENVIRONMENT}" echo "RPC_PORT = ${RPC_PORT}" echo "ADB_SERIAL = ${ADB_SERIAL}" +echo "DEVICE LISTEN POPRT = ${LISTEN_PORT}" function def_environment() { @@ -100,10 +107,11 @@ case ${ENVIRONMENT} in fi adb reverse tcp:${TVM_TRACKER_PORT} tcp:${TVM_TRACKER_PORT} - adb forward tcp:5000 tcp:5000 - adb forward tcp:5001 tcp:5001 - adb forward tcp:5002 tcp:5002 - adb shell "cd ${TARGET_FOLDER}; killall -9 tvm_rpc-${USER}; sleep 2; LD_LIBRARY_PATH=${TARGET_FOLDER}/ ./tvm_rpc-${USER} server --host=0.0.0.0 --port=5000 --port-end=5010 --tracker=127.0.0.1:${TVM_TRACKER_PORT} --key=${RPC_DEVICE_KEY}" + adb forward tcp:${LISTEN_PORT} tcp:${LISTEN_PORT} + adb forward tcp:$((LISTEN_PORT + 1)) tcp:$((LISTEN_PORT + 1)) + adb forward tcp:$((LISTEN_PORT + 2)) tcp:$((LISTEN_PORT + 2)) + adb forward tcp:$((LISTEN_PORT + 3)) tcp:$((LISTEN_PORT + 3)) + adb shell "cd ${TARGET_FOLDER}; killall -9 tvm_rpc-${USER}; sleep 2; LD_LIBRARY_PATH=${TARGET_FOLDER}/ ./tvm_rpc-${USER} server --host=0.0.0.0 --port=${LISTEN_PORT} --port-end=$((LISTEN_PORT + 10)) --tracker=127.0.0.1:${TVM_TRACKER_PORT} --key=${RPC_DEVICE_KEY}" ;; "query") diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh index 62e6ffecbc..1b6750f165 100755 --- a/tests/scripts/task_config_build_adreno.sh +++ b/tests/scripts/task_config_build_adreno.sh @@ -25,6 +25,8 @@ cp ../cmake/config.cmake . if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake +else +echo set\(USE_OPENCL ON\) >> config.cmake fi echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake