-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add T5 encoder bfloat16 support (#614)
Adds various tests for verifying bfloat16 execution. The bfloat16 eager execution matches Flux's output. One test verifies this against golden values from the Flux pipeline. I would say that the eager numerical error is still quite high. My initial idea was to compare the numerical error in IREE bfloat16 vs eager float32. It should match the error profile as the one we get between eager bfloat16 compared against eager float32. The problem is that the bfloat16 eager has a high error that needs further investigation. Due to this some tests are marked as xfail as we don't have a good metric to evaluate the IREE results.
- Loading branch information
Showing
13 changed files
with
615 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ | |
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from .sharding import * | ||
from .dataset import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import torch | ||
|
||
from ...types.tensors import InferenceTensor, PrimitiveTensor, DefaultPrimitiveTensor | ||
from ... import ops | ||
|
||
|
||
def set_float_dtype(tensor: InferenceTensor, dtype: torch.dtype) -> InferenceTensor: | ||
if isinstance(tensor, PrimitiveTensor) and tensor.dtype.is_floating_point: | ||
return DefaultPrimitiveTensor( | ||
name=tensor.name, data=ops.to(tensor, dtype=dtype) | ||
) | ||
|
||
return tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.