You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
29 lines
1.3 KiB
Python
29 lines
1.3 KiB
Python
7 months ago
|
import pandas as pd
|
||
|
import os
|
||
|
|
||
|
def merge_df(save_result_dir, df1_name, df2_name, df1_row_name, df1_row_new_name, df2_row_name, df2_row_new_name, file_type='csv'):
|
||
|
df1_path = str(os.path.join(save_result_dir, df1_name))
|
||
|
df2_path = str(os.path.join(save_result_dir, df2_name))
|
||
|
if file_type == 'csv':
|
||
|
df1 = pd.read_csv(df1_path)
|
||
|
df2 = pd.read_csv(df2_path)
|
||
|
elif file_type == 'json':
|
||
|
df1 = pd.read_json(df1_path)
|
||
|
df2 = pd.read_json(df2_path)
|
||
|
else:
|
||
|
raise ValueError("Invalid file type. Please choose either 'csv' or 'json'.")
|
||
|
df2 = df2.rename(columns={df2_row_name: df2_row_new_name})
|
||
|
df1 = df1.rename(columns={df1_row_name: df1_row_new_name}).join(df2[df2_row_new_name])
|
||
|
result_file_name = f'{df1_name}_{df2_name}_merge.csv'
|
||
|
df1.to_csv(os.path.join(save_result_dir, result_file_name))
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
dir_path = 'logs/pt_sft'
|
||
|
pt_file_name = 'output-pt-sft-1-0.95-0.5-1.2.json'
|
||
|
npt_file_name = 'output-npt-sft-1-0.95-0.5-1.2.json'
|
||
|
predict_row_name = 'Predict'
|
||
|
pt_predict_row_name = 'predict_finetune'
|
||
|
npt_predict_row_name = 'predict_origin'
|
||
|
merge_df(dir_path, pt_file_name, npt_file_name, predict_row_name, pt_predict_row_name, predict_row_name, npt_predict_row_name, 'json')
|