You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

40 lines
2.2 KiB
Python

import os
import re
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from evaluators.evaluator import Evaluator
from evaluators.chatglm_mixin import ChatGLMMixin
from peft import PeftModel
class ChatGLM2_Evaluator(Evaluator, ChatGLMMixin):
def __init__(self, choices, k, model_name, device, finetune=None, finetune_method=None):
super(ChatGLM2_Evaluator, self).__init__(choices, model_name, k)
self.finetune_method = finetune_method
self.finetune_name = finetune
self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, mirror="tuna")
if finetune_method == "lora":
self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, mirror="tuna",
resume_download=True).half().to(device)
peft_model_id = "lora/glm2/" + finetune
self.model = PeftModel.from_pretrained(self.model, peft_model_id)
print("Model loaded! use GLM2" + finetune)
elif finetune_method == "ptuning":
CHECKPOINT_PATH = "ptuning/glm2/" + finetune
config = AutoConfig.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, pre_seq_len=128)
self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model = self.model.half().to(device)
self.model.transformer.prefix_encoder.float()
print("Model loaded! use GLM2 + " + finetune)
else:
self.model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True, mirror="tuna",
resume_download=True).half().to(device)
print("Model loaded!(GLM2)")
# self.model = self.model.eval()