Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix optimum.quanto quantization call in cache_utils #34606

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

w3rew
Copy link

@w3rew w3rew commented Nov 4, 2024

What does this PR do?

Fixes call to optimum.quanto.quantized_weight in QuantoQuantizedCache, which currently lacks scale and shift parameters and thus fails. This was introduced by cac4a48 when migrating to optimum.quanto I think.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc

@@ -813,7 +813,8 @@ def _quantize(self, tensor, axis):
if is_optimum_quanto_available():
from optimum.quanto import quantize_weight

qtensor = quantize_weight(tensor, self.qtype, axis, self.q_group_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you have a look @zucchini-nlp as you did the change in this PR. Looking at optimum-quanto source code, quantize_weight do require to pass scale.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now when I look I think it is related to the version of optimum quanto. I see pre v0.24 had no scale

https://github.com/huggingface/optimum-quanto/blob/832f7f5c3926c91fe4f923aaaf037a780ac3e6c3/optimum/quanto/tensor/qweight.py#L32-L39

But after v0.24 we have to pass scale in before group size

https://github.com/huggingface/optimum-quanto/blob/f3c400e9b5b28b499f87c30325f8628d50417eef/optimum/quanto/tensor/weights/quantization.py#L27-L36

@SunMarc if that is correct, prob we need to check the version also. Seems like a lot of checks but since the old quanto should be removed in next v4.47 release, could be a workaround. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me ! Also what was the issue with the prior implementation ? Was it to just simplify a bit the code ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean the prev PR I merged? It only made the code compatible with optimum-quanto v0.24 but I forgot there could be older versions. For the why optimum-quanto is changing its code, i have no idea. But would be nice if they wouldnt change it drastically anymore 😅

i guess you'll have more info about future maintaining plans in quanto :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@w3rew thanks for opening the PR, can you please update with suggested changes?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Will look into it shortly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @dacorvo for visibility

Copy link
Contributor

@dacorvo dacorvo Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp optimum-quanto 0.2.5 release was synchronized with the switch from quanto to optimum-quanto in transformers at the beginning of october, and the code in cache_utils.py was correct.
It is your pull-request to align with 0.2.4 (a version that was never supported by transformers) that was actually incorrect.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants