2022-08-28 02:32:28 +08:00
import argparse
import os
import sys
2022-08-26 04:31:44 +08:00
from collections import namedtuple
2022-08-22 22:15:46 +08:00
import torch
import torch . nn as nn
import numpy as np
import gradio as gr
from omegaconf import OmegaConf
2022-08-24 22:57:49 +08:00
from PIL import Image , ImageFont , ImageDraw , PngImagePlugin
2022-08-22 22:15:46 +08:00
from torch import autocast
import mimetypes
import random
2022-08-23 01:08:32 +08:00
import math
2022-08-24 23:47:23 +08:00
import html
import time
2022-08-26 02:52:05 +08:00
import json
import traceback
2022-08-22 22:15:46 +08:00
2022-08-26 04:31:44 +08:00
import k_diffusion . sampling
2022-08-22 22:15:46 +08:00
from ldm . util import instantiate_from_config
from ldm . models . diffusion . ddim import DDIMSampler
from ldm . models . diffusion . plms import PLMSSampler
2022-08-23 16:58:50 +08:00
try :
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging . set_verbosity_error ( )
2022-08-28 02:32:28 +08:00
except Exception :
2022-08-23 16:58:50 +08:00
pass
2022-08-22 22:15:46 +08:00
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes . init ( )
mimetypes . add_type ( ' application/javascript ' , ' .js ' )
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
2022-08-24 06:08:59 +08:00
LANCZOS = ( Image . Resampling . LANCZOS if hasattr ( Image , ' Resampling ' ) else Image . LANCZOS )
2022-08-28 02:32:28 +08:00
invalid_filename_chars = ' <>: " / \\ |?* \n '
2022-08-26 02:52:05 +08:00
config_filename = " config.json "
2022-08-23 05:34:49 +08:00
2022-08-22 22:15:46 +08:00
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --config " , type = str , default = " configs/stable-diffusion/v1-inference.yaml " , help = " path to config which constructs model " , )
parser . add_argument ( " --ckpt " , type = str , default = " models/ldm/stable-diffusion-v1/model.ckpt " , help = " path to checkpoint of model " , )
2022-08-28 02:32:28 +08:00
parser . add_argument ( " --gfpgan-dir " , type = str , help = " GFPGAN directory " , default = ( ' ./src/gfpgan ' if os . path . exists ( ' ./src/gfpgan ' ) else ' ./GFPGAN ' ) )
2022-08-24 05:38:53 +08:00
parser . add_argument ( " --no-half " , action = ' store_true ' , help = " do not switch the model to 16-bit floats " )
2022-08-24 14:06:36 +08:00
parser . add_argument ( " --no-progressbar-hiding " , action = ' store_true ' , help = " do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser) " )
2022-08-24 22:57:49 +08:00
parser . add_argument ( " --max-batch-count " , type = int , default = 16 , help = " maximum batch count value for the UI " )
2022-08-26 02:52:05 +08:00
parser . add_argument ( " --embeddings-dir " , type = str , default = ' embeddings ' , help = " embeddings dirtectory for textual inversion (default: embeddings) " )
2022-08-24 22:57:49 +08:00
2022-08-26 02:52:05 +08:00
cmd_opts = parser . parse_args ( )
2022-08-22 22:15:46 +08:00
2022-08-24 14:06:36 +08:00
css_hide_progressbar = """
. wrap . m - 12 svg { display : none ! important ; }
. wrap . m - 12 : : before { content : " Loading... " }
. progress - bar { display : none ! important ; }
. meta - text { display : none ! important ; }
"""
2022-08-22 22:15:46 +08:00
2022-08-26 04:31:44 +08:00
SamplerData = namedtuple ( ' SamplerData ' , [ ' name ' , ' constructor ' ] )
samplers = [
2022-08-28 02:32:28 +08:00
* [ SamplerData ( x [ 0 ] , lambda funcname = x [ 1 ] : KDiffusionSampler ( funcname ) ) for x in [
2022-08-26 04:31:44 +08:00
( ' LMS ' , ' sample_lms ' ) ,
( ' Heun ' , ' sample_heun ' ) ,
( ' Euler ' , ' sample_euler ' ) ,
( ' Euler ancestral ' , ' sample_euler_ancestral ' ) ,
( ' DPM 2 ' , ' sample_dpm_2 ' ) ,
( ' DPM 2 Ancestral ' , ' sample_dpm_2_ancestral ' ) ,
] if hasattr ( k_diffusion . sampling , x [ 1 ] ) ] ,
2022-08-28 02:32:28 +08:00
SamplerData ( ' DDIM ' , lambda : VanillaStableDiffusionSampler ( DDIMSampler ) ) ,
SamplerData ( ' PLMS ' , lambda : VanillaStableDiffusionSampler ( PLMSSampler ) ) ,
2022-08-26 04:31:44 +08:00
]
2022-08-26 19:10:40 +08:00
samplers_for_img2img = [ x for x in samplers if x . name != ' DDIM ' and x . name != ' PLMS ' ]
2022-08-26 02:52:05 +08:00
2022-08-26 16:16:57 +08:00
RealesrganModelInfo = namedtuple ( " RealesrganModelInfo " , [ " name " , " location " , " model " , " netscale " ] )
try :
from basicsr . archs . rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan . archs . srvgg_arch import SRVGGNetCompact
realesrgan_models = [
RealesrganModelInfo (
name = " Real-ESRGAN 4x plus " ,
location = " https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth " ,
netscale = 4 , model = lambda : RRDBNet ( num_in_ch = 3 , num_out_ch = 3 , num_feat = 64 , num_block = 23 , num_grow_ch = 32 , scale = 4 )
) ,
RealesrganModelInfo (
name = " Real-ESRGAN 4x plus anime 6B " ,
location = " https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth " ,
netscale = 4 , model = lambda : RRDBNet ( num_in_ch = 3 , num_out_ch = 3 , num_feat = 64 , num_block = 6 , num_grow_ch = 32 , scale = 4 )
) ,
2022-08-27 21:13:33 +08:00
RealesrganModelInfo (
name = " Real-ESRGAN 2x plus " ,
location = " https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth " ,
netscale = 2 , model = lambda : RRDBNet ( num_in_ch = 3 , num_out_ch = 3 , num_feat = 64 , num_block = 23 , num_grow_ch = 32 , scale = 2 )
) ,
2022-08-26 16:16:57 +08:00
]
have_realesrgan = True
2022-08-28 02:32:28 +08:00
except Exception :
2022-08-26 16:16:57 +08:00
print ( " Error loading Real-ESRGAN: " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
realesrgan_models = [ RealesrganModelInfo ( ' None ' , ' ' , 0 , None ) ]
have_realesrgan = False
2022-08-26 02:52:05 +08:00
class Options :
2022-08-28 02:32:28 +08:00
class OptionInfo :
def __init__ ( self , default = None , label = " " , component = None , component_args = None ) :
self . default = default
self . label = label
self . component = component
self . component_args = component_args
2022-08-26 02:52:05 +08:00
data = None
data_labels = {
2022-08-28 02:32:28 +08:00
" outdir " : OptionInfo ( " " , " Output dictectory; if empty, defaults to ' outputs/* ' " ) ,
" samples_save " : OptionInfo ( True , " Save indiviual samples " ) ,
" samples_format " : OptionInfo ( ' png ' , ' File format for indiviual samples ' ) ,
" grid_save " : OptionInfo ( True , " Save image grids " ) ,
" grid_format " : OptionInfo ( ' png ' , ' File format for grids ' ) ,
" grid_extended_filename " : OptionInfo ( False , " Add extended info (seed, prompt) to filename when saving grid " ) ,
" n_rows " : OptionInfo ( - 1 , " Grid row count; use -1 for autodetect and 0 for it to be same as batch size " , gr . Slider , { " minimum " : - 1 , " maximum " : 16 , " step " : 1 } ) ,
" jpeg_quality " : OptionInfo ( 80 , " Quality for saved jpeg images " , gr . Slider , { " minimum " : 1 , " maximum " : 100 , " step " : 1 } ) ,
" enable_pnginfo " : OptionInfo ( True , " Save text information about generation parameters as chunks to png files " ) ,
" prompt_matrix_add_to_start " : OptionInfo ( True , " In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end " ) ,
" sd_upscale_overlap " : OptionInfo ( 64 , " Overlap for tiles for SD upscale. The smaller it is, the less smooth transition from one tile to another " , gr . Slider , { " minimum " : 0 , " maximum " : 256 , " step " : 16 } ) ,
2022-08-26 02:52:05 +08:00
}
def __init__ ( self ) :
2022-08-28 02:32:28 +08:00
self . data = { k : v . default for k , v in self . data_labels . items ( ) }
2022-08-26 02:52:05 +08:00
def __setattr__ ( self , key , value ) :
if self . data is not None :
if key in self . data :
self . data [ key ] = value
return super ( Options , self ) . __setattr__ ( key , value )
def __getattr__ ( self , item ) :
if self . data is not None :
if item in self . data :
return self . data [ item ]
2022-08-26 13:47:44 +08:00
if item in self . data_labels :
2022-08-28 02:32:28 +08:00
return self . data_labels [ item ] . default
2022-08-26 13:47:44 +08:00
2022-08-26 02:52:05 +08:00
return super ( Options , self ) . __getattribute__ ( item )
def save ( self , filename ) :
with open ( filename , " w " , encoding = " utf8 " ) as file :
json . dump ( self . data , file )
def load ( self , filename ) :
with open ( filename , " r " , encoding = " utf8 " ) as file :
self . data = json . load ( file )
2022-08-22 22:15:46 +08:00
def load_model_from_config ( config , ckpt , verbose = False ) :
print ( f " Loading model from { ckpt } " )
pl_sd = torch . load ( ckpt , map_location = " cpu " )
if " global_step " in pl_sd :
print ( f " Global Step: { pl_sd [ ' global_step ' ] } " )
sd = pl_sd [ " state_dict " ]
model = instantiate_from_config ( config . model )
m , u = model . load_state_dict ( sd , strict = False )
if len ( m ) > 0 and verbose :
print ( " missing keys: " )
print ( m )
if len ( u ) > 0 and verbose :
print ( " unexpected keys: " )
print ( u )
model . cuda ( )
model . eval ( )
return model
2022-08-24 03:42:43 +08:00
def create_random_tensors ( shape , seeds ) :
2022-08-23 19:07:37 +08:00
xs = [ ]
2022-08-24 03:42:43 +08:00
for seed in seeds :
torch . manual_seed ( seed )
# randn results depend on device; gpu and cpu get different results for same seed;
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
# but the original script had it like this so i do not dare change it for now because
# it will break everyone's seeds.
2022-08-23 19:07:37 +08:00
xs . append ( torch . randn ( shape , device = device ) )
x = torch . stack ( xs )
return x
2022-08-24 21:12:33 +08:00
def torch_gc ( ) :
torch . cuda . empty_cache ( )
torch . cuda . ipc_collect ( )
2022-08-23 19:07:37 +08:00
2022-08-24 22:41:37 +08:00
2022-08-24 22:57:49 +08:00
def save_image ( image , path , basename , seed , prompt , extension , info = None , short_filename = False ) :
2022-08-24 22:41:37 +08:00
prompt = sanitize_filename_part ( prompt )
if short_filename :
filename = f " { basename } . { extension } "
else :
filename = f " { basename } - { seed } - { prompt [ : 128 ] } . { extension } "
2022-08-26 16:16:57 +08:00
if extension == ' png ' and opts . enable_pnginfo and info is not None :
2022-08-24 22:57:49 +08:00
pnginfo = PngImagePlugin . PngInfo ( )
pnginfo . add_text ( " parameters " , info )
else :
pnginfo = None
2022-08-26 02:52:05 +08:00
image . save ( os . path . join ( path , filename ) , quality = opts . jpeg_quality , pnginfo = pnginfo )
2022-08-24 22:41:37 +08:00
2022-08-26 16:16:57 +08:00
def sanitize_filename_part ( text ) :
return text . replace ( ' ' , ' _ ' ) . translate ( { ord ( x ) : ' ' for x in invalid_filename_chars } ) [ : 128 ]
2022-08-24 23:47:23 +08:00
def plaintext_to_html ( text ) :
text = " " . join ( [ f " <p> { html . escape ( x ) } </p> \n " for x in text . split ( ' \n ' ) ] )
return text
2022-08-28 02:32:28 +08:00
def load_gfpgan ( ) :
2022-08-22 22:15:46 +08:00
model_name = ' GFPGANv1.3 '
2022-08-26 02:52:05 +08:00
model_path = os . path . join ( cmd_opts . gfpgan_dir , ' experiments/pretrained_models ' , model_name + ' .pth ' )
2022-08-22 22:15:46 +08:00
if not os . path . isfile ( model_path ) :
raise Exception ( " GFPGAN model not found at path " + model_path )
2022-08-26 02:52:05 +08:00
sys . path . append ( os . path . abspath ( cmd_opts . gfpgan_dir ) )
2022-08-22 22:15:46 +08:00
from gfpgan import GFPGANer
return GFPGANer ( model_path = model_path , upscale = 1 , arch = ' clean ' , channel_multiplier = 2 , bg_upsampler = None )
2022-08-26 23:04:00 +08:00
def image_grid ( imgs , batch_size , force_n_rows = None ) :
2022-08-24 21:42:22 +08:00
if force_n_rows is not None :
rows = force_n_rows
2022-08-26 02:52:05 +08:00
elif opts . n_rows > 0 :
rows = opts . n_rows
elif opts . n_rows == 0 :
2022-08-23 01:08:32 +08:00
rows = batch_size
else :
2022-08-23 05:34:49 +08:00
rows = math . sqrt ( len ( imgs ) )
2022-08-26 23:04:00 +08:00
rows = round ( rows )
2022-08-23 01:08:32 +08:00
cols = math . ceil ( len ( imgs ) / rows )
2022-08-22 22:15:46 +08:00
w , h = imgs [ 0 ] . size
2022-08-23 01:08:32 +08:00
grid = Image . new ( ' RGB ' , size = ( cols * w , rows * h ) , color = ' black ' )
2022-08-22 22:15:46 +08:00
for i , img in enumerate ( imgs ) :
grid . paste ( img , box = ( i % cols * w , i / / cols * h ) )
return grid
2022-08-23 01:08:32 +08:00
2022-08-27 21:13:33 +08:00
Grid = namedtuple ( " Grid " , [ " tiles " , " tile_w " , " tile_h " , " image_w " , " image_h " , " overlap " ] )
def split_grid ( image , tile_w = 512 , tile_h = 512 , overlap = 64 ) :
w = image . width
h = image . height
now = tile_w - overlap # non-overlap width
noh = tile_h - overlap
cols = math . ceil ( ( w - overlap ) / now )
rows = math . ceil ( ( h - overlap ) / noh )
grid = Grid ( [ ] , tile_w , tile_h , w , h , overlap )
for row in range ( rows ) :
row_images = [ ]
y = row * noh
if y + tile_h > = h :
y = h - tile_h
for col in range ( cols ) :
x = col * now
if x + tile_w > = w :
x = w - tile_w
tile = image . crop ( ( x , y , x + tile_w , y + tile_h ) )
row_images . append ( [ x , tile_w , tile ] )
grid . tiles . append ( [ y , tile_h , row_images ] )
return grid
def combine_grid ( grid ) :
def make_mask_image ( r ) :
r = r * 255 / grid . overlap
r = r . astype ( np . uint8 )
return Image . fromarray ( r , ' L ' )
mask_w = make_mask_image ( np . arange ( grid . overlap , dtype = np . float ) . reshape ( ( 1 , grid . overlap ) ) . repeat ( grid . tile_h , axis = 0 ) )
mask_h = make_mask_image ( np . arange ( grid . overlap , dtype = np . float ) . reshape ( ( grid . overlap , 1 ) ) . repeat ( grid . image_w , axis = 1 ) )
combined_image = Image . new ( " RGB " , ( grid . image_w , grid . image_h ) )
for y , h , row in grid . tiles :
combined_row = Image . new ( " RGB " , ( grid . image_w , h ) )
for x , w , tile in row :
if x == 0 :
combined_row . paste ( tile , ( 0 , 0 ) )
continue
combined_row . paste ( tile . crop ( ( 0 , 0 , grid . overlap , h ) ) , ( x , 0 ) , mask = mask_w )
combined_row . paste ( tile . crop ( ( grid . overlap , 0 , w , h ) ) , ( x + grid . overlap , 0 ) )
if y == 0 :
combined_image . paste ( combined_row , ( 0 , 0 ) )
continue
combined_image . paste ( combined_row . crop ( ( 0 , 0 , combined_row . width , grid . overlap ) ) , ( 0 , y ) , mask = mask_h )
combined_image . paste ( combined_row . crop ( ( 0 , grid . overlap , combined_row . width , h ) ) , ( 0 , y + grid . overlap ) )
return combined_image
2022-08-23 23:04:13 +08:00
def draw_prompt_matrix ( im , width , height , all_prompts ) :
2022-08-28 02:32:28 +08:00
def wrap ( text , font , line_length ) :
2022-08-23 23:04:13 +08:00
lines = [ ' ' ]
for word in text . split ( ) :
line = f ' { lines [ - 1 ] } { word } ' . strip ( )
if d . textlength ( line , font = font ) < = line_length :
lines [ - 1 ] = line
else :
lines . append ( word )
return ' \n ' . join ( lines )
2022-08-28 02:32:28 +08:00
def draw_texts ( pos , draw_x , draw_y , texts , sizes ) :
2022-08-23 23:04:13 +08:00
for i , ( text , size ) in enumerate ( zip ( texts , sizes ) ) :
active = pos & ( 1 << i ) != 0
if not active :
text = ' \u0336 ' . join ( text ) + ' \u0336 '
2022-08-28 02:32:28 +08:00
d . multiline_text ( ( draw_x , draw_y + size [ 1 ] / 2 ) , text , font = fnt , fill = color_active if active else color_inactive , anchor = " mm " , align = " center " )
2022-08-23 23:04:13 +08:00
2022-08-28 02:32:28 +08:00
draw_y + = size [ 1 ] + line_spacing
2022-08-23 23:04:13 +08:00
fontsize = ( width + height ) / / 25
line_spacing = fontsize / / 2
fnt = ImageFont . truetype ( " arial.ttf " , fontsize )
color_active = ( 0 , 0 , 0 )
color_inactive = ( 153 , 153 , 153 )
pad_top = height / / 4
2022-08-24 03:42:43 +08:00
pad_left = width * 3 / / 4 if len ( all_prompts ) > 2 else 0
2022-08-23 23:04:13 +08:00
cols = im . width / / width
rows = im . height / / height
prompts = all_prompts [ 1 : ]
result = Image . new ( " RGB " , ( im . width + pad_left , im . height + pad_top ) , " white " )
result . paste ( im , ( pad_left , pad_top ) )
d = ImageDraw . Draw ( result )
boundary = math . ceil ( len ( prompts ) / 2 )
2022-08-28 02:32:28 +08:00
prompts_horiz = [ wrap ( x , fnt , width ) for x in prompts [ : boundary ] ]
prompts_vert = [ wrap ( x , fnt , pad_left ) for x in prompts [ boundary : ] ]
2022-08-23 23:04:13 +08:00
sizes_hor = [ ( x [ 2 ] - x [ 0 ] , x [ 3 ] - x [ 1 ] ) for x in [ d . multiline_textbbox ( ( 0 , 0 ) , x , font = fnt ) for x in prompts_horiz ] ]
sizes_ver = [ ( x [ 2 ] - x [ 0 ] , x [ 3 ] - x [ 1 ] ) for x in [ d . multiline_textbbox ( ( 0 , 0 ) , x , font = fnt ) for x in prompts_vert ] ]
hor_text_height = sum ( [ x [ 1 ] + line_spacing for x in sizes_hor ] ) - line_spacing
ver_text_height = sum ( [ x [ 1 ] + line_spacing for x in sizes_ver ] ) - line_spacing
for col in range ( cols ) :
x = pad_left + width * col + width / 2
y = pad_top / 2 - hor_text_height / 2
draw_texts ( col , x , y , prompts_horiz , sizes_hor )
for row in range ( rows ) :
x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_height / 2
draw_texts ( row , x , y , prompts_vert , sizes_ver )
return result
2022-08-24 15:52:41 +08:00
def resize_image ( resize_mode , im , width , height ) :
if resize_mode == 0 :
res = im . resize ( ( width , height ) , resample = LANCZOS )
elif resize_mode == 1 :
ratio = width / height
src_ratio = im . width / im . height
src_w = width if ratio > src_ratio else im . width * height / / im . height
src_h = height if ratio < = src_ratio else im . height * width / / im . width
resized = im . resize ( ( src_w , src_h ) , resample = LANCZOS )
res = Image . new ( " RGB " , ( width , height ) )
res . paste ( resized , box = ( width / / 2 - src_w / / 2 , height / / 2 - src_h / / 2 ) )
else :
ratio = width / height
src_ratio = im . width / im . height
src_w = width if ratio < src_ratio else im . width * height / / im . height
src_h = height if ratio > = src_ratio else im . height * width / / im . width
resized = im . resize ( ( src_w , src_h ) , resample = LANCZOS )
res = Image . new ( " RGB " , ( width , height ) )
res . paste ( resized , box = ( width / / 2 - src_w / / 2 , height / / 2 - src_h / / 2 ) )
if ratio < src_ratio :
fill_height = height / / 2 - src_h / / 2
res . paste ( resized . resize ( ( width , fill_height ) , box = ( 0 , 0 , width , 0 ) ) , box = ( 0 , 0 ) )
res . paste ( resized . resize ( ( width , fill_height ) , box = ( 0 , resized . height , width , resized . height ) ) , box = ( 0 , fill_height + src_h ) )
2022-08-24 18:42:21 +08:00
elif ratio > src_ratio :
2022-08-24 15:52:41 +08:00
fill_width = width / / 2 - src_w / / 2
res . paste ( resized . resize ( ( fill_width , height ) , box = ( 0 , 0 , 0 , height ) ) , box = ( 0 , 0 ) )
res . paste ( resized . resize ( ( fill_width , height ) , box = ( resized . width , 0 , resized . width , height ) ) , box = ( fill_width + src_w , 0 ) )
return res
2022-08-24 23:47:23 +08:00
def wrap_gradio_call ( func ) :
def f ( * p1 , * * p2 ) :
t = time . perf_counter ( )
res = list ( func ( * p1 , * * p2 ) )
elapsed = time . perf_counter ( ) - t
# last item is always HTML
res [ - 1 ] = res [ - 1 ] + f " <p class= ' performance ' >Time taken: { elapsed : .2f } s</p> "
return tuple ( res )
return f
2022-08-26 02:52:05 +08:00
GFPGAN = None
if os . path . exists ( cmd_opts . gfpgan_dir ) :
try :
2022-08-28 02:32:28 +08:00
GFPGAN = load_gfpgan ( )
2022-08-26 02:52:05 +08:00
print ( " Loaded GFPGAN " )
except Exception :
print ( " Error loading GFPGAN: " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
2022-08-27 16:17:55 +08:00
class StableDiffuionModelHijack :
2022-08-26 02:52:05 +08:00
ids_lookup = { }
word_embeddings = { }
word_embeddings_checksums = { }
2022-08-27 16:17:55 +08:00
fixes = None
2022-08-28 02:32:28 +08:00
comments = None
2022-08-26 02:52:05 +08:00
dir_mtime = None
2022-08-28 02:32:28 +08:00
def load_textual_inversion_embeddings ( self , dirname , model ) :
mt = os . path . getmtime ( dirname )
2022-08-26 02:52:05 +08:00
if self . dir_mtime is not None and mt < = self . dir_mtime :
return
self . dir_mtime = mt
self . ids_lookup . clear ( )
self . word_embeddings . clear ( )
tokenizer = model . cond_stage_model . tokenizer
def const_hash ( a ) :
r = 0
for v in a :
r = ( r * 281 ^ int ( v ) * 997 ) & 0xFFFFFFFF
return r
def process_file ( path , filename ) :
name = os . path . splitext ( filename ) [ 0 ]
data = torch . load ( path )
param_dict = data [ ' string_to_param ' ]
assert len ( param_dict ) == 1 , ' embedding file has multiple terms in it '
emb = next ( iter ( param_dict . items ( ) ) ) [ 1 ] . reshape ( 768 )
self . word_embeddings [ name ] = emb
self . word_embeddings_checksums [ name ] = f ' { const_hash ( emb ) & 0xffff : 04x } '
ids = tokenizer ( [ name ] , add_special_tokens = False ) [ ' input_ids ' ] [ 0 ]
2022-08-27 16:17:55 +08:00
2022-08-26 02:52:05 +08:00
first_id = ids [ 0 ]
if first_id not in self . ids_lookup :
self . ids_lookup [ first_id ] = [ ]
self . ids_lookup [ first_id ] . append ( ( ids , name ) )
2022-08-28 02:32:28 +08:00
for fn in os . listdir ( dirname ) :
2022-08-26 02:52:05 +08:00
try :
2022-08-28 02:32:28 +08:00
process_file ( os . path . join ( dirname , fn ) , fn )
except Exception :
2022-08-26 02:52:05 +08:00
print ( f " Error loading emedding { fn } : " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
continue
print ( f " Loaded a total of { len ( self . word_embeddings ) } text inversion embeddings. " )
def hijack ( self , m ) :
model_embeddings = m . cond_stage_model . transformer . text_model . embeddings
model_embeddings . token_embedding = EmbeddingsWithFixes ( model_embeddings . token_embedding , self )
m . cond_stage_model = FrozenCLIPEmbedderWithCustomWords ( m . cond_stage_model , self )
2022-08-27 21:13:33 +08:00
2022-08-26 02:52:05 +08:00
class FrozenCLIPEmbedderWithCustomWords ( torch . nn . Module ) :
2022-08-28 02:32:28 +08:00
def __init__ ( self , wrapped , hijack ) :
2022-08-26 02:52:05 +08:00
super ( ) . __init__ ( )
self . wrapped = wrapped
2022-08-28 02:32:28 +08:00
self . hijack = hijack
2022-08-26 02:52:05 +08:00
self . tokenizer = wrapped . tokenizer
self . max_length = wrapped . max_length
2022-08-27 16:17:55 +08:00
self . token_mults = { }
tokens_with_parens = [ ( k , v ) for k , v in self . tokenizer . get_vocab ( ) . items ( ) if ' ( ' in k or ' ) ' in k or ' [ ' in k or ' ] ' in k ]
for text , ident in tokens_with_parens :
mult = 1.0
for c in text :
if c == ' [ ' :
mult / = 1.1
if c == ' ] ' :
mult * = 1.1
if c == ' ( ' :
mult * = 1.1
if c == ' ) ' :
mult / = 1.1
if mult != 1.0 :
self . token_mults [ ident ] = mult
2022-08-26 02:52:05 +08:00
def forward ( self , text ) :
2022-08-28 02:32:28 +08:00
self . hijack . fixes = [ ]
self . hijack . comments = [ ]
2022-08-26 02:52:05 +08:00
remade_batch_tokens = [ ]
id_start = self . wrapped . tokenizer . bos_token_id
id_end = self . wrapped . tokenizer . eos_token_id
maxlen = self . wrapped . max_length - 2
2022-08-28 02:32:28 +08:00
used_custom_terms = [ ]
2022-08-26 02:52:05 +08:00
cache = { }
batch_tokens = self . wrapped . tokenizer ( text , truncation = False , add_special_tokens = False ) [ " input_ids " ]
2022-08-27 16:17:55 +08:00
batch_multipliers = [ ]
2022-08-26 02:52:05 +08:00
for tokens in batch_tokens :
tuple_tokens = tuple ( tokens )
if tuple_tokens in cache :
2022-08-27 16:17:55 +08:00
remade_tokens , fixes , multipliers = cache [ tuple_tokens ]
2022-08-26 02:52:05 +08:00
else :
fixes = [ ]
remade_tokens = [ ]
2022-08-27 16:17:55 +08:00
multipliers = [ ]
mult = 1.0
2022-08-26 02:52:05 +08:00
i = 0
while i < len ( tokens ) :
token = tokens [ i ]
2022-08-28 02:32:28 +08:00
possible_matches = self . hijack . ids_lookup . get ( token , None )
2022-08-26 02:52:05 +08:00
2022-08-27 16:17:55 +08:00
mult_change = self . token_mults . get ( token )
if mult_change is not None :
mult * = mult_change
elif possible_matches is None :
2022-08-26 02:52:05 +08:00
remade_tokens . append ( token )
2022-08-27 16:17:55 +08:00
multipliers . append ( mult )
2022-08-26 02:52:05 +08:00
else :
found = False
for ids , word in possible_matches :
if tokens [ i : i + len ( ids ) ] == ids :
fixes . append ( ( len ( remade_tokens ) , word ) )
remade_tokens . append ( 777 )
2022-08-27 16:17:55 +08:00
multipliers . append ( mult )
2022-08-26 02:52:05 +08:00
i + = len ( ids ) - 1
found = True
2022-08-28 02:32:28 +08:00
used_custom_terms . append ( ( word , self . hijack . word_embeddings_checksums [ word ] ) )
2022-08-26 02:52:05 +08:00
break
if not found :
remade_tokens . append ( token )
2022-08-27 16:17:55 +08:00
multipliers . append ( mult )
2022-08-26 02:52:05 +08:00
i + = 1
2022-08-28 02:32:28 +08:00
if len ( remade_tokens ) > maxlen - 2 :
vocab = { v : k for k , v in self . wrapped . tokenizer . get_vocab ( ) . items ( ) }
ovf = remade_tokens [ maxlen - 2 : ]
overflowing_words = [ vocab . get ( int ( x ) , " " ) for x in ovf ]
overflowing_text = self . wrapped . tokenizer . convert_tokens_to_string ( ' ' . join ( overflowing_words ) )
self . hijack . comments . append ( f " Warning: too many input tokens; some ( { len ( overflowing_words ) } ) have been truncated: \n { overflowing_text } \n " )
2022-08-26 02:52:05 +08:00
remade_tokens = remade_tokens + [ id_end ] * ( maxlen - 2 - len ( remade_tokens ) )
remade_tokens = [ id_start ] + remade_tokens [ 0 : maxlen - 2 ] + [ id_end ]
2022-08-27 16:17:55 +08:00
cache [ tuple_tokens ] = ( remade_tokens , fixes , multipliers )
multipliers = multipliers + [ 1.0 ] * ( maxlen - 2 - len ( multipliers ) )
multipliers = [ 1.0 ] + multipliers [ 0 : maxlen - 2 ] + [ 1.0 ]
2022-08-26 02:52:05 +08:00
remade_batch_tokens . append ( remade_tokens )
2022-08-28 02:32:28 +08:00
self . hijack . fixes . append ( fixes )
2022-08-27 16:17:55 +08:00
batch_multipliers . append ( multipliers )
2022-08-26 02:52:05 +08:00
2022-08-28 02:32:28 +08:00
if len ( used_custom_terms ) > 0 :
self . hijack . comments . append ( " Used custom terms: " + " , " . join ( [ f ' { word } [ { checksum } ] ' for word , checksum in used_custom_terms ] ) )
2022-08-26 02:52:05 +08:00
tokens = torch . asarray ( remade_batch_tokens ) . to ( self . wrapped . device )
outputs = self . wrapped . transformer ( input_ids = tokens )
z = outputs . last_hidden_state
2022-08-27 16:17:55 +08:00
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch . asarray ( np . array ( batch_multipliers ) ) . to ( device )
original_mean = z . mean ( )
z * = batch_multipliers . reshape ( batch_multipliers . shape + ( 1 , ) ) . expand ( z . shape )
new_mean = z . mean ( )
z * = original_mean / new_mean
2022-08-26 02:52:05 +08:00
return z
class EmbeddingsWithFixes ( nn . Module ) :
def __init__ ( self , wrapped , embeddings ) :
super ( ) . __init__ ( )
self . wrapped = wrapped
self . embeddings = embeddings
def forward ( self , input_ids ) :
batch_fixes = self . embeddings . fixes
2022-08-27 16:17:55 +08:00
self . embeddings . fixes = None
2022-08-26 02:52:05 +08:00
inputs_embeds = self . wrapped ( input_ids )
2022-08-27 16:17:55 +08:00
if batch_fixes is not None :
for fixes , tensor in zip ( batch_fixes , inputs_embeds ) :
for offset , word in fixes :
tensor [ offset ] = self . embeddings . word_embeddings [ word ]
2022-08-26 02:52:05 +08:00
2022-08-27 16:17:55 +08:00
return inputs_embeds
2022-08-26 02:52:05 +08:00
2022-08-28 02:32:28 +08:00
class StableDiffusionProcessing :
def __init__ ( self , outpath = None , prompt = " " , seed = - 1 , sampler_index = 0 , batch_size = 1 , n_iter = 1 , steps = 50 , cfg_scale = 7.0 , width = 512 , height = 512 , prompt_matrix = False , use_GFPGAN = False , do_not_save_grid = False , extra_generation_params = None ) :
self . outpath : str = outpath
self . prompt : str = prompt
self . seed : int = seed
self . sampler_index : int = sampler_index
self . batch_size : int = batch_size
self . n_iter : int = n_iter
self . steps : int = steps
self . cfg_scale : float = cfg_scale
self . width : int = width
self . height : int = height
self . prompt_matrix : bool = prompt_matrix
self . use_GFPGAN : bool = use_GFPGAN
self . do_not_save_grid : bool = do_not_save_grid
self . extra_generation_params : dict = extra_generation_params
def init ( self ) :
pass
def sample ( self , x , conditioning , unconditional_conditioning ) :
raise NotImplementedError ( )
class VanillaStableDiffusionSampler :
def __init__ ( self , constructor ) :
self . sampler = constructor ( sd_model )
def sample ( self , p : StableDiffusionProcessing , x , conditioning , unconditional_conditioning ) :
samples_ddim , _ = self . sampler . sample ( S = p . steps , conditioning = conditioning , batch_size = int ( x . shape [ 0 ] ) , shape = x [ 0 ] . shape , verbose = False , unconditional_guidance_scale = p . cfg_scale , unconditional_conditioning = unconditional_conditioning , x_T = x )
return samples_ddim
class CFGDenoiser ( nn . Module ) :
def __init__ ( self , model ) :
super ( ) . __init__ ( )
self . inner_model = model
def forward ( self , x , sigma , uncond , cond , cond_scale ) :
x_in = torch . cat ( [ x ] * 2 )
sigma_in = torch . cat ( [ sigma ] * 2 )
cond_in = torch . cat ( [ uncond , cond ] )
uncond , cond = self . inner_model ( x_in , sigma_in , cond = cond_in ) . chunk ( 2 )
return uncond + ( cond - uncond ) * cond_scale
class KDiffusionSampler :
def __init__ ( self , funcname ) :
self . model_wrap = k_diffusion . external . CompVisDenoiser ( sd_model )
self . funcname = funcname
self . func = getattr ( k_diffusion . sampling , self . funcname )
self . model_wrap_cfg = CFGDenoiser ( self . model_wrap )
def sample ( self , p : StableDiffusionProcessing , x , conditioning , unconditional_conditioning ) :
sigmas = self . model_wrap . get_sigmas ( p . steps )
x = x * sigmas [ 0 ]
samples_ddim = self . func ( self . model_wrap_cfg , x , sigmas , extra_args = { ' cond ' : conditioning , ' uncond ' : unconditional_conditioning , ' cond_scale ' : p . cfg_scale } , disable = False )
return samples_ddim
def process_images ( p : StableDiffusionProcessing ) :
2022-08-24 03:42:43 +08:00
""" this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch """
2022-08-22 22:15:46 +08:00
2022-08-28 02:32:28 +08:00
prompt = p . prompt
model = sd_model
assert p . prompt is not None
2022-08-24 21:12:33 +08:00
torch_gc ( )
2022-08-22 22:15:46 +08:00
2022-08-28 02:32:28 +08:00
seed = int ( random . randrange ( 4294967294 ) if p . seed == - 1 else p . seed )
2022-08-22 22:15:46 +08:00
2022-08-28 02:32:28 +08:00
os . makedirs ( p . outpath , exist_ok = True )
2022-08-22 22:15:46 +08:00
2022-08-28 02:32:28 +08:00
sample_path = os . path . join ( p . outpath , " samples " )
2022-08-22 22:15:46 +08:00
os . makedirs ( sample_path , exist_ok = True )
base_count = len ( os . listdir ( sample_path ) )
2022-08-28 02:32:28 +08:00
grid_count = len ( os . listdir ( p . outpath ) ) - 1
2022-08-22 22:15:46 +08:00
2022-08-24 05:02:43 +08:00
comments = [ ]
2022-08-23 23:04:13 +08:00
prompt_matrix_parts = [ ]
2022-08-28 02:32:28 +08:00
if p . prompt_matrix :
2022-08-24 03:42:43 +08:00
all_prompts = [ ]
2022-08-23 23:04:13 +08:00
prompt_matrix_parts = prompt . split ( " | " )
2022-08-24 03:42:43 +08:00
combination_count = 2 * * ( len ( prompt_matrix_parts ) - 1 )
2022-08-23 05:34:49 +08:00
for combination_num in range ( combination_count ) :
2022-08-28 02:32:28 +08:00
selected_prompts = [ text . strip ( ) . strip ( ' , ' ) for n , text in enumerate ( prompt_matrix_parts [ 1 : ] ) if combination_num & ( 1 << n ) ]
2022-08-23 05:34:49 +08:00
2022-08-26 13:47:44 +08:00
if opts . prompt_matrix_add_to_start :
selected_prompts = selected_prompts + [ prompt_matrix_parts [ 0 ] ]
else :
selected_prompts = [ prompt_matrix_parts [ 0 ] ] + selected_prompts
2022-08-23 05:34:49 +08:00
2022-08-28 02:32:28 +08:00
all_prompts . append ( " , " . join ( selected_prompts ) )
2022-08-23 05:34:49 +08:00
2022-08-28 02:32:28 +08:00
p . n_iter = math . ceil ( len ( all_prompts ) / p . batch_size )
2022-08-24 03:42:43 +08:00
all_seeds = len ( all_prompts ) * [ seed ]
2022-08-28 02:32:28 +08:00
print ( f " Prompt matrix will create { len ( all_prompts ) } images using a total of { p . n_iter } batches. " )
2022-08-24 03:42:43 +08:00
else :
2022-08-28 02:32:28 +08:00
all_prompts = p . batch_size * p . n_iter * [ prompt ]
2022-08-24 03:42:43 +08:00
all_seeds = [ seed + x for x in range ( len ( all_prompts ) ) ]
2022-08-23 05:34:49 +08:00
2022-08-26 14:02:21 +08:00
generation_params = {
2022-08-28 02:32:28 +08:00
" Steps " : p . steps ,
" Sampler " : samplers [ p . sampler_index ] . name ,
" CFG scale " : p . cfg_scale ,
2022-08-26 14:02:21 +08:00
" Seed " : seed ,
2022-08-28 02:32:28 +08:00
" GFPGAN " : ( " GFPGAN " if p . use_GFPGAN and GFPGAN is not None else None )
2022-08-26 14:02:21 +08:00
}
2022-08-28 02:32:28 +08:00
if p . extra_generation_params is not None :
generation_params . update ( p . extra_generation_params )
2022-08-26 14:02:21 +08:00
generation_params_text = " , " . join ( [ k if k == v else f ' { k } : { v } ' for k , v in generation_params . items ( ) if v is not None ] )
2022-08-26 02:52:05 +08:00
def infotext ( ) :
2022-08-26 14:02:21 +08:00
return f " { prompt } \n { generation_params_text } " . strip ( ) + " " . join ( [ " \n \n " + x for x in comments ] )
2022-08-26 02:52:05 +08:00
if os . path . exists ( cmd_opts . embeddings_dir ) :
2022-08-27 16:17:55 +08:00
model_hijack . load_textual_inversion_embeddings ( cmd_opts . embeddings_dir , model )
2022-08-24 22:57:49 +08:00
2022-08-22 22:15:46 +08:00
output_images = [ ]
2022-08-26 02:52:05 +08:00
with torch . no_grad ( ) , autocast ( " cuda " ) , model . ema_scope ( ) :
2022-08-28 02:32:28 +08:00
p . init ( )
2022-08-24 03:42:43 +08:00
2022-08-28 02:32:28 +08:00
for n in range ( p . n_iter ) :
prompts = all_prompts [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
seeds = all_seeds [ n * p . batch_size : ( n + 1 ) * p . batch_size ]
2022-08-23 05:34:49 +08:00
2022-08-26 02:52:05 +08:00
uc = model . get_learned_conditioning ( len ( prompts ) * [ " " ] )
2022-08-23 05:34:49 +08:00
c = model . get_learned_conditioning ( prompts )
2022-08-28 02:32:28 +08:00
if len ( model_hijack . comments ) > 0 :
comments + = model_hijack . comments
2022-08-26 02:52:05 +08:00
2022-08-23 05:34:49 +08:00
# we manually generate all input noises because each one should have a specific seed
2022-08-28 02:32:28 +08:00
x = create_random_tensors ( [ opt_C , p . height / / opt_f , p . width / / opt_f ] , seeds = seeds )
2022-08-23 19:07:37 +08:00
2022-08-28 02:32:28 +08:00
samples_ddim = p . sample ( x = x , conditioning = c , unconditional_conditioning = uc )
2022-08-22 22:15:46 +08:00
2022-08-23 05:34:49 +08:00
x_samples_ddim = model . decode_first_stage ( samples_ddim )
x_samples_ddim = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
2022-08-22 22:15:46 +08:00
2022-08-28 02:32:28 +08:00
if p . prompt_matrix or opts . samples_save or opts . grid_save :
2022-08-23 05:34:49 +08:00
for i , x_sample in enumerate ( x_samples_ddim ) :
2022-08-28 02:32:28 +08:00
x_sample = 255. * np . moveaxis ( x_sample . cpu ( ) . numpy ( ) , 0 , 2 )
2022-08-23 05:34:49 +08:00
x_sample = x_sample . astype ( np . uint8 )
2022-08-28 02:32:28 +08:00
if p . use_GFPGAN and GFPGAN is not None :
2022-08-26 08:59:09 +08:00
torch_gc ( )
2022-08-23 05:34:49 +08:00
cropped_faces , restored_faces , restored_img = GFPGAN . enhance ( x_sample , has_aligned = False , only_center_face = False , paste_back = True )
x_sample = restored_img
image = Image . fromarray ( x_sample )
2022-08-26 02:52:05 +08:00
save_image ( image , sample_path , f " { base_count : 05 } " , seeds [ i ] , prompts [ i ] , opts . samples_format , info = infotext ( ) )
2022-08-23 05:34:49 +08:00
output_images . append ( image )
base_count + = 1
2022-08-22 22:15:46 +08:00
2022-08-28 02:32:28 +08:00
if ( p . prompt_matrix or opts . grid_save ) and not p . do_not_save_grid :
if p . prompt_matrix :
grid = image_grid ( output_images , p . batch_size , force_n_rows = 1 << ( ( len ( prompt_matrix_parts ) - 1 ) / / 2 ) )
2022-08-24 03:42:43 +08:00
try :
2022-08-28 02:32:28 +08:00
grid = draw_prompt_matrix ( grid , p . width , p . height , prompt_matrix_parts )
except Exception :
2022-08-24 03:42:43 +08:00
import traceback
print ( " Error creating prompt_matrix text: " , file = sys . stderr )
print ( traceback . format_exc ( ) , file = sys . stderr )
2022-08-23 23:04:13 +08:00
output_images . insert ( 0 , grid )
2022-08-26 23:04:00 +08:00
else :
2022-08-28 02:32:28 +08:00
grid = image_grid ( output_images , p . batch_size )
2022-08-23 23:04:13 +08:00
2022-08-28 02:32:28 +08:00
save_image ( grid , p . outpath , f " grid- { grid_count : 04 } " , seed , prompt , opts . grid_format , info = infotext ( ) , short_filename = not opts . grid_extended_filename )
2022-08-22 22:15:46 +08:00
grid_count + = 1
2022-08-26 02:52:05 +08:00
torch_gc ( )
return output_images , seed , infotext ( )
2022-08-25 02:20:36 +08:00
2022-08-24 03:42:43 +08:00
2022-08-28 02:32:28 +08:00
class StableDiffusionProcessingTxt2Img ( StableDiffusionProcessing ) :
sampler = None
2022-08-25 02:20:36 +08:00
2022-08-28 02:32:28 +08:00
def init ( self ) :
self . sampler = samplers [ self . sampler_index ] . constructor ( )
2022-08-24 03:42:43 +08:00
2022-08-28 02:32:28 +08:00
def sample ( self , x , conditioning , unconditional_conditioning ) :
samples_ddim = self . sampler . sample ( self , x , conditioning , unconditional_conditioning )
2022-08-24 03:42:43 +08:00
return samples_ddim
2022-08-28 02:32:28 +08:00
def txt2img ( prompt : str , ddim_steps : int , sampler_index : int , use_GFPGAN : bool , prompt_matrix : bool , n_iter : int , batch_size : int , cfg_scale : float , seed : int , height : int , width : int ) :
outpath = opts . outdir or " outputs/txt2img-samples "
p = StableDiffusionProcessingTxt2Img (
2022-08-24 03:42:43 +08:00
outpath = outpath ,
prompt = prompt ,
seed = seed ,
2022-08-26 04:31:44 +08:00
sampler_index = sampler_index ,
2022-08-24 03:42:43 +08:00
batch_size = batch_size ,
n_iter = n_iter ,
steps = ddim_steps ,
cfg_scale = cfg_scale ,
width = width ,
height = height ,
prompt_matrix = prompt_matrix ,
use_GFPGAN = use_GFPGAN
)
2022-08-28 02:32:28 +08:00
output_images , seed , info = process_images ( p )
2022-08-24 03:42:43 +08:00
2022-08-24 23:47:23 +08:00
return output_images , seed , plaintext_to_html ( info )
2022-08-24 03:42:43 +08:00
2022-08-23 05:34:49 +08:00
class Flagging ( gr . FlaggingCallback ) :
def setup ( self , components , flagging_dir : str ) :
pass
2022-08-23 16:58:50 +08:00
def flag ( self , flag_data , flag_option = None , flag_index = None , username = None ) :
import csv
2022-08-23 05:34:49 +08:00
os . makedirs ( " log/images " , exist_ok = True )
2022-08-24 03:42:43 +08:00
# those must match the "txt2img" function
2022-08-28 02:32:28 +08:00
prompt , ddim_steps , sampler_name , use_gfpgan , prompt_matrix , ddim_eta , n_iter , n_samples , cfg_scale , request_seed , height , width , images , seed , comment = flag_data
2022-08-23 05:34:49 +08:00
filenames = [ ]
with open ( " log/log.csv " , " a " , encoding = " utf8 " , newline = ' ' ) as file :
import time
import base64
at_start = file . tell ( ) == 0
writer = csv . writer ( file )
if at_start :
writer . writerow ( [ " prompt " , " seed " , " width " , " height " , " cfgs " , " steps " , " filename " ] )
filename_base = str ( int ( time . time ( ) * 1000 ) )
for i , filedata in enumerate ( images ) :
filename = " log/images/ " + filename_base + ( " " if len ( images ) == 1 else " - " + str ( i + 1 ) ) + " .png "
if filedata . startswith ( " data:image/png;base64, " ) :
filedata = filedata [ len ( " data:image/png;base64, " ) : ]
with open ( filename , " wb " ) as imgfile :
imgfile . write ( base64 . decodebytes ( filedata . encode ( ' utf-8 ' ) ) )
filenames . append ( filename )
writer . writerow ( [ prompt , seed , width , height , cfg_scale , ddim_steps , filenames [ 0 ] ] )
print ( " Logged: " , filenames [ 0 ] )
2022-08-22 22:15:46 +08:00
2022-08-24 03:42:43 +08:00
txt2img_interface = gr . Interface (
2022-08-24 23:47:23 +08:00
wrap_gradio_call ( txt2img ) ,
2022-08-22 22:15:46 +08:00
inputs = [
gr . Textbox ( label = " Prompt " , placeholder = " A corgi wearing a top hat as an oil painting. " , lines = 1 ) ,
gr . Slider ( minimum = 1 , maximum = 150 , step = 1 , label = " Sampling Steps " , value = 50 ) ,
2022-08-26 04:31:44 +08:00
gr . Radio ( label = ' Sampling method ' , choices = [ x . name for x in samplers ] , value = samplers [ 0 ] . name , type = " index " ) ,
2022-08-22 22:15:46 +08:00
gr . Checkbox ( label = ' Fix faces using GFPGAN ' , value = False , visible = GFPGAN is not None ) ,
2022-08-23 05:34:49 +08:00
gr . Checkbox ( label = ' Create prompt matrix (separate multiple prompts using |, and get all combinations of them) ' , value = False ) ,
2022-08-26 02:52:05 +08:00
gr . Slider ( minimum = 1 , maximum = cmd_opts . max_batch_count , step = 1 , label = ' Batch count (how many batches of images to generate) ' , value = 1 ) ,
2022-08-23 23:04:13 +08:00
gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size (how many images are in a batch; memory-hungry) ' , value = 1 ) ,
2022-08-23 16:58:50 +08:00
gr . Slider ( minimum = 1.0 , maximum = 15.0 , step = 0.5 , label = ' Classifier Free Guidance Scale (how strongly the image should follow the prompt) ' , value = 7.0 ) ,
2022-08-22 22:15:46 +08:00
gr . Number ( label = ' Seed ' , value = - 1 ) ,
gr . Slider ( minimum = 64 , maximum = 2048 , step = 64 , label = " Height " , value = 512 ) ,
gr . Slider ( minimum = 64 , maximum = 2048 , step = 64 , label = " Width " , value = 512 ) ,
] ,
outputs = [
gr . Gallery ( label = " Images " ) ,
gr . Number ( label = ' Seed ' ) ,
2022-08-24 23:47:23 +08:00
gr . HTML ( ) ,
2022-08-22 22:15:46 +08:00
] ,
2022-08-26 02:52:05 +08:00
title = " Stable Diffusion Text-to-Image " ,
2022-08-23 05:34:49 +08:00
flagging_callback = Flagging ( )
2022-08-22 22:15:46 +08:00
)
2022-08-28 02:32:28 +08:00
class StableDiffusionProcessingImg2Img ( StableDiffusionProcessing ) :
sampler = None
2022-08-25 02:20:36 +08:00
2022-08-28 02:32:28 +08:00
def __init__ ( self , init_images = None , resize_mode = 0 , denoising_strength = 0.75 , * * kwargs ) :
super ( ) . __init__ ( * * kwargs )
2022-08-22 22:15:46 +08:00
2022-08-28 02:32:28 +08:00
self . init_images = init_images
self . resize_mode : int = resize_mode
self . denoising_strength : float = denoising_strength
self . init_latent = None
def init ( self ) :
self . sampler = samplers_for_img2img [ self . sampler_index ] . constructor ( )
2022-08-22 22:15:46 +08:00
2022-08-28 02:32:28 +08:00
imgs = [ ]
for img in self . init_images :
image = img . convert ( " RGB " )
image = resize_image ( self . resize_mode , image , self . width , self . height )
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = np . moveaxis ( image , 2 , 0 )
imgs . append ( image )
2022-08-22 22:15:46 +08:00
2022-08-28 02:32:28 +08:00
if len ( imgs ) == 1 :
batch_images = np . expand_dims ( imgs [ 0 ] , axis = 0 ) . repeat ( self . batch_size , axis = 0 )
elif len ( imgs ) < = self . batch_size :
self . batch_size = len ( imgs )
batch_images = np . array ( imgs )
else :
raise RuntimeError ( f " bad number of images passed: { len ( imgs ) } ; expecting { self . batch_size } or less " )
2022-08-23 19:07:37 +08:00
2022-08-28 02:32:28 +08:00
image = torch . from_numpy ( batch_images )
image = 2. * image - 1.
image = image . to ( device )
2022-08-24 03:42:43 +08:00
2022-08-28 02:32:28 +08:00
self . init_latent = sd_model . get_first_stage_encoding ( sd_model . encode_first_stage ( image ) )
2022-08-24 21:42:22 +08:00
2022-08-28 02:32:28 +08:00
def sample ( self , x , conditioning , unconditional_conditioning ) :
t_enc = int ( self . denoising_strength * self . steps )
2022-08-24 03:42:43 +08:00
2022-08-28 02:32:28 +08:00
sigmas = self . sampler . model_wrap . get_sigmas ( self . steps )
noise = x * sigmas [ self . steps - t_enc - 1 ]
2022-08-24 03:42:43 +08:00
2022-08-28 02:32:28 +08:00
xi = self . init_latent + noise
sigma_sched = sigmas [ self . steps - t_enc - 1 : ]
samples_ddim = self . sampler . func ( self . sampler . model_wrap_cfg , xi , sigma_sched , extra_args = { ' cond ' : conditioning , ' uncond ' : unconditional_conditioning , ' cond_scale ' : self . cfg_scale } , disable = False )
2022-08-24 03:42:43 +08:00
return samples_ddim
2022-08-28 02:32:28 +08:00
def img2img ( prompt : str , init_img , ddim_steps : int , sampler_index : int , use_GFPGAN : bool , prompt_matrix , loopback : bool , sd_upscale : bool , n_iter : int , batch_size : int , cfg_scale : float , denoising_strength : float , seed : int , height : int , width : int , resize_mode : int ) :
outpath = opts . outdir or " outputs/img2img-samples "
assert 0. < = denoising_strength < = 1. , ' can only work with strength in [0.0, 1.0] '
p = StableDiffusionProcessingImg2Img (
outpath = outpath ,
prompt = prompt ,
seed = seed ,
sampler_index = sampler_index ,
batch_size = batch_size ,
n_iter = n_iter ,
steps = ddim_steps ,
cfg_scale = cfg_scale ,
width = width ,
height = height ,
prompt_matrix = prompt_matrix ,
use_GFPGAN = use_GFPGAN ,
init_images = [ init_img ] ,
resize_mode = resize_mode ,
denoising_strength = denoising_strength ,
extra_generation_params = { " Denoising Strength " : denoising_strength }
)
2022-08-24 21:42:22 +08:00
if loopback :
output_images , info = None , None
history = [ ]
initial_seed = None
for i in range ( n_iter ) :
2022-08-28 02:32:28 +08:00
p . n_iter = 1
p . batch_size = 1
p . do_not_save_grid = True
output_images , seed , info = process_images ( p )
2022-08-24 21:42:22 +08:00
if initial_seed is None :
initial_seed = seed
2022-08-28 02:32:28 +08:00
p . init_img = output_images [ 0 ]
p . seed = seed + 1
p . denoising_strength = max ( p . denoising_strength * 0.95 , 0.1 )
history . append ( output_images [ 0 ] )
2022-08-24 21:42:22 +08:00
grid_count = len ( os . listdir ( outpath ) ) - 1
grid = image_grid ( history , batch_size , force_n_rows = 1 )
2022-08-24 22:41:37 +08:00
2022-08-26 02:52:05 +08:00
save_image ( grid , outpath , f " grid- { grid_count : 04 } " , initial_seed , prompt , opts . grid_format , info = info , short_filename = not opts . grid_extended_filename )
2022-08-24 21:42:22 +08:00
output_images = history
seed = initial_seed
2022-08-27 21:13:33 +08:00
elif sd_upscale :
initial_seed = None
initial_info = None
img = upscale_with_realesrgan ( init_img , RealESRGAN_upscaling = 2 , RealESRGAN_model_index = 0 )
torch_gc ( )
grid = split_grid ( img , tile_w = width , tile_h = height , overlap = opts . sd_upscale_overlap )
2022-08-28 02:32:28 +08:00
p . n_iter = 1
p . do_not_save_grid = True
work = [ ]
work_results = [ ]
for y , h , row in grid . tiles :
for tiledata in row :
work . append ( tiledata [ 2 ] )
batch_count = math . ceil ( len ( work ) / p . batch_size )
print ( f " SD upscaling will process a total of { len ( work ) } images tiled as { len ( grid . tiles [ 0 ] [ 2 ] ) } x { len ( grid . tiles ) } in a total of { batch_count } batches. " )
2022-08-27 21:13:33 +08:00
2022-08-28 02:32:28 +08:00
for i in range ( batch_count ) :
p . init_images = work [ i * p . batch_size : ( i + 1 ) * p . batch_size ]
2022-08-27 21:13:33 +08:00
2022-08-28 02:32:28 +08:00
output_images , seed , info = process_images ( p )
if initial_seed is None :
initial_seed = seed
initial_info = info
p . seed = seed + 1
work_results + = output_images
image_index = 0
2022-08-27 21:13:33 +08:00
for y , h , row in grid . tiles :
for tiledata in row :
2022-08-28 02:32:28 +08:00
tiledata [ 2 ] = work_results [ image_index ]
image_index + = 1
2022-08-27 21:13:33 +08:00
combined_image = combine_grid ( grid )
grid_count = len ( os . listdir ( outpath ) ) - 1
save_image ( combined_image , outpath , f " grid- { grid_count : 04 } " , initial_seed , prompt , opts . grid_format , info = initial_info , short_filename = not opts . grid_extended_filename )
output_images = [ combined_image ]
seed = initial_seed
info = initial_info
2022-08-24 21:42:22 +08:00
else :
2022-08-28 02:32:28 +08:00
output_images , seed , info = process_images ( p )
2022-08-22 22:15:46 +08:00
2022-08-24 23:47:23 +08:00
return output_images , seed , plaintext_to_html ( info )
2022-08-22 22:15:46 +08:00
2022-08-24 14:24:32 +08:00
sample_img2img = " assets/stable-samples/img2img/sketch-mountains-input.jpg "
sample_img2img = sample_img2img if os . path . exists ( sample_img2img ) else None
2022-08-22 22:15:46 +08:00
img2img_interface = gr . Interface (
2022-08-24 23:47:23 +08:00
wrap_gradio_call ( img2img ) ,
2022-08-22 22:15:46 +08:00
inputs = [
gr . Textbox ( placeholder = " A fantasy landscape, trending on artstation. " , lines = 1 ) ,
2022-08-24 14:24:32 +08:00
gr . Image ( value = sample_img2img , source = " upload " , interactive = True , type = " pil " ) ,
2022-08-22 22:15:46 +08:00
gr . Slider ( minimum = 1 , maximum = 150 , step = 1 , label = " Sampling Steps " , value = 50 ) ,
2022-08-26 19:10:40 +08:00
gr . Radio ( label = ' Sampling method ' , choices = [ x . name for x in samplers_for_img2img ] , value = samplers_for_img2img [ 0 ] . name , type = " index " ) ,
2022-08-23 01:08:32 +08:00
gr . Checkbox ( label = ' Fix faces using GFPGAN ' , value = False , visible = GFPGAN is not None ) ,
2022-08-24 03:42:43 +08:00
gr . Checkbox ( label = ' Create prompt matrix (separate multiple prompts using |, and get all combinations of them) ' , value = False ) ,
2022-08-24 21:42:22 +08:00
gr . Checkbox ( label = ' Loopback (use images from previous batch when creating next batch) ' , value = False ) ,
2022-08-27 21:13:33 +08:00
gr . Checkbox ( label = ' Stable Diffusion upscale ' , value = False ) ,
2022-08-26 02:52:05 +08:00
gr . Slider ( minimum = 1 , maximum = cmd_opts . max_batch_count , step = 1 , label = ' Batch count (how many batches of images to generate) ' , value = 1 ) ,
2022-08-23 23:04:13 +08:00
gr . Slider ( minimum = 1 , maximum = 8 , step = 1 , label = ' Batch size (how many images are in a batch; memory-hungry) ' , value = 1 ) ,
2022-08-23 16:58:50 +08:00
gr . Slider ( minimum = 1.0 , maximum = 15.0 , step = 0.5 , label = ' Classifier Free Guidance Scale (how strongly the image should follow the prompt) ' , value = 7.0 ) ,
2022-08-22 22:15:46 +08:00
gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = ' Denoising Strength ' , value = 0.75 ) ,
gr . Number ( label = ' Seed ' , value = - 1 ) ,
2022-08-23 16:58:50 +08:00
gr . Slider ( minimum = 64 , maximum = 2048 , step = 64 , label = " Height " , value = 512 ) ,
gr . Slider ( minimum = 64 , maximum = 2048 , step = 64 , label = " Width " , value = 512 ) ,
2022-08-26 02:52:05 +08:00
gr . Radio ( label = " Resize mode " , choices = [ " Just resize " , " Crop and resize " , " Resize and fill " ] , type = " index " , value = " Just resize " )
2022-08-22 22:15:46 +08:00
] ,
outputs = [
gr . Gallery ( ) ,
2022-08-24 03:42:43 +08:00
gr . Number ( label = ' Seed ' ) ,
2022-08-24 23:47:23 +08:00
gr . HTML ( ) ,
2022-08-22 22:15:46 +08:00
] ,
2022-08-23 16:58:50 +08:00
allow_flagging = " never " ,
2022-08-22 22:15:46 +08:00
)
2022-08-26 04:31:44 +08:00
2022-08-27 21:13:33 +08:00
def upscale_with_realesrgan ( image , RealESRGAN_upscaling , RealESRGAN_model_index ) :
info = realesrgan_models [ RealESRGAN_model_index ]
model = info . model ( )
upsampler = RealESRGANer (
scale = info . netscale ,
model_path = info . location ,
model = model ,
half = True
)
upsampled = upsampler . enhance ( np . array ( image ) , outscale = RealESRGAN_upscaling ) [ 0 ]
image = Image . fromarray ( upsampled )
return image
2022-08-26 16:16:57 +08:00
def run_extras ( image , GFPGAN_strength , RealESRGAN_upscaling , RealESRGAN_model_index ) :
2022-08-27 21:13:33 +08:00
torch_gc ( )
2022-08-23 01:08:32 +08:00
image = image . convert ( " RGB " )
2022-08-26 16:16:57 +08:00
outpath = opts . outdir or " outputs/extras-samples "
if GFPGAN is not None and GFPGAN_strength > 0 :
cropped_faces , restored_faces , restored_img = GFPGAN . enhance ( np . array ( image , dtype = np . uint8 ) , has_aligned = False , only_center_face = False , paste_back = True )
res = Image . fromarray ( restored_img )
if GFPGAN_strength < 1.0 :
res = Image . blend ( image , res , GFPGAN_strength )
image = res
if have_realesrgan and RealESRGAN_upscaling != 1.0 :
2022-08-27 21:13:33 +08:00
image = upscale_with_realesrgan ( image , RealESRGAN_upscaling , RealESRGAN_model_index )
2022-08-23 01:08:32 +08:00
2022-08-26 16:16:57 +08:00
os . makedirs ( outpath , exist_ok = True )
base_count = len ( os . listdir ( outpath ) )
save_image ( image , outpath , f " { base_count : 05 } " , None , ' ' , opts . samples_format , short_filename = True )
2022-08-23 01:08:32 +08:00
2022-08-26 16:16:57 +08:00
return image , 0 , ' '
2022-08-23 01:08:32 +08:00
2022-08-26 16:16:57 +08:00
extras_interface = gr . Interface (
wrap_gradio_call ( run_extras ) ,
2022-08-26 02:52:05 +08:00
inputs = [
gr . Image ( label = " Source " , source = " upload " , interactive = True , type = " pil " ) ,
2022-08-26 16:16:57 +08:00
gr . Slider ( minimum = 0.0 , maximum = 1.0 , step = 0.001 , label = " GFPGAN strength " , value = 1 , interactive = GFPGAN is not None ) ,
gr . Slider ( minimum = 1.0 , maximum = 4.0 , step = 0.05 , label = " Real-ESRGAN upscaling " , value = 2 , interactive = have_realesrgan ) ,
gr . Radio ( label = ' Real-ESRGAN model ' , choices = [ x . name for x in realesrgan_models ] , value = realesrgan_models [ 0 ] . name , type = " index " , interactive = have_realesrgan ) ,
2022-08-26 02:52:05 +08:00
] ,
outputs = [
gr . Image ( label = " Result " ) ,
gr . Number ( label = ' Seed ' , visible = False ) ,
gr . HTML ( ) ,
] ,
allow_flagging = " never " ,
)
opts = Options ( )
if os . path . exists ( config_filename ) :
opts . load ( config_filename )
def run_settings ( * args ) :
up = [ ]
for key , value , comp in zip ( opts . data_labels . keys ( ) , args , settings_interface . input_components ) :
opts . data [ key ] = value
up . append ( comp . update ( value = value ) )
opts . save ( config_filename )
return ' Settings saved. ' , ' '
def create_setting_component ( key ) :
def fun ( ) :
2022-08-28 02:32:28 +08:00
return opts . data [ key ] if key in opts . data else opts . data_labels [ key ] . default
info = opts . data_labels [ key ]
t = type ( info . default )
2022-08-26 02:52:05 +08:00
2022-08-28 02:32:28 +08:00
if info . component is not None :
item = info . component ( label = info . label , value = fun , * * ( info . component_args or { } ) )
elif t == str :
item = gr . Textbox ( label = info . label , value = fun , lines = 1 )
2022-08-26 02:52:05 +08:00
elif t == int :
2022-08-28 02:32:28 +08:00
item = gr . Number ( label = info . label , value = fun )
2022-08-26 02:52:05 +08:00
elif t == bool :
2022-08-28 02:32:28 +08:00
item = gr . Checkbox ( label = info . label , value = fun )
2022-08-26 02:52:05 +08:00
else :
raise Exception ( f ' bad options item type: { str ( t ) } for key { key } ' )
return item
settings_interface = gr . Interface (
run_settings ,
inputs = [ create_setting_component ( key ) for key in opts . data_labels . keys ( ) ] ,
outputs = [
gr . Textbox ( label = ' Result ' ) ,
gr . HTML ( ) ,
] ,
title = None ,
description = None ,
allow_flagging = " never " ,
)
interfaces = [
( txt2img_interface , " txt2img " ) ,
( img2img_interface , " img2img " ) ,
2022-08-26 16:16:57 +08:00
( extras_interface , " Extras " ) ,
2022-08-26 02:52:05 +08:00
( settings_interface , " Settings " ) ,
]
2022-08-28 02:32:28 +08:00
sd_config = OmegaConf . load ( cmd_opts . config )
sd_model = load_model_from_config ( sd_config , cmd_opts . ckpt )
2022-08-26 02:52:05 +08:00
device = torch . device ( " cuda " ) if torch . cuda . is_available ( ) else torch . device ( " cpu " )
2022-08-28 02:32:28 +08:00
sd_model = ( sd_model if cmd_opts . no_half else sd_model . half ( ) ) . to ( device )
2022-08-26 02:52:05 +08:00
2022-08-27 16:17:55 +08:00
model_hijack = StableDiffuionModelHijack ( )
2022-08-28 02:32:28 +08:00
model_hijack . hijack ( sd_model )
2022-08-23 01:08:32 +08:00
2022-08-24 14:06:36 +08:00
demo = gr . TabbedInterface (
interface_list = [ x [ 0 ] for x in interfaces ] ,
tab_names = [ x [ 1 ] for x in interfaces ] ,
2022-08-26 02:52:05 +08:00
css = ( " " if cmd_opts . no_progressbar_hiding else css_hide_progressbar ) + """
2022-08-24 23:47:23 +08:00
. output - html p { margin : 0 0.5 em ; }
. performance { font - size : 0.85 em ; color : #444; }
"""
2022-08-24 14:06:36 +08:00
)
2022-08-22 22:15:46 +08:00
2022-08-26 02:52:05 +08:00
demo . launch ( )