Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Marlin Kernels for Int4 GPTQ inference #2497

Merged
merged 52 commits into from
Mar 1, 2024

Conversation

robertgshaw2-neuralmagic
Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic commented Jan 18, 2024

Draft PR to integrate IST-A's new Marlin kernels for running GPTQ models into vLLM

Screenshot 2024-01-18 at 2 45 36 PM

Currently

  • Have end-to-end working with enforce_eager=True for TP=1

Todos:

  • Figure out why gibberish generated when CUDAGraphs are active (alex claims due to temporary allocation)
  • Benchmark performance
  • Figure out if we can support desc_act=true (current plan is adding some torch level permutations + ultimately to fuse the kernel)
  • Ensure models with bias parameter work
  • Make channelwise quantization work (its supported by the kernels)

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

Confirmed things are working properly with TP=2 on llama-2-7b and TP=4 on llama-2-13b

  • the dimensions of the matrices get too small (no longer fit mod 256 for the rowparallel pieces) after these sizes

@zhyncs
Copy link
Contributor

zhyncs commented Jan 19, 2024

Hi @rib-2 You are so efficient. My original plan is to wait for AutoGPTQ/AutoGPTQ#514 to be merged and then submit it according to this. Our ideas coincide and nice to see your PR.

csrc/ops.h Outdated Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
@zhyncs
Copy link
Contributor

zhyncs commented Jan 19, 2024

It might be better if you can add performance testing with different batches and output diff verification compared with huggingface/transformers on multiple models. Cheers.

@efrantar
Copy link

Two comments on column-wise and act-order:

  • The former is actually supported with groupsize=-1 and is slightly faster (it reaches 4x rather than 3.87x on ideal matrix shape benchmarks).
  • The easiest way to implement act-order would be with a separate kernel that permutes the input before launching Marlin (and correspondingly keep the quantized weights in activation order); I would expect that this could be done with only a modest performance hit.

On the CUDA graph problem, perhaps this is related to CUDA streams, which are currently not handled explicitly by the Marlin kernel launch?

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

@efrantar I think the desc_act=True is pretty important for supporting the existing gptq models.

For instance, I cannot find a Mistral variant on the HF hub that has desc_act=False

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

@efrantar
Copy link

@efrantar I think the desc_act=True is pretty important for supporting the existing gptq models.

For instance, I cannot find a Mistral variant on the HF hub that has desc_act=False

I think a simple fix would be to just do the perm in PyTorch with A[:, perm]; from preliminary testing this yields between 5 - 20% overhead depending on matrix and batch size; which isn't great and can certainly be improved quite a bit (e.g., by fusing into the previous op and making sure it stays in L2 cache for the next kernel launch) but should be workable as a start for compatibility.

In general, I think one should be a bit more careful with this setting, it is really only helpful on a small amount of outlier models (I know only of Llama1-7B and OPT-66B) and has nonzero impact on inference speed.

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

robertgshaw2-neuralmagic commented Jan 19, 2024

On the CUDA graph problem, perhaps this is related to CUDA streams, which are currently not handled explicitly by the Marlin kernel launch?

I think its due to the fact that the output buffer is allocated with dynamic shape in apply_weights. @alexm-nm is looking into this

note ---> I havent really seen much of a performance delta from running w. or without cudagraphs for fp16

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

@efrantar I think the desc_act=True is pretty important for supporting the existing gptq models.
For instance, I cannot find a Mistral variant on the HF hub that has desc_act=False

I think a simple fix would be to just do the perm in PyTorch with A[:, perm]; from preliminary testing this yields between 5 - 20% overhead depending on matrix and batch size; which isn't great and can certainly be improved quite a bit (e.g., by fusing into the previous op and making sure it stays in L2 cache for the next kernel launch) but should be workable as a start for compatibility.

In general, I think one should be a bit more careful with this setting, it is really only helpful on a small amount of outlier models (I know only of Llama1-7B and OPT-66B) and has nonzero impact on inference speed.

Thanks, will look into this as a starting point

@chu-tianxiang
Copy link
Contributor

