Update TSLANet_Forecasting.py

This commit is contained in:
Emadeldeen Eldele 2024-06-24 12:00:02 +08:00 committed by GitHub
parent f57845266f
commit 562e3f2d38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -52,7 +52,7 @@ class Adaptive_Spectral_Block(nn.Module):
trunc_normal_(self.complex_weight_high, std=.02)
trunc_normal_(self.complex_weight, std=.02)
self.threshold_param = nn.Parameter(torch.rand(1) * 0.5)
self.threshold_param = nn.Parameter(torch.rand(1)) # * 0.5)
def create_adaptive_high_freq_mask(self, x_fft):
B, _, _ = x_fft.shape
@ -68,12 +68,8 @@ class Adaptive_Spectral_Block(nn.Module):
# Normalize energy
normalized_energy = energy / (median_energy + 1e-6)
threshold = torch.quantile(normalized_energy, self.threshold_param)
dominant_frequencies = normalized_energy > threshold
# Initialize adaptive mask
adaptive_mask = torch.zeros_like(x_fft, device=x_fft.device)
adaptive_mask[dominant_frequencies] = 1
adaptive_mask = ((normalized_energy > self.threshold_param).float() - self.threshold_param).detach() + self.threshold_param
adaptive_mask = adaptive_mask.unsqueeze(-1)
return adaptive_mask