mirror of
https://github.com/KimMeen/Time-LLM.git
synced 2024-11-27 07:49:53 +08:00
commit=add GPT2 and BERT for Time-LLM as general framework &update the prompt_bank(weather, ECL,Traffic)
This commit is contained in:
parent
980a9024ff
commit
e022670182
@ -6,7 +6,9 @@ data_dict = {
|
|||||||
'ETTh2': Dataset_ETT_hour,
|
'ETTh2': Dataset_ETT_hour,
|
||||||
'ETTm1': Dataset_ETT_minute,
|
'ETTm1': Dataset_ETT_minute,
|
||||||
'ETTm2': Dataset_ETT_minute,
|
'ETTm2': Dataset_ETT_minute,
|
||||||
'custom': Dataset_Custom,
|
'ECL': Dataset_Custom,
|
||||||
|
'Traffic': Dataset_Custom,
|
||||||
|
'Weather': Dataset_Custom,
|
||||||
'm4': Dataset_M4,
|
'm4': Dataset_M4,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,11 +3,10 @@ from math import sqrt
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from transformers import LlamaConfig, LlamaModel, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaModel, LlamaTokenizer, GPT2Config, GPT2Model, GPT2Tokenizer, BertConfig, \
|
||||||
|
BertModel, BertTokenizer
|
||||||
from layers.Embed import PatchEmbedding
|
from layers.Embed import PatchEmbedding
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
from layers.StandardNorm import Normalize
|
from layers.StandardNorm import Normalize
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
@ -37,17 +36,18 @@ class Model(nn.Module):
|
|||||||
self.seq_len = configs.seq_len
|
self.seq_len = configs.seq_len
|
||||||
self.d_ff = configs.d_ff
|
self.d_ff = configs.d_ff
|
||||||
self.top_k = 5
|
self.top_k = 5
|
||||||
self.d_llm = 4096
|
self.d_llm = configs.llm_dim
|
||||||
self.patch_len = configs.patch_len
|
self.patch_len = configs.patch_len
|
||||||
self.stride = configs.stride
|
self.stride = configs.stride
|
||||||
|
|
||||||
self.llama_config = LlamaConfig.from_pretrained('/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/')
|
if configs.llm_model == 'LLAMA':
|
||||||
# self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b')
|
# self.llama_config = LlamaConfig.from_pretrained('/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/')
|
||||||
|
self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b')
|
||||||
self.llama_config.num_hidden_layers = configs.llm_layers
|
self.llama_config.num_hidden_layers = configs.llm_layers
|
||||||
self.llama_config.output_attentions = True
|
self.llama_config.output_attentions = True
|
||||||
self.llama_config.output_hidden_states = True
|
self.llama_config.output_hidden_states = True
|
||||||
try:
|
try:
|
||||||
self.llama = LlamaModel.from_pretrained(
|
self.llm_model = LlamaModel.from_pretrained(
|
||||||
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
|
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
|
||||||
'huggyllama/llama-7b',
|
'huggyllama/llama-7b',
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@ -57,7 +57,7 @@ class Model(nn.Module):
|
|||||||
)
|
)
|
||||||
except EnvironmentError: # downloads model from HF is not already done
|
except EnvironmentError: # downloads model from HF is not already done
|
||||||
print("Local model files not found. Attempting to download...")
|
print("Local model files not found. Attempting to download...")
|
||||||
self.llama = LlamaModel.from_pretrained(
|
self.llm_model = LlamaModel.from_pretrained(
|
||||||
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
|
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
|
||||||
'huggyllama/llama-7b',
|
'huggyllama/llama-7b',
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
@ -80,6 +80,78 @@ class Model(nn.Module):
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
local_files_only=False
|
local_files_only=False
|
||||||
)
|
)
|
||||||
|
elif configs.llm_model == 'GPT2':
|
||||||
|
self.gpt2_config = GPT2Config.from_pretrained('openai-community/gpt2')
|
||||||
|
|
||||||
|
self.gpt2_config.num_hidden_layers = configs.llm_layers
|
||||||
|
self.gpt2_config.output_attentions = True
|
||||||
|
self.gpt2_config.output_hidden_states = True
|
||||||
|
try:
|
||||||
|
self.llm_model = GPT2Model.from_pretrained(
|
||||||
|
'openai-community/gpt2',
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=True,
|
||||||
|
config=self.gpt2_config,
|
||||||
|
)
|
||||||
|
except EnvironmentError: # downloads model from HF is not already done
|
||||||
|
print("Local model files not found. Attempting to download...")
|
||||||
|
self.llm_model = GPT2Model.from_pretrained(
|
||||||
|
'openai-community/gpt2',
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=False,
|
||||||
|
config=self.gpt2_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.tokenizer = GPT2Tokenizer.from_pretrained(
|
||||||
|
'openai-community/gpt2',
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=True
|
||||||
|
)
|
||||||
|
except EnvironmentError: # downloads the tokenizer from HF if not already done
|
||||||
|
print("Local tokenizer files not found. Atempting to download them..")
|
||||||
|
self.tokenizer = GPT2Tokenizer.from_pretrained(
|
||||||
|
'openai-community/gpt2',
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=False
|
||||||
|
)
|
||||||
|
elif configs.llm_model == 'BERT':
|
||||||
|
self.bert_config = BertConfig.from_pretrained('google-bert/bert-base-uncased')
|
||||||
|
|
||||||
|
self.bert_config.num_hidden_layers = configs.llm_layers
|
||||||
|
self.bert_config.output_attentions = True
|
||||||
|
self.bert_config.output_hidden_states = True
|
||||||
|
try:
|
||||||
|
self.llm_model = BertModel.from_pretrained(
|
||||||
|
'google-bert/bert-base-uncased',
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=True,
|
||||||
|
config=self.bert_config,
|
||||||
|
)
|
||||||
|
except EnvironmentError: # downloads model from HF is not already done
|
||||||
|
print("Local model files not found. Attempting to download...")
|
||||||
|
self.llm_model = BertModel.from_pretrained(
|
||||||
|
'google-bert/bert-base-uncased',
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=False,
|
||||||
|
config=self.gpt2_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained(
|
||||||
|
'google-bert/bert-base-uncased',
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=True
|
||||||
|
)
|
||||||
|
except EnvironmentError: # downloads the tokenizer from HF if not already done
|
||||||
|
print("Local tokenizer files not found. Atempting to download them..")
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained(
|
||||||
|
'google-bert/bert-base-uncased',
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception('LLM model is not defined')
|
||||||
|
|
||||||
if self.tokenizer.eos_token:
|
if self.tokenizer.eos_token:
|
||||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
@ -88,15 +160,20 @@ class Model(nn.Module):
|
|||||||
self.tokenizer.add_special_tokens({'pad_token': pad_token})
|
self.tokenizer.add_special_tokens({'pad_token': pad_token})
|
||||||
self.tokenizer.pad_token = pad_token
|
self.tokenizer.pad_token = pad_token
|
||||||
|
|
||||||
for param in self.llama.parameters():
|
for param in self.llm_model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
if configs.prompt_domain:
|
||||||
|
self.description = configs.content
|
||||||
|
else:
|
||||||
|
self.description = 'The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment.'
|
||||||
|
|
||||||
self.dropout = nn.Dropout(configs.dropout)
|
self.dropout = nn.Dropout(configs.dropout)
|
||||||
|
|
||||||
self.patch_embedding = PatchEmbedding(
|
self.patch_embedding = PatchEmbedding(
|
||||||
configs.d_model, self.patch_len, self.stride, configs.dropout)
|
configs.d_model, self.patch_len, self.stride, configs.dropout)
|
||||||
|
|
||||||
self.word_embeddings = self.llama.get_input_embeddings().weight
|
self.word_embeddings = self.llm_model.get_input_embeddings().weight
|
||||||
self.vocab_size = self.word_embeddings.shape[0]
|
self.vocab_size = self.word_embeddings.shape[0]
|
||||||
self.num_tokens = 1000
|
self.num_tokens = 1000
|
||||||
self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)
|
self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)
|
||||||
@ -140,7 +217,7 @@ class Model(nn.Module):
|
|||||||
median_values_str = str(medians[b].tolist()[0])
|
median_values_str = str(medians[b].tolist()[0])
|
||||||
lags_values_str = str(lags[b].tolist())
|
lags_values_str = str(lags[b].tolist())
|
||||||
prompt_ = (
|
prompt_ = (
|
||||||
f"<|start_prompt|>Dataset description: The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment."
|
f"<|start_prompt|>Dataset description: {self.description}"
|
||||||
f"Task description: forecast the next {str(self.pred_len)} steps given the previous {str(self.seq_len)} steps information; "
|
f"Task description: forecast the next {str(self.pred_len)} steps given the previous {str(self.seq_len)} steps information; "
|
||||||
"Input statistics: "
|
"Input statistics: "
|
||||||
f"min value {min_values_str}, "
|
f"min value {min_values_str}, "
|
||||||
@ -155,7 +232,7 @@ class Model(nn.Module):
|
|||||||
x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()
|
x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
|
prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
|
||||||
prompt_embeddings = self.llama.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim)
|
prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim)
|
||||||
|
|
||||||
source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
||||||
|
|
||||||
@ -163,7 +240,7 @@ class Model(nn.Module):
|
|||||||
enc_out, n_vars = self.patch_embedding(x_enc.to(torch.bfloat16))
|
enc_out, n_vars = self.patch_embedding(x_enc.to(torch.bfloat16))
|
||||||
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
||||||
llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)
|
llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)
|
||||||
dec_out = self.llama(inputs_embeds=llama_enc_out).last_hidden_state
|
dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state
|
||||||
dec_out = dec_out[:, :, :self.d_ff]
|
dec_out = dec_out[:, :, :self.d_ff]
|
||||||
|
|
||||||
dec_out = torch.reshape(
|
dec_out = torch.reshape(
|
||||||
|
@ -80,7 +80,9 @@ parser.add_argument('--activation', type=str, default='gelu', help='activation')
|
|||||||
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
|
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
|
||||||
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
|
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
|
||||||
parser.add_argument('--stride', type=int, default=8, help='stride')
|
parser.add_argument('--stride', type=int, default=8, help='stride')
|
||||||
parser.add_argument('--prompt_domain', type=int, default=0, help='stride')
|
parser.add_argument('--prompt_domain', type=int, default=0, help='')
|
||||||
|
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model')
|
||||||
|
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')
|
||||||
|
|
||||||
# optimization
|
# optimization
|
||||||
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
|
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
|
||||||
|
@ -77,6 +77,9 @@ parser.add_argument('--output_attention', action='store_true', help='whether to
|
|||||||
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
|
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
|
||||||
parser.add_argument('--stride', type=int, default=8, help='stride')
|
parser.add_argument('--stride', type=int, default=8, help='stride')
|
||||||
parser.add_argument('--prompt_domain', type=int, default=0, help='')
|
parser.add_argument('--prompt_domain', type=int, default=0, help='')
|
||||||
|
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model')
|
||||||
|
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')
|
||||||
|
|
||||||
|
|
||||||
# optimization
|
# optimization
|
||||||
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
|
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
|
||||||
|
@ -77,7 +77,9 @@ parser.add_argument('--activation', type=str, default='gelu', help='activation')
|
|||||||
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
|
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
|
||||||
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
|
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
|
||||||
parser.add_argument('--stride', type=int, default=8, help='stride')
|
parser.add_argument('--stride', type=int, default=8, help='stride')
|
||||||
parser.add_argument('--prompt_domain', type=int, default=0, help='stride')
|
parser.add_argument('--prompt_domain', type=int, default=0, help='')
|
||||||
|
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model')
|
||||||
|
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')
|
||||||
|
|
||||||
# optimization
|
# optimization
|
||||||
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
|
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
|
||||||
|
@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path electricity.csv \
|
--data_path electricity.csv \
|
||||||
--model_id ECL_512_96 \
|
--model_id ECL_512_96 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data ECL \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
@ -42,7 +42,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path electricity.csv \
|
--data_path electricity.csv \
|
||||||
--model_id ECL_512_192 \
|
--model_id ECL_512_192 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data ECL \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
@ -66,7 +66,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path electricity.csv \
|
--data_path electricity.csv \
|
||||||
--model_id ECL_512_336 \
|
--model_id ECL_512_336 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data ECL \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
@ -90,7 +90,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path electricity.csv \
|
--data_path electricity.csv \
|
||||||
--model_id ECL_512_720 \
|
--model_id ECL_512_720 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data ECL \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
|
@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path traffic.csv \
|
--data_path traffic.csv \
|
||||||
--model_id traffic_512_96 \
|
--model_id traffic_512_96 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data Traffic \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
@ -42,7 +42,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path traffic.csv \
|
--data_path traffic.csv \
|
||||||
--model_id traffic_512_96 \
|
--model_id traffic_512_96 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data Traffic \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
@ -66,7 +66,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path traffic.csv \
|
--data_path traffic.csv \
|
||||||
--model_id traffic_512_96 \
|
--model_id traffic_512_96 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data Traffic \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
@ -90,7 +90,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path traffic.csv \
|
--data_path traffic.csv \
|
||||||
--model_id traffic_512_96 \
|
--model_id traffic_512_96 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data Traffic \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 720 \
|
--label_len 720 \
|
||||||
|
@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path weather.csv \
|
--data_path weather.csv \
|
||||||
--model_id weather_512_96 \
|
--model_id weather_512_96 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data Weather \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
@ -44,7 +44,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path weather.csv \
|
--data_path weather.csv \
|
||||||
--model_id weather_512_192 \
|
--model_id weather_512_192 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data Weather \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
@ -70,7 +70,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path weather.csv \
|
--data_path weather.csv \
|
||||||
--model_id weather_512_336 \
|
--model_id weather_512_336 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data Weather \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
@ -96,7 +96,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
|||||||
--data_path weather.csv \
|
--data_path weather.csv \
|
||||||
--model_id weather_512_720 \
|
--model_id weather_512_720 \
|
||||||
--model $model_name \
|
--model $model_name \
|
||||||
--data custom \
|
--data Weather \
|
||||||
--features M \
|
--features M \
|
||||||
--seq_len 512 \
|
--seq_len 512 \
|
||||||
--label_len 48 \
|
--label_len 48 \
|
||||||
|
Loading…
Reference in New Issue
Block a user