Skip to content

Commit f545712

Browse files
ENH: Add tests, docs, types for scaling methods (#2526)
For the LoRA methods - set_scale - scale_layer - unscale_layer unit tests, docstrings, and type annotations were added. --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 6c8c3c3 commit f545712

File tree

2 files changed

+279
-3
lines changed

2 files changed

+279
-3
lines changed

src/peft/tuners/lora/layer.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -450,13 +450,18 @@ def _cache_pop(self, key: str) -> Any:
450450
value = self._caches.pop(key)
451451
return value
452452

453-
def set_scale(self, adapter, scale):
453+
def set_scale(self, adapter: str, scale: float | int) -> None:
454+
"""Set the scale of the given adapter to the initial scale multiplied by the provided factor
455+
456+
The initial scale is determined by the configured `r` (rank) and `lora_alpha`.
457+
"""
454458
if adapter not in self.scaling:
455459
# Ignore the case where the adapter is not in the layer
456460
return
457461
self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
458462

459-
def scale_layer(self, scale: float) -> None:
463+
def scale_layer(self, scale: float | int) -> None:
464+
"""Multiply the current scale of all active adapters by the provided factor"""
460465
if scale == 1:
461466
return
462467

@@ -466,7 +471,13 @@ def scale_layer(self, scale: float) -> None:
466471

467472
self.scaling[active_adapter] *= scale
468473

469-
def unscale_layer(self, scale=None) -> None:
474+
def unscale_layer(self, scale: Optional[float | int] = None) -> None:
475+
"""Divide the current scale of all active adapters by the provided factor. If `scale=None` is passed, reset to
476+
initial scale
477+
478+
The initial scale is determined by the configured `r` (rank) and `lora_alpha`.
479+
480+
"""
470481
for active_adapter in self.active_adapters:
471482
if active_adapter not in self.lora_A.keys():
472483
continue

tests/test_initialization.py

+265
Original file line numberDiff line numberDiff line change
@@ -3492,3 +3492,268 @@ def test_import_peft_type_to_model_mapping_deprecation_warning(recwarn):
34923492
# check that there is a warning with this message after importing the variable
34933493
warnings = (w.message.args[0] for w in recwarn.list)
34943494
assert any(w.startswith(expected) for w in warnings)
3495+
3496+
3497+
class TestScaling:
3498+
"""Tests for scaling and unscaling
3499+
3500+
Those methods are currently only implemented for LoRA and were added for use in diffusers.
3501+
"""
3502+
3503+
@pytest.fixture
3504+
def model(self):
3505+
# tiny opt with 5 attention layers
3506+
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
3507+
return AutoModelForCausalLM.from_pretrained(model_id)
3508+
3509+
def get_scalings(self, model, adapter_name="default"):
3510+
# helper function, returns the scalings of the 5 attention layers
3511+
return [m.scaling[adapter_name] for m in model.modules() if isinstance(m, LoraLayer)]
3512+
3513+
def set_scale(self, model, adapter_name, scale):
3514+
for module in model.modules():
3515+
if isinstance(module, LoraLayer):
3516+
module.set_scale(adapter_name, scale)
3517+
3518+
def scale_layer(self, model, scale):
3519+
for module in model.modules():
3520+
if isinstance(module, LoraLayer):
3521+
module.scale_layer(scale)
3522+
3523+
def unscale_layer(self, model, scale):
3524+
for module in model.modules():
3525+
if isinstance(module, LoraLayer):
3526+
module.unscale_layer(scale)
3527+
3528+
def test_scaling_simple(self, model):
3529+
n_layers = 5
3530+
rank, lora_alpha = 8, 16
3531+
config = LoraConfig(
3532+
r=rank,
3533+
lora_alpha=lora_alpha,
3534+
target_modules=["k_proj"],
3535+
)
3536+
model = get_peft_model(model, config)
3537+
scalings = self.get_scalings(model)
3538+
expected = [lora_alpha / rank] * n_layers
3539+
assert scalings == expected
3540+
3541+
# double
3542+
self.scale_layer(model, 2)
3543+
scalings = self.get_scalings(model)
3544+
expected = [4.0] * n_layers
3545+
assert scalings == expected
3546+
3547+
# back to original
3548+
self.unscale_layer(model, None)
3549+
scalings = self.get_scalings(model)
3550+
expected = [2.0] * n_layers
3551+
assert scalings == expected
3552+
3553+
# triple
3554+
self.set_scale(model, "default", 3)
3555+
scalings = self.get_scalings(model)
3556+
expected = [6.0] * n_layers
3557+
assert scalings == expected
3558+
3559+
# back to original
3560+
self.unscale_layer(model, 3)
3561+
scalings = self.get_scalings(model)
3562+
expected = [2.0] * n_layers
3563+
assert scalings == expected
3564+
3565+
def test_scaling_rank_pattern_alpha_pattern(self, model):
3566+
# layer 0: 8 / 8
3567+
# layer 1: 8 / 16
3568+
# layer 2: 4 / 32
3569+
# layer 3: 16 / 8
3570+
# layer 4: 8 / 8
3571+
config = LoraConfig(
3572+
r=8,
3573+
lora_alpha=8,
3574+
target_modules=["k_proj"],
3575+
rank_pattern={"layers.1.self_attn.k_proj": 16, "layers.2.self_attn.k_proj": 32},
3576+
alpha_pattern={"layers.2.self_attn.k_proj": 4, "layers.3.self_attn.k_proj": 16},
3577+
)
3578+
model = get_peft_model(model, config)
3579+
scalings = self.get_scalings(model)
3580+
expected = [1.0, 0.5, 0.125, 2.0, 1.0]
3581+
assert scalings == expected
3582+
3583+
# double
3584+
self.scale_layer(model, 2)
3585+
scalings = self.get_scalings(model)
3586+
expected = [2.0, 1.0, 0.25, 4.0, 2.0]
3587+
assert scalings == expected
3588+
3589+
# back to original
3590+
self.unscale_layer(model, None)
3591+
scalings = self.get_scalings(model)
3592+
expected = [1.0, 0.5, 0.125, 2.0, 1.0]
3593+
assert scalings == expected
3594+
3595+
# triple
3596+
self.set_scale(model, "default", 3)
3597+
scalings = self.get_scalings(model)
3598+
expected = [3.0, 1.5, 0.375, 6.0, 3.0]
3599+
assert scalings == expected
3600+
3601+
# back to original
3602+
self.unscale_layer(model, 3)
3603+
scalings = self.get_scalings(model)
3604+
expected = [1.0, 0.5, 0.125, 2.0, 1.0]
3605+
assert scalings == expected
3606+
3607+
def test_scaling_multiple_times(self, model):
3608+
# same as previous test, but scale and unscale multiple times in a row
3609+
# layer 0: 8 / 8
3610+
# layer 1: 8 / 16
3611+
# layer 2: 4 / 32
3612+
# layer 3: 16 / 8
3613+
# layer 4: 8 / 8
3614+
config = LoraConfig(
3615+
r=8,
3616+
lora_alpha=8,
3617+
target_modules=["k_proj"],
3618+
rank_pattern={"layers.1.self_attn.k_proj": 16, "layers.2.self_attn.k_proj": 32},
3619+
alpha_pattern={"layers.2.self_attn.k_proj": 4, "layers.3.self_attn.k_proj": 16},
3620+
)
3621+
model = get_peft_model(model, config)
3622+
scalings = self.get_scalings(model)
3623+
expected = [1.0, 0.5, 0.125, 2.0, 1.0]
3624+
assert scalings == expected
3625+
3626+
# scale of 1 makes no difference
3627+
self.scale_layer(model, 1)
3628+
scalings = self.get_scalings(model)
3629+
expected = [1.0, 0.5, 0.125, 2.0, 1.0]
3630+
3631+
# double
3632+
self.scale_layer(model, 2)
3633+
scalings = self.get_scalings(model)
3634+
expected = [2.0, 1.0, 0.25, 4.0, 2.0]
3635+
assert scalings == expected
3636+
3637+
# triple, on top of previous double
3638+
self.scale_layer(model, 3)
3639+
scalings = self.get_scalings(model)
3640+
expected = [6.0, 3.0, 0.75, 12.0, 6.0]
3641+
assert scalings == expected
3642+
3643+
# half
3644+
self.unscale_layer(model, 2)
3645+
scalings = self.get_scalings(model)
3646+
expected = [3.0, 1.5, 0.375, 6.0, 3.0]
3647+
assert scalings == expected
3648+
3649+
# divide by 3, on top of previous half
3650+
self.unscale_layer(model, 3)
3651+
scalings = self.get_scalings(model)
3652+
expected = [1.0, 0.5, 0.125, 2.0, 1.0]
3653+
assert scalings == expected
3654+
3655+
# set scale to 2
3656+
self.set_scale(model, "default", 2)
3657+
scalings = self.get_scalings(model)
3658+
expected = [2.0, 1.0, 0.25, 4.0, 2.0]
3659+
assert scalings == expected
3660+
3661+
# set scale to 3, it is cumulative but based on the initial scaling, so factor 3, not 6
3662+
self.set_scale(model, "default", 3)
3663+
scalings = self.get_scalings(model)
3664+
expected = [3.0, 1.5, 0.375, 6.0, 3.0]
3665+
assert scalings == expected
3666+
3667+
# back to original
3668+
self.unscale_layer(model, None)
3669+
scalings = self.get_scalings(model)
3670+
expected = [1.0, 0.5, 0.125, 2.0, 1.0]
3671+
assert scalings == expected
3672+
3673+
# back to original again
3674+
self.unscale_layer(model, None)
3675+
scalings = self.get_scalings(model)
3676+
expected = [1.0, 0.5, 0.125, 2.0, 1.0]
3677+
assert scalings == expected
3678+
3679+
def test_scaling_multiple_adapters(self, model):
3680+
# ensure that scaling works with multiple adapters
3681+
n_layers = 5
3682+
rank0, lora_alpha0 = 8, 16
3683+
config0 = LoraConfig(
3684+
r=rank0,
3685+
lora_alpha=lora_alpha0,
3686+
target_modules=["k_proj"],
3687+
)
3688+
rank1, lora_alpha1 = 16, 8
3689+
config1 = LoraConfig(
3690+
r=rank1,
3691+
lora_alpha=lora_alpha1,
3692+
target_modules=["k_proj"],
3693+
)
3694+
model = get_peft_model(model, config0)
3695+
model.add_adapter("other", config1)
3696+
3697+
scalings_default = self.get_scalings(model, "default")
3698+
scalings_other = self.get_scalings(model, "other")
3699+
expected_default = [lora_alpha0 / rank0] * n_layers
3700+
expected_other = [lora_alpha1 / rank1] * n_layers
3701+
assert scalings_default == expected_default
3702+
assert scalings_other == expected_other
3703+
3704+
# double the scale for other
3705+
self.set_scale(model, "other", 2)
3706+
scalings_default = self.get_scalings(model, "default")
3707+
scalings_other = self.get_scalings(model, "other")
3708+
expected_default = [lora_alpha0 / rank0] * n_layers
3709+
expected_other = [2 * lora_alpha1 / rank1] * n_layers
3710+
assert scalings_default == expected_default
3711+
assert scalings_other == expected_other
3712+
3713+
# quarter the scale for default
3714+
self.set_scale(model, "default", 0.25)
3715+
scalings_default = self.get_scalings(model, "default")
3716+
scalings_other = self.get_scalings(model, "other")
3717+
expected_default = [lora_alpha0 / rank0 / 4] * n_layers
3718+
expected_other = [2 * lora_alpha1 / rank1] * n_layers
3719+
assert scalings_default == expected_default
3720+
assert scalings_other == expected_other
3721+
3722+
# unscale resets for all *active* adapters
3723+
self.unscale_layer(model, None)
3724+
scalings_default = self.get_scalings(model, "default")
3725+
scalings_other = self.get_scalings(model, "other")
3726+
expected_default = [lora_alpha0 / rank0] * n_layers
3727+
expected_other = [2 * lora_alpha1 / rank1] * n_layers # stays the same as 'other' is not active
3728+
assert scalings_default == expected_default
3729+
assert scalings_other == expected_other
3730+
3731+
# scale all *active* adapters by 2
3732+
self.scale_layer(model, 2)
3733+
scalings_default = self.get_scalings(model, "default")
3734+
scalings_other = self.get_scalings(model, "other")
3735+
expected_default = [2 * lora_alpha0 / rank0] * n_layers
3736+
expected_other = [2 * lora_alpha1 / rank1] * n_layers # stays the same as 'other' is not active
3737+
assert scalings_default == expected_default
3738+
assert scalings_other == expected_other
3739+
3740+
# switch to 'other'
3741+
model.set_adapter("other")
3742+
3743+
# unscale, this time 'other'
3744+
self.unscale_layer(model, None)
3745+
scalings_default = self.get_scalings(model, "default")
3746+
scalings_other = self.get_scalings(model, "other")
3747+
expected_default = [2 * lora_alpha0 / rank0] * n_layers # stays the same as 'other' is not active
3748+
expected_other = [lora_alpha1 / rank1] * n_layers
3749+
assert scalings_default == expected_default
3750+
assert scalings_other == expected_other
3751+
3752+
# scale all *active* adapters by 3
3753+
self.scale_layer(model, 3)
3754+
scalings_default = self.get_scalings(model, "default")
3755+
scalings_other = self.get_scalings(model, "other")
3756+
expected_default = [2 * lora_alpha0 / rank0] * n_layers # stays the same as 'other' is not active
3757+
expected_other = [3 * lora_alpha1 / rank1] * n_layers
3758+
assert scalings_default == expected_default
3759+
assert scalings_other == expected_other

0 commit comments

Comments
 (0)