Skip to content

Commit 72b1119

Browse files
committed
fully working sd webui api on ray serve
fully working sd webui api on ray serve
1 parent 6df93bc commit 72b1119

29 files changed

+267
-140
lines changed

extensions-builtin/Lora/networks.py

+26-12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Union
1616

1717
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
18+
from modules.shared import shared_instance
1819

1920
module_types = [
2021
network_lora.ModuleTypeLora(),
@@ -112,8 +113,10 @@ def match(match_list, regex_text):
112113
def assign_network_names_to_compvis_modules(sd_model):
113114
network_layer_mapping = {}
114115

115-
if shared.sd_model.is_sdxl:
116-
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
116+
#if shared.sd_model.is_sdxl:
117+
if shared_instance.sd_model.is_sdxl:
118+
#for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
119+
for i, embedder in enumerate(shared_instance.sd_model.conditioner.embedders):
117120
if not hasattr(embedder, 'wrapped'):
118121
continue
119122

@@ -122,12 +125,14 @@ def assign_network_names_to_compvis_modules(sd_model):
122125
network_layer_mapping[network_name] = module
123126
module.network_layer_name = network_name
124127
else:
125-
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
128+
#for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
129+
for name, module in shared_instance.sd_model.cond_stage_model.wrapped.named_modules():
126130
network_name = name.replace(".", "_")
127131
network_layer_mapping[network_name] = module
128132
module.network_layer_name = network_name
129133

130-
for name, module in shared.sd_model.model.named_modules():
134+
#for name, module in shared.sd_model.model.named_modules():
135+
for name, module in shared_instance.sd_model.model.named_modules():
131136
network_name = name.replace(".", "_")
132137
network_layer_mapping[network_name] = module
133138
module.network_layer_name = network_name
@@ -142,37 +147,46 @@ def load_network(name, network_on_disk):
142147
sd = sd_models.read_state_dict(network_on_disk.filename)
143148

144149
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
145-
if not hasattr(shared.sd_model, 'network_layer_mapping'):
146-
assign_network_names_to_compvis_modules(shared.sd_model)
150+
#if not hasattr(shared.sd_model, 'network_layer_mapping'):
151+
#assign_network_names_to_compvis_modules(shared.sd_model)
152+
if not hasattr(shared_instance.sd_model, 'network_layer_mapping'):
153+
assign_network_names_to_compvis_modules(shared_instance.sd_model)
147154

148155
keys_failed_to_match = {}
149-
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
156+
#is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
157+
is_sd2 = 'model_transformer_resblocks' in shared_instance.sd_model.network_layer_mapping
150158

151159
matched_networks = {}
152160

153161
for key_network, weight in sd.items():
154162
key_network_without_network_parts, network_part = key_network.split(".", 1)
155163

156164
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
157-
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
165+
#sd_module = shared.sd_model.network_layer_mapping.get(key, None)
166+
sd_module = shared_instance.sd_model.network_layer_mapping.get(key, None)
158167

159168
if sd_module is None:
160169
m = re_x_proj.match(key)
161170
if m:
162-
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
171+
#sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
172+
sd_module = shared_instance.sd_model.network_layer_mapping.get(m.group(1), None)
173+
163174

164175
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
165176
if sd_module is None and "lora_unet" in key_network_without_network_parts:
166177
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
167-
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
178+
#sd_module = shared.sd_model.network_layer_mapping.get(key, None)
179+
sd_module = shared_instance.sd_model.network_layer_mapping.get(key, None)
168180
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
169181
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
170-
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
182+
#sd_module = shared.sd_model.network_layer_mapping.get(key, None)
183+
sd_module = shared_instance.sd_model.network_layer_mapping.get(key, None)
171184

172185
# some SD1 Loras also have correct compvis keys
173186
if sd_module is None:
174187
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
175-
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
188+
#sd_module = shared.sd_model.network_layer_mapping.get(key, None)
189+
sd_module = shared_instance.sd_model.network_layer_mapping.get(key, None)
176190

177191
if sd_module is None:
178192
keys_failed_to_match[key_network] = key

extensions-builtin/Lora/ui_extra_networks_lora.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from modules import shared, ui_extra_networks
77
from modules.ui_extra_networks import quote_js
88
from ui_edit_user_metadata import LoraUserMetadataEditor
9-
9+
from modules.shared import shared_instance
1010

1111
class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
1212
def __init__(self):
@@ -52,17 +52,26 @@ def create_item(self, name, index=None, enable_filter=True):
5252

5353
if shared.opts.lora_show_all or not enable_filter:
5454
pass
55+
# elif sd_version == network.SdVersion.Unknown:
56+
# model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
57+
# if model_version.name in shared.opts.lora_hide_unknown_for_versions:
58+
# return None
59+
# elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
60+
# return None
61+
# elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
62+
# return None
63+
# elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
64+
# return None
5565
elif sd_version == network.SdVersion.Unknown:
56-
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
66+
model_version = network.SdVersion.SDXL if shared_instance.sd_model.is_sdxl else network.SdVersion.SD2 if shared_instance.sd_model.is_sd2 else network.SdVersion.SD1
5767
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
5868
return None
59-
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
69+
elif shared_instance.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
6070
return None
61-
elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
71+
elif shared_instance.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
6272
return None
63-
elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
73+
elif shared_instance.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
6474
return None
65-
6675
return item
6776

6877
def list_items(self):

modules/api/api.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import piexif
3434
import piexif.helper
3535
from contextlib import closing
36-
36+
from modules.shared import shared_instance
3737

3838
def script_name_to_index(name, scripts):
3939
try:
@@ -364,7 +364,8 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
364364
args.pop('save_images', None)
365365

366366
with self.queue_lock:
367-
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
367+
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared_instance.sd_model, **args)) as p:
368+
#with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
368369
p.is_api = True
369370
p.scripts = script_runner
370371
p.outpath_grids = opts.outdir_txt2img_grids
@@ -424,7 +425,8 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
424425
args.pop('save_images', None)
425426

