Update tools.py

This commit is contained in:
Ming Jin 2024-04-01 17:45:07 +11:00
parent d3fa8694a5
commit 0a6ebb7a6e

View File

@ -133,6 +133,7 @@ def cal_accuracy(y_pred, y_true):
def del_files(dir_path):
shutil.rmtree(dir_path)
def vali(args, accelerator, model, vali_data, vali_loader, criterion, mae_metric):
total_loss = []
total_mae_loss = []
@ -161,7 +162,9 @@ def vali(args, accelerator, model, vali_data, vali_loader, criterion, mae_metric
outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
else:
outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
# self.accelerator.wait_for_everyone()
outputs, batch_y = accelerator.gather_for_metrics((outputs, batch_y))
f_dim = -1 if args.features == 'MS' else 0
outputs = outputs[:, -args.pred_len:, f_dim:]
batch_y = batch_y[:, -args.pred_len:, f_dim:].to(accelerator.device)
@ -205,11 +208,15 @@ def test(args, accelerator, model, train_loader, vali_loader, criterion):
None
)
accelerator.wait_for_everyone()
outputs = accelerator.gather_for_metrics(outputs)
f_dim = -1 if args.features == 'MS' else 0
outputs = outputs[:, -args.pred_len:, f_dim:]
pred = outputs
true = torch.from_numpy(np.array(y)).to(accelerator.device)
batch_y_mark = torch.ones(true.shape).to(accelerator.device)
true = accelerator.gather_for_metrics(true)
batch_y_mark = accelerator.gather_for_metrics(batch_y_mark)
loss = criterion(x[:, :, 0], args.frequency_map, pred[:, :, 0], true, batch_y_mark)
model.train()