From 01dd39bac1b5e73f609e59edf7f107d1ca044edc Mon Sep 17 00:00:00 2001 From: Mick Date: Sun, 18 May 2025 13:53:20 +0800 Subject: [PATCH] refactor: minor refactors regarding multimodal processing (#6187) --- python/sglang/srt/configs/model_config.py | 41 +++++---------- python/sglang/srt/hf_transformers_utils.py | 52 ++++++++++++++++--- python/sglang/srt/layers/attention/vision.py | 2 +- python/sglang/srt/managers/io_struct.py | 10 +++- python/sglang/srt/managers/mm_utils.py | 23 ++++++-- .../multimodal_processors/base_processor.py | 34 ++++++------ .../managers/multimodal_processors/minicpm.py | 28 ---------- python/sglang/srt/managers/schedule_batch.py | 6 ++- .../sglang/srt/managers/tokenizer_manager.py | 18 ++++--- python/sglang/srt/mm_utils.py | 10 ++++ .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/models/minicpmo.py | 7 ++- python/sglang/srt/models/mllama.py | 1 - python/sglang/srt/models/qwen2_5_vl.py | 2 + test/srt/test_vision_chunked_prefill.py | 2 +- 15 files changed, 140 insertions(+), 98 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a3c9e4ed9..043ef7a0f 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -22,7 +22,11 @@ from typing import List, Optional, Set, Union import torch from transformers import PretrainedConfig -from sglang.srt.hf_transformers_utils import get_config, get_context_length +from sglang.srt.hf_transformers_utils import ( + get_config, + get_context_length, + get_hf_text_config, +) from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_bool_env_var, is_hip @@ -209,7 +213,13 @@ class ModelConfig: # Cache attributes self.hf_eos_token_id = self.get_hf_eos_token_id() - self.image_token_id = getattr(self.hf_config, "image_token_id", None) + + config = self.hf_config + + # multimodal + self.image_token_id = getattr(config, "image_token_id", None) or getattr( + config, "image_token_index", None + ) @staticmethod def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs): @@ -423,31 +433,6 @@ class ModelConfig: self.model_path = client.get_local_dir() -def get_hf_text_config(config: PretrainedConfig): - """Get the "sub" config relevant to llm for multi modal models. - No op for pure text models. - """ - class_name = config.architectures[0] - if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"): - # We support non-hf version of llava models, so we do not want to - # read the wrong values from the unused default text_config. - # NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as - # `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`. - setattr(config, "torch_dtype", torch.float16) - return config - - if hasattr(config, "text_config"): - # The code operates under the assumption that text_config should have - # `num_attention_heads` (among others). Assert here to fail early - # if transformers config doesn't align with this assumption. - assert hasattr(config.text_config, "num_attention_heads") - return config.text_config - if hasattr(config, "language_config"): - return config.language_config - else: - return config - - # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, @@ -537,6 +522,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal multimodal_model_archs = [ + "CLIPModel", "DeepseekVL2ForCausalLM", "Gemma3ForConditionalGeneration", "Grok1VForCausalLM", @@ -554,7 +540,6 @@ multimodal_model_archs = [ "MllamaForConditionalGeneration", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration", - "CLIPModel", "KimiVLForConditionalGeneration", "InternVLChatModel", ] diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 48fa3d56d..33ec9ce09 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -19,6 +19,7 @@ import warnings from pathlib import Path from typing import Dict, Optional, Type, Union +import torch from huggingface_hub import snapshot_download from transformers import ( AutoConfig, @@ -65,6 +66,43 @@ def download_from_hf(model_path: str): return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"]) +def get_hf_text_config(config: PretrainedConfig): + """Get the "sub" config relevant to llm for multi modal models. + No op for pure text models. + """ + if config.architectures is not None: + class_name = config.architectures[0] + if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"): + # We support non-hf version of llava models, so we do not want to + # read the wrong values from the unused default text_config. + # NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as + # `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`. + setattr(config, "torch_dtype", torch.float16) + return config + + if hasattr(config, "text_config"): + # The code operates under the assumption that text_config should have + # `num_attention_heads` (among others). Assert here to fail early + # if transformers config doesn't align with this assumption. + assert hasattr(config.text_config, "num_attention_heads") + return config.text_config + if hasattr(config, "language_config"): + return config.language_config + if hasattr(config, "thinker_config"): + # qwen2.5 omni + thinker_config = config.thinker_config + if hasattr(thinker_config, "text_config"): + setattr( + thinker_config.text_config, + "torch_dtype", + getattr(thinker_config, "torch_dtype", None), + ) + return thinker_config.text_config + return thinker_config + else: + return config + + def get_config( model: str, trust_remote_code: bool, @@ -80,13 +118,12 @@ def get_config( config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) + text_config = get_hf_text_config(config=config) - # FIXME: Pour contents of janus-pro's langauge_config to first-level - if isinstance(model, str) and model.lower().startswith("deepseek-ai/janus-pro"): - assert hasattr(config, "language_config") - for key, val in config.language_config.__dict__.items(): - setattr(config, key, val) - setattr(config, "architectures", ["MultiModalityCausalLM"]) + if isinstance(model, str) and text_config is not None: + for key, val in text_config.__dict__.items(): + if not hasattr(config, key) and getattr(text_config, key, None) is not None: + setattr(config, key, val) if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] @@ -99,6 +136,9 @@ def get_config( if not hasattr(config, key): setattr(config, key, val) + if config.model_type == "multi_modality": + config.update({"architectures": ["MultiModalityCausalLM"]}) + if model_override_args: config.update(model_override_args) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 429787ec8..f1f45e27a 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -120,7 +120,7 @@ class VisionSdpaAttention(nn.Module): flatten_batch: bool = False, ) -> Optional[torch.Tensor]: r""" - Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`. + Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`. Args: s: sequence length cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index dfb3b6eb2..5734cd95c 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -22,13 +22,15 @@ from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from sglang.srt.mm_utils import has_valid_data + # handle serialization of Image for pydantic if TYPE_CHECKING: from PIL.Image import Image else: Image = Any -from sglang.srt.managers.schedule_batch import BaseFinishReason +from sglang.srt.managers.schedule_batch import BaseFinishReason, flatten_nested_list from sglang.srt.sampling.sampling_params import SamplingParams @@ -104,6 +106,9 @@ class GenerateReqInput: bootstrap_port: Optional[Union[List[int], int]] = None bootstrap_room: Optional[Union[List[int], int]] = None + def contains_mm_input(self) -> bool: + return has_valid_data(self.image_data) or has_valid_data(self.audio_data) + def normalize_batch_and_arguments(self): """ Normalize the batch size and arguments for the request. @@ -487,6 +492,9 @@ class EmbeddingReqInput: # The modalities of the image data [image, multi-images, video] modalities: Optional[List[str]] = None + def contains_mm_input(self) -> bool: + return has_valid_data(self.image_data) or has_valid_data(self.audio_data) + def normalize_batch_and_arguments(self): # at least one of text, input_ids, or image should be provided if self.text is None and self.input_ids is None and self.image_data is None: diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 2c8cad5ac..5a3392661 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -2,6 +2,7 @@ Multi-modality utils """ +import dataclasses import logging from abc import abstractmethod from typing import Callable, List, Optional, Tuple @@ -41,11 +42,26 @@ class MultiModalityDataPaddingPattern: class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern): """In this pattern, data tokens should be enclosed by special token pairs (e.g. ..., data_token_pairs) + The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value + This strategy should be applied when data content is marked by start/end token pairs in the input sequence. """ - def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None: + def __init__( + self, + data_token_pairs: Optional[List[Tuple[int, int]]], + data_start_token_ids: Optional[List[int]] = None, + ) -> None: + """ + + Args: + data_start_token_ids marks the start of a single multimodal data + See Minicpmo's slice_start_id for example + """ self.data_token_id_pairs = data_token_pairs + self.data_start_token_ids = data_start_token_ids or [ + s for s, _e in data_token_pairs + ] def pad_input_tokens( self, input_ids: List[int], mm_inputs: MultimodalInputs @@ -79,7 +95,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) for start_idx, end_idx in zip(start_indices, end_indices): padded_ids.extend(input_ids[last_idx : start_idx + 1]) - if input_ids[start_idx] in start_token_ids: + if input_ids[start_idx] in self.data_start_token_ids: data_idx += 1 mm_inputs.data_offsets += [start_idx] @@ -170,7 +186,6 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa output_ids_tensor[start_idx:end_idx] = pad_value else: logger.warning(f"Skipping region {i} due to None pad_value.") - return output_ids_tensor.tolist() @@ -202,7 +217,7 @@ def get_embedding_and_mask( num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item() if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding: logger.warning( - f"Number of tokens in multimodal embedding does not match those in the input text." + f"Number of tokens in multimodal embedding does not match those in the input text. " f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} " "tokens from multimodal embeddings." ) diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index a6070cc0f..b957adf4b 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -36,9 +36,21 @@ class BaseMultiModalProcessorOutput: @dataclasses.dataclass class MultimodalSpecialTokens: - image_token: Optional[str] = None - video_token: Optional[str] = None - audio_token: Optional[str] = None + image_token: Optional[Union[int, str, List[str]]] = None + video_token: Optional[Union[int, str, List[str]]] = None + audio_token: Optional[Union[int, str, List[str]]] = None + + def convert_to_str(self, token: Union[str, int], processor) -> str: + if token is None: + return token + if isinstance(token, str): + return token + return processor.tokenizer.convert_ids_to_tokens([token])[0] + + def convert_to_strs(self, processor): + self.image_token = self.convert_to_str(self.image_token, processor) + self.video_token = self.convert_to_str(self.video_token, processor) + self.audio_token = self.convert_to_str(self.audio_token, processor) image_token_regex: Optional[re.Pattern] = None video_token_regex: Optional[re.Pattern] = None @@ -74,6 +86,7 @@ class BaseMultimodalProcessor(ABC): def __init__(self, hf_config, server_args, _processor): self.hf_config = hf_config self._processor = _processor + self.arch = hf_config.architectures[0] self.server_args = server_args # FIXME: not accurate, model and image specific self.NUM_TOKEN_PER_FRAME = 330 @@ -260,19 +273,10 @@ class BaseMultimodalProcessor(ABC): """ if not return_text: raise NotImplementedError() - if image_data is None: image_data = [] - if isinstance(multimodal_tokens.image_token, int): - multimodal_tokens.image_token = re.compile( - re.escape( - self._processor.tokenizer.convert_ids_to_tokens( - multimodal_tokens.image_token - ) - ) - ) - else: - multimodal_tokens.image_token = multimodal_tokens.image_token + + multimodal_tokens.convert_to_strs(self._processor) multimodal_tokens_pattern = multimodal_tokens.collect() if isinstance(prompt, list) and return_text: @@ -332,9 +336,9 @@ class BaseMultimodalProcessor(ABC): new_text += text_part out = BaseMultiModalProcessorOutput( + input_text=new_text, images=images, audios=audios, - input_text=new_text, ) out.normalize() return out diff --git a/python/sglang/srt/managers/multimodal_processors/minicpm.py b/python/sglang/srt/managers/multimodal_processors/minicpm.py index 35b41bab4..f6611ac79 100644 --- a/python/sglang/srt/managers/multimodal_processors/minicpm.py +++ b/python/sglang/srt/managers/multimodal_processors/minicpm.py @@ -1,7 +1,6 @@ from typing import List, Union import torch -from transformers import BaseImageProcessorFast from sglang.srt.managers.multimodal_processors.base_processor import ( BaseMultimodalProcessor, @@ -21,33 +20,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): self.image_token = "(./)" self.audio_token = "()" - def process_data_task(self, input_text, images=None, audios=None): - - if isinstance(images, list) and len(images) == 0: - images = None - if isinstance(audios, list) and len(audios) == 0: - audios = None - processor = self._processor - args = {} - if isinstance(processor, BaseImageProcessorFast): - args["device"] = "cuda" - result = self._processor.__call__( - text=input_text, - images=images, - audios=audios, - return_tensors="pt", - chunk_input=True, - **args, - ) - return { - "input_ids": result.input_ids, - "pixel_values": getattr(result, "pixel_values", None), - "tgt_sizes": getattr(result, "tgt_sizes", None), - "audio_features": getattr(result, "audio_features", None), - "audio_feature_lens": getattr(result, "audio_feature_lens", None), - "audio_bounds": getattr(result, "audio_bounds", None), - } - async def process_mm_data_async( self, image_data: List[Union[str, bytes]], diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 83479cd59..836335136 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -324,8 +324,9 @@ class MultimodalInputs: video_token_id: Optional[int] = None # audio - audio_start_id: Optional[torch.Tensor] = None - audio_end_id: Optional[torch.Tensor] = None + audio_token_id: Optional[int] = None + audio_start_id: Optional[int] = None + audio_end_id: Optional[int] = None @staticmethod def from_dict(obj: dict): @@ -349,6 +350,7 @@ class MultimodalInputs: "slice_end_id", "audio_start_id", "audio_end_id", + "audio_token_id", ] for arg in optional_args: if arg in obj: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 53e422718..1a81f498b 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -459,14 +459,16 @@ class TokenizerManager: ) input_ids = self.tokenizer.encode(input_text) - image_inputs: Dict = await self.mm_processor.process_mm_data_async( - image_data=obj.image_data, - input_text=input_text or input_ids, - request_obj=obj, - max_req_input_len=self.max_req_input_len, - ) - if image_inputs and "input_ids" in image_inputs: - input_ids = image_inputs["input_ids"] + image_inputs: Optional[Dict] = None + if obj.contains_mm_input(): + image_inputs = await self.mm_processor.process_mm_data_async( + image_data=obj.image_data, + input_text=input_text or input_ids, + request_obj=obj, + max_req_input_len=self.max_req_input_len, + ) + if image_inputs and "input_ids" in image_inputs: + input_ids = image_inputs["input_ids"] self._validate_token_len(obj, input_ids) return self._create_tokenized_object( diff --git a/python/sglang/srt/mm_utils.py b/python/sglang/srt/mm_utils.py index 040764906..9c05c1859 100644 --- a/python/sglang/srt/mm_utils.py +++ b/python/sglang/srt/mm_utils.py @@ -36,6 +36,16 @@ from io import BytesIO import numpy as np from PIL import Image +from sglang.srt.utils import flatten_nested_list + + +def has_valid_data(data) -> bool: + if data is None: + return False + if isinstance(data, list): + return any(has_valid_data(item) for item in flatten_nested_list(data)) + return True + def select_best_resolution(original_size, possible_resolutions): """ diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a034cbbcd..656fc86eb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1165,7 +1165,7 @@ class ModelRunner: def model_is_mrope(self) -> bool: """Detect if the model has "mrope" rope_scaling type. mrope requires keep "rope_deltas" between prompt and decoding phases.""" - rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + rope_scaling = getattr(self.model_config.hf_text_config, "rope_scaling", {}) if rope_scaling is None: return False is_mrope_enabled = "mrope_section" in rope_scaling diff --git a/python/sglang/srt/models/minicpmo.py b/python/sglang/srt/models/minicpmo.py index 75df1dfae..7199da4f1 100644 --- a/python/sglang/srt/models/minicpmo.py +++ b/python/sglang/srt/models/minicpmo.py @@ -1520,12 +1520,15 @@ class MiniCPMO(MiniCPMBaseModel): slice_start_id: int = mm_input.slice_start_id slice_end_id: int = mm_input.slice_end_id - media_token_pairs = [ + data_token_pairs = [ (im_start_id, im_end_id), (slice_start_id, slice_end_id), (mm_input.audio_start_id, mm_input.audio_end_id), ] - pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) + data_start_token_ids = [im_start_id, mm_input.audio_start_id] + pattern = MultiModalityDataPaddingPatternTokenPairs( + data_token_pairs=data_token_pairs, data_start_token_ids=data_start_token_ids + ) return pattern.pad_input_tokens(input_ids, mm_input) diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 8d63d9cfc..6439f9327 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -865,7 +865,6 @@ class MllamaForConditionalGeneration(nn.Module): pixel_values = torch.cat( [item.pixel_values for item in mm_input.mm_items], dim=0 ) - # max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items)) max_num_images = max(max_num_images, pixel_values.shape[1]) max_num_tiles = max(max_num_tiles, pixel_values.shape[2]) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 7cc24f182..420216c7b 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -146,6 +146,8 @@ class Qwen2_5_VisionBlock(nn.Module): num_heads=num_heads, projection_size=dim, use_qkv_parallel=True, + rotary_embed="normal", + proj_bias=True, qkv_backend=qkv_backend, softmax_in_single_precision=softmax_in_single_precision, flatten_batch=flatten_batch, diff --git a/test/srt/test_vision_chunked_prefill.py b/test/srt/test_vision_chunked_prefill.py index 7c8f21107..e41759e7b 100644 --- a/test/srt/test_vision_chunked_prefill.py +++ b/test/srt/test_vision_chunked_prefill.py @@ -147,8 +147,8 @@ class TestVisionChunkedPrefill(CustomTestCase): def _test_chunked_prefill(self, batches, num_frames): # Chunked + chunked_server_pid = self.launch_server(chunked_prefill_size=1024) try: - chunked_server_pid = self.launch_server(chunked_prefill_size=1024) outputs_chunked = [] for batch, num_frame in zip(batches, num_frames): output_chunked = self.generate_for_video(