diff --git a/modules/processing.py b/modules/processing.py old mode 100755 new mode 100644 index e60cc92b0..066351c12 --- a/modules/processing.py +++ b/modules/processing.py @@ -407,12 +407,14 @@ class StableDiffusionProcessing: self.main_prompt = self.all_prompts[0] self.main_negative_prompt = self.all_negative_prompts[0] - def cached_params(self, required_prompts, steps, extra_network_data): + def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False): """Returns parameters that invalidate the cond cache if changed""" return ( required_prompts, steps, + hires_steps, + use_old_scheduling, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data, @@ -422,7 +424,7 @@ class StableDiffusionProcessing: self.height, ) - def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data): + def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) using a cache to store the result if the same arguments have been used before. @@ -435,7 +437,7 @@ class StableDiffusionProcessing: caches is a list with items described above. """ - cached_params = self.cached_params(required_prompts, steps, extra_network_data) + cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling) for cache in caches: if cache[0] is not None and cached_params == cache[0]: @@ -444,7 +446,7 @@ class StableDiffusionProcessing: cache = caches[0] with devices.autocast(): - cache[1] = function(shared.sd_model, required_prompts, steps) + cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) cache[0] = cached_params return cache[1] @@ -456,6 +458,8 @@ class StableDiffusionProcessing: sampler_config = sd_samplers.find_sampler_config(self.sampler_name) total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps self.step_multiplier = total_steps // self.steps + self.firstpass_steps = total_steps + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data) @@ -1292,8 +1296,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): steps = self.hr_second_pass_steps or self.steps total_steps = sampler_config.total_steps(steps) if sampler_config else steps - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps) def setup_conds(self): if self.is_hr_pass: diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index e811ae99d..334efeef3 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -26,7 +26,7 @@ plain: /([^\\\[\]():|]|\\.)+/ %import common.SIGNED_NUMBER -> NUMBER """) -def get_learned_conditioning_prompt_schedules(prompts, steps): +def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False): """ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0] >>> g("test") @@ -57,18 +57,39 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']] >>> g("[fe|||]male") [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']] + >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0] + >>> g("a [b:.5] c") + [[10, 'a b c']] + >>> g("a [b:1.5] c") + [[5, 'a c'], [10, 'a b c']] """ + if hires_steps is None or use_old_scheduling: + int_offset = 0 + flt_offset = 0 + steps = base_steps + else: + int_offset = base_steps + flt_offset = 1.0 + steps = hires_steps + def collect_steps(steps, tree): res = [steps] class CollectSteps(lark.Visitor): def scheduled(self, tree): - tree.children[-2] = float(tree.children[-2]) - if tree.children[-2] < 1: - tree.children[-2] *= steps - tree.children[-2] = min(steps, int(tree.children[-2])) - res.append(tree.children[-2]) + s = tree.children[-2] + v = float(s) + if use_old_scheduling: + v = v*steps if v<1 else v + else: + if "." in s: + v = (v - flt_offset) * steps + else: + v = (v - int_offset) + tree.children[-2] = min(steps, int(v)) + if tree.children[-2] >= 1: + res.append(tree.children[-2]) def alternate(self, tree): res.extend(range(1, steps+1)) @@ -134,7 +155,7 @@ class SdConditioning(list): -def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): +def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False): """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), and the sampling step at which this condition is to be replaced by the next one. @@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): """ res = [] - prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) + prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling) cache = {} for prompt, prompt_schedule in zip(prompts, prompt_schedules): @@ -229,7 +250,7 @@ class MulticondLearnedConditioning: self.batch: List[List[ComposableScheduledPromptConditioning]] = batch -def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: +def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning: """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. For each prompt, the list is obtained by splitting the prompt using the AND separator. @@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts) - learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps) + learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling) res = [] for indexes in res_indexes: diff --git a/modules/shared_options.py b/modules/shared_options.py index 88f6b334c..d13898380 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -203,6 +203,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), + "use_old_scheduling": OptionInfo(False, "Use old prompt where first pass and hires both used the same timeline, and < 1 meant relative and >= 1 meant absolute"), })) options_templates.update(options_section(('interrogate', "Interrogate"), {