You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

192 lines
8.8 KiB
Python

import os
import re
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from evaluators.evaluator import Evaluator
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 5] = 5e4
return scores
class ChatGLM_Evaluator(Evaluator):
def __init__(self, choices, k, model_name, device, finetune=None, finetune_method=None):
super(ChatGLM_Evaluator, self).__init__(choices, model_name, k)
# try adding 'mirror="tuna"' and 'resume_download=True' if facing the 'read timed out' problem
# or directly clone the model
self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, mirror="tuna")
if finetune_method == "ptuning":
CHECKPOINT_PATH = "ptuning/glm1/" + finetune
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
self.model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
self.model = self.model.half().to(device)
self.model.transformer.prefix_encoder.float()
print("Model loaded! use GLM + " + finetune)
else:
self.model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, mirror="tuna",
resume_download=True).half().to(device)
print("Model loaded! (GLM original)")
# self.model = self.model.eval()
def eval_subject(self, subject_name, test_df, dev_df=None, few_shot=False, cot=False, save_result_dir=None):
correct_num = 0
result = []
score = []
answer_list = []
if few_shot:
history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot)
else:
history = []
answers = list(test_df['answer'])
for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)):
question = self.format_example(row, include_answer=False, cot=cot)
if few_shot:
response, _ = self.model.chat(self.tokenizer, question, do_sample=False, history=history)
response = response.strip()
# For ChatGLM, we use answer extraction in answer-only mode too.
ans, direct_extract = self.extract_cot_answer(row, response)
else: # zero-shot by extracting answer from distribution
response, _ = self.model.chat(self.tokenizer, question, max_length=300,
do_sample=False, history=history)
response = response.strip()
ans, direct_extract = self.extract_cot_answer(row, response)
# print(response, ans)
# ans = self.generate_dist(self.model, self.tokenizer, question, do_sample=False, max_length=2048,
# history=history)
if ans == answers[row_index]:
correct_num += 1
correct = 1
else:
correct = 0
if save_result_dir:
if few_shot:
result.append(response)
score.append(correct)
answer_list.append(ans)
correct_ratio = 100 * correct_num / len(answers)
if save_result_dir:
if few_shot:
test_df['model_output'] = result
test_df['correctness'] = score
test_df['model_answer'] = answer_list
result_file_name = f'{subject_name}_{correct_ratio}_test.csv'
if few_shot:
result_file_name = f'{subject_name}_{correct_ratio}_few_shot_test.csv'
test_df.to_csv(os.path.join(save_result_dir, result_file_name))
return correct_ratio
def eval_qa(self, subject_name, qa_df, save_result_dir=None):
history = []
for row_index, row in tqdm(qa_df.iterrows(), total=len(qa_df)):
question = row['question']
response, _ = self.model.chat(self.tokenizer, question, do_sample=False, history=history)
response = response.strip()
qa_df.loc[row_index, 'model_output'] = response
if save_result_dir:
result_file_name = f'{subject_name}_qa_test_result.csv'
qa_df.to_csv(os.path.join(save_result_dir, result_file_name))
def generate_few_shot_prompt(self, subject, dev_df, cot=False):
message = []
k = self.k
if self.k == -1:
k = dev_df.shape[0]
message.append(self.format_example(dev_df.iloc[0, :], cot=cot,
add_prompt=f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n"))
for i in range(1, k):
message.append(self.format_example(dev_df.iloc[i, :], cot=cot))
return message
def format_example(self, line, include_answer=True, cot=False, add_prompt=''):
example = add_prompt + line['question']
# print(example)
for choice in self.choices:
example += f'\n{choice}. {line[f"{choice}"]}'
example += '\n答案:'
if include_answer:
if cot:
ans = "让我们一步一步思考,\n" + line["explanation"] + f"\n所以答案是{line['answer']}"
else:
ans = line["answer"]
m = (example, ans)
return m
return example
def extract_cot_answer(self, line, gen_ans):
m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M)
if len(m) > 0 and m[-1] in self.choices:
return m[-1], True
answer_patterns = [
r'([ABCD])是正确的',
r'选项([ABCD])正确',
r'答案为([ABCD])',
r'答案是([ABCD])',
r'答案([ABCD])',
r'选择([ABCD])',
r'答案:([ABCD])',
r'选择答案([ABCD])',
r'正确答案是([ABCD])'
]
# RE extraction
for answer_pattern in answer_patterns:
m = re.search(answer_pattern, gen_ans, re.M)
if m:
answer = m.group(1)
return answer, False
# only containing one choice-character
m = re.findall(r'[ABCD]', gen_ans, re.M)
if len(m) == 1:
answer = m[0]
return answer, False
answer_word_counter = 0
# only containing one choice-context
for c in self.choices:
if str(line[f'{c}']) in gen_ans:
answer = c
answer_word_counter += 1
if answer_word_counter == 1:
return answer, False
return '-', False
def generate_dist(self, model, tokenizer, query, history, num_beams=1, max_length=2048,
do_sample=False, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, "max_length": 2048,
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
if not history:
prompt = query
else:
prompt = ""
for i, (old_query, response) in enumerate(history):
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(model.device)
outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, **gen_kwargs)
score = outputs.scores[0][0].tolist()
choice_score = [score[167], score[333], score[251], score[416]]
ranked_index = [index for index, value in
sorted(list(enumerate(choice_score)), key=lambda x: x[1], reverse=True)]
return self.choices[ranked_index[0]]