diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 58efcad23..5f91f801a 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -3,7 +3,7 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models +from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, sd_vae_consistency, shared, sd_models from modules.shared import opts, state import k_diffusion.sampling @@ -31,7 +31,7 @@ def setup_img2img_steps(p, steps=None): return steps, t_enc -approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3} +approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3, "Consistency Decoder": 4} def samples_to_images_tensor(sample, approximation=None, model=None): @@ -51,6 +51,13 @@ def samples_to_images_tensor(sample, approximation=None, model=None): elif approximation == 3: x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach() x_sample = x_sample * 2 - 1 + elif approximation == 4: + with devices.autocast(), torch.no_grad(): + x_sample = sd_vae_consistency.decoder_model()( + sample.to(devices.device, devices.dtype)/0.18215, + schedule=[1.0] + ) + sd_vae_consistency.unload() else: if model is None: model = shared.sd_model diff --git a/modules/sd_vae_consistency.py b/modules/sd_vae_consistency.py new file mode 100644 index 000000000..47e86e6b1 --- /dev/null +++ b/modules/sd_vae_consistency.py @@ -0,0 +1,35 @@ +""" +Consistency Decoder +Improved decoding for stable diffusion vaes. + +https://github.com/openai/consistencydecoder +""" +import os +import torch +import torch.nn as nn + +from modules import devices, paths_internal, shared +from consistencydecoder import ConsistencyDecoder + + +sd_vae_consistency_models = None +model_path = os.path.join(paths_internal.models_path, 'consistencydecoder') + + +def decoder_model(): + global sd_vae_consistency_models + if getattr(shared.sd_model, 'is_sdxl', False): + raise NotImplementedError("SDXL is not supported for consistency decoder") + if sd_vae_consistency_models is not None: + sd_vae_consistency_models.ckpt.to(devices.device) + return sd_vae_consistency_models + + loaded_model = ConsistencyDecoder(devices.device, model_path) + sd_vae_consistency_models = loaded_model + return loaded_model + + +def unload(): + global sd_vae_consistency_models + if sd_vae_consistency_models is not None: + sd_vae_consistency_models.ckpt.to('cpu') \ No newline at end of file diff --git a/modules/shared_options.py b/modules/shared_options.py index a9964fcbb..002a75454 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -172,7 +172,7 @@ For img2img, VAE is used to process user's input image before the sampling, and "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), "auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), - "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"), + "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD", "Consistency Decoder"]}, infotext='VAE Decoder').info("method to decode latent to image"), })) options_templates.update(options_section(('img2img', "img2img"), { diff --git a/requirements.txt b/requirements.txt index 80b438455..a6ba6a7a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,3 +32,5 @@ torch torchdiffeq torchsde transformers==4.30.2 + +git+https://github.com/openai/consistencydecoder.git diff --git a/requirements_versions.txt b/requirements_versions.txt index cb7403a9d..990c7da0f 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -30,3 +30,4 @@ torchdiffeq==0.2.3 torchsde==0.2.6 transformers==4.30.2 httpx==0.24.1 +git+https://github.com/openai/consistencydecoder.git