Update TSLANet_Forecasting.py

This commit is contained in:
Emadeldeen Eldele 2024-04-18 16:36:04 +08:00 committed by GitHub
parent 5f33cb3895
commit 9d191631f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,7 +14,7 @@ from timm.models.layers import trunc_normal_
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
from data_factory import data_provider
from utils import save_copy_of_files, random_masking_3D, str2bool, random_masking
from utils import save_copy_of_files, random_masking_3D, str2bool
class ICB(L.LightningModule):
@ -151,7 +151,6 @@ class TSLANet(nn.Module):
# Parameters/Embeddings
self.out_layer = nn.Linear(args.emb_dim * num_patches, args.pred_len)
# self.out_layer_mask = nn.Linear(args.emb_dim, self.patch_size)
def pretrain(self, x_in):
x = rearrange(x_in, 'b l m -> b m l')
@ -168,33 +167,6 @@ class TSLANet(nn.Module):
return xb_mask, self.input_layer(x_patched)
def pretrain_(self, x_in):
B, L, M = x_in.shape
x = rearrange(x_in, 'b l m -> b m l')
x_patched = x.unfold(dimension=-1, size=self.patch_size, step=self.stride)
# x_patched = rearrange(x_patched, 'b m n p -> (b m) n p')
x_masked, _, self.mask, _ = random_masking(x_patched, mask_ratio=args.mask_ratio)
self.mask = self.mask.bool() # mask: [bs x num_patch]
x_masked = rearrange(x_masked, 'b m n p -> (b m) n p')
x_patched = rearrange(x_patched, 'b m n p -> (b m) n p')
xT_masked = self.input_layer(x_masked)
xT_patched = self.input_layer(x_patched)
for gf_blk in self.gf_blocks:
xT_masked = gf_blk(xT_masked)
for gf_blk in self.gf_blocks:
xT_patched = gf_blk(xT_patched)
xT_masked = rearrange(xT_masked, '(b m) n p -> b m n p', b=B)
xT_patched = rearrange(xT_patched, '(b m) n p -> b m n p', b=B)
return xT_masked, xT_patched
def forward(self, x):
B, L, M = x.shape
@ -398,7 +370,7 @@ if __name__ == '__main__':
# Data args...
parser.add_argument('--data', type=str, default='ETTh1', help='dataset type')
parser.add_argument('--root_path', type=str, default='C:/Emad/datasets/Forecasting/ETT-small',
parser.add_argument('--root_path', type=str, default='data/ETT-small',
help='root path of the data file')
parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
parser.add_argument('--embed', type=str, default='timeF',
@ -411,7 +383,7 @@ if __name__ == '__main__':
help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
# forecasting lengths
parser.add_argument('--seq_len', type=int, default=64, help='input sequence length')
parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
parser.add_argument('--label_len', type=int, default=48, help='start token length')
parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4')