Regarding cuda graph, there're two things need to be fixed.

  1. In order to capture the graph, you have to get current stream to launch the kernel, see this pr.
  2. cudaMemset is a host function and cannot be captured. I tested creating a new tensor for the lock each loop and have no problem running in cuda graph mode, it's not a good solution though.
auto new_lock = at::zeros({256}, at::device({at::kCUDA, dev}).dtype(torch::kInt32));
int* locks = (int*)new_lock.data_ptr();

@qwopqwop200
Copy link

qwopqwop200 commented Jan 19, 2024

@efrantar
I tested marlin.
However, it works properly in llama, but in the case of opt(125m and 2.7b), nan is output after matmul. Since this is the result before adding bias, it does not seem to have anything to do with whether or not there is bias.
Is this a bug?

Edit)
This was just a dtype issue of the input.

@alexm-neuralmagic
Copy link
Collaborator

Pushed changes to support CUDA graphs. @chu-tianxiang thanks for the input on what needs to be done!

@efrantar
Copy link

efrantar commented Jan 19, 2024

Hi, I fixed the workspace related bug mentioned above and also think I did the necessary changes for CUDA streams in the the most recent commit of the Marlin repo. Now the workspace must always be 0 and the last write in the kernel will reset it; this way we avoid the memset() at no extra cost (actually, it seems a very tiny bit faster now).

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

@alexm-nm can you take a look at adding elias's changes?

@alexm-neuralmagic
Copy link
Collaborator

Yeah adding them now

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

@simon-mo - this is ready for your review

I have added some tests using a couple models that I created and formatted in the Marlin structure. The tests basically compare the output of the current GPTQ kernels against the output of the Marlin kernels.

A couple notes about this.

  • Marlin does not guarantee bitwise correctness vs the existing exllama kernels. As a result, in this test, we just confirm that the top 5 selected tokens of the Marlin model are in the top 5 selected tokens of the GPTQ model.

  • Marlin internally uses locks to synchronize the threads. This can result in very slight nondeterminism for Marlin. As a result, we re-run the test up to 3 times to see if we pass.

Additionally, note that I noticed in conftest.py that the code for reading the example prompts is only reading the first line. I fixed this up in the PR. Can split this out into a separate PR if you want.

Note --> I see that the build is red. It looks like this has nothing to do with the Marlin code (e.g. the regression test is red because we cannot download the model from the HF hub). Is there anything I can do to resolve this?

https://buildkite.com/vllm/ci/builds/1383#018dbdb4-2348-4f49-abfe-c6f0da837c19

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. This is also a clean PR. @WoosukKwon and @zhuohan123 PTAL

@robertgshaw2-neuralmagic
Copy link
Collaborator Author

@simon-mo Is there any way I can update the buildkite CI to not run the tests for marlin with --forked?

Getting a torch reinit error since I create the vLLM runner twice in these tests

@WoosukKwon WoosukKwon self-requested a review March 1, 2024 20:10
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic LGTM. Thanks for submitting the PR! The code is clean and its speedup looks very promising. Hope we can use the kernel for act_order=True as well.

We will merge the PR for the new release. Please address my minor comments in the next PR.

@@ -48,11 +48,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");

// Quantization ops
// Quantization ops
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
// Quantization ops
// Quantization ops

return [torch.half]

@classmethod
# Need to figure it out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Need to figure it out

Comment on lines +102 to +120
# Validate output_size_per_partition
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}."
)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}."
)

# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}."
)
if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please break the lines that exceed the maximum length (80).

)

# Determine if channelwise or not
input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@WoosukKwon WoosukKwon merged commit c0c2335 into vllm-project:main Mar 1, 2024
19 of 21 checks passed
Comment on lines +170 to +171
and "is_marlin_format" in hf_quant_config
and hf_quant_config["is_marlin_format"]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we simplify this to

and getattr(hf_quant_config, "is_marlin_format", False)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we need to disable this for GPUs that do not support the Marlin kernel.

xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com>
Co-authored-by: alexm <alexm@neuralmagic.com>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com>
Co-authored-by: alexm <alexm@neuralmagic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.