change formatting to match the main program in devices.py

This commit is contained in:
AUTOMATIC 2022-11-12 10:00:49 +03:00
parent c62d17aee3
commit 0ab0a50f9a

View File

@ -3,23 +3,27 @@ import contextlib
import torch import torch
from modules import errors from modules import errors
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
# check `getattr` and try it for compatibility # check `getattr` and try it for compatibility
def has_mps() -> bool: def has_mps() -> bool:
if not getattr(torch, 'has_mps', False): return False if not getattr(torch, 'has_mps', False):
return False
try: try:
torch.zeros(1).to(torch.device("mps")) torch.zeros(1).to(torch.device("mps"))
return True return True
except Exception: except Exception:
return False return False
cpu = torch.device("cpu")
def extract_device_id(args, name): def extract_device_id(args, name):
for x in range(len(args)): for x in range(len(args)):
if name in args[x]: return args[x+1] if name in args[x]:
return args[x + 1]
return None return None
def get_optimal_device(): def get_optimal_device():
if torch.cuda.is_available(): if torch.cuda.is_available():
from modules import shared from modules import shared
@ -52,10 +56,12 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16 dtype_vae = torch.float16
def randn(seed, shape): def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
if device.type == 'mps': if device.type == 'mps':
@ -89,6 +95,11 @@ def autocast(disable=False):
return torch.autocast("cuda") return torch.autocast("cuda")
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor def mps_contiguous(input_tensor, device):
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) return input_tensor.contiguous() if device.type == 'mps' else input_tensor
def mps_contiguous_to(input_tensor, device):
return mps_contiguous(input_tensor, device).to(device)