You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
7.4 KiB
Python
179 lines
7.4 KiB
Python
import os
|
|
import re
|
|
from tqdm import tqdm
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
|
from transformers.generation.logits_process import LogitsProcessor
|
|
from transformers.generation.utils import LogitsProcessorList
|
|
from evaluators.evaluator import Evaluator
|
|
|
|
|
|
class ChatGLMMixin:
|
|
def __init__(self):
|
|
self.tokenizer = None
|
|
self.model = None
|
|
self.model_name = None
|
|
self.k = None
|
|
self.choices = None
|
|
self.finetune_name = None
|
|
|
|
def eval_subject(self, subject_name, test_df, dev_df=None, few_shot=False, cot=False, save_result_dir=None):
|
|
correct_num = 0
|
|
result = []
|
|
score = []
|
|
answer_list = []
|
|
if few_shot:
|
|
history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot)
|
|
else:
|
|
history = self.generate_zero_shot_prompt(is_choice_question=True)
|
|
answers = list(test_df['answer'])
|
|
for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
|
question = self.format_example(row, include_answer=False, cot=cot)
|
|
history_tmp = history.copy()
|
|
if few_shot:
|
|
response, _ = self.model.chat(self.tokenizer, question, max_length=2000,
|
|
do_sample=False, history=history_tmp)
|
|
response = response.strip()
|
|
ans, direct_extract = self.extract_cot_answer(row, response)
|
|
else: # zero-shot by extracting answer from distribution
|
|
response, _ = self.model.chat(self.tokenizer, question, max_length=2000,
|
|
do_sample=False, history=history_tmp)
|
|
response = response.strip()
|
|
ans, direct_extract = self.extract_cot_answer(row, response)
|
|
if ans == answers[row_index]:
|
|
correct_num += 1
|
|
correct = 1
|
|
else:
|
|
correct = 0
|
|
if save_result_dir:
|
|
result.append(response)
|
|
score.append(correct)
|
|
answer_list.append(ans)
|
|
correct_ratio = 100 * correct_num / len(answers)
|
|
|
|
if save_result_dir:
|
|
test_df['model_output'] = result
|
|
test_df['correctness'] = score
|
|
test_df['model_answer'] = answer_list
|
|
result_file_name = f'{subject_name}_{correct_ratio}_test.csv'
|
|
if few_shot:
|
|
result_file_name = f'{subject_name}_{correct_ratio}_few_shot_test.csv'
|
|
test_df.to_csv(os.path.join(save_result_dir, result_file_name))
|
|
|
|
return correct_ratio
|
|
|
|
def eval_qa(self, subject_name, qa_df, save_result_dir=None):
|
|
history = self.generate_zero_shot_prompt(is_choice_question=False)
|
|
for row_index, row in tqdm(qa_df.iterrows(), total=len(qa_df)):
|
|
question = row['question']
|
|
history_tmp = history.copy()
|
|
response, _ = self.model.chat(self.tokenizer, question, max_length=2000,
|
|
do_sample=False, history=history_tmp)
|
|
response = response.strip()
|
|
qa_df.loc[row_index, 'model_output'] = response
|
|
# current_length = 0
|
|
# response = ""
|
|
# for resp, _ in self.model.stream_chat(self.tokenizer, question, max_length=300,
|
|
# do_sample=False, history=history):
|
|
# print(resp[current_length:], end="", flush=True)
|
|
# current_length = len(resp)
|
|
# response = resp
|
|
# print('')
|
|
if save_result_dir and self.finetune_name is not None:
|
|
result_file_name = f'{subject_name}_qa_test_result.csv'
|
|
qa_df.to_csv(os.path.join(save_result_dir, result_file_name))
|
|
return qa_df
|
|
|
|
def generate_few_shot_prompt(self, subject, dev_df, cot=False):
|
|
message = []
|
|
k = self.k
|
|
if self.k == -1:
|
|
k = dev_df.shape[0]
|
|
init_example = self.format_example(dev_df.iloc[0, :], cot=cot,
|
|
add_prompt=f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n")
|
|
if isinstance(init_example, list):
|
|
message.extend(init_example)
|
|
else:
|
|
message.append(init_example)
|
|
for i in range(1, k):
|
|
example = self.format_example(dev_df.iloc[i, :], cot=cot)
|
|
if isinstance(example, list):
|
|
message.extend(example)
|
|
else:
|
|
message.append(example)
|
|
return message
|
|
|
|
def generate_zero_shot_prompt(self, is_choice_question=True):
|
|
if self.model_name == 'chatglm3' and is_choice_question:
|
|
return [{'role': 'user',
|
|
'content': '接下来会提供给你一些选择题,请选出正确的答案,给出正确的选项即可。'},
|
|
{'role': 'assistant',
|
|
'content': '好的,我会尽力解答。'}]
|
|
elif self.model_name == 'chatglm3' and not is_choice_question:
|
|
return [{'role': 'user',
|
|
'content': '接下来会给你一些一些汽车领域相关问题,请回答。'},
|
|
{'role': 'assistant',
|
|
'content': '好的,我会尽力解答。'}]
|
|
else:
|
|
return []
|
|
|
|
def format_example(self, line, include_answer=True, cot=False, add_prompt=''):
|
|
example = add_prompt + line['question']
|
|
# print(example)
|
|
for choice in self.choices:
|
|
example += f'\n{choice}. {line[f"{choice}"]}'
|
|
example += '\n答案:'
|
|
if include_answer:
|
|
if cot:
|
|
ans = "让我们一步一步思考,\n" + line["explanation"] + f"\n所以答案是{line['answer']}。"
|
|
else:
|
|
ans = line["answer"]
|
|
if self.model_name == 'chatglm3':
|
|
m = [{
|
|
'role': 'user',
|
|
'content': example
|
|
}, {
|
|
'role': 'assistant',
|
|
'content': ans
|
|
}]
|
|
else:
|
|
m = (example, ans)
|
|
return m
|
|
return example
|
|
|
|
def extract_cot_answer(self, line, gen_ans):
|
|
m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M)
|
|
if len(m) > 0 and m[-1] in self.choices:
|
|
return m[-1], True
|
|
answer_patterns = [
|
|
r'([ABCD])是正确的',
|
|
r'选项([ABCD])正确',
|
|
r'答案为([ABCD])',
|
|
r'答案是([ABCD])',
|
|
r'答案([ABCD])',
|
|
r'选择([ABCD])',
|
|
r'答案:([ABCD])',
|
|
r'选择答案([ABCD])',
|
|
r'正确答案是([ABCD])'
|
|
]
|
|
# RE extraction
|
|
for answer_pattern in answer_patterns:
|
|
m = re.search(answer_pattern, gen_ans, re.M)
|
|
if m:
|
|
answer = m.group(1)
|
|
return answer, False
|
|
# only containing one choice-character
|
|
m = re.findall(r'[ABCD]', gen_ans, re.M)
|
|
if len(m) == 1:
|
|
answer = m[0]
|
|
return answer, False
|
|
answer_word_counter = 0
|
|
# only containing one choice-context
|
|
for c in self.choices:
|
|
if str(line[f'{c}']) in gen_ans:
|
|
answer = c
|
|
answer_word_counter += 1
|
|
if answer_word_counter == 1:
|
|
return answer, False
|
|
return '-', False
|