You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
6.1 KiB
Python

import os
import argparse
import pandas as pd
import torch
from evaluators.chatgpt import ChatGPT_Evaluator
from evaluators.chatglm import ChatGLM_Evaluator
from evaluators.chatglm2 import ChatGLM2_Evaluator
from evaluators.chatglm3 import ChatGLM3_Evaluator
import time
from scoring.assessment_engine import AssessmentEngine
choices = ["A", "B", "C", "D"]
device = torch.device("cpu")
def main(args):
global device
evaluator_class = None
if args.cuda_device:
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_device
device = torch.device("cuda")
if "turbo" in args.model_name or "gpt-4" in args.model_name:
evaluator = ChatGPT_Evaluator(
choices=choices,
k=args.ntrain,
api_key=args.openai_key,
model_name=args.model_name
)
elif "chatglm3" in args.model_name:
if args.finetune:
fine_tune_model = args.finetune
evaluator_class = ChatGLM3_Evaluator
else:
fine_tune_model = None
evaluator = ChatGLM3_Evaluator(
choices=choices,
k=args.ntrain,
model_name=args.model_name,
device=device,
finetune=fine_tune_model,
finetune_method=args.finetune_method
)
elif "chatglm2" in args.model_name:
if args.finetune:
fine_tune_model = args.finetune
evaluator_class = ChatGLM2_Evaluator
else:
fine_tune_model = None
evaluator = ChatGLM2_Evaluator(
choices=choices,
k=args.ntrain,
model_name=args.model_name,
device=device,
finetune=fine_tune_model,
finetune_method=args.finetune_method
)
elif "chatglm" in args.model_name:
if args.finetune:
fine_tune_model = args.finetune
evaluator_class = ChatGLM_Evaluator
else:
fine_tune_model = None
evaluator = ChatGLM_Evaluator(
choices=choices,
k=args.ntrain,
model_name=args.model_name,
device=device,
finetune=fine_tune_model,
finetune_method=args.finetune_method
)
else:
print("Unknown model name")
return -1
if not os.path.exists(r"logs"):
os.mkdir(r"logs")
run_date = time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime(time.time()))
if args.finetune:
fine_tune_model_name = args.finetune
else:
fine_tune_model_name = 'original'
save_result_dir = os.path.join(r"logs", f"{args.model_name}_{fine_tune_model_name}/{run_date}")
os.makedirs(save_result_dir)
subject_list = ['computer_architecture', 'car_knowledge', 'car_use', 'car_market']
subject_list.extend(['car_knowledge_in_train', 'car_use_in_train', 'car_market_in_train'])
# qa_subject_list = ['car_knowledge', 'car_use', 'car_market']
qa_subject_list = ['car_market']
for subject_name in subject_list:
print("Now testing: " + subject_name)
# subject_name=args.subject
val_file_path = os.path.join('data/val', f'{subject_name}_val.csv')
val_df = pd.read_csv(val_file_path)
if args.few_shot:
dev_file_path = os.path.join('data/dev', f'{subject_name}_dev.csv')
dev_df = pd.read_csv(dev_file_path)
correct_ratio = evaluator.eval_subject(subject_name, val_df, dev_df, few_shot=args.few_shot,
save_result_dir=save_result_dir, cot=args.cot)
else:
correct_ratio = evaluator.eval_subject(subject_name, val_df, few_shot=args.few_shot,
save_result_dir=save_result_dir)
print("Acc:", correct_ratio)
result_list = []
# for subject_name in qa_subject_list:
# print("Now testing: " + subject_name)
# qa_file_path = os.path.join('data/qa', f'{subject_name}_qa.csv')
# qa_df = pd.read_csv(qa_file_path)
# result_list.append(evaluator.eval_qa(subject_name, qa_df, save_result_dir=save_result_dir))
# if evaluator_class is not None:
# del evaluator
# evaluator = evaluator_class(
# choices=choices,
# k=args.ntrain,
# model_name=args.model_name,
# device=device
# )
# for index,subject_name in enumerate(qa_subject_list):
# print("Now testing (origin): " + subject_name)
# qa_file_path = os.path.join('data/qa', f'{subject_name}_qa.csv')
# qa_df = pd.read_csv(qa_file_path)
# origin_result = evaluator.eval_qa(subject_name, qa_df, save_result_dir=save_result_dir)
# origin_result = origin_result.rename(columns={"model_output": "predict_origin"})
# result_df = result_list[index].rename(columns={"model_output": "predict_finetune"}).join(origin_result["predict_origin"])
# result_file_name = f'{subject_name}_qa_compare_result.csv'
# result_df.to_csv(os.path.join(save_result_dir, result_file_name))
# assessment_engine = AssessmentEngine(save_result_dir, args.api_key)
# for subject_name in qa_subject_list:
# assessment_engine.eval_result_diff(f'{subject_name}_qa_compare_result.csv')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ntrain", "-k", type=int, default=5)
parser.add_argument("--openai_key", type=str, default="xxx")
parser.add_argument("--api_key", type=str, default="xxx")
parser.add_argument("--llm_engine", type=str, default="gemini")
parser.add_argument("--minimax_group_id", type=str, default="xxx")
parser.add_argument("--minimax_key", type=str, default="xxx")
parser.add_argument("--few_shot", action="store_true")
parser.add_argument("--model_name", type=str)
parser.add_argument("--cot", action="store_true")
# parser.add_argument("--subject","-s",type=str,default="operating_system")
parser.add_argument("--cuda_device", type=str)
parser.add_argument("--finetune", type=str)
parser.add_argument("--finetune_method", type=str)
user_args = parser.parse_args()
main(user_args)