diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 7d3e4131e..982cae8dd 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -116,6 +116,10 @@ class ModelConfig: self.is_audio_model = enable_multimodal and is_audio_model( self.hf_config.architectures ) + self.is_multimodal_chunked_prefill_supported = ( + enable_multimodal + and is_multimodal_chunked_prefill_supported(self.hf_config.architectures) + ) self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -574,6 +578,21 @@ def is_encoder_decoder_model(model_architectures: List[str]): return "MllamaForConditionalGeneration" in model_architectures +def is_multimodal_chunked_prefill_supported(model_architectures: List[str]): + """Check if chunked prefill is supported for a MultiModal model.""" + unsupported = [ + "Grok1VForCausalLM", + "Grok1AForCausalLM", + "LlavaLlamaForCausalLM", + "MllamaForConditionalGeneration", + "CLIPModel", + ] + if any(multi_model_arch in unsupported for multi_model_arch in model_architectures): + return False + else: + return True + + def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: if scale <= 1: return 1.0 diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 5a3392661..f39f730e3 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -16,10 +16,15 @@ from sglang.srt.managers.schedule_batch import ( MultimodalInputs, global_server_args_dict, ) +from sglang.srt.mem_cache.multimodal_cache import MultiModalCache from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import flatten_nested_list, print_warning_once +from sglang.utils import logger -logger = logging.getLogger(__name__) +# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger +# to ensure consistent logging behavior across the codebase. This prevents issues with log +# propagation that can cause some log messages (like 'server is fired up') to not appear +# in the console when multimodal support is enabled. class MultiModalityDataPaddingPattern: @@ -189,26 +194,137 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa return output_ids_tensor.tolist() +embedding_cache = None + + +def init_embedding_cache(max_size: int): + global embedding_cache + embedding_cache = MultiModalCache(max_size) + + +def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int: + hash_list = [item.hash for item in embedding_items] + return hash(tuple(hash_list)) + + +def get_embedding_chunk( + embedding: torch.Tensor, + extend_prefix_len: int, + extend_seq_len: int, + items_offset: List[Tuple[int, int]], +) -> Tuple[torch.Tensor, int, int]: + """ + Extract a chunk of embeddings based on the specified prefix length, sequence length, and offset ranges. + + Args: + embedding: The full embedding tensor to extract a chunk from + extend_prefix_len: The starting position (prefix length) for extraction + extend_seq_len: The number of tokens to extract + items_offset: List of [start, end] offset ranges for multimodal items in the input sequence + + Returns: + A tuple containing: + - The extracted embedding chunk as a tensor + - The start index used for extraction + - The end index used for extraction + + Note: + If there's no overlap between the requested range and the offset ranges, + an empty tensor is returned with zeros for start and end indices. + """ + start_index, end_index = 0, 0 + extend_start_index = extend_prefix_len + extend_end_index = extend_prefix_len + extend_seq_len - 1 + + for start, end in items_offset: + if extend_start_index >= start and extend_start_index <= end: + start_index += extend_start_index - start + elif extend_start_index > end: + start_index += end - start + 1 + + if extend_end_index >= start and extend_end_index <= end: + end_index += extend_end_index - start + 1 + elif extend_end_index > end: + end_index += end - start + 1 + # some models embedding is 3-dim, reshape it to 2-dim + embedding = embedding.reshape(-1, embedding.shape[-1]) + embedding_chunk = embedding[start_index:end_index] + return embedding_chunk, start_index, end_index + + def get_embedding_and_mask( data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor], embedding_items: List[MultimodalDataItem], placeholder_tensor: torch.Tensor, input_ids: torch.Tensor, -): + items_size: List[int], + prefix_length: List[int], + extend_length: List[int], + items_offset_list: List[List[Tuple[int, int]]], +) -> Tuple[torch.Tensor, torch.Tensor]: """ - Get the multimodal embedding and its mask from input_ids + Generate multimodal embeddings and create a mask for identifying their positions in the input sequence. + Args: + data_embedding_func: Function that generates embeddings for multimodal items + embedding_items: List of multimodal items to embed + placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content + input_ids: The input token IDs tensor + items_size: Cumulative sizes of multimodal items per request + prefix_length: Prefix lengths for each request + extend_length: Sequence lengths for each request + items_offset_list: List of offset ranges for multimodal items in each request + + Returns: + A tuple containing: + - The generated embeddings tensor + - A boolean mask tensor indicating where these embeddings should be placed + + Raises: + AssertionError: If the number of multimodal tokens in input_ids doesn't match + the number of tokens in the generated embeddings """ # 1. Get the embedding - embedding = data_embedding_func(embedding_items) + # Calculate embedding for each request, try to get it from cache to avoid repeated calculation + embedding_list = [] + for i in range(len(items_size) - 1): + if items_size[i] == items_size[i + 1]: + continue + embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]] + items_offset = items_offset_list[i] + embedding_items_hash = get_embedding_hash(embedding_items_per_req) + # if all items has been prefixed, we do not need to calculate embedding + if all([offset_end < prefix_length[i] for _, offset_end in items_offset]): + continue + embedding_per_req = embedding_cache.get(embedding_items_hash) + if embedding_per_req is None: + embedding_per_req = data_embedding_func(embedding_items_per_req) + if not embedding_cache.put(embedding_items_hash, embedding_per_req): + print_warning_once( + "Multimodal embedding cache is full. Consider increasing the " + "`SGLANG_VLM_CACHE_SIZE_MB` environment variable." + ) + embedding_per_req_chunk, _, end_index = get_embedding_chunk( + embedding=embedding_per_req, + extend_prefix_len=prefix_length[i], + extend_seq_len=extend_length[i], + items_offset=items_offset, + ) + # remove this item from cache if chunk reaches to the end + embedding_per_req_length = ( + embedding_per_req.shape[0] + if embedding_per_req.dim() == 2 + else embedding_per_req.shape[0] * embedding_per_req.shape[1] + ) + if end_index == embedding_per_req_length: + embedding_cache.free(embedding_items_hash) + embedding_list.append(embedding_per_req_chunk) + if len(embedding_list) == 0: + return None, None + embedding = torch.concat(embedding_list, dim=0) # 2. Check the embedding - if embedding.dim() == 2: - num_mm_tokens_in_embedding = embedding.shape[0] - else: - num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1] - - # the mask of multimodal tokens from input_ids + num_mm_tokens_in_embedding = embedding.shape[0] special_multimodal_mask = torch.isin( input_ids, placeholder_tensor, @@ -222,9 +338,6 @@ def get_embedding_and_mask( "tokens from multimodal embeddings." ) if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding: - # TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding - # a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with - # extend_start_loc and extend_seq_lens chunked_prefill_size = global_server_args_dict["chunked_prefill_size"] if chunked_prefill_size != -1: logger.warning( @@ -245,7 +358,9 @@ def get_embedding_and_mask( def embed_mm_inputs( - mm_inputs: MultimodalInputs, + mm_inputs_list: List[MultimodalInputs], + extend_prefix_lens: List[int], + extend_seq_lens: List[int], input_ids: torch.Tensor, input_embedding: nn.Embedding, image_data_embedding_func: Callable[ @@ -257,125 +372,133 @@ def embed_mm_inputs( placeholder_tokens: dict[Modality, List[int]] = None, ) -> Optional[torch.Tensor]: """ - Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations + Embed multimodal inputs and integrate them with text token embeddings. - Args: - placeholder_tokens: denoting the token of multimodal data in input_ids. - If none, the pad_values of multimodal items are used + Args: + mm_inputs_list: List of multimodal inputs to process + extend_prefix_lens: Prefix lengths for each request + extend_seq_lens: Sequence lengths for each request + input_ids: Input token IDs tensor + input_embedding: Embedding layer for text tokens + image_data_embedding_func: Function to embed image data + audio_data_embedding_func: Function to embed audio data + placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None) - Returns: - final embedding: Optional[torch.Tensor] + Returns: + Combined embedding tensor with multimodal content integrated """ - if mm_inputs is None: + if mm_inputs_list is None: return None # 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values # we assume that multimodal data are represented with its pad_values in input_ids - # See `pad_input_ids` for more detail + item_flatten_list = [] + for mm_inputs in mm_inputs_list: + item_flatten_list += [item for item in mm_inputs.mm_items if item is not None] - # if placeholder_tokens is specified - if placeholder_tokens is not None: - placeholder_token_ids = flatten_nested_list( - [placeholder_token for placeholder_token in placeholder_tokens.values()] + embeddings, masks = [], [] + + # 2. Get multimodal embedding separately + # TODO: make this more generic + # Try get image embedding if any + if ( + any(True for item in item_flatten_list if item.is_image()) + and image_data_embedding_func + ): + items = [item for item in item_flatten_list if item.is_image()] + placeholder_tensor = torch.tensor( + [item.pad_value for item in items], + device=input_ids.device, ) - else: - placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items] - - assert isinstance(placeholder_token_ids[0], int) - - placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device) - - placeholder_masks = torch.isin(input_ids, placeholder_tensor) - - appearing_pad_values = torch.unique( - input_ids[placeholder_masks], return_counts=False - ) - - if appearing_pad_values.numel() == 0: - # all been prefixed - inputs_embeds = input_embedding(input_ids) - else: - appearing_items = [ - item - for item in mm_inputs.mm_items - if item.pad_value is not None and item.pad_value in appearing_pad_values - ] - - using_all_items = False - if len(appearing_items) == 0: - # This happens mostly when arg placeholder_token_ids is passed - logger.warning( - "No multimodal data item's pad value exist in placeholder ids. Using all items" + # calculate per request items length offset + items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int) + items_offsets = [] + for i, mm_inputs in enumerate(mm_inputs_list): + image_items = [item for item in mm_inputs.mm_items if item.is_image()] + items_size[i + 1] = len(image_items) + items_offsets.append( + flatten_nested_list( + [ + item.image_offsets + for item in mm_inputs.mm_items + if item.is_image() + ] + ) ) - using_all_items = True - appearing_items = mm_inputs.mm_items + items_size = torch.cumsum(items_size, dim=0).tolist() - embeddings, masks = [], [] + embedding, mask = get_embedding_and_mask( + data_embedding_func=image_data_embedding_func, + embedding_items=items, + placeholder_tensor=placeholder_tensor, + input_ids=input_ids, + items_size=items_size, + prefix_length=extend_prefix_lens, + extend_length=extend_seq_lens, + items_offset_list=items_offsets, + ) + embeddings += [embedding] + masks += [mask] - # 2. Get multimodal embedding separately - # TODO: make this more generic - # Try get image embedding if any - if ( - any(True for item in appearing_items if item.is_image()) - and image_data_embedding_func - ): - items = [item for item in appearing_items if item.is_image()] - embedding, mask = get_embedding_and_mask( - data_embedding_func=image_data_embedding_func, - embedding_items=items, - placeholder_tensor=( - # use the specified modality token to identify the location to embed - placeholder_tokens[Modality.IMAGE] - if using_all_items - else torch.tensor( - [item.pad_value for item in items], - device=input_ids.device, - ) - ), - input_ids=input_ids, + # Try get audio embedding if any + if ( + any(True for item in item_flatten_list if item.is_audio()) + and audio_data_embedding_func + ): + items = [item for item in item_flatten_list if item.is_audio()] + placeholder_tensor = torch.tensor( + [item.pad_value for item in items], + device=input_ids.device, + ) + items_offsets = [] + # calculate per request items length offset + items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int) + for i, mm_inputs in enumerate(mm_inputs_list): + audio_items = [item for item in mm_inputs.mm_items if item.is_audio()] + items_size[i + 1] = len(audio_items) + items_offsets.append( + flatten_nested_list( + [ + item.audio_offsets + for item in mm_inputs.mm_items + if item.is_audio() + ] + ) ) - embeddings += [embedding] - masks += [mask] + items_size = torch.cumsum(items_size, dim=0) - # Try get audio embedding if any - if ( - any(True for item in appearing_items if item.is_audio()) - and audio_data_embedding_func - ): - items = [item for item in appearing_items if item.is_audio()] - embedding, mask = get_embedding_and_mask( - data_embedding_func=audio_data_embedding_func, - embedding_items=items, - placeholder_tensor=( - placeholder_tokens[Modality.AUDIO] - if using_all_items - else torch.tensor( - [item.pad_value for item in items], - device=input_ids.device, - ) - ), - input_ids=input_ids, - ) - embeddings += [embedding] - masks += [mask] + embedding, mask = get_embedding_and_mask( + data_embedding_func=audio_data_embedding_func, + embedding_items=items, + placeholder_tensor=placeholder_tensor, + input_ids=input_ids, + items_size=items_size, + prefix_length=extend_prefix_lens, + extend_length=extend_seq_lens, + items_offset_list=items_offsets, + ) + embeddings += [embedding] + masks += [mask] - # 3. Get input embeddings - vocab_size = input_embedding.num_embeddings - # Important: clamp after getting original multimodal regions - # Clamp input ids. This is because the input_ids for the multimodal tokens are - # filled with the hash values of the multimodal for the prefix matching in the radix attention. - # There values are useless because their embeddings will be replaced by vision embeddings anyway. - input_ids.clamp_(min=0, max=vocab_size - 1) - inputs_embeds = input_embedding(input_ids) + # 3. Get input embeddings + vocab_size = input_embedding.num_embeddings + # Important: clamp after getting original multimodal regions + # Clamp input ids. This is because the input_ids for the multimodal tokens are + # filled with the hash values of the multimodal for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + input_ids.clamp_(min=0, max=vocab_size - 1) + inputs_embeds = input_embedding(input_ids) - # 4. Scatter embeddings into input embedding - for embedding, mask in zip(embeddings, masks): - mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device) - inputs_embeds = inputs_embeds.masked_scatter( - mask, - embedding.to(inputs_embeds.device, inputs_embeds.dtype), - ) + # 4. scatter embeddings into input embedding + for embedding, mask in zip(embeddings, masks): + if embedding is None or mask is None: + continue + mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter( + mask, + embedding.to(inputs_embeds.device, inputs_embeds.dtype), + ) return inputs_embeds @@ -393,16 +516,19 @@ def general_mm_embed_routine( **kwargs, ) -> torch.Tensor: """ - A general wrapper function to get final input embeds from multimodal models with a language model as causal model + Process multimodal inputs and forward through language model. - Args: - placeholder_token_ids (List[int]): the ids of mm data placeholder tokens - image_data_embedding_func : the function returning the image embedding - audio_data_embedding_func : the function returning the image embedding - - Returns: - forwarded hidden states + Args: + input_ids: Input token IDs tensor + forward_batch: Batch information for model forward pass + language_model: Base language model to use + image_data_embedding_func: Function to embed image data + audio_data_embedding_func: Function to embed audio data + placeholder_tokens: Token IDs for multimodal placeholders + **kwargs: Additional arguments passed to language model + Returns: + Hidden states from language model forward pass """ assert hasattr(language_model, "get_input_embeddings") embed_tokens = language_model.get_input_embeddings() @@ -410,9 +536,23 @@ def general_mm_embed_routine( not forward_batch.forward_mode.is_decode() and forward_batch.contains_mm_inputs() ): - mm_input = forward_batch.merge_mm_inputs() + mm_inputs_list = [ + mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None + ] + extend_prefix_lens = [ + prefix_len + for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu) + if forward_batch.mm_inputs[i] is not None + ] + extend_seq_lens = [ + seq_len + for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu) + if forward_batch.mm_inputs[i] is not None + ] inputs_embeds = embed_mm_inputs( - mm_inputs=mm_input, + mm_inputs_list=mm_inputs_list, + extend_prefix_lens=extend_prefix_lens, + extend_seq_lens=extend_seq_lens, input_ids=input_ids, input_embedding=embed_tokens, image_data_embedding_func=image_data_embedding_func, diff --git a/python/sglang/srt/managers/multimodal_processors/base_processor.py b/python/sglang/srt/managers/multimodal_processors/base_processor.py index b957adf4b..a293d4be4 100644 --- a/python/sglang/srt/managers/multimodal_processors/base_processor.py +++ b/python/sglang/srt/managers/multimodal_processors/base_processor.py @@ -5,7 +5,7 @@ import multiprocessing as mp import os import re from abc import ABC, abstractmethod -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -343,6 +343,33 @@ class BaseMultimodalProcessor(ABC): out.normalize() return out + @staticmethod + def get_mm_items_offset( + input_ids: torch.Tensor, mm_token_id: int + ) -> List[Tuple[int, int]]: + """ + Get a set of range for mm_items from input_ids + Example: + input_ids = [1, 2, 3, 3, 3, 4, 3, 3] + mm_token_id = 3 + return result = [(2,4),(6,7)] + """ + mask = input_ids == mm_token_id + + start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0] + end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0] + + return list(zip(start_positions.tolist(), end_positions.tolist())) + + @staticmethod + def get_mm_items_offset_by_pair( + input_ids: torch.Tensor, mm_start_id: int, mm_end_id: int + ) -> List[Tuple[int, int]]: + indices_start = (input_ids == mm_start_id).nonzero(as_tuple=True)[0] + 1 + indices_end = (input_ids == mm_end_id).nonzero(as_tuple=True)[0] - 1 + + return list(zip(indices_start.tolist(), indices_end.tolist())) + def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]): """Returns true if all images are preprocessed, false if all are not, and error otherwise.""" if not mm_inputs: diff --git a/python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py b/python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py index 188778781..8b31f7304 100644 --- a/python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +++ b/python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py @@ -70,8 +70,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0) items = [] + input_ids = res["input_ids"] + image_offsets = self.get_mm_items_offset( + input_ids=input_ids, mm_token_id=self._processor.image_token_id + ) item = MultimodalDataItem( pixel_values=res["images"], + image_offsets=image_offsets, modality=Modality.IMAGE, image_emb_mask=images_seq_mask, image_spatial_crop=batched_images_spatial_crop, @@ -80,6 +85,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): return { "mm_items": items, - "input_ids": res["input_ids"].tolist(), + "input_ids": input_ids.tolist(), "im_token_id": self._processor.image_token_id, } diff --git a/python/sglang/srt/managers/multimodal_processors/gemma3.py b/python/sglang/srt/managers/multimodal_processors/gemma3.py index 481a31718..1f7846ba9 100644 --- a/python/sglang/srt/managers/multimodal_processors/gemma3.py +++ b/python/sglang/srt/managers/multimodal_processors/gemma3.py @@ -61,6 +61,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ) items = [] + input_ids = ret["input_ids"].flatten() + image_offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=self.hf_config.image_token_index, + ) for i, image in enumerate(base_output.images): if images_are_preprocessed: pixel_values = image.pixel_values @@ -73,12 +78,13 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): pixel_values=pixel_values, precomputed_features=precomputed_features, modality=Modality.IMAGE, + image_offsets=image_offsets[i], ) items += [item] return { "mm_items": items, - "input_ids": ret["input_ids"].flatten().tolist(), + "input_ids": input_ids.tolist(), "im_start_id": self.IM_START_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID, } diff --git a/python/sglang/srt/managers/multimodal_processors/internvl.py b/python/sglang/srt/managers/multimodal_processors/internvl.py index 44143b501..6d7c14c4f 100644 --- a/python/sglang/srt/managers/multimodal_processors/internvl.py +++ b/python/sglang/srt/managers/multimodal_processors/internvl.py @@ -209,7 +209,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor): return None pixel_values = torch.cat(pixel_values, dim=0) - items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)] for idx, num_patches in enumerate(num_patches_list): image_tokens = ( @@ -220,10 +219,21 @@ class InternVLImageProcessor(BaseMultimodalProcessor): input_text = input_text.replace("", image_tokens, 1) tokenizer = self._processor + input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten() + image_offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=self.img_context_token_id, + ) + items = [ + MultimodalDataItem( + pixel_values=pixel_values, + modality=Modality.IMAGE, + image_offsets=image_offsets, + ) + ] + return { - "input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"] - .flatten() - .tolist(), + "input_ids": input_ids.tolist(), "mm_items": items, "im_start_id": self.img_start_token_id, "im_end_id": self.img_end_token_id, diff --git a/python/sglang/srt/managers/multimodal_processors/janus_pro.py b/python/sglang/srt/managers/multimodal_processors/janus_pro.py index cf68703f8..c06aedbc5 100644 --- a/python/sglang/srt/managers/multimodal_processors/janus_pro.py +++ b/python/sglang/srt/managers/multimodal_processors/janus_pro.py @@ -45,15 +45,21 @@ class JanusProImageProcessor(BaseMultimodalProcessor): prompt=base_out.input_text, images=images, ) + + input_ids = res["input_ids"].flatten() + image_offsets = self.get_mm_items_offset( + input_ids=input_ids, mm_token_id=processor.image_id + ) return { "mm_items": [ MultimodalDataItem( pixel_values=res["pixel_values"], image_emb_mask=res["images_emb_mask"], + image_offsets=image_offsets, modality=Modality.IMAGE, ) ], - "input_ids": res["input_ids"].flatten().tolist(), + "input_ids": input_ids.tolist(), "im_start_id": processor.image_start_id, "im_end_id": processor.image_end_id, "im_token_id": processor.image_id, diff --git a/python/sglang/srt/managers/multimodal_processors/kimi_vl.py b/python/sglang/srt/managers/multimodal_processors/kimi_vl.py index 4d596941b..0a276f7ce 100644 --- a/python/sglang/srt/managers/multimodal_processors/kimi_vl.py +++ b/python/sglang/srt/managers/multimodal_processors/kimi_vl.py @@ -1,10 +1,5 @@ -import asyncio -import math from typing import List, Union -import torch -from PIL import Image - from sglang.srt.managers.multimodal_processors.base_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor, ) @@ -57,13 +52,19 @@ class KimiVLImageProcessor(SGLangBaseProcessor): input_text=base_output.input_text, images=base_output.images, ) + input_ids = ret["input_ids"].flatten() + image_offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=self.im_token_id, + ) return { - "input_ids": ret["input_ids"].flatten().tolist(), + "input_ids": input_ids.tolist(), "mm_items": [ MultimodalDataItem( pixel_values=ret["pixel_values"], image_grid_thws=ret["image_grid_hws"], modality=Modality.IMAGE, + image_offsets=image_offsets, ) ], "im_token_id": self.im_token_id, diff --git a/python/sglang/srt/managers/multimodal_processors/llava.py b/python/sglang/srt/managers/multimodal_processors/llava.py index 5da52c11e..15968e16d 100644 --- a/python/sglang/srt/managers/multimodal_processors/llava.py +++ b/python/sglang/srt/managers/multimodal_processors/llava.py @@ -1,5 +1,4 @@ import asyncio -import importlib from typing import List, Optional, Union import numpy as np diff --git a/python/sglang/srt/managers/multimodal_processors/minicpm.py b/python/sglang/srt/managers/multimodal_processors/minicpm.py index f6611ac79..dba7245e8 100644 --- a/python/sglang/srt/managers/multimodal_processors/minicpm.py +++ b/python/sglang/srt/managers/multimodal_processors/minicpm.py @@ -69,6 +69,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): audio_start_id = tokenizer.audio_start_id audio_end_id = tokenizer.audio_end_id + im_start_id = tokenizer.im_start_id + im_end_id = tokenizer.im_end_id im_token_id = tokenizer.unk_id pixel_values = res["pixel_values"] tgt_sizes = res["tgt_sizes"] @@ -104,9 +106,20 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): pixel_values = pixel_values_flat items = [] + input_ids = res["input_ids"].flatten() + image_offsets = self.get_mm_items_offset_by_pair( + input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id + ) + slice_offsets = self.get_mm_items_offset_by_pair( + input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id + ) + image_offsets.extend(slice_offsets) + image_offsets = sorted(image_offsets) + if len(pixel_values) != 0: item = MultimodalDataItem( pixel_values=pixel_values, + image_offsets=image_offsets, tgt_size=tgt_sizes_flat, modality=Modality.IMAGE, ) @@ -117,21 +130,30 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): and res["audio_features"] is not None and len(res["audio_features"]) != 0 ): + if audio_start_id is not None and audio_end_id is not None: + audio_offsets = self.get_mm_items_offset_by_pair( + input_ids=input_ids, + mm_start_id=audio_start_id, + mm_end_id=audio_end_id, + ) + else: + audio_offsets = None item = MultimodalDataItem( audio_features=[res["audio_features"]], audio_feature_lens=res["audio_feature_lens"], + audio_offsets=audio_offsets, modality=Modality.AUDIO, ) items += [item] return { "mm_items": items, - "input_ids": res["input_ids"].flatten().tolist(), + "input_ids": input_ids.tolist(), "audio_start_id": audio_start_id, "audio_end_id": audio_end_id, "im_token_id": im_token_id, - "im_start_id": tokenizer.im_start_id, - "im_end_id": tokenizer.im_end_id, + "im_start_id": im_start_id, + "im_end_id": im_end_id, "slice_start_id": slice_start_id, "slice_end_id": slice_end_id, } diff --git a/python/sglang/srt/managers/multimodal_processors/mllama4.py b/python/sglang/srt/managers/multimodal_processors/mllama4.py index 0d3c289b5..4dadca6f8 100644 --- a/python/sglang/srt/managers/multimodal_processors/mllama4.py +++ b/python/sglang/srt/managers/multimodal_processors/mllama4.py @@ -135,11 +135,17 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): processor_output["im_end_id"] = self.eoi_token_index processor_output["im_token_id"] = self.image_token_index + image_offsets = self.get_mm_items_offset( + input_ids=torch.tensor(processor_output["input_ids"]), + mm_token_id=self.image_token_index, + ) + # Add metadata for image processing processor_output["mm_items"] = [ MultimodalDataItem( pixel_values=processor_output["pixel_values"], modality=Modality.IMAGE, + image_offsets=image_offsets, ) ] diff --git a/python/sglang/srt/managers/multimodal_processors/pixtral.py b/python/sglang/srt/managers/multimodal_processors/pixtral.py index 07a772cdf..638938097 100644 --- a/python/sglang/srt/managers/multimodal_processors/pixtral.py +++ b/python/sglang/srt/managers/multimodal_processors/pixtral.py @@ -1,9 +1,7 @@ import asyncio import math -from typing import List, Optional, Union +from typing import List, Union -import numpy as np -from transformers import PretrainedConfig from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens as _get_pixtral_hf_num_image_tokens, ) @@ -12,11 +10,7 @@ from sglang.srt.managers.multimodal_processors.base_processor import ( BaseMultimodalProcessor, MultimodalSpecialTokens, ) -from sglang.srt.managers.schedule_batch import ( - Modality, - MultimodalDataItem, - MultimodalInputs, -) +from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.models.pixtral import PixtralVisionModel @@ -108,15 +102,21 @@ class PixtralProcessor(BaseMultimodalProcessor): ) if "pixel_values" in processor_output: + input_ids = processor_output["input_ids"].view(-1) + image_offsets = self.get_mm_items_offset( + input_ids=input_ids, + mm_token_id=self.image_token_id, + ) mm_items = [ MultimodalDataItem( pixel_values=processor_output["pixel_values"], image_sizes=processor_output["image_sizes"], modality=Modality.IMAGE, + image_offsets=image_offsets, ) ] - input_ids = processor_output["input_ids"].view(-1).tolist() + input_ids = input_ids.tolist() processor_output.update( input_ids=input_ids, mm_items=mm_items, diff --git a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py index ef7ed44b3..76c3c546f 100644 --- a/python/sglang/srt/managers/multimodal_processors/qwen_vl.py +++ b/python/sglang/srt/managers/multimodal_processors/qwen_vl.py @@ -135,6 +135,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): images=None if images_are_preprocessed else base_output.images, ) input_ids = ret["input_ids"].flatten().tolist() + image_offsets = self.get_mm_items_offset( + input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id + ) image_grid_thw = None video_grid_thw = None # TODO items = [] @@ -175,6 +178,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): image_grid_thws=image_grid_thw, video_grid_thws=video_grid_thw, precomputed_features=precomputed_features, + image_offsets=image_offsets, modality=Modality.IMAGE, ) ] diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index abc466fa9..4f86ac5dd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -197,6 +197,7 @@ class MultimodalDataItem: audio_features: Union[torch.Tensor, np.ndarray] = None audio_feature_lens: Optional[List[torch.Tensor]] = None + audio_offsets: Optional[List[Tuple[int, int]]] = None precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None @@ -1097,7 +1098,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): else: self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc) - assert len(self.out_cache_loc) == self.extend_num_tokens + assert ( + len(self.out_cache_loc) == self.extend_num_tokens + ), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}" def prepare_for_extend(self): self.forward_mode = ForwardMode.EXTEND diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 33e599208..60f39b1a5 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -102,6 +102,7 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, ) +from sglang.srt.managers.mm_utils import init_embedding_cache from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, MultimodalInputs, @@ -2282,6 +2283,10 @@ def run_scheduler_process( if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) + embedding_cache_size = 100 + if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ: + embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"]) + init_embedding_cache(embedding_cache_size * 1024 * 1024) # Create a scheduler and run the event loop try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank) diff --git a/python/sglang/srt/mem_cache/multimodal_cache.py b/python/sglang/srt/mem_cache/multimodal_cache.py new file mode 100644 index 000000000..985fd32eb --- /dev/null +++ b/python/sglang/srt/mem_cache/multimodal_cache.py @@ -0,0 +1,45 @@ +from typing import Dict + +import torch + + +class MultiModalCache: + """MultiModalCache is used to store vlm encoder results""" + + def __init__( + self, + max_size: int, + ): + self.max_size = max_size + self.mm_cache: Dict[int, torch.Tensor] = {} + self.current_size = 0 + + def put(self, mm_hash: int, embedding: torch.Tensor) -> bool: + if mm_hash in self.mm_cache: + return True + data_size = self._get_tensor_size(embedding) + if self.current_size + data_size > self.max_size: + return False + self.mm_cache[mm_hash] = embedding + self.current_size += data_size + return True + + def get(self, mm_hash: int) -> torch.Tensor: + return self.mm_cache.get(mm_hash) + + def free(self, mm_hash: int) -> bool: + if mm_hash not in self.mm_cache: + return False + old_embedding = self.mm_cache.pop(mm_hash) + self.current_size -= self._get_tensor_size(old_embedding) + return True + + def clear(self): + self.mm_cache.clear() + self.current_size = 0 + + def _get_tensor_size(self, embedding: torch.Tensor): + return embedding.element_size() * embedding.numel() + + def __len__(self): + return len(self.mm_cache) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8aed6399f..916768d60 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -166,6 +166,9 @@ class ModelRunner: self.is_draft_worker = is_draft_worker self.is_generation = model_config.is_generation self.is_multimodal = model_config.is_multimodal + self.is_multimodal_chunked_prefill_supported = ( + model_config.is_multimodal_chunked_prefill_supported + ) self.spec_algorithm = SpeculativeAlgorithm.from_string( server_args.speculative_algorithm ) @@ -389,12 +392,15 @@ class ModelRunner: if self.is_multimodal: self.mem_fraction_static *= 0.90 logger.info( - f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model." - ) - server_args.chunked_prefill_size = -1 - logger.info( - "Automatically turn off --chunked-prefill-size for multimodal model." + f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " + f"because this is a multimodal model." ) + if not self.is_multimodal_chunked_prefill_supported: + server_args.chunked_prefill_size = -1 + logger.info( + f"Automatically turn of --chunked-prefill-size as it is not supported for " + f"{self.model_config.hf_config.model_type}" + ) if not self.use_mla_backend: server_args.disable_chunked_prefix_cache = True diff --git a/python/sglang/srt/models/minicpmo.py b/python/sglang/srt/models/minicpmo.py index 7199da4f1..24d983f1e 100644 --- a/python/sglang/srt/models/minicpmo.py +++ b/python/sglang/srt/models/minicpmo.py @@ -1826,22 +1826,12 @@ class MiniCPMO(MiniCPMBaseModel): **kwargs: Any, ) -> torch.Tensor: - mm_input = forward_batch.merge_mm_inputs() - placeholder_token_ids = ( - ([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items]) - if forward_batch.contains_mm_inputs() - else [] - ) hidden_states = general_mm_embed_routine( input_ids=input_ids, forward_batch=forward_batch, language_model=self.llm, image_data_embedding_func=self.get_image_feature, audio_data_embedding_func=self.get_audio_feature, - placeholder_tokens={ - Modality.IMAGE: placeholder_token_ids, - Modality.AUDIO: placeholder_token_ids, - }, positions=positions, ) return hidden_states diff --git a/test/srt/test_vision_openai_server_common.py b/test/srt/test_vision_openai_server_common.py index a10605ae5..eda29f056 100644 --- a/test/srt/test_vision_openai_server_common.py +++ b/test/srt/test_vision_openai_server_common.py @@ -294,20 +294,24 @@ class TestOpenAIVisionServer(CustomTestCase): print("-" * 30) # Add assertions to validate the video response - assert "iPod" in video_response or "device" in video_response, video_response + assert ( + "iPod" in video_response or "device" in video_response + ), f"video_response: {video_response}, should contain 'iPod' or 'device'" assert ( "man" in video_response or "person" in video_response or "individual" in video_response or "speaker" in video_response - ), video_response + ), f"video_response: {video_response}, should either have 'man' in video_response, or 'person' in video_response, or 'individual' in video_response or 'speaker' in video_response" assert ( "present" in video_response or "examine" in video_response or "display" in video_response or "hold" in video_response - ) - assert "black" in video_response or "dark" in video_response + ), f"video_response: {video_response}, should contain 'present', 'examine', 'display', or 'hold'" + assert ( + "black" in video_response or "dark" in video_response + ), f"video_response: {video_response}, should contain 'black' or 'dark'" self.assertIsNotNone(video_response) self.assertGreater(len(video_response), 0) diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index 49805230a..9ad17dcb7 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -21,7 +21,10 @@ from transformers import ( from sglang import Engine from sglang.srt.configs.model_config import ModelConfig from sglang.srt.conversation import generate_chat_conv -from sglang.srt.managers.mm_utils import embed_mm_inputs +from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache +from sglang.srt.managers.multimodal_processors.base_processor import ( + BaseMultimodalProcessor, +) from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, @@ -188,6 +191,7 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): .eval() .to(cls.device) ) + init_embedding_cache(0) async def test_vlm_embedding_output(self): """ @@ -226,17 +230,41 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): for pixel_n, tgt_n in zip(pixel_b, tgt_b): pixel_values_flat += [pixel_n] tgt_sizes_flat += [tgt_n] + + im_start_id, im_end_id = ( + self.tokenizer.im_start_id, + self.tokenizer.im_end_id, + ) + slice_start_id, slice_end_id = ( + self.tokenizer.slice_start_id, + self.tokenizer.slice_end_id, + ) + + image_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair( + input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id + ) + slice_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair( + input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id + ) + image_offsets.extend(slice_offsets) + image_offsets = sorted(image_offsets) + sglang_output = embed_mm_inputs( - mm_inputs=MultimodalInputs( - mm_items=[ - MultimodalDataItem( - pixel_values=pixel_values_flat, - tgt_size=tgt_sizes_flat, - modality=Modality.IMAGE, - pad_value=self.processor.tokenizer.unk_token_id, - ) - ] - ), + mm_inputs_list=[ + MultimodalInputs( + mm_items=[ + MultimodalDataItem( + pixel_values=pixel_values_flat, + image_offsets=image_offsets, + tgt_size=tgt_sizes_flat, + modality=Modality.IMAGE, + pad_value=self.processor.tokenizer.unk_token_id, + ) + ] + ), + ], + extend_prefix_lens=[0], + extend_seq_lens=[input_ids.shape[0]], input_ids=input_ids, input_embedding=model.get_input_embeddings(), image_data_embedding_func=model.get_image_feature,