EXAONE 3.0 Model Support (#1258)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
김종곤
2024-08-30 17:08:28 +09:00
committed by GitHub
parent f414352ae6
commit b7f8341014
4 changed files with 609 additions and 2 deletions

View File

@@ -15,6 +15,7 @@ limitations under the License.
"""Utilities for Huggingface Transformers."""
import contextlib
import functools
import json
import os
@@ -34,14 +35,21 @@ from transformers import (
try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
from sglang.srt.configs import ExaoneConfig
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
ChatGLMConfig.model_type: ChatGLMConfig,
DbrxConfig.model_type: DbrxConfig,
ExaoneConfig.model_type: ExaoneConfig,
}
except ImportError:
# We want this file to run without vllm dependency
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
for name, cls in _CONFIG_REGISTRY.items():
with contextlib.suppress(ValueError):
AutoConfig.register(name, cls)
from sglang.srt.utils import is_multimodal_model
@@ -53,7 +61,7 @@ def download_from_hf(model_path: str):
def get_config_json(model_path: str):
with open(os.path.join(model_path, "config.json")) as f:
with open(os.path.join(model_path, "configs.json")) as f:
config = json.load(f)
return config
@@ -89,7 +97,7 @@ CONTEXT_LENGTH_KEYS = [
def get_context_length(config):
"""Get the context length of a model from a huggingface model config."""
"""Get the context length of a model from a huggingface model configs."""
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling:
rope_scaling_factor = config.rope_scaling["factor"]