Skip to content

Commit

Permalink
Fix linear (#613)
Browse files Browse the repository at this point in the history
Fixes punet regression from linear.py
I went a little crazy with the fake-quant arg and the logic was
ugly. Now the fake_quant arg only affects q_input. The purpose in having
it at all is to allow exporting and fake eager using the same irpa file

---------

Co-authored-by: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com>
  • Loading branch information
dan-garvey and nithinsubbiah authored Nov 27, 2024
1 parent 10cd58f commit 9e1e121
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 13 deletions.
39 changes: 39 additions & 0 deletions .github/workflows/ci-sharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,45 @@ concurrency:
cancel-in-progress: true

jobs:
test_punet:
name: "Integration Tests - punet"
runs-on: nodai-amdgpu-mi250-x86-64
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
steps:
- name: "Setting up Python"
id: setup_python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: 3.11

- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- name: Cache Pip Packages
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
id: cache-pip
with:
path: ${{ env.PIP_CACHE_DIR }}
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','sharktank/requirements*.txt') }}

- name: Install pip deps
run: |
python -m pip install --no-compile --upgrade pip
# Note: We install in three steps in order to satisfy requirements
# from non default locations first. Installing the PyTorch CPU
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
pip install --no-compile -r pytorch-cpu-requirements.txt
pip install --no-compile -r requirements.txt -r sharktank/requirements-tests.txt -e sharktank/
# Update to the latest iree packages.
pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \
iree-base-compiler iree-base-runtime --src deps \
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
- name: Run punet tests
run: |
pytest -v sharktank/ -m model_punet
test:
name: "Unit Tests and Type Checking"
strategy:
Expand Down
7 changes: 4 additions & 3 deletions sharktank/integration/models/punet/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,13 @@ def sdxl_fp16_dataset(sdxl_fp16_base_files, temp_dir):
def sdxl_int8_base_files():
from huggingface_hub import hf_hub_download

REPO_ID = "amd-shark/sdxl-quant-models"
REVISION = "942e771bf0c2657a8b33380103d04747a75dfa4a"
REPO_ID = "amd-shark/sdxl-quant-int8"
SUBFOLDER = "mi300_all_sym_8_step14_fp32"
REVISION = "efda8afb35fd72c1769e02370b320b1011622958"

def download(filename):
return hf_hub_download(
repo_id=REPO_ID, subfolder="unet/int8", filename=filename, revision=REVISION
repo_id=REPO_ID, subfolder=SUBFOLDER, filename=filename, revision=REVISION
)

return {
Expand Down
21 changes: 11 additions & 10 deletions sharktank/sharktank/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ class LinearLayer(ThetaLayer):
x = x * premul_input
matmul(x, weight.T) + bias
fake_quant exists to allow export without adding dequant ops.
when fake_quant is True, the op will in quant dequant fashion.
When false, it will keep quantized types.
fake quant only exists in order to allow for q_input to act as qdq.
when fake quant is false, q_input will quantize normally.
```
"""

Expand All @@ -43,7 +42,7 @@ def __init__(
*,
weight_name: str = "weight",
bias_name: str = "bias",
fake_quant: bool = True,
fake_quant: bool = False,
):
super().__init__(theta)
self._simulate_native_quant = True
Expand Down Expand Up @@ -74,21 +73,23 @@ def forward(self, x):
x = q_input.quantize(x)
if self.fake_quant:
x = x.unpack().dequant()
elif qdq_input is not None and self.fake_quant:

elif qdq_input is not None:
x = qdq_input.quantize(x).unpack().dequant()

y = ops.linear(x, weight, bias)

# Unconditionally dequantize.
if isinstance(y, QuantizedTensor) and not self.fake_quant:
if isinstance(y, QuantizedTensor):
y = y.unpack().dequant()
# Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32.
# We can truncate to fp16 in iree, so we do a cast here
# to account for this in the IR. This is may not be the right
# level to do this, but for now its here.
if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.float16)
return y
if qdq_output is not None and self.fake_quant:
if not isinstance(y, QuantizedTensor):
if y.dtype == torch.float8_e4m3fnuz:
y = ops.to(y, torch.float16)
return y
if qdq_output is not None:
y = qdq_output.quantize(y).unpack().dequant()
return y

0 comments on commit 9e1e121

Please sign in to comment.