Update TSLANet_classification.py

This commit is contained in:
Emadeldeen Eldele 2024-06-24 11:59:06 +08:00 committed by GitHub
parent b6d2f886b6
commit f57845266f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -67,7 +67,7 @@ class Adaptive_Spectral_Block(nn.Module):
trunc_normal_(self.complex_weight_high, std=.02)
trunc_normal_(self.complex_weight, std=.02)
self.threshold_param = nn.Parameter(torch.rand(1) * 0.5)
self.threshold_param = nn.Parameter(torch.rand(1)) # * 0.5)
def create_adaptive_high_freq_mask(self, x_fft):
B, _, _ = x_fft.shape
@ -84,12 +84,8 @@ class Adaptive_Spectral_Block(nn.Module):
epsilon = 1e-6 # Small constant to avoid division by zero
normalized_energy = energy / (median_energy + epsilon)
threshold = torch.quantile(normalized_energy, self.threshold_param)
dominant_frequencies = normalized_energy > threshold
# Initialize adaptive mask
adaptive_mask = torch.zeros_like(x_fft, device=x_fft.device)
adaptive_mask[dominant_frequencies] = 1
adaptive_mask = ((normalized_energy > self.threshold_param).float() - self.threshold_param).detach() + self.threshold_param
adaptive_mask = adaptive_mask.unsqueeze(-1)
return adaptive_mask