增加评分代码
parent
c06f9a3684
commit
148a9e1de0
@ -1,5 +1,3 @@
|
||||
# rogue
|
||||
from rogue import get_rouge_score
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,37 @@
|
||||
from scoring.gpt_scorer import GPTScorer
|
||||
from scoring.rogue_scorer import get_rouge_score
|
||||
import pandas as pd
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class AssessmentEngine:
|
||||
def __init__(self, save_result_dir, api_key):
|
||||
self.save_result_dir = save_result_dir
|
||||
self.gpt_scorer = GPTScorer(api_key)
|
||||
|
||||
def eval_subject(self, subject_name, csv_file_name):
|
||||
qa_result_df = pd.read_csv('logs/' + self.save_result_dir + '/' + csv_file_name)
|
||||
start_time = time.time()
|
||||
row_count = 0
|
||||
rouge_score_sum = 0
|
||||
for row_index, row in tqdm(qa_result_df.iterrows(), total=len(qa_result_df)):
|
||||
row_count += 1
|
||||
test_question = row['question']
|
||||
model_response = row['model_output']
|
||||
reference_answer = row['answer']
|
||||
rouge_score = get_rouge_score(model_response, reference_answer)
|
||||
rouge_1_f_score = rouge_score['rouge-1']['f']
|
||||
rouge_score_sum += rouge_1_f_score
|
||||
qa_result_df.loc[row_index, 'rouge_score'] = rouge_1_f_score
|
||||
self.gpt_scorer.mode("accuracy")
|
||||
gpt_score_acc, gpt_response_acc = self.gpt_scorer.score_with_chatgpt(test_question,
|
||||
model_response, reference_answer)
|
||||
qa_result_df.loc[row_index, 'gpt_score_acc'] = gpt_score_acc
|
||||
qa_result_df.loc[row_index, 'gpt_response_acc'] = gpt_response_acc
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print("共评估结果" + str(row_count) + "条,总共用时:", elapsed_time, "秒")
|
||||
synthesis_score = rouge_score_sum / row_count
|
||||
qa_result_df.to_csv('logs/' + self.save_result_dir + '/' + subject_name + '_qa_test_score_'
|
||||
+ str(synthesis_score) + '.csv', index=False)
|
@ -0,0 +1,7 @@
|
||||
from scoring.assessment_engine import AssessmentEngine
|
||||
|
||||
assessment_engine = AssessmentEngine("chatglm2_glm2_pt1_2024-03-08_11-24-47",
|
||||
"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")
|
||||
assessment_engine.eval_subject("car_knowledge", "car_knowledge_qa_test_result.csv")
|
||||
assessment_engine.eval_subject("car_use", "car_use_qa_test_result.csv")
|
||||
assessment_engine.eval_subject("car_market", "car_market_qa_test_result.csv")
|
Loading…
Reference in New Issue