You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.8 KiB
Python

8 months ago
import os
import re
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel, AutoConfig
from evaluators.evaluator import Evaluator
from evaluators.chatglm_mixin import ChatGLMMixin
8 months ago
from pathlib import Path
from typing import Union, Tuple
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from peft import PeftModel, PeftConfig
CUDA_VISIBLE_DEVICES = 0
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
def _resolve_path(path: Union[str, Path]) -> Path:
return Path(path).expanduser().resolve()
def load_model_and_tokenizer(model_dir: Union[str, Path], device) -> Tuple[ModelType, TokenizerType]:
8 months ago
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
config = PeftConfig.from_pretrained(str(model_dir))
8 months ago
base_model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, mirror="tuna",
resume_download=True).to(device)
8 months ago
# base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,trust_remote_code=True,
# device_map='auto')
model = PeftModel.from_pretrained(base_model, model_dir)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=True, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=True
)
return model, tokenizer
class ChatGLM3_Evaluator(Evaluator, ChatGLMMixin):
8 months ago
def __init__(self, choices, k, model_name, device, finetune=None, finetune_method=None):
super(ChatGLM3_Evaluator, self).__init__(choices, model_name, k)
self.finetune_method = finetune_method
self.finetune_name = finetune
8 months ago
self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, mirror="tuna")
if finetune_method == "qlora":
model_dir = 'qlora/glm3/' + finetune
self.model, self.tokenizer = load_model_and_tokenizer(model_dir, device)
8 months ago
self.model = self.model.half().to(device)
print("Model loaded! use GLM3 " + finetune)
else:
self.model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, mirror="tuna",
resume_download=True).half().to(device)
print("Model loaded! (GLM3)")
self.model = self.model.eval()