Fix hf config loading (#702)
This commit is contained in:
@@ -4,19 +4,26 @@ import functools
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import AbstractSet, Collection, Literal, Optional, Union
|
||||
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
||||
|
||||
from sglang.srt.utils import is_multimodal_model
|
||||
|
||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
ChatGLMConfig.model_type: ChatGLMConfig,
|
||||
DbrxConfig.model_type: DbrxConfig,
|
||||
}
|
||||
|
||||
|
||||
def download_from_hf(model_path: str):
|
||||
if os.path.exists(model_path):
|
||||
@@ -40,6 +47,9 @@ def get_config(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision
|
||||
)
|
||||
if config.model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||
config = config_class.from_pretrained(model, revision=revision)
|
||||
if model_overide_args:
|
||||
config.update(model_overide_args)
|
||||
return config
|
||||
|
||||
Reference in New Issue
Block a user