Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize DoRA in eval and no dropout #2122

Merged
merged 17 commits into from
Oct 16, 2024

Conversation

ariG23498
Copy link
Contributor

Fixes #2107

Comment on lines 588 to 592
if isinstance(dropout, nn.Identity):
print("no dropout, optimize here")
else:
print("dropout, same ops")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan did you envision something like this?

My intuition was:

  1. Figure out whether there is dropout or not
  2. Use a flag for dropout
  3. Pass the flag to the forward or DoRA layers -- where I would need to skip the alignment step and reuse x (the base model outputs)

Let me know if I am on the right track.

Note: I could not figure out a way to catch if the model was in eval mode. How would you have done it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think the dropout check is valid as is. Regarding eval mode, I think that checking self.training should work.

On how to proceed, my thinking was that if we find that we can make this optimization, we pass the base result as an additional argument to DoRA forward (default for that argument being None) and there, we use this base result if it's given and if not, we calculate it like we currently do. Could be that I'm missing something but that's my idea.

The good news is that since we have a working implementation, we can then compare the results using both approaches and it should be identical (of course not when there is dropout, but apart from that).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat solution!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ariG23498 ariG23498 marked this pull request as ready for review October 2, 2024 10:23
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this task. I have a few comments:

I don't think we need the do_optimize argument. When result is passed, let's use it, else don't.

Also, let's refactor this a bit and avoid an early return. Instead, let's do something like:

if result is not None:
    # let's also add a comment to explain
    base_result = ...
else:
    base _result = F.linear(x, transpose(weight, self.fan_in_fan_out))

Then in this line:

https://github.com/huggingface/peft/pull/2122/files#diff-bcb4d7d165949d2eb3eac3203b9067589261e52a1192b7a31f101a3ec98855acR99

replace F.linear(x, transpose(weight, self.fan_in_fan_out)) by base_result.

A small caveat I see with this approach is that we assume that we can just remove the bias to get the base result. This will work for normal nn.Linear layers but may not work for other types. But let's maybe not worry about that for now.

We should also run an example with and without the optimization to ensure that we get the same results and also a speedup and memory improvement.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update. I left a comment where I think the calculation is not quite right.

Also, it would be great if you could check a DoRA example to see the changes caused by this PR. Probably one of the existing examples could be used. We should ensure that:

  • The results are the same (assuming dropout being 0)
  • Training should be faster
  • Memory usage should be lower

bias = base_layer.bias
if bias is not None:
base_result = base_result - bias
result_dora = mag_norm_scale * base_result + mag_norm_scale * lora_result * scaling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, wait, should this not be the exact same calculation as in line 103? I.e. we should leave the condition after calculating the base_result and then do the same calculation of dora_result for both cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that I follow. With the base_result in place:

  1. We first subtract the bias
  2. Compute the dora_result where the scale the base_result with mag_norm_scale

But without the base_result:

  1. We compute the base_result with the linear forward
  2. Compute the dora_result where we scale the base_result with (1 - mag_norm_scale)

Aren't they going to be different for each case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so I'm a bit confused, let's try to resolve this.

In the old code, we basically have:

dora_result = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * lora_scale

variable names slightly changed for clarity

My thinking is that the base_result is either calculated right there (old code) or we use the base_result that is being passed as an argument, but the basic equation stays the same.

Of course, as you correctly noted, the bias needs to be subtracted first and then added back in the latter case.

In the currently proposed code, in one case we calculate mag_norm_scale * base_result and in the other (mag_norm_scale - 1) * base_result. This looks inconsistent to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @BenjaminBossan

I have made the changes as suggested.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the change is still not correct. Here is my suggestion:

        bias = None
        if base_result is not None:
            bias = base_layer.bias
            if bias is not None:
                result = result - bias
        else:
            base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))

        result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling

        if bias is not None:
            result_dora = result + bias

This way, if base_result = None, the computation is exactly the same as it was previously.

I believe the confusion may stem from my comment:

        # result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling

This comment should have been:

        # result = (mag_norm_scale - 1) * result + mag_norm_scale * lora_B(lora_A(x)) * scaling

Does that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense!

About the unit test -- do you want me to add a new test file? Or add a test somewhere?

@ariG23498
Copy link
Contributor Author

