mirror of
https://github.com/KimMeen/Time-LLM.git
synced 2024-12-15 08:50:00 +08:00
61 lines
2.3 KiB
Python
61 lines
2.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class Inception_Block_V1(nn.Module):
|
|
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
|
|
super(Inception_Block_V1, self).__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.num_kernels = num_kernels
|
|
kernels = []
|
|
for i in range(self.num_kernels):
|
|
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=2 * i + 1, padding=i))
|
|
self.kernels = nn.ModuleList(kernels)
|
|
if init_weight:
|
|
self._initialize_weights()
|
|
|
|
def _initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
res_list = []
|
|
for i in range(self.num_kernels):
|
|
res_list.append(self.kernels[i](x))
|
|
res = torch.stack(res_list, dim=-1).mean(-1)
|
|
return res
|
|
|
|
|
|
class Inception_Block_V2(nn.Module):
|
|
def __init__(self, in_channels, out_channels, num_kernels=6, init_weight=True):
|
|
super(Inception_Block_V2, self).__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.num_kernels = num_kernels
|
|
kernels = []
|
|
for i in range(self.num_kernels // 2):
|
|
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[1, 2 * i + 3], padding=[0, i + 1]))
|
|
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=[2 * i + 3, 1], padding=[i + 1, 0]))
|
|
kernels.append(nn.Conv2d(in_channels, out_channels, kernel_size=1))
|
|
self.kernels = nn.ModuleList(kernels)
|
|
if init_weight:
|
|
self._initialize_weights()
|
|
|
|
def _initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
if m.bias is not None:
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x):
|
|
res_list = []
|
|
for i in range(self.num_kernels + 1):
|
|
res_list.append(self.kernels[i](x))
|
|
res = torch.stack(res_list, dim=-1).mean(-1)
|
|
return res
|