Skip to content

Commit 0d997a9

Browse files
Merge pull request #1 from parmanu-lcs2/parmanu-lcs2/main
Implemented MonteCLoRA into LoRA variants
2 parents 1371d40 + 57d764c commit 0d997a9

File tree

17 files changed

+1281
-158
lines changed

17 files changed

+1281
-158
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ __pycache__/
77
*.so
88

99
# Distribution / packaging
10+
test_imp/
1011
.Python
1112
build/
1213
develop-eggs/

src/peft/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
AutoPeftModelForTokenClassification,
2626
)
2727
from .config import PeftConfig, PromptLearningConfig
28+
from .helpers import MonteCLoRATrainerMixin
2829
from .mapping import (
2930
PEFT_TYPE_TO_CONFIG_MAPPING,
3031
PEFT_TYPE_TO_MIXED_MODEL_MAPPING,
@@ -87,7 +88,6 @@
8788
MissConfig,
8889
MissModel,
8990
MonteCLoraConfig,
90-
MonteCLoraModel,
9191
MultitaskPromptTuningConfig,
9292
MultitaskPromptTuningInit,
9393
OFTConfig,
@@ -202,8 +202,8 @@
202202
"LoraRuntimeConfig",
203203
"MissConfig",
204204
"MissModel",
205+
"MonteCLoRATrainerMixin",
205206
"MonteCLoraConfig",
206-
"MonteCLoraModel",
207207
"MultitaskPromptTuningConfig",
208208
"MultitaskPromptTuningInit",
209209
"OFTConfig",

src/peft/helpers.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,93 @@ def disable_input_dtype_casting(model: nn.Module, active: bool = True):
251251
module.cast_input_dtype_enabled = original_values[name]
252252

253253

254+
class MonteCLoRATrainerMixin:
255+
"""
256+
Mixin class for adding MonteCLoRA variational loss to the Trainer's compute_loss method.
257+
258+
This mixin can be used with any Trainer class (e.g., Trainer, SFTTrainer) to add support for
259+
MonteCLoRA's variational regularization during training.
260+
261+
Example:
262+
```python
263+
from transformers import Trainer
264+
from peft import get_peft_model, LoraConfig
265+
from peft.helpers import MonteCLoRATrainerMixin
266+
from peft.tuners.monteclora_new import MonteCLoraConfig
267+
268+
#custom trainer that supports MonteCLoRA
269+
class MonteCLoRATrainer(MonteCLoRATrainerMixin, Trainer):
270+
pass
271+
272+
# Configure LoRA with MonteCLoRA
273+
monteclora_config = MonteCLoraConfig(
274+
monteclora_n=8,
275+
sample_scaler=1e-4,
276+
kl_loss_weight=1e-5,
277+
)
278+
lora_config = LoraConfig(
279+
r=16,
280+
lora_alpha=32,
281+
target_modules=["q_proj", "v_proj"],
282+
use_monteclora=True,
283+
monteclora_config=monteclora_config,
284+
)
285+
286+
# Get PEFT model and train
287+
model = get_peft_model(base_model, lora_config)
288+
trainer = MonteCLoRATrainer(model=model, args=training_args, ...)
289+
trainer.train()
290+
```
291+
"""
292+
293+
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
294+
"""
295+
Compute loss with MonteCLoRA variational regularization.
296+
297+
This method extends the standard compute_loss by adding the variational loss
298+
(KL divergence + entropy) from MonteCLoRA samplers to the task loss.
299+
300+
Args:
301+
model: The model being trained
302+
inputs: Input batch
303+
return_outputs: Whether to return model outputs along with loss
304+
**kwargs: Additional arguments
305+
306+
Returns:
307+
loss or (loss, outputs) depending on return_outputs
308+
"""
309+
# 1. Compute the standard task loss
310+
if return_outputs:
311+
task_loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs)
312+
else:
313+
task_loss = super().compute_loss(model, inputs, return_outputs=False, **kwargs)
314+
outputs = None
315+
316+
# 2. Calculate Variational Loss (KLD + Entropy) from MonteCLoRA samplers
317+
var_loss_sum = 0.0
318+
num_monte_layers = 0
319+
320+
# Iterate through modules to find MonteCLoRA samplers
321+
for name, module in model.named_modules():
322+
# Check if this is a MonteCLoRASampler by checking for the get_variational_loss method
323+
if hasattr(module, "get_variational_loss") and module.__class__.__name__ == "MonteCLoRASampler":
324+
try:
325+
kl_loss, entropy_loss = module.get_variational_loss()
326+
var_loss_sum += kl_loss + entropy_loss
327+
num_monte_layers += 1
328+
except Exception:
329+
# Silently ignore if get_variational_loss fails
330+
pass
331+
332+
# 3. Normalize the Variational Loss
333+
regularization_loss = 0.0
334+
if num_monte_layers > 0:
335+
regularization_loss = var_loss_sum / num_monte_layers
336+
337+
# 4. Combine losses
338+
total_loss = task_loss + regularization_loss
339+
340+
return (total_loss, outputs) if return_outputs else total_loss
254341
class DoraCaching:
255342
"""Context manager to enable DoRA caching, which improves speed of DoRA inference at the expense of memory.
256343

src/peft/tuners/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646
from .miss import MissConfig, MissModel
4747
from .mixed import MixedModel
48-
from .monteclora import MonteCLoraConfig, MonteCLoraModel
48+
from .monteclora import MonteCLoraConfig
4949
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit
5050
from .oft import OFTConfig, OFTModel
5151
from .osf import OSFConfig, OSFModel
@@ -106,7 +106,6 @@
106106
"MissModel",
107107
"MixedModel",
108108
"MonteCLoraConfig",
109-
"MonteCLoraModel",
110109
"MultitaskPromptEmbedding",
111110
"MultitaskPromptTuningConfig",
112111
"MultitaskPromptTuningInit",

src/peft/tuners/lora/config.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,26 @@ class LoraConfig(PeftConfig):
693693
)
694694
},
695695
)
696+
use_monteclora: bool = field(
697+
default=False,
698+
metadata={
699+
"help": (
700+
"Enable MonteCLoRA (Monte Carlo Low-Rank Adaptation). This technique introduces variational "
701+
"inference into LoRA by adding Monte Carlo sampling to the adapter weights during training. "
702+
"This can improve model performance and uncertainty estimation. When enabled, you should also "
703+
"provide `monteclora_config` with the MonteCLoRA hyperparameters."
704+
)
705+
},
706+
)
707+
monteclora_config: Optional[MonteCLoraConfig] = field( # noqa: F821
708+
default=None,
709+
metadata={
710+
"help": (
711+
"The configuration of MonteCLoRA. If this is passed along with `use_monteclora=True`, then "
712+
"MonteCLoRA will be used to add variational sampling to the LoRA adapters."
713+
)
714+
},
715+
)
696716
# Enables replicating layers in a model to expand it to a larger model.
697717
layer_replication: Optional[list[tuple[int, int]]] = field(
698718
default=None,
@@ -832,6 +852,20 @@ def __post_init__(self):
832852
elif self.init_lora_weights != "corda" and self.corda_config is not None:
833853
warnings.warn("`corda_config` specified but will be ignored when `init_lora_weights` is not 'corda'.")
834854

855+
# Handle MonteCLoRA configuration
856+
if self.use_monteclora:
857+
from peft.tuners.monteclora.config import MonteCLoraConfig
858+
859+
if self.monteclora_config is None:
860+
warnings.warn(
861+
"`use_monteclora=True` but `monteclora_config` is not specified. Using default MonteCLoRA config."
862+
)
863+
self.monteclora_config = MonteCLoraConfig()
864+
elif isinstance(self.monteclora_config, dict):
865+
self.monteclora_config = MonteCLoraConfig(**self.monteclora_config)
866+
elif self.monteclora_config is not None:
867+
warnings.warn("`monteclora_config` specified but will be ignored when `use_monteclora=False`.")
868+
835869
if self.lora_bias:
836870
if self.init_lora_weights not in (True, False):
837871
raise ValueError(

src/peft/tuners/lora/layer.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,21 @@ def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[Lora
147147

148148
def update_layer(
149149
self,
150+
adapter_name,
151+
r,
152+
lora_alpha,
153+
lora_dropout,
154+
init_lora_weights,
155+
use_rslora,
156+
use_dora: bool = False,
157+
use_alora: bool = False,
158+
use_qalora: bool = False,
159+
use_monteclora: bool = False,
160+
lora_bias: bool = False,
161+
arrow_config: ArrowConfig = None,
162+
monteclora_config=None,
163+
qalora_group_size: int = 32,
164+
inference_mode: bool = False,
150165
adapter_name: str,
151166
r: int,
152167
lora_alpha: int,
@@ -174,7 +189,15 @@ def update_layer(
174189
PeftWarning,
175190
)
176191

177-
lora_variant = self.resolve_lora_variant(config=config)
192+
lora_variant = self.resolve_lora_variant(
193+
use_dora=use_dora,
194+
use_alora=use_alora,
195+
use_qalora=use_qalora,
196+
use_monteclora=use_monteclora,
197+
qalora_group_size=qalora_group_size,
198+
arrow_config=arrow_config,
199+
monteclora_config=monteclora_config,
200+
)
178201
if lora_variant is not None:
179202
self.lora_variant[adapter_name] = lora_variant
180203

@@ -732,6 +755,14 @@ def __init__(
732755
r: int = 0,
733756
lora_alpha: int = 1,
734757
is_target_conv_1d_layer: bool = False,
758+
init_lora_weights: Union[bool, str] = True,
759+
use_rslora: bool = False,
760+
use_dora: bool = False,
761+
use_alora: bool = False,
762+
use_monteclora: bool = False,
763+
arrow_config: ArrowConfig = None,
764+
monteclora_config=None,
765+
lora_bias: bool = False,
735766
**kwargs,
736767
) -> None:
737768
super().__init__()
@@ -743,24 +774,32 @@ def __init__(
743774
adapter_name,
744775
r,
745776
lora_alpha=lora_alpha,
746-
config=config,
747-
**kwargs,
777+
lora_dropout=lora_dropout,
778+
init_lora_weights=init_lora_weights,
779+
use_rslora=use_rslora,
780+
use_dora=use_dora,
781+
use_alora=use_alora,
782+
use_monteclora=use_monteclora,
783+
lora_bias=lora_bias,
784+
arrow_config=arrow_config,
785+
monteclora_config=monteclora_config,
748786
)
749787
self.is_target_conv_1d_layer = is_target_conv_1d_layer
750788

751-
def resolve_lora_variant(self, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
752-
if config.arrow_config is not None:
789+
def resolve_lora_variant(
790+
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, use_monteclora: bool = False, **kwargs
791+
) -> Optional[LoraVariant]:
792+
if arrow_config is not None:
753793
from .variants import ArrowLinearVariant
754794

755795
return ArrowLinearVariant()
756796

757-
if config.use_bdlora is not None:
758-
from .variants import BdLoraLinearVariant
797+
if use_monteclora:
798+
from peft.tuners.monteclora.variant import MonteCLoraLinearVariant
759799

760-
return BdLoraLinearVariant()
800+
return MonteCLoraLinearVariant()
761801

762-
use_alora = config.alora_invocation_tokens is not None
763-
if not config.use_dora and not use_alora:
802+
if not use_dora and not use_alora:
764803
return None
765804

766805
from .variants import ALoraLinearVariant, DoraLinearVariant

src/peft/tuners/lora/model.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,19 @@ def _create_and_replace(
206206
kwargs = {
207207
"r": r,
208208
"lora_alpha": alpha,
209-
"target_name": current_key,
209+
"lora_dropout": lora_config.lora_dropout,
210+
"fan_in_fan_out": lora_config.fan_in_fan_out,
211+
"init_lora_weights": lora_config.init_lora_weights,
212+
"use_rslora": lora_config.use_rslora,
213+
"use_dora": lora_config.use_dora,
214+
"use_alora": lora_config.alora_invocation_tokens is not None,
215+
"use_qalora": lora_config.use_qalora,
216+
"use_monteclora": lora_config.use_monteclora,
217+
"qalora_group_size": lora_config.qalora_group_size,
218+
"monteclora_config": lora_config.monteclora_config,
219+
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
220+
"lora_bias": lora_config.lora_bias,
221+
"arrow_config": lora_config.arrow_config,
210222
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
211223
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
212224
"parameter_name": parameter_name,
@@ -236,8 +248,15 @@ def _create_and_replace(
236248
adapter_name,
237249
r,
238250
lora_alpha=alpha,
239-
target_name=current_key,
240-
config=lora_config,
251+
lora_dropout=lora_config.lora_dropout,
252+
init_lora_weights=lora_config.init_lora_weights,
253+
use_rslora=lora_config.use_rslora,
254+
use_dora=lora_config.use_dora,
255+
use_monteclora=lora_config.use_monteclora,
256+
lora_bias=lora_config.lora_bias,
257+
arrow_config=lora_config.arrow_config,
258+
monteclora_config=lora_config.monteclora_config,
259+
inference_mode=lora_config.inference_mode,
241260
)
242261
else:
243262
if isinstance(target, ParamWrapper) and (parameter_name == target.parameter_name):
Lines changed: 3 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-present the HuggingFace Inc. team.
1+
# Copyright 2026-present the HuggingFace Inc. team.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,35 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available
16-
from peft.utils import register_peft_method
17-
1815
from .config import MonteCLoraConfig
19-
from .layer import MonteCLoraLinear
20-
from .model import MonteCLoraModel
21-
22-
23-
__all__ = ["MonteCLoraConfig", "MonteCLoraLinear", "MonteCLoraModel"]
24-
25-
register_peft_method(
26-
name="monteclora", prefix="lora_", config_cls=MonteCLoraConfig, model_cls=MonteCLoraModel, is_mixed_compatible=True
27-
)
28-
29-
30-
def __getattr__(name):
31-
if (name == "Linear8bitLt") and is_bnb_available():
32-
from peft.tuners.lora.bnb import Linear8bitLt
33-
34-
return Linear8bitLt
35-
36-
if (name == "Linear4bit") and is_bnb_4bit_available():
37-
from peft.tuners.lora.bnb import Linear4bit
38-
39-
return Linear4bit
40-
41-
if (name == "EetqLoraLinear") and is_eetq_available():
42-
from peft.tuners.lora.eetq import EetqLoraLinear
16+
from .variant import MonteCLoraLinearVariant
4317

44-
return EetqLoraLinear
4518

46-
raise AttributeError(f"module {__name__} has no attribute {name}")
19+
__all__ = ["MonteCLoraConfig", "MonteCLoraLinearVariant"]

0 commit comments

Comments
 (0)