From f57845266fb238a8448f52cfed7119b3b0359cbf Mon Sep 17 00:00:00 2001 From: Emadeldeen Eldele <37911596+emadeldeen24@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:59:06 +0800 Subject: [PATCH] Update TSLANet_classification.py --- Classification/TSLANet_classification.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/Classification/TSLANet_classification.py b/Classification/TSLANet_classification.py index 82b3120..b2f7684 100644 --- a/Classification/TSLANet_classification.py +++ b/Classification/TSLANet_classification.py @@ -67,7 +67,7 @@ class Adaptive_Spectral_Block(nn.Module): trunc_normal_(self.complex_weight_high, std=.02) trunc_normal_(self.complex_weight, std=.02) - self.threshold_param = nn.Parameter(torch.rand(1) * 0.5) + self.threshold_param = nn.Parameter(torch.rand(1)) # * 0.5) def create_adaptive_high_freq_mask(self, x_fft): B, _, _ = x_fft.shape @@ -84,12 +84,8 @@ class Adaptive_Spectral_Block(nn.Module): epsilon = 1e-6 # Small constant to avoid division by zero normalized_energy = energy / (median_energy + epsilon) - threshold = torch.quantile(normalized_energy, self.threshold_param) - dominant_frequencies = normalized_energy > threshold - - # Initialize adaptive mask - adaptive_mask = torch.zeros_like(x_fft, device=x_fft.device) - adaptive_mask[dominant_frequencies] = 1 + adaptive_mask = ((normalized_energy > self.threshold_param).float() - self.threshold_param).detach() + self.threshold_param + adaptive_mask = adaptive_mask.unsqueeze(-1) return adaptive_mask