diff --git a/.github/workflows/build_packages.yml b/.github/workflows/build_packages.yml
index 0ef36bc09..1200234c4 100644
--- a/.github/workflows/build_packages.yml
+++ b/.github/workflows/build_packages.yml
@@ -26,7 +26,7 @@ jobs:
with:
submodules: false
- name: Setup Python
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: 3.12
cache: "pip"
@@ -37,15 +37,15 @@ jobs:
- name: Generate release candidate versions
id: version_rc
run: |
- sharktank_package_version=$(python3 build_tools/gen_version_info_rc.py sharktank)
- shortfin_package_version=$(python3 build_tools/gen_version_info_rc.py shortfin)
- - name: Upload version_info_rc.json
+ sharktank_package_version=$(python3 build_tools/python_deploy/compute_local_version.py sharktank)
+ shortfin_package_version=$(python3 build_tools/python_deploy/compute_local_version.py shortfin)
+ - name: Upload version_local.json
uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
with:
- name: version_info_rc
+ name: version_local
path: |
- sharktank/version_info_rc.json
- shortfin/version_info_rc.json
+ sharktank/version_local.json
+ shortfin/version_local.json
build_packages:
name: "${{ matrix.package }} :: ${{ matrix.platform }} :: ${{ matrix.python-version }}"
@@ -60,6 +60,10 @@ jobs:
platform: linux-x86_64
package: sharktank
python-version: cp311-cp311 # Ignored (generic wheel), set for workflow naming
+ - runs-on: ubuntu-24.04
+ platform: linux-x86_64
+ package: shortfin
+ python-version: cp310-cp310
- runs-on: ubuntu-24.04
platform: linux-x86_64
package: shortfin
@@ -87,10 +91,10 @@ jobs:
path: "c" # Windows can hit path length limits, so use a short path.
submodules: false
- - name: Download version_info_rc.json
+ - name: Download version_local.json
uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8
with:
- name: version_info_rc
+ name: version_local
path: ./c/
merge-multiple: true
@@ -125,7 +129,7 @@ jobs:
token: "${{ secrets.RELEASE_PUBLISH_ACCESS_TOKEN }}"
tag: "dev-wheels"
name: "dev-wheels"
- body: "Automatic snapshot release of SHARK-Platform python wheels."
+ body: "Automatic snapshot release of shark-ai python wheels."
removeArtifacts: false
allowUpdates: true
replacesArtifacts: true
diff --git a/.github/workflows/ci-llama-large-tests.yaml b/.github/workflows/ci-llama-large-tests.yaml
index 394cba93a..41ad5af6b 100644
--- a/.github/workflows/ci-llama-large-tests.yaml
+++ b/.github/workflows/ci-llama-large-tests.yaml
@@ -41,15 +41,15 @@ jobs:
- name: "Setting up Python"
id: setup_python
- uses: actions/setup-python@v3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}
- name: "Checkout Code"
- uses: actions/checkout@v3
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Cache Pip Packages
- uses: actions/cache@v4
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
@@ -70,21 +70,20 @@ jobs:
# Test with pinned nightly releases, not what iree-turbine uses.
pip install -f https://iree.dev/pip-release-links.html --upgrade \
- iree-base-compiler==2.9.0rc20241108 \
- iree-base-runtime==2.9.0rc20241108 \
- "numpy<2.0"
+ iree-base-compiler==3.0.0rc20241115 \
+ iree-base-runtime==3.0.0rc20241115
- name: Run llama tests
run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --run-all-llama --iree-hip-target=gfx942 --html=out/index.html
- name: Deploy to GitHub Pages
- uses: peaceiris/actions-gh-pages@v3
+ uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
with:
github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
publish_dir: ./out
- name: Upload llama executable files
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
with:
name: llama-files
path: ${{ github.workspace }}/${{ steps.date.outputs.date }}
diff --git a/.github/workflows/ci-llama-quick-tests.yaml b/.github/workflows/ci-llama-quick-tests.yaml
index ce55f81f8..63637e9b9 100644
--- a/.github/workflows/ci-llama-quick-tests.yaml
+++ b/.github/workflows/ci-llama-quick-tests.yaml
@@ -42,15 +42,15 @@ jobs:
- name: "Setting up Python"
id: setup_python
- uses: actions/setup-python@v3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}
- name: "Checkout Code"
- uses: actions/checkout@v3
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Cache Pip Packages
- uses: actions/cache@v4
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
@@ -71,9 +71,8 @@ jobs:
# Test with pinned nightly releases, not what iree-turbine uses.
pip install -f https://iree.dev/pip-release-links.html --upgrade \
- iree-base-compiler==2.9.0rc20241108 \
- iree-base-runtime==2.9.0rc20241108 \
- "numpy<2.0"
+ iree-base-compiler==3.0.0rc20241115 \
+ iree-base-runtime==3.0.0rc20241115
- name: Run llama 8b tests
run: pytest sharktank/tests/models/llama/benchmark_amdgpu_test.py -v -s --iree-hip-target=gfx942 --run-8b-llama
diff --git a/.github/workflows/ci-sdxl.yaml b/.github/workflows/ci-sdxl.yaml
index 373bc9319..31218d25f 100644
--- a/.github/workflows/ci-sdxl.yaml
+++ b/.github/workflows/ci-sdxl.yaml
@@ -64,7 +64,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
- ref: iree-2.9.0rc20241108
+ ref: iree-3.0.0rc20241115
- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
@@ -76,7 +76,7 @@ jobs:
git submodule update --init --depth 1 -- third_party/hip-build-deps/
- name: Setup Python
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: "3.12"
cache: "pip"
diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml
new file mode 100644
index 000000000..6a5fa4112
--- /dev/null
+++ b/.github/workflows/ci-sglang-benchmark.yml
@@ -0,0 +1,88 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+name: SGLang Llama Benchmarking Tests
+
+on:
+ workflow_dispatch:
+ schedule:
+ # Weekdays at 4:00 AM UTC = 9:00 PM PST.
+ - cron: "0 4 * * 1-5"
+
+concurrency:
+ # A PR number if a pull request and otherwise the commit hash. This cancels
+ # queued and in-progress runs for the same PR (presubmit) or commit
+ # (postsubmit). The workflow name is prepended to avoid conflicts between
+ # different workflows.
+ group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
+ cancel-in-progress: true
+
+jobs:
+ sglang_bench_serve:
+ name: "SGLang Serving Benchmark Tests"
+ strategy:
+ matrix:
+ version: [3.11]
+ fail-fast: false
+ runs-on: llama-mi300x-3
+ defaults:
+ run:
+ shell: bash
+ env:
+ PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
+ steps:
+ - name: Get Current Date
+ id: date
+ run: echo "::set-output name=date::$(date +'%Y-%m-%d')"
+
+ - name: "Setting up Python"
+ id: setup_python
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ with:
+ python-version: ${{matrix.version}}
+
+ - name: "Checkout Code"
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Cache Pip Packages
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
+ id: cache-pip
+ with:
+ path: ${{ env.PIP_CACHE_DIR }}
+ key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
+
+ - name: Install pip deps
+ run: |
+ python -m pip install --no-compile --upgrade pip
+ # Note: We install in three steps in order to satisfy requirements
+ # from non default locations first. Installing the PyTorch CPU
+ # wheels saves multiple minutes and a lot of bandwidth on runner setup.
+ pip install --no-compile -r pytorch-cpu-requirements.txt
+ pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+ pip install --no-compile -r requirements.txt -e sharktank/ shortfin/
+
+ # Try with the latest nightly releases, not what iree-turbine pins.
+ # We could also pin to a known working or stable version.
+ # This should eventually stabilize. Do the best we can for now.
+ pip install -f https://iree.dev/pip-release-links.html --upgrade \
+ iree-base-compiler==3.0.0rc20241115 \
+ iree-base-runtime==3.0.0rc20241115 \
+ "numpy<2.0"
+
+ - name: Install SGLang
+ run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"
+
+ - name: Launch Shortfin Server
+ run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html
+
+ - name: Deploy to GitHub Pages
+ uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
+ with:
+ github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
+ publish_dir: ./out/llm/sglang
+ destination_dir: ./llm/sglang
+ keep_files: true
diff --git a/.github/workflows/ci-shark-platform.yml b/.github/workflows/ci-shark-ai.yml
similarity index 86%
rename from .github/workflows/ci-shark-platform.yml
rename to .github/workflows/ci-shark-ai.yml
index 708fed66f..28e2bc883 100644
--- a/.github/workflows/ci-shark-platform.yml
+++ b/.github/workflows/ci-shark-ai.yml
@@ -4,7 +4,7 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-name: CI - shark-platform
+name: CI - shark-ai
on:
workflow_dispatch:
@@ -37,15 +37,15 @@ jobs:
steps:
- name: "Setting up Python"
id: setup_python
- uses: actions/setup-python@v3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}
- name: "Checkout Code"
- uses: actions/checkout@v3
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Cache Pip Packages
- uses: actions/cache@v4
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
@@ -67,10 +67,9 @@ jobs:
# Try with the latest IREE nightly releases, not what iree-turbine pins.
# We could also pin to a known working or stable version.
# This should eventually stabilize. Do the best we can for now.
- pip install -f https://iree.dev/pip-release-links.html --upgrade \
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler \
- iree-base-runtime \
- "numpy<2.0"
+ iree-base-runtime
- name: Run LLM Integration Tests
- run: pytest -v build_tools/integration_tests/llm --log-cli-level=INFO
+ run: pytest -v app_tests/integration_tests/llm --log-cli-level=INFO
diff --git a/.github/workflows/ci-sharktank.yml b/.github/workflows/ci-sharktank.yml
index eadc33501..7d6a7b7f1 100644
--- a/.github/workflows/ci-sharktank.yml
+++ b/.github/workflows/ci-sharktank.yml
@@ -38,15 +38,15 @@ jobs:
steps:
- name: "Setting up Python"
id: setup_python
- uses: actions/setup-python@v3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}
- name: "Checkout Code"
- uses: actions/checkout@v3
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Cache Pip Packages
- uses: actions/cache@v4
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
@@ -62,8 +62,8 @@ jobs:
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
# Update to the latest iree packages.
- pip install -f https://iree.dev/pip-release-links.html --upgrade \
- iree-compiler iree-runtime --src deps \
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler iree-base-runtime --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
- name: Run sharktank tests
diff --git a/.github/workflows/ci-tuner.yml b/.github/workflows/ci-tuner.yml
index f1fcebfdc..81b920e31 100644
--- a/.github/workflows/ci-tuner.yml
+++ b/.github/workflows/ci-tuner.yml
@@ -35,7 +35,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: '3.10.12'
@@ -49,8 +49,11 @@ jobs:
pip install -r tuner/requirements-tuner.txt
python -m pip install \
--find-links https://iree.dev/pip-release-links.html \
- --upgrade \
+ --upgrade --pre \
iree-base-compiler iree-base-runtime
- name: Run tuner tests
run: pytest tuner/
+
+ - name: Run mypy type checker
+ run: mypy tuner/tuner
diff --git a/.github/workflows/ci_eval.yaml b/.github/workflows/ci_eval.yaml
index cf18212df..999d921c7 100644
--- a/.github/workflows/ci_eval.yaml
+++ b/.github/workflows/ci_eval.yaml
@@ -39,15 +39,15 @@ jobs:
steps:
- name: "Setting up Python"
id: setup_python
- uses: actions/setup-python@v3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}
- name: "Checkout Code"
- uses: actions/checkout@v3
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Cache Pip Packages
- uses: actions/cache@v4
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
@@ -69,16 +69,15 @@ jobs:
# Try with the latest IREE nightly releases, not what iree-turbine pins.
# We could also pin to a known working or stable version.
# This should eventually stabilize. Do the best we can for now.
- pip install -f https://iree.dev/pip-release-links.html --upgrade \
+ pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler \
- iree-base-runtime \
- "numpy<2.0"
+ iree-base-runtime
- name: Run perplexity test with IREE
run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py --longrun --iree-device='hip://7' --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=perplexity/perplexity_iree.html
- name: Deploy to GitHub Pages
- uses: peaceiris/actions-gh-pages@v3
+ uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
with:
github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
publish_dir: ./perplexity
@@ -101,15 +100,15 @@ jobs:
steps:
- name: "Setting up Python"
id: setup_python
- uses: actions/setup-python@v3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{matrix.version}}
- name: "Checkout Code"
- uses: actions/checkout@v3
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Cache Pip Packages
- uses: actions/cache@v4
+ uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
@@ -132,7 +131,7 @@ jobs:
run: pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py --longrun --llama3-8b-f16-model-path=/data/llama3.1/8b/llama8b_f16.irpa --llama3-8b-tokenizer-path=/data/llama3.1/8b/tokenizer_config.json --html=perplexity/perplexity_torch.html
- name: Deploy to GitHub Pages
- uses: peaceiris/actions-gh-pages@v3
+ uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0
with:
github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }}
publish_dir: ./perplexity
diff --git a/.github/workflows/ci_linux_x64-libshortfin.yml b/.github/workflows/ci_linux_x64-libshortfin.yml
index ad154748f..45ddfe90d 100644
--- a/.github/workflows/ci_linux_x64-libshortfin.yml
+++ b/.github/workflows/ci_linux_x64-libshortfin.yml
@@ -40,7 +40,7 @@ jobs:
runs-on: ubuntu-24.04
strategy:
matrix:
- python-version: ["3.11", "3.12"]
+ python-version: ["3.10", "3.11", "3.12"]
steps:
- name: Install dependencies
@@ -59,7 +59,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
- ref: iree-2.9.0rc20241108
+ ref: iree-3.0.0rc20241115
- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
@@ -71,7 +71,7 @@ jobs:
git submodule update --init --depth 1 -- third_party/hip-build-deps/
- name: Setup Python ${{ matrix.python-version }}
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
diff --git a/.github/workflows/ci_linux_x64_asan-libshortfin.yml b/.github/workflows/ci_linux_x64_asan-libshortfin.yml
index d9eee2576..5692a8336 100644
--- a/.github/workflows/ci_linux_x64_asan-libshortfin.yml
+++ b/.github/workflows/ci_linux_x64_asan-libshortfin.yml
@@ -32,8 +32,8 @@ concurrency:
env:
PYENV_ROOT: ${{ github.workspace }}/pyenv
- PYENV_REF: 9ecd803bffaffb949fbdd8c70cb086227f6a3202 # v2.4.10
- PYTHON_VER: 3.12.3
+ PYENV_REF: 96b3fb2fc3bee85650cb22e2cb06c83c24509a6d # v2.4.17
+ PYTHON_VER: 3.12.7
CACHE_ASAN_VER: 2
CACHE_DEPS_VER: 1
IREE_SOURCE_DIR: ${{ github.workspace }}/iree
@@ -109,7 +109,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_SOURCE_DIR }}
submodules: false
- ref: iree-2.9.0rc20241108
+ ref: iree-3.0.0rc20241115
- name: Initalize IREE submodules
working-directory: ${{ env.IREE_SOURCE_DIR }}
diff --git a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml
index c80b40c03..c382edbf4 100644
--- a/.github/workflows/ci_linux_x64_nogil-libshortfin.yml
+++ b/.github/workflows/ci_linux_x64_nogil-libshortfin.yml
@@ -57,7 +57,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
- ref: iree-2.9.0rc20241108
+ ref: iree-3.0.0rc20241115
- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
diff --git a/.github/workflows/ci_windows_x64-libshortfin.yml b/.github/workflows/ci_windows_x64-libshortfin.yml
index 4bbef8f12..544b45c76 100644
--- a/.github/workflows/ci_windows_x64-libshortfin.yml
+++ b/.github/workflows/ci_windows_x64-libshortfin.yml
@@ -54,7 +54,7 @@ jobs:
repository: iree-org/iree
path: ${{ env.IREE_REPO_DIR }}
submodules: false
- ref: iree-2.9.0rc20241108
+ ref: iree-3.0.0rc20241115
- name: Initalize IREE submodules
working-directory: ${{ env.IREE_REPO_DIR }}
@@ -66,7 +66,7 @@ jobs:
git submodule update --init --depth 1 -- third_party/hip-build-deps/
- name: Setup Python
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.3
+ uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: "3.12"
cache: "pip"
diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml
index 2b11178bf..8ec1e8d55 100644
--- a/.github/workflows/pre-commit.yaml
+++ b/.github/workflows/pre-commit.yaml
@@ -9,6 +9,6 @@ jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v3
- - uses: actions/setup-python@v3
- - uses: pre-commit/action@v3.0.1
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
diff --git a/.gitignore b/.gitignore
index 6474e6a8c..bdb0b5387 100644
--- a/.gitignore
+++ b/.gitignore
@@ -30,6 +30,9 @@ wheelhouse
*.whl
*.venv
+# Local-only config options
+version_local.json
+
#Model artifacts
*.pt
*.safetensors
diff --git a/README.md b/README.md
index aa4c46bdc..77f4a0d75 100644
--- a/README.md
+++ b/README.md
@@ -1,22 +1,40 @@
-# SHARK Modeling and Serving Libraries
+# shark-ai: SHARK Modeling and Serving Libraries
-**WARNING: This is an early preview that is in progress. It is not ready for
-general use.**
+> [!IMPORTANT]
+> Development is still in progress for several project components. See the
+> notes below for which workflows are best supported.
-![GitHub License](https://img.shields.io/github/license/nod-ai/SHARK-Platform)
- [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit)
+![GitHub License](https://img.shields.io/github/license/nod-ai/shark-ai)
+[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit)
## Sub-projects
+### [`shortfin/`](./shortfin/)
+
+
+
+[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush)
+
+The shortfin sub-project is SHARK's high performance inference library and
+serving engine.
+
+* API documentation for shortfin is available on
+ [readthedocs](https://shortfin.readthedocs.io/en/latest/).
+
### [`sharktank/`](./sharktank/)
-[![PyPI version](https://badge.fury.io/py/sharktank.svg)](https://badge.fury.io/py/sharktank) [![CI - sharktank](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-sharktank.yml/badge.svg?event=push)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-sharktank.yml?query=event%3Apush)
+[![PyPI version](https://badge.fury.io/py/sharktank.svg)](https://badge.fury.io/py/sharktank) [![CI - sharktank](https://github.com/nod-ai/shark-ai/actions/workflows/ci-sharktank.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci-sharktank.yml?query=event%3Apush)
The SHARK Tank sub-project contains a collection of model recipes and
conversion tools to produce inference-optimized programs.
+> [!WARNING]
+> SHARK Tank is still under development. Experienced users may want to try it
+> out, but we currently recommend most users download pre-exported or
+> pre-compiled model files for serving with shortfin.
+
* See the [SHARK Tank Programming Guide](./docs/programming_guide.md) for
@@ -25,25 +43,18 @@ conversion tools to produce inference-optimized programs.
* See [Direct Quantization with SHARK Tank](./docs/quantization.md)
for information about quantization support.
-### [`shortfin/`](./shortfin/)
-
-
-
-[![PyPI version](https://badge.fury.io/py/shortfin.svg)](https://badge.fury.io/py/shortfin) [![CI - shortfin](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_linux_x64-libshortfin.yml/badge.svg?event=push)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_linux_x64-libshortfin.yml?query=event%3Apush)
-
-The shortfin sub-project is SHARK's high performance inference library and
-serving engine.
-
-* API documentation for shortfin is available on
- [readthedocs](https://shortfin.readthedocs.io/en/latest/).
-
### [`tuner/`](./tuner/)
-[![CI - Tuner](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-tuner.yml/badge.svg?event=push)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci-tuner.yml?query=event%3Apush)
+[![CI - Tuner](https://github.com/nod-ai/shark-ai/actions/workflows/ci-tuner.yml/badge.svg?event=push)](https://github.com/nod-ai/shark-ai/actions/workflows/ci-tuner.yml?query=event%3Apush)
The Tuner sub-project assists with tuning program performance by searching for
optimal parameter configurations to use during model compilation.
+> [!WARNING]
+> SHARK Tuner is still in early development. Interested users may want
+> to try it out, but the tuner is not ready for general use yet. Check out
+> [the readme](tuner/README.md) for more details.
+
## Support matrix
@@ -52,65 +63,14 @@ optimal parameter configurations to use during model compilation.
Model name | Model recipes | Serving apps
---------- | ------------- | ------------
-SDXL | [`sharktank/sharktank/models/punet/`](https://github.com/nod-ai/SHARK-Platform/tree/main/sharktank/sharktank/models/punet) | [`shortfin/python/shortfin_apps/sd/`](https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps/sd)
-llama | [`sharktank/sharktank/models/llama/`](https://github.com/nod-ai/SHARK-Platform/tree/main/sharktank/sharktank/models/llama) | [`shortfin/python/shortfin_apps/llm/`](https://github.com/nod-ai/SHARK-Platform/tree/main/shortfin/python/shortfin_apps/llm)
-
-## Development getting started
-
-
-
-Use this as a guide to get started developing the project using pinned,
-pre-release dependencies. You are welcome to deviate as you see fit, but
-these canonical directions mirror what the CI does.
-
-### Setup a venv
-
-We recommend setting up a virtual environment (venv). The project is configured
-to ignore `.venv` directories, and editors like VSCode pick them up by default.
-
-```
-python -m venv .venv
-source .venv/bin/activate
-```
-
-### Install PyTorch for your system
-
-If no explicit action is taken, the default PyTorch version will be installed.
-This will give you a current CUDA-based version. Install a different variant
-by doing so explicitly first:
-
-*CPU:*
-
-```
-pip install -r pytorch-cpu-requirements.txt
-```
-
-*ROCM:*
-
-```
-pip install -r pytorch-rocm-requirements.txt
-```
-
-### Install development packages
-
-```
-# Install editable local projects.
-pip install -r requirements.txt -e sharktank/ shortfin/
+SDXL | [`sharktank/sharktank/models/punet/`](https://github.com/nod-ai/shark-ai/tree/main/sharktank/sharktank/models/punet) | [`shortfin/python/shortfin_apps/sd/`](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps/sd)
+llama | [`sharktank/sharktank/models/llama/`](https://github.com/nod-ai/shark-ai/tree/main/sharktank/sharktank/models/llama) | [`shortfin/python/shortfin_apps/llm/`](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps/llm)
-# Optionally clone and install editable iree-turbine dep in deps/
-pip install -f https://iree.dev/pip-release-links.html --src deps \
- -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
-```
-### Running tests
+## SHARK Users
-```
-pytest sharktank
-pytest shortfin
-```
+If you're looking to use SHARK check out our [User Guide](docs/user_guide.md).
-### Optional: pre-commits and developer settings
+## SHARK Developers
-This project is set up to use the `pre-commit` tooling. To install it in
-your local repo, run: `pre-commit install`. After this point, when making
-commits locally, hooks will run. See https://pre-commit.com/
+If you're looking to develop SHARK, check out our [Developer Guide](docs/developer_guide.md).
diff --git a/app_tests/__init__.py b/app_tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/app_tests/benchmark_tests/__init__.py b/app_tests/benchmark_tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/app_tests/benchmark_tests/llm/conftest.py b/app_tests/benchmark_tests/llm/conftest.py
new file mode 100644
index 000000000..aac66ca0f
--- /dev/null
+++ b/app_tests/benchmark_tests/llm/conftest.py
@@ -0,0 +1,47 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import os
+import pytest
+import sys
+
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
+from integration_tests.llm.utils import compile_model, export_paged_llm_v1
+
+
+@pytest.fixture(scope="module")
+def pre_process_model(request, tmp_path_factory):
+ tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")
+
+ model_path = request.param["model_path"]
+ settings = request.param["settings"]
+ batch_sizes = request.param["batch_sizes"]
+
+ tmp_dir = tmp_path_factory.mktemp("llm_benchmark_test")
+ mlir_path = tmp_dir / "model.mlir"
+ config_path = tmp_dir / "config.json"
+ vmfb_path = tmp_dir / "model.vmfb"
+
+ export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)
+
+ config = {
+ "module_name": "module",
+ "module_abi_version": 1,
+ "max_seq_len": 131072,
+ "attn_head_count": 8,
+ "attn_head_dim": 128,
+ "prefill_batch_sizes": batch_sizes,
+ "decode_batch_sizes": batch_sizes,
+ "transformer_block_count": 32,
+ "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
+ }
+ with open(config_path, "w") as file:
+ json.dump(config, file)
+
+ compile_model(mlir_path, vmfb_path, settings)
+
+ return tmp_dir
diff --git a/app_tests/benchmark_tests/llm/sglang_benchmark_test.py b/app_tests/benchmark_tests/llm/sglang_benchmark_test.py
new file mode 100644
index 000000000..8027fcea7
--- /dev/null
+++ b/app_tests/benchmark_tests/llm/sglang_benchmark_test.py
@@ -0,0 +1,108 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import logging
+import multiprocessing
+import os
+from pathlib import Path
+import pytest
+import time
+from unittest.mock import patch
+
+pytest.importorskip("sglang")
+from sglang import bench_serving
+
+from utils import SGLangBenchmarkArgs
+
+from integration_tests.llm.utils import (
+ find_available_port,
+ start_llm_server,
+)
+
+logger = logging.getLogger("__name__")
+
+device_settings = {
+ "device_flags": [
+ "--iree-hal-target-backends=rocm",
+ "--iree-hip-target=gfx942",
+ ],
+ "device": "hip",
+}
+
+# TODO: Download on demand instead of assuming files exist at this path
+MODEL_PATH = Path("/data/llama3.1/8b/llama8b_f16.irpa")
+TOKENIZER_DIR = Path("/data/llama3.1/8b/")
+
+
+@pytest.mark.parametrize("request_rate", [1, 2, 4, 8, 16, 32])
+@pytest.mark.parametrize(
+ "pre_process_model",
+ [
+ (
+ {
+ "model_path": MODEL_PATH,
+ "settings": device_settings,
+ "batch_sizes": [1, 4],
+ }
+ )
+ ],
+ indirect=True,
+)
+def test_sglang_benchmark_server(request_rate, pre_process_model):
+ # TODO: Remove when multi-device is fixed
+ os.environ["ROCR_VISIBLE_DEVICES"] = "1"
+
+ tmp_dir = pre_process_model
+
+ config_path = tmp_dir / "config.json"
+ vmfb_path = tmp_dir / "model.vmfb"
+ tokenizer_path = TOKENIZER_DIR / "tokenizer.json"
+
+ # Start shortfin llm server
+ port = find_available_port()
+ server_process = start_llm_server(
+ port,
+ tokenizer_path,
+ config_path,
+ vmfb_path,
+ MODEL_PATH,
+ device_settings,
+ timeout=30,
+ )
+
+ # Run and collect SGLang Serving Benchmark
+ benchmark_args = SGLangBenchmarkArgs(
+ backend="shortfin",
+ num_prompt=10,
+ base_url=f"http://localhost:{port}",
+ tokenizer=TOKENIZER_DIR,
+ request_rate=request_rate,
+ )
+ output_file = (
+ tmp_dir
+ / f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl"
+ )
+ benchmark_args.output_file = output_file
+
+ logger.info("Running SGLang Benchmark with the following args:")
+ logger.info(benchmark_args)
+ try:
+ start = time.time()
+ with patch.object(bench_serving, "print", side_effect=logger.info):
+ benchmark_process = multiprocessing.Process(
+ target=bench_serving.run_benchmark,
+ args=(benchmark_args.as_namespace(),),
+ )
+ benchmark_process.start()
+ benchmark_process.join()
+
+ logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds")
+ except Exception as e:
+ logger.info(e)
+
+ server_process.terminate()
+ server_process.wait()
diff --git a/app_tests/benchmark_tests/llm/utils.py b/app_tests/benchmark_tests/llm/utils.py
new file mode 100644
index 000000000..c217720cb
--- /dev/null
+++ b/app_tests/benchmark_tests/llm/utils.py
@@ -0,0 +1,55 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from argparse import Namespace
+from dataclasses import dataclass
+from pathlib import Path
+
+
+@dataclass
+class SGLangBenchmarkArgs:
+ base_url: str
+ num_prompt: int
+ request_rate: int
+ tokenizer: str | Path
+
+ seed: int = 1
+ extra_request_body: str | None = None
+ output_file: str | Path | None = None
+ port: int = 8000
+ backend: str = "shortfin"
+
+ def as_namespace(self) -> Namespace:
+ return Namespace(
+ num_prompts=self.num_prompt,
+ base_url=self.base_url,
+ tokenizer=str(self.tokenizer),
+ request_rate=self.request_rate,
+ backend=self.backend,
+ output_file=self.output_file,
+ seed=self.seed,
+ extra_request_body=self.extra_request_body,
+ port=8000,
+ model=None,
+ dataset_name="sharegpt",
+ random_input_len=None,
+ random_output_len=None,
+ dataset_path="",
+ sharegpt_output_len=None,
+ multi=False,
+ disable_tqdm=False,
+ disable_stream=False,
+ disable_ignore_eos=False,
+ )
+
+ def __repr__(self):
+ return (
+ f"Backend: {self.backend}\n"
+ f"Base URL: {self.base_url}\n"
+ f"Num Prompt: {self.num_prompt}\n"
+ f"Tokenizer: {self.tokenizer}\n"
+ f"Request Rate: {self.request_rate}"
+ )
diff --git a/app_tests/integration_tests/__init__.py b/app_tests/integration_tests/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/app_tests/integration_tests/llm/__init__.py b/app_tests/integration_tests/llm/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/app_tests/integration_tests/llm/conftest.py b/app_tests/integration_tests/llm/conftest.py
new file mode 100644
index 000000000..17cdf1def
--- /dev/null
+++ b/app_tests/integration_tests/llm/conftest.py
@@ -0,0 +1,135 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import logging
+import os
+from pathlib import Path
+import pytest
+import shutil
+
+pytest.importorskip("transformers")
+from .utils import (
+ download_huggingface_model,
+ download_tokenizer,
+ export_paged_llm_v1,
+ compile_model,
+ find_available_port,
+ start_llm_server,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.fixture(scope="module")
+def model_test_dir(request, tmp_path_factory):
+ """Prepare model artifacts for starting the LLM server.
+
+ Args:
+ request (FixtureRequest): The following params are accepted:
+ - repo_id (str): The Hugging Face repo ID.
+ - model_file (str): The model file to download.
+ - tokenizer_id (str): The tokenizer ID to download.
+ - settings (dict): The settings for sharktank export.
+ - batch_sizes (list): The batch sizes to use for the model.
+ tmp_path_factory (TempPathFactory): Temp dir to save artifacts to.
+
+ Yields:
+ Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir.
+ """
+ logger.info("Preparing model artifacts...")
+
+ repo_id = request.param["repo_id"]
+ model_file = request.param["model_file"]
+ tokenizer_id = request.param["tokenizer_id"]
+ settings = request.param["settings"]
+ batch_sizes = request.param["batch_sizes"]
+
+ tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test")
+ hf_home = os.environ.get("HF_HOME", None)
+ hf_home = Path(hf_home) if hf_home is not None else tmp_dir
+ try:
+ # Download model if it doesn't exist
+ model_path = hf_home / model_file
+ download_huggingface_model(hf_home, repo_id, model_file)
+
+ # Set up tokenizer if it doesn't exist
+ download_tokenizer(hf_home, tokenizer_id)
+
+ # Export model
+ mlir_path = tmp_dir / "model.mlir"
+ config_path = tmp_dir / "config.json"
+ export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)
+
+ # Compile model
+ vmfb_path = tmp_dir / "model.vmfb"
+ compile_model(mlir_path, vmfb_path, settings)
+
+ # Write config
+ edited_config_path = tmp_dir / "edited_config.json"
+ config = {
+ "module_name": "module",
+ "module_abi_version": 1,
+ "max_seq_len": 2048,
+ "attn_head_count": 32,
+ "attn_head_dim": 100,
+ "prefill_batch_sizes": batch_sizes,
+ "decode_batch_sizes": batch_sizes,
+ "transformer_block_count": 26,
+ "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
+ }
+ logger.info(f"Saving edited config to: {edited_config_path}\n")
+ logger.info(f"Config: {json.dumps(config, indent=2)}")
+ with open(edited_config_path, "w") as f:
+ json.dump(config, f)
+ logger.info("Model artifacts setup successfully")
+ yield hf_home, tmp_dir
+ finally:
+ shutil.rmtree(tmp_dir)
+
+
+@pytest.fixture(scope="module")
+def available_port():
+ return find_available_port()
+
+
+@pytest.fixture(scope="module")
+def llm_server(request, model_test_dir, available_port):
+ """Start the LLM server.
+
+ Args:
+ request (FixtureRequest): The following params are accepted:
+ - model_file (str): The model file to download.
+ - settings (dict): The settings for starting the server.
+ model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir.
+ available_port (int): The available port to start the server on.
+
+ Yields:
+ subprocess.Popen: The server process that was started.
+ """
+ logger.info("Starting LLM server...")
+ hf_home, tmp_dir = model_test_dir
+ model_file = request.param["model_file"]
+ settings = request.param["settings"]
+
+ tokenizer_path = hf_home / "tokenizer.json"
+ config_path = tmp_dir / "edited_config.json"
+ vmfb_path = tmp_dir / "model.vmfb"
+ parameters_path = hf_home / model_file
+
+ # Start llm server
+ server_process = start_llm_server(
+ available_port,
+ tokenizer_path,
+ config_path,
+ vmfb_path,
+ parameters_path,
+ settings,
+ )
+ yield server_process
+ # Teardown: kill the server
+ server_process.terminate()
+ server_process.wait()
diff --git a/build_tools/integration_tests/llm/cpu_llm_server_test.py b/app_tests/integration_tests/llm/cpu_llm_server_test.py
similarity index 93%
rename from build_tools/integration_tests/llm/cpu_llm_server_test.py
rename to app_tests/integration_tests/llm/cpu_llm_server_test.py
index 638bce7ee..e7d0792d8 100644
--- a/build_tools/integration_tests/llm/cpu_llm_server_test.py
+++ b/app_tests/integration_tests/llm/cpu_llm_server_test.py
@@ -10,7 +10,7 @@
import requests
import uuid
-from utils import AccuracyValidationException
+from .utils import AccuracyValidationException
logger = logging.getLogger(__name__)
@@ -78,7 +78,6 @@ def do_generate(prompt, port):
],
indirect=True,
)
-@pytest.mark.xfail(raises=AccuracyValidationException)
def test_llm_server(llm_server, available_port):
# Here you would typically make requests to your server
# and assert on the responses
@@ -86,7 +85,6 @@ def test_llm_server(llm_server, available_port):
output = do_generate("1 2 3 4 5 ", available_port)
logger.info(output)
expected_output_prefix = "6 7 8"
- # TODO(#437): Remove when accuracy issue from latest iree-compiler RC is resolved.
if not output.startswith(expected_output_prefix):
raise AccuracyValidationException(
f"Expected '{output}' to start with '{expected_output_prefix}'"
diff --git a/app_tests/integration_tests/llm/utils.py b/app_tests/integration_tests/llm/utils.py
new file mode 100644
index 000000000..b8b5ae60f
--- /dev/null
+++ b/app_tests/integration_tests/llm/utils.py
@@ -0,0 +1,180 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import logging
+import multiprocessing
+import os
+import subprocess
+import sys
+import time
+
+import requests
+from transformers import AutoTokenizer
+
+logger = logging.getLogger("__name__")
+
+
+class AccuracyValidationException(RuntimeError):
+ pass
+
+
+def download_huggingface_model(local_dir, repo_id, model_file):
+ model_path = local_dir / model_file
+ logger.info(f"Preparing model_path: {model_path}..")
+ if not os.path.exists(model_path):
+ logger.info(f"Downloading model {repo_id} {model_file} from Hugging Face...")
+ subprocess.run(
+ f"huggingface-cli download --local-dir {local_dir} {repo_id} {model_file}",
+ shell=True,
+ check=True,
+ )
+ logger.info(f"Model downloaded to {model_path}")
+ else:
+ logger.info("Using cached model")
+
+
+def download_tokenizer(local_dir, tokenizer_id):
+ # Set up tokenizer if it doesn't exist
+ tokenizer_path = local_dir / "tokenizer.json"
+ logger.info(f"Preparing tokenizer_path: {tokenizer_path}...")
+ if not os.path.exists(tokenizer_path):
+ logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...")
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_id,
+ )
+ tokenizer.save_pretrained(local_dir)
+ logger.info(f"Tokenizer saved to {tokenizer_path}")
+ else:
+ logger.info("Using cached tokenizer")
+
+
+def export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes):
+ bs_string = ",".join(map(str, batch_sizes))
+ logger.info(
+ "Exporting model with following settings:\n"
+ f" MLIR Path: {mlir_path}\n"
+ f" Config Path: {config_path}\n"
+ f" Batch Sizes: {bs_string}"
+ )
+ subprocess.run(
+ [
+ "python",
+ "-m",
+ "sharktank.examples.export_paged_llm_v1",
+ f"--{model_path.suffix.strip('.')}-file={model_path}",
+ f"--output-mlir={mlir_path}",
+ f"--output-config={config_path}",
+ f"--bs={bs_string}",
+ ],
+ check=True,
+ )
+ logger.info(f"Model successfully exported to {mlir_path}")
+
+
+def compile_model(mlir_path, vmfb_path, device_settings):
+ logger.info(f"Compiling model to {vmfb_path}")
+ subprocess.run(
+ [
+ "iree-compile",
+ mlir_path,
+ "-o",
+ vmfb_path,
+ ]
+ + device_settings["device_flags"],
+ check=True,
+ )
+ logger.info(f"Model successfully compiled to {vmfb_path}")
+
+
+def find_available_port():
+ import socket
+ from contextlib import closing
+
+ logger.info(f"Finding available port...")
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+ s.bind(("", 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ port = s.getsockname()[1]
+ logger.info(f"Found available port: {port}")
+ return port
+
+
+def wait_for_server(url, timeout=10):
+ logger.info(f"Waiting for server to start at {url}...")
+ start = time.time()
+ while time.time() - start < timeout:
+ try:
+ requests.get(f"{url}/health")
+ logger.info("Server successfully started")
+ return
+ except requests.exceptions.ConnectionError:
+ time.sleep(1)
+ raise TimeoutError(f"Server did not start within {timeout} seconds")
+
+
+def _start_llm_server_args(
+ tokenizer_path,
+ model_config_path,
+ vmfb_path,
+ parameters_path,
+ settings,
+ port,
+):
+ return [
+ sys.executable,
+ "-m",
+ "shortfin_apps.llm.server",
+ f"--tokenizer_json={tokenizer_path}",
+ f"--model_config={model_config_path}",
+ f"--vmfb={vmfb_path}",
+ f"--parameters={parameters_path}",
+ f"--device={settings['device']}",
+ f"--port={port}",
+ ]
+
+
+def start_llm_server(
+ port,
+ tokenizer_path,
+ model_config_path,
+ vmfb_path,
+ parameters_path,
+ settings,
+ timeout=10,
+ multi=False,
+):
+ logger.info("Starting LLM server...")
+ if multi:
+ server_process = multiprocessing.Process(
+ target=subprocess.Popen(
+ _start_llm_server_args(
+ tokenizer_path,
+ model_config_path,
+ vmfb_path,
+ parameters_path,
+ settings,
+ port,
+ ),
+ )
+ )
+ server_process.start()
+
+ else:
+ # Start the server
+ server_process = subprocess.Popen(
+ _start_llm_server_args(
+ tokenizer_path,
+ model_config_path,
+ vmfb_path,
+ parameters_path,
+ settings,
+ port,
+ )
+ )
+ logger.info("Process started... waiting for server")
+ # Wait for server to start
+ wait_for_server(f"http://localhost:{port}", timeout)
+ return server_process
diff --git a/build_tools/integration_tests/llm/conftest.py b/build_tools/integration_tests/llm/conftest.py
deleted file mode 100644
index 9b93a5d96..000000000
--- a/build_tools/integration_tests/llm/conftest.py
+++ /dev/null
@@ -1,206 +0,0 @@
-import json
-import logging
-import os
-from pathlib import Path
-import pytest
-import requests
-import shutil
-import subprocess
-import time
-
-pytest.importorskip("transformers")
-from transformers import AutoTokenizer
-
-logger = logging.getLogger(__name__)
-
-
-@pytest.fixture(scope="module")
-def model_test_dir(request, tmp_path_factory):
- """Prepare model artifacts for starting the LLM server.
-
- Args:
- request (FixtureRequest): The following params are accepted:
- - repo_id (str): The Hugging Face repo ID.
- - model_file (str): The model file to download.
- - tokenizer_id (str): The tokenizer ID to download.
- - settings (dict): The settings for sharktank export.
- - batch_sizes (list): The batch sizes to use for the model.
- tmp_path_factory (TempPathFactory): Temp dir to save artifacts to.
-
- Yields:
- Tuple[Path, Path]: The paths to the Hugging Face home and the temp dir.
- """
- logger.info("Preparing model artifacts...")
-
- repo_id = request.param["repo_id"]
- model_file = request.param["model_file"]
- tokenizer_id = request.param["tokenizer_id"]
- settings = request.param["settings"]
- batch_sizes = request.param["batch_sizes"]
-
- tmp_dir = tmp_path_factory.mktemp("cpu_llm_server_test")
- hf_home = os.environ.get("HF_HOME", None)
- hf_home = Path(hf_home) if hf_home is not None else tmp_dir
- try:
- # Download model if it doesn't exist
- model_path = hf_home / model_file
- logger.info(f"Preparing model_path: {model_path}..")
- if not os.path.exists(model_path):
- logger.info(
- f"Downloading model {repo_id} {model_file} from Hugging Face..."
- )
- subprocess.run(
- f"huggingface-cli download --local-dir {hf_home} {repo_id} {model_file}",
- shell=True,
- check=True,
- )
- logger.info(f"Model downloaded to {model_path}")
- else:
- logger.info("Using cached model")
-
- # Set up tokenizer if it doesn't exist
- tokenizer_path = hf_home / "tokenizer.json"
- logger.info(f"Preparing tokenizer_path: {tokenizer_path}...")
- if not os.path.exists(tokenizer_path):
- logger.info(f"Downloading tokenizer {tokenizer_id} from Hugging Face...")
- tokenizer = AutoTokenizer.from_pretrained(
- tokenizer_id,
- )
- tokenizer.save_pretrained(hf_home)
- logger.info(f"Tokenizer saved to {tokenizer_path}")
- else:
- logger.info("Using cached tokenizer")
-
- # Export model
- mlir_path = tmp_dir / "model.mlir"
- config_path = tmp_dir / "config.json"
- bs_string = ",".join(map(str, batch_sizes))
- logger.info(
- "Exporting model with following settings:\n"
- f" MLIR Path: {mlir_path}\n"
- f" Config Path: {config_path}\n"
- f" Batch Sizes: {bs_string}"
- )
- subprocess.run(
- [
- "python",
- "-m",
- "sharktank.examples.export_paged_llm_v1",
- f"--gguf-file={model_path}",
- f"--output-mlir={mlir_path}",
- f"--output-config={config_path}",
- f"--bs={bs_string}",
- ],
- check=True,
- )
- logger.info(f"Model successfully exported to {mlir_path}")
-
- # Compile model
- vmfb_path = tmp_dir / "model.vmfb"
- logger.info(f"Compiling model to {vmfb_path}")
- subprocess.run(
- [
- "iree-compile",
- mlir_path,
- "-o",
- vmfb_path,
- ]
- + settings["device_flags"],
- check=True,
- )
- logger.info(f"Model successfully compiled to {vmfb_path}")
-
- # Write config if it doesn't exist
- edited_config_path = tmp_dir / "edited_config.json"
- config = {
- "module_name": "module",
- "module_abi_version": 1,
- "max_seq_len": 2048,
- "attn_head_count": 32,
- "attn_head_dim": 100,
- "prefill_batch_sizes": batch_sizes,
- "decode_batch_sizes": batch_sizes,
- "transformer_block_count": 26,
- "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
- }
- logger.info(f"Saving edited config to: {edited_config_path}\n")
- logger.info(f"Config: {json.dumps(config, indent=2)}")
- with open(edited_config_path, "w") as f:
- json.dump(config, f)
- logger.info("Model artifacts setup successfully")
- yield hf_home, tmp_dir
- finally:
- shutil.rmtree(tmp_dir)
-
-
-@pytest.fixture(scope="module")
-def available_port(port=8000, max_port=8100):
- import socket
-
- logger.info(f"Finding available port in range {port}-{max_port}...")
-
- starting_port = port
-
- while port < max_port:
- try:
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(("localhost", port))
- s.close()
- logger.info(f"Found available port: {port}")
- return port
- except socket.error:
- port += 1
-
- raise IOError(f"No available ports found within range {starting_port}-{max_port}")
-
-
-def wait_for_server(url, timeout=10):
- logger.info(f"Waiting for server to start at {url}...")
- start = time.time()
- while time.time() - start < timeout:
- try:
- requests.get(f"{url}/health")
- logger.info("Server successfully started")
- return
- except requests.exceptions.ConnectionError:
- time.sleep(1)
- raise TimeoutError(f"Server did not start within {timeout} seconds")
-
-
-@pytest.fixture(scope="module")
-def llm_server(request, model_test_dir, available_port):
- """Start the LLM server.
-
- Args:
- request (FixtureRequest): The following params are accepted:
- - model_file (str): The model file to download.
- - settings (dict): The settings for starting the server.
- model_test_dir (Tuple[Path, Path]): The paths to the Hugging Face home and the temp dir.
- available_port (int): The available port to start the server on.
-
- Yields:
- subprocess.Popen: The server process that was started.
- """
- logger.info("Starting LLM server...")
- # Start the server
- hf_home, tmp_dir = model_test_dir
- model_file = request.param["model_file"]
- settings = request.param["settings"]
- server_process = subprocess.Popen(
- [
- "python",
- "-m",
- "shortfin_apps.llm.server",
- f"--tokenizer_json={hf_home / 'tokenizer.json'}",
- f"--model_config={tmp_dir / 'edited_config.json'}",
- f"--vmfb={tmp_dir / 'model.vmfb'}",
- f"--parameters={hf_home / model_file}",
- f"--device={settings['device']}",
- ]
- )
- # Wait for server to start
- wait_for_server(f"http://localhost:{available_port}")
- yield server_process
- # Teardown: kill the server
- server_process.terminate()
- server_process.wait()
diff --git a/build_tools/integration_tests/llm/utils.py b/build_tools/integration_tests/llm/utils.py
deleted file mode 100644
index b31a3e416..000000000
--- a/build_tools/integration_tests/llm/utils.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright 2024 Advanced Micro Devices, Inc.
-#
-# Licensed under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-
-class AccuracyValidationException(RuntimeError):
- pass
diff --git a/build_tools/python_deploy/README.md b/build_tools/python_deploy/README.md
new file mode 100644
index 000000000..d36545a9c
--- /dev/null
+++ b/build_tools/python_deploy/README.md
@@ -0,0 +1,48 @@
+# Python Deployment
+
+These scripts assist with building Python packages and pushing them to
+[PyPI (the Python Package Index)](https://pypi.org/). See also
+
+* The Python Packaging User Guide:
+
+## Overview
+
+See comments in scripts for canonical usage. This page includes additional
+notes.
+
+### Package building
+
+These scripts build packages:
+
+* [`/shark-ai/build_tools/build_linux_package.sh`](/shark-ai/build_tools/build_linux_package.sh)
+* [`/sharktank/build_tools/build_linux_package.sh`](/sharktank/build_tools/build_linux_package.sh)
+* [`/shortfin/build_tools/build_linux_package.sh`](/shortfin/build_tools/build_linux_package.sh)
+
+### Version management
+
+These scripts handle versioning across packages, including considerations like
+major, minor, and patch levels (`X.Y.Z`), as well as suffixes like
+`rc20241107`:
+
+* [`compute_common_version.py`](./compute_common_version.py)
+* [`compute_local_version.py`](./compute_local_version.py)
+* [`promote_whl_from_rc_to_final.py`](./promote_whl_from_rc_to_final.py)
+* [`write_requirements.py`](./write_requirements.py)
+
+### PyPI deployment
+
+These scripts handle promoting nightly releases packages to stable and pushing
+to PyPI:
+
+* [`promote_whl_from_rc_to_final.py`](./promote_whl_from_rc_to_final.py)
+* [`pypi_deploy.sh`](./pypi_deploy.sh)
+
+Both of these scripts expect to have the dependencies from
+[`requirements-pypi-deploy.txt`](./requirements-pypi-deploy.txt) installed.
+This can be easily managed by using a Python virtual environment:
+
+```bash
+python -m venv .venv
+source .venv/bin/activate
+python -m pip install -r ./requirements-pypi-deploy.txt
+```
diff --git a/build_tools/python_deploy/compute_common_version.py b/build_tools/python_deploy/compute_common_version.py
new file mode 100755
index 000000000..ed6f8c708
--- /dev/null
+++ b/build_tools/python_deploy/compute_common_version.py
@@ -0,0 +1,79 @@
+#!/usr/bin/env python3
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This scripts grabs the `X.Y.Z[.dev]` version identifier from the
+# 'sharktank' and 'shortfin' version files and computes the version
+# for the meta 'shark-ai' package.
+#
+# Usage:
+# ./compute_common_version.py --stable-release --write-json
+# cat ../../shark-ai/version_local.json
+
+import argparse
+from pathlib import Path
+import json
+from datetime import datetime
+import sys
+
+from packaging.version import Version
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--write-json", action="store_true")
+
+release_type = parser.add_mutually_exclusive_group()
+release_type.add_argument("-stable", "--stable-release", action="store_true") # default
+release_type.add_argument("-rc", "--nightly-release", action="store_true")
+
+
+args = parser.parse_args()
+
+if not (args.stable_release or args.nightly_release):
+ parser.print_usage(sys.stderr)
+ sys.stderr.write("error: A release type is required\n")
+ sys.exit(1)
+
+THIS_DIR = Path(__file__).parent.resolve()
+REPO_ROOT = THIS_DIR.parent.parent
+
+VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version.json"
+VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version.json"
+VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json"
+
+
+def load_version_info(version_file):
+ with open(version_file, "rt") as f:
+ return json.load(f)
+
+
+def write_version_info():
+ with open(VERSION_FILE_LOCAL, "w") as f:
+ json.dump(version_local, f, indent=2)
+ f.write("\n")
+
+
+sharktank_version = load_version_info(VERSION_FILE_SHARKTANK)
+SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version")
+SHARKTANK_BASE_VERSION = Version(SHARKTANK_PACKAGE_VERSION).base_version
+
+shortfin_version = load_version_info(VERSION_FILE_SHORTFIN)
+SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version")
+SHORTFIN_BASE_VERSION = Version(SHORTFIN_PACKAGE_VERSION).base_version
+
+if SHARKTANK_BASE_VERSION > SHORTFIN_BASE_VERSION:
+ COMMON_VERSION = SHARKTANK_BASE_VERSION
+else:
+ COMMON_VERSION = SHORTFIN_BASE_VERSION
+
+if args.nightly_release:
+ COMMON_VERSION += "rc" + datetime.today().strftime("%Y%m%d")
+
+if args.write_json:
+ version_local = {"package-version": COMMON_VERSION}
+ write_version_info()
+
+print(COMMON_VERSION)
diff --git a/build_tools/gen_version_info_rc.py b/build_tools/python_deploy/compute_local_version.py
old mode 100644
new mode 100755
similarity index 59%
rename from build_tools/gen_version_info_rc.py
rename to build_tools/python_deploy/compute_local_version.py
index 9399053b0..46d18d0ed
--- a/build_tools/gen_version_info_rc.py
+++ b/build_tools/python_deploy/compute_local_version.py
@@ -1,3 +1,4 @@
+#!/usr/bin/env python3
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
@@ -5,8 +6,8 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# This scripts grabs the X.Y.Z[.dev]` version identifier from a
-# `version_info.json` and writes the corresponding
-# `X.Y.ZrcYYYYMMDD` version identifier to `version_rc_info.json`.
+# `version.json` and writes the corresponding
+# `X.Y.ZrcYYYYMMDD` version identifier to `version_local.json`.
import argparse
from pathlib import Path
@@ -20,18 +21,18 @@
parser.add_argument("path", type=Path)
args = parser.parse_args()
-VERSION_INFO_FILE = args.path / "version_info.json"
-VERSION_INFO_RC_FILE = args.path / "version_info_rc.json"
+VERSION_FILE = args.path / "version.json"
+VERSION_FILE_LOCAL = args.path / "version_local.json"
def load_version_info():
- with open(VERSION_INFO_FILE, "rt") as f:
+ with open(VERSION_FILE, "rt") as f:
return json.load(f)
def write_version_info():
- with open(VERSION_INFO_RC_FILE, "w") as f:
- json.dump(version_info_rc, f, indent=2)
+ with open(VERSION_FILE_LOCAL, "w") as f:
+ json.dump(version_local, f, indent=2)
f.write("\n")
@@ -39,10 +40,12 @@ def write_version_info():
PACKAGE_VERSION = version_info.get("package-version")
PACKAGE_BASE_VERSION = Version(PACKAGE_VERSION).base_version
-PACKAGE_RC_VERSION = PACKAGE_BASE_VERSION + "rc" + datetime.today().strftime("%Y%m%d")
+PACKAGE_LOCAL_VERSION = (
+ PACKAGE_BASE_VERSION + "rc" + datetime.today().strftime("%Y%m%d")
+)
-version_info_rc = {"package-version": PACKAGE_RC_VERSION}
+version_local = {"package-version": PACKAGE_LOCAL_VERSION}
write_version_info()
-print(PACKAGE_RC_VERSION)
+print(PACKAGE_LOCAL_VERSION)
diff --git a/build_tools/python_deploy/promote_whl_from_rc_to_final.py b/build_tools/python_deploy/promote_whl_from_rc_to_final.py
new file mode 100755
index 000000000..061dd933b
--- /dev/null
+++ b/build_tools/python_deploy/promote_whl_from_rc_to_final.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python3
+
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This scripts takes a file like 'sharktank-2.9.0rc20241110-py3-none-any.whl'
+# with embedded version '2.9.0rc20241110' as input and then drops the
+# 'rcYYYYMMDD' suffix from both the embedded version and file name.
+#
+# Typical usage:
+# pip install -r requirements-pypi-deploy.txt
+# ./promote_whl_from_rc_to_final.py /path/to/file.whl --delete-old-wheel
+
+import argparse
+from change_wheel_version import change_wheel_version
+from packaging.version import Version
+from pathlib import Path
+from pkginfo import Wheel
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "input_file",
+ help="Path to the input .whl file to promote",
+ type=Path,
+ )
+ parser.add_argument(
+ "--delete-old-wheel",
+ help="Deletes the original wheel after successfully promoting it",
+ action="store_true",
+ default=False,
+ )
+ return parser.parse_args()
+
+
+def main(args):
+ original_wheel_path = args.input_file
+ print(f"Promoting whl from rc to final: '{original_wheel_path}'")
+
+ original_wheel = Wheel(original_wheel_path)
+ original_version = Version(original_wheel.version)
+ base_version = original_version.base_version
+ print(
+ f" Original wheel version is '{original_version}' with base '{base_version}'"
+ )
+
+ if str(base_version) == str(original_version):
+ print(" Version is already a release version, skipping")
+ return
+
+ print(f" Changing to base version: '{base_version}'")
+ new_wheel_path = change_wheel_version(original_wheel_path, str(base_version), None)
+ print(f" New wheel path is '{new_wheel_path}'")
+
+ new_wheel = Wheel(new_wheel_path)
+ new_version = Version(new_wheel.version)
+ print(f" New wheel version is '{new_version}'")
+
+ if args.delete_old_wheel:
+ print(" Deleting original wheel")
+ original_wheel_path.unlink()
+
+
+if __name__ == "__main__":
+ main(parse_arguments())
diff --git a/build_tools/python_deploy/pypi_deploy.sh b/build_tools/python_deploy/pypi_deploy.sh
new file mode 100755
index 000000000..63f123ac0
--- /dev/null
+++ b/build_tools/python_deploy/pypi_deploy.sh
@@ -0,0 +1,126 @@
+#!/bin/bash
+
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This script promotes Python packages from nightly releases to PyPI.
+#
+# Prerequisites:
+# * You will need to have PyPI credentials set up. See
+# https://packaging.python.org/en/latest/tutorials/packaging-projects/#uploading-the-distribution-archives
+# * Install requirements, e.g. in a Python virtual environment (venv):
+# `pip install -r requirements-pypi-deploy.txt`
+# * Install python3.13t and install pip. On Ubuntu:
+# ```bash
+# sudo add-apt-repository ppa:deadsnakes
+# sudo apt-get update
+# sudo apt-get install python3.13-nogil
+# python3.13t -m ensurepip --upgrade
+# ```
+# * Choose a release candidate to promote from
+# https://github.com/nod-ai/shark-ai/releases/tag/dev-wheels
+#
+# Usage:
+# ./pypi_deploy.sh 2.9.0rc20241108
+
+set -euo pipefail
+
+RELEASE="$1"
+
+SCRIPT_DIR="$(dirname -- "$( readlink -f -- "$0"; )")";
+REPO_ROOT="$(cd "$SCRIPT_DIR"/../../ && pwd)"
+TMPDIR="$(mktemp --directory --tmpdir shark_platform_pypi_wheels.XXXXX)"
+ASSETS_PAGE="https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels"
+
+# TODO: rewrite in Python?
+
+function download_wheels() {
+ echo ""
+ echo "Downloading wheels for '${RELEASE}'..."
+
+ # sharktank
+ python -m pip download sharktank==${RELEASE} \
+ --no-deps --python-version 3.11 -f ${ASSETS_PAGE}
+
+ # shortfin
+ python -m pip download shortfin==${RELEASE} \
+ --no-deps --python-version 3.11 -f ${ASSETS_PAGE}
+ python -m pip download shortfin==${RELEASE} \
+ --no-deps --python-version 3.12 -f ${ASSETS_PAGE}
+ python -m pip download shortfin==${RELEASE} \
+ --no-deps --python-version 3.13 -f ${ASSETS_PAGE}
+ python -m pip download shortfin==${RELEASE} \
+ --no-deps --python-version 3.13 -f ${ASSETS_PAGE}
+ # TODO: fetch 3.13t using the same `python` somehow
+ # * https://pip.pypa.io/en/stable/cli/pip_download/
+ # * https://py-free-threading.github.io/installing_cpython/
+ # * https://pip.pypa.io/en/stable/installation/
+ python3.13t -m pip download shortfin==${RELEASE} --no-deps -f ${ASSETS_PAGE}
+
+ # TODO: shark-ai meta package when it is published to nightlies
+
+ echo ""
+ echo "Downloaded wheels:"
+ ls
+}
+
+function edit_release_versions() {
+ echo ""
+ echo "Editing release versions..."
+ for file in *
+ do
+ ${SCRIPT_DIR}/promote_whl_from_rc_to_final.py ${file} --delete-old-wheel
+ done
+
+ echo "Edited wheels:"
+ ls
+}
+
+function upload_wheels() {
+ # TODO: list packages that would be uploaded, pause, prompt to continue
+ echo ""
+ echo "Uploading wheels:"
+ ls
+ twine upload --verbose *
+}
+
+function build_shark_ai_meta_package() {
+ # TODO: download meta package from nightly releases instead of this
+ # Be aware that nightly releases pin other dependencies via the
+ # generated `requirements.txt` compared to stable releases.
+ echo ""
+
+ # TODO: rework `write_requirements.py` to use the versions from the downloaded whls?
+ echo "Computing local versions for sharktank and shortfin..."
+ ${SCRIPT_DIR}/compute_local_version.py ${REPO_ROOT}/sharktank
+ ${SCRIPT_DIR}/compute_local_version.py ${REPO_ROOT}/shortfin
+
+ echo "Computing common version for shark-ai meta package..."
+ ${SCRIPT_DIR}/compute_common_version.py --stable-release --write-json
+
+ echo "Writing requirements for shark-ai meta package..."
+ ${SCRIPT_DIR}/write_requirements.py
+
+ echo "Building shark-ai meta package..."
+ ${REPO_ROOT}/shark-ai/build_tools/build_linux_package.sh
+
+ # TODO: This is error-prone. We only want to publish the whl for this release.
+ # Copy instead? Specify exact file name? Clear directory before building?
+ mv ${REPO_ROOT}/shark-ai/build_tools/wheelhouse/* .
+}
+
+function main() {
+ echo "Changing into ${TMPDIR}"
+ cd "${TMPDIR}"
+ # TODO: check_requirements (using pip)
+
+ download_wheels
+ edit_release_versions
+ build_shark_ai_meta_package
+ upload_wheels
+}
+
+main
diff --git a/build_tools/python_deploy/requirements-pypi-deploy.txt b/build_tools/python_deploy/requirements-pypi-deploy.txt
new file mode 100644
index 000000000..dcc32d47a
--- /dev/null
+++ b/build_tools/python_deploy/requirements-pypi-deploy.txt
@@ -0,0 +1,4 @@
+change_wheel_version
+packaging
+pkginfo
+twine
diff --git a/build_tools/python_deploy/write_requirements.py b/build_tools/python_deploy/write_requirements.py
new file mode 100755
index 000000000..38ae5d2b3
--- /dev/null
+++ b/build_tools/python_deploy/write_requirements.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python3
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This script writes the `packaging/shark-ai/requirements.txt` file and pins
+# the versions of the dependencies accordingly. For nighly releases,
+# * sharktank
+# * shortfin
+# get pinned to the corresponding nighly version. For stable releases,
+# * iree-base-compiler
+# * iree-base-runtime
+# * iree-turbine
+# * sharktank
+# * shortfin
+# get pinned to the corresponding `X.Y.*` version.
+
+import argparse
+from pathlib import Path
+import json
+
+from packaging.version import Version
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--version-suffix", action="store", type=str)
+
+args = parser.parse_args()
+
+
+THIS_DIR = Path(__file__).parent.resolve()
+REPO_ROOT = THIS_DIR.parent.parent
+
+VERSION_FILE_SHARKTANK = REPO_ROOT / "sharktank/version_local.json"
+VERSION_FILE_SHORTFIN = REPO_ROOT / "shortfin/version_local.json"
+VERSION_FILE_LOCAL = REPO_ROOT / "shark-ai/version_local.json"
+REQUIREMENTS_TXT = REPO_ROOT / "shark-ai/requirements.txt"
+
+
+def load_version_info(version_file):
+ with open(version_file, "rt") as f:
+ return json.load(f)
+
+
+def write_requirements(requirements):
+ with open(REQUIREMENTS_TXT, "w") as f:
+ f.write("%s\n" % requirements)
+
+
+metapackage_version = load_version_info(VERSION_FILE_LOCAL)
+PACKAGE_VERSION = metapackage_version.get("package-version")
+
+sharktank_version = load_version_info(VERSION_FILE_SHARKTANK)
+SHARKTANK_PACKAGE_VERSION = sharktank_version.get("package-version")
+
+shortfin_version = load_version_info(VERSION_FILE_SHORTFIN)
+SHORTFIN_PACKAGE_VERSION = shortfin_version.get("package-version")
+
+stable_packages_list = ["iree-base-compiler", "iree-base-runtime", "iree-turbine"]
+
+if Version(PACKAGE_VERSION).is_prerelease:
+ # TODO: Include sharktank as a dependencies of future releases
+ # requirements = (
+ # "sharktank=="
+ # + Version(SHARKTANK_PACKAGE_VERSION).base_version
+ # + "rc"
+ # + args.version_suffix
+ # + "\n"
+ # )
+ requirements += (
+ "shortfin=="
+ + Version(SHORTFIN_PACKAGE_VERSION).base_version
+ + "rc"
+ + args.version_suffix
+ )
+
+ write_requirements(requirements)
+
+else:
+ MAJOR_VERSION = Version(PACKAGE_VERSION).major
+ MINOR_VERSION = Version(PACKAGE_VERSION).minor
+
+ STABLE_VERSION_TO_PIN = str(MAJOR_VERSION) + "." + str(MINOR_VERSION) + ".*"
+
+ requirements = ""
+ for package in stable_packages_list:
+ requirements += package + "==" + STABLE_VERSION_TO_PIN + "\n"
+ # TODO: Include sharktank as a dependencies of future releases
+ # requirements += (
+ # "sharktank==" + Version(SHARKTANK_PACKAGE_VERSION).base_version + "\n"
+ # )
+ requirements += "shortfin==" + Version(SHORTFIN_PACKAGE_VERSION).base_version
+
+ write_requirements(requirements)
diff --git a/docs/developer_guide.md b/docs/developer_guide.md
new file mode 100644
index 000000000..832466688
--- /dev/null
+++ b/docs/developer_guide.md
@@ -0,0 +1,65 @@
+# SHARK Developer Guide
+
+Each sub-project has its own developer guide. If you would like to work across
+projects, these instructions should help you get started:
+
+### Setup a venv
+
+We recommend setting up a Python
+[virtual environment (venv)](https://docs.python.org/3/library/venv.html).
+The project is configured to ignore `.venv` directories, and editors like
+VSCode pick them up by default.
+
+```bash
+python -m venv .venv
+source .venv/bin/activate
+```
+
+### Install PyTorch for your system
+
+If no explicit action is taken, the default PyTorch version will be installed.
+This will give you a current CUDA-based version, which takes longer to download
+and includes other dependencies that SHARK does not require. To install a
+different variant, run one of these commands first:
+
+* *CPU:*
+
+ ```bash
+ pip install -r pytorch-cpu-requirements.txt
+ ```
+
+* *ROCM:*
+
+ ```bash
+ pip install -r pytorch-rocm-requirements.txt
+ ```
+
+* *Other:* see instructions at .
+
+### Install development packages
+
+```bash
+# Install editable local projects.
+pip install -r requirements.txt -e sharktank/ shortfin/
+
+# Optionally clone and install the latest editable iree-turbine dep in deps/,
+# along with nightly versions of iree-base-compiler and iree-base-runtime.
+pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
+ iree-base-compiler iree-base-runtime --src deps \
+ -e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
+```
+
+See also: [nightly_releases.md](nightly_releases.md).
+
+### Running tests
+
+```bash
+pytest sharktank
+pytest shortfin
+```
+
+### Optional: pre-commits and developer settings
+
+This project is set up to use the `pre-commit` tooling. To install it in
+your local repo, run: `pre-commit install`. After this point, when making
+commits locally, hooks will run. See https://pre-commit.com/
diff --git a/docs/model_cookbook.md b/docs/model_cookbook.md
index fdf4c7ede..ddc0cb3bb 100644
--- a/docs/model_cookbook.md
+++ b/docs/model_cookbook.md
@@ -1,6 +1,6 @@
# Model cookbook
-Note: These are early notes and commands that the SHARK-Platform team is using
+Note: These are early notes and commands that the shark-ai team is using
and will turn into proper docs later.
## Diagrams
diff --git a/docs/nightly_releases.md b/docs/nightly_releases.md
index 706d6e755..545cdd4f5 100644
--- a/docs/nightly_releases.md
+++ b/docs/nightly_releases.md
@@ -2,19 +2,19 @@
> [!WARNING]
> This is still under development! See
-> https://github.com/nod-ai/SHARK-Platform/issues/400.
+> https://github.com/nod-ai/shark-ai/issues/400.
>
> These instructions will be converted into a user guide once stable packages
-> are published to PyPI: .
+> are published to PyPI: .
Nightly releases are uploaded to
-https://github.com/nod-ai/SHARK-Platform/releases/tag/dev-wheels.
+https://github.com/nod-ai/shark-ai/releases/tag/dev-wheels.
* The "expanded_assets" version of a release page is compatible with the
`-f, --find-links ` options of `pip install`
([docs here](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-f)).
For the "dev-wheels" release above, that page is:
-
+
* These releases are generated using
[`.github/workflows/build_package.yml`](../.github/workflows/build_packages.yml)
* That workflow runs the
@@ -23,7 +23,7 @@ https://github.com/nod-ai/SHARK-Platform/releases/tag/dev-wheels.
[`shortfin/build_tools/build_linux_package.sh`](../shortfin/build_tools/build_linux_package.sh)
scripts
* Workflow history can be viewed at
-
+
## Prerequisites
@@ -38,7 +38,7 @@ source builds.
You will need a recent version of Python.
* As of Nov 1, 2024, sharktank is compatible with Python 3.11. See
- https://github.com/nod-ai/SHARK-Platform/issues/349 for Python 3.12 support.
+ https://github.com/nod-ai/shark-ai/issues/349 for Python 3.12 support.
* As of Nov 4, 2024, shortfin publishes packages for Python 3.11, 3.12, 3.13,
and 3.13t
@@ -67,7 +67,7 @@ python3.11 -m venv 3.11.venv
source 3.11.venv/bin/activate
# Install 'sharktank' package from nightly releases.
-python -m pip install sharktank -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels
+pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
# Test the installation.
python -c "from sharktank import ops; print('Sanity check passed')"
@@ -84,7 +84,7 @@ python3.11 -m venv 3.11.venv
source 3.11.venv/bin/activate
# Install 'shortfin' package from nightly releases.
-python -m pip install shortfin -f https://github.com/nod-ai/SHARK-Platform/releases/expanded_assets/dev-wheels
+pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
# Test the installation.
python -c "import shortfin as sf; print('Sanity check passed')"
@@ -98,7 +98,7 @@ deactivate
To install the `iree-turbine` package from the latest source:
```bash
-python -m pip install --src deps \
+pip install --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
```
@@ -106,14 +106,14 @@ To install the `iree-base-compiler` and `iree-base-runtime` packages from
nightly releases:
```bash
-python -m pip install -f https://iree.dev/pip-release-links.html --upgrade \
+pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler iree-base-runtime
```
To install all three packages together:
```bash
-python -m pip install -f https://iree.dev/pip-release-links.html --upgrade \
+pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler iree-base-runtime --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
```
diff --git a/docs/quantization.md b/docs/quantization.md
index fcc8961b0..25bfc9f8d 100644
--- a/docs/quantization.md
+++ b/docs/quantization.md
@@ -64,11 +64,11 @@ amount of Python code implementing direct math and packing schemes.
PyTorch modules like `Linear` and `Conv2D`.
2. Types/Ops: The `nn.Module` implementations we provide are built in terms
of SHARK Tank custom
- [`InferenceTensor`](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/types/tensors.py#L153)
- and [polymorphic functional ops library](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/signatures.py).
+ [`InferenceTensor`](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/types/tensors.py#L153)
+ and [polymorphic functional ops library](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/signatures.py).
3. Op specializations for optimized subsets of op type signatures and features
(for example, [an optimized affine quantized linear specialization for
- supported combinations of `TensorScaledLayout` arguments](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/qlinear_impls.py)).
+ supported combinations of `TensorScaledLayout` arguments](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qlinear_impls.py)).
(TODO: good place for a diagram)
@@ -78,18 +78,18 @@ amount of Python code implementing direct math and packing schemes.
Available modules that support direct quantization (TODO: refactor to use
torch "Module" terminology and naming schemes consistently):
-* [`LinearLayer`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/layers/linear.py)
-* [convolution layers](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/layers/conv.py)
+* [`LinearLayer`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/linear.py)
+* [convolution layers](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/conv.py)
Note that most sharktank modules extend
-[`ThetaLayer`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/layers/base.py#L63),
+[`ThetaLayer`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/layers/base.py#L63),
which calls for a bit of explanation. Traditional PyTorch Modules directly
instantiate their backing parameters in their constructor. For dataset-heavy
and polymorphic implementations like we commonly see in quantization and
distribution, however, it can be beneficial to separate these concerns.
The `ThetaLayer` simply takes a
-[`Theta` object](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/theta.py#L74),
+[`Theta` object](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/theta.py#L74),
which is a tree-structured bag of native `torch.Tensor` or `InferenceTensor`
instances, and it adopts the tensors in the bag as its own vs creating them.
For those familiar with the concept, this is a form of dependency-injection
@@ -114,7 +114,7 @@ tree to a specific Module instance.
We've already met the `Theta` object above, which holds a tree of something
called an
-[`InferenceTensor`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L153).
+[`InferenceTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L153).
Now we describe what this is. Note that presently, `InferenceTensor` is not a
`torch.Tensor` but its own `ABC` type that:
@@ -140,11 +140,11 @@ pipelines.
There is a growing list of `InferenceTensor` sub-types, many of which are
related to quantization:
-* [`PrimitiveTensor`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L286):
+* [`PrimitiveTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L286):
A simple composition of a single `torch.Tensor`. This is often used
interchangeably with a `torch.Tensor` but is present for completeness of
the type hierarchy and to be able to type select on.
-* [`QuantizedTensor`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L372):
+* [`QuantizedTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L372):
Abstract base class of all quantized tensors, providing two primary operations:
* `unpack`: Accesses the backing `QuantizedLayout` of the tensor, which is
@@ -154,12 +154,12 @@ related to quantization:
layout, this explodes it into a canonical representation of individual
tensors which can be algebraically implemented individually/generically).
-* [`PlanarQuantizedTensor`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408):
+* [`PlanarQuantizedTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408):
Concrete implementation for all non-packed quantized tensors that can be
losslessly represented by a layout based on individual tensor components.
All `QuantizedTensor` instances can be converted to a `PlanarQuantizedTensor`.
-* [`QuantizerTensor`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408):
+* [`QuantizerTensor`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L408):
(note the "r" in the name) An abstract `InferenceTensor` that exposes a
`quantize(torch.Tensor | InferenceTensor) -> QuantizedTensor` operation used
to transform an arbitrary tensor to a quantized form. There are a handful
@@ -178,7 +178,7 @@ manipulate tensor contents via `QuantizedLayout`, but we haven't yet defined
that. The *Tensor types are structural and exist to give identity, but the
`QuantizedLayout` is where the "magic happens".
-[`QuantizedLayout`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/tensors.py#L44)
+[`QuantizedLayout`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/tensors.py#L44)
is an `ABC`, supporting:
* Serialization/interop with parameter archives.
@@ -193,7 +193,7 @@ is an `ABC`, supporting:
There are a number of implementations, as every quantization scheme typically
needs at least one concrete `QuantizedLayout`. Simple schemes like affine
quantization can be fully defined in terms of a single
-[`TensorScaledLayout`](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/types/layouts.py#L43).
+[`TensorScaledLayout`](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/types/layouts.py#L43).
Whereas packed schemes like we find in inference engines like GGML and XNNPACK
optimally require both a packed layout and a planar layout.
@@ -224,7 +224,7 @@ interpreting/transforming using their natively defined forms.
Previously, we found a rich type system defining all manner of layouts and
quantization schemes, but what can be done with it? That is where the
sharktank functional op library comes in. These
-[logical ops](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/signatures.py)
+[logical ops](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/signatures.py)
provide the building blocks to implement built-in and custom `nn.Module`
implementations operating on `InferenceTensor` (and torch.Tensor) types.
@@ -239,12 +239,12 @@ implementation at any needed level of granularity:
structures and preserve it when computing (when combined with a
fusing compiler, this alone provides decent fallback implementations for a
variety of "weight compression" oriented techniques). See
- [some examples](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/custom_impls.py#L51).
+ [some examples](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/custom_impls.py#L51).
* Pure-Torch decompositions for algebraic techniques like affine quantization
(when combined with a fusing compiler, this alone is sufficient for
optimization). See
- [qlinear](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/qlinear_impls.py) and
- [qconv](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/ops/qconv_impls.py)
+ [qlinear](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qlinear_impls.py) and
+ [qconv](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/ops/qconv_impls.py)
implementations of actual affine quantized decompositions.
* Completely custom packed/optimized implementation. These can be written to
activate on any level of detail of the type hierarchy. The implementation
@@ -280,8 +280,8 @@ level. Some examples:
[tensor trace/print](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/iree.py#L52)
* [Simple linalg based template expansion](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/_jinja_test_ops.py#L28)
(see backing example [jinja template](https://github.com/iree-org/iree-turbine/blob/main/iree.turbine/ops/templates/test_add_jinja.mlir)).
-* Optimal linalg-based [8-bit block scaled mmt for weight compression](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/kernels/mmt_block_scaled_q8.py)
- (see backing [jinja template](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/kernels/templates/mmt_block_scaled_q8_3d.mlir)).
+* Optimal linalg-based [8-bit block scaled mmt for weight compression](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/mmt_block_scaled_q8.py)
+ (see backing [jinja template](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/templates/mmt_block_scaled_q8_3d.mlir)).
* DSL based [like this fused attention kernel](https://github.com/iree-org/iree-turbine/blob/main/tests/kernel/fused_attention_test.py#L20)
(note that in this case, the DSL exports to the unerlying IR-based registration
mechanism used in the previous examples).
@@ -292,8 +292,8 @@ Since all of these types of custom kernels are just defined with simple Python
tooling, they are really fast to iterate on. The linalg based kernels specifically
tend to be highly portable, and we don't hesitate to write one of those when
we need something specific that PyTorch doesn't provide out of the box
-(i.e. [proper mixed-precision integer conv](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py)
-([template](https://github.com/nod-ai/SHARK-Platform/blob/main/sharktank/sharktank/kernels/templates/conv_2d_nchw_fchw.mlir))).
+(i.e. [proper mixed-precision integer conv](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/conv_2d_nchw_fchw.py)
+([template](https://github.com/nod-ai/shark-ai/blob/main/sharktank/sharktank/kernels/templates/conv_2d_nchw_fchw.mlir))).
## Dataset transformation
@@ -307,7 +307,7 @@ We take a practical approach to this, writing implementation specific converters
where needed, and taking advantage of industry-standard consolidation points
where available (like GGUF) in order to cover a wider surface area.
-Behind both is the notion of a [`Dataset`](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/types/theta.py#L263),
+Behind both is the notion of a [`Dataset`](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/types/theta.py#L263),
which combines some set of hyper-parameters with a root `Theta` object
(typically representing the layer-tree of frozen tensors). Datasets can be
losslessly persisted to IREE IRPA files, which can then be loaded by either
@@ -321,9 +321,9 @@ transform, shard, etc.
See some examples:
-* [models/punet/tools/import_hf_dataset.py](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_hf_dataset.py) :
+* [models/punet/tools/import_hf_dataset.py](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_hf_dataset.py) :
Creating a `Dataset` object from an HF diffusers safetensors file and config.json.
-* [models/punet/tools/import_brevitas_dataset.py](https://github.com/nod-ai/SHARK-Platform/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py) :
+* [models/punet/tools/import_brevitas_dataset.py](https://github.com/nod-ai/shark-ai/blob/quant_docs/sharktank/sharktank/models/punet/tools/import_brevitas_dataset.py) :
Creates a quantized `Dataset` by combining:
* HF diffusers `config.json`
diff --git a/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md b/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md
new file mode 100644
index 000000000..1ce2d1e8d
--- /dev/null
+++ b/docs/shortfin/llm/developer/e2e_llama8b_mi300x.md
@@ -0,0 +1,242 @@
+# LLama 8b GPU Instructions on MI300X
+
+**NOTE: This was ran on the `mi300x-3` system**
+
+## Setup
+
+We will use an example with `llama_8b_f16_decomposed` in order to describe the
+process of exporting a model for use in the shortfin llm server with an MI300 GPU.
+
+### Pre-Requisites
+
+- Python >= 3.11 is recommended for this flow
+ - You can check out [pyenv](https://github.com/pyenv/pyenv) as a good tool
+ to be able to manage multiple versions of python on the same system.
+
+### Setting Up Environment
+
+Follow the `Development Getting Started` docs
+[here](https://github.com/nod-ai/shark-ai/blob/main/README.md#development-getting-started)
+to setup your environment for development.
+
+We will use an example with `llama_8b_f16_decomposed` in order to describe the
+process of exporting a model for use in the shortfin llm server with an MI300 GPU.
+
+### Define a directory for export files
+
+Create a new directory for us to export files like `model.mlir`, `model.vmfb`, etc.
+
+```bash
+mkdir $PWD/export
+export EXPORT_DIR=$PWD/exportd
+```
+
+### Define environment variables
+
+Define the following environment variables to make running this example a bit easier:
+
+#### Model/Tokenizer vars
+
+This example uses the `llama8b_f16.irpa` and `tokenizer.json` files that are
+pre-existing on the MI300X-3 system.
+You may need to change the paths for your own system.
+
+```bash
+export MODEL_PARAMS_PATH=/data/llama3.1/8b/llama8b_f16.irpa # Path to existing .irpa file, may need to change w/ system
+export TOKENIZER_PATH=/data/llama3.1/8b/tokenizer.json # Path to existing tokenizer.json, may need to change w/ system
+```
+
+#### General env vars
+
+The following env vars can be copy + pasted directly:
+
+```bash
+export MLIR_PATH=$EXPORT_DIR/model.mlir # Path to export model.mlir file
+export OUTPUT_CONFIG_PATH=$EXPORT_DIR/config.json # Path to export config.json file
+export EDITED_CONFIG_PATH=$EXPORT_DIR/edited_config.json # Path to export config.json file
+export VMFB_PATH=$EXPORT_DIR/model.vmfb # Path to export model.vmfb file
+export BS=1,4 # Batch size for kvcache
+export ROCR_VISIBLE_DEVICES=1 # NOTE: This is temporary, until multi-device is fixed
+```
+
+### Export to MLIR
+
+We will now use the `sharktank.examples.export_paged_llm_v1` script to export
+our model to `.mlir` format.
+
+```bash
+python -m sharktank.examples.export_paged_llm_v1 \
+ --irpa-file=$MODEL_PARAMS_PATH \
+ --output-mlir=$MLIR_PATH \
+ --output-config=$OUTPUT_CONFIG_PATH \
+ --bs=$BS
+```
+
+## Compiling to `.vmfb`
+
+Now that we have generated a `model.mlir` file, we can compile it to `.vmfb`
+format, which is required for running the `shortfin` LLM server.
+
+We will use the [iree-compile](https://iree.dev/developers/general/developer-overview/#iree-compile)
+tool for compiling our model.
+
+### Compile for MI300
+
+**NOTE: This command is specific to MI300 GPUs.
+For other `--iree-hip-target` GPU options,
+look [here](https://iree.dev/guides/deployment-configurations/gpu-rocm/#compile-a-program)**
+
+```bash
+iree-compile $MLIR_PATH \
+ --iree-hal-target-backends=rocm \
+ --iree-hip-target=gfx942 \
+ -o $VMFB_PATH
+```
+
+## Write an edited config
+
+We need to write a config for our model with a slightly edited structure
+to run with shortfin. This will work for the example in our docs.
+You may need to modify some of the parameters for a specific model.
+
+### Write edited config
+
+```bash
+cat > $EDITED_CONFIG_PATH << EOF
+{
+ "module_name": "module",
+ "module_abi_version": 1,
+ "max_seq_len": 131072,
+ "attn_head_count": 8,
+ "attn_head_dim": 128,
+ "prefill_batch_sizes": [
+ $BS
+ ],
+ "decode_batch_sizes": [
+ $BS
+ ],
+ "transformer_block_count": 32,
+ "paged_kv_cache": {
+ "block_seq_stride": 16,
+ "device_block_count": 256
+ }
+}
+EOF
+```
+
+## Running the `shortfin` LLM server
+
+We should now have all of the files that we need to run the shortfin LLM server.
+
+Verify that you have the following in your specified directory ($EXPORT_DIR):
+
+```bash
+ls $EXPORT_DIR
+```
+
+- edited_config.json
+- model.vmfb
+
+### Launch server:
+
+#### Set the target device
+
+
+
+#### Run the shortfin server
+
+Run the following command to launch the Shortfin LLM Server in the background:
+
+> **Note**
+> By default, our server will start at `http://localhost:8000`.
+> You can specify the `--host` and/or `--port` arguments, to run at a different address.
+>
+> If you receive an error similar to the following:
+>
+> `[errno 98] address already in use`
+>
+> Then, you can confirm the port is in use with `ss -ntl | grep 8000`
+> and either kill the process running at that port,
+> or start the shortfin server at a different port.
+
+```bash
+python -m shortfin_apps.llm.server \
+ --tokenizer_json=$TOKENIZER_PATH \
+ --model_config=$EDITED_CONFIG_PATH \
+ --vmfb=$VMFB_PATH \
+ --parameters=$MODEL_PARAMS_PATH \
+ --device=hip > shortfin_llm_server.log 2>&1 &
+shortfin_process=$!
+```
+
+You can verify your command has launched successfully when you see the following
+ logs outputted to terminal:
+
+```bash
+cat shortfin_llm_server.log
+```
+
+#### Expected output
+
+```text
+[2024-10-24 15:40:27.440] [info] [on.py:62] Application startup complete.
+[2024-10-24 15:40:27.444] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
+```
+
+## Verify server
+
+### Client script
+
+We can test the LLM server, by running our client script:
+
+```bash
+python shortfin/python/shortfin_apps/llm/client.py --port 8000
+```
+
+### Simple request
+
+Or by sending a simple request:
+
+### Open python shell
+
+```bash
+python
+```
+
+### Send request
+
+```python
+import requests
+
+import os
+
+port = 8000 # Change if running at a different port
+
+generate_url = f"http://localhost:{port}/generate"
+
+def generation_request():
+ payload = {"text": "What is the capital of the United States?", "sampling_params": {"max_completion_tokens": 50}}
+ try:
+ resp = requests.post(generate_url, json=payload)
+ resp.raise_for_status() # Raises an HTTPError for bad responses
+ print(resp.text)
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+
+generation_request()
+```
+
+After you receive the request, you can exit the python shell:
+
+```bash
+quit()
+```
+
+## Cleanup
+
+When done, you can kill the shortfin_llm_server by killing the process:
+
+```bash
+kill -9 $shortfin_process
+```
diff --git a/docs/shortfin/llm/user/e2e_llama8b_mi300x.md b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md
new file mode 100644
index 000000000..5e0749546
--- /dev/null
+++ b/docs/shortfin/llm/user/e2e_llama8b_mi300x.md
@@ -0,0 +1,278 @@
+# LLama 8b GPU instructions on MI300X
+
+## Setup
+
+We will use an example with `llama_8b_f16` in order to describe the
+process of exporting a model for use in the shortfin llm server with an
+MI300 GPU.
+
+### Pre-Requisites
+
+- Python >= 3.11 is recommended for this flow
+ - You can check out [pyenv](https://github.com/pyenv/pyenv)
+ as a good tool to be able to manage multiple versions of python
+ on the same system.
+
+### Create virtual environment
+
+To start, create a new virtual environment:
+
+```bash
+python -m venv --prompt shark-ai .venv
+source .venv/bin/activate
+```
+
+### Install `shark-ai`
+
+You can install either the `latest stable` version of `shark-ai`
+or the `nightly` version:
+
+#### Stable
+
+```bash
+pip install shark-ai
+```
+
+#### Nightly
+
+```bash
+pip install sharktank -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
+pip install shortfin -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
+```
+
+#### Install dataclasses-json
+
+
+
+```bash
+pip install dataclasses-json
+```
+
+### Define a directory for export files
+
+Create a new directory for us to export files like
+`model.mlir`, `model.vmfb`, etc.
+
+```bash
+mkdir $PWD/export
+export EXPORT_DIR=$PWD/export
+```
+
+### Download llama3_8b_fp16.gguf
+
+We will use the `hf_datasets` module in `sharktank` to download a
+LLama3.1 8b f16 model.
+
+```bash
+python -m sharktank.utils.hf_datasets amd-shark/llama3.1-8B --local-dir $EXPORT_DIR
+```
+
+### Define environment variables
+
+Define the following environment variables to make running
+this example a bit easier:
+
+#### Model/Tokenizer vars
+
+This example uses the `llama8b_f16.gguf` and `tokenizer.json` files
+that were downloaded in the previous step.
+
+```bash
+export MODEL_PARAMS_PATH=$EXPORT_DIR/llama3.1-8b/llama8b_f16.gguf
+export TOKENIZER_PATH=$EXPORT_DIR/llama3.1-8b/tokenizer.json
+```
+
+#### General env vars
+
+The following env vars can be copy + pasted directly:
+
+```bash
+# Path to export model.mlir file
+export MLIR_PATH=$EXPORT_DIR/model.mlir
+# Path to export config.json file
+export OUTPUT_CONFIG_PATH=$EXPORT_DIR/config.json
+# Path to export edited_config.json file
+export EDITED_CONFIG_PATH=$EXPORT_DIR/edited_config.json
+# Path to export model.vmfb file
+export VMFB_PATH=$EXPORT_DIR/model.vmfb
+# Batch size for kvcache
+export BS=1,4
+# NOTE: This is temporary, until multi-device is fixed
+export ROCR_VISIBLE_DEVICES=1
+```
+
+## Export to MLIR
+
+We will now use the `sharktank.examples.export_paged_llm_v1` script
+to export our model to `.mlir` format.
+
+```bash
+python -m sharktank.examples.export_paged_llm_v1 \
+ --irpa-file=$MODEL_PARAMS_PATH \
+ --output-mlir=$MLIR_PATH \
+ --output-config=$OUTPUT_CONFIG_PATH \
+ --bs=$BS
+```
+
+## Compiling to `.vmfb`
+
+Now that we have generated a `model.mlir` file,
+we can compile it to `.vmfb` format, which is required for running
+the `shortfin` LLM server.
+
+We will use the
+[iree-compile](https://iree.dev/developers/general/developer-overview/#iree-compile)
+tool for compiling our model.
+
+### Compile for MI300
+
+**NOTE: This command is specific to MI300 GPUs.
+For other `--iree-hip-target` GPU options,
+look [here](https://iree.dev/guides/deployment-configurations/gpu-rocm/#compile-a-program)**
+
+```bash
+iree-compile $MLIR_PATH \
+ --iree-hal-target-backends=rocm \
+ --iree-hip-target=gfx942 \
+ -o $VMFB_PATH
+```
+
+## Write an edited config
+
+We need to write a config for our model with a slightly edited structure
+to run with shortfin. This will work for the example in our docs.
+You may need to modify some of the parameters for a specific model.
+
+### Write edited config
+
+```bash
+cat > $EDITED_CONFIG_PATH << EOF
+{
+ "module_name": "module",
+ "module_abi_version": 1,
+ "max_seq_len": 131072,
+ "attn_head_count": 8,
+ "attn_head_dim": 128,
+ "prefill_batch_sizes": [
+ $BS
+ ],
+ "decode_batch_sizes": [
+ $BS
+ ],
+ "transformer_block_count": 32,
+ "paged_kv_cache": {
+ "block_seq_stride": 16,
+ "device_block_count": 256
+ }
+}
+EOF
+```
+
+## Running the `shortfin` LLM server
+
+We should now have all of the files that we need to run the shortfin LLM server.
+
+Verify that you have the following in your specified directory ($EXPORT_DIR):
+
+```bash
+ls $EXPORT_DIR
+```
+
+- edited_config.json
+- model.vmfb
+
+### Launch server:
+
+
+
+#### Run the shortfin server
+
+Now that we are finished with setup, we can start the Shortfin LLM Server.
+
+Run the following command to launch the Shortfin LLM Server in the background:
+
+> **Note**
+> By default, our server will start at `http://localhost:8000`.
+> You can specify the `--host` and/or `--port` arguments, to run at a different address.
+>
+> If you receive an error similar to the following:
+>
+> `[errno 98] address already in use`
+>
+> Then, you can confirm the port is in use with `ss -ntl | grep 8000`
+> and either kill the process running at that port,
+> or start the shortfin server at a different port.
+
+```bash
+python -m shortfin_apps.llm.server \
+ --tokenizer_json=$TOKENIZER_PATH \
+ --model_config=$EDITED_CONFIG_PATH \
+ --vmfb=$VMFB_PATH \
+ --parameters=$MODEL_PARAMS_PATH \
+ --device=hip > shortfin_llm_server.log 2>&1 &
+shortfin_process=$!
+```
+
+You can verify your command has launched successfully
+when you see the following logs outputted to terminal:
+
+```bash
+cat shortfin_llm_server.log
+```
+
+#### Expected output
+
+```text
+[2024-10-24 15:40:27.440] [info] [on.py:62] Application startup complete.
+[2024-10-24 15:40:27.444] [info] [server.py:214] Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
+```
+
+## Verify server
+
+We can now verify our LLM server by sending a simple request:
+
+### Open python shell
+
+```bash
+python
+```
+
+### Send request
+
+```python
+import requests
+
+import os
+
+port = 8000 # Change if running on a different port
+
+generate_url = f"http://localhost:{port}/generate"
+
+def generation_request():
+ payload = {"text": "What is the capital of the United States?", "sampling_params": {"max_completion_tokens": 50}}
+ try:
+ resp = requests.post(generate_url, json=payload)
+ resp.raise_for_status() # Raises an HTTPError for bad responses
+ print(resp.text)
+ except requests.exceptions.RequestException as e:
+ print(f"An error occurred: {e}")
+
+generation_request()
+```
+
+After you receive the request, you can exit the python shell:
+
+```bash
+quit()
+```
+
+## Cleanup
+
+When done, you can kill the shortfin_llm_server by killing the process:
+
+```bash
+kill -9 $shortfin_process
+```
diff --git a/docs/user_guide.md b/docs/user_guide.md
new file mode 100644
index 000000000..c3da1f4f5
--- /dev/null
+++ b/docs/user_guide.md
@@ -0,0 +1,115 @@
+# SHARK User Guide
+
+> [!WARNING]
+> This is still pre-release so the artifacts listed here may be broken
+>
+
+These instructions cover the usage of the latest stable release of SHARK. For a more bleeding edge release please install the [nightly releases](nightly_releases.md).
+
+## Prerequisites
+
+Our current user guide requires that you have:
+- Access to a computer with an installed AMD Instinctâ„¢ MI300x Series Accelerator
+- Installed a compatible version of Linux and ROCm on the computer (see the [ROCm compatability matrix](https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html))
+
+
+## Set up Environment
+
+This section will help you install Python and set up a Python environment with venv.
+
+Officially we support Python versions: 3.11, 3.12, 3.13, 3.13t
+
+The rest of this guide assumes you are using Python 3.11.
+
+### Install Python
+To install Python 3.11 on Ubuntu:
+
+```bash
+sudo apt install python3.11 python3.11-dev python3.11-venv
+
+which python3.11
+# /usr/bin/python3.11
+```
+
+### Create a Python Environment
+
+Setup your Python environment with the following commands:
+
+```bash
+# Set up a virtual environment to isolate packages from other envs.
+python3.11 -m venv 3.11.venv
+source 3.11.venv/bin/activate
+```
+
+## Install SHARK and its dependencies
+
+```bash
+pip install shark-ai[apps]
+```
+
+Temporarily, you may need an update to your `shortfin` install.
+Install the latest pre-release with:
+```
+pip install shortfin --upgrade --pre -f https://github.com/nod-ai/shark-ai/releases/expanded_assets/dev-wheels
+```
+
+### Test the installation.
+
+```
+python -m shortfin_apps.sd.server --help
+```
+
+## Quickstart
+
+### Run the SDXL Server
+
+Run the [SDXL Server](../shortfin/python/shortfin_apps/sd/README.md#Start-SDXL-Server)
+
+### Run the SDXL Client
+
+```
+python -m shortfin_apps.sd.simple_client --interactive
+```
+
+Congratulations!!! At this point you can play around with the server and client based on your usage.
+
+### Update flags
+
+Please see --help for both the server and client for usage instructions. Here's a quick snapshot.
+
+#### Update server options:
+
+| Flags | options |
+|---|---|
+|--host HOST |
+|--port PORT | server port |
+|--root-path ROOT_PATH |
+|--timeout-keep-alive |
+|--device | local-task,hip,amdgpu | amdgpu only supported in this release
+|--target | gfx942,gfx1100 | gfx942 only supported in this release
+|--device_ids |
+|--tokenizers |
+|--model_config |
+| --workers_per_device |
+| --fibers_per_device |
+| --isolation | per_fiber, per_call, none |
+| --show_progress |
+| --trace_execution |
+| --amdgpu_async_allocations |
+| --splat |
+| --build_preference | compile,precompiled |
+| --compile_flags |
+| --flagfile FLAGFILE |
+| --artifacts_dir ARTIFACTS_DIR | Where to store cached artifacts from the Cloud |
+
+#### Update client with different options:
+
+| Flags |options|
+|---|---
+|--file |
+|--reps |
+|--save | Whether to save image generated by the server |
+|--outputdir| output directory to store images generated by SDXL |
+|--steps |
+|--interactive |
+|--port| port to interact with server |
diff --git a/shark-ai/.gitignore b/shark-ai/.gitignore
new file mode 100644
index 000000000..80bf001b8
--- /dev/null
+++ b/shark-ai/.gitignore
@@ -0,0 +1,2 @@
+# Local-only config options
+requirements.txt
diff --git a/shark-ai/README.md b/shark-ai/README.md
new file mode 100644
index 000000000..0bb1abafd
--- /dev/null
+++ b/shark-ai/README.md
@@ -0,0 +1,3 @@
+# SHARK AI meta package
+
+Meta package to install `shortfin` and compatible IREE packages.
diff --git a/shark-ai/build_tools/build_linux_package.sh b/shark-ai/build_tools/build_linux_package.sh
new file mode 100755
index 000000000..d16f339b1
--- /dev/null
+++ b/shark-ai/build_tools/build_linux_package.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# build_linux_package.sh
+#
+# Builds shark-ai Python package for Linux.
+#
+# Usage:
+# ./build_tools/build_linux_package.sh
+
+set -xeu -o errtrace
+
+THIS_DIR="$(cd $(dirname $0) && pwd)"
+REPO_ROOT="$(cd "$THIS_DIR"/../../ && pwd)"
+OUTPUT_DIR="${OUTPUT_DIR:-${THIS_DIR}/wheelhouse}"
+
+python -m pip wheel --disable-pip-version-check --no-deps -v -w "${OUTPUT_DIR}" "${REPO_ROOT}/shark-ai"
+
+wheel_output="$(echo "${OUTPUT_DIR}/shark_ai-"*".whl")"
+ls "${wheel_output}"
diff --git a/shark-ai/pyproject.toml b/shark-ai/pyproject.toml
new file mode 100644
index 000000000..3f7e4a1da
--- /dev/null
+++ b/shark-ai/pyproject.toml
@@ -0,0 +1,36 @@
+[build-system]
+requires = ["setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "shark-ai"
+authors = [
+ {name = "SHARK Authors"},
+]
+description = "SHARK AI meta package"
+readme = "README.md"
+license = {text = "Apache-2.0"}
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+]
+# Version is set via the `setup.py` and requirements are set via files below.
+dynamic = ["version", "dependencies"]
+
+[project.urls]
+Repository = "https://github.com/nod-ai/shark-ai"
+
+[project.optional-dependencies]
+onnx = [
+ "iree-base-compiler[onnx]",
+]
+apps = [
+ "shortfin[apps]",
+]
+
+[tool.setuptools]
+packages = []
+
+[tool.setuptools.dynamic]
+dependencies = {file = ["requirements.txt"]}
diff --git a/shark-ai/setup.py b/shark-ai/setup.py
new file mode 100644
index 000000000..5ceac55bd
--- /dev/null
+++ b/shark-ai/setup.py
@@ -0,0 +1,33 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import os
+from pathlib import Path
+
+from setuptools import setup
+
+THIS_DIR = Path(__file__).parent.resolve()
+
+# Setup and get version information.
+# The `version_local.json` is generated by calling:
+# `build_tools/python_deploy/compute_common_version.py -stable --write-json`
+VERSION_FILE_LOCAL = os.path.join(THIS_DIR, "version_local.json")
+
+
+def load_version_info(version_file):
+ with open(version_file, "rt") as f:
+ return json.load(f)
+
+
+version_info = load_version_info(VERSION_FILE_LOCAL)
+
+PACKAGE_VERSION = version_info.get("package-version")
+print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'")
+
+setup(
+ version=f"{PACKAGE_VERSION}",
+)
diff --git a/sharktank/README.md b/sharktank/README.md
index c36cdd055..7770595ed 100644
--- a/sharktank/README.md
+++ b/sharktank/README.md
@@ -12,7 +12,7 @@ tooling.
## Project Status
-[![CI - Perplexity](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_eval.yaml/badge.svg?branch=main&event=schedule)](https://github.com/nod-ai/SHARK-Platform/actions/workflows/ci_eval.yaml)
+[![CI - Perplexity](https://github.com/nod-ai/shark-ai/actions/workflows/ci_eval.yaml/badge.svg?branch=main&event=schedule)](https://github.com/nod-ai/shark-ai/actions/workflows/ci_eval.yaml)
## Examples
diff --git a/sharktank/pyproject.toml b/sharktank/pyproject.toml
index 65f264d16..01cad409b 100644
--- a/sharktank/pyproject.toml
+++ b/sharktank/pyproject.toml
@@ -14,17 +14,13 @@ classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.11",
- "Programming Language :: Python :: 3.12",
- "Programming Language :: Python :: 3.13",
]
-requires-python = ">= 3.11"
# Version is set via the `setup.py` and requirements are set via files below.
dynamic = ["version", "dependencies", "optional-dependencies"]
[project.urls]
-Repository = "https://github.com/nod-ai/SHARK-Platform"
+Repository = "https://github.com/nod-ai/shark-ai"
[tool.setuptools.packages.find]
where = ["."]
diff --git a/sharktank/requirements.txt b/sharktank/requirements.txt
index dd8f14fb6..6b533d977 100644
--- a/sharktank/requirements.txt
+++ b/sharktank/requirements.txt
@@ -2,7 +2,7 @@ iree-turbine
# Runtime deps.
gguf==0.6.0
-numpy==1.26.3
+numpy<2.0
# Needed for newer gguf versions (TODO: remove when gguf package includes this)
# sentencepiece>=0.1.98,<=0.2.0
diff --git a/sharktank/setup.py b/sharktank/setup.py
index aca5c63d0..182f94abc 100644
--- a/sharktank/setup.py
+++ b/sharktank/setup.py
@@ -13,8 +13,8 @@
SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__))
# Setup and get version information.
-VERSION_INFO_FILE = os.path.join(SETUPPY_DIR, "version_info.json")
-VERSION_INFO_RC_FILE = os.path.join(SETUPPY_DIR, "version_info_rc.json")
+VERSION_FILE = os.path.join(SETUPPY_DIR, "version.json")
+VERSION_FILE_LOCAL = os.path.join(SETUPPY_DIR, "version_local.json")
def load_version_info(version_file):
@@ -23,10 +23,10 @@ def load_version_info(version_file):
try:
- version_info = load_version_info(VERSION_INFO_RC_FILE)
+ version_info = load_version_info(VERSION_FILE_LOCAL)
except FileNotFoundError:
- print("version_info_rc.json not found. Default to dev build")
- version_info = load_version_info(VERSION_INFO_FILE)
+ print("version_local.json not found. Default to dev build")
+ version_info = load_version_info(VERSION_FILE)
PACKAGE_VERSION = version_info.get("package-version")
print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'")
diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py
index 7bf76a2ce..a740f0bff 100644
--- a/sharktank/sharktank/examples/export_paged_llm_v1.py
+++ b/sharktank/sharktank/examples/export_paged_llm_v1.py
@@ -54,24 +54,19 @@ def main():
help="Enables strictness during export",
action="store_true",
)
- parser.add_argument(
- "--attention-kernel",
- type=str,
- default="decomposed",
- choices=["decomposed", "torch"],
- )
-
+ cli.add_quantization_options(parser)
+ cli.add_model_options(parser)
args = cli.parse(parser)
dataset_type = cli.get_input_data_files(args)
dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
dataset = cli.get_input_dataset(args)
-
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
tensor_parallelism_size = (
dataset.properties["tensor_parallelism_size"]
if "tensor_parallelism_size" in dataset.properties
else 1
)
+
llama_config = LlamaModelConfig(
hp,
tensor_parallelism_size=tensor_parallelism_size,
@@ -80,6 +75,7 @@ def main():
kv_cache_type="direct" if args.bs == [1] else "paged",
attention_kernel=args.attention_kernel,
)
+ llama_config.fake_quant = args.fake_quant
if llama_config.hp.expert_count:
if llama_config.hp.model_arch == "grok":
@@ -214,8 +210,6 @@ def _(model, tokens, seq_lens, seq_block_ids, cs):
cache_tensors = repack_cache(cs, cache_shard_dim)
- cache_tensors = [model.cache.unflatten_page_table(cache_tensors)]
-
logits = model.prefill(
tokens,
attention_mask=attention_mask,
@@ -302,8 +296,6 @@ def _(
cache_state = repack_cache(cache_state, cache_shard_dim)
- cache_state = [model.cache.unflatten_page_table(cache_state)]
-
logits = model.decode(
tokens,
attention_mask=attention_mask,
diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py
index 10c76e644..b30acc026 100644
--- a/sharktank/sharktank/examples/paged_llm_v1.py
+++ b/sharktank/sharktank/examples/paged_llm_v1.py
@@ -196,6 +196,14 @@ def decode(self):
trace_tensor("decode.start_positions", start_positions)
trace_tensor("decode.seq_block_ids", seq_block_ids_tensor)
trace_tensor("decode.attention_mask", decode_attention_mask)
+
+ if model.config.tensor_parallelism_size != 1:
+ tp = model.config.tensor_parallelism_size
+ self.next_tokens = replicate(self.next_tokens, tp)
+ start_positions = replicate(start_positions, tp)
+ seq_block_ids_tensor = replicate(seq_block_ids_tensor, tp)
+ decode_attention_mask = replicate(decode_attention_mask, tp)
+
logits = model.decode(
self.next_tokens,
attention_mask=decode_attention_mask,
@@ -250,16 +258,15 @@ def main():
)
cli.add_input_dataset_options(parser)
cli.add_tokenizer_options(parser)
+ cli.add_quantization_options(parser)
+ cli.add_model_options(parser)
args = cli.parse(parser)
-
device = torch.device(args.device) if args.device else None
activation_dtype = getattr(torch, args.activation_dtype)
assert isinstance(activation_dtype, torch.dtype)
-
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
prompts = args.prompt
-
config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
@@ -267,8 +274,10 @@ def main():
device=device,
activation_dtype=activation_dtype,
attention_dtype=activation_dtype,
+ attention_kernel=args.attention_kernel,
use_hf=args.use_hf,
tensor_parallelism_size=args.tensor_parallelism_size,
+ fake_quant=args.fake_quant,
)
if config.tensor_parallelism_size > 1:
dataset.root_theta = shard_theta(dataset.root_theta, config)
diff --git a/sharktank/sharktank/kernels/__init__.py b/sharktank/sharktank/kernels/__init__.py
index beb7e90a2..445f44852 100644
--- a/sharktank/sharktank/kernels/__init__.py
+++ b/sharktank/sharktank/kernels/__init__.py
@@ -14,3 +14,4 @@
from .conv_2d_nchw_fchw import *
from .pooling_nchw_sum import *
from .base import *
+from .bitcast import *
diff --git a/sharktank/sharktank/kernels/bitcast.py b/sharktank/sharktank/kernels/bitcast.py
new file mode 100644
index 000000000..66850008f
--- /dev/null
+++ b/sharktank/sharktank/kernels/bitcast.py
@@ -0,0 +1,138 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from sharktank.kernels.base import *
+
+import torch
+
+from iree.turbine.support.ir_imports import (
+ ComplexType,
+ F16Type,
+ F32Type,
+ RankedTensorType,
+ ShapedType,
+ Value,
+ flow_d,
+ tensor_d,
+)
+
+from iree.turbine.runtime.op_reg import (
+ CustomOp,
+ KernelBuilder,
+ KernelSelection,
+)
+
+__all__ = [
+ "bitcast_to_complex",
+ "bitcast_to_real",
+]
+
+_ftype_to_ctype_table = {
+ torch.float16: torch.complex32,
+ torch.float32: torch.complex64,
+}
+
+_ctype_to_ftype_table = {
+ torch.complex32: torch.float16,
+ torch.complex64: torch.float32,
+}
+
+_type_to_irtype_table = {
+ torch.float16: lambda: F16Type.get(),
+ torch.float32: lambda: F32Type.get(),
+ torch.complex32: lambda: ComplexType.get(F16Type.get()),
+ torch.complex64: lambda: ComplexType.get(F32Type.get()),
+}
+
+
+@CustomOp.register(library=LIBRARY)
+class bitcast_to_complex(CustomOp):
+
+ signature = "bitcast_to_complex(Tensor q) -> (Tensor)"
+
+ def select(self, ksel: KernelSelection):
+ ta = ksel.arg_tensor(0)
+
+ torch._check(ta.t.dtype in _ftype_to_ctype_table)
+ torch._check(isinstance(ta.t.shape[-1], int))
+
+ new_shape = [i for i in ta.t.shape]
+ new_shape[-1] = new_shape[-1] // 2
+
+ ctype = _ftype_to_ctype_table[ta.t.dtype]
+ ret = ksel.return_new_tensor(new_shape, dtype=ctype)
+ specialize_all_known_dims(ta)
+ specialize_all_known_dims(ret)
+
+ def eager_execute(self, tensor):
+ return torch.view_as_complex(tensor.unflatten(-1, (-1, 2)))
+
+ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
+ t = kb.arg_bindings[0]
+ result_desc = ksel.result_descs[0]
+ result_shape = [
+ d if isinstance(d, int) else RankedTensorType.get_dynamic_size()
+ for d in result_desc.t.shape
+ ]
+
+ dynamic_dims: list[Value] = []
+ _append_dynamic_dims(kb, dynamic_dims, t)
+
+ c64 = _type_to_irtype_table[result_desc.t.dtype]()
+ rtt = RankedTensorType.get(result_shape, c64)
+ result = flow_d.TensorBitCastOp(rtt, t, dynamic_dims, dynamic_dims).result
+ kb.yield_results(result)
+
+
+@CustomOp.register(library=LIBRARY)
+class bitcast_to_real(CustomOp):
+
+ signature = "bitcast_to_real(Tensor q) -> (Tensor)"
+
+ def select(self, ksel: KernelSelection):
+ ta = ksel.arg_tensor(0)
+
+ torch._check(ta.t.dtype in _ctype_to_ftype_table)
+ torch._check(isinstance(ta.t.shape[-1], int))
+
+ new_shape = [i for i in ta.t.shape]
+ new_shape[-1] = new_shape[-1] * 2
+
+ ftype = _ctype_to_ftype_table[ta.t.dtype]
+ ret = ksel.return_new_tensor(new_shape, dtype=ftype)
+ specialize_all_known_dims(ta)
+ specialize_all_known_dims(ret)
+
+ def eager_execute(self, tensor):
+ return torch.view_as_real(tensor).flatten(-2, -1)
+
+ def generate(self, ksel: KernelSelection, kb: KernelBuilder):
+ t = kb.arg_bindings[0]
+ result_desc = ksel.result_descs[0]
+ result_shape = [
+ d if isinstance(d, int) else RankedTensorType.get_dynamic_size()
+ for d in result_desc.t.shape
+ ]
+
+ dynamic_dims: list[Value] = []
+ _append_dynamic_dims(kb, dynamic_dims, t)
+
+ ftype = _type_to_irtype_table[result_desc.t.dtype]()
+ rtt = RankedTensorType.get(result_shape, ftype)
+ result = flow_d.TensorBitCastOp(rtt, t, dynamic_dims, dynamic_dims).result
+ kb.yield_results(result)
+
+
+################################################################################
+# Emission utilities
+################################################################################
+
+
+def _append_dynamic_dims(kb: KernelBuilder, dynamic_dims: list[Value], tensor: Value):
+ rtt = RankedTensorType(tensor.type)
+ for i in range(rtt.rank):
+ if rtt.is_dynamic_dim(i):
+ dynamic_dims.append(tensor_d.dim(tensor, kb.constant_index(i)))
diff --git a/sharktank/sharktank/layers/causal_llm.py b/sharktank/sharktank/layers/causal_llm.py
index 7a09995a8..8ace77981 100644
--- a/sharktank/sharktank/layers/causal_llm.py
+++ b/sharktank/sharktank/layers/causal_llm.py
@@ -33,12 +33,14 @@ def __init__(
device: Optional[torch.device] = None,
activation_dtype: torch.dtype = torch.float32,
attention_dtype: torch.dtype = torch.float32,
+ fake_quant: bool = True,
):
super().__init__(theta)
self.device = device
self.activation_dtype = activation_dtype
self.attention_dtype = attention_dtype
self.context_length = context_length
+ self.fake_quant = fake_quant
if static_tables:
self.register_buffer(
diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py
index 6dbe6fc52..c440ad441 100644
--- a/sharktank/sharktank/layers/configs/llm_configs.py
+++ b/sharktank/sharktank/layers/configs/llm_configs.py
@@ -44,7 +44,7 @@ class LlamaHParams:
@staticmethod
def from_gguf_props(p: dict[str, Any]):
- name_prefix = p["general.architecture"]
+ name_prefix = p.get("general.architecture", "llama")
default_expert_count = 0
default_expert_used_count = 0
default_rope_freq_base = 10000.0
@@ -156,6 +156,9 @@ class LlamaModelConfig:
# Dtype to use for attention.
attention_dtype: torch.dtype = torch.float16
+ # fake quant determines the mode the Layer Thetas operate w.r.t quantized tensors.
+ fake_quant: bool = True
+
# How many devices are involved for tensor parallel sharding.
# If greater than 1, the model will expect sharded model parameters and function
# arguments.
diff --git a/sharktank/sharktank/layers/kv_cache.py b/sharktank/sharktank/layers/kv_cache.py
index d9ed05f79..c73b7a8f4 100644
--- a/sharktank/sharktank/layers/kv_cache.py
+++ b/sharktank/sharktank/layers/kv_cache.py
@@ -141,7 +141,7 @@ def read(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
- dest_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
+ read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
seq_len: int,
page_ids: Optional[Union[torch.Tensor, ReplicatedTensor]] = None,
@@ -150,7 +150,7 @@ def read(
Args:
state: State struct as returned from allocate().
- dest_partitions: List of cache partitions to read into in-place.
+ read_into_partitions: List of cache partitions to read into in-place.
transformer_block_index: The index of the transformer block accessing
the cache.
page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids
@@ -161,7 +161,7 @@ def read(
materializing linearly may not be terribly efficient unless if the
compiler can fuse the gather.
"""
- read_count = len(dest_partitions)
+ read_count = len(read_into_partitions)
reads = []
for i in range(read_count):
reads.append(
@@ -284,10 +284,6 @@ def unflatten_page_table(
"""Unflattens the 2D page table to a 6D tensor."""
assert len(state) == 1, f"Expected 1-element state. Got: {len(state)}"
page_slab = state[0]
-
- if len(page_slab.shape) == 6:
- return page_slab
-
if self.shard_count == 1:
assert not isinstance(page_slab, SplitPrimitiveTensor)
return page_slab.unflatten(1, self.sub_page_dims)
@@ -356,7 +352,7 @@ def read(
self,
state: list[Union[torch.Tensor, SplitPrimitiveTensor]],
*,
- dest_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
+ read_into_partitions: list[Union[torch.Tensor, SplitPrimitiveTensor]],
transformer_block_index: int,
seq_len: int,
page_ids: Union[torch.Tensor, ReplicatedTensor],
@@ -365,7 +361,7 @@ def read(
Args:
state: State struct as returned from allocate().
- dest_partitions: List of cache partitions to read into in-place.
+ read_into_partitions: List of cache partitions to read into in-place.
transformer_block_index: The index of the transformer block accessing
the cache.
page_ids: Tensor of [bs, max_seqlen // block_pos_stride] of page ids
@@ -378,9 +374,36 @@ def read(
"""
page_table = self.unflatten_page_table(state) # 6D
- def read_cache_partitions(
- into_partitions: List[Union[torch.Tensor, SplitPrimitiveTensor]]
+ bs, block_seq_len, *_ = page_ids.shape
+ # Blocks dim 1,2 according to the configured block stride.
+ blocked_shape = [
+ bs,
+ block_seq_len,
+ self.block_seq_stride,
+ self.attn_head_count // self.shard_count,
+ self.attn_head_dim,
+ ]
+
+ # Reshape the page cache into sub-blocks so that we can index at the
+ # granularity of the transformer_block and cache partition.
+ # This requires us to recompute indices to the sub-block reference
+ # frame.
+ # The subblock slab is organized as:
+ # [page, attn_layer, cache_partition]
+ # Where the cache line can be 0 (k) or 1 (v).
+ subblock_table = page_table.flatten(start_dim=0, end_dim=2)
+ page_stride = self.transformer_block_count * self.cache_partition_count
+ transformer_block_stride = self.cache_partition_count
+ base_subblock_ids = page_ids * page_stride + (
+ transformer_block_index * transformer_block_stride
+ )
+
+ def read_cache_partition(
+ index: int, into_partition: Union[torch.Tensor, SplitPrimitiveTensor]
):
+ subblock_ids = (
+ (base_subblock_ids + index) if index > 0 else base_subblock_ids
+ )
# TODO: Potentially clamp all page 0 indices to the mask value.
# Or even better, require that the ids are replicated such that access is
# legal.
@@ -389,16 +412,18 @@ def read_cache_partitions(
# copy of the sub-blocks by collapsing the first two dims so we have
# a linear list.
# TODO: Can be rewritten into inplace with out= on index_select.
+ selected = (
+ ops.index_select(subblock_table, 0, subblock_ids.flatten(0, 1))
+ .unflatten(0, blocked_shape[0:2])
+ .flatten(1, 2)
+ )
+ # trace_tensor("kv.selected", selected)
+ into_partition[...] = selected
- for i, into_partition in enumerate(into_partitions):
- selected = page_table[
- page_ids.flatten(0, 1), transformer_block_index, i
- ]
- selected = selected.unflatten(0, page_ids.shape).flatten(1, 2)
- into_partition[...] = selected
+ for index, read_into_partition in enumerate(read_into_partitions):
+ read_cache_partition(index, read_into_partition)
- read_cache_partitions(dest_partitions)
- return tuple([p[:, :seq_len, :] for p in dest_partitions])
+ return tuple([p[:, :seq_len, :] for p in read_into_partitions])
def write_timestep(
self,
@@ -463,25 +488,46 @@ def write(
in-place scatter cannot be fused.
"""
page_table = self.unflatten_page_table(state) # 6D
- _, block_seq_len, *_ = page_ids.shape
+
+ bs, block_seq_len, *_ = page_ids.shape
+ # Blocks dim 1,2 according to the configured block stride.
+ blocked_shape = [
+ bs,
+ block_seq_len,
+ self.block_seq_stride,
+ self.attn_head_count,
+ self.attn_head_dim,
+ ]
+
+ # Reshape the page cache into sub-blocks so that we can index at the
+ # granularity of the transformer_block and cache partition.
+ # This requires us to recompute indices to the sub-block reference
+ # frame.
+ # The subblock slab is organized as:
+ # [page, attn_layer, cache_partition]
+ # Where the cache line can be 0 (k) or 1 (v).
+ subblock_table = page_table.flatten(start_dim=0, end_dim=2)
+ page_stride = self.transformer_block_count * self.cache_partition_count
+ transformer_block_stride = self.cache_partition_count
+ base_subblock_ids = page_ids * page_stride + (
+ transformer_block_index * transformer_block_stride
+ )
part_block_views = []
- for partition in cache_partitions:
+ subblock_ids_kv = []
+ for index, partition in enumerate(cache_partitions):
part_block_view = partition.unflatten(
1, (block_seq_len, self.block_seq_stride)
)
- part_block_view = part_block_view.unsqueeze(2)
+ part_block_view = part_block_view.flatten(0, 1)
part_block_views.append(part_block_view)
- part_block_view = ops.cat(part_block_views, dim=2)
+ subblock_ids = (
+ (base_subblock_ids + index) if index > 0 else base_subblock_ids
+ ).flatten(0, 1)
+ subblock_ids_kv.append(subblock_ids)
- page_ids = page_ids.flatten(0, 1)
- part_block_view = part_block_view.flatten(0, 1)
+ subblock_ids = ops.cat(subblock_ids_kv)
+ part_block_view = ops.cat(part_block_views, dim=0)
- page_table.index_put_(
- (
- page_ids,
- torch.full(page_ids.shape, transformer_block_index, dtype=torch.int64),
- ),
- part_block_view,
- )
+ subblock_table.index_copy_(0, subblock_ids, part_block_view)
diff --git a/sharktank/sharktank/layers/linear.py b/sharktank/sharktank/layers/linear.py
index c5e2ea330..b679dccde 100644
--- a/sharktank/sharktank/layers/linear.py
+++ b/sharktank/sharktank/layers/linear.py
@@ -15,6 +15,7 @@
QuantizerTensor,
StaticScaledQuantizer,
TensorScaledLayout,
+ PlanarQuantizedTensor,
)
__all__ = [
@@ -29,6 +30,10 @@ class LinearLayer(ThetaLayer):
if premul_input is not None:
x = x * premul_input
matmul(x, weight.T) + bias
+
+ fake_quant exists to allow export without adding dequant ops.
+ when fake_quant is True, the op will in quant dequant fashion.
+ When false, it will keep quantized types.
```
"""
@@ -38,11 +43,13 @@ def __init__(
*,
weight_name: str = "weight",
bias_name: str = "bias",
+ fake_quant: bool = True,
):
super().__init__(theta)
self._simulate_native_quant = True
self.weight = self.theta_tensor(weight_name)
self.bias = None
+ self.fake_quant = fake_quant
if bias_name in self.theta.keys:
self.bias = self.theta_tensor(bias_name)
@@ -65,18 +72,23 @@ def forward(self, x):
if q_input is not None:
x = q_input.quantize(x)
- elif qdq_input is not None:
- # TODO: probably need a way to only do q_input if exporting.
+ if self.fake_quant:
+ x = x.unpack().dequant()
+ elif qdq_input is not None and self.fake_quant:
x = qdq_input.quantize(x).unpack().dequant()
y = ops.linear(x, weight, bias)
# Unconditionally dequantize.
- # TODO: Support a q_output specifier that signals the layer to let
- # the QuantizedTensor escape.
- if isinstance(y, QuantizedTensor):
+ if isinstance(y, QuantizedTensor) and not self.fake_quant:
y = y.unpack().dequant()
- if qdq_output is not None:
- # TODO: same as above.
+ # Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32.
+ # We can truncate to fp16 in iree, so we do a cast here
+ # to account for this in the IR. This is may not be the right
+ # level to do this, but for now its here.
+ if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz:
+ y = ops.to(y, torch.float16)
+ return y
+ if qdq_output is not None and self.fake_quant:
y = qdq_output.quantize(y).unpack().dequant()
return y
diff --git a/sharktank/sharktank/layers/paged_llama_attention_block.py b/sharktank/sharktank/layers/paged_llama_attention_block.py
index 796e8224a..22647bf49 100644
--- a/sharktank/sharktank/layers/paged_llama_attention_block.py
+++ b/sharktank/sharktank/layers/paged_llama_attention_block.py
@@ -10,7 +10,7 @@
import torch
import torch.nn.functional as F
-
+from ..types import QuantizerTensor
from .base import Theta, ThetaLayer
from .linear import LinearLayer
from .norm import RMSNormLayer
@@ -40,6 +40,7 @@ def __init__(
attention_kernel: str = "decomposed",
attention_scale: Optional[float] = None,
softcap: Optional[float] = None,
+ fake_quant: Optional[bool] = True,
):
super().__init__(theta)
@@ -51,14 +52,28 @@ def __init__(
self.attention_kernel = attention_kernel
self.attention_scale = attention_scale
self.softcap = softcap
+ self.fake_quant = fake_quant
self.add_module(
"attn_norm", RMSNormLayer(theta("attn_norm"), epsilon=rms_epsilon)
)
- self.add_module("attn_q", LinearLayer(theta("attn_q")))
- self.add_module("attn_k", LinearLayer(theta("attn_k")))
- self.add_module("attn_v", LinearLayer(theta("attn_v")))
- self.add_module("attn_output", LinearLayer(theta("attn_output")))
+ self.add_module(
+ "attn_q", LinearLayer(theta("attn_q"), fake_quant=self.fake_quant)
+ )
+ self.add_module(
+ "attn_k", LinearLayer(theta("attn_k"), fake_quant=self.fake_quant)
+ )
+ self.add_module(
+ "attn_v", LinearLayer(theta("attn_v"), fake_quant=self.fake_quant)
+ )
+ self.add_module(
+ "attn_output", LinearLayer(theta("attn_output"), fake_quant=self.fake_quant)
+ )
+ self.cache_quantizer = None
+ if "kv_cache" in theta.keys:
+ self.cache_quantizer: Optional[QuantizerTensor] = theta.optional_tensor(
+ "kv_cache.quantizer"
+ )
if theta.optional_tensor("attn_output_norm") is None:
self.add_module(
@@ -104,15 +119,38 @@ def forward(
# Fast path to start_index based embedding lookup if available.
# Falls back to a slower position based index lookup.
if start_index is not None:
- xq, xk = embedding.forward(xq=xq, xk=xk, start_index=start_index)
+ xq = embedding.forward(xt=xq, start_index=start_index)
+ xk = embedding.forward(xt=xk, start_index=start_index)
else:
- xq, xk = embedding.apply_batched_mask(
- xq=xq, xk=xk, mask=embedding_batch_mask
- )
+ xq = embedding.apply_batched_mask(xt=xq, mask=embedding_batch_mask)
+ xk = embedding.apply_batched_mask(xt=xk, mask=embedding_batch_mask)
# Full sequence length.
kv_seq_len = seq_block_ids.shape[1] * self.cache.block_seq_stride
+ # Used by fp8_e4m3fnuz model
+ if self.cache_quantizer is not None:
+ # For fake quant, store the fp16 qdq value in the cache
+ if self.fake_quant:
+ xk = (
+ self.cache_quantizer.quantize(xk)
+ .unpack()
+ .dequant()
+ .to(torch.float16)
+ )
+ xv = (
+ self.cache_quantizer.quantize(xv)
+ .unpack()
+ .dequant()
+ .to(torch.float16)
+ )
+ # For real quant, store the quantized fp8 value in the cache
+ else:
+ # TODO: this seems like a bastardization of our quantized tensor api
+ # Probably want to add support for using quantized tensors more directly
+ xk = self.cache_quantizer.quantize(xk).unpack().qs
+ xv = self.cache_quantizer.quantize(xv).unpack().qs
+
xk, xv = self.transact_cache(
xk_cache_update=xk,
xv_cache_update=xv,
@@ -138,6 +176,14 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
xk = repeat_kv(xk)
xv = repeat_kv(xv)
+ # Fake quant is already dequantized when stored in the cache.
+ if self.cache_quantizer and not self.fake_quant:
+ xk = self.cache_quantizer.dequantize_raw_tensor(
+ xk, torch.float16, name="xk_deq"
+ )
+ xv = self.cache_quantizer.dequantize_raw_tensor(
+ xv, torch.float16, name="xv_deq"
+ )
# Transpose into [bs, heads, sl, dim]
xq = xq.transpose(1, 2)
keys = xk.transpose(1, 2)
@@ -170,7 +216,8 @@ def repeat_kv(x: torch.Tensor) -> torch.Tensor:
attn_weights, values
) # (bs, heads, slen, head_dim)
else:
- is_causal = attention_mask is None and batch_seq_len == 1
+ is_causal = True
+ attention_mask = None
attn_output = ops.scaled_dot_product_attention(
q=xq, # [bs, ..., sl, dim]
k=keys, # [bs, ..., sl, dim]
@@ -241,7 +288,7 @@ def transact_cache(
# Restore from the cache.
xk, xv = cache.read(
cache_state,
- dest_partitions=[
+ read_into_partitions=[
xk_temp[:, 0:kv_seq_len, ...],
xv_temp[:, 0:kv_seq_len, ...],
],
diff --git a/sharktank/sharktank/layers/rotary_embedding.py b/sharktank/sharktank/layers/rotary_embedding.py
index 18a95aba3..0664a9a46 100644
--- a/sharktank/sharktank/layers/rotary_embedding.py
+++ b/sharktank/sharktank/layers/rotary_embedding.py
@@ -53,49 +53,38 @@ def rotary_embed_table(self):
return self.static_rotary_embed_table
return self._create_rotary_embed_table()
- if self.tensor_parallelism_size == 1:
- return None
-
- nt = namedtuple("replicated_tensor", ["shards"])
- return nt([None] * self.tensor_parallelism_size)
+ return None
def forward(
self,
*,
- xq: Union[torch.Tensor, SplitPrimitiveTensor],
- xk: Union[torch.Tensor, SplitPrimitiveTensor],
+ xt: Union[torch.Tensor, SplitPrimitiveTensor],
start_index: int,
):
- if isinstance(xq, SplitPrimitiveTensor):
- assert (
- isinstance(xk, SplitPrimitiveTensor)
- and xq.shard_count == xk.shard_count
- and xk.shard_dim == xq.shard_dim
- )
- assert (
- isinstance(self.rotary_embed_table, ReplicatedTensor)
- and xq.shard_count == self.rotary_embed_table.shard_count
- )
- xqk_shards = [
+ if isinstance(xt, SplitPrimitiveTensor):
+ rotary_shards = [None] * xt.shard_count
+ if self.rotary_embed_table is not None:
+ assert (
+ isinstance(self.rotary_embed_table, ReplicatedTensor)
+ and xt.shard_count == self.rotary_embed_table.shard_count
+ )
+ rotary_shards = [
+ unbox_tensor(shard) for shard in self.rotary_embed_table.shards
+ ]
+
+ xt_shards = [
self.forward_unsharded(
- xq=unbox_tensor(xq_shard),
- xk=unbox_tensor(xk_shard),
+ xt=unbox_tensor(xt_shard),
start_index=start_index,
- rotary_embed_table=unbox_tensor(rotary_embed_table_shard),
- )
- for xq_shard, xk_shard, rotary_embed_table_shard in zip(
- xq.shards, xk.shards, self.rotary_embed_table.shards
+ rotary_embed_table=rotary_shard,
)
+ for xt_shard, rotary_shard in zip(xt.shards, rotary_shards)
]
- xq_shards = [xqk[0] for xqk in xqk_shards]
- xk_shards = [xqk[1] for xqk in xqk_shards]
- xq = SplitPrimitiveTensor(ts=xq_shards, shard_dim=xq.shard_dim)
- xk = SplitPrimitiveTensor(ts=xk_shards, shard_dim=xk.shard_dim)
- return xq, xk
+ xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
+ return xt
else:
return self.forward_unsharded(
- xq=xq,
- xk=xk,
+ xt=xt,
start_index=start_index,
rotary_embed_table=self.rotary_embed_table,
)
@@ -103,8 +92,7 @@ def forward(
def forward_unsharded(
self,
*,
- xq: torch.Tensor,
- xk: torch.Tensor,
+ xt: torch.Tensor,
start_index: int,
rotary_embed_table: Optional[torch.Tensor],
):
@@ -149,60 +137,30 @@ def create_ordering_tensor(dim):
return order_tensor
if self.use_hf:
- xq = xq[..., create_interleaved_tensor(xq.shape[-1])]
- xk = xk[..., create_interleaved_tensor(xq.shape[-1])]
-
- xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2)))
- xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2)))
- _, sl, _, dim = xq_.shape
+ xt = xt[..., create_interleaved_tensor(xt.shape[-1])]
+ xt_ = xt
+ _, sl, _, _ = xt_.shape
# Offset the table based on starting position.
if self.use_table:
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
+ freqs_cis = freqs_cis[None, 0:sl, None, :]
else:
- freqs_cis = torch.arange(start_index, start_index + sl, device=xq.device)
- freqs_cis = self._compute_rotary_embed_table(freqs_cis)
- freqs_cis = self._replicate(freqs_cis)
+ freqs_cis = torch.arange(sl, device=xt.device) + start_index
+ freqs_cis = self._compute_rotary_embed_table(freqs_cis)[None, :, None, :]
- assert freqs_cis.shape[-1] == dim
assert (
- freqs_cis.shape[0] >= sl
+ freqs_cis.shape[1] >= sl
), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"
- broadcast_freqs_cis = freqs_cis[None, 0:sl, None, :]
+ xt_ = ops.view_as_complex(xt_)
+ xt_ = xt_ * freqs_cis
+ xt_out = ops.view_as_real(xt_)
if self.use_hf:
- xq_out = torch.view_as_real(
- self.complex_multiply(xq_, broadcast_freqs_cis)
- ).flatten(3)
- xk_out = torch.view_as_real(
- self.complex_multiply(xk_, broadcast_freqs_cis)
- ).flatten(3)
-
- xq_out = xq_out[..., create_ordering_tensor(xq_out.shape[-1])]
- xk_out = xk_out[..., create_ordering_tensor(xq_out.shape[-1])]
-
- return xq_out.type_as(xq), xk_out.type_as(xk)
+ xt_out = xt_out[..., create_ordering_tensor(xt_out.shape[-1])]
- xq_out = torch.view_as_real(xq_ * broadcast_freqs_cis).flatten(3)
- xk_out = torch.view_as_real(xk_ * broadcast_freqs_cis).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
-
- def complex_multiply(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
- """Function for elementwise-multiplication of two complex torch tensors.
- Functionally similar to a*b, but numerically accurate for HuggingFace
- LLaMa implementation.
-
- Args:
- a: First torch tensor operand
- b: Second torch tensor operand
- Returns:
- Tensor of same size to a, b whose elements is product of corresponding
- elements in a, b
- """
- return torch.complex(
- a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real
- )
+ return ops.to(xt_out, xt.dtype)
def compute_batch_mask(
self, start_positions: Union[torch.Tensor, ReplicatedTensor], batch_seq_len: int
@@ -227,8 +185,15 @@ def compute_batch_mask(
freqs_cis = self.rotary_embed_table[positions_seq]
else:
shape = positions_seq.shape
- freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
- freqs_cis = freqs_cis.unflatten(0, shape)
+ if isinstance(positions_seq, ReplicatedTensor):
+ ts = [
+ self._compute_rotary_embed_table(s.flatten()).unflatten(0, shape)
+ for s in positions_seq.shards
+ ]
+ freqs_cis = ReplicatedTensor(ts=ts)
+ else:
+ freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())
+ freqs_cis = freqs_cis.unflatten(0, shape)
# Unsqueeze a unit dim for attention heads.
broadcast_freqs_cis = freqs_cis.unsqueeze(2)
@@ -237,41 +202,24 @@ def compute_batch_mask(
def apply_batched_mask(
self,
*,
- xq: Union[torch.Tensor, SplitPrimitiveTensor],
- xk: Union[torch.Tensor, SplitPrimitiveTensor],
+ xt: Union[torch.Tensor, SplitPrimitiveTensor],
mask: Union[torch.Tensor, ReplicatedTensor],
):
- if isinstance(xq, SplitPrimitiveTensor):
- assert (
- isinstance(xk, SplitPrimitiveTensor)
- and xq.shard_count == xk.shard_count
- and xk.shard_dim == xq.shard_dim
+ if not isinstance(xt, SplitPrimitiveTensor):
+ return self.apply_batched_mask_unsharded(xt=xt, mask=mask)
+
+ assert isinstance(mask, ReplicatedTensor) and mask.shard_count == xt.shard_count
+ xt_shards = [
+ self.apply_batched_mask_unsharded(
+ xt=unbox_tensor(xt_shard),
+ mask=unbox_tensor(mask_shard),
)
- assert (
- isinstance(mask, ReplicatedTensor)
- and mask.shard_count == xq.shard_count
- )
- xqk_shards = [
- self.apply_batched_mask_unsharded(
- xq=unbox_tensor(xq_shard),
- xk=unbox_tensor(xk_shard),
- mask=unbox_tensor(mask_shard),
- )
- for xq_shard, xk_shard, mask_shard in zip(
- xq.shards, xk.shards, mask.shards
- )
- ]
- xq_shards = [xqk[0] for xqk in xqk_shards]
- xk_shards = [xqk[1] for xqk in xqk_shards]
- xq = SplitPrimitiveTensor(ts=xq_shards, shard_dim=xq.shard_dim)
- xk = SplitPrimitiveTensor(ts=xk_shards, shard_dim=xk.shard_dim)
- return xq, xk
- else:
- return self.apply_batched_mask_unsharded(xq=xq, xk=xk, mask=mask)
+ for xt_shard, mask_shard in zip(xt.shards, mask.shards)
+ ]
+ xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
+ return xt
- def apply_batched_mask_unsharded(
- self, *, xq: torch.Tensor, xk: torch.Tensor, mask: torch.Tensor
- ):
+ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
"""Applies the embedding to a ragged batch of queries and keys.
This does a more complicated indexing operation for cases when the each
@@ -281,29 +229,23 @@ def apply_batched_mask_unsharded(
"""
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim
- xq_ = torch.view_as_complex(xq.unflatten(-1, (-1, 2)))
- xk_ = torch.view_as_complex(xk.unflatten(-1, (-1, 2)))
- _, sl, _, dim = xq_.shape
+ xt_ = ops.view_as_complex(xt)
+ xt_ = xt_ * mask
+ xt_out = ops.view_as_real(xt_)
- xq_out = torch.view_as_real(xq_ * mask).flatten(3)
- xk_out = torch.view_as_real(xk_ * mask).flatten(3)
- return xq_out.type_as(xq), xk_out.type_as(xk)
+ return xt_out.type_as(xt)
def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
freqs = 1.0 / (
- self.rope_freq_base
- ** (torch.arange(0, dim, 2, device=t.device)[: (dim // 2)].float() / dim)
+ self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
freqs = torch.outer(t, freqs).float()
- freqs_cis = (
- torch.complex(torch.cos(freqs), torch.sin(freqs))
- if self.use_hf
- else torch.polar(torch.ones_like(freqs), freqs)
- )
-
- return freqs_cis
+ cos = torch.cos(freqs)
+ sin = torch.sin(freqs)
+ complex = torch.complex(cos, sin)
+ return complex
def _create_rotary_embed_table(self):
t = torch.arange(self.max_seqlen, device=self.device)
diff --git a/sharktank/sharktank/models/llama/llama.py b/sharktank/sharktank/models/llama/llama.py
index ef3c4800d..0a9a6f1c3 100644
--- a/sharktank/sharktank/models/llama/llama.py
+++ b/sharktank/sharktank/models/llama/llama.py
@@ -71,6 +71,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
device=config.device,
activation_dtype=config.activation_dtype,
attention_dtype=config.attention_dtype,
+ fake_quant=config.fake_quant,
)
self.config = config
self.hp = hp
@@ -113,6 +114,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
head_count_kv=hp.attention_head_count_kv,
rms_epsilon=hp.attention_layer_norm_rms_epsilon,
attention_kernel=self.attention_kernel,
+ fake_quant=self.fake_quant,
)
for n in range(hp.block_count)
]
@@ -186,29 +188,6 @@ def decode(
self._assert_device(start_positions)
self._assert_device(*cache_state, dtype=self.activation_dtype)
- if self.config.tensor_parallelism_size > 1:
- if not isinstance(tokens, ReplicatedTensor):
- tokens = ops.replicate(
- tokens, count=self.config.tensor_parallelism_size
- )
- if not isinstance(attention_mask, ReplicatedTensor):
- attention_mask = ops.replicate(
- attention_mask, count=self.config.tensor_parallelism_size
- )
- if not isinstance(start_positions, ReplicatedTensor):
- start_positions = ops.replicate(
- start_positions, count=self.config.tensor_parallelism_size
- )
- if not isinstance(seq_block_ids, ReplicatedTensor):
- seq_block_ids = ops.replicate(
- seq_block_ids, count=self.config.tensor_parallelism_size
- )
- # If the user provided unsharded arguments they probably want
- # an unsharded result as well.
- unshard_result = True
- else:
- unshard_result = False
-
bs, _ = tokens.shape
# Precompute a position based mask for computing rope embeddings
# as it is the same for all blocks.
@@ -307,6 +286,7 @@ def __init__(
head_count_kv: int,
rms_epsilon: float,
attention_kernel: str = "decomposed",
+ fake_quant: bool = True,
):
super().__init__(theta)
self.add_module(
@@ -320,6 +300,7 @@ def __init__(
head_count_kv=head_count_kv,
rms_epsilon=rms_epsilon,
attention_kernel=attention_kernel,
+ fake_quant=fake_quant,
),
)
self.add_module(
diff --git a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py
index 0d869932e..052593748 100644
--- a/sharktank/sharktank/models/llama/tools/import_quark_dataset.py
+++ b/sharktank/sharktank/models/llama/tools/import_quark_dataset.py
@@ -107,25 +107,21 @@ def apply_per_layer_quant(
layer_theta = root_theta(layer_name)
- weight_quant_scale = layer_theta.tensor("weight_quant_scale").as_torch()
+ weight_quant_scale = layer_theta.tensor("weight_scale").as_torch()
weight = layer_theta.tensor("weight").as_torch()
# It looks dumb but, this step is required for numerical correctness against quark.
- weight = weight.view(torch.float8_e4m3fn)
+ # weight = weight.view(torch.float8_e4m3fn)
weight = (weight.to(torch.float64) * weight_quant_scale).to(torch.float16)
- weight_quant_zero_point = layer_theta.optional_tensor("weight_quant_zero_point")
+ weight_quant_zero_point = layer_theta.optional_tensor("weight_zero_point")
if weight_quant_zero_point == None:
weight_quant_zero_point = torch.zeros(1, dtype=torch.float32)
else:
weight_quant_zero_point = weight_quant_zero_point.as_torch()
- input_quant_scale = as_torch_or_none(
- layer_theta.optional_tensor("input_quant_scale")
- )
- output_quant_scale = as_torch_or_none(
- layer_theta.optional_tensor("output_quant_scale")
- )
+ input_quant_scale = as_torch_or_none(layer_theta.optional_tensor("input_scale"))
+ output_quant_scale = as_torch_or_none(layer_theta.optional_tensor("output_scale"))
if weight_quant_scale is None:
print("weight quant scale not found for layer ", layer_name)
@@ -190,11 +186,11 @@ def quantize_weight(
reciprocal_scale=output_quant_scale * 2.0,
dtype=torch.float8_e4m3fnuz,
)
- names = [f"{i}.qdq_input" for i in [q_name, k_name, v_name]]
+ names = [f"{i}.q_input" for i in [q_name, k_name, v_name]]
for name in names:
updated_tensors[name] = StaticScaledQuantizer(
name=name,
- scale=1.0 / input_quant_scale * 2.0,
+ scale=1.0 / (input_quant_scale * 2.0),
reciprocal_scale=input_quant_scale * 2.0,
dtype=torch.float8_e4m3fnuz,
)
@@ -214,18 +210,18 @@ def quantize_weight(
)
# we explicitly provide the reciprocal scale because converting from float16 to float8 after doing 1/scale results in significant numerical differences
if input_quant_scale is not None:
- updated_tensors[new_layer_name + ".qdq_input"] = StaticScaledQuantizer(
- name=new_layer_name + ".qdq_input",
- scale=1.0 / input_quant_scale,
- reciprocal_scale=input_quant_scale,
- dtype=torch.float8_e4m3fn,
+ updated_tensors[new_layer_name + ".q_input"] = StaticScaledQuantizer(
+ name=new_layer_name + ".q_input",
+ scale=1.0 / (input_quant_scale * 2.0),
+ reciprocal_scale=input_quant_scale * 2.0,
+ dtype=torch.float8_e4m3fnuz,
)
if output_quant_scale is not None:
updated_tensors[new_layer_name + ".qdq_output"] = StaticScaledQuantizer(
name=new_layer_name + ".qdq_output",
scale=1.0 / output_quant_scale,
reciprocal_scale=output_quant_scale,
- dtype=torch.float8_e4m3fn,
+ dtype=torch.float8_e4m3fnuz,
)
# Remove the updated tensor from the original tree.
@@ -261,15 +257,15 @@ def update_norm_layer(
sub_name = layer_name + "." + sub
new_name = hf_to_gguf(sub_name) + ".weight"
single_replace(quant_theta, sub_name, new_name, updated_tensors)
- kv_cache_scale = (
- quant_theta(layer_name).tensor("kv_cache_scaling_factor").as_torch()
- )
+ kv_cache_scale = quant_theta(layer_name, "self_attn").tensor("kv_scale").as_torch()
layer_idx = layer_name.split(".")[-1]
new_name = f"blk.{layer_idx}.kv_cache"
- kv_cache_scale = DefaultPrimitiveTensor(
- name=new_name + ".kv_cache_scaling_factor", data=kv_cache_scale
+ updated_tensors[new_name] = StaticScaledQuantizer(
+ name=new_name + ".quantizer",
+ scale=1.0 / (kv_cache_scale * 2.0),
+ reciprocal_scale=kv_cache_scale * 2.0,
+ dtype=torch.float8_e4m3fnuz,
)
- updated_tensors[new_name] = kv_cache_scale
def single_replace(
@@ -279,6 +275,8 @@ def single_replace(
updated_tensors: dict[str, InferenceTensor],
):
data = quant_theta(layer_name).tensor("weight").as_torch()
+ if data.dtype == torch.bfloat16:
+ data = data.to(torch.float32)
updated_tensors[gguf_name] = DefaultPrimitiveTensor(name=gguf_name, data=data)
@@ -330,7 +328,9 @@ def main(argv):
"mlp.down_proj",
"mlp.up_proj",
"self_attn.o_proj",
- "self_attn.qkv",
+ "self_attn.q_proj",
+ "self_attn.k_proj",
+ "self_attn.v_proj",
]
for layer in model_layers:
for sub in sub_layers:
diff --git a/sharktank/sharktank/ops/custom_impls.py b/sharktank/sharktank/ops/custom_impls.py
index c5079f6d4..9acc7c562 100644
--- a/sharktank/sharktank/ops/custom_impls.py
+++ b/sharktank/sharktank/ops/custom_impls.py
@@ -5,21 +5,24 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import torch
+
from torch import Tensor, dtype
+from typing import Union
+
import torch.nn.functional as F
from ..kernels import (
einsum_2args_q4,
mmt_block_scaled_offset_q4_unsigned,
mmt_block_scaled_q8,
- mmtfp,
mmt_super_block_scaled_offset_q4_unsigned,
+ bitcast_to_complex,
+ bitcast_to_real,
)
from ..types import (
BlockScaledLayout,
BlockScaledI4Layout,
- InferenceTensor,
PrimitiveTensor,
QuantizedTensor,
SuperBlockOffsetScaled_4_6_Layout,
@@ -30,7 +33,7 @@
# Fused FP matmul.
-# Disabled: See https://github.com/nod-ai/SHARK-Platform/issues/44
+# Disabled: See https://github.com/nod-ai/shark-ai/issues/44
# @matmul.override(Tensor, Tensor)
# def matmul_mmtfp_tensor_tensor(lhs, rhs, *, transpose_rhs: bool):
# lhs = unbox_tensor(lhs)
@@ -123,3 +126,15 @@ def matmul_generic_tensor_super_block_offset_scaled_4_6_i4(
sb_mins_low,
rhs_unpacked.qs_bit_packed,
)
+
+
+@view_as_complex.override(Union[Tensor, PrimitiveTensor])
+def view_as_complex(t):
+ t = unbox_tensor(t)
+ return bitcast_to_complex(t)
+
+
+@view_as_real.override(Union[Tensor, PrimitiveTensor])
+def view_as_real(t):
+ t = unbox_tensor(t)
+ return bitcast_to_real(t)
diff --git a/sharktank/sharktank/ops/default_impls.py b/sharktank/sharktank/ops/default_impls.py
index 08a9c896b..b155fdaa3 100644
--- a/sharktank/sharktank/ops/default_impls.py
+++ b/sharktank/sharktank/ops/default_impls.py
@@ -355,7 +355,6 @@ def matmul_default(lhs, rhs, *, transpose_rhs: bool) -> Tensor:
rhs = unbox_tensor(rhs)
if transpose_rhs:
rhs = rhs.mT
-
rhs = rhs.to(lhs.dtype)
if len(lhs.shape) > 2 and len(rhs.shape) < 3:
@@ -503,3 +502,13 @@ def view_QuantizedTensor(tensor: QuantizedTensor, shape):
new_m = unpacked.m.view(shape[:-1] + [shape[-1] // 32, 1])
layout = BlockScaledI4Layout(shape=shape, d=new_d, qs=new_qs, m=new_m)
return PlanarQuantizedTensor(shape=shape, layout=layout)
+
+
+@view_as_complex.override(Tensor)
+def view_as_complex_default(tensor: Union[Tensor, PrimitiveTensor]) -> Tensor:
+ return torch.view_as_complex(unbox_tensor(tensor))
+
+
+@view_as_real.override(Tensor)
+def view_as_real_default(tensor: Union[Tensor, PrimitiveTensor]) -> Tensor:
+ return torch.view_as_real(unbox_tensor(tensor))
diff --git a/sharktank/sharktank/ops/qlinear_impls.py b/sharktank/sharktank/ops/qlinear_impls.py
index f4f7ac0ca..b66d3be1d 100644
--- a/sharktank/sharktank/ops/qlinear_impls.py
+++ b/sharktank/sharktank/ops/qlinear_impls.py
@@ -50,10 +50,12 @@ def qlinear_tensor_scaled(
# Handle only integer and fp8 quantizations.
if x_layout.qs.dtype.is_floating_point or weight_layout.qs.dtype.is_floating_point:
- if (
- x_layout.qs.dtype != torch.float8_e4m3fnuz
- or weight_layout.qs.dtype != torch.float8_e4m3fnuz
- ):
+ if x_layout.qs.dtype == torch.float8_e4m3fnuz:
+ # assume quark
+ return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True).to(
+ torch.float16
+ )
+ else:
return NotImplemented
# Bias.
diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py
index 87592c6fd..4aa473e08 100644
--- a/sharktank/sharktank/ops/sharded_impls.py
+++ b/sharktank/sharktank/ops/sharded_impls.py
@@ -1303,3 +1303,27 @@ def view_split(tensor: SplitPrimitiveTensor, shape: List[int]) -> SplitPrimitive
res = SplitPrimitiveTensor(shard_dim=shard_dim, ts=shards)
assert math.prod(res.shape) == math.prod(tensor.shape)
return res
+
+
+@view_as_complex.override(SplitPrimitiveTensor)
+def view_as_complex_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor:
+ shards = [view_as_complex(shard) for shard in tensor.shards]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim)
+
+
+@view_as_complex.override(ReplicatedTensor)
+def view_as_complex_rep(tensor: ReplicatedTensor) -> ReplicatedTensor:
+ shards = [view_as_complex(shard) for shard in tensor.shards]
+ return ReplicatedTensor(ts=shards)
+
+
+@view_as_real.override(SplitPrimitiveTensor)
+def view_as_real_split(tensor: SplitPrimitiveTensor) -> SplitPrimitiveTensor:
+ shards = [view_as_real(shard) for shard in tensor.shards]
+ return SplitPrimitiveTensor(ts=shards, shard_dim=tensor.shard_dim)
+
+
+@view_as_real.override(ReplicatedTensor)
+def view_as_real_rep(tensor: ReplicatedTensor) -> ReplicatedTensor:
+ shards = [view_as_real(shard) for shard in tensor.shards]
+ return ReplicatedTensor(ts=shards)
diff --git a/sharktank/sharktank/ops/signatures.py b/sharktank/sharktank/ops/signatures.py
index d9002ce37..762b99896 100644
--- a/sharktank/sharktank/ops/signatures.py
+++ b/sharktank/sharktank/ops/signatures.py
@@ -59,6 +59,8 @@
"unshard",
"unsqueeze",
"view",
+ "view_as_complex",
+ "view_as_real",
]
IntOrSequenceInt = Union[int, Sequence[int]]
@@ -1087,3 +1089,37 @@ def _view_trampoline(
return override, result
else:
d.fail(tensors)
+
+
+@overridable
+def view_as_complex(tensor: AnyTensor, shape: List[int]) -> AnyTensor:
+ """See torch.Tensor.view_as_complex"""
+ ...
+
+
+@view_as_complex.trampoline
+def _view_as_complex_trampoline(d: SignatureDispatcher, tensor: AnyTensor) -> AnyTensor:
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
+
+
+@overridable
+def view_as_real(tensor: AnyTensor, shape: List[int]) -> AnyTensor:
+ """See torch.Tensor.view_as_complex"""
+ ...
+
+
+@view_as_real.trampoline
+def _view_as_real_trampoline(d: SignatureDispatcher, tensor: AnyTensor) -> AnyTensor:
+ tensors = (tensor,)
+ for override in d.find_overrides(tensors):
+ result = override(tensor)
+ if result is not NotImplemented:
+ return override, result
+ else:
+ d.fail(tensors)
diff --git a/sharktank/sharktank/types/quantizers.py b/sharktank/sharktank/types/quantizers.py
index 575c969de..d3c093b85 100644
--- a/sharktank/sharktank/types/quantizers.py
+++ b/sharktank/sharktank/types/quantizers.py
@@ -131,6 +131,25 @@ def __init__(
else:
assert len(self._scale.shape) == 0, "Expected per-tensor scale to be 0D"
+ def dequantize_raw_tensor(
+ self, t: torch.Tensor, to: torch.dtype, *, name: str
+ ) -> torch.Tensor:
+ return (
+ PlanarQuantizedTensor(
+ shape=t.shape,
+ name=t.name,
+ layout=TensorScaledLayout(
+ shape=t.shape,
+ d=self._reciprocal_scale,
+ qs=t,
+ m=self.offset,
+ dtype=to,
+ ),
+ )
+ .unpack()
+ .dequant()
+ )
+
def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor:
"""Performs a quantizing transformation on t, returning a QuantizeTensor."""
shape = list(t.shape)
diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py
index df8df075b..f870aa101 100644
--- a/sharktank/sharktank/types/tensors.py
+++ b/sharktank/sharktank/types/tensors.py
@@ -543,9 +543,10 @@ def _clone_with_globals(
) -> "InferenceTensor":
return DefaultPrimitiveTensor(name=self.name, data=new_globals[self.name])
- def __getitem__(self, keys):
- if not isinstance(keys, list) and not isinstance(keys, tuple):
- keys = [keys]
+ def __getitem__(self, key):
+ keys = [key]
+ if isinstance(key, tuple) or isinstance(key, list):
+ keys = key
keys = [
unbox_tensor(key) if isinstance(key, PrimitiveTensor) else key
@@ -1188,15 +1189,19 @@ def create(
raise IOError(f"Missing component tensor '' in {raw_tensors.keys()}") from e
return cls(name=name, ts=ts)
- def __getitem__(self, keys):
- if not isinstance(keys, list) and not isinstance(keys, tuple):
- keys = [keys]
+ def __getitem__(self, key):
+ keys = [key]
+ if isinstance(key, tuple) or isinstance(key, list):
+ keys = key
shards = []
for i, shard in enumerate(self.shards):
- shard_keys = [
- k.shards[i] if isinstance(k, ReplicatedTensor) else k for k in keys
- ]
+ shard_keys = []
+ for k in keys:
+ if isinstance(k, ReplicatedTensor):
+ shard_keys.append(k.shards[i])
+ else:
+ shard_keys.append(k)
shards.append(shard[*shard_keys])
return ReplicatedTensor(ts=shards)
diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py
index 396c74363..84ee741bf 100644
--- a/sharktank/sharktank/utils/cli.py
+++ b/sharktank/sharktank/utils/cli.py
@@ -61,6 +61,24 @@ def add_output_dataset_options(parser: argparse.ArgumentParser):
)
+def add_model_options(parser: argparse.ArgumentParser):
+ """Adds model config options not exclusive to export or eager"""
+ parser.add_argument(
+ "--attention-kernel",
+ type=str,
+ default="decomposed",
+ choices=["decomposed", "torch"],
+ )
+
+
+def add_quantization_options(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--fake-quant",
+ action=argparse.BooleanOptionalAction,
+ help="whether or not to run/export the model in fake quant mode. Note, running eagerly without fake quant is dependent on torch types supporting operations. YMMV",
+ )
+
+
def add_tokenizer_options(parser: argparse.ArgumentParser):
"""Adds options for specifying a tokenizer.
diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py
index 057d3b664..9deade56c 100644
--- a/sharktank/sharktank/utils/export_artifacts.py
+++ b/sharktank/sharktank/utils/export_artifacts.py
@@ -25,7 +25,7 @@
class ExportMlirException(Exception):
- """SHARK-Platform export MLIR exception that preserves the command line and error output."""
+ """shark-ai export MLIR exception that preserves the command line and error output."""
def __init__(self, process: subprocess.CompletedProcess, cwd: str):
try:
diff --git a/sharktank/tests/layers/kv_cache_test.py b/sharktank/tests/layers/kv_cache_test.py
index ab66dd97b..65b42c986 100644
--- a/sharktank/tests/layers/kv_cache_test.py
+++ b/sharktank/tests/layers/kv_cache_test.py
@@ -56,7 +56,7 @@ def test_direct():
]
read_back = cache.read(
allocation,
- dest_partitions=read_empty,
+ read_into_partitions=read_empty,
transformer_block_index=1,
seq_len=write_seq_length,
)
@@ -79,7 +79,7 @@ def test_direct():
]
read_ones = cache.read(
allocation,
- dest_partitions=read_ones,
+ read_into_partitions=read_ones,
transformer_block_index=i,
seq_len=write_seq_length,
)
@@ -113,7 +113,7 @@ def test_direct():
]
read_back = cache.read(
allocation,
- dest_partitions=read_empty,
+ read_into_partitions=read_empty,
transformer_block_index=1,
seq_len=write_seq_length + 1,
)
@@ -184,7 +184,7 @@ def test_sharded_direct():
]
read_back = cache.read(
allocation,
- dest_partitions=read_empty,
+ read_into_partitions=read_empty,
transformer_block_index=1,
seq_len=write_seq_length,
)
@@ -225,7 +225,7 @@ def test_sharded_direct():
]
read_back = cache.read(
allocation,
- dest_partitions=read_empty,
+ read_into_partitions=read_empty,
transformer_block_index=1,
seq_len=write_seq_length + 1,
)
@@ -288,7 +288,7 @@ def test_paged():
]
read_back = cache.read(
allocation,
- dest_partitions=read_empty,
+ read_into_partitions=read_empty,
transformer_block_index=1,
seq_len=write_seq_length,
page_ids=write_page_ids,
@@ -312,7 +312,7 @@ def test_paged():
]
read_ones = cache.read(
allocation,
- dest_partitions=read_ones,
+ read_into_partitions=read_ones,
transformer_block_index=i,
seq_len=write_seq_length,
page_ids=write_page_ids,
@@ -348,7 +348,7 @@ def test_paged():
]
read_back = cache.read(
allocation,
- dest_partitions=read_empty,
+ read_into_partitions=read_empty,
transformer_block_index=1,
seq_len=write_seq_length + 1,
page_ids=page_ids,
@@ -436,7 +436,7 @@ def test_sharded_paged():
read_back = cache.read(
allocation,
- dest_partitions=read_empty,
+ read_into_partitions=read_empty,
transformer_block_index=1,
seq_len=write_seq_length,
page_ids=write_page_ids,
@@ -489,7 +489,7 @@ def test_sharded_paged():
read_back = cache.read(
allocation,
- dest_partitions=[empty_k, empty_v],
+ read_into_partitions=[empty_k, empty_v],
transformer_block_index=1,
seq_len=write_seq_length + 1,
page_ids=page_ids,
diff --git a/sharktank/tests/layers/linear_test.py b/sharktank/tests/layers/linear_test.py
index e2d038f72..ad657889d 100644
--- a/sharktank/tests/layers/linear_test.py
+++ b/sharktank/tests/layers/linear_test.py
@@ -84,7 +84,7 @@ def testNativeQuant_SymPerTensor_AsymPerAxis0_Dynamic(self):
bias_quant,
]
)
- linear = LinearLayer(theta)
+ linear = LinearLayer(theta, fake_quant=False)
output = linear(lhs)
output_ref = torch.matmul(lhs, rhs.T) + bias
diff --git a/sharktank/tests/layers/sharded_paged_kv_cache_test.py b/sharktank/tests/layers/sharded_paged_kv_cache_test.py
index 766c4e804..d7b6a0b33 100644
--- a/sharktank/tests/layers/sharded_paged_kv_cache_test.py
+++ b/sharktank/tests/layers/sharded_paged_kv_cache_test.py
@@ -104,7 +104,7 @@ def testRead(self):
sharded_cache_state,
) = self.make_unsharded_and_sharded_equal_cache_states()
- dest_partitions_snapshot = [
+ read_into_partitions_snapshot = [
torch.rand(
self.batch_size,
self.block_seq_len * self.block_seq_stride,
@@ -113,33 +113,35 @@ def testRead(self):
)
for _ in range(self.cache_partition_count)
]
- dest_partitions = deepcopy(dest_partitions_snapshot)
+ read_into_partitions = deepcopy(read_into_partitions_snapshot)
transformer_block_index = 1
page_ids = torch.randint(
low=0, high=self.page_count, size=[self.batch_size, self.block_seq_len]
).reshape([self.batch_size, self.block_seq_len])
self.cache.read(
state=cache_state,
- dest_partitions=dest_partitions,
+ read_into_partitions=read_into_partitions,
transformer_block_index=transformer_block_index,
page_ids=page_ids,
seq_len=self.block_seq_len * self.block_seq_stride,
)
- sharded_dest_partitions = deepcopy(
+ sharded_read_into_partitions = deepcopy(
[
ops.reshard_split(t, dim=2, count=self.shard_count)
- for t in dest_partitions_snapshot
+ for t in read_into_partitions_snapshot
]
)
sharded_page_ids = ops.replicate(page_ids, count=self.shard_count)
self.sharded_cache.read(
state=sharded_cache_state,
- dest_partitions=sharded_dest_partitions,
+ read_into_partitions=sharded_read_into_partitions,
transformer_block_index=transformer_block_index,
page_ids=sharded_page_ids,
seq_len=self.block_seq_len * self.block_seq_stride,
)
- for unsharded, sharded in zip(dest_partitions, sharded_dest_partitions):
+ for unsharded, sharded in zip(
+ read_into_partitions, sharded_read_into_partitions
+ ):
assert ops.equal(unsharded, ops.unshard(sharded))
def testWriteTimestep(self):
diff --git a/sharktank/tests/layers/sharded_rotary_embedding_test.py b/sharktank/tests/layers/sharded_rotary_embedding_test.py
index 963b9b432..f24b8313a 100644
--- a/sharktank/tests/layers/sharded_rotary_embedding_test.py
+++ b/sharktank/tests/layers/sharded_rotary_embedding_test.py
@@ -35,7 +35,8 @@ def test_sharded_rotary_table():
max_seqlen=max_seqlen,
rope_freq_base=rope_freq_base,
)
- oq, ok = default_layer(xq=xq, xk=xk, start_index=0)
+ oq = default_layer(xt=xq, start_index=0)
+ ok = default_layer(xt=xk, start_index=0)
# Then we can shard the same inputs and layer
xq = SplitPrimitiveTensor(ts=xq, shard_dim=2, shard_count=4)
@@ -46,7 +47,8 @@ def test_sharded_rotary_table():
rope_freq_base=rope_freq_base,
tensor_parallelism_size=4,
)
- sq, sk = shard_layer(xq=xq, xk=xk, start_index=0)
+ sq = shard_layer(xt=xq, start_index=0)
+ sk = shard_layer(xt=xk, start_index=0)
# Gathering and unboxing should yield the same results
sq = ops.unshard(sq)
diff --git a/sharktank/tests/models/llama/attention_test.py b/sharktank/tests/models/llama/attention_test.py
index daeefd93b..211fab5a0 100644
--- a/sharktank/tests/models/llama/attention_test.py
+++ b/sharktank/tests/models/llama/attention_test.py
@@ -59,7 +59,7 @@ def test(self):
head_dim=head_dim,
head_count_kv=head_count_kv,
rms_epsilon=rms_epsilon,
- attention_kernel="torch",
+ attention_kernel="decomposed",
)
attention_embedding = RotaryEmbeddingLayer(
rope_dimension_count=rope_dimension_count,
diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py
index adbfeaf7e..f70607832 100644
--- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py
+++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py
@@ -255,7 +255,7 @@ def testBenchmark8B_fp8_Decomposed(self):
)
@pytest.mark.xfail(
- reason="Test not yet implemented", strict=True, raises=ExportMlirException
+ reason="Compile failure", strict=True, raises=ExportMlirException
)
def testBenchmark8B_fp8_Non_Decomposed(self):
output_file_name = self.dir_path_8b / "fp8_torch"
diff --git a/sharktank/tests/ops/ops_test.py b/sharktank/tests/ops/ops_test.py
index 8b37525e5..ad6759ce6 100644
--- a/sharktank/tests/ops/ops_test.py
+++ b/sharktank/tests/ops/ops_test.py
@@ -136,7 +136,7 @@ def testMatchFail(self):
):
ops.matmul(1, 2)
- @unittest.skip("https://github.com/nod-ai/SHARK-Platform/issues/44")
+ @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44")
def testTorchImplTransposedRHS(self):
ops._registry._test_enable_last_op_dispatch(True)
t1 = torch.rand(32, 16, dtype=torch.float32)
@@ -149,7 +149,7 @@ def testTorchImplTransposedRHS(self):
ops.custom_impls.matmul_mmtfp_tensor_tensor,
)
- @unittest.skip("https://github.com/nod-ai/SHARK-Platform/issues/44")
+ @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44")
def testTorchImplNonTransposedRHS(self):
ops._registry._test_enable_last_op_dispatch(True)
t1 = torch.rand(32, 16, dtype=torch.float32)
@@ -162,7 +162,7 @@ def testTorchImplNonTransposedRHS(self):
ops.custom_impls.matmul_mmtfp_tensor_tensor,
)
- @unittest.skip("https://github.com/nod-ai/SHARK-Platform/issues/44")
+ @unittest.skip("https://github.com/nod-ai/shark-ai/issues/44")
def testTorchImplTransposedPrimitiveRHS(self):
ops._registry._test_enable_last_op_dispatch(True)
t1 = torch.rand(32, 16, dtype=torch.float32)
diff --git a/sharktank/version.json b/sharktank/version.json
new file mode 100644
index 000000000..f09f61d2a
--- /dev/null
+++ b/sharktank/version.json
@@ -0,0 +1,3 @@
+{
+ "package-version": "2.9.2.dev"
+}
diff --git a/sharktank/version_info.json b/sharktank/version_info.json
deleted file mode 100644
index ca3c0ed0b..000000000
--- a/sharktank/version_info.json
+++ /dev/null
@@ -1,3 +0,0 @@
-{
- "package-version": "2.9.0.dev"
-}
diff --git a/shortfin/CMakeLists.txt b/shortfin/CMakeLists.txt
index 85113ce00..93ee63594 100644
--- a/shortfin/CMakeLists.txt
+++ b/shortfin/CMakeLists.txt
@@ -14,7 +14,7 @@ if(CMAKE_SOURCE_DIR STREQUAL CMAKE_BINARY_DIR)
endif()
# Get version number from file
-file(READ ${CMAKE_CURRENT_SOURCE_DIR}/version_info.json VERSION_JSON_STRING)
+file(READ ${CMAKE_CURRENT_SOURCE_DIR}/version.json VERSION_JSON_STRING)
string(JSON PACKAGE_VERSION GET ${VERSION_JSON_STRING} package-version)
string(REGEX MATCH "(0|[1-9][0-9]*)(\.(0|[1-9][0-9]*))*" BASE_VERSION ${PACKAGE_VERSION})
@@ -40,7 +40,7 @@ if(NOT WIN32)
endif()
# Pins
-set(SHORTFIN_IREE_GIT_TAG "iree-2.9.0rc20241108")
+set(SHORTFIN_IREE_GIT_TAG "iree-3.0.0rc20241115")
# build options
option(SHORTFIN_BUILD_PYTHON_BINDINGS "Builds Python Bindings" OFF)
diff --git a/shortfin/README.md b/shortfin/README.md
index 3e7901342..6269ca702 100644
--- a/shortfin/README.md
+++ b/shortfin/README.md
@@ -1,20 +1,52 @@
-# shortfin - SHARK C++ inference library
+# shortfin - SHARK inference library and serving engine
-## Simple User Installation
+The shortfin project is SHARK's open source, high performance inference library
+and serving engine. Shortfin consists of these major components:
-Install:
+* The "libshortfin" inference library written in C/C++ and built on
+ [IREE](https://github.com/iree-org/iree)
+* Python bindings for the underlying inference library
+* Example applications in
+ ['shortfin_apps'](https://github.com/nod-ai/shark-ai/tree/main/shortfin/python/shortfin_apps)
+ built using the python bindings
+## Prerequisites
+
+* Python 3.11+
+
+## Simple user installation
+
+Install the latest stable version:
+
+```bash
+pip install shortfin
```
-python -m pip install .
+
+## Developer guides
+
+### Quick start: install local packages and run tests
+
+After cloning this repository, from the `shortfin/` directory:
+
+```bash
+pip install -e .
```
-Run tests:
+Install test requirements:
+```bash
+pip install -r requirements-tests.txt
```
-python -m pytest -s tests/
+
+Run tests:
+
+```bash
+pytest -s tests/
```
-## Simple Dev Setup
+### Simple dev setup
+
+We recommend this development setup for core contributors:
1. Check out this repository as a sibling to [IREE](https://github.com/iree-org/iree)
if you already have an IREE source checkout. Otherwise, a pinned version will
@@ -36,7 +68,7 @@ python -m pytest -s tests/
Refer to the advanced build options below for other scenarios.
-## Advanced Build Options
+### Advanced build options
1. Native C++ build
2. Local Python release build
@@ -48,7 +80,7 @@ Prerequisites
* A modern C/C++ compiler, such as clang 18 or gcc 12
* A modern Python, such as Python 3.12
-### Native C++ Builds
+#### Native C++ builds
```bash
cmake -GNinja -S. -Bbuild \
@@ -61,13 +93,7 @@ If Python bindings are enabled in this mode (`-DSHORTFIN_BUILD_PYTHON_BINDINGS=O
then `pip install -e build/` will install from the build dir (and support
build/continue).
-### Local Python Release Builds
-
-```bash
-pip install -v -e .
-```
-
-### Package Python Release Builds
+#### Package Python release builds
* To build wheels for Linux using a manylinux Docker container:
@@ -86,7 +112,7 @@ pip install -v -e .
python3 -m pip install dist/*.whl
```
-### Python Dev Builds
+#### Python dev builds
```bash
# Install build system pre-reqs (since we are building in dev mode, this
@@ -124,7 +150,7 @@ Several optional environment variables can be used with setup.py:
* `SHORTFIN_RUN_CTESTS=ON` : Runs `ctest` as part of the build. Useful for CI
as it uses the version of ctest installed in the pip venv.
-### Running Tests
+### Running tests
The project uses a combination of ctest for native C++ tests and pytest. Much
of the functionality is only tested via the Python tests, using the
@@ -156,7 +182,7 @@ pytest tests/ --system amdgpu \
--compile-flags="--iree-hal-target-backends=rocm --iree-hip-target=gfx1100"
```
-# Production Library Building
+## Production library building
In order to build a production library, additional build steps are typically
recommended:
@@ -167,23 +193,23 @@ recommended:
* Enable LTO builds of libshortfin
* Set flags to enable symbol versioning
-# Miscellaneous Build Topics
+## Miscellaneous build topics
-## Free-threaded Python
+### Free-threaded Python
Support for free-threaded Python builds (aka. "nogil") is in progress. It
-is currently being tested via dev builds of CPython 3.13 with the
-`--disable-gil` option set. There are multiple ways to acquire such an
-environment. If using `pyenv`, here is a way:
+is currently being tested via CPython 3.13 with the `--disable-gil` option set.
+There are multiple ways to acquire such an environment:
-```
-# Build a free-threaded 3.13 version.
-pyenv install --debug 3.13t-dev
+* Generally, see the documentation at
+
+* If using `pyenv`:
-# Test (should print "False").
-pyenv shell 3.13t-dev
-python -c 'import sys; print(sys._is_gil_enabled())'
-```
+ ```bash
+ # Install a free-threaded 3.13 version.
+ pyenv install 3.13t
-Further ways of installing a free-threaded CPython interpreter are documented at
-[py-free-threading.github.io](https://py-free-threading.github.io/installing_cpython/).
+ # Test (should print "False").
+ pyenv shell 3.13t
+ python -c 'import sys; print(sys._is_gil_enabled())'
+ ```
diff --git a/shortfin/build_tools/cmake/shortfin_library.cmake b/shortfin/build_tools/cmake/shortfin_library.cmake
index 23755fb9d..aaa97a6c1 100644
--- a/shortfin/build_tools/cmake/shortfin_library.cmake
+++ b/shortfin/build_tools/cmake/shortfin_library.cmake
@@ -80,7 +80,7 @@ function(shortfin_public_library)
PRIVATE ${_DYLIB_COMPONENTS}
)
set_target_properties("${_RULE_NAME}" PROPERTIES
- VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}
+ VERSION ${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}.${PROJECT_VERSION_PATCH}
SOVERSION ${SOVERSION}
)
endif()
diff --git a/shortfin/dev_me.py b/shortfin/dev_me.py
index ca6916767..8eacca274 100755
--- a/shortfin/dev_me.py
+++ b/shortfin/dev_me.py
@@ -31,8 +31,8 @@
# Otherwise, the shortfin build will download a pinned IREE source tree.
import argparse
+import importlib
import os
-from packaging.version import Version
from pathlib import Path
import re
import subprocess
@@ -40,10 +40,19 @@
import sys
import sysconfig
+try:
+ from packaging.version import Version
+except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"'packaging' package not installed and required: Install with:\n"
+ f" {sys.executable} -m pip install packaging"
+ )
+
CMAKE_REQUIRED_VERSION = Version("3.29")
PYTHON_REQUIRED_VERSION = Version("3.12")
CLANG_REQUIRED_VERSION = Version("16")
+SETUPTOOLS_REQUIRED_VERSION = Version("61.0")
class EnvInfo:
@@ -58,6 +67,8 @@ def __init__(self, args):
self.ninja_exe = shutil.which("ninja")
self.clang_exe, self.clang_version = self.find_clang(args)
self.iree_dir = self.find_iree(args)
+ self.setuptools_version = self.find_package_version("setuptools")
+ self.wheel_version = self.find_package_version("wheel")
self.configured_dirs = []
self.add_configured(self.this_dir / "build" / "cmake" / "default")
@@ -94,12 +105,10 @@ def find_clang(self, args):
clang_exe = shutil.which("clang")
if not clang_exe:
return None, None
- try:
- clang_output = subprocess.check_output(
- [clang_exe, "--version"]
- ).decode()
- except:
- return None, None
+ try:
+ clang_output = subprocess.check_output([clang_exe, "--version"]).decode()
+ except:
+ return None, None
if m := re.search(r"clang version ([0-9\.]+)", clang_output):
return clang_exe, Version(m.group(1))
return None, None
@@ -116,6 +125,13 @@ def find_iree(self, args):
sys.exit(1)
return str(iree_dir)
+ def find_package_version(self, package_name: str) -> Version | None:
+ try:
+ m = importlib.import_module(package_name)
+ except ModuleNotFoundError:
+ return None
+ return Version(m.__version__)
+
def check_prereqs(self, args):
if self.cmake_version is None or self.cmake_version < CMAKE_REQUIRED_VERSION:
print(
@@ -131,7 +147,7 @@ def check_prereqs(self, args):
)
sys.exit(1)
if self.clang_exe and self.clang_version < CLANG_REQUIRED_VERSION:
- print(f"ERROR: clang version too old: {self.clang_exe}")
+ print(f"WARNING: clang version too old: {self.clang_exe}")
print(f" REQUIRED: {CLANG_REQUIRED_VERSION}, Found {self.clang_version}")
elif not self.clang_exe:
print(f"WARNING: Building the project with clang is highly recommended")
@@ -143,6 +159,19 @@ def check_prereqs(self, args):
)
sys.exit(1)
+ if (
+ self.setuptools_version is None
+ or self.setuptools_version < SETUPTOOLS_REQUIRED_VERSION
+ ):
+ print(
+ f"ERROR: 'setuptools' packaging is not installed or too old. "
+ f"Found {self.setuptools_version}, Need {SETUPTOOLS_REQUIRED_VERSION}"
+ )
+ sys.exit(1)
+ if self.wheel_version is None:
+ print(f"'wheel' package is not installed")
+ sys.exit(1)
+
def __repr__(self):
report = [
f"python: {self.python_exe}",
@@ -153,6 +182,8 @@ def __repr__(self):
f"ninja: {self.ninja_exe}",
f"clang: {self.clang_exe} ({self.clang_version})",
f"iree: {self.iree_dir}",
+ f"setuptools: {self.setuptools_version}",
+ f"wheel: {self.wheel_version}",
]
return "\n".join(report)
@@ -211,7 +242,7 @@ def configure_mode(env_info: EnvInfo, args):
"-e",
str(env_info.this_dir),
]
- print(f"{' '.join('='.join(kv) for kv in env_vars.items())} \\")
+ print(f"{' '.join('='.join(str(kv)) for kv in env_vars.items())} \\")
print(f" {' '.join(setup_args)}")
actual_env_vars = dict(os.environ)
actual_env_vars.update(env_vars)
diff --git a/shortfin/pyproject.toml b/shortfin/pyproject.toml
index eb54c835b..1abb49ef6 100644
--- a/shortfin/pyproject.toml
+++ b/shortfin/pyproject.toml
@@ -4,9 +4,46 @@ requires = [
"setuptools>=61.0",
"wheel",
"ninja",
+ 'typing-extensions ; python_version == "3.10" ',
]
build-backend = "setuptools.build_meta"
+[project]
+name = "shortfin"
+authors = [
+ {name = "SHARK Authors"},
+]
+description = "SHARK inference library and serving engine"
+readme = "README.md"
+license = {text = "Apache-2.0"}
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+]
+requires-python = ">= 3.10"
+
+# Version is set via the `setup.py`.
+dynamic = ["version"]
+
+[project.urls]
+Repository = "https://github.com/nod-ai/shark-ai"
+Documentation = "https://shortfin.readthedocs.io/en/latest/"
+
+[project.optional-dependencies]
+apps = [
+ "transformers",
+ "dataclasses-json",
+ "pillow",
+ "fastapi",
+ "uvicorn",
+ "aiohttp",
+]
+
[tool.pytest.ini_options]
addopts = [
"-ra",
diff --git a/shortfin/python/_shortfin/asyncio_bridge.py b/shortfin/python/_shortfin/asyncio_bridge.py
index 4cb54449c..28264e9e3 100644
--- a/shortfin/python/_shortfin/asyncio_bridge.py
+++ b/shortfin/python/_shortfin/asyncio_bridge.py
@@ -5,10 +5,19 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import asyncio
+import inspect
from . import lib as sfl
+# Feature detect some versions where signatures changes.
+if "context" in inspect.signature(asyncio.Task).parameters:
+ # Python > 3.10
+ _ASYNCIO_TASK_HAS_CONTEXT = True
+else:
+ _ASYNCIO_TASK_HAS_CONTEXT = False
+
+
class PyWorkerEventLoop(asyncio.AbstractEventLoop):
def __init__(self, worker: sfl.local.Worker):
self._worker = worker
@@ -17,8 +26,15 @@ def get_debug(self):
# Requirement of asyncio.
return False
- def create_task(self, coro, *, name=None, context=None):
- return asyncio.Task(coro, loop=self, name=name, context=context)
+ if _ASYNCIO_TASK_HAS_CONTEXT:
+
+ def create_task(self, coro, *, name=None, context=None):
+ return asyncio.Task(coro, loop=self, name=name, context=context)
+
+ else:
+
+ def create_task(self, coro, *, name=None):
+ return asyncio.Task(coro, loop=self, name=name)
def create_future(self):
return asyncio.Future(loop=self)
diff --git a/shortfin/python/array_binding.cc b/shortfin/python/array_binding.cc
index da7197b14..a05232674 100644
--- a/shortfin/python/array_binding.cc
+++ b/shortfin/python/array_binding.cc
@@ -7,6 +7,7 @@
#include "./lib_ext.h"
#include "./utils.h"
#include "shortfin/array/api.h"
+#include "shortfin/support/logging.h"
using namespace shortfin::array;
@@ -223,6 +224,7 @@ class PyMapping {
}
void FillFromScalar(Refs *refs, py::handle value) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::FillFromScalar");
if (!dtype()) {
throw std::invalid_argument(
"The `fill` method is only valid for typed mappings but "
@@ -242,6 +244,7 @@ class PyMapping {
}
void FillFromBuffer(py::handle buffer) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::FillFromBuffer");
Py_buffer py_view;
int flags = PyBUF_FORMAT | PyBUF_ND; // C-Contiguous ND.
if (PyObject_GetBuffer(buffer.ptr(), &py_view, flags) != 0) {
@@ -286,6 +289,7 @@ class PyMapping {
}
py::object GetItems(py::handle self_obj, Refs *refs) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::GetItems");
if (!dtype()) {
throw std::invalid_argument(
"The `items` property is only valid for typed mappings but "
@@ -306,6 +310,7 @@ class PyMapping {
}
void SetItems(Refs *refs, py::handle initializer) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyMapping::SetItems");
if (!dtype()) {
throw std::invalid_argument(
"The `items` property is only valid for typed mappings but "
@@ -410,6 +415,7 @@ void BindArray(py::module_ &m) {
.def(
"map",
[](storage &self, bool read, bool write, bool discard) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyStorage::map");
int access = 0;
if (read) access |= IREE_HAL_MEMORY_ACCESS_READ;
if (write || discard) access |= IREE_HAL_MEMORY_ACCESS_WRITE;
@@ -565,6 +571,7 @@ void BindArray(py::module_ &m) {
.def(
"map",
[](device_array &self, bool read, bool write, bool discard) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyArray::map");
int access = 0;
if (read) access |= IREE_HAL_MEMORY_ACCESS_READ;
if (write || discard) access |= IREE_HAL_MEMORY_ACCESS_WRITE;
@@ -586,6 +593,7 @@ void BindArray(py::module_ &m) {
.def_prop_rw(
"items",
[refs](device_array &self) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyArray::items");
PyMapping *mapping;
py::object mapping_obj = CreateMappingObject(&mapping);
mapping->set_dtype(self.dtype());
@@ -606,6 +614,7 @@ void BindArray(py::module_ &m) {
.def_prop_ro(
"__array_interface__",
[refs](device_array &self) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyArray::__array_interface__");
py::dict interface;
interface["version"] = 3;
interface["strides"] = py::none();
diff --git a/shortfin/python/array_host_ops.cc b/shortfin/python/array_host_ops.cc
index 8c4af0070..86385cfee 100644
--- a/shortfin/python/array_host_ops.cc
+++ b/shortfin/python/array_host_ops.cc
@@ -38,6 +38,34 @@ Implemented for dtypes: float16, float32.
A device_array of dtype=int64, allocated on the host and not visible to the device.
)";
+static const char DOCSTRING_CONVERT[] =
+ R"(Does an elementwise conversion from one dtype to another.
+
+The same behavior exists for several conversion ops:
+
+* `convert` : element-wise conversion like a static cast.
+* `round` : element-wise nearest integer to the input, rounding halfway cases
+ away from zero.
+* `ceil` : element-wise smallest integer value not less than the input.
+* `floor` : element-wise smallest integer value not greater than the input.
+* `trunc` : element-wise nearest integer not greater in magnitude than the input.
+
+For nearest-integer conversions (round, ceil, floor, trunc), the input dtype
+must be a floating point array, and the output must be a byte-aligned integer
+type between 8 and 32 bits.
+
+Args:
+ input: An input array of a floating point dtype.
+ dtype: If given, then this is the explicit output dtype.
+ out: If given, then the results are written to this array. This implies the
+ output dtype.
+ device_visible: Whether to make the result array visible to devices. Defaults to
+ False.
+
+Returns:
+ A device_array of the requested dtype, or the input dtype if not specified.
+)";
+
static const char DOCSTRING_FILL_RANDN[] =
R"(Fills an array with numbers sampled from the standard ormal distribution.
@@ -63,7 +91,14 @@ static const char DOCSTRING_RANDOM_GENERATOR[] =
fixed number.
)";
-} // namespace
+#define SF_UNARY_FUNCTION_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): \
+ return compute.template operator()()
+
+#define SF_UNARY_THUNK_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): \
+ compute.template operator()(); \
+ break
struct PyRandomGenerator {
public:
@@ -85,9 +120,261 @@ struct PyRandomGenerator {
xt::random::default_engine_type engine_;
};
-#define SF_UNARY_COMPUTE_CASE(dtype_name, cpp_type) \
- case DType::dtype_name(): \
- return compute.template operator()()
+// Generic conversion templates, split into a bindable template and functors
+// that operate on pre-allocated outputs.
+template
+device_array GenericElementwiseConvert(device_array &input,
+ std::optional dtype,
+ std::optional out,
+ bool device_visible) {
+ // Argument check and output allocation.
+ if (!dtype) {
+ dtype = out ? out->dtype() : input.dtype();
+ } else {
+ if (out && out->dtype() != dtype) {
+ throw std::invalid_argument(
+ "if both dtype and out are specified, they must match");
+ }
+ }
+ if (!out) {
+ out.emplace(device_array::for_host(input.device(), input.shape(), *dtype,
+ device_visible));
+ }
+
+ ConvertFunc::Invoke(input, *dtype, *out);
+ return *out;
+}
+
+// Generic elementwise conversion functor
+struct ConvertFunctor {
+ static void Invoke(device_array &input, DType dtype, device_array &out) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::convert");
+ auto compute = [&]() -> void {
+ auto input_t = input.map_xtensor();
+ // Casted output.
+#define SF_STORE_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): { \
+ auto out_t = out.map_xtensor_w(); \
+ *out_t = xt::cast(*input_t); \
+ break; \
+ }
+ switch (dtype) {
+ SF_STORE_CASE(float16, half_float::half);
+ SF_STORE_CASE(float32, float);
+ SF_STORE_CASE(float64, double);
+ SF_STORE_CASE(uint8, uint8_t);
+ SF_STORE_CASE(int8, int8_t);
+ SF_STORE_CASE(uint16, uint16_t);
+ SF_STORE_CASE(int16, int16_t);
+ SF_STORE_CASE(uint32, uint32_t);
+ SF_STORE_CASE(int32, int32_t);
+ SF_STORE_CASE(uint64, uint64_t);
+ SF_STORE_CASE(int64, int64_t);
+ default:
+ throw std::invalid_argument("Invalid output dtype for convert op");
+ }
+
+#undef SF_STORE_CASE
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_THUNK_CASE(float16, half_float::half);
+ SF_UNARY_THUNK_CASE(float32, float);
+ SF_UNARY_THUNK_CASE(float64, double);
+ SF_UNARY_THUNK_CASE(uint8, uint8_t);
+ SF_UNARY_THUNK_CASE(int8, int8_t);
+ SF_UNARY_THUNK_CASE(uint16, uint16_t);
+ SF_UNARY_THUNK_CASE(int16, int16_t);
+ SF_UNARY_THUNK_CASE(uint32, uint32_t);
+ SF_UNARY_THUNK_CASE(int32, uint32_t);
+ SF_UNARY_THUNK_CASE(uint64, uint64_t);
+ SF_UNARY_THUNK_CASE(int64, int64_t);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for converting nearest integer op",
+ dtype.name()));
+ }
+ }
+};
+
+// Converting round functor.
+struct ConvertRoundFunctor {
+ static void Invoke(device_array &input, DType dtype, device_array &out) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::round");
+ auto compute = [&]() -> void {
+ auto input_t = input.map_xtensor();
+ auto rounded = xt::round(*input_t);
+ if (input.dtype() == dtype) {
+ // Same type output.
+ auto out_t = out.map_xtensor_w();
+ *out_t = rounded;
+ } else {
+ // Casted output.
+#define SF_STORE_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): { \
+ auto out_t = out.map_xtensor_w(); \
+ *out_t = xt::cast(rounded); \
+ break; \
+ }
+ switch (dtype) {
+ SF_STORE_CASE(uint8, uint8_t);
+ SF_STORE_CASE(int8, int8_t);
+ SF_STORE_CASE(uint16, uint16_t);
+ SF_STORE_CASE(int16, int16_t);
+ SF_STORE_CASE(uint32, uint32_t);
+ SF_STORE_CASE(int32, int32_t);
+ default:
+ throw std::invalid_argument(
+ "Invalid output dtype for converting nearest integer op");
+ }
+ }
+#undef SF_STORE_CASE
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_THUNK_CASE(float16, half_float::half);
+ SF_UNARY_THUNK_CASE(float32, float);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for converting nearest integer op",
+ dtype.name()));
+ }
+ }
+};
+
+struct ConvertCeilFunctor {
+ static void Invoke(device_array &input, DType dtype, device_array &out) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::ceil");
+ auto compute = [&]() -> void {
+ auto input_t = input.map_xtensor();
+ auto rounded = xt::ceil(*input_t);
+ if (input.dtype() == dtype) {
+ // Same type output.
+ auto out_t = out.map_xtensor_w();
+ *out_t = rounded;
+ } else {
+ // Casted output.
+#define SF_STORE_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): { \
+ auto out_t = out.map_xtensor_w(); \
+ *out_t = xt::cast(rounded); \
+ break; \
+ }
+ switch (dtype) {
+ SF_STORE_CASE(uint8, uint8_t);
+ SF_STORE_CASE(int8, int8_t);
+ SF_STORE_CASE(uint16, uint16_t);
+ SF_STORE_CASE(int16, int16_t);
+ SF_STORE_CASE(uint32, uint32_t);
+ SF_STORE_CASE(int32, int32_t);
+ default:
+ throw std::invalid_argument(
+ "Invalid output dtype for converting nearest integer op");
+ }
+ }
+#undef SF_STORE_CASE
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_THUNK_CASE(float16, half_float::half);
+ SF_UNARY_THUNK_CASE(float32, float);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for converting nearest integer op",
+ dtype.name()));
+ }
+ }
+};
+
+struct ConvertFloorFunctor {
+ static void Invoke(device_array &input, DType dtype, device_array &out) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::floor");
+ auto compute = [&]() -> void {
+ auto input_t = input.map_xtensor();
+ auto rounded = xt::floor(*input_t);
+ if (input.dtype() == dtype) {
+ // Same type output.
+ auto out_t = out.map_xtensor_w();
+ *out_t = rounded;
+ } else {
+ // Casted output.
+#define SF_STORE_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): { \
+ auto out_t = out.map_xtensor_w(); \
+ *out_t = xt::cast(rounded); \
+ break; \
+ }
+ switch (dtype) {
+ SF_STORE_CASE(uint8, uint8_t);
+ SF_STORE_CASE(int8, int8_t);
+ SF_STORE_CASE(uint16, uint16_t);
+ SF_STORE_CASE(int16, int16_t);
+ SF_STORE_CASE(uint32, uint32_t);
+ SF_STORE_CASE(int32, int32_t);
+ default:
+ throw std::invalid_argument(
+ "Invalid output dtype for converting nearest integer op");
+ }
+ }
+#undef SF_STORE_CASE
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_THUNK_CASE(float16, half_float::half);
+ SF_UNARY_THUNK_CASE(float32, float);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for converting nearest integer op",
+ dtype.name()));
+ }
+ }
+};
+
+struct ConvertTruncFunctor {
+ static void Invoke(device_array &input, DType dtype, device_array &out) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::trunc");
+ auto compute = [&]() -> void {
+ auto input_t = input.map_xtensor();
+ auto rounded = xt::trunc(*input_t);
+ if (input.dtype() == dtype) {
+ // Same type output.
+ auto out_t = out.map_xtensor_w();
+ *out_t = rounded;
+ } else {
+ // Casted output.
+#define SF_STORE_CASE(dtype_name, cpp_type) \
+ case DType::dtype_name(): { \
+ auto out_t = out.map_xtensor_w(); \
+ *out_t = xt::cast(rounded); \
+ break; \
+ }
+ switch (dtype) {
+ SF_STORE_CASE(uint8, uint8_t);
+ SF_STORE_CASE(int8, int8_t);
+ SF_STORE_CASE(uint16, uint16_t);
+ SF_STORE_CASE(int16, int16_t);
+ SF_STORE_CASE(uint32, uint32_t);
+ SF_STORE_CASE(int32, int32_t);
+ default:
+ throw std::invalid_argument(
+ "Invalid output dtype for converting nearest integer op");
+ }
+ }
+#undef SF_STORE_CASE
+ };
+
+ switch (input.dtype()) {
+ SF_UNARY_THUNK_CASE(float16, half_float::half);
+ SF_UNARY_THUNK_CASE(float32, float);
+ default:
+ throw std::invalid_argument(fmt::format(
+ "Unsupported dtype({}) for converting nearest integer op",
+ dtype.name()));
+ }
+ }
+};
+
+} // namespace
void BindArrayHostOps(py::module_ &m) {
// Simple op definitions.
@@ -95,6 +382,7 @@ void BindArrayHostOps(py::module_ &m) {
"argmax",
[](device_array &input, int axis, std::optional out,
bool keepdims, bool device_visible) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::argmax");
if (axis < 0) axis += input.shape().size();
if (axis < 0 || axis >= input.shape().size()) {
throw std::invalid_argument(
@@ -120,8 +408,8 @@ void BindArrayHostOps(py::module_ &m) {
};
switch (input.dtype()) {
- SF_UNARY_COMPUTE_CASE(float16, half_float::half);
- SF_UNARY_COMPUTE_CASE(float32, float);
+ SF_UNARY_FUNCTION_CASE(float16, half_float::half);
+ SF_UNARY_FUNCTION_CASE(float32, float);
default:
throw std::invalid_argument(
fmt::format("Unsupported dtype({}) for operator argmax",
@@ -139,6 +427,7 @@ void BindArrayHostOps(py::module_ &m) {
m.def(
"fill_randn",
[](device_array out, std::optional gen) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyHostOp::fill_randn");
if (!gen) gen = &PyRandomGenerator::get_default();
auto compute = [&]() {
auto result = xt::random::randn(out.shape_container(), /*mean=*/0.0,
@@ -148,8 +437,8 @@ void BindArrayHostOps(py::module_ &m) {
};
switch (out.dtype()) {
- SF_UNARY_COMPUTE_CASE(float16, half_float::half);
- SF_UNARY_COMPUTE_CASE(float32, float);
+ SF_UNARY_FUNCTION_CASE(float16, half_float::half);
+ SF_UNARY_FUNCTION_CASE(float32, float);
default:
throw std::invalid_argument(
fmt::format("Unsupported dtype({}) for operator randn",
@@ -157,6 +446,17 @@ void BindArrayHostOps(py::module_ &m) {
}
},
py::arg("out"), py::arg("generator") = py::none(), DOCSTRING_FILL_RANDN);
+
+// Data-type conversion and rounding.
+#define SF_DEF_CONVERT(py_name, target) \
+ m.def(py_name, target, py::arg("input"), py::kw_only(), \
+ py::arg("dtype") = py::none(), py::arg("out") = py::none(), \
+ py::arg("device_visible") = false, DOCSTRING_CONVERT)
+ SF_DEF_CONVERT("convert", GenericElementwiseConvert);
+ SF_DEF_CONVERT("ceil", GenericElementwiseConvert);
+ SF_DEF_CONVERT("floor", GenericElementwiseConvert);
+ SF_DEF_CONVERT("round", GenericElementwiseConvert);
+ SF_DEF_CONVERT("trunc", GenericElementwiseConvert);
}
} // namespace shortfin::python
diff --git a/shortfin/python/lib_ext.cc b/shortfin/python/lib_ext.cc
index 0bfb67588..d17606b4b 100644
--- a/shortfin/python/lib_ext.cc
+++ b/shortfin/python/lib_ext.cc
@@ -173,6 +173,7 @@ class PyWorkerExtension : public local::Worker::Extension {
py::handle loop() { return loop_; }
void OnThreadStart() noexcept override {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::OnThreadStart");
// Python threading initialization.
// If our own thread, teach Python about it. Not done for donated.
if (worker().options().owned_thread) {
@@ -187,6 +188,7 @@ class PyWorkerExtension : public local::Worker::Extension {
}
void OnThreadStop() noexcept override {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::OnThreadStop");
{
// Do Python level thread cleanup.
py::gil_scoped_acquire g;
@@ -253,6 +255,7 @@ class PyProcess : public local::detail::BaseProcess {
std::bind(&PyProcess::RunOnWorker, self_object));
}
static void RunOnWorker(py::handle self_handle) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyProcess:RunOnWorker");
py::gil_scoped_acquire g;
// Steal the reference back from ScheduleOnWorker. Important: this is
// very likely the last reference to the process. So self must not be
@@ -342,6 +345,7 @@ py::object PyRehydrateRef(local::ProgramInvocation *inv,
py::object RunInForeground(std::shared_ptr refs, local::System &self,
py::object coro) {
+ SHORTFIN_TRACE_SCOPE_NAMED("CoroRunInForeground");
bool is_main_thread =
refs->threading_current_thread().is(refs->threading_main_thread());
@@ -936,6 +940,7 @@ void BindLocal(py::module_ &m) {
callable.inc_ref(); // Stolen within the callback.
auto thunk = +[](void *user_data, iree_loop_t loop,
iree_status_t status) noexcept -> iree_status_t {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::Callback");
py::gil_scoped_acquire g;
py::object user_callable =
py::steal(static_cast(user_data));
@@ -955,6 +960,7 @@ void BindLocal(py::module_ &m) {
callable.inc_ref(); // Stolen within the callback.
auto thunk = +[](void *user_data, iree_loop_t loop,
iree_status_t status) noexcept -> iree_status_t {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyWorker::DelayCallback");
py::gil_scoped_acquire g;
py::object user_callable =
py::steal(static_cast(user_data));
@@ -1030,6 +1036,7 @@ void BindLocal(py::module_ &m) {
py::class_(m, "CompletionEvent")
.def(py::init<>())
.def("__await__", [](py::handle self_obj) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyCompletionEvent::__await__");
auto &worker_ext = PyWorkerExtension::GetCurrent();
auto &self = py::cast(self_obj);
py::object future = worker_ext.loop().attr("create_future")();
@@ -1051,6 +1058,7 @@ void BindLocal(py::module_ &m) {
self, iree_infinite_timeout(),
+[](void *future_vp, iree_loop_t loop,
iree_status_t status) noexcept -> iree_status_t {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyCompletionEvent::OnComplete");
py::gil_scoped_acquire g;
py::object future = py::steal(static_cast(future_vp));
try {
@@ -1145,6 +1153,7 @@ void BindLocal(py::module_ &m) {
return py::none();
})
.def("__await__", [](py::handle self_obj) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyFuture::__await__");
// TODO: We should make our C++ future able to be used directly
// vs needing to bridge it like this.
auto &worker_ext = PyWorkerExtension::GetCurrent();
@@ -1166,6 +1175,7 @@ void BindLocal(py::module_ &m) {
self.AddCallback(
[py_future_vp = static_cast(future.release().ptr())](
local::Future &sf_future) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyFuture::OnComplete");
py::gil_scoped_acquire g;
py::object py_future =
py::steal(static_cast(py_future_vp));
diff --git a/shortfin/python/shortfin/array/__init__.py b/shortfin/python/shortfin/array/__init__.py
index 3a4d28877..6079541c8 100644
--- a/shortfin/python/shortfin/array/__init__.py
+++ b/shortfin/python/shortfin/array/__init__.py
@@ -44,7 +44,12 @@
# Ops.
argmax = _sfl.array.argmax
+ceil = _sfl.array.ceil
+convert = _sfl.array.convert
fill_randn = _sfl.array.fill_randn
+floor = _sfl.array.floor
+round = _sfl.array.round
+trunc = _sfl.array.trunc
RandomGenerator = _sfl.array.RandomGenerator
__all__ = [
@@ -82,7 +87,12 @@
"DType",
# Ops.
"argmax",
+ "ceil",
+ "convert",
"fill_randn",
+ "floor",
+ "round",
+ "trunc",
"RandomGenerator",
]
diff --git a/shortfin/python/shortfin/interop/support/device_setup.py b/shortfin/python/shortfin/interop/support/device_setup.py
new file mode 100644
index 000000000..afe6ca695
--- /dev/null
+++ b/shortfin/python/shortfin/interop/support/device_setup.py
@@ -0,0 +1,26 @@
+import shortfin as sf
+
+
+def get_selected_devices(sb: sf.SystemBuilder, device_ids=None):
+ available = sb.available_devices
+ selected = []
+ if device_ids is not None:
+ if len(device_ids) > len(available):
+ raise ValueError(
+ f"Requested more device ids ({device_ids}) than available ({available})."
+ )
+ for did in device_ids:
+ if isinstance(did, str):
+ try:
+ did = int(did)
+ except ValueError:
+ did = did
+ if did in available:
+ selected.append(did)
+ elif isinstance(did, int):
+ selected.append(available[did])
+ else:
+ raise ValueError(f"Device id {did} could not be parsed.")
+ else:
+ selected = available
+ return selected
diff --git a/shortfin/python/shortfin/support/logging_setup.py b/shortfin/python/shortfin/support/logging_setup.py
index 39cf3cf75..849d65bf3 100644
--- a/shortfin/python/shortfin/support/logging_setup.py
+++ b/shortfin/python/shortfin/support/logging_setup.py
@@ -38,16 +38,15 @@ def __init__(self):
native_handler.setFormatter(NativeFormatter())
# TODO: Source from env vars.
-logger.setLevel(logging.DEBUG)
+logger.setLevel(logging.WARNING)
logger.addHandler(native_handler)
def configure_main_logger(module_suffix: str = "__main__") -> logging.Logger:
"""Configures logging from a main entrypoint.
-
Returns a logger that can be used for the main module itself.
"""
logging.root.addHandler(native_handler)
- logging.root.setLevel(logging.DEBUG) # TODO: source from env vars
+ logging.root.setLevel(logging.WARNING) # TODO: source from env vars
main_module = sys.modules["__main__"]
return logging.getLogger(f"{main_module.__package__}.{module_suffix}")
diff --git a/shortfin/python/shortfin_apps/llm/client.py b/shortfin/python/shortfin_apps/llm/client.py
index 63cff7bee..e3ff3ec39 100644
--- a/shortfin/python/shortfin_apps/llm/client.py
+++ b/shortfin/python/shortfin_apps/llm/client.py
@@ -1,3 +1,9 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
import requests
import json
import uuid
@@ -5,8 +11,6 @@
import time
from typing import Dict, Any
-BASE_URL = "http://localhost:8000"
-
def main() -> None:
parser = argparse.ArgumentParser(description="Test LLM server")
@@ -20,8 +24,16 @@ def main() -> None:
parser.add_argument(
"--stream", action="store_true", help="Enable response streaming"
)
+ parser.add_argument(
+ "--port",
+ type=str,
+ default="8000",
+ help="Port that shortfin server is running on",
+ )
args = parser.parse_args()
+ base_url = f"http://localhost:{args.port}"
+
data = {
"text": args.text,
"sampling_params": {
@@ -36,13 +48,13 @@ def main() -> None:
"stream": args.stream,
}
- print(f"Testing LLM server at {BASE_URL}")
+ print(f"Testing LLM server at {base_url}")
# Health check with exponential backoff
backoff = 1
while True:
try:
- requests.get(f"{BASE_URL}/health").raise_for_status()
+ requests.get(f"{base_url}/health").raise_for_status()
break
except requests.exceptions.RequestException as e:
if backoff > 16:
@@ -56,7 +68,7 @@ def main() -> None:
try:
print("Prompt text:", data["text"])
headers = {"Content-Type": "application/json"}
- response = requests.post(f"{BASE_URL}/generate", headers=headers, json=data)
+ response = requests.post(f"{base_url}/generate", headers=headers, json=data)
response.raise_for_status()
if response.text.startswith("data: "):
diff --git a/shortfin/python/shortfin_apps/llm/components/manager.py b/shortfin/python/shortfin_apps/llm/components/manager.py
index e3057de22..b44116b39 100644
--- a/shortfin/python/shortfin_apps/llm/components/manager.py
+++ b/shortfin/python/shortfin_apps/llm/components/manager.py
@@ -8,16 +8,23 @@
import threading
import shortfin as sf
+from shortfin.interop.support.device_setup import get_selected_devices
logger = logging.getLogger(__name__)
class SystemManager:
- def __init__(self, device="local-task"):
- if device == "local-task":
+ def __init__(self, device="local-task", device_ids=None, async_allocs=True):
+ if any(x in device for x in ["local-task", "cpu"]):
self.ls = sf.host.CPUSystemBuilder().create_system()
- elif device == "hip":
- self.ls = sf.amdgpu.SystemBuilder().create_system()
+ elif any(x in device for x in ["hip", "amdgpu"]):
+ sb = sf.SystemBuilder(
+ system_type="amdgpu", amdgpu_async_allocations=async_allocs
+ )
+ if device_ids:
+ sb.visible_devices = sb.available_devices
+ sb.visible_devices = get_selected_devices(sb, device_ids)
+ self.ls = sb.create_system()
logger.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py
index 5b51a9a7f..2ab7a1b96 100644
--- a/shortfin/python/shortfin_apps/llm/server.py
+++ b/shortfin/python/shortfin_apps/llm/server.py
@@ -86,7 +86,11 @@ def get_eos_from_tokenizer_config(json_path):
def configure(args) -> SystemManager:
# Setup system (configure devices, etc).
- sysman = SystemManager(device=args.device)
+ sysman = SystemManager(
+ device=args.device,
+ device_ids=args.device_ids,
+ async_allocs=args.amdgpu_async_allocations,
+ )
# Setup each service we are hosting.
eos_token = get_eos_from_tokenizer_config(args.tokenizer_config_json)
@@ -155,9 +159,17 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
parser.add_argument(
"--device",
type=str,
- default="local-task",
+ required=True,
+ choices=["local-task", "hip", "amdgpu"],
help="Device to serve on; e.g. local-task, hip. Same options as `iree-run-module --device` ",
)
+ parser.add_argument(
+ "--device_ids",
+ type=str,
+ nargs="*",
+ default=None,
+ help="Device IDs visible to the system builder. Defaults to None (full visibility). Can be an index or a sf device id like amdgpu:0:0@0",
+ )
parser.add_argument(
"--isolation",
type=str,
@@ -165,6 +177,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
choices=[isolation.name.lower() for isolation in ProgramIsolation],
help="Concurrency control -- How to isolate programs.",
)
+ parser.add_argument(
+ "--amdgpu_async_allocations",
+ action="store_true",
+ help="Enable asynchronous allocations for amdgpu device contexts.",
+ )
args = parser.parse_args(argv)
if args.tokenizer_config_json is None:
diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md
index 4808cad08..3397be6cf 100644
--- a/shortfin/python/shortfin_apps/sd/README.md
+++ b/shortfin/python/shortfin_apps/sd/README.md
@@ -1,50 +1,30 @@
-# SD Server and CLI
+# SDXL Server and CLI
-This directory contains a SD inference server, CLI and support components.
+This directory contains a [SDXL](https://stablediffusionxl.com/) inference server, CLI and support components. More information about SDXL on [huggingface](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0).
+## Install
-## Quick start
-
-In your shortfin environment,
-```
-pip install transformers
-pip install dataclasses-json
-pip install pillow
-
-```
-```
-python -m shortfin_apps.sd.server --help
-```
-
-## Run tests
-
- - From SHARK-Platform/shortfin:
- ```
- pytest --system=amdgpu -k "sd"
- ```
- The tests run with splat weights.
+For [nightly releases](../../../../docs/nightly_releases.md)
+For our [stable release](../../../../docs/user_guide.md)
+## Start SDXL Server
+The server will prepare runtime artifacts for you.
-## Run on MI300x
+By default, the port is set to 8000. If you would like to change this, use `--port` in each of the following commands.
- - Follow quick start
+You can check if this (or any) port is in use on Linux with `ss -ntl | grep 8000`.
- - Navigate to shortfin/ (only necessary if you're using following CLI exactly.)
```
-cd shortfin/
+python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single"
```
- - Run CLI server interface (you can find `sdxl_config_i8.json` in shortfin_apps/sd/examples):
-
-The server will prepare runtime artifacts for you.
-
+ - Wait until your server outputs:
```
-python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile
+INFO - Application startup complete.
+INFO - Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
```
- - Run with splat(empty) weights:
-```
-python -m shortfin_apps.sd.server --model_config=./python/shortfin_apps/sd/examples/sdxl_config_i8.json --device=amdgpu --device_ids=0 --splat --flagfile=./python/shortfin_apps/sd/examples/sdxl_flags_gfx942.txt --build_preference=compile
-```
- - Run a request in a separate shell:
+## Run the SDXL Client
+
+ - Run a CLI client in a separate shell:
```
-python shortfin/python/shortfin_apps/sd/examples/send_request.py --file=shortfin/python/shortfin_apps/sd/examples/sdxl_request.json
+python -m shortfin_apps.sd.simple_client --interactive
```
diff --git a/shortfin/python/shortfin_apps/sd/components/builders.py b/shortfin/python/shortfin_apps/sd/components/builders.py
index 1f9d0c2ee..f23922dd6 100644
--- a/shortfin/python/shortfin_apps/sd/components/builders.py
+++ b/shortfin/python/shortfin_apps/sd/components/builders.py
@@ -1,3 +1,9 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
from iree.build import *
from iree.build.executor import FileNamespace
import itertools
@@ -18,7 +24,7 @@
sfnp.bfloat16: "bf16",
}
-ARTIFACT_VERSION = "11022024"
+ARTIFACT_VERSION = "11132024"
SDXL_BUCKET = (
f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/"
)
@@ -45,7 +51,9 @@ def get_mlir_filenames(model_params: ModelParams, model=None):
return filter_by_model(mlir_filenames, model)
-def get_vmfb_filenames(model_params: ModelParams, model=None, target: str = "gfx942"):
+def get_vmfb_filenames(
+ model_params: ModelParams, model=None, target: str = "amdgpu-gfx942"
+):
vmfb_filenames = []
file_stems = get_file_stems(model_params)
for stem in file_stems:
@@ -159,9 +167,9 @@ def needs_file(filename, ctx, namespace=FileNamespace.GEN):
if os.path.exists(out_file):
needed = False
else:
- name_path = "bin" if namespace == FileNamespace.BIN else ""
- if name_path:
- filename = os.path.join(name_path, filename)
+ # name_path = "bin" if namespace == FileNamespace.BIN else ""
+ # if name_path:
+ # filename = os.path.join(name_path, filename)
filekey = os.path.join(ctx.path, filename)
ctx.executor.all[filekey] = None
needed = True
@@ -210,6 +218,8 @@ def sdxl(
mlir_bucket = SDXL_BUCKET + "mlir/"
vmfb_bucket = SDXL_BUCKET + "vmfbs/"
+ if "gfx" in target:
+ target = "amdgpu-" + target
mlir_filenames = get_mlir_filenames(model_params, model)
mlir_urls = get_url_map(mlir_filenames, mlir_bucket)
@@ -241,7 +251,7 @@ def sdxl(
params_urls = get_url_map(params_filenames, SDXL_WEIGHTS_BUCKET)
for f, url in params_urls.items():
out_file = os.path.join(ctx.executor.output_dir, f)
- if update or needs_file(f, ctx):
+ if needs_file(f, ctx):
fetch_http(name=f, url=url)
filenames = [*vmfb_filenames, *params_filenames, *mlir_filenames]
return filenames
diff --git a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py
new file mode 100644
index 000000000..f3502f22e
--- /dev/null
+++ b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py
@@ -0,0 +1,111 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from iree.build import *
+from iree.build.executor import FileNamespace
+import itertools
+import os
+import shortfin.array as sfnp
+import copy
+
+ARTIFACT_VERSION = "11132024"
+SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/"
+
+
+def get_url_map(filenames: list[str], bucket: str):
+ file_map = {}
+ for filename in filenames:
+ file_map[filename] = f"{bucket}{filename}"
+ return file_map
+
+
+def needs_update(ctx):
+ stamp = ctx.allocate_file("version.txt")
+ stamp_path = stamp.get_fs_path()
+ if os.path.exists(stamp_path):
+ with open(stamp_path, "r") as s:
+ ver = s.read()
+ if ver != ARTIFACT_VERSION:
+ return True
+ else:
+ with open(stamp_path, "w") as s:
+ s.write(ARTIFACT_VERSION)
+ return True
+ return False
+
+
+def needs_file(filename, ctx, namespace=FileNamespace.GEN):
+ out_file = ctx.allocate_file(filename, namespace=namespace).get_fs_path()
+ if os.path.exists(out_file):
+ needed = False
+ else:
+ # name_path = "bin" if namespace == FileNamespace.BIN else ""
+ # if name_path:
+ # filename = os.path.join(name_path, filename)
+ filekey = os.path.join(ctx.path, filename)
+ ctx.executor.all[filekey] = None
+ needed = True
+ return needed
+
+
+@entrypoint(description="Retreives a set of SDXL configuration files.")
+def sdxlconfig(
+ target=cl_arg(
+ "target",
+ default="gfx942",
+ help="IREE target architecture.",
+ ),
+ model=cl_arg("model", type=str, default="sdxl", help="Model architecture"),
+ topology=cl_arg(
+ "topology",
+ type=str,
+ default="spx_single",
+ help="System topology configfile keyword",
+ ),
+):
+ ctx = executor.BuildContext.current()
+ update = needs_update(ctx)
+
+ model_config_filenames = [f"{model}_config_i8.json"]
+ model_config_urls = get_url_map(model_config_filenames, SDXL_CONFIG_BUCKET)
+ for f, url in model_config_urls.items():
+ out_file = os.path.join(ctx.executor.output_dir, f)
+ if update or needs_file(f, ctx):
+ fetch_http(name=f, url=url)
+
+ topology_config_filenames = [f"topology_config_{topology}.txt"]
+ topology_config_urls = get_url_map(topology_config_filenames, SDXL_CONFIG_BUCKET)
+ for f, url in topology_config_urls.items():
+ out_file = os.path.join(ctx.executor.output_dir, f)
+ if update or needs_file(f, ctx):
+ fetch_http(name=f, url=url)
+
+ flagfile_filenames = [f"{model}_flagfile_{target}.txt"]
+ flagfile_urls = get_url_map(flagfile_filenames, SDXL_CONFIG_BUCKET)
+ for f, url in flagfile_urls.items():
+ out_file = os.path.join(ctx.executor.output_dir, f)
+ if update or needs_file(f, ctx):
+ fetch_http(name=f, url=url)
+
+ tuning_filenames = (
+ [f"attention_and_matmul_spec_{target}.mlir"] if target == "gfx942" else []
+ )
+ tuning_urls = get_url_map(tuning_filenames, SDXL_CONFIG_BUCKET)
+ for f, url in tuning_urls.items():
+ out_file = os.path.join(ctx.executor.output_dir, f)
+ if update or needs_file(f, ctx):
+ fetch_http(name=f, url=url)
+ filenames = [
+ *model_config_filenames,
+ *topology_config_filenames,
+ *flagfile_filenames,
+ *tuning_filenames,
+ ]
+ return filenames
+
+
+if __name__ == "__main__":
+ iree_build_main()
diff --git a/shortfin/python/shortfin_apps/sd/components/config_struct.py b/shortfin/python/shortfin_apps/sd/components/config_struct.py
index 3dda6edfc..478d03ad8 100644
--- a/shortfin/python/shortfin_apps/sd/components/config_struct.py
+++ b/shortfin/python/shortfin_apps/sd/components/config_struct.py
@@ -1,3 +1,9 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
"""Configuration objects.
Parameters that are intrinsic to a specific model.
diff --git a/shortfin/python/shortfin_apps/sd/components/generate.py b/shortfin/python/shortfin_apps/sd/components/generate.py
index ebb5ea08a..1afa73d5e 100644
--- a/shortfin/python/shortfin_apps/sd/components/generate.py
+++ b/shortfin/python/shortfin_apps/sd/components/generate.py
@@ -20,7 +20,7 @@
from .service import GenerateService
from .metrics import measure
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("shortfin-sd.generate")
class GenerateImageProcess(sf.Process):
diff --git a/shortfin/python/shortfin_apps/sd/components/io_struct.py b/shortfin/python/shortfin_apps/sd/components/io_struct.py
index d2952a818..d1d9cf41a 100644
--- a/shortfin/python/shortfin_apps/sd/components/io_struct.py
+++ b/shortfin/python/shortfin_apps/sd/components/io_struct.py
@@ -72,3 +72,10 @@ def post_init(self):
raise ValueError("The rid should be a list.")
if self.output_type is None:
self.output_type = ["base64"] * self.num_output_images
+ # Temporary restrictions
+ heights = [self.height] if not isinstance(self.height, list) else self.height
+ widths = [self.width] if not isinstance(self.width, list) else self.width
+ if any(dim != 1024 for dim in [*heights, *widths]):
+ raise ValueError(
+ "Currently, only 1024x1024 output image size is supported."
+ )
diff --git a/shortfin/python/shortfin_apps/sd/components/manager.py b/shortfin/python/shortfin_apps/sd/components/manager.py
index c52cf62f7..ea29b69a4 100644
--- a/shortfin/python/shortfin_apps/sd/components/manager.py
+++ b/shortfin/python/shortfin_apps/sd/components/manager.py
@@ -8,35 +8,24 @@
import threading
import shortfin as sf
+from shortfin.interop.support.device_setup import get_selected_devices
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("shortfin-sd.manager")
class SystemManager:
- def __init__(self, device="local-task", device_ids=None):
+ def __init__(self, device="local-task", device_ids=None, async_allocs=True):
if any(x in device for x in ["local-task", "cpu"]):
self.ls = sf.host.CPUSystemBuilder().create_system()
elif any(x in device for x in ["hip", "amdgpu"]):
- sc_query = sf.amdgpu.SystemBuilder()
- available = sc_query.available_devices
- selected = []
- if device_ids is not None:
- if len(device_ids) >= len(available):
- raise ValueError(
- f"Requested more device ids ({device_ids}) than available ({available})."
- )
- for did in device_ids:
- if did in available:
- selected.append(did)
- elif isinstance(did, int):
- selected.append(available[did])
- else:
- raise ValueError(f"Device id {did} could not be parsed.")
- else:
- selected = available
- sb = sf.amdgpu.SystemBuilder(amdgpu_visible_devices=";".join(selected))
+ sb = sf.SystemBuilder(
+ system_type="amdgpu", amdgpu_async_allocations=async_allocs
+ )
+ if device_ids:
+ sb.visible_devices = sb.available_devices
+ sb.visible_devices = get_selected_devices(sb, device_ids)
self.ls = sb.create_system()
- logger.info(f"Created local system with {self.ls.device_names} devices")
+ logging.info(f"Created local system with {self.ls.device_names} devices")
# TODO: Come up with an easier bootstrap thing than manually
# running a thread.
self.t = threading.Thread(target=lambda: self.ls.run(self.run()))
diff --git a/shortfin/python/shortfin_apps/sd/components/messages.py b/shortfin/python/shortfin_apps/sd/components/messages.py
index 88eb28ff4..6ae716bad 100644
--- a/shortfin/python/shortfin_apps/sd/components/messages.py
+++ b/shortfin/python/shortfin_apps/sd/components/messages.py
@@ -13,7 +13,7 @@
from .io_struct import GenerateReqInput
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("shortfin-sd.messages")
class InferencePhase(Enum):
diff --git a/shortfin/python/shortfin_apps/sd/components/metrics.py b/shortfin/python/shortfin_apps/sd/components/metrics.py
index 6d3c1aa8b..a1811beea 100644
--- a/shortfin/python/shortfin_apps/sd/components/metrics.py
+++ b/shortfin/python/shortfin_apps/sd/components/metrics.py
@@ -1,10 +1,16 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
import logging
import time
import asyncio
from typing import Callable, Any
import functools
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("shortfin-sd.metrics")
def measure(fn=None, type="exec", task=None, num_items=None, freq=1, label="items"):
diff --git a/shortfin/python/shortfin_apps/sd/components/service.py b/shortfin/python/shortfin_apps/sd/components/service.py
index af8423a11..ad3fd9404 100644
--- a/shortfin/python/shortfin_apps/sd/components/service.py
+++ b/shortfin/python/shortfin_apps/sd/components/service.py
@@ -24,7 +24,8 @@
from .metrics import measure
-logger = logging.getLogger(__name__)
+logger = logging.getLogger("shortfin-sd.service")
+logger.setLevel(logging.DEBUG)
prog_isolations = {
"none": sf.ProgramIsolation.NONE,
@@ -62,12 +63,20 @@ def __init__(
self.inference_parameters: dict[str, list[sf.BaseProgramParameters]] = {}
self.inference_modules: dict[str, sf.ProgramModule] = {}
self.inference_functions: dict[str, dict[str, sf.ProgramFunction]] = {}
- self.inference_programs: dict[str, sf.Program] = {}
+ self.inference_programs: dict[int, dict[str, sf.Program]] = {}
self.trace_execution = trace_execution
self.show_progress = show_progress
+
+ self.prog_isolation = prog_isolations[prog_isolation]
+
self.workers_per_device = workers_per_device
self.fibers_per_device = fibers_per_device
- self.prog_isolation = prog_isolations[prog_isolation]
+ if fibers_per_device % workers_per_device != 0:
+ raise ValueError(
+ "Currently, fibers_per_device must be divisible by workers_per_device"
+ )
+ self.fibers_per_worker = int(fibers_per_device / workers_per_device)
+
self.workers = []
self.fibers = []
self.fiber_status = []
@@ -81,7 +90,9 @@ def __init__(
)
self.fibers.append(fiber)
self.fiber_status.append(0)
-
+ for idx in range(len(self.workers)):
+ self.inference_programs[idx] = {}
+ self.inference_functions[idx] = {}
# Scope dependent objects.
self.batcher = BatcherProcess(self)
@@ -108,52 +119,61 @@ def load_inference_parameters(
self.inference_parameters[component].append(p)
def start(self):
- for fiber in self.fibers:
- for component in self.inference_modules:
- component_modules = [
- sf.ProgramModule.parameter_provider(
- self.sysman.ls, *self.inference_parameters.get(component, [])
- ),
- *self.inference_modules[component],
- ]
- self.inference_programs[component] = sf.Program(
+ # Initialize programs.
+ for component in self.inference_modules:
+ component_modules = [
+ sf.ProgramModule.parameter_provider(
+ self.sysman.ls, *self.inference_parameters.get(component, [])
+ ),
+ *self.inference_modules[component],
+ ]
+
+ for worker_idx, worker in enumerate(self.workers):
+ worker_devices = self.fibers[
+ worker_idx * (self.fibers_per_worker)
+ ].raw_devices
+ logger.info(
+ f"Loading inference program: {component}, worker index: {worker_idx}, device: {worker_devices}"
+ )
+ self.inference_programs[worker_idx][component] = sf.Program(
modules=component_modules,
- devices=fiber.raw_devices,
+ devices=worker_devices,
isolation=self.prog_isolation,
trace_execution=self.trace_execution,
)
-
- # TODO: export vmfbs with multiple batch size entrypoints
-
- self.inference_functions["encode"] = {}
- for bs in self.model_params.clip_batch_sizes:
- self.inference_functions["encode"][bs] = self.inference_programs["clip"][
- f"{self.model_params.clip_module_name}.encode_prompts"
- ]
-
- self.inference_functions["denoise"] = {}
- for bs in self.model_params.unet_batch_sizes:
- self.inference_functions["denoise"][bs] = {
- "unet": self.inference_programs["unet"][
- f"{self.model_params.unet_module_name}.{self.model_params.unet_fn_name}"
- ],
- "init": self.inference_programs["scheduler"][
- f"{self.model_params.scheduler_module_name}.run_initialize"
- ],
- "scale": self.inference_programs["scheduler"][
- f"{self.model_params.scheduler_module_name}.run_scale"
- ],
- "step": self.inference_programs["scheduler"][
- f"{self.model_params.scheduler_module_name}.run_step"
- ],
- }
-
- self.inference_functions["decode"] = {}
- for bs in self.model_params.vae_batch_sizes:
- self.inference_functions["decode"][bs] = self.inference_programs["vae"][
- f"{self.model_params.vae_module_name}.decode"
- ]
-
+ logger.info("Program loaded.")
+
+ for worker_idx, worker in enumerate(self.workers):
+ self.inference_functions[worker_idx]["encode"] = {}
+ for bs in self.model_params.clip_batch_sizes:
+ self.inference_functions[worker_idx]["encode"][
+ bs
+ ] = self.inference_programs[worker_idx]["clip"][
+ f"{self.model_params.clip_module_name}.encode_prompts"
+ ]
+ self.inference_functions[worker_idx]["denoise"] = {}
+ for bs in self.model_params.unet_batch_sizes:
+ self.inference_functions[worker_idx]["denoise"][bs] = {
+ "unet": self.inference_programs[worker_idx]["unet"][
+ f"{self.model_params.unet_module_name}.{self.model_params.unet_fn_name}"
+ ],
+ "init": self.inference_programs[worker_idx]["scheduler"][
+ f"{self.model_params.scheduler_module_name}.run_initialize"
+ ],
+ "scale": self.inference_programs[worker_idx]["scheduler"][
+ f"{self.model_params.scheduler_module_name}.run_scale"
+ ],
+ "step": self.inference_programs[worker_idx]["scheduler"][
+ f"{self.model_params.scheduler_module_name}.run_step"
+ ],
+ }
+ self.inference_functions[worker_idx]["decode"] = {}
+ for bs in self.model_params.vae_batch_sizes:
+ self.inference_functions[worker_idx]["decode"][
+ bs
+ ] = self.inference_programs[worker_idx]["vae"][
+ f"{self.model_params.vae_module_name}.decode"
+ ]
self.batcher.launch()
def shutdown(self):
@@ -166,6 +186,8 @@ def __repr__(self):
params = [
f" {key} : {value}" for key, value in self.inference_parameters.items()
]
+ # For python 3.11 since we can't have \ in the f"" expression.
+ new_line = "\n"
return (
f"ServiceManager("
f"\n INFERENCE DEVICES : \n"
@@ -176,9 +198,9 @@ def __repr__(self):
f" fibers per device : {self.fibers_per_device}\n"
f" program isolation mode : {self.prog_isolation}\n"
f"\n INFERENCE MODULES : \n"
- f"{'\n'.join(modules)}\n"
+ f"{new_line.join(modules)}\n"
f"\n INFERENCE PARAMETERS : \n"
- f"{'\n'.join(params)}\n"
+ f"{new_line.join(params)}\n"
f")"
)
@@ -193,8 +215,8 @@ class BatcherProcess(sf.Process):
into batches.
"""
- STROBE_SHORT_DELAY = 0.1
- STROBE_LONG_DELAY = 0.25
+ STROBE_SHORT_DELAY = 0.5
+ STROBE_LONG_DELAY = 1
def __init__(self, service: GenerateService):
super().__init__(fiber=service.fibers[0])
@@ -320,7 +342,11 @@ def __init__(
):
super().__init__(fiber=service.fibers[index])
self.service = service
- self.worker_index = index
+ self.fiber_index = index
+ self.worker_index = int(
+ (index - index % self.service.fibers_per_worker)
+ / self.service.fibers_per_worker
+ )
self.exec_requests: list[InferenceExecRequest] = []
@measure(type="exec", task="inference process")
@@ -333,9 +359,8 @@ async def run(self):
logger.error("Executor process recieved disjoint batch.")
phase = req.phase
phases = self.exec_requests[0].phases
-
req_count = len(self.exec_requests)
- device0 = self.service.fibers[self.worker_index].device(0)
+ device0 = self.service.fibers[self.fiber_index].device(0)
if phases[InferencePhase.PREPARE]["required"]:
await self._prepare(device=device0, requests=self.exec_requests)
if phases[InferencePhase.ENCODE]["required"]:
@@ -346,11 +371,11 @@ async def run(self):
await self._decode(device=device0, requests=self.exec_requests)
if phases[InferencePhase.POSTPROCESS]["required"]:
await self._postprocess(device=device0, requests=self.exec_requests)
-
+ await device0
for i in range(req_count):
req = self.exec_requests[i]
req.done.set_success()
- self.service.fiber_status[self.worker_index] = 0
+ self.service.fiber_status[self.fiber_index] = 0
except Exception:
logger.exception("Fatal error in image generation")
@@ -400,10 +425,13 @@ async def _prepare(self, device, requests):
async def _encode(self, device, requests):
req_bs = len(requests)
-
- entrypoints = self.service.inference_functions["encode"]
+ entrypoints = self.service.inference_functions[self.worker_index]["encode"]
+ if req_bs not in list(entrypoints.keys()):
+ for request in requests:
+ await self._encode(device, [request])
+ return
for bs, fn in entrypoints.items():
- if bs >= req_bs:
+ if bs == req_bs:
break
# Prepare tokenized input ids for CLIP inference
@@ -440,6 +468,7 @@ async def _encode(self, device, requests):
fn,
"".join([f"\n {i}: {ary.shape}" for i, ary in enumerate(clip_inputs)]),
)
+ await device
pe, te = await fn(*clip_inputs, fiber=self.fiber)
for i in range(req_bs):
@@ -454,9 +483,13 @@ async def _denoise(self, device, requests):
step_count = requests[0].steps
cfg_mult = 2 if self.service.model_params.cfg_mode else 1
# Produce denoised latents
- entrypoints = self.service.inference_functions["denoise"]
+ entrypoints = self.service.inference_functions[self.worker_index]["denoise"]
+ if req_bs not in list(entrypoints.keys()):
+ for request in requests:
+ await self._denoise(device, [request])
+ return
for bs, fns in entrypoints.items():
- if bs >= req_bs:
+ if bs == req_bs:
break
# Get shape of batched latents.
@@ -590,9 +623,13 @@ async def _denoise(self, device, requests):
async def _decode(self, device, requests):
req_bs = len(requests)
# Decode latents to images
- entrypoints = self.service.inference_functions["decode"]
+ entrypoints = self.service.inference_functions[self.worker_index]["decode"]
+ if req_bs not in list(entrypoints.keys()):
+ for request in requests:
+ await self._decode(device, [request])
+ return
for bs, fn in entrypoints.items():
- if bs >= req_bs:
+ if bs == req_bs:
break
latents_shape = [
diff --git a/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json
new file mode 100644
index 000000000..002f43f0e
--- /dev/null
+++ b/shortfin/python/shortfin_apps/sd/examples/sdxl_request_bs32.json
@@ -0,0 +1,57 @@
+{
+ "prompt": [
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, amateur photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, wide shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with green eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with red eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by ice, cinematic style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cartoon style, medium shot, professional photo, animal",
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo"
+ ],
+ "neg_prompt": [
+ "Watermark, blurry, oversaturated, low resolution, pollution"
+ ],
+ "height": [
+ 1024
+ ],
+ "width": [
+ 1024
+ ],
+ "steps": [
+ 20
+ ],
+ "guidance_scale": [
+ 7.5
+ ],
+ "seed": [
+ 0
+ ],
+ "output_type": [
+ "base64"
+ ]
+}
diff --git a/shortfin/python/shortfin_apps/sd/examples/send_request.py b/shortfin/python/shortfin_apps/sd/examples/send_request.py
deleted file mode 100644
index 94fae9659..000000000
--- a/shortfin/python/shortfin_apps/sd/examples/send_request.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import json
-import requests
-import argparse
-import base64
-
-from datetime import datetime as dt
-from PIL import Image
-
-sample_request = {
- "prompt": [
- " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
- ],
- "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"],
- "height": [1024],
- "width": [1024],
- "steps": [20],
- "guidance_scale": [7.5],
- "seed": [0],
- "output_type": ["base64"],
- "rid": ["string"],
-}
-
-
-def bytes_to_img(bytes, idx=0, width=1024, height=1024):
- timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
- image = Image.frombytes(
- mode="RGB", size=(width, height), data=base64.b64decode(bytes)
- )
- image.save(f"shortfin_sd_output_{timestamp}_{idx}.png")
- print(f"Saved to shortfin_sd_output_{timestamp}_{idx}.png")
-
-
-def send_json_file(file_path):
- # Read the JSON file
- try:
- if file_path == "default":
- data = sample_request
- else:
- with open(file_path, "r") as json_file:
- data = json.load(json_file)
- except Exception as e:
- print(f"Error reading the JSON file: {e}")
- return
-
- # Send the data to the /generate endpoint
- try:
- response = requests.post("http://0.0.0.0:8000/generate", json=data)
- response.raise_for_status() # Raise an error for bad responses
- print("Saving response as image...")
- timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
- request = json.loads(response.request.body.decode("utf-8"))
- for idx, item in enumerate(response.json()["images"]):
- width = get_batched(request, "width", idx)
- height = get_batched(request, "height", idx)
- bytes_to_img(item.encode("utf-8"), idx, width, height)
-
- except requests.exceptions.RequestException as e:
- print(f"Error sending the request: {e}")
-
-
-def get_batched(request, arg, idx):
- if isinstance(request[arg], list):
- if len(request[arg]) == 1:
- indexed = request[arg][0]
- else:
- indexed = request[arg][idx]
- else:
- indexed = request[arg]
- return indexed
-
-
-if __name__ == "__main__":
- p = argparse.ArgumentParser()
- p.add_argument("--file", type=str, default="default")
- args = p.parse_args()
- send_json_file(args.file)
diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py
index 849337900..9cd624241 100644
--- a/shortfin/python/shortfin_apps/sd/server.py
+++ b/shortfin/python/shortfin_apps/sd/server.py
@@ -15,10 +15,6 @@
import copy
import subprocess
-from iree.build import *
-
-import uvicorn.logging
-
# Import first as it does dep checking and reporting.
from shortfin.interop.fastapi import FastAPIResponder
@@ -27,7 +23,6 @@
from fastapi import FastAPI, Request, Response
import uvicorn
-
from .components.generate import ClientGenerateBatchProcess
from .components.config_struct import ModelParams
from .components.io_struct import GenerateReqInput
@@ -36,10 +31,12 @@
from .components.tokenizer import Tokenizer
from .components.builders import sdxl
+from shortfin.support.logging_setup import native_handler, configure_main_logger
-from shortfin.support.logging_setup import configure_main_logger
-
-logger = configure_main_logger("server")
+logger = logging.getLogger("shortfin-sd")
+logger.addHandler(native_handler)
+logger.setLevel(logging.INFO)
+logger.propagate = False
THIS_DIR = Path(__file__).resolve().parent
@@ -88,7 +85,8 @@ async def generate_request(gen_req: GenerateReqInput, request: Request):
def configure(args) -> SystemManager:
# Setup system (configure devices, etc).
- sysman = SystemManager(args.device, args.device_ids)
+ model_config, topology_config, flagfile, tuning_spec, args = get_configs(args)
+ sysman = SystemManager(args.device, args.device_ids, args.amdgpu_async_allocations)
# Setup each service we are hosting.
tokenizers = []
@@ -96,7 +94,9 @@ def configure(args) -> SystemManager:
subfolder = f"tokenizer_{idx + 1}" if idx > 0 else "tokenizer"
tokenizers.append(Tokenizer.from_pretrained(tok_name, subfolder))
- model_params = ModelParams.load_json(args.model_config)
+ model_params = ModelParams.load_json(model_config)
+ vmfbs, params = get_modules(args, model_config, flagfile, tuning_spec)
+
sm = GenerateService(
name="sd",
sysman=sysman,
@@ -108,7 +108,6 @@ def configure(args) -> SystemManager:
show_progress=args.show_progress,
trace_execution=args.trace_execution,
)
- vmfbs, params = get_modules(args)
for key, vmfblist in vmfbs.items():
for vmfb in vmfblist:
sm.load_inference_module(vmfb, component=key)
@@ -118,14 +117,80 @@ def configure(args) -> SystemManager:
return sysman
-def get_modules(args):
+def get_configs(args):
+ # Returns one set of config artifacts.
+ modelname = "sdxl"
+ model_config = args.model_config if args.model_config else None
+ topology_config = None
+ tuning_spec = None
+ flagfile = args.flagfile if args.flagfile else None
+ topology_inp = args.topology if args.topology else "spx_single"
+ cfg_builder_args = [
+ sys.executable,
+ "-m",
+ "iree.build",
+ os.path.join(THIS_DIR, "components", "config_artifacts.py"),
+ f"--target={args.target}",
+ f"--output-dir={args.artifacts_dir}",
+ f"--model={modelname}",
+ f"--topology={topology_inp}",
+ ]
+ outs = subprocess.check_output(cfg_builder_args).decode()
+ outs_paths = outs.splitlines()
+ for i in outs_paths:
+ if "sdxl_config" in i and not args.model_config:
+ model_config = i
+ elif "topology" in i and args.topology:
+ topology_config = i
+ elif "flagfile" in i and not args.flagfile:
+ flagfile = i
+ elif "attention_and_matmul_spec" in i and args.use_tuned:
+ tuning_spec = i
+
+ if args.use_tuned and args.tuning_spec:
+ tuning_spec = os.path.abspath(args.tuning_spec)
+
+ if topology_config:
+ with open(topology_config, "r") as f:
+ contents = [line.rstrip() for line in f]
+ for spec in contents:
+ if "--" in spec:
+ arglist = spec.strip("--").split("=")
+ arg = arglist[0]
+ if len(arglist) > 2:
+ value = arglist[1:]
+ for val in value:
+ try:
+ val = int(val)
+ except ValueError:
+ continue
+ elif len(arglist) == 2:
+ value = arglist[-1]
+ try:
+ value = int(value)
+ except ValueError:
+ continue
+ else:
+ # It's a boolean arg.
+ value = True
+ setattr(args, arg, value)
+ else:
+ # It's an env var.
+ arglist = spec.split("=")
+ os.environ[arglist[0]] = arglist[1]
+
+ return model_config, topology_config, flagfile, tuning_spec, args
+
+
+def get_modules(args, model_config, flagfile, td_spec):
+ # TODO: Move this out of server entrypoint
vmfbs = {"clip": [], "unet": [], "vae": [], "scheduler": []}
params = {"clip": [], "unet": [], "vae": []}
model_flags = copy.deepcopy(vmfbs)
model_flags["all"] = args.compile_flags
- if args.flagfile:
- with open(args.flagfile, "r") as f:
+ if flagfile:
+ with open(flagfile, "r") as f:
contents = [line.rstrip() for line in f]
flagged_model = "all"
for elem in contents:
@@ -134,6 +199,10 @@ def get_modules(args):
flagged_model = elem
else:
model_flags[flagged_model].extend([elem])
+ if td_spec:
+ model_flags["unet"].extend(
+ [f"--iree-codegen-transform-dialect-library={td_spec}"]
+ )
filenames = []
for modelname in vmfbs.keys():
@@ -143,7 +212,7 @@ def get_modules(args):
"-m",
"iree.build",
os.path.join(THIS_DIR, "components", "builders.py"),
- f"--model-json={args.model_config}",
+ f"--model-json={model_config}",
f"--target={args.target}",
f"--splat={args.splat}",
f"--build-preference={args.build_preference}",
@@ -151,11 +220,9 @@ def get_modules(args):
f"--model={modelname}",
f"--iree-hal-target-device={args.device}",
f"--iree-hip-target={args.target}",
- f"--iree-compile-extra-args={" ".join(ireec_args)}",
+ f"--iree-compile-extra-args={' '.join(ireec_args)}",
]
- print("BUILDER INPUT:\n", " \ \n ".join(builder_args))
output = subprocess.check_output(builder_args).decode()
- print("OUTPUT:", output)
output_paths = output.splitlines()
filenames.extend(output_paths)
@@ -170,15 +237,11 @@ def get_modules(args):
def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
+ from pathlib import Path
+
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
- parser.add_argument(
- "--root-path",
- type=str,
- default=None,
- help="Root path to use for installing behind path based proxy.",
- )
parser.add_argument(
"--timeout-keep-alive", type=int, default=5, help="Keep alive timeout"
)
@@ -199,7 +262,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
)
parser.add_argument(
"--device_ids",
- type=int,
+ type=str,
nargs="*",
default=None,
help="Device IDs visible to the system builder. Defaults to None (full visibility). Can be an index or a sf device id like amdgpu:0:0@0",
@@ -217,8 +280,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
parser.add_argument(
"--model_config",
type=Path,
- required=True,
- help="Path to the model config file",
+ help="Path to the model config file. If None, defaults to i8 punet, batch size 1",
)
parser.add_argument(
"--workers_per_device",
@@ -239,9 +301,6 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
choices=["per_fiber", "per_call", "none"],
help="Concurrency control -- How to isolate programs.",
)
- parser.add_argument(
- "--log_level", type=str, default="error", choices=["info", "debug", "error"]
- )
parser.add_argument(
"--show_progress",
action="store_true",
@@ -252,6 +311,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
action="store_true",
help="Enable tracing of program modules.",
)
+ parser.add_argument(
+ "--amdgpu_async_allocations",
+ action="store_true",
+ help="Enable asynchronous allocations for amdgpu device contexts.",
+ )
parser.add_argument(
"--splat",
action="store_true",
@@ -278,21 +342,36 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
)
parser.add_argument(
"--artifacts_dir",
- type=str,
- default="",
+ type=Path,
+ default=None,
help="Path to local artifacts cache.",
)
- log_levels = {
- "info": logging.INFO,
- "debug": logging.DEBUG,
- "error": logging.ERROR,
- }
+ parser.add_argument(
+ "--tuning_spec",
+ type=str,
+ default=None,
+ help="Path to transform dialect spec if compiling an executable with tunings.",
+ )
+ parser.add_argument(
+ "--topology",
+ type=str,
+ default=None,
+ choices=["spx_single", "cpx_single", "spx_multi", "cpx_multi"],
+ help="Use one of four known performant preconfigured device/fiber topologies.",
+ )
+ parser.add_argument(
+ "--use_tuned",
+ type=int,
+ default=1,
+ help="Use tunings for attention and matmul ops. 0 to disable.",
+ )
args = parser.parse_args(argv)
+ if not args.artifacts_dir:
+ home = Path.home()
+ artdir = home / ".cache" / "shark"
+ args.artifacts_dir = str(artdir)
- log_level = log_levels[args.log_level]
- logger.setLevel(log_level)
- logger.addHandler(logging.FileHandler("shortfin_sd.log"))
global sysman
sysman = configure(args)
uvicorn.run(
@@ -305,14 +384,31 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
if __name__ == "__main__":
+ logging.root.setLevel(logging.INFO)
main(
sys.argv[1:],
# Make logging defer to the default shortfin logging config.
log_config={
"version": 1,
"disable_existing_loggers": False,
- "formatters": {},
- "handlers": {},
- "loggers": {},
+ "formatters": {
+ "default": {
+ "format": "%(asctime)s - %(levelname)s - %(message)s",
+ "datefmt": "%Y-%m-%d %H:%M:%S",
+ },
+ },
+ "handlers": {
+ "console": {
+ "class": "logging.StreamHandler",
+ "formatter": "default",
+ },
+ },
+ "loggers": {
+ "uvicorn": {
+ "handlers": ["console"],
+ "level": "INFO",
+ "propagate": False,
+ },
+ },
},
)
diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py
new file mode 100644
index 000000000..bc0f10655
--- /dev/null
+++ b/shortfin/python/shortfin_apps/sd/simple_client.py
@@ -0,0 +1,228 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import requests
+import argparse
+import base64
+import time
+import asyncio
+import aiohttp
+import sys
+import os
+
+from datetime import datetime as dt
+from PIL import Image
+
+sample_request = {
+ "prompt": [
+ " a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal",
+ ],
+ "neg_prompt": ["Watermark, blurry, oversaturated, low resolution, pollution"],
+ "height": [1024],
+ "width": [1024],
+ "steps": [20],
+ "guidance_scale": [7.5],
+ "seed": [0],
+ "output_type": ["base64"],
+ "rid": ["string"],
+}
+
+
+def bytes_to_img(bytes, outputdir, idx=0, width=1024, height=1024):
+ timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
+ image = Image.frombytes(
+ mode="RGB", size=(width, height), data=base64.b64decode(bytes)
+ )
+ if not os.path.isdir(outputdir):
+ os.mkdir(outputdir)
+ im_path = os.path.join(outputdir, f"shortfin_sd_output_{timestamp}_{idx}.png")
+ image.save(im_path)
+ print(f"Saved to {im_path}")
+
+
+def get_batched(request, arg, idx):
+ if isinstance(request[arg], list):
+ # some args are broadcasted to each prompt, hence overriding idx for single-item entries
+ if len(request[arg]) == 1:
+ indexed = request[arg][0]
+ else:
+ indexed = request[arg][idx]
+ else:
+ indexed = request[arg]
+ return indexed
+
+
+async def send_request(session, rep, args, data):
+ print("Sending request batch #", rep)
+ url = f"http://0.0.0.0:{args.port}/generate"
+ start = time.time()
+ async with session.post(url, json=data) as response:
+ end = time.time()
+ # Check if the response was successful
+ if response.status == 200:
+ response.raise_for_status() # Raise an error for bad responses
+ timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
+ res_json = await response.json(content_type=None)
+ if args.save:
+ for idx, item in enumerate(res_json["images"]):
+ width = get_batched(data, "width", idx)
+ height = get_batched(data, "height", idx)
+ print("Saving response as image...")
+ bytes_to_img(
+ item.encode("utf-8"), args.outputdir, idx, width, height
+ )
+ latency = end - start
+ print("Responses processed.")
+ return latency, len(data["prompt"])
+ else:
+ print(f"Error: Received {response.status} from server")
+ raise Exception
+
+
+async def static(args):
+ # Create an aiohttp session for sending requests
+ async with aiohttp.ClientSession() as session:
+ pending = []
+ latencies = []
+ sample_counts = []
+ # Read the JSON file if supplied. Otherwise, get user input.
+ try:
+ if not args.file:
+ data = sample_request
+ else:
+ with open(args.file, "r") as json_file:
+ data = json.load(json_file)
+ except Exception as e:
+ print(f"Error reading the JSON file: {e}")
+ return
+ data["prompt"] = (
+ [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"]
+ )
+ start = time.time()
+
+ async for i in async_range(args.reps):
+ pending.append(asyncio.create_task(send_request(session, i, args, data)))
+ await asyncio.sleep(1) # Wait for 1 second before sending the next request
+ while pending:
+ done, pending = await asyncio.wait(
+ pending, return_when=asyncio.ALL_COMPLETED
+ )
+ for task in done:
+ latency, num_samples = await task
+ latencies.append(latency)
+ sample_counts.append(num_samples)
+ end = time.time()
+ if not any([i is None for i in [latencies, sample_counts]]):
+ total_num_samples = sum(sample_counts)
+ sps = str(total_num_samples / (end - start))
+ # Until we have better measurements, don't report the throughput that includes saving images.
+ if not args.save:
+ print(f"Average throughput: {sps} samples per second")
+ else:
+ raise ValueError("Received error response from server.")
+
+
+async def interactive(args):
+ # Create an aiohttp session for sending requests
+ async with aiohttp.ClientSession() as session:
+ pending = []
+ latencies = []
+ sample_counts = []
+ # Read the JSON file if supplied. Otherwise, get user input.
+ try:
+ if not args.file:
+ data = sample_request
+ else:
+ with open(args.file, "r") as json_file:
+ data = json.load(json_file)
+ except Exception as e:
+ print(f"Error reading the JSON file: {e}")
+ return
+ data["prompt"] = (
+ [data["prompt"]] if isinstance(data["prompt"], str) else data["prompt"]
+ )
+ while True:
+ prompt = await ainput("Enter a prompt: ")
+ data["prompt"] = [prompt]
+ data["steps"] = [args.steps]
+ print("Sending request with prompt: ", data["prompt"])
+
+ async for i in async_range(args.reps):
+ pending.append(
+ asyncio.create_task(send_request(session, i, args, data))
+ )
+ await asyncio.sleep(
+ 1
+ ) # Wait for 1 second before sending the next request
+ while pending:
+ done, pending = await asyncio.wait(
+ pending, return_when=asyncio.ALL_COMPLETED
+ )
+ for task in done:
+ latency, num_samples = await task
+ pending = []
+ if any([i is None for i in [latencies, sample_counts]]):
+ raise ValueError("Received error response from server.")
+
+
+async def ainput(prompt: str) -> str:
+ return await asyncio.to_thread(input, f"{prompt} ")
+
+
+async def async_range(count):
+ for i in range(count):
+ yield (i)
+ await asyncio.sleep(0.0)
+
+
+def main(argv):
+ p = argparse.ArgumentParser()
+ p.add_argument(
+ "--file",
+ type=str,
+ default=None,
+ help="A non-default request to send to the server.",
+ )
+ p.add_argument(
+ "--reps",
+ type=int,
+ default=1,
+ help="Number of times to duplicate each request in one second intervals.",
+ )
+ p.add_argument(
+ "--save",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="Save images. To disable, use --no-save",
+ )
+ p.add_argument(
+ "--outputdir",
+ type=str,
+ default="gen_imgs",
+ help="Directory to which images get saved.",
+ )
+ p.add_argument("--port", type=str, default="8000", help="Server port")
+ p.add_argument(
+ "--steps",
+ type=int,
+ default="20",
+ help="Number of inference steps. More steps usually means a better image. Interactive only.",
+ )
+ p.add_argument(
+ "--interactive",
+ action="store_true",
+ help="Start as an example CLI client instead of sending static requests.",
+ )
+ args = p.parse_args()
+ if args.interactive:
+ asyncio.run(interactive(args))
+ else:
+ asyncio.run(static(args))
+
+
+if __name__ == "__main__":
+ main(sys.argv)
diff --git a/shortfin/requirements-iree-compiler.txt b/shortfin/requirements-iree-compiler.txt
index 7aea80277..ec033c57c 100644
--- a/shortfin/requirements-iree-compiler.txt
+++ b/shortfin/requirements-iree-compiler.txt
@@ -1,4 +1,4 @@
# Keep in sync with "ref: iree-" in .github/workflows/* and GIT_TAG in CMakeLists.txt
-f https://iree.dev/pip-release-links.html
-iree-base-compiler==2.9.0rc20241108
-iree-base-runtime==2.9.0rc20241108
+iree-base-compiler==3.0.0rc20241115
+iree-base-runtime==3.0.0rc20241115
diff --git a/shortfin/setup.py b/shortfin/setup.py
index 94aae4a55..cf3762950 100644
--- a/shortfin/setup.py
+++ b/shortfin/setup.py
@@ -141,8 +141,8 @@ def copy_extensions_to_source(self, *args, **kwargs):
# Setup and get version information.
-VERSION_INFO_FILE = os.path.join(REL_SOURCE_DIR, "version_info.json")
-VERSION_INFO_RC_FILE = os.path.join(REL_SOURCE_DIR, "version_info_rc.json")
+VERSION_FILE = os.path.join(REL_SOURCE_DIR, "version.json")
+VERSION_FILE_LOCAL = os.path.join(REL_SOURCE_DIR, "version_local.json")
def load_version_info(version_file):
@@ -151,10 +151,10 @@ def load_version_info(version_file):
try:
- version_info = load_version_info(VERSION_INFO_RC_FILE)
+ version_info = load_version_info(VERSION_FILE_LOCAL)
except FileNotFoundError:
- print("version_info_rc.json not found. Default to dev build")
- version_info = load_version_info(VERSION_INFO_FILE)
+ print("version_local.json not found. Default to dev build")
+ version_info = load_version_info(VERSION_FILE)
PACKAGE_VERSION = version_info.get("package-version")
print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'")
@@ -359,10 +359,7 @@ def populate_built_package(abs_dir):
print(f"Found shortfin packages: {packages}")
setup(
- name="shortfin",
version=f"{PACKAGE_VERSION}",
- description="Shortfin native library implementation",
- author="SHARK Authors",
packages=packages,
zip_safe=False,
package_dir=combine_dicts(
diff --git a/shortfin/src/CMakeLists.txt b/shortfin/src/CMakeLists.txt
index 9a955f742..53d801f36 100644
--- a/shortfin/src/CMakeLists.txt
+++ b/shortfin/src/CMakeLists.txt
@@ -4,6 +4,10 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+# Any definitions which must be reflected on the public library must be added
+# to this library.
+add_library(shortfin_public_defs INTERFACE)
+
add_subdirectory(shortfin)
# Common definitions exported from both static and dynamic libraries.
@@ -28,6 +32,7 @@ shortfin_public_library(
shortfin_systems_factory
${_SHORTFIN_LIB_OPTIONAL_COMPONENTS}
USAGE_DEPS
+ shortfin_public_defs
spdlog::spdlog
fmt::fmt
xtensor
diff --git a/shortfin/src/shortfin/array/array.cc b/shortfin/src/shortfin/array/array.cc
index 11961b449..882e4ef39 100644
--- a/shortfin/src/shortfin/array/array.cc
+++ b/shortfin/src/shortfin/array/array.cc
@@ -64,6 +64,7 @@ mapping device_array::data_rw() { return storage_.map_read_write(); }
mapping device_array::data_w() { return storage_.map_write_discard(); }
std::optional device_array::map_memory_for_xtensor() {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::map_memory_for_xtensor");
if (storage_.is_mappable_for_read_write()) {
return storage_.map_read_write();
} else if (storage_.is_mappable_for_read()) {
@@ -97,6 +98,7 @@ std::string device_array::to_s() const {
void device_array::AddAsInvocationArgument(
local::ProgramInvocation *inv, local::ProgramResourceBarrier barrier) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::AddAsInvocationArgument");
auto dims_span = shape();
iree_hal_buffer_view_t *buffer_view;
SHORTFIN_THROW_IF_ERROR(iree_hal_buffer_view_create(
@@ -117,6 +119,7 @@ iree_vm_ref_type_t device_array::invocation_marshalable_type() {
device_array device_array::CreateFromInvocationResultRef(
local::ProgramInvocation *inv, iree::vm_opaque_ref ref) {
+ SHORTFIN_TRACE_SCOPE_NAMED("PyDeviceArray::CreateFromInvocationResultRef");
// We don't retain the buffer view in the device array, so just deref it
// vs stealing the ref.
iree_hal_buffer_view_t *bv = iree_hal_buffer_view_deref(*ref.get());
diff --git a/shortfin/src/shortfin/array/storage.cc b/shortfin/src/shortfin/array/storage.cc
index a30dbf450..ffbbd9ba2 100644
--- a/shortfin/src/shortfin/array/storage.cc
+++ b/shortfin/src/shortfin/array/storage.cc
@@ -43,6 +43,7 @@ storage storage::import_buffer(local::ScopedDevice &device,
storage storage::allocate_device(ScopedDevice &device,
iree_device_size_t allocation_size) {
+ SHORTFIN_TRACE_SCOPE_NAMED("storage::allocate_device");
if (!device.raw_device()) {
throw std::invalid_argument("Cannot allocate with a null device affinity");
}
@@ -63,6 +64,7 @@ storage storage::allocate_device(ScopedDevice &device,
storage storage::allocate_host(ScopedDevice &device,
iree_device_size_t allocation_size,
bool device_visible) {
+ SHORTFIN_TRACE_SCOPE_NAMED("storage::allocate_host");
if (!device.raw_device()) {
throw std::invalid_argument("Cannot allocate with a null device affinity");
}
@@ -207,6 +209,7 @@ std::string storage::formatted_buffer_usage() const {
void storage::AddAsInvocationArgument(local::ProgramInvocation *inv,
local::ProgramResourceBarrier barrier) {
+ SHORTFIN_TRACE_SCOPE_NAMED("storage::AddAsInvocationArgument");
iree::vm_opaque_ref ref;
*(&ref) = iree_hal_buffer_retain_ref(buffer_);
inv->AddArg(std::move(ref));
@@ -220,6 +223,7 @@ iree_vm_ref_type_t storage::invocation_marshalable_type() {
storage storage::CreateFromInvocationResultRef(local::ProgramInvocation *inv,
iree::vm_opaque_ref ref) {
+ SHORTFIN_TRACE_SCOPE_NAMED("storage::CreateFromInvocationResultRef");
// Steal the ref to one of our smart pointers.
// TODO: Should have an opaque_ref::release().
iree::hal_buffer_ptr buffer =
@@ -230,6 +234,7 @@ storage storage::CreateFromInvocationResultRef(local::ProgramInvocation *inv,
storage storage::ImportInvocationResultStorage(local::ProgramInvocation *inv,
iree::hal_buffer_ptr buffer) {
+ SHORTFIN_TRACE_SCOPE_NAMED("storage::ImportInvocationResultStorage");
local::ScopedDevice device =
local::ScopedDevice(*inv->fiber(), inv->device_selection());
auto imported_storage = storage::import_buffer(device, std::move(buffer));
@@ -251,6 +256,7 @@ storage storage::ImportInvocationResultStorage(local::ProgramInvocation *inv,
void storage::AddInvocationArgBarrier(local::ProgramInvocation *inv,
local::ProgramResourceBarrier barrier) {
+ SHORTFIN_TRACE_SCOPE_NAMED("storage::AddInvocationArgBarrier");
switch (barrier) {
case ProgramResourceBarrier::DEFAULT:
case ProgramResourceBarrier::READ:
diff --git a/shortfin/src/shortfin/array/storage.h b/shortfin/src/shortfin/array/storage.h
index b1d7eb6ad..2ea8f5aef 100644
--- a/shortfin/src/shortfin/array/storage.h
+++ b/shortfin/src/shortfin/array/storage.h
@@ -232,14 +232,14 @@ class typed_mapping {
span_type span() { return span_type(data(), size()); }
const_span_type span() const { return const_span_type(data(), size()); }
- span_type::iterator begin() { return span().begin(); }
- span_type::iterator end() { return span().end(); }
+ typename span_type::iterator begin() { return span().begin(); }
+ typename span_type::iterator end() { return span().end(); }
- const_span_type::iterator begin() const { return span().begin(); }
- const_span_type::iterator end() const { return span().end(); }
+ typename const_span_type::iterator begin() const { return span().begin(); }
+ typename const_span_type::iterator end() const { return span().end(); }
- const_span_type::iterator cbegin() const { return span().begin(); }
- const_span_type::iterator cend() const { return span().end(); }
+ typename const_span_type::iterator cbegin() const { return span().begin(); }
+ typename const_span_type::iterator cend() const { return span().end(); }
private:
mapping untyped_mapping_;
diff --git a/shortfin/src/shortfin/array/xtensor_bridge.cc b/shortfin/src/shortfin/array/xtensor_bridge.cc
index bd3753331..da350b71a 100644
--- a/shortfin/src/shortfin/array/xtensor_bridge.cc
+++ b/shortfin/src/shortfin/array/xtensor_bridge.cc
@@ -8,6 +8,7 @@
#include
+#include "shortfin/support/logging.h"
#include "xtl/xhalf_float.hpp"
namespace shortfin::array {
@@ -56,6 +57,7 @@ class typed_xt_methods final : public poly_xt_methods {
bool poly_xt_methods::inplace_new(uint8_t *inst_storage, DType dtype,
void *array_memory, size_t array_memory_size,
Dims &dims) {
+ SHORTFIN_TRACE_SCOPE_NAMED("array_xtensor_cast");
#define POLY_XT_CASE(et, cpp_type) \
case et: \
typed_xt_methods::concrete_inplace_new( \
diff --git a/shortfin/src/shortfin/local/program.cc b/shortfin/src/shortfin/local/program.cc
index 3fd41d87b..6ab1f47ae 100644
--- a/shortfin/src/shortfin/local/program.cc
+++ b/shortfin/src/shortfin/local/program.cc
@@ -75,6 +75,7 @@ std::string_view ProgramFunction::calling_convention() const {
ProgramInvocation::Ptr ProgramFunction::CreateInvocation(
std::shared_ptr fiber, std::optional isolation) {
+ SHORTFIN_TRACE_SCOPE_NAMED("ProgramFunction::CreateInvocation");
ProgramIsolation actual_isolation = isolation ? *isolation : isolation_;
// Low-overhead NONE isolation handling (saves some ref-count twiddling).
if (actual_isolation == ProgramIsolation::NONE) {
@@ -101,6 +102,7 @@ std::string ProgramFunction::to_s() const {
ProgramModule ProgramModule::Load(System &system,
const std::filesystem::path &path,
bool mmap) {
+ SHORTFIN_TRACE_SCOPE_NAMED("ProgramModule::Load");
iree::file_contents_ptr contents;
iree_file_read_flags_t flags =
mmap ? IREE_FILE_READ_FLAG_MMAP : IREE_FILE_READ_FLAG_PRELOAD;
@@ -171,6 +173,7 @@ std::vector ProgramModule::exports() const {
Program Program::Load(std::span modules,
Options &&options) {
+ SHORTFIN_TRACE_SCOPE_NAMED("Program::Load");
std::vector all_modules;
std::vector raw_devices;
@@ -451,6 +454,7 @@ iree_status_t ProgramInvocation::FinalizeCallingConvention(
ProgramInvocation::Future ProgramInvocation::Invoke(
ProgramInvocation::Ptr invocation) {
+ SHORTFIN_TRACE_SCOPE_NAMED("ProgramInvocation::Invoke");
invocation->CheckNotScheduled();
Worker &worker = invocation->fiber_->worker();
@@ -462,9 +466,11 @@ ProgramInvocation::Future ProgramInvocation::Invoke(
iree_vm_function_t function,
ProgramInvocationModel invocation_model,
std::optional failure_future) {
+ SHORTFIN_TRACE_SCOPE_NAMED("ProgramInvocation::InvokeAsync");
auto complete_callback =
[](void *user_data, iree_loop_t loop, iree_status_t status,
iree_vm_list_t *outputs) noexcept -> iree_status_t {
+ SHORTFIN_TRACE_SCOPE_NAMED("ProgramInvocation::Complete");
// Async invocation helpfully gives us a retained reference to the
// outputs, but we already have one statically on the
// ProgramInvocation. So release this one, which makes it safe to
@@ -620,6 +626,7 @@ StaticProgramParameters::StaticProgramParameters(
void StaticProgramParameters::Load(std::filesystem::path file_path,
LoadOptions options) {
+ SHORTFIN_TRACE_SCOPE_NAMED("StaticProgramParameters::Load");
// Default format from extension.
if (options.format.empty()) {
options.format = file_path.extension().string();
diff --git a/shortfin/src/shortfin/local/scheduler.cc b/shortfin/src/shortfin/local/scheduler.cc
index 3b82ded20..883951a20 100644
--- a/shortfin/src/shortfin/local/scheduler.cc
+++ b/shortfin/src/shortfin/local/scheduler.cc
@@ -61,6 +61,7 @@ void Account::active_deps_extend(iree_hal_semaphore_list_t sem_list) {
}
VoidFuture Account::OnSync() {
+ SHORTFIN_TRACE_SCOPE_NAMED("Account::OnSync");
// TODO: Burn this path with fire! No attempt has been made to make this
// particularly good: the backend is being implemented now to export
// HAL semaphores via iree_hal_semaphore_await, and that should be used
@@ -133,6 +134,7 @@ Scheduler::~Scheduler() {
void Scheduler::Initialize(
std::span> devices) {
+ SHORTFIN_TRACE_SCOPE_NAMED("Scheduler::Initialize");
for (auto &it : devices) {
accounts_.emplace_back(*this, it.second);
}
@@ -165,6 +167,7 @@ Account &Scheduler::GetDefaultAccount(ScopedDevice &device) {
void Scheduler::AppendCommandBuffer(ScopedDevice &device,
TransactionType tx_type,
std::function callback) {
+ SHORTFIN_TRACE_SCOPE_NAMED("Scheduler::AppendCommandBuffer");
Account &account = GetDefaultAccount(device);
auto needed_affinity_bits = device.affinity().queue_affinity();
SHORTFIN_SCHED_LOG(
@@ -242,6 +245,7 @@ void Scheduler::AppendCommandBuffer(ScopedDevice &device,
}
iree_status_t Scheduler::FlushWithStatus() noexcept {
+ SHORTFIN_TRACE_SCOPE_NAMED("Scheduler::FlushWithStatus");
// This loop is optimized for a small number of accounts, where it is
// fine to just linearly probe. If this ever becomes cumbersome, we can
// maintain a dirty list which is appended to when an account transitions
diff --git a/shortfin/src/shortfin/local/system.cc b/shortfin/src/shortfin/local/system.cc
index f5012c626..ef31bb001 100644
--- a/shortfin/src/shortfin/local/system.cc
+++ b/shortfin/src/shortfin/local/system.cc
@@ -20,6 +20,7 @@ namespace shortfin::local {
System::System(iree_allocator_t host_allocator)
: host_allocator_(host_allocator) {
+ SHORTFIN_TRACE_SCOPE_NAMED("System::System");
logging::construct("System", this);
SHORTFIN_THROW_IF_ERROR(iree_vm_instance_create(IREE_VM_TYPE_CAPACITY_DEFAULT,
host_allocator_,
@@ -29,6 +30,7 @@ System::System(iree_allocator_t host_allocator)
}
System::~System() {
+ SHORTFIN_TRACE_SCOPE_NAMED("System::~System");
logging::destruct("System", this);
bool needs_shutdown = false;
{
@@ -61,6 +63,7 @@ System::~System() {
}
void System::Shutdown() {
+ SHORTFIN_TRACE_SCOPE_NAMED("System::Shutdown");
// Stop workers.
std::vector local_workers;
{
diff --git a/shortfin/src/shortfin/local/systems/CMakeLists.txt b/shortfin/src/shortfin/local/systems/CMakeLists.txt
index b2bcbef23..b1c9d8b44 100644
--- a/shortfin/src/shortfin/local/systems/CMakeLists.txt
+++ b/shortfin/src/shortfin/local/systems/CMakeLists.txt
@@ -29,6 +29,7 @@ shortfin_cc_component(
iree_task_task
)
list(APPEND _SYSTEM_COMPONENTS shortfin_systems_host)
+target_compile_definitions(shortfin_public_defs INTERFACE SHORTFIN_HAVE_HOSTCPU)
if(SHORTFIN_SYSTEMS_AMDGPU)
shortfin_cc_component(
@@ -47,6 +48,7 @@ if(SHORTFIN_SYSTEMS_AMDGPU)
iree_hal_drivers_hip_hip
)
list(APPEND _SYSTEM_COMPONENTS shortfin_systems_amdgpu)
+ target_compile_definitions(shortfin_public_defs INTERFACE SHORTFIN_HAVE_AMDGPU)
endif()
shortfin_cc_component(
diff --git a/shortfin/src/shortfin/local/systems/amdgpu.cc b/shortfin/src/shortfin/local/systems/amdgpu.cc
index 2625e8325..cecedd1a0 100644
--- a/shortfin/src/shortfin/local/systems/amdgpu.cc
+++ b/shortfin/src/shortfin/local/systems/amdgpu.cc
@@ -7,6 +7,7 @@
#include "shortfin/local/systems/amdgpu.h"
#include "shortfin/support/logging.h"
+#include "shortfin/support/sysconfig.h"
namespace shortfin::local::systems {
@@ -86,6 +87,7 @@ void AMDGPUSystemBuilder::InitializeDefaultSettings() {
void AMDGPUSystemBuilder::Enumerate() {
if (hip_hal_driver_) return;
+ SHORTFIN_TRACE_SCOPE_NAMED("AMDGPUSystemBuilder::Enumerate");
iree_hal_hip_driver_options_t driver_options;
iree_hal_hip_driver_options_initialize(&driver_options);
@@ -126,6 +128,7 @@ std::vector AMDGPUSystemBuilder::GetAvailableDeviceIds() {
}
SystemPtr AMDGPUSystemBuilder::CreateSystem() {
+ SHORTFIN_TRACE_SCOPE_NAMED("AMDGPUSystemBuilder::CreateSystem");
auto lsys = std::make_shared(host_allocator());
Enumerate();
@@ -190,6 +193,22 @@ SystemPtr AMDGPUSystemBuilder::CreateSystem() {
}
}
+ // Estimate the resource requirements for the requested number of devices.
+ // As of 2024-11-08, the number of file handles required to open 64 device
+ // partitions was 31 times the number to open one device. Because it is not
+ // good to run near the limit, we conservatively round that up to 64 above
+ // an arbitrary baseline of 768. This means that on a small, four device
+ // system, we will not request to raise limits for the Linux default of
+ // 1024 file handles, but we will raise for everything larger (which tends
+ // to be where the problems are).
+ size_t expected_device_count =
+ used_device_ids.size() * logical_devices_per_physical_device_;
+ if (!sysconfig::EnsureFileLimit(expected_device_count * 64 + 768)) {
+ logging::error(
+ "Could not ensure sufficient file handles for minimum operations: "
+ "Suggest setting explicit limits with `ulimit -n` and system settings");
+ }
+
// Initialize all used GPU devices.
for (size_t instance_ordinal = 0; instance_ordinal < used_device_ids.size();
++instance_ordinal) {
diff --git a/shortfin/src/shortfin/local/systems/host.cc b/shortfin/src/shortfin/local/systems/host.cc
index 5629979e4..1da4b2af1 100644
--- a/shortfin/src/shortfin/local/systems/host.cc
+++ b/shortfin/src/shortfin/local/systems/host.cc
@@ -11,6 +11,7 @@
#include "iree/hal/local/loaders/registration/init.h"
#include "shortfin/support/iree_helpers.h"
#include "shortfin/support/logging.h"
+#include "shortfin/support/sysconfig.h"
namespace shortfin::local::systems {
@@ -124,6 +125,7 @@ HostCPUSystemBuilder::SelectHostCPUNodesFromOptions() {
}
SystemPtr HostCPUSystemBuilder::CreateSystem() {
+ SHORTFIN_TRACE_SCOPE_NAMED("HostCPUSystemBuilder::CreateSystem");
auto lsys = std::make_shared(host_allocator());
// TODO: Real NUMA awareness.
lsys->InitializeNodes(1);
@@ -135,6 +137,7 @@ SystemPtr HostCPUSystemBuilder::CreateSystem() {
}
iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) {
+ SHORTFIN_TRACE_SCOPE_NAMED("HostCPUSystemBuilder::InitializeHostCPUDriver");
// TODO: Kill these flag variants in favor of settings on the config
// object.
SHORTFIN_THROW_IF_ERROR(iree_task_executor_options_initialize_from_flags(
@@ -149,6 +152,8 @@ iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) {
}
// Create one queue executor per node.
+ unsigned total_needed_file_handles = 512;
+ bool has_issued_limit_error = false;
std::vector queue_executors;
queue_executors.reserve(selected_nodes.size());
queue_node_ids_.reserve(selected_nodes.size());
@@ -162,6 +167,21 @@ iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) {
node_id, iree_task_topology_group_count(&topology.topology));
queue_executors.push_back({});
auto &executor = queue_executors.back();
+ // As of 2024-11-8, it took approximately 32 file handles per node-group.
+ // To be conservative because file handle limits are basically free, we
+ // round up to 64 and assume a floor of 512. This allows small, default
+ // 8 group, single node configs to require no limit increase for Linux
+ // 1024 default cases.
+ total_needed_file_handles += 64 * topology.topology.group_count;
+ if (!sysconfig::EnsureFileLimit(total_needed_file_handles) &&
+ !has_issued_limit_error) {
+ logging::error(
+ "Could not ensure sufficient file handles for minimum operations: "
+ "Suggest setting explicit limits with `ulimit -n` and system "
+ "settings");
+ has_issued_limit_error = true;
+ }
+
SHORTFIN_THROW_IF_ERROR(iree_task_executor_create(
host_cpu_deps_.task_executor_options, &topology.topology,
host_allocator(), executor.for_output()));
@@ -188,6 +208,7 @@ iree_hal_driver_t *HostCPUSystemBuilder::InitializeHostCPUDriver(System &lsys) {
void HostCPUSystemBuilder::InitializeHostCPUDevices(System &lsys,
iree_hal_driver_t *driver) {
+ SHORTFIN_TRACE_SCOPE_NAMED("HostCPUSystemBuilder::InitializeHostCPUDevices");
iree_host_size_t device_info_count = 0;
iree::allocated_ptr device_infos(host_allocator());
SHORTFIN_THROW_IF_ERROR(iree_hal_driver_query_available_devices(
diff --git a/shortfin/src/shortfin/local/worker.cc b/shortfin/src/shortfin/local/worker.cc
index 09207e5e4..d5ffafdbe 100644
--- a/shortfin/src/shortfin/local/worker.cc
+++ b/shortfin/src/shortfin/local/worker.cc
@@ -109,6 +109,7 @@ iree_status_t Worker::TransactLoop(iree_status_t signal_status) {
for (auto& next_thunk : next_thunks_) {
// TODO: Make thunks have to return a status, propagate, and handle
// exceptions.
+ SHORTFIN_TRACE_SCOPE_NAMED("Worker::ThreadsafeCallback");
next_thunk();
}
next_thunks_.clear();
diff --git a/shortfin/src/shortfin/support/CMakeLists.txt b/shortfin/src/shortfin/support/CMakeLists.txt
index cbf171894..ea8572466 100644
--- a/shortfin/src/shortfin/support/CMakeLists.txt
+++ b/shortfin/src/shortfin/support/CMakeLists.txt
@@ -16,12 +16,14 @@ shortfin_cc_component(
iree_concurrency.h
logging.h
stl_extras.h
+ sysconfig.h
SRCS
blocking_executor.cc
config.cc
globals.cc
iree_helpers.cc
logging.cc
+ sysconfig.cc
DEPS
iree_base_base
# TODO: Maybe reclassify some of these low level, shared support entities
diff --git a/shortfin/src/shortfin/support/logging.h b/shortfin/src/shortfin/support/logging.h
index 7bc9e130d..e70c54e99 100644
--- a/shortfin/src/shortfin/support/logging.h
+++ b/shortfin/src/shortfin/support/logging.h
@@ -23,6 +23,14 @@
#define SHORTFIN_SCHED_LOG(...)
#endif
+// Tracing macros. These are currently just aliases of the underlying IREE
+// macros, but we maintain the ability to redirect them in the future (i.e.
+// for certain kinds of library builds, etc).
+#define SHORTFIN_TRACE_SCOPE IREE_TRACE_SCOPE
+#define SHORTFIN_TRACE_SCOPE_NAMED(name_literal) \
+ IREE_TRACE_SCOPE_NAMED(name_literal)
+#define SHORTFIN_TRACE_SCOPE_ID IREE_TRACE_SCOPE_ID
+
namespace shortfin::logging {
SHORTFIN_API void InitializeFromEnv();
diff --git a/shortfin/src/shortfin/support/sysconfig.cc b/shortfin/src/shortfin/support/sysconfig.cc
new file mode 100644
index 000000000..486f5ffc4
--- /dev/null
+++ b/shortfin/src/shortfin/support/sysconfig.cc
@@ -0,0 +1,63 @@
+// Copyright 2024 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "shortfin/support/sysconfig.h"
+
+#include "shortfin/support/logging.h"
+
+#ifdef __linux__
+#include
+#endif
+
+namespace shortfin::sysconfig {
+
+// -----------------------------------------------------------------------------
+// File handle limits
+// -----------------------------------------------------------------------------
+
+#ifdef __linux__
+
+bool EnsureFileLimit(unsigned needed_limit) {
+ struct rlimit limit;
+ if (getrlimit(RLIMIT_NOFILE, &limit) != 0) {
+ return {};
+ }
+
+ if (limit.rlim_cur >= needed_limit) return true;
+ unsigned requested_limit = needed_limit;
+ if (limit.rlim_max >= needed_limit) {
+ logging::debug(
+ "Estimated number of open file handles ({}) < current limit ({}) but "
+ "within max limit ({}): Increasing limit",
+ needed_limit, limit.rlim_cur, limit.rlim_max);
+ } else if (limit.rlim_max > limit.rlim_cur) {
+ logging::warn(
+ "Esimated number of open file handles ({}) < current ({}) and max ({}) "
+ "limit: Increasing to max",
+ needed_limit, limit.rlim_cur, limit.rlim_max);
+ requested_limit = limit.rlim_max;
+ } else {
+ logging::warn("Esimated number of open file handles ({}) < max ({})",
+ needed_limit, limit.rlim_max);
+ return false;
+ }
+
+ limit.rlim_cur = requested_limit;
+ if (setrlimit(RLIMIT_NOFILE, &limit) != 0) {
+ logging::error("Could not set open file handle limit to {}",
+ requested_limit);
+ return false;
+ }
+
+ return limit.rlim_cur >= needed_limit;
+}
+
+#else
+// Fallback implementation.
+bool EnsureFileLimit(unsigned needed_limit) { return true; }
+#endif
+
+} // namespace shortfin::sysconfig
diff --git a/shortfin/src/shortfin/support/sysconfig.h b/shortfin/src/shortfin/support/sysconfig.h
new file mode 100644
index 000000000..864405efc
--- /dev/null
+++ b/shortfin/src/shortfin/support/sysconfig.h
@@ -0,0 +1,25 @@
+// Copyright 2024 Advanced Micro Devices, Inc.
+//
+// Licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef SHORTFIN_SUPPORT_SYSCONFIG_H
+#define SHORTFIN_SUPPORT_SYSCONFIG_H
+
+#include
+#include
+
+namespace shortfin::sysconfig {
+
+// Attempts to ensure that the given number of file descriptors can be created.
+// If the system does not support such a thing (i.e. GetOpenFileLimit() returns
+// nothing), then nothing is done and true is returned. If the system does
+// support it and heuristics say this should be allowed, then true will return.
+// Otherwise, a warning will be logged and false returned.
+// This is a best effort attempt.
+bool EnsureFileLimit(unsigned needed_limit);
+
+} // namespace shortfin::sysconfig
+
+#endif // SHORTFIN_SUPPORT_SYSCONFIG_H
diff --git a/shortfin/tests/api/array_ops_test.py b/shortfin/tests/api/array_ops_test.py
index 69d21e929..7c792d92b 100644
--- a/shortfin/tests/api/array_ops_test.py
+++ b/shortfin/tests/api/array_ops_test.py
@@ -167,3 +167,104 @@ def test_fill_randn_explicit_generator(device, dtype):
assert contents1 == contents2
# And not be zero.
assert contents1 != bytes(mz)
+
+
+@pytest.mark.parametrize(
+ "dtype",
+ [
+ sfnp.uint8,
+ sfnp.uint16,
+ sfnp.uint32,
+ sfnp.uint64,
+ sfnp.int8,
+ sfnp.int16,
+ sfnp.int32,
+ sfnp.int64,
+ sfnp.float16,
+ sfnp.float32,
+ sfnp.float64,
+ ],
+)
+def test_convert(device, dtype):
+ input_array = sfnp.device_array(device, [2, 3], dtype=sfnp.int32)
+ with input_array.map(write=True) as m:
+ m.fill(16)
+ intermediate = sfnp.convert(input_array, dtype=dtype)
+ with input_array.map(write=True) as m:
+ m.fill(0)
+ sfnp.convert(intermediate, out=input_array)
+ assert list(input_array.items) == 6 * [16]
+
+
+def round_half_up(n):
+ return math.floor(n + 0.5)
+
+
+def round_half_away_from_zero(n):
+ rounded_abs = round_half_up(abs(n))
+ return math.copysign(rounded_abs, n)
+
+
+@pytest.mark.parametrize(
+ "dtype,sfnp_func,ref_round_func",
+ [
+ (sfnp.float16, sfnp.round, round_half_away_from_zero),
+ (sfnp.float32, sfnp.round, round_half_away_from_zero),
+ (sfnp.float16, sfnp.ceil, math.ceil),
+ (sfnp.float32, sfnp.ceil, math.ceil),
+ (sfnp.float16, sfnp.floor, math.floor),
+ (sfnp.float32, sfnp.floor, math.floor),
+ (sfnp.float16, sfnp.trunc, math.trunc),
+ (sfnp.float32, sfnp.trunc, math.trunc),
+ ],
+)
+def test_nearest_int_no_conversion(device, dtype, sfnp_func, ref_round_func):
+ input = sfnp.device_array(device, [2, 3], dtype=dtype)
+ sfnp.fill_randn(input)
+ ref_rounded = [
+ ref_round_func(n) for n in sfnp.convert(input, dtype=sfnp.float32).items
+ ]
+ output = sfnp_func(input)
+ assert output.dtype == dtype
+ output_items = sfnp.convert(output, dtype=sfnp.float32).items
+ print(output_items)
+ for ref, actual in zip(ref_rounded, output_items):
+ assert ref == pytest.approx(actual)
+
+
+@pytest.mark.parametrize(
+ "dtype,out_dtype,sfnp_func,ref_round_func",
+ [
+ # Round
+ (sfnp.float16, sfnp.int8, sfnp.round, round_half_away_from_zero),
+ (sfnp.float32, sfnp.int8, sfnp.round, round_half_away_from_zero),
+ (sfnp.float32, sfnp.int16, sfnp.round, round_half_away_from_zero),
+ (sfnp.float32, sfnp.int32, sfnp.round, round_half_away_from_zero),
+ # Note that we do not test unsigned conversion with random data.
+ # Ceil
+ (sfnp.float16, sfnp.int8, sfnp.ceil, math.ceil),
+ (sfnp.float32, sfnp.int8, sfnp.ceil, math.ceil),
+ (sfnp.float32, sfnp.int16, sfnp.ceil, math.ceil),
+ (sfnp.float32, sfnp.int32, sfnp.ceil, math.ceil),
+ # Floor
+ (sfnp.float16, sfnp.int8, sfnp.floor, math.floor),
+ (sfnp.float32, sfnp.int8, sfnp.floor, math.floor),
+ (sfnp.float32, sfnp.int16, sfnp.floor, math.floor),
+ (sfnp.float32, sfnp.int32, sfnp.floor, math.floor),
+ # Trunc
+ (sfnp.float16, sfnp.int8, sfnp.trunc, math.trunc),
+ (sfnp.float32, sfnp.int8, sfnp.trunc, math.trunc),
+ (sfnp.float32, sfnp.int16, sfnp.trunc, math.trunc),
+ (sfnp.float32, sfnp.int32, sfnp.trunc, math.trunc),
+ ],
+)
+def test_nearest_int_conversion(device, dtype, out_dtype, sfnp_func, ref_round_func):
+ input = sfnp.device_array(device, [2, 3], dtype=dtype)
+ sfnp.fill_randn(input)
+ ref_rounded = [
+ int(ref_round_func(n)) for n in sfnp.convert(input, dtype=sfnp.float32).items
+ ]
+ output = sfnp_func(input, dtype=out_dtype)
+ assert output.dtype == out_dtype
+ for ref, actual in zip(ref_rounded, output.items):
+ assert ref == int(actual)
diff --git a/shortfin/tests/apps/sd/e2e_test.py b/shortfin/tests/apps/sd/e2e_test.py
index cab8ecab2..26c2e30f6 100644
--- a/shortfin/tests/apps/sd/e2e_test.py
+++ b/shortfin/tests/apps/sd/e2e_test.py
@@ -1,3 +1,9 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
import json
import requests
import time
diff --git a/shortfin/version.json b/shortfin/version.json
new file mode 100644
index 000000000..f09f61d2a
--- /dev/null
+++ b/shortfin/version.json
@@ -0,0 +1,3 @@
+{
+ "package-version": "2.9.2.dev"
+}
diff --git a/shortfin/version_info.json b/shortfin/version_info.json
deleted file mode 100644
index ca3c0ed0b..000000000
--- a/shortfin/version_info.json
+++ /dev/null
@@ -1,3 +0,0 @@
-{
- "package-version": "2.9.0.dev"
-}
diff --git a/tuner/README.md b/tuner/README.md
index 47156779c..3737f6bdf 100644
--- a/tuner/README.md
+++ b/tuner/README.md
@@ -1,5 +1,8 @@
# IREE dispatch auto-tuning scripts
-`libtuner.py` is the core Python script that provides the fundamental functions for the tuning loop. It imports `candidate_gen.py` for candidate generation. To implement the full tuning loop, `libtuner.py` requires a separate Python script that uses the provided `TuningClient` API from `libtuner.py`.
+`libtuner.py` is the core Python script that provides the fundamental functions
+for the tuning loop. It imports `candidate_gen.py` for candidate generation. To
+implement the full tuning loop, `libtuner.py` requires a separate Python script
+that uses the provided `TuningClient` API from `libtuner.py`.
## Prerequisites
[Optional] Using virtual environments:
@@ -22,47 +25,13 @@ Using the IREE's Python bindings:
- Set environment
```shell
source ../iree-build/.env && export PYTHONPATH
+ export PATH="$(realpath ../iree-build/tools):$PATH"
```
-For more information, refer to the [IREE documentation](https://iree.dev/building-from-source/getting-started/#python-bindings)
+For more information, refer to the [IREE
+documentation](https://iree.dev/building-from-source/getting-started/#python-bindings).
-### Overall flow
+## Examples
-1. Symlink all scripts and mlir/irpa files in your build dir.
- - Symlink `iree-build-dir/tools` inside `tuning`.
- - Symlink ML model MLIR and weights based on `unet.sh`.
-
-2. Copy the attention/matmul spec as `config.mlir` in the tuning dir.
-
-3. Temporarily comment out all the existing configs in `config.mlir`.
- - Example:
- ```mlir
- // , @match_mmt_2048x10240x1280 -> @apply_op_config
- // , @match_mmt_2048x1280x5120 -> @apply_op_config
- // , @match_mmt_2048x1280x1280 -> @apply_op_config
- ```
-
-4. Compile a baseline unet
-```shell
-./unet.sh winograd unet.mlir -o unet_baseline.vmfb --iree-hal-dump-executable-files-to=dump-winograd
-```
-
-5. Find the matmul to tune and copy the `*_benchmark.mlir` file to the build dir.
-```shell
-cp dump-winograd/*_141_*benchmark.mlir ./141.mlir
-```
-
-6. Run the tuning script.
- - Example:
- ```shell
- python -m examples.punet 141.mlir --devices=hip://GPU-0,hip://GPU-4 --num-candidates=1024
- ```
-
-7. Check the winner candidate in `result_summary.log`, find and copy the transform spec.
-
-8. Paste the transform spec into the `config.mlir` and uncomment them.
-
-9. Add the match function to the entry point in `config.mlir`
- - Example:
- ```mlir
- @match_something -> @apply_op_config
- ```
+Check the `examples` directory for sample tuners implemented with `libtuner`.
+The [`dispatch` example](https://github.com/nod-ai/shark-ai/tree/main/tuner/examples/dispatch)
+should be a good starting point for most users.
diff --git a/tuner/pyproject.toml b/tuner/pyproject.toml
new file mode 100644
index 000000000..c36326bf7
--- /dev/null
+++ b/tuner/pyproject.toml
@@ -0,0 +1,24 @@
+[project]
+name = "SHARK Tuner"
+authors = [
+ {name = "SHARK Authors"},
+]
+description = "IREE Dispatch Tuner"
+readme = "README.md"
+license = {text = "Apache-2.0"}
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
+]
+requires-python = ">= 3.10"
+
+# Version is set via the `setup.py`.
+dynamic = ["version"]
+
+[project.urls]
+Repository = "https://github.com/nod-ai/shark-ai"
diff --git a/tuner/requirements-dev.txt b/tuner/requirements-dev.txt
index 51d5b9ba0..747b28508 100644
--- a/tuner/requirements-dev.txt
+++ b/tuner/requirements-dev.txt
@@ -1,2 +1,3 @@
+mypy==1.8.0
pre-commit==3.8.0
virtualenv==20.13.0
diff --git a/tuner/setup.py b/tuner/setup.py
new file mode 100644
index 000000000..aa450eaee
--- /dev/null
+++ b/tuner/setup.py
@@ -0,0 +1,35 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import json
+import os
+
+from setuptools import setup
+
+SETUPPY_DIR = os.path.realpath(os.path.dirname(__file__))
+
+# Setup and get version information.
+VERSION_FILE = os.path.join(SETUPPY_DIR, "version.json")
+VERSION_FILE_LOCAL = os.path.join(SETUPPY_DIR, "version_local.json")
+
+
+def load_version_info(version_file):
+ with open(version_file, "rt") as f:
+ return json.load(f)
+
+
+try:
+ version_info = load_version_info(VERSION_FILE_LOCAL)
+except FileNotFoundError:
+ print("version_local.json not found. Default to dev build")
+ version_info = load_version_info(VERSION_FILE)
+
+PACKAGE_VERSION = version_info.get("package-version")
+print(f"Using PACKAGE_VERSION: '{PACKAGE_VERSION}'")
+
+setup(
+ version=f"{PACKAGE_VERSION}",
+)
diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py
index 40eb27a82..b50df12d5 100644
--- a/tuner/tuner/candidate_gen.py
+++ b/tuner/tuner/candidate_gen.py
@@ -20,256 +20,23 @@
import argparse
import logging
-import math
import pickle
import re
-import z3
-from dataclasses import astuple, dataclass
-from enum import Enum
-from os import mkdir, path, makedirs
+from dataclasses import dataclass
+from os import path, makedirs
from typing import Optional
from textwrap import indent
-from abc import ABC, abstractmethod
+from abc import abstractmethod
-import iree.compiler as ireec
-from iree.compiler import ir
-from iree.compiler.dialects import _linalg_ops_gen, _util_ops_gen
+from iree.compiler import ir # type: ignore
+from .common import *
+from .dispatch_constraints import *
+from .dispatch_parser import *
tune_logger = logging.getLogger("tune")
-class DispatchKind(Enum):
- conv = 1
- mmt = 2
- contraction = 3
- batch_mmt = 4
- batch_matmul = 5
- broadcast_rhs_mmt = 6
-
-
-class ElementType(Enum):
- i8 = 1
- i32 = 2
- f8 = 3
- f16 = 4
- f32 = 5
-
- @property
- def bitwidth(self) -> int:
- match self:
- case ElementType.i8 | ElementType.f8:
- return 8
- case ElementType.f16:
- return 16
- case ElementType.i32 | ElementType.f32:
- return 32
- case _:
- assert False, "unhandled case"
-
- def __str__(self) -> str:
- return self.name
-
-
-@dataclass
-class ShapedType:
- shape: list[int]
- element_type: ElementType
-
- def rank(self) -> int:
- return len(self.shape)
-
- @property
- def bitwidth(self) -> int:
- return self.element_type.bitwidth
-
- def __str__(self) -> str:
- dim_to_str = lambda dim: str(dim) if dim != -1 else "?"
- return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type)
-
-
-@dataclass
-class MatmulSize:
- M: int
- N: int
- K: int
- B: int = 1
-
-
-@dataclass
-class ProblemSize:
- matmul_size: MatmulSize
- lhs_type: ShapedType
- rhs_type: ShapedType
- res_type: ShapedType
- dispatch_kind: DispatchKind
-
- @property
- def MNK(self) -> tuple[int, int, int]:
- return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K)
-
-
-@dataclass
-class MfmaIntrinsic:
- output_type: ElementType
- m: int
- n: int
- k: int
- input_type: ElementType
-
- def __str__(self) -> str:
- input = str(self.input_type).upper()
- output = str(self.output_type).upper()
- return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}"
-
- @staticmethod
- def mfma_f32_16x16x16_f16():
- return MfmaIntrinsic(ElementType.f32, 16, 16, 16, ElementType.f16)
-
- @staticmethod
- def mfma_f32_32x32x8_f16():
- return MfmaIntrinsic(ElementType.f32, 32, 32, 8, ElementType.f16)
-
- @staticmethod
- def mfma_i32_16x16x32_i8():
- return MfmaIntrinsic(ElementType.i32, 16, 16, 32, ElementType.i8)
-
- @staticmethod
- def mfma_i32_32x32x16_i8():
- return MfmaIntrinsic(ElementType.i32, 32, 32, 16, ElementType.i8)
-
- @staticmethod
- def all():
- return [
- MfmaIntrinsic.mfma_f32_16x16x16_f16(),
- MfmaIntrinsic.mfma_f32_32x32x8_f16(),
- MfmaIntrinsic.mfma_i32_16x16x32_i8(),
- MfmaIntrinsic.mfma_i32_32x32x16_i8(),
- ]
-
-
-class ReorderWorkgroupsStrategy(Enum):
- NONE = 0
- SWIZZLE = 1
- TRANSPOSE = 2
-
- def __str__(self) -> str:
- return self.name.title()
-
-
-@dataclass
-class GpuPipelineOptions:
- """Represents the `iree_gpu.pipeline_options` attribute"""
-
- prefetch_shared_memory: Optional[bool] = None
- no_reduce_shared_memory_bank_conflicts: Optional[bool] = None
- reorder_workgroups_strategy: Optional[ReorderWorkgroupsStrategy] = None
-
- def all_default(self) -> bool:
- return all(x is None for x in astuple(self))
-
- def __str__(self) -> str:
- options: list[str] = []
- if self.prefetch_shared_memory is not None:
- options.append(
- f"prefetch_shared_memory = {str(self.prefetch_shared_memory).lower()}"
- )
- if self.no_reduce_shared_memory_bank_conflicts is not None:
- options.append(
- f"no_reduce_shared_memory_bank_conflicts = {str(self.no_reduce_shared_memory_bank_conflicts).lower()}"
- )
- if self.reorder_workgroups_strategy is not None:
- options.append(
- f"reorder_workgroups_strategy = {self.reorder_workgroups_strategy}"
- )
-
- return f"#iree_gpu.pipeline_options<{', '.join(options)}>"
-
-
-@dataclass
-class Configuration:
- subgroup_size: int
- workgroup_size: list[int]
- intrinsic: MfmaIntrinsic
- tile_sizes: list[int]
- subgroup_m_count: int
- subgroup_n_count: int
- gpu_pipeline_options: GpuPipelineOptions
- waves_per_eu: int
-
-
-class MlirRegex(Enum):
- ssa_value = r"%[a-zA-Z0-9-_]+"
- tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>"
-
- def __str__(self) -> str:
- return self.value
-
- @staticmethod
- def dps_ins_two_args() -> str:
- return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)"
-
- @staticmethod
- def dps_outs_one_arg() -> str:
- return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)"
-
-
-def read_input_mlir(filename: str) -> list[str]:
- with open(filename, "r") as f:
- return f.readlines()
-
-
-def get_mmt_tile_sizes(configuration: Configuration):
- return configuration.tile_sizes
-
-
-@dataclass
-class ConvDimInfo:
- n: int
- oh: int
- ow: int
- oc: int
- fh: int
- fw: int
- ic: int
-
- @staticmethod
- def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType):
- n, oh, ow, oc = res_shaped_type.shape
- fh, fw, ic, _ = rhs_shaped_type.shape
- return ConvDimInfo(n, oh, ow, oc, fh, fw, ic)
-
- @staticmethod
- def from_problem_size(problem_size: ProblemSize):
- return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type)
-
-
-def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]:
- m, n, k = configuration.tile_sizes
- tile_size = [1] * len(tile_dims)
- for idx, dim in enumerate(tile_dims):
- if dim == "m":
- tile_size[idx] = m
- if dim == "n":
- tile_size[idx] = n
- if dim == "k":
- tile_size[idx] = k
- return tile_size
-
-
-def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]:
- return [1] + configuration.tile_sizes
-
-
-def get_pipeline_config(configuration: Configuration) -> str:
- extra_config = ""
- if not configuration.gpu_pipeline_options.all_default():
- extra_config += f", gpu_pipeline_options = {configuration.gpu_pipeline_options}"
- if configuration.waves_per_eu != 2:
- extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}'
- return extra_config
-
-
def apply_configuration(
template: list[str], configuration: Configuration, tile_sizes: list[int]
) -> str:
@@ -306,253 +73,8 @@ def apply_configuration(
return new_mlir
-def parse_tensor_type(tensor_type: str) -> ShapedType:
- shape_match = re.search(str(MlirRegex.tensor_type), tensor_type)
- assert shape_match
-
- shape_str = shape_match.group(1)
- dims_and_elem = shape_str.split("x")
- dims = [int(x) for x in dims_and_elem[:-1]]
- elem = dims_and_elem[-1]
- str_to_elem_ty = {x.name: x for x in ElementType}
- return ShapedType(dims, str_to_elem_ty[elem])
-
-
-def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]:
- def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
- if problem_size.res_type.element_type != intrinsic.output_type:
- return False
- if problem_size.dispatch_kind != DispatchKind.batch_matmul:
- if problem_size.lhs_type.element_type != intrinsic.input_type:
- return False
- if problem_size.rhs_type.element_type != intrinsic.input_type:
- return False
- return True
-
- return list(filter(is_compatible, MfmaIntrinsic.all()))
-
-
-def get_mfma_intrinsic_constraints(
- problem_size: ProblemSize,
- intrinsic_m: z3.ArithRef,
- intrinsic_n: z3.ArithRef,
- intrinsic_k: z3.ArithRef,
-) -> z3.BoolRef:
- compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size)
- assert len(compatible_intrinsics) > 0, "No compatible intrinsics found"
- return z3.Or(
- *(
- z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k)
- for mfma in compatible_intrinsics
- )
- )
-
-
-def get_dispatch_constraints(
- problem_size: ProblemSize,
- tile_m: z3.ArithRef,
- tile_n: z3.ArithRef,
- tile_k: z3.ArithRef,
-) -> list[z3.BoolRef]:
- if problem_size.dispatch_kind != DispatchKind.conv:
- return []
-
- dim_info = ConvDimInfo.from_problem_size(problem_size)
- conv_constraints = []
- # WARNING: This sometimes makes the constraints UNSAT for some reason.
- conv_constraints += [tile_m <= dim_info.ow]
- conv_constraints += [tile_n <= dim_info.oc]
- conv_constraints += [tile_k <= dim_info.ic]
- return conv_constraints
-
-
-def calculate_shared_memory_usage_in_bytes(
- problem_size: ProblemSize,
- m: int | z3.ArithRef,
- n: int | z3.ArithRef,
- k: int | z3.ArithRef,
-) -> int | z3.ArithRef:
- lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8)
- rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8)
- return lhs_memory + rhs_memory
-
-
-def generate_constraints(
- problem_size: ProblemSize,
- tile_sizes,
- num_subgroups,
- subgroup_size,
- intrinsic_size,
- workgroup_size,
- subgroup_m_count,
- subgroup_n_count,
- waves_per_eu,
-):
- M, N, K = (
- problem_size.matmul_size.M,
- problem_size.matmul_size.N,
- problem_size.matmul_size.K,
- )
- m, n, k = tile_sizes
- intrinsic_mn, intrinsic_k = intrinsic_size
- wg_x, wg_y, wg_z = workgroup_size
- wg_threads = z3.Int("wg_threads")
- constraints = [wg_threads == wg_x * wg_y * wg_z]
- constraints += [subgroup_size == 64, wg_threads <= 1024]
- constraints += [
- get_mfma_intrinsic_constraints(
- problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k
- )
- ]
- subgroup_k_count = 1
- constraints += [
- m >= intrinsic_mn,
- m <= 512,
- m <= M,
- ]
- constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0]
- constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0]
- for x in (subgroup_m_count, subgroup_n_count):
- constraints += [x >= 1, x <= 32]
-
- subgroup_m_tile_count = z3.Int("sg_m_tcnt")
- subgroup_n_tile_count = z3.Int("sg_n_tcnt")
- subgroup_k_tile_count = z3.Int("sg_k_tcnt")
- for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count):
- constraints += [x >= 1, x <= 32]
-
- constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn]
- constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn]
- constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k]
- constraints += [wg_x == subgroup_size * subgroup_n_count]
- constraints += [wg_y == subgroup_m_count]
- constraints += [wg_z == subgroup_k_count]
- constraints += [z3.Or(wg_x <= n, wg_x <= m)]
- constraints += [k % intrinsic_mn == 0]
- constraints += [(k * n) % wg_threads == 0]
- constraints += [(k * m) % wg_threads == 0]
- subgroups = subgroup_m_count * subgroup_n_count
- if num_subgroups > 0:
- constraints += [subgroups == num_subgroups]
- else:
- constraints += [subgroups >= 1, subgroups <= 10]
-
- constraints += [waves_per_eu == 2]
- # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)]
-
- shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k)
- constraints += [shared_memory <= 65536]
-
- constraints += get_dispatch_constraints(problem_size, m, n, k)
-
- return constraints
-
-
-def generate_solutions(problem_size: ProblemSize, num_subgrups: int):
- M, N, K = problem_size.MNK
- tune_logger.info(f"{M},{N},{K}")
- m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k")
- subgroup_size = z3.Int("subgroup_size")
- intrinsic_mn = z3.Int("intrinsic_mn")
- intrinsic_k = z3.Int("intrinsic_k")
- wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z")
- sg_m_cnt = z3.Int("sg_m_cnt")
- sg_n_cnt = z3.Int("sg_n_cnt")
- waves_per_eu = z3.Int("waves_per_eu")
- all_vars = [
- m,
- n,
- k,
- subgroup_size,
- intrinsic_mn,
- intrinsic_k,
- wg_x,
- wg_y,
- wg_z,
- sg_m_cnt,
- sg_n_cnt,
- waves_per_eu,
- ]
-
- solver = z3.Solver()
- constraints = generate_constraints(
- problem_size,
- [m, n, k],
- num_subgrups,
- subgroup_size,
- [intrinsic_mn, intrinsic_k],
- [wg_x, wg_y, wg_z],
- sg_m_cnt,
- sg_n_cnt,
- waves_per_eu,
- )
- solver.add(z3.simplify(z3.And(constraints)))
- tune_logger.debug(f"Initial constraints: {solver}")
- i = 0
- while solver.check() == z3.sat:
- model = solver.model()
- lookup = lambda var: model[var].as_long()
-
- config = Configuration(
- lookup(subgroup_size),
- [lookup(wg_x), lookup(wg_y), lookup(wg_z)],
- MfmaIntrinsic(
- problem_size.res_type.element_type,
- lookup(intrinsic_mn),
- lookup(intrinsic_mn),
- lookup(intrinsic_k),
- problem_size.lhs_type.element_type,
- ),
- [lookup(m), lookup(n), lookup(k)],
- lookup(sg_m_cnt),
- lookup(sg_n_cnt),
- GpuPipelineOptions(),
- lookup(waves_per_eu),
- )
- solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars)))))
- i += 1
- yield config
-
-
-def get_default_output_dir() -> str:
- from datetime import datetime
-
- return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M")
-
-
-def parse_mlir(mlir_text: str) -> ir.Module:
- mlir_module = None
- with ireec.ir.Context() as context:
- try:
- mlir_module = ireec.ir.Module.parse(mlir_text)
- tune_logger.info("MLIR parsing successful!")
- except ireec.ir.MLIRError as e:
- tune_logger.error(f"Error parsing MLIR: {e}")
- raise RuntimeError(f"Error parsing MLIR: {e}")
-
- return mlir_module
-
-
-@dataclass
-class MLIRTransformation:
- """Transformation of MLIR context"""
-
- template: str
- modified: str
- embeddable: str
-
-
-class DispatchTuner(ABC):
- @abstractmethod
- def supports(self, op_name: str) -> bool:
- """Check if the tuner can handle the type of operation represented by the input string."""
- pass
-
- @abstractmethod
- def get_shapes(self, template: list[str]) -> ProblemSize:
- """Extract problem size of thge operation."""
- pass
-
+class DispatchTuner(DispatchParser):
+ # TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove this in favor of configuring using transform dialect.
@abstractmethod
def apply_params(
self,
@@ -564,12 +86,6 @@ def apply_params(
pass
-@dataclass
-class OpWalkResult:
- was_interrupted: bool = False
- dispatch_tuner: Optional[DispatchTuner] = None
-
-
class DispatchTunerRegistry:
def __init__(self):
self.registry = set()
@@ -593,60 +109,7 @@ def find_handler(self, op_name: str) -> DispatchTuner:
assert False, "Dispatch kind not supported"
-class MmtTuner(DispatchTuner):
- def supports(self, op_name: str) -> bool:
- return "matmul_transpose_b" in op_name
-
- def get_shapes(self, template: list[str]) -> ProblemSize:
- mmt_re = None
- dps = None
- for line in template:
- if "linalg.generic" not in line:
- continue
- if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line:
- continue
- # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>)
- mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
- dps = re.search(mmt_re, line)
- if dps is None:
- continue
-
- lhs_tensor_type = dps.group("LHS")
- rhs_tensor_type = dps.group("RHS")
- lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
- assert lhs_shaped_type.rank() == 2
- lhs_M, lhs_K = lhs_shaped_type.shape
-
- rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
- assert rhs_shaped_type.rank() == 2
- rhs_N, rhs_K = rhs_shaped_type.shape
-
- assert lhs_shaped_type.element_type == rhs_shaped_type.element_type
- assert lhs_K == rhs_K
-
- res_tensor_type = dps.group("RES")
- res_shaped_type = parse_tensor_type(res_tensor_type)
- assert res_shaped_type.rank() == 2
- res_M, res_N = res_shaped_type.shape
-
- assert lhs_M == res_M
- assert rhs_N == res_N
-
- matmul_size = MatmulSize(
- lhs_shaped_type.shape[0],
- rhs_shaped_type.shape[0],
- lhs_shaped_type.shape[1],
- )
- return ProblemSize(
- matmul_size,
- lhs_type=lhs_shaped_type,
- rhs_type=rhs_shaped_type,
- res_type=res_shaped_type,
- dispatch_kind=DispatchKind.mmt,
- )
- assert mmt_re
- assert dps, f"'{mmt_re}' not found in given context"
-
+class MmtTuner(DispatchTuner, MmtParser):
def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
@@ -698,71 +161,7 @@ def apply_params(
return MLIRTransformation(template, modified, embeddable)
-class ConvTuner(DispatchTuner):
- def supports(self, op_name: str) -> bool:
- return "conv_2d_nhwc_hwcf" in op_name
-
- def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]:
- m, n, k = configuration.tile_sizes
- batch = 1
- fh = 1
- fw = 1
-
- oh = 1
-
- oc = n
- ow = m
- ic = k
- return [batch, oh, ow, oc, fh, fw, ic]
-
- def get_shapes(self, template: list[str]) -> ProblemSize:
- for line in template:
- if "linalg.conv_2d_nhwc_hwcf" not in line:
- continue
-
- # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>)
- conv_re = (
- rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
- )
- dps = re.search(conv_re, line)
- if dps is None:
- continue
-
- lhs_tensor_type = dps.group("LHS")
- rhs_tensor_type = dps.group("RHS")
- lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
- assert lhs_shaped_type.rank() == 4
-
- rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
- assert rhs_shaped_type.rank() == 4
-
- res_tensor_type = dps.group("RES")
- res_shaped_type = parse_tensor_type(res_tensor_type)
- assert res_shaped_type.rank() == 4
-
- # int64_t n = outputShape[0];
- # int64_t oh = outputShape[1];
- # int64_t ow = outputShape[2];
- # int64_t oc = outputShape[3];
- # int64_t fh = filterShape[0];
- # int64_t fw = filterShape[1];
- # int64_t ic = filterShape[2];
- dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type)
- return ProblemSize(
- MatmulSize(
- M=dim_info.oh * dim_info.ow,
- N=dim_info.oc,
- K=dim_info.fh * dim_info.fw * dim_info.ic,
- B=dim_info.n,
- ),
- lhs_shaped_type,
- rhs_shaped_type,
- res_shaped_type,
- DispatchKind.conv,
- )
-
- assert False, "Shape not found"
-
+class ConvTuner(DispatchTuner, ConvParser):
# int64_t n = outputShape[0];
# int64_t oh = outputShape[1];
# int64_t ow = outputShape[2];
@@ -837,135 +236,7 @@ def apply_params(
return MLIRTransformation(template, modified, embeddable)
-class ContractionTuner(DispatchTuner):
- def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str):
- self.lhs_dims = lhs_dims
- self.rhs_dims = rhs_dims
- self.tile_dims = tile_dims
-
- def supports(self, op_name: str) -> bool:
- return "matmul_like" in op_name
-
- def is_broadcast_rhs_mmt_op(self, line: str) -> bool:
- if "linalg.generic" not in line:
- return False
- if (
- r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]'
- not in line
- ):
- return False
- if (
- r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>"
- not in line
- ):
- return False
- return True
-
- def is_broadcast_rhs_mmt(self, template: list[str]) -> bool:
- return any(self.is_broadcast_rhs_mmt_op(line) for line in template)
-
- def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize:
- for line in template:
- if not self.is_broadcast_rhs_mmt_op(line):
- continue
-
- # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>)
- bmmt_re = (
- rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
- )
- dps = re.search(bmmt_re, line)
- if dps is None:
- continue
-
- lhs_tensor_type = dps.group("LHS")
- rhs_tensor_type = dps.group("RHS")
- lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
- assert lhs_shaped_type.rank() == 3
-
- rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
- assert rhs_shaped_type.rank() == 2
-
- res_tensor_type = dps.group("RES")
- res_shaped_type = parse_tensor_type(res_tensor_type)
- assert res_shaped_type.rank() == 3
-
- B0, M0, K0 = lhs_shaped_type.shape
- N1, K1 = rhs_shaped_type.shape
- B2, M2, N2 = res_shaped_type.shape
- assert B0 == B2
- assert M0 == M2
- assert N1 == N2
- assert K0 == K1
- return ProblemSize(
- MatmulSize(M0, N1, K0, B0),
- lhs_shaped_type,
- rhs_shaped_type,
- res_shaped_type,
- DispatchKind.broadcast_rhs_mmt,
- )
-
- assert False, "Shape not found"
-
- def get_shapes(self, template: list[str]) -> ProblemSize:
- if self.is_broadcast_rhs_mmt(template):
- return self.get_shapes_broadcast_rhs_mmt(template)
-
- for line in template:
- if "linalg.generic" not in line:
- continue
- if "lowering_config =" not in line:
- continue
- if '"reduction"' not in line:
- continue
-
- # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>)
- cont_re = (
- rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
- )
- dps = re.search(cont_re, line)
- if dps is None:
- continue
-
- lhs_tensor_type = dps.group("LHS")
- rhs_tensor_type = dps.group("RHS")
- lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
- assert lhs_shaped_type.rank() == len(self.lhs_dims)
-
- rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
- assert rhs_shaped_type.rank() == len(self.rhs_dims)
-
- res_tensor_type = dps.group("RES")
- res_shaped_type = parse_tensor_type(res_tensor_type)
- assert res_shaped_type.rank() >= 2
-
- M = math.prod(
- val if dim == "m" else 1
- for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape)
- )
- N = math.prod(
- val if dim == "n" else 1
- for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape)
- )
- K0 = math.prod(
- val if dim == "k" else 1
- for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape)
- )
- K1 = math.prod(
- val if dim == "k" else 1
- for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape)
- )
- assert K0 == K1
-
- return ProblemSize(
- MatmulSize(M, N, K0),
- lhs_type=lhs_shaped_type,
- rhs_type=rhs_shaped_type,
- res_type=res_shaped_type,
- dispatch_kind=DispatchKind.contraction,
- )
-
- assert False, "Shape not found"
-
+class ContractionTuner(DispatchTuner, ContractionParser):
def get_transform_function_broadcast_rhs_mmt(
self,
problem_size: ProblemSize,
@@ -1049,57 +320,7 @@ def apply_params(
)
-class BatchMmtTuner(DispatchTuner):
- def supports(self, op_name: str) -> bool:
- return "batch_matmul_transpose_b" in op_name
-
- def get_shapes(self, template: list[str]) -> ProblemSize:
- for line in template:
- if "linalg.generic" not in line:
- continue
- if (
- r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]'
- not in line
- ):
- continue
- # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>)
- bmmt_re = (
- rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
- )
- dps = re.search(bmmt_re, line)
- if dps is None:
- continue
-
- lhs_tensor_type = dps.group("LHS")
- rhs_tensor_type = dps.group("RHS")
- lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
- assert lhs_shaped_type.rank() == 3
-
- rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
- assert rhs_shaped_type.rank() == 3
-
- res_tensor_type = dps.group("RES")
- res_shaped_type = parse_tensor_type(res_tensor_type)
- assert res_shaped_type.rank() == 3
-
- B0, M0, K0 = lhs_shaped_type.shape
- B1, N1, K1 = rhs_shaped_type.shape
- B2, M2, N2 = res_shaped_type.shape
- assert B0 == B1
- assert B0 == B2
- assert M0 == M2
- assert N1 == N2
- assert K0 == K1
- return ProblemSize(
- MatmulSize(M0, N1, K0, B0),
- lhs_shaped_type,
- rhs_shaped_type,
- res_shaped_type,
- DispatchKind.batch_mmt,
- )
-
- assert False, "Shape not found"
-
+class BatchMmtTuner(DispatchTuner, BatchMmtParser):
def get_transform_function_batch_mmt(
self,
problem_size: ProblemSize,
@@ -1158,78 +379,7 @@ def apply_params(
return MLIRTransformation(template, modified, embeddable)
-class BatchMatmulTuner(DispatchTuner):
- def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str):
- self.lhs_dims = lhs_dims
- self.rhs_dims = rhs_dims
- self.tile_dims = tile_dims
-
- def supports(self, op_name: str) -> bool:
- return "batch_matmul" in op_name
-
- def get_shapes(self, template: list[str]) -> ProblemSize:
- for line in template:
- if "linalg.batch_matmul" not in line:
- continue
- # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>)
- # outs(%12 : tensor<64x72x1280xf32>)
- cont_re = (
- rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
- )
- dps = re.search(cont_re, line)
- if dps is None:
- continue
-
- lhs_tensor_type = dps.group("LHS")
- rhs_tensor_type = dps.group("RHS")
- lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
- assert lhs_shaped_type.rank() == len(self.lhs_dims)
-
- rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
- assert rhs_shaped_type.rank() == len(self.rhs_dims)
-
- res_tensor_type = dps.group("RES")
- res_shaped_type = parse_tensor_type(res_tensor_type)
- assert res_shaped_type.rank() == lhs_shaped_type.rank()
-
- LHS = lhs_shaped_type.shape
- RHS = rhs_shaped_type.shape
- RES = res_shaped_type.shape
-
- B = math.prod(
- val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS)
- )
- B0 = math.prod(
- val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS)
- )
- B1 = math.prod(
- val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES)
- )
- M = math.prod(
- val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS)
- )
- N = math.prod(
- val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS)
- )
- K0 = math.prod(
- val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS)
- )
- K1 = math.prod(
- val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS)
- )
- assert B == B0 and B == B1
- assert K0 == K1
-
- return ProblemSize(
- MatmulSize(M, N, K0, B),
- lhs_type=lhs_shaped_type,
- rhs_type=rhs_shaped_type,
- res_type=res_shaped_type,
- dispatch_kind=DispatchKind.batch_matmul,
- )
-
- assert False, "Shape not found"
-
+class BatchMatmulTuner(DispatchTuner, BatchMatmulParser):
def get_transform_function_batch_matmul(
self,
problem_size: ProblemSize,
@@ -1301,6 +451,12 @@ def apply_params(
return MLIRTransformation(template, modified, embeddable)
+@dataclass
+class OpWalkResult:
+ was_interrupted: bool = False
+ dispatch_tuner: Optional[DispatchTuner] = None
+
+
def walk_callback_get_fn(
op: ir.Operation,
walk_result: OpWalkResult,
@@ -1331,6 +487,12 @@ def walk_mlir_op(
return walk_result
+def get_default_output_dir() -> str:
+ from datetime import datetime
+
+ return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M")
+
+
def tune(
input: str, # Path to the mlir file to be tuned
output: str = "", # Path to the output directory, auto creates one if not given
@@ -1353,45 +515,50 @@ def tune(
mlir_template = read_input_mlir(input_file)
mlir_text = "".join(mlir_template)
- mlir_module = parse_mlir(mlir_text)
- # Save the input file as the first candidate.
- with open(path.join(output, f"0.mlir"), "w") as f:
- f.write(mlir_text)
-
- dispatch_tuner_registry = DispatchTunerRegistry()
- dispatch_tuner_registry.register(
- [
- MmtTuner(),
- ConvTuner(),
- ContractionTuner(lhs_dims, rhs_dims, tile_dims),
- BatchMmtTuner(),
- BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims),
- ]
- )
-
- walk_result = walk_mlir_op(mlir_module, dispatch_tuner_registry)
-
- dispatch_tuner = walk_result.dispatch_tuner
- problem_size = dispatch_tuner.get_shapes(mlir_template)
- tune_logger.debug(str(problem_size))
- configs = []
- for i, config in enumerate(generate_solutions(problem_size, num_subgroups)):
- if i >= limit:
- break
- tune_logger.info(f"Solution #{i+1}: {config}")
- configs.append(config)
- tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config)
-
- with open(path.join(output, f"{i+1}.mlir"), "w") as f:
- f.write(tf_mlir.modified)
- with open(path.join(output, f"{i+1}_config.mlir"), "w") as f:
- f.write(tf_mlir.embeddable)
+ with ir.Context() as ctx:
+ tuner_context = TunerContext(ctx, tune_logger)
+ mlir_module: ir.Module = parse_mlir(mlir_text, tuner_context)
+ # Save the input file as the first candidate.
+ with open(path.join(output, f"0.mlir"), "w") as f:
+ f.write(mlir_text)
+
+ dispatch_tuner_registry = DispatchTunerRegistry()
+ dispatch_tuner_registry.register(
+ [
+ MmtTuner(),
+ ConvTuner(),
+ ContractionTuner(lhs_dims, rhs_dims, tile_dims),
+ BatchMmtTuner(),
+ BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims),
+ ]
+ )
- with open(path.join(output, "configs.pkl"), "wb") as file:
- pickle.dump(configs, file)
+ walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry)
- tune_logger.info(f"Generated {len(configs)} candidates")
- tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl")
+ dispatch_tuner = walk_result.dispatch_tuner
+ assert dispatch_tuner, "No suitable dispatch tuner found"
+ problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
+ tune_logger.debug(str(problem_size))
+ configs = []
+ for i, config in enumerate(
+ generate_solutions(tuner_context, problem_size, num_subgroups)
+ ):
+ if i >= limit:
+ break
+ tune_logger.info(f"Solution #{i+1}: {config}")
+ configs.append(config)
+ tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config)
+
+ with open(path.join(output, f"{i+1}.mlir"), "w") as f:
+ f.write(tf_mlir.modified)
+ with open(path.join(output, f"{i+1}_config.mlir"), "w") as f:
+ f.write(tf_mlir.embeddable)
+
+ with open(path.join(output, "configs.pkl"), "wb") as file:
+ pickle.dump(configs, file)
+
+ tune_logger.info(f"Generated {len(configs)} candidates")
+ tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl")
def main():
diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py
index 2924db75b..47e351fc7 100644
--- a/tuner/tuner/candidate_gen_test.py
+++ b/tuner/tuner/candidate_gen_test.py
@@ -9,433 +9,9 @@
"""
import pytest
-from . import candidate_gen
-
-
-def test_get_shaped_type_element_bitwidth():
- assert (
- candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth
- == 8
- )
- assert (
- candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32
- )
- assert (
- candidate_gen.ShapedType(
- [2048, 512, 384], candidate_gen.ElementType.f8
- ).bitwidth
- == 8
- )
- assert (
- candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16
- )
-
-
-def test_get_shaped_type_to_str():
- assert (
- str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8))
- == "1024x2048xi8"
- )
- assert (
- str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32))
- == "1024xf32"
- )
- assert (
- str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16))
- == "1x2x3xf16"
- )
- assert (
- str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16))
- == "?x2x3xf16"
- )
-
-
-def test_parse_tensor_type():
- assert candidate_gen.parse_tensor_type(
- "tensor<1x2x3xf32>"
- ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32)
- assert candidate_gen.parse_tensor_type(
- "tensor<123xi8>"
- ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8)
-
-
-def test_get_mmt_tile_sizes():
- config = candidate_gen.Configuration(
- subgroup_size=0,
- workgroup_size=[],
- intrinsic="",
- tile_sizes=[128, 320, 32],
- subgroup_m_count=0,
- subgroup_n_count=0,
- gpu_pipeline_options=candidate_gen.GpuPipelineOptions(),
- waves_per_eu=0,
- )
- assert candidate_gen.get_mmt_tile_sizes(config) == [128, 320, 32]
-
-def test_get_conv_tile_sizes():
- config = candidate_gen.Configuration(
- subgroup_size=64,
- workgroup_size=[256, 1, 1],
- intrinsic="#iree_gpu.mma_layout",
- tile_sizes=[464, 320, 16],
- subgroup_m_count=1,
- subgroup_n_count=4,
- gpu_pipeline_options=candidate_gen.GpuPipelineOptions(),
- waves_per_eu=1,
- )
- assert candidate_gen.ConvTuner().get_conv_tile_sizes(config) == [
- 1,
- 1,
- 464,
- 320,
- 1,
- 1,
- 16,
- ]
-
-
-def test_gpu_pipeline_options():
- options = candidate_gen.GpuPipelineOptions()
- assert options.all_default()
- assert str(options) == "#iree_gpu.pipeline_options<>"
-
- options.prefetch_shared_memory = True
- assert not options.all_default()
- assert str(options) == "#iree_gpu.pipeline_options"
-
- options.no_reduce_shared_memory_bank_conflicts = False
- assert (
- str(options)
- == "#iree_gpu.pipeline_options"
- )
-
- options = candidate_gen.GpuPipelineOptions()
- options.reorder_workgroups_strategy = (
- candidate_gen.ReorderWorkgroupsStrategy.TRANSPOSE
- )
- assert not options.all_default()
- assert (
- str(options)
- == "#iree_gpu.pipeline_options"
- )
-
-
-def test_get_contract_tile_sizes():
- config = candidate_gen.Configuration(
- subgroup_size=32,
- workgroup_size=[16, 16, 1],
- intrinsic="",
- tile_sizes=[4, 8, 16],
- subgroup_m_count=1,
- subgroup_n_count=1,
- gpu_pipeline_options=candidate_gen.GpuPipelineOptions(),
- waves_per_eu=2,
- )
- assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16]
- assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16]
- assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4]
- assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [
- 16,
- 16,
- 16,
- ]
-
-
-def test_get_pipeline_config():
- config = candidate_gen.Configuration(
- subgroup_size=32,
- workgroup_size=[16, 16, 1],
- intrinsic="",
- tile_sizes=[4, 8, 16],
- subgroup_m_count=1,
- subgroup_n_count=1,
- gpu_pipeline_options=candidate_gen.GpuPipelineOptions(),
- waves_per_eu=2,
- )
- config1_str: str = candidate_gen.get_pipeline_config(config)
- assert config1_str == ""
-
- config.waves_per_eu = 4
- config2_str: str = candidate_gen.get_pipeline_config(config)
- assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}'
-
- config.gpu_pipeline_options.prefetch_shared_memory = True
- config3_str = candidate_gen.get_pipeline_config(config)
- assert (
- config3_str
- == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}'
- )
-
-
-def test_get_shapes_mmt():
- template = [
- r"%18 = tensor.empty() : tensor<2048x1280xf32>",
- r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>",
- r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {',
- r"^bb0(%in: f16, %in_0: f16, %out: f32):",
- ]
- assert candidate_gen.MmtTuner().get_shapes(template) == candidate_gen.ProblemSize(
- candidate_gen.MatmulSize(2048, 1280, 1280),
- candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16),
- candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16),
- candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32),
- candidate_gen.DispatchKind.mmt,
- )
-
-
-def test_get_shapes_conv():
- template = [
- r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>",
- r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>",
- r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>",
- ]
- assert candidate_gen.ConvTuner().get_shapes(template) == candidate_gen.ProblemSize(
- candidate_gen.MatmulSize(32, 256, 11520),
- candidate_gen.ShapedType([1, 3, 34, 1280], candidate_gen.ElementType.f16),
- candidate_gen.ShapedType([3, 3, 1280, 256], candidate_gen.ElementType.f16),
- candidate_gen.ShapedType([1, 1, 32, 256], candidate_gen.ElementType.f32),
- candidate_gen.DispatchKind.conv,
- )
-
-
-def test_get_shapes_contract():
- template = [
- r"%18 = tensor.empty() : tensor<2048x1280xf32>",
- r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>",
- r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {',
- r"^bb0(%in: f16, %in_0: f16, %out: f32):",
- ]
- assert candidate_gen.ContractionTuner("mk", "nk", "mnk").get_shapes(
- template
- ) == candidate_gen.ProblemSize(
- candidate_gen.MatmulSize(2048, 1280, 1280),
- candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16),
- candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16),
- candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32),
- candidate_gen.DispatchKind.contraction,
- )
-
-
-def test_get_shapes_batch_matmul():
- template = [
- "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>",
- "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>",
- "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>",
- ]
- assert candidate_gen.BatchMatmulTuner("bmk", "bkn", "mnk").get_shapes(
- template
- ) == candidate_gen.ProblemSize(
- candidate_gen.MatmulSize(32, 32, 1024, 1),
- candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32),
- candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32),
- candidate_gen.ShapedType([1, 32, 32], candidate_gen.ElementType.f32),
- candidate_gen.DispatchKind.batch_matmul,
- )
-
-
-def test_get_shapes_batch_mmt():
- template = [
- r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>",
- r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {',
- r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>",
- ]
- assert candidate_gen.BatchMmtTuner().get_shapes(
- template
- ) == candidate_gen.ProblemSize(
- candidate_gen.MatmulSize(4096, 640, 640, 2),
- candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i8),
- candidate_gen.ShapedType([2, 640, 640], candidate_gen.ElementType.i8),
- candidate_gen.ShapedType([2, 4096, 640], candidate_gen.ElementType.i32),
- candidate_gen.DispatchKind.batch_mmt,
- )
-
-
-def test_mfma_intrinsic_to_str():
- assert (
- str(candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16())
- == "MFMA_F32_16x16x16_F16"
- )
- assert (
- str(candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8())
- == "MFMA_I32_32x32x16_I8"
- )
-
-
-def test_get_compatible_mfma_intrinsics():
- assert candidate_gen.get_compatible_mfma_intrinsics(
- candidate_gen.ProblemSize(
- candidate_gen.MatmulSize(2048, 1280, 1280),
- candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16),
- candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16),
- candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f32),
- candidate_gen.DispatchKind.mmt,
- )
- ) == [
- candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
- candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
- ]
-
- assert candidate_gen.get_compatible_mfma_intrinsics(
- candidate_gen.ProblemSize(
- candidate_gen.MatmulSize(2048, 1280, 1280),
- candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i8),
- candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.i8),
- candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.i32),
- candidate_gen.DispatchKind.mmt,
- )
- ) == [
- candidate_gen.MfmaIntrinsic.mfma_i32_16x16x32_i8(),
- candidate_gen.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
- ]
-
- assert candidate_gen.get_compatible_mfma_intrinsics(
- candidate_gen.ProblemSize(
- candidate_gen.MatmulSize(968, 320, 640, 64),
- candidate_gen.ShapedType([64, 968, 640], candidate_gen.ElementType.f32),
- candidate_gen.ShapedType([64, 640, 320], candidate_gen.ElementType.f32),
- candidate_gen.ShapedType([64, 968, 320], candidate_gen.ElementType.f32),
- candidate_gen.DispatchKind.batch_matmul,
- )
- ) == [
- candidate_gen.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
- candidate_gen.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
- ]
-
-
-def test_generate_solutions():
- matmul_size = candidate_gen.MatmulSize(2048, 3840, 1280)
- lhs_type = candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16)
- rhs_type = candidate_gen.ShapedType([3840, 1280], candidate_gen.ElementType.f16)
- res_type = candidate_gen.ShapedType([2048, 3840], candidate_gen.ElementType.f32)
- problem_size = candidate_gen.ProblemSize(
- matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt
- )
- configs = candidate_gen.generate_solutions(problem_size, 4)
- assert configs is not None
-
-
-def test_calculate_shared_memory_usage_in_bytes():
- matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024)
- lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16)
- rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16)
- res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32)
- problem_size = candidate_gen.ProblemSize(
- matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt
- )
- assert (
- candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128)
- == 147456
- )
-
- lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i8)
- problem_size = candidate_gen.ProblemSize(
- matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt
- )
- assert (
- candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128)
- == 81920
- )
-
- rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32)
- problem_size = candidate_gen.ProblemSize(
- matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt
- )
- assert (
- candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32)
- == 12288
- )
-
-
-def test_generate_constraints_valid_input():
- matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024)
- lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16)
- rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16)
- res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32)
- problem_size = candidate_gen.ProblemSize(
- matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt
- )
- # Define input parameters as z3 Ints
- m, n, k = (
- candidate_gen.z3.Int("m"),
- candidate_gen.z3.Int("n"),
- candidate_gen.z3.Int("k"),
- )
- subgroup_size = candidate_gen.z3.Int("subgroup_size")
- intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn")
- intrinsic_k = candidate_gen.z3.Int("intrinsic_k")
- wg_x, wg_y, wg_z = (
- candidate_gen.z3.Int("wg_x"),
- candidate_gen.z3.Int("wg_y"),
- candidate_gen.z3.Int("wg_z"),
- )
- sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt")
- sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt")
- waves_per_eu = candidate_gen.z3.Int("waves_per_eu")
-
- constraints = candidate_gen.generate_constraints(
- problem_size,
- [m, n, k],
- 4,
- subgroup_size,
- [intrinsic_mn, intrinsic_k],
- [wg_x, wg_y, wg_z],
- sg_m_cnt,
- sg_n_cnt,
- waves_per_eu,
- )
-
- solver = candidate_gen.z3.Solver()
- solver.add(constraints)
-
- # Check if the constraints are satisfiable
- assert solver.check() == candidate_gen.z3.sat
-
-
-def test_generate_constraints_invalid_input():
- # Define input parameters that should lead to unsatisfiable constraints
- matmul_size = candidate_gen.MatmulSize(1024, 1024, 1024)
- lhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16)
- rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f16)
- res_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.f32)
- problem_size = candidate_gen.ProblemSize(
- matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt
- )
- m, n, k = (
- candidate_gen.z3.Int("m"),
- candidate_gen.z3.Int("n"),
- candidate_gen.z3.Int("k"),
- )
- subgroup_size = candidate_gen.z3.Int("subgroup_size")
- intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn")
- intrinsic_k = candidate_gen.z3.Int("intrinsic_k")
- wg_x, wg_y, wg_z = (
- candidate_gen.z3.Int("wg_x"),
- candidate_gen.z3.Int("wg_y"),
- candidate_gen.z3.Int("wg_z"),
- )
- sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt")
- sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt")
- waves_per_eu = candidate_gen.z3.Int("waves_per_eu")
-
- constraints = candidate_gen.generate_constraints(
- problem_size,
- [m, n, k],
- 4,
- subgroup_size,
- [intrinsic_mn, intrinsic_k],
- [wg_x, wg_y, wg_z],
- sg_m_cnt,
- sg_n_cnt,
- waves_per_eu,
- )
- constraints.append(m > 1000) # Adding an additional unsatisfiable constraint
-
- solver = candidate_gen.z3.Solver()
- solver.add(constraints)
-
- # Check if the constraints are unsatisfiable
- assert solver.check() == candidate_gen.z3.unsat
+from . import candidate_gen
+from . import common
def remove_comments(mlir: str) -> str:
@@ -444,7 +20,7 @@ def remove_comments(mlir: str) -> str:
)
-def test_apply_params_mmt():
+def test_apply_params_mmt() -> None:
mlir_template = [
", subgroup_m_count = 16, subgroup_n_count = 16>",
" None:
mlir_template = [
", subgroup_m_count = 16, subgroup_n_count = 16>",
" None:
mlir_template = [
", subgroup_m_count = 2, subgroup_n_count = 2>}>",
" None:
mlir_template = [
", subgroup_m_count = 4, subgroup_n_count = 1>}>",
" None:
mlir_template = [
", subgroup_m_count = 4, subgroup_n_count = 1>}>",
" None:
mlir_template = [
", subgroup_m_count = 4, subgroup_n_count = 1>}>",
" None:
mlir_template = [
", subgroup_m_count = 4, subgroup_n_count = 1>}>",
" None:
mlir_lines = [
r"%18 = tensor.empty() : tensor<2x1024x10240xi32>",
r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>",
@@ -859,20 +431,3 @@ def test_detect_broadcast_rhs_mmt():
assert candidate_gen.ContractionTuner("mk", "nk", "mnk").is_broadcast_rhs_mmt(
mlir_lines
)
-
-
-def test_parse_mlir():
- mlir_str = r"""
- builtin.module {
- func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
- %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
- return %0 : tensor<4xf32>
- }
- }
- """
- mlir_module = candidate_gen.parse_mlir(mlir_str)
- assert mlir_module != None
- assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module)
- assert isinstance(
- mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp
- )
diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py
new file mode 100644
index 000000000..7b295cdb0
--- /dev/null
+++ b/tuner/tuner/common.py
@@ -0,0 +1,264 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import re
+import logging
+from dataclasses import astuple, dataclass
+from enum import Enum
+from typing import Optional
+
+from iree.compiler import ir # type: ignore
+
+
+class TunerContext:
+ def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger):
+ self.mlir_ctx = mlir_ctx
+ self.logger = logger
+
+
+class DispatchKind(Enum):
+ conv = 1
+ mmt = 2
+ contraction = 3
+ batch_mmt = 4
+ batch_matmul = 5
+ broadcast_rhs_mmt = 6
+
+
+class ElementType(Enum):
+ i8 = 1
+ i32 = 2
+ f8 = 3
+ f16 = 4
+ f32 = 5
+
+ @property
+ def bitwidth(self) -> int:
+ match self:
+ case ElementType.i8 | ElementType.f8:
+ return 8
+ case ElementType.f16:
+ return 16
+ case ElementType.i32 | ElementType.f32:
+ return 32
+ case _:
+ assert False, "unhandled case"
+
+ def __str__(self) -> str:
+ return self.name
+
+
+@dataclass
+class ShapedType:
+ shape: list[int]
+ element_type: ElementType
+
+ def rank(self) -> int:
+ return len(self.shape)
+
+ @property
+ def bitwidth(self) -> int:
+ return self.element_type.bitwidth
+
+ def __str__(self) -> str:
+ dim_to_str = lambda dim: str(dim) if dim != -1 else "?"
+ return "x".join(map(dim_to_str, self.shape)) + "x" + str(self.element_type)
+
+
+@dataclass
+class MatmulSize:
+ M: int
+ N: int
+ K: int
+ B: int = 1
+
+
+@dataclass
+class ProblemSize:
+ matmul_size: MatmulSize
+ lhs_type: ShapedType
+ rhs_type: ShapedType
+ res_type: ShapedType
+ dispatch_kind: DispatchKind
+
+ @property
+ def MNK(self) -> tuple[int, int, int]:
+ return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K)
+
+
+@dataclass
+class MfmaIntrinsic:
+ output_type: ElementType
+ m: int
+ n: int
+ k: int
+ input_type: ElementType
+
+ def __str__(self) -> str:
+ input = str(self.input_type).upper()
+ output = str(self.output_type).upper()
+ return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}"
+
+ @staticmethod
+ def mfma_f32_16x16x16_f16():
+ return MfmaIntrinsic(ElementType.f32, 16, 16, 16, ElementType.f16)
+
+ @staticmethod
+ def mfma_f32_32x32x8_f16():
+ return MfmaIntrinsic(ElementType.f32, 32, 32, 8, ElementType.f16)
+
+ @staticmethod
+ def mfma_i32_16x16x32_i8():
+ return MfmaIntrinsic(ElementType.i32, 16, 16, 32, ElementType.i8)
+
+ @staticmethod
+ def mfma_i32_32x32x16_i8():
+ return MfmaIntrinsic(ElementType.i32, 32, 32, 16, ElementType.i8)
+
+ @staticmethod
+ def all():
+ return [
+ MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ MfmaIntrinsic.mfma_f32_32x32x8_f16(),
+ MfmaIntrinsic.mfma_i32_16x16x32_i8(),
+ MfmaIntrinsic.mfma_i32_32x32x16_i8(),
+ ]
+
+
+def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]:
+ def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
+ if problem_size.res_type.element_type != intrinsic.output_type:
+ return False
+ if problem_size.dispatch_kind != DispatchKind.batch_matmul:
+ if problem_size.lhs_type.element_type != intrinsic.input_type:
+ return False
+ if problem_size.rhs_type.element_type != intrinsic.input_type:
+ return False
+ return True
+
+ return list(filter(is_compatible, MfmaIntrinsic.all()))
+
+
+class ReorderWorkgroupsStrategy(Enum):
+ NONE = 0
+ SWIZZLE = 1
+ TRANSPOSE = 2
+
+ def __str__(self) -> str:
+ return self.name.title()
+
+
+@dataclass
+class GpuPipelineOptions:
+ """Represents the `iree_gpu.pipeline_options` attribute"""
+
+ prefetch_shared_memory: Optional[bool] = None
+ no_reduce_shared_memory_bank_conflicts: Optional[bool] = None
+ reorder_workgroups_strategy: Optional[ReorderWorkgroupsStrategy] = None
+
+ def all_default(self) -> bool:
+ return all(x is None for x in astuple(self))
+
+ def __str__(self) -> str:
+ options: list[str] = []
+ if self.prefetch_shared_memory is not None:
+ options.append(
+ f"prefetch_shared_memory = {str(self.prefetch_shared_memory).lower()}"
+ )
+ if self.no_reduce_shared_memory_bank_conflicts is not None:
+ options.append(
+ f"no_reduce_shared_memory_bank_conflicts = {str(self.no_reduce_shared_memory_bank_conflicts).lower()}"
+ )
+ if self.reorder_workgroups_strategy is not None:
+ options.append(
+ f"reorder_workgroups_strategy = {self.reorder_workgroups_strategy}"
+ )
+
+ return f"#iree_gpu.pipeline_options<{', '.join(options)}>"
+
+
+@dataclass
+class Configuration:
+ subgroup_size: int
+ workgroup_size: list[int]
+ intrinsic: MfmaIntrinsic
+ tile_sizes: list[int]
+ subgroup_m_count: int
+ subgroup_n_count: int
+ gpu_pipeline_options: GpuPipelineOptions
+ waves_per_eu: int
+
+
+def get_pipeline_config(configuration: Configuration) -> str:
+ extra_config = ""
+ if not configuration.gpu_pipeline_options.all_default():
+ extra_config += f", gpu_pipeline_options = {configuration.gpu_pipeline_options}"
+ if configuration.waves_per_eu != 2:
+ extra_config += f', llvm_func_attrs = {{"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"}}'
+ return extra_config
+
+
+class MlirRegex(Enum):
+ ssa_value = r"%[a-zA-Z0-9-_]+"
+ tensor_type = r"tensor<(([0-9]+x)+((f|i)[0-9]+))>"
+
+ def __str__(self) -> str:
+ return self.value
+
+ @staticmethod
+ def dps_ins_two_args() -> str:
+ return rf"ins\({MlirRegex.ssa_value}, {MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type}), (?P{MlirRegex.tensor_type})\)"
+
+ @staticmethod
+ def dps_outs_one_arg() -> str:
+ return rf"outs\({MlirRegex.ssa_value} : (?P{MlirRegex.tensor_type})\)"
+
+
+def read_input_mlir(filename: str) -> list[str]:
+ with open(filename, "r") as f:
+ return f.readlines()
+
+
+@dataclass
+class ConvDimInfo:
+ n: int
+ oh: int
+ ow: int
+ oc: int
+ fh: int
+ fw: int
+ ic: int
+
+ @staticmethod
+ def from_rhs_res(rhs_shaped_type: ShapedType, res_shaped_type: ShapedType):
+ n, oh, ow, oc = res_shaped_type.shape
+ fh, fw, ic, _ = rhs_shaped_type.shape
+ return ConvDimInfo(n, oh, ow, oc, fh, fw, ic)
+
+ @staticmethod
+ def from_problem_size(problem_size: ProblemSize):
+ return ConvDimInfo.from_rhs_res(problem_size.rhs_type, problem_size.res_type)
+
+
+def parse_tensor_type(tensor_type: str) -> ShapedType:
+ shape_match = re.search(str(MlirRegex.tensor_type), tensor_type)
+ assert shape_match
+
+ shape_str = shape_match.group(1)
+ dims_and_elem = shape_str.split("x")
+ dims = [int(x) for x in dims_and_elem[:-1]]
+ elem = dims_and_elem[-1]
+ str_to_elem_ty = {x.name: x for x in ElementType}
+ return ShapedType(dims, str_to_elem_ty[elem])
+
+
+@dataclass
+class MLIRTransformation:
+ """Transformation of MLIR context"""
+
+ template: list[str]
+ modified: str
+ embeddable: str
diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py
new file mode 100644
index 000000000..858d593c9
--- /dev/null
+++ b/tuner/tuner/common_test.py
@@ -0,0 +1,131 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""
+Usage: python -m pytest candidate_gen_test.py
+"""
+
+import pytest
+from . import common
+
+
+def test_get_shaped_type_element_bitwidth() -> None:
+ assert common.ShapedType([1024, 2048], common.ElementType.i8).bitwidth == 8
+ assert common.ShapedType([2048], common.ElementType.i32).bitwidth == 32
+ assert common.ShapedType([2048, 512, 384], common.ElementType.f8).bitwidth == 8
+ assert common.ShapedType([1, 1], common.ElementType.f16).bitwidth == 16
+
+
+def test_get_shaped_type_to_str() -> None:
+ assert str(common.ShapedType([1024, 2048], common.ElementType.i8)) == "1024x2048xi8"
+ assert str(common.ShapedType([1024], common.ElementType.f32)) == "1024xf32"
+ assert str(common.ShapedType([1, 2, 3], common.ElementType.f16)) == "1x2x3xf16"
+ assert str(common.ShapedType([-1, 2, 3], common.ElementType.f16)) == "?x2x3xf16"
+
+
+def test_parse_tensor_type() -> None:
+ assert common.parse_tensor_type("tensor<1x2x3xf32>") == common.ShapedType(
+ [1, 2, 3], common.ElementType.f32
+ )
+ assert common.parse_tensor_type("tensor<123xi8>") == common.ShapedType(
+ [123], common.ElementType.i8
+ )
+
+
+def test_gpu_pipeline_options() -> None:
+ options = common.GpuPipelineOptions()
+ assert options.all_default()
+ assert str(options) == "#iree_gpu.pipeline_options<>"
+
+ options.prefetch_shared_memory = True
+ assert not options.all_default()
+ assert str(options) == "#iree_gpu.pipeline_options"
+
+ options.no_reduce_shared_memory_bank_conflicts = False
+ assert (
+ str(options)
+ == "#iree_gpu.pipeline_options"
+ )
+
+ options = common.GpuPipelineOptions()
+ options.reorder_workgroups_strategy = common.ReorderWorkgroupsStrategy.TRANSPOSE
+ assert not options.all_default()
+ assert (
+ str(options)
+ == "#iree_gpu.pipeline_options"
+ )
+
+
+def test_get_pipeline_config() -> None:
+ config = common.Configuration(
+ subgroup_size=32,
+ workgroup_size=[16, 16, 1],
+ intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ tile_sizes=[4, 8, 16],
+ subgroup_m_count=1,
+ subgroup_n_count=1,
+ gpu_pipeline_options=common.GpuPipelineOptions(),
+ waves_per_eu=2,
+ )
+ config1_str: str = common.get_pipeline_config(config)
+ assert config1_str == ""
+
+ config.waves_per_eu = 4
+ config2_str: str = common.get_pipeline_config(config)
+ assert config2_str == ', llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}'
+
+ config.gpu_pipeline_options.prefetch_shared_memory = True
+ config3_str = common.get_pipeline_config(config)
+ assert (
+ config3_str
+ == ', gpu_pipeline_options = #iree_gpu.pipeline_options, llvm_func_attrs = {"amdgpu-waves-per-eu" = "4"}'
+ )
+
+
+def test_mfma_intrinsic_to_str() -> None:
+ assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16"
+ assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8"
+
+
+def test_get_compatible_mfma_intrinsics() -> None:
+ assert common.get_compatible_mfma_intrinsics(
+ common.ProblemSize(
+ common.MatmulSize(2048, 1280, 1280),
+ common.ShapedType([2048, 1280], common.ElementType.f16),
+ common.ShapedType([1280, 1280], common.ElementType.f16),
+ common.ShapedType([2048, 1280], common.ElementType.f32),
+ common.DispatchKind.mmt,
+ )
+ ) == [
+ common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
+ ]
+
+ assert common.get_compatible_mfma_intrinsics(
+ common.ProblemSize(
+ common.MatmulSize(2048, 1280, 1280),
+ common.ShapedType([2048, 1280], common.ElementType.i8),
+ common.ShapedType([1280, 1280], common.ElementType.i8),
+ common.ShapedType([2048, 1280], common.ElementType.i32),
+ common.DispatchKind.mmt,
+ )
+ ) == [
+ common.MfmaIntrinsic.mfma_i32_16x16x32_i8(),
+ common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
+ ]
+
+ assert common.get_compatible_mfma_intrinsics(
+ common.ProblemSize(
+ common.MatmulSize(968, 320, 640, 64),
+ common.ShapedType([64, 968, 640], common.ElementType.f32),
+ common.ShapedType([64, 640, 320], common.ElementType.f32),
+ common.ShapedType([64, 968, 320], common.ElementType.f32),
+ common.DispatchKind.batch_matmul,
+ )
+ ) == [
+ common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
+ ]
diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py
new file mode 100644
index 000000000..ac46d8edd
--- /dev/null
+++ b/tuner/tuner/dispatch_constraints.py
@@ -0,0 +1,197 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Given an input dispatch, this code modifies the hyperparameters
+# in the code and runs it.
+
+import z3 # type: ignore
+from typing import Iterator
+
+from .common import *
+
+
+def get_mfma_intrinsic_constraints(
+ problem_size: ProblemSize,
+ intrinsic_m: z3.ArithRef,
+ intrinsic_n: z3.ArithRef,
+ intrinsic_k: z3.ArithRef,
+) -> z3.BoolRef:
+ compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size)
+ assert len(compatible_intrinsics) > 0, "No compatible intrinsics found"
+ return z3.Or(
+ *(
+ z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k)
+ for mfma in compatible_intrinsics
+ )
+ )
+
+
+def get_dispatch_constraints(
+ problem_size: ProblemSize,
+ tile_m: z3.ArithRef,
+ tile_n: z3.ArithRef,
+ tile_k: z3.ArithRef,
+) -> list[z3.BoolRef]:
+ if problem_size.dispatch_kind != DispatchKind.conv:
+ return []
+
+ dim_info = ConvDimInfo.from_problem_size(problem_size)
+ conv_constraints = []
+ # WARNING: This sometimes makes the constraints UNSAT for some reason.
+ conv_constraints += [tile_m <= dim_info.ow]
+ conv_constraints += [tile_n <= dim_info.oc]
+ conv_constraints += [tile_k <= dim_info.ic]
+ return conv_constraints
+
+
+def calculate_shared_memory_usage_in_bytes(
+ problem_size: ProblemSize,
+ m: int | z3.ArithRef,
+ n: int | z3.ArithRef,
+ k: int | z3.ArithRef,
+) -> int | z3.ArithRef:
+ lhs_memory = m * k * (problem_size.lhs_type.bitwidth // 8)
+ rhs_memory = k * n * (problem_size.rhs_type.bitwidth // 8)
+ return lhs_memory + rhs_memory
+
+
+def generate_constraints(
+ problem_size: ProblemSize,
+ tile_sizes,
+ num_subgroups,
+ subgroup_size,
+ intrinsic_size,
+ workgroup_size,
+ subgroup_m_count,
+ subgroup_n_count,
+ waves_per_eu,
+):
+ M, N, K = (
+ problem_size.matmul_size.M,
+ problem_size.matmul_size.N,
+ problem_size.matmul_size.K,
+ )
+ m, n, k = tile_sizes
+ intrinsic_mn, intrinsic_k = intrinsic_size
+ wg_x, wg_y, wg_z = workgroup_size
+ wg_threads = z3.Int("wg_threads")
+ constraints = [wg_threads == wg_x * wg_y * wg_z]
+ constraints += [subgroup_size == 64, wg_threads <= 1024]
+ constraints += [
+ get_mfma_intrinsic_constraints(
+ problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k
+ )
+ ]
+ subgroup_k_count = 1
+ constraints += [
+ m >= intrinsic_mn,
+ m <= 512,
+ m <= M,
+ ]
+ constraints += [n >= intrinsic_mn, n <= 512, n <= N, N % n == 0]
+ constraints += [k >= intrinsic_k, k <= 512, k <= K, K % k == 0]
+ for x in (subgroup_m_count, subgroup_n_count):
+ constraints += [x >= 1, x <= 32]
+
+ subgroup_m_tile_count = z3.Int("sg_m_tcnt")
+ subgroup_n_tile_count = z3.Int("sg_n_tcnt")
+ subgroup_k_tile_count = z3.Int("sg_k_tcnt")
+ for x in (subgroup_m_tile_count, subgroup_n_tile_count, subgroup_k_tile_count):
+ constraints += [x >= 1, x <= 32]
+
+ constraints += [m == subgroup_m_count * subgroup_m_tile_count * intrinsic_mn]
+ constraints += [n == subgroup_n_count * subgroup_n_tile_count * intrinsic_mn]
+ constraints += [k == subgroup_k_count * subgroup_k_tile_count * intrinsic_k]
+ constraints += [wg_x == subgroup_size * subgroup_n_count]
+ constraints += [wg_y == subgroup_m_count]
+ constraints += [wg_z == subgroup_k_count]
+ constraints += [z3.Or(wg_x <= n, wg_x <= m)]
+ constraints += [k % intrinsic_mn == 0]
+ constraints += [(k * n) % wg_threads == 0]
+ constraints += [(k * m) % wg_threads == 0]
+ subgroups = subgroup_m_count * subgroup_n_count
+ if num_subgroups > 0:
+ constraints += [subgroups == num_subgroups]
+ else:
+ constraints += [subgroups >= 1, subgroups <= 10]
+
+ constraints += [waves_per_eu == 2]
+ # constraints += [z3.Or(waves_per_eu == 2, waves_per_eu == 3, waves_per_eu == 4)]
+
+ shared_memory = calculate_shared_memory_usage_in_bytes(problem_size, m, n, k)
+ constraints += [shared_memory <= 65536]
+
+ constraints += get_dispatch_constraints(problem_size, m, n, k)
+
+ return constraints
+
+
+def generate_solutions(
+ ctx: TunerContext, problem_size: ProblemSize, num_subgrups: int
+) -> Iterator[Configuration]:
+ M, N, K = problem_size.MNK
+ ctx.logger.info(f"{M},{N},{K}")
+ m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k")
+ subgroup_size = z3.Int("subgroup_size")
+ intrinsic_mn = z3.Int("intrinsic_mn")
+ intrinsic_k = z3.Int("intrinsic_k")
+ wg_x, wg_y, wg_z = z3.Int("wg_x"), z3.Int("wg_y"), z3.Int("wg_z")
+ sg_m_cnt = z3.Int("sg_m_cnt")
+ sg_n_cnt = z3.Int("sg_n_cnt")
+ waves_per_eu = z3.Int("waves_per_eu")
+ all_vars = [
+ m,
+ n,
+ k,
+ subgroup_size,
+ intrinsic_mn,
+ intrinsic_k,
+ wg_x,
+ wg_y,
+ wg_z,
+ sg_m_cnt,
+ sg_n_cnt,
+ waves_per_eu,
+ ]
+
+ solver = z3.Solver()
+ constraints = generate_constraints(
+ problem_size,
+ [m, n, k],
+ num_subgrups,
+ subgroup_size,
+ [intrinsic_mn, intrinsic_k],
+ [wg_x, wg_y, wg_z],
+ sg_m_cnt,
+ sg_n_cnt,
+ waves_per_eu,
+ )
+ solver.add(z3.simplify(z3.And(constraints)))
+ ctx.logger.debug(f"Initial constraints: {solver}")
+ i = 0
+ while solver.check() == z3.sat:
+ model = solver.model()
+ lookup = lambda var: model[var].as_long()
+
+ config = Configuration(
+ lookup(subgroup_size),
+ [lookup(wg_x), lookup(wg_y), lookup(wg_z)],
+ MfmaIntrinsic(
+ problem_size.res_type.element_type,
+ lookup(intrinsic_mn),
+ lookup(intrinsic_mn),
+ lookup(intrinsic_k),
+ problem_size.lhs_type.element_type,
+ ),
+ [lookup(m), lookup(n), lookup(k)],
+ lookup(sg_m_cnt),
+ lookup(sg_n_cnt),
+ GpuPipelineOptions(),
+ lookup(waves_per_eu),
+ )
+ solver.add(z3.simplify(z3.Not(z3.And(list(x == model[x] for x in all_vars)))))
+ i += 1
+ yield config
diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py
new file mode 100644
index 000000000..55f3a8c43
--- /dev/null
+++ b/tuner/tuner/dispatch_constraints_test.py
@@ -0,0 +1,161 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""
+Usage: python -m pytest candidate_gen_test.py
+"""
+
+import pytest
+import z3 # type: ignore
+
+from logging import Logger
+from unittest.mock import MagicMock
+
+from . import common
+from . import dispatch_constraints
+
+
+def test_generate_solutions() -> None:
+ matmul_size = common.MatmulSize(2048, 3840, 1280)
+ lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16)
+ rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16)
+ res_type = common.ShapedType([2048, 3840], common.ElementType.f32)
+ problem_size = common.ProblemSize(
+ matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
+ )
+ logger: Logger = MagicMock(spec=Logger)
+ ctx = common.TunerContext(None, logger)
+ configs = dispatch_constraints.generate_solutions(ctx, problem_size, 4)
+ assert configs is not None
+
+
+def test_calculate_shared_memory_usage_in_bytes() -> None:
+ matmul_size = common.MatmulSize(1024, 1024, 1024)
+ lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
+ rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
+ res_type = common.ShapedType([1024, 1024], common.ElementType.f32)
+ problem_size = common.ProblemSize(
+ matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
+ )
+ assert (
+ dispatch_constraints.calculate_shared_memory_usage_in_bytes(
+ problem_size, 512, 64, 128
+ )
+ == 147456
+ )
+
+ lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8)
+ problem_size = common.ProblemSize(
+ matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
+ )
+ assert (
+ dispatch_constraints.calculate_shared_memory_usage_in_bytes(
+ problem_size, 512, 64, 128
+ )
+ == 81920
+ )
+
+ rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32)
+ problem_size = common.ProblemSize(
+ matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
+ )
+ assert (
+ dispatch_constraints.calculate_shared_memory_usage_in_bytes(
+ problem_size, 128, 64, 32
+ )
+ == 12288
+ )
+
+
+def test_generate_constraints_valid_input() -> None:
+ matmul_size = common.MatmulSize(1024, 1024, 1024)
+ lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
+ rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
+ res_type = common.ShapedType([1024, 1024], common.ElementType.f32)
+ problem_size = common.ProblemSize(
+ matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
+ )
+ # Define input parameters as z3 Ints
+ m, n, k = (
+ dispatch_constraints.z3.Int("m"),
+ z3.Int("n"),
+ z3.Int("k"),
+ )
+ subgroup_size = z3.Int("subgroup_size")
+ intrinsic_mn = z3.Int("intrinsic_mn")
+ intrinsic_k = z3.Int("intrinsic_k")
+ wg_x, wg_y, wg_z = (
+ z3.Int("wg_x"),
+ z3.Int("wg_y"),
+ z3.Int("wg_z"),
+ )
+ sg_m_cnt = z3.Int("sg_m_cnt")
+ sg_n_cnt = z3.Int("sg_n_cnt")
+ waves_per_eu = z3.Int("waves_per_eu")
+
+ constraints = dispatch_constraints.generate_constraints(
+ problem_size,
+ [m, n, k],
+ 4,
+ subgroup_size,
+ [intrinsic_mn, intrinsic_k],
+ [wg_x, wg_y, wg_z],
+ sg_m_cnt,
+ sg_n_cnt,
+ waves_per_eu,
+ )
+
+ solver = z3.Solver()
+ solver.add(constraints)
+
+ # Check if the constraints are satisfiable
+ assert solver.check() == z3.sat
+
+
+def test_generate_constraints_invalid_input() -> None:
+ # Define input parameters that should lead to unsatisfiable constraints
+ matmul_size = common.MatmulSize(1024, 1024, 1024)
+ lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
+ rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16)
+ res_type = common.ShapedType([1024, 1024], common.ElementType.f32)
+ problem_size = common.ProblemSize(
+ matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
+ )
+ m, n, k = (
+ z3.Int("m"),
+ z3.Int("n"),
+ z3.Int("k"),
+ )
+ subgroup_size = z3.Int("subgroup_size")
+ intrinsic_mn = z3.Int("intrinsic_mn")
+ intrinsic_k = z3.Int("intrinsic_k")
+ wg_x, wg_y, wg_z = (
+ z3.Int("wg_x"),
+ z3.Int("wg_y"),
+ z3.Int("wg_z"),
+ )
+ sg_m_cnt = z3.Int("sg_m_cnt")
+ sg_n_cnt = z3.Int("sg_n_cnt")
+ waves_per_eu = z3.Int("waves_per_eu")
+
+ constraints = dispatch_constraints.generate_constraints(
+ problem_size,
+ [m, n, k],
+ 4,
+ subgroup_size,
+ [intrinsic_mn, intrinsic_k],
+ [wg_x, wg_y, wg_z],
+ sg_m_cnt,
+ sg_n_cnt,
+ waves_per_eu,
+ )
+ constraints.append(m > 1000) # Adding an additional unsatisfiable constraint
+
+ solver = z3.Solver()
+ solver.add(constraints)
+
+ # Check if the constraints are unsatisfiable
+ assert solver.check() == z3.unsat
diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py
new file mode 100644
index 000000000..670f8c3f7
--- /dev/null
+++ b/tuner/tuner/dispatch_parser.py
@@ -0,0 +1,435 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Given an input dispatch, this code modifies the hyperparameters
+# in the code and runs it.
+
+import math
+import re
+from abc import ABCMeta, abstractmethod
+
+from .common import *
+
+
+def get_mmt_tile_sizes(configuration: Configuration):
+ return configuration.tile_sizes
+
+
+def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]:
+ m, n, k = configuration.tile_sizes
+ tile_size = [1] * len(tile_dims)
+ for idx, dim in enumerate(tile_dims):
+ if dim == "m":
+ tile_size[idx] = m
+ if dim == "n":
+ tile_size[idx] = n
+ if dim == "k":
+ tile_size[idx] = k
+ return tile_size
+
+
+def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]:
+ return [1] + configuration.tile_sizes
+
+
+def parse_mlir(mlir_text: str, ctx: TunerContext) -> ir.Module:
+ mlir_module = None
+ try:
+ mlir_module = ir.Module.parse(mlir_text, ctx.mlir_ctx)
+ ctx.logger.info("MLIR parsing successful!")
+ except ir.MLIRError as e:
+ ctx.logger.error(f"Error parsing MLIR: {e}")
+ raise RuntimeError(f"Error parsing MLIR: {e}")
+
+ return mlir_module
+
+
+class DispatchParser(metaclass=ABCMeta):
+ @abstractmethod
+ def supports(self, op_name: str) -> bool:
+ """Check if the tuner can handle the type of operation represented by the input string."""
+ pass
+
+ @abstractmethod
+ def get_shapes(self, template: list[str]) -> ProblemSize:
+ """Extract problem size of the operation."""
+ pass
+
+
+class MmtParser(DispatchParser):
+ def supports(self, op_name: str) -> bool:
+ return "matmul_transpose_b" in op_name
+
+ def get_shapes(self, template: list[str]) -> ProblemSize:
+ mmt_re = None
+ dps = None
+ for line in template:
+ if "linalg.generic" not in line:
+ continue
+ if r'iterator_types = ["parallel", "parallel", "reduction"]' not in line:
+ continue
+ # ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>)
+ mmt_re = rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
+ dps = re.search(mmt_re, line)
+ if dps is None:
+ continue
+
+ lhs_tensor_type = dps.group("LHS")
+ rhs_tensor_type = dps.group("RHS")
+ lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
+ assert lhs_shaped_type.rank() == 2
+ lhs_M, lhs_K = lhs_shaped_type.shape
+
+ rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
+ assert rhs_shaped_type.rank() == 2
+ rhs_N, rhs_K = rhs_shaped_type.shape
+
+ assert lhs_shaped_type.element_type == rhs_shaped_type.element_type
+ assert lhs_K == rhs_K
+
+ res_tensor_type = dps.group("RES")
+ res_shaped_type = parse_tensor_type(res_tensor_type)
+ assert res_shaped_type.rank() == 2
+ res_M, res_N = res_shaped_type.shape
+
+ assert lhs_M == res_M
+ assert rhs_N == res_N
+
+ matmul_size = MatmulSize(
+ lhs_shaped_type.shape[0],
+ rhs_shaped_type.shape[0],
+ lhs_shaped_type.shape[1],
+ )
+ return ProblemSize(
+ matmul_size,
+ lhs_type=lhs_shaped_type,
+ rhs_type=rhs_shaped_type,
+ res_type=res_shaped_type,
+ dispatch_kind=DispatchKind.mmt,
+ )
+ assert mmt_re
+ assert False, f"'{mmt_re}' not found in given context"
+
+
+class ConvParser(DispatchParser):
+ def supports(self, op_name: str) -> bool:
+ return "conv_2d_nhwc_hwcf" in op_name
+
+ def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]:
+ m, n, k = configuration.tile_sizes
+ batch = 1
+ fh = 1
+ fw = 1
+
+ oh = 1
+
+ oc = n
+ ow = m
+ ic = k
+ return [batch, oh, ow, oc, fh, fw, ic]
+
+ def get_shapes(self, template: list[str]) -> ProblemSize:
+ for line in template:
+ if "linalg.conv_2d_nhwc_hwcf" not in line:
+ continue
+
+ # ins(%19, %20 : tensor<2x34x34x1280xf16>, tensor<3x3x1280x1280xf16>) outs (%27 : tensor<2x32x32x1280xf32>)
+ conv_re = (
+ rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
+ )
+ dps = re.search(conv_re, line)
+ if dps is None:
+ continue
+
+ lhs_tensor_type = dps.group("LHS")
+ rhs_tensor_type = dps.group("RHS")
+ lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
+ assert lhs_shaped_type.rank() == 4
+
+ rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
+ assert rhs_shaped_type.rank() == 4
+
+ res_tensor_type = dps.group("RES")
+ res_shaped_type = parse_tensor_type(res_tensor_type)
+ assert res_shaped_type.rank() == 4
+
+ # int64_t n = outputShape[0];
+ # int64_t oh = outputShape[1];
+ # int64_t ow = outputShape[2];
+ # int64_t oc = outputShape[3];
+ # int64_t fh = filterShape[0];
+ # int64_t fw = filterShape[1];
+ # int64_t ic = filterShape[2];
+ dim_info = ConvDimInfo.from_rhs_res(rhs_shaped_type, res_shaped_type)
+ return ProblemSize(
+ MatmulSize(
+ M=dim_info.oh * dim_info.ow,
+ N=dim_info.oc,
+ K=dim_info.fh * dim_info.fw * dim_info.ic,
+ B=dim_info.n,
+ ),
+ lhs_shaped_type,
+ rhs_shaped_type,
+ res_shaped_type,
+ DispatchKind.conv,
+ )
+
+ assert False, "Shape not found"
+
+
+class ContractionParser(DispatchParser):
+ def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str):
+ self.lhs_dims = lhs_dims
+ self.rhs_dims = rhs_dims
+ self.tile_dims = tile_dims
+
+ def supports(self, op_name: str) -> bool:
+ return "matmul_like" in op_name
+
+ def is_broadcast_rhs_mmt_op(self, line: str) -> bool:
+ if "linalg.generic" not in line:
+ return False
+ if (
+ r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]'
+ not in line
+ ):
+ return False
+ if (
+ r"indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>"
+ not in line
+ ):
+ return False
+ return True
+
+ def is_broadcast_rhs_mmt(self, template: list[str]) -> bool:
+ return any(self.is_broadcast_rhs_mmt_op(line) for line in template)
+
+ def get_shapes_broadcast_rhs_mmt(self, template: list[str]) -> ProblemSize:
+ for line in template:
+ if not self.is_broadcast_rhs_mmt_op(line):
+ continue
+
+ # ins(%11, %12 : tensor<2x1024x1280xi8>, tensor<10240x1280xi8>) outs(%19 : tensor<2x1024x10240xi32>)
+ bmmt_re = (
+ rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
+ )
+ dps = re.search(bmmt_re, line)
+ if dps is None:
+ continue
+
+ lhs_tensor_type = dps.group("LHS")
+ rhs_tensor_type = dps.group("RHS")
+ lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
+ assert lhs_shaped_type.rank() == 3
+
+ rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
+ assert rhs_shaped_type.rank() == 2
+
+ res_tensor_type = dps.group("RES")
+ res_shaped_type = parse_tensor_type(res_tensor_type)
+ assert res_shaped_type.rank() == 3
+
+ B0, M0, K0 = lhs_shaped_type.shape
+ N1, K1 = rhs_shaped_type.shape
+ B2, M2, N2 = res_shaped_type.shape
+ assert B0 == B2
+ assert M0 == M2
+ assert N1 == N2
+ assert K0 == K1
+ return ProblemSize(
+ MatmulSize(M0, N1, K0, B0),
+ lhs_shaped_type,
+ rhs_shaped_type,
+ res_shaped_type,
+ DispatchKind.broadcast_rhs_mmt,
+ )
+
+ assert False, "Shape not found"
+
+ def get_shapes(self, template: list[str]) -> ProblemSize:
+ if self.is_broadcast_rhs_mmt(template):
+ return self.get_shapes_broadcast_rhs_mmt(template)
+
+ for line in template:
+ if "linalg.generic" not in line:
+ continue
+ if "lowering_config =" not in line:
+ continue
+ if '"reduction"' not in line:
+ continue
+
+ # ins(%7, %8 : tensor<2x1024x1280xf16>, tensor<20x64x1280xf16>)
+ cont_re = (
+ rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
+ )
+ dps = re.search(cont_re, line)
+ if dps is None:
+ continue
+
+ lhs_tensor_type = dps.group("LHS")
+ rhs_tensor_type = dps.group("RHS")
+ lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
+ assert lhs_shaped_type.rank() == len(self.lhs_dims)
+
+ rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
+ assert rhs_shaped_type.rank() == len(self.rhs_dims)
+
+ res_tensor_type = dps.group("RES")
+ res_shaped_type = parse_tensor_type(res_tensor_type)
+ assert res_shaped_type.rank() >= 2
+
+ M = math.prod(
+ val if dim == "m" else 1
+ for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape)
+ )
+ N = math.prod(
+ val if dim == "n" else 1
+ for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape)
+ )
+ K0 = math.prod(
+ val if dim == "k" else 1
+ for dim, val in zip(self.lhs_dims, lhs_shaped_type.shape)
+ )
+ K1 = math.prod(
+ val if dim == "k" else 1
+ for dim, val in zip(self.rhs_dims, rhs_shaped_type.shape)
+ )
+ assert K0 == K1
+
+ return ProblemSize(
+ MatmulSize(M, N, K0),
+ lhs_type=lhs_shaped_type,
+ rhs_type=rhs_shaped_type,
+ res_type=res_shaped_type,
+ dispatch_kind=DispatchKind.contraction,
+ )
+
+ assert False, "Shape not found"
+
+
+class BatchMmtParser(DispatchParser):
+ def supports(self, op_name: str) -> bool:
+ return "batch_matmul_transpose_b" in op_name
+
+ def get_shapes(self, template: list[str]) -> ProblemSize:
+ for line in template:
+ if "linalg.generic" not in line:
+ continue
+ if (
+ r'iterator_types = ["parallel", "parallel", "parallel", "reduction"]'
+ not in line
+ ):
+ continue
+ # ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>)
+ bmmt_re = (
+ rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
+ )
+ dps = re.search(bmmt_re, line)
+ if dps is None:
+ continue
+
+ lhs_tensor_type = dps.group("LHS")
+ rhs_tensor_type = dps.group("RHS")
+ lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
+ assert lhs_shaped_type.rank() == 3
+
+ rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
+ assert rhs_shaped_type.rank() == 3
+
+ res_tensor_type = dps.group("RES")
+ res_shaped_type = parse_tensor_type(res_tensor_type)
+ assert res_shaped_type.rank() == 3
+
+ B0, M0, K0 = lhs_shaped_type.shape
+ B1, N1, K1 = rhs_shaped_type.shape
+ B2, M2, N2 = res_shaped_type.shape
+ assert B0 == B1
+ assert B0 == B2
+ assert M0 == M2
+ assert N1 == N2
+ assert K0 == K1
+ return ProblemSize(
+ MatmulSize(M0, N1, K0, B0),
+ lhs_shaped_type,
+ rhs_shaped_type,
+ res_shaped_type,
+ DispatchKind.batch_mmt,
+ )
+
+ assert False, "Shape not found"
+
+
+class BatchMatmulParser(DispatchParser):
+ def __init__(self, lhs_dims: str, rhs_dims: str, tile_dims: str):
+ self.lhs_dims = lhs_dims
+ self.rhs_dims = rhs_dims
+ self.tile_dims = tile_dims
+
+ def supports(self, op_name: str) -> bool:
+ return "batch_matmul" in op_name
+
+ def get_shapes(self, template: list[str]) -> ProblemSize:
+ for line in template:
+ if "linalg.batch_matmul" not in line:
+ continue
+ # ins(%9, %10 : tensor<64x72x1280xf16>, tensor<64x1280x1280xf16>)
+ # outs(%12 : tensor<64x72x1280xf32>)
+ cont_re = (
+ rf"{MlirRegex.dps_ins_two_args()}\s+{MlirRegex.dps_outs_one_arg()}"
+ )
+ dps = re.search(cont_re, line)
+ if dps is None:
+ continue
+
+ lhs_tensor_type = dps.group("LHS")
+ rhs_tensor_type = dps.group("RHS")
+ lhs_shaped_type = parse_tensor_type(lhs_tensor_type)
+ assert lhs_shaped_type.rank() == len(self.lhs_dims)
+
+ rhs_shaped_type = parse_tensor_type(rhs_tensor_type)
+ assert rhs_shaped_type.rank() == len(self.rhs_dims)
+
+ res_tensor_type = dps.group("RES")
+ res_shaped_type = parse_tensor_type(res_tensor_type)
+ assert res_shaped_type.rank() == lhs_shaped_type.rank()
+
+ LHS = lhs_shaped_type.shape
+ RHS = rhs_shaped_type.shape
+ RES = res_shaped_type.shape
+
+ B = math.prod(
+ val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, LHS)
+ )
+ B0 = math.prod(
+ val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RHS)
+ )
+ B1 = math.prod(
+ val if dim == "b" else 1 for dim, val in zip(self.lhs_dims, RES)
+ )
+ M = math.prod(
+ val if dim == "m" else 1 for dim, val in zip(self.lhs_dims, LHS)
+ )
+ N = math.prod(
+ val if dim == "n" else 1 for dim, val in zip(self.rhs_dims, RHS)
+ )
+ K0 = math.prod(
+ val if dim == "k" else 1 for dim, val in zip(self.lhs_dims, LHS)
+ )
+ K1 = math.prod(
+ val if dim == "k" else 1 for dim, val in zip(self.rhs_dims, RHS)
+ )
+ assert B == B0 and B == B1
+ assert K0 == K1
+
+ return ProblemSize(
+ MatmulSize(M, N, K0, B),
+ lhs_type=lhs_shaped_type,
+ rhs_type=rhs_shaped_type,
+ res_type=res_shaped_type,
+ dispatch_kind=DispatchKind.batch_matmul,
+ )
+
+ assert False, "Shape not found"
diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py
new file mode 100644
index 000000000..bcdee240c
--- /dev/null
+++ b/tuner/tuner/dispatch_parser_test.py
@@ -0,0 +1,176 @@
+# Copyright 2024 Advanced Micro Devices, Inc.
+#
+# Licensed under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+"""
+Usage: python -m pytest candidate_gen_test.py
+"""
+
+import pytest
+
+from logging import Logger
+from unittest.mock import MagicMock
+
+from iree.compiler import ir # type: ignore
+from iree.compiler.dialects import func # type: ignore
+
+from . import common
+from . import dispatch_parser
+
+
+def test_get_mmt_tile_sizes() -> None:
+ config = dispatch_parser.Configuration(
+ subgroup_size=0,
+ workgroup_size=[],
+ intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ tile_sizes=[128, 320, 32],
+ subgroup_m_count=0,
+ subgroup_n_count=0,
+ gpu_pipeline_options=common.GpuPipelineOptions(),
+ waves_per_eu=0,
+ )
+ assert dispatch_parser.get_mmt_tile_sizes(config) == [128, 320, 32]
+
+
+def test_get_conv_tile_sizes() -> None:
+ config = dispatch_parser.Configuration(
+ subgroup_size=64,
+ workgroup_size=[256, 1, 1],
+ intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ tile_sizes=[464, 320, 16],
+ subgroup_m_count=1,
+ subgroup_n_count=4,
+ gpu_pipeline_options=common.GpuPipelineOptions(),
+ waves_per_eu=1,
+ )
+ assert dispatch_parser.ConvParser().get_conv_tile_sizes(config) == [
+ 1,
+ 1,
+ 464,
+ 320,
+ 1,
+ 1,
+ 16,
+ ]
+
+
+def test_get_contract_tile_sizes() -> None:
+ config = dispatch_parser.Configuration(
+ subgroup_size=32,
+ workgroup_size=[16, 16, 1],
+ intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
+ tile_sizes=[4, 8, 16],
+ subgroup_m_count=1,
+ subgroup_n_count=1,
+ gpu_pipeline_options=common.GpuPipelineOptions(),
+ waves_per_eu=2,
+ )
+ assert dispatch_parser.get_contract_tile_sizes(config, "mnk") == [4, 8, 16]
+ assert dispatch_parser.get_contract_tile_sizes(config, "nmk") == [8, 4, 16]
+ assert dispatch_parser.get_contract_tile_sizes(config, "knm") == [16, 8, 4]
+ assert dispatch_parser.get_contract_tile_sizes(config, "kkk") == [
+ 16,
+ 16,
+ 16,
+ ]
+
+
+def test_get_shapes_mmt() -> None:
+ template = [
+ r"%18 = tensor.empty() : tensor<2048x1280xf32>",
+ r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>",
+ r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {',
+ r"^bb0(%in: f16, %in_0: f16, %out: f32):",
+ ]
+ assert dispatch_parser.MmtParser().get_shapes(template) == common.ProblemSize(
+ common.MatmulSize(2048, 1280, 1280),
+ common.ShapedType([2048, 1280], common.ElementType.f16),
+ common.ShapedType([1280, 1280], common.ElementType.f16),
+ common.ShapedType([2048, 1280], common.ElementType.f32),
+ dispatch_parser.DispatchKind.mmt,
+ )
+
+
+def test_get_shapes_conv() -> None:
+ template = [
+ r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>",
+ r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>",
+ r"flow.dispatch.tensor.store %8, %2, offsets = [%workgroup_id_z, %workgroup_id_y, 0, %3], sizes = [1, 1, 32, 256], strides = [1, 1, 1, 1] : tensor<1x1x32x256xf32> -> !flow.dispatch.tensor>",
+ ]
+ assert dispatch_parser.ConvParser().get_shapes(template) == common.ProblemSize(
+ common.MatmulSize(32, 256, 11520),
+ common.ShapedType([1, 3, 34, 1280], common.ElementType.f16),
+ common.ShapedType([3, 3, 1280, 256], common.ElementType.f16),
+ common.ShapedType([1, 1, 32, 256], common.ElementType.f32),
+ dispatch_parser.DispatchKind.conv,
+ )
+
+
+def test_get_shapes_contract() -> None:
+ template = [
+ r"%18 = tensor.empty() : tensor<2048x1280xf32>",
+ r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>",
+ r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {',
+ r"^bb0(%in: f16, %in_0: f16, %out: f32):",
+ ]
+ assert dispatch_parser.ContractionParser("mk", "nk", "mnk").get_shapes(
+ template
+ ) == common.ProblemSize(
+ common.MatmulSize(2048, 1280, 1280),
+ common.ShapedType([2048, 1280], common.ElementType.f16),
+ common.ShapedType([1280, 1280], common.ElementType.f16),
+ common.ShapedType([2048, 1280], common.ElementType.f32),
+ dispatch_parser.DispatchKind.contraction,
+ )
+
+
+def test_get_shapes_batch_matmul() -> None:
+ template = [
+ "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>",
+ "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>",
+ "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>",
+ ]
+ assert dispatch_parser.BatchMatmulParser("bmk", "bkn", "mnk").get_shapes(
+ template
+ ) == common.ProblemSize(
+ common.MatmulSize(32, 32, 1024, 1),
+ common.ShapedType([1, 32, 1024], common.ElementType.f32),
+ common.ShapedType([1, 1024, 32], common.ElementType.f32),
+ common.ShapedType([1, 32, 32], common.ElementType.f32),
+ dispatch_parser.DispatchKind.batch_matmul,
+ )
+
+
+def test_get_shapes_batch_mmt() -> None:
+ template = [
+ r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>",
+ r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {',
+ r"flow.dispatch.tensor.store %21, %10, offsets = [0, 0, 0], sizes = [2, 4096, 640], strides = [1, 1, 1] : tensor<2x4096x640xf16> -> !flow.dispatch.tensor>",
+ ]
+ assert dispatch_parser.BatchMmtParser().get_shapes(template) == common.ProblemSize(
+ common.MatmulSize(4096, 640, 640, 2),
+ common.ShapedType([2, 4096, 640], common.ElementType.i8),
+ common.ShapedType([2, 640, 640], common.ElementType.i8),
+ common.ShapedType([2, 4096, 640], common.ElementType.i32),
+ dispatch_parser.DispatchKind.batch_mmt,
+ )
+
+
+def test_parse_mlir() -> None:
+ with ir.Context() as ctx:
+ mlir_str = r"""
+ builtin.module {
+ func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
+ %0 = arith.mulf %arg0, %arg1 : tensor<4xf32>
+ return %0 : tensor<4xf32>
+ }
+ }
+ """
+ logger: Logger = MagicMock(spec=Logger)
+ tuner_context = common.TunerContext(ctx, logger)
+ mlir_module = dispatch_parser.parse_mlir(mlir_str, tuner_context)
+ assert mlir_module is not None
+ assert isinstance(mlir_module, ir.Module)
+ assert isinstance(mlir_module.body.operations[0], func.FuncOp)
diff --git a/tuner/tuner/libtuner.py b/tuner/tuner/libtuner.py
index 91c7b417a..3aa932dc4 100644
--- a/tuner/tuner/libtuner.py
+++ b/tuner/tuner/libtuner.py
@@ -38,7 +38,7 @@
import random
import json
from abc import ABC, abstractmethod
-import iree.runtime as ireert
+import iree.runtime as ireert # type: ignore
from . import candidate_gen
@@ -250,10 +250,11 @@ def get_mean_time_us(self) -> Optional[float]:
mean_benchmark = self.find_mean_benchmark(self.result_json)
if mean_benchmark:
- real_time = mean_benchmark.get("real_time")
- time_unit = mean_benchmark.get("time_unit")
+ real_time: float | None = mean_benchmark.get("real_time")
+ time_unit: str | None = mean_benchmark.get("time_unit")
if real_time is not None:
+ assert time_unit is not None
return self.unit_to_microseconds(real_time, time_unit)
return None
@@ -549,7 +550,7 @@ def create_worker_context_queue(device_ids: list[int]) -> queue.Queue[tuple[int,
return worker_contexts_queue
-def run_command(run_pack: RunPack) -> TaskResult:
+def run_command(run_pack: RunPack) -> RunResult:
command = run_pack.command
check = run_pack.check
timeout_seconds = run_pack.timeout_seconds
@@ -946,6 +947,7 @@ def parse_dispatch_benchmark_results(
continue
res_json = extract_benchmark_from_run_result(benchmark_result.run_result)
+ assert res_json is not None
res = IREEBenchmarkResult(candidate_id, res_json)
benchmark_time = res.get_mean_time_us()
assert benchmark_time is not None
@@ -985,7 +987,10 @@ def generate_sample_task_result(
stdout=stdout,
returncode=0,
)
- return TaskResult(result=res, candidate_id=candidate_id, device_id=device_id)
+ run_result = RunResult(res, False)
+ return TaskResult(
+ run_result=run_result, candidate_id=candidate_id, device_id=device_id
+ )
def generate_dryrun_dispatch_benchmark_results(
@@ -1235,6 +1240,7 @@ def parse_model_benchmark_results(
continue
result_json = extract_benchmark_from_run_result(task_result.run_result)
+ assert result_json is not None
res = IREEBenchmarkResult(candidate_id, result_json)
benchmark_time = res.get_mean_time_us()
assert benchmark_time is not None
diff --git a/tuner/tuner/libtuner_test.py b/tuner/tuner/libtuner_test.py
index 36bda3bd5..11af59af4 100644
--- a/tuner/tuner/libtuner_test.py
+++ b/tuner/tuner/libtuner_test.py
@@ -7,6 +7,7 @@
import argparse
import pytest
import json
+from subprocess import CompletedProcess
from unittest.mock import call, patch, MagicMock
from . import libtuner
@@ -15,15 +16,15 @@
"""
-def test_group_benchmark_results_by_device_id():
+def test_group_benchmark_results_by_device_id() -> None:
# Create mock TaskResult objects with device_id attributes
- task_result_1 = MagicMock()
+ task_result_1: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult)
task_result_1.device_id = "device_1"
- task_result_2 = MagicMock()
+ task_result_2: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult)
task_result_2.device_id = "device_2"
- task_result_3 = MagicMock()
+ task_result_3: libtuner.TaskResult = MagicMock(spec=libtuner.TaskResult)
task_result_3.device_id = "device_1"
benchmark_results = [task_result_1, task_result_2, task_result_3]
@@ -40,7 +41,7 @@ def test_group_benchmark_results_by_device_id():
assert grouped_results[1][0].device_id == "device_2"
-def test_find_collisions():
+def test_find_collisions() -> None:
input = [(1, "abc"), (2, "def"), (3, "abc")]
assert libtuner.find_collisions(input) == (True, [("abc", [1, 3]), ("def", [2])])
input = [(1, "abc"), (2, "def"), (3, "hig")]
@@ -50,14 +51,14 @@ def test_find_collisions():
)
-def test_collision_handler():
+def test_collision_handler() -> None:
input = [(1, "abc"), (2, "def"), (3, "abc"), (4, "def"), (5, "hig")]
assert libtuner.collision_handler(input) == (True, [1, 2, 5])
input = [(1, "abc"), (2, "def"), (3, "hig")]
assert libtuner.collision_handler(input) == (False, [])
-def test_IREEBenchmarkResult_get():
+def test_IREEBenchmarkResult_get() -> None:
# Time is int in us
int_json = [{"aggregate_name": "mean", "real_time": 1, "time_unit": "us"}]
@@ -108,7 +109,7 @@ def test_IREEBenchmarkResult_get():
assert res.get_mean_time_us() == None
# Invalid json: empty dictionary
- res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json={})
+ res = libtuner.IREEBenchmarkResult(candidate_id=8, result_json=[])
assert res.get_mean_time_us() is None
# Invalid json: invalid time unit
@@ -131,7 +132,7 @@ def test_IREEBenchmarkResult_get():
assert res.get_mean_time_us() is None
-def test_generate_display_BR():
+def test_generate_display_BR() -> None:
output = libtuner.generate_display_DBR(1, 3.14)
expected = f"1\tMean Time: 3.1"
assert output == expected, "DispatchBenchmarkResult generates invalid sample string"
@@ -147,29 +148,38 @@ def test_generate_display_BR():
assert output == expected, "ModelBenchmarkResult generates invalid sample string"
-def test_parse_dispatch_benchmark_results():
+def make_mock_task_result() -> libtuner.TaskResult:
+ process: CompletedProcess = MagicMock(spec=CompletedProcess)
+ run_result = libtuner.RunResult(process, False)
+ task_result = libtuner.TaskResult(run_result, 0, "")
+ return task_result
+
+
+def test_parse_dispatch_benchmark_results() -> None:
base_path = libtuner.Path("/mock/base/dir")
spec_dir = base_path / "specs"
path_config = libtuner.PathConfig()
object.__setattr__(path_config, "specs_dir", spec_dir)
- mock_result_1 = MagicMock()
+ mock_result_1 = make_mock_task_result()
mock_json_1 = {
"benchmarks": [
{"aggregate_name": "mean", "real_time": 100.0, "time_unit": "us"}
]
}
+ assert mock_result_1.run_result.process_res is not None
mock_result_1.run_result.process_res.stdout = json.dumps(mock_json_1)
mock_result_1.candidate_id = 1
- mock_result_2 = MagicMock()
+ mock_result_2 = make_mock_task_result()
mock_json_2 = {
"benchmarks": [
{"aggregate_name": "mean", "real_time": 200.0, "time_unit": "us"}
]
}
+ assert mock_result_2.run_result.process_res is not None
mock_result_2.run_result.process_res.stdout = json.dumps(mock_json_2)
mock_result_2.candidate_id = 2
- mock_result_3 = MagicMock()
+ mock_result_3 = make_mock_task_result()
mock_json_3 = {
"benchmarks": [
{
@@ -179,11 +189,11 @@ def test_parse_dispatch_benchmark_results():
}
]
}
+ assert mock_result_3.run_result.process_res is not None
mock_result_3.run_result.process_res.stdout = json.dumps(mock_json_3)
mock_result_3.candidate_id = 3
- mock_result_4 = MagicMock()
- mock_result_4.run_result.process_res = None # Incomplete result
- mock_result_4.candidate_id = 4
+ # Incomplete result.
+ mock_result_4 = libtuner.TaskResult(libtuner.RunResult(None, True), 4, "4")
benchmark_results = [mock_result_1, mock_result_2, mock_result_3, mock_result_4]
candidate_trackers = []
@@ -239,7 +249,7 @@ def test_parse_dispatch_benchmark_results():
)
-def test_parse_model_benchmark_results():
+def test_parse_model_benchmark_results() -> None:
# Setup mock data for candidate_trackers
tracker0 = libtuner.CandidateTracker(0)
tracker0.compiled_model_path = libtuner.Path("/path/to/baseline.vmfb")
@@ -256,38 +266,40 @@ def test_parse_model_benchmark_results():
candidate_trackers = [tracker0, tracker1, tracker2, tracker3]
# Setup mock data for task results
- result1 = MagicMock()
+ result1 = make_mock_task_result()
result_json_1 = {"benchmarks": [{"real_time": 1.23}]}
+ assert result1.run_result.process_res is not None
result1.run_result.process_res.stdout = json.dumps(result_json_1)
result1.candidate_id = 1
result1.device_id = "device1"
- result2 = MagicMock()
+ result2 = make_mock_task_result()
result_json_2 = {"benchmarks": [{"real_time": 4.56}]}
+ assert result2.run_result.process_res is not None
result2.run_result.process_res.stdout = json.dumps(result_json_2)
result2.candidate_id = 2
result2.device_id = "device2"
- result3 = MagicMock()
+ result3 = make_mock_task_result()
result_json_3 = {"benchmarks": [{"real_time": 0.98}]}
+ assert result3.run_result.process_res is not None
result3.run_result.process_res.stdout = json.dumps(result_json_3)
result3.candidate_id = 0
result3.device_id = "device1"
- result4 = MagicMock()
+ result4 = make_mock_task_result()
result_json_4 = {"benchmarks": [{"real_time": 4.13}]}
+ assert result4.run_result.process_res is not None
result4.run_result.process_res.stdout = json.dumps(result_json_4)
result4.candidate_id = 0
result4.device_id = "device2"
# Incomplete baseline on device3
- result5 = MagicMock()
- result5.run_result.process_res = None
- result5.candidate_id = 0
- result5.device_id = "device3"
+ result5 = libtuner.TaskResult(libtuner.RunResult(None, True), 0, "device3")
- result6 = MagicMock()
+ result6 = make_mock_task_result()
result_json_6 = {"benchmarks": [{"real_time": 3.38}]}
+ assert result6.run_result.process_res is not None
result6.run_result.process_res.stdout = json.dumps(result_json_6)
result6.candidate_id = 3
result6.device_id = "device3"
@@ -347,14 +359,14 @@ def mock_get_mean_time_us(self):
)
-def test_extract_driver_names():
+def test_extract_driver_names() -> None:
user_devices = ["hip://0", "local-sync://default", "cuda://default"]
expected_output = {"hip", "local-sync", "cuda"}
assert libtuner.extract_driver_names(user_devices) == expected_output
-def test_fetch_available_devices_success():
+def test_fetch_available_devices_success() -> None:
drivers = ["hip", "local-sync", "cuda"]
mock_devices = {
"hip": [{"path": "ABCD", "device_id": 1}],
@@ -384,7 +396,7 @@ def get_mock_driver(name):
assert actual_output == expected_output
-def test_fetch_available_devices_failure():
+def test_fetch_available_devices_failure() -> None:
drivers = ["hip", "local-sync", "cuda"]
mock_devices = {
"hip": [{"path": "ABCD", "device_id": 1}],
@@ -421,7 +433,7 @@ def get_mock_driver(name):
)
-def test_parse_devices():
+def test_parse_devices() -> None:
user_devices_str = "hip://0, local-sync://default, cuda://default"
expected_output = ["hip://0", "local-sync://default", "cuda://default"]
@@ -432,7 +444,7 @@ def test_parse_devices():
mock_handle_error.assert_not_called()
-def test_parse_devices_with_invalid_input():
+def test_parse_devices_with_invalid_input() -> None:
user_devices_str = "hip://0, local-sync://default, invalid_device, cuda://default"
expected_output = [
"hip://0",
@@ -452,7 +464,7 @@ def test_parse_devices_with_invalid_input():
)
-def test_validate_devices():
+def test_validate_devices() -> None:
user_devices = ["hip://0", "local-sync://default"]
user_drivers = {"hip", "local-sync"}
@@ -469,7 +481,7 @@ def test_validate_devices():
)
-def test_validate_devices_with_invalid_device():
+def test_validate_devices_with_invalid_device() -> None:
user_devices = ["hip://0", "local-sync://default", "cuda://default"]
user_drivers = {"hip", "local-sync", "cuda"}
diff --git a/tuner/tuner/py.typed b/tuner/tuner/py.typed
new file mode 100644
index 000000000..e69de29bb
diff --git a/tuner/version.json b/tuner/version.json
new file mode 100644
index 000000000..f09f61d2a
--- /dev/null
+++ b/tuner/version.json
@@ -0,0 +1,3 @@
+{
+ "package-version": "2.9.2.dev"
+}