From 4e6b64fd207d19c43ea480d4a0a78ee0ab1615e6 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 8 Nov 2024 16:02:25 -0800 Subject: [PATCH 01/59] [shortfin] NFC change that makes compilation succeed on clang-14. (#461) --- shortfin/src/shortfin/array/storage.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/shortfin/src/shortfin/array/storage.h b/shortfin/src/shortfin/array/storage.h index b1d7eb6ad..2ea8f5aef 100644 --- a/shortfin/src/shortfin/array/storage.h +++ b/shortfin/src/shortfin/array/storage.h @@ -232,14 +232,14 @@ class typed_mapping { span_type span() { return span_type(data(), size()); } const_span_type span() const { return const_span_type(data(), size()); } - span_type::iterator begin() { return span().begin(); } - span_type::iterator end() { return span().end(); } + typename span_type::iterator begin() { return span().begin(); } + typename span_type::iterator end() { return span().end(); } - const_span_type::iterator begin() const { return span().begin(); } - const_span_type::iterator end() const { return span().end(); } + typename const_span_type::iterator begin() const { return span().begin(); } + typename const_span_type::iterator end() const { return span().end(); } - const_span_type::iterator cbegin() const { return span().begin(); } - const_span_type::iterator cend() const { return span().end(); } + typename const_span_type::iterator cbegin() const { return span().begin(); } + typename const_span_type::iterator cend() const { return span().end(); } private: mapping untyped_mapping_; From bdff9e344431be056f963eac6832c72ab23d610e Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 8 Nov 2024 18:02:59 -0600 Subject: [PATCH 02/59] [shortfin] Update dev_me.py for new machine. * Add checkes for setuptools version since it was found that very old versions are quite broken. * Add check for wheel and packaging packages. * Make the clang version check a warning. --- shortfin/dev_me.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/shortfin/dev_me.py b/shortfin/dev_me.py index ca6916767..9125c8713 100755 --- a/shortfin/dev_me.py +++ b/shortfin/dev_me.py @@ -31,8 +31,8 @@ # Otherwise, the shortfin build will download a pinned IREE source tree. import argparse +import importlib import os -from packaging.version import Version from pathlib import Path import re import subprocess @@ -40,10 +40,19 @@ import sys import sysconfig +try: + from packaging.version import Version +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"'packaging' package not installed and required: Install with:\n" + f" {sys.executable} -m pip install packaging" + ) + CMAKE_REQUIRED_VERSION = Version("3.29") PYTHON_REQUIRED_VERSION = Version("3.12") CLANG_REQUIRED_VERSION = Version("16") +SETUPTOOLS_REQUIRED_VERSION = Version("61.0") class EnvInfo: @@ -58,6 +67,8 @@ def __init__(self, args): self.ninja_exe = shutil.which("ninja") self.clang_exe, self.clang_version = self.find_clang(args) self.iree_dir = self.find_iree(args) + self.setuptools_version = self.find_package_version("setuptools") + self.wheel_version = self.find_package_version("wheel") self.configured_dirs = [] self.add_configured(self.this_dir / "build" / "cmake" / "default") @@ -116,6 +127,13 @@ def find_iree(self, args): sys.exit(1) return str(iree_dir) + def find_package_version(self, package_name: str) -> Version | None: + try: + m = importlib.import_module(package_name) + except ModuleNotFoundError: + return None + return Version(m.__version__) + def check_prereqs(self, args): if self.cmake_version is None or self.cmake_version < CMAKE_REQUIRED_VERSION: print( @@ -131,7 +149,7 @@ def check_prereqs(self, args): ) sys.exit(1) if self.clang_exe and self.clang_version < CLANG_REQUIRED_VERSION: - print(f"ERROR: clang version too old: {self.clang_exe}") + print(f"WARNING: clang version too old: {self.clang_exe}") print(f" REQUIRED: {CLANG_REQUIRED_VERSION}, Found {self.clang_version}") elif not self.clang_exe: print(f"WARNING: Building the project with clang is highly recommended") @@ -143,6 +161,19 @@ def check_prereqs(self, args): ) sys.exit(1) + if ( + self.setuptools_version is None + or self.setuptools_version < SETUPTOOLS_REQUIRED_VERSION + ): + print( + f"ERROR: 'setuptools' packaging is not installed or too old. " + f"Found {self.setuptools_version}, Need {SETUPTOOLS_REQUIRED_VERSION}" + ) + sys.exit(1) + if self.wheel_version is None: + print(f"'wheel' package is not installed") + sys.exit(1) + def __repr__(self): report = [ f"python: {self.python_exe}", @@ -153,6 +184,8 @@ def __repr__(self): f"ninja: {self.ninja_exe}", f"clang: {self.clang_exe} ({self.clang_version})", f"iree: {self.iree_dir}", + f"setuptools: {self.setuptools_version}", + f"wheel: {self.wheel_version}", ] return "\n".join(report) From 86ff97dabb499ce6615b32968758c2fe2ad65d78 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 8 Nov 2024 18:25:18 -0600 Subject: [PATCH 03/59] [shortfin] Fix build error keeping amdgpu Python bindings from being included. This fixes forward an issue introduced in #434. --- shortfin/src/CMakeLists.txt | 5 +++++ shortfin/src/shortfin/local/systems/CMakeLists.txt | 2 ++ 2 files changed, 7 insertions(+) diff --git a/shortfin/src/CMakeLists.txt b/shortfin/src/CMakeLists.txt index 9a955f742..53d801f36 100644 --- a/shortfin/src/CMakeLists.txt +++ b/shortfin/src/CMakeLists.txt @@ -4,6 +4,10 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Any definitions which must be reflected on the public library must be added +# to this library. +add_library(shortfin_public_defs INTERFACE) + add_subdirectory(shortfin) # Common definitions exported from both static and dynamic libraries. @@ -28,6 +32,7 @@ shortfin_public_library( shortfin_systems_factory ${_SHORTFIN_LIB_OPTIONAL_COMPONENTS} USAGE_DEPS + shortfin_public_defs spdlog::spdlog fmt::fmt xtensor diff --git a/shortfin/src/shortfin/local/systems/CMakeLists.txt b/shortfin/src/shortfin/local/systems/CMakeLists.txt index b2bcbef23..b1c9d8b44 100644 --- a/shortfin/src/shortfin/local/systems/CMakeLists.txt +++ b/shortfin/src/shortfin/local/systems/CMakeLists.txt @@ -29,6 +29,7 @@ shortfin_cc_component( iree_task_task ) list(APPEND _SYSTEM_COMPONENTS shortfin_systems_host) +target_compile_definitions(shortfin_public_defs INTERFACE SHORTFIN_HAVE_HOSTCPU) if(SHORTFIN_SYSTEMS_AMDGPU) shortfin_cc_component( @@ -47,6 +48,7 @@ if(SHORTFIN_SYSTEMS_AMDGPU) iree_hal_drivers_hip_hip ) list(APPEND _SYSTEM_COMPONENTS shortfin_systems_amdgpu) + target_compile_definitions(shortfin_public_defs INTERFACE SHORTFIN_HAVE_AMDGPU) endif() shortfin_cc_component( From 2cbf76866fd7371cd5adf27a1f59bfc5d62c9bff Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 8 Nov 2024 17:03:16 -0800 Subject: [PATCH 04/59] [shortfin] Fix some issues blocking operation on Python 3.10. (#462) With this unit tests appear to pass (I'm on a system with other issues but appear unrelated). --- shortfin/python/_shortfin/asyncio_bridge.py | 20 ++++++++++++++++++-- shortfin/python/shortfin_apps/sd/server.py | 2 +- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/shortfin/python/_shortfin/asyncio_bridge.py b/shortfin/python/_shortfin/asyncio_bridge.py index 4cb54449c..28264e9e3 100644 --- a/shortfin/python/_shortfin/asyncio_bridge.py +++ b/shortfin/python/_shortfin/asyncio_bridge.py @@ -5,10 +5,19 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import asyncio +import inspect from . import lib as sfl +# Feature detect some versions where signatures changes. +if "context" in inspect.signature(asyncio.Task).parameters: + # Python > 3.10 + _ASYNCIO_TASK_HAS_CONTEXT = True +else: + _ASYNCIO_TASK_HAS_CONTEXT = False + + class PyWorkerEventLoop(asyncio.AbstractEventLoop): def __init__(self, worker: sfl.local.Worker): self._worker = worker @@ -17,8 +26,15 @@ def get_debug(self): # Requirement of asyncio. return False - def create_task(self, coro, *, name=None, context=None): - return asyncio.Task(coro, loop=self, name=name, context=context) + if _ASYNCIO_TASK_HAS_CONTEXT: + + def create_task(self, coro, *, name=None, context=None): + return asyncio.Task(coro, loop=self, name=name, context=context) + + else: + + def create_task(self, coro, *, name=None): + return asyncio.Task(coro, loop=self, name=name) def create_future(self): return asyncio.Future(loop=self) diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 849337900..eba8b490b 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -151,7 +151,7 @@ def get_modules(args): f"--model={modelname}", f"--iree-hal-target-device={args.device}", f"--iree-hip-target={args.target}", - f"--iree-compile-extra-args={" ".join(ireec_args)}", + f"--iree-compile-extra-args={' '.join(ireec_args)}", ] print("BUILDER INPUT:\n", " \ \n ".join(builder_args)) output = subprocess.check_output(builder_args).decode() From eefc3531388cd74f4fa331f027bb756d48ddbdf1 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 8 Nov 2024 20:58:06 -0800 Subject: [PATCH 05/59] [shortfin] Add heuristics for adjusting file descriptor limits on Linux. (#465) Without this, on very large systems (i.e. 64 GPU / 192 Core), it was not possible to open all devices without manual tweaks to file handle descriptor limits. The result were various forms of RESOURCE_EXHAUSTED errors. This may require more tweaking in the future, and for fully robust setups, production installations should explicitly configure high limits. However, these heuristics remove a significant barrier to entry and provide some feedback in terms of logs. Progress on #463 --- shortfin/src/shortfin/local/systems/amdgpu.cc | 17 +++++ shortfin/src/shortfin/local/systems/host.cc | 18 ++++++ shortfin/src/shortfin/support/CMakeLists.txt | 2 + shortfin/src/shortfin/support/sysconfig.cc | 63 +++++++++++++++++++ shortfin/src/shortfin/support/sysconfig.h | 25 ++++++++ 5 files changed, 125 insertions(+) create mode 100644 shortfin/src/shortfin/support/sysconfig.cc create mode 100644 shortfin/src/shortfin/support/sysconfig.h diff --git a/shortfin/src/shortfin/local/systems/amdgpu.cc b/shortfin/src/shortfin/local/systems/amdgpu.cc index 2625e8325..78efad709 100644 --- a/shortfin/src/shortfin/local/systems/amdgpu.cc +++ b/shortfin/src/shortfin/local/systems/amdgpu.cc @@ -7,6 +7,7 @@ #include "shortfin/local/systems/amdgpu.h" #include "shortfin/support/logging.h" +#include "shortfin/support/sysconfig.h" namespace shortfin::local::systems { @@ -190,6 +191,22 @@ SystemPtr AMDGPUSystemBuilder::CreateSystem() { } } + // Estimate the resource requirements for the requested number of devices. + // As of 2024-11-08, the number of file handles required to open 64 device + // partitions was 31 times the number to open one device. Because it is not + // good to run near the limit, we conservatively round that up to 64 above + // an arbitrary baseline of 768. This means that on a small, four device + // system, we will not request to raise limits for the Linux default of + // 1024 file handles, but we will raise for everything larger (which tends + // to be where the problems are). + size_t expected_device_count = + used_device_ids.size() * logical_devices_per_physical_device_; + if (!sysconfig::EnsureFileLimit(expected_device_count * 64 + 768)) { + logging::error( + "Could not ensure sufficient file handles for minimum operations: " + "Suggest setting explicit limits with `ulimit -n` and system settings"); + } + // Initialize all used GPU devices. for (size_t instance_ordinal = 0; instance_ordinal < used_device_ids.size(); ++instance_ordinal) { diff --git a/shortfin/src/shortfin/local/systems/host.cc b/shortfin/src/shortfin/local/systems/host.cc index 5629979e4..440a3ff51 100644 --- a/shortfin/src/shortfin/local/systems/host.cc +++ b/shortfin/src/shortfin/local/systems/host.cc @@ -11,6 +11,7 @@ #include "iree/hal/local/loaders/registration/init.h" #include "shortfin/support/iree_helpers.h" #include "shortfin/support/logging.h" +#include "shortfin/support/sysconfig.h" namespace shortfin::local::systems { @@ -149,6 +150,8 @@ iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) { } // Create one queue executor per node. + unsigned total_needed_file_handles = 512; + bool has_issued_limit_error = false; std::vector queue_executors; queue_executors.reserve(selected_nodes.size()); queue_node_ids_.reserve(selected_nodes.size()); @@ -162,6 +165,21 @@ iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) { node_id, iree_task_topology_group_count(&topology.topology)); queue_executors.push_back({}); auto &executor = queue_executors.back(); + // As of 2024-11-8, it took approximately 32 file handles per node-group. + // To be conservative because file handle limits are basically free, we + // round up to 64 and assume a floor of 512. This allows small, default + // 8 group, single node configs to require no limit increase for Linux + // 1024 default cases. + total_needed_file_handles += 64 * topology.topology.group_count; + if (!sysconfig::EnsureFileLimit(total_needed_file_handles) && + !has_issued_limit_error) { + logging::error( + "Could not ensure sufficient file handles for minimum operations: " + "Suggest setting explicit limits with `ulimit -n` and system " + "settings"); + has_issued_limit_error = true; + } + SHORTFIN_THROW_IF_ERROR(iree_task_executor_create( host_cpu_deps_.task_executor_options, &topology.topology, host_allocator(), executor.for_output())); diff --git a/shortfin/src/shortfin/support/CMakeLists.txt b/shortfin/src/shortfin/support/CMakeLists.txt index cbf171894..ea8572466 100644 --- a/shortfin/src/shortfin/support/CMakeLists.txt +++ b/shortfin/src/shortfin/support/CMakeLists.txt @@ -16,12 +16,14 @@ shortfin_cc_component( iree_concurrency.h logging.h stl_extras.h + sysconfig.h SRCS blocking_executor.cc config.cc globals.cc iree_helpers.cc logging.cc + sysconfig.cc DEPS iree_base_base # TODO: Maybe reclassify some of these low level, shared support entities diff --git a/shortfin/src/shortfin/support/sysconfig.cc b/shortfin/src/shortfin/support/sysconfig.cc new file mode 100644 index 000000000..486f5ffc4 --- /dev/null +++ b/shortfin/src/shortfin/support/sysconfig.cc @@ -0,0 +1,63 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "shortfin/support/sysconfig.h" + +#include "shortfin/support/logging.h" + +#ifdef __linux__ +#include +#endif + +namespace shortfin::sysconfig { + +// ----------------------------------------------------------------------------- +// File handle limits +// ----------------------------------------------------------------------------- + +#ifdef __linux__ + +bool EnsureFileLimit(unsigned needed_limit) { + struct rlimit limit; + if (getrlimit(RLIMIT_NOFILE, &limit) != 0) { + return {}; + } + + if (limit.rlim_cur >= needed_limit) return true; + unsigned requested_limit = needed_limit; + if (limit.rlim_max >= needed_limit) { + logging::debug( + "Estimated number of open file handles ({}) < current limit ({}) but " + "within max limit ({}): Increasing limit", + needed_limit, limit.rlim_cur, limit.rlim_max); + } else if (limit.rlim_max > limit.rlim_cur) { + logging::warn( + "Esimated number of open file handles ({}) < current ({}) and max ({}) " + "limit: Increasing to max", + needed_limit, limit.rlim_cur, limit.rlim_max); + requested_limit = limit.rlim_max; + } else { + logging::warn("Esimated number of open file handles ({}) < max ({})", + needed_limit, limit.rlim_max); + return false; + } + + limit.rlim_cur = requested_limit; + if (setrlimit(RLIMIT_NOFILE, &limit) != 0) { + logging::error("Could not set open file handle limit to {}", + requested_limit); + return false; + } + + return limit.rlim_cur >= needed_limit; +} + +#else +// Fallback implementation. +bool EnsureFileLimit(unsigned needed_limit) { return true; } +#endif + +} // namespace shortfin::sysconfig diff --git a/shortfin/src/shortfin/support/sysconfig.h b/shortfin/src/shortfin/support/sysconfig.h new file mode 100644 index 000000000..864405efc --- /dev/null +++ b/shortfin/src/shortfin/support/sysconfig.h @@ -0,0 +1,25 @@ +// Copyright 2024 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef SHORTFIN_SUPPORT_SYSCONFIG_H +#define SHORTFIN_SUPPORT_SYSCONFIG_H + +#include +#include + +namespace shortfin::sysconfig { + +// Attempts to ensure that the given number of file descriptors can be created. +// If the system does not support such a thing (i.e. GetOpenFileLimit() returns +// nothing), then nothing is done and true is returned. If the system does +// support it and heuristics say this should be allowed, then true will return. +// Otherwise, a warning will be logged and false returned. +// This is a best effort attempt. +bool EnsureFileLimit(unsigned needed_limit); + +} // namespace shortfin::sysconfig + +#endif // SHORTFIN_SUPPORT_SYSCONFIG_H From 029d35e125368163e71006632e1e9fc36b0ec750 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Sat, 9 Nov 2024 11:10:20 -0600 Subject: [PATCH 06/59] (shortfin-sd) Program initialization and logging improvements (#444) Fixes program initialization per worker and systembuilder usage/options --- .../python/shortfin/support/logging_setup.py | 11 +- .../shortfin_apps/sd/components/builders.py | 6 +- .../shortfin_apps/sd/components/manager.py | 51 ++++--- .../shortfin_apps/sd/components/service.py | 126 ++++++++++-------- .../sd/examples/sdxl_request_bs32.json | 55 ++++++++ .../shortfin_apps/sd/examples/send_request.py | 22 ++- shortfin/python/shortfin_apps/sd/server.py | 29 ++-- 7 files changed, 196 insertions(+), 104 deletions(-) create mode 100644 shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json diff --git a/shortfin/python/shortfin/support/logging_setup.py b/shortfin/python/shortfin/support/logging_setup.py index 39cf3cf75..5585e6a82 100644 --- a/shortfin/python/shortfin/support/logging_setup.py +++ b/shortfin/python/shortfin/support/logging_setup.py @@ -38,7 +38,7 @@ def __init__(self): native_handler.setFormatter(NativeFormatter()) # TODO: Source from env vars. -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) logger.addHandler(native_handler) @@ -47,7 +47,10 @@ def configure_main_logger(module_suffix: str = "__main__") -> logging.Logger: Returns a logger that can be used for the main module itself. """ - logging.root.addHandler(native_handler) - logging.root.setLevel(logging.DEBUG) # TODO: source from env vars main_module = sys.modules["__main__"] - return logging.getLogger(f"{main_module.__package__}.{module_suffix}") + logging.root.setLevel(logging.INFO) + logger = logging.getLogger(f"{main_module.__package__}.{module_suffix}") + logger.setLevel(logging.INFO) + logger.addHandler(native_handler) + + return logger diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py index 1f9d0c2ee..81203e713 100644 --- a/shortfin/python/shortfin_apps/sd/components/builders.py +++ b/shortfin/python/shortfin_apps/sd/components/builders.py @@ -159,9 +159,9 @@ def needs_file(filename, ctx, namespace=FileNamespace.GEN): if os.path.exists(out_file): needed = False else: - name_path = "bin" if namespace == FileNamespace.BIN else "" - if name_path: - filename = os.path.join(name_path, filename) + # name_path = "bin" if namespace == FileNamespace.BIN else "" + # if name_path: + # filename = os.path.join(name_path, filename) filekey = os.path.join(ctx.path, filename) ctx.executor.all[filekey] = None needed = True diff --git a/shortfin/python/shortfin_apps/sd/components/manager.py b/shortfin/python/shortfin_apps/sd/components/manager.py index c52cf62f7..846c4ced6 100644 --- a/shortfin/python/shortfin_apps/sd/components/manager.py +++ b/shortfin/python/shortfin_apps/sd/components/manager.py @@ -12,29 +12,42 @@ logger = logging.getLogger(__name__) +def get_selected_devices(sb: sf.SystemBuilder, device_ids=None): + available = sb.available_devices + selected = [] + if device_ids is not None: + if len(device_ids) >= len(available): + raise ValueError( + f"Requested more device ids ({device_ids}) than available ({available})." + ) + for did in device_ids: + if isinstance(did, str): + try: + did = int(did) + except ValueError: + did = did + if did in available: + selected.append(did) + elif isinstance(did, int): + selected.append(available[did]) + else: + raise ValueError(f"Device id {did} could not be parsed.") + else: + selected = available + return selected + + class SystemManager: - def __init__(self, device="local-task", device_ids=None): + def __init__(self, device="local-task", device_ids=None, async_allocs=True): if any(x in device for x in ["local-task", "cpu"]): self.ls = sf.host.CPUSystemBuilder().create_system() elif any(x in device for x in ["hip", "amdgpu"]): - sc_query = sf.amdgpu.SystemBuilder() - available = sc_query.available_devices - selected = [] - if device_ids is not None: - if len(device_ids) >= len(available): - raise ValueError( - f"Requested more device ids ({device_ids}) than available ({available})." - ) - for did in device_ids: - if did in available: - selected.append(did) - elif isinstance(did, int): - selected.append(available[did]) - else: - raise ValueError(f"Device id {did} could not be parsed.") - else: - selected = available - sb = sf.amdgpu.SystemBuilder(amdgpu_visible_devices=";".join(selected)) + sb = sf.SystemBuilder( + system_type="amdgpu", amdgpu_async_allocations=async_allocs + ) + if device_ids: + sb.visible_devices = sb.available_devices + sb.visible_devices = get_selected_devices(sb, device_ids) self.ls = sb.create_system() logger.info(f"Created local system with {self.ls.device_names} devices") # TODO: Come up with an easier bootstrap thing than manually diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index af8423a11..a64013db0 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -62,12 +62,20 @@ def __init__( self.inference_parameters: dict[str, list[sf.BaseProgramParameters]] = {} self.inference_modules: dict[str, sf.ProgramModule] = {} self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {} - self.inference_programs: dict[str, sf.Program] = {} + self.inference_programs: dict[int, dict[str, sf.Program]] = {} self.trace_execution = trace_execution self.show_progress = show_progress + + self.prog_isolation = prog_isolations[prog_isolation] + self.workers_per_device = workers_per_device self.fibers_per_device = fibers_per_device - self.prog_isolation = prog_isolations[prog_isolation] + if fibers_per_device % workers_per_device != 0: + raise ValueError( + "Currently, fibers_per_device must be divisible by workers_per_device" + ) + self.fibers_per_worker = int(fibers_per_device / workers_per_device) + self.workers = [] self.fibers = [] self.fiber_status = [] @@ -81,7 +89,9 @@ def __init__( ) self.fibers.append(fiber) self.fiber_status.append(0) - + for idx in range(len(self.workers)): + self.inference_programs[idx] = {} + self.inference_functions[idx] = {} # Scope dependent objects. self.batcher = BatcherProcess(self) @@ -108,52 +118,59 @@ def load_inference_parameters( self.inference_parameters[component].append(p) def start(self): - for fiber in self.fibers: - for component in self.inference_modules: - component_modules = [ - sf.ProgramModule.parameter_provider( - self.sysman.ls, *self.inference_parameters.get(component, []) - ), - *self.inference_modules[component], - ] - self.inference_programs[component] = sf.Program( + # Initialize programs. + # This can work if we only initialize one set of programs per service, as our programs + # in SDXL are stateless and + for component in self.inference_modules: + component_modules = [ + sf.ProgramModule.parameter_provider( + self.sysman.ls, *self.inference_parameters.get(component, []) + ), + *self.inference_modules[component], + ] + for worker_idx, worker in enumerate(self.workers): + worker_devices = self.fibers[ + worker_idx * (self.fibers_per_worker) + ].raw_devices + + self.inference_programs[worker_idx][component] = sf.Program( modules=component_modules, - devices=fiber.raw_devices, + devices=worker_devices, isolation=self.prog_isolation, trace_execution=self.trace_execution, ) - - # TODO: export vmfbs with multiple batch size entrypoints - - self.inference_functions["encode"] = {} - for bs in self.model_params.clip_batch_sizes: - self.inference_functions["encode"][bs] = self.inference_programs["clip"][ - f"{self.model_params.clip_module_name}.encode_prompts" - ] - - self.inference_functions["denoise"] = {} - for bs in self.model_params.unet_batch_sizes: - self.inference_functions["denoise"][bs] = { - "unet": self.inference_programs["unet"][ - f"{self.model_params.unet_module_name}.{self.model_params.unet_fn_name}" - ], - "init": self.inference_programs["scheduler"][ - f"{self.model_params.scheduler_module_name}.run_initialize" - ], - "scale": self.inference_programs["scheduler"][ - f"{self.model_params.scheduler_module_name}.run_scale" - ], - "step": self.inference_programs["scheduler"][ - f"{self.model_params.scheduler_module_name}.run_step" - ], - } - - self.inference_functions["decode"] = {} - for bs in self.model_params.vae_batch_sizes: - self.inference_functions["decode"][bs] = self.inference_programs["vae"][ - f"{self.model_params.vae_module_name}.decode" - ] - + for worker_idx, worker in enumerate(self.workers): + self.inference_functions[worker_idx]["encode"] = {} + for bs in self.model_params.clip_batch_sizes: + self.inference_functions[worker_idx]["encode"][ + bs + ] = self.inference_programs[worker_idx]["clip"][ + f"{self.model_params.clip_module_name}.encode_prompts" + ] + self.inference_functions[worker_idx]["denoise"] = {} + for bs in self.model_params.unet_batch_sizes: + self.inference_functions[worker_idx]["denoise"][bs] = { + "unet": self.inference_programs[worker_idx]["unet"][ + f"{self.model_params.unet_module_name}.{self.model_params.unet_fn_name}" + ], + "init": self.inference_programs[worker_idx]["scheduler"][ + f"{self.model_params.scheduler_module_name}.run_initialize" + ], + "scale": self.inference_programs[worker_idx]["scheduler"][ + f"{self.model_params.scheduler_module_name}.run_scale" + ], + "step": self.inference_programs[worker_idx]["scheduler"][ + f"{self.model_params.scheduler_module_name}.run_step" + ], + } + self.inference_functions[worker_idx]["decode"] = {} + for bs in self.model_params.vae_batch_sizes: + self.inference_functions[worker_idx]["decode"][ + bs + ] = self.inference_programs[worker_idx]["vae"][ + f"{self.model_params.vae_module_name}.decode" + ] + # breakpoint() self.batcher.launch() def shutdown(self): @@ -320,7 +337,11 @@ def __init__( ): super().__init__(fiber=service.fibers[index]) self.service = service - self.worker_index = index + self.fiber_index = index + self.worker_index = int( + (index - index % self.service.fibers_per_worker) + / self.service.fibers_per_worker + ) self.exec_requests: list[InferenceExecRequest] = [] @measure(type="exec", task="inference process") @@ -335,7 +356,7 @@ async def run(self): phases = self.exec_requests[0].phases req_count = len(self.exec_requests) - device0 = self.service.fibers[self.worker_index].device(0) + device0 = self.service.fibers[self.fiber_index].device(0) if phases[InferencePhase.PREPARE]["required"]: await self._prepare(device=device0, requests=self.exec_requests) if phases[InferencePhase.ENCODE]["required"]: @@ -346,11 +367,11 @@ async def run(self): await self._decode(device=device0, requests=self.exec_requests) if phases[InferencePhase.POSTPROCESS]["required"]: await self._postprocess(device=device0, requests=self.exec_requests) - + await device0 for i in range(req_count): req = self.exec_requests[i] req.done.set_success() - self.service.fiber_status[self.worker_index] = 0 + self.service.fiber_status[self.fiber_index] = 0 except Exception: logger.exception("Fatal error in image generation") @@ -400,8 +421,7 @@ async def _prepare(self, device, requests): async def _encode(self, device, requests): req_bs = len(requests) - - entrypoints = self.service.inference_functions["encode"] + entrypoints = self.service.inference_functions[self.worker_index]["encode"] for bs, fn in entrypoints.items(): if bs >= req_bs: break @@ -454,7 +474,7 @@ async def _denoise(self, device, requests): step_count = requests[0].steps cfg_mult = 2 if self.service.model_params.cfg_mode else 1 # Produce denoised latents - entrypoints = self.service.inference_functions["denoise"] + entrypoints = self.service.inference_functions[self.worker_index]["denoise"] for bs, fns in entrypoints.items(): if bs >= req_bs: break @@ -590,7 +610,7 @@ async def _denoise(self, device, requests): async def _decode(self, device, requests): req_bs = len(requests) # Decode latents to images - entrypoints = self.service.inference_functions["decode"] + entrypoints = self.service.inference_functions[self.worker_index]["decode"] for bs, fn in entrypoints.items(): if bs >= req_bs: break diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json new file mode 100644 index 000000000..192a2be61 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json @@ -0,0 +1,55 @@ +{ + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo" + ], + "neg_prompt": [ + "Watermark, blurry, oversaturated, low resolution, pollution" + ], + "height": [ + 1024 + ], + "width": [ + 1024 + ], + "steps": [ + 20 + ], + "guidance_scale": [ + 7.5 + ], + "seed": [ + 0 + ], + "output_type": [ + "base64" + ] +} diff --git a/shortfin/python/shortfin_apps/sd/examples/send_request.py b/shortfin/python/shortfin_apps/sd/examples/send_request.py index 94fae9659..dd2226e70 100644 --- a/shortfin/python/shortfin_apps/sd/examples/send_request.py +++ b/shortfin/python/shortfin_apps/sd/examples/send_request.py @@ -30,29 +30,35 @@ def bytes_to_img(bytes, idx=0, width=1024, height=1024): print(f"Saved to shortfin_sd_output_{timestamp}_{idx}.png") -def send_json_file(file_path): +def send_json_file(args): # Read the JSON file try: - if file_path == "default": + if args.file == "default": data = sample_request else: - with open(file_path, "r") as json_file: + with open(args.file, "r") as json_file: data = json.load(json_file) except Exception as e: print(f"Error reading the JSON file: {e}") return - + data["prompt"] = ( + [data["prompt"]] + if isinstance(data["prompt"], str) + else data["prompt"] * args.reps + ) # Send the data to the /generate endpoint try: response = requests.post("http://0.0.0.0:8000/generate", json=data) response.raise_for_status() # Raise an error for bad responses - print("Saving response as image...") timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") request = json.loads(response.request.body.decode("utf-8")) for idx, item in enumerate(response.json()["images"]): width = get_batched(request, "width", idx) height = get_batched(request, "height", idx) - bytes_to_img(item.encode("utf-8"), idx, width, height) + if args.save: + print("Saving response as image...") + bytes_to_img(item.encode("utf-8"), idx, width, height) + print("Responses processed.") except requests.exceptions.RequestException as e: print(f"Error sending the request: {e}") @@ -72,5 +78,7 @@ def get_batched(request, arg, idx): if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--file", type=str, default="default") + p.add_argument("--reps", type=int, default=1) + p.add_argument("--save", type=argparse.BooleanOptionalAction, help="save images") args = p.parse_args() - send_json_file(args.file) + send_json_file(args) diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index eba8b490b..7ace4d407 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -17,8 +17,6 @@ from iree.build import * -import uvicorn.logging - # Import first as it does dep checking and reporting. from shortfin.interop.fastapi import FastAPIResponder @@ -27,7 +25,6 @@ from fastapi import FastAPI, Request, Response import uvicorn - from .components.generate import ClientGenerateBatchProcess from .components.config_struct import ModelParams from .components.io_struct import GenerateReqInput @@ -36,7 +33,6 @@ from .components.tokenizer import Tokenizer from .components.builders import sdxl - from shortfin.support.logging_setup import configure_main_logger logger = configure_main_logger("server") @@ -88,7 +84,7 @@ async def generate_request(gen_req: GenerateReqInput, request: Request): def configure(args) -> SystemManager: # Setup system (configure devices, etc). - sysman = SystemManager(args.device, args.device_ids) + sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations) # Setup each service we are hosting. tokenizers = [] @@ -119,6 +115,7 @@ def configure(args) -> SystemManager: def get_modules(args): + # TODO: Move this out of server entrypoint vmfbs = {"clip": [], "unet": [], "vae": [], "scheduler": []} params = {"clip": [], "unet": [], "vae": []} model_flags = copy.deepcopy(vmfbs) @@ -153,9 +150,7 @@ def get_modules(args): f"--iree-hip-target={args.target}", f"--iree-compile-extra-args={' '.join(ireec_args)}", ] - print("BUILDER INPUT:\n", " \ \n ".join(builder_args)) output = subprocess.check_output(builder_args).decode() - print("OUTPUT:", output) output_paths = output.splitlines() filenames.extend(output_paths) @@ -199,7 +194,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): ) parser.add_argument( "--device_ids", - type=int, + type=str, nargs="*", default=None, help="Device IDs visible to the system builder. Defaults to None (full visibility). Can be an index or a sf device id like amdgpu:0:0@0", @@ -239,9 +234,6 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): choices=["per_fiber", "per_call", "none"], help="Concurrency control -- How to isolate programs.", ) - parser.add_argument( - "--log_level", type=str, default="error", choices=["info", "debug", "error"] - ) parser.add_argument( "--show_progress", action="store_true", @@ -252,6 +244,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): action="store_true", help="Enable tracing of program modules.", ) + parser.add_argument( + "--amdgpu_async_allocations", + action="store_true", + help="Enable asynchronous allocations for amdgpu device contexts.", + ) parser.add_argument( "--splat", action="store_true", @@ -282,16 +279,12 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): default="", help="Path to local artifacts cache.", ) - log_levels = { - "info": logging.INFO, - "debug": logging.DEBUG, - "error": logging.ERROR, - } args = parser.parse_args(argv) - log_level = log_levels[args.log_level] - logger.setLevel(log_level) + log_level = logging.INFO + + logging.root.setLevel(log_level) logger.addHandler(logging.FileHandler("shortfin_sd.log")) global sysman sysman = configure(args) From d619cb5e488018b2c1e6151cea70da11451903ff Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 11 Nov 2024 08:04:26 +0100 Subject: [PATCH 07/59] [shortfin] Include patch level in so (#455) --- shortfin/build_tools/cmake/shortfin_library.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shortfin/build_tools/cmake/shortfin_library.cmake b/shortfin/build_tools/cmake/shortfin_library.cmake index 23755fb9d..aaa97a6c1 100644 --- a/shortfin/build_tools/cmake/shortfin_library.cmake +++ b/shortfin/build_tools/cmake/shortfin_library.cmake @@ -80,7 +80,7 @@ function(shortfin_public_library) PRIVATE ${_DYLIB_COMPONENTS} ) set_target_properties("${_RULE_NAME}" PROPERTIES - VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR} + VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH} SOVERSION ${SOVERSION} ) endif() From 6c3a5b2692e1a11f971387939be5e2f6766cd7f6 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 11 Nov 2024 08:40:07 +0100 Subject: [PATCH 08/59] Bump pyenv and Python version (#470) --- .github/workflows/ci_linux_x64_asan-libshortfin.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml index d9eee2576..b61536218 100644 --- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -32,8 +32,8 @@ concurrency: env: PYENV_ROOT: ${{ github.workspace }}/pyenv - PYENV_REF: 9ecd803bffaffb949fbdd8c70cb086227f6a3202 # v2.4.10 - PYTHON_VER: 3.12.3 + PYENV_REF: 96b3fb2fc3bee85650cb22e2cb06c83c24509a6d # v2.4.17 + PYTHON_VER: 3.12.7 CACHE_ASAN_VER: 2 CACHE_DEPS_VER: 1 IREE_SOURCE_DIR: ${{ github.workspace }}/iree From a50741d78a2fad2a718155836c75c5597534d586 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 11 Nov 2024 09:12:17 +0100 Subject: [PATCH 09/59] [shortfin] Build for and test with Python 3.10 (#469) --- .github/workflows/build_packages.yml | 4 ++++ .github/workflows/ci_linux_x64-libshortfin.yml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build_packages.yml b/.github/workflows/build_packages.yml index 0ef36bc09..16b42b5c7 100644 --- a/.github/workflows/build_packages.yml +++ b/.github/workflows/build_packages.yml @@ -60,6 +60,10 @@ jobs: platform: linux-x86_64 package: sharktank python-version: cp311-cp311 # Ignored (generic wheel), set for workflow naming + - runs-on: ubuntu-24.04 + platform: linux-x86_64 + package: shortfin + python-version: cp310-cp310 - runs-on: ubuntu-24.04 platform: linux-x86_64 package: shortfin diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index ad154748f..6e9925010 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -40,7 +40,7 @@ jobs: runs-on: ubuntu-24.04 strategy: matrix: - python-version: ["3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - name: Install dependencies From 91200e599cb0af351a9bd9e522ae61741c6687c9 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 11 Nov 2024 18:46:30 +0100 Subject: [PATCH 10/59] [shortfin] Add missing build time dependency (#474) Adds `typing-extensions` as a build time dependency as it required by nanobind, see https://github.com/nod-ai/SHARK-Platform/actions/runs/11777270868/job/32801349583#step:5:744 --- shortfin/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/shortfin/pyproject.toml b/shortfin/pyproject.toml index eb54c835b..47cde6775 100644 --- a/shortfin/pyproject.toml +++ b/shortfin/pyproject.toml @@ -4,6 +4,7 @@ requires = [ "setuptools>=61.0", "wheel", "ninja", + 'typing-extensions ; python_version == "3.10" ', ] build-backend = "setuptools.build_meta" From 8d54823c5a2af376f3d6b8b5b949bc144a527839 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 11 Nov 2024 19:00:42 +0100 Subject: [PATCH 11/59] Pin (and update) actions (#471) Updates actions and pin as suggested by OpenSSF Scorecard, see https://github.com/ossf/scorecard/blob/main/docs/checks.md#pinned-dependencies. --- .github/workflows/build_packages.yml | 2 +- .github/workflows/ci-llama-large-tests.yaml | 10 +++++----- .github/workflows/ci-llama-quick-tests.yaml | 6 +++--- .github/workflows/ci-sdxl.yaml | 2 +- .github/workflows/ci-shark-platform.yml | 6 +++--- .github/workflows/ci-sharktank.yml | 6 +++--- .github/workflows/ci-tuner.yml | 2 +- .github/workflows/ci_eval.yaml | 12 ++++++------ .github/workflows/ci_linux_x64-libshortfin.yml | 2 +- .github/workflows/ci_windows_x64-libshortfin.yml | 2 +- .github/workflows/pre-commit.yaml | 6 +++--- 11 files changed, 28 insertions(+), 28 deletions(-) diff --git a/.github/workflows/build_packages.yml b/.github/workflows/build_packages.yml index 16b42b5c7..4a332b6f5 100644 --- a/.github/workflows/build_packages.yml +++ b/.github/workflows/build_packages.yml @@ -26,7 +26,7 @@ jobs: with: submodules: false - name: Setup Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: 3.12 cache: "pip" diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index 394cba93a..d79031b8c 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -41,15 +41,15 @@ jobs: - name: "Setting up Python" id: setup_python - uses: actions/setup-python@v3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - name: "Checkout Code" - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Cache Pip Packages - uses: actions/cache@v4 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} @@ -78,13 +78,13 @@ jobs: run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-all-llama --iree-hip-target=gfx942 --html=out/index.html - name: Deploy to GitHub Pages - uses: peaceiris/actions-gh-pages@v3 + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 with: github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} publish_dir: ./out - name: Upload llama executable files - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: name: llama-files path: ${{ github.workspace }}/${{ steps.date.outputs.date }} diff --git a/.github/workflows/ci-llama-quick-tests.yaml b/.github/workflows/ci-llama-quick-tests.yaml index ce55f81f8..decd0aa96 100644 --- a/.github/workflows/ci-llama-quick-tests.yaml +++ b/.github/workflows/ci-llama-quick-tests.yaml @@ -42,15 +42,15 @@ jobs: - name: "Setting up Python" id: setup_python - uses: actions/setup-python@v3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - name: "Checkout Code" - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Cache Pip Packages - uses: actions/cache@v4 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} diff --git a/.github/workflows/ci-sdxl.yaml b/.github/workflows/ci-sdxl.yaml index 373bc9319..9c5776c4c 100644 --- a/.github/workflows/ci-sdxl.yaml +++ b/.github/workflows/ci-sdxl.yaml @@ -76,7 +76,7 @@ jobs: git submodule update --init --depth 1 -- third_party/hip-build-deps/ - name: Setup Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: "3.12" cache: "pip" diff --git a/.github/workflows/ci-shark-platform.yml b/.github/workflows/ci-shark-platform.yml index 708fed66f..445e2e448 100644 --- a/.github/workflows/ci-shark-platform.yml +++ b/.github/workflows/ci-shark-platform.yml @@ -37,15 +37,15 @@ jobs: steps: - name: "Setting up Python" id: setup_python - uses: actions/setup-python@v3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - name: "Checkout Code" - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Cache Pip Packages - uses: actions/cache@v4 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index eadc33501..6f359077a 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -38,15 +38,15 @@ jobs: steps: - name: "Setting up Python" id: setup_python - uses: actions/setup-python@v3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - name: "Checkout Code" - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Cache Pip Packages - uses: actions/cache@v4 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index f1fcebfdc..cd9a48d5e 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -35,7 +35,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Set up Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: '3.10.12' diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 4c98bf79b..7afaeb1fe 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -39,15 +39,15 @@ jobs: steps: - name: "Setting up Python" id: setup_python - uses: actions/setup-python@v3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - name: "Checkout Code" - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Cache Pip Packages - uses: actions/cache@v4 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} @@ -95,15 +95,15 @@ jobs: steps: - name: "Setting up Python" id: setup_python - uses: actions/setup-python@v3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{matrix.version}} - name: "Checkout Code" - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Cache Pip Packages - uses: actions/cache@v4 + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index 6e9925010..c1b039da3 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -71,7 +71,7 @@ jobs: git submodule update --init --depth 1 -- third_party/hip-build-deps/ - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: ${{ matrix.python-version }} cache: "pip" diff --git a/.github/workflows/ci_windows_x64-libshortfin.yml b/.github/workflows/ci_windows_x64-libshortfin.yml index 4bbef8f12..929244af4 100644 --- a/.github/workflows/ci_windows_x64-libshortfin.yml +++ b/.github/workflows/ci_windows_x64-libshortfin.yml @@ -66,7 +66,7 @@ jobs: git submodule update --init --depth 1 -- third_party/hip-build-deps/ - name: Setup Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.3 + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 with: python-version: "3.12" cache: "pip" diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index 2b11178bf..8ec1e8d55 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -9,6 +9,6 @@ jobs: pre-commit: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v3 - - uses: pre-commit/action@v3.0.1 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 From 7ae6064ca38119e4dbc8e9ee84b9828751f15c03 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Mon, 11 Nov 2024 12:09:59 -0800 Subject: [PATCH 12/59] Add script to drop `rcYYYYMMDD` suffix from packages for PyPI deploy. (#477) Progress on https://github.com/nod-ai/SHARK-Platform/issues/400 This lets us take a release from https://github.com/nod-ai/SHARK-Platform/releases/tag/dev-wheels and "promote it" for publishing to https://pypi.org/ by turning the version from `X.Y.ZrcYYYYMMDD` to simply `X.Y.Z`. The script is forked from the one in IREE I added last week: https://github.com/iree-org/iree/pull/19067. More steps like https://iree.dev/developers/general/release-management/#promoting-a-candidate-to-stable and scripting like https://github.com/iree-org/iree/blob/main/build_tools/python_deploy/pypi_deploy.sh will follow. --- build_tools/promote_whl_from_rc_to_final.py | 69 +++++++++++++++++++++ build_tools/requirements-pypi-deploy.txt | 4 ++ 2 files changed, 73 insertions(+) create mode 100755 build_tools/promote_whl_from_rc_to_final.py create mode 100644 build_tools/requirements-pypi-deploy.txt diff --git a/build_tools/promote_whl_from_rc_to_final.py b/build_tools/promote_whl_from_rc_to_final.py new file mode 100755 index 000000000..061dd933b --- /dev/null +++ b/build_tools/promote_whl_from_rc_to_final.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This scripts takes a file like 'sharktank-2.9.0rc20241110-py3-none-any.whl' +# with embedded version '2.9.0rc20241110' as input and then drops the +# 'rcYYYYMMDD' suffix from both the embedded version and file name. +# +# Typical usage: +# pip install -r requirements-pypi-deploy.txt +# ./promote_whl_from_rc_to_final.py /path/to/file.whl --delete-old-wheel + +import argparse +from change_wheel_version import change_wheel_version +from packaging.version import Version +from pathlib import Path +from pkginfo import Wheel + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "input_file", + help="Path to the input .whl file to promote", + type=Path, + ) + parser.add_argument( + "--delete-old-wheel", + help="Deletes the original wheel after successfully promoting it", + action="store_true", + default=False, + ) + return parser.parse_args() + + +def main(args): + original_wheel_path = args.input_file + print(f"Promoting whl from rc to final: '{original_wheel_path}'") + + original_wheel = Wheel(original_wheel_path) + original_version = Version(original_wheel.version) + base_version = original_version.base_version + print( + f" Original wheel version is '{original_version}' with base '{base_version}'" + ) + + if str(base_version) == str(original_version): + print(" Version is already a release version, skipping") + return + + print(f" Changing to base version: '{base_version}'") + new_wheel_path = change_wheel_version(original_wheel_path, str(base_version), None) + print(f" New wheel path is '{new_wheel_path}'") + + new_wheel = Wheel(new_wheel_path) + new_version = Version(new_wheel.version) + print(f" New wheel version is '{new_version}'") + + if args.delete_old_wheel: + print(" Deleting original wheel") + original_wheel_path.unlink() + + +if __name__ == "__main__": + main(parse_arguments()) diff --git a/build_tools/requirements-pypi-deploy.txt b/build_tools/requirements-pypi-deploy.txt new file mode 100644 index 000000000..dcc32d47a --- /dev/null +++ b/build_tools/requirements-pypi-deploy.txt @@ -0,0 +1,4 @@ +change_wheel_version +packaging +pkginfo +twine From c2a1488141ef781dbc0684f2d708b29050434859 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 11 Nov 2024 23:25:00 +0100 Subject: [PATCH 13/59] Add a shark-ai meta package (#475) Co-authored-by: Scott Todd --- .../python_deploy/compute_common_version.py | 75 ++++++++++++++++ .../python_deploy/write_requirements.py | 89 +++++++++++++++++++ shark-ai/.gitignore | 3 + shark-ai/README.md | 3 + shark-ai/pyproject.toml | 38 ++++++++ shark-ai/setup.py | 33 +++++++ 6 files changed, 241 insertions(+) create mode 100644 build_tools/python_deploy/compute_common_version.py create mode 100644 build_tools/python_deploy/write_requirements.py create mode 100644 shark-ai/.gitignore create mode 100644 shark-ai/README.md create mode 100644 shark-ai/pyproject.toml create mode 100644 shark-ai/setup.py diff --git a/build_tools/python_deploy/compute_common_version.py b/build_tools/python_deploy/compute_common_version.py new file mode 100644 index 000000000..accdc3d28 --- /dev/null +++ b/build_tools/python_deploy/compute_common_version.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This scripts grabs the `X.Y.Z[.dev]` version identifier from the +# sharktank and shortfin version files and computes the version +# for the meta package. + +import argparse +from pathlib import Path +import json +from datetime import datetime +import sys + +from packaging.version import Version + + +parser = argparse.ArgumentParser() +parser.add_argument("--write-json", action="store_true") + +release_type = parser.add_mutually_exclusive_group() +release_type.add_argument("-stable", "--stable-release", action="store_true") # default +release_type.add_argument("-rc", "--nightly-release", action="store_true") + + +args = parser.parse_args() + +if not (args.stable_release or args.nightly_release): + parser.print_usage(sys.stderr) + sys.stderr.write("error: A release type is required\n") + sys.exit(1) + +THIS_DIR = Path(__file__).parent.resolve() +REPO_ROOT = THIS_DIR.parent.parent + +VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version_info.json" +VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version_info.json" +VERSION_FILE_LOCAL = REPO_ROOT / "packaging/shark-ai/version_local.json" + + +def load_version_info(version_file): + with open(version_file, "rt") as f: + return json.load(f) + + +def write_version_info(): + with open(VERSION_FILE_LOCAL, "w") as f: + json.dump(version_local, f, indent=2) + f.write("\n") + + +sharktank_version = load_version_info(VERSION_FILE_SHARKTANK) +SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version") +SHARKTANK_BASE_VERSION = Version(SHARKTANK_PACKAGE_VERSION).base_version + +shortfin_version = load_version_info(VERSION_FILE_SHORTFIN) +SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version") +SHORTFIN_BASE_VERSION = Version(SHORTFIN_PACKAGE_VERSION).base_version + +if SHARKTANK_BASE_VERSION > SHORTFIN_BASE_VERSION: + COMMON_VERSION = SHARKTANK_BASE_VERSION +else: + COMMON_VERSION = SHORTFIN_BASE_VERSION + +if args.nightly_release: + COMMON_VERSION += "rc" + datetime.today().strftime("%Y%m%d") + +if args.write_json: + version_local = {"package-version": COMMON_VERSION} + write_version_info() + +print(COMMON_VERSION) diff --git a/build_tools/python_deploy/write_requirements.py b/build_tools/python_deploy/write_requirements.py new file mode 100644 index 000000000..346d46ae1 --- /dev/null +++ b/build_tools/python_deploy/write_requirements.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This script writes the `packaging/shark-ai/requirements.txt` file and pins +# the versions of the dependencies accordingly. For nighly releases, +# * sharktank +# * shortfin +# get pinned to the corresponding nighly version. For stable releases, +# * iree-base-compiler +# * iree-base-runtime +# * iree-turbine +# * sharktank +# * shortfin +# get pinned to the corresponding `X.Y.*` version. + +import argparse +from pathlib import Path +import json + +from packaging.version import Version + + +parser = argparse.ArgumentParser() +parser.add_argument("--version-suffix", action="store", type=str) + +args = parser.parse_args() + + +THIS_DIR = Path(__file__).parent.resolve() +REPO_ROOT = THIS_DIR.parent.parent + +VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version_info.json" +VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version_info.json" +VERSION_FILE_LOCAL = REPO_ROOT / "packaging/shark-ai/version_local.json" +REQUIREMENTS_TXT = REPO_ROOT / "packaging/shark-ai/requirements.txt" + + +def load_version_info(version_file): + with open(version_file, "rt") as f: + return json.load(f) + + +def write_requirements(package_list, package_version): + with open(REQUIREMENTS_TXT, "w") as f: + for package in package_list: + PINNED_PACKAGE = package + "==" + package_version + f.write("%s\n" % PINNED_PACKAGE) + + +def append_requirements(package_list, package_version): + with open(REQUIREMENTS_TXT, "a") as f: + for package in package_list: + PINNED_PACKAGE = package + "==" + package_version + f.write("%s\n" % PINNED_PACKAGE) + + +metapackage_version = load_version_info(VERSION_FILE_LOCAL) +PACKAGE_VERSION = metapackage_version.get("package-version") + +sharktank_version = load_version_info(VERSION_FILE_SHARKTANK) +SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version") + +shortfin_version = load_version_info(VERSION_FILE_SHORTFIN) +SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version") + +stable_packages_list = ["iree-base-compiler", "iree-base-runtime", "iree-turbine"] + +if Version(PACKAGE_VERSION).is_prerelease: + write_requirements( + ["sharktank"], + Version(SHARKTANK_PACKAGE_VERSION).base_version + "rc" + args.version_suffix, + ) + append_requirements( + ["shortfin"], + Version(SHORTFIN_PACKAGE_VERSION).base_version + "rc" + args.version_suffix, + ) +else: + MAJOR_VERSION = Version(PACKAGE_VERSION).major + MINOR_VERSION = Version(PACKAGE_VERSION).minor + + write_requirements( + stable_packages_list, str(MAJOR_VERSION) + "." + str(MINOR_VERSION) + ".*" + ) + append_requirements(["sharktank"], Version(SHARKTANK_PACKAGE_VERSION).base_version) + append_requirements(["shortfin"], Version(SHORTFIN_PACKAGE_VERSION).base_version) diff --git a/shark-ai/.gitignore b/shark-ai/.gitignore new file mode 100644 index 000000000..8e68ab1b5 --- /dev/null +++ b/shark-ai/.gitignore @@ -0,0 +1,3 @@ +# Local-only config options +version_local.json +requirements.txt diff --git a/shark-ai/README.md b/shark-ai/README.md new file mode 100644 index 000000000..93bdfd671 --- /dev/null +++ b/shark-ai/README.md @@ -0,0 +1,3 @@ +# SHARK AI meta package + +Meta package to install `sharktank` and `shortfin`. diff --git a/shark-ai/pyproject.toml b/shark-ai/pyproject.toml new file mode 100644 index 000000000..5a7493ec7 --- /dev/null +++ b/shark-ai/pyproject.toml @@ -0,0 +1,38 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "shark-ai" +authors = [ + {name = "SHARK Authors"}, +] +description = "SHARK AI meta package" +readme = "README.md" +license = {text = "Apache-2.0"} +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +requires-python = ">= 3.10" +# Version is set via the `setup.py` and requirements are set via files below. +dynamic = ["version", "dependencies"] + +[project.urls] +Repository = "https://github.com/nod-ai/SHARK-Platform" + +[project.optional-dependencies] +onnx = [ + "iree-base-compiler[onnx]", +] + +[tool.setuptools] +packages = [] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} diff --git a/shark-ai/setup.py b/shark-ai/setup.py new file mode 100644 index 000000000..5ceac55bd --- /dev/null +++ b/shark-ai/setup.py @@ -0,0 +1,33 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import os +from pathlib import Path + +from setuptools import setup + +THIS_DIR = Path(__file__).parent.resolve() + +# Setup and get version information. +# The `version_local.json` is generated by calling: +# `build_tools/python_deploy/compute_common_version.py -stable --write-json` +VERSION_FILE_LOCAL = os.path.join(THIS_DIR, "version_local.json") + + +def load_version_info(version_file): + with open(version_file, "rt") as f: + return json.load(f) + + +version_info = load_version_info(VERSION_FILE_LOCAL) + +PACKAGE_VERSION = version_info.get("package-version") +print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'") + +setup( + version=f"{PACKAGE_VERSION}", +) From 033cfb5616c94122b9c7b5587b2a2667d2607c3f Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Mon, 11 Nov 2024 23:52:15 +0100 Subject: [PATCH 14/59] Fix path to `shark-ai` package (#478) --- build_tools/python_deploy/compute_common_version.py | 2 +- build_tools/python_deploy/write_requirements.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build_tools/python_deploy/compute_common_version.py b/build_tools/python_deploy/compute_common_version.py index accdc3d28..ba5e653fb 100644 --- a/build_tools/python_deploy/compute_common_version.py +++ b/build_tools/python_deploy/compute_common_version.py @@ -38,7 +38,7 @@ VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version_info.json" VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version_info.json" -VERSION_FILE_LOCAL = REPO_ROOT / "packaging/shark-ai/version_local.json" +VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json" def load_version_info(version_file): diff --git a/build_tools/python_deploy/write_requirements.py b/build_tools/python_deploy/write_requirements.py index 346d46ae1..6ad7c10f5 100644 --- a/build_tools/python_deploy/write_requirements.py +++ b/build_tools/python_deploy/write_requirements.py @@ -35,8 +35,8 @@ VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version_info.json" VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version_info.json" -VERSION_FILE_LOCAL = REPO_ROOT / "packaging/shark-ai/version_local.json" -REQUIREMENTS_TXT = REPO_ROOT / "packaging/shark-ai/requirements.txt" +VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json" +REQUIREMENTS_TXT = REPO_ROOT / "shark-ai/requirements.txt" def load_version_info(version_file): From 074bc6674f42e73aa364a46c8847e8dd338de092 Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:04:03 -0800 Subject: [PATCH 15/59] Pin actions-gh-pages to latest hash Co-authored-by: Marius Brehler --- .github/workflows/ci_eval.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 8f08a68aa..5a0e7537d 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -133,7 +133,7 @@ jobs: run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=perplexity/perplexity_torch.html - name: Deploy to GitHub Pages - uses: peaceiris/actions-gh-pages@v3 + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 with: github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} publish_dir: ./perplexity From f95023fc97d67740ac7706899347de8e98f3fbfa Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Mon, 11 Nov 2024 15:04:25 -0800 Subject: [PATCH 16/59] Pin actions-gh-pages to latest hash Co-authored-by: Marius Brehler --- .github/workflows/ci_eval.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 5a0e7537d..6a1ca40ef 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -79,7 +79,7 @@ jobs: run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=perplexity/perplexity_iree.html - name: Deploy to GitHub Pages - uses: peaceiris/actions-gh-pages@v3 + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 with: github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} publish_dir: ./perplexity From feb9c4bac9fb2d2a8f754f565e548c7cd045c09e Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 11 Nov 2024 15:21:10 -0800 Subject: [PATCH 17/59] Rework embedding table to not use complex numbers (#432) Interleaved writes are problematic for fusions. Reworked the embedding update to better work with non-complex writes. This should better support fusions. --- .../layers/paged_llama_attention_block.py | 8 +- .../sharktank/layers/rotary_embedding.py | 171 +++++++----------- sharktank/sharktank/types/tensors.py | 23 ++- .../layers/sharded_rotary_embedding_test.py | 6 +- 4 files changed, 91 insertions(+), 117 deletions(-) diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 796e8224a..29aaa8d7c 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -104,11 +104,11 @@ def forward( # Fast path to start_index based embedding lookup if available. # Falls back to a slower position based index lookup. if start_index is not None: - xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index) + xq = embedding.forward(xt=xq, start_index=start_index) + xk = embedding.forward(xt=xk, start_index=start_index) else: - xq, xk = embedding.apply_batched_mask( - xq=xq, xk=xk, mask=embedding_batch_mask - ) + xq = embedding.apply_batched_mask(xt=xq, mask=embedding_batch_mask) + xk = embedding.apply_batched_mask(xt=xk, mask=embedding_batch_mask) # Full sequence length. kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index 18a95aba3..c11a2d126 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -53,49 +53,38 @@ def rotary_embed_table(self): return self.static_rotary_embed_table return self._create_rotary_embed_table() - if self.tensor_parallelism_size == 1: - return None - - nt = namedtuple("replicated_tensor", ["shards"]) - return nt([None] * self.tensor_parallelism_size) + return None def forward( self, *, - xq: Union[torch.Tensor, SplitPrimitiveTensor], - xk: Union[torch.Tensor, SplitPrimitiveTensor], + xt: Union[torch.Tensor, SplitPrimitiveTensor], start_index: int, ): - if isinstance(xq, SplitPrimitiveTensor): - assert ( - isinstance(xk, SplitPrimitiveTensor) - and xq.shard_count == xk.shard_count - and xk.shard_dim == xq.shard_dim - ) - assert ( - isinstance(self.rotary_embed_table, ReplicatedTensor) - and xq.shard_count == self.rotary_embed_table.shard_count - ) - xqk_shards = [ + if isinstance(xt, SplitPrimitiveTensor): + rotary_shards = [None] * xt.shard_count + if self.rotary_embed_table is not None: + assert ( + isinstance(self.rotary_embed_table, ReplicatedTensor) + and xt.shard_count == self.rotary_embed_table.shard_count + ) + rotary_shards = [ + unbox_tensor(shard) for shard in self.rotary_embed_table.shards + ] + + xt_shards = [ self.forward_unsharded( - xq=unbox_tensor(xq_shard), - xk=unbox_tensor(xk_shard), + xt=unbox_tensor(xt_shard), start_index=start_index, - rotary_embed_table=unbox_tensor(rotary_embed_table_shard), - ) - for xq_shard, xk_shard, rotary_embed_table_shard in zip( - xq.shards, xk.shards, self.rotary_embed_table.shards + rotary_embed_table=rotary_shard, ) + for xt_shard, rotary_shard in zip(xt.shards, rotary_shards) ] - xq_shards = [xqk[0] for xqk in xqk_shards] - xk_shards = [xqk[1] for xqk in xqk_shards] - xq = SplitPrimitiveTensor(ts=xq_shards, shard_dim=xq.shard_dim) - xk = SplitPrimitiveTensor(ts=xk_shards, shard_dim=xk.shard_dim) - return xq, xk + xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + return xt else: return self.forward_unsharded( - xq=xq, - xk=xk, + xt=xt, start_index=start_index, rotary_embed_table=self.rotary_embed_table, ) @@ -103,8 +92,7 @@ def forward( def forward_unsharded( self, *, - xq: torch.Tensor, - xk: torch.Tensor, + xt: torch.Tensor, start_index: int, rotary_embed_table: Optional[torch.Tensor], ): @@ -149,44 +137,39 @@ def create_ordering_tensor(dim): return order_tensor if self.use_hf: - xq = xq[..., create_interleaved_tensor(xq.shape[-1])] - xk = xk[..., create_interleaved_tensor(xq.shape[-1])] - - xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2))) - xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2))) - _, sl, _, dim = xq_.shape + xt = xt[..., create_interleaved_tensor(xt.shape[-1])] + xt_ = xt.unflatten(-1, (-1, 2)) + _, sl, _, dim, _ = xt_.shape # Offset the table based on starting position. if self.use_table: - freqs_cis = rotary_embed_table[start_index : start_index + sl, :] + freqs_cis = rotary_embed_table[:, start_index : start_index + sl, :] else: - freqs_cis = torch.arange(start_index, start_index + sl, device=xq.device) + freqs_cis = torch.arange(start_index, start_index + sl, device=xt.device) freqs_cis = self._compute_rotary_embed_table(freqs_cis) - freqs_cis = self._replicate(freqs_cis) assert freqs_cis.shape[-1] == dim assert ( - freqs_cis.shape[0] >= sl + freqs_cis.shape[1] >= sl ), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})" - broadcast_freqs_cis = freqs_cis[None, 0:sl, None, :] + broadcast_freqs_cis = freqs_cis[:, None, 0:sl, None, :] - if self.use_hf: - xq_out = torch.view_as_real( - self.complex_multiply(xq_, broadcast_freqs_cis) - ).flatten(3) - xk_out = torch.view_as_real( - self.complex_multiply(xk_, broadcast_freqs_cis) - ).flatten(3) + cos = broadcast_freqs_cis[0] + sin = broadcast_freqs_cis[1] + xt_r = xt_[..., 0] + xt_i = xt_[..., 1] + + xt_out_r = xt_r * cos - xt_i * sin + xt_out_i = xt_i * cos + xt_r * sin - xq_out = xq_out[..., create_ordering_tensor(xq_out.shape[-1])] - xk_out = xk_out[..., create_ordering_tensor(xq_out.shape[-1])] + xt_out = torch.concatenate((xt_out_r, xt_out_i), dim=-1) - return xq_out.type_as(xq), xk_out.type_as(xk) + if self.use_hf: + xt_out = xt_out[..., create_ordering_tensor(xt_out.shape[-1])] + return xt_out.type_as(xt) - xq_out = torch.view_as_real(xq_ * broadcast_freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * broadcast_freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) + return xt_out.type_as(xt) def complex_multiply(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """Function for elementwise-multiplication of two complex torch tensors. @@ -224,11 +207,11 @@ def compute_batch_mask( self.trace_tensor("rope.positions_seq", positions_seq) if self.use_table: - freqs_cis = self.rotary_embed_table[positions_seq] + freqs_cis = self.rotary_embed_table[:, positions_seq] else: shape = positions_seq.shape freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) - freqs_cis = freqs_cis.unflatten(0, shape) + freqs_cis = freqs_cis.unflatten(1, shape) # Unsqueeze a unit dim for attention heads. broadcast_freqs_cis = freqs_cis.unsqueeze(2) @@ -237,41 +220,24 @@ def compute_batch_mask( def apply_batched_mask( self, *, - xq: Union[torch.Tensor, SplitPrimitiveTensor], - xk: Union[torch.Tensor, SplitPrimitiveTensor], + xt: Union[torch.Tensor, SplitPrimitiveTensor], mask: Union[torch.Tensor, ReplicatedTensor], ): - if isinstance(xq, SplitPrimitiveTensor): - assert ( - isinstance(xk, SplitPrimitiveTensor) - and xq.shard_count == xk.shard_count - and xk.shard_dim == xq.shard_dim - ) - assert ( - isinstance(mask, ReplicatedTensor) - and mask.shard_count == xq.shard_count + if not isinstance(xt, SplitPrimitiveTensor): + return self.apply_batched_mask_unsharded(xt=xt, mask=mask) + + assert isinstance(mask, ReplicatedTensor) and mask.shard_count == xt.shard_count + xt_shards = [ + self.apply_batched_mask_unsharded( + xt=unbox_tensor(xt_shard), + mask=unbox_tensor(mask_shard), ) - xqk_shards = [ - self.apply_batched_mask_unsharded( - xq=unbox_tensor(xq_shard), - xk=unbox_tensor(xk_shard), - mask=unbox_tensor(mask_shard), - ) - for xq_shard, xk_shard, mask_shard in zip( - xq.shards, xk.shards, mask.shards - ) - ] - xq_shards = [xqk[0] for xqk in xqk_shards] - xk_shards = [xqk[1] for xqk in xqk_shards] - xq = SplitPrimitiveTensor(ts=xq_shards, shard_dim=xq.shard_dim) - xk = SplitPrimitiveTensor(ts=xk_shards, shard_dim=xk.shard_dim) - return xq, xk - else: - return self.apply_batched_mask_unsharded(xq=xq, xk=xk, mask=mask) + for xt_shard, mask_shard in zip(xt.shards, mask.shards) + ] + xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim) + return xt - def apply_batched_mask_unsharded( - self, *, xq: torch.Tensor, xk: torch.Tensor, mask: torch.Tensor - ): + def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): """Applies the embedding to a ragged batch of queries and keys. This does a more complicated indexing operation for cases when the each @@ -281,13 +247,17 @@ def apply_batched_mask_unsharded( """ # xq_, xk_ shape: bs, sl, _, dim # freqs_cis shape: max_sl, dim - xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2))) - xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2))) - _, sl, _, dim = xq_.shape + cos = mask[0] + sin = mask[1] - xq_out = torch.view_as_real(xq_ * mask).flatten(3) - xk_out = torch.view_as_real(xk_ * mask).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) + xt_ = xt.unflatten(-1, (-1, 2)) + xt_r = xt_[..., 0] + xt_i = xt_[..., 1] + + xt_out_r = xt_r * cos - xt_i * sin + xt_out_i = xt_r * sin + xt_i * cos + xt_out = torch.concatenate((xt_out_r, xt_out_i), dim=-1) + return xt_out.type_as(xt) def _compute_rotary_embed_table(self, t): dim = self.rope_dimension_count @@ -297,13 +267,10 @@ def _compute_rotary_embed_table(self, t): ) freqs = torch.outer(t, freqs).float() - freqs_cis = ( - torch.complex(torch.cos(freqs), torch.sin(freqs)) - if self.use_hf - else torch.polar(torch.ones_like(freqs), freqs) - ) + cos = torch.cos(freqs).unsqueeze(0) + sin = torch.sin(freqs).unsqueeze(0) - return freqs_cis + return torch.concatenate((cos, sin), dim=0) def _create_rotary_embed_table(self): t = torch.arange(self.max_seqlen, device=self.device) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index df8df075b..87a40fb7b 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -543,9 +543,10 @@ def _clone_with_globals( ) -> "InferenceTensor": return DefaultPrimitiveTensor(name=self.name, data=new_globals[self.name]) - def __getitem__(self, keys): - if not isinstance(keys, list) and not isinstance(keys, tuple): - keys = [keys] + def __getitem__(self, key): + keys = [key] + if isinstance(key, tuple) or isinstance(key, list): + keys = key keys = [ unbox_tensor(key) if isinstance(key, PrimitiveTensor) else key @@ -1188,15 +1189,19 @@ def create( raise IOError(f"Missing component tensor '' in {raw_tensors.keys()}") from e return cls(name=name, ts=ts) - def __getitem__(self, keys): - if not isinstance(keys, list) and not isinstance(keys, tuple): - keys = [keys] + def __getitem__(self, key): + keys = [key] + if isinstance(keys, tuple) or isinstance(keys, list): + keys = key shards = [] for i, shard in enumerate(self.shards): - shard_keys = [ - k.shards[i] if isinstance(k, ReplicatedTensor) else k for k in keys - ] + shard_keys = [] + for k in keys: + if isinstance(k, ReplicatedTensor): + shard_keys.append(k.shards[i]) + else: + shard_keys.append(k) shards.append(shard[*shard_keys]) return ReplicatedTensor(ts=shards) diff --git a/sharktank/tests/layers/sharded_rotary_embedding_test.py b/sharktank/tests/layers/sharded_rotary_embedding_test.py index 963b9b432..f24b8313a 100644 --- a/sharktank/tests/layers/sharded_rotary_embedding_test.py +++ b/sharktank/tests/layers/sharded_rotary_embedding_test.py @@ -35,7 +35,8 @@ def test_sharded_rotary_table(): max_seqlen=max_seqlen, rope_freq_base=rope_freq_base, ) - oq, ok = default_layer(xq=xq, xk=xk, start_index=0) + oq = default_layer(xt=xq, start_index=0) + ok = default_layer(xt=xk, start_index=0) # Then we can shard the same inputs and layer xq = SplitPrimitiveTensor(ts=xq, shard_dim=2, shard_count=4) @@ -46,7 +47,8 @@ def test_sharded_rotary_table(): rope_freq_base=rope_freq_base, tensor_parallelism_size=4, ) - sq, sk = shard_layer(xq=xq, xk=xk, start_index=0) + sq = shard_layer(xt=xq, start_index=0) + sk = shard_layer(xt=xk, start_index=0) # Gathering and unboxing should yield the same results sq = ops.unshard(sq) From fdf5dd9186966ff3582f3395bb901c58608261c6 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 12 Nov 2024 00:25:35 +0100 Subject: [PATCH 18/59] Increment package version to 2.9.1 (#479) To match with X.Y of `iree-{base-compiler,base-runtime,turbine}`, the patch level is increased instead of the minor version. --- sharktank/version_info.json | 2 +- shortfin/version_info.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sharktank/version_info.json b/sharktank/version_info.json index ca3c0ed0b..794a2de28 100644 --- a/sharktank/version_info.json +++ b/sharktank/version_info.json @@ -1,3 +1,3 @@ { - "package-version": "2.9.0.dev" + "package-version": "2.9.1.dev" } diff --git a/shortfin/version_info.json b/shortfin/version_info.json index ca3c0ed0b..794a2de28 100644 --- a/shortfin/version_info.json +++ b/shortfin/version_info.json @@ -1,3 +1,3 @@ { - "package-version": "2.9.0.dev" + "package-version": "2.9.1.dev" } From 35bc60d321af76cef4b82d8383456300018b65dc Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 11 Nov 2024 18:12:26 -0800 Subject: [PATCH 19/59] [shortfin] Adds C++ tracing scopes. (#480) * Manually chose a number of things that would be interesting to see in a trace and added annotations. * QC'd with some unit test traces. --- shortfin/python/array_binding.cc | 9 +++++++++ shortfin/python/array_host_ops.cc | 2 ++ shortfin/python/lib_ext.cc | 10 ++++++++++ shortfin/src/shortfin/array/array.cc | 3 +++ shortfin/src/shortfin/array/storage.cc | 6 ++++++ shortfin/src/shortfin/array/xtensor_bridge.cc | 2 ++ shortfin/src/shortfin/local/program.cc | 7 +++++++ shortfin/src/shortfin/local/scheduler.cc | 4 ++++ shortfin/src/shortfin/local/system.cc | 3 +++ shortfin/src/shortfin/local/systems/amdgpu.cc | 2 ++ shortfin/src/shortfin/local/systems/host.cc | 3 +++ shortfin/src/shortfin/local/worker.cc | 1 + shortfin/src/shortfin/support/logging.h | 8 ++++++++ 13 files changed, 60 insertions(+) diff --git a/shortfin/python/array_binding.cc b/shortfin/python/array_binding.cc index da7197b14..a05232674 100644 --- a/shortfin/python/array_binding.cc +++ b/shortfin/python/array_binding.cc @@ -7,6 +7,7 @@ #include "./lib_ext.h" #include "./utils.h" #include "shortfin/array/api.h" +#include "shortfin/support/logging.h" using namespace shortfin::array; @@ -223,6 +224,7 @@ class PyMapping { } void FillFromScalar(Refs *refs, py::handle value) { + SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::FillFromScalar"); if (!dtype()) { throw std::invalid_argument( "The `fill` method is only valid for typed mappings but " @@ -242,6 +244,7 @@ class PyMapping { } void FillFromBuffer(py::handle buffer) { + SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::FillFromBuffer"); Py_buffer py_view; int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND. if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) { @@ -286,6 +289,7 @@ class PyMapping { } py::object GetItems(py::handle self_obj, Refs *refs) { + SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::GetItems"); if (!dtype()) { throw std::invalid_argument( "The `items` property is only valid for typed mappings but " @@ -306,6 +310,7 @@ class PyMapping { } void SetItems(Refs *refs, py::handle initializer) { + SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::SetItems"); if (!dtype()) { throw std::invalid_argument( "The `items` property is only valid for typed mappings but " @@ -410,6 +415,7 @@ void BindArray(py::module_ &m) { .def( "map", [](storage &self, bool read, bool write, bool discard) { + SHORTFIN_TRACE_SCOPE_NAMED("PyStorage::map"); int access = 0; if (read) access |= IREE_HAL_MEMORY_ACCESS_READ; if (write || discard) access |= IREE_HAL_MEMORY_ACCESS_WRITE; @@ -565,6 +571,7 @@ void BindArray(py::module_ &m) { .def( "map", [](device_array &self, bool read, bool write, bool discard) { + SHORTFIN_TRACE_SCOPE_NAMED("PyArray::map"); int access = 0; if (read) access |= IREE_HAL_MEMORY_ACCESS_READ; if (write || discard) access |= IREE_HAL_MEMORY_ACCESS_WRITE; @@ -586,6 +593,7 @@ void BindArray(py::module_ &m) { .def_prop_rw( "items", [refs](device_array &self) { + SHORTFIN_TRACE_SCOPE_NAMED("PyArray::items"); PyMapping *mapping; py::object mapping_obj = CreateMappingObject(&mapping); mapping->set_dtype(self.dtype()); @@ -606,6 +614,7 @@ void BindArray(py::module_ &m) { .def_prop_ro( "__array_interface__", [refs](device_array &self) { + SHORTFIN_TRACE_SCOPE_NAMED("PyArray::__array_interface__"); py::dict interface; interface["version"] = 3; interface["strides"] = py::none(); diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc index 8c4af0070..8eaf4ddd8 100644 --- a/shortfin/python/array_host_ops.cc +++ b/shortfin/python/array_host_ops.cc @@ -95,6 +95,7 @@ void BindArrayHostOps(py::module_ &m) { "argmax", [](device_array &input, int axis, std::optional out, bool keepdims, bool device_visible) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::argmax"); if (axis < 0) axis += input.shape().size(); if (axis < 0 || axis >= input.shape().size()) { throw std::invalid_argument( @@ -139,6 +140,7 @@ void BindArrayHostOps(py::module_ &m) { m.def( "fill_randn", [](device_array out, std::optional gen) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::fill_randn"); if (!gen) gen = &PyRandomGenerator::get_default(); auto compute = [&]() { auto result = xt::random::randn(out.shape_container(), /*mean=*/0.0, diff --git a/shortfin/python/lib_ext.cc b/shortfin/python/lib_ext.cc index 0bfb67588..d17606b4b 100644 --- a/shortfin/python/lib_ext.cc +++ b/shortfin/python/lib_ext.cc @@ -173,6 +173,7 @@ class PyWorkerExtension : public local::Worker::Extension { py::handle loop() { return loop_; } void OnThreadStart() noexcept override { + SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::OnThreadStart"); // Python threading initialization. // If our own thread, teach Python about it. Not done for donated. if (worker().options().owned_thread) { @@ -187,6 +188,7 @@ class PyWorkerExtension : public local::Worker::Extension { } void OnThreadStop() noexcept override { + SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::OnThreadStop"); { // Do Python level thread cleanup. py::gil_scoped_acquire g; @@ -253,6 +255,7 @@ class PyProcess : public local::detail::BaseProcess { std::bind(&PyProcess::RunOnWorker, self_object)); } static void RunOnWorker(py::handle self_handle) { + SHORTFIN_TRACE_SCOPE_NAMED("PyProcess:RunOnWorker"); py::gil_scoped_acquire g; // Steal the reference back from ScheduleOnWorker. Important: this is // very likely the last reference to the process. So self must not be @@ -342,6 +345,7 @@ py::object PyRehydrateRef(local::ProgramInvocation *inv, py::object RunInForeground(std::shared_ptr refs, local::System &self, py::object coro) { + SHORTFIN_TRACE_SCOPE_NAMED("CoroRunInForeground"); bool is_main_thread = refs->threading_current_thread().is(refs->threading_main_thread()); @@ -936,6 +940,7 @@ void BindLocal(py::module_ &m) { callable.inc_ref(); // Stolen within the callback. auto thunk = +[](void *user_data, iree_loop_t loop, iree_status_t status) noexcept -> iree_status_t { + SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::Callback"); py::gil_scoped_acquire g; py::object user_callable = py::steal(static_cast(user_data)); @@ -955,6 +960,7 @@ void BindLocal(py::module_ &m) { callable.inc_ref(); // Stolen within the callback. auto thunk = +[](void *user_data, iree_loop_t loop, iree_status_t status) noexcept -> iree_status_t { + SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::DelayCallback"); py::gil_scoped_acquire g; py::object user_callable = py::steal(static_cast(user_data)); @@ -1030,6 +1036,7 @@ void BindLocal(py::module_ &m) { py::class_(m, "CompletionEvent") .def(py::init<>()) .def("__await__", [](py::handle self_obj) { + SHORTFIN_TRACE_SCOPE_NAMED("PyCompletionEvent::__await__"); auto &worker_ext = PyWorkerExtension::GetCurrent(); auto &self = py::cast(self_obj); py::object future = worker_ext.loop().attr("create_future")(); @@ -1051,6 +1058,7 @@ void BindLocal(py::module_ &m) { self, iree_infinite_timeout(), +[](void *future_vp, iree_loop_t loop, iree_status_t status) noexcept -> iree_status_t { + SHORTFIN_TRACE_SCOPE_NAMED("PyCompletionEvent::OnComplete"); py::gil_scoped_acquire g; py::object future = py::steal(static_cast(future_vp)); try { @@ -1145,6 +1153,7 @@ void BindLocal(py::module_ &m) { return py::none(); }) .def("__await__", [](py::handle self_obj) { + SHORTFIN_TRACE_SCOPE_NAMED("PyFuture::__await__"); // TODO: We should make our C++ future able to be used directly // vs needing to bridge it like this. auto &worker_ext = PyWorkerExtension::GetCurrent(); @@ -1166,6 +1175,7 @@ void BindLocal(py::module_ &m) { self.AddCallback( [py_future_vp = static_cast(future.release().ptr())]( local::Future &sf_future) { + SHORTFIN_TRACE_SCOPE_NAMED("PyFuture::OnComplete"); py::gil_scoped_acquire g; py::object py_future = py::steal(static_cast(py_future_vp)); diff --git a/shortfin/src/shortfin/array/array.cc b/shortfin/src/shortfin/array/array.cc index 11961b449..882e4ef39 100644 --- a/shortfin/src/shortfin/array/array.cc +++ b/shortfin/src/shortfin/array/array.cc @@ -64,6 +64,7 @@ mapping device_array::data_rw() { return storage_.map_read_write(); } mapping device_array::data_w() { return storage_.map_write_discard(); } std::optional device_array::map_memory_for_xtensor() { + SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::map_memory_for_xtensor"); if (storage_.is_mappable_for_read_write()) { return storage_.map_read_write(); } else if (storage_.is_mappable_for_read()) { @@ -97,6 +98,7 @@ std::string device_array::to_s() const { void device_array::AddAsInvocationArgument( local::ProgramInvocation *inv, local::ProgramResourceBarrier barrier) { + SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::AddAsInvocationArgument"); auto dims_span = shape(); iree_hal_buffer_view_t *buffer_view; SHORTFIN_THROW_IF_ERROR(iree_hal_buffer_view_create( @@ -117,6 +119,7 @@ iree_vm_ref_type_t device_array::invocation_marshalable_type() { device_array device_array::CreateFromInvocationResultRef( local::ProgramInvocation *inv, iree::vm_opaque_ref ref) { + SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::CreateFromInvocationResultRef"); // We don't retain the buffer view in the device array, so just deref it // vs stealing the ref. iree_hal_buffer_view_t *bv = iree_hal_buffer_view_deref(*ref.get()); diff --git a/shortfin/src/shortfin/array/storage.cc b/shortfin/src/shortfin/array/storage.cc index a30dbf450..ffbbd9ba2 100644 --- a/shortfin/src/shortfin/array/storage.cc +++ b/shortfin/src/shortfin/array/storage.cc @@ -43,6 +43,7 @@ storage storage::import_buffer(local::ScopedDevice &device, storage storage::allocate_device(ScopedDevice &device, iree_device_size_t allocation_size) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::allocate_device"); if (!device.raw_device()) { throw std::invalid_argument("Cannot allocate with a null device affinity"); } @@ -63,6 +64,7 @@ storage storage::allocate_device(ScopedDevice &device, storage storage::allocate_host(ScopedDevice &device, iree_device_size_t allocation_size, bool device_visible) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::allocate_host"); if (!device.raw_device()) { throw std::invalid_argument("Cannot allocate with a null device affinity"); } @@ -207,6 +209,7 @@ std::string storage::formatted_buffer_usage() const { void storage::AddAsInvocationArgument(local::ProgramInvocation *inv, local::ProgramResourceBarrier barrier) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::AddAsInvocationArgument"); iree::vm_opaque_ref ref; *(&ref) = iree_hal_buffer_retain_ref(buffer_); inv->AddArg(std::move(ref)); @@ -220,6 +223,7 @@ iree_vm_ref_type_t storage::invocation_marshalable_type() { storage storage::CreateFromInvocationResultRef(local::ProgramInvocation *inv, iree::vm_opaque_ref ref) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::CreateFromInvocationResultRef"); // Steal the ref to one of our smart pointers. // TODO: Should have an opaque_ref::release(). iree::hal_buffer_ptr buffer = @@ -230,6 +234,7 @@ storage storage::CreateFromInvocationResultRef(local::ProgramInvocation *inv, storage storage::ImportInvocationResultStorage(local::ProgramInvocation *inv, iree::hal_buffer_ptr buffer) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::ImportInvocationResultStorage"); local::ScopedDevice device = local::ScopedDevice(*inv->fiber(), inv->device_selection()); auto imported_storage = storage::import_buffer(device, std::move(buffer)); @@ -251,6 +256,7 @@ storage storage::ImportInvocationResultStorage(local::ProgramInvocation *inv, void storage::AddInvocationArgBarrier(local::ProgramInvocation *inv, local::ProgramResourceBarrier barrier) { + SHORTFIN_TRACE_SCOPE_NAMED("storage::AddInvocationArgBarrier"); switch (barrier) { case ProgramResourceBarrier::DEFAULT: case ProgramResourceBarrier::READ: diff --git a/shortfin/src/shortfin/array/xtensor_bridge.cc b/shortfin/src/shortfin/array/xtensor_bridge.cc index bd3753331..da350b71a 100644 --- a/shortfin/src/shortfin/array/xtensor_bridge.cc +++ b/shortfin/src/shortfin/array/xtensor_bridge.cc @@ -8,6 +8,7 @@ #include +#include "shortfin/support/logging.h" #include "xtl/xhalf_float.hpp" namespace shortfin::array { @@ -56,6 +57,7 @@ class typed_xt_methods final : public poly_xt_methods { bool poly_xt_methods::inplace_new(uint8_t *inst_storage, DType dtype, void *array_memory, size_t array_memory_size, Dims &dims) { + SHORTFIN_TRACE_SCOPE_NAMED("array_xtensor_cast"); #define POLY_XT_CASE(et, cpp_type) \ case et: \ typed_xt_methods::concrete_inplace_new( \ diff --git a/shortfin/src/shortfin/local/program.cc b/shortfin/src/shortfin/local/program.cc index 3fd41d87b..6ab1f47ae 100644 --- a/shortfin/src/shortfin/local/program.cc +++ b/shortfin/src/shortfin/local/program.cc @@ -75,6 +75,7 @@ std::string_view ProgramFunction::calling_convention() const { ProgramInvocation::Ptr ProgramFunction::CreateInvocation( std::shared_ptr fiber, std::optional isolation) { + SHORTFIN_TRACE_SCOPE_NAMED("ProgramFunction::CreateInvocation"); ProgramIsolation actual_isolation = isolation ? *isolation : isolation_; // Low-overhead NONE isolation handling (saves some ref-count twiddling). if (actual_isolation == ProgramIsolation::NONE) { @@ -101,6 +102,7 @@ std::string ProgramFunction::to_s() const { ProgramModule ProgramModule::Load(System &system, const std::filesystem::path &path, bool mmap) { + SHORTFIN_TRACE_SCOPE_NAMED("ProgramModule::Load"); iree::file_contents_ptr contents; iree_file_read_flags_t flags = mmap ? IREE_FILE_READ_FLAG_MMAP : IREE_FILE_READ_FLAG_PRELOAD; @@ -171,6 +173,7 @@ std::vector ProgramModule::exports() const { Program Program::Load(std::span modules, Options &&options) { + SHORTFIN_TRACE_SCOPE_NAMED("Program::Load"); std::vector all_modules; std::vector raw_devices; @@ -451,6 +454,7 @@ iree_status_t ProgramInvocation::FinalizeCallingConvention( ProgramInvocation::Future ProgramInvocation::Invoke( ProgramInvocation::Ptr invocation) { + SHORTFIN_TRACE_SCOPE_NAMED("ProgramInvocation::Invoke"); invocation->CheckNotScheduled(); Worker &worker = invocation->fiber_->worker(); @@ -462,9 +466,11 @@ ProgramInvocation::Future ProgramInvocation::Invoke( iree_vm_function_t function, ProgramInvocationModel invocation_model, std::optional failure_future) { + SHORTFIN_TRACE_SCOPE_NAMED("ProgramInvocation::InvokeAsync"); auto complete_callback = [](void *user_data, iree_loop_t loop, iree_status_t status, iree_vm_list_t *outputs) noexcept -> iree_status_t { + SHORTFIN_TRACE_SCOPE_NAMED("ProgramInvocation::Complete"); // Async invocation helpfully gives us a retained reference to the // outputs, but we already have one statically on the // ProgramInvocation. So release this one, which makes it safe to @@ -620,6 +626,7 @@ StaticProgramParameters::StaticProgramParameters( void StaticProgramParameters::Load(std::filesystem::path file_path, LoadOptions options) { + SHORTFIN_TRACE_SCOPE_NAMED("StaticProgramParameters::Load"); // Default format from extension. if (options.format.empty()) { options.format = file_path.extension().string(); diff --git a/shortfin/src/shortfin/local/scheduler.cc b/shortfin/src/shortfin/local/scheduler.cc index 3b82ded20..883951a20 100644 --- a/shortfin/src/shortfin/local/scheduler.cc +++ b/shortfin/src/shortfin/local/scheduler.cc @@ -61,6 +61,7 @@ void Account::active_deps_extend(iree_hal_semaphore_list_t sem_list) { } VoidFuture Account::OnSync() { + SHORTFIN_TRACE_SCOPE_NAMED("Account::OnSync"); // TODO: Burn this path with fire! No attempt has been made to make this // particularly good: the backend is being implemented now to export // HAL semaphores via iree_hal_semaphore_await, and that should be used @@ -133,6 +134,7 @@ Scheduler::~Scheduler() { void Scheduler::Initialize( std::span> devices) { + SHORTFIN_TRACE_SCOPE_NAMED("Scheduler::Initialize"); for (auto &it : devices) { accounts_.emplace_back(*this, it.second); } @@ -165,6 +167,7 @@ Account &Scheduler::GetDefaultAccount(ScopedDevice &device) { void Scheduler::AppendCommandBuffer(ScopedDevice &device, TransactionType tx_type, std::function callback) { + SHORTFIN_TRACE_SCOPE_NAMED("Scheduler::AppendCommandBuffer"); Account &account = GetDefaultAccount(device); auto needed_affinity_bits = device.affinity().queue_affinity(); SHORTFIN_SCHED_LOG( @@ -242,6 +245,7 @@ void Scheduler::AppendCommandBuffer(ScopedDevice &device, } iree_status_t Scheduler::FlushWithStatus() noexcept { + SHORTFIN_TRACE_SCOPE_NAMED("Scheduler::FlushWithStatus"); // This loop is optimized for a small number of accounts, where it is // fine to just linearly probe. If this ever becomes cumbersome, we can // maintain a dirty list which is appended to when an account transitions diff --git a/shortfin/src/shortfin/local/system.cc b/shortfin/src/shortfin/local/system.cc index f5012c626..ef31bb001 100644 --- a/shortfin/src/shortfin/local/system.cc +++ b/shortfin/src/shortfin/local/system.cc @@ -20,6 +20,7 @@ namespace shortfin::local { System::System(iree_allocator_t host_allocator) : host_allocator_(host_allocator) { + SHORTFIN_TRACE_SCOPE_NAMED("System::System"); logging::construct("System", this); SHORTFIN_THROW_IF_ERROR(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT, host_allocator_, @@ -29,6 +30,7 @@ System::System(iree_allocator_t host_allocator) } System::~System() { + SHORTFIN_TRACE_SCOPE_NAMED("System::~System"); logging::destruct("System", this); bool needs_shutdown = false; { @@ -61,6 +63,7 @@ System::~System() { } void System::Shutdown() { + SHORTFIN_TRACE_SCOPE_NAMED("System::Shutdown"); // Stop workers. std::vector local_workers; { diff --git a/shortfin/src/shortfin/local/systems/amdgpu.cc b/shortfin/src/shortfin/local/systems/amdgpu.cc index 78efad709..cecedd1a0 100644 --- a/shortfin/src/shortfin/local/systems/amdgpu.cc +++ b/shortfin/src/shortfin/local/systems/amdgpu.cc @@ -87,6 +87,7 @@ void AMDGPUSystemBuilder::InitializeDefaultSettings() { void AMDGPUSystemBuilder::Enumerate() { if (hip_hal_driver_) return; + SHORTFIN_TRACE_SCOPE_NAMED("AMDGPUSystemBuilder::Enumerate"); iree_hal_hip_driver_options_t driver_options; iree_hal_hip_driver_options_initialize(&driver_options); @@ -127,6 +128,7 @@ std::vector AMDGPUSystemBuilder::GetAvailableDeviceIds() { } SystemPtr AMDGPUSystemBuilder::CreateSystem() { + SHORTFIN_TRACE_SCOPE_NAMED("AMDGPUSystemBuilder::CreateSystem"); auto lsys = std::make_shared(host_allocator()); Enumerate(); diff --git a/shortfin/src/shortfin/local/systems/host.cc b/shortfin/src/shortfin/local/systems/host.cc index 440a3ff51..1da4b2af1 100644 --- a/shortfin/src/shortfin/local/systems/host.cc +++ b/shortfin/src/shortfin/local/systems/host.cc @@ -125,6 +125,7 @@ HostCPUSystemBuilder::SelectHostCPUNodesFromOptions() { } SystemPtr HostCPUSystemBuilder::CreateSystem() { + SHORTFIN_TRACE_SCOPE_NAMED("HostCPUSystemBuilder::CreateSystem"); auto lsys = std::make_shared(host_allocator()); // TODO: Real NUMA awareness. lsys->InitializeNodes(1); @@ -136,6 +137,7 @@ SystemPtr HostCPUSystemBuilder::CreateSystem() { } iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) { + SHORTFIN_TRACE_SCOPE_NAMED("HostCPUSystemBuilder::InitializeHostCPUDriver"); // TODO: Kill these flag variants in favor of settings on the config // object. SHORTFIN_THROW_IF_ERROR(iree_task_executor_options_initialize_from_flags( @@ -206,6 +208,7 @@ iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) { void HostCPUSystemBuilder::InitializeHostCPUDevices(System &lsys, iree_hal_driver_t *driver) { + SHORTFIN_TRACE_SCOPE_NAMED("HostCPUSystemBuilder::InitializeHostCPUDevices"); iree_host_size_t device_info_count = 0; iree::allocated_ptr device_infos(host_allocator()); SHORTFIN_THROW_IF_ERROR(iree_hal_driver_query_available_devices( diff --git a/shortfin/src/shortfin/local/worker.cc b/shortfin/src/shortfin/local/worker.cc index 09207e5e4..d5ffafdbe 100644 --- a/shortfin/src/shortfin/local/worker.cc +++ b/shortfin/src/shortfin/local/worker.cc @@ -109,6 +109,7 @@ iree_status_t Worker::TransactLoop(iree_status_t signal_status) { for (auto& next_thunk : next_thunks_) { // TODO: Make thunks have to return a status, propagate, and handle // exceptions. + SHORTFIN_TRACE_SCOPE_NAMED("Worker::ThreadsafeCallback"); next_thunk(); } next_thunks_.clear(); diff --git a/shortfin/src/shortfin/support/logging.h b/shortfin/src/shortfin/support/logging.h index 7bc9e130d..e70c54e99 100644 --- a/shortfin/src/shortfin/support/logging.h +++ b/shortfin/src/shortfin/support/logging.h @@ -23,6 +23,14 @@ #define SHORTFIN_SCHED_LOG(...) #endif +// Tracing macros. These are currently just aliases of the underlying IREE +// macros, but we maintain the ability to redirect them in the future (i.e. +// for certain kinds of library builds, etc). +#define SHORTFIN_TRACE_SCOPE IREE_TRACE_SCOPE +#define SHORTFIN_TRACE_SCOPE_NAMED(name_literal) \ + IREE_TRACE_SCOPE_NAMED(name_literal) +#define SHORTFIN_TRACE_SCOPE_ID IREE_TRACE_SCOPE_ID + namespace shortfin::logging { SHORTFIN_API void InitializeFromEnv(); From 4759dbc9fc794a5c18115b309b727e8114ecd967 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 11 Nov 2024 20:46:50 -0800 Subject: [PATCH 20/59] [shortfin] Add conversion host ops. (#482) Ops added: `convert`, `round`, `ceil`, `floor`, `trunc` All ops were implemented to the same pattern, supporting fused conversion and output array. Fixes #315 --- shortfin/python/array_host_ops.cc | 314 ++++++++++++++++++++- shortfin/python/shortfin/array/__init__.py | 10 + shortfin/tests/api/array_ops_test.py | 101 +++++++ 3 files changed, 417 insertions(+), 8 deletions(-) diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc index 8eaf4ddd8..86385cfee 100644 --- a/shortfin/python/array_host_ops.cc +++ b/shortfin/python/array_host_ops.cc @@ -38,6 +38,34 @@ Implemented for dtypes: float16, float32. A device_array of dtype=int64, allocated on the host and not visible to the device. )"; +static const char DOCSTRING_CONVERT[] = + R"(Does an elementwise conversion from one dtype to another. + +The same behavior exists for several conversion ops: + +* `convert` : element-wise conversion like a static cast. +* `round` : element-wise nearest integer to the input, rounding halfway cases + away from zero. +* `ceil` : element-wise smallest integer value not less than the input. +* `floor` : element-wise smallest integer value not greater than the input. +* `trunc` : element-wise nearest integer not greater in magnitude than the input. + +For nearest-integer conversions (round, ceil, floor, trunc), the input dtype +must be a floating point array, and the output must be a byte-aligned integer +type between 8 and 32 bits. + +Args: + input: An input array of a floating point dtype. + dtype: If given, then this is the explicit output dtype. + out: If given, then the results are written to this array. This implies the + output dtype. + device_visible: Whether to make the result array visible to devices. Defaults to + False. + +Returns: + A device_array of the requested dtype, or the input dtype if not specified. +)"; + static const char DOCSTRING_FILL_RANDN[] = R"(Fills an array with numbers sampled from the standard ormal distribution. @@ -63,7 +91,14 @@ static const char DOCSTRING_RANDOM_GENERATOR[] = fixed number. )"; -} // namespace +#define SF_UNARY_FUNCTION_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): \ + return compute.template operator()() + +#define SF_UNARY_THUNK_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): \ + compute.template operator()(); \ + break struct PyRandomGenerator { public: @@ -85,9 +120,261 @@ struct PyRandomGenerator { xt::random::default_engine_type engine_; }; -#define SF_UNARY_COMPUTE_CASE(dtype_name, cpp_type) \ - case DType::dtype_name(): \ - return compute.template operator()() +// Generic conversion templates, split into a bindable template and functors +// that operate on pre-allocated outputs. +template +device_array GenericElementwiseConvert(device_array &input, + std::optional dtype, + std::optional out, + bool device_visible) { + // Argument check and output allocation. + if (!dtype) { + dtype = out ? out->dtype() : input.dtype(); + } else { + if (out && out->dtype() != dtype) { + throw std::invalid_argument( + "if both dtype and out are specified, they must match"); + } + } + if (!out) { + out.emplace(device_array::for_host(input.device(), input.shape(), *dtype, + device_visible)); + } + + ConvertFunc::Invoke(input, *dtype, *out); + return *out; +} + +// Generic elementwise conversion functor +struct ConvertFunctor { + static void Invoke(device_array &input, DType dtype, device_array &out) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::convert"); + auto compute = [&]() -> void { + auto input_t = input.map_xtensor(); + // Casted output. +#define SF_STORE_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): { \ + auto out_t = out.map_xtensor_w(); \ + *out_t = xt::cast(*input_t); \ + break; \ + } + switch (dtype) { + SF_STORE_CASE(float16, half_float::half); + SF_STORE_CASE(float32, float); + SF_STORE_CASE(float64, double); + SF_STORE_CASE(uint8, uint8_t); + SF_STORE_CASE(int8, int8_t); + SF_STORE_CASE(uint16, uint16_t); + SF_STORE_CASE(int16, int16_t); + SF_STORE_CASE(uint32, uint32_t); + SF_STORE_CASE(int32, int32_t); + SF_STORE_CASE(uint64, uint64_t); + SF_STORE_CASE(int64, int64_t); + default: + throw std::invalid_argument("Invalid output dtype for convert op"); + } + +#undef SF_STORE_CASE + }; + + switch (input.dtype()) { + SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(float32, float); + SF_UNARY_THUNK_CASE(float64, double); + SF_UNARY_THUNK_CASE(uint8, uint8_t); + SF_UNARY_THUNK_CASE(int8, int8_t); + SF_UNARY_THUNK_CASE(uint16, uint16_t); + SF_UNARY_THUNK_CASE(int16, int16_t); + SF_UNARY_THUNK_CASE(uint32, uint32_t); + SF_UNARY_THUNK_CASE(int32, uint32_t); + SF_UNARY_THUNK_CASE(uint64, uint64_t); + SF_UNARY_THUNK_CASE(int64, int64_t); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for converting nearest integer op", + dtype.name())); + } + } +}; + +// Converting round functor. +struct ConvertRoundFunctor { + static void Invoke(device_array &input, DType dtype, device_array &out) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::round"); + auto compute = [&]() -> void { + auto input_t = input.map_xtensor(); + auto rounded = xt::round(*input_t); + if (input.dtype() == dtype) { + // Same type output. + auto out_t = out.map_xtensor_w(); + *out_t = rounded; + } else { + // Casted output. +#define SF_STORE_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): { \ + auto out_t = out.map_xtensor_w(); \ + *out_t = xt::cast(rounded); \ + break; \ + } + switch (dtype) { + SF_STORE_CASE(uint8, uint8_t); + SF_STORE_CASE(int8, int8_t); + SF_STORE_CASE(uint16, uint16_t); + SF_STORE_CASE(int16, int16_t); + SF_STORE_CASE(uint32, uint32_t); + SF_STORE_CASE(int32, int32_t); + default: + throw std::invalid_argument( + "Invalid output dtype for converting nearest integer op"); + } + } +#undef SF_STORE_CASE + }; + + switch (input.dtype()) { + SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(float32, float); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for converting nearest integer op", + dtype.name())); + } + } +}; + +struct ConvertCeilFunctor { + static void Invoke(device_array &input, DType dtype, device_array &out) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::ceil"); + auto compute = [&]() -> void { + auto input_t = input.map_xtensor(); + auto rounded = xt::ceil(*input_t); + if (input.dtype() == dtype) { + // Same type output. + auto out_t = out.map_xtensor_w(); + *out_t = rounded; + } else { + // Casted output. +#define SF_STORE_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): { \ + auto out_t = out.map_xtensor_w(); \ + *out_t = xt::cast(rounded); \ + break; \ + } + switch (dtype) { + SF_STORE_CASE(uint8, uint8_t); + SF_STORE_CASE(int8, int8_t); + SF_STORE_CASE(uint16, uint16_t); + SF_STORE_CASE(int16, int16_t); + SF_STORE_CASE(uint32, uint32_t); + SF_STORE_CASE(int32, int32_t); + default: + throw std::invalid_argument( + "Invalid output dtype for converting nearest integer op"); + } + } +#undef SF_STORE_CASE + }; + + switch (input.dtype()) { + SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(float32, float); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for converting nearest integer op", + dtype.name())); + } + } +}; + +struct ConvertFloorFunctor { + static void Invoke(device_array &input, DType dtype, device_array &out) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::floor"); + auto compute = [&]() -> void { + auto input_t = input.map_xtensor(); + auto rounded = xt::floor(*input_t); + if (input.dtype() == dtype) { + // Same type output. + auto out_t = out.map_xtensor_w(); + *out_t = rounded; + } else { + // Casted output. +#define SF_STORE_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): { \ + auto out_t = out.map_xtensor_w(); \ + *out_t = xt::cast(rounded); \ + break; \ + } + switch (dtype) { + SF_STORE_CASE(uint8, uint8_t); + SF_STORE_CASE(int8, int8_t); + SF_STORE_CASE(uint16, uint16_t); + SF_STORE_CASE(int16, int16_t); + SF_STORE_CASE(uint32, uint32_t); + SF_STORE_CASE(int32, int32_t); + default: + throw std::invalid_argument( + "Invalid output dtype for converting nearest integer op"); + } + } +#undef SF_STORE_CASE + }; + + switch (input.dtype()) { + SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(float32, float); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for converting nearest integer op", + dtype.name())); + } + } +}; + +struct ConvertTruncFunctor { + static void Invoke(device_array &input, DType dtype, device_array &out) { + SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::trunc"); + auto compute = [&]() -> void { + auto input_t = input.map_xtensor(); + auto rounded = xt::trunc(*input_t); + if (input.dtype() == dtype) { + // Same type output. + auto out_t = out.map_xtensor_w(); + *out_t = rounded; + } else { + // Casted output. +#define SF_STORE_CASE(dtype_name, cpp_type) \ + case DType::dtype_name(): { \ + auto out_t = out.map_xtensor_w(); \ + *out_t = xt::cast(rounded); \ + break; \ + } + switch (dtype) { + SF_STORE_CASE(uint8, uint8_t); + SF_STORE_CASE(int8, int8_t); + SF_STORE_CASE(uint16, uint16_t); + SF_STORE_CASE(int16, int16_t); + SF_STORE_CASE(uint32, uint32_t); + SF_STORE_CASE(int32, int32_t); + default: + throw std::invalid_argument( + "Invalid output dtype for converting nearest integer op"); + } + } +#undef SF_STORE_CASE + }; + + switch (input.dtype()) { + SF_UNARY_THUNK_CASE(float16, half_float::half); + SF_UNARY_THUNK_CASE(float32, float); + default: + throw std::invalid_argument(fmt::format( + "Unsupported dtype({}) for converting nearest integer op", + dtype.name())); + } + } +}; + +} // namespace void BindArrayHostOps(py::module_ &m) { // Simple op definitions. @@ -121,8 +408,8 @@ void BindArrayHostOps(py::module_ &m) { }; switch (input.dtype()) { - SF_UNARY_COMPUTE_CASE(float16, half_float::half); - SF_UNARY_COMPUTE_CASE(float32, float); + SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(float32, float); default: throw std::invalid_argument( fmt::format("Unsupported dtype({}) for operator argmax", @@ -150,8 +437,8 @@ void BindArrayHostOps(py::module_ &m) { }; switch (out.dtype()) { - SF_UNARY_COMPUTE_CASE(float16, half_float::half); - SF_UNARY_COMPUTE_CASE(float32, float); + SF_UNARY_FUNCTION_CASE(float16, half_float::half); + SF_UNARY_FUNCTION_CASE(float32, float); default: throw std::invalid_argument( fmt::format("Unsupported dtype({}) for operator randn", @@ -159,6 +446,17 @@ void BindArrayHostOps(py::module_ &m) { } }, py::arg("out"), py::arg("generator") = py::none(), DOCSTRING_FILL_RANDN); + +// Data-type conversion and rounding. +#define SF_DEF_CONVERT(py_name, target) \ + m.def(py_name, target, py::arg("input"), py::kw_only(), \ + py::arg("dtype") = py::none(), py::arg("out") = py::none(), \ + py::arg("device_visible") = false, DOCSTRING_CONVERT) + SF_DEF_CONVERT("convert", GenericElementwiseConvert); + SF_DEF_CONVERT("ceil", GenericElementwiseConvert); + SF_DEF_CONVERT("floor", GenericElementwiseConvert); + SF_DEF_CONVERT("round", GenericElementwiseConvert); + SF_DEF_CONVERT("trunc", GenericElementwiseConvert); } } // namespace shortfin::python diff --git a/shortfin/python/shortfin/array/__init__.py b/shortfin/python/shortfin/array/__init__.py index 3a4d28877..6079541c8 100644 --- a/shortfin/python/shortfin/array/__init__.py +++ b/shortfin/python/shortfin/array/__init__.py @@ -44,7 +44,12 @@ # Ops. argmax = _sfl.array.argmax +ceil = _sfl.array.ceil +convert = _sfl.array.convert fill_randn = _sfl.array.fill_randn +floor = _sfl.array.floor +round = _sfl.array.round +trunc = _sfl.array.trunc RandomGenerator = _sfl.array.RandomGenerator __all__ = [ @@ -82,7 +87,12 @@ "DType", # Ops. "argmax", + "ceil", + "convert", "fill_randn", + "floor", + "round", + "trunc", "RandomGenerator", ] diff --git a/shortfin/tests/api/array_ops_test.py b/shortfin/tests/api/array_ops_test.py index 69d21e929..7c792d92b 100644 --- a/shortfin/tests/api/array_ops_test.py +++ b/shortfin/tests/api/array_ops_test.py @@ -167,3 +167,104 @@ def test_fill_randn_explicit_generator(device, dtype): assert contents1 == contents2 # And not be zero. assert contents1 != bytes(mz) + + +@pytest.mark.parametrize( + "dtype", + [ + sfnp.uint8, + sfnp.uint16, + sfnp.uint32, + sfnp.uint64, + sfnp.int8, + sfnp.int16, + sfnp.int32, + sfnp.int64, + sfnp.float16, + sfnp.float32, + sfnp.float64, + ], +) +def test_convert(device, dtype): + input_array = sfnp.device_array(device, [2, 3], dtype=sfnp.int32) + with input_array.map(write=True) as m: + m.fill(16) + intermediate = sfnp.convert(input_array, dtype=dtype) + with input_array.map(write=True) as m: + m.fill(0) + sfnp.convert(intermediate, out=input_array) + assert list(input_array.items) == 6 * [16] + + +def round_half_up(n): + return math.floor(n + 0.5) + + +def round_half_away_from_zero(n): + rounded_abs = round_half_up(abs(n)) + return math.copysign(rounded_abs, n) + + +@pytest.mark.parametrize( + "dtype,sfnp_func,ref_round_func", + [ + (sfnp.float16, sfnp.round, round_half_away_from_zero), + (sfnp.float32, sfnp.round, round_half_away_from_zero), + (sfnp.float16, sfnp.ceil, math.ceil), + (sfnp.float32, sfnp.ceil, math.ceil), + (sfnp.float16, sfnp.floor, math.floor), + (sfnp.float32, sfnp.floor, math.floor), + (sfnp.float16, sfnp.trunc, math.trunc), + (sfnp.float32, sfnp.trunc, math.trunc), + ], +) +def test_nearest_int_no_conversion(device, dtype, sfnp_func, ref_round_func): + input = sfnp.device_array(device, [2, 3], dtype=dtype) + sfnp.fill_randn(input) + ref_rounded = [ + ref_round_func(n) for n in sfnp.convert(input, dtype=sfnp.float32).items + ] + output = sfnp_func(input) + assert output.dtype == dtype + output_items = sfnp.convert(output, dtype=sfnp.float32).items + print(output_items) + for ref, actual in zip(ref_rounded, output_items): + assert ref == pytest.approx(actual) + + +@pytest.mark.parametrize( + "dtype,out_dtype,sfnp_func,ref_round_func", + [ + # Round + (sfnp.float16, sfnp.int8, sfnp.round, round_half_away_from_zero), + (sfnp.float32, sfnp.int8, sfnp.round, round_half_away_from_zero), + (sfnp.float32, sfnp.int16, sfnp.round, round_half_away_from_zero), + (sfnp.float32, sfnp.int32, sfnp.round, round_half_away_from_zero), + # Note that we do not test unsigned conversion with random data. + # Ceil + (sfnp.float16, sfnp.int8, sfnp.ceil, math.ceil), + (sfnp.float32, sfnp.int8, sfnp.ceil, math.ceil), + (sfnp.float32, sfnp.int16, sfnp.ceil, math.ceil), + (sfnp.float32, sfnp.int32, sfnp.ceil, math.ceil), + # Floor + (sfnp.float16, sfnp.int8, sfnp.floor, math.floor), + (sfnp.float32, sfnp.int8, sfnp.floor, math.floor), + (sfnp.float32, sfnp.int16, sfnp.floor, math.floor), + (sfnp.float32, sfnp.int32, sfnp.floor, math.floor), + # Trunc + (sfnp.float16, sfnp.int8, sfnp.trunc, math.trunc), + (sfnp.float32, sfnp.int8, sfnp.trunc, math.trunc), + (sfnp.float32, sfnp.int16, sfnp.trunc, math.trunc), + (sfnp.float32, sfnp.int32, sfnp.trunc, math.trunc), + ], +) +def test_nearest_int_conversion(device, dtype, out_dtype, sfnp_func, ref_round_func): + input = sfnp.device_array(device, [2, 3], dtype=dtype) + sfnp.fill_randn(input) + ref_rounded = [ + int(ref_round_func(n)) for n in sfnp.convert(input, dtype=sfnp.float32).items + ] + output = sfnp_func(input, dtype=out_dtype) + assert output.dtype == out_dtype + for ref, actual in zip(ref_rounded, output.items): + assert ref == int(actual) From c197e169d8ac62eb0725e127b7dbadbc6792d330 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 12 Nov 2024 16:34:12 +0100 Subject: [PATCH 21/59] Add copyright headers (#484) --- build_tools/integration_tests/llm/conftest.py | 6 ++++++ shortfin/python/shortfin_apps/llm/client.py | 6 ++++++ shortfin/python/shortfin_apps/sd/components/builders.py | 6 ++++++ .../python/shortfin_apps/sd/components/config_struct.py | 6 ++++++ shortfin/python/shortfin_apps/sd/components/metrics.py | 6 ++++++ shortfin/python/shortfin_apps/sd/examples/send_request.py | 6 ++++++ shortfin/tests/apps/sd/e2e_test.py | 6 ++++++ 7 files changed, 42 insertions(+) diff --git a/build_tools/integration_tests/llm/conftest.py b/build_tools/integration_tests/llm/conftest.py index 9b93a5d96..1103065bc 100644 --- a/build_tools/integration_tests/llm/conftest.py +++ b/build_tools/integration_tests/llm/conftest.py @@ -1,3 +1,9 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + import json import logging import os diff --git a/shortfin/python/shortfin_apps/llm/client.py b/shortfin/python/shortfin_apps/llm/client.py index 63cff7bee..f4e104a9f 100644 --- a/shortfin/python/shortfin_apps/llm/client.py +++ b/shortfin/python/shortfin_apps/llm/client.py @@ -1,3 +1,9 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + import requests import json import uuid diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py index 81203e713..ed948bee9 100644 --- a/shortfin/python/shortfin_apps/sd/components/builders.py +++ b/shortfin/python/shortfin_apps/sd/components/builders.py @@ -1,3 +1,9 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + from iree.build import * from iree.build.executor import FileNamespace import itertools diff --git a/shortfin/python/shortfin_apps/sd/components/config_struct.py b/shortfin/python/shortfin_apps/sd/components/config_struct.py index 3dda6edfc..478d03ad8 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/config_struct.py @@ -1,3 +1,9 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """Configuration objects. Parameters that are intrinsic to a specific model. diff --git a/shortfin/python/shortfin_apps/sd/components/metrics.py b/shortfin/python/shortfin_apps/sd/components/metrics.py index 6d3c1aa8b..f8fd30876 100644 --- a/shortfin/python/shortfin_apps/sd/components/metrics.py +++ b/shortfin/python/shortfin_apps/sd/components/metrics.py @@ -1,3 +1,9 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + import logging import time import asyncio diff --git a/shortfin/python/shortfin_apps/sd/examples/send_request.py b/shortfin/python/shortfin_apps/sd/examples/send_request.py index dd2226e70..9fce890d6 100644 --- a/shortfin/python/shortfin_apps/sd/examples/send_request.py +++ b/shortfin/python/shortfin_apps/sd/examples/send_request.py @@ -1,3 +1,9 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + import json import requests import argparse diff --git a/shortfin/tests/apps/sd/e2e_test.py b/shortfin/tests/apps/sd/e2e_test.py index cab8ecab2..26c2e30f6 100644 --- a/shortfin/tests/apps/sd/e2e_test.py +++ b/shortfin/tests/apps/sd/e2e_test.py @@ -1,3 +1,9 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + import json import requests import time From 903d3c1ac63c4e9b8a013a29eeca9a8ed7047e9f Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Tue, 12 Nov 2024 19:45:28 +0100 Subject: [PATCH 22/59] Refactor release related files and scripts (#483) * Renames and moves `gen_version_info_rc.py`. * Harmonizes the naming of version files: In IREE a `version.json` file is used to compute a `version_local.json`. With this patch, the same is applied here to reduce maintenance efforts. * Refactors `write_requirements.py`: Refactors the script to generate the requirements file in memory and write only once. --- .github/workflows/build_packages.yml | 16 +++--- .gitignore | 3 ++ .../python_deploy/compute_common_version.py | 4 +- .../compute_local_version.py} | 23 +++++---- .../python_deploy/write_requirements.py | 51 ++++++++++--------- shark-ai/.gitignore | 1 - sharktank/setup.py | 10 ++-- sharktank/{version_info.json => version.json} | 0 shortfin/CMakeLists.txt | 2 +- shortfin/setup.py | 10 ++-- shortfin/{version_info.json => version.json} | 0 11 files changed, 65 insertions(+), 55 deletions(-) rename build_tools/{gen_version_info_rc.py => python_deploy/compute_local_version.py} (59%) rename sharktank/{version_info.json => version.json} (100%) rename shortfin/{version_info.json => version.json} (100%) diff --git a/.github/workflows/build_packages.yml b/.github/workflows/build_packages.yml index 4a332b6f5..8f138d973 100644 --- a/.github/workflows/build_packages.yml +++ b/.github/workflows/build_packages.yml @@ -37,15 +37,15 @@ jobs: - name: Generate release candidate versions id: version_rc run: | - sharktank_package_version=$(python3 build_tools/gen_version_info_rc.py sharktank) - shortfin_package_version=$(python3 build_tools/gen_version_info_rc.py shortfin) - - name: Upload version_info_rc.json + sharktank_package_version=$(python3 build_tools/python_deploy/compute_local_version.py sharktank) + shortfin_package_version=$(python3 build_tools/python_deploy/compute_local_version.py shortfin) + - name: Upload version_local.json uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 with: - name: version_info_rc + name: version_local path: | - sharktank/version_info_rc.json - shortfin/version_info_rc.json + sharktank/version_local.json + shortfin/version_local.json build_packages: name: "${{ matrix.package }} :: ${{ matrix.platform }} :: ${{ matrix.python-version }}" @@ -91,10 +91,10 @@ jobs: path: "c" # Windows can hit path length limits, so use a short path. submodules: false - - name: Download version_info_rc.json + - name: Download version_local.json uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 with: - name: version_info_rc + name: version_local path: ./c/ merge-multiple: true diff --git a/.gitignore b/.gitignore index 6474e6a8c..bdb0b5387 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,9 @@ wheelhouse *.whl *.venv +# Local-only config options +version_local.json + #Model artifacts *.pt *.safetensors diff --git a/build_tools/python_deploy/compute_common_version.py b/build_tools/python_deploy/compute_common_version.py index ba5e653fb..6aea7f254 100644 --- a/build_tools/python_deploy/compute_common_version.py +++ b/build_tools/python_deploy/compute_common_version.py @@ -36,8 +36,8 @@ THIS_DIR = Path(__file__).parent.resolve() REPO_ROOT = THIS_DIR.parent.parent -VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version_info.json" -VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version_info.json" +VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version.json" +VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version.json" VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json" diff --git a/build_tools/gen_version_info_rc.py b/build_tools/python_deploy/compute_local_version.py similarity index 59% rename from build_tools/gen_version_info_rc.py rename to build_tools/python_deploy/compute_local_version.py index 9399053b0..46d18d0ed 100644 --- a/build_tools/gen_version_info_rc.py +++ b/build_tools/python_deploy/compute_local_version.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright 2024 Advanced Micro Devices, Inc. # # Licensed under the Apache License v2.0 with LLVM Exceptions. @@ -5,8 +6,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # This scripts grabs the X.Y.Z[.dev]` version identifier from a -# `version_info.json` and writes the corresponding -# `X.Y.ZrcYYYYMMDD` version identifier to `version_rc_info.json`. +# `version.json` and writes the corresponding +# `X.Y.ZrcYYYYMMDD` version identifier to `version_local.json`. import argparse from pathlib import Path @@ -20,18 +21,18 @@ parser.add_argument("path", type=Path) args = parser.parse_args() -VERSION_INFO_FILE = args.path / "version_info.json" -VERSION_INFO_RC_FILE = args.path / "version_info_rc.json" +VERSION_FILE = args.path / "version.json" +VERSION_FILE_LOCAL = args.path / "version_local.json" def load_version_info(): - with open(VERSION_INFO_FILE, "rt") as f: + with open(VERSION_FILE, "rt") as f: return json.load(f) def write_version_info(): - with open(VERSION_INFO_RC_FILE, "w") as f: - json.dump(version_info_rc, f, indent=2) + with open(VERSION_FILE_LOCAL, "w") as f: + json.dump(version_local, f, indent=2) f.write("\n") @@ -39,10 +40,12 @@ def write_version_info(): PACKAGE_VERSION = version_info.get("package-version") PACKAGE_BASE_VERSION = Version(PACKAGE_VERSION).base_version -PACKAGE_RC_VERSION = PACKAGE_BASE_VERSION + "rc" + datetime.today().strftime("%Y%m%d") +PACKAGE_LOCAL_VERSION = ( + PACKAGE_BASE_VERSION + "rc" + datetime.today().strftime("%Y%m%d") +) -version_info_rc = {"package-version": PACKAGE_RC_VERSION} +version_local = {"package-version": PACKAGE_LOCAL_VERSION} write_version_info() -print(PACKAGE_RC_VERSION) +print(PACKAGE_LOCAL_VERSION) diff --git a/build_tools/python_deploy/write_requirements.py b/build_tools/python_deploy/write_requirements.py index 6ad7c10f5..a89b74dfe 100644 --- a/build_tools/python_deploy/write_requirements.py +++ b/build_tools/python_deploy/write_requirements.py @@ -33,8 +33,8 @@ THIS_DIR = Path(__file__).parent.resolve() REPO_ROOT = THIS_DIR.parent.parent -VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version_info.json" -VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version_info.json" +VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version_local.json" +VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version_local.json" VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json" REQUIREMENTS_TXT = REPO_ROOT / "shark-ai/requirements.txt" @@ -44,18 +44,9 @@ def load_version_info(version_file): return json.load(f) -def write_requirements(package_list, package_version): +def write_requirements(requirements): with open(REQUIREMENTS_TXT, "w") as f: - for package in package_list: - PINNED_PACKAGE = package + "==" + package_version - f.write("%s\n" % PINNED_PACKAGE) - - -def append_requirements(package_list, package_version): - with open(REQUIREMENTS_TXT, "a") as f: - for package in package_list: - PINNED_PACKAGE = package + "==" + package_version - f.write("%s\n" % PINNED_PACKAGE) + f.write("%s\n" % requirements) metapackage_version = load_version_info(VERSION_FILE_LOCAL) @@ -70,20 +61,34 @@ def append_requirements(package_list, package_version): stable_packages_list = ["iree-base-compiler", "iree-base-runtime", "iree-turbine"] if Version(PACKAGE_VERSION).is_prerelease: - write_requirements( - ["sharktank"], - Version(SHARKTANK_PACKAGE_VERSION).base_version + "rc" + args.version_suffix, + requirements = ( + "sharktank==" + + Version(SHARKTANK_PACKAGE_VERSION).base_version + + "rc" + + args.version_suffix + + "\n" ) - append_requirements( - ["shortfin"], - Version(SHORTFIN_PACKAGE_VERSION).base_version + "rc" + args.version_suffix, + requirements += ( + "shortfin==" + + Version(SHORTFIN_PACKAGE_VERSION).base_version + + "rc" + + args.version_suffix ) + + write_requirements(requirements) + else: MAJOR_VERSION = Version(PACKAGE_VERSION).major MINOR_VERSION = Version(PACKAGE_VERSION).minor - write_requirements( - stable_packages_list, str(MAJOR_VERSION) + "." + str(MINOR_VERSION) + ".*" + STABLE_VERSION_TO_PIN = str(MAJOR_VERSION) + "." + str(MINOR_VERSION) + ".*" + + requirements = "" + for package in stable_packages_list: + requirements += package + "==" + STABLE_VERSION_TO_PIN + "\n" + requirements += ( + "sharktank==" + Version(SHARKTANK_PACKAGE_VERSION).base_version + "\n" ) - append_requirements(["sharktank"], Version(SHARKTANK_PACKAGE_VERSION).base_version) - append_requirements(["shortfin"], Version(SHORTFIN_PACKAGE_VERSION).base_version) + requirements += "shortfin==" + Version(SHORTFIN_PACKAGE_VERSION).base_version + + write_requirements(requirements) diff --git a/shark-ai/.gitignore b/shark-ai/.gitignore index 8e68ab1b5..80bf001b8 100644 --- a/shark-ai/.gitignore +++ b/shark-ai/.gitignore @@ -1,3 +1,2 @@ # Local-only config options -version_local.json requirements.txt diff --git a/sharktank/setup.py b/sharktank/setup.py index aca5c63d0..182f94abc 100644 --- a/sharktank/setup.py +++ b/sharktank/setup.py @@ -13,8 +13,8 @@ SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__)) # Setup and get version information. -VERSION_INFO_FILE = os.path.join(SETUPPY_DIR, "version_info.json") -VERSION_INFO_RC_FILE = os.path.join(SETUPPY_DIR, "version_info_rc.json") +VERSION_FILE = os.path.join(SETUPPY_DIR, "version.json") +VERSION_FILE_LOCAL = os.path.join(SETUPPY_DIR, "version_local.json") def load_version_info(version_file): @@ -23,10 +23,10 @@ def load_version_info(version_file): try: - version_info = load_version_info(VERSION_INFO_RC_FILE) + version_info = load_version_info(VERSION_FILE_LOCAL) except FileNotFoundError: - print("version_info_rc.json not found. Default to dev build") - version_info = load_version_info(VERSION_INFO_FILE) + print("version_local.json not found. Default to dev build") + version_info = load_version_info(VERSION_FILE) PACKAGE_VERSION = version_info.get("package-version") print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'") diff --git a/sharktank/version_info.json b/sharktank/version.json similarity index 100% rename from sharktank/version_info.json rename to sharktank/version.json diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt index 85113ce00..11982202d 100644 --- a/shortfin/CMakeLists.txt +++ b/shortfin/CMakeLists.txt @@ -14,7 +14,7 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_BINARY_DIR) endif() # Get version number from file -file(READ ${CMAKE_CURRENT_SOURCE_DIR}/version_info.json VERSION_JSON_STRING) +file(READ ${CMAKE_CURRENT_SOURCE_DIR}/version.json VERSION_JSON_STRING) string(JSON PACKAGE_VERSION GET ${VERSION_JSON_STRING} package-version) string(REGEX MATCH "(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*" BASE_VERSION ${PACKAGE_VERSION}) diff --git a/shortfin/setup.py b/shortfin/setup.py index 94aae4a55..2830f9f35 100644 --- a/shortfin/setup.py +++ b/shortfin/setup.py @@ -141,8 +141,8 @@ def copy_extensions_to_source(self, *args, **kwargs): # Setup and get version information. -VERSION_INFO_FILE = os.path.join(REL_SOURCE_DIR, "version_info.json") -VERSION_INFO_RC_FILE = os.path.join(REL_SOURCE_DIR, "version_info_rc.json") +VERSION_FILE = os.path.join(REL_SOURCE_DIR, "version.json") +VERSION_FILE_LOCAL = os.path.join(REL_SOURCE_DIR, "version_local.json") def load_version_info(version_file): @@ -151,10 +151,10 @@ def load_version_info(version_file): try: - version_info = load_version_info(VERSION_INFO_RC_FILE) + version_info = load_version_info(VERSION_FILE_LOCAL) except FileNotFoundError: - print("version_info_rc.json not found. Default to dev build") - version_info = load_version_info(VERSION_INFO_FILE) + print("version_local.json not found. Default to dev build") + version_info = load_version_info(VERSION_FILE) PACKAGE_VERSION = version_info.get("package-version") print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'") diff --git a/shortfin/version_info.json b/shortfin/version.json similarity index 100% rename from shortfin/version_info.json rename to shortfin/version.json From ce6ccf88e8f106a6a88b1946076139c011a8f5e8 Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Wed, 13 Nov 2024 10:15:31 -0600 Subject: [PATCH 23/59] Shortfin LLM Deviceid Support (#493) # Description Add the ability to specify device_ids that you want Shortfin LLM Server to run with. The setup is essentially 1-1 with how SD server sets device_ids support up. Created a new `shortfin/interop/support/device_setup.py` module and moved the `get_selected_devices` function there to be shared across `managers`. ## Example ```bash python -m shortfin_apps.llm.server --tokenizer_json=/data/llama3.1/8b/tokenizer.json --model_config=./export/edited_config.json --vmfb=./export/model.vmfb --parameters=/data/llama3.1/8b/llama8b_f16.irpa --device=amdgpu --device_ids=0 ``` --- .../shortfin/interop/support/device_setup.py | 26 +++++++++++++++++++ .../shortfin_apps/llm/components/manager.py | 15 ++++++++--- shortfin/python/shortfin_apps/llm/server.py | 21 +++++++++++++-- .../shortfin_apps/sd/components/manager.py | 26 +------------------ 4 files changed, 57 insertions(+), 31 deletions(-) create mode 100644 shortfin/python/shortfin/interop/support/device_setup.py diff --git a/shortfin/python/shortfin/interop/support/device_setup.py b/shortfin/python/shortfin/interop/support/device_setup.py new file mode 100644 index 000000000..afe6ca695 --- /dev/null +++ b/shortfin/python/shortfin/interop/support/device_setup.py @@ -0,0 +1,26 @@ +import shortfin as sf + + +def get_selected_devices(sb: sf.SystemBuilder, device_ids=None): + available = sb.available_devices + selected = [] + if device_ids is not None: + if len(device_ids) > len(available): + raise ValueError( + f"Requested more device ids ({device_ids}) than available ({available})." + ) + for did in device_ids: + if isinstance(did, str): + try: + did = int(did) + except ValueError: + did = did + if did in available: + selected.append(did) + elif isinstance(did, int): + selected.append(available[did]) + else: + raise ValueError(f"Device id {did} could not be parsed.") + else: + selected = available + return selected diff --git a/shortfin/python/shortfin_apps/llm/components/manager.py b/shortfin/python/shortfin_apps/llm/components/manager.py index e3057de22..b44116b39 100644 --- a/shortfin/python/shortfin_apps/llm/components/manager.py +++ b/shortfin/python/shortfin_apps/llm/components/manager.py @@ -8,16 +8,23 @@ import threading import shortfin as sf +from shortfin.interop.support.device_setup import get_selected_devices logger = logging.getLogger(__name__) class SystemManager: - def __init__(self, device="local-task"): - if device == "local-task": + def __init__(self, device="local-task", device_ids=None, async_allocs=True): + if any(x in device for x in ["local-task", "cpu"]): self.ls = sf.host.CPUSystemBuilder().create_system() - elif device == "hip": - self.ls = sf.amdgpu.SystemBuilder().create_system() + elif any(x in device for x in ["hip", "amdgpu"]): + sb = sf.SystemBuilder( + system_type="amdgpu", amdgpu_async_allocations=async_allocs + ) + if device_ids: + sb.visible_devices = sb.available_devices + sb.visible_devices = get_selected_devices(sb, device_ids) + self.ls = sb.create_system() logger.info(f"Created local system with {self.ls.device_names} devices") # TODO: Come up with an easier bootstrap thing than manually # running a thread. diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py index 5b51a9a7f..2ab7a1b96 100644 --- a/shortfin/python/shortfin_apps/llm/server.py +++ b/shortfin/python/shortfin_apps/llm/server.py @@ -86,7 +86,11 @@ def get_eos_from_tokenizer_config(json_path): def configure(args) -> SystemManager: # Setup system (configure devices, etc). - sysman = SystemManager(device=args.device) + sysman = SystemManager( + device=args.device, + device_ids=args.device_ids, + async_allocs=args.amdgpu_async_allocations, + ) # Setup each service we are hosting. eos_token = get_eos_from_tokenizer_config(args.tokenizer_config_json) @@ -155,9 +159,17 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): parser.add_argument( "--device", type=str, - default="local-task", + required=True, + choices=["local-task", "hip", "amdgpu"], help="Device to serve on; e.g. local-task, hip. Same options as `iree-run-module --device` ", ) + parser.add_argument( + "--device_ids", + type=str, + nargs="*", + default=None, + help="Device IDs visible to the system builder. Defaults to None (full visibility). Can be an index or a sf device id like amdgpu:0:0@0", + ) parser.add_argument( "--isolation", type=str, @@ -165,6 +177,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): choices=[isolation.name.lower() for isolation in ProgramIsolation], help="Concurrency control -- How to isolate programs.", ) + parser.add_argument( + "--amdgpu_async_allocations", + action="store_true", + help="Enable asynchronous allocations for amdgpu device contexts.", + ) args = parser.parse_args(argv) if args.tokenizer_config_json is None: diff --git a/shortfin/python/shortfin_apps/sd/components/manager.py b/shortfin/python/shortfin_apps/sd/components/manager.py index 846c4ced6..b44116b39 100644 --- a/shortfin/python/shortfin_apps/sd/components/manager.py +++ b/shortfin/python/shortfin_apps/sd/components/manager.py @@ -8,35 +8,11 @@ import threading import shortfin as sf +from shortfin.interop.support.device_setup import get_selected_devices logger = logging.getLogger(__name__) -def get_selected_devices(sb: sf.SystemBuilder, device_ids=None): - available = sb.available_devices - selected = [] - if device_ids is not None: - if len(device_ids) >= len(available): - raise ValueError( - f"Requested more device ids ({device_ids}) than available ({available})." - ) - for did in device_ids: - if isinstance(did, str): - try: - did = int(did) - except ValueError: - did = did - if did in available: - selected.append(did) - elif isinstance(did, int): - selected.append(available[did]) - else: - raise ValueError(f"Device id {did} could not be parsed.") - else: - selected = available - return selected - - class SystemManager: def __init__(self, device="local-task", device_ids=None, async_allocs=True): if any(x in device for x in ["local-task", "cpu"]): From 23bed174837e9ebe60e9b9f4ced3b9ccf9932c9c Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Wed, 13 Nov 2024 17:31:17 +0100 Subject: [PATCH 24/59] Adapt to package rename (#494) The packages were recently renamed from `iree-compiler` and `iree-runtime` to `iree-base-compiler` and `iree-base-compiler`, respectively. --- .github/workflows/ci-sharktank.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 6f359077a..4c660e6ee 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -63,7 +63,7 @@ jobs: # Update to the latest iree packages. pip install -f https://iree.dev/pip-release-links.html --upgrade \ - iree-compiler iree-runtime --src deps \ + iree-base-compiler iree-base-runtime --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" - name: Run sharktank tests From 7bd325388e72a187f7ecb07e95b15da0cf4fb384 Mon Sep 17 00:00:00 2001 From: Andrew Woloszyn Date: Wed, 13 Nov 2024 12:49:12 -0500 Subject: [PATCH 25/59] [shortfin] Fix the f-string for python 3.11 (#499) --- shortfin/python/shortfin_apps/sd/components/service.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index a64013db0..1ee11569a 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -183,6 +183,8 @@ def __repr__(self): params = [ f" {key} : {value}" for key, value in self.inference_parameters.items() ] + # For python 3.11 since we can't have \ in the f"" expression. + new_line = "\n" return ( f"ServiceManager(" f"\n INFERENCE DEVICES : \n" @@ -193,9 +195,9 @@ def __repr__(self): f" fibers per device : {self.fibers_per_device}\n" f" program isolation mode : {self.prog_isolation}\n" f"\n INFERENCE MODULES : \n" - f"{'\n'.join(modules)}\n" + f"{new_line.join(modules)}\n" f"\n INFERENCE PARAMETERS : \n" - f"{'\n'.join(params)}\n" + f"{new_line.join(params)}\n" f")" ) From 9c5c8ccae4219ae22a10ba93efbdedefce28b9bd Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Wed, 13 Nov 2024 19:38:03 +0100 Subject: [PATCH 26/59] Add metadata to shortfin Python package (#492) Similar to #489, this adds mertadata to the shortfin Python package but does so by extending the `pyproject.toml` file instead of adding metadata to the `setup.py` file. --- shortfin/README.md | 2 +- shortfin/pyproject.toml | 26 ++++++++++++++++++++++++++ shortfin/setup.py | 3 --- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/shortfin/README.md b/shortfin/README.md index 3e7901342..9818e05d3 100644 --- a/shortfin/README.md +++ b/shortfin/README.md @@ -1,4 +1,4 @@ -# shortfin - SHARK C++ inference library +# shortfin - SHARK inference library and serving engine ## Simple User Installation diff --git a/shortfin/pyproject.toml b/shortfin/pyproject.toml index 47cde6775..15bd68732 100644 --- a/shortfin/pyproject.toml +++ b/shortfin/pyproject.toml @@ -8,6 +8,32 @@ requires = [ ] build-backend = "setuptools.build_meta" +[project] +name = "shortfin" +authors = [ + {name = "SHARK Authors"}, +] +description = "SHARK inference library and serving engine" +readme = "README.md" +license = {text = "Apache-2.0"} +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +requires-python = ">= 3.10" + +# Version is set via the `setup.py`. +dynamic = ["version"] + +[project.urls] +Repository = "https://github.com/nod-ai/SHARK-Platform" +Documentation = "https://shortfin.readthedocs.io/en/latest/" + [tool.pytest.ini_options] addopts = [ "-ra", diff --git a/shortfin/setup.py b/shortfin/setup.py index 2830f9f35..cf3762950 100644 --- a/shortfin/setup.py +++ b/shortfin/setup.py @@ -359,10 +359,7 @@ def populate_built_package(abs_dir): print(f"Found shortfin packages: {packages}") setup( - name="shortfin", version=f"{PACKAGE_VERSION}", - description="Shortfin native library implementation", - author="SHARK Authors", packages=packages, zip_safe=False, package_dir=combine_dicts( From 2fa1f926d44859539f8da03a619f27c6f35d5995 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:15:48 -0600 Subject: [PATCH 27/59] (shortfin-sd) Usability and logging improvements. (#491) - Port configuration fixes - Reduce need for topology footguns by implementing simple topology config artifacts for 4 setups (cpx:single/multi, spx:single/multi). These set server/device topologies to "known good" configurations. - Fix example client script (send_request.py -> simple_client.py) and include in python package, arg problems fixed as well (--save) - Remove need for source code - Safer failures for invalid output dims - Don't report server startup under uvicorn.error - Updates sd README with new example CLI inputs - Adds help for client CLI args --- .../python/shortfin/support/logging_setup.py | 12 +- shortfin/python/shortfin_apps/sd/README.md | 11 +- .../shortfin_apps/sd/components/builders.py | 10 +- .../sd/components/config_artifacts.py | 123 ++++++++++ .../shortfin_apps/sd/components/generate.py | 2 +- .../shortfin_apps/sd/components/io_struct.py | 7 + .../shortfin_apps/sd/components/manager.py | 4 +- .../shortfin_apps/sd/components/messages.py | 2 +- .../shortfin_apps/sd/components/metrics.py | 2 +- .../shortfin_apps/sd/components/service.py | 37 ++- .../sd/examples/sdxl_request_bs32.json | 2 + .../shortfin_apps/sd/examples/send_request.py | 90 ------- shortfin/python/shortfin_apps/sd/server.py | 149 ++++++++++-- .../python/shortfin_apps/sd/simple_client.py | 229 ++++++++++++++++++ 14 files changed, 536 insertions(+), 144 deletions(-) create mode 100644 shortfin/python/shortfin_apps/sd/components/config_artifacts.py delete mode 100644 shortfin/python/shortfin_apps/sd/examples/send_request.py create mode 100644 shortfin/python/shortfin_apps/sd/simple_client.py diff --git a/shortfin/python/shortfin/support/logging_setup.py b/shortfin/python/shortfin/support/logging_setup.py index 5585e6a82..3cb373f1e 100644 --- a/shortfin/python/shortfin/support/logging_setup.py +++ b/shortfin/python/shortfin/support/logging_setup.py @@ -38,19 +38,15 @@ def __init__(self): native_handler.setFormatter(NativeFormatter()) # TODO: Source from env vars. -logger.setLevel(logging.INFO) +logger.setLevel(logging.DEBUG) logger.addHandler(native_handler) def configure_main_logger(module_suffix: str = "__main__") -> logging.Logger: """Configures logging from a main entrypoint. - Returns a logger that can be used for the main module itself. """ + logging.root.addHandler(native_handler) + logging.root.setLevel(logging.WARNING) # TODO: source from env vars main_module = sys.modules["__main__"] - logging.root.setLevel(logging.INFO) - logger = logging.getLogger(f"{main_module.__package__}.{module_suffix}") - logger.setLevel(logging.INFO) - logger.addHandler(native_handler) - - return logger + return logging.getLogger(f"{main_module.__package__}.{module_suffix}") diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md index 4808cad08..6dd701c62 100644 --- a/shortfin/python/shortfin_apps/sd/README.md +++ b/shortfin/python/shortfin_apps/sd/README.md @@ -38,13 +38,10 @@ cd shortfin/ The server will prepare runtime artifacts for you. ``` -python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile +python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single" ``` - - Run with splat(empty) weights: -``` -python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile -``` - - Run a request in a separate shell: + + - Run a CLI client in a separate shell: ``` -python shortfin/python/shortfin_apps/sd/examples/send_request.py --file=shortfin/python/shortfin_apps/sd/examples/sdxl_request.json +python -m shortfin_apps.sd.simple_client --interactive --save ``` diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py index ed948bee9..f23922dd6 100644 --- a/shortfin/python/shortfin_apps/sd/components/builders.py +++ b/shortfin/python/shortfin_apps/sd/components/builders.py @@ -24,7 +24,7 @@ sfnp.bfloat16: "bf16", } -ARTIFACT_VERSION = "11022024" +ARTIFACT_VERSION = "11132024" SDXL_BUCKET = ( f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/" ) @@ -51,7 +51,9 @@ def get_mlir_filenames(model_params: ModelParams, model=None): return filter_by_model(mlir_filenames, model) -def get_vmfb_filenames(model_params: ModelParams, model=None, target: str = "gfx942"): +def get_vmfb_filenames( + model_params: ModelParams, model=None, target: str = "amdgpu-gfx942" +): vmfb_filenames = [] file_stems = get_file_stems(model_params) for stem in file_stems: @@ -216,6 +218,8 @@ def sdxl( mlir_bucket = SDXL_BUCKET + "mlir/" vmfb_bucket = SDXL_BUCKET + "vmfbs/" + if "gfx" in target: + target = "amdgpu-" + target mlir_filenames = get_mlir_filenames(model_params, model) mlir_urls = get_url_map(mlir_filenames, mlir_bucket) @@ -247,7 +251,7 @@ def sdxl( params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET) for f, url in params_urls.items(): out_file = os.path.join(ctx.executor.output_dir, f) - if update or needs_file(f, ctx): + if needs_file(f, ctx): fetch_http(name=f, url=url) filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames] return filenames diff --git a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py new file mode 100644 index 000000000..b5a1d682b --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py @@ -0,0 +1,123 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.build import * +from iree.build.executor import FileNamespace +import itertools +import os +import shortfin.array as sfnp +import copy + +from shortfin_apps.sd.components.config_struct import ModelParams + +this_dir = os.path.dirname(os.path.abspath(__file__)) +parent = os.path.dirname(this_dir) + +dtype_to_filetag = { + sfnp.float16: "fp16", + sfnp.float32: "fp32", + sfnp.int8: "i8", + sfnp.bfloat16: "bf16", +} + +ARTIFACT_VERSION = "11132024" +SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/" + + +def get_url_map(filenames: list[str], bucket: str): + file_map = {} + for filename in filenames: + file_map[filename] = f"{bucket}{filename}" + return file_map + + +def needs_update(ctx): + stamp = ctx.allocate_file("version.txt") + stamp_path = stamp.get_fs_path() + if os.path.exists(stamp_path): + with open(stamp_path, "r") as s: + ver = s.read() + if ver != ARTIFACT_VERSION: + return True + else: + with open(stamp_path, "w") as s: + s.write(ARTIFACT_VERSION) + return True + return False + + +def needs_file(filename, ctx, namespace=FileNamespace.GEN): + out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path() + if os.path.exists(out_file): + needed = False + else: + # name_path = "bin" if namespace == FileNamespace.BIN else "" + # if name_path: + # filename = os.path.join(name_path, filename) + filekey = os.path.join(ctx.path, filename) + ctx.executor.all[filekey] = None + needed = True + return needed + + +@entrypoint(description="Retreives a set of SDXL configuration files.") +def sdxlconfig( + target=cl_arg( + "target", + default="gfx942", + help="IREE target architecture.", + ), + model=cl_arg("model", type=str, default="sdxl", help="Model architecture"), + topology=cl_arg( + "topology", + type=str, + default="spx_single", + help="System topology configfile keyword", + ), +): + ctx = executor.BuildContext.current() + update = needs_update(ctx) + + model_config_filenames = [f"{model}_config_i8.json"] + model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET) + for f, url in model_config_urls.items(): + out_file = os.path.join(ctx.executor.output_dir, f) + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + topology_config_filenames = [f"topology_config_{topology}.txt"] + topology_config_urls = get_url_map(topology_config_filenames, SDXL_CONFIG_BUCKET) + for f, url in topology_config_urls.items(): + out_file = os.path.join(ctx.executor.output_dir, f) + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + flagfile_filenames = [f"{model}_flagfile_{target}.txt"] + flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET) + for f, url in flagfile_urls.items(): + out_file = os.path.join(ctx.executor.output_dir, f) + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + + tuning_filenames = ( + [f"attention_and_matmul_spec_{target}.mlir"] if target == "gfx942" else [] + ) + tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET) + for f, url in tuning_urls.items(): + out_file = os.path.join(ctx.executor.output_dir, f) + if update or needs_file(f, ctx): + fetch_http(name=f, url=url) + filenames = [ + *model_config_filenames, + *topology_config_filenames, + *flagfile_filenames, + *tuning_filenames, + ] + return filenames + + +if __name__ == "__main__": + iree_build_main() diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py index ebb5ea08a..1afa73d5e 100644 --- a/shortfin/python/shortfin_apps/sd/components/generate.py +++ b/shortfin/python/shortfin_apps/sd/components/generate.py @@ -20,7 +20,7 @@ from .service import GenerateService from .metrics import measure -logger = logging.getLogger(__name__) +logger = logging.getLogger("shortfin-sd.generate") class GenerateImageProcess(sf.Process): diff --git a/shortfin/python/shortfin_apps/sd/components/io_struct.py b/shortfin/python/shortfin_apps/sd/components/io_struct.py index d2952a818..d1d9cf41a 100644 --- a/shortfin/python/shortfin_apps/sd/components/io_struct.py +++ b/shortfin/python/shortfin_apps/sd/components/io_struct.py @@ -72,3 +72,10 @@ def post_init(self): raise ValueError("The rid should be a list.") if self.output_type is None: self.output_type = ["base64"] * self.num_output_images + # Temporary restrictions + heights = [self.height] if not isinstance(self.height, list) else self.height + widths = [self.width] if not isinstance(self.width, list) else self.width + if any(dim != 1024 for dim in [*heights, *widths]): + raise ValueError( + "Currently, only 1024x1024 output image size is supported." + ) diff --git a/shortfin/python/shortfin_apps/sd/components/manager.py b/shortfin/python/shortfin_apps/sd/components/manager.py index b44116b39..ea29b69a4 100644 --- a/shortfin/python/shortfin_apps/sd/components/manager.py +++ b/shortfin/python/shortfin_apps/sd/components/manager.py @@ -10,7 +10,7 @@ import shortfin as sf from shortfin.interop.support.device_setup import get_selected_devices -logger = logging.getLogger(__name__) +logger = logging.getLogger("shortfin-sd.manager") class SystemManager: @@ -25,7 +25,7 @@ def __init__(self, device="local-task", device_ids=None, async_allocs=True): sb.visible_devices = sb.available_devices sb.visible_devices = get_selected_devices(sb, device_ids) self.ls = sb.create_system() - logger.info(f"Created local system with {self.ls.device_names} devices") + logging.info(f"Created local system with {self.ls.device_names} devices") # TODO: Come up with an easier bootstrap thing than manually # running a thread. self.t = threading.Thread(target=lambda: self.ls.run(self.run())) diff --git a/shortfin/python/shortfin_apps/sd/components/messages.py b/shortfin/python/shortfin_apps/sd/components/messages.py index 88eb28ff4..6ae716bad 100644 --- a/shortfin/python/shortfin_apps/sd/components/messages.py +++ b/shortfin/python/shortfin_apps/sd/components/messages.py @@ -13,7 +13,7 @@ from .io_struct import GenerateReqInput -logger = logging.getLogger(__name__) +logger = logging.getLogger("shortfin-sd.messages") class InferencePhase(Enum): diff --git a/shortfin/python/shortfin_apps/sd/components/metrics.py b/shortfin/python/shortfin_apps/sd/components/metrics.py index f8fd30876..a1811beea 100644 --- a/shortfin/python/shortfin_apps/sd/components/metrics.py +++ b/shortfin/python/shortfin_apps/sd/components/metrics.py @@ -10,7 +10,7 @@ from typing import Callable, Any import functools -logger = logging.getLogger(__name__) +logger = logging.getLogger("shortfin-sd.metrics") def measure(fn=None, type="exec", task=None, num_items=None, freq=1, label="items"): diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py index 1ee11569a..ad3fd9404 100644 --- a/shortfin/python/shortfin_apps/sd/components/service.py +++ b/shortfin/python/shortfin_apps/sd/components/service.py @@ -24,7 +24,8 @@ from .metrics import measure -logger = logging.getLogger(__name__) +logger = logging.getLogger("shortfin-sd.service") +logger.setLevel(logging.DEBUG) prog_isolations = { "none": sf.ProgramIsolation.NONE, @@ -119,8 +120,6 @@ def load_inference_parameters( def start(self): # Initialize programs. - # This can work if we only initialize one set of programs per service, as our programs - # in SDXL are stateless and for component in self.inference_modules: component_modules = [ sf.ProgramModule.parameter_provider( @@ -128,17 +127,22 @@ def start(self): ), *self.inference_modules[component], ] + for worker_idx, worker in enumerate(self.workers): worker_devices = self.fibers[ worker_idx * (self.fibers_per_worker) ].raw_devices - + logger.info( + f"Loading inference program: {component}, worker index: {worker_idx}, device: {worker_devices}" + ) self.inference_programs[worker_idx][component] = sf.Program( modules=component_modules, devices=worker_devices, isolation=self.prog_isolation, trace_execution=self.trace_execution, ) + logger.info("Program loaded.") + for worker_idx, worker in enumerate(self.workers): self.inference_functions[worker_idx]["encode"] = {} for bs in self.model_params.clip_batch_sizes: @@ -170,7 +174,6 @@ def start(self): ] = self.inference_programs[worker_idx]["vae"][ f"{self.model_params.vae_module_name}.decode" ] - # breakpoint() self.batcher.launch() def shutdown(self): @@ -212,8 +215,8 @@ class BatcherProcess(sf.Process): into batches. """ - STROBE_SHORT_DELAY = 0.1 - STROBE_LONG_DELAY = 0.25 + STROBE_SHORT_DELAY = 0.5 + STROBE_LONG_DELAY = 1 def __init__(self, service: GenerateService): super().__init__(fiber=service.fibers[0]) @@ -356,7 +359,6 @@ async def run(self): logger.error("Executor process recieved disjoint batch.") phase = req.phase phases = self.exec_requests[0].phases - req_count = len(self.exec_requests) device0 = self.service.fibers[self.fiber_index].device(0) if phases[InferencePhase.PREPARE]["required"]: @@ -424,8 +426,12 @@ async def _prepare(self, device, requests): async def _encode(self, device, requests): req_bs = len(requests) entrypoints = self.service.inference_functions[self.worker_index]["encode"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._encode(device, [request]) + return for bs, fn in entrypoints.items(): - if bs >= req_bs: + if bs == req_bs: break # Prepare tokenized input ids for CLIP inference @@ -462,6 +468,7 @@ async def _encode(self, device, requests): fn, "".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]), ) + await device pe, te = await fn(*clip_inputs, fiber=self.fiber) for i in range(req_bs): @@ -477,8 +484,12 @@ async def _denoise(self, device, requests): cfg_mult = 2 if self.service.model_params.cfg_mode else 1 # Produce denoised latents entrypoints = self.service.inference_functions[self.worker_index]["denoise"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._denoise(device, [request]) + return for bs, fns in entrypoints.items(): - if bs >= req_bs: + if bs == req_bs: break # Get shape of batched latents. @@ -613,8 +624,12 @@ async def _decode(self, device, requests): req_bs = len(requests) # Decode latents to images entrypoints = self.service.inference_functions[self.worker_index]["decode"] + if req_bs not in list(entrypoints.keys()): + for request in requests: + await self._decode(device, [request]) + return for bs, fn in entrypoints.items(): - if bs >= req_bs: + if bs == req_bs: break latents_shape = [ diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json index 192a2be61..002f43f0e 100644 --- a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json +++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json @@ -29,6 +29,8 @@ " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal", " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal", + " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal", " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo" ], "neg_prompt": [ diff --git a/shortfin/python/shortfin_apps/sd/examples/send_request.py b/shortfin/python/shortfin_apps/sd/examples/send_request.py deleted file mode 100644 index 9fce890d6..000000000 --- a/shortfin/python/shortfin_apps/sd/examples/send_request.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import json -import requests -import argparse -import base64 - -from datetime import datetime as dt -from PIL import Image - -sample_request = { - "prompt": [ - " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", - ], - "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], - "height": [1024], - "width": [1024], - "steps": [20], - "guidance_scale": [7.5], - "seed": [0], - "output_type": ["base64"], - "rid": ["string"], -} - - -def bytes_to_img(bytes, idx=0, width=1024, height=1024): - timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") - image = Image.frombytes( - mode="RGB", size=(width, height), data=base64.b64decode(bytes) - ) - image.save(f"shortfin_sd_output_{timestamp}_{idx}.png") - print(f"Saved to shortfin_sd_output_{timestamp}_{idx}.png") - - -def send_json_file(args): - # Read the JSON file - try: - if args.file == "default": - data = sample_request - else: - with open(args.file, "r") as json_file: - data = json.load(json_file) - except Exception as e: - print(f"Error reading the JSON file: {e}") - return - data["prompt"] = ( - [data["prompt"]] - if isinstance(data["prompt"], str) - else data["prompt"] * args.reps - ) - # Send the data to the /generate endpoint - try: - response = requests.post("http://0.0.0.0:8000/generate", json=data) - response.raise_for_status() # Raise an error for bad responses - timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") - request = json.loads(response.request.body.decode("utf-8")) - for idx, item in enumerate(response.json()["images"]): - width = get_batched(request, "width", idx) - height = get_batched(request, "height", idx) - if args.save: - print("Saving response as image...") - bytes_to_img(item.encode("utf-8"), idx, width, height) - print("Responses processed.") - - except requests.exceptions.RequestException as e: - print(f"Error sending the request: {e}") - - -def get_batched(request, arg, idx): - if isinstance(request[arg], list): - if len(request[arg]) == 1: - indexed = request[arg][0] - else: - indexed = request[arg][idx] - else: - indexed = request[arg] - return indexed - - -if __name__ == "__main__": - p = argparse.ArgumentParser() - p.add_argument("--file", type=str, default="default") - p.add_argument("--reps", type=int, default=1) - p.add_argument("--save", type=argparse.BooleanOptionalAction, help="save images") - args = p.parse_args() - send_json_file(args) diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 7ace4d407..9ee81d1c4 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -15,8 +15,6 @@ import copy import subprocess -from iree.build import * - # Import first as it does dep checking and reporting. from shortfin.interop.fastapi import FastAPIResponder @@ -33,9 +31,12 @@ from .components.tokenizer import Tokenizer from .components.builders import sdxl -from shortfin.support.logging_setup import configure_main_logger +from shortfin.support.logging_setup import native_handler, configure_main_logger -logger = configure_main_logger("server") +logger = logging.getLogger("shortfin-sd") +logger.addHandler(native_handler) +logger.setLevel(logging.INFO) +logger.propagate = False THIS_DIR = Path(__file__).resolve().parent @@ -84,6 +85,7 @@ async def generate_request(gen_req: GenerateReqInput, request: Request): def configure(args) -> SystemManager: # Setup system (configure devices, etc). + model_config, topology_config, flagfile, tuning_spec, args = get_configs(args) sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations) # Setup each service we are hosting. @@ -92,7 +94,9 @@ def configure(args) -> SystemManager: subfolder = f"tokenizer_{idx + 1}" if idx > 0 else "tokenizer" tokenizers.append(Tokenizer.from_pretrained(tok_name, subfolder)) - model_params = ModelParams.load_json(args.model_config) + model_params = ModelParams.load_json(model_config) + vmfbs, params = get_modules(args, model_config, flagfile, tuning_spec) + sm = GenerateService( name="sd", sysman=sysman, @@ -104,7 +108,6 @@ def configure(args) -> SystemManager: show_progress=args.show_progress, trace_execution=args.trace_execution, ) - vmfbs, params = get_modules(args) for key, vmfblist in vmfbs.items(): for vmfb in vmfblist: sm.load_inference_module(vmfb, component=key) @@ -114,15 +117,80 @@ def configure(args) -> SystemManager: return sysman -def get_modules(args): +def get_configs(args): + # Returns one set of config artifacts. + modelname = "sdxl" + model_config = args.model_config if args.model_config else None + topology_config = None + tuning_spec = None + flagfile = args.flagfile if args.flagfile else None + topology_inp = args.topology if args.topology else "spx_single" + cfg_builder_args = [ + sys.executable, + "-m", + "iree.build", + os.path.join(THIS_DIR, "components", "config_artifacts.py"), + f"--target={args.target}", + f"--output-dir={args.artifacts_dir}", + f"--model={modelname}", + f"--topology={topology_inp}", + ] + outs = subprocess.check_output(cfg_builder_args).decode() + outs_paths = outs.splitlines() + for i in outs_paths: + if "sdxl_config" in i and not args.model_config: + model_config = i + elif "topology" in i and args.topology: + topology_config = i + elif "flagfile" in i and not args.flagfile: + flagfile = i + elif "attention_and_matmul_spec" in i and args.use_tuned: + tuning_spec = i + + if args.use_tuned and args.tuning_spec: + tuning_spec = os.path.abspath(args.tuning_spec) + + if topology_config: + with open(topology_config, "r") as f: + contents = [line.rstrip() for line in f] + for spec in contents: + if "--" in spec: + arglist = spec.strip("--").split("=") + arg = arglist[0] + if len(arglist) > 2: + value = arglist[1:] + for val in value: + try: + val = int(val) + except ValueError: + continue + elif len(arglist) == 2: + value = arglist[-1] + try: + value = int(value) + except ValueError: + continue + else: + # It's a boolean arg. + value = True + setattr(args, arg, value) + else: + # It's an env var. + arglist = spec.split("=") + os.environ[arglist[0]] = arglist[1] + + return model_config, topology_config, flagfile, tuning_spec, args + + +def get_modules(args, model_config, flagfile, td_spec): # TODO: Move this out of server entrypoint vmfbs = {"clip": [], "unet": [], "vae": [], "scheduler": []} params = {"clip": [], "unet": [], "vae": []} model_flags = copy.deepcopy(vmfbs) model_flags["all"] = args.compile_flags - if args.flagfile: - with open(args.flagfile, "r") as f: + if flagfile: + with open(flagfile, "r") as f: contents = [line.rstrip() for line in f] flagged_model = "all" for elem in contents: @@ -131,6 +199,10 @@ def get_modules(args): flagged_model = elem else: model_flags[flagged_model].extend([elem]) + if td_spec: + model_flags["unet"].extend( + [f"--iree-codegen-transform-dialect-library={td_spec}"] + ) filenames = [] for modelname in vmfbs.keys(): @@ -140,7 +212,7 @@ def get_modules(args): "-m", "iree.build", os.path.join(THIS_DIR, "components", "builders.py"), - f"--model-json={args.model_config}", + f"--model-json={model_config}", f"--target={args.target}", f"--splat={args.splat}", f"--build-preference={args.build_preference}", @@ -165,6 +237,8 @@ def get_modules(args): def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): + from pathlib import Path + parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8000) @@ -212,8 +286,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): parser.add_argument( "--model_config", type=Path, - required=True, - help="Path to the model config file", + help="Path to the model config file. If None, defaults to i8 punet, batch size 1", ) parser.add_argument( "--workers_per_device", @@ -275,17 +348,36 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): ) parser.add_argument( "--artifacts_dir", + type=Path, + default=None, + help="Path to local artifacts cache.", + ) + parser.add_argument( + "--tuning_spec", type=str, default="", - help="Path to local artifacts cache.", + help="Path to transform dialect spec if compiling an executable with tunings.", + ) + parser.add_argument( + "--topology", + type=str, + default=None, + choices=["spx_single", "cpx_single", "spx_multi", "cpx_multi"], + help="Use one of four known performant preconfigured device/fiber topologies.", + ) + parser.add_argument( + "--use_tuned", + type=int, + default=1, + help="Use tunings for attention and matmul ops. 0 to disable.", ) args = parser.parse_args(argv) + if not args.artifacts_dir: + home = Path.home() + artdir = home / ".cache" / "shark" + args.artifacts_dir = str(artdir) - log_level = logging.INFO - - logging.root.setLevel(log_level) - logger.addHandler(logging.FileHandler("shortfin_sd.log")) global sysman sysman = configure(args) uvicorn.run( @@ -298,14 +390,31 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): if __name__ == "__main__": + logging.root.setLevel(logging.INFO) main( sys.argv[1:], # Make logging defer to the default shortfin logging config. log_config={ "version": 1, "disable_existing_loggers": False, - "formatters": {}, - "handlers": {}, - "loggers": {}, + "formatters": { + "default": { + "format": "%(asctime)s - %(levelname)s - %(message)s", + "datefmt": "%Y-%m-%d %H:%M:%S", + }, + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "default", + }, + }, + "loggers": { + "uvicorn": { + "handlers": ["console"], + "level": "INFO", + "propagate": False, + }, + }, }, ) diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py new file mode 100644 index 000000000..f8aabd8e7 --- /dev/null +++ b/shortfin/python/shortfin_apps/sd/simple_client.py @@ -0,0 +1,229 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import requests +import argparse +import base64 +import time +import asyncio +import aiohttp +import sys +import os + +from datetime import datetime as dt +from PIL import Image + +sample_request = { + "prompt": [ + " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + ], + "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"], + "height": [1024], + "width": [1024], + "steps": [20], + "guidance_scale": [7.5], + "seed": [0], + "output_type": ["base64"], + "rid": ["string"], +} + + +def bytes_to_img(bytes, idx=0, width=1024, height=1024, outputdir="./gen_imgs"): + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + image = Image.frombytes( + mode="RGB", size=(width, height), data=base64.b64decode(bytes) + ) + if not os.path.isdir(outputdir): + os.mkdir(outputdir) + im_path = os.path.join(outputdir, f"shortfin_sd_output_{timestamp}_{idx}.png") + image.save(im_path) + print(f"Saved to {im_path}") + + +def get_batched(request, arg, idx): + if isinstance(request[arg], list): + if len(request[arg]) == 1: + indexed = request[arg][0] + else: + indexed = request[arg][idx] + else: + indexed = request[arg] + return indexed + + +async def send_request(session, rep, args, data): + try: + print("Sending request batch #", rep) + url = f"http://0.0.0.0:{args.port}/generate" + start = time.time() + async with session.post(url, json=data) as response: + end = time.time() + # Check if the response was successful + if response.status == 200: + response.raise_for_status() # Raise an error for bad responses + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + res_json = await response.json(content_type=None) + if args.save: + for idx, item in enumerate(res_json["images"]): + width = get_batched(data, "width", idx) + height = get_batched(data, "height", idx) + print("Saving response as image...") + bytes_to_img( + item.encode("utf-8"), idx, width, height, args.outputdir + ) + latency = end - start + print("Responses processed.") + return latency, len(data["prompt"]) + else: + print(f"Error: Received {response.status} from server") + raise Exception + except Exception as e: + print(f"Request failed: {e}") + raise Exception + + +async def static(args): + # Create an aiohttp session for sending requests + async with aiohttp.ClientSession() as session: + pending = [] + latencies = [] + sample_counts = [] + # Read the JSON file if supplied. Otherwise, get user input. + try: + if args.file == "default": + data = sample_request + else: + with open(args.file, "r") as json_file: + data = json.load(json_file) + except Exception as e: + print(f"Error reading the JSON file: {e}") + return + data["prompt"] = ( + [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"] + ) + start = time.time() + + async for i in async_range(args.reps): + pending.append(asyncio.create_task(send_request(session, i, args, data))) + await asyncio.sleep(1) # Wait for 1 second before sending the next request + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.ALL_COMPLETED + ) + for task in done: + latency, num_samples = await task + latencies.append(latency) + sample_counts.append(num_samples) + end = time.time() + if not any([i is None for i in [latencies, sample_counts]]): + total_num_samples = sum(sample_counts) + sps = str(total_num_samples / (end - start)) + print(f"Average throughput: {sps} samples per second") + else: + raise ValueError("Received error response from server.") + + +async def interactive(args): + # Create an aiohttp session for sending requests + async with aiohttp.ClientSession() as session: + pending = [] + latencies = [] + sample_counts = [] + # Read the JSON file if supplied. Otherwise, get user input. + try: + if args.file == "default": + data = sample_request + else: + with open(args.file, "r") as json_file: + data = json.load(json_file) + except Exception as e: + print(f"Error reading the JSON file: {e}") + return + data["prompt"] = ( + [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"] + ) + while True: + prompt = await ainput("Enter a prompt: ") + data["prompt"] = [prompt] + data["steps"] = [args.steps] + print("Sending request with prompt: ", data["prompt"]) + + async for i in async_range(args.reps): + pending.append( + asyncio.create_task(send_request(session, i, args, data)) + ) + await asyncio.sleep( + 1 + ) # Wait for 1 second before sending the next request + while pending: + done, pending = await asyncio.wait( + pending, return_when=asyncio.ALL_COMPLETED + ) + for task in done: + latency, num_samples = await task + pending = [] + if any([i is None for i in [latencies, sample_counts]]): + raise ValueError("Received error response from server.") + + +async def ainput(prompt: str) -> str: + return await asyncio.to_thread(input, f"{prompt} ") + + +async def async_range(count): + for i in range(count): + yield (i) + await asyncio.sleep(0.0) + + +def main(argv): + p = argparse.ArgumentParser() + p.add_argument( + "--file", + type=str, + default="default", + help="A non-default request to send to the server.", + ) + p.add_argument( + "--reps", + type=int, + default=1, + help="Number of times to duplicate each request in one second intervals.", + ) + p.add_argument( + "--save", + action=argparse.BooleanOptionalAction, + default=True, + help="Save images. To disable, use --no-save", + ) + p.add_argument( + "--outputdir", + type=str, + default="gen_imgs", + help="Directory to which images get saved.", + ) + p.add_argument("--port", type=str, default="8000", help="Server port") + p.add_argument( + "--steps", + type=int, + default="20", + help="Number of inference steps. More steps usually means a better image. Interactive only.", + ) + p.add_argument( + "--interactive", + action="store_true", + help="Start as an example CLI client instead of sending static requests.", + ) + args = p.parse_args() + if args.interactive: + asyncio.run(interactive(args)) + else: + asyncio.run(static(args)) + + +if __name__ == "__main__": + main(sys.argv) From 51cf2f45139f67e15434990173b49375998db9f3 Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:40:58 -0600 Subject: [PATCH 28/59] Shortfin LLM Docs (#481) # Description The following docs outline how to export, and compile a Llama 8b f16 decomposed model, then run the Shortfin LLM Server with the the compiled model. It includes docs for both a `developer` flow and a `user` flow. There are a couple `TODOs` that can be updated/fixed as we make patches in shortfin and/or sharktank. --- .../llm/developer/e2e_llama8b_mi300x.md | 242 +++++++++++++++ docs/shortfin/llm/user/e2e_llama8b_mi300x.md | 278 ++++++++++++++++++ shortfin/python/shortfin_apps/llm/client.py | 16 +- 3 files changed, 531 insertions(+), 5 deletions(-) create mode 100644 docs/shortfin/llm/developer/e2e_llama8b_mi300x.md create mode 100644 docs/shortfin/llm/user/e2e_llama8b_mi300x.md diff --git a/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md b/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md new file mode 100644 index 000000000..e3150ed5c --- /dev/null +++ b/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md @@ -0,0 +1,242 @@ +# LLama 8b GPU Instructions on MI300X + +**NOTE: This was ran on the `mi300x-3` system** + +## Setup + +We will use an example with `llama_8b_f16_decomposed` in order to describe the +process of exporting a model for use in the shortfin llm server with an MI300 GPU. + +### Pre-Requisites + +- Python >= 3.11 is recommended for this flow + - You can check out [pyenv](https://github.com/pyenv/pyenv) as a good tool + to be able to manage multiple versions of python on the same system. + +### Setting Up Environment + +Follow the `Development Getting Started` docs +[here](https://github.com/nod-ai/SHARK-Platform/blob/main/README.md#development-getting-started) +to setup your environment for development. + +We will use an example with `llama_8b_f16_decomposed` in order to describe the +process of exporting a model for use in the shortfin llm server with an MI300 GPU. + +### Define a directory for export files + +Create a new directory for us to export files like `model.mlir`, `model.vmfb`, etc. + +```bash +mkdir $PWD/export +export EXPORT_DIR=$PWD/exportd +``` + +### Define environment variables + +Define the following environment variables to make running this example a bit easier: + +#### Model/Tokenizer vars + +This example uses the `llama8b_f16.irpa` and `tokenizer.json` files that are +pre-existing on the MI300X-3 system. +You may need to change the paths for your own system. + +```bash +export MODEL_PARAMS_PATH=/data/llama3.1/8b/llama8b_f16.irpa # Path to existing .irpa file, may need to change w/ system +export TOKENIZER_PATH=/data/llama3.1/8b/tokenizer.json # Path to existing tokenizer.json, may need to change w/ system +``` + +#### General env vars + +The following env vars can be copy + pasted directly: + +```bash +export MLIR_PATH=$EXPORT_DIR/model.mlir # Path to export model.mlir file +export OUTPUT_CONFIG_PATH=$EXPORT_DIR/config.json # Path to export config.json file +export EDITED_CONFIG_PATH=$EXPORT_DIR/edited_config.json # Path to export config.json file +export VMFB_PATH=$EXPORT_DIR/model.vmfb # Path to export model.vmfb file +export BS=1,4 # Batch size for kvcache +export ROCR_VISIBLE_DEVICES=1 # NOTE: This is temporary, until multi-device is fixed +``` + +### Export to MLIR + +We will now use the `sharktank.examples.export_paged_llm_v1` script to export +our model to `.mlir` format. + +```bash +python -m sharktank.examples.export_paged_llm_v1 \ + --irpa-file=$MODEL_PARAMS_PATH \ + --output-mlir=$MLIR_PATH \ + --output-config=$OUTPUT_CONFIG_PATH \ + --bs=$BS +``` + +## Compiling to `.vmfb` + +Now that we have generated a `model.mlir` file, we can compile it to `.vmfb` +format, which is required for running the `shortfin` LLM server. + +We will use the [iree-compile](https://iree.dev/developers/general/developer-overview/#iree-compile) +tool for compiling our model. + +### Compile for MI300 + +**NOTE: This command is specific to MI300 GPUs. +For other `--iree-hip-target` GPU options, +look [here](https://iree.dev/guides/deployment-configurations/gpu-rocm/#compile-a-program)** + +```bash +iree-compile $MLIR_PATH \ + --iree-hal-target-backends=rocm \ + --iree-hip-target=gfx942 \ + -o $VMFB_PATH +``` + +## Write an edited config + +We need to write a config for our model with a slightly edited structure +to run with shortfin. This will work for the example in our docs. +You may need to modify some of the parameters for a specific model. + +### Write edited config + +```bash +cat > $EDITED_CONFIG_PATH << EOF +{ + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 131072, + "attn_head_count": 8, + "attn_head_dim": 128, + "prefill_batch_sizes": [ + $BS + ], + "decode_batch_sizes": [ + $BS + ], + "transformer_block_count": 32, + "paged_kv_cache": { + "block_seq_stride": 16, + "device_block_count": 256 + } +} +EOF +``` + +## Running the `shortfin` LLM server + +We should now have all of the files that we need to run the shortfin LLM server. + +Verify that you have the following in your specified directory ($EXPORT_DIR): + +```bash +ls $EXPORT_DIR +``` + +- edited_config.json +- model.vmfb + +### Launch server: + +#### Set the target device + + + +#### Run the shortfin server + +Run the following command to launch the Shortfin LLM Server in the background: + +> **Note** +> By default, our server will start at `http://localhost:8000`. +> You can specify the `--host` and/or `--port` arguments, to run at a different address. +> +> If you receive an error similar to the following: +> +> `[errno 98] address already in use` +> +> Then, you can confirm the port is in use with `ss -ntl | grep 8000` +> and either kill the process running at that port, +> or start the shortfin server at a different port. + +```bash +python -m shortfin_apps.llm.server \ + --tokenizer_json=$TOKENIZER_PATH \ + --model_config=$EDITED_CONFIG_PATH \ + --vmfb=$VMFB_PATH \ + --parameters=$MODEL_PARAMS_PATH \ + --device=hip > shortfin_llm_server.log 2>&1 & +shortfin_process=$! +``` + +You can verify your command has launched successfully when you see the following + logs outputted to terminal: + +```bash +cat shortfin_llm_server.log +``` + +#### Expected output + +```text +[2024-10-24 15:40:27.440] [info] [on.py:62] Application startup complete. +[2024-10-24 15:40:27.444] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` + +## Verify server + +### Client script + +We can test the LLM server, by running our client script: + +```bash +python shortfin/python/shortfin_apps/llm/client.py --port 8000 +``` + +### Simple request + +Or by sending a simple request: + +### Open python shell + +```bash +python +``` + +### Send request + +```python +import requests + +import os + +port = 8000 # Change if running at a different port + +generate_url = f"http://localhost:{port}/generate" + +def generation_request(): + payload = {"text": "What is the capital of the United States?", "sampling_params": {"max_completion_tokens": 50}} + try: + resp = requests.post(generate_url, json=payload) + resp.raise_for_status() # Raises an HTTPError for bad responses + print(resp.text) + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") + +generation_request() +``` + +After you receive the request, you can exit the python shell: + +```bash +quit() +``` + +## Cleanup + +When done, you can kill the shortfin_llm_server by killing the process: + +```bash +kill -9 $shortfin_process +``` diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md new file mode 100644 index 000000000..985e55c13 --- /dev/null +++ b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md @@ -0,0 +1,278 @@ +# LLama 8b GPU instructions on MI300X + +## Setup + +We will use an example with `llama_8b_f16` in order to describe the +process of exporting a model for use in the shortfin llm server with an +MI300 GPU. + +### Pre-Requisites + +- Python >= 3.11 is recommended for this flow + - You can check out [pyenv](https://github.com/pyenv/pyenv) + as a good tool to be able to manage multiple versions of python + on the same system. + +### Create virtual environment + +To start, create a new virtual environment: + +```bash +python -m venv --prompt shark-ai .venv +source .venv/bin/activate +``` + +### Install `shark-ai` + +You can install either the `latest stable` version of `shark-ai` +or the `nightly` version: + +#### Stable + +```bash +pip install shark-ai +``` + +#### Nightly + +```bash +pip install sharktank -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels +pip install shortfin -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels +``` + +#### Install dataclasses-json + + + +```bash +pip install dataclasses-json +``` + +### Define a directory for export files + +Create a new directory for us to export files like +`model.mlir`, `model.vmfb`, etc. + +```bash +mkdir $PWD/export +export EXPORT_DIR=$PWD/export +``` + +### Download llama3_8b_fp16.gguf + +We will use the `hf_datasets` module in `sharktank` to download a +LLama3.1 8b f16 model. + +```bash +python -m sharktank.utils.hf_datasets amd-shark/llama3.1-8B --local-dir $EXPORT_DIR +``` + +### Define environment variables + +Define the following environment variables to make running +this example a bit easier: + +#### Model/Tokenizer vars + +This example uses the `llama8b_f16.gguf` and `tokenizer.json` files +that were downloaded in the previous step. + +```bash +export MODEL_PARAMS_PATH=$EXPORT_DIR/llama3.1-8b/llama8b_f16.gguf +export TOKENIZER_PATH=$EXPORT_DIR/llama3.1-8b/tokenizer.json +``` + +#### General env vars + +The following env vars can be copy + pasted directly: + +```bash +# Path to export model.mlir file +export MLIR_PATH=$EXPORT_DIR/model.mlir +# Path to export config.json file +export OUTPUT_CONFIG_PATH=$EXPORT_DIR/config.json +# Path to export edited_config.json file +export EDITED_CONFIG_PATH=$EXPORT_DIR/edited_config.json +# Path to export model.vmfb file +export VMFB_PATH=$EXPORT_DIR/model.vmfb +# Batch size for kvcache +export BS=1,4 +# NOTE: This is temporary, until multi-device is fixed +export ROCR_VISIBLE_DEVICES=1 +``` + +## Export to MLIR + +We will now use the `sharktank.examples.export_paged_llm_v1` script +to export our model to `.mlir` format. + +```bash +python -m sharktank.examples.export_paged_llm_v1 \ + --irpa-file=$MODEL_PARAMS_PATH \ + --output-mlir=$MLIR_PATH \ + --output-config=$OUTPUT_CONFIG_PATH \ + --bs=$BS +``` + +## Compiling to `.vmfb` + +Now that we have generated a `model.mlir` file, +we can compile it to `.vmfb` format, which is required for running +the `shortfin` LLM server. + +We will use the +[iree-compile](https://iree.dev/developers/general/developer-overview/#iree-compile) +tool for compiling our model. + +### Compile for MI300 + +**NOTE: This command is specific to MI300 GPUs. +For other `--iree-hip-target` GPU options, +look [here](https://iree.dev/guides/deployment-configurations/gpu-rocm/#compile-a-program)** + +```bash +iree-compile $MLIR_PATH \ + --iree-hal-target-backends=rocm \ + --iree-hip-target=gfx942 \ + -o $VMFB_PATH +``` + +## Write an edited config + +We need to write a config for our model with a slightly edited structure +to run with shortfin. This will work for the example in our docs. +You may need to modify some of the parameters for a specific model. + +### Write edited config + +```bash +cat > $EDITED_CONFIG_PATH << EOF +{ + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 131072, + "attn_head_count": 8, + "attn_head_dim": 128, + "prefill_batch_sizes": [ + $BS + ], + "decode_batch_sizes": [ + $BS + ], + "transformer_block_count": 32, + "paged_kv_cache": { + "block_seq_stride": 16, + "device_block_count": 256 + } +} +EOF +``` + +## Running the `shortfin` LLM server + +We should now have all of the files that we need to run the shortfin LLM server. + +Verify that you have the following in your specified directory ($EXPORT_DIR): + +```bash +ls $EXPORT_DIR +``` + +- edited_config.json +- model.vmfb + +### Launch server: + + + +#### Run the shortfin server + +Now that we are finished with setup, we can start the Shortfin LLM Server. + +Run the following command to launch the Shortfin LLM Server in the background: + +> **Note** +> By default, our server will start at `http://localhost:8000`. +> You can specify the `--host` and/or `--port` arguments, to run at a different address. +> +> If you receive an error similar to the following: +> +> `[errno 98] address already in use` +> +> Then, you can confirm the port is in use with `ss -ntl | grep 8000` +> and either kill the process running at that port, +> or start the shortfin server at a different port. + +```bash +python -m shortfin_apps.llm.server \ + --tokenizer_json=$TOKENIZER_PATH \ + --model_config=$EDITED_CONFIG_PATH \ + --vmfb=$VMFB_PATH \ + --parameters=$MODEL_PARAMS_PATH \ + --device=hip > shortfin_llm_server.log 2>&1 & +shortfin_process=$! +``` + +You can verify your command has launched successfully +when you see the following logs outputted to terminal: + +```bash +cat shortfin_llm_server.log +``` + +#### Expected output + +```text +[2024-10-24 15:40:27.440] [info] [on.py:62] Application startup complete. +[2024-10-24 15:40:27.444] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` + +## Verify server + +We can now verify our LLM server by sending a simple request: + +### Open python shell + +```bash +python +``` + +### Send request + +```python +import requests + +import os + +port = 8000 # Change if running on a different port + +generate_url = f"http://localhost:{port}/generate" + +def generation_request(): + payload = {"text": "What is the capital of the United States?", "sampling_params": {"max_completion_tokens": 50}} + try: + resp = requests.post(generate_url, json=payload) + resp.raise_for_status() # Raises an HTTPError for bad responses + print(resp.text) + except requests.exceptions.RequestException as e: + print(f"An error occurred: {e}") + +generation_request() +``` + +After you receive the request, you can exit the python shell: + +```bash +quit() +``` + +## Cleanup + +When done, you can kill the shortfin_llm_server by killing the process: + +```bash +kill -9 $shortfin_process +``` diff --git a/shortfin/python/shortfin_apps/llm/client.py b/shortfin/python/shortfin_apps/llm/client.py index f4e104a9f..e3ff3ec39 100644 --- a/shortfin/python/shortfin_apps/llm/client.py +++ b/shortfin/python/shortfin_apps/llm/client.py @@ -11,8 +11,6 @@ import time from typing import Dict, Any -BASE_URL = "http://localhost:8000" - def main() -> None: parser = argparse.ArgumentParser(description="Test LLM server") @@ -26,8 +24,16 @@ def main() -> None: parser.add_argument( "--stream", action="store_true", help="Enable response streaming" ) + parser.add_argument( + "--port", + type=str, + default="8000", + help="Port that shortfin server is running on", + ) args = parser.parse_args() + base_url = f"http://localhost:{args.port}" + data = { "text": args.text, "sampling_params": { @@ -42,13 +48,13 @@ def main() -> None: "stream": args.stream, } - print(f"Testing LLM server at {BASE_URL}") + print(f"Testing LLM server at {base_url}") # Health check with exponential backoff backoff = 1 while True: try: - requests.get(f"{BASE_URL}/health").raise_for_status() + requests.get(f"{base_url}/health").raise_for_status() break except requests.exceptions.RequestException as e: if backoff > 16: @@ -62,7 +68,7 @@ def main() -> None: try: print("Prompt text:", data["text"]) headers = {"Content-Type": "application/json"} - response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data) + response = requests.post(f"{base_url}/generate", headers=headers, json=data) response.raise_for_status() if response.text.startswith("data: "): From 5dd512a164532d23d9ad4fb2323cfc4868739c61 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 13 Nov 2024 21:23:28 -0600 Subject: [PATCH 29/59] (shortfin-sd) readme / client simplifications (#504) --- .../python/shortfin/support/logging_setup.py | 2 +- shortfin/python/shortfin_apps/sd/README.md | 25 ++------ .../sd/components/config_artifacts.py | 12 ---- shortfin/python/shortfin_apps/sd/server.py | 2 +- .../python/shortfin_apps/sd/simple_client.py | 61 +++++++++---------- 5 files changed, 37 insertions(+), 65 deletions(-) diff --git a/shortfin/python/shortfin/support/logging_setup.py b/shortfin/python/shortfin/support/logging_setup.py index 3cb373f1e..849d65bf3 100644 --- a/shortfin/python/shortfin/support/logging_setup.py +++ b/shortfin/python/shortfin/support/logging_setup.py @@ -38,7 +38,7 @@ def __init__(self): native_handler.setFormatter(NativeFormatter()) # TODO: Source from env vars. -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.WARNING) logger.addHandler(native_handler) diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md index 6dd701c62..30002ec40 100644 --- a/shortfin/python/shortfin_apps/sd/README.md +++ b/shortfin/python/shortfin_apps/sd/README.md @@ -10,32 +10,19 @@ In your shortfin environment, pip install transformers pip install dataclasses-json pip install pillow +pip install shark-ai ``` ``` python -m shortfin_apps.sd.server --help ``` -## Run tests - - - From SHARK-Platform/shortfin: - ``` - pytest --system=amdgpu -k "sd" - ``` - The tests run with splat weights. - - -## Run on MI300x - - - Follow quick start +# Run on MI300x +The server will prepare runtime artifacts for you. - - Navigate to shortfin/ (only necessary if you're using following CLI exactly.) -``` -cd shortfin/ -``` - - Run CLI server interface (you can find `sdxl_config_i8.json` in shortfin_apps/sd/examples): +By default, the port is set to 8000. If you would like to change this, use `--port` in each of the following commands. -The server will prepare runtime artifacts for you. +You can check if this (or any) port is in use on Linux with `ss -ntl | grep 8000`. ``` python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single" @@ -43,5 +30,5 @@ python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_prefere - Run a CLI client in a separate shell: ``` -python -m shortfin_apps.sd.simple_client --interactive --save +python -m shortfin_apps.sd.simple_client --interactive ``` diff --git a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py index b5a1d682b..f3502f22e 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py +++ b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py @@ -11,18 +11,6 @@ import shortfin.array as sfnp import copy -from shortfin_apps.sd.components.config_struct import ModelParams - -this_dir = os.path.dirname(os.path.abspath(__file__)) -parent = os.path.dirname(this_dir) - -dtype_to_filetag = { - sfnp.float16: "fp16", - sfnp.float32: "fp32", - sfnp.int8: "i8", - sfnp.bfloat16: "bf16", -} - ARTIFACT_VERSION = "11132024" SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/" diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 9ee81d1c4..2b7a93a91 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -355,7 +355,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): parser.add_argument( "--tuning_spec", type=str, - default="", + default=None, help="Path to transform dialect spec if compiling an executable with tunings.", ) parser.add_argument( diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py index f8aabd8e7..550fd7c60 100644 --- a/shortfin/python/shortfin_apps/sd/simple_client.py +++ b/shortfin/python/shortfin_apps/sd/simple_client.py @@ -32,7 +32,7 @@ } -def bytes_to_img(bytes, idx=0, width=1024, height=1024, outputdir="./gen_imgs"): +def bytes_to_img(bytes, outputdir, idx=0, width=1024, height=1024): timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") image = Image.frombytes( mode="RGB", size=(width, height), data=base64.b64decode(bytes) @@ -46,6 +46,7 @@ def bytes_to_img(bytes, idx=0, width=1024, height=1024, outputdir="./gen_imgs"): def get_batched(request, arg, idx): if isinstance(request[arg], list): + # some args are broadcasted to each prompt, hence overriding idx for single-item entries if len(request[arg]) == 1: indexed = request[arg][0] else: @@ -56,34 +57,30 @@ def get_batched(request, arg, idx): async def send_request(session, rep, args, data): - try: - print("Sending request batch #", rep) - url = f"http://0.0.0.0:{args.port}/generate" - start = time.time() - async with session.post(url, json=data) as response: - end = time.time() - # Check if the response was successful - if response.status == 200: - response.raise_for_status() # Raise an error for bad responses - timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") - res_json = await response.json(content_type=None) - if args.save: - for idx, item in enumerate(res_json["images"]): - width = get_batched(data, "width", idx) - height = get_batched(data, "height", idx) - print("Saving response as image...") - bytes_to_img( - item.encode("utf-8"), idx, width, height, args.outputdir - ) - latency = end - start - print("Responses processed.") - return latency, len(data["prompt"]) - else: - print(f"Error: Received {response.status} from server") - raise Exception - except Exception as e: - print(f"Request failed: {e}") - raise Exception + print("Sending request batch #", rep) + url = f"http://0.0.0.0:{args.port}/generate" + start = time.time() + async with session.post(url, json=data) as response: + end = time.time() + # Check if the response was successful + if response.status == 200: + response.raise_for_status() # Raise an error for bad responses + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + res_json = await response.json(content_type=None) + if args.save: + for idx, item in enumerate(res_json["images"]): + width = get_batched(data, "width", idx) + height = get_batched(data, "height", idx) + print("Saving response as image...") + bytes_to_img( + item.encode("utf-8"), args.outputdir, idx, width, height + ) + latency = end - start + print("Responses processed.") + return latency, len(data["prompt"]) + else: + print(f"Error: Received {response.status} from server") + raise Exception async def static(args): @@ -94,7 +91,7 @@ async def static(args): sample_counts = [] # Read the JSON file if supplied. Otherwise, get user input. try: - if args.file == "default": + if not args.file: data = sample_request else: with open(args.file, "r") as json_file: @@ -135,7 +132,7 @@ async def interactive(args): sample_counts = [] # Read the JSON file if supplied. Otherwise, get user input. try: - if args.file == "default": + if not args.file: data = sample_request else: with open(args.file, "r") as json_file: @@ -185,7 +182,7 @@ def main(argv): p.add_argument( "--file", type=str, - default="default", + default=None, help="A non-default request to send to the server.", ) p.add_argument( From d080ce1a1530f062f85cc1c7f5e5d1fb124b7e0e Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 14 Nov 2024 10:23:27 -0500 Subject: [PATCH 30/59] [tuner] Clean up candidate generation code (#508) * Pass context explicitly to the parse function * Fix mypy typing violations * Add pyproject.toml --- tuner/pyproject.toml | 24 +++++++ tuner/setup.py | 35 ++++++++++ tuner/tuner/candidate_gen.py | 108 +++++++++++++++--------------- tuner/tuner/candidate_gen_test.py | 102 ++++++++++++++-------------- tuner/tuner/py.typed | 0 tuner/version.json | 3 + 6 files changed, 167 insertions(+), 105 deletions(-) create mode 100644 tuner/pyproject.toml create mode 100644 tuner/setup.py create mode 100644 tuner/tuner/py.typed create mode 100644 tuner/version.json diff --git a/tuner/pyproject.toml b/tuner/pyproject.toml new file mode 100644 index 000000000..1661a7744 --- /dev/null +++ b/tuner/pyproject.toml @@ -0,0 +1,24 @@ +[project] +name = "SHARK Tuner" +authors = [ + {name = "SHARK Authors"}, +] +description = "IREE Dispatch Tuner" +readme = "README.md" +license = {text = "Apache-2.0"} +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +requires-python = ">= 3.10" + +# Version is set via the `setup.py`. +dynamic = ["version"] + +[project.urls] +Repository = "https://github.com/nod-ai/SHARK-Platform" diff --git a/tuner/setup.py b/tuner/setup.py new file mode 100644 index 000000000..aa450eaee --- /dev/null +++ b/tuner/setup.py @@ -0,0 +1,35 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import os + +from setuptools import setup + +SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__)) + +# Setup and get version information. +VERSION_FILE = os.path.join(SETUPPY_DIR, "version.json") +VERSION_FILE_LOCAL = os.path.join(SETUPPY_DIR, "version_local.json") + + +def load_version_info(version_file): + with open(version_file, "rt") as f: + return json.load(f) + + +try: + version_info = load_version_info(VERSION_FILE_LOCAL) +except FileNotFoundError: + print("version_local.json not found. Default to dev build") + version_info = load_version_info(VERSION_FILE) + +PACKAGE_VERSION = version_info.get("package-version") +print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'") + +setup( + version=f"{PACKAGE_VERSION}", +) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 40eb27a82..96bfc7146 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -23,18 +23,15 @@ import math import pickle import re -import z3 +import z3 # type: ignore from dataclasses import astuple, dataclass from enum import Enum -from os import mkdir, path, makedirs +from os import path, makedirs from typing import Optional from textwrap import indent from abc import ABC, abstractmethod -import iree.compiler as ireec -from iree.compiler import ir -from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen - +from iree.compiler import ir # type: ignore tune_logger = logging.getLogger("tune") @@ -520,15 +517,14 @@ def get_default_output_dir() -> str: return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") -def parse_mlir(mlir_text: str) -> ir.Module: +def parse_mlir(mlir_text: str, ctx: ir.Context) -> ir.Module: mlir_module = None - with ireec.ir.Context() as context: - try: - mlir_module = ireec.ir.Module.parse(mlir_text) - tune_logger.info("MLIR parsing successful!") - except ireec.ir.MLIRError as e: - tune_logger.error(f"Error parsing MLIR: {e}") - raise RuntimeError(f"Error parsing MLIR: {e}") + try: + mlir_module = ir.Module.parse(mlir_text) + tune_logger.info("MLIR parsing successful!") + except ir.MLIRError as e: + tune_logger.error(f"Error parsing MLIR: {e}") + raise RuntimeError(f"Error parsing MLIR: {e}") return mlir_module @@ -537,7 +533,7 @@ def parse_mlir(mlir_text: str) -> ir.Module: class MLIRTransformation: """Transformation of MLIR context""" - template: str + template: list[str] modified: str embeddable: str @@ -550,7 +546,7 @@ def supports(self, op_name: str) -> bool: @abstractmethod def get_shapes(self, template: list[str]) -> ProblemSize: - """Extract problem size of thge operation.""" + """Extract problem size of the operation.""" pass @abstractmethod @@ -645,7 +641,7 @@ def get_shapes(self, template: list[str]) -> ProblemSize: dispatch_kind=DispatchKind.mmt, ) assert mmt_re - assert dps, f"'{mmt_re}' not found in given context" + assert False, f"'{mmt_re}' not found in given context" def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration @@ -1353,45 +1349,47 @@ def tune( mlir_template = read_input_mlir(input_file) mlir_text = "".join(mlir_template) - mlir_module = parse_mlir(mlir_text) - # Save the input file as the first candidate. - with open(path.join(output, f"0.mlir"), "w") as f: - f.write(mlir_text) - - dispatch_tuner_registry = DispatchTunerRegistry() - dispatch_tuner_registry.register( - [ - MmtTuner(), - ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims), - BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), - ] - ) - - walk_result = walk_mlir_op(mlir_module, dispatch_tuner_registry) - - dispatch_tuner = walk_result.dispatch_tuner - problem_size = dispatch_tuner.get_shapes(mlir_template) - tune_logger.debug(str(problem_size)) - configs = [] - for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): - if i >= limit: - break - tune_logger.info(f"Solution #{i+1}: {config}") - configs.append(config) - tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) - - with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(tf_mlir.modified) - with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(tf_mlir.embeddable) - - with open(path.join(output, "configs.pkl"), "wb") as file: - pickle.dump(configs, file) + with ir.Context() as ctx: + mlir_module: ir.Module = parse_mlir(mlir_text, ctx) + # Save the input file as the first candidate. + with open(path.join(output, f"0.mlir"), "w") as f: + f.write(mlir_text) + + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + MmtTuner(), + ConvTuner(), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), + BatchMmtTuner(), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), + ] + ) - tune_logger.info(f"Generated {len(configs)} candidates") - tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") + walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) + + dispatch_tuner = walk_result.dispatch_tuner + assert dispatch_tuner, "No suitable dispatch tuner found" + problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) + tune_logger.debug(str(problem_size)) + configs = [] + for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + configs.append(config) + tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) + + with open(path.join(output, f"{i+1}.mlir"), "w") as f: + f.write(tf_mlir.modified) + with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: + f.write(tf_mlir.embeddable) + + with open(path.join(output, "configs.pkl"), "wb") as file: + pickle.dump(configs, file) + + tune_logger.info(f"Generated {len(configs)} candidates") + tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") def main(): diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 2924db75b..a1a3a3e49 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -11,8 +11,11 @@ import pytest from . import candidate_gen +from iree.compiler import ir # type: ignore +from iree.compiler.dialects import func # type: ignore -def test_get_shaped_type_element_bitwidth(): + +def test_get_shaped_type_element_bitwidth() -> None: assert ( candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth == 8 @@ -31,7 +34,7 @@ def test_get_shaped_type_element_bitwidth(): ) -def test_get_shaped_type_to_str(): +def test_get_shaped_type_to_str() -> None: assert ( str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) == "1024x2048xi8" @@ -50,7 +53,7 @@ def test_get_shaped_type_to_str(): ) -def test_parse_tensor_type(): +def test_parse_tensor_type() -> None: assert candidate_gen.parse_tensor_type( "tensor<1x2x3xf32>" ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) @@ -59,11 +62,11 @@ def test_parse_tensor_type(): ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) -def test_get_mmt_tile_sizes(): +def test_get_mmt_tile_sizes() -> None: config = candidate_gen.Configuration( subgroup_size=0, workgroup_size=[], - intrinsic="", + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[128, 320, 32], subgroup_m_count=0, subgroup_n_count=0, @@ -73,11 +76,11 @@ def test_get_mmt_tile_sizes(): assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] -def test_get_conv_tile_sizes(): +def test_get_conv_tile_sizes() -> None: config = candidate_gen.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic="#iree_gpu.mma_layout", + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, @@ -95,7 +98,7 @@ def test_get_conv_tile_sizes(): ] -def test_gpu_pipeline_options(): +def test_gpu_pipeline_options() -> None: options = candidate_gen.GpuPipelineOptions() assert options.all_default() assert str(options) == "#iree_gpu.pipeline_options<>" @@ -121,32 +124,32 @@ def test_gpu_pipeline_options(): ) -def test_get_contract_tile_sizes(): +def test_get_contract_tile_sizes() -> None: config = candidate_gen.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic="", + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), waves_per_eu=2, ) - assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ + assert candidate_gen.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] + assert candidate_gen.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] + assert candidate_gen.get_contract_tile_sizes(config, "knm") == [16, 8, 4] + assert candidate_gen.get_contract_tile_sizes(config, "kkk") == [ 16, 16, 16, ] -def test_get_pipeline_config(): +def test_get_pipeline_config() -> None: config = candidate_gen.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic="", + intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, @@ -168,7 +171,7 @@ def test_get_pipeline_config(): ) -def test_get_shapes_mmt(): +def test_get_shapes_mmt() -> None: template = [ r"%18 = tensor.empty() : tensor<2048x1280xf32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", @@ -184,7 +187,7 @@ def test_get_shapes_mmt(): ) -def test_get_shapes_conv(): +def test_get_shapes_conv() -> None: template = [ r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", @@ -199,7 +202,7 @@ def test_get_shapes_conv(): ) -def test_get_shapes_contract(): +def test_get_shapes_contract() -> None: template = [ r"%18 = tensor.empty() : tensor<2048x1280xf32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", @@ -217,7 +220,7 @@ def test_get_shapes_contract(): ) -def test_get_shapes_batch_matmul(): +def test_get_shapes_batch_matmul() -> None: template = [ "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", @@ -234,7 +237,7 @@ def test_get_shapes_batch_matmul(): ) -def test_get_shapes_batch_mmt(): +def test_get_shapes_batch_mmt() -> None: template = [ r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', @@ -251,7 +254,7 @@ def test_get_shapes_batch_mmt(): ) -def test_mfma_intrinsic_to_str(): +def test_mfma_intrinsic_to_str() -> None: assert ( str(candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" @@ -262,7 +265,7 @@ def test_mfma_intrinsic_to_str(): ) -def test_get_compatible_mfma_intrinsics(): +def test_get_compatible_mfma_intrinsics() -> None: assert candidate_gen.get_compatible_mfma_intrinsics( candidate_gen.ProblemSize( candidate_gen.MatmulSize(2048, 1280, 1280), @@ -303,7 +306,7 @@ def test_get_compatible_mfma_intrinsics(): ] -def test_generate_solutions(): +def test_generate_solutions() -> None: matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) @@ -315,7 +318,7 @@ def test_generate_solutions(): assert configs is not None -def test_calculate_shared_memory_usage_in_bytes(): +def test_calculate_shared_memory_usage_in_bytes() -> None: matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) @@ -347,7 +350,7 @@ def test_calculate_shared_memory_usage_in_bytes(): ) -def test_generate_constraints_valid_input(): +def test_generate_constraints_valid_input() -> None: matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) @@ -392,7 +395,7 @@ def test_generate_constraints_valid_input(): assert solver.check() == candidate_gen.z3.sat -def test_generate_constraints_invalid_input(): +def test_generate_constraints_invalid_input() -> None: # Define input parameters that should lead to unsatisfiable constraints matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) @@ -444,7 +447,7 @@ def remove_comments(mlir: str) -> str: ) -def test_apply_params_mmt(): +def test_apply_params_mmt() -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", " None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", " None: mlir_template = [ ", subgroup_m_count = 2, subgroup_n_count = 2>}>", " None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: mlir_lines = [ r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", @@ -861,18 +864,17 @@ def test_detect_broadcast_rhs_mmt(): ) -def test_parse_mlir(): - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - mlir_module = candidate_gen.parse_mlir(mlir_str) - assert mlir_module != None - assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) - assert isinstance( - mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp - ) +def test_parse_mlir() -> None: + with ir.Context() as ctx: + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } + """ + mlir_module = candidate_gen.parse_mlir(mlir_str, ctx) + assert mlir_module is not None + assert isinstance(mlir_module, ir.Module) + assert isinstance(mlir_module.body.operations[0], func.FuncOp) diff --git a/tuner/tuner/py.typed b/tuner/tuner/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/tuner/version.json b/tuner/version.json new file mode 100644 index 000000000..794a2de28 --- /dev/null +++ b/tuner/version.json @@ -0,0 +1,3 @@ +{ + "package-version": "2.9.1.dev" +} From 9d45921db43e8d2f8cef0a5473a082e5d0d8d92d Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Thu, 14 Nov 2024 08:09:49 -0800 Subject: [PATCH 31/59] Iterate on root README.md. (#505) Progress on https://github.com/nod-ai/SHARK-Platform/issues/359 * Relax warning at the top of the file about project status then add a new warning about sharktank status * Swap shortfin and sharktank ordering under sub-projects (no longer alphabetical, but lead with what is most user-facing) * Reword developer instructions * Fix nightly IREE install instructions to use `--pre` --- README.md | 88 +++++++++++++++++++++++----------------- docs/nightly_releases.md | 10 ++--- 2 files changed, 55 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index aa4c46bdc..d187c23a0 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,28 @@ # SHARK Modeling and Serving Libraries -**WARNING: This is an early preview that is in progress. It is not ready for -general use.** +> [!IMPORTANT] +> Development is still in progress for several project components. See the +> notes below for which workflows are best supported. ![GitHub License](https://img.shields.io/github/license/nod-ai/SHARK-Platform) - [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit) +[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit) ## Sub-projects +### [`shortfin/`](./shortfin/) + + + +[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_linux_x64-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush) + +The shortfin sub-project is SHARK's high performance inference library and +serving engine. + +* API documentation for shortfin is available on + [readthedocs](https://shortfin.readthedocs.io/en/latest/). + ### [`sharktank/`](./sharktank/) [![PyPI version](https://badge.fury.io/py/sharktank.svg)](https://badge.fury.io/py/sharktank) [![CI - sharktank](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-sharktank.yml/badge.svg?event=push)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-sharktank.yml?query=event%3Apush) @@ -17,6 +30,11 @@ general use.** The SHARK Tank sub-project contains a collection of model recipes and conversion tools to produce inference-optimized programs. +> [!WARNING] +> SHARK Tank is still under development. Experienced users may want to try it +> out, but we currently recommend most users download pre-exported or +> pre-compiled model files for serving with shortfin. + * See the [SHARK Tank Programming Guide](./docs/programming_guide.md) for @@ -25,18 +43,6 @@ conversion tools to produce inference-optimized programs. * See [Direct Quantization with SHARK Tank](./docs/quantization.md) for information about quantization support. -### [`shortfin/`](./shortfin/) - - - -[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_linux_x64-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush) - -The shortfin sub-project is SHARK's high performance inference library and -serving engine. - -* API documentation for shortfin is available on - [readthedocs](https://shortfin.readthedocs.io/en/latest/). - ### [`tuner/`](./tuner/) [![CI - Tuner](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-tuner.yml/badge.svg?event=push)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-tuner.yml?query=event%3Apush) @@ -55,20 +61,19 @@ Model name | Model recipes | Serving apps SDXL | [`sharktank/sharktank/models/punet/`](https://github.com/nod-ai/SHARK-Platform/tree/main/sharktank/sharktank/models/punet) | [`shortfin/python/shortfin_apps/sd/`](https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps/sd) llama | [`sharktank/sharktank/models/llama/`](https://github.com/nod-ai/SHARK-Platform/tree/main/sharktank/sharktank/models/llama) | [`shortfin/python/shortfin_apps/llm/`](https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps/llm) -## Development getting started +## Development tips - - -Use this as a guide to get started developing the project using pinned, -pre-release dependencies. You are welcome to deviate as you see fit, but -these canonical directions mirror what the CI does. +Each sub-project has its own developer guide. If you would like to work across +projects, these instructions should help you get started: ### Setup a venv -We recommend setting up a virtual environment (venv). The project is configured -to ignore `.venv` directories, and editors like VSCode pick them up by default. +We recommend setting up a Python +[virtual environment (venv)](https://docs.python.org/3/library/venv.html). +The project is configured to ignore `.venv` directories, and editors like +VSCode pick them up by default. -``` +```bash python -m venv .venv source .venv/bin/activate ``` @@ -76,35 +81,42 @@ source .venv/bin/activate ### Install PyTorch for your system If no explicit action is taken, the default PyTorch version will be installed. -This will give you a current CUDA-based version. Install a different variant -by doing so explicitly first: +This will give you a current CUDA-based version, which takes longer to download +and includes other dependencies that SHARK does not require. To install a +different variant, run one of these commands first: -*CPU:* +* *CPU:* -``` -pip install -r pytorch-cpu-requirements.txt -``` + ```bash + pip install -r pytorch-cpu-requirements.txt + ``` -*ROCM:* +* *ROCM:* -``` -pip install -r pytorch-rocm-requirements.txt -``` + ```bash + pip install -r pytorch-rocm-requirements.txt + ``` + +* *Other:* see instructions at . ### Install development packages -``` +```bash # Install editable local projects. pip install -r requirements.txt -e sharktank/ shortfin/ -# Optionally clone and install editable iree-turbine dep in deps/ -pip install -f https://iree.dev/pip-release-links.html --src deps \ +# Optionally clone and install the latest editable iree-turbine dep in deps/, +# along with nightly versions of iree-base-compiler and iree-base-runtime. +pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler iree-base-runtime --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" ``` +See also: [`docs/nightly_releases.md`](./docs/nightly_releases.md). + ### Running tests -``` +```bash pytest sharktank pytest shortfin ``` diff --git a/docs/nightly_releases.md b/docs/nightly_releases.md index 706d6e755..819e22f61 100644 --- a/docs/nightly_releases.md +++ b/docs/nightly_releases.md @@ -67,7 +67,7 @@ python3.11 -m venv 3.11.venv source 3.11.venv/bin/activate # Install 'sharktank' package from nightly releases. -python -m pip install sharktank -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels +pip install sharktank -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels # Test the installation. python -c "from sharktank import ops; print('Sanity check passed')" @@ -84,7 +84,7 @@ python3.11 -m venv 3.11.venv source 3.11.venv/bin/activate # Install 'shortfin' package from nightly releases. -python -m pip install shortfin -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels +pip install shortfin -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels # Test the installation. python -c "import shortfin as sf; print('Sanity check passed')" @@ -98,7 +98,7 @@ deactivate To install the `iree-turbine` package from the latest source: ```bash -python -m pip install --src deps \ +pip install --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" ``` @@ -106,14 +106,14 @@ To install the `iree-base-compiler` and `iree-base-runtime` packages from nightly releases: ```bash -python -m pip install -f https://iree.dev/pip-release-links.html --upgrade \ +pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ iree-base-compiler iree-base-runtime ``` To install all three packages together: ```bash -python -m pip install -f https://iree.dev/pip-release-links.html --upgrade \ +pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ iree-base-compiler iree-base-runtime --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" ``` From f429c91b4247c73f30585c9453d9444812d9bd51 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Thu, 14 Nov 2024 08:18:26 -0800 Subject: [PATCH 32/59] Fix workflows installing nightly IREE packages to use `--pre`. (#506) Pip will only install "pre-release" versions if `--pre` is set or the version is set explicitly: https://pip.pypa.io/en/stable/cli/pip_install/#pre-release-versions. The `rc` part of our `rcYYYYMMDD` suffix used in the nightly IREE packages is considered a "Pre-release segment" in the version string: https://packaging.python.org/en/latest/specifications/version-specifiers/#public-version-identifiers. This should change workflow jobs like https://github.com/nod-ai/SHARK-Platform/actions/runs/11827078971/job/32954462017#step:5:163 that have been downloading `iree-base-compiler-2.9.0 iree-base-runtime-2.9.0` to downloading the latest nightly versions instead, as they intend: https://github.com/nod-ai/SHARK-Platform/blob/51cf2f45139f67e15434990173b49375998db9f3/.github/workflows/ci-sharktank.yml#L64-L67 Co-authored-by: Marius Brehler --- .github/workflows/ci-shark-platform.yml | 2 +- .github/workflows/ci-sharktank.yml | 2 +- .github/workflows/ci-tuner.yml | 2 +- .github/workflows/ci_eval.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-shark-platform.yml b/.github/workflows/ci-shark-platform.yml index 445e2e448..06a05e3cf 100644 --- a/.github/workflows/ci-shark-platform.yml +++ b/.github/workflows/ci-shark-platform.yml @@ -67,7 +67,7 @@ jobs: # Try with the latest IREE nightly releases, not what iree-turbine pins. # We could also pin to a known working or stable version. # This should eventually stabilize. Do the best we can for now. - pip install -f https://iree.dev/pip-release-links.html --upgrade \ + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ iree-base-compiler \ iree-base-runtime \ "numpy<2.0" diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml index 4c660e6ee..7d6a7b7f1 100644 --- a/.github/workflows/ci-sharktank.yml +++ b/.github/workflows/ci-sharktank.yml @@ -62,7 +62,7 @@ jobs: pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/ # Update to the latest iree packages. - pip install -f https://iree.dev/pip-release-links.html --upgrade \ + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ iree-base-compiler iree-base-runtime --src deps \ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index cd9a48d5e..dad766d8e 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -49,7 +49,7 @@ jobs: pip install -r tuner/requirements-tuner.txt python -m pip install \ --find-links https://iree.dev/pip-release-links.html \ - --upgrade \ + --upgrade --pre \ iree-base-compiler iree-base-runtime - name: Run tuner tests diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 7afaeb1fe..105f53ad8 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -69,7 +69,7 @@ jobs: # Try with the latest IREE nightly releases, not what iree-turbine pins. # We could also pin to a known working or stable version. # This should eventually stabilize. Do the best we can for now. - pip install -f https://iree.dev/pip-release-links.html --upgrade \ + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ iree-base-compiler \ iree-base-runtime \ "numpy<2.0" From fa122befeb935ddf3f5d552aee298e889172b3a5 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:45:22 -0600 Subject: [PATCH 33/59] (shortfin-sd) Cleanup args and instructions. (#511) --- shortfin/python/shortfin_apps/sd/README.md | 13 ++++++++++++- shortfin/python/shortfin_apps/sd/server.py | 6 ------ shortfin/python/shortfin_apps/sd/simple_client.py | 4 +++- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md index 30002ec40..0bb5db511 100644 --- a/shortfin/python/shortfin_apps/sd/README.md +++ b/shortfin/python/shortfin_apps/sd/README.md @@ -13,6 +13,13 @@ pip install pillow pip install shark-ai ``` + +Temporarily, you may need an update to your `shortfin` install. +Install the latest pre-release with: +``` +pip install shortfin --upgrade --pre -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels +``` + ``` python -m shortfin_apps.sd.server --help ``` @@ -27,7 +34,11 @@ You can check if this (or any) port is in use on Linux with `ss -ntl | grep 8000 ``` python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single" ``` - + - Wait until your server outputs: +``` +INFO - Application startup complete. +INFO - Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) +``` - Run a CLI client in a separate shell: ``` python -m shortfin_apps.sd.simple_client --interactive diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 2b7a93a91..9cd624241 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -242,12 +242,6 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8000) - parser.add_argument( - "--root-path", - type=str, - default=None, - help="Root path to use for installing behind path based proxy.", - ) parser.add_argument( "--timeout-keep-alive", type=int, default=5, help="Keep alive timeout" ) diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py index 550fd7c60..bc0f10655 100644 --- a/shortfin/python/shortfin_apps/sd/simple_client.py +++ b/shortfin/python/shortfin_apps/sd/simple_client.py @@ -119,7 +119,9 @@ async def static(args): if not any([i is None for i in [latencies, sample_counts]]): total_num_samples = sum(sample_counts) sps = str(total_num_samples / (end - start)) - print(f"Average throughput: {sps} samples per second") + # Until we have better measurements, don't report the throughput that includes saving images. + if not args.save: + print(f"Average throughput: {sps} samples per second") else: raise ValueError("Received error response from server.") From 3e7429544fd81938e2b1c0d8fbb098b1c871ecf7 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Thu, 14 Nov 2024 09:12:02 -0800 Subject: [PATCH 34/59] Iterate on shortfin/README.md. (#501) Progress on https://github.com/nod-ai/SHARK-Platform/issues/359. This page will appear at https://pypi.org/project/shortfin/, so include enough context for users. I'm still deciding where it makes sense to highlight current support status for platforms, operating systems, models, etc. --------- Co-authored-by: Marius Brehler --- shortfin/README.md | 90 +++++++++++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 32 deletions(-) diff --git a/shortfin/README.md b/shortfin/README.md index 9818e05d3..13ee20966 100644 --- a/shortfin/README.md +++ b/shortfin/README.md @@ -1,20 +1,52 @@ # shortfin - SHARK inference library and serving engine -## Simple User Installation +The shortfin project is SHARK's open source, high performance inference library +and serving engine. Shortfin consists of these major components: -Install: +* The "libshortfin" inference library written in C/C++ and built on + [IREE](https://github.com/iree-org/iree) +* Python bindings for the underlying inference library +* Example applications in + ['shortfin_apps'](https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps) + built using the python bindings +## Prerequisites + +* Python 3.11+ + +## Simple user installation + +Install the latest stable version: + +```bash +pip install shortfin ``` -python -m pip install . + +## Developer guides + +### Quick start: install local packages and run tests + +After cloning this repository, from the `shortfin/` directory: + +```bash +pip install -e . ``` -Run tests: +Install test requirements: +```bash +pip install -r requirements-tests.txt ``` -python -m pytest -s tests/ + +Run tests: + +```bash +pytest -s tests/ ``` -## Simple Dev Setup +### Simple dev setup + +We recommend this development setup for core contributors: 1. Check out this repository as a sibling to [IREE](https://github.com/iree-org/iree) if you already have an IREE source checkout. Otherwise, a pinned version will @@ -36,7 +68,7 @@ python -m pytest -s tests/ Refer to the advanced build options below for other scenarios. -## Advanced Build Options +### Advanced build options 1. Native C++ build 2. Local Python release build @@ -48,7 +80,7 @@ Prerequisites * A modern C/C++ compiler, such as clang 18 or gcc 12 * A modern Python, such as Python 3.12 -### Native C++ Builds +#### Native C++ builds ```bash cmake -GNinja -S. -Bbuild \ @@ -61,13 +93,7 @@ If Python bindings are enabled in this mode (`-DSHORTFIN_BUILD_PYTHON_BINDINGS=O then `pip install -e build/` will install from the build dir (and support build/continue). -### Local Python Release Builds - -```bash -pip install -v -e . -``` - -### Package Python Release Builds +#### Package Python release builds * To build wheels for Linux using a manylinux Docker container: @@ -86,7 +112,7 @@ pip install -v -e . python3 -m pip install dist/*.whl ``` -### Python Dev Builds +#### Python dev builds ```bash # Install build system pre-reqs (since we are building in dev mode, this @@ -124,7 +150,7 @@ Several optional environment variables can be used with setup.py: * `SHORTFIN_RUN_CTESTS=ON` : Runs `ctest` as part of the build. Useful for CI as it uses the version of ctest installed in the pip venv. -### Running Tests +### Running tests The project uses a combination of ctest for native C++ tests and pytest. Much of the functionality is only tested via the Python tests, using the @@ -156,7 +182,7 @@ pytest tests/ --system amdgpu \ --compile-flags="--iree-hal-target-backends=rocm --iree-hip-target=gfx1100" ``` -# Production Library Building +## Production library building In order to build a production library, additional build steps are typically recommended: @@ -167,23 +193,23 @@ recommended: * Enable LTO builds of libshortfin * Set flags to enable symbol versioning -# Miscellaneous Build Topics +## Miscellaneous build topics -## Free-threaded Python +### Free-threaded Python Support for free-threaded Python builds (aka. "nogil") is in progress. It -is currently being tested via dev builds of CPython 3.13 with the -`--disable-gil` option set. There are multiple ways to acquire such an -environment. If using `pyenv`, here is a way: +is currently being tested via CPython 3.13 with the `--disable-gil` option set. +There are multiple ways to acquire such an environment: -``` -# Build a free-threaded 3.13 version. -pyenv install --debug 3.13t-dev +* Generally, see the documentation at + +* If using `pyenv`: -# Test (should print "False"). -pyenv shell 3.13t-dev -python -c 'import sys; print(sys._is_gil_enabled())' -``` + ```bash + # Install a free-threaded 3.13 version. + pyenv install 3.13t -Further ways of installing a free-threaded CPython interpreter are documented at -[py-free-threading.github.io](https://py-free-threading.github.io/installing_cpython/). + # Test (should print "False"). + pyenv shell 3.13t + python -c 'import sys; print(sys._is_gil_enabled())' + ``` From 4fcf77b7ed9c5e72fa302152e24406510924c47e Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 14 Nov 2024 18:26:43 +0100 Subject: [PATCH 35/59] [sharktank] Unpin NumPy on non-Windows (#496) --- .github/workflows/ci-llama-large-tests.yaml | 3 +-- .github/workflows/ci-llama-quick-tests.yaml | 3 +-- .github/workflows/ci-shark-platform.yml | 3 +-- .github/workflows/ci_eval.yaml | 3 +-- sharktank/requirements.txt | 3 ++- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index d79031b8c..5645efd8a 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -71,8 +71,7 @@ jobs: # Test with pinned nightly releases, not what iree-turbine uses. pip install -f https://iree.dev/pip-release-links.html --upgrade \ iree-base-compiler==2.9.0rc20241108 \ - iree-base-runtime==2.9.0rc20241108 \ - "numpy<2.0" + iree-base-runtime==2.9.0rc20241108 - name: Run llama tests run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-all-llama --iree-hip-target=gfx942 --html=out/index.html diff --git a/.github/workflows/ci-llama-quick-tests.yaml b/.github/workflows/ci-llama-quick-tests.yaml index decd0aa96..585a759ac 100644 --- a/.github/workflows/ci-llama-quick-tests.yaml +++ b/.github/workflows/ci-llama-quick-tests.yaml @@ -72,8 +72,7 @@ jobs: # Test with pinned nightly releases, not what iree-turbine uses. pip install -f https://iree.dev/pip-release-links.html --upgrade \ iree-base-compiler==2.9.0rc20241108 \ - iree-base-runtime==2.9.0rc20241108 \ - "numpy<2.0" + iree-base-runtime==2.9.0rc20241108 - name: Run llama 8b tests run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --run-8b-llama diff --git a/.github/workflows/ci-shark-platform.yml b/.github/workflows/ci-shark-platform.yml index 06a05e3cf..6741f7ea0 100644 --- a/.github/workflows/ci-shark-platform.yml +++ b/.github/workflows/ci-shark-platform.yml @@ -69,8 +69,7 @@ jobs: # This should eventually stabilize. Do the best we can for now. pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ iree-base-compiler \ - iree-base-runtime \ - "numpy<2.0" + iree-base-runtime - name: Run LLM Integration Tests run: pytest -v build_tools/integration_tests/llm --log-cli-level=INFO diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml index 105f53ad8..856d37c40 100644 --- a/.github/workflows/ci_eval.yaml +++ b/.github/workflows/ci_eval.yaml @@ -71,8 +71,7 @@ jobs: # This should eventually stabilize. Do the best we can for now. pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ iree-base-compiler \ - iree-base-runtime \ - "numpy<2.0" + iree-base-runtime - name: Run perplexity test with vmfb run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_vmfb_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index dd8f14fb6..19e48f825 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -2,7 +2,8 @@ iree-turbine # Runtime deps. gguf==0.6.0 -numpy==1.26.3 +numpy==1.26.3; sys_platform == 'win32' +numpy; sys_platform != 'win32' # Needed for newer gguf versions (TODO: remove when gguf package includes this) # sentencepiece>=0.1.98,<=0.2.0 From 7cc1db3338985722abeb4abcc52ba3101f47811a Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 14 Nov 2024 19:00:27 +0100 Subject: [PATCH 36/59] Specify optional dependencies for shortfin_apps (#513) --- shark-ai/pyproject.toml | 3 +++ shortfin/pyproject.toml | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/shark-ai/pyproject.toml b/shark-ai/pyproject.toml index 5a7493ec7..f78a1641f 100644 --- a/shark-ai/pyproject.toml +++ b/shark-ai/pyproject.toml @@ -30,6 +30,9 @@ Repository = "https://github.com/nod-ai/SHARK-Platform" onnx = [ "iree-base-compiler[onnx]", ] +apps = [ + "shortfin[apps]", +] [tool.setuptools] packages = [] diff --git a/shortfin/pyproject.toml b/shortfin/pyproject.toml index 15bd68732..67483fb05 100644 --- a/shortfin/pyproject.toml +++ b/shortfin/pyproject.toml @@ -34,6 +34,13 @@ dynamic = ["version"] Repository = "https://github.com/nod-ai/SHARK-Platform" Documentation = "https://shortfin.readthedocs.io/en/latest/" +[project.optional-dependencies] +apps = [ + "transformers", + "dataclasses-json", + "pillow", +] + [tool.pytest.ini_options] addopts = [ "-ra", From a786114dac996cb21f59bfd383ba284ee9b9803f Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 14 Nov 2024 19:07:39 +0100 Subject: [PATCH 37/59] [shark-ai] Temporarily drop sharktank dependency (#514) This drops the strict dependency to sharktank which otherwise always gets pulled in when installing the shark-ai package. --- .../python_deploy/write_requirements.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/build_tools/python_deploy/write_requirements.py b/build_tools/python_deploy/write_requirements.py index a89b74dfe..38ae5d2b3 100644 --- a/build_tools/python_deploy/write_requirements.py +++ b/build_tools/python_deploy/write_requirements.py @@ -61,13 +61,14 @@ def write_requirements(requirements): stable_packages_list = ["iree-base-compiler", "iree-base-runtime", "iree-turbine"] if Version(PACKAGE_VERSION).is_prerelease: - requirements = ( - "sharktank==" - + Version(SHARKTANK_PACKAGE_VERSION).base_version - + "rc" - + args.version_suffix - + "\n" - ) + # TODO: Include sharktank as a dependencies of future releases + # requirements = ( + # "sharktank==" + # + Version(SHARKTANK_PACKAGE_VERSION).base_version + # + "rc" + # + args.version_suffix + # + "\n" + # ) requirements += ( "shortfin==" + Version(SHORTFIN_PACKAGE_VERSION).base_version @@ -86,9 +87,10 @@ def write_requirements(requirements): requirements = "" for package in stable_packages_list: requirements += package + "==" + STABLE_VERSION_TO_PIN + "\n" - requirements += ( - "sharktank==" + Version(SHARKTANK_PACKAGE_VERSION).base_version + "\n" - ) + # TODO: Include sharktank as a dependencies of future releases + # requirements += ( + # "sharktank==" + Version(SHARKTANK_PACKAGE_VERSION).base_version + "\n" + # ) requirements += "shortfin==" + Version(SHORTFIN_PACKAGE_VERSION).base_version write_requirements(requirements) From bfc7738a8895d88a2edadf7dcadc7e87cd265d70 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Thu, 14 Nov 2024 14:16:48 -0500 Subject: [PATCH 38/59] Revert "Rework kv cache to avoid needless reshaping (#450)" (#509) This reverts commit b36ddf7842e7124926b5f99d89e504bcbef93753. The beginning of the fix to the brand new November KV cache issue --- .../sharktank/examples/export_paged_llm_v1.py | 4 - sharktank/sharktank/layers/kv_cache.py | 110 +++++++++++++----- .../layers/paged_llama_attention_block.py | 2 +- sharktank/tests/layers/kv_cache_test.py | 20 ++-- .../layers/sharded_paged_kv_cache_test.py | 16 +-- 5 files changed, 98 insertions(+), 54 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 7bf76a2ce..f22b2ccbd 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -214,8 +214,6 @@ def _(model, tokens, seq_lens, seq_block_ids, cs): cache_tensors = repack_cache(cs, cache_shard_dim) - cache_tensors = [model.cache.unflatten_page_table(cache_tensors)] - logits = model.prefill( tokens, attention_mask=attention_mask, @@ -302,8 +300,6 @@ def _( cache_state = repack_cache(cache_state, cache_shard_dim) - cache_state = [model.cache.unflatten_page_table(cache_state)] - logits = model.decode( tokens, attention_mask=attention_mask, diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py index d9ed05f79..c73b7a8f4 100644 --- a/sharktank/sharktank/layers/kv_cache.py +++ b/sharktank/sharktank/layers/kv_cache.py @@ -141,7 +141,7 @@ def read( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], *, - dest_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], transformer_block_index: int, seq_len: int, page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None, @@ -150,7 +150,7 @@ def read( Args: state: State struct as returned from allocate(). - dest_partitions: List of cache partitions to read into in-place. + read_into_partitions: List of cache partitions to read into in-place. transformer_block_index: The index of the transformer block accessing the cache. page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids @@ -161,7 +161,7 @@ def read( materializing linearly may not be terribly efficient unless if the compiler can fuse the gather. """ - read_count = len(dest_partitions) + read_count = len(read_into_partitions) reads = [] for i in range(read_count): reads.append( @@ -284,10 +284,6 @@ def unflatten_page_table( """Unflattens the 2D page table to a 6D tensor.""" assert len(state) == 1, f"Expected 1-element state. Got: {len(state)}" page_slab = state[0] - - if len(page_slab.shape) == 6: - return page_slab - if self.shard_count == 1: assert not isinstance(page_slab, SplitPrimitiveTensor) return page_slab.unflatten(1, self.sub_page_dims) @@ -356,7 +352,7 @@ def read( self, state: list[Union[torch.Tensor, SplitPrimitiveTensor]], *, - dest_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], + read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]], transformer_block_index: int, seq_len: int, page_ids: Union[torch.Tensor, ReplicatedTensor], @@ -365,7 +361,7 @@ def read( Args: state: State struct as returned from allocate(). - dest_partitions: List of cache partitions to read into in-place. + read_into_partitions: List of cache partitions to read into in-place. transformer_block_index: The index of the transformer block accessing the cache. page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids @@ -378,9 +374,36 @@ def read( """ page_table = self.unflatten_page_table(state) # 6D - def read_cache_partitions( - into_partitions: List[Union[torch.Tensor, SplitPrimitiveTensor]] + bs, block_seq_len, *_ = page_ids.shape + # Blocks dim 1,2 according to the configured block stride. + blocked_shape = [ + bs, + block_seq_len, + self.block_seq_stride, + self.attn_head_count // self.shard_count, + self.attn_head_dim, + ] + + # Reshape the page cache into sub-blocks so that we can index at the + # granularity of the transformer_block and cache partition. + # This requires us to recompute indices to the sub-block reference + # frame. + # The subblock slab is organized as: + # [page, attn_layer, cache_partition] + # Where the cache line can be 0 (k) or 1 (v). + subblock_table = page_table.flatten(start_dim=0, end_dim=2) + page_stride = self.transformer_block_count * self.cache_partition_count + transformer_block_stride = self.cache_partition_count + base_subblock_ids = page_ids * page_stride + ( + transformer_block_index * transformer_block_stride + ) + + def read_cache_partition( + index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor] ): + subblock_ids = ( + (base_subblock_ids + index) if index > 0 else base_subblock_ids + ) # TODO: Potentially clamp all page 0 indices to the mask value. # Or even better, require that the ids are replicated such that access is # legal. @@ -389,16 +412,18 @@ def read_cache_partitions( # copy of the sub-blocks by collapsing the first two dims so we have # a linear list. # TODO: Can be rewritten into inplace with out= on index_select. + selected = ( + ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1)) + .unflatten(0, blocked_shape[0:2]) + .flatten(1, 2) + ) + # trace_tensor("kv.selected", selected) + into_partition[...] = selected - for i, into_partition in enumerate(into_partitions): - selected = page_table[ - page_ids.flatten(0, 1), transformer_block_index, i - ] - selected = selected.unflatten(0, page_ids.shape).flatten(1, 2) - into_partition[...] = selected + for index, read_into_partition in enumerate(read_into_partitions): + read_cache_partition(index, read_into_partition) - read_cache_partitions(dest_partitions) - return tuple([p[:, :seq_len, :] for p in dest_partitions]) + return tuple([p[:, :seq_len, :] for p in read_into_partitions]) def write_timestep( self, @@ -463,25 +488,46 @@ def write( in-place scatter cannot be fused. """ page_table = self.unflatten_page_table(state) # 6D - _, block_seq_len, *_ = page_ids.shape + + bs, block_seq_len, *_ = page_ids.shape + # Blocks dim 1,2 according to the configured block stride. + blocked_shape = [ + bs, + block_seq_len, + self.block_seq_stride, + self.attn_head_count, + self.attn_head_dim, + ] + + # Reshape the page cache into sub-blocks so that we can index at the + # granularity of the transformer_block and cache partition. + # This requires us to recompute indices to the sub-block reference + # frame. + # The subblock slab is organized as: + # [page, attn_layer, cache_partition] + # Where the cache line can be 0 (k) or 1 (v). + subblock_table = page_table.flatten(start_dim=0, end_dim=2) + page_stride = self.transformer_block_count * self.cache_partition_count + transformer_block_stride = self.cache_partition_count + base_subblock_ids = page_ids * page_stride + ( + transformer_block_index * transformer_block_stride + ) part_block_views = [] - for partition in cache_partitions: + subblock_ids_kv = [] + for index, partition in enumerate(cache_partitions): part_block_view = partition.unflatten( 1, (block_seq_len, self.block_seq_stride) ) - part_block_view = part_block_view.unsqueeze(2) + part_block_view = part_block_view.flatten(0, 1) part_block_views.append(part_block_view) - part_block_view = ops.cat(part_block_views, dim=2) + subblock_ids = ( + (base_subblock_ids + index) if index > 0 else base_subblock_ids + ).flatten(0, 1) + subblock_ids_kv.append(subblock_ids) - page_ids = page_ids.flatten(0, 1) - part_block_view = part_block_view.flatten(0, 1) + subblock_ids = ops.cat(subblock_ids_kv) + part_block_view = ops.cat(part_block_views, dim=0) - page_table.index_put_( - ( - page_ids, - torch.full(page_ids.shape, transformer_block_index, dtype=torch.int64), - ), - part_block_view, - ) + subblock_table.index_copy_(0, subblock_ids, part_block_view) diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 29aaa8d7c..6b460d81b 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -241,7 +241,7 @@ def transact_cache( # Restore from the cache. xk, xv = cache.read( cache_state, - dest_partitions=[ + read_into_partitions=[ xk_temp[:, 0:kv_seq_len, ...], xv_temp[:, 0:kv_seq_len, ...], ], diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py index ab66dd97b..65b42c986 100644 --- a/sharktank/tests/layers/kv_cache_test.py +++ b/sharktank/tests/layers/kv_cache_test.py @@ -56,7 +56,7 @@ def test_direct(): ] read_back = cache.read( allocation, - dest_partitions=read_empty, + read_into_partitions=read_empty, transformer_block_index=1, seq_len=write_seq_length, ) @@ -79,7 +79,7 @@ def test_direct(): ] read_ones = cache.read( allocation, - dest_partitions=read_ones, + read_into_partitions=read_ones, transformer_block_index=i, seq_len=write_seq_length, ) @@ -113,7 +113,7 @@ def test_direct(): ] read_back = cache.read( allocation, - dest_partitions=read_empty, + read_into_partitions=read_empty, transformer_block_index=1, seq_len=write_seq_length + 1, ) @@ -184,7 +184,7 @@ def test_sharded_direct(): ] read_back = cache.read( allocation, - dest_partitions=read_empty, + read_into_partitions=read_empty, transformer_block_index=1, seq_len=write_seq_length, ) @@ -225,7 +225,7 @@ def test_sharded_direct(): ] read_back = cache.read( allocation, - dest_partitions=read_empty, + read_into_partitions=read_empty, transformer_block_index=1, seq_len=write_seq_length + 1, ) @@ -288,7 +288,7 @@ def test_paged(): ] read_back = cache.read( allocation, - dest_partitions=read_empty, + read_into_partitions=read_empty, transformer_block_index=1, seq_len=write_seq_length, page_ids=write_page_ids, @@ -312,7 +312,7 @@ def test_paged(): ] read_ones = cache.read( allocation, - dest_partitions=read_ones, + read_into_partitions=read_ones, transformer_block_index=i, seq_len=write_seq_length, page_ids=write_page_ids, @@ -348,7 +348,7 @@ def test_paged(): ] read_back = cache.read( allocation, - dest_partitions=read_empty, + read_into_partitions=read_empty, transformer_block_index=1, seq_len=write_seq_length + 1, page_ids=page_ids, @@ -436,7 +436,7 @@ def test_sharded_paged(): read_back = cache.read( allocation, - dest_partitions=read_empty, + read_into_partitions=read_empty, transformer_block_index=1, seq_len=write_seq_length, page_ids=write_page_ids, @@ -489,7 +489,7 @@ def test_sharded_paged(): read_back = cache.read( allocation, - dest_partitions=[empty_k, empty_v], + read_into_partitions=[empty_k, empty_v], transformer_block_index=1, seq_len=write_seq_length + 1, page_ids=page_ids, diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py index 766c4e804..d7b6a0b33 100644 --- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py +++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py @@ -104,7 +104,7 @@ def testRead(self): sharded_cache_state, ) = self.make_unsharded_and_sharded_equal_cache_states() - dest_partitions_snapshot = [ + read_into_partitions_snapshot = [ torch.rand( self.batch_size, self.block_seq_len * self.block_seq_stride, @@ -113,33 +113,35 @@ def testRead(self): ) for _ in range(self.cache_partition_count) ] - dest_partitions = deepcopy(dest_partitions_snapshot) + read_into_partitions = deepcopy(read_into_partitions_snapshot) transformer_block_index = 1 page_ids = torch.randint( low=0, high=self.page_count, size=[self.batch_size, self.block_seq_len] ).reshape([self.batch_size, self.block_seq_len]) self.cache.read( state=cache_state, - dest_partitions=dest_partitions, + read_into_partitions=read_into_partitions, transformer_block_index=transformer_block_index, page_ids=page_ids, seq_len=self.block_seq_len * self.block_seq_stride, ) - sharded_dest_partitions = deepcopy( + sharded_read_into_partitions = deepcopy( [ ops.reshard_split(t, dim=2, count=self.shard_count) - for t in dest_partitions_snapshot + for t in read_into_partitions_snapshot ] ) sharded_page_ids = ops.replicate(page_ids, count=self.shard_count) self.sharded_cache.read( state=sharded_cache_state, - dest_partitions=sharded_dest_partitions, + read_into_partitions=sharded_read_into_partitions, transformer_block_index=transformer_block_index, page_ids=sharded_page_ids, seq_len=self.block_seq_len * self.block_seq_stride, ) - for unsharded, sharded in zip(dest_partitions, sharded_dest_partitions): + for unsharded, sharded in zip( + read_into_partitions, sharded_read_into_partitions + ): assert ops.equal(unsharded, ops.unshard(sharded)) def testWriteTimestep(self): From e381e871adf03658efae32684f4fbb6025ea6e76 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 14 Nov 2024 12:15:14 -0800 Subject: [PATCH 39/59] Restore RotaryEmbedding use complex numbers (#517) It appears the change away from complex numbers triggered a downstream iree failure. An inflight work to use `flow.tensor.bitcast` and restore to complex numbers appears to fix the issue. Fixing forward as we wanted to revert part of this change anyway. --- sharktank/sharktank/examples/paged_llm_v1.py | 8 + sharktank/sharktank/kernels/__init__.py | 1 + sharktank/sharktank/kernels/bitcast.py | 138 ++++++++++++++++++ .../sharktank/layers/rotary_embedding.py | 81 ++++------ sharktank/sharktank/models/llama/llama.py | 23 --- sharktank/sharktank/ops/custom_impls.py | 19 ++- sharktank/sharktank/ops/default_impls.py | 10 ++ sharktank/sharktank/ops/sharded_impls.py | 24 +++ sharktank/sharktank/ops/signatures.py | 36 +++++ sharktank/sharktank/types/tensors.py | 2 +- 10 files changed, 263 insertions(+), 79 deletions(-) create mode 100644 sharktank/sharktank/kernels/bitcast.py diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 10c76e644..6d0bfd14c 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -196,6 +196,14 @@ def decode(self): trace_tensor("decode.start_positions", start_positions) trace_tensor("decode.seq_block_ids", seq_block_ids_tensor) trace_tensor("decode.attention_mask", decode_attention_mask) + + if model.config.tensor_parallelism_size != 1: + tp = model.config.tensor_parallelism_size + self.next_tokens = replicate(self.next_tokens, tp) + start_positions = replicate(start_positions, tp) + seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp) + decode_attention_mask = replicate(decode_attention_mask, tp) + logits = model.decode( self.next_tokens, attention_mask=decode_attention_mask, diff --git a/sharktank/sharktank/kernels/__init__.py b/sharktank/sharktank/kernels/__init__.py index beb7e90a2..445f44852 100644 --- a/sharktank/sharktank/kernels/__init__.py +++ b/sharktank/sharktank/kernels/__init__.py @@ -14,3 +14,4 @@ from .conv_2d_nchw_fchw import * from .pooling_nchw_sum import * from .base import * +from .bitcast import * diff --git a/sharktank/sharktank/kernels/bitcast.py b/sharktank/sharktank/kernels/bitcast.py new file mode 100644 index 000000000..66850008f --- /dev/null +++ b/sharktank/sharktank/kernels/bitcast.py @@ -0,0 +1,138 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from sharktank.kernels.base import * + +import torch + +from iree.turbine.support.ir_imports import ( + ComplexType, + F16Type, + F32Type, + RankedTensorType, + ShapedType, + Value, + flow_d, + tensor_d, +) + +from iree.turbine.runtime.op_reg import ( + CustomOp, + KernelBuilder, + KernelSelection, +) + +__all__ = [ + "bitcast_to_complex", + "bitcast_to_real", +] + +_ftype_to_ctype_table = { + torch.float16: torch.complex32, + torch.float32: torch.complex64, +} + +_ctype_to_ftype_table = { + torch.complex32: torch.float16, + torch.complex64: torch.float32, +} + +_type_to_irtype_table = { + torch.float16: lambda: F16Type.get(), + torch.float32: lambda: F32Type.get(), + torch.complex32: lambda: ComplexType.get(F16Type.get()), + torch.complex64: lambda: ComplexType.get(F32Type.get()), +} + + +@CustomOp.register(library=LIBRARY) +class bitcast_to_complex(CustomOp): + + signature = "bitcast_to_complex(Tensor q) -> (Tensor)" + + def select(self, ksel: KernelSelection): + ta = ksel.arg_tensor(0) + + torch._check(ta.t.dtype in _ftype_to_ctype_table) + torch._check(isinstance(ta.t.shape[-1], int)) + + new_shape = [i for i in ta.t.shape] + new_shape[-1] = new_shape[-1] // 2 + + ctype = _ftype_to_ctype_table[ta.t.dtype] + ret = ksel.return_new_tensor(new_shape, dtype=ctype) + specialize_all_known_dims(ta) + specialize_all_known_dims(ret) + + def eager_execute(self, tensor): + return torch.view_as_complex(tensor.unflatten(-1, (-1, 2))) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + t = kb.arg_bindings[0] + result_desc = ksel.result_descs[0] + result_shape = [ + d if isinstance(d, int) else RankedTensorType.get_dynamic_size() + for d in result_desc.t.shape + ] + + dynamic_dims: list[Value] = [] + _append_dynamic_dims(kb, dynamic_dims, t) + + c64 = _type_to_irtype_table[result_desc.t.dtype]() + rtt = RankedTensorType.get(result_shape, c64) + result = flow_d.TensorBitCastOp(rtt, t, dynamic_dims, dynamic_dims).result + kb.yield_results(result) + + +@CustomOp.register(library=LIBRARY) +class bitcast_to_real(CustomOp): + + signature = "bitcast_to_real(Tensor q) -> (Tensor)" + + def select(self, ksel: KernelSelection): + ta = ksel.arg_tensor(0) + + torch._check(ta.t.dtype in _ctype_to_ftype_table) + torch._check(isinstance(ta.t.shape[-1], int)) + + new_shape = [i for i in ta.t.shape] + new_shape[-1] = new_shape[-1] * 2 + + ftype = _ctype_to_ftype_table[ta.t.dtype] + ret = ksel.return_new_tensor(new_shape, dtype=ftype) + specialize_all_known_dims(ta) + specialize_all_known_dims(ret) + + def eager_execute(self, tensor): + return torch.view_as_real(tensor).flatten(-2, -1) + + def generate(self, ksel: KernelSelection, kb: KernelBuilder): + t = kb.arg_bindings[0] + result_desc = ksel.result_descs[0] + result_shape = [ + d if isinstance(d, int) else RankedTensorType.get_dynamic_size() + for d in result_desc.t.shape + ] + + dynamic_dims: list[Value] = [] + _append_dynamic_dims(kb, dynamic_dims, t) + + ftype = _type_to_irtype_table[result_desc.t.dtype]() + rtt = RankedTensorType.get(result_shape, ftype) + result = flow_d.TensorBitCastOp(rtt, t, dynamic_dims, dynamic_dims).result + kb.yield_results(result) + + +################################################################################ +# Emission utilities +################################################################################ + + +def _append_dynamic_dims(kb: KernelBuilder, dynamic_dims: list[Value], tensor: Value): + rtt = RankedTensorType(tensor.type) + for i in range(rtt.rank): + if rtt.is_dynamic_dim(i): + dynamic_dims.append(tensor_d.dim(tensor, kb.constant_index(i))) diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py index c11a2d126..0664a9a46 100644 --- a/sharktank/sharktank/layers/rotary_embedding.py +++ b/sharktank/sharktank/layers/rotary_embedding.py @@ -138,54 +138,29 @@ def create_ordering_tensor(dim): if self.use_hf: xt = xt[..., create_interleaved_tensor(xt.shape[-1])] - xt_ = xt.unflatten(-1, (-1, 2)) - _, sl, _, dim, _ = xt_.shape + xt_ = xt + _, sl, _, _ = xt_.shape # Offset the table based on starting position. if self.use_table: - freqs_cis = rotary_embed_table[:, start_index : start_index + sl, :] + freqs_cis = rotary_embed_table[start_index : start_index + sl, :] + freqs_cis = freqs_cis[None, 0:sl, None, :] else: - freqs_cis = torch.arange(start_index, start_index + sl, device=xt.device) - freqs_cis = self._compute_rotary_embed_table(freqs_cis) + freqs_cis = torch.arange(sl, device=xt.device) + start_index + freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :] - assert freqs_cis.shape[-1] == dim assert ( freqs_cis.shape[1] >= sl ), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})" - broadcast_freqs_cis = freqs_cis[:, None, 0:sl, None, :] - - cos = broadcast_freqs_cis[0] - sin = broadcast_freqs_cis[1] - xt_r = xt_[..., 0] - xt_i = xt_[..., 1] - - xt_out_r = xt_r * cos - xt_i * sin - xt_out_i = xt_i * cos + xt_r * sin - - xt_out = torch.concatenate((xt_out_r, xt_out_i), dim=-1) + xt_ = ops.view_as_complex(xt_) + xt_ = xt_ * freqs_cis + xt_out = ops.view_as_real(xt_) if self.use_hf: xt_out = xt_out[..., create_ordering_tensor(xt_out.shape[-1])] - return xt_out.type_as(xt) - return xt_out.type_as(xt) - - def complex_multiply(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - """Function for elementwise-multiplication of two complex torch tensors. - Functionally similar to a*b, but numerically accurate for HuggingFace - LLaMa implementation. - - Args: - a: First torch tensor operand - b: Second torch tensor operand - Returns: - Tensor of same size to a, b whose elements is product of corresponding - elements in a, b - """ - return torch.complex( - a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real - ) + return ops.to(xt_out, xt.dtype) def compute_batch_mask( self, start_positions: Union[torch.Tensor, ReplicatedTensor], batch_seq_len: int @@ -207,11 +182,18 @@ def compute_batch_mask( self.trace_tensor("rope.positions_seq", positions_seq) if self.use_table: - freqs_cis = self.rotary_embed_table[:, positions_seq] + freqs_cis = self.rotary_embed_table[positions_seq] else: shape = positions_seq.shape - freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) - freqs_cis = freqs_cis.unflatten(1, shape) + if isinstance(positions_seq, ReplicatedTensor): + ts = [ + self._compute_rotary_embed_table(s.flatten()).unflatten(0, shape) + for s in positions_seq.shards + ] + freqs_cis = ReplicatedTensor(ts=ts) + else: + freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten()) + freqs_cis = freqs_cis.unflatten(0, shape) # Unsqueeze a unit dim for attention heads. broadcast_freqs_cis = freqs_cis.unsqueeze(2) @@ -247,30 +229,23 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor): """ # xq_, xk_ shape: bs, sl, _, dim # freqs_cis shape: max_sl, dim - cos = mask[0] - sin = mask[1] - - xt_ = xt.unflatten(-1, (-1, 2)) - xt_r = xt_[..., 0] - xt_i = xt_[..., 1] + xt_ = ops.view_as_complex(xt) + xt_ = xt_ * mask + xt_out = ops.view_as_real(xt_) - xt_out_r = xt_r * cos - xt_i * sin - xt_out_i = xt_r * sin + xt_i * cos - xt_out = torch.concatenate((xt_out_r, xt_out_i), dim=-1) return xt_out.type_as(xt) def _compute_rotary_embed_table(self, t): dim = self.rope_dimension_count freqs = 1.0 / ( - self.rope_freq_base - ** (torch.arange(0, dim, 2, device=t.device)[: (dim // 2)].float() / dim) + self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) ) freqs = torch.outer(t, freqs).float() - cos = torch.cos(freqs).unsqueeze(0) - sin = torch.sin(freqs).unsqueeze(0) - - return torch.concatenate((cos, sin), dim=0) + cos = torch.cos(freqs) + sin = torch.sin(freqs) + complex = torch.complex(cos, sin) + return complex def _create_rotary_embed_table(self): t = torch.arange(self.max_seqlen, device=self.device) diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index ef3c4800d..2ec25e171 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -186,29 +186,6 @@ def decode( self._assert_device(start_positions) self._assert_device(*cache_state, dtype=self.activation_dtype) - if self.config.tensor_parallelism_size > 1: - if not isinstance(tokens, ReplicatedTensor): - tokens = ops.replicate( - tokens, count=self.config.tensor_parallelism_size - ) - if not isinstance(attention_mask, ReplicatedTensor): - attention_mask = ops.replicate( - attention_mask, count=self.config.tensor_parallelism_size - ) - if not isinstance(start_positions, ReplicatedTensor): - start_positions = ops.replicate( - start_positions, count=self.config.tensor_parallelism_size - ) - if not isinstance(seq_block_ids, ReplicatedTensor): - seq_block_ids = ops.replicate( - seq_block_ids, count=self.config.tensor_parallelism_size - ) - # If the user provided unsharded arguments they probably want - # an unsharded result as well. - unshard_result = True - else: - unshard_result = False - bs, _ = tokens.shape # Precompute a position based mask for computing rope embeddings # as it is the same for all blocks. diff --git a/sharktank/sharktank/ops/custom_impls.py b/sharktank/sharktank/ops/custom_impls.py index c5079f6d4..8f6654a8e 100644 --- a/sharktank/sharktank/ops/custom_impls.py +++ b/sharktank/sharktank/ops/custom_impls.py @@ -5,21 +5,24 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import torch + from torch import Tensor, dtype +from typing import Union + import torch.nn.functional as F from ..kernels import ( einsum_2args_q4, mmt_block_scaled_offset_q4_unsigned, mmt_block_scaled_q8, - mmtfp, mmt_super_block_scaled_offset_q4_unsigned, + bitcast_to_complex, + bitcast_to_real, ) from ..types import ( BlockScaledLayout, BlockScaledI4Layout, - InferenceTensor, PrimitiveTensor, QuantizedTensor, SuperBlockOffsetScaled_4_6_Layout, @@ -123,3 +126,15 @@ def matmul_generic_tensor_super_block_offset_scaled_4_6_i4( sb_mins_low, rhs_unpacked.qs_bit_packed, ) + + +@view_as_complex.override(Union[Tensor, PrimitiveTensor]) +def view_as_complex(t): + t = unbox_tensor(t) + return bitcast_to_complex(t) + + +@view_as_real.override(Union[Tensor, PrimitiveTensor]) +def view_as_real(t): + t = unbox_tensor(t) + return bitcast_to_real(t) diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 08a9c896b..40384b21e 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -503,3 +503,13 @@ def view_QuantizedTensor(tensor: QuantizedTensor, shape): new_m = unpacked.m.view(shape[:-1] + [shape[-1] // 32, 1]) layout = BlockScaledI4Layout(shape=shape, d=new_d, qs=new_qs, m=new_m) return PlanarQuantizedTensor(shape=shape, layout=layout) + + +@view_as_complex.override(Tensor) +def view_as_complex_default(tensor: Union[Tensor, PrimitiveTensor]) -> Tensor: + return torch.view_as_complex(unbox_tensor(tensor)) + + +@view_as_real.override(Tensor) +def view_as_real_default(tensor: Union[Tensor, PrimitiveTensor]) -> Tensor: + return torch.view_as_real(unbox_tensor(tensor)) diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index 87592c6fd..4aa473e08 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -1303,3 +1303,27 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards) assert math.prod(res.shape) == math.prod(tensor.shape) return res + + +@view_as_complex.override(SplitPrimitiveTensor) +def view_as_complex_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor: + shards = [view_as_complex(shard) for shard in tensor.shards] + return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + + +@view_as_complex.override(ReplicatedTensor) +def view_as_complex_rep(tensor: ReplicatedTensor) -> ReplicatedTensor: + shards = [view_as_complex(shard) for shard in tensor.shards] + return ReplicatedTensor(ts=shards) + + +@view_as_real.override(SplitPrimitiveTensor) +def view_as_real_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor: + shards = [view_as_real(shard) for shard in tensor.shards] + return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim) + + +@view_as_real.override(ReplicatedTensor) +def view_as_real_rep(tensor: ReplicatedTensor) -> ReplicatedTensor: + shards = [view_as_real(shard) for shard in tensor.shards] + return ReplicatedTensor(ts=shards) diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py index d9002ce37..762b99896 100644 --- a/sharktank/sharktank/ops/signatures.py +++ b/sharktank/sharktank/ops/signatures.py @@ -59,6 +59,8 @@ "unshard", "unsqueeze", "view", + "view_as_complex", + "view_as_real", ] IntOrSequenceInt = Union[int, Sequence[int]] @@ -1087,3 +1089,37 @@ def _view_trampoline( return override, result else: d.fail(tensors) + + +@overridable +def view_as_complex(tensor: AnyTensor, shape: List[int]) -> AnyTensor: + """See torch.Tensor.view_as_complex""" + ... + + +@view_as_complex.trampoline +def _view_as_complex_trampoline(d: SignatureDispatcher, tensor: AnyTensor) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) + + +@overridable +def view_as_real(tensor: AnyTensor, shape: List[int]) -> AnyTensor: + """See torch.Tensor.view_as_complex""" + ... + + +@view_as_real.trampoline +def _view_as_real_trampoline(d: SignatureDispatcher, tensor: AnyTensor) -> AnyTensor: + tensors = (tensor,) + for override in d.find_overrides(tensors): + result = override(tensor) + if result is not NotImplemented: + return override, result + else: + d.fail(tensors) diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 87a40fb7b..f870aa101 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -1191,7 +1191,7 @@ def create( def __getitem__(self, key): keys = [key] - if isinstance(keys, tuple) or isinstance(keys, list): + if isinstance(key, tuple) or isinstance(key, list): keys = key shards = [] From 754e168992e75a61cfb7caf128b2fc71255af947 Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:44:53 -0600 Subject: [PATCH 40/59] Restore CPU LLM Server integration test (#520) Remove `xfail` from CPU LLM Server Integration test. Shortfin server is back in a good spot, and this should be left on anyways --- build_tools/integration_tests/llm/cpu_llm_server_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/build_tools/integration_tests/llm/cpu_llm_server_test.py b/build_tools/integration_tests/llm/cpu_llm_server_test.py index 638bce7ee..4d4ec5540 100644 --- a/build_tools/integration_tests/llm/cpu_llm_server_test.py +++ b/build_tools/integration_tests/llm/cpu_llm_server_test.py @@ -78,7 +78,6 @@ def do_generate(prompt, port): ], indirect=True, ) -@pytest.mark.xfail(raises=AccuracyValidationException) def test_llm_server(llm_server, available_port): # Here you would typically make requests to your server # and assert on the responses @@ -86,7 +85,6 @@ def test_llm_server(llm_server, available_port): output = do_generate("1 2 3 4 5 ", available_port) logger.info(output) expected_output_prefix = "6 7 8" - # TODO(#437): Remove when accuracy issue from latest iree-compiler RC is resolved. if not output.startswith(expected_output_prefix): raise AccuracyValidationException( f"Expected '{output}' to start with '{expected_output_prefix}'" From e9ba3efe314bd2d83ec11e55f6f176693a220563 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Thu, 14 Nov 2024 13:07:37 -0800 Subject: [PATCH 41/59] Update package versions to 3.0.0. (#521) We just published version 2.9.1 to PyPI, so update the version to 3.0.0 ahead of the next nightly build. Note: updating the version and rebuilding a release is relatively cheap in this project, since the release build action only takes 5 minutes. We can reduce the version if we want, or update it after a nightly build. That is _not_ the case in https://github.com/iree-org/iree, where the release build takes over 4 hours. --- sharktank/version.json | 2 +- shortfin/version.json | 2 +- tuner/version.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sharktank/version.json b/sharktank/version.json index 794a2de28..85afb41ed 100644 --- a/sharktank/version.json +++ b/sharktank/version.json @@ -1,3 +1,3 @@ { - "package-version": "2.9.1.dev" + "package-version": "3.0.0.dev" } diff --git a/shortfin/version.json b/shortfin/version.json index 794a2de28..85afb41ed 100644 --- a/shortfin/version.json +++ b/shortfin/version.json @@ -1,3 +1,3 @@ { - "package-version": "2.9.1.dev" + "package-version": "3.0.0.dev" } diff --git a/tuner/version.json b/tuner/version.json index 794a2de28..85afb41ed 100644 --- a/tuner/version.json +++ b/tuner/version.json @@ -1,3 +1,3 @@ { - "package-version": "2.9.1.dev" + "package-version": "3.0.0.dev" } From 4271115f51609ea8105feb21e9047875a6443340 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Thu, 14 Nov 2024 18:04:45 -0500 Subject: [PATCH 42/59] (shortfin) Fix dev_me.py: --clang arg and mishandling of a pathlib.PosixPath (#523) --- shortfin/dev_me.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/shortfin/dev_me.py b/shortfin/dev_me.py index 9125c8713..8eacca274 100755 --- a/shortfin/dev_me.py +++ b/shortfin/dev_me.py @@ -105,12 +105,10 @@ def find_clang(self, args): clang_exe = shutil.which("clang") if not clang_exe: return None, None - try: - clang_output = subprocess.check_output( - [clang_exe, "--version"] - ).decode() - except: - return None, None + try: + clang_output = subprocess.check_output([clang_exe, "--version"]).decode() + except: + return None, None if m := re.search(r"clang version ([0-9\.]+)", clang_output): return clang_exe, Version(m.group(1)) return None, None @@ -244,7 +242,7 @@ def configure_mode(env_info: EnvInfo, args): "-e", str(env_info.this_dir), ] - print(f"{' '.join('='.join(kv) for kv in env_vars.items())} \\") + print(f"{' '.join('='.join(str(kv)) for kv in env_vars.items())} \\") print(f" {' '.join(setup_args)}") actual_env_vars = dict(os.environ) actual_env_vars.update(env_vars) From 46420ecdebc9d0589ba25713aad22ac8ecc196a2 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 14 Nov 2024 18:07:23 -0500 Subject: [PATCH 43/59] [tuner] Fix typing issues in libtuner. NFC. (#526) Make libtuner and its test type-check with mypy. --- tuner/tuner/libtuner.py | 16 +++++--- tuner/tuner/libtuner_test.py | 78 +++++++++++++++++++++--------------- 2 files changed, 56 insertions(+), 38 deletions(-) diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py index 91c7b417a..3aa932dc4 100644 --- a/tuner/tuner/libtuner.py +++ b/tuner/tuner/libtuner.py @@ -38,7 +38,7 @@ import random import json from abc import ABC, abstractmethod -import iree.runtime as ireert +import iree.runtime as ireert # type: ignore from . import candidate_gen @@ -250,10 +250,11 @@ def get_mean_time_us(self) -> Optional[float]: mean_benchmark = self.find_mean_benchmark(self.result_json) if mean_benchmark: - real_time = mean_benchmark.get("real_time") - time_unit = mean_benchmark.get("time_unit") + real_time: float | None = mean_benchmark.get("real_time") + time_unit: str | None = mean_benchmark.get("time_unit") if real_time is not None: + assert time_unit is not None return self.unit_to_microseconds(real_time, time_unit) return None @@ -549,7 +550,7 @@ def create_worker_context_queue(device_ids: list[int]) -> queue.Queue[tuple[int, return worker_contexts_queue -def run_command(run_pack: RunPack) -> TaskResult: +def run_command(run_pack: RunPack) -> RunResult: command = run_pack.command check = run_pack.check timeout_seconds = run_pack.timeout_seconds @@ -946,6 +947,7 @@ def parse_dispatch_benchmark_results( continue res_json = extract_benchmark_from_run_result(benchmark_result.run_result) + assert res_json is not None res = IREEBenchmarkResult(candidate_id, res_json) benchmark_time = res.get_mean_time_us() assert benchmark_time is not None @@ -985,7 +987,10 @@ def generate_sample_task_result( stdout=stdout, returncode=0, ) - return TaskResult(result=res, candidate_id=candidate_id, device_id=device_id) + run_result = RunResult(res, False) + return TaskResult( + run_result=run_result, candidate_id=candidate_id, device_id=device_id + ) def generate_dryrun_dispatch_benchmark_results( @@ -1235,6 +1240,7 @@ def parse_model_benchmark_results( continue result_json = extract_benchmark_from_run_result(task_result.run_result) + assert result_json is not None res = IREEBenchmarkResult(candidate_id, result_json) benchmark_time = res.get_mean_time_us() assert benchmark_time is not None diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py index 36bda3bd5..11af59af4 100644 --- a/tuner/tuner/libtuner_test.py +++ b/tuner/tuner/libtuner_test.py @@ -7,6 +7,7 @@ import argparse import pytest import json +from subprocess import CompletedProcess from unittest.mock import call, patch, MagicMock from . import libtuner @@ -15,15 +16,15 @@ """ -def test_group_benchmark_results_by_device_id(): +def test_group_benchmark_results_by_device_id() -> None: # Create mock TaskResult objects with device_id attributes - task_result_1 = MagicMock() + task_result_1: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) task_result_1.device_id = "device_1" - task_result_2 = MagicMock() + task_result_2: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) task_result_2.device_id = "device_2" - task_result_3 = MagicMock() + task_result_3: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult) task_result_3.device_id = "device_1" benchmark_results = [task_result_1, task_result_2, task_result_3] @@ -40,7 +41,7 @@ def test_group_benchmark_results_by_device_id(): assert grouped_results[1][0].device_id == "device_2" -def test_find_collisions(): +def test_find_collisions() -> None: input = [(1, "abc"), (2, "def"), (3, "abc")] assert libtuner.find_collisions(input) == (True, [("abc", [1, 3]), ("def", [2])]) input = [(1, "abc"), (2, "def"), (3, "hig")] @@ -50,14 +51,14 @@ def test_find_collisions(): ) -def test_collision_handler(): +def test_collision_handler() -> None: input = [(1, "abc"), (2, "def"), (3, "abc"), (4, "def"), (5, "hig")] assert libtuner.collision_handler(input) == (True, [1, 2, 5]) input = [(1, "abc"), (2, "def"), (3, "hig")] assert libtuner.collision_handler(input) == (False, []) -def test_IREEBenchmarkResult_get(): +def test_IREEBenchmarkResult_get() -> None: # Time is int in us int_json = [{"aggregate_name": "mean", "real_time": 1, "time_unit": "us"}] @@ -108,7 +109,7 @@ def test_IREEBenchmarkResult_get(): assert res.get_mean_time_us() == None # Invalid json: empty dictionary - res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json={}) + res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json=[]) assert res.get_mean_time_us() is None # Invalid json: invalid time unit @@ -131,7 +132,7 @@ def test_IREEBenchmarkResult_get(): assert res.get_mean_time_us() is None -def test_generate_display_BR(): +def test_generate_display_BR() -> None: output = libtuner.generate_display_DBR(1, 3.14) expected = f"1\tMean Time: 3.1" assert output == expected, "DispatchBenchmarkResult generates invalid sample string" @@ -147,29 +148,38 @@ def test_generate_display_BR(): assert output == expected, "ModelBenchmarkResult generates invalid sample string" -def test_parse_dispatch_benchmark_results(): +def make_mock_task_result() -> libtuner.TaskResult: + process: CompletedProcess = MagicMock(spec=CompletedProcess) + run_result = libtuner.RunResult(process, False) + task_result = libtuner.TaskResult(run_result, 0, "") + return task_result + + +def test_parse_dispatch_benchmark_results() -> None: base_path = libtuner.Path("/mock/base/dir") spec_dir = base_path / "specs" path_config = libtuner.PathConfig() object.__setattr__(path_config, "specs_dir", spec_dir) - mock_result_1 = MagicMock() + mock_result_1 = make_mock_task_result() mock_json_1 = { "benchmarks": [ {"aggregate_name": "mean", "real_time": 100.0, "time_unit": "us"} ] } + assert mock_result_1.run_result.process_res is not None mock_result_1.run_result.process_res.stdout = json.dumps(mock_json_1) mock_result_1.candidate_id = 1 - mock_result_2 = MagicMock() + mock_result_2 = make_mock_task_result() mock_json_2 = { "benchmarks": [ {"aggregate_name": "mean", "real_time": 200.0, "time_unit": "us"} ] } + assert mock_result_2.run_result.process_res is not None mock_result_2.run_result.process_res.stdout = json.dumps(mock_json_2) mock_result_2.candidate_id = 2 - mock_result_3 = MagicMock() + mock_result_3 = make_mock_task_result() mock_json_3 = { "benchmarks": [ { @@ -179,11 +189,11 @@ def test_parse_dispatch_benchmark_results(): } ] } + assert mock_result_3.run_result.process_res is not None mock_result_3.run_result.process_res.stdout = json.dumps(mock_json_3) mock_result_3.candidate_id = 3 - mock_result_4 = MagicMock() - mock_result_4.run_result.process_res = None # Incomplete result - mock_result_4.candidate_id = 4 + # Incomplete result. + mock_result_4 = libtuner.TaskResult(libtuner.RunResult(None, True), 4, "4") benchmark_results = [mock_result_1, mock_result_2, mock_result_3, mock_result_4] candidate_trackers = [] @@ -239,7 +249,7 @@ def test_parse_dispatch_benchmark_results(): ) -def test_parse_model_benchmark_results(): +def test_parse_model_benchmark_results() -> None: # Setup mock data for candidate_trackers tracker0 = libtuner.CandidateTracker(0) tracker0.compiled_model_path = libtuner.Path("/path/to/baseline.vmfb") @@ -256,38 +266,40 @@ def test_parse_model_benchmark_results(): candidate_trackers = [tracker0, tracker1, tracker2, tracker3] # Setup mock data for task results - result1 = MagicMock() + result1 = make_mock_task_result() result_json_1 = {"benchmarks": [{"real_time": 1.23}]} + assert result1.run_result.process_res is not None result1.run_result.process_res.stdout = json.dumps(result_json_1) result1.candidate_id = 1 result1.device_id = "device1" - result2 = MagicMock() + result2 = make_mock_task_result() result_json_2 = {"benchmarks": [{"real_time": 4.56}]} + assert result2.run_result.process_res is not None result2.run_result.process_res.stdout = json.dumps(result_json_2) result2.candidate_id = 2 result2.device_id = "device2" - result3 = MagicMock() + result3 = make_mock_task_result() result_json_3 = {"benchmarks": [{"real_time": 0.98}]} + assert result3.run_result.process_res is not None result3.run_result.process_res.stdout = json.dumps(result_json_3) result3.candidate_id = 0 result3.device_id = "device1" - result4 = MagicMock() + result4 = make_mock_task_result() result_json_4 = {"benchmarks": [{"real_time": 4.13}]} + assert result4.run_result.process_res is not None result4.run_result.process_res.stdout = json.dumps(result_json_4) result4.candidate_id = 0 result4.device_id = "device2" # Incomplete baseline on device3 - result5 = MagicMock() - result5.run_result.process_res = None - result5.candidate_id = 0 - result5.device_id = "device3" + result5 = libtuner.TaskResult(libtuner.RunResult(None, True), 0, "device3") - result6 = MagicMock() + result6 = make_mock_task_result() result_json_6 = {"benchmarks": [{"real_time": 3.38}]} + assert result6.run_result.process_res is not None result6.run_result.process_res.stdout = json.dumps(result_json_6) result6.candidate_id = 3 result6.device_id = "device3" @@ -347,14 +359,14 @@ def mock_get_mean_time_us(self): ) -def test_extract_driver_names(): +def test_extract_driver_names() -> None: user_devices = ["hip://0", "local-sync://default", "cuda://default"] expected_output = {"hip", "local-sync", "cuda"} assert libtuner.extract_driver_names(user_devices) == expected_output -def test_fetch_available_devices_success(): +def test_fetch_available_devices_success() -> None: drivers = ["hip", "local-sync", "cuda"] mock_devices = { "hip": [{"path": "ABCD", "device_id": 1}], @@ -384,7 +396,7 @@ def get_mock_driver(name): assert actual_output == expected_output -def test_fetch_available_devices_failure(): +def test_fetch_available_devices_failure() -> None: drivers = ["hip", "local-sync", "cuda"] mock_devices = { "hip": [{"path": "ABCD", "device_id": 1}], @@ -421,7 +433,7 @@ def get_mock_driver(name): ) -def test_parse_devices(): +def test_parse_devices() -> None: user_devices_str = "hip://0, local-sync://default, cuda://default" expected_output = ["hip://0", "local-sync://default", "cuda://default"] @@ -432,7 +444,7 @@ def test_parse_devices(): mock_handle_error.assert_not_called() -def test_parse_devices_with_invalid_input(): +def test_parse_devices_with_invalid_input() -> None: user_devices_str = "hip://0, local-sync://default, invalid_device, cuda://default" expected_output = [ "hip://0", @@ -452,7 +464,7 @@ def test_parse_devices_with_invalid_input(): ) -def test_validate_devices(): +def test_validate_devices() -> None: user_devices = ["hip://0", "local-sync://default"] user_drivers = {"hip", "local-sync"} @@ -469,7 +481,7 @@ def test_validate_devices(): ) -def test_validate_devices_with_invalid_device(): +def test_validate_devices_with_invalid_device() -> None: user_devices = ["hip://0", "local-sync://default", "cuda://default"] user_drivers = {"hip", "local-sync", "cuda"} From 86bd384f163b6ce7ab808b53e1659840591109e8 Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:28:51 -0600 Subject: [PATCH 44/59] Sglang benchmark test (#476) # Description Create a nightly workflow for SGLang Benchmark test that enables running a Shortfin server and benchmarking from SGLang, using the `bench_serving` script. ## `bench_serving` Invocations The bench_serving script is ran with various `request-rate` arguments: - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer= `--request-rate 1` --output-file /shortfin_10_1.jsonl - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer= `--request-rate 2` --output-file /shortfin_10_1.jsonl - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer= `--request-rate 4` --output-file /shortfin_10_1.jsonl - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer= `--request-rate 8` --output-file /shortfin_10_1.jsonl - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer= `--request-rate 16` --output-file /shortfin_10_1.jsonl - python -m sglang.bench_serving --backend shortfin --num-prompt 10 --base-url http://localhost:8000 --tokenizer= `--request-rate 32` --output-file /shortfin_10_1.jsonl After the test is finished running, we upload the html output from pytest to gh-pages. The subdirectory is set to `./llm/sglang`, so the results should be accessible from the browser at `/llm/sglang/index.html` in gh-pages. This also includes a refactor of the existing integration test. Most of the methods for downloading a model/tokenizer, exporting to mlir, compiling to vmfb, and starting a shortfin server have been moved to a `utils.py` file. --- .github/workflows/ci-sglang-benchmark.yml | 88 ++++++++ .github/workflows/ci-shark-platform.yml | 2 +- app_tests/__init__.py | 0 app_tests/benchmark_tests/__init__.py | 0 app_tests/benchmark_tests/llm/conftest.py | 47 ++++ .../llm/sglang_benchmark_test.py | 108 +++++++++ app_tests/benchmark_tests/llm/utils.py | 55 +++++ app_tests/integration_tests/__init__.py | 0 app_tests/integration_tests/llm/__init__.py | 0 app_tests/integration_tests/llm/conftest.py | 135 +++++++++++ .../llm/cpu_llm_server_test.py | 2 +- app_tests/integration_tests/llm/utils.py | 180 +++++++++++++++ build_tools/integration_tests/llm/conftest.py | 212 ------------------ build_tools/integration_tests/llm/utils.py | 9 - 14 files changed, 615 insertions(+), 223 deletions(-) create mode 100644 .github/workflows/ci-sglang-benchmark.yml create mode 100644 app_tests/__init__.py create mode 100644 app_tests/benchmark_tests/__init__.py create mode 100644 app_tests/benchmark_tests/llm/conftest.py create mode 100644 app_tests/benchmark_tests/llm/sglang_benchmark_test.py create mode 100644 app_tests/benchmark_tests/llm/utils.py create mode 100644 app_tests/integration_tests/__init__.py create mode 100644 app_tests/integration_tests/llm/__init__.py create mode 100644 app_tests/integration_tests/llm/conftest.py rename {build_tools => app_tests}/integration_tests/llm/cpu_llm_server_test.py (98%) create mode 100644 app_tests/integration_tests/llm/utils.py delete mode 100644 build_tools/integration_tests/llm/conftest.py delete mode 100644 build_tools/integration_tests/llm/utils.py diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml new file mode 100644 index 000000000..d890d972c --- /dev/null +++ b/.github/workflows/ci-sglang-benchmark.yml @@ -0,0 +1,88 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +name: SGLang Llama Benchmarking Tests + +on: + workflow_dispatch: + schedule: + # Weekdays at 4:00 AM UTC = 9:00 PM PST. + - cron: "0 4 * * 1-5" + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + sglang_bench_serve: + name: "SGLang Serving Benchmark Tests" + strategy: + matrix: + version: [3.11] + fail-fast: false + runs-on: llama-mi300x-3 + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - name: Get Current Date + id: date + run: echo "::set-output name=date::$(date +'%Y-%m-%d')" + + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-cpu-requirements.txt + pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" + pip install --no-compile -r requirements.txt -e sharktank/ shortfin/ + + # Try with the latest nightly releases, not what iree-turbine pins. + # We could also pin to a known working or stable version. + # This should eventually stabilize. Do the best we can for now. + pip install -f https://iree.dev/pip-release-links.html --upgrade \ + iree-base-compiler==2.9.0rc20241108 \ + iree-base-runtime==2.9.0rc20241108 \ + "numpy<2.0" + + - name: Install SGLang + run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" + + - name: Launch Shortfin Server + run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 + with: + github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} + publish_dir: ./out/llm/sglang + destination_dir: ./llm/sglang + keep_files: true diff --git a/.github/workflows/ci-shark-platform.yml b/.github/workflows/ci-shark-platform.yml index 6741f7ea0..dc2f4646a 100644 --- a/.github/workflows/ci-shark-platform.yml +++ b/.github/workflows/ci-shark-platform.yml @@ -72,4 +72,4 @@ jobs: iree-base-runtime - name: Run LLM Integration Tests - run: pytest -v build_tools/integration_tests/llm --log-cli-level=INFO + run: pytest -v app_tests/integration_tests/llm --log-cli-level=INFO diff --git a/app_tests/__init__.py b/app_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app_tests/benchmark_tests/__init__.py b/app_tests/benchmark_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app_tests/benchmark_tests/llm/conftest.py b/app_tests/benchmark_tests/llm/conftest.py new file mode 100644 index 000000000..aac66ca0f --- /dev/null +++ b/app_tests/benchmark_tests/llm/conftest.py @@ -0,0 +1,47 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import os +import pytest +import sys + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from integration_tests.llm.utils import compile_model, export_paged_llm_v1 + + +@pytest.fixture(scope="module") +def pre_process_model(request, tmp_path_factory): + tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test") + + model_path = request.param["model_path"] + settings = request.param["settings"] + batch_sizes = request.param["batch_sizes"] + + tmp_dir = tmp_path_factory.mktemp("llm_benchmark_test") + mlir_path = tmp_dir / "model.mlir" + config_path = tmp_dir / "config.json" + vmfb_path = tmp_dir / "model.vmfb" + + export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) + + config = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 131072, + "attn_head_count": 8, + "attn_head_dim": 128, + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "transformer_block_count": 32, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + with open(config_path, "w") as file: + json.dump(config, file) + + compile_model(mlir_path, vmfb_path, settings) + + return tmp_dir diff --git a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py b/app_tests/benchmark_tests/llm/sglang_benchmark_test.py new file mode 100644 index 000000000..8027fcea7 --- /dev/null +++ b/app_tests/benchmark_tests/llm/sglang_benchmark_test.py @@ -0,0 +1,108 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import logging +import multiprocessing +import os +from pathlib import Path +import pytest +import time +from unittest.mock import patch + +pytest.importorskip("sglang") +from sglang import bench_serving + +from utils import SGLangBenchmarkArgs + +from integration_tests.llm.utils import ( + find_available_port, + start_llm_server, +) + +logger = logging.getLogger("__name__") + +device_settings = { + "device_flags": [ + "--iree-hal-target-backends=rocm", + "--iree-hip-target=gfx942", + ], + "device": "hip", +} + +# TODO: Download on demand instead of assuming files exist at this path +MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa") +TOKENIZER_DIR = Path("/data/llama3.1/8b/") + + +@pytest.mark.parametrize("request_rate", [1, 2, 4, 8, 16, 32]) +@pytest.mark.parametrize( + "pre_process_model", + [ + ( + { + "model_path": MODEL_PATH, + "settings": device_settings, + "batch_sizes": [1, 4], + } + ) + ], + indirect=True, +) +def test_sglang_benchmark_server(request_rate, pre_process_model): + # TODO: Remove when multi-device is fixed + os.environ["ROCR_VISIBLE_DEVICES"] = "1" + + tmp_dir = pre_process_model + + config_path = tmp_dir / "config.json" + vmfb_path = tmp_dir / "model.vmfb" + tokenizer_path = TOKENIZER_DIR / "tokenizer.json" + + # Start shortfin llm server + port = find_available_port() + server_process = start_llm_server( + port, + tokenizer_path, + config_path, + vmfb_path, + MODEL_PATH, + device_settings, + timeout=30, + ) + + # Run and collect SGLang Serving Benchmark + benchmark_args = SGLangBenchmarkArgs( + backend="shortfin", + num_prompt=10, + base_url=f"http://localhost:{port}", + tokenizer=TOKENIZER_DIR, + request_rate=request_rate, + ) + output_file = ( + tmp_dir + / f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl" + ) + benchmark_args.output_file = output_file + + logger.info("Running SGLang Benchmark with the following args:") + logger.info(benchmark_args) + try: + start = time.time() + with patch.object(bench_serving, "print", side_effect=logger.info): + benchmark_process = multiprocessing.Process( + target=bench_serving.run_benchmark, + args=(benchmark_args.as_namespace(),), + ) + benchmark_process.start() + benchmark_process.join() + + logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds") + except Exception as e: + logger.info(e) + + server_process.terminate() + server_process.wait() diff --git a/app_tests/benchmark_tests/llm/utils.py b/app_tests/benchmark_tests/llm/utils.py new file mode 100644 index 000000000..c217720cb --- /dev/null +++ b/app_tests/benchmark_tests/llm/utils.py @@ -0,0 +1,55 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from argparse import Namespace +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class SGLangBenchmarkArgs: + base_url: str + num_prompt: int + request_rate: int + tokenizer: str | Path + + seed: int = 1 + extra_request_body: str | None = None + output_file: str | Path | None = None + port: int = 8000 + backend: str = "shortfin" + + def as_namespace(self) -> Namespace: + return Namespace( + num_prompts=self.num_prompt, + base_url=self.base_url, + tokenizer=str(self.tokenizer), + request_rate=self.request_rate, + backend=self.backend, + output_file=self.output_file, + seed=self.seed, + extra_request_body=self.extra_request_body, + port=8000, + model=None, + dataset_name="sharegpt", + random_input_len=None, + random_output_len=None, + dataset_path="", + sharegpt_output_len=None, + multi=False, + disable_tqdm=False, + disable_stream=False, + disable_ignore_eos=False, + ) + + def __repr__(self): + return ( + f"Backend: {self.backend}\n" + f"Base URL: {self.base_url}\n" + f"Num Prompt: {self.num_prompt}\n" + f"Tokenizer: {self.tokenizer}\n" + f"Request Rate: {self.request_rate}" + ) diff --git a/app_tests/integration_tests/__init__.py b/app_tests/integration_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app_tests/integration_tests/llm/__init__.py b/app_tests/integration_tests/llm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/app_tests/integration_tests/llm/conftest.py b/app_tests/integration_tests/llm/conftest.py new file mode 100644 index 000000000..17cdf1def --- /dev/null +++ b/app_tests/integration_tests/llm/conftest.py @@ -0,0 +1,135 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import json +import logging +import os +from pathlib import Path +import pytest +import shutil + +pytest.importorskip("transformers") +from .utils import ( + download_huggingface_model, + download_tokenizer, + export_paged_llm_v1, + compile_model, + find_available_port, + start_llm_server, +) + +logger = logging.getLogger(__name__) + + +@pytest.fixture(scope="module") +def model_test_dir(request, tmp_path_factory): + """Prepare model artifacts for starting the LLM server. + + Args: + request (FixtureRequest): The following params are accepted: + - repo_id (str): The Hugging Face repo ID. + - model_file (str): The model file to download. + - tokenizer_id (str): The tokenizer ID to download. + - settings (dict): The settings for sharktank export. + - batch_sizes (list): The batch sizes to use for the model. + tmp_path_factory (TempPathFactory): Temp dir to save artifacts to. + + Yields: + Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir. + """ + logger.info("Preparing model artifacts...") + + repo_id = request.param["repo_id"] + model_file = request.param["model_file"] + tokenizer_id = request.param["tokenizer_id"] + settings = request.param["settings"] + batch_sizes = request.param["batch_sizes"] + + tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test") + hf_home = os.environ.get("HF_HOME", None) + hf_home = Path(hf_home) if hf_home is not None else tmp_dir + try: + # Download model if it doesn't exist + model_path = hf_home / model_file + download_huggingface_model(hf_home, repo_id, model_file) + + # Set up tokenizer if it doesn't exist + download_tokenizer(hf_home, tokenizer_id) + + # Export model + mlir_path = tmp_dir / "model.mlir" + config_path = tmp_dir / "config.json" + export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) + + # Compile model + vmfb_path = tmp_dir / "model.vmfb" + compile_model(mlir_path, vmfb_path, settings) + + # Write config + edited_config_path = tmp_dir / "edited_config.json" + config = { + "module_name": "module", + "module_abi_version": 1, + "max_seq_len": 2048, + "attn_head_count": 32, + "attn_head_dim": 100, + "prefill_batch_sizes": batch_sizes, + "decode_batch_sizes": batch_sizes, + "transformer_block_count": 26, + "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, + } + logger.info(f"Saving edited config to: {edited_config_path}\n") + logger.info(f"Config: {json.dumps(config, indent=2)}") + with open(edited_config_path, "w") as f: + json.dump(config, f) + logger.info("Model artifacts setup successfully") + yield hf_home, tmp_dir + finally: + shutil.rmtree(tmp_dir) + + +@pytest.fixture(scope="module") +def available_port(): + return find_available_port() + + +@pytest.fixture(scope="module") +def llm_server(request, model_test_dir, available_port): + """Start the LLM server. + + Args: + request (FixtureRequest): The following params are accepted: + - model_file (str): The model file to download. + - settings (dict): The settings for starting the server. + model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir. + available_port (int): The available port to start the server on. + + Yields: + subprocess.Popen: The server process that was started. + """ + logger.info("Starting LLM server...") + hf_home, tmp_dir = model_test_dir + model_file = request.param["model_file"] + settings = request.param["settings"] + + tokenizer_path = hf_home / "tokenizer.json" + config_path = tmp_dir / "edited_config.json" + vmfb_path = tmp_dir / "model.vmfb" + parameters_path = hf_home / model_file + + # Start llm server + server_process = start_llm_server( + available_port, + tokenizer_path, + config_path, + vmfb_path, + parameters_path, + settings, + ) + yield server_process + # Teardown: kill the server + server_process.terminate() + server_process.wait() diff --git a/build_tools/integration_tests/llm/cpu_llm_server_test.py b/app_tests/integration_tests/llm/cpu_llm_server_test.py similarity index 98% rename from build_tools/integration_tests/llm/cpu_llm_server_test.py rename to app_tests/integration_tests/llm/cpu_llm_server_test.py index 4d4ec5540..e7d0792d8 100644 --- a/build_tools/integration_tests/llm/cpu_llm_server_test.py +++ b/app_tests/integration_tests/llm/cpu_llm_server_test.py @@ -10,7 +10,7 @@ import requests import uuid -from utils import AccuracyValidationException +from .utils import AccuracyValidationException logger = logging.getLogger(__name__) diff --git a/app_tests/integration_tests/llm/utils.py b/app_tests/integration_tests/llm/utils.py new file mode 100644 index 000000000..b8b5ae60f --- /dev/null +++ b/app_tests/integration_tests/llm/utils.py @@ -0,0 +1,180 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import multiprocessing +import os +import subprocess +import sys +import time + +import requests +from transformers import AutoTokenizer + +logger = logging.getLogger("__name__") + + +class AccuracyValidationException(RuntimeError): + pass + + +def download_huggingface_model(local_dir, repo_id, model_file): + model_path = local_dir / model_file + logger.info(f"Preparing model_path: {model_path}..") + if not os.path.exists(model_path): + logger.info(f"Downloading model {repo_id} {model_file} from Hugging Face...") + subprocess.run( + f"huggingface-cli download --local-dir {local_dir} {repo_id} {model_file}", + shell=True, + check=True, + ) + logger.info(f"Model downloaded to {model_path}") + else: + logger.info("Using cached model") + + +def download_tokenizer(local_dir, tokenizer_id): + # Set up tokenizer if it doesn't exist + tokenizer_path = local_dir / "tokenizer.json" + logger.info(f"Preparing tokenizer_path: {tokenizer_path}...") + if not os.path.exists(tokenizer_path): + logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...") + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_id, + ) + tokenizer.save_pretrained(local_dir) + logger.info(f"Tokenizer saved to {tokenizer_path}") + else: + logger.info("Using cached tokenizer") + + +def export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes): + bs_string = ",".join(map(str, batch_sizes)) + logger.info( + "Exporting model with following settings:\n" + f" MLIR Path: {mlir_path}\n" + f" Config Path: {config_path}\n" + f" Batch Sizes: {bs_string}" + ) + subprocess.run( + [ + "python", + "-m", + "sharktank.examples.export_paged_llm_v1", + f"--{model_path.suffix.strip('.')}-file={model_path}", + f"--output-mlir={mlir_path}", + f"--output-config={config_path}", + f"--bs={bs_string}", + ], + check=True, + ) + logger.info(f"Model successfully exported to {mlir_path}") + + +def compile_model(mlir_path, vmfb_path, device_settings): + logger.info(f"Compiling model to {vmfb_path}") + subprocess.run( + [ + "iree-compile", + mlir_path, + "-o", + vmfb_path, + ] + + device_settings["device_flags"], + check=True, + ) + logger.info(f"Model successfully compiled to {vmfb_path}") + + +def find_available_port(): + import socket + from contextlib import closing + + logger.info(f"Finding available port...") + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = s.getsockname()[1] + logger.info(f"Found available port: {port}") + return port + + +def wait_for_server(url, timeout=10): + logger.info(f"Waiting for server to start at {url}...") + start = time.time() + while time.time() - start < timeout: + try: + requests.get(f"{url}/health") + logger.info("Server successfully started") + return + except requests.exceptions.ConnectionError: + time.sleep(1) + raise TimeoutError(f"Server did not start within {timeout} seconds") + + +def _start_llm_server_args( + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + port, +): + return [ + sys.executable, + "-m", + "shortfin_apps.llm.server", + f"--tokenizer_json={tokenizer_path}", + f"--model_config={model_config_path}", + f"--vmfb={vmfb_path}", + f"--parameters={parameters_path}", + f"--device={settings['device']}", + f"--port={port}", + ] + + +def start_llm_server( + port, + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + timeout=10, + multi=False, +): + logger.info("Starting LLM server...") + if multi: + server_process = multiprocessing.Process( + target=subprocess.Popen( + _start_llm_server_args( + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + port, + ), + ) + ) + server_process.start() + + else: + # Start the server + server_process = subprocess.Popen( + _start_llm_server_args( + tokenizer_path, + model_config_path, + vmfb_path, + parameters_path, + settings, + port, + ) + ) + logger.info("Process started... waiting for server") + # Wait for server to start + wait_for_server(f"http://localhost:{port}", timeout) + return server_process diff --git a/build_tools/integration_tests/llm/conftest.py b/build_tools/integration_tests/llm/conftest.py deleted file mode 100644 index 1103065bc..000000000 --- a/build_tools/integration_tests/llm/conftest.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import json -import logging -import os -from pathlib import Path -import pytest -import requests -import shutil -import subprocess -import time - -pytest.importorskip("transformers") -from transformers import AutoTokenizer - -logger = logging.getLogger(__name__) - - -@pytest.fixture(scope="module") -def model_test_dir(request, tmp_path_factory): - """Prepare model artifacts for starting the LLM server. - - Args: - request (FixtureRequest): The following params are accepted: - - repo_id (str): The Hugging Face repo ID. - - model_file (str): The model file to download. - - tokenizer_id (str): The tokenizer ID to download. - - settings (dict): The settings for sharktank export. - - batch_sizes (list): The batch sizes to use for the model. - tmp_path_factory (TempPathFactory): Temp dir to save artifacts to. - - Yields: - Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir. - """ - logger.info("Preparing model artifacts...") - - repo_id = request.param["repo_id"] - model_file = request.param["model_file"] - tokenizer_id = request.param["tokenizer_id"] - settings = request.param["settings"] - batch_sizes = request.param["batch_sizes"] - - tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test") - hf_home = os.environ.get("HF_HOME", None) - hf_home = Path(hf_home) if hf_home is not None else tmp_dir - try: - # Download model if it doesn't exist - model_path = hf_home / model_file - logger.info(f"Preparing model_path: {model_path}..") - if not os.path.exists(model_path): - logger.info( - f"Downloading model {repo_id} {model_file} from Hugging Face..." - ) - subprocess.run( - f"huggingface-cli download --local-dir {hf_home} {repo_id} {model_file}", - shell=True, - check=True, - ) - logger.info(f"Model downloaded to {model_path}") - else: - logger.info("Using cached model") - - # Set up tokenizer if it doesn't exist - tokenizer_path = hf_home / "tokenizer.json" - logger.info(f"Preparing tokenizer_path: {tokenizer_path}...") - if not os.path.exists(tokenizer_path): - logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...") - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_id, - ) - tokenizer.save_pretrained(hf_home) - logger.info(f"Tokenizer saved to {tokenizer_path}") - else: - logger.info("Using cached tokenizer") - - # Export model - mlir_path = tmp_dir / "model.mlir" - config_path = tmp_dir / "config.json" - bs_string = ",".join(map(str, batch_sizes)) - logger.info( - "Exporting model with following settings:\n" - f" MLIR Path: {mlir_path}\n" - f" Config Path: {config_path}\n" - f" Batch Sizes: {bs_string}" - ) - subprocess.run( - [ - "python", - "-m", - "sharktank.examples.export_paged_llm_v1", - f"--gguf-file={model_path}", - f"--output-mlir={mlir_path}", - f"--output-config={config_path}", - f"--bs={bs_string}", - ], - check=True, - ) - logger.info(f"Model successfully exported to {mlir_path}") - - # Compile model - vmfb_path = tmp_dir / "model.vmfb" - logger.info(f"Compiling model to {vmfb_path}") - subprocess.run( - [ - "iree-compile", - mlir_path, - "-o", - vmfb_path, - ] - + settings["device_flags"], - check=True, - ) - logger.info(f"Model successfully compiled to {vmfb_path}") - - # Write config if it doesn't exist - edited_config_path = tmp_dir / "edited_config.json" - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": batch_sizes, - "decode_batch_sizes": batch_sizes, - "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - logger.info(f"Saving edited config to: {edited_config_path}\n") - logger.info(f"Config: {json.dumps(config, indent=2)}") - with open(edited_config_path, "w") as f: - json.dump(config, f) - logger.info("Model artifacts setup successfully") - yield hf_home, tmp_dir - finally: - shutil.rmtree(tmp_dir) - - -@pytest.fixture(scope="module") -def available_port(port=8000, max_port=8100): - import socket - - logger.info(f"Finding available port in range {port}-{max_port}...") - - starting_port = port - - while port < max_port: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("localhost", port)) - s.close() - logger.info(f"Found available port: {port}") - return port - except socket.error: - port += 1 - - raise IOError(f"No available ports found within range {starting_port}-{max_port}") - - -def wait_for_server(url, timeout=10): - logger.info(f"Waiting for server to start at {url}...") - start = time.time() - while time.time() - start < timeout: - try: - requests.get(f"{url}/health") - logger.info("Server successfully started") - return - except requests.exceptions.ConnectionError: - time.sleep(1) - raise TimeoutError(f"Server did not start within {timeout} seconds") - - -@pytest.fixture(scope="module") -def llm_server(request, model_test_dir, available_port): - """Start the LLM server. - - Args: - request (FixtureRequest): The following params are accepted: - - model_file (str): The model file to download. - - settings (dict): The settings for starting the server. - model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir. - available_port (int): The available port to start the server on. - - Yields: - subprocess.Popen: The server process that was started. - """ - logger.info("Starting LLM server...") - # Start the server - hf_home, tmp_dir = model_test_dir - model_file = request.param["model_file"] - settings = request.param["settings"] - server_process = subprocess.Popen( - [ - "python", - "-m", - "shortfin_apps.llm.server", - f"--tokenizer_json={hf_home / 'tokenizer.json'}", - f"--model_config={tmp_dir / 'edited_config.json'}", - f"--vmfb={tmp_dir / 'model.vmfb'}", - f"--parameters={hf_home / model_file}", - f"--device={settings['device']}", - ] - ) - # Wait for server to start - wait_for_server(f"http://localhost:{available_port}") - yield server_process - # Teardown: kill the server - server_process.terminate() - server_process.wait() diff --git a/build_tools/integration_tests/llm/utils.py b/build_tools/integration_tests/llm/utils.py deleted file mode 100644 index b31a3e416..000000000 --- a/build_tools/integration_tests/llm/utils.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - - -class AccuracyValidationException(RuntimeError): - pass From 778b567aa111b558fb7cf6c3fff98c79753f37ef Mon Sep 17 00:00:00 2001 From: Chris Sosa Date: Thu, 14 Nov 2024 17:14:33 -0800 Subject: [PATCH 45/59] Update pyproject to include missing SDXL app deps (#529) Update pyproject to include missing SDXL app deps. These were discovered while testing shortfin apps flow by only pre-installing the shark meta pkg --- shortfin/pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/shortfin/pyproject.toml b/shortfin/pyproject.toml index 67483fb05..7c4ed8a33 100644 --- a/shortfin/pyproject.toml +++ b/shortfin/pyproject.toml @@ -39,6 +39,9 @@ apps = [ "transformers", "dataclasses-json", "pillow", + "fastapi", + "uvicorn", + "aiohttp", ] [tool.pytest.ini_options] From 7fd047fe23e41c88be464cf957dcb27f3e74a69f Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 14 Nov 2024 20:41:21 -0500 Subject: [PATCH 46/59] [tuner][CI] Enable mypy type checking (#530) For now only typecheck the tuner implementation. In the future, we should extend it to examples. --- .github/workflows/ci-tuner.yml | 3 +++ tuner/requirements-dev.txt | 1 + 2 files changed, 4 insertions(+) diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml index dad766d8e..81b920e31 100644 --- a/.github/workflows/ci-tuner.yml +++ b/.github/workflows/ci-tuner.yml @@ -54,3 +54,6 @@ jobs: - name: Run tuner tests run: pytest tuner/ + + - name: Run mypy type checker + run: mypy tuner/tuner diff --git a/tuner/requirements-dev.txt b/tuner/requirements-dev.txt index 51d5b9ba0..747b28508 100644 --- a/tuner/requirements-dev.txt +++ b/tuner/requirements-dev.txt @@ -1,2 +1,3 @@ +mypy==1.8.0 pre-commit==3.8.0 virtualenv==20.13.0 From 8de8524d55b7538ac0a7bc94220fd91d000602e4 Mon Sep 17 00:00:00 2001 From: Chris Sosa Date: Thu, 14 Nov 2024 20:12:34 -0800 Subject: [PATCH 47/59] Add SHARK user guide to root of docs directory (#528) Progress on https://github.com/nod-ai/SHARK-Platform/issues/458 This PR adds a SHARK user guide to root of docs directory and does some basic information re-architecture to point installation paths of current main readmes to the new user guide. One new landing page and removal of duplicate installation paths in SD folder to point to both nightly / new release page depending on use case. Incorporated a mini guide on supported options for the SD Server / Client in the user guide. Changed root readme to include a path for users so that anyone who lands on the main SHARK readme can quickly get started as a non-developer. --- README.md | 65 +----------- docs/developer_guide.md | 65 ++++++++++++ docs/user_guide.md | 115 +++++++++++++++++++++ shortfin/python/shortfin_apps/sd/README.md | 31 ++---- 4 files changed, 192 insertions(+), 84 deletions(-) create mode 100644 docs/developer_guide.md create mode 100644 docs/user_guide.md diff --git a/README.md b/README.md index d187c23a0..f5a255c84 100644 --- a/README.md +++ b/README.md @@ -61,68 +61,11 @@ Model name | Model recipes | Serving apps SDXL | [`sharktank/sharktank/models/punet/`](https://github.com/nod-ai/SHARK-Platform/tree/main/sharktank/sharktank/models/punet) | [`shortfin/python/shortfin_apps/sd/`](https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps/sd) llama | [`sharktank/sharktank/models/llama/`](https://github.com/nod-ai/SHARK-Platform/tree/main/sharktank/sharktank/models/llama) | [`shortfin/python/shortfin_apps/llm/`](https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps/llm) -## Development tips -Each sub-project has its own developer guide. If you would like to work across -projects, these instructions should help you get started: +## SHARK Users -### Setup a venv +If you're looking to use SHARK check out our [User Guide](docs/user_guide.md). -We recommend setting up a Python -[virtual environment (venv)](https://docs.python.org/3/library/venv.html). -The project is configured to ignore `.venv` directories, and editors like -VSCode pick them up by default. +## SHARK Developers -```bash -python -m venv .venv -source .venv/bin/activate -``` - -### Install PyTorch for your system - -If no explicit action is taken, the default PyTorch version will be installed. -This will give you a current CUDA-based version, which takes longer to download -and includes other dependencies that SHARK does not require. To install a -different variant, run one of these commands first: - -* *CPU:* - - ```bash - pip install -r pytorch-cpu-requirements.txt - ``` - -* *ROCM:* - - ```bash - pip install -r pytorch-rocm-requirements.txt - ``` - -* *Other:* see instructions at . - -### Install development packages - -```bash -# Install editable local projects. -pip install -r requirements.txt -e sharktank/ shortfin/ - -# Optionally clone and install the latest editable iree-turbine dep in deps/, -# along with nightly versions of iree-base-compiler and iree-base-runtime. -pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ - iree-base-compiler iree-base-runtime --src deps \ - -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" -``` - -See also: [`docs/nightly_releases.md`](./docs/nightly_releases.md). - -### Running tests - -```bash -pytest sharktank -pytest shortfin -``` - -### Optional: pre-commits and developer settings - -This project is set up to use the `pre-commit` tooling. To install it in -your local repo, run: `pre-commit install`. After this point, when making -commits locally, hooks will run. See https://pre-commit.com/ +If you're looking to develop SHARK, check out our [Developer Guide](docs/developer_guide.md). diff --git a/docs/developer_guide.md b/docs/developer_guide.md new file mode 100644 index 000000000..832466688 --- /dev/null +++ b/docs/developer_guide.md @@ -0,0 +1,65 @@ +# SHARK Developer Guide + +Each sub-project has its own developer guide. If you would like to work across +projects, these instructions should help you get started: + +### Setup a venv + +We recommend setting up a Python +[virtual environment (venv)](https://docs.python.org/3/library/venv.html). +The project is configured to ignore `.venv` directories, and editors like +VSCode pick them up by default. + +```bash +python -m venv .venv +source .venv/bin/activate +``` + +### Install PyTorch for your system + +If no explicit action is taken, the default PyTorch version will be installed. +This will give you a current CUDA-based version, which takes longer to download +and includes other dependencies that SHARK does not require. To install a +different variant, run one of these commands first: + +* *CPU:* + + ```bash + pip install -r pytorch-cpu-requirements.txt + ``` + +* *ROCM:* + + ```bash + pip install -r pytorch-rocm-requirements.txt + ``` + +* *Other:* see instructions at . + +### Install development packages + +```bash +# Install editable local projects. +pip install -r requirements.txt -e sharktank/ shortfin/ + +# Optionally clone and install the latest editable iree-turbine dep in deps/, +# along with nightly versions of iree-base-compiler and iree-base-runtime. +pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ + iree-base-compiler iree-base-runtime --src deps \ + -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine" +``` + +See also: [nightly_releases.md](nightly_releases.md). + +### Running tests + +```bash +pytest sharktank +pytest shortfin +``` + +### Optional: pre-commits and developer settings + +This project is set up to use the `pre-commit` tooling. To install it in +your local repo, run: `pre-commit install`. After this point, when making +commits locally, hooks will run. See https://pre-commit.com/ diff --git a/docs/user_guide.md b/docs/user_guide.md new file mode 100644 index 000000000..b7a530583 --- /dev/null +++ b/docs/user_guide.md @@ -0,0 +1,115 @@ +# SHARK User Guide + +> [!WARNING] +> This is still pre-release so the artifacts listed here may be broken +> + +These instructions cover the usage of the latest stable release of SHARK. For a more bleeding edge release please install the [nightly releases](nightly_releases.md). + +## Prerequisites + +Our current user guide requires that you have: +- Access to a computer with an installed AMD Instinctâ„¢ MI300x Series Accelerator +- Installed a compatible version of Linux and ROCm on the computer (see the [ROCm compatability matrix](https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html)) + + +## Set up Environment + +This section will help you install Python and set up a Python environment with venv. + +Officially we support Python versions: 3.11, 3.12, 3.13, 3.13t + +The rest of this guide assumes you are using Python 3.11. + +### Install Python +To install Python 3.11 on Ubuntu: + +```bash +sudo apt install python3.11 python3.11-dev python3.11-venv + +which python3.11 +# /usr/bin/python3.11 +``` + +### Create a Python Environment + +Setup your Python environment with the following commands: + +```bash +# Set up a virtual environment to isolate packages from other envs. +python3.11 -m venv 3.11.venv +source 3.11.venv/bin/activate +``` + +## Install SHARK and its dependencies + +```bash +pip install shark-ai[apps] +``` + +Temporarily, you may need an update to your `shortfin` install. +Install the latest pre-release with: +``` +pip install shortfin --upgrade --pre -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels +``` + +### Test the installation. + +``` +python -m shortfin_apps.sd.server --help +``` + +## Quickstart + +### Run the SDXL Server + +Run the [SDXL Server](../shortfin/python/shortfin_apps/sd/README.md#Start-SDXL-Server) + +### Run the SDXL Client + +``` +python -m shortfin_apps.sd.simple_client --interactive +``` + +Congratulations!!! At this point you can play around with the server and client based on your usage. + +### Update flags + +Please see --help for both the server and client for usage instructions. Here's a quick snapshot. + +#### Update server options: + +| Flags | options | +|---|---| +|--host HOST | +|--port PORT | server port | +|--root-path ROOT_PATH | +|--timeout-keep-alive | +|--device | local-task,hip,amdgpu | amdgpu only supported in this release +|--target | gfx942,gfx1100 | gfx942 only supported in this release +|--device_ids | +|--tokenizers | +|--model_config | +| --workers_per_device | +| --fibers_per_device | +| --isolation | per_fiber, per_call, none | +| --show_progress | +| --trace_execution | +| --amdgpu_async_allocations | +| --splat | +| --build_preference | compile,precompiled | +| --compile_flags | +| --flagfile FLAGFILE | +| --artifacts_dir ARTIFACTS_DIR | Where to store cached artifacts from the Cloud | + +#### Update client with different options: + +| Flags |options| +|---|--- +|--file | +|--reps | +|--save | Whether to save image generated by the server | +|--outputdir| output directory to store images generated by SDXL | +|--steps | +|--interactive | +|--port| port to interact with server | diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md index 0bb5db511..3397be6cf 100644 --- a/shortfin/python/shortfin_apps/sd/README.md +++ b/shortfin/python/shortfin_apps/sd/README.md @@ -1,30 +1,13 @@ -# SD Server and CLI +# SDXL Server and CLI -This directory contains a SD inference server, CLI and support components. +This directory contains a [SDXL](https://stablediffusionxl.com/) inference server, CLI and support components. More information about SDXL on [huggingface](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). +## Install -## Quick start +For [nightly releases](../../../../docs/nightly_releases.md) +For our [stable release](../../../../docs/user_guide.md) -In your shortfin environment, -``` -pip install transformers -pip install dataclasses-json -pip install pillow -pip install shark-ai - -``` - -Temporarily, you may need an update to your `shortfin` install. -Install the latest pre-release with: -``` -pip install shortfin --upgrade --pre -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels -``` - -``` -python -m shortfin_apps.sd.server --help -``` - -# Run on MI300x +## Start SDXL Server The server will prepare runtime artifacts for you. By default, the port is set to 8000. If you would like to change this, use `--port` in each of the following commands. @@ -39,6 +22,8 @@ python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_prefere INFO - Application startup complete. INFO - Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) ``` +## Run the SDXL Client + - Run a CLI client in a separate shell: ``` python -m shortfin_apps.sd.simple_client --interactive From f1bf28236a36a6f8dcb56a655b0756cb7d84b6b2 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 15 Nov 2024 09:41:47 -0500 Subject: [PATCH 48/59] [tuner] Make candidate_gen more modular. NFC. (#531) Separate out common utilities and parsing code into their own files. In the future, I also plan to pull constraint generation into its own file. --- tuner/tuner/candidate_gen.py | 691 +--------------------------- tuner/tuner/candidate_gen_test.py | 504 ++++---------------- tuner/tuner/common.py | 264 +++++++++++ tuner/tuner/common_test.py | 131 ++++++ tuner/tuner/dispatch_parser.py | 435 +++++++++++++++++ tuner/tuner/dispatch_parser_test.py | 176 +++++++ 6 files changed, 1122 insertions(+), 1079 deletions(-) create mode 100644 tuner/tuner/common.py create mode 100644 tuner/tuner/common_test.py create mode 100644 tuner/tuner/dispatch_parser.py create mode 100644 tuner/tuner/dispatch_parser_test.py diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 96bfc7146..06ccae0e3 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -20,251 +20,21 @@ import argparse import logging -import math import pickle import re import z3 # type: ignore -from dataclasses import astuple, dataclass -from enum import Enum +from dataclasses import dataclass from os import path, makedirs from typing import Optional from textwrap import indent -from abc import ABC, abstractmethod +from abc import abstractmethod from iree.compiler import ir # type: ignore -tune_logger = logging.getLogger("tune") - - -class DispatchKind(Enum): - conv = 1 - mmt = 2 - contraction = 3 - batch_mmt = 4 - batch_matmul = 5 - broadcast_rhs_mmt = 6 - - -class ElementType(Enum): - i8 = 1 - i32 = 2 - f8 = 3 - f16 = 4 - f32 = 5 - - @property - def bitwidth(self) -> int: - match self: - case ElementType.i8 | ElementType.f8: - return 8 - case ElementType.f16: - return 16 - case ElementType.i32 | ElementType.f32: - return 32 - case _: - assert False, "unhandled case" - - def __str__(self) -> str: - return self.name - - -@dataclass -class ShapedType: - shape: list[int] - element_type: ElementType - - def rank(self) -> int: - return len(self.shape) - - @property - def bitwidth(self) -> int: - return self.element_type.bitwidth - - def __str__(self) -> str: - dim_to_str = lambda dim: str(dim) if dim != -1 else "?" - return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) - - -@dataclass -class MatmulSize: - M: int - N: int - K: int - B: int = 1 - - -@dataclass -class ProblemSize: - matmul_size: MatmulSize - lhs_type: ShapedType - rhs_type: ShapedType - res_type: ShapedType - dispatch_kind: DispatchKind - - @property - def MNK(self) -> tuple[int, int, int]: - return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) - - -@dataclass -class MfmaIntrinsic: - output_type: ElementType - m: int - n: int - k: int - input_type: ElementType - - def __str__(self) -> str: - input = str(self.input_type).upper() - output = str(self.output_type).upper() - return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}" - - @staticmethod - def mfma_f32_16x16x16_f16(): - return MfmaIntrinsic(ElementType.f32, 16, 16, 16, ElementType.f16) - - @staticmethod - def mfma_f32_32x32x8_f16(): - return MfmaIntrinsic(ElementType.f32, 32, 32, 8, ElementType.f16) - - @staticmethod - def mfma_i32_16x16x32_i8(): - return MfmaIntrinsic(ElementType.i32, 16, 16, 32, ElementType.i8) - - @staticmethod - def mfma_i32_32x32x16_i8(): - return MfmaIntrinsic(ElementType.i32, 32, 32, 16, ElementType.i8) - - @staticmethod - def all(): - return [ - MfmaIntrinsic.mfma_f32_16x16x16_f16(), - MfmaIntrinsic.mfma_f32_32x32x8_f16(), - MfmaIntrinsic.mfma_i32_16x16x32_i8(), - MfmaIntrinsic.mfma_i32_32x32x16_i8(), - ] - - -class ReorderWorkgroupsStrategy(Enum): - NONE = 0 - SWIZZLE = 1 - TRANSPOSE = 2 - - def __str__(self) -> str: - return self.name.title() - - -@dataclass -class GpuPipelineOptions: - """Represents the `iree_gpu.pipeline_options` attribute""" - - prefetch_shared_memory: Optional[bool] = None - no_reduce_shared_memory_bank_conflicts: Optional[bool] = None - reorder_workgroups_strategy: Optional[ReorderWorkgroupsStrategy] = None +from .common import * +from .dispatch_parser import * - def all_default(self) -> bool: - return all(x is None for x in astuple(self)) - - def __str__(self) -> str: - options: list[str] = [] - if self.prefetch_shared_memory is not None: - options.append( - f"prefetch_shared_memory = {str(self.prefetch_shared_memory).lower()}" - ) - if self.no_reduce_shared_memory_bank_conflicts is not None: - options.append( - f"no_reduce_shared_memory_bank_conflicts = {str(self.no_reduce_shared_memory_bank_conflicts).lower()}" - ) - if self.reorder_workgroups_strategy is not None: - options.append( - f"reorder_workgroups_strategy = {self.reorder_workgroups_strategy}" - ) - - return f"#iree_gpu.pipeline_options<{', '.join(options)}>" - - -@dataclass -class Configuration: - subgroup_size: int - workgroup_size: list[int] - intrinsic: MfmaIntrinsic - tile_sizes: list[int] - subgroup_m_count: int - subgroup_n_count: int - gpu_pipeline_options: GpuPipelineOptions - waves_per_eu: int - - -class MlirRegex(Enum): - ssa_value = r"%[a-zA-Z0-9-_]+" - tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" - - def __str__(self) -> str: - return self.value - - @staticmethod - def dps_ins_two_args() -> str: - return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" - - @staticmethod - def dps_outs_one_arg() -> str: - return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" - - -def read_input_mlir(filename: str) -> list[str]: - with open(filename, "r") as f: - return f.readlines() - - -def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tile_sizes - - -@dataclass -class ConvDimInfo: - n: int - oh: int - ow: int - oc: int - fh: int - fw: int - ic: int - - @staticmethod - def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): - n, oh, ow, oc = res_shaped_type.shape - fh, fw, ic, _ = rhs_shaped_type.shape - return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) - - @staticmethod - def from_problem_size(problem_size: ProblemSize): - return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) - - -def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tile_sizes - tile_size = [1] * len(tile_dims) - for idx, dim in enumerate(tile_dims): - if dim == "m": - tile_size[idx] = m - if dim == "n": - tile_size[idx] = n - if dim == "k": - tile_size[idx] = k - return tile_size - - -def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tile_sizes - - -def get_pipeline_config(configuration: Configuration) -> str: - extra_config = "" - if not configuration.gpu_pipeline_options.all_default(): - extra_config += f", gpu_pipeline_options = {configuration.gpu_pipeline_options}" - if configuration.waves_per_eu != 2: - extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' - return extra_config +tune_logger = logging.getLogger("tune") def apply_configuration( @@ -303,32 +73,6 @@ def apply_configuration( return new_mlir -def parse_tensor_type(tensor_type: str) -> ShapedType: - shape_match = re.search(str(MlirRegex.tensor_type), tensor_type) - assert shape_match - - shape_str = shape_match.group(1) - dims_and_elem = shape_str.split("x") - dims = [int(x) for x in dims_and_elem[:-1]] - elem = dims_and_elem[-1] - str_to_elem_ty = {x.name: x for x in ElementType} - return ShapedType(dims, str_to_elem_ty[elem]) - - -def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: - def is_compatible(intrinsic: MfmaIntrinsic) -> bool: - if problem_size.res_type.element_type != intrinsic.output_type: - return False - if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if problem_size.lhs_type.element_type != intrinsic.input_type: - return False - if problem_size.rhs_type.element_type != intrinsic.input_type: - return False - return True - - return list(filter(is_compatible, MfmaIntrinsic.all())) - - def get_mfma_intrinsic_constraints( problem_size: ProblemSize, intrinsic_m: z3.ArithRef, @@ -517,38 +261,8 @@ def get_default_output_dir() -> str: return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") -def parse_mlir(mlir_text: str, ctx: ir.Context) -> ir.Module: - mlir_module = None - try: - mlir_module = ir.Module.parse(mlir_text) - tune_logger.info("MLIR parsing successful!") - except ir.MLIRError as e: - tune_logger.error(f"Error parsing MLIR: {e}") - raise RuntimeError(f"Error parsing MLIR: {e}") - - return mlir_module - - -@dataclass -class MLIRTransformation: - """Transformation of MLIR context""" - - template: list[str] - modified: str - embeddable: str - - -class DispatchTuner(ABC): - @abstractmethod - def supports(self, op_name: str) -> bool: - """Check if the tuner can handle the type of operation represented by the input string.""" - pass - - @abstractmethod - def get_shapes(self, template: list[str]) -> ProblemSize: - """Extract problem size of the operation.""" - pass - +class DispatchTuner(DispatchParser): + # TODO(https://github.com/nod-ai/SHARK-Platform/issues/453): Remove this in favor of configuring using transform dialect. @abstractmethod def apply_params( self, @@ -560,12 +274,6 @@ def apply_params( pass -@dataclass -class OpWalkResult: - was_interrupted: bool = False - dispatch_tuner: Optional[DispatchTuner] = None - - class DispatchTunerRegistry: def __init__(self): self.registry = set() @@ -589,60 +297,7 @@ def find_handler(self, op_name: str) -> DispatchTuner: assert False, "Dispatch kind not supported" -class MmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - mmt_re = None - dps = None - for line in template: - if "linalg.generic" not in line: - continue - if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: - continue - # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) - mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - dps = re.search(mmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 2 - lhs_M, lhs_K = lhs_shaped_type.shape - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - rhs_N, rhs_K = rhs_shaped_type.shape - - assert lhs_shaped_type.element_type == rhs_shaped_type.element_type - assert lhs_K == rhs_K - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 2 - res_M, res_N = res_shaped_type.shape - - assert lhs_M == res_M - assert rhs_N == res_N - - matmul_size = MatmulSize( - lhs_shaped_type.shape[0], - rhs_shaped_type.shape[0], - lhs_shaped_type.shape[1], - ) - return ProblemSize( - matmul_size, - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.mmt, - ) - assert mmt_re - assert False, f"'{mmt_re}' not found in given context" - +class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: @@ -694,71 +349,7 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) -class ConvTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "conv_2d_nhwc_hwcf" in op_name - - def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tile_sizes - batch = 1 - fh = 1 - fw = 1 - - oh = 1 - - oc = n - ow = m - ic = k - return [batch, oh, ow, oc, fh, fw, ic] - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.conv_2d_nhwc_hwcf" not in line: - continue - - # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) - conv_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(conv_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 4 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 4 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 4 - - # int64_t n = outputShape[0]; - # int64_t oh = outputShape[1]; - # int64_t ow = outputShape[2]; - # int64_t oc = outputShape[3]; - # int64_t fh = filterShape[0]; - # int64_t fw = filterShape[1]; - # int64_t ic = filterShape[2]; - dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) - return ProblemSize( - MatmulSize( - M=dim_info.oh * dim_info.ow, - N=dim_info.oc, - K=dim_info.fh * dim_info.fw * dim_info.ic, - B=dim_info.n, - ), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.conv, - ) - - assert False, "Shape not found" - +class ConvTuner(DispatchTuner, ConvParser): # int64_t n = outputShape[0]; # int64_t oh = outputShape[1]; # int64_t ow = outputShape[2]; @@ -833,135 +424,7 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) -class ContractionTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "matmul_like" in op_name - - def is_broadcast_rhs_mmt_op(self, line: str) -> bool: - if "linalg.generic" not in line: - return False - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - return False - if ( - r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" - not in line - ): - return False - return True - - def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: - return any(self.is_broadcast_rhs_mmt_op(line) for line in template) - - def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: - for line in template: - if not self.is_broadcast_rhs_mmt_op(line): - continue - - # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 2 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.broadcast_rhs_mmt, - ) - - assert False, "Shape not found" - - def get_shapes(self, template: list[str]) -> ProblemSize: - if self.is_broadcast_rhs_mmt(template): - return self.get_shapes_broadcast_rhs_mmt(template) - - for line in template: - if "linalg.generic" not in line: - continue - if "lowering_config =" not in line: - continue - if '"reduction"' not in line: - continue - - # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() >= 2 - - M = math.prod( - val if dim == "m" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - N = math.prod( - val if dim == "n" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - K0 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) - ) - K1 = math.prod( - val if dim == "k" else 1 - for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) - ) - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.contraction, - ) - - assert False, "Shape not found" - +class ContractionTuner(DispatchTuner, ContractionParser): def get_transform_function_broadcast_rhs_mmt( self, problem_size: ProblemSize, @@ -1045,57 +508,7 @@ def apply_params( ) -class BatchMmtTuner(DispatchTuner): - def supports(self, op_name: str) -> bool: - return "batch_matmul_transpose_b" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.generic" not in line: - continue - if ( - r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' - not in line - ): - continue - # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) - bmmt_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(bmmt_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == 3 - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == 3 - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == 3 - - B0, M0, K0 = lhs_shaped_type.shape - B1, N1, K1 = rhs_shaped_type.shape - B2, M2, N2 = res_shaped_type.shape - assert B0 == B1 - assert B0 == B2 - assert M0 == M2 - assert N1 == N2 - assert K0 == K1 - return ProblemSize( - MatmulSize(M0, N1, K0, B0), - lhs_shaped_type, - rhs_shaped_type, - res_shaped_type, - DispatchKind.batch_mmt, - ) - - assert False, "Shape not found" - +class BatchMmtTuner(DispatchTuner, BatchMmtParser): def get_transform_function_batch_mmt( self, problem_size: ProblemSize, @@ -1154,78 +567,7 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) -class BatchMatmulTuner(DispatchTuner): - def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): - self.lhs_dims = lhs_dims - self.rhs_dims = rhs_dims - self.tile_dims = tile_dims - - def supports(self, op_name: str) -> bool: - return "batch_matmul" in op_name - - def get_shapes(self, template: list[str]) -> ProblemSize: - for line in template: - if "linalg.batch_matmul" not in line: - continue - # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) - # outs(%12 : tensor<64x72x1280xf32>) - cont_re = ( - rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" - ) - dps = re.search(cont_re, line) - if dps is None: - continue - - lhs_tensor_type = dps.group("LHS") - rhs_tensor_type = dps.group("RHS") - lhs_shaped_type = parse_tensor_type(lhs_tensor_type) - assert lhs_shaped_type.rank() == len(self.lhs_dims) - - rhs_shaped_type = parse_tensor_type(rhs_tensor_type) - assert rhs_shaped_type.rank() == len(self.rhs_dims) - - res_tensor_type = dps.group("RES") - res_shaped_type = parse_tensor_type(res_tensor_type) - assert res_shaped_type.rank() == lhs_shaped_type.rank() - - LHS = lhs_shaped_type.shape - RHS = rhs_shaped_type.shape - RES = res_shaped_type.shape - - B = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - B0 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) - ) - B1 = math.prod( - val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) - ) - M = math.prod( - val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - N = math.prod( - val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - K0 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) - ) - K1 = math.prod( - val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) - ) - assert B == B0 and B == B1 - assert K0 == K1 - - return ProblemSize( - MatmulSize(M, N, K0, B), - lhs_type=lhs_shaped_type, - rhs_type=rhs_shaped_type, - res_type=res_shaped_type, - dispatch_kind=DispatchKind.batch_matmul, - ) - - assert False, "Shape not found" - +class BatchMatmulTuner(DispatchTuner, BatchMatmulParser): def get_transform_function_batch_matmul( self, problem_size: ProblemSize, @@ -1297,6 +639,12 @@ def apply_params( return MLIRTransformation(template, modified, embeddable) +@dataclass +class OpWalkResult: + was_interrupted: bool = False + dispatch_tuner: Optional[DispatchTuner] = None + + def walk_callback_get_fn( op: ir.Operation, walk_result: OpWalkResult, @@ -1350,7 +698,8 @@ def tune( mlir_text = "".join(mlir_template) with ir.Context() as ctx: - mlir_module: ir.Module = parse_mlir(mlir_text, ctx) + tuner_context = TunerContext(ctx, tune_logger) + mlir_module: ir.Module = parse_mlir(mlir_text, tuner_context) # Save the input file as the first candidate. with open(path.join(output, f"0.mlir"), "w") as f: f.write(mlir_text) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index a1a3a3e49..63e8441d1 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -9,340 +9,48 @@ """ import pytest -from . import candidate_gen - -from iree.compiler import ir # type: ignore -from iree.compiler.dialects import func # type: ignore - - -def test_get_shaped_type_element_bitwidth() -> None: - assert ( - candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 - ) - assert ( - candidate_gen.ShapedType( - [2048, 512, 384], candidate_gen.ElementType.f8 - ).bitwidth - == 8 - ) - assert ( - candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 - ) - - -def test_get_shaped_type_to_str() -> None: - assert ( - str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) - == "1024x2048xi8" - ) - assert ( - str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) - == "1024xf32" - ) - assert ( - str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) - == "1x2x3xf16" - ) - assert ( - str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) - == "?x2x3xf16" - ) - - -def test_parse_tensor_type() -> None: - assert candidate_gen.parse_tensor_type( - "tensor<1x2x3xf32>" - ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) - assert candidate_gen.parse_tensor_type( - "tensor<123xi8>" - ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) - - -def test_get_mmt_tile_sizes() -> None: - config = candidate_gen.Configuration( - subgroup_size=0, - workgroup_size=[], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=0, - ) - assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32] - - -def test_get_conv_tile_sizes() -> None: - config = candidate_gen.Configuration( - subgroup_size=64, - workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=1, - ) - assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [ - 1, - 1, - 464, - 320, - 1, - 1, - 16, - ] - - -def test_gpu_pipeline_options() -> None: - options = candidate_gen.GpuPipelineOptions() - assert options.all_default() - assert str(options) == "#iree_gpu.pipeline_options<>" - - options.prefetch_shared_memory = True - assert not options.all_default() - assert str(options) == "#iree_gpu.pipeline_options" - - options.no_reduce_shared_memory_bank_conflicts = False - assert ( - str(options) - == "#iree_gpu.pipeline_options" - ) - - options = candidate_gen.GpuPipelineOptions() - options.reorder_workgroups_strategy = ( - candidate_gen.ReorderWorkgroupsStrategy.TRANSPOSE - ) - assert not options.all_default() - assert ( - str(options) - == "#iree_gpu.pipeline_options" - ) - -def test_get_contract_tile_sizes() -> None: - config = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=2, - ) - assert candidate_gen.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] - assert candidate_gen.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] - assert candidate_gen.get_contract_tile_sizes(config, "knm") == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, "kkk") == [ - 16, - 16, - 16, - ] - - -def test_get_pipeline_config() -> None: - config = candidate_gen.Configuration( - subgroup_size=32, - workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), - waves_per_eu=2, - ) - config1_str: str = candidate_gen.get_pipeline_config(config) - assert config1_str == "" - - config.waves_per_eu = 4 - config2_str: str = candidate_gen.get_pipeline_config(config) - assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' - - config.gpu_pipeline_options.prefetch_shared_memory = True - config3_str = candidate_gen.get_pipeline_config(config) - assert ( - config3_str - == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' - ) - - -def test_get_shapes_mmt() -> None: - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.MmtTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - - -def test_get_shapes_conv() -> None: - template = [ - r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", - r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.ConvTuner().get_shapes(template) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 256, 11520), - candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, - ) - - -def test_get_shapes_contract() -> None: - template = [ - r"%18 = tensor.empty() : tensor<2048x1280xf32>", - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"^bb0(%in: f16, %in_0: f16, %out: f32):", - ] - assert candidate_gen.ContractionTuner("mk", "nk", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, - ) - - -def test_get_shapes_batch_matmul() -> None: - template = [ - "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", - "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMatmulTuner("bmk", "bkn", "mnk").get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(32, 32, 1024, 1), - candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - - -def test_get_shapes_batch_mmt() -> None: - template = [ - r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", - r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', - r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", - ] - assert candidate_gen.BatchMmtTuner().get_shapes( - template - ) == candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, - ) - - -def test_mfma_intrinsic_to_str() -> None: - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16()) - == "MFMA_F32_16x16x16_F16" - ) - assert ( - str(candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8()) - == "MFMA_I32_32x32x16_I8" - ) - - -def test_get_compatible_mfma_intrinsics() -> None: - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 1280, 1280), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.mmt, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_i32_16x16x32_i8(), - candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8(), - ] - - assert candidate_gen.get_compatible_mfma_intrinsics( - candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, - ) - ) == [ - candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(), - ] +from . import candidate_gen +from . import common def test_generate_solutions() -> None: - matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280) - lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(2048, 3840, 1280) + lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16) + rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16) + res_type = common.ShapedType([2048, 3840], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) configs = candidate_gen.generate_solutions(problem_size, 4) assert configs is not None def test_calculate_shared_memory_usage_in_bytes() -> None: - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) assert ( candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) == 147456 ) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) assert ( candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) == 81920 ) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) assert ( candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) @@ -351,12 +59,12 @@ def test_calculate_shared_memory_usage_in_bytes() -> None: def test_generate_constraints_valid_input() -> None: - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) # Define input parameters as z3 Ints m, n, k = ( @@ -397,12 +105,12 @@ def test_generate_constraints_valid_input() -> None: def test_generate_constraints_invalid_input() -> None: # Define input parameters that should lead to unsatisfiable constraints - matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024) - lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16) - res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32) - problem_size = candidate_gen.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) m, n, k = ( candidate_gen.z3.Int("m"), @@ -458,25 +166,23 @@ def test_apply_params_mmt() -> None: M, N, K = 2048, 1280, 1280 - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[8, 8, 8], subgroup_m_count=16, subgroup_n_count=16, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions( - prefetch_shared_memory=True - ), + gpu_pipeline_options=common.GpuPipelineOptions(prefetch_shared_memory=True), waves_per_eu=8, ) - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(M, N, K), - candidate_gen.ShapedType([M, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([N, K], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.mmt, + problem_size = common.ProblemSize( + common.MatmulSize(M, N, K), + common.ShapedType([M, K], common.ElementType.f16), + common.ShapedType([N, K], common.ElementType.f16), + common.ShapedType([M, N], common.ElementType.f32), + common.DispatchKind.mmt, ) tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) @@ -512,27 +218,25 @@ def test_apply_params_conv() -> None: n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions( - reorder_workgroups_strategy=candidate_gen.ReorderWorkgroupsStrategy.TRANSPOSE + gpu_pipeline_options=common.GpuPipelineOptions( + reorder_workgroups_strategy=common.ReorderWorkgroupsStrategy.TRANSPOSE ), waves_per_eu=2, ) - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), - candidate_gen.ShapedType( - [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 - ), - candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.conv, + problem_size = common.ProblemSize( + common.MatmulSize(oh * ow, oc, fh * fw * ic), + common.ShapedType([n, oh + 2, ow + 2, oc], common.ElementType.f16), + common.ShapedType([fh, fw, ic, oc], common.ElementType.f16), + common.ShapedType([n, oh, ow, oc], common.ElementType.f32), + common.DispatchKind.conv, ) tf_mlir = candidate_gen.ConvTuner().apply_params( problem_size, mlir_template, config @@ -570,22 +274,22 @@ def test_apply_params_contract() -> None: ] tile_dims = "*mnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(2048, 3840, 1280), - candidate_gen.ShapedType([2, 1024, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 20, 64, 1280], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([3, 2, 20, 1024, 64], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.contraction, + problem_size = common.ProblemSize( + common.MatmulSize(2048, 3840, 1280), + common.ShapedType([2, 1024, 1280], common.ElementType.f16), + common.ShapedType([3, 20, 64, 1280], common.ElementType.f16), + common.ShapedType([3, 2, 20, 1024, 64], common.ElementType.f32), + common.DispatchKind.contraction, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), tile_sizes=[480, 384, 32], subgroup_m_count=1, subgroup_n_count=4, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=2, ) @@ -617,22 +321,22 @@ def test_apply_params_batch_matmul() -> None: ] tile_dims = "bmnk" - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(968, 320, 640, 64), - candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_matmul, + problem_size = common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], common.ElementType.f16), + common.ShapedType([64, 640, 320], common.ElementType.f16), + common.ShapedType([64, 968, 320], common.ElementType.f32), + common.DispatchKind.batch_matmul, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), tile_sizes=[416, 320, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=2, ) @@ -667,22 +371,22 @@ def test_apply_params_batch_mmt_float() -> None: '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.f16), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.f32), - candidate_gen.DispatchKind.batch_mmt, + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.f16), + common.ShapedType([2, 640, 640], common.ElementType.f16), + common.ShapedType([2, 4096, 640], common.ElementType.f32), + common.DispatchKind.batch_mmt, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=2, ) @@ -715,22 +419,22 @@ def test_apply_params_batch_mmt_int() -> None: '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.batch_mmt, + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.i8), + common.ShapedType([2, 640, 640], common.ElementType.i8), + common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.DispatchKind.batch_mmt, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=4, ) @@ -786,22 +490,22 @@ def test_apply_params_broadcast_rhs_mmt() -> None: '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "1"}', ] - problem_size = candidate_gen.ProblemSize( - candidate_gen.MatmulSize(4096, 640, 640, 2), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([640, 640], candidate_gen.ElementType.i8), - candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32), - candidate_gen.DispatchKind.broadcast_rhs_mmt, + problem_size = common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.i8), + common.ShapedType([640, 640], common.ElementType.i8), + common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.DispatchKind.broadcast_rhs_mmt, ) - config = candidate_gen.Configuration( + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, - gpu_pipeline_options=candidate_gen.GpuPipelineOptions(), + gpu_pipeline_options=common.GpuPipelineOptions(), waves_per_eu=4, ) @@ -862,19 +566,3 @@ def test_detect_broadcast_rhs_mmt() -> None: assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt( mlir_lines ) - - -def test_parse_mlir() -> None: - with ir.Context() as ctx: - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - mlir_module = candidate_gen.parse_mlir(mlir_str, ctx) - assert mlir_module is not None - assert isinstance(mlir_module, ir.Module) - assert isinstance(mlir_module.body.operations[0], func.FuncOp) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py new file mode 100644 index 000000000..7b295cdb0 --- /dev/null +++ b/tuner/tuner/common.py @@ -0,0 +1,264 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import re +import logging +from dataclasses import astuple, dataclass +from enum import Enum +from typing import Optional + +from iree.compiler import ir # type: ignore + + +class TunerContext: + def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): + self.mlir_ctx = mlir_ctx + self.logger = logger + + +class DispatchKind(Enum): + conv = 1 + mmt = 2 + contraction = 3 + batch_mmt = 4 + batch_matmul = 5 + broadcast_rhs_mmt = 6 + + +class ElementType(Enum): + i8 = 1 + i32 = 2 + f8 = 3 + f16 = 4 + f32 = 5 + + @property + def bitwidth(self) -> int: + match self: + case ElementType.i8 | ElementType.f8: + return 8 + case ElementType.f16: + return 16 + case ElementType.i32 | ElementType.f32: + return 32 + case _: + assert False, "unhandled case" + + def __str__(self) -> str: + return self.name + + +@dataclass +class ShapedType: + shape: list[int] + element_type: ElementType + + def rank(self) -> int: + return len(self.shape) + + @property + def bitwidth(self) -> int: + return self.element_type.bitwidth + + def __str__(self) -> str: + dim_to_str = lambda dim: str(dim) if dim != -1 else "?" + return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type) + + +@dataclass +class MatmulSize: + M: int + N: int + K: int + B: int = 1 + + +@dataclass +class ProblemSize: + matmul_size: MatmulSize + lhs_type: ShapedType + rhs_type: ShapedType + res_type: ShapedType + dispatch_kind: DispatchKind + + @property + def MNK(self) -> tuple[int, int, int]: + return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) + + +@dataclass +class MfmaIntrinsic: + output_type: ElementType + m: int + n: int + k: int + input_type: ElementType + + def __str__(self) -> str: + input = str(self.input_type).upper() + output = str(self.output_type).upper() + return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}" + + @staticmethod + def mfma_f32_16x16x16_f16(): + return MfmaIntrinsic(ElementType.f32, 16, 16, 16, ElementType.f16) + + @staticmethod + def mfma_f32_32x32x8_f16(): + return MfmaIntrinsic(ElementType.f32, 32, 32, 8, ElementType.f16) + + @staticmethod + def mfma_i32_16x16x32_i8(): + return MfmaIntrinsic(ElementType.i32, 16, 16, 32, ElementType.i8) + + @staticmethod + def mfma_i32_32x32x16_i8(): + return MfmaIntrinsic(ElementType.i32, 32, 32, 16, ElementType.i8) + + @staticmethod + def all(): + return [ + MfmaIntrinsic.mfma_f32_16x16x16_f16(), + MfmaIntrinsic.mfma_f32_32x32x8_f16(), + MfmaIntrinsic.mfma_i32_16x16x32_i8(), + MfmaIntrinsic.mfma_i32_32x32x16_i8(), + ] + + +def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]: + def is_compatible(intrinsic: MfmaIntrinsic) -> bool: + if problem_size.res_type.element_type != intrinsic.output_type: + return False + if problem_size.dispatch_kind != DispatchKind.batch_matmul: + if problem_size.lhs_type.element_type != intrinsic.input_type: + return False + if problem_size.rhs_type.element_type != intrinsic.input_type: + return False + return True + + return list(filter(is_compatible, MfmaIntrinsic.all())) + + +class ReorderWorkgroupsStrategy(Enum): + NONE = 0 + SWIZZLE = 1 + TRANSPOSE = 2 + + def __str__(self) -> str: + return self.name.title() + + +@dataclass +class GpuPipelineOptions: + """Represents the `iree_gpu.pipeline_options` attribute""" + + prefetch_shared_memory: Optional[bool] = None + no_reduce_shared_memory_bank_conflicts: Optional[bool] = None + reorder_workgroups_strategy: Optional[ReorderWorkgroupsStrategy] = None + + def all_default(self) -> bool: + return all(x is None for x in astuple(self)) + + def __str__(self) -> str: + options: list[str] = [] + if self.prefetch_shared_memory is not None: + options.append( + f"prefetch_shared_memory = {str(self.prefetch_shared_memory).lower()}" + ) + if self.no_reduce_shared_memory_bank_conflicts is not None: + options.append( + f"no_reduce_shared_memory_bank_conflicts = {str(self.no_reduce_shared_memory_bank_conflicts).lower()}" + ) + if self.reorder_workgroups_strategy is not None: + options.append( + f"reorder_workgroups_strategy = {self.reorder_workgroups_strategy}" + ) + + return f"#iree_gpu.pipeline_options<{', '.join(options)}>" + + +@dataclass +class Configuration: + subgroup_size: int + workgroup_size: list[int] + intrinsic: MfmaIntrinsic + tile_sizes: list[int] + subgroup_m_count: int + subgroup_n_count: int + gpu_pipeline_options: GpuPipelineOptions + waves_per_eu: int + + +def get_pipeline_config(configuration: Configuration) -> str: + extra_config = "" + if not configuration.gpu_pipeline_options.all_default(): + extra_config += f", gpu_pipeline_options = {configuration.gpu_pipeline_options}" + if configuration.waves_per_eu != 2: + extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}' + return extra_config + + +class MlirRegex(Enum): + ssa_value = r"%[a-zA-Z0-9-_]+" + tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>" + + def __str__(self) -> str: + return self.value + + @staticmethod + def dps_ins_two_args() -> str: + return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)" + + @staticmethod + def dps_outs_one_arg() -> str: + return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)" + + +def read_input_mlir(filename: str) -> list[str]: + with open(filename, "r") as f: + return f.readlines() + + +@dataclass +class ConvDimInfo: + n: int + oh: int + ow: int + oc: int + fh: int + fw: int + ic: int + + @staticmethod + def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType): + n, oh, ow, oc = res_shaped_type.shape + fh, fw, ic, _ = rhs_shaped_type.shape + return ConvDimInfo(n, oh, ow, oc, fh, fw, ic) + + @staticmethod + def from_problem_size(problem_size: ProblemSize): + return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type) + + +def parse_tensor_type(tensor_type: str) -> ShapedType: + shape_match = re.search(str(MlirRegex.tensor_type), tensor_type) + assert shape_match + + shape_str = shape_match.group(1) + dims_and_elem = shape_str.split("x") + dims = [int(x) for x in dims_and_elem[:-1]] + elem = dims_and_elem[-1] + str_to_elem_ty = {x.name: x for x in ElementType} + return ShapedType(dims, str_to_elem_ty[elem]) + + +@dataclass +class MLIRTransformation: + """Transformation of MLIR context""" + + template: list[str] + modified: str + embeddable: str diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py new file mode 100644 index 000000000..858d593c9 --- /dev/null +++ b/tuner/tuner/common_test.py @@ -0,0 +1,131 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest +from . import common + + +def test_get_shaped_type_element_bitwidth() -> None: + assert common.ShapedType([1024, 2048], common.ElementType.i8).bitwidth == 8 + assert common.ShapedType([2048], common.ElementType.i32).bitwidth == 32 + assert common.ShapedType([2048, 512, 384], common.ElementType.f8).bitwidth == 8 + assert common.ShapedType([1, 1], common.ElementType.f16).bitwidth == 16 + + +def test_get_shaped_type_to_str() -> None: + assert str(common.ShapedType([1024, 2048], common.ElementType.i8)) == "1024x2048xi8" + assert str(common.ShapedType([1024], common.ElementType.f32)) == "1024xf32" + assert str(common.ShapedType([1, 2, 3], common.ElementType.f16)) == "1x2x3xf16" + assert str(common.ShapedType([-1, 2, 3], common.ElementType.f16)) == "?x2x3xf16" + + +def test_parse_tensor_type() -> None: + assert common.parse_tensor_type("tensor<1x2x3xf32>") == common.ShapedType( + [1, 2, 3], common.ElementType.f32 + ) + assert common.parse_tensor_type("tensor<123xi8>") == common.ShapedType( + [123], common.ElementType.i8 + ) + + +def test_gpu_pipeline_options() -> None: + options = common.GpuPipelineOptions() + assert options.all_default() + assert str(options) == "#iree_gpu.pipeline_options<>" + + options.prefetch_shared_memory = True + assert not options.all_default() + assert str(options) == "#iree_gpu.pipeline_options" + + options.no_reduce_shared_memory_bank_conflicts = False + assert ( + str(options) + == "#iree_gpu.pipeline_options" + ) + + options = common.GpuPipelineOptions() + options.reorder_workgroups_strategy = common.ReorderWorkgroupsStrategy.TRANSPOSE + assert not options.all_default() + assert ( + str(options) + == "#iree_gpu.pipeline_options" + ) + + +def test_get_pipeline_config() -> None: + config = common.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=2, + ) + config1_str: str = common.get_pipeline_config(config) + assert config1_str == "" + + config.waves_per_eu = 4 + config2_str: str = common.get_pipeline_config(config) + assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + + config.gpu_pipeline_options.prefetch_shared_memory = True + config3_str = common.get_pipeline_config(config) + assert ( + config3_str + == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}' + ) + + +def test_mfma_intrinsic_to_str() -> None: + assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" + assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8" + + +def test_get_compatible_mfma_intrinsics() -> None: + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], common.ElementType.f16), + common.ShapedType([1280, 1280], common.ElementType.f16), + common.ShapedType([2048, 1280], common.ElementType.f32), + common.DispatchKind.mmt, + ) + ) == [ + common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + ] + + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], common.ElementType.i8), + common.ShapedType([1280, 1280], common.ElementType.i8), + common.ShapedType([2048, 1280], common.ElementType.i32), + common.DispatchKind.mmt, + ) + ) == [ + common.MfmaIntrinsic.mfma_i32_16x16x32_i8(), + common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + ] + + assert common.get_compatible_mfma_intrinsics( + common.ProblemSize( + common.MatmulSize(968, 320, 640, 64), + common.ShapedType([64, 968, 640], common.ElementType.f32), + common.ShapedType([64, 640, 320], common.ElementType.f32), + common.ShapedType([64, 968, 320], common.ElementType.f32), + common.DispatchKind.batch_matmul, + ) + ) == [ + common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + ] diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py new file mode 100644 index 000000000..670f8c3f7 --- /dev/null +++ b/tuner/tuner/dispatch_parser.py @@ -0,0 +1,435 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +import math +import re +from abc import ABCMeta, abstractmethod + +from .common import * + + +def get_mmt_tile_sizes(configuration: Configuration): + return configuration.tile_sizes + + +def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: + m, n, k = configuration.tile_sizes + tile_size = [1] * len(tile_dims) + for idx, dim in enumerate(tile_dims): + if dim == "m": + tile_size[idx] = m + if dim == "n": + tile_size[idx] = n + if dim == "k": + tile_size[idx] = k + return tile_size + + +def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: + return [1] + configuration.tile_sizes + + +def parse_mlir(mlir_text: str, ctx: TunerContext) -> ir.Module: + mlir_module = None + try: + mlir_module = ir.Module.parse(mlir_text, ctx.mlir_ctx) + ctx.logger.info("MLIR parsing successful!") + except ir.MLIRError as e: + ctx.logger.error(f"Error parsing MLIR: {e}") + raise RuntimeError(f"Error parsing MLIR: {e}") + + return mlir_module + + +class DispatchParser(metaclass=ABCMeta): + @abstractmethod + def supports(self, op_name: str) -> bool: + """Check if the tuner can handle the type of operation represented by the input string.""" + pass + + @abstractmethod + def get_shapes(self, template: list[str]) -> ProblemSize: + """Extract problem size of the operation.""" + pass + + +class MmtParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + mmt_re = None + dps = None + for line in template: + if "linalg.generic" not in line: + continue + if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line: + continue + # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) + mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + dps = re.search(mmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 2 + lhs_M, lhs_K = lhs_shaped_type.shape + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + rhs_N, rhs_K = rhs_shaped_type.shape + + assert lhs_shaped_type.element_type == rhs_shaped_type.element_type + assert lhs_K == rhs_K + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 2 + res_M, res_N = res_shaped_type.shape + + assert lhs_M == res_M + assert rhs_N == res_N + + matmul_size = MatmulSize( + lhs_shaped_type.shape[0], + rhs_shaped_type.shape[0], + lhs_shaped_type.shape[1], + ) + return ProblemSize( + matmul_size, + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.mmt, + ) + assert mmt_re + assert False, f"'{mmt_re}' not found in given context" + + +class ConvParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "conv_2d_nhwc_hwcf" in op_name + + def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: + m, n, k = configuration.tile_sizes + batch = 1 + fh = 1 + fw = 1 + + oh = 1 + + oc = n + ow = m + ic = k + return [batch, oh, ow, oc, fh, fw, ic] + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.conv_2d_nhwc_hwcf" not in line: + continue + + # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>) + conv_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(conv_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 4 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 4 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 4 + + # int64_t n = outputShape[0]; + # int64_t oh = outputShape[1]; + # int64_t ow = outputShape[2]; + # int64_t oc = outputShape[3]; + # int64_t fh = filterShape[0]; + # int64_t fw = filterShape[1]; + # int64_t ic = filterShape[2]; + dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type) + return ProblemSize( + MatmulSize( + M=dim_info.oh * dim_info.ow, + N=dim_info.oc, + K=dim_info.fh * dim_info.fw * dim_info.ic, + B=dim_info.n, + ), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.conv, + ) + + assert False, "Shape not found" + + +class ContractionParser(DispatchParser): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "matmul_like" in op_name + + def is_broadcast_rhs_mmt_op(self, line: str) -> bool: + if "linalg.generic" not in line: + return False + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + return False + if ( + r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>" + not in line + ): + return False + return True + + def is_broadcast_rhs_mmt(self, template: list[str]) -> bool: + return any(self.is_broadcast_rhs_mmt_op(line) for line in template) + + def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize: + for line in template: + if not self.is_broadcast_rhs_mmt_op(line): + continue + + # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 2 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.broadcast_rhs_mmt, + ) + + assert False, "Shape not found" + + def get_shapes(self, template: list[str]) -> ProblemSize: + if self.is_broadcast_rhs_mmt(template): + return self.get_shapes_broadcast_rhs_mmt(template) + + for line in template: + if "linalg.generic" not in line: + continue + if "lowering_config =" not in line: + continue + if '"reduction"' not in line: + continue + + # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() >= 2 + + M = math.prod( + val if dim == "m" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + N = math.prod( + val if dim == "n" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + K0 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape) + ) + K1 = math.prod( + val if dim == "k" else 1 + for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape) + ) + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.contraction, + ) + + assert False, "Shape not found" + + +class BatchMmtParser(DispatchParser): + def supports(self, op_name: str) -> bool: + return "batch_matmul_transpose_b" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.generic" not in line: + continue + if ( + r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]' + not in line + ): + continue + # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) + bmmt_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(bmmt_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == 3 + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == 3 + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == 3 + + B0, M0, K0 = lhs_shaped_type.shape + B1, N1, K1 = rhs_shaped_type.shape + B2, M2, N2 = res_shaped_type.shape + assert B0 == B1 + assert B0 == B2 + assert M0 == M2 + assert N1 == N2 + assert K0 == K1 + return ProblemSize( + MatmulSize(M0, N1, K0, B0), + lhs_shaped_type, + rhs_shaped_type, + res_shaped_type, + DispatchKind.batch_mmt, + ) + + assert False, "Shape not found" + + +class BatchMatmulParser(DispatchParser): + def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str): + self.lhs_dims = lhs_dims + self.rhs_dims = rhs_dims + self.tile_dims = tile_dims + + def supports(self, op_name: str) -> bool: + return "batch_matmul" in op_name + + def get_shapes(self, template: list[str]) -> ProblemSize: + for line in template: + if "linalg.batch_matmul" not in line: + continue + # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>) + # outs(%12 : tensor<64x72x1280xf32>) + cont_re = ( + rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}" + ) + dps = re.search(cont_re, line) + if dps is None: + continue + + lhs_tensor_type = dps.group("LHS") + rhs_tensor_type = dps.group("RHS") + lhs_shaped_type = parse_tensor_type(lhs_tensor_type) + assert lhs_shaped_type.rank() == len(self.lhs_dims) + + rhs_shaped_type = parse_tensor_type(rhs_tensor_type) + assert rhs_shaped_type.rank() == len(self.rhs_dims) + + res_tensor_type = dps.group("RES") + res_shaped_type = parse_tensor_type(res_tensor_type) + assert res_shaped_type.rank() == lhs_shaped_type.rank() + + LHS = lhs_shaped_type.shape + RHS = rhs_shaped_type.shape + RES = res_shaped_type.shape + + B = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + B0 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS) + ) + B1 = math.prod( + val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES) + ) + M = math.prod( + val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + N = math.prod( + val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + K0 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS) + ) + K1 = math.prod( + val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS) + ) + assert B == B0 and B == B1 + assert K0 == K1 + + return ProblemSize( + MatmulSize(M, N, K0, B), + lhs_type=lhs_shaped_type, + rhs_type=rhs_shaped_type, + res_type=res_shaped_type, + dispatch_kind=DispatchKind.batch_matmul, + ) + + assert False, "Shape not found" diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py new file mode 100644 index 000000000..bcdee240c --- /dev/null +++ b/tuner/tuner/dispatch_parser_test.py @@ -0,0 +1,176 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest + +from logging import Logger +from unittest.mock import MagicMock + +from iree.compiler import ir # type: ignore +from iree.compiler.dialects import func # type: ignore + +from . import common +from . import dispatch_parser + + +def test_get_mmt_tile_sizes() -> None: + config = dispatch_parser.Configuration( + subgroup_size=0, + workgroup_size=[], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[128, 320, 32], + subgroup_m_count=0, + subgroup_n_count=0, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=0, + ) + assert dispatch_parser.get_mmt_tile_sizes(config) == [128, 320, 32] + + +def test_get_conv_tile_sizes() -> None: + config = dispatch_parser.Configuration( + subgroup_size=64, + workgroup_size=[256, 1, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[464, 320, 16], + subgroup_m_count=1, + subgroup_n_count=4, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=1, + ) + assert dispatch_parser.ConvParser().get_conv_tile_sizes(config) == [ + 1, + 1, + 464, + 320, + 1, + 1, + 16, + ] + + +def test_get_contract_tile_sizes() -> None: + config = dispatch_parser.Configuration( + subgroup_size=32, + workgroup_size=[16, 16, 1], + intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + tile_sizes=[4, 8, 16], + subgroup_m_count=1, + subgroup_n_count=1, + gpu_pipeline_options=common.GpuPipelineOptions(), + waves_per_eu=2, + ) + assert dispatch_parser.get_contract_tile_sizes(config, "mnk") == [4, 8, 16] + assert dispatch_parser.get_contract_tile_sizes(config, "nmk") == [8, 4, 16] + assert dispatch_parser.get_contract_tile_sizes(config, "knm") == [16, 8, 4] + assert dispatch_parser.get_contract_tile_sizes(config, "kkk") == [ + 16, + 16, + 16, + ] + + +def test_get_shapes_mmt() -> None: + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert dispatch_parser.MmtParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], common.ElementType.f16), + common.ShapedType([1280, 1280], common.ElementType.f16), + common.ShapedType([2048, 1280], common.ElementType.f32), + dispatch_parser.DispatchKind.mmt, + ) + + +def test_get_shapes_conv() -> None: + template = [ + r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", + r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.ConvParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(32, 256, 11520), + common.ShapedType([1, 3, 34, 1280], common.ElementType.f16), + common.ShapedType([3, 3, 1280, 256], common.ElementType.f16), + common.ShapedType([1, 1, 32, 256], common.ElementType.f32), + dispatch_parser.DispatchKind.conv, + ) + + +def test_get_shapes_contract() -> None: + template = [ + r"%18 = tensor.empty() : tensor<2048x1280xf32>", + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"^bb0(%in: f16, %in_0: f16, %out: f32):", + ] + assert dispatch_parser.ContractionParser("mk", "nk", "mnk").get_shapes( + template + ) == common.ProblemSize( + common.MatmulSize(2048, 1280, 1280), + common.ShapedType([2048, 1280], common.ElementType.f16), + common.ShapedType([1280, 1280], common.ElementType.f16), + common.ShapedType([2048, 1280], common.ElementType.f32), + dispatch_parser.DispatchKind.contraction, + ) + + +def test_get_shapes_batch_matmul() -> None: + template = [ + "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", + "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.BatchMatmulParser("bmk", "bkn", "mnk").get_shapes( + template + ) == common.ProblemSize( + common.MatmulSize(32, 32, 1024, 1), + common.ShapedType([1, 32, 1024], common.ElementType.f32), + common.ShapedType([1, 1024, 32], common.ElementType.f32), + common.ShapedType([1, 32, 32], common.ElementType.f32), + dispatch_parser.DispatchKind.batch_matmul, + ) + + +def test_get_shapes_batch_mmt() -> None: + template = [ + r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", + r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', + r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>", + ] + assert dispatch_parser.BatchMmtParser().get_shapes(template) == common.ProblemSize( + common.MatmulSize(4096, 640, 640, 2), + common.ShapedType([2, 4096, 640], common.ElementType.i8), + common.ShapedType([2, 640, 640], common.ElementType.i8), + common.ShapedType([2, 4096, 640], common.ElementType.i32), + dispatch_parser.DispatchKind.batch_mmt, + ) + + +def test_parse_mlir() -> None: + with ir.Context() as ctx: + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } + """ + logger: Logger = MagicMock(spec=Logger) + tuner_context = common.TunerContext(ctx, logger) + mlir_module = dispatch_parser.parse_mlir(mlir_str, tuner_context) + assert mlir_module is not None + assert isinstance(mlir_module, ir.Module) + assert isinstance(mlir_module.body.operations[0], func.FuncOp) From a5c152723a8e5261e9736d298d870779f81fada3 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 15 Nov 2024 11:09:46 -0500 Subject: [PATCH 49/59] [tuner] Move constraint generation out of canddiate_gen. NFC. (#539) This is just code motion to make the code more modular. --- tuner/tuner/candidate_gen.py | 200 ++--------------------- tuner/tuner/candidate_gen_test.py | 135 --------------- tuner/tuner/dispatch_constraints.py | 197 ++++++++++++++++++++++ tuner/tuner/dispatch_constraints_test.py | 161 ++++++++++++++++++ 4 files changed, 368 insertions(+), 325 deletions(-) create mode 100644 tuner/tuner/dispatch_constraints.py create mode 100644 tuner/tuner/dispatch_constraints_test.py diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 06ccae0e3..2f21520f0 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -22,7 +22,6 @@ import logging import pickle import re -import z3 # type: ignore from dataclasses import dataclass from os import path, makedirs from typing import Optional @@ -32,6 +31,7 @@ from iree.compiler import ir # type: ignore from .common import * +from .dispatch_constraints import * from .dispatch_parser import * tune_logger = logging.getLogger("tune") @@ -73,194 +73,6 @@ def apply_configuration( return new_mlir -def get_mfma_intrinsic_constraints( - problem_size: ProblemSize, - intrinsic_m: z3.ArithRef, - intrinsic_n: z3.ArithRef, - intrinsic_k: z3.ArithRef, -) -> z3.BoolRef: - compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) - assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" - return z3.Or( - *( - z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) - for mfma in compatible_intrinsics - ) - ) - - -def get_dispatch_constraints( - problem_size: ProblemSize, - tile_m: z3.ArithRef, - tile_n: z3.ArithRef, - tile_k: z3.ArithRef, -) -> list[z3.BoolRef]: - if problem_size.dispatch_kind != DispatchKind.conv: - return [] - - dim_info = ConvDimInfo.from_problem_size(problem_size) - conv_constraints = [] - # WARNING: This sometimes makes the constraints UNSAT for some reason. - conv_constraints += [tile_m <= dim_info.ow] - conv_constraints += [tile_n <= dim_info.oc] - conv_constraints += [tile_k <= dim_info.ic] - return conv_constraints - - -def calculate_shared_memory_usage_in_bytes( - problem_size: ProblemSize, - m: int | z3.ArithRef, - n: int | z3.ArithRef, - k: int | z3.ArithRef, -) -> int | z3.ArithRef: - lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) - rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) - return lhs_memory + rhs_memory - - -def generate_constraints( - problem_size: ProblemSize, - tile_sizes, - num_subgroups, - subgroup_size, - intrinsic_size, - workgroup_size, - subgroup_m_count, - subgroup_n_count, - waves_per_eu, -): - M, N, K = ( - problem_size.matmul_size.M, - problem_size.matmul_size.N, - problem_size.matmul_size.K, - ) - m, n, k = tile_sizes - intrinsic_mn, intrinsic_k = intrinsic_size - wg_x, wg_y, wg_z = workgroup_size - wg_threads = z3.Int("wg_threads") - constraints = [wg_threads == wg_x * wg_y * wg_z] - constraints += [subgroup_size == 64, wg_threads <= 1024] - constraints += [ - get_mfma_intrinsic_constraints( - problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k - ) - ] - subgroup_k_count = 1 - constraints += [ - m >= intrinsic_mn, - m <= 512, - m <= M, - ] - constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] - constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] - for x in (subgroup_m_count, subgroup_n_count): - constraints += [x >= 1, x <= 32] - - subgroup_m_tile_count = z3.Int("sg_m_tcnt") - subgroup_n_tile_count = z3.Int("sg_n_tcnt") - subgroup_k_tile_count = z3.Int("sg_k_tcnt") - for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): - constraints += [x >= 1, x <= 32] - - constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] - constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] - constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] - constraints += [wg_x == subgroup_size * subgroup_n_count] - constraints += [wg_y == subgroup_m_count] - constraints += [wg_z == subgroup_k_count] - constraints += [z3.Or(wg_x <= n, wg_x <= m)] - constraints += [k % intrinsic_mn == 0] - constraints += [(k * n) % wg_threads == 0] - constraints += [(k * m) % wg_threads == 0] - subgroups = subgroup_m_count * subgroup_n_count - if num_subgroups > 0: - constraints += [subgroups == num_subgroups] - else: - constraints += [subgroups >= 1, subgroups <= 10] - - constraints += [waves_per_eu == 2] - # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] - - shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) - constraints += [shared_memory <= 65536] - - constraints += get_dispatch_constraints(problem_size, m, n, k) - - return constraints - - -def generate_solutions(problem_size: ProblemSize, num_subgrups: int): - M, N, K = problem_size.MNK - tune_logger.info(f"{M},{N},{K}") - m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") - subgroup_size = z3.Int("subgroup_size") - intrinsic_mn = z3.Int("intrinsic_mn") - intrinsic_k = z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") - sg_m_cnt = z3.Int("sg_m_cnt") - sg_n_cnt = z3.Int("sg_n_cnt") - waves_per_eu = z3.Int("waves_per_eu") - all_vars = [ - m, - n, - k, - subgroup_size, - intrinsic_mn, - intrinsic_k, - wg_x, - wg_y, - wg_z, - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ] - - solver = z3.Solver() - constraints = generate_constraints( - problem_size, - [m, n, k], - num_subgrups, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - solver.add(z3.simplify(z3.And(constraints))) - tune_logger.debug(f"Initial constraints: {solver}") - i = 0 - while solver.check() == z3.sat: - model = solver.model() - lookup = lambda var: model[var].as_long() - - config = Configuration( - lookup(subgroup_size), - [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - MfmaIntrinsic( - problem_size.res_type.element_type, - lookup(intrinsic_mn), - lookup(intrinsic_mn), - lookup(intrinsic_k), - problem_size.lhs_type.element_type, - ), - [lookup(m), lookup(n), lookup(k)], - lookup(sg_m_cnt), - lookup(sg_n_cnt), - GpuPipelineOptions(), - lookup(waves_per_eu), - ) - solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) - i += 1 - yield config - - -def get_default_output_dir() -> str: - from datetime import datetime - - return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") - - class DispatchTuner(DispatchParser): # TODO(https://github.com/nod-ai/SHARK-Platform/issues/453): Remove this in favor of configuring using transform dialect. @abstractmethod @@ -675,6 +487,12 @@ def walk_mlir_op( return walk_result +def get_default_output_dir() -> str: + from datetime import datetime + + return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M") + + def tune( input: str, # Path to the mlir file to be tuned output: str = "", # Path to the output directory, auto creates one if not given @@ -722,7 +540,9 @@ def tune( problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) tune_logger.debug(str(problem_size)) configs = [] - for i, config in enumerate(generate_solutions(problem_size, num_subgroups)): + for i, config in enumerate( + generate_solutions(tuner_context, problem_size, num_subgroups) + ): if i >= limit: break tune_logger.info(f"Solution #{i+1}: {config}") diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 63e8441d1..47e351fc7 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -14,141 +14,6 @@ from . import common -def test_generate_solutions() -> None: - matmul_size = common.MatmulSize(2048, 3840, 1280) - lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16) - rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16) - res_type = common.ShapedType([2048, 3840], common.ElementType.f32) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - configs = candidate_gen.generate_solutions(problem_size, 4) - assert configs is not None - - -def test_calculate_shared_memory_usage_in_bytes() -> None: - matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 147456 - ) - - lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) - == 81920 - ) - - rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) - == 12288 - ) - - -def test_generate_constraints_valid_input() -> None: - matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - # Define input parameters as z3 Ints - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are satisfiable - assert solver.check() == candidate_gen.z3.sat - - -def test_generate_constraints_invalid_input() -> None: - # Define input parameters that should lead to unsatisfiable constraints - matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) - problem_size = common.ProblemSize( - matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt - ) - m, n, k = ( - candidate_gen.z3.Int("m"), - candidate_gen.z3.Int("n"), - candidate_gen.z3.Int("k"), - ) - subgroup_size = candidate_gen.z3.Int("subgroup_size") - intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") - intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = ( - candidate_gen.z3.Int("wg_x"), - candidate_gen.z3.Int("wg_y"), - candidate_gen.z3.Int("wg_z"), - ) - sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") - sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") - waves_per_eu = candidate_gen.z3.Int("waves_per_eu") - - constraints = candidate_gen.generate_constraints( - problem_size, - [m, n, k], - 4, - subgroup_size, - [intrinsic_mn, intrinsic_k], - [wg_x, wg_y, wg_z], - sg_m_cnt, - sg_n_cnt, - waves_per_eu, - ) - constraints.append(m > 1000) # Adding an additional unsatisfiable constraint - - solver = candidate_gen.z3.Solver() - solver.add(constraints) - - # Check if the constraints are unsatisfiable - assert solver.check() == candidate_gen.z3.unsat - - def remove_comments(mlir: str) -> str: return "\n".join( filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines()) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py new file mode 100644 index 000000000..ac46d8edd --- /dev/null +++ b/tuner/tuner/dispatch_constraints.py @@ -0,0 +1,197 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Given an input dispatch, this code modifies the hyperparameters +# in the code and runs it. + +import z3 # type: ignore +from typing import Iterator + +from .common import * + + +def get_mfma_intrinsic_constraints( + problem_size: ProblemSize, + intrinsic_m: z3.ArithRef, + intrinsic_n: z3.ArithRef, + intrinsic_k: z3.ArithRef, +) -> z3.BoolRef: + compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size) + assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" + return z3.Or( + *( + z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) + for mfma in compatible_intrinsics + ) + ) + + +def get_dispatch_constraints( + problem_size: ProblemSize, + tile_m: z3.ArithRef, + tile_n: z3.ArithRef, + tile_k: z3.ArithRef, +) -> list[z3.BoolRef]: + if problem_size.dispatch_kind != DispatchKind.conv: + return [] + + dim_info = ConvDimInfo.from_problem_size(problem_size) + conv_constraints = [] + # WARNING: This sometimes makes the constraints UNSAT for some reason. + conv_constraints += [tile_m <= dim_info.ow] + conv_constraints += [tile_n <= dim_info.oc] + conv_constraints += [tile_k <= dim_info.ic] + return conv_constraints + + +def calculate_shared_memory_usage_in_bytes( + problem_size: ProblemSize, + m: int | z3.ArithRef, + n: int | z3.ArithRef, + k: int | z3.ArithRef, +) -> int | z3.ArithRef: + lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8) + rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8) + return lhs_memory + rhs_memory + + +def generate_constraints( + problem_size: ProblemSize, + tile_sizes, + num_subgroups, + subgroup_size, + intrinsic_size, + workgroup_size, + subgroup_m_count, + subgroup_n_count, + waves_per_eu, +): + M, N, K = ( + problem_size.matmul_size.M, + problem_size.matmul_size.N, + problem_size.matmul_size.K, + ) + m, n, k = tile_sizes + intrinsic_mn, intrinsic_k = intrinsic_size + wg_x, wg_y, wg_z = workgroup_size + wg_threads = z3.Int("wg_threads") + constraints = [wg_threads == wg_x * wg_y * wg_z] + constraints += [subgroup_size == 64, wg_threads <= 1024] + constraints += [ + get_mfma_intrinsic_constraints( + problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k + ) + ] + subgroup_k_count = 1 + constraints += [ + m >= intrinsic_mn, + m <= 512, + m <= M, + ] + constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0] + constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0] + for x in (subgroup_m_count, subgroup_n_count): + constraints += [x >= 1, x <= 32] + + subgroup_m_tile_count = z3.Int("sg_m_tcnt") + subgroup_n_tile_count = z3.Int("sg_n_tcnt") + subgroup_k_tile_count = z3.Int("sg_k_tcnt") + for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count): + constraints += [x >= 1, x <= 32] + + constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn] + constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn] + constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k] + constraints += [wg_x == subgroup_size * subgroup_n_count] + constraints += [wg_y == subgroup_m_count] + constraints += [wg_z == subgroup_k_count] + constraints += [z3.Or(wg_x <= n, wg_x <= m)] + constraints += [k % intrinsic_mn == 0] + constraints += [(k * n) % wg_threads == 0] + constraints += [(k * m) % wg_threads == 0] + subgroups = subgroup_m_count * subgroup_n_count + if num_subgroups > 0: + constraints += [subgroups == num_subgroups] + else: + constraints += [subgroups >= 1, subgroups <= 10] + + constraints += [waves_per_eu == 2] + # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)] + + shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k) + constraints += [shared_memory <= 65536] + + constraints += get_dispatch_constraints(problem_size, m, n, k) + + return constraints + + +def generate_solutions( + ctx: TunerContext, problem_size: ProblemSize, num_subgrups: int +) -> Iterator[Configuration]: + M, N, K = problem_size.MNK + ctx.logger.info(f"{M},{N},{K}") + m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z") + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + all_vars = [ + m, + n, + k, + subgroup_size, + intrinsic_mn, + intrinsic_k, + wg_x, + wg_y, + wg_z, + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ] + + solver = z3.Solver() + constraints = generate_constraints( + problem_size, + [m, n, k], + num_subgrups, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + solver.add(z3.simplify(z3.And(constraints))) + ctx.logger.debug(f"Initial constraints: {solver}") + i = 0 + while solver.check() == z3.sat: + model = solver.model() + lookup = lambda var: model[var].as_long() + + config = Configuration( + lookup(subgroup_size), + [lookup(wg_x), lookup(wg_y), lookup(wg_z)], + MfmaIntrinsic( + problem_size.res_type.element_type, + lookup(intrinsic_mn), + lookup(intrinsic_mn), + lookup(intrinsic_k), + problem_size.lhs_type.element_type, + ), + [lookup(m), lookup(n), lookup(k)], + lookup(sg_m_cnt), + lookup(sg_n_cnt), + GpuPipelineOptions(), + lookup(waves_per_eu), + ) + solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars))))) + i += 1 + yield config diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py new file mode 100644 index 000000000..55f3a8c43 --- /dev/null +++ b/tuner/tuner/dispatch_constraints_test.py @@ -0,0 +1,161 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Usage: python -m pytest candidate_gen_test.py +""" + +import pytest +import z3 # type: ignore + +from logging import Logger +from unittest.mock import MagicMock + +from . import common +from . import dispatch_constraints + + +def test_generate_solutions() -> None: + matmul_size = common.MatmulSize(2048, 3840, 1280) + lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16) + rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16) + res_type = common.ShapedType([2048, 3840], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + logger: Logger = MagicMock(spec=Logger) + ctx = common.TunerContext(None, logger) + configs = dispatch_constraints.generate_solutions(ctx, problem_size, 4) + assert configs is not None + + +def test_calculate_shared_memory_usage_in_bytes() -> None: + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + assert ( + dispatch_constraints.calculate_shared_memory_usage_in_bytes( + problem_size, 512, 64, 128 + ) + == 147456 + ) + + lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + assert ( + dispatch_constraints.calculate_shared_memory_usage_in_bytes( + problem_size, 512, 64, 128 + ) + == 81920 + ) + + rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + assert ( + dispatch_constraints.calculate_shared_memory_usage_in_bytes( + problem_size, 128, 64, 32 + ) + == 12288 + ) + + +def test_generate_constraints_valid_input() -> None: + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + # Define input parameters as z3 Ints + m, n, k = ( + dispatch_constraints.z3.Int("m"), + z3.Int("n"), + z3.Int("k"), + ) + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + z3.Int("wg_x"), + z3.Int("wg_y"), + z3.Int("wg_z"), + ) + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + + constraints = dispatch_constraints.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + + solver = z3.Solver() + solver.add(constraints) + + # Check if the constraints are satisfiable + assert solver.check() == z3.sat + + +def test_generate_constraints_invalid_input() -> None: + # Define input parameters that should lead to unsatisfiable constraints + matmul_size = common.MatmulSize(1024, 1024, 1024) + lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) + res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + problem_size = common.ProblemSize( + matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt + ) + m, n, k = ( + z3.Int("m"), + z3.Int("n"), + z3.Int("k"), + ) + subgroup_size = z3.Int("subgroup_size") + intrinsic_mn = z3.Int("intrinsic_mn") + intrinsic_k = z3.Int("intrinsic_k") + wg_x, wg_y, wg_z = ( + z3.Int("wg_x"), + z3.Int("wg_y"), + z3.Int("wg_z"), + ) + sg_m_cnt = z3.Int("sg_m_cnt") + sg_n_cnt = z3.Int("sg_n_cnt") + waves_per_eu = z3.Int("waves_per_eu") + + constraints = dispatch_constraints.generate_constraints( + problem_size, + [m, n, k], + 4, + subgroup_size, + [intrinsic_mn, intrinsic_k], + [wg_x, wg_y, wg_z], + sg_m_cnt, + sg_n_cnt, + waves_per_eu, + ) + constraints.append(m > 1000) # Adding an additional unsatisfiable constraint + + solver = z3.Solver() + solver.add(constraints) + + # Check if the constraints are unsatisfiable + assert solver.check() == z3.unsat From 7ff0534fe54dcb03a57bae6bfda1bc753cf95654 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 15 Nov 2024 11:17:48 -0500 Subject: [PATCH 50/59] [tuner] Update README (#532) Add a warning about the tuner being in early development. Update the tuner readme and remove obsolete instructions. --- README.md | 5 +++++ tuner/README.md | 53 ++++++++++--------------------------------------- 2 files changed, 16 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index f5a255c84..517980838 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,11 @@ conversion tools to produce inference-optimized programs. The Tuner sub-project assists with tuning program performance by searching for optimal parameter configurations to use during model compilation. +> [!WARNING] +> SHARK Tuner is still in early development. Interested users may want +> to try it out, but the tuner is not ready for general use yet. Check out +> [the readme](tuner/README.md) for more details. + ## Support matrix diff --git a/tuner/README.md b/tuner/README.md index 47156779c..e6a515729 100644 --- a/tuner/README.md +++ b/tuner/README.md @@ -1,5 +1,8 @@ # IREE dispatch auto-tuning scripts -`libtuner.py` is the core Python script that provides the fundamental functions for the tuning loop. It imports `candidate_gen.py` for candidate generation. To implement the full tuning loop, `libtuner.py` requires a separate Python script that uses the provided `TuningClient` API from `libtuner.py`. +`libtuner.py` is the core Python script that provides the fundamental functions +for the tuning loop. It imports `candidate_gen.py` for candidate generation. To +implement the full tuning loop, `libtuner.py` requires a separate Python script +that uses the provided `TuningClient` API from `libtuner.py`. ## Prerequisites [Optional] Using virtual environments: @@ -22,47 +25,13 @@ Using the IREE's Python bindings: - Set environment ```shell source ../iree-build/.env && export PYTHONPATH + export PATH="$(realpath ../iree-build/tools):$PATH" ``` -For more information, refer to the [IREE documentation](https://iree.dev/building-from-source/getting-started/#python-bindings) +For more information, refer to the [IREE +documentation](https://iree.dev/building-from-source/getting-started/#python-bindings). -### Overall flow +## Examples -1. Symlink all scripts and mlir/irpa files in your build dir. - - Symlink `iree-build-dir/tools` inside `tuning`. - - Symlink ML model MLIR and weights based on `unet.sh`. - -2. Copy the attention/matmul spec as `config.mlir` in the tuning dir. - -3. Temporarily comment out all the existing configs in `config.mlir`. - - Example: - ```mlir - // , @match_mmt_2048x10240x1280 -> @apply_op_config - // , @match_mmt_2048x1280x5120 -> @apply_op_config - // , @match_mmt_2048x1280x1280 -> @apply_op_config - ``` - -4. Compile a baseline unet -```shell -./unet.sh winograd unet.mlir -o unet_baseline.vmfb --iree-hal-dump-executable-files-to=dump-winograd -``` - -5. Find the matmul to tune and copy the `*_benchmark.mlir` file to the build dir. -```shell -cp dump-winograd/*_141_*benchmark.mlir ./141.mlir -``` - -6. Run the tuning script. - - Example: - ```shell - python -m examples.punet 141.mlir --devices=hip://GPU-0,hip://GPU-4 --num-candidates=1024 - ``` - -7. Check the winner candidate in `result_summary.log`, find and copy the transform spec. - -8. Paste the transform spec into the `config.mlir` and uncomment them. - -9. Add the match function to the entry point in `config.mlir` - - Example: - ```mlir - @match_something -> @apply_op_config - ``` +Check the `examples` directory for sample tuners implemented with `libtuner`. +The [`dispatch` example](https://github.com/nod-ai/SHARK-Platform/tree/main/tuner/examples/dispatch) +should be a good starting point for most users. From 61a211f9625d7756ecf12e323f35f8e210a62085 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 15 Nov 2024 18:00:30 +0100 Subject: [PATCH 51/59] [shark-ai] Update README (#536) `sharktank` is currently not pulled in via the meta package. --- shark-ai/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shark-ai/README.md b/shark-ai/README.md index 93bdfd671..0bb1abafd 100644 --- a/shark-ai/README.md +++ b/shark-ai/README.md @@ -1,3 +1,3 @@ # SHARK AI meta package -Meta package to install `sharktank` and `shortfin`. +Meta package to install `shortfin` and compatible IREE packages. From 292e3759040edf7dfacfbed4344508f7c2137f12 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 15 Nov 2024 09:06:35 -0800 Subject: [PATCH 52/59] Add scripting and documentation for pushing releases to PyPI. (#519) Progress on https://github.com/nod-ai/SHARK-Platform/issues/400. This scripting allows us to publish .whl files from https://github.com/nod-ai/SHARK-Platform/releases/tag/dev-wheels to [PyPI](https://pypi.org/). Here are the basic steps: 1. Download wheels for a specific pre-release (e.g. `2.9.1rc20241114`) 2. Edit the versions in the downloaded wheels to remove the `rcYYYYMMDD` suffix 3. Build the `shark-ai` meta package using the versions in those whls (NOTE: currently this uses the versions _in the source tree_, see below) 4. Upload all wheels to PyPI Logs of this running: https://gist.github.com/ScottTodd/9c7418d5bbbebc8aea72a39bc2ac37b0 (to push 2.9.1 to PyPI) The new `pypi_deploy.sh` script is based on the similar script we maintain in IREE: https://github.com/iree-org/iree/blob/main/build_tools/python_deploy/pypi_deploy.sh . The `README.md` file is also forked from IREE. Known sharp edges with the publishing process: * This currently mixes local information (versions from the source tree where this is running) with remote information (versions from the nightly release packages). Publishing nightly meta packages will let us simplify and make this safer to run from arbitrary working trees * The local build of the shark-ai meta package copies _all_ files in `shark-ai/build_tools/wheelhouse/` to the working directory that gets sent to `twine upload *`. If there are preexisting files in that directory they will be published. * The step that downloads releases from GitHub uses `pip download` so it can leverage PyPI's version resolution logic, but I couldn't figure out a way to download 3.13t wheels without running a 3.13t python interpreter. I happen to have 3.13t installed on my system, but we shouldn't require that in the scripting for all release engineers. We could try using the `gh` tool as in IREE, if we properly filter the download to just the version we want: https://github.com/iree-org/iree/blob/9eaa4ef7d6b439d8c444b533beddd82146578e25/build_tools/python_deploy/pypi_deploy.sh#L69-L72 (that has one release per nightly, while here we have a single release shared between all nightlies) --------- Co-authored-by: Marius Brehler --- build_tools/python_deploy/README.md | 48 +++++++ .../python_deploy/compute_common_version.py | 8 +- .../python_deploy/compute_local_version.py | 0 .../promote_whl_from_rc_to_final.py | 0 build_tools/python_deploy/pypi_deploy.sh | 126 ++++++++++++++++++ .../requirements-pypi-deploy.txt | 0 .../python_deploy/write_requirements.py | 0 shark-ai/build_tools/build_linux_package.sh | 25 ++++ 8 files changed, 205 insertions(+), 2 deletions(-) create mode 100644 build_tools/python_deploy/README.md mode change 100644 => 100755 build_tools/python_deploy/compute_common_version.py mode change 100644 => 100755 build_tools/python_deploy/compute_local_version.py rename build_tools/{ => python_deploy}/promote_whl_from_rc_to_final.py (100%) create mode 100755 build_tools/python_deploy/pypi_deploy.sh rename build_tools/{ => python_deploy}/requirements-pypi-deploy.txt (100%) mode change 100644 => 100755 build_tools/python_deploy/write_requirements.py create mode 100755 shark-ai/build_tools/build_linux_package.sh diff --git a/build_tools/python_deploy/README.md b/build_tools/python_deploy/README.md new file mode 100644 index 000000000..d36545a9c --- /dev/null +++ b/build_tools/python_deploy/README.md @@ -0,0 +1,48 @@ +# Python Deployment + +These scripts assist with building Python packages and pushing them to +[PyPI (the Python Package Index)](https://pypi.org/). See also + +* The Python Packaging User Guide: + +## Overview + +See comments in scripts for canonical usage. This page includes additional +notes. + +### Package building + +These scripts build packages: + +* [`/shark-ai/build_tools/build_linux_package.sh`](/shark-ai/build_tools/build_linux_package.sh) +* [`/sharktank/build_tools/build_linux_package.sh`](/sharktank/build_tools/build_linux_package.sh) +* [`/shortfin/build_tools/build_linux_package.sh`](/shortfin/build_tools/build_linux_package.sh) + +### Version management + +These scripts handle versioning across packages, including considerations like +major, minor, and patch levels (`X.Y.Z`), as well as suffixes like +`rc20241107`: + +* [`compute_common_version.py`](./compute_common_version.py) +* [`compute_local_version.py`](./compute_local_version.py) +* [`promote_whl_from_rc_to_final.py`](./promote_whl_from_rc_to_final.py) +* [`write_requirements.py`](./write_requirements.py) + +### PyPI deployment + +These scripts handle promoting nightly releases packages to stable and pushing +to PyPI: + +* [`promote_whl_from_rc_to_final.py`](./promote_whl_from_rc_to_final.py) +* [`pypi_deploy.sh`](./pypi_deploy.sh) + +Both of these scripts expect to have the dependencies from +[`requirements-pypi-deploy.txt`](./requirements-pypi-deploy.txt) installed. +This can be easily managed by using a Python virtual environment: + +```bash +python -m venv .venv +source .venv/bin/activate +python -m pip install -r ./requirements-pypi-deploy.txt +``` diff --git a/build_tools/python_deploy/compute_common_version.py b/build_tools/python_deploy/compute_common_version.py old mode 100644 new mode 100755 index 6aea7f254..ed6f8c708 --- a/build_tools/python_deploy/compute_common_version.py +++ b/build_tools/python_deploy/compute_common_version.py @@ -6,8 +6,12 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # This scripts grabs the `X.Y.Z[.dev]` version identifier from the -# sharktank and shortfin version files and computes the version -# for the meta package. +# 'sharktank' and 'shortfin' version files and computes the version +# for the meta 'shark-ai' package. +# +# Usage: +# ./compute_common_version.py --stable-release --write-json +# cat ../../shark-ai/version_local.json import argparse from pathlib import Path diff --git a/build_tools/python_deploy/compute_local_version.py b/build_tools/python_deploy/compute_local_version.py old mode 100644 new mode 100755 diff --git a/build_tools/promote_whl_from_rc_to_final.py b/build_tools/python_deploy/promote_whl_from_rc_to_final.py similarity index 100% rename from build_tools/promote_whl_from_rc_to_final.py rename to build_tools/python_deploy/promote_whl_from_rc_to_final.py diff --git a/build_tools/python_deploy/pypi_deploy.sh b/build_tools/python_deploy/pypi_deploy.sh new file mode 100755 index 000000000..c141aea4f --- /dev/null +++ b/build_tools/python_deploy/pypi_deploy.sh @@ -0,0 +1,126 @@ +#!/bin/bash + +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This script promotes Python packages from nightly releases to PyPI. +# +# Prerequisites: +# * You will need to have PyPI credentials set up. See +# https://packaging.python.org/en/latest/tutorials/packaging-projects/#uploading-the-distribution-archives +# * Install requirements, e.g. in a Python virtual environment (venv): +# `pip install -r requirements-pypi-deploy.txt` +# * Install python3.13t and install pip. On Ubuntu: +# ```bash +# sudo add-apt-repository ppa:deadsnakes +# sudo apt-get update +# sudo apt-get install python3.13-nogil +# python3.13t -m ensurepip --upgrade +# ``` +# * Choose a release candidate to promote from +# https://github.com/nod-ai/SHARK-Platform/releases/tag/dev-wheels +# +# Usage: +# ./pypi_deploy.sh 2.9.0rc20241108 + +set -euo pipefail + +RELEASE="$1" + +SCRIPT_DIR="$(dirname -- "$( readlink -f -- "$0"; )")"; +REPO_ROOT="$(cd "$SCRIPT_DIR"/../../ && pwd)" +TMPDIR="$(mktemp --directory --tmpdir shark_platform_pypi_wheels.XXXXX)" +ASSETS_PAGE="https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels" + +# TODO: rewrite in Python? + +function download_wheels() { + echo "" + echo "Downloading wheels for '${RELEASE}'..." + + # sharktank + python -m pip download sharktank==${RELEASE} \ + --no-deps --python-version 3.11 -f ${ASSETS_PAGE} + + # shortfin + python -m pip download shortfin==${RELEASE} \ + --no-deps --python-version 3.11 -f ${ASSETS_PAGE} + python -m pip download shortfin==${RELEASE} \ + --no-deps --python-version 3.12 -f ${ASSETS_PAGE} + python -m pip download shortfin==${RELEASE} \ + --no-deps --python-version 3.13 -f ${ASSETS_PAGE} + python -m pip download shortfin==${RELEASE} \ + --no-deps --python-version 3.13 -f ${ASSETS_PAGE} + # TODO: fetch 3.13t using the same `python` somehow + # * https://pip.pypa.io/en/stable/cli/pip_download/ + # * https://py-free-threading.github.io/installing_cpython/ + # * https://pip.pypa.io/en/stable/installation/ + python3.13t -m pip download shortfin==${RELEASE} --no-deps -f ${ASSETS_PAGE} + + # TODO: shark-ai meta package when it is published to nightlies + + echo "" + echo "Downloaded wheels:" + ls +} + +function edit_release_versions() { + echo "" + echo "Editing release versions..." + for file in * + do + ${SCRIPT_DIR}/promote_whl_from_rc_to_final.py ${file} --delete-old-wheel + done + + echo "Edited wheels:" + ls +} + +function upload_wheels() { + # TODO: list packages that would be uploaded, pause, prompt to continue + echo "" + echo "Uploading wheels:" + ls + twine upload --verbose * +} + +function build_shark_ai_meta_package() { + # TODO: download meta package from nightly releases instead of this + # Be aware that nightly releases pin other dependencies via the + # generated `requirements.txt` compared to stable releases. + echo "" + + # TODO: rework `write_requirements.py` to use the versions from the downloaded whls? + echo "Computing local versions for sharktank and shortfin..." + ${SCRIPT_DIR}/compute_local_version.py ${REPO_ROOT}/sharktank + ${SCRIPT_DIR}/compute_local_version.py ${REPO_ROOT}/shortfin + + echo "Computing common version for shark-ai meta package..." + ${SCRIPT_DIR}/compute_common_version.py --stable-release --write-json + + echo "Writing requirements for shark-ai meta package..." + ${SCRIPT_DIR}/write_requirements.py + + echo "Building shark-ai meta package..." + ${REPO_ROOT}/shark-ai/build_tools/build_linux_package.sh + + # TODO: This is error-prone. We only want to publish the whl for this release. + # Copy instead? Specify exact file name? Clear directory before building? + mv ${REPO_ROOT}/shark-ai/build_tools/wheelhouse/* . +} + +function main() { + echo "Changing into ${TMPDIR}" + cd "${TMPDIR}" + # TODO: check_requirements (using pip) + + download_wheels + edit_release_versions + build_shark_ai_meta_package + upload_wheels +} + +main diff --git a/build_tools/requirements-pypi-deploy.txt b/build_tools/python_deploy/requirements-pypi-deploy.txt similarity index 100% rename from build_tools/requirements-pypi-deploy.txt rename to build_tools/python_deploy/requirements-pypi-deploy.txt diff --git a/build_tools/python_deploy/write_requirements.py b/build_tools/python_deploy/write_requirements.py old mode 100644 new mode 100755 diff --git a/shark-ai/build_tools/build_linux_package.sh b/shark-ai/build_tools/build_linux_package.sh new file mode 100755 index 000000000..d16f339b1 --- /dev/null +++ b/shark-ai/build_tools/build_linux_package.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# build_linux_package.sh +# +# Builds shark-ai Python package for Linux. +# +# Usage: +# ./build_tools/build_linux_package.sh + +set -xeu -o errtrace + +THIS_DIR="$(cd $(dirname $0) && pwd)" +REPO_ROOT="$(cd "$THIS_DIR"/../../ && pwd)" +OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}" + +python -m pip wheel --disable-pip-version-check --no-deps -v -w "${OUTPUT_DIR}" "${REPO_ROOT}/shark-ai" + +wheel_output="$(echo "${OUTPUT_DIR}/shark_ai-"*".whl")" +ls "${wheel_output}" From 8664abebcbde7542efaa85e339c0752bfad1f54c Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 15 Nov 2024 18:21:34 +0100 Subject: [PATCH 53/59] Bump IREE version to 3.0.0rc20241115 (#537) Bumps the IREE version (which is among other things used to build shortfin) to the latest release candidate / nightly version. --- .github/workflows/ci-llama-large-tests.yaml | 4 ++-- .github/workflows/ci-llama-quick-tests.yaml | 4 ++-- .github/workflows/ci-sdxl.yaml | 2 +- .github/workflows/ci-sglang-benchmark.yml | 4 ++-- .github/workflows/ci_linux_x64-libshortfin.yml | 2 +- .github/workflows/ci_linux_x64_asan-libshortfin.yml | 2 +- .github/workflows/ci_linux_x64_nogil-libshortfin.yml | 2 +- .github/workflows/ci_windows_x64-libshortfin.yml | 2 +- shortfin/CMakeLists.txt | 2 +- shortfin/requirements-iree-compiler.txt | 4 ++-- 10 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml index 5645efd8a..41ad5af6b 100644 --- a/.github/workflows/ci-llama-large-tests.yaml +++ b/.github/workflows/ci-llama-large-tests.yaml @@ -70,8 +70,8 @@ jobs: # Test with pinned nightly releases, not what iree-turbine uses. pip install -f https://iree.dev/pip-release-links.html --upgrade \ - iree-base-compiler==2.9.0rc20241108 \ - iree-base-runtime==2.9.0rc20241108 + iree-base-compiler==3.0.0rc20241115 \ + iree-base-runtime==3.0.0rc20241115 - name: Run llama tests run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-all-llama --iree-hip-target=gfx942 --html=out/index.html diff --git a/.github/workflows/ci-llama-quick-tests.yaml b/.github/workflows/ci-llama-quick-tests.yaml index 585a759ac..63637e9b9 100644 --- a/.github/workflows/ci-llama-quick-tests.yaml +++ b/.github/workflows/ci-llama-quick-tests.yaml @@ -71,8 +71,8 @@ jobs: # Test with pinned nightly releases, not what iree-turbine uses. pip install -f https://iree.dev/pip-release-links.html --upgrade \ - iree-base-compiler==2.9.0rc20241108 \ - iree-base-runtime==2.9.0rc20241108 + iree-base-compiler==3.0.0rc20241115 \ + iree-base-runtime==3.0.0rc20241115 - name: Run llama 8b tests run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --run-8b-llama diff --git a/.github/workflows/ci-sdxl.yaml b/.github/workflows/ci-sdxl.yaml index 9c5776c4c..31218d25f 100644 --- a/.github/workflows/ci-sdxl.yaml +++ b/.github/workflows/ci-sdxl.yaml @@ -64,7 +64,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: iree-2.9.0rc20241108 + ref: iree-3.0.0rc20241115 - name: Initalize IREE submodules working-directory: ${{ env.IREE_REPO_DIR }} diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml index d890d972c..6a5fa4112 100644 --- a/.github/workflows/ci-sglang-benchmark.yml +++ b/.github/workflows/ci-sglang-benchmark.yml @@ -69,8 +69,8 @@ jobs: # We could also pin to a known working or stable version. # This should eventually stabilize. Do the best we can for now. pip install -f https://iree.dev/pip-release-links.html --upgrade \ - iree-base-compiler==2.9.0rc20241108 \ - iree-base-runtime==2.9.0rc20241108 \ + iree-base-compiler==3.0.0rc20241115 \ + iree-base-runtime==3.0.0rc20241115 \ "numpy<2.0" - name: Install SGLang diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml index c1b039da3..45ddfe90d 100644 --- a/.github/workflows/ci_linux_x64-libshortfin.yml +++ b/.github/workflows/ci_linux_x64-libshortfin.yml @@ -59,7 +59,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: iree-2.9.0rc20241108 + ref: iree-3.0.0rc20241115 - name: Initalize IREE submodules working-directory: ${{ env.IREE_REPO_DIR }} diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml index b61536218..5692a8336 100644 --- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml @@ -109,7 +109,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_SOURCE_DIR }} submodules: false - ref: iree-2.9.0rc20241108 + ref: iree-3.0.0rc20241115 - name: Initalize IREE submodules working-directory: ${{ env.IREE_SOURCE_DIR }} diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml index c80b40c03..c382edbf4 100644 --- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml +++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml @@ -57,7 +57,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: iree-2.9.0rc20241108 + ref: iree-3.0.0rc20241115 - name: Initalize IREE submodules working-directory: ${{ env.IREE_REPO_DIR }} diff --git a/.github/workflows/ci_windows_x64-libshortfin.yml b/.github/workflows/ci_windows_x64-libshortfin.yml index 929244af4..544b45c76 100644 --- a/.github/workflows/ci_windows_x64-libshortfin.yml +++ b/.github/workflows/ci_windows_x64-libshortfin.yml @@ -54,7 +54,7 @@ jobs: repository: iree-org/iree path: ${{ env.IREE_REPO_DIR }} submodules: false - ref: iree-2.9.0rc20241108 + ref: iree-3.0.0rc20241115 - name: Initalize IREE submodules working-directory: ${{ env.IREE_REPO_DIR }} diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt index 11982202d..93ee63594 100644 --- a/shortfin/CMakeLists.txt +++ b/shortfin/CMakeLists.txt @@ -40,7 +40,7 @@ if(NOT WIN32) endif() # Pins -set(SHORTFIN_IREE_GIT_TAG "iree-2.9.0rc20241108") +set(SHORTFIN_IREE_GIT_TAG "iree-3.0.0rc20241115") # build options option(SHORTFIN_BUILD_PYTHON_BINDINGS "Builds Python Bindings" OFF) diff --git a/shortfin/requirements-iree-compiler.txt b/shortfin/requirements-iree-compiler.txt index 7aea80277..ec033c57c 100644 --- a/shortfin/requirements-iree-compiler.txt +++ b/shortfin/requirements-iree-compiler.txt @@ -1,4 +1,4 @@ # Keep in sync with "ref: iree-" in .github/workflows/* and GIT_TAG in CMakeLists.txt -f https://iree.dev/pip-release-links.html -iree-base-compiler==2.9.0rc20241108 -iree-base-runtime==2.9.0rc20241108 +iree-base-compiler==3.0.0rc20241115 +iree-base-runtime==3.0.0rc20241115 From d83ab9d83c9c7352edc5e8bd34da691cfe230502 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 15 Nov 2024 10:24:45 -0800 Subject: [PATCH 54/59] Adapt to repository rename from SHARK-Platform to shark-ai. (#545) --- .github/workflows/build_packages.yml | 2 +- ...{ci-shark-platform.yml => ci-shark-ai.yml} | 2 +- README.md | 14 +++--- build_tools/python_deploy/pypi_deploy.sh | 4 +- docs/model_cookbook.md | 2 +- docs/nightly_releases.md | 16 +++--- docs/quantization.md | 50 +++++++++---------- .../llm/developer/e2e_llama8b_mi300x.md | 2 +- docs/shortfin/llm/user/e2e_llama8b_mi300x.md | 4 +- docs/user_guide.md | 2 +- shark-ai/pyproject.toml | 2 +- sharktank/README.md | 2 +- sharktank/pyproject.toml | 2 +- sharktank/sharktank/ops/custom_impls.py | 2 +- sharktank/sharktank/utils/export_artifacts.py | 2 +- sharktank/tests/ops/ops_test.py | 6 +-- shortfin/README.md | 2 +- shortfin/pyproject.toml | 2 +- tuner/README.md | 2 +- tuner/pyproject.toml | 2 +- tuner/tuner/candidate_gen.py | 2 +- 21 files changed, 62 insertions(+), 62 deletions(-) rename .github/workflows/{ci-shark-platform.yml => ci-shark-ai.yml} (99%) diff --git a/.github/workflows/build_packages.yml b/.github/workflows/build_packages.yml index 8f138d973..1200234c4 100644 --- a/.github/workflows/build_packages.yml +++ b/.github/workflows/build_packages.yml @@ -129,7 +129,7 @@ jobs: token: "${{ secrets.RELEASE_PUBLISH_ACCESS_TOKEN }}" tag: "dev-wheels" name: "dev-wheels" - body: "Automatic snapshot release of SHARK-Platform python wheels." + body: "Automatic snapshot release of shark-ai python wheels." removeArtifacts: false allowUpdates: true replacesArtifacts: true diff --git a/.github/workflows/ci-shark-platform.yml b/.github/workflows/ci-shark-ai.yml similarity index 99% rename from .github/workflows/ci-shark-platform.yml rename to .github/workflows/ci-shark-ai.yml index dc2f4646a..28e2bc883 100644 --- a/.github/workflows/ci-shark-platform.yml +++ b/.github/workflows/ci-shark-ai.yml @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -name: CI - shark-platform +name: CI - shark-ai on: workflow_dispatch: diff --git a/README.md b/README.md index 517980838..77f4a0d75 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ -# SHARK Modeling and Serving Libraries +# shark-ai: SHARK Modeling and Serving Libraries > [!IMPORTANT] > Development is still in progress for several project components. See the > notes below for which workflows are best supported. -![GitHub License](https://img.shields.io/github/license/nod-ai/SHARK-Platform) +![GitHub License](https://img.shields.io/github/license/nod-ai/shark-ai) [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit) @@ -15,7 +15,7 @@ -[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_linux_x64-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush) +[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush) The shortfin sub-project is SHARK's high performance inference library and serving engine. @@ -25,7 +25,7 @@ serving engine. ### [`sharktank/`](./sharktank/) -[![PyPI version](https://badge.fury.io/py/sharktank.svg)](https://badge.fury.io/py/sharktank) [![CI - sharktank](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-sharktank.yml/badge.svg?event=push)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-sharktank.yml?query=event%3Apush) +[![PyPI version](https://badge.fury.io/py/sharktank.svg)](https://badge.fury.io/py/sharktank) [![CI - sharktank](https://github.com/nod-ai/shark-ai/actions/workflows/ci-sharktank.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci-sharktank.yml?query=event%3Apush) The SHARK Tank sub-project contains a collection of model recipes and conversion tools to produce inference-optimized programs. @@ -45,7 +45,7 @@ conversion tools to produce inference-optimized programs. ### [`tuner/`](./tuner/) -[![CI - Tuner](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-tuner.yml/badge.svg?event=push)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-tuner.yml?query=event%3Apush) +[![CI - Tuner](https://github.com/nod-ai/shark-ai/actions/workflows/ci-tuner.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci-tuner.yml?query=event%3Apush) The Tuner sub-project assists with tuning program performance by searching for optimal parameter configurations to use during model compilation. @@ -63,8 +63,8 @@ optimal parameter configurations to use during model compilation. Model name | Model recipes | Serving apps ---------- | ------------- | ------------ -SDXL | [`sharktank/sharktank/models/punet/`](https://github.com/nod-ai/SHARK-Platform/tree/main/sharktank/sharktank/models/punet) | [`shortfin/python/shortfin_apps/sd/`](https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps/sd) -llama | [`sharktank/sharktank/models/llama/`](https://github.com/nod-ai/SHARK-Platform/tree/main/sharktank/sharktank/models/llama) | [`shortfin/python/shortfin_apps/llm/`](https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps/llm) +SDXL | [`sharktank/sharktank/models/punet/`](https://github.com/nod-ai/shark-ai/tree/main/sharktank/sharktank/models/punet) | [`shortfin/python/shortfin_apps/sd/`](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps/sd) +llama | [`sharktank/sharktank/models/llama/`](https://github.com/nod-ai/shark-ai/tree/main/sharktank/sharktank/models/llama) | [`shortfin/python/shortfin_apps/llm/`](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps/llm) ## SHARK Users diff --git a/build_tools/python_deploy/pypi_deploy.sh b/build_tools/python_deploy/pypi_deploy.sh index c141aea4f..63f123ac0 100755 --- a/build_tools/python_deploy/pypi_deploy.sh +++ b/build_tools/python_deploy/pypi_deploy.sh @@ -21,7 +21,7 @@ # python3.13t -m ensurepip --upgrade # ``` # * Choose a release candidate to promote from -# https://github.com/nod-ai/SHARK-Platform/releases/tag/dev-wheels +# https://github.com/nod-ai/shark-ai/releases/tag/dev-wheels # # Usage: # ./pypi_deploy.sh 2.9.0rc20241108 @@ -33,7 +33,7 @@ RELEASE="$1" SCRIPT_DIR="$(dirname -- "$( readlink -f -- "$0"; )")"; REPO_ROOT="$(cd "$SCRIPT_DIR"/../../ && pwd)" TMPDIR="$(mktemp --directory --tmpdir shark_platform_pypi_wheels.XXXXX)" -ASSETS_PAGE="https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels" +ASSETS_PAGE="https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels" # TODO: rewrite in Python? diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md index fdf4c7ede..ddc0cb3bb 100644 --- a/docs/model_cookbook.md +++ b/docs/model_cookbook.md @@ -1,6 +1,6 @@ # Model cookbook -Note: These are early notes and commands that the SHARK-Platform team is using +Note: These are early notes and commands that the shark-ai team is using and will turn into proper docs later. ## Diagrams diff --git a/docs/nightly_releases.md b/docs/nightly_releases.md index 819e22f61..545cdd4f5 100644 --- a/docs/nightly_releases.md +++ b/docs/nightly_releases.md @@ -2,19 +2,19 @@ > [!WARNING] > This is still under development! See -> https://github.com/nod-ai/SHARK-Platform/issues/400. +> https://github.com/nod-ai/shark-ai/issues/400. > > These instructions will be converted into a user guide once stable packages -> are published to PyPI: . +> are published to PyPI: . Nightly releases are uploaded to -https://github.com/nod-ai/SHARK-Platform/releases/tag/dev-wheels. +https://github.com/nod-ai/shark-ai/releases/tag/dev-wheels. * The "expanded_assets" version of a release page is compatible with the `-f, --find-links ` options of `pip install` ([docs here](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-f)). For the "dev-wheels" release above, that page is: - + * These releases are generated using [`.github/workflows/build_package.yml`](../.github/workflows/build_packages.yml) * That workflow runs the @@ -23,7 +23,7 @@ https://github.com/nod-ai/SHARK-Platform/releases/tag/dev-wheels. [`shortfin/build_tools/build_linux_package.sh`](../shortfin/build_tools/build_linux_package.sh) scripts * Workflow history can be viewed at - + ## Prerequisites @@ -38,7 +38,7 @@ source builds. You will need a recent version of Python. * As of Nov 1, 2024, sharktank is compatible with Python 3.11. See - https://github.com/nod-ai/SHARK-Platform/issues/349 for Python 3.12 support. + https://github.com/nod-ai/shark-ai/issues/349 for Python 3.12 support. * As of Nov 4, 2024, shortfin publishes packages for Python 3.11, 3.12, 3.13, and 3.13t @@ -67,7 +67,7 @@ python3.11 -m venv 3.11.venv source 3.11.venv/bin/activate # Install 'sharktank' package from nightly releases. -pip install sharktank -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels +pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels # Test the installation. python -c "from sharktank import ops; print('Sanity check passed')" @@ -84,7 +84,7 @@ python3.11 -m venv 3.11.venv source 3.11.venv/bin/activate # Install 'shortfin' package from nightly releases. -pip install shortfin -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels +pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels # Test the installation. python -c "import shortfin as sf; print('Sanity check passed')" diff --git a/docs/quantization.md b/docs/quantization.md index fcc8961b0..25bfc9f8d 100644 --- a/docs/quantization.md +++ b/docs/quantization.md @@ -64,11 +64,11 @@ amount of Python code implementing direct math and packing schemes. PyTorch modules like `Linear` and `Conv2D`. 2. Types/Ops: The `nn.Module` implementations we provide are built in terms of SHARK Tank custom - [`InferenceTensor`](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/types/tensors.py#L153) - and [polymorphic functional ops library](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/signatures.py). + [`InferenceTensor`](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/types/tensors.py#L153) + and [polymorphic functional ops library](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/signatures.py). 3. Op specializations for optimized subsets of op type signatures and features (for example, [an optimized affine quantized linear specialization for - supported combinations of `TensorScaledLayout` arguments](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/qlinear_impls.py)). + supported combinations of `TensorScaledLayout` arguments](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qlinear_impls.py)). (TODO: good place for a diagram) @@ -78,18 +78,18 @@ amount of Python code implementing direct math and packing schemes. Available modules that support direct quantization (TODO: refactor to use torch "Module" terminology and naming schemes consistently): -* [`LinearLayer`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/layers/linear.py) -* [convolution layers](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/layers/conv.py) +* [`LinearLayer`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/linear.py) +* [convolution layers](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/conv.py) Note that most sharktank modules extend -[`ThetaLayer`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/layers/base.py#L63), +[`ThetaLayer`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/base.py#L63), which calls for a bit of explanation. Traditional PyTorch Modules directly instantiate their backing parameters in their constructor. For dataset-heavy and polymorphic implementations like we commonly see in quantization and distribution, however, it can be beneficial to separate these concerns. The `ThetaLayer` simply takes a -[`Theta` object](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/theta.py#L74), +[`Theta` object](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/theta.py#L74), which is a tree-structured bag of native `torch.Tensor` or `InferenceTensor` instances, and it adopts the tensors in the bag as its own vs creating them. For those familiar with the concept, this is a form of dependency-injection @@ -114,7 +114,7 @@ tree to a specific Module instance. We've already met the `Theta` object above, which holds a tree of something called an -[`InferenceTensor`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L153). +[`InferenceTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L153). Now we describe what this is. Note that presently, `InferenceTensor` is not a `torch.Tensor` but its own `ABC` type that: @@ -140,11 +140,11 @@ pipelines. There is a growing list of `InferenceTensor` sub-types, many of which are related to quantization: -* [`PrimitiveTensor`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L286): +* [`PrimitiveTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L286): A simple composition of a single `torch.Tensor`. This is often used interchangeably with a `torch.Tensor` but is present for completeness of the type hierarchy and to be able to type select on. -* [`QuantizedTensor`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L372): +* [`QuantizedTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L372): Abstract base class of all quantized tensors, providing two primary operations: * `unpack`: Accesses the backing `QuantizedLayout` of the tensor, which is @@ -154,12 +154,12 @@ related to quantization: layout, this explodes it into a canonical representation of individual tensors which can be algebraically implemented individually/generically). -* [`PlanarQuantizedTensor`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408): +* [`PlanarQuantizedTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408): Concrete implementation for all non-packed quantized tensors that can be losslessly represented by a layout based on individual tensor components. All `QuantizedTensor` instances can be converted to a `PlanarQuantizedTensor`. -* [`QuantizerTensor`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408): +* [`QuantizerTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408): (note the "r" in the name) An abstract `InferenceTensor` that exposes a `quantize(torch.Tensor | InferenceTensor) -> QuantizedTensor` operation used to transform an arbitrary tensor to a quantized form. There are a handful @@ -178,7 +178,7 @@ manipulate tensor contents via `QuantizedLayout`, but we haven't yet defined that. The *Tensor types are structural and exist to give identity, but the `QuantizedLayout` is where the "magic happens". -[`QuantizedLayout`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L44) +[`QuantizedLayout`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L44) is an `ABC`, supporting: * Serialization/interop with parameter archives. @@ -193,7 +193,7 @@ is an `ABC`, supporting: There are a number of implementations, as every quantization scheme typically needs at least one concrete `QuantizedLayout`. Simple schemes like affine quantization can be fully defined in terms of a single -[`TensorScaledLayout`](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/types/layouts.py#L43). +[`TensorScaledLayout`](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/types/layouts.py#L43). Whereas packed schemes like we find in inference engines like GGML and XNNPACK optimally require both a packed layout and a planar layout. @@ -224,7 +224,7 @@ interpreting/transforming using their natively defined forms. Previously, we found a rich type system defining all manner of layouts and quantization schemes, but what can be done with it? That is where the sharktank functional op library comes in. These -[logical ops](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/signatures.py) +[logical ops](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/signatures.py) provide the building blocks to implement built-in and custom `nn.Module` implementations operating on `InferenceTensor` (and torch.Tensor) types. @@ -239,12 +239,12 @@ implementation at any needed level of granularity: structures and preserve it when computing (when combined with a fusing compiler, this alone provides decent fallback implementations for a variety of "weight compression" oriented techniques). See - [some examples](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/custom_impls.py#L51). + [some examples](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/custom_impls.py#L51). * Pure-Torch decompositions for algebraic techniques like affine quantization (when combined with a fusing compiler, this alone is sufficient for optimization). See - [qlinear](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/qlinear_impls.py) and - [qconv](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/qconv_impls.py) + [qlinear](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qlinear_impls.py) and + [qconv](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qconv_impls.py) implementations of actual affine quantized decompositions. * Completely custom packed/optimized implementation. These can be written to activate on any level of detail of the type hierarchy. The implementation @@ -280,8 +280,8 @@ level. Some examples: [tensor trace/print](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/iree.py#L52) * [Simple linalg based template expansion](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/_jinja_test_ops.py#L28) (see backing example [jinja template](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/templates/test_add_jinja.mlir)). -* Optimal linalg-based [8-bit block scaled mmt for weight compression](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/kernels/mmt_block_scaled_q8.py) - (see backing [jinja template](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/kernels/templates/mmt_block_scaled_q8_3d.mlir)). +* Optimal linalg-based [8-bit block scaled mmt for weight compression](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/mmt_block_scaled_q8.py) + (see backing [jinja template](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/templates/mmt_block_scaled_q8_3d.mlir)). * DSL based [like this fused attention kernel](https://github.com/iree-org/iree-turbine/blob/main/tests/kernel/fused_attention_test.py#L20) (note that in this case, the DSL exports to the unerlying IR-based registration mechanism used in the previous examples). @@ -292,8 +292,8 @@ Since all of these types of custom kernels are just defined with simple Python tooling, they are really fast to iterate on. The linalg based kernels specifically tend to be highly portable, and we don't hesitate to write one of those when we need something specific that PyTorch doesn't provide out of the box -(i.e. [proper mixed-precision integer conv](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py) -([template](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/kernels/templates/conv_2d_nchw_fchw.mlir))). +(i.e. [proper mixed-precision integer conv](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py) +([template](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/templates/conv_2d_nchw_fchw.mlir))). ## Dataset transformation @@ -307,7 +307,7 @@ We take a practical approach to this, writing implementation specific converters where needed, and taking advantage of industry-standard consolidation points where available (like GGUF) in order to cover a wider surface area. -Behind both is the notion of a [`Dataset`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/theta.py#L263), +Behind both is the notion of a [`Dataset`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/theta.py#L263), which combines some set of hyper-parameters with a root `Theta` object (typically representing the layer-tree of frozen tensors). Datasets can be losslessly persisted to IREE IRPA files, which can then be loaded by either @@ -321,9 +321,9 @@ transform, shard, etc. See some examples: -* [models/punet/tools/import_hf_dataset.py](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_hf_dataset.py) : +* [models/punet/tools/import_hf_dataset.py](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_hf_dataset.py) : Creating a `Dataset` object from an HF diffusers safetensors file and config.json. -* [models/punet/tools/import_brevitas_dataset.py](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py) : +* [models/punet/tools/import_brevitas_dataset.py](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py) : Creates a quantized `Dataset` by combining: * HF diffusers `config.json` diff --git a/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md b/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md index e3150ed5c..1ce2d1e8d 100644 --- a/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md +++ b/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md @@ -16,7 +16,7 @@ process of exporting a model for use in the shortfin llm server with an MI300 GP ### Setting Up Environment Follow the `Development Getting Started` docs -[here](https://github.com/nod-ai/SHARK-Platform/blob/main/README.md#development-getting-started) +[here](https://github.com/nod-ai/shark-ai/blob/main/README.md#development-getting-started) to setup your environment for development. We will use an example with `llama_8b_f16_decomposed` in order to describe the diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md index 985e55c13..5e0749546 100644 --- a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md +++ b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md @@ -36,8 +36,8 @@ pip install shark-ai #### Nightly ```bash -pip install sharktank -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels -pip install shortfin -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels +pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels +pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels ``` #### Install dataclasses-json diff --git a/docs/user_guide.md b/docs/user_guide.md index b7a530583..c3da1f4f5 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -50,7 +50,7 @@ pip install shark-ai[apps] Temporarily, you may need an update to your `shortfin` install. Install the latest pre-release with: ``` -pip install shortfin --upgrade --pre -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels +pip install shortfin --upgrade --pre -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels ``` ### Test the installation. diff --git a/shark-ai/pyproject.toml b/shark-ai/pyproject.toml index f78a1641f..133026d13 100644 --- a/shark-ai/pyproject.toml +++ b/shark-ai/pyproject.toml @@ -24,7 +24,7 @@ requires-python = ">= 3.10" dynamic = ["version", "dependencies"] [project.urls] -Repository = "https://github.com/nod-ai/SHARK-Platform" +Repository = "https://github.com/nod-ai/shark-ai" [project.optional-dependencies] onnx = [ diff --git a/sharktank/README.md b/sharktank/README.md index c36cdd055..7770595ed 100644 --- a/sharktank/README.md +++ b/sharktank/README.md @@ -12,7 +12,7 @@ tooling. ## Project Status -[![CI - Perplexity](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_eval.yaml/badge.svg?branch=main&event=schedule)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_eval.yaml) +[![CI - Perplexity](https://github.com/nod-ai/shark-ai/actions/workflows/ci_eval.yaml/badge.svg?branch=main&event=schedule)](https://github.com/nod-ai/shark-ai/actions/workflows/ci_eval.yaml) ## Examples diff --git a/sharktank/pyproject.toml b/sharktank/pyproject.toml index 65f264d16..e5f9972a3 100644 --- a/sharktank/pyproject.toml +++ b/sharktank/pyproject.toml @@ -24,7 +24,7 @@ requires-python = ">= 3.11" dynamic = ["version", "dependencies", "optional-dependencies"] [project.urls] -Repository = "https://github.com/nod-ai/SHARK-Platform" +Repository = "https://github.com/nod-ai/shark-ai" [tool.setuptools.packages.find] where = ["."] diff --git a/sharktank/sharktank/ops/custom_impls.py b/sharktank/sharktank/ops/custom_impls.py index 8f6654a8e..9acc7c562 100644 --- a/sharktank/sharktank/ops/custom_impls.py +++ b/sharktank/sharktank/ops/custom_impls.py @@ -33,7 +33,7 @@ # Fused FP matmul. -# Disabled: See https://github.com/nod-ai/SHARK-Platform/issues/44 +# Disabled: See https://github.com/nod-ai/shark-ai/issues/44 # @matmul.override(Tensor, Tensor) # def matmul_mmtfp_tensor_tensor(lhs, rhs, *, transpose_rhs: bool): # lhs = unbox_tensor(lhs) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 057d3b664..9deade56c 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -25,7 +25,7 @@ class ExportMlirException(Exception): - """SHARK-Platform export MLIR exception that preserves the command line and error output.""" + """shark-ai export MLIR exception that preserves the command line and error output.""" def __init__(self, process: subprocess.CompletedProcess, cwd: str): try: diff --git a/sharktank/tests/ops/ops_test.py b/sharktank/tests/ops/ops_test.py index 8b37525e5..ad6759ce6 100644 --- a/sharktank/tests/ops/ops_test.py +++ b/sharktank/tests/ops/ops_test.py @@ -136,7 +136,7 @@ def testMatchFail(self): ): ops.matmul(1, 2) - @unittest.skip("https://github.com/nod-ai/SHARK-Platform/issues/44") + @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44") def testTorchImplTransposedRHS(self): ops._registry._test_enable_last_op_dispatch(True) t1 = torch.rand(32, 16, dtype=torch.float32) @@ -149,7 +149,7 @@ def testTorchImplTransposedRHS(self): ops.custom_impls.matmul_mmtfp_tensor_tensor, ) - @unittest.skip("https://github.com/nod-ai/SHARK-Platform/issues/44") + @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44") def testTorchImplNonTransposedRHS(self): ops._registry._test_enable_last_op_dispatch(True) t1 = torch.rand(32, 16, dtype=torch.float32) @@ -162,7 +162,7 @@ def testTorchImplNonTransposedRHS(self): ops.custom_impls.matmul_mmtfp_tensor_tensor, ) - @unittest.skip("https://github.com/nod-ai/SHARK-Platform/issues/44") + @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44") def testTorchImplTransposedPrimitiveRHS(self): ops._registry._test_enable_last_op_dispatch(True) t1 = torch.rand(32, 16, dtype=torch.float32) diff --git a/shortfin/README.md b/shortfin/README.md index 13ee20966..6269ca702 100644 --- a/shortfin/README.md +++ b/shortfin/README.md @@ -7,7 +7,7 @@ and serving engine. Shortfin consists of these major components: [IREE](https://github.com/iree-org/iree) * Python bindings for the underlying inference library * Example applications in - ['shortfin_apps'](https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps) + ['shortfin_apps'](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps) built using the python bindings ## Prerequisites diff --git a/shortfin/pyproject.toml b/shortfin/pyproject.toml index 7c4ed8a33..1abb49ef6 100644 --- a/shortfin/pyproject.toml +++ b/shortfin/pyproject.toml @@ -31,7 +31,7 @@ requires-python = ">= 3.10" dynamic = ["version"] [project.urls] -Repository = "https://github.com/nod-ai/SHARK-Platform" +Repository = "https://github.com/nod-ai/shark-ai" Documentation = "https://shortfin.readthedocs.io/en/latest/" [project.optional-dependencies] diff --git a/tuner/README.md b/tuner/README.md index e6a515729..3737f6bdf 100644 --- a/tuner/README.md +++ b/tuner/README.md @@ -33,5 +33,5 @@ documentation](https://iree.dev/building-from-source/getting-started/#python-bin ## Examples Check the `examples` directory for sample tuners implemented with `libtuner`. -The [`dispatch` example](https://github.com/nod-ai/SHARK-Platform/tree/main/tuner/examples/dispatch) +The [`dispatch` example](https://github.com/nod-ai/shark-ai/tree/main/tuner/examples/dispatch) should be a good starting point for most users. diff --git a/tuner/pyproject.toml b/tuner/pyproject.toml index 1661a7744..c36326bf7 100644 --- a/tuner/pyproject.toml +++ b/tuner/pyproject.toml @@ -21,4 +21,4 @@ requires-python = ">= 3.10" dynamic = ["version"] [project.urls] -Repository = "https://github.com/nod-ai/SHARK-Platform" +Repository = "https://github.com/nod-ai/shark-ai" diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 2f21520f0..b50df12d5 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -74,7 +74,7 @@ def apply_configuration( class DispatchTuner(DispatchParser): - # TODO(https://github.com/nod-ai/SHARK-Platform/issues/453): Remove this in favor of configuring using transform dialect. + # TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove this in favor of configuring using transform dialect. @abstractmethod def apply_params( self, From e417abd25201c47360cdb2e9a2d29996a170c7c4 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Fri, 15 Nov 2024 12:30:29 -0600 Subject: [PATCH 55/59] Fp8 llama3 support (#320) Makes changes to the importer to allow for newer quark format Adds flag `--no-fake-quant` to enable exporting to mlir in native fp8 Adds support for KV cache quant/dequant Adds support to QuantizerTensor to allow dequanting from a torch tensor using previous parameters. This may not be the best way to do this, but it seems better than pulling the quant parameters out into the wild. This kind of exposes a missing abstraction in our workflow as most things assume we are not using unpack().qs with the intention of doing something with that value and then later dequantizing it. (The kv cache being the relevant example) --- .../sharktank/examples/export_paged_llm_v1.py | 12 ++-- sharktank/sharktank/examples/paged_llm_v1.py | 7 ++- sharktank/sharktank/layers/causal_llm.py | 2 + .../sharktank/layers/configs/llm_configs.py | 5 +- sharktank/sharktank/layers/linear.py | 26 +++++--- .../layers/paged_llama_attention_block.py | 59 +++++++++++++++++-- sharktank/sharktank/models/llama/llama.py | 4 ++ .../llama/tools/import_quark_dataset.py | 48 +++++++-------- sharktank/sharktank/ops/default_impls.py | 1 - sharktank/sharktank/ops/qlinear_impls.py | 10 ++-- sharktank/sharktank/types/quantizers.py | 19 ++++++ sharktank/sharktank/utils/cli.py | 18 ++++++ sharktank/tests/layers/linear_test.py | 2 +- .../tests/models/llama/attention_test.py | 2 +- .../models/llama/benchmark_amdgpu_test.py | 2 +- 15 files changed, 160 insertions(+), 57 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index f22b2ccbd..a740f0bff 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -54,24 +54,19 @@ def main(): help="Enables strictness during export", action="store_true", ) - parser.add_argument( - "--attention-kernel", - type=str, - default="decomposed", - choices=["decomposed", "torch"], - ) - + cli.add_quantization_options(parser) + cli.add_model_options(parser) args = cli.parse(parser) dataset_type = cli.get_input_data_files(args) dataset_type = "irpa" if "irpa" in dataset_type else "gguf" dataset = cli.get_input_dataset(args) - hp = configs.LlamaHParams.from_gguf_props(dataset.properties) tensor_parallelism_size = ( dataset.properties["tensor_parallelism_size"] if "tensor_parallelism_size" in dataset.properties else 1 ) + llama_config = LlamaModelConfig( hp, tensor_parallelism_size=tensor_parallelism_size, @@ -80,6 +75,7 @@ def main(): kv_cache_type="direct" if args.bs == [1] else "paged", attention_kernel=args.attention_kernel, ) + llama_config.fake_quant = args.fake_quant if llama_config.hp.expert_count: if llama_config.hp.model_arch == "grok": diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 6d0bfd14c..b30acc026 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -258,16 +258,15 @@ def main(): ) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) + cli.add_quantization_options(parser) + cli.add_model_options(parser) args = cli.parse(parser) - device = torch.device(args.device) if args.device else None activation_dtype = getattr(torch, args.activation_dtype) assert isinstance(activation_dtype, torch.dtype) - dataset = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) prompts = args.prompt - config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(dataset.properties), block_seq_stride=16, @@ -275,8 +274,10 @@ def main(): device=device, activation_dtype=activation_dtype, attention_dtype=activation_dtype, + attention_kernel=args.attention_kernel, use_hf=args.use_hf, tensor_parallelism_size=args.tensor_parallelism_size, + fake_quant=args.fake_quant, ) if config.tensor_parallelism_size > 1: dataset.root_theta = shard_theta(dataset.root_theta, config) diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py index 7a09995a8..8ace77981 100644 --- a/sharktank/sharktank/layers/causal_llm.py +++ b/sharktank/sharktank/layers/causal_llm.py @@ -33,12 +33,14 @@ def __init__( device: Optional[torch.device] = None, activation_dtype: torch.dtype = torch.float32, attention_dtype: torch.dtype = torch.float32, + fake_quant: bool = True, ): super().__init__(theta) self.device = device self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype self.context_length = context_length + self.fake_quant = fake_quant if static_tables: self.register_buffer( diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 6dbe6fc52..c440ad441 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -44,7 +44,7 @@ class LlamaHParams: @staticmethod def from_gguf_props(p: dict[str, Any]): - name_prefix = p["general.architecture"] + name_prefix = p.get("general.architecture", "llama") default_expert_count = 0 default_expert_used_count = 0 default_rope_freq_base = 10000.0 @@ -156,6 +156,9 @@ class LlamaModelConfig: # Dtype to use for attention. attention_dtype: torch.dtype = torch.float16 + # fake quant determines the mode the Layer Thetas operate w.r.t quantized tensors. + fake_quant: bool = True + # How many devices are involved for tensor parallel sharding. # If greater than 1, the model will expect sharded model parameters and function # arguments. diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py index c5e2ea330..b679dccde 100644 --- a/sharktank/sharktank/layers/linear.py +++ b/sharktank/sharktank/layers/linear.py @@ -15,6 +15,7 @@ QuantizerTensor, StaticScaledQuantizer, TensorScaledLayout, + PlanarQuantizedTensor, ) __all__ = [ @@ -29,6 +30,10 @@ class LinearLayer(ThetaLayer): if premul_input is not None: x = x * premul_input matmul(x, weight.T) + bias + + fake_quant exists to allow export without adding dequant ops. + when fake_quant is True, the op will in quant dequant fashion. + When false, it will keep quantized types. ``` """ @@ -38,11 +43,13 @@ def __init__( *, weight_name: str = "weight", bias_name: str = "bias", + fake_quant: bool = True, ): super().__init__(theta) self._simulate_native_quant = True self.weight = self.theta_tensor(weight_name) self.bias = None + self.fake_quant = fake_quant if bias_name in self.theta.keys: self.bias = self.theta_tensor(bias_name) @@ -65,18 +72,23 @@ def forward(self, x): if q_input is not None: x = q_input.quantize(x) - elif qdq_input is not None: - # TODO: probably need a way to only do q_input if exporting. + if self.fake_quant: + x = x.unpack().dequant() + elif qdq_input is not None and self.fake_quant: x = qdq_input.quantize(x).unpack().dequant() y = ops.linear(x, weight, bias) # Unconditionally dequantize. - # TODO: Support a q_output specifier that signals the layer to let - # the QuantizedTensor escape. - if isinstance(y, QuantizedTensor): + if isinstance(y, QuantizedTensor) and not self.fake_quant: y = y.unpack().dequant() - if qdq_output is not None: - # TODO: same as above. + # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32. + # We can truncate to fp16 in iree, so we do a cast here + # to account for this in the IR. This is may not be the right + # level to do this, but for now its here. + if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz: + y = ops.to(y, torch.float16) + return y + if qdq_output is not None and self.fake_quant: y = qdq_output.quantize(y).unpack().dequant() return y diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py index 6b460d81b..22647bf49 100644 --- a/sharktank/sharktank/layers/paged_llama_attention_block.py +++ b/sharktank/sharktank/layers/paged_llama_attention_block.py @@ -10,7 +10,7 @@ import torch import torch.nn.functional as F - +from ..types import QuantizerTensor from .base import Theta, ThetaLayer from .linear import LinearLayer from .norm import RMSNormLayer @@ -40,6 +40,7 @@ def __init__( attention_kernel: str = "decomposed", attention_scale: Optional[float] = None, softcap: Optional[float] = None, + fake_quant: Optional[bool] = True, ): super().__init__(theta) @@ -51,14 +52,28 @@ def __init__( self.attention_kernel = attention_kernel self.attention_scale = attention_scale self.softcap = softcap + self.fake_quant = fake_quant self.add_module( "attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon) ) - self.add_module("attn_q", LinearLayer(theta("attn_q"))) - self.add_module("attn_k", LinearLayer(theta("attn_k"))) - self.add_module("attn_v", LinearLayer(theta("attn_v"))) - self.add_module("attn_output", LinearLayer(theta("attn_output"))) + self.add_module( + "attn_q", LinearLayer(theta("attn_q"), fake_quant=self.fake_quant) + ) + self.add_module( + "attn_k", LinearLayer(theta("attn_k"), fake_quant=self.fake_quant) + ) + self.add_module( + "attn_v", LinearLayer(theta("attn_v"), fake_quant=self.fake_quant) + ) + self.add_module( + "attn_output", LinearLayer(theta("attn_output"), fake_quant=self.fake_quant) + ) + self.cache_quantizer = None + if "kv_cache" in theta.keys: + self.cache_quantizer: Optional[QuantizerTensor] = theta.optional_tensor( + "kv_cache.quantizer" + ) if theta.optional_tensor("attn_output_norm") is None: self.add_module( @@ -113,6 +128,29 @@ def forward( # Full sequence length. kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride + # Used by fp8_e4m3fnuz model + if self.cache_quantizer is not None: + # For fake quant, store the fp16 qdq value in the cache + if self.fake_quant: + xk = ( + self.cache_quantizer.quantize(xk) + .unpack() + .dequant() + .to(torch.float16) + ) + xv = ( + self.cache_quantizer.quantize(xv) + .unpack() + .dequant() + .to(torch.float16) + ) + # For real quant, store the quantized fp8 value in the cache + else: + # TODO: this seems like a bastardization of our quantized tensor api + # Probably want to add support for using quantized tensors more directly + xk = self.cache_quantizer.quantize(xk).unpack().qs + xv = self.cache_quantizer.quantize(xv).unpack().qs + xk, xv = self.transact_cache( xk_cache_update=xk, xv_cache_update=xv, @@ -138,6 +176,14 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: xk = repeat_kv(xk) xv = repeat_kv(xv) + # Fake quant is already dequantized when stored in the cache. + if self.cache_quantizer and not self.fake_quant: + xk = self.cache_quantizer.dequantize_raw_tensor( + xk, torch.float16, name="xk_deq" + ) + xv = self.cache_quantizer.dequantize_raw_tensor( + xv, torch.float16, name="xv_deq" + ) # Transpose into [bs, heads, sl, dim] xq = xq.transpose(1, 2) keys = xk.transpose(1, 2) @@ -170,7 +216,8 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor: attn_weights, values ) # (bs, heads, slen, head_dim) else: - is_causal = attention_mask is None and batch_seq_len == 1 + is_causal = True + attention_mask = None attn_output = ops.scaled_dot_product_attention( q=xq, # [bs, ..., sl, dim] k=keys, # [bs, ..., sl, dim] diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py index 2ec25e171..0a9a6f1c3 100644 --- a/sharktank/sharktank/models/llama/llama.py +++ b/sharktank/sharktank/models/llama/llama.py @@ -71,6 +71,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): device=config.device, activation_dtype=config.activation_dtype, attention_dtype=config.attention_dtype, + fake_quant=config.fake_quant, ) self.config = config self.hp = hp @@ -113,6 +114,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig): head_count_kv=hp.attention_head_count_kv, rms_epsilon=hp.attention_layer_norm_rms_epsilon, attention_kernel=self.attention_kernel, + fake_quant=self.fake_quant, ) for n in range(hp.block_count) ] @@ -284,6 +286,7 @@ def __init__( head_count_kv: int, rms_epsilon: float, attention_kernel: str = "decomposed", + fake_quant: bool = True, ): super().__init__(theta) self.add_module( @@ -297,6 +300,7 @@ def __init__( head_count_kv=head_count_kv, rms_epsilon=rms_epsilon, attention_kernel=attention_kernel, + fake_quant=fake_quant, ), ) self.add_module( diff --git a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py index 0d869932e..052593748 100644 --- a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py +++ b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py @@ -107,25 +107,21 @@ def apply_per_layer_quant( layer_theta = root_theta(layer_name) - weight_quant_scale = layer_theta.tensor("weight_quant_scale").as_torch() + weight_quant_scale = layer_theta.tensor("weight_scale").as_torch() weight = layer_theta.tensor("weight").as_torch() # It looks dumb but, this step is required for numerical correctness against quark. - weight = weight.view(torch.float8_e4m3fn) + # weight = weight.view(torch.float8_e4m3fn) weight = (weight.to(torch.float64) * weight_quant_scale).to(torch.float16) - weight_quant_zero_point = layer_theta.optional_tensor("weight_quant_zero_point") + weight_quant_zero_point = layer_theta.optional_tensor("weight_zero_point") if weight_quant_zero_point == None: weight_quant_zero_point = torch.zeros(1, dtype=torch.float32) else: weight_quant_zero_point = weight_quant_zero_point.as_torch() - input_quant_scale = as_torch_or_none( - layer_theta.optional_tensor("input_quant_scale") - ) - output_quant_scale = as_torch_or_none( - layer_theta.optional_tensor("output_quant_scale") - ) + input_quant_scale = as_torch_or_none(layer_theta.optional_tensor("input_scale")) + output_quant_scale = as_torch_or_none(layer_theta.optional_tensor("output_scale")) if weight_quant_scale is None: print("weight quant scale not found for layer ", layer_name) @@ -190,11 +186,11 @@ def quantize_weight( reciprocal_scale=output_quant_scale * 2.0, dtype=torch.float8_e4m3fnuz, ) - names = [f"{i}.qdq_input" for i in [q_name, k_name, v_name]] + names = [f"{i}.q_input" for i in [q_name, k_name, v_name]] for name in names: updated_tensors[name] = StaticScaledQuantizer( name=name, - scale=1.0 / input_quant_scale * 2.0, + scale=1.0 / (input_quant_scale * 2.0), reciprocal_scale=input_quant_scale * 2.0, dtype=torch.float8_e4m3fnuz, ) @@ -214,18 +210,18 @@ def quantize_weight( ) # we explicitly provide the reciprocal scale because converting from float16 to float8 after doing 1/scale results in significant numerical differences if input_quant_scale is not None: - updated_tensors[new_layer_name + ".qdq_input"] = StaticScaledQuantizer( - name=new_layer_name + ".qdq_input", - scale=1.0 / input_quant_scale, - reciprocal_scale=input_quant_scale, - dtype=torch.float8_e4m3fn, + updated_tensors[new_layer_name + ".q_input"] = StaticScaledQuantizer( + name=new_layer_name + ".q_input", + scale=1.0 / (input_quant_scale * 2.0), + reciprocal_scale=input_quant_scale * 2.0, + dtype=torch.float8_e4m3fnuz, ) if output_quant_scale is not None: updated_tensors[new_layer_name + ".qdq_output"] = StaticScaledQuantizer( name=new_layer_name + ".qdq_output", scale=1.0 / output_quant_scale, reciprocal_scale=output_quant_scale, - dtype=torch.float8_e4m3fn, + dtype=torch.float8_e4m3fnuz, ) # Remove the updated tensor from the original tree. @@ -261,15 +257,15 @@ def update_norm_layer( sub_name = layer_name + "." + sub new_name = hf_to_gguf(sub_name) + ".weight" single_replace(quant_theta, sub_name, new_name, updated_tensors) - kv_cache_scale = ( - quant_theta(layer_name).tensor("kv_cache_scaling_factor").as_torch() - ) + kv_cache_scale = quant_theta(layer_name, "self_attn").tensor("kv_scale").as_torch() layer_idx = layer_name.split(".")[-1] new_name = f"blk.{layer_idx}.kv_cache" - kv_cache_scale = DefaultPrimitiveTensor( - name=new_name + ".kv_cache_scaling_factor", data=kv_cache_scale + updated_tensors[new_name] = StaticScaledQuantizer( + name=new_name + ".quantizer", + scale=1.0 / (kv_cache_scale * 2.0), + reciprocal_scale=kv_cache_scale * 2.0, + dtype=torch.float8_e4m3fnuz, ) - updated_tensors[new_name] = kv_cache_scale def single_replace( @@ -279,6 +275,8 @@ def single_replace( updated_tensors: dict[str, InferenceTensor], ): data = quant_theta(layer_name).tensor("weight").as_torch() + if data.dtype == torch.bfloat16: + data = data.to(torch.float32) updated_tensors[gguf_name] = DefaultPrimitiveTensor(name=gguf_name, data=data) @@ -330,7 +328,9 @@ def main(argv): "mlp.down_proj", "mlp.up_proj", "self_attn.o_proj", - "self_attn.qkv", + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", ] for layer in model_layers: for sub in sub_layers: diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py index 40384b21e..b155fdaa3 100644 --- a/sharktank/sharktank/ops/default_impls.py +++ b/sharktank/sharktank/ops/default_impls.py @@ -355,7 +355,6 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor: rhs = unbox_tensor(rhs) if transpose_rhs: rhs = rhs.mT - rhs = rhs.to(lhs.dtype) if len(lhs.shape) > 2 and len(rhs.shape) < 3: diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py index f4f7ac0ca..b66d3be1d 100644 --- a/sharktank/sharktank/ops/qlinear_impls.py +++ b/sharktank/sharktank/ops/qlinear_impls.py @@ -50,10 +50,12 @@ def qlinear_tensor_scaled( # Handle only integer and fp8 quantizations. if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point: - if ( - x_layout.qs.dtype != torch.float8_e4m3fnuz - or weight_layout.qs.dtype != torch.float8_e4m3fnuz - ): + if x_layout.qs.dtype == torch.float8_e4m3fnuz: + # assume quark + return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True).to( + torch.float16 + ) + else: return NotImplemented # Bias. diff --git a/sharktank/sharktank/types/quantizers.py b/sharktank/sharktank/types/quantizers.py index 575c969de..d3c093b85 100644 --- a/sharktank/sharktank/types/quantizers.py +++ b/sharktank/sharktank/types/quantizers.py @@ -131,6 +131,25 @@ def __init__( else: assert len(self._scale.shape) == 0, "Expected per-tensor scale to be 0D" + def dequantize_raw_tensor( + self, t: torch.Tensor, to: torch.dtype, *, name: str + ) -> torch.Tensor: + return ( + PlanarQuantizedTensor( + shape=t.shape, + name=t.name, + layout=TensorScaledLayout( + shape=t.shape, + d=self._reciprocal_scale, + qs=t, + m=self.offset, + dtype=to, + ), + ) + .unpack() + .dequant() + ) + def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor: """Performs a quantizing transformation on t, returning a QuantizeTensor.""" shape = list(t.shape) diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 396c74363..84ee741bf 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -61,6 +61,24 @@ def add_output_dataset_options(parser: argparse.ArgumentParser): ) +def add_model_options(parser: argparse.ArgumentParser): + """Adds model config options not exclusive to export or eager""" + parser.add_argument( + "--attention-kernel", + type=str, + default="decomposed", + choices=["decomposed", "torch"], + ) + + +def add_quantization_options(parser: argparse.ArgumentParser): + parser.add_argument( + "--fake-quant", + action=argparse.BooleanOptionalAction, + help="whether or not to run/export the model in fake quant mode. Note, running eagerly without fake quant is dependent on torch types supporting operations. YMMV", + ) + + def add_tokenizer_options(parser: argparse.ArgumentParser): """Adds options for specifying a tokenizer. diff --git a/sharktank/tests/layers/linear_test.py b/sharktank/tests/layers/linear_test.py index e2d038f72..ad657889d 100644 --- a/sharktank/tests/layers/linear_test.py +++ b/sharktank/tests/layers/linear_test.py @@ -84,7 +84,7 @@ def testNativeQuant_SymPerTensor_AsymPerAxis0_Dynamic(self): bias_quant, ] ) - linear = LinearLayer(theta) + linear = LinearLayer(theta, fake_quant=False) output = linear(lhs) output_ref = torch.matmul(lhs, rhs.T) + bias diff --git a/sharktank/tests/models/llama/attention_test.py b/sharktank/tests/models/llama/attention_test.py index daeefd93b..211fab5a0 100644 --- a/sharktank/tests/models/llama/attention_test.py +++ b/sharktank/tests/models/llama/attention_test.py @@ -59,7 +59,7 @@ def test(self): head_dim=head_dim, head_count_kv=head_count_kv, rms_epsilon=rms_epsilon, - attention_kernel="torch", + attention_kernel="decomposed", ) attention_embedding = RotaryEmbeddingLayer( rope_dimension_count=rope_dimension_count, diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index adbfeaf7e..f70607832 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -255,7 +255,7 @@ def testBenchmark8B_fp8_Decomposed(self): ) @pytest.mark.xfail( - reason="Test not yet implemented", strict=True, raises=ExportMlirException + reason="Compile failure", strict=True, raises=ExportMlirException ) def testBenchmark8B_fp8_Non_Decomposed(self): output_file_name = self.dir_path_8b / "fp8_torch" From b57464919b50c5930aea412170dc5d189d320d41 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 15 Nov 2024 10:44:01 -0800 Subject: [PATCH 56/59] Downgrade version to 2.9.2 to push a patch release. (#547) Still working out the details for release process / branching. Planned next steps: 1. Merge this 2. Trigger https://github.com/nod-ai/shark-ai/actions/workflows/build_packages.yml to build 2.9.2 3. Delete 3.0.0 wheels from https://github.com/nod-ai/shark-ai/releases/tag/dev-wheels 4. Push 2.9.2 to PyPI 5. Revert this PR to get the version back to 3.0.0 --- sharktank/version.json | 2 +- shortfin/version.json | 2 +- tuner/version.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sharktank/version.json b/sharktank/version.json index 85afb41ed..f09f61d2a 100644 --- a/sharktank/version.json +++ b/sharktank/version.json @@ -1,3 +1,3 @@ { - "package-version": "3.0.0.dev" + "package-version": "2.9.2.dev" } diff --git a/shortfin/version.json b/shortfin/version.json index 85afb41ed..f09f61d2a 100644 --- a/shortfin/version.json +++ b/shortfin/version.json @@ -1,3 +1,3 @@ { - "package-version": "3.0.0.dev" + "package-version": "2.9.2.dev" } diff --git a/tuner/version.json b/tuner/version.json index 85afb41ed..f09f61d2a 100644 --- a/tuner/version.json +++ b/tuner/version.json @@ -1,3 +1,3 @@ { - "package-version": "3.0.0.dev" + "package-version": "2.9.2.dev" } From 5ccfc87ecc42ce0cf49d36641b5bf77c928ed75f Mon Sep 17 00:00:00 2001 From: Stephen Baione <109226581+stbaione@users.noreply.github.com> Date: Fri, 15 Nov 2024 12:47:00 -0600 Subject: [PATCH 57/59] Set upper Numpy Version (#540) # Description Found a bug when walking through the shortfin llm docs using latest `nightly` sharktank. gguf is currently incompatible with numpy > 2. This breaks `sharktank.examples.export_paged_llm_v1` on linux. The gguf issue is filed [here](https://github.com/ggerganov/llama.cpp/issues/9021). It was closed from inactivity, but isn't actually solved and has a PR open for the fix. ## Repro Steps On linux: ### Before re-pinning Create a virtual environment: ```bash python -m venv --prompt sharktank .venv souce .venv/bin/activate ``` Install depencies and sharktank: ```bash pip install -r pytorch-cpu-requirements.txt pip install -r requirements.txt -e sharktank/ ``` Show numpy version (before re-pinning): ```bash pip show numpy | grep Version Version: 2.1.3 ``` Try running `export_paged_llm_v1`: ```bash python -m sharktank.examples.export_paged_llm_v1 --gguf-file=$PATH_TO_GGUF --output-mlir=./temp/model.mlir --output-config=./temp/config.json --bs=1,4 ``` You'll see this error: ```text Traceback (most recent call last): File "", line 198, in _run_module_as_main File "", line 88, in _run_code File "/home/stbaione/repos/SHARK-Platform/sharktank/sharktank/examples/export_paged_llm_v1.py", line 336, in main() File "/home/stbaione/repos/SHARK-Platform/sharktank/sharktank/examples/export_paged_llm_v1.py", line 67, in main dataset = cli.get_input_dataset(args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/stbaione/repos/SHARK-Platform/sharktank/sharktank/utils/cli.py", line 104, in get_input_dataset return Dataset.load(data_files["gguf"], file_type="gguf") ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/stbaione/repos/SHARK-Platform/sharktank/sharktank/types/theta.py", line 347, in load ds = _dataset_load_helper(path, file_type=file_type, mmap=mmap) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/stbaione/repos/SHARK-Platform/sharktank/sharktank/types/theta.py", line 536, in _dataset_load_helper return gguf_interop.load_file(path) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/stbaione/repos/SHARK-Platform/sharktank/sharktank/types/gguf_interop/base.py", line 117, in load_file reader = GGUFReader(gguf_path) ^^^^^^^^^^^^^^^^^^^^^ File "/home/stbaione/repos/SHARK-Platform/.venv_2/lib/python3.12/site-packages/gguf/gguf_reader.py", line 87, in __init__ if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/stbaione/repos/SHARK-Platform/.venv_2/lib/python3.12/site-packages/gguf/gguf_reader.py", line 137, in _get .newbyteorder(override_order or self.byte_order) ^^^^^^^^^^^^ AttributeError: `newbyteorder` was removed from the ndarray class in NumPy 2.0. Use `arr.view(arr.dtype.newbyteorder(order))` instead. ``` ## After re-pinning Create a virtual environment: ```bash python -m venv --prompt sharktank .venv souce .venv/bin/activate ``` Install depencies and sharktank: ```bash pip install -r pytorch-cpu-requirements.txt pip install -r requirements.txt -e sharktank/ ``` Show numpy version: ```bash pip show numpy | grep Version Version: 1.26.3 ``` Run `export_paged_llm_v1`: ```bash python -m sharktank.examples.export_paged_llm_v1 --gguf-file=$PATH_TO_GGUF --output-mlir=./temp/model.mlir --output-config=./temp/config.json --bs=1,4 ``` With re-pinning we get desired output: ```text Exporting decode_bs1 Exporting prefill_bs4 Exporting decode_bs4 GENERATED! Exporting Saving to './temp/model.mlir' ``` --------- Co-authored-by: Marius Brehler --- sharktank/requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt index 19e48f825..6b533d977 100644 --- a/sharktank/requirements.txt +++ b/sharktank/requirements.txt @@ -2,8 +2,7 @@ iree-turbine # Runtime deps. gguf==0.6.0 -numpy==1.26.3; sys_platform == 'win32' -numpy; sys_platform != 'win32' +numpy<2.0 # Needed for newer gguf versions (TODO: remove when gguf package includes this) # sentencepiece>=0.1.98,<=0.2.0 From 3ba0bcf39fe567c4fcb231ce9139c4320a53ca09 Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 15 Nov 2024 19:51:13 +0100 Subject: [PATCH 58/59] [shark-ai] Make it a general Python 3 package (#542) --- shark-ai/pyproject.toml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/shark-ai/pyproject.toml b/shark-ai/pyproject.toml index 133026d13..3f7e4a1da 100644 --- a/shark-ai/pyproject.toml +++ b/shark-ai/pyproject.toml @@ -14,12 +14,7 @@ classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", ] -requires-python = ">= 3.10" # Version is set via the `setup.py` and requirements are set via files below. dynamic = ["version", "dependencies"] From 8d9a923f1aa149b13ae7aa85fe57fc6286ac822d Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Fri, 15 Nov 2024 19:51:29 +0100 Subject: [PATCH 59/59] [sharktank] Make it a general Python 3 package (#543) --- sharktank/pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sharktank/pyproject.toml b/sharktank/pyproject.toml index e5f9972a3..01cad409b 100644 --- a/sharktank/pyproject.toml +++ b/sharktank/pyproject.toml @@ -14,11 +14,7 @@ classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", ] -requires-python = ">= 3.11" # Version is set via the `setup.py` and requirements are set via files below. dynamic = ["version", "dependencies", "optional-dependencies"]