mirror of
https://github.com/emadeldeen24/TSLANet.git
synced 2024-11-27 00:50:10 +08:00
Update TSLANet_classification.py
This commit is contained in:
parent
b6d2f886b6
commit
f57845266f
@ -67,7 +67,7 @@ class Adaptive_Spectral_Block(nn.Module):
|
||||
|
||||
trunc_normal_(self.complex_weight_high, std=.02)
|
||||
trunc_normal_(self.complex_weight, std=.02)
|
||||
self.threshold_param = nn.Parameter(torch.rand(1) * 0.5)
|
||||
self.threshold_param = nn.Parameter(torch.rand(1)) # * 0.5)
|
||||
|
||||
def create_adaptive_high_freq_mask(self, x_fft):
|
||||
B, _, _ = x_fft.shape
|
||||
@ -84,12 +84,8 @@ class Adaptive_Spectral_Block(nn.Module):
|
||||
epsilon = 1e-6 # Small constant to avoid division by zero
|
||||
normalized_energy = energy / (median_energy + epsilon)
|
||||
|
||||
threshold = torch.quantile(normalized_energy, self.threshold_param)
|
||||
dominant_frequencies = normalized_energy > threshold
|
||||
|
||||
# Initialize adaptive mask
|
||||
adaptive_mask = torch.zeros_like(x_fft, device=x_fft.device)
|
||||
adaptive_mask[dominant_frequencies] = 1
|
||||
adaptive_mask = ((normalized_energy > self.threshold_param).float() - self.threshold_param).detach() + self.threshold_param
|
||||
adaptive_mask = adaptive_mask.unsqueeze(-1)
|
||||
|
||||
return adaptive_mask
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user