TSLANet/Forecasting/utils.py
Emadeldeen Eldele ee60bee26b
Update utils.py
2024-04-18 17:03:44 +08:00

69 lines
2.4 KiB
Python

import torch
import os
import shutil
import inspect
import argparse
def save_copy_of_files(checkpoint_callback):
# Get the frame of the caller of this function
caller_frame = inspect.currentframe().f_back
# Get the filename of the caller
caller_filename = caller_frame.f_globals["__file__"]
# Get the absolute path of the caller script
caller_script_path = os.path.abspath(caller_filename)
# Destination directory (PyTorch Lightning saving directory)
destination_directory = checkpoint_callback.dirpath
# Ensure the destination directory exists
os.makedirs(destination_directory, exist_ok=True)
# Copy the caller script to the destination directory
shutil.copy(caller_script_path, destination_directory)
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def random_masking_3D(xb, mask_ratio):
# xb: [bs x num_patch x dim]
bs, L, D = xb.shape
x = xb.clone()
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(bs, L, device=xb.device) # noise in [0, 1], bs x L
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1) # ids_restore: [bs x L]
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep] # ids_keep: [bs x len_keep]
x_kept = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # x_kept: [bs x len_keep x dim]
# removed x
x_removed = torch.zeros(bs, L - len_keep, D, device=xb.device) # x_removed: [bs x (L-len_keep) x dim]
x_ = torch.cat([x_kept, x_removed], dim=1) # x_: [bs x L x dim]
# combine the kept part and the removed one
x_masked = torch.gather(x_, dim=1,
index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) # x_masked: [bs x num_patch x dim]
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([bs, L], device=x.device) # mask: [bs x num_patch]
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore) # [bs x num_patch]
return x_masked, x_kept, mask, ids_restore