You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

176 lines
7.3 KiB
Python

import os
import re
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList
from evaluators.evaluator import Evaluator
class ChatGLMMixin:
def __init__(self):
self.tokenizer = None
self.model = None
self.model_name = None
self.k = None
self.choices = None
self.finetune_name = None
def eval_subject(self, subject_name, test_df, dev_df=None, few_shot=False, cot=False, save_result_dir=None):
correct_num = 0
result = []
score = []
answer_list = []
if few_shot:
history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot)
else:
history = self.generate_zero_shot_prompt(is_choice_question=True)
answers = list(test_df['answer'])
for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)):
question = self.format_example(row, include_answer=False, cot=cot)
if few_shot:
response, _ = self.model.chat(self.tokenizer, question, max_length=2000,
do_sample=False, history=history)
response = response.strip()
ans, direct_extract = self.extract_cot_answer(row, response)
else: # zero-shot by extracting answer from distribution
response, _ = self.model.chat(self.tokenizer, question, max_length=2000,
do_sample=False, history=history)
response = response.strip()
ans, direct_extract = self.extract_cot_answer(row, response)
if ans == answers[row_index]:
correct_num += 1
correct = 1
else:
correct = 0
if save_result_dir:
result.append(response)
score.append(correct)
answer_list.append(ans)
correct_ratio = 100 * correct_num / len(answers)
if save_result_dir:
test_df['model_output'] = result
test_df['correctness'] = score
test_df['model_answer'] = answer_list
result_file_name = f'{subject_name}_{correct_ratio}_test.csv'
if few_shot:
result_file_name = f'{subject_name}_{correct_ratio}_few_shot_test.csv'
test_df.to_csv(os.path.join(save_result_dir, result_file_name))
return correct_ratio
def eval_qa(self, subject_name, qa_df, save_result_dir=None):
history = self.generate_zero_shot_prompt(is_choice_question=False)
for row_index, row in tqdm(qa_df.iterrows(), total=len(qa_df)):
question = row['question']
response, _ = self.model.chat(self.tokenizer, question, max_length=2000,
do_sample=False, history=history)
response = response.strip()
qa_df.loc[row_index, 'model_output'] = response
# current_length = 0
# response = ""
# for resp, _ in self.model.stream_chat(self.tokenizer, question, max_length=300,
# do_sample=False, history=history):
# print(resp[current_length:], end="", flush=True)
# current_length = len(resp)
# response = resp
# print('')
if save_result_dir and self.finetune_name is not None:
result_file_name = f'{subject_name}_qa_test_result.csv'
qa_df.to_csv(os.path.join(save_result_dir, result_file_name))
return qa_df
def generate_few_shot_prompt(self, subject, dev_df, cot=False):
message = []
k = self.k
if self.k == -1:
k = dev_df.shape[0]
init_example = self.format_example(dev_df.iloc[0, :], cot=cot,
add_prompt=f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n")
if isinstance(init_example, list):
message.extend(init_example)
else:
message.append(init_example)
for i in range(1, k):
example = self.format_example(dev_df.iloc[i, :], cot=cot)
if isinstance(example, list):
message.extend(example)
else:
message.append(example)
return message
def generate_zero_shot_prompt(self, is_choice_question=True):
if self.model_name == 'chatglm3' and is_choice_question:
return [{'role': 'user',
'content': '接下来会提供给你一些选择题,请选出正确的答案,给出正确的选项即可。'},
{'role': 'assistant',
'content': '好的,我会尽力解答。'}]
elif self.model_name == 'chatglm3' and not is_choice_question:
return [{'role': 'user',
'content': '接下来会给你一些一些汽车领域相关问题,请回答。'},
{'role': 'assistant',
'content': '好的,我会尽力解答。'}]
else:
return []
def format_example(self, line, include_answer=True, cot=False, add_prompt=''):
example = add_prompt + line['question']
# print(example)
for choice in self.choices:
example += f'\n{choice}. {line[f"{choice}"]}'
example += '\n答案:'
if include_answer:
if cot:
ans = "让我们一步一步思考,\n" + line["explanation"] + f"\n所以答案是{line['answer']}"
else:
ans = line["answer"]
if self.model_name == 'chatglm3':
m = [{
'role': 'user',
'content': example
}, {
'role': 'assistant',
'content': ans
}]
else:
m = (example, ans)
return m
return example
def extract_cot_answer(self, line, gen_ans):
m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M)
if len(m) > 0 and m[-1] in self.choices:
return m[-1], True
answer_patterns = [
r'([ABCD])是正确的',
r'选项([ABCD])正确',
r'答案为([ABCD])',
r'答案是([ABCD])',
r'答案([ABCD])',
r'选择([ABCD])',
r'答案:([ABCD])',
r'选择答案([ABCD])',
r'正确答案是([ABCD])'
]
# RE extraction
for answer_pattern in answer_patterns:
m = re.search(answer_pattern, gen_ans, re.M)
if m:
answer = m.group(1)
return answer, False
# only containing one choice-character
m = re.findall(r'[ABCD]', gen_ans, re.M)
if len(m) == 1:
answer = m[0]
return answer, False
answer_word_counter = 0
# only containing one choice-context
for c in self.choices:
if str(line[f'{c}']) in gen_ans:
answer = c
answer_word_counter += 1
if answer_word_counter == 1:
return answer, False
return '-', False