mirror of
https://github.com/KimMeen/Time-LLM.git
synced 2024-11-21 03:13:47 +08:00
Update tools.py
This commit is contained in:
parent
d3fa8694a5
commit
0a6ebb7a6e
@ -133,6 +133,7 @@ def cal_accuracy(y_pred, y_true):
|
||||
def del_files(dir_path):
|
||||
shutil.rmtree(dir_path)
|
||||
|
||||
|
||||
def vali(args, accelerator, model, vali_data, vali_loader, criterion, mae_metric):
|
||||
total_loss = []
|
||||
total_mae_loss = []
|
||||
@ -161,7 +162,9 @@ def vali(args, accelerator, model, vali_data, vali_loader, criterion, mae_metric
|
||||
outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
|
||||
else:
|
||||
outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||
# self.accelerator.wait_for_everyone()
|
||||
|
||||
outputs, batch_y = accelerator.gather_for_metrics((outputs, batch_y))
|
||||
|
||||
f_dim = -1 if args.features == 'MS' else 0
|
||||
outputs = outputs[:, -args.pred_len:, f_dim:]
|
||||
batch_y = batch_y[:, -args.pred_len:, f_dim:].to(accelerator.device)
|
||||
@ -205,11 +208,15 @@ def test(args, accelerator, model, train_loader, vali_loader, criterion):
|
||||
None
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
outputs = accelerator.gather_for_metrics(outputs)
|
||||
f_dim = -1 if args.features == 'MS' else 0
|
||||
outputs = outputs[:, -args.pred_len:, f_dim:]
|
||||
pred = outputs
|
||||
true = torch.from_numpy(np.array(y)).to(accelerator.device)
|
||||
batch_y_mark = torch.ones(true.shape).to(accelerator.device)
|
||||
true = accelerator.gather_for_metrics(true)
|
||||
batch_y_mark = accelerator.gather_for_metrics(batch_y_mark)
|
||||
|
||||
loss = criterion(x[:, :, 0], args.frequency_map, pred[:, :, 0], true, batch_y_mark)
|
||||
|
||||
model.train()
|
||||
|
Loading…
Reference in New Issue
Block a user