Fix hf config loading (#702)
This commit is contained in:
@@ -4,19 +4,26 @@ import functools
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AbstractSet, Collection, Literal, Optional, Union
|
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
PretrainedConfig,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
)
|
)
|
||||||
|
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
||||||
|
|
||||||
from sglang.srt.utils import is_multimodal_model
|
from sglang.srt.utils import is_multimodal_model
|
||||||
|
|
||||||
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||||
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
||||||
|
DbrxConfig.model_type: DbrxConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def download_from_hf(model_path: str):
|
def download_from_hf(model_path: str):
|
||||||
if os.path.exists(model_path):
|
if os.path.exists(model_path):
|
||||||
@@ -40,6 +47,9 @@ def get_config(
|
|||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model, trust_remote_code=trust_remote_code, revision=revision
|
model, trust_remote_code=trust_remote_code, revision=revision
|
||||||
)
|
)
|
||||||
|
if config.model_type in _CONFIG_REGISTRY:
|
||||||
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||||
|
config = config_class.from_pretrained(model, revision=revision)
|
||||||
if model_overide_args:
|
if model_overide_args:
|
||||||
config.update(model_overide_args)
|
config.update(model_overide_args)
|
||||||
return config
|
return config
|
||||||
|
|||||||
Reference in New Issue
Block a user