426427
with self.queue_lock:
427-
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
428+
#with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
429+
with closing(StableDiffusionProcessingImg2Img(sd_model=shared_instance.sd_model, **args)) as p:
428430
p.init_images = [decode_base64_to_image(x) for x in init_images]
429431
p.is_api = True
430432
p.scripts = script_runner
@@ -724,8 +726,10 @@ def train_hypernetwork(self, args: dict):
724726
except Exception as e:
725727
error = e
726728
finally:
727-
shared.sd_model.cond_stage_model.to(devices.device)
728-
shared.sd_model.first_stage_model.to(devices.device)
729+
#shared.sd_model.cond_stage_model.to(devices.device)
730+
#shared.sd_model.first_stage_model.to(devices.device)
731+
shared_instance.sd_model.cond_stage_model.to(devices.device)
732+
shared_instance.sd_model.first_stage_model.to(devices.device)
729733
if not apply_optimizations:
730734
sd_hijack.apply_optimizations()
731735
shared.state.end()

modules/api/raypi.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from modules import initialize_util
4141
from modules import script_callbacks
4242
import os
43-
43+
from modules.shared import shared_instance
4444
import launch
4545
from ray import serve
4646

@@ -349,8 +349,8 @@ def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
349349
send_images = args.pop('send_images', True)
350350
args.pop('save_images', None)
351351

352-
353-
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
352+
with closing(StableDiffusionProcessingTxt2Img(sd_model=shared_instance.sd_model, **args)) as p:
353+
#with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:
354354
p.is_api = True
355355
p.scripts = script_runner
356356
p.outpath_grids = opts.outdir_txt2img_grids
@@ -409,8 +409,8 @@ def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
409409
send_images = args.pop('send_images', True)
410410
args.pop('save_images', None)
411411

412-
413-
with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
412+
with closing(StableDiffusionProcessingImg2Img(sd_model=shared_instance.sd_model, **args)) as p:
413+
#with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:
414414
p.init_images = [decode_base64_to_image(x) for x in init_images]
415415
p.is_api = True
416416
p.scripts = script_runner
@@ -737,8 +737,10 @@ def train_hypernetwork(self, args: dict):
737737
except Exception as e:
738738
error = e
739739
finally:
740-
shared.sd_model.cond_stage_model.to(devices.device)
741-
shared.sd_model.first_stage_model.to(devices.device)
740+
#shared.sd_model.cond_stage_model.to(devices.device)
741+
#shared.sd_model.first_stage_model.to(devices.device)
742+
shared_instance.sd_model.cond_stage_model.to(devices.device)
743+
shared_instance.sd_model.first_stage_model.to(devices.device)
742744
if not apply_optimizations:
743745
sd_hijack.apply_optimizations()
744746
shared.state.end()

modules/hypernetworks/hypernetwork.py

+31-16
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from collections import deque
2020
from statistics import stdev, mean
21-
21+
from modules.shared import shared_instance
2222

2323
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
2424

@@ -525,7 +525,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
525525

526526
pin_memory = shared.opts.pin_memory
527527

528-
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
528+
#ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
529+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared_instance.sd_model, cond_model=shared_instance.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
529530