@BenjaminBossan could you help me with the initial thought for an example that explains the efficiency of the process. As you have said, you would like to see memory_usage and faster_training. I am not very sure how I could explain the efficiency of memory usage in a colab notebook. Any resource would be very helpful.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the further work.

As you have said, you would like to see memory_usage and faster_training. I am not very sure how I could explain the efficiency of memory usage in a colab notebook. Any resource would be very helpful.

I think we don't need to be overly strict for this, we could just run an example or two, once before and once after the changes, and monitor memory usage and runtime. I can also do that once we're ready to test.

For unit tests, we should just ensure that they cover this new code path, measuring memory and runtime is not reliable enough for unit testing.

bias = base_layer.bias
if bias is not None:
base_result = base_result - bias
result_dora = mag_norm_scale * base_result + mag_norm_scale * lora_result * scaling
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the change is still not correct. Here is my suggestion:

        bias = None
        if base_result is not None:
            bias = base_layer.bias
            if bias is not None:
                result = result - bias
        else:
            base_result = F.linear(x, transpose(weight, self.fan_in_fan_out))

        result_dora = (mag_norm_scale - 1) * base_result + mag_norm_scale * lora_result * scaling

        if bias is not None:
            result_dora = result + bias

This way, if base_result = None, the computation is exactly the same as it was previously.

I believe the confusion may stem from my comment:

        # result = mag_norm_scale * result + mag_norm_scale * lora_B(lora_A(x)) * scaling

This comment should have been:

        # result = (mag_norm_scale - 1) * result + mag_norm_scale * lora_B(lora_A(x)) * scaling

Does that make sense?

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are still some lingering issues, as witnessed by the failing tests. Could you check my suggestion please?

src/peft/tuners/lora/layer.py Outdated Show resolved Hide resolved
src/peft/tuners/lora/dora.py Outdated Show resolved Hide resolved
@BenjaminBossan
Copy link
Member

@ariG23498 Could you please run make style so that the CI can pass?

@nbasyl
Copy link

nbasyl commented Oct 14, 2024

Hi @ariG23498 and @BenjaminBossan, thanks for the effort, I have reviewed the code and found no issue with it. Please let me know or tag me if you need my help, thanks!

@BenjaminBossan
Copy link
Member

Thanks @ariG23498 for the latest fixes and @nbasyl for the review.

I did a small test using this DoRA script by calling:

CUDA_VISIBLE_DEVICES=0 time python dora_finetuning.py --quantize --lora_dropout 0 --use_dora

(I changed grad acc steps from 16 to 2). For this to work, I had to propagate the DoRA changes from this PR to the bitsandbytes layers.

What I found is:

  • PEFT main:
    {'train_runtime': 14.0129, 'train_samples_per_second': 0.714, 'train_steps_per_second': 0.357, 'train_loss': 10.531291198730468, 'epoch': 0.0}
  • This PR:
    {'train_runtime': 11.8011, 'train_samples_per_second': 0.847, 'train_steps_per_second': 0.424, 'train_loss': 10.531893920898437, 'epoch': 0.0}

I also monitored memory and it went down from 7557MiB to 7325MiB.

So the final losses are not 100% identical, but I think it's within rounding error. Runtime was improved and memory usage slightly decreased with this PR.

Overall, I believe these are nice results and we can continue with this PR. @ariG23498 could you please propagate the changes to the quantized LoRA layers types that support it. We could probably also document this to let users know that they should consider disabling dropout for DoRA training to benefit from this optimization, with some numbers to underline this.

@ariG23498
Copy link
Contributor Author

Thanks for the detailed reply @BenjaminBossan
I am glad that the memory usage went down and the runtime also improved.

@ariG23498 could you please propagate the changes to the quantized LoRA layers types that support it.

Do you mean all the variants found here? Also I think it would be better to have the current change made to DoRA only, and then create another PR for the rest of the layers, WDYT?

@charchit7
Copy link

@ariG23498 @BenjaminBossan very nice PR. I learned a lot. @ariG23498 let me know if I can be of help to propagate the changes, maybe in separate PR.

@BenjaminBossan
Copy link
Member

Do you mean all the variants found here? Also I think it would be better to have the current change made to DoRA only, and then create another PR for the rest of the layers, WDYT?

