mirror of
https://github.com/emadeldeen24/TSLANet.git
synced 2025-02-17 10:49:29 +08:00
Update TSLANet_Forecasting.py
This commit is contained in:
parent
5f33cb3895
commit
9d191631f9
@ -14,7 +14,7 @@ from timm.models.layers import trunc_normal_
|
||||
from torchmetrics.regression import MeanSquaredError, MeanAbsoluteError
|
||||
|
||||
from data_factory import data_provider
|
||||
from utils import save_copy_of_files, random_masking_3D, str2bool, random_masking
|
||||
from utils import save_copy_of_files, random_masking_3D, str2bool
|
||||
|
||||
|
||||
class ICB(L.LightningModule):
|
||||
@ -151,7 +151,6 @@ class TSLANet(nn.Module):
|
||||
|
||||
# Parameters/Embeddings
|
||||
self.out_layer = nn.Linear(args.emb_dim * num_patches, args.pred_len)
|
||||
# self.out_layer_mask = nn.Linear(args.emb_dim, self.patch_size)
|
||||
|
||||
def pretrain(self, x_in):
|
||||
x = rearrange(x_in, 'b l m -> b m l')
|
||||
@ -168,33 +167,6 @@ class TSLANet(nn.Module):
|
||||
return xb_mask, self.input_layer(x_patched)
|
||||
|
||||
|
||||
def pretrain_(self, x_in):
|
||||
B, L, M = x_in.shape
|
||||
x = rearrange(x_in, 'b l m -> b m l')
|
||||
x_patched = x.unfold(dimension=-1, size=self.patch_size, step=self.stride)
|
||||
# x_patched = rearrange(x_patched, 'b m n p -> (b m) n p')
|
||||
|
||||
x_masked, _, self.mask, _ = random_masking(x_patched, mask_ratio=args.mask_ratio)
|
||||
self.mask = self.mask.bool() # mask: [bs x num_patch]
|
||||
|
||||
x_masked = rearrange(x_masked, 'b m n p -> (b m) n p')
|
||||
x_patched = rearrange(x_patched, 'b m n p -> (b m) n p')
|
||||
|
||||
xT_masked = self.input_layer(x_masked)
|
||||
xT_patched = self.input_layer(x_patched)
|
||||
|
||||
for gf_blk in self.gf_blocks:
|
||||
xT_masked = gf_blk(xT_masked)
|
||||
|
||||
for gf_blk in self.gf_blocks:
|
||||
xT_patched = gf_blk(xT_patched)
|
||||
|
||||
xT_masked = rearrange(xT_masked, '(b m) n p -> b m n p', b=B)
|
||||
xT_patched = rearrange(xT_patched, '(b m) n p -> b m n p', b=B)
|
||||
|
||||
return xT_masked, xT_patched
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
B, L, M = x.shape
|
||||
|
||||
@ -398,7 +370,7 @@ if __name__ == '__main__':
|
||||
|
||||
# Data args...
|
||||
parser.add_argument('--data', type=str, default='ETTh1', help='dataset type')
|
||||
parser.add_argument('--root_path', type=str, default='C:/Emad/datasets/Forecasting/ETT-small',
|
||||
parser.add_argument('--root_path', type=str, default='data/ETT-small',
|
||||
help='root path of the data file')
|
||||
parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file')
|
||||
parser.add_argument('--embed', type=str, default='timeF',
|
||||
@ -411,7 +383,7 @@ if __name__ == '__main__':
|
||||
help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
|
||||
|
||||
# forecasting lengths
|
||||
parser.add_argument('--seq_len', type=int, default=64, help='input sequence length')
|
||||
parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
|
||||
parser.add_argument('--label_len', type=int, default=48, help='start token length')
|
||||
parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
|
||||
parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4')
|
||||
|
Loading…
Reference in New Issue
Block a user