Skip to content

Conversation

Che-Xu
Copy link
Contributor

@Che-Xu Che-Xu commented Sep 21, 2025

Resolves #2786

Description

This PR addresses two issues identified while using X-LoRA with the Qwen2-VL-7B model.

Issue 1: Internal Scaling Storage Problem

Location: _enable_peft_forward_hooks() in src/peft/tuners/xlora/model.py
Problem: After calling enable_scalings_logging(), the subsequent call to get_latest_scalings() returned None because computed xlora_scalings were not properly stored for later retrieval.

Issue 2: Incorrect Probability Normalization

Location: get_maybe_topk_scalings() in src/peft/tuners/xlora/layer.py
Problem: The current implementation incorrectly normalized expert probabilities such that the sum over all tokens was 1, rather than summing to 1 per token.

Implementation

  • src/peft/tuners/xlora/model.py: Added storage of computed scalings in _enable_peft_forward_hooks()
xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real)
# Store computed scalings to fix get_latest_scalings() returning None
self.internal_xlora_scalings = xlora_scalings
  • src/peft/tuners/xlora/layer.py: Fixed normalization logic in get_maybe_topk_scalings()
# Apply per-token normalization to the xLoRA scaling factors using a softmax
if self.config.enable_softmax_topk:
    nonzero_mask = xlora_scalings != 0
    full = xlora_scalings.masked_fill(~nonzero_mask, float("-inf"))
    new_scalings = torch.softmax(full, dim=-1)
    xlora_scalings = new_scalings.masked_fill(~nonzero_mask, 0.0)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Sep 23, 2025

Thanks a lot for creating this PR to fix the issues you identified. At a glance, the changes look good. Ideally, we should also have unit tests to check for these bugs. I think it shouldn't be too hard to add them by extending test_xlora.py. Would you be interested in adding these tests @Che-Xu?

PS: The failing CI is unrelated and can be ignored.

Copy link
Member

@EricLBuehler EricLBuehler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great 👍!

@BenjaminBossan
Copy link
Member

@Che-Xu Do you still plan to work on this?

@Che-Xu
Copy link
Contributor Author

Che-Xu commented Oct 6, 2025

Hi @BenjaminBossan,

Thank you for the reminder and my sincere apologies for the delayed response. I've been occupied with other projects over the past two weeks.

I'm still committed to this and will submit the unit tests within the next two days. I understand the importance of completing this and appreciate you checking in.

Again, sorry for the delay and thank you for your patience. I'll prioritize this task and keep you updated.

@BenjaminBossan
Copy link
Member

@Che-Xu No worries, take the time you need. My ping was just a reminder, as sometimes people forget or miss notifications. Feel free to let me know if you have any questions.

@Che-Xu
Copy link
Contributor Author

Che-Xu commented Oct 8, 2025

@BenjaminBossan

Thanks for the feedback! I've extended test_xlora.py with two new unit tests:

  1. test_scalings_storage()
  • Verifies that scaling values are properly stored after generation
  • Validates that get_latest_scalings() returns non-None torch.Tensors
  • Ensures all scaling values are finite and numerically stable
  1. test_per_token_normalization_with_softmax_topk()
  • Implements a testing approach using monkey-patching to hook into the forward pass of XLoraLinear layers
  • Captures scaling values during the second forward pass and validates the get_maybe_topk_scalings() function behavior
  • Validates both the existence and mathematical correctness of normalized scalings
  • Includes shape validation, batch processing checks, and numerical stability assertions

The tests have been added to the existing PR. Please let me know if you'd like me to adjust anything in the test coverage!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the unit tests. The test_scalings_storage looks good but I believe the test_per_token_normalization_with_softmax_topk test can be greatly simplified if we focus on the main aspect that we want to test. Please check my proposal.

Comment on lines 479 to 481
if normalized_scalings is None:
assert normalized_scalings is not None, (
f"Missing normalized_scalings in layer {data['layer']} {data['projection']}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if check can be removed, right?

continue

if hasattr(normalized_scalings, "cpu"):
scalings_np = normalized_scalings.cpu().detach().numpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary to move the array to numpy?

for t in range(seq_len):
weights = scalings_np[b, t, :]
weight_sum = weights.sum()
assert np.isclose(weight_sum, 1.0, atol=1e-5), (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the essential part of the test. I would focus on this assert, no need to check the other stuff and also no need to report the layer, batch, and token in detail.

assert torch.isfinite(latest_scalings).all(), "Scalings should contain finite values"

def test_per_token_normalization_with_softmax_topk(self, tokenizer, model):
captured_data = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the whole test can be greatly simplified if we are content with only logging the scalings. I think that should be enough, as the other logged data is just needed for a nicer error message and I believe we can do without that.

Here is my proposal:

from peft.tuners.xlora.layer import XLoraLayer

...
    def test_per_token_normalization_with_softmax_topk(self, tokenizer, model, monkeypatch):
        orig_get_maybe_topk_scalings = XLoraLayer.get_maybe_topk_scalings

        captured_data = []
        def mock_get_maybe_topk_scalings(*args, **kwargs):
            result = orig_get_maybe_topk_scalings(*args, **kwargs)
            captured_data.append(result)
            return result

        monkeypatch.setattr(XLoraLayer, "get_maybe_topk_scalings", mock_get_maybe_topk_scalings)

        model.enable_scalings_logging()
        inputs = tokenizer.encode("Test per token normalization", add_special_tokens=False, return_tensors="pt")
        outputs = model.generate(
            input_ids=inputs.to(self.torch_device),
            max_new_tokens=1,
        )
        for scaling in captured_data:
            assert ...

@Che-Xu
Copy link
Contributor Author

Che-Xu commented Oct 9, 2025

@BenjaminBossan,

Thank you for your suggestion! The revised version is much cleaner and clearer, and I have learned a lot from your approach.

One small addition I made is that, since XLoRA performs two forward passes (a dummy pass and a real pass), and we only want to capture the scalings from the real pass, I included the check if getattr(model, "internal_xlora_scalings", None) is not None: in the mock function. This ensures we only record the normalized scalings that are actually used in the real forward pass.

I have already made these changes in the existing PR. Thank you again for your guidance!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for identifying and fixing these two issues with X-LoRA, the changes LGTM. Failing tests are unrelated.

@BenjaminBossan BenjaminBossan merged commit e9f5707 into huggingface:main Oct 9, 2025
5 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Two issues with [X-LoRA] Implementation: Scalings Logging and Top-K Softmax Normalization

4 participants