commit=add GPT2 and BERT for Time-LLM as general framework &update the prompt_bank(weather, ECL,Traffic)

This commit is contained in:
weiming.wsy 2024-03-18 14:58:49 +08:00
parent 980a9024ff
commit e022670182
8 changed files with 149 additions and 63 deletions

View File

@ -6,7 +6,9 @@ data_dict = {
'ETTh2': Dataset_ETT_hour, 'ETTh2': Dataset_ETT_hour,
'ETTm1': Dataset_ETT_minute, 'ETTm1': Dataset_ETT_minute,
'ETTm2': Dataset_ETT_minute, 'ETTm2': Dataset_ETT_minute,
'custom': Dataset_Custom, 'ECL': Dataset_Custom,
'Traffic': Dataset_Custom,
'Weather': Dataset_Custom,
'm4': Dataset_M4, 'm4': Dataset_M4,
} }

View File

@ -3,11 +3,10 @@ from math import sqrt
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import LlamaConfig, LlamaModel, LlamaTokenizer from transformers import LlamaConfig, LlamaModel, LlamaTokenizer, GPT2Config, GPT2Model, GPT2Tokenizer, BertConfig, \
BertModel, BertTokenizer
from layers.Embed import PatchEmbedding from layers.Embed import PatchEmbedding
import transformers import transformers
from layers.StandardNorm import Normalize from layers.StandardNorm import Normalize
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -37,17 +36,18 @@ class Model(nn.Module):
self.seq_len = configs.seq_len self.seq_len = configs.seq_len
self.d_ff = configs.d_ff self.d_ff = configs.d_ff
self.top_k = 5 self.top_k = 5
self.d_llm = 4096 self.d_llm = configs.llm_dim
self.patch_len = configs.patch_len self.patch_len = configs.patch_len
self.stride = configs.stride self.stride = configs.stride
self.llama_config = LlamaConfig.from_pretrained('/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/') if configs.llm_model == 'LLAMA':
# self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b') # self.llama_config = LlamaConfig.from_pretrained('/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/')
self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b')
self.llama_config.num_hidden_layers = configs.llm_layers self.llama_config.num_hidden_layers = configs.llm_layers
self.llama_config.output_attentions = True self.llama_config.output_attentions = True
self.llama_config.output_hidden_states = True self.llama_config.output_hidden_states = True
try: try:
self.llama = LlamaModel.from_pretrained( self.llm_model = LlamaModel.from_pretrained(
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/", # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
'huggyllama/llama-7b', 'huggyllama/llama-7b',
trust_remote_code=True, trust_remote_code=True,
@ -57,7 +57,7 @@ class Model(nn.Module):
) )
except EnvironmentError: # downloads model from HF is not already done except EnvironmentError: # downloads model from HF is not already done
print("Local model files not found. Attempting to download...") print("Local model files not found. Attempting to download...")
self.llama = LlamaModel.from_pretrained( self.llm_model = LlamaModel.from_pretrained(
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/", # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
'huggyllama/llama-7b', 'huggyllama/llama-7b',
trust_remote_code=True, trust_remote_code=True,
@ -80,6 +80,78 @@ class Model(nn.Module):
trust_remote_code=True, trust_remote_code=True,
local_files_only=False local_files_only=False
) )
elif configs.llm_model == 'GPT2':
self.gpt2_config = GPT2Config.from_pretrained('openai-community/gpt2')
self.gpt2_config.num_hidden_layers = configs.llm_layers
self.gpt2_config.output_attentions = True
self.gpt2_config.output_hidden_states = True
try:
self.llm_model = GPT2Model.from_pretrained(
'openai-community/gpt2',
trust_remote_code=True,
local_files_only=True,
config=self.gpt2_config,
)
except EnvironmentError: # downloads model from HF is not already done
print("Local model files not found. Attempting to download...")
self.llm_model = GPT2Model.from_pretrained(
'openai-community/gpt2',
trust_remote_code=True,
local_files_only=False,
config=self.gpt2_config,
)
try:
self.tokenizer = GPT2Tokenizer.from_pretrained(
'openai-community/gpt2',
trust_remote_code=True,
local_files_only=True
)
except EnvironmentError: # downloads the tokenizer from HF if not already done
print("Local tokenizer files not found. Atempting to download them..")
self.tokenizer = GPT2Tokenizer.from_pretrained(
'openai-community/gpt2',
trust_remote_code=True,
local_files_only=False
)
elif configs.llm_model == 'BERT':
self.bert_config = BertConfig.from_pretrained('google-bert/bert-base-uncased')
self.bert_config.num_hidden_layers = configs.llm_layers
self.bert_config.output_attentions = True
self.bert_config.output_hidden_states = True
try:
self.llm_model = BertModel.from_pretrained(
'google-bert/bert-base-uncased',
trust_remote_code=True,
local_files_only=True,
config=self.bert_config,
)
except EnvironmentError: # downloads model from HF is not already done
print("Local model files not found. Attempting to download...")
self.llm_model = BertModel.from_pretrained(
'google-bert/bert-base-uncased',
trust_remote_code=True,
local_files_only=False,
config=self.gpt2_config,
)
try:
self.tokenizer = BertTokenizer.from_pretrained(
'google-bert/bert-base-uncased',
trust_remote_code=True,
local_files_only=True
)
except EnvironmentError: # downloads the tokenizer from HF if not already done
print("Local tokenizer files not found. Atempting to download them..")
self.tokenizer = BertTokenizer.from_pretrained(
'google-bert/bert-base-uncased',
trust_remote_code=True,
local_files_only=False
)
else:
raise Exception('LLM model is not defined')
if self.tokenizer.eos_token: if self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token = self.tokenizer.eos_token
@ -88,15 +160,20 @@ class Model(nn.Module):
self.tokenizer.add_special_tokens({'pad_token': pad_token}) self.tokenizer.add_special_tokens({'pad_token': pad_token})
self.tokenizer.pad_token = pad_token self.tokenizer.pad_token = pad_token
for param in self.llama.parameters(): for param in self.llm_model.parameters():
param.requires_grad = False param.requires_grad = False
if configs.prompt_domain:
self.description = configs.content
else:
self.description = 'The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment.'
self.dropout = nn.Dropout(configs.dropout) self.dropout = nn.Dropout(configs.dropout)
self.patch_embedding = PatchEmbedding( self.patch_embedding = PatchEmbedding(
configs.d_model, self.patch_len, self.stride, configs.dropout) configs.d_model, self.patch_len, self.stride, configs.dropout)
self.word_embeddings = self.llama.get_input_embeddings().weight self.word_embeddings = self.llm_model.get_input_embeddings().weight
self.vocab_size = self.word_embeddings.shape[0] self.vocab_size = self.word_embeddings.shape[0]
self.num_tokens = 1000 self.num_tokens = 1000
self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens) self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)
@ -140,7 +217,7 @@ class Model(nn.Module):
median_values_str = str(medians[b].tolist()[0]) median_values_str = str(medians[b].tolist()[0])
lags_values_str = str(lags[b].tolist()) lags_values_str = str(lags[b].tolist())
prompt_ = ( prompt_ = (
f"<|start_prompt|>Dataset description: The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment." f"<|start_prompt|>Dataset description: {self.description}"
f"Task description: forecast the next {str(self.pred_len)} steps given the previous {str(self.seq_len)} steps information; " f"Task description: forecast the next {str(self.pred_len)} steps given the previous {str(self.seq_len)} steps information; "
"Input statistics: " "Input statistics: "
f"min value {min_values_str}, " f"min value {min_values_str}, "
@ -155,7 +232,7 @@ class Model(nn.Module):
x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous() x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()
prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
prompt_embeddings = self.llama.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim) prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim)
source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
@ -163,7 +240,7 @@ class Model(nn.Module):
enc_out, n_vars = self.patch_embedding(x_enc.to(torch.bfloat16)) enc_out, n_vars = self.patch_embedding(x_enc.to(torch.bfloat16))
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1) llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)
dec_out = self.llama(inputs_embeds=llama_enc_out).last_hidden_state dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state
dec_out = dec_out[:, :, :self.d_ff] dec_out = dec_out[:, :, :self.d_ff]
dec_out = torch.reshape( dec_out = torch.reshape(

View File

@ -80,7 +80,9 @@ parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
parser.add_argument('--patch_len', type=int, default=16, help='patch length') parser.add_argument('--patch_len', type=int, default=16, help='patch length')
parser.add_argument('--stride', type=int, default=8, help='stride') parser.add_argument('--stride', type=int, default=8, help='stride')
parser.add_argument('--prompt_domain', type=int, default=0, help='stride') parser.add_argument('--prompt_domain', type=int, default=0, help='')
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model')
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')
# optimization # optimization
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')

View File

@ -77,6 +77,9 @@ parser.add_argument('--output_attention', action='store_true', help='whether to
parser.add_argument('--patch_len', type=int, default=16, help='patch length') parser.add_argument('--patch_len', type=int, default=16, help='patch length')
parser.add_argument('--stride', type=int, default=8, help='stride') parser.add_argument('--stride', type=int, default=8, help='stride')
parser.add_argument('--prompt_domain', type=int, default=0, help='') parser.add_argument('--prompt_domain', type=int, default=0, help='')
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model')
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')
# optimization # optimization
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')

View File

@ -77,7 +77,9 @@ parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
parser.add_argument('--patch_len', type=int, default=16, help='patch length') parser.add_argument('--patch_len', type=int, default=16, help='patch length')
parser.add_argument('--stride', type=int, default=8, help='stride') parser.add_argument('--stride', type=int, default=8, help='stride')
parser.add_argument('--prompt_domain', type=int, default=0, help='stride') parser.add_argument('--prompt_domain', type=int, default=0, help='')
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model')
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')
# optimization # optimization
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')

View File

@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path electricity.csv \ --data_path electricity.csv \
--model_id ECL_512_96 \ --model_id ECL_512_96 \
--model $model_name \ --model $model_name \
--data custom \ --data ECL \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \
@ -42,7 +42,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path electricity.csv \ --data_path electricity.csv \
--model_id ECL_512_192 \ --model_id ECL_512_192 \
--model $model_name \ --model $model_name \
--data custom \ --data ECL \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \
@ -66,7 +66,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path electricity.csv \ --data_path electricity.csv \
--model_id ECL_512_336 \ --model_id ECL_512_336 \
--model $model_name \ --model $model_name \
--data custom \ --data ECL \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \
@ -90,7 +90,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path electricity.csv \ --data_path electricity.csv \
--model_id ECL_512_720 \ --model_id ECL_512_720 \
--model $model_name \ --model $model_name \
--data custom \ --data ECL \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \

View File

@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path traffic.csv \ --data_path traffic.csv \
--model_id traffic_512_96 \ --model_id traffic_512_96 \
--model $model_name \ --model $model_name \
--data custom \ --data Traffic \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \
@ -42,7 +42,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path traffic.csv \ --data_path traffic.csv \
--model_id traffic_512_96 \ --model_id traffic_512_96 \
--model $model_name \ --model $model_name \
--data custom \ --data Traffic \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \
@ -66,7 +66,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path traffic.csv \ --data_path traffic.csv \
--model_id traffic_512_96 \ --model_id traffic_512_96 \
--model $model_name \ --model $model_name \
--data custom \ --data Traffic \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \
@ -90,7 +90,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path traffic.csv \ --data_path traffic.csv \
--model_id traffic_512_96 \ --model_id traffic_512_96 \
--model $model_name \ --model $model_name \
--data custom \ --data Traffic \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 720 \ --label_len 720 \

View File

@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path weather.csv \ --data_path weather.csv \
--model_id weather_512_96 \ --model_id weather_512_96 \
--model $model_name \ --model $model_name \
--data custom \ --data Weather \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \
@ -44,7 +44,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path weather.csv \ --data_path weather.csv \
--model_id weather_512_192 \ --model_id weather_512_192 \
--model $model_name \ --model $model_name \
--data custom \ --data Weather \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \
@ -70,7 +70,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path weather.csv \ --data_path weather.csv \
--model_id weather_512_336 \ --model_id weather_512_336 \
--model $model_name \ --model $model_name \
--data custom \ --data Weather \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \
@ -96,7 +96,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path weather.csv \ --data_path weather.csv \
--model_id weather_512_720 \ --model_id weather_512_720 \
--model $model_name \ --model $model_name \
--data custom \ --data Weather \
--features M \ --features M \
--seq_len 512 \ --seq_len 512 \
--label_len 48 \ --label_len 48 \