Skip to content

Commit

Permalink
FIX Several bugs in LoKr (#2180)
Browse files Browse the repository at this point in the history
- Added rank_dropout_scale parameter 
- Fix scale related corrections
- Added lycoris weight initialization
  • Loading branch information
yaswanth19 authored Nov 5, 2024
1 parent 13fb29f commit b1fd97d
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 9 deletions.
15 changes: 9 additions & 6 deletions src/peft/tuners/lokr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional, Union
from typing import Literal, Optional, Union

from peft.tuners.lycoris_utils import LycorisConfig
from peft.utils import PeftType
Expand All @@ -40,6 +40,8 @@ class LoKrConfig(LycorisConfig):
Perform rank decomposition of left kronecker product matrix.
decompose_factor (`int`):
Kronecker product decomposition factor.
rank_dropout_scale ('bool)
Whether to scale the rank dropout while training, defaults to `False`.
target_modules (`Optional[Union[List[str], str]]`):
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
Expand All @@ -53,8 +55,8 @@ class LoKrConfig(LycorisConfig):
When passing a list of strings, either an exact match will be performed or it is checked if the name of the
module ends with any of the passed strings.
init_weights (`bool`):
Whether to perform initialization of adapter weights. This defaults to `True`, passing `False` is
discouraged.
Whether to perform initialization of adapter weights. This defaults to `True`. Use "lycoris" to initialize
weights in the style of the LYCORIS repository. Passing `False` is discouraged.
layers_to_transform (`Union[List[int], int]`):
The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
that are specified in this list. If a single integer is passed, it will apply the transformations on the
Expand Down Expand Up @@ -91,6 +93,7 @@ class LoKrConfig(LycorisConfig):
metadata={"help": "Perform rank decomposition of left kronecker product matrix."},
)
decompose_factor: int = field(default=-1, metadata={"help": "Kronecker product decomposition factor."})
rank_dropout_scale: bool = field(default=False, metadata={"help": "Rank dropout scale"})
target_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={
Expand All @@ -103,12 +106,12 @@ class LoKrConfig(LycorisConfig):
default=None,
metadata={"help": "List of module names or regex expression of the module names to exclude from LoKr."},
)
init_weights: bool = field(
init_weights: Union[bool, Literal["lycoris"]] = field(
default=True,
metadata={
"help": (
"Whether to initialize the weights of the LoKr layers with their default initialization. Don't change "
"this setting, except if you know exactly what you're doing."
"Whether to initialize the weights of the LoKr layers with their default initialization. Can be True, False or 'lycoris'."
"Default is True. Don't change this setting to False, except if you know exactly what you're doing."
),
},
)
Expand Down
28 changes: 25 additions & 3 deletions src/peft/tuners/lokr/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,23 @@ def reset_adapter_parameters_random(self, adapter_name: str):
if adapter_name in self.lokr_t2:
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5))

