Description
System Info
peft = 0.13.2
python = 3.12.7
transformers = 4.45.2
Who can help?
I am using inject_adapter_model(...)
to finetune a model from OpenCLIP using LoRA layers. I am able to finetune the model by modifying Linear()
layers and other supported types as expected. However, there is a model that I am currently training that has an attention module called "out_proj" that has the following layer type NonDynamicallyQuantizableLinear(Linear)
. I may be mistaken but from my understanding of the source code for NonDynamicallyQuantizableLinear
(https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/linear.py#L136), I should be able to treat it as just a typical torch.nn.Linear
layer for my purposes. However, I always get the following error: "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn". The lora layers are also added as expected.
(transformer): Transformer(
(resblocks): ModuleList(
(0-11): 12 x ResidualAttentionBlock(
(ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(out_proj): lora.Linear(
(base_layer): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=512, out_features=32, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=32, out_features=512, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
(lora_magnitude_vector): ModuleDict()
)
)
The layers are successfully added when I add it via target_modules and also if I use register_custom_modules
with the following mapping torch.nn.modules.linear.NonDynamicallyQuantizableLinear
-> peft.tuners.lora.layer.Linear
. However, neither case trains. Furthermore, the model trains when I include any other layers e.g. a fully-connected one that's of type torch.nn.Linear
.
target_modules =
- [out_proj] doesn't train
- [fc1] trains
- [out_proj, fc1] trains (I have to conclude that the attention layers aren't really being trained in this case and this config is equivalent to the one immediately above)
Any idea why this may be the case? Your help would be truly appreciated
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder - My own task or dataset (give details below)
Reproduction
Train step:
def train_step(model, batch, loss_fn, device, trainer_cfg):
images, tokenized_texts = batch
images, tokenized_texts = images.to(device), tokenized_texts.to(device)
# Forward pass: Get embeddings for images and texts # (bs, 512), (bs, 512), scalar
image_features, text_features, scale_exp = model(images, tokenized_texts)
# Compute logits as dot products between image and text features # (bs, bs)
logits_per_image = (image_features @ text_features.T) / scale_exp
logits_per_text = logits_per_image.T
# Create labels (diagonal is the correct match)
labels = torch.arange(images.shape[0], device=device)
# Compute loss (bs,)
loss = (loss_fn(logits_per_image, labels) +
loss_fn(logits_per_text, labels)) / 2
# If gradient accumulation is used, normalize the loss
accumulation_steps = trainer_cfg.get('accumulation_steps', 1)
if accumulation_steps > 1:
loss = loss / accumulation_steps
# Backward pass
loss.backward()
# Apply gradient clipping if necessary
max_grad_norm = trainer_cfg.get('max_grad_norm', None)
if max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
return loss.item()
Model structure near a layer of interest:
('transformer.resblocks.11.attn', <class 'torch.nn.modules.activation.MultiheadAttention'>)
('transformer.resblocks.11.attn.out_proj', <class 'torch.nn.modules.linear.NonDynamicallyQuantizableLinear'>)
('transformer.resblocks.11.ls_1', <class 'torch.nn.modules.linear.Identity'>)
Injection code:
model_path = cfg.training.model_name
# Load pretrained model, tokenizer, and image processor
model, preprocess_train, preprocess_val = create_model_and_transforms(model_path)
tokenizer = get_tokenizer(model_path)
print("Before adapting..")
total_params, trainable_params, trainable_percent = count_parameters(model)
lora_config = LoraConfig(**cfg.lora_config.kwargs)
# kwargs:
# target_modules: ["out_proj"]
# r: 32 # Rank of the LoRA low-rank matrices
# lora_alpha: 32 # Scaling factor for LoRA updates
# lora_dropout: 0.1 # Dropout for LoRA layers to avoid overfitting
# bias: 'none' # Whether to use bias in LoRA layers []'none', 'all', 'lora_only']
#
lora_model = inject_adapter_in_model(lora_config, model)
Expected behavior
I would expect it to begin training. Here are the first few print outs of atypical run
[2024-11-05 20:01:55,364][utils.loggers][INFO] - Total Parameters: 154,406,529
[2024-11-05 20:01:55,366][utils.loggers][INFO] - Trainable Parameters: 2,887,680 (1.87%)
[2024-11-05 20:01:55,371][utils.loggers][INFO] - Patience scaled to 10 validation steps
Epoch 1/105: 0it [00:00, ?it/s]c:\Users\spet4299\Anaconda3\envs\tee_clip_env\Lib\site-packages\torch\nn\functional.py:5476: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ..\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:263.)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
[2024-11-05 20:01:57,719][utils.loggers][INFO] - Epoch 1 | step 1 | Train Loss = 2.7726
Epoch 1/105: 1it [00:02, 2.34s/it][2024-11-05 20:01:58,755][utils.loggers][INFO] - Epoch 1 | step 2 | Train Loss = 2.7726
Epoch 1/105: 2it [00:03, 1.57s/it][2024-11-05 20:01:59,659][utils.loggers][INFO] - Epoch 1 | step 3 | Train Loss = 2.7726