From 17e846150c49395e44550e18cbe5120fcf64c173 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Tue, 28 May 2024 19:35:35 -0400 Subject: [PATCH] Add process_before_every_sampling hook --- modules/processing.py | 24 ++++++++++++++++++++++++ modules/scripts.py | 15 +++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/modules/processing.py b/modules/processing.py index 91cb94db1..79a3f0a72 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1330,6 +1330,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): # here we generate an image normally x = self.rng.next() + if self.scripts is not None: + self.scripts.process_before_every_sampling( + p=self, + x=x, + noise=x, + c=conditioning, + uc=unconditional_conditioning + ) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x @@ -1430,6 +1439,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.scripts is not None: self.scripts.before_hr(self) + self.scripts.process_before_every_sampling( + p=self, + x=samples, + noise=noise, + c=self.hr_c, + uc=self.hr_uc, + ) samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) @@ -1743,6 +1759,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier x *= self.initial_noise_multiplier + if self.scripts is not None: + self.scripts.process_before_every_sampling( + p=self, + x=self.init_latent, + noise=x, + c=conditioning, + uc=unconditional_conditioning + ) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: diff --git a/modules/scripts.py b/modules/scripts.py index 70ccfbe46..8eca396b1 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -187,6 +187,13 @@ class Script: """ pass + def process_before_every_sampling(self, p, *args, **kwargs): + """ + Similar to process(), called before every sampling. + If you use high-res fix, this will be called two times. + """ + pass + def process_batch(self, p, *args, **kwargs): """ Same as process(), but called for every batch. @@ -826,6 +833,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running process: {script.filename}", exc_info=True) + def process_before_every_sampling(self, p, **kwargs): + for script in self.ordered_scripts('process_before_every_sampling'): + try: + script_args = p.script_args[script.args_from:script.args_to] + script.process_before_every_sampling(p, *script_args, **kwargs) + except Exception: + errors.report(f"Error running process_before_every_sampling: {script.filename}", exc_info=True) + def before_process_batch(self, p, **kwargs): for script in self.ordered_scripts('before_process_batch'): try: