Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev quant squeeze #941

Closed
wants to merge 17 commits into from
Closed

Dev quant squeeze #941

wants to merge 17 commits into from

Conversation

ScXfjiang
Copy link

@ScXfjiang ScXfjiang commented Apr 21, 2024

The PR was submitted to add the Squeeze/Unsqueeze operation to QuantTensor. Issue: #891

Changes

  • Add Squeeze/Unsqueeze operation to QuantTensor.
  • For meta-data sensitive operations in QuantTensor (e.g., reshape, flatten, transpose, permute, squeeze, unsqueeze), add per-channel granularity test cases.

See more discussions in #891 #728.

@ScXfjiang
Copy link
Author

have formatted the code and added a quick fix @Giuseppe5

@ScXfjiang
Copy link
Author

Hi @Giuseppe5,

I've updated the code to be compatible with PyTorch < 2.0, and it should now pass the CI checks.

In PyTorch < 2.0, torch.squeeze() doesn't support multiple dims.

Changelog of PyTorch 2.0:
Update torch.squeeze to allow squeezing multiple dimensions at once (pytorch/pytorch#89017)

To make Brevitas flexible, QuantTensor.squeeze() always supports multiple dims, even when using PyTorch < 2.0.

The method seems more complex than other meta-data sensitive methods because it needs to guarantee that the value tensor and the meta-data tensors choose the same dims to squeeze.

tensor_meta = {
'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width}
for k, tm in tensor_meta.items():
if tm is not None and len(value.shape) == len(tm.shape) - len(sorted_target_dims):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadatas cannot be None anymore

@@ -46,6 +46,21 @@ def transpose_handler(inp, *args, **kwargs):
return inp.transpose(*args, **kwargs)


@implements(torch.permute)
def permute_handler(inp, *args, **kwargs):
return inp.permute(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need a permute implementation to permute metadata of quant tensor?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply. I will update my code next week : )

@Giuseppe5
Copy link
Collaborator

As you can see, in the meantime we merged another PR to add support for FloatQuantTensor.
The two have different metadatas, so they would need slightly adjusted implementation. Feel free to keep this PR focused on IntQuantTensor and I will open an issue to track the same problem for FloatQuantTensor.

The extension to that should be trivial but I don't want to expand the scope too much.

@ScXfjiang ScXfjiang closed this Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants