-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
sd-scripts/networks/lora_flux.py
Lines 1086 to 1134 in f5d44fd
def state_dict(self, destination=None, prefix="", keep_vars=False): | |
if not self.split_qkv: | |
return super().state_dict(destination, prefix, keep_vars) | |
# merge qkv | |
state_dict = super().state_dict(destination, prefix, keep_vars) | |
new_state_dict = {} | |
for key in list(state_dict.keys()): | |
if "double" in key and "qkv" in key: | |
split_dims = [3072] * 3 | |
elif "single" in key and "linear1" in key: | |
split_dims = [3072] * 3 + [12288] | |
else: | |
new_state_dict[key] = state_dict[key] | |
continue | |
if key not in state_dict: | |
continue # already merged | |
lora_name = key.split(".")[0] | |
# (rank, in_dim) * 3 | |
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))] | |
# (split dim, rank) * 3 | |
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))] | |
alpha = state_dict.pop(f"{lora_name}.alpha") | |
# merge down weight | |
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim) | |
# merge up weight (sum of split_dim, rank*3) | |
rank = up_weights[0].size(1) | |
up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype) | |
i = 0 | |
for j in range(len(split_dims)): | |
up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j] | |
i += split_dims[j] | |
new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight | |
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight | |
new_state_dict[f"{lora_name}.alpha"] = alpha | |
# print( | |
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}" | |
# ) | |
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha") | |
return new_state_dict |
With split_qkv
, we merge ModuleList into a single fused tensor. The rank of this tensor will be three or four times larger than regular. At the same time, we leave alpha unchanged. If the inference code does not take this feature into account and calculates scale using the usual formula: scale=alpha/lora_down.shape[0], then the contribution of these layers is weakened by 3-4 times.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working