-
Notifications
You must be signed in to change notification settings - Fork 528
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
amd fp8 rowwise gemm prefill shape tuning #3607
Conversation
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
This pull request was exported from Phabricator. Differential Revision: D68521662 |
This pull request was exported from Phabricator. Differential Revision: D68521662 |
Summary: X-link: facebookresearch/FBGEMM#685 Pull Request resolved: pytorch#3607 This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases. Consider input [M, K] and weight [N, K]. For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M. For each combination of N and K, there is offline tuning for many M, looking like: ``` 5280, 8192, 3584, 0.318272, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5312, 8192, 3584, 0.322179, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5344, 8192, 3584, 0.320632, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5440, 8192, 3584, 0.341432, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5472, 8192, 3584, 0.3436, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5536, 8192, 3584, 0.341703, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5568, 8192, 3584, 0.342054, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5632, 8192, 3584, 0.347904, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5664, 8192, 3584, 0.345129, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 ``` A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic. The full tuning log is parsed and converted into a std::map for range based lookup. One key question here is which instance to use right at the range where the best instance has changed. For example: ``` 5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 ``` Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407? The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values. Differential Revision: D68521662
20e71d5
to
f4c6b0e
Compare
This pull request was exported from Phabricator. Differential Revision: D68521662 |
Summary: X-link: facebookresearch/FBGEMM#685 Pull Request resolved: pytorch#3607 This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases. Consider input [M, K] and weight [N, K]. For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M. For each combination of N and K, there is offline tuning for many M, looking like: ``` 5280, 8192, 3584, 0.318272, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5312, 8192, 3584, 0.322179, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5344, 8192, 3584, 0.320632, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5440, 8192, 3584, 0.341432, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5472, 8192, 3584, 0.3436, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5536, 8192, 3584, 0.341703, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5568, 8192, 3584, 0.342054, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5632, 8192, 3584, 0.347904, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5664, 8192, 3584, 0.345129, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 ``` A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic. The full tuning log is parsed and converted into a std::map for range based lookup. One key question here is which instance to use right at the range where the best instance has changed. For example: ``` 5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 ``` Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407? The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values. Reviewed By: jwfromm Differential Revision: D68521662
f4c6b0e
to
9acfa9a
Compare
Summary: X-link: facebookresearch/FBGEMM#685 Pull Request resolved: pytorch#3607 This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases. Consider input [M, K] and weight [N, K]. For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M. For each combination of N and K, there is offline tuning for many M, looking like: ``` 5280, 8192, 3584, 0.318272, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5312, 8192, 3584, 0.322179, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5344, 8192, 3584, 0.320632, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5440, 8192, 3584, 0.341432, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5472, 8192, 3584, 0.3436, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5536, 8192, 3584, 0.341703, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5568, 8192, 3584, 0.342054, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5632, 8192, 3584, 0.347904, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5664, 8192, 3584, 0.345129, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 ``` A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic. The full tuning log is parsed and converted into a std::map for range based lookup. One key question here is which instance to use right at the range where the best instance has changed. For example: ``` 5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 ``` Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407? The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values. Reviewed By: jwfromm Differential Revision: D68521662
This pull request was exported from Phabricator. Differential Revision: D68521662 |
9acfa9a
to
9f84364
Compare
This pull request has been merged in 74490d6. |
Summary:
This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases. Consider input [M, K] and weight [N, K].
For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M.
For each combination of N and K, there is offline tuning for many M, looking like:
A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic.
The full tuning log is parsed and converted into a std::map for range based lookup.
One key question here is which instance to use right at the range where the best instance has changed. For example:
Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407? The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values.
Differential Revision: D68521662