refactor: minor refactors regarding multimodal processing (#6187)

This commit is contained in:
Mick
2025-05-18 13:53:20 +08:00
committed by GitHub
parent b3f3d610fd
commit 01dd39bac1
15 changed files with 140 additions and 98 deletions

View File

@@ -22,7 +22,11 @@ from typing import List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.hf_transformers_utils import (
get_config,
get_context_length,
get_hf_text_config,
)
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, is_hip
@@ -209,7 +213,13 @@ class ModelConfig:
# Cache attributes
self.hf_eos_token_id = self.get_hf_eos_token_id()
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
config = self.hf_config
# multimodal
self.image_token_id = getattr(config, "image_token_id", None) or getattr(
config, "image_token_index", None
)
@staticmethod
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
@@ -423,31 +433,6 @@ class ModelConfig:
self.model_path = client.get_local_dir()
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
class_name = config.architectures[0]
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr(config, "torch_dtype", torch.float16)
return config
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
if hasattr(config, "language_config"):
return config.language_config
else:
return config
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
@@ -537,6 +522,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs = [
"CLIPModel",
"DeepseekVL2ForCausalLM",
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM",
@@ -554,7 +540,6 @@ multimodal_model_archs = [
"MllamaForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"CLIPModel",
"KimiVLForConditionalGeneration",
"InternVLChatModel",
]