You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

29 lines
1.3 KiB
Python

7 months ago
import pandas as pd
import os
def merge_df(save_result_dir, df1_name, df2_name, df1_row_name, df1_row_new_name, df2_row_name, df2_row_new_name, file_type='csv'):
df1_path = str(os.path.join(save_result_dir, df1_name))
df2_path = str(os.path.join(save_result_dir, df2_name))
if file_type == 'csv':
df1 = pd.read_csv(df1_path)
df2 = pd.read_csv(df2_path)
elif file_type == 'json':
df1 = pd.read_json(df1_path)
df2 = pd.read_json(df2_path)
else:
raise ValueError("Invalid file type. Please choose either 'csv' or 'json'.")
df2 = df2.rename(columns={df2_row_name: df2_row_new_name})
df1 = df1.rename(columns={df1_row_name: df1_row_new_name}).join(df2[df2_row_new_name])
result_file_name = f'{df1_name}_{df2_name}_merge.csv'
df1.to_csv(os.path.join(save_result_dir, result_file_name))
if __name__ == "__main__":
dir_path = 'logs/pt_sft'
pt_file_name = 'output-pt-sft-1-0.95-0.5-1.2.json'
npt_file_name = 'output-npt-sft-1-0.95-0.5-1.2.json'
predict_row_name = 'Predict'
pt_predict_row_name = 'predict_finetune'
npt_predict_row_name = 'predict_origin'
merge_df(dir_path, pt_file_name, npt_file_name, predict_row_name, pt_predict_row_name, predict_row_name, npt_predict_row_name, 'json')