refactor: minor refactors regarding multimodal processing (#6187)
This commit is contained in:
@@ -22,7 +22,11 @@ from typing import List, Optional, Set, Union
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_config,
|
||||
get_context_length,
|
||||
get_hf_text_config,
|
||||
)
|
||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||
@@ -209,7 +213,13 @@ class ModelConfig:
|
||||
|
||||
# Cache attributes
|
||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
||||
|
||||
config = self.hf_config
|
||||
|
||||
# multimodal
|
||||
self.image_token_id = getattr(config, "image_token_id", None) or getattr(
|
||||
config, "image_token_index", None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
|
||||
@@ -423,31 +433,6 @@ class ModelConfig:
|
||||
self.model_path = client.get_local_dir()
|
||||
|
||||
|
||||
def get_hf_text_config(config: PretrainedConfig):
|
||||
"""Get the "sub" config relevant to llm for multi modal models.
|
||||
No op for pure text models.
|
||||
"""
|
||||
class_name = config.architectures[0]
|
||||
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
||||
# We support non-hf version of llava models, so we do not want to
|
||||
# read the wrong values from the unused default text_config.
|
||||
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
|
||||
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
|
||||
setattr(config, "torch_dtype", torch.float16)
|
||||
return config
|
||||
|
||||
if hasattr(config, "text_config"):
|
||||
# The code operates under the assumption that text_config should have
|
||||
# `num_attention_heads` (among others). Assert here to fail early
|
||||
# if transformers config doesn't align with this assumption.
|
||||
assert hasattr(config.text_config, "num_attention_heads")
|
||||
return config.text_config
|
||||
if hasattr(config, "language_config"):
|
||||
return config.language_config
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
@@ -537,6 +522,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
||||
|
||||
|
||||
multimodal_model_archs = [
|
||||
"CLIPModel",
|
||||
"DeepseekVL2ForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Grok1VForCausalLM",
|
||||
@@ -554,7 +540,6 @@ multimodal_model_archs = [
|
||||
"MllamaForConditionalGeneration",
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"CLIPModel",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"InternVLChatModel",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user