From 0a6ebb7a6e8860d28ebf6b86bf23199cf5d003f1 Mon Sep 17 00:00:00 2001 From: Ming Jin Date: Mon, 1 Apr 2024 17:45:07 +1100 Subject: [PATCH] Update tools.py --- utils/tools.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/utils/tools.py b/utils/tools.py index fe98a72..b9ee19b 100644 --- a/utils/tools.py +++ b/utils/tools.py @@ -133,6 +133,7 @@ def cal_accuracy(y_pred, y_true): def del_files(dir_path): shutil.rmtree(dir_path) + def vali(args, accelerator, model, vali_data, vali_loader, criterion, mae_metric): total_loss = [] total_mae_loss = [] @@ -161,7 +162,9 @@ def vali(args, accelerator, model, vali_data, vali_loader, criterion, mae_metric outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0] else: outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark) - # self.accelerator.wait_for_everyone() + + outputs, batch_y = accelerator.gather_for_metrics((outputs, batch_y)) + f_dim = -1 if args.features == 'MS' else 0 outputs = outputs[:, -args.pred_len:, f_dim:] batch_y = batch_y[:, -args.pred_len:, f_dim:].to(accelerator.device) @@ -205,11 +208,15 @@ def test(args, accelerator, model, train_loader, vali_loader, criterion): None ) accelerator.wait_for_everyone() + outputs = accelerator.gather_for_metrics(outputs) f_dim = -1 if args.features == 'MS' else 0 outputs = outputs[:, -args.pred_len:, f_dim:] pred = outputs true = torch.from_numpy(np.array(y)).to(accelerator.device) batch_y_mark = torch.ones(true.shape).to(accelerator.device) + true = accelerator.gather_for_metrics(true) + batch_y_mark = accelerator.gather_for_metrics(batch_y_mark) + loss = criterion(x[:, :, 0], args.frequency_map, pred[:, :, 0], true, batch_y_mark) model.train()