From bdd634deeb708240ca59c03c812a33360a7a3a01 Mon Sep 17 00:00:00 2001 From: Hamza Ba-mohammed <62408114+HamBa-m@users.noreply.github.com> Date: Wed, 6 Mar 2024 05:31:03 +0100 Subject: [PATCH] Update TimeLLM.py add try/except block to launch model and tokenizer download if not already available in local --- models/TimeLLM.py | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/models/TimeLLM.py b/models/TimeLLM.py index 7c312bc..46999df 100644 --- a/models/TimeLLM.py +++ b/models/TimeLLM.py @@ -46,21 +46,40 @@ class Model(nn.Module): self.llama_config.num_hidden_layers = configs.llm_layers self.llama_config.output_attentions = True self.llama_config.output_hidden_states = True - self.llama = LlamaModel.from_pretrained( - "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/", - # 'huggyllama/llama-7b', + try: + self.llama = LlamaModel.from_pretrained( + # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/", + 'huggyllama/llama-7b', trust_remote_code=True, local_files_only=True, config=self.llama_config, - load_in_4bit=True + # load_in_4bit=True ) - - self.tokenizer = LlamaTokenizer.from_pretrained( - "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model", - # 'huggyllama/llama-7b', + except EnvironmentError: # downloads model from HF is not already done + print("Local model files not found. Attempting to download...") + self.llama = LlamaModel.from_pretrained( + # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/", + 'huggyllama/llama-7b', trust_remote_code=True, - local_files_only=True + local_files_only=False, + config=self.llama_config, + # load_in_4bit=True ) + try: + self.tokenizer = LlamaTokenizer.from_pretrained( + # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model", + 'huggyllama/llama-7b', + trust_remote_code=True, + local_files_only=True + ) + except EnvironmentError: # downloads the tokenizer from HF if not already done + print("Local tokenizer files not found. Atempting to download them..") + self.tokenizer = LlamaTokenizer.from_pretrained( + # "/mnt/alps/modelhub/pretrained_model/LLaMA/7B_hf/tokenizer.model", + 'huggyllama/llama-7b', + trust_remote_code=True, + local_files_only=False + ) if self.tokenizer.eos_token: self.tokenizer.pad_token = self.tokenizer.eos_token