From 824a77d04d90662eeb3864d3f36e9f2458d4b9f6 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 24 Jul 2024 02:39:08 +0800 Subject: [PATCH] Fix hf config loading (#702) --- python/sglang/srt/hf_transformers_utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 218af433c..850f3ffc2 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -4,19 +4,26 @@ import functools import json import os import warnings -from typing import AbstractSet, Collection, Literal, Optional, Union +from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union from huggingface_hub import snapshot_download from transformers import ( AutoConfig, AutoProcessor, AutoTokenizer, + PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, ) +from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig from sglang.srt.utils import is_multimodal_model +_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { + ChatGLMConfig.model_type: ChatGLMConfig, + DbrxConfig.model_type: DbrxConfig, +} + def download_from_hf(model_path: str): if os.path.exists(model_path): @@ -40,6 +47,9 @@ def get_config( config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision ) + if config.model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[config.model_type] + config = config_class.from_pretrained(model, revision=revision) if model_overide_args: config.update(model_overide_args) return config