|
|
|
import os
|
|
|
|
import re
|
|
|
|
from tqdm import tqdm
|
|
|
|
import torch
|
|
|
|
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
|
|
|
from evaluators.evaluator import Evaluator
|
|
|
|
from evaluators.chatglm_mixin import ChatGLMMixin
|
|
|
|
from peft import PeftModel
|
|
|
|
|
|
|
|
class ChatGLM2_Evaluator(Evaluator, ChatGLMMixin):
|
|
|
|
def __init__(self, choices, k, model_name, device, finetune=None, finetune_method=None):
|
|
|
|
super(ChatGLM2_Evaluator, self).__init__(choices, model_name, k)
|
|
|
|
self.finetune_method = finetune_method
|
|
|
|
self.finetune_name = finetune
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, mirror="tuna")
|
|
|
|
if finetune_method == "lora":
|
|
|
|
self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, mirror="tuna",
|
|
|
|
resume_download=True).half().to(device)
|
|
|
|
peft_model_id = "lora/glm2/" + finetune
|
|
|
|
self.model = PeftModel.from_pretrained(self.model, peft_model_id)
|
|
|
|
print("Model loaded! use GLM2" + finetune)
|
|
|
|
elif finetune_method == "ptuning":
|
|
|
|
CHECKPOINT_PATH = "ptuning/glm2/" + finetune
|
|
|
|
config = AutoConfig.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, pre_seq_len=128)
|
|
|
|
self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", config=config, trust_remote_code=True)
|
|
|
|
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
|
|
|
|
new_prefix_state_dict = {}
|
|
|
|
for k, v in prefix_state_dict.items():
|
|
|
|
if k.startswith("transformer.prefix_encoder."):
|
|
|
|
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
|
|
|
|
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
|
|
|
self.model = self.model.half().to(device)
|
|
|
|
self.model.transformer.prefix_encoder.float()
|
|
|
|
print("Model loaded! use GLM2 + " + finetune)
|
|
|
|
else:
|
|
|
|
self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, mirror="tuna",
|
|
|
|
resume_download=True).half().to(device)
|
|
|
|
print("Model loaded!(GLM2)")
|
|
|
|
# self.model = self.model.eval()
|