# Initializes weight matrices similar to the way initialized in the LyCORIS repository.
def reset_adapter_parameters_lycoris_way(self, adapter_name):
if adapter_name in self.lokr_w1:
nn.init.kaiming_uniform_(self.lokr_w1[adapter_name], a=math.sqrt(5))
else:
nn.init.kaiming_uniform_(self.lokr_w1_a[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lokr_w1_b[adapter_name], a=math.sqrt(5))

if adapter_name in self.lokr_w2:
nn.init.zeros_(self.lokr_w2[adapter_name])
else:
nn.init.zeros_(self.lokr_w2_b[adapter_name])
nn.init.kaiming_uniform_(self.lokr_w2_a[adapter_name], a=math.sqrt(5))

if adapter_name in self.lokr_t2:
nn.init.kaiming_uniform_(self.lokr_t2[adapter_name], a=math.sqrt(5))

def update_layer(
self,
adapter_name: str,
Expand Down Expand Up @@ -160,6 +177,7 @@ def update_layer(
self.scaling[adapter_name] = alpha / r
self.rank_dropout[adapter_name] = rank_dropout
self.module_dropout[adapter_name] = module_dropout
self.rank_dropout_scale[adapter_name] = kwargs["rank_dropout_scale"]
base_layer = self.get_base_layer()

# Determine shape of LoKr weights
Expand Down Expand Up @@ -192,7 +210,10 @@ def update_layer(

# Initialize weights
if init_weights:
self.reset_adapter_parameters(adapter_name)
if init_weights == "lycoris":
self.reset_adapter_parameters_lycoris_way(adapter_name)
else:
self.reset_adapter_parameters(adapter_name)
else:
self.reset_adapter_parameters_random(adapter_name)

Expand All @@ -215,15 +236,16 @@ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
w2 = self.lokr_w2_a[adapter_name] @ self.lokr_w2_b[adapter_name]

# Make weights with Kronecker product
weight = make_kron(w1, w2)
weight = make_kron(w1, w2, self.scaling[adapter_name])
weight = weight.reshape(self.get_base_layer().weight.shape)

# Perform rank dropout during training - drop rows of addition weights
rank_dropout = self.rank_dropout[adapter_name]
if self.training and rank_dropout:
drop = (torch.rand(weight.size(0)) > rank_dropout).float()
drop = drop.view(-1, *[1] * len(weight.shape[1:])).to(weight.device)
drop /= drop.mean()
if self.rank_dropout_scale[adapter_name]:
drop /= drop.mean()
weight *= drop

return weight
Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/lokr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _create_and_replace(
kwargs = config.to_dict()
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
kwargs["alpha"] = config.alpha_pattern.get(target_name_key, config.alpha)
kwargs["rank_dropout_scale"] = config.rank_dropout_scale

if isinstance(target, LoKrLayer):
target.update_layer(adapter_name, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/lycoris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, base_layer: nn.Module) -> None:
self.alpha = {}
self.scaling = {}
self.rank_dropout = {}
self.rank_dropout_scale = {}
self.module_dropout = {}

# Tuner info
Expand Down
89 changes: 89 additions & 0 deletions tests/test_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from peft import (
AdaLoraConfig,
IA3Config,
LoKrConfig,
LoraConfig,
PeftMixedModel,
PeftModel,
Expand Down Expand Up @@ -1120,6 +1121,94 @@ def test_lora_use_dora_with_megatron_core_raises(self):
LoraConfig(target_modules=["linear"], use_dora=True, megatron_config=megatron_config)


class TestLokrInitialization:
torch_device = infer_device()

def get_model(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# Choose a large weight so that averages are close to expected values.
self.linear = nn.Linear(1000, 1000)
self.conv2d = nn.Conv2d(100, 100, 3)

def forward(self, x):
x_4d = x.flatten().reshape(1, 100, 10, 10)
return self.linear(x), self.conv2d(x_4d)

return MyModule().eval().to(self.torch_device)

@pytest.fixture
def data(self):
return torch.rand(10, 1000).to(self.torch_device)

def test_lokr_linear_init_default(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[0]
config = LoKrConfig(target_modules=["linear"])
model = get_peft_model(model, config)
output_after = model(data)[0]

assert torch.allclose(output_before, output_after)

def test_lokr_linear_init_false(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[0]
config = LoKrConfig(target_modules=["linear"], init_weights=False)
model = get_peft_model(model, config)
output_after = model(data)[0]

assert not torch.allclose(output_before, output_after)

def test_lokr_linear_init_lycoris(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[0]
config = LoKrConfig(target_modules=["linear"], init_weights="lycoris")
model = get_peft_model(model, config)
output_after = model(data)[0]

assert torch.allclose(output_before, output_after)

def test_lokr_conv2d_init_default(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[1]
config = LoKrConfig(target_modules=["conv2d"])
model = get_peft_model(model, config)
output_after = model(data)[1]

assert torch.allclose(output_before, output_after)

def test_lokr_conv2d_init_false(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[1]
config = LoKrConfig(target_modules=["conv2d"], init_weights=False)
model = get_peft_model(model, config)
output_after = model(data)[1]

assert not torch.allclose(output_before, output_after)

def test_lokr_conv2d_init_lycoris(self, data):
torch.manual_seed(0)

model = self.get_model()
output_before = model(data)[1]
config = LoKrConfig(target_modules=["conv2d"], init_weights="lycoris")
model = get_peft_model(model, config)
output_after = model(data)[1]

assert torch.allclose(output_before, output_after)


class TestAdaLoraInitialization:
torch_device = infer_device()

Expand Down

0 comments on commit b1fd97d

Please sign in to comment.