@@ -457,14 +457,20 @@ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps
457
457
opts .emphasis ,
458
458
)
459
459
460
- def apply_generation_params_states (self , generation_params_states ):
460
+ def apply_generation_params_list (self , generation_params_states ):
461
461
"""add and apply generation_params_states to self.extra_generation_params"""
462
462
for key , value in generation_params_states .items ():
463
463
if key in self .extra_generation_params and isinstance (current_value := self .extra_generation_params [key ], util .GenerationParametersList ):
464
464
self .extra_generation_params [key ] = current_value + value
465
465
else :
466
466
self .extra_generation_params [key ] = value
467
467
468
+ def clear_marked_generation_params (self ):
469
+ """clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
470
+ for key , value in list (self .extra_generation_params .items ()):
471
+ if getattr (value , 'to_be_clear_before_batch' , False ):
472
+ self .extra_generation_params .pop (key )
473
+
468
474
def get_conds_with_caching (self , function , required_prompts , steps , caches , extra_network_data , hires_steps = None ):
469
475
"""
470
476
Returns the result of calling function(shared.sd_model, required_prompts, steps)
@@ -491,7 +497,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
491
497
if len (cache ) == 3 :
492
498
generation_params_states , cached_cached_params = cache [2 ]
493
499
if cached_params == cached_cached_params :
494
- self .apply_generation_params_states (generation_params_states )
500
+ self .apply_generation_params_list (generation_params_states )
495
501
return cache [1 ]
496
502
497
503
cache = caches [0 ]
@@ -500,7 +506,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
500
506
cache [1 ] = function (shared .sd_model , required_prompts , steps , hires_steps , shared .opts .use_old_scheduling )
501
507
502
508
generation_params_states = model_hijack .extract_generation_params_states ()
503
- self .apply_generation_params_states (generation_params_states )
509
+ self .apply_generation_params_list (generation_params_states )
504
510
if len (cache ) == 2 :
505
511
cache .append ((generation_params_states , cached_params ))
506
512
else :
@@ -959,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
959
965
if state .interrupted or state .stopping_generation :
960
966
break
961
967
968
+ p .clear_marked_generation_params () # clean up some generation params are tagged to be cleared before batch
962
969
sd_models .reload_model_weights () # model can be changed for example by refiner
963
970
964
971
p .prompts = p .all_prompts [n * p .batch_size :(n + 1 ) * p .batch_size ]
0 commit comments