mirror of
https://github.com/KimMeen/Time-LLM.git
synced 2024-11-21 03:13:47 +08:00
Update tools.py
This commit is contained in:
parent
d3fa8694a5
commit
0a6ebb7a6e
@ -133,6 +133,7 @@ def cal_accuracy(y_pred, y_true):
|
|||||||
def del_files(dir_path):
|
def del_files(dir_path):
|
||||||
shutil.rmtree(dir_path)
|
shutil.rmtree(dir_path)
|
||||||
|
|
||||||
|
|
||||||
def vali(args, accelerator, model, vali_data, vali_loader, criterion, mae_metric):
|
def vali(args, accelerator, model, vali_data, vali_loader, criterion, mae_metric):
|
||||||
total_loss = []
|
total_loss = []
|
||||||
total_mae_loss = []
|
total_mae_loss = []
|
||||||
@ -161,7 +162,9 @@ def vali(args, accelerator, model, vali_data, vali_loader, criterion, mae_metric
|
|||||||
outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
|
outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
|
||||||
else:
|
else:
|
||||||
outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
||||||
# self.accelerator.wait_for_everyone()
|
|
||||||
|
outputs, batch_y = accelerator.gather_for_metrics((outputs, batch_y))
|
||||||
|
|
||||||
f_dim = -1 if args.features == 'MS' else 0
|
f_dim = -1 if args.features == 'MS' else 0
|
||||||
outputs = outputs[:, -args.pred_len:, f_dim:]
|
outputs = outputs[:, -args.pred_len:, f_dim:]
|
||||||
batch_y = batch_y[:, -args.pred_len:, f_dim:].to(accelerator.device)
|
batch_y = batch_y[:, -args.pred_len:, f_dim:].to(accelerator.device)
|
||||||
@ -205,11 +208,15 @@ def test(args, accelerator, model, train_loader, vali_loader, criterion):
|
|||||||
None
|
None
|
||||||
)
|
)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
outputs = accelerator.gather_for_metrics(outputs)
|
||||||
f_dim = -1 if args.features == 'MS' else 0
|
f_dim = -1 if args.features == 'MS' else 0
|
||||||
outputs = outputs[:, -args.pred_len:, f_dim:]
|
outputs = outputs[:, -args.pred_len:, f_dim:]
|
||||||
pred = outputs
|
pred = outputs
|
||||||
true = torch.from_numpy(np.array(y)).to(accelerator.device)
|
true = torch.from_numpy(np.array(y)).to(accelerator.device)
|
||||||
batch_y_mark = torch.ones(true.shape).to(accelerator.device)
|
batch_y_mark = torch.ones(true.shape).to(accelerator.device)
|
||||||
|
true = accelerator.gather_for_metrics(true)
|
||||||
|
batch_y_mark = accelerator.gather_for_metrics(batch_y_mark)
|
||||||
|
|
||||||
loss = criterion(x[:, :, 0], args.frequency_map, pred[:, :, 0], true, batch_y_mark)
|
loss = criterion(x[:, :, 0], args.frequency_map, pred[:, :, 0], true, batch_y_mark)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
|
Loading…
Reference in New Issue
Block a user