You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.8 KiB
Python

import os
import re
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from evaluators.evaluator import Evaluator
from evaluators.chatglm_mixin import ChatGLMMixin
from pathlib import Path
from typing import Union, Tuple
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from peft import PeftModel, PeftConfig
CUDA_VISIBLE_DEVICES = 0
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def load_model_and_tokenizer(model_dir: Union[str, Path], device) -> Tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
config = PeftConfig.from_pretrained(str(model_dir))
base_model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, mirror="tuna",
resume_download=True).to(device)
# base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,trust_remote_code=True,
# device_map='auto')
model = PeftModel.from_pretrained(base_model, model_dir)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=True, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=True
)
return model, tokenizer
class ChatGLM3_Evaluator(Evaluator, ChatGLMMixin):
def __init__(self, choices, k, model_name, device, finetune=None, finetune_method=None):
super(ChatGLM3_Evaluator, self).__init__(choices, model_name, k)
self.finetune_method = finetune_method
self.finetune_name = finetune
self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, mirror="tuna")
if finetune_method == "qlora":
model_dir = 'qlora/glm3/' + finetune
self.model, self.tokenizer = load_model_and_tokenizer(model_dir, device)
self.model = self.model.half().to(device)
print("Model loaded! use GLM3 " + finetune)
else:
self.model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, mirror="tuna",
resume_download=True).half().to(device)
print("Model loaded! (GLM3)")
self.model = self.model.eval()