Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

amd fp8 rowwise gemm prefill shape tuning #3607

Closed
wants to merge 1 commit into from

Conversation

mxz297
Copy link
Contributor

@mxz297 mxz297 commented Jan 23, 2025

Summary:
This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases. Consider input [M, K] and weight [N, K].

For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M.

For each combination of N and K, there is offline tuning for many M, looking like:

5280, 8192, 3584, 0.318272, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5312, 8192, 3584, 0.322179, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5344, 8192, 3584, 0.320632, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5440, 8192, 3584, 0.341432, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5472, 8192, 3584, 0.3436, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5536, 8192, 3584, 0.341703, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5568, 8192, 3584, 0.342054, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5632, 8192, 3584, 0.347904, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5664, 8192, 3584, 0.345129, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1

A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic.

The full tuning log is parsed and converted into a std::map for range based lookup.

One key question here is which instance to use right at the range where the best instance has changed. For example:

5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1

Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407? The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values.

Differential Revision: D68521662

Copy link

netlify bot commented Jan 23, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 9f84364
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/6792ddb372044900088f48da
😎 Deploy Preview https://deploy-preview-3607--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68521662

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68521662

mxz297 added a commit to mxz297/FBGEMM that referenced this pull request Jan 23, 2025
Summary:
X-link: facebookresearch/FBGEMM#685

Pull Request resolved: pytorch#3607

This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases.  Consider input [M, K] and weight [N, K].

For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M.

For each combination of N and K, there is offline tuning for many M, looking like:

```
5280, 8192, 3584, 0.318272, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5312, 8192, 3584, 0.322179, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5344, 8192, 3584, 0.320632, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5440, 8192, 3584, 0.341432, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5472, 8192, 3584, 0.3436, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5536, 8192, 3584, 0.341703, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5568, 8192, 3584, 0.342054, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5632, 8192, 3584, 0.347904, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5664, 8192, 3584, 0.345129, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
```

A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic.

The full tuning log is parsed and converted into a std::map for range based lookup.

One key question here is which instance to use right at the range where the best instance has changed. For example:

```
5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
```

Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407?  The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values.

Differential Revision: D68521662
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68521662

mxz297 added a commit to mxz297/FBGEMM that referenced this pull request Jan 24, 2025
Summary:
X-link: facebookresearch/FBGEMM#685

Pull Request resolved: pytorch#3607

This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases.  Consider input [M, K] and weight [N, K].

For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M.

For each combination of N and K, there is offline tuning for many M, looking like:

```
5280, 8192, 3584, 0.318272, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5312, 8192, 3584, 0.322179, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5344, 8192, 3584, 0.320632, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5440, 8192, 3584, 0.341432, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5472, 8192, 3584, 0.3436, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5536, 8192, 3584, 0.341703, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5568, 8192, 3584, 0.342054, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5632, 8192, 3584, 0.347904, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5664, 8192, 3584, 0.345129, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
```

A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic.

The full tuning log is parsed and converted into a std::map for range based lookup.

One key question here is which instance to use right at the range where the best instance has changed. For example:

```
5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
```

Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407?  The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values.

Reviewed By: jwfromm

Differential Revision: D68521662
Summary:
X-link: facebookresearch/FBGEMM#685

Pull Request resolved: pytorch#3607

This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases.  Consider input [M, K] and weight [N, K].

For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M.

For each combination of N and K, there is offline tuning for many M, looking like:

```
5280, 8192, 3584, 0.318272, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5312, 8192, 3584, 0.322179, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5344, 8192, 3584, 0.320632, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5440, 8192, 3584, 0.341432, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5472, 8192, 3584, 0.3436, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5536, 8192, 3584, 0.341703, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5568, 8192, 3584, 0.342054, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5632, 8192, 3584, 0.347904, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
5664, 8192, 3584, 0.345129, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
```

A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic.

The full tuning log is parsed and converted into a std::map for range based lookup.

One key question here is which instance to use right at the range where the best instance has changed. For example:

```
5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1
5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1
```

Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407?  The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values.

Reviewed By: jwfromm

Differential Revision: D68521662
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D68521662

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 74490d6.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants