mirror of
https://github.com/emadeldeen24/TSLANet.git
synced 2024-11-27 00:50:10 +08:00
Update TSLANet_Forecasting.py
This commit is contained in:
parent
f57845266f
commit
562e3f2d38
@ -52,7 +52,7 @@ class Adaptive_Spectral_Block(nn.Module):
|
||||
|
||||
trunc_normal_(self.complex_weight_high, std=.02)
|
||||
trunc_normal_(self.complex_weight, std=.02)
|
||||
self.threshold_param = nn.Parameter(torch.rand(1) * 0.5)
|
||||
self.threshold_param = nn.Parameter(torch.rand(1)) # * 0.5)
|
||||
|
||||
def create_adaptive_high_freq_mask(self, x_fft):
|
||||
B, _, _ = x_fft.shape
|
||||
@ -68,12 +68,8 @@ class Adaptive_Spectral_Block(nn.Module):
|
||||
# Normalize energy
|
||||
normalized_energy = energy / (median_energy + 1e-6)
|
||||
|
||||
threshold = torch.quantile(normalized_energy, self.threshold_param)
|
||||
dominant_frequencies = normalized_energy > threshold
|
||||
|
||||
# Initialize adaptive mask
|
||||
adaptive_mask = torch.zeros_like(x_fft, device=x_fft.device)
|
||||
adaptive_mask[dominant_frequencies] = 1
|
||||
adaptive_mask = ((normalized_energy > self.threshold_param).float() - self.threshold_param).detach() + self.threshold_param
|
||||
adaptive_mask = adaptive_mask.unsqueeze(-1)
|
||||
|
||||
return adaptive_mask
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user