mirror of
https://github.com/KimMeen/Time-LLM.git
synced 2024-11-21 03:13:47 +08:00
commit=add GPT2 and BERT for Time-LLM as general framework &update the prompt_bank(weather, ECL,Traffic)
This commit is contained in:
parent
980a9024ff
commit
e022670182
@ -6,7 +6,9 @@ data_dict = {
|
||||
'ETTh2': Dataset_ETT_hour,
|
||||
'ETTm1': Dataset_ETT_minute,
|
||||
'ETTm2': Dataset_ETT_minute,
|
||||
'custom': Dataset_Custom,
|
||||
'ECL': Dataset_Custom,
|
||||
'Traffic': Dataset_Custom,
|
||||
'Weather': Dataset_Custom,
|
||||
'm4': Dataset_M4,
|
||||
}
|
||||
|
||||
|
@ -3,11 +3,10 @@ from math import sqrt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import LlamaConfig, LlamaModel, LlamaTokenizer
|
||||
from transformers import LlamaConfig, LlamaModel, LlamaTokenizer, GPT2Config, GPT2Model, GPT2Tokenizer, BertConfig, \
|
||||
BertModel, BertTokenizer
|
||||
from layers.Embed import PatchEmbedding
|
||||
|
||||
import transformers
|
||||
|
||||
from layers.StandardNorm import Normalize
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
@ -37,17 +36,18 @@ class Model(nn.Module):
|
||||
self.seq_len = configs.seq_len
|
||||
self.d_ff = configs.d_ff
|
||||
self.top_k = 5
|
||||
self.d_llm = 4096
|
||||
self.d_llm = configs.llm_dim
|
||||
self.patch_len = configs.patch_len
|
||||
self.stride = configs.stride
|
||||
|
||||
self.llama_config = LlamaConfig.from_pretrained('/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/')
|
||||
# self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b')
|
||||
if configs.llm_model == 'LLAMA':
|
||||
# self.llama_config = LlamaConfig.from_pretrained('/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/')
|
||||
self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b')
|
||||
self.llama_config.num_hidden_layers = configs.llm_layers
|
||||
self.llama_config.output_attentions = True
|
||||
self.llama_config.output_hidden_states = True
|
||||
try:
|
||||
self.llama = LlamaModel.from_pretrained(
|
||||
self.llm_model = LlamaModel.from_pretrained(
|
||||
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
|
||||
'huggyllama/llama-7b',
|
||||
trust_remote_code=True,
|
||||
@ -57,7 +57,7 @@ class Model(nn.Module):
|
||||
)
|
||||
except EnvironmentError: # downloads model from HF is not already done
|
||||
print("Local model files not found. Attempting to download...")
|
||||
self.llama = LlamaModel.from_pretrained(
|
||||
self.llm_model = LlamaModel.from_pretrained(
|
||||
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
|
||||
'huggyllama/llama-7b',
|
||||
trust_remote_code=True,
|
||||
@ -80,6 +80,78 @@ class Model(nn.Module):
|
||||
trust_remote_code=True,
|
||||
local_files_only=False
|
||||
)
|
||||
elif configs.llm_model == 'GPT2':
|
||||
self.gpt2_config = GPT2Config.from_pretrained('openai-community/gpt2')
|
||||
|
||||
self.gpt2_config.num_hidden_layers = configs.llm_layers
|
||||
self.gpt2_config.output_attentions = True
|
||||
self.gpt2_config.output_hidden_states = True
|
||||
try:
|
||||
self.llm_model = GPT2Model.from_pretrained(
|
||||
'openai-community/gpt2',
|
||||
trust_remote_code=True,
|
||||
local_files_only=True,
|
||||
config=self.gpt2_config,
|
||||
)
|
||||
except EnvironmentError: # downloads model from HF is not already done
|
||||
print("Local model files not found. Attempting to download...")
|
||||
self.llm_model = GPT2Model.from_pretrained(
|
||||
'openai-community/gpt2',
|
||||
trust_remote_code=True,
|
||||
local_files_only=False,
|
||||
config=self.gpt2_config,
|
||||
)
|
||||
|
||||
try:
|
||||
self.tokenizer = GPT2Tokenizer.from_pretrained(
|
||||
'openai-community/gpt2',
|
||||
trust_remote_code=True,
|
||||
local_files_only=True
|
||||
)
|
||||
except EnvironmentError: # downloads the tokenizer from HF if not already done
|
||||
print("Local tokenizer files not found. Atempting to download them..")
|
||||
self.tokenizer = GPT2Tokenizer.from_pretrained(
|
||||
'openai-community/gpt2',
|
||||
trust_remote_code=True,
|
||||
local_files_only=False
|
||||
)
|
||||
elif configs.llm_model == 'BERT':
|
||||
self.bert_config = BertConfig.from_pretrained('google-bert/bert-base-uncased')
|
||||
|
||||
self.bert_config.num_hidden_layers = configs.llm_layers
|
||||
self.bert_config.output_attentions = True
|
||||
self.bert_config.output_hidden_states = True
|
||||
try:
|
||||
self.llm_model = BertModel.from_pretrained(
|
||||
'google-bert/bert-base-uncased',
|
||||
trust_remote_code=True,
|
||||
local_files_only=True,
|
||||
config=self.bert_config,
|
||||
)
|
||||
except EnvironmentError: # downloads model from HF is not already done
|
||||
print("Local model files not found. Attempting to download...")
|
||||
self.llm_model = BertModel.from_pretrained(
|
||||
'google-bert/bert-base-uncased',
|
||||
trust_remote_code=True,
|
||||
local_files_only=False,
|
||||
config=self.gpt2_config,
|
||||
)
|
||||
|
||||
try:
|
||||
self.tokenizer = BertTokenizer.from_pretrained(
|
||||
'google-bert/bert-base-uncased',
|
||||
trust_remote_code=True,
|
||||
local_files_only=True
|
||||
)
|
||||
except EnvironmentError: # downloads the tokenizer from HF if not already done
|
||||
print("Local tokenizer files not found. Atempting to download them..")
|
||||
self.tokenizer = BertTokenizer.from_pretrained(
|
||||
'google-bert/bert-base-uncased',
|
||||
trust_remote_code=True,
|
||||
local_files_only=False
|
||||
)
|
||||
else:
|
||||
raise Exception('LLM model is not defined')
|
||||
|
||||
if self.tokenizer.eos_token:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
@ -88,15 +160,20 @@ class Model(nn.Module):
|
||||
self.tokenizer.add_special_tokens({'pad_token': pad_token})
|
||||
self.tokenizer.pad_token = pad_token
|
||||
|
||||
for param in self.llama.parameters():
|
||||
for param in self.llm_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
if configs.prompt_domain:
|
||||
self.description = configs.content
|
||||
else:
|
||||
self.description = 'The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment.'
|
||||
|
||||
self.dropout = nn.Dropout(configs.dropout)
|
||||
|
||||
self.patch_embedding = PatchEmbedding(
|
||||
configs.d_model, self.patch_len, self.stride, configs.dropout)
|
||||
|
||||
self.word_embeddings = self.llama.get_input_embeddings().weight
|
||||
self.word_embeddings = self.llm_model.get_input_embeddings().weight
|
||||
self.vocab_size = self.word_embeddings.shape[0]
|
||||
self.num_tokens = 1000
|
||||
self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)
|
||||
@ -140,7 +217,7 @@ class Model(nn.Module):
|
||||
median_values_str = str(medians[b].tolist()[0])
|
||||
lags_values_str = str(lags[b].tolist())
|
||||
prompt_ = (
|
||||
f"<|start_prompt|>Dataset description: The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment."
|
||||
f"<|start_prompt|>Dataset description: {self.description}"
|
||||
f"Task description: forecast the next {str(self.pred_len)} steps given the previous {str(self.seq_len)} steps information; "
|
||||
"Input statistics: "
|
||||
f"min value {min_values_str}, "
|
||||
@ -155,7 +232,7 @@ class Model(nn.Module):
|
||||
x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()
|
||||
|
||||
prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
|
||||
prompt_embeddings = self.llama.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim)
|
||||
prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim)
|
||||
|
||||
source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
||||
|
||||
@ -163,7 +240,7 @@ class Model(nn.Module):
|
||||
enc_out, n_vars = self.patch_embedding(x_enc.to(torch.bfloat16))
|
||||
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
||||
llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)
|
||||
dec_out = self.llama(inputs_embeds=llama_enc_out).last_hidden_state
|
||||
dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state
|
||||
dec_out = dec_out[:, :, :self.d_ff]
|
||||
|
||||
dec_out = torch.reshape(
|
||||
|
@ -80,7 +80,9 @@ parser.add_argument('--activation', type=str, default='gelu', help='activation')
|
||||
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
|
||||
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
|
||||
parser.add_argument('--stride', type=int, default=8, help='stride')
|
||||
parser.add_argument('--prompt_domain', type=int, default=0, help='stride')
|
||||
parser.add_argument('--prompt_domain', type=int, default=0, help='')
|
||||
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model')
|
||||
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')
|
||||
|
||||
# optimization
|
||||
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
|
||||
|
@ -77,6 +77,9 @@ parser.add_argument('--output_attention', action='store_true', help='whether to
|
||||
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
|
||||
parser.add_argument('--stride', type=int, default=8, help='stride')
|
||||
parser.add_argument('--prompt_domain', type=int, default=0, help='')
|
||||
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model')
|
||||
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')
|
||||
|
||||
|
||||
# optimization
|
||||
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
|
||||
|
@ -77,7 +77,9 @@ parser.add_argument('--activation', type=str, default='gelu', help='activation')
|
||||
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
|
||||
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
|
||||
parser.add_argument('--stride', type=int, default=8, help='stride')
|
||||
parser.add_argument('--prompt_domain', type=int, default=0, help='stride')
|
||||
parser.add_argument('--prompt_domain', type=int, default=0, help='')
|
||||
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model')
|
||||
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')
|
||||
|
||||
# optimization
|
||||
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
|
||||
|
@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path electricity.csv \
|
||||
--model_id ECL_512_96 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data ECL \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
@ -42,7 +42,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path electricity.csv \
|
||||
--model_id ECL_512_192 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data ECL \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
@ -66,7 +66,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path electricity.csv \
|
||||
--model_id ECL_512_336 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data ECL \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
@ -90,7 +90,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path electricity.csv \
|
||||
--model_id ECL_512_720 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data ECL \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
|
@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path traffic.csv \
|
||||
--model_id traffic_512_96 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data Traffic \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
@ -42,7 +42,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path traffic.csv \
|
||||
--model_id traffic_512_96 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data Traffic \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
@ -66,7 +66,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path traffic.csv \
|
||||
--model_id traffic_512_96 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data Traffic \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
@ -90,7 +90,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path traffic.csv \
|
||||
--model_id traffic_512_96 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data Traffic \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 720 \
|
||||
|
@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path weather.csv \
|
||||
--model_id weather_512_96 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data Weather \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
@ -44,7 +44,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path weather.csv \
|
||||
--model_id weather_512_192 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data Weather \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
@ -70,7 +70,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path weather.csv \
|
||||
--model_id weather_512_336 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data Weather \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
@ -96,7 +96,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
|
||||
--data_path weather.csv \
|
||||
--model_id weather_512_720 \
|
||||
--model $model_name \
|
||||
--data custom \
|
||||
--data Weather \
|
||||
--features M \
|
||||
--seq_len 512 \
|
||||
--label_len 48 \
|
||||
|
Loading…
Reference in New Issue
Block a user