Skip to content

Gradient checkpointing and Activated LoRA models #2826

@kgreenewald

Description

@kgreenewald

System Info

NA

Who can help?

I can potentially implement changes in a PR, but would like guidance on how maintainers would prefer this fixed. Previously, I worked with @githubnemo on building the Activated LoRA feature for PEFT.

Reproduction

Activated LoRA in PEFT works by having forward hooks for each layer that allow an argument "alora_offsets" to be passed from the PeftModelforCausalLM.forward into the appropriate lora variant forward.

I've discovered that with gradient checkpointing (at least in the simple DPO example below), the .backward operation does not save this alora_offsets input, does not re-run the forward hooks to recreate it, and hence uses alora_offsets = None in practice for the .backward computation. To be clear, no errors are thrown, just the computed gradients are zero. Also turning gradient_checkpointing=False fixes the issue (as you might expect).

There might be a lot of ways to try to fix this and keep support for gradient checkpointing. For instance, the simplest fix (but also pretty unsafe) would be to edit ALoraLinearVariant.forward in tuners/lora/variants.py by adding the following lines:

alora_offsets = kwargs.get("alora_offsets", None) 
if alora_offsets is not None:
         module.alora_offsets = alora_offsets
 else:
         alora_offsets = module.alora_offsets

Here, the initial forward pass in the training loop will use the specified alora_offsets coming from the forward hooks and save it. In the gradient-checkpointed backward pass, alora_offsets will be None, but the original value would be recovered and used. While this does seem to work (I tried it), it feels unsafe to me.

Perhaps there could be a safer way of doing this, e.g. by hashing the layer input x and active adapter, and using that to store alora_offsets more safely, i.e. the saved alora_offsets will only be used if the hashed value matches a key in some offsets table. As a result, this should only trigger in a .backward scenario where the true alora_offsets got removed for some reason.

What do you think (e.g. @githubnemo )? Happy to clarify anything and/or discuss further.

Example train script where you can see that the gradients are all 0 (fixed by turning off gradient checkpointing).

from peft import LoraConfig
import os
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer


model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct",device_map="cuda", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct",padding_side='left', trust_remote_code=True)
model.pad_token_id = tokenizer.pad_token_id
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
test_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="test")
peft_config = LoraConfig(
            task_type="CAUSAL_LM",
                r=32,
                lora_alpha=32,
                lora_dropout=0.05,
                bias="none",
                alora_invocation_tokens = [151644], #this is from the chat template, i.e. the generation prompt
                target_modules=["q_proj", "k_proj", "v_proj","o_proj",
                                                    "gate_proj", "up_proj", "down_proj",
                                                        ],
)
training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", report_to="wandb",per_device_train_batch_size=4,
            per_device_eval_batch_size=4,
            learning_rate=5e-6,
            gradient_checkpointing=True, 
            gradient_accumulation_steps=2,
            eval_strategy="steps",   
            eval_steps=500,                
            logging_steps=50,
            max_length=2048,
            save_strategy="steps",
            save_steps=1000,
            save_total_limit=2,)
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset,peft_config=peft_config, eval_dataset=test_dataset)
trainer.train()
model.save_pretrained("./models/dpo/alora")

Expected behavior

See above

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions