Fix hf config loading (#702)

This commit is contained in:
Ke Bao
2024-07-24 02:39:08 +08:00
committed by GitHub
parent cf99eab7d5
commit 824a77d04d

View File

@@ -4,19 +4,26 @@ import functools
import json import json
import os import os
import warnings import warnings
from typing import AbstractSet, Collection, Literal, Optional, Union from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer, PreTrainedTokenizer,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
) )
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
from sglang.srt.utils import is_multimodal_model from sglang.srt.utils import is_multimodal_model
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig,
}
def download_from_hf(model_path: str): def download_from_hf(model_path: str):
if os.path.exists(model_path): if os.path.exists(model_path):
@@ -40,6 +47,9 @@ def get_config(
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision model, trust_remote_code=trust_remote_code, revision=revision
) )
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
if model_overide_args: if model_overide_args:
config.update(model_overide_args) config.update(model_overide_args)
return config return config