Intel® Extension for PyTorch* v2.3.0+cpu Release Notes
We are excited to announce the release of Intel® Extension for PyTorch* 2.3.0+cpu which accompanies PyTorch 2.3. This release mainly brings you the new feature on Large Language Model (LLM) called module level LLM optimization API, which provides module level optimizations for commonly used LLM modules and functionalities, and targets to optimize customized LLM modeling for scenarios like private models, self-customized models, LLM serving frameworks, etc. This release also extends the list of optimized LLM models to a broader level and includes a set of bug fixing and small optimizations. We want to sincerely thank our dedicated community for your contributions. As always, we encourage you to try this release and feedback as to improve further on this product.
Highlights
-
Large Language Model (LLM) optimization
Intel® Extension for PyTorch* provides a new feature called module level LLM optimization API, which provides module level optimizations for commonly used LLM modules and functionalities. LLM creators can then use this new API set to replace related parts in models by themselves, with which to reach peak performance.
There are 3 categories of module level LLM optimization APIs in general:
- Linear post-op APIs
# using module init and forward ipex.llm.modules.linearMul ipex.llm.modules.linearGelu ipex.llm.modules.linearNewGelu ipex.llm.modules.linearAdd ipex.llm.modules.linearAddAdd ipex.llm.modules.linearSilu ipex.llm.modules.linearSiluMul ipex.llm.modules.linear2SiluMul ipex.llm.modules.linearRelu
- Attention related APIs
# using module init and forward ipex.llm.modules.RotaryEmbedding ipex.llm.modules.RMSNorm ipex.llm.modules.FastLayerNorm ipex.llm.modules.VarlenAttention ipex.llm.modules.PagedAttention ipex.llm.modules.IndirectAccessKVCacheAttention # using as functions ipex.llm.functional.rotary_embedding ipex.llm.functional.rms_norm ipex.llm.functional.fast_layer_norm ipex.llm.functional.indirect_access_kv_cache_attention ipex.llm.functional.varlen_attention
- Generation related APIs
# using for optimizing huggingface generation APIs with prompt sharing ipex.llm.generation.hf_beam_sample ipex.llm.generation.hf_beam_search ipex.llm.generation.hf_greedy_search ipex.llm.generation.hf_sample
More detailed introduction on how to apply this API set and example code walking you through can be found here.
-
Bug fixing and other optimization
- Optimized the performance of LLM #2561 #2584 #2617 #2663 #2733
- Supported Act Order of GPTQ #2550 #2568
- Improved the warning and the logging information for better user experience #2641 #2675
- Added TorchServe CPU Example #2613
- Upgraded oneDNN to v3.4.1 #2747
- Misc fix and enhancement #2468 #2627 #2631 #2704
Full Changelog: v2.2.0+cpu...v2.3.0+cpu