mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-11-21 03:11:40 +08:00
Add consistency decoder
This commit is contained in:
parent
9c1c0da026
commit
64fd916334
@ -3,7 +3,7 @@ from collections import namedtuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
|
||||
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, sd_vae_consistency, shared, sd_models
|
||||
from modules.shared import opts, state
|
||||
import k_diffusion.sampling
|
||||
|
||||
@ -31,7 +31,7 @@ def setup_img2img_steps(p, steps=None):
|
||||
return steps, t_enc
|
||||
|
||||
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
|
||||
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3, "Consistency Decoder": 4}
|
||||
|
||||
|
||||
def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
@ -51,6 +51,13 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
elif approximation == 3:
|
||||
x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
|
||||
x_sample = x_sample * 2 - 1
|
||||
elif approximation == 4:
|
||||
with devices.autocast(), torch.no_grad():
|
||||
x_sample = sd_vae_consistency.decoder_model()(
|
||||
sample.to(devices.device, devices.dtype)/0.18215,
|
||||
schedule=[1.0]
|
||||
)
|
||||
sd_vae_consistency.unload()
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
|
35
modules/sd_vae_consistency.py
Normal file
35
modules/sd_vae_consistency.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""
|
||||
Consistency Decoder
|
||||
Improved decoding for stable diffusion vaes.
|
||||
|
||||
https://github.com/openai/consistencydecoder
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modules import devices, paths_internal, shared
|
||||
from consistencydecoder import ConsistencyDecoder
|
||||
|
||||
|
||||
sd_vae_consistency_models = None
|
||||
model_path = os.path.join(paths_internal.models_path, 'consistencydecoder')
|
||||
|
||||
|
||||
def decoder_model():
|
||||
global sd_vae_consistency_models
|
||||
if getattr(shared.sd_model, 'is_sdxl', False):
|
||||
raise NotImplementedError("SDXL is not supported for consistency decoder")
|
||||
if sd_vae_consistency_models is not None:
|
||||
sd_vae_consistency_models.ckpt.to(devices.device)
|
||||
return sd_vae_consistency_models
|
||||
|
||||
loaded_model = ConsistencyDecoder(devices.device, model_path)
|
||||
sd_vae_consistency_models = loaded_model
|
||||
return loaded_model
|
||||
|
||||
|
||||
def unload():
|
||||
global sd_vae_consistency_models
|
||||
if sd_vae_consistency_models is not None:
|
||||
sd_vae_consistency_models.ckpt.to('cpu')
|
@ -172,7 +172,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
|
||||
"sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
|
||||
"auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
|
||||
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
|
||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD", "Consistency Decoder"]}, infotext='VAE Decoder').info("method to decode latent to image"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('img2img', "img2img"), {
|
||||
|
@ -32,3 +32,5 @@ torch
|
||||
torchdiffeq
|
||||
torchsde
|
||||
transformers==4.30.2
|
||||
|
||||
git+https://github.com/openai/consistencydecoder.git
|
||||
|
@ -30,3 +30,4 @@ torchdiffeq==0.2.3
|
||||
torchsde==0.2.6
|
||||
transformers==4.30.2
|
||||
httpx==0.24.1
|
||||
git+https://github.com/openai/consistencydecoder.git
|
||||
|
Loading…
Reference in New Issue
Block a user