You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

38 lines
1.8 KiB
Python

from scoring.gpt_scorer import GPTScorer
from scoring.rogue_scorer import get_rouge_score
import pandas as pd
import time
from tqdm import tqdm
class AssessmentEngine:
def __init__(self, save_result_dir, api_key):
self.save_result_dir = save_result_dir
self.gpt_scorer = GPTScorer(api_key)
def eval_subject(self, subject_name, csv_file_name):
qa_result_df = pd.read_csv('logs/' + self.save_result_dir + '/' + csv_file_name)
start_time = time.time()
row_count = 0
rouge_score_sum = 0
for row_index, row in tqdm(qa_result_df.iterrows(), total=len(qa_result_df)):
row_count += 1
test_question = row['question']
model_response = row['model_output']
reference_answer = row['answer']
rouge_score = get_rouge_score(model_response, reference_answer)
rouge_1_f_score = rouge_score['rouge-1']['f']
rouge_score_sum += rouge_1_f_score
qa_result_df.loc[row_index, 'rouge_score'] = rouge_1_f_score
self.gpt_scorer.mode("accuracy")
gpt_score_acc, gpt_response_acc = self.gpt_scorer.score_with_chatgpt(test_question,
model_response, reference_answer)
qa_result_df.loc[row_index, 'gpt_score_acc'] = gpt_score_acc
qa_result_df.loc[row_index, 'gpt_response_acc'] = gpt_response_acc
end_time = time.time()
elapsed_time = end_time - start_time
print("共评估结果" + str(row_count) + "条,总共用时:", elapsed_time, "")
synthesis_score = rouge_score_sum / row_count
qa_result_df.to_csv('logs/' + self.save_result_dir + '/' + subject_name + '_qa_test_score_'
+ str(synthesis_score) + '.csv', index=False)