-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate Marlin Kernels for Int4 GPTQ inference #2497
Conversation
Confirmed things are working properly with TP=2 on llama-2-7b and TP=4 on llama-2-13b
|
Hi @rib-2 You are so efficient. My original plan is to wait for AutoGPTQ/AutoGPTQ#514 to be merged and then submit it according to this. Our ideas coincide and nice to see your PR. |
It might be better if you can add performance testing with different batches and output diff verification compared with |
Two comments on column-wise and act-order:
On the CUDA graph problem, perhaps this is related to CUDA streams, which are currently not handled explicitly by the Marlin kernel launch? |
@efrantar I think the For instance, I cannot find a Mistral variant on the HF hub that has |
Also, here are a few models I converted to Marlin format: |
I think a simple fix would be to just do the perm in PyTorch with In general, I think one should be a bit more careful with this setting, it is really only helpful on a small amount of outlier models (I know only of Llama1-7B and OPT-66B) and has nonzero impact on inference speed. |
I think its due to the fact that the output buffer is allocated with dynamic shape in note ---> I havent really seen much of a performance delta from running w. or without cudagraphs for fp16 |
Thanks, will look into this as a starting point |
Regarding cuda graph, there're two things need to be fixed.
auto new_lock = at::zeros({256}, at::device({at::kCUDA, dev}).dtype(torch::kInt32));
int* locks = (int*)new_lock.data_ptr(); |
@efrantar Edit) |
…re (eager_force=False)
Pushed changes to support CUDA graphs. @chu-tianxiang thanks for the input on what needs to be done! |
Hi, I fixed the workspace related bug mentioned above and also think I did the necessary changes for CUDA streams in the the most recent commit of the Marlin repo. Now the workspace must always be 0 and the last write in the kernel will reset it; this way we avoid the |
@alexm-nm can you take a look at adding elias's changes? |
Yeah adding them now |
@simon-mo - this is ready for your review I have added some tests using a couple models that I created and formatted in the Marlin structure. The tests basically compare the output of the current GPTQ kernels against the output of the Marlin kernels. A couple notes about this.
Additionally, note that I noticed in Note --> I see that the build is red. It looks like this has nothing to do with the Marlin code (e.g. the regression test is red because we cannot download the model from the HF hub). Is there anything I can do to resolve this? https://buildkite.com/vllm/ci/builds/1383#018dbdb4-2348-4f49-abfe-c6f0da837c19 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall. This is also a clean PR. @WoosukKwon and @zhuohan123 PTAL
@simon-mo Is there any way I can update the buildkite CI to not run the tests for marlin with Getting a torch reinit error since I create the vLLM runner twice in these tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robertgshaw2-neuralmagic LGTM. Thanks for submitting the PR! The code is clean and its speedup looks very promising. Hope we can use the kernel for act_order=True
as well.
We will merge the PR for the new release. Please address my minor comments in the next PR.
@@ -48,11 +48,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |||
&rotary_embedding, | |||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); | |||
|
|||
// Quantization ops | |||
// Quantization ops |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
// Quantization ops | |
// Quantization ops |
return [torch.half] | ||
|
||
@classmethod | ||
# Need to figure it out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Need to figure it out |
# Validate output_size_per_partition | ||
if output_size_per_partition % self.quant_config.min_n_threads != 0: | ||
raise ValueError( | ||
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}." | ||
) | ||
if output_size_per_partition % self.quant_config.pack_factor != 0: | ||
raise ValueError( | ||
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}." | ||
) | ||
|
||
# Validate input_size_per_partition | ||
if input_size_per_partition % self.quant_config.min_k_threads != 0: | ||
raise ValueError( | ||
f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}." | ||
) | ||
if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0: | ||
raise ValueError( | ||
f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please break the lines that exceed the maximum length (80).
) | ||
|
||
# Determine if channelwise or not | ||
input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
and "is_marlin_format" in hf_quant_config | ||
and hf_quant_config["is_marlin_format"]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we simplify this to
and getattr(hf_quant_config, "is_marlin_format", False)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, we need to disable this for GPUs that do not support the Marlin kernel.
Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com> Co-authored-by: alexm <alexm@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com> Co-authored-by: alexm <alexm@neuralmagic.com>
Draft PR to integrate IST-A's new Marlin kernels for running GPTQ models into vLLM
Currently
enforce_eager=True
for TP=1Todos:
desc_act=true
(current plan is adding some torch level permutations + ultimately to fuse the kernel)bias
parameter workchannelwise
quantization work (its supported by the kernels)