Merge pull request #38 from KimMeen/dev

commit=add GPT2 and BERT for Time-LLM as general framework
This commit is contained in:
MetaKing 2024-03-18 20:00:06 +08:00 committed by GitHub
commit af55253c88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 149 additions and 63 deletions

View File

@ -6,7 +6,9 @@ data_dict = {
'ETTh2': Dataset_ETT_hour,
'ETTm1': Dataset_ETT_minute,
'ETTm2': Dataset_ETT_minute,
'custom': Dataset_Custom,
'ECL': Dataset_Custom,
'Traffic': Dataset_Custom,
'Weather': Dataset_Custom,
'm4': Dataset_M4,
}

View File

@ -3,11 +3,10 @@ from math import sqrt
import torch
import torch.nn as nn
from transformers import LlamaConfig, LlamaModel, LlamaTokenizer
from transformers import LlamaConfig, LlamaModel, LlamaTokenizer, GPT2Config, GPT2Model, GPT2Tokenizer, BertConfig, \
BertModel, BertTokenizer
from layers.Embed import PatchEmbedding
import transformers
from layers.StandardNorm import Normalize
transformers.logging.set_verbosity_error()
@ -37,49 +36,122 @@ class Model(nn.Module):
self.seq_len = configs.seq_len
self.d_ff = configs.d_ff
self.top_k = 5
self.d_llm = 4096
self.d_llm = configs.llm_dim
self.patch_len = configs.patch_len
self.stride = configs.stride
self.llama_config = LlamaConfig.from_pretrained('/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/')
# self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b')
self.llama_config.num_hidden_layers = configs.llm_layers
self.llama_config.output_attentions = True
self.llama_config.output_hidden_states = True
try:
self.llama = LlamaModel.from_pretrained(
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
'huggyllama/llama-7b',
trust_remote_code=True,
local_files_only=True,
config=self.llama_config,
# load_in_4bit=True
)
except EnvironmentError: # downloads model from HF is not already done
print("Local model files not found. Attempting to download...")
self.llama = LlamaModel.from_pretrained(
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
'huggyllama/llama-7b',
trust_remote_code=True,
local_files_only=False,
config=self.llama_config,
# load_in_4bit=True
)
try:
self.tokenizer = LlamaTokenizer.from_pretrained(
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model",
'huggyllama/llama-7b',
trust_remote_code=True,
local_files_only=True
)
except EnvironmentError: # downloads the tokenizer from HF if not already done
print("Local tokenizer files not found. Atempting to download them..")
self.tokenizer = LlamaTokenizer.from_pretrained(
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model",
'huggyllama/llama-7b',
trust_remote_code=True,
local_files_only=False
)
if configs.llm_model == 'LLAMA':
# self.llama_config = LlamaConfig.from_pretrained('/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/')
self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b')
self.llama_config.num_hidden_layers = configs.llm_layers
self.llama_config.output_attentions = True
self.llama_config.output_hidden_states = True
try:
self.llm_model = LlamaModel.from_pretrained(
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
'huggyllama/llama-7b',
trust_remote_code=True,
local_files_only=True,
config=self.llama_config,
# load_in_4bit=True
)
except EnvironmentError: # downloads model from HF is not already done
print("Local model files not found. Attempting to download...")
self.llm_model = LlamaModel.from_pretrained(
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/",
'huggyllama/llama-7b',
trust_remote_code=True,
local_files_only=False,
config=self.llama_config,
# load_in_4bit=True
)
try:
self.tokenizer = LlamaTokenizer.from_pretrained(
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model",
'huggyllama/llama-7b',
trust_remote_code=True,
local_files_only=True
)
except EnvironmentError: # downloads the tokenizer from HF if not already done
print("Local tokenizer files not found. Atempting to download them..")
self.tokenizer = LlamaTokenizer.from_pretrained(
# "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model",
'huggyllama/llama-7b',
trust_remote_code=True,
local_files_only=False
)
elif configs.llm_model == 'GPT2':
self.gpt2_config = GPT2Config.from_pretrained('openai-community/gpt2')
self.gpt2_config.num_hidden_layers = configs.llm_layers
self.gpt2_config.output_attentions = True
self.gpt2_config.output_hidden_states = True
try:
self.llm_model = GPT2Model.from_pretrained(
'openai-community/gpt2',
trust_remote_code=True,
local_files_only=True,
config=self.gpt2_config,
)
except EnvironmentError: # downloads model from HF is not already done
print("Local model files not found. Attempting to download...")
self.llm_model = GPT2Model.from_pretrained(
'openai-community/gpt2',
trust_remote_code=True,
local_files_only=False,
config=self.gpt2_config,
)
try:
self.tokenizer = GPT2Tokenizer.from_pretrained(
'openai-community/gpt2',
trust_remote_code=True,
local_files_only=True
)
except EnvironmentError: # downloads the tokenizer from HF if not already done
print("Local tokenizer files not found. Atempting to download them..")
self.tokenizer = GPT2Tokenizer.from_pretrained(
'openai-community/gpt2',
trust_remote_code=True,
local_files_only=False
)
elif configs.llm_model == 'BERT':
self.bert_config = BertConfig.from_pretrained('google-bert/bert-base-uncased')
self.bert_config.num_hidden_layers = configs.llm_layers
self.bert_config.output_attentions = True
self.bert_config.output_hidden_states = True
try:
self.llm_model = BertModel.from_pretrained(
'google-bert/bert-base-uncased',
trust_remote_code=True,
local_files_only=True,
config=self.bert_config,
)
except EnvironmentError: # downloads model from HF is not already done
print("Local model files not found. Attempting to download...")
self.llm_model = BertModel.from_pretrained(
'google-bert/bert-base-uncased',
trust_remote_code=True,
local_files_only=False,
config=self.bert_config,
)
try:
self.tokenizer = BertTokenizer.from_pretrained(
'google-bert/bert-base-uncased',
trust_remote_code=True,
local_files_only=True
)
except EnvironmentError: # downloads the tokenizer from HF if not already done
print("Local tokenizer files not found. Atempting to download them..")
self.tokenizer = BertTokenizer.from_pretrained(
'google-bert/bert-base-uncased',
trust_remote_code=True,
local_files_only=False
)
else:
raise Exception('LLM model is not defined')
if self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
@ -88,15 +160,20 @@ class Model(nn.Module):
self.tokenizer.add_special_tokens({'pad_token': pad_token})
self.tokenizer.pad_token = pad_token
for param in self.llama.parameters():
for param in self.llm_model.parameters():
param.requires_grad = False
if configs.prompt_domain:
self.description = configs.content
else:
self.description = 'The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment.'
self.dropout = nn.Dropout(configs.dropout)
self.patch_embedding = PatchEmbedding(
configs.d_model, self.patch_len, self.stride, configs.dropout)
self.word_embeddings = self.llama.get_input_embeddings().weight
self.word_embeddings = self.llm_model.get_input_embeddings().weight
self.vocab_size = self.word_embeddings.shape[0]
self.num_tokens = 1000
self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)
@ -140,7 +217,7 @@ class Model(nn.Module):
median_values_str = str(medians[b].tolist()[0])
lags_values_str = str(lags[b].tolist())
prompt_ = (
f"<|start_prompt|>Dataset description: The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment."
f"<|start_prompt|>Dataset description: {self.description}"
f"Task description: forecast the next {str(self.pred_len)} steps given the previous {str(self.seq_len)} steps information; "
"Input statistics: "
f"min value {min_values_str}, "
@ -155,7 +232,7 @@ class Model(nn.Module):
x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()
prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
prompt_embeddings = self.llama.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim)
prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim)
source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
@ -163,7 +240,7 @@ class Model(nn.Module):
enc_out, n_vars = self.patch_embedding(x_enc.to(torch.bfloat16))
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)
dec_out = self.llama(inputs_embeds=llama_enc_out).last_hidden_state
dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state
dec_out = dec_out[:, :, :self.d_ff]
dec_out = torch.reshape(

View File

@ -80,7 +80,9 @@ parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
parser.add_argument('--stride', type=int, default=8, help='stride')
parser.add_argument('--prompt_domain', type=int, default=0, help='stride')
parser.add_argument('--prompt_domain', type=int, default=0, help='')
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model') # LLAMA, GPT2, BERT
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')# LLama7b:4096; GPT2-small:768; BERT-base:768
# optimization
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')

View File

@ -77,6 +77,9 @@ parser.add_argument('--output_attention', action='store_true', help='whether to
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
parser.add_argument('--stride', type=int, default=8, help='stride')
parser.add_argument('--prompt_domain', type=int, default=0, help='')
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model') # LLAMA, GPT2, BERT
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')# LLama7b:4096; GPT2-small:768; BERT-base:768
# optimization
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')

View File

@ -77,7 +77,9 @@ parser.add_argument('--activation', type=str, default='gelu', help='activation')
parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder')
parser.add_argument('--patch_len', type=int, default=16, help='patch length')
parser.add_argument('--stride', type=int, default=8, help='stride')
parser.add_argument('--prompt_domain', type=int, default=0, help='stride')
parser.add_argument('--prompt_domain', type=int, default=0, help='')
parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model') # LLAMA, GPT2, BERT
parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension')# LLama7b:4096; GPT2-small:768; BERT-base:768
# optimization
parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')

View File

@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path electricity.csv \
--model_id ECL_512_96 \
--model $model_name \
--data custom \
--data ECL \
--features M \
--seq_len 512 \
--label_len 48 \
@ -42,7 +42,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path electricity.csv \
--model_id ECL_512_192 \
--model $model_name \
--data custom \
--data ECL \
--features M \
--seq_len 512 \
--label_len 48 \
@ -66,7 +66,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path electricity.csv \
--model_id ECL_512_336 \
--model $model_name \
--data custom \
--data ECL \
--features M \
--seq_len 512 \
--label_len 48 \
@ -90,7 +90,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path electricity.csv \
--model_id ECL_512_720 \
--model $model_name \
--data custom \
--data ECL \
--features M \
--seq_len 512 \
--label_len 48 \

View File

@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path traffic.csv \
--model_id traffic_512_96 \
--model $model_name \
--data custom \
--data Traffic \
--features M \
--seq_len 512 \
--label_len 48 \
@ -42,7 +42,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path traffic.csv \
--model_id traffic_512_96 \
--model $model_name \
--data custom \
--data Traffic \
--features M \
--seq_len 512 \
--label_len 48 \
@ -66,7 +66,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path traffic.csv \
--model_id traffic_512_96 \
--model $model_name \
--data custom \
--data Traffic \
--features M \
--seq_len 512 \
--label_len 48 \
@ -90,7 +90,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path traffic.csv \
--model_id traffic_512_96 \
--model $model_name \
--data custom \
--data Traffic \
--features M \
--seq_len 512 \
--label_len 720 \

View File

@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path weather.csv \
--model_id weather_512_96 \
--model $model_name \
--data custom \
--data Weather \
--features M \
--seq_len 512 \
--label_len 48 \
@ -44,7 +44,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path weather.csv \
--model_id weather_512_192 \
--model $model_name \
--data custom \
--data Weather \
--features M \
--seq_len 512 \
--label_len 48 \
@ -70,7 +70,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path weather.csv \
--model_id weather_512_336 \
--model $model_name \
--data custom \
--data Weather \
--features M \
--seq_len 512 \
--label_len 48 \
@ -96,7 +96,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces
--data_path weather.csv \
--model_id weather_512_720 \
--model $model_name \
--data custom \
--data Weather \
--features M \
--seq_len 512 \
--label_len 48 \