530531
if shared.opts.save_training_settings_to_txt:
531532
saved_params = dict(
@@ -542,8 +543,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
542543

543544
if unload:
544545
shared.parallel_processing_allowed = False
545-
shared.sd_model.cond_stage_model.to(devices.cpu)
546-
shared.sd_model.first_stage_model.to(devices.cpu)
546+
#shared.sd_model.cond_stage_model.to(devices.cpu)
547+
#shared.sd_model.first_stage_model.to(devices.cpu)
548+
shared_instance.sd_model.cond_stage_model.to(devices.cpu)
549+
shared_instance.sd_model.first_stage_model.to(devices.cpu)
547550

548551
weights = hypernetwork.weights()
549552
hypernetwork.train()
@@ -614,16 +617,21 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
614617
if use_weight:
615618
w = batch.weight.to(devices.device, non_blocking=pin_memory)
616619
if tag_drop_out != 0 or shuffle_tags:
617-
shared.sd_model.cond_stage_model.to(devices.device)
618-
c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
619-
shared.sd_model.cond_stage_model.to(devices.cpu)
620+
#shared.sd_model.cond_stage_model.to(devices.device)
621+
shared_instance.sd_model.cond_stage_model.to(devices.device)
622+
c = shared_instance.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
623+
shared_instance.sd_model.cond_stage_model.to(devices.cpu)
624+
#c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
625+
#shared.sd_model.cond_stage_model.to(devices.cpu)
620626
else:
621627
c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
622628
if use_weight:
623-
loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
629+
loss = shared_instance.sd_model.weighted_forward(x, c, w)[0] / gradient_step
630+
#loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
624631
del w
625632
else:
626-
loss = shared.sd_model.forward(x, c)[0] / gradient_step
633+
#loss = shared.sd_model.forward(x, c)[0] / gradient_step
634+
loss = shared_instance.sd_model.forward(x, c)[0] / gradient_step
627635
del x
628636
del c
629637

@@ -683,11 +691,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
683691
cuda_rng_state = None
684692
if torch.cuda.is_available():
685693
cuda_rng_state = torch.cuda.get_rng_state_all()
686-
shared.sd_model.cond_stage_model.to(devices.device)
687-
shared.sd_model.first_stage_model.to(devices.device)
694+
#shared.sd_model.cond_stage_model.to(devices.device)
695+
#shared.sd_model.first_stage_model.to(devices.device)
696+
shared_instance.sd_model.cond_stage_model.to(devices.device)
697+
shared_instance.sd_model.first_stage_model.to(devices.device)
688698

689699
p = processing.StableDiffusionProcessingTxt2Img(
690-
sd_model=shared.sd_model,
700+
#sd_model=shared.sd_model,
701+
sd_model=shared_instance.sd_model,
691702
do_not_save_grid=True,
692703
do_not_save_samples=True,
693704
)
@@ -716,8 +727,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
716727
image = processed.images[0] if len(processed.images) > 0 else None
717728

718729
if unload:
719-
shared.sd_model.cond_stage_model.to(devices.cpu)
720-
shared.sd_model.first_stage_model.to(devices.cpu)
730+
#shared.sd_model.cond_stage_model.to(devices.cpu)
731+
#shared.sd_model.first_stage_model.to(devices.cpu)
732+
shared_instance.sd_model.cond_stage_model.to(devices.cpu)
733+
shared_instance.sd_model.first_stage_model.to(devices.cpu)
721734
torch.set_rng_state(rng_state)
722735
if torch.cuda.is_available():
723736
torch.cuda.set_rng_state_all(cuda_rng_state)
@@ -760,8 +773,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
760773

761774
del optimizer
762775
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
763-
shared.sd_model.cond_stage_model.to(devices.device)
764-
shared.sd_model.first_stage_model.to(devices.device)
776+
#shared.sd_model.cond_stage_model.to(devices.device)
777+
#shared.sd_model.first_stage_model.to(devices.device)
778+
shared_instance.sd_model.cond_stage_model.to(devices.device)
779+
shared_instance.sd_model.first_stage_model.to(devices.device)
765780
shared.parallel_processing_allowed = old_parallel_processing_allowed
766781

767782
return hypernetwork, filename

modules/hypernetworks/ui.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import gradio as gr
44
import modules.hypernetworks.hypernetwork
55
from modules import devices, sd_hijack, shared
6-
6+
from modules.shared import shared_instance
77
not_available = ["hardswish", "multiheadattention"]
88
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
99

@@ -32,7 +32,9 @@ def train_hypernetwork(*args):
3232
except Exception:
3333
raise
3434
finally:
35-
shared.sd_model.cond_stage_model.to(devices.device)
36-
shared.sd_model.first_stage_model.to(devices.device)
35+
#shared.sd_model.cond_stage_model.to(devices.device)
36+
#shared.sd_model.first_stage_model.to(devices.device)
37+
shared_instance.sd_model.cond_stage_model.to(devices.device)
38+
shared_instance.sd_model.first_stage_model.to(devices.device)
3739
sd_hijack.apply_optimizations()
3840

0 commit comments

Comments
 (0)