mirror of
https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
synced 2025-01-18 14:55:09 +08:00
Unify CodeFormer and GFPGAN restoration backends, use Spandrel for GFPGAN
This commit is contained in:
parent
b0f5934234
commit
b621a63cf6
8
.github/workflows/run_tests.yaml
vendored
8
.github/workflows/run_tests.yaml
vendored
@ -20,6 +20,12 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
**/requirements*txt
|
||||
launch.py
|
||||
- name: Cache models
|
||||
id: cache-models
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: models
|
||||
key: "2023-12-30"
|
||||
- name: Install test dependencies
|
||||
run: pip install wait-for-it -r requirements-test.txt
|
||||
env:
|
||||
@ -33,6 +39,8 @@ jobs:
|
||||
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
|
||||
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
|
||||
PYTHONUNBUFFERED: "1"
|
||||
- name: Print installed packages
|
||||
run: pip freeze
|
||||
- name: Start test server
|
||||
run: >
|
||||
python -m coverage run
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -37,3 +37,4 @@ notification.mp3
|
||||
/node_modules
|
||||
/package-lock.json
|
||||
/.coverage*
|
||||
/test/test_outputs
|
||||
|
@ -1,140 +1,62 @@
|
||||
import os
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
import modules.face_restoration
|
||||
import modules.shared
|
||||
from modules import shared, devices, modelloader, errors
|
||||
from modules.paths import models_path
|
||||
from modules import (
|
||||
devices,
|
||||
errors,
|
||||
face_restoration,
|
||||
face_restoration_utils,
|
||||
modelloader,
|
||||
shared,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
model_dir = "Codeformer"
|
||||
model_path = os.path.join(models_path, model_dir)
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_download_name = 'codeformer-v0.1.0.pth'
|
||||
|
||||
codeformer = None
|
||||
# used by e.g. postprocessing_codeformer.py
|
||||
codeformer: face_restoration.FaceRestoration | None = None
|
||||
|
||||
|
||||
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
|
||||
class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
|
||||
def name(self):
|
||||
return "CodeFormer"
|
||||
|
||||
def __init__(self, dirname):
|
||||
self.net = None
|
||||
self.face_helper = None
|
||||
self.cmd_dir = dirname
|
||||
|
||||
def create_models(self):
|
||||
from facexlib.detection import retinaface
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
if self.net is not None and self.face_helper is not None:
|
||||
self.net.to(devices.device_codeformer)
|
||||
return self.net, self.face_helper
|
||||
model_paths = modelloader.load_models(
|
||||
model_path,
|
||||
model_url,
|
||||
self.cmd_dir,
|
||||
download_name='codeformer-v0.1.0.pth',
|
||||
def load_net(self) -> torch.Module:
|
||||
for model_path in modelloader.load_models(
|
||||
model_path=self.model_path,
|
||||
model_url=model_url,
|
||||
command_path=self.model_path,
|
||||
download_name=model_download_name,
|
||||
ext_filter=['.pth'],
|
||||
)
|
||||
):
|
||||
return modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=devices.device_codeformer,
|
||||
).model
|
||||
raise ValueError("No codeformer model found")
|
||||
|
||||
if len(model_paths) != 0:
|
||||
ckpt_path = model_paths[0]
|
||||
else:
|
||||
print("Unable to load codeformer model.")
|
||||
return None, None
|
||||
net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer)
|
||||
def get_device(self):
|
||||
return devices.device_codeformer
|
||||
|
||||
if hasattr(retinaface, 'device'):
|
||||
retinaface.device = devices.device_codeformer
|
||||
def restore(self, np_image, w: float | None = None):
|
||||
if w is None:
|
||||
w = getattr(shared.opts, "code_former_weight", 0.5)
|
||||
|
||||
face_helper = FaceRestoreHelper(
|
||||
upscale_factor=1,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
use_parse=True,
|
||||
device=devices.device_codeformer,
|
||||
)
|
||||
def restore_face(cropped_face_t):
|
||||
assert self.net is not None
|
||||
return self.net(cropped_face_t, w=w, adain=True)[0]
|
||||
|
||||
self.net = net
|
||||
self.face_helper = face_helper
|
||||
|
||||
def send_model_to(self, device):
|
||||
self.net.to(device)
|
||||
self.face_helper.face_det.to(device)
|
||||
self.face_helper.face_parse.to(device)
|
||||
|
||||
def restore(self, np_image, w=None):
|
||||
from torchvision.transforms.functional import normalize
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
np_image = np_image[:, :, ::-1]
|
||||
|
||||
original_resolution = np_image.shape[0:2]
|
||||
|
||||
self.create_models()
|
||||
if self.net is None or self.face_helper is None:
|
||||
return np_image
|
||||
|
||||
self.send_model_to(devices.device_codeformer)
|
||||
|
||||
self.face_helper.clean_all()
|
||||
self.face_helper.read_image(np_image)
|
||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
self.face_helper.align_warp_face()
|
||||
|
||||
for cropped_face in self.face_helper.cropped_faces:
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)
|
||||
if isinstance(res, tuple):
|
||||
output = res[0]
|
||||
else:
|
||||
output = res
|
||||
if not isinstance(res, torch.Tensor):
|
||||
raise TypeError(f"Expected torch.Tensor, got {type(res)}")
|
||||
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
devices.torch_gc()
|
||||
except Exception:
|
||||
errors.report('Failed inference for CodeFormer', exc_info=True)
|
||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
self.face_helper.add_restored_face(restored_face)
|
||||
|
||||
self.face_helper.get_inverse_affine(None)
|
||||
|
||||
restored_img = self.face_helper.paste_faces_to_input_image()
|
||||
restored_img = restored_img[:, :, ::-1]
|
||||
|
||||
if original_resolution != restored_img.shape[0:2]:
|
||||
restored_img = cv2.resize(
|
||||
restored_img,
|
||||
(0, 0),
|
||||
fx=original_resolution[1]/restored_img.shape[1],
|
||||
fy=original_resolution[0]/restored_img.shape[0],
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
self.face_helper.clean_all()
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
self.send_model_to(devices.cpu)
|
||||
|
||||
return restored_img
|
||||
return self.restore_with_helper(np_image, restore_face)
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
def setup_model(dirname: str) -> None:
|
||||
global codeformer
|
||||
try:
|
||||
global codeformer
|
||||
codeformer = FaceRestorerCodeFormer(dirname)
|
||||
shared.face_restorers.append(codeformer)
|
||||
except Exception:
|
||||
|
163
modules/face_restoration_utils.py
Normal file
163
modules/face_restoration_utils.py
Normal file
@ -0,0 +1,163 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules import devices, errors, face_restoration, shared
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_face_helper(device) -> FaceRestoreHelper:
|
||||
from facexlib.detection import retinaface
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
if hasattr(retinaface, 'device'):
|
||||
retinaface.device = device
|
||||
return FaceRestoreHelper(
|
||||
upscale_factor=1,
|
||||
face_size=512,
|
||||
crop_ratio=(1, 1),
|
||||
det_model='retinaface_resnet50',
|
||||
save_ext='png',
|
||||
use_parse=True,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
def restore_with_face_helper(
|
||||
np_image: np.ndarray,
|
||||
face_helper: FaceRestoreHelper,
|
||||
restore_face: Callable[[np.ndarray], np.ndarray],
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
|
||||
|
||||
`restore_face` should take a cropped face image and return a restored face image.
|
||||
"""
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from torchvision.transforms.functional import normalize
|
||||
np_image = np_image[:, :, ::-1]
|
||||
original_resolution = np_image.shape[0:2]
|
||||
|
||||
try:
|
||||
logger.debug("Detecting faces...")
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(np_image)
|
||||
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||
face_helper.align_warp_face()
|
||||
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
|
||||
for cropped_face in face_helper.cropped_faces:
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
restored_face = tensor2img(
|
||||
restore_face(cropped_face_t),
|
||||
rgb2bgr=True,
|
||||
min_max=(-1, 1),
|
||||
)
|
||||
devices.torch_gc()
|
||||
except Exception:
|
||||
errors.report('Failed face-restoration inference', exc_info=True)
|
||||
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
face_helper.add_restored_face(restored_face)
|
||||
|
||||
logger.debug("Merging restored faces into image")
|
||||
face_helper.get_inverse_affine(None)
|
||||
img = face_helper.paste_faces_to_input_image()
|
||||
img = img[:, :, ::-1]
|
||||
if original_resolution != img.shape[0:2]:
|
||||
img = cv2.resize(
|
||||
img,
|
||||
(0, 0),
|
||||
fx=original_resolution[1] / img.shape[1],
|
||||
fy=original_resolution[0] / img.shape[0],
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
logger.debug("Face restoration complete")
|
||||
finally:
|
||||
face_helper.clean_all()
|
||||
return img
|
||||
|
||||
|
||||
class CommonFaceRestoration(face_restoration.FaceRestoration):
|
||||
net: torch.Module | None
|
||||
model_url: str
|
||||
model_download_name: str
|
||||
|
||||
def __init__(self, model_path: str):
|
||||
super().__init__()
|
||||
self.net = None
|
||||
self.model_path = model_path
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
@cached_property
|
||||
def face_helper(self) -> FaceRestoreHelper:
|
||||
return create_face_helper(self.get_device())
|
||||
|
||||
def send_model_to(self, device):
|
||||
if self.net:
|
||||
logger.debug("Sending %s to %s", self.net, device)
|
||||
self.net.to(device)
|
||||
if self.face_helper:
|
||||
logger.debug("Sending face helper to %s", device)
|
||||
self.face_helper.face_det.to(device)
|
||||
self.face_helper.face_parse.to(device)
|
||||
|
||||
def get_device(self):
|
||||
raise NotImplementedError("get_device must be implemented by subclasses")
|
||||
|
||||
def load_net(self) -> torch.Module:
|
||||
raise NotImplementedError("load_net must be implemented by subclasses")
|
||||
|
||||
def restore_with_helper(
|
||||
self,
|
||||
np_image: np.ndarray,
|
||||
restore_face: Callable[[np.ndarray], np.ndarray],
|
||||
) -> np.ndarray:
|
||||
try:
|
||||
if self.net is None:
|
||||
self.net = self.load_net()
|
||||
except Exception:
|
||||
logger.warning("Unable to load face-restoration model", exc_info=True)
|
||||
return np_image
|
||||
|
||||
try:
|
||||
self.send_model_to(self.get_device())
|
||||
return restore_with_face_helper(np_image, self.face_helper, restore_face)
|
||||
finally:
|
||||
if shared.opts.face_restoration_unload:
|
||||
self.send_model_to(devices.cpu)
|
||||
|
||||
|
||||
def patch_facexlib(dirname: str) -> None:
|
||||
import facexlib.detection
|
||||
import facexlib.parsing
|
||||
|
||||
det_facex_load_file_from_url = facexlib.detection.load_file_from_url
|
||||
par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
|
||||
|
||||
def update_kwargs(kwargs):
|
||||
return dict(kwargs, save_dir=dirname, model_dir=None)
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
return det_facex_load_file_from_url(**update_kwargs(kwargs))
|
||||
|
||||
def facex_load_file_from_url2(**kwargs):
|
||||
return par_facex_load_file_from_url(**update_kwargs(kwargs))
|
||||
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
@ -1,126 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import modules.face_restoration
|
||||
from modules import paths, shared, devices, modelloader, errors
|
||||
from modules import (
|
||||
devices,
|
||||
errors,
|
||||
face_restoration,
|
||||
face_restoration_utils,
|
||||
modelloader,
|
||||
shared,
|
||||
)
|
||||
|
||||
model_dir = "GFPGAN"
|
||||
user_path = None
|
||||
model_path = os.path.join(paths.models_path, model_dir)
|
||||
model_file_path = None
|
||||
logger = logging.getLogger(__name__)
|
||||
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
|
||||
have_gfpgan = False
|
||||
loaded_gfpgan_model = None
|
||||
model_download_name = "GFPGANv1.4.pth"
|
||||
gfpgan_face_restorer: face_restoration.FaceRestoration | None = None
|
||||
|
||||
|
||||
def gfpgann():
|
||||
global loaded_gfpgan_model
|
||||
global model_path
|
||||
global model_file_path
|
||||
if loaded_gfpgan_model is not None:
|
||||
loaded_gfpgan_model.gfpgan.to(devices.device_gfpgan)
|
||||
return loaded_gfpgan_model
|
||||
class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):
|
||||
def name(self):
|
||||
return "GFPGAN"
|
||||
|
||||
if gfpgan_constructor is None:
|
||||
return None
|
||||
def get_device(self):
|
||||
return devices.device_gfpgan
|
||||
|
||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter=['.pth'])
|
||||
def load_net(self) -> None:
|
||||
for model_path in modelloader.load_models(
|
||||
model_path=self.model_path,
|
||||
model_url=model_url,
|
||||
command_path=self.model_path,
|
||||
download_name=model_download_name,
|
||||
ext_filter=['.pth'],
|
||||
):
|
||||
if 'GFPGAN' in os.path.basename(model_path):
|
||||
net = modelloader.load_spandrel_model(
|
||||
model_path,
|
||||
device=self.get_device(),
|
||||
).model
|
||||
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
|
||||
return net
|
||||
raise ValueError("No GFPGAN model found")
|
||||
|
||||
if len(models) == 1 and models[0].startswith("http"):
|
||||
model_file = models[0]
|
||||
elif len(models) != 0:
|
||||
gfp_models = []
|
||||
for item in models:
|
||||
if 'GFPGAN' in os.path.basename(item):
|
||||
gfp_models.append(item)
|
||||
latest_file = max(gfp_models, key=os.path.getctime)
|
||||
model_file = latest_file
|
||||
else:
|
||||
print("Unable to load gfpgan model!")
|
||||
return None
|
||||
def restore(self, np_image):
|
||||
def restore_face(cropped_face_t):
|
||||
assert self.net is not None
|
||||
return self.net(cropped_face_t, return_rgb=False)[0]
|
||||
|
||||
import facexlib.detection.retinaface
|
||||
|
||||
if hasattr(facexlib.detection.retinaface, 'device'):
|
||||
facexlib.detection.retinaface.device = devices.device_gfpgan
|
||||
model_file_path = model_file
|
||||
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=devices.device_gfpgan)
|
||||
loaded_gfpgan_model = model
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def send_model_to(model, device):
|
||||
model.gfpgan.to(device)
|
||||
model.face_helper.face_det.to(device)
|
||||
model.face_helper.face_parse.to(device)
|
||||
return self.restore_with_helper(np_image, restore_face)
|
||||
|
||||
|
||||
def gfpgan_fix_faces(np_image):
|
||||
model = gfpgann()
|
||||
if model is None:
|
||||
return np_image
|
||||
|
||||
send_model_to(model, devices.device_gfpgan)
|
||||
|
||||
np_image_bgr = np_image[:, :, ::-1]
|
||||
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
|
||||
np_image = gfpgan_output_bgr[:, :, ::-1]
|
||||
|
||||
model.face_helper.clean_all()
|
||||
|
||||
if shared.opts.face_restoration_unload:
|
||||
send_model_to(model, devices.cpu)
|
||||
|
||||
if gfpgan_face_restorer:
|
||||
return gfpgan_face_restorer.restore(np_image)
|
||||
logger.warning("GFPGAN face restorer not set up")
|
||||
return np_image
|
||||
|
||||
|
||||
gfpgan_constructor = None
|
||||
def setup_model(dirname: str) -> None:
|
||||
global gfpgan_face_restorer
|
||||
|
||||
|
||||
def setup_model(dirname):
|
||||
try:
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
import gfpgan
|
||||
import facexlib.detection
|
||||
import facexlib.parsing
|
||||
|
||||
global user_path
|
||||
global have_gfpgan
|
||||
global gfpgan_constructor
|
||||
global model_file_path
|
||||
|
||||
facexlib_path = model_path
|
||||
|
||||
if dirname is not None:
|
||||
facexlib_path = dirname
|
||||
|
||||
load_file_from_url_orig = gfpgan.utils.load_file_from_url
|
||||
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
|
||||
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
|
||||
|
||||
def my_load_file_from_url(**kwargs):
|
||||
return load_file_from_url_orig(**dict(kwargs, model_dir=model_file_path))
|
||||
|
||||
def facex_load_file_from_url(**kwargs):
|
||||
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||
|
||||
def facex_load_file_from_url2(**kwargs):
|
||||
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=facexlib_path, model_dir=None))
|
||||
|
||||
gfpgan.utils.load_file_from_url = my_load_file_from_url
|
||||
facexlib.detection.load_file_from_url = facex_load_file_from_url
|
||||
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
|
||||
user_path = dirname
|
||||
have_gfpgan = True
|
||||
gfpgan_constructor = gfpgan.GFPGANer
|
||||
|
||||
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
|
||||
def name(self):
|
||||
return "GFPGAN"
|
||||
|
||||
def restore(self, np_image):
|
||||
return gfpgan_fix_faces(np_image)
|
||||
|
||||
shared.face_restorers.append(FaceRestorerGFPGAN())
|
||||
face_restoration_utils.patch_facexlib(dirname)
|
||||
gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)
|
||||
shared.face_restorers.append(gfpgan_face_restorer)
|
||||
except Exception:
|
||||
errors.report("Error setting up GFPGAN", exc_info=True)
|
||||
|
@ -8,7 +8,6 @@ clean-fid
|
||||
einops
|
||||
facexlib
|
||||
fastapi>=0.90.1
|
||||
gfpgan
|
||||
gradio==3.41.2
|
||||
inflection
|
||||
jsonmerge
|
||||
|
@ -7,7 +7,6 @@ clean-fid==0.1.35
|
||||
einops==0.4.1
|
||||
facexlib==0.3.0
|
||||
fastapi==0.94.0
|
||||
gfpgan==1.3.8
|
||||
gradio==3.41.2
|
||||
httpcore==0.15
|
||||
inflection==0.5.1
|
||||
|
@ -1,10 +1,16 @@
|
||||
import base64
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
test_files_path = os.path.dirname(__file__) + "/test_files"
|
||||
test_outputs_path = os.path.dirname(__file__) + "/test_outputs"
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
# We don't want to fail on Py.test command line arguments being
|
||||
# parsed by webui:
|
||||
os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1")
|
||||
|
||||
|
||||
def file_to_base64(filename):
|
||||
@ -23,3 +29,8 @@ def img2img_basic_image_base64() -> str:
|
||||
@pytest.fixture(scope="session") # session so we don't read this over and over
|
||||
def mask_basic_image_base64() -> str:
|
||||
return file_to_base64(os.path.join(test_files_path, "mask_basic.png"))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def initialize() -> None:
|
||||
import webui # noqa: F401
|
||||
|
29
test/test_face_restorers.py
Normal file
29
test/test_face_restorers.py
Normal file
@ -0,0 +1,29 @@
|
||||
import os
|
||||
from test.conftest import test_files_path, test_outputs_path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("initialize")
|
||||
@pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"])
|
||||
def test_face_restorers(restorer_name):
|
||||
from modules import shared
|
||||
|
||||
if restorer_name == "gfpgan":
|
||||
from modules import gfpgan_model
|
||||
gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path)
|
||||
restorer = gfpgan_model.gfpgan_fix_faces
|
||||
elif restorer_name == "codeformer":
|
||||
from modules import codeformer_model
|
||||
codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path)
|
||||
restorer = codeformer_model.codeformer.restore
|
||||
else:
|
||||
raise NotImplementedError("...")
|
||||
img = Image.open(os.path.join(test_files_path, "two-faces.jpg"))
|
||||
np_img = np.array(img, dtype=np.uint8)
|
||||
fixed_image = restorer(np_img)
|
||||
assert fixed_image.shape == np_img.shape
|
||||
assert not np.allclose(fixed_image, np_img) # should have visibly changed
|
||||
Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png"))
|
BIN
test/test_files/two-faces.jpg
Normal file
BIN
test/test_files/two-faces.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 14 KiB |
0
test/test_outputs/.gitkeep
Normal file
0
test/test_outputs/.gitkeep
Normal file
Loading…
Reference in New Issue
Block a user