mirror of
https://github.com/emadeldeen24/TSLANet.git
synced 2024-11-27 00:50:10 +08:00
69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
import torch
|
|
import os
|
|
import shutil
|
|
import inspect
|
|
import argparse
|
|
|
|
|
|
def save_copy_of_files(checkpoint_callback):
|
|
# Get the frame of the caller of this function
|
|
caller_frame = inspect.currentframe().f_back
|
|
|
|
# Get the filename of the caller
|
|
caller_filename = caller_frame.f_globals["__file__"]
|
|
|
|
# Get the absolute path of the caller script
|
|
caller_script_path = os.path.abspath(caller_filename)
|
|
|
|
# Destination directory (PyTorch Lightning saving directory)
|
|
destination_directory = checkpoint_callback.dirpath
|
|
|
|
# Ensure the destination directory exists
|
|
os.makedirs(destination_directory, exist_ok=True)
|
|
|
|
# Copy the caller script to the destination directory
|
|
shutil.copy(caller_script_path, destination_directory)
|
|
|
|
|
|
def str2bool(v):
|
|
if isinstance(v, bool):
|
|
return v
|
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
|
return True
|
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
|
return False
|
|
else:
|
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
|
|
|
def random_masking_3D(xb, mask_ratio):
|
|
# xb: [bs x num_patch x dim]
|
|
bs, L, D = xb.shape
|
|
x = xb.clone()
|
|
|
|
len_keep = int(L * (1 - mask_ratio))
|
|
|
|
noise = torch.rand(bs, L, device=xb.device) # noise in [0, 1], bs x L
|
|
|
|
# sort noise for each sample
|
|
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1) # ids_restore: [bs x L]
|
|
|
|
# keep the first subset
|
|
ids_keep = ids_shuffle[:, :len_keep] # ids_keep: [bs x len_keep]
|
|
x_kept = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) # x_kept: [bs x len_keep x dim]
|
|
|
|
# removed x
|
|
x_removed = torch.zeros(bs, L - len_keep, D, device=xb.device) # x_removed: [bs x (L-len_keep) x dim]
|
|
x_ = torch.cat([x_kept, x_removed], dim=1) # x_: [bs x L x dim]
|
|
|
|
# combine the kept part and the removed one
|
|
x_masked = torch.gather(x_, dim=1,
|
|
index=ids_restore.unsqueeze(-1).repeat(1, 1, D)) # x_masked: [bs x num_patch x dim]
|
|
|
|
# generate the binary mask: 0 is keep, 1 is remove
|
|
mask = torch.ones([bs, L], device=x.device) # mask: [bs x num_patch]
|
|
mask[:, :len_keep] = 0
|
|
# unshuffle to get the binary mask
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) # [bs x num_patch]
|
|
return x_masked, x_kept, mask, ids_restore
|