From 1a0353675de8b2f4d2ce784a37fe4d6121307131 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 23 Sep 2022 17:37:47 +0300 Subject: [PATCH] Option to use advanced upscalers with normal img2img --- modules/images.py | 15 ++++++++++++--- modules/processing.py | 4 ++-- modules/shared.py | 2 +- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/modules/images.py b/modules/images.py index 6cf56ddbe..c88ae4751 100644 --- a/modules/images.py +++ b/modules/images.py @@ -209,8 +209,16 @@ def draw_prompt_matrix(im, width, height, all_prompts): def resize_image(resize_mode, im, width, height): + def resize(im, w, h): + if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": + return im.resize((w, h), resample=LANCZOS) + + upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0] + return upscaler.upscale(im, w, h) + if resize_mode == 0: - res = im.resize((width, height), resample=LANCZOS) + res = resize(im, width, height) + elif resize_mode == 1: ratio = width / height src_ratio = im.width / im.height @@ -218,9 +226,10 @@ def resize_image(resize_mode, im, width, height): src_w = width if ratio > src_ratio else im.width * height // im.height src_h = height if ratio <= src_ratio else im.height * width // im.width - resized = im.resize((src_w, src_h), resample=LANCZOS) + resized = resize(im, src_w, src_h) res = Image.new("RGB", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) + else: ratio = width / height src_ratio = im.width / im.height @@ -228,7 +237,7 @@ def resize_image(resize_mode, im, width, height): src_w = width if ratio < src_ratio else im.width * height // im.height src_h = height if ratio >= src_ratio else im.height * width // im.width - resized = im.resize((src_w, src_h), resample=LANCZOS) + resized = resize(im, src_w, src_h) res = Image.new("RGB", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) diff --git a/modules/processing.py b/modules/processing.py index d27d86e99..79a159a2e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -462,7 +462,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: decoded_samples = self.sd_model.decode_first_stage(samples) - if opts.upscaler_for_hires_fix is None or opts.upscaler_for_hires_fix == "None": + if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") else: lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) @@ -472,7 +472,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) image = Image.fromarray(x_sample) - upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_hires_fix][0] + upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0] image = upscaler.upscale(image, self.width, self.height) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) diff --git a/modules/shared.py b/modules/shared.py index 0d7b8623a..667c3441f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -168,7 +168,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}), "ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}), - "upscaler_for_hires_fix": OptionInfo(None, "Upscaler for highres. fix", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), + "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), })) options_templates.update(options_section(('face-restoration', "Face restoration"), {