18
18
19
19
from collections import deque
20
20
from statistics import stdev , mean
21
-
21
+ from modules . shared import shared_instance
22
22
23
23
optimizer_dict = {optim_name : cls_obj for optim_name , cls_obj in inspect .getmembers (torch .optim , inspect .isclass ) if optim_name != "Optimizer" }
24
24
@@ -525,7 +525,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
525
525
526
526
pin_memory = shared .opts .pin_memory
527
527
528
- ds = modules .textual_inversion .dataset .PersonalizedBase (data_root = data_root , width = training_width , height = training_height , repeats = shared .opts .training_image_repeats_per_epoch , placeholder_token = hypernetwork_name , model = shared .sd_model , cond_model = shared .sd_model .cond_stage_model , device = devices .device , template_file = template_file , include_cond = True , batch_size = batch_size , gradient_step = gradient_step , shuffle_tags = shuffle_tags , tag_drop_out = tag_drop_out , latent_sampling_method = latent_sampling_method , varsize = varsize , use_weight = use_weight )
528
+ #ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
529
+ ds = modules .textual_inversion .dataset .PersonalizedBase (data_root = data_root , width = training_width , height = training_height , repeats = shared .opts .training_image_repeats_per_epoch , placeholder_token = hypernetwork_name , model = shared_instance .sd_model , cond_model = shared_instance .sd_model .cond_stage_model , device = devices .device , template_file = template_file , include_cond = True , batch_size = batch_size , gradient_step = gradient_step , shuffle_tags = shuffle_tags , tag_drop_out = tag_drop_out , latent_sampling_method = latent_sampling_method , varsize = varsize , use_weight = use_weight )
529
530
530
531
if shared .opts .save_training_settings_to_txt :
531
532
saved_params = dict (
@@ -542,8 +543,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
542
543
543
544
if unload :
544
545
shared .parallel_processing_allowed = False
545
- shared .sd_model .cond_stage_model .to (devices .cpu )
546
- shared .sd_model .first_stage_model .to (devices .cpu )
546
+ #shared.sd_model.cond_stage_model.to(devices.cpu)
547
+ #shared.sd_model.first_stage_model.to(devices.cpu)
548
+ shared_instance .sd_model .cond_stage_model .to (devices .cpu )
549
+ shared_instance .sd_model .first_stage_model .to (devices .cpu )
547
550
548
551
weights = hypernetwork .weights ()
549
552
hypernetwork .train ()
@@ -614,16 +617,21 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
614
617
if use_weight :
615
618
w = batch .weight .to (devices .device , non_blocking = pin_memory )
616
619
if tag_drop_out != 0 or shuffle_tags :
617
- shared .sd_model .cond_stage_model .to (devices .device )
618
- c = shared .sd_model .cond_stage_model (batch .cond_text ).to (devices .device , non_blocking = pin_memory )
619
- shared .sd_model .cond_stage_model .to (devices .cpu )
620
+ #shared.sd_model.cond_stage_model.to(devices.device)
621
+ shared_instance .sd_model .cond_stage_model .to (devices .device )
622
+ c = shared_instance .sd_model .cond_stage_model (batch .cond_text ).to (devices .device , non_blocking = pin_memory )
623
+ shared_instance .sd_model .cond_stage_model .to (devices .cpu )
624
+ #c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
625
+ #shared.sd_model.cond_stage_model.to(devices.cpu)
620
626
else :
621
627
c = stack_conds (batch .cond ).to (devices .device , non_blocking = pin_memory )
622
628
if use_weight :
623
- loss = shared .sd_model .weighted_forward (x , c , w )[0 ] / gradient_step
629
+ loss = shared_instance .sd_model .weighted_forward (x , c , w )[0 ] / gradient_step
630
+ #loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
624
631
del w
625
632
else :
626
- loss = shared .sd_model .forward (x , c )[0 ] / gradient_step
633
+ #loss = shared.sd_model.forward(x, c)[0] / gradient_step
634
+ loss = shared_instance .sd_model .forward (x , c )[0 ] / gradient_step
627
635
del x
628
636
del c
629
637
@@ -683,11 +691,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
683
691
cuda_rng_state = None
684
692
if torch .cuda .is_available ():
685
693
cuda_rng_state = torch .cuda .get_rng_state_all ()
686
- shared .sd_model .cond_stage_model .to (devices .device )
687
- shared .sd_model .first_stage_model .to (devices .device )
694
+ #shared.sd_model.cond_stage_model.to(devices.device)
695
+ #shared.sd_model.first_stage_model.to(devices.device)
696
+ shared_instance .sd_model .cond_stage_model .to (devices .device )
697
+ shared_instance .sd_model .first_stage_model .to (devices .device )
688
698
689
699
p = processing .StableDiffusionProcessingTxt2Img (
690
- sd_model = shared .sd_model ,
700
+ #sd_model=shared.sd_model,
701
+ sd_model = shared_instance .sd_model ,
691
702
do_not_save_grid = True ,
692
703
do_not_save_samples = True ,
693
704
)
@@ -716,8 +727,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
716
727
image = processed .images [0 ] if len (processed .images ) > 0 else None
717
728
718
729
if unload :
719
- shared .sd_model .cond_stage_model .to (devices .cpu )
720
- shared .sd_model .first_stage_model .to (devices .cpu )
730
+ #shared.sd_model.cond_stage_model.to(devices.cpu)
731
+ #shared.sd_model.first_stage_model.to(devices.cpu)
732
+ shared_instance .sd_model .cond_stage_model .to (devices .cpu )
733
+ shared_instance .sd_model .first_stage_model .to (devices .cpu )
721
734
torch .set_rng_state (rng_state )
722
735
if torch .cuda .is_available ():
723
736
torch .cuda .set_rng_state_all (cuda_rng_state )
@@ -760,8 +773,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
760
773
761
774
del optimizer
762
775
hypernetwork .optimizer_state_dict = None # dereference it after saving, to save memory.
763
- shared .sd_model .cond_stage_model .to (devices .device )
764
- shared .sd_model .first_stage_model .to (devices .device )
776
+ #shared.sd_model.cond_stage_model.to(devices.device)
777
+ #shared.sd_model.first_stage_model.to(devices.device)
778
+ shared_instance .sd_model .cond_stage_model .to (devices .device )
779
+ shared_instance .sd_model .first_stage_model .to (devices .device )
765
780
shared .parallel_processing_allowed = old_parallel_processing_allowed
766
781
767
782
return hypernetwork , filename
0 commit comments