Skip to content

Commit fb6a050

Browse files
committed
clear GenerationParametersList before batch
clears any generation parameters that are with the attribute to_be_clear_before_batch = True prevent buildup of some parameters
1 parent 0250802 commit fb6a050

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

modules/processing.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -457,14 +457,20 @@ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps
457457
opts.emphasis,
458458
)
459459

460-
def apply_generation_params_states(self, generation_params_states):
460+
def apply_generation_params_list(self, generation_params_states):
461461
"""add and apply generation_params_states to self.extra_generation_params"""
462462
for key, value in generation_params_states.items():
463463
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
464464
self.extra_generation_params[key] = current_value + value
465465
else:
466466
self.extra_generation_params[key] = value
467467

468+
def clear_marked_generation_params(self):
469+
"""clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
470+
for key, value in list(self.extra_generation_params.items()):
471+
if getattr(value, 'to_be_clear_before_batch', False):
472+
self.extra_generation_params.pop(key)
473+
468474
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
469475
"""
470476
Returns the result of calling function(shared.sd_model, required_prompts, steps)
@@ -491,7 +497,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
491497
if len(cache) == 3:
492498
generation_params_states, cached_cached_params = cache[2]
493499
if cached_params == cached_cached_params:
494-
self.apply_generation_params_states(generation_params_states)
500+
self.apply_generation_params_list(generation_params_states)
495501
return cache[1]
496502

497503
cache = caches[0]
@@ -500,7 +506,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr
500506
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
501507

502508
generation_params_states = model_hijack.extract_generation_params_states()
503-
self.apply_generation_params_states(generation_params_states)
509+
self.apply_generation_params_list(generation_params_states)
504510
if len(cache) == 2:
505511
cache.append((generation_params_states, cached_params))
506512
else:
@@ -959,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
959965
if state.interrupted or state.stopping_generation:
960966
break
961967

968+
p.clear_marked_generation_params() # clean up some generation params are tagged to be cleared before batch
962969
sd_models.reload_model_weights() # model can be changed for example by refiner
963970

964971
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]

modules/util.py

+8
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,17 @@ class GenerationParametersList(list):
308308
if return str, the value will be written to infotext, if return None will be ignored.
309309
"""
310310

311+
def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
312+
super().__init__(*args, **kwargs)
313+
self._to_be_clear_before_batch = to_be_clear_before_batch
314+
311315
def __call__(self, *args, **kwargs):
312316
return ', '.join(sorted(set(self), key=natural_sort_key))
313317

318+
@property
319+
def to_be_clear_before_batch(self):
320+
return self.to_be_clear_before_batch
321+
314322
def __add__(self, other):
315323
if isinstance(other, GenerationParametersList):
316324
return self.__class__([*self, *other])

0 commit comments

Comments
 (0)