-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathcomposable_lora.py
165 lines (131 loc) · 6.78 KB
/
composable_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from typing import List, Dict
import re
import torch
from modules import extra_networks, shared
re_AND = re.compile(r"\bAND\b")
def load_prompt_loras(prompt: str):
prompt_loras.clear()
subprompts = re_AND.split(prompt)
tmp_prompt_loras = []
for i, subprompt in enumerate(subprompts):
loras = {}
_, extra_network_data = extra_networks.parse_prompt(subprompt)
for params in extra_network_data['lora']:
name = params.items[0]
multiplier = float(params.items[1]) if len(params.items) > 1 else 1.0
loras[name] = multiplier
tmp_prompt_loras.append(loras)
prompt_loras.extend(tmp_prompt_loras * num_batches)
def reset_counters():
global text_model_encoder_counter
global diffusion_model_counter
# reset counter to uc head
text_model_encoder_counter = -1
diffusion_model_counter = 0
def lora_forward(compvis_module, input, res):
global text_model_encoder_counter
global diffusion_model_counter
import lora
if len(lora.loaded_loras) == 0:
return res
lora_layer_name: str | None = getattr(compvis_module, 'lora_layer_name', None)
if lora_layer_name is None:
return res
num_loras = len(lora.loaded_loras)
if text_model_encoder_counter == -1:
text_model_encoder_counter = len(prompt_loras) * num_loras
# print(f"lora.forward lora_layer_name={lora_layer_name} in.shape={input.shape} res.shape={res.shape} num_batches={num_batches} num_prompts={num_prompts}")
for lora in lora.loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is None:
continue
if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
patch = module.up(module.down(res))
else:
patch = module.up(module.down(input))
alpha = module.alpha / module.up.weight.shape[1] if module.alpha else 1.0
num_prompts = len(prompt_loras)
# print(f"lora.name={lora.name} lora.mul={lora.multiplier} alpha={alpha} pat.shape={patch.shape}")
if enabled:
if lora_layer_name.startswith("transformer_"): # "transformer_text_model_encoder_"
#
if 0 <= text_model_encoder_counter // num_loras < len(prompt_loras):
# c
loras = prompt_loras[text_model_encoder_counter // num_loras]
multiplier = loras.get(lora.name, 0.0)
if multiplier != 0.0:
# print(f"c #{text_model_encoder_counter // num_loras} lora.name={lora.name} mul={multiplier}")
res += multiplier * alpha * patch
else:
# uc
if opt_uc_text_model_encoder and lora.multiplier != 0.0:
# print(f"uc #{text_model_encoder_counter // num_loras} lora.name={lora.name} lora.mul={lora.multiplier}")
res += lora.multiplier * alpha * patch
if lora_layer_name.endswith("_11_mlp_fc2"): # last lora_layer_name of text_model_encoder
text_model_encoder_counter += 1
# c1 c1 c2 c2 .. .. uc uc
if text_model_encoder_counter == (len(prompt_loras) + num_batches) * num_loras:
text_model_encoder_counter = 0
elif lora_layer_name.startswith("diffusion_model_"): # "diffusion_model_"
if res.shape[0] == num_batches * num_prompts + num_batches:
# tensor.shape[1] == uncond.shape[1]
tensor_off = 0
uncond_off = num_batches * num_prompts
for b in range(num_batches):
# c
for p, loras in enumerate(prompt_loras):
multiplier = loras.get(lora.name, 0.0)
if multiplier != 0.0:
# print(f"tensor #{b}.{p} lora.name={lora.name} mul={multiplier}")
res[tensor_off] += multiplier * alpha * patch[tensor_off]
tensor_off += 1
# uc
if opt_uc_diffusion_model and lora.multiplier != 0.0:
# print(f"uncond lora.name={lora.name} lora.mul={lora.multiplier}")
res[uncond_off] += lora.multiplier * alpha * patch[uncond_off]
uncond_off += 1
else:
# tensor.shape[1] != uncond.shape[1]
cur_num_prompts = res.shape[0]
base = (diffusion_model_counter // cur_num_prompts) // num_loras * cur_num_prompts
if 0 <= base < len(prompt_loras):
# c
for off in range(cur_num_prompts):
loras = prompt_loras[base + off]
multiplier = loras.get(lora.name, 0.0)
if multiplier != 0.0:
# print(f"c #{base + off} lora.name={lora.name} mul={multiplier}", lora_layer_name=lora_layer_name)
res[off] += multiplier * alpha * patch[off]
else:
# uc
if opt_uc_diffusion_model and lora.multiplier != 0.0:
# print(f"uc {lora_layer_name} lora.name={lora.name} lora.mul={lora.multiplier}")
res += lora.multiplier * alpha * patch
if lora_layer_name.endswith("_11_1_proj_out"): # last lora_layer_name of diffusion_model
diffusion_model_counter += cur_num_prompts
# c1 c2 .. uc
if diffusion_model_counter >= (len(prompt_loras) + num_batches) * num_loras:
diffusion_model_counter = 0
else:
# default
if lora.multiplier != 0.0:
# print(f"default {lora_layer_name} lora.name={lora.name} lora.mul={lora.multiplier}")
res += lora.multiplier * alpha * patch
else:
# default
if lora.multiplier != 0.0:
# print(f"DEFAULT {lora_layer_name} lora.name={lora.name} lora.mul={lora.multiplier}")
res += lora.multiplier * alpha * patch
return res
def lora_Linear_forward(self, input):
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input))
def lora_Conv2d_forward(self, input):
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input))
enabled = False
opt_uc_text_model_encoder = False
opt_uc_diffusion_model = False
verbose = True
num_batches: int = 0
prompt_loras: List[Dict[str, float]] = []
text_model_encoder_counter: int = -1
diffusion_model_counter: int = 0