You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
67 lines
2.8 KiB
Python
67 lines
2.8 KiB
Python
import os
|
|
import re
|
|
from tqdm import tqdm
|
|
import torch
|
|
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
|
from evaluators.evaluator import Evaluator
|
|
from evaluators.chatglm_mixin import ChatGLMMixin
|
|
|
|
from pathlib import Path
|
|
from typing import Union, Tuple
|
|
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
PreTrainedModel,
|
|
PreTrainedTokenizer,
|
|
PreTrainedTokenizerFast,
|
|
)
|
|
from peft import PeftModel, PeftConfig
|
|
|
|
CUDA_VISIBLE_DEVICES = 0
|
|
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
|
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
|
|
|
|
|
def _resolve_path(path: Union[str, Path]) -> Path:
|
|
return Path(path).expanduser().resolve()
|
|
|
|
|
|
def load_model_and_tokenizer(model_dir: Union[str, Path], device) -> Tuple[ModelType, TokenizerType]:
|
|
model_dir = _resolve_path(model_dir)
|
|
if (model_dir / 'adapter_config.json').exists():
|
|
config = PeftConfig.from_pretrained(str(model_dir))
|
|
base_model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, mirror="tuna",
|
|
resume_download=True).to(device)
|
|
# base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,trust_remote_code=True,
|
|
# device_map='auto')
|
|
model = PeftModel.from_pretrained(base_model, model_dir)
|
|
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
|
|
|
else:
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_dir, trust_remote_code=True, device_map='auto'
|
|
)
|
|
tokenizer_dir = model_dir
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
tokenizer_dir, trust_remote_code=True
|
|
)
|
|
return model, tokenizer
|
|
|
|
|
|
class ChatGLM3_Evaluator(Evaluator, ChatGLMMixin):
|
|
def __init__(self, choices, k, model_name, device, finetune=None, finetune_method=None):
|
|
super(ChatGLM3_Evaluator, self).__init__(choices, model_name, k)
|
|
self.finetune_method = finetune_method
|
|
self.finetune_name = finetune
|
|
self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, mirror="tuna")
|
|
if finetune_method == "qlora":
|
|
model_dir = 'qlora/glm3/' + finetune
|
|
self.model, self.tokenizer = load_model_and_tokenizer(model_dir, device)
|
|
self.model = self.model.half().to(device)
|
|
print("Model loaded! use GLM3 " + finetune)
|
|
else:
|
|
self.model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, mirror="tuna",
|
|
resume_download=True).half().to(device)
|
|
print("Model loaded! (GLM3)")
|
|
self.model = self.model.eval()
|