mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2024-11-21 03:11:40 +08:00
Holy $hit.
Yep. Fix gfpgan_model_arch requirement(s). Add Upscaler base class, move from images. Add a lot of methods to Upscaler. Re-work all the child upscalers to be proper classes. Add BSRGAN scaler. Add ldsr_model_arch class, removing the dependency for another repo that just uses regular latent-diffusion stuff. Add one universal method that will always find and load new upscaler models without having to add new "setup_model" calls. Still need to add command line params, but that could probably be automated. Add a "self.scale" property to all Upscalers so the scalers themselves can do "things" in response to the requested upscaling size. Ensure LDSR doesn't get stuck in a longer loop of "upscale/downscale/upscale" as we try to reach the target upscale size. Add typehints for IDE sanity. PEP-8 improvements. Moar.
This commit is contained in:
parent
31ad536c33
commit
0dce0df1ee
11
launch.py
11
launch.py
@ -1,5 +1,5 @@
|
||||
# this scripts installs necessary requirements and launches main program in webui.py
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
@ -22,7 +22,6 @@ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "6
|
||||
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH',"abf33e7002d59d9085081bce93ec798dcabd49af")
|
||||
|
||||
args = shlex.split(commandline_args)
|
||||
|
||||
@ -122,9 +121,11 @@ git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-di
|
||||
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
|
||||
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
|
||||
git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash)
|
||||
|
||||
if os.path.isdir(repo_dir('latent-diffusion')):
|
||||
try:
|
||||
shutil.rmtree(repo_dir('latent-diffusion'))
|
||||
except:
|
||||
pass
|
||||
if not is_installed("lpips"):
|
||||
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
|
||||
|
||||
|
79
modules/bsrgan_model.py
Normal file
79
modules/bsrgan_model.py
Normal file
@ -0,0 +1,79 @@
|
||||
import os.path
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import PIL.Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.upscaler
|
||||
from modules import shared, modelloader
|
||||
from modules.bsrgan_model_arch import RRDBNet
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
class UpscalerBSRGAN(modules.upscaler.Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "BSRGAN"
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.model_name = "BSRGAN 4x"
|
||||
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
try:
|
||||
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
except Exception:
|
||||
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img: PIL.Image, selected_file):
|
||||
torch.cuda.empty_cache()
|
||||
model = self.load_model(selected_file)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
torch.cuda.empty_cache()
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
img = np.moveaxis(img, 2, 0) / 255
|
||||
img = torch.from_numpy(img).float()
|
||||
img = img.unsqueeze(0).to(shared.device)
|
||||
with torch.no_grad():
|
||||
output = model(img)
|
||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output = 255. * np.moveaxis(output, 0, 2)
|
||||
output = output.astype(np.uint8)
|
||||
output = output[:, :, ::-1]
|
||||
torch.cuda.empty_cache()
|
||||
return PIL.Image.fromarray(output, 'RGB')
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print("Unable to load %s from %s" % (self.model_dir, filename))
|
||||
return None
|
||||
print("Loading %s from %s" % (self.model_dir, filename))
|
||||
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network
|
||||
model.load_state_dict(torch.load(filename), strict=True)
|
||||
model.eval()
|
||||
for k, v in model.named_parameters():
|
||||
v.requires_grad = False
|
||||
return model
|
||||
|
103
modules/bsrgan_model_arch.py
Normal file
103
modules/bsrgan_model_arch.py
Normal file
@ -0,0 +1,103 @@
|
||||
import functools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
def initialize_weights(net_l, scale=1):
|
||||
if not isinstance(net_l, list):
|
||||
net_l = [net_l]
|
||||
for net in net_l:
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale # for residual block
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
|
||||
m.weight.data *= scale
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
init.constant_(m.weight, 1)
|
||||
init.constant_(m.bias.data, 0.0)
|
||||
|
||||
|
||||
def make_layer(block, n_layers):
|
||||
layers = []
|
||||
for _ in range(n_layers):
|
||||
layers.append(block())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
def __init__(self, nf=64, gc=32, bias=True):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
# gc: growth channel, i.e. intermediate channels
|
||||
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
|
||||
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
|
||||
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
# initialization
|
||||
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.lrelu(self.conv1(x))
|
||||
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
||||
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
'''Residual in Residual Dense Block'''
|
||||
|
||||
def __init__(self, nf, gc=32):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
|
||||
super(RRDBNet, self).__init__()
|
||||
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
|
||||
self.sf = sf
|
||||
print([in_nc, out_nc, nf, nb, gc, sf])
|
||||
|
||||
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
|
||||
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
#### upsampling
|
||||
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
if self.sf==4:
|
||||
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
if self.sf==4:
|
||||
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
return out
|
@ -1,6 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -8,93 +6,119 @@ from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.esrgam_model_arch as arch
|
||||
import modules.images
|
||||
from modules import shared
|
||||
from modules import shared, modelloader
|
||||
from modules import shared, modelloader, images
|
||||
from modules.devices import has_mps
|
||||
from modules.paths import models_path
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import opts
|
||||
|
||||
model_dir = "ESRGAN"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
||||
model_name = "ESRGAN_x4"
|
||||
|
||||
class UpscalerESRGAN(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "ESRGAN"
|
||||
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
|
||||
self.model_name = "ESRGAN 4x"
|
||||
self.scalers = []
|
||||
self.user_path = dirname
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
super().__init__()
|
||||
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
scalers = []
|
||||
if len(model_paths) == 0:
|
||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||
scalers.append(scaler_data)
|
||||
for file in model_paths:
|
||||
print(f"File: {file}")
|
||||
if "http" in file:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(file)
|
||||
|
||||
def load_model(path: str, name: str):
|
||||
global model_path
|
||||
global model_url
|
||||
global model_dir
|
||||
global model_name
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=model_url, model_dir=model_path, file_name="%s.pth" % model_name, progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print("Unable to load %s from %s" % (model_dir, filename))
|
||||
return None
|
||||
print("Loading %s from %s" % (model_dir, filename))
|
||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
scaler_data = UpscalerData(name, file, self, 4)
|
||||
print(f"ESRGAN: Adding scaler {name}")
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
if 'conv_first.weight' in pretrained_net:
|
||||
crt_model.load_state_dict(pretrained_net)
|
||||
def do_upscale(self, img, selected_model):
|
||||
model = self.load_model(selected_model)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
def load_model(self, path: str):
|
||||
if "http" in path:
|
||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="%s.pth" % self.model_name,
|
||||
progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
||||
return None
|
||||
# this code is adapted from https://github.com/xinntao/ESRGAN
|
||||
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
|
||||
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
|
||||
|
||||
if 'conv_first.weight' in pretrained_net:
|
||||
crt_model.load_state_dict(pretrained_net)
|
||||
return crt_model
|
||||
|
||||
if 'model.0.weight' not in pretrained_net:
|
||||
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
|
||||
"params_ema"]
|
||||
if is_realesrgan:
|
||||
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
|
||||
else:
|
||||
raise Exception("The file is not a ESRGAN model.")
|
||||
|
||||
crt_net = crt_model.state_dict()
|
||||
load_net_clean = {}
|
||||
for k, v in pretrained_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
pretrained_net = load_net_clean
|
||||
|
||||
tbd = []
|
||||
for k, v in crt_net.items():
|
||||
tbd.append(k)
|
||||
|
||||
# directly copy
|
||||
for k, v in crt_net.items():
|
||||
if k in pretrained_net and pretrained_net[k].size() == v.size():
|
||||
crt_net[k] = pretrained_net[k]
|
||||
tbd.remove(k)
|
||||
|
||||
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
|
||||
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
|
||||
|
||||
for k in tbd.copy():
|
||||
if 'RDB' in k:
|
||||
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
||||
if '.weight' in k:
|
||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||
elif '.bias' in k:
|
||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||
crt_net[k] = pretrained_net[ori_k]
|
||||
tbd.remove(k)
|
||||
|
||||
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
|
||||
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
|
||||
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
|
||||
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
|
||||
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
|
||||
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
|
||||
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
|
||||
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
|
||||
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
||||
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
||||
|
||||
crt_model.load_state_dict(crt_net)
|
||||
crt_model.eval()
|
||||
return crt_model
|
||||
|
||||
if 'model.0.weight' not in pretrained_net:
|
||||
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
|
||||
if is_realesrgan:
|
||||
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
|
||||
else:
|
||||
raise Exception("The file is not a ESRGAN model.")
|
||||
|
||||
crt_net = crt_model.state_dict()
|
||||
load_net_clean = {}
|
||||
for k, v in pretrained_net.items():
|
||||
if k.startswith('module.'):
|
||||
load_net_clean[k[7:]] = v
|
||||
else:
|
||||
load_net_clean[k] = v
|
||||
pretrained_net = load_net_clean
|
||||
|
||||
tbd = []
|
||||
for k, v in crt_net.items():
|
||||
tbd.append(k)
|
||||
|
||||
# directly copy
|
||||
for k, v in crt_net.items():
|
||||
if k in pretrained_net and pretrained_net[k].size() == v.size():
|
||||
crt_net[k] = pretrained_net[k]
|
||||
tbd.remove(k)
|
||||
|
||||
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
|
||||
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
|
||||
|
||||
for k in tbd.copy():
|
||||
if 'RDB' in k:
|
||||
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
|
||||
if '.weight' in k:
|
||||
ori_k = ori_k.replace('.weight', '.0.weight')
|
||||
elif '.bias' in k:
|
||||
ori_k = ori_k.replace('.bias', '.0.bias')
|
||||
crt_net[k] = pretrained_net[ori_k]
|
||||
tbd.remove(k)
|
||||
|
||||
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
|
||||
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
|
||||
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
|
||||
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
|
||||
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
|
||||
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
|
||||
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
|
||||
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
|
||||
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
|
||||
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
|
||||
|
||||
crt_model.load_state_dict(crt_net)
|
||||
crt_model.eval()
|
||||
return crt_model
|
||||
|
||||
def upscale_without_tiling(model, img):
|
||||
img = np.array(img)
|
||||
@ -115,7 +139,7 @@ def esrgan_upscale(model, img):
|
||||
if opts.ESRGAN_tile == 0:
|
||||
return upscale_without_tiling(model, img)
|
||||
|
||||
grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
||||
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
|
||||
newtiles = []
|
||||
scale_factor = 1
|
||||
|
||||
@ -130,38 +154,7 @@ def esrgan_upscale(model, img):
|
||||
newrow.append([x * scale_factor, w * scale_factor, output])
|
||||
newtiles.append([y * scale_factor, h * scale_factor, newrow])
|
||||
|
||||
newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||
output = modules.images.combine_grid(newgrid)
|
||||
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor,
|
||||
grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
|
||||
output = images.combine_grid(newgrid)
|
||||
return output
|
||||
|
||||
|
||||
class UpscalerESRGAN(modules.images.Upscaler):
|
||||
def __init__(self, filename, title):
|
||||
self.name = title
|
||||
self.filename = filename
|
||||
|
||||
def do_upscale(self, img):
|
||||
model = load_model(self.filename, self.name)
|
||||
if model is None:
|
||||
return img
|
||||
model.to(shared.device)
|
||||
img = esrgan_upscale(model, img)
|
||||
return img
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
global model_name
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
|
||||
model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
|
||||
if len(model_paths) == 0:
|
||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
|
||||
for file in model_paths:
|
||||
name = modelloader.friendly_name(file)
|
||||
try:
|
||||
modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
|
||||
except Exception:
|
||||
print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
@ -66,29 +66,28 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
|
||||
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
|
||||
image = res
|
||||
|
||||
if upscaling_resize != 1.0:
|
||||
def upscale(image, scaler_index, resize):
|
||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||
pixels = tuple(np.array(small).flatten().tolist())
|
||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
||||
def upscale(image, scaler_index, resize):
|
||||
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
|
||||
pixels = tuple(np.array(small).flatten().tolist())
|
||||
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
|
||||
|
||||
c = cached_images.get(key)
|
||||
if c is None:
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
c = upscaler.upscale(image, image.width * resize, image.height * resize)
|
||||
cached_images[key] = c
|
||||
c = cached_images.get(key)
|
||||
if c is None:
|
||||
upscaler = shared.sd_upscalers[scaler_index]
|
||||
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
|
||||
cached_images[key] = c
|
||||
|
||||
return c
|
||||
return c
|
||||
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
||||
res = upscale(image, extras_upscaler_1, upscaling_resize)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
|
||||
res = upscale(image, extras_upscaler_1, upscaling_resize)
|
||||
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||
res2 = upscale(image, extras_upscaler_2, upscaling_resize)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
||||
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
||||
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
|
||||
res2 = upscale(image, extras_upscaler_2, upscaling_resize)
|
||||
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
|
||||
res = Image.blend(res, res2, extras_upscaler_2_visibility)
|
||||
|
||||
image = res
|
||||
image = res
|
||||
|
||||
while len(cached_images) > 2:
|
||||
del cached_images[next(iter(cached_images.keys()))]
|
||||
|
@ -1,24 +1,23 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from glob import glob
|
||||
|
||||
from modules import shared, devices
|
||||
from modules.shared import cmd_opts
|
||||
from modules.paths import script_path
|
||||
import facexlib
|
||||
import gfpgan
|
||||
|
||||
import modules.face_restoration
|
||||
from modules import shared, devices, modelloader
|
||||
from modules.paths import models_path
|
||||
|
||||
model_dir = "GFPGAN"
|
||||
cmd_dir = None
|
||||
user_path = None
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
|
||||
have_gfpgan = False
|
||||
loaded_gfpgan_model = None
|
||||
|
||||
|
||||
def gfpgan():
|
||||
def gfpgann():
|
||||
global loaded_gfpgan_model
|
||||
global model_path
|
||||
if loaded_gfpgan_model is not None:
|
||||
@ -28,14 +27,16 @@ def gfpgan():
|
||||
if gfpgan_constructor is None:
|
||||
return None
|
||||
|
||||
models = modelloader.load_models(model_path, model_url, cmd_dir)
|
||||
if len(models) != 0:
|
||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
||||
if len(models) == 1 and "http" in models[0]:
|
||||
model_file = models[0]
|
||||
elif len(models) != 0:
|
||||
latest_file = max(models, key=os.path.getctime)
|
||||
model_file = latest_file
|
||||
else:
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2,
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2,
|
||||
bg_upsampler=None)
|
||||
model.gfpgan.to(shared.device)
|
||||
loaded_gfpgan_model = model
|
||||
@ -44,11 +45,12 @@ def gfpgan():
|
||||
|
||||
|
||||
def gfpgan_fix_faces(np_image):
|
||||
model = gfpgan()
|
||||
model = gfpgann()
|
||||
if model is None:
|
||||
return np_image
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False,
|
||||
only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
@ -57,7 +59,6 @@ def gfpgan_fix_faces(np_image):
|
||||
return np_image
|
||||
|
||||
|
||||
have_gfpgan = False
|
||||
gfpgan_constructor = None
|
||||
|
||||
|
||||
@ -67,14 +68,33 @@ def setup_model(dirname):
|
||||
os.makedirs(model_path)
|
||||
|
||||
try:
|
||||
from modules.gfpgan_model_arch import GFPGANerr
|
||||
global cmd_dir
|
||||
from gfpgan import GFPGANer
|
||||
from facexlib import detection, parsing
|
||||
global user_path
|
||||
global have_gfpgan
|
||||
global gfpgan_constructor
|
||||
|
||||
cmd_dir = dirname
|
||||
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
||||
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
||||
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
||||
|
||||
def my_load_file_from_url(**kwargs):
|
||||
print("Setting model_dir to " + model_path)
|
||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
|
||||
|
||||
def facex_load_file_from_url2(**kwargs):
|
||||
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
|
||||
|
||||
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
||||
user_path = dirname
|
||||
print("Have gfpgan should be true?")
|
||||
have_gfpgan = True
|
||||
gfpgan_constructor = GFPGANerr
|
||||
gfpgan_constructor = GFPGANer
|
||||
|
||||
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
||||
def name(self):
|
||||
@ -82,7 +102,9 @@ def setup_model(dirname):
|
||||
|
||||
def restore(self, np_image):
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False,
|
||||
only_center_face=False,
|
||||
paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
||||
return np_image
|
||||
|
@ -1,150 +0,0 @@
|
||||
# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original...
|
||||
|
||||
import cv2
|
||||
import os
|
||||
import torch
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
|
||||
from gfpgan.archs.gfpganv1_arch import GFPGANv1
|
||||
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
class GFPGANerr():
|
||||
"""Helper for restoration with GFPGAN.
|
||||
|
||||
It will detect and crop faces, and then resize the faces to 512x512.
|
||||
GFPGAN is used to restored the resized faces.
|
||||
The background is upsampled with the bg_upsampler.
|
||||
Finally, the faces will be pasted back to the upsample background image.
|
||||
|
||||
Args:
|
||||
model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
|
||||
upscale (float): The upscale of the final output. Default: 2.
|
||||
arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
|
||||
channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
|
||||
bg_upsampler (nn.Module): The upsampler for the background. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
|
||||
self.upscale = upscale
|
||||
self.bg_upsampler = bg_upsampler
|
||||
|
||||
# initialize model
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
||||
# initialize the GFP-GAN
|
||||
if arch == 'clean':
|
||||
self.gfpgan = GFPGANv1Clean(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'bilinear':
|
||||
self.gfpgan = GFPGANBilinear(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=False,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'original':
|
||||
self.gfpgan = GFPGANv1(
|
||||
out_size=512,
|
||||
num_style_feat=512,
|
||||
channel_multiplier=channel_multiplier,
|
||||
decoder_load_path=None,
|
||||
fix_decoder=True,
|
||||
num_mlp=8,
|
||||
input_is_latent=True,
|
||||
different_w=True,
|
||||
narrow=1,
|
||||
sft_half=True)
|
||||
elif arch == 'RestoreFormer':
|
||||
from gfpgan.archs.restoreformer_arch import RestoreFormer
|
||||
self.gfpgan = RestoreFormer()
|
||||
# initialize face helper
|
||||
self.face_helper = FaceRestoreHelper(
|
||||
upscale,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
use_parse=True,
|
||||
device=self.device,
|
||||
model_rootpath=model_dir)
|
||||
|
||||
if model_path.startswith('https://'):
|
||||
model_path = load_file_from_url(
|
||||
url=model_path, model_dir=model_dir, progress=True, file_name=None)
|
||||
loadnet = torch.load(model_path)
|
||||
if 'params_ema' in loadnet:
|
||||
keyname = 'params_ema'
|
||||
else:
|
||||
keyname = 'params'
|
||||
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
|
||||
self.gfpgan.eval()
|
||||
self.gfpgan = self.gfpgan.to(self.device)
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if has_aligned: # the inputs are already aligned
|
||||
img = cv2.resize(img, (512, 512))
|
||||
self.face_helper.cropped_faces = [img]
|
||||
else:
|
||||
self.face_helper.read_image(img)
|
||||
# get face landmarks for each face
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
|
||||
# eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
|
||||
# TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
|
||||
# align and warp each face
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
# face restoration
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
# prepare data
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
|
||||
|
||||
try:
|
||||
output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
|
||||
# convert to image
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
except RuntimeError as error:
|
||||
print(f'\tFailed inference for GFPGAN: {error}.')
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
if not has_aligned and paste_back:
|
||||
# upsample the background
|
||||
if self.bg_upsampler is not None:
|
||||
# Now only support RealESRGAN for upsampling background
|
||||
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
|
||||
else:
|
||||
bg_img = None
|
||||
|
||||
self.face_helper.get_inverse_affine(None)
|
||||
# paste each restored face to the input image
|
||||
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
|
||||
else:
|
||||
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
|
@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||
from fonts.ttf import Roboto
|
||||
import string
|
||||
|
||||
import modules.shared
|
||||
from modules import sd_samplers, shared
|
||||
from modules.shared import opts, cmd_opts
|
||||
|
||||
@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
||||
cols = math.ceil((w - overlap) / non_overlap_width)
|
||||
rows = math.ceil((h - overlap) / non_overlap_height)
|
||||
|
||||
dx = (w - tile_w) / (cols-1) if cols > 1 else 0
|
||||
dy = (h - tile_h) / (rows-1) if rows > 1 else 0
|
||||
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
||||
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
||||
|
||||
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
||||
for row in range(rows):
|
||||
@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
||||
for col in range(cols):
|
||||
x = int(col * dx)
|
||||
|
||||
if x+tile_w >= w:
|
||||
if x + tile_w >= w:
|
||||
x = w - tile_w
|
||||
|
||||
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
||||
@ -85,8 +84,10 @@ def combine_grid(grid):
|
||||
r = r.astype(np.uint8)
|
||||
return Image.fromarray(r, 'L')
|
||||
|
||||
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
||||
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
||||
mask_w = make_mask_image(
|
||||
np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
||||
mask_h = make_mask_image(
|
||||
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
||||
|
||||
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
||||
for y, h, row in grid.tiles:
|
||||
@ -129,10 +130,12 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||
|
||||
def draw_texts(drawing, draw_x, draw_y, lines):
|
||||
for i, line in enumerate(lines):
|
||||
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
||||
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt,
|
||||
fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
|
||||
|
||||
if not line.is_active:
|
||||
drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
|
||||
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2,
|
||||
draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
||||
|
||||
draw_y += line.size[1] + line_spacing
|
||||
|
||||
@ -171,7 +174,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
|
||||
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
||||
|
||||
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
|
||||
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
|
||||
ver_texts]
|
||||
|
||||
pad_top = max(hor_text_heights) + line_spacing * 2
|
||||
|
||||
@ -202,8 +206,10 @@ def draw_prompt_matrix(im, width, height, all_prompts):
|
||||
prompts_horiz = prompts[:boundary]
|
||||
prompts_vert = prompts[boundary:]
|
||||
|
||||
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
||||
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
||||
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in
|
||||
range(1 << len(prompts_horiz))]
|
||||
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in
|
||||
range(1 << len(prompts_vert))]
|
||||
|
||||
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
|
||||
|
||||
@ -214,7 +220,8 @@ def resize_image(resize_mode, im, width, height):
|
||||
return im.resize((w, h), resample=LANCZOS)
|
||||
|
||||
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
|
||||
return upscaler.upscale(im, w, h)
|
||||
scale = w / im.width
|
||||
return upscaler.scaler.upscale(im, scale)
|
||||
|
||||
if resize_mode == 0:
|
||||
res = resize(im, width, height)
|
||||
@ -244,11 +251,13 @@ def resize_image(resize_mode, im, width, height):
|
||||
if ratio < src_ratio:
|
||||
fill_height = height // 2 - src_h // 2
|
||||
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
||||
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
||||
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
|
||||
box=(0, fill_height + src_h))
|
||||
elif ratio > src_ratio:
|
||||
fill_width = width // 2 - src_w // 2
|
||||
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
||||
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
||||
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
|
||||
box=(fill_width + src_w, 0))
|
||||
|
||||
return res
|
||||
|
||||
@ -256,7 +265,7 @@ def resize_image(resize_mode, im, width, height):
|
||||
invalid_filename_chars = '<>:"/\\|?*\n'
|
||||
invalid_filename_prefix = ' '
|
||||
invalid_filename_postfix = ' .'
|
||||
re_nonletters = re.compile(r'[\s'+string.punctuation+']+')
|
||||
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
||||
max_filename_part_length = 128
|
||||
|
||||
|
||||
@ -283,7 +292,8 @@ def apply_filename_pattern(x, p, seed, prompt):
|
||||
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
|
||||
if len(words) == 0:
|
||||
words = ["empty"]
|
||||
x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
||||
x = x.replace("[prompt_words]",
|
||||
sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
|
||||
|
||||
if p is not None:
|
||||
x = x.replace("[steps]", str(p.steps))
|
||||
@ -291,7 +301,8 @@ def apply_filename_pattern(x, p, seed, prompt):
|
||||
x = x.replace("[width]", str(p.width))
|
||||
x = x.replace("[height]", str(p.height))
|
||||
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
|
||||
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||
x = x.replace("[sampler]",
|
||||
sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
|
||||
|
||||
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
|
||||
x = x.replace("[date]", datetime.date.today().isoformat())
|
||||
@ -303,6 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def get_next_sequence_number(path, basename):
|
||||
"""
|
||||
Determines and returns the next sequence number to use when saving an image in the specified directory.
|
||||
@ -316,7 +328,8 @@ def get_next_sequence_number(path, basename):
|
||||
prefix_length = len(basename)
|
||||
for p in os.listdir(path):
|
||||
if p.startswith(basename):
|
||||
l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
l = os.path.splitext(p[prefix_length:])[0].split(
|
||||
'-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||
try:
|
||||
result = max(int(l[0]), result)
|
||||
except ValueError:
|
||||
@ -324,7 +337,10 @@ def get_next_sequence_number(path, basename):
|
||||
|
||||
return result + 1
|
||||
|
||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
|
||||
|
||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False,
|
||||
no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None,
|
||||
forced_filename=None, suffix=""):
|
||||
if short_filename or prompt is None or seed is None:
|
||||
file_decoration = ""
|
||||
elif opts.save_to_dirs:
|
||||
@ -361,7 +377,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
fullfn = "a.png"
|
||||
fullfn_without_extension = "a"
|
||||
for i in range(500):
|
||||
fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}"
|
||||
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
|
||||
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
|
||||
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
|
||||
if not os.path.exists(fullfn):
|
||||
@ -403,31 +419,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
||||
file.write(info + "\n")
|
||||
|
||||
|
||||
class Upscaler:
|
||||
name = "Lanczos"
|
||||
|
||||
def do_upscale(self, img):
|
||||
return img
|
||||
|
||||
def upscale(self, img, w, h):
|
||||
for i in range(3):
|
||||
if img.width >= w and img.height >= h:
|
||||
break
|
||||
|
||||
img = self.do_upscale(img)
|
||||
|
||||
if img.width != w or img.height != h:
|
||||
img = img.resize((int(w), int(h)), resample=LANCZOS)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class UpscalerNone(Upscaler):
|
||||
name = "None"
|
||||
|
||||
def upscale(self, img, w, h):
|
||||
return img
|
||||
|
||||
|
||||
modules.shared.sd_upscalers.append(UpscalerNone())
|
||||
modules.shared.sd_upscalers.append(Upscaler())
|
||||
|
@ -1,74 +1,45 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
from modules import shared, images, modelloader, paths
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.ldsr_model_arch import LDSR
|
||||
from modules import shared
|
||||
from modules.paths import models_path
|
||||
|
||||
model_dir = "LDSR"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
cmd_path = None
|
||||
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||
|
||||
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
|
||||
|
||||
ldsr_models = []
|
||||
have_ldsr = False
|
||||
LDSR_obj = None
|
||||
|
||||
|
||||
class UpscalerLDSR(images.Upscaler):
|
||||
def __init__(self, steps):
|
||||
self.steps = steps
|
||||
class UpscalerLDSR(Upscaler):
|
||||
def __init__(self, user_path):
|
||||
self.name = "LDSR"
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.user_path = user_path
|
||||
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
|
||||
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
|
||||
super().__init__()
|
||||
scaler_data = UpscalerData("LDSR", None, self)
|
||||
self.scalers = [scaler_data]
|
||||
|
||||
def do_upscale(self, img):
|
||||
return upscale_with_ldsr(img)
|
||||
def load_model(self, path: str):
|
||||
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="model.pth", progress=True)
|
||||
yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
||||
file_name="project.yaml", progress=True)
|
||||
|
||||
try:
|
||||
return LDSR(model, yaml)
|
||||
|
||||
def setup_model(dirname):
|
||||
global cmd_path
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
cmd_path = dirname
|
||||
shared.sd_upscalers.append(UpscalerLDSR(100))
|
||||
except Exception:
|
||||
print("Error importing LDSR:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def prepare_ldsr():
|
||||
path = paths.paths.get("LDSR", None)
|
||||
if path is None:
|
||||
return
|
||||
global have_ldsr
|
||||
global LDSR_obj
|
||||
try:
|
||||
from LDSR import LDSR
|
||||
model_files = modelloader.load_models(model_path, model_url, cmd_path, dl_name="model.ckpt", ext_filter=[".ckpt"])
|
||||
yaml_files = modelloader.load_models(model_path, yaml_url, cmd_path, dl_name="project.yaml", ext_filter=[".yaml"])
|
||||
if len(model_files) != 0 and len(yaml_files) != 0:
|
||||
model_file = model_files[0]
|
||||
yaml_file = yaml_files[0]
|
||||
have_ldsr = True
|
||||
LDSR_obj = LDSR(model_file, yaml_file)
|
||||
else:
|
||||
return
|
||||
|
||||
except Exception:
|
||||
print("Error importing LDSR:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
have_ldsr = False
|
||||
|
||||
|
||||
def upscale_with_ldsr(image):
|
||||
prepare_ldsr()
|
||||
if not have_ldsr or LDSR_obj is None:
|
||||
return image
|
||||
|
||||
ddim_steps = shared.opts.ldsr_steps
|
||||
pre_scale = shared.opts.ldsr_pre_down
|
||||
post_scale = shared.opts.ldsr_post_down
|
||||
|
||||
image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
|
||||
return image
|
||||
def do_upscale(self, img, path):
|
||||
ldsr = self.load_model(path)
|
||||
if ldsr is None:
|
||||
print("NO LDSR!")
|
||||
return img
|
||||
ddim_steps = shared.opts.ldsr_steps
|
||||
pre_scale = shared.opts.ldsr_pre_down
|
||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||
|
223
modules/ldsr_model_arch.py
Normal file
223
modules/ldsr_model_arch.py
Normal file
@ -0,0 +1,223 @@
|
||||
import gc
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.util import instantiate_from_config, ismap
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
|
||||
|
||||
# Create LDSR Class
|
||||
class LDSR:
|
||||
def load_model_from_config(self, half_attention):
|
||||
print(f"Loading model from {self.modelPath}")
|
||||
pl_sd = torch.load(self.modelPath, map_location="cpu")
|
||||
sd = pl_sd["state_dict"]
|
||||
config = OmegaConf.load(self.yamlPath)
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(sd, strict=False)
|
||||
model.cuda()
|
||||
if half_attention:
|
||||
model = model.half()
|
||||
|
||||
model.eval()
|
||||
return {"model": model}
|
||||
|
||||
def __init__(self, model_path, yaml_path):
|
||||
self.modelPath = model_path
|
||||
self.yamlPath = yaml_path
|
||||
|
||||
@staticmethod
|
||||
def run(model, selected_path, custom_steps, eta):
|
||||
example = get_cond(selected_path)
|
||||
|
||||
n_runs = 1
|
||||
guider = None
|
||||
ckwargs = None
|
||||
ddim_use_x0_pred = False
|
||||
temperature = 1.
|
||||
eta = eta
|
||||
custom_shape = None
|
||||
|
||||
height, width = example["image"].shape[1:3]
|
||||
split_input = height >= 128 and width >= 128
|
||||
|
||||
if split_input:
|
||||
ks = 128
|
||||
stride = 64
|
||||
vqf = 4 #
|
||||
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
|
||||
"vqf": vqf,
|
||||
"patch_distributed_vq": True,
|
||||
"tie_braker": False,
|
||||
"clip_max_weight": 0.5,
|
||||
"clip_min_weight": 0.01,
|
||||
"clip_max_tie_weight": 0.5,
|
||||
"clip_min_tie_weight": 0.01}
|
||||
else:
|
||||
if hasattr(model, "split_input_params"):
|
||||
delattr(model, "split_input_params")
|
||||
|
||||
x_t = None
|
||||
logs = None
|
||||
for n in range(n_runs):
|
||||
if custom_shape is not None:
|
||||
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
||||
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
||||
|
||||
logs = make_convolutional_sample(example, model,
|
||||
custom_steps=custom_steps,
|
||||
eta=eta, quantize_x0=False,
|
||||
custom_shape=custom_shape,
|
||||
temperature=temperature, noise_dropout=0.,
|
||||
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
|
||||
ddim_use_x0_pred=ddim_use_x0_pred
|
||||
)
|
||||
return logs
|
||||
|
||||
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
|
||||
model = self.load_model_from_config(half_attention)
|
||||
|
||||
# Run settings
|
||||
diffusion_steps = int(steps)
|
||||
eta = 1.0
|
||||
|
||||
down_sample_method = 'Lanczos'
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
im_og = image
|
||||
width_og, height_og = im_og.size
|
||||
# If we can adjust the max upscale size, then the 4 below should be our variable
|
||||
print("Foo")
|
||||
down_sample_rate = target_scale / 4
|
||||
print(f"Downsample rate is {down_sample_rate}")
|
||||
width_downsampled_pre = width_og * down_sample_rate
|
||||
height_downsampled_pre = height_og * down_sample_method
|
||||
|
||||
if down_sample_rate != 1:
|
||||
print(
|
||||
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
|
||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||
else:
|
||||
print(f"Down sample rate is 1 from {target_scale} / 4")
|
||||
logs = self.run(model["model"], im_og, diffusion_steps, eta)
|
||||
|
||||
sample = logs["sample"]
|
||||
sample = sample.detach().cpu()
|
||||
sample = torch.clamp(sample, -1., 1.)
|
||||
sample = (sample + 1.) / 2. * 255
|
||||
sample = sample.numpy().astype(np.uint8)
|
||||
sample = np.transpose(sample, (0, 2, 3, 1))
|
||||
a = Image.fromarray(sample[0])
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
print(f'Processing finished!')
|
||||
return a
|
||||
|
||||
|
||||
def get_cond(selected_path):
|
||||
example = dict()
|
||||
up_f = 4
|
||||
c = selected_path.convert('RGB')
|
||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
||||
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
|
||||
antialias=True)
|
||||
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
|
||||
c = rearrange(c, '1 c h w -> 1 h w c')
|
||||
c = 2. * c - 1.
|
||||
|
||||
c = c.to(torch.device("cuda"))
|
||||
example["LR_image"] = c
|
||||
example["image"] = c_up
|
||||
|
||||
return example
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
|
||||
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
|
||||
corrector_kwargs=None, x_t=None
|
||||
):
|
||||
ddim = DDIMSampler(model)
|
||||
bs = shape[0]
|
||||
shape = shape[1:]
|
||||
print(f"Sampling with eta = {eta}; steps: {steps}")
|
||||
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
|
||||
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
|
||||
mask=mask, x0=x0, temperature=temperature, verbose=False,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs, x_t=x_t)
|
||||
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
||||
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
||||
log = dict()
|
||||
|
||||
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
force_c_encode=not (hasattr(model, 'split_input_params')
|
||||
and model.cond_stage_key == 'coordinates_bbox'),
|
||||
return_original_cond=True)
|
||||
|
||||
if custom_shape is not None:
|
||||
z = torch.randn(custom_shape)
|
||||
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
|
||||
|
||||
z0 = None
|
||||
|
||||
log["input"] = x
|
||||
log["reconstruction"] = xrec
|
||||
|
||||
if ismap(xc):
|
||||
log["original_conditioning"] = model.to_rgb(xc)
|
||||
if hasattr(model, 'cond_stage_key'):
|
||||
log[model.cond_stage_key] = model.to_rgb(xc)
|
||||
|
||||
else:
|
||||
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
|
||||
if model.cond_stage_model:
|
||||
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
|
||||
if model.cond_stage_key == 'class_label':
|
||||
log[model.cond_stage_key] = xc[model.cond_stage_key]
|
||||
|
||||
with model.ema_scope("Plotting"):
|
||||
t0 = time.time()
|
||||
|
||||
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
|
||||
eta=eta,
|
||||
quantize_x0=quantize_x0, mask=None, x0=z0,
|
||||
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
|
||||
x_t=x_T)
|
||||
t1 = time.time()
|
||||
|
||||
if ddim_use_x0_pred:
|
||||
sample = intermediates['pred_x0'][-1]
|
||||
|
||||
x_sample = model.decode_first_stage(sample)
|
||||
|
||||
try:
|
||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||
log["sample_noquant"] = x_sample_noquant
|
||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||
except:
|
||||
pass
|
||||
|
||||
log["sample"] = x_sample
|
||||
log["time"] = t1 - t0
|
||||
|
||||
return log
|
@ -1,34 +1,36 @@
|
||||
import os
|
||||
import shutil
|
||||
import importlib
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
from modules import shared
|
||||
from modules.upscaler import Upscaler
|
||||
from modules.paths import script_path, models_path
|
||||
|
||||
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, dl_name: str = None, existing=None,
|
||||
ext_filter=None) -> list:
|
||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
|
||||
"""
|
||||
A one-and done loader to try finding the desired models in specified directories.
|
||||
|
||||
@param dl_name: The file name to use for downloading a model. If not specified, it will be used from the URL.
|
||||
@param model_url: If specified, attempt to download model from the given URL.
|
||||
@param download_name: Specify to download from model_url immediately.
|
||||
@param model_url: If no other models are found, this will be downloaded on upscale.
|
||||
@param model_path: The location to store/find models in.
|
||||
@param command_path: A command-line argument to search for models in first.
|
||||
@param existing: An array of existing model paths.
|
||||
@param ext_filter: An optional list of filename extensions to filter by
|
||||
@return: A list of paths containing the desired model(s)
|
||||
"""
|
||||
output = []
|
||||
|
||||
if ext_filter is None:
|
||||
ext_filter = []
|
||||
if existing is None:
|
||||
existing = []
|
||||
try:
|
||||
places = []
|
||||
if command_path is not None and command_path != model_path:
|
||||
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
|
||||
if os.path.exists(pretrained_path):
|
||||
print(f"Appending path: {pretrained_path}")
|
||||
places.append(pretrained_path)
|
||||
elif os.path.exists(command_path):
|
||||
places.append(command_path)
|
||||
@ -36,26 +38,24 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
||||
for place in places:
|
||||
if os.path.exists(place):
|
||||
for file in os.listdir(place):
|
||||
if os.path.isdir(file):
|
||||
full_path = os.path.join(place, file)
|
||||
if os.path.isdir(full_path):
|
||||
continue
|
||||
if len(ext_filter) != 0:
|
||||
model_name, extension = os.path.splitext(file)
|
||||
if extension not in ext_filter:
|
||||
continue
|
||||
if file not in existing:
|
||||
path = os.path.join(place, file)
|
||||
existing.append(path)
|
||||
if model_url is not None and len(existing) == 0:
|
||||
if dl_name is not None:
|
||||
model_file = load_file_from_url(url=model_url, model_dir=model_path, file_name=dl_name, progress=True)
|
||||
if file not in output:
|
||||
output.append(full_path)
|
||||
if model_url is not None and len(output) == 0:
|
||||
if download_name is not None:
|
||||
dl = load_file_from_url(model_url, model_path, True, download_name)
|
||||
output.append(dl)
|
||||
else:
|
||||
model_file = load_file_from_url(url=model_url, model_dir=model_path, progress=True)
|
||||
|
||||
if os.path.exists(model_file) and os.path.isfile(model_file) and model_file not in existing:
|
||||
existing.append(model_file)
|
||||
output.append(model_url)
|
||||
except:
|
||||
pass
|
||||
return existing
|
||||
return output
|
||||
|
||||
|
||||
def friendly_name(file: str):
|
||||
@ -110,4 +110,38 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
||||
print(f"Removing empty folder: {src_path}")
|
||||
shutil.rmtree(src_path, True)
|
||||
except:
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def load_upscalers():
|
||||
datas = []
|
||||
for cls in Upscaler.__subclasses__():
|
||||
name = cls.__name__
|
||||
module_name = cls.__module__
|
||||
print(f"Class: {name} and {module_name}")
|
||||
module = importlib.import_module(module_name)
|
||||
class_ = getattr(module, name)
|
||||
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
|
||||
print(f"CMD Name: {cmd_name}")
|
||||
opt_string = None
|
||||
try:
|
||||
opt_string = shared.opts.__getattr__(cmd_name)
|
||||
except:
|
||||
pass
|
||||
scaler = class_(opt_string)
|
||||
for child in scaler.scalers:
|
||||
print(f"Appending {child.name}")
|
||||
datas.append(child)
|
||||
|
||||
shared.sd_upscalers = datas
|
||||
|
||||
# for scaler in subclasses:
|
||||
# print(f"Found scaler: {type(scaler).__name__}")
|
||||
# try:
|
||||
# scaler = scaler()
|
||||
# for child in scaler.scalers:
|
||||
# print(f"Appending {child.name}")
|
||||
# datas.append[child]
|
||||
# except:
|
||||
# pass
|
||||
# shared.sd_upscalers = datas
|
||||
|
@ -1,64 +1,135 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
import modules.images
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.paths import models_path
|
||||
from modules.shared import cmd_opts, opts
|
||||
|
||||
model_dir = "RealESRGAN"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
cmd_dir = None
|
||||
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
|
||||
realesrgan_models = []
|
||||
have_realesrgan = False
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
def __init__(self, path):
|
||||
self.name = "RealESRGAN"
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.user_path = path
|
||||
super().__init__()
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
self.enable = True
|
||||
self.scalers = []
|
||||
scalers = self.load_models(path)
|
||||
for scaler in scalers:
|
||||
if scaler.name in opts.realesrgan_enabled_models:
|
||||
self.scalers.append(scaler)
|
||||
|
||||
except Exception:
|
||||
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
self.enable = False
|
||||
self.scalers = []
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
if not self.enable:
|
||||
return img
|
||||
|
||||
info = self.load_model(path)
|
||||
if not os.path.exists(info.data_path):
|
||||
print("Unable to load RealESRGAN model: %s" % info.name)
|
||||
return img
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.scale,
|
||||
model_path=info.data_path,
|
||||
model=info.model(),
|
||||
half=not cmd_opts.no_half,
|
||||
tile=opts.ESRGAN_tile,
|
||||
tile_pad=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
|
||||
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
|
||||
|
||||
image = Image.fromarray(upsampled)
|
||||
return image
|
||||
|
||||
def load_model(self, path):
|
||||
try:
|
||||
info = None
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
info = scaler
|
||||
|
||||
if info is None:
|
||||
print(f"Unable to find model info: {path}")
|
||||
return None
|
||||
|
||||
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||
info.data_path = model_file
|
||||
return info
|
||||
except Exception as e:
|
||||
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return None
|
||||
|
||||
def load_models(self, _):
|
||||
return get_realesrgan_models(self)
|
||||
|
||||
|
||||
def get_realesrgan_models():
|
||||
def get_realesrgan_models(scaler):
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
models = [
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN General x4x3",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
|
||||
netscale=4,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3"
|
||||
".pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
|
||||
act_type='prelu')
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN General WDN x4x3",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||
netscale=4,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
UpscalerData(
|
||||
name="R-ESRGAN General WDN 4xV3",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
|
||||
act_type='prelu')
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN AnimeVideo",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||
netscale=4,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
|
||||
UpscalerData(
|
||||
name="R-ESRGAN AnimeVideo",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4,
|
||||
act_type='prelu')
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN 4x plus",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
netscale=4,
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN 4x plus anime 6B",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
netscale=4,
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 4x+ Anime6B",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
scale=4,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
),
|
||||
RealesrganModelInfo(
|
||||
name="Real-ESRGAN 2x plus",
|
||||
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
netscale=2,
|
||||
UpscalerData(
|
||||
name="R-ESRGAN 2x+",
|
||||
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
scale=2,
|
||||
upscaler=scaler,
|
||||
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
||||
),
|
||||
]
|
||||
@ -66,69 +137,3 @@ def get_realesrgan_models():
|
||||
except Exception as e:
|
||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(modules.images.Upscaler):
|
||||
def __init__(self, upscaling, model_index):
|
||||
self.upscaling = upscaling
|
||||
self.model_index = model_index
|
||||
self.name = realesrgan_models[model_index].name
|
||||
|
||||
def do_upscale(self, img):
|
||||
return upscale_with_realesrgan(img, self.upscaling, self.model_index)
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
|
||||
global realesrgan_models
|
||||
global have_realesrgan
|
||||
if model_path != dirname:
|
||||
model_path = dirname
|
||||
try:
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
|
||||
realesrgan_models = get_realesrgan_models()
|
||||
have_realesrgan = True
|
||||
|
||||
for i, model in enumerate(realesrgan_models):
|
||||
if model.name in opts.realesrgan_enabled_models:
|
||||
modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
|
||||
|
||||
except Exception:
|
||||
print("Error importing Real-ESRGAN:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
|
||||
have_realesrgan = False
|
||||
|
||||
|
||||
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
|
||||
if not have_realesrgan:
|
||||
return image
|
||||
|
||||
info = realesrgan_models[RealESRGAN_model_index]
|
||||
|
||||
model = info.model()
|
||||
model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True)
|
||||
if not os.path.exists(model_file):
|
||||
print("Unable to load RealESRGAN model: %s" % info.name)
|
||||
return image
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
scale=info.netscale,
|
||||
model_path=info.location,
|
||||
model=model,
|
||||
half=not cmd_opts.no_half,
|
||||
tile=opts.ESRGAN_tile,
|
||||
tile_pad=opts.ESRGAN_tile_overlap,
|
||||
)
|
||||
|
||||
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
|
||||
|
||||
image = Image.fromarray(upsampled)
|
||||
return image
|
||||
|
@ -50,7 +50,7 @@ def setup_model(dirname):
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
checkpoints_list.clear()
|
||||
model_list = modelloader.load_models(model_path, model_url, dirname, model_name, ext_filter=".ckpt")
|
||||
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=dirname, download_name=model_name, ext_filter=".ckpt")
|
||||
|
||||
cmd_ckpt = shared.cmd_opts.ckpt
|
||||
if os.path.exists(cmd_ckpt):
|
||||
@ -68,6 +68,7 @@ def setup_model(dirname):
|
||||
|
||||
def model_hash(filename):
|
||||
try:
|
||||
print(f"Opening: {filename}")
|
||||
with open(filename, "rb") as file:
|
||||
import hashlib
|
||||
m = hashlib.sha256()
|
||||
|
@ -154,9 +154,9 @@ class VanillaStableDiffusionSampler:
|
||||
|
||||
# existing code fails with cetin step counts, like 9
|
||||
try:
|
||||
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
|
||||
samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_t=x, eta=p.ddim_eta)
|
||||
except Exception:
|
||||
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
|
||||
samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_t=x, eta=p.ddim_eta)
|
||||
|
||||
return samples_ddim
|
||||
|
||||
|
@ -1,18 +1,19 @@
|
||||
import sys
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import gradio as gr
|
||||
import tqdm
|
||||
import datetime
|
||||
|
||||
import modules.artists
|
||||
from modules.paths import script_path, sd_path
|
||||
from modules.devices import get_optimal_device
|
||||
import modules.styles
|
||||
import modules.interrogate
|
||||
import modules.memmon
|
||||
import modules.sd_models
|
||||
import modules.styles
|
||||
from modules.devices import get_optimal_device
|
||||
from modules.paths import script_path, sd_path
|
||||
|
||||
sd_model_file = os.path.join(script_path, 'model.ckpt')
|
||||
default_sd_model_file = sd_model_file
|
||||
@ -38,6 +39,7 @@ parser.add_argument("--share", action='store_true', help="use share=True for gra
|
||||
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
|
||||
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
|
||||
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
|
||||
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
|
||||
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
|
||||
parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR'))
|
||||
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
|
||||
@ -111,7 +113,7 @@ face_restorers = []
|
||||
|
||||
def realesrgan_models_names():
|
||||
import modules.realesrgan_model
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models()]
|
||||
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
|
||||
|
||||
|
||||
class OptionInfo:
|
||||
@ -176,13 +178,11 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
|
||||
options_templates.update(options_section(('upscaling', "Upscaling"), {
|
||||
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
|
||||
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
|
||||
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
|
||||
"ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
|
||||
"ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
|
||||
|
||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||
}))
|
||||
|
||||
|
@ -1,92 +1,91 @@
|
||||
import contextlib
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
|
||||
import modules.images
|
||||
from modules import modelloader
|
||||
from modules.paths import models_path
|
||||
from modules.shared import cmd_opts, opts, device
|
||||
from modules.swinir_model_arch import SwinIR as net
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
model_dir = "SwinIR"
|
||||
model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
||||
model_name = "SwinIR x4"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
cmd_path = ""
|
||||
precision_scope = (
|
||||
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
|
||||
)
|
||||
|
||||
|
||||
def load_model(path, scale=4):
|
||||
global model_path
|
||||
global model_name
|
||||
if "http" in path:
|
||||
dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth")
|
||||
filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True)
|
||||
else:
|
||||
filename = path
|
||||
if filename is None or not os.path.exists(filename):
|
||||
return None
|
||||
model = net(
|
||||
upscale=scale,
|
||||
in_chans=3,
|
||||
img_size=64,
|
||||
window_size=8,
|
||||
img_range=1.0,
|
||||
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
||||
embed_dim=240,
|
||||
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
||||
mlp_ratio=2,
|
||||
upsampler="nearest+conv",
|
||||
resi_connection="3conv",
|
||||
)
|
||||
class UpscalerSwinIR(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self.name = "SwinIR"
|
||||
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
|
||||
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
||||
"-L_x4_GAN.pth "
|
||||
self.model_name = "SwinIR 4x"
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
self.user_path = dirname
|
||||
super().__init__()
|
||||
scalers = []
|
||||
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
||||
for model in model_files:
|
||||
if "http" in model:
|
||||
name = self.model_name
|
||||
else:
|
||||
name = modelloader.friendly_name(model)
|
||||
model_data = UpscalerData(name, model, self)
|
||||
scalers.append(model_data)
|
||||
self.scalers = scalers
|
||||
|
||||
pretrained_model = torch.load(filename)
|
||||
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
||||
if not cmd_opts.no_half:
|
||||
model = model.half()
|
||||
return model
|
||||
def do_upscale(self, img, model_file):
|
||||
model = self.load_model(model_file)
|
||||
if model is None:
|
||||
return img
|
||||
model = model.to(device)
|
||||
img = upscale(img, model)
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
return img
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
global model_path
|
||||
global model_name
|
||||
global cmd_path
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(model_path)
|
||||
cmd_path = dirname
|
||||
model_file = ""
|
||||
try:
|
||||
models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path)
|
||||
|
||||
if len(models) != 0:
|
||||
model_file = models[0]
|
||||
name = modelloader.friendly_name(model_file)
|
||||
def load_model(self, path, scale=4):
|
||||
if "http" in path:
|
||||
dl_name = "%s%s" % (self.name.replace(" ", "_"), ".pth")
|
||||
filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
|
||||
else:
|
||||
# Add the "default" model if none are found.
|
||||
model_file = model_url
|
||||
name = model_name
|
||||
filename = path
|
||||
if filename is None or not os.path.exists(filename):
|
||||
return None
|
||||
model = net(
|
||||
upscale=scale,
|
||||
in_chans=3,
|
||||
img_size=64,
|
||||
window_size=8,
|
||||
img_range=1.0,
|
||||
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
|
||||
embed_dim=240,
|
||||
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
|
||||
mlp_ratio=2,
|
||||
upsampler="nearest+conv",
|
||||
resi_connection="3conv",
|
||||
)
|
||||
|
||||
modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name))
|
||||
except Exception:
|
||||
print(f"Error loading SwinIR model: {model_file}", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
pretrained_model = torch.load(filename)
|
||||
model.load_state_dict(pretrained_model["params_ema"], strict=True)
|
||||
if not cmd_opts.no_half:
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
|
||||
def upscale(
|
||||
img,
|
||||
model,
|
||||
tile=opts.SWIN_tile,
|
||||
tile_overlap=opts.SWIN_tile_overlap,
|
||||
window_size=8,
|
||||
scale=4,
|
||||
img,
|
||||
model,
|
||||
tile=opts.SWIN_tile,
|
||||
tile_overlap=opts.SWIN_tile_overlap,
|
||||
window_size=8,
|
||||
scale=4,
|
||||
):
|
||||
img = np.array(img)
|
||||
img = img[:, :, ::-1]
|
||||
@ -125,34 +124,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
||||
|
||||
for h_idx in h_idx_list:
|
||||
for w_idx in w_idx_list:
|
||||
in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile]
|
||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||
out_patch = model(in_patch)
|
||||
out_patch_mask = torch.ones_like(out_patch)
|
||||
|
||||
E[
|
||||
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
|
||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||
].add_(out_patch)
|
||||
W[
|
||||
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
|
||||
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||
].add_(out_patch_mask)
|
||||
output = E.div_(W)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class UpscalerSwin(modules.images.Upscaler):
|
||||
def __init__(self, filename, title):
|
||||
self.name = title
|
||||
self.filename = filename
|
||||
|
||||
def do_upscale(self, img):
|
||||
model = load_model(self.filename)
|
||||
if model is None:
|
||||
return img
|
||||
model = model.to(device)
|
||||
img = upscale(img, model)
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
return img
|
121
modules/upscaler.py
Normal file
121
modules/upscaler.py
Normal file
@ -0,0 +1,121 @@
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared
|
||||
from modules import modelloader, shared
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
from modules.paths import models_path
|
||||
|
||||
|
||||
class Upscaler:
|
||||
name = None
|
||||
model_path = None
|
||||
model_name = None
|
||||
model_url = None
|
||||
enable = True
|
||||
filter = None
|
||||
model = None
|
||||
user_path = None
|
||||
scalers: []
|
||||
tile = True
|
||||
|
||||
def __init__(self, create_dirs=False):
|
||||
self.mod_pad_h = None
|
||||
self.tile_size = modules.shared.opts.ESRGAN_tile
|
||||
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
|
||||
self.device = modules.shared.device
|
||||
self.img = None
|
||||
self.output = None
|
||||
self.scale = 1
|
||||
self.half = not modules.shared.cmd_opts.no_half
|
||||
self.pre_pad = 0
|
||||
self.mod_scale = None
|
||||
if self.name is not None and create_dirs:
|
||||
self.model_path = os.path.join(models_path, self.name)
|
||||
if not os.path.exists(self.model_path):
|
||||
os.makedirs(self.model_path)
|
||||
|
||||
try:
|
||||
import cv2
|
||||
self.can_tile = True
|
||||
except:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_upscale(self, img: PIL.Image, selected_model: str):
|
||||
return img
|
||||
|
||||
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
|
||||
self.scale = scale
|
||||
dest_w = img.width * scale
|
||||
dest_h = img.height * scale
|
||||
for i in range(3):
|
||||
if img.width >= dest_w and img.height >= dest_h:
|
||||
break
|
||||
img = self.do_upscale(img, selected_model)
|
||||
if img.width != dest_w or img.height != dest_h:
|
||||
img = img.resize(dest_w, dest_h, resample=LANCZOS)
|
||||
|
||||
return img
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, path: str):
|
||||
pass
|
||||
|
||||
def find_models(self, ext_filter=None) -> list:
|
||||
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
|
||||
|
||||
def update_status(self, prompt):
|
||||
print(f"\nextras: {prompt}", file=shared.progress_print_out)
|
||||
|
||||
|
||||
class UpscalerData:
|
||||
name = None
|
||||
data_path = None
|
||||
scale: int = 4
|
||||
scaler: Upscaler = None
|
||||
model: None
|
||||
|
||||
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
|
||||
self.name = name
|
||||
self.data_path = path
|
||||
self.scaler = upscaler
|
||||
self.scale = scale
|
||||
self.model = model
|
||||
|
||||
|
||||
class UpscalerNone(Upscaler):
|
||||
name = "None"
|
||||
scalers = []
|
||||
|
||||
def load_model(self, path):
|
||||
pass
|
||||
|
||||
def do_upscale(self, img, selected_model=None):
|
||||
return img
|
||||
|
||||
def __init__(self, dirname=None):
|
||||
super().__init__(False)
|
||||
self.scalers = [UpscalerData("None", None, self)]
|
||||
|
||||
|
||||
class UpscalerLanczos(Upscaler):
|
||||
scalers = []
|
||||
|
||||
def do_upscale(self, img, selected_model=None):
|
||||
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
|
||||
|
||||
def load_model(self, _):
|
||||
pass
|
||||
|
||||
def __init__(self, dirname=None):
|
||||
super().__init__(False)
|
||||
self.name = "Lanczos"
|
||||
self.scalers = [UpscalerData("Lanczos", None, self)]
|
||||
|
9
webui.py
9
webui.py
@ -1,9 +1,10 @@
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
|
||||
import modules.paths
|
||||
import modules.codeformer_model as codeformer
|
||||
import modules.esrgan_model as esrgan
|
||||
import modules.bsrgan_model as bsrgan
|
||||
import modules.extras
|
||||
import modules.face_restoration
|
||||
import modules.gfpgan_model as gfpgan
|
||||
@ -27,11 +28,7 @@ modules.sd_models.setup_model(cmd_opts.stablediffusion_models_path)
|
||||
codeformer.setup_model(cmd_opts.codeformer_models_path)
|
||||
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
|
||||
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
|
||||
|
||||
esrgan.setup_model(cmd_opts.esrgan_models_path)
|
||||
swinir.setup_model(cmd_opts.swinir_models_path)
|
||||
realesrgan.setup_model(cmd_opts.realesrgan_models_path)
|
||||
ldsr.setup_model(cmd_opts.ldsr_models_path)
|
||||
modelloader.load_upscalers()
|
||||
queue_lock = threading.Lock()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user