Skip to content

Commit

Permalink
Add T5 encoder bfloat16 support (#614)
Browse files Browse the repository at this point in the history
Adds various tests for verifying bfloat16 execution.

The bfloat16 eager execution matches Flux's output. One test verifies
this against golden values from the Flux pipeline.

I would say that the eager numerical error is still quite high. My
initial idea was to compare the numerical error in IREE bfloat16 vs
eager float32. It should match the error profile as the one we get
between eager bfloat16 compared against eager float32. The problem is
that the bfloat16 eager has a high error that needs further
investigation.
Due to this some tests are marked as xfail as we don't have a good
metric to evaluate the IREE results.
  • Loading branch information
sogartar authored Nov 27, 2024
1 parent 26bf8ce commit 10cd58f
Show file tree
Hide file tree
Showing 13 changed files with 615 additions and 114 deletions.
41 changes: 29 additions & 12 deletions sharktank/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,28 @@ def pytest_addoption(parser):
# --outtype=f32 \
# t5-v1_1-small
parser.addoption(
"--google-t5-v1-1-small-fp32-model-path",
"--google-t5-v1-1-small-f32-model-path",
type=Path,
default="/data/t5/small/google__t5-v1_1-small_fp32.gguf",
help="Google T5 v1.1 small fp32 model path",
default="/data/t5/small/google__t5-v1_1-small_f32.gguf",
help="Google T5 v1.1 small float32 model path",
)
parser.addoption(
"--google-t5-v1-1-xxl-fp32-model-path",
"--google-t5-v1-1-small-bf16-model-path",
type=Path,
default="/data/t5/xxl/google__t5-v1_1-xxl_fp32.gguf",
help="Google T5 v1.1 XXL fp32 model path",
default="/data/t5/small/google__t5-v1_1-small_bf16.gguf",
help="Google T5 v1.1 small bfloat16 model path",
)
parser.addoption(
"--google-t5-v1-1-xxl-f32-model-path",
type=Path,
default="/data/t5/xxl/google__t5-v1_1-xxl_f32.gguf",
help="Google T5 v1.1 XXL float32 model path",
)
parser.addoption(
"--google-t5-v1-1-xxl-bf16-model-path",
type=Path,
default="/data/t5/xxl/google__t5-v1_1-xxl_bf16.gguf",
help="Google T5 v1.1 XXL bfloat16 model path",
)

parser.addoption(
Expand Down Expand Up @@ -288,15 +300,20 @@ def get_model_artifacts(request: FixtureRequest):
model_path["llama3_405b_fp8_model_path"] = set_fixture_from_cli_option(
request, "--llama3-405b-fp8-model-path", "llama3_405b_fp8_model"
)
model_path["google__t5_v1_1_small_fp32_model_path"] = set_fixture_from_cli_option(
model_path["google__t5_v1_1_small_f32_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-small-f32-model-path",
"google__t5_v1_1_small_f32_model",
)
model_path["google__t5_v1_1_small_bf16_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-small-fp32-model-path",
"google__t5_v1_1_small_fp32_model",
"--google-t5-v1-1-small-bf16-model-path",
"google__t5_v1_1_small_bf16_model",
)
model_path["google__t5_v1_1_xxl_fp32_model_path"] = set_fixture_from_cli_option(
model_path["google__t5_v1_1_xxl_f32_model_path"] = set_fixture_from_cli_option(
request,
"--google-t5-v1-1-xxl-fp32-model-path",
"google__t5_v1_1_xxl_fp32_model",
"--google-t5-v1-1-xxl-f32-model-path",
"google__t5_v1_1_xxl_f32_model",
)
return model_path

Expand Down
2 changes: 1 addition & 1 deletion sharktank/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
iree-turbine

# Runtime deps.
gguf==0.6.0
gguf==0.10.0
numpy<2.0

# Needed for newer gguf versions (TODO: remove when gguf package includes this)
Expand Down
17 changes: 15 additions & 2 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
== properties["t5.attention.layer_norm_rms_epsilon"]
)

all_kwargs = {"vocab_size": None, "feed_forward_proj": None}

gguf_to_config_names_map = {
"t5.context_length": ["context_length"],
"t5.embedding_length": ["d_model"],
Expand All @@ -236,18 +238,29 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
"t5.attention.key_length": ["d_kv"],
"t5.attention.layer_norm_epsilon": ["layer_norm_epsilon"],
"t5.attention.relative_buckets_count": ["relative_attention_num_buckets"],
"t5.decoder_start_token_id": ["decoder_start_token_id"],
"tokenizer.ggml.eos_token_id": ["eos_token_id"],
"tokenizer.ggml.padding_token_id": ["pad_token_id"],
}
all_kwargs = {"vocab_size": None, "feed_forward_proj": None}
all_kwargs.update(
{
config_name: properties[gguf_name]
for gguf_name, config_names in gguf_to_config_names_map.items()
for config_name in config_names
}
)

gguf_to_optional_config_names_map = {
"t5.decoder_start_token_id": ["decoder_start_token_id"],
}
all_kwargs.update(
{
config_name: properties[gguf_name]
for gguf_name, config_names in gguf_to_optional_config_names_map.items()
for config_name in config_names
if gguf_name in properties
}
)

if "tokenizer.ggml.tokens" in properties:
all_kwargs["vocab_size"] = len(properties["tokenizer.ggml.tokens"])
all_kwargs.update(kwargs)
Expand Down
3 changes: 2 additions & 1 deletion sharktank/sharktank/layers/token_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch
from typing import Optional

from .. import ops
from .base import Theta, ThetaLayer
Expand All @@ -16,7 +17,7 @@ def __init__(
theta: Theta,
*,
weight_name: str = "weight",
dtype: torch.dtype = torch.float32,
dtype: Optional[torch.dtype] = torch.float32,
):
super().__init__(theta)
self.weight = self.theta_tensor(weight_name)
Expand Down
20 changes: 17 additions & 3 deletions sharktank/sharktank/models/t5/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Union
import functools
from typing import Optional, Union
from pathlib import Path
import torch
from copy import copy

from .t5 import T5Config, T5Encoder
from ...types import Dataset
from ...transforms.dataset import set_float_dtype
from iree.turbine.aot import FxProgramsBuilder, export

__all__ = [
Expand Down Expand Up @@ -91,7 +94,18 @@ def prune_decoder_parameters(dataset: Dataset):
pass


def export_encoder_iree_parameters(model_path: str, output_path: str):
dataset = Dataset.load(model_path)
def export_encoder_iree_parameters(
model_path_or_dataset: str | Dataset,
output_path: str,
dtype: Optional[torch.dtype] = None,
):
if isinstance(model_path_or_dataset, Dataset):
dataset = copy(model_path_or_dataset)
else:
dataset = Dataset.load(model_path_or_dataset)
if dtype:
dataset.root_theta = dataset.root_theta.transform(
functools.partial(set_float_dtype, dtype=dtype)
)
prune_decoder_parameters(dataset)
dataset.save(output_path)
8 changes: 6 additions & 2 deletions sharktank/sharktank/models/t5/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,9 @@ def __init__(self, theta: Theta, config: T5Config, embed_tokens=None):
self.add_module(
"final_layer_norm",
RMSNormLayer(
theta(f"{theta_prefix}.output_norm"), epsilon=config.layer_norm_epsilon
theta(f"{theta_prefix}.output_norm"),
epsilon=config.layer_norm_epsilon,
dtype=config.activation_dtype,
),
)

Expand Down Expand Up @@ -1046,7 +1048,9 @@ def __init__(self, theta: Theta, config: T5Config):
super().__init__()
self.add_module(
"token_embedding",
TokenEmbeddingLayer(theta("token_embd"), dtype=config.activation_dtype),
TokenEmbeddingLayer(
theta("token_embd"), dtype=theta("token_embd").tensor("weight").dtype
),
)

encoder_config = copy.deepcopy(config)
Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/transforms/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .sharding import *
from .dataset import *
19 changes: 19 additions & 0 deletions sharktank/sharktank/transforms/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch

from ...types.tensors import InferenceTensor, PrimitiveTensor, DefaultPrimitiveTensor
from ... import ops


def set_float_dtype(tensor: InferenceTensor, dtype: torch.dtype) -> InferenceTensor:
if isinstance(tensor, PrimitiveTensor) and tensor.dtype.is_floating_point:
return DefaultPrimitiveTensor(
name=tensor.name, data=ops.to(tensor, dtype=dtype)
)

return tensor
9 changes: 9 additions & 0 deletions sharktank/sharktank/types/gguf_interop/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def _wrap_tensor(
name=name, data=_externalize_tensor(name, data, logical_shape)
)

if type_name == "BF16":
assert data.dtype == np.uint8
return DefaultPrimitiveTensor(
name=name,
data=_externalize_tensor(name, data.view(np.int16), logical_shape).view(
dtype=torch.bfloat16
),
)

quantized_type = _quantized_types.get(type_name)
if quantized_type is not None:
return quantized_type(
Expand Down
40 changes: 40 additions & 0 deletions sharktank/sharktank/types/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"AnyTensor",
"DefaultPrimitiveTensor",
"dtype_to_serialized_name",
"dtype_to_serialized_short_name",
"flatten_tensor_tree",
"InferenceTensor",
"MetaDataValueType",
Expand All @@ -51,6 +52,7 @@
"register_quantized_layout",
"ReplicatedTensor",
"serialized_name_to_dtype",
"serialized_short_name_to_dtype",
"ShardedTensor",
"SplitPrimitiveTensor",
"torch_tree_flatten",
Expand Down Expand Up @@ -1286,6 +1288,15 @@ def dtype_to_serialized_name(dtype: torch.dtype) -> str:
) from e


def dtype_to_serialized_short_name(dtype: torch.dtype) -> str:
try:
return _DTYPE_TO_SHORT_NAME[dtype]
except KeyError as e:
raise KeyError(
f"Missing mapping for dtype {dtype}. Please add to the _SHORT_NAME_TO_DTYPE dict"
) from e


def serialized_name_to_dtype(dtype_name: str) -> torch.dtype:
try:
return _NAME_TO_DTYPE[dtype_name]
Expand All @@ -1295,6 +1306,15 @@ def serialized_name_to_dtype(dtype_name: str) -> torch.dtype:
) from e


def serialized_short_name_to_dtype(dtype_name: str) -> torch.dtype:
try:
return _SHORT_NAME_TO_DTYPE[dtype_name]
except KeyError as e:
raise KeyError(
f"Missing mapping for dtype '{dtype_name}'. Please add to the _SHORT_NAME_TO_DTYPE dict"
) from e


_NAME_TO_DTYPE: dict[str, torch.dtype] = {
"float32": torch.float32,
"float64": torch.float64,
Expand Down Expand Up @@ -1338,6 +1358,26 @@ def _maybe_dtype(*names: str):

_DTYPE_TO_NAME: dict[torch.dtype, str] = {v: k for k, v in _NAME_TO_DTYPE.items()}

_SHORT_NAME_TO_DTYPE: dict[str, torch.dtype] = {
"f32": torch.float32,
"f64": torch.float64,
"c64": torch.complex64,
"c128": torch.complex128,
"f16": torch.float16,
"bf16": torch.bfloat16,
"ui8": torch.uint8,
"i8": torch.int8,
"i16": torch.int16,
"i32": torch.int32,
"i64": torch.int64,
"b": torch.bool,
"f8_e4m3fnuz": torch.float8_e4m3fnuz,
}

_DTYPE_TO_SHORT_NAME: dict[torch.dtype, str] = {
v: k for k, v in _SHORT_NAME_TO_DTYPE.items()
}

AnyTensor = Union[torch.Tensor, InferenceTensor]

########################################################################################
Expand Down
Loading

0 comments on commit 10cd58f

Please sign in to comment.