# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for Huggingface Transformers.""" import contextlib import json import os import warnings from pathlib import Path from typing import Any, Dict, Optional, Type, Union import torch from huggingface_hub import snapshot_download from transformers import ( AutoConfig, AutoProcessor, AutoTokenizer, GenerationConfig, PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast, ) from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from sglang.srt.configs import ( ChatGLMConfig, DbrxConfig, DeepseekVL2Config, DotsVLMConfig, ExaoneConfig, KimiVLConfig, LongcatFlashConfig, MultiModalityConfig, Qwen3NextConfig, Step3VLConfig, ) from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.connector import create_remote_connector from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ChatGLMConfig.model_type: ChatGLMConfig, DbrxConfig.model_type: DbrxConfig, ExaoneConfig.model_type: ExaoneConfig, DeepseekVL2Config.model_type: DeepseekVL2Config, MultiModalityConfig.model_type: MultiModalityConfig, KimiVLConfig.model_type: KimiVLConfig, InternVLChatConfig.model_type: InternVLChatConfig, Step3VLConfig.model_type: Step3VLConfig, LongcatFlashConfig.model_type: LongcatFlashConfig, Qwen3NextConfig.model_type: Qwen3NextConfig, DotsVLMConfig.model_type: DotsVLMConfig, } for name, cls in _CONFIG_REGISTRY.items(): with contextlib.suppress(ValueError): AutoConfig.register(name, cls) def download_from_hf( model_path: str, allow_patterns: Optional[Union[str, list]] = None, ): if os.path.exists(model_path): return model_path if not allow_patterns: allow_patterns = ["*.json", "*.bin", "*.model"] return snapshot_download(model_path, allow_patterns=allow_patterns) def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models. """ if config.architectures is not None: class_name = config.architectures[0] if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"): # We support non-hf version of llava models, so we do not want to # read the wrong values from the unused default text_config. # NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as # `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`. setattr(config, "torch_dtype", torch.float16) return config if hasattr(config, "text_config"): # The code operates under the assumption that text_config should have # `num_attention_heads` (among others). Assert here to fail early # if transformers config doesn't align with this assumption. assert hasattr(config.text_config, "num_attention_heads") return config.text_config if hasattr(config, "language_config"): return config.language_config if hasattr(config, "thinker_config"): # qwen2.5 omni thinker_config = config.thinker_config if hasattr(thinker_config, "text_config"): setattr( thinker_config.text_config, "torch_dtype", getattr(thinker_config, "torch_dtype", None), ) return thinker_config.text_config return thinker_config else: return config @lru_cache_frozenset(maxsize=32) def get_config( model: str, trust_remote_code: bool, revision: Optional[str] = None, model_override_args: Optional[dict] = None, **kwargs, ): is_gguf = check_gguf_file(model) if is_gguf: kwargs["gguf_file"] = model model = Path(model).parent if is_remote_url(model): # BaseConnector implements __del__() to clean up the local dir. # Since config files need to exist all the time, so we DO NOT use # with statement to avoid closing the client. client = create_remote_connector(model) client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model = client.get_local_dir() config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) if ( config.architectures is not None and config.architectures[0] == "Phi4MMForCausalLM" ): # Phi4MMForCausalLM uses a hard-coded vision_config. See: # https://github.com/vllm-project/vllm/blob/6071e989df1531b59ef35568f83f7351afb0b51e/vllm/model_executor/models/phi4mm.py#L71 # We set it here to support cases where num_attention_heads is not divisible by the TP size. from transformers import SiglipVisionConfig vision_config = { "hidden_size": 1152, "image_size": 448, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction. "patch_size": 14, } config.vision_config = SiglipVisionConfig(**vision_config) text_config = get_hf_text_config(config=config) if isinstance(model, str) and text_config is not None: for key, val in text_config.__dict__.items(): if not hasattr(config, key) and getattr(text_config, key, None) is not None: setattr(config, key, val) if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`. setattr(config, "_name_or_path", model) if isinstance(model, str) and config.model_type == "internvl_chat": for key, val in config.llm_config.__dict__.items(): if not hasattr(config, key): setattr(config, key, val) if config.model_type == "multi_modality": config.update({"architectures": ["MultiModalityCausalLM"]}) if model_override_args: config.update(model_override_args) # Special architecture mapping check for GGUF models if is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: raise RuntimeError(f"Can't get gguf config for {config.model_type}.") model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) return config @lru_cache_frozenset(maxsize=32) def get_generation_config( model: str, trust_remote_code: bool, revision: Optional[str] = None, **kwargs, ): try: return GenerationConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) except OSError as e: return None # Qwen-1M related def get_sparse_attention_config( model: str, sparse_attention_config_filename: str = "sparse_attention_config.json", ) -> Dict[str, Any]: is_local = os.path.isdir(model) if not is_local: # Download the config files. model = download_from_hf(model, allow_patterns=["*.json"]) config_file = os.path.join(model, sparse_attention_config_filename) if not os.path.exists(config_file): return {} # Load the sparse attention config. with open(config_file) as f: config = json.load(f) return config # Models don't use the same configuration key for determining the maximum # context length. Store them here so we can sanely check them. # NOTE: The ordering here is important. Some models have two of these and we # have a preference for which value gets used. CONTEXT_LENGTH_KEYS = [ "max_sequence_length", "seq_length", "max_seq_len", "model_max_length", "max_position_embeddings", ] def get_context_length(config): """Get the context length of a model from a huggingface model configs.""" text_config = config rope_scaling = getattr(text_config, "rope_scaling", None) if rope_scaling: rope_scaling_factor = rope_scaling.get("factor", 1) if "original_max_position_embeddings" in rope_scaling: rope_scaling_factor = 1 if rope_scaling.get("rope_type", None) == "llama3": rope_scaling_factor = 1 else: rope_scaling_factor = 1 for key in CONTEXT_LENGTH_KEYS: val = getattr(text_config, key, None) if val is not None: return int(rope_scaling_factor * val) return 2048 # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" def get_tokenizer( tokenizer_name: str, *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Gets a tokenizer for the given model name via Huggingface.""" if tokenizer_name.endswith(".json"): from sglang.srt.tokenizer.tiktoken_tokenizer import TiktokenTokenizer return TiktokenTokenizer(tokenizer_name) if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False # TODO(Xinyuan): Remove this once we have a proper tokenizer for Devstral if tokenizer_name == "mistralai/Devstral-Small-2505": tokenizer_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" is_gguf = check_gguf_file(tokenizer_name) if is_gguf: kwargs["gguf_file"] = tokenizer_name tokenizer_name = Path(tokenizer_name).parent if is_remote_url(tokenizer_name): # BaseConnector implements __del__() to clean up the local dir. # Since config files need to exist all the time, so we DO NOT use # with statement to avoid closing the client. client = create_remote_connector(tokenizer_name) client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) tokenizer_name = client.get_local_dir() try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, clean_up_tokenization_spaces=False, **kwargs, ) except TypeError as e: # The LLaMA tokenizer causes a protobuf error in some environments. err_msg = ( "Failed to load the tokenizer. If you are using a LLaMA V1 model " f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the " "original tokenizer." ) raise RuntimeError(err_msg) from e except ValueError as e: # If the error pertains to the tokenizer class not existing or not # currently being imported, suggest using the --trust-remote-code flag. if not trust_remote_code and ( "does not exist or is not currently imported." in str(e) or "requires you to execute the tokenizer file" in str(e) ): err_msg = ( "Failed to load the tokenizer. If the tokenizer is a custom " "tokenizer not yet available in the HuggingFace transformers " "library, consider setting `trust_remote_code=True` in LLM " "or using the `--trust-remote-code` flag in the CLI." ) raise RuntimeError(err_msg) from e else: raise e if not isinstance(tokenizer, PreTrainedTokenizerFast): warnings.warn( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) attach_additional_stop_token_ids(tokenizer) return tokenizer # Some models doesn't have an available processor, e.g.: InternVL def get_tokenizer_from_processor(processor): if isinstance(processor, PreTrainedTokenizerBase): return processor return processor.tokenizer def get_processor( tokenizer_name: str, *args, tokenizer_mode: str = "auto", trust_remote_code: bool = False, tokenizer_revision: Optional[str] = None, use_fast: Optional[bool] = True, **kwargs, ): # pop 'revision' from kwargs if present. revision = kwargs.pop("revision", tokenizer_revision) config = AutoConfig.from_pretrained( tokenizer_name, trust_remote_code=trust_remote_code, revision=revision, **kwargs, ) # fix: for Qwen2-VL and Sarashina2Vision models, inject default 'size' if not provided. if config.model_type in {"qwen2_vl", "sarashina2_vision"}: if "size" not in kwargs: kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520} if config.model_type not in {"llava", "clip"}: kwargs["use_fast"] = use_fast try: if "InternVL3_5" in tokenizer_name: processor = AutoTokenizer.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, revision=revision, **kwargs, ) else: processor = AutoProcessor.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, revision=revision, **kwargs, ) except ValueError as e: error_message = str(e) if "does not have a slow version" in error_message: logger.info( f"Processor {tokenizer_name} does not have a slow version. Automatically use fast version" ) kwargs["use_fast"] = True processor = AutoProcessor.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, revision=revision, **kwargs, ) else: raise e tokenizer = get_tokenizer_from_processor(processor) attach_additional_stop_token_ids(tokenizer) return processor def attach_additional_stop_token_ids(tokenizer): # Special handling for stop token <|eom_id|> generated by llama 3 tool use. if "<|eom_id|>" in tokenizer.get_added_vocab(): tokenizer.additional_stop_token_ids = set( [tokenizer.get_added_vocab()["<|eom_id|>"]] ) else: tokenizer.additional_stop_token_ids = None def check_gguf_file(model: Union[str, os.PathLike]) -> bool: """Check if the file is a GGUF model.""" model = Path(model) if not model.is_file(): return False elif model.suffix == ".gguf": return True with open(model, "rb") as f: header = f.read(4) return header == b"GGUF"