|
|
|
@ -7,6 +7,7 @@ from transformers.generation.logits_process import LogitsProcessor
|
|
|
|
|
from transformers.generation.utils import LogitsProcessorList
|
|
|
|
|
from evaluators.evaluator import Evaluator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatGLMMixin:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.tokenizer = None
|
|
|
|
@ -28,14 +29,15 @@ class ChatGLMMixin:
|
|
|
|
|
answers = list(test_df['answer'])
|
|
|
|
|
for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
|
|
|
|
question = self.format_example(row, include_answer=False, cot=cot)
|
|
|
|
|
history_tmp = history.copy()
|
|
|
|
|
if few_shot:
|
|
|
|
|
response, _ = self.model.chat(self.tokenizer, question, max_length=2000,
|
|
|
|
|
do_sample=False, history=history)
|
|
|
|
|
do_sample=False, history=history_tmp)
|
|
|
|
|
response = response.strip()
|
|
|
|
|
ans, direct_extract = self.extract_cot_answer(row, response)
|
|
|
|
|
else: # zero-shot by extracting answer from distribution
|
|
|
|
|
response, _ = self.model.chat(self.tokenizer, question, max_length=2000,
|
|
|
|
|
do_sample=False, history=history)
|
|
|
|
|
do_sample=False, history=history_tmp)
|
|
|
|
|
response = response.strip()
|
|
|
|
|
ans, direct_extract = self.extract_cot_answer(row, response)
|
|
|
|
|
if ans == answers[row_index]:
|
|
|
|
@ -64,8 +66,9 @@ class ChatGLMMixin:
|
|
|
|
|
history = self.generate_zero_shot_prompt(is_choice_question=False)
|
|
|
|
|
for row_index, row in tqdm(qa_df.iterrows(), total=len(qa_df)):
|
|
|
|
|
question = row['question']
|
|
|
|
|
history_tmp = history.copy()
|
|
|
|
|
response, _ = self.model.chat(self.tokenizer, question, max_length=2000,
|
|
|
|
|
do_sample=False, history=history)
|
|
|
|
|
do_sample=False, history=history_tmp)
|
|
|
|
|
response = response.strip()
|
|
|
|
|
qa_df.loc[row_index, 'model_output'] = response
|
|
|
|
|
# current_length = 0
|
|
|
|
|