From e0226701820f54b36d187679d6446f75c59b407b Mon Sep 17 00:00:00 2001 From: "weiming.wsy" Date: Mon, 18 Mar 2024 14:58:49 +0800 Subject: [PATCH] =?UTF-8?q?commit=3Dadd=20GPT2=20and=20BERT=20for=20Time-L?= =?UTF-8?q?LM=20as=20general=20framework=20&update=20the=20prompt=5Fbank(w?= =?UTF-8?q?eather,=20ECL=EF=BC=8CTraffic)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data_provider/data_factory.py | 4 +- models/TimeLLM.py | 173 ++++++++++++++++++++++++---------- run_m4.py | 4 +- run_main.py | 3 + run_pretrain.py | 4 +- scripts/TimeLLM_ECL.sh | 8 +- scripts/TimeLLM_Traffic.sh | 8 +- scripts/TimeLLM_Weather.sh | 8 +- 8 files changed, 149 insertions(+), 63 deletions(-) diff --git a/data_provider/data_factory.py b/data_provider/data_factory.py index 1e51851..c046f05 100644 --- a/data_provider/data_factory.py +++ b/data_provider/data_factory.py @@ -6,7 +6,9 @@ data_dict = { 'ETTh2': Dataset_ETT_hour, 'ETTm1': Dataset_ETT_minute, 'ETTm2': Dataset_ETT_minute, - 'custom': Dataset_Custom, + 'ECL': Dataset_Custom, + 'Traffic': Dataset_Custom, + 'Weather': Dataset_Custom, 'm4': Dataset_M4, } diff --git a/models/TimeLLM.py b/models/TimeLLM.py index 46999df..8cbebf6 100644 --- a/models/TimeLLM.py +++ b/models/TimeLLM.py @@ -3,11 +3,10 @@ from math import sqrt import torch import torch.nn as nn -from transformers import LlamaConfig, LlamaModel, LlamaTokenizer +from transformers import LlamaConfig, LlamaModel, LlamaTokenizer, GPT2Config, GPT2Model, GPT2Tokenizer, BertConfig, \ + BertModel, BertTokenizer from layers.Embed import PatchEmbedding - import transformers - from layers.StandardNorm import Normalize transformers.logging.set_verbosity_error() @@ -37,49 +36,122 @@ class Model(nn.Module): self.seq_len = configs.seq_len self.d_ff = configs.d_ff self.top_k = 5 - self.d_llm = 4096 + self.d_llm = configs.llm_dim self.patch_len = configs.patch_len self.stride = configs.stride - self.llama_config = LlamaConfig.from_pretrained('/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/') - # self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b') - self.llama_config.num_hidden_layers = configs.llm_layers - self.llama_config.output_attentions = True - self.llama_config.output_hidden_states = True - try: - self.llama = LlamaModel.from_pretrained( - # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/", - 'huggyllama/llama-7b', - trust_remote_code=True, - local_files_only=True, - config=self.llama_config, - # load_in_4bit=True - ) - except EnvironmentError: # downloads model from HF is not already done - print("Local model files not found. Attempting to download...") - self.llama = LlamaModel.from_pretrained( - # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/", - 'huggyllama/llama-7b', - trust_remote_code=True, - local_files_only=False, - config=self.llama_config, - # load_in_4bit=True - ) - try: - self.tokenizer = LlamaTokenizer.from_pretrained( - # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model", - 'huggyllama/llama-7b', - trust_remote_code=True, - local_files_only=True - ) - except EnvironmentError: # downloads the tokenizer from HF if not already done - print("Local tokenizer files not found. Atempting to download them..") - self.tokenizer = LlamaTokenizer.from_pretrained( - # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model", - 'huggyllama/llama-7b', - trust_remote_code=True, - local_files_only=False - ) + if configs.llm_model == 'LLAMA': + # self.llama_config = LlamaConfig.from_pretrained('/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/') + self.llama_config = LlamaConfig.from_pretrained('huggyllama/llama-7b') + self.llama_config.num_hidden_layers = configs.llm_layers + self.llama_config.output_attentions = True + self.llama_config.output_hidden_states = True + try: + self.llm_model = LlamaModel.from_pretrained( + # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/", + 'huggyllama/llama-7b', + trust_remote_code=True, + local_files_only=True, + config=self.llama_config, + # load_in_4bit=True + ) + except EnvironmentError: # downloads model from HF is not already done + print("Local model files not found. Attempting to download...") + self.llm_model = LlamaModel.from_pretrained( + # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/", + 'huggyllama/llama-7b', + trust_remote_code=True, + local_files_only=False, + config=self.llama_config, + # load_in_4bit=True + ) + try: + self.tokenizer = LlamaTokenizer.from_pretrained( + # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model", + 'huggyllama/llama-7b', + trust_remote_code=True, + local_files_only=True + ) + except EnvironmentError: # downloads the tokenizer from HF if not already done + print("Local tokenizer files not found. Atempting to download them..") + self.tokenizer = LlamaTokenizer.from_pretrained( + # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model", + 'huggyllama/llama-7b', + trust_remote_code=True, + local_files_only=False + ) + elif configs.llm_model == 'GPT2': + self.gpt2_config = GPT2Config.from_pretrained('openai-community/gpt2') + + self.gpt2_config.num_hidden_layers = configs.llm_layers + self.gpt2_config.output_attentions = True + self.gpt2_config.output_hidden_states = True + try: + self.llm_model = GPT2Model.from_pretrained( + 'openai-community/gpt2', + trust_remote_code=True, + local_files_only=True, + config=self.gpt2_config, + ) + except EnvironmentError: # downloads model from HF is not already done + print("Local model files not found. Attempting to download...") + self.llm_model = GPT2Model.from_pretrained( + 'openai-community/gpt2', + trust_remote_code=True, + local_files_only=False, + config=self.gpt2_config, + ) + + try: + self.tokenizer = GPT2Tokenizer.from_pretrained( + 'openai-community/gpt2', + trust_remote_code=True, + local_files_only=True + ) + except EnvironmentError: # downloads the tokenizer from HF if not already done + print("Local tokenizer files not found. Atempting to download them..") + self.tokenizer = GPT2Tokenizer.from_pretrained( + 'openai-community/gpt2', + trust_remote_code=True, + local_files_only=False + ) + elif configs.llm_model == 'BERT': + self.bert_config = BertConfig.from_pretrained('google-bert/bert-base-uncased') + + self.bert_config.num_hidden_layers = configs.llm_layers + self.bert_config.output_attentions = True + self.bert_config.output_hidden_states = True + try: + self.llm_model = BertModel.from_pretrained( + 'google-bert/bert-base-uncased', + trust_remote_code=True, + local_files_only=True, + config=self.bert_config, + ) + except EnvironmentError: # downloads model from HF is not already done + print("Local model files not found. Attempting to download...") + self.llm_model = BertModel.from_pretrained( + 'google-bert/bert-base-uncased', + trust_remote_code=True, + local_files_only=False, + config=self.gpt2_config, + ) + + try: + self.tokenizer = BertTokenizer.from_pretrained( + 'google-bert/bert-base-uncased', + trust_remote_code=True, + local_files_only=True + ) + except EnvironmentError: # downloads the tokenizer from HF if not already done + print("Local tokenizer files not found. Atempting to download them..") + self.tokenizer = BertTokenizer.from_pretrained( + 'google-bert/bert-base-uncased', + trust_remote_code=True, + local_files_only=False + ) + else: + raise Exception('LLM model is not defined') if self.tokenizer.eos_token: self.tokenizer.pad_token = self.tokenizer.eos_token @@ -88,15 +160,20 @@ class Model(nn.Module): self.tokenizer.add_special_tokens({'pad_token': pad_token}) self.tokenizer.pad_token = pad_token - for param in self.llama.parameters(): + for param in self.llm_model.parameters(): param.requires_grad = False + if configs.prompt_domain: + self.description = configs.content + else: + self.description = 'The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment.' + self.dropout = nn.Dropout(configs.dropout) self.patch_embedding = PatchEmbedding( configs.d_model, self.patch_len, self.stride, configs.dropout) - self.word_embeddings = self.llama.get_input_embeddings().weight + self.word_embeddings = self.llm_model.get_input_embeddings().weight self.vocab_size = self.word_embeddings.shape[0] self.num_tokens = 1000 self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens) @@ -140,7 +217,7 @@ class Model(nn.Module): median_values_str = str(medians[b].tolist()[0]) lags_values_str = str(lags[b].tolist()) prompt_ = ( - f"<|start_prompt|>Dataset description: The Electricity Transformer Temperature (ETT) is a crucial indicator in the electric power long-term deployment." + f"<|start_prompt|>Dataset description: {self.description}" f"Task description: forecast the next {str(self.pred_len)} steps given the previous {str(self.seq_len)} steps information; " "Input statistics: " f"min value {min_values_str}, " @@ -155,7 +232,7 @@ class Model(nn.Module): x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous() prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids - prompt_embeddings = self.llama.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim) + prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device)) # (batch, prompt_token, dim) source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) @@ -163,7 +240,7 @@ class Model(nn.Module): enc_out, n_vars = self.patch_embedding(x_enc.to(torch.bfloat16)) enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1) - dec_out = self.llama(inputs_embeds=llama_enc_out).last_hidden_state + dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state dec_out = dec_out[:, :, :self.d_ff] dec_out = torch.reshape( diff --git a/run_m4.py b/run_m4.py index 33d430b..48ce2e3 100644 --- a/run_m4.py +++ b/run_m4.py @@ -80,7 +80,9 @@ parser.add_argument('--activation', type=str, default='gelu', help='activation') parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') parser.add_argument('--patch_len', type=int, default=16, help='patch length') parser.add_argument('--stride', type=int, default=8, help='stride') -parser.add_argument('--prompt_domain', type=int, default=0, help='stride') +parser.add_argument('--prompt_domain', type=int, default=0, help='') +parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model') +parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension') # optimization parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') diff --git a/run_main.py b/run_main.py index 04759a4..c364424 100644 --- a/run_main.py +++ b/run_main.py @@ -77,6 +77,9 @@ parser.add_argument('--output_attention', action='store_true', help='whether to parser.add_argument('--patch_len', type=int, default=16, help='patch length') parser.add_argument('--stride', type=int, default=8, help='stride') parser.add_argument('--prompt_domain', type=int, default=0, help='') +parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model') +parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension') + # optimization parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') diff --git a/run_pretrain.py b/run_pretrain.py index e4284a8..1baef4f 100644 --- a/run_pretrain.py +++ b/run_pretrain.py @@ -77,7 +77,9 @@ parser.add_argument('--activation', type=str, default='gelu', help='activation') parser.add_argument('--output_attention', action='store_true', help='whether to output attention in ecoder') parser.add_argument('--patch_len', type=int, default=16, help='patch length') parser.add_argument('--stride', type=int, default=8, help='stride') -parser.add_argument('--prompt_domain', type=int, default=0, help='stride') +parser.add_argument('--prompt_domain', type=int, default=0, help='') +parser.add_argument('--llm_model', type=str, default='LLAMA', help='LLM model') +parser.add_argument('--llm_dim', type=int, default='4096', help='LLM model dimension') # optimization parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers') diff --git a/scripts/TimeLLM_ECL.sh b/scripts/TimeLLM_ECL.sh index 36abf15..9db39fa 100644 --- a/scripts/TimeLLM_ECL.sh +++ b/scripts/TimeLLM_ECL.sh @@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path electricity.csv \ --model_id ECL_512_96 \ --model $model_name \ - --data custom \ + --data ECL \ --features M \ --seq_len 512 \ --label_len 48 \ @@ -42,7 +42,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path electricity.csv \ --model_id ECL_512_192 \ --model $model_name \ - --data custom \ + --data ECL \ --features M \ --seq_len 512 \ --label_len 48 \ @@ -66,7 +66,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path electricity.csv \ --model_id ECL_512_336 \ --model $model_name \ - --data custom \ + --data ECL \ --features M \ --seq_len 512 \ --label_len 48 \ @@ -90,7 +90,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path electricity.csv \ --model_id ECL_512_720 \ --model $model_name \ - --data custom \ + --data ECL \ --features M \ --seq_len 512 \ --label_len 48 \ diff --git a/scripts/TimeLLM_Traffic.sh b/scripts/TimeLLM_Traffic.sh index e7ec8f5..fb9dca0 100644 --- a/scripts/TimeLLM_Traffic.sh +++ b/scripts/TimeLLM_Traffic.sh @@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path traffic.csv \ --model_id traffic_512_96 \ --model $model_name \ - --data custom \ + --data Traffic \ --features M \ --seq_len 512 \ --label_len 48 \ @@ -42,7 +42,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path traffic.csv \ --model_id traffic_512_96 \ --model $model_name \ - --data custom \ + --data Traffic \ --features M \ --seq_len 512 \ --label_len 48 \ @@ -66,7 +66,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path traffic.csv \ --model_id traffic_512_96 \ --model $model_name \ - --data custom \ + --data Traffic \ --features M \ --seq_len 512 \ --label_len 48 \ @@ -90,7 +90,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path traffic.csv \ --model_id traffic_512_96 \ --model $model_name \ - --data custom \ + --data Traffic \ --features M \ --seq_len 512 \ --label_len 720 \ diff --git a/scripts/TimeLLM_Weather.sh b/scripts/TimeLLM_Weather.sh index b555ce3..046d8b3 100644 --- a/scripts/TimeLLM_Weather.sh +++ b/scripts/TimeLLM_Weather.sh @@ -18,7 +18,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path weather.csv \ --model_id weather_512_96 \ --model $model_name \ - --data custom \ + --data Weather \ --features M \ --seq_len 512 \ --label_len 48 \ @@ -44,7 +44,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path weather.csv \ --model_id weather_512_192 \ --model $model_name \ - --data custom \ + --data Weather \ --features M \ --seq_len 512 \ --label_len 48 \ @@ -70,7 +70,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path weather.csv \ --model_id weather_512_336 \ --model $model_name \ - --data custom \ + --data Weather \ --features M \ --seq_len 512 \ --label_len 48 \ @@ -96,7 +96,7 @@ accelerate launch --multi_gpu --mixed_precision bf16 --num_processes $num_proces --data_path weather.csv \ --model_id weather_512_720 \ --model $model_name \ - --data custom \ + --data Weather \ --features M \ --seq_len 512 \ --label_len 48 \