You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

100 lines
3.6 KiB
Python

import os
import argparse
import pandas as pd
import torch
from evaluators.chatgpt import ChatGPT_Evaluator
from evaluators.chatglm import ChatGLM_Evaluator
from evaluators.chatglm2 import ChatGLM_Evaluator as ChatGLM2_Evaluator
import time
choices = ["A", "B", "C", "D"]
def main(args):
if "turbo" in args.model_name or "gpt-4" in args.model_name:
# print("Not supported yet")
# return -1
evaluator=ChatGPT_Evaluator(
choices=choices,
k=args.ntrain,
api_key=args.openai_key,
model_name=args.model_name
)
elif "chatglm2" in args.model_name:
if args.cuda_device:
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_device
device = torch.device("cuda")
if args.finetune:
fine_tune_model = args.finetune
else:
fine_tune_model = None
evaluator = ChatGLM2_Evaluator(
choices=choices,
k=args.ntrain,
model_name=args.model_name,
device=device,
finetune=fine_tune_model
)
elif "chatglm" in args.model_name:
if args.cuda_device:
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_device
device = torch.device("cuda")
if args.finetune:
fine_tune_model = args.finetune
else:
fine_tune_model = None
evaluator = ChatGLM_Evaluator(
choices=choices,
k=args.ntrain,
model_name=args.model_name,
device=device,
finetune=fine_tune_model
)
else:
print("Unknown model name")
return -1
if not os.path.exists(r"logs"):
os.mkdir(r"logs")
run_date = time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime(time.time()))
if args.finetune:
fine_tune_model_name = args.finetune
else:
fine_tune_model_name = 'original'
save_result_dir = os.path.join(r"logs", f"{args.model_name}_{fine_tune_model_name}_{run_date}")
os.mkdir(save_result_dir)
subject_list = ['computer_architecture', 'car_knowledge', 'car_use', 'car_market']
for subject_name in subject_list:
print(subject_name)
# subject_name=args.subject
val_file_path = os.path.join('data/val', f'{subject_name}_val.csv')
val_df = pd.read_csv(val_file_path)
if args.few_shot:
dev_file_path = os.path.join('data/dev', f'{subject_name}_dev.csv')
dev_df = pd.read_csv(dev_file_path)
correct_ratio = evaluator.eval_subject(subject_name, val_df, dev_df, few_shot=args.few_shot,
save_result_dir=save_result_dir, cot=args.cot)
else:
correct_ratio = evaluator.eval_subject(subject_name, val_df, few_shot=args.few_shot,
save_result_dir=save_result_dir)
print("Acc:", correct_ratio)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ntrain", "-k", type=int, default=5)
parser.add_argument("--openai_key", type=str, default="xxx")
parser.add_argument("--minimax_group_id", type=str, default="xxx")
parser.add_argument("--minimax_key", type=str, default="xxx")
parser.add_argument("--few_shot", action="store_true")
parser.add_argument("--model_name", type=str)
parser.add_argument("--cot", action="store_true")
# parser.add_argument("--subject","-s",type=str,default="operating_system")
parser.add_argument("--cuda_device", type=str)
parser.add_argument("--finetune", type=str)
args = parser.parse_args()
main(args)