mirror of
https://github.com/emadeldeen24/TSLANet.git
synced 2024-11-21 00:40:52 +08:00
Update TSLANet_classification.py
This commit is contained in:
parent
ee60bee26b
commit
cc528c563d
@ -161,21 +161,23 @@ class TSLANet_layer(L.LightningModule):
|
||||
def __init__(self, dim, mlp_ratio=3., drop=0., drop_path=0., norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.filter = Adaptive_Spectral_Block(dim)
|
||||
self.asb = Adaptive_Spectral_Block(dim)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.icb = ICB(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
if not args.ICB:
|
||||
# Check if both ASB and ICB are true
|
||||
if args.ICB and args.ASB:
|
||||
x = x + self.drop_path(self.icb(self.norm2(self.asb(self.norm1(x)))))
|
||||
# If only ICB is true
|
||||
elif args.ICB:
|
||||
x = x + self.drop_path(self.icb(self.norm2(x)))
|
||||
return x
|
||||
if not args.ASB:
|
||||
x = x + self.drop_path(self.filter(self.norm1(x)))
|
||||
return x
|
||||
|
||||
x = x + self.drop_path(self.icb(self.norm2(self.filter(self.norm1(x)))))
|
||||
# If only ASB is true
|
||||
elif args.ASB:
|
||||
x = x + self.drop_path(self.asb(self.norm1(x)))
|
||||
# If neither is true, just pass x through
|
||||
return x
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user