2022-12-10 14:14:30 +08:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class TorchHijackForUnet:
|
|
|
|
"""
|
|
|
|
This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
|
2022-12-15 10:01:32 +08:00
|
|
|
this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64
|
2022-12-10 14:14:30 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __getattr__(self, item):
|
|
|
|
if item == 'cat':
|
|
|
|
return self.cat
|
|
|
|
|
|
|
|
if hasattr(torch, item):
|
|
|
|
return getattr(torch, item)
|
|
|
|
|
|
|
|
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
|
|
|
|
|
|
|
def cat(self, tensors, *args, **kwargs):
|
|
|
|
if len(tensors) == 2:
|
|
|
|
a, b = tensors
|
|
|
|
if a.shape[-2:] != b.shape[-2:]:
|
|
|
|
a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")
|
|
|
|
|
|
|
|
tensors = (a, b)
|
|
|
|
|
|
|
|
return torch.cat(tensors, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
th = TorchHijackForUnet()
|