Yes, so what I mean is that e.g. in lora/bnb.py, the DoRA call has to be adjusted in the same way as in lora/layer.py. It should be quite straightforward, because you can reuse the same code everywhere. I'd prefer this to be consistent before merging the PR.

Other than that, let's add a mention of this optimization in the docs. Ideally, we can add some numbers. I can re-run the experiment mentioned above and give some definitive numbers if you want.

let me know if I can be of help to propagate the changes, maybe in separate PR.

Thanks for the offer. As mentioned, let's try to get it into this PR. @ariG23498 up to you if/how you want to split up the work.

@ariG23498
Copy link
Contributor Author

@BenjaminBossan I have made the changes.

@charchit7 thank you for the offer, but as this is a redundant piece of code, I thought it was better to make the changes myself. Please feel free to take up other issues and comment for collaboration 🤗

@BenjaminBossan
Copy link
Member

Thanks for the update. Let's also add it here:

x = dropout(x)
result = result + self.lora_magnitude_vector[active_adapter](
x,
lora_A=lora_A,
lora_B=lora_B,
scaling=scaling,
base_layer=self.get_base_layer(),
)

The other layer types don't seem to properly implement DoRA yet, so we can keep those to a separate PR.

Would you also be so kind to add to the docs?

@charchit7
Copy link

@BenjaminBossan I have made the changes.

@charchit7 thank you for the offer, but as this is a redundant piece of code, I thought it was better to make the changes myself. Please feel free to take up other issues and comment for collaboration 🤗

Yes, I completely understand.

Thank you, yes, will do :)

@BenjaminBossan
Copy link
Member

Thanks for the updates. I'll re-run the script later, as the first test was only very short, to get some final numbers to report.

@BenjaminBossan
Copy link
Member

Update: So I re-ran the script for a while longer and with a higher batch size (before it was 1), using:

$CUDA_VISIBLE_DEVICES=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora

I also set gradient_accumulation_steps=2 and max_steps=20.

What I found is that training was 20% (wall time) to 23% (transformes reported time) faster. However, there no longer was any memory advantage, not quite sure what was different the first time around.

  • before:
    {'train_runtime': 359.7298, 'train_samples_per_second': 1.779, 'train_steps_per_second': 0.056, 'total_flos': 1.303253870444544e+16, 'train_loss': 9.653419399261475, 'epoch': 0.06493506493506493, 'step': 20}
  • after
    {'train_runtime': 279.2676, 'train_samples_per_second': 2.292, 'train_steps_per_second': 0.072, 'total_flos': 1.303253870444544e+16, 'train_loss': 9.643538236618042, 'epoch': 0.06493506493506493, 'step': 20}

image

Losses aligned quite nicely, with only a rounding error level of difference:

image

@ariG23498 Could you please update the docs accordingly (no need to mention the loss, as this is expected).

@ariG23498
Copy link
Contributor Author

@BenjaminBossan the docs have been updated! The benchmark results are crazy.

Comment on lines 141 to 142
DoRA is optimized (computes faster and takes less memory) for models in the evaluation mode, or when dropout is set to 0. We reuse the
base result at those times to get the speedup. Running [dora finetuning](https://github.com/huggingface/peft/blob/main/examples/dora_finetuning/dora_finetuning.py) with `CUDA_VISIBLE_DEVICES=0 time python dora_finetuning.py --quantize --lora_dropout 0 --use_dora` these were the observations:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please update the command according to reflect the last one I used (including the batch size and grad acc)? Also, no need for line breaks in the docs. You could also mention that this was tested on a 4090.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this DoRA speed improvement and applying it to all relevant layers (+ some other minor fixes you did in this PR). LGTM.

Let's keep in mind that we can still roll this out to more layer types (see #2153) and not all quants implement DoRA yet.

@BenjaminBossan BenjaminBossan merged commit 338aeff into huggingface:main Oct 16, 2024
14 checks passed
sirluk pushed a commit to sirluk/peft that referenced this pull request Oct 19, 2024
yaswanth19 pushed a commit to yaswanth19/peft that referenced this pull request Oct 20, 2024
yaswanth19 pushed a commit to yaswanth19/peft that referenced this pull request Oct 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimize DoRA computation when there is no dropout
5 participants