From f00454cc16d6e4efb07dd08e09fe711e62ae54c5 Mon Sep 17 00:00:00 2001 From: Emadeldeen Eldele <37911596+emadeldeen24@users.noreply.github.com> Date: Fri, 19 Apr 2024 09:28:47 +0800 Subject: [PATCH] Update TSLANet_Forecasting.py --- Forecasting/TSLANet_Forecasting.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/Forecasting/TSLANet_Forecasting.py b/Forecasting/TSLANet_Forecasting.py index c6199ab..642ab48 100644 --- a/Forecasting/TSLANet_Forecasting.py +++ b/Forecasting/TSLANet_Forecasting.py @@ -108,25 +108,26 @@ class Adaptive_Spectral_Block(nn.Module): class TSLANet_layer(L.LightningModule): - def __init__(self, dim, mlp_ratio=3., drop=0., drop_path=0., norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) - self.filter = Adaptive_Spectral_Block(dim) + self.asb = Adaptive_Spectral_Block(dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.icb = ICB(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) def forward(self, x): - if not args.ICB: + # Check if both ASB and ICB are true + if args.ICB and args.ASB: + x = x + self.drop_path(self.icb(self.norm2(self.asb(self.norm1(x))))) + # If only ICB is true + elif args.ICB: x = x + self.drop_path(self.icb(self.norm2(x))) - return x - if not args.ASB: - x = x + self.drop_path(self.filter(self.norm1(x))) - return x - - x = x + self.drop_path(self.icb(self.norm2(self.filter(self.norm1(x))))) + # If only ASB is true + elif args.ASB: + x = x + self.drop_path(self.asb(self.norm1(x))) + # If neither is true, just pass x through return x