194 lines
6.3 KiB
Python
194 lines
6.3 KiB
Python
"""
|
|
Copyright 2023-2024 SGLang Team
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
"""Utilities for Huggingface Transformers."""
|
|
|
|
import contextlib
|
|
import os
|
|
import warnings
|
|
from typing import Dict, Optional, Type, Union
|
|
|
|
from huggingface_hub import snapshot_download
|
|
from transformers import (
|
|
AutoConfig,
|
|
AutoProcessor,
|
|
AutoTokenizer,
|
|
PretrainedConfig,
|
|
PreTrainedTokenizer,
|
|
PreTrainedTokenizerFast,
|
|
)
|
|
|
|
try:
|
|
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
|
|
|
from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig
|
|
|
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|
ChatGLMConfig.model_type: ChatGLMConfig,
|
|
DbrxConfig.model_type: DbrxConfig,
|
|
ExaoneConfig.model_type: ExaoneConfig,
|
|
Qwen2VLConfig.model_type: Qwen2VLConfig,
|
|
}
|
|
except ImportError:
|
|
# We want this file to run without vllm dependency
|
|
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
|
|
|
|
for name, cls in _CONFIG_REGISTRY.items():
|
|
with contextlib.suppress(ValueError):
|
|
AutoConfig.register(name, cls)
|
|
|
|
|
|
def download_from_hf(model_path: str):
|
|
if os.path.exists(model_path):
|
|
return model_path
|
|
|
|
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
|
|
|
|
|
|
def get_config(
|
|
model: str,
|
|
trust_remote_code: bool,
|
|
revision: Optional[str] = None,
|
|
model_override_args: Optional[dict] = None,
|
|
):
|
|
config = AutoConfig.from_pretrained(
|
|
model, trust_remote_code=trust_remote_code, revision=revision
|
|
)
|
|
if config.model_type in _CONFIG_REGISTRY:
|
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
|
config = config_class.from_pretrained(model, revision=revision)
|
|
if model_override_args:
|
|
config.update(model_override_args)
|
|
return config
|
|
|
|
|
|
# Models don't use the same configuration key for determining the maximum
|
|
# context length. Store them here so we can sanely check them.
|
|
# NOTE: The ordering here is important. Some models have two of these and we
|
|
# have a preference for which value gets used.
|
|
CONTEXT_LENGTH_KEYS = [
|
|
"max_sequence_length",
|
|
"seq_length",
|
|
"max_position_embeddings",
|
|
"max_seq_len",
|
|
"model_max_length",
|
|
]
|
|
|
|
|
|
def get_context_length(config):
|
|
"""Get the context length of a model from a huggingface model configs."""
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
if rope_scaling:
|
|
rope_scaling_factor = config.rope_scaling.get("factor", 1)
|
|
if "original_max_position_embeddings" in rope_scaling:
|
|
rope_scaling_factor = 1
|
|
if config.rope_scaling.get("rope_type", None) == "llama3":
|
|
rope_scaling_factor = 1
|
|
else:
|
|
rope_scaling_factor = 1
|
|
|
|
for key in CONTEXT_LENGTH_KEYS:
|
|
val = getattr(config, key, None)
|
|
if val is not None:
|
|
return int(rope_scaling_factor * val)
|
|
return 2048
|
|
|
|
|
|
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
|
|
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
|
|
|
|
|
|
def get_tokenizer(
|
|
tokenizer_name: str,
|
|
*args,
|
|
tokenizer_mode: str = "auto",
|
|
trust_remote_code: bool = False,
|
|
tokenizer_revision: Optional[str] = None,
|
|
**kwargs,
|
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
|
if tokenizer_mode == "slow":
|
|
if kwargs.get("use_fast", False):
|
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
|
kwargs["use_fast"] = False
|
|
|
|
try:
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
tokenizer_name,
|
|
*args,
|
|
trust_remote_code=trust_remote_code,
|
|
tokenizer_revision=tokenizer_revision,
|
|
clean_up_tokenization_spaces=False,
|
|
**kwargs,
|
|
)
|
|
except TypeError as e:
|
|
# The LLaMA tokenizer causes a protobuf error in some environments.
|
|
err_msg = (
|
|
"Failed to load the tokenizer. If you are using a LLaMA V1 model "
|
|
f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the "
|
|
"original tokenizer."
|
|
)
|
|
raise RuntimeError(err_msg) from e
|
|
except ValueError as e:
|
|
# If the error pertains to the tokenizer class not existing or not
|
|
# currently being imported, suggest using the --trust-remote-code flag.
|
|
if not trust_remote_code and (
|
|
"does not exist or is not currently imported." in str(e)
|
|
or "requires you to execute the tokenizer file" in str(e)
|
|
):
|
|
err_msg = (
|
|
"Failed to load the tokenizer. If the tokenizer is a custom "
|
|
"tokenizer not yet available in the HuggingFace transformers "
|
|
"library, consider setting `trust_remote_code=True` in LLM "
|
|
"or using the `--trust-remote-code` flag in the CLI."
|
|
)
|
|
raise RuntimeError(err_msg) from e
|
|
else:
|
|
raise e
|
|
|
|
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
|
warnings.warn(
|
|
"Using a slow tokenizer. This might cause a significant "
|
|
"slowdown. Consider using a fast tokenizer instead."
|
|
)
|
|
|
|
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
|
|
if "<|eom_id|>" in tokenizer.get_added_vocab():
|
|
tokenizer.additional_stop_token_ids = set(
|
|
[tokenizer.get_added_vocab()["<|eom_id|>"]]
|
|
)
|
|
else:
|
|
tokenizer.additional_stop_token_ids = None
|
|
|
|
return tokenizer
|
|
|
|
|
|
def get_processor(
|
|
tokenizer_name: str,
|
|
*args,
|
|
tokenizer_mode: str = "auto",
|
|
trust_remote_code: bool = False,
|
|
tokenizer_revision: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
processor = AutoProcessor.from_pretrained(
|
|
tokenizer_name,
|
|
*args,
|
|
trust_remote_code=trust_remote_code,
|
|
tokenizer_revision=tokenizer_revision,
|
|
**kwargs,
|
|
)
|
|
return processor
|