Skip to content

Is alpha calculated correctly? #2204

@iqddd

Description

@iqddd

def state_dict(self, destination=None, prefix="", keep_vars=False):
if not self.split_qkv:
return super().state_dict(destination, prefix, keep_vars)
# merge qkv
state_dict = super().state_dict(destination, prefix, keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
else:
new_state_dict[key] = state_dict[key]
continue
if key not in state_dict:
continue # already merged
lora_name = key.split(".")[0]
# (rank, in_dim) * 3
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))]
# (split dim, rank) * 3
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]
alpha = state_dict.pop(f"{lora_name}.alpha")
# merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)
# merge up weight (sum of split_dim, rank*3)
rank = up_weights[0].size(1)
up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0
for j in range(len(split_dims)):
up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j]
i += split_dims[j]
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha
# print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
# )
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")
return new_state_dict

With split_qkv, we merge ModuleList into a single fused tensor. The rank of this tensor will be three or four times larger than regular. At the same time, we leave alpha unchanged. If the inference code does not take this feature into account and calculates scale using the usual formula: scale=alpha/lora_down.shape[0], then the contribution of these layers is weakened by 3-4 times.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions