You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

36 lines
1.9 KiB
Python

import os
import re
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from evaluators.evaluator import Evaluator
from evaluators.chatglm_mixin import ChatGLMMixin
class ChatGLM_Evaluator(Evaluator, ChatGLMMixin):
def __init__(self, choices, k, model_name, device='cpu', finetune=None, finetune_method=None):
super(ChatGLM_Evaluator, self).__init__(choices, model_name, k)
self.finetune_method = finetune_method
self.finetune_name = finetune
self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, mirror="tuna")
if finetune_method == "ptuning":
CHECKPOINT_PATH = "ptuning/glm1/" + finetune
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
self.model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model = self.model.half().to(device)
self.model.transformer.prefix_encoder.float()
print("Model loaded! use GLM + " + finetune)
else:
self.model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, mirror="tuna",
resume_download=True).half().to(device)
print("Model loaded! (GLM original)")
# self.model = self.model.eval()