import os import re from tqdm import tqdm import torch from transformers import AutoTokenizer, AutoModel, AutoConfig from evaluators.evaluator import Evaluator from evaluators.chatglm_mixin import ChatGLMMixin from peft import PeftModel class ChatGLM2_Evaluator(Evaluator, ChatGLMMixin): def __init__(self, choices, k, model_name, device, finetune=None, finetune_method=None): super(ChatGLM2_Evaluator, self).__init__(choices, model_name, k) self.finetune_method = finetune_method self.finetune_name = finetune self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, mirror="tuna") if finetune_method == "lora": self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, mirror="tuna", resume_download=True).half().to(device) peft_model_id = "lora/glm2/" + finetune self.model = PeftModel.from_pretrained(self.model, peft_model_id) print("Model loaded! use GLM2" + finetune) elif finetune_method == "ptuning": CHECKPOINT_PATH = "ptuning/glm2/" + finetune config = AutoConfig.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, pre_seq_len=128) self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", config=config, trust_remote_code=True) prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) self.model = self.model.half().to(device) self.model.transformer.prefix_encoder.float() print("Model loaded! use GLM2 + " + finetune) else: self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, mirror="tuna", resume_download=True).half().to(device) print("Model loaded!(GLM2)") # self.model = self.model.eval()