Update TSLANet_classification.py

This commit is contained in:
Emadeldeen Eldele 2024-04-19 09:28:21 +08:00 committed by GitHub
parent ee60bee26b
commit cc528c563d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -161,21 +161,23 @@ class TSLANet_layer(L.LightningModule):
def __init__(self, dim, mlp_ratio=3., drop=0., drop_path=0., norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.filter = Adaptive_Spectral_Block(dim)
self.asb = Adaptive_Spectral_Block(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.icb = ICB(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
def forward(self, x):
if not args.ICB:
# Check if both ASB and ICB are true
if args.ICB and args.ASB:
x = x + self.drop_path(self.icb(self.norm2(self.asb(self.norm1(x)))))
# If only ICB is true
elif args.ICB:
x = x + self.drop_path(self.icb(self.norm2(x)))
return x
if not args.ASB:
x = x + self.drop_path(self.filter(self.norm1(x)))
return x
x = x + self.drop_path(self.icb(self.norm2(self.filter(self.norm1(x)))))
# If only ASB is true
elif args.ASB:
x = x + self.drop_path(self.asb(self.norm1(x)))
# If neither is true, just pass x through
return x