You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
import re
|
|
import string
|
|
class Evaluator:
|
|
def __init__(self, choices, model_name, k=-1):
|
|
self.choices = choices
|
|
self.model_name = model_name
|
|
self.k = k
|
|
self.puncs = list(string.punctuation)
|
|
|
|
def format_example(self, line, include_answer=True):
|
|
example = line['question']
|
|
# print(example)
|
|
for choice in self.choices:
|
|
example += f'\n{choice}. {line[f"{choice}"]}'
|
|
example += '\n答案:'
|
|
if include_answer:
|
|
example += f'{line["answer"]}\n\n'
|
|
return example
|
|
|
|
def generate_few_shot_prompt(self, subject, dev_df):
|
|
prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n"
|
|
k = self.k
|
|
if self.k == -1:
|
|
k = dev_df.shape[0]
|
|
for i in range(k):
|
|
prompt += self.format_example(dev_df.iloc[i, :])
|
|
return prompt
|
|
|
|
def eval_subject(self, subject_name, test_df, dev_df=None, few_shot=False, save_result_dir=None):
|
|
pass
|
|
|
|
def normalize_answer(self,s):
|
|
|
|
def white_space_fix(text):
|
|
return ' '.join(text.split())
|
|
|
|
def remove_punc(text):
|
|
exclude=set(self.puncs)
|
|
return ''.join(ch for ch in text if ch not in exclude)
|
|
|
|
def lower(text):
|
|
return text.lower()
|
|
|
|
return white_space_fix(remove_punc(lower(s)))
|
|
|
|
def exact_match(self,pred, target):
|
|
return self.normalize_answer(pred)==self.normalize_answer(target)
|