Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Support quantization for VeRA using bitsandbytes (#2070) #2076

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

ZiadHelal
Copy link

@ZiadHelal ZiadHelal commented Sep 18, 2024

This PR introduces support for 4-bit and 8-bit quantization in the VeRA method, leveraging bitsandbytes.

Addresses #2070

Changes made:

  • Created bnb.py for both 8-bit & 4-bit linear layers
  • Updated model.py for quantization handling and refactored _find_dim
  • Modified init.py for new imports
  • Ensured VeRA-specific logic in quantized setting

@ZiadHelal ZiadHelal marked this pull request as draft September 18, 2024 16:09
@ZiadHelal
Copy link
Author

Hey @BenjaminBossan could you take a look at this draft PR? I will remove the separate (quantization) config from config.py in vera folder later on, I just want to know if I'm going in the right direction or not.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on adding 8-bit quantization to VeRA. Super nice!

I haven't done an in-depth review, only a high level check. Here are some observations:

  1. The most important next step IMHO is to add one or a few tests, similar to what we have for LoRA + bnb (see tests/test_common_gpu.py). Let's validate quickly that the approach works.
  2. I see you copied the dispatch_bnb_8bit similar to what we have in LoRA. Note, however, that as of now, LoRA is the only PEFT method that uses that dispatch approach. This is because there are now so many LoRA "backends" that doing otherwise became unwieldy. For VeRA, we're not at that stage yet, so I would remove it. Here is how the code looked like before the dispatch refactor. I would also be fine with adding the dispatch refactor in this PR too, but then we'd need the full refactor, i.e. including dispatch_default.
  3. Do you plan on adding 4 bit as well?
  4. Regarding what you mentioned, yes, I agree:

I will remove the separate (quantization) config from config.py in vera folder later on

@ZiadHelal
Copy link
Author

Hey @BenjaminBossan,

I've finished the 8-bit quantization and now it works with all tests passed, however, with 4-bit it's a bit tricky due to bnb packing of weights implementation. I've added a work-around (which is not correct) in the forward method and it can now train any model but again I think it's not correct, also it fails for the test_vera_bnb_quantization_from_pretrained_safetensors test in 4-bit.

I would very much appreciate it if you could direct me in which way to proceed for the 4-bit implementation.

@ZiadHelal ZiadHelal marked this pull request as ready for review September 22, 2024 08:11
@BenjaminBossan
Copy link
Member

Thanks for the updates.

I've added a work-around (which is not correct) in the forward method and it can now train any model but again I think it's not correct

Could you give me a pointer what lines of code exactly you mean here?

also it fails for the test_vera_bnb_quantization_from_pretrained_safetensors test in 4-bit.

For me the test fails too, albeit already during initialization, not in the forward pass.

Also pinging @vvvm23 and @dkopi for awareness.

@ZiadHelal ZiadHelal changed the title (#2070) Draft PR for 8-bit quantization support for VeRA using bitsandbytes FEAT: Support quantization for VeRA using bitsandbytes (#2070) Sep 23, 2024
@ZiadHelal
Copy link
Author

Hey @BenjaminBossan, it now works on my side! thanks for your help.

Could you check if it works now with you?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

Thanks @ZiadHelal I can confirm that the tests are now passing on my machine. I think what would be great is if we could take one of the VeRA examples and verify that it works with 4bit and 8bit bnb. Of course, results won't be exactly the same, but we should expect roughly similar outcomes, probably slightly worse. This would be a nice confirmation that the implementation is correct. Is that something you would be willing to tackle?

Apart from that, please run make style on your PR, so that our linter is happy and tests can be run.

@ZiadHelal
Copy link
Author

Hi @BenjaminBossan, sorry for my late reply!

I've run make style and the code should now be good for the linter. Regarding the tests, I've added several tests primarily for CausalLM but I would be happy to add further for the audio and seq2seq models or if you have any other tests in mind, please let me know.

@ZiadHelal
Copy link
Author

I've run again the make style command now. Hope it works!

@BenjaminBossan
Copy link
Member

Ouch, a bunch of tests are failing. Could you please investigate? Please LMK if you need help.

@ZiadHelal
Copy link
Author

I've run the tests that are failing on my machine specifically test_initialization.py, test_feature_extraction_models.py, test_decoder_models.py, and test_custom_models.py. It appears that the refactored lines that you suggested for the function _find_dim in model.py in VeRA are causing these errors. Idk how to proceed next whether are these lines are correct and we need to adjust the tests accordingly or implementing a work around for _find_dim to get the dimensions of vera_A & vera_B based on whether quantization is used or not.

@ZiadHelal
Copy link
Author

This approach is what I have in mind (I haven't implemented it yet but should work ig).

    def _find_dim(self, config) -> tuple[int, int]:
        """
        Finds the largest input and output dimensions across linear layers that have been wrapped with VeRA.

        This will be used for determining the size of the shared vera_A and vera_B matrices.
        """
        model_config = self.get_model_config(self.model)

        peft_config = self._prepare_adapter_config(config, model_config)
        peft_config = _maybe_include_all_linear_layers(peft_config, self.model)

        loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)

        largest_shape = None
        for key, module in self.model.named_modules():
            if not self._check_target_module_exists(peft_config, key):
                continue

            if loaded_in_4bit:
                if isinstance(module, nn.Linear):
                    module_shape = module.in_features, module.out_features
                elif isinstance(module, Conv1D):
                    module_shape = module.weight.ds_shape if hasattr(module.weight, "ds_shape") else module.weight.shape
                else:
                    continue
            else:
                if isinstance(module, (nn.Linear, Conv1D)):
                    module_shape = tuple(module.weight.shape)
                    if isinstance(module, Conv1D):
                        module_shape = module_shape[::-1]
                else:
                    continue

            if largest_shape is None:
                largest_shape = module_shape
                continue

            if module_shape != largest_shape:
                largest_shape = tuple(max(a, b) for a, b in zip(largest_shape, module_shape))

        if largest_shape is None:
            msg = "No layers types compatible with VeRA were found. Please check `peft_config.target_modules`."
            raise ValueError(msg)

        return largest_shape

@BenjaminBossan
Copy link
Member

I took a closer look and the code that I suggested had a simple error, I was returning the shapes in the wrong order. So the correct code should be:

            if isinstance(module, nn.Linear):
                module_shape = module.out_features, module.in_features
            elif isinstance(module, Conv1D):
                module_shape = module.weight.ds_shape if hasattr(module.weight, "ds_shape") else module.weight.shape
                module_shape = module_shape[::-1]

As to your suggestion: Yes, possibly there needs to be some special handling for quantized weights. I haven't checked that yet.

@ZiadHelal
Copy link
Author

I added your line suggestions and it now works for all the tests except (I think) one test test_decoder_models.py which is complaining about something not related to VeRA. Maybe it won't fail in the pipeline. Can you run the tests again and see if it relates to VeRA, if so then I will apply my approach of adding the special handling for quantized weights.

@BenjaminBossan
Copy link
Member

one test test_decoder_models.py which is complaining about something not related to VeRA.

Which one is it?

@ZiadHelal
Copy link
Author

test_generate_half_prec this is the one but with several precisions

@ZiadHelal
Copy link
Author

ZiadHelal commented Sep 26, 2024

@BenjaminBossan, my bad. it should pass this test, sorry for the confusion. you can run the workflow now.

@BenjaminBossan
Copy link
Member

That would be strange, as VeRA tests are not run:

if config_cls not in (IA3Config, LoraConfig, PrefixTuningConfig):
return pytest.skip(f"Test not applicable for {config_cls}")

Anyway, I'll start the CI, let's see.

@ZiadHelal
Copy link
Author

Are the 11 failing tests related to VeRA?

@BenjaminBossan
Copy link
Member

No, I don't think that they're related. This could possibly be caused by the latest transformers release, not sure. I'll investigate tomorrow.

@ZiadHelal
Copy link
Author

Okay, thanks!

@BenjaminBossan
Copy link
Member

Small update, it is indeed unrelated, the tests started breaking due to a recent change in transformers. I'm trying to get to the bottom of it.

@ZiadHelal
Copy link
Author

Ok got it, thanks for the update!

@BenjaminBossan
Copy link
Member

The fix is now merged. Once you merge with/rebase on main, the tests should pass.

@ZiadHelal
Copy link
Author

Synced with main upstream.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the latest updates. I somehow forgot about this PR, sorry for my silence, don't hesitate to ping me if you don't hear back for a few days.

Overall, this looks really good and is ready to be merged except for two small issues. One I have commented on. The other is that the docs need updates to reflect that VeRA now supports bnb quantization:

- Quantized layers are not supported.

Besides that, I think we should also extend quantization.md. I'm not 100% sure what the best way is to update the document, but maybe the easiest is to add a new section before the "Next steps" section called "Other supported PEFT methods". There, you could mention that besides, LoRA, bnb also works with VeRA.

@@ -0,0 +1,408 @@
# Copyright 2023-present the HuggingFace Inc. team.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update the date to 2024.

@BenjaminBossan
Copy link
Member

Ah, forgot to mention, I ran a test modified from the VeRA.ipynb notebook comparing no quantization vs 8bit vs 4bit. Overall, the results were quite similar, with a bit of degradation for 4bit, but that's expected, so I think the implementation works as expected.

@ZiadHelal
Copy link
Author

Sorry for late reply. I did the requests you mentioned and added also for the docs the other peft methods that support quantization which weren't added at the